Add server component: FastAPI + WebSocket speech pipeline
Voice-in/voice-out server for the Shop Bob machine shop assistant. STT (faster-whisper), LLM (Ollama), TTS (Piper) with sentence-level audio streaming over WebSocket for low-latency responses. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
commit
98310bf062
9 changed files with 441 additions and 0 deletions
0
server/__init__.py
Normal file
0
server/__init__.py
Normal file
39
server/config.py
Normal file
39
server/config.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = {"env_prefix": "BOB_", "env_file": ".env", "env_file_encoding": "utf-8"}
|
||||
|
||||
# Networking
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8765
|
||||
|
||||
# Whisper STT
|
||||
whisper_model: str = "large-v3"
|
||||
whisper_device: str = "cuda"
|
||||
whisper_compute_type: str = "float16"
|
||||
stt_sample_rate: int = 16000
|
||||
max_concurrent_transcriptions: int = 2
|
||||
|
||||
# Ollama LLM
|
||||
ollama_url: str = "http://localhost:11434"
|
||||
llm_model: str = "llama3.1:8b"
|
||||
max_concurrent_llm: int = 3
|
||||
|
||||
# Piper TTS
|
||||
piper_model: str = "en_US-lessac-medium"
|
||||
tts_sample_rate: int = 22050
|
||||
|
||||
# System prompt for the machine shop assistant
|
||||
system_prompt: str = (
|
||||
"You are Bob, a knowledgeable machine shop assistant. "
|
||||
"Give concise, direct answers about machining, tooling, materials, "
|
||||
"feeds and speeds, and shop processes. "
|
||||
"Always prioritize safety — if a question involves a potentially "
|
||||
"dangerous operation, lead with the safety considerations. "
|
||||
"Keep answers short and practical — shop floor workers need quick info, "
|
||||
"not essays."
|
||||
)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
26
server/connection_manager.py
Normal file
26
server/connection_manager.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import logging
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
def __init__(self) -> None:
|
||||
self._connections: dict[str, WebSocket] = {}
|
||||
|
||||
async def connect(self, client_id: str, websocket: WebSocket) -> None:
|
||||
await websocket.accept()
|
||||
self._connections[client_id] = websocket
|
||||
logger.info("Client connected: %s (total: %d)", client_id, len(self._connections))
|
||||
|
||||
def disconnect(self, client_id: str) -> None:
|
||||
self._connections.pop(client_id, None)
|
||||
logger.info("Client disconnected: %s (total: %d)", client_id, len(self._connections))
|
||||
|
||||
def get_active_connections(self) -> dict[str, WebSocket]:
|
||||
return dict(self._connections)
|
||||
|
||||
@property
|
||||
def active_count(self) -> int:
|
||||
return len(self._connections)
|
||||
59
server/llm.py
Normal file
59
server/llm.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
|
||||
from .config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_semaphore = asyncio.Semaphore(settings.max_concurrent_llm)
|
||||
|
||||
|
||||
async def check_ollama() -> bool:
|
||||
"""Verify Ollama is reachable."""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.get(f"{settings.ollama_url}/api/tags", timeout=5)
|
||||
resp.raise_for_status()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error("Ollama not reachable at %s: %s", settings.ollama_url, e)
|
||||
return False
|
||||
|
||||
|
||||
async def generate_response(
|
||||
transcript: str,
|
||||
system_prompt: str | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Stream text tokens from Ollama for the given user transcript."""
|
||||
prompt = system_prompt or settings.system_prompt
|
||||
|
||||
payload = {
|
||||
"model": settings.llm_model,
|
||||
"messages": [
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": transcript},
|
||||
],
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
async with _semaphore:
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(120.0, connect=10.0)) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{settings.ollama_url}/api/chat",
|
||||
json=payload,
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
data = json.loads(line)
|
||||
token = data.get("message", {}).get("content", "")
|
||||
if token:
|
||||
yield token
|
||||
if data.get("done"):
|
||||
break
|
||||
114
server/main.py
Normal file
114
server/main.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
import json
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from . import llm, stt, tts
|
||||
from .config import settings
|
||||
from .connection_manager import ConnectionManager
|
||||
from .pipeline import process_request
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s [%(name)s] %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
manager = ConnectionManager()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup
|
||||
logger.info("Starting Shop Bob server...")
|
||||
stt.load_model()
|
||||
tts.load_model()
|
||||
if not await llm.check_ollama():
|
||||
logger.warning("Ollama is not reachable — LLM calls will fail until it's up")
|
||||
logger.info("Shop Bob server ready on %s:%d", settings.host, settings.port)
|
||||
yield
|
||||
# Shutdown
|
||||
logger.info("Shutting down Shop Bob server...")
|
||||
|
||||
|
||||
app = FastAPI(title="Shop Bob", lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
ollama_ok = await llm.check_ollama()
|
||||
return {
|
||||
"status": "ok",
|
||||
"active_connections": manager.active_count,
|
||||
"ollama": "ok" if ollama_ok else "unreachable",
|
||||
}
|
||||
|
||||
|
||||
@app.websocket("/ws")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
client_id: str | None = None
|
||||
try:
|
||||
# Wait for the first message which should be audio_start
|
||||
raw = await websocket.receive_text()
|
||||
msg = json.loads(raw)
|
||||
|
||||
if msg.get("type") != "audio_start":
|
||||
await websocket.close(code=1008, reason="Expected audio_start message")
|
||||
return
|
||||
|
||||
client_id = msg.get("client_id", "unknown")
|
||||
sample_rate = msg.get("sample_rate", settings.stt_sample_rate)
|
||||
|
||||
await manager.connect(client_id, websocket)
|
||||
|
||||
# Main message loop
|
||||
while True:
|
||||
audio_chunks: list[bytes] = []
|
||||
|
||||
# Collect binary audio frames until audio_end
|
||||
while True:
|
||||
message = await websocket.receive()
|
||||
|
||||
if "text" in message:
|
||||
data = json.loads(message["text"])
|
||||
if data.get("type") == "audio_end":
|
||||
break
|
||||
elif data.get("type") == "audio_start":
|
||||
# New utterance — update sample rate if provided
|
||||
sample_rate = data.get("sample_rate", sample_rate)
|
||||
audio_chunks = []
|
||||
continue
|
||||
elif "bytes" in message:
|
||||
audio_chunks.append(message["bytes"])
|
||||
|
||||
if audio_chunks:
|
||||
audio_bytes = b"".join(audio_chunks)
|
||||
await process_request(audio_bytes, sample_rate, websocket)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
logger.info("Client %s disconnected", client_id)
|
||||
except Exception:
|
||||
logger.exception("WebSocket error for client %s", client_id)
|
||||
finally:
|
||||
if client_id:
|
||||
manager.disconnect(client_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"server.main:app",
|
||||
host=settings.host,
|
||||
port=settings.port,
|
||||
log_level="info",
|
||||
)
|
||||
85
server/pipeline.py
Normal file
85
server/pipeline.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
from . import llm, stt, tts
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Regex to split text on sentence boundaries while keeping the delimiters
|
||||
_SENTENCE_RE = re.compile(r"(?<=[.!?])\s+")
|
||||
|
||||
|
||||
async def _send_status(ws: WebSocket, state: str) -> None:
|
||||
await ws.send_text(json.dumps({"type": "status", "state": state}))
|
||||
|
||||
|
||||
async def process_request(
|
||||
audio_bytes: bytes,
|
||||
sample_rate: int,
|
||||
websocket: WebSocket,
|
||||
) -> None:
|
||||
"""Run the full speech-in → text-out → speech-out pipeline."""
|
||||
try:
|
||||
# --- STT ---
|
||||
await _send_status(websocket, "transcribing")
|
||||
transcript = await stt.transcribe(audio_bytes, sample_rate)
|
||||
|
||||
if not transcript.strip():
|
||||
await websocket.send_text(
|
||||
json.dumps({"type": "transcript", "text": ""})
|
||||
)
|
||||
await websocket.send_text(json.dumps({"type": "response_end"}))
|
||||
return
|
||||
|
||||
await websocket.send_text(
|
||||
json.dumps({"type": "transcript", "text": transcript})
|
||||
)
|
||||
|
||||
# --- LLM ---
|
||||
await _send_status(websocket, "thinking")
|
||||
full_response = ""
|
||||
sentence_buffer = ""
|
||||
|
||||
# --- Sentence-level TTS streaming ---
|
||||
await _send_status(websocket, "speaking")
|
||||
|
||||
async for token in llm.generate_response(transcript):
|
||||
full_response += token
|
||||
sentence_buffer += token
|
||||
|
||||
# Check if we have one or more complete sentences
|
||||
parts = _SENTENCE_RE.split(sentence_buffer)
|
||||
if len(parts) > 1:
|
||||
# All parts except the last are complete sentences
|
||||
for sentence in parts[:-1]:
|
||||
sentence = sentence.strip()
|
||||
if sentence:
|
||||
audio_chunk = await tts.synthesize(sentence)
|
||||
await websocket.send_bytes(audio_chunk)
|
||||
# Keep the incomplete remainder
|
||||
sentence_buffer = parts[-1]
|
||||
|
||||
# Flush any remaining text
|
||||
sentence_buffer = sentence_buffer.strip()
|
||||
if sentence_buffer:
|
||||
audio_chunk = await tts.synthesize(sentence_buffer)
|
||||
await websocket.send_bytes(audio_chunk)
|
||||
|
||||
# Send the full text response and signal completion
|
||||
await websocket.send_text(
|
||||
json.dumps({"type": "response_text", "text": full_response})
|
||||
)
|
||||
await websocket.send_text(json.dumps({"type": "response_end"}))
|
||||
|
||||
except Exception:
|
||||
logger.exception("Pipeline error")
|
||||
try:
|
||||
await websocket.send_text(
|
||||
json.dumps({"type": "error", "text": "Internal processing error"})
|
||||
)
|
||||
await websocket.send_text(json.dumps({"type": "response_end"}))
|
||||
except Exception:
|
||||
pass # Client already disconnected
|
||||
8
server/requirements.txt
Normal file
8
server/requirements.txt
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
fastapi>=0.104
|
||||
uvicorn[standard]>=0.24
|
||||
websockets>=12.0
|
||||
faster-whisper>=1.0
|
||||
httpx>=0.25
|
||||
piper-tts>=1.2
|
||||
numpy>=1.24
|
||||
pydantic-settings>=2.0
|
||||
62
server/stt.py
Normal file
62
server/stt.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
import asyncio
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
from .config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_model: WhisperModel | None = None
|
||||
_executor = ThreadPoolExecutor(max_workers=settings.max_concurrent_transcriptions)
|
||||
_semaphore = asyncio.Semaphore(settings.max_concurrent_transcriptions)
|
||||
|
||||
|
||||
def load_model() -> None:
|
||||
global _model
|
||||
logger.info(
|
||||
"Loading Whisper model %s on %s (%s)...",
|
||||
settings.whisper_model,
|
||||
settings.whisper_device,
|
||||
settings.whisper_compute_type,
|
||||
)
|
||||
_model = WhisperModel(
|
||||
settings.whisper_model,
|
||||
device=settings.whisper_device,
|
||||
compute_type=settings.whisper_compute_type,
|
||||
)
|
||||
logger.info("Whisper model loaded.")
|
||||
|
||||
|
||||
def _transcribe_sync(audio_bytes: bytes, sample_rate: int) -> str:
|
||||
assert _model is not None, "Whisper model not loaded — call load_model() first"
|
||||
|
||||
# Convert raw PCM 16-bit mono bytes to float32 numpy array
|
||||
audio = np.frombuffer(audio_bytes, dtype=np.int16).astype(np.float32) / 32768.0
|
||||
|
||||
if sample_rate != 16000:
|
||||
# faster-whisper expects 16kHz — resample via simple linear interpolation
|
||||
duration = len(audio) / sample_rate
|
||||
target_len = int(duration * 16000)
|
||||
audio = np.interp(
|
||||
np.linspace(0, len(audio) - 1, target_len),
|
||||
np.arange(len(audio)),
|
||||
audio,
|
||||
).astype(np.float32)
|
||||
|
||||
segments, info = _model.transcribe(audio, beam_size=5)
|
||||
text = " ".join(seg.text.strip() for seg in segments)
|
||||
logger.info("Transcribed %.1fs audio → %d chars", info.duration, len(text))
|
||||
return text
|
||||
|
||||
|
||||
async def transcribe(audio_bytes: bytes, sample_rate: int) -> str:
|
||||
async with _semaphore:
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
_executor,
|
||||
partial(_transcribe_sync, audio_bytes, sample_rate),
|
||||
)
|
||||
48
server/tts.py
Normal file
48
server/tts.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
import asyncio
|
||||
import io
|
||||
import logging
|
||||
import wave
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
from piper.voice import PiperVoice
|
||||
|
||||
from .config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_voice: PiperVoice | None = None
|
||||
_executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
|
||||
def load_model() -> None:
|
||||
global _voice
|
||||
logger.info("Loading Piper TTS voice %s...", settings.piper_model)
|
||||
_voice = PiperVoice.load(settings.piper_model)
|
||||
logger.info("Piper TTS loaded.")
|
||||
|
||||
|
||||
def _synthesize_sync(text: str) -> bytes:
|
||||
"""Synthesize text to raw PCM 16-bit mono audio bytes."""
|
||||
assert _voice is not None, "Piper voice not loaded — call load_model() first"
|
||||
|
||||
buf = io.BytesIO()
|
||||
with wave.open(buf, "wb") as wf:
|
||||
_voice.synthesize(text, wf)
|
||||
|
||||
# Extract raw PCM from the WAV container
|
||||
buf.seek(0)
|
||||
with wave.open(buf, "rb") as wf:
|
||||
pcm_data = wf.readframes(wf.getnframes())
|
||||
|
||||
logger.debug("Synthesized %d chars → %d bytes PCM", len(text), len(pcm_data))
|
||||
return pcm_data
|
||||
|
||||
|
||||
async def synthesize(text: str) -> bytes:
|
||||
"""Async wrapper — runs Piper in a thread pool."""
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(
|
||||
_executor,
|
||||
partial(_synthesize_sync, text),
|
||||
)
|
||||
Loading…
Reference in a new issue