from uuid import UUID import uuid from sqlalchemy import select from sqlalchemy.orm import Session from argon2 import PasswordHasher from argon2.exceptions import VerifyMismatchError from app.models.user import User from app.schemas.user import ( UserCreate, UserUpdate, AdministrativeUserUpdate, PasswordUpdate, LoginRequest, ) from app.util.errors import InvalidStateError, NotFoundError hasher = PasswordHasher(memory_cost=102400) async def get_user(db: Session, id: UUID): return db.get(User, id) async def get_user_by_email(db: Session, email: str): stmt = select(User).where(User.email == email) result = db.execute(stmt) return result.scalars().first() async def get_users( db: Session, skip: int = 0, limit: int = 20, email: str = None ): stmt = select(User) if email is not None: stmt = stmt.where(User.email.like(email)) stmt = stmt.offset(skip).limit(limit) result = db.execute(stmt) return result.scalars().all() async def create_user(db: Session, user: UserCreate) -> User: if await get_user_by_email(db=db, email=user.email): raise InvalidStateError hashed_password = hasher.hash(user.password) db_user = User( friendly_name=user.friendly_name, email=user.email, password=hashed_password ) db.add(db_user) db.commit() db.refresh(db_user) return db_user async def update_user( db: Session, id: UUID, update: UserUpdate | AdministrativeUserUpdate, ) -> User: db_user = await get_user(db, id) if db_user is None: raise NotFoundError changed_attributes = dict() for key, value in update.model_dump(exclude_unset=True).items(): changed_attributes[key] = {"old": getattr(db_user, key), "new": value} setattr(db_user, key, value) db.commit() return db_user async def change_user_password(db: Session, id: UUID, update: PasswordUpdate): db_user = await get_user(db, id) if db_user is None: raise NotFoundError try: hasher.verify(hash=db_user.password, password=update.old_password) db_user.password = hasher.hash(update.new_password) db.commit() except VerifyMismatchError: raise InvalidStateError async def remove_user(db: Session, id: UUID): db_user = await get_user(db, id) if db_user is None: raise NotFoundError db.delete(db_user) db.commit() async def validate_login(db: Session, login: LoginRequest) -> User | None: stmt = select(User).where(User.email == login.email) result = db.execute(stmt) db_user = result.scalars().first() if db_user is None: db.commit() return None try: hasher.verify(hash=db_user.password, password=login.password) if hasher.check_needs_rehash(db_user.password): db_user.password = hasher.hash(login.password) if db_user.is_active: db.commit() return db_user else: db.commit() return None except VerifyMismatchError: db.commit() return None