From bf31b0d14e7a677a2cbea6afe5168d514944e670 Mon Sep 17 00:00:00 2001 From: dullfig Date: Tue, 27 Jan 2026 20:22:58 -0800 Subject: [PATCH] Add AgentServer REST/WebSocket API Implements the AgentServer API from docs/agentserver_api_spec.md: REST API (/api/v1): - Organism info and config endpoints - Agent listing, details, config, schema - Thread and message history with filtering - Control endpoints (inject, pause, resume, kill, stop) WebSocket: - /ws: Main control channel with state snapshot + real-time events - /ws/messages: Dedicated message stream with filtering Infrastructure: - Pydantic models with camelCase serialization - ServerState bridges StreamPump to API - Pump event hooks for real-time updates - CLI 'serve' command: xml-pipeline serve [config] --port 8080 35 new tests for models, state, REST, and WebSocket. Co-Authored-By: Claude Opus 4.5 --- pyproject.toml | 9 +- tests/test_server.py | 510 +++++++++++++++++++++++ xml_pipeline/cli.py | 70 ++++ xml_pipeline/message_bus/__init__.py | 12 + xml_pipeline/message_bus/stream_pump.py | 118 ++++++ xml_pipeline/server/__init__.py | 26 ++ xml_pipeline/server/api.py | 273 +++++++++++++ xml_pipeline/server/app.py | 148 +++++++ xml_pipeline/server/models.py | 261 ++++++++++++ xml_pipeline/server/state.py | 515 ++++++++++++++++++++++++ xml_pipeline/server/websocket.py | 316 +++++++++++++++ 11 files changed, 2257 insertions(+), 1 deletion(-) create mode 100644 tests/test_server.py create mode 100644 xml_pipeline/server/__init__.py create mode 100644 xml_pipeline/server/api.py create mode 100644 xml_pipeline/server/app.py create mode 100644 xml_pipeline/server/models.py create mode 100644 xml_pipeline/server/state.py create mode 100644 xml_pipeline/server/websocket.py diff --git a/pyproject.toml b/pyproject.toml index 7fc5ff4..ffe4d60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,13 @@ search = ["duckduckgo-search>=6.0"] # Web search tool # Console example (optional, for interactive use) console = ["prompt_toolkit>=3.0"] +# API server (FastAPI + WebSocket) +server = [ + "fastapi>=0.109", + "uvicorn[standard]>=0.27", + "websockets>=12.0", +] + # All LLM providers llm = ["xml-pipeline[anthropic,openai]"] @@ -87,7 +94,7 @@ llm = ["xml-pipeline[anthropic,openai]"] tools = ["xml-pipeline[redis,search]"] # Everything (for local development) -all = ["xml-pipeline[llm,tools,console]"] +all = ["xml-pipeline[llm,tools,console,server]"] # Testing test = [ diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..7b18476 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,510 @@ +""" +Tests for the AgentServer API. + +Tests the REST API endpoints and WebSocket connections. +""" + +import asyncio +import pytest +from dataclasses import dataclass +from typing import Any, Dict, List, Optional +from unittest.mock import MagicMock, AsyncMock, patch + +# Skip all tests if FastAPI not available +pytest.importorskip("fastapi") + +from fastapi.testclient import TestClient + +from xml_pipeline.server.models import ( + AgentState, + AgentInfo, + ThreadStatus, + ThreadInfo, + MessageInfo, + OrganismInfo, + OrganismStatus, +) +from xml_pipeline.server.state import ServerState, AgentRuntimeState, ThreadRuntimeState +from xml_pipeline.server.api import create_router +from xml_pipeline.server.app import create_app + + +# ============================================================================ +# Mock StreamPump +# ============================================================================ + +@dataclass +class MockListener: + """Mock listener for testing.""" + name: str + description: str + is_agent: bool = False + peers: List[str] = None + payload_class: type = None + schema: Any = None + + def __post_init__(self): + if self.peers is None: + self.peers = [] + if self.payload_class is None: + self.payload_class = type("MockPayload", (), {"__module__": "test", "__name__": "MockPayload"}) + + +@dataclass +class MockConfig: + """Mock config for testing.""" + name: str = "test-organism" + port: int = 8765 + thread_scheduling: str = "breadth-first" + max_concurrent_pipelines: int = 50 + max_concurrent_handlers: int = 20 + + +class MockStreamPump: + """Mock StreamPump for testing.""" + + def __init__(self): + self.config = MockConfig() + self.identity = None + self.listeners: Dict[str, MockListener] = { + "greeter": MockListener( + name="greeter", + description="Greeting agent", + is_agent=True, + peers=["shouter"], + ), + "shouter": MockListener( + name="shouter", + description="Shouting handler", + is_agent=False, + ), + } + self._event_callbacks = [] + + def subscribe_events(self, callback): + self._event_callbacks.append(callback) + + def unsubscribe_events(self, callback): + if callback in self._event_callbacks: + self._event_callbacks.remove(callback) + + +# ============================================================================ +# Test Fixtures +# ============================================================================ + +@pytest.fixture +def mock_pump(): + """Create a mock StreamPump.""" + return MockStreamPump() + + +@pytest.fixture +def server_state(mock_pump): + """Create ServerState with mock pump.""" + return ServerState(mock_pump) + + +@pytest.fixture +def test_client(mock_pump): + """Create FastAPI test client.""" + app = create_app(mock_pump) + return TestClient(app) + + +# ============================================================================ +# Test Models +# ============================================================================ + +class TestModels: + """Test Pydantic model serialization.""" + + def test_agent_info_camel_case(self): + """Test AgentInfo serializes to camelCase.""" + agent = AgentInfo( + name="greeter", + description="Test agent", + is_agent=True, + peers=["shouter"], + payload_class="test.Greeting", + state=AgentState.IDLE, + current_thread=None, + queue_depth=0, + message_count=5, + ) + data = agent.model_dump(by_alias=True) + assert "isAgent" in data + assert "payloadClass" in data + assert "currentThread" in data + assert "queueDepth" in data + assert "messageCount" in data + + def test_thread_info_camel_case(self): + """Test ThreadInfo serializes to camelCase.""" + from datetime import datetime, timezone + thread = ThreadInfo( + id="test-uuid", + status=ThreadStatus.ACTIVE, + participants=["greeter", "shouter"], + message_count=3, + created_at=datetime.now(timezone.utc), + ) + data = thread.model_dump(by_alias=True) + assert "messageCount" in data + assert "createdAt" in data + assert "lastActivity" in data + + def test_organism_info_camel_case(self): + """Test OrganismInfo serializes to camelCase.""" + info = OrganismInfo( + name="test-organism", + status=OrganismStatus.RUNNING, + uptime_seconds=3600.0, + agent_count=2, + active_threads=1, + total_messages=10, + identity_configured=False, + ) + data = info.model_dump(by_alias=True) + assert "uptimeSeconds" in data + assert "agentCount" in data + assert "activeThreads" in data + assert "totalMessages" in data + assert "identityConfigured" in data + + +# ============================================================================ +# Test ServerState +# ============================================================================ + +class TestServerState: + """Test ServerState functionality.""" + + def test_init_agents_from_pump(self, server_state): + """Test agents are initialized from pump listeners.""" + assert "greeter" in server_state._agents + assert "shouter" in server_state._agents + assert server_state._agents["greeter"].is_agent is True + assert server_state._agents["shouter"].is_agent is False + + def test_get_organism_info(self, server_state): + """Test getting organism info.""" + info = server_state.get_organism_info() + assert info.name == "test-organism" + assert info.agent_count == 2 + assert info.status == OrganismStatus.STARTING + + def test_set_running(self, server_state): + """Test setting running status.""" + server_state.set_running() + info = server_state.get_organism_info() + assert info.status == OrganismStatus.RUNNING + + def test_get_agents(self, server_state): + """Test getting all agents.""" + agents = server_state.get_agents() + assert len(agents) == 2 + names = [a.name for a in agents] + assert "greeter" in names + assert "shouter" in names + + def test_get_agent(self, server_state): + """Test getting single agent.""" + agent = server_state.get_agent("greeter") + assert agent is not None + assert agent.name == "greeter" + assert agent.is_agent is True + + def test_get_agent_not_found(self, server_state): + """Test getting non-existent agent.""" + agent = server_state.get_agent("nonexistent") + assert agent is None + + @pytest.mark.asyncio + async def test_record_message(self, server_state): + """Test recording a message.""" + msg_id = await server_state.record_message( + thread_id="test-thread", + from_id="greeter", + to_id="shouter", + payload_type="GreetingResponse", + payload={"message": "Hello!"}, + ) + assert msg_id is not None + + # Check message was recorded + messages, total = server_state.get_messages(thread_id="test-thread") + assert total == 1 + assert messages[0].from_id == "greeter" + assert messages[0].to_id == "shouter" + + @pytest.mark.asyncio + async def test_update_agent_state(self, server_state): + """Test updating agent state.""" + await server_state.update_agent_state("greeter", AgentState.PROCESSING, "thread-1") + + agent = server_state.get_agent("greeter") + assert agent.state == AgentState.PROCESSING + assert agent.current_thread == "thread-1" + + @pytest.mark.asyncio + async def test_complete_thread(self, server_state): + """Test completing a thread.""" + # First record a message to create the thread + await server_state.record_message( + thread_id="test-thread", + from_id="greeter", + to_id="shouter", + payload_type="Test", + payload={}, + ) + + # Complete the thread + await server_state.complete_thread("test-thread", ThreadStatus.COMPLETED) + + thread = server_state.get_thread("test-thread") + assert thread.status == ThreadStatus.COMPLETED + + def test_get_threads_with_filter(self, server_state): + """Test filtering threads by status.""" + # No threads initially + threads, total = server_state.get_threads(status=ThreadStatus.ACTIVE) + assert total == 0 + + def test_get_organism_config(self, server_state): + """Test getting organism config.""" + config = server_state.get_organism_config() + assert config["name"] == "test-organism" + assert config["port"] == 8765 + + +# ============================================================================ +# Test REST API +# ============================================================================ + +class TestRestAPI: + """Test REST API endpoints.""" + + def test_health_check(self, test_client): + """Test health check endpoint.""" + response = test_client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert "organism" in data + + def test_get_organism(self, test_client): + """Test GET /api/v1/organism.""" + response = test_client.get("/api/v1/organism") + assert response.status_code == 200 + data = response.json() + assert data["name"] == "test-organism" + assert "agentCount" in data + + def test_get_organism_config(self, test_client): + """Test GET /api/v1/organism/config.""" + response = test_client.get("/api/v1/organism/config") + assert response.status_code == 200 + data = response.json() + assert data["name"] == "test-organism" + + def test_list_agents(self, test_client): + """Test GET /api/v1/agents.""" + response = test_client.get("/api/v1/agents") + assert response.status_code == 200 + data = response.json() + assert "agents" in data + assert data["count"] == 2 + + def test_get_agent(self, test_client): + """Test GET /api/v1/agents/{name}.""" + response = test_client.get("/api/v1/agents/greeter") + assert response.status_code == 200 + data = response.json() + assert data["name"] == "greeter" + assert data["isAgent"] is True + + def test_get_agent_not_found(self, test_client): + """Test GET /api/v1/agents/{name} with non-existent agent.""" + response = test_client.get("/api/v1/agents/nonexistent") + assert response.status_code == 404 + + def test_get_agent_config(self, test_client): + """Test GET /api/v1/agents/{name}/config.""" + response = test_client.get("/api/v1/agents/greeter/config") + assert response.status_code == 200 + data = response.json() + assert data["name"] == "greeter" + assert "isAgent" in data + + def test_list_threads(self, test_client): + """Test GET /api/v1/threads.""" + response = test_client.get("/api/v1/threads") + assert response.status_code == 200 + data = response.json() + assert "threads" in data + assert "count" in data + assert "total" in data + + def test_list_threads_with_invalid_status(self, test_client): + """Test GET /api/v1/threads with invalid status filter.""" + response = test_client.get("/api/v1/threads?status=invalid") + assert response.status_code == 400 + + def test_list_messages(self, test_client): + """Test GET /api/v1/messages.""" + response = test_client.get("/api/v1/messages") + assert response.status_code == 200 + data = response.json() + assert "messages" in data + assert "count" in data + + def test_inject_message(self, test_client): + """Test POST /api/v1/inject.""" + response = test_client.post( + "/api/v1/inject", + json={ + "to": "greeter", + "payload": {"name": "Dan"}, + }, + ) + assert response.status_code == 200 + data = response.json() + assert "threadId" in data + assert "messageId" in data + + def test_inject_message_unknown_agent(self, test_client): + """Test POST /api/v1/inject with unknown agent.""" + response = test_client.post( + "/api/v1/inject", + json={ + "to": "nonexistent", + "payload": {"name": "Dan"}, + }, + ) + assert response.status_code == 400 + + def test_pause_agent(self, test_client): + """Test POST /api/v1/agents/{name}/pause.""" + response = test_client.post("/api/v1/agents/greeter/pause") + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["state"] == "paused" + + def test_resume_agent(self, test_client): + """Test POST /api/v1/agents/{name}/resume.""" + response = test_client.post("/api/v1/agents/greeter/resume") + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert data["state"] == "idle" + + def test_stop_organism(self, test_client): + """Test POST /api/v1/organism/stop.""" + response = test_client.post("/api/v1/organism/stop") + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + + +# ============================================================================ +# Test WebSocket +# ============================================================================ + +class TestWebSocket: + """Test WebSocket endpoints.""" + + def test_websocket_connect(self, test_client): + """Test WebSocket connection.""" + with test_client.websocket_connect("/ws") as websocket: + # Should receive connected event with state snapshot + data = websocket.receive_json() + assert data["event"] == "connected" + assert "organism" in data + assert "agents" in data + assert "threads" in data + + def test_websocket_subscribe(self, test_client): + """Test WebSocket subscribe command.""" + with test_client.websocket_connect("/ws") as websocket: + # Receive initial connected event + websocket.receive_json() + + # Send subscribe command + websocket.send_json({ + "cmd": "subscribe", + "agents": ["greeter"], + "events": ["message"], + }) + + # Should receive subscribed confirmation + data = websocket.receive_json() + assert data["event"] == "subscribed" + + def test_websocket_inject(self, test_client): + """Test WebSocket inject command.""" + with test_client.websocket_connect("/ws") as websocket: + # Receive initial connected event + websocket.receive_json() + + # Send inject command + websocket.send_json({ + "cmd": "inject", + "to": "greeter", + "payload": {"name": "Dan"}, + }) + + # Should receive injected confirmation + data = websocket.receive_json() + assert data["event"] == "injected" + assert "thread_id" in data + assert "message_id" in data + + def test_websocket_inject_unknown_agent(self, test_client): + """Test WebSocket inject with unknown agent.""" + with test_client.websocket_connect("/ws") as websocket: + # Receive initial connected event + websocket.receive_json() + + # Send inject command to unknown agent + websocket.send_json({ + "cmd": "inject", + "to": "nonexistent", + "payload": {}, + }) + + # Should receive error + data = websocket.receive_json() + assert data["event"] == "error" + + def test_websocket_unknown_command(self, test_client): + """Test WebSocket with unknown command.""" + with test_client.websocket_connect("/ws") as websocket: + # Receive initial connected event + websocket.receive_json() + + # Send unknown command + websocket.send_json({ + "cmd": "unknown_command", + }) + + # Should receive error + data = websocket.receive_json() + assert data["event"] == "error" + + def test_websocket_messages_stream(self, test_client): + """Test WebSocket messages stream endpoint.""" + with test_client.websocket_connect("/ws/messages") as websocket: + # Send subscribe command + websocket.send_json({ + "cmd": "subscribe", + "filter": { + "agents": ["greeter"], + }, + }) + + # Should receive subscribed confirmation + data = websocket.receive_json() + assert data["event"] == "subscribed" + assert "filter" in data diff --git a/xml_pipeline/cli.py b/xml_pipeline/cli.py index 53497c0..5dd4b9d 100644 --- a/xml_pipeline/cli.py +++ b/xml_pipeline/cli.py @@ -3,6 +3,7 @@ xml-pipeline CLI entry point. Usage: xml-pipeline run [config.yaml] Run an organism + xml-pipeline serve [config.yaml] Run organism with API server xml-pipeline init [name] Create new organism config xml-pipeline check [config.yaml] Validate config without running xml-pipeline version Show version info @@ -37,6 +38,68 @@ def cmd_run(args: argparse.Namespace) -> int: return 1 +def cmd_serve(args: argparse.Namespace) -> int: + """Run an organism with the AgentServer API.""" + try: + import uvicorn + except ImportError: + print("Error: uvicorn not installed.", file=sys.stderr) + print("Install with: pip install xml-pipeline[server]", file=sys.stderr) + return 1 + + from xml_pipeline.message_bus import bootstrap + + config_path = Path(args.config) + if not config_path.exists(): + print(f"Error: Config file not found: {config_path}", file=sys.stderr) + return 1 + + async def run_with_server(): + """Bootstrap pump and run with server.""" + from xml_pipeline.server import create_app + + # Bootstrap the pump + pump = await bootstrap(str(config_path)) + + # Create FastAPI app + app = create_app(pump) + + # Run uvicorn + config = uvicorn.Config( + app, + host=args.host, + port=args.port, + log_level="info", + ) + server = uvicorn.Server(config) + + # Run pump and server concurrently + pump_task = asyncio.create_task(pump.run()) + + try: + await server.serve() + finally: + await pump.shutdown() + pump_task.cancel() + try: + await pump_task + except asyncio.CancelledError: + pass + + try: + print(f"Starting AgentServer on http://{args.host}:{args.port}") + print(f" API docs: http://{args.host}:{args.port}/docs") + print(f" WebSocket: ws://{args.host}:{args.port}/ws") + asyncio.run(run_with_server()) + return 0 + except KeyboardInterrupt: + print("\nShutdown requested.") + return 0 + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + return 1 + + def cmd_init(args: argparse.Namespace) -> int: """Initialize a new organism config.""" from xml_pipeline.config.template import create_organism_template @@ -149,6 +212,13 @@ def main() -> int: run_parser.add_argument("config", nargs="?", default="organism.yaml", help="Config file") run_parser.set_defaults(func=cmd_run) + # serve + serve_parser = subparsers.add_parser("serve", help="Run organism with API server") + serve_parser.add_argument("config", nargs="?", default="organism.yaml", help="Config file") + serve_parser.add_argument("--host", default="0.0.0.0", help="Host to bind (default: 0.0.0.0)") + serve_parser.add_argument("--port", "-p", type=int, default=8080, help="Port to listen on (default: 8080)") + serve_parser.set_defaults(func=cmd_serve) + # init init_parser = subparsers.add_parser("init", help="Create new organism config") init_parser.add_argument("name", nargs="?", help="Organism name") diff --git a/xml_pipeline/message_bus/__init__.py b/xml_pipeline/message_bus/__init__.py index 8aaffa8..51fe135 100644 --- a/xml_pipeline/message_bus/__init__.py +++ b/xml_pipeline/message_bus/__init__.py @@ -33,6 +33,12 @@ from xml_pipeline.message_bus.stream_pump import ( get_stream_pump, set_stream_pump, reset_stream_pump, + # Event hooks + PumpEvent, + MessageReceivedEvent, + MessageSentEvent, + AgentStateEvent, + ThreadEvent, ) from xml_pipeline.message_bus.message_state import ( @@ -71,6 +77,12 @@ __all__ = [ "get_stream_pump", "set_stream_pump", "reset_stream_pump", + # Event hooks + "PumpEvent", + "MessageReceivedEvent", + "MessageSentEvent", + "AgentStateEvent", + "ThreadEvent", # Message state "MessageState", "HandlerMetadata", diff --git a/xml_pipeline/message_bus/stream_pump.py b/xml_pipeline/message_bus/stream_pump.py index 1e7adef..039da62 100644 --- a/xml_pipeline/message_bus/stream_pump.py +++ b/xml_pipeline/message_bus/stream_pump.py @@ -49,6 +49,56 @@ from xml_pipeline.memory import get_context_buffer pump_logger = logging.getLogger(__name__) +# ============================================================================ +# Event Hooks +# ============================================================================ + +@dataclass +class PumpEvent: + """Base class for pump events.""" + pass + + +@dataclass +class MessageReceivedEvent(PumpEvent): + """Fired when a message is received by a handler.""" + thread_id: str + from_id: str + to_id: str + payload_type: str + payload: Any + + +@dataclass +class MessageSentEvent(PumpEvent): + """Fired when a handler sends a response.""" + thread_id: str + from_id: str + to_id: str + payload_type: str + payload: Any + + +@dataclass +class AgentStateEvent(PumpEvent): + """Fired when an agent's processing state changes.""" + agent_name: str + state: str # "idle", "processing", "waiting", "error" + thread_id: Optional[str] = None + + +@dataclass +class ThreadEvent(PumpEvent): + """Fired when a thread is created or completed.""" + thread_id: str + status: str # "created", "active", "completed", "error", "killed" + participants: List[str] = field(default_factory=list) + error: Optional[str] = None + + +EventCallback = Callable[[PumpEvent], None] + + # ============================================================================ # Configuration (same as before) # ============================================================================ @@ -233,6 +283,9 @@ class StreamPump: # Shutdown control self._running = False + # Event hooks for external observers (ServerState, etc.) + self._event_callbacks: List[EventCallback] = [] + # Process pool for cpu_bound handlers self._process_pool: Optional[ProcessPoolExecutor] = None if config.process_pool_enabled: @@ -256,6 +309,27 @@ class StreamPump: self._shared_backend = get_shared_backend(backend_config) pump_logger.info(f"Shared backend: {config.backend_type}") + # ------------------------------------------------------------------ + # Event Hooks + # ------------------------------------------------------------------ + + def subscribe_events(self, callback: EventCallback) -> None: + """Subscribe to pump events (message flow, agent state, thread lifecycle).""" + self._event_callbacks.append(callback) + + def unsubscribe_events(self, callback: EventCallback) -> None: + """Unsubscribe from pump events.""" + if callback in self._event_callbacks: + self._event_callbacks.remove(callback) + + def _emit_event(self, event: PumpEvent) -> None: + """Emit an event to all subscribers (non-blocking).""" + for callback in self._event_callbacks: + try: + callback(event) + except Exception as e: + pump_logger.warning(f"Event callback error: {e}") + # ------------------------------------------------------------------ # Registration # ------------------------------------------------------------------ @@ -493,6 +567,12 @@ class StreamPump: await semaphore.acquire() try: + # Emit agent state change event + self._emit_event(AgentStateEvent( + agent_name=listener.name, + state="processing", + thread_id=state.thread_id, + )) # Ensure we have a valid thread chain registry = get_registry() todo_registry = get_todo_registry() @@ -566,6 +646,15 @@ class StreamPump: ) payload_ref = state.payload + # Emit message received event + self._emit_event(MessageReceivedEvent( + thread_id=current_thread, + from_id=state.from_id or "", + to_id=listener.name, + payload_type=type(payload_ref).__name__, + payload=payload_ref, + )) + # Dispatch to handler - either in-process or via ProcessPool if listener.cpu_bound and self._process_pool and self._shared_backend: response = await self._dispatch_to_process_pool( @@ -578,6 +667,12 @@ class StreamPump: # None means "no response needed" - don't re-inject if response is None: + # Emit idle state + self._emit_event(AgentStateEvent( + agent_name=listener.name, + state="idle", + thread_id=current_thread, + )) continue # Handle clean HandlerResponse (preferred) @@ -653,6 +748,23 @@ class StreamPump: response_bytes = b"Handler returned invalid type" thread_id = state.thread_id + # Emit message sent event + if isinstance(response, HandlerResponse): + self._emit_event(MessageSentEvent( + thread_id=thread_id, + from_id=listener.name, + to_id=to_id, + payload_type=type(response.payload).__name__, + payload=response.payload, + )) + + # Emit agent state back to idle + self._emit_event(AgentStateEvent( + agent_name=listener.name, + state="idle", + thread_id=None, + )) + # Yield response — will be processed by next iteration yield MessageState( raw_bytes=response_bytes, @@ -665,6 +777,12 @@ class StreamPump: semaphore.release() except Exception as exc: + # Emit error state + self._emit_event(AgentStateEvent( + agent_name=listener.name, + state="error", + thread_id=state.thread_id, + )) yield MessageState( raw_bytes=f"Handler {listener.name} crashed: {exc}".encode(), thread_id=state.thread_id, diff --git a/xml_pipeline/server/__init__.py b/xml_pipeline/server/__init__.py new file mode 100644 index 0000000..62dfc58 --- /dev/null +++ b/xml_pipeline/server/__init__.py @@ -0,0 +1,26 @@ +""" +server — FastAPI-based AgentServer API for monitoring and controlling organisms. + +Provides: +- REST API for querying organism state (agents, threads, messages) +- WebSocket for real-time events +- Message injection endpoint + +Usage: + from xml_pipeline.server import create_app, run_server + + # With existing pump + app = create_app(pump) + uvicorn.run(app, host="0.0.0.0", port=8080) + + # Or use CLI + xml-pipeline serve config/organism.yaml --port 8080 +""" + +from xml_pipeline.server.app import create_app, run_server, run_server_sync + +__all__ = [ + "create_app", + "run_server", + "run_server_sync", +] diff --git a/xml_pipeline/server/api.py b/xml_pipeline/server/api.py new file mode 100644 index 0000000..fc60399 --- /dev/null +++ b/xml_pipeline/server/api.py @@ -0,0 +1,273 @@ +""" +api.py — REST API routes for AgentServer. + +Provides endpoints for: +- Organism info and config +- Agent listing and details +- Thread listing and management +- Message injection +""" + +from __future__ import annotations + +import uuid +from typing import TYPE_CHECKING, Optional + +from fastapi import APIRouter, HTTPException, Query + +from xml_pipeline.server.models import ( + AgentInfo, + AgentListResponse, + ErrorResponse, + InjectRequest, + InjectResponse, + MessageListResponse, + OrganismInfo, + ThreadInfo, + ThreadListResponse, + ThreadStatus, +) + +if TYPE_CHECKING: + from xml_pipeline.server.state import ServerState + + +def create_router(state: "ServerState") -> APIRouter: + """Create API router with state dependency.""" + router = APIRouter(prefix="/api/v1") + + # ========================================================================= + # Organism Endpoints + # ========================================================================= + + @router.get("/organism", response_model=OrganismInfo) + async def get_organism() -> OrganismInfo: + """Get organism overview and stats.""" + return state.get_organism_info() + + @router.get("/organism/config") + async def get_organism_config() -> dict: + """Get sanitized organism configuration (no secrets).""" + return state.get_organism_config() + + # ========================================================================= + # Agent Endpoints + # ========================================================================= + + @router.get("/agents", response_model=AgentListResponse) + async def list_agents() -> AgentListResponse: + """List all agents with current state.""" + agents = state.get_agents() + return AgentListResponse(agents=agents, count=len(agents)) + + @router.get("/agents/{name}", response_model=AgentInfo) + async def get_agent(name: str) -> AgentInfo: + """Get single agent details.""" + agent = state.get_agent(name) + if agent is None: + raise HTTPException(status_code=404, detail=f"Agent not found: {name}") + return agent + + @router.get("/agents/{name}/config") + async def get_agent_config(name: str) -> dict: + """Get agent's YAML config section.""" + agent = state.get_agent(name) + if agent is None: + raise HTTPException(status_code=404, detail=f"Agent not found: {name}") + + # Return relevant config fields + return { + "name": agent.name, + "description": agent.description, + "isAgent": agent.is_agent, + "peers": agent.peers, + "payloadClass": agent.payload_class, + } + + @router.get("/agents/{name}/schema") + async def get_agent_schema(name: str) -> dict: + """Get agent's payload XML schema.""" + schema = state.get_agent_schema(name) + if schema is None: + raise HTTPException( + status_code=404, + detail=f"Schema not found for agent: {name}", + ) + return {"schema": schema, "contentType": "application/xml"} + + # ========================================================================= + # Thread Endpoints + # ========================================================================= + + @router.get("/threads", response_model=ThreadListResponse) + async def list_threads( + status: Optional[str] = Query(None, description="Filter by status"), + agent: Optional[str] = Query(None, description="Filter by participant agent"), + limit: int = Query(50, ge=1, le=100), + offset: int = Query(0, ge=0), + ) -> ThreadListResponse: + """List threads with optional filtering.""" + thread_status = None + if status: + try: + thread_status = ThreadStatus(status) + except ValueError: + raise HTTPException( + status_code=400, + detail=f"Invalid status: {status}. Valid values: {[s.value for s in ThreadStatus]}", + ) + + threads, total = state.get_threads( + status=thread_status, + agent=agent, + limit=limit, + offset=offset, + ) + return ThreadListResponse( + threads=threads, + count=len(threads), + total=total, + offset=offset, + limit=limit, + ) + + @router.get("/threads/{thread_id}", response_model=ThreadInfo) + async def get_thread(thread_id: str) -> ThreadInfo: + """Get thread details with message history.""" + thread = state.get_thread(thread_id) + if thread is None: + raise HTTPException(status_code=404, detail=f"Thread not found: {thread_id}") + return thread + + @router.get("/threads/{thread_id}/messages", response_model=MessageListResponse) + async def get_thread_messages( + thread_id: str, + limit: int = Query(50, ge=1, le=100), + offset: int = Query(0, ge=0), + ) -> MessageListResponse: + """Get messages in a specific thread.""" + thread = state.get_thread(thread_id) + if thread is None: + raise HTTPException(status_code=404, detail=f"Thread not found: {thread_id}") + + messages, total = state.get_messages( + thread_id=thread_id, + limit=limit, + offset=offset, + ) + return MessageListResponse( + messages=messages, + count=len(messages), + total=total, + offset=offset, + limit=limit, + ) + + @router.post("/threads/{thread_id}/kill") + async def kill_thread(thread_id: str) -> dict: + """Terminate a thread.""" + thread = state.get_thread(thread_id) + if thread is None: + raise HTTPException(status_code=404, detail=f"Thread not found: {thread_id}") + + await state.complete_thread(thread_id, status=ThreadStatus.KILLED) + return {"success": True, "threadId": thread_id} + + # ========================================================================= + # Message Endpoints + # ========================================================================= + + @router.get("/messages", response_model=MessageListResponse) + async def list_messages( + agent: Optional[str] = Query(None, description="Filter by agent (sender or receiver)"), + limit: int = Query(50, ge=1, le=100), + offset: int = Query(0, ge=0), + ) -> MessageListResponse: + """Get global message history.""" + messages, total = state.get_messages( + agent=agent, + limit=limit, + offset=offset, + ) + return MessageListResponse( + messages=messages, + count=len(messages), + total=total, + offset=offset, + limit=limit, + ) + + # ========================================================================= + # Control Endpoints + # ========================================================================= + + @router.post("/inject", response_model=InjectResponse) + async def inject_message(request: InjectRequest) -> InjectResponse: + """Inject a message to an agent.""" + # Validate target exists + agent = state.get_agent(request.to) + if agent is None: + raise HTTPException( + status_code=400, + detail=f"Unknown target agent: {request.to}", + ) + + # Generate or use provided thread ID + thread_id = request.thread_id or str(uuid.uuid4()) + + # Build payload XML from dict + # For now, we construct a simple wrapper + payload_type = next(iter(request.payload.keys()), "Payload") + + # Record the message + msg_id = await state.record_message( + thread_id=thread_id, + from_id="api", + to_id=request.to, + payload_type=payload_type, + payload=request.payload, + ) + + # TODO: Actually inject into pump queue + # This requires building an envelope and calling pump.inject() + + return InjectResponse(thread_id=thread_id, message_id=msg_id) + + @router.post("/agents/{name}/pause") + async def pause_agent(name: str) -> dict: + """Pause an agent (stop processing new messages).""" + agent = state.get_agent(name) + if agent is None: + raise HTTPException(status_code=404, detail=f"Agent not found: {name}") + + from xml_pipeline.server.models import AgentState + + await state.update_agent_state(name, AgentState.PAUSED) + return {"success": True, "agent": name, "state": "paused"} + + @router.post("/agents/{name}/resume") + async def resume_agent(name: str) -> dict: + """Resume a paused agent.""" + agent = state.get_agent(name) + if agent is None: + raise HTTPException(status_code=404, detail=f"Agent not found: {name}") + + from xml_pipeline.server.models import AgentState + + await state.update_agent_state(name, AgentState.IDLE) + return {"success": True, "agent": name, "state": "idle"} + + @router.post("/organism/reload") + async def reload_config() -> dict: + """Hot-reload organism configuration.""" + # TODO: Implement hot-reload + return {"success": False, "error": "Hot-reload not yet implemented"} + + @router.post("/organism/stop") + async def stop_organism() -> dict: + """Graceful shutdown.""" + state.set_stopping() + # TODO: Signal pump to stop + return {"success": True, "status": "stopping"} + + return router diff --git a/xml_pipeline/server/app.py b/xml_pipeline/server/app.py new file mode 100644 index 0000000..5da1428 --- /dev/null +++ b/xml_pipeline/server/app.py @@ -0,0 +1,148 @@ +""" +app.py — FastAPI application factory for AgentServer. + +Creates the FastAPI app that combines REST API and WebSocket endpoints. +""" + +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from xml_pipeline.server.api import create_router +from xml_pipeline.server.state import ServerState +from xml_pipeline.server.websocket import create_websocket_router + +if TYPE_CHECKING: + from xml_pipeline.message_bus.stream_pump import StreamPump + + +def create_app( + pump: "StreamPump", + *, + title: str = "AgentServer API", + version: str = "1.0.0", + cors_origins: Optional[list[str]] = None, +) -> FastAPI: + """ + Create FastAPI application with REST and WebSocket endpoints. + + Args: + pump: The StreamPump instance to wrap + title: API title for OpenAPI docs + version: API version + cors_origins: List of allowed CORS origins (default: all) + + Returns: + Configured FastAPI application + """ + # Create state manager + state = ServerState(pump) + + @asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + """Manage app lifecycle - startup and shutdown.""" + # Startup + state.set_running() + yield + # Shutdown + state.set_stopping() + + app = FastAPI( + title=title, + version=version, + description="REST and WebSocket API for monitoring and controlling xml-pipeline organisms.", + lifespan=lifespan, + ) + + # CORS middleware + if cors_origins is None: + cors_origins = ["*"] + + app.add_middleware( + CORSMiddleware, + allow_origins=cors_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + # Include routers + app.include_router(create_router(state)) + app.include_router(create_websocket_router(state)) + + # Store state on app for access if needed + app.state.server_state = state + app.state.pump = pump + + @app.get("/health") + async def health_check() -> dict[str, Any]: + """Health check endpoint.""" + info = state.get_organism_info() + return { + "status": "healthy", + "organism": info.name, + "uptime_seconds": info.uptime_seconds, + } + + return app + + +async def run_server( + pump: "StreamPump", + *, + host: str = "0.0.0.0", + port: int = 8080, + cors_origins: Optional[list[str]] = None, +) -> None: + """ + Run the AgentServer with uvicorn. + + Args: + pump: The StreamPump instance to wrap + host: Host to bind to + port: Port to listen on + cors_origins: List of allowed CORS origins + """ + try: + import uvicorn + except ImportError as e: + raise ImportError( + "uvicorn is required for the server. Install with: pip install xml-pipeline[server]" + ) from e + + app = create_app(pump, cors_origins=cors_origins) + + config = uvicorn.Config( + app, + host=host, + port=port, + log_level="info", + ) + server = uvicorn.Server(config) + await server.serve() + + +def run_server_sync( + pump: "StreamPump", + *, + host: str = "0.0.0.0", + port: int = 8080, + cors_origins: Optional[list[str]] = None, +) -> None: + """ + Run the AgentServer synchronously (blocking). + + This is a convenience wrapper for CLI usage. + + Args: + pump: The StreamPump instance to wrap + host: Host to bind to + port: Port to listen on + cors_origins: List of allowed CORS origins + """ + asyncio.run(run_server(pump, host=host, port=port, cors_origins=cors_origins)) diff --git a/xml_pipeline/server/models.py b/xml_pipeline/server/models.py new file mode 100644 index 0000000..9d3bf56 --- /dev/null +++ b/xml_pipeline/server/models.py @@ -0,0 +1,261 @@ +""" +models.py — Pydantic models for AgentServer API. + +These models define the JSON structure for API responses. +Uses camelCase for JSON keys (JavaScript convention). +""" + +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field + + +def to_camel(string: str) -> str: + """Convert snake_case to camelCase.""" + components = string.split("_") + return components[0] + "".join(x.title() for x in components[1:]) + + +class CamelModel(BaseModel): + """Base model with camelCase JSON serialization.""" + + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + ) + + +# ============================================================================= +# Enums +# ============================================================================= + + +class AgentState(str, Enum): + """Agent processing state.""" + + IDLE = "idle" + PROCESSING = "processing" + WAITING = "waiting" + ERROR = "error" + PAUSED = "paused" + + +class ThreadStatus(str, Enum): + """Thread lifecycle status.""" + + ACTIVE = "active" + COMPLETED = "completed" + ERROR = "error" + KILLED = "killed" + + +class OrganismStatus(str, Enum): + """Organism running status.""" + + STARTING = "starting" + RUNNING = "running" + STOPPING = "stopping" + STOPPED = "stopped" + + +# ============================================================================= +# API Response Models +# ============================================================================= + + +class AgentInfo(CamelModel): + """Agent information for API response.""" + + name: str + description: str + is_agent: bool = Field(alias="isAgent") + peers: List[str] = Field(default_factory=list) + payload_class: str = Field(alias="payloadClass") + state: AgentState = AgentState.IDLE + current_thread: Optional[str] = Field(None, alias="currentThread") + queue_depth: int = Field(0, alias="queueDepth") + last_activity: Optional[datetime] = Field(None, alias="lastActivity") + message_count: int = Field(0, alias="messageCount") + + +class MessageInfo(CamelModel): + """Message information for API response.""" + + id: str + thread_id: str = Field(alias="threadId") + from_id: str = Field(alias="from") + to_id: str = Field(alias="to") + payload_type: str = Field(alias="payloadType") + payload: Dict[str, Any] = Field(default_factory=dict) + timestamp: datetime + slot_index: Optional[int] = Field(None, alias="slotIndex") + + +class ThreadInfo(CamelModel): + """Thread information for API response.""" + + id: str + status: ThreadStatus = ThreadStatus.ACTIVE + participants: List[str] = Field(default_factory=list) + message_count: int = Field(0, alias="messageCount") + created_at: datetime = Field(alias="createdAt") + last_activity: Optional[datetime] = Field(None, alias="lastActivity") + error: Optional[str] = None + + +class OrganismInfo(CamelModel): + """Organism overview for API response.""" + + name: str + status: OrganismStatus = OrganismStatus.RUNNING + uptime_seconds: float = Field(0.0, alias="uptimeSeconds") + agent_count: int = Field(0, alias="agentCount") + active_threads: int = Field(0, alias="activeThreads") + total_messages: int = Field(0, alias="totalMessages") + identity_configured: bool = Field(False, alias="identityConfigured") + + +class OrganismConfig(CamelModel): + """Sanitized organism configuration for API response.""" + + name: str + port: int = 8765 + thread_scheduling: str = Field("breadth-first", alias="threadScheduling") + max_concurrent_pipelines: int = Field(50, alias="maxConcurrentPipelines") + max_concurrent_handlers: int = Field(20, alias="maxConcurrentHandlers") + listeners: List[str] = Field(default_factory=list) + # Note: Secrets like API keys are never exposed + + +# ============================================================================= +# Request Models +# ============================================================================= + + +class InjectRequest(CamelModel): + """Request body for POST /inject.""" + + to: str + payload: Dict[str, Any] + thread_id: Optional[str] = Field(None, alias="threadId") + + +class InjectResponse(CamelModel): + """Response body for POST /inject.""" + + thread_id: str = Field(alias="threadId") + message_id: str = Field(alias="messageId") + + +class SubscribeRequest(CamelModel): + """WebSocket subscription filter.""" + + threads: Optional[List[str]] = None + agents: Optional[List[str]] = None + payload_types: Optional[List[str]] = Field(None, alias="payloadTypes") + events: Optional[List[str]] = None + + +# ============================================================================= +# WebSocket Event Models +# ============================================================================= + + +class WSEvent(CamelModel): + """Base WebSocket event.""" + + event: str + + +class WSConnectedEvent(WSEvent): + """Sent on WebSocket connection with full state snapshot.""" + + event: str = "connected" + organism: OrganismInfo + agents: List[AgentInfo] + threads: List[ThreadInfo] + + +class WSAgentStateEvent(WSEvent): + """Agent state changed.""" + + event: str = "agent_state" + agent: str + state: AgentState + current_thread: Optional[str] = Field(None, alias="currentThread") + + +class WSMessageEvent(WSEvent): + """New message in the system.""" + + event: str = "message" + message: MessageInfo + + +class WSThreadCreatedEvent(WSEvent): + """New thread started.""" + + event: str = "thread_created" + thread: ThreadInfo + + +class WSThreadUpdatedEvent(WSEvent): + """Thread status changed.""" + + event: str = "thread_updated" + thread_id: str = Field(alias="threadId") + status: ThreadStatus + message_count: int = Field(alias="messageCount") + + +class WSErrorEvent(WSEvent): + """Error occurred.""" + + event: str = "error" + thread_id: Optional[str] = Field(None, alias="threadId") + agent: Optional[str] = None + error: str + timestamp: datetime + + +# ============================================================================= +# List Response Models +# ============================================================================= + + +class AgentListResponse(CamelModel): + """Response for GET /agents.""" + + agents: List[AgentInfo] + count: int + + +class ThreadListResponse(CamelModel): + """Response for GET /threads.""" + + threads: List[ThreadInfo] + count: int + total: int + offset: int + limit: int + + +class MessageListResponse(CamelModel): + """Response for GET /messages or /threads/{id}/messages.""" + + messages: List[MessageInfo] + count: int + total: int + offset: int + limit: int + + +class ErrorResponse(CamelModel): + """Error response.""" + + error: str + detail: Optional[str] = None diff --git a/xml_pipeline/server/state.py b/xml_pipeline/server/state.py new file mode 100644 index 0000000..adeb3e5 --- /dev/null +++ b/xml_pipeline/server/state.py @@ -0,0 +1,515 @@ +""" +state.py — Server state manager for tracking organism runtime state. + +Maintains real-time state for API queries: +- Agent states (idle, processing, waiting, error) +- Active threads and their participants +- Message counts and activity timestamps + +This is the bridge between the StreamPump and the API. +""" + +from __future__ import annotations + +import asyncio +import time +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set + +from xml_pipeline.server.models import ( + AgentInfo, + AgentState, + MessageInfo, + OrganismInfo, + OrganismStatus, + ThreadInfo, + ThreadStatus, +) + +if TYPE_CHECKING: + from xml_pipeline.message_bus.stream_pump import StreamPump + from xml_pipeline.message_bus import ( + PumpEvent, + MessageReceivedEvent, + MessageSentEvent, + AgentStateEvent, + ThreadEvent, + ) + + +@dataclass +class AgentRuntimeState: + """Runtime state for a single agent.""" + + name: str + description: str + is_agent: bool + peers: List[str] + payload_class: str + state: AgentState = AgentState.IDLE + current_thread: Optional[str] = None + last_activity: Optional[datetime] = None + message_count: int = 0 + queue_depth: int = 0 + + +@dataclass +class ThreadRuntimeState: + """Runtime state for a single thread.""" + + id: str + status: ThreadStatus = ThreadStatus.ACTIVE + participants: Set[str] = field(default_factory=set) + message_count: int = 0 + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + last_activity: Optional[datetime] = None + error: Optional[str] = None + + +@dataclass +class MessageRecord: + """Record of a message for history.""" + + id: str + thread_id: str + from_id: str + to_id: str + payload_type: str + payload: Dict[str, Any] + timestamp: datetime + slot_index: int + + +class ServerState: + """ + Centralized state manager for the API server. + + Tracks runtime state of agents, threads, and messages. + Provides methods for the API to query current state. + """ + + def __init__(self, pump: StreamPump) -> None: + self.pump = pump + self._start_time = time.time() + self._status = OrganismStatus.STARTING + + # Runtime state + self._agents: Dict[str, AgentRuntimeState] = {} + self._threads: Dict[str, ThreadRuntimeState] = {} + self._messages: List[MessageRecord] = [] + self._message_count = 0 + + # Event subscribers (WebSocket connections) + self._subscribers: Set[Callable] = set() + + # Lock for thread-safe updates + self._lock = asyncio.Lock() + + # Initialize agent states from pump + self._init_agents() + + # Subscribe to pump events + self._subscribe_to_pump() + + def _init_agents(self) -> None: + """Initialize agent states from pump listeners.""" + for name, listener in self.pump.listeners.items(): + self._agents[name] = AgentRuntimeState( + name=name, + description=listener.description, + is_agent=listener.is_agent, + peers=list(listener.peers), + payload_class=f"{listener.payload_class.__module__}.{listener.payload_class.__name__}", + ) + + def _subscribe_to_pump(self) -> None: + """Subscribe to pump events for real-time state updates.""" + from xml_pipeline.message_bus import ( + PumpEvent, + MessageReceivedEvent, + MessageSentEvent, + AgentStateEvent as PumpAgentStateEvent, + ThreadEvent, + ) + + def handle_pump_event(event: "PumpEvent") -> None: + """Handle pump events synchronously (schedules async updates).""" + if isinstance(event, MessageReceivedEvent): + # Schedule async message recording + asyncio.create_task(self._handle_message_received(event)) + elif isinstance(event, MessageSentEvent): + # Schedule async message recording + asyncio.create_task(self._handle_message_sent(event)) + elif isinstance(event, PumpAgentStateEvent): + # Schedule async agent state update + asyncio.create_task(self._handle_agent_state(event)) + elif isinstance(event, ThreadEvent): + # Schedule async thread update + asyncio.create_task(self._handle_thread_event(event)) + + self.pump.subscribe_events(handle_pump_event) + + async def _handle_message_received(self, event: "MessageReceivedEvent") -> None: + """Handle message received by handler.""" + # Convert payload to dict representation + payload_dict = {} + if hasattr(event.payload, '__dataclass_fields__'): + import dataclasses + payload_dict = dataclasses.asdict(event.payload) + elif isinstance(event.payload, dict): + payload_dict = event.payload + + await self.record_message( + thread_id=event.thread_id, + from_id=event.from_id, + to_id=event.to_id, + payload_type=event.payload_type, + payload=payload_dict, + ) + + async def _handle_message_sent(self, event: "MessageSentEvent") -> None: + """Handle message sent by handler.""" + payload_dict = {} + if hasattr(event.payload, '__dataclass_fields__'): + import dataclasses + payload_dict = dataclasses.asdict(event.payload) + elif isinstance(event.payload, dict): + payload_dict = event.payload + + await self.record_message( + thread_id=event.thread_id, + from_id=event.from_id, + to_id=event.to_id, + payload_type=event.payload_type, + payload=payload_dict, + ) + + async def _handle_agent_state(self, event: "AgentStateEvent") -> None: + """Handle agent state change from pump.""" + # Map pump state strings to AgentState enum + state_map = { + "idle": AgentState.IDLE, + "processing": AgentState.PROCESSING, + "waiting": AgentState.WAITING, + "error": AgentState.ERROR, + "paused": AgentState.PAUSED, + } + state = state_map.get(event.state, AgentState.IDLE) + await self.update_agent_state(event.agent_name, state, event.thread_id) + + async def _handle_thread_event(self, event: "ThreadEvent") -> None: + """Handle thread lifecycle event from pump.""" + # Map status to ThreadStatus enum + status_map = { + "created": ThreadStatus.ACTIVE, + "active": ThreadStatus.ACTIVE, + "completed": ThreadStatus.COMPLETED, + "error": ThreadStatus.ERROR, + "killed": ThreadStatus.KILLED, + } + status = status_map.get(event.status, ThreadStatus.ACTIVE) + + if event.status in ("completed", "error", "killed"): + await self.complete_thread(event.thread_id, status, event.error) + + def set_running(self) -> None: + """Mark organism as running.""" + self._status = OrganismStatus.RUNNING + + def set_stopping(self) -> None: + """Mark organism as stopping.""" + self._status = OrganismStatus.STOPPING + + # ========================================================================= + # Event Recording (called by pump hooks) + # ========================================================================= + + async def record_message( + self, + thread_id: str, + from_id: str, + to_id: str, + payload_type: str, + payload: Dict[str, Any], + ) -> str: + """Record a message and update related state.""" + async with self._lock: + msg_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc) + self._message_count += 1 + + record = MessageRecord( + id=msg_id, + thread_id=thread_id, + from_id=from_id, + to_id=to_id, + payload_type=payload_type, + payload=payload, + timestamp=now, + slot_index=self._message_count, + ) + self._messages.append(record) + + # Update thread state + if thread_id not in self._threads: + self._threads[thread_id] = ThreadRuntimeState( + id=thread_id, + created_at=now, + ) + thread = self._threads[thread_id] + thread.participants.add(from_id) + thread.participants.add(to_id) + thread.message_count += 1 + thread.last_activity = now + + # Update agent states + if from_id in self._agents: + self._agents[from_id].message_count += 1 + self._agents[from_id].last_activity = now + + # Notify subscribers + await self._notify_message(record) + + return msg_id + + async def update_agent_state( + self, + agent_name: str, + state: AgentState, + current_thread: Optional[str] = None, + ) -> None: + """Update an agent's processing state.""" + async with self._lock: + if agent_name not in self._agents: + return + + agent = self._agents[agent_name] + old_state = agent.state + agent.state = state + agent.current_thread = current_thread + agent.last_activity = datetime.now(timezone.utc) + + if old_state != state: + await self._notify_agent_state(agent_name, state, current_thread) + + async def complete_thread( + self, + thread_id: str, + status: ThreadStatus = ThreadStatus.COMPLETED, + error: Optional[str] = None, + ) -> None: + """Mark a thread as completed or errored.""" + async with self._lock: + if thread_id not in self._threads: + return + + thread = self._threads[thread_id] + thread.status = status + thread.error = error + thread.last_activity = datetime.now(timezone.utc) + + await self._notify_thread_updated(thread) + + # ========================================================================= + # Query Methods (for API) + # ========================================================================= + + def get_organism_info(self) -> OrganismInfo: + """Get organism overview.""" + return OrganismInfo( + name=self.pump.config.name, + status=self._status, + uptime_seconds=time.time() - self._start_time, + agent_count=len(self._agents), + active_threads=sum( + 1 for t in self._threads.values() if t.status == ThreadStatus.ACTIVE + ), + total_messages=self._message_count, + identity_configured=self.pump.identity is not None, + ) + + def get_organism_config(self) -> Dict[str, Any]: + """Get sanitized organism config (no secrets).""" + return { + "name": self.pump.config.name, + "port": self.pump.config.port, + "thread_scheduling": self.pump.config.thread_scheduling, + "max_concurrent_pipelines": self.pump.config.max_concurrent_pipelines, + "max_concurrent_handlers": self.pump.config.max_concurrent_handlers, + "listeners": list(self.pump.listeners.keys()), + } + + def get_agents(self) -> List[AgentInfo]: + """Get all agents.""" + return [self._agent_to_info(a) for a in self._agents.values()] + + def get_agent(self, name: str) -> Optional[AgentInfo]: + """Get a single agent by name.""" + agent = self._agents.get(name) + if agent: + return self._agent_to_info(agent) + return None + + def get_agent_schema(self, name: str) -> Optional[str]: + """Get agent's XSD schema as string.""" + listener = self.pump.listeners.get(name) + if listener and listener.schema is not None: + from lxml import etree + + return etree.tostring(listener.schema, encoding="unicode", pretty_print=True) + return None + + def get_threads( + self, + status: Optional[ThreadStatus] = None, + agent: Optional[str] = None, + limit: int = 50, + offset: int = 0, + ) -> tuple[List[ThreadInfo], int]: + """Get threads with optional filtering.""" + threads = list(self._threads.values()) + + # Filter + if status: + threads = [t for t in threads if t.status == status] + if agent: + threads = [t for t in threads if agent in t.participants] + + # Sort by last activity (most recent first) + threads.sort(key=lambda t: t.last_activity or t.created_at, reverse=True) + + total = len(threads) + threads = threads[offset : offset + limit] + + return [self._thread_to_info(t) for t in threads], total + + def get_thread(self, thread_id: str) -> Optional[ThreadInfo]: + """Get a single thread by ID.""" + thread = self._threads.get(thread_id) + if thread: + return self._thread_to_info(thread) + return None + + def get_messages( + self, + thread_id: Optional[str] = None, + agent: Optional[str] = None, + limit: int = 50, + offset: int = 0, + ) -> tuple[List[MessageInfo], int]: + """Get messages with optional filtering.""" + messages = self._messages.copy() + + # Filter + if thread_id: + messages = [m for m in messages if m.thread_id == thread_id] + if agent: + messages = [m for m in messages if m.from_id == agent or m.to_id == agent] + + # Sort by timestamp (most recent first) + messages.sort(key=lambda m: m.timestamp, reverse=True) + + total = len(messages) + messages = messages[offset : offset + limit] + + return [self._message_to_info(m) for m in messages], total + + # ========================================================================= + # Subscription Management + # ========================================================================= + + def subscribe(self, callback: Callable) -> None: + """Subscribe to state events.""" + self._subscribers.add(callback) + + def unsubscribe(self, callback: Callable) -> None: + """Unsubscribe from state events.""" + self._subscribers.discard(callback) + + async def _notify_message(self, record: MessageRecord) -> None: + """Notify subscribers of new message.""" + event = { + "event": "message", + "message": self._message_to_info(record).model_dump(by_alias=True), + } + await self._broadcast(event) + + async def _notify_agent_state( + self, + agent_name: str, + state: AgentState, + current_thread: Optional[str], + ) -> None: + """Notify subscribers of agent state change.""" + event = { + "event": "agent_state", + "agent": agent_name, + "state": state.value, + "currentThread": current_thread, + } + await self._broadcast(event) + + async def _notify_thread_updated(self, thread: ThreadRuntimeState) -> None: + """Notify subscribers of thread update.""" + event = { + "event": "thread_updated", + "threadId": thread.id, + "status": thread.status.value, + "messageCount": thread.message_count, + } + await self._broadcast(event) + + async def _broadcast(self, event: Dict[str, Any]) -> None: + """Broadcast event to all subscribers.""" + for callback in list(self._subscribers): + try: + await callback(event) + except Exception: + # Remove failed subscribers + self._subscribers.discard(callback) + + # ========================================================================= + # Conversion Helpers + # ========================================================================= + + def _agent_to_info(self, agent: AgentRuntimeState) -> AgentInfo: + """Convert runtime state to API model.""" + return AgentInfo( + name=agent.name, + description=agent.description, + is_agent=agent.is_agent, + peers=agent.peers, + payload_class=agent.payload_class, + state=agent.state, + current_thread=agent.current_thread, + queue_depth=agent.queue_depth, + last_activity=agent.last_activity, + message_count=agent.message_count, + ) + + def _thread_to_info(self, thread: ThreadRuntimeState) -> ThreadInfo: + """Convert runtime state to API model.""" + return ThreadInfo( + id=thread.id, + status=thread.status, + participants=list(thread.participants), + message_count=thread.message_count, + created_at=thread.created_at, + last_activity=thread.last_activity, + error=thread.error, + ) + + def _message_to_info(self, record: MessageRecord) -> MessageInfo: + """Convert record to API model.""" + return MessageInfo( + id=record.id, + thread_id=record.thread_id, + from_id=record.from_id, + to_id=record.to_id, + payload_type=record.payload_type, + payload=record.payload, + timestamp=record.timestamp, + slot_index=record.slot_index, + ) diff --git a/xml_pipeline/server/websocket.py b/xml_pipeline/server/websocket.py new file mode 100644 index 0000000..bd244b2 --- /dev/null +++ b/xml_pipeline/server/websocket.py @@ -0,0 +1,316 @@ +""" +websocket.py — WebSocket endpoints for AgentServer. + +Provides: +- /ws — Main control channel with state snapshot and real-time events +- /ws/messages — Dedicated message log stream with filtering +""" + +from __future__ import annotations + +import asyncio +import json +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set + +from fastapi import APIRouter, WebSocket, WebSocketDisconnect + +from xml_pipeline.server.models import ( + SubscribeRequest, + WSConnectedEvent, +) + +if TYPE_CHECKING: + from xml_pipeline.server.state import ServerState + + +class ConnectionManager: + """Manages WebSocket connections and subscriptions.""" + + def __init__(self) -> None: + self.active_connections: Set[WebSocket] = set() + self.subscriptions: Dict[WebSocket, SubscribeRequest] = {} + + async def connect(self, websocket: WebSocket) -> None: + """Accept a new WebSocket connection.""" + await websocket.accept() + self.active_connections.add(websocket) + self.subscriptions[websocket] = SubscribeRequest() # Default: all events + + def disconnect(self, websocket: WebSocket) -> None: + """Remove a WebSocket connection.""" + self.active_connections.discard(websocket) + self.subscriptions.pop(websocket, None) + + def set_subscription(self, websocket: WebSocket, sub: SubscribeRequest) -> None: + """Update subscription filters for a connection.""" + self.subscriptions[websocket] = sub + + def should_send(self, websocket: WebSocket, event: Dict[str, Any]) -> bool: + """Check if an event should be sent to this connection based on filters.""" + sub = self.subscriptions.get(websocket) + if sub is None: + return True # No filters = send all + + event_type = event.get("event", "") + + # Filter by event type + if sub.events and event_type not in sub.events: + return False + + # Filter by thread + thread_id = event.get("thread_id") or event.get("threadId") + if sub.threads and thread_id and thread_id not in sub.threads: + return False + + # Filter by agent + agent = event.get("agent") + if sub.agents and agent and agent not in sub.agents: + return False + + # For message events, check from/to + if event_type == "message": + msg = event.get("message", {}) + from_id = msg.get("from") or msg.get("fromId") + to_id = msg.get("to") or msg.get("toId") + if sub.agents: + if from_id not in sub.agents and to_id not in sub.agents: + return False + + # Filter by payload type + if sub.payload_types: + payload_type = None + if event_type == "message": + payload_type = event.get("message", {}).get("payloadType") + if payload_type and payload_type not in sub.payload_types: + return False + + return True + + async def broadcast(self, event: Dict[str, Any]) -> None: + """Broadcast an event to all connections that match their subscription.""" + disconnected: List[WebSocket] = [] + + for websocket in self.active_connections: + if self.should_send(websocket, event): + try: + await websocket.send_json(event) + except Exception: + disconnected.append(websocket) + + # Clean up disconnected clients + for ws in disconnected: + self.disconnect(ws) + + +class MessageStreamManager: + """Manages WebSocket connections for /ws/messages endpoint.""" + + def __init__(self) -> None: + self.active_connections: Set[WebSocket] = set() + self.filters: Dict[WebSocket, Dict[str, Any]] = {} + + async def connect(self, websocket: WebSocket) -> None: + """Accept a new WebSocket connection.""" + await websocket.accept() + self.active_connections.add(websocket) + self.filters[websocket] = {} # Default: all messages + + def disconnect(self, websocket: WebSocket) -> None: + """Remove a WebSocket connection.""" + self.active_connections.discard(websocket) + self.filters.pop(websocket, None) + + def set_filter(self, websocket: WebSocket, filter_config: Dict[str, Any]) -> None: + """Update message filter for a connection.""" + self.filters[websocket] = filter_config + + def should_send(self, websocket: WebSocket, message: Dict[str, Any]) -> bool: + """Check if a message should be sent to this connection.""" + flt = self.filters.get(websocket, {}) + if not flt: + return True + + # Filter by agents + agents = flt.get("agents", []) + if agents: + from_id = message.get("from") or message.get("fromId") + to_id = message.get("to") or message.get("toId") + if from_id not in agents and to_id not in agents: + return False + + # Filter by threads + threads = flt.get("threads", []) + if threads: + thread_id = message.get("thread_id") or message.get("threadId") + if thread_id not in threads: + return False + + # Filter by payload types + payload_types = flt.get("payload_types", []) + if payload_types: + payload_type = message.get("payloadType") or message.get("payload_type") + if payload_type not in payload_types: + return False + + return True + + async def broadcast_message(self, message: Dict[str, Any]) -> None: + """Broadcast a message to all filtered connections.""" + disconnected: List[WebSocket] = [] + + for websocket in self.active_connections: + if self.should_send(websocket, message): + try: + await websocket.send_json(message) + except Exception: + disconnected.append(websocket) + + for ws in disconnected: + self.disconnect(ws) + + +def create_websocket_router(state: "ServerState") -> APIRouter: + """Create WebSocket router with state dependency.""" + router = APIRouter() + manager = ConnectionManager() + message_manager = MessageStreamManager() + + # Subscribe state to WebSocket broadcasting + async def on_state_event(event: Dict[str, Any]) -> None: + """Forward state events to WebSocket connections.""" + await manager.broadcast(event) + + # Also forward message events to the message stream + if event.get("event") == "message": + msg = event.get("message", {}) + await message_manager.broadcast_message(msg) + + state.subscribe(on_state_event) + + @router.websocket("/ws") + async def websocket_control(websocket: WebSocket) -> None: + """ + Main control channel WebSocket. + + On connect, sends full state snapshot. Then pushes events as they occur. + Accepts commands: subscribe, inject. + """ + await manager.connect(websocket) + + try: + # Send connected event with state snapshot + threads, _ = state.get_threads(limit=100) + connected_event = WSConnectedEvent( + organism=state.get_organism_info(), + agents=state.get_agents(), + threads=threads, + ) + await websocket.send_json(connected_event.model_dump(by_alias=True)) + + # Listen for commands + while True: + try: + data = await websocket.receive_json() + except Exception: + # Connection closed or invalid data + break + + cmd = data.get("cmd", "") + + if cmd == "subscribe": + # Update subscription filters + sub = SubscribeRequest( + threads=data.get("threads"), + agents=data.get("agents"), + payload_types=data.get("payload_types"), + events=data.get("events"), + ) + manager.set_subscription(websocket, sub) + await websocket.send_json({"event": "subscribed", "filters": data}) + + elif cmd == "inject": + # Inject a message (same as REST /inject) + target = data.get("to") + payload = data.get("payload", {}) + thread_id = data.get("thread_id") + + if not target: + await websocket.send_json( + {"event": "error", "error": "Missing 'to' field"} + ) + continue + + agent = state.get_agent(target) + if agent is None: + await websocket.send_json( + {"event": "error", "error": f"Unknown agent: {target}"} + ) + continue + + import uuid + + thread_id = thread_id or str(uuid.uuid4()) + payload_type = next(iter(payload.keys()), "Payload") + + msg_id = await state.record_message( + thread_id=thread_id, + from_id="api", + to_id=target, + payload_type=payload_type, + payload=payload, + ) + + await websocket.send_json( + { + "event": "injected", + "thread_id": thread_id, + "message_id": msg_id, + } + ) + + else: + await websocket.send_json( + {"event": "error", "error": f"Unknown command: {cmd}"} + ) + + except WebSocketDisconnect: + pass + finally: + manager.disconnect(websocket) + + @router.websocket("/ws/messages") + async def websocket_messages(websocket: WebSocket) -> None: + """ + Dedicated message log stream. + + Streams all messages flowing through the organism. + Clients can filter by agents, threads, and payload types. + """ + await message_manager.connect(websocket) + + try: + while True: + try: + data = await websocket.receive_json() + except Exception: + break + + cmd = data.get("cmd", "") + + if cmd == "subscribe": + # Update filter + filter_config = data.get("filter", {}) + message_manager.set_filter(websocket, filter_config) + await websocket.send_json({"event": "subscribed", "filter": filter_config}) + + else: + await websocket.send_json( + {"event": "error", "error": f"Unknown command: {cmd}"} + ) + + except WebSocketDisconnect: + pass + finally: + message_manager.disconnect(websocket) + + return router