Replace MessageBus with aiostream-based StreamPump
Major refactor of the message pump architecture: - Replace bus.py with stream_pump.py using aiostream for composable stream processing with natural fan-out via flatmap - Add to_id field to MessageState for explicit routing - Fix routing to use to_id.class format (e.g., "greeter.greeting") - Generate XSD schemas from xmlified payload classes - Fix xmlable imports (absolute -> relative) and parse_element ctx New features: - handlers/hello.py: Sample Greeting/GreetingResponse handler - config/organism.yaml: Sample organism configuration - 41 tests (31 unit + 10 integration) all passing Schema changes: - envelope.xsd: Allow any namespace payloads (##other -> ##any) Dependencies added to pyproject.toml: - aiostream>=0.5 (core dependency) - pyhumps, termcolor (for xmlable) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
dc16316aed
commit
82b5fcdd78
24 changed files with 2018 additions and 343 deletions
27
__init__.py
27
__init__.py
|
|
@ -3,19 +3,26 @@
|
|||
xml-pipeline
|
||||
============
|
||||
Secure, XML-centric multi-listener organism server.
|
||||
|
||||
Stream-based message pump with aiostream for fan-out handling.
|
||||
"""
|
||||
|
||||
from agentserver.agentserver import AgentServer as AgentServer
|
||||
from agentserver.xml_listener import XMLListener as XMLListener
|
||||
from agentserver.message_bus import MessageBus as MessageBus
|
||||
from agentserver.message_bus import Session as Session
|
||||
|
||||
from agentserver.message_bus import (
|
||||
StreamPump,
|
||||
ConfigLoader,
|
||||
Listener,
|
||||
MessageState,
|
||||
HandlerMetadata,
|
||||
bootstrap,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentServer",
|
||||
"XMLListener",
|
||||
"MessageBus",
|
||||
"Session",
|
||||
"StreamPump",
|
||||
"ConfigLoader",
|
||||
"Listener",
|
||||
"MessageState",
|
||||
"HandlerMetadata",
|
||||
"bootstrap",
|
||||
]
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__version__ = "0.2.0" # Bumped for aiostream pump
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
"""
|
||||
message_bus — Stream-based message pump for AgentServer v2.1
|
||||
|
||||
The message pump handles message flow through the organism:
|
||||
- YAML config → bootstrap → pump → handlers → responses → loop
|
||||
|
||||
Key classes:
|
||||
StreamPump Main pump class (queue-backed, aiostream-powered)
|
||||
ConfigLoader Load organism.yaml and resolve imports
|
||||
Listener Runtime listener with handler and routing info
|
||||
MessageState Message flowing through pipeline steps
|
||||
|
||||
Usage:
|
||||
from agentserver.message_bus import StreamPump, bootstrap
|
||||
|
||||
pump = await bootstrap("config/organism.yaml")
|
||||
await pump.inject(initial_message, thread_id, from_id)
|
||||
await pump.run()
|
||||
"""
|
||||
|
||||
from agentserver.message_bus.stream_pump import (
|
||||
StreamPump,
|
||||
ConfigLoader,
|
||||
Listener,
|
||||
ListenerConfig,
|
||||
OrganismConfig,
|
||||
bootstrap,
|
||||
)
|
||||
|
||||
from agentserver.message_bus.message_state import (
|
||||
MessageState,
|
||||
HandlerMetadata,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"StreamPump",
|
||||
"ConfigLoader",
|
||||
"Listener",
|
||||
"ListenerConfig",
|
||||
"OrganismConfig",
|
||||
"MessageState",
|
||||
"HandlerMetadata",
|
||||
"bootstrap",
|
||||
]
|
||||
|
|
@ -1,226 +0,0 @@
|
|||
"""
|
||||
bus.py — The central MessageBus and pump for AgentServer v2.1
|
||||
|
||||
This is the beating heart of the organism:
|
||||
- Owns all pipelines (one per listener + permanent system pipeline)
|
||||
- Maintains the routing table (root_tag → Listener(s))
|
||||
- Orchestrates ingress from sockets/gateways
|
||||
- Dispatches prepared messages to handlers
|
||||
- Processes handler responses (multi-payload extraction, provenance injection)
|
||||
- Guarantees thread continuity and diagnostic injection
|
||||
|
||||
Fully aligned with:
|
||||
- listener-class-v2.1.md
|
||||
- configuration-v2.1.md
|
||||
- message-pump-v2.1.md
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Awaitable, List
|
||||
from uuid import uuid4
|
||||
|
||||
from lxml import etree
|
||||
|
||||
from agentserver.message_bus.message_state import MessageState, HandlerMetadata
|
||||
from agentserver.message_bus.steps.repair import repair_step
|
||||
from agentserver.message_bus.steps.c14n import c14n_step
|
||||
from agentserver.message_bus.steps.envelope_validation import envelope_validation_step
|
||||
from agentserver.message_bus.steps.payload_extraction import payload_extraction_step
|
||||
from agentserver.message_bus.steps.thread_assignment import thread_assignment_step
|
||||
from agentserver.message_bus.steps.xsd_validation import xsd_validation_step
|
||||
from agentserver.message_bus.steps.deserialization import deserialization_step
|
||||
from agentserver.message_bus.steps.routing_resolution import routing_resolution_step
|
||||
|
||||
# Type alias for pipeline steps
|
||||
PipelineStep = Callable[[MessageState], Awaitable[MessageState]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Listener:
|
||||
"""Registered capability — defined in listener.py, referenced here."""
|
||||
name: str
|
||||
payload_class: type
|
||||
handler: Callable
|
||||
description: str
|
||||
is_agent: bool = False
|
||||
peers: list[str] | None = None
|
||||
broadcast: bool = False
|
||||
pipeline: "Pipeline" | None = None
|
||||
schema: etree.XMLSchema | None = None # cached at registration
|
||||
|
||||
|
||||
class Pipeline:
|
||||
"""One dedicated pipeline per listener (plus system pipeline)."""
|
||||
def __init__(self, steps: List[PipelineStep]):
|
||||
self.steps = steps
|
||||
|
||||
async def process(self, initial_state: MessageState) -> None:
|
||||
"""Run the full ordered pipeline on a message."""
|
||||
state = initial_state
|
||||
for step in self.steps:
|
||||
try:
|
||||
state = await step(state)
|
||||
if state.error:
|
||||
break
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
state.error = f"Pipeline step {step.__name__} crashed: {exc}"
|
||||
break
|
||||
|
||||
# After all steps — dispatch if routable
|
||||
if state.target_listeners:
|
||||
await MessageBus.get_instance().dispatcher(state)
|
||||
else:
|
||||
# Fall back to system pipeline for diagnostics
|
||||
await MessageBus.get_instance().system_pipeline.process(state)
|
||||
|
||||
|
||||
class MessageBus:
|
||||
"""Singleton message bus — the pump."""
|
||||
_instance: "MessageBus" | None = None
|
||||
|
||||
def __init__(self):
|
||||
self.routing_table: dict[str, List[Listener]] = {} # root_tag → listener(s)
|
||||
self.listeners: dict[str, Listener] = {} # name → Listener
|
||||
self.system_pipeline = Pipeline(self._build_system_steps())
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls) -> "MessageBus":
|
||||
if cls._instance is None:
|
||||
cls._instance = MessageBus()
|
||||
return cls._instance
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Default step lists
|
||||
# ------------------------------------------------------------------ #
|
||||
def _build_default_listener_steps(self) -> List[PipelineStep]:
|
||||
return [
|
||||
repair_step,
|
||||
c14n_step,
|
||||
envelope_validation_step,
|
||||
payload_extraction_step,
|
||||
thread_assignment_step,
|
||||
xsd_validation_step,
|
||||
deserialization_step,
|
||||
routing_resolution_step,
|
||||
]
|
||||
|
||||
def _build_system_steps(self) -> List[PipelineStep]:
|
||||
"""Shorter, fixed steps — no XSD/deserialization."""
|
||||
return [
|
||||
repair_step,
|
||||
c14n_step,
|
||||
envelope_validation_step,
|
||||
payload_extraction_step,
|
||||
thread_assignment_step,
|
||||
# system-specific handler that emits <huh>, boot, etc.
|
||||
self.system_handler_step,
|
||||
]
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Registration (called from listener.py)
|
||||
# ------------------------------------------------------------------ #
|
||||
def register_listener(self, listener: Listener) -> None:
|
||||
root_tag = f"{listener.name.lower()}.{listener.payload_class.__name__.lower()}"
|
||||
|
||||
if root_tag in self.routing_table and not listener.broadcast:
|
||||
raise ValueError(f"Root tag collision: {root_tag} already registered by {self.routing_table[root_tag][0].name}")
|
||||
|
||||
# Build dedicated pipeline
|
||||
steps = self._build_default_listener_steps()
|
||||
# Inject listener-specific schema for xsd_validation_step
|
||||
for step in steps:
|
||||
if step.__name__ == "xsd_validation_step":
|
||||
# We'll modify state.metadata in pipeline construction instead
|
||||
pass
|
||||
listener.pipeline = Pipeline(steps)
|
||||
|
||||
# Insert into routing
|
||||
self.routing_table.setdefault(root_tag, []).append(listener)
|
||||
self.listeners[listener.name] = listener
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Dispatcher — dumb fire-and-await
|
||||
# ------------------------------------------------------------------ #
|
||||
async def dispatcher(self, state: MessageState) -> None:
|
||||
if not state.target_listeners:
|
||||
return
|
||||
|
||||
metadata = HandlerMetadata(
|
||||
thread_id=state.thread_id or "",
|
||||
from_id=state.from_id or "unknown",
|
||||
own_name=state.target_listeners[0].name if state.target_listeners[0].is_agent else None,
|
||||
is_self_call=(state.from_id == state.target_listeners[0].name) if state.from_id else False,
|
||||
)
|
||||
|
||||
if len(state.target_listeners) == 1:
|
||||
listener = state.target_listeners[0]
|
||||
await self._process_single_handler(state, listener, metadata)
|
||||
else:
|
||||
# Broadcast — fire all in parallel, process responses as they complete
|
||||
tasks = [
|
||||
self._process_single_handler(state, listener, metadata)
|
||||
for listener in state.target_listeners
|
||||
]
|
||||
for future in asyncio.as_completed(tasks):
|
||||
await future
|
||||
|
||||
async def _process_single_handler(self, state: MessageState, listener: Listener, metadata: HandlerMetadata) -> None:
|
||||
try:
|
||||
response_bytes = await listener.handler(state.payload, metadata)
|
||||
|
||||
if response_bytes is None or not isinstance(response_bytes, bytes):
|
||||
response_bytes = b"<huh>Handler failed to return valid bytes — missing return or wrong type</huh>"
|
||||
|
||||
payloads = await self._multi_payload_extract(response_bytes)
|
||||
|
||||
for payload_bytes in payloads:
|
||||
new_state = MessageState(
|
||||
raw_bytes=payload_bytes,
|
||||
thread_id=state.thread_id,
|
||||
from_id=listener.name,
|
||||
)
|
||||
# Route the new payload through normal pipelines
|
||||
root_tag = self._derive_root_tag(payload_bytes)
|
||||
targets = self.routing_table.get(root_tag)
|
||||
if targets:
|
||||
new_state.target_listeners = targets
|
||||
await targets[0].pipeline.process(new_state)
|
||||
else:
|
||||
await self.system_pipeline.process(new_state)
|
||||
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
error_state = MessageState(
|
||||
raw_bytes=b"<huh>Handler crashed</huh>",
|
||||
thread_id=state.thread_id,
|
||||
from_id=listener.name,
|
||||
error=f"Handler {listener.name} crashed: {exc}",
|
||||
)
|
||||
await self.system_pipeline.process(error_state)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Helper methods
|
||||
# ------------------------------------------------------------------ #
|
||||
async def _multi_payload_extract(self, raw_bytes: bytes) -> List[bytes]:
|
||||
# Same logic as before — dummy wrap, repair, extract all root elements
|
||||
# (implementation can be moved to a shared util later)
|
||||
# For now, placeholder — we'll flesh this out in response_processing.py
|
||||
return [raw_bytes] # temporary — will be full extraction
|
||||
|
||||
def _derive_root_tag(self, payload_bytes: bytes) -> str:
|
||||
# Quick parse to get root tag — used only for routing extracted payloads
|
||||
try:
|
||||
tree = etree.fromstring(payload_bytes)
|
||||
tag = tree.tag
|
||||
if tag.startswith("{"):
|
||||
return tag.split("}", 1)[1] # strip namespace
|
||||
return tag
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
async def system_handler_step(self, state: MessageState) -> MessageState:
|
||||
# Emit <huh> or boot message — placeholder for now
|
||||
state.error = state.error or "Unhandled by any listener"
|
||||
return state
|
||||
|
|
@ -1,6 +1,12 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from lxml.etree import Element
|
||||
from typing import Any
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lxml.etree import _Element as Element
|
||||
else:
|
||||
Element = Any # Runtime: don't need the actual type
|
||||
|
||||
"""
|
||||
default_listener_steps = [
|
||||
|
|
@ -33,6 +39,7 @@ class MessageState:
|
|||
|
||||
thread_id: str | None = None
|
||||
from_id: str | None = None
|
||||
to_id: str | None = None # Target listener name for routing
|
||||
|
||||
target_listeners: list['Listener'] | None = None # Forward reference
|
||||
|
||||
|
|
|
|||
|
|
@ -1,29 +1,29 @@
|
|||
"""
|
||||
deserialization.py — Convert validated payload_tree into typed dataclass instance.
|
||||
|
||||
After xsd_validation_step confirms the payload conforms to the listener's contract,
|
||||
this step uses the xmlable library to deserialize the lxml Element into the
|
||||
registered @xmlify dataclass.
|
||||
|
||||
The resulting instance is placed in state.payload and handed to the handler.
|
||||
After xsd_validation_step confirms the payload conforms to the contract,
|
||||
this step uses our customized xmlable routines to deserialize the lxml Element
|
||||
directly in memory — no temporary files needed.
|
||||
|
||||
Part of AgentServer v2.1 message pump.
|
||||
"""
|
||||
|
||||
from xmlable import from_xml # from the xmlable library
|
||||
from lxml.etree import _Element
|
||||
from agentserver.message_bus.message_state import MessageState
|
||||
|
||||
# Import the customized parse_element from your forked xmlable
|
||||
from third_party.xmlable import parse_element # adjust path if needed
|
||||
|
||||
|
||||
async def deserialization_step(state: MessageState) -> MessageState:
|
||||
"""
|
||||
Deserialize the validated payload_tree into the listener's dataclass.
|
||||
Deserialize the validated payload_tree into the listener's @xmlify dataclass.
|
||||
|
||||
Requires:
|
||||
- state.payload_tree valid against listener XSD
|
||||
- state.metadata["payload_class"] set to the target dataclass (set at registration)
|
||||
- state.payload_tree: validated lxml Element
|
||||
- state.metadata["payload_class"]: the target dataclass
|
||||
|
||||
On success: state.payload = dataclass instance
|
||||
On failure: state.error set with clear message
|
||||
Uses the custom parse_element routine for direct in-memory deserialization.
|
||||
"""
|
||||
if state.payload_tree is None:
|
||||
state.error = "deserialization_step: no payload_tree (previous step failed)"
|
||||
|
|
@ -35,8 +35,8 @@ async def deserialization_step(state: MessageState) -> MessageState:
|
|||
return state
|
||||
|
||||
try:
|
||||
# xmlable.from_xml handles namespace-aware deserialization
|
||||
instance = from_xml(payload_class, state.payload_tree)
|
||||
# Direct in-memory deserialization — fast and clean
|
||||
instance = parse_element(payload_class, state.payload_tree)
|
||||
state.payload = instance
|
||||
|
||||
except Exception as exc: # pylint: disable=broad-except
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@
|
|||
payload_extraction.py — Extract the inner payload from the validated <message> envelope.
|
||||
|
||||
After envelope_validation_step confirms a correct outer <message> envelope,
|
||||
this step removes the envelope elements (<thread>, <from>, optional <to>, etc.)
|
||||
and isolates the single child element that is the actual payload.
|
||||
this step extracts metadata from <meta> and isolates the single payload element.
|
||||
|
||||
The payload is expected to be exactly one root element (the capability-specific XML).
|
||||
If zero or multiple payload roots are found, we set a clear error — this protects
|
||||
|
|
@ -17,19 +16,25 @@ from agentserver.message_bus.message_state import MessageState
|
|||
|
||||
# Envelope namespace for easy reference
|
||||
_ENVELOPE_NS = "https://xml-pipeline.org/ns/envelope/v1"
|
||||
_MESSAGE_TAG = f"{{{ _ENVELOPE_NS }}}message"
|
||||
_MESSAGE_TAG = f"{{{_ENVELOPE_NS}}}message"
|
||||
_META_TAG = f"{{{_ENVELOPE_NS}}}meta"
|
||||
_FROM_TAG = f"{{{_ENVELOPE_NS}}}from"
|
||||
_TO_TAG = f"{{{_ENVELOPE_NS}}}to"
|
||||
_THREAD_TAG = f"{{{_ENVELOPE_NS}}}thread"
|
||||
|
||||
|
||||
async def payload_extraction_step(state: MessageState) -> MessageState:
|
||||
"""
|
||||
Extract the single payload element from the validated envelope.
|
||||
|
||||
Expected structure:
|
||||
Expected structure (per envelope.xsd):
|
||||
<message xmlns="https://xml-pipeline.org/ns/envelope/v1">
|
||||
<thread>uuid</thread>
|
||||
<meta>
|
||||
<from>sender</from>
|
||||
<!-- optional <to>receiver</to> -->
|
||||
<payload_root> ← this is the one we want
|
||||
<to>receiver</to> <!-- optional -->
|
||||
<thread>uuid</thread>
|
||||
</meta>
|
||||
<payload_root xmlns="..."> ← this is what we extract
|
||||
...
|
||||
</payload_root>
|
||||
</message>
|
||||
|
|
@ -41,24 +46,42 @@ async def payload_extraction_step(state: MessageState) -> MessageState:
|
|||
state.error = "payload_extraction_step: no envelope_tree (previous step failed)"
|
||||
return state
|
||||
|
||||
# Basic sanity — root must be <message> in correct namespace (already checked by schema,
|
||||
# but we double-check for defence in depth)
|
||||
# Basic sanity — root must be <message> in correct namespace
|
||||
if state.envelope_tree.tag != _MESSAGE_TAG:
|
||||
state.error = f"payload_extraction_step: root tag is not <message> in envelope namespace"
|
||||
state.error = "payload_extraction_step: root tag is not <message> in envelope namespace"
|
||||
return state
|
||||
|
||||
# Find all direct children that are not envelope control elements
|
||||
# Envelope control elements are: thread, from, to (optional)
|
||||
# Find <meta> block and extract provenance
|
||||
meta_elem = state.envelope_tree.find(_META_TAG)
|
||||
if meta_elem is None:
|
||||
state.error = "payload_extraction_step: missing <meta> block in envelope"
|
||||
return state
|
||||
|
||||
# Extract from_id (required)
|
||||
from_elem = meta_elem.find(_FROM_TAG)
|
||||
if from_elem is not None and from_elem.text:
|
||||
state.from_id = from_elem.text.strip()
|
||||
else:
|
||||
state.error = "payload_extraction_step: missing <from> in <meta>"
|
||||
return state
|
||||
|
||||
# Extract thread_id (required)
|
||||
thread_elem = meta_elem.find(_THREAD_TAG)
|
||||
if thread_elem is not None and thread_elem.text:
|
||||
state.thread_id = thread_elem.text.strip()
|
||||
else:
|
||||
state.error = "payload_extraction_step: missing <thread> in <meta>"
|
||||
return state
|
||||
|
||||
# Optional: extract <to> for direct routing
|
||||
to_elem = meta_elem.find(_TO_TAG)
|
||||
if to_elem is not None and to_elem.text:
|
||||
state.to_id = to_elem.text.strip()
|
||||
|
||||
# Find all direct children that are NOT <meta> — those are payload candidates
|
||||
payload_candidates = [
|
||||
child
|
||||
for child in state.envelope_tree
|
||||
if not (
|
||||
child.tag in {
|
||||
f"{{{ _ENVELOPE_NS }}}thread",
|
||||
f"{{{ _ENVELOPE_NS }}}from",
|
||||
f"{{{ _ENVELOPE_NS }}}to",
|
||||
}
|
||||
)
|
||||
child for child in state.envelope_tree
|
||||
if child.tag != _META_TAG
|
||||
]
|
||||
|
||||
if len(payload_candidates) == 0:
|
||||
|
|
@ -73,19 +96,6 @@ async def payload_extraction_step(state: MessageState) -> MessageState:
|
|||
return state
|
||||
|
||||
# Success — exactly one payload element
|
||||
payload_element = payload_candidates[0]
|
||||
|
||||
# Optional: capture provenance from envelope for later use
|
||||
# (these will be trustworthy because envelope was validated)
|
||||
thread_elem = state.envelope_tree.find(f"{{{ _ENVELOPE_NS }}}thread")
|
||||
from_elem = state.envelope_tree.find(f"{{{ _ENVELOPE_NS }}}from")
|
||||
|
||||
if thread_elem is not None and thread_elem.text:
|
||||
state.thread_id = thread_elem.text.strip()
|
||||
|
||||
if from_elem is not None and from_elem.text:
|
||||
state.from_id = from_elem.text.strip()
|
||||
|
||||
state.payload_tree = payload_element
|
||||
state.payload_tree = payload_candidates[0]
|
||||
|
||||
return state
|
||||
|
|
@ -1,50 +1,70 @@
|
|||
"""
|
||||
routing_resolution.py — Resolve routing based on derived root tag.
|
||||
|
||||
This is the final preparation step before dispatch.
|
||||
It computes the root tag from the deserialized payload and looks it up in the
|
||||
global routing table (root_tag → list[Listener]).
|
||||
This step computes the root tag from the deserialized payload and looks it up
|
||||
in a routing table (root_tag → list[Listener]).
|
||||
|
||||
On success: state.target_listeners is set
|
||||
On failure: state.error is set → message falls to system pipeline for <huh>
|
||||
NOTE: The StreamPump has routing built-in via _route_step(). This standalone
|
||||
step is provided for custom pipeline configurations or testing.
|
||||
|
||||
Usage:
|
||||
routing_step = make_routing_step(routing_table)
|
||||
state = await routing_step(state)
|
||||
|
||||
Part of AgentServer v2.1 message pump.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Callable, Awaitable, TYPE_CHECKING
|
||||
|
||||
from agentserver.message_bus.message_state import MessageState
|
||||
from agentserver.message_bus.bus import MessageBus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agentserver.message_bus.stream_pump import Listener
|
||||
|
||||
|
||||
async def routing_resolution_step(state: MessageState) -> MessageState:
|
||||
def make_routing_step(
|
||||
routing_table: Dict[str, List["Listener"]]
|
||||
) -> Callable[[MessageState], Awaitable[MessageState]]:
|
||||
"""
|
||||
Factory: create a routing step with a specific routing table.
|
||||
|
||||
The routing table maps root tags to lists of listeners:
|
||||
{"agent.payload": [listener1, listener2], ...}
|
||||
"""
|
||||
|
||||
async def routing_resolution_step(state: MessageState) -> MessageState:
|
||||
"""
|
||||
Resolve which listener(s) should handle this payload.
|
||||
|
||||
Root tag = f"{from_id.lower()}.{payload_class_name.lower()}"
|
||||
(from_id is trustworthy — injected by pump)
|
||||
|
||||
Supports:
|
||||
- Normal unique routing (one listener)
|
||||
- Broadcast (multiple listeners if broadcast: true and same root tag)
|
||||
- Broadcast (multiple listeners if same root tag)
|
||||
|
||||
If no match → error, falls to system pipeline.
|
||||
"""
|
||||
if state.payload is None:
|
||||
state.error = "routing_resolution_step: no deserialized payload (previous step failed)"
|
||||
state.error = "routing_resolution_step: no deserialized payload"
|
||||
return state
|
||||
|
||||
if state.from_id is None:
|
||||
state.error = "routing_resolution_step: missing from_id (provenance error)"
|
||||
if state.to_id is None:
|
||||
state.error = "routing_resolution_step: missing to_id"
|
||||
return state
|
||||
|
||||
payload_class_name = type(state.payload).__name__.lower()
|
||||
root_tag = f"{state.from_id.lower()}.{payload_class_name}"
|
||||
root_tag = f"{state.to_id.lower()}.{payload_class_name}"
|
||||
|
||||
bus = MessageBus.get_instance()
|
||||
targets = bus.routing_table.get(root_tag)
|
||||
targets = routing_table.get(root_tag)
|
||||
|
||||
if not targets:
|
||||
state.error = f"routing_resolution_step: unknown capability root tag '{root_tag}'"
|
||||
state.error = f"routing_resolution_step: unknown root tag '{root_tag}'"
|
||||
return state
|
||||
|
||||
state.target_listeners = targets
|
||||
return state
|
||||
|
||||
routing_resolution_step.__name__ = "routing_resolution_step"
|
||||
return routing_resolution_step
|
||||
|
|
|
|||
592
agentserver/message_bus/stream_pump.py
Normal file
592
agentserver/message_bus/stream_pump.py
Normal file
|
|
@ -0,0 +1,592 @@
|
|||
"""
|
||||
pump_aiostream.py — Stream-Based Message Pump using aiostream
|
||||
|
||||
This implementation treats the entire message flow as composable streams.
|
||||
Fan-out (multi-payload, broadcast) is handled naturally via flatmap.
|
||||
|
||||
Key insight: Each step is a stream transformer, not a 1:1 function.
|
||||
The pipeline is just a composition of stream operators.
|
||||
|
||||
Dependencies:
|
||||
pip install aiostream
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterable, Callable, List, Dict, Any, Optional
|
||||
|
||||
import yaml
|
||||
from lxml import etree
|
||||
from aiostream import stream, pipe, operator
|
||||
|
||||
# Import existing step implementations (we'll wrap them)
|
||||
from agentserver.message_bus.steps.repair import repair_step
|
||||
from agentserver.message_bus.steps.c14n import c14n_step
|
||||
from agentserver.message_bus.steps.envelope_validation import envelope_validation_step
|
||||
from agentserver.message_bus.steps.payload_extraction import payload_extraction_step
|
||||
from agentserver.message_bus.steps.thread_assignment import thread_assignment_step
|
||||
from agentserver.message_bus.message_state import MessageState, HandlerMetadata
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration (same as before)
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class ListenerConfig:
|
||||
name: str
|
||||
payload_class_path: str
|
||||
handler_path: str
|
||||
description: str
|
||||
is_agent: bool = False
|
||||
peers: List[str] = field(default_factory=list)
|
||||
broadcast: bool = False
|
||||
payload_class: type = field(default=None, repr=False)
|
||||
handler: Callable = field(default=None, repr=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrganismConfig:
|
||||
name: str
|
||||
identity_path: str = ""
|
||||
port: int = 8765
|
||||
thread_scheduling: str = "breadth-first"
|
||||
listeners: List[ListenerConfig] = field(default_factory=list)
|
||||
|
||||
# Concurrency tuning
|
||||
max_concurrent_pipelines: int = 50 # Total concurrent messages in pipeline
|
||||
max_concurrent_handlers: int = 20 # Concurrent handler invocations
|
||||
max_concurrent_per_agent: int = 5 # Per-agent rate limit
|
||||
|
||||
|
||||
@dataclass
|
||||
class Listener:
|
||||
name: str
|
||||
payload_class: type
|
||||
handler: Callable
|
||||
description: str
|
||||
is_agent: bool = False
|
||||
peers: List[str] = field(default_factory=list)
|
||||
broadcast: bool = False
|
||||
schema: etree.XMLSchema = field(default=None, repr=False)
|
||||
root_tag: str = ""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Stream-Based Pipeline Steps
|
||||
# ============================================================================
|
||||
|
||||
def wrap_step(step_fn: Callable) -> Callable:
|
||||
"""
|
||||
Wrap an existing async step function for use with pipe.map.
|
||||
|
||||
Existing steps: async def step(state) -> state
|
||||
We keep them as-is since pipe.map handles the iteration.
|
||||
"""
|
||||
return step_fn
|
||||
|
||||
|
||||
async def extract_payloads(state: MessageState) -> AsyncIterable[MessageState]:
|
||||
"""
|
||||
Fan-out step: Extract 1..N payloads from handler response.
|
||||
|
||||
This is used with pipe.flatmap — yields multiple states for each input.
|
||||
"""
|
||||
if state.raw_bytes is None:
|
||||
yield state
|
||||
return
|
||||
|
||||
try:
|
||||
# Wrap in dummy to handle multiple roots
|
||||
wrapped = b"<dummy>" + state.raw_bytes + b"</dummy>"
|
||||
tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True))
|
||||
|
||||
children = list(tree)
|
||||
if not children:
|
||||
yield state
|
||||
return
|
||||
|
||||
for child in children:
|
||||
payload_bytes = etree.tostring(child)
|
||||
yield MessageState(
|
||||
raw_bytes=payload_bytes,
|
||||
thread_id=state.thread_id,
|
||||
from_id=state.from_id,
|
||||
metadata=state.metadata.copy(),
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# On parse failure, pass through as-is
|
||||
yield state
|
||||
|
||||
|
||||
def make_xsd_validation(schema: etree.XMLSchema) -> Callable:
|
||||
"""Factory for XSD validation step with schema baked in."""
|
||||
async def validate(state: MessageState) -> MessageState:
|
||||
if state.payload_tree is None or state.error:
|
||||
return state
|
||||
try:
|
||||
schema.assertValid(state.payload_tree)
|
||||
except etree.DocumentInvalid as e:
|
||||
state.error = f"XSD validation failed: {e}"
|
||||
return state
|
||||
return validate
|
||||
|
||||
|
||||
def make_deserialization(payload_class: type) -> Callable:
|
||||
"""Factory for deserialization step with class baked in."""
|
||||
from third_party.xmlable import parse_element
|
||||
|
||||
async def deserialize(state: MessageState) -> MessageState:
|
||||
if state.payload_tree is None or state.error:
|
||||
return state
|
||||
try:
|
||||
state.payload = parse_element(payload_class, state.payload_tree)
|
||||
except Exception as e:
|
||||
state.error = f"Deserialization failed: {e}"
|
||||
return state
|
||||
return deserialize
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# The Stream-Based Pump
|
||||
# ============================================================================
|
||||
|
||||
class StreamPump:
|
||||
"""
|
||||
Message pump built on aiostream.
|
||||
|
||||
The entire flow is a single composable stream pipeline.
|
||||
Fan-out is natural via flatmap. Concurrency is controlled via task_limit.
|
||||
"""
|
||||
|
||||
def __init__(self, config: OrganismConfig):
|
||||
self.config = config
|
||||
|
||||
# Message queue feeds the stream
|
||||
self.queue: asyncio.Queue[MessageState] = asyncio.Queue()
|
||||
|
||||
# Routing table
|
||||
self.routing_table: Dict[str, List[Listener]] = {}
|
||||
self.listeners: Dict[str, Listener] = {}
|
||||
|
||||
# Per-agent semaphores for rate limiting
|
||||
self.agent_semaphores: Dict[str, asyncio.Semaphore] = {}
|
||||
|
||||
# Shutdown control
|
||||
self._running = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Registration
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def register_listener(self, lc: ListenerConfig) -> Listener:
|
||||
root_tag = f"{lc.name.lower()}.{lc.payload_class.__name__.lower()}"
|
||||
|
||||
listener = Listener(
|
||||
name=lc.name,
|
||||
payload_class=lc.payload_class,
|
||||
handler=lc.handler,
|
||||
description=lc.description,
|
||||
is_agent=lc.is_agent,
|
||||
peers=lc.peers,
|
||||
broadcast=lc.broadcast,
|
||||
schema=self._generate_schema(lc.payload_class),
|
||||
root_tag=root_tag,
|
||||
)
|
||||
|
||||
if lc.is_agent:
|
||||
self.agent_semaphores[lc.name] = asyncio.Semaphore(
|
||||
self.config.max_concurrent_per_agent
|
||||
)
|
||||
|
||||
self.routing_table.setdefault(root_tag, []).append(listener)
|
||||
self.listeners[lc.name] = listener
|
||||
return listener
|
||||
|
||||
def register_all(self) -> None:
|
||||
for lc in self.config.listeners:
|
||||
self.register_listener(lc)
|
||||
|
||||
def _generate_schema(self, payload_class: type) -> etree.XMLSchema:
|
||||
"""Generate XSD schema from xmlified payload class."""
|
||||
if hasattr(payload_class, 'xsd'):
|
||||
xsd_tree = payload_class.xsd()
|
||||
return etree.XMLSchema(xsd_tree)
|
||||
# Fallback for non-xmlified classes (e.g., in tests)
|
||||
permissive = '<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema"><xs:any processContents="lax"/></xs:schema>'
|
||||
return etree.XMLSchema(etree.fromstring(permissive.encode()))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Stream Source
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _queue_source(self) -> AsyncIterable[MessageState]:
|
||||
"""Async generator that yields messages from the queue."""
|
||||
while self._running:
|
||||
try:
|
||||
state = await asyncio.wait_for(self.queue.get(), timeout=0.5)
|
||||
yield state
|
||||
self.queue.task_done()
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Pipeline Steps (as stream operators)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _route_step(self, state: MessageState) -> MessageState:
|
||||
"""Determine target listeners based on to_id.class format."""
|
||||
if state.error or state.payload is None:
|
||||
return state
|
||||
|
||||
payload_class_name = type(state.payload).__name__.lower()
|
||||
to_id = (state.to_id or "").lower()
|
||||
root_tag = f"{to_id}.{payload_class_name}" if to_id else payload_class_name
|
||||
|
||||
targets = self.routing_table.get(root_tag)
|
||||
if targets:
|
||||
state.target_listeners = targets
|
||||
else:
|
||||
state.error = f"No listener for: {root_tag}"
|
||||
|
||||
return state
|
||||
|
||||
async def _dispatch_to_handlers(self, state: MessageState) -> AsyncIterable[MessageState]:
|
||||
"""
|
||||
Fan-out step: Dispatch to handler(s) and yield response states.
|
||||
|
||||
For broadcast, yields one response per listener.
|
||||
Each response becomes a new message in the stream.
|
||||
"""
|
||||
if state.error or not state.target_listeners:
|
||||
# Pass through errors/unroutable for downstream handling
|
||||
yield state
|
||||
return
|
||||
|
||||
for listener in state.target_listeners:
|
||||
try:
|
||||
# Rate limiting for agents
|
||||
semaphore = self.agent_semaphores.get(listener.name)
|
||||
if semaphore:
|
||||
await semaphore.acquire()
|
||||
|
||||
try:
|
||||
metadata = HandlerMetadata(
|
||||
thread_id=state.thread_id or "",
|
||||
from_id=state.from_id or "",
|
||||
own_name=listener.name if listener.is_agent else None,
|
||||
)
|
||||
|
||||
response_bytes = await listener.handler(state.payload, metadata)
|
||||
|
||||
if not isinstance(response_bytes, bytes):
|
||||
response_bytes = b"<huh>Handler returned invalid type</huh>"
|
||||
|
||||
# Yield response — will be processed by next iteration
|
||||
yield MessageState(
|
||||
raw_bytes=response_bytes,
|
||||
thread_id=state.thread_id,
|
||||
from_id=listener.name,
|
||||
)
|
||||
|
||||
finally:
|
||||
if semaphore:
|
||||
semaphore.release()
|
||||
|
||||
except Exception as exc:
|
||||
yield MessageState(
|
||||
raw_bytes=f"<huh>Handler {listener.name} crashed: {exc}</huh>".encode(),
|
||||
thread_id=state.thread_id,
|
||||
from_id=listener.name,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
async def _reinject_responses(self, state: MessageState) -> None:
|
||||
"""Push handler responses back into the queue for next iteration."""
|
||||
await self.queue.put(state)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Build the Pipeline
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def build_pipeline(self, source: AsyncIterable[MessageState]):
|
||||
"""
|
||||
Construct the full processing pipeline.
|
||||
|
||||
This is where you configure the flow. Modify this method to:
|
||||
- Add/remove steps
|
||||
- Change concurrency limits
|
||||
- Insert logging/metrics
|
||||
- Add filtering
|
||||
"""
|
||||
|
||||
# The pipeline is a composition of stream operators
|
||||
pipeline = (
|
||||
stream.iterate(source)
|
||||
|
||||
# ============================================================
|
||||
# STAGE 1: Envelope Processing (1:1 transforms)
|
||||
# ============================================================
|
||||
| pipe.map(repair_step)
|
||||
| pipe.map(c14n_step)
|
||||
| pipe.map(envelope_validation_step)
|
||||
| pipe.map(payload_extraction_step)
|
||||
| pipe.map(thread_assignment_step)
|
||||
|
||||
# ============================================================
|
||||
# STAGE 2: Fan-out — Extract Multiple Payloads (1:N)
|
||||
# ============================================================
|
||||
# Handler responses may contain multiple payloads.
|
||||
# Each becomes a separate message in the stream.
|
||||
| pipe.flatmap(extract_payloads)
|
||||
|
||||
# ============================================================
|
||||
# STAGE 3: Per-Payload Validation (1:1 transforms)
|
||||
# ============================================================
|
||||
# Note: In a real implementation, you'd route to listener-specific
|
||||
# validation here. For now, we use a simplified approach.
|
||||
| pipe.map(self._validate_and_deserialize)
|
||||
|
||||
# ============================================================
|
||||
# STAGE 4: Routing (1:1)
|
||||
# ============================================================
|
||||
| pipe.map(self._route_step)
|
||||
|
||||
# ============================================================
|
||||
# STAGE 5: Filter Errors
|
||||
# ============================================================
|
||||
# Errors go to a separate handler (could also be a branch)
|
||||
| pipe.map(self._handle_errors)
|
||||
| pipe.filter(lambda s: s.error is None and s.target_listeners)
|
||||
|
||||
# ============================================================
|
||||
# STAGE 6: Fan-out — Dispatch to Handlers (1:N for broadcast)
|
||||
# ============================================================
|
||||
# This is where handlers are invoked. Broadcast = multiple yields.
|
||||
# task_limit controls concurrent handler invocations.
|
||||
| pipe.flatmap(
|
||||
self._dispatch_to_handlers,
|
||||
task_limit=self.config.max_concurrent_handlers
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# STAGE 7: Re-inject Responses
|
||||
# ============================================================
|
||||
# Handler responses go back into the queue for next iteration.
|
||||
# The cycle continues until no more messages.
|
||||
| pipe.action(self._reinject_responses)
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
async def _validate_and_deserialize(self, state: MessageState) -> MessageState:
|
||||
"""
|
||||
Combined validation + deserialization.
|
||||
|
||||
Uses to_id + payload tag to find the right listener and schema.
|
||||
"""
|
||||
if state.error or state.payload_tree is None:
|
||||
return state
|
||||
|
||||
# Build lookup key: to_id.payload_tag (matching routing table format)
|
||||
payload_tag = state.payload_tree.tag
|
||||
if payload_tag.startswith("{"):
|
||||
payload_tag = payload_tag.split("}", 1)[1]
|
||||
|
||||
to_id = (state.to_id or "").lower()
|
||||
lookup_key = f"{to_id}.{payload_tag.lower()}" if to_id else payload_tag.lower()
|
||||
|
||||
listeners = self.routing_table.get(lookup_key, [])
|
||||
if not listeners:
|
||||
state.error = f"No listener for: {lookup_key}"
|
||||
return state
|
||||
|
||||
listener = listeners[0]
|
||||
|
||||
# Validate against listener's schema
|
||||
try:
|
||||
listener.schema.assertValid(state.payload_tree)
|
||||
except etree.DocumentInvalid as e:
|
||||
state.error = f"XSD validation failed: {e}"
|
||||
return state
|
||||
|
||||
# Deserialize
|
||||
try:
|
||||
from third_party.xmlable import parse_element
|
||||
state.payload = parse_element(listener.payload_class, state.payload_tree)
|
||||
except Exception as e:
|
||||
state.error = f"Deserialization failed: {e}"
|
||||
|
||||
return state
|
||||
|
||||
async def _handle_errors(self, state: MessageState) -> MessageState:
|
||||
"""Log errors (could also emit <huh> messages)."""
|
||||
if state.error:
|
||||
print(f"[ERROR] {state.thread_id}: {state.error}")
|
||||
# Could emit <huh> to a specific listener here
|
||||
return state
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Run the Pump
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def run(self) -> None:
|
||||
"""
|
||||
Main entry point — run the stream pipeline.
|
||||
|
||||
The pipeline pulls from the queue, processes messages,
|
||||
and re-injects handler responses. Continues until shutdown.
|
||||
"""
|
||||
self._running = True
|
||||
|
||||
pipeline = self.build_pipeline(self._queue_source())
|
||||
|
||||
try:
|
||||
async with pipeline.stream() as streamer:
|
||||
async for _ in streamer:
|
||||
# The pipeline drives itself via re-injection.
|
||||
# We just need to consume the stream.
|
||||
pass
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
self._running = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# External API
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def inject(self, raw_bytes: bytes, thread_id: str, from_id: str) -> None:
|
||||
"""Inject a message to start processing."""
|
||||
state = MessageState(
|
||||
raw_bytes=raw_bytes,
|
||||
thread_id=thread_id,
|
||||
from_id=from_id,
|
||||
)
|
||||
await self.queue.put(state)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Graceful shutdown — wait for queue to drain."""
|
||||
self._running = False
|
||||
await self.queue.join()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Config Loader (same as before)
|
||||
# ============================================================================
|
||||
|
||||
class ConfigLoader:
|
||||
@classmethod
|
||||
def load(cls, path: str | Path) -> OrganismConfig:
|
||||
with open(Path(path)) as f:
|
||||
raw = yaml.safe_load(f)
|
||||
return cls._parse(raw)
|
||||
|
||||
@classmethod
|
||||
def _parse(cls, raw: dict) -> OrganismConfig:
|
||||
org = raw.get("organism", {})
|
||||
config = OrganismConfig(
|
||||
name=org.get("name", "unnamed"),
|
||||
identity_path=org.get("identity", ""),
|
||||
port=org.get("port", 8765),
|
||||
thread_scheduling=raw.get("thread_scheduling", "breadth-first"),
|
||||
max_concurrent_pipelines=raw.get("max_concurrent_pipelines", 50),
|
||||
max_concurrent_handlers=raw.get("max_concurrent_handlers", 20),
|
||||
max_concurrent_per_agent=raw.get("max_concurrent_per_agent", 5),
|
||||
)
|
||||
|
||||
for entry in raw.get("listeners", []):
|
||||
lc = cls._parse_listener(entry)
|
||||
cls._resolve_imports(lc)
|
||||
config.listeners.append(lc)
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _parse_listener(cls, raw: dict) -> ListenerConfig:
|
||||
return ListenerConfig(
|
||||
name=raw["name"],
|
||||
payload_class_path=raw["payload_class"],
|
||||
handler_path=raw["handler"],
|
||||
description=raw["description"],
|
||||
is_agent=raw.get("agent", False),
|
||||
peers=raw.get("peers", []),
|
||||
broadcast=raw.get("broadcast", False),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _resolve_imports(cls, lc: ListenerConfig) -> None:
|
||||
mod, cls_name = lc.payload_class_path.rsplit(".", 1)
|
||||
lc.payload_class = getattr(importlib.import_module(mod), cls_name)
|
||||
|
||||
mod, fn_name = lc.handler_path.rsplit(".", 1)
|
||||
lc.handler = getattr(importlib.import_module(mod), fn_name)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Bootstrap
|
||||
# ============================================================================
|
||||
|
||||
async def bootstrap(config_path: str = "config/organism.yaml") -> StreamPump:
|
||||
"""Load config and create pump."""
|
||||
config = ConfigLoader.load(config_path)
|
||||
print(f"Organism: {config.name}")
|
||||
print(f"Listeners: {len(config.listeners)}")
|
||||
|
||||
pump = StreamPump(config)
|
||||
pump.register_all()
|
||||
|
||||
print(f"Routing: {list(pump.routing_table.keys())}")
|
||||
return pump
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Example: Customizing the Pipeline
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
The beauty of aiostream: the pipeline is just a composition.
|
||||
You can easily insert, remove, or reorder stages.
|
||||
|
||||
# Add logging between stages:
|
||||
| pipe.action(lambda s: print(f"After repair: {s.thread_id}"))
|
||||
|
||||
# Add throttling:
|
||||
| pipe.map(some_step, task_limit=5)
|
||||
|
||||
# Branch errors to a separate stream:
|
||||
errors, valid = stream.partition(source, lambda s: s.error is not None)
|
||||
|
||||
# Merge multiple sources:
|
||||
combined = stream.merge(queue_source, oob_source, external_api_source)
|
||||
|
||||
# Add timeout per message:
|
||||
| pipe.timeout(30.0) # 30 second timeout per item
|
||||
|
||||
# Rate limit the whole pipeline:
|
||||
| pipe.spaceout(0.1) # 100ms between items
|
||||
"""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Comparison: Old vs New
|
||||
# ============================================================================
|
||||
|
||||
"""
|
||||
OLD (bus.py):
|
||||
for payload in payloads:
|
||||
await pipeline.process(payload) # Sequential, recursive
|
||||
|
||||
NEW (aiostream):
|
||||
| pipe.flatmap(extract_payloads) # Fan-out, parallel
|
||||
| pipe.flatmap(dispatch, task_limit=20) # Concurrent handlers
|
||||
|
||||
The key difference:
|
||||
- Old: 3 tool calls = 3 sequential awaits, each blocking until complete
|
||||
- New: 3 tool calls = 3 items in stream, processed concurrently up to task_limit
|
||||
"""
|
||||
|
|
@ -20,8 +20,8 @@
|
|||
</xs:complexType>
|
||||
</xs:element>
|
||||
|
||||
<!-- Exactly one payload element from any foreign namespace -->
|
||||
<xs:any namespace="##other" processContents="lax" minOccurs="0" maxOccurs="1"/>
|
||||
<!-- Exactly one payload element (any namespace including no namespace) -->
|
||||
<xs:any namespace="##any" processContents="lax" minOccurs="0" maxOccurs="1"/>
|
||||
</xs:sequence>
|
||||
</xs:complexType>
|
||||
</xs:element>
|
||||
|
|
|
|||
24
config/organism.yaml
Normal file
24
config/organism.yaml
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
# organism.yaml — Sample configuration for testing the message pump
|
||||
#
|
||||
# This defines a simple "hello world" organism with one listener.
|
||||
|
||||
organism:
|
||||
name: hello-world
|
||||
port: 8765
|
||||
|
||||
# Concurrency settings
|
||||
max_concurrent_pipelines: 50
|
||||
max_concurrent_handlers: 20
|
||||
max_concurrent_per_agent: 5
|
||||
|
||||
# Thread scheduling: breadth-first or depth-first
|
||||
thread_scheduling: breadth-first
|
||||
|
||||
listeners:
|
||||
# The greeter listener responds to Greeting payloads
|
||||
- name: greeter
|
||||
payload_class: handlers.hello.Greeting
|
||||
handler: handlers.hello.handle_greeting
|
||||
description: Responds with a greeting message
|
||||
agent: false
|
||||
broadcast: false
|
||||
1
handlers/__init__.py
Normal file
1
handlers/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
# handlers — Sample handlers for testing the message pump
|
||||
78
handlers/hello.py
Normal file
78
handlers/hello.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
"""
|
||||
hello.py — Hello World handler for testing the message pump.
|
||||
|
||||
This module provides:
|
||||
- Greeting: payload class (what the handler receives)
|
||||
- GreetingResponse: response payload (what the handler returns)
|
||||
- handle_greeting: async handler function
|
||||
|
||||
Usage in organism.yaml:
|
||||
listeners:
|
||||
- name: greeter
|
||||
payload_class: handlers.hello.Greeting
|
||||
handler: handlers.hello.handle_greeting
|
||||
description: Responds with a greeting message
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from lxml import etree
|
||||
|
||||
from third_party.xmlable import xmlify
|
||||
from agentserver.message_bus.message_state import HandlerMetadata
|
||||
|
||||
|
||||
# Envelope namespace
|
||||
ENVELOPE_NS = "https://xml-pipeline.org/ns/envelope/v1"
|
||||
|
||||
|
||||
@xmlify
|
||||
@dataclass
|
||||
class Greeting:
|
||||
"""Incoming greeting request."""
|
||||
name: str
|
||||
|
||||
|
||||
@xmlify
|
||||
@dataclass
|
||||
class GreetingResponse:
|
||||
"""Outgoing greeting response."""
|
||||
message: str
|
||||
|
||||
|
||||
def wrap_in_envelope(payload_bytes: bytes, from_id: str, to_id: str, thread_id: str) -> bytes:
|
||||
"""Wrap a payload in a proper message envelope."""
|
||||
return f"""<message xmlns="{ENVELOPE_NS}">
|
||||
<meta>
|
||||
<from>{from_id}</from>
|
||||
<to>{to_id}</to>
|
||||
<thread>{thread_id}</thread>
|
||||
</meta>
|
||||
{payload_bytes.decode('utf-8')}
|
||||
</message>""".encode('utf-8')
|
||||
|
||||
|
||||
async def handle_greeting(payload: Greeting, metadata: HandlerMetadata) -> bytes:
|
||||
"""
|
||||
Handle an incoming Greeting and respond with a GreetingResponse.
|
||||
|
||||
Args:
|
||||
payload: The deserialized Greeting instance
|
||||
metadata: Contains thread_id, from_id, own_name
|
||||
|
||||
Returns:
|
||||
XML bytes of the response envelope
|
||||
"""
|
||||
# Create response
|
||||
response = GreetingResponse(message=f"Hello, {payload.name}!")
|
||||
|
||||
# Serialize to XML
|
||||
response_tree = response.xml_value("GreetingResponse")
|
||||
payload_bytes = etree.tostring(response_tree, encoding='utf-8')
|
||||
|
||||
# Wrap in envelope - respond back to sender
|
||||
return wrap_in_envelope(
|
||||
payload_bytes=payload_bytes,
|
||||
from_id=metadata.own_name or "greeter",
|
||||
to_id=metadata.from_id, # Send back to whoever sent the greeting
|
||||
thread_id=metadata.thread_id,
|
||||
)
|
||||
|
|
@ -5,15 +5,52 @@ build-backend = "setuptools.build_meta"
|
|||
|
||||
[project]
|
||||
name = "xml-pipeline"
|
||||
version = "0.1.0"
|
||||
version = "0.2.0"
|
||||
description = "Tamper-proof nervous system for multi-agent organisms"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11"
|
||||
license = {text = "MIT"}
|
||||
keywords = ["xml", "multi-agent", "message-bus", "aiostream"]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
"Framework :: AsyncIO",
|
||||
]
|
||||
dependencies = [
|
||||
"lxml",
|
||||
"websockets",
|
||||
"pyotp",
|
||||
"pyyaml",
|
||||
"cryptography",
|
||||
"aiostream>=0.5",
|
||||
"pyhumps",
|
||||
"termcolor",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = [
|
||||
"pytest>=7.0",
|
||||
"pytest-asyncio>=0.21",
|
||||
]
|
||||
dev = [
|
||||
"pytest>=7.0",
|
||||
"pytest-asyncio>=0.21",
|
||||
"mypy",
|
||||
"ruff",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "function"
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py"]
|
||||
# Don't collect root __init__.py (has imports that break isolation)
|
||||
norecursedirs = [".git", "__pycache__", "*.egg-info"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["agentserver*", "third_party*"]
|
||||
|
|
|
|||
48
tests/conftest.py
Normal file
48
tests/conftest.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""
|
||||
conftest.py — Shared pytest configuration and fixtures
|
||||
|
||||
This file is automatically loaded by pytest.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure the project root is in the path for imports
|
||||
# BUT don't import the root package (it has heavy deps)
|
||||
project_root = Path(__file__).parent.parent
|
||||
if str(project_root) not in sys.path:
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Tell pytest to ignore the root __init__.py
|
||||
collect_ignore_glob = ["../__init__.py"]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Markers
|
||||
# ============================================================================
|
||||
|
||||
def pytest_configure(config):
|
||||
"""Register custom markers."""
|
||||
config.addinivalue_line(
|
||||
"markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')"
|
||||
)
|
||||
config.addinivalue_line(
|
||||
"markers", "integration: marks tests as integration tests"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures available to all tests
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def sample_thread_id():
|
||||
"""A valid UUID for testing."""
|
||||
return "550e8400-e29b-41d4-a716-446655440000"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_from_id():
|
||||
"""A valid sender ID for testing."""
|
||||
return "calculator.add"
|
||||
632
tests/test_pipeline_steps.py
Normal file
632
tests/test_pipeline_steps.py
Normal file
|
|
@ -0,0 +1,632 @@
|
|||
"""
|
||||
test_pipeline_steps.py — Unit tests for individual pipeline steps
|
||||
|
||||
Run with: pytest tests/test_pipeline_steps.py -v
|
||||
|
||||
Each step is tested in isolation with known inputs and expected outputs.
|
||||
This makes debugging much easier than testing the full pipeline.
|
||||
|
||||
Install test dependencies:
|
||||
pip install -e ".[test]"
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from lxml import etree
|
||||
|
||||
# Import the message state
|
||||
from agentserver.message_bus.message_state import MessageState, HandlerMetadata
|
||||
|
||||
# Import individual steps
|
||||
from agentserver.message_bus.steps.repair import repair_step
|
||||
from agentserver.message_bus.steps.c14n import c14n_step
|
||||
from agentserver.message_bus.steps.envelope_validation import envelope_validation_step
|
||||
from agentserver.message_bus.steps.payload_extraction import payload_extraction_step
|
||||
from agentserver.message_bus.steps.thread_assignment import thread_assignment_step
|
||||
|
||||
# Check for optional dependencies
|
||||
try:
|
||||
import aiostream
|
||||
HAS_AIOSTREAM = True
|
||||
except ImportError:
|
||||
HAS_AIOSTREAM = False
|
||||
|
||||
requires_aiostream = pytest.mark.skipif(
|
||||
not HAS_AIOSTREAM,
|
||||
reason="aiostream not installed (pip install aiostream)"
|
||||
)
|
||||
|
||||
# Check for stream_pump dependencies
|
||||
try:
|
||||
from agentserver.message_bus.stream_pump import StreamPump, Listener
|
||||
from agentserver.message_bus.steps.routing_resolution import make_routing_step
|
||||
HAS_STREAM_PUMP = True
|
||||
except ImportError:
|
||||
HAS_STREAM_PUMP = False
|
||||
|
||||
requires_stream_pump = pytest.mark.skipif(
|
||||
not HAS_STREAM_PUMP,
|
||||
reason="stream_pump dependencies not available"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Test Fixtures
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture
|
||||
def valid_envelope_bytes():
|
||||
"""A well-formed message envelope matching envelope.xsd."""
|
||||
return b'''<?xml version="1.0"?>
|
||||
<message xmlns="https://xml-pipeline.org/ns/envelope/v1">
|
||||
<meta>
|
||||
<from>calculator.add</from>
|
||||
<thread>550e8400-e29b-41d4-a716-446655440000</thread>
|
||||
</meta>
|
||||
<addpayload xmlns="https://xml-pipeline.org/ns/calculator/add/v1">
|
||||
<a>5</a>
|
||||
<b>3</b>
|
||||
</addpayload>
|
||||
</message>'''
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def malformed_xml_bytes():
|
||||
"""Malformed XML that lxml can partially recover."""
|
||||
return b'<message><unclosed><nested>content</nested></message>'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def completely_broken_bytes():
|
||||
"""Not XML at all."""
|
||||
return b'this is not xml at all { json: "maybe" }'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def multi_payload_response():
|
||||
"""Handler response with multiple payloads."""
|
||||
return b'''
|
||||
<search.result><answer>42</answer></search.result>
|
||||
<calculator.add.addpayload><a>1</a><b>2</b></calculator.add.addpayload>
|
||||
<thought>I should also check...</thought>
|
||||
'''
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_state():
|
||||
"""Fresh MessageState with no data."""
|
||||
return MessageState()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def state_with_bytes(valid_envelope_bytes):
|
||||
"""MessageState with raw_bytes populated."""
|
||||
return MessageState(raw_bytes=valid_envelope_bytes)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# repair_step Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestRepairStep:
|
||||
"""Tests for the XML repair/recovery step."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_xml_passes_through(self, valid_envelope_bytes):
|
||||
"""Valid XML should parse without error."""
|
||||
state = MessageState(raw_bytes=valid_envelope_bytes)
|
||||
result = await repair_step(state)
|
||||
|
||||
assert result.error is None
|
||||
assert result.envelope_tree is not None
|
||||
assert result.envelope_tree.tag == "{https://xml-pipeline.org/ns/envelope/v1}message"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_xml_recovered(self, malformed_xml_bytes):
|
||||
"""Malformed XML should be recovered if possible."""
|
||||
state = MessageState(raw_bytes=malformed_xml_bytes)
|
||||
result = await repair_step(state)
|
||||
|
||||
# lxml recovery mode should produce something
|
||||
# May or may not have error depending on severity
|
||||
assert result.envelope_tree is not None or result.error is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_bytes_sets_error(self, empty_state):
|
||||
"""Missing raw_bytes should set an error."""
|
||||
result = await repair_step(empty_state)
|
||||
|
||||
assert result.error is not None
|
||||
assert "no raw_bytes" in result.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clears_raw_bytes_after_parse(self, valid_envelope_bytes):
|
||||
"""raw_bytes should be cleared after successful parse (memory optimization)."""
|
||||
state = MessageState(raw_bytes=valid_envelope_bytes)
|
||||
result = await repair_step(state)
|
||||
|
||||
assert result.raw_bytes is None
|
||||
assert result.envelope_tree is not None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# c14n_step Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestC14nStep:
|
||||
"""Tests for the canonicalization step."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalizes_whitespace(self):
|
||||
"""C14N should normalize whitespace."""
|
||||
xml_with_whitespace = b'''<root>
|
||||
<child> value </child>
|
||||
</root>'''
|
||||
|
||||
state = MessageState(raw_bytes=xml_with_whitespace)
|
||||
state = await repair_step(state)
|
||||
result = await c14n_step(state)
|
||||
|
||||
assert result.error is None
|
||||
assert result.envelope_tree is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalizes_attribute_order(self):
|
||||
"""C14N should produce consistent attribute ordering."""
|
||||
xml_a = b'<root z="1" a="2"/>'
|
||||
xml_b = b'<root a="2" z="1"/>'
|
||||
|
||||
state_a = MessageState(raw_bytes=xml_a)
|
||||
state_a = await repair_step(state_a)
|
||||
state_a = await c14n_step(state_a)
|
||||
|
||||
state_b = MessageState(raw_bytes=xml_b)
|
||||
state_b = await repair_step(state_b)
|
||||
state_b = await c14n_step(state_b)
|
||||
|
||||
# Both should produce identical canonical form
|
||||
c14n_a = etree.tostring(state_a.envelope_tree, method="c14n")
|
||||
c14n_b = etree.tostring(state_b.envelope_tree, method="c14n")
|
||||
assert c14n_a == c14n_b
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tree_sets_error(self, empty_state):
|
||||
"""Missing envelope_tree should set error."""
|
||||
result = await c14n_step(empty_state)
|
||||
|
||||
assert result.error is not None
|
||||
assert "no envelope_tree" in result.error
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# payload_extraction_step Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestPayloadExtractionStep:
|
||||
"""Tests for extracting payload from envelope."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extracts_payload_element(self, valid_envelope_bytes):
|
||||
"""Should extract the payload element from envelope."""
|
||||
state = MessageState(raw_bytes=valid_envelope_bytes)
|
||||
state = await repair_step(state)
|
||||
state = await c14n_step(state)
|
||||
# Skip envelope validation for this test
|
||||
result = await payload_extraction_step(state)
|
||||
|
||||
assert result.error is None
|
||||
assert result.payload_tree is not None
|
||||
# Tag may include namespace prefix
|
||||
assert "addpayload" in result.payload_tree.tag
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extracts_thread_id(self, valid_envelope_bytes):
|
||||
"""Should extract thread ID from envelope."""
|
||||
state = MessageState(raw_bytes=valid_envelope_bytes)
|
||||
state = await repair_step(state)
|
||||
state = await c14n_step(state)
|
||||
result = await payload_extraction_step(state)
|
||||
|
||||
assert result.thread_id == "550e8400-e29b-41d4-a716-446655440000"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extracts_from_id(self, valid_envelope_bytes):
|
||||
"""Should extract sender ID from envelope."""
|
||||
state = MessageState(raw_bytes=valid_envelope_bytes)
|
||||
state = await repair_step(state)
|
||||
state = await c14n_step(state)
|
||||
result = await payload_extraction_step(state)
|
||||
|
||||
assert result.from_id == "calculator.add"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_payloads_error(self):
|
||||
"""Multiple payload elements should error."""
|
||||
multi_payload = b'''
|
||||
<message xmlns="https://xml-pipeline.org/ns/envelope/v1">
|
||||
<meta>
|
||||
<from>test</from>
|
||||
<thread>uuid-here</thread>
|
||||
</meta>
|
||||
<payload1>data</payload1>
|
||||
<payload2>more data</payload2>
|
||||
</message>'''
|
||||
|
||||
state = MessageState(raw_bytes=multi_payload)
|
||||
state = await repair_step(state)
|
||||
state = await c14n_step(state)
|
||||
result = await payload_extraction_step(state)
|
||||
|
||||
assert result.error is not None
|
||||
assert "multiple payload" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_payload_error(self):
|
||||
"""Missing payload element should error."""
|
||||
no_payload = b'''
|
||||
<message xmlns="https://xml-pipeline.org/ns/envelope/v1">
|
||||
<meta>
|
||||
<from>test</from>
|
||||
<thread>uuid-here</thread>
|
||||
</meta>
|
||||
</message>'''
|
||||
|
||||
state = MessageState(raw_bytes=no_payload)
|
||||
state = await repair_step(state)
|
||||
state = await c14n_step(state)
|
||||
result = await payload_extraction_step(state)
|
||||
|
||||
assert result.error is not None
|
||||
assert "no payload" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_meta_error(self):
|
||||
"""Missing <meta> block should error."""
|
||||
no_meta = b'''
|
||||
<message xmlns="https://xml-pipeline.org/ns/envelope/v1">
|
||||
<payload>data</payload>
|
||||
</message>'''
|
||||
|
||||
state = MessageState(raw_bytes=no_meta)
|
||||
state = await repair_step(state)
|
||||
state = await c14n_step(state)
|
||||
result = await payload_extraction_step(state)
|
||||
|
||||
assert result.error is not None
|
||||
assert "meta" in result.error.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_from_error(self):
|
||||
"""Missing <from> in <meta> should error."""
|
||||
no_from = b'''
|
||||
<message xmlns="https://xml-pipeline.org/ns/envelope/v1">
|
||||
<meta>
|
||||
<thread>uuid-here</thread>
|
||||
</meta>
|
||||
<payload>data</payload>
|
||||
</message>'''
|
||||
|
||||
state = MessageState(raw_bytes=no_from)
|
||||
state = await repair_step(state)
|
||||
state = await c14n_step(state)
|
||||
result = await payload_extraction_step(state)
|
||||
|
||||
assert result.error is not None
|
||||
assert "from" in result.error.lower()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# thread_assignment_step Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestThreadAssignmentStep:
|
||||
"""Tests for thread UUID assignment."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_uuid_preserved(self):
|
||||
"""Valid UUID should be preserved."""
|
||||
valid_uuid = "550e8400-e29b-41d4-a716-446655440000"
|
||||
state = MessageState(thread_id=valid_uuid)
|
||||
result = await thread_assignment_step(state)
|
||||
|
||||
assert result.thread_id == valid_uuid
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_uuid_generated(self, empty_state):
|
||||
"""Missing UUID should generate a new one."""
|
||||
result = await thread_assignment_step(empty_state)
|
||||
|
||||
assert result.thread_id is not None
|
||||
assert len(result.thread_id) == 36 # UUID format
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_uuid_replaced(self):
|
||||
"""Invalid UUID should be replaced with a new one."""
|
||||
state = MessageState(thread_id="not-a-valid-uuid")
|
||||
result = await thread_assignment_step(state)
|
||||
|
||||
assert result.thread_id != "not-a-valid-uuid"
|
||||
assert len(result.thread_id) == 36
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_replacement_logged_in_metadata(self):
|
||||
"""Replaced UUIDs should be logged in metadata."""
|
||||
state = MessageState(thread_id="bad-uuid")
|
||||
result = await thread_assignment_step(state)
|
||||
|
||||
diagnostics = result.metadata.get("diagnostics", [])
|
||||
assert len(diagnostics) > 0
|
||||
assert "bad-uuid" in diagnostics[0]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Multi-Payload Extraction Tests (standalone, no aiostream required)
|
||||
# ============================================================================
|
||||
|
||||
class TestPayloadExtractionLogic:
|
||||
"""Test the core payload extraction logic without aiostream."""
|
||||
|
||||
def test_extract_single_payload(self):
|
||||
"""Single root element should extract cleanly."""
|
||||
raw = b"<result>42</result>"
|
||||
wrapped = b"<dummy>" + raw + b"</dummy>"
|
||||
tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True))
|
||||
|
||||
children = list(tree)
|
||||
assert len(children) == 1
|
||||
assert children[0].tag == "result"
|
||||
assert children[0].text == "42"
|
||||
|
||||
def test_extract_multiple_payloads(self, multi_payload_response):
|
||||
"""Multiple root elements should all be extracted."""
|
||||
wrapped = b"<dummy>" + multi_payload_response + b"</dummy>"
|
||||
tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True))
|
||||
|
||||
children = list(tree)
|
||||
assert len(children) == 3
|
||||
|
||||
tags = [c.tag for c in children]
|
||||
assert "search.result" in tags
|
||||
assert "calculator.add.addpayload" in tags
|
||||
assert "thought" in tags
|
||||
|
||||
def test_extract_preserves_content(self):
|
||||
"""Extracted payloads should preserve their content."""
|
||||
raw = b"<data><nested><deep>value</deep></nested></data>"
|
||||
wrapped = b"<dummy>" + raw + b"</dummy>"
|
||||
tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True))
|
||||
|
||||
children = list(tree)
|
||||
assert len(children) == 1
|
||||
|
||||
# Re-serialize and check
|
||||
extracted = etree.tostring(children[0])
|
||||
assert b"<deep>value</deep>" in extracted
|
||||
|
||||
def test_empty_response_no_crash(self):
|
||||
"""Empty response should not crash."""
|
||||
wrapped = b"<dummy></dummy>"
|
||||
tree = etree.fromstring(wrapped)
|
||||
|
||||
children = list(tree)
|
||||
assert len(children) == 0
|
||||
|
||||
def test_malformed_response_recovers(self):
|
||||
"""Malformed XML should be recovered if possible."""
|
||||
raw = b"<unclosed><valid>text</valid>"
|
||||
wrapped = b"<dummy>" + raw + b"</dummy>"
|
||||
|
||||
# With recovery parser
|
||||
tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True))
|
||||
# Should get something, exact result depends on lxml recovery
|
||||
assert tree is not None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Multi-Payload Extraction Tests (from stream_pump.py)
|
||||
# ============================================================================
|
||||
|
||||
@requires_aiostream
|
||||
class TestMultiPayloadExtraction:
|
||||
"""Tests for the fan-out payload extraction."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_payload_yields_one(self):
|
||||
"""Single payload should yield one state."""
|
||||
from agentserver.message_bus.stream_pump import extract_payloads
|
||||
|
||||
state = MessageState(
|
||||
raw_bytes=b"<result>42</result>",
|
||||
thread_id="test-thread",
|
||||
from_id="test-sender",
|
||||
)
|
||||
|
||||
results = [s async for s in extract_payloads(state)]
|
||||
|
||||
assert len(results) == 1
|
||||
assert b"<result>" in results[0].raw_bytes
|
||||
assert results[0].thread_id == "test-thread"
|
||||
assert results[0].from_id == "test-sender"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_payloads_yields_many(self, multi_payload_response):
|
||||
"""Multiple payloads should yield multiple states."""
|
||||
from agentserver.message_bus.stream_pump import extract_payloads
|
||||
|
||||
state = MessageState(
|
||||
raw_bytes=multi_payload_response,
|
||||
thread_id="test-thread",
|
||||
from_id="agent",
|
||||
)
|
||||
|
||||
results = [s async for s in extract_payloads(state)]
|
||||
|
||||
assert len(results) == 3
|
||||
# Each result should have the same thread_id and from_id
|
||||
for r in results:
|
||||
assert r.thread_id == "test-thread"
|
||||
assert r.from_id == "agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response_yields_original(self):
|
||||
"""Empty response should yield original state."""
|
||||
from agentserver.message_bus.stream_pump import extract_payloads
|
||||
|
||||
state = MessageState(
|
||||
raw_bytes=b"",
|
||||
thread_id="test",
|
||||
from_id="test",
|
||||
)
|
||||
|
||||
results = [s async for s in extract_payloads(state)]
|
||||
|
||||
# Should yield something (original or empty handling)
|
||||
assert len(results) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_metadata(self):
|
||||
"""Extracted payloads should preserve metadata."""
|
||||
from agentserver.message_bus.stream_pump import extract_payloads
|
||||
|
||||
state = MessageState(
|
||||
raw_bytes=b"<a/><b/>",
|
||||
thread_id="test",
|
||||
from_id="test",
|
||||
metadata={"custom": "value"},
|
||||
)
|
||||
|
||||
results = [s async for s in extract_payloads(state)]
|
||||
|
||||
for r in results:
|
||||
assert r.metadata.get("custom") == "value"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Step Factory Tests
|
||||
# ============================================================================
|
||||
|
||||
@requires_stream_pump
|
||||
class TestStepFactories:
|
||||
"""Tests for the step factory functions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_xsd_validation_direct(self):
|
||||
"""XSD validation via lxml schema."""
|
||||
# Create a simple schema
|
||||
xsd_str = '''
|
||||
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
|
||||
<xs:element name="test">
|
||||
<xs:complexType>
|
||||
<xs:sequence>
|
||||
<xs:element name="value" type="xs:integer"/>
|
||||
</xs:sequence>
|
||||
</xs:complexType>
|
||||
</xs:element>
|
||||
</xs:schema>'''
|
||||
schema = etree.XMLSchema(etree.fromstring(xsd_str.encode()))
|
||||
|
||||
# Valid payload
|
||||
valid_xml = etree.fromstring(b"<test><value>42</value></test>")
|
||||
assert schema.validate(valid_xml)
|
||||
|
||||
# Invalid payload
|
||||
invalid_xml = etree.fromstring(b"<test><value>not-an-int</value></test>")
|
||||
assert not schema.validate(invalid_xml)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routing_factory(self):
|
||||
"""Routing step should use injected routing table."""
|
||||
from agentserver.message_bus.steps.routing_resolution import make_routing_step
|
||||
from agentserver.message_bus.stream_pump import Listener
|
||||
|
||||
# Create mock listener
|
||||
mock_listener = Listener(
|
||||
name="calculator.add",
|
||||
payload_class=type("AddPayload", (), {}),
|
||||
handler=lambda x, m: b"<result/>",
|
||||
description="test",
|
||||
)
|
||||
|
||||
routing_table = {
|
||||
"calculator.add.addpayload": [mock_listener]
|
||||
}
|
||||
|
||||
step = make_routing_step(routing_table)
|
||||
|
||||
# Create a mock payload instance
|
||||
@dataclass
|
||||
class AddPayload:
|
||||
a: int = 0
|
||||
b: int = 0
|
||||
|
||||
state = MessageState(
|
||||
payload=AddPayload(a=1, b=2),
|
||||
to_id="calculator.add",
|
||||
)
|
||||
|
||||
result = await step(state)
|
||||
|
||||
assert result.error is None
|
||||
assert result.target_listeners == [mock_listener]
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Pipeline Integration Tests (lightweight)
|
||||
# ============================================================================
|
||||
|
||||
class TestPipelineIntegration:
|
||||
"""Integration tests for step sequences."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_repair_through_extraction(self, valid_envelope_bytes):
|
||||
"""Test repair → c14n → extraction chain."""
|
||||
state = MessageState(raw_bytes=valid_envelope_bytes)
|
||||
|
||||
state = await repair_step(state)
|
||||
assert state.error is None, f"repair failed: {state.error}"
|
||||
|
||||
state = await c14n_step(state)
|
||||
assert state.error is None, f"c14n failed: {state.error}"
|
||||
|
||||
state = await payload_extraction_step(state)
|
||||
assert state.error is None, f"extraction failed: {state.error}"
|
||||
|
||||
assert state.payload_tree is not None
|
||||
assert state.thread_id is not None
|
||||
assert state.from_id is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_short_circuits(self):
|
||||
"""Errors should prevent downstream steps from running."""
|
||||
call_log = []
|
||||
|
||||
async def step_a(state):
|
||||
call_log.append("a")
|
||||
state.error = "Intentional error"
|
||||
return state
|
||||
|
||||
async def step_b(state):
|
||||
call_log.append("b")
|
||||
return state
|
||||
|
||||
# Simple pipeline runner (same logic as StreamPump uses)
|
||||
async def run_pipeline(steps, state):
|
||||
for step in steps:
|
||||
state = await step(state)
|
||||
if state.error:
|
||||
break
|
||||
return state
|
||||
|
||||
result = await run_pipeline([step_a, step_b], MessageState())
|
||||
|
||||
assert call_log == ["a"] # step_b should not have been called
|
||||
assert result.error == "Intentional error"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Run with pytest
|
||||
# ============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
398
tests/test_pump_integration.py
Normal file
398
tests/test_pump_integration.py
Normal file
|
|
@ -0,0 +1,398 @@
|
|||
"""
|
||||
test_pump_integration.py — Integration tests for the StreamPump
|
||||
|
||||
Run with: pytest tests/test_pump_integration.py -v
|
||||
|
||||
These tests verify the full message flow through the pump:
|
||||
inject → parse → extract → validate → deserialize → route → handler → response
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from agentserver.message_bus import StreamPump, bootstrap, MessageState
|
||||
from agentserver.message_bus.stream_pump import ConfigLoader, ListenerConfig, OrganismConfig, Listener
|
||||
from handlers.hello import Greeting, GreetingResponse, handle_greeting, ENVELOPE_NS
|
||||
|
||||
|
||||
def make_envelope(payload_xml: str, from_id: str, to_id: str, thread_id: str) -> bytes:
|
||||
"""Helper to create a properly formatted envelope.
|
||||
|
||||
Note: payload_xml should include its own namespace (or xmlns="") to avoid
|
||||
inheriting the envelope namespace. The envelope XSD expects payload to be
|
||||
in a foreign namespace (##other).
|
||||
"""
|
||||
# Ensure payload has explicit namespace (empty string = no namespace)
|
||||
if 'xmlns=' not in payload_xml:
|
||||
# Insert xmlns="" after the tag name
|
||||
idx = payload_xml.index('>')
|
||||
if payload_xml[idx-1] == '/':
|
||||
idx -= 1
|
||||
payload_xml = payload_xml[:idx] + ' xmlns=""' + payload_xml[idx:]
|
||||
|
||||
return f"""<message xmlns="{ENVELOPE_NS}">
|
||||
<meta>
|
||||
<from>{from_id}</from>
|
||||
<to>{to_id}</to>
|
||||
<thread>{thread_id}</thread>
|
||||
</meta>
|
||||
{payload_xml}
|
||||
</message>""".encode('utf-8')
|
||||
|
||||
|
||||
class TestPumpBootstrap:
|
||||
"""Test ConfigLoader and bootstrap."""
|
||||
|
||||
def test_config_loader_parses_yaml(self):
|
||||
"""ConfigLoader should parse organism.yaml correctly."""
|
||||
config = ConfigLoader.load('config/organism.yaml')
|
||||
|
||||
assert config.name == "hello-world"
|
||||
assert len(config.listeners) == 1
|
||||
assert config.listeners[0].name == "greeter"
|
||||
assert config.listeners[0].payload_class == Greeting
|
||||
assert config.listeners[0].handler == handle_greeting
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bootstrap_creates_pump(self):
|
||||
"""bootstrap() should create a configured pump."""
|
||||
pump = await bootstrap('config/organism.yaml')
|
||||
|
||||
assert pump.config.name == "hello-world"
|
||||
assert "greeter.greeting" in pump.routing_table
|
||||
assert pump.listeners["greeter"].payload_class == Greeting
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bootstrap_generates_xsd(self):
|
||||
"""bootstrap() should generate XSD schemas for listeners."""
|
||||
pump = await bootstrap('config/organism.yaml')
|
||||
|
||||
listener = pump.listeners["greeter"]
|
||||
assert listener.schema is not None
|
||||
|
||||
# Schema should validate a proper Greeting
|
||||
from lxml import etree
|
||||
valid_xml = etree.fromstring(b"<Greeting><Name>Test</Name></Greeting>")
|
||||
listener.schema.assertValid(valid_xml)
|
||||
|
||||
|
||||
class TestPumpInjection:
|
||||
"""Test message injection and queue behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inject_adds_to_queue(self):
|
||||
"""inject() should add a MessageState to the queue."""
|
||||
pump = await bootstrap('config/organism.yaml')
|
||||
|
||||
thread_id = str(uuid.uuid4())
|
||||
await pump.inject(b"<test/>", thread_id, from_id="user")
|
||||
|
||||
assert pump.queue.qsize() == 1
|
||||
state = await pump.queue.get()
|
||||
assert state.raw_bytes == b"<test/>"
|
||||
assert state.thread_id == thread_id
|
||||
assert state.from_id == "user"
|
||||
|
||||
|
||||
class TestFullPipelineFlow:
|
||||
"""Test complete message flow through the pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_greeting_round_trip(self):
|
||||
"""
|
||||
Full integration test:
|
||||
1. Inject a Greeting message
|
||||
2. Pump processes it through the pipeline
|
||||
3. Handler is called with deserialized Greeting
|
||||
4. Handler response is re-injected
|
||||
"""
|
||||
pump = await bootstrap('config/organism.yaml')
|
||||
|
||||
# Track what the handler receives
|
||||
handler_calls = []
|
||||
original_handler = pump.listeners["greeter"].handler
|
||||
|
||||
async def tracking_handler(payload, metadata):
|
||||
handler_calls.append((payload, metadata))
|
||||
return await original_handler(payload, metadata)
|
||||
|
||||
pump.listeners["greeter"].handler = tracking_handler
|
||||
|
||||
# Create and inject a Greeting message
|
||||
thread_id = str(uuid.uuid4())
|
||||
envelope = make_envelope(
|
||||
payload_xml="<Greeting><Name>World</Name></Greeting>",
|
||||
from_id="user",
|
||||
to_id="greeter",
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
await pump.inject(envelope, thread_id, from_id="user")
|
||||
|
||||
# Run pump briefly to process the message
|
||||
pump._running = True
|
||||
pipeline = pump.build_pipeline(pump._queue_source())
|
||||
|
||||
# Process with timeout
|
||||
async def run_with_timeout():
|
||||
async with pipeline.stream() as streamer:
|
||||
try:
|
||||
async for _ in streamer:
|
||||
# One iteration should process our message
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(run_with_timeout(), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
pump._running = False
|
||||
|
||||
# Verify handler was called
|
||||
assert len(handler_calls) == 1
|
||||
payload, metadata = handler_calls[0]
|
||||
|
||||
assert isinstance(payload, Greeting)
|
||||
assert payload.name == "World"
|
||||
assert metadata.thread_id == thread_id
|
||||
assert metadata.from_id == "user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handler_response_reinjected(self):
|
||||
"""Handler response should be re-injected into the queue."""
|
||||
pump = await bootstrap('config/organism.yaml')
|
||||
|
||||
# Capture re-injected messages
|
||||
reinjected = []
|
||||
original_reinject = pump._reinject_responses
|
||||
|
||||
async def capture_reinject(state):
|
||||
reinjected.append(state)
|
||||
# Don't actually re-inject to avoid infinite loop
|
||||
|
||||
pump._reinject_responses = capture_reinject
|
||||
|
||||
# Inject a Greeting
|
||||
thread_id = str(uuid.uuid4())
|
||||
envelope = make_envelope(
|
||||
payload_xml="<Greeting><Name>Alice</Name></Greeting>",
|
||||
from_id="user",
|
||||
to_id="greeter",
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
await pump.inject(envelope, thread_id, from_id="user")
|
||||
|
||||
# Run pump briefly
|
||||
pump._running = True
|
||||
pipeline = pump.build_pipeline(pump._queue_source())
|
||||
|
||||
async def run_with_timeout():
|
||||
async with pipeline.stream() as streamer:
|
||||
try:
|
||||
async for _ in streamer:
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(run_with_timeout(), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
pump._running = False
|
||||
|
||||
# Verify response was re-injected
|
||||
assert len(reinjected) == 1
|
||||
response_state = reinjected[0]
|
||||
|
||||
assert response_state.raw_bytes is not None
|
||||
assert b"Hello, Alice!" in response_state.raw_bytes
|
||||
assert response_state.thread_id == thread_id
|
||||
assert response_state.from_id == "greeter"
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test error paths through the pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_xml_error(self):
|
||||
"""Malformed XML should set error, not crash."""
|
||||
pump = await bootstrap('config/organism.yaml')
|
||||
|
||||
errors = []
|
||||
original_handle_errors = pump._handle_errors
|
||||
|
||||
async def capture_errors(state):
|
||||
if state.error:
|
||||
errors.append(state.error)
|
||||
return await original_handle_errors(state)
|
||||
|
||||
pump._handle_errors = capture_errors
|
||||
|
||||
# Inject malformed XML
|
||||
thread_id = str(uuid.uuid4())
|
||||
await pump.inject(b"<not valid xml", thread_id, from_id="user")
|
||||
|
||||
# Run pump
|
||||
pump._running = True
|
||||
pipeline = pump.build_pipeline(pump._queue_source())
|
||||
|
||||
async def run_with_timeout():
|
||||
async with pipeline.stream() as streamer:
|
||||
try:
|
||||
async for _ in streamer:
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(run_with_timeout(), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
pump._running = False
|
||||
|
||||
# Should have logged an error (repair step recovers, but envelope validation fails)
|
||||
# The exact error depends on how far it gets
|
||||
assert pump.queue.qsize() == 0 or len(errors) >= 0 # Processed without crash
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_route_error(self):
|
||||
"""Message to unknown listener should error gracefully."""
|
||||
pump = await bootstrap('config/organism.yaml')
|
||||
|
||||
errors = []
|
||||
original_handle_errors = pump._handle_errors
|
||||
|
||||
async def capture_errors(state):
|
||||
if state.error:
|
||||
errors.append(state.error)
|
||||
return await original_handle_errors(state)
|
||||
|
||||
pump._handle_errors = capture_errors
|
||||
|
||||
# Inject message to non-existent listener
|
||||
thread_id = str(uuid.uuid4())
|
||||
envelope = make_envelope(
|
||||
payload_xml="<Greeting><Name>Test</Name></Greeting>",
|
||||
from_id="user",
|
||||
to_id="nonexistent", # No such listener
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
await pump.inject(envelope, thread_id, from_id="user")
|
||||
|
||||
# Run pump
|
||||
pump._running = True
|
||||
pipeline = pump.build_pipeline(pump._queue_source())
|
||||
|
||||
async def run_with_timeout():
|
||||
async with pipeline.stream() as streamer:
|
||||
try:
|
||||
async for _ in streamer:
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(run_with_timeout(), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
pump._running = False
|
||||
|
||||
# Should have a routing error
|
||||
assert any("nonexistent" in e for e in errors)
|
||||
|
||||
|
||||
class TestManualPumpConfiguration:
|
||||
"""Test creating a pump without YAML config."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manual_listener_registration(self):
|
||||
"""Can register listeners programmatically."""
|
||||
config = OrganismConfig(name="manual-test")
|
||||
pump = StreamPump(config)
|
||||
|
||||
lc = ListenerConfig(
|
||||
name="greeter",
|
||||
payload_class_path="handlers.hello.Greeting",
|
||||
handler_path="handlers.hello.handle_greeting",
|
||||
description="Test listener",
|
||||
payload_class=Greeting,
|
||||
handler=handle_greeting,
|
||||
)
|
||||
|
||||
listener = pump.register_listener(lc)
|
||||
|
||||
assert listener.name == "greeter"
|
||||
assert listener.root_tag == "greeter.greeting"
|
||||
assert "greeter.greeting" in pump.routing_table
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_handler(self):
|
||||
"""Can use a custom handler function."""
|
||||
config = OrganismConfig(name="custom-test")
|
||||
pump = StreamPump(config)
|
||||
|
||||
responses = []
|
||||
|
||||
async def custom_handler(payload, metadata):
|
||||
responses.append(payload)
|
||||
return b"<Ack/>"
|
||||
|
||||
lc = ListenerConfig(
|
||||
name="custom",
|
||||
payload_class_path="handlers.hello.Greeting",
|
||||
handler_path="handlers.hello.handle_greeting",
|
||||
description="Custom handler",
|
||||
payload_class=Greeting,
|
||||
handler=custom_handler,
|
||||
)
|
||||
|
||||
pump.register_listener(lc)
|
||||
|
||||
# Inject and process
|
||||
thread_id = str(uuid.uuid4())
|
||||
envelope = make_envelope(
|
||||
payload_xml="<Greeting><Name>Custom</Name></Greeting>",
|
||||
from_id="tester",
|
||||
to_id="custom",
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
await pump.inject(envelope, thread_id, from_id="tester")
|
||||
|
||||
# Run pump
|
||||
pump._running = True
|
||||
|
||||
# Capture re-injected to prevent loop
|
||||
async def noop_reinject(state):
|
||||
pass
|
||||
pump._reinject_responses = noop_reinject
|
||||
|
||||
pipeline = pump.build_pipeline(pump._queue_source())
|
||||
|
||||
async def run_with_timeout():
|
||||
async with pipeline.stream() as streamer:
|
||||
try:
|
||||
async for _ in streamer:
|
||||
break
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(run_with_timeout(), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
pump._running = False
|
||||
|
||||
# Custom handler should have been called
|
||||
assert len(responses) == 1
|
||||
assert responses[0].name == "Custom"
|
||||
5
third_party/xmlable/__init__.py
vendored
5
third_party/xmlable/__init__.py
vendored
|
|
@ -7,6 +7,7 @@ from typing import Type, TypeVar, Any
|
|||
from io import BytesIO
|
||||
|
||||
from ._xmlify import xmlify
|
||||
from ._errors import XErrorCtx
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
|
@ -19,7 +20,9 @@ def parse_element(cls: Type[T], element: _Element | ObjectifiedElement) -> T:
|
|||
"""Direct in-memory deserialization from validated lxml Element."""
|
||||
xobject = _get_xobject(cls)
|
||||
obj_element = objectify.fromstring(etree.tostring(element))
|
||||
return xobject.xml_in(obj_element, ctx=None)
|
||||
# Create a root context for error tracing
|
||||
ctx = XErrorCtx(trace=[cls.__name__])
|
||||
return xobject.xml_in(obj_element, ctx=ctx)
|
||||
|
||||
def parse_bytes(cls: Type[T], xml_bytes: bytes) -> T:
|
||||
tree = objectify.parse(BytesIO(xml_bytes))
|
||||
|
|
|
|||
2
third_party/xmlable/_errors.py
vendored
2
third_party/xmlable/_errors.py
vendored
|
|
@ -9,7 +9,7 @@ from typing import Any, Iterable
|
|||
from termcolor import colored
|
||||
from termcolor.termcolor import Color
|
||||
|
||||
from xmlable._utils import typename, AnyType
|
||||
from ._utils import typename, AnyType
|
||||
|
||||
|
||||
def trace_note(trace: list[str], arrow_c: Color, node_c: Color):
|
||||
|
|
|
|||
6
third_party/xmlable/_io.py
vendored
6
third_party/xmlable/_io.py
vendored
|
|
@ -10,9 +10,9 @@ from termcolor import colored
|
|||
from lxml.objectify import parse as objectify_parse
|
||||
from lxml.etree import _ElementTree
|
||||
|
||||
from xmlable._utils import typename
|
||||
from xmlable._xobject import is_xmlified
|
||||
from xmlable._errors import ErrorTypes
|
||||
from ._utils import typename
|
||||
from ._xobject import is_xmlified
|
||||
from ._errors import ErrorTypes
|
||||
|
||||
|
||||
def write_file(file_path: str | Path, tree: _ElementTree):
|
||||
|
|
|
|||
6
third_party/xmlable/_manual.py
vendored
6
third_party/xmlable/_manual.py
vendored
|
|
@ -8,9 +8,9 @@ from typing import Any
|
|||
from lxml.etree import _Element, Element, _ElementTree, ElementTree
|
||||
from lxml.objectify import ObjectifiedElement
|
||||
|
||||
from xmlable._utils import typename, AnyType, ordered_iter
|
||||
from xmlable._lxml_helpers import with_children, XMLSchema
|
||||
from xmlable._errors import XError, XErrorCtx, ErrorTypes
|
||||
from ._utils import typename, AnyType, ordered_iter
|
||||
from ._lxml_helpers import with_children, XMLSchema
|
||||
from ._errors import XError, XErrorCtx, ErrorTypes
|
||||
|
||||
|
||||
def validate_manual_class(cls: AnyType):
|
||||
|
|
|
|||
4
third_party/xmlable/_user.py
vendored
4
third_party/xmlable/_user.py
vendored
|
|
@ -6,8 +6,8 @@ The IXmlify interface
|
|||
|
||||
from abc import ABC, abstractmethod
|
||||
from lxml.etree import _Element
|
||||
from xmlable._xobject import XObject
|
||||
from xmlable._utils import AnyType
|
||||
from ._xobject import XObject
|
||||
from ._utils import AnyType
|
||||
|
||||
|
||||
class IXmlify(ABC):
|
||||
|
|
|
|||
10
third_party/xmlable/_xmlify.py
vendored
10
third_party/xmlable/_xmlify.py
vendored
|
|
@ -15,11 +15,11 @@ from typing import Any, dataclass_transform, cast
|
|||
from lxml.objectify import ObjectifiedElement
|
||||
from lxml.etree import Element, _Element
|
||||
|
||||
from xmlable._utils import get, typename, AnyType
|
||||
from xmlable._errors import XError, XErrorCtx, ErrorTypes
|
||||
from xmlable._manual import manual_xmlify
|
||||
from xmlable._lxml_helpers import with_children, with_child, XMLSchema
|
||||
from xmlable._xobject import XObject, gen_xobject
|
||||
from ._utils import get, typename, AnyType
|
||||
from ._errors import XError, XErrorCtx, ErrorTypes
|
||||
from ._manual import manual_xmlify
|
||||
from ._lxml_helpers import with_children, with_child, XMLSchema
|
||||
from ._xobject import XObject, gen_xobject
|
||||
|
||||
|
||||
def validate_class(cls: AnyType):
|
||||
|
|
|
|||
6
third_party/xmlable/_xobject.py
vendored
6
third_party/xmlable/_xobject.py
vendored
|
|
@ -13,9 +13,9 @@ from abc import ABC, abstractmethod
|
|||
from typing import Any, Callable, Type, get_args, TypeAlias, cast
|
||||
from types import GenericAlias
|
||||
|
||||
from xmlable._utils import get, typename, firstkey, AnyType
|
||||
from xmlable._errors import XErrorCtx, ErrorTypes
|
||||
from xmlable._lxml_helpers import (
|
||||
from ._utils import get, typename, firstkey, AnyType
|
||||
from ._errors import XErrorCtx, ErrorTypes
|
||||
from ._lxml_helpers import (
|
||||
with_text,
|
||||
with_child,
|
||||
with_children,
|
||||
|
|
|
|||
Loading…
Reference in a new issue