Add server component: FastAPI + WebSocket speech pipeline
Voice-in/voice-out server for the Shop Bob machine shop assistant. STT (faster-whisper), LLM (Ollama), TTS (Piper) with sentence-level audio streaming over WebSocket for low-latency responses. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
commit
98310bf062
9 changed files with 441 additions and 0 deletions
0
server/__init__.py
Normal file
0
server/__init__.py
Normal file
39
server/config.py
Normal file
39
server/config.py
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
model_config = {"env_prefix": "BOB_", "env_file": ".env", "env_file_encoding": "utf-8"}
|
||||||
|
|
||||||
|
# Networking
|
||||||
|
host: str = "0.0.0.0"
|
||||||
|
port: int = 8765
|
||||||
|
|
||||||
|
# Whisper STT
|
||||||
|
whisper_model: str = "large-v3"
|
||||||
|
whisper_device: str = "cuda"
|
||||||
|
whisper_compute_type: str = "float16"
|
||||||
|
stt_sample_rate: int = 16000
|
||||||
|
max_concurrent_transcriptions: int = 2
|
||||||
|
|
||||||
|
# Ollama LLM
|
||||||
|
ollama_url: str = "http://localhost:11434"
|
||||||
|
llm_model: str = "llama3.1:8b"
|
||||||
|
max_concurrent_llm: int = 3
|
||||||
|
|
||||||
|
# Piper TTS
|
||||||
|
piper_model: str = "en_US-lessac-medium"
|
||||||
|
tts_sample_rate: int = 22050
|
||||||
|
|
||||||
|
# System prompt for the machine shop assistant
|
||||||
|
system_prompt: str = (
|
||||||
|
"You are Bob, a knowledgeable machine shop assistant. "
|
||||||
|
"Give concise, direct answers about machining, tooling, materials, "
|
||||||
|
"feeds and speeds, and shop processes. "
|
||||||
|
"Always prioritize safety — if a question involves a potentially "
|
||||||
|
"dangerous operation, lead with the safety considerations. "
|
||||||
|
"Keep answers short and practical — shop floor workers need quick info, "
|
||||||
|
"not essays."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
26
server/connection_manager.py
Normal file
26
server/connection_manager.py
Normal file
|
|
@ -0,0 +1,26 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import WebSocket
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionManager:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._connections: dict[str, WebSocket] = {}
|
||||||
|
|
||||||
|
async def connect(self, client_id: str, websocket: WebSocket) -> None:
|
||||||
|
await websocket.accept()
|
||||||
|
self._connections[client_id] = websocket
|
||||||
|
logger.info("Client connected: %s (total: %d)", client_id, len(self._connections))
|
||||||
|
|
||||||
|
def disconnect(self, client_id: str) -> None:
|
||||||
|
self._connections.pop(client_id, None)
|
||||||
|
logger.info("Client disconnected: %s (total: %d)", client_id, len(self._connections))
|
||||||
|
|
||||||
|
def get_active_connections(self) -> dict[str, WebSocket]:
|
||||||
|
return dict(self._connections)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def active_count(self) -> int:
|
||||||
|
return len(self._connections)
|
||||||
59
server/llm.py
Normal file
59
server/llm.py
Normal file
|
|
@ -0,0 +1,59 @@
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from .config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_semaphore = asyncio.Semaphore(settings.max_concurrent_llm)
|
||||||
|
|
||||||
|
|
||||||
|
async def check_ollama() -> bool:
|
||||||
|
"""Verify Ollama is reachable."""
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
resp = await client.get(f"{settings.ollama_url}/api/tags", timeout=5)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Ollama not reachable at %s: %s", settings.ollama_url, e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_response(
|
||||||
|
transcript: str,
|
||||||
|
system_prompt: str | None = None,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Stream text tokens from Ollama for the given user transcript."""
|
||||||
|
prompt = system_prompt or settings.system_prompt
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": settings.llm_model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": prompt},
|
||||||
|
{"role": "user", "content": transcript},
|
||||||
|
],
|
||||||
|
"stream": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with _semaphore:
|
||||||
|
async with httpx.AsyncClient(timeout=httpx.Timeout(120.0, connect=10.0)) as client:
|
||||||
|
async with client.stream(
|
||||||
|
"POST",
|
||||||
|
f"{settings.ollama_url}/api/chat",
|
||||||
|
json=payload,
|
||||||
|
) as resp:
|
||||||
|
resp.raise_for_status()
|
||||||
|
async for line in resp.aiter_lines():
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
data = json.loads(line)
|
||||||
|
token = data.get("message", {}).get("content", "")
|
||||||
|
if token:
|
||||||
|
yield token
|
||||||
|
if data.get("done"):
|
||||||
|
break
|
||||||
114
server/main.py
Normal file
114
server/main.py
Normal file
|
|
@ -0,0 +1,114 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from . import llm, stt, tts
|
||||||
|
from .config import settings
|
||||||
|
from .connection_manager import ConnectionManager
|
||||||
|
from .pipeline import process_request
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
manager = ConnectionManager()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
# Startup
|
||||||
|
logger.info("Starting Shop Bob server...")
|
||||||
|
stt.load_model()
|
||||||
|
tts.load_model()
|
||||||
|
if not await llm.check_ollama():
|
||||||
|
logger.warning("Ollama is not reachable — LLM calls will fail until it's up")
|
||||||
|
logger.info("Shop Bob server ready on %s:%d", settings.host, settings.port)
|
||||||
|
yield
|
||||||
|
# Shutdown
|
||||||
|
logger.info("Shutting down Shop Bob server...")
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="Shop Bob", lifespan=lifespan)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
ollama_ok = await llm.check_ollama()
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"active_connections": manager.active_count,
|
||||||
|
"ollama": "ok" if ollama_ok else "unreachable",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.websocket("/ws")
|
||||||
|
async def websocket_endpoint(websocket: WebSocket):
|
||||||
|
client_id: str | None = None
|
||||||
|
try:
|
||||||
|
# Wait for the first message which should be audio_start
|
||||||
|
raw = await websocket.receive_text()
|
||||||
|
msg = json.loads(raw)
|
||||||
|
|
||||||
|
if msg.get("type") != "audio_start":
|
||||||
|
await websocket.close(code=1008, reason="Expected audio_start message")
|
||||||
|
return
|
||||||
|
|
||||||
|
client_id = msg.get("client_id", "unknown")
|
||||||
|
sample_rate = msg.get("sample_rate", settings.stt_sample_rate)
|
||||||
|
|
||||||
|
await manager.connect(client_id, websocket)
|
||||||
|
|
||||||
|
# Main message loop
|
||||||
|
while True:
|
||||||
|
audio_chunks: list[bytes] = []
|
||||||
|
|
||||||
|
# Collect binary audio frames until audio_end
|
||||||
|
while True:
|
||||||
|
message = await websocket.receive()
|
||||||
|
|
||||||
|
if "text" in message:
|
||||||
|
data = json.loads(message["text"])
|
||||||
|
if data.get("type") == "audio_end":
|
||||||
|
break
|
||||||
|
elif data.get("type") == "audio_start":
|
||||||
|
# New utterance — update sample rate if provided
|
||||||
|
sample_rate = data.get("sample_rate", sample_rate)
|
||||||
|
audio_chunks = []
|
||||||
|
continue
|
||||||
|
elif "bytes" in message:
|
||||||
|
audio_chunks.append(message["bytes"])
|
||||||
|
|
||||||
|
if audio_chunks:
|
||||||
|
audio_bytes = b"".join(audio_chunks)
|
||||||
|
await process_request(audio_bytes, sample_rate, websocket)
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
logger.info("Client %s disconnected", client_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("WebSocket error for client %s", client_id)
|
||||||
|
finally:
|
||||||
|
if client_id:
|
||||||
|
manager.disconnect(client_id)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
"server.main:app",
|
||||||
|
host=settings.host,
|
||||||
|
port=settings.port,
|
||||||
|
log_level="info",
|
||||||
|
)
|
||||||
85
server/pipeline.py
Normal file
85
server/pipeline.py
Normal file
|
|
@ -0,0 +1,85 @@
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from fastapi import WebSocket
|
||||||
|
|
||||||
|
from . import llm, stt, tts
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Regex to split text on sentence boundaries while keeping the delimiters
|
||||||
|
_SENTENCE_RE = re.compile(r"(?<=[.!?])\s+")
|
||||||
|
|
||||||
|
|
||||||
|
async def _send_status(ws: WebSocket, state: str) -> None:
|
||||||
|
await ws.send_text(json.dumps({"type": "status", "state": state}))
|
||||||
|
|
||||||
|
|
||||||
|
async def process_request(
|
||||||
|
audio_bytes: bytes,
|
||||||
|
sample_rate: int,
|
||||||
|
websocket: WebSocket,
|
||||||
|
) -> None:
|
||||||
|
"""Run the full speech-in → text-out → speech-out pipeline."""
|
||||||
|
try:
|
||||||
|
# --- STT ---
|
||||||
|
await _send_status(websocket, "transcribing")
|
||||||
|
transcript = await stt.transcribe(audio_bytes, sample_rate)
|
||||||
|
|
||||||
|
if not transcript.strip():
|
||||||
|
await websocket.send_text(
|
||||||
|
json.dumps({"type": "transcript", "text": ""})
|
||||||
|
)
|
||||||
|
await websocket.send_text(json.dumps({"type": "response_end"}))
|
||||||
|
return
|
||||||
|
|
||||||
|
await websocket.send_text(
|
||||||
|
json.dumps({"type": "transcript", "text": transcript})
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- LLM ---
|
||||||
|
await _send_status(websocket, "thinking")
|
||||||
|
full_response = ""
|
||||||
|
sentence_buffer = ""
|
||||||
|
|
||||||
|
# --- Sentence-level TTS streaming ---
|
||||||
|
await _send_status(websocket, "speaking")
|
||||||
|
|
||||||
|
async for token in llm.generate_response(transcript):
|
||||||
|
full_response += token
|
||||||
|
sentence_buffer += token
|
||||||
|
|
||||||
|
# Check if we have one or more complete sentences
|
||||||
|
parts = _SENTENCE_RE.split(sentence_buffer)
|
||||||
|
if len(parts) > 1:
|
||||||
|
# All parts except the last are complete sentences
|
||||||
|
for sentence in parts[:-1]:
|
||||||
|
sentence = sentence.strip()
|
||||||
|
if sentence:
|
||||||
|
audio_chunk = await tts.synthesize(sentence)
|
||||||
|
await websocket.send_bytes(audio_chunk)
|
||||||
|
# Keep the incomplete remainder
|
||||||
|
sentence_buffer = parts[-1]
|
||||||
|
|
||||||
|
# Flush any remaining text
|
||||||
|
sentence_buffer = sentence_buffer.strip()
|
||||||
|
if sentence_buffer:
|
||||||
|
audio_chunk = await tts.synthesize(sentence_buffer)
|
||||||
|
await websocket.send_bytes(audio_chunk)
|
||||||
|
|
||||||
|
# Send the full text response and signal completion
|
||||||
|
await websocket.send_text(
|
||||||
|
json.dumps({"type": "response_text", "text": full_response})
|
||||||
|
)
|
||||||
|
await websocket.send_text(json.dumps({"type": "response_end"}))
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Pipeline error")
|
||||||
|
try:
|
||||||
|
await websocket.send_text(
|
||||||
|
json.dumps({"type": "error", "text": "Internal processing error"})
|
||||||
|
)
|
||||||
|
await websocket.send_text(json.dumps({"type": "response_end"}))
|
||||||
|
except Exception:
|
||||||
|
pass # Client already disconnected
|
||||||
8
server/requirements.txt
Normal file
8
server/requirements.txt
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
fastapi>=0.104
|
||||||
|
uvicorn[standard]>=0.24
|
||||||
|
websockets>=12.0
|
||||||
|
faster-whisper>=1.0
|
||||||
|
httpx>=0.25
|
||||||
|
piper-tts>=1.2
|
||||||
|
numpy>=1.24
|
||||||
|
pydantic-settings>=2.0
|
||||||
62
server/stt.py
Normal file
62
server/stt.py
Normal file
|
|
@ -0,0 +1,62 @@
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
|
from .config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_model: WhisperModel | None = None
|
||||||
|
_executor = ThreadPoolExecutor(max_workers=settings.max_concurrent_transcriptions)
|
||||||
|
_semaphore = asyncio.Semaphore(settings.max_concurrent_transcriptions)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model() -> None:
|
||||||
|
global _model
|
||||||
|
logger.info(
|
||||||
|
"Loading Whisper model %s on %s (%s)...",
|
||||||
|
settings.whisper_model,
|
||||||
|
settings.whisper_device,
|
||||||
|
settings.whisper_compute_type,
|
||||||
|
)
|
||||||
|
_model = WhisperModel(
|
||||||
|
settings.whisper_model,
|
||||||
|
device=settings.whisper_device,
|
||||||
|
compute_type=settings.whisper_compute_type,
|
||||||
|
)
|
||||||
|
logger.info("Whisper model loaded.")
|
||||||
|
|
||||||
|
|
||||||
|
def _transcribe_sync(audio_bytes: bytes, sample_rate: int) -> str:
|
||||||
|
assert _model is not None, "Whisper model not loaded — call load_model() first"
|
||||||
|
|
||||||
|
# Convert raw PCM 16-bit mono bytes to float32 numpy array
|
||||||
|
audio = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
|
||||||
|
|
||||||
|
if sample_rate != 16000:
|
||||||
|
# faster-whisper expects 16kHz — resample via simple linear interpolation
|
||||||
|
duration = len(audio) / sample_rate
|
||||||
|
target_len = int(duration * 16000)
|
||||||
|
audio = np.interp(
|
||||||
|
np.linspace(0, len(audio) - 1, target_len),
|
||||||
|
np.arange(len(audio)),
|
||||||
|
audio,
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
segments, info = _model.transcribe(audio, beam_size=5)
|
||||||
|
text = " ".join(seg.text.strip() for seg in segments)
|
||||||
|
logger.info("Transcribed %.1fs audio → %d chars", info.duration, len(text))
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
async def transcribe(audio_bytes: bytes, sample_rate: int) -> str:
|
||||||
|
async with _semaphore:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
return await loop.run_in_executor(
|
||||||
|
_executor,
|
||||||
|
partial(_transcribe_sync, audio_bytes, sample_rate),
|
||||||
|
)
|
||||||
48
server/tts.py
Normal file
48
server/tts.py
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
import asyncio
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import wave
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from piper.voice import PiperVoice
|
||||||
|
|
||||||
|
from .config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_voice: PiperVoice | None = None
|
||||||
|
_executor = ThreadPoolExecutor(max_workers=2)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model() -> None:
|
||||||
|
global _voice
|
||||||
|
logger.info("Loading Piper TTS voice %s...", settings.piper_model)
|
||||||
|
_voice = PiperVoice.load(settings.piper_model)
|
||||||
|
logger.info("Piper TTS loaded.")
|
||||||
|
|
||||||
|
|
||||||
|
def _synthesize_sync(text: str) -> bytes:
|
||||||
|
"""Synthesize text to raw PCM 16-bit mono audio bytes."""
|
||||||
|
assert _voice is not None, "Piper voice not loaded — call load_model() first"
|
||||||
|
|
||||||
|
buf = io.BytesIO()
|
||||||
|
with wave.open(buf, "wb") as wf:
|
||||||
|
_voice.synthesize(text, wf)
|
||||||
|
|
||||||
|
# Extract raw PCM from the WAV container
|
||||||
|
buf.seek(0)
|
||||||
|
with wave.open(buf, "rb") as wf:
|
||||||
|
pcm_data = wf.readframes(wf.getnframes())
|
||||||
|
|
||||||
|
logger.debug("Synthesized %d chars → %d bytes PCM", len(text), len(pcm_data))
|
||||||
|
return pcm_data
|
||||||
|
|
||||||
|
|
||||||
|
async def synthesize(text: str) -> bytes:
|
||||||
|
"""Async wrapper — runs Piper in a thread pool."""
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
return await loop.run_in_executor(
|
||||||
|
_executor,
|
||||||
|
partial(_synthesize_sync, text),
|
||||||
|
)
|
||||||
Loading…
Reference in a new issue