""" test_sequence.py — Tests for the Sequence orchestration primitives. Tests: 1. SequenceRegistry basic operations 2. SequenceStart handler 3. Step result handling 4. Error propagation """ import pytest import uuid from xml_pipeline.message_bus.sequence_registry import ( SequenceRegistry, SequenceState, get_sequence_registry, reset_sequence_registry, ) from xml_pipeline.message_bus.message_state import HandlerMetadata, HandlerResponse from xml_pipeline.primitives.sequence import ( SequenceStart, SequenceComplete, SequenceError, handle_sequence_start, ) class TestSequenceRegistry: """Test SequenceRegistry basic operations.""" def test_create_sequence_state(self): """create() should create a sequence state with correct fields.""" registry = SequenceRegistry() state = registry.create( sequence_id="seq001", steps=["step1", "step2", "step3"], return_to="caller", thread_id="thread-123", from_id="console", initial_payload="", ) assert state.sequence_id == "seq001" assert state.steps == ["step1", "step2", "step3"] assert state.return_to == "caller" assert state.thread_id == "thread-123" assert state.from_id == "console" assert state.current_index == 0 assert state.is_complete is False assert state.current_step == "step1" assert state.remaining_steps == ["step1", "step2", "step3"] assert state.last_result == "" def test_get_returns_sequence(self): """get() should return sequence by ID.""" registry = SequenceRegistry() registry.create( sequence_id="seq002", steps=["a", "b"], return_to="x", thread_id="t", from_id="c", ) state = registry.get("seq002") assert state is not None assert state.sequence_id == "seq002" # Non-existent returns None assert registry.get("nonexistent") is None def test_advance_increments_index(self): """advance() should increment index and store result.""" registry = SequenceRegistry() registry.create( sequence_id="seq003", steps=["a", "b", "c"], return_to="x", thread_id="t", from_id="c", ) # Advance first step state = registry.advance("seq003", "") assert state.current_index == 1 assert state.results == [""] assert state.last_result == "" assert state.current_step == "b" assert state.is_complete is False # Advance second step state = registry.advance("seq003", "") assert state.current_index == 2 assert state.results == ["", ""] assert state.current_step == "c" # Advance third step - now complete state = registry.advance("seq003", "") assert state.current_index == 3 assert state.is_complete is True assert state.current_step is None assert state.remaining_steps == [] def test_mark_failed(self): """mark_failed() should set failed state.""" registry = SequenceRegistry() registry.create( sequence_id="seq004", steps=["a", "b"], return_to="x", thread_id="t", from_id="c", ) state = registry.mark_failed("seq004", "a", "XSD validation failed") assert state.failed is True assert state.failed_step == "a" assert state.error == "XSD validation failed" assert state.is_complete is False # Not complete, but failed assert state.current_step is None # No current step when failed assert state.remaining_steps == [] # No remaining steps when failed def test_remove_deletes_sequence(self): """remove() should delete sequence from registry.""" registry = SequenceRegistry() registry.create( sequence_id="seq005", steps=["a"], return_to="x", thread_id="t", from_id="c", ) assert registry.get("seq005") is not None result = registry.remove("seq005") assert result is True assert registry.get("seq005") is None # Remove non-existent returns False assert registry.remove("nonexistent") is False def test_list_active(self): """list_active() should return all active sequence IDs.""" registry = SequenceRegistry() registry.create("seq-a", ["1"], "x", "t", "c") registry.create("seq-b", ["2"], "x", "t", "c") registry.create("seq-c", ["3"], "x", "t", "c") active = registry.list_active() assert set(active) == {"seq-a", "seq-b", "seq-c"} def test_clear(self): """clear() should remove all sequences.""" registry = SequenceRegistry() registry.create("seq-1", ["a"], "x", "t", "c") registry.create("seq-2", ["b"], "x", "t", "c") registry.clear() assert registry.list_active() == [] class TestSequenceStateProperties: """Test SequenceState computed properties.""" def test_is_complete_after_all_steps(self): """is_complete should be True when all steps are done.""" state = SequenceState( sequence_id="test", steps=["a", "b"], return_to="x", thread_id="t", from_id="c", current_index=2, # Past all steps ) assert state.is_complete is True def test_not_complete_when_failed(self): """is_complete should be False when failed.""" state = SequenceState( sequence_id="test", steps=["a", "b"], return_to="x", thread_id="t", from_id="c", current_index=2, failed=True, ) assert state.is_complete is False def test_current_step_none_when_complete(self): """current_step should be None when sequence is complete.""" state = SequenceState( sequence_id="test", steps=["a"], return_to="x", thread_id="t", from_id="c", current_index=1, ) assert state.current_step is None def test_remaining_steps_empty_when_complete(self): """remaining_steps should be empty when complete.""" state = SequenceState( sequence_id="test", steps=["a", "b"], return_to="x", thread_id="t", from_id="c", current_index=2, ) assert state.remaining_steps == [] class TestSequenceStartPayload: """Test SequenceStart payload serialization.""" def test_sequence_start_fields(self): """SequenceStart should have expected fields.""" payload = SequenceStart( steps="step1,step2", payload="", return_to="caller", sequence_id="custom-id", ) assert payload.steps == "step1,step2" assert payload.payload == "" assert payload.return_to == "caller" assert payload.sequence_id == "custom-id" def test_sequence_start_default_values(self): """SequenceStart should have sensible defaults.""" payload = SequenceStart() assert payload.steps == "" assert payload.payload == "" assert payload.return_to == "" assert payload.sequence_id == "" class TestSequenceCompletePayload: """Test SequenceComplete payload.""" def test_sequence_complete_fields(self): """SequenceComplete should have expected fields.""" payload = SequenceComplete( sequence_id="seq123", final_result="42", step_count=3, ) assert payload.sequence_id == "seq123" assert payload.final_result == "42" assert payload.step_count == 3 class TestSequenceErrorPayload: """Test SequenceError payload.""" def test_sequence_error_fields(self): """SequenceError should have expected fields.""" payload = SequenceError( sequence_id="seq456", failed_step="bad-step", step_index=1, error="Validation failed", ) assert payload.sequence_id == "seq456" assert payload.failed_step == "bad-step" assert payload.step_index == 1 assert payload.error == "Validation failed" class TestHandleSequenceStartValidation: """Test validation in handle_sequence_start.""" @pytest.fixture(autouse=True) def setup(self): """Reset registries before each test.""" reset_sequence_registry() @pytest.mark.asyncio async def test_empty_steps_returns_error(self): """handle_sequence_start should return error for empty steps.""" from unittest.mock import patch, MagicMock # Mock get_stream_pump to avoid dependency mock_pump = MagicMock() mock_pump.listeners = {} with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump): payload = SequenceStart( steps="", # Empty payload="", return_to="caller", ) metadata = HandlerMetadata( thread_id=str(uuid.uuid4()), from_id="console", ) response = await handle_sequence_start(payload, metadata) assert isinstance(response, HandlerResponse) assert isinstance(response.payload, SequenceError) assert "No steps" in response.payload.error @pytest.mark.asyncio async def test_unknown_step_returns_error(self): """handle_sequence_start should return error for unknown step.""" from unittest.mock import patch, MagicMock # Mock get_stream_pump with no listeners mock_pump = MagicMock() mock_pump.listeners = {} # No listeners registered with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump): payload = SequenceStart( steps="unknown_step", payload="", return_to="caller", ) metadata = HandlerMetadata( thread_id=str(uuid.uuid4()), from_id="console", ) response = await handle_sequence_start(payload, metadata) assert isinstance(response, HandlerResponse) assert isinstance(response.payload, SequenceError) assert "Unknown listener" in response.payload.error assert "unknown_step" in response.payload.failed_step class TestSequenceRegistrySingleton: """Test singleton pattern for SequenceRegistry.""" def test_get_sequence_registry_returns_singleton(self): """get_sequence_registry should return same instance.""" reset_sequence_registry() reg1 = get_sequence_registry() reg2 = get_sequence_registry() assert reg1 is reg2 def test_reset_creates_new_instance(self): """reset_sequence_registry should clear singleton.""" reg1 = get_sequence_registry() reg1.create("test", ["a"], "x", "t", "c") reset_sequence_registry() reg2 = get_sequence_registry() assert reg2.get("test") is None class TestSequenceMultipleSteps: """Test sequences with multiple steps.""" def test_three_step_sequence(self): """A three-step sequence should advance through all steps.""" registry = SequenceRegistry() registry.create( sequence_id="multi", steps=["add", "multiply", "format"], return_to="caller", thread_id="t", from_id="c", initial_payload="5", ) # Step 1: add state = registry.get("multi") assert state.current_step == "add" state = registry.advance("multi", "8") assert state.last_result == "8" # Step 2: multiply assert state.current_step == "multiply" state = registry.advance("multi", "40") # Step 3: format assert state.current_step == "format" state = registry.advance("multi", "Result: 40") # Complete assert state.is_complete is True assert len(state.results) == 3 assert state.results[0] == "8" assert state.results[1] == "40" assert state.results[2] == "Result: 40" def test_failure_at_middle_step(self): """Failure at middle step should stop sequence.""" registry = SequenceRegistry() registry.create( sequence_id="fail-mid", steps=["step1", "step2", "step3"], return_to="caller", thread_id="t", from_id="c", ) # Step 1 succeeds registry.advance("fail-mid", "") # Step 2 fails state = registry.mark_failed("fail-mid", "step2", "Connection timeout") assert state.failed is True assert state.failed_step == "step2" assert state.current_index == 1 # Was at step 2 (index 1) assert len(state.results) == 1 # Only step 1 result class TestSequenceWithRealSteps: """Integration-style tests with mock handlers.""" @pytest.mark.asyncio async def test_sequence_creates_ephemeral_listener(self): """Starting a sequence should create an ephemeral listener.""" from unittest.mock import patch, MagicMock, AsyncMock mock_pump = MagicMock() mock_pump.listeners = {"step1": MagicMock(), "step2": MagicMock()} mock_pump.register_generic_listener = MagicMock() with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump): payload = SequenceStart( steps="step1,step2", payload="", return_to="caller", ) metadata = HandlerMetadata( thread_id=str(uuid.uuid4()), from_id="console", ) response = await handle_sequence_start(payload, metadata) # Should have registered an ephemeral listener mock_pump.register_generic_listener.assert_called_once() call_args = mock_pump.register_generic_listener.call_args name_arg = call_args.kwargs.get('name') or call_args.args[0] assert name_arg.startswith("sequence_")