Implement user authentication and permissions
This commit is contained in:
parent
5e9d90ed0b
commit
ac8303378a
26 changed files with 1182 additions and 172 deletions
83
app/services/session_service.py
Normal file
83
app/services/session_service.py
Normal file
|
@ -0,0 +1,83 @@
|
|||
from datetime import datetime, UTC
|
||||
import secrets
|
||||
from uuid import UUID
|
||||
from sqlalchemy import select, delete
|
||||
from sqlalchemy.orm import Session as SqlaSession
|
||||
|
||||
from app.models.session import Session
|
||||
from app.models.user import User
|
||||
from app.util.errors import NotFoundError
|
||||
|
||||
|
||||
async def get_sessions(
|
||||
db: SqlaSession, skip: int = 0, limit: int = 20
|
||||
) -> tuple[Session]:
|
||||
stmt = select(Session).offset(skip).limit(limit)
|
||||
result = db.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
async def get_sessions_by_user(db: SqlaSession, user_id: UUID) -> tuple[Session]:
|
||||
stmt = select(Session).where(Session.user_id == user_id)
|
||||
result = db.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
async def create_session(db: SqlaSession, user: User, useragent: str) -> Session:
|
||||
session = Session(
|
||||
name=useragent,
|
||||
refresh_token=secrets.token_urlsafe(64),
|
||||
last_used=datetime.now(UTC),
|
||||
user_id=user.id,
|
||||
)
|
||||
db.add(session)
|
||||
db.commit()
|
||||
db.refresh(session)
|
||||
return session
|
||||
|
||||
|
||||
async def remove_session(db: SqlaSession, id: UUID):
|
||||
session = db.get(Session, id)
|
||||
if not session:
|
||||
raise NotFoundError
|
||||
db.delete(session)
|
||||
db.commit()
|
||||
|
||||
|
||||
async def remove_session_for_user(
|
||||
db: SqlaSession, id: UUID, user_id: UUID
|
||||
):
|
||||
stmt = select(Session).where(Session.id == id and Session.user_id == user_id)
|
||||
result = db.execute(stmt)
|
||||
session = result.scalars().first()
|
||||
if not session:
|
||||
raise NotFoundError
|
||||
db.delete(session)
|
||||
db.commit()
|
||||
|
||||
|
||||
async def remove_all_sessions_for_user(db: SqlaSession, user_id: UUID):
|
||||
stmt = delete(Session).where(Session.user_id == user_id)
|
||||
db.execute(stmt)
|
||||
db.commit()
|
||||
|
||||
|
||||
async def remove_all_sessions(db: SqlaSession):
|
||||
stmt = delete(Session)
|
||||
db.execute(stmt)
|
||||
db.commit()
|
||||
|
||||
|
||||
async def validate_and_rotate_refresh_token(
|
||||
db: SqlaSession, refresh_token: str
|
||||
) -> Session:
|
||||
stmt = select(Session).where(Session.refresh_token == refresh_token)
|
||||
result = db.execute(stmt)
|
||||
session = result.scalars().first()
|
||||
if not session:
|
||||
raise NotFoundError
|
||||
session.refresh_token = secrets.token_urlsafe(64)
|
||||
session.last_used = datetime.now(UTC)
|
||||
|
||||
db.commit()
|
||||
return session
|
69
app/services/token_service.py
Normal file
69
app/services/token_service.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
import json
|
||||
import os
|
||||
import secrets
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
from jwcrypto import jwt, jwk
|
||||
from datetime import datetime, timedelta, UTC
|
||||
|
||||
from app.models.user import User
|
||||
from app.schemas.auth_token import AccessToken
|
||||
from app.schemas.user import Role
|
||||
from app.util.errors import InsufficientPermissionsError, InvalidTokenAudienceError
|
||||
|
||||
__signing_key = jwk.JWK.from_password(os.getenv("CS_TOKEN_SECRET", secrets.token_urlsafe(64)))
|
||||
|
||||
async def __create_token(claims: dict) -> str:
|
||||
default_claims = {
|
||||
"iss": os.getenv("CS_TOKEN_ISSUER", "https://localhost:8000"),
|
||||
"iat": datetime.now(UTC).timestamp(),
|
||||
}
|
||||
header = {"alg": "HS256", "typ": "JWT", "kid": "default"}
|
||||
token = jwt.JWT(header=header, claims=(claims | default_claims))
|
||||
token.make_signed_token(__signing_key)
|
||||
return token.serialize()
|
||||
|
||||
|
||||
async def __verify_token(token: str, audience: str) -> dict | None:
|
||||
try:
|
||||
token = jwt.JWT(jwt=token, key=__signing_key)
|
||||
claims = json.loads(token.claims)
|
||||
if claims.get("aud") == audience:
|
||||
return claims
|
||||
else:
|
||||
raise InvalidTokenAudienceError
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def create_access_token(
|
||||
user: User, session_id: UUID
|
||||
) -> tuple[str, datetime]:
|
||||
token_lifetime = float(os.getenv("CS_ACCESS_TOKEN_LIFETIME_SECONDS", "300"))
|
||||
exp_time = datetime.now(UTC) + timedelta(seconds=token_lifetime)
|
||||
claims = {
|
||||
"aud": "access",
|
||||
"sub": str(user.id),
|
||||
"exp": exp_time.timestamp(),
|
||||
"session": str(session_id),
|
||||
"role": str(user.role),
|
||||
}
|
||||
return await __create_token(claims=claims), exp_time
|
||||
|
||||
async def verify_access_token(
|
||||
token: str, required_roles: Optional[list[str]] = None
|
||||
) -> AccessToken | None:
|
||||
try:
|
||||
claims = await __verify_token(token=token, audience="access")
|
||||
if not claims:
|
||||
return None
|
||||
if not required_roles or claims.get("role") in required_roles:
|
||||
return AccessToken(
|
||||
subject=claims.get("sub"),
|
||||
role=Role(claims.get("role")),
|
||||
session=claims.get("session"),
|
||||
)
|
||||
else:
|
||||
raise InsufficientPermissionsError
|
||||
except InvalidTokenAudienceError:
|
||||
pass
|
111
app/services/user_service.py
Normal file
111
app/services/user_service.py
Normal file
|
@ -0,0 +1,111 @@
|
|||
from uuid import UUID
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.exceptions import VerifyMismatchError
|
||||
|
||||
from app.models.user import User
|
||||
from app.schemas.user import (
|
||||
UserCreate,
|
||||
UserUpdate,
|
||||
AdministrativeUserUpdate,
|
||||
PasswordUpdate,
|
||||
LoginRequest,
|
||||
)
|
||||
from app.util.errors import InvalidStateError, NotFoundError
|
||||
|
||||
hasher = PasswordHasher(memory_cost=102400)
|
||||
|
||||
async def get_user(db: Session, id: UUID):
|
||||
return db.get(User, id)
|
||||
|
||||
|
||||
async def get_user_by_email(db: Session, email: str):
|
||||
stmt = select(User).where(User.email == email)
|
||||
result = db.execute(stmt)
|
||||
return result.scalars().first()
|
||||
|
||||
|
||||
async def get_users(
|
||||
db: Session, skip: int = 0, limit: int = 20, email: str = None
|
||||
):
|
||||
stmt = select(User)
|
||||
if email is not None:
|
||||
stmt = stmt.where(User.email.like(email))
|
||||
stmt = stmt.offset(skip).limit(limit)
|
||||
result = db.execute(stmt)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
async def create_user(db: Session, user: UserCreate) -> User:
|
||||
if await get_user_by_email(db=db, email=user.email):
|
||||
raise InvalidStateError
|
||||
hashed_password = hasher.hash(user.password)
|
||||
db_user = User(
|
||||
friendly_name=user.friendly_name, email=user.email, password=hashed_password
|
||||
)
|
||||
db.add(db_user)
|
||||
db.commit()
|
||||
db.refresh(db_user)
|
||||
return db_user
|
||||
|
||||
|
||||
async def update_user(
|
||||
db: Session,
|
||||
id: UUID,
|
||||
update: UserUpdate | AdministrativeUserUpdate,
|
||||
) -> User:
|
||||
db_user = await get_user(db, id)
|
||||
if db_user is None:
|
||||
raise NotFoundError
|
||||
|
||||
changed_attributes = dict()
|
||||
for key, value in update.model_dump(exclude_unset=True).items():
|
||||
changed_attributes[key] = {"old": getattr(db_user, key), "new": value}
|
||||
setattr(db_user, key, value)
|
||||
db.commit()
|
||||
return db_user
|
||||
|
||||
|
||||
async def change_user_password(db: Session, id: UUID, update: PasswordUpdate):
|
||||
db_user = await get_user(db, id)
|
||||
if db_user is None:
|
||||
raise NotFoundError
|
||||
try:
|
||||
hasher.verify(hash=db_user.password, password=update.old_password)
|
||||
db_user.password = hasher.hash(update.new_password)
|
||||
db.commit()
|
||||
except VerifyMismatchError:
|
||||
raise InvalidStateError
|
||||
|
||||
|
||||
async def remove_user(db: Session, id: UUID):
|
||||
db_user = await get_user(db, id)
|
||||
if db_user is None:
|
||||
raise NotFoundError
|
||||
db.delete(db_user)
|
||||
db.commit()
|
||||
|
||||
|
||||
async def validate_login(db: Session, login: LoginRequest) -> User | None:
|
||||
stmt = select(User).where(User.email == login.email)
|
||||
result = db.execute(stmt)
|
||||
db_user = result.scalars().first()
|
||||
if db_user is None:
|
||||
db.commit()
|
||||
return None
|
||||
try:
|
||||
hasher.verify(hash=db_user.password, password=login.password)
|
||||
if hasher.check_needs_rehash(db_user.password):
|
||||
db_user.password = hasher.hash(login.password)
|
||||
if db_user.is_active:
|
||||
db.commit()
|
||||
return db_user
|
||||
else:
|
||||
db.commit()
|
||||
return None
|
||||
except VerifyMismatchError:
|
||||
db.commit()
|
||||
return None
|
Loading…
Add table
Add a link
Reference in a new issue