83 lines
2.2 KiB
Python
83 lines
2.2 KiB
Python
|
from datetime import datetime, UTC
|
||
|
import secrets
|
||
|
from uuid import UUID
|
||
|
from sqlalchemy import select, delete
|
||
|
from sqlalchemy.orm import Session as SqlaSession
|
||
|
|
||
|
from app.models.session import Session
|
||
|
from app.models.user import User
|
||
|
from app.util.errors import NotFoundError
|
||
|
|
||
|
|
||
|
async def get_sessions(
|
||
|
db: SqlaSession, skip: int = 0, limit: int = 20
|
||
|
) -> tuple[Session]:
|
||
|
stmt = select(Session).offset(skip).limit(limit)
|
||
|
result = db.execute(stmt)
|
||
|
return result.scalars().all()
|
||
|
|
||
|
|
||
|
async def get_sessions_by_user(db: SqlaSession, user_id: UUID) -> tuple[Session]:
|
||
|
stmt = select(Session).where(Session.user_id == user_id)
|
||
|
result = db.execute(stmt)
|
||
|
return result.scalars().all()
|
||
|
|
||
|
|
||
|
async def create_session(db: SqlaSession, user: User, useragent: str) -> Session:
|
||
|
session = Session(
|
||
|
name=useragent,
|
||
|
refresh_token=secrets.token_urlsafe(64),
|
||
|
last_used=datetime.now(UTC),
|
||
|
user_id=user.id,
|
||
|
)
|
||
|
db.add(session)
|
||
|
db.commit()
|
||
|
db.refresh(session)
|
||
|
return session
|
||
|
|
||
|
|
||
|
async def remove_session(db: SqlaSession, id: UUID):
|
||
|
session = db.get(Session, id)
|
||
|
if not session:
|
||
|
raise NotFoundError
|
||
|
db.delete(session)
|
||
|
db.commit()
|
||
|
|
||
|
|
||
|
async def remove_session_for_user(
|
||
|
db: SqlaSession, id: UUID, user_id: UUID
|
||
|
):
|
||
|
stmt = select(Session).where(Session.id == id and Session.user_id == user_id)
|
||
|
result = db.execute(stmt)
|
||
|
session = result.scalars().first()
|
||
|
if not session:
|
||
|
raise NotFoundError
|
||
|
db.delete(session)
|
||
|
db.commit()
|
||
|
|
||
|
|
||
|
async def remove_all_sessions_for_user(db: SqlaSession, user_id: UUID):
|
||
|
stmt = delete(Session).where(Session.user_id == user_id)
|
||
|
db.execute(stmt)
|
||
|
db.commit()
|
||
|
|
||
|
|
||
|
async def remove_all_sessions(db: SqlaSession):
|
||
|
stmt = delete(Session)
|
||
|
db.execute(stmt)
|
||
|
db.commit()
|
||
|
|
||
|
|
||
|
async def validate_and_rotate_refresh_token(
|
||
|
db: SqlaSession, refresh_token: str
|
||
|
) -> Session:
|
||
|
stmt = select(Session).where(Session.refresh_token == refresh_token)
|
||
|
result = db.execute(stmt)
|
||
|
session = result.scalars().first()
|
||
|
if not session:
|
||
|
raise NotFoundError
|
||
|
session.refresh_token = secrets.token_urlsafe(64)
|
||
|
session.last_used = datetime.now(UTC)
|
||
|
|
||
|
db.commit()
|
||
|
return session
|