"""
pump_aiostream.py — Stream-Based Message Pump using aiostream
This implementation treats the entire message flow as composable streams.
Fan-out (multi-payload, broadcast) is handled naturally via flatmap.
Key insight: Each step is a stream transformer, not a 1:1 function.
The pipeline is just a composition of stream operators.
Dependencies:
pip install aiostream
CPU-Bound Handlers:
Handlers marked with `cpu_bound: true` are dispatched to a
ProcessPoolExecutor instead of running in the main event loop.
This prevents long-running handlers from blocking other messages.
Requires shared backend (Redis or Manager) for cross-process data access.
"""
from __future__ import annotations
import asyncio
import importlib
import logging
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass, field
from pathlib import Path
from typing import AsyncIterable, Callable, List, Dict, Any, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from xml_pipeline.crypto.identity import Identity
import yaml
from lxml import etree
from aiostream import stream, pipe, operator
# Import existing step implementations (we'll wrap them)
from xml_pipeline.message_bus.steps.repair import repair_step
from xml_pipeline.message_bus.steps.c14n import c14n_step
from xml_pipeline.message_bus.steps.envelope_validation import envelope_validation_step
from xml_pipeline.message_bus.steps.payload_extraction import payload_extraction_step
from xml_pipeline.message_bus.steps.thread_assignment import thread_assignment_step
from xml_pipeline.message_bus.message_state import MessageState, HandlerMetadata, HandlerResponse, SystemError, ROUTING_ERROR
from xml_pipeline.message_bus.thread_registry import get_registry
from xml_pipeline.message_bus.todo_registry import get_todo_registry
from xml_pipeline.message_bus.budget_registry import get_budget_registry
from xml_pipeline.message_bus.dispatch_hook import DispatchHook
from xml_pipeline.memory import get_context_buffer
pump_logger = logging.getLogger(__name__)
# ============================================================================
# Event Hooks
# ============================================================================
@dataclass
class PumpEvent:
"""Base class for pump events."""
pass
@dataclass
class MessageReceivedEvent(PumpEvent):
"""Fired when a message is received by a handler."""
thread_id: str
from_id: str
to_id: str
payload_type: str
payload: Any
@dataclass
class MessageSentEvent(PumpEvent):
"""Fired when a handler sends a response."""
thread_id: str
from_id: str
to_id: str
payload_type: str
payload: Any
@dataclass
class AgentStateEvent(PumpEvent):
"""Fired when an agent's processing state changes."""
agent_name: str
state: str # "idle", "processing", "waiting", "error"
thread_id: Optional[str] = None
@dataclass
class ThreadEvent(PumpEvent):
"""Fired when a thread is created or completed."""
thread_id: str
status: str # "created", "active", "completed", "error", "killed"
participants: List[str] = field(default_factory=list)
error: Optional[str] = None
@dataclass
class ReloadEvent(PumpEvent):
"""Fired when organism configuration is reloaded."""
success: bool
added_listeners: List[str] = field(default_factory=list)
removed_listeners: List[str] = field(default_factory=list)
updated_listeners: List[str] = field(default_factory=list)
error: Optional[str] = None
EventCallback = Callable[[PumpEvent], None]
# ============================================================================
# Configuration (same as before)
# ============================================================================
@dataclass
class ListenerConfig:
name: str
payload_class_path: str
handler_path: str
description: str
is_agent: bool = False
peers: List[str] = field(default_factory=list)
broadcast: bool = False
prompt: str = "" # System prompt for LLM agents (loaded into PromptRegistry)
cpu_bound: bool = False # Dispatch to ProcessPoolExecutor if True
payload_class: type = field(default=None, repr=False)
handler: Callable = field(default=None, repr=False)
# Output type for sequence factory validation
output_class_path: str = "" # Import path to output dataclass
output_class: type = field(default=None, repr=False)
@dataclass
class OrganismConfig:
name: str
identity_path: str = ""
port: int = 8765
thread_scheduling: str = "breadth-first"
listeners: List[ListenerConfig] = field(default_factory=list)
# Concurrency tuning
max_concurrent_pipelines: int = 50 # Total concurrent messages in pipeline
max_concurrent_handlers: int = 20 # Concurrent handler invocations
max_concurrent_per_agent: int = 5 # Per-agent rate limit
# Token budget enforcement
max_tokens_per_thread: int = 100_000 # Max tokens per conversation thread
# LLM configuration (optional)
llm_config: Dict[str, Any] = field(default_factory=dict)
# Process pool configuration (for cpu_bound handlers)
process_pool_workers: int = 4
process_pool_max_tasks_per_child: int = 100
process_pool_enabled: bool = False
# Backend configuration (for shared state)
backend_type: str = "memory" # "memory", "manager", "redis"
backend_redis_url: str = "redis://localhost:6379"
backend_redis_prefix: str = "xp:"
@dataclass
class Listener:
name: str
payload_class: type
handler: Callable
description: str
is_agent: bool = False
peers: List[str] = field(default_factory=list)
broadcast: bool = False
cpu_bound: bool = False # Dispatch to ProcessPoolExecutor if True
handler_path: str = "" # Import path for worker process
schema: etree.XMLSchema = field(default=None, repr=False)
root_tag: str = ""
usage_instructions: str = "" # Generated at registration for LLM agents
output_class: type = field(default=None, repr=False) # Output type for sequence validation
# ============================================================================
# Stream-Based Pipeline Steps
# ============================================================================
def wrap_step(step_fn: Callable) -> Callable:
"""
Wrap an existing async step function for use with pipe.map.
Existing steps: async def step(state) -> state
We keep them as-is since pipe.map handles the iteration.
"""
return step_fn
async def extract_payloads(state: MessageState) -> AsyncIterable[MessageState]:
"""
Fan-out step: Extract 1..N payloads from handler response.
This is used with pipe.flatmap — yields multiple states for each input.
"""
if state.raw_bytes is None:
yield state
return
try:
# Wrap in dummy to handle multiple roots
wrapped = b"" + state.raw_bytes + b""
tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True))
children = list(tree)
if not children:
yield state
return
for child in children:
payload_bytes = etree.tostring(child)
yield MessageState(
raw_bytes=payload_bytes,
thread_id=state.thread_id,
from_id=state.from_id,
metadata=state.metadata.copy(),
)
except Exception:
# On parse failure, pass through as-is
yield state
def make_xsd_validation(schema: etree.XMLSchema) -> Callable:
"""Factory for XSD validation step with schema baked in."""
async def validate(state: MessageState) -> MessageState:
if state.payload_tree is None or state.error:
return state
try:
schema.assertValid(state.payload_tree)
except etree.DocumentInvalid as e:
state.error = f"XSD validation failed: {e}"
return state
return validate
def make_deserialization(payload_class: type) -> Callable:
"""Factory for deserialization step with class baked in."""
from third_party.xmlable import parse_element
async def deserialize(state: MessageState) -> MessageState:
if state.payload_tree is None or state.error:
return state
try:
state.payload = parse_element(payload_class, state.payload_tree)
except Exception as e:
state.error = f"Deserialization failed: {e}"
return state
return deserialize
# ============================================================================
# The Stream-Based Pump
# ============================================================================
class StreamPump:
"""
Message pump built on aiostream.
The entire flow is a single composable stream pipeline.
Fan-out is natural via flatmap. Concurrency is controlled via task_limit.
CPU-bound handlers can be dispatched to a ProcessPoolExecutor by
marking them with `cpu_bound: true` in config. This requires a
shared backend (Redis or Manager) for cross-process data access.
"""
def __init__(self, config: OrganismConfig, config_path: str = ""):
self.config = config
self.config_path = config_path # Store path for hot-reload
# Message queue feeds the stream
self.queue: asyncio.Queue[MessageState] = asyncio.Queue()
# Routing table
self.routing_table: Dict[str, List[Listener]] = {}
self.listeners: Dict[str, Listener] = {}
# Identity for envelope signing (optional)
self.identity: Optional["Identity"] = None
if config.identity_path:
try:
from xml_pipeline.crypto import Identity
self.identity = Identity.load(config.identity_path)
pump_logger.info(f"Identity loaded: {config.identity_path}")
except Exception as e:
pump_logger.warning(f"Failed to load identity: {e}")
# Generic listeners (accept any payload type)
# Used for ephemeral orchestration handlers (sequences, buffers)
self._generic_listeners: Dict[str, Listener] = {}
# Per-agent semaphores for rate limiting
self.agent_semaphores: Dict[str, asyncio.Semaphore] = {}
# Shutdown control
self._running = False
# Event hooks for external observers (ServerState, etc.)
self._event_callbacks: List[EventCallback] = []
# Dispatch hooks for journaling/auditing (DispatchHook protocol)
self.dispatch_hooks: List[DispatchHook] = []
# Process pool for cpu_bound handlers
self._process_pool: Optional[ProcessPoolExecutor] = None
if config.process_pool_enabled:
self._process_pool = ProcessPoolExecutor(
max_workers=config.process_pool_workers,
max_tasks_per_child=config.process_pool_max_tasks_per_child,
)
pump_logger.info(
f"ProcessPool initialized: {config.process_pool_workers} workers"
)
# Shared backend for cross-process state
self._shared_backend = None
if config.backend_type != "memory":
from xml_pipeline.memory.shared_backend import BackendConfig, get_shared_backend
backend_config = BackendConfig(
backend_type=config.backend_type,
redis_url=config.backend_redis_url,
redis_prefix=config.backend_redis_prefix,
)
self._shared_backend = get_shared_backend(backend_config)
pump_logger.info(f"Shared backend: {config.backend_type}")
# ------------------------------------------------------------------
# Event Hooks
# ------------------------------------------------------------------
def subscribe_events(self, callback: EventCallback) -> None:
"""Subscribe to pump events (message flow, agent state, thread lifecycle)."""
self._event_callbacks.append(callback)
def unsubscribe_events(self, callback: EventCallback) -> None:
"""Unsubscribe from pump events."""
if callback in self._event_callbacks:
self._event_callbacks.remove(callback)
def _emit_event(self, event: PumpEvent) -> None:
"""Emit an event to all subscribers (non-blocking)."""
for callback in self._event_callbacks:
try:
callback(event)
except Exception as e:
pump_logger.warning(f"Event callback error: {e}")
# ------------------------------------------------------------------
# Dispatch Hooks
# ------------------------------------------------------------------
def register_dispatch_hook(self, hook: DispatchHook) -> None:
"""
Register a dispatch lifecycle hook (journaling, auditing, etc.).
Hooks are called at key dispatch lifecycle points:
on_intent, on_dispatched, on_acknowledged, on_failed, on_thread_complete.
"""
self.dispatch_hooks.append(hook)
def unregister_dispatch_hook(self, hook: DispatchHook) -> None:
"""Remove a dispatch hook."""
if hook in self.dispatch_hooks:
self.dispatch_hooks.remove(hook)
# ------------------------------------------------------------------
# Registration
# ------------------------------------------------------------------
def register_listener(self, lc: ListenerConfig) -> Listener:
root_tag = f"{lc.name.lower()}.{lc.payload_class.__name__.lower()}"
listener = Listener(
name=lc.name,
payload_class=lc.payload_class,
handler=lc.handler,
description=lc.description,
is_agent=lc.is_agent,
peers=lc.peers,
broadcast=lc.broadcast,
cpu_bound=lc.cpu_bound,
handler_path=lc.handler_path, # For worker process import
schema=self._generate_schema(lc.payload_class),
root_tag=root_tag,
output_class=lc.output_class, # For sequence factory validation
)
if lc.is_agent:
self.agent_semaphores[lc.name] = asyncio.Semaphore(
self.config.max_concurrent_per_agent
)
self.routing_table.setdefault(root_tag, []).append(listener)
self.listeners[lc.name] = listener
return listener
def register_generic_listener(
self,
name: str,
handler: Callable,
description: str = "",
) -> Listener:
"""
Register a generic listener that accepts any payload type.
Used for ephemeral orchestration handlers (sequences, buffers)
that need to receive responses from various step types.
Generic listeners:
- Are NOT added to the routing table (no root_tag)
- Are looked up by name (to_id) as a fallback in routing
- Receive payload_tree directly (no XSD validation/deserialization)
Args:
name: Unique listener name (e.g., "sequence_abc123")
handler: Async handler function (receives payload_tree, metadata)
description: Human-readable description
Returns:
Listener object
"""
listener = Listener(
name=name,
payload_class=object, # Placeholder - not used
handler=handler,
description=description,
is_agent=False,
root_tag="*", # Wildcard marker
)
self._generic_listeners[name.lower()] = listener
self.listeners[name] = listener
pump_logger.debug(f"Registered generic listener: {name}")
return listener
def unregister_listener(self, name: str) -> bool:
"""
Remove a listener by name.
Used to clean up ephemeral listeners after orchestration completes.
Args:
name: Listener name to remove
Returns:
True if found and removed, False if not found
"""
name_lower = name.lower()
removed = False
# Remove from generic listeners
if name_lower in self._generic_listeners:
del self._generic_listeners[name_lower]
removed = True
pump_logger.debug(f"Unregistered generic listener: {name}")
# Remove from main listeners dict
if name in self.listeners:
listener = self.listeners.pop(name)
removed = True
# Remove from routing table
if listener.root_tag and listener.root_tag != "*":
listeners_for_tag = self.routing_table.get(listener.root_tag, [])
if listener in listeners_for_tag:
listeners_for_tag.remove(listener)
if not listeners_for_tag:
del self.routing_table[listener.root_tag]
return removed
def register_all(self) -> None:
# First pass: register all listeners
for lc in self.config.listeners:
self.register_listener(lc)
# Second pass: build usage instructions (needs all listeners registered)
for listener in self.listeners.values():
if listener.is_agent and listener.peers:
listener.usage_instructions = self._build_usage_instructions(listener)
def _build_usage_instructions(self, agent: Listener) -> str:
"""
Build LLM system prompt instructions from peer schemas.
Generates human-readable documentation of what messages
this agent can send to its peers.
"""
lines = [
f"You are the {agent.name} agent.",
f"Description: {agent.description}",
"",
"You can send messages to the following peers:",
]
for peer_name in agent.peers:
peer = self.listeners.get(peer_name)
if not peer:
lines.append(f"\n## {peer_name} (not registered)")
continue
lines.append(f"\n## {peer_name}")
lines.append(f"Description: {peer.description}")
# Get XSD schema as readable XML
if hasattr(peer.payload_class, 'xsd'):
xsd_tree = peer.payload_class.xsd()
xsd_str = etree.tostring(xsd_tree, pretty_print=True, encoding='unicode')
lines.append(f"Expected payload schema:\n```xml\n{xsd_str}```")
# Also show a simple example structure
if hasattr(peer.payload_class, '__dataclass_fields__'):
fields = peer.payload_class.__dataclass_fields__
example_lines = [f"<{peer.payload_class.__name__}>"]
for fname, finfo in fields.items():
example_lines.append(f" <{fname}>...{fname}>")
example_lines.append(f"{peer.payload_class.__name__}>")
lines.append(f"Example structure:\n```xml\n" + "\n".join(example_lines) + "\n```")
lines.append("\n---")
lines.append("## Important: Response Semantics")
lines.append("")
lines.append("When you RESPOND (return to your caller), your call chain is pruned.")
lines.append("This means:")
lines.append("- Any sub-agents you called are effectively terminated")
lines.append("- Their state/context is lost (e.g., calculator memory, scratch space)")
lines.append("- You cannot call them again in the same context after responding")
lines.append("")
lines.append("Therefore: Complete ALL sub-tasks before responding to your caller.")
lines.append("If you need results from a peer, wait for their response before you respond.")
return "\n".join(lines)
def _generate_schema(self, payload_class: type) -> etree.XMLSchema:
"""Generate XSD schema from xmlified payload class."""
if hasattr(payload_class, 'xsd'):
xsd_tree = payload_class.xsd()
return etree.XMLSchema(xsd_tree)
# Fallback for non-xmlified classes (e.g., in tests)
permissive = ''
return etree.XMLSchema(etree.fromstring(permissive.encode()))
# ------------------------------------------------------------------
# Stream Source
# ------------------------------------------------------------------
async def _queue_source(self) -> AsyncIterable[MessageState]:
"""Async generator that yields messages from the queue."""
while self._running:
try:
state = await asyncio.wait_for(self.queue.get(), timeout=0.5)
yield state
self.queue.task_done()
except asyncio.TimeoutError:
continue
# ------------------------------------------------------------------
# Pipeline Steps (as stream operators)
# ------------------------------------------------------------------
async def _route_step(self, state: MessageState) -> MessageState:
"""Determine target listeners based on to_id.class format."""
if state.error or state.payload is None:
return state
payload_class_name = type(state.payload).__name__.lower()
to_id = (state.to_id or "").lower()
root_tag = f"{to_id}.{payload_class_name}" if to_id else payload_class_name
targets = self.routing_table.get(root_tag)
if targets:
state.target_listeners = targets
else:
state.error = f"No listener for: {root_tag}"
return state
async def _dispatch_to_handlers(self, state: MessageState) -> AsyncIterable[MessageState]:
"""
Fan-out step: Dispatch to handler(s) and yield response states.
For broadcast, yields one response per listener.
Each response becomes a new message in the stream.
Handlers can return:
- None: no response needed
- HandlerResponse(payload, to): clean dataclass + target (preferred)
- bytes: raw envelope XML (legacy, for backwards compatibility)
"""
if state.error or not state.target_listeners:
# Pass through errors/unroutable for downstream handling
yield state
return
for listener in state.target_listeners:
hook_entries: list[tuple[DispatchHook, str]] = []
try:
# Rate limiting for agents
semaphore = self.agent_semaphores.get(listener.name)
if semaphore:
await semaphore.acquire()
try:
# Emit agent state change event
self._emit_event(AgentStateEvent(
agent_name=listener.name,
state="processing",
thread_id=state.thread_id,
))
# Ensure we have a valid thread chain
registry = get_registry()
todo_registry = get_todo_registry()
context_buffer = get_context_buffer()
current_thread = state.thread_id or ""
# Check if thread exists in registry; if not, register it
if current_thread and not registry.lookup(current_thread):
# New conversation - register existing UUID to chain
# The UUID was assigned by thread_assignment_step
from_id = state.from_id or "external"
registry.register_thread(current_thread, from_id, listener.name)
# Check for todo matches on this message
# This may raise eyebrows on watchers for this thread
if current_thread and state.payload:
payload_type = type(state.payload).__name__
todo_registry.check(
thread_id=current_thread,
payload_type=payload_type,
from_id=state.from_id or "",
payload=state.payload,
)
# Detect self-calls (agent sending to itself)
is_self_call = (state.from_id or "") == listener.name
# Get any raised eyebrows for this agent (for nagging)
todo_nudge = ""
if listener.is_agent and current_thread:
raised = todo_registry.get_raised_for(current_thread, listener.name)
todo_nudge = todo_registry.format_nudge(raised)
# === CONTEXT BUFFER: Record incoming message ===
# Append validated payload to thread's context buffer
# The returned BufferSlot becomes the single source of truth
slot = None
if current_thread and state.payload:
try:
slot = context_buffer.append(
thread_id=current_thread,
payload=state.payload,
from_id=state.from_id or "unknown",
to_id=listener.name,
own_name=listener.name if listener.is_agent else None,
is_self_call=is_self_call,
usage_instructions=listener.usage_instructions,
todo_nudge=todo_nudge,
)
except MemoryError:
# Thread exceeded max slots - log and continue
import logging
logging.getLogger(__name__).warning(
f"Thread {current_thread[:8]}... exceeded context buffer limit"
)
# Derive metadata from slot (single source of truth)
# Fall back to manual construction if no slot (e.g., buffer overflow)
if slot:
from xml_pipeline.memory import slot_to_handler_metadata
metadata = slot_to_handler_metadata(slot)
payload_ref = slot.payload # Same reference as in buffer
else:
metadata = HandlerMetadata(
thread_id=current_thread,
from_id=state.from_id or "",
own_name=listener.name if listener.is_agent else None,
is_self_call=is_self_call,
usage_instructions=listener.usage_instructions,
todo_nudge=todo_nudge,
)
payload_ref = state.payload
# Emit message received event
self._emit_event(MessageReceivedEvent(
thread_id=current_thread,
from_id=state.from_id or "",
to_id=listener.name,
payload_type=type(payload_ref).__name__,
payload=payload_ref,
))
# --- Dispatch hooks: on_intent + on_dispatched ---
hook_entries: list[tuple[DispatchHook, str]] = []
if self.dispatch_hooks:
payload_bytes_for_hook = b""
if state.raw_bytes:
payload_bytes_for_hook = state.raw_bytes
for hook in self.dispatch_hooks:
try:
eid = await hook.on_intent(
thread_id=current_thread,
from_id=state.from_id or "",
to_id=listener.name,
payload_type=type(payload_ref).__name__,
payload_bytes=payload_bytes_for_hook,
)
hook_entries.append((hook, eid))
except Exception as he:
pump_logger.warning(f"Dispatch hook on_intent error: {he}")
for hook, eid in hook_entries:
try:
await hook.on_dispatched(eid)
except Exception as he:
pump_logger.warning(f"Dispatch hook on_dispatched error: {he}")
# Dispatch to handler - either in-process or via ProcessPool
if listener.cpu_bound and self._process_pool and self._shared_backend:
response = await self._dispatch_to_process_pool(
listener=listener,
payload=payload_ref,
metadata=metadata,
)
else:
response = await listener.handler(payload_ref, metadata)
# None means "no response needed" - don't re-inject
if response is None:
# --- Dispatch hooks: on_acknowledged + on_thread_complete ---
for hook, eid in hook_entries:
try:
await hook.on_acknowledged(eid)
except Exception as he:
pump_logger.warning(f"Dispatch hook on_acknowledged error: {he}")
for hook in self.dispatch_hooks:
try:
await hook.on_thread_complete(current_thread)
except Exception as he:
pump_logger.warning(f"Dispatch hook on_thread_complete error: {he}")
# Thread terminates here - cleanup budget
budget_registry = get_budget_registry()
final_budget = budget_registry.cleanup_thread(current_thread)
if final_budget:
pump_logger.debug(
f"Thread {current_thread[:8]}... completed: "
f"{final_budget.total_tokens} tokens used"
)
# Emit idle state
self._emit_event(AgentStateEvent(
agent_name=listener.name,
state="idle",
thread_id=current_thread,
))
continue
# Handle clean HandlerResponse (preferred)
if isinstance(response, HandlerResponse):
registry = get_registry()
if response.is_response:
# Response back to caller - prune chain
target, new_thread_id = registry.prune_for_response(current_thread)
if target is None:
# Chain exhausted - nowhere to respond to
# --- Dispatch hooks: on_acknowledged + on_thread_complete ---
for hook, eid in hook_entries:
try:
await hook.on_acknowledged(eid)
except Exception as he:
pump_logger.warning(f"Dispatch hook on_acknowledged error: {he}")
for hook in self.dispatch_hooks:
try:
await hook.on_thread_complete(current_thread)
except Exception as he:
pump_logger.warning(f"Dispatch hook on_thread_complete error: {he}")
# Cleanup thread budget
budget_registry = get_budget_registry()
final_budget = budget_registry.cleanup_thread(current_thread)
if final_budget:
pump_logger.debug(
f"Thread {current_thread[:8]}... chain exhausted: "
f"{final_budget.total_tokens} tokens used"
)
continue
to_id = target
thread_id = new_thread_id
else:
# Forward to named target - validate against peers
requested_to = response.to
# Enforce peer constraints for agents
if listener.is_agent and listener.peers:
if requested_to not in listener.peers:
# Agent trying to send to non-peer - send generic error back to agent
# Log details internally but don't reveal to agent
import logging
logging.getLogger(__name__).warning(
f"Peer violation: {listener.name} -> {requested_to} (allowed: {listener.peers})"
)
# Send SystemError back to the agent (keeps thread alive)
error_bytes = self._wrap_in_envelope(
payload=ROUTING_ERROR,
from_id="system",
to_id=listener.name,
thread_id=current_thread,
)
yield MessageState(
raw_bytes=error_bytes,
thread_id=current_thread,
from_id="system",
)
continue
to_id = requested_to
thread_id = registry.extend_chain(current_thread, to_id)
# === CONTEXT BUFFER: Record outgoing response ===
# Append handler's response to the target thread's buffer
# This happens BEFORE serialization - the buffer holds the clean payload
try:
context_buffer.append(
thread_id=thread_id,
payload=response.payload,
from_id=listener.name,
to_id=to_id,
)
except MemoryError:
import logging
logging.getLogger(__name__).warning(
f"Thread {thread_id[:8]}... exceeded context buffer limit"
)
response_bytes = self._wrap_in_envelope(
payload=response.payload,
from_id=listener.name,
to_id=to_id,
thread_id=thread_id,
)
# Legacy: raw bytes (backwards compatible)
elif isinstance(response, bytes):
response_bytes = response
thread_id = state.thread_id
else:
response_bytes = b"Handler returned invalid type"
thread_id = state.thread_id
# Emit message sent event
if isinstance(response, HandlerResponse):
self._emit_event(MessageSentEvent(
thread_id=thread_id,
from_id=listener.name,
to_id=to_id,
payload_type=type(response.payload).__name__,
payload=response.payload,
))
# --- Dispatch hooks: on_acknowledged ---
for hook, eid in hook_entries:
try:
await hook.on_acknowledged(eid)
except Exception as he:
pump_logger.warning(f"Dispatch hook on_acknowledged error: {he}")
# Emit agent state back to idle
self._emit_event(AgentStateEvent(
agent_name=listener.name,
state="idle",
thread_id=None,
))
# Yield response — will be processed by next iteration
yield MessageState(
raw_bytes=response_bytes,
thread_id=thread_id,
from_id=listener.name,
)
finally:
if semaphore:
semaphore.release()
except Exception as exc:
# --- Dispatch hooks: on_failed ---
for hook, eid in hook_entries:
try:
await hook.on_failed(eid, str(exc))
except Exception as he:
pump_logger.warning(f"Dispatch hook on_failed error: {he}")
# Emit error state
self._emit_event(AgentStateEvent(
agent_name=listener.name,
state="error",
thread_id=state.thread_id,
))
yield MessageState(
raw_bytes=f"Handler {listener.name} crashed: {exc}".encode(),
thread_id=state.thread_id,
from_id=listener.name,
error=str(exc),
)
def _wrap_in_envelope(self, payload: Any, from_id: str, to_id: str, thread_id: str) -> bytes:
"""Wrap a dataclass payload in a message envelope, optionally signed."""
# Serialize payload to XML
if hasattr(payload, 'to_xml'):
# SystemError and similar have manual to_xml()
payload_str = payload.to_xml()
elif hasattr(payload, 'xml_value'):
# @xmlify dataclasses
payload_class_name = type(payload).__name__
payload_tree = payload.xml_value(payload_class_name)
payload_str = etree.tostring(payload_tree, encoding='unicode')
else:
# Fallback for non-xmlify classes
payload_class_name = type(payload).__name__
payload_str = f"<{payload_class_name}>{payload}{payload_class_name}>"
# Add xmlns="" to keep payload out of envelope namespace
if 'xmlns=' not in payload_str:
idx = payload_str.index('>')
payload_str = payload_str[:idx] + ' xmlns=""' + payload_str[idx:]
envelope_str = f"""
{from_id}
{to_id}
{thread_id}
{payload_str}
"""
# Sign if identity is configured
if self.identity is not None:
try:
from xml_pipeline.crypto.signing import sign_envelope
envelope_tree = etree.fromstring(envelope_str.encode('utf-8'))
signed_tree = sign_envelope(envelope_tree, self.identity, in_place=True)
return etree.tostring(signed_tree, encoding='utf-8', xml_declaration=True)
except Exception as e:
pump_logger.warning(f"Failed to sign envelope: {e}")
# Fall through to unsigned
return envelope_str.encode('utf-8')
async def _dispatch_to_process_pool(
self,
listener: Listener,
payload: Any,
metadata: HandlerMetadata,
) -> Optional[HandlerResponse]:
"""
Dispatch handler to ProcessPoolExecutor for CPU-bound execution.
This offloads work to a separate process to avoid blocking
the main event loop.
Args:
listener: The target listener
payload: The @xmlify dataclass payload
metadata: Handler metadata
Returns:
HandlerResponse or None (same as direct handler call)
"""
from xml_pipeline.message_bus.worker import (
WorkerTask,
store_task_data,
fetch_response,
cleanup_task_data,
execute_handler,
)
assert self._process_pool is not None
assert self._shared_backend is not None
# Store payload and metadata in shared backend
payload_uuid, metadata_uuid = store_task_data(
self._shared_backend, payload, metadata
)
# Create worker task
task = WorkerTask(
thread_uuid=metadata.thread_id,
payload_uuid=payload_uuid,
handler_path=listener.handler_path,
metadata_uuid=metadata_uuid,
listener_name=listener.name,
is_agent=listener.is_agent,
peers=list(listener.peers),
)
# Backend config for worker process
backend_config = {
"backend_type": self.config.backend_type,
"redis_url": self.config.backend_redis_url,
"redis_prefix": self.config.backend_redis_prefix,
}
try:
# Submit to process pool and await result
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
self._process_pool,
execute_handler,
task,
backend_config,
)
if not result.success:
pump_logger.error(
f"Worker error for {listener.name}: {result.error}"
)
if result.error_traceback:
pump_logger.debug(f"Traceback: {result.error_traceback}")
return None
# Fetch response from shared backend
if result.response_uuid:
response = fetch_response(self._shared_backend, result.response_uuid)
return response
return None
finally:
# Clean up task data from backend
cleanup_task_data(
self._shared_backend,
payload_uuid,
metadata_uuid,
result.response_uuid if 'result' in dir() and result.success else None,
)
async def _reinject_responses(self, state: MessageState) -> None:
"""Push handler responses back into the queue for next iteration."""
await self.queue.put(state)
# ------------------------------------------------------------------
# Build the Pipeline
# ------------------------------------------------------------------
def build_pipeline(self, source: AsyncIterable[MessageState]):
"""
Construct the full processing pipeline.
This is where you configure the flow. Modify this method to:
- Add/remove steps
- Change concurrency limits
- Insert logging/metrics
- Add filtering
"""
# The pipeline is a composition of stream operators
pipeline = (
stream.iterate(source)
# ============================================================
# STAGE 1: Envelope Processing (1:1 transforms)
# ============================================================
| pipe.map(repair_step)
| pipe.map(c14n_step)
| pipe.map(envelope_validation_step)
| pipe.map(payload_extraction_step)
| pipe.map(thread_assignment_step)
# ============================================================
# STAGE 2: Fan-out — Extract Multiple Payloads (1:N)
# ============================================================
# Handler responses may contain multiple payloads.
# Each becomes a separate message in the stream.
| pipe.flatmap(extract_payloads)
# ============================================================
# STAGE 3: Per-Payload Validation (1:1 transforms)
# ============================================================
# Note: In a real implementation, you'd route to listener-specific
# validation here. For now, we use a simplified approach.
| pipe.map(self._validate_and_deserialize)
# ============================================================
# STAGE 4: Routing (1:1)
# ============================================================
| pipe.map(self._route_step)
# ============================================================
# STAGE 5: Filter Errors
# ============================================================
# Errors go to a separate handler (could also be a branch)
| pipe.map(self._handle_errors)
| pipe.filter(lambda s: s.error is None and s.target_listeners)
# ============================================================
# STAGE 6: Fan-out — Dispatch to Handlers (1:N for broadcast)
# ============================================================
# This is where handlers are invoked. Broadcast = multiple yields.
# task_limit controls concurrent handler invocations.
| pipe.flatmap(
self._dispatch_to_handlers,
task_limit=self.config.max_concurrent_handlers
)
# ============================================================
# STAGE 7: Re-inject Responses
# ============================================================
# Handler responses go back into the queue for next iteration.
# The cycle continues until no more messages.
| pipe.action(self._reinject_responses)
)
return pipeline
async def _validate_and_deserialize(self, state: MessageState) -> MessageState:
"""
Combined validation + deserialization.
Uses to_id + payload tag to find the right listener and schema.
Falls back to generic listeners (ephemeral orchestration handlers)
when no regular listener matches.
"""
if state.error or state.payload_tree is None:
return state
# Build lookup key: to_id.payload_tag (matching routing table format)
payload_tag = state.payload_tree.tag
if payload_tag.startswith("{"):
payload_tag = payload_tag.split("}", 1)[1]
to_id = (state.to_id or "").lower()
lookup_key = f"{to_id}.{payload_tag.lower()}" if to_id else payload_tag.lower()
listeners = self.routing_table.get(lookup_key, [])
# Fallback: check for generic listener by to_id
# Generic listeners accept any payload type (for orchestration)
if not listeners and to_id:
generic_listener = self._generic_listeners.get(to_id)
if generic_listener:
# Generic listener: skip XSD validation and deserialization
# Pass the raw payload_tree to the handler
state.payload = state.payload_tree # Handler receives Element
state.target_listeners = [generic_listener]
state.metadata["generic_handler"] = True
return state
if not listeners:
state.error = f"No listener for: {lookup_key}"
return state
listener = listeners[0]
# Validate against listener's schema
try:
listener.schema.assertValid(state.payload_tree)
except etree.DocumentInvalid as e:
state.error = f"XSD validation failed: {e}"
return state
# Deserialize
try:
from third_party.xmlable import parse_element
state.payload = parse_element(listener.payload_class, state.payload_tree)
except Exception as e:
state.error = f"Deserialization failed: {e}"
return state
async def _handle_errors(self, state: MessageState) -> MessageState:
"""Log errors (could also emit messages)."""
if state.error:
print(f"[ERROR] {state.thread_id}: {state.error}")
# Could emit to a specific listener here
return state
# ------------------------------------------------------------------
# Run the Pump
# ------------------------------------------------------------------
async def run(self) -> None:
"""
Main entry point — run the stream pipeline.
The pipeline pulls from the queue, processes messages,
and re-injects handler responses. Continues until shutdown.
"""
self._running = True
pipeline = self.build_pipeline(self._queue_source())
try:
async with pipeline.stream() as streamer:
async for _ in streamer:
# The pipeline drives itself via re-injection.
# We just need to consume the stream.
pass
except asyncio.CancelledError:
pass
finally:
self._running = False
# ------------------------------------------------------------------
# External API
# ------------------------------------------------------------------
async def inject(self, raw_bytes: bytes, thread_id: str, from_id: str) -> None:
"""Inject a message to start processing."""
state = MessageState(
raw_bytes=raw_bytes,
thread_id=thread_id,
from_id=from_id,
)
await self.queue.put(state)
async def shutdown(self) -> None:
"""Graceful shutdown — wait for queue to drain and close resources."""
self._running = False
await self.queue.join()
# Shutdown process pool if active
if self._process_pool:
self._process_pool.shutdown(wait=True)
pump_logger.info("ProcessPool shutdown complete")
self._process_pool = None
def reload_config(self, config_path: Optional[str] = None) -> ReloadEvent:
"""
Hot-reload organism configuration.
Re-reads the config file and updates listeners:
- New listeners are registered
- Removed listeners are unregistered
- Changed listeners are updated (handler, peers, description)
Args:
config_path: Path to config file. Uses stored path if not provided.
Returns:
ReloadEvent with details of what changed
"""
path = config_path or self.config_path
if not path:
return ReloadEvent(
success=False,
error="No config path available for reload",
)
try:
# Re-read config
new_config = ConfigLoader.load(path)
# Track changes
added: List[str] = []
removed: List[str] = []
updated: List[str] = []
# Get current listener names (excluding system listeners)
current_names = {
name for name in self.listeners.keys()
if not name.startswith("system.")
}
# Get new listener names
new_names = {lc.name for lc in new_config.listeners}
# Find removed listeners
for name in current_names - new_names:
if self.unregister_listener(name):
removed.append(name)
pump_logger.info(f"Hot-reload: removed listener '{name}'")
# Find new and updated listeners
for lc in new_config.listeners:
# Resolve imports for the listener config
ConfigLoader._resolve_imports(lc)
if lc.name in current_names:
# Check if changed
existing = self.listeners.get(lc.name)
if existing and self._listener_changed(existing, lc):
# Remove old and re-register
self.unregister_listener(lc.name)
self.register_listener(lc)
updated.append(lc.name)
pump_logger.info(f"Hot-reload: updated listener '{lc.name}'")
else:
# New listener
self.register_listener(lc)
added.append(lc.name)
pump_logger.info(f"Hot-reload: added listener '{lc.name}'")
# Rebuild usage instructions for all agents (peers may have changed)
for listener in self.listeners.values():
if listener.is_agent and listener.peers:
listener.usage_instructions = self._build_usage_instructions(listener)
# Update stored config
self.config = new_config
# Emit reload event
event = ReloadEvent(
success=True,
added_listeners=added,
removed_listeners=removed,
updated_listeners=updated,
)
self._emit_event(event)
pump_logger.info(
f"Hot-reload complete: +{len(added)} -{len(removed)} ~{len(updated)}"
)
return event
except Exception as e:
pump_logger.error(f"Hot-reload failed: {e}")
event = ReloadEvent(success=False, error=str(e))
self._emit_event(event)
return event
def _listener_changed(self, existing: Listener, new_config: ListenerConfig) -> bool:
"""Check if listener config has changed."""
# Compare key fields
if existing.handler != new_config.handler:
return True
if existing.payload_class != new_config.payload_class:
return True
if existing.description != new_config.description:
return True
if existing.is_agent != new_config.is_agent:
return True
if set(existing.peers) != set(new_config.peers):
return True
if existing.broadcast != new_config.broadcast:
return True
if existing.cpu_bound != new_config.cpu_bound:
return True
return False
# ============================================================================
# Config Loader (same as before)
# ============================================================================
class ConfigLoader:
@classmethod
def load(cls, path: str | Path) -> OrganismConfig:
with open(Path(path)) as f:
raw = yaml.safe_load(f)
return cls._parse(raw)
@classmethod
def _parse(cls, raw: dict) -> OrganismConfig:
org = raw.get("organism", {})
# Parse process pool config
pool = raw.get("process_pool", {})
process_pool_enabled = pool.get("enabled", False) if pool else False
process_pool_workers = pool.get("workers", 4) if pool else 4
process_pool_max_tasks = pool.get("max_tasks_per_child", 100) if pool else 100
# Parse backend config
backend = raw.get("backend", {})
backend_type = backend.get("type", "memory") if backend else "memory"
backend_redis_url = backend.get("redis_url", "redis://localhost:6379") if backend else "redis://localhost:6379"
backend_redis_prefix = backend.get("redis_prefix", "xp:") if backend else "xp:"
config = OrganismConfig(
name=org.get("name", "unnamed"),
identity_path=org.get("identity", ""),
port=org.get("port", 8765),
thread_scheduling=raw.get("thread_scheduling", "breadth-first"),
max_concurrent_pipelines=raw.get("max_concurrent_pipelines", 50),
max_concurrent_handlers=raw.get("max_concurrent_handlers", 20),
max_concurrent_per_agent=raw.get("max_concurrent_per_agent", 5),
max_tokens_per_thread=raw.get("max_tokens_per_thread", 100_000),
llm_config=raw.get("llm", {}),
process_pool_enabled=process_pool_enabled,
process_pool_workers=process_pool_workers,
process_pool_max_tasks_per_child=process_pool_max_tasks,
backend_type=backend_type,
backend_redis_url=backend_redis_url,
backend_redis_prefix=backend_redis_prefix,
)
for entry in raw.get("listeners", []):
lc = cls._parse_listener(entry)
cls._resolve_imports(lc)
config.listeners.append(lc)
return config
@classmethod
def _parse_listener(cls, raw: dict) -> ListenerConfig:
return ListenerConfig(
name=raw["name"],
payload_class_path=raw["payload_class"],
handler_path=raw["handler"],
description=raw["description"],
is_agent=raw.get("agent", False),
peers=raw.get("peers", []),
broadcast=raw.get("broadcast", False),
prompt=raw.get("prompt", ""),
cpu_bound=raw.get("cpu_bound", False),
output_class_path=raw.get("output_class", ""),
)
@classmethod
def _resolve_imports(cls, lc: ListenerConfig) -> None:
mod, cls_name = lc.payload_class_path.rsplit(".", 1)
lc.payload_class = getattr(importlib.import_module(mod), cls_name)
mod, fn_name = lc.handler_path.rsplit(".", 1)
lc.handler = getattr(importlib.import_module(mod), fn_name)
# Load output class if specified (for sequence factory validation)
if lc.output_class_path:
mod, cls_name = lc.output_class_path.rsplit(".", 1)
lc.output_class = getattr(importlib.import_module(mod), cls_name)
# ============================================================================
# Bootstrap
# ============================================================================
async def bootstrap(config_path: str = "config/organism.yaml") -> StreamPump:
"""Load config, create pump, initialize root thread, and inject boot message."""
from datetime import datetime, timezone
from dotenv import load_dotenv
from xml_pipeline.primitives import Boot, handle_boot
from xml_pipeline.primitives import (
TodoUntil, TodoComplete,
handle_todo_until, handle_todo_complete,
)
from xml_pipeline.platform import get_prompt_registry
# Load .env file if present
load_dotenv()
config = ConfigLoader.load(config_path)
print(f"Organism: {config.name}")
print(f"Listeners: {len(config.listeners)}")
pump = StreamPump(config, config_path=config_path)
# Register system listeners first
boot_listener_config = ListenerConfig(
name="system.boot",
payload_class_path="xml_pipeline.primitives.Boot",
handler_path="xml_pipeline.primitives.handle_boot",
description="System boot handler - initializes organism",
is_agent=False,
payload_class=Boot,
handler=handle_boot,
)
pump.register_listener(boot_listener_config)
# Register TodoUntil handler (agents register watchers)
todo_until_config = ListenerConfig(
name="system.todo",
payload_class_path="xml_pipeline.primitives.TodoUntil",
handler_path="xml_pipeline.primitives.handle_todo_until",
description="System todo handler - registers watchers",
is_agent=False,
payload_class=TodoUntil,
handler=handle_todo_until,
)
pump.register_listener(todo_until_config)
# Register TodoComplete handler (agents close watchers)
todo_complete_config = ListenerConfig(
name="system.todo-complete",
payload_class_path="xml_pipeline.primitives.TodoComplete",
handler_path="xml_pipeline.primitives.handle_todo_complete",
description="System todo handler - closes watchers",
is_agent=False,
payload_class=TodoComplete,
handler=handle_todo_complete,
)
pump.register_listener(todo_complete_config)
# Register Sequence primitives (orchestration)
from xml_pipeline.primitives.sequence import (
SequenceStart, handle_sequence_start,
)
sequence_config = ListenerConfig(
name="system.sequence",
payload_class_path="xml_pipeline.primitives.sequence.SequenceStart",
handler_path="xml_pipeline.primitives.sequence.handle_sequence_start",
description="System sequence handler - chains listeners in order",
is_agent=False,
payload_class=SequenceStart,
handler=handle_sequence_start,
)
pump.register_listener(sequence_config)
# Register Sequence Factory (validates sequences before execution)
from xml_pipeline.sequence import (
SequenceRequest, handle_sequence_request,
)
sequence_factory_config = ListenerConfig(
name="system.sequence.factory",
payload_class_path="xml_pipeline.sequence.primitives.SequenceRequest",
handler_path="xml_pipeline.sequence.handler.handle_sequence_request",
description="Validates and builds sequences - agents request here first",
is_agent=False,
payload_class=SequenceRequest,
handler=handle_sequence_request,
)
pump.register_listener(sequence_factory_config)
# Register Buffer primitives (fan-out orchestration)
from xml_pipeline.primitives.buffer import (
BufferStart, handle_buffer_start,
)
buffer_config = ListenerConfig(
name="system.buffer",
payload_class_path="xml_pipeline.primitives.buffer.BufferStart",
handler_path="xml_pipeline.primitives.buffer.handle_buffer_start",
description="System buffer handler - fan-out to parallel workers",
is_agent=False,
payload_class=BufferStart,
handler=handle_buffer_start,
)
pump.register_listener(buffer_config)
# Register all user-defined listeners
pump.register_all()
# Load prompts into PromptRegistry (platform-managed, immutable)
prompt_registry = get_prompt_registry()
prompt_count = 0
for listener in pump.listeners.values():
if listener.is_agent:
# Get prompt from config (may be empty)
lc = next((l for l in config.listeners if l.name == listener.name), None)
system_prompt = lc.prompt if lc else ""
# Register prompt with peer schemas (usage_instructions)
prompt_registry.register(
agent_name=listener.name,
system_prompt=system_prompt,
peer_schemas=listener.usage_instructions,
)
prompt_count += 1
# Freeze registry - no more registrations allowed
prompt_registry.freeze()
print(f"Prompts: {prompt_count} agents registered, registry frozen")
# Configure LLM router if llm section present
if config.llm_config:
from xml_pipeline.llm import configure_router
configure_router(config.llm_config)
print(f"LLM backends: {len(config.llm_config.get('backends', []))}")
# Configure thread budget registry
from xml_pipeline.message_bus.budget_registry import configure_budget_registry
configure_budget_registry(config.max_tokens_per_thread)
print(f"Token budget: {config.max_tokens_per_thread:,} per thread")
# Initialize root thread in registry
registry = get_registry()
root_uuid = registry.initialize_root(config.name)
print(f"Root thread: {root_uuid} ({registry.root_chain})")
# Create and inject the boot message
boot_payload = Boot(
organism_name=config.name,
timestamp=datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
listener_count=len(pump.listeners),
)
# Wrap boot payload in envelope
boot_envelope = pump._wrap_in_envelope(
payload=boot_payload,
from_id="system",
to_id="system.boot",
thread_id=root_uuid,
)
# Inject boot message (will be processed when pump.run() is called)
await pump.inject(boot_envelope, thread_id=root_uuid, from_id="system")
# Set global pump instance for get_stream_pump()
set_stream_pump(pump)
print(f"Routing: {list(pump.routing_table.keys())}")
return pump
# ============================================================================
# Example: Customizing the Pipeline
# ============================================================================
"""
The beauty of aiostream: the pipeline is just a composition.
You can easily insert, remove, or reorder stages.
# Add logging between stages:
| pipe.action(lambda s: print(f"After repair: {s.thread_id}"))
# Add throttling:
| pipe.map(some_step, task_limit=5)
# Branch errors to a separate stream:
errors, valid = stream.partition(source, lambda s: s.error is not None)
# Merge multiple sources:
combined = stream.merge(queue_source, oob_source, external_api_source)
# Add timeout per message:
| pipe.timeout(30.0) # 30 second timeout per item
# Rate limit the whole pipeline:
| pipe.spaceout(0.1) # 100ms between items
"""
# ============================================================================
# Comparison: Old vs New
# ============================================================================
"""
OLD (bus.py):
for payload in payloads:
await pipeline.process(payload) # Sequential, recursive
NEW (aiostream):
| pipe.flatmap(extract_payloads) # Fan-out, parallel
| pipe.flatmap(dispatch, task_limit=20) # Concurrent handlers
The key difference:
- Old: 3 tool calls = 3 sequential awaits, each blocking until complete
- New: 3 tool calls = 3 items in stream, processed concurrently up to task_limit
"""
# ============================================================================
# Global Singleton
# ============================================================================
_pump: Optional[StreamPump] = None
def get_stream_pump() -> StreamPump:
"""
Get the global StreamPump instance.
The pump is initialized via bootstrap() and set here.
Raises RuntimeError if called before bootstrap.
"""
global _pump
if _pump is None:
raise RuntimeError(
"StreamPump not initialized. Call bootstrap() first."
)
return _pump
def set_stream_pump(pump: StreamPump) -> None:
"""
Set the global StreamPump instance.
Called by bootstrap() after creating the pump.
"""
global _pump
_pump = pump
def reset_stream_pump() -> None:
"""
Reset the global StreamPump instance.
Useful for testing.
"""
global _pump
_pump = None