"""
test_pipeline_steps.py — Unit tests for individual pipeline steps
Run with: pytest tests/test_pipeline_steps.py -v
Each step is tested in isolation with known inputs and expected outputs.
This makes debugging much easier than testing the full pipeline.
Install test dependencies:
pip install -e ".[test]"
"""
import pytest
import asyncio
from dataclasses import dataclass
from lxml import etree
# Import the message state
from xml_pipeline.message_bus.message_state import MessageState, HandlerMetadata
# Import individual steps
from xml_pipeline.message_bus.steps.repair import repair_step
from xml_pipeline.message_bus.steps.c14n import c14n_step
from xml_pipeline.message_bus.steps.envelope_validation import envelope_validation_step
from xml_pipeline.message_bus.steps.payload_extraction import payload_extraction_step
from xml_pipeline.message_bus.steps.thread_assignment import thread_assignment_step
# Check for optional dependencies
try:
import aiostream
HAS_AIOSTREAM = True
except ImportError:
HAS_AIOSTREAM = False
requires_aiostream = pytest.mark.skipif(
not HAS_AIOSTREAM,
reason="aiostream not installed (pip install aiostream)"
)
# Check for stream_pump dependencies
try:
from xml_pipeline.message_bus.stream_pump import StreamPump, Listener
from xml_pipeline.message_bus.steps.routing_resolution import make_routing_step
HAS_STREAM_PUMP = True
except ImportError:
HAS_STREAM_PUMP = False
requires_stream_pump = pytest.mark.skipif(
not HAS_STREAM_PUMP,
reason="stream_pump dependencies not available"
)
# ============================================================================
# Test Fixtures
# ============================================================================
@pytest.fixture
def valid_envelope_bytes():
"""A well-formed message envelope matching envelope.xsd."""
return b'''
calculator.add
550e8400-e29b-41d4-a716-446655440000
5
3
'''
@pytest.fixture
def malformed_xml_bytes():
"""Malformed XML that lxml can partially recover."""
return b'content'
@pytest.fixture
def completely_broken_bytes():
"""Not XML at all."""
return b'this is not xml at all { json: "maybe" }'
@pytest.fixture
def multi_payload_response():
"""Handler response with multiple payloads."""
return b'''
42
12
I should also check...
'''
@pytest.fixture
def empty_state():
"""Fresh MessageState with no data."""
return MessageState()
@pytest.fixture
def state_with_bytes(valid_envelope_bytes):
"""MessageState with raw_bytes populated."""
return MessageState(raw_bytes=valid_envelope_bytes)
# ============================================================================
# repair_step Tests
# ============================================================================
class TestRepairStep:
"""Tests for the XML repair/recovery step."""
@pytest.mark.asyncio
async def test_valid_xml_passes_through(self, valid_envelope_bytes):
"""Valid XML should parse without error."""
state = MessageState(raw_bytes=valid_envelope_bytes)
result = await repair_step(state)
assert result.error is None
assert result.envelope_tree is not None
assert result.envelope_tree.tag == "{https://xml-pipeline.org/ns/envelope/v1}message"
@pytest.mark.asyncio
async def test_malformed_xml_recovered(self, malformed_xml_bytes):
"""Malformed XML should be recovered if possible."""
state = MessageState(raw_bytes=malformed_xml_bytes)
result = await repair_step(state)
# lxml recovery mode should produce something
# May or may not have error depending on severity
assert result.envelope_tree is not None or result.error is not None
@pytest.mark.asyncio
async def test_no_bytes_sets_error(self, empty_state):
"""Missing raw_bytes should set an error."""
result = await repair_step(empty_state)
assert result.error is not None
assert "no raw_bytes" in result.error
@pytest.mark.asyncio
async def test_clears_raw_bytes_after_parse(self, valid_envelope_bytes):
"""raw_bytes should be cleared after successful parse (memory optimization)."""
state = MessageState(raw_bytes=valid_envelope_bytes)
result = await repair_step(state)
assert result.raw_bytes is None
assert result.envelope_tree is not None
# ============================================================================
# c14n_step Tests
# ============================================================================
class TestC14nStep:
"""Tests for the canonicalization step."""
@pytest.mark.asyncio
async def test_normalizes_whitespace(self):
"""C14N should normalize whitespace."""
xml_with_whitespace = b'''
value
'''
state = MessageState(raw_bytes=xml_with_whitespace)
state = await repair_step(state)
result = await c14n_step(state)
assert result.error is None
assert result.envelope_tree is not None
@pytest.mark.asyncio
async def test_normalizes_attribute_order(self):
"""C14N should produce consistent attribute ordering."""
xml_a = b''
xml_b = b''
state_a = MessageState(raw_bytes=xml_a)
state_a = await repair_step(state_a)
state_a = await c14n_step(state_a)
state_b = MessageState(raw_bytes=xml_b)
state_b = await repair_step(state_b)
state_b = await c14n_step(state_b)
# Both should produce identical canonical form
c14n_a = etree.tostring(state_a.envelope_tree, method="c14n")
c14n_b = etree.tostring(state_b.envelope_tree, method="c14n")
assert c14n_a == c14n_b
@pytest.mark.asyncio
async def test_no_tree_sets_error(self, empty_state):
"""Missing envelope_tree should set error."""
result = await c14n_step(empty_state)
assert result.error is not None
assert "no envelope_tree" in result.error
# ============================================================================
# payload_extraction_step Tests
# ============================================================================
class TestPayloadExtractionStep:
"""Tests for extracting payload from envelope."""
@pytest.mark.asyncio
async def test_extracts_payload_element(self, valid_envelope_bytes):
"""Should extract the payload element from envelope."""
state = MessageState(raw_bytes=valid_envelope_bytes)
state = await repair_step(state)
state = await c14n_step(state)
# Skip envelope validation for this test
result = await payload_extraction_step(state)
assert result.error is None
assert result.payload_tree is not None
# Tag may include namespace prefix
assert "addpayload" in result.payload_tree.tag
@pytest.mark.asyncio
async def test_extracts_thread_id(self, valid_envelope_bytes):
"""Should extract thread ID from envelope."""
state = MessageState(raw_bytes=valid_envelope_bytes)
state = await repair_step(state)
state = await c14n_step(state)
result = await payload_extraction_step(state)
assert result.thread_id == "550e8400-e29b-41d4-a716-446655440000"
@pytest.mark.asyncio
async def test_extracts_from_id(self, valid_envelope_bytes):
"""Should extract sender ID from envelope."""
state = MessageState(raw_bytes=valid_envelope_bytes)
state = await repair_step(state)
state = await c14n_step(state)
result = await payload_extraction_step(state)
assert result.from_id == "calculator.add"
@pytest.mark.asyncio
async def test_multiple_payloads_error(self):
"""Multiple payload elements should error."""
multi_payload = b'''
test
uuid-here
data
more data
'''
state = MessageState(raw_bytes=multi_payload)
state = await repair_step(state)
state = await c14n_step(state)
result = await payload_extraction_step(state)
assert result.error is not None
assert "multiple payload" in result.error.lower()
@pytest.mark.asyncio
async def test_no_payload_error(self):
"""Missing payload element should error."""
no_payload = b'''
test
uuid-here
'''
state = MessageState(raw_bytes=no_payload)
state = await repair_step(state)
state = await c14n_step(state)
result = await payload_extraction_step(state)
assert result.error is not None
assert "no payload" in result.error.lower()
@pytest.mark.asyncio
async def test_missing_meta_error(self):
"""Missing block should error."""
no_meta = b'''
data
'''
state = MessageState(raw_bytes=no_meta)
state = await repair_step(state)
state = await c14n_step(state)
result = await payload_extraction_step(state)
assert result.error is not None
assert "meta" in result.error.lower()
@pytest.mark.asyncio
async def test_missing_from_error(self):
"""Missing in should error."""
no_from = b'''
uuid-here
data
'''
state = MessageState(raw_bytes=no_from)
state = await repair_step(state)
state = await c14n_step(state)
result = await payload_extraction_step(state)
assert result.error is not None
assert "from" in result.error.lower()
# ============================================================================
# thread_assignment_step Tests
# ============================================================================
class TestThreadAssignmentStep:
"""Tests for thread UUID assignment."""
@pytest.mark.asyncio
async def test_valid_uuid_preserved(self):
"""Valid UUID should be preserved."""
valid_uuid = "550e8400-e29b-41d4-a716-446655440000"
state = MessageState(thread_id=valid_uuid)
result = await thread_assignment_step(state)
assert result.thread_id == valid_uuid
@pytest.mark.asyncio
async def test_missing_uuid_generated(self, empty_state):
"""Missing UUID should generate a new one."""
result = await thread_assignment_step(empty_state)
assert result.thread_id is not None
assert len(result.thread_id) == 36 # UUID format
@pytest.mark.asyncio
async def test_invalid_uuid_replaced(self):
"""Invalid UUID should be replaced with a new one."""
state = MessageState(thread_id="not-a-valid-uuid")
result = await thread_assignment_step(state)
assert result.thread_id != "not-a-valid-uuid"
assert len(result.thread_id) == 36
@pytest.mark.asyncio
async def test_replacement_logged_in_metadata(self):
"""Replaced UUIDs should be logged in metadata."""
state = MessageState(thread_id="bad-uuid")
result = await thread_assignment_step(state)
diagnostics = result.metadata.get("diagnostics", [])
assert len(diagnostics) > 0
assert "bad-uuid" in diagnostics[0]
# ============================================================================
# Multi-Payload Extraction Tests (standalone, no aiostream required)
# ============================================================================
class TestPayloadExtractionLogic:
"""Test the core payload extraction logic without aiostream."""
def test_extract_single_payload(self):
"""Single root element should extract cleanly."""
raw = b"42"
wrapped = b"" + raw + b""
tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True))
children = list(tree)
assert len(children) == 1
assert children[0].tag == "result"
assert children[0].text == "42"
def test_extract_multiple_payloads(self, multi_payload_response):
"""Multiple root elements should all be extracted."""
wrapped = b"" + multi_payload_response + b""
tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True))
children = list(tree)
assert len(children) == 3
tags = [c.tag for c in children]
assert "search.result" in tags
assert "calculator.add.addpayload" in tags
assert "thought" in tags
def test_extract_preserves_content(self):
"""Extracted payloads should preserve their content."""
raw = b"value"
wrapped = b"" + raw + b""
tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True))
children = list(tree)
assert len(children) == 1
# Re-serialize and check
extracted = etree.tostring(children[0])
assert b"value" in extracted
def test_empty_response_no_crash(self):
"""Empty response should not crash."""
wrapped = b""
tree = etree.fromstring(wrapped)
children = list(tree)
assert len(children) == 0
def test_malformed_response_recovers(self):
"""Malformed XML should be recovered if possible."""
raw = b"text"
wrapped = b"" + raw + b""
# With recovery parser
tree = etree.fromstring(wrapped, parser=etree.XMLParser(recover=True))
# Should get something, exact result depends on lxml recovery
assert tree is not None
# ============================================================================
# Multi-Payload Extraction Tests (from stream_pump.py)
# ============================================================================
@requires_aiostream
class TestMultiPayloadExtraction:
"""Tests for the fan-out payload extraction."""
@pytest.mark.asyncio
async def test_single_payload_yields_one(self):
"""Single payload should yield one state."""
from xml_pipeline.message_bus.stream_pump import extract_payloads
state = MessageState(
raw_bytes=b"42",
thread_id="test-thread",
from_id="test-sender",
)
results = [s async for s in extract_payloads(state)]
assert len(results) == 1
assert b"" in results[0].raw_bytes
assert results[0].thread_id == "test-thread"
assert results[0].from_id == "test-sender"
@pytest.mark.asyncio
async def test_multiple_payloads_yields_many(self, multi_payload_response):
"""Multiple payloads should yield multiple states."""
from xml_pipeline.message_bus.stream_pump import extract_payloads
state = MessageState(
raw_bytes=multi_payload_response,
thread_id="test-thread",
from_id="agent",
)
results = [s async for s in extract_payloads(state)]
assert len(results) == 3
# Each result should have the same thread_id and from_id
for r in results:
assert r.thread_id == "test-thread"
assert r.from_id == "agent"
@pytest.mark.asyncio
async def test_empty_response_yields_original(self):
"""Empty response should yield original state."""
from xml_pipeline.message_bus.stream_pump import extract_payloads
state = MessageState(
raw_bytes=b"",
thread_id="test",
from_id="test",
)
results = [s async for s in extract_payloads(state)]
# Should yield something (original or empty handling)
assert len(results) >= 1
@pytest.mark.asyncio
async def test_preserves_metadata(self):
"""Extracted payloads should preserve metadata."""
from xml_pipeline.message_bus.stream_pump import extract_payloads
state = MessageState(
raw_bytes=b"",
thread_id="test",
from_id="test",
metadata={"custom": "value"},
)
results = [s async for s in extract_payloads(state)]
for r in results:
assert r.metadata.get("custom") == "value"
# ============================================================================
# Step Factory Tests
# ============================================================================
@requires_stream_pump
class TestStepFactories:
"""Tests for the step factory functions."""
@pytest.mark.asyncio
async def test_xsd_validation_direct(self):
"""XSD validation via lxml schema."""
# Create a simple schema
xsd_str = '''
'''
schema = etree.XMLSchema(etree.fromstring(xsd_str.encode()))
# Valid payload
valid_xml = etree.fromstring(b"42")
assert schema.validate(valid_xml)
# Invalid payload
invalid_xml = etree.fromstring(b"not-an-int")
assert not schema.validate(invalid_xml)
@pytest.mark.asyncio
async def test_routing_factory(self):
"""Routing step should use injected routing table."""
from xml_pipeline.message_bus.steps.routing_resolution import make_routing_step
from xml_pipeline.message_bus.stream_pump import Listener
# Create mock listener
mock_listener = Listener(
name="calculator.add",
payload_class=type("AddPayload", (), {}),
handler=lambda x, m: b"",
description="test",
)
routing_table = {
"calculator.add.addpayload": [mock_listener]
}
step = make_routing_step(routing_table)
# Create a mock payload instance
@dataclass
class AddPayload:
a: int = 0
b: int = 0
state = MessageState(
payload=AddPayload(a=1, b=2),
to_id="calculator.add",
)
result = await step(state)
assert result.error is None
assert result.target_listeners == [mock_listener]
# ============================================================================
# Pipeline Integration Tests (lightweight)
# ============================================================================
class TestPipelineIntegration:
"""Integration tests for step sequences."""
@pytest.mark.asyncio
async def test_repair_through_extraction(self, valid_envelope_bytes):
"""Test repair → c14n → extraction chain."""
state = MessageState(raw_bytes=valid_envelope_bytes)
state = await repair_step(state)
assert state.error is None, f"repair failed: {state.error}"
state = await c14n_step(state)
assert state.error is None, f"c14n failed: {state.error}"
state = await payload_extraction_step(state)
assert state.error is None, f"extraction failed: {state.error}"
assert state.payload_tree is not None
assert state.thread_id is not None
assert state.from_id is not None
@pytest.mark.asyncio
async def test_error_short_circuits(self):
"""Errors should prevent downstream steps from running."""
call_log = []
async def step_a(state):
call_log.append("a")
state.error = "Intentional error"
return state
async def step_b(state):
call_log.append("b")
return state
# Simple pipeline runner (same logic as StreamPump uses)
async def run_pipeline(steps, state):
for step in steps:
state = await step(state)
if state.error:
break
return state
result = await run_pipeline([step_a, step_b], MessageState())
assert call_log == ["a"] # step_b should not have been called
assert result.error == "Intentional error"
# ============================================================================
# Run with pytest
# ============================================================================
if __name__ == "__main__":
pytest.main([__file__, "-v"])