Add AgentServer REST/WebSocket API
Implements the AgentServer API from docs/agentserver_api_spec.md: REST API (/api/v1): - Organism info and config endpoints - Agent listing, details, config, schema - Thread and message history with filtering - Control endpoints (inject, pause, resume, kill, stop) WebSocket: - /ws: Main control channel with state snapshot + real-time events - /ws/messages: Dedicated message stream with filtering Infrastructure: - Pydantic models with camelCase serialization - ServerState bridges StreamPump to API - Pump event hooks for real-time updates - CLI 'serve' command: xml-pipeline serve [config] --port 8080 35 new tests for models, state, REST, and WebSocket. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
809862af35
commit
bf31b0d14e
11 changed files with 2257 additions and 1 deletions
|
|
@ -80,6 +80,13 @@ search = ["duckduckgo-search>=6.0"] # Web search tool
|
|||
# Console example (optional, for interactive use)
|
||||
console = ["prompt_toolkit>=3.0"]
|
||||
|
||||
# API server (FastAPI + WebSocket)
|
||||
server = [
|
||||
"fastapi>=0.109",
|
||||
"uvicorn[standard]>=0.27",
|
||||
"websockets>=12.0",
|
||||
]
|
||||
|
||||
# All LLM providers
|
||||
llm = ["xml-pipeline[anthropic,openai]"]
|
||||
|
||||
|
|
@ -87,7 +94,7 @@ llm = ["xml-pipeline[anthropic,openai]"]
|
|||
tools = ["xml-pipeline[redis,search]"]
|
||||
|
||||
# Everything (for local development)
|
||||
all = ["xml-pipeline[llm,tools,console]"]
|
||||
all = ["xml-pipeline[llm,tools,console,server]"]
|
||||
|
||||
# Testing
|
||||
test = [
|
||||
|
|
|
|||
510
tests/test_server.py
Normal file
510
tests/test_server.py
Normal file
|
|
@ -0,0 +1,510 @@
|
|||
"""
|
||||
Tests for the AgentServer API.
|
||||
|
||||
Tests the REST API endpoints and WebSocket connections.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pytest
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
# Skip all tests if FastAPI not available
|
||||
pytest.importorskip("fastapi")
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from xml_pipeline.server.models import (
|
||||
AgentState,
|
||||
AgentInfo,
|
||||
ThreadStatus,
|
||||
ThreadInfo,
|
||||
MessageInfo,
|
||||
OrganismInfo,
|
||||
OrganismStatus,
|
||||
)
|
||||
from xml_pipeline.server.state import ServerState, AgentRuntimeState, ThreadRuntimeState
|
||||
from xml_pipeline.server.api import create_router
|
||||
from xml_pipeline.server.app import create_app
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Mock StreamPump
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class MockListener:
|
||||
"""Mock listener for testing."""
|
||||
name: str
|
||||
description: str
|
||||
is_agent: bool = False
|
||||
peers: List[str] = None
|
||||
payload_class: type = None
|
||||
schema: Any = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.peers is None:
|
||||
self.peers = []
|
||||
if self.payload_class is None:
|
||||
self.payload_class = type("MockPayload", (), {"__module__": "test", "__name__": "MockPayload"})
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockConfig:
|
||||
"""Mock config for testing."""
|
||||
name: str = "test-organism"
|
||||
port: int = 8765
|
||||
thread_scheduling: str = "breadth-first"
|
||||
max_concurrent_pipelines: int = 50
|
||||
max_concurrent_handlers: int = 20
|
||||
|
||||
|
||||
class MockStreamPump:
|
||||
"""Mock StreamPump for testing."""
|
||||
|
||||
def __init__(self):
|
||||
self.config = MockConfig()
|
||||
self.identity = None
|
||||
self.listeners: Dict[str, MockListener] = {
|
||||
"greeter": MockListener(
|
||||
name="greeter",
|
||||
description="Greeting agent",
|
||||
is_agent=True,
|
||||
peers=["shouter"],
|
||||
),
|
||||
"shouter": MockListener(
|
||||
name="shouter",
|
||||
description="Shouting handler",
|
||||
is_agent=False,
|
||||
),
|
||||
}
|
||||
self._event_callbacks = []
|
||||
|
||||
def subscribe_events(self, callback):
|
||||
self._event_callbacks.append(callback)
|
||||
|
||||
def unsubscribe_events(self, callback):
|
||||
if callback in self._event_callbacks:
|
||||
self._event_callbacks.remove(callback)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Fixtures
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pump():
|
||||
"""Create a mock StreamPump."""
|
||||
return MockStreamPump()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server_state(mock_pump):
|
||||
"""Create ServerState with mock pump."""
|
||||
return ServerState(mock_pump)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_client(mock_pump):
|
||||
"""Create FastAPI test client."""
|
||||
app = create_app(mock_pump)
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Models
|
||||
# ============================================================================
|
||||
|
||||
class TestModels:
|
||||
"""Test Pydantic model serialization."""
|
||||
|
||||
def test_agent_info_camel_case(self):
|
||||
"""Test AgentInfo serializes to camelCase."""
|
||||
agent = AgentInfo(
|
||||
name="greeter",
|
||||
description="Test agent",
|
||||
is_agent=True,
|
||||
peers=["shouter"],
|
||||
payload_class="test.Greeting",
|
||||
state=AgentState.IDLE,
|
||||
current_thread=None,
|
||||
queue_depth=0,
|
||||
message_count=5,
|
||||
)
|
||||
data = agent.model_dump(by_alias=True)
|
||||
assert "isAgent" in data
|
||||
assert "payloadClass" in data
|
||||
assert "currentThread" in data
|
||||
assert "queueDepth" in data
|
||||
assert "messageCount" in data
|
||||
|
||||
def test_thread_info_camel_case(self):
|
||||
"""Test ThreadInfo serializes to camelCase."""
|
||||
from datetime import datetime, timezone
|
||||
thread = ThreadInfo(
|
||||
id="test-uuid",
|
||||
status=ThreadStatus.ACTIVE,
|
||||
participants=["greeter", "shouter"],
|
||||
message_count=3,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
data = thread.model_dump(by_alias=True)
|
||||
assert "messageCount" in data
|
||||
assert "createdAt" in data
|
||||
assert "lastActivity" in data
|
||||
|
||||
def test_organism_info_camel_case(self):
|
||||
"""Test OrganismInfo serializes to camelCase."""
|
||||
info = OrganismInfo(
|
||||
name="test-organism",
|
||||
status=OrganismStatus.RUNNING,
|
||||
uptime_seconds=3600.0,
|
||||
agent_count=2,
|
||||
active_threads=1,
|
||||
total_messages=10,
|
||||
identity_configured=False,
|
||||
)
|
||||
data = info.model_dump(by_alias=True)
|
||||
assert "uptimeSeconds" in data
|
||||
assert "agentCount" in data
|
||||
assert "activeThreads" in data
|
||||
assert "totalMessages" in data
|
||||
assert "identityConfigured" in data
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test ServerState
|
||||
# ============================================================================
|
||||
|
||||
class TestServerState:
|
||||
"""Test ServerState functionality."""
|
||||
|
||||
def test_init_agents_from_pump(self, server_state):
|
||||
"""Test agents are initialized from pump listeners."""
|
||||
assert "greeter" in server_state._agents
|
||||
assert "shouter" in server_state._agents
|
||||
assert server_state._agents["greeter"].is_agent is True
|
||||
assert server_state._agents["shouter"].is_agent is False
|
||||
|
||||
def test_get_organism_info(self, server_state):
|
||||
"""Test getting organism info."""
|
||||
info = server_state.get_organism_info()
|
||||
assert info.name == "test-organism"
|
||||
assert info.agent_count == 2
|
||||
assert info.status == OrganismStatus.STARTING
|
||||
|
||||
def test_set_running(self, server_state):
|
||||
"""Test setting running status."""
|
||||
server_state.set_running()
|
||||
info = server_state.get_organism_info()
|
||||
assert info.status == OrganismStatus.RUNNING
|
||||
|
||||
def test_get_agents(self, server_state):
|
||||
"""Test getting all agents."""
|
||||
agents = server_state.get_agents()
|
||||
assert len(agents) == 2
|
||||
names = [a.name for a in agents]
|
||||
assert "greeter" in names
|
||||
assert "shouter" in names
|
||||
|
||||
def test_get_agent(self, server_state):
|
||||
"""Test getting single agent."""
|
||||
agent = server_state.get_agent("greeter")
|
||||
assert agent is not None
|
||||
assert agent.name == "greeter"
|
||||
assert agent.is_agent is True
|
||||
|
||||
def test_get_agent_not_found(self, server_state):
|
||||
"""Test getting non-existent agent."""
|
||||
agent = server_state.get_agent("nonexistent")
|
||||
assert agent is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_message(self, server_state):
|
||||
"""Test recording a message."""
|
||||
msg_id = await server_state.record_message(
|
||||
thread_id="test-thread",
|
||||
from_id="greeter",
|
||||
to_id="shouter",
|
||||
payload_type="GreetingResponse",
|
||||
payload={"message": "Hello!"},
|
||||
)
|
||||
assert msg_id is not None
|
||||
|
||||
# Check message was recorded
|
||||
messages, total = server_state.get_messages(thread_id="test-thread")
|
||||
assert total == 1
|
||||
assert messages[0].from_id == "greeter"
|
||||
assert messages[0].to_id == "shouter"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_agent_state(self, server_state):
|
||||
"""Test updating agent state."""
|
||||
await server_state.update_agent_state("greeter", AgentState.PROCESSING, "thread-1")
|
||||
|
||||
agent = server_state.get_agent("greeter")
|
||||
assert agent.state == AgentState.PROCESSING
|
||||
assert agent.current_thread == "thread-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_thread(self, server_state):
|
||||
"""Test completing a thread."""
|
||||
# First record a message to create the thread
|
||||
await server_state.record_message(
|
||||
thread_id="test-thread",
|
||||
from_id="greeter",
|
||||
to_id="shouter",
|
||||
payload_type="Test",
|
||||
payload={},
|
||||
)
|
||||
|
||||
# Complete the thread
|
||||
await server_state.complete_thread("test-thread", ThreadStatus.COMPLETED)
|
||||
|
||||
thread = server_state.get_thread("test-thread")
|
||||
assert thread.status == ThreadStatus.COMPLETED
|
||||
|
||||
def test_get_threads_with_filter(self, server_state):
|
||||
"""Test filtering threads by status."""
|
||||
# No threads initially
|
||||
threads, total = server_state.get_threads(status=ThreadStatus.ACTIVE)
|
||||
assert total == 0
|
||||
|
||||
def test_get_organism_config(self, server_state):
|
||||
"""Test getting organism config."""
|
||||
config = server_state.get_organism_config()
|
||||
assert config["name"] == "test-organism"
|
||||
assert config["port"] == 8765
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test REST API
|
||||
# ============================================================================
|
||||
|
||||
class TestRestAPI:
|
||||
"""Test REST API endpoints."""
|
||||
|
||||
def test_health_check(self, test_client):
|
||||
"""Test health check endpoint."""
|
||||
response = test_client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert "organism" in data
|
||||
|
||||
def test_get_organism(self, test_client):
|
||||
"""Test GET /api/v1/organism."""
|
||||
response = test_client.get("/api/v1/organism")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "test-organism"
|
||||
assert "agentCount" in data
|
||||
|
||||
def test_get_organism_config(self, test_client):
|
||||
"""Test GET /api/v1/organism/config."""
|
||||
response = test_client.get("/api/v1/organism/config")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "test-organism"
|
||||
|
||||
def test_list_agents(self, test_client):
|
||||
"""Test GET /api/v1/agents."""
|
||||
response = test_client.get("/api/v1/agents")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "agents" in data
|
||||
assert data["count"] == 2
|
||||
|
||||
def test_get_agent(self, test_client):
|
||||
"""Test GET /api/v1/agents/{name}."""
|
||||
response = test_client.get("/api/v1/agents/greeter")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "greeter"
|
||||
assert data["isAgent"] is True
|
||||
|
||||
def test_get_agent_not_found(self, test_client):
|
||||
"""Test GET /api/v1/agents/{name} with non-existent agent."""
|
||||
response = test_client.get("/api/v1/agents/nonexistent")
|
||||
assert response.status_code == 404
|
||||
|
||||
def test_get_agent_config(self, test_client):
|
||||
"""Test GET /api/v1/agents/{name}/config."""
|
||||
response = test_client.get("/api/v1/agents/greeter/config")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "greeter"
|
||||
assert "isAgent" in data
|
||||
|
||||
def test_list_threads(self, test_client):
|
||||
"""Test GET /api/v1/threads."""
|
||||
response = test_client.get("/api/v1/threads")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "threads" in data
|
||||
assert "count" in data
|
||||
assert "total" in data
|
||||
|
||||
def test_list_threads_with_invalid_status(self, test_client):
|
||||
"""Test GET /api/v1/threads with invalid status filter."""
|
||||
response = test_client.get("/api/v1/threads?status=invalid")
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_list_messages(self, test_client):
|
||||
"""Test GET /api/v1/messages."""
|
||||
response = test_client.get("/api/v1/messages")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "messages" in data
|
||||
assert "count" in data
|
||||
|
||||
def test_inject_message(self, test_client):
|
||||
"""Test POST /api/v1/inject."""
|
||||
response = test_client.post(
|
||||
"/api/v1/inject",
|
||||
json={
|
||||
"to": "greeter",
|
||||
"payload": {"name": "Dan"},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "threadId" in data
|
||||
assert "messageId" in data
|
||||
|
||||
def test_inject_message_unknown_agent(self, test_client):
|
||||
"""Test POST /api/v1/inject with unknown agent."""
|
||||
response = test_client.post(
|
||||
"/api/v1/inject",
|
||||
json={
|
||||
"to": "nonexistent",
|
||||
"payload": {"name": "Dan"},
|
||||
},
|
||||
)
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_pause_agent(self, test_client):
|
||||
"""Test POST /api/v1/agents/{name}/pause."""
|
||||
response = test_client.post("/api/v1/agents/greeter/pause")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["state"] == "paused"
|
||||
|
||||
def test_resume_agent(self, test_client):
|
||||
"""Test POST /api/v1/agents/{name}/resume."""
|
||||
response = test_client.post("/api/v1/agents/greeter/resume")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
assert data["state"] == "idle"
|
||||
|
||||
def test_stop_organism(self, test_client):
|
||||
"""Test POST /api/v1/organism/stop."""
|
||||
response = test_client.post("/api/v1/organism/stop")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["success"] is True
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test WebSocket
|
||||
# ============================================================================
|
||||
|
||||
class TestWebSocket:
|
||||
"""Test WebSocket endpoints."""
|
||||
|
||||
def test_websocket_connect(self, test_client):
|
||||
"""Test WebSocket connection."""
|
||||
with test_client.websocket_connect("/ws") as websocket:
|
||||
# Should receive connected event with state snapshot
|
||||
data = websocket.receive_json()
|
||||
assert data["event"] == "connected"
|
||||
assert "organism" in data
|
||||
assert "agents" in data
|
||||
assert "threads" in data
|
||||
|
||||
def test_websocket_subscribe(self, test_client):
|
||||
"""Test WebSocket subscribe command."""
|
||||
with test_client.websocket_connect("/ws") as websocket:
|
||||
# Receive initial connected event
|
||||
websocket.receive_json()
|
||||
|
||||
# Send subscribe command
|
||||
websocket.send_json({
|
||||
"cmd": "subscribe",
|
||||
"agents": ["greeter"],
|
||||
"events": ["message"],
|
||||
})
|
||||
|
||||
# Should receive subscribed confirmation
|
||||
data = websocket.receive_json()
|
||||
assert data["event"] == "subscribed"
|
||||
|
||||
def test_websocket_inject(self, test_client):
|
||||
"""Test WebSocket inject command."""
|
||||
with test_client.websocket_connect("/ws") as websocket:
|
||||
# Receive initial connected event
|
||||
websocket.receive_json()
|
||||
|
||||
# Send inject command
|
||||
websocket.send_json({
|
||||
"cmd": "inject",
|
||||
"to": "greeter",
|
||||
"payload": {"name": "Dan"},
|
||||
})
|
||||
|
||||
# Should receive injected confirmation
|
||||
data = websocket.receive_json()
|
||||
assert data["event"] == "injected"
|
||||
assert "thread_id" in data
|
||||
assert "message_id" in data
|
||||
|
||||
def test_websocket_inject_unknown_agent(self, test_client):
|
||||
"""Test WebSocket inject with unknown agent."""
|
||||
with test_client.websocket_connect("/ws") as websocket:
|
||||
# Receive initial connected event
|
||||
websocket.receive_json()
|
||||
|
||||
# Send inject command to unknown agent
|
||||
websocket.send_json({
|
||||
"cmd": "inject",
|
||||
"to": "nonexistent",
|
||||
"payload": {},
|
||||
})
|
||||
|
||||
# Should receive error
|
||||
data = websocket.receive_json()
|
||||
assert data["event"] == "error"
|
||||
|
||||
def test_websocket_unknown_command(self, test_client):
|
||||
"""Test WebSocket with unknown command."""
|
||||
with test_client.websocket_connect("/ws") as websocket:
|
||||
# Receive initial connected event
|
||||
websocket.receive_json()
|
||||
|
||||
# Send unknown command
|
||||
websocket.send_json({
|
||||
"cmd": "unknown_command",
|
||||
})
|
||||
|
||||
# Should receive error
|
||||
data = websocket.receive_json()
|
||||
assert data["event"] == "error"
|
||||
|
||||
def test_websocket_messages_stream(self, test_client):
|
||||
"""Test WebSocket messages stream endpoint."""
|
||||
with test_client.websocket_connect("/ws/messages") as websocket:
|
||||
# Send subscribe command
|
||||
websocket.send_json({
|
||||
"cmd": "subscribe",
|
||||
"filter": {
|
||||
"agents": ["greeter"],
|
||||
},
|
||||
})
|
||||
|
||||
# Should receive subscribed confirmation
|
||||
data = websocket.receive_json()
|
||||
assert data["event"] == "subscribed"
|
||||
assert "filter" in data
|
||||
|
|
@ -3,6 +3,7 @@ xml-pipeline CLI entry point.
|
|||
|
||||
Usage:
|
||||
xml-pipeline run [config.yaml] Run an organism
|
||||
xml-pipeline serve [config.yaml] Run organism with API server
|
||||
xml-pipeline init [name] Create new organism config
|
||||
xml-pipeline check [config.yaml] Validate config without running
|
||||
xml-pipeline version Show version info
|
||||
|
|
@ -37,6 +38,68 @@ def cmd_run(args: argparse.Namespace) -> int:
|
|||
return 1
|
||||
|
||||
|
||||
def cmd_serve(args: argparse.Namespace) -> int:
|
||||
"""Run an organism with the AgentServer API."""
|
||||
try:
|
||||
import uvicorn
|
||||
except ImportError:
|
||||
print("Error: uvicorn not installed.", file=sys.stderr)
|
||||
print("Install with: pip install xml-pipeline[server]", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
from xml_pipeline.message_bus import bootstrap
|
||||
|
||||
config_path = Path(args.config)
|
||||
if not config_path.exists():
|
||||
print(f"Error: Config file not found: {config_path}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
async def run_with_server():
|
||||
"""Bootstrap pump and run with server."""
|
||||
from xml_pipeline.server import create_app
|
||||
|
||||
# Bootstrap the pump
|
||||
pump = await bootstrap(str(config_path))
|
||||
|
||||
# Create FastAPI app
|
||||
app = create_app(pump)
|
||||
|
||||
# Run uvicorn
|
||||
config = uvicorn.Config(
|
||||
app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="info",
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# Run pump and server concurrently
|
||||
pump_task = asyncio.create_task(pump.run())
|
||||
|
||||
try:
|
||||
await server.serve()
|
||||
finally:
|
||||
await pump.shutdown()
|
||||
pump_task.cancel()
|
||||
try:
|
||||
await pump_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
print(f"Starting AgentServer on http://{args.host}:{args.port}")
|
||||
print(f" API docs: http://{args.host}:{args.port}/docs")
|
||||
print(f" WebSocket: ws://{args.host}:{args.port}/ws")
|
||||
asyncio.run(run_with_server())
|
||||
return 0
|
||||
except KeyboardInterrupt:
|
||||
print("\nShutdown requested.")
|
||||
return 0
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
|
||||
def cmd_init(args: argparse.Namespace) -> int:
|
||||
"""Initialize a new organism config."""
|
||||
from xml_pipeline.config.template import create_organism_template
|
||||
|
|
@ -149,6 +212,13 @@ def main() -> int:
|
|||
run_parser.add_argument("config", nargs="?", default="organism.yaml", help="Config file")
|
||||
run_parser.set_defaults(func=cmd_run)
|
||||
|
||||
# serve
|
||||
serve_parser = subparsers.add_parser("serve", help="Run organism with API server")
|
||||
serve_parser.add_argument("config", nargs="?", default="organism.yaml", help="Config file")
|
||||
serve_parser.add_argument("--host", default="0.0.0.0", help="Host to bind (default: 0.0.0.0)")
|
||||
serve_parser.add_argument("--port", "-p", type=int, default=8080, help="Port to listen on (default: 8080)")
|
||||
serve_parser.set_defaults(func=cmd_serve)
|
||||
|
||||
# init
|
||||
init_parser = subparsers.add_parser("init", help="Create new organism config")
|
||||
init_parser.add_argument("name", nargs="?", help="Organism name")
|
||||
|
|
|
|||
|
|
@ -33,6 +33,12 @@ from xml_pipeline.message_bus.stream_pump import (
|
|||
get_stream_pump,
|
||||
set_stream_pump,
|
||||
reset_stream_pump,
|
||||
# Event hooks
|
||||
PumpEvent,
|
||||
MessageReceivedEvent,
|
||||
MessageSentEvent,
|
||||
AgentStateEvent,
|
||||
ThreadEvent,
|
||||
)
|
||||
|
||||
from xml_pipeline.message_bus.message_state import (
|
||||
|
|
@ -71,6 +77,12 @@ __all__ = [
|
|||
"get_stream_pump",
|
||||
"set_stream_pump",
|
||||
"reset_stream_pump",
|
||||
# Event hooks
|
||||
"PumpEvent",
|
||||
"MessageReceivedEvent",
|
||||
"MessageSentEvent",
|
||||
"AgentStateEvent",
|
||||
"ThreadEvent",
|
||||
# Message state
|
||||
"MessageState",
|
||||
"HandlerMetadata",
|
||||
|
|
|
|||
|
|
@ -49,6 +49,56 @@ from xml_pipeline.memory import get_context_buffer
|
|||
pump_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Event Hooks
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class PumpEvent:
|
||||
"""Base class for pump events."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageReceivedEvent(PumpEvent):
|
||||
"""Fired when a message is received by a handler."""
|
||||
thread_id: str
|
||||
from_id: str
|
||||
to_id: str
|
||||
payload_type: str
|
||||
payload: Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageSentEvent(PumpEvent):
|
||||
"""Fired when a handler sends a response."""
|
||||
thread_id: str
|
||||
from_id: str
|
||||
to_id: str
|
||||
payload_type: str
|
||||
payload: Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentStateEvent(PumpEvent):
|
||||
"""Fired when an agent's processing state changes."""
|
||||
agent_name: str
|
||||
state: str # "idle", "processing", "waiting", "error"
|
||||
thread_id: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThreadEvent(PumpEvent):
|
||||
"""Fired when a thread is created or completed."""
|
||||
thread_id: str
|
||||
status: str # "created", "active", "completed", "error", "killed"
|
||||
participants: List[str] = field(default_factory=list)
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
EventCallback = Callable[[PumpEvent], None]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration (same as before)
|
||||
# ============================================================================
|
||||
|
|
@ -233,6 +283,9 @@ class StreamPump:
|
|||
# Shutdown control
|
||||
self._running = False
|
||||
|
||||
# Event hooks for external observers (ServerState, etc.)
|
||||
self._event_callbacks: List[EventCallback] = []
|
||||
|
||||
# Process pool for cpu_bound handlers
|
||||
self._process_pool: Optional[ProcessPoolExecutor] = None
|
||||
if config.process_pool_enabled:
|
||||
|
|
@ -256,6 +309,27 @@ class StreamPump:
|
|||
self._shared_backend = get_shared_backend(backend_config)
|
||||
pump_logger.info(f"Shared backend: {config.backend_type}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Event Hooks
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def subscribe_events(self, callback: EventCallback) -> None:
|
||||
"""Subscribe to pump events (message flow, agent state, thread lifecycle)."""
|
||||
self._event_callbacks.append(callback)
|
||||
|
||||
def unsubscribe_events(self, callback: EventCallback) -> None:
|
||||
"""Unsubscribe from pump events."""
|
||||
if callback in self._event_callbacks:
|
||||
self._event_callbacks.remove(callback)
|
||||
|
||||
def _emit_event(self, event: PumpEvent) -> None:
|
||||
"""Emit an event to all subscribers (non-blocking)."""
|
||||
for callback in self._event_callbacks:
|
||||
try:
|
||||
callback(event)
|
||||
except Exception as e:
|
||||
pump_logger.warning(f"Event callback error: {e}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Registration
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -493,6 +567,12 @@ class StreamPump:
|
|||
await semaphore.acquire()
|
||||
|
||||
try:
|
||||
# Emit agent state change event
|
||||
self._emit_event(AgentStateEvent(
|
||||
agent_name=listener.name,
|
||||
state="processing",
|
||||
thread_id=state.thread_id,
|
||||
))
|
||||
# Ensure we have a valid thread chain
|
||||
registry = get_registry()
|
||||
todo_registry = get_todo_registry()
|
||||
|
|
@ -566,6 +646,15 @@ class StreamPump:
|
|||
)
|
||||
payload_ref = state.payload
|
||||
|
||||
# Emit message received event
|
||||
self._emit_event(MessageReceivedEvent(
|
||||
thread_id=current_thread,
|
||||
from_id=state.from_id or "",
|
||||
to_id=listener.name,
|
||||
payload_type=type(payload_ref).__name__,
|
||||
payload=payload_ref,
|
||||
))
|
||||
|
||||
# Dispatch to handler - either in-process or via ProcessPool
|
||||
if listener.cpu_bound and self._process_pool and self._shared_backend:
|
||||
response = await self._dispatch_to_process_pool(
|
||||
|
|
@ -578,6 +667,12 @@ class StreamPump:
|
|||
|
||||
# None means "no response needed" - don't re-inject
|
||||
if response is None:
|
||||
# Emit idle state
|
||||
self._emit_event(AgentStateEvent(
|
||||
agent_name=listener.name,
|
||||
state="idle",
|
||||
thread_id=current_thread,
|
||||
))
|
||||
continue
|
||||
|
||||
# Handle clean HandlerResponse (preferred)
|
||||
|
|
@ -653,6 +748,23 @@ class StreamPump:
|
|||
response_bytes = b"<huh>Handler returned invalid type</huh>"
|
||||
thread_id = state.thread_id
|
||||
|
||||
# Emit message sent event
|
||||
if isinstance(response, HandlerResponse):
|
||||
self._emit_event(MessageSentEvent(
|
||||
thread_id=thread_id,
|
||||
from_id=listener.name,
|
||||
to_id=to_id,
|
||||
payload_type=type(response.payload).__name__,
|
||||
payload=response.payload,
|
||||
))
|
||||
|
||||
# Emit agent state back to idle
|
||||
self._emit_event(AgentStateEvent(
|
||||
agent_name=listener.name,
|
||||
state="idle",
|
||||
thread_id=None,
|
||||
))
|
||||
|
||||
# Yield response — will be processed by next iteration
|
||||
yield MessageState(
|
||||
raw_bytes=response_bytes,
|
||||
|
|
@ -665,6 +777,12 @@ class StreamPump:
|
|||
semaphore.release()
|
||||
|
||||
except Exception as exc:
|
||||
# Emit error state
|
||||
self._emit_event(AgentStateEvent(
|
||||
agent_name=listener.name,
|
||||
state="error",
|
||||
thread_id=state.thread_id,
|
||||
))
|
||||
yield MessageState(
|
||||
raw_bytes=f"<huh>Handler {listener.name} crashed: {exc}</huh>".encode(),
|
||||
thread_id=state.thread_id,
|
||||
|
|
|
|||
26
xml_pipeline/server/__init__.py
Normal file
26
xml_pipeline/server/__init__.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
"""
|
||||
server — FastAPI-based AgentServer API for monitoring and controlling organisms.
|
||||
|
||||
Provides:
|
||||
- REST API for querying organism state (agents, threads, messages)
|
||||
- WebSocket for real-time events
|
||||
- Message injection endpoint
|
||||
|
||||
Usage:
|
||||
from xml_pipeline.server import create_app, run_server
|
||||
|
||||
# With existing pump
|
||||
app = create_app(pump)
|
||||
uvicorn.run(app, host="0.0.0.0", port=8080)
|
||||
|
||||
# Or use CLI
|
||||
xml-pipeline serve config/organism.yaml --port 8080
|
||||
"""
|
||||
|
||||
from xml_pipeline.server.app import create_app, run_server, run_server_sync
|
||||
|
||||
__all__ = [
|
||||
"create_app",
|
||||
"run_server",
|
||||
"run_server_sync",
|
||||
]
|
||||
273
xml_pipeline/server/api.py
Normal file
273
xml_pipeline/server/api.py
Normal file
|
|
@ -0,0 +1,273 @@
|
|||
"""
|
||||
api.py — REST API routes for AgentServer.
|
||||
|
||||
Provides endpoints for:
|
||||
- Organism info and config
|
||||
- Agent listing and details
|
||||
- Thread listing and management
|
||||
- Message injection
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
|
||||
from xml_pipeline.server.models import (
|
||||
AgentInfo,
|
||||
AgentListResponse,
|
||||
ErrorResponse,
|
||||
InjectRequest,
|
||||
InjectResponse,
|
||||
MessageListResponse,
|
||||
OrganismInfo,
|
||||
ThreadInfo,
|
||||
ThreadListResponse,
|
||||
ThreadStatus,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from xml_pipeline.server.state import ServerState
|
||||
|
||||
|
||||
def create_router(state: "ServerState") -> APIRouter:
|
||||
"""Create API router with state dependency."""
|
||||
router = APIRouter(prefix="/api/v1")
|
||||
|
||||
# =========================================================================
|
||||
# Organism Endpoints
|
||||
# =========================================================================
|
||||
|
||||
@router.get("/organism", response_model=OrganismInfo)
|
||||
async def get_organism() -> OrganismInfo:
|
||||
"""Get organism overview and stats."""
|
||||
return state.get_organism_info()
|
||||
|
||||
@router.get("/organism/config")
|
||||
async def get_organism_config() -> dict:
|
||||
"""Get sanitized organism configuration (no secrets)."""
|
||||
return state.get_organism_config()
|
||||
|
||||
# =========================================================================
|
||||
# Agent Endpoints
|
||||
# =========================================================================
|
||||
|
||||
@router.get("/agents", response_model=AgentListResponse)
|
||||
async def list_agents() -> AgentListResponse:
|
||||
"""List all agents with current state."""
|
||||
agents = state.get_agents()
|
||||
return AgentListResponse(agents=agents, count=len(agents))
|
||||
|
||||
@router.get("/agents/{name}", response_model=AgentInfo)
|
||||
async def get_agent(name: str) -> AgentInfo:
|
||||
"""Get single agent details."""
|
||||
agent = state.get_agent(name)
|
||||
if agent is None:
|
||||
raise HTTPException(status_code=404, detail=f"Agent not found: {name}")
|
||||
return agent
|
||||
|
||||
@router.get("/agents/{name}/config")
|
||||
async def get_agent_config(name: str) -> dict:
|
||||
"""Get agent's YAML config section."""
|
||||
agent = state.get_agent(name)
|
||||
if agent is None:
|
||||
raise HTTPException(status_code=404, detail=f"Agent not found: {name}")
|
||||
|
||||
# Return relevant config fields
|
||||
return {
|
||||
"name": agent.name,
|
||||
"description": agent.description,
|
||||
"isAgent": agent.is_agent,
|
||||
"peers": agent.peers,
|
||||
"payloadClass": agent.payload_class,
|
||||
}
|
||||
|
||||
@router.get("/agents/{name}/schema")
|
||||
async def get_agent_schema(name: str) -> dict:
|
||||
"""Get agent's payload XML schema."""
|
||||
schema = state.get_agent_schema(name)
|
||||
if schema is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Schema not found for agent: {name}",
|
||||
)
|
||||
return {"schema": schema, "contentType": "application/xml"}
|
||||
|
||||
# =========================================================================
|
||||
# Thread Endpoints
|
||||
# =========================================================================
|
||||
|
||||
@router.get("/threads", response_model=ThreadListResponse)
|
||||
async def list_threads(
|
||||
status: Optional[str] = Query(None, description="Filter by status"),
|
||||
agent: Optional[str] = Query(None, description="Filter by participant agent"),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0),
|
||||
) -> ThreadListResponse:
|
||||
"""List threads with optional filtering."""
|
||||
thread_status = None
|
||||
if status:
|
||||
try:
|
||||
thread_status = ThreadStatus(status)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid status: {status}. Valid values: {[s.value for s in ThreadStatus]}",
|
||||
)
|
||||
|
||||
threads, total = state.get_threads(
|
||||
status=thread_status,
|
||||
agent=agent,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
return ThreadListResponse(
|
||||
threads=threads,
|
||||
count=len(threads),
|
||||
total=total,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
@router.get("/threads/{thread_id}", response_model=ThreadInfo)
|
||||
async def get_thread(thread_id: str) -> ThreadInfo:
|
||||
"""Get thread details with message history."""
|
||||
thread = state.get_thread(thread_id)
|
||||
if thread is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread not found: {thread_id}")
|
||||
return thread
|
||||
|
||||
@router.get("/threads/{thread_id}/messages", response_model=MessageListResponse)
|
||||
async def get_thread_messages(
|
||||
thread_id: str,
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0),
|
||||
) -> MessageListResponse:
|
||||
"""Get messages in a specific thread."""
|
||||
thread = state.get_thread(thread_id)
|
||||
if thread is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread not found: {thread_id}")
|
||||
|
||||
messages, total = state.get_messages(
|
||||
thread_id=thread_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
return MessageListResponse(
|
||||
messages=messages,
|
||||
count=len(messages),
|
||||
total=total,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
@router.post("/threads/{thread_id}/kill")
|
||||
async def kill_thread(thread_id: str) -> dict:
|
||||
"""Terminate a thread."""
|
||||
thread = state.get_thread(thread_id)
|
||||
if thread is None:
|
||||
raise HTTPException(status_code=404, detail=f"Thread not found: {thread_id}")
|
||||
|
||||
await state.complete_thread(thread_id, status=ThreadStatus.KILLED)
|
||||
return {"success": True, "threadId": thread_id}
|
||||
|
||||
# =========================================================================
|
||||
# Message Endpoints
|
||||
# =========================================================================
|
||||
|
||||
@router.get("/messages", response_model=MessageListResponse)
|
||||
async def list_messages(
|
||||
agent: Optional[str] = Query(None, description="Filter by agent (sender or receiver)"),
|
||||
limit: int = Query(50, ge=1, le=100),
|
||||
offset: int = Query(0, ge=0),
|
||||
) -> MessageListResponse:
|
||||
"""Get global message history."""
|
||||
messages, total = state.get_messages(
|
||||
agent=agent,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
return MessageListResponse(
|
||||
messages=messages,
|
||||
count=len(messages),
|
||||
total=total,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Control Endpoints
|
||||
# =========================================================================
|
||||
|
||||
@router.post("/inject", response_model=InjectResponse)
|
||||
async def inject_message(request: InjectRequest) -> InjectResponse:
|
||||
"""Inject a message to an agent."""
|
||||
# Validate target exists
|
||||
agent = state.get_agent(request.to)
|
||||
if agent is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unknown target agent: {request.to}",
|
||||
)
|
||||
|
||||
# Generate or use provided thread ID
|
||||
thread_id = request.thread_id or str(uuid.uuid4())
|
||||
|
||||
# Build payload XML from dict
|
||||
# For now, we construct a simple wrapper
|
||||
payload_type = next(iter(request.payload.keys()), "Payload")
|
||||
|
||||
# Record the message
|
||||
msg_id = await state.record_message(
|
||||
thread_id=thread_id,
|
||||
from_id="api",
|
||||
to_id=request.to,
|
||||
payload_type=payload_type,
|
||||
payload=request.payload,
|
||||
)
|
||||
|
||||
# TODO: Actually inject into pump queue
|
||||
# This requires building an envelope and calling pump.inject()
|
||||
|
||||
return InjectResponse(thread_id=thread_id, message_id=msg_id)
|
||||
|
||||
@router.post("/agents/{name}/pause")
|
||||
async def pause_agent(name: str) -> dict:
|
||||
"""Pause an agent (stop processing new messages)."""
|
||||
agent = state.get_agent(name)
|
||||
if agent is None:
|
||||
raise HTTPException(status_code=404, detail=f"Agent not found: {name}")
|
||||
|
||||
from xml_pipeline.server.models import AgentState
|
||||
|
||||
await state.update_agent_state(name, AgentState.PAUSED)
|
||||
return {"success": True, "agent": name, "state": "paused"}
|
||||
|
||||
@router.post("/agents/{name}/resume")
|
||||
async def resume_agent(name: str) -> dict:
|
||||
"""Resume a paused agent."""
|
||||
agent = state.get_agent(name)
|
||||
if agent is None:
|
||||
raise HTTPException(status_code=404, detail=f"Agent not found: {name}")
|
||||
|
||||
from xml_pipeline.server.models import AgentState
|
||||
|
||||
await state.update_agent_state(name, AgentState.IDLE)
|
||||
return {"success": True, "agent": name, "state": "idle"}
|
||||
|
||||
@router.post("/organism/reload")
|
||||
async def reload_config() -> dict:
|
||||
"""Hot-reload organism configuration."""
|
||||
# TODO: Implement hot-reload
|
||||
return {"success": False, "error": "Hot-reload not yet implemented"}
|
||||
|
||||
@router.post("/organism/stop")
|
||||
async def stop_organism() -> dict:
|
||||
"""Graceful shutdown."""
|
||||
state.set_stopping()
|
||||
# TODO: Signal pump to stop
|
||||
return {"success": True, "status": "stopping"}
|
||||
|
||||
return router
|
||||
148
xml_pipeline/server/app.py
Normal file
148
xml_pipeline/server/app.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
"""
|
||||
app.py — FastAPI application factory for AgentServer.
|
||||
|
||||
Creates the FastAPI app that combines REST API and WebSocket endpoints.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from xml_pipeline.server.api import create_router
|
||||
from xml_pipeline.server.state import ServerState
|
||||
from xml_pipeline.server.websocket import create_websocket_router
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from xml_pipeline.message_bus.stream_pump import StreamPump
|
||||
|
||||
|
||||
def create_app(
|
||||
pump: "StreamPump",
|
||||
*,
|
||||
title: str = "AgentServer API",
|
||||
version: str = "1.0.0",
|
||||
cors_origins: Optional[list[str]] = None,
|
||||
) -> FastAPI:
|
||||
"""
|
||||
Create FastAPI application with REST and WebSocket endpoints.
|
||||
|
||||
Args:
|
||||
pump: The StreamPump instance to wrap
|
||||
title: API title for OpenAPI docs
|
||||
version: API version
|
||||
cors_origins: List of allowed CORS origins (default: all)
|
||||
|
||||
Returns:
|
||||
Configured FastAPI application
|
||||
"""
|
||||
# Create state manager
|
||||
state = ServerState(pump)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Manage app lifecycle - startup and shutdown."""
|
||||
# Startup
|
||||
state.set_running()
|
||||
yield
|
||||
# Shutdown
|
||||
state.set_stopping()
|
||||
|
||||
app = FastAPI(
|
||||
title=title,
|
||||
version=version,
|
||||
description="REST and WebSocket API for monitoring and controlling xml-pipeline organisms.",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
if cors_origins is None:
|
||||
cors_origins = ["*"]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Include routers
|
||||
app.include_router(create_router(state))
|
||||
app.include_router(create_websocket_router(state))
|
||||
|
||||
# Store state on app for access if needed
|
||||
app.state.server_state = state
|
||||
app.state.pump = pump
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, Any]:
|
||||
"""Health check endpoint."""
|
||||
info = state.get_organism_info()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"organism": info.name,
|
||||
"uptime_seconds": info.uptime_seconds,
|
||||
}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
async def run_server(
|
||||
pump: "StreamPump",
|
||||
*,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8080,
|
||||
cors_origins: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the AgentServer with uvicorn.
|
||||
|
||||
Args:
|
||||
pump: The StreamPump instance to wrap
|
||||
host: Host to bind to
|
||||
port: Port to listen on
|
||||
cors_origins: List of allowed CORS origins
|
||||
"""
|
||||
try:
|
||||
import uvicorn
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"uvicorn is required for the server. Install with: pip install xml-pipeline[server]"
|
||||
) from e
|
||||
|
||||
app = create_app(pump, cors_origins=cors_origins)
|
||||
|
||||
config = uvicorn.Config(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info",
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
||||
def run_server_sync(
|
||||
pump: "StreamPump",
|
||||
*,
|
||||
host: str = "0.0.0.0",
|
||||
port: int = 8080,
|
||||
cors_origins: Optional[list[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the AgentServer synchronously (blocking).
|
||||
|
||||
This is a convenience wrapper for CLI usage.
|
||||
|
||||
Args:
|
||||
pump: The StreamPump instance to wrap
|
||||
host: Host to bind to
|
||||
port: Port to listen on
|
||||
cors_origins: List of allowed CORS origins
|
||||
"""
|
||||
asyncio.run(run_server(pump, host=host, port=port, cors_origins=cors_origins))
|
||||
261
xml_pipeline/server/models.py
Normal file
261
xml_pipeline/server/models.py
Normal file
|
|
@ -0,0 +1,261 @@
|
|||
"""
|
||||
models.py — Pydantic models for AgentServer API.
|
||||
|
||||
These models define the JSON structure for API responses.
|
||||
Uses camelCase for JSON keys (JavaScript convention).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
def to_camel(string: str) -> str:
|
||||
"""Convert snake_case to camelCase."""
|
||||
components = string.split("_")
|
||||
return components[0] + "".join(x.title() for x in components[1:])
|
||||
|
||||
|
||||
class CamelModel(BaseModel):
|
||||
"""Base model with camelCase JSON serialization."""
|
||||
|
||||
model_config = ConfigDict(
|
||||
alias_generator=to_camel,
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Enums
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AgentState(str, Enum):
|
||||
"""Agent processing state."""
|
||||
|
||||
IDLE = "idle"
|
||||
PROCESSING = "processing"
|
||||
WAITING = "waiting"
|
||||
ERROR = "error"
|
||||
PAUSED = "paused"
|
||||
|
||||
|
||||
class ThreadStatus(str, Enum):
|
||||
"""Thread lifecycle status."""
|
||||
|
||||
ACTIVE = "active"
|
||||
COMPLETED = "completed"
|
||||
ERROR = "error"
|
||||
KILLED = "killed"
|
||||
|
||||
|
||||
class OrganismStatus(str, Enum):
|
||||
"""Organism running status."""
|
||||
|
||||
STARTING = "starting"
|
||||
RUNNING = "running"
|
||||
STOPPING = "stopping"
|
||||
STOPPED = "stopped"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# API Response Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AgentInfo(CamelModel):
|
||||
"""Agent information for API response."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
is_agent: bool = Field(alias="isAgent")
|
||||
peers: List[str] = Field(default_factory=list)
|
||||
payload_class: str = Field(alias="payloadClass")
|
||||
state: AgentState = AgentState.IDLE
|
||||
current_thread: Optional[str] = Field(None, alias="currentThread")
|
||||
queue_depth: int = Field(0, alias="queueDepth")
|
||||
last_activity: Optional[datetime] = Field(None, alias="lastActivity")
|
||||
message_count: int = Field(0, alias="messageCount")
|
||||
|
||||
|
||||
class MessageInfo(CamelModel):
|
||||
"""Message information for API response."""
|
||||
|
||||
id: str
|
||||
thread_id: str = Field(alias="threadId")
|
||||
from_id: str = Field(alias="from")
|
||||
to_id: str = Field(alias="to")
|
||||
payload_type: str = Field(alias="payloadType")
|
||||
payload: Dict[str, Any] = Field(default_factory=dict)
|
||||
timestamp: datetime
|
||||
slot_index: Optional[int] = Field(None, alias="slotIndex")
|
||||
|
||||
|
||||
class ThreadInfo(CamelModel):
|
||||
"""Thread information for API response."""
|
||||
|
||||
id: str
|
||||
status: ThreadStatus = ThreadStatus.ACTIVE
|
||||
participants: List[str] = Field(default_factory=list)
|
||||
message_count: int = Field(0, alias="messageCount")
|
||||
created_at: datetime = Field(alias="createdAt")
|
||||
last_activity: Optional[datetime] = Field(None, alias="lastActivity")
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class OrganismInfo(CamelModel):
|
||||
"""Organism overview for API response."""
|
||||
|
||||
name: str
|
||||
status: OrganismStatus = OrganismStatus.RUNNING
|
||||
uptime_seconds: float = Field(0.0, alias="uptimeSeconds")
|
||||
agent_count: int = Field(0, alias="agentCount")
|
||||
active_threads: int = Field(0, alias="activeThreads")
|
||||
total_messages: int = Field(0, alias="totalMessages")
|
||||
identity_configured: bool = Field(False, alias="identityConfigured")
|
||||
|
||||
|
||||
class OrganismConfig(CamelModel):
|
||||
"""Sanitized organism configuration for API response."""
|
||||
|
||||
name: str
|
||||
port: int = 8765
|
||||
thread_scheduling: str = Field("breadth-first", alias="threadScheduling")
|
||||
max_concurrent_pipelines: int = Field(50, alias="maxConcurrentPipelines")
|
||||
max_concurrent_handlers: int = Field(20, alias="maxConcurrentHandlers")
|
||||
listeners: List[str] = Field(default_factory=list)
|
||||
# Note: Secrets like API keys are never exposed
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Request Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class InjectRequest(CamelModel):
|
||||
"""Request body for POST /inject."""
|
||||
|
||||
to: str
|
||||
payload: Dict[str, Any]
|
||||
thread_id: Optional[str] = Field(None, alias="threadId")
|
||||
|
||||
|
||||
class InjectResponse(CamelModel):
|
||||
"""Response body for POST /inject."""
|
||||
|
||||
thread_id: str = Field(alias="threadId")
|
||||
message_id: str = Field(alias="messageId")
|
||||
|
||||
|
||||
class SubscribeRequest(CamelModel):
|
||||
"""WebSocket subscription filter."""
|
||||
|
||||
threads: Optional[List[str]] = None
|
||||
agents: Optional[List[str]] = None
|
||||
payload_types: Optional[List[str]] = Field(None, alias="payloadTypes")
|
||||
events: Optional[List[str]] = None
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# WebSocket Event Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class WSEvent(CamelModel):
|
||||
"""Base WebSocket event."""
|
||||
|
||||
event: str
|
||||
|
||||
|
||||
class WSConnectedEvent(WSEvent):
|
||||
"""Sent on WebSocket connection with full state snapshot."""
|
||||
|
||||
event: str = "connected"
|
||||
organism: OrganismInfo
|
||||
agents: List[AgentInfo]
|
||||
threads: List[ThreadInfo]
|
||||
|
||||
|
||||
class WSAgentStateEvent(WSEvent):
|
||||
"""Agent state changed."""
|
||||
|
||||
event: str = "agent_state"
|
||||
agent: str
|
||||
state: AgentState
|
||||
current_thread: Optional[str] = Field(None, alias="currentThread")
|
||||
|
||||
|
||||
class WSMessageEvent(WSEvent):
|
||||
"""New message in the system."""
|
||||
|
||||
event: str = "message"
|
||||
message: MessageInfo
|
||||
|
||||
|
||||
class WSThreadCreatedEvent(WSEvent):
|
||||
"""New thread started."""
|
||||
|
||||
event: str = "thread_created"
|
||||
thread: ThreadInfo
|
||||
|
||||
|
||||
class WSThreadUpdatedEvent(WSEvent):
|
||||
"""Thread status changed."""
|
||||
|
||||
event: str = "thread_updated"
|
||||
thread_id: str = Field(alias="threadId")
|
||||
status: ThreadStatus
|
||||
message_count: int = Field(alias="messageCount")
|
||||
|
||||
|
||||
class WSErrorEvent(WSEvent):
|
||||
"""Error occurred."""
|
||||
|
||||
event: str = "error"
|
||||
thread_id: Optional[str] = Field(None, alias="threadId")
|
||||
agent: Optional[str] = None
|
||||
error: str
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# List Response Models
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class AgentListResponse(CamelModel):
|
||||
"""Response for GET /agents."""
|
||||
|
||||
agents: List[AgentInfo]
|
||||
count: int
|
||||
|
||||
|
||||
class ThreadListResponse(CamelModel):
|
||||
"""Response for GET /threads."""
|
||||
|
||||
threads: List[ThreadInfo]
|
||||
count: int
|
||||
total: int
|
||||
offset: int
|
||||
limit: int
|
||||
|
||||
|
||||
class MessageListResponse(CamelModel):
|
||||
"""Response for GET /messages or /threads/{id}/messages."""
|
||||
|
||||
messages: List[MessageInfo]
|
||||
count: int
|
||||
total: int
|
||||
offset: int
|
||||
limit: int
|
||||
|
||||
|
||||
class ErrorResponse(CamelModel):
|
||||
"""Error response."""
|
||||
|
||||
error: str
|
||||
detail: Optional[str] = None
|
||||
515
xml_pipeline/server/state.py
Normal file
515
xml_pipeline/server/state.py
Normal file
|
|
@ -0,0 +1,515 @@
|
|||
"""
|
||||
state.py — Server state manager for tracking organism runtime state.
|
||||
|
||||
Maintains real-time state for API queries:
|
||||
- Agent states (idle, processing, waiting, error)
|
||||
- Active threads and their participants
|
||||
- Message counts and activity timestamps
|
||||
|
||||
This is the bridge between the StreamPump and the API.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set
|
||||
|
||||
from xml_pipeline.server.models import (
|
||||
AgentInfo,
|
||||
AgentState,
|
||||
MessageInfo,
|
||||
OrganismInfo,
|
||||
OrganismStatus,
|
||||
ThreadInfo,
|
||||
ThreadStatus,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from xml_pipeline.message_bus.stream_pump import StreamPump
|
||||
from xml_pipeline.message_bus import (
|
||||
PumpEvent,
|
||||
MessageReceivedEvent,
|
||||
MessageSentEvent,
|
||||
AgentStateEvent,
|
||||
ThreadEvent,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentRuntimeState:
|
||||
"""Runtime state for a single agent."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
is_agent: bool
|
||||
peers: List[str]
|
||||
payload_class: str
|
||||
state: AgentState = AgentState.IDLE
|
||||
current_thread: Optional[str] = None
|
||||
last_activity: Optional[datetime] = None
|
||||
message_count: int = 0
|
||||
queue_depth: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThreadRuntimeState:
|
||||
"""Runtime state for a single thread."""
|
||||
|
||||
id: str
|
||||
status: ThreadStatus = ThreadStatus.ACTIVE
|
||||
participants: Set[str] = field(default_factory=set)
|
||||
message_count: int = 0
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
last_activity: Optional[datetime] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageRecord:
|
||||
"""Record of a message for history."""
|
||||
|
||||
id: str
|
||||
thread_id: str
|
||||
from_id: str
|
||||
to_id: str
|
||||
payload_type: str
|
||||
payload: Dict[str, Any]
|
||||
timestamp: datetime
|
||||
slot_index: int
|
||||
|
||||
|
||||
class ServerState:
|
||||
"""
|
||||
Centralized state manager for the API server.
|
||||
|
||||
Tracks runtime state of agents, threads, and messages.
|
||||
Provides methods for the API to query current state.
|
||||
"""
|
||||
|
||||
def __init__(self, pump: StreamPump) -> None:
|
||||
self.pump = pump
|
||||
self._start_time = time.time()
|
||||
self._status = OrganismStatus.STARTING
|
||||
|
||||
# Runtime state
|
||||
self._agents: Dict[str, AgentRuntimeState] = {}
|
||||
self._threads: Dict[str, ThreadRuntimeState] = {}
|
||||
self._messages: List[MessageRecord] = []
|
||||
self._message_count = 0
|
||||
|
||||
# Event subscribers (WebSocket connections)
|
||||
self._subscribers: Set[Callable] = set()
|
||||
|
||||
# Lock for thread-safe updates
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
# Initialize agent states from pump
|
||||
self._init_agents()
|
||||
|
||||
# Subscribe to pump events
|
||||
self._subscribe_to_pump()
|
||||
|
||||
def _init_agents(self) -> None:
|
||||
"""Initialize agent states from pump listeners."""
|
||||
for name, listener in self.pump.listeners.items():
|
||||
self._agents[name] = AgentRuntimeState(
|
||||
name=name,
|
||||
description=listener.description,
|
||||
is_agent=listener.is_agent,
|
||||
peers=list(listener.peers),
|
||||
payload_class=f"{listener.payload_class.__module__}.{listener.payload_class.__name__}",
|
||||
)
|
||||
|
||||
def _subscribe_to_pump(self) -> None:
|
||||
"""Subscribe to pump events for real-time state updates."""
|
||||
from xml_pipeline.message_bus import (
|
||||
PumpEvent,
|
||||
MessageReceivedEvent,
|
||||
MessageSentEvent,
|
||||
AgentStateEvent as PumpAgentStateEvent,
|
||||
ThreadEvent,
|
||||
)
|
||||
|
||||
def handle_pump_event(event: "PumpEvent") -> None:
|
||||
"""Handle pump events synchronously (schedules async updates)."""
|
||||
if isinstance(event, MessageReceivedEvent):
|
||||
# Schedule async message recording
|
||||
asyncio.create_task(self._handle_message_received(event))
|
||||
elif isinstance(event, MessageSentEvent):
|
||||
# Schedule async message recording
|
||||
asyncio.create_task(self._handle_message_sent(event))
|
||||
elif isinstance(event, PumpAgentStateEvent):
|
||||
# Schedule async agent state update
|
||||
asyncio.create_task(self._handle_agent_state(event))
|
||||
elif isinstance(event, ThreadEvent):
|
||||
# Schedule async thread update
|
||||
asyncio.create_task(self._handle_thread_event(event))
|
||||
|
||||
self.pump.subscribe_events(handle_pump_event)
|
||||
|
||||
async def _handle_message_received(self, event: "MessageReceivedEvent") -> None:
|
||||
"""Handle message received by handler."""
|
||||
# Convert payload to dict representation
|
||||
payload_dict = {}
|
||||
if hasattr(event.payload, '__dataclass_fields__'):
|
||||
import dataclasses
|
||||
payload_dict = dataclasses.asdict(event.payload)
|
||||
elif isinstance(event.payload, dict):
|
||||
payload_dict = event.payload
|
||||
|
||||
await self.record_message(
|
||||
thread_id=event.thread_id,
|
||||
from_id=event.from_id,
|
||||
to_id=event.to_id,
|
||||
payload_type=event.payload_type,
|
||||
payload=payload_dict,
|
||||
)
|
||||
|
||||
async def _handle_message_sent(self, event: "MessageSentEvent") -> None:
|
||||
"""Handle message sent by handler."""
|
||||
payload_dict = {}
|
||||
if hasattr(event.payload, '__dataclass_fields__'):
|
||||
import dataclasses
|
||||
payload_dict = dataclasses.asdict(event.payload)
|
||||
elif isinstance(event.payload, dict):
|
||||
payload_dict = event.payload
|
||||
|
||||
await self.record_message(
|
||||
thread_id=event.thread_id,
|
||||
from_id=event.from_id,
|
||||
to_id=event.to_id,
|
||||
payload_type=event.payload_type,
|
||||
payload=payload_dict,
|
||||
)
|
||||
|
||||
async def _handle_agent_state(self, event: "AgentStateEvent") -> None:
|
||||
"""Handle agent state change from pump."""
|
||||
# Map pump state strings to AgentState enum
|
||||
state_map = {
|
||||
"idle": AgentState.IDLE,
|
||||
"processing": AgentState.PROCESSING,
|
||||
"waiting": AgentState.WAITING,
|
||||
"error": AgentState.ERROR,
|
||||
"paused": AgentState.PAUSED,
|
||||
}
|
||||
state = state_map.get(event.state, AgentState.IDLE)
|
||||
await self.update_agent_state(event.agent_name, state, event.thread_id)
|
||||
|
||||
async def _handle_thread_event(self, event: "ThreadEvent") -> None:
|
||||
"""Handle thread lifecycle event from pump."""
|
||||
# Map status to ThreadStatus enum
|
||||
status_map = {
|
||||
"created": ThreadStatus.ACTIVE,
|
||||
"active": ThreadStatus.ACTIVE,
|
||||
"completed": ThreadStatus.COMPLETED,
|
||||
"error": ThreadStatus.ERROR,
|
||||
"killed": ThreadStatus.KILLED,
|
||||
}
|
||||
status = status_map.get(event.status, ThreadStatus.ACTIVE)
|
||||
|
||||
if event.status in ("completed", "error", "killed"):
|
||||
await self.complete_thread(event.thread_id, status, event.error)
|
||||
|
||||
def set_running(self) -> None:
|
||||
"""Mark organism as running."""
|
||||
self._status = OrganismStatus.RUNNING
|
||||
|
||||
def set_stopping(self) -> None:
|
||||
"""Mark organism as stopping."""
|
||||
self._status = OrganismStatus.STOPPING
|
||||
|
||||
# =========================================================================
|
||||
# Event Recording (called by pump hooks)
|
||||
# =========================================================================
|
||||
|
||||
async def record_message(
|
||||
self,
|
||||
thread_id: str,
|
||||
from_id: str,
|
||||
to_id: str,
|
||||
payload_type: str,
|
||||
payload: Dict[str, Any],
|
||||
) -> str:
|
||||
"""Record a message and update related state."""
|
||||
async with self._lock:
|
||||
msg_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc)
|
||||
self._message_count += 1
|
||||
|
||||
record = MessageRecord(
|
||||
id=msg_id,
|
||||
thread_id=thread_id,
|
||||
from_id=from_id,
|
||||
to_id=to_id,
|
||||
payload_type=payload_type,
|
||||
payload=payload,
|
||||
timestamp=now,
|
||||
slot_index=self._message_count,
|
||||
)
|
||||
self._messages.append(record)
|
||||
|
||||
# Update thread state
|
||||
if thread_id not in self._threads:
|
||||
self._threads[thread_id] = ThreadRuntimeState(
|
||||
id=thread_id,
|
||||
created_at=now,
|
||||
)
|
||||
thread = self._threads[thread_id]
|
||||
thread.participants.add(from_id)
|
||||
thread.participants.add(to_id)
|
||||
thread.message_count += 1
|
||||
thread.last_activity = now
|
||||
|
||||
# Update agent states
|
||||
if from_id in self._agents:
|
||||
self._agents[from_id].message_count += 1
|
||||
self._agents[from_id].last_activity = now
|
||||
|
||||
# Notify subscribers
|
||||
await self._notify_message(record)
|
||||
|
||||
return msg_id
|
||||
|
||||
async def update_agent_state(
|
||||
self,
|
||||
agent_name: str,
|
||||
state: AgentState,
|
||||
current_thread: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Update an agent's processing state."""
|
||||
async with self._lock:
|
||||
if agent_name not in self._agents:
|
||||
return
|
||||
|
||||
agent = self._agents[agent_name]
|
||||
old_state = agent.state
|
||||
agent.state = state
|
||||
agent.current_thread = current_thread
|
||||
agent.last_activity = datetime.now(timezone.utc)
|
||||
|
||||
if old_state != state:
|
||||
await self._notify_agent_state(agent_name, state, current_thread)
|
||||
|
||||
async def complete_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
status: ThreadStatus = ThreadStatus.COMPLETED,
|
||||
error: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Mark a thread as completed or errored."""
|
||||
async with self._lock:
|
||||
if thread_id not in self._threads:
|
||||
return
|
||||
|
||||
thread = self._threads[thread_id]
|
||||
thread.status = status
|
||||
thread.error = error
|
||||
thread.last_activity = datetime.now(timezone.utc)
|
||||
|
||||
await self._notify_thread_updated(thread)
|
||||
|
||||
# =========================================================================
|
||||
# Query Methods (for API)
|
||||
# =========================================================================
|
||||
|
||||
def get_organism_info(self) -> OrganismInfo:
|
||||
"""Get organism overview."""
|
||||
return OrganismInfo(
|
||||
name=self.pump.config.name,
|
||||
status=self._status,
|
||||
uptime_seconds=time.time() - self._start_time,
|
||||
agent_count=len(self._agents),
|
||||
active_threads=sum(
|
||||
1 for t in self._threads.values() if t.status == ThreadStatus.ACTIVE
|
||||
),
|
||||
total_messages=self._message_count,
|
||||
identity_configured=self.pump.identity is not None,
|
||||
)
|
||||
|
||||
def get_organism_config(self) -> Dict[str, Any]:
|
||||
"""Get sanitized organism config (no secrets)."""
|
||||
return {
|
||||
"name": self.pump.config.name,
|
||||
"port": self.pump.config.port,
|
||||
"thread_scheduling": self.pump.config.thread_scheduling,
|
||||
"max_concurrent_pipelines": self.pump.config.max_concurrent_pipelines,
|
||||
"max_concurrent_handlers": self.pump.config.max_concurrent_handlers,
|
||||
"listeners": list(self.pump.listeners.keys()),
|
||||
}
|
||||
|
||||
def get_agents(self) -> List[AgentInfo]:
|
||||
"""Get all agents."""
|
||||
return [self._agent_to_info(a) for a in self._agents.values()]
|
||||
|
||||
def get_agent(self, name: str) -> Optional[AgentInfo]:
|
||||
"""Get a single agent by name."""
|
||||
agent = self._agents.get(name)
|
||||
if agent:
|
||||
return self._agent_to_info(agent)
|
||||
return None
|
||||
|
||||
def get_agent_schema(self, name: str) -> Optional[str]:
|
||||
"""Get agent's XSD schema as string."""
|
||||
listener = self.pump.listeners.get(name)
|
||||
if listener and listener.schema is not None:
|
||||
from lxml import etree
|
||||
|
||||
return etree.tostring(listener.schema, encoding="unicode", pretty_print=True)
|
||||
return None
|
||||
|
||||
def get_threads(
|
||||
self,
|
||||
status: Optional[ThreadStatus] = None,
|
||||
agent: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> tuple[List[ThreadInfo], int]:
|
||||
"""Get threads with optional filtering."""
|
||||
threads = list(self._threads.values())
|
||||
|
||||
# Filter
|
||||
if status:
|
||||
threads = [t for t in threads if t.status == status]
|
||||
if agent:
|
||||
threads = [t for t in threads if agent in t.participants]
|
||||
|
||||
# Sort by last activity (most recent first)
|
||||
threads.sort(key=lambda t: t.last_activity or t.created_at, reverse=True)
|
||||
|
||||
total = len(threads)
|
||||
threads = threads[offset : offset + limit]
|
||||
|
||||
return [self._thread_to_info(t) for t in threads], total
|
||||
|
||||
def get_thread(self, thread_id: str) -> Optional[ThreadInfo]:
|
||||
"""Get a single thread by ID."""
|
||||
thread = self._threads.get(thread_id)
|
||||
if thread:
|
||||
return self._thread_to_info(thread)
|
||||
return None
|
||||
|
||||
def get_messages(
|
||||
self,
|
||||
thread_id: Optional[str] = None,
|
||||
agent: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> tuple[List[MessageInfo], int]:
|
||||
"""Get messages with optional filtering."""
|
||||
messages = self._messages.copy()
|
||||
|
||||
# Filter
|
||||
if thread_id:
|
||||
messages = [m for m in messages if m.thread_id == thread_id]
|
||||
if agent:
|
||||
messages = [m for m in messages if m.from_id == agent or m.to_id == agent]
|
||||
|
||||
# Sort by timestamp (most recent first)
|
||||
messages.sort(key=lambda m: m.timestamp, reverse=True)
|
||||
|
||||
total = len(messages)
|
||||
messages = messages[offset : offset + limit]
|
||||
|
||||
return [self._message_to_info(m) for m in messages], total
|
||||
|
||||
# =========================================================================
|
||||
# Subscription Management
|
||||
# =========================================================================
|
||||
|
||||
def subscribe(self, callback: Callable) -> None:
|
||||
"""Subscribe to state events."""
|
||||
self._subscribers.add(callback)
|
||||
|
||||
def unsubscribe(self, callback: Callable) -> None:
|
||||
"""Unsubscribe from state events."""
|
||||
self._subscribers.discard(callback)
|
||||
|
||||
async def _notify_message(self, record: MessageRecord) -> None:
|
||||
"""Notify subscribers of new message."""
|
||||
event = {
|
||||
"event": "message",
|
||||
"message": self._message_to_info(record).model_dump(by_alias=True),
|
||||
}
|
||||
await self._broadcast(event)
|
||||
|
||||
async def _notify_agent_state(
|
||||
self,
|
||||
agent_name: str,
|
||||
state: AgentState,
|
||||
current_thread: Optional[str],
|
||||
) -> None:
|
||||
"""Notify subscribers of agent state change."""
|
||||
event = {
|
||||
"event": "agent_state",
|
||||
"agent": agent_name,
|
||||
"state": state.value,
|
||||
"currentThread": current_thread,
|
||||
}
|
||||
await self._broadcast(event)
|
||||
|
||||
async def _notify_thread_updated(self, thread: ThreadRuntimeState) -> None:
|
||||
"""Notify subscribers of thread update."""
|
||||
event = {
|
||||
"event": "thread_updated",
|
||||
"threadId": thread.id,
|
||||
"status": thread.status.value,
|
||||
"messageCount": thread.message_count,
|
||||
}
|
||||
await self._broadcast(event)
|
||||
|
||||
async def _broadcast(self, event: Dict[str, Any]) -> None:
|
||||
"""Broadcast event to all subscribers."""
|
||||
for callback in list(self._subscribers):
|
||||
try:
|
||||
await callback(event)
|
||||
except Exception:
|
||||
# Remove failed subscribers
|
||||
self._subscribers.discard(callback)
|
||||
|
||||
# =========================================================================
|
||||
# Conversion Helpers
|
||||
# =========================================================================
|
||||
|
||||
def _agent_to_info(self, agent: AgentRuntimeState) -> AgentInfo:
|
||||
"""Convert runtime state to API model."""
|
||||
return AgentInfo(
|
||||
name=agent.name,
|
||||
description=agent.description,
|
||||
is_agent=agent.is_agent,
|
||||
peers=agent.peers,
|
||||
payload_class=agent.payload_class,
|
||||
state=agent.state,
|
||||
current_thread=agent.current_thread,
|
||||
queue_depth=agent.queue_depth,
|
||||
last_activity=agent.last_activity,
|
||||
message_count=agent.message_count,
|
||||
)
|
||||
|
||||
def _thread_to_info(self, thread: ThreadRuntimeState) -> ThreadInfo:
|
||||
"""Convert runtime state to API model."""
|
||||
return ThreadInfo(
|
||||
id=thread.id,
|
||||
status=thread.status,
|
||||
participants=list(thread.participants),
|
||||
message_count=thread.message_count,
|
||||
created_at=thread.created_at,
|
||||
last_activity=thread.last_activity,
|
||||
error=thread.error,
|
||||
)
|
||||
|
||||
def _message_to_info(self, record: MessageRecord) -> MessageInfo:
|
||||
"""Convert record to API model."""
|
||||
return MessageInfo(
|
||||
id=record.id,
|
||||
thread_id=record.thread_id,
|
||||
from_id=record.from_id,
|
||||
to_id=record.to_id,
|
||||
payload_type=record.payload_type,
|
||||
payload=record.payload,
|
||||
timestamp=record.timestamp,
|
||||
slot_index=record.slot_index,
|
||||
)
|
||||
316
xml_pipeline/server/websocket.py
Normal file
316
xml_pipeline/server/websocket.py
Normal file
|
|
@ -0,0 +1,316 @@
|
|||
"""
|
||||
websocket.py — WebSocket endpoints for AgentServer.
|
||||
|
||||
Provides:
|
||||
- /ws — Main control channel with state snapshot and real-time events
|
||||
- /ws/messages — Dedicated message log stream with filtering
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
|
||||
from xml_pipeline.server.models import (
|
||||
SubscribeRequest,
|
||||
WSConnectedEvent,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from xml_pipeline.server.state import ServerState
|
||||
|
||||
|
||||
class ConnectionManager:
|
||||
"""Manages WebSocket connections and subscriptions."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.active_connections: Set[WebSocket] = set()
|
||||
self.subscriptions: Dict[WebSocket, SubscribeRequest] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket) -> None:
|
||||
"""Accept a new WebSocket connection."""
|
||||
await websocket.accept()
|
||||
self.active_connections.add(websocket)
|
||||
self.subscriptions[websocket] = SubscribeRequest() # Default: all events
|
||||
|
||||
def disconnect(self, websocket: WebSocket) -> None:
|
||||
"""Remove a WebSocket connection."""
|
||||
self.active_connections.discard(websocket)
|
||||
self.subscriptions.pop(websocket, None)
|
||||
|
||||
def set_subscription(self, websocket: WebSocket, sub: SubscribeRequest) -> None:
|
||||
"""Update subscription filters for a connection."""
|
||||
self.subscriptions[websocket] = sub
|
||||
|
||||
def should_send(self, websocket: WebSocket, event: Dict[str, Any]) -> bool:
|
||||
"""Check if an event should be sent to this connection based on filters."""
|
||||
sub = self.subscriptions.get(websocket)
|
||||
if sub is None:
|
||||
return True # No filters = send all
|
||||
|
||||
event_type = event.get("event", "")
|
||||
|
||||
# Filter by event type
|
||||
if sub.events and event_type not in sub.events:
|
||||
return False
|
||||
|
||||
# Filter by thread
|
||||
thread_id = event.get("thread_id") or event.get("threadId")
|
||||
if sub.threads and thread_id and thread_id not in sub.threads:
|
||||
return False
|
||||
|
||||
# Filter by agent
|
||||
agent = event.get("agent")
|
||||
if sub.agents and agent and agent not in sub.agents:
|
||||
return False
|
||||
|
||||
# For message events, check from/to
|
||||
if event_type == "message":
|
||||
msg = event.get("message", {})
|
||||
from_id = msg.get("from") or msg.get("fromId")
|
||||
to_id = msg.get("to") or msg.get("toId")
|
||||
if sub.agents:
|
||||
if from_id not in sub.agents and to_id not in sub.agents:
|
||||
return False
|
||||
|
||||
# Filter by payload type
|
||||
if sub.payload_types:
|
||||
payload_type = None
|
||||
if event_type == "message":
|
||||
payload_type = event.get("message", {}).get("payloadType")
|
||||
if payload_type and payload_type not in sub.payload_types:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def broadcast(self, event: Dict[str, Any]) -> None:
|
||||
"""Broadcast an event to all connections that match their subscription."""
|
||||
disconnected: List[WebSocket] = []
|
||||
|
||||
for websocket in self.active_connections:
|
||||
if self.should_send(websocket, event):
|
||||
try:
|
||||
await websocket.send_json(event)
|
||||
except Exception:
|
||||
disconnected.append(websocket)
|
||||
|
||||
# Clean up disconnected clients
|
||||
for ws in disconnected:
|
||||
self.disconnect(ws)
|
||||
|
||||
|
||||
class MessageStreamManager:
|
||||
"""Manages WebSocket connections for /ws/messages endpoint."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.active_connections: Set[WebSocket] = set()
|
||||
self.filters: Dict[WebSocket, Dict[str, Any]] = {}
|
||||
|
||||
async def connect(self, websocket: WebSocket) -> None:
|
||||
"""Accept a new WebSocket connection."""
|
||||
await websocket.accept()
|
||||
self.active_connections.add(websocket)
|
||||
self.filters[websocket] = {} # Default: all messages
|
||||
|
||||
def disconnect(self, websocket: WebSocket) -> None:
|
||||
"""Remove a WebSocket connection."""
|
||||
self.active_connections.discard(websocket)
|
||||
self.filters.pop(websocket, None)
|
||||
|
||||
def set_filter(self, websocket: WebSocket, filter_config: Dict[str, Any]) -> None:
|
||||
"""Update message filter for a connection."""
|
||||
self.filters[websocket] = filter_config
|
||||
|
||||
def should_send(self, websocket: WebSocket, message: Dict[str, Any]) -> bool:
|
||||
"""Check if a message should be sent to this connection."""
|
||||
flt = self.filters.get(websocket, {})
|
||||
if not flt:
|
||||
return True
|
||||
|
||||
# Filter by agents
|
||||
agents = flt.get("agents", [])
|
||||
if agents:
|
||||
from_id = message.get("from") or message.get("fromId")
|
||||
to_id = message.get("to") or message.get("toId")
|
||||
if from_id not in agents and to_id not in agents:
|
||||
return False
|
||||
|
||||
# Filter by threads
|
||||
threads = flt.get("threads", [])
|
||||
if threads:
|
||||
thread_id = message.get("thread_id") or message.get("threadId")
|
||||
if thread_id not in threads:
|
||||
return False
|
||||
|
||||
# Filter by payload types
|
||||
payload_types = flt.get("payload_types", [])
|
||||
if payload_types:
|
||||
payload_type = message.get("payloadType") or message.get("payload_type")
|
||||
if payload_type not in payload_types:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def broadcast_message(self, message: Dict[str, Any]) -> None:
|
||||
"""Broadcast a message to all filtered connections."""
|
||||
disconnected: List[WebSocket] = []
|
||||
|
||||
for websocket in self.active_connections:
|
||||
if self.should_send(websocket, message):
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
except Exception:
|
||||
disconnected.append(websocket)
|
||||
|
||||
for ws in disconnected:
|
||||
self.disconnect(ws)
|
||||
|
||||
|
||||
def create_websocket_router(state: "ServerState") -> APIRouter:
|
||||
"""Create WebSocket router with state dependency."""
|
||||
router = APIRouter()
|
||||
manager = ConnectionManager()
|
||||
message_manager = MessageStreamManager()
|
||||
|
||||
# Subscribe state to WebSocket broadcasting
|
||||
async def on_state_event(event: Dict[str, Any]) -> None:
|
||||
"""Forward state events to WebSocket connections."""
|
||||
await manager.broadcast(event)
|
||||
|
||||
# Also forward message events to the message stream
|
||||
if event.get("event") == "message":
|
||||
msg = event.get("message", {})
|
||||
await message_manager.broadcast_message(msg)
|
||||
|
||||
state.subscribe(on_state_event)
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def websocket_control(websocket: WebSocket) -> None:
|
||||
"""
|
||||
Main control channel WebSocket.
|
||||
|
||||
On connect, sends full state snapshot. Then pushes events as they occur.
|
||||
Accepts commands: subscribe, inject.
|
||||
"""
|
||||
await manager.connect(websocket)
|
||||
|
||||
try:
|
||||
# Send connected event with state snapshot
|
||||
threads, _ = state.get_threads(limit=100)
|
||||
connected_event = WSConnectedEvent(
|
||||
organism=state.get_organism_info(),
|
||||
agents=state.get_agents(),
|
||||
threads=threads,
|
||||
)
|
||||
await websocket.send_json(connected_event.model_dump(by_alias=True))
|
||||
|
||||
# Listen for commands
|
||||
while True:
|
||||
try:
|
||||
data = await websocket.receive_json()
|
||||
except Exception:
|
||||
# Connection closed or invalid data
|
||||
break
|
||||
|
||||
cmd = data.get("cmd", "")
|
||||
|
||||
if cmd == "subscribe":
|
||||
# Update subscription filters
|
||||
sub = SubscribeRequest(
|
||||
threads=data.get("threads"),
|
||||
agents=data.get("agents"),
|
||||
payload_types=data.get("payload_types"),
|
||||
events=data.get("events"),
|
||||
)
|
||||
manager.set_subscription(websocket, sub)
|
||||
await websocket.send_json({"event": "subscribed", "filters": data})
|
||||
|
||||
elif cmd == "inject":
|
||||
# Inject a message (same as REST /inject)
|
||||
target = data.get("to")
|
||||
payload = data.get("payload", {})
|
||||
thread_id = data.get("thread_id")
|
||||
|
||||
if not target:
|
||||
await websocket.send_json(
|
||||
{"event": "error", "error": "Missing 'to' field"}
|
||||
)
|
||||
continue
|
||||
|
||||
agent = state.get_agent(target)
|
||||
if agent is None:
|
||||
await websocket.send_json(
|
||||
{"event": "error", "error": f"Unknown agent: {target}"}
|
||||
)
|
||||
continue
|
||||
|
||||
import uuid
|
||||
|
||||
thread_id = thread_id or str(uuid.uuid4())
|
||||
payload_type = next(iter(payload.keys()), "Payload")
|
||||
|
||||
msg_id = await state.record_message(
|
||||
thread_id=thread_id,
|
||||
from_id="api",
|
||||
to_id=target,
|
||||
payload_type=payload_type,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
await websocket.send_json(
|
||||
{
|
||||
"event": "injected",
|
||||
"thread_id": thread_id,
|
||||
"message_id": msg_id,
|
||||
}
|
||||
)
|
||||
|
||||
else:
|
||||
await websocket.send_json(
|
||||
{"event": "error", "error": f"Unknown command: {cmd}"}
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
manager.disconnect(websocket)
|
||||
|
||||
@router.websocket("/ws/messages")
|
||||
async def websocket_messages(websocket: WebSocket) -> None:
|
||||
"""
|
||||
Dedicated message log stream.
|
||||
|
||||
Streams all messages flowing through the organism.
|
||||
Clients can filter by agents, threads, and payload types.
|
||||
"""
|
||||
await message_manager.connect(websocket)
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
data = await websocket.receive_json()
|
||||
except Exception:
|
||||
break
|
||||
|
||||
cmd = data.get("cmd", "")
|
||||
|
||||
if cmd == "subscribe":
|
||||
# Update filter
|
||||
filter_config = data.get("filter", {})
|
||||
message_manager.set_filter(websocket, filter_config)
|
||||
await websocket.send_json({"event": "subscribed", "filter": filter_config})
|
||||
|
||||
else:
|
||||
await websocket.send_json(
|
||||
{"event": "error", "error": f"Unknown command: {cmd}"}
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
finally:
|
||||
message_manager.disconnect(websocket)
|
||||
|
||||
return router
|
||||
Loading…
Reference in a new issue