import json import logging from contextlib import asynccontextmanager from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from . import llm, stt, tts from .config import settings from .connection_manager import ConnectionManager from .pipeline import process_request logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s [%(name)s] %(message)s", ) logger = logging.getLogger(__name__) manager = ConnectionManager() @asynccontextmanager async def lifespan(app: FastAPI): # Startup logger.info("Starting Shop Bob server...") stt.load_model() tts.load_model() if not await llm.check_ollama(): logger.warning("Ollama is not reachable — LLM calls will fail until it's up") logger.info("Shop Bob server ready on %s:%d", settings.host, settings.port) yield # Shutdown logger.info("Shutting down Shop Bob server...") app = FastAPI(title="Shop Bob", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) @app.get("/health") async def health(): ollama_ok = await llm.check_ollama() return { "status": "ok", "active_connections": manager.active_count, "ollama": "ok" if ollama_ok else "unreachable", } @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket): client_id: str | None = None try: # Wait for the first message which should be audio_start raw = await websocket.receive_text() msg = json.loads(raw) if msg.get("type") != "audio_start": await websocket.close(code=1008, reason="Expected audio_start message") return client_id = msg.get("client_id", "unknown") sample_rate = msg.get("sample_rate", settings.stt_sample_rate) await manager.connect(client_id, websocket) # Main message loop while True: audio_chunks: list[bytes] = [] # Collect binary audio frames until audio_end while True: message = await websocket.receive() if "text" in message: data = json.loads(message["text"]) if data.get("type") == "audio_end": break elif data.get("type") == "audio_start": # New utterance — update sample rate if provided sample_rate = data.get("sample_rate", sample_rate) audio_chunks = [] continue elif "bytes" in message: audio_chunks.append(message["bytes"]) if audio_chunks: audio_bytes = b"".join(audio_chunks) await process_request(audio_bytes, sample_rate, websocket) except WebSocketDisconnect: logger.info("Client %s disconnected", client_id) except Exception: logger.exception("WebSocket error for client %s", client_id) finally: if client_id: manager.disconnect(client_id) if __name__ == "__main__": import uvicorn uvicorn.run( "server.main:app", host=settings.host, port=settings.port, log_level="info", )