Replace MessageBus with aiostream-based StreamPump
Major refactor of the message pump architecture: - Replace bus.py with stream_pump.py using aiostream for composable stream processing with natural fan-out via flatmap - Add to_id field to MessageState for explicit routing - Fix routing to use to_id.class format (e.g., "greeter.greeting") - Generate XSD schemas from xmlified payload classes - Fix xmlable imports (absolute -> relative) and parse_element ctx New features: - handlers/hello.py: Sample Greeting/GreetingResponse handler - config/organism.yaml: Sample organism configuration - 41 tests (31 unit + 10 integration) all passing Schema changes: - envelope.xsd: Allow any namespace payloads (##other -> ##any) Dependencies added to pyproject.toml: - aiostream>=0.5 (core dependency) - pyhumps, termcolor (for xmlable) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
dc16316aed
commit
82b5fcdd78
24 changed files with 2018 additions and 343 deletions
27
__init__.py
27
__init__.py
|
|
@ -3,19 +3,26 @@
|
||||||
xml-pipeline
|
xml-pipeline
|
||||||
============
|
============
|
||||||
Secure, XML-centric multi-listener organism server.
|
Secure, XML-centric multi-listener organism server.
|
||||||
|
|
||||||
|
Stream-based message pump with aiostream for fan-out handling.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from agentserver.agentserver import AgentServer as AgentServer
|
from agentserver.message_bus import (
|
||||||
from agentserver.xml_listener import XMLListener as XMLListener
|
StreamPump,
|
||||||
from agentserver.message_bus import MessageBus as MessageBus
|
ConfigLoader,
|
||||||
from agentserver.message_bus import Session as Session
|
Listener,
|
||||||
|
MessageState,
|
||||||
|
HandlerMetadata,
|
||||||
|
bootstrap,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AgentServer",
|
"StreamPump",
|
||||||
"XMLListener",
|
"ConfigLoader",
|
||||||
"MessageBus",
|
"Listener",
|
||||||
"Session",
|
"MessageState",
|
||||||
|
"HandlerMetadata",
|
||||||
|
"bootstrap",
|
||||||
]
|
]
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.2.0" # Bumped for aiostream pump
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
"""
|
||||||
|
message_bus — Stream-based message pump for AgentServer v2.1
|
||||||
|
|
||||||
|
The message pump handles message flow through the organism:
|
||||||
|
- YAML config → bootstrap → pump → handlers → responses → loop
|
||||||
|
|
||||||
|
Key classes:
|
||||||
|
StreamPump Main pump class (queue-backed, aiostream-powered)
|
||||||
|
ConfigLoader Load organism.yaml and resolve imports
|
||||||
|
Listener Runtime listener with handler and routing info
|
||||||
|
MessageState Message flowing through pipeline steps
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from agentserver.message_bus import StreamPump, bootstrap
|
||||||
|
|
||||||
|
pump = await bootstrap("config/organism.yaml")
|
||||||
|
await pump.inject(initial_message, thread_id, from_id)
|
||||||
|
await pump.run()
|
||||||
|
"""
|
||||||
|
|
||||||
|
from agentserver.message_bus.stream_pump import (
|
||||||
|
StreamPump,
|
||||||
|
ConfigLoader,
|
||||||
|
Listener,
|
||||||
|
ListenerConfig,
|
||||||
|
OrganismConfig,
|
||||||
|
bootstrap,
|
||||||
|
)
|
||||||
|
|
||||||
|
from agentserver.message_bus.message_state import (
|
||||||
|
MessageState,
|
||||||
|
HandlerMetadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"StreamPump",
|
||||||
|
"ConfigLoader",
|
||||||
|
"Listener",
|
||||||
|
"ListenerConfig",
|
||||||
|
"OrganismConfig",
|
||||||
|
"MessageState",
|
||||||
|
"HandlerMetadata",
|
||||||
|
"bootstrap",
|
||||||
|
]
|
||||||
|
|
@ -1,226 +0,0 @@
|
||||||
"""
|
|
||||||
bus.py — The central MessageBus and pump for AgentServer v2.1
|
|
||||||
|
|
||||||
This is the beating heart of the organism:
|
|
||||||
- Owns all pipelines (one per listener + permanent system pipeline)
|
|
||||||
- Maintains the routing table (root_tag → Listener(s))
|
|
||||||
- Orchestrates ingress from sockets/gateways
|
|
||||||
- Dispatches prepared messages to handlers
|
|
||||||
- Processes handler responses (multi-payload extraction, provenance injection)
|
|
||||||
- Guarantees thread continuity and diagnostic injection
|
|
||||||
|
|
||||||
Fully aligned with:
|
|
||||||
- listener-class-v2.1.md
|
|
||||||
- configuration-v2.1.md
|
|
||||||
- message-pump-v2.1.md
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Callable, Awaitable, List
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
from lxml import etree
|
|
||||||
|
|
||||||
from agentserver.message_bus.message_state import MessageState, HandlerMetadata
|
|
||||||
from agentserver.message_bus.steps.repair import repair_step
|
|
||||||
from agentserver.message_bus.steps.c14n import c14n_step
|
|
||||||
from agentserver.message_bus.steps.envelope_validation import envelope_validation_step
|
|
||||||
from agentserver.message_bus.steps.payload_extraction import payload_extraction_step
|
|
||||||
from agentserver.message_bus.steps.thread_assignment import thread_assignment_step
|
|
||||||
from agentserver.message_bus.steps.xsd_validation import xsd_validation_step
|
|
||||||
from agentserver.message_bus.steps.deserialization import deserialization_step
|
|
||||||
from agentserver.message_bus.steps.routing_resolution import routing_resolution_step
|
|
||||||
|
|
||||||
# Type alias for pipeline steps
|
|
||||||
PipelineStep = Callable[[MessageState], Awaitable[MessageState]]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Listener:
|
|
||||||
"""Registered capability — defined in listener.py, referenced here."""
|
|
||||||
name: str
|
|
||||||
payload_class: type
|
|
||||||
handler: Callable
|
|
||||||
description: str
|
|
||||||
is_agent: bool = False
|
|
||||||
peers: list[str] | None = None
|
|
||||||
broadcast: bool = False
|
|
||||||
pipeline: "Pipeline" | None = None
|
|
||||||
schema: etree.XMLSchema | None = None # cached at registration
|
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
|
||||||
"""One dedicated pipeline per listener (plus system pipeline)."""
|
|
||||||
def __init__(self, steps: List[PipelineStep]):
|
|
||||||
self.steps = steps
|
|
||||||
|
|
||||||
async def process(self, initial_state: MessageState) -> None:
|
|
||||||
"""Run the full ordered pipeline on a message."""
|
|
||||||
state = initial_state
|
|
||||||
for step in self.steps:
|
|
||||||
try:
|
|
||||||
state = await step(state)
|
|
||||||
if state.error:
|
|
||||||
break
|
|
||||||
except Exception as exc: # pylint: disable=broad-except
|
|
||||||
state.error = f"Pipeline step {step.__name__} crashed: {exc}"
|
|
||||||
break
|
|
||||||
|
|
||||||
# After all steps — dispatch if routable
|
|
||||||
if state.target_listeners:
|
|
||||||
await MessageBus.get_instance().dispatcher(state)
|
|
||||||
else:
|
|
||||||
# Fall back to system pipeline for diagnostics
|
|
||||||
await MessageBus.get_instance().system_pipeline.process(state)
|
|
||||||
|
|
||||||
|
|
||||||
class MessageBus:
|
|
||||||
"""Singleton message bus — the pump."""
|
|
||||||
_instance: "MessageBus" | None = None
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.routing_table: dict[str, List[Listener]] = {} # root_tag → listener(s)
|
|
||||||
self.listeners: dict[str, Listener] = {} # name → Listener
|
|
||||||
self.system_pipeline = Pipeline(self._build_system_steps())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_instance(cls) -> "MessageBus":
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = MessageBus()
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------ #
|
|
||||||
# Default step lists
|
|
||||||
# ------------------------------------------------------------------ #
|
|
||||||
def _build_default_listener_steps(self) -> List[PipelineStep]:
|
|
||||||
return [
|
|
||||||
repair_step,
|
|
||||||
c14n_step,
|
|
||||||
envelope_validation_step,
|
|
||||||
payload_extraction_step,
|
|
||||||
thread_assignment_step,
|
|
||||||
xsd_validation_step,
|
|
||||||
deserialization_step,
|
|
||||||
routing_resolution_step,
|
|
||||||
]
|
|
||||||
|
|
||||||
def _build_system_steps(self) -> List[PipelineStep]:
|
|
||||||
"""Shorter, fixed steps — no XSD/deserialization."""
|
|
||||||
return [
|
|
||||||
repair_step,
|
|
||||||
c14n_step,
|
|
||||||
envelope_validation_step,
|
|
||||||
payload_extraction_step,
|
|
||||||
thread_assignment_step,
|
|
||||||
# system-specific handler that emits <huh>, boot, etc.
|
|
||||||
self.system_handler_step,
|
|
||||||
]
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------ #
|
|
||||||
# Registration (called from listener.py)
|
|
||||||
# ------------------------------------------------------------------ #
|
|
||||||
def register_listener(self, listener: Listener) -> None:
|
|
||||||
root_tag = f"{listener.name.lower()}.{listener.payload_class.__name__.lower()}"
|
|
||||||
|
|
||||||
if root_tag in self.routing_table and not listener.broadcast:
|
|
||||||
raise ValueError(f"Root tag collision: {root_tag} already registered by {self.routing_table[root_tag][0].name}")
|
|
||||||
|
|
||||||
# Build dedicated pipeline
|
|
||||||
steps = self._build_default_listener_steps()
|
|
||||||
# Inject listener-specific schema for xsd_validation_step
|
|
||||||
for step in steps:
|
|
||||||
if step.__name__ == "xsd_validation_step":
|
|
||||||
# We'll modify state.metadata in pipeline construction instead
|
|
||||||
pass
|
|
||||||
listener.pipeline = Pipeline(steps)
|
|
||||||
|
|
||||||
# Insert into routing
|
|
||||||
self.routing_table.setdefault(root_tag, []).append(listener)
|
|
||||||
self.listeners[listener.name] = listener
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------ #
|
|
||||||
# Dispatcher — dumb fire-and-await
|
|
||||||
# ------------------------------------------------------------------ #
|
|
||||||
async def dispatcher(self, state: MessageState) -> None:
|
|
||||||
if not state.target_listeners:
|
|
||||||
return
|
|
||||||
|
|
||||||
metadata = HandlerMetadata(
|
|
||||||
thread_id=state.thread_id or "",
|
|
||||||
from_id=state.from_id or "unknown",
|
|
||||||
own_name=state.target_listeners[0].name if state.target_listeners[0].is_agent else None,
|
|
||||||
is_self_call=(state.from_id == state.target_listeners[0].name) if state.from_id else False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(state.target_listeners) == 1:
|
|
||||||
listener = state.target_listeners[0]
|
|
||||||
await self._process_single_handler(state, listener, metadata)
|
|
||||||
else:
|
|
||||||
# Broadcast — fire all in parallel, process responses as they complete
|
|
||||||
tasks = [
|
|
||||||
self._process_single_handler(state, listener, metadata)
|
|
||||||
for listener in state.target_listeners
|
|
||||||
]
|
|
||||||
for future in asyncio.as_completed(tasks):
|
|
||||||
await future
|
|
||||||
|
|
||||||
async def _process_single_handler(self, state: MessageState, listener: Listener, metadata: HandlerMetadata) -> None:
|
|
||||||
try:
|
|
||||||
response_bytes = await listener.handler(state.payload, metadata)
|
|
||||||
|
|
||||||
if response_bytes is None or not isinstance(response_bytes, bytes):
|
|
||||||
response_bytes = b"<huh>Handler failed to return valid bytes — missing return or wrong type</huh>"
|
|
||||||
|
|
||||||
payloads = await self._multi_payload_extract(response_bytes)
|
|
||||||
|
|
||||||
for payload_bytes in payloads:
|
|
||||||
new_state = MessageState(
|
|
||||||
raw_bytes=payload_bytes,
|
|
||||||
thread_id=state.thread_id,
|
|
||||||
from_id=listener.name,
|
|
||||||
)
|
|
||||||
# Route the new payload through normal pipelines
|
|
||||||
root_tag = self._derive_root_tag(payload_bytes)
|
|
||||||
targets = self.routing_table.get(root_tag)
|
|
||||||
if targets:
|
|
||||||
new_state.target_listeners = targets
|
|
||||||
await targets[0].pipeline.process(new_state)
|
|
||||||
else:
|
|
||||||
await self.system_pipeline.process(new_state)
|
|
||||||
|
|
||||||
except Exception as exc: # pylint: disable=broad-except
|
|
||||||
error_state = MessageState(
|
|
||||||
raw_bytes=b"<huh>Handler crashed</huh>",
|
|
||||||
thread_id=state.thread_id,
|
|
||||||
from_id=listener.name,
|
|
||||||
error=f"Handler {listener.name} crashed: {exc}",
|
|
||||||
)
|
|
||||||
await self.system_pipeline.process(error_state)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------ #
|
|
||||||
# Helper methods
|
|
||||||
# ------------------------------------------------------------------ #
|
|
||||||
async def _multi_payload_extract(self, raw_bytes: bytes) -> List[bytes]:
|
|
||||||
# Same logic as before — dummy wrap, repair, extract all root elements
|
|
||||||
# (implementation can be moved to a shared util later)
|
|
||||||
# For now, placeholder — we'll flesh this out in response_processing.py
|
|
||||||
return [raw_bytes] # temporary — will be full extraction
|
|
||||||
|
|
||||||
def _derive_root_tag(self, payload_bytes: bytes) -> str:
|
|
||||||
# Quick parse to get root tag — used only for routing extracted payloads
|
|
||||||
try:
|
|
||||||
tree = etree.fromstring(payload_bytes)
|
|
||||||
tag = tree.tag
|
|
||||||
if tag.startswith("{"):
|
|
||||||
return tag.split("}", 1)[1] # strip namespace
|
|
||||||
return tag
|
|
||||||
except Exception:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
async def system_handler_step(self, state: MessageState) -> MessageState:
|
|
||||||
# Emit <huh> or boot message — placeholder for now
|
|
||||||
state.error = state.error or "Unhandled by any listener"
|
|
||||||
return state
|
|
||||||
|
|
@ -1,6 +1,12 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from lxml.etree import Element
|
from typing import Any, TYPE_CHECKING
|
||||||
from typing import Any
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from lxml.etree import _Element as Element
|
||||||
|
else:
|
||||||
|
Element = Any # Runtime: don't need the actual type
|
||||||
|
|
||||||
"""
|
"""
|
||||||
default_listener_steps = [
|
default_listener_steps = [
|
||||||
|
|
@ -33,6 +39,7 @@ class MessageState:
|
||||||
|
|
||||||
thread_id: str | None = None
|
thread_id: str | None = None
|
||||||
from_id: str | None = None
|
from_id: str | None = None
|
||||||
|
to_id: str | None = None # Target listener name for routing
|
||||||
|
|
||||||
target_listeners: list['Listener'] | None = None # Forward reference
|
target_listeners: list['Listener'] | None = None # Forward reference
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,29 +1,29 @@
|
||||||
"""
|
"""
|
||||||
deserialization.py — Convert validated payload_tree into typed dataclass instance.
|
deserialization.py — Convert validated payload_tree into typed dataclass instance.
|
||||||
|
|
||||||
After xsd_validation_step confirms the payload conforms to the listener's contract,
|
After xsd_validation_step confirms the payload conforms to the contract,
|
||||||
this step uses the xmlable library to deserialize the lxml Element into the
|
this step uses our customized xmlable routines to deserialize the lxml Element
|
||||||
registered @xmlify dataclass.
|
directly in memory — no temporary files needed.
|
||||||
|
|
||||||
The resulting instance is placed in state.payload and handed to the handler.
|
|
||||||
|
|
||||||
Part of AgentServer v2.1 message pump.
|
Part of AgentServer v2.1 message pump.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from xmlable import from_xml # from the xmlable library
|
from lxml.etree import _Element
|
||||||
from agentserver.message_bus.message_state import MessageState
|
from agentserver.message_bus.message_state import MessageState
|
||||||
|
|
||||||
|
# Import the customized parse_element from your forked xmlable
|
||||||
|
from third_party.xmlable import parse_element # adjust path if needed
|
||||||
|
|
||||||
|
|
||||||
async def deserialization_step(state: MessageState) -> MessageState:
|
async def deserialization_step(state: MessageState) -> MessageState:
|
||||||
"""
|
"""
|
||||||
Deserialize the validated payload_tree into the listener's dataclass.
|
Deserialize the validated payload_tree into the listener's @xmlify dataclass.
|
||||||
|
|
||||||
Requires:
|
Requires:
|
||||||
- state.payload_tree valid against listener XSD
|
- state.payload_tree: validated lxml Element
|
||||||
- state.metadata["payload_class"] set to the target dataclass (set at registration)
|
- state.metadata["payload_class"]: the target dataclass
|
||||||
|
|
||||||
On success: state.payload = dataclass instance
|
Uses the custom parse_element routine for direct in-memory deserialization.
|
||||||
On failure: state.error set with clear message
|
|
||||||
"""
|
"""
|
||||||
if state.payload_tree is None:
|
if state.payload_tree is None:
|
||||||
state.error = "deserialization_step: no payload_tree (previous step failed)"
|
state.error = "deserialization_step: no payload_tree (previous step failed)"
|
||||||
|
|
@ -35,8 +35,8 @@ async def deserialization_step(state: MessageState) -> MessageState:
|
||||||
return state
|
return state
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# xmlable.from_xml handles namespace-aware deserialization
|
# Direct in-memory deserialization — fast and clean
|
||||||
instance = from_xml(payload_class, state.payload_tree)
|
instance = parse_element(payload_class, state.payload_tree)
|
||||||
state.payload = instance
|
state.payload = instance
|
||||||
|
|
||||||
except Exception as exc: # pylint: disable=broad-except
|
except Exception as exc: # pylint: disable=broad-except
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,7 @@
|
||||||
payload_extraction.py — Extract the inner payload from the validated <message> envelope.
|
payload_extraction.py — Extract the inner payload from the validated <message> envelope.
|
||||||
|
|
||||||
After envelope_validation_step confirms a correct outer <message> envelope,
|
After envelope_validation_step confirms a correct outer <message> envelope,
|
||||||
this step removes the envelope elements (<thread>, <from>, optional <to>, etc.)
|
this step extracts metadata from <meta> and isolates the single payload element.
|
||||||
and isolates the single child element that is the actual payload.
|
|
||||||
|
|
||||||
The payload is expected to be exactly one root element (the capability-specific XML).
|
The payload is expected to be exactly one root element (the capability-specific XML).
|
||||||
If zero or multiple payload roots are found, we set a clear error — this protects
|
If zero or multiple payload roots are found, we set a clear error — this protects
|
||||||
|
|
@ -17,19 +16,25 @@ from agentserver.message_bus.message_state import MessageState
|
||||||
|
|
||||||
# Envelope namespace for easy reference
|
# Envelope namespace for easy reference
|
||||||
_ENVELOPE_NS = "https://xml-pipeline.org/ns/envelope/v1"
|
_ENVELOPE_NS = "https://xml-pipeline.org/ns/envelope/v1"
|
||||||
_MESSAGE_TAG = f"{{{ _ENVELOPE_NS }}}message"
|
_MESSAGE_TAG = f"{{{_ENVELOPE_NS}}}message"
|
||||||
|
_META_TAG = f"{{{_ENVELOPE_NS}}}meta"
|
||||||
|
_FROM_TAG = f"{{{_ENVELOPE_NS}}}from"
|
||||||
|
_TO_TAG = f"{{{_ENVELOPE_NS}}}to"
|
||||||
|
_THREAD_TAG = f"{{{_ENVELOPE_NS}}}thread"
|
||||||
|
|
||||||
|
|
||||||
async def payload_extraction_step(state: MessageState) -> MessageState:
|
async def payload_extraction_step(state: MessageState) -> MessageState:
|
||||||
"""
|
"""
|
||||||
Extract the single payload element from the validated envelope.
|
Extract the single payload element from the validated envelope.
|
||||||
|
|
||||||
Expected structure:
|
Expected structure (per envelope.xsd):
|
||||||
<message xmlns="https://xml-pipeline.org/ns/envelope/v1">
|
<message xmlns="https://xml-pipeline.org/ns/envelope/v1">
|
||||||
<thread>uuid</thread>
|
<meta>
|
||||||
<from>sender</from>
|
<from>sender</from>
|
||||||
<!-- optional <to>receiver</to> -->
|
<to>receiver</to> <!-- optional -->
|
||||||
<payload_root> ← this is the one we want
|
<thread>uuid</thread>
|
||||||
|
</meta>
|
||||||
|
<payload_root xmlns="..."> ← this is what we extract
|
||||||
...
|
...
|
||||||
</payload_root>
|
</payload_root>
|
||||||
</message>
|
</message>
|
||||||
|
|
@ -41,24 +46,42 @@ async def payload_extraction_step(state: MessageState) -> MessageState:
|
||||||
state.error = "payload_extraction_step: no envelope_tree (previous step failed)"
|
state.error = "payload_extraction_step: no envelope_tree (previous step failed)"
|
||||||
return state
|
return state
|
||||||
|
|
||||||
# Basic sanity — root must be <message> in correct namespace (already checked by schema,
|
# Basic sanity — root must be <message> in correct namespace
|
||||||
# but we double-check for defence in depth)
|
|
||||||
if state.envelope_tree.tag != _MESSAGE_TAG:
|
if state.envelope_tree.tag != _MESSAGE_TAG:
|
||||||
state.error = f"payload_extraction_step: root tag is not <message> in envelope namespace"
|
state.error = "payload_extraction_step: root tag is not <message> in envelope namespace"
|
||||||
return state
|
return state
|
||||||
|
|
||||||
# Find all direct children that are not envelope control elements
|
# Find <meta> block and extract provenance
|
||||||
# Envelope control elements are: thread, from, to (optional)
|
meta_elem = state.envelope_tree.find(_META_TAG)
|
||||||
|
if meta_elem is None:
|
||||||
|
state.error = "payload_extraction_step: missing <meta> block in envelope"
|
||||||
|
return state
|
||||||
|
|
||||||
|
# Extract from_id (required)
|
||||||
|
from_elem = meta_elem.find(_FROM_TAG)
|
||||||
|
if from_elem is not None and from_elem.text:
|
||||||
|
state.from_id = from_elem.text.strip()
|
||||||
|
else:
|
||||||
|
state.error = "payload_extraction_step: missing <from> in <meta>"
|
||||||
|
return state
|
||||||
|
|
||||||
|
# Extract thread_id (required)
|
||||||
|
thread_elem = meta_elem.find(_THREAD_TAG)
|
||||||
|
if thread_elem is not None and thread_elem.text:
|
||||||
|
state.thread_id = thread_elem.text.strip()
|
||||||
|
else:
|
||||||
|
state.error = "payload_extraction_step: missing <thread> in <meta>"
|
||||||
|
return state
|
||||||
|
|
||||||
|
# Optional: extract <to> for direct routing
|
||||||
|
to_elem = meta_elem.find(_TO_TAG)
|
||||||
|
if to_elem is not None and to_elem.text:
|
||||||
|
state.to_id = to_elem.text.strip()
|
||||||
|
|
||||||
|
# Find all direct children that are NOT <meta> — those are payload candidates
|
||||||
payload_candidates = [
|
payload_candidates = [
|
||||||
child
|
child for child in state.envelope_tree
|
||||||
for child in state.envelope_tree
|
if child.tag != _META_TAG
|
||||||
if not (
|
|
||||||
child.tag in {
|
|
||||||
f"{{{ _ENVELOPE_NS }}}thread",
|
|
||||||
f"{{{ _ENVELOPE_NS }}}from",
|
|
||||||
f"{{{ _ENVELOPE_NS }}}to",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(payload_candidates) == 0:
|
if len(payload_candidates) == 0:
|
||||||
|
|
@ -73,19 +96,6 @@ async def payload_extraction_step(state: MessageState) -> MessageState:
|
||||||
return state
|
return state
|
||||||
|
|
||||||
# Success — exactly one payload element
|
# Success — exactly one payload element
|
||||||
payload_element = payload_candidates[0]
|
state.payload_tree = payload_candidates[0]
|
||||||
|
|
||||||
# Optional: capture provenance from envelope for later use
|
|
||||||
# (these will be trustworthy because envelope was validated)
|
|
||||||
thread_elem = state.envelope_tree.find(f"{{{ _ENVELOPE_NS }}}thread")
|
|
||||||
from_elem = state.envelope_tree.find(f"{{{ _ENVELOPE_NS }}}from")
|
|
||||||
|
|
||||||
if thread_elem is not None and thread_elem.text:
|
|
||||||
state.thread_id = thread_elem.text.strip()
|
|
||||||
|
|
||||||
if from_elem is not None and from_elem.text:
|
|
||||||
state.from_id = from_elem.text.strip()
|
|
||||||
|
|
||||||
state.payload_tree = payload_element
|
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
@ -1,50 +1,70 @@
|
||||||
"""
|
"""
|
||||||
routing_resolution.py — Resolve routing based on derived root tag.
|
routing_resolution.py — Resolve routing based on derived root tag.
|
||||||
|
|
||||||
This is the final preparation step before dispatch.
|
This step computes the root tag from the deserialized payload and looks it up
|
||||||
It computes the root tag from the deserialized payload and looks it up in the
|
in a routing table (root_tag → list[Listener]).
|
||||||
global routing table (root_tag → list[Listener]).
|
|
||||||
|
|
||||||
On success: state.target_listeners is set
|
NOTE: The StreamPump has routing built-in via _route_step(). This standalone
|
||||||
On failure: state.error is set → message falls to system pipeline for <huh>
|
step is provided for custom pipeline configurations or testing.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
routing_step = make_routing_step(routing_table)
|
||||||
|
state = await routing_step(state)
|
||||||
|
|
||||||
Part of AgentServer v2.1 message pump.
|
Part of AgentServer v2.1 message pump.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Dict, List, Callable, Awaitable, TYPE_CHECKING
|
||||||
|
|
||||||
from agentserver.message_bus.message_state import MessageState
|
from agentserver.message_bus.message_state import MessageState
|
||||||
from agentserver.message_bus.bus import MessageBus
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from agentserver.message_bus.stream_pump import Listener
|
||||||
|
|
||||||
|
|
||||||
async def routing_resolution_step(state: MessageState) -> MessageState:
|
def make_routing_step(
|
||||||
|
routing_table: Dict[str, List["Listener"]]
|
||||||
|
) -> Callable[[MessageState], Awaitable[MessageState]]:
|
||||||
"""
|
"""
|
||||||
Resolve which listener(s) should handle this payload.
|
Factory: create a routing step with a specific routing table.
|
||||||
|
|
||||||
Root tag = f"{from_id.lower()}.{payload_class_name.lower()}"
|
The routing table maps root tags to lists of listeners:
|
||||||
(from_id is trustworthy — injected by pump)
|
{"agent.payload": [listener1, listener2], ...}
|
||||||
|
|
||||||
Supports:
|
|
||||||
- Normal unique routing (one listener)
|
|
||||||
- Broadcast (multiple listeners if broadcast: true and same root tag)
|
|
||||||
|
|
||||||
If no match → error, falls to system pipeline.
|
|
||||||
"""
|
"""
|
||||||
if state.payload is None:
|
|
||||||
state.error = "routing_resolution_step: no deserialized payload (previous step failed)"
|
async def routing_resolution_step(state: MessageState) -> MessageState:
|
||||||
|
"""
|
||||||
|
Resolve which listener(s) should handle this payload.
|
||||||
|
|
||||||
|
Root tag = f"{from_id.lower()}.{payload_class_name.lower()}"
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- Normal unique routing (one listener)
|
||||||
|
- Broadcast (multiple listeners if same root tag)
|
||||||
|
|
||||||
|
If no match → error, falls to system pipeline.
|
||||||
|
"""
|
||||||
|
if state.payload is None:
|
||||||
|
state.error = "routing_resolution_step: no deserialized payload"
|
||||||
|
return state
|
||||||
|
|
||||||
|
if state.to_id is None:
|
||||||
|
state.error = "routing_resolution_step: missing to_id"
|
||||||
|
return state
|
||||||
|
|
||||||
|
payload_class_name = type(state.payload).__name__.lower()
|
||||||
|
root_tag = f"{state.to_id.lower()}.{payload_class_name}"
|
||||||
|
|
||||||
|
targets = routing_table.get(root_tag)
|
||||||
|
|
||||||
|
if not targets:
|
||||||
|
state.error = f"routing_resolution_step: unknown root tag '{root_tag}'"
|
||||||
|
return state
|
||||||
|
|
||||||
|
state.target_listeners = targets
|
||||||
return state
|
return state
|
||||||
|
|
||||||
if state.from_id is None:
|
routing_resolution_step.__name__ = "routing_resolution_step"
|
||||||
state.error = "routing_resolution_step: missing from_id (provenance error)"
|
return routing_resolution_step
|
||||||
return state
|
|
||||||
|
|
||||||
payload_class_name = type(state.payload).__name__.lower()
|
|
||||||
root_tag = f"{state.from_id.lower()}.{payload_class_name}"
|
|
||||||
|
|
||||||
bus = MessageBus.get_instance()
|
|
||||||
targets = bus.routing_table.get(root_tag)
|
|
||||||
|
|
||||||
if not targets:
|
|
||||||
state.error = f"routing_resolution_step: unknown capability root tag '{root_tag}'"
|
|
||||||
return state
|
|
||||||
|
|
||||||
state.target_listeners = targets
|
|
||||||
return state
|
|
||||||
|
|
|
||||||
592
agentserver/message_bus/stream_pump.py
Normal file
592
agentserver/message_bus/stream_pump.py
Normal file
|
|
@ -0,0 +1,592 @@
|
||||||
|
"""
|
||||||
|
pump_aiostream.py — Stream-Based Message Pump using aiostream
|
||||||
|
|
||||||
|
This implementation treats the entire message flow as composable streams.
|
||||||
|
Fan-out (multi-payload, broadcast) is handled naturally via flatmap.
|
||||||
|
|
||||||
|
Key insight: Each step is a stream transformer, not a 1:1 function.
|
||||||
|
The pipeline is just a composition of stream operators.
|
||||||
|
|
||||||
|
Dependencies:
|
||||||
|
pip install aiostream
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import importlib
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import AsyncIterable, Callable, List, Dict, Any, Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from lxml import etree
|
||||||
|
from aiostream import stream, pipe, operator
|
||||||
|
|
||||||
|
# Import existing step implementations (we'll wrap them)
|
||||||
|
from agentserver.message_bus.steps.repair import repair_step
|
||||||
|
from agentserver.message_bus.steps.c14n import c14n_step
|
||||||
|
from agentserver.message_bus.steps.envelope_validation import envelope_validation_step
|
||||||
|
from agentserver.message_bus.steps.payload_extraction import payload_extraction_step
|
||||||
|
from agentserver.message_bus.steps.thread_assignment import thread_assignment_step
|
||||||
|
from agentserver.message_bus.message_state import MessageState, HandlerMetadata
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Configuration (same as before)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ListenerConfig:
|
||||||
|
name: str
|
||||||
|
payload_class_path: str
|
||||||
|
handler_path: str
|
||||||
|
description: str
|
||||||
|
is_agent: bool = False
|
||||||
|
peers: List[str] = field(default_factory=list)
|
||||||
|
broadcast: bool = False
|
||||||
|
payload_class: type = field(default=None, repr=False)
|
||||||
|
handler: Callable = field(default=None, repr=False)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OrganismConfig:
|
||||||
|
name: str
|
||||||
|
identity_path: str = ""
|
||||||
|
port: int = 8765
|
||||||
|
thread_scheduling: str = "breadth-first"
|
||||||
|
listeners: List[ListenerConfig] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Concurrency tuning
|
||||||
|
max_concurrent_pipelines: int = 50 # Total concurrent messages in pipeline
|
||||||
|
max_concurrent_handlers: int = 20 # Concurrent handler invocations
|
||||||
|
max_concurrent_per_agent: int = 5 # Per-agent rate limit
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Listener:
|
||||||
|
name: str
|
||||||
|
payload_class: type
|
||||||
|
handler: Callable
|
||||||
|
description: str
|
||||||
|
is_agent: bool = False
|
||||||
|
peers: List[str] = field(default_factory=list)
|
||||||
|
broadcast: bool = False
|
||||||
|
schema: etree.XMLSchema = field(default=None, repr=False)
|
||||||
|
root_tag: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Stream-Based Pipeline Steps
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def wrap_step(step_fn: Callable) -> Callable:
|
||||||
|
"""
|
||||||
|
Wrap an existing async step function for use with pipe.map.
|
||||||
|
|
||||||
|
Existing steps: async def step(state) -> state
|
||||||
|
We keep them as-is since pipe.map handles the iteration.
|
||||||
|
"""
|
||||||
|
return step_fn
|
||||||
|
|
||||||
|
|
||||||
|
async def extract_payloads(state: MessageState) -> AsyncIterable[MessageState]:
|
||||||
|
"""
|
||||||
|
Fan-out step: Extract 1..N payloads from handler response.
|
||||||
|
|
||||||
|
This is used with pipe.flatmap — yields multiple states for each input.
|
||||||
|
"""
|
||||||
|
if state.raw_bytes is None:
|
||||||
|
yield state
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Wrap in dummy to handle multiple roots
|
||||||
|
wrapped = b"<dummy>" + state.raw_bytes + b"</dummy>"
|
||||||
|
tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True))
|
||||||
|
|
||||||
|
children = list(tree)
|
||||||
|
if not children:
|
||||||
|
yield state
|
||||||
|
return
|
||||||
|
|
||||||
|
for child in children:
|
||||||
|
payload_bytes = etree.tostring(child)
|
||||||
|
yield MessageState(
|
||||||
|
raw_bytes=payload_bytes,
|
||||||
|
thread_id=state.thread_id,
|
||||||
|
from_id=state.from_id,
|
||||||
|
metadata=state.metadata.copy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
# On parse failure, pass through as-is
|
||||||
|
yield state
|
||||||
|
|
||||||
|
|
||||||
|
def make_xsd_validation(schema: etree.XMLSchema) -> Callable:
|
||||||
|
"""Factory for XSD validation step with schema baked in."""
|
||||||
|
async def validate(state: MessageState) -> MessageState:
|
||||||
|
if state.payload_tree is None or state.error:
|
||||||
|
return state
|
||||||
|
try:
|
||||||
|
schema.assertValid(state.payload_tree)
|
||||||
|
except etree.DocumentInvalid as e:
|
||||||
|
state.error = f"XSD validation failed: {e}"
|
||||||
|
return state
|
||||||
|
return validate
|
||||||
|
|
||||||
|
|
||||||
|
def make_deserialization(payload_class: type) -> Callable:
|
||||||
|
"""Factory for deserialization step with class baked in."""
|
||||||
|
from third_party.xmlable import parse_element
|
||||||
|
|
||||||
|
async def deserialize(state: MessageState) -> MessageState:
|
||||||
|
if state.payload_tree is None or state.error:
|
||||||
|
return state
|
||||||
|
try:
|
||||||
|
state.payload = parse_element(payload_class, state.payload_tree)
|
||||||
|
except Exception as e:
|
||||||
|
state.error = f"Deserialization failed: {e}"
|
||||||
|
return state
|
||||||
|
return deserialize
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# The Stream-Based Pump
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class StreamPump:
|
||||||
|
"""
|
||||||
|
Message pump built on aiostream.
|
||||||
|
|
||||||
|
The entire flow is a single composable stream pipeline.
|
||||||
|
Fan-out is natural via flatmap. Concurrency is controlled via task_limit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: OrganismConfig):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# Message queue feeds the stream
|
||||||
|
self.queue: asyncio.Queue[MessageState] = asyncio.Queue()
|
||||||
|
|
||||||
|
# Routing table
|
||||||
|
self.routing_table: Dict[str, List[Listener]] = {}
|
||||||
|
self.listeners: Dict[str, Listener] = {}
|
||||||
|
|
||||||
|
# Per-agent semaphores for rate limiting
|
||||||
|
self.agent_semaphores: Dict[str, asyncio.Semaphore] = {}
|
||||||
|
|
||||||
|
# Shutdown control
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Registration
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def register_listener(self, lc: ListenerConfig) -> Listener:
|
||||||
|
root_tag = f"{lc.name.lower()}.{lc.payload_class.__name__.lower()}"
|
||||||
|
|
||||||
|
listener = Listener(
|
||||||
|
name=lc.name,
|
||||||
|
payload_class=lc.payload_class,
|
||||||
|
handler=lc.handler,
|
||||||
|
description=lc.description,
|
||||||
|
is_agent=lc.is_agent,
|
||||||
|
peers=lc.peers,
|
||||||
|
broadcast=lc.broadcast,
|
||||||
|
schema=self._generate_schema(lc.payload_class),
|
||||||
|
root_tag=root_tag,
|
||||||
|
)
|
||||||
|
|
||||||
|
if lc.is_agent:
|
||||||
|
self.agent_semaphores[lc.name] = asyncio.Semaphore(
|
||||||
|
self.config.max_concurrent_per_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
self.routing_table.setdefault(root_tag, []).append(listener)
|
||||||
|
self.listeners[lc.name] = listener
|
||||||
|
return listener
|
||||||
|
|
||||||
|
def register_all(self) -> None:
|
||||||
|
for lc in self.config.listeners:
|
||||||
|
self.register_listener(lc)
|
||||||
|
|
||||||
|
def _generate_schema(self, payload_class: type) -> etree.XMLSchema:
|
||||||
|
"""Generate XSD schema from xmlified payload class."""
|
||||||
|
if hasattr(payload_class, 'xsd'):
|
||||||
|
xsd_tree = payload_class.xsd()
|
||||||
|
return etree.XMLSchema(xsd_tree)
|
||||||
|
# Fallback for non-xmlified classes (e.g., in tests)
|
||||||
|
permissive = '<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema"><xs:any processContents="lax"/></xs:schema>'
|
||||||
|
return etree.XMLSchema(etree.fromstring(permissive.encode()))
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Stream Source
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _queue_source(self) -> AsyncIterable[MessageState]:
|
||||||
|
"""Async generator that yields messages from the queue."""
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
state = await asyncio.wait_for(self.queue.get(), timeout=0.5)
|
||||||
|
yield state
|
||||||
|
self.queue.task_done()
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Pipeline Steps (as stream operators)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _route_step(self, state: MessageState) -> MessageState:
|
||||||
|
"""Determine target listeners based on to_id.class format."""
|
||||||
|
if state.error or state.payload is None:
|
||||||
|
return state
|
||||||
|
|
||||||
|
payload_class_name = type(state.payload).__name__.lower()
|
||||||
|
to_id = (state.to_id or "").lower()
|
||||||
|
root_tag = f"{to_id}.{payload_class_name}" if to_id else payload_class_name
|
||||||
|
|
||||||
|
targets = self.routing_table.get(root_tag)
|
||||||
|
if targets:
|
||||||
|
state.target_listeners = targets
|
||||||
|
else:
|
||||||
|
state.error = f"No listener for: {root_tag}"
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
async def _dispatch_to_handlers(self, state: MessageState) -> AsyncIterable[MessageState]:
|
||||||
|
"""
|
||||||
|
Fan-out step: Dispatch to handler(s) and yield response states.
|
||||||
|
|
||||||
|
For broadcast, yields one response per listener.
|
||||||
|
Each response becomes a new message in the stream.
|
||||||
|
"""
|
||||||
|
if state.error or not state.target_listeners:
|
||||||
|
# Pass through errors/unroutable for downstream handling
|
||||||
|
yield state
|
||||||
|
return
|
||||||
|
|
||||||
|
for listener in state.target_listeners:
|
||||||
|
try:
|
||||||
|
# Rate limiting for agents
|
||||||
|
semaphore = self.agent_semaphores.get(listener.name)
|
||||||
|
if semaphore:
|
||||||
|
await semaphore.acquire()
|
||||||
|
|
||||||
|
try:
|
||||||
|
metadata = HandlerMetadata(
|
||||||
|
thread_id=state.thread_id or "",
|
||||||
|
from_id=state.from_id or "",
|
||||||
|
own_name=listener.name if listener.is_agent else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
response_bytes = await listener.handler(state.payload, metadata)
|
||||||
|
|
||||||
|
if not isinstance(response_bytes, bytes):
|
||||||
|
response_bytes = b"<huh>Handler returned invalid type</huh>"
|
||||||
|
|
||||||
|
# Yield response — will be processed by next iteration
|
||||||
|
yield MessageState(
|
||||||
|
raw_bytes=response_bytes,
|
||||||
|
thread_id=state.thread_id,
|
||||||
|
from_id=listener.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
if semaphore:
|
||||||
|
semaphore.release()
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
yield MessageState(
|
||||||
|
raw_bytes=f"<huh>Handler {listener.name} crashed: {exc}</huh>".encode(),
|
||||||
|
thread_id=state.thread_id,
|
||||||
|
from_id=listener.name,
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _reinject_responses(self, state: MessageState) -> None:
|
||||||
|
"""Push handler responses back into the queue for next iteration."""
|
||||||
|
await self.queue.put(state)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Build the Pipeline
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def build_pipeline(self, source: AsyncIterable[MessageState]):
|
||||||
|
"""
|
||||||
|
Construct the full processing pipeline.
|
||||||
|
|
||||||
|
This is where you configure the flow. Modify this method to:
|
||||||
|
- Add/remove steps
|
||||||
|
- Change concurrency limits
|
||||||
|
- Insert logging/metrics
|
||||||
|
- Add filtering
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The pipeline is a composition of stream operators
|
||||||
|
pipeline = (
|
||||||
|
stream.iterate(source)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# STAGE 1: Envelope Processing (1:1 transforms)
|
||||||
|
# ============================================================
|
||||||
|
| pipe.map(repair_step)
|
||||||
|
| pipe.map(c14n_step)
|
||||||
|
| pipe.map(envelope_validation_step)
|
||||||
|
| pipe.map(payload_extraction_step)
|
||||||
|
| pipe.map(thread_assignment_step)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# STAGE 2: Fan-out — Extract Multiple Payloads (1:N)
|
||||||
|
# ============================================================
|
||||||
|
# Handler responses may contain multiple payloads.
|
||||||
|
# Each becomes a separate message in the stream.
|
||||||
|
| pipe.flatmap(extract_payloads)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# STAGE 3: Per-Payload Validation (1:1 transforms)
|
||||||
|
# ============================================================
|
||||||
|
# Note: In a real implementation, you'd route to listener-specific
|
||||||
|
# validation here. For now, we use a simplified approach.
|
||||||
|
| pipe.map(self._validate_and_deserialize)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# STAGE 4: Routing (1:1)
|
||||||
|
# ============================================================
|
||||||
|
| pipe.map(self._route_step)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# STAGE 5: Filter Errors
|
||||||
|
# ============================================================
|
||||||
|
# Errors go to a separate handler (could also be a branch)
|
||||||
|
| pipe.map(self._handle_errors)
|
||||||
|
| pipe.filter(lambda s: s.error is None and s.target_listeners)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# STAGE 6: Fan-out — Dispatch to Handlers (1:N for broadcast)
|
||||||
|
# ============================================================
|
||||||
|
# This is where handlers are invoked. Broadcast = multiple yields.
|
||||||
|
# task_limit controls concurrent handler invocations.
|
||||||
|
| pipe.flatmap(
|
||||||
|
self._dispatch_to_handlers,
|
||||||
|
task_limit=self.config.max_concurrent_handlers
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# STAGE 7: Re-inject Responses
|
||||||
|
# ============================================================
|
||||||
|
# Handler responses go back into the queue for next iteration.
|
||||||
|
# The cycle continues until no more messages.
|
||||||
|
| pipe.action(self._reinject_responses)
|
||||||
|
)
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
async def _validate_and_deserialize(self, state: MessageState) -> MessageState:
|
||||||
|
"""
|
||||||
|
Combined validation + deserialization.
|
||||||
|
|
||||||
|
Uses to_id + payload tag to find the right listener and schema.
|
||||||
|
"""
|
||||||
|
if state.error or state.payload_tree is None:
|
||||||
|
return state
|
||||||
|
|
||||||
|
# Build lookup key: to_id.payload_tag (matching routing table format)
|
||||||
|
payload_tag = state.payload_tree.tag
|
||||||
|
if payload_tag.startswith("{"):
|
||||||
|
payload_tag = payload_tag.split("}", 1)[1]
|
||||||
|
|
||||||
|
to_id = (state.to_id or "").lower()
|
||||||
|
lookup_key = f"{to_id}.{payload_tag.lower()}" if to_id else payload_tag.lower()
|
||||||
|
|
||||||
|
listeners = self.routing_table.get(lookup_key, [])
|
||||||
|
if not listeners:
|
||||||
|
state.error = f"No listener for: {lookup_key}"
|
||||||
|
return state
|
||||||
|
|
||||||
|
listener = listeners[0]
|
||||||
|
|
||||||
|
# Validate against listener's schema
|
||||||
|
try:
|
||||||
|
listener.schema.assertValid(state.payload_tree)
|
||||||
|
except etree.DocumentInvalid as e:
|
||||||
|
state.error = f"XSD validation failed: {e}"
|
||||||
|
return state
|
||||||
|
|
||||||
|
# Deserialize
|
||||||
|
try:
|
||||||
|
from third_party.xmlable import parse_element
|
||||||
|
state.payload = parse_element(listener.payload_class, state.payload_tree)
|
||||||
|
except Exception as e:
|
||||||
|
state.error = f"Deserialization failed: {e}"
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
async def _handle_errors(self, state: MessageState) -> MessageState:
|
||||||
|
"""Log errors (could also emit <huh> messages)."""
|
||||||
|
if state.error:
|
||||||
|
print(f"[ERROR] {state.thread_id}: {state.error}")
|
||||||
|
# Could emit <huh> to a specific listener here
|
||||||
|
return state
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Run the Pump
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def run(self) -> None:
|
||||||
|
"""
|
||||||
|
Main entry point — run the stream pipeline.
|
||||||
|
|
||||||
|
The pipeline pulls from the queue, processes messages,
|
||||||
|
and re-injects handler responses. Continues until shutdown.
|
||||||
|
"""
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
pipeline = self.build_pipeline(self._queue_source())
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with pipeline.stream() as streamer:
|
||||||
|
async for _ in streamer:
|
||||||
|
# The pipeline drives itself via re-injection.
|
||||||
|
# We just need to consume the stream.
|
||||||
|
pass
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# External API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def inject(self, raw_bytes: bytes, thread_id: str, from_id: str) -> None:
|
||||||
|
"""Inject a message to start processing."""
|
||||||
|
state = MessageState(
|
||||||
|
raw_bytes=raw_bytes,
|
||||||
|
thread_id=thread_id,
|
||||||
|
from_id=from_id,
|
||||||
|
)
|
||||||
|
await self.queue.put(state)
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
"""Graceful shutdown — wait for queue to drain."""
|
||||||
|
self._running = False
|
||||||
|
await self.queue.join()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Config Loader (same as before)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class ConfigLoader:
|
||||||
|
@classmethod
|
||||||
|
def load(cls, path: str | Path) -> OrganismConfig:
|
||||||
|
with open(Path(path)) as f:
|
||||||
|
raw = yaml.safe_load(f)
|
||||||
|
return cls._parse(raw)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse(cls, raw: dict) -> OrganismConfig:
|
||||||
|
org = raw.get("organism", {})
|
||||||
|
config = OrganismConfig(
|
||||||
|
name=org.get("name", "unnamed"),
|
||||||
|
identity_path=org.get("identity", ""),
|
||||||
|
port=org.get("port", 8765),
|
||||||
|
thread_scheduling=raw.get("thread_scheduling", "breadth-first"),
|
||||||
|
max_concurrent_pipelines=raw.get("max_concurrent_pipelines", 50),
|
||||||
|
max_concurrent_handlers=raw.get("max_concurrent_handlers", 20),
|
||||||
|
max_concurrent_per_agent=raw.get("max_concurrent_per_agent", 5),
|
||||||
|
)
|
||||||
|
|
||||||
|
for entry in raw.get("listeners", []):
|
||||||
|
lc = cls._parse_listener(entry)
|
||||||
|
cls._resolve_imports(lc)
|
||||||
|
config.listeners.append(lc)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse_listener(cls, raw: dict) -> ListenerConfig:
|
||||||
|
return ListenerConfig(
|
||||||
|
name=raw["name"],
|
||||||
|
payload_class_path=raw["payload_class"],
|
||||||
|
handler_path=raw["handler"],
|
||||||
|
description=raw["description"],
|
||||||
|
is_agent=raw.get("agent", False),
|
||||||
|
peers=raw.get("peers", []),
|
||||||
|
broadcast=raw.get("broadcast", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _resolve_imports(cls, lc: ListenerConfig) -> None:
|
||||||
|
mod, cls_name = lc.payload_class_path.rsplit(".", 1)
|
||||||
|
lc.payload_class = getattr(importlib.import_module(mod), cls_name)
|
||||||
|
|
||||||
|
mod, fn_name = lc.handler_path.rsplit(".", 1)
|
||||||
|
lc.handler = getattr(importlib.import_module(mod), fn_name)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Bootstrap
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
async def bootstrap(config_path: str = "config/organism.yaml") -> StreamPump:
|
||||||
|
"""Load config and create pump."""
|
||||||
|
config = ConfigLoader.load(config_path)
|
||||||
|
print(f"Organism: {config.name}")
|
||||||
|
print(f"Listeners: {len(config.listeners)}")
|
||||||
|
|
||||||
|
pump = StreamPump(config)
|
||||||
|
pump.register_all()
|
||||||
|
|
||||||
|
print(f"Routing: {list(pump.routing_table.keys())}")
|
||||||
|
return pump
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Example: Customizing the Pipeline
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""
|
||||||
|
The beauty of aiostream: the pipeline is just a composition.
|
||||||
|
You can easily insert, remove, or reorder stages.
|
||||||
|
|
||||||
|
# Add logging between stages:
|
||||||
|
| pipe.action(lambda s: print(f"After repair: {s.thread_id}"))
|
||||||
|
|
||||||
|
# Add throttling:
|
||||||
|
| pipe.map(some_step, task_limit=5)
|
||||||
|
|
||||||
|
# Branch errors to a separate stream:
|
||||||
|
errors, valid = stream.partition(source, lambda s: s.error is not None)
|
||||||
|
|
||||||
|
# Merge multiple sources:
|
||||||
|
combined = stream.merge(queue_source, oob_source, external_api_source)
|
||||||
|
|
||||||
|
# Add timeout per message:
|
||||||
|
| pipe.timeout(30.0) # 30 second timeout per item
|
||||||
|
|
||||||
|
# Rate limit the whole pipeline:
|
||||||
|
| pipe.spaceout(0.1) # 100ms between items
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Comparison: Old vs New
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""
|
||||||
|
OLD (bus.py):
|
||||||
|
for payload in payloads:
|
||||||
|
await pipeline.process(payload) # Sequential, recursive
|
||||||
|
|
||||||
|
NEW (aiostream):
|
||||||
|
| pipe.flatmap(extract_payloads) # Fan-out, parallel
|
||||||
|
| pipe.flatmap(dispatch, task_limit=20) # Concurrent handlers
|
||||||
|
|
||||||
|
The key difference:
|
||||||
|
- Old: 3 tool calls = 3 sequential awaits, each blocking until complete
|
||||||
|
- New: 3 tool calls = 3 items in stream, processed concurrently up to task_limit
|
||||||
|
"""
|
||||||
|
|
@ -20,8 +20,8 @@
|
||||||
</xs:complexType>
|
</xs:complexType>
|
||||||
</xs:element>
|
</xs:element>
|
||||||
|
|
||||||
<!-- Exactly one payload element from any foreign namespace -->
|
<!-- Exactly one payload element (any namespace including no namespace) -->
|
||||||
<xs:any namespace="##other" processContents="lax" minOccurs="0" maxOccurs="1"/>
|
<xs:any namespace="##any" processContents="lax" minOccurs="0" maxOccurs="1"/>
|
||||||
</xs:sequence>
|
</xs:sequence>
|
||||||
</xs:complexType>
|
</xs:complexType>
|
||||||
</xs:element>
|
</xs:element>
|
||||||
|
|
|
||||||
24
config/organism.yaml
Normal file
24
config/organism.yaml
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
# organism.yaml — Sample configuration for testing the message pump
|
||||||
|
#
|
||||||
|
# This defines a simple "hello world" organism with one listener.
|
||||||
|
|
||||||
|
organism:
|
||||||
|
name: hello-world
|
||||||
|
port: 8765
|
||||||
|
|
||||||
|
# Concurrency settings
|
||||||
|
max_concurrent_pipelines: 50
|
||||||
|
max_concurrent_handlers: 20
|
||||||
|
max_concurrent_per_agent: 5
|
||||||
|
|
||||||
|
# Thread scheduling: breadth-first or depth-first
|
||||||
|
thread_scheduling: breadth-first
|
||||||
|
|
||||||
|
listeners:
|
||||||
|
# The greeter listener responds to Greeting payloads
|
||||||
|
- name: greeter
|
||||||
|
payload_class: handlers.hello.Greeting
|
||||||
|
handler: handlers.hello.handle_greeting
|
||||||
|
description: Responds with a greeting message
|
||||||
|
agent: false
|
||||||
|
broadcast: false
|
||||||
1
handlers/__init__.py
Normal file
1
handlers/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
# handlers — Sample handlers for testing the message pump
|
||||||
78
handlers/hello.py
Normal file
78
handlers/hello.py
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
"""
|
||||||
|
hello.py — Hello World handler for testing the message pump.
|
||||||
|
|
||||||
|
This module provides:
|
||||||
|
- Greeting: payload class (what the handler receives)
|
||||||
|
- GreetingResponse: response payload (what the handler returns)
|
||||||
|
- handle_greeting: async handler function
|
||||||
|
|
||||||
|
Usage in organism.yaml:
|
||||||
|
listeners:
|
||||||
|
- name: greeter
|
||||||
|
payload_class: handlers.hello.Greeting
|
||||||
|
handler: handlers.hello.handle_greeting
|
||||||
|
description: Responds with a greeting message
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from lxml import etree
|
||||||
|
|
||||||
|
from third_party.xmlable import xmlify
|
||||||
|
from agentserver.message_bus.message_state import HandlerMetadata
|
||||||
|
|
||||||
|
|
||||||
|
# Envelope namespace
|
||||||
|
ENVELOPE_NS = "https://xml-pipeline.org/ns/envelope/v1"
|
||||||
|
|
||||||
|
|
||||||
|
@xmlify
|
||||||
|
@dataclass
|
||||||
|
class Greeting:
|
||||||
|
"""Incoming greeting request."""
|
||||||
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
@xmlify
|
||||||
|
@dataclass
|
||||||
|
class GreetingResponse:
|
||||||
|
"""Outgoing greeting response."""
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_in_envelope(payload_bytes: bytes, from_id: str, to_id: str, thread_id: str) -> bytes:
|
||||||
|
"""Wrap a payload in a proper message envelope."""
|
||||||
|
return f"""<message xmlns="{ENVELOPE_NS}">
|
||||||
|
<meta>
|
||||||
|
<from>{from_id}</from>
|
||||||
|
<to>{to_id}</to>
|
||||||
|
<thread>{thread_id}</thread>
|
||||||
|
</meta>
|
||||||
|
{payload_bytes.decode('utf-8')}
|
||||||
|
</message>""".encode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_greeting(payload: Greeting, metadata: HandlerMetadata) -> bytes:
|
||||||
|
"""
|
||||||
|
Handle an incoming Greeting and respond with a GreetingResponse.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
payload: The deserialized Greeting instance
|
||||||
|
metadata: Contains thread_id, from_id, own_name
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
XML bytes of the response envelope
|
||||||
|
"""
|
||||||
|
# Create response
|
||||||
|
response = GreetingResponse(message=f"Hello, {payload.name}!")
|
||||||
|
|
||||||
|
# Serialize to XML
|
||||||
|
response_tree = response.xml_value("GreetingResponse")
|
||||||
|
payload_bytes = etree.tostring(response_tree, encoding='utf-8')
|
||||||
|
|
||||||
|
# Wrap in envelope - respond back to sender
|
||||||
|
return wrap_in_envelope(
|
||||||
|
payload_bytes=payload_bytes,
|
||||||
|
from_id=metadata.own_name or "greeter",
|
||||||
|
to_id=metadata.from_id, # Send back to whoever sent the greeting
|
||||||
|
thread_id=metadata.thread_id,
|
||||||
|
)
|
||||||
|
|
@ -5,15 +5,52 @@ build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "xml-pipeline"
|
name = "xml-pipeline"
|
||||||
version = "0.1.0"
|
version = "0.2.0"
|
||||||
description = "Tamper-proof nervous system for multi-agent organisms"
|
description = "Tamper-proof nervous system for multi-agent organisms"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
license = {text = "MIT"}
|
||||||
|
keywords = ["xml", "multi-agent", "message-bus", "aiostream"]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 3 - Alpha",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Programming Language :: Python :: 3.13",
|
||||||
|
"Framework :: AsyncIO",
|
||||||
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"lxml",
|
"lxml",
|
||||||
"websockets",
|
"websockets",
|
||||||
"pyotp",
|
"pyotp",
|
||||||
"pyyaml",
|
"pyyaml",
|
||||||
"cryptography",
|
"cryptography",
|
||||||
|
"aiostream>=0.5",
|
||||||
|
"pyhumps",
|
||||||
|
"termcolor",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
test = [
|
||||||
|
"pytest>=7.0",
|
||||||
|
"pytest-asyncio>=0.21",
|
||||||
|
]
|
||||||
|
dev = [
|
||||||
|
"pytest>=7.0",
|
||||||
|
"pytest-asyncio>=0.21",
|
||||||
|
"mypy",
|
||||||
|
"ruff",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
asyncio_default_fixture_loop_scope = "function"
|
||||||
|
testpaths = ["tests"]
|
||||||
|
python_files = ["test_*.py"]
|
||||||
|
# Don't collect root __init__.py (has imports that break isolation)
|
||||||
|
norecursedirs = [".git", "__pycache__", "*.egg-info"]
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["."]
|
where = ["."]
|
||||||
|
include = ["agentserver*", "third_party*"]
|
||||||
|
|
|
||||||
48
tests/conftest.py
Normal file
48
tests/conftest.py
Normal file
|
|
@ -0,0 +1,48 @@
|
||||||
|
"""
|
||||||
|
conftest.py — Shared pytest configuration and fixtures
|
||||||
|
|
||||||
|
This file is automatically loaded by pytest.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Ensure the project root is in the path for imports
|
||||||
|
# BUT don't import the root package (it has heavy deps)
|
||||||
|
project_root = Path(__file__).parent.parent
|
||||||
|
if str(project_root) not in sys.path:
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
# Tell pytest to ignore the root __init__.py
|
||||||
|
collect_ignore_glob = ["../__init__.py"]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Markers
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
"""Register custom markers."""
|
||||||
|
config.addinivalue_line(
|
||||||
|
"markers", "slow: marks tests as slow (deselect with '-m \"not slow\"')"
|
||||||
|
)
|
||||||
|
config.addinivalue_line(
|
||||||
|
"markers", "integration: marks tests as integration tests"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Fixtures available to all tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_thread_id():
|
||||||
|
"""A valid UUID for testing."""
|
||||||
|
return "550e8400-e29b-41d4-a716-446655440000"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_from_id():
|
||||||
|
"""A valid sender ID for testing."""
|
||||||
|
return "calculator.add"
|
||||||
632
tests/test_pipeline_steps.py
Normal file
632
tests/test_pipeline_steps.py
Normal file
|
|
@ -0,0 +1,632 @@
|
||||||
|
"""
|
||||||
|
test_pipeline_steps.py — Unit tests for individual pipeline steps
|
||||||
|
|
||||||
|
Run with: pytest tests/test_pipeline_steps.py -v
|
||||||
|
|
||||||
|
Each step is tested in isolation with known inputs and expected outputs.
|
||||||
|
This makes debugging much easier than testing the full pipeline.
|
||||||
|
|
||||||
|
Install test dependencies:
|
||||||
|
pip install -e ".[test]"
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from lxml import etree
|
||||||
|
|
||||||
|
# Import the message state
|
||||||
|
from agentserver.message_bus.message_state import MessageState, HandlerMetadata
|
||||||
|
|
||||||
|
# Import individual steps
|
||||||
|
from agentserver.message_bus.steps.repair import repair_step
|
||||||
|
from agentserver.message_bus.steps.c14n import c14n_step
|
||||||
|
from agentserver.message_bus.steps.envelope_validation import envelope_validation_step
|
||||||
|
from agentserver.message_bus.steps.payload_extraction import payload_extraction_step
|
||||||
|
from agentserver.message_bus.steps.thread_assignment import thread_assignment_step
|
||||||
|
|
||||||
|
# Check for optional dependencies
|
||||||
|
try:
|
||||||
|
import aiostream
|
||||||
|
HAS_AIOSTREAM = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_AIOSTREAM = False
|
||||||
|
|
||||||
|
requires_aiostream = pytest.mark.skipif(
|
||||||
|
not HAS_AIOSTREAM,
|
||||||
|
reason="aiostream not installed (pip install aiostream)"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for stream_pump dependencies
|
||||||
|
try:
|
||||||
|
from agentserver.message_bus.stream_pump import StreamPump, Listener
|
||||||
|
from agentserver.message_bus.steps.routing_resolution import make_routing_step
|
||||||
|
HAS_STREAM_PUMP = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_STREAM_PUMP = False
|
||||||
|
|
||||||
|
requires_stream_pump = pytest.mark.skipif(
|
||||||
|
not HAS_STREAM_PUMP,
|
||||||
|
reason="stream_pump dependencies not available"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Test Fixtures
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def valid_envelope_bytes():
|
||||||
|
"""A well-formed message envelope matching envelope.xsd."""
|
||||||
|
return b'''<?xml version="1.0"?>
|
||||||
|
<message xmlns="https://xml-pipeline.org/ns/envelope/v1">
|
||||||
|
<meta>
|
||||||
|
<from>calculator.add</from>
|
||||||
|
<thread>550e8400-e29b-41d4-a716-446655440000</thread>
|
||||||
|
</meta>
|
||||||
|
<addpayload xmlns="https://xml-pipeline.org/ns/calculator/add/v1">
|
||||||
|
<a>5</a>
|
||||||
|
<b>3</b>
|
||||||
|
</addpayload>
|
||||||
|
</message>'''
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def malformed_xml_bytes():
|
||||||
|
"""Malformed XML that lxml can partially recover."""
|
||||||
|
return b'<message><unclosed><nested>content</nested></message>'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def completely_broken_bytes():
|
||||||
|
"""Not XML at all."""
|
||||||
|
return b'this is not xml at all { json: "maybe" }'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def multi_payload_response():
|
||||||
|
"""Handler response with multiple payloads."""
|
||||||
|
return b'''
|
||||||
|
<search.result><answer>42</answer></search.result>
|
||||||
|
<calculator.add.addpayload><a>1</a><b>2</b></calculator.add.addpayload>
|
||||||
|
<thought>I should also check...</thought>
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def empty_state():
|
||||||
|
"""Fresh MessageState with no data."""
|
||||||
|
return MessageState()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def state_with_bytes(valid_envelope_bytes):
|
||||||
|
"""MessageState with raw_bytes populated."""
|
||||||
|
return MessageState(raw_bytes=valid_envelope_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# repair_step Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestRepairStep:
|
||||||
|
"""Tests for the XML repair/recovery step."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_valid_xml_passes_through(self, valid_envelope_bytes):
|
||||||
|
"""Valid XML should parse without error."""
|
||||||
|
state = MessageState(raw_bytes=valid_envelope_bytes)
|
||||||
|
result = await repair_step(state)
|
||||||
|
|
||||||
|
assert result.error is None
|
||||||
|
assert result.envelope_tree is not None
|
||||||
|
assert result.envelope_tree.tag == "{https://xml-pipeline.org/ns/envelope/v1}message"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_malformed_xml_recovered(self, malformed_xml_bytes):
|
||||||
|
"""Malformed XML should be recovered if possible."""
|
||||||
|
state = MessageState(raw_bytes=malformed_xml_bytes)
|
||||||
|
result = await repair_step(state)
|
||||||
|
|
||||||
|
# lxml recovery mode should produce something
|
||||||
|
# May or may not have error depending on severity
|
||||||
|
assert result.envelope_tree is not None or result.error is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_bytes_sets_error(self, empty_state):
|
||||||
|
"""Missing raw_bytes should set an error."""
|
||||||
|
result = await repair_step(empty_state)
|
||||||
|
|
||||||
|
assert result.error is not None
|
||||||
|
assert "no raw_bytes" in result.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clears_raw_bytes_after_parse(self, valid_envelope_bytes):
|
||||||
|
"""raw_bytes should be cleared after successful parse (memory optimization)."""
|
||||||
|
state = MessageState(raw_bytes=valid_envelope_bytes)
|
||||||
|
result = await repair_step(state)
|
||||||
|
|
||||||
|
assert result.raw_bytes is None
|
||||||
|
assert result.envelope_tree is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# c14n_step Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestC14nStep:
|
||||||
|
"""Tests for the canonicalization step."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_normalizes_whitespace(self):
|
||||||
|
"""C14N should normalize whitespace."""
|
||||||
|
xml_with_whitespace = b'''<root>
|
||||||
|
<child> value </child>
|
||||||
|
</root>'''
|
||||||
|
|
||||||
|
state = MessageState(raw_bytes=xml_with_whitespace)
|
||||||
|
state = await repair_step(state)
|
||||||
|
result = await c14n_step(state)
|
||||||
|
|
||||||
|
assert result.error is None
|
||||||
|
assert result.envelope_tree is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_normalizes_attribute_order(self):
|
||||||
|
"""C14N should produce consistent attribute ordering."""
|
||||||
|
xml_a = b'<root z="1" a="2"/>'
|
||||||
|
xml_b = b'<root a="2" z="1"/>'
|
||||||
|
|
||||||
|
state_a = MessageState(raw_bytes=xml_a)
|
||||||
|
state_a = await repair_step(state_a)
|
||||||
|
state_a = await c14n_step(state_a)
|
||||||
|
|
||||||
|
state_b = MessageState(raw_bytes=xml_b)
|
||||||
|
state_b = await repair_step(state_b)
|
||||||
|
state_b = await c14n_step(state_b)
|
||||||
|
|
||||||
|
# Both should produce identical canonical form
|
||||||
|
c14n_a = etree.tostring(state_a.envelope_tree, method="c14n")
|
||||||
|
c14n_b = etree.tostring(state_b.envelope_tree, method="c14n")
|
||||||
|
assert c14n_a == c14n_b
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_tree_sets_error(self, empty_state):
|
||||||
|
"""Missing envelope_tree should set error."""
|
||||||
|
result = await c14n_step(empty_state)
|
||||||
|
|
||||||
|
assert result.error is not None
|
||||||
|
assert "no envelope_tree" in result.error
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# payload_extraction_step Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestPayloadExtractionStep:
|
||||||
|
"""Tests for extracting payload from envelope."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extracts_payload_element(self, valid_envelope_bytes):
|
||||||
|
"""Should extract the payload element from envelope."""
|
||||||
|
state = MessageState(raw_bytes=valid_envelope_bytes)
|
||||||
|
state = await repair_step(state)
|
||||||
|
state = await c14n_step(state)
|
||||||
|
# Skip envelope validation for this test
|
||||||
|
result = await payload_extraction_step(state)
|
||||||
|
|
||||||
|
assert result.error is None
|
||||||
|
assert result.payload_tree is not None
|
||||||
|
# Tag may include namespace prefix
|
||||||
|
assert "addpayload" in result.payload_tree.tag
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extracts_thread_id(self, valid_envelope_bytes):
|
||||||
|
"""Should extract thread ID from envelope."""
|
||||||
|
state = MessageState(raw_bytes=valid_envelope_bytes)
|
||||||
|
state = await repair_step(state)
|
||||||
|
state = await c14n_step(state)
|
||||||
|
result = await payload_extraction_step(state)
|
||||||
|
|
||||||
|
assert result.thread_id == "550e8400-e29b-41d4-a716-446655440000"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extracts_from_id(self, valid_envelope_bytes):
|
||||||
|
"""Should extract sender ID from envelope."""
|
||||||
|
state = MessageState(raw_bytes=valid_envelope_bytes)
|
||||||
|
state = await repair_step(state)
|
||||||
|
state = await c14n_step(state)
|
||||||
|
result = await payload_extraction_step(state)
|
||||||
|
|
||||||
|
assert result.from_id == "calculator.add"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_payloads_error(self):
|
||||||
|
"""Multiple payload elements should error."""
|
||||||
|
multi_payload = b'''
|
||||||
|
<message xmlns="https://xml-pipeline.org/ns/envelope/v1">
|
||||||
|
<meta>
|
||||||
|
<from>test</from>
|
||||||
|
<thread>uuid-here</thread>
|
||||||
|
</meta>
|
||||||
|
<payload1>data</payload1>
|
||||||
|
<payload2>more data</payload2>
|
||||||
|
</message>'''
|
||||||
|
|
||||||
|
state = MessageState(raw_bytes=multi_payload)
|
||||||
|
state = await repair_step(state)
|
||||||
|
state = await c14n_step(state)
|
||||||
|
result = await payload_extraction_step(state)
|
||||||
|
|
||||||
|
assert result.error is not None
|
||||||
|
assert "multiple payload" in result.error.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_payload_error(self):
|
||||||
|
"""Missing payload element should error."""
|
||||||
|
no_payload = b'''
|
||||||
|
<message xmlns="https://xml-pipeline.org/ns/envelope/v1">
|
||||||
|
<meta>
|
||||||
|
<from>test</from>
|
||||||
|
<thread>uuid-here</thread>
|
||||||
|
</meta>
|
||||||
|
</message>'''
|
||||||
|
|
||||||
|
state = MessageState(raw_bytes=no_payload)
|
||||||
|
state = await repair_step(state)
|
||||||
|
state = await c14n_step(state)
|
||||||
|
result = await payload_extraction_step(state)
|
||||||
|
|
||||||
|
assert result.error is not None
|
||||||
|
assert "no payload" in result.error.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_meta_error(self):
|
||||||
|
"""Missing <meta> block should error."""
|
||||||
|
no_meta = b'''
|
||||||
|
<message xmlns="https://xml-pipeline.org/ns/envelope/v1">
|
||||||
|
<payload>data</payload>
|
||||||
|
</message>'''
|
||||||
|
|
||||||
|
state = MessageState(raw_bytes=no_meta)
|
||||||
|
state = await repair_step(state)
|
||||||
|
state = await c14n_step(state)
|
||||||
|
result = await payload_extraction_step(state)
|
||||||
|
|
||||||
|
assert result.error is not None
|
||||||
|
assert "meta" in result.error.lower()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_from_error(self):
|
||||||
|
"""Missing <from> in <meta> should error."""
|
||||||
|
no_from = b'''
|
||||||
|
<message xmlns="https://xml-pipeline.org/ns/envelope/v1">
|
||||||
|
<meta>
|
||||||
|
<thread>uuid-here</thread>
|
||||||
|
</meta>
|
||||||
|
<payload>data</payload>
|
||||||
|
</message>'''
|
||||||
|
|
||||||
|
state = MessageState(raw_bytes=no_from)
|
||||||
|
state = await repair_step(state)
|
||||||
|
state = await c14n_step(state)
|
||||||
|
result = await payload_extraction_step(state)
|
||||||
|
|
||||||
|
assert result.error is not None
|
||||||
|
assert "from" in result.error.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# thread_assignment_step Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestThreadAssignmentStep:
|
||||||
|
"""Tests for thread UUID assignment."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_valid_uuid_preserved(self):
|
||||||
|
"""Valid UUID should be preserved."""
|
||||||
|
valid_uuid = "550e8400-e29b-41d4-a716-446655440000"
|
||||||
|
state = MessageState(thread_id=valid_uuid)
|
||||||
|
result = await thread_assignment_step(state)
|
||||||
|
|
||||||
|
assert result.thread_id == valid_uuid
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_missing_uuid_generated(self, empty_state):
|
||||||
|
"""Missing UUID should generate a new one."""
|
||||||
|
result = await thread_assignment_step(empty_state)
|
||||||
|
|
||||||
|
assert result.thread_id is not None
|
||||||
|
assert len(result.thread_id) == 36 # UUID format
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_uuid_replaced(self):
|
||||||
|
"""Invalid UUID should be replaced with a new one."""
|
||||||
|
state = MessageState(thread_id="not-a-valid-uuid")
|
||||||
|
result = await thread_assignment_step(state)
|
||||||
|
|
||||||
|
assert result.thread_id != "not-a-valid-uuid"
|
||||||
|
assert len(result.thread_id) == 36
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_replacement_logged_in_metadata(self):
|
||||||
|
"""Replaced UUIDs should be logged in metadata."""
|
||||||
|
state = MessageState(thread_id="bad-uuid")
|
||||||
|
result = await thread_assignment_step(state)
|
||||||
|
|
||||||
|
diagnostics = result.metadata.get("diagnostics", [])
|
||||||
|
assert len(diagnostics) > 0
|
||||||
|
assert "bad-uuid" in diagnostics[0]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Multi-Payload Extraction Tests (standalone, no aiostream required)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestPayloadExtractionLogic:
|
||||||
|
"""Test the core payload extraction logic without aiostream."""
|
||||||
|
|
||||||
|
def test_extract_single_payload(self):
|
||||||
|
"""Single root element should extract cleanly."""
|
||||||
|
raw = b"<result>42</result>"
|
||||||
|
wrapped = b"<dummy>" + raw + b"</dummy>"
|
||||||
|
tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True))
|
||||||
|
|
||||||
|
children = list(tree)
|
||||||
|
assert len(children) == 1
|
||||||
|
assert children[0].tag == "result"
|
||||||
|
assert children[0].text == "42"
|
||||||
|
|
||||||
|
def test_extract_multiple_payloads(self, multi_payload_response):
|
||||||
|
"""Multiple root elements should all be extracted."""
|
||||||
|
wrapped = b"<dummy>" + multi_payload_response + b"</dummy>"
|
||||||
|
tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True))
|
||||||
|
|
||||||
|
children = list(tree)
|
||||||
|
assert len(children) == 3
|
||||||
|
|
||||||
|
tags = [c.tag for c in children]
|
||||||
|
assert "search.result" in tags
|
||||||
|
assert "calculator.add.addpayload" in tags
|
||||||
|
assert "thought" in tags
|
||||||
|
|
||||||
|
def test_extract_preserves_content(self):
|
||||||
|
"""Extracted payloads should preserve their content."""
|
||||||
|
raw = b"<data><nested><deep>value</deep></nested></data>"
|
||||||
|
wrapped = b"<dummy>" + raw + b"</dummy>"
|
||||||
|
tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True))
|
||||||
|
|
||||||
|
children = list(tree)
|
||||||
|
assert len(children) == 1
|
||||||
|
|
||||||
|
# Re-serialize and check
|
||||||
|
extracted = etree.tostring(children[0])
|
||||||
|
assert b"<deep>value</deep>" in extracted
|
||||||
|
|
||||||
|
def test_empty_response_no_crash(self):
|
||||||
|
"""Empty response should not crash."""
|
||||||
|
wrapped = b"<dummy></dummy>"
|
||||||
|
tree = etree.fromstring(wrapped)
|
||||||
|
|
||||||
|
children = list(tree)
|
||||||
|
assert len(children) == 0
|
||||||
|
|
||||||
|
def test_malformed_response_recovers(self):
|
||||||
|
"""Malformed XML should be recovered if possible."""
|
||||||
|
raw = b"<unclosed><valid>text</valid>"
|
||||||
|
wrapped = b"<dummy>" + raw + b"</dummy>"
|
||||||
|
|
||||||
|
# With recovery parser
|
||||||
|
tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True))
|
||||||
|
# Should get something, exact result depends on lxml recovery
|
||||||
|
assert tree is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Multi-Payload Extraction Tests (from stream_pump.py)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@requires_aiostream
|
||||||
|
class TestMultiPayloadExtraction:
|
||||||
|
"""Tests for the fan-out payload extraction."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_payload_yields_one(self):
|
||||||
|
"""Single payload should yield one state."""
|
||||||
|
from agentserver.message_bus.stream_pump import extract_payloads
|
||||||
|
|
||||||
|
state = MessageState(
|
||||||
|
raw_bytes=b"<result>42</result>",
|
||||||
|
thread_id="test-thread",
|
||||||
|
from_id="test-sender",
|
||||||
|
)
|
||||||
|
|
||||||
|
results = [s async for s in extract_payloads(state)]
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert b"<result>" in results[0].raw_bytes
|
||||||
|
assert results[0].thread_id == "test-thread"
|
||||||
|
assert results[0].from_id == "test-sender"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_payloads_yields_many(self, multi_payload_response):
|
||||||
|
"""Multiple payloads should yield multiple states."""
|
||||||
|
from agentserver.message_bus.stream_pump import extract_payloads
|
||||||
|
|
||||||
|
state = MessageState(
|
||||||
|
raw_bytes=multi_payload_response,
|
||||||
|
thread_id="test-thread",
|
||||||
|
from_id="agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
results = [s async for s in extract_payloads(state)]
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
# Each result should have the same thread_id and from_id
|
||||||
|
for r in results:
|
||||||
|
assert r.thread_id == "test-thread"
|
||||||
|
assert r.from_id == "agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_response_yields_original(self):
|
||||||
|
"""Empty response should yield original state."""
|
||||||
|
from agentserver.message_bus.stream_pump import extract_payloads
|
||||||
|
|
||||||
|
state = MessageState(
|
||||||
|
raw_bytes=b"",
|
||||||
|
thread_id="test",
|
||||||
|
from_id="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
results = [s async for s in extract_payloads(state)]
|
||||||
|
|
||||||
|
# Should yield something (original or empty handling)
|
||||||
|
assert len(results) >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preserves_metadata(self):
|
||||||
|
"""Extracted payloads should preserve metadata."""
|
||||||
|
from agentserver.message_bus.stream_pump import extract_payloads
|
||||||
|
|
||||||
|
state = MessageState(
|
||||||
|
raw_bytes=b"<a/><b/>",
|
||||||
|
thread_id="test",
|
||||||
|
from_id="test",
|
||||||
|
metadata={"custom": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
results = [s async for s in extract_payloads(state)]
|
||||||
|
|
||||||
|
for r in results:
|
||||||
|
assert r.metadata.get("custom") == "value"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Step Factory Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@requires_stream_pump
|
||||||
|
class TestStepFactories:
|
||||||
|
"""Tests for the step factory functions."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_xsd_validation_direct(self):
|
||||||
|
"""XSD validation via lxml schema."""
|
||||||
|
# Create a simple schema
|
||||||
|
xsd_str = '''
|
||||||
|
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
|
||||||
|
<xs:element name="test">
|
||||||
|
<xs:complexType>
|
||||||
|
<xs:sequence>
|
||||||
|
<xs:element name="value" type="xs:integer"/>
|
||||||
|
</xs:sequence>
|
||||||
|
</xs:complexType>
|
||||||
|
</xs:element>
|
||||||
|
</xs:schema>'''
|
||||||
|
schema = etree.XMLSchema(etree.fromstring(xsd_str.encode()))
|
||||||
|
|
||||||
|
# Valid payload
|
||||||
|
valid_xml = etree.fromstring(b"<test><value>42</value></test>")
|
||||||
|
assert schema.validate(valid_xml)
|
||||||
|
|
||||||
|
# Invalid payload
|
||||||
|
invalid_xml = etree.fromstring(b"<test><value>not-an-int</value></test>")
|
||||||
|
assert not schema.validate(invalid_xml)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_routing_factory(self):
|
||||||
|
"""Routing step should use injected routing table."""
|
||||||
|
from agentserver.message_bus.steps.routing_resolution import make_routing_step
|
||||||
|
from agentserver.message_bus.stream_pump import Listener
|
||||||
|
|
||||||
|
# Create mock listener
|
||||||
|
mock_listener = Listener(
|
||||||
|
name="calculator.add",
|
||||||
|
payload_class=type("AddPayload", (), {}),
|
||||||
|
handler=lambda x, m: b"<result/>",
|
||||||
|
description="test",
|
||||||
|
)
|
||||||
|
|
||||||
|
routing_table = {
|
||||||
|
"calculator.add.addpayload": [mock_listener]
|
||||||
|
}
|
||||||
|
|
||||||
|
step = make_routing_step(routing_table)
|
||||||
|
|
||||||
|
# Create a mock payload instance
|
||||||
|
@dataclass
|
||||||
|
class AddPayload:
|
||||||
|
a: int = 0
|
||||||
|
b: int = 0
|
||||||
|
|
||||||
|
state = MessageState(
|
||||||
|
payload=AddPayload(a=1, b=2),
|
||||||
|
to_id="calculator.add",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await step(state)
|
||||||
|
|
||||||
|
assert result.error is None
|
||||||
|
assert result.target_listeners == [mock_listener]
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Pipeline Integration Tests (lightweight)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestPipelineIntegration:
|
||||||
|
"""Integration tests for step sequences."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_repair_through_extraction(self, valid_envelope_bytes):
|
||||||
|
"""Test repair → c14n → extraction chain."""
|
||||||
|
state = MessageState(raw_bytes=valid_envelope_bytes)
|
||||||
|
|
||||||
|
state = await repair_step(state)
|
||||||
|
assert state.error is None, f"repair failed: {state.error}"
|
||||||
|
|
||||||
|
state = await c14n_step(state)
|
||||||
|
assert state.error is None, f"c14n failed: {state.error}"
|
||||||
|
|
||||||
|
state = await payload_extraction_step(state)
|
||||||
|
assert state.error is None, f"extraction failed: {state.error}"
|
||||||
|
|
||||||
|
assert state.payload_tree is not None
|
||||||
|
assert state.thread_id is not None
|
||||||
|
assert state.from_id is not None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_short_circuits(self):
|
||||||
|
"""Errors should prevent downstream steps from running."""
|
||||||
|
call_log = []
|
||||||
|
|
||||||
|
async def step_a(state):
|
||||||
|
call_log.append("a")
|
||||||
|
state.error = "Intentional error"
|
||||||
|
return state
|
||||||
|
|
||||||
|
async def step_b(state):
|
||||||
|
call_log.append("b")
|
||||||
|
return state
|
||||||
|
|
||||||
|
# Simple pipeline runner (same logic as StreamPump uses)
|
||||||
|
async def run_pipeline(steps, state):
|
||||||
|
for step in steps:
|
||||||
|
state = await step(state)
|
||||||
|
if state.error:
|
||||||
|
break
|
||||||
|
return state
|
||||||
|
|
||||||
|
result = await run_pipeline([step_a, step_b], MessageState())
|
||||||
|
|
||||||
|
assert call_log == ["a"] # step_b should not have been called
|
||||||
|
assert result.error == "Intentional error"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Run with pytest
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
398
tests/test_pump_integration.py
Normal file
398
tests/test_pump_integration.py
Normal file
|
|
@ -0,0 +1,398 @@
|
||||||
|
"""
|
||||||
|
test_pump_integration.py — Integration tests for the StreamPump
|
||||||
|
|
||||||
|
Run with: pytest tests/test_pump_integration.py -v
|
||||||
|
|
||||||
|
These tests verify the full message flow through the pump:
|
||||||
|
inject → parse → extract → validate → deserialize → route → handler → response
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
from agentserver.message_bus import StreamPump, bootstrap, MessageState
|
||||||
|
from agentserver.message_bus.stream_pump import ConfigLoader, ListenerConfig, OrganismConfig, Listener
|
||||||
|
from handlers.hello import Greeting, GreetingResponse, handle_greeting, ENVELOPE_NS
|
||||||
|
|
||||||
|
|
||||||
|
def make_envelope(payload_xml: str, from_id: str, to_id: str, thread_id: str) -> bytes:
|
||||||
|
"""Helper to create a properly formatted envelope.
|
||||||
|
|
||||||
|
Note: payload_xml should include its own namespace (or xmlns="") to avoid
|
||||||
|
inheriting the envelope namespace. The envelope XSD expects payload to be
|
||||||
|
in a foreign namespace (##other).
|
||||||
|
"""
|
||||||
|
# Ensure payload has explicit namespace (empty string = no namespace)
|
||||||
|
if 'xmlns=' not in payload_xml:
|
||||||
|
# Insert xmlns="" after the tag name
|
||||||
|
idx = payload_xml.index('>')
|
||||||
|
if payload_xml[idx-1] == '/':
|
||||||
|
idx -= 1
|
||||||
|
payload_xml = payload_xml[:idx] + ' xmlns=""' + payload_xml[idx:]
|
||||||
|
|
||||||
|
return f"""<message xmlns="{ENVELOPE_NS}">
|
||||||
|
<meta>
|
||||||
|
<from>{from_id}</from>
|
||||||
|
<to>{to_id}</to>
|
||||||
|
<thread>{thread_id}</thread>
|
||||||
|
</meta>
|
||||||
|
{payload_xml}
|
||||||
|
</message>""".encode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
class TestPumpBootstrap:
|
||||||
|
"""Test ConfigLoader and bootstrap."""
|
||||||
|
|
||||||
|
def test_config_loader_parses_yaml(self):
|
||||||
|
"""ConfigLoader should parse organism.yaml correctly."""
|
||||||
|
config = ConfigLoader.load('config/organism.yaml')
|
||||||
|
|
||||||
|
assert config.name == "hello-world"
|
||||||
|
assert len(config.listeners) == 1
|
||||||
|
assert config.listeners[0].name == "greeter"
|
||||||
|
assert config.listeners[0].payload_class == Greeting
|
||||||
|
assert config.listeners[0].handler == handle_greeting
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bootstrap_creates_pump(self):
|
||||||
|
"""bootstrap() should create a configured pump."""
|
||||||
|
pump = await bootstrap('config/organism.yaml')
|
||||||
|
|
||||||
|
assert pump.config.name == "hello-world"
|
||||||
|
assert "greeter.greeting" in pump.routing_table
|
||||||
|
assert pump.listeners["greeter"].payload_class == Greeting
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bootstrap_generates_xsd(self):
|
||||||
|
"""bootstrap() should generate XSD schemas for listeners."""
|
||||||
|
pump = await bootstrap('config/organism.yaml')
|
||||||
|
|
||||||
|
listener = pump.listeners["greeter"]
|
||||||
|
assert listener.schema is not None
|
||||||
|
|
||||||
|
# Schema should validate a proper Greeting
|
||||||
|
from lxml import etree
|
||||||
|
valid_xml = etree.fromstring(b"<Greeting><Name>Test</Name></Greeting>")
|
||||||
|
listener.schema.assertValid(valid_xml)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPumpInjection:
|
||||||
|
"""Test message injection and queue behavior."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_inject_adds_to_queue(self):
|
||||||
|
"""inject() should add a MessageState to the queue."""
|
||||||
|
pump = await bootstrap('config/organism.yaml')
|
||||||
|
|
||||||
|
thread_id = str(uuid.uuid4())
|
||||||
|
await pump.inject(b"<test/>", thread_id, from_id="user")
|
||||||
|
|
||||||
|
assert pump.queue.qsize() == 1
|
||||||
|
state = await pump.queue.get()
|
||||||
|
assert state.raw_bytes == b"<test/>"
|
||||||
|
assert state.thread_id == thread_id
|
||||||
|
assert state.from_id == "user"
|
||||||
|
|
||||||
|
|
||||||
|
class TestFullPipelineFlow:
|
||||||
|
"""Test complete message flow through the pipeline."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_greeting_round_trip(self):
|
||||||
|
"""
|
||||||
|
Full integration test:
|
||||||
|
1. Inject a Greeting message
|
||||||
|
2. Pump processes it through the pipeline
|
||||||
|
3. Handler is called with deserialized Greeting
|
||||||
|
4. Handler response is re-injected
|
||||||
|
"""
|
||||||
|
pump = await bootstrap('config/organism.yaml')
|
||||||
|
|
||||||
|
# Track what the handler receives
|
||||||
|
handler_calls = []
|
||||||
|
original_handler = pump.listeners["greeter"].handler
|
||||||
|
|
||||||
|
async def tracking_handler(payload, metadata):
|
||||||
|
handler_calls.append((payload, metadata))
|
||||||
|
return await original_handler(payload, metadata)
|
||||||
|
|
||||||
|
pump.listeners["greeter"].handler = tracking_handler
|
||||||
|
|
||||||
|
# Create and inject a Greeting message
|
||||||
|
thread_id = str(uuid.uuid4())
|
||||||
|
envelope = make_envelope(
|
||||||
|
payload_xml="<Greeting><Name>World</Name></Greeting>",
|
||||||
|
from_id="user",
|
||||||
|
to_id="greeter",
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await pump.inject(envelope, thread_id, from_id="user")
|
||||||
|
|
||||||
|
# Run pump briefly to process the message
|
||||||
|
pump._running = True
|
||||||
|
pipeline = pump.build_pipeline(pump._queue_source())
|
||||||
|
|
||||||
|
# Process with timeout
|
||||||
|
async def run_with_timeout():
|
||||||
|
async with pipeline.stream() as streamer:
|
||||||
|
try:
|
||||||
|
async for _ in streamer:
|
||||||
|
# One iteration should process our message
|
||||||
|
break
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(run_with_timeout(), timeout=2.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
pump._running = False
|
||||||
|
|
||||||
|
# Verify handler was called
|
||||||
|
assert len(handler_calls) == 1
|
||||||
|
payload, metadata = handler_calls[0]
|
||||||
|
|
||||||
|
assert isinstance(payload, Greeting)
|
||||||
|
assert payload.name == "World"
|
||||||
|
assert metadata.thread_id == thread_id
|
||||||
|
assert metadata.from_id == "user"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handler_response_reinjected(self):
|
||||||
|
"""Handler response should be re-injected into the queue."""
|
||||||
|
pump = await bootstrap('config/organism.yaml')
|
||||||
|
|
||||||
|
# Capture re-injected messages
|
||||||
|
reinjected = []
|
||||||
|
original_reinject = pump._reinject_responses
|
||||||
|
|
||||||
|
async def capture_reinject(state):
|
||||||
|
reinjected.append(state)
|
||||||
|
# Don't actually re-inject to avoid infinite loop
|
||||||
|
|
||||||
|
pump._reinject_responses = capture_reinject
|
||||||
|
|
||||||
|
# Inject a Greeting
|
||||||
|
thread_id = str(uuid.uuid4())
|
||||||
|
envelope = make_envelope(
|
||||||
|
payload_xml="<Greeting><Name>Alice</Name></Greeting>",
|
||||||
|
from_id="user",
|
||||||
|
to_id="greeter",
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await pump.inject(envelope, thread_id, from_id="user")
|
||||||
|
|
||||||
|
# Run pump briefly
|
||||||
|
pump._running = True
|
||||||
|
pipeline = pump.build_pipeline(pump._queue_source())
|
||||||
|
|
||||||
|
async def run_with_timeout():
|
||||||
|
async with pipeline.stream() as streamer:
|
||||||
|
try:
|
||||||
|
async for _ in streamer:
|
||||||
|
break
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(run_with_timeout(), timeout=2.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
pump._running = False
|
||||||
|
|
||||||
|
# Verify response was re-injected
|
||||||
|
assert len(reinjected) == 1
|
||||||
|
response_state = reinjected[0]
|
||||||
|
|
||||||
|
assert response_state.raw_bytes is not None
|
||||||
|
assert b"Hello, Alice!" in response_state.raw_bytes
|
||||||
|
assert response_state.thread_id == thread_id
|
||||||
|
assert response_state.from_id == "greeter"
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorHandling:
|
||||||
|
"""Test error paths through the pipeline."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_xml_error(self):
|
||||||
|
"""Malformed XML should set error, not crash."""
|
||||||
|
pump = await bootstrap('config/organism.yaml')
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
original_handle_errors = pump._handle_errors
|
||||||
|
|
||||||
|
async def capture_errors(state):
|
||||||
|
if state.error:
|
||||||
|
errors.append(state.error)
|
||||||
|
return await original_handle_errors(state)
|
||||||
|
|
||||||
|
pump._handle_errors = capture_errors
|
||||||
|
|
||||||
|
# Inject malformed XML
|
||||||
|
thread_id = str(uuid.uuid4())
|
||||||
|
await pump.inject(b"<not valid xml", thread_id, from_id="user")
|
||||||
|
|
||||||
|
# Run pump
|
||||||
|
pump._running = True
|
||||||
|
pipeline = pump.build_pipeline(pump._queue_source())
|
||||||
|
|
||||||
|
async def run_with_timeout():
|
||||||
|
async with pipeline.stream() as streamer:
|
||||||
|
try:
|
||||||
|
async for _ in streamer:
|
||||||
|
break
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(run_with_timeout(), timeout=2.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
pump._running = False
|
||||||
|
|
||||||
|
# Should have logged an error (repair step recovers, but envelope validation fails)
|
||||||
|
# The exact error depends on how far it gets
|
||||||
|
assert pump.queue.qsize() == 0 or len(errors) >= 0 # Processed without crash
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unknown_route_error(self):
|
||||||
|
"""Message to unknown listener should error gracefully."""
|
||||||
|
pump = await bootstrap('config/organism.yaml')
|
||||||
|
|
||||||
|
errors = []
|
||||||
|
original_handle_errors = pump._handle_errors
|
||||||
|
|
||||||
|
async def capture_errors(state):
|
||||||
|
if state.error:
|
||||||
|
errors.append(state.error)
|
||||||
|
return await original_handle_errors(state)
|
||||||
|
|
||||||
|
pump._handle_errors = capture_errors
|
||||||
|
|
||||||
|
# Inject message to non-existent listener
|
||||||
|
thread_id = str(uuid.uuid4())
|
||||||
|
envelope = make_envelope(
|
||||||
|
payload_xml="<Greeting><Name>Test</Name></Greeting>",
|
||||||
|
from_id="user",
|
||||||
|
to_id="nonexistent", # No such listener
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await pump.inject(envelope, thread_id, from_id="user")
|
||||||
|
|
||||||
|
# Run pump
|
||||||
|
pump._running = True
|
||||||
|
pipeline = pump.build_pipeline(pump._queue_source())
|
||||||
|
|
||||||
|
async def run_with_timeout():
|
||||||
|
async with pipeline.stream() as streamer:
|
||||||
|
try:
|
||||||
|
async for _ in streamer:
|
||||||
|
break
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(run_with_timeout(), timeout=2.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
pump._running = False
|
||||||
|
|
||||||
|
# Should have a routing error
|
||||||
|
assert any("nonexistent" in e for e in errors)
|
||||||
|
|
||||||
|
|
||||||
|
class TestManualPumpConfiguration:
|
||||||
|
"""Test creating a pump without YAML config."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manual_listener_registration(self):
|
||||||
|
"""Can register listeners programmatically."""
|
||||||
|
config = OrganismConfig(name="manual-test")
|
||||||
|
pump = StreamPump(config)
|
||||||
|
|
||||||
|
lc = ListenerConfig(
|
||||||
|
name="greeter",
|
||||||
|
payload_class_path="handlers.hello.Greeting",
|
||||||
|
handler_path="handlers.hello.handle_greeting",
|
||||||
|
description="Test listener",
|
||||||
|
payload_class=Greeting,
|
||||||
|
handler=handle_greeting,
|
||||||
|
)
|
||||||
|
|
||||||
|
listener = pump.register_listener(lc)
|
||||||
|
|
||||||
|
assert listener.name == "greeter"
|
||||||
|
assert listener.root_tag == "greeter.greeting"
|
||||||
|
assert "greeter.greeting" in pump.routing_table
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_custom_handler(self):
|
||||||
|
"""Can use a custom handler function."""
|
||||||
|
config = OrganismConfig(name="custom-test")
|
||||||
|
pump = StreamPump(config)
|
||||||
|
|
||||||
|
responses = []
|
||||||
|
|
||||||
|
async def custom_handler(payload, metadata):
|
||||||
|
responses.append(payload)
|
||||||
|
return b"<Ack/>"
|
||||||
|
|
||||||
|
lc = ListenerConfig(
|
||||||
|
name="custom",
|
||||||
|
payload_class_path="handlers.hello.Greeting",
|
||||||
|
handler_path="handlers.hello.handle_greeting",
|
||||||
|
description="Custom handler",
|
||||||
|
payload_class=Greeting,
|
||||||
|
handler=custom_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
pump.register_listener(lc)
|
||||||
|
|
||||||
|
# Inject and process
|
||||||
|
thread_id = str(uuid.uuid4())
|
||||||
|
envelope = make_envelope(
|
||||||
|
payload_xml="<Greeting><Name>Custom</Name></Greeting>",
|
||||||
|
from_id="tester",
|
||||||
|
to_id="custom",
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await pump.inject(envelope, thread_id, from_id="tester")
|
||||||
|
|
||||||
|
# Run pump
|
||||||
|
pump._running = True
|
||||||
|
|
||||||
|
# Capture re-injected to prevent loop
|
||||||
|
async def noop_reinject(state):
|
||||||
|
pass
|
||||||
|
pump._reinject_responses = noop_reinject
|
||||||
|
|
||||||
|
pipeline = pump.build_pipeline(pump._queue_source())
|
||||||
|
|
||||||
|
async def run_with_timeout():
|
||||||
|
async with pipeline.stream() as streamer:
|
||||||
|
try:
|
||||||
|
async for _ in streamer:
|
||||||
|
break
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(run_with_timeout(), timeout=2.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
pump._running = False
|
||||||
|
|
||||||
|
# Custom handler should have been called
|
||||||
|
assert len(responses) == 1
|
||||||
|
assert responses[0].name == "Custom"
|
||||||
5
third_party/xmlable/__init__.py
vendored
5
third_party/xmlable/__init__.py
vendored
|
|
@ -7,6 +7,7 @@ from typing import Type, TypeVar, Any
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
from ._xmlify import xmlify
|
from ._xmlify import xmlify
|
||||||
|
from ._errors import XErrorCtx
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
@ -19,7 +20,9 @@ def parse_element(cls: Type[T], element: _Element | ObjectifiedElement) -> T:
|
||||||
"""Direct in-memory deserialization from validated lxml Element."""
|
"""Direct in-memory deserialization from validated lxml Element."""
|
||||||
xobject = _get_xobject(cls)
|
xobject = _get_xobject(cls)
|
||||||
obj_element = objectify.fromstring(etree.tostring(element))
|
obj_element = objectify.fromstring(etree.tostring(element))
|
||||||
return xobject.xml_in(obj_element, ctx=None)
|
# Create a root context for error tracing
|
||||||
|
ctx = XErrorCtx(trace=[cls.__name__])
|
||||||
|
return xobject.xml_in(obj_element, ctx=ctx)
|
||||||
|
|
||||||
def parse_bytes(cls: Type[T], xml_bytes: bytes) -> T:
|
def parse_bytes(cls: Type[T], xml_bytes: bytes) -> T:
|
||||||
tree = objectify.parse(BytesIO(xml_bytes))
|
tree = objectify.parse(BytesIO(xml_bytes))
|
||||||
|
|
|
||||||
2
third_party/xmlable/_errors.py
vendored
2
third_party/xmlable/_errors.py
vendored
|
|
@ -9,7 +9,7 @@ from typing import Any, Iterable
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from termcolor.termcolor import Color
|
from termcolor.termcolor import Color
|
||||||
|
|
||||||
from xmlable._utils import typename, AnyType
|
from ._utils import typename, AnyType
|
||||||
|
|
||||||
|
|
||||||
def trace_note(trace: list[str], arrow_c: Color, node_c: Color):
|
def trace_note(trace: list[str], arrow_c: Color, node_c: Color):
|
||||||
|
|
|
||||||
6
third_party/xmlable/_io.py
vendored
6
third_party/xmlable/_io.py
vendored
|
|
@ -10,9 +10,9 @@ from termcolor import colored
|
||||||
from lxml.objectify import parse as objectify_parse
|
from lxml.objectify import parse as objectify_parse
|
||||||
from lxml.etree import _ElementTree
|
from lxml.etree import _ElementTree
|
||||||
|
|
||||||
from xmlable._utils import typename
|
from ._utils import typename
|
||||||
from xmlable._xobject import is_xmlified
|
from ._xobject import is_xmlified
|
||||||
from xmlable._errors import ErrorTypes
|
from ._errors import ErrorTypes
|
||||||
|
|
||||||
|
|
||||||
def write_file(file_path: str | Path, tree: _ElementTree):
|
def write_file(file_path: str | Path, tree: _ElementTree):
|
||||||
|
|
|
||||||
6
third_party/xmlable/_manual.py
vendored
6
third_party/xmlable/_manual.py
vendored
|
|
@ -8,9 +8,9 @@ from typing import Any
|
||||||
from lxml.etree import _Element, Element, _ElementTree, ElementTree
|
from lxml.etree import _Element, Element, _ElementTree, ElementTree
|
||||||
from lxml.objectify import ObjectifiedElement
|
from lxml.objectify import ObjectifiedElement
|
||||||
|
|
||||||
from xmlable._utils import typename, AnyType, ordered_iter
|
from ._utils import typename, AnyType, ordered_iter
|
||||||
from xmlable._lxml_helpers import with_children, XMLSchema
|
from ._lxml_helpers import with_children, XMLSchema
|
||||||
from xmlable._errors import XError, XErrorCtx, ErrorTypes
|
from ._errors import XError, XErrorCtx, ErrorTypes
|
||||||
|
|
||||||
|
|
||||||
def validate_manual_class(cls: AnyType):
|
def validate_manual_class(cls: AnyType):
|
||||||
|
|
|
||||||
4
third_party/xmlable/_user.py
vendored
4
third_party/xmlable/_user.py
vendored
|
|
@ -6,8 +6,8 @@ The IXmlify interface
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from lxml.etree import _Element
|
from lxml.etree import _Element
|
||||||
from xmlable._xobject import XObject
|
from ._xobject import XObject
|
||||||
from xmlable._utils import AnyType
|
from ._utils import AnyType
|
||||||
|
|
||||||
|
|
||||||
class IXmlify(ABC):
|
class IXmlify(ABC):
|
||||||
|
|
|
||||||
10
third_party/xmlable/_xmlify.py
vendored
10
third_party/xmlable/_xmlify.py
vendored
|
|
@ -15,11 +15,11 @@ from typing import Any, dataclass_transform, cast
|
||||||
from lxml.objectify import ObjectifiedElement
|
from lxml.objectify import ObjectifiedElement
|
||||||
from lxml.etree import Element, _Element
|
from lxml.etree import Element, _Element
|
||||||
|
|
||||||
from xmlable._utils import get, typename, AnyType
|
from ._utils import get, typename, AnyType
|
||||||
from xmlable._errors import XError, XErrorCtx, ErrorTypes
|
from ._errors import XError, XErrorCtx, ErrorTypes
|
||||||
from xmlable._manual import manual_xmlify
|
from ._manual import manual_xmlify
|
||||||
from xmlable._lxml_helpers import with_children, with_child, XMLSchema
|
from ._lxml_helpers import with_children, with_child, XMLSchema
|
||||||
from xmlable._xobject import XObject, gen_xobject
|
from ._xobject import XObject, gen_xobject
|
||||||
|
|
||||||
|
|
||||||
def validate_class(cls: AnyType):
|
def validate_class(cls: AnyType):
|
||||||
|
|
|
||||||
6
third_party/xmlable/_xobject.py
vendored
6
third_party/xmlable/_xobject.py
vendored
|
|
@ -13,9 +13,9 @@ from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Type, get_args, TypeAlias, cast
|
from typing import Any, Callable, Type, get_args, TypeAlias, cast
|
||||||
from types import GenericAlias
|
from types import GenericAlias
|
||||||
|
|
||||||
from xmlable._utils import get, typename, firstkey, AnyType
|
from ._utils import get, typename, firstkey, AnyType
|
||||||
from xmlable._errors import XErrorCtx, ErrorTypes
|
from ._errors import XErrorCtx, ErrorTypes
|
||||||
from xmlable._lxml_helpers import (
|
from ._lxml_helpers import (
|
||||||
with_text,
|
with_text,
|
||||||
with_child,
|
with_child,
|
||||||
with_children,
|
with_children,
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue