import json import os import secrets from typing import Optional from uuid import UUID from jwcrypto import jwt, jwk from datetime import datetime, timedelta, UTC from app.models.user import User from app.schemas.auth_token import AccessToken from app.schemas.user import Role from app.util.errors import InsufficientPermissionsError, InvalidTokenAudienceError __signing_key = jwk.JWK.from_password(os.getenv("CS_TOKEN_SECRET", secrets.token_urlsafe(64))) async def __create_token(claims: dict) -> str: default_claims = { "iss": os.getenv("CS_TOKEN_ISSUER", "https://localhost:8000"), "iat": datetime.now(UTC).timestamp(), } header = {"alg": "HS256", "typ": "JWT", "kid": "default"} token = jwt.JWT(header=header, claims=(claims | default_claims)) token.make_signed_token(__signing_key) return token.serialize() async def __verify_token(token: str, audience: str) -> dict | None: try: token = jwt.JWT(jwt=token, key=__signing_key) claims = json.loads(token.claims) if claims.get("aud") == audience: return claims else: raise InvalidTokenAudienceError except Exception: return None async def create_access_token( user: User, session_id: UUID ) -> tuple[str, datetime]: token_lifetime = float(os.getenv("CS_ACCESS_TOKEN_LIFETIME_SECONDS", "300")) exp_time = datetime.now(UTC) + timedelta(seconds=token_lifetime) claims = { "aud": "access", "sub": str(user.id), "exp": exp_time.timestamp(), "session": str(session_id), "role": str(user.role), } return await __create_token(claims=claims), exp_time async def verify_access_token( token: str, required_roles: Optional[list[str]] = None ) -> AccessToken | None: try: claims = await __verify_token(token=token, audience="access") if not claims: return None if not required_roles or claims.get("role") in required_roles: return AccessToken( subject=claims.get("sub"), role=Role(claims.get("role")), session=claims.get("session"), ) else: raise InsufficientPermissionsError except InvalidTokenAudienceError: pass