Voice-in/voice-out server for the Shop Bob machine shop assistant. STT (faster-whisper), LLM (Ollama), TTS (Piper) with sentence-level audio streaming over WebSocket for low-latency responses. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
114 lines
3.2 KiB
Python
114 lines
3.2 KiB
Python
import json
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
from . import llm, stt, tts
|
|
from .config import settings
|
|
from .connection_manager import ConnectionManager
|
|
from .pipeline import process_request
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
|
|
)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
manager = ConnectionManager()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
# Startup
|
|
logger.info("Starting Shop Bob server...")
|
|
stt.load_model()
|
|
tts.load_model()
|
|
if not await llm.check_ollama():
|
|
logger.warning("Ollama is not reachable — LLM calls will fail until it's up")
|
|
logger.info("Shop Bob server ready on %s:%d", settings.host, settings.port)
|
|
yield
|
|
# Shutdown
|
|
logger.info("Shutting down Shop Bob server...")
|
|
|
|
|
|
app = FastAPI(title="Shop Bob", lifespan=lifespan)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
ollama_ok = await llm.check_ollama()
|
|
return {
|
|
"status": "ok",
|
|
"active_connections": manager.active_count,
|
|
"ollama": "ok" if ollama_ok else "unreachable",
|
|
}
|
|
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_endpoint(websocket: WebSocket):
|
|
client_id: str | None = None
|
|
try:
|
|
# Wait for the first message which should be audio_start
|
|
raw = await websocket.receive_text()
|
|
msg = json.loads(raw)
|
|
|
|
if msg.get("type") != "audio_start":
|
|
await websocket.close(code=1008, reason="Expected audio_start message")
|
|
return
|
|
|
|
client_id = msg.get("client_id", "unknown")
|
|
sample_rate = msg.get("sample_rate", settings.stt_sample_rate)
|
|
|
|
await manager.connect(client_id, websocket)
|
|
|
|
# Main message loop
|
|
while True:
|
|
audio_chunks: list[bytes] = []
|
|
|
|
# Collect binary audio frames until audio_end
|
|
while True:
|
|
message = await websocket.receive()
|
|
|
|
if "text" in message:
|
|
data = json.loads(message["text"])
|
|
if data.get("type") == "audio_end":
|
|
break
|
|
elif data.get("type") == "audio_start":
|
|
# New utterance — update sample rate if provided
|
|
sample_rate = data.get("sample_rate", sample_rate)
|
|
audio_chunks = []
|
|
continue
|
|
elif "bytes" in message:
|
|
audio_chunks.append(message["bytes"])
|
|
|
|
if audio_chunks:
|
|
audio_bytes = b"".join(audio_chunks)
|
|
await process_request(audio_bytes, sample_rate, websocket)
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info("Client %s disconnected", client_id)
|
|
except Exception:
|
|
logger.exception("WebSocket error for client %s", client_id)
|
|
finally:
|
|
if client_id:
|
|
manager.disconnect(client_id)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
|
|
uvicorn.run(
|
|
"server.main:app",
|
|
host=settings.host,
|
|
port=settings.port,
|
|
log_level="info",
|
|
)
|