import json
import os
import secrets
from typing import Optional
from uuid import UUID
from jwcrypto import jwt, jwk
from datetime import datetime, timedelta, UTC

from app.models.user import User
from app.schemas.auth_token import AccessToken
from app.schemas.user import Role
from app.util.errors import InsufficientPermissionsError, InvalidTokenAudienceError

__signing_key = jwk.JWK.from_password(os.getenv("CS_TOKEN_SECRET", secrets.token_urlsafe(64)))

async def __create_token(claims: dict) -> str:
    default_claims = {
        "iss": os.getenv("CS_TOKEN_ISSUER", "https://localhost:8000"),
        "iat": datetime.now(UTC).timestamp(),
    }
    header = {"alg": "HS256", "typ": "JWT", "kid": "default"}
    token = jwt.JWT(header=header, claims=(claims | default_claims))
    token.make_signed_token(__signing_key)
    return token.serialize()


async def __verify_token(token: str, audience: str) -> dict | None:
    try:
        token = jwt.JWT(jwt=token, key=__signing_key)
        claims = json.loads(token.claims)
        if claims.get("aud") == audience:
            return claims
        else:
            raise InvalidTokenAudienceError
    except Exception:
        return None


async def create_access_token(
    user: User, session_id: UUID
) -> tuple[str, datetime]:
    token_lifetime = float(os.getenv("CS_ACCESS_TOKEN_LIFETIME_SECONDS", "300"))
    exp_time = datetime.now(UTC) + timedelta(seconds=token_lifetime)
    claims = {
        "aud": "access",
        "sub": str(user.id),
        "exp": exp_time.timestamp(),
        "session": str(session_id),
        "role": str(user.role),
    }
    return await __create_token(claims=claims), exp_time

async def verify_access_token(
    token: str, required_roles: Optional[list[str]] = None
) -> AccessToken | None:
    try:
        claims = await __verify_token(token=token, audience="access")
        if not claims:
            return None
        if not required_roles or claims.get("role") in required_roles:
            return AccessToken(
                subject=claims.get("sub"),
                role=Role(claims.get("role")),
                session=claims.get("session"),
            )
        else:
            raise InsufficientPermissionsError
    except InvalidTokenAudienceError:
        pass