xml-pipeline/xml_pipeline/server/websocket.py
dullfig bf31b0d14e Add AgentServer REST/WebSocket API
Implements the AgentServer API from docs/agentserver_api_spec.md:

REST API (/api/v1):
- Organism info and config endpoints
- Agent listing, details, config, schema
- Thread and message history with filtering
- Control endpoints (inject, pause, resume, kill, stop)

WebSocket:
- /ws: Main control channel with state snapshot + real-time events
- /ws/messages: Dedicated message stream with filtering

Infrastructure:
- Pydantic models with camelCase serialization
- ServerState bridges StreamPump to API
- Pump event hooks for real-time updates
- CLI 'serve' command: xml-pipeline serve [config] --port 8080

35 new tests for models, state, REST, and WebSocket.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 20:22:58 -08:00

316 lines
11 KiB
Python

"""
websocket.py — WebSocket endpoints for AgentServer.
Provides:
- /ws — Main control channel with state snapshot and real-time events
- /ws/messages — Dedicated message log stream with filtering
"""
from __future__ import annotations
import asyncio
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from xml_pipeline.server.models import (
SubscribeRequest,
WSConnectedEvent,
)
if TYPE_CHECKING:
from xml_pipeline.server.state import ServerState
class ConnectionManager:
"""Manages WebSocket connections and subscriptions."""
def __init__(self) -> None:
self.active_connections: Set[WebSocket] = set()
self.subscriptions: Dict[WebSocket, SubscribeRequest] = {}
async def connect(self, websocket: WebSocket) -> None:
"""Accept a new WebSocket connection."""
await websocket.accept()
self.active_connections.add(websocket)
self.subscriptions[websocket] = SubscribeRequest() # Default: all events
def disconnect(self, websocket: WebSocket) -> None:
"""Remove a WebSocket connection."""
self.active_connections.discard(websocket)
self.subscriptions.pop(websocket, None)
def set_subscription(self, websocket: WebSocket, sub: SubscribeRequest) -> None:
"""Update subscription filters for a connection."""
self.subscriptions[websocket] = sub
def should_send(self, websocket: WebSocket, event: Dict[str, Any]) -> bool:
"""Check if an event should be sent to this connection based on filters."""
sub = self.subscriptions.get(websocket)
if sub is None:
return True # No filters = send all
event_type = event.get("event", "")
# Filter by event type
if sub.events and event_type not in sub.events:
return False
# Filter by thread
thread_id = event.get("thread_id") or event.get("threadId")
if sub.threads and thread_id and thread_id not in sub.threads:
return False
# Filter by agent
agent = event.get("agent")
if sub.agents and agent and agent not in sub.agents:
return False
# For message events, check from/to
if event_type == "message":
msg = event.get("message", {})
from_id = msg.get("from") or msg.get("fromId")
to_id = msg.get("to") or msg.get("toId")
if sub.agents:
if from_id not in sub.agents and to_id not in sub.agents:
return False
# Filter by payload type
if sub.payload_types:
payload_type = None
if event_type == "message":
payload_type = event.get("message", {}).get("payloadType")
if payload_type and payload_type not in sub.payload_types:
return False
return True
async def broadcast(self, event: Dict[str, Any]) -> None:
"""Broadcast an event to all connections that match their subscription."""
disconnected: List[WebSocket] = []
for websocket in self.active_connections:
if self.should_send(websocket, event):
try:
await websocket.send_json(event)
except Exception:
disconnected.append(websocket)
# Clean up disconnected clients
for ws in disconnected:
self.disconnect(ws)
class MessageStreamManager:
"""Manages WebSocket connections for /ws/messages endpoint."""
def __init__(self) -> None:
self.active_connections: Set[WebSocket] = set()
self.filters: Dict[WebSocket, Dict[str, Any]] = {}
async def connect(self, websocket: WebSocket) -> None:
"""Accept a new WebSocket connection."""
await websocket.accept()
self.active_connections.add(websocket)
self.filters[websocket] = {} # Default: all messages
def disconnect(self, websocket: WebSocket) -> None:
"""Remove a WebSocket connection."""
self.active_connections.discard(websocket)
self.filters.pop(websocket, None)
def set_filter(self, websocket: WebSocket, filter_config: Dict[str, Any]) -> None:
"""Update message filter for a connection."""
self.filters[websocket] = filter_config
def should_send(self, websocket: WebSocket, message: Dict[str, Any]) -> bool:
"""Check if a message should be sent to this connection."""
flt = self.filters.get(websocket, {})
if not flt:
return True
# Filter by agents
agents = flt.get("agents", [])
if agents:
from_id = message.get("from") or message.get("fromId")
to_id = message.get("to") or message.get("toId")
if from_id not in agents and to_id not in agents:
return False
# Filter by threads
threads = flt.get("threads", [])
if threads:
thread_id = message.get("thread_id") or message.get("threadId")
if thread_id not in threads:
return False
# Filter by payload types
payload_types = flt.get("payload_types", [])
if payload_types:
payload_type = message.get("payloadType") or message.get("payload_type")
if payload_type not in payload_types:
return False
return True
async def broadcast_message(self, message: Dict[str, Any]) -> None:
"""Broadcast a message to all filtered connections."""
disconnected: List[WebSocket] = []
for websocket in self.active_connections:
if self.should_send(websocket, message):
try:
await websocket.send_json(message)
except Exception:
disconnected.append(websocket)
for ws in disconnected:
self.disconnect(ws)
def create_websocket_router(state: "ServerState") -> APIRouter:
"""Create WebSocket router with state dependency."""
router = APIRouter()
manager = ConnectionManager()
message_manager = MessageStreamManager()
# Subscribe state to WebSocket broadcasting
async def on_state_event(event: Dict[str, Any]) -> None:
"""Forward state events to WebSocket connections."""
await manager.broadcast(event)
# Also forward message events to the message stream
if event.get("event") == "message":
msg = event.get("message", {})
await message_manager.broadcast_message(msg)
state.subscribe(on_state_event)
@router.websocket("/ws")
async def websocket_control(websocket: WebSocket) -> None:
"""
Main control channel WebSocket.
On connect, sends full state snapshot. Then pushes events as they occur.
Accepts commands: subscribe, inject.
"""
await manager.connect(websocket)
try:
# Send connected event with state snapshot
threads, _ = state.get_threads(limit=100)
connected_event = WSConnectedEvent(
organism=state.get_organism_info(),
agents=state.get_agents(),
threads=threads,
)
await websocket.send_json(connected_event.model_dump(by_alias=True))
# Listen for commands
while True:
try:
data = await websocket.receive_json()
except Exception:
# Connection closed or invalid data
break
cmd = data.get("cmd", "")
if cmd == "subscribe":
# Update subscription filters
sub = SubscribeRequest(
threads=data.get("threads"),
agents=data.get("agents"),
payload_types=data.get("payload_types"),
events=data.get("events"),
)
manager.set_subscription(websocket, sub)
await websocket.send_json({"event": "subscribed", "filters": data})
elif cmd == "inject":
# Inject a message (same as REST /inject)
target = data.get("to")
payload = data.get("payload", {})
thread_id = data.get("thread_id")
if not target:
await websocket.send_json(
{"event": "error", "error": "Missing 'to' field"}
)
continue
agent = state.get_agent(target)
if agent is None:
await websocket.send_json(
{"event": "error", "error": f"Unknown agent: {target}"}
)
continue
import uuid
thread_id = thread_id or str(uuid.uuid4())
payload_type = next(iter(payload.keys()), "Payload")
msg_id = await state.record_message(
thread_id=thread_id,
from_id="api",
to_id=target,
payload_type=payload_type,
payload=payload,
)
await websocket.send_json(
{
"event": "injected",
"thread_id": thread_id,
"message_id": msg_id,
}
)
else:
await websocket.send_json(
{"event": "error", "error": f"Unknown command: {cmd}"}
)
except WebSocketDisconnect:
pass
finally:
manager.disconnect(websocket)
@router.websocket("/ws/messages")
async def websocket_messages(websocket: WebSocket) -> None:
"""
Dedicated message log stream.
Streams all messages flowing through the organism.
Clients can filter by agents, threads, and payload types.
"""
await message_manager.connect(websocket)
try:
while True:
try:
data = await websocket.receive_json()
except Exception:
break
cmd = data.get("cmd", "")
if cmd == "subscribe":
# Update filter
filter_config = data.get("filter", {})
message_manager.set_filter(websocket, filter_config)
await websocket.send_json({"event": "subscribed", "filter": filter_config})
else:
await websocket.send_json(
{"event": "error", "error": f"Unknown command: {cmd}"}
)
except WebSocketDisconnect:
pass
finally:
message_manager.disconnect(websocket)
return router