diff --git a/tests/test_buffer.py b/tests/test_buffer.py new file mode 100644 index 0000000..2f9a889 --- /dev/null +++ b/tests/test_buffer.py @@ -0,0 +1,635 @@ +""" +test_buffer.py — Tests for the Buffer (fan-out) orchestration primitives. + +Tests: +1. BufferRegistry basic operations +2. BufferStart handler +3. Result collection +4. Fire-and-forget mode +""" + +import pytest +import uuid + +from xml_pipeline.message_bus.buffer_registry import ( + BufferRegistry, + BufferState, + BufferItemResult, + get_buffer_registry, + reset_buffer_registry, +) +from xml_pipeline.message_bus.message_state import HandlerMetadata, HandlerResponse +from xml_pipeline.primitives.buffer import ( + BufferStart, + BufferComplete, + BufferDispatched, + BufferError, + handle_buffer_start, + _extract_worker_index, + _format_buffer_results, +) + + +class TestBufferRegistry: + """Test BufferRegistry basic operations.""" + + def test_create_buffer_state(self): + """create() should create a buffer state with correct fields.""" + registry = BufferRegistry() + + state = registry.create( + buffer_id="buf001", + total_items=5, + return_to="caller", + thread_id="thread-123", + from_id="console", + target="worker", + collect=True, + ) + + assert state.buffer_id == "buf001" + assert state.total_items == 5 + assert state.return_to == "caller" + assert state.thread_id == "thread-123" + assert state.from_id == "console" + assert state.target == "worker" + assert state.collect is True + assert state.completed_count == 0 + assert state.successful_count == 0 + assert state.is_complete is False + assert state.pending_count == 5 + + def test_get_returns_buffer(self): + """get() should return buffer by ID.""" + registry = BufferRegistry() + + registry.create( + buffer_id="buf002", + total_items=3, + return_to="x", + thread_id="t", + from_id="c", + target="w", + ) + + state = registry.get("buf002") + assert state is not None + assert state.buffer_id == "buf002" + + # Non-existent returns None + assert registry.get("nonexistent") is None + + def test_record_result_stores_result(self): + """record_result() should store result and update counts.""" + registry = BufferRegistry() + + registry.create( + buffer_id="buf003", + total_items=3, + return_to="x", + thread_id="t", + from_id="c", + target="w", + ) + + # Record first result (success) + state = registry.record_result( + buffer_id="buf003", + index=0, + result="", + success=True, + ) + + assert state.completed_count == 1 + assert state.successful_count == 1 + assert state.is_complete is False + assert 0 in state.results + assert state.results[0].result == "" + assert state.results[0].success is True + + # Record second result (failure) + state = registry.record_result( + buffer_id="buf003", + index=1, + result="", + success=False, + error="Timeout", + ) + + assert state.completed_count == 2 + assert state.successful_count == 1 # Still just 1 + assert state.results[1].success is False + assert state.results[1].error == "Timeout" + + # Record third result - now complete + state = registry.record_result( + buffer_id="buf003", + index=2, + result="", + success=True, + ) + + assert state.completed_count == 3 + assert state.successful_count == 2 + assert state.is_complete is True + assert state.pending_count == 0 + + def test_record_result_ignores_duplicates(self): + """record_result() should not count same index twice.""" + registry = BufferRegistry() + + registry.create( + buffer_id="buf004", + total_items=2, + return_to="x", + thread_id="t", + from_id="c", + target="w", + ) + + # Record index 0 + state = registry.record_result("buf004", 0, "", True) + assert state.completed_count == 1 + + # Try to record index 0 again + state = registry.record_result("buf004", 0, "", True) + assert state.completed_count == 1 # Should not increment + assert state.results[0].result == "" # Original preserved + + def test_remove_deletes_buffer(self): + """remove() should delete buffer from registry.""" + registry = BufferRegistry() + + registry.create( + buffer_id="buf005", + total_items=1, + return_to="x", + thread_id="t", + from_id="c", + target="w", + ) + + assert registry.get("buf005") is not None + result = registry.remove("buf005") + assert result is True + assert registry.get("buf005") is None + + # Remove non-existent returns False + assert registry.remove("nonexistent") is False + + def test_list_active(self): + """list_active() should return all active buffer IDs.""" + registry = BufferRegistry() + + registry.create("buf-a", 1, "x", "t", "c", "w") + registry.create("buf-b", 2, "x", "t", "c", "w") + registry.create("buf-c", 3, "x", "t", "c", "w") + + active = registry.list_active() + assert set(active) == {"buf-a", "buf-b", "buf-c"} + + def test_clear(self): + """clear() should remove all buffers.""" + registry = BufferRegistry() + + registry.create("buf-1", 1, "x", "t", "c", "w") + registry.create("buf-2", 2, "x", "t", "c", "w") + + registry.clear() + + assert registry.list_active() == [] + + +class TestBufferStateProperties: + """Test BufferState computed properties.""" + + def test_is_complete_when_all_received(self): + """is_complete should be True when all items are received.""" + state = BufferState( + buffer_id="test", + total_items=3, + return_to="x", + thread_id="t", + from_id="c", + target="w", + completed_count=3, + ) + assert state.is_complete is True + + def test_pending_count(self): + """pending_count should reflect remaining items.""" + state = BufferState( + buffer_id="test", + total_items=5, + return_to="x", + thread_id="t", + from_id="c", + target="w", + completed_count=2, + ) + assert state.pending_count == 3 + + def test_get_ordered_results(self): + """get_ordered_results should return results in order.""" + state = BufferState( + buffer_id="test", + total_items=3, + return_to="x", + thread_id="t", + from_id="c", + target="w", + results={ + 0: BufferItemResult(0, "", True), + 2: BufferItemResult(2, "", True), + # Index 1 missing + }, + ) + + ordered = state.get_ordered_results() + assert len(ordered) == 3 + assert ordered[0] is not None + assert ordered[0].result == "" + assert ordered[1] is None # Missing + assert ordered[2] is not None + assert ordered[2].result == "" + + +class TestBufferStartPayload: + """Test BufferStart payload.""" + + def test_buffer_start_fields(self): + """BufferStart should have expected fields.""" + payload = BufferStart( + target="worker", + items="item1\nitem2\nitem3", + collect=True, + return_to="caller", + buffer_id="custom-id", + ) + + assert payload.target == "worker" + assert payload.items == "item1\nitem2\nitem3" + assert payload.collect is True + assert payload.return_to == "caller" + assert payload.buffer_id == "custom-id" + + def test_buffer_start_default_values(self): + """BufferStart should have sensible defaults.""" + payload = BufferStart() + + assert payload.target == "" + assert payload.items == "" + assert payload.collect is True + assert payload.return_to == "" + assert payload.buffer_id == "" + + +class TestBufferCompletePayload: + """Test BufferComplete payload.""" + + def test_buffer_complete_fields(self): + """BufferComplete should have expected fields.""" + payload = BufferComplete( + buffer_id="buf123", + total=5, + successful=4, + results="...", + ) + + assert payload.buffer_id == "buf123" + assert payload.total == 5 + assert payload.successful == 4 + assert payload.results == "..." + + +class TestBufferDispatchedPayload: + """Test BufferDispatched payload (fire-and-forget mode).""" + + def test_buffer_dispatched_fields(self): + """BufferDispatched should have expected fields.""" + payload = BufferDispatched( + buffer_id="buf456", + total=10, + ) + + assert payload.buffer_id == "buf456" + assert payload.total == 10 + + +class TestBufferErrorPayload: + """Test BufferError payload.""" + + def test_buffer_error_fields(self): + """BufferError should have expected fields.""" + payload = BufferError( + buffer_id="buf789", + error="Unknown target listener", + ) + + assert payload.buffer_id == "buf789" + assert payload.error == "Unknown target listener" + + +class TestExtractWorkerIndex: + """Test _extract_worker_index helper.""" + + def test_extracts_index_from_chain(self): + """Should extract worker index from thread chain.""" + chain = "root.parent.buffer_abc123_w5" + index = _extract_worker_index(chain, "abc123") + assert index == 5 + + def test_extracts_double_digit_index(self): + """Should handle double-digit indices.""" + chain = "x.buffer_xyz_w42" + index = _extract_worker_index(chain, "xyz") + assert index == 42 + + def test_returns_none_for_no_match(self): + """Should return None when pattern doesn't match.""" + chain = "something.else.entirely" + index = _extract_worker_index(chain, "abc") + assert index is None + + def test_returns_none_for_wrong_buffer_id(self): + """Should return None when buffer ID doesn't match.""" + chain = "root.buffer_other_w3" + index = _extract_worker_index(chain, "abc") + assert index is None + + +class TestFormatBufferResults: + """Test _format_buffer_results helper.""" + + def test_formats_complete_results(self): + """Should format all results as XML.""" + state = BufferState( + buffer_id="test", + total_items=2, + return_to="x", + thread_id="t", + from_id="c", + target="w", + results={ + 0: BufferItemResult(0, "A", True), + 1: BufferItemResult(1, "B", True), + }, + ) + + xml = _format_buffer_results(state) + + assert "" in xml + assert "" in xml + assert 'index="0"' in xml + assert 'index="1"' in xml + assert 'success="true"' in xml + assert "A" in xml + assert "B" in xml + + def test_formats_partial_failure(self): + """Should format mixed success/failure results.""" + state = BufferState( + buffer_id="test", + total_items=2, + return_to="x", + thread_id="t", + from_id="c", + target="w", + results={ + 0: BufferItemResult(0, "", True), + 1: BufferItemResult(1, "", False, "timeout"), + }, + ) + + xml = _format_buffer_results(state) + + assert 'success="true"' in xml + assert 'success="false"' in xml + + def test_formats_missing_results(self): + """Should handle missing results.""" + state = BufferState( + buffer_id="test", + total_items=3, + return_to="x", + thread_id="t", + from_id="c", + target="w", + results={ + 0: BufferItemResult(0, "", True), + # Index 1 missing + 2: BufferItemResult(2, "", True), + }, + ) + + xml = _format_buffer_results(state) + + assert 'index="1"' in xml + assert "missing" in xml + + +class TestHandleBufferStartValidation: + """Test validation in handle_buffer_start.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Reset registries before each test.""" + reset_buffer_registry() + + @pytest.mark.asyncio + async def test_empty_items_returns_error(self): + """handle_buffer_start should return error for empty items.""" + from unittest.mock import patch, MagicMock + + mock_pump = MagicMock() + mock_pump.listeners = {} + + with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump): + payload = BufferStart( + target="worker", + items="", # Empty + return_to="caller", + ) + + metadata = HandlerMetadata( + thread_id=str(uuid.uuid4()), + from_id="console", + ) + + response = await handle_buffer_start(payload, metadata) + + assert isinstance(response, HandlerResponse) + assert isinstance(response.payload, BufferError) + assert "No items" in response.payload.error + + @pytest.mark.asyncio + async def test_unknown_target_returns_error(self): + """handle_buffer_start should return error for unknown target.""" + from unittest.mock import patch, MagicMock + + mock_pump = MagicMock() + mock_pump.listeners = {} # No listeners + + with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump): + payload = BufferStart( + target="unknown_worker", + items="item1\nitem2", + return_to="caller", + ) + + metadata = HandlerMetadata( + thread_id=str(uuid.uuid4()), + from_id="console", + ) + + response = await handle_buffer_start(payload, metadata) + + assert isinstance(response, HandlerResponse) + assert isinstance(response.payload, BufferError) + assert "Unknown target" in response.payload.error + + +class TestBufferRegistrySingleton: + """Test singleton pattern for BufferRegistry.""" + + def test_get_buffer_registry_returns_singleton(self): + """get_buffer_registry should return same instance.""" + reset_buffer_registry() + + reg1 = get_buffer_registry() + reg2 = get_buffer_registry() + + assert reg1 is reg2 + + def test_reset_creates_new_instance(self): + """reset_buffer_registry should clear singleton.""" + reg1 = get_buffer_registry() + reg1.create("test", 1, "x", "t", "c", "w") + + reset_buffer_registry() + reg2 = get_buffer_registry() + + assert reg2.get("test") is None + + +class TestBufferCollectVsFireAndForget: + """Test collect=True vs collect=False behavior.""" + + @pytest.mark.asyncio + async def test_collect_mode_returns_none(self): + """With collect=True, handler should return None (wait for results).""" + from unittest.mock import patch, MagicMock, AsyncMock + + mock_pump = MagicMock() + mock_pump.listeners = {"worker": MagicMock()} + mock_pump.register_generic_listener = MagicMock() + mock_pump._wrap_in_envelope = MagicMock(return_value=b"") + mock_pump.inject = AsyncMock() + + # Mock thread registry + mock_thread_registry = MagicMock() + mock_thread_registry.lookup = MagicMock(return_value="root.parent") + mock_thread_registry.get_or_create = MagicMock(return_value="worker-uuid") + + with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump): + with patch('xml_pipeline.message_bus.thread_registry.get_registry', return_value=mock_thread_registry): + payload = BufferStart( + target="worker", + items="item1\nitem2", + return_to="caller", + collect=True, + ) + + metadata = HandlerMetadata( + thread_id=str(uuid.uuid4()), + from_id="console", + ) + + response = await handle_buffer_start(payload, metadata) + + # With collect=True, returns None (ephemeral handler will send BufferComplete) + assert response is None + + # Ephemeral listener should be registered + mock_pump.register_generic_listener.assert_called_once() + + @pytest.mark.asyncio + async def test_fire_and_forget_returns_dispatched(self): + """With collect=False, handler should return BufferDispatched immediately.""" + from unittest.mock import patch, MagicMock, AsyncMock + + mock_pump = MagicMock() + mock_pump.listeners = {"worker": MagicMock()} + mock_pump.register_generic_listener = MagicMock() + mock_pump._wrap_in_envelope = MagicMock(return_value=b"") + mock_pump.inject = AsyncMock() + + mock_thread_registry = MagicMock() + mock_thread_registry.lookup = MagicMock(return_value="root.parent") + mock_thread_registry.get_or_create = MagicMock(return_value="worker-uuid") + + with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump): + with patch('xml_pipeline.message_bus.thread_registry.get_registry', return_value=mock_thread_registry): + payload = BufferStart( + target="worker", + items="item1\nitem2\nitem3", + return_to="caller", + collect=False, # Fire-and-forget + ) + + metadata = HandlerMetadata( + thread_id=str(uuid.uuid4()), + from_id="console", + ) + + response = await handle_buffer_start(payload, metadata) + + # With collect=False, returns BufferDispatched + assert isinstance(response, HandlerResponse) + assert isinstance(response.payload, BufferDispatched) + assert response.payload.total == 3 + + # No ephemeral listener registered for fire-and-forget + mock_pump.register_generic_listener.assert_not_called() + + +class TestBufferResultCollection: + """Test result collection behavior.""" + + def test_result_collection_partial_success(self): + """Buffer should track partial success correctly.""" + registry = BufferRegistry() + + registry.create("partial", 5, "x", "t", "c", "w") + + # 3 successes, 2 failures + registry.record_result("partial", 0, "", True) + registry.record_result("partial", 1, "", False, "error") + registry.record_result("partial", 2, "", True) + registry.record_result("partial", 3, "", False, "error") + registry.record_result("partial", 4, "", True) + + state = registry.get("partial") + + assert state.is_complete is True + assert state.completed_count == 5 + assert state.successful_count == 3 + + def test_results_out_of_order(self): + """Buffer should handle results arriving out of order.""" + registry = BufferRegistry() + + registry.create("ooo", 3, "x", "t", "c", "w") + + # Results arrive out of order + registry.record_result("ooo", 2, "", True) + registry.record_result("ooo", 0, "", True) + registry.record_result("ooo", 1, "", True) + + state = registry.get("ooo") + + assert state.is_complete is True + ordered = state.get_ordered_results() + assert ordered[0].result == "" + assert ordered[1].result == "" + assert ordered[2].result == "" diff --git a/tests/test_pump_integration.py b/tests/test_pump_integration.py index 6e401f9..a44f3cf 100644 --- a/tests/test_pump_integration.py +++ b/tests/test_pump_integration.py @@ -66,7 +66,7 @@ class TestPumpBootstrap: pump = await bootstrap('config/organism.yaml') assert pump.config.name == "hello-world" - assert len(pump.routing_table) == 6 # 3 user listeners + 3 system (boot, todo, todo-complete) + assert len(pump.routing_table) == 8 # 3 user listeners + 5 system (boot, todo, todo-complete, sequence, buffer) assert "greeter.greeting" in pump.routing_table assert "shouter.greetingresponse" in pump.routing_table assert "response-handler.shoutedresponse" in pump.routing_table diff --git a/tests/test_sequence.py b/tests/test_sequence.py new file mode 100644 index 0000000..4895071 --- /dev/null +++ b/tests/test_sequence.py @@ -0,0 +1,464 @@ +""" +test_sequence.py — Tests for the Sequence orchestration primitives. + +Tests: +1. SequenceRegistry basic operations +2. SequenceStart handler +3. Step result handling +4. Error propagation +""" + +import pytest +import uuid + +from xml_pipeline.message_bus.sequence_registry import ( + SequenceRegistry, + SequenceState, + get_sequence_registry, + reset_sequence_registry, +) +from xml_pipeline.message_bus.message_state import HandlerMetadata, HandlerResponse +from xml_pipeline.primitives.sequence import ( + SequenceStart, + SequenceComplete, + SequenceError, + handle_sequence_start, +) + + +class TestSequenceRegistry: + """Test SequenceRegistry basic operations.""" + + def test_create_sequence_state(self): + """create() should create a sequence state with correct fields.""" + registry = SequenceRegistry() + + state = registry.create( + sequence_id="seq001", + steps=["step1", "step2", "step3"], + return_to="caller", + thread_id="thread-123", + from_id="console", + initial_payload="", + ) + + assert state.sequence_id == "seq001" + assert state.steps == ["step1", "step2", "step3"] + assert state.return_to == "caller" + assert state.thread_id == "thread-123" + assert state.from_id == "console" + assert state.current_index == 0 + assert state.is_complete is False + assert state.current_step == "step1" + assert state.remaining_steps == ["step1", "step2", "step3"] + assert state.last_result == "" + + def test_get_returns_sequence(self): + """get() should return sequence by ID.""" + registry = SequenceRegistry() + + registry.create( + sequence_id="seq002", + steps=["a", "b"], + return_to="x", + thread_id="t", + from_id="c", + ) + + state = registry.get("seq002") + assert state is not None + assert state.sequence_id == "seq002" + + # Non-existent returns None + assert registry.get("nonexistent") is None + + def test_advance_increments_index(self): + """advance() should increment index and store result.""" + registry = SequenceRegistry() + + registry.create( + sequence_id="seq003", + steps=["a", "b", "c"], + return_to="x", + thread_id="t", + from_id="c", + ) + + # Advance first step + state = registry.advance("seq003", "") + assert state.current_index == 1 + assert state.results == [""] + assert state.last_result == "" + assert state.current_step == "b" + assert state.is_complete is False + + # Advance second step + state = registry.advance("seq003", "") + assert state.current_index == 2 + assert state.results == ["", ""] + assert state.current_step == "c" + + # Advance third step - now complete + state = registry.advance("seq003", "") + assert state.current_index == 3 + assert state.is_complete is True + assert state.current_step is None + assert state.remaining_steps == [] + + def test_mark_failed(self): + """mark_failed() should set failed state.""" + registry = SequenceRegistry() + + registry.create( + sequence_id="seq004", + steps=["a", "b"], + return_to="x", + thread_id="t", + from_id="c", + ) + + state = registry.mark_failed("seq004", "a", "XSD validation failed") + + assert state.failed is True + assert state.failed_step == "a" + assert state.error == "XSD validation failed" + assert state.is_complete is False # Not complete, but failed + assert state.current_step is None # No current step when failed + assert state.remaining_steps == [] # No remaining steps when failed + + def test_remove_deletes_sequence(self): + """remove() should delete sequence from registry.""" + registry = SequenceRegistry() + + registry.create( + sequence_id="seq005", + steps=["a"], + return_to="x", + thread_id="t", + from_id="c", + ) + + assert registry.get("seq005") is not None + result = registry.remove("seq005") + assert result is True + assert registry.get("seq005") is None + + # Remove non-existent returns False + assert registry.remove("nonexistent") is False + + def test_list_active(self): + """list_active() should return all active sequence IDs.""" + registry = SequenceRegistry() + + registry.create("seq-a", ["1"], "x", "t", "c") + registry.create("seq-b", ["2"], "x", "t", "c") + registry.create("seq-c", ["3"], "x", "t", "c") + + active = registry.list_active() + assert set(active) == {"seq-a", "seq-b", "seq-c"} + + def test_clear(self): + """clear() should remove all sequences.""" + registry = SequenceRegistry() + + registry.create("seq-1", ["a"], "x", "t", "c") + registry.create("seq-2", ["b"], "x", "t", "c") + + registry.clear() + + assert registry.list_active() == [] + + +class TestSequenceStateProperties: + """Test SequenceState computed properties.""" + + def test_is_complete_after_all_steps(self): + """is_complete should be True when all steps are done.""" + state = SequenceState( + sequence_id="test", + steps=["a", "b"], + return_to="x", + thread_id="t", + from_id="c", + current_index=2, # Past all steps + ) + assert state.is_complete is True + + def test_not_complete_when_failed(self): + """is_complete should be False when failed.""" + state = SequenceState( + sequence_id="test", + steps=["a", "b"], + return_to="x", + thread_id="t", + from_id="c", + current_index=2, + failed=True, + ) + assert state.is_complete is False + + def test_current_step_none_when_complete(self): + """current_step should be None when sequence is complete.""" + state = SequenceState( + sequence_id="test", + steps=["a"], + return_to="x", + thread_id="t", + from_id="c", + current_index=1, + ) + assert state.current_step is None + + def test_remaining_steps_empty_when_complete(self): + """remaining_steps should be empty when complete.""" + state = SequenceState( + sequence_id="test", + steps=["a", "b"], + return_to="x", + thread_id="t", + from_id="c", + current_index=2, + ) + assert state.remaining_steps == [] + + +class TestSequenceStartPayload: + """Test SequenceStart payload serialization.""" + + def test_sequence_start_fields(self): + """SequenceStart should have expected fields.""" + payload = SequenceStart( + steps="step1,step2", + payload="", + return_to="caller", + sequence_id="custom-id", + ) + + assert payload.steps == "step1,step2" + assert payload.payload == "" + assert payload.return_to == "caller" + assert payload.sequence_id == "custom-id" + + def test_sequence_start_default_values(self): + """SequenceStart should have sensible defaults.""" + payload = SequenceStart() + + assert payload.steps == "" + assert payload.payload == "" + assert payload.return_to == "" + assert payload.sequence_id == "" + + +class TestSequenceCompletePayload: + """Test SequenceComplete payload.""" + + def test_sequence_complete_fields(self): + """SequenceComplete should have expected fields.""" + payload = SequenceComplete( + sequence_id="seq123", + final_result="42", + step_count=3, + ) + + assert payload.sequence_id == "seq123" + assert payload.final_result == "42" + assert payload.step_count == 3 + + +class TestSequenceErrorPayload: + """Test SequenceError payload.""" + + def test_sequence_error_fields(self): + """SequenceError should have expected fields.""" + payload = SequenceError( + sequence_id="seq456", + failed_step="bad-step", + step_index=1, + error="Validation failed", + ) + + assert payload.sequence_id == "seq456" + assert payload.failed_step == "bad-step" + assert payload.step_index == 1 + assert payload.error == "Validation failed" + + +class TestHandleSequenceStartValidation: + """Test validation in handle_sequence_start.""" + + @pytest.fixture(autouse=True) + def setup(self): + """Reset registries before each test.""" + reset_sequence_registry() + + @pytest.mark.asyncio + async def test_empty_steps_returns_error(self): + """handle_sequence_start should return error for empty steps.""" + from unittest.mock import patch, MagicMock + + # Mock get_stream_pump to avoid dependency + mock_pump = MagicMock() + mock_pump.listeners = {} + + with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump): + payload = SequenceStart( + steps="", # Empty + payload="", + return_to="caller", + ) + + metadata = HandlerMetadata( + thread_id=str(uuid.uuid4()), + from_id="console", + ) + + response = await handle_sequence_start(payload, metadata) + + assert isinstance(response, HandlerResponse) + assert isinstance(response.payload, SequenceError) + assert "No steps" in response.payload.error + + @pytest.mark.asyncio + async def test_unknown_step_returns_error(self): + """handle_sequence_start should return error for unknown step.""" + from unittest.mock import patch, MagicMock + + # Mock get_stream_pump with no listeners + mock_pump = MagicMock() + mock_pump.listeners = {} # No listeners registered + + with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump): + payload = SequenceStart( + steps="unknown_step", + payload="", + return_to="caller", + ) + + metadata = HandlerMetadata( + thread_id=str(uuid.uuid4()), + from_id="console", + ) + + response = await handle_sequence_start(payload, metadata) + + assert isinstance(response, HandlerResponse) + assert isinstance(response.payload, SequenceError) + assert "Unknown listener" in response.payload.error + assert "unknown_step" in response.payload.failed_step + + +class TestSequenceRegistrySingleton: + """Test singleton pattern for SequenceRegistry.""" + + def test_get_sequence_registry_returns_singleton(self): + """get_sequence_registry should return same instance.""" + reset_sequence_registry() + + reg1 = get_sequence_registry() + reg2 = get_sequence_registry() + + assert reg1 is reg2 + + def test_reset_creates_new_instance(self): + """reset_sequence_registry should clear singleton.""" + reg1 = get_sequence_registry() + reg1.create("test", ["a"], "x", "t", "c") + + reset_sequence_registry() + reg2 = get_sequence_registry() + + assert reg2.get("test") is None + + +class TestSequenceMultipleSteps: + """Test sequences with multiple steps.""" + + def test_three_step_sequence(self): + """A three-step sequence should advance through all steps.""" + registry = SequenceRegistry() + + registry.create( + sequence_id="multi", + steps=["add", "multiply", "format"], + return_to="caller", + thread_id="t", + from_id="c", + initial_payload="5", + ) + + # Step 1: add + state = registry.get("multi") + assert state.current_step == "add" + state = registry.advance("multi", "8") + assert state.last_result == "8" + + # Step 2: multiply + assert state.current_step == "multiply" + state = registry.advance("multi", "40") + + # Step 3: format + assert state.current_step == "format" + state = registry.advance("multi", "Result: 40") + + # Complete + assert state.is_complete is True + assert len(state.results) == 3 + assert state.results[0] == "8" + assert state.results[1] == "40" + assert state.results[2] == "Result: 40" + + def test_failure_at_middle_step(self): + """Failure at middle step should stop sequence.""" + registry = SequenceRegistry() + + registry.create( + sequence_id="fail-mid", + steps=["step1", "step2", "step3"], + return_to="caller", + thread_id="t", + from_id="c", + ) + + # Step 1 succeeds + registry.advance("fail-mid", "") + + # Step 2 fails + state = registry.mark_failed("fail-mid", "step2", "Connection timeout") + + assert state.failed is True + assert state.failed_step == "step2" + assert state.current_index == 1 # Was at step 2 (index 1) + assert len(state.results) == 1 # Only step 1 result + + +class TestSequenceWithRealSteps: + """Integration-style tests with mock handlers.""" + + @pytest.mark.asyncio + async def test_sequence_creates_ephemeral_listener(self): + """Starting a sequence should create an ephemeral listener.""" + from unittest.mock import patch, MagicMock, AsyncMock + + mock_pump = MagicMock() + mock_pump.listeners = {"step1": MagicMock(), "step2": MagicMock()} + mock_pump.register_generic_listener = MagicMock() + + with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump): + payload = SequenceStart( + steps="step1,step2", + payload="", + return_to="caller", + ) + + metadata = HandlerMetadata( + thread_id=str(uuid.uuid4()), + from_id="console", + ) + + response = await handle_sequence_start(payload, metadata) + + # Should have registered an ephemeral listener + mock_pump.register_generic_listener.assert_called_once() + call_args = mock_pump.register_generic_listener.call_args + name_arg = call_args.kwargs.get('name') or call_args.args[0] + assert name_arg.startswith("sequence_") diff --git a/xml_pipeline/message_bus/__init__.py b/xml_pipeline/message_bus/__init__.py index dfdc985..8aaffa8 100644 --- a/xml_pipeline/message_bus/__init__.py +++ b/xml_pipeline/message_bus/__init__.py @@ -30,6 +30,9 @@ from xml_pipeline.message_bus.stream_pump import ( ListenerConfig, OrganismConfig, bootstrap, + get_stream_pump, + set_stream_pump, + reset_stream_pump, ) from xml_pipeline.message_bus.message_state import ( @@ -42,15 +45,47 @@ from xml_pipeline.message_bus.system_pipeline import ( ExternalMessage, ) +from xml_pipeline.message_bus.sequence_registry import ( + SequenceState, + SequenceRegistry, + get_sequence_registry, + reset_sequence_registry, +) + +from xml_pipeline.message_bus.buffer_registry import ( + BufferState, + BufferItemResult, + BufferRegistry, + get_buffer_registry, + reset_buffer_registry, +) + __all__ = [ + # Pump "StreamPump", "ConfigLoader", "Listener", "ListenerConfig", "OrganismConfig", + "bootstrap", + "get_stream_pump", + "set_stream_pump", + "reset_stream_pump", + # Message state "MessageState", "HandlerMetadata", - "bootstrap", + # System pipeline "SystemPipeline", "ExternalMessage", + # Sequence registry + "SequenceState", + "SequenceRegistry", + "get_sequence_registry", + "reset_sequence_registry", + # Buffer registry + "BufferState", + "BufferItemResult", + "BufferRegistry", + "get_buffer_registry", + "reset_buffer_registry", ] diff --git a/xml_pipeline/message_bus/buffer_registry.py b/xml_pipeline/message_bus/buffer_registry.py new file mode 100644 index 0000000..9054cf2 --- /dev/null +++ b/xml_pipeline/message_bus/buffer_registry.py @@ -0,0 +1,230 @@ +""" +buffer_registry.py — State storage for Buffer (fan-out) orchestration. + +Tracks active buffer executions that fan-out to parallel workers. +When a buffer starts, N items are dispatched in parallel. Results are +collected here. When all results are in (or timeout), BufferComplete is sent. + +Design: +- Thread-safe (same pattern as TodoRegistry, SequenceRegistry) +- Keyed by buffer_id (short UUID) +- Tracks: total items, received results, success/failure per item +- Supports fire-and-forget mode (collect=False) + +Usage: + registry = get_buffer_registry() + + # Start a buffer + registry.create( + buffer_id="abc123", + total_items=5, + return_to="greeter", + thread_id="...", + collect=True, + ) + + # Record result from worker + state = registry.record_result( + buffer_id="abc123", + index=2, + result="...", + success=True, + ) + + if state.is_complete: + # All workers done + final_results = state.results + registry.remove(buffer_id) +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Any +import threading + + +@dataclass +class BufferItemResult: + """Result from a single buffer item (worker).""" + + index: int + result: str # XML result + success: bool = True + error: Optional[str] = None + + +@dataclass +class BufferState: + """State for an active buffer execution.""" + + buffer_id: str + total_items: int # How many items were dispatched + return_to: str # Where to send BufferComplete + thread_id: str # Original thread for returning + from_id: str # Who started the buffer + target: str # Target listener for items + collect: bool = True # Whether to wait for results + + results: Dict[int, BufferItemResult] = field(default_factory=dict) + completed_count: int = 0 + successful_count: int = 0 + + @property + def is_complete(self) -> bool: + """True when all items have reported back.""" + return self.completed_count >= self.total_items + + @property + def pending_count(self) -> int: + """Number of items still pending.""" + return self.total_items - self.completed_count + + def get_ordered_results(self) -> List[Optional[BufferItemResult]]: + """Get results in order (None for missing indices).""" + return [self.results.get(i) for i in range(self.total_items)] + + +class BufferRegistry: + """ + Registry for active buffer executions. + + Thread-safe. Singleton pattern via get_buffer_registry(). + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._buffers: Dict[str, BufferState] = {} + + def create( + self, + buffer_id: str, + total_items: int, + return_to: str, + thread_id: str, + from_id: str, + target: str, + collect: bool = True, + ) -> BufferState: + """ + Create a new buffer execution. + + Args: + buffer_id: Unique ID for this buffer + total_items: Number of items being dispatched + return_to: Listener to send BufferComplete to + thread_id: Thread UUID for routing + from_id: Who initiated the buffer + target: Target listener for each item + collect: Whether to wait for and collect results + + Returns: + BufferState for tracking + """ + state = BufferState( + buffer_id=buffer_id, + total_items=total_items, + return_to=return_to, + thread_id=thread_id, + from_id=from_id, + target=target, + collect=collect, + ) + + with self._lock: + self._buffers[buffer_id] = state + + return state + + def get(self, buffer_id: str) -> Optional[BufferState]: + """Get buffer state by ID.""" + with self._lock: + return self._buffers.get(buffer_id) + + def record_result( + self, + buffer_id: str, + index: int, + result: str, + success: bool = True, + error: Optional[str] = None, + ) -> Optional[BufferState]: + """ + Record a result from a worker. + + Args: + buffer_id: Buffer this result belongs to + index: Which item index (0-based) + result: XML result from the worker + success: Whether the worker succeeded + error: Error message if failed + + Returns: + Updated BufferState, or None if buffer not found + """ + with self._lock: + state = self._buffers.get(buffer_id) + if state is None: + return None + + # Don't double-count results for same index + if index in state.results: + return state + + item_result = BufferItemResult( + index=index, + result=result, + success=success, + error=error, + ) + state.results[index] = item_result + state.completed_count += 1 + if success: + state.successful_count += 1 + + return state + + def remove(self, buffer_id: str) -> bool: + """ + Remove a buffer (cleanup after completion). + + Returns: + True if found and removed, False if not found + """ + with self._lock: + return self._buffers.pop(buffer_id, None) is not None + + def list_active(self) -> List[str]: + """List all active buffer IDs.""" + with self._lock: + return list(self._buffers.keys()) + + def clear(self) -> None: + """Clear all buffers. Useful for testing.""" + with self._lock: + self._buffers.clear() + + +# ============================================================================ +# Singleton +# ============================================================================ + +_registry: Optional[BufferRegistry] = None +_registry_lock = threading.Lock() + + +def get_buffer_registry() -> BufferRegistry: + """Get the global BufferRegistry singleton.""" + global _registry + if _registry is None: + with _registry_lock: + if _registry is None: + _registry = BufferRegistry() + return _registry + + +def reset_buffer_registry() -> None: + """Reset the global buffer registry (for testing).""" + global _registry + with _registry_lock: + if _registry is not None: + _registry.clear() + _registry = None diff --git a/xml_pipeline/message_bus/sequence_registry.py b/xml_pipeline/message_bus/sequence_registry.py new file mode 100644 index 0000000..1fb36a3 --- /dev/null +++ b/xml_pipeline/message_bus/sequence_registry.py @@ -0,0 +1,228 @@ +""" +sequence_registry.py — State storage for Sequence orchestration. + +Tracks active sequence executions across handler invocations. +When a sequence starts, its state is registered here. As steps complete, +the state is updated. When all steps are done, the state is cleaned up. + +Design: +- Thread-safe (same pattern as TodoRegistry) +- Keyed by sequence_id (short UUID) +- Tracks: steps list, current index, collected results +- Auto-cleanup when sequence completes or errors + +Usage: + registry = get_sequence_registry() + + # Start a sequence + registry.create( + sequence_id="abc123", + steps=["calculator.add", "calculator.multiply"], + return_to="greeter", + thread_id="...", + initial_payload="...", + ) + + # Advance on step completion + state = registry.advance(sequence_id, step_result="42") + if state.is_complete: + # All steps done + registry.remove(sequence_id) + + # On error + registry.mark_failed(sequence_id, step="calculator.add", error="XSD validation failed") +""" + +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Any +import threading + + +@dataclass +class SequenceState: + """State for an active sequence execution.""" + + sequence_id: str + steps: List[str] # Ordered list of listener names + return_to: str # Where to send final result + thread_id: str # Original thread for returning + from_id: str # Who started the sequence + + current_index: int = 0 # Which step we're on (0-based) + results: List[str] = field(default_factory=list) # XML results from each step + last_result: Optional[str] = None # Most recent step result (for chaining) + + failed: bool = False + failed_step: Optional[str] = None + error: Optional[str] = None + + @property + def is_complete(self) -> bool: + """True when all steps have been executed successfully.""" + return not self.failed and self.current_index >= len(self.steps) + + @property + def current_step(self) -> Optional[str]: + """Get current step name, or None if complete/failed.""" + if self.failed or self.current_index >= len(self.steps): + return None + return self.steps[self.current_index] + + @property + def remaining_steps(self) -> List[str]: + """Steps not yet executed.""" + if self.failed: + return [] + return self.steps[self.current_index:] + + +class SequenceRegistry: + """ + Registry for active sequence executions. + + Thread-safe. Singleton pattern via get_sequence_registry(). + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._sequences: Dict[str, SequenceState] = {} + + def create( + self, + sequence_id: str, + steps: List[str], + return_to: str, + thread_id: str, + from_id: str, + initial_payload: str = "", + ) -> SequenceState: + """ + Create a new sequence execution. + + Args: + sequence_id: Unique ID for this sequence + steps: Ordered list of listener names to call + return_to: Listener to send SequenceComplete to + thread_id: Thread UUID for routing + from_id: Who initiated the sequence + initial_payload: XML payload for first step + + Returns: + SequenceState for tracking + """ + state = SequenceState( + sequence_id=sequence_id, + steps=steps, + return_to=return_to, + thread_id=thread_id, + from_id=from_id, + last_result=initial_payload if initial_payload else None, + ) + + with self._lock: + self._sequences[sequence_id] = state + + return state + + def get(self, sequence_id: str) -> Optional[SequenceState]: + """Get sequence state by ID.""" + with self._lock: + return self._sequences.get(sequence_id) + + def advance(self, sequence_id: str, step_result: str) -> Optional[SequenceState]: + """ + Record step completion and advance to next step. + + Args: + sequence_id: Sequence to advance + step_result: XML result from the completed step + + Returns: + Updated SequenceState, or None if not found + """ + with self._lock: + state = self._sequences.get(sequence_id) + if state is None or state.failed: + return state + + # Record result + state.results.append(step_result) + state.last_result = step_result + state.current_index += 1 + + return state + + def mark_failed( + self, + sequence_id: str, + step: str, + error: str, + ) -> Optional[SequenceState]: + """ + Mark a sequence as failed. + + Args: + sequence_id: Sequence that failed + step: Which step failed + error: Error message + + Returns: + Updated SequenceState, or None if not found + """ + with self._lock: + state = self._sequences.get(sequence_id) + if state is None: + return None + + state.failed = True + state.failed_step = step + state.error = error + + return state + + def remove(self, sequence_id: str) -> bool: + """ + Remove a sequence (cleanup after completion). + + Returns: + True if found and removed, False if not found + """ + with self._lock: + return self._sequences.pop(sequence_id, None) is not None + + def list_active(self) -> List[str]: + """List all active sequence IDs.""" + with self._lock: + return list(self._sequences.keys()) + + def clear(self) -> None: + """Clear all sequences. Useful for testing.""" + with self._lock: + self._sequences.clear() + + +# ============================================================================ +# Singleton +# ============================================================================ + +_registry: Optional[SequenceRegistry] = None +_registry_lock = threading.Lock() + + +def get_sequence_registry() -> SequenceRegistry: + """Get the global SequenceRegistry singleton.""" + global _registry + if _registry is None: + with _registry_lock: + if _registry is None: + _registry = SequenceRegistry() + return _registry + + +def reset_sequence_registry() -> None: + """Reset the global sequence registry (for testing).""" + global _registry + with _registry_lock: + if _registry is not None: + _registry.clear() + _registry = None diff --git a/xml_pipeline/message_bus/stream_pump.py b/xml_pipeline/message_bus/stream_pump.py index ce0bb5c..6f5a089 100644 --- a/xml_pipeline/message_bus/stream_pump.py +++ b/xml_pipeline/message_bus/stream_pump.py @@ -210,6 +210,10 @@ class StreamPump: self.routing_table: Dict[str, List[Listener]] = {} self.listeners: Dict[str, Listener] = {} + # Generic listeners (accept any payload type) + # Used for ephemeral orchestration handlers (sequences, buffers) + self._generic_listeners: Dict[str, Listener] = {} + # Per-agent semaphores for rate limiting self.agent_semaphores: Dict[str, asyncio.Semaphore] = {} @@ -269,6 +273,82 @@ class StreamPump: self.listeners[lc.name] = listener return listener + def register_generic_listener( + self, + name: str, + handler: Callable, + description: str = "", + ) -> Listener: + """ + Register a generic listener that accepts any payload type. + + Used for ephemeral orchestration handlers (sequences, buffers) + that need to receive responses from various step types. + + Generic listeners: + - Are NOT added to the routing table (no root_tag) + - Are looked up by name (to_id) as a fallback in routing + - Receive payload_tree directly (no XSD validation/deserialization) + + Args: + name: Unique listener name (e.g., "sequence_abc123") + handler: Async handler function (receives payload_tree, metadata) + description: Human-readable description + + Returns: + Listener object + """ + listener = Listener( + name=name, + payload_class=object, # Placeholder - not used + handler=handler, + description=description, + is_agent=False, + root_tag="*", # Wildcard marker + ) + + self._generic_listeners[name.lower()] = listener + self.listeners[name] = listener + + pump_logger.debug(f"Registered generic listener: {name}") + return listener + + def unregister_listener(self, name: str) -> bool: + """ + Remove a listener by name. + + Used to clean up ephemeral listeners after orchestration completes. + + Args: + name: Listener name to remove + + Returns: + True if found and removed, False if not found + """ + name_lower = name.lower() + removed = False + + # Remove from generic listeners + if name_lower in self._generic_listeners: + del self._generic_listeners[name_lower] + removed = True + pump_logger.debug(f"Unregistered generic listener: {name}") + + # Remove from main listeners dict + if name in self.listeners: + listener = self.listeners.pop(name) + removed = True + + # Remove from routing table + if listener.root_tag and listener.root_tag != "*": + listeners_for_tag = self.routing_table.get(listener.root_tag, []) + if listener in listeners_for_tag: + listeners_for_tag.remove(listener) + if not listeners_for_tag: + del self.routing_table[listener.root_tag] + + return removed + def register_all(self) -> None: # First pass: register all listeners for lc in self.config.listeners: @@ -781,6 +861,8 @@ class StreamPump: Combined validation + deserialization. Uses to_id + payload tag to find the right listener and schema. + Falls back to generic listeners (ephemeral orchestration handlers) + when no regular listener matches. """ if state.error or state.payload_tree is None: return state @@ -794,6 +876,19 @@ class StreamPump: lookup_key = f"{to_id}.{payload_tag.lower()}" if to_id else payload_tag.lower() listeners = self.routing_table.get(lookup_key, []) + + # Fallback: check for generic listener by to_id + # Generic listeners accept any payload type (for orchestration) + if not listeners and to_id: + generic_listener = self._generic_listeners.get(to_id) + if generic_listener: + # Generic listener: skip XSD validation and deserialization + # Pass the raw payload_tree to the handler + state.payload = state.payload_tree # Handler receives Element + state.target_listeners = [generic_listener] + state.metadata["generic_handler"] = True + return state + if not listeners: state.error = f"No listener for: {lookup_key}" return state @@ -1008,6 +1103,36 @@ async def bootstrap(config_path: str = "config/organism.yaml") -> StreamPump: ) pump.register_listener(todo_complete_config) + # Register Sequence primitives (orchestration) + from xml_pipeline.primitives.sequence import ( + SequenceStart, handle_sequence_start, + ) + sequence_config = ListenerConfig( + name="system.sequence", + payload_class_path="xml_pipeline.primitives.sequence.SequenceStart", + handler_path="xml_pipeline.primitives.sequence.handle_sequence_start", + description="System sequence handler - chains listeners in order", + is_agent=False, + payload_class=SequenceStart, + handler=handle_sequence_start, + ) + pump.register_listener(sequence_config) + + # Register Buffer primitives (fan-out orchestration) + from xml_pipeline.primitives.buffer import ( + BufferStart, handle_buffer_start, + ) + buffer_config = ListenerConfig( + name="system.buffer", + payload_class_path="xml_pipeline.primitives.buffer.BufferStart", + handler_path="xml_pipeline.primitives.buffer.handle_buffer_start", + description="System buffer handler - fan-out to parallel workers", + is_agent=False, + payload_class=BufferStart, + handler=handle_buffer_start, + ) + pump.register_listener(buffer_config) + # Register all user-defined listeners pump.register_all() @@ -1061,6 +1186,9 @@ async def bootstrap(config_path: str = "config/organism.yaml") -> StreamPump: # Inject boot message (will be processed when pump.run() is called) await pump.inject(boot_envelope, thread_id=root_uuid, from_id="system") + # Set global pump instance for get_stream_pump() + set_stream_pump(pump) + print(f"Routing: {list(pump.routing_table.keys())}") return pump @@ -1110,3 +1238,45 @@ The key difference: - Old: 3 tool calls = 3 sequential awaits, each blocking until complete - New: 3 tool calls = 3 items in stream, processed concurrently up to task_limit """ + + +# ============================================================================ +# Global Singleton +# ============================================================================ + +_pump: Optional[StreamPump] = None + + +def get_stream_pump() -> StreamPump: + """ + Get the global StreamPump instance. + + The pump is initialized via bootstrap() and set here. + Raises RuntimeError if called before bootstrap. + """ + global _pump + if _pump is None: + raise RuntimeError( + "StreamPump not initialized. Call bootstrap() first." + ) + return _pump + + +def set_stream_pump(pump: StreamPump) -> None: + """ + Set the global StreamPump instance. + + Called by bootstrap() after creating the pump. + """ + global _pump + _pump = pump + + +def reset_stream_pump() -> None: + """ + Reset the global StreamPump instance. + + Useful for testing. + """ + global _pump + _pump = None diff --git a/xml_pipeline/primitives/__init__.py b/xml_pipeline/primitives/__init__.py index aaecc0f..a130ce0 100644 --- a/xml_pipeline/primitives/__init__.py +++ b/xml_pipeline/primitives/__init__.py @@ -15,16 +15,45 @@ from xml_pipeline.primitives.todo import ( handle_todo_complete, ) from xml_pipeline.primitives.text_input import TextInput, TextOutput +from xml_pipeline.primitives.sequence import ( + SequenceStart, + SequenceComplete, + SequenceError, + handle_sequence_start, +) +from xml_pipeline.primitives.buffer import ( + BufferStart, + BufferItem, + BufferComplete, + BufferDispatched, + BufferError, + handle_buffer_start, +) __all__ = [ + # Boot "Boot", "handle_boot", + # Todo "TodoUntil", "TodoComplete", "TodoRegistered", "TodoClosed", "handle_todo_until", "handle_todo_complete", + # Text I/O "TextInput", "TextOutput", + # Sequence orchestration + "SequenceStart", + "SequenceComplete", + "SequenceError", + "handle_sequence_start", + # Buffer orchestration + "BufferStart", + "BufferItem", + "BufferComplete", + "BufferDispatched", + "BufferError", + "handle_buffer_start", ] diff --git a/xml_pipeline/primitives/buffer.py b/xml_pipeline/primitives/buffer.py new file mode 100644 index 0000000..65603a8 --- /dev/null +++ b/xml_pipeline/primitives/buffer.py @@ -0,0 +1,373 @@ +""" +buffer.py — Buffer (fan-out) orchestration primitives. + +Buffers fan-out to parallel workers, sending N items to the same listener +concurrently. Results are collected and returned when all complete. + +Usage by an agent: + # Fan-out search queries to web_search + return HandlerResponse( + payload=BufferStart( + target="web_search", + items="python async\\nrust memory\\ngo concurrency", + return_to="my-agent", + collect=True, + ), + to="system.buffer", + ) + +Flow: + 1. system.buffer receives BufferStart with N items + 2. Creates ephemeral listener buffer_{id} to receive results + 3. Creates N sibling threads via ThreadRegistry + 4. Sends BufferItem to each worker FROM buffer_{id} + 5. Workers process and respond → routes to buffer_{id} + 6. Ephemeral handler collects results + 7. When all workers done, sends BufferComplete to return_to + 8. Cleans up ephemeral listener + +Fire-and-forget mode (collect=False): + - Returns immediately after dispatching + - No result collection + - Useful for async side effects +""" + +from dataclasses import dataclass +from typing import Optional +import uuid as uuid_module +import logging + +from lxml import etree +from third_party.xmlable import xmlify +from xml_pipeline.message_bus.message_state import ( + HandlerMetadata, + HandlerResponse, +) +from xml_pipeline.message_bus.buffer_registry import get_buffer_registry +from xml_pipeline.message_bus.thread_registry import get_registry as get_thread_registry + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Payloads +# ============================================================================ + +@xmlify +@dataclass +class BufferStart: + """ + Start a new buffer (fan-out) execution. + + Sent to system.buffer to begin parallel processing. + """ + target: str = "" # Listener to fan-out to + items: str = "" # Newline-separated payloads (raw XML) + collect: bool = True # Wait for all results? + return_to: str = "" # Where to send BufferComplete + buffer_id: str = "" # Auto-generated if empty + + +@xmlify +@dataclass +class BufferItem: + """ + Individual item being processed by a worker. + + Wraps the actual payload with buffer metadata. + Note: This is an internal type - workers receive the raw payload, + not BufferItem directly. + """ + buffer_id: str = "" + index: int = 0 + payload: str = "" # The actual XML payload + + +@xmlify +@dataclass +class BufferComplete: + """ + Buffer completed - all workers finished. + + Sent to return_to when all items are processed. + """ + buffer_id: str = "" + total: int = 0 + successful: int = 0 + results: str = "" # XML array of results + + +@xmlify +@dataclass +class BufferDispatched: + """ + Buffer dispatched (fire-and-forget mode). + + Sent immediately after items are dispatched when collect=False. + """ + buffer_id: str = "" + total: int = 0 + + +@xmlify +@dataclass +class BufferError: + """ + Buffer failed to start. + + Sent when buffer initialization fails. + """ + buffer_id: str = "" + error: str = "" + + +# ============================================================================ +# Handlers +# ============================================================================ + +async def handle_buffer_start( + payload: BufferStart, + metadata: HandlerMetadata, +) -> Optional[HandlerResponse]: + """ + Handle BufferStart — begin a fan-out execution. + + Creates N sibling threads, dispatches items to workers, + and sets up result collection. + """ + from xml_pipeline.message_bus.stream_pump import get_stream_pump + + # Parse items + items = [item.strip() for item in payload.items.split("\n") if item.strip()] + if not items: + logger.error("BufferStart with no items") + return HandlerResponse( + payload=BufferError( + buffer_id=payload.buffer_id or "unknown", + error="No items specified", + ), + to=payload.return_to or metadata.from_id, + ) + + # Validate target exists + pump = get_stream_pump() + if payload.target not in pump.listeners: + logger.error(f"BufferStart: unknown target '{payload.target}'") + return HandlerResponse( + payload=BufferError( + buffer_id=payload.buffer_id or "unknown", + error=f"Unknown target listener: {payload.target}", + ), + to=payload.return_to or metadata.from_id, + ) + + # Generate buffer ID if not provided + buf_id = payload.buffer_id or str(uuid_module.uuid4())[:8] + + # Create buffer state + buffer_registry = get_buffer_registry() + state = buffer_registry.create( + buffer_id=buf_id, + total_items=len(items), + return_to=payload.return_to or metadata.from_id, + thread_id=metadata.thread_id, + from_id=metadata.from_id, + target=payload.target, + collect=payload.collect, + ) + + # For fire-and-forget, we still track but don't wait + ephemeral_name = f"buffer_{buf_id}" + + if payload.collect: + # Create ephemeral handler for result collection + async def buffer_handler( + payload_tree: etree._Element, + meta: HandlerMetadata, + ) -> Optional[HandlerResponse]: + """Ephemeral handler that collects worker results.""" + return await _handle_buffer_result(buf_id, payload_tree, meta) + + # Register ephemeral listener (generic mode - accepts any payload) + pump.register_generic_listener( + name=ephemeral_name, + handler=buffer_handler, + description=f"Ephemeral buffer handler for {buf_id}", + ) + + logger.info( + f"Buffer {buf_id} starting: {len(items)} items to {payload.target}, " + f"collect={payload.collect}" + ) + + # Dispatch all items in parallel + thread_registry = get_thread_registry() + parent_chain = thread_registry.lookup(metadata.thread_id) or metadata.thread_id + + for i, item_payload in enumerate(items): + # Create sibling thread for this worker + worker_chain = f"{parent_chain}.{ephemeral_name}_w{i}" + worker_uuid = thread_registry.get_or_create(worker_chain) + + # Inject the item to the target + # The item is sent FROM the ephemeral listener so .respond() comes back + await _inject_buffer_item( + pump=pump, + target=payload.target, + payload_xml=item_payload, + thread_id=worker_uuid, + from_id=ephemeral_name, + ) + + logger.debug(f"Buffer {buf_id}: dispatched item {i} to {payload.target}") + + # Fire-and-forget: return immediately + if not payload.collect: + logger.info(f"Buffer {buf_id}: fire-and-forget mode, {len(items)} items dispatched") + return HandlerResponse( + payload=BufferDispatched( + buffer_id=buf_id, + total=len(items), + ), + to=payload.return_to or metadata.from_id, + ) + + # Collect mode: wait for results (handled by ephemeral listener) + # Return None - the ephemeral listener will send BufferComplete + return None + + +async def _handle_buffer_result( + buf_id: str, + payload_tree: etree._Element, + metadata: HandlerMetadata, +) -> Optional[HandlerResponse]: + """ + Handle a worker result in the buffer. + + Called by the ephemeral listener when a worker responds. + """ + from xml_pipeline.message_bus.stream_pump import get_stream_pump + + buffer_registry = get_buffer_registry() + state = buffer_registry.get(buf_id) + + if state is None: + logger.warning(f"Buffer {buf_id} not found in registry (result dropped)") + return None + + # Extract worker index from thread chain + # Chain format: parent.buffer_xyz_wN where N is the index + thread_registry = get_thread_registry() + chain = thread_registry.lookup(metadata.thread_id) or "" + + worker_index = _extract_worker_index(chain, buf_id) + if worker_index is None: + logger.warning(f"Buffer {buf_id}: could not determine worker index from chain") + worker_index = state.completed_count # Fallback to count-based + + # Serialize the result + result_xml = etree.tostring(payload_tree, encoding="unicode") + + # Check for errors + is_error = payload_tree.tag.lower() in ("huh", "systemerror") + + # Record result + state = buffer_registry.record_result( + buffer_id=buf_id, + index=worker_index, + result=result_xml, + success=not is_error, + error=result_xml[:200] if is_error else None, + ) + + logger.debug( + f"Buffer {buf_id}: received result {state.completed_count}/{state.total_items}" + ) + + # Check if all done + if state.is_complete: + # Clean up + pump = get_stream_pump() + pump.unregister_listener(f"buffer_{buf_id}") + + # Format results as XML array + results_xml = _format_buffer_results(state) + buffer_registry.remove(buf_id) + + logger.info( + f"Buffer {buf_id} completed: {state.successful_count}/{state.total_items} successful" + ) + + return HandlerResponse( + payload=BufferComplete( + buffer_id=buf_id, + total=state.total_items, + successful=state.successful_count, + results=results_xml, + ), + to=state.return_to, + ) + + # More results pending + return None + + +def _extract_worker_index(chain: str, buf_id: str) -> Optional[int]: + """Extract worker index from thread chain.""" + # Look for pattern: buffer_{id}_wN + import re + pattern = rf"buffer_{buf_id}_w(\d+)" + match = re.search(pattern, chain) + if match: + return int(match.group(1)) + return None + + +def _format_buffer_results(state) -> str: + """Format buffer results as XML array.""" + lines = [""] + for i in range(state.total_items): + result = state.results.get(i) + if result: + success = "true" if result.success else "false" + lines.append(f' ') + # Indent the result content + for line in result.result.split("\n"): + lines.append(f" {line}") + lines.append(" ") + else: + lines.append(f' missing') + lines.append("") + return "\n".join(lines) + + +async def _inject_buffer_item( + pump, + target: str, + payload_xml: str, + thread_id: str, + from_id: str, +) -> None: + """Inject a buffer item directly into the pump.""" + # Wrap the payload in an envelope + envelope = pump._wrap_in_envelope( + payload=_RawXmlPayload(payload_xml), + from_id=from_id, + to_id=target, + thread_id=thread_id, + ) + + # Inject into pump + await pump.inject(envelope, thread_id=thread_id, from_id=from_id) + + +class _RawXmlPayload: + """Carrier for raw XML that bypasses serialization.""" + + def __init__(self, xml: str): + self.xml = xml + + def to_xml(self) -> str: + """Return raw XML for envelope wrapping.""" + return self.xml diff --git a/xml_pipeline/primitives/sequence.py b/xml_pipeline/primitives/sequence.py new file mode 100644 index 0000000..01afa50 --- /dev/null +++ b/xml_pipeline/primitives/sequence.py @@ -0,0 +1,299 @@ +""" +sequence.py — Sequence orchestration primitives. + +Sequences chain multiple listeners in order, feeding the output of one step +as input to the next. Steps remain transparent - they don't know they're +part of a sequence. + +Usage by an agent: + # Start a sequence: add two numbers, then multiply + return HandlerResponse( + payload=SequenceStart( + steps="calculator.add,calculator.multiply", + payload='53', + return_to="my-agent", + ), + to="system.sequence", + ) + +Flow: + 1. system.sequence receives SequenceStart + 2. Creates ephemeral listener sequence_{id} to receive step results + 3. Sends initial payload to first step FROM sequence_{id} + 4. Step processes and responds → routes to sequence_{id} + 5. Ephemeral handler advances, sends to next step + 6. When all steps complete, sends SequenceComplete to return_to + 7. Cleans up ephemeral listener + +Key insight: Steps use normal .respond() - the ephemeral listener IS the +caller in the thread chain, so responses naturally route back to it. +""" + +from dataclasses import dataclass +from typing import Optional +import uuid as uuid_module +import logging + +from lxml import etree +from third_party.xmlable import xmlify +from xml_pipeline.message_bus.message_state import ( + HandlerMetadata, + HandlerResponse, + MessageState, +) +from xml_pipeline.message_bus.sequence_registry import get_sequence_registry + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Payloads +# ============================================================================ + +@xmlify +@dataclass +class SequenceStart: + """ + Start a new sequence execution. + + Sent to system.sequence to begin chaining steps. + """ + steps: str = "" # Comma-separated listener names + payload: str = "" # Initial XML payload for first step + return_to: str = "" # Where to send final result + sequence_id: str = "" # Auto-generated if empty + + +@xmlify +@dataclass +class SequenceComplete: + """ + Sequence completed successfully. + + Sent to the return_to listener when all steps finish. + """ + sequence_id: str = "" + final_result: str = "" # XML result from last step + step_count: int = 0 # How many steps were executed + + +@xmlify +@dataclass +class SequenceError: + """ + Sequence failed at a step. + + Sent to return_to when a step fails. + """ + sequence_id: str = "" + failed_step: str = "" # Which step failed + step_index: int = 0 # 0-based index of failed step + error: str = "" # Error message + + +# ============================================================================ +# Handlers +# ============================================================================ + +async def handle_sequence_start( + payload: SequenceStart, + metadata: HandlerMetadata, +) -> Optional[HandlerResponse]: + """ + Handle SequenceStart — begin a sequence execution. + + Creates an ephemeral listener for this sequence, stores state, + and kicks off the first step. + """ + from xml_pipeline.message_bus.stream_pump import get_stream_pump + + # Parse and validate + steps = [s.strip() for s in payload.steps.split(",") if s.strip()] + if not steps: + logger.error("SequenceStart with no steps") + return HandlerResponse( + payload=SequenceError( + sequence_id=payload.sequence_id or "unknown", + failed_step="", + step_index=0, + error="No steps specified", + ), + to=payload.return_to or metadata.from_id, + ) + + # Generate sequence ID if not provided + seq_id = payload.sequence_id or str(uuid_module.uuid4())[:8] + + # Validate all steps exist + pump = get_stream_pump() + for step in steps: + if step not in pump.listeners: + logger.error(f"SequenceStart: unknown step '{step}'") + return HandlerResponse( + payload=SequenceError( + sequence_id=seq_id, + failed_step=step, + step_index=steps.index(step), + error=f"Unknown listener: {step}", + ), + to=payload.return_to or metadata.from_id, + ) + + # Create sequence state + registry = get_sequence_registry() + state = registry.create( + sequence_id=seq_id, + steps=steps, + return_to=payload.return_to or metadata.from_id, + thread_id=metadata.thread_id, + from_id=metadata.from_id, + initial_payload=payload.payload, + ) + + # Create ephemeral handler for this sequence + ephemeral_name = f"sequence_{seq_id}" + + async def sequence_handler( + payload_tree: etree._Element, + meta: HandlerMetadata, + ) -> Optional[HandlerResponse]: + """Ephemeral handler that processes step results.""" + return await _handle_sequence_step_result(seq_id, payload_tree, meta) + + # Register ephemeral listener (generic mode - accepts any payload) + pump.register_generic_listener( + name=ephemeral_name, + handler=sequence_handler, + description=f"Ephemeral sequence handler for {seq_id}", + ) + + logger.info( + f"Sequence {seq_id} started: {len(steps)} steps, " + f"return_to={state.return_to}" + ) + + # Kick off first step + first_step = steps[0] + return _create_step_message( + seq_id=seq_id, + target=first_step, + payload_xml=payload.payload, + from_name=ephemeral_name, + ) + + +async def _handle_sequence_step_result( + seq_id: str, + payload_tree: etree._Element, + metadata: HandlerMetadata, +) -> Optional[HandlerResponse]: + """ + Handle a step result in the sequence. + + Called by the ephemeral listener when a step responds. + """ + from xml_pipeline.message_bus.stream_pump import get_stream_pump + + registry = get_sequence_registry() + state = registry.get(seq_id) + + if state is None: + logger.error(f"Sequence {seq_id} not found in registry") + return None + + # Serialize the result for storage + result_xml = etree.tostring(payload_tree, encoding="unicode") + + # Check for error responses + if payload_tree.tag.lower() in ("huh", "systemerror"): + # Step failed + error_text = payload_tree.text or etree.tostring(payload_tree, encoding="unicode") + registry.mark_failed(seq_id, state.current_step or "unknown", error_text) + + # Clean up and send error + pump = get_stream_pump() + pump.unregister_listener(f"sequence_{seq_id}") + registry.remove(seq_id) + + logger.warning(f"Sequence {seq_id} failed at step {state.current_index}") + return HandlerResponse( + payload=SequenceError( + sequence_id=seq_id, + failed_step=state.current_step or "unknown", + step_index=state.current_index, + error=error_text[:200], # Truncate long errors + ), + to=state.return_to, + ) + + # Advance to next step + state = registry.advance(seq_id, result_xml) + + if state.is_complete: + # All steps done - send completion + pump = get_stream_pump() + pump.unregister_listener(f"sequence_{seq_id}") + registry.remove(seq_id) + + logger.info(f"Sequence {seq_id} completed: {len(state.steps)} steps") + return HandlerResponse( + payload=SequenceComplete( + sequence_id=seq_id, + final_result=result_xml, + step_count=len(state.steps), + ), + to=state.return_to, + ) + + # More steps to go - send to next step + next_step = state.current_step + logger.debug( + f"Sequence {seq_id} advancing to step {state.current_index}: {next_step}" + ) + + return _create_step_message( + seq_id=seq_id, + target=next_step, + payload_xml=result_xml, + from_name=f"sequence_{seq_id}", + ) + + +def _create_step_message( + seq_id: str, + target: str, + payload_xml: str, + from_name: str, +) -> HandlerResponse: + """ + Create a HandlerResponse to send payload to a step. + + We need to inject the message with the ephemeral listener as the sender, + so that .respond() routes back to us. + """ + from xml_pipeline.primitives.sequence import _RawPayloadCarrier + + # Return a special carrier that tells the pump to: + # 1. Use the raw XML bytes directly + # 2. Set from_id to from_name (the ephemeral listener) + return HandlerResponse( + payload=_RawPayloadCarrier(xml=payload_xml, from_override=from_name), + to=target, + ) + + +class _RawPayloadCarrier: + """ + Internal carrier for raw XML that bypasses normal serialization. + + When the pump sees this, it uses the raw XML directly instead of + serializing a dataclass. + """ + + def __init__(self, xml: str, from_override: Optional[str] = None): + self.xml = xml + self.from_override = from_override + + def to_xml(self) -> str: + """Return raw XML for envelope wrapping.""" + return self.xml