Implements the AgentServer API from docs/agentserver_api_spec.md: REST API (/api/v1): - Organism info and config endpoints - Agent listing, details, config, schema - Thread and message history with filtering - Control endpoints (inject, pause, resume, kill, stop) WebSocket: - /ws: Main control channel with state snapshot + real-time events - /ws/messages: Dedicated message stream with filtering Infrastructure: - Pydantic models with camelCase serialization - ServerState bridges StreamPump to API - Pump event hooks for real-time updates - CLI 'serve' command: xml-pipeline serve [config] --port 8080 35 new tests for models, state, REST, and WebSocket. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
316 lines
11 KiB
Python
316 lines
11 KiB
Python
"""
|
|
websocket.py — WebSocket endpoints for AgentServer.
|
|
|
|
Provides:
|
|
- /ws — Main control channel with state snapshot and real-time events
|
|
- /ws/messages — Dedicated message log stream with filtering
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
|
|
|
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
|
|
from xml_pipeline.server.models import (
|
|
SubscribeRequest,
|
|
WSConnectedEvent,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from xml_pipeline.server.state import ServerState
|
|
|
|
|
|
class ConnectionManager:
|
|
"""Manages WebSocket connections and subscriptions."""
|
|
|
|
def __init__(self) -> None:
|
|
self.active_connections: Set[WebSocket] = set()
|
|
self.subscriptions: Dict[WebSocket, SubscribeRequest] = {}
|
|
|
|
async def connect(self, websocket: WebSocket) -> None:
|
|
"""Accept a new WebSocket connection."""
|
|
await websocket.accept()
|
|
self.active_connections.add(websocket)
|
|
self.subscriptions[websocket] = SubscribeRequest() # Default: all events
|
|
|
|
def disconnect(self, websocket: WebSocket) -> None:
|
|
"""Remove a WebSocket connection."""
|
|
self.active_connections.discard(websocket)
|
|
self.subscriptions.pop(websocket, None)
|
|
|
|
def set_subscription(self, websocket: WebSocket, sub: SubscribeRequest) -> None:
|
|
"""Update subscription filters for a connection."""
|
|
self.subscriptions[websocket] = sub
|
|
|
|
def should_send(self, websocket: WebSocket, event: Dict[str, Any]) -> bool:
|
|
"""Check if an event should be sent to this connection based on filters."""
|
|
sub = self.subscriptions.get(websocket)
|
|
if sub is None:
|
|
return True # No filters = send all
|
|
|
|
event_type = event.get("event", "")
|
|
|
|
# Filter by event type
|
|
if sub.events and event_type not in sub.events:
|
|
return False
|
|
|
|
# Filter by thread
|
|
thread_id = event.get("thread_id") or event.get("threadId")
|
|
if sub.threads and thread_id and thread_id not in sub.threads:
|
|
return False
|
|
|
|
# Filter by agent
|
|
agent = event.get("agent")
|
|
if sub.agents and agent and agent not in sub.agents:
|
|
return False
|
|
|
|
# For message events, check from/to
|
|
if event_type == "message":
|
|
msg = event.get("message", {})
|
|
from_id = msg.get("from") or msg.get("fromId")
|
|
to_id = msg.get("to") or msg.get("toId")
|
|
if sub.agents:
|
|
if from_id not in sub.agents and to_id not in sub.agents:
|
|
return False
|
|
|
|
# Filter by payload type
|
|
if sub.payload_types:
|
|
payload_type = None
|
|
if event_type == "message":
|
|
payload_type = event.get("message", {}).get("payloadType")
|
|
if payload_type and payload_type not in sub.payload_types:
|
|
return False
|
|
|
|
return True
|
|
|
|
async def broadcast(self, event: Dict[str, Any]) -> None:
|
|
"""Broadcast an event to all connections that match their subscription."""
|
|
disconnected: List[WebSocket] = []
|
|
|
|
for websocket in self.active_connections:
|
|
if self.should_send(websocket, event):
|
|
try:
|
|
await websocket.send_json(event)
|
|
except Exception:
|
|
disconnected.append(websocket)
|
|
|
|
# Clean up disconnected clients
|
|
for ws in disconnected:
|
|
self.disconnect(ws)
|
|
|
|
|
|
class MessageStreamManager:
|
|
"""Manages WebSocket connections for /ws/messages endpoint."""
|
|
|
|
def __init__(self) -> None:
|
|
self.active_connections: Set[WebSocket] = set()
|
|
self.filters: Dict[WebSocket, Dict[str, Any]] = {}
|
|
|
|
async def connect(self, websocket: WebSocket) -> None:
|
|
"""Accept a new WebSocket connection."""
|
|
await websocket.accept()
|
|
self.active_connections.add(websocket)
|
|
self.filters[websocket] = {} # Default: all messages
|
|
|
|
def disconnect(self, websocket: WebSocket) -> None:
|
|
"""Remove a WebSocket connection."""
|
|
self.active_connections.discard(websocket)
|
|
self.filters.pop(websocket, None)
|
|
|
|
def set_filter(self, websocket: WebSocket, filter_config: Dict[str, Any]) -> None:
|
|
"""Update message filter for a connection."""
|
|
self.filters[websocket] = filter_config
|
|
|
|
def should_send(self, websocket: WebSocket, message: Dict[str, Any]) -> bool:
|
|
"""Check if a message should be sent to this connection."""
|
|
flt = self.filters.get(websocket, {})
|
|
if not flt:
|
|
return True
|
|
|
|
# Filter by agents
|
|
agents = flt.get("agents", [])
|
|
if agents:
|
|
from_id = message.get("from") or message.get("fromId")
|
|
to_id = message.get("to") or message.get("toId")
|
|
if from_id not in agents and to_id not in agents:
|
|
return False
|
|
|
|
# Filter by threads
|
|
threads = flt.get("threads", [])
|
|
if threads:
|
|
thread_id = message.get("thread_id") or message.get("threadId")
|
|
if thread_id not in threads:
|
|
return False
|
|
|
|
# Filter by payload types
|
|
payload_types = flt.get("payload_types", [])
|
|
if payload_types:
|
|
payload_type = message.get("payloadType") or message.get("payload_type")
|
|
if payload_type not in payload_types:
|
|
return False
|
|
|
|
return True
|
|
|
|
async def broadcast_message(self, message: Dict[str, Any]) -> None:
|
|
"""Broadcast a message to all filtered connections."""
|
|
disconnected: List[WebSocket] = []
|
|
|
|
for websocket in self.active_connections:
|
|
if self.should_send(websocket, message):
|
|
try:
|
|
await websocket.send_json(message)
|
|
except Exception:
|
|
disconnected.append(websocket)
|
|
|
|
for ws in disconnected:
|
|
self.disconnect(ws)
|
|
|
|
|
|
def create_websocket_router(state: "ServerState") -> APIRouter:
|
|
"""Create WebSocket router with state dependency."""
|
|
router = APIRouter()
|
|
manager = ConnectionManager()
|
|
message_manager = MessageStreamManager()
|
|
|
|
# Subscribe state to WebSocket broadcasting
|
|
async def on_state_event(event: Dict[str, Any]) -> None:
|
|
"""Forward state events to WebSocket connections."""
|
|
await manager.broadcast(event)
|
|
|
|
# Also forward message events to the message stream
|
|
if event.get("event") == "message":
|
|
msg = event.get("message", {})
|
|
await message_manager.broadcast_message(msg)
|
|
|
|
state.subscribe(on_state_event)
|
|
|
|
@router.websocket("/ws")
|
|
async def websocket_control(websocket: WebSocket) -> None:
|
|
"""
|
|
Main control channel WebSocket.
|
|
|
|
On connect, sends full state snapshot. Then pushes events as they occur.
|
|
Accepts commands: subscribe, inject.
|
|
"""
|
|
await manager.connect(websocket)
|
|
|
|
try:
|
|
# Send connected event with state snapshot
|
|
threads, _ = state.get_threads(limit=100)
|
|
connected_event = WSConnectedEvent(
|
|
organism=state.get_organism_info(),
|
|
agents=state.get_agents(),
|
|
threads=threads,
|
|
)
|
|
await websocket.send_json(connected_event.model_dump(by_alias=True))
|
|
|
|
# Listen for commands
|
|
while True:
|
|
try:
|
|
data = await websocket.receive_json()
|
|
except Exception:
|
|
# Connection closed or invalid data
|
|
break
|
|
|
|
cmd = data.get("cmd", "")
|
|
|
|
if cmd == "subscribe":
|
|
# Update subscription filters
|
|
sub = SubscribeRequest(
|
|
threads=data.get("threads"),
|
|
agents=data.get("agents"),
|
|
payload_types=data.get("payload_types"),
|
|
events=data.get("events"),
|
|
)
|
|
manager.set_subscription(websocket, sub)
|
|
await websocket.send_json({"event": "subscribed", "filters": data})
|
|
|
|
elif cmd == "inject":
|
|
# Inject a message (same as REST /inject)
|
|
target = data.get("to")
|
|
payload = data.get("payload", {})
|
|
thread_id = data.get("thread_id")
|
|
|
|
if not target:
|
|
await websocket.send_json(
|
|
{"event": "error", "error": "Missing 'to' field"}
|
|
)
|
|
continue
|
|
|
|
agent = state.get_agent(target)
|
|
if agent is None:
|
|
await websocket.send_json(
|
|
{"event": "error", "error": f"Unknown agent: {target}"}
|
|
)
|
|
continue
|
|
|
|
import uuid
|
|
|
|
thread_id = thread_id or str(uuid.uuid4())
|
|
payload_type = next(iter(payload.keys()), "Payload")
|
|
|
|
msg_id = await state.record_message(
|
|
thread_id=thread_id,
|
|
from_id="api",
|
|
to_id=target,
|
|
payload_type=payload_type,
|
|
payload=payload,
|
|
)
|
|
|
|
await websocket.send_json(
|
|
{
|
|
"event": "injected",
|
|
"thread_id": thread_id,
|
|
"message_id": msg_id,
|
|
}
|
|
)
|
|
|
|
else:
|
|
await websocket.send_json(
|
|
{"event": "error", "error": f"Unknown command: {cmd}"}
|
|
)
|
|
|
|
except WebSocketDisconnect:
|
|
pass
|
|
finally:
|
|
manager.disconnect(websocket)
|
|
|
|
@router.websocket("/ws/messages")
|
|
async def websocket_messages(websocket: WebSocket) -> None:
|
|
"""
|
|
Dedicated message log stream.
|
|
|
|
Streams all messages flowing through the organism.
|
|
Clients can filter by agents, threads, and payload types.
|
|
"""
|
|
await message_manager.connect(websocket)
|
|
|
|
try:
|
|
while True:
|
|
try:
|
|
data = await websocket.receive_json()
|
|
except Exception:
|
|
break
|
|
|
|
cmd = data.get("cmd", "")
|
|
|
|
if cmd == "subscribe":
|
|
# Update filter
|
|
filter_config = data.get("filter", {})
|
|
message_manager.set_filter(websocket, filter_config)
|
|
await websocket.send_json({"event": "subscribed", "filter": filter_config})
|
|
|
|
else:
|
|
await websocket.send_json(
|
|
{"event": "error", "error": f"Unknown command: {cmd}"}
|
|
)
|
|
|
|
except WebSocketDisconnect:
|
|
pass
|
|
finally:
|
|
message_manager.disconnect(websocket)
|
|
|
|
return router
|