Add Sequence and Buffer orchestration primitives
Implement two virtual node patterns for message flow orchestration: - Sequence: Chains listeners in order (A→B→C), feeding each step's output as input to the next. Uses ephemeral listeners to intercept step results without modifying core pump behavior. - Buffer: Fan-out to parallel worker threads with optional result collection. Supports fire-and-forget mode (collect=False) for non-blocking dispatch. New files: - sequence_registry.py / buffer_registry.py: State tracking - sequence.py / buffer.py: Payloads and handlers - test_sequence.py / test_buffer.py: 52 new tests Pump additions: - register_generic_listener(): Accept any payload type - unregister_listener(): Cleanup ephemeral listeners - Global singleton accessors for pump instance Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
a69eae79c5
commit
a623c534d5
10 changed files with 2465 additions and 2 deletions
635
tests/test_buffer.py
Normal file
635
tests/test_buffer.py
Normal file
|
|
@ -0,0 +1,635 @@
|
|||
"""
|
||||
test_buffer.py — Tests for the Buffer (fan-out) orchestration primitives.
|
||||
|
||||
Tests:
|
||||
1. BufferRegistry basic operations
|
||||
2. BufferStart handler
|
||||
3. Result collection
|
||||
4. Fire-and-forget mode
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import uuid
|
||||
|
||||
from xml_pipeline.message_bus.buffer_registry import (
|
||||
BufferRegistry,
|
||||
BufferState,
|
||||
BufferItemResult,
|
||||
get_buffer_registry,
|
||||
reset_buffer_registry,
|
||||
)
|
||||
from xml_pipeline.message_bus.message_state import HandlerMetadata, HandlerResponse
|
||||
from xml_pipeline.primitives.buffer import (
|
||||
BufferStart,
|
||||
BufferComplete,
|
||||
BufferDispatched,
|
||||
BufferError,
|
||||
handle_buffer_start,
|
||||
_extract_worker_index,
|
||||
_format_buffer_results,
|
||||
)
|
||||
|
||||
|
||||
class TestBufferRegistry:
|
||||
"""Test BufferRegistry basic operations."""
|
||||
|
||||
def test_create_buffer_state(self):
|
||||
"""create() should create a buffer state with correct fields."""
|
||||
registry = BufferRegistry()
|
||||
|
||||
state = registry.create(
|
||||
buffer_id="buf001",
|
||||
total_items=5,
|
||||
return_to="caller",
|
||||
thread_id="thread-123",
|
||||
from_id="console",
|
||||
target="worker",
|
||||
collect=True,
|
||||
)
|
||||
|
||||
assert state.buffer_id == "buf001"
|
||||
assert state.total_items == 5
|
||||
assert state.return_to == "caller"
|
||||
assert state.thread_id == "thread-123"
|
||||
assert state.from_id == "console"
|
||||
assert state.target == "worker"
|
||||
assert state.collect is True
|
||||
assert state.completed_count == 0
|
||||
assert state.successful_count == 0
|
||||
assert state.is_complete is False
|
||||
assert state.pending_count == 5
|
||||
|
||||
def test_get_returns_buffer(self):
|
||||
"""get() should return buffer by ID."""
|
||||
registry = BufferRegistry()
|
||||
|
||||
registry.create(
|
||||
buffer_id="buf002",
|
||||
total_items=3,
|
||||
return_to="x",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
target="w",
|
||||
)
|
||||
|
||||
state = registry.get("buf002")
|
||||
assert state is not None
|
||||
assert state.buffer_id == "buf002"
|
||||
|
||||
# Non-existent returns None
|
||||
assert registry.get("nonexistent") is None
|
||||
|
||||
def test_record_result_stores_result(self):
|
||||
"""record_result() should store result and update counts."""
|
||||
registry = BufferRegistry()
|
||||
|
||||
registry.create(
|
||||
buffer_id="buf003",
|
||||
total_items=3,
|
||||
return_to="x",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
target="w",
|
||||
)
|
||||
|
||||
# Record first result (success)
|
||||
state = registry.record_result(
|
||||
buffer_id="buf003",
|
||||
index=0,
|
||||
result="<Result0/>",
|
||||
success=True,
|
||||
)
|
||||
|
||||
assert state.completed_count == 1
|
||||
assert state.successful_count == 1
|
||||
assert state.is_complete is False
|
||||
assert 0 in state.results
|
||||
assert state.results[0].result == "<Result0/>"
|
||||
assert state.results[0].success is True
|
||||
|
||||
# Record second result (failure)
|
||||
state = registry.record_result(
|
||||
buffer_id="buf003",
|
||||
index=1,
|
||||
result="<Error/>",
|
||||
success=False,
|
||||
error="Timeout",
|
||||
)
|
||||
|
||||
assert state.completed_count == 2
|
||||
assert state.successful_count == 1 # Still just 1
|
||||
assert state.results[1].success is False
|
||||
assert state.results[1].error == "Timeout"
|
||||
|
||||
# Record third result - now complete
|
||||
state = registry.record_result(
|
||||
buffer_id="buf003",
|
||||
index=2,
|
||||
result="<Result2/>",
|
||||
success=True,
|
||||
)
|
||||
|
||||
assert state.completed_count == 3
|
||||
assert state.successful_count == 2
|
||||
assert state.is_complete is True
|
||||
assert state.pending_count == 0
|
||||
|
||||
def test_record_result_ignores_duplicates(self):
|
||||
"""record_result() should not count same index twice."""
|
||||
registry = BufferRegistry()
|
||||
|
||||
registry.create(
|
||||
buffer_id="buf004",
|
||||
total_items=2,
|
||||
return_to="x",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
target="w",
|
||||
)
|
||||
|
||||
# Record index 0
|
||||
state = registry.record_result("buf004", 0, "<R/>", True)
|
||||
assert state.completed_count == 1
|
||||
|
||||
# Try to record index 0 again
|
||||
state = registry.record_result("buf004", 0, "<Duplicate/>", True)
|
||||
assert state.completed_count == 1 # Should not increment
|
||||
assert state.results[0].result == "<R/>" # Original preserved
|
||||
|
||||
def test_remove_deletes_buffer(self):
|
||||
"""remove() should delete buffer from registry."""
|
||||
registry = BufferRegistry()
|
||||
|
||||
registry.create(
|
||||
buffer_id="buf005",
|
||||
total_items=1,
|
||||
return_to="x",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
target="w",
|
||||
)
|
||||
|
||||
assert registry.get("buf005") is not None
|
||||
result = registry.remove("buf005")
|
||||
assert result is True
|
||||
assert registry.get("buf005") is None
|
||||
|
||||
# Remove non-existent returns False
|
||||
assert registry.remove("nonexistent") is False
|
||||
|
||||
def test_list_active(self):
|
||||
"""list_active() should return all active buffer IDs."""
|
||||
registry = BufferRegistry()
|
||||
|
||||
registry.create("buf-a", 1, "x", "t", "c", "w")
|
||||
registry.create("buf-b", 2, "x", "t", "c", "w")
|
||||
registry.create("buf-c", 3, "x", "t", "c", "w")
|
||||
|
||||
active = registry.list_active()
|
||||
assert set(active) == {"buf-a", "buf-b", "buf-c"}
|
||||
|
||||
def test_clear(self):
|
||||
"""clear() should remove all buffers."""
|
||||
registry = BufferRegistry()
|
||||
|
||||
registry.create("buf-1", 1, "x", "t", "c", "w")
|
||||
registry.create("buf-2", 2, "x", "t", "c", "w")
|
||||
|
||||
registry.clear()
|
||||
|
||||
assert registry.list_active() == []
|
||||
|
||||
|
||||
class TestBufferStateProperties:
|
||||
"""Test BufferState computed properties."""
|
||||
|
||||
def test_is_complete_when_all_received(self):
|
||||
"""is_complete should be True when all items are received."""
|
||||
state = BufferState(
|
||||
buffer_id="test",
|
||||
total_items=3,
|
||||
return_to="x",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
target="w",
|
||||
completed_count=3,
|
||||
)
|
||||
assert state.is_complete is True
|
||||
|
||||
def test_pending_count(self):
|
||||
"""pending_count should reflect remaining items."""
|
||||
state = BufferState(
|
||||
buffer_id="test",
|
||||
total_items=5,
|
||||
return_to="x",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
target="w",
|
||||
completed_count=2,
|
||||
)
|
||||
assert state.pending_count == 3
|
||||
|
||||
def test_get_ordered_results(self):
|
||||
"""get_ordered_results should return results in order."""
|
||||
state = BufferState(
|
||||
buffer_id="test",
|
||||
total_items=3,
|
||||
return_to="x",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
target="w",
|
||||
results={
|
||||
0: BufferItemResult(0, "<R0/>", True),
|
||||
2: BufferItemResult(2, "<R2/>", True),
|
||||
# Index 1 missing
|
||||
},
|
||||
)
|
||||
|
||||
ordered = state.get_ordered_results()
|
||||
assert len(ordered) == 3
|
||||
assert ordered[0] is not None
|
||||
assert ordered[0].result == "<R0/>"
|
||||
assert ordered[1] is None # Missing
|
||||
assert ordered[2] is not None
|
||||
assert ordered[2].result == "<R2/>"
|
||||
|
||||
|
||||
class TestBufferStartPayload:
|
||||
"""Test BufferStart payload."""
|
||||
|
||||
def test_buffer_start_fields(self):
|
||||
"""BufferStart should have expected fields."""
|
||||
payload = BufferStart(
|
||||
target="worker",
|
||||
items="item1\nitem2\nitem3",
|
||||
collect=True,
|
||||
return_to="caller",
|
||||
buffer_id="custom-id",
|
||||
)
|
||||
|
||||
assert payload.target == "worker"
|
||||
assert payload.items == "item1\nitem2\nitem3"
|
||||
assert payload.collect is True
|
||||
assert payload.return_to == "caller"
|
||||
assert payload.buffer_id == "custom-id"
|
||||
|
||||
def test_buffer_start_default_values(self):
|
||||
"""BufferStart should have sensible defaults."""
|
||||
payload = BufferStart()
|
||||
|
||||
assert payload.target == ""
|
||||
assert payload.items == ""
|
||||
assert payload.collect is True
|
||||
assert payload.return_to == ""
|
||||
assert payload.buffer_id == ""
|
||||
|
||||
|
||||
class TestBufferCompletePayload:
|
||||
"""Test BufferComplete payload."""
|
||||
|
||||
def test_buffer_complete_fields(self):
|
||||
"""BufferComplete should have expected fields."""
|
||||
payload = BufferComplete(
|
||||
buffer_id="buf123",
|
||||
total=5,
|
||||
successful=4,
|
||||
results="<results>...</results>",
|
||||
)
|
||||
|
||||
assert payload.buffer_id == "buf123"
|
||||
assert payload.total == 5
|
||||
assert payload.successful == 4
|
||||
assert payload.results == "<results>...</results>"
|
||||
|
||||
|
||||
class TestBufferDispatchedPayload:
|
||||
"""Test BufferDispatched payload (fire-and-forget mode)."""
|
||||
|
||||
def test_buffer_dispatched_fields(self):
|
||||
"""BufferDispatched should have expected fields."""
|
||||
payload = BufferDispatched(
|
||||
buffer_id="buf456",
|
||||
total=10,
|
||||
)
|
||||
|
||||
assert payload.buffer_id == "buf456"
|
||||
assert payload.total == 10
|
||||
|
||||
|
||||
class TestBufferErrorPayload:
|
||||
"""Test BufferError payload."""
|
||||
|
||||
def test_buffer_error_fields(self):
|
||||
"""BufferError should have expected fields."""
|
||||
payload = BufferError(
|
||||
buffer_id="buf789",
|
||||
error="Unknown target listener",
|
||||
)
|
||||
|
||||
assert payload.buffer_id == "buf789"
|
||||
assert payload.error == "Unknown target listener"
|
||||
|
||||
|
||||
class TestExtractWorkerIndex:
|
||||
"""Test _extract_worker_index helper."""
|
||||
|
||||
def test_extracts_index_from_chain(self):
|
||||
"""Should extract worker index from thread chain."""
|
||||
chain = "root.parent.buffer_abc123_w5"
|
||||
index = _extract_worker_index(chain, "abc123")
|
||||
assert index == 5
|
||||
|
||||
def test_extracts_double_digit_index(self):
|
||||
"""Should handle double-digit indices."""
|
||||
chain = "x.buffer_xyz_w42"
|
||||
index = _extract_worker_index(chain, "xyz")
|
||||
assert index == 42
|
||||
|
||||
def test_returns_none_for_no_match(self):
|
||||
"""Should return None when pattern doesn't match."""
|
||||
chain = "something.else.entirely"
|
||||
index = _extract_worker_index(chain, "abc")
|
||||
assert index is None
|
||||
|
||||
def test_returns_none_for_wrong_buffer_id(self):
|
||||
"""Should return None when buffer ID doesn't match."""
|
||||
chain = "root.buffer_other_w3"
|
||||
index = _extract_worker_index(chain, "abc")
|
||||
assert index is None
|
||||
|
||||
|
||||
class TestFormatBufferResults:
|
||||
"""Test _format_buffer_results helper."""
|
||||
|
||||
def test_formats_complete_results(self):
|
||||
"""Should format all results as XML."""
|
||||
state = BufferState(
|
||||
buffer_id="test",
|
||||
total_items=2,
|
||||
return_to="x",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
target="w",
|
||||
results={
|
||||
0: BufferItemResult(0, "<Result>A</Result>", True),
|
||||
1: BufferItemResult(1, "<Result>B</Result>", True),
|
||||
},
|
||||
)
|
||||
|
||||
xml = _format_buffer_results(state)
|
||||
|
||||
assert "<results>" in xml
|
||||
assert "</results>" in xml
|
||||
assert 'index="0"' in xml
|
||||
assert 'index="1"' in xml
|
||||
assert 'success="true"' in xml
|
||||
assert "<Result>A</Result>" in xml
|
||||
assert "<Result>B</Result>" in xml
|
||||
|
||||
def test_formats_partial_failure(self):
|
||||
"""Should format mixed success/failure results."""
|
||||
state = BufferState(
|
||||
buffer_id="test",
|
||||
total_items=2,
|
||||
return_to="x",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
target="w",
|
||||
results={
|
||||
0: BufferItemResult(0, "<Good/>", True),
|
||||
1: BufferItemResult(1, "<Error/>", False, "timeout"),
|
||||
},
|
||||
)
|
||||
|
||||
xml = _format_buffer_results(state)
|
||||
|
||||
assert 'success="true"' in xml
|
||||
assert 'success="false"' in xml
|
||||
|
||||
def test_formats_missing_results(self):
|
||||
"""Should handle missing results."""
|
||||
state = BufferState(
|
||||
buffer_id="test",
|
||||
total_items=3,
|
||||
return_to="x",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
target="w",
|
||||
results={
|
||||
0: BufferItemResult(0, "<R/>", True),
|
||||
# Index 1 missing
|
||||
2: BufferItemResult(2, "<R/>", True),
|
||||
},
|
||||
)
|
||||
|
||||
xml = _format_buffer_results(state)
|
||||
|
||||
assert 'index="1"' in xml
|
||||
assert "missing" in xml
|
||||
|
||||
|
||||
class TestHandleBufferStartValidation:
|
||||
"""Test validation in handle_buffer_start."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self):
|
||||
"""Reset registries before each test."""
|
||||
reset_buffer_registry()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_items_returns_error(self):
|
||||
"""handle_buffer_start should return error for empty items."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_pump = MagicMock()
|
||||
mock_pump.listeners = {}
|
||||
|
||||
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
||||
payload = BufferStart(
|
||||
target="worker",
|
||||
items="", # Empty
|
||||
return_to="caller",
|
||||
)
|
||||
|
||||
metadata = HandlerMetadata(
|
||||
thread_id=str(uuid.uuid4()),
|
||||
from_id="console",
|
||||
)
|
||||
|
||||
response = await handle_buffer_start(payload, metadata)
|
||||
|
||||
assert isinstance(response, HandlerResponse)
|
||||
assert isinstance(response.payload, BufferError)
|
||||
assert "No items" in response.payload.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_target_returns_error(self):
|
||||
"""handle_buffer_start should return error for unknown target."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_pump = MagicMock()
|
||||
mock_pump.listeners = {} # No listeners
|
||||
|
||||
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
||||
payload = BufferStart(
|
||||
target="unknown_worker",
|
||||
items="item1\nitem2",
|
||||
return_to="caller",
|
||||
)
|
||||
|
||||
metadata = HandlerMetadata(
|
||||
thread_id=str(uuid.uuid4()),
|
||||
from_id="console",
|
||||
)
|
||||
|
||||
response = await handle_buffer_start(payload, metadata)
|
||||
|
||||
assert isinstance(response, HandlerResponse)
|
||||
assert isinstance(response.payload, BufferError)
|
||||
assert "Unknown target" in response.payload.error
|
||||
|
||||
|
||||
class TestBufferRegistrySingleton:
|
||||
"""Test singleton pattern for BufferRegistry."""
|
||||
|
||||
def test_get_buffer_registry_returns_singleton(self):
|
||||
"""get_buffer_registry should return same instance."""
|
||||
reset_buffer_registry()
|
||||
|
||||
reg1 = get_buffer_registry()
|
||||
reg2 = get_buffer_registry()
|
||||
|
||||
assert reg1 is reg2
|
||||
|
||||
def test_reset_creates_new_instance(self):
|
||||
"""reset_buffer_registry should clear singleton."""
|
||||
reg1 = get_buffer_registry()
|
||||
reg1.create("test", 1, "x", "t", "c", "w")
|
||||
|
||||
reset_buffer_registry()
|
||||
reg2 = get_buffer_registry()
|
||||
|
||||
assert reg2.get("test") is None
|
||||
|
||||
|
||||
class TestBufferCollectVsFireAndForget:
|
||||
"""Test collect=True vs collect=False behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collect_mode_returns_none(self):
|
||||
"""With collect=True, handler should return None (wait for results)."""
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
mock_pump = MagicMock()
|
||||
mock_pump.listeners = {"worker": MagicMock()}
|
||||
mock_pump.register_generic_listener = MagicMock()
|
||||
mock_pump._wrap_in_envelope = MagicMock(return_value=b"<envelope/>")
|
||||
mock_pump.inject = AsyncMock()
|
||||
|
||||
# Mock thread registry
|
||||
mock_thread_registry = MagicMock()
|
||||
mock_thread_registry.lookup = MagicMock(return_value="root.parent")
|
||||
mock_thread_registry.get_or_create = MagicMock(return_value="worker-uuid")
|
||||
|
||||
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
||||
with patch('xml_pipeline.message_bus.thread_registry.get_registry', return_value=mock_thread_registry):
|
||||
payload = BufferStart(
|
||||
target="worker",
|
||||
items="item1\nitem2",
|
||||
return_to="caller",
|
||||
collect=True,
|
||||
)
|
||||
|
||||
metadata = HandlerMetadata(
|
||||
thread_id=str(uuid.uuid4()),
|
||||
from_id="console",
|
||||
)
|
||||
|
||||
response = await handle_buffer_start(payload, metadata)
|
||||
|
||||
# With collect=True, returns None (ephemeral handler will send BufferComplete)
|
||||
assert response is None
|
||||
|
||||
# Ephemeral listener should be registered
|
||||
mock_pump.register_generic_listener.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fire_and_forget_returns_dispatched(self):
|
||||
"""With collect=False, handler should return BufferDispatched immediately."""
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
mock_pump = MagicMock()
|
||||
mock_pump.listeners = {"worker": MagicMock()}
|
||||
mock_pump.register_generic_listener = MagicMock()
|
||||
mock_pump._wrap_in_envelope = MagicMock(return_value=b"<envelope/>")
|
||||
mock_pump.inject = AsyncMock()
|
||||
|
||||
mock_thread_registry = MagicMock()
|
||||
mock_thread_registry.lookup = MagicMock(return_value="root.parent")
|
||||
mock_thread_registry.get_or_create = MagicMock(return_value="worker-uuid")
|
||||
|
||||
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
||||
with patch('xml_pipeline.message_bus.thread_registry.get_registry', return_value=mock_thread_registry):
|
||||
payload = BufferStart(
|
||||
target="worker",
|
||||
items="item1\nitem2\nitem3",
|
||||
return_to="caller",
|
||||
collect=False, # Fire-and-forget
|
||||
)
|
||||
|
||||
metadata = HandlerMetadata(
|
||||
thread_id=str(uuid.uuid4()),
|
||||
from_id="console",
|
||||
)
|
||||
|
||||
response = await handle_buffer_start(payload, metadata)
|
||||
|
||||
# With collect=False, returns BufferDispatched
|
||||
assert isinstance(response, HandlerResponse)
|
||||
assert isinstance(response.payload, BufferDispatched)
|
||||
assert response.payload.total == 3
|
||||
|
||||
# No ephemeral listener registered for fire-and-forget
|
||||
mock_pump.register_generic_listener.assert_not_called()
|
||||
|
||||
|
||||
class TestBufferResultCollection:
|
||||
"""Test result collection behavior."""
|
||||
|
||||
def test_result_collection_partial_success(self):
|
||||
"""Buffer should track partial success correctly."""
|
||||
registry = BufferRegistry()
|
||||
|
||||
registry.create("partial", 5, "x", "t", "c", "w")
|
||||
|
||||
# 3 successes, 2 failures
|
||||
registry.record_result("partial", 0, "<R/>", True)
|
||||
registry.record_result("partial", 1, "<E/>", False, "error")
|
||||
registry.record_result("partial", 2, "<R/>", True)
|
||||
registry.record_result("partial", 3, "<E/>", False, "error")
|
||||
registry.record_result("partial", 4, "<R/>", True)
|
||||
|
||||
state = registry.get("partial")
|
||||
|
||||
assert state.is_complete is True
|
||||
assert state.completed_count == 5
|
||||
assert state.successful_count == 3
|
||||
|
||||
def test_results_out_of_order(self):
|
||||
"""Buffer should handle results arriving out of order."""
|
||||
registry = BufferRegistry()
|
||||
|
||||
registry.create("ooo", 3, "x", "t", "c", "w")
|
||||
|
||||
# Results arrive out of order
|
||||
registry.record_result("ooo", 2, "<R2/>", True)
|
||||
registry.record_result("ooo", 0, "<R0/>", True)
|
||||
registry.record_result("ooo", 1, "<R1/>", True)
|
||||
|
||||
state = registry.get("ooo")
|
||||
|
||||
assert state.is_complete is True
|
||||
ordered = state.get_ordered_results()
|
||||
assert ordered[0].result == "<R0/>"
|
||||
assert ordered[1].result == "<R1/>"
|
||||
assert ordered[2].result == "<R2/>"
|
||||
|
|
@ -66,7 +66,7 @@ class TestPumpBootstrap:
|
|||
pump = await bootstrap('config/organism.yaml')
|
||||
|
||||
assert pump.config.name == "hello-world"
|
||||
assert len(pump.routing_table) == 6 # 3 user listeners + 3 system (boot, todo, todo-complete)
|
||||
assert len(pump.routing_table) == 8 # 3 user listeners + 5 system (boot, todo, todo-complete, sequence, buffer)
|
||||
assert "greeter.greeting" in pump.routing_table
|
||||
assert "shouter.greetingresponse" in pump.routing_table
|
||||
assert "response-handler.shoutedresponse" in pump.routing_table
|
||||
|
|
|
|||
464
tests/test_sequence.py
Normal file
464
tests/test_sequence.py
Normal file
|
|
@ -0,0 +1,464 @@
|
|||
"""
|
||||
test_sequence.py — Tests for the Sequence orchestration primitives.
|
||||
|
||||
Tests:
|
||||
1. SequenceRegistry basic operations
|
||||
2. SequenceStart handler
|
||||
3. Step result handling
|
||||
4. Error propagation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import uuid
|
||||
|
||||
from xml_pipeline.message_bus.sequence_registry import (
|
||||
SequenceRegistry,
|
||||
SequenceState,
|
||||
get_sequence_registry,
|
||||
reset_sequence_registry,
|
||||
)
|
||||
from xml_pipeline.message_bus.message_state import HandlerMetadata, HandlerResponse
|
||||
from xml_pipeline.primitives.sequence import (
|
||||
SequenceStart,
|
||||
SequenceComplete,
|
||||
SequenceError,
|
||||
handle_sequence_start,
|
||||
)
|
||||
|
||||
|
||||
class TestSequenceRegistry:
|
||||
"""Test SequenceRegistry basic operations."""
|
||||
|
||||
def test_create_sequence_state(self):
|
||||
"""create() should create a sequence state with correct fields."""
|
||||
registry = SequenceRegistry()
|
||||
|
||||
state = registry.create(
|
||||
sequence_id="seq001",
|
||||
steps=["step1", "step2", "step3"],
|
||||
return_to="caller",
|
||||
thread_id="thread-123",
|
||||
from_id="console",
|
||||
initial_payload="<TestPayload/>",
|
||||
)
|
||||
|
||||
assert state.sequence_id == "seq001"
|
||||
assert state.steps == ["step1", "step2", "step3"]
|
||||
assert state.return_to == "caller"
|
||||
assert state.thread_id == "thread-123"
|
||||
assert state.from_id == "console"
|
||||
assert state.current_index == 0
|
||||
assert state.is_complete is False
|
||||
assert state.current_step == "step1"
|
||||
assert state.remaining_steps == ["step1", "step2", "step3"]
|
||||
assert state.last_result == "<TestPayload/>"
|
||||
|
||||
def test_get_returns_sequence(self):
|
||||
"""get() should return sequence by ID."""
|
||||
registry = SequenceRegistry()
|
||||
|
||||
registry.create(
|
||||
sequence_id="seq002",
|
||||
steps=["a", "b"],
|
||||
return_to="x",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
)
|
||||
|
||||
state = registry.get("seq002")
|
||||
assert state is not None
|
||||
assert state.sequence_id == "seq002"
|
||||
|
||||
# Non-existent returns None
|
||||
assert registry.get("nonexistent") is None
|
||||
|
||||
def test_advance_increments_index(self):
|
||||
"""advance() should increment index and store result."""
|
||||
registry = SequenceRegistry()
|
||||
|
||||
registry.create(
|
||||
sequence_id="seq003",
|
||||
steps=["a", "b", "c"],
|
||||
return_to="x",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
)
|
||||
|
||||
# Advance first step
|
||||
state = registry.advance("seq003", "<ResultA/>")
|
||||
assert state.current_index == 1
|
||||
assert state.results == ["<ResultA/>"]
|
||||
assert state.last_result == "<ResultA/>"
|
||||
assert state.current_step == "b"
|
||||
assert state.is_complete is False
|
||||
|
||||
# Advance second step
|
||||
state = registry.advance("seq003", "<ResultB/>")
|
||||
assert state.current_index == 2
|
||||
assert state.results == ["<ResultA/>", "<ResultB/>"]
|
||||
assert state.current_step == "c"
|
||||
|
||||
# Advance third step - now complete
|
||||
state = registry.advance("seq003", "<ResultC/>")
|
||||
assert state.current_index == 3
|
||||
assert state.is_complete is True
|
||||
assert state.current_step is None
|
||||
assert state.remaining_steps == []
|
||||
|
||||
def test_mark_failed(self):
|
||||
"""mark_failed() should set failed state."""
|
||||
registry = SequenceRegistry()
|
||||
|
||||
registry.create(
|
||||
sequence_id="seq004",
|
||||
steps=["a", "b"],
|
||||
return_to="x",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
)
|
||||
|
||||
state = registry.mark_failed("seq004", "a", "XSD validation failed")
|
||||
|
||||
assert state.failed is True
|
||||
assert state.failed_step == "a"
|
||||
assert state.error == "XSD validation failed"
|
||||
assert state.is_complete is False # Not complete, but failed
|
||||
assert state.current_step is None # No current step when failed
|
||||
assert state.remaining_steps == [] # No remaining steps when failed
|
||||
|
||||
def test_remove_deletes_sequence(self):
|
||||
"""remove() should delete sequence from registry."""
|
||||
registry = SequenceRegistry()
|
||||
|
||||
registry.create(
|
||||
sequence_id="seq005",
|
||||
steps=["a"],
|
||||
return_to="x",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
)
|
||||
|
||||
assert registry.get("seq005") is not None
|
||||
result = registry.remove("seq005")
|
||||
assert result is True
|
||||
assert registry.get("seq005") is None
|
||||
|
||||
# Remove non-existent returns False
|
||||
assert registry.remove("nonexistent") is False
|
||||
|
||||
def test_list_active(self):
|
||||
"""list_active() should return all active sequence IDs."""
|
||||
registry = SequenceRegistry()
|
||||
|
||||
registry.create("seq-a", ["1"], "x", "t", "c")
|
||||
registry.create("seq-b", ["2"], "x", "t", "c")
|
||||
registry.create("seq-c", ["3"], "x", "t", "c")
|
||||
|
||||
active = registry.list_active()
|
||||
assert set(active) == {"seq-a", "seq-b", "seq-c"}
|
||||
|
||||
def test_clear(self):
|
||||
"""clear() should remove all sequences."""
|
||||
registry = SequenceRegistry()
|
||||
|
||||
registry.create("seq-1", ["a"], "x", "t", "c")
|
||||
registry.create("seq-2", ["b"], "x", "t", "c")
|
||||
|
||||
registry.clear()
|
||||
|
||||
assert registry.list_active() == []
|
||||
|
||||
|
||||
class TestSequenceStateProperties:
|
||||
"""Test SequenceState computed properties."""
|
||||
|
||||
def test_is_complete_after_all_steps(self):
|
||||
"""is_complete should be True when all steps are done."""
|
||||
state = SequenceState(
|
||||
sequence_id="test",
|
||||
steps=["a", "b"],
|
||||
return_to="x",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
current_index=2, # Past all steps
|
||||
)
|
||||
assert state.is_complete is True
|
||||
|
||||
def test_not_complete_when_failed(self):
|
||||
"""is_complete should be False when failed."""
|
||||
state = SequenceState(
|
||||
sequence_id="test",
|
||||
steps=["a", "b"],
|
||||
return_to="x",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
current_index=2,
|
||||
failed=True,
|
||||
)
|
||||
assert state.is_complete is False
|
||||
|
||||
def test_current_step_none_when_complete(self):
|
||||
"""current_step should be None when sequence is complete."""
|
||||
state = SequenceState(
|
||||
sequence_id="test",
|
||||
steps=["a"],
|
||||
return_to="x",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
current_index=1,
|
||||
)
|
||||
assert state.current_step is None
|
||||
|
||||
def test_remaining_steps_empty_when_complete(self):
|
||||
"""remaining_steps should be empty when complete."""
|
||||
state = SequenceState(
|
||||
sequence_id="test",
|
||||
steps=["a", "b"],
|
||||
return_to="x",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
current_index=2,
|
||||
)
|
||||
assert state.remaining_steps == []
|
||||
|
||||
|
||||
class TestSequenceStartPayload:
|
||||
"""Test SequenceStart payload serialization."""
|
||||
|
||||
def test_sequence_start_fields(self):
|
||||
"""SequenceStart should have expected fields."""
|
||||
payload = SequenceStart(
|
||||
steps="step1,step2",
|
||||
payload="<Test/>",
|
||||
return_to="caller",
|
||||
sequence_id="custom-id",
|
||||
)
|
||||
|
||||
assert payload.steps == "step1,step2"
|
||||
assert payload.payload == "<Test/>"
|
||||
assert payload.return_to == "caller"
|
||||
assert payload.sequence_id == "custom-id"
|
||||
|
||||
def test_sequence_start_default_values(self):
|
||||
"""SequenceStart should have sensible defaults."""
|
||||
payload = SequenceStart()
|
||||
|
||||
assert payload.steps == ""
|
||||
assert payload.payload == ""
|
||||
assert payload.return_to == ""
|
||||
assert payload.sequence_id == ""
|
||||
|
||||
|
||||
class TestSequenceCompletePayload:
|
||||
"""Test SequenceComplete payload."""
|
||||
|
||||
def test_sequence_complete_fields(self):
|
||||
"""SequenceComplete should have expected fields."""
|
||||
payload = SequenceComplete(
|
||||
sequence_id="seq123",
|
||||
final_result="<Result>42</Result>",
|
||||
step_count=3,
|
||||
)
|
||||
|
||||
assert payload.sequence_id == "seq123"
|
||||
assert payload.final_result == "<Result>42</Result>"
|
||||
assert payload.step_count == 3
|
||||
|
||||
|
||||
class TestSequenceErrorPayload:
|
||||
"""Test SequenceError payload."""
|
||||
|
||||
def test_sequence_error_fields(self):
|
||||
"""SequenceError should have expected fields."""
|
||||
payload = SequenceError(
|
||||
sequence_id="seq456",
|
||||
failed_step="bad-step",
|
||||
step_index=1,
|
||||
error="Validation failed",
|
||||
)
|
||||
|
||||
assert payload.sequence_id == "seq456"
|
||||
assert payload.failed_step == "bad-step"
|
||||
assert payload.step_index == 1
|
||||
assert payload.error == "Validation failed"
|
||||
|
||||
|
||||
class TestHandleSequenceStartValidation:
|
||||
"""Test validation in handle_sequence_start."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self):
|
||||
"""Reset registries before each test."""
|
||||
reset_sequence_registry()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_steps_returns_error(self):
|
||||
"""handle_sequence_start should return error for empty steps."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# Mock get_stream_pump to avoid dependency
|
||||
mock_pump = MagicMock()
|
||||
mock_pump.listeners = {}
|
||||
|
||||
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
||||
payload = SequenceStart(
|
||||
steps="", # Empty
|
||||
payload="<Test/>",
|
||||
return_to="caller",
|
||||
)
|
||||
|
||||
metadata = HandlerMetadata(
|
||||
thread_id=str(uuid.uuid4()),
|
||||
from_id="console",
|
||||
)
|
||||
|
||||
response = await handle_sequence_start(payload, metadata)
|
||||
|
||||
assert isinstance(response, HandlerResponse)
|
||||
assert isinstance(response.payload, SequenceError)
|
||||
assert "No steps" in response.payload.error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_step_returns_error(self):
|
||||
"""handle_sequence_start should return error for unknown step."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# Mock get_stream_pump with no listeners
|
||||
mock_pump = MagicMock()
|
||||
mock_pump.listeners = {} # No listeners registered
|
||||
|
||||
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
||||
payload = SequenceStart(
|
||||
steps="unknown_step",
|
||||
payload="<Test/>",
|
||||
return_to="caller",
|
||||
)
|
||||
|
||||
metadata = HandlerMetadata(
|
||||
thread_id=str(uuid.uuid4()),
|
||||
from_id="console",
|
||||
)
|
||||
|
||||
response = await handle_sequence_start(payload, metadata)
|
||||
|
||||
assert isinstance(response, HandlerResponse)
|
||||
assert isinstance(response.payload, SequenceError)
|
||||
assert "Unknown listener" in response.payload.error
|
||||
assert "unknown_step" in response.payload.failed_step
|
||||
|
||||
|
||||
class TestSequenceRegistrySingleton:
|
||||
"""Test singleton pattern for SequenceRegistry."""
|
||||
|
||||
def test_get_sequence_registry_returns_singleton(self):
|
||||
"""get_sequence_registry should return same instance."""
|
||||
reset_sequence_registry()
|
||||
|
||||
reg1 = get_sequence_registry()
|
||||
reg2 = get_sequence_registry()
|
||||
|
||||
assert reg1 is reg2
|
||||
|
||||
def test_reset_creates_new_instance(self):
|
||||
"""reset_sequence_registry should clear singleton."""
|
||||
reg1 = get_sequence_registry()
|
||||
reg1.create("test", ["a"], "x", "t", "c")
|
||||
|
||||
reset_sequence_registry()
|
||||
reg2 = get_sequence_registry()
|
||||
|
||||
assert reg2.get("test") is None
|
||||
|
||||
|
||||
class TestSequenceMultipleSteps:
|
||||
"""Test sequences with multiple steps."""
|
||||
|
||||
def test_three_step_sequence(self):
|
||||
"""A three-step sequence should advance through all steps."""
|
||||
registry = SequenceRegistry()
|
||||
|
||||
registry.create(
|
||||
sequence_id="multi",
|
||||
steps=["add", "multiply", "format"],
|
||||
return_to="caller",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
initial_payload="<Input>5</Input>",
|
||||
)
|
||||
|
||||
# Step 1: add
|
||||
state = registry.get("multi")
|
||||
assert state.current_step == "add"
|
||||
state = registry.advance("multi", "<Sum>8</Sum>")
|
||||
assert state.last_result == "<Sum>8</Sum>"
|
||||
|
||||
# Step 2: multiply
|
||||
assert state.current_step == "multiply"
|
||||
state = registry.advance("multi", "<Product>40</Product>")
|
||||
|
||||
# Step 3: format
|
||||
assert state.current_step == "format"
|
||||
state = registry.advance("multi", "<Formatted>Result: 40</Formatted>")
|
||||
|
||||
# Complete
|
||||
assert state.is_complete is True
|
||||
assert len(state.results) == 3
|
||||
assert state.results[0] == "<Sum>8</Sum>"
|
||||
assert state.results[1] == "<Product>40</Product>"
|
||||
assert state.results[2] == "<Formatted>Result: 40</Formatted>"
|
||||
|
||||
def test_failure_at_middle_step(self):
|
||||
"""Failure at middle step should stop sequence."""
|
||||
registry = SequenceRegistry()
|
||||
|
||||
registry.create(
|
||||
sequence_id="fail-mid",
|
||||
steps=["step1", "step2", "step3"],
|
||||
return_to="caller",
|
||||
thread_id="t",
|
||||
from_id="c",
|
||||
)
|
||||
|
||||
# Step 1 succeeds
|
||||
registry.advance("fail-mid", "<R1/>")
|
||||
|
||||
# Step 2 fails
|
||||
state = registry.mark_failed("fail-mid", "step2", "Connection timeout")
|
||||
|
||||
assert state.failed is True
|
||||
assert state.failed_step == "step2"
|
||||
assert state.current_index == 1 # Was at step 2 (index 1)
|
||||
assert len(state.results) == 1 # Only step 1 result
|
||||
|
||||
|
||||
class TestSequenceWithRealSteps:
|
||||
"""Integration-style tests with mock handlers."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sequence_creates_ephemeral_listener(self):
|
||||
"""Starting a sequence should create an ephemeral listener."""
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
mock_pump = MagicMock()
|
||||
mock_pump.listeners = {"step1": MagicMock(), "step2": MagicMock()}
|
||||
mock_pump.register_generic_listener = MagicMock()
|
||||
|
||||
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
||||
payload = SequenceStart(
|
||||
steps="step1,step2",
|
||||
payload="<Input/>",
|
||||
return_to="caller",
|
||||
)
|
||||
|
||||
metadata = HandlerMetadata(
|
||||
thread_id=str(uuid.uuid4()),
|
||||
from_id="console",
|
||||
)
|
||||
|
||||
response = await handle_sequence_start(payload, metadata)
|
||||
|
||||
# Should have registered an ephemeral listener
|
||||
mock_pump.register_generic_listener.assert_called_once()
|
||||
call_args = mock_pump.register_generic_listener.call_args
|
||||
name_arg = call_args.kwargs.get('name') or call_args.args[0]
|
||||
assert name_arg.startswith("sequence_")
|
||||
|
|
@ -30,6 +30,9 @@ from xml_pipeline.message_bus.stream_pump import (
|
|||
ListenerConfig,
|
||||
OrganismConfig,
|
||||
bootstrap,
|
||||
get_stream_pump,
|
||||
set_stream_pump,
|
||||
reset_stream_pump,
|
||||
)
|
||||
|
||||
from xml_pipeline.message_bus.message_state import (
|
||||
|
|
@ -42,15 +45,47 @@ from xml_pipeline.message_bus.system_pipeline import (
|
|||
ExternalMessage,
|
||||
)
|
||||
|
||||
from xml_pipeline.message_bus.sequence_registry import (
|
||||
SequenceState,
|
||||
SequenceRegistry,
|
||||
get_sequence_registry,
|
||||
reset_sequence_registry,
|
||||
)
|
||||
|
||||
from xml_pipeline.message_bus.buffer_registry import (
|
||||
BufferState,
|
||||
BufferItemResult,
|
||||
BufferRegistry,
|
||||
get_buffer_registry,
|
||||
reset_buffer_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Pump
|
||||
"StreamPump",
|
||||
"ConfigLoader",
|
||||
"Listener",
|
||||
"ListenerConfig",
|
||||
"OrganismConfig",
|
||||
"bootstrap",
|
||||
"get_stream_pump",
|
||||
"set_stream_pump",
|
||||
"reset_stream_pump",
|
||||
# Message state
|
||||
"MessageState",
|
||||
"HandlerMetadata",
|
||||
"bootstrap",
|
||||
# System pipeline
|
||||
"SystemPipeline",
|
||||
"ExternalMessage",
|
||||
# Sequence registry
|
||||
"SequenceState",
|
||||
"SequenceRegistry",
|
||||
"get_sequence_registry",
|
||||
"reset_sequence_registry",
|
||||
# Buffer registry
|
||||
"BufferState",
|
||||
"BufferItemResult",
|
||||
"BufferRegistry",
|
||||
"get_buffer_registry",
|
||||
"reset_buffer_registry",
|
||||
]
|
||||
|
|
|
|||
230
xml_pipeline/message_bus/buffer_registry.py
Normal file
230
xml_pipeline/message_bus/buffer_registry.py
Normal file
|
|
@ -0,0 +1,230 @@
|
|||
"""
|
||||
buffer_registry.py — State storage for Buffer (fan-out) orchestration.
|
||||
|
||||
Tracks active buffer executions that fan-out to parallel workers.
|
||||
When a buffer starts, N items are dispatched in parallel. Results are
|
||||
collected here. When all results are in (or timeout), BufferComplete is sent.
|
||||
|
||||
Design:
|
||||
- Thread-safe (same pattern as TodoRegistry, SequenceRegistry)
|
||||
- Keyed by buffer_id (short UUID)
|
||||
- Tracks: total items, received results, success/failure per item
|
||||
- Supports fire-and-forget mode (collect=False)
|
||||
|
||||
Usage:
|
||||
registry = get_buffer_registry()
|
||||
|
||||
# Start a buffer
|
||||
registry.create(
|
||||
buffer_id="abc123",
|
||||
total_items=5,
|
||||
return_to="greeter",
|
||||
thread_id="...",
|
||||
collect=True,
|
||||
)
|
||||
|
||||
# Record result from worker
|
||||
state = registry.record_result(
|
||||
buffer_id="abc123",
|
||||
index=2,
|
||||
result="<SearchResult>...</SearchResult>",
|
||||
success=True,
|
||||
)
|
||||
|
||||
if state.is_complete:
|
||||
# All workers done
|
||||
final_results = state.results
|
||||
registry.remove(buffer_id)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Any
|
||||
import threading
|
||||
|
||||
|
||||
@dataclass
|
||||
class BufferItemResult:
|
||||
"""Result from a single buffer item (worker)."""
|
||||
|
||||
index: int
|
||||
result: str # XML result
|
||||
success: bool = True
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BufferState:
|
||||
"""State for an active buffer execution."""
|
||||
|
||||
buffer_id: str
|
||||
total_items: int # How many items were dispatched
|
||||
return_to: str # Where to send BufferComplete
|
||||
thread_id: str # Original thread for returning
|
||||
from_id: str # Who started the buffer
|
||||
target: str # Target listener for items
|
||||
collect: bool = True # Whether to wait for results
|
||||
|
||||
results: Dict[int, BufferItemResult] = field(default_factory=dict)
|
||||
completed_count: int = 0
|
||||
successful_count: int = 0
|
||||
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
"""True when all items have reported back."""
|
||||
return self.completed_count >= self.total_items
|
||||
|
||||
@property
|
||||
def pending_count(self) -> int:
|
||||
"""Number of items still pending."""
|
||||
return self.total_items - self.completed_count
|
||||
|
||||
def get_ordered_results(self) -> List[Optional[BufferItemResult]]:
|
||||
"""Get results in order (None for missing indices)."""
|
||||
return [self.results.get(i) for i in range(self.total_items)]
|
||||
|
||||
|
||||
class BufferRegistry:
|
||||
"""
|
||||
Registry for active buffer executions.
|
||||
|
||||
Thread-safe. Singleton pattern via get_buffer_registry().
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._buffers: Dict[str, BufferState] = {}
|
||||
|
||||
def create(
|
||||
self,
|
||||
buffer_id: str,
|
||||
total_items: int,
|
||||
return_to: str,
|
||||
thread_id: str,
|
||||
from_id: str,
|
||||
target: str,
|
||||
collect: bool = True,
|
||||
) -> BufferState:
|
||||
"""
|
||||
Create a new buffer execution.
|
||||
|
||||
Args:
|
||||
buffer_id: Unique ID for this buffer
|
||||
total_items: Number of items being dispatched
|
||||
return_to: Listener to send BufferComplete to
|
||||
thread_id: Thread UUID for routing
|
||||
from_id: Who initiated the buffer
|
||||
target: Target listener for each item
|
||||
collect: Whether to wait for and collect results
|
||||
|
||||
Returns:
|
||||
BufferState for tracking
|
||||
"""
|
||||
state = BufferState(
|
||||
buffer_id=buffer_id,
|
||||
total_items=total_items,
|
||||
return_to=return_to,
|
||||
thread_id=thread_id,
|
||||
from_id=from_id,
|
||||
target=target,
|
||||
collect=collect,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._buffers[buffer_id] = state
|
||||
|
||||
return state
|
||||
|
||||
def get(self, buffer_id: str) -> Optional[BufferState]:
|
||||
"""Get buffer state by ID."""
|
||||
with self._lock:
|
||||
return self._buffers.get(buffer_id)
|
||||
|
||||
def record_result(
|
||||
self,
|
||||
buffer_id: str,
|
||||
index: int,
|
||||
result: str,
|
||||
success: bool = True,
|
||||
error: Optional[str] = None,
|
||||
) -> Optional[BufferState]:
|
||||
"""
|
||||
Record a result from a worker.
|
||||
|
||||
Args:
|
||||
buffer_id: Buffer this result belongs to
|
||||
index: Which item index (0-based)
|
||||
result: XML result from the worker
|
||||
success: Whether the worker succeeded
|
||||
error: Error message if failed
|
||||
|
||||
Returns:
|
||||
Updated BufferState, or None if buffer not found
|
||||
"""
|
||||
with self._lock:
|
||||
state = self._buffers.get(buffer_id)
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
# Don't double-count results for same index
|
||||
if index in state.results:
|
||||
return state
|
||||
|
||||
item_result = BufferItemResult(
|
||||
index=index,
|
||||
result=result,
|
||||
success=success,
|
||||
error=error,
|
||||
)
|
||||
state.results[index] = item_result
|
||||
state.completed_count += 1
|
||||
if success:
|
||||
state.successful_count += 1
|
||||
|
||||
return state
|
||||
|
||||
def remove(self, buffer_id: str) -> bool:
|
||||
"""
|
||||
Remove a buffer (cleanup after completion).
|
||||
|
||||
Returns:
|
||||
True if found and removed, False if not found
|
||||
"""
|
||||
with self._lock:
|
||||
return self._buffers.pop(buffer_id, None) is not None
|
||||
|
||||
def list_active(self) -> List[str]:
|
||||
"""List all active buffer IDs."""
|
||||
with self._lock:
|
||||
return list(self._buffers.keys())
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all buffers. Useful for testing."""
|
||||
with self._lock:
|
||||
self._buffers.clear()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Singleton
|
||||
# ============================================================================
|
||||
|
||||
_registry: Optional[BufferRegistry] = None
|
||||
_registry_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_buffer_registry() -> BufferRegistry:
|
||||
"""Get the global BufferRegistry singleton."""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
with _registry_lock:
|
||||
if _registry is None:
|
||||
_registry = BufferRegistry()
|
||||
return _registry
|
||||
|
||||
|
||||
def reset_buffer_registry() -> None:
|
||||
"""Reset the global buffer registry (for testing)."""
|
||||
global _registry
|
||||
with _registry_lock:
|
||||
if _registry is not None:
|
||||
_registry.clear()
|
||||
_registry = None
|
||||
228
xml_pipeline/message_bus/sequence_registry.py
Normal file
228
xml_pipeline/message_bus/sequence_registry.py
Normal file
|
|
@ -0,0 +1,228 @@
|
|||
"""
|
||||
sequence_registry.py — State storage for Sequence orchestration.
|
||||
|
||||
Tracks active sequence executions across handler invocations.
|
||||
When a sequence starts, its state is registered here. As steps complete,
|
||||
the state is updated. When all steps are done, the state is cleaned up.
|
||||
|
||||
Design:
|
||||
- Thread-safe (same pattern as TodoRegistry)
|
||||
- Keyed by sequence_id (short UUID)
|
||||
- Tracks: steps list, current index, collected results
|
||||
- Auto-cleanup when sequence completes or errors
|
||||
|
||||
Usage:
|
||||
registry = get_sequence_registry()
|
||||
|
||||
# Start a sequence
|
||||
registry.create(
|
||||
sequence_id="abc123",
|
||||
steps=["calculator.add", "calculator.multiply"],
|
||||
return_to="greeter",
|
||||
thread_id="...",
|
||||
initial_payload="<AddPayload>...</AddPayload>",
|
||||
)
|
||||
|
||||
# Advance on step completion
|
||||
state = registry.advance(sequence_id, step_result="<AddResult>42</AddResult>")
|
||||
if state.is_complete:
|
||||
# All steps done
|
||||
registry.remove(sequence_id)
|
||||
|
||||
# On error
|
||||
registry.mark_failed(sequence_id, step="calculator.add", error="XSD validation failed")
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Any
|
||||
import threading
|
||||
|
||||
|
||||
@dataclass
|
||||
class SequenceState:
|
||||
"""State for an active sequence execution."""
|
||||
|
||||
sequence_id: str
|
||||
steps: List[str] # Ordered list of listener names
|
||||
return_to: str # Where to send final result
|
||||
thread_id: str # Original thread for returning
|
||||
from_id: str # Who started the sequence
|
||||
|
||||
current_index: int = 0 # Which step we're on (0-based)
|
||||
results: List[str] = field(default_factory=list) # XML results from each step
|
||||
last_result: Optional[str] = None # Most recent step result (for chaining)
|
||||
|
||||
failed: bool = False
|
||||
failed_step: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
@property
|
||||
def is_complete(self) -> bool:
|
||||
"""True when all steps have been executed successfully."""
|
||||
return not self.failed and self.current_index >= len(self.steps)
|
||||
|
||||
@property
|
||||
def current_step(self) -> Optional[str]:
|
||||
"""Get current step name, or None if complete/failed."""
|
||||
if self.failed or self.current_index >= len(self.steps):
|
||||
return None
|
||||
return self.steps[self.current_index]
|
||||
|
||||
@property
|
||||
def remaining_steps(self) -> List[str]:
|
||||
"""Steps not yet executed."""
|
||||
if self.failed:
|
||||
return []
|
||||
return self.steps[self.current_index:]
|
||||
|
||||
|
||||
class SequenceRegistry:
|
||||
"""
|
||||
Registry for active sequence executions.
|
||||
|
||||
Thread-safe. Singleton pattern via get_sequence_registry().
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._sequences: Dict[str, SequenceState] = {}
|
||||
|
||||
def create(
|
||||
self,
|
||||
sequence_id: str,
|
||||
steps: List[str],
|
||||
return_to: str,
|
||||
thread_id: str,
|
||||
from_id: str,
|
||||
initial_payload: str = "",
|
||||
) -> SequenceState:
|
||||
"""
|
||||
Create a new sequence execution.
|
||||
|
||||
Args:
|
||||
sequence_id: Unique ID for this sequence
|
||||
steps: Ordered list of listener names to call
|
||||
return_to: Listener to send SequenceComplete to
|
||||
thread_id: Thread UUID for routing
|
||||
from_id: Who initiated the sequence
|
||||
initial_payload: XML payload for first step
|
||||
|
||||
Returns:
|
||||
SequenceState for tracking
|
||||
"""
|
||||
state = SequenceState(
|
||||
sequence_id=sequence_id,
|
||||
steps=steps,
|
||||
return_to=return_to,
|
||||
thread_id=thread_id,
|
||||
from_id=from_id,
|
||||
last_result=initial_payload if initial_payload else None,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._sequences[sequence_id] = state
|
||||
|
||||
return state
|
||||
|
||||
def get(self, sequence_id: str) -> Optional[SequenceState]:
|
||||
"""Get sequence state by ID."""
|
||||
with self._lock:
|
||||
return self._sequences.get(sequence_id)
|
||||
|
||||
def advance(self, sequence_id: str, step_result: str) -> Optional[SequenceState]:
|
||||
"""
|
||||
Record step completion and advance to next step.
|
||||
|
||||
Args:
|
||||
sequence_id: Sequence to advance
|
||||
step_result: XML result from the completed step
|
||||
|
||||
Returns:
|
||||
Updated SequenceState, or None if not found
|
||||
"""
|
||||
with self._lock:
|
||||
state = self._sequences.get(sequence_id)
|
||||
if state is None or state.failed:
|
||||
return state
|
||||
|
||||
# Record result
|
||||
state.results.append(step_result)
|
||||
state.last_result = step_result
|
||||
state.current_index += 1
|
||||
|
||||
return state
|
||||
|
||||
def mark_failed(
|
||||
self,
|
||||
sequence_id: str,
|
||||
step: str,
|
||||
error: str,
|
||||
) -> Optional[SequenceState]:
|
||||
"""
|
||||
Mark a sequence as failed.
|
||||
|
||||
Args:
|
||||
sequence_id: Sequence that failed
|
||||
step: Which step failed
|
||||
error: Error message
|
||||
|
||||
Returns:
|
||||
Updated SequenceState, or None if not found
|
||||
"""
|
||||
with self._lock:
|
||||
state = self._sequences.get(sequence_id)
|
||||
if state is None:
|
||||
return None
|
||||
|
||||
state.failed = True
|
||||
state.failed_step = step
|
||||
state.error = error
|
||||
|
||||
return state
|
||||
|
||||
def remove(self, sequence_id: str) -> bool:
|
||||
"""
|
||||
Remove a sequence (cleanup after completion).
|
||||
|
||||
Returns:
|
||||
True if found and removed, False if not found
|
||||
"""
|
||||
with self._lock:
|
||||
return self._sequences.pop(sequence_id, None) is not None
|
||||
|
||||
def list_active(self) -> List[str]:
|
||||
"""List all active sequence IDs."""
|
||||
with self._lock:
|
||||
return list(self._sequences.keys())
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all sequences. Useful for testing."""
|
||||
with self._lock:
|
||||
self._sequences.clear()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Singleton
|
||||
# ============================================================================
|
||||
|
||||
_registry: Optional[SequenceRegistry] = None
|
||||
_registry_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_sequence_registry() -> SequenceRegistry:
|
||||
"""Get the global SequenceRegistry singleton."""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
with _registry_lock:
|
||||
if _registry is None:
|
||||
_registry = SequenceRegistry()
|
||||
return _registry
|
||||
|
||||
|
||||
def reset_sequence_registry() -> None:
|
||||
"""Reset the global sequence registry (for testing)."""
|
||||
global _registry
|
||||
with _registry_lock:
|
||||
if _registry is not None:
|
||||
_registry.clear()
|
||||
_registry = None
|
||||
|
|
@ -210,6 +210,10 @@ class StreamPump:
|
|||
self.routing_table: Dict[str, List[Listener]] = {}
|
||||
self.listeners: Dict[str, Listener] = {}
|
||||
|
||||
# Generic listeners (accept any payload type)
|
||||
# Used for ephemeral orchestration handlers (sequences, buffers)
|
||||
self._generic_listeners: Dict[str, Listener] = {}
|
||||
|
||||
# Per-agent semaphores for rate limiting
|
||||
self.agent_semaphores: Dict[str, asyncio.Semaphore] = {}
|
||||
|
||||
|
|
@ -269,6 +273,82 @@ class StreamPump:
|
|||
self.listeners[lc.name] = listener
|
||||
return listener
|
||||
|
||||
def register_generic_listener(
|
||||
self,
|
||||
name: str,
|
||||
handler: Callable,
|
||||
description: str = "",
|
||||
) -> Listener:
|
||||
"""
|
||||
Register a generic listener that accepts any payload type.
|
||||
|
||||
Used for ephemeral orchestration handlers (sequences, buffers)
|
||||
that need to receive responses from various step types.
|
||||
|
||||
Generic listeners:
|
||||
- Are NOT added to the routing table (no root_tag)
|
||||
- Are looked up by name (to_id) as a fallback in routing
|
||||
- Receive payload_tree directly (no XSD validation/deserialization)
|
||||
|
||||
Args:
|
||||
name: Unique listener name (e.g., "sequence_abc123")
|
||||
handler: Async handler function (receives payload_tree, metadata)
|
||||
description: Human-readable description
|
||||
|
||||
Returns:
|
||||
Listener object
|
||||
"""
|
||||
listener = Listener(
|
||||
name=name,
|
||||
payload_class=object, # Placeholder - not used
|
||||
handler=handler,
|
||||
description=description,
|
||||
is_agent=False,
|
||||
root_tag="*", # Wildcard marker
|
||||
)
|
||||
|
||||
self._generic_listeners[name.lower()] = listener
|
||||
self.listeners[name] = listener
|
||||
|
||||
pump_logger.debug(f"Registered generic listener: {name}")
|
||||
return listener
|
||||
|
||||
def unregister_listener(self, name: str) -> bool:
|
||||
"""
|
||||
Remove a listener by name.
|
||||
|
||||
Used to clean up ephemeral listeners after orchestration completes.
|
||||
|
||||
Args:
|
||||
name: Listener name to remove
|
||||
|
||||
Returns:
|
||||
True if found and removed, False if not found
|
||||
"""
|
||||
name_lower = name.lower()
|
||||
removed = False
|
||||
|
||||
# Remove from generic listeners
|
||||
if name_lower in self._generic_listeners:
|
||||
del self._generic_listeners[name_lower]
|
||||
removed = True
|
||||
pump_logger.debug(f"Unregistered generic listener: {name}")
|
||||
|
||||
# Remove from main listeners dict
|
||||
if name in self.listeners:
|
||||
listener = self.listeners.pop(name)
|
||||
removed = True
|
||||
|
||||
# Remove from routing table
|
||||
if listener.root_tag and listener.root_tag != "*":
|
||||
listeners_for_tag = self.routing_table.get(listener.root_tag, [])
|
||||
if listener in listeners_for_tag:
|
||||
listeners_for_tag.remove(listener)
|
||||
if not listeners_for_tag:
|
||||
del self.routing_table[listener.root_tag]
|
||||
|
||||
return removed
|
||||
|
||||
def register_all(self) -> None:
|
||||
# First pass: register all listeners
|
||||
for lc in self.config.listeners:
|
||||
|
|
@ -781,6 +861,8 @@ class StreamPump:
|
|||
Combined validation + deserialization.
|
||||
|
||||
Uses to_id + payload tag to find the right listener and schema.
|
||||
Falls back to generic listeners (ephemeral orchestration handlers)
|
||||
when no regular listener matches.
|
||||
"""
|
||||
if state.error or state.payload_tree is None:
|
||||
return state
|
||||
|
|
@ -794,6 +876,19 @@ class StreamPump:
|
|||
lookup_key = f"{to_id}.{payload_tag.lower()}" if to_id else payload_tag.lower()
|
||||
|
||||
listeners = self.routing_table.get(lookup_key, [])
|
||||
|
||||
# Fallback: check for generic listener by to_id
|
||||
# Generic listeners accept any payload type (for orchestration)
|
||||
if not listeners and to_id:
|
||||
generic_listener = self._generic_listeners.get(to_id)
|
||||
if generic_listener:
|
||||
# Generic listener: skip XSD validation and deserialization
|
||||
# Pass the raw payload_tree to the handler
|
||||
state.payload = state.payload_tree # Handler receives Element
|
||||
state.target_listeners = [generic_listener]
|
||||
state.metadata["generic_handler"] = True
|
||||
return state
|
||||
|
||||
if not listeners:
|
||||
state.error = f"No listener for: {lookup_key}"
|
||||
return state
|
||||
|
|
@ -1008,6 +1103,36 @@ async def bootstrap(config_path: str = "config/organism.yaml") -> StreamPump:
|
|||
)
|
||||
pump.register_listener(todo_complete_config)
|
||||
|
||||
# Register Sequence primitives (orchestration)
|
||||
from xml_pipeline.primitives.sequence import (
|
||||
SequenceStart, handle_sequence_start,
|
||||
)
|
||||
sequence_config = ListenerConfig(
|
||||
name="system.sequence",
|
||||
payload_class_path="xml_pipeline.primitives.sequence.SequenceStart",
|
||||
handler_path="xml_pipeline.primitives.sequence.handle_sequence_start",
|
||||
description="System sequence handler - chains listeners in order",
|
||||
is_agent=False,
|
||||
payload_class=SequenceStart,
|
||||
handler=handle_sequence_start,
|
||||
)
|
||||
pump.register_listener(sequence_config)
|
||||
|
||||
# Register Buffer primitives (fan-out orchestration)
|
||||
from xml_pipeline.primitives.buffer import (
|
||||
BufferStart, handle_buffer_start,
|
||||
)
|
||||
buffer_config = ListenerConfig(
|
||||
name="system.buffer",
|
||||
payload_class_path="xml_pipeline.primitives.buffer.BufferStart",
|
||||
handler_path="xml_pipeline.primitives.buffer.handle_buffer_start",
|
||||
description="System buffer handler - fan-out to parallel workers",
|
||||
is_agent=False,
|
||||
payload_class=BufferStart,
|
||||
handler=handle_buffer_start,
|
||||
)
|
||||
pump.register_listener(buffer_config)
|
||||
|
||||
# Register all user-defined listeners
|
||||
pump.register_all()
|
||||
|
||||
|
|
@ -1061,6 +1186,9 @@ async def bootstrap(config_path: str = "config/organism.yaml") -> StreamPump:
|
|||
# Inject boot message (will be processed when pump.run() is called)
|
||||
await pump.inject(boot_envelope, thread_id=root_uuid, from_id="system")
|
||||
|
||||
# Set global pump instance for get_stream_pump()
|
||||
set_stream_pump(pump)
|
||||
|
||||
print(f"Routing: {list(pump.routing_table.keys())}")
|
||||
return pump
|
||||
|
||||
|
|
@ -1110,3 +1238,45 @@ The key difference:
|
|||
- Old: 3 tool calls = 3 sequential awaits, each blocking until complete
|
||||
- New: 3 tool calls = 3 items in stream, processed concurrently up to task_limit
|
||||
"""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Global Singleton
|
||||
# ============================================================================
|
||||
|
||||
_pump: Optional[StreamPump] = None
|
||||
|
||||
|
||||
def get_stream_pump() -> StreamPump:
|
||||
"""
|
||||
Get the global StreamPump instance.
|
||||
|
||||
The pump is initialized via bootstrap() and set here.
|
||||
Raises RuntimeError if called before bootstrap.
|
||||
"""
|
||||
global _pump
|
||||
if _pump is None:
|
||||
raise RuntimeError(
|
||||
"StreamPump not initialized. Call bootstrap() first."
|
||||
)
|
||||
return _pump
|
||||
|
||||
|
||||
def set_stream_pump(pump: StreamPump) -> None:
|
||||
"""
|
||||
Set the global StreamPump instance.
|
||||
|
||||
Called by bootstrap() after creating the pump.
|
||||
"""
|
||||
global _pump
|
||||
_pump = pump
|
||||
|
||||
|
||||
def reset_stream_pump() -> None:
|
||||
"""
|
||||
Reset the global StreamPump instance.
|
||||
|
||||
Useful for testing.
|
||||
"""
|
||||
global _pump
|
||||
_pump = None
|
||||
|
|
|
|||
|
|
@ -15,16 +15,45 @@ from xml_pipeline.primitives.todo import (
|
|||
handle_todo_complete,
|
||||
)
|
||||
from xml_pipeline.primitives.text_input import TextInput, TextOutput
|
||||
from xml_pipeline.primitives.sequence import (
|
||||
SequenceStart,
|
||||
SequenceComplete,
|
||||
SequenceError,
|
||||
handle_sequence_start,
|
||||
)
|
||||
from xml_pipeline.primitives.buffer import (
|
||||
BufferStart,
|
||||
BufferItem,
|
||||
BufferComplete,
|
||||
BufferDispatched,
|
||||
BufferError,
|
||||
handle_buffer_start,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Boot
|
||||
"Boot",
|
||||
"handle_boot",
|
||||
# Todo
|
||||
"TodoUntil",
|
||||
"TodoComplete",
|
||||
"TodoRegistered",
|
||||
"TodoClosed",
|
||||
"handle_todo_until",
|
||||
"handle_todo_complete",
|
||||
# Text I/O
|
||||
"TextInput",
|
||||
"TextOutput",
|
||||
# Sequence orchestration
|
||||
"SequenceStart",
|
||||
"SequenceComplete",
|
||||
"SequenceError",
|
||||
"handle_sequence_start",
|
||||
# Buffer orchestration
|
||||
"BufferStart",
|
||||
"BufferItem",
|
||||
"BufferComplete",
|
||||
"BufferDispatched",
|
||||
"BufferError",
|
||||
"handle_buffer_start",
|
||||
]
|
||||
|
|
|
|||
373
xml_pipeline/primitives/buffer.py
Normal file
373
xml_pipeline/primitives/buffer.py
Normal file
|
|
@ -0,0 +1,373 @@
|
|||
"""
|
||||
buffer.py — Buffer (fan-out) orchestration primitives.
|
||||
|
||||
Buffers fan-out to parallel workers, sending N items to the same listener
|
||||
concurrently. Results are collected and returned when all complete.
|
||||
|
||||
Usage by an agent:
|
||||
# Fan-out search queries to web_search
|
||||
return HandlerResponse(
|
||||
payload=BufferStart(
|
||||
target="web_search",
|
||||
items="python async\\nrust memory\\ngo concurrency",
|
||||
return_to="my-agent",
|
||||
collect=True,
|
||||
),
|
||||
to="system.buffer",
|
||||
)
|
||||
|
||||
Flow:
|
||||
1. system.buffer receives BufferStart with N items
|
||||
2. Creates ephemeral listener buffer_{id} to receive results
|
||||
3. Creates N sibling threads via ThreadRegistry
|
||||
4. Sends BufferItem to each worker FROM buffer_{id}
|
||||
5. Workers process and respond → routes to buffer_{id}
|
||||
6. Ephemeral handler collects results
|
||||
7. When all workers done, sends BufferComplete to return_to
|
||||
8. Cleans up ephemeral listener
|
||||
|
||||
Fire-and-forget mode (collect=False):
|
||||
- Returns immediately after dispatching
|
||||
- No result collection
|
||||
- Useful for async side effects
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
import uuid as uuid_module
|
||||
import logging
|
||||
|
||||
from lxml import etree
|
||||
from third_party.xmlable import xmlify
|
||||
from xml_pipeline.message_bus.message_state import (
|
||||
HandlerMetadata,
|
||||
HandlerResponse,
|
||||
)
|
||||
from xml_pipeline.message_bus.buffer_registry import get_buffer_registry
|
||||
from xml_pipeline.message_bus.thread_registry import get_registry as get_thread_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Payloads
|
||||
# ============================================================================
|
||||
|
||||
@xmlify
|
||||
@dataclass
|
||||
class BufferStart:
|
||||
"""
|
||||
Start a new buffer (fan-out) execution.
|
||||
|
||||
Sent to system.buffer to begin parallel processing.
|
||||
"""
|
||||
target: str = "" # Listener to fan-out to
|
||||
items: str = "" # Newline-separated payloads (raw XML)
|
||||
collect: bool = True # Wait for all results?
|
||||
return_to: str = "" # Where to send BufferComplete
|
||||
buffer_id: str = "" # Auto-generated if empty
|
||||
|
||||
|
||||
@xmlify
|
||||
@dataclass
|
||||
class BufferItem:
|
||||
"""
|
||||
Individual item being processed by a worker.
|
||||
|
||||
Wraps the actual payload with buffer metadata.
|
||||
Note: This is an internal type - workers receive the raw payload,
|
||||
not BufferItem directly.
|
||||
"""
|
||||
buffer_id: str = ""
|
||||
index: int = 0
|
||||
payload: str = "" # The actual XML payload
|
||||
|
||||
|
||||
@xmlify
|
||||
@dataclass
|
||||
class BufferComplete:
|
||||
"""
|
||||
Buffer completed - all workers finished.
|
||||
|
||||
Sent to return_to when all items are processed.
|
||||
"""
|
||||
buffer_id: str = ""
|
||||
total: int = 0
|
||||
successful: int = 0
|
||||
results: str = "" # XML array of results
|
||||
|
||||
|
||||
@xmlify
|
||||
@dataclass
|
||||
class BufferDispatched:
|
||||
"""
|
||||
Buffer dispatched (fire-and-forget mode).
|
||||
|
||||
Sent immediately after items are dispatched when collect=False.
|
||||
"""
|
||||
buffer_id: str = ""
|
||||
total: int = 0
|
||||
|
||||
|
||||
@xmlify
|
||||
@dataclass
|
||||
class BufferError:
|
||||
"""
|
||||
Buffer failed to start.
|
||||
|
||||
Sent when buffer initialization fails.
|
||||
"""
|
||||
buffer_id: str = ""
|
||||
error: str = ""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Handlers
|
||||
# ============================================================================
|
||||
|
||||
async def handle_buffer_start(
|
||||
payload: BufferStart,
|
||||
metadata: HandlerMetadata,
|
||||
) -> Optional[HandlerResponse]:
|
||||
"""
|
||||
Handle BufferStart — begin a fan-out execution.
|
||||
|
||||
Creates N sibling threads, dispatches items to workers,
|
||||
and sets up result collection.
|
||||
"""
|
||||
from xml_pipeline.message_bus.stream_pump import get_stream_pump
|
||||
|
||||
# Parse items
|
||||
items = [item.strip() for item in payload.items.split("\n") if item.strip()]
|
||||
if not items:
|
||||
logger.error("BufferStart with no items")
|
||||
return HandlerResponse(
|
||||
payload=BufferError(
|
||||
buffer_id=payload.buffer_id or "unknown",
|
||||
error="No items specified",
|
||||
),
|
||||
to=payload.return_to or metadata.from_id,
|
||||
)
|
||||
|
||||
# Validate target exists
|
||||
pump = get_stream_pump()
|
||||
if payload.target not in pump.listeners:
|
||||
logger.error(f"BufferStart: unknown target '{payload.target}'")
|
||||
return HandlerResponse(
|
||||
payload=BufferError(
|
||||
buffer_id=payload.buffer_id or "unknown",
|
||||
error=f"Unknown target listener: {payload.target}",
|
||||
),
|
||||
to=payload.return_to or metadata.from_id,
|
||||
)
|
||||
|
||||
# Generate buffer ID if not provided
|
||||
buf_id = payload.buffer_id or str(uuid_module.uuid4())[:8]
|
||||
|
||||
# Create buffer state
|
||||
buffer_registry = get_buffer_registry()
|
||||
state = buffer_registry.create(
|
||||
buffer_id=buf_id,
|
||||
total_items=len(items),
|
||||
return_to=payload.return_to or metadata.from_id,
|
||||
thread_id=metadata.thread_id,
|
||||
from_id=metadata.from_id,
|
||||
target=payload.target,
|
||||
collect=payload.collect,
|
||||
)
|
||||
|
||||
# For fire-and-forget, we still track but don't wait
|
||||
ephemeral_name = f"buffer_{buf_id}"
|
||||
|
||||
if payload.collect:
|
||||
# Create ephemeral handler for result collection
|
||||
async def buffer_handler(
|
||||
payload_tree: etree._Element,
|
||||
meta: HandlerMetadata,
|
||||
) -> Optional[HandlerResponse]:
|
||||
"""Ephemeral handler that collects worker results."""
|
||||
return await _handle_buffer_result(buf_id, payload_tree, meta)
|
||||
|
||||
# Register ephemeral listener (generic mode - accepts any payload)
|
||||
pump.register_generic_listener(
|
||||
name=ephemeral_name,
|
||||
handler=buffer_handler,
|
||||
description=f"Ephemeral buffer handler for {buf_id}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Buffer {buf_id} starting: {len(items)} items to {payload.target}, "
|
||||
f"collect={payload.collect}"
|
||||
)
|
||||
|
||||
# Dispatch all items in parallel
|
||||
thread_registry = get_thread_registry()
|
||||
parent_chain = thread_registry.lookup(metadata.thread_id) or metadata.thread_id
|
||||
|
||||
for i, item_payload in enumerate(items):
|
||||
# Create sibling thread for this worker
|
||||
worker_chain = f"{parent_chain}.{ephemeral_name}_w{i}"
|
||||
worker_uuid = thread_registry.get_or_create(worker_chain)
|
||||
|
||||
# Inject the item to the target
|
||||
# The item is sent FROM the ephemeral listener so .respond() comes back
|
||||
await _inject_buffer_item(
|
||||
pump=pump,
|
||||
target=payload.target,
|
||||
payload_xml=item_payload,
|
||||
thread_id=worker_uuid,
|
||||
from_id=ephemeral_name,
|
||||
)
|
||||
|
||||
logger.debug(f"Buffer {buf_id}: dispatched item {i} to {payload.target}")
|
||||
|
||||
# Fire-and-forget: return immediately
|
||||
if not payload.collect:
|
||||
logger.info(f"Buffer {buf_id}: fire-and-forget mode, {len(items)} items dispatched")
|
||||
return HandlerResponse(
|
||||
payload=BufferDispatched(
|
||||
buffer_id=buf_id,
|
||||
total=len(items),
|
||||
),
|
||||
to=payload.return_to or metadata.from_id,
|
||||
)
|
||||
|
||||
# Collect mode: wait for results (handled by ephemeral listener)
|
||||
# Return None - the ephemeral listener will send BufferComplete
|
||||
return None
|
||||
|
||||
|
||||
async def _handle_buffer_result(
|
||||
buf_id: str,
|
||||
payload_tree: etree._Element,
|
||||
metadata: HandlerMetadata,
|
||||
) -> Optional[HandlerResponse]:
|
||||
"""
|
||||
Handle a worker result in the buffer.
|
||||
|
||||
Called by the ephemeral listener when a worker responds.
|
||||
"""
|
||||
from xml_pipeline.message_bus.stream_pump import get_stream_pump
|
||||
|
||||
buffer_registry = get_buffer_registry()
|
||||
state = buffer_registry.get(buf_id)
|
||||
|
||||
if state is None:
|
||||
logger.warning(f"Buffer {buf_id} not found in registry (result dropped)")
|
||||
return None
|
||||
|
||||
# Extract worker index from thread chain
|
||||
# Chain format: parent.buffer_xyz_wN where N is the index
|
||||
thread_registry = get_thread_registry()
|
||||
chain = thread_registry.lookup(metadata.thread_id) or ""
|
||||
|
||||
worker_index = _extract_worker_index(chain, buf_id)
|
||||
if worker_index is None:
|
||||
logger.warning(f"Buffer {buf_id}: could not determine worker index from chain")
|
||||
worker_index = state.completed_count # Fallback to count-based
|
||||
|
||||
# Serialize the result
|
||||
result_xml = etree.tostring(payload_tree, encoding="unicode")
|
||||
|
||||
# Check for errors
|
||||
is_error = payload_tree.tag.lower() in ("huh", "systemerror")
|
||||
|
||||
# Record result
|
||||
state = buffer_registry.record_result(
|
||||
buffer_id=buf_id,
|
||||
index=worker_index,
|
||||
result=result_xml,
|
||||
success=not is_error,
|
||||
error=result_xml[:200] if is_error else None,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Buffer {buf_id}: received result {state.completed_count}/{state.total_items}"
|
||||
)
|
||||
|
||||
# Check if all done
|
||||
if state.is_complete:
|
||||
# Clean up
|
||||
pump = get_stream_pump()
|
||||
pump.unregister_listener(f"buffer_{buf_id}")
|
||||
|
||||
# Format results as XML array
|
||||
results_xml = _format_buffer_results(state)
|
||||
buffer_registry.remove(buf_id)
|
||||
|
||||
logger.info(
|
||||
f"Buffer {buf_id} completed: {state.successful_count}/{state.total_items} successful"
|
||||
)
|
||||
|
||||
return HandlerResponse(
|
||||
payload=BufferComplete(
|
||||
buffer_id=buf_id,
|
||||
total=state.total_items,
|
||||
successful=state.successful_count,
|
||||
results=results_xml,
|
||||
),
|
||||
to=state.return_to,
|
||||
)
|
||||
|
||||
# More results pending
|
||||
return None
|
||||
|
||||
|
||||
def _extract_worker_index(chain: str, buf_id: str) -> Optional[int]:
|
||||
"""Extract worker index from thread chain."""
|
||||
# Look for pattern: buffer_{id}_wN
|
||||
import re
|
||||
pattern = rf"buffer_{buf_id}_w(\d+)"
|
||||
match = re.search(pattern, chain)
|
||||
if match:
|
||||
return int(match.group(1))
|
||||
return None
|
||||
|
||||
|
||||
def _format_buffer_results(state) -> str:
|
||||
"""Format buffer results as XML array."""
|
||||
lines = ["<results>"]
|
||||
for i in range(state.total_items):
|
||||
result = state.results.get(i)
|
||||
if result:
|
||||
success = "true" if result.success else "false"
|
||||
lines.append(f' <item index="{i}" success="{success}">')
|
||||
# Indent the result content
|
||||
for line in result.result.split("\n"):
|
||||
lines.append(f" {line}")
|
||||
lines.append(" </item>")
|
||||
else:
|
||||
lines.append(f' <item index="{i}" success="false">missing</item>')
|
||||
lines.append("</results>")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def _inject_buffer_item(
|
||||
pump,
|
||||
target: str,
|
||||
payload_xml: str,
|
||||
thread_id: str,
|
||||
from_id: str,
|
||||
) -> None:
|
||||
"""Inject a buffer item directly into the pump."""
|
||||
# Wrap the payload in an envelope
|
||||
envelope = pump._wrap_in_envelope(
|
||||
payload=_RawXmlPayload(payload_xml),
|
||||
from_id=from_id,
|
||||
to_id=target,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
# Inject into pump
|
||||
await pump.inject(envelope, thread_id=thread_id, from_id=from_id)
|
||||
|
||||
|
||||
class _RawXmlPayload:
|
||||
"""Carrier for raw XML that bypasses serialization."""
|
||||
|
||||
def __init__(self, xml: str):
|
||||
self.xml = xml
|
||||
|
||||
def to_xml(self) -> str:
|
||||
"""Return raw XML for envelope wrapping."""
|
||||
return self.xml
|
||||
299
xml_pipeline/primitives/sequence.py
Normal file
299
xml_pipeline/primitives/sequence.py
Normal file
|
|
@ -0,0 +1,299 @@
|
|||
"""
|
||||
sequence.py — Sequence orchestration primitives.
|
||||
|
||||
Sequences chain multiple listeners in order, feeding the output of one step
|
||||
as input to the next. Steps remain transparent - they don't know they're
|
||||
part of a sequence.
|
||||
|
||||
Usage by an agent:
|
||||
# Start a sequence: add two numbers, then multiply
|
||||
return HandlerResponse(
|
||||
payload=SequenceStart(
|
||||
steps="calculator.add,calculator.multiply",
|
||||
payload='<AddPayload><a>5</a><b>3</b></AddPayload>',
|
||||
return_to="my-agent",
|
||||
),
|
||||
to="system.sequence",
|
||||
)
|
||||
|
||||
Flow:
|
||||
1. system.sequence receives SequenceStart
|
||||
2. Creates ephemeral listener sequence_{id} to receive step results
|
||||
3. Sends initial payload to first step FROM sequence_{id}
|
||||
4. Step processes and responds → routes to sequence_{id}
|
||||
5. Ephemeral handler advances, sends to next step
|
||||
6. When all steps complete, sends SequenceComplete to return_to
|
||||
7. Cleans up ephemeral listener
|
||||
|
||||
Key insight: Steps use normal .respond() - the ephemeral listener IS the
|
||||
caller in the thread chain, so responses naturally route back to it.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
import uuid as uuid_module
|
||||
import logging
|
||||
|
||||
from lxml import etree
|
||||
from third_party.xmlable import xmlify
|
||||
from xml_pipeline.message_bus.message_state import (
|
||||
HandlerMetadata,
|
||||
HandlerResponse,
|
||||
MessageState,
|
||||
)
|
||||
from xml_pipeline.message_bus.sequence_registry import get_sequence_registry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Payloads
|
||||
# ============================================================================
|
||||
|
||||
@xmlify
|
||||
@dataclass
|
||||
class SequenceStart:
|
||||
"""
|
||||
Start a new sequence execution.
|
||||
|
||||
Sent to system.sequence to begin chaining steps.
|
||||
"""
|
||||
steps: str = "" # Comma-separated listener names
|
||||
payload: str = "" # Initial XML payload for first step
|
||||
return_to: str = "" # Where to send final result
|
||||
sequence_id: str = "" # Auto-generated if empty
|
||||
|
||||
|
||||
@xmlify
|
||||
@dataclass
|
||||
class SequenceComplete:
|
||||
"""
|
||||
Sequence completed successfully.
|
||||
|
||||
Sent to the return_to listener when all steps finish.
|
||||
"""
|
||||
sequence_id: str = ""
|
||||
final_result: str = "" # XML result from last step
|
||||
step_count: int = 0 # How many steps were executed
|
||||
|
||||
|
||||
@xmlify
|
||||
@dataclass
|
||||
class SequenceError:
|
||||
"""
|
||||
Sequence failed at a step.
|
||||
|
||||
Sent to return_to when a step fails.
|
||||
"""
|
||||
sequence_id: str = ""
|
||||
failed_step: str = "" # Which step failed
|
||||
step_index: int = 0 # 0-based index of failed step
|
||||
error: str = "" # Error message
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Handlers
|
||||
# ============================================================================
|
||||
|
||||
async def handle_sequence_start(
|
||||
payload: SequenceStart,
|
||||
metadata: HandlerMetadata,
|
||||
) -> Optional[HandlerResponse]:
|
||||
"""
|
||||
Handle SequenceStart — begin a sequence execution.
|
||||
|
||||
Creates an ephemeral listener for this sequence, stores state,
|
||||
and kicks off the first step.
|
||||
"""
|
||||
from xml_pipeline.message_bus.stream_pump import get_stream_pump
|
||||
|
||||
# Parse and validate
|
||||
steps = [s.strip() for s in payload.steps.split(",") if s.strip()]
|
||||
if not steps:
|
||||
logger.error("SequenceStart with no steps")
|
||||
return HandlerResponse(
|
||||
payload=SequenceError(
|
||||
sequence_id=payload.sequence_id or "unknown",
|
||||
failed_step="",
|
||||
step_index=0,
|
||||
error="No steps specified",
|
||||
),
|
||||
to=payload.return_to or metadata.from_id,
|
||||
)
|
||||
|
||||
# Generate sequence ID if not provided
|
||||
seq_id = payload.sequence_id or str(uuid_module.uuid4())[:8]
|
||||
|
||||
# Validate all steps exist
|
||||
pump = get_stream_pump()
|
||||
for step in steps:
|
||||
if step not in pump.listeners:
|
||||
logger.error(f"SequenceStart: unknown step '{step}'")
|
||||
return HandlerResponse(
|
||||
payload=SequenceError(
|
||||
sequence_id=seq_id,
|
||||
failed_step=step,
|
||||
step_index=steps.index(step),
|
||||
error=f"Unknown listener: {step}",
|
||||
),
|
||||
to=payload.return_to or metadata.from_id,
|
||||
)
|
||||
|
||||
# Create sequence state
|
||||
registry = get_sequence_registry()
|
||||
state = registry.create(
|
||||
sequence_id=seq_id,
|
||||
steps=steps,
|
||||
return_to=payload.return_to or metadata.from_id,
|
||||
thread_id=metadata.thread_id,
|
||||
from_id=metadata.from_id,
|
||||
initial_payload=payload.payload,
|
||||
)
|
||||
|
||||
# Create ephemeral handler for this sequence
|
||||
ephemeral_name = f"sequence_{seq_id}"
|
||||
|
||||
async def sequence_handler(
|
||||
payload_tree: etree._Element,
|
||||
meta: HandlerMetadata,
|
||||
) -> Optional[HandlerResponse]:
|
||||
"""Ephemeral handler that processes step results."""
|
||||
return await _handle_sequence_step_result(seq_id, payload_tree, meta)
|
||||
|
||||
# Register ephemeral listener (generic mode - accepts any payload)
|
||||
pump.register_generic_listener(
|
||||
name=ephemeral_name,
|
||||
handler=sequence_handler,
|
||||
description=f"Ephemeral sequence handler for {seq_id}",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Sequence {seq_id} started: {len(steps)} steps, "
|
||||
f"return_to={state.return_to}"
|
||||
)
|
||||
|
||||
# Kick off first step
|
||||
first_step = steps[0]
|
||||
return _create_step_message(
|
||||
seq_id=seq_id,
|
||||
target=first_step,
|
||||
payload_xml=payload.payload,
|
||||
from_name=ephemeral_name,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_sequence_step_result(
|
||||
seq_id: str,
|
||||
payload_tree: etree._Element,
|
||||
metadata: HandlerMetadata,
|
||||
) -> Optional[HandlerResponse]:
|
||||
"""
|
||||
Handle a step result in the sequence.
|
||||
|
||||
Called by the ephemeral listener when a step responds.
|
||||
"""
|
||||
from xml_pipeline.message_bus.stream_pump import get_stream_pump
|
||||
|
||||
registry = get_sequence_registry()
|
||||
state = registry.get(seq_id)
|
||||
|
||||
if state is None:
|
||||
logger.error(f"Sequence {seq_id} not found in registry")
|
||||
return None
|
||||
|
||||
# Serialize the result for storage
|
||||
result_xml = etree.tostring(payload_tree, encoding="unicode")
|
||||
|
||||
# Check for error responses
|
||||
if payload_tree.tag.lower() in ("huh", "systemerror"):
|
||||
# Step failed
|
||||
error_text = payload_tree.text or etree.tostring(payload_tree, encoding="unicode")
|
||||
registry.mark_failed(seq_id, state.current_step or "unknown", error_text)
|
||||
|
||||
# Clean up and send error
|
||||
pump = get_stream_pump()
|
||||
pump.unregister_listener(f"sequence_{seq_id}")
|
||||
registry.remove(seq_id)
|
||||
|
||||
logger.warning(f"Sequence {seq_id} failed at step {state.current_index}")
|
||||
return HandlerResponse(
|
||||
payload=SequenceError(
|
||||
sequence_id=seq_id,
|
||||
failed_step=state.current_step or "unknown",
|
||||
step_index=state.current_index,
|
||||
error=error_text[:200], # Truncate long errors
|
||||
),
|
||||
to=state.return_to,
|
||||
)
|
||||
|
||||
# Advance to next step
|
||||
state = registry.advance(seq_id, result_xml)
|
||||
|
||||
if state.is_complete:
|
||||
# All steps done - send completion
|
||||
pump = get_stream_pump()
|
||||
pump.unregister_listener(f"sequence_{seq_id}")
|
||||
registry.remove(seq_id)
|
||||
|
||||
logger.info(f"Sequence {seq_id} completed: {len(state.steps)} steps")
|
||||
return HandlerResponse(
|
||||
payload=SequenceComplete(
|
||||
sequence_id=seq_id,
|
||||
final_result=result_xml,
|
||||
step_count=len(state.steps),
|
||||
),
|
||||
to=state.return_to,
|
||||
)
|
||||
|
||||
# More steps to go - send to next step
|
||||
next_step = state.current_step
|
||||
logger.debug(
|
||||
f"Sequence {seq_id} advancing to step {state.current_index}: {next_step}"
|
||||
)
|
||||
|
||||
return _create_step_message(
|
||||
seq_id=seq_id,
|
||||
target=next_step,
|
||||
payload_xml=result_xml,
|
||||
from_name=f"sequence_{seq_id}",
|
||||
)
|
||||
|
||||
|
||||
def _create_step_message(
|
||||
seq_id: str,
|
||||
target: str,
|
||||
payload_xml: str,
|
||||
from_name: str,
|
||||
) -> HandlerResponse:
|
||||
"""
|
||||
Create a HandlerResponse to send payload to a step.
|
||||
|
||||
We need to inject the message with the ephemeral listener as the sender,
|
||||
so that .respond() routes back to us.
|
||||
"""
|
||||
from xml_pipeline.primitives.sequence import _RawPayloadCarrier
|
||||
|
||||
# Return a special carrier that tells the pump to:
|
||||
# 1. Use the raw XML bytes directly
|
||||
# 2. Set from_id to from_name (the ephemeral listener)
|
||||
return HandlerResponse(
|
||||
payload=_RawPayloadCarrier(xml=payload_xml, from_override=from_name),
|
||||
to=target,
|
||||
)
|
||||
|
||||
|
||||
class _RawPayloadCarrier:
|
||||
"""
|
||||
Internal carrier for raw XML that bypasses normal serialization.
|
||||
|
||||
When the pump sees this, it uses the raw XML directly instead of
|
||||
serializing a dataclass.
|
||||
"""
|
||||
|
||||
def __init__(self, xml: str, from_override: Optional[str] = None):
|
||||
self.xml = xml
|
||||
self.from_override = from_override
|
||||
|
||||
def to_xml(self) -> str:
|
||||
"""Return raw XML for envelope wrapping."""
|
||||
return self.xml
|
||||
Loading…
Reference in a new issue