xml-pipeline/xml_pipeline/server/state.py
dullfig 3ff399e849 Add hot-reload for organism configuration
Implements runtime configuration reload via POST /api/v1/organism/reload:
- StreamPump.reload_config() re-reads organism.yaml
- Adds new listeners, removes old ones, updates changed ones
- System listeners (system.*) are protected from removal
- ReloadEvent emitted to notify WebSocket subscribers
- ServerState.reload_config() refreshes agent runtime state

14 new tests covering add/remove/update scenarios.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 20:38:48 -08:00

569 lines
19 KiB
Python

"""
state.py — Server state manager for tracking organism runtime state.
Maintains real-time state for API queries:
- Agent states (idle, processing, waiting, error)
- Active threads and their participants
- Message counts and activity timestamps
This is the bridge between the StreamPump and the API.
"""
from __future__ import annotations
import asyncio
import time
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set
from xml_pipeline.server.models import (
AgentInfo,
AgentState,
MessageInfo,
OrganismInfo,
OrganismStatus,
ThreadInfo,
ThreadStatus,
)
if TYPE_CHECKING:
from xml_pipeline.message_bus.stream_pump import StreamPump
from xml_pipeline.message_bus import (
PumpEvent,
MessageReceivedEvent,
MessageSentEvent,
AgentStateEvent,
ThreadEvent,
)
@dataclass
class AgentRuntimeState:
"""Runtime state for a single agent."""
name: str
description: str
is_agent: bool
peers: List[str]
payload_class: str
state: AgentState = AgentState.IDLE
current_thread: Optional[str] = None
last_activity: Optional[datetime] = None
message_count: int = 0
queue_depth: int = 0
@dataclass
class ThreadRuntimeState:
"""Runtime state for a single thread."""
id: str
status: ThreadStatus = ThreadStatus.ACTIVE
participants: Set[str] = field(default_factory=set)
message_count: int = 0
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
last_activity: Optional[datetime] = None
error: Optional[str] = None
@dataclass
class MessageRecord:
"""Record of a message for history."""
id: str
thread_id: str
from_id: str
to_id: str
payload_type: str
payload: Dict[str, Any]
timestamp: datetime
slot_index: int
class ServerState:
"""
Centralized state manager for the API server.
Tracks runtime state of agents, threads, and messages.
Provides methods for the API to query current state.
"""
def __init__(self, pump: StreamPump) -> None:
self.pump = pump
self._start_time = time.time()
self._status = OrganismStatus.STARTING
# Runtime state
self._agents: Dict[str, AgentRuntimeState] = {}
self._threads: Dict[str, ThreadRuntimeState] = {}
self._messages: List[MessageRecord] = []
self._message_count = 0
# Event subscribers (WebSocket connections)
self._subscribers: Set[Callable] = set()
# Lock for thread-safe updates
self._lock = asyncio.Lock()
# Initialize agent states from pump
self._init_agents()
# Subscribe to pump events
self._subscribe_to_pump()
def _init_agents(self) -> None:
"""Initialize agent states from pump listeners."""
for name, listener in self.pump.listeners.items():
self._agents[name] = AgentRuntimeState(
name=name,
description=listener.description,
is_agent=listener.is_agent,
peers=list(listener.peers),
payload_class=f"{listener.payload_class.__module__}.{listener.payload_class.__name__}",
)
def _subscribe_to_pump(self) -> None:
"""Subscribe to pump events for real-time state updates."""
from xml_pipeline.message_bus import (
PumpEvent,
MessageReceivedEvent,
MessageSentEvent,
AgentStateEvent as PumpAgentStateEvent,
ThreadEvent,
)
def handle_pump_event(event: "PumpEvent") -> None:
"""Handle pump events synchronously (schedules async updates)."""
if isinstance(event, MessageReceivedEvent):
# Schedule async message recording
asyncio.create_task(self._handle_message_received(event))
elif isinstance(event, MessageSentEvent):
# Schedule async message recording
asyncio.create_task(self._handle_message_sent(event))
elif isinstance(event, PumpAgentStateEvent):
# Schedule async agent state update
asyncio.create_task(self._handle_agent_state(event))
elif isinstance(event, ThreadEvent):
# Schedule async thread update
asyncio.create_task(self._handle_thread_event(event))
self.pump.subscribe_events(handle_pump_event)
async def _handle_message_received(self, event: "MessageReceivedEvent") -> None:
"""Handle message received by handler."""
# Convert payload to dict representation
payload_dict = {}
if hasattr(event.payload, '__dataclass_fields__'):
import dataclasses
payload_dict = dataclasses.asdict(event.payload)
elif isinstance(event.payload, dict):
payload_dict = event.payload
await self.record_message(
thread_id=event.thread_id,
from_id=event.from_id,
to_id=event.to_id,
payload_type=event.payload_type,
payload=payload_dict,
)
async def _handle_message_sent(self, event: "MessageSentEvent") -> None:
"""Handle message sent by handler."""
payload_dict = {}
if hasattr(event.payload, '__dataclass_fields__'):
import dataclasses
payload_dict = dataclasses.asdict(event.payload)
elif isinstance(event.payload, dict):
payload_dict = event.payload
await self.record_message(
thread_id=event.thread_id,
from_id=event.from_id,
to_id=event.to_id,
payload_type=event.payload_type,
payload=payload_dict,
)
async def _handle_agent_state(self, event: "AgentStateEvent") -> None:
"""Handle agent state change from pump."""
# Map pump state strings to AgentState enum
state_map = {
"idle": AgentState.IDLE,
"processing": AgentState.PROCESSING,
"waiting": AgentState.WAITING,
"error": AgentState.ERROR,
"paused": AgentState.PAUSED,
}
state = state_map.get(event.state, AgentState.IDLE)
await self.update_agent_state(event.agent_name, state, event.thread_id)
async def _handle_thread_event(self, event: "ThreadEvent") -> None:
"""Handle thread lifecycle event from pump."""
# Map status to ThreadStatus enum
status_map = {
"created": ThreadStatus.ACTIVE,
"active": ThreadStatus.ACTIVE,
"completed": ThreadStatus.COMPLETED,
"error": ThreadStatus.ERROR,
"killed": ThreadStatus.KILLED,
}
status = status_map.get(event.status, ThreadStatus.ACTIVE)
if event.status in ("completed", "error", "killed"):
await self.complete_thread(event.thread_id, status, event.error)
def set_running(self) -> None:
"""Mark organism as running."""
self._status = OrganismStatus.RUNNING
def set_stopping(self) -> None:
"""Mark organism as stopping."""
self._status = OrganismStatus.STOPPING
async def reload_config(self, config_path: Optional[str] = None) -> dict:
"""
Hot-reload organism configuration.
Calls pump.reload_config() and updates local agent state.
Args:
config_path: Optional path to config file (uses stored path if not provided)
Returns:
Dict with reload results
"""
from xml_pipeline.message_bus import ReloadEvent
# Call pump's reload
event = self.pump.reload_config(config_path)
if event.success:
# Refresh agent states from pump
async with self._lock:
# Remove agents that were removed from pump
for name in event.removed_listeners:
if name in self._agents:
del self._agents[name]
# Add/update agents
for name in event.added_listeners + event.updated_listeners:
listener = self.pump.listeners.get(name)
if listener:
self._agents[name] = AgentRuntimeState(
name=name,
description=listener.description,
is_agent=listener.is_agent,
peers=list(listener.peers),
payload_class=f"{listener.payload_class.__module__}.{listener.payload_class.__name__}",
)
# Notify subscribers of reload
await self._broadcast({
"event": "reload",
"success": True,
"added": event.added_listeners,
"removed": event.removed_listeners,
"updated": event.updated_listeners,
})
return {
"success": event.success,
"added": event.added_listeners,
"removed": event.removed_listeners,
"updated": event.updated_listeners,
"error": event.error,
}
# =========================================================================
# Event Recording (called by pump hooks)
# =========================================================================
async def record_message(
self,
thread_id: str,
from_id: str,
to_id: str,
payload_type: str,
payload: Dict[str, Any],
) -> str:
"""Record a message and update related state."""
async with self._lock:
msg_id = str(uuid.uuid4())
now = datetime.now(timezone.utc)
self._message_count += 1
record = MessageRecord(
id=msg_id,
thread_id=thread_id,
from_id=from_id,
to_id=to_id,
payload_type=payload_type,
payload=payload,
timestamp=now,
slot_index=self._message_count,
)
self._messages.append(record)
# Update thread state
if thread_id not in self._threads:
self._threads[thread_id] = ThreadRuntimeState(
id=thread_id,
created_at=now,
)
thread = self._threads[thread_id]
thread.participants.add(from_id)
thread.participants.add(to_id)
thread.message_count += 1
thread.last_activity = now
# Update agent states
if from_id in self._agents:
self._agents[from_id].message_count += 1
self._agents[from_id].last_activity = now
# Notify subscribers
await self._notify_message(record)
return msg_id
async def update_agent_state(
self,
agent_name: str,
state: AgentState,
current_thread: Optional[str] = None,
) -> None:
"""Update an agent's processing state."""
async with self._lock:
if agent_name not in self._agents:
return
agent = self._agents[agent_name]
old_state = agent.state
agent.state = state
agent.current_thread = current_thread
agent.last_activity = datetime.now(timezone.utc)
if old_state != state:
await self._notify_agent_state(agent_name, state, current_thread)
async def complete_thread(
self,
thread_id: str,
status: ThreadStatus = ThreadStatus.COMPLETED,
error: Optional[str] = None,
) -> None:
"""Mark a thread as completed or errored."""
async with self._lock:
if thread_id not in self._threads:
return
thread = self._threads[thread_id]
thread.status = status
thread.error = error
thread.last_activity = datetime.now(timezone.utc)
await self._notify_thread_updated(thread)
# =========================================================================
# Query Methods (for API)
# =========================================================================
def get_organism_info(self) -> OrganismInfo:
"""Get organism overview."""
return OrganismInfo(
name=self.pump.config.name,
status=self._status,
uptime_seconds=time.time() - self._start_time,
agent_count=len(self._agents),
active_threads=sum(
1 for t in self._threads.values() if t.status == ThreadStatus.ACTIVE
),
total_messages=self._message_count,
identity_configured=self.pump.identity is not None,
)
def get_organism_config(self) -> Dict[str, Any]:
"""Get sanitized organism config (no secrets)."""
return {
"name": self.pump.config.name,
"port": self.pump.config.port,
"thread_scheduling": self.pump.config.thread_scheduling,
"max_concurrent_pipelines": self.pump.config.max_concurrent_pipelines,
"max_concurrent_handlers": self.pump.config.max_concurrent_handlers,
"listeners": list(self.pump.listeners.keys()),
}
def get_agents(self) -> List[AgentInfo]:
"""Get all agents."""
return [self._agent_to_info(a) for a in self._agents.values()]
def get_agent(self, name: str) -> Optional[AgentInfo]:
"""Get a single agent by name."""
agent = self._agents.get(name)
if agent:
return self._agent_to_info(agent)
return None
def get_agent_schema(self, name: str) -> Optional[str]:
"""Get agent's XSD schema as string."""
listener = self.pump.listeners.get(name)
if listener and listener.schema is not None:
from lxml import etree
return etree.tostring(listener.schema, encoding="unicode", pretty_print=True)
return None
def get_threads(
self,
status: Optional[ThreadStatus] = None,
agent: Optional[str] = None,
limit: int = 50,
offset: int = 0,
) -> tuple[List[ThreadInfo], int]:
"""Get threads with optional filtering."""
threads = list(self._threads.values())
# Filter
if status:
threads = [t for t in threads if t.status == status]
if agent:
threads = [t for t in threads if agent in t.participants]
# Sort by last activity (most recent first)
threads.sort(key=lambda t: t.last_activity or t.created_at, reverse=True)
total = len(threads)
threads = threads[offset : offset + limit]
return [self._thread_to_info(t) for t in threads], total
def get_thread(self, thread_id: str) -> Optional[ThreadInfo]:
"""Get a single thread by ID."""
thread = self._threads.get(thread_id)
if thread:
return self._thread_to_info(thread)
return None
def get_messages(
self,
thread_id: Optional[str] = None,
agent: Optional[str] = None,
limit: int = 50,
offset: int = 0,
) -> tuple[List[MessageInfo], int]:
"""Get messages with optional filtering."""
messages = self._messages.copy()
# Filter
if thread_id:
messages = [m for m in messages if m.thread_id == thread_id]
if agent:
messages = [m for m in messages if m.from_id == agent or m.to_id == agent]
# Sort by timestamp (most recent first)
messages.sort(key=lambda m: m.timestamp, reverse=True)
total = len(messages)
messages = messages[offset : offset + limit]
return [self._message_to_info(m) for m in messages], total
# =========================================================================
# Subscription Management
# =========================================================================
def subscribe(self, callback: Callable) -> None:
"""Subscribe to state events."""
self._subscribers.add(callback)
def unsubscribe(self, callback: Callable) -> None:
"""Unsubscribe from state events."""
self._subscribers.discard(callback)
async def _notify_message(self, record: MessageRecord) -> None:
"""Notify subscribers of new message."""
event = {
"event": "message",
"message": self._message_to_info(record).model_dump(by_alias=True),
}
await self._broadcast(event)
async def _notify_agent_state(
self,
agent_name: str,
state: AgentState,
current_thread: Optional[str],
) -> None:
"""Notify subscribers of agent state change."""
event = {
"event": "agent_state",
"agent": agent_name,
"state": state.value,
"currentThread": current_thread,
}
await self._broadcast(event)
async def _notify_thread_updated(self, thread: ThreadRuntimeState) -> None:
"""Notify subscribers of thread update."""
event = {
"event": "thread_updated",
"threadId": thread.id,
"status": thread.status.value,
"messageCount": thread.message_count,
}
await self._broadcast(event)
async def _broadcast(self, event: Dict[str, Any]) -> None:
"""Broadcast event to all subscribers."""
for callback in list(self._subscribers):
try:
await callback(event)
except Exception:
# Remove failed subscribers
self._subscribers.discard(callback)
# =========================================================================
# Conversion Helpers
# =========================================================================
def _agent_to_info(self, agent: AgentRuntimeState) -> AgentInfo:
"""Convert runtime state to API model."""
return AgentInfo(
name=agent.name,
description=agent.description,
is_agent=agent.is_agent,
peers=agent.peers,
payload_class=agent.payload_class,
state=agent.state,
current_thread=agent.current_thread,
queue_depth=agent.queue_depth,
last_activity=agent.last_activity,
message_count=agent.message_count,
)
def _thread_to_info(self, thread: ThreadRuntimeState) -> ThreadInfo:
"""Convert runtime state to API model."""
return ThreadInfo(
id=thread.id,
status=thread.status,
participants=list(thread.participants),
message_count=thread.message_count,
created_at=thread.created_at,
last_activity=thread.last_activity,
error=thread.error,
)
def _message_to_info(self, record: MessageRecord) -> MessageInfo:
"""Convert record to API model."""
return MessageInfo(
id=record.id,
thread_id=record.thread_id,
from_id=record.from_id,
to_id=record.to_id,
payload_type=record.payload_type,
payload=record.payload,
timestamp=record.timestamp,
slot_index=record.slot_index,
)