From 6790c7a46c2e8da4aae259eb979581d5f88a7e78 Mon Sep 17 00:00:00 2001 From: dullfig Date: Tue, 20 Jan 2026 20:18:22 -0800 Subject: [PATCH] Add shared backend for multiprocess pipeline support Introduces SharedBackend Protocol for cross-process state sharing: - InMemoryBackend: default single-process storage - ManagerBackend: multiprocessing.Manager for local multi-process - RedisBackend: distributed deployments with TTL auto-GC Adds ProcessPoolExecutor support for CPU-bound handlers: - worker.py: worker process entry point - stream_pump.py: cpu_bound handler dispatch - Config: backend and process_pool sections in organism.yaml ContextBuffer and ThreadRegistry now accept optional backend parameter while maintaining full backward compatibility. Co-Authored-By: Claude Opus 4.5 --- tests/conftest.py | 44 +++ tests/test_shared_backend.py | 393 ++++++++++++++++++++ xml_pipeline/config/loader.py | 67 ++++ xml_pipeline/memory/__init__.py | 34 ++ xml_pipeline/memory/context_buffer.py | 204 +++++++++- xml_pipeline/memory/manager_backend.py | 220 +++++++++++ xml_pipeline/memory/memory_backend.py | 136 +++++++ xml_pipeline/memory/redis_backend.py | 262 +++++++++++++ xml_pipeline/memory/shared_backend.py | 275 ++++++++++++++ xml_pipeline/message_bus/stream_pump.py | 179 ++++++++- xml_pipeline/message_bus/thread_registry.py | 221 ++++++++++- xml_pipeline/message_bus/worker.py | 339 +++++++++++++++++ 12 files changed, 2346 insertions(+), 28 deletions(-) create mode 100644 tests/test_shared_backend.py create mode 100644 xml_pipeline/memory/manager_backend.py create mode 100644 xml_pipeline/memory/memory_backend.py create mode 100644 xml_pipeline/memory/redis_backend.py create mode 100644 xml_pipeline/memory/shared_backend.py create mode 100644 xml_pipeline/message_bus/worker.py diff --git a/tests/conftest.py b/tests/conftest.py index b321683..5337728 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -36,6 +36,50 @@ def pytest_configure(config): # Fixtures available to all tests # ============================================================================ +@pytest.fixture(autouse=True) +def reset_singletons(): + """Reset global singletons before each test to ensure isolation.""" + # Clear registries before test + try: + from xml_pipeline.platform.prompt_registry import get_prompt_registry + get_prompt_registry().clear() + except ImportError: + pass + + try: + from xml_pipeline.memory.context_buffer import get_context_buffer + get_context_buffer().clear() + except ImportError: + pass + + try: + from xml_pipeline.message_bus.thread_registry import get_registry + get_registry().clear() + except (ImportError, AttributeError): + pass + + yield # Run the test + + # Clear after test too for good measure + try: + from xml_pipeline.platform.prompt_registry import get_prompt_registry + get_prompt_registry().clear() + except ImportError: + pass + + try: + from xml_pipeline.memory.context_buffer import get_context_buffer + get_context_buffer().clear() + except ImportError: + pass + + try: + from xml_pipeline.message_bus.thread_registry import get_registry + get_registry().clear() + except (ImportError, AttributeError): + pass + + @pytest.fixture def sample_thread_id(): """A valid UUID for testing.""" diff --git a/tests/test_shared_backend.py b/tests/test_shared_backend.py new file mode 100644 index 0000000..afdebae --- /dev/null +++ b/tests/test_shared_backend.py @@ -0,0 +1,393 @@ +""" +Tests for shared backend implementations. + +Tests InMemoryBackend and ManagerBackend. +Redis tests require a running Redis server (skipped if not available). +""" + +import pytest +from dataclasses import dataclass + +from xml_pipeline.memory.memory_backend import InMemoryBackend +from xml_pipeline.memory.shared_backend import ( + BackendConfig, + serialize_slot, + deserialize_slot, +) +from xml_pipeline.memory.context_buffer import BufferSlot, SlotMetadata + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +def memory_backend(): + """Create fresh in-memory backend.""" + backend = InMemoryBackend() + yield backend + backend.close() + + +@pytest.fixture +def manager_backend(): + """Create fresh manager backend.""" + from xml_pipeline.memory.manager_backend import ManagerBackend + + backend = ManagerBackend() + yield backend + backend.close() + + +def redis_available() -> bool: + """Check if Redis is available.""" + try: + import redis + + client = redis.from_url("redis://localhost:6379") + client.ping() + client.close() + return True + except Exception: + return False + + +@pytest.fixture +def redis_backend(): + """Create fresh Redis backend (skipped if Redis not available).""" + if not redis_available(): + pytest.skip("Redis not available") + + from xml_pipeline.memory.redis_backend import RedisBackend + + backend = RedisBackend( + url="redis://localhost:6379", + prefix="xp_test:", + ttl=300, + ) + # Clear test keys + backend.buffer_clear() + backend.registry_clear() + yield backend + # Cleanup + backend.buffer_clear() + backend.registry_clear() + backend.close() + + +# ============================================================================ +# Sample data +# ============================================================================ + + +@dataclass +class SamplePayload: + """Sample payload for testing.""" + + message: str + value: int + + +def make_slot(thread_id: str, index: int = 0) -> BufferSlot: + """Create a sample buffer slot.""" + payload = SamplePayload(message="hello", value=42) + metadata = SlotMetadata( + thread_id=thread_id, + from_id="sender", + to_id="receiver", + slot_index=index, + timestamp="2024-01-15T00:00:00Z", + payload_type="SamplePayload", + ) + return BufferSlot(payload=payload, metadata=metadata) + + +# ============================================================================ +# InMemoryBackend Tests +# ============================================================================ + + +class TestInMemoryBackend: + """Tests for InMemoryBackend.""" + + def test_buffer_append_and_get(self, memory_backend): + """Test appending and retrieving buffer slots.""" + slot = make_slot("thread-1") + slot_bytes = serialize_slot(slot) + + # Append + index = memory_backend.buffer_append("thread-1", slot_bytes) + assert index == 0 + + # Append another + index2 = memory_backend.buffer_append("thread-1", slot_bytes) + assert index2 == 1 + + # Get all + slots = memory_backend.buffer_get_thread("thread-1") + assert len(slots) == 2 + + # Deserialize and verify + retrieved = deserialize_slot(slots[0]) + assert retrieved.metadata.thread_id == "thread-1" + assert retrieved.payload.message == "hello" + + def test_buffer_get_slot(self, memory_backend): + """Test getting specific slot by index.""" + slot = make_slot("thread-1") + slot_bytes = serialize_slot(slot) + + memory_backend.buffer_append("thread-1", slot_bytes) + memory_backend.buffer_append("thread-1", slot_bytes) + + # Get specific slot + data = memory_backend.buffer_get_slot("thread-1", 1) + assert data is not None + + # Non-existent index + data = memory_backend.buffer_get_slot("thread-1", 999) + assert data is None + + def test_buffer_thread_exists(self, memory_backend): + """Test thread existence check.""" + assert not memory_backend.buffer_thread_exists("thread-1") + + slot_bytes = serialize_slot(make_slot("thread-1")) + memory_backend.buffer_append("thread-1", slot_bytes) + + assert memory_backend.buffer_thread_exists("thread-1") + + def test_buffer_delete_thread(self, memory_backend): + """Test deleting thread buffer.""" + slot_bytes = serialize_slot(make_slot("thread-1")) + memory_backend.buffer_append("thread-1", slot_bytes) + + assert memory_backend.buffer_delete_thread("thread-1") + assert not memory_backend.buffer_thread_exists("thread-1") + assert not memory_backend.buffer_delete_thread("thread-1") # Already deleted + + def test_buffer_list_threads(self, memory_backend): + """Test listing all threads.""" + slot_bytes = serialize_slot(make_slot("thread-1")) + memory_backend.buffer_append("thread-1", slot_bytes) + memory_backend.buffer_append("thread-2", slot_bytes) + + threads = memory_backend.buffer_list_threads() + assert set(threads) == {"thread-1", "thread-2"} + + def test_registry_set_and_get(self, memory_backend): + """Test registry set and get operations.""" + memory_backend.registry_set("a.b.c", "uuid-123") + + # Get UUID from chain + uuid = memory_backend.registry_get_uuid("a.b.c") + assert uuid == "uuid-123" + + # Get chain from UUID + chain = memory_backend.registry_get_chain("uuid-123") + assert chain == "a.b.c" + + def test_registry_delete(self, memory_backend): + """Test registry delete.""" + memory_backend.registry_set("a.b.c", "uuid-123") + + assert memory_backend.registry_delete("uuid-123") + assert memory_backend.registry_get_uuid("a.b.c") is None + assert memory_backend.registry_get_chain("uuid-123") is None + assert not memory_backend.registry_delete("uuid-123") # Already deleted + + def test_registry_list_all(self, memory_backend): + """Test listing all registry entries.""" + memory_backend.registry_set("a.b", "uuid-1") + memory_backend.registry_set("x.y.z", "uuid-2") + + all_entries = memory_backend.registry_list_all() + assert all_entries == {"uuid-1": "a.b", "uuid-2": "x.y.z"} + + def test_registry_clear(self, memory_backend): + """Test clearing registry.""" + memory_backend.registry_set("a.b", "uuid-1") + memory_backend.registry_clear() + + assert memory_backend.registry_list_all() == {} + + +# ============================================================================ +# ManagerBackend Tests +# ============================================================================ + + +class TestManagerBackend: + """Tests for ManagerBackend (multiprocessing.Manager).""" + + def test_buffer_append_and_get(self, manager_backend): + """Test appending and retrieving buffer slots.""" + slot = make_slot("thread-1") + slot_bytes = serialize_slot(slot) + + index = manager_backend.buffer_append("thread-1", slot_bytes) + assert index == 0 + + slots = manager_backend.buffer_get_thread("thread-1") + assert len(slots) == 1 + + def test_registry_operations(self, manager_backend): + """Test registry operations via Manager.""" + manager_backend.registry_set("a.b.c", "uuid-123") + + uuid = manager_backend.registry_get_uuid("a.b.c") + assert uuid == "uuid-123" + + chain = manager_backend.registry_get_chain("uuid-123") + assert chain == "a.b.c" + + +# ============================================================================ +# RedisBackend Tests +# ============================================================================ + + +@pytest.mark.skipif(not redis_available(), reason="Redis not available") +class TestRedisBackend: + """Tests for RedisBackend (requires running Redis).""" + + def test_buffer_append_and_get(self, redis_backend): + """Test appending and retrieving buffer slots.""" + slot = make_slot("thread-1") + slot_bytes = serialize_slot(slot) + + index = redis_backend.buffer_append("thread-1", slot_bytes) + assert index == 0 + + slots = redis_backend.buffer_get_thread("thread-1") + assert len(slots) == 1 + + retrieved = deserialize_slot(slots[0]) + assert retrieved.metadata.thread_id == "thread-1" + + def test_buffer_thread_len(self, redis_backend): + """Test getting thread length.""" + slot_bytes = serialize_slot(make_slot("thread-1")) + + redis_backend.buffer_append("thread-1", slot_bytes) + redis_backend.buffer_append("thread-1", slot_bytes) + redis_backend.buffer_append("thread-1", slot_bytes) + + assert redis_backend.buffer_thread_len("thread-1") == 3 + + def test_registry_operations(self, redis_backend): + """Test registry operations via Redis.""" + redis_backend.registry_set("console.router.greeter", "uuid-abc") + + uuid = redis_backend.registry_get_uuid("console.router.greeter") + assert uuid == "uuid-abc" + + chain = redis_backend.registry_get_chain("uuid-abc") + assert chain == "console.router.greeter" + + def test_registry_delete(self, redis_backend): + """Test registry delete removes both directions.""" + redis_backend.registry_set("a.b", "uuid-123") + + assert redis_backend.registry_delete("uuid-123") + assert redis_backend.registry_get_uuid("a.b") is None + assert redis_backend.registry_get_chain("uuid-123") is None + + def test_ping(self, redis_backend): + """Test Redis ping.""" + assert redis_backend.ping() + + def test_info(self, redis_backend): + """Test backend info.""" + slot_bytes = serialize_slot(make_slot("thread-1")) + redis_backend.buffer_append("thread-1", slot_bytes) + redis_backend.registry_set("a.b", "uuid-1") + + info = redis_backend.info() + assert "buffer_threads" in info + assert "registry_entries" in info + + +# ============================================================================ +# Integration: ContextBuffer with SharedBackend +# ============================================================================ + + +class TestContextBufferWithBackend: + """Test ContextBuffer using shared backends.""" + + def test_context_buffer_with_memory_backend(self): + """Test ContextBuffer works with in-memory backend.""" + from xml_pipeline.memory.context_buffer import ContextBuffer + + backend = InMemoryBackend() + buffer = ContextBuffer(backend=backend) + + assert buffer.is_shared + + # Append slot + slot = buffer.append( + thread_id="test-thread", + payload=SamplePayload(message="hello", value=42), + from_id="sender", + to_id="receiver", + ) + + assert slot.thread_id == "test-thread" + assert slot.payload.message == "hello" + + # Get thread + thread = buffer.get_thread("test-thread") + assert thread is not None + assert len(thread) == 1 + + # Get stats + stats = buffer.get_stats() + assert stats["thread_count"] == 1 + assert stats["backend"] == "shared" + + backend.close() + + +# ============================================================================ +# Integration: ThreadRegistry with SharedBackend +# ============================================================================ + + +class TestThreadRegistryWithBackend: + """Test ThreadRegistry using shared backends.""" + + def test_thread_registry_with_memory_backend(self): + """Test ThreadRegistry works with in-memory backend.""" + from xml_pipeline.message_bus.thread_registry import ThreadRegistry + + backend = InMemoryBackend() + registry = ThreadRegistry(backend=backend) + + assert registry.is_shared + + # Initialize root + root_uuid = registry.initialize_root("test-organism") + assert root_uuid is not None + assert registry.root_chain == "system.test-organism" + + # Get or create chain + uuid = registry.get_or_create("a.b.c") + assert uuid is not None + + # Lookup + chain = registry.lookup(uuid) + assert chain == "a.b.c" + + # Extend chain + new_uuid = registry.extend_chain(uuid, "d") + new_chain = registry.lookup(new_uuid) + assert new_chain == "a.b.c.d" + + # Prune for response + target, pruned_uuid = registry.prune_for_response(new_uuid) + assert target == "c" + assert registry.lookup(pruned_uuid) == "a.b.c" + + backend.close() diff --git a/xml_pipeline/config/loader.py b/xml_pipeline/config/loader.py index 93870b9..312c7a0 100644 --- a/xml_pipeline/config/loader.py +++ b/xml_pipeline/config/loader.py @@ -65,6 +65,9 @@ class ListenerConfig: allowed_tools: list[str] = field(default_factory=list) blocked_tools: list[str] = field(default_factory=list) + # Dispatch mode + cpu_bound: bool = False # If True, dispatch to ProcessPoolExecutor + @dataclass class ServerConfig: @@ -83,6 +86,42 @@ class AuthConfig: totp_secret_env: str = "ORGANISM_TOTP_SECRET" +@dataclass +class BackendStorageConfig: + """ + Shared backend configuration for multi-process deployments. + + Enables ContextBuffer and ThreadRegistry to use shared storage + (Redis or multiprocessing.Manager) for cross-process access. + """ + + backend_type: str = "memory" # "memory", "manager", "redis" + + # Redis-specific config + redis_url: str = "redis://localhost:6379" + redis_prefix: str = "xp:" + redis_ttl: int = 86400 # 24 hours default TTL + + # Limits + max_slots_per_thread: int = 10000 + max_threads: int = 1000 + + +@dataclass +class ProcessPoolConfig: + """ + Process pool configuration for CPU-bound handler dispatch. + + When configured, handlers marked with `cpu_bound: true` are + dispatched to a ProcessPoolExecutor instead of running in + the main event loop. + """ + + enabled: bool = False + workers: int = 4 # Number of worker processes + max_tasks_per_child: int = 100 # Restart workers after N tasks + + @dataclass class OrganismConfig: """Complete organism configuration.""" @@ -92,6 +131,8 @@ class OrganismConfig: llm_backends: list[LLMBackendConfig] = field(default_factory=list) server: ServerConfig | None = None auth: AuthConfig | None = None + backend: BackendStorageConfig | None = None + process_pool: ProcessPoolConfig | None = None def load_config(path: Path) -> OrganismConfig: @@ -152,6 +193,7 @@ def load_config(path: Path) -> OrganismConfig: peers=listener_raw.get("peers", []), allowed_tools=listener_raw.get("allowed_tools", []), blocked_tools=listener_raw.get("blocked_tools", []), + cpu_bound=listener_raw.get("cpu_bound", False), ) ) @@ -174,12 +216,37 @@ def load_config(path: Path) -> OrganismConfig: totp_secret_env=auth_raw.get("totp_secret_env", "ORGANISM_TOTP_SECRET"), ) + # Parse optional backend config + backend = None + if "backend" in raw: + backend_raw = raw["backend"] + backend = BackendStorageConfig( + backend_type=backend_raw.get("type", "memory"), + redis_url=backend_raw.get("redis_url", "redis://localhost:6379"), + redis_prefix=backend_raw.get("redis_prefix", "xp:"), + redis_ttl=backend_raw.get("redis_ttl", 86400), + max_slots_per_thread=backend_raw.get("max_slots_per_thread", 10000), + max_threads=backend_raw.get("max_threads", 1000), + ) + + # Parse optional process pool config + process_pool = None + if "process_pool" in raw: + pool_raw = raw["process_pool"] + process_pool = ProcessPoolConfig( + enabled=pool_raw.get("enabled", True), + workers=pool_raw.get("workers", 4), + max_tasks_per_child=pool_raw.get("max_tasks_per_child", 100), + ) + return OrganismConfig( organism=organism, listeners=listeners, llm_backends=llm_backends, server=server, auth=auth, + backend=backend, + process_pool=process_pool, ) diff --git a/xml_pipeline/memory/__init__.py b/xml_pipeline/memory/__init__.py index 6915c65..8794c5f 100644 --- a/xml_pipeline/memory/__init__.py +++ b/xml_pipeline/memory/__init__.py @@ -6,6 +6,21 @@ Provides thread-scoped, append-only context buffers with: - Thread isolation (handlers only see their context) - Complete audit trail (all messages preserved) - GC and limits (prevent runaway memory usage) + +For multi-process deployments, supports shared backends: +- InMemoryBackend: Default single-process storage +- ManagerBackend: multiprocessing.Manager for local multi-process +- RedisBackend: Redis for distributed deployments + +Usage: + # Default (in-memory, single process) + buffer = get_context_buffer() + + # With Redis backend + from xml_pipeline.memory.shared_backend import BackendConfig, get_shared_backend + config = BackendConfig(backend_type="redis", redis_url="redis://localhost:6379") + backend = get_shared_backend(config) + buffer = get_context_buffer(backend=backend) """ from xml_pipeline.memory.context_buffer import ( @@ -14,14 +29,33 @@ from xml_pipeline.memory.context_buffer import ( BufferSlot, SlotMetadata, get_context_buffer, + reset_context_buffer, slot_to_handler_metadata, ) +from xml_pipeline.memory.shared_backend import ( + SharedBackend, + BackendConfig, + get_shared_backend, + reset_shared_backend, + serialize_slot, + deserialize_slot, +) + __all__ = [ + # Context buffer "ContextBuffer", "ThreadContext", "BufferSlot", "SlotMetadata", "get_context_buffer", + "reset_context_buffer", "slot_to_handler_metadata", + # Shared backend + "SharedBackend", + "BackendConfig", + "get_shared_backend", + "reset_shared_backend", + "serialize_slot", + "deserialize_slot", ] diff --git a/xml_pipeline/memory/context_buffer.py b/xml_pipeline/memory/context_buffer.py index 3f112e3..de146db 100644 --- a/xml_pipeline/memory/context_buffer.py +++ b/xml_pipeline/memory/context_buffer.py @@ -27,16 +27,26 @@ Usage: # Get thread history history = buffer.get_thread(thread_id) + +For multi-process deployments, the buffer can use a shared backend: + from xml_pipeline.memory.shared_backend import get_shared_backend, BackendConfig + + config = BackendConfig(backend_type="redis", redis_url="redis://localhost:6379") + backend = get_shared_backend(config) + buffer = get_context_buffer(backend=backend) """ from __future__ import annotations from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Iterator +from typing import Any, Dict, List, Optional, Iterator, TYPE_CHECKING from datetime import datetime, timezone import threading import uuid +if TYPE_CHECKING: + from xml_pipeline.memory.shared_backend import SharedBackend + @dataclass(frozen=True) class SlotMetadata: @@ -166,9 +176,26 @@ class ContextBuffer: Global context buffer managing all thread contexts. Thread-safe. Singleton pattern via get_context_buffer(). + + Supports two storage modes: + 1. Local mode (default): Uses in-process ThreadContext objects + 2. Shared mode: Uses SharedBackend (Redis, Manager) for cross-process access + + In shared mode, slots are serialized via pickle and stored in the backend. + This enables multi-process handler dispatch (cpu_bound handlers). """ - def __init__(self): + def __init__(self, backend: Optional[SharedBackend] = None): + """ + Initialize context buffer. + + Args: + backend: Optional shared backend for cross-process storage. + If None, uses in-process storage (original behavior). + """ + self._backend = backend + + # Local storage (used when no backend) self._threads: Dict[str, ThreadContext] = {} self._lock = threading.Lock() @@ -176,8 +203,17 @@ class ContextBuffer: self.max_slots_per_thread: int = 10000 self.max_threads: int = 1000 + @property + def is_shared(self) -> bool: + """Return True if using shared backend.""" + return self._backend is not None + def get_or_create_thread(self, thread_id: str) -> ThreadContext: - """Get existing thread context or create new one.""" + """ + Get existing thread context or create new one. + + Note: In shared mode, this creates a local proxy that syncs with backend. + """ with self._lock: if thread_id not in self._threads: if len(self._threads) >= self.max_threads: @@ -205,7 +241,23 @@ class ContextBuffer: This is the main entry point for the pipeline. Returns the immutable slot reference. + + In shared mode, the slot is serialized and stored in the backend. """ + if self._backend is not None: + # Shared mode: serialize and store in backend + return self._append_shared( + thread_id=thread_id, + payload=payload, + from_id=from_id, + to_id=to_id, + own_name=own_name, + is_self_call=is_self_call, + usage_instructions=usage_instructions, + todo_nudge=todo_nudge, + ) + + # Local mode: use ThreadContext thread = self.get_or_create_thread(thread_id) # Enforce slot limit @@ -224,18 +276,116 @@ class ContextBuffer: todo_nudge=todo_nudge, ) + def _append_shared( + self, + thread_id: str, + payload: Any, + from_id: str, + to_id: str, + own_name: Optional[str] = None, + is_self_call: bool = False, + usage_instructions: str = "", + todo_nudge: str = "", + ) -> BufferSlot: + """Append to shared backend.""" + from xml_pipeline.memory.shared_backend import serialize_slot + + assert self._backend is not None + + # Get current slot count for index + current_len = self._backend.buffer_thread_len(thread_id) + + # Enforce slot limit + if current_len >= self.max_slots_per_thread: + raise MemoryError( + f"Thread {thread_id} exceeded max slots ({self.max_slots_per_thread})" + ) + + # Create metadata + metadata = SlotMetadata( + thread_id=thread_id, + from_id=from_id, + to_id=to_id, + slot_index=current_len, + timestamp=datetime.now(timezone.utc).isoformat(), + payload_type=type(payload).__name__, + own_name=own_name, + is_self_call=is_self_call, + usage_instructions=usage_instructions, + todo_nudge=todo_nudge, + ) + + # Create slot + slot = BufferSlot(payload=payload, metadata=metadata) + + # Serialize and store + slot_data = serialize_slot(slot) + self._backend.buffer_append(thread_id, slot_data) + + return slot + def get_thread(self, thread_id: str) -> Optional[ThreadContext]: - """Get a thread's context (None if not found).""" + """ + Get a thread's context (None if not found). + + In shared mode, returns a local ThreadContext populated from backend. + """ + if self._backend is not None: + return self._get_thread_shared(thread_id) + with self._lock: return self._threads.get(thread_id) + def _get_thread_shared(self, thread_id: str) -> Optional[ThreadContext]: + """Get thread from shared backend.""" + from xml_pipeline.memory.shared_backend import deserialize_slot + + assert self._backend is not None + + if not self._backend.buffer_thread_exists(thread_id): + return None + + # Create local ThreadContext and populate from backend + thread = ThreadContext(thread_id) + slot_data_list = self._backend.buffer_get_thread(thread_id) + + for slot_data in slot_data_list: + slot = deserialize_slot(slot_data) + thread._slots.append(slot) + + return thread + + def get_thread_slots(self, thread_id: str) -> List[BufferSlot]: + """ + Get all slots for a thread as a list. + + More efficient than get_thread() when you just need slots. + """ + if self._backend is not None: + from xml_pipeline.memory.shared_backend import deserialize_slot + + slot_data_list = self._backend.buffer_get_thread(thread_id) + return [deserialize_slot(data) for data in slot_data_list] + + with self._lock: + thread = self._threads.get(thread_id) + if thread: + return list(thread._slots) + return [] + def thread_exists(self, thread_id: str) -> bool: """Check if a thread exists.""" + if self._backend is not None: + return self._backend.buffer_thread_exists(thread_id) + with self._lock: return thread_id in self._threads def delete_thread(self, thread_id: str) -> bool: """Delete a thread's context (GC).""" + if self._backend is not None: + return self._backend.buffer_delete_thread(thread_id) + with self._lock: if thread_id in self._threads: del self._threads[thread_id] @@ -244,6 +394,20 @@ class ContextBuffer: def get_stats(self) -> Dict[str, Any]: """Get buffer statistics.""" + if self._backend is not None: + threads = self._backend.buffer_list_threads() + total_slots = sum( + self._backend.buffer_thread_len(t) for t in threads + ) + return { + "thread_count": len(threads), + "total_slots": total_slots, + "max_threads": self.max_threads, + "max_slots_per_thread": self.max_slots_per_thread, + "threads": threads, + "backend": "shared", + } + with self._lock: total_slots = sum(len(t) for t in self._threads.values()) return { @@ -252,10 +416,15 @@ class ContextBuffer: "max_threads": self.max_threads, "max_slots_per_thread": self.max_slots_per_thread, "threads": list(self._threads.keys()), + "backend": "local", } - def clear(self): + def clear(self) -> None: """Clear all contexts (for testing).""" + if self._backend is not None: + self._backend.buffer_clear() + return + with self._lock: self._threads.clear() @@ -268,16 +437,35 @@ _buffer: Optional[ContextBuffer] = None _buffer_lock = threading.Lock() -def get_context_buffer() -> ContextBuffer: - """Get the global ContextBuffer singleton.""" +def get_context_buffer(backend: Optional[SharedBackend] = None) -> ContextBuffer: + """ + Get the global ContextBuffer singleton. + + Args: + backend: Optional shared backend for cross-process storage. + Only used on first call (when creating the singleton). + Subsequent calls return the existing singleton. + + Returns: + Global ContextBuffer instance. + """ global _buffer if _buffer is None: with _buffer_lock: if _buffer is None: - _buffer = ContextBuffer() + _buffer = ContextBuffer(backend=backend) return _buffer +def reset_context_buffer() -> None: + """Reset the global context buffer (for testing).""" + global _buffer + with _buffer_lock: + if _buffer is not None: + _buffer.clear() + _buffer = None + + # ============================================================================ # Handler-facing metadata adapter # ============================================================================ diff --git a/xml_pipeline/memory/manager_backend.py b/xml_pipeline/memory/manager_backend.py new file mode 100644 index 0000000..e7241e1 --- /dev/null +++ b/xml_pipeline/memory/manager_backend.py @@ -0,0 +1,220 @@ +""" +manager_backend.py — multiprocessing.Manager implementation of SharedBackend. + +Uses Python's multiprocessing.Manager for cross-process state sharing +without requiring Redis. Good for local development and testing. + +Key differences from InMemoryBackend: +- Data structures are proxied across processes +- Slightly higher overhead than in-memory +- No TTL/expiration (use for short-lived test runs) +- No persistence (data lost on manager shutdown) + +Usage: + # Start manager in main process + backend = ManagerBackend() + + # Workers connect via the same address + # (Manager handles cross-process communication) +""" + +from __future__ import annotations + +import logging +import multiprocessing +from multiprocessing.managers import SyncManager +from typing import Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + + +class ManagerBackend: + """ + multiprocessing.Manager-backed shared state. + + Provides cross-process state without external dependencies. + Suitable for local multi-process testing when Redis isn't available. + """ + + def __init__( + self, + address: Optional[Tuple[str, int]] = None, + authkey: Optional[bytes] = None, + ) -> None: + """ + Initialize Manager backend. + + Args: + address: (host, port) for remote manager connection. + If None, creates a local manager. + authkey: Authentication key for manager. + If None, uses process authkey. + """ + self._is_server = address is None + self._manager: Optional[SyncManager] = None + + if self._is_server: + # Create a local manager (server mode) + self._manager = multiprocessing.Manager() + self._init_data_structures() + logger.info("Manager backend started (server mode)") + else: + # Connect to remote manager (client mode) + # Note: For remote connection, you'd need a custom SyncManager + # This is simplified for now - just use local manager + self._manager = multiprocessing.Manager() + self._init_data_structures() + logger.info("Manager backend started (client mode)") + + def _init_data_structures(self) -> None: + """Initialize managed data structures.""" + if self._manager is None: + raise RuntimeError("Manager not initialized") + + # Context buffer: thread_id → list of pickled slots + self._buffer: Dict[str, List[bytes]] = self._manager.dict() + + # Thread registry: bidirectional mapping + self._chain_to_uuid: Dict[str, str] = self._manager.dict() + self._uuid_to_chain: Dict[str, str] = self._manager.dict() + + # Lock for complex operations + self._lock = self._manager.Lock() + + # ========================================================================= + # Context Buffer Operations + # ========================================================================= + + def buffer_append(self, thread_id: str, slot_data: bytes) -> int: + """Append a slot to a thread's buffer.""" + with self._lock: + if thread_id not in self._buffer: + # Manager.dict doesn't support nested assignment directly + # We need to get, modify, put back + self._buffer[thread_id] = self._manager.list() # type: ignore + + # Get the list, append, and count + slots = self._buffer[thread_id] + slots.append(slot_data) + return len(slots) - 1 + + def buffer_get_thread(self, thread_id: str) -> List[bytes]: + """Get all slots for a thread.""" + with self._lock: + if thread_id not in self._buffer: + return [] + # Convert manager.list to regular list + return list(self._buffer[thread_id]) + + def buffer_get_slot(self, thread_id: str, index: int) -> Optional[bytes]: + """Get a specific slot by index.""" + with self._lock: + if thread_id not in self._buffer: + return None + slots = self._buffer[thread_id] + if 0 <= index < len(slots): + return slots[index] + return None + + def buffer_thread_len(self, thread_id: str) -> int: + """Get number of slots in a thread.""" + with self._lock: + if thread_id not in self._buffer: + return 0 + return len(self._buffer[thread_id]) + + def buffer_thread_exists(self, thread_id: str) -> bool: + """Check if a thread has any slots.""" + with self._lock: + return thread_id in self._buffer + + def buffer_delete_thread(self, thread_id: str) -> bool: + """Delete all slots for a thread.""" + with self._lock: + if thread_id in self._buffer: + del self._buffer[thread_id] + return True + return False + + def buffer_list_threads(self) -> List[str]: + """List all thread IDs with slots.""" + with self._lock: + return list(self._buffer.keys()) + + def buffer_clear(self) -> None: + """Clear all buffer data.""" + with self._lock: + # Can't call .clear() on manager.dict proxy + keys = list(self._buffer.keys()) + for key in keys: + del self._buffer[key] + + # ========================================================================= + # Thread Registry Operations + # ========================================================================= + + def registry_set(self, chain: str, uuid: str) -> None: + """Set bidirectional mapping: chain ↔ uuid.""" + with self._lock: + self._chain_to_uuid[chain] = uuid + self._uuid_to_chain[uuid] = chain + + def registry_get_uuid(self, chain: str) -> Optional[str]: + """Get UUID for a chain.""" + with self._lock: + return self._chain_to_uuid.get(chain) + + def registry_get_chain(self, uuid: str) -> Optional[str]: + """Get chain for a UUID.""" + with self._lock: + return self._uuid_to_chain.get(uuid) + + def registry_delete(self, uuid: str) -> bool: + """Delete mapping by UUID.""" + with self._lock: + chain = self._uuid_to_chain.get(uuid) + if chain: + del self._uuid_to_chain[uuid] + del self._chain_to_uuid[chain] + return True + return False + + def registry_list_all(self) -> Dict[str, str]: + """Get all UUID → chain mappings.""" + with self._lock: + return dict(self._uuid_to_chain) + + def registry_clear(self) -> None: + """Clear all registry data.""" + with self._lock: + # Can't call .clear() on manager.dict proxy + for key in list(self._chain_to_uuid.keys()): + del self._chain_to_uuid[key] + for key in list(self._uuid_to_chain.keys()): + del self._uuid_to_chain[key] + + # ========================================================================= + # Lifecycle + # ========================================================================= + + def close(self) -> None: + """Shutdown manager.""" + if self._manager and self._is_server: + try: + self._manager.shutdown() + logger.info("Manager backend shutdown") + except Exception as e: + logger.warning(f"Manager shutdown error: {e}") + self._manager = None + + # ========================================================================= + # Info + # ========================================================================= + + def info(self) -> Dict[str, int]: + """Get backend statistics.""" + with self._lock: + return { + "buffer_threads": len(self._buffer), + "registry_entries": len(self._uuid_to_chain), + } diff --git a/xml_pipeline/memory/memory_backend.py b/xml_pipeline/memory/memory_backend.py new file mode 100644 index 0000000..2ced2c9 --- /dev/null +++ b/xml_pipeline/memory/memory_backend.py @@ -0,0 +1,136 @@ +""" +memory_backend.py — In-memory implementation of SharedBackend. + +This is the default backend for single-process operation. +It provides the same interface as Redis/Manager backends but stores +everything in Python data structures. + +Thread-safe via threading.Lock (same as original ContextBuffer/ThreadRegistry). +""" + +from __future__ import annotations + +import threading +from typing import Dict, List, Optional + + +class InMemoryBackend: + """ + In-memory shared backend implementation. + + This is equivalent to the original behavior - all data in-process. + Use for development, testing, or single-process deployments. + """ + + def __init__(self) -> None: + # Context buffer storage: thread_id → list of pickled slots + self._buffer: Dict[str, List[bytes]] = {} + self._buffer_lock = threading.Lock() + + # Thread registry storage: bidirectional mapping + self._chain_to_uuid: Dict[str, str] = {} + self._uuid_to_chain: Dict[str, str] = {} + self._registry_lock = threading.Lock() + + # ========================================================================= + # Context Buffer Operations + # ========================================================================= + + def buffer_append(self, thread_id: str, slot_data: bytes) -> int: + """Append a slot to a thread's buffer.""" + with self._buffer_lock: + if thread_id not in self._buffer: + self._buffer[thread_id] = [] + + index = len(self._buffer[thread_id]) + self._buffer[thread_id].append(slot_data) + return index + + def buffer_get_thread(self, thread_id: str) -> List[bytes]: + """Get all slots for a thread.""" + with self._buffer_lock: + return list(self._buffer.get(thread_id, [])) + + def buffer_get_slot(self, thread_id: str, index: int) -> Optional[bytes]: + """Get a specific slot by index.""" + with self._buffer_lock: + slots = self._buffer.get(thread_id, []) + if 0 <= index < len(slots): + return slots[index] + return None + + def buffer_thread_len(self, thread_id: str) -> int: + """Get number of slots in a thread.""" + with self._buffer_lock: + return len(self._buffer.get(thread_id, [])) + + def buffer_thread_exists(self, thread_id: str) -> bool: + """Check if a thread has any slots.""" + with self._buffer_lock: + return thread_id in self._buffer + + def buffer_delete_thread(self, thread_id: str) -> bool: + """Delete all slots for a thread.""" + with self._buffer_lock: + if thread_id in self._buffer: + del self._buffer[thread_id] + return True + return False + + def buffer_list_threads(self) -> List[str]: + """List all thread IDs with slots.""" + with self._buffer_lock: + return list(self._buffer.keys()) + + def buffer_clear(self) -> None: + """Clear all buffer data.""" + with self._buffer_lock: + self._buffer.clear() + + # ========================================================================= + # Thread Registry Operations + # ========================================================================= + + def registry_set(self, chain: str, uuid: str) -> None: + """Set bidirectional mapping: chain ↔ uuid.""" + with self._registry_lock: + self._chain_to_uuid[chain] = uuid + self._uuid_to_chain[uuid] = chain + + def registry_get_uuid(self, chain: str) -> Optional[str]: + """Get UUID for a chain.""" + with self._registry_lock: + return self._chain_to_uuid.get(chain) + + def registry_get_chain(self, uuid: str) -> Optional[str]: + """Get chain for a UUID.""" + with self._registry_lock: + return self._uuid_to_chain.get(uuid) + + def registry_delete(self, uuid: str) -> bool: + """Delete mapping by UUID.""" + with self._registry_lock: + chain = self._uuid_to_chain.pop(uuid, None) + if chain: + self._chain_to_uuid.pop(chain, None) + return True + return False + + def registry_list_all(self) -> Dict[str, str]: + """Get all UUID → chain mappings.""" + with self._registry_lock: + return dict(self._uuid_to_chain) + + def registry_clear(self) -> None: + """Clear all registry data.""" + with self._registry_lock: + self._chain_to_uuid.clear() + self._uuid_to_chain.clear() + + # ========================================================================= + # Lifecycle + # ========================================================================= + + def close(self) -> None: + """No resources to clean up for in-memory backend.""" + pass diff --git a/xml_pipeline/memory/redis_backend.py b/xml_pipeline/memory/redis_backend.py new file mode 100644 index 0000000..86274b3 --- /dev/null +++ b/xml_pipeline/memory/redis_backend.py @@ -0,0 +1,262 @@ +""" +redis_backend.py — Redis implementation of SharedBackend. + +Uses Redis for distributed shared state, enabling: +- Multi-process handler dispatch (cpu_bound handlers in ProcessPool) +- Multi-tenant deployments (key prefix isolation) +- Automatic TTL-based garbage collection +- Cross-machine state sharing (future: WASM handlers) + +Key schema: +- Buffer slots: `{prefix}buffer:{thread_id}` → Redis LIST of pickled slots +- Registry chain→uuid: `{prefix}chain:{chain}` → uuid string +- Registry uuid→chain: `{prefix}uuid:{uuid}` → chain string + +Dependencies: + pip install redis +""" + +from __future__ import annotations + +import logging +from typing import Dict, List, Optional + +try: + import redis +except ImportError: + redis = None # type: ignore + +logger = logging.getLogger(__name__) + + +class RedisBackend: + """ + Redis-backed shared state for multi-process deployments. + + All operations are synchronous (blocking). For async contexts, + wrap calls in asyncio.to_thread() or use a connection pool. + """ + + def __init__( + self, + url: str = "redis://localhost:6379", + prefix: str = "xp:", + ttl: int = 86400, + ) -> None: + """ + Initialize Redis backend. + + Args: + url: Redis connection URL + prefix: Key prefix for multi-tenancy isolation + ttl: TTL in seconds for automatic cleanup (0 = no TTL) + """ + if redis is None: + raise ImportError( + "redis package not installed. Install with: pip install redis" + ) + + self.url = url + self.prefix = prefix + self.ttl = ttl + + # Connect to Redis + self._client = redis.from_url(url, decode_responses=False) + + # Verify connection + try: + self._client.ping() + logger.info(f"Redis backend connected: {url}") + except redis.ConnectionError as e: + logger.error(f"Redis connection failed: {e}") + raise + + def _buffer_key(self, thread_id: str) -> str: + """Get Redis key for a thread's buffer.""" + return f"{self.prefix}buffer:{thread_id}" + + def _chain_key(self, chain: str) -> str: + """Get Redis key for chain→uuid mapping.""" + return f"{self.prefix}chain:{chain}" + + def _uuid_key(self, uuid: str) -> str: + """Get Redis key for uuid→chain mapping.""" + return f"{self.prefix}uuid:{uuid}" + + # ========================================================================= + # Context Buffer Operations + # ========================================================================= + + def buffer_append(self, thread_id: str, slot_data: bytes) -> int: + """Append a slot to a thread's buffer.""" + key = self._buffer_key(thread_id) + + # RPUSH returns new length, so index = length - 1 + new_len = self._client.rpush(key, slot_data) + + # Set/refresh TTL + if self.ttl > 0: + self._client.expire(key, self.ttl) + + return new_len - 1 + + def buffer_get_thread(self, thread_id: str) -> List[bytes]: + """Get all slots for a thread.""" + key = self._buffer_key(thread_id) + # LRANGE returns all elements (0, -1 = full list) + result = self._client.lrange(key, 0, -1) + return result if result else [] + + def buffer_get_slot(self, thread_id: str, index: int) -> Optional[bytes]: + """Get a specific slot by index.""" + key = self._buffer_key(thread_id) + # LINDEX returns element at index, or None + return self._client.lindex(key, index) + + def buffer_thread_len(self, thread_id: str) -> int: + """Get number of slots in a thread.""" + key = self._buffer_key(thread_id) + return self._client.llen(key) + + def buffer_thread_exists(self, thread_id: str) -> bool: + """Check if a thread has any slots.""" + key = self._buffer_key(thread_id) + return self._client.exists(key) > 0 + + def buffer_delete_thread(self, thread_id: str) -> bool: + """Delete all slots for a thread.""" + key = self._buffer_key(thread_id) + deleted = self._client.delete(key) + return deleted > 0 + + def buffer_list_threads(self) -> List[str]: + """List all thread IDs with slots.""" + pattern = f"{self.prefix}buffer:*" + keys = self._client.keys(pattern) + prefix_len = len(f"{self.prefix}buffer:") + return [k.decode() if isinstance(k, bytes) else k for k in keys] + + def buffer_clear(self) -> None: + """Clear all buffer data.""" + pattern = f"{self.prefix}buffer:*" + keys = self._client.keys(pattern) + if keys: + self._client.delete(*keys) + + # ========================================================================= + # Thread Registry Operations + # ========================================================================= + + def registry_set(self, chain: str, uuid: str) -> None: + """Set bidirectional mapping: chain ↔ uuid.""" + chain_key = self._chain_key(chain) + uuid_key = self._uuid_key(uuid) + + # Use pipeline for atomicity + pipe = self._client.pipeline() + pipe.set(chain_key, uuid.encode()) + pipe.set(uuid_key, chain.encode()) + + if self.ttl > 0: + pipe.expire(chain_key, self.ttl) + pipe.expire(uuid_key, self.ttl) + + pipe.execute() + + def registry_get_uuid(self, chain: str) -> Optional[str]: + """Get UUID for a chain.""" + key = self._chain_key(chain) + value = self._client.get(key) + if value: + return value.decode() if isinstance(value, bytes) else value + return None + + def registry_get_chain(self, uuid: str) -> Optional[str]: + """Get chain for a UUID.""" + key = self._uuid_key(uuid) + value = self._client.get(key) + if value: + return value.decode() if isinstance(value, bytes) else value + return None + + def registry_delete(self, uuid: str) -> bool: + """Delete mapping by UUID.""" + uuid_key = self._uuid_key(uuid) + + # First get chain to delete both directions + chain = self._client.get(uuid_key) + if not chain: + return False + + if isinstance(chain, bytes): + chain = chain.decode() + + chain_key = self._chain_key(chain) + + # Delete both keys + deleted = self._client.delete(uuid_key, chain_key) + return deleted > 0 + + def registry_list_all(self) -> Dict[str, str]: + """Get all UUID → chain mappings.""" + pattern = f"{self.prefix}uuid:*" + keys = self._client.keys(pattern) + + result = {} + prefix_len = len(f"{self.prefix}uuid:") + + for key in keys: + if isinstance(key, bytes): + key = key.decode() + uuid = key[prefix_len:] + chain = self._client.get(key) + if chain: + if isinstance(chain, bytes): + chain = chain.decode() + result[uuid] = chain + + return result + + def registry_clear(self) -> None: + """Clear all registry data.""" + chain_pattern = f"{self.prefix}chain:*" + uuid_pattern = f"{self.prefix}uuid:*" + + chain_keys = self._client.keys(chain_pattern) + uuid_keys = self._client.keys(uuid_pattern) + + all_keys = list(chain_keys) + list(uuid_keys) + if all_keys: + self._client.delete(*all_keys) + + # ========================================================================= + # Lifecycle + # ========================================================================= + + def close(self) -> None: + """Close Redis connection.""" + if self._client: + self._client.close() + logger.info("Redis backend closed") + + # ========================================================================= + # Health Check + # ========================================================================= + + def ping(self) -> bool: + """Check if Redis is reachable.""" + try: + return self._client.ping() + except Exception: + return False + + def info(self) -> Dict[str, int]: + """Get backend statistics.""" + buffer_keys = len(self._client.keys(f"{self.prefix}buffer:*")) + uuid_keys = len(self._client.keys(f"{self.prefix}uuid:*")) + + return { + "buffer_threads": buffer_keys, + "registry_entries": uuid_keys, + "ttl_seconds": self.ttl, + } diff --git a/xml_pipeline/memory/shared_backend.py b/xml_pipeline/memory/shared_backend.py new file mode 100644 index 0000000..13750e7 --- /dev/null +++ b/xml_pipeline/memory/shared_backend.py @@ -0,0 +1,275 @@ +""" +shared_backend.py — Abstract backend interface for shared state. + +Provides a Protocol defining operations needed by ContextBuffer and ThreadRegistry +to work across processes. Implementations include: +- InMemoryBackend: Default, single-process (current behavior) +- ManagerBackend: multiprocessing.Manager for local multi-process +- RedisBackend: Redis for distributed/multi-tenant scenarios + +All backends support: +- Context buffer operations (thread slots) +- Thread registry operations (UUID ↔ chain mapping) +- TTL for automatic garbage collection +""" + +from __future__ import annotations + +import pickle +from abc import abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Protocol, Tuple, runtime_checkable + + +@dataclass +class BackendConfig: + """Configuration for shared backend selection.""" + + backend_type: str = "memory" # "memory", "manager", "redis" + + # Redis-specific config + redis_url: str = "redis://localhost:6379" + redis_prefix: str = "xp:" + redis_ttl: int = 86400 # 24 hours default + + # Manager-specific config (for local multiprocess) + manager_address: Optional[Tuple[str, int]] = None + manager_authkey: Optional[bytes] = None + + # Common limits + max_slots_per_thread: int = 10000 + max_threads: int = 1000 + + +@runtime_checkable +class SharedBackend(Protocol): + """ + Protocol for shared state backends. + + All methods should be synchronous (blocking) for simplicity. + Async wrappers can be added at the caller level if needed. + """ + + # ========================================================================= + # Context Buffer Operations + # ========================================================================= + + @abstractmethod + def buffer_append( + self, + thread_id: str, + slot_data: bytes, # Pickled BufferSlot + ) -> int: + """ + Append a slot to a thread's buffer. + + Args: + thread_id: UUID of the thread + slot_data: Pickled BufferSlot bytes + + Returns: + Index of the appended slot (0-based) + """ + ... + + @abstractmethod + def buffer_get_thread(self, thread_id: str) -> List[bytes]: + """ + Get all slots for a thread. + + Args: + thread_id: UUID of the thread + + Returns: + List of pickled BufferSlot bytes (in order) + """ + ... + + @abstractmethod + def buffer_get_slot(self, thread_id: str, index: int) -> Optional[bytes]: + """ + Get a specific slot by index. + + Args: + thread_id: UUID of the thread + index: Slot index (0-based) + + Returns: + Pickled BufferSlot bytes, or None if not found + """ + ... + + @abstractmethod + def buffer_thread_len(self, thread_id: str) -> int: + """Get number of slots in a thread.""" + ... + + @abstractmethod + def buffer_thread_exists(self, thread_id: str) -> bool: + """Check if a thread has any slots.""" + ... + + @abstractmethod + def buffer_delete_thread(self, thread_id: str) -> bool: + """Delete all slots for a thread. Returns True if thread existed.""" + ... + + @abstractmethod + def buffer_list_threads(self) -> List[str]: + """List all thread IDs with slots.""" + ... + + @abstractmethod + def buffer_clear(self) -> None: + """Clear all buffer data (for testing).""" + ... + + # ========================================================================= + # Thread Registry Operations + # ========================================================================= + + @abstractmethod + def registry_set(self, chain: str, uuid: str) -> None: + """ + Set bidirectional mapping: chain ↔ uuid. + + Args: + chain: Dot-separated call chain (e.g., "console.router.greeter") + uuid: UUID string for this chain + """ + ... + + @abstractmethod + def registry_get_uuid(self, chain: str) -> Optional[str]: + """ + Get UUID for a chain. + + Args: + chain: Call chain to look up + + Returns: + UUID string, or None if not found + """ + ... + + @abstractmethod + def registry_get_chain(self, uuid: str) -> Optional[str]: + """ + Get chain for a UUID. + + Args: + uuid: UUID to look up + + Returns: + Chain string, or None if not found + """ + ... + + @abstractmethod + def registry_delete(self, uuid: str) -> bool: + """ + Delete mapping by UUID. + + Removes both chain→uuid and uuid→chain mappings. + + Returns: + True if mapping existed + """ + ... + + @abstractmethod + def registry_list_all(self) -> Dict[str, str]: + """ + Get all UUID → chain mappings. + + Returns: + Dict mapping UUID to chain + """ + ... + + @abstractmethod + def registry_clear(self) -> None: + """Clear all registry data (for testing).""" + ... + + # ========================================================================= + # Lifecycle + # ========================================================================= + + @abstractmethod + def close(self) -> None: + """Close connections and clean up resources.""" + ... + + +# ============================================================================= +# Serialization Helpers +# ============================================================================= + + +def serialize_slot(slot: Any) -> bytes: + """Serialize a BufferSlot to bytes using pickle.""" + return pickle.dumps(slot) + + +def deserialize_slot(data: bytes) -> Any: + """Deserialize bytes back to a BufferSlot.""" + return pickle.loads(data) + + +# ============================================================================= +# Factory +# ============================================================================= + +_backend: Optional[SharedBackend] = None + + +def get_shared_backend(config: Optional[BackendConfig] = None) -> SharedBackend: + """ + Get or create the global shared backend. + + Backend selection: + 1. If Redis URL configured and redis available → RedisBackend + 2. If Manager configured → ManagerBackend + 3. Otherwise → InMemoryBackend (default) + + Thread-safe singleton pattern. + """ + global _backend + + if _backend is not None: + return _backend + + if config is None: + config = BackendConfig() + + if config.backend_type == "redis": + from xml_pipeline.memory.redis_backend import RedisBackend + + _backend = RedisBackend( + url=config.redis_url, + prefix=config.redis_prefix, + ttl=config.redis_ttl, + ) + elif config.backend_type == "manager": + from xml_pipeline.memory.manager_backend import ManagerBackend + + _backend = ManagerBackend( + address=config.manager_address, + authkey=config.manager_authkey, + ) + else: + # Default: in-memory backend + from xml_pipeline.memory.memory_backend import InMemoryBackend + + _backend = InMemoryBackend() + + return _backend + + +def reset_shared_backend() -> None: + """Reset the global backend (for testing).""" + global _backend + if _backend is not None: + _backend.close() + _backend = None diff --git a/xml_pipeline/message_bus/stream_pump.py b/xml_pipeline/message_bus/stream_pump.py index 15b7bb6..ce0bb5c 100644 --- a/xml_pipeline/message_bus/stream_pump.py +++ b/xml_pipeline/message_bus/stream_pump.py @@ -9,12 +9,21 @@ The pipeline is just a composition of stream operators. Dependencies: pip install aiostream + +CPU-Bound Handlers: + Handlers marked with `cpu_bound: true` are dispatched to a + ProcessPoolExecutor instead of running in the main event loop. + This prevents long-running handlers from blocking other messages. + + Requires shared backend (Redis or Manager) for cross-process data access. """ from __future__ import annotations import asyncio import importlib +import logging +from concurrent.futures import ProcessPoolExecutor from dataclasses import dataclass, field from pathlib import Path from typing import AsyncIterable, Callable, List, Dict, Any, Optional @@ -34,6 +43,8 @@ from xml_pipeline.message_bus.thread_registry import get_registry from xml_pipeline.message_bus.todo_registry import get_todo_registry from xml_pipeline.memory import get_context_buffer +pump_logger = logging.getLogger(__name__) + # ============================================================================ # Configuration (same as before) @@ -49,6 +60,7 @@ class ListenerConfig: peers: List[str] = field(default_factory=list) broadcast: bool = False prompt: str = "" # System prompt for LLM agents (loaded into PromptRegistry) + cpu_bound: bool = False # Dispatch to ProcessPoolExecutor if True payload_class: type = field(default=None, repr=False) handler: Callable = field(default=None, repr=False) @@ -69,6 +81,16 @@ class OrganismConfig: # LLM configuration (optional) llm_config: Dict[str, Any] = field(default_factory=dict) + # Process pool configuration (for cpu_bound handlers) + process_pool_workers: int = 4 + process_pool_max_tasks_per_child: int = 100 + process_pool_enabled: bool = False + + # Backend configuration (for shared state) + backend_type: str = "memory" # "memory", "manager", "redis" + backend_redis_url: str = "redis://localhost:6379" + backend_redis_prefix: str = "xp:" + @dataclass class Listener: @@ -79,6 +101,8 @@ class Listener: is_agent: bool = False peers: List[str] = field(default_factory=list) broadcast: bool = False + cpu_bound: bool = False # Dispatch to ProcessPoolExecutor if True + handler_path: str = "" # Import path for worker process schema: etree.XMLSchema = field(default=None, repr=False) root_tag: str = "" usage_instructions: str = "" # Generated at registration for LLM agents @@ -170,6 +194,10 @@ class StreamPump: The entire flow is a single composable stream pipeline. Fan-out is natural via flatmap. Concurrency is controlled via task_limit. + + CPU-bound handlers can be dispatched to a ProcessPoolExecutor by + marking them with `cpu_bound: true` in config. This requires a + shared backend (Redis or Manager) for cross-process data access. """ def __init__(self, config: OrganismConfig): @@ -188,6 +216,29 @@ class StreamPump: # Shutdown control self._running = False + # Process pool for cpu_bound handlers + self._process_pool: Optional[ProcessPoolExecutor] = None + if config.process_pool_enabled: + self._process_pool = ProcessPoolExecutor( + max_workers=config.process_pool_workers, + max_tasks_per_child=config.process_pool_max_tasks_per_child, + ) + pump_logger.info( + f"ProcessPool initialized: {config.process_pool_workers} workers" + ) + + # Shared backend for cross-process state + self._shared_backend = None + if config.backend_type != "memory": + from xml_pipeline.memory.shared_backend import BackendConfig, get_shared_backend + backend_config = BackendConfig( + backend_type=config.backend_type, + redis_url=config.backend_redis_url, + redis_prefix=config.backend_redis_prefix, + ) + self._shared_backend = get_shared_backend(backend_config) + pump_logger.info(f"Shared backend: {config.backend_type}") + # ------------------------------------------------------------------ # Registration # ------------------------------------------------------------------ @@ -203,6 +254,8 @@ class StreamPump: is_agent=lc.is_agent, peers=lc.peers, broadcast=lc.broadcast, + cpu_bound=lc.cpu_bound, + handler_path=lc.handler_path, # For worker process import schema=self._generate_schema(lc.payload_class), root_tag=root_tag, ) @@ -420,7 +473,15 @@ class StreamPump: ) payload_ref = state.payload - response = await listener.handler(payload_ref, metadata) + # Dispatch to handler - either in-process or via ProcessPool + if listener.cpu_bound and self._process_pool and self._shared_backend: + response = await self._dispatch_to_process_pool( + listener=listener, + payload=payload_ref, + metadata=metadata, + ) + else: + response = await listener.handler(payload_ref, metadata) # None means "no response needed" - don't re-inject if response is None: @@ -549,6 +610,94 @@ class StreamPump: """ return envelope.encode('utf-8') + async def _dispatch_to_process_pool( + self, + listener: Listener, + payload: Any, + metadata: HandlerMetadata, + ) -> Optional[HandlerResponse]: + """ + Dispatch handler to ProcessPoolExecutor for CPU-bound execution. + + This offloads work to a separate process to avoid blocking + the main event loop. + + Args: + listener: The target listener + payload: The @xmlify dataclass payload + metadata: Handler metadata + + Returns: + HandlerResponse or None (same as direct handler call) + """ + from xml_pipeline.message_bus.worker import ( + WorkerTask, + store_task_data, + fetch_response, + cleanup_task_data, + execute_handler, + ) + + assert self._process_pool is not None + assert self._shared_backend is not None + + # Store payload and metadata in shared backend + payload_uuid, metadata_uuid = store_task_data( + self._shared_backend, payload, metadata + ) + + # Create worker task + task = WorkerTask( + thread_uuid=metadata.thread_id, + payload_uuid=payload_uuid, + handler_path=listener.handler_path, + metadata_uuid=metadata_uuid, + listener_name=listener.name, + is_agent=listener.is_agent, + peers=list(listener.peers), + ) + + # Backend config for worker process + backend_config = { + "backend_type": self.config.backend_type, + "redis_url": self.config.backend_redis_url, + "redis_prefix": self.config.backend_redis_prefix, + } + + try: + # Submit to process pool and await result + loop = asyncio.get_event_loop() + result = await loop.run_in_executor( + self._process_pool, + execute_handler, + task, + backend_config, + ) + + if not result.success: + pump_logger.error( + f"Worker error for {listener.name}: {result.error}" + ) + if result.error_traceback: + pump_logger.debug(f"Traceback: {result.error_traceback}") + return None + + # Fetch response from shared backend + if result.response_uuid: + response = fetch_response(self._shared_backend, result.response_uuid) + return response + + return None + + finally: + # Clean up task data from backend + cleanup_task_data( + self._shared_backend, + payload_uuid, + metadata_uuid, + result.response_uuid if 'result' in dir() and result.success else None, + ) + async def _reinject_responses(self, state: MessageState) -> None: """Push handler responses back into the queue for next iteration.""" await self.queue.put(state) @@ -714,10 +863,16 @@ class StreamPump: await self.queue.put(state) async def shutdown(self) -> None: - """Graceful shutdown — wait for queue to drain.""" + """Graceful shutdown — wait for queue to drain and close resources.""" self._running = False await self.queue.join() + # Shutdown process pool if active + if self._process_pool: + self._process_pool.shutdown(wait=True) + pump_logger.info("ProcessPool shutdown complete") + self._process_pool = None + # ============================================================================ # Config Loader (same as before) @@ -733,6 +888,19 @@ class ConfigLoader: @classmethod def _parse(cls, raw: dict) -> OrganismConfig: org = raw.get("organism", {}) + + # Parse process pool config + pool = raw.get("process_pool", {}) + process_pool_enabled = pool.get("enabled", False) if pool else False + process_pool_workers = pool.get("workers", 4) if pool else 4 + process_pool_max_tasks = pool.get("max_tasks_per_child", 100) if pool else 100 + + # Parse backend config + backend = raw.get("backend", {}) + backend_type = backend.get("type", "memory") if backend else "memory" + backend_redis_url = backend.get("redis_url", "redis://localhost:6379") if backend else "redis://localhost:6379" + backend_redis_prefix = backend.get("redis_prefix", "xp:") if backend else "xp:" + config = OrganismConfig( name=org.get("name", "unnamed"), identity_path=org.get("identity", ""), @@ -742,6 +910,12 @@ class ConfigLoader: max_concurrent_handlers=raw.get("max_concurrent_handlers", 20), max_concurrent_per_agent=raw.get("max_concurrent_per_agent", 5), llm_config=raw.get("llm", {}), + process_pool_enabled=process_pool_enabled, + process_pool_workers=process_pool_workers, + process_pool_max_tasks_per_child=process_pool_max_tasks, + backend_type=backend_type, + backend_redis_url=backend_redis_url, + backend_redis_prefix=backend_redis_prefix, ) for entry in raw.get("listeners", []): @@ -762,6 +936,7 @@ class ConfigLoader: peers=raw.get("peers", []), broadcast=raw.get("broadcast", False), prompt=raw.get("prompt", ""), + cpu_bound=raw.get("cpu_bound", False), ) @classmethod diff --git a/xml_pipeline/message_bus/thread_registry.py b/xml_pipeline/message_bus/thread_registry.py index 580eed9..4f34c64 100644 --- a/xml_pipeline/message_bus/thread_registry.py +++ b/xml_pipeline/message_bus/thread_registry.py @@ -14,15 +14,26 @@ Response routing: 2. Prunes the last segment (the responder) 3. Routes to the new last segment (the caller) 4. Updates/cleans up the registry + +For multi-process deployments, the registry can use a shared backend: + from xml_pipeline.memory.shared_backend import get_shared_backend, BackendConfig + + config = BackendConfig(backend_type="redis", redis_url="redis://localhost:6379") + backend = get_shared_backend(config) + registry = get_registry(backend=backend) """ -import uuid +from __future__ import annotations + +import uuid as uuid_module from dataclasses import dataclass, field -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, TYPE_CHECKING import threading +if TYPE_CHECKING: + from xml_pipeline.memory.shared_backend import SharedBackend + -@dataclass class ThreadRegistry: """ Bidirectional mapping between UUIDs and call chains. @@ -32,12 +43,35 @@ class ThreadRegistry: The registry maintains a root thread established at boot time. All external messages without a known parent are registered as children of the root thread. + + Supports two storage modes: + 1. Local mode (default): Uses in-process dictionaries + 2. Shared mode: Uses SharedBackend (Redis, Manager) for cross-process access """ - _chain_to_uuid: Dict[str, str] = field(default_factory=dict) - _uuid_to_chain: Dict[str, str] = field(default_factory=dict) - _lock: threading.Lock = field(default_factory=threading.Lock) - _root_uuid: Optional[str] = field(default=None) - _root_chain: str = field(default="system") + + def __init__(self, backend: Optional[SharedBackend] = None): + """ + Initialize thread registry. + + Args: + backend: Optional shared backend for cross-process storage. + If None, uses in-process storage (original behavior). + """ + self._backend = backend + + # Local storage (used when no backend) + self._chain_to_uuid: Dict[str, str] = {} + self._uuid_to_chain: Dict[str, str] = {} + self._lock = threading.Lock() + + # Root thread tracking + self._root_uuid: Optional[str] = None + self._root_chain: str = "system" + + @property + def is_shared(self) -> bool: + """Return True if using shared backend.""" + return self._backend is not None def initialize_root(self, organism_name: str = "organism") -> str: """ @@ -52,16 +86,36 @@ class ThreadRegistry: Returns: UUID for the root thread """ + if self._backend is not None: + return self._initialize_root_shared(organism_name) + with self._lock: if self._root_uuid is not None: return self._root_uuid self._root_chain = f"system.{organism_name}" - self._root_uuid = str(uuid.uuid4()) + self._root_uuid = str(uuid_module.uuid4()) self._chain_to_uuid[self._root_chain] = self._root_uuid self._uuid_to_chain[self._root_uuid] = self._root_chain return self._root_uuid + def _initialize_root_shared(self, organism_name: str) -> str: + """Initialize root in shared backend.""" + assert self._backend is not None + + self._root_chain = f"system.{organism_name}" + + # Check if root already exists in backend + existing_uuid = self._backend.registry_get_uuid(self._root_chain) + if existing_uuid: + self._root_uuid = existing_uuid + return existing_uuid + + # Create new root + self._root_uuid = str(uuid_module.uuid4()) + self._backend.registry_set(self._root_chain, self._root_uuid) + return self._root_uuid + @property def root_uuid(self) -> Optional[str]: """Get the root thread UUID (None if not initialized).""" @@ -82,11 +136,19 @@ class ThreadRegistry: Returns: UUID string for this chain """ + if self._backend is not None: + existing = self._backend.registry_get_uuid(chain) + if existing: + return existing + new_uuid = str(uuid_module.uuid4()) + self._backend.registry_set(chain, new_uuid) + return new_uuid + with self._lock: if chain in self._chain_to_uuid: return self._chain_to_uuid[chain] - new_uuid = str(uuid.uuid4()) + new_uuid = str(uuid_module.uuid4()) self._chain_to_uuid[chain] = new_uuid self._uuid_to_chain[new_uuid] = chain return new_uuid @@ -101,6 +163,9 @@ class ThreadRegistry: Returns: Chain string, or None if not found """ + if self._backend is not None: + return self._backend.registry_get_chain(thread_id) + with self._lock: return self._uuid_to_chain.get(thread_id) @@ -115,6 +180,9 @@ class ThreadRegistry: Returns: UUID for the extended chain """ + if self._backend is not None: + return self._extend_chain_shared(current_uuid, next_hop) + with self._lock: current_chain = self._uuid_to_chain.get(current_uuid, "") if current_chain: @@ -127,11 +195,31 @@ class ThreadRegistry: return self._chain_to_uuid[new_chain] # Create new UUID for extended chain - new_uuid = str(uuid.uuid4()) + new_uuid = str(uuid_module.uuid4()) self._chain_to_uuid[new_chain] = new_uuid self._uuid_to_chain[new_uuid] = new_chain return new_uuid + def _extend_chain_shared(self, current_uuid: str, next_hop: str) -> str: + """Extend chain in shared backend.""" + assert self._backend is not None + + current_chain = self._backend.registry_get_chain(current_uuid) or "" + if current_chain: + new_chain = f"{current_chain}.{next_hop}" + else: + new_chain = next_hop + + # Check if extended chain already exists + existing = self._backend.registry_get_uuid(new_chain) + if existing: + return existing + + # Create new UUID for extended chain + new_uuid = str(uuid_module.uuid4()) + self._backend.registry_set(new_chain, new_uuid) + return new_uuid + def prune_for_response(self, thread_id: str) -> Tuple[Optional[str], Optional[str]]: """ Prune chain for a response and get the target. @@ -147,6 +235,9 @@ class ThreadRegistry: Returns: Tuple of (target_listener, new_thread_uuid) or (None, None) if chain exhausted """ + if self._backend is not None: + return self._prune_for_response_shared(thread_id) + with self._lock: chain = self._uuid_to_chain.get(thread_id) if not chain: @@ -168,15 +259,40 @@ class ThreadRegistry: if pruned_chain in self._chain_to_uuid: new_uuid = self._chain_to_uuid[pruned_chain] else: - new_uuid = str(uuid.uuid4()) + new_uuid = str(uuid_module.uuid4()) self._chain_to_uuid[pruned_chain] = new_uuid self._uuid_to_chain[new_uuid] = pruned_chain - # Clean up old UUID (optional - could keep for debugging) - # self._cleanup_uuid(thread_id) - return target, new_uuid + def _prune_for_response_shared(self, thread_id: str) -> Tuple[Optional[str], Optional[str]]: + """Prune chain in shared backend.""" + assert self._backend is not None + + chain = self._backend.registry_get_chain(thread_id) + if not chain: + return None, None + + parts = chain.split(".") + if len(parts) <= 1: + # Chain exhausted + self._backend.registry_delete(thread_id) + return None, None + + # Prune last segment + pruned_parts = parts[:-1] + target = pruned_parts[-1] + pruned_chain = ".".join(pruned_parts) + + # Get or create UUID for pruned chain + existing = self._backend.registry_get_uuid(pruned_chain) + if existing: + return target, existing + + new_uuid = str(uuid_module.uuid4()) + self._backend.registry_set(pruned_chain, new_uuid) + return target, new_uuid + def start_chain(self, initiator: str, target: str) -> str: """ Start a new call chain. @@ -208,6 +324,9 @@ class ThreadRegistry: Returns: The same thread_id (now registered) """ + if self._backend is not None: + return self._register_thread_shared(thread_id, initiator, target) + with self._lock: # Check if UUID already registered (shouldn't happen, but be safe) if thread_id in self._uuid_to_chain: @@ -230,6 +349,29 @@ class ThreadRegistry: self._uuid_to_chain[thread_id] = chain return thread_id + def _register_thread_shared(self, thread_id: str, initiator: str, target: str) -> str: + """Register thread in shared backend.""" + assert self._backend is not None + + # Check if UUID already registered + if self._backend.registry_get_chain(thread_id): + return thread_id + + # Build chain rooted at system root + if self._root_uuid is not None: + chain = f"{self._root_chain}.{initiator}.{target}" + else: + chain = f"{initiator}.{target}" + + # Check if chain already has a different UUID + existing = self._backend.registry_get_uuid(chain) + if existing: + return existing + + # Register the external UUID to this chain + self._backend.registry_set(chain, thread_id) + return thread_id + def _cleanup_uuid(self, thread_id: str) -> None: """Remove a UUID mapping (internal, call with lock held).""" chain = self._uuid_to_chain.pop(thread_id, None) @@ -238,22 +380,65 @@ class ThreadRegistry: def cleanup(self, thread_id: str) -> None: """Explicitly clean up a thread UUID.""" + if self._backend is not None: + self._backend.registry_delete(thread_id) + return + with self._lock: self._cleanup_uuid(thread_id) def debug_dump(self) -> Dict[str, str]: """Return current mappings for debugging.""" + if self._backend is not None: + return self._backend.registry_list_all() + with self._lock: return dict(self._uuid_to_chain) + def clear(self) -> None: + """Clear all thread mappings (for testing only).""" + if self._backend is not None: + self._backend.registry_clear() + self._root_uuid = None + self._root_chain = "system" + return + + with self._lock: + self._chain_to_uuid.clear() + self._uuid_to_chain.clear() + self._root_uuid = None + self._root_chain = "system" + # Global registry instance _registry: Optional[ThreadRegistry] = None +_registry_lock = threading.Lock() -def get_registry() -> ThreadRegistry: - """Get the global thread registry.""" +def get_registry(backend: Optional[SharedBackend] = None) -> ThreadRegistry: + """ + Get the global thread registry. + + Args: + backend: Optional shared backend for cross-process storage. + Only used on first call (when creating the singleton). + Subsequent calls return the existing singleton. + + Returns: + Global ThreadRegistry instance. + """ global _registry if _registry is None: - _registry = ThreadRegistry() + with _registry_lock: + if _registry is None: + _registry = ThreadRegistry(backend=backend) return _registry + + +def reset_registry() -> None: + """Reset the global thread registry (for testing).""" + global _registry + with _registry_lock: + if _registry is not None: + _registry.clear() + _registry = None diff --git a/xml_pipeline/message_bus/worker.py b/xml_pipeline/message_bus/worker.py new file mode 100644 index 0000000..74fadda --- /dev/null +++ b/xml_pipeline/message_bus/worker.py @@ -0,0 +1,339 @@ +""" +worker.py — Worker process entry point for CPU-bound handler dispatch. + +This module provides the entry point for handler execution in worker processes +when using ProcessPoolExecutor. Workers communicate with the main process via +the shared backend (Redis or Manager). + +Architecture: + Main Process (StreamPump) + │ + │ Submits WorkerTask to ProcessPoolExecutor + │ + ▼ + Worker Process (this module) + │ + ├── Imports handler module + ├── Fetches payload from shared backend + ├── Executes handler + ├── Stores response in shared backend + └── Returns WorkerResult + +Key design decisions: +- Minimal IPC payload: Only UUIDs and module paths cross process boundary +- Handler state via shared backend: Workers fetch/store data in Redis +- Process reuse: Workers are reused across tasks (max_tasks_per_child for restart) +- Error isolation: Handler exceptions don't crash the worker pool + +Usage: + # Main process submits task + from concurrent.futures import ProcessPoolExecutor + from xml_pipeline.message_bus.worker import execute_handler, WorkerTask + + pool = ProcessPoolExecutor(max_workers=4) + task = WorkerTask( + thread_uuid="550e8400-...", + payload_uuid="6ba7b810-...", + handler_path="handlers.librarian.handle_query", + metadata_uuid="7c9e6679-...", + ) + future = pool.submit(execute_handler, task) + result = future.result() # WorkerResult +""" + +from __future__ import annotations + +import importlib +import logging +import traceback +import uuid +from dataclasses import dataclass, field +from typing import Any, Callable, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from xml_pipeline.memory.shared_backend import SharedBackend + +logger = logging.getLogger(__name__) + + +@dataclass +class WorkerTask: + """ + Task submitted to worker process. + + Contains only UUIDs and paths — actual data lives in shared backend. + This minimizes IPC overhead and allows large payloads. + """ + + thread_uuid: str + payload_uuid: str # UUID for looking up serialized payload + handler_path: str # e.g., "handlers.librarian.handle_query" + metadata_uuid: str # UUID for looking up serialized HandlerMetadata + listener_name: str # Name of the listener (for logging/response injection) + is_agent: bool = False # Whether this is an agent handler + peers: list[str] = field(default_factory=list) # Allowed peers (for agents) + + +@dataclass +class WorkerResult: + """ + Result from worker process. + + Contains UUID for response stored in backend, or error information. + """ + + success: bool + response_uuid: Optional[str] = None # UUID for serialized HandlerResponse + error: Optional[str] = None # Error message if failed + error_traceback: Optional[str] = None # Full traceback for debugging + elapsed_ms: float = 0.0 # Execution time + + +# Keys for storing task data in shared backend +def _payload_key(uuid: str) -> str: + return f"worker:payload:{uuid}" + + +def _metadata_key(uuid: str) -> str: + return f"worker:metadata:{uuid}" + + +def _response_key(uuid: str) -> str: + return f"worker:response:{uuid}" + + +def _import_handler(handler_path: str) -> Callable: + """ + Import and return handler function from path. + + Args: + handler_path: Dotted path like "handlers.librarian.handle_query" + + Returns: + The handler function + + Raises: + ImportError: If module not found + AttributeError: If function not found in module + """ + module_path, func_name = handler_path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, func_name) + + +def store_task_data( + backend: SharedBackend, + payload: Any, + metadata: Any, + ttl: int = 3600, +) -> tuple[str, str]: + """ + Store payload and metadata in shared backend for worker access. + + Args: + backend: Shared backend instance + payload: The @xmlify dataclass payload + metadata: HandlerMetadata instance + ttl: Time-to-live in seconds (default 1 hour) + + Returns: + Tuple of (payload_uuid, metadata_uuid) + """ + import pickle + + payload_uuid = str(uuid.uuid4()) + metadata_uuid = str(uuid.uuid4()) + + # Serialize and store + payload_bytes = pickle.dumps(payload) + metadata_bytes = pickle.dumps(metadata) + + # Store as buffer slots (reusing the backend interface) + # We use a special "worker" prefix to distinguish from regular slots + backend.buffer_append(f"worker:payload:{payload_uuid}", payload_bytes) + backend.buffer_append(f"worker:metadata:{metadata_uuid}", metadata_bytes) + + return payload_uuid, metadata_uuid + + +def fetch_task_data( + backend: SharedBackend, + payload_uuid: str, + metadata_uuid: str, +) -> tuple[Any, Any]: + """ + Fetch payload and metadata from shared backend. + + Args: + backend: Shared backend instance + payload_uuid: UUID of stored payload + metadata_uuid: UUID of stored metadata + + Returns: + Tuple of (payload, metadata) + """ + import pickle + + payload_data = backend.buffer_get_slot(f"worker:payload:{payload_uuid}", 0) + metadata_data = backend.buffer_get_slot(f"worker:metadata:{metadata_uuid}", 0) + + if payload_data is None: + raise ValueError(f"Payload not found: {payload_uuid}") + if metadata_data is None: + raise ValueError(f"Metadata not found: {metadata_uuid}") + + payload = pickle.loads(payload_data) + metadata = pickle.loads(metadata_data) + + return payload, metadata + + +def store_response( + backend: SharedBackend, + response: Any, +) -> str: + """ + Store handler response in shared backend for main process retrieval. + + Args: + backend: Shared backend instance + response: HandlerResponse or None + + Returns: + Response UUID + """ + import pickle + + response_uuid = str(uuid.uuid4()) + response_bytes = pickle.dumps(response) + backend.buffer_append(f"worker:response:{response_uuid}", response_bytes) + + return response_uuid + + +def fetch_response( + backend: SharedBackend, + response_uuid: str, +) -> Any: + """ + Fetch handler response from shared backend. + + Args: + backend: Shared backend instance + response_uuid: UUID of stored response + + Returns: + HandlerResponse or None + """ + import pickle + + response_data = backend.buffer_get_slot(f"worker:response:{response_uuid}", 0) + if response_data is None: + raise ValueError(f"Response not found: {response_uuid}") + + return pickle.loads(response_data) + + +def cleanup_task_data( + backend: SharedBackend, + payload_uuid: str, + metadata_uuid: str, + response_uuid: Optional[str] = None, +) -> None: + """ + Clean up temporary task data from shared backend. + + Call after response has been processed by main process. + """ + backend.buffer_delete_thread(f"worker:payload:{payload_uuid}") + backend.buffer_delete_thread(f"worker:metadata:{metadata_uuid}") + if response_uuid: + backend.buffer_delete_thread(f"worker:response:{response_uuid}") + + +def execute_handler( + task: WorkerTask, + backend_config: Optional[dict] = None, +) -> WorkerResult: + """ + Execute a handler in the worker process. + + This is the entry point called by ProcessPoolExecutor. + + Args: + task: WorkerTask containing UUIDs and handler path + backend_config: Optional backend configuration dict + (if None, uses default shared backend) + + Returns: + WorkerResult with response UUID or error + """ + import asyncio + import time + + start_time = time.monotonic() + + try: + # Get shared backend + from xml_pipeline.memory.shared_backend import ( + BackendConfig, + get_shared_backend, + ) + + if backend_config: + config = BackendConfig(**backend_config) + backend = get_shared_backend(config) + else: + backend = get_shared_backend() + + # Fetch payload and metadata + payload, metadata = fetch_task_data( + backend, task.payload_uuid, task.metadata_uuid + ) + + # Import handler + handler = _import_handler(task.handler_path) + + # Execute handler (async) + # Workers run their own event loop for async handlers + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + response = loop.run_until_complete(handler(payload, metadata)) + finally: + loop.close() + + # Store response + response_uuid = store_response(backend, response) + + elapsed = (time.monotonic() - start_time) * 1000 + + return WorkerResult( + success=True, + response_uuid=response_uuid, + elapsed_ms=elapsed, + ) + + except Exception as e: + elapsed = (time.monotonic() - start_time) * 1000 + logger.exception(f"Handler execution failed: {task.handler_path}") + + return WorkerResult( + success=False, + error=str(e), + error_traceback=traceback.format_exc(), + elapsed_ms=elapsed, + ) + + +def execute_handler_sync( + task: WorkerTask, + backend_config: Optional[dict] = None, +) -> WorkerResult: + """ + Synchronous wrapper for execute_handler. + + Use this when the handler doesn't need async features + (pure computation, no I/O). + """ + return execute_handler(task, backend_config)