Implement two virtual node patterns for message flow orchestration: - Sequence: Chains listeners in order (A→B→C), feeding each step's output as input to the next. Uses ephemeral listeners to intercept step results without modifying core pump behavior. - Buffer: Fan-out to parallel worker threads with optional result collection. Supports fire-and-forget mode (collect=False) for non-blocking dispatch. New files: - sequence_registry.py / buffer_registry.py: State tracking - sequence.py / buffer.py: Payloads and handlers - test_sequence.py / test_buffer.py: 52 new tests Pump additions: - register_generic_listener(): Accept any payload type - unregister_listener(): Cleanup ephemeral listeners - Global singleton accessors for pump instance Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
635 lines
20 KiB
Python
635 lines
20 KiB
Python
"""
|
|
test_buffer.py — Tests for the Buffer (fan-out) orchestration primitives.
|
|
|
|
Tests:
|
|
1. BufferRegistry basic operations
|
|
2. BufferStart handler
|
|
3. Result collection
|
|
4. Fire-and-forget mode
|
|
"""
|
|
|
|
import pytest
|
|
import uuid
|
|
|
|
from xml_pipeline.message_bus.buffer_registry import (
|
|
BufferRegistry,
|
|
BufferState,
|
|
BufferItemResult,
|
|
get_buffer_registry,
|
|
reset_buffer_registry,
|
|
)
|
|
from xml_pipeline.message_bus.message_state import HandlerMetadata, HandlerResponse
|
|
from xml_pipeline.primitives.buffer import (
|
|
BufferStart,
|
|
BufferComplete,
|
|
BufferDispatched,
|
|
BufferError,
|
|
handle_buffer_start,
|
|
_extract_worker_index,
|
|
_format_buffer_results,
|
|
)
|
|
|
|
|
|
class TestBufferRegistry:
|
|
"""Test BufferRegistry basic operations."""
|
|
|
|
def test_create_buffer_state(self):
|
|
"""create() should create a buffer state with correct fields."""
|
|
registry = BufferRegistry()
|
|
|
|
state = registry.create(
|
|
buffer_id="buf001",
|
|
total_items=5,
|
|
return_to="caller",
|
|
thread_id="thread-123",
|
|
from_id="console",
|
|
target="worker",
|
|
collect=True,
|
|
)
|
|
|
|
assert state.buffer_id == "buf001"
|
|
assert state.total_items == 5
|
|
assert state.return_to == "caller"
|
|
assert state.thread_id == "thread-123"
|
|
assert state.from_id == "console"
|
|
assert state.target == "worker"
|
|
assert state.collect is True
|
|
assert state.completed_count == 0
|
|
assert state.successful_count == 0
|
|
assert state.is_complete is False
|
|
assert state.pending_count == 5
|
|
|
|
def test_get_returns_buffer(self):
|
|
"""get() should return buffer by ID."""
|
|
registry = BufferRegistry()
|
|
|
|
registry.create(
|
|
buffer_id="buf002",
|
|
total_items=3,
|
|
return_to="x",
|
|
thread_id="t",
|
|
from_id="c",
|
|
target="w",
|
|
)
|
|
|
|
state = registry.get("buf002")
|
|
assert state is not None
|
|
assert state.buffer_id == "buf002"
|
|
|
|
# Non-existent returns None
|
|
assert registry.get("nonexistent") is None
|
|
|
|
def test_record_result_stores_result(self):
|
|
"""record_result() should store result and update counts."""
|
|
registry = BufferRegistry()
|
|
|
|
registry.create(
|
|
buffer_id="buf003",
|
|
total_items=3,
|
|
return_to="x",
|
|
thread_id="t",
|
|
from_id="c",
|
|
target="w",
|
|
)
|
|
|
|
# Record first result (success)
|
|
state = registry.record_result(
|
|
buffer_id="buf003",
|
|
index=0,
|
|
result="<Result0/>",
|
|
success=True,
|
|
)
|
|
|
|
assert state.completed_count == 1
|
|
assert state.successful_count == 1
|
|
assert state.is_complete is False
|
|
assert 0 in state.results
|
|
assert state.results[0].result == "<Result0/>"
|
|
assert state.results[0].success is True
|
|
|
|
# Record second result (failure)
|
|
state = registry.record_result(
|
|
buffer_id="buf003",
|
|
index=1,
|
|
result="<Error/>",
|
|
success=False,
|
|
error="Timeout",
|
|
)
|
|
|
|
assert state.completed_count == 2
|
|
assert state.successful_count == 1 # Still just 1
|
|
assert state.results[1].success is False
|
|
assert state.results[1].error == "Timeout"
|
|
|
|
# Record third result - now complete
|
|
state = registry.record_result(
|
|
buffer_id="buf003",
|
|
index=2,
|
|
result="<Result2/>",
|
|
success=True,
|
|
)
|
|
|
|
assert state.completed_count == 3
|
|
assert state.successful_count == 2
|
|
assert state.is_complete is True
|
|
assert state.pending_count == 0
|
|
|
|
def test_record_result_ignores_duplicates(self):
|
|
"""record_result() should not count same index twice."""
|
|
registry = BufferRegistry()
|
|
|
|
registry.create(
|
|
buffer_id="buf004",
|
|
total_items=2,
|
|
return_to="x",
|
|
thread_id="t",
|
|
from_id="c",
|
|
target="w",
|
|
)
|
|
|
|
# Record index 0
|
|
state = registry.record_result("buf004", 0, "<R/>", True)
|
|
assert state.completed_count == 1
|
|
|
|
# Try to record index 0 again
|
|
state = registry.record_result("buf004", 0, "<Duplicate/>", True)
|
|
assert state.completed_count == 1 # Should not increment
|
|
assert state.results[0].result == "<R/>" # Original preserved
|
|
|
|
def test_remove_deletes_buffer(self):
|
|
"""remove() should delete buffer from registry."""
|
|
registry = BufferRegistry()
|
|
|
|
registry.create(
|
|
buffer_id="buf005",
|
|
total_items=1,
|
|
return_to="x",
|
|
thread_id="t",
|
|
from_id="c",
|
|
target="w",
|
|
)
|
|
|
|
assert registry.get("buf005") is not None
|
|
result = registry.remove("buf005")
|
|
assert result is True
|
|
assert registry.get("buf005") is None
|
|
|
|
# Remove non-existent returns False
|
|
assert registry.remove("nonexistent") is False
|
|
|
|
def test_list_active(self):
|
|
"""list_active() should return all active buffer IDs."""
|
|
registry = BufferRegistry()
|
|
|
|
registry.create("buf-a", 1, "x", "t", "c", "w")
|
|
registry.create("buf-b", 2, "x", "t", "c", "w")
|
|
registry.create("buf-c", 3, "x", "t", "c", "w")
|
|
|
|
active = registry.list_active()
|
|
assert set(active) == {"buf-a", "buf-b", "buf-c"}
|
|
|
|
def test_clear(self):
|
|
"""clear() should remove all buffers."""
|
|
registry = BufferRegistry()
|
|
|
|
registry.create("buf-1", 1, "x", "t", "c", "w")
|
|
registry.create("buf-2", 2, "x", "t", "c", "w")
|
|
|
|
registry.clear()
|
|
|
|
assert registry.list_active() == []
|
|
|
|
|
|
class TestBufferStateProperties:
|
|
"""Test BufferState computed properties."""
|
|
|
|
def test_is_complete_when_all_received(self):
|
|
"""is_complete should be True when all items are received."""
|
|
state = BufferState(
|
|
buffer_id="test",
|
|
total_items=3,
|
|
return_to="x",
|
|
thread_id="t",
|
|
from_id="c",
|
|
target="w",
|
|
completed_count=3,
|
|
)
|
|
assert state.is_complete is True
|
|
|
|
def test_pending_count(self):
|
|
"""pending_count should reflect remaining items."""
|
|
state = BufferState(
|
|
buffer_id="test",
|
|
total_items=5,
|
|
return_to="x",
|
|
thread_id="t",
|
|
from_id="c",
|
|
target="w",
|
|
completed_count=2,
|
|
)
|
|
assert state.pending_count == 3
|
|
|
|
def test_get_ordered_results(self):
|
|
"""get_ordered_results should return results in order."""
|
|
state = BufferState(
|
|
buffer_id="test",
|
|
total_items=3,
|
|
return_to="x",
|
|
thread_id="t",
|
|
from_id="c",
|
|
target="w",
|
|
results={
|
|
0: BufferItemResult(0, "<R0/>", True),
|
|
2: BufferItemResult(2, "<R2/>", True),
|
|
# Index 1 missing
|
|
},
|
|
)
|
|
|
|
ordered = state.get_ordered_results()
|
|
assert len(ordered) == 3
|
|
assert ordered[0] is not None
|
|
assert ordered[0].result == "<R0/>"
|
|
assert ordered[1] is None # Missing
|
|
assert ordered[2] is not None
|
|
assert ordered[2].result == "<R2/>"
|
|
|
|
|
|
class TestBufferStartPayload:
|
|
"""Test BufferStart payload."""
|
|
|
|
def test_buffer_start_fields(self):
|
|
"""BufferStart should have expected fields."""
|
|
payload = BufferStart(
|
|
target="worker",
|
|
items="item1\nitem2\nitem3",
|
|
collect=True,
|
|
return_to="caller",
|
|
buffer_id="custom-id",
|
|
)
|
|
|
|
assert payload.target == "worker"
|
|
assert payload.items == "item1\nitem2\nitem3"
|
|
assert payload.collect is True
|
|
assert payload.return_to == "caller"
|
|
assert payload.buffer_id == "custom-id"
|
|
|
|
def test_buffer_start_default_values(self):
|
|
"""BufferStart should have sensible defaults."""
|
|
payload = BufferStart()
|
|
|
|
assert payload.target == ""
|
|
assert payload.items == ""
|
|
assert payload.collect is True
|
|
assert payload.return_to == ""
|
|
assert payload.buffer_id == ""
|
|
|
|
|
|
class TestBufferCompletePayload:
|
|
"""Test BufferComplete payload."""
|
|
|
|
def test_buffer_complete_fields(self):
|
|
"""BufferComplete should have expected fields."""
|
|
payload = BufferComplete(
|
|
buffer_id="buf123",
|
|
total=5,
|
|
successful=4,
|
|
results="<results>...</results>",
|
|
)
|
|
|
|
assert payload.buffer_id == "buf123"
|
|
assert payload.total == 5
|
|
assert payload.successful == 4
|
|
assert payload.results == "<results>...</results>"
|
|
|
|
|
|
class TestBufferDispatchedPayload:
|
|
"""Test BufferDispatched payload (fire-and-forget mode)."""
|
|
|
|
def test_buffer_dispatched_fields(self):
|
|
"""BufferDispatched should have expected fields."""
|
|
payload = BufferDispatched(
|
|
buffer_id="buf456",
|
|
total=10,
|
|
)
|
|
|
|
assert payload.buffer_id == "buf456"
|
|
assert payload.total == 10
|
|
|
|
|
|
class TestBufferErrorPayload:
|
|
"""Test BufferError payload."""
|
|
|
|
def test_buffer_error_fields(self):
|
|
"""BufferError should have expected fields."""
|
|
payload = BufferError(
|
|
buffer_id="buf789",
|
|
error="Unknown target listener",
|
|
)
|
|
|
|
assert payload.buffer_id == "buf789"
|
|
assert payload.error == "Unknown target listener"
|
|
|
|
|
|
class TestExtractWorkerIndex:
|
|
"""Test _extract_worker_index helper."""
|
|
|
|
def test_extracts_index_from_chain(self):
|
|
"""Should extract worker index from thread chain."""
|
|
chain = "root.parent.buffer_abc123_w5"
|
|
index = _extract_worker_index(chain, "abc123")
|
|
assert index == 5
|
|
|
|
def test_extracts_double_digit_index(self):
|
|
"""Should handle double-digit indices."""
|
|
chain = "x.buffer_xyz_w42"
|
|
index = _extract_worker_index(chain, "xyz")
|
|
assert index == 42
|
|
|
|
def test_returns_none_for_no_match(self):
|
|
"""Should return None when pattern doesn't match."""
|
|
chain = "something.else.entirely"
|
|
index = _extract_worker_index(chain, "abc")
|
|
assert index is None
|
|
|
|
def test_returns_none_for_wrong_buffer_id(self):
|
|
"""Should return None when buffer ID doesn't match."""
|
|
chain = "root.buffer_other_w3"
|
|
index = _extract_worker_index(chain, "abc")
|
|
assert index is None
|
|
|
|
|
|
class TestFormatBufferResults:
|
|
"""Test _format_buffer_results helper."""
|
|
|
|
def test_formats_complete_results(self):
|
|
"""Should format all results as XML."""
|
|
state = BufferState(
|
|
buffer_id="test",
|
|
total_items=2,
|
|
return_to="x",
|
|
thread_id="t",
|
|
from_id="c",
|
|
target="w",
|
|
results={
|
|
0: BufferItemResult(0, "<Result>A</Result>", True),
|
|
1: BufferItemResult(1, "<Result>B</Result>", True),
|
|
},
|
|
)
|
|
|
|
xml = _format_buffer_results(state)
|
|
|
|
assert "<results>" in xml
|
|
assert "</results>" in xml
|
|
assert 'index="0"' in xml
|
|
assert 'index="1"' in xml
|
|
assert 'success="true"' in xml
|
|
assert "<Result>A</Result>" in xml
|
|
assert "<Result>B</Result>" in xml
|
|
|
|
def test_formats_partial_failure(self):
|
|
"""Should format mixed success/failure results."""
|
|
state = BufferState(
|
|
buffer_id="test",
|
|
total_items=2,
|
|
return_to="x",
|
|
thread_id="t",
|
|
from_id="c",
|
|
target="w",
|
|
results={
|
|
0: BufferItemResult(0, "<Good/>", True),
|
|
1: BufferItemResult(1, "<Error/>", False, "timeout"),
|
|
},
|
|
)
|
|
|
|
xml = _format_buffer_results(state)
|
|
|
|
assert 'success="true"' in xml
|
|
assert 'success="false"' in xml
|
|
|
|
def test_formats_missing_results(self):
|
|
"""Should handle missing results."""
|
|
state = BufferState(
|
|
buffer_id="test",
|
|
total_items=3,
|
|
return_to="x",
|
|
thread_id="t",
|
|
from_id="c",
|
|
target="w",
|
|
results={
|
|
0: BufferItemResult(0, "<R/>", True),
|
|
# Index 1 missing
|
|
2: BufferItemResult(2, "<R/>", True),
|
|
},
|
|
)
|
|
|
|
xml = _format_buffer_results(state)
|
|
|
|
assert 'index="1"' in xml
|
|
assert "missing" in xml
|
|
|
|
|
|
class TestHandleBufferStartValidation:
|
|
"""Test validation in handle_buffer_start."""
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def setup(self):
|
|
"""Reset registries before each test."""
|
|
reset_buffer_registry()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_items_returns_error(self):
|
|
"""handle_buffer_start should return error for empty items."""
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
mock_pump = MagicMock()
|
|
mock_pump.listeners = {}
|
|
|
|
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
|
payload = BufferStart(
|
|
target="worker",
|
|
items="", # Empty
|
|
return_to="caller",
|
|
)
|
|
|
|
metadata = HandlerMetadata(
|
|
thread_id=str(uuid.uuid4()),
|
|
from_id="console",
|
|
)
|
|
|
|
response = await handle_buffer_start(payload, metadata)
|
|
|
|
assert isinstance(response, HandlerResponse)
|
|
assert isinstance(response.payload, BufferError)
|
|
assert "No items" in response.payload.error
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unknown_target_returns_error(self):
|
|
"""handle_buffer_start should return error for unknown target."""
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
mock_pump = MagicMock()
|
|
mock_pump.listeners = {} # No listeners
|
|
|
|
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
|
payload = BufferStart(
|
|
target="unknown_worker",
|
|
items="item1\nitem2",
|
|
return_to="caller",
|
|
)
|
|
|
|
metadata = HandlerMetadata(
|
|
thread_id=str(uuid.uuid4()),
|
|
from_id="console",
|
|
)
|
|
|
|
response = await handle_buffer_start(payload, metadata)
|
|
|
|
assert isinstance(response, HandlerResponse)
|
|
assert isinstance(response.payload, BufferError)
|
|
assert "Unknown target" in response.payload.error
|
|
|
|
|
|
class TestBufferRegistrySingleton:
|
|
"""Test singleton pattern for BufferRegistry."""
|
|
|
|
def test_get_buffer_registry_returns_singleton(self):
|
|
"""get_buffer_registry should return same instance."""
|
|
reset_buffer_registry()
|
|
|
|
reg1 = get_buffer_registry()
|
|
reg2 = get_buffer_registry()
|
|
|
|
assert reg1 is reg2
|
|
|
|
def test_reset_creates_new_instance(self):
|
|
"""reset_buffer_registry should clear singleton."""
|
|
reg1 = get_buffer_registry()
|
|
reg1.create("test", 1, "x", "t", "c", "w")
|
|
|
|
reset_buffer_registry()
|
|
reg2 = get_buffer_registry()
|
|
|
|
assert reg2.get("test") is None
|
|
|
|
|
|
class TestBufferCollectVsFireAndForget:
|
|
"""Test collect=True vs collect=False behavior."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_collect_mode_returns_none(self):
|
|
"""With collect=True, handler should return None (wait for results)."""
|
|
from unittest.mock import patch, MagicMock, AsyncMock
|
|
|
|
mock_pump = MagicMock()
|
|
mock_pump.listeners = {"worker": MagicMock()}
|
|
mock_pump.register_generic_listener = MagicMock()
|
|
mock_pump._wrap_in_envelope = MagicMock(return_value=b"<envelope/>")
|
|
mock_pump.inject = AsyncMock()
|
|
|
|
# Mock thread registry
|
|
mock_thread_registry = MagicMock()
|
|
mock_thread_registry.lookup = MagicMock(return_value="root.parent")
|
|
mock_thread_registry.get_or_create = MagicMock(return_value="worker-uuid")
|
|
|
|
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
|
with patch('xml_pipeline.message_bus.thread_registry.get_registry', return_value=mock_thread_registry):
|
|
payload = BufferStart(
|
|
target="worker",
|
|
items="item1\nitem2",
|
|
return_to="caller",
|
|
collect=True,
|
|
)
|
|
|
|
metadata = HandlerMetadata(
|
|
thread_id=str(uuid.uuid4()),
|
|
from_id="console",
|
|
)
|
|
|
|
response = await handle_buffer_start(payload, metadata)
|
|
|
|
# With collect=True, returns None (ephemeral handler will send BufferComplete)
|
|
assert response is None
|
|
|
|
# Ephemeral listener should be registered
|
|
mock_pump.register_generic_listener.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fire_and_forget_returns_dispatched(self):
|
|
"""With collect=False, handler should return BufferDispatched immediately."""
|
|
from unittest.mock import patch, MagicMock, AsyncMock
|
|
|
|
mock_pump = MagicMock()
|
|
mock_pump.listeners = {"worker": MagicMock()}
|
|
mock_pump.register_generic_listener = MagicMock()
|
|
mock_pump._wrap_in_envelope = MagicMock(return_value=b"<envelope/>")
|
|
mock_pump.inject = AsyncMock()
|
|
|
|
mock_thread_registry = MagicMock()
|
|
mock_thread_registry.lookup = MagicMock(return_value="root.parent")
|
|
mock_thread_registry.get_or_create = MagicMock(return_value="worker-uuid")
|
|
|
|
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
|
with patch('xml_pipeline.message_bus.thread_registry.get_registry', return_value=mock_thread_registry):
|
|
payload = BufferStart(
|
|
target="worker",
|
|
items="item1\nitem2\nitem3",
|
|
return_to="caller",
|
|
collect=False, # Fire-and-forget
|
|
)
|
|
|
|
metadata = HandlerMetadata(
|
|
thread_id=str(uuid.uuid4()),
|
|
from_id="console",
|
|
)
|
|
|
|
response = await handle_buffer_start(payload, metadata)
|
|
|
|
# With collect=False, returns BufferDispatched
|
|
assert isinstance(response, HandlerResponse)
|
|
assert isinstance(response.payload, BufferDispatched)
|
|
assert response.payload.total == 3
|
|
|
|
# No ephemeral listener registered for fire-and-forget
|
|
mock_pump.register_generic_listener.assert_not_called()
|
|
|
|
|
|
class TestBufferResultCollection:
|
|
"""Test result collection behavior."""
|
|
|
|
def test_result_collection_partial_success(self):
|
|
"""Buffer should track partial success correctly."""
|
|
registry = BufferRegistry()
|
|
|
|
registry.create("partial", 5, "x", "t", "c", "w")
|
|
|
|
# 3 successes, 2 failures
|
|
registry.record_result("partial", 0, "<R/>", True)
|
|
registry.record_result("partial", 1, "<E/>", False, "error")
|
|
registry.record_result("partial", 2, "<R/>", True)
|
|
registry.record_result("partial", 3, "<E/>", False, "error")
|
|
registry.record_result("partial", 4, "<R/>", True)
|
|
|
|
state = registry.get("partial")
|
|
|
|
assert state.is_complete is True
|
|
assert state.completed_count == 5
|
|
assert state.successful_count == 3
|
|
|
|
def test_results_out_of_order(self):
|
|
"""Buffer should handle results arriving out of order."""
|
|
registry = BufferRegistry()
|
|
|
|
registry.create("ooo", 3, "x", "t", "c", "w")
|
|
|
|
# Results arrive out of order
|
|
registry.record_result("ooo", 2, "<R2/>", True)
|
|
registry.record_result("ooo", 0, "<R0/>", True)
|
|
registry.record_result("ooo", 1, "<R1/>", True)
|
|
|
|
state = registry.get("ooo")
|
|
|
|
assert state.is_complete is True
|
|
ordered = state.get_ordered_results()
|
|
assert ordered[0].result == "<R0/>"
|
|
assert ordered[1].result == "<R1/>"
|
|
assert ordered[2].result == "<R2/>"
|