Add AgentServer REST/WebSocket API

Implements the AgentServer API from docs/agentserver_api_spec.md:

REST API (/api/v1):
- Organism info and config endpoints
- Agent listing, details, config, schema
- Thread and message history with filtering
- Control endpoints (inject, pause, resume, kill, stop)

WebSocket:
- /ws: Main control channel with state snapshot + real-time events
- /ws/messages: Dedicated message stream with filtering

Infrastructure:
- Pydantic models with camelCase serialization
- ServerState bridges StreamPump to API
- Pump event hooks for real-time updates
- CLI 'serve' command: xml-pipeline serve [config] --port 8080

35 new tests for models, state, REST, and WebSocket.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
dullfig 2026-01-27 20:22:58 -08:00
parent 809862af35
commit bf31b0d14e
11 changed files with 2257 additions and 1 deletions

View file

@ -80,6 +80,13 @@ search = ["duckduckgo-search>=6.0"] # Web search tool
# Console example (optional, for interactive use) # Console example (optional, for interactive use)
console = ["prompt_toolkit>=3.0"] console = ["prompt_toolkit>=3.0"]
# API server (FastAPI + WebSocket)
server = [
"fastapi>=0.109",
"uvicorn[standard]>=0.27",
"websockets>=12.0",
]
# All LLM providers # All LLM providers
llm = ["xml-pipeline[anthropic,openai]"] llm = ["xml-pipeline[anthropic,openai]"]
@ -87,7 +94,7 @@ llm = ["xml-pipeline[anthropic,openai]"]
tools = ["xml-pipeline[redis,search]"] tools = ["xml-pipeline[redis,search]"]
# Everything (for local development) # Everything (for local development)
all = ["xml-pipeline[llm,tools,console]"] all = ["xml-pipeline[llm,tools,console,server]"]
# Testing # Testing
test = [ test = [

510
tests/test_server.py Normal file
View file

@ -0,0 +1,510 @@
"""
Tests for the AgentServer API.
Tests the REST API endpoints and WebSocket connections.
"""
import asyncio
import pytest
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock, AsyncMock, patch
# Skip all tests if FastAPI not available
pytest.importorskip("fastapi")
from fastapi.testclient import TestClient
from xml_pipeline.server.models import (
AgentState,
AgentInfo,
ThreadStatus,
ThreadInfo,
MessageInfo,
OrganismInfo,
OrganismStatus,
)
from xml_pipeline.server.state import ServerState, AgentRuntimeState, ThreadRuntimeState
from xml_pipeline.server.api import create_router
from xml_pipeline.server.app import create_app
# ============================================================================
# Mock StreamPump
# ============================================================================
@dataclass
class MockListener:
"""Mock listener for testing."""
name: str
description: str
is_agent: bool = False
peers: List[str] = None
payload_class: type = None
schema: Any = None
def __post_init__(self):
if self.peers is None:
self.peers = []
if self.payload_class is None:
self.payload_class = type("MockPayload", (), {"__module__": "test", "__name__": "MockPayload"})
@dataclass
class MockConfig:
"""Mock config for testing."""
name: str = "test-organism"
port: int = 8765
thread_scheduling: str = "breadth-first"
max_concurrent_pipelines: int = 50
max_concurrent_handlers: int = 20
class MockStreamPump:
"""Mock StreamPump for testing."""
def __init__(self):
self.config = MockConfig()
self.identity = None
self.listeners: Dict[str, MockListener] = {
"greeter": MockListener(
name="greeter",
description="Greeting agent",
is_agent=True,
peers=["shouter"],
),
"shouter": MockListener(
name="shouter",
description="Shouting handler",
is_agent=False,
),
}
self._event_callbacks = []
def subscribe_events(self, callback):
self._event_callbacks.append(callback)
def unsubscribe_events(self, callback):
if callback in self._event_callbacks:
self._event_callbacks.remove(callback)
# ============================================================================
# Test Fixtures
# ============================================================================
@pytest.fixture
def mock_pump():
"""Create a mock StreamPump."""
return MockStreamPump()
@pytest.fixture
def server_state(mock_pump):
"""Create ServerState with mock pump."""
return ServerState(mock_pump)
@pytest.fixture
def test_client(mock_pump):
"""Create FastAPI test client."""
app = create_app(mock_pump)
return TestClient(app)
# ============================================================================
# Test Models
# ============================================================================
class TestModels:
"""Test Pydantic model serialization."""
def test_agent_info_camel_case(self):
"""Test AgentInfo serializes to camelCase."""
agent = AgentInfo(
name="greeter",
description="Test agent",
is_agent=True,
peers=["shouter"],
payload_class="test.Greeting",
state=AgentState.IDLE,
current_thread=None,
queue_depth=0,
message_count=5,
)
data = agent.model_dump(by_alias=True)
assert "isAgent" in data
assert "payloadClass" in data
assert "currentThread" in data
assert "queueDepth" in data
assert "messageCount" in data
def test_thread_info_camel_case(self):
"""Test ThreadInfo serializes to camelCase."""
from datetime import datetime, timezone
thread = ThreadInfo(
id="test-uuid",
status=ThreadStatus.ACTIVE,
participants=["greeter", "shouter"],
message_count=3,
created_at=datetime.now(timezone.utc),
)
data = thread.model_dump(by_alias=True)
assert "messageCount" in data
assert "createdAt" in data
assert "lastActivity" in data
def test_organism_info_camel_case(self):
"""Test OrganismInfo serializes to camelCase."""
info = OrganismInfo(
name="test-organism",
status=OrganismStatus.RUNNING,
uptime_seconds=3600.0,
agent_count=2,
active_threads=1,
total_messages=10,
identity_configured=False,
)
data = info.model_dump(by_alias=True)
assert "uptimeSeconds" in data
assert "agentCount" in data
assert "activeThreads" in data
assert "totalMessages" in data
assert "identityConfigured" in data
# ============================================================================
# Test ServerState
# ============================================================================
class TestServerState:
"""Test ServerState functionality."""
def test_init_agents_from_pump(self, server_state):
"""Test agents are initialized from pump listeners."""
assert "greeter" in server_state._agents
assert "shouter" in server_state._agents
assert server_state._agents["greeter"].is_agent is True
assert server_state._agents["shouter"].is_agent is False
def test_get_organism_info(self, server_state):
"""Test getting organism info."""
info = server_state.get_organism_info()
assert info.name == "test-organism"
assert info.agent_count == 2
assert info.status == OrganismStatus.STARTING
def test_set_running(self, server_state):
"""Test setting running status."""
server_state.set_running()
info = server_state.get_organism_info()
assert info.status == OrganismStatus.RUNNING
def test_get_agents(self, server_state):
"""Test getting all agents."""
agents = server_state.get_agents()
assert len(agents) == 2
names = [a.name for a in agents]
assert "greeter" in names
assert "shouter" in names
def test_get_agent(self, server_state):
"""Test getting single agent."""
agent = server_state.get_agent("greeter")
assert agent is not None
assert agent.name == "greeter"
assert agent.is_agent is True
def test_get_agent_not_found(self, server_state):
"""Test getting non-existent agent."""
agent = server_state.get_agent("nonexistent")
assert agent is None
@pytest.mark.asyncio
async def test_record_message(self, server_state):
"""Test recording a message."""
msg_id = await server_state.record_message(
thread_id="test-thread",
from_id="greeter",
to_id="shouter",
payload_type="GreetingResponse",
payload={"message": "Hello!"},
)
assert msg_id is not None
# Check message was recorded
messages, total = server_state.get_messages(thread_id="test-thread")
assert total == 1
assert messages[0].from_id == "greeter"
assert messages[0].to_id == "shouter"
@pytest.mark.asyncio
async def test_update_agent_state(self, server_state):
"""Test updating agent state."""
await server_state.update_agent_state("greeter", AgentState.PROCESSING, "thread-1")
agent = server_state.get_agent("greeter")
assert agent.state == AgentState.PROCESSING
assert agent.current_thread == "thread-1"
@pytest.mark.asyncio
async def test_complete_thread(self, server_state):
"""Test completing a thread."""
# First record a message to create the thread
await server_state.record_message(
thread_id="test-thread",
from_id="greeter",
to_id="shouter",
payload_type="Test",
payload={},
)
# Complete the thread
await server_state.complete_thread("test-thread", ThreadStatus.COMPLETED)
thread = server_state.get_thread("test-thread")
assert thread.status == ThreadStatus.COMPLETED
def test_get_threads_with_filter(self, server_state):
"""Test filtering threads by status."""
# No threads initially
threads, total = server_state.get_threads(status=ThreadStatus.ACTIVE)
assert total == 0
def test_get_organism_config(self, server_state):
"""Test getting organism config."""
config = server_state.get_organism_config()
assert config["name"] == "test-organism"
assert config["port"] == 8765
# ============================================================================
# Test REST API
# ============================================================================
class TestRestAPI:
"""Test REST API endpoints."""
def test_health_check(self, test_client):
"""Test health check endpoint."""
response = test_client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
assert "organism" in data
def test_get_organism(self, test_client):
"""Test GET /api/v1/organism."""
response = test_client.get("/api/v1/organism")
assert response.status_code == 200
data = response.json()
assert data["name"] == "test-organism"
assert "agentCount" in data
def test_get_organism_config(self, test_client):
"""Test GET /api/v1/organism/config."""
response = test_client.get("/api/v1/organism/config")
assert response.status_code == 200
data = response.json()
assert data["name"] == "test-organism"
def test_list_agents(self, test_client):
"""Test GET /api/v1/agents."""
response = test_client.get("/api/v1/agents")
assert response.status_code == 200
data = response.json()
assert "agents" in data
assert data["count"] == 2
def test_get_agent(self, test_client):
"""Test GET /api/v1/agents/{name}."""
response = test_client.get("/api/v1/agents/greeter")
assert response.status_code == 200
data = response.json()
assert data["name"] == "greeter"
assert data["isAgent"] is True
def test_get_agent_not_found(self, test_client):
"""Test GET /api/v1/agents/{name} with non-existent agent."""
response = test_client.get("/api/v1/agents/nonexistent")
assert response.status_code == 404
def test_get_agent_config(self, test_client):
"""Test GET /api/v1/agents/{name}/config."""
response = test_client.get("/api/v1/agents/greeter/config")
assert response.status_code == 200
data = response.json()
assert data["name"] == "greeter"
assert "isAgent" in data
def test_list_threads(self, test_client):
"""Test GET /api/v1/threads."""
response = test_client.get("/api/v1/threads")
assert response.status_code == 200
data = response.json()
assert "threads" in data
assert "count" in data
assert "total" in data
def test_list_threads_with_invalid_status(self, test_client):
"""Test GET /api/v1/threads with invalid status filter."""
response = test_client.get("/api/v1/threads?status=invalid")
assert response.status_code == 400
def test_list_messages(self, test_client):
"""Test GET /api/v1/messages."""
response = test_client.get("/api/v1/messages")
assert response.status_code == 200
data = response.json()
assert "messages" in data
assert "count" in data
def test_inject_message(self, test_client):
"""Test POST /api/v1/inject."""
response = test_client.post(
"/api/v1/inject",
json={
"to": "greeter",
"payload": {"name": "Dan"},
},
)
assert response.status_code == 200
data = response.json()
assert "threadId" in data
assert "messageId" in data
def test_inject_message_unknown_agent(self, test_client):
"""Test POST /api/v1/inject with unknown agent."""
response = test_client.post(
"/api/v1/inject",
json={
"to": "nonexistent",
"payload": {"name": "Dan"},
},
)
assert response.status_code == 400
def test_pause_agent(self, test_client):
"""Test POST /api/v1/agents/{name}/pause."""
response = test_client.post("/api/v1/agents/greeter/pause")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["state"] == "paused"
def test_resume_agent(self, test_client):
"""Test POST /api/v1/agents/{name}/resume."""
response = test_client.post("/api/v1/agents/greeter/resume")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
assert data["state"] == "idle"
def test_stop_organism(self, test_client):
"""Test POST /api/v1/organism/stop."""
response = test_client.post("/api/v1/organism/stop")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
# ============================================================================
# Test WebSocket
# ============================================================================
class TestWebSocket:
"""Test WebSocket endpoints."""
def test_websocket_connect(self, test_client):
"""Test WebSocket connection."""
with test_client.websocket_connect("/ws") as websocket:
# Should receive connected event with state snapshot
data = websocket.receive_json()
assert data["event"] == "connected"
assert "organism" in data
assert "agents" in data
assert "threads" in data
def test_websocket_subscribe(self, test_client):
"""Test WebSocket subscribe command."""
with test_client.websocket_connect("/ws") as websocket:
# Receive initial connected event
websocket.receive_json()
# Send subscribe command
websocket.send_json({
"cmd": "subscribe",
"agents": ["greeter"],
"events": ["message"],
})
# Should receive subscribed confirmation
data = websocket.receive_json()
assert data["event"] == "subscribed"
def test_websocket_inject(self, test_client):
"""Test WebSocket inject command."""
with test_client.websocket_connect("/ws") as websocket:
# Receive initial connected event
websocket.receive_json()
# Send inject command
websocket.send_json({
"cmd": "inject",
"to": "greeter",
"payload": {"name": "Dan"},
})
# Should receive injected confirmation
data = websocket.receive_json()
assert data["event"] == "injected"
assert "thread_id" in data
assert "message_id" in data
def test_websocket_inject_unknown_agent(self, test_client):
"""Test WebSocket inject with unknown agent."""
with test_client.websocket_connect("/ws") as websocket:
# Receive initial connected event
websocket.receive_json()
# Send inject command to unknown agent
websocket.send_json({
"cmd": "inject",
"to": "nonexistent",
"payload": {},
})
# Should receive error
data = websocket.receive_json()
assert data["event"] == "error"
def test_websocket_unknown_command(self, test_client):
"""Test WebSocket with unknown command."""
with test_client.websocket_connect("/ws") as websocket:
# Receive initial connected event
websocket.receive_json()
# Send unknown command
websocket.send_json({
"cmd": "unknown_command",
})
# Should receive error
data = websocket.receive_json()
assert data["event"] == "error"
def test_websocket_messages_stream(self, test_client):
"""Test WebSocket messages stream endpoint."""
with test_client.websocket_connect("/ws/messages") as websocket:
# Send subscribe command
websocket.send_json({
"cmd": "subscribe",
"filter": {
"agents": ["greeter"],
},
})
# Should receive subscribed confirmation
data = websocket.receive_json()
assert data["event"] == "subscribed"
assert "filter" in data

View file

@ -3,6 +3,7 @@ xml-pipeline CLI entry point.
Usage: Usage:
xml-pipeline run [config.yaml] Run an organism xml-pipeline run [config.yaml] Run an organism
xml-pipeline serve [config.yaml] Run organism with API server
xml-pipeline init [name] Create new organism config xml-pipeline init [name] Create new organism config
xml-pipeline check [config.yaml] Validate config without running xml-pipeline check [config.yaml] Validate config without running
xml-pipeline version Show version info xml-pipeline version Show version info
@ -37,6 +38,68 @@ def cmd_run(args: argparse.Namespace) -> int:
return 1 return 1
def cmd_serve(args: argparse.Namespace) -> int:
"""Run an organism with the AgentServer API."""
try:
import uvicorn
except ImportError:
print("Error: uvicorn not installed.", file=sys.stderr)
print("Install with: pip install xml-pipeline[server]", file=sys.stderr)
return 1
from xml_pipeline.message_bus import bootstrap
config_path = Path(args.config)
if not config_path.exists():
print(f"Error: Config file not found: {config_path}", file=sys.stderr)
return 1
async def run_with_server():
"""Bootstrap pump and run with server."""
from xml_pipeline.server import create_app
# Bootstrap the pump
pump = await bootstrap(str(config_path))
# Create FastAPI app
app = create_app(pump)
# Run uvicorn
config = uvicorn.Config(
app,
host=args.host,
port=args.port,
log_level="info",
)
server = uvicorn.Server(config)
# Run pump and server concurrently
pump_task = asyncio.create_task(pump.run())
try:
await server.serve()
finally:
await pump.shutdown()
pump_task.cancel()
try:
await pump_task
except asyncio.CancelledError:
pass
try:
print(f"Starting AgentServer on http://{args.host}:{args.port}")
print(f" API docs: http://{args.host}:{args.port}/docs")
print(f" WebSocket: ws://{args.host}:{args.port}/ws")
asyncio.run(run_with_server())
return 0
except KeyboardInterrupt:
print("\nShutdown requested.")
return 0
except Exception as e:
print(f"Error: {e}", file=sys.stderr)
return 1
def cmd_init(args: argparse.Namespace) -> int: def cmd_init(args: argparse.Namespace) -> int:
"""Initialize a new organism config.""" """Initialize a new organism config."""
from xml_pipeline.config.template import create_organism_template from xml_pipeline.config.template import create_organism_template
@ -149,6 +212,13 @@ def main() -> int:
run_parser.add_argument("config", nargs="?", default="organism.yaml", help="Config file") run_parser.add_argument("config", nargs="?", default="organism.yaml", help="Config file")
run_parser.set_defaults(func=cmd_run) run_parser.set_defaults(func=cmd_run)
# serve
serve_parser = subparsers.add_parser("serve", help="Run organism with API server")
serve_parser.add_argument("config", nargs="?", default="organism.yaml", help="Config file")
serve_parser.add_argument("--host", default="0.0.0.0", help="Host to bind (default: 0.0.0.0)")
serve_parser.add_argument("--port", "-p", type=int, default=8080, help="Port to listen on (default: 8080)")
serve_parser.set_defaults(func=cmd_serve)
# init # init
init_parser = subparsers.add_parser("init", help="Create new organism config") init_parser = subparsers.add_parser("init", help="Create new organism config")
init_parser.add_argument("name", nargs="?", help="Organism name") init_parser.add_argument("name", nargs="?", help="Organism name")

View file

@ -33,6 +33,12 @@ from xml_pipeline.message_bus.stream_pump import (
get_stream_pump, get_stream_pump,
set_stream_pump, set_stream_pump,
reset_stream_pump, reset_stream_pump,
# Event hooks
PumpEvent,
MessageReceivedEvent,
MessageSentEvent,
AgentStateEvent,
ThreadEvent,
) )
from xml_pipeline.message_bus.message_state import ( from xml_pipeline.message_bus.message_state import (
@ -71,6 +77,12 @@ __all__ = [
"get_stream_pump", "get_stream_pump",
"set_stream_pump", "set_stream_pump",
"reset_stream_pump", "reset_stream_pump",
# Event hooks
"PumpEvent",
"MessageReceivedEvent",
"MessageSentEvent",
"AgentStateEvent",
"ThreadEvent",
# Message state # Message state
"MessageState", "MessageState",
"HandlerMetadata", "HandlerMetadata",

View file

@ -49,6 +49,56 @@ from xml_pipeline.memory import get_context_buffer
pump_logger = logging.getLogger(__name__) pump_logger = logging.getLogger(__name__)
# ============================================================================
# Event Hooks
# ============================================================================
@dataclass
class PumpEvent:
"""Base class for pump events."""
pass
@dataclass
class MessageReceivedEvent(PumpEvent):
"""Fired when a message is received by a handler."""
thread_id: str
from_id: str
to_id: str
payload_type: str
payload: Any
@dataclass
class MessageSentEvent(PumpEvent):
"""Fired when a handler sends a response."""
thread_id: str
from_id: str
to_id: str
payload_type: str
payload: Any
@dataclass
class AgentStateEvent(PumpEvent):
"""Fired when an agent's processing state changes."""
agent_name: str
state: str # "idle", "processing", "waiting", "error"
thread_id: Optional[str] = None
@dataclass
class ThreadEvent(PumpEvent):
"""Fired when a thread is created or completed."""
thread_id: str
status: str # "created", "active", "completed", "error", "killed"
participants: List[str] = field(default_factory=list)
error: Optional[str] = None
EventCallback = Callable[[PumpEvent], None]
# ============================================================================ # ============================================================================
# Configuration (same as before) # Configuration (same as before)
# ============================================================================ # ============================================================================
@ -233,6 +283,9 @@ class StreamPump:
# Shutdown control # Shutdown control
self._running = False self._running = False
# Event hooks for external observers (ServerState, etc.)
self._event_callbacks: List[EventCallback] = []
# Process pool for cpu_bound handlers # Process pool for cpu_bound handlers
self._process_pool: Optional[ProcessPoolExecutor] = None self._process_pool: Optional[ProcessPoolExecutor] = None
if config.process_pool_enabled: if config.process_pool_enabled:
@ -256,6 +309,27 @@ class StreamPump:
self._shared_backend = get_shared_backend(backend_config) self._shared_backend = get_shared_backend(backend_config)
pump_logger.info(f"Shared backend: {config.backend_type}") pump_logger.info(f"Shared backend: {config.backend_type}")
# ------------------------------------------------------------------
# Event Hooks
# ------------------------------------------------------------------
def subscribe_events(self, callback: EventCallback) -> None:
"""Subscribe to pump events (message flow, agent state, thread lifecycle)."""
self._event_callbacks.append(callback)
def unsubscribe_events(self, callback: EventCallback) -> None:
"""Unsubscribe from pump events."""
if callback in self._event_callbacks:
self._event_callbacks.remove(callback)
def _emit_event(self, event: PumpEvent) -> None:
"""Emit an event to all subscribers (non-blocking)."""
for callback in self._event_callbacks:
try:
callback(event)
except Exception as e:
pump_logger.warning(f"Event callback error: {e}")
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Registration # Registration
# ------------------------------------------------------------------ # ------------------------------------------------------------------
@ -493,6 +567,12 @@ class StreamPump:
await semaphore.acquire() await semaphore.acquire()
try: try:
# Emit agent state change event
self._emit_event(AgentStateEvent(
agent_name=listener.name,
state="processing",
thread_id=state.thread_id,
))
# Ensure we have a valid thread chain # Ensure we have a valid thread chain
registry = get_registry() registry = get_registry()
todo_registry = get_todo_registry() todo_registry = get_todo_registry()
@ -566,6 +646,15 @@ class StreamPump:
) )
payload_ref = state.payload payload_ref = state.payload
# Emit message received event
self._emit_event(MessageReceivedEvent(
thread_id=current_thread,
from_id=state.from_id or "",
to_id=listener.name,
payload_type=type(payload_ref).__name__,
payload=payload_ref,
))
# Dispatch to handler - either in-process or via ProcessPool # Dispatch to handler - either in-process or via ProcessPool
if listener.cpu_bound and self._process_pool and self._shared_backend: if listener.cpu_bound and self._process_pool and self._shared_backend:
response = await self._dispatch_to_process_pool( response = await self._dispatch_to_process_pool(
@ -578,6 +667,12 @@ class StreamPump:
# None means "no response needed" - don't re-inject # None means "no response needed" - don't re-inject
if response is None: if response is None:
# Emit idle state
self._emit_event(AgentStateEvent(
agent_name=listener.name,
state="idle",
thread_id=current_thread,
))
continue continue
# Handle clean HandlerResponse (preferred) # Handle clean HandlerResponse (preferred)
@ -653,6 +748,23 @@ class StreamPump:
response_bytes = b"<huh>Handler returned invalid type</huh>" response_bytes = b"<huh>Handler returned invalid type</huh>"
thread_id = state.thread_id thread_id = state.thread_id
# Emit message sent event
if isinstance(response, HandlerResponse):
self._emit_event(MessageSentEvent(
thread_id=thread_id,
from_id=listener.name,
to_id=to_id,
payload_type=type(response.payload).__name__,
payload=response.payload,
))
# Emit agent state back to idle
self._emit_event(AgentStateEvent(
agent_name=listener.name,
state="idle",
thread_id=None,
))
# Yield response — will be processed by next iteration # Yield response — will be processed by next iteration
yield MessageState( yield MessageState(
raw_bytes=response_bytes, raw_bytes=response_bytes,
@ -665,6 +777,12 @@ class StreamPump:
semaphore.release() semaphore.release()
except Exception as exc: except Exception as exc:
# Emit error state
self._emit_event(AgentStateEvent(
agent_name=listener.name,
state="error",
thread_id=state.thread_id,
))
yield MessageState( yield MessageState(
raw_bytes=f"<huh>Handler {listener.name} crashed: {exc}</huh>".encode(), raw_bytes=f"<huh>Handler {listener.name} crashed: {exc}</huh>".encode(),
thread_id=state.thread_id, thread_id=state.thread_id,

View file

@ -0,0 +1,26 @@
"""
server FastAPI-based AgentServer API for monitoring and controlling organisms.
Provides:
- REST API for querying organism state (agents, threads, messages)
- WebSocket for real-time events
- Message injection endpoint
Usage:
from xml_pipeline.server import create_app, run_server
# With existing pump
app = create_app(pump)
uvicorn.run(app, host="0.0.0.0", port=8080)
# Or use CLI
xml-pipeline serve config/organism.yaml --port 8080
"""
from xml_pipeline.server.app import create_app, run_server, run_server_sync
__all__ = [
"create_app",
"run_server",
"run_server_sync",
]

273
xml_pipeline/server/api.py Normal file
View file

@ -0,0 +1,273 @@
"""
api.py REST API routes for AgentServer.
Provides endpoints for:
- Organism info and config
- Agent listing and details
- Thread listing and management
- Message injection
"""
from __future__ import annotations
import uuid
from typing import TYPE_CHECKING, Optional
from fastapi import APIRouter, HTTPException, Query
from xml_pipeline.server.models import (
AgentInfo,
AgentListResponse,
ErrorResponse,
InjectRequest,
InjectResponse,
MessageListResponse,
OrganismInfo,
ThreadInfo,
ThreadListResponse,
ThreadStatus,
)
if TYPE_CHECKING:
from xml_pipeline.server.state import ServerState
def create_router(state: "ServerState") -> APIRouter:
"""Create API router with state dependency."""
router = APIRouter(prefix="/api/v1")
# =========================================================================
# Organism Endpoints
# =========================================================================
@router.get("/organism", response_model=OrganismInfo)
async def get_organism() -> OrganismInfo:
"""Get organism overview and stats."""
return state.get_organism_info()
@router.get("/organism/config")
async def get_organism_config() -> dict:
"""Get sanitized organism configuration (no secrets)."""
return state.get_organism_config()
# =========================================================================
# Agent Endpoints
# =========================================================================
@router.get("/agents", response_model=AgentListResponse)
async def list_agents() -> AgentListResponse:
"""List all agents with current state."""
agents = state.get_agents()
return AgentListResponse(agents=agents, count=len(agents))
@router.get("/agents/{name}", response_model=AgentInfo)
async def get_agent(name: str) -> AgentInfo:
"""Get single agent details."""
agent = state.get_agent(name)
if agent is None:
raise HTTPException(status_code=404, detail=f"Agent not found: {name}")
return agent
@router.get("/agents/{name}/config")
async def get_agent_config(name: str) -> dict:
"""Get agent's YAML config section."""
agent = state.get_agent(name)
if agent is None:
raise HTTPException(status_code=404, detail=f"Agent not found: {name}")
# Return relevant config fields
return {
"name": agent.name,
"description": agent.description,
"isAgent": agent.is_agent,
"peers": agent.peers,
"payloadClass": agent.payload_class,
}
@router.get("/agents/{name}/schema")
async def get_agent_schema(name: str) -> dict:
"""Get agent's payload XML schema."""
schema = state.get_agent_schema(name)
if schema is None:
raise HTTPException(
status_code=404,
detail=f"Schema not found for agent: {name}",
)
return {"schema": schema, "contentType": "application/xml"}
# =========================================================================
# Thread Endpoints
# =========================================================================
@router.get("/threads", response_model=ThreadListResponse)
async def list_threads(
status: Optional[str] = Query(None, description="Filter by status"),
agent: Optional[str] = Query(None, description="Filter by participant agent"),
limit: int = Query(50, ge=1, le=100),
offset: int = Query(0, ge=0),
) -> ThreadListResponse:
"""List threads with optional filtering."""
thread_status = None
if status:
try:
thread_status = ThreadStatus(status)
except ValueError:
raise HTTPException(
status_code=400,
detail=f"Invalid status: {status}. Valid values: {[s.value for s in ThreadStatus]}",
)
threads, total = state.get_threads(
status=thread_status,
agent=agent,
limit=limit,
offset=offset,
)
return ThreadListResponse(
threads=threads,
count=len(threads),
total=total,
offset=offset,
limit=limit,
)
@router.get("/threads/{thread_id}", response_model=ThreadInfo)
async def get_thread(thread_id: str) -> ThreadInfo:
"""Get thread details with message history."""
thread = state.get_thread(thread_id)
if thread is None:
raise HTTPException(status_code=404, detail=f"Thread not found: {thread_id}")
return thread
@router.get("/threads/{thread_id}/messages", response_model=MessageListResponse)
async def get_thread_messages(
thread_id: str,
limit: int = Query(50, ge=1, le=100),
offset: int = Query(0, ge=0),
) -> MessageListResponse:
"""Get messages in a specific thread."""
thread = state.get_thread(thread_id)
if thread is None:
raise HTTPException(status_code=404, detail=f"Thread not found: {thread_id}")
messages, total = state.get_messages(
thread_id=thread_id,
limit=limit,
offset=offset,
)
return MessageListResponse(
messages=messages,
count=len(messages),
total=total,
offset=offset,
limit=limit,
)
@router.post("/threads/{thread_id}/kill")
async def kill_thread(thread_id: str) -> dict:
"""Terminate a thread."""
thread = state.get_thread(thread_id)
if thread is None:
raise HTTPException(status_code=404, detail=f"Thread not found: {thread_id}")
await state.complete_thread(thread_id, status=ThreadStatus.KILLED)
return {"success": True, "threadId": thread_id}
# =========================================================================
# Message Endpoints
# =========================================================================
@router.get("/messages", response_model=MessageListResponse)
async def list_messages(
agent: Optional[str] = Query(None, description="Filter by agent (sender or receiver)"),
limit: int = Query(50, ge=1, le=100),
offset: int = Query(0, ge=0),
) -> MessageListResponse:
"""Get global message history."""
messages, total = state.get_messages(
agent=agent,
limit=limit,
offset=offset,
)
return MessageListResponse(
messages=messages,
count=len(messages),
total=total,
offset=offset,
limit=limit,
)
# =========================================================================
# Control Endpoints
# =========================================================================
@router.post("/inject", response_model=InjectResponse)
async def inject_message(request: InjectRequest) -> InjectResponse:
"""Inject a message to an agent."""
# Validate target exists
agent = state.get_agent(request.to)
if agent is None:
raise HTTPException(
status_code=400,
detail=f"Unknown target agent: {request.to}",
)
# Generate or use provided thread ID
thread_id = request.thread_id or str(uuid.uuid4())
# Build payload XML from dict
# For now, we construct a simple wrapper
payload_type = next(iter(request.payload.keys()), "Payload")
# Record the message
msg_id = await state.record_message(
thread_id=thread_id,
from_id="api",
to_id=request.to,
payload_type=payload_type,
payload=request.payload,
)
# TODO: Actually inject into pump queue
# This requires building an envelope and calling pump.inject()
return InjectResponse(thread_id=thread_id, message_id=msg_id)
@router.post("/agents/{name}/pause")
async def pause_agent(name: str) -> dict:
"""Pause an agent (stop processing new messages)."""
agent = state.get_agent(name)
if agent is None:
raise HTTPException(status_code=404, detail=f"Agent not found: {name}")
from xml_pipeline.server.models import AgentState
await state.update_agent_state(name, AgentState.PAUSED)
return {"success": True, "agent": name, "state": "paused"}
@router.post("/agents/{name}/resume")
async def resume_agent(name: str) -> dict:
"""Resume a paused agent."""
agent = state.get_agent(name)
if agent is None:
raise HTTPException(status_code=404, detail=f"Agent not found: {name}")
from xml_pipeline.server.models import AgentState
await state.update_agent_state(name, AgentState.IDLE)
return {"success": True, "agent": name, "state": "idle"}
@router.post("/organism/reload")
async def reload_config() -> dict:
"""Hot-reload organism configuration."""
# TODO: Implement hot-reload
return {"success": False, "error": "Hot-reload not yet implemented"}
@router.post("/organism/stop")
async def stop_organism() -> dict:
"""Graceful shutdown."""
state.set_stopping()
# TODO: Signal pump to stop
return {"success": True, "status": "stopping"}
return router

148
xml_pipeline/server/app.py Normal file
View file

@ -0,0 +1,148 @@
"""
app.py FastAPI application factory for AgentServer.
Creates the FastAPI app that combines REST API and WebSocket endpoints.
"""
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from xml_pipeline.server.api import create_router
from xml_pipeline.server.state import ServerState
from xml_pipeline.server.websocket import create_websocket_router
if TYPE_CHECKING:
from xml_pipeline.message_bus.stream_pump import StreamPump
def create_app(
pump: "StreamPump",
*,
title: str = "AgentServer API",
version: str = "1.0.0",
cors_origins: Optional[list[str]] = None,
) -> FastAPI:
"""
Create FastAPI application with REST and WebSocket endpoints.
Args:
pump: The StreamPump instance to wrap
title: API title for OpenAPI docs
version: API version
cors_origins: List of allowed CORS origins (default: all)
Returns:
Configured FastAPI application
"""
# Create state manager
state = ServerState(pump)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Manage app lifecycle - startup and shutdown."""
# Startup
state.set_running()
yield
# Shutdown
state.set_stopping()
app = FastAPI(
title=title,
version=version,
description="REST and WebSocket API for monitoring and controlling xml-pipeline organisms.",
lifespan=lifespan,
)
# CORS middleware
if cors_origins is None:
cors_origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include routers
app.include_router(create_router(state))
app.include_router(create_websocket_router(state))
# Store state on app for access if needed
app.state.server_state = state
app.state.pump = pump
@app.get("/health")
async def health_check() -> dict[str, Any]:
"""Health check endpoint."""
info = state.get_organism_info()
return {
"status": "healthy",
"organism": info.name,
"uptime_seconds": info.uptime_seconds,
}
return app
async def run_server(
pump: "StreamPump",
*,
host: str = "0.0.0.0",
port: int = 8080,
cors_origins: Optional[list[str]] = None,
) -> None:
"""
Run the AgentServer with uvicorn.
Args:
pump: The StreamPump instance to wrap
host: Host to bind to
port: Port to listen on
cors_origins: List of allowed CORS origins
"""
try:
import uvicorn
except ImportError as e:
raise ImportError(
"uvicorn is required for the server. Install with: pip install xml-pipeline[server]"
) from e
app = create_app(pump, cors_origins=cors_origins)
config = uvicorn.Config(
app,
host=host,
port=port,
log_level="info",
)
server = uvicorn.Server(config)
await server.serve()
def run_server_sync(
pump: "StreamPump",
*,
host: str = "0.0.0.0",
port: int = 8080,
cors_origins: Optional[list[str]] = None,
) -> None:
"""
Run the AgentServer synchronously (blocking).
This is a convenience wrapper for CLI usage.
Args:
pump: The StreamPump instance to wrap
host: Host to bind to
port: Port to listen on
cors_origins: List of allowed CORS origins
"""
asyncio.run(run_server(pump, host=host, port=port, cors_origins=cors_origins))

View file

@ -0,0 +1,261 @@
"""
models.py Pydantic models for AgentServer API.
These models define the JSON structure for API responses.
Uses camelCase for JSON keys (JavaScript convention).
"""
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field
def to_camel(string: str) -> str:
"""Convert snake_case to camelCase."""
components = string.split("_")
return components[0] + "".join(x.title() for x in components[1:])
class CamelModel(BaseModel):
"""Base model with camelCase JSON serialization."""
model_config = ConfigDict(
alias_generator=to_camel,
populate_by_name=True,
)
# =============================================================================
# Enums
# =============================================================================
class AgentState(str, Enum):
"""Agent processing state."""
IDLE = "idle"
PROCESSING = "processing"
WAITING = "waiting"
ERROR = "error"
PAUSED = "paused"
class ThreadStatus(str, Enum):
"""Thread lifecycle status."""
ACTIVE = "active"
COMPLETED = "completed"
ERROR = "error"
KILLED = "killed"
class OrganismStatus(str, Enum):
"""Organism running status."""
STARTING = "starting"
RUNNING = "running"
STOPPING = "stopping"
STOPPED = "stopped"
# =============================================================================
# API Response Models
# =============================================================================
class AgentInfo(CamelModel):
"""Agent information for API response."""
name: str
description: str
is_agent: bool = Field(alias="isAgent")
peers: List[str] = Field(default_factory=list)
payload_class: str = Field(alias="payloadClass")
state: AgentState = AgentState.IDLE
current_thread: Optional[str] = Field(None, alias="currentThread")
queue_depth: int = Field(0, alias="queueDepth")
last_activity: Optional[datetime] = Field(None, alias="lastActivity")
message_count: int = Field(0, alias="messageCount")
class MessageInfo(CamelModel):
"""Message information for API response."""
id: str
thread_id: str = Field(alias="threadId")
from_id: str = Field(alias="from")
to_id: str = Field(alias="to")
payload_type: str = Field(alias="payloadType")
payload: Dict[str, Any] = Field(default_factory=dict)
timestamp: datetime
slot_index: Optional[int] = Field(None, alias="slotIndex")
class ThreadInfo(CamelModel):
"""Thread information for API response."""
id: str
status: ThreadStatus = ThreadStatus.ACTIVE
participants: List[str] = Field(default_factory=list)
message_count: int = Field(0, alias="messageCount")
created_at: datetime = Field(alias="createdAt")
last_activity: Optional[datetime] = Field(None, alias="lastActivity")
error: Optional[str] = None
class OrganismInfo(CamelModel):
"""Organism overview for API response."""
name: str
status: OrganismStatus = OrganismStatus.RUNNING
uptime_seconds: float = Field(0.0, alias="uptimeSeconds")
agent_count: int = Field(0, alias="agentCount")
active_threads: int = Field(0, alias="activeThreads")
total_messages: int = Field(0, alias="totalMessages")
identity_configured: bool = Field(False, alias="identityConfigured")
class OrganismConfig(CamelModel):
"""Sanitized organism configuration for API response."""
name: str
port: int = 8765
thread_scheduling: str = Field("breadth-first", alias="threadScheduling")
max_concurrent_pipelines: int = Field(50, alias="maxConcurrentPipelines")
max_concurrent_handlers: int = Field(20, alias="maxConcurrentHandlers")
listeners: List[str] = Field(default_factory=list)
# Note: Secrets like API keys are never exposed
# =============================================================================
# Request Models
# =============================================================================
class InjectRequest(CamelModel):
"""Request body for POST /inject."""
to: str
payload: Dict[str, Any]
thread_id: Optional[str] = Field(None, alias="threadId")
class InjectResponse(CamelModel):
"""Response body for POST /inject."""
thread_id: str = Field(alias="threadId")
message_id: str = Field(alias="messageId")
class SubscribeRequest(CamelModel):
"""WebSocket subscription filter."""
threads: Optional[List[str]] = None
agents: Optional[List[str]] = None
payload_types: Optional[List[str]] = Field(None, alias="payloadTypes")
events: Optional[List[str]] = None
# =============================================================================
# WebSocket Event Models
# =============================================================================
class WSEvent(CamelModel):
"""Base WebSocket event."""
event: str
class WSConnectedEvent(WSEvent):
"""Sent on WebSocket connection with full state snapshot."""
event: str = "connected"
organism: OrganismInfo
agents: List[AgentInfo]
threads: List[ThreadInfo]
class WSAgentStateEvent(WSEvent):
"""Agent state changed."""
event: str = "agent_state"
agent: str
state: AgentState
current_thread: Optional[str] = Field(None, alias="currentThread")
class WSMessageEvent(WSEvent):
"""New message in the system."""
event: str = "message"
message: MessageInfo
class WSThreadCreatedEvent(WSEvent):
"""New thread started."""
event: str = "thread_created"
thread: ThreadInfo
class WSThreadUpdatedEvent(WSEvent):
"""Thread status changed."""
event: str = "thread_updated"
thread_id: str = Field(alias="threadId")
status: ThreadStatus
message_count: int = Field(alias="messageCount")
class WSErrorEvent(WSEvent):
"""Error occurred."""
event: str = "error"
thread_id: Optional[str] = Field(None, alias="threadId")
agent: Optional[str] = None
error: str
timestamp: datetime
# =============================================================================
# List Response Models
# =============================================================================
class AgentListResponse(CamelModel):
"""Response for GET /agents."""
agents: List[AgentInfo]
count: int
class ThreadListResponse(CamelModel):
"""Response for GET /threads."""
threads: List[ThreadInfo]
count: int
total: int
offset: int
limit: int
class MessageListResponse(CamelModel):
"""Response for GET /messages or /threads/{id}/messages."""
messages: List[MessageInfo]
count: int
total: int
offset: int
limit: int
class ErrorResponse(CamelModel):
"""Error response."""
error: str
detail: Optional[str] = None

View file

@ -0,0 +1,515 @@
"""
state.py Server state manager for tracking organism runtime state.
Maintains real-time state for API queries:
- Agent states (idle, processing, waiting, error)
- Active threads and their participants
- Message counts and activity timestamps
This is the bridge between the StreamPump and the API.
"""
from __future__ import annotations
import asyncio
import time
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set
from xml_pipeline.server.models import (
AgentInfo,
AgentState,
MessageInfo,
OrganismInfo,
OrganismStatus,
ThreadInfo,
ThreadStatus,
)
if TYPE_CHECKING:
from xml_pipeline.message_bus.stream_pump import StreamPump
from xml_pipeline.message_bus import (
PumpEvent,
MessageReceivedEvent,
MessageSentEvent,
AgentStateEvent,
ThreadEvent,
)
@dataclass
class AgentRuntimeState:
"""Runtime state for a single agent."""
name: str
description: str
is_agent: bool
peers: List[str]
payload_class: str
state: AgentState = AgentState.IDLE
current_thread: Optional[str] = None
last_activity: Optional[datetime] = None
message_count: int = 0
queue_depth: int = 0
@dataclass
class ThreadRuntimeState:
"""Runtime state for a single thread."""
id: str
status: ThreadStatus = ThreadStatus.ACTIVE
participants: Set[str] = field(default_factory=set)
message_count: int = 0
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
last_activity: Optional[datetime] = None
error: Optional[str] = None
@dataclass
class MessageRecord:
"""Record of a message for history."""
id: str
thread_id: str
from_id: str
to_id: str
payload_type: str
payload: Dict[str, Any]
timestamp: datetime
slot_index: int
class ServerState:
"""
Centralized state manager for the API server.
Tracks runtime state of agents, threads, and messages.
Provides methods for the API to query current state.
"""
def __init__(self, pump: StreamPump) -> None:
self.pump = pump
self._start_time = time.time()
self._status = OrganismStatus.STARTING
# Runtime state
self._agents: Dict[str, AgentRuntimeState] = {}
self._threads: Dict[str, ThreadRuntimeState] = {}
self._messages: List[MessageRecord] = []
self._message_count = 0
# Event subscribers (WebSocket connections)
self._subscribers: Set[Callable] = set()
# Lock for thread-safe updates
self._lock = asyncio.Lock()
# Initialize agent states from pump
self._init_agents()
# Subscribe to pump events
self._subscribe_to_pump()
def _init_agents(self) -> None:
"""Initialize agent states from pump listeners."""
for name, listener in self.pump.listeners.items():
self._agents[name] = AgentRuntimeState(
name=name,
description=listener.description,
is_agent=listener.is_agent,
peers=list(listener.peers),
payload_class=f"{listener.payload_class.__module__}.{listener.payload_class.__name__}",
)
def _subscribe_to_pump(self) -> None:
"""Subscribe to pump events for real-time state updates."""
from xml_pipeline.message_bus import (
PumpEvent,
MessageReceivedEvent,
MessageSentEvent,
AgentStateEvent as PumpAgentStateEvent,
ThreadEvent,
)
def handle_pump_event(event: "PumpEvent") -> None:
"""Handle pump events synchronously (schedules async updates)."""
if isinstance(event, MessageReceivedEvent):
# Schedule async message recording
asyncio.create_task(self._handle_message_received(event))
elif isinstance(event, MessageSentEvent):
# Schedule async message recording
asyncio.create_task(self._handle_message_sent(event))
elif isinstance(event, PumpAgentStateEvent):
# Schedule async agent state update
asyncio.create_task(self._handle_agent_state(event))
elif isinstance(event, ThreadEvent):
# Schedule async thread update
asyncio.create_task(self._handle_thread_event(event))
self.pump.subscribe_events(handle_pump_event)
async def _handle_message_received(self, event: "MessageReceivedEvent") -> None:
"""Handle message received by handler."""
# Convert payload to dict representation
payload_dict = {}
if hasattr(event.payload, '__dataclass_fields__'):
import dataclasses
payload_dict = dataclasses.asdict(event.payload)
elif isinstance(event.payload, dict):
payload_dict = event.payload
await self.record_message(
thread_id=event.thread_id,
from_id=event.from_id,
to_id=event.to_id,
payload_type=event.payload_type,
payload=payload_dict,
)
async def _handle_message_sent(self, event: "MessageSentEvent") -> None:
"""Handle message sent by handler."""
payload_dict = {}
if hasattr(event.payload, '__dataclass_fields__'):
import dataclasses
payload_dict = dataclasses.asdict(event.payload)
elif isinstance(event.payload, dict):
payload_dict = event.payload
await self.record_message(
thread_id=event.thread_id,
from_id=event.from_id,
to_id=event.to_id,
payload_type=event.payload_type,
payload=payload_dict,
)
async def _handle_agent_state(self, event: "AgentStateEvent") -> None:
"""Handle agent state change from pump."""
# Map pump state strings to AgentState enum
state_map = {
"idle": AgentState.IDLE,
"processing": AgentState.PROCESSING,
"waiting": AgentState.WAITING,
"error": AgentState.ERROR,
"paused": AgentState.PAUSED,
}
state = state_map.get(event.state, AgentState.IDLE)
await self.update_agent_state(event.agent_name, state, event.thread_id)
async def _handle_thread_event(self, event: "ThreadEvent") -> None:
"""Handle thread lifecycle event from pump."""
# Map status to ThreadStatus enum
status_map = {
"created": ThreadStatus.ACTIVE,
"active": ThreadStatus.ACTIVE,
"completed": ThreadStatus.COMPLETED,
"error": ThreadStatus.ERROR,
"killed": ThreadStatus.KILLED,
}
status = status_map.get(event.status, ThreadStatus.ACTIVE)
if event.status in ("completed", "error", "killed"):
await self.complete_thread(event.thread_id, status, event.error)
def set_running(self) -> None:
"""Mark organism as running."""
self._status = OrganismStatus.RUNNING
def set_stopping(self) -> None:
"""Mark organism as stopping."""
self._status = OrganismStatus.STOPPING
# =========================================================================
# Event Recording (called by pump hooks)
# =========================================================================
async def record_message(
self,
thread_id: str,
from_id: str,
to_id: str,
payload_type: str,
payload: Dict[str, Any],
) -> str:
"""Record a message and update related state."""
async with self._lock:
msg_id = str(uuid.uuid4())
now = datetime.now(timezone.utc)
self._message_count += 1
record = MessageRecord(
id=msg_id,
thread_id=thread_id,
from_id=from_id,
to_id=to_id,
payload_type=payload_type,
payload=payload,
timestamp=now,
slot_index=self._message_count,
)
self._messages.append(record)
# Update thread state
if thread_id not in self._threads:
self._threads[thread_id] = ThreadRuntimeState(
id=thread_id,
created_at=now,
)
thread = self._threads[thread_id]
thread.participants.add(from_id)
thread.participants.add(to_id)
thread.message_count += 1
thread.last_activity = now
# Update agent states
if from_id in self._agents:
self._agents[from_id].message_count += 1
self._agents[from_id].last_activity = now
# Notify subscribers
await self._notify_message(record)
return msg_id
async def update_agent_state(
self,
agent_name: str,
state: AgentState,
current_thread: Optional[str] = None,
) -> None:
"""Update an agent's processing state."""
async with self._lock:
if agent_name not in self._agents:
return
agent = self._agents[agent_name]
old_state = agent.state
agent.state = state
agent.current_thread = current_thread
agent.last_activity = datetime.now(timezone.utc)
if old_state != state:
await self._notify_agent_state(agent_name, state, current_thread)
async def complete_thread(
self,
thread_id: str,
status: ThreadStatus = ThreadStatus.COMPLETED,
error: Optional[str] = None,
) -> None:
"""Mark a thread as completed or errored."""
async with self._lock:
if thread_id not in self._threads:
return
thread = self._threads[thread_id]
thread.status = status
thread.error = error
thread.last_activity = datetime.now(timezone.utc)
await self._notify_thread_updated(thread)
# =========================================================================
# Query Methods (for API)
# =========================================================================
def get_organism_info(self) -> OrganismInfo:
"""Get organism overview."""
return OrganismInfo(
name=self.pump.config.name,
status=self._status,
uptime_seconds=time.time() - self._start_time,
agent_count=len(self._agents),
active_threads=sum(
1 for t in self._threads.values() if t.status == ThreadStatus.ACTIVE
),
total_messages=self._message_count,
identity_configured=self.pump.identity is not None,
)
def get_organism_config(self) -> Dict[str, Any]:
"""Get sanitized organism config (no secrets)."""
return {
"name": self.pump.config.name,
"port": self.pump.config.port,
"thread_scheduling": self.pump.config.thread_scheduling,
"max_concurrent_pipelines": self.pump.config.max_concurrent_pipelines,
"max_concurrent_handlers": self.pump.config.max_concurrent_handlers,
"listeners": list(self.pump.listeners.keys()),
}
def get_agents(self) -> List[AgentInfo]:
"""Get all agents."""
return [self._agent_to_info(a) for a in self._agents.values()]
def get_agent(self, name: str) -> Optional[AgentInfo]:
"""Get a single agent by name."""
agent = self._agents.get(name)
if agent:
return self._agent_to_info(agent)
return None
def get_agent_schema(self, name: str) -> Optional[str]:
"""Get agent's XSD schema as string."""
listener = self.pump.listeners.get(name)
if listener and listener.schema is not None:
from lxml import etree
return etree.tostring(listener.schema, encoding="unicode", pretty_print=True)
return None
def get_threads(
self,
status: Optional[ThreadStatus] = None,
agent: Optional[str] = None,
limit: int = 50,
offset: int = 0,
) -> tuple[List[ThreadInfo], int]:
"""Get threads with optional filtering."""
threads = list(self._threads.values())
# Filter
if status:
threads = [t for t in threads if t.status == status]
if agent:
threads = [t for t in threads if agent in t.participants]
# Sort by last activity (most recent first)
threads.sort(key=lambda t: t.last_activity or t.created_at, reverse=True)
total = len(threads)
threads = threads[offset : offset + limit]
return [self._thread_to_info(t) for t in threads], total
def get_thread(self, thread_id: str) -> Optional[ThreadInfo]:
"""Get a single thread by ID."""
thread = self._threads.get(thread_id)
if thread:
return self._thread_to_info(thread)
return None
def get_messages(
self,
thread_id: Optional[str] = None,
agent: Optional[str] = None,
limit: int = 50,
offset: int = 0,
) -> tuple[List[MessageInfo], int]:
"""Get messages with optional filtering."""
messages = self._messages.copy()
# Filter
if thread_id:
messages = [m for m in messages if m.thread_id == thread_id]
if agent:
messages = [m for m in messages if m.from_id == agent or m.to_id == agent]
# Sort by timestamp (most recent first)
messages.sort(key=lambda m: m.timestamp, reverse=True)
total = len(messages)
messages = messages[offset : offset + limit]
return [self._message_to_info(m) for m in messages], total
# =========================================================================
# Subscription Management
# =========================================================================
def subscribe(self, callback: Callable) -> None:
"""Subscribe to state events."""
self._subscribers.add(callback)
def unsubscribe(self, callback: Callable) -> None:
"""Unsubscribe from state events."""
self._subscribers.discard(callback)
async def _notify_message(self, record: MessageRecord) -> None:
"""Notify subscribers of new message."""
event = {
"event": "message",
"message": self._message_to_info(record).model_dump(by_alias=True),
}
await self._broadcast(event)
async def _notify_agent_state(
self,
agent_name: str,
state: AgentState,
current_thread: Optional[str],
) -> None:
"""Notify subscribers of agent state change."""
event = {
"event": "agent_state",
"agent": agent_name,
"state": state.value,
"currentThread": current_thread,
}
await self._broadcast(event)
async def _notify_thread_updated(self, thread: ThreadRuntimeState) -> None:
"""Notify subscribers of thread update."""
event = {
"event": "thread_updated",
"threadId": thread.id,
"status": thread.status.value,
"messageCount": thread.message_count,
}
await self._broadcast(event)
async def _broadcast(self, event: Dict[str, Any]) -> None:
"""Broadcast event to all subscribers."""
for callback in list(self._subscribers):
try:
await callback(event)
except Exception:
# Remove failed subscribers
self._subscribers.discard(callback)
# =========================================================================
# Conversion Helpers
# =========================================================================
def _agent_to_info(self, agent: AgentRuntimeState) -> AgentInfo:
"""Convert runtime state to API model."""
return AgentInfo(
name=agent.name,
description=agent.description,
is_agent=agent.is_agent,
peers=agent.peers,
payload_class=agent.payload_class,
state=agent.state,
current_thread=agent.current_thread,
queue_depth=agent.queue_depth,
last_activity=agent.last_activity,
message_count=agent.message_count,
)
def _thread_to_info(self, thread: ThreadRuntimeState) -> ThreadInfo:
"""Convert runtime state to API model."""
return ThreadInfo(
id=thread.id,
status=thread.status,
participants=list(thread.participants),
message_count=thread.message_count,
created_at=thread.created_at,
last_activity=thread.last_activity,
error=thread.error,
)
def _message_to_info(self, record: MessageRecord) -> MessageInfo:
"""Convert record to API model."""
return MessageInfo(
id=record.id,
thread_id=record.thread_id,
from_id=record.from_id,
to_id=record.to_id,
payload_type=record.payload_type,
payload=record.payload,
timestamp=record.timestamp,
slot_index=record.slot_index,
)

View file

@ -0,0 +1,316 @@
"""
websocket.py WebSocket endpoints for AgentServer.
Provides:
- /ws Main control channel with state snapshot and real-time events
- /ws/messages Dedicated message log stream with filtering
"""
from __future__ import annotations
import asyncio
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from xml_pipeline.server.models import (
SubscribeRequest,
WSConnectedEvent,
)
if TYPE_CHECKING:
from xml_pipeline.server.state import ServerState
class ConnectionManager:
"""Manages WebSocket connections and subscriptions."""
def __init__(self) -> None:
self.active_connections: Set[WebSocket] = set()
self.subscriptions: Dict[WebSocket, SubscribeRequest] = {}
async def connect(self, websocket: WebSocket) -> None:
"""Accept a new WebSocket connection."""
await websocket.accept()
self.active_connections.add(websocket)
self.subscriptions[websocket] = SubscribeRequest() # Default: all events
def disconnect(self, websocket: WebSocket) -> None:
"""Remove a WebSocket connection."""
self.active_connections.discard(websocket)
self.subscriptions.pop(websocket, None)
def set_subscription(self, websocket: WebSocket, sub: SubscribeRequest) -> None:
"""Update subscription filters for a connection."""
self.subscriptions[websocket] = sub
def should_send(self, websocket: WebSocket, event: Dict[str, Any]) -> bool:
"""Check if an event should be sent to this connection based on filters."""
sub = self.subscriptions.get(websocket)
if sub is None:
return True # No filters = send all
event_type = event.get("event", "")
# Filter by event type
if sub.events and event_type not in sub.events:
return False
# Filter by thread
thread_id = event.get("thread_id") or event.get("threadId")
if sub.threads and thread_id and thread_id not in sub.threads:
return False
# Filter by agent
agent = event.get("agent")
if sub.agents and agent and agent not in sub.agents:
return False
# For message events, check from/to
if event_type == "message":
msg = event.get("message", {})
from_id = msg.get("from") or msg.get("fromId")
to_id = msg.get("to") or msg.get("toId")
if sub.agents:
if from_id not in sub.agents and to_id not in sub.agents:
return False
# Filter by payload type
if sub.payload_types:
payload_type = None
if event_type == "message":
payload_type = event.get("message", {}).get("payloadType")
if payload_type and payload_type not in sub.payload_types:
return False
return True
async def broadcast(self, event: Dict[str, Any]) -> None:
"""Broadcast an event to all connections that match their subscription."""
disconnected: List[WebSocket] = []
for websocket in self.active_connections:
if self.should_send(websocket, event):
try:
await websocket.send_json(event)
except Exception:
disconnected.append(websocket)
# Clean up disconnected clients
for ws in disconnected:
self.disconnect(ws)
class MessageStreamManager:
"""Manages WebSocket connections for /ws/messages endpoint."""
def __init__(self) -> None:
self.active_connections: Set[WebSocket] = set()
self.filters: Dict[WebSocket, Dict[str, Any]] = {}
async def connect(self, websocket: WebSocket) -> None:
"""Accept a new WebSocket connection."""
await websocket.accept()
self.active_connections.add(websocket)
self.filters[websocket] = {} # Default: all messages
def disconnect(self, websocket: WebSocket) -> None:
"""Remove a WebSocket connection."""
self.active_connections.discard(websocket)
self.filters.pop(websocket, None)
def set_filter(self, websocket: WebSocket, filter_config: Dict[str, Any]) -> None:
"""Update message filter for a connection."""
self.filters[websocket] = filter_config
def should_send(self, websocket: WebSocket, message: Dict[str, Any]) -> bool:
"""Check if a message should be sent to this connection."""
flt = self.filters.get(websocket, {})
if not flt:
return True
# Filter by agents
agents = flt.get("agents", [])
if agents:
from_id = message.get("from") or message.get("fromId")
to_id = message.get("to") or message.get("toId")
if from_id not in agents and to_id not in agents:
return False
# Filter by threads
threads = flt.get("threads", [])
if threads:
thread_id = message.get("thread_id") or message.get("threadId")
if thread_id not in threads:
return False
# Filter by payload types
payload_types = flt.get("payload_types", [])
if payload_types:
payload_type = message.get("payloadType") or message.get("payload_type")
if payload_type not in payload_types:
return False
return True
async def broadcast_message(self, message: Dict[str, Any]) -> None:
"""Broadcast a message to all filtered connections."""
disconnected: List[WebSocket] = []
for websocket in self.active_connections:
if self.should_send(websocket, message):
try:
await websocket.send_json(message)
except Exception:
disconnected.append(websocket)
for ws in disconnected:
self.disconnect(ws)
def create_websocket_router(state: "ServerState") -> APIRouter:
"""Create WebSocket router with state dependency."""
router = APIRouter()
manager = ConnectionManager()
message_manager = MessageStreamManager()
# Subscribe state to WebSocket broadcasting
async def on_state_event(event: Dict[str, Any]) -> None:
"""Forward state events to WebSocket connections."""
await manager.broadcast(event)
# Also forward message events to the message stream
if event.get("event") == "message":
msg = event.get("message", {})
await message_manager.broadcast_message(msg)
state.subscribe(on_state_event)
@router.websocket("/ws")
async def websocket_control(websocket: WebSocket) -> None:
"""
Main control channel WebSocket.
On connect, sends full state snapshot. Then pushes events as they occur.
Accepts commands: subscribe, inject.
"""
await manager.connect(websocket)
try:
# Send connected event with state snapshot
threads, _ = state.get_threads(limit=100)
connected_event = WSConnectedEvent(
organism=state.get_organism_info(),
agents=state.get_agents(),
threads=threads,
)
await websocket.send_json(connected_event.model_dump(by_alias=True))
# Listen for commands
while True:
try:
data = await websocket.receive_json()
except Exception:
# Connection closed or invalid data
break
cmd = data.get("cmd", "")
if cmd == "subscribe":
# Update subscription filters
sub = SubscribeRequest(
threads=data.get("threads"),
agents=data.get("agents"),
payload_types=data.get("payload_types"),
events=data.get("events"),
)
manager.set_subscription(websocket, sub)
await websocket.send_json({"event": "subscribed", "filters": data})
elif cmd == "inject":
# Inject a message (same as REST /inject)
target = data.get("to")
payload = data.get("payload", {})
thread_id = data.get("thread_id")
if not target:
await websocket.send_json(
{"event": "error", "error": "Missing 'to' field"}
)
continue
agent = state.get_agent(target)
if agent is None:
await websocket.send_json(
{"event": "error", "error": f"Unknown agent: {target}"}
)
continue
import uuid
thread_id = thread_id or str(uuid.uuid4())
payload_type = next(iter(payload.keys()), "Payload")
msg_id = await state.record_message(
thread_id=thread_id,
from_id="api",
to_id=target,
payload_type=payload_type,
payload=payload,
)
await websocket.send_json(
{
"event": "injected",
"thread_id": thread_id,
"message_id": msg_id,
}
)
else:
await websocket.send_json(
{"event": "error", "error": f"Unknown command: {cmd}"}
)
except WebSocketDisconnect:
pass
finally:
manager.disconnect(websocket)
@router.websocket("/ws/messages")
async def websocket_messages(websocket: WebSocket) -> None:
"""
Dedicated message log stream.
Streams all messages flowing through the organism.
Clients can filter by agents, threads, and payload types.
"""
await message_manager.connect(websocket)
try:
while True:
try:
data = await websocket.receive_json()
except Exception:
break
cmd = data.get("cmd", "")
if cmd == "subscribe":
# Update filter
filter_config = data.get("filter", {})
message_manager.set_filter(websocket, filter_config)
await websocket.send_json({"event": "subscribed", "filter": filter_config})
else:
await websocket.send_json(
{"event": "error", "error": f"Unknown command: {cmd}"}
)
except WebSocketDisconnect:
pass
finally:
message_manager.disconnect(websocket)
return router