diff --git a/pyproject.toml b/pyproject.toml
index 7fc5ff4..ffe4d60 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -80,6 +80,13 @@ search = ["duckduckgo-search>=6.0"] # Web search tool
# Console example (optional, for interactive use)
console = ["prompt_toolkit>=3.0"]
+# API server (FastAPI + WebSocket)
+server = [
+ "fastapi>=0.109",
+ "uvicorn[standard]>=0.27",
+ "websockets>=12.0",
+]
+
# All LLM providers
llm = ["xml-pipeline[anthropic,openai]"]
@@ -87,7 +94,7 @@ llm = ["xml-pipeline[anthropic,openai]"]
tools = ["xml-pipeline[redis,search]"]
# Everything (for local development)
-all = ["xml-pipeline[llm,tools,console]"]
+all = ["xml-pipeline[llm,tools,console,server]"]
# Testing
test = [
diff --git a/tests/test_server.py b/tests/test_server.py
new file mode 100644
index 0000000..7b18476
--- /dev/null
+++ b/tests/test_server.py
@@ -0,0 +1,510 @@
+"""
+Tests for the AgentServer API.
+
+Tests the REST API endpoints and WebSocket connections.
+"""
+
+import asyncio
+import pytest
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional
+from unittest.mock import MagicMock, AsyncMock, patch
+
+# Skip all tests if FastAPI not available
+pytest.importorskip("fastapi")
+
+from fastapi.testclient import TestClient
+
+from xml_pipeline.server.models import (
+ AgentState,
+ AgentInfo,
+ ThreadStatus,
+ ThreadInfo,
+ MessageInfo,
+ OrganismInfo,
+ OrganismStatus,
+)
+from xml_pipeline.server.state import ServerState, AgentRuntimeState, ThreadRuntimeState
+from xml_pipeline.server.api import create_router
+from xml_pipeline.server.app import create_app
+
+
+# ============================================================================
+# Mock StreamPump
+# ============================================================================
+
+@dataclass
+class MockListener:
+ """Mock listener for testing."""
+ name: str
+ description: str
+ is_agent: bool = False
+ peers: List[str] = None
+ payload_class: type = None
+ schema: Any = None
+
+ def __post_init__(self):
+ if self.peers is None:
+ self.peers = []
+ if self.payload_class is None:
+ self.payload_class = type("MockPayload", (), {"__module__": "test", "__name__": "MockPayload"})
+
+
+@dataclass
+class MockConfig:
+ """Mock config for testing."""
+ name: str = "test-organism"
+ port: int = 8765
+ thread_scheduling: str = "breadth-first"
+ max_concurrent_pipelines: int = 50
+ max_concurrent_handlers: int = 20
+
+
+class MockStreamPump:
+ """Mock StreamPump for testing."""
+
+ def __init__(self):
+ self.config = MockConfig()
+ self.identity = None
+ self.listeners: Dict[str, MockListener] = {
+ "greeter": MockListener(
+ name="greeter",
+ description="Greeting agent",
+ is_agent=True,
+ peers=["shouter"],
+ ),
+ "shouter": MockListener(
+ name="shouter",
+ description="Shouting handler",
+ is_agent=False,
+ ),
+ }
+ self._event_callbacks = []
+
+ def subscribe_events(self, callback):
+ self._event_callbacks.append(callback)
+
+ def unsubscribe_events(self, callback):
+ if callback in self._event_callbacks:
+ self._event_callbacks.remove(callback)
+
+
+# ============================================================================
+# Test Fixtures
+# ============================================================================
+
+@pytest.fixture
+def mock_pump():
+ """Create a mock StreamPump."""
+ return MockStreamPump()
+
+
+@pytest.fixture
+def server_state(mock_pump):
+ """Create ServerState with mock pump."""
+ return ServerState(mock_pump)
+
+
+@pytest.fixture
+def test_client(mock_pump):
+ """Create FastAPI test client."""
+ app = create_app(mock_pump)
+ return TestClient(app)
+
+
+# ============================================================================
+# Test Models
+# ============================================================================
+
+class TestModels:
+ """Test Pydantic model serialization."""
+
+ def test_agent_info_camel_case(self):
+ """Test AgentInfo serializes to camelCase."""
+ agent = AgentInfo(
+ name="greeter",
+ description="Test agent",
+ is_agent=True,
+ peers=["shouter"],
+ payload_class="test.Greeting",
+ state=AgentState.IDLE,
+ current_thread=None,
+ queue_depth=0,
+ message_count=5,
+ )
+ data = agent.model_dump(by_alias=True)
+ assert "isAgent" in data
+ assert "payloadClass" in data
+ assert "currentThread" in data
+ assert "queueDepth" in data
+ assert "messageCount" in data
+
+ def test_thread_info_camel_case(self):
+ """Test ThreadInfo serializes to camelCase."""
+ from datetime import datetime, timezone
+ thread = ThreadInfo(
+ id="test-uuid",
+ status=ThreadStatus.ACTIVE,
+ participants=["greeter", "shouter"],
+ message_count=3,
+ created_at=datetime.now(timezone.utc),
+ )
+ data = thread.model_dump(by_alias=True)
+ assert "messageCount" in data
+ assert "createdAt" in data
+ assert "lastActivity" in data
+
+ def test_organism_info_camel_case(self):
+ """Test OrganismInfo serializes to camelCase."""
+ info = OrganismInfo(
+ name="test-organism",
+ status=OrganismStatus.RUNNING,
+ uptime_seconds=3600.0,
+ agent_count=2,
+ active_threads=1,
+ total_messages=10,
+ identity_configured=False,
+ )
+ data = info.model_dump(by_alias=True)
+ assert "uptimeSeconds" in data
+ assert "agentCount" in data
+ assert "activeThreads" in data
+ assert "totalMessages" in data
+ assert "identityConfigured" in data
+
+
+# ============================================================================
+# Test ServerState
+# ============================================================================
+
+class TestServerState:
+ """Test ServerState functionality."""
+
+ def test_init_agents_from_pump(self, server_state):
+ """Test agents are initialized from pump listeners."""
+ assert "greeter" in server_state._agents
+ assert "shouter" in server_state._agents
+ assert server_state._agents["greeter"].is_agent is True
+ assert server_state._agents["shouter"].is_agent is False
+
+ def test_get_organism_info(self, server_state):
+ """Test getting organism info."""
+ info = server_state.get_organism_info()
+ assert info.name == "test-organism"
+ assert info.agent_count == 2
+ assert info.status == OrganismStatus.STARTING
+
+ def test_set_running(self, server_state):
+ """Test setting running status."""
+ server_state.set_running()
+ info = server_state.get_organism_info()
+ assert info.status == OrganismStatus.RUNNING
+
+ def test_get_agents(self, server_state):
+ """Test getting all agents."""
+ agents = server_state.get_agents()
+ assert len(agents) == 2
+ names = [a.name for a in agents]
+ assert "greeter" in names
+ assert "shouter" in names
+
+ def test_get_agent(self, server_state):
+ """Test getting single agent."""
+ agent = server_state.get_agent("greeter")
+ assert agent is not None
+ assert agent.name == "greeter"
+ assert agent.is_agent is True
+
+ def test_get_agent_not_found(self, server_state):
+ """Test getting non-existent agent."""
+ agent = server_state.get_agent("nonexistent")
+ assert agent is None
+
+ @pytest.mark.asyncio
+ async def test_record_message(self, server_state):
+ """Test recording a message."""
+ msg_id = await server_state.record_message(
+ thread_id="test-thread",
+ from_id="greeter",
+ to_id="shouter",
+ payload_type="GreetingResponse",
+ payload={"message": "Hello!"},
+ )
+ assert msg_id is not None
+
+ # Check message was recorded
+ messages, total = server_state.get_messages(thread_id="test-thread")
+ assert total == 1
+ assert messages[0].from_id == "greeter"
+ assert messages[0].to_id == "shouter"
+
+ @pytest.mark.asyncio
+ async def test_update_agent_state(self, server_state):
+ """Test updating agent state."""
+ await server_state.update_agent_state("greeter", AgentState.PROCESSING, "thread-1")
+
+ agent = server_state.get_agent("greeter")
+ assert agent.state == AgentState.PROCESSING
+ assert agent.current_thread == "thread-1"
+
+ @pytest.mark.asyncio
+ async def test_complete_thread(self, server_state):
+ """Test completing a thread."""
+ # First record a message to create the thread
+ await server_state.record_message(
+ thread_id="test-thread",
+ from_id="greeter",
+ to_id="shouter",
+ payload_type="Test",
+ payload={},
+ )
+
+ # Complete the thread
+ await server_state.complete_thread("test-thread", ThreadStatus.COMPLETED)
+
+ thread = server_state.get_thread("test-thread")
+ assert thread.status == ThreadStatus.COMPLETED
+
+ def test_get_threads_with_filter(self, server_state):
+ """Test filtering threads by status."""
+ # No threads initially
+ threads, total = server_state.get_threads(status=ThreadStatus.ACTIVE)
+ assert total == 0
+
+ def test_get_organism_config(self, server_state):
+ """Test getting organism config."""
+ config = server_state.get_organism_config()
+ assert config["name"] == "test-organism"
+ assert config["port"] == 8765
+
+
+# ============================================================================
+# Test REST API
+# ============================================================================
+
+class TestRestAPI:
+ """Test REST API endpoints."""
+
+ def test_health_check(self, test_client):
+ """Test health check endpoint."""
+ response = test_client.get("/health")
+ assert response.status_code == 200
+ data = response.json()
+ assert data["status"] == "healthy"
+ assert "organism" in data
+
+ def test_get_organism(self, test_client):
+ """Test GET /api/v1/organism."""
+ response = test_client.get("/api/v1/organism")
+ assert response.status_code == 200
+ data = response.json()
+ assert data["name"] == "test-organism"
+ assert "agentCount" in data
+
+ def test_get_organism_config(self, test_client):
+ """Test GET /api/v1/organism/config."""
+ response = test_client.get("/api/v1/organism/config")
+ assert response.status_code == 200
+ data = response.json()
+ assert data["name"] == "test-organism"
+
+ def test_list_agents(self, test_client):
+ """Test GET /api/v1/agents."""
+ response = test_client.get("/api/v1/agents")
+ assert response.status_code == 200
+ data = response.json()
+ assert "agents" in data
+ assert data["count"] == 2
+
+ def test_get_agent(self, test_client):
+ """Test GET /api/v1/agents/{name}."""
+ response = test_client.get("/api/v1/agents/greeter")
+ assert response.status_code == 200
+ data = response.json()
+ assert data["name"] == "greeter"
+ assert data["isAgent"] is True
+
+ def test_get_agent_not_found(self, test_client):
+ """Test GET /api/v1/agents/{name} with non-existent agent."""
+ response = test_client.get("/api/v1/agents/nonexistent")
+ assert response.status_code == 404
+
+ def test_get_agent_config(self, test_client):
+ """Test GET /api/v1/agents/{name}/config."""
+ response = test_client.get("/api/v1/agents/greeter/config")
+ assert response.status_code == 200
+ data = response.json()
+ assert data["name"] == "greeter"
+ assert "isAgent" in data
+
+ def test_list_threads(self, test_client):
+ """Test GET /api/v1/threads."""
+ response = test_client.get("/api/v1/threads")
+ assert response.status_code == 200
+ data = response.json()
+ assert "threads" in data
+ assert "count" in data
+ assert "total" in data
+
+ def test_list_threads_with_invalid_status(self, test_client):
+ """Test GET /api/v1/threads with invalid status filter."""
+ response = test_client.get("/api/v1/threads?status=invalid")
+ assert response.status_code == 400
+
+ def test_list_messages(self, test_client):
+ """Test GET /api/v1/messages."""
+ response = test_client.get("/api/v1/messages")
+ assert response.status_code == 200
+ data = response.json()
+ assert "messages" in data
+ assert "count" in data
+
+ def test_inject_message(self, test_client):
+ """Test POST /api/v1/inject."""
+ response = test_client.post(
+ "/api/v1/inject",
+ json={
+ "to": "greeter",
+ "payload": {"name": "Dan"},
+ },
+ )
+ assert response.status_code == 200
+ data = response.json()
+ assert "threadId" in data
+ assert "messageId" in data
+
+ def test_inject_message_unknown_agent(self, test_client):
+ """Test POST /api/v1/inject with unknown agent."""
+ response = test_client.post(
+ "/api/v1/inject",
+ json={
+ "to": "nonexistent",
+ "payload": {"name": "Dan"},
+ },
+ )
+ assert response.status_code == 400
+
+ def test_pause_agent(self, test_client):
+ """Test POST /api/v1/agents/{name}/pause."""
+ response = test_client.post("/api/v1/agents/greeter/pause")
+ assert response.status_code == 200
+ data = response.json()
+ assert data["success"] is True
+ assert data["state"] == "paused"
+
+ def test_resume_agent(self, test_client):
+ """Test POST /api/v1/agents/{name}/resume."""
+ response = test_client.post("/api/v1/agents/greeter/resume")
+ assert response.status_code == 200
+ data = response.json()
+ assert data["success"] is True
+ assert data["state"] == "idle"
+
+ def test_stop_organism(self, test_client):
+ """Test POST /api/v1/organism/stop."""
+ response = test_client.post("/api/v1/organism/stop")
+ assert response.status_code == 200
+ data = response.json()
+ assert data["success"] is True
+
+
+# ============================================================================
+# Test WebSocket
+# ============================================================================
+
+class TestWebSocket:
+ """Test WebSocket endpoints."""
+
+ def test_websocket_connect(self, test_client):
+ """Test WebSocket connection."""
+ with test_client.websocket_connect("/ws") as websocket:
+ # Should receive connected event with state snapshot
+ data = websocket.receive_json()
+ assert data["event"] == "connected"
+ assert "organism" in data
+ assert "agents" in data
+ assert "threads" in data
+
+ def test_websocket_subscribe(self, test_client):
+ """Test WebSocket subscribe command."""
+ with test_client.websocket_connect("/ws") as websocket:
+ # Receive initial connected event
+ websocket.receive_json()
+
+ # Send subscribe command
+ websocket.send_json({
+ "cmd": "subscribe",
+ "agents": ["greeter"],
+ "events": ["message"],
+ })
+
+ # Should receive subscribed confirmation
+ data = websocket.receive_json()
+ assert data["event"] == "subscribed"
+
+ def test_websocket_inject(self, test_client):
+ """Test WebSocket inject command."""
+ with test_client.websocket_connect("/ws") as websocket:
+ # Receive initial connected event
+ websocket.receive_json()
+
+ # Send inject command
+ websocket.send_json({
+ "cmd": "inject",
+ "to": "greeter",
+ "payload": {"name": "Dan"},
+ })
+
+ # Should receive injected confirmation
+ data = websocket.receive_json()
+ assert data["event"] == "injected"
+ assert "thread_id" in data
+ assert "message_id" in data
+
+ def test_websocket_inject_unknown_agent(self, test_client):
+ """Test WebSocket inject with unknown agent."""
+ with test_client.websocket_connect("/ws") as websocket:
+ # Receive initial connected event
+ websocket.receive_json()
+
+ # Send inject command to unknown agent
+ websocket.send_json({
+ "cmd": "inject",
+ "to": "nonexistent",
+ "payload": {},
+ })
+
+ # Should receive error
+ data = websocket.receive_json()
+ assert data["event"] == "error"
+
+ def test_websocket_unknown_command(self, test_client):
+ """Test WebSocket with unknown command."""
+ with test_client.websocket_connect("/ws") as websocket:
+ # Receive initial connected event
+ websocket.receive_json()
+
+ # Send unknown command
+ websocket.send_json({
+ "cmd": "unknown_command",
+ })
+
+ # Should receive error
+ data = websocket.receive_json()
+ assert data["event"] == "error"
+
+ def test_websocket_messages_stream(self, test_client):
+ """Test WebSocket messages stream endpoint."""
+ with test_client.websocket_connect("/ws/messages") as websocket:
+ # Send subscribe command
+ websocket.send_json({
+ "cmd": "subscribe",
+ "filter": {
+ "agents": ["greeter"],
+ },
+ })
+
+ # Should receive subscribed confirmation
+ data = websocket.receive_json()
+ assert data["event"] == "subscribed"
+ assert "filter" in data
diff --git a/xml_pipeline/cli.py b/xml_pipeline/cli.py
index 53497c0..5dd4b9d 100644
--- a/xml_pipeline/cli.py
+++ b/xml_pipeline/cli.py
@@ -3,6 +3,7 @@ xml-pipeline CLI entry point.
Usage:
xml-pipeline run [config.yaml] Run an organism
+ xml-pipeline serve [config.yaml] Run organism with API server
xml-pipeline init [name] Create new organism config
xml-pipeline check [config.yaml] Validate config without running
xml-pipeline version Show version info
@@ -37,6 +38,68 @@ def cmd_run(args: argparse.Namespace) -> int:
return 1
+def cmd_serve(args: argparse.Namespace) -> int:
+ """Run an organism with the AgentServer API."""
+ try:
+ import uvicorn
+ except ImportError:
+ print("Error: uvicorn not installed.", file=sys.stderr)
+ print("Install with: pip install xml-pipeline[server]", file=sys.stderr)
+ return 1
+
+ from xml_pipeline.message_bus import bootstrap
+
+ config_path = Path(args.config)
+ if not config_path.exists():
+ print(f"Error: Config file not found: {config_path}", file=sys.stderr)
+ return 1
+
+ async def run_with_server():
+ """Bootstrap pump and run with server."""
+ from xml_pipeline.server import create_app
+
+ # Bootstrap the pump
+ pump = await bootstrap(str(config_path))
+
+ # Create FastAPI app
+ app = create_app(pump)
+
+ # Run uvicorn
+ config = uvicorn.Config(
+ app,
+ host=args.host,
+ port=args.port,
+ log_level="info",
+ )
+ server = uvicorn.Server(config)
+
+ # Run pump and server concurrently
+ pump_task = asyncio.create_task(pump.run())
+
+ try:
+ await server.serve()
+ finally:
+ await pump.shutdown()
+ pump_task.cancel()
+ try:
+ await pump_task
+ except asyncio.CancelledError:
+ pass
+
+ try:
+ print(f"Starting AgentServer on http://{args.host}:{args.port}")
+ print(f" API docs: http://{args.host}:{args.port}/docs")
+ print(f" WebSocket: ws://{args.host}:{args.port}/ws")
+ asyncio.run(run_with_server())
+ return 0
+ except KeyboardInterrupt:
+ print("\nShutdown requested.")
+ return 0
+ except Exception as e:
+ print(f"Error: {e}", file=sys.stderr)
+ return 1
+
+
def cmd_init(args: argparse.Namespace) -> int:
"""Initialize a new organism config."""
from xml_pipeline.config.template import create_organism_template
@@ -149,6 +212,13 @@ def main() -> int:
run_parser.add_argument("config", nargs="?", default="organism.yaml", help="Config file")
run_parser.set_defaults(func=cmd_run)
+ # serve
+ serve_parser = subparsers.add_parser("serve", help="Run organism with API server")
+ serve_parser.add_argument("config", nargs="?", default="organism.yaml", help="Config file")
+ serve_parser.add_argument("--host", default="0.0.0.0", help="Host to bind (default: 0.0.0.0)")
+ serve_parser.add_argument("--port", "-p", type=int, default=8080, help="Port to listen on (default: 8080)")
+ serve_parser.set_defaults(func=cmd_serve)
+
# init
init_parser = subparsers.add_parser("init", help="Create new organism config")
init_parser.add_argument("name", nargs="?", help="Organism name")
diff --git a/xml_pipeline/message_bus/__init__.py b/xml_pipeline/message_bus/__init__.py
index 8aaffa8..51fe135 100644
--- a/xml_pipeline/message_bus/__init__.py
+++ b/xml_pipeline/message_bus/__init__.py
@@ -33,6 +33,12 @@ from xml_pipeline.message_bus.stream_pump import (
get_stream_pump,
set_stream_pump,
reset_stream_pump,
+ # Event hooks
+ PumpEvent,
+ MessageReceivedEvent,
+ MessageSentEvent,
+ AgentStateEvent,
+ ThreadEvent,
)
from xml_pipeline.message_bus.message_state import (
@@ -71,6 +77,12 @@ __all__ = [
"get_stream_pump",
"set_stream_pump",
"reset_stream_pump",
+ # Event hooks
+ "PumpEvent",
+ "MessageReceivedEvent",
+ "MessageSentEvent",
+ "AgentStateEvent",
+ "ThreadEvent",
# Message state
"MessageState",
"HandlerMetadata",
diff --git a/xml_pipeline/message_bus/stream_pump.py b/xml_pipeline/message_bus/stream_pump.py
index 1e7adef..039da62 100644
--- a/xml_pipeline/message_bus/stream_pump.py
+++ b/xml_pipeline/message_bus/stream_pump.py
@@ -49,6 +49,56 @@ from xml_pipeline.memory import get_context_buffer
pump_logger = logging.getLogger(__name__)
+# ============================================================================
+# Event Hooks
+# ============================================================================
+
+@dataclass
+class PumpEvent:
+ """Base class for pump events."""
+ pass
+
+
+@dataclass
+class MessageReceivedEvent(PumpEvent):
+ """Fired when a message is received by a handler."""
+ thread_id: str
+ from_id: str
+ to_id: str
+ payload_type: str
+ payload: Any
+
+
+@dataclass
+class MessageSentEvent(PumpEvent):
+ """Fired when a handler sends a response."""
+ thread_id: str
+ from_id: str
+ to_id: str
+ payload_type: str
+ payload: Any
+
+
+@dataclass
+class AgentStateEvent(PumpEvent):
+ """Fired when an agent's processing state changes."""
+ agent_name: str
+ state: str # "idle", "processing", "waiting", "error"
+ thread_id: Optional[str] = None
+
+
+@dataclass
+class ThreadEvent(PumpEvent):
+ """Fired when a thread is created or completed."""
+ thread_id: str
+ status: str # "created", "active", "completed", "error", "killed"
+ participants: List[str] = field(default_factory=list)
+ error: Optional[str] = None
+
+
+EventCallback = Callable[[PumpEvent], None]
+
+
# ============================================================================
# Configuration (same as before)
# ============================================================================
@@ -233,6 +283,9 @@ class StreamPump:
# Shutdown control
self._running = False
+ # Event hooks for external observers (ServerState, etc.)
+ self._event_callbacks: List[EventCallback] = []
+
# Process pool for cpu_bound handlers
self._process_pool: Optional[ProcessPoolExecutor] = None
if config.process_pool_enabled:
@@ -256,6 +309,27 @@ class StreamPump:
self._shared_backend = get_shared_backend(backend_config)
pump_logger.info(f"Shared backend: {config.backend_type}")
+ # ------------------------------------------------------------------
+ # Event Hooks
+ # ------------------------------------------------------------------
+
+ def subscribe_events(self, callback: EventCallback) -> None:
+ """Subscribe to pump events (message flow, agent state, thread lifecycle)."""
+ self._event_callbacks.append(callback)
+
+ def unsubscribe_events(self, callback: EventCallback) -> None:
+ """Unsubscribe from pump events."""
+ if callback in self._event_callbacks:
+ self._event_callbacks.remove(callback)
+
+ def _emit_event(self, event: PumpEvent) -> None:
+ """Emit an event to all subscribers (non-blocking)."""
+ for callback in self._event_callbacks:
+ try:
+ callback(event)
+ except Exception as e:
+ pump_logger.warning(f"Event callback error: {e}")
+
# ------------------------------------------------------------------
# Registration
# ------------------------------------------------------------------
@@ -493,6 +567,12 @@ class StreamPump:
await semaphore.acquire()
try:
+ # Emit agent state change event
+ self._emit_event(AgentStateEvent(
+ agent_name=listener.name,
+ state="processing",
+ thread_id=state.thread_id,
+ ))
# Ensure we have a valid thread chain
registry = get_registry()
todo_registry = get_todo_registry()
@@ -566,6 +646,15 @@ class StreamPump:
)
payload_ref = state.payload
+ # Emit message received event
+ self._emit_event(MessageReceivedEvent(
+ thread_id=current_thread,
+ from_id=state.from_id or "",
+ to_id=listener.name,
+ payload_type=type(payload_ref).__name__,
+ payload=payload_ref,
+ ))
+
# Dispatch to handler - either in-process or via ProcessPool
if listener.cpu_bound and self._process_pool and self._shared_backend:
response = await self._dispatch_to_process_pool(
@@ -578,6 +667,12 @@ class StreamPump:
# None means "no response needed" - don't re-inject
if response is None:
+ # Emit idle state
+ self._emit_event(AgentStateEvent(
+ agent_name=listener.name,
+ state="idle",
+ thread_id=current_thread,
+ ))
continue
# Handle clean HandlerResponse (preferred)
@@ -653,6 +748,23 @@ class StreamPump:
response_bytes = b"Handler returned invalid type"
thread_id = state.thread_id
+ # Emit message sent event
+ if isinstance(response, HandlerResponse):
+ self._emit_event(MessageSentEvent(
+ thread_id=thread_id,
+ from_id=listener.name,
+ to_id=to_id,
+ payload_type=type(response.payload).__name__,
+ payload=response.payload,
+ ))
+
+ # Emit agent state back to idle
+ self._emit_event(AgentStateEvent(
+ agent_name=listener.name,
+ state="idle",
+ thread_id=None,
+ ))
+
# Yield response — will be processed by next iteration
yield MessageState(
raw_bytes=response_bytes,
@@ -665,6 +777,12 @@ class StreamPump:
semaphore.release()
except Exception as exc:
+ # Emit error state
+ self._emit_event(AgentStateEvent(
+ agent_name=listener.name,
+ state="error",
+ thread_id=state.thread_id,
+ ))
yield MessageState(
raw_bytes=f"Handler {listener.name} crashed: {exc}".encode(),
thread_id=state.thread_id,
diff --git a/xml_pipeline/server/__init__.py b/xml_pipeline/server/__init__.py
new file mode 100644
index 0000000..62dfc58
--- /dev/null
+++ b/xml_pipeline/server/__init__.py
@@ -0,0 +1,26 @@
+"""
+server — FastAPI-based AgentServer API for monitoring and controlling organisms.
+
+Provides:
+- REST API for querying organism state (agents, threads, messages)
+- WebSocket for real-time events
+- Message injection endpoint
+
+Usage:
+ from xml_pipeline.server import create_app, run_server
+
+ # With existing pump
+ app = create_app(pump)
+ uvicorn.run(app, host="0.0.0.0", port=8080)
+
+ # Or use CLI
+ xml-pipeline serve config/organism.yaml --port 8080
+"""
+
+from xml_pipeline.server.app import create_app, run_server, run_server_sync
+
+__all__ = [
+ "create_app",
+ "run_server",
+ "run_server_sync",
+]
diff --git a/xml_pipeline/server/api.py b/xml_pipeline/server/api.py
new file mode 100644
index 0000000..fc60399
--- /dev/null
+++ b/xml_pipeline/server/api.py
@@ -0,0 +1,273 @@
+"""
+api.py — REST API routes for AgentServer.
+
+Provides endpoints for:
+- Organism info and config
+- Agent listing and details
+- Thread listing and management
+- Message injection
+"""
+
+from __future__ import annotations
+
+import uuid
+from typing import TYPE_CHECKING, Optional
+
+from fastapi import APIRouter, HTTPException, Query
+
+from xml_pipeline.server.models import (
+ AgentInfo,
+ AgentListResponse,
+ ErrorResponse,
+ InjectRequest,
+ InjectResponse,
+ MessageListResponse,
+ OrganismInfo,
+ ThreadInfo,
+ ThreadListResponse,
+ ThreadStatus,
+)
+
+if TYPE_CHECKING:
+ from xml_pipeline.server.state import ServerState
+
+
+def create_router(state: "ServerState") -> APIRouter:
+ """Create API router with state dependency."""
+ router = APIRouter(prefix="/api/v1")
+
+ # =========================================================================
+ # Organism Endpoints
+ # =========================================================================
+
+ @router.get("/organism", response_model=OrganismInfo)
+ async def get_organism() -> OrganismInfo:
+ """Get organism overview and stats."""
+ return state.get_organism_info()
+
+ @router.get("/organism/config")
+ async def get_organism_config() -> dict:
+ """Get sanitized organism configuration (no secrets)."""
+ return state.get_organism_config()
+
+ # =========================================================================
+ # Agent Endpoints
+ # =========================================================================
+
+ @router.get("/agents", response_model=AgentListResponse)
+ async def list_agents() -> AgentListResponse:
+ """List all agents with current state."""
+ agents = state.get_agents()
+ return AgentListResponse(agents=agents, count=len(agents))
+
+ @router.get("/agents/{name}", response_model=AgentInfo)
+ async def get_agent(name: str) -> AgentInfo:
+ """Get single agent details."""
+ agent = state.get_agent(name)
+ if agent is None:
+ raise HTTPException(status_code=404, detail=f"Agent not found: {name}")
+ return agent
+
+ @router.get("/agents/{name}/config")
+ async def get_agent_config(name: str) -> dict:
+ """Get agent's YAML config section."""
+ agent = state.get_agent(name)
+ if agent is None:
+ raise HTTPException(status_code=404, detail=f"Agent not found: {name}")
+
+ # Return relevant config fields
+ return {
+ "name": agent.name,
+ "description": agent.description,
+ "isAgent": agent.is_agent,
+ "peers": agent.peers,
+ "payloadClass": agent.payload_class,
+ }
+
+ @router.get("/agents/{name}/schema")
+ async def get_agent_schema(name: str) -> dict:
+ """Get agent's payload XML schema."""
+ schema = state.get_agent_schema(name)
+ if schema is None:
+ raise HTTPException(
+ status_code=404,
+ detail=f"Schema not found for agent: {name}",
+ )
+ return {"schema": schema, "contentType": "application/xml"}
+
+ # =========================================================================
+ # Thread Endpoints
+ # =========================================================================
+
+ @router.get("/threads", response_model=ThreadListResponse)
+ async def list_threads(
+ status: Optional[str] = Query(None, description="Filter by status"),
+ agent: Optional[str] = Query(None, description="Filter by participant agent"),
+ limit: int = Query(50, ge=1, le=100),
+ offset: int = Query(0, ge=0),
+ ) -> ThreadListResponse:
+ """List threads with optional filtering."""
+ thread_status = None
+ if status:
+ try:
+ thread_status = ThreadStatus(status)
+ except ValueError:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Invalid status: {status}. Valid values: {[s.value for s in ThreadStatus]}",
+ )
+
+ threads, total = state.get_threads(
+ status=thread_status,
+ agent=agent,
+ limit=limit,
+ offset=offset,
+ )
+ return ThreadListResponse(
+ threads=threads,
+ count=len(threads),
+ total=total,
+ offset=offset,
+ limit=limit,
+ )
+
+ @router.get("/threads/{thread_id}", response_model=ThreadInfo)
+ async def get_thread(thread_id: str) -> ThreadInfo:
+ """Get thread details with message history."""
+ thread = state.get_thread(thread_id)
+ if thread is None:
+ raise HTTPException(status_code=404, detail=f"Thread not found: {thread_id}")
+ return thread
+
+ @router.get("/threads/{thread_id}/messages", response_model=MessageListResponse)
+ async def get_thread_messages(
+ thread_id: str,
+ limit: int = Query(50, ge=1, le=100),
+ offset: int = Query(0, ge=0),
+ ) -> MessageListResponse:
+ """Get messages in a specific thread."""
+ thread = state.get_thread(thread_id)
+ if thread is None:
+ raise HTTPException(status_code=404, detail=f"Thread not found: {thread_id}")
+
+ messages, total = state.get_messages(
+ thread_id=thread_id,
+ limit=limit,
+ offset=offset,
+ )
+ return MessageListResponse(
+ messages=messages,
+ count=len(messages),
+ total=total,
+ offset=offset,
+ limit=limit,
+ )
+
+ @router.post("/threads/{thread_id}/kill")
+ async def kill_thread(thread_id: str) -> dict:
+ """Terminate a thread."""
+ thread = state.get_thread(thread_id)
+ if thread is None:
+ raise HTTPException(status_code=404, detail=f"Thread not found: {thread_id}")
+
+ await state.complete_thread(thread_id, status=ThreadStatus.KILLED)
+ return {"success": True, "threadId": thread_id}
+
+ # =========================================================================
+ # Message Endpoints
+ # =========================================================================
+
+ @router.get("/messages", response_model=MessageListResponse)
+ async def list_messages(
+ agent: Optional[str] = Query(None, description="Filter by agent (sender or receiver)"),
+ limit: int = Query(50, ge=1, le=100),
+ offset: int = Query(0, ge=0),
+ ) -> MessageListResponse:
+ """Get global message history."""
+ messages, total = state.get_messages(
+ agent=agent,
+ limit=limit,
+ offset=offset,
+ )
+ return MessageListResponse(
+ messages=messages,
+ count=len(messages),
+ total=total,
+ offset=offset,
+ limit=limit,
+ )
+
+ # =========================================================================
+ # Control Endpoints
+ # =========================================================================
+
+ @router.post("/inject", response_model=InjectResponse)
+ async def inject_message(request: InjectRequest) -> InjectResponse:
+ """Inject a message to an agent."""
+ # Validate target exists
+ agent = state.get_agent(request.to)
+ if agent is None:
+ raise HTTPException(
+ status_code=400,
+ detail=f"Unknown target agent: {request.to}",
+ )
+
+ # Generate or use provided thread ID
+ thread_id = request.thread_id or str(uuid.uuid4())
+
+ # Build payload XML from dict
+ # For now, we construct a simple wrapper
+ payload_type = next(iter(request.payload.keys()), "Payload")
+
+ # Record the message
+ msg_id = await state.record_message(
+ thread_id=thread_id,
+ from_id="api",
+ to_id=request.to,
+ payload_type=payload_type,
+ payload=request.payload,
+ )
+
+ # TODO: Actually inject into pump queue
+ # This requires building an envelope and calling pump.inject()
+
+ return InjectResponse(thread_id=thread_id, message_id=msg_id)
+
+ @router.post("/agents/{name}/pause")
+ async def pause_agent(name: str) -> dict:
+ """Pause an agent (stop processing new messages)."""
+ agent = state.get_agent(name)
+ if agent is None:
+ raise HTTPException(status_code=404, detail=f"Agent not found: {name}")
+
+ from xml_pipeline.server.models import AgentState
+
+ await state.update_agent_state(name, AgentState.PAUSED)
+ return {"success": True, "agent": name, "state": "paused"}
+
+ @router.post("/agents/{name}/resume")
+ async def resume_agent(name: str) -> dict:
+ """Resume a paused agent."""
+ agent = state.get_agent(name)
+ if agent is None:
+ raise HTTPException(status_code=404, detail=f"Agent not found: {name}")
+
+ from xml_pipeline.server.models import AgentState
+
+ await state.update_agent_state(name, AgentState.IDLE)
+ return {"success": True, "agent": name, "state": "idle"}
+
+ @router.post("/organism/reload")
+ async def reload_config() -> dict:
+ """Hot-reload organism configuration."""
+ # TODO: Implement hot-reload
+ return {"success": False, "error": "Hot-reload not yet implemented"}
+
+ @router.post("/organism/stop")
+ async def stop_organism() -> dict:
+ """Graceful shutdown."""
+ state.set_stopping()
+ # TODO: Signal pump to stop
+ return {"success": True, "status": "stopping"}
+
+ return router
diff --git a/xml_pipeline/server/app.py b/xml_pipeline/server/app.py
new file mode 100644
index 0000000..5da1428
--- /dev/null
+++ b/xml_pipeline/server/app.py
@@ -0,0 +1,148 @@
+"""
+app.py — FastAPI application factory for AgentServer.
+
+Creates the FastAPI app that combines REST API and WebSocket endpoints.
+"""
+
+from __future__ import annotations
+
+import asyncio
+from contextlib import asynccontextmanager
+from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional
+
+from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+
+from xml_pipeline.server.api import create_router
+from xml_pipeline.server.state import ServerState
+from xml_pipeline.server.websocket import create_websocket_router
+
+if TYPE_CHECKING:
+ from xml_pipeline.message_bus.stream_pump import StreamPump
+
+
+def create_app(
+ pump: "StreamPump",
+ *,
+ title: str = "AgentServer API",
+ version: str = "1.0.0",
+ cors_origins: Optional[list[str]] = None,
+) -> FastAPI:
+ """
+ Create FastAPI application with REST and WebSocket endpoints.
+
+ Args:
+ pump: The StreamPump instance to wrap
+ title: API title for OpenAPI docs
+ version: API version
+ cors_origins: List of allowed CORS origins (default: all)
+
+ Returns:
+ Configured FastAPI application
+ """
+ # Create state manager
+ state = ServerState(pump)
+
+ @asynccontextmanager
+ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
+ """Manage app lifecycle - startup and shutdown."""
+ # Startup
+ state.set_running()
+ yield
+ # Shutdown
+ state.set_stopping()
+
+ app = FastAPI(
+ title=title,
+ version=version,
+ description="REST and WebSocket API for monitoring and controlling xml-pipeline organisms.",
+ lifespan=lifespan,
+ )
+
+ # CORS middleware
+ if cors_origins is None:
+ cors_origins = ["*"]
+
+ app.add_middleware(
+ CORSMiddleware,
+ allow_origins=cors_origins,
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+ )
+
+ # Include routers
+ app.include_router(create_router(state))
+ app.include_router(create_websocket_router(state))
+
+ # Store state on app for access if needed
+ app.state.server_state = state
+ app.state.pump = pump
+
+ @app.get("/health")
+ async def health_check() -> dict[str, Any]:
+ """Health check endpoint."""
+ info = state.get_organism_info()
+ return {
+ "status": "healthy",
+ "organism": info.name,
+ "uptime_seconds": info.uptime_seconds,
+ }
+
+ return app
+
+
+async def run_server(
+ pump: "StreamPump",
+ *,
+ host: str = "0.0.0.0",
+ port: int = 8080,
+ cors_origins: Optional[list[str]] = None,
+) -> None:
+ """
+ Run the AgentServer with uvicorn.
+
+ Args:
+ pump: The StreamPump instance to wrap
+ host: Host to bind to
+ port: Port to listen on
+ cors_origins: List of allowed CORS origins
+ """
+ try:
+ import uvicorn
+ except ImportError as e:
+ raise ImportError(
+ "uvicorn is required for the server. Install with: pip install xml-pipeline[server]"
+ ) from e
+
+ app = create_app(pump, cors_origins=cors_origins)
+
+ config = uvicorn.Config(
+ app,
+ host=host,
+ port=port,
+ log_level="info",
+ )
+ server = uvicorn.Server(config)
+ await server.serve()
+
+
+def run_server_sync(
+ pump: "StreamPump",
+ *,
+ host: str = "0.0.0.0",
+ port: int = 8080,
+ cors_origins: Optional[list[str]] = None,
+) -> None:
+ """
+ Run the AgentServer synchronously (blocking).
+
+ This is a convenience wrapper for CLI usage.
+
+ Args:
+ pump: The StreamPump instance to wrap
+ host: Host to bind to
+ port: Port to listen on
+ cors_origins: List of allowed CORS origins
+ """
+ asyncio.run(run_server(pump, host=host, port=port, cors_origins=cors_origins))
diff --git a/xml_pipeline/server/models.py b/xml_pipeline/server/models.py
new file mode 100644
index 0000000..9d3bf56
--- /dev/null
+++ b/xml_pipeline/server/models.py
@@ -0,0 +1,261 @@
+"""
+models.py — Pydantic models for AgentServer API.
+
+These models define the JSON structure for API responses.
+Uses camelCase for JSON keys (JavaScript convention).
+"""
+
+from __future__ import annotations
+
+from datetime import datetime
+from enum import Enum
+from typing import Any, Dict, List, Optional
+
+from pydantic import BaseModel, ConfigDict, Field
+
+
+def to_camel(string: str) -> str:
+ """Convert snake_case to camelCase."""
+ components = string.split("_")
+ return components[0] + "".join(x.title() for x in components[1:])
+
+
+class CamelModel(BaseModel):
+ """Base model with camelCase JSON serialization."""
+
+ model_config = ConfigDict(
+ alias_generator=to_camel,
+ populate_by_name=True,
+ )
+
+
+# =============================================================================
+# Enums
+# =============================================================================
+
+
+class AgentState(str, Enum):
+ """Agent processing state."""
+
+ IDLE = "idle"
+ PROCESSING = "processing"
+ WAITING = "waiting"
+ ERROR = "error"
+ PAUSED = "paused"
+
+
+class ThreadStatus(str, Enum):
+ """Thread lifecycle status."""
+
+ ACTIVE = "active"
+ COMPLETED = "completed"
+ ERROR = "error"
+ KILLED = "killed"
+
+
+class OrganismStatus(str, Enum):
+ """Organism running status."""
+
+ STARTING = "starting"
+ RUNNING = "running"
+ STOPPING = "stopping"
+ STOPPED = "stopped"
+
+
+# =============================================================================
+# API Response Models
+# =============================================================================
+
+
+class AgentInfo(CamelModel):
+ """Agent information for API response."""
+
+ name: str
+ description: str
+ is_agent: bool = Field(alias="isAgent")
+ peers: List[str] = Field(default_factory=list)
+ payload_class: str = Field(alias="payloadClass")
+ state: AgentState = AgentState.IDLE
+ current_thread: Optional[str] = Field(None, alias="currentThread")
+ queue_depth: int = Field(0, alias="queueDepth")
+ last_activity: Optional[datetime] = Field(None, alias="lastActivity")
+ message_count: int = Field(0, alias="messageCount")
+
+
+class MessageInfo(CamelModel):
+ """Message information for API response."""
+
+ id: str
+ thread_id: str = Field(alias="threadId")
+ from_id: str = Field(alias="from")
+ to_id: str = Field(alias="to")
+ payload_type: str = Field(alias="payloadType")
+ payload: Dict[str, Any] = Field(default_factory=dict)
+ timestamp: datetime
+ slot_index: Optional[int] = Field(None, alias="slotIndex")
+
+
+class ThreadInfo(CamelModel):
+ """Thread information for API response."""
+
+ id: str
+ status: ThreadStatus = ThreadStatus.ACTIVE
+ participants: List[str] = Field(default_factory=list)
+ message_count: int = Field(0, alias="messageCount")
+ created_at: datetime = Field(alias="createdAt")
+ last_activity: Optional[datetime] = Field(None, alias="lastActivity")
+ error: Optional[str] = None
+
+
+class OrganismInfo(CamelModel):
+ """Organism overview for API response."""
+
+ name: str
+ status: OrganismStatus = OrganismStatus.RUNNING
+ uptime_seconds: float = Field(0.0, alias="uptimeSeconds")
+ agent_count: int = Field(0, alias="agentCount")
+ active_threads: int = Field(0, alias="activeThreads")
+ total_messages: int = Field(0, alias="totalMessages")
+ identity_configured: bool = Field(False, alias="identityConfigured")
+
+
+class OrganismConfig(CamelModel):
+ """Sanitized organism configuration for API response."""
+
+ name: str
+ port: int = 8765
+ thread_scheduling: str = Field("breadth-first", alias="threadScheduling")
+ max_concurrent_pipelines: int = Field(50, alias="maxConcurrentPipelines")
+ max_concurrent_handlers: int = Field(20, alias="maxConcurrentHandlers")
+ listeners: List[str] = Field(default_factory=list)
+ # Note: Secrets like API keys are never exposed
+
+
+# =============================================================================
+# Request Models
+# =============================================================================
+
+
+class InjectRequest(CamelModel):
+ """Request body for POST /inject."""
+
+ to: str
+ payload: Dict[str, Any]
+ thread_id: Optional[str] = Field(None, alias="threadId")
+
+
+class InjectResponse(CamelModel):
+ """Response body for POST /inject."""
+
+ thread_id: str = Field(alias="threadId")
+ message_id: str = Field(alias="messageId")
+
+
+class SubscribeRequest(CamelModel):
+ """WebSocket subscription filter."""
+
+ threads: Optional[List[str]] = None
+ agents: Optional[List[str]] = None
+ payload_types: Optional[List[str]] = Field(None, alias="payloadTypes")
+ events: Optional[List[str]] = None
+
+
+# =============================================================================
+# WebSocket Event Models
+# =============================================================================
+
+
+class WSEvent(CamelModel):
+ """Base WebSocket event."""
+
+ event: str
+
+
+class WSConnectedEvent(WSEvent):
+ """Sent on WebSocket connection with full state snapshot."""
+
+ event: str = "connected"
+ organism: OrganismInfo
+ agents: List[AgentInfo]
+ threads: List[ThreadInfo]
+
+
+class WSAgentStateEvent(WSEvent):
+ """Agent state changed."""
+
+ event: str = "agent_state"
+ agent: str
+ state: AgentState
+ current_thread: Optional[str] = Field(None, alias="currentThread")
+
+
+class WSMessageEvent(WSEvent):
+ """New message in the system."""
+
+ event: str = "message"
+ message: MessageInfo
+
+
+class WSThreadCreatedEvent(WSEvent):
+ """New thread started."""
+
+ event: str = "thread_created"
+ thread: ThreadInfo
+
+
+class WSThreadUpdatedEvent(WSEvent):
+ """Thread status changed."""
+
+ event: str = "thread_updated"
+ thread_id: str = Field(alias="threadId")
+ status: ThreadStatus
+ message_count: int = Field(alias="messageCount")
+
+
+class WSErrorEvent(WSEvent):
+ """Error occurred."""
+
+ event: str = "error"
+ thread_id: Optional[str] = Field(None, alias="threadId")
+ agent: Optional[str] = None
+ error: str
+ timestamp: datetime
+
+
+# =============================================================================
+# List Response Models
+# =============================================================================
+
+
+class AgentListResponse(CamelModel):
+ """Response for GET /agents."""
+
+ agents: List[AgentInfo]
+ count: int
+
+
+class ThreadListResponse(CamelModel):
+ """Response for GET /threads."""
+
+ threads: List[ThreadInfo]
+ count: int
+ total: int
+ offset: int
+ limit: int
+
+
+class MessageListResponse(CamelModel):
+ """Response for GET /messages or /threads/{id}/messages."""
+
+ messages: List[MessageInfo]
+ count: int
+ total: int
+ offset: int
+ limit: int
+
+
+class ErrorResponse(CamelModel):
+ """Error response."""
+
+ error: str
+ detail: Optional[str] = None
diff --git a/xml_pipeline/server/state.py b/xml_pipeline/server/state.py
new file mode 100644
index 0000000..adeb3e5
--- /dev/null
+++ b/xml_pipeline/server/state.py
@@ -0,0 +1,515 @@
+"""
+state.py — Server state manager for tracking organism runtime state.
+
+Maintains real-time state for API queries:
+- Agent states (idle, processing, waiting, error)
+- Active threads and their participants
+- Message counts and activity timestamps
+
+This is the bridge between the StreamPump and the API.
+"""
+
+from __future__ import annotations
+
+import asyncio
+import time
+import uuid
+from dataclasses import dataclass, field
+from datetime import datetime, timezone
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set
+
+from xml_pipeline.server.models import (
+ AgentInfo,
+ AgentState,
+ MessageInfo,
+ OrganismInfo,
+ OrganismStatus,
+ ThreadInfo,
+ ThreadStatus,
+)
+
+if TYPE_CHECKING:
+ from xml_pipeline.message_bus.stream_pump import StreamPump
+ from xml_pipeline.message_bus import (
+ PumpEvent,
+ MessageReceivedEvent,
+ MessageSentEvent,
+ AgentStateEvent,
+ ThreadEvent,
+ )
+
+
+@dataclass
+class AgentRuntimeState:
+ """Runtime state for a single agent."""
+
+ name: str
+ description: str
+ is_agent: bool
+ peers: List[str]
+ payload_class: str
+ state: AgentState = AgentState.IDLE
+ current_thread: Optional[str] = None
+ last_activity: Optional[datetime] = None
+ message_count: int = 0
+ queue_depth: int = 0
+
+
+@dataclass
+class ThreadRuntimeState:
+ """Runtime state for a single thread."""
+
+ id: str
+ status: ThreadStatus = ThreadStatus.ACTIVE
+ participants: Set[str] = field(default_factory=set)
+ message_count: int = 0
+ created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
+ last_activity: Optional[datetime] = None
+ error: Optional[str] = None
+
+
+@dataclass
+class MessageRecord:
+ """Record of a message for history."""
+
+ id: str
+ thread_id: str
+ from_id: str
+ to_id: str
+ payload_type: str
+ payload: Dict[str, Any]
+ timestamp: datetime
+ slot_index: int
+
+
+class ServerState:
+ """
+ Centralized state manager for the API server.
+
+ Tracks runtime state of agents, threads, and messages.
+ Provides methods for the API to query current state.
+ """
+
+ def __init__(self, pump: StreamPump) -> None:
+ self.pump = pump
+ self._start_time = time.time()
+ self._status = OrganismStatus.STARTING
+
+ # Runtime state
+ self._agents: Dict[str, AgentRuntimeState] = {}
+ self._threads: Dict[str, ThreadRuntimeState] = {}
+ self._messages: List[MessageRecord] = []
+ self._message_count = 0
+
+ # Event subscribers (WebSocket connections)
+ self._subscribers: Set[Callable] = set()
+
+ # Lock for thread-safe updates
+ self._lock = asyncio.Lock()
+
+ # Initialize agent states from pump
+ self._init_agents()
+
+ # Subscribe to pump events
+ self._subscribe_to_pump()
+
+ def _init_agents(self) -> None:
+ """Initialize agent states from pump listeners."""
+ for name, listener in self.pump.listeners.items():
+ self._agents[name] = AgentRuntimeState(
+ name=name,
+ description=listener.description,
+ is_agent=listener.is_agent,
+ peers=list(listener.peers),
+ payload_class=f"{listener.payload_class.__module__}.{listener.payload_class.__name__}",
+ )
+
+ def _subscribe_to_pump(self) -> None:
+ """Subscribe to pump events for real-time state updates."""
+ from xml_pipeline.message_bus import (
+ PumpEvent,
+ MessageReceivedEvent,
+ MessageSentEvent,
+ AgentStateEvent as PumpAgentStateEvent,
+ ThreadEvent,
+ )
+
+ def handle_pump_event(event: "PumpEvent") -> None:
+ """Handle pump events synchronously (schedules async updates)."""
+ if isinstance(event, MessageReceivedEvent):
+ # Schedule async message recording
+ asyncio.create_task(self._handle_message_received(event))
+ elif isinstance(event, MessageSentEvent):
+ # Schedule async message recording
+ asyncio.create_task(self._handle_message_sent(event))
+ elif isinstance(event, PumpAgentStateEvent):
+ # Schedule async agent state update
+ asyncio.create_task(self._handle_agent_state(event))
+ elif isinstance(event, ThreadEvent):
+ # Schedule async thread update
+ asyncio.create_task(self._handle_thread_event(event))
+
+ self.pump.subscribe_events(handle_pump_event)
+
+ async def _handle_message_received(self, event: "MessageReceivedEvent") -> None:
+ """Handle message received by handler."""
+ # Convert payload to dict representation
+ payload_dict = {}
+ if hasattr(event.payload, '__dataclass_fields__'):
+ import dataclasses
+ payload_dict = dataclasses.asdict(event.payload)
+ elif isinstance(event.payload, dict):
+ payload_dict = event.payload
+
+ await self.record_message(
+ thread_id=event.thread_id,
+ from_id=event.from_id,
+ to_id=event.to_id,
+ payload_type=event.payload_type,
+ payload=payload_dict,
+ )
+
+ async def _handle_message_sent(self, event: "MessageSentEvent") -> None:
+ """Handle message sent by handler."""
+ payload_dict = {}
+ if hasattr(event.payload, '__dataclass_fields__'):
+ import dataclasses
+ payload_dict = dataclasses.asdict(event.payload)
+ elif isinstance(event.payload, dict):
+ payload_dict = event.payload
+
+ await self.record_message(
+ thread_id=event.thread_id,
+ from_id=event.from_id,
+ to_id=event.to_id,
+ payload_type=event.payload_type,
+ payload=payload_dict,
+ )
+
+ async def _handle_agent_state(self, event: "AgentStateEvent") -> None:
+ """Handle agent state change from pump."""
+ # Map pump state strings to AgentState enum
+ state_map = {
+ "idle": AgentState.IDLE,
+ "processing": AgentState.PROCESSING,
+ "waiting": AgentState.WAITING,
+ "error": AgentState.ERROR,
+ "paused": AgentState.PAUSED,
+ }
+ state = state_map.get(event.state, AgentState.IDLE)
+ await self.update_agent_state(event.agent_name, state, event.thread_id)
+
+ async def _handle_thread_event(self, event: "ThreadEvent") -> None:
+ """Handle thread lifecycle event from pump."""
+ # Map status to ThreadStatus enum
+ status_map = {
+ "created": ThreadStatus.ACTIVE,
+ "active": ThreadStatus.ACTIVE,
+ "completed": ThreadStatus.COMPLETED,
+ "error": ThreadStatus.ERROR,
+ "killed": ThreadStatus.KILLED,
+ }
+ status = status_map.get(event.status, ThreadStatus.ACTIVE)
+
+ if event.status in ("completed", "error", "killed"):
+ await self.complete_thread(event.thread_id, status, event.error)
+
+ def set_running(self) -> None:
+ """Mark organism as running."""
+ self._status = OrganismStatus.RUNNING
+
+ def set_stopping(self) -> None:
+ """Mark organism as stopping."""
+ self._status = OrganismStatus.STOPPING
+
+ # =========================================================================
+ # Event Recording (called by pump hooks)
+ # =========================================================================
+
+ async def record_message(
+ self,
+ thread_id: str,
+ from_id: str,
+ to_id: str,
+ payload_type: str,
+ payload: Dict[str, Any],
+ ) -> str:
+ """Record a message and update related state."""
+ async with self._lock:
+ msg_id = str(uuid.uuid4())
+ now = datetime.now(timezone.utc)
+ self._message_count += 1
+
+ record = MessageRecord(
+ id=msg_id,
+ thread_id=thread_id,
+ from_id=from_id,
+ to_id=to_id,
+ payload_type=payload_type,
+ payload=payload,
+ timestamp=now,
+ slot_index=self._message_count,
+ )
+ self._messages.append(record)
+
+ # Update thread state
+ if thread_id not in self._threads:
+ self._threads[thread_id] = ThreadRuntimeState(
+ id=thread_id,
+ created_at=now,
+ )
+ thread = self._threads[thread_id]
+ thread.participants.add(from_id)
+ thread.participants.add(to_id)
+ thread.message_count += 1
+ thread.last_activity = now
+
+ # Update agent states
+ if from_id in self._agents:
+ self._agents[from_id].message_count += 1
+ self._agents[from_id].last_activity = now
+
+ # Notify subscribers
+ await self._notify_message(record)
+
+ return msg_id
+
+ async def update_agent_state(
+ self,
+ agent_name: str,
+ state: AgentState,
+ current_thread: Optional[str] = None,
+ ) -> None:
+ """Update an agent's processing state."""
+ async with self._lock:
+ if agent_name not in self._agents:
+ return
+
+ agent = self._agents[agent_name]
+ old_state = agent.state
+ agent.state = state
+ agent.current_thread = current_thread
+ agent.last_activity = datetime.now(timezone.utc)
+
+ if old_state != state:
+ await self._notify_agent_state(agent_name, state, current_thread)
+
+ async def complete_thread(
+ self,
+ thread_id: str,
+ status: ThreadStatus = ThreadStatus.COMPLETED,
+ error: Optional[str] = None,
+ ) -> None:
+ """Mark a thread as completed or errored."""
+ async with self._lock:
+ if thread_id not in self._threads:
+ return
+
+ thread = self._threads[thread_id]
+ thread.status = status
+ thread.error = error
+ thread.last_activity = datetime.now(timezone.utc)
+
+ await self._notify_thread_updated(thread)
+
+ # =========================================================================
+ # Query Methods (for API)
+ # =========================================================================
+
+ def get_organism_info(self) -> OrganismInfo:
+ """Get organism overview."""
+ return OrganismInfo(
+ name=self.pump.config.name,
+ status=self._status,
+ uptime_seconds=time.time() - self._start_time,
+ agent_count=len(self._agents),
+ active_threads=sum(
+ 1 for t in self._threads.values() if t.status == ThreadStatus.ACTIVE
+ ),
+ total_messages=self._message_count,
+ identity_configured=self.pump.identity is not None,
+ )
+
+ def get_organism_config(self) -> Dict[str, Any]:
+ """Get sanitized organism config (no secrets)."""
+ return {
+ "name": self.pump.config.name,
+ "port": self.pump.config.port,
+ "thread_scheduling": self.pump.config.thread_scheduling,
+ "max_concurrent_pipelines": self.pump.config.max_concurrent_pipelines,
+ "max_concurrent_handlers": self.pump.config.max_concurrent_handlers,
+ "listeners": list(self.pump.listeners.keys()),
+ }
+
+ def get_agents(self) -> List[AgentInfo]:
+ """Get all agents."""
+ return [self._agent_to_info(a) for a in self._agents.values()]
+
+ def get_agent(self, name: str) -> Optional[AgentInfo]:
+ """Get a single agent by name."""
+ agent = self._agents.get(name)
+ if agent:
+ return self._agent_to_info(agent)
+ return None
+
+ def get_agent_schema(self, name: str) -> Optional[str]:
+ """Get agent's XSD schema as string."""
+ listener = self.pump.listeners.get(name)
+ if listener and listener.schema is not None:
+ from lxml import etree
+
+ return etree.tostring(listener.schema, encoding="unicode", pretty_print=True)
+ return None
+
+ def get_threads(
+ self,
+ status: Optional[ThreadStatus] = None,
+ agent: Optional[str] = None,
+ limit: int = 50,
+ offset: int = 0,
+ ) -> tuple[List[ThreadInfo], int]:
+ """Get threads with optional filtering."""
+ threads = list(self._threads.values())
+
+ # Filter
+ if status:
+ threads = [t for t in threads if t.status == status]
+ if agent:
+ threads = [t for t in threads if agent in t.participants]
+
+ # Sort by last activity (most recent first)
+ threads.sort(key=lambda t: t.last_activity or t.created_at, reverse=True)
+
+ total = len(threads)
+ threads = threads[offset : offset + limit]
+
+ return [self._thread_to_info(t) for t in threads], total
+
+ def get_thread(self, thread_id: str) -> Optional[ThreadInfo]:
+ """Get a single thread by ID."""
+ thread = self._threads.get(thread_id)
+ if thread:
+ return self._thread_to_info(thread)
+ return None
+
+ def get_messages(
+ self,
+ thread_id: Optional[str] = None,
+ agent: Optional[str] = None,
+ limit: int = 50,
+ offset: int = 0,
+ ) -> tuple[List[MessageInfo], int]:
+ """Get messages with optional filtering."""
+ messages = self._messages.copy()
+
+ # Filter
+ if thread_id:
+ messages = [m for m in messages if m.thread_id == thread_id]
+ if agent:
+ messages = [m for m in messages if m.from_id == agent or m.to_id == agent]
+
+ # Sort by timestamp (most recent first)
+ messages.sort(key=lambda m: m.timestamp, reverse=True)
+
+ total = len(messages)
+ messages = messages[offset : offset + limit]
+
+ return [self._message_to_info(m) for m in messages], total
+
+ # =========================================================================
+ # Subscription Management
+ # =========================================================================
+
+ def subscribe(self, callback: Callable) -> None:
+ """Subscribe to state events."""
+ self._subscribers.add(callback)
+
+ def unsubscribe(self, callback: Callable) -> None:
+ """Unsubscribe from state events."""
+ self._subscribers.discard(callback)
+
+ async def _notify_message(self, record: MessageRecord) -> None:
+ """Notify subscribers of new message."""
+ event = {
+ "event": "message",
+ "message": self._message_to_info(record).model_dump(by_alias=True),
+ }
+ await self._broadcast(event)
+
+ async def _notify_agent_state(
+ self,
+ agent_name: str,
+ state: AgentState,
+ current_thread: Optional[str],
+ ) -> None:
+ """Notify subscribers of agent state change."""
+ event = {
+ "event": "agent_state",
+ "agent": agent_name,
+ "state": state.value,
+ "currentThread": current_thread,
+ }
+ await self._broadcast(event)
+
+ async def _notify_thread_updated(self, thread: ThreadRuntimeState) -> None:
+ """Notify subscribers of thread update."""
+ event = {
+ "event": "thread_updated",
+ "threadId": thread.id,
+ "status": thread.status.value,
+ "messageCount": thread.message_count,
+ }
+ await self._broadcast(event)
+
+ async def _broadcast(self, event: Dict[str, Any]) -> None:
+ """Broadcast event to all subscribers."""
+ for callback in list(self._subscribers):
+ try:
+ await callback(event)
+ except Exception:
+ # Remove failed subscribers
+ self._subscribers.discard(callback)
+
+ # =========================================================================
+ # Conversion Helpers
+ # =========================================================================
+
+ def _agent_to_info(self, agent: AgentRuntimeState) -> AgentInfo:
+ """Convert runtime state to API model."""
+ return AgentInfo(
+ name=agent.name,
+ description=agent.description,
+ is_agent=agent.is_agent,
+ peers=agent.peers,
+ payload_class=agent.payload_class,
+ state=agent.state,
+ current_thread=agent.current_thread,
+ queue_depth=agent.queue_depth,
+ last_activity=agent.last_activity,
+ message_count=agent.message_count,
+ )
+
+ def _thread_to_info(self, thread: ThreadRuntimeState) -> ThreadInfo:
+ """Convert runtime state to API model."""
+ return ThreadInfo(
+ id=thread.id,
+ status=thread.status,
+ participants=list(thread.participants),
+ message_count=thread.message_count,
+ created_at=thread.created_at,
+ last_activity=thread.last_activity,
+ error=thread.error,
+ )
+
+ def _message_to_info(self, record: MessageRecord) -> MessageInfo:
+ """Convert record to API model."""
+ return MessageInfo(
+ id=record.id,
+ thread_id=record.thread_id,
+ from_id=record.from_id,
+ to_id=record.to_id,
+ payload_type=record.payload_type,
+ payload=record.payload,
+ timestamp=record.timestamp,
+ slot_index=record.slot_index,
+ )
diff --git a/xml_pipeline/server/websocket.py b/xml_pipeline/server/websocket.py
new file mode 100644
index 0000000..bd244b2
--- /dev/null
+++ b/xml_pipeline/server/websocket.py
@@ -0,0 +1,316 @@
+"""
+websocket.py — WebSocket endpoints for AgentServer.
+
+Provides:
+- /ws — Main control channel with state snapshot and real-time events
+- /ws/messages — Dedicated message log stream with filtering
+"""
+
+from __future__ import annotations
+
+import asyncio
+import json
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
+
+from fastapi import APIRouter, WebSocket, WebSocketDisconnect
+
+from xml_pipeline.server.models import (
+ SubscribeRequest,
+ WSConnectedEvent,
+)
+
+if TYPE_CHECKING:
+ from xml_pipeline.server.state import ServerState
+
+
+class ConnectionManager:
+ """Manages WebSocket connections and subscriptions."""
+
+ def __init__(self) -> None:
+ self.active_connections: Set[WebSocket] = set()
+ self.subscriptions: Dict[WebSocket, SubscribeRequest] = {}
+
+ async def connect(self, websocket: WebSocket) -> None:
+ """Accept a new WebSocket connection."""
+ await websocket.accept()
+ self.active_connections.add(websocket)
+ self.subscriptions[websocket] = SubscribeRequest() # Default: all events
+
+ def disconnect(self, websocket: WebSocket) -> None:
+ """Remove a WebSocket connection."""
+ self.active_connections.discard(websocket)
+ self.subscriptions.pop(websocket, None)
+
+ def set_subscription(self, websocket: WebSocket, sub: SubscribeRequest) -> None:
+ """Update subscription filters for a connection."""
+ self.subscriptions[websocket] = sub
+
+ def should_send(self, websocket: WebSocket, event: Dict[str, Any]) -> bool:
+ """Check if an event should be sent to this connection based on filters."""
+ sub = self.subscriptions.get(websocket)
+ if sub is None:
+ return True # No filters = send all
+
+ event_type = event.get("event", "")
+
+ # Filter by event type
+ if sub.events and event_type not in sub.events:
+ return False
+
+ # Filter by thread
+ thread_id = event.get("thread_id") or event.get("threadId")
+ if sub.threads and thread_id and thread_id not in sub.threads:
+ return False
+
+ # Filter by agent
+ agent = event.get("agent")
+ if sub.agents and agent and agent not in sub.agents:
+ return False
+
+ # For message events, check from/to
+ if event_type == "message":
+ msg = event.get("message", {})
+ from_id = msg.get("from") or msg.get("fromId")
+ to_id = msg.get("to") or msg.get("toId")
+ if sub.agents:
+ if from_id not in sub.agents and to_id not in sub.agents:
+ return False
+
+ # Filter by payload type
+ if sub.payload_types:
+ payload_type = None
+ if event_type == "message":
+ payload_type = event.get("message", {}).get("payloadType")
+ if payload_type and payload_type not in sub.payload_types:
+ return False
+
+ return True
+
+ async def broadcast(self, event: Dict[str, Any]) -> None:
+ """Broadcast an event to all connections that match their subscription."""
+ disconnected: List[WebSocket] = []
+
+ for websocket in self.active_connections:
+ if self.should_send(websocket, event):
+ try:
+ await websocket.send_json(event)
+ except Exception:
+ disconnected.append(websocket)
+
+ # Clean up disconnected clients
+ for ws in disconnected:
+ self.disconnect(ws)
+
+
+class MessageStreamManager:
+ """Manages WebSocket connections for /ws/messages endpoint."""
+
+ def __init__(self) -> None:
+ self.active_connections: Set[WebSocket] = set()
+ self.filters: Dict[WebSocket, Dict[str, Any]] = {}
+
+ async def connect(self, websocket: WebSocket) -> None:
+ """Accept a new WebSocket connection."""
+ await websocket.accept()
+ self.active_connections.add(websocket)
+ self.filters[websocket] = {} # Default: all messages
+
+ def disconnect(self, websocket: WebSocket) -> None:
+ """Remove a WebSocket connection."""
+ self.active_connections.discard(websocket)
+ self.filters.pop(websocket, None)
+
+ def set_filter(self, websocket: WebSocket, filter_config: Dict[str, Any]) -> None:
+ """Update message filter for a connection."""
+ self.filters[websocket] = filter_config
+
+ def should_send(self, websocket: WebSocket, message: Dict[str, Any]) -> bool:
+ """Check if a message should be sent to this connection."""
+ flt = self.filters.get(websocket, {})
+ if not flt:
+ return True
+
+ # Filter by agents
+ agents = flt.get("agents", [])
+ if agents:
+ from_id = message.get("from") or message.get("fromId")
+ to_id = message.get("to") or message.get("toId")
+ if from_id not in agents and to_id not in agents:
+ return False
+
+ # Filter by threads
+ threads = flt.get("threads", [])
+ if threads:
+ thread_id = message.get("thread_id") or message.get("threadId")
+ if thread_id not in threads:
+ return False
+
+ # Filter by payload types
+ payload_types = flt.get("payload_types", [])
+ if payload_types:
+ payload_type = message.get("payloadType") or message.get("payload_type")
+ if payload_type not in payload_types:
+ return False
+
+ return True
+
+ async def broadcast_message(self, message: Dict[str, Any]) -> None:
+ """Broadcast a message to all filtered connections."""
+ disconnected: List[WebSocket] = []
+
+ for websocket in self.active_connections:
+ if self.should_send(websocket, message):
+ try:
+ await websocket.send_json(message)
+ except Exception:
+ disconnected.append(websocket)
+
+ for ws in disconnected:
+ self.disconnect(ws)
+
+
+def create_websocket_router(state: "ServerState") -> APIRouter:
+ """Create WebSocket router with state dependency."""
+ router = APIRouter()
+ manager = ConnectionManager()
+ message_manager = MessageStreamManager()
+
+ # Subscribe state to WebSocket broadcasting
+ async def on_state_event(event: Dict[str, Any]) -> None:
+ """Forward state events to WebSocket connections."""
+ await manager.broadcast(event)
+
+ # Also forward message events to the message stream
+ if event.get("event") == "message":
+ msg = event.get("message", {})
+ await message_manager.broadcast_message(msg)
+
+ state.subscribe(on_state_event)
+
+ @router.websocket("/ws")
+ async def websocket_control(websocket: WebSocket) -> None:
+ """
+ Main control channel WebSocket.
+
+ On connect, sends full state snapshot. Then pushes events as they occur.
+ Accepts commands: subscribe, inject.
+ """
+ await manager.connect(websocket)
+
+ try:
+ # Send connected event with state snapshot
+ threads, _ = state.get_threads(limit=100)
+ connected_event = WSConnectedEvent(
+ organism=state.get_organism_info(),
+ agents=state.get_agents(),
+ threads=threads,
+ )
+ await websocket.send_json(connected_event.model_dump(by_alias=True))
+
+ # Listen for commands
+ while True:
+ try:
+ data = await websocket.receive_json()
+ except Exception:
+ # Connection closed or invalid data
+ break
+
+ cmd = data.get("cmd", "")
+
+ if cmd == "subscribe":
+ # Update subscription filters
+ sub = SubscribeRequest(
+ threads=data.get("threads"),
+ agents=data.get("agents"),
+ payload_types=data.get("payload_types"),
+ events=data.get("events"),
+ )
+ manager.set_subscription(websocket, sub)
+ await websocket.send_json({"event": "subscribed", "filters": data})
+
+ elif cmd == "inject":
+ # Inject a message (same as REST /inject)
+ target = data.get("to")
+ payload = data.get("payload", {})
+ thread_id = data.get("thread_id")
+
+ if not target:
+ await websocket.send_json(
+ {"event": "error", "error": "Missing 'to' field"}
+ )
+ continue
+
+ agent = state.get_agent(target)
+ if agent is None:
+ await websocket.send_json(
+ {"event": "error", "error": f"Unknown agent: {target}"}
+ )
+ continue
+
+ import uuid
+
+ thread_id = thread_id or str(uuid.uuid4())
+ payload_type = next(iter(payload.keys()), "Payload")
+
+ msg_id = await state.record_message(
+ thread_id=thread_id,
+ from_id="api",
+ to_id=target,
+ payload_type=payload_type,
+ payload=payload,
+ )
+
+ await websocket.send_json(
+ {
+ "event": "injected",
+ "thread_id": thread_id,
+ "message_id": msg_id,
+ }
+ )
+
+ else:
+ await websocket.send_json(
+ {"event": "error", "error": f"Unknown command: {cmd}"}
+ )
+
+ except WebSocketDisconnect:
+ pass
+ finally:
+ manager.disconnect(websocket)
+
+ @router.websocket("/ws/messages")
+ async def websocket_messages(websocket: WebSocket) -> None:
+ """
+ Dedicated message log stream.
+
+ Streams all messages flowing through the organism.
+ Clients can filter by agents, threads, and payload types.
+ """
+ await message_manager.connect(websocket)
+
+ try:
+ while True:
+ try:
+ data = await websocket.receive_json()
+ except Exception:
+ break
+
+ cmd = data.get("cmd", "")
+
+ if cmd == "subscribe":
+ # Update filter
+ filter_config = data.get("filter", {})
+ message_manager.set_filter(websocket, filter_config)
+ await websocket.send_json({"event": "subscribed", "filter": filter_config})
+
+ else:
+ await websocket.send_json(
+ {"event": "error", "error": f"Unknown command: {cmd}"}
+ )
+
+ except WebSocketDisconnect:
+ pass
+ finally:
+ message_manager.disconnect(websocket)
+
+ return router