Add Sequence and Buffer orchestration primitives
Implement two virtual node patterns for message flow orchestration: - Sequence: Chains listeners in order (A→B→C), feeding each step's output as input to the next. Uses ephemeral listeners to intercept step results without modifying core pump behavior. - Buffer: Fan-out to parallel worker threads with optional result collection. Supports fire-and-forget mode (collect=False) for non-blocking dispatch. New files: - sequence_registry.py / buffer_registry.py: State tracking - sequence.py / buffer.py: Payloads and handlers - test_sequence.py / test_buffer.py: 52 new tests Pump additions: - register_generic_listener(): Accept any payload type - unregister_listener(): Cleanup ephemeral listeners - Global singleton accessors for pump instance Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
a69eae79c5
commit
a623c534d5
10 changed files with 2465 additions and 2 deletions
635
tests/test_buffer.py
Normal file
635
tests/test_buffer.py
Normal file
|
|
@ -0,0 +1,635 @@
|
||||||
|
"""
|
||||||
|
test_buffer.py — Tests for the Buffer (fan-out) orchestration primitives.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
1. BufferRegistry basic operations
|
||||||
|
2. BufferStart handler
|
||||||
|
3. Result collection
|
||||||
|
4. Fire-and-forget mode
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from xml_pipeline.message_bus.buffer_registry import (
|
||||||
|
BufferRegistry,
|
||||||
|
BufferState,
|
||||||
|
BufferItemResult,
|
||||||
|
get_buffer_registry,
|
||||||
|
reset_buffer_registry,
|
||||||
|
)
|
||||||
|
from xml_pipeline.message_bus.message_state import HandlerMetadata, HandlerResponse
|
||||||
|
from xml_pipeline.primitives.buffer import (
|
||||||
|
BufferStart,
|
||||||
|
BufferComplete,
|
||||||
|
BufferDispatched,
|
||||||
|
BufferError,
|
||||||
|
handle_buffer_start,
|
||||||
|
_extract_worker_index,
|
||||||
|
_format_buffer_results,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBufferRegistry:
|
||||||
|
"""Test BufferRegistry basic operations."""
|
||||||
|
|
||||||
|
def test_create_buffer_state(self):
|
||||||
|
"""create() should create a buffer state with correct fields."""
|
||||||
|
registry = BufferRegistry()
|
||||||
|
|
||||||
|
state = registry.create(
|
||||||
|
buffer_id="buf001",
|
||||||
|
total_items=5,
|
||||||
|
return_to="caller",
|
||||||
|
thread_id="thread-123",
|
||||||
|
from_id="console",
|
||||||
|
target="worker",
|
||||||
|
collect=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert state.buffer_id == "buf001"
|
||||||
|
assert state.total_items == 5
|
||||||
|
assert state.return_to == "caller"
|
||||||
|
assert state.thread_id == "thread-123"
|
||||||
|
assert state.from_id == "console"
|
||||||
|
assert state.target == "worker"
|
||||||
|
assert state.collect is True
|
||||||
|
assert state.completed_count == 0
|
||||||
|
assert state.successful_count == 0
|
||||||
|
assert state.is_complete is False
|
||||||
|
assert state.pending_count == 5
|
||||||
|
|
||||||
|
def test_get_returns_buffer(self):
|
||||||
|
"""get() should return buffer by ID."""
|
||||||
|
registry = BufferRegistry()
|
||||||
|
|
||||||
|
registry.create(
|
||||||
|
buffer_id="buf002",
|
||||||
|
total_items=3,
|
||||||
|
return_to="x",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
target="w",
|
||||||
|
)
|
||||||
|
|
||||||
|
state = registry.get("buf002")
|
||||||
|
assert state is not None
|
||||||
|
assert state.buffer_id == "buf002"
|
||||||
|
|
||||||
|
# Non-existent returns None
|
||||||
|
assert registry.get("nonexistent") is None
|
||||||
|
|
||||||
|
def test_record_result_stores_result(self):
|
||||||
|
"""record_result() should store result and update counts."""
|
||||||
|
registry = BufferRegistry()
|
||||||
|
|
||||||
|
registry.create(
|
||||||
|
buffer_id="buf003",
|
||||||
|
total_items=3,
|
||||||
|
return_to="x",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
target="w",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record first result (success)
|
||||||
|
state = registry.record_result(
|
||||||
|
buffer_id="buf003",
|
||||||
|
index=0,
|
||||||
|
result="<Result0/>",
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert state.completed_count == 1
|
||||||
|
assert state.successful_count == 1
|
||||||
|
assert state.is_complete is False
|
||||||
|
assert 0 in state.results
|
||||||
|
assert state.results[0].result == "<Result0/>"
|
||||||
|
assert state.results[0].success is True
|
||||||
|
|
||||||
|
# Record second result (failure)
|
||||||
|
state = registry.record_result(
|
||||||
|
buffer_id="buf003",
|
||||||
|
index=1,
|
||||||
|
result="<Error/>",
|
||||||
|
success=False,
|
||||||
|
error="Timeout",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert state.completed_count == 2
|
||||||
|
assert state.successful_count == 1 # Still just 1
|
||||||
|
assert state.results[1].success is False
|
||||||
|
assert state.results[1].error == "Timeout"
|
||||||
|
|
||||||
|
# Record third result - now complete
|
||||||
|
state = registry.record_result(
|
||||||
|
buffer_id="buf003",
|
||||||
|
index=2,
|
||||||
|
result="<Result2/>",
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert state.completed_count == 3
|
||||||
|
assert state.successful_count == 2
|
||||||
|
assert state.is_complete is True
|
||||||
|
assert state.pending_count == 0
|
||||||
|
|
||||||
|
def test_record_result_ignores_duplicates(self):
|
||||||
|
"""record_result() should not count same index twice."""
|
||||||
|
registry = BufferRegistry()
|
||||||
|
|
||||||
|
registry.create(
|
||||||
|
buffer_id="buf004",
|
||||||
|
total_items=2,
|
||||||
|
return_to="x",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
target="w",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record index 0
|
||||||
|
state = registry.record_result("buf004", 0, "<R/>", True)
|
||||||
|
assert state.completed_count == 1
|
||||||
|
|
||||||
|
# Try to record index 0 again
|
||||||
|
state = registry.record_result("buf004", 0, "<Duplicate/>", True)
|
||||||
|
assert state.completed_count == 1 # Should not increment
|
||||||
|
assert state.results[0].result == "<R/>" # Original preserved
|
||||||
|
|
||||||
|
def test_remove_deletes_buffer(self):
|
||||||
|
"""remove() should delete buffer from registry."""
|
||||||
|
registry = BufferRegistry()
|
||||||
|
|
||||||
|
registry.create(
|
||||||
|
buffer_id="buf005",
|
||||||
|
total_items=1,
|
||||||
|
return_to="x",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
target="w",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert registry.get("buf005") is not None
|
||||||
|
result = registry.remove("buf005")
|
||||||
|
assert result is True
|
||||||
|
assert registry.get("buf005") is None
|
||||||
|
|
||||||
|
# Remove non-existent returns False
|
||||||
|
assert registry.remove("nonexistent") is False
|
||||||
|
|
||||||
|
def test_list_active(self):
|
||||||
|
"""list_active() should return all active buffer IDs."""
|
||||||
|
registry = BufferRegistry()
|
||||||
|
|
||||||
|
registry.create("buf-a", 1, "x", "t", "c", "w")
|
||||||
|
registry.create("buf-b", 2, "x", "t", "c", "w")
|
||||||
|
registry.create("buf-c", 3, "x", "t", "c", "w")
|
||||||
|
|
||||||
|
active = registry.list_active()
|
||||||
|
assert set(active) == {"buf-a", "buf-b", "buf-c"}
|
||||||
|
|
||||||
|
def test_clear(self):
|
||||||
|
"""clear() should remove all buffers."""
|
||||||
|
registry = BufferRegistry()
|
||||||
|
|
||||||
|
registry.create("buf-1", 1, "x", "t", "c", "w")
|
||||||
|
registry.create("buf-2", 2, "x", "t", "c", "w")
|
||||||
|
|
||||||
|
registry.clear()
|
||||||
|
|
||||||
|
assert registry.list_active() == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestBufferStateProperties:
|
||||||
|
"""Test BufferState computed properties."""
|
||||||
|
|
||||||
|
def test_is_complete_when_all_received(self):
|
||||||
|
"""is_complete should be True when all items are received."""
|
||||||
|
state = BufferState(
|
||||||
|
buffer_id="test",
|
||||||
|
total_items=3,
|
||||||
|
return_to="x",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
target="w",
|
||||||
|
completed_count=3,
|
||||||
|
)
|
||||||
|
assert state.is_complete is True
|
||||||
|
|
||||||
|
def test_pending_count(self):
|
||||||
|
"""pending_count should reflect remaining items."""
|
||||||
|
state = BufferState(
|
||||||
|
buffer_id="test",
|
||||||
|
total_items=5,
|
||||||
|
return_to="x",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
target="w",
|
||||||
|
completed_count=2,
|
||||||
|
)
|
||||||
|
assert state.pending_count == 3
|
||||||
|
|
||||||
|
def test_get_ordered_results(self):
|
||||||
|
"""get_ordered_results should return results in order."""
|
||||||
|
state = BufferState(
|
||||||
|
buffer_id="test",
|
||||||
|
total_items=3,
|
||||||
|
return_to="x",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
target="w",
|
||||||
|
results={
|
||||||
|
0: BufferItemResult(0, "<R0/>", True),
|
||||||
|
2: BufferItemResult(2, "<R2/>", True),
|
||||||
|
# Index 1 missing
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
ordered = state.get_ordered_results()
|
||||||
|
assert len(ordered) == 3
|
||||||
|
assert ordered[0] is not None
|
||||||
|
assert ordered[0].result == "<R0/>"
|
||||||
|
assert ordered[1] is None # Missing
|
||||||
|
assert ordered[2] is not None
|
||||||
|
assert ordered[2].result == "<R2/>"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBufferStartPayload:
|
||||||
|
"""Test BufferStart payload."""
|
||||||
|
|
||||||
|
def test_buffer_start_fields(self):
|
||||||
|
"""BufferStart should have expected fields."""
|
||||||
|
payload = BufferStart(
|
||||||
|
target="worker",
|
||||||
|
items="item1\nitem2\nitem3",
|
||||||
|
collect=True,
|
||||||
|
return_to="caller",
|
||||||
|
buffer_id="custom-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert payload.target == "worker"
|
||||||
|
assert payload.items == "item1\nitem2\nitem3"
|
||||||
|
assert payload.collect is True
|
||||||
|
assert payload.return_to == "caller"
|
||||||
|
assert payload.buffer_id == "custom-id"
|
||||||
|
|
||||||
|
def test_buffer_start_default_values(self):
|
||||||
|
"""BufferStart should have sensible defaults."""
|
||||||
|
payload = BufferStart()
|
||||||
|
|
||||||
|
assert payload.target == ""
|
||||||
|
assert payload.items == ""
|
||||||
|
assert payload.collect is True
|
||||||
|
assert payload.return_to == ""
|
||||||
|
assert payload.buffer_id == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestBufferCompletePayload:
|
||||||
|
"""Test BufferComplete payload."""
|
||||||
|
|
||||||
|
def test_buffer_complete_fields(self):
|
||||||
|
"""BufferComplete should have expected fields."""
|
||||||
|
payload = BufferComplete(
|
||||||
|
buffer_id="buf123",
|
||||||
|
total=5,
|
||||||
|
successful=4,
|
||||||
|
results="<results>...</results>",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert payload.buffer_id == "buf123"
|
||||||
|
assert payload.total == 5
|
||||||
|
assert payload.successful == 4
|
||||||
|
assert payload.results == "<results>...</results>"
|
||||||
|
|
||||||
|
|
||||||
|
class TestBufferDispatchedPayload:
|
||||||
|
"""Test BufferDispatched payload (fire-and-forget mode)."""
|
||||||
|
|
||||||
|
def test_buffer_dispatched_fields(self):
|
||||||
|
"""BufferDispatched should have expected fields."""
|
||||||
|
payload = BufferDispatched(
|
||||||
|
buffer_id="buf456",
|
||||||
|
total=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert payload.buffer_id == "buf456"
|
||||||
|
assert payload.total == 10
|
||||||
|
|
||||||
|
|
||||||
|
class TestBufferErrorPayload:
|
||||||
|
"""Test BufferError payload."""
|
||||||
|
|
||||||
|
def test_buffer_error_fields(self):
|
||||||
|
"""BufferError should have expected fields."""
|
||||||
|
payload = BufferError(
|
||||||
|
buffer_id="buf789",
|
||||||
|
error="Unknown target listener",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert payload.buffer_id == "buf789"
|
||||||
|
assert payload.error == "Unknown target listener"
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtractWorkerIndex:
|
||||||
|
"""Test _extract_worker_index helper."""
|
||||||
|
|
||||||
|
def test_extracts_index_from_chain(self):
|
||||||
|
"""Should extract worker index from thread chain."""
|
||||||
|
chain = "root.parent.buffer_abc123_w5"
|
||||||
|
index = _extract_worker_index(chain, "abc123")
|
||||||
|
assert index == 5
|
||||||
|
|
||||||
|
def test_extracts_double_digit_index(self):
|
||||||
|
"""Should handle double-digit indices."""
|
||||||
|
chain = "x.buffer_xyz_w42"
|
||||||
|
index = _extract_worker_index(chain, "xyz")
|
||||||
|
assert index == 42
|
||||||
|
|
||||||
|
def test_returns_none_for_no_match(self):
|
||||||
|
"""Should return None when pattern doesn't match."""
|
||||||
|
chain = "something.else.entirely"
|
||||||
|
index = _extract_worker_index(chain, "abc")
|
||||||
|
assert index is None
|
||||||
|
|
||||||
|
def test_returns_none_for_wrong_buffer_id(self):
|
||||||
|
"""Should return None when buffer ID doesn't match."""
|
||||||
|
chain = "root.buffer_other_w3"
|
||||||
|
index = _extract_worker_index(chain, "abc")
|
||||||
|
assert index is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatBufferResults:
|
||||||
|
"""Test _format_buffer_results helper."""
|
||||||
|
|
||||||
|
def test_formats_complete_results(self):
|
||||||
|
"""Should format all results as XML."""
|
||||||
|
state = BufferState(
|
||||||
|
buffer_id="test",
|
||||||
|
total_items=2,
|
||||||
|
return_to="x",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
target="w",
|
||||||
|
results={
|
||||||
|
0: BufferItemResult(0, "<Result>A</Result>", True),
|
||||||
|
1: BufferItemResult(1, "<Result>B</Result>", True),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
xml = _format_buffer_results(state)
|
||||||
|
|
||||||
|
assert "<results>" in xml
|
||||||
|
assert "</results>" in xml
|
||||||
|
assert 'index="0"' in xml
|
||||||
|
assert 'index="1"' in xml
|
||||||
|
assert 'success="true"' in xml
|
||||||
|
assert "<Result>A</Result>" in xml
|
||||||
|
assert "<Result>B</Result>" in xml
|
||||||
|
|
||||||
|
def test_formats_partial_failure(self):
|
||||||
|
"""Should format mixed success/failure results."""
|
||||||
|
state = BufferState(
|
||||||
|
buffer_id="test",
|
||||||
|
total_items=2,
|
||||||
|
return_to="x",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
target="w",
|
||||||
|
results={
|
||||||
|
0: BufferItemResult(0, "<Good/>", True),
|
||||||
|
1: BufferItemResult(1, "<Error/>", False, "timeout"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
xml = _format_buffer_results(state)
|
||||||
|
|
||||||
|
assert 'success="true"' in xml
|
||||||
|
assert 'success="false"' in xml
|
||||||
|
|
||||||
|
def test_formats_missing_results(self):
|
||||||
|
"""Should handle missing results."""
|
||||||
|
state = BufferState(
|
||||||
|
buffer_id="test",
|
||||||
|
total_items=3,
|
||||||
|
return_to="x",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
target="w",
|
||||||
|
results={
|
||||||
|
0: BufferItemResult(0, "<R/>", True),
|
||||||
|
# Index 1 missing
|
||||||
|
2: BufferItemResult(2, "<R/>", True),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
xml = _format_buffer_results(state)
|
||||||
|
|
||||||
|
assert 'index="1"' in xml
|
||||||
|
assert "missing" in xml
|
||||||
|
|
||||||
|
|
||||||
|
class TestHandleBufferStartValidation:
|
||||||
|
"""Test validation in handle_buffer_start."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup(self):
|
||||||
|
"""Reset registries before each test."""
|
||||||
|
reset_buffer_registry()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_items_returns_error(self):
|
||||||
|
"""handle_buffer_start should return error for empty items."""
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
mock_pump = MagicMock()
|
||||||
|
mock_pump.listeners = {}
|
||||||
|
|
||||||
|
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
||||||
|
payload = BufferStart(
|
||||||
|
target="worker",
|
||||||
|
items="", # Empty
|
||||||
|
return_to="caller",
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = HandlerMetadata(
|
||||||
|
thread_id=str(uuid.uuid4()),
|
||||||
|
from_id="console",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handle_buffer_start(payload, metadata)
|
||||||
|
|
||||||
|
assert isinstance(response, HandlerResponse)
|
||||||
|
assert isinstance(response.payload, BufferError)
|
||||||
|
assert "No items" in response.payload.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unknown_target_returns_error(self):
|
||||||
|
"""handle_buffer_start should return error for unknown target."""
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
mock_pump = MagicMock()
|
||||||
|
mock_pump.listeners = {} # No listeners
|
||||||
|
|
||||||
|
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
||||||
|
payload = BufferStart(
|
||||||
|
target="unknown_worker",
|
||||||
|
items="item1\nitem2",
|
||||||
|
return_to="caller",
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = HandlerMetadata(
|
||||||
|
thread_id=str(uuid.uuid4()),
|
||||||
|
from_id="console",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handle_buffer_start(payload, metadata)
|
||||||
|
|
||||||
|
assert isinstance(response, HandlerResponse)
|
||||||
|
assert isinstance(response.payload, BufferError)
|
||||||
|
assert "Unknown target" in response.payload.error
|
||||||
|
|
||||||
|
|
||||||
|
class TestBufferRegistrySingleton:
|
||||||
|
"""Test singleton pattern for BufferRegistry."""
|
||||||
|
|
||||||
|
def test_get_buffer_registry_returns_singleton(self):
|
||||||
|
"""get_buffer_registry should return same instance."""
|
||||||
|
reset_buffer_registry()
|
||||||
|
|
||||||
|
reg1 = get_buffer_registry()
|
||||||
|
reg2 = get_buffer_registry()
|
||||||
|
|
||||||
|
assert reg1 is reg2
|
||||||
|
|
||||||
|
def test_reset_creates_new_instance(self):
|
||||||
|
"""reset_buffer_registry should clear singleton."""
|
||||||
|
reg1 = get_buffer_registry()
|
||||||
|
reg1.create("test", 1, "x", "t", "c", "w")
|
||||||
|
|
||||||
|
reset_buffer_registry()
|
||||||
|
reg2 = get_buffer_registry()
|
||||||
|
|
||||||
|
assert reg2.get("test") is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestBufferCollectVsFireAndForget:
|
||||||
|
"""Test collect=True vs collect=False behavior."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_collect_mode_returns_none(self):
|
||||||
|
"""With collect=True, handler should return None (wait for results)."""
|
||||||
|
from unittest.mock import patch, MagicMock, AsyncMock
|
||||||
|
|
||||||
|
mock_pump = MagicMock()
|
||||||
|
mock_pump.listeners = {"worker": MagicMock()}
|
||||||
|
mock_pump.register_generic_listener = MagicMock()
|
||||||
|
mock_pump._wrap_in_envelope = MagicMock(return_value=b"<envelope/>")
|
||||||
|
mock_pump.inject = AsyncMock()
|
||||||
|
|
||||||
|
# Mock thread registry
|
||||||
|
mock_thread_registry = MagicMock()
|
||||||
|
mock_thread_registry.lookup = MagicMock(return_value="root.parent")
|
||||||
|
mock_thread_registry.get_or_create = MagicMock(return_value="worker-uuid")
|
||||||
|
|
||||||
|
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
||||||
|
with patch('xml_pipeline.message_bus.thread_registry.get_registry', return_value=mock_thread_registry):
|
||||||
|
payload = BufferStart(
|
||||||
|
target="worker",
|
||||||
|
items="item1\nitem2",
|
||||||
|
return_to="caller",
|
||||||
|
collect=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = HandlerMetadata(
|
||||||
|
thread_id=str(uuid.uuid4()),
|
||||||
|
from_id="console",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handle_buffer_start(payload, metadata)
|
||||||
|
|
||||||
|
# With collect=True, returns None (ephemeral handler will send BufferComplete)
|
||||||
|
assert response is None
|
||||||
|
|
||||||
|
# Ephemeral listener should be registered
|
||||||
|
mock_pump.register_generic_listener.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fire_and_forget_returns_dispatched(self):
|
||||||
|
"""With collect=False, handler should return BufferDispatched immediately."""
|
||||||
|
from unittest.mock import patch, MagicMock, AsyncMock
|
||||||
|
|
||||||
|
mock_pump = MagicMock()
|
||||||
|
mock_pump.listeners = {"worker": MagicMock()}
|
||||||
|
mock_pump.register_generic_listener = MagicMock()
|
||||||
|
mock_pump._wrap_in_envelope = MagicMock(return_value=b"<envelope/>")
|
||||||
|
mock_pump.inject = AsyncMock()
|
||||||
|
|
||||||
|
mock_thread_registry = MagicMock()
|
||||||
|
mock_thread_registry.lookup = MagicMock(return_value="root.parent")
|
||||||
|
mock_thread_registry.get_or_create = MagicMock(return_value="worker-uuid")
|
||||||
|
|
||||||
|
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
||||||
|
with patch('xml_pipeline.message_bus.thread_registry.get_registry', return_value=mock_thread_registry):
|
||||||
|
payload = BufferStart(
|
||||||
|
target="worker",
|
||||||
|
items="item1\nitem2\nitem3",
|
||||||
|
return_to="caller",
|
||||||
|
collect=False, # Fire-and-forget
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = HandlerMetadata(
|
||||||
|
thread_id=str(uuid.uuid4()),
|
||||||
|
from_id="console",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handle_buffer_start(payload, metadata)
|
||||||
|
|
||||||
|
# With collect=False, returns BufferDispatched
|
||||||
|
assert isinstance(response, HandlerResponse)
|
||||||
|
assert isinstance(response.payload, BufferDispatched)
|
||||||
|
assert response.payload.total == 3
|
||||||
|
|
||||||
|
# No ephemeral listener registered for fire-and-forget
|
||||||
|
mock_pump.register_generic_listener.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestBufferResultCollection:
|
||||||
|
"""Test result collection behavior."""
|
||||||
|
|
||||||
|
def test_result_collection_partial_success(self):
|
||||||
|
"""Buffer should track partial success correctly."""
|
||||||
|
registry = BufferRegistry()
|
||||||
|
|
||||||
|
registry.create("partial", 5, "x", "t", "c", "w")
|
||||||
|
|
||||||
|
# 3 successes, 2 failures
|
||||||
|
registry.record_result("partial", 0, "<R/>", True)
|
||||||
|
registry.record_result("partial", 1, "<E/>", False, "error")
|
||||||
|
registry.record_result("partial", 2, "<R/>", True)
|
||||||
|
registry.record_result("partial", 3, "<E/>", False, "error")
|
||||||
|
registry.record_result("partial", 4, "<R/>", True)
|
||||||
|
|
||||||
|
state = registry.get("partial")
|
||||||
|
|
||||||
|
assert state.is_complete is True
|
||||||
|
assert state.completed_count == 5
|
||||||
|
assert state.successful_count == 3
|
||||||
|
|
||||||
|
def test_results_out_of_order(self):
|
||||||
|
"""Buffer should handle results arriving out of order."""
|
||||||
|
registry = BufferRegistry()
|
||||||
|
|
||||||
|
registry.create("ooo", 3, "x", "t", "c", "w")
|
||||||
|
|
||||||
|
# Results arrive out of order
|
||||||
|
registry.record_result("ooo", 2, "<R2/>", True)
|
||||||
|
registry.record_result("ooo", 0, "<R0/>", True)
|
||||||
|
registry.record_result("ooo", 1, "<R1/>", True)
|
||||||
|
|
||||||
|
state = registry.get("ooo")
|
||||||
|
|
||||||
|
assert state.is_complete is True
|
||||||
|
ordered = state.get_ordered_results()
|
||||||
|
assert ordered[0].result == "<R0/>"
|
||||||
|
assert ordered[1].result == "<R1/>"
|
||||||
|
assert ordered[2].result == "<R2/>"
|
||||||
|
|
@ -66,7 +66,7 @@ class TestPumpBootstrap:
|
||||||
pump = await bootstrap('config/organism.yaml')
|
pump = await bootstrap('config/organism.yaml')
|
||||||
|
|
||||||
assert pump.config.name == "hello-world"
|
assert pump.config.name == "hello-world"
|
||||||
assert len(pump.routing_table) == 6 # 3 user listeners + 3 system (boot, todo, todo-complete)
|
assert len(pump.routing_table) == 8 # 3 user listeners + 5 system (boot, todo, todo-complete, sequence, buffer)
|
||||||
assert "greeter.greeting" in pump.routing_table
|
assert "greeter.greeting" in pump.routing_table
|
||||||
assert "shouter.greetingresponse" in pump.routing_table
|
assert "shouter.greetingresponse" in pump.routing_table
|
||||||
assert "response-handler.shoutedresponse" in pump.routing_table
|
assert "response-handler.shoutedresponse" in pump.routing_table
|
||||||
|
|
|
||||||
464
tests/test_sequence.py
Normal file
464
tests/test_sequence.py
Normal file
|
|
@ -0,0 +1,464 @@
|
||||||
|
"""
|
||||||
|
test_sequence.py — Tests for the Sequence orchestration primitives.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
1. SequenceRegistry basic operations
|
||||||
|
2. SequenceStart handler
|
||||||
|
3. Step result handling
|
||||||
|
4. Error propagation
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from xml_pipeline.message_bus.sequence_registry import (
|
||||||
|
SequenceRegistry,
|
||||||
|
SequenceState,
|
||||||
|
get_sequence_registry,
|
||||||
|
reset_sequence_registry,
|
||||||
|
)
|
||||||
|
from xml_pipeline.message_bus.message_state import HandlerMetadata, HandlerResponse
|
||||||
|
from xml_pipeline.primitives.sequence import (
|
||||||
|
SequenceStart,
|
||||||
|
SequenceComplete,
|
||||||
|
SequenceError,
|
||||||
|
handle_sequence_start,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestSequenceRegistry:
|
||||||
|
"""Test SequenceRegistry basic operations."""
|
||||||
|
|
||||||
|
def test_create_sequence_state(self):
|
||||||
|
"""create() should create a sequence state with correct fields."""
|
||||||
|
registry = SequenceRegistry()
|
||||||
|
|
||||||
|
state = registry.create(
|
||||||
|
sequence_id="seq001",
|
||||||
|
steps=["step1", "step2", "step3"],
|
||||||
|
return_to="caller",
|
||||||
|
thread_id="thread-123",
|
||||||
|
from_id="console",
|
||||||
|
initial_payload="<TestPayload/>",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert state.sequence_id == "seq001"
|
||||||
|
assert state.steps == ["step1", "step2", "step3"]
|
||||||
|
assert state.return_to == "caller"
|
||||||
|
assert state.thread_id == "thread-123"
|
||||||
|
assert state.from_id == "console"
|
||||||
|
assert state.current_index == 0
|
||||||
|
assert state.is_complete is False
|
||||||
|
assert state.current_step == "step1"
|
||||||
|
assert state.remaining_steps == ["step1", "step2", "step3"]
|
||||||
|
assert state.last_result == "<TestPayload/>"
|
||||||
|
|
||||||
|
def test_get_returns_sequence(self):
|
||||||
|
"""get() should return sequence by ID."""
|
||||||
|
registry = SequenceRegistry()
|
||||||
|
|
||||||
|
registry.create(
|
||||||
|
sequence_id="seq002",
|
||||||
|
steps=["a", "b"],
|
||||||
|
return_to="x",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
)
|
||||||
|
|
||||||
|
state = registry.get("seq002")
|
||||||
|
assert state is not None
|
||||||
|
assert state.sequence_id == "seq002"
|
||||||
|
|
||||||
|
# Non-existent returns None
|
||||||
|
assert registry.get("nonexistent") is None
|
||||||
|
|
||||||
|
def test_advance_increments_index(self):
|
||||||
|
"""advance() should increment index and store result."""
|
||||||
|
registry = SequenceRegistry()
|
||||||
|
|
||||||
|
registry.create(
|
||||||
|
sequence_id="seq003",
|
||||||
|
steps=["a", "b", "c"],
|
||||||
|
return_to="x",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Advance first step
|
||||||
|
state = registry.advance("seq003", "<ResultA/>")
|
||||||
|
assert state.current_index == 1
|
||||||
|
assert state.results == ["<ResultA/>"]
|
||||||
|
assert state.last_result == "<ResultA/>"
|
||||||
|
assert state.current_step == "b"
|
||||||
|
assert state.is_complete is False
|
||||||
|
|
||||||
|
# Advance second step
|
||||||
|
state = registry.advance("seq003", "<ResultB/>")
|
||||||
|
assert state.current_index == 2
|
||||||
|
assert state.results == ["<ResultA/>", "<ResultB/>"]
|
||||||
|
assert state.current_step == "c"
|
||||||
|
|
||||||
|
# Advance third step - now complete
|
||||||
|
state = registry.advance("seq003", "<ResultC/>")
|
||||||
|
assert state.current_index == 3
|
||||||
|
assert state.is_complete is True
|
||||||
|
assert state.current_step is None
|
||||||
|
assert state.remaining_steps == []
|
||||||
|
|
||||||
|
def test_mark_failed(self):
|
||||||
|
"""mark_failed() should set failed state."""
|
||||||
|
registry = SequenceRegistry()
|
||||||
|
|
||||||
|
registry.create(
|
||||||
|
sequence_id="seq004",
|
||||||
|
steps=["a", "b"],
|
||||||
|
return_to="x",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
)
|
||||||
|
|
||||||
|
state = registry.mark_failed("seq004", "a", "XSD validation failed")
|
||||||
|
|
||||||
|
assert state.failed is True
|
||||||
|
assert state.failed_step == "a"
|
||||||
|
assert state.error == "XSD validation failed"
|
||||||
|
assert state.is_complete is False # Not complete, but failed
|
||||||
|
assert state.current_step is None # No current step when failed
|
||||||
|
assert state.remaining_steps == [] # No remaining steps when failed
|
||||||
|
|
||||||
|
def test_remove_deletes_sequence(self):
|
||||||
|
"""remove() should delete sequence from registry."""
|
||||||
|
registry = SequenceRegistry()
|
||||||
|
|
||||||
|
registry.create(
|
||||||
|
sequence_id="seq005",
|
||||||
|
steps=["a"],
|
||||||
|
return_to="x",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert registry.get("seq005") is not None
|
||||||
|
result = registry.remove("seq005")
|
||||||
|
assert result is True
|
||||||
|
assert registry.get("seq005") is None
|
||||||
|
|
||||||
|
# Remove non-existent returns False
|
||||||
|
assert registry.remove("nonexistent") is False
|
||||||
|
|
||||||
|
def test_list_active(self):
|
||||||
|
"""list_active() should return all active sequence IDs."""
|
||||||
|
registry = SequenceRegistry()
|
||||||
|
|
||||||
|
registry.create("seq-a", ["1"], "x", "t", "c")
|
||||||
|
registry.create("seq-b", ["2"], "x", "t", "c")
|
||||||
|
registry.create("seq-c", ["3"], "x", "t", "c")
|
||||||
|
|
||||||
|
active = registry.list_active()
|
||||||
|
assert set(active) == {"seq-a", "seq-b", "seq-c"}
|
||||||
|
|
||||||
|
def test_clear(self):
|
||||||
|
"""clear() should remove all sequences."""
|
||||||
|
registry = SequenceRegistry()
|
||||||
|
|
||||||
|
registry.create("seq-1", ["a"], "x", "t", "c")
|
||||||
|
registry.create("seq-2", ["b"], "x", "t", "c")
|
||||||
|
|
||||||
|
registry.clear()
|
||||||
|
|
||||||
|
assert registry.list_active() == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestSequenceStateProperties:
|
||||||
|
"""Test SequenceState computed properties."""
|
||||||
|
|
||||||
|
def test_is_complete_after_all_steps(self):
|
||||||
|
"""is_complete should be True when all steps are done."""
|
||||||
|
state = SequenceState(
|
||||||
|
sequence_id="test",
|
||||||
|
steps=["a", "b"],
|
||||||
|
return_to="x",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
current_index=2, # Past all steps
|
||||||
|
)
|
||||||
|
assert state.is_complete is True
|
||||||
|
|
||||||
|
def test_not_complete_when_failed(self):
|
||||||
|
"""is_complete should be False when failed."""
|
||||||
|
state = SequenceState(
|
||||||
|
sequence_id="test",
|
||||||
|
steps=["a", "b"],
|
||||||
|
return_to="x",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
current_index=2,
|
||||||
|
failed=True,
|
||||||
|
)
|
||||||
|
assert state.is_complete is False
|
||||||
|
|
||||||
|
def test_current_step_none_when_complete(self):
|
||||||
|
"""current_step should be None when sequence is complete."""
|
||||||
|
state = SequenceState(
|
||||||
|
sequence_id="test",
|
||||||
|
steps=["a"],
|
||||||
|
return_to="x",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
current_index=1,
|
||||||
|
)
|
||||||
|
assert state.current_step is None
|
||||||
|
|
||||||
|
def test_remaining_steps_empty_when_complete(self):
|
||||||
|
"""remaining_steps should be empty when complete."""
|
||||||
|
state = SequenceState(
|
||||||
|
sequence_id="test",
|
||||||
|
steps=["a", "b"],
|
||||||
|
return_to="x",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
current_index=2,
|
||||||
|
)
|
||||||
|
assert state.remaining_steps == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestSequenceStartPayload:
|
||||||
|
"""Test SequenceStart payload serialization."""
|
||||||
|
|
||||||
|
def test_sequence_start_fields(self):
|
||||||
|
"""SequenceStart should have expected fields."""
|
||||||
|
payload = SequenceStart(
|
||||||
|
steps="step1,step2",
|
||||||
|
payload="<Test/>",
|
||||||
|
return_to="caller",
|
||||||
|
sequence_id="custom-id",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert payload.steps == "step1,step2"
|
||||||
|
assert payload.payload == "<Test/>"
|
||||||
|
assert payload.return_to == "caller"
|
||||||
|
assert payload.sequence_id == "custom-id"
|
||||||
|
|
||||||
|
def test_sequence_start_default_values(self):
|
||||||
|
"""SequenceStart should have sensible defaults."""
|
||||||
|
payload = SequenceStart()
|
||||||
|
|
||||||
|
assert payload.steps == ""
|
||||||
|
assert payload.payload == ""
|
||||||
|
assert payload.return_to == ""
|
||||||
|
assert payload.sequence_id == ""
|
||||||
|
|
||||||
|
|
||||||
|
class TestSequenceCompletePayload:
|
||||||
|
"""Test SequenceComplete payload."""
|
||||||
|
|
||||||
|
def test_sequence_complete_fields(self):
|
||||||
|
"""SequenceComplete should have expected fields."""
|
||||||
|
payload = SequenceComplete(
|
||||||
|
sequence_id="seq123",
|
||||||
|
final_result="<Result>42</Result>",
|
||||||
|
step_count=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert payload.sequence_id == "seq123"
|
||||||
|
assert payload.final_result == "<Result>42</Result>"
|
||||||
|
assert payload.step_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestSequenceErrorPayload:
|
||||||
|
"""Test SequenceError payload."""
|
||||||
|
|
||||||
|
def test_sequence_error_fields(self):
|
||||||
|
"""SequenceError should have expected fields."""
|
||||||
|
payload = SequenceError(
|
||||||
|
sequence_id="seq456",
|
||||||
|
failed_step="bad-step",
|
||||||
|
step_index=1,
|
||||||
|
error="Validation failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert payload.sequence_id == "seq456"
|
||||||
|
assert payload.failed_step == "bad-step"
|
||||||
|
assert payload.step_index == 1
|
||||||
|
assert payload.error == "Validation failed"
|
||||||
|
|
||||||
|
|
||||||
|
class TestHandleSequenceStartValidation:
|
||||||
|
"""Test validation in handle_sequence_start."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup(self):
|
||||||
|
"""Reset registries before each test."""
|
||||||
|
reset_sequence_registry()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_steps_returns_error(self):
|
||||||
|
"""handle_sequence_start should return error for empty steps."""
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
# Mock get_stream_pump to avoid dependency
|
||||||
|
mock_pump = MagicMock()
|
||||||
|
mock_pump.listeners = {}
|
||||||
|
|
||||||
|
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
||||||
|
payload = SequenceStart(
|
||||||
|
steps="", # Empty
|
||||||
|
payload="<Test/>",
|
||||||
|
return_to="caller",
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = HandlerMetadata(
|
||||||
|
thread_id=str(uuid.uuid4()),
|
||||||
|
from_id="console",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handle_sequence_start(payload, metadata)
|
||||||
|
|
||||||
|
assert isinstance(response, HandlerResponse)
|
||||||
|
assert isinstance(response.payload, SequenceError)
|
||||||
|
assert "No steps" in response.payload.error
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unknown_step_returns_error(self):
|
||||||
|
"""handle_sequence_start should return error for unknown step."""
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
# Mock get_stream_pump with no listeners
|
||||||
|
mock_pump = MagicMock()
|
||||||
|
mock_pump.listeners = {} # No listeners registered
|
||||||
|
|
||||||
|
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
||||||
|
payload = SequenceStart(
|
||||||
|
steps="unknown_step",
|
||||||
|
payload="<Test/>",
|
||||||
|
return_to="caller",
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = HandlerMetadata(
|
||||||
|
thread_id=str(uuid.uuid4()),
|
||||||
|
from_id="console",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handle_sequence_start(payload, metadata)
|
||||||
|
|
||||||
|
assert isinstance(response, HandlerResponse)
|
||||||
|
assert isinstance(response.payload, SequenceError)
|
||||||
|
assert "Unknown listener" in response.payload.error
|
||||||
|
assert "unknown_step" in response.payload.failed_step
|
||||||
|
|
||||||
|
|
||||||
|
class TestSequenceRegistrySingleton:
|
||||||
|
"""Test singleton pattern for SequenceRegistry."""
|
||||||
|
|
||||||
|
def test_get_sequence_registry_returns_singleton(self):
|
||||||
|
"""get_sequence_registry should return same instance."""
|
||||||
|
reset_sequence_registry()
|
||||||
|
|
||||||
|
reg1 = get_sequence_registry()
|
||||||
|
reg2 = get_sequence_registry()
|
||||||
|
|
||||||
|
assert reg1 is reg2
|
||||||
|
|
||||||
|
def test_reset_creates_new_instance(self):
|
||||||
|
"""reset_sequence_registry should clear singleton."""
|
||||||
|
reg1 = get_sequence_registry()
|
||||||
|
reg1.create("test", ["a"], "x", "t", "c")
|
||||||
|
|
||||||
|
reset_sequence_registry()
|
||||||
|
reg2 = get_sequence_registry()
|
||||||
|
|
||||||
|
assert reg2.get("test") is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestSequenceMultipleSteps:
|
||||||
|
"""Test sequences with multiple steps."""
|
||||||
|
|
||||||
|
def test_three_step_sequence(self):
|
||||||
|
"""A three-step sequence should advance through all steps."""
|
||||||
|
registry = SequenceRegistry()
|
||||||
|
|
||||||
|
registry.create(
|
||||||
|
sequence_id="multi",
|
||||||
|
steps=["add", "multiply", "format"],
|
||||||
|
return_to="caller",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
initial_payload="<Input>5</Input>",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 1: add
|
||||||
|
state = registry.get("multi")
|
||||||
|
assert state.current_step == "add"
|
||||||
|
state = registry.advance("multi", "<Sum>8</Sum>")
|
||||||
|
assert state.last_result == "<Sum>8</Sum>"
|
||||||
|
|
||||||
|
# Step 2: multiply
|
||||||
|
assert state.current_step == "multiply"
|
||||||
|
state = registry.advance("multi", "<Product>40</Product>")
|
||||||
|
|
||||||
|
# Step 3: format
|
||||||
|
assert state.current_step == "format"
|
||||||
|
state = registry.advance("multi", "<Formatted>Result: 40</Formatted>")
|
||||||
|
|
||||||
|
# Complete
|
||||||
|
assert state.is_complete is True
|
||||||
|
assert len(state.results) == 3
|
||||||
|
assert state.results[0] == "<Sum>8</Sum>"
|
||||||
|
assert state.results[1] == "<Product>40</Product>"
|
||||||
|
assert state.results[2] == "<Formatted>Result: 40</Formatted>"
|
||||||
|
|
||||||
|
def test_failure_at_middle_step(self):
|
||||||
|
"""Failure at middle step should stop sequence."""
|
||||||
|
registry = SequenceRegistry()
|
||||||
|
|
||||||
|
registry.create(
|
||||||
|
sequence_id="fail-mid",
|
||||||
|
steps=["step1", "step2", "step3"],
|
||||||
|
return_to="caller",
|
||||||
|
thread_id="t",
|
||||||
|
from_id="c",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 1 succeeds
|
||||||
|
registry.advance("fail-mid", "<R1/>")
|
||||||
|
|
||||||
|
# Step 2 fails
|
||||||
|
state = registry.mark_failed("fail-mid", "step2", "Connection timeout")
|
||||||
|
|
||||||
|
assert state.failed is True
|
||||||
|
assert state.failed_step == "step2"
|
||||||
|
assert state.current_index == 1 # Was at step 2 (index 1)
|
||||||
|
assert len(state.results) == 1 # Only step 1 result
|
||||||
|
|
||||||
|
|
||||||
|
class TestSequenceWithRealSteps:
|
||||||
|
"""Integration-style tests with mock handlers."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sequence_creates_ephemeral_listener(self):
|
||||||
|
"""Starting a sequence should create an ephemeral listener."""
|
||||||
|
from unittest.mock import patch, MagicMock, AsyncMock
|
||||||
|
|
||||||
|
mock_pump = MagicMock()
|
||||||
|
mock_pump.listeners = {"step1": MagicMock(), "step2": MagicMock()}
|
||||||
|
mock_pump.register_generic_listener = MagicMock()
|
||||||
|
|
||||||
|
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
||||||
|
payload = SequenceStart(
|
||||||
|
steps="step1,step2",
|
||||||
|
payload="<Input/>",
|
||||||
|
return_to="caller",
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = HandlerMetadata(
|
||||||
|
thread_id=str(uuid.uuid4()),
|
||||||
|
from_id="console",
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await handle_sequence_start(payload, metadata)
|
||||||
|
|
||||||
|
# Should have registered an ephemeral listener
|
||||||
|
mock_pump.register_generic_listener.assert_called_once()
|
||||||
|
call_args = mock_pump.register_generic_listener.call_args
|
||||||
|
name_arg = call_args.kwargs.get('name') or call_args.args[0]
|
||||||
|
assert name_arg.startswith("sequence_")
|
||||||
|
|
@ -30,6 +30,9 @@ from xml_pipeline.message_bus.stream_pump import (
|
||||||
ListenerConfig,
|
ListenerConfig,
|
||||||
OrganismConfig,
|
OrganismConfig,
|
||||||
bootstrap,
|
bootstrap,
|
||||||
|
get_stream_pump,
|
||||||
|
set_stream_pump,
|
||||||
|
reset_stream_pump,
|
||||||
)
|
)
|
||||||
|
|
||||||
from xml_pipeline.message_bus.message_state import (
|
from xml_pipeline.message_bus.message_state import (
|
||||||
|
|
@ -42,15 +45,47 @@ from xml_pipeline.message_bus.system_pipeline import (
|
||||||
ExternalMessage,
|
ExternalMessage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from xml_pipeline.message_bus.sequence_registry import (
|
||||||
|
SequenceState,
|
||||||
|
SequenceRegistry,
|
||||||
|
get_sequence_registry,
|
||||||
|
reset_sequence_registry,
|
||||||
|
)
|
||||||
|
|
||||||
|
from xml_pipeline.message_bus.buffer_registry import (
|
||||||
|
BufferState,
|
||||||
|
BufferItemResult,
|
||||||
|
BufferRegistry,
|
||||||
|
get_buffer_registry,
|
||||||
|
reset_buffer_registry,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Pump
|
||||||
"StreamPump",
|
"StreamPump",
|
||||||
"ConfigLoader",
|
"ConfigLoader",
|
||||||
"Listener",
|
"Listener",
|
||||||
"ListenerConfig",
|
"ListenerConfig",
|
||||||
"OrganismConfig",
|
"OrganismConfig",
|
||||||
|
"bootstrap",
|
||||||
|
"get_stream_pump",
|
||||||
|
"set_stream_pump",
|
||||||
|
"reset_stream_pump",
|
||||||
|
# Message state
|
||||||
"MessageState",
|
"MessageState",
|
||||||
"HandlerMetadata",
|
"HandlerMetadata",
|
||||||
"bootstrap",
|
# System pipeline
|
||||||
"SystemPipeline",
|
"SystemPipeline",
|
||||||
"ExternalMessage",
|
"ExternalMessage",
|
||||||
|
# Sequence registry
|
||||||
|
"SequenceState",
|
||||||
|
"SequenceRegistry",
|
||||||
|
"get_sequence_registry",
|
||||||
|
"reset_sequence_registry",
|
||||||
|
# Buffer registry
|
||||||
|
"BufferState",
|
||||||
|
"BufferItemResult",
|
||||||
|
"BufferRegistry",
|
||||||
|
"get_buffer_registry",
|
||||||
|
"reset_buffer_registry",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
230
xml_pipeline/message_bus/buffer_registry.py
Normal file
230
xml_pipeline/message_bus/buffer_registry.py
Normal file
|
|
@ -0,0 +1,230 @@
|
||||||
|
"""
|
||||||
|
buffer_registry.py — State storage for Buffer (fan-out) orchestration.
|
||||||
|
|
||||||
|
Tracks active buffer executions that fan-out to parallel workers.
|
||||||
|
When a buffer starts, N items are dispatched in parallel. Results are
|
||||||
|
collected here. When all results are in (or timeout), BufferComplete is sent.
|
||||||
|
|
||||||
|
Design:
|
||||||
|
- Thread-safe (same pattern as TodoRegistry, SequenceRegistry)
|
||||||
|
- Keyed by buffer_id (short UUID)
|
||||||
|
- Tracks: total items, received results, success/failure per item
|
||||||
|
- Supports fire-and-forget mode (collect=False)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
registry = get_buffer_registry()
|
||||||
|
|
||||||
|
# Start a buffer
|
||||||
|
registry.create(
|
||||||
|
buffer_id="abc123",
|
||||||
|
total_items=5,
|
||||||
|
return_to="greeter",
|
||||||
|
thread_id="...",
|
||||||
|
collect=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record result from worker
|
||||||
|
state = registry.record_result(
|
||||||
|
buffer_id="abc123",
|
||||||
|
index=2,
|
||||||
|
result="<SearchResult>...</SearchResult>",
|
||||||
|
success=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if state.is_complete:
|
||||||
|
# All workers done
|
||||||
|
final_results = state.results
|
||||||
|
registry.remove(buffer_id)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BufferItemResult:
|
||||||
|
"""Result from a single buffer item (worker)."""
|
||||||
|
|
||||||
|
index: int
|
||||||
|
result: str # XML result
|
||||||
|
success: bool = True
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BufferState:
|
||||||
|
"""State for an active buffer execution."""
|
||||||
|
|
||||||
|
buffer_id: str
|
||||||
|
total_items: int # How many items were dispatched
|
||||||
|
return_to: str # Where to send BufferComplete
|
||||||
|
thread_id: str # Original thread for returning
|
||||||
|
from_id: str # Who started the buffer
|
||||||
|
target: str # Target listener for items
|
||||||
|
collect: bool = True # Whether to wait for results
|
||||||
|
|
||||||
|
results: Dict[int, BufferItemResult] = field(default_factory=dict)
|
||||||
|
completed_count: int = 0
|
||||||
|
successful_count: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_complete(self) -> bool:
|
||||||
|
"""True when all items have reported back."""
|
||||||
|
return self.completed_count >= self.total_items
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pending_count(self) -> int:
|
||||||
|
"""Number of items still pending."""
|
||||||
|
return self.total_items - self.completed_count
|
||||||
|
|
||||||
|
def get_ordered_results(self) -> List[Optional[BufferItemResult]]:
|
||||||
|
"""Get results in order (None for missing indices)."""
|
||||||
|
return [self.results.get(i) for i in range(self.total_items)]
|
||||||
|
|
||||||
|
|
||||||
|
class BufferRegistry:
|
||||||
|
"""
|
||||||
|
Registry for active buffer executions.
|
||||||
|
|
||||||
|
Thread-safe. Singleton pattern via get_buffer_registry().
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._buffers: Dict[str, BufferState] = {}
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
buffer_id: str,
|
||||||
|
total_items: int,
|
||||||
|
return_to: str,
|
||||||
|
thread_id: str,
|
||||||
|
from_id: str,
|
||||||
|
target: str,
|
||||||
|
collect: bool = True,
|
||||||
|
) -> BufferState:
|
||||||
|
"""
|
||||||
|
Create a new buffer execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
buffer_id: Unique ID for this buffer
|
||||||
|
total_items: Number of items being dispatched
|
||||||
|
return_to: Listener to send BufferComplete to
|
||||||
|
thread_id: Thread UUID for routing
|
||||||
|
from_id: Who initiated the buffer
|
||||||
|
target: Target listener for each item
|
||||||
|
collect: Whether to wait for and collect results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BufferState for tracking
|
||||||
|
"""
|
||||||
|
state = BufferState(
|
||||||
|
buffer_id=buffer_id,
|
||||||
|
total_items=total_items,
|
||||||
|
return_to=return_to,
|
||||||
|
thread_id=thread_id,
|
||||||
|
from_id=from_id,
|
||||||
|
target=target,
|
||||||
|
collect=collect,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self._buffers[buffer_id] = state
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def get(self, buffer_id: str) -> Optional[BufferState]:
|
||||||
|
"""Get buffer state by ID."""
|
||||||
|
with self._lock:
|
||||||
|
return self._buffers.get(buffer_id)
|
||||||
|
|
||||||
|
def record_result(
|
||||||
|
self,
|
||||||
|
buffer_id: str,
|
||||||
|
index: int,
|
||||||
|
result: str,
|
||||||
|
success: bool = True,
|
||||||
|
error: Optional[str] = None,
|
||||||
|
) -> Optional[BufferState]:
|
||||||
|
"""
|
||||||
|
Record a result from a worker.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
buffer_id: Buffer this result belongs to
|
||||||
|
index: Which item index (0-based)
|
||||||
|
result: XML result from the worker
|
||||||
|
success: Whether the worker succeeded
|
||||||
|
error: Error message if failed
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated BufferState, or None if buffer not found
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
state = self._buffers.get(buffer_id)
|
||||||
|
if state is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Don't double-count results for same index
|
||||||
|
if index in state.results:
|
||||||
|
return state
|
||||||
|
|
||||||
|
item_result = BufferItemResult(
|
||||||
|
index=index,
|
||||||
|
result=result,
|
||||||
|
success=success,
|
||||||
|
error=error,
|
||||||
|
)
|
||||||
|
state.results[index] = item_result
|
||||||
|
state.completed_count += 1
|
||||||
|
if success:
|
||||||
|
state.successful_count += 1
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def remove(self, buffer_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Remove a buffer (cleanup after completion).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if found and removed, False if not found
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
return self._buffers.pop(buffer_id, None) is not None
|
||||||
|
|
||||||
|
def list_active(self) -> List[str]:
|
||||||
|
"""List all active buffer IDs."""
|
||||||
|
with self._lock:
|
||||||
|
return list(self._buffers.keys())
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear all buffers. Useful for testing."""
|
||||||
|
with self._lock:
|
||||||
|
self._buffers.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Singleton
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
_registry: Optional[BufferRegistry] = None
|
||||||
|
_registry_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_buffer_registry() -> BufferRegistry:
|
||||||
|
"""Get the global BufferRegistry singleton."""
|
||||||
|
global _registry
|
||||||
|
if _registry is None:
|
||||||
|
with _registry_lock:
|
||||||
|
if _registry is None:
|
||||||
|
_registry = BufferRegistry()
|
||||||
|
return _registry
|
||||||
|
|
||||||
|
|
||||||
|
def reset_buffer_registry() -> None:
|
||||||
|
"""Reset the global buffer registry (for testing)."""
|
||||||
|
global _registry
|
||||||
|
with _registry_lock:
|
||||||
|
if _registry is not None:
|
||||||
|
_registry.clear()
|
||||||
|
_registry = None
|
||||||
228
xml_pipeline/message_bus/sequence_registry.py
Normal file
228
xml_pipeline/message_bus/sequence_registry.py
Normal file
|
|
@ -0,0 +1,228 @@
|
||||||
|
"""
|
||||||
|
sequence_registry.py — State storage for Sequence orchestration.
|
||||||
|
|
||||||
|
Tracks active sequence executions across handler invocations.
|
||||||
|
When a sequence starts, its state is registered here. As steps complete,
|
||||||
|
the state is updated. When all steps are done, the state is cleaned up.
|
||||||
|
|
||||||
|
Design:
|
||||||
|
- Thread-safe (same pattern as TodoRegistry)
|
||||||
|
- Keyed by sequence_id (short UUID)
|
||||||
|
- Tracks: steps list, current index, collected results
|
||||||
|
- Auto-cleanup when sequence completes or errors
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
registry = get_sequence_registry()
|
||||||
|
|
||||||
|
# Start a sequence
|
||||||
|
registry.create(
|
||||||
|
sequence_id="abc123",
|
||||||
|
steps=["calculator.add", "calculator.multiply"],
|
||||||
|
return_to="greeter",
|
||||||
|
thread_id="...",
|
||||||
|
initial_payload="<AddPayload>...</AddPayload>",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Advance on step completion
|
||||||
|
state = registry.advance(sequence_id, step_result="<AddResult>42</AddResult>")
|
||||||
|
if state.is_complete:
|
||||||
|
# All steps done
|
||||||
|
registry.remove(sequence_id)
|
||||||
|
|
||||||
|
# On error
|
||||||
|
registry.mark_failed(sequence_id, step="calculator.add", error="XSD validation failed")
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
import threading
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SequenceState:
|
||||||
|
"""State for an active sequence execution."""
|
||||||
|
|
||||||
|
sequence_id: str
|
||||||
|
steps: List[str] # Ordered list of listener names
|
||||||
|
return_to: str # Where to send final result
|
||||||
|
thread_id: str # Original thread for returning
|
||||||
|
from_id: str # Who started the sequence
|
||||||
|
|
||||||
|
current_index: int = 0 # Which step we're on (0-based)
|
||||||
|
results: List[str] = field(default_factory=list) # XML results from each step
|
||||||
|
last_result: Optional[str] = None # Most recent step result (for chaining)
|
||||||
|
|
||||||
|
failed: bool = False
|
||||||
|
failed_step: Optional[str] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_complete(self) -> bool:
|
||||||
|
"""True when all steps have been executed successfully."""
|
||||||
|
return not self.failed and self.current_index >= len(self.steps)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_step(self) -> Optional[str]:
|
||||||
|
"""Get current step name, or None if complete/failed."""
|
||||||
|
if self.failed or self.current_index >= len(self.steps):
|
||||||
|
return None
|
||||||
|
return self.steps[self.current_index]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def remaining_steps(self) -> List[str]:
|
||||||
|
"""Steps not yet executed."""
|
||||||
|
if self.failed:
|
||||||
|
return []
|
||||||
|
return self.steps[self.current_index:]
|
||||||
|
|
||||||
|
|
||||||
|
class SequenceRegistry:
|
||||||
|
"""
|
||||||
|
Registry for active sequence executions.
|
||||||
|
|
||||||
|
Thread-safe. Singleton pattern via get_sequence_registry().
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._sequences: Dict[str, SequenceState] = {}
|
||||||
|
|
||||||
|
def create(
|
||||||
|
self,
|
||||||
|
sequence_id: str,
|
||||||
|
steps: List[str],
|
||||||
|
return_to: str,
|
||||||
|
thread_id: str,
|
||||||
|
from_id: str,
|
||||||
|
initial_payload: str = "",
|
||||||
|
) -> SequenceState:
|
||||||
|
"""
|
||||||
|
Create a new sequence execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequence_id: Unique ID for this sequence
|
||||||
|
steps: Ordered list of listener names to call
|
||||||
|
return_to: Listener to send SequenceComplete to
|
||||||
|
thread_id: Thread UUID for routing
|
||||||
|
from_id: Who initiated the sequence
|
||||||
|
initial_payload: XML payload for first step
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SequenceState for tracking
|
||||||
|
"""
|
||||||
|
state = SequenceState(
|
||||||
|
sequence_id=sequence_id,
|
||||||
|
steps=steps,
|
||||||
|
return_to=return_to,
|
||||||
|
thread_id=thread_id,
|
||||||
|
from_id=from_id,
|
||||||
|
last_result=initial_payload if initial_payload else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self._sequences[sequence_id] = state
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def get(self, sequence_id: str) -> Optional[SequenceState]:
|
||||||
|
"""Get sequence state by ID."""
|
||||||
|
with self._lock:
|
||||||
|
return self._sequences.get(sequence_id)
|
||||||
|
|
||||||
|
def advance(self, sequence_id: str, step_result: str) -> Optional[SequenceState]:
|
||||||
|
"""
|
||||||
|
Record step completion and advance to next step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequence_id: Sequence to advance
|
||||||
|
step_result: XML result from the completed step
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated SequenceState, or None if not found
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
state = self._sequences.get(sequence_id)
|
||||||
|
if state is None or state.failed:
|
||||||
|
return state
|
||||||
|
|
||||||
|
# Record result
|
||||||
|
state.results.append(step_result)
|
||||||
|
state.last_result = step_result
|
||||||
|
state.current_index += 1
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def mark_failed(
|
||||||
|
self,
|
||||||
|
sequence_id: str,
|
||||||
|
step: str,
|
||||||
|
error: str,
|
||||||
|
) -> Optional[SequenceState]:
|
||||||
|
"""
|
||||||
|
Mark a sequence as failed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sequence_id: Sequence that failed
|
||||||
|
step: Which step failed
|
||||||
|
error: Error message
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated SequenceState, or None if not found
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
state = self._sequences.get(sequence_id)
|
||||||
|
if state is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
state.failed = True
|
||||||
|
state.failed_step = step
|
||||||
|
state.error = error
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
|
def remove(self, sequence_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Remove a sequence (cleanup after completion).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if found and removed, False if not found
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
return self._sequences.pop(sequence_id, None) is not None
|
||||||
|
|
||||||
|
def list_active(self) -> List[str]:
|
||||||
|
"""List all active sequence IDs."""
|
||||||
|
with self._lock:
|
||||||
|
return list(self._sequences.keys())
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear all sequences. Useful for testing."""
|
||||||
|
with self._lock:
|
||||||
|
self._sequences.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Singleton
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
_registry: Optional[SequenceRegistry] = None
|
||||||
|
_registry_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_sequence_registry() -> SequenceRegistry:
|
||||||
|
"""Get the global SequenceRegistry singleton."""
|
||||||
|
global _registry
|
||||||
|
if _registry is None:
|
||||||
|
with _registry_lock:
|
||||||
|
if _registry is None:
|
||||||
|
_registry = SequenceRegistry()
|
||||||
|
return _registry
|
||||||
|
|
||||||
|
|
||||||
|
def reset_sequence_registry() -> None:
|
||||||
|
"""Reset the global sequence registry (for testing)."""
|
||||||
|
global _registry
|
||||||
|
with _registry_lock:
|
||||||
|
if _registry is not None:
|
||||||
|
_registry.clear()
|
||||||
|
_registry = None
|
||||||
|
|
@ -210,6 +210,10 @@ class StreamPump:
|
||||||
self.routing_table: Dict[str, List[Listener]] = {}
|
self.routing_table: Dict[str, List[Listener]] = {}
|
||||||
self.listeners: Dict[str, Listener] = {}
|
self.listeners: Dict[str, Listener] = {}
|
||||||
|
|
||||||
|
# Generic listeners (accept any payload type)
|
||||||
|
# Used for ephemeral orchestration handlers (sequences, buffers)
|
||||||
|
self._generic_listeners: Dict[str, Listener] = {}
|
||||||
|
|
||||||
# Per-agent semaphores for rate limiting
|
# Per-agent semaphores for rate limiting
|
||||||
self.agent_semaphores: Dict[str, asyncio.Semaphore] = {}
|
self.agent_semaphores: Dict[str, asyncio.Semaphore] = {}
|
||||||
|
|
||||||
|
|
@ -269,6 +273,82 @@ class StreamPump:
|
||||||
self.listeners[lc.name] = listener
|
self.listeners[lc.name] = listener
|
||||||
return listener
|
return listener
|
||||||
|
|
||||||
|
def register_generic_listener(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
handler: Callable,
|
||||||
|
description: str = "",
|
||||||
|
) -> Listener:
|
||||||
|
"""
|
||||||
|
Register a generic listener that accepts any payload type.
|
||||||
|
|
||||||
|
Used for ephemeral orchestration handlers (sequences, buffers)
|
||||||
|
that need to receive responses from various step types.
|
||||||
|
|
||||||
|
Generic listeners:
|
||||||
|
- Are NOT added to the routing table (no root_tag)
|
||||||
|
- Are looked up by name (to_id) as a fallback in routing
|
||||||
|
- Receive payload_tree directly (no XSD validation/deserialization)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Unique listener name (e.g., "sequence_abc123")
|
||||||
|
handler: Async handler function (receives payload_tree, metadata)
|
||||||
|
description: Human-readable description
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Listener object
|
||||||
|
"""
|
||||||
|
listener = Listener(
|
||||||
|
name=name,
|
||||||
|
payload_class=object, # Placeholder - not used
|
||||||
|
handler=handler,
|
||||||
|
description=description,
|
||||||
|
is_agent=False,
|
||||||
|
root_tag="*", # Wildcard marker
|
||||||
|
)
|
||||||
|
|
||||||
|
self._generic_listeners[name.lower()] = listener
|
||||||
|
self.listeners[name] = listener
|
||||||
|
|
||||||
|
pump_logger.debug(f"Registered generic listener: {name}")
|
||||||
|
return listener
|
||||||
|
|
||||||
|
def unregister_listener(self, name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Remove a listener by name.
|
||||||
|
|
||||||
|
Used to clean up ephemeral listeners after orchestration completes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Listener name to remove
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if found and removed, False if not found
|
||||||
|
"""
|
||||||
|
name_lower = name.lower()
|
||||||
|
removed = False
|
||||||
|
|
||||||
|
# Remove from generic listeners
|
||||||
|
if name_lower in self._generic_listeners:
|
||||||
|
del self._generic_listeners[name_lower]
|
||||||
|
removed = True
|
||||||
|
pump_logger.debug(f"Unregistered generic listener: {name}")
|
||||||
|
|
||||||
|
# Remove from main listeners dict
|
||||||
|
if name in self.listeners:
|
||||||
|
listener = self.listeners.pop(name)
|
||||||
|
removed = True
|
||||||
|
|
||||||
|
# Remove from routing table
|
||||||
|
if listener.root_tag and listener.root_tag != "*":
|
||||||
|
listeners_for_tag = self.routing_table.get(listener.root_tag, [])
|
||||||
|
if listener in listeners_for_tag:
|
||||||
|
listeners_for_tag.remove(listener)
|
||||||
|
if not listeners_for_tag:
|
||||||
|
del self.routing_table[listener.root_tag]
|
||||||
|
|
||||||
|
return removed
|
||||||
|
|
||||||
def register_all(self) -> None:
|
def register_all(self) -> None:
|
||||||
# First pass: register all listeners
|
# First pass: register all listeners
|
||||||
for lc in self.config.listeners:
|
for lc in self.config.listeners:
|
||||||
|
|
@ -781,6 +861,8 @@ class StreamPump:
|
||||||
Combined validation + deserialization.
|
Combined validation + deserialization.
|
||||||
|
|
||||||
Uses to_id + payload tag to find the right listener and schema.
|
Uses to_id + payload tag to find the right listener and schema.
|
||||||
|
Falls back to generic listeners (ephemeral orchestration handlers)
|
||||||
|
when no regular listener matches.
|
||||||
"""
|
"""
|
||||||
if state.error or state.payload_tree is None:
|
if state.error or state.payload_tree is None:
|
||||||
return state
|
return state
|
||||||
|
|
@ -794,6 +876,19 @@ class StreamPump:
|
||||||
lookup_key = f"{to_id}.{payload_tag.lower()}" if to_id else payload_tag.lower()
|
lookup_key = f"{to_id}.{payload_tag.lower()}" if to_id else payload_tag.lower()
|
||||||
|
|
||||||
listeners = self.routing_table.get(lookup_key, [])
|
listeners = self.routing_table.get(lookup_key, [])
|
||||||
|
|
||||||
|
# Fallback: check for generic listener by to_id
|
||||||
|
# Generic listeners accept any payload type (for orchestration)
|
||||||
|
if not listeners and to_id:
|
||||||
|
generic_listener = self._generic_listeners.get(to_id)
|
||||||
|
if generic_listener:
|
||||||
|
# Generic listener: skip XSD validation and deserialization
|
||||||
|
# Pass the raw payload_tree to the handler
|
||||||
|
state.payload = state.payload_tree # Handler receives Element
|
||||||
|
state.target_listeners = [generic_listener]
|
||||||
|
state.metadata["generic_handler"] = True
|
||||||
|
return state
|
||||||
|
|
||||||
if not listeners:
|
if not listeners:
|
||||||
state.error = f"No listener for: {lookup_key}"
|
state.error = f"No listener for: {lookup_key}"
|
||||||
return state
|
return state
|
||||||
|
|
@ -1008,6 +1103,36 @@ async def bootstrap(config_path: str = "config/organism.yaml") -> StreamPump:
|
||||||
)
|
)
|
||||||
pump.register_listener(todo_complete_config)
|
pump.register_listener(todo_complete_config)
|
||||||
|
|
||||||
|
# Register Sequence primitives (orchestration)
|
||||||
|
from xml_pipeline.primitives.sequence import (
|
||||||
|
SequenceStart, handle_sequence_start,
|
||||||
|
)
|
||||||
|
sequence_config = ListenerConfig(
|
||||||
|
name="system.sequence",
|
||||||
|
payload_class_path="xml_pipeline.primitives.sequence.SequenceStart",
|
||||||
|
handler_path="xml_pipeline.primitives.sequence.handle_sequence_start",
|
||||||
|
description="System sequence handler - chains listeners in order",
|
||||||
|
is_agent=False,
|
||||||
|
payload_class=SequenceStart,
|
||||||
|
handler=handle_sequence_start,
|
||||||
|
)
|
||||||
|
pump.register_listener(sequence_config)
|
||||||
|
|
||||||
|
# Register Buffer primitives (fan-out orchestration)
|
||||||
|
from xml_pipeline.primitives.buffer import (
|
||||||
|
BufferStart, handle_buffer_start,
|
||||||
|
)
|
||||||
|
buffer_config = ListenerConfig(
|
||||||
|
name="system.buffer",
|
||||||
|
payload_class_path="xml_pipeline.primitives.buffer.BufferStart",
|
||||||
|
handler_path="xml_pipeline.primitives.buffer.handle_buffer_start",
|
||||||
|
description="System buffer handler - fan-out to parallel workers",
|
||||||
|
is_agent=False,
|
||||||
|
payload_class=BufferStart,
|
||||||
|
handler=handle_buffer_start,
|
||||||
|
)
|
||||||
|
pump.register_listener(buffer_config)
|
||||||
|
|
||||||
# Register all user-defined listeners
|
# Register all user-defined listeners
|
||||||
pump.register_all()
|
pump.register_all()
|
||||||
|
|
||||||
|
|
@ -1061,6 +1186,9 @@ async def bootstrap(config_path: str = "config/organism.yaml") -> StreamPump:
|
||||||
# Inject boot message (will be processed when pump.run() is called)
|
# Inject boot message (will be processed when pump.run() is called)
|
||||||
await pump.inject(boot_envelope, thread_id=root_uuid, from_id="system")
|
await pump.inject(boot_envelope, thread_id=root_uuid, from_id="system")
|
||||||
|
|
||||||
|
# Set global pump instance for get_stream_pump()
|
||||||
|
set_stream_pump(pump)
|
||||||
|
|
||||||
print(f"Routing: {list(pump.routing_table.keys())}")
|
print(f"Routing: {list(pump.routing_table.keys())}")
|
||||||
return pump
|
return pump
|
||||||
|
|
||||||
|
|
@ -1110,3 +1238,45 @@ The key difference:
|
||||||
- Old: 3 tool calls = 3 sequential awaits, each blocking until complete
|
- Old: 3 tool calls = 3 sequential awaits, each blocking until complete
|
||||||
- New: 3 tool calls = 3 items in stream, processed concurrently up to task_limit
|
- New: 3 tool calls = 3 items in stream, processed concurrently up to task_limit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Global Singleton
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
_pump: Optional[StreamPump] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_stream_pump() -> StreamPump:
|
||||||
|
"""
|
||||||
|
Get the global StreamPump instance.
|
||||||
|
|
||||||
|
The pump is initialized via bootstrap() and set here.
|
||||||
|
Raises RuntimeError if called before bootstrap.
|
||||||
|
"""
|
||||||
|
global _pump
|
||||||
|
if _pump is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"StreamPump not initialized. Call bootstrap() first."
|
||||||
|
)
|
||||||
|
return _pump
|
||||||
|
|
||||||
|
|
||||||
|
def set_stream_pump(pump: StreamPump) -> None:
|
||||||
|
"""
|
||||||
|
Set the global StreamPump instance.
|
||||||
|
|
||||||
|
Called by bootstrap() after creating the pump.
|
||||||
|
"""
|
||||||
|
global _pump
|
||||||
|
_pump = pump
|
||||||
|
|
||||||
|
|
||||||
|
def reset_stream_pump() -> None:
|
||||||
|
"""
|
||||||
|
Reset the global StreamPump instance.
|
||||||
|
|
||||||
|
Useful for testing.
|
||||||
|
"""
|
||||||
|
global _pump
|
||||||
|
_pump = None
|
||||||
|
|
|
||||||
|
|
@ -15,16 +15,45 @@ from xml_pipeline.primitives.todo import (
|
||||||
handle_todo_complete,
|
handle_todo_complete,
|
||||||
)
|
)
|
||||||
from xml_pipeline.primitives.text_input import TextInput, TextOutput
|
from xml_pipeline.primitives.text_input import TextInput, TextOutput
|
||||||
|
from xml_pipeline.primitives.sequence import (
|
||||||
|
SequenceStart,
|
||||||
|
SequenceComplete,
|
||||||
|
SequenceError,
|
||||||
|
handle_sequence_start,
|
||||||
|
)
|
||||||
|
from xml_pipeline.primitives.buffer import (
|
||||||
|
BufferStart,
|
||||||
|
BufferItem,
|
||||||
|
BufferComplete,
|
||||||
|
BufferDispatched,
|
||||||
|
BufferError,
|
||||||
|
handle_buffer_start,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Boot
|
||||||
"Boot",
|
"Boot",
|
||||||
"handle_boot",
|
"handle_boot",
|
||||||
|
# Todo
|
||||||
"TodoUntil",
|
"TodoUntil",
|
||||||
"TodoComplete",
|
"TodoComplete",
|
||||||
"TodoRegistered",
|
"TodoRegistered",
|
||||||
"TodoClosed",
|
"TodoClosed",
|
||||||
"handle_todo_until",
|
"handle_todo_until",
|
||||||
"handle_todo_complete",
|
"handle_todo_complete",
|
||||||
|
# Text I/O
|
||||||
"TextInput",
|
"TextInput",
|
||||||
"TextOutput",
|
"TextOutput",
|
||||||
|
# Sequence orchestration
|
||||||
|
"SequenceStart",
|
||||||
|
"SequenceComplete",
|
||||||
|
"SequenceError",
|
||||||
|
"handle_sequence_start",
|
||||||
|
# Buffer orchestration
|
||||||
|
"BufferStart",
|
||||||
|
"BufferItem",
|
||||||
|
"BufferComplete",
|
||||||
|
"BufferDispatched",
|
||||||
|
"BufferError",
|
||||||
|
"handle_buffer_start",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
373
xml_pipeline/primitives/buffer.py
Normal file
373
xml_pipeline/primitives/buffer.py
Normal file
|
|
@ -0,0 +1,373 @@
|
||||||
|
"""
|
||||||
|
buffer.py — Buffer (fan-out) orchestration primitives.
|
||||||
|
|
||||||
|
Buffers fan-out to parallel workers, sending N items to the same listener
|
||||||
|
concurrently. Results are collected and returned when all complete.
|
||||||
|
|
||||||
|
Usage by an agent:
|
||||||
|
# Fan-out search queries to web_search
|
||||||
|
return HandlerResponse(
|
||||||
|
payload=BufferStart(
|
||||||
|
target="web_search",
|
||||||
|
items="python async\\nrust memory\\ngo concurrency",
|
||||||
|
return_to="my-agent",
|
||||||
|
collect=True,
|
||||||
|
),
|
||||||
|
to="system.buffer",
|
||||||
|
)
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
1. system.buffer receives BufferStart with N items
|
||||||
|
2. Creates ephemeral listener buffer_{id} to receive results
|
||||||
|
3. Creates N sibling threads via ThreadRegistry
|
||||||
|
4. Sends BufferItem to each worker FROM buffer_{id}
|
||||||
|
5. Workers process and respond → routes to buffer_{id}
|
||||||
|
6. Ephemeral handler collects results
|
||||||
|
7. When all workers done, sends BufferComplete to return_to
|
||||||
|
8. Cleans up ephemeral listener
|
||||||
|
|
||||||
|
Fire-and-forget mode (collect=False):
|
||||||
|
- Returns immediately after dispatching
|
||||||
|
- No result collection
|
||||||
|
- Useful for async side effects
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
import uuid as uuid_module
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from lxml import etree
|
||||||
|
from third_party.xmlable import xmlify
|
||||||
|
from xml_pipeline.message_bus.message_state import (
|
||||||
|
HandlerMetadata,
|
||||||
|
HandlerResponse,
|
||||||
|
)
|
||||||
|
from xml_pipeline.message_bus.buffer_registry import get_buffer_registry
|
||||||
|
from xml_pipeline.message_bus.thread_registry import get_registry as get_thread_registry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Payloads
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@xmlify
|
||||||
|
@dataclass
|
||||||
|
class BufferStart:
|
||||||
|
"""
|
||||||
|
Start a new buffer (fan-out) execution.
|
||||||
|
|
||||||
|
Sent to system.buffer to begin parallel processing.
|
||||||
|
"""
|
||||||
|
target: str = "" # Listener to fan-out to
|
||||||
|
items: str = "" # Newline-separated payloads (raw XML)
|
||||||
|
collect: bool = True # Wait for all results?
|
||||||
|
return_to: str = "" # Where to send BufferComplete
|
||||||
|
buffer_id: str = "" # Auto-generated if empty
|
||||||
|
|
||||||
|
|
||||||
|
@xmlify
|
||||||
|
@dataclass
|
||||||
|
class BufferItem:
|
||||||
|
"""
|
||||||
|
Individual item being processed by a worker.
|
||||||
|
|
||||||
|
Wraps the actual payload with buffer metadata.
|
||||||
|
Note: This is an internal type - workers receive the raw payload,
|
||||||
|
not BufferItem directly.
|
||||||
|
"""
|
||||||
|
buffer_id: str = ""
|
||||||
|
index: int = 0
|
||||||
|
payload: str = "" # The actual XML payload
|
||||||
|
|
||||||
|
|
||||||
|
@xmlify
|
||||||
|
@dataclass
|
||||||
|
class BufferComplete:
|
||||||
|
"""
|
||||||
|
Buffer completed - all workers finished.
|
||||||
|
|
||||||
|
Sent to return_to when all items are processed.
|
||||||
|
"""
|
||||||
|
buffer_id: str = ""
|
||||||
|
total: int = 0
|
||||||
|
successful: int = 0
|
||||||
|
results: str = "" # XML array of results
|
||||||
|
|
||||||
|
|
||||||
|
@xmlify
|
||||||
|
@dataclass
|
||||||
|
class BufferDispatched:
|
||||||
|
"""
|
||||||
|
Buffer dispatched (fire-and-forget mode).
|
||||||
|
|
||||||
|
Sent immediately after items are dispatched when collect=False.
|
||||||
|
"""
|
||||||
|
buffer_id: str = ""
|
||||||
|
total: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@xmlify
|
||||||
|
@dataclass
|
||||||
|
class BufferError:
|
||||||
|
"""
|
||||||
|
Buffer failed to start.
|
||||||
|
|
||||||
|
Sent when buffer initialization fails.
|
||||||
|
"""
|
||||||
|
buffer_id: str = ""
|
||||||
|
error: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Handlers
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
async def handle_buffer_start(
|
||||||
|
payload: BufferStart,
|
||||||
|
metadata: HandlerMetadata,
|
||||||
|
) -> Optional[HandlerResponse]:
|
||||||
|
"""
|
||||||
|
Handle BufferStart — begin a fan-out execution.
|
||||||
|
|
||||||
|
Creates N sibling threads, dispatches items to workers,
|
||||||
|
and sets up result collection.
|
||||||
|
"""
|
||||||
|
from xml_pipeline.message_bus.stream_pump import get_stream_pump
|
||||||
|
|
||||||
|
# Parse items
|
||||||
|
items = [item.strip() for item in payload.items.split("\n") if item.strip()]
|
||||||
|
if not items:
|
||||||
|
logger.error("BufferStart with no items")
|
||||||
|
return HandlerResponse(
|
||||||
|
payload=BufferError(
|
||||||
|
buffer_id=payload.buffer_id or "unknown",
|
||||||
|
error="No items specified",
|
||||||
|
),
|
||||||
|
to=payload.return_to or metadata.from_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate target exists
|
||||||
|
pump = get_stream_pump()
|
||||||
|
if payload.target not in pump.listeners:
|
||||||
|
logger.error(f"BufferStart: unknown target '{payload.target}'")
|
||||||
|
return HandlerResponse(
|
||||||
|
payload=BufferError(
|
||||||
|
buffer_id=payload.buffer_id or "unknown",
|
||||||
|
error=f"Unknown target listener: {payload.target}",
|
||||||
|
),
|
||||||
|
to=payload.return_to or metadata.from_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate buffer ID if not provided
|
||||||
|
buf_id = payload.buffer_id or str(uuid_module.uuid4())[:8]
|
||||||
|
|
||||||
|
# Create buffer state
|
||||||
|
buffer_registry = get_buffer_registry()
|
||||||
|
state = buffer_registry.create(
|
||||||
|
buffer_id=buf_id,
|
||||||
|
total_items=len(items),
|
||||||
|
return_to=payload.return_to or metadata.from_id,
|
||||||
|
thread_id=metadata.thread_id,
|
||||||
|
from_id=metadata.from_id,
|
||||||
|
target=payload.target,
|
||||||
|
collect=payload.collect,
|
||||||
|
)
|
||||||
|
|
||||||
|
# For fire-and-forget, we still track but don't wait
|
||||||
|
ephemeral_name = f"buffer_{buf_id}"
|
||||||
|
|
||||||
|
if payload.collect:
|
||||||
|
# Create ephemeral handler for result collection
|
||||||
|
async def buffer_handler(
|
||||||
|
payload_tree: etree._Element,
|
||||||
|
meta: HandlerMetadata,
|
||||||
|
) -> Optional[HandlerResponse]:
|
||||||
|
"""Ephemeral handler that collects worker results."""
|
||||||
|
return await _handle_buffer_result(buf_id, payload_tree, meta)
|
||||||
|
|
||||||
|
# Register ephemeral listener (generic mode - accepts any payload)
|
||||||
|
pump.register_generic_listener(
|
||||||
|
name=ephemeral_name,
|
||||||
|
handler=buffer_handler,
|
||||||
|
description=f"Ephemeral buffer handler for {buf_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Buffer {buf_id} starting: {len(items)} items to {payload.target}, "
|
||||||
|
f"collect={payload.collect}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dispatch all items in parallel
|
||||||
|
thread_registry = get_thread_registry()
|
||||||
|
parent_chain = thread_registry.lookup(metadata.thread_id) or metadata.thread_id
|
||||||
|
|
||||||
|
for i, item_payload in enumerate(items):
|
||||||
|
# Create sibling thread for this worker
|
||||||
|
worker_chain = f"{parent_chain}.{ephemeral_name}_w{i}"
|
||||||
|
worker_uuid = thread_registry.get_or_create(worker_chain)
|
||||||
|
|
||||||
|
# Inject the item to the target
|
||||||
|
# The item is sent FROM the ephemeral listener so .respond() comes back
|
||||||
|
await _inject_buffer_item(
|
||||||
|
pump=pump,
|
||||||
|
target=payload.target,
|
||||||
|
payload_xml=item_payload,
|
||||||
|
thread_id=worker_uuid,
|
||||||
|
from_id=ephemeral_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Buffer {buf_id}: dispatched item {i} to {payload.target}")
|
||||||
|
|
||||||
|
# Fire-and-forget: return immediately
|
||||||
|
if not payload.collect:
|
||||||
|
logger.info(f"Buffer {buf_id}: fire-and-forget mode, {len(items)} items dispatched")
|
||||||
|
return HandlerResponse(
|
||||||
|
payload=BufferDispatched(
|
||||||
|
buffer_id=buf_id,
|
||||||
|
total=len(items),
|
||||||
|
),
|
||||||
|
to=payload.return_to or metadata.from_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect mode: wait for results (handled by ephemeral listener)
|
||||||
|
# Return None - the ephemeral listener will send BufferComplete
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_buffer_result(
|
||||||
|
buf_id: str,
|
||||||
|
payload_tree: etree._Element,
|
||||||
|
metadata: HandlerMetadata,
|
||||||
|
) -> Optional[HandlerResponse]:
|
||||||
|
"""
|
||||||
|
Handle a worker result in the buffer.
|
||||||
|
|
||||||
|
Called by the ephemeral listener when a worker responds.
|
||||||
|
"""
|
||||||
|
from xml_pipeline.message_bus.stream_pump import get_stream_pump
|
||||||
|
|
||||||
|
buffer_registry = get_buffer_registry()
|
||||||
|
state = buffer_registry.get(buf_id)
|
||||||
|
|
||||||
|
if state is None:
|
||||||
|
logger.warning(f"Buffer {buf_id} not found in registry (result dropped)")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Extract worker index from thread chain
|
||||||
|
# Chain format: parent.buffer_xyz_wN where N is the index
|
||||||
|
thread_registry = get_thread_registry()
|
||||||
|
chain = thread_registry.lookup(metadata.thread_id) or ""
|
||||||
|
|
||||||
|
worker_index = _extract_worker_index(chain, buf_id)
|
||||||
|
if worker_index is None:
|
||||||
|
logger.warning(f"Buffer {buf_id}: could not determine worker index from chain")
|
||||||
|
worker_index = state.completed_count # Fallback to count-based
|
||||||
|
|
||||||
|
# Serialize the result
|
||||||
|
result_xml = etree.tostring(payload_tree, encoding="unicode")
|
||||||
|
|
||||||
|
# Check for errors
|
||||||
|
is_error = payload_tree.tag.lower() in ("huh", "systemerror")
|
||||||
|
|
||||||
|
# Record result
|
||||||
|
state = buffer_registry.record_result(
|
||||||
|
buffer_id=buf_id,
|
||||||
|
index=worker_index,
|
||||||
|
result=result_xml,
|
||||||
|
success=not is_error,
|
||||||
|
error=result_xml[:200] if is_error else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Buffer {buf_id}: received result {state.completed_count}/{state.total_items}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if all done
|
||||||
|
if state.is_complete:
|
||||||
|
# Clean up
|
||||||
|
pump = get_stream_pump()
|
||||||
|
pump.unregister_listener(f"buffer_{buf_id}")
|
||||||
|
|
||||||
|
# Format results as XML array
|
||||||
|
results_xml = _format_buffer_results(state)
|
||||||
|
buffer_registry.remove(buf_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Buffer {buf_id} completed: {state.successful_count}/{state.total_items} successful"
|
||||||
|
)
|
||||||
|
|
||||||
|
return HandlerResponse(
|
||||||
|
payload=BufferComplete(
|
||||||
|
buffer_id=buf_id,
|
||||||
|
total=state.total_items,
|
||||||
|
successful=state.successful_count,
|
||||||
|
results=results_xml,
|
||||||
|
),
|
||||||
|
to=state.return_to,
|
||||||
|
)
|
||||||
|
|
||||||
|
# More results pending
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_worker_index(chain: str, buf_id: str) -> Optional[int]:
|
||||||
|
"""Extract worker index from thread chain."""
|
||||||
|
# Look for pattern: buffer_{id}_wN
|
||||||
|
import re
|
||||||
|
pattern = rf"buffer_{buf_id}_w(\d+)"
|
||||||
|
match = re.search(pattern, chain)
|
||||||
|
if match:
|
||||||
|
return int(match.group(1))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _format_buffer_results(state) -> str:
|
||||||
|
"""Format buffer results as XML array."""
|
||||||
|
lines = ["<results>"]
|
||||||
|
for i in range(state.total_items):
|
||||||
|
result = state.results.get(i)
|
||||||
|
if result:
|
||||||
|
success = "true" if result.success else "false"
|
||||||
|
lines.append(f' <item index="{i}" success="{success}">')
|
||||||
|
# Indent the result content
|
||||||
|
for line in result.result.split("\n"):
|
||||||
|
lines.append(f" {line}")
|
||||||
|
lines.append(" </item>")
|
||||||
|
else:
|
||||||
|
lines.append(f' <item index="{i}" success="false">missing</item>')
|
||||||
|
lines.append("</results>")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
async def _inject_buffer_item(
|
||||||
|
pump,
|
||||||
|
target: str,
|
||||||
|
payload_xml: str,
|
||||||
|
thread_id: str,
|
||||||
|
from_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Inject a buffer item directly into the pump."""
|
||||||
|
# Wrap the payload in an envelope
|
||||||
|
envelope = pump._wrap_in_envelope(
|
||||||
|
payload=_RawXmlPayload(payload_xml),
|
||||||
|
from_id=from_id,
|
||||||
|
to_id=target,
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inject into pump
|
||||||
|
await pump.inject(envelope, thread_id=thread_id, from_id=from_id)
|
||||||
|
|
||||||
|
|
||||||
|
class _RawXmlPayload:
|
||||||
|
"""Carrier for raw XML that bypasses serialization."""
|
||||||
|
|
||||||
|
def __init__(self, xml: str):
|
||||||
|
self.xml = xml
|
||||||
|
|
||||||
|
def to_xml(self) -> str:
|
||||||
|
"""Return raw XML for envelope wrapping."""
|
||||||
|
return self.xml
|
||||||
299
xml_pipeline/primitives/sequence.py
Normal file
299
xml_pipeline/primitives/sequence.py
Normal file
|
|
@ -0,0 +1,299 @@
|
||||||
|
"""
|
||||||
|
sequence.py — Sequence orchestration primitives.
|
||||||
|
|
||||||
|
Sequences chain multiple listeners in order, feeding the output of one step
|
||||||
|
as input to the next. Steps remain transparent - they don't know they're
|
||||||
|
part of a sequence.
|
||||||
|
|
||||||
|
Usage by an agent:
|
||||||
|
# Start a sequence: add two numbers, then multiply
|
||||||
|
return HandlerResponse(
|
||||||
|
payload=SequenceStart(
|
||||||
|
steps="calculator.add,calculator.multiply",
|
||||||
|
payload='<AddPayload><a>5</a><b>3</b></AddPayload>',
|
||||||
|
return_to="my-agent",
|
||||||
|
),
|
||||||
|
to="system.sequence",
|
||||||
|
)
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
1. system.sequence receives SequenceStart
|
||||||
|
2. Creates ephemeral listener sequence_{id} to receive step results
|
||||||
|
3. Sends initial payload to first step FROM sequence_{id}
|
||||||
|
4. Step processes and responds → routes to sequence_{id}
|
||||||
|
5. Ephemeral handler advances, sends to next step
|
||||||
|
6. When all steps complete, sends SequenceComplete to return_to
|
||||||
|
7. Cleans up ephemeral listener
|
||||||
|
|
||||||
|
Key insight: Steps use normal .respond() - the ephemeral listener IS the
|
||||||
|
caller in the thread chain, so responses naturally route back to it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
import uuid as uuid_module
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from lxml import etree
|
||||||
|
from third_party.xmlable import xmlify
|
||||||
|
from xml_pipeline.message_bus.message_state import (
|
||||||
|
HandlerMetadata,
|
||||||
|
HandlerResponse,
|
||||||
|
MessageState,
|
||||||
|
)
|
||||||
|
from xml_pipeline.message_bus.sequence_registry import get_sequence_registry
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Payloads
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@xmlify
|
||||||
|
@dataclass
|
||||||
|
class SequenceStart:
|
||||||
|
"""
|
||||||
|
Start a new sequence execution.
|
||||||
|
|
||||||
|
Sent to system.sequence to begin chaining steps.
|
||||||
|
"""
|
||||||
|
steps: str = "" # Comma-separated listener names
|
||||||
|
payload: str = "" # Initial XML payload for first step
|
||||||
|
return_to: str = "" # Where to send final result
|
||||||
|
sequence_id: str = "" # Auto-generated if empty
|
||||||
|
|
||||||
|
|
||||||
|
@xmlify
|
||||||
|
@dataclass
|
||||||
|
class SequenceComplete:
|
||||||
|
"""
|
||||||
|
Sequence completed successfully.
|
||||||
|
|
||||||
|
Sent to the return_to listener when all steps finish.
|
||||||
|
"""
|
||||||
|
sequence_id: str = ""
|
||||||
|
final_result: str = "" # XML result from last step
|
||||||
|
step_count: int = 0 # How many steps were executed
|
||||||
|
|
||||||
|
|
||||||
|
@xmlify
|
||||||
|
@dataclass
|
||||||
|
class SequenceError:
|
||||||
|
"""
|
||||||
|
Sequence failed at a step.
|
||||||
|
|
||||||
|
Sent to return_to when a step fails.
|
||||||
|
"""
|
||||||
|
sequence_id: str = ""
|
||||||
|
failed_step: str = "" # Which step failed
|
||||||
|
step_index: int = 0 # 0-based index of failed step
|
||||||
|
error: str = "" # Error message
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Handlers
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
async def handle_sequence_start(
|
||||||
|
payload: SequenceStart,
|
||||||
|
metadata: HandlerMetadata,
|
||||||
|
) -> Optional[HandlerResponse]:
|
||||||
|
"""
|
||||||
|
Handle SequenceStart — begin a sequence execution.
|
||||||
|
|
||||||
|
Creates an ephemeral listener for this sequence, stores state,
|
||||||
|
and kicks off the first step.
|
||||||
|
"""
|
||||||
|
from xml_pipeline.message_bus.stream_pump import get_stream_pump
|
||||||
|
|
||||||
|
# Parse and validate
|
||||||
|
steps = [s.strip() for s in payload.steps.split(",") if s.strip()]
|
||||||
|
if not steps:
|
||||||
|
logger.error("SequenceStart with no steps")
|
||||||
|
return HandlerResponse(
|
||||||
|
payload=SequenceError(
|
||||||
|
sequence_id=payload.sequence_id or "unknown",
|
||||||
|
failed_step="",
|
||||||
|
step_index=0,
|
||||||
|
error="No steps specified",
|
||||||
|
),
|
||||||
|
to=payload.return_to or metadata.from_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate sequence ID if not provided
|
||||||
|
seq_id = payload.sequence_id or str(uuid_module.uuid4())[:8]
|
||||||
|
|
||||||
|
# Validate all steps exist
|
||||||
|
pump = get_stream_pump()
|
||||||
|
for step in steps:
|
||||||
|
if step not in pump.listeners:
|
||||||
|
logger.error(f"SequenceStart: unknown step '{step}'")
|
||||||
|
return HandlerResponse(
|
||||||
|
payload=SequenceError(
|
||||||
|
sequence_id=seq_id,
|
||||||
|
failed_step=step,
|
||||||
|
step_index=steps.index(step),
|
||||||
|
error=f"Unknown listener: {step}",
|
||||||
|
),
|
||||||
|
to=payload.return_to or metadata.from_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create sequence state
|
||||||
|
registry = get_sequence_registry()
|
||||||
|
state = registry.create(
|
||||||
|
sequence_id=seq_id,
|
||||||
|
steps=steps,
|
||||||
|
return_to=payload.return_to or metadata.from_id,
|
||||||
|
thread_id=metadata.thread_id,
|
||||||
|
from_id=metadata.from_id,
|
||||||
|
initial_payload=payload.payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create ephemeral handler for this sequence
|
||||||
|
ephemeral_name = f"sequence_{seq_id}"
|
||||||
|
|
||||||
|
async def sequence_handler(
|
||||||
|
payload_tree: etree._Element,
|
||||||
|
meta: HandlerMetadata,
|
||||||
|
) -> Optional[HandlerResponse]:
|
||||||
|
"""Ephemeral handler that processes step results."""
|
||||||
|
return await _handle_sequence_step_result(seq_id, payload_tree, meta)
|
||||||
|
|
||||||
|
# Register ephemeral listener (generic mode - accepts any payload)
|
||||||
|
pump.register_generic_listener(
|
||||||
|
name=ephemeral_name,
|
||||||
|
handler=sequence_handler,
|
||||||
|
description=f"Ephemeral sequence handler for {seq_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Sequence {seq_id} started: {len(steps)} steps, "
|
||||||
|
f"return_to={state.return_to}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Kick off first step
|
||||||
|
first_step = steps[0]
|
||||||
|
return _create_step_message(
|
||||||
|
seq_id=seq_id,
|
||||||
|
target=first_step,
|
||||||
|
payload_xml=payload.payload,
|
||||||
|
from_name=ephemeral_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_sequence_step_result(
|
||||||
|
seq_id: str,
|
||||||
|
payload_tree: etree._Element,
|
||||||
|
metadata: HandlerMetadata,
|
||||||
|
) -> Optional[HandlerResponse]:
|
||||||
|
"""
|
||||||
|
Handle a step result in the sequence.
|
||||||
|
|
||||||
|
Called by the ephemeral listener when a step responds.
|
||||||
|
"""
|
||||||
|
from xml_pipeline.message_bus.stream_pump import get_stream_pump
|
||||||
|
|
||||||
|
registry = get_sequence_registry()
|
||||||
|
state = registry.get(seq_id)
|
||||||
|
|
||||||
|
if state is None:
|
||||||
|
logger.error(f"Sequence {seq_id} not found in registry")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Serialize the result for storage
|
||||||
|
result_xml = etree.tostring(payload_tree, encoding="unicode")
|
||||||
|
|
||||||
|
# Check for error responses
|
||||||
|
if payload_tree.tag.lower() in ("huh", "systemerror"):
|
||||||
|
# Step failed
|
||||||
|
error_text = payload_tree.text or etree.tostring(payload_tree, encoding="unicode")
|
||||||
|
registry.mark_failed(seq_id, state.current_step or "unknown", error_text)
|
||||||
|
|
||||||
|
# Clean up and send error
|
||||||
|
pump = get_stream_pump()
|
||||||
|
pump.unregister_listener(f"sequence_{seq_id}")
|
||||||
|
registry.remove(seq_id)
|
||||||
|
|
||||||
|
logger.warning(f"Sequence {seq_id} failed at step {state.current_index}")
|
||||||
|
return HandlerResponse(
|
||||||
|
payload=SequenceError(
|
||||||
|
sequence_id=seq_id,
|
||||||
|
failed_step=state.current_step or "unknown",
|
||||||
|
step_index=state.current_index,
|
||||||
|
error=error_text[:200], # Truncate long errors
|
||||||
|
),
|
||||||
|
to=state.return_to,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Advance to next step
|
||||||
|
state = registry.advance(seq_id, result_xml)
|
||||||
|
|
||||||
|
if state.is_complete:
|
||||||
|
# All steps done - send completion
|
||||||
|
pump = get_stream_pump()
|
||||||
|
pump.unregister_listener(f"sequence_{seq_id}")
|
||||||
|
registry.remove(seq_id)
|
||||||
|
|
||||||
|
logger.info(f"Sequence {seq_id} completed: {len(state.steps)} steps")
|
||||||
|
return HandlerResponse(
|
||||||
|
payload=SequenceComplete(
|
||||||
|
sequence_id=seq_id,
|
||||||
|
final_result=result_xml,
|
||||||
|
step_count=len(state.steps),
|
||||||
|
),
|
||||||
|
to=state.return_to,
|
||||||
|
)
|
||||||
|
|
||||||
|
# More steps to go - send to next step
|
||||||
|
next_step = state.current_step
|
||||||
|
logger.debug(
|
||||||
|
f"Sequence {seq_id} advancing to step {state.current_index}: {next_step}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return _create_step_message(
|
||||||
|
seq_id=seq_id,
|
||||||
|
target=next_step,
|
||||||
|
payload_xml=result_xml,
|
||||||
|
from_name=f"sequence_{seq_id}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_step_message(
|
||||||
|
seq_id: str,
|
||||||
|
target: str,
|
||||||
|
payload_xml: str,
|
||||||
|
from_name: str,
|
||||||
|
) -> HandlerResponse:
|
||||||
|
"""
|
||||||
|
Create a HandlerResponse to send payload to a step.
|
||||||
|
|
||||||
|
We need to inject the message with the ephemeral listener as the sender,
|
||||||
|
so that .respond() routes back to us.
|
||||||
|
"""
|
||||||
|
from xml_pipeline.primitives.sequence import _RawPayloadCarrier
|
||||||
|
|
||||||
|
# Return a special carrier that tells the pump to:
|
||||||
|
# 1. Use the raw XML bytes directly
|
||||||
|
# 2. Set from_id to from_name (the ephemeral listener)
|
||||||
|
return HandlerResponse(
|
||||||
|
payload=_RawPayloadCarrier(xml=payload_xml, from_override=from_name),
|
||||||
|
to=target,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _RawPayloadCarrier:
|
||||||
|
"""
|
||||||
|
Internal carrier for raw XML that bypasses normal serialization.
|
||||||
|
|
||||||
|
When the pump sees this, it uses the raw XML directly instead of
|
||||||
|
serializing a dataclass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, xml: str, from_override: Optional[str] = None):
|
||||||
|
self.xml = xml
|
||||||
|
self.from_override = from_override
|
||||||
|
|
||||||
|
def to_xml(self) -> str:
|
||||||
|
"""Return raw XML for envelope wrapping."""
|
||||||
|
return self.xml
|
||||||
Loading…
Reference in a new issue