111 lines
3 KiB
Python
111 lines
3 KiB
Python
from uuid import UUID
|
|
import uuid
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
from argon2 import PasswordHasher
|
|
from argon2.exceptions import VerifyMismatchError
|
|
|
|
from app.models.user import User
|
|
from app.schemas.user import (
|
|
UserCreate,
|
|
UserUpdate,
|
|
AdministrativeUserUpdate,
|
|
PasswordUpdate,
|
|
LoginRequest,
|
|
)
|
|
from app.util.errors import InvalidStateError, NotFoundError
|
|
|
|
hasher = PasswordHasher(memory_cost=102400)
|
|
|
|
async def get_user(db: Session, id: UUID):
|
|
return db.get(User, id)
|
|
|
|
|
|
async def get_user_by_email(db: Session, email: str):
|
|
stmt = select(User).where(User.email == email)
|
|
result = db.execute(stmt)
|
|
return result.scalars().first()
|
|
|
|
|
|
async def get_users(
|
|
db: Session, skip: int = 0, limit: int = 20, email: str = None
|
|
):
|
|
stmt = select(User)
|
|
if email is not None:
|
|
stmt = stmt.where(User.email.like(email))
|
|
stmt = stmt.offset(skip).limit(limit)
|
|
result = db.execute(stmt)
|
|
return result.scalars().all()
|
|
|
|
|
|
async def create_user(db: Session, user: UserCreate) -> User:
|
|
if await get_user_by_email(db=db, email=user.email):
|
|
raise InvalidStateError
|
|
hashed_password = hasher.hash(user.password)
|
|
db_user = User(
|
|
friendly_name=user.friendly_name, email=user.email, password=hashed_password
|
|
)
|
|
db.add(db_user)
|
|
db.commit()
|
|
db.refresh(db_user)
|
|
return db_user
|
|
|
|
|
|
async def update_user(
|
|
db: Session,
|
|
id: UUID,
|
|
update: UserUpdate | AdministrativeUserUpdate,
|
|
) -> User:
|
|
db_user = await get_user(db, id)
|
|
if db_user is None:
|
|
raise NotFoundError
|
|
|
|
changed_attributes = dict()
|
|
for key, value in update.model_dump(exclude_unset=True).items():
|
|
changed_attributes[key] = {"old": getattr(db_user, key), "new": value}
|
|
setattr(db_user, key, value)
|
|
db.commit()
|
|
return db_user
|
|
|
|
|
|
async def change_user_password(db: Session, id: UUID, update: PasswordUpdate):
|
|
db_user = await get_user(db, id)
|
|
if db_user is None:
|
|
raise NotFoundError
|
|
try:
|
|
hasher.verify(hash=db_user.password, password=update.old_password)
|
|
db_user.password = hasher.hash(update.new_password)
|
|
db.commit()
|
|
except VerifyMismatchError:
|
|
raise InvalidStateError
|
|
|
|
|
|
async def remove_user(db: Session, id: UUID):
|
|
db_user = await get_user(db, id)
|
|
if db_user is None:
|
|
raise NotFoundError
|
|
db.delete(db_user)
|
|
db.commit()
|
|
|
|
|
|
async def validate_login(db: Session, login: LoginRequest) -> User | None:
|
|
stmt = select(User).where(User.email == login.email)
|
|
result = db.execute(stmt)
|
|
db_user = result.scalars().first()
|
|
if db_user is None:
|
|
db.commit()
|
|
return None
|
|
try:
|
|
hasher.verify(hash=db_user.password, password=login.password)
|
|
if hasher.check_needs_rehash(db_user.password):
|
|
db_user.password = hasher.hash(login.password)
|
|
if db_user.is_active:
|
|
db.commit()
|
|
return db_user
|
|
else:
|
|
db.commit()
|
|
return None
|
|
except VerifyMismatchError:
|
|
db.commit()
|
|
return None
|