Add Sequence and Buffer orchestration primitives

Implement two virtual node patterns for message flow orchestration:

- Sequence: Chains listeners in order (A→B→C), feeding each step's
  output as input to the next. Uses ephemeral listeners to intercept
  step results without modifying core pump behavior.

- Buffer: Fan-out to parallel worker threads with optional result
  collection. Supports fire-and-forget mode (collect=False) for
  non-blocking dispatch.

New files:
- sequence_registry.py / buffer_registry.py: State tracking
- sequence.py / buffer.py: Payloads and handlers
- test_sequence.py / test_buffer.py: 52 new tests

Pump additions:
- register_generic_listener(): Accept any payload type
- unregister_listener(): Cleanup ephemeral listeners
- Global singleton accessors for pump instance

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
dullfig 2026-01-25 14:56:15 -08:00
parent a69eae79c5
commit a623c534d5
10 changed files with 2465 additions and 2 deletions

635
tests/test_buffer.py Normal file
View file

@ -0,0 +1,635 @@
"""
test_buffer.py Tests for the Buffer (fan-out) orchestration primitives.
Tests:
1. BufferRegistry basic operations
2. BufferStart handler
3. Result collection
4. Fire-and-forget mode
"""
import pytest
import uuid
from xml_pipeline.message_bus.buffer_registry import (
BufferRegistry,
BufferState,
BufferItemResult,
get_buffer_registry,
reset_buffer_registry,
)
from xml_pipeline.message_bus.message_state import HandlerMetadata, HandlerResponse
from xml_pipeline.primitives.buffer import (
BufferStart,
BufferComplete,
BufferDispatched,
BufferError,
handle_buffer_start,
_extract_worker_index,
_format_buffer_results,
)
class TestBufferRegistry:
"""Test BufferRegistry basic operations."""
def test_create_buffer_state(self):
"""create() should create a buffer state with correct fields."""
registry = BufferRegistry()
state = registry.create(
buffer_id="buf001",
total_items=5,
return_to="caller",
thread_id="thread-123",
from_id="console",
target="worker",
collect=True,
)
assert state.buffer_id == "buf001"
assert state.total_items == 5
assert state.return_to == "caller"
assert state.thread_id == "thread-123"
assert state.from_id == "console"
assert state.target == "worker"
assert state.collect is True
assert state.completed_count == 0
assert state.successful_count == 0
assert state.is_complete is False
assert state.pending_count == 5
def test_get_returns_buffer(self):
"""get() should return buffer by ID."""
registry = BufferRegistry()
registry.create(
buffer_id="buf002",
total_items=3,
return_to="x",
thread_id="t",
from_id="c",
target="w",
)
state = registry.get("buf002")
assert state is not None
assert state.buffer_id == "buf002"
# Non-existent returns None
assert registry.get("nonexistent") is None
def test_record_result_stores_result(self):
"""record_result() should store result and update counts."""
registry = BufferRegistry()
registry.create(
buffer_id="buf003",
total_items=3,
return_to="x",
thread_id="t",
from_id="c",
target="w",
)
# Record first result (success)
state = registry.record_result(
buffer_id="buf003",
index=0,
result="<Result0/>",
success=True,
)
assert state.completed_count == 1
assert state.successful_count == 1
assert state.is_complete is False
assert 0 in state.results
assert state.results[0].result == "<Result0/>"
assert state.results[0].success is True
# Record second result (failure)
state = registry.record_result(
buffer_id="buf003",
index=1,
result="<Error/>",
success=False,
error="Timeout",
)
assert state.completed_count == 2
assert state.successful_count == 1 # Still just 1
assert state.results[1].success is False
assert state.results[1].error == "Timeout"
# Record third result - now complete
state = registry.record_result(
buffer_id="buf003",
index=2,
result="<Result2/>",
success=True,
)
assert state.completed_count == 3
assert state.successful_count == 2
assert state.is_complete is True
assert state.pending_count == 0
def test_record_result_ignores_duplicates(self):
"""record_result() should not count same index twice."""
registry = BufferRegistry()
registry.create(
buffer_id="buf004",
total_items=2,
return_to="x",
thread_id="t",
from_id="c",
target="w",
)
# Record index 0
state = registry.record_result("buf004", 0, "<R/>", True)
assert state.completed_count == 1
# Try to record index 0 again
state = registry.record_result("buf004", 0, "<Duplicate/>", True)
assert state.completed_count == 1 # Should not increment
assert state.results[0].result == "<R/>" # Original preserved
def test_remove_deletes_buffer(self):
"""remove() should delete buffer from registry."""
registry = BufferRegistry()
registry.create(
buffer_id="buf005",
total_items=1,
return_to="x",
thread_id="t",
from_id="c",
target="w",
)
assert registry.get("buf005") is not None
result = registry.remove("buf005")
assert result is True
assert registry.get("buf005") is None
# Remove non-existent returns False
assert registry.remove("nonexistent") is False
def test_list_active(self):
"""list_active() should return all active buffer IDs."""
registry = BufferRegistry()
registry.create("buf-a", 1, "x", "t", "c", "w")
registry.create("buf-b", 2, "x", "t", "c", "w")
registry.create("buf-c", 3, "x", "t", "c", "w")
active = registry.list_active()
assert set(active) == {"buf-a", "buf-b", "buf-c"}
def test_clear(self):
"""clear() should remove all buffers."""
registry = BufferRegistry()
registry.create("buf-1", 1, "x", "t", "c", "w")
registry.create("buf-2", 2, "x", "t", "c", "w")
registry.clear()
assert registry.list_active() == []
class TestBufferStateProperties:
"""Test BufferState computed properties."""
def test_is_complete_when_all_received(self):
"""is_complete should be True when all items are received."""
state = BufferState(
buffer_id="test",
total_items=3,
return_to="x",
thread_id="t",
from_id="c",
target="w",
completed_count=3,
)
assert state.is_complete is True
def test_pending_count(self):
"""pending_count should reflect remaining items."""
state = BufferState(
buffer_id="test",
total_items=5,
return_to="x",
thread_id="t",
from_id="c",
target="w",
completed_count=2,
)
assert state.pending_count == 3
def test_get_ordered_results(self):
"""get_ordered_results should return results in order."""
state = BufferState(
buffer_id="test",
total_items=3,
return_to="x",
thread_id="t",
from_id="c",
target="w",
results={
0: BufferItemResult(0, "<R0/>", True),
2: BufferItemResult(2, "<R2/>", True),
# Index 1 missing
},
)
ordered = state.get_ordered_results()
assert len(ordered) == 3
assert ordered[0] is not None
assert ordered[0].result == "<R0/>"
assert ordered[1] is None # Missing
assert ordered[2] is not None
assert ordered[2].result == "<R2/>"
class TestBufferStartPayload:
"""Test BufferStart payload."""
def test_buffer_start_fields(self):
"""BufferStart should have expected fields."""
payload = BufferStart(
target="worker",
items="item1\nitem2\nitem3",
collect=True,
return_to="caller",
buffer_id="custom-id",
)
assert payload.target == "worker"
assert payload.items == "item1\nitem2\nitem3"
assert payload.collect is True
assert payload.return_to == "caller"
assert payload.buffer_id == "custom-id"
def test_buffer_start_default_values(self):
"""BufferStart should have sensible defaults."""
payload = BufferStart()
assert payload.target == ""
assert payload.items == ""
assert payload.collect is True
assert payload.return_to == ""
assert payload.buffer_id == ""
class TestBufferCompletePayload:
"""Test BufferComplete payload."""
def test_buffer_complete_fields(self):
"""BufferComplete should have expected fields."""
payload = BufferComplete(
buffer_id="buf123",
total=5,
successful=4,
results="<results>...</results>",
)
assert payload.buffer_id == "buf123"
assert payload.total == 5
assert payload.successful == 4
assert payload.results == "<results>...</results>"
class TestBufferDispatchedPayload:
"""Test BufferDispatched payload (fire-and-forget mode)."""
def test_buffer_dispatched_fields(self):
"""BufferDispatched should have expected fields."""
payload = BufferDispatched(
buffer_id="buf456",
total=10,
)
assert payload.buffer_id == "buf456"
assert payload.total == 10
class TestBufferErrorPayload:
"""Test BufferError payload."""
def test_buffer_error_fields(self):
"""BufferError should have expected fields."""
payload = BufferError(
buffer_id="buf789",
error="Unknown target listener",
)
assert payload.buffer_id == "buf789"
assert payload.error == "Unknown target listener"
class TestExtractWorkerIndex:
"""Test _extract_worker_index helper."""
def test_extracts_index_from_chain(self):
"""Should extract worker index from thread chain."""
chain = "root.parent.buffer_abc123_w5"
index = _extract_worker_index(chain, "abc123")
assert index == 5
def test_extracts_double_digit_index(self):
"""Should handle double-digit indices."""
chain = "x.buffer_xyz_w42"
index = _extract_worker_index(chain, "xyz")
assert index == 42
def test_returns_none_for_no_match(self):
"""Should return None when pattern doesn't match."""
chain = "something.else.entirely"
index = _extract_worker_index(chain, "abc")
assert index is None
def test_returns_none_for_wrong_buffer_id(self):
"""Should return None when buffer ID doesn't match."""
chain = "root.buffer_other_w3"
index = _extract_worker_index(chain, "abc")
assert index is None
class TestFormatBufferResults:
"""Test _format_buffer_results helper."""
def test_formats_complete_results(self):
"""Should format all results as XML."""
state = BufferState(
buffer_id="test",
total_items=2,
return_to="x",
thread_id="t",
from_id="c",
target="w",
results={
0: BufferItemResult(0, "<Result>A</Result>", True),
1: BufferItemResult(1, "<Result>B</Result>", True),
},
)
xml = _format_buffer_results(state)
assert "<results>" in xml
assert "</results>" in xml
assert 'index="0"' in xml
assert 'index="1"' in xml
assert 'success="true"' in xml
assert "<Result>A</Result>" in xml
assert "<Result>B</Result>" in xml
def test_formats_partial_failure(self):
"""Should format mixed success/failure results."""
state = BufferState(
buffer_id="test",
total_items=2,
return_to="x",
thread_id="t",
from_id="c",
target="w",
results={
0: BufferItemResult(0, "<Good/>", True),
1: BufferItemResult(1, "<Error/>", False, "timeout"),
},
)
xml = _format_buffer_results(state)
assert 'success="true"' in xml
assert 'success="false"' in xml
def test_formats_missing_results(self):
"""Should handle missing results."""
state = BufferState(
buffer_id="test",
total_items=3,
return_to="x",
thread_id="t",
from_id="c",
target="w",
results={
0: BufferItemResult(0, "<R/>", True),
# Index 1 missing
2: BufferItemResult(2, "<R/>", True),
},
)
xml = _format_buffer_results(state)
assert 'index="1"' in xml
assert "missing" in xml
class TestHandleBufferStartValidation:
"""Test validation in handle_buffer_start."""
@pytest.fixture(autouse=True)
def setup(self):
"""Reset registries before each test."""
reset_buffer_registry()
@pytest.mark.asyncio
async def test_empty_items_returns_error(self):
"""handle_buffer_start should return error for empty items."""
from unittest.mock import patch, MagicMock
mock_pump = MagicMock()
mock_pump.listeners = {}
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
payload = BufferStart(
target="worker",
items="", # Empty
return_to="caller",
)
metadata = HandlerMetadata(
thread_id=str(uuid.uuid4()),
from_id="console",
)
response = await handle_buffer_start(payload, metadata)
assert isinstance(response, HandlerResponse)
assert isinstance(response.payload, BufferError)
assert "No items" in response.payload.error
@pytest.mark.asyncio
async def test_unknown_target_returns_error(self):
"""handle_buffer_start should return error for unknown target."""
from unittest.mock import patch, MagicMock
mock_pump = MagicMock()
mock_pump.listeners = {} # No listeners
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
payload = BufferStart(
target="unknown_worker",
items="item1\nitem2",
return_to="caller",
)
metadata = HandlerMetadata(
thread_id=str(uuid.uuid4()),
from_id="console",
)
response = await handle_buffer_start(payload, metadata)
assert isinstance(response, HandlerResponse)
assert isinstance(response.payload, BufferError)
assert "Unknown target" in response.payload.error
class TestBufferRegistrySingleton:
"""Test singleton pattern for BufferRegistry."""
def test_get_buffer_registry_returns_singleton(self):
"""get_buffer_registry should return same instance."""
reset_buffer_registry()
reg1 = get_buffer_registry()
reg2 = get_buffer_registry()
assert reg1 is reg2
def test_reset_creates_new_instance(self):
"""reset_buffer_registry should clear singleton."""
reg1 = get_buffer_registry()
reg1.create("test", 1, "x", "t", "c", "w")
reset_buffer_registry()
reg2 = get_buffer_registry()
assert reg2.get("test") is None
class TestBufferCollectVsFireAndForget:
"""Test collect=True vs collect=False behavior."""
@pytest.mark.asyncio
async def test_collect_mode_returns_none(self):
"""With collect=True, handler should return None (wait for results)."""
from unittest.mock import patch, MagicMock, AsyncMock
mock_pump = MagicMock()
mock_pump.listeners = {"worker": MagicMock()}
mock_pump.register_generic_listener = MagicMock()
mock_pump._wrap_in_envelope = MagicMock(return_value=b"<envelope/>")
mock_pump.inject = AsyncMock()
# Mock thread registry
mock_thread_registry = MagicMock()
mock_thread_registry.lookup = MagicMock(return_value="root.parent")
mock_thread_registry.get_or_create = MagicMock(return_value="worker-uuid")
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
with patch('xml_pipeline.message_bus.thread_registry.get_registry', return_value=mock_thread_registry):
payload = BufferStart(
target="worker",
items="item1\nitem2",
return_to="caller",
collect=True,
)
metadata = HandlerMetadata(
thread_id=str(uuid.uuid4()),
from_id="console",
)
response = await handle_buffer_start(payload, metadata)
# With collect=True, returns None (ephemeral handler will send BufferComplete)
assert response is None
# Ephemeral listener should be registered
mock_pump.register_generic_listener.assert_called_once()
@pytest.mark.asyncio
async def test_fire_and_forget_returns_dispatched(self):
"""With collect=False, handler should return BufferDispatched immediately."""
from unittest.mock import patch, MagicMock, AsyncMock
mock_pump = MagicMock()
mock_pump.listeners = {"worker": MagicMock()}
mock_pump.register_generic_listener = MagicMock()
mock_pump._wrap_in_envelope = MagicMock(return_value=b"<envelope/>")
mock_pump.inject = AsyncMock()
mock_thread_registry = MagicMock()
mock_thread_registry.lookup = MagicMock(return_value="root.parent")
mock_thread_registry.get_or_create = MagicMock(return_value="worker-uuid")
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
with patch('xml_pipeline.message_bus.thread_registry.get_registry', return_value=mock_thread_registry):
payload = BufferStart(
target="worker",
items="item1\nitem2\nitem3",
return_to="caller",
collect=False, # Fire-and-forget
)
metadata = HandlerMetadata(
thread_id=str(uuid.uuid4()),
from_id="console",
)
response = await handle_buffer_start(payload, metadata)
# With collect=False, returns BufferDispatched
assert isinstance(response, HandlerResponse)
assert isinstance(response.payload, BufferDispatched)
assert response.payload.total == 3
# No ephemeral listener registered for fire-and-forget
mock_pump.register_generic_listener.assert_not_called()
class TestBufferResultCollection:
"""Test result collection behavior."""
def test_result_collection_partial_success(self):
"""Buffer should track partial success correctly."""
registry = BufferRegistry()
registry.create("partial", 5, "x", "t", "c", "w")
# 3 successes, 2 failures
registry.record_result("partial", 0, "<R/>", True)
registry.record_result("partial", 1, "<E/>", False, "error")
registry.record_result("partial", 2, "<R/>", True)
registry.record_result("partial", 3, "<E/>", False, "error")
registry.record_result("partial", 4, "<R/>", True)
state = registry.get("partial")
assert state.is_complete is True
assert state.completed_count == 5
assert state.successful_count == 3
def test_results_out_of_order(self):
"""Buffer should handle results arriving out of order."""
registry = BufferRegistry()
registry.create("ooo", 3, "x", "t", "c", "w")
# Results arrive out of order
registry.record_result("ooo", 2, "<R2/>", True)
registry.record_result("ooo", 0, "<R0/>", True)
registry.record_result("ooo", 1, "<R1/>", True)
state = registry.get("ooo")
assert state.is_complete is True
ordered = state.get_ordered_results()
assert ordered[0].result == "<R0/>"
assert ordered[1].result == "<R1/>"
assert ordered[2].result == "<R2/>"

View file

@ -66,7 +66,7 @@ class TestPumpBootstrap:
pump = await bootstrap('config/organism.yaml')
assert pump.config.name == "hello-world"
assert len(pump.routing_table) == 6 # 3 user listeners + 3 system (boot, todo, todo-complete)
assert len(pump.routing_table) == 8 # 3 user listeners + 5 system (boot, todo, todo-complete, sequence, buffer)
assert "greeter.greeting" in pump.routing_table
assert "shouter.greetingresponse" in pump.routing_table
assert "response-handler.shoutedresponse" in pump.routing_table

464
tests/test_sequence.py Normal file
View file

@ -0,0 +1,464 @@
"""
test_sequence.py Tests for the Sequence orchestration primitives.
Tests:
1. SequenceRegistry basic operations
2. SequenceStart handler
3. Step result handling
4. Error propagation
"""
import pytest
import uuid
from xml_pipeline.message_bus.sequence_registry import (
SequenceRegistry,
SequenceState,
get_sequence_registry,
reset_sequence_registry,
)
from xml_pipeline.message_bus.message_state import HandlerMetadata, HandlerResponse
from xml_pipeline.primitives.sequence import (
SequenceStart,
SequenceComplete,
SequenceError,
handle_sequence_start,
)
class TestSequenceRegistry:
"""Test SequenceRegistry basic operations."""
def test_create_sequence_state(self):
"""create() should create a sequence state with correct fields."""
registry = SequenceRegistry()
state = registry.create(
sequence_id="seq001",
steps=["step1", "step2", "step3"],
return_to="caller",
thread_id="thread-123",
from_id="console",
initial_payload="<TestPayload/>",
)
assert state.sequence_id == "seq001"
assert state.steps == ["step1", "step2", "step3"]
assert state.return_to == "caller"
assert state.thread_id == "thread-123"
assert state.from_id == "console"
assert state.current_index == 0
assert state.is_complete is False
assert state.current_step == "step1"
assert state.remaining_steps == ["step1", "step2", "step3"]
assert state.last_result == "<TestPayload/>"
def test_get_returns_sequence(self):
"""get() should return sequence by ID."""
registry = SequenceRegistry()
registry.create(
sequence_id="seq002",
steps=["a", "b"],
return_to="x",
thread_id="t",
from_id="c",
)
state = registry.get("seq002")
assert state is not None
assert state.sequence_id == "seq002"
# Non-existent returns None
assert registry.get("nonexistent") is None
def test_advance_increments_index(self):
"""advance() should increment index and store result."""
registry = SequenceRegistry()
registry.create(
sequence_id="seq003",
steps=["a", "b", "c"],
return_to="x",
thread_id="t",
from_id="c",
)
# Advance first step
state = registry.advance("seq003", "<ResultA/>")
assert state.current_index == 1
assert state.results == ["<ResultA/>"]
assert state.last_result == "<ResultA/>"
assert state.current_step == "b"
assert state.is_complete is False
# Advance second step
state = registry.advance("seq003", "<ResultB/>")
assert state.current_index == 2
assert state.results == ["<ResultA/>", "<ResultB/>"]
assert state.current_step == "c"
# Advance third step - now complete
state = registry.advance("seq003", "<ResultC/>")
assert state.current_index == 3
assert state.is_complete is True
assert state.current_step is None
assert state.remaining_steps == []
def test_mark_failed(self):
"""mark_failed() should set failed state."""
registry = SequenceRegistry()
registry.create(
sequence_id="seq004",
steps=["a", "b"],
return_to="x",
thread_id="t",
from_id="c",
)
state = registry.mark_failed("seq004", "a", "XSD validation failed")
assert state.failed is True
assert state.failed_step == "a"
assert state.error == "XSD validation failed"
assert state.is_complete is False # Not complete, but failed
assert state.current_step is None # No current step when failed
assert state.remaining_steps == [] # No remaining steps when failed
def test_remove_deletes_sequence(self):
"""remove() should delete sequence from registry."""
registry = SequenceRegistry()
registry.create(
sequence_id="seq005",
steps=["a"],
return_to="x",
thread_id="t",
from_id="c",
)
assert registry.get("seq005") is not None
result = registry.remove("seq005")
assert result is True
assert registry.get("seq005") is None
# Remove non-existent returns False
assert registry.remove("nonexistent") is False
def test_list_active(self):
"""list_active() should return all active sequence IDs."""
registry = SequenceRegistry()
registry.create("seq-a", ["1"], "x", "t", "c")
registry.create("seq-b", ["2"], "x", "t", "c")
registry.create("seq-c", ["3"], "x", "t", "c")
active = registry.list_active()
assert set(active) == {"seq-a", "seq-b", "seq-c"}
def test_clear(self):
"""clear() should remove all sequences."""
registry = SequenceRegistry()
registry.create("seq-1", ["a"], "x", "t", "c")
registry.create("seq-2", ["b"], "x", "t", "c")
registry.clear()
assert registry.list_active() == []
class TestSequenceStateProperties:
"""Test SequenceState computed properties."""
def test_is_complete_after_all_steps(self):
"""is_complete should be True when all steps are done."""
state = SequenceState(
sequence_id="test",
steps=["a", "b"],
return_to="x",
thread_id="t",
from_id="c",
current_index=2, # Past all steps
)
assert state.is_complete is True
def test_not_complete_when_failed(self):
"""is_complete should be False when failed."""
state = SequenceState(
sequence_id="test",
steps=["a", "b"],
return_to="x",
thread_id="t",
from_id="c",
current_index=2,
failed=True,
)
assert state.is_complete is False
def test_current_step_none_when_complete(self):
"""current_step should be None when sequence is complete."""
state = SequenceState(
sequence_id="test",
steps=["a"],
return_to="x",
thread_id="t",
from_id="c",
current_index=1,
)
assert state.current_step is None
def test_remaining_steps_empty_when_complete(self):
"""remaining_steps should be empty when complete."""
state = SequenceState(
sequence_id="test",
steps=["a", "b"],
return_to="x",
thread_id="t",
from_id="c",
current_index=2,
)
assert state.remaining_steps == []
class TestSequenceStartPayload:
"""Test SequenceStart payload serialization."""
def test_sequence_start_fields(self):
"""SequenceStart should have expected fields."""
payload = SequenceStart(
steps="step1,step2",
payload="<Test/>",
return_to="caller",
sequence_id="custom-id",
)
assert payload.steps == "step1,step2"
assert payload.payload == "<Test/>"
assert payload.return_to == "caller"
assert payload.sequence_id == "custom-id"
def test_sequence_start_default_values(self):
"""SequenceStart should have sensible defaults."""
payload = SequenceStart()
assert payload.steps == ""
assert payload.payload == ""
assert payload.return_to == ""
assert payload.sequence_id == ""
class TestSequenceCompletePayload:
"""Test SequenceComplete payload."""
def test_sequence_complete_fields(self):
"""SequenceComplete should have expected fields."""
payload = SequenceComplete(
sequence_id="seq123",
final_result="<Result>42</Result>",
step_count=3,
)
assert payload.sequence_id == "seq123"
assert payload.final_result == "<Result>42</Result>"
assert payload.step_count == 3
class TestSequenceErrorPayload:
"""Test SequenceError payload."""
def test_sequence_error_fields(self):
"""SequenceError should have expected fields."""
payload = SequenceError(
sequence_id="seq456",
failed_step="bad-step",
step_index=1,
error="Validation failed",
)
assert payload.sequence_id == "seq456"
assert payload.failed_step == "bad-step"
assert payload.step_index == 1
assert payload.error == "Validation failed"
class TestHandleSequenceStartValidation:
"""Test validation in handle_sequence_start."""
@pytest.fixture(autouse=True)
def setup(self):
"""Reset registries before each test."""
reset_sequence_registry()
@pytest.mark.asyncio
async def test_empty_steps_returns_error(self):
"""handle_sequence_start should return error for empty steps."""
from unittest.mock import patch, MagicMock
# Mock get_stream_pump to avoid dependency
mock_pump = MagicMock()
mock_pump.listeners = {}
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
payload = SequenceStart(
steps="", # Empty
payload="<Test/>",
return_to="caller",
)
metadata = HandlerMetadata(
thread_id=str(uuid.uuid4()),
from_id="console",
)
response = await handle_sequence_start(payload, metadata)
assert isinstance(response, HandlerResponse)
assert isinstance(response.payload, SequenceError)
assert "No steps" in response.payload.error
@pytest.mark.asyncio
async def test_unknown_step_returns_error(self):
"""handle_sequence_start should return error for unknown step."""
from unittest.mock import patch, MagicMock
# Mock get_stream_pump with no listeners
mock_pump = MagicMock()
mock_pump.listeners = {} # No listeners registered
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
payload = SequenceStart(
steps="unknown_step",
payload="<Test/>",
return_to="caller",
)
metadata = HandlerMetadata(
thread_id=str(uuid.uuid4()),
from_id="console",
)
response = await handle_sequence_start(payload, metadata)
assert isinstance(response, HandlerResponse)
assert isinstance(response.payload, SequenceError)
assert "Unknown listener" in response.payload.error
assert "unknown_step" in response.payload.failed_step
class TestSequenceRegistrySingleton:
"""Test singleton pattern for SequenceRegistry."""
def test_get_sequence_registry_returns_singleton(self):
"""get_sequence_registry should return same instance."""
reset_sequence_registry()
reg1 = get_sequence_registry()
reg2 = get_sequence_registry()
assert reg1 is reg2
def test_reset_creates_new_instance(self):
"""reset_sequence_registry should clear singleton."""
reg1 = get_sequence_registry()
reg1.create("test", ["a"], "x", "t", "c")
reset_sequence_registry()
reg2 = get_sequence_registry()
assert reg2.get("test") is None
class TestSequenceMultipleSteps:
"""Test sequences with multiple steps."""
def test_three_step_sequence(self):
"""A three-step sequence should advance through all steps."""
registry = SequenceRegistry()
registry.create(
sequence_id="multi",
steps=["add", "multiply", "format"],
return_to="caller",
thread_id="t",
from_id="c",
initial_payload="<Input>5</Input>",
)
# Step 1: add
state = registry.get("multi")
assert state.current_step == "add"
state = registry.advance("multi", "<Sum>8</Sum>")
assert state.last_result == "<Sum>8</Sum>"
# Step 2: multiply
assert state.current_step == "multiply"
state = registry.advance("multi", "<Product>40</Product>")
# Step 3: format
assert state.current_step == "format"
state = registry.advance("multi", "<Formatted>Result: 40</Formatted>")
# Complete
assert state.is_complete is True
assert len(state.results) == 3
assert state.results[0] == "<Sum>8</Sum>"
assert state.results[1] == "<Product>40</Product>"
assert state.results[2] == "<Formatted>Result: 40</Formatted>"
def test_failure_at_middle_step(self):
"""Failure at middle step should stop sequence."""
registry = SequenceRegistry()
registry.create(
sequence_id="fail-mid",
steps=["step1", "step2", "step3"],
return_to="caller",
thread_id="t",
from_id="c",
)
# Step 1 succeeds
registry.advance("fail-mid", "<R1/>")
# Step 2 fails
state = registry.mark_failed("fail-mid", "step2", "Connection timeout")
assert state.failed is True
assert state.failed_step == "step2"
assert state.current_index == 1 # Was at step 2 (index 1)
assert len(state.results) == 1 # Only step 1 result
class TestSequenceWithRealSteps:
"""Integration-style tests with mock handlers."""
@pytest.mark.asyncio
async def test_sequence_creates_ephemeral_listener(self):
"""Starting a sequence should create an ephemeral listener."""
from unittest.mock import patch, MagicMock, AsyncMock
mock_pump = MagicMock()
mock_pump.listeners = {"step1": MagicMock(), "step2": MagicMock()}
mock_pump.register_generic_listener = MagicMock()
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
payload = SequenceStart(
steps="step1,step2",
payload="<Input/>",
return_to="caller",
)
metadata = HandlerMetadata(
thread_id=str(uuid.uuid4()),
from_id="console",
)
response = await handle_sequence_start(payload, metadata)
# Should have registered an ephemeral listener
mock_pump.register_generic_listener.assert_called_once()
call_args = mock_pump.register_generic_listener.call_args
name_arg = call_args.kwargs.get('name') or call_args.args[0]
assert name_arg.startswith("sequence_")

View file

@ -30,6 +30,9 @@ from xml_pipeline.message_bus.stream_pump import (
ListenerConfig,
OrganismConfig,
bootstrap,
get_stream_pump,
set_stream_pump,
reset_stream_pump,
)
from xml_pipeline.message_bus.message_state import (
@ -42,15 +45,47 @@ from xml_pipeline.message_bus.system_pipeline import (
ExternalMessage,
)
from xml_pipeline.message_bus.sequence_registry import (
SequenceState,
SequenceRegistry,
get_sequence_registry,
reset_sequence_registry,
)
from xml_pipeline.message_bus.buffer_registry import (
BufferState,
BufferItemResult,
BufferRegistry,
get_buffer_registry,
reset_buffer_registry,
)
__all__ = [
# Pump
"StreamPump",
"ConfigLoader",
"Listener",
"ListenerConfig",
"OrganismConfig",
"bootstrap",
"get_stream_pump",
"set_stream_pump",
"reset_stream_pump",
# Message state
"MessageState",
"HandlerMetadata",
"bootstrap",
# System pipeline
"SystemPipeline",
"ExternalMessage",
# Sequence registry
"SequenceState",
"SequenceRegistry",
"get_sequence_registry",
"reset_sequence_registry",
# Buffer registry
"BufferState",
"BufferItemResult",
"BufferRegistry",
"get_buffer_registry",
"reset_buffer_registry",
]

View file

@ -0,0 +1,230 @@
"""
buffer_registry.py State storage for Buffer (fan-out) orchestration.
Tracks active buffer executions that fan-out to parallel workers.
When a buffer starts, N items are dispatched in parallel. Results are
collected here. When all results are in (or timeout), BufferComplete is sent.
Design:
- Thread-safe (same pattern as TodoRegistry, SequenceRegistry)
- Keyed by buffer_id (short UUID)
- Tracks: total items, received results, success/failure per item
- Supports fire-and-forget mode (collect=False)
Usage:
registry = get_buffer_registry()
# Start a buffer
registry.create(
buffer_id="abc123",
total_items=5,
return_to="greeter",
thread_id="...",
collect=True,
)
# Record result from worker
state = registry.record_result(
buffer_id="abc123",
index=2,
result="<SearchResult>...</SearchResult>",
success=True,
)
if state.is_complete:
# All workers done
final_results = state.results
registry.remove(buffer_id)
"""
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
import threading
@dataclass
class BufferItemResult:
"""Result from a single buffer item (worker)."""
index: int
result: str # XML result
success: bool = True
error: Optional[str] = None
@dataclass
class BufferState:
"""State for an active buffer execution."""
buffer_id: str
total_items: int # How many items were dispatched
return_to: str # Where to send BufferComplete
thread_id: str # Original thread for returning
from_id: str # Who started the buffer
target: str # Target listener for items
collect: bool = True # Whether to wait for results
results: Dict[int, BufferItemResult] = field(default_factory=dict)
completed_count: int = 0
successful_count: int = 0
@property
def is_complete(self) -> bool:
"""True when all items have reported back."""
return self.completed_count >= self.total_items
@property
def pending_count(self) -> int:
"""Number of items still pending."""
return self.total_items - self.completed_count
def get_ordered_results(self) -> List[Optional[BufferItemResult]]:
"""Get results in order (None for missing indices)."""
return [self.results.get(i) for i in range(self.total_items)]
class BufferRegistry:
"""
Registry for active buffer executions.
Thread-safe. Singleton pattern via get_buffer_registry().
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._buffers: Dict[str, BufferState] = {}
def create(
self,
buffer_id: str,
total_items: int,
return_to: str,
thread_id: str,
from_id: str,
target: str,
collect: bool = True,
) -> BufferState:
"""
Create a new buffer execution.
Args:
buffer_id: Unique ID for this buffer
total_items: Number of items being dispatched
return_to: Listener to send BufferComplete to
thread_id: Thread UUID for routing
from_id: Who initiated the buffer
target: Target listener for each item
collect: Whether to wait for and collect results
Returns:
BufferState for tracking
"""
state = BufferState(
buffer_id=buffer_id,
total_items=total_items,
return_to=return_to,
thread_id=thread_id,
from_id=from_id,
target=target,
collect=collect,
)
with self._lock:
self._buffers[buffer_id] = state
return state
def get(self, buffer_id: str) -> Optional[BufferState]:
"""Get buffer state by ID."""
with self._lock:
return self._buffers.get(buffer_id)
def record_result(
self,
buffer_id: str,
index: int,
result: str,
success: bool = True,
error: Optional[str] = None,
) -> Optional[BufferState]:
"""
Record a result from a worker.
Args:
buffer_id: Buffer this result belongs to
index: Which item index (0-based)
result: XML result from the worker
success: Whether the worker succeeded
error: Error message if failed
Returns:
Updated BufferState, or None if buffer not found
"""
with self._lock:
state = self._buffers.get(buffer_id)
if state is None:
return None
# Don't double-count results for same index
if index in state.results:
return state
item_result = BufferItemResult(
index=index,
result=result,
success=success,
error=error,
)
state.results[index] = item_result
state.completed_count += 1
if success:
state.successful_count += 1
return state
def remove(self, buffer_id: str) -> bool:
"""
Remove a buffer (cleanup after completion).
Returns:
True if found and removed, False if not found
"""
with self._lock:
return self._buffers.pop(buffer_id, None) is not None
def list_active(self) -> List[str]:
"""List all active buffer IDs."""
with self._lock:
return list(self._buffers.keys())
def clear(self) -> None:
"""Clear all buffers. Useful for testing."""
with self._lock:
self._buffers.clear()
# ============================================================================
# Singleton
# ============================================================================
_registry: Optional[BufferRegistry] = None
_registry_lock = threading.Lock()
def get_buffer_registry() -> BufferRegistry:
"""Get the global BufferRegistry singleton."""
global _registry
if _registry is None:
with _registry_lock:
if _registry is None:
_registry = BufferRegistry()
return _registry
def reset_buffer_registry() -> None:
"""Reset the global buffer registry (for testing)."""
global _registry
with _registry_lock:
if _registry is not None:
_registry.clear()
_registry = None

View file

@ -0,0 +1,228 @@
"""
sequence_registry.py State storage for Sequence orchestration.
Tracks active sequence executions across handler invocations.
When a sequence starts, its state is registered here. As steps complete,
the state is updated. When all steps are done, the state is cleaned up.
Design:
- Thread-safe (same pattern as TodoRegistry)
- Keyed by sequence_id (short UUID)
- Tracks: steps list, current index, collected results
- Auto-cleanup when sequence completes or errors
Usage:
registry = get_sequence_registry()
# Start a sequence
registry.create(
sequence_id="abc123",
steps=["calculator.add", "calculator.multiply"],
return_to="greeter",
thread_id="...",
initial_payload="<AddPayload>...</AddPayload>",
)
# Advance on step completion
state = registry.advance(sequence_id, step_result="<AddResult>42</AddResult>")
if state.is_complete:
# All steps done
registry.remove(sequence_id)
# On error
registry.mark_failed(sequence_id, step="calculator.add", error="XSD validation failed")
"""
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
import threading
@dataclass
class SequenceState:
"""State for an active sequence execution."""
sequence_id: str
steps: List[str] # Ordered list of listener names
return_to: str # Where to send final result
thread_id: str # Original thread for returning
from_id: str # Who started the sequence
current_index: int = 0 # Which step we're on (0-based)
results: List[str] = field(default_factory=list) # XML results from each step
last_result: Optional[str] = None # Most recent step result (for chaining)
failed: bool = False
failed_step: Optional[str] = None
error: Optional[str] = None
@property
def is_complete(self) -> bool:
"""True when all steps have been executed successfully."""
return not self.failed and self.current_index >= len(self.steps)
@property
def current_step(self) -> Optional[str]:
"""Get current step name, or None if complete/failed."""
if self.failed or self.current_index >= len(self.steps):
return None
return self.steps[self.current_index]
@property
def remaining_steps(self) -> List[str]:
"""Steps not yet executed."""
if self.failed:
return []
return self.steps[self.current_index:]
class SequenceRegistry:
"""
Registry for active sequence executions.
Thread-safe. Singleton pattern via get_sequence_registry().
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._sequences: Dict[str, SequenceState] = {}
def create(
self,
sequence_id: str,
steps: List[str],
return_to: str,
thread_id: str,
from_id: str,
initial_payload: str = "",
) -> SequenceState:
"""
Create a new sequence execution.
Args:
sequence_id: Unique ID for this sequence
steps: Ordered list of listener names to call
return_to: Listener to send SequenceComplete to
thread_id: Thread UUID for routing
from_id: Who initiated the sequence
initial_payload: XML payload for first step
Returns:
SequenceState for tracking
"""
state = SequenceState(
sequence_id=sequence_id,
steps=steps,
return_to=return_to,
thread_id=thread_id,
from_id=from_id,
last_result=initial_payload if initial_payload else None,
)
with self._lock:
self._sequences[sequence_id] = state
return state
def get(self, sequence_id: str) -> Optional[SequenceState]:
"""Get sequence state by ID."""
with self._lock:
return self._sequences.get(sequence_id)
def advance(self, sequence_id: str, step_result: str) -> Optional[SequenceState]:
"""
Record step completion and advance to next step.
Args:
sequence_id: Sequence to advance
step_result: XML result from the completed step
Returns:
Updated SequenceState, or None if not found
"""
with self._lock:
state = self._sequences.get(sequence_id)
if state is None or state.failed:
return state
# Record result
state.results.append(step_result)
state.last_result = step_result
state.current_index += 1
return state
def mark_failed(
self,
sequence_id: str,
step: str,
error: str,
) -> Optional[SequenceState]:
"""
Mark a sequence as failed.
Args:
sequence_id: Sequence that failed
step: Which step failed
error: Error message
Returns:
Updated SequenceState, or None if not found
"""
with self._lock:
state = self._sequences.get(sequence_id)
if state is None:
return None
state.failed = True
state.failed_step = step
state.error = error
return state
def remove(self, sequence_id: str) -> bool:
"""
Remove a sequence (cleanup after completion).
Returns:
True if found and removed, False if not found
"""
with self._lock:
return self._sequences.pop(sequence_id, None) is not None
def list_active(self) -> List[str]:
"""List all active sequence IDs."""
with self._lock:
return list(self._sequences.keys())
def clear(self) -> None:
"""Clear all sequences. Useful for testing."""
with self._lock:
self._sequences.clear()
# ============================================================================
# Singleton
# ============================================================================
_registry: Optional[SequenceRegistry] = None
_registry_lock = threading.Lock()
def get_sequence_registry() -> SequenceRegistry:
"""Get the global SequenceRegistry singleton."""
global _registry
if _registry is None:
with _registry_lock:
if _registry is None:
_registry = SequenceRegistry()
return _registry
def reset_sequence_registry() -> None:
"""Reset the global sequence registry (for testing)."""
global _registry
with _registry_lock:
if _registry is not None:
_registry.clear()
_registry = None

View file

@ -210,6 +210,10 @@ class StreamPump:
self.routing_table: Dict[str, List[Listener]] = {}
self.listeners: Dict[str, Listener] = {}
# Generic listeners (accept any payload type)
# Used for ephemeral orchestration handlers (sequences, buffers)
self._generic_listeners: Dict[str, Listener] = {}
# Per-agent semaphores for rate limiting
self.agent_semaphores: Dict[str, asyncio.Semaphore] = {}
@ -269,6 +273,82 @@ class StreamPump:
self.listeners[lc.name] = listener
return listener
def register_generic_listener(
self,
name: str,
handler: Callable,
description: str = "",
) -> Listener:
"""
Register a generic listener that accepts any payload type.
Used for ephemeral orchestration handlers (sequences, buffers)
that need to receive responses from various step types.
Generic listeners:
- Are NOT added to the routing table (no root_tag)
- Are looked up by name (to_id) as a fallback in routing
- Receive payload_tree directly (no XSD validation/deserialization)
Args:
name: Unique listener name (e.g., "sequence_abc123")
handler: Async handler function (receives payload_tree, metadata)
description: Human-readable description
Returns:
Listener object
"""
listener = Listener(
name=name,
payload_class=object, # Placeholder - not used
handler=handler,
description=description,
is_agent=False,
root_tag="*", # Wildcard marker
)
self._generic_listeners[name.lower()] = listener
self.listeners[name] = listener
pump_logger.debug(f"Registered generic listener: {name}")
return listener
def unregister_listener(self, name: str) -> bool:
"""
Remove a listener by name.
Used to clean up ephemeral listeners after orchestration completes.
Args:
name: Listener name to remove
Returns:
True if found and removed, False if not found
"""
name_lower = name.lower()
removed = False
# Remove from generic listeners
if name_lower in self._generic_listeners:
del self._generic_listeners[name_lower]
removed = True
pump_logger.debug(f"Unregistered generic listener: {name}")
# Remove from main listeners dict
if name in self.listeners:
listener = self.listeners.pop(name)
removed = True
# Remove from routing table
if listener.root_tag and listener.root_tag != "*":
listeners_for_tag = self.routing_table.get(listener.root_tag, [])
if listener in listeners_for_tag:
listeners_for_tag.remove(listener)
if not listeners_for_tag:
del self.routing_table[listener.root_tag]
return removed
def register_all(self) -> None:
# First pass: register all listeners
for lc in self.config.listeners:
@ -781,6 +861,8 @@ class StreamPump:
Combined validation + deserialization.
Uses to_id + payload tag to find the right listener and schema.
Falls back to generic listeners (ephemeral orchestration handlers)
when no regular listener matches.
"""
if state.error or state.payload_tree is None:
return state
@ -794,6 +876,19 @@ class StreamPump:
lookup_key = f"{to_id}.{payload_tag.lower()}" if to_id else payload_tag.lower()
listeners = self.routing_table.get(lookup_key, [])
# Fallback: check for generic listener by to_id
# Generic listeners accept any payload type (for orchestration)
if not listeners and to_id:
generic_listener = self._generic_listeners.get(to_id)
if generic_listener:
# Generic listener: skip XSD validation and deserialization
# Pass the raw payload_tree to the handler
state.payload = state.payload_tree # Handler receives Element
state.target_listeners = [generic_listener]
state.metadata["generic_handler"] = True
return state
if not listeners:
state.error = f"No listener for: {lookup_key}"
return state
@ -1008,6 +1103,36 @@ async def bootstrap(config_path: str = "config/organism.yaml") -> StreamPump:
)
pump.register_listener(todo_complete_config)
# Register Sequence primitives (orchestration)
from xml_pipeline.primitives.sequence import (
SequenceStart, handle_sequence_start,
)
sequence_config = ListenerConfig(
name="system.sequence",
payload_class_path="xml_pipeline.primitives.sequence.SequenceStart",
handler_path="xml_pipeline.primitives.sequence.handle_sequence_start",
description="System sequence handler - chains listeners in order",
is_agent=False,
payload_class=SequenceStart,
handler=handle_sequence_start,
)
pump.register_listener(sequence_config)
# Register Buffer primitives (fan-out orchestration)
from xml_pipeline.primitives.buffer import (
BufferStart, handle_buffer_start,
)
buffer_config = ListenerConfig(
name="system.buffer",
payload_class_path="xml_pipeline.primitives.buffer.BufferStart",
handler_path="xml_pipeline.primitives.buffer.handle_buffer_start",
description="System buffer handler - fan-out to parallel workers",
is_agent=False,
payload_class=BufferStart,
handler=handle_buffer_start,
)
pump.register_listener(buffer_config)
# Register all user-defined listeners
pump.register_all()
@ -1061,6 +1186,9 @@ async def bootstrap(config_path: str = "config/organism.yaml") -> StreamPump:
# Inject boot message (will be processed when pump.run() is called)
await pump.inject(boot_envelope, thread_id=root_uuid, from_id="system")
# Set global pump instance for get_stream_pump()
set_stream_pump(pump)
print(f"Routing: {list(pump.routing_table.keys())}")
return pump
@ -1110,3 +1238,45 @@ The key difference:
- Old: 3 tool calls = 3 sequential awaits, each blocking until complete
- New: 3 tool calls = 3 items in stream, processed concurrently up to task_limit
"""
# ============================================================================
# Global Singleton
# ============================================================================
_pump: Optional[StreamPump] = None
def get_stream_pump() -> StreamPump:
"""
Get the global StreamPump instance.
The pump is initialized via bootstrap() and set here.
Raises RuntimeError if called before bootstrap.
"""
global _pump
if _pump is None:
raise RuntimeError(
"StreamPump not initialized. Call bootstrap() first."
)
return _pump
def set_stream_pump(pump: StreamPump) -> None:
"""
Set the global StreamPump instance.
Called by bootstrap() after creating the pump.
"""
global _pump
_pump = pump
def reset_stream_pump() -> None:
"""
Reset the global StreamPump instance.
Useful for testing.
"""
global _pump
_pump = None

View file

@ -15,16 +15,45 @@ from xml_pipeline.primitives.todo import (
handle_todo_complete,
)
from xml_pipeline.primitives.text_input import TextInput, TextOutput
from xml_pipeline.primitives.sequence import (
SequenceStart,
SequenceComplete,
SequenceError,
handle_sequence_start,
)
from xml_pipeline.primitives.buffer import (
BufferStart,
BufferItem,
BufferComplete,
BufferDispatched,
BufferError,
handle_buffer_start,
)
__all__ = [
# Boot
"Boot",
"handle_boot",
# Todo
"TodoUntil",
"TodoComplete",
"TodoRegistered",
"TodoClosed",
"handle_todo_until",
"handle_todo_complete",
# Text I/O
"TextInput",
"TextOutput",
# Sequence orchestration
"SequenceStart",
"SequenceComplete",
"SequenceError",
"handle_sequence_start",
# Buffer orchestration
"BufferStart",
"BufferItem",
"BufferComplete",
"BufferDispatched",
"BufferError",
"handle_buffer_start",
]

View file

@ -0,0 +1,373 @@
"""
buffer.py Buffer (fan-out) orchestration primitives.
Buffers fan-out to parallel workers, sending N items to the same listener
concurrently. Results are collected and returned when all complete.
Usage by an agent:
# Fan-out search queries to web_search
return HandlerResponse(
payload=BufferStart(
target="web_search",
items="python async\\nrust memory\\ngo concurrency",
return_to="my-agent",
collect=True,
),
to="system.buffer",
)
Flow:
1. system.buffer receives BufferStart with N items
2. Creates ephemeral listener buffer_{id} to receive results
3. Creates N sibling threads via ThreadRegistry
4. Sends BufferItem to each worker FROM buffer_{id}
5. Workers process and respond routes to buffer_{id}
6. Ephemeral handler collects results
7. When all workers done, sends BufferComplete to return_to
8. Cleans up ephemeral listener
Fire-and-forget mode (collect=False):
- Returns immediately after dispatching
- No result collection
- Useful for async side effects
"""
from dataclasses import dataclass
from typing import Optional
import uuid as uuid_module
import logging
from lxml import etree
from third_party.xmlable import xmlify
from xml_pipeline.message_bus.message_state import (
HandlerMetadata,
HandlerResponse,
)
from xml_pipeline.message_bus.buffer_registry import get_buffer_registry
from xml_pipeline.message_bus.thread_registry import get_registry as get_thread_registry
logger = logging.getLogger(__name__)
# ============================================================================
# Payloads
# ============================================================================
@xmlify
@dataclass
class BufferStart:
"""
Start a new buffer (fan-out) execution.
Sent to system.buffer to begin parallel processing.
"""
target: str = "" # Listener to fan-out to
items: str = "" # Newline-separated payloads (raw XML)
collect: bool = True # Wait for all results?
return_to: str = "" # Where to send BufferComplete
buffer_id: str = "" # Auto-generated if empty
@xmlify
@dataclass
class BufferItem:
"""
Individual item being processed by a worker.
Wraps the actual payload with buffer metadata.
Note: This is an internal type - workers receive the raw payload,
not BufferItem directly.
"""
buffer_id: str = ""
index: int = 0
payload: str = "" # The actual XML payload
@xmlify
@dataclass
class BufferComplete:
"""
Buffer completed - all workers finished.
Sent to return_to when all items are processed.
"""
buffer_id: str = ""
total: int = 0
successful: int = 0
results: str = "" # XML array of results
@xmlify
@dataclass
class BufferDispatched:
"""
Buffer dispatched (fire-and-forget mode).
Sent immediately after items are dispatched when collect=False.
"""
buffer_id: str = ""
total: int = 0
@xmlify
@dataclass
class BufferError:
"""
Buffer failed to start.
Sent when buffer initialization fails.
"""
buffer_id: str = ""
error: str = ""
# ============================================================================
# Handlers
# ============================================================================
async def handle_buffer_start(
payload: BufferStart,
metadata: HandlerMetadata,
) -> Optional[HandlerResponse]:
"""
Handle BufferStart begin a fan-out execution.
Creates N sibling threads, dispatches items to workers,
and sets up result collection.
"""
from xml_pipeline.message_bus.stream_pump import get_stream_pump
# Parse items
items = [item.strip() for item in payload.items.split("\n") if item.strip()]
if not items:
logger.error("BufferStart with no items")
return HandlerResponse(
payload=BufferError(
buffer_id=payload.buffer_id or "unknown",
error="No items specified",
),
to=payload.return_to or metadata.from_id,
)
# Validate target exists
pump = get_stream_pump()
if payload.target not in pump.listeners:
logger.error(f"BufferStart: unknown target '{payload.target}'")
return HandlerResponse(
payload=BufferError(
buffer_id=payload.buffer_id or "unknown",
error=f"Unknown target listener: {payload.target}",
),
to=payload.return_to or metadata.from_id,
)
# Generate buffer ID if not provided
buf_id = payload.buffer_id or str(uuid_module.uuid4())[:8]
# Create buffer state
buffer_registry = get_buffer_registry()
state = buffer_registry.create(
buffer_id=buf_id,
total_items=len(items),
return_to=payload.return_to or metadata.from_id,
thread_id=metadata.thread_id,
from_id=metadata.from_id,
target=payload.target,
collect=payload.collect,
)
# For fire-and-forget, we still track but don't wait
ephemeral_name = f"buffer_{buf_id}"
if payload.collect:
# Create ephemeral handler for result collection
async def buffer_handler(
payload_tree: etree._Element,
meta: HandlerMetadata,
) -> Optional[HandlerResponse]:
"""Ephemeral handler that collects worker results."""
return await _handle_buffer_result(buf_id, payload_tree, meta)
# Register ephemeral listener (generic mode - accepts any payload)
pump.register_generic_listener(
name=ephemeral_name,
handler=buffer_handler,
description=f"Ephemeral buffer handler for {buf_id}",
)
logger.info(
f"Buffer {buf_id} starting: {len(items)} items to {payload.target}, "
f"collect={payload.collect}"
)
# Dispatch all items in parallel
thread_registry = get_thread_registry()
parent_chain = thread_registry.lookup(metadata.thread_id) or metadata.thread_id
for i, item_payload in enumerate(items):
# Create sibling thread for this worker
worker_chain = f"{parent_chain}.{ephemeral_name}_w{i}"
worker_uuid = thread_registry.get_or_create(worker_chain)
# Inject the item to the target
# The item is sent FROM the ephemeral listener so .respond() comes back
await _inject_buffer_item(
pump=pump,
target=payload.target,
payload_xml=item_payload,
thread_id=worker_uuid,
from_id=ephemeral_name,
)
logger.debug(f"Buffer {buf_id}: dispatched item {i} to {payload.target}")
# Fire-and-forget: return immediately
if not payload.collect:
logger.info(f"Buffer {buf_id}: fire-and-forget mode, {len(items)} items dispatched")
return HandlerResponse(
payload=BufferDispatched(
buffer_id=buf_id,
total=len(items),
),
to=payload.return_to or metadata.from_id,
)
# Collect mode: wait for results (handled by ephemeral listener)
# Return None - the ephemeral listener will send BufferComplete
return None
async def _handle_buffer_result(
buf_id: str,
payload_tree: etree._Element,
metadata: HandlerMetadata,
) -> Optional[HandlerResponse]:
"""
Handle a worker result in the buffer.
Called by the ephemeral listener when a worker responds.
"""
from xml_pipeline.message_bus.stream_pump import get_stream_pump
buffer_registry = get_buffer_registry()
state = buffer_registry.get(buf_id)
if state is None:
logger.warning(f"Buffer {buf_id} not found in registry (result dropped)")
return None
# Extract worker index from thread chain
# Chain format: parent.buffer_xyz_wN where N is the index
thread_registry = get_thread_registry()
chain = thread_registry.lookup(metadata.thread_id) or ""
worker_index = _extract_worker_index(chain, buf_id)
if worker_index is None:
logger.warning(f"Buffer {buf_id}: could not determine worker index from chain")
worker_index = state.completed_count # Fallback to count-based
# Serialize the result
result_xml = etree.tostring(payload_tree, encoding="unicode")
# Check for errors
is_error = payload_tree.tag.lower() in ("huh", "systemerror")
# Record result
state = buffer_registry.record_result(
buffer_id=buf_id,
index=worker_index,
result=result_xml,
success=not is_error,
error=result_xml[:200] if is_error else None,
)
logger.debug(
f"Buffer {buf_id}: received result {state.completed_count}/{state.total_items}"
)
# Check if all done
if state.is_complete:
# Clean up
pump = get_stream_pump()
pump.unregister_listener(f"buffer_{buf_id}")
# Format results as XML array
results_xml = _format_buffer_results(state)
buffer_registry.remove(buf_id)
logger.info(
f"Buffer {buf_id} completed: {state.successful_count}/{state.total_items} successful"
)
return HandlerResponse(
payload=BufferComplete(
buffer_id=buf_id,
total=state.total_items,
successful=state.successful_count,
results=results_xml,
),
to=state.return_to,
)
# More results pending
return None
def _extract_worker_index(chain: str, buf_id: str) -> Optional[int]:
"""Extract worker index from thread chain."""
# Look for pattern: buffer_{id}_wN
import re
pattern = rf"buffer_{buf_id}_w(\d+)"
match = re.search(pattern, chain)
if match:
return int(match.group(1))
return None
def _format_buffer_results(state) -> str:
"""Format buffer results as XML array."""
lines = ["<results>"]
for i in range(state.total_items):
result = state.results.get(i)
if result:
success = "true" if result.success else "false"
lines.append(f' <item index="{i}" success="{success}">')
# Indent the result content
for line in result.result.split("\n"):
lines.append(f" {line}")
lines.append(" </item>")
else:
lines.append(f' <item index="{i}" success="false">missing</item>')
lines.append("</results>")
return "\n".join(lines)
async def _inject_buffer_item(
pump,
target: str,
payload_xml: str,
thread_id: str,
from_id: str,
) -> None:
"""Inject a buffer item directly into the pump."""
# Wrap the payload in an envelope
envelope = pump._wrap_in_envelope(
payload=_RawXmlPayload(payload_xml),
from_id=from_id,
to_id=target,
thread_id=thread_id,
)
# Inject into pump
await pump.inject(envelope, thread_id=thread_id, from_id=from_id)
class _RawXmlPayload:
"""Carrier for raw XML that bypasses serialization."""
def __init__(self, xml: str):
self.xml = xml
def to_xml(self) -> str:
"""Return raw XML for envelope wrapping."""
return self.xml

View file

@ -0,0 +1,299 @@
"""
sequence.py Sequence orchestration primitives.
Sequences chain multiple listeners in order, feeding the output of one step
as input to the next. Steps remain transparent - they don't know they're
part of a sequence.
Usage by an agent:
# Start a sequence: add two numbers, then multiply
return HandlerResponse(
payload=SequenceStart(
steps="calculator.add,calculator.multiply",
payload='<AddPayload><a>5</a><b>3</b></AddPayload>',
return_to="my-agent",
),
to="system.sequence",
)
Flow:
1. system.sequence receives SequenceStart
2. Creates ephemeral listener sequence_{id} to receive step results
3. Sends initial payload to first step FROM sequence_{id}
4. Step processes and responds routes to sequence_{id}
5. Ephemeral handler advances, sends to next step
6. When all steps complete, sends SequenceComplete to return_to
7. Cleans up ephemeral listener
Key insight: Steps use normal .respond() - the ephemeral listener IS the
caller in the thread chain, so responses naturally route back to it.
"""
from dataclasses import dataclass
from typing import Optional
import uuid as uuid_module
import logging
from lxml import etree
from third_party.xmlable import xmlify
from xml_pipeline.message_bus.message_state import (
HandlerMetadata,
HandlerResponse,
MessageState,
)
from xml_pipeline.message_bus.sequence_registry import get_sequence_registry
logger = logging.getLogger(__name__)
# ============================================================================
# Payloads
# ============================================================================
@xmlify
@dataclass
class SequenceStart:
"""
Start a new sequence execution.
Sent to system.sequence to begin chaining steps.
"""
steps: str = "" # Comma-separated listener names
payload: str = "" # Initial XML payload for first step
return_to: str = "" # Where to send final result
sequence_id: str = "" # Auto-generated if empty
@xmlify
@dataclass
class SequenceComplete:
"""
Sequence completed successfully.
Sent to the return_to listener when all steps finish.
"""
sequence_id: str = ""
final_result: str = "" # XML result from last step
step_count: int = 0 # How many steps were executed
@xmlify
@dataclass
class SequenceError:
"""
Sequence failed at a step.
Sent to return_to when a step fails.
"""
sequence_id: str = ""
failed_step: str = "" # Which step failed
step_index: int = 0 # 0-based index of failed step
error: str = "" # Error message
# ============================================================================
# Handlers
# ============================================================================
async def handle_sequence_start(
payload: SequenceStart,
metadata: HandlerMetadata,
) -> Optional[HandlerResponse]:
"""
Handle SequenceStart begin a sequence execution.
Creates an ephemeral listener for this sequence, stores state,
and kicks off the first step.
"""
from xml_pipeline.message_bus.stream_pump import get_stream_pump
# Parse and validate
steps = [s.strip() for s in payload.steps.split(",") if s.strip()]
if not steps:
logger.error("SequenceStart with no steps")
return HandlerResponse(
payload=SequenceError(
sequence_id=payload.sequence_id or "unknown",
failed_step="",
step_index=0,
error="No steps specified",
),
to=payload.return_to or metadata.from_id,
)
# Generate sequence ID if not provided
seq_id = payload.sequence_id or str(uuid_module.uuid4())[:8]
# Validate all steps exist
pump = get_stream_pump()
for step in steps:
if step not in pump.listeners:
logger.error(f"SequenceStart: unknown step '{step}'")
return HandlerResponse(
payload=SequenceError(
sequence_id=seq_id,
failed_step=step,
step_index=steps.index(step),
error=f"Unknown listener: {step}",
),
to=payload.return_to or metadata.from_id,
)
# Create sequence state
registry = get_sequence_registry()
state = registry.create(
sequence_id=seq_id,
steps=steps,
return_to=payload.return_to or metadata.from_id,
thread_id=metadata.thread_id,
from_id=metadata.from_id,
initial_payload=payload.payload,
)
# Create ephemeral handler for this sequence
ephemeral_name = f"sequence_{seq_id}"
async def sequence_handler(
payload_tree: etree._Element,
meta: HandlerMetadata,
) -> Optional[HandlerResponse]:
"""Ephemeral handler that processes step results."""
return await _handle_sequence_step_result(seq_id, payload_tree, meta)
# Register ephemeral listener (generic mode - accepts any payload)
pump.register_generic_listener(
name=ephemeral_name,
handler=sequence_handler,
description=f"Ephemeral sequence handler for {seq_id}",
)
logger.info(
f"Sequence {seq_id} started: {len(steps)} steps, "
f"return_to={state.return_to}"
)
# Kick off first step
first_step = steps[0]
return _create_step_message(
seq_id=seq_id,
target=first_step,
payload_xml=payload.payload,
from_name=ephemeral_name,
)
async def _handle_sequence_step_result(
seq_id: str,
payload_tree: etree._Element,
metadata: HandlerMetadata,
) -> Optional[HandlerResponse]:
"""
Handle a step result in the sequence.
Called by the ephemeral listener when a step responds.
"""
from xml_pipeline.message_bus.stream_pump import get_stream_pump
registry = get_sequence_registry()
state = registry.get(seq_id)
if state is None:
logger.error(f"Sequence {seq_id} not found in registry")
return None
# Serialize the result for storage
result_xml = etree.tostring(payload_tree, encoding="unicode")
# Check for error responses
if payload_tree.tag.lower() in ("huh", "systemerror"):
# Step failed
error_text = payload_tree.text or etree.tostring(payload_tree, encoding="unicode")
registry.mark_failed(seq_id, state.current_step or "unknown", error_text)
# Clean up and send error
pump = get_stream_pump()
pump.unregister_listener(f"sequence_{seq_id}")
registry.remove(seq_id)
logger.warning(f"Sequence {seq_id} failed at step {state.current_index}")
return HandlerResponse(
payload=SequenceError(
sequence_id=seq_id,
failed_step=state.current_step or "unknown",
step_index=state.current_index,
error=error_text[:200], # Truncate long errors
),
to=state.return_to,
)
# Advance to next step
state = registry.advance(seq_id, result_xml)
if state.is_complete:
# All steps done - send completion
pump = get_stream_pump()
pump.unregister_listener(f"sequence_{seq_id}")
registry.remove(seq_id)
logger.info(f"Sequence {seq_id} completed: {len(state.steps)} steps")
return HandlerResponse(
payload=SequenceComplete(
sequence_id=seq_id,
final_result=result_xml,
step_count=len(state.steps),
),
to=state.return_to,
)
# More steps to go - send to next step
next_step = state.current_step
logger.debug(
f"Sequence {seq_id} advancing to step {state.current_index}: {next_step}"
)
return _create_step_message(
seq_id=seq_id,
target=next_step,
payload_xml=result_xml,
from_name=f"sequence_{seq_id}",
)
def _create_step_message(
seq_id: str,
target: str,
payload_xml: str,
from_name: str,
) -> HandlerResponse:
"""
Create a HandlerResponse to send payload to a step.
We need to inject the message with the ephemeral listener as the sender,
so that .respond() routes back to us.
"""
from xml_pipeline.primitives.sequence import _RawPayloadCarrier
# Return a special carrier that tells the pump to:
# 1. Use the raw XML bytes directly
# 2. Set from_id to from_name (the ephemeral listener)
return HandlerResponse(
payload=_RawPayloadCarrier(xml=payload_xml, from_override=from_name),
to=target,
)
class _RawPayloadCarrier:
"""
Internal carrier for raw XML that bypasses normal serialization.
When the pump sees this, it uses the raw XML directly instead of
serializing a dataclass.
"""
def __init__(self, xml: str, from_override: Optional[str] = None):
self.xml = xml
self.from_override = from_override
def to_xml(self) -> str:
"""Return raw XML for envelope wrapping."""
return self.xml