From 98310bf062ed2c86b52218a839691b0eda0e47a1 Mon Sep 17 00:00:00 2001 From: dan Date: Thu, 5 Feb 2026 13:23:01 -0800 Subject: [PATCH] Add server component: FastAPI + WebSocket speech pipeline Voice-in/voice-out server for the Shop Bob machine shop assistant. STT (faster-whisper), LLM (Ollama), TTS (Piper) with sentence-level audio streaming over WebSocket for low-latency responses. Co-Authored-By: Claude Opus 4.5 --- server/__init__.py | 0 server/config.py | 39 ++++++++++++ server/connection_manager.py | 26 ++++++++ server/llm.py | 59 ++++++++++++++++++ server/main.py | 114 +++++++++++++++++++++++++++++++++++ server/pipeline.py | 85 ++++++++++++++++++++++++++ server/requirements.txt | 8 +++ server/stt.py | 62 +++++++++++++++++++ server/tts.py | 48 +++++++++++++++ 9 files changed, 441 insertions(+) create mode 100644 server/__init__.py create mode 100644 server/config.py create mode 100644 server/connection_manager.py create mode 100644 server/llm.py create mode 100644 server/main.py create mode 100644 server/pipeline.py create mode 100644 server/requirements.txt create mode 100644 server/stt.py create mode 100644 server/tts.py diff --git a/server/__init__.py b/server/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/server/config.py b/server/config.py new file mode 100644 index 0000000..556249c --- /dev/null +++ b/server/config.py @@ -0,0 +1,39 @@ +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + model_config = {"env_prefix": "BOB_", "env_file": ".env", "env_file_encoding": "utf-8"} + + # Networking + host: str = "0.0.0.0" + port: int = 8765 + + # Whisper STT + whisper_model: str = "large-v3" + whisper_device: str = "cuda" + whisper_compute_type: str = "float16" + stt_sample_rate: int = 16000 + max_concurrent_transcriptions: int = 2 + + # Ollama LLM + ollama_url: str = "http://localhost:11434" + llm_model: str = "llama3.1:8b" + max_concurrent_llm: int = 3 + + # Piper TTS + piper_model: str = "en_US-lessac-medium" + tts_sample_rate: int = 22050 + + # System prompt for the machine shop assistant + system_prompt: str = ( + "You are Bob, a knowledgeable machine shop assistant. " + "Give concise, direct answers about machining, tooling, materials, " + "feeds and speeds, and shop processes. " + "Always prioritize safety — if a question involves a potentially " + "dangerous operation, lead with the safety considerations. " + "Keep answers short and practical — shop floor workers need quick info, " + "not essays." + ) + + +settings = Settings() diff --git a/server/connection_manager.py b/server/connection_manager.py new file mode 100644 index 0000000..37764c2 --- /dev/null +++ b/server/connection_manager.py @@ -0,0 +1,26 @@ +import logging + +from fastapi import WebSocket + +logger = logging.getLogger(__name__) + + +class ConnectionManager: + def __init__(self) -> None: + self._connections: dict[str, WebSocket] = {} + + async def connect(self, client_id: str, websocket: WebSocket) -> None: + await websocket.accept() + self._connections[client_id] = websocket + logger.info("Client connected: %s (total: %d)", client_id, len(self._connections)) + + def disconnect(self, client_id: str) -> None: + self._connections.pop(client_id, None) + logger.info("Client disconnected: %s (total: %d)", client_id, len(self._connections)) + + def get_active_connections(self) -> dict[str, WebSocket]: + return dict(self._connections) + + @property + def active_count(self) -> int: + return len(self._connections) diff --git a/server/llm.py b/server/llm.py new file mode 100644 index 0000000..67c5dd9 --- /dev/null +++ b/server/llm.py @@ -0,0 +1,59 @@ +import asyncio +import json +import logging +from collections.abc import AsyncGenerator + +import httpx + +from .config import settings + +logger = logging.getLogger(__name__) + +_semaphore = asyncio.Semaphore(settings.max_concurrent_llm) + + +async def check_ollama() -> bool: + """Verify Ollama is reachable.""" + try: + async with httpx.AsyncClient() as client: + resp = await client.get(f"{settings.ollama_url}/api/tags", timeout=5) + resp.raise_for_status() + return True + except Exception as e: + logger.error("Ollama not reachable at %s: %s", settings.ollama_url, e) + return False + + +async def generate_response( + transcript: str, + system_prompt: str | None = None, +) -> AsyncGenerator[str, None]: + """Stream text tokens from Ollama for the given user transcript.""" + prompt = system_prompt or settings.system_prompt + + payload = { + "model": settings.llm_model, + "messages": [ + {"role": "system", "content": prompt}, + {"role": "user", "content": transcript}, + ], + "stream": True, + } + + async with _semaphore: + async with httpx.AsyncClient(timeout=httpx.Timeout(120.0, connect=10.0)) as client: + async with client.stream( + "POST", + f"{settings.ollama_url}/api/chat", + json=payload, + ) as resp: + resp.raise_for_status() + async for line in resp.aiter_lines(): + if not line: + continue + data = json.loads(line) + token = data.get("message", {}).get("content", "") + if token: + yield token + if data.get("done"): + break diff --git a/server/main.py b/server/main.py new file mode 100644 index 0000000..be5108b --- /dev/null +++ b/server/main.py @@ -0,0 +1,114 @@ +import json +import logging +from contextlib import asynccontextmanager + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from fastapi.middleware.cors import CORSMiddleware + +from . import llm, stt, tts +from .config import settings +from .connection_manager import ConnectionManager +from .pipeline import process_request + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s [%(name)s] %(message)s", +) +logger = logging.getLogger(__name__) + +manager = ConnectionManager() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # Startup + logger.info("Starting Shop Bob server...") + stt.load_model() + tts.load_model() + if not await llm.check_ollama(): + logger.warning("Ollama is not reachable — LLM calls will fail until it's up") + logger.info("Shop Bob server ready on %s:%d", settings.host, settings.port) + yield + # Shutdown + logger.info("Shutting down Shop Bob server...") + + +app = FastAPI(title="Shop Bob", lifespan=lifespan) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], +) + + +@app.get("/health") +async def health(): + ollama_ok = await llm.check_ollama() + return { + "status": "ok", + "active_connections": manager.active_count, + "ollama": "ok" if ollama_ok else "unreachable", + } + + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + client_id: str | None = None + try: + # Wait for the first message which should be audio_start + raw = await websocket.receive_text() + msg = json.loads(raw) + + if msg.get("type") != "audio_start": + await websocket.close(code=1008, reason="Expected audio_start message") + return + + client_id = msg.get("client_id", "unknown") + sample_rate = msg.get("sample_rate", settings.stt_sample_rate) + + await manager.connect(client_id, websocket) + + # Main message loop + while True: + audio_chunks: list[bytes] = [] + + # Collect binary audio frames until audio_end + while True: + message = await websocket.receive() + + if "text" in message: + data = json.loads(message["text"]) + if data.get("type") == "audio_end": + break + elif data.get("type") == "audio_start": + # New utterance — update sample rate if provided + sample_rate = data.get("sample_rate", sample_rate) + audio_chunks = [] + continue + elif "bytes" in message: + audio_chunks.append(message["bytes"]) + + if audio_chunks: + audio_bytes = b"".join(audio_chunks) + await process_request(audio_bytes, sample_rate, websocket) + + except WebSocketDisconnect: + logger.info("Client %s disconnected", client_id) + except Exception: + logger.exception("WebSocket error for client %s", client_id) + finally: + if client_id: + manager.disconnect(client_id) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run( + "server.main:app", + host=settings.host, + port=settings.port, + log_level="info", + ) diff --git a/server/pipeline.py b/server/pipeline.py new file mode 100644 index 0000000..9199c1d --- /dev/null +++ b/server/pipeline.py @@ -0,0 +1,85 @@ +import json +import logging +import re + +from fastapi import WebSocket + +from . import llm, stt, tts + +logger = logging.getLogger(__name__) + +# Regex to split text on sentence boundaries while keeping the delimiters +_SENTENCE_RE = re.compile(r"(?<=[.!?])\s+") + + +async def _send_status(ws: WebSocket, state: str) -> None: + await ws.send_text(json.dumps({"type": "status", "state": state})) + + +async def process_request( + audio_bytes: bytes, + sample_rate: int, + websocket: WebSocket, +) -> None: + """Run the full speech-in → text-out → speech-out pipeline.""" + try: + # --- STT --- + await _send_status(websocket, "transcribing") + transcript = await stt.transcribe(audio_bytes, sample_rate) + + if not transcript.strip(): + await websocket.send_text( + json.dumps({"type": "transcript", "text": ""}) + ) + await websocket.send_text(json.dumps({"type": "response_end"})) + return + + await websocket.send_text( + json.dumps({"type": "transcript", "text": transcript}) + ) + + # --- LLM --- + await _send_status(websocket, "thinking") + full_response = "" + sentence_buffer = "" + + # --- Sentence-level TTS streaming --- + await _send_status(websocket, "speaking") + + async for token in llm.generate_response(transcript): + full_response += token + sentence_buffer += token + + # Check if we have one or more complete sentences + parts = _SENTENCE_RE.split(sentence_buffer) + if len(parts) > 1: + # All parts except the last are complete sentences + for sentence in parts[:-1]: + sentence = sentence.strip() + if sentence: + audio_chunk = await tts.synthesize(sentence) + await websocket.send_bytes(audio_chunk) + # Keep the incomplete remainder + sentence_buffer = parts[-1] + + # Flush any remaining text + sentence_buffer = sentence_buffer.strip() + if sentence_buffer: + audio_chunk = await tts.synthesize(sentence_buffer) + await websocket.send_bytes(audio_chunk) + + # Send the full text response and signal completion + await websocket.send_text( + json.dumps({"type": "response_text", "text": full_response}) + ) + await websocket.send_text(json.dumps({"type": "response_end"})) + + except Exception: + logger.exception("Pipeline error") + try: + await websocket.send_text( + json.dumps({"type": "error", "text": "Internal processing error"}) + ) + await websocket.send_text(json.dumps({"type": "response_end"})) + except Exception: + pass # Client already disconnected diff --git a/server/requirements.txt b/server/requirements.txt new file mode 100644 index 0000000..92ec998 --- /dev/null +++ b/server/requirements.txt @@ -0,0 +1,8 @@ +fastapi>=0.104 +uvicorn[standard]>=0.24 +websockets>=12.0 +faster-whisper>=1.0 +httpx>=0.25 +piper-tts>=1.2 +numpy>=1.24 +pydantic-settings>=2.0 diff --git a/server/stt.py b/server/stt.py new file mode 100644 index 0000000..1d6ed3e --- /dev/null +++ b/server/stt.py @@ -0,0 +1,62 @@ +import asyncio +import logging +from concurrent.futures import ThreadPoolExecutor +from functools import partial + +import numpy as np +from faster_whisper import WhisperModel + +from .config import settings + +logger = logging.getLogger(__name__) + +_model: WhisperModel | None = None +_executor = ThreadPoolExecutor(max_workers=settings.max_concurrent_transcriptions) +_semaphore = asyncio.Semaphore(settings.max_concurrent_transcriptions) + + +def load_model() -> None: + global _model + logger.info( + "Loading Whisper model %s on %s (%s)...", + settings.whisper_model, + settings.whisper_device, + settings.whisper_compute_type, + ) + _model = WhisperModel( + settings.whisper_model, + device=settings.whisper_device, + compute_type=settings.whisper_compute_type, + ) + logger.info("Whisper model loaded.") + + +def _transcribe_sync(audio_bytes: bytes, sample_rate: int) -> str: + assert _model is not None, "Whisper model not loaded — call load_model() first" + + # Convert raw PCM 16-bit mono bytes to float32 numpy array + audio = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0 + + if sample_rate != 16000: + # faster-whisper expects 16kHz — resample via simple linear interpolation + duration = len(audio) / sample_rate + target_len = int(duration * 16000) + audio = np.interp( + np.linspace(0, len(audio) - 1, target_len), + np.arange(len(audio)), + audio, + ).astype(np.float32) + + segments, info = _model.transcribe(audio, beam_size=5) + text = " ".join(seg.text.strip() for seg in segments) + logger.info("Transcribed %.1fs audio → %d chars", info.duration, len(text)) + return text + + +async def transcribe(audio_bytes: bytes, sample_rate: int) -> str: + async with _semaphore: + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + _executor, + partial(_transcribe_sync, audio_bytes, sample_rate), + ) diff --git a/server/tts.py b/server/tts.py new file mode 100644 index 0000000..36dd093 --- /dev/null +++ b/server/tts.py @@ -0,0 +1,48 @@ +import asyncio +import io +import logging +import wave +from concurrent.futures import ThreadPoolExecutor +from functools import partial + +from piper.voice import PiperVoice + +from .config import settings + +logger = logging.getLogger(__name__) + +_voice: PiperVoice | None = None +_executor = ThreadPoolExecutor(max_workers=2) + + +def load_model() -> None: + global _voice + logger.info("Loading Piper TTS voice %s...", settings.piper_model) + _voice = PiperVoice.load(settings.piper_model) + logger.info("Piper TTS loaded.") + + +def _synthesize_sync(text: str) -> bytes: + """Synthesize text to raw PCM 16-bit mono audio bytes.""" + assert _voice is not None, "Piper voice not loaded — call load_model() first" + + buf = io.BytesIO() + with wave.open(buf, "wb") as wf: + _voice.synthesize(text, wf) + + # Extract raw PCM from the WAV container + buf.seek(0) + with wave.open(buf, "rb") as wf: + pcm_data = wf.readframes(wf.getnframes()) + + logger.debug("Synthesized %d chars → %d bytes PCM", len(text), len(pcm_data)) + return pcm_data + + +async def synthesize(text: str) -> bytes: + """Async wrapper — runs Piper in a thread pool.""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + _executor, + partial(_synthesize_sync, text), + )