diff --git a/__init__.py b/__init__.py index 6672898..cfc1c94 100644 --- a/__init__.py +++ b/__init__.py @@ -3,19 +3,26 @@ xml-pipeline ============ Secure, XML-centric multi-listener organism server. + +Stream-based message pump with aiostream for fan-out handling. """ -from agentserver.agentserver import AgentServer as AgentServer -from agentserver.xml_listener import XMLListener as XMLListener -from agentserver.message_bus import MessageBus as MessageBus -from agentserver.message_bus import Session as Session - +from agentserver.message_bus import ( + StreamPump, + ConfigLoader, + Listener, + MessageState, + HandlerMetadata, + bootstrap, +) __all__ = [ - "AgentServer", - "XMLListener", - "MessageBus", - "Session", + "StreamPump", + "ConfigLoader", + "Listener", + "MessageState", + "HandlerMetadata", + "bootstrap", ] -__version__ = "0.1.0" \ No newline at end of file +__version__ = "0.2.0" # Bumped for aiostream pump \ No newline at end of file diff --git a/agentserver/message_bus/__init__.py b/agentserver/message_bus/__init__.py index e69de29..d391572 100644 --- a/agentserver/message_bus/__init__.py +++ b/agentserver/message_bus/__init__.py @@ -0,0 +1,44 @@ +""" +message_bus — Stream-based message pump for AgentServer v2.1 + +The message pump handles message flow through the organism: +- YAML config → bootstrap → pump → handlers → responses → loop + +Key classes: + StreamPump Main pump class (queue-backed, aiostream-powered) + ConfigLoader Load organism.yaml and resolve imports + Listener Runtime listener with handler and routing info + MessageState Message flowing through pipeline steps + +Usage: + from agentserver.message_bus import StreamPump, bootstrap + + pump = await bootstrap("config/organism.yaml") + await pump.inject(initial_message, thread_id, from_id) + await pump.run() +""" + +from agentserver.message_bus.stream_pump import ( + StreamPump, + ConfigLoader, + Listener, + ListenerConfig, + OrganismConfig, + bootstrap, +) + +from agentserver.message_bus.message_state import ( + MessageState, + HandlerMetadata, +) + +__all__ = [ + "StreamPump", + "ConfigLoader", + "Listener", + "ListenerConfig", + "OrganismConfig", + "MessageState", + "HandlerMetadata", + "bootstrap", +] diff --git a/agentserver/message_bus/bus.py b/agentserver/message_bus/bus.py deleted file mode 100644 index 680d670..0000000 --- a/agentserver/message_bus/bus.py +++ /dev/null @@ -1,226 +0,0 @@ -""" -bus.py — The central MessageBus and pump for AgentServer v2.1 - -This is the beating heart of the organism: -- Owns all pipelines (one per listener + permanent system pipeline) -- Maintains the routing table (root_tag → Listener(s)) -- Orchestrates ingress from sockets/gateways -- Dispatches prepared messages to handlers -- Processes handler responses (multi-payload extraction, provenance injection) -- Guarantees thread continuity and diagnostic injection - -Fully aligned with: - - listener-class-v2.1.md - - configuration-v2.1.md - - message-pump-v2.1.md -""" - -from __future__ import annotations - -import asyncio -from dataclasses import dataclass -from typing import Callable, Awaitable, List -from uuid import uuid4 - -from lxml import etree - -from agentserver.message_bus.message_state import MessageState, HandlerMetadata -from agentserver.message_bus.steps.repair import repair_step -from agentserver.message_bus.steps.c14n import c14n_step -from agentserver.message_bus.steps.envelope_validation import envelope_validation_step -from agentserver.message_bus.steps.payload_extraction import payload_extraction_step -from agentserver.message_bus.steps.thread_assignment import thread_assignment_step -from agentserver.message_bus.steps.xsd_validation import xsd_validation_step -from agentserver.message_bus.steps.deserialization import deserialization_step -from agentserver.message_bus.steps.routing_resolution import routing_resolution_step - -# Type alias for pipeline steps -PipelineStep = Callable[[MessageState], Awaitable[MessageState]] - - -@dataclass -class Listener: - """Registered capability — defined in listener.py, referenced here.""" - name: str - payload_class: type - handler: Callable - description: str - is_agent: bool = False - peers: list[str] | None = None - broadcast: bool = False - pipeline: "Pipeline" | None = None - schema: etree.XMLSchema | None = None # cached at registration - - -class Pipeline: - """One dedicated pipeline per listener (plus system pipeline).""" - def __init__(self, steps: List[PipelineStep]): - self.steps = steps - - async def process(self, initial_state: MessageState) -> None: - """Run the full ordered pipeline on a message.""" - state = initial_state - for step in self.steps: - try: - state = await step(state) - if state.error: - break - except Exception as exc: # pylint: disable=broad-except - state.error = f"Pipeline step {step.__name__} crashed: {exc}" - break - - # After all steps — dispatch if routable - if state.target_listeners: - await MessageBus.get_instance().dispatcher(state) - else: - # Fall back to system pipeline for diagnostics - await MessageBus.get_instance().system_pipeline.process(state) - - -class MessageBus: - """Singleton message bus — the pump.""" - _instance: "MessageBus" | None = None - - def __init__(self): - self.routing_table: dict[str, List[Listener]] = {} # root_tag → listener(s) - self.listeners: dict[str, Listener] = {} # name → Listener - self.system_pipeline = Pipeline(self._build_system_steps()) - - @classmethod - def get_instance(cls) -> "MessageBus": - if cls._instance is None: - cls._instance = MessageBus() - return cls._instance - - # ------------------------------------------------------------------ # - # Default step lists - # ------------------------------------------------------------------ # - def _build_default_listener_steps(self) -> List[PipelineStep]: - return [ - repair_step, - c14n_step, - envelope_validation_step, - payload_extraction_step, - thread_assignment_step, - xsd_validation_step, - deserialization_step, - routing_resolution_step, - ] - - def _build_system_steps(self) -> List[PipelineStep]: - """Shorter, fixed steps — no XSD/deserialization.""" - return [ - repair_step, - c14n_step, - envelope_validation_step, - payload_extraction_step, - thread_assignment_step, - # system-specific handler that emits , boot, etc. - self.system_handler_step, - ] - - # ------------------------------------------------------------------ # - # Registration (called from listener.py) - # ------------------------------------------------------------------ # - def register_listener(self, listener: Listener) -> None: - root_tag = f"{listener.name.lower()}.{listener.payload_class.__name__.lower()}" - - if root_tag in self.routing_table and not listener.broadcast: - raise ValueError(f"Root tag collision: {root_tag} already registered by {self.routing_table[root_tag][0].name}") - - # Build dedicated pipeline - steps = self._build_default_listener_steps() - # Inject listener-specific schema for xsd_validation_step - for step in steps: - if step.__name__ == "xsd_validation_step": - # We'll modify state.metadata in pipeline construction instead - pass - listener.pipeline = Pipeline(steps) - - # Insert into routing - self.routing_table.setdefault(root_tag, []).append(listener) - self.listeners[listener.name] = listener - - # ------------------------------------------------------------------ # - # Dispatcher — dumb fire-and-await - # ------------------------------------------------------------------ # - async def dispatcher(self, state: MessageState) -> None: - if not state.target_listeners: - return - - metadata = HandlerMetadata( - thread_id=state.thread_id or "", - from_id=state.from_id or "unknown", - own_name=state.target_listeners[0].name if state.target_listeners[0].is_agent else None, - is_self_call=(state.from_id == state.target_listeners[0].name) if state.from_id else False, - ) - - if len(state.target_listeners) == 1: - listener = state.target_listeners[0] - await self._process_single_handler(state, listener, metadata) - else: - # Broadcast — fire all in parallel, process responses as they complete - tasks = [ - self._process_single_handler(state, listener, metadata) - for listener in state.target_listeners - ] - for future in asyncio.as_completed(tasks): - await future - - async def _process_single_handler(self, state: MessageState, listener: Listener, metadata: HandlerMetadata) -> None: - try: - response_bytes = await listener.handler(state.payload, metadata) - - if response_bytes is None or not isinstance(response_bytes, bytes): - response_bytes = b"Handler failed to return valid bytes — missing return or wrong type" - - payloads = await self._multi_payload_extract(response_bytes) - - for payload_bytes in payloads: - new_state = MessageState( - raw_bytes=payload_bytes, - thread_id=state.thread_id, - from_id=listener.name, - ) - # Route the new payload through normal pipelines - root_tag = self._derive_root_tag(payload_bytes) - targets = self.routing_table.get(root_tag) - if targets: - new_state.target_listeners = targets - await targets[0].pipeline.process(new_state) - else: - await self.system_pipeline.process(new_state) - - except Exception as exc: # pylint: disable=broad-except - error_state = MessageState( - raw_bytes=b"Handler crashed", - thread_id=state.thread_id, - from_id=listener.name, - error=f"Handler {listener.name} crashed: {exc}", - ) - await self.system_pipeline.process(error_state) - - # ------------------------------------------------------------------ # - # Helper methods - # ------------------------------------------------------------------ # - async def _multi_payload_extract(self, raw_bytes: bytes) -> List[bytes]: - # Same logic as before — dummy wrap, repair, extract all root elements - # (implementation can be moved to a shared util later) - # For now, placeholder — we'll flesh this out in response_processing.py - return [raw_bytes] # temporary — will be full extraction - - def _derive_root_tag(self, payload_bytes: bytes) -> str: - # Quick parse to get root tag — used only for routing extracted payloads - try: - tree = etree.fromstring(payload_bytes) - tag = tree.tag - if tag.startswith("{"): - return tag.split("}", 1)[1] # strip namespace - return tag - except Exception: - return "" - - async def system_handler_step(self, state: MessageState) -> MessageState: - # Emit or boot message — placeholder for now - state.error = state.error or "Unhandled by any listener" - return state \ No newline at end of file diff --git a/agentserver/message_bus/message_state.py b/agentserver/message_bus/message_state.py index bc50d5f..df64e6f 100644 --- a/agentserver/message_bus/message_state.py +++ b/agentserver/message_bus/message_state.py @@ -1,6 +1,12 @@ +from __future__ import annotations + from dataclasses import dataclass, field -from lxml.etree import Element -from typing import Any +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + from lxml.etree import _Element as Element +else: + Element = Any # Runtime: don't need the actual type """ default_listener_steps = [ @@ -33,6 +39,7 @@ class MessageState: thread_id: str | None = None from_id: str | None = None + to_id: str | None = None # Target listener name for routing target_listeners: list['Listener'] | None = None # Forward reference diff --git a/agentserver/message_bus/steps/deserialization.py b/agentserver/message_bus/steps/deserialization.py index 4ab9753..835b568 100644 --- a/agentserver/message_bus/steps/deserialization.py +++ b/agentserver/message_bus/steps/deserialization.py @@ -1,29 +1,29 @@ """ deserialization.py — Convert validated payload_tree into typed dataclass instance. -After xsd_validation_step confirms the payload conforms to the listener's contract, -this step uses the xmlable library to deserialize the lxml Element into the -registered @xmlify dataclass. - -The resulting instance is placed in state.payload and handed to the handler. +After xsd_validation_step confirms the payload conforms to the contract, +this step uses our customized xmlable routines to deserialize the lxml Element +directly in memory — no temporary files needed. Part of AgentServer v2.1 message pump. """ -from xmlable import from_xml # from the xmlable library +from lxml.etree import _Element from agentserver.message_bus.message_state import MessageState +# Import the customized parse_element from your forked xmlable +from third_party.xmlable import parse_element # adjust path if needed + async def deserialization_step(state: MessageState) -> MessageState: """ - Deserialize the validated payload_tree into the listener's dataclass. + Deserialize the validated payload_tree into the listener's @xmlify dataclass. Requires: - - state.payload_tree valid against listener XSD - - state.metadata["payload_class"] set to the target dataclass (set at registration) + - state.payload_tree: validated lxml Element + - state.metadata["payload_class"]: the target dataclass - On success: state.payload = dataclass instance - On failure: state.error set with clear message + Uses the custom parse_element routine for direct in-memory deserialization. """ if state.payload_tree is None: state.error = "deserialization_step: no payload_tree (previous step failed)" @@ -35,8 +35,8 @@ async def deserialization_step(state: MessageState) -> MessageState: return state try: - # xmlable.from_xml handles namespace-aware deserialization - instance = from_xml(payload_class, state.payload_tree) + # Direct in-memory deserialization — fast and clean + instance = parse_element(payload_class, state.payload_tree) state.payload = instance except Exception as exc: # pylint: disable=broad-except diff --git a/agentserver/message_bus/steps/payload_extraction.py b/agentserver/message_bus/steps/payload_extraction.py index 27a4c2c..fba95c2 100644 --- a/agentserver/message_bus/steps/payload_extraction.py +++ b/agentserver/message_bus/steps/payload_extraction.py @@ -2,8 +2,7 @@ payload_extraction.py — Extract the inner payload from the validated envelope. After envelope_validation_step confirms a correct outer envelope, -this step removes the envelope elements (, , optional , etc.) -and isolates the single child element that is the actual payload. +this step extracts metadata from and isolates the single payload element. The payload is expected to be exactly one root element (the capability-specific XML). If zero or multiple payload roots are found, we set a clear error — this protects @@ -17,19 +16,25 @@ from agentserver.message_bus.message_state import MessageState # Envelope namespace for easy reference _ENVELOPE_NS = "https://xml-pipeline.org/ns/envelope/v1" -_MESSAGE_TAG = f"{{{ _ENVELOPE_NS }}}message" +_MESSAGE_TAG = f"{{{_ENVELOPE_NS}}}message" +_META_TAG = f"{{{_ENVELOPE_NS}}}meta" +_FROM_TAG = f"{{{_ENVELOPE_NS}}}from" +_TO_TAG = f"{{{_ENVELOPE_NS}}}to" +_THREAD_TAG = f"{{{_ENVELOPE_NS}}}thread" async def payload_extraction_step(state: MessageState) -> MessageState: """ Extract the single payload element from the validated envelope. - Expected structure: + Expected structure (per envelope.xsd): - uuid - sender - - ← this is the one we want + + sender + receiver + uuid + + ← this is what we extract ... @@ -41,24 +46,42 @@ async def payload_extraction_step(state: MessageState) -> MessageState: state.error = "payload_extraction_step: no envelope_tree (previous step failed)" return state - # Basic sanity — root must be in correct namespace (already checked by schema, - # but we double-check for defence in depth) + # Basic sanity — root must be in correct namespace if state.envelope_tree.tag != _MESSAGE_TAG: - state.error = f"payload_extraction_step: root tag is not in envelope namespace" + state.error = "payload_extraction_step: root tag is not in envelope namespace" return state - # Find all direct children that are not envelope control elements - # Envelope control elements are: thread, from, to (optional) + # Find block and extract provenance + meta_elem = state.envelope_tree.find(_META_TAG) + if meta_elem is None: + state.error = "payload_extraction_step: missing block in envelope" + return state + + # Extract from_id (required) + from_elem = meta_elem.find(_FROM_TAG) + if from_elem is not None and from_elem.text: + state.from_id = from_elem.text.strip() + else: + state.error = "payload_extraction_step: missing in " + return state + + # Extract thread_id (required) + thread_elem = meta_elem.find(_THREAD_TAG) + if thread_elem is not None and thread_elem.text: + state.thread_id = thread_elem.text.strip() + else: + state.error = "payload_extraction_step: missing in " + return state + + # Optional: extract for direct routing + to_elem = meta_elem.find(_TO_TAG) + if to_elem is not None and to_elem.text: + state.to_id = to_elem.text.strip() + + # Find all direct children that are NOT — those are payload candidates payload_candidates = [ - child - for child in state.envelope_tree - if not ( - child.tag in { - f"{{{ _ENVELOPE_NS }}}thread", - f"{{{ _ENVELOPE_NS }}}from", - f"{{{ _ENVELOPE_NS }}}to", - } - ) + child for child in state.envelope_tree + if child.tag != _META_TAG ] if len(payload_candidates) == 0: @@ -73,19 +96,6 @@ async def payload_extraction_step(state: MessageState) -> MessageState: return state # Success — exactly one payload element - payload_element = payload_candidates[0] - - # Optional: capture provenance from envelope for later use - # (these will be trustworthy because envelope was validated) - thread_elem = state.envelope_tree.find(f"{{{ _ENVELOPE_NS }}}thread") - from_elem = state.envelope_tree.find(f"{{{ _ENVELOPE_NS }}}from") - - if thread_elem is not None and thread_elem.text: - state.thread_id = thread_elem.text.strip() - - if from_elem is not None and from_elem.text: - state.from_id = from_elem.text.strip() - - state.payload_tree = payload_element + state.payload_tree = payload_candidates[0] return state \ No newline at end of file diff --git a/agentserver/message_bus/steps/routing_resolution.py b/agentserver/message_bus/steps/routing_resolution.py index d38b7b8..b70ae32 100644 --- a/agentserver/message_bus/steps/routing_resolution.py +++ b/agentserver/message_bus/steps/routing_resolution.py @@ -1,50 +1,70 @@ """ routing_resolution.py — Resolve routing based on derived root tag. -This is the final preparation step before dispatch. -It computes the root tag from the deserialized payload and looks it up in the -global routing table (root_tag → list[Listener]). +This step computes the root tag from the deserialized payload and looks it up +in a routing table (root_tag → list[Listener]). -On success: state.target_listeners is set -On failure: state.error is set → message falls to system pipeline for +NOTE: The StreamPump has routing built-in via _route_step(). This standalone +step is provided for custom pipeline configurations or testing. + +Usage: + routing_step = make_routing_step(routing_table) + state = await routing_step(state) Part of AgentServer v2.1 message pump. """ +from __future__ import annotations + +from typing import Dict, List, Callable, Awaitable, TYPE_CHECKING + from agentserver.message_bus.message_state import MessageState -from agentserver.message_bus.bus import MessageBus + +if TYPE_CHECKING: + from agentserver.message_bus.stream_pump import Listener -async def routing_resolution_step(state: MessageState) -> MessageState: +def make_routing_step( + routing_table: Dict[str, List["Listener"]] +) -> Callable[[MessageState], Awaitable[MessageState]]: """ - Resolve which listener(s) should handle this payload. + Factory: create a routing step with a specific routing table. - Root tag = f"{from_id.lower()}.{payload_class_name.lower()}" - (from_id is trustworthy — injected by pump) - - Supports: - - Normal unique routing (one listener) - - Broadcast (multiple listeners if broadcast: true and same root tag) - - If no match → error, falls to system pipeline. + The routing table maps root tags to lists of listeners: + {"agent.payload": [listener1, listener2], ...} """ - if state.payload is None: - state.error = "routing_resolution_step: no deserialized payload (previous step failed)" + + async def routing_resolution_step(state: MessageState) -> MessageState: + """ + Resolve which listener(s) should handle this payload. + + Root tag = f"{from_id.lower()}.{payload_class_name.lower()}" + + Supports: + - Normal unique routing (one listener) + - Broadcast (multiple listeners if same root tag) + + If no match → error, falls to system pipeline. + """ + if state.payload is None: + state.error = "routing_resolution_step: no deserialized payload" + return state + + if state.to_id is None: + state.error = "routing_resolution_step: missing to_id" + return state + + payload_class_name = type(state.payload).__name__.lower() + root_tag = f"{state.to_id.lower()}.{payload_class_name}" + + targets = routing_table.get(root_tag) + + if not targets: + state.error = f"routing_resolution_step: unknown root tag '{root_tag}'" + return state + + state.target_listeners = targets return state - if state.from_id is None: - state.error = "routing_resolution_step: missing from_id (provenance error)" - return state - - payload_class_name = type(state.payload).__name__.lower() - root_tag = f"{state.from_id.lower()}.{payload_class_name}" - - bus = MessageBus.get_instance() - targets = bus.routing_table.get(root_tag) - - if not targets: - state.error = f"routing_resolution_step: unknown capability root tag '{root_tag}'" - return state - - state.target_listeners = targets - return state \ No newline at end of file + routing_resolution_step.__name__ = "routing_resolution_step" + return routing_resolution_step diff --git a/agentserver/message_bus/stream_pump.py b/agentserver/message_bus/stream_pump.py new file mode 100644 index 0000000..eccbef7 --- /dev/null +++ b/agentserver/message_bus/stream_pump.py @@ -0,0 +1,592 @@ +""" +pump_aiostream.py — Stream-Based Message Pump using aiostream + +This implementation treats the entire message flow as composable streams. +Fan-out (multi-payload, broadcast) is handled naturally via flatmap. + +Key insight: Each step is a stream transformer, not a 1:1 function. +The pipeline is just a composition of stream operators. + +Dependencies: + pip install aiostream +""" + +from __future__ import annotations + +import asyncio +import importlib +from dataclasses import dataclass, field +from pathlib import Path +from typing import AsyncIterable, Callable, List, Dict, Any, Optional + +import yaml +from lxml import etree +from aiostream import stream, pipe, operator + +# Import existing step implementations (we'll wrap them) +from agentserver.message_bus.steps.repair import repair_step +from agentserver.message_bus.steps.c14n import c14n_step +from agentserver.message_bus.steps.envelope_validation import envelope_validation_step +from agentserver.message_bus.steps.payload_extraction import payload_extraction_step +from agentserver.message_bus.steps.thread_assignment import thread_assignment_step +from agentserver.message_bus.message_state import MessageState, HandlerMetadata + + +# ============================================================================ +# Configuration (same as before) +# ============================================================================ + +@dataclass +class ListenerConfig: + name: str + payload_class_path: str + handler_path: str + description: str + is_agent: bool = False + peers: List[str] = field(default_factory=list) + broadcast: bool = False + payload_class: type = field(default=None, repr=False) + handler: Callable = field(default=None, repr=False) + + +@dataclass +class OrganismConfig: + name: str + identity_path: str = "" + port: int = 8765 + thread_scheduling: str = "breadth-first" + listeners: List[ListenerConfig] = field(default_factory=list) + + # Concurrency tuning + max_concurrent_pipelines: int = 50 # Total concurrent messages in pipeline + max_concurrent_handlers: int = 20 # Concurrent handler invocations + max_concurrent_per_agent: int = 5 # Per-agent rate limit + + +@dataclass +class Listener: + name: str + payload_class: type + handler: Callable + description: str + is_agent: bool = False + peers: List[str] = field(default_factory=list) + broadcast: bool = False + schema: etree.XMLSchema = field(default=None, repr=False) + root_tag: str = "" + + +# ============================================================================ +# Stream-Based Pipeline Steps +# ============================================================================ + +def wrap_step(step_fn: Callable) -> Callable: + """ + Wrap an existing async step function for use with pipe.map. + + Existing steps: async def step(state) -> state + We keep them as-is since pipe.map handles the iteration. + """ + return step_fn + + +async def extract_payloads(state: MessageState) -> AsyncIterable[MessageState]: + """ + Fan-out step: Extract 1..N payloads from handler response. + + This is used with pipe.flatmap — yields multiple states for each input. + """ + if state.raw_bytes is None: + yield state + return + + try: + # Wrap in dummy to handle multiple roots + wrapped = b"" + state.raw_bytes + b"" + tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True)) + + children = list(tree) + if not children: + yield state + return + + for child in children: + payload_bytes = etree.tostring(child) + yield MessageState( + raw_bytes=payload_bytes, + thread_id=state.thread_id, + from_id=state.from_id, + metadata=state.metadata.copy(), + ) + + except Exception: + # On parse failure, pass through as-is + yield state + + +def make_xsd_validation(schema: etree.XMLSchema) -> Callable: + """Factory for XSD validation step with schema baked in.""" + async def validate(state: MessageState) -> MessageState: + if state.payload_tree is None or state.error: + return state + try: + schema.assertValid(state.payload_tree) + except etree.DocumentInvalid as e: + state.error = f"XSD validation failed: {e}" + return state + return validate + + +def make_deserialization(payload_class: type) -> Callable: + """Factory for deserialization step with class baked in.""" + from third_party.xmlable import parse_element + + async def deserialize(state: MessageState) -> MessageState: + if state.payload_tree is None or state.error: + return state + try: + state.payload = parse_element(payload_class, state.payload_tree) + except Exception as e: + state.error = f"Deserialization failed: {e}" + return state + return deserialize + + +# ============================================================================ +# The Stream-Based Pump +# ============================================================================ + +class StreamPump: + """ + Message pump built on aiostream. + + The entire flow is a single composable stream pipeline. + Fan-out is natural via flatmap. Concurrency is controlled via task_limit. + """ + + def __init__(self, config: OrganismConfig): + self.config = config + + # Message queue feeds the stream + self.queue: asyncio.Queue[MessageState] = asyncio.Queue() + + # Routing table + self.routing_table: Dict[str, List[Listener]] = {} + self.listeners: Dict[str, Listener] = {} + + # Per-agent semaphores for rate limiting + self.agent_semaphores: Dict[str, asyncio.Semaphore] = {} + + # Shutdown control + self._running = False + + # ------------------------------------------------------------------ + # Registration + # ------------------------------------------------------------------ + + def register_listener(self, lc: ListenerConfig) -> Listener: + root_tag = f"{lc.name.lower()}.{lc.payload_class.__name__.lower()}" + + listener = Listener( + name=lc.name, + payload_class=lc.payload_class, + handler=lc.handler, + description=lc.description, + is_agent=lc.is_agent, + peers=lc.peers, + broadcast=lc.broadcast, + schema=self._generate_schema(lc.payload_class), + root_tag=root_tag, + ) + + if lc.is_agent: + self.agent_semaphores[lc.name] = asyncio.Semaphore( + self.config.max_concurrent_per_agent + ) + + self.routing_table.setdefault(root_tag, []).append(listener) + self.listeners[lc.name] = listener + return listener + + def register_all(self) -> None: + for lc in self.config.listeners: + self.register_listener(lc) + + def _generate_schema(self, payload_class: type) -> etree.XMLSchema: + """Generate XSD schema from xmlified payload class.""" + if hasattr(payload_class, 'xsd'): + xsd_tree = payload_class.xsd() + return etree.XMLSchema(xsd_tree) + # Fallback for non-xmlified classes (e.g., in tests) + permissive = '' + return etree.XMLSchema(etree.fromstring(permissive.encode())) + + # ------------------------------------------------------------------ + # Stream Source + # ------------------------------------------------------------------ + + async def _queue_source(self) -> AsyncIterable[MessageState]: + """Async generator that yields messages from the queue.""" + while self._running: + try: + state = await asyncio.wait_for(self.queue.get(), timeout=0.5) + yield state + self.queue.task_done() + except asyncio.TimeoutError: + continue + + # ------------------------------------------------------------------ + # Pipeline Steps (as stream operators) + # ------------------------------------------------------------------ + + async def _route_step(self, state: MessageState) -> MessageState: + """Determine target listeners based on to_id.class format.""" + if state.error or state.payload is None: + return state + + payload_class_name = type(state.payload).__name__.lower() + to_id = (state.to_id or "").lower() + root_tag = f"{to_id}.{payload_class_name}" if to_id else payload_class_name + + targets = self.routing_table.get(root_tag) + if targets: + state.target_listeners = targets + else: + state.error = f"No listener for: {root_tag}" + + return state + + async def _dispatch_to_handlers(self, state: MessageState) -> AsyncIterable[MessageState]: + """ + Fan-out step: Dispatch to handler(s) and yield response states. + + For broadcast, yields one response per listener. + Each response becomes a new message in the stream. + """ + if state.error or not state.target_listeners: + # Pass through errors/unroutable for downstream handling + yield state + return + + for listener in state.target_listeners: + try: + # Rate limiting for agents + semaphore = self.agent_semaphores.get(listener.name) + if semaphore: + await semaphore.acquire() + + try: + metadata = HandlerMetadata( + thread_id=state.thread_id or "", + from_id=state.from_id or "", + own_name=listener.name if listener.is_agent else None, + ) + + response_bytes = await listener.handler(state.payload, metadata) + + if not isinstance(response_bytes, bytes): + response_bytes = b"Handler returned invalid type" + + # Yield response — will be processed by next iteration + yield MessageState( + raw_bytes=response_bytes, + thread_id=state.thread_id, + from_id=listener.name, + ) + + finally: + if semaphore: + semaphore.release() + + except Exception as exc: + yield MessageState( + raw_bytes=f"Handler {listener.name} crashed: {exc}".encode(), + thread_id=state.thread_id, + from_id=listener.name, + error=str(exc), + ) + + async def _reinject_responses(self, state: MessageState) -> None: + """Push handler responses back into the queue for next iteration.""" + await self.queue.put(state) + + # ------------------------------------------------------------------ + # Build the Pipeline + # ------------------------------------------------------------------ + + def build_pipeline(self, source: AsyncIterable[MessageState]): + """ + Construct the full processing pipeline. + + This is where you configure the flow. Modify this method to: + - Add/remove steps + - Change concurrency limits + - Insert logging/metrics + - Add filtering + """ + + # The pipeline is a composition of stream operators + pipeline = ( + stream.iterate(source) + + # ============================================================ + # STAGE 1: Envelope Processing (1:1 transforms) + # ============================================================ + | pipe.map(repair_step) + | pipe.map(c14n_step) + | pipe.map(envelope_validation_step) + | pipe.map(payload_extraction_step) + | pipe.map(thread_assignment_step) + + # ============================================================ + # STAGE 2: Fan-out — Extract Multiple Payloads (1:N) + # ============================================================ + # Handler responses may contain multiple payloads. + # Each becomes a separate message in the stream. + | pipe.flatmap(extract_payloads) + + # ============================================================ + # STAGE 3: Per-Payload Validation (1:1 transforms) + # ============================================================ + # Note: In a real implementation, you'd route to listener-specific + # validation here. For now, we use a simplified approach. + | pipe.map(self._validate_and_deserialize) + + # ============================================================ + # STAGE 4: Routing (1:1) + # ============================================================ + | pipe.map(self._route_step) + + # ============================================================ + # STAGE 5: Filter Errors + # ============================================================ + # Errors go to a separate handler (could also be a branch) + | pipe.map(self._handle_errors) + | pipe.filter(lambda s: s.error is None and s.target_listeners) + + # ============================================================ + # STAGE 6: Fan-out — Dispatch to Handlers (1:N for broadcast) + # ============================================================ + # This is where handlers are invoked. Broadcast = multiple yields. + # task_limit controls concurrent handler invocations. + | pipe.flatmap( + self._dispatch_to_handlers, + task_limit=self.config.max_concurrent_handlers + ) + + # ============================================================ + # STAGE 7: Re-inject Responses + # ============================================================ + # Handler responses go back into the queue for next iteration. + # The cycle continues until no more messages. + | pipe.action(self._reinject_responses) + ) + + return pipeline + + async def _validate_and_deserialize(self, state: MessageState) -> MessageState: + """ + Combined validation + deserialization. + + Uses to_id + payload tag to find the right listener and schema. + """ + if state.error or state.payload_tree is None: + return state + + # Build lookup key: to_id.payload_tag (matching routing table format) + payload_tag = state.payload_tree.tag + if payload_tag.startswith("{"): + payload_tag = payload_tag.split("}", 1)[1] + + to_id = (state.to_id or "").lower() + lookup_key = f"{to_id}.{payload_tag.lower()}" if to_id else payload_tag.lower() + + listeners = self.routing_table.get(lookup_key, []) + if not listeners: + state.error = f"No listener for: {lookup_key}" + return state + + listener = listeners[0] + + # Validate against listener's schema + try: + listener.schema.assertValid(state.payload_tree) + except etree.DocumentInvalid as e: + state.error = f"XSD validation failed: {e}" + return state + + # Deserialize + try: + from third_party.xmlable import parse_element + state.payload = parse_element(listener.payload_class, state.payload_tree) + except Exception as e: + state.error = f"Deserialization failed: {e}" + + return state + + async def _handle_errors(self, state: MessageState) -> MessageState: + """Log errors (could also emit messages).""" + if state.error: + print(f"[ERROR] {state.thread_id}: {state.error}") + # Could emit to a specific listener here + return state + + # ------------------------------------------------------------------ + # Run the Pump + # ------------------------------------------------------------------ + + async def run(self) -> None: + """ + Main entry point — run the stream pipeline. + + The pipeline pulls from the queue, processes messages, + and re-injects handler responses. Continues until shutdown. + """ + self._running = True + + pipeline = self.build_pipeline(self._queue_source()) + + try: + async with pipeline.stream() as streamer: + async for _ in streamer: + # The pipeline drives itself via re-injection. + # We just need to consume the stream. + pass + except asyncio.CancelledError: + pass + finally: + self._running = False + + # ------------------------------------------------------------------ + # External API + # ------------------------------------------------------------------ + + async def inject(self, raw_bytes: bytes, thread_id: str, from_id: str) -> None: + """Inject a message to start processing.""" + state = MessageState( + raw_bytes=raw_bytes, + thread_id=thread_id, + from_id=from_id, + ) + await self.queue.put(state) + + async def shutdown(self) -> None: + """Graceful shutdown — wait for queue to drain.""" + self._running = False + await self.queue.join() + + +# ============================================================================ +# Config Loader (same as before) +# ============================================================================ + +class ConfigLoader: + @classmethod + def load(cls, path: str | Path) -> OrganismConfig: + with open(Path(path)) as f: + raw = yaml.safe_load(f) + return cls._parse(raw) + + @classmethod + def _parse(cls, raw: dict) -> OrganismConfig: + org = raw.get("organism", {}) + config = OrganismConfig( + name=org.get("name", "unnamed"), + identity_path=org.get("identity", ""), + port=org.get("port", 8765), + thread_scheduling=raw.get("thread_scheduling", "breadth-first"), + max_concurrent_pipelines=raw.get("max_concurrent_pipelines", 50), + max_concurrent_handlers=raw.get("max_concurrent_handlers", 20), + max_concurrent_per_agent=raw.get("max_concurrent_per_agent", 5), + ) + + for entry in raw.get("listeners", []): + lc = cls._parse_listener(entry) + cls._resolve_imports(lc) + config.listeners.append(lc) + + return config + + @classmethod + def _parse_listener(cls, raw: dict) -> ListenerConfig: + return ListenerConfig( + name=raw["name"], + payload_class_path=raw["payload_class"], + handler_path=raw["handler"], + description=raw["description"], + is_agent=raw.get("agent", False), + peers=raw.get("peers", []), + broadcast=raw.get("broadcast", False), + ) + + @classmethod + def _resolve_imports(cls, lc: ListenerConfig) -> None: + mod, cls_name = lc.payload_class_path.rsplit(".", 1) + lc.payload_class = getattr(importlib.import_module(mod), cls_name) + + mod, fn_name = lc.handler_path.rsplit(".", 1) + lc.handler = getattr(importlib.import_module(mod), fn_name) + + +# ============================================================================ +# Bootstrap +# ============================================================================ + +async def bootstrap(config_path: str = "config/organism.yaml") -> StreamPump: + """Load config and create pump.""" + config = ConfigLoader.load(config_path) + print(f"Organism: {config.name}") + print(f"Listeners: {len(config.listeners)}") + + pump = StreamPump(config) + pump.register_all() + + print(f"Routing: {list(pump.routing_table.keys())}") + return pump + + +# ============================================================================ +# Example: Customizing the Pipeline +# ============================================================================ + +""" +The beauty of aiostream: the pipeline is just a composition. +You can easily insert, remove, or reorder stages. + +# Add logging between stages: +| pipe.action(lambda s: print(f"After repair: {s.thread_id}")) + +# Add throttling: +| pipe.map(some_step, task_limit=5) + +# Branch errors to a separate stream: +errors, valid = stream.partition(source, lambda s: s.error is not None) + +# Merge multiple sources: +combined = stream.merge(queue_source, oob_source, external_api_source) + +# Add timeout per message: +| pipe.timeout(30.0) # 30 second timeout per item + +# Rate limit the whole pipeline: +| pipe.spaceout(0.1) # 100ms between items +""" + + +# ============================================================================ +# Comparison: Old vs New +# ============================================================================ + +""" +OLD (bus.py): + for payload in payloads: + await pipeline.process(payload) # Sequential, recursive + +NEW (aiostream): + | pipe.flatmap(extract_payloads) # Fan-out, parallel + | pipe.flatmap(dispatch, task_limit=20) # Concurrent handlers + +The key difference: +- Old: 3 tool calls = 3 sequential awaits, each blocking until complete +- New: 3 tool calls = 3 items in stream, processed concurrently up to task_limit +""" diff --git a/agentserver/schema/envelope.xsd b/agentserver/schema/envelope.xsd index 5d69c61..846102e 100644 --- a/agentserver/schema/envelope.xsd +++ b/agentserver/schema/envelope.xsd @@ -20,8 +20,8 @@ - - + + diff --git a/config/organism.yaml b/config/organism.yaml new file mode 100644 index 0000000..33b5e29 --- /dev/null +++ b/config/organism.yaml @@ -0,0 +1,24 @@ +# organism.yaml — Sample configuration for testing the message pump +# +# This defines a simple "hello world" organism with one listener. + +organism: + name: hello-world + port: 8765 + +# Concurrency settings +max_concurrent_pipelines: 50 +max_concurrent_handlers: 20 +max_concurrent_per_agent: 5 + +# Thread scheduling: breadth-first or depth-first +thread_scheduling: breadth-first + +listeners: + # The greeter listener responds to Greeting payloads + - name: greeter + payload_class: handlers.hello.Greeting + handler: handlers.hello.handle_greeting + description: Responds with a greeting message + agent: false + broadcast: false diff --git a/handlers/__init__.py b/handlers/__init__.py new file mode 100644 index 0000000..defb042 --- /dev/null +++ b/handlers/__init__.py @@ -0,0 +1 @@ +# handlers — Sample handlers for testing the message pump diff --git a/handlers/hello.py b/handlers/hello.py new file mode 100644 index 0000000..dd7754e --- /dev/null +++ b/handlers/hello.py @@ -0,0 +1,78 @@ +""" +hello.py — Hello World handler for testing the message pump. + +This module provides: +- Greeting: payload class (what the handler receives) +- GreetingResponse: response payload (what the handler returns) +- handle_greeting: async handler function + +Usage in organism.yaml: + listeners: + - name: greeter + payload_class: handlers.hello.Greeting + handler: handlers.hello.handle_greeting + description: Responds with a greeting message +""" + +from dataclasses import dataclass +from lxml import etree + +from third_party.xmlable import xmlify +from agentserver.message_bus.message_state import HandlerMetadata + + +# Envelope namespace +ENVELOPE_NS = "https://xml-pipeline.org/ns/envelope/v1" + + +@xmlify +@dataclass +class Greeting: + """Incoming greeting request.""" + name: str + + +@xmlify +@dataclass +class GreetingResponse: + """Outgoing greeting response.""" + message: str + + +def wrap_in_envelope(payload_bytes: bytes, from_id: str, to_id: str, thread_id: str) -> bytes: + """Wrap a payload in a proper message envelope.""" + return f""" + + {from_id} + {to_id} + {thread_id} + + {payload_bytes.decode('utf-8')} +""".encode('utf-8') + + +async def handle_greeting(payload: Greeting, metadata: HandlerMetadata) -> bytes: + """ + Handle an incoming Greeting and respond with a GreetingResponse. + + Args: + payload: The deserialized Greeting instance + metadata: Contains thread_id, from_id, own_name + + Returns: + XML bytes of the response envelope + """ + # Create response + response = GreetingResponse(message=f"Hello, {payload.name}!") + + # Serialize to XML + response_tree = response.xml_value("GreetingResponse") + payload_bytes = etree.tostring(response_tree, encoding='utf-8') + + # Wrap in envelope - respond back to sender + return wrap_in_envelope( + payload_bytes=payload_bytes, + from_id=metadata.own_name or "greeter", + to_id=metadata.from_id, # Send back to whoever sent the greeting + thread_id=metadata.thread_id, + ) diff --git a/pyproject.toml b/pyproject.toml index 6f0be1b..b25eda8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,15 +5,52 @@ build-backend = "setuptools.build_meta" [project] name = "xml-pipeline" -version = "0.1.0" +version = "0.2.0" description = "Tamper-proof nervous system for multi-agent organisms" +readme = "README.md" +requires-python = ">=3.11" +license = {text = "MIT"} +keywords = ["xml", "multi-agent", "message-bus", "aiostream"] +classifiers = [ + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Framework :: AsyncIO", +] dependencies = [ "lxml", "websockets", "pyotp", "pyyaml", "cryptography", + "aiostream>=0.5", + "pyhumps", + "termcolor", ] +[project.optional-dependencies] +test = [ + "pytest>=7.0", + "pytest-asyncio>=0.21", +] +dev = [ + "pytest>=7.0", + "pytest-asyncio>=0.21", + "mypy", + "ruff", +] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" +testpaths = ["tests"] +python_files = ["test_*.py"] +# Don't collect root __init__.py (has imports that break isolation) +norecursedirs = [".git", "__pycache__", "*.egg-info"] + [tool.setuptools.packages.find] -where = ["."] \ No newline at end of file +where = ["."] +include = ["agentserver*", "third_party*"] diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..b321683 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,48 @@ +""" +conftest.py — Shared pytest configuration and fixtures + +This file is automatically loaded by pytest. +""" + +import pytest +import sys +from pathlib import Path + +# Ensure the project root is in the path for imports +# BUT don't import the root package (it has heavy deps) +project_root = Path(__file__).parent.parent +if str(project_root) not in sys.path: + sys.path.insert(0, str(project_root)) + +# Tell pytest to ignore the root __init__.py +collect_ignore_glob = ["../__init__.py"] + + +# ============================================================================ +# Markers +# ============================================================================ + +def pytest_configure(config): + """Register custom markers.""" + config.addinivalue_line( + "markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')" + ) + config.addinivalue_line( + "markers", "integration: marks tests as integration tests" + ) + + +# ============================================================================ +# Fixtures available to all tests +# ============================================================================ + +@pytest.fixture +def sample_thread_id(): + """A valid UUID for testing.""" + return "550e8400-e29b-41d4-a716-446655440000" + + +@pytest.fixture +def sample_from_id(): + """A valid sender ID for testing.""" + return "calculator.add" diff --git a/tests/test_pipeline_steps.py b/tests/test_pipeline_steps.py new file mode 100644 index 0000000..4ed9d3e --- /dev/null +++ b/tests/test_pipeline_steps.py @@ -0,0 +1,632 @@ +""" +test_pipeline_steps.py — Unit tests for individual pipeline steps + +Run with: pytest tests/test_pipeline_steps.py -v + +Each step is tested in isolation with known inputs and expected outputs. +This makes debugging much easier than testing the full pipeline. + +Install test dependencies: + pip install -e ".[test]" +""" + +import pytest +import asyncio +from dataclasses import dataclass +from lxml import etree + +# Import the message state +from agentserver.message_bus.message_state import MessageState, HandlerMetadata + +# Import individual steps +from agentserver.message_bus.steps.repair import repair_step +from agentserver.message_bus.steps.c14n import c14n_step +from agentserver.message_bus.steps.envelope_validation import envelope_validation_step +from agentserver.message_bus.steps.payload_extraction import payload_extraction_step +from agentserver.message_bus.steps.thread_assignment import thread_assignment_step + +# Check for optional dependencies +try: + import aiostream + HAS_AIOSTREAM = True +except ImportError: + HAS_AIOSTREAM = False + +requires_aiostream = pytest.mark.skipif( + not HAS_AIOSTREAM, + reason="aiostream not installed (pip install aiostream)" +) + +# Check for stream_pump dependencies +try: + from agentserver.message_bus.stream_pump import StreamPump, Listener + from agentserver.message_bus.steps.routing_resolution import make_routing_step + HAS_STREAM_PUMP = True +except ImportError: + HAS_STREAM_PUMP = False + +requires_stream_pump = pytest.mark.skipif( + not HAS_STREAM_PUMP, + reason="stream_pump dependencies not available" +) + + +# ============================================================================ +# Test Fixtures +# ============================================================================ + +@pytest.fixture +def valid_envelope_bytes(): + """A well-formed message envelope matching envelope.xsd.""" + return b''' + + + calculator.add + 550e8400-e29b-41d4-a716-446655440000 + + + 5 + 3 + + ''' + + +@pytest.fixture +def malformed_xml_bytes(): + """Malformed XML that lxml can partially recover.""" + return b'content' + + +@pytest.fixture +def completely_broken_bytes(): + """Not XML at all.""" + return b'this is not xml at all { json: "maybe" }' + + +@pytest.fixture +def multi_payload_response(): + """Handler response with multiple payloads.""" + return b''' + 42 + 12 + I should also check... + ''' + + +@pytest.fixture +def empty_state(): + """Fresh MessageState with no data.""" + return MessageState() + + +@pytest.fixture +def state_with_bytes(valid_envelope_bytes): + """MessageState with raw_bytes populated.""" + return MessageState(raw_bytes=valid_envelope_bytes) + + +# ============================================================================ +# repair_step Tests +# ============================================================================ + +class TestRepairStep: + """Tests for the XML repair/recovery step.""" + + @pytest.mark.asyncio + async def test_valid_xml_passes_through(self, valid_envelope_bytes): + """Valid XML should parse without error.""" + state = MessageState(raw_bytes=valid_envelope_bytes) + result = await repair_step(state) + + assert result.error is None + assert result.envelope_tree is not None + assert result.envelope_tree.tag == "{https://xml-pipeline.org/ns/envelope/v1}message" + + @pytest.mark.asyncio + async def test_malformed_xml_recovered(self, malformed_xml_bytes): + """Malformed XML should be recovered if possible.""" + state = MessageState(raw_bytes=malformed_xml_bytes) + result = await repair_step(state) + + # lxml recovery mode should produce something + # May or may not have error depending on severity + assert result.envelope_tree is not None or result.error is not None + + @pytest.mark.asyncio + async def test_no_bytes_sets_error(self, empty_state): + """Missing raw_bytes should set an error.""" + result = await repair_step(empty_state) + + assert result.error is not None + assert "no raw_bytes" in result.error + + @pytest.mark.asyncio + async def test_clears_raw_bytes_after_parse(self, valid_envelope_bytes): + """raw_bytes should be cleared after successful parse (memory optimization).""" + state = MessageState(raw_bytes=valid_envelope_bytes) + result = await repair_step(state) + + assert result.raw_bytes is None + assert result.envelope_tree is not None + + +# ============================================================================ +# c14n_step Tests +# ============================================================================ + +class TestC14nStep: + """Tests for the canonicalization step.""" + + @pytest.mark.asyncio + async def test_normalizes_whitespace(self): + """C14N should normalize whitespace.""" + xml_with_whitespace = b''' + value + ''' + + state = MessageState(raw_bytes=xml_with_whitespace) + state = await repair_step(state) + result = await c14n_step(state) + + assert result.error is None + assert result.envelope_tree is not None + + @pytest.mark.asyncio + async def test_normalizes_attribute_order(self): + """C14N should produce consistent attribute ordering.""" + xml_a = b'' + xml_b = b'' + + state_a = MessageState(raw_bytes=xml_a) + state_a = await repair_step(state_a) + state_a = await c14n_step(state_a) + + state_b = MessageState(raw_bytes=xml_b) + state_b = await repair_step(state_b) + state_b = await c14n_step(state_b) + + # Both should produce identical canonical form + c14n_a = etree.tostring(state_a.envelope_tree, method="c14n") + c14n_b = etree.tostring(state_b.envelope_tree, method="c14n") + assert c14n_a == c14n_b + + @pytest.mark.asyncio + async def test_no_tree_sets_error(self, empty_state): + """Missing envelope_tree should set error.""" + result = await c14n_step(empty_state) + + assert result.error is not None + assert "no envelope_tree" in result.error + + +# ============================================================================ +# payload_extraction_step Tests +# ============================================================================ + +class TestPayloadExtractionStep: + """Tests for extracting payload from envelope.""" + + @pytest.mark.asyncio + async def test_extracts_payload_element(self, valid_envelope_bytes): + """Should extract the payload element from envelope.""" + state = MessageState(raw_bytes=valid_envelope_bytes) + state = await repair_step(state) + state = await c14n_step(state) + # Skip envelope validation for this test + result = await payload_extraction_step(state) + + assert result.error is None + assert result.payload_tree is not None + # Tag may include namespace prefix + assert "addpayload" in result.payload_tree.tag + + @pytest.mark.asyncio + async def test_extracts_thread_id(self, valid_envelope_bytes): + """Should extract thread ID from envelope.""" + state = MessageState(raw_bytes=valid_envelope_bytes) + state = await repair_step(state) + state = await c14n_step(state) + result = await payload_extraction_step(state) + + assert result.thread_id == "550e8400-e29b-41d4-a716-446655440000" + + @pytest.mark.asyncio + async def test_extracts_from_id(self, valid_envelope_bytes): + """Should extract sender ID from envelope.""" + state = MessageState(raw_bytes=valid_envelope_bytes) + state = await repair_step(state) + state = await c14n_step(state) + result = await payload_extraction_step(state) + + assert result.from_id == "calculator.add" + + @pytest.mark.asyncio + async def test_multiple_payloads_error(self): + """Multiple payload elements should error.""" + multi_payload = b''' + + + test + uuid-here + + data + more data + ''' + + state = MessageState(raw_bytes=multi_payload) + state = await repair_step(state) + state = await c14n_step(state) + result = await payload_extraction_step(state) + + assert result.error is not None + assert "multiple payload" in result.error.lower() + + @pytest.mark.asyncio + async def test_no_payload_error(self): + """Missing payload element should error.""" + no_payload = b''' + + + test + uuid-here + + ''' + + state = MessageState(raw_bytes=no_payload) + state = await repair_step(state) + state = await c14n_step(state) + result = await payload_extraction_step(state) + + assert result.error is not None + assert "no payload" in result.error.lower() + + @pytest.mark.asyncio + async def test_missing_meta_error(self): + """Missing block should error.""" + no_meta = b''' + + data + ''' + + state = MessageState(raw_bytes=no_meta) + state = await repair_step(state) + state = await c14n_step(state) + result = await payload_extraction_step(state) + + assert result.error is not None + assert "meta" in result.error.lower() + + @pytest.mark.asyncio + async def test_missing_from_error(self): + """Missing in should error.""" + no_from = b''' + + + uuid-here + + data + ''' + + state = MessageState(raw_bytes=no_from) + state = await repair_step(state) + state = await c14n_step(state) + result = await payload_extraction_step(state) + + assert result.error is not None + assert "from" in result.error.lower() + + +# ============================================================================ +# thread_assignment_step Tests +# ============================================================================ + +class TestThreadAssignmentStep: + """Tests for thread UUID assignment.""" + + @pytest.mark.asyncio + async def test_valid_uuid_preserved(self): + """Valid UUID should be preserved.""" + valid_uuid = "550e8400-e29b-41d4-a716-446655440000" + state = MessageState(thread_id=valid_uuid) + result = await thread_assignment_step(state) + + assert result.thread_id == valid_uuid + + @pytest.mark.asyncio + async def test_missing_uuid_generated(self, empty_state): + """Missing UUID should generate a new one.""" + result = await thread_assignment_step(empty_state) + + assert result.thread_id is not None + assert len(result.thread_id) == 36 # UUID format + + @pytest.mark.asyncio + async def test_invalid_uuid_replaced(self): + """Invalid UUID should be replaced with a new one.""" + state = MessageState(thread_id="not-a-valid-uuid") + result = await thread_assignment_step(state) + + assert result.thread_id != "not-a-valid-uuid" + assert len(result.thread_id) == 36 + + @pytest.mark.asyncio + async def test_replacement_logged_in_metadata(self): + """Replaced UUIDs should be logged in metadata.""" + state = MessageState(thread_id="bad-uuid") + result = await thread_assignment_step(state) + + diagnostics = result.metadata.get("diagnostics", []) + assert len(diagnostics) > 0 + assert "bad-uuid" in diagnostics[0] + + +# ============================================================================ +# Multi-Payload Extraction Tests (standalone, no aiostream required) +# ============================================================================ + +class TestPayloadExtractionLogic: + """Test the core payload extraction logic without aiostream.""" + + def test_extract_single_payload(self): + """Single root element should extract cleanly.""" + raw = b"42" + wrapped = b"" + raw + b"" + tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True)) + + children = list(tree) + assert len(children) == 1 + assert children[0].tag == "result" + assert children[0].text == "42" + + def test_extract_multiple_payloads(self, multi_payload_response): + """Multiple root elements should all be extracted.""" + wrapped = b"" + multi_payload_response + b"" + tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True)) + + children = list(tree) + assert len(children) == 3 + + tags = [c.tag for c in children] + assert "search.result" in tags + assert "calculator.add.addpayload" in tags + assert "thought" in tags + + def test_extract_preserves_content(self): + """Extracted payloads should preserve their content.""" + raw = b"value" + wrapped = b"" + raw + b"" + tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True)) + + children = list(tree) + assert len(children) == 1 + + # Re-serialize and check + extracted = etree.tostring(children[0]) + assert b"value" in extracted + + def test_empty_response_no_crash(self): + """Empty response should not crash.""" + wrapped = b"" + tree = etree.fromstring(wrapped) + + children = list(tree) + assert len(children) == 0 + + def test_malformed_response_recovers(self): + """Malformed XML should be recovered if possible.""" + raw = b"text" + wrapped = b"" + raw + b"" + + # With recovery parser + tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True)) + # Should get something, exact result depends on lxml recovery + assert tree is not None + + +# ============================================================================ +# Multi-Payload Extraction Tests (from stream_pump.py) +# ============================================================================ + +@requires_aiostream +class TestMultiPayloadExtraction: + """Tests for the fan-out payload extraction.""" + + @pytest.mark.asyncio + async def test_single_payload_yields_one(self): + """Single payload should yield one state.""" + from agentserver.message_bus.stream_pump import extract_payloads + + state = MessageState( + raw_bytes=b"42", + thread_id="test-thread", + from_id="test-sender", + ) + + results = [s async for s in extract_payloads(state)] + + assert len(results) == 1 + assert b"" in results[0].raw_bytes + assert results[0].thread_id == "test-thread" + assert results[0].from_id == "test-sender" + + @pytest.mark.asyncio + async def test_multiple_payloads_yields_many(self, multi_payload_response): + """Multiple payloads should yield multiple states.""" + from agentserver.message_bus.stream_pump import extract_payloads + + state = MessageState( + raw_bytes=multi_payload_response, + thread_id="test-thread", + from_id="agent", + ) + + results = [s async for s in extract_payloads(state)] + + assert len(results) == 3 + # Each result should have the same thread_id and from_id + for r in results: + assert r.thread_id == "test-thread" + assert r.from_id == "agent" + + @pytest.mark.asyncio + async def test_empty_response_yields_original(self): + """Empty response should yield original state.""" + from agentserver.message_bus.stream_pump import extract_payloads + + state = MessageState( + raw_bytes=b"", + thread_id="test", + from_id="test", + ) + + results = [s async for s in extract_payloads(state)] + + # Should yield something (original or empty handling) + assert len(results) >= 1 + + @pytest.mark.asyncio + async def test_preserves_metadata(self): + """Extracted payloads should preserve metadata.""" + from agentserver.message_bus.stream_pump import extract_payloads + + state = MessageState( + raw_bytes=b"", + thread_id="test", + from_id="test", + metadata={"custom": "value"}, + ) + + results = [s async for s in extract_payloads(state)] + + for r in results: + assert r.metadata.get("custom") == "value" + + +# ============================================================================ +# Step Factory Tests +# ============================================================================ + +@requires_stream_pump +class TestStepFactories: + """Tests for the step factory functions.""" + + @pytest.mark.asyncio + async def test_xsd_validation_direct(self): + """XSD validation via lxml schema.""" + # Create a simple schema + xsd_str = ''' + + + + + + + + + ''' + schema = etree.XMLSchema(etree.fromstring(xsd_str.encode())) + + # Valid payload + valid_xml = etree.fromstring(b"42") + assert schema.validate(valid_xml) + + # Invalid payload + invalid_xml = etree.fromstring(b"not-an-int") + assert not schema.validate(invalid_xml) + + @pytest.mark.asyncio + async def test_routing_factory(self): + """Routing step should use injected routing table.""" + from agentserver.message_bus.steps.routing_resolution import make_routing_step + from agentserver.message_bus.stream_pump import Listener + + # Create mock listener + mock_listener = Listener( + name="calculator.add", + payload_class=type("AddPayload", (), {}), + handler=lambda x, m: b"", + description="test", + ) + + routing_table = { + "calculator.add.addpayload": [mock_listener] + } + + step = make_routing_step(routing_table) + + # Create a mock payload instance + @dataclass + class AddPayload: + a: int = 0 + b: int = 0 + + state = MessageState( + payload=AddPayload(a=1, b=2), + to_id="calculator.add", + ) + + result = await step(state) + + assert result.error is None + assert result.target_listeners == [mock_listener] + + +# ============================================================================ +# Pipeline Integration Tests (lightweight) +# ============================================================================ + +class TestPipelineIntegration: + """Integration tests for step sequences.""" + + @pytest.mark.asyncio + async def test_repair_through_extraction(self, valid_envelope_bytes): + """Test repair → c14n → extraction chain.""" + state = MessageState(raw_bytes=valid_envelope_bytes) + + state = await repair_step(state) + assert state.error is None, f"repair failed: {state.error}" + + state = await c14n_step(state) + assert state.error is None, f"c14n failed: {state.error}" + + state = await payload_extraction_step(state) + assert state.error is None, f"extraction failed: {state.error}" + + assert state.payload_tree is not None + assert state.thread_id is not None + assert state.from_id is not None + + @pytest.mark.asyncio + async def test_error_short_circuits(self): + """Errors should prevent downstream steps from running.""" + call_log = [] + + async def step_a(state): + call_log.append("a") + state.error = "Intentional error" + return state + + async def step_b(state): + call_log.append("b") + return state + + # Simple pipeline runner (same logic as StreamPump uses) + async def run_pipeline(steps, state): + for step in steps: + state = await step(state) + if state.error: + break + return state + + result = await run_pipeline([step_a, step_b], MessageState()) + + assert call_log == ["a"] # step_b should not have been called + assert result.error == "Intentional error" + + +# ============================================================================ +# Run with pytest +# ============================================================================ + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_pump_integration.py b/tests/test_pump_integration.py new file mode 100644 index 0000000..81b654a --- /dev/null +++ b/tests/test_pump_integration.py @@ -0,0 +1,398 @@ +""" +test_pump_integration.py — Integration tests for the StreamPump + +Run with: pytest tests/test_pump_integration.py -v + +These tests verify the full message flow through the pump: + inject → parse → extract → validate → deserialize → route → handler → response +""" + +import pytest +import asyncio +import uuid +from unittest.mock import AsyncMock, patch + +from agentserver.message_bus import StreamPump, bootstrap, MessageState +from agentserver.message_bus.stream_pump import ConfigLoader, ListenerConfig, OrganismConfig, Listener +from handlers.hello import Greeting, GreetingResponse, handle_greeting, ENVELOPE_NS + + +def make_envelope(payload_xml: str, from_id: str, to_id: str, thread_id: str) -> bytes: + """Helper to create a properly formatted envelope. + + Note: payload_xml should include its own namespace (or xmlns="") to avoid + inheriting the envelope namespace. The envelope XSD expects payload to be + in a foreign namespace (##other). + """ + # Ensure payload has explicit namespace (empty string = no namespace) + if 'xmlns=' not in payload_xml: + # Insert xmlns="" after the tag name + idx = payload_xml.index('>') + if payload_xml[idx-1] == '/': + idx -= 1 + payload_xml = payload_xml[:idx] + ' xmlns=""' + payload_xml[idx:] + + return f""" + + {from_id} + {to_id} + {thread_id} + + {payload_xml} +""".encode('utf-8') + + +class TestPumpBootstrap: + """Test ConfigLoader and bootstrap.""" + + def test_config_loader_parses_yaml(self): + """ConfigLoader should parse organism.yaml correctly.""" + config = ConfigLoader.load('config/organism.yaml') + + assert config.name == "hello-world" + assert len(config.listeners) == 1 + assert config.listeners[0].name == "greeter" + assert config.listeners[0].payload_class == Greeting + assert config.listeners[0].handler == handle_greeting + + @pytest.mark.asyncio + async def test_bootstrap_creates_pump(self): + """bootstrap() should create a configured pump.""" + pump = await bootstrap('config/organism.yaml') + + assert pump.config.name == "hello-world" + assert "greeter.greeting" in pump.routing_table + assert pump.listeners["greeter"].payload_class == Greeting + + @pytest.mark.asyncio + async def test_bootstrap_generates_xsd(self): + """bootstrap() should generate XSD schemas for listeners.""" + pump = await bootstrap('config/organism.yaml') + + listener = pump.listeners["greeter"] + assert listener.schema is not None + + # Schema should validate a proper Greeting + from lxml import etree + valid_xml = etree.fromstring(b"Test") + listener.schema.assertValid(valid_xml) + + +class TestPumpInjection: + """Test message injection and queue behavior.""" + + @pytest.mark.asyncio + async def test_inject_adds_to_queue(self): + """inject() should add a MessageState to the queue.""" + pump = await bootstrap('config/organism.yaml') + + thread_id = str(uuid.uuid4()) + await pump.inject(b"", thread_id, from_id="user") + + assert pump.queue.qsize() == 1 + state = await pump.queue.get() + assert state.raw_bytes == b"" + assert state.thread_id == thread_id + assert state.from_id == "user" + + +class TestFullPipelineFlow: + """Test complete message flow through the pipeline.""" + + @pytest.mark.asyncio + async def test_greeting_round_trip(self): + """ + Full integration test: + 1. Inject a Greeting message + 2. Pump processes it through the pipeline + 3. Handler is called with deserialized Greeting + 4. Handler response is re-injected + """ + pump = await bootstrap('config/organism.yaml') + + # Track what the handler receives + handler_calls = [] + original_handler = pump.listeners["greeter"].handler + + async def tracking_handler(payload, metadata): + handler_calls.append((payload, metadata)) + return await original_handler(payload, metadata) + + pump.listeners["greeter"].handler = tracking_handler + + # Create and inject a Greeting message + thread_id = str(uuid.uuid4()) + envelope = make_envelope( + payload_xml="World", + from_id="user", + to_id="greeter", + thread_id=thread_id, + ) + + await pump.inject(envelope, thread_id, from_id="user") + + # Run pump briefly to process the message + pump._running = True + pipeline = pump.build_pipeline(pump._queue_source()) + + # Process with timeout + async def run_with_timeout(): + async with pipeline.stream() as streamer: + try: + async for _ in streamer: + # One iteration should process our message + break + except asyncio.CancelledError: + pass + + try: + await asyncio.wait_for(run_with_timeout(), timeout=2.0) + except asyncio.TimeoutError: + pass + finally: + pump._running = False + + # Verify handler was called + assert len(handler_calls) == 1 + payload, metadata = handler_calls[0] + + assert isinstance(payload, Greeting) + assert payload.name == "World" + assert metadata.thread_id == thread_id + assert metadata.from_id == "user" + + @pytest.mark.asyncio + async def test_handler_response_reinjected(self): + """Handler response should be re-injected into the queue.""" + pump = await bootstrap('config/organism.yaml') + + # Capture re-injected messages + reinjected = [] + original_reinject = pump._reinject_responses + + async def capture_reinject(state): + reinjected.append(state) + # Don't actually re-inject to avoid infinite loop + + pump._reinject_responses = capture_reinject + + # Inject a Greeting + thread_id = str(uuid.uuid4()) + envelope = make_envelope( + payload_xml="Alice", + from_id="user", + to_id="greeter", + thread_id=thread_id, + ) + + await pump.inject(envelope, thread_id, from_id="user") + + # Run pump briefly + pump._running = True + pipeline = pump.build_pipeline(pump._queue_source()) + + async def run_with_timeout(): + async with pipeline.stream() as streamer: + try: + async for _ in streamer: + break + except asyncio.CancelledError: + pass + + try: + await asyncio.wait_for(run_with_timeout(), timeout=2.0) + except asyncio.TimeoutError: + pass + finally: + pump._running = False + + # Verify response was re-injected + assert len(reinjected) == 1 + response_state = reinjected[0] + + assert response_state.raw_bytes is not None + assert b"Hello, Alice!" in response_state.raw_bytes + assert response_state.thread_id == thread_id + assert response_state.from_id == "greeter" + + +class TestErrorHandling: + """Test error paths through the pipeline.""" + + @pytest.mark.asyncio + async def test_invalid_xml_error(self): + """Malformed XML should set error, not crash.""" + pump = await bootstrap('config/organism.yaml') + + errors = [] + original_handle_errors = pump._handle_errors + + async def capture_errors(state): + if state.error: + errors.append(state.error) + return await original_handle_errors(state) + + pump._handle_errors = capture_errors + + # Inject malformed XML + thread_id = str(uuid.uuid4()) + await pump.inject(b"= 0 # Processed without crash + + @pytest.mark.asyncio + async def test_unknown_route_error(self): + """Message to unknown listener should error gracefully.""" + pump = await bootstrap('config/organism.yaml') + + errors = [] + original_handle_errors = pump._handle_errors + + async def capture_errors(state): + if state.error: + errors.append(state.error) + return await original_handle_errors(state) + + pump._handle_errors = capture_errors + + # Inject message to non-existent listener + thread_id = str(uuid.uuid4()) + envelope = make_envelope( + payload_xml="Test", + from_id="user", + to_id="nonexistent", # No such listener + thread_id=thread_id, + ) + + await pump.inject(envelope, thread_id, from_id="user") + + # Run pump + pump._running = True + pipeline = pump.build_pipeline(pump._queue_source()) + + async def run_with_timeout(): + async with pipeline.stream() as streamer: + try: + async for _ in streamer: + break + except asyncio.CancelledError: + pass + + try: + await asyncio.wait_for(run_with_timeout(), timeout=2.0) + except asyncio.TimeoutError: + pass + finally: + pump._running = False + + # Should have a routing error + assert any("nonexistent" in e for e in errors) + + +class TestManualPumpConfiguration: + """Test creating a pump without YAML config.""" + + @pytest.mark.asyncio + async def test_manual_listener_registration(self): + """Can register listeners programmatically.""" + config = OrganismConfig(name="manual-test") + pump = StreamPump(config) + + lc = ListenerConfig( + name="greeter", + payload_class_path="handlers.hello.Greeting", + handler_path="handlers.hello.handle_greeting", + description="Test listener", + payload_class=Greeting, + handler=handle_greeting, + ) + + listener = pump.register_listener(lc) + + assert listener.name == "greeter" + assert listener.root_tag == "greeter.greeting" + assert "greeter.greeting" in pump.routing_table + + @pytest.mark.asyncio + async def test_custom_handler(self): + """Can use a custom handler function.""" + config = OrganismConfig(name="custom-test") + pump = StreamPump(config) + + responses = [] + + async def custom_handler(payload, metadata): + responses.append(payload) + return b"" + + lc = ListenerConfig( + name="custom", + payload_class_path="handlers.hello.Greeting", + handler_path="handlers.hello.handle_greeting", + description="Custom handler", + payload_class=Greeting, + handler=custom_handler, + ) + + pump.register_listener(lc) + + # Inject and process + thread_id = str(uuid.uuid4()) + envelope = make_envelope( + payload_xml="Custom", + from_id="tester", + to_id="custom", + thread_id=thread_id, + ) + + await pump.inject(envelope, thread_id, from_id="tester") + + # Run pump + pump._running = True + + # Capture re-injected to prevent loop + async def noop_reinject(state): + pass + pump._reinject_responses = noop_reinject + + pipeline = pump.build_pipeline(pump._queue_source()) + + async def run_with_timeout(): + async with pipeline.stream() as streamer: + try: + async for _ in streamer: + break + except asyncio.CancelledError: + pass + + try: + await asyncio.wait_for(run_with_timeout(), timeout=2.0) + except asyncio.TimeoutError: + pass + finally: + pump._running = False + + # Custom handler should have been called + assert len(responses) == 1 + assert responses[0].name == "Custom" diff --git a/third_party/xmlable/__init__.py b/third_party/xmlable/__init__.py index 41bf01d..f7c0108 100644 --- a/third_party/xmlable/__init__.py +++ b/third_party/xmlable/__init__.py @@ -7,6 +7,7 @@ from typing import Type, TypeVar, Any from io import BytesIO from ._xmlify import xmlify +from ._errors import XErrorCtx T = TypeVar("T") @@ -19,7 +20,9 @@ def parse_element(cls: Type[T], element: _Element | ObjectifiedElement) -> T: """Direct in-memory deserialization from validated lxml Element.""" xobject = _get_xobject(cls) obj_element = objectify.fromstring(etree.tostring(element)) - return xobject.xml_in(obj_element, ctx=None) + # Create a root context for error tracing + ctx = XErrorCtx(trace=[cls.__name__]) + return xobject.xml_in(obj_element, ctx=ctx) def parse_bytes(cls: Type[T], xml_bytes: bytes) -> T: tree = objectify.parse(BytesIO(xml_bytes)) diff --git a/third_party/xmlable/_errors.py b/third_party/xmlable/_errors.py index 99e8ea2..dea06de 100644 --- a/third_party/xmlable/_errors.py +++ b/third_party/xmlable/_errors.py @@ -9,7 +9,7 @@ from typing import Any, Iterable from termcolor import colored from termcolor.termcolor import Color -from xmlable._utils import typename, AnyType +from ._utils import typename, AnyType def trace_note(trace: list[str], arrow_c: Color, node_c: Color): diff --git a/third_party/xmlable/_io.py b/third_party/xmlable/_io.py index 475e205..8b2a692 100644 --- a/third_party/xmlable/_io.py +++ b/third_party/xmlable/_io.py @@ -10,9 +10,9 @@ from termcolor import colored from lxml.objectify import parse as objectify_parse from lxml.etree import _ElementTree -from xmlable._utils import typename -from xmlable._xobject import is_xmlified -from xmlable._errors import ErrorTypes +from ._utils import typename +from ._xobject import is_xmlified +from ._errors import ErrorTypes def write_file(file_path: str | Path, tree: _ElementTree): diff --git a/third_party/xmlable/_manual.py b/third_party/xmlable/_manual.py index 9e73c31..4c2c58c 100644 --- a/third_party/xmlable/_manual.py +++ b/third_party/xmlable/_manual.py @@ -8,9 +8,9 @@ from typing import Any from lxml.etree import _Element, Element, _ElementTree, ElementTree from lxml.objectify import ObjectifiedElement -from xmlable._utils import typename, AnyType, ordered_iter -from xmlable._lxml_helpers import with_children, XMLSchema -from xmlable._errors import XError, XErrorCtx, ErrorTypes +from ._utils import typename, AnyType, ordered_iter +from ._lxml_helpers import with_children, XMLSchema +from ._errors import XError, XErrorCtx, ErrorTypes def validate_manual_class(cls: AnyType): diff --git a/third_party/xmlable/_user.py b/third_party/xmlable/_user.py index d94db60..8bcbb7e 100644 --- a/third_party/xmlable/_user.py +++ b/third_party/xmlable/_user.py @@ -6,8 +6,8 @@ The IXmlify interface from abc import ABC, abstractmethod from lxml.etree import _Element -from xmlable._xobject import XObject -from xmlable._utils import AnyType +from ._xobject import XObject +from ._utils import AnyType class IXmlify(ABC): diff --git a/third_party/xmlable/_xmlify.py b/third_party/xmlable/_xmlify.py index 050b069..9b07264 100644 --- a/third_party/xmlable/_xmlify.py +++ b/third_party/xmlable/_xmlify.py @@ -15,11 +15,11 @@ from typing import Any, dataclass_transform, cast from lxml.objectify import ObjectifiedElement from lxml.etree import Element, _Element -from xmlable._utils import get, typename, AnyType -from xmlable._errors import XError, XErrorCtx, ErrorTypes -from xmlable._manual import manual_xmlify -from xmlable._lxml_helpers import with_children, with_child, XMLSchema -from xmlable._xobject import XObject, gen_xobject +from ._utils import get, typename, AnyType +from ._errors import XError, XErrorCtx, ErrorTypes +from ._manual import manual_xmlify +from ._lxml_helpers import with_children, with_child, XMLSchema +from ._xobject import XObject, gen_xobject def validate_class(cls: AnyType): diff --git a/third_party/xmlable/_xobject.py b/third_party/xmlable/_xobject.py index 766695e..4dcc6c4 100644 --- a/third_party/xmlable/_xobject.py +++ b/third_party/xmlable/_xobject.py @@ -13,9 +13,9 @@ from abc import ABC, abstractmethod from typing import Any, Callable, Type, get_args, TypeAlias, cast from types import GenericAlias -from xmlable._utils import get, typename, firstkey, AnyType -from xmlable._errors import XErrorCtx, ErrorTypes -from xmlable._lxml_helpers import ( +from ._utils import get, typename, firstkey, AnyType +from ._errors import XErrorCtx, ErrorTypes +from ._lxml_helpers import ( with_text, with_child, with_children,