Add server component: FastAPI + WebSocket speech pipeline

Voice-in/voice-out server for the Shop Bob machine shop assistant.
STT (faster-whisper), LLM (Ollama), TTS (Piper) with sentence-level
audio streaming over WebSocket for low-latency responses.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
dan 2026-02-05 13:23:01 -08:00
commit 98310bf062
9 changed files with 441 additions and 0 deletions

0
server/__init__.py Normal file
View file

39
server/config.py Normal file
View file

@ -0,0 +1,39 @@
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
model_config = {"env_prefix": "BOB_", "env_file": ".env", "env_file_encoding": "utf-8"}
# Networking
host: str = "0.0.0.0"
port: int = 8765
# Whisper STT
whisper_model: str = "large-v3"
whisper_device: str = "cuda"
whisper_compute_type: str = "float16"
stt_sample_rate: int = 16000
max_concurrent_transcriptions: int = 2
# Ollama LLM
ollama_url: str = "http://localhost:11434"
llm_model: str = "llama3.1:8b"
max_concurrent_llm: int = 3
# Piper TTS
piper_model: str = "en_US-lessac-medium"
tts_sample_rate: int = 22050
# System prompt for the machine shop assistant
system_prompt: str = (
"You are Bob, a knowledgeable machine shop assistant. "
"Give concise, direct answers about machining, tooling, materials, "
"feeds and speeds, and shop processes. "
"Always prioritize safety — if a question involves a potentially "
"dangerous operation, lead with the safety considerations. "
"Keep answers short and practical — shop floor workers need quick info, "
"not essays."
)
settings = Settings()

View file

@ -0,0 +1,26 @@
import logging
from fastapi import WebSocket
logger = logging.getLogger(__name__)
class ConnectionManager:
def __init__(self) -> None:
self._connections: dict[str, WebSocket] = {}
async def connect(self, client_id: str, websocket: WebSocket) -> None:
await websocket.accept()
self._connections[client_id] = websocket
logger.info("Client connected: %s (total: %d)", client_id, len(self._connections))
def disconnect(self, client_id: str) -> None:
self._connections.pop(client_id, None)
logger.info("Client disconnected: %s (total: %d)", client_id, len(self._connections))
def get_active_connections(self) -> dict[str, WebSocket]:
return dict(self._connections)
@property
def active_count(self) -> int:
return len(self._connections)

59
server/llm.py Normal file
View file

@ -0,0 +1,59 @@
import asyncio
import json
import logging
from collections.abc import AsyncGenerator
import httpx
from .config import settings
logger = logging.getLogger(__name__)
_semaphore = asyncio.Semaphore(settings.max_concurrent_llm)
async def check_ollama() -> bool:
"""Verify Ollama is reachable."""
try:
async with httpx.AsyncClient() as client:
resp = await client.get(f"{settings.ollama_url}/api/tags", timeout=5)
resp.raise_for_status()
return True
except Exception as e:
logger.error("Ollama not reachable at %s: %s", settings.ollama_url, e)
return False
async def generate_response(
transcript: str,
system_prompt: str | None = None,
) -> AsyncGenerator[str, None]:
"""Stream text tokens from Ollama for the given user transcript."""
prompt = system_prompt or settings.system_prompt
payload = {
"model": settings.llm_model,
"messages": [
{"role": "system", "content": prompt},
{"role": "user", "content": transcript},
],
"stream": True,
}
async with _semaphore:
async with httpx.AsyncClient(timeout=httpx.Timeout(120.0, connect=10.0)) as client:
async with client.stream(
"POST",
f"{settings.ollama_url}/api/chat",
json=payload,
) as resp:
resp.raise_for_status()
async for line in resp.aiter_lines():
if not line:
continue
data = json.loads(line)
token = data.get("message", {}).get("content", "")
if token:
yield token
if data.get("done"):
break

114
server/main.py Normal file
View file

@ -0,0 +1,114 @@
import json
import logging
from contextlib import asynccontextmanager
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.middleware.cors import CORSMiddleware
from . import llm, stt, tts
from .config import settings
from .connection_manager import ConnectionManager
from .pipeline import process_request
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
)
logger = logging.getLogger(__name__)
manager = ConnectionManager()
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
logger.info("Starting Shop Bob server...")
stt.load_model()
tts.load_model()
if not await llm.check_ollama():
logger.warning("Ollama is not reachable — LLM calls will fail until it's up")
logger.info("Shop Bob server ready on %s:%d", settings.host, settings.port)
yield
# Shutdown
logger.info("Shutting down Shop Bob server...")
app = FastAPI(title="Shop Bob", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health")
async def health():
ollama_ok = await llm.check_ollama()
return {
"status": "ok",
"active_connections": manager.active_count,
"ollama": "ok" if ollama_ok else "unreachable",
}
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
client_id: str | None = None
try:
# Wait for the first message which should be audio_start
raw = await websocket.receive_text()
msg = json.loads(raw)
if msg.get("type") != "audio_start":
await websocket.close(code=1008, reason="Expected audio_start message")
return
client_id = msg.get("client_id", "unknown")
sample_rate = msg.get("sample_rate", settings.stt_sample_rate)
await manager.connect(client_id, websocket)
# Main message loop
while True:
audio_chunks: list[bytes] = []
# Collect binary audio frames until audio_end
while True:
message = await websocket.receive()
if "text" in message:
data = json.loads(message["text"])
if data.get("type") == "audio_end":
break
elif data.get("type") == "audio_start":
# New utterance — update sample rate if provided
sample_rate = data.get("sample_rate", sample_rate)
audio_chunks = []
continue
elif "bytes" in message:
audio_chunks.append(message["bytes"])
if audio_chunks:
audio_bytes = b"".join(audio_chunks)
await process_request(audio_bytes, sample_rate, websocket)
except WebSocketDisconnect:
logger.info("Client %s disconnected", client_id)
except Exception:
logger.exception("WebSocket error for client %s", client_id)
finally:
if client_id:
manager.disconnect(client_id)
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"server.main:app",
host=settings.host,
port=settings.port,
log_level="info",
)

85
server/pipeline.py Normal file
View file

@ -0,0 +1,85 @@
import json
import logging
import re
from fastapi import WebSocket
from . import llm, stt, tts
logger = logging.getLogger(__name__)
# Regex to split text on sentence boundaries while keeping the delimiters
_SENTENCE_RE = re.compile(r"(?<=[.!?])\s+")
async def _send_status(ws: WebSocket, state: str) -> None:
await ws.send_text(json.dumps({"type": "status", "state": state}))
async def process_request(
audio_bytes: bytes,
sample_rate: int,
websocket: WebSocket,
) -> None:
"""Run the full speech-in → text-out → speech-out pipeline."""
try:
# --- STT ---
await _send_status(websocket, "transcribing")
transcript = await stt.transcribe(audio_bytes, sample_rate)
if not transcript.strip():
await websocket.send_text(
json.dumps({"type": "transcript", "text": ""})
)
await websocket.send_text(json.dumps({"type": "response_end"}))
return
await websocket.send_text(
json.dumps({"type": "transcript", "text": transcript})
)
# --- LLM ---
await _send_status(websocket, "thinking")
full_response = ""
sentence_buffer = ""
# --- Sentence-level TTS streaming ---
await _send_status(websocket, "speaking")
async for token in llm.generate_response(transcript):
full_response += token
sentence_buffer += token
# Check if we have one or more complete sentences
parts = _SENTENCE_RE.split(sentence_buffer)
if len(parts) > 1:
# All parts except the last are complete sentences
for sentence in parts[:-1]:
sentence = sentence.strip()
if sentence:
audio_chunk = await tts.synthesize(sentence)
await websocket.send_bytes(audio_chunk)
# Keep the incomplete remainder
sentence_buffer = parts[-1]
# Flush any remaining text
sentence_buffer = sentence_buffer.strip()
if sentence_buffer:
audio_chunk = await tts.synthesize(sentence_buffer)
await websocket.send_bytes(audio_chunk)
# Send the full text response and signal completion
await websocket.send_text(
json.dumps({"type": "response_text", "text": full_response})
)
await websocket.send_text(json.dumps({"type": "response_end"}))
except Exception:
logger.exception("Pipeline error")
try:
await websocket.send_text(
json.dumps({"type": "error", "text": "Internal processing error"})
)
await websocket.send_text(json.dumps({"type": "response_end"}))
except Exception:
pass # Client already disconnected

8
server/requirements.txt Normal file
View file

@ -0,0 +1,8 @@
fastapi>=0.104
uvicorn[standard]>=0.24
websockets>=12.0
faster-whisper>=1.0
httpx>=0.25
piper-tts>=1.2
numpy>=1.24
pydantic-settings>=2.0

62
server/stt.py Normal file
View file

@ -0,0 +1,62 @@
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from functools import partial
import numpy as np
from faster_whisper import WhisperModel
from .config import settings
logger = logging.getLogger(__name__)
_model: WhisperModel | None = None
_executor = ThreadPoolExecutor(max_workers=settings.max_concurrent_transcriptions)
_semaphore = asyncio.Semaphore(settings.max_concurrent_transcriptions)
def load_model() -> None:
global _model
logger.info(
"Loading Whisper model %s on %s (%s)...",
settings.whisper_model,
settings.whisper_device,
settings.whisper_compute_type,
)
_model = WhisperModel(
settings.whisper_model,
device=settings.whisper_device,
compute_type=settings.whisper_compute_type,
)
logger.info("Whisper model loaded.")
def _transcribe_sync(audio_bytes: bytes, sample_rate: int) -> str:
assert _model is not None, "Whisper model not loaded — call load_model() first"
# Convert raw PCM 16-bit mono bytes to float32 numpy array
audio = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
if sample_rate != 16000:
# faster-whisper expects 16kHz — resample via simple linear interpolation
duration = len(audio) / sample_rate
target_len = int(duration * 16000)
audio = np.interp(
np.linspace(0, len(audio) - 1, target_len),
np.arange(len(audio)),
audio,
).astype(np.float32)
segments, info = _model.transcribe(audio, beam_size=5)
text = " ".join(seg.text.strip() for seg in segments)
logger.info("Transcribed %.1fs audio → %d chars", info.duration, len(text))
return text
async def transcribe(audio_bytes: bytes, sample_rate: int) -> str:
async with _semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
_executor,
partial(_transcribe_sync, audio_bytes, sample_rate),
)

48
server/tts.py Normal file
View file

@ -0,0 +1,48 @@
import asyncio
import io
import logging
import wave
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from piper.voice import PiperVoice
from .config import settings
logger = logging.getLogger(__name__)
_voice: PiperVoice | None = None
_executor = ThreadPoolExecutor(max_workers=2)
def load_model() -> None:
global _voice
logger.info("Loading Piper TTS voice %s...", settings.piper_model)
_voice = PiperVoice.load(settings.piper_model)
logger.info("Piper TTS loaded.")
def _synthesize_sync(text: str) -> bytes:
"""Synthesize text to raw PCM 16-bit mono audio bytes."""
assert _voice is not None, "Piper voice not loaded — call load_model() first"
buf = io.BytesIO()
with wave.open(buf, "wb") as wf:
_voice.synthesize(text, wf)
# Extract raw PCM from the WAV container
buf.seek(0)
with wave.open(buf, "rb") as wf:
pcm_data = wf.readframes(wf.getnframes())
logger.debug("Synthesized %d chars → %d bytes PCM", len(text), len(pcm_data))
return pcm_data
async def synthesize(text: str) -> bytes:
"""Async wrapper — runs Piper in a thread pool."""
loop = asyncio.get_running_loop()
return await loop.run_in_executor(
_executor,
partial(_synthesize_sync, text),
)