from datetime import datetime, UTC
import secrets
from uuid import UUID
from sqlalchemy import select, delete
from sqlalchemy.orm import Session as SqlaSession

from app.models.session import Session
from app.models.user import User
from app.util.errors import NotFoundError


async def get_sessions(
    db: SqlaSession, skip: int = 0, limit: int = 20
) -> tuple[Session]:
    stmt = select(Session).offset(skip).limit(limit)
    result = db.execute(stmt)
    return result.scalars().all()


async def get_sessions_by_user(db: SqlaSession, user_id: UUID) -> tuple[Session]:
    stmt = select(Session).where(Session.user_id == user_id)
    result = db.execute(stmt)
    return result.scalars().all()


async def create_session(db: SqlaSession, user: User, useragent: str) -> Session:
    session = Session(
        name=useragent,
        refresh_token=secrets.token_urlsafe(64),
        last_used=datetime.now(UTC),
        user_id=user.id,
    )
    db.add(session)
    db.commit()
    db.refresh(session)
    return session


async def remove_session(db: SqlaSession, id: UUID):
    session = db.get(Session, id)
    if not session:
        raise NotFoundError
    db.delete(session)
    db.commit()


async def remove_session_for_user(
    db: SqlaSession, id: UUID, user_id: UUID
):
    stmt = select(Session).where(Session.id == id and Session.user_id == user_id)
    result = db.execute(stmt)
    session = result.scalars().first()
    if not session:
        raise NotFoundError
    db.delete(session)
    db.commit()


async def remove_all_sessions_for_user(db: SqlaSession, user_id: UUID):
    stmt = delete(Session).where(Session.user_id == user_id)
    db.execute(stmt)
    db.commit()


async def remove_all_sessions(db: SqlaSession):
    stmt = delete(Session)
    db.execute(stmt)
    db.commit()


async def validate_and_rotate_refresh_token(
    db: SqlaSession, refresh_token: str
) -> Session:
    stmt = select(Session).where(Session.refresh_token == refresh_token)
    result = db.execute(stmt)
    session = result.scalars().first()
    if not session:
        raise NotFoundError
    session.refresh_token = secrets.token_urlsafe(64)
    session.last_used = datetime.now(UTC)

    db.commit()
    return session