import logging
from fastapi import WebSocket, WebSocketDisconnect
from websockets import ConnectionClosed

logger = logging.getLogger("gunicorn.error")

# Wrapper to transform a FastAPI websocket to a standard websocket
class WebSocketWrapper():
    def __init__(self, websocket: WebSocket):
        self._websocket = websocket

    async def recv(self) -> str:
        try:
            text = await self._websocket.receive_text()
            logger.info("Message received: %s", text)
            return text
        except WebSocketDisconnect as e:
            raise ConnectionClosed(e.code, 'WebSocketWrapper')

    async def send(self, msg: str) -> None:
        logger.info("Message sent: %s", msg)
        await self._websocket.send_text(msg)

    async def close(self, code: int, reason: str) -> None:
        await self._websocket.close(code)