Add shared backend for multiprocess pipeline support
Introduces SharedBackend Protocol for cross-process state sharing: - InMemoryBackend: default single-process storage - ManagerBackend: multiprocessing.Manager for local multi-process - RedisBackend: distributed deployments with TTL auto-GC Adds ProcessPoolExecutor support for CPU-bound handlers: - worker.py: worker process entry point - stream_pump.py: cpu_bound handler dispatch - Config: backend and process_pool sections in organism.yaml ContextBuffer and ThreadRegistry now accept optional backend parameter while maintaining full backward compatibility. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
f87d9f80e9
commit
6790c7a46c
12 changed files with 2346 additions and 28 deletions
|
|
@ -36,6 +36,50 @@ def pytest_configure(config):
|
|||
# Fixtures available to all tests
|
||||
# ============================================================================
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_singletons():
|
||||
"""Reset global singletons before each test to ensure isolation."""
|
||||
# Clear registries before test
|
||||
try:
|
||||
from xml_pipeline.platform.prompt_registry import get_prompt_registry
|
||||
get_prompt_registry().clear()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from xml_pipeline.memory.context_buffer import get_context_buffer
|
||||
get_context_buffer().clear()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from xml_pipeline.message_bus.thread_registry import get_registry
|
||||
get_registry().clear()
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
yield # Run the test
|
||||
|
||||
# Clear after test too for good measure
|
||||
try:
|
||||
from xml_pipeline.platform.prompt_registry import get_prompt_registry
|
||||
get_prompt_registry().clear()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from xml_pipeline.memory.context_buffer import get_context_buffer
|
||||
get_context_buffer().clear()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from xml_pipeline.message_bus.thread_registry import get_registry
|
||||
get_registry().clear()
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_thread_id():
|
||||
"""A valid UUID for testing."""
|
||||
|
|
|
|||
393
tests/test_shared_backend.py
Normal file
393
tests/test_shared_backend.py
Normal file
|
|
@ -0,0 +1,393 @@
|
|||
"""
|
||||
Tests for shared backend implementations.
|
||||
|
||||
Tests InMemoryBackend and ManagerBackend.
|
||||
Redis tests require a running Redis server (skipped if not available).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from dataclasses import dataclass
|
||||
|
||||
from xml_pipeline.memory.memory_backend import InMemoryBackend
|
||||
from xml_pipeline.memory.shared_backend import (
|
||||
BackendConfig,
|
||||
serialize_slot,
|
||||
deserialize_slot,
|
||||
)
|
||||
from xml_pipeline.memory.context_buffer import BufferSlot, SlotMetadata
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def memory_backend():
|
||||
"""Create fresh in-memory backend."""
|
||||
backend = InMemoryBackend()
|
||||
yield backend
|
||||
backend.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def manager_backend():
|
||||
"""Create fresh manager backend."""
|
||||
from xml_pipeline.memory.manager_backend import ManagerBackend
|
||||
|
||||
backend = ManagerBackend()
|
||||
yield backend
|
||||
backend.close()
|
||||
|
||||
|
||||
def redis_available() -> bool:
|
||||
"""Check if Redis is available."""
|
||||
try:
|
||||
import redis
|
||||
|
||||
client = redis.from_url("redis://localhost:6379")
|
||||
client.ping()
|
||||
client.close()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def redis_backend():
|
||||
"""Create fresh Redis backend (skipped if Redis not available)."""
|
||||
if not redis_available():
|
||||
pytest.skip("Redis not available")
|
||||
|
||||
from xml_pipeline.memory.redis_backend import RedisBackend
|
||||
|
||||
backend = RedisBackend(
|
||||
url="redis://localhost:6379",
|
||||
prefix="xp_test:",
|
||||
ttl=300,
|
||||
)
|
||||
# Clear test keys
|
||||
backend.buffer_clear()
|
||||
backend.registry_clear()
|
||||
yield backend
|
||||
# Cleanup
|
||||
backend.buffer_clear()
|
||||
backend.registry_clear()
|
||||
backend.close()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Sample data
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@dataclass
|
||||
class SamplePayload:
|
||||
"""Sample payload for testing."""
|
||||
|
||||
message: str
|
||||
value: int
|
||||
|
||||
|
||||
def make_slot(thread_id: str, index: int = 0) -> BufferSlot:
|
||||
"""Create a sample buffer slot."""
|
||||
payload = SamplePayload(message="hello", value=42)
|
||||
metadata = SlotMetadata(
|
||||
thread_id=thread_id,
|
||||
from_id="sender",
|
||||
to_id="receiver",
|
||||
slot_index=index,
|
||||
timestamp="2024-01-15T00:00:00Z",
|
||||
payload_type="SamplePayload",
|
||||
)
|
||||
return BufferSlot(payload=payload, metadata=metadata)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# InMemoryBackend Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestInMemoryBackend:
|
||||
"""Tests for InMemoryBackend."""
|
||||
|
||||
def test_buffer_append_and_get(self, memory_backend):
|
||||
"""Test appending and retrieving buffer slots."""
|
||||
slot = make_slot("thread-1")
|
||||
slot_bytes = serialize_slot(slot)
|
||||
|
||||
# Append
|
||||
index = memory_backend.buffer_append("thread-1", slot_bytes)
|
||||
assert index == 0
|
||||
|
||||
# Append another
|
||||
index2 = memory_backend.buffer_append("thread-1", slot_bytes)
|
||||
assert index2 == 1
|
||||
|
||||
# Get all
|
||||
slots = memory_backend.buffer_get_thread("thread-1")
|
||||
assert len(slots) == 2
|
||||
|
||||
# Deserialize and verify
|
||||
retrieved = deserialize_slot(slots[0])
|
||||
assert retrieved.metadata.thread_id == "thread-1"
|
||||
assert retrieved.payload.message == "hello"
|
||||
|
||||
def test_buffer_get_slot(self, memory_backend):
|
||||
"""Test getting specific slot by index."""
|
||||
slot = make_slot("thread-1")
|
||||
slot_bytes = serialize_slot(slot)
|
||||
|
||||
memory_backend.buffer_append("thread-1", slot_bytes)
|
||||
memory_backend.buffer_append("thread-1", slot_bytes)
|
||||
|
||||
# Get specific slot
|
||||
data = memory_backend.buffer_get_slot("thread-1", 1)
|
||||
assert data is not None
|
||||
|
||||
# Non-existent index
|
||||
data = memory_backend.buffer_get_slot("thread-1", 999)
|
||||
assert data is None
|
||||
|
||||
def test_buffer_thread_exists(self, memory_backend):
|
||||
"""Test thread existence check."""
|
||||
assert not memory_backend.buffer_thread_exists("thread-1")
|
||||
|
||||
slot_bytes = serialize_slot(make_slot("thread-1"))
|
||||
memory_backend.buffer_append("thread-1", slot_bytes)
|
||||
|
||||
assert memory_backend.buffer_thread_exists("thread-1")
|
||||
|
||||
def test_buffer_delete_thread(self, memory_backend):
|
||||
"""Test deleting thread buffer."""
|
||||
slot_bytes = serialize_slot(make_slot("thread-1"))
|
||||
memory_backend.buffer_append("thread-1", slot_bytes)
|
||||
|
||||
assert memory_backend.buffer_delete_thread("thread-1")
|
||||
assert not memory_backend.buffer_thread_exists("thread-1")
|
||||
assert not memory_backend.buffer_delete_thread("thread-1") # Already deleted
|
||||
|
||||
def test_buffer_list_threads(self, memory_backend):
|
||||
"""Test listing all threads."""
|
||||
slot_bytes = serialize_slot(make_slot("thread-1"))
|
||||
memory_backend.buffer_append("thread-1", slot_bytes)
|
||||
memory_backend.buffer_append("thread-2", slot_bytes)
|
||||
|
||||
threads = memory_backend.buffer_list_threads()
|
||||
assert set(threads) == {"thread-1", "thread-2"}
|
||||
|
||||
def test_registry_set_and_get(self, memory_backend):
|
||||
"""Test registry set and get operations."""
|
||||
memory_backend.registry_set("a.b.c", "uuid-123")
|
||||
|
||||
# Get UUID from chain
|
||||
uuid = memory_backend.registry_get_uuid("a.b.c")
|
||||
assert uuid == "uuid-123"
|
||||
|
||||
# Get chain from UUID
|
||||
chain = memory_backend.registry_get_chain("uuid-123")
|
||||
assert chain == "a.b.c"
|
||||
|
||||
def test_registry_delete(self, memory_backend):
|
||||
"""Test registry delete."""
|
||||
memory_backend.registry_set("a.b.c", "uuid-123")
|
||||
|
||||
assert memory_backend.registry_delete("uuid-123")
|
||||
assert memory_backend.registry_get_uuid("a.b.c") is None
|
||||
assert memory_backend.registry_get_chain("uuid-123") is None
|
||||
assert not memory_backend.registry_delete("uuid-123") # Already deleted
|
||||
|
||||
def test_registry_list_all(self, memory_backend):
|
||||
"""Test listing all registry entries."""
|
||||
memory_backend.registry_set("a.b", "uuid-1")
|
||||
memory_backend.registry_set("x.y.z", "uuid-2")
|
||||
|
||||
all_entries = memory_backend.registry_list_all()
|
||||
assert all_entries == {"uuid-1": "a.b", "uuid-2": "x.y.z"}
|
||||
|
||||
def test_registry_clear(self, memory_backend):
|
||||
"""Test clearing registry."""
|
||||
memory_backend.registry_set("a.b", "uuid-1")
|
||||
memory_backend.registry_clear()
|
||||
|
||||
assert memory_backend.registry_list_all() == {}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ManagerBackend Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestManagerBackend:
|
||||
"""Tests for ManagerBackend (multiprocessing.Manager)."""
|
||||
|
||||
def test_buffer_append_and_get(self, manager_backend):
|
||||
"""Test appending and retrieving buffer slots."""
|
||||
slot = make_slot("thread-1")
|
||||
slot_bytes = serialize_slot(slot)
|
||||
|
||||
index = manager_backend.buffer_append("thread-1", slot_bytes)
|
||||
assert index == 0
|
||||
|
||||
slots = manager_backend.buffer_get_thread("thread-1")
|
||||
assert len(slots) == 1
|
||||
|
||||
def test_registry_operations(self, manager_backend):
|
||||
"""Test registry operations via Manager."""
|
||||
manager_backend.registry_set("a.b.c", "uuid-123")
|
||||
|
||||
uuid = manager_backend.registry_get_uuid("a.b.c")
|
||||
assert uuid == "uuid-123"
|
||||
|
||||
chain = manager_backend.registry_get_chain("uuid-123")
|
||||
assert chain == "a.b.c"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# RedisBackend Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.mark.skipif(not redis_available(), reason="Redis not available")
|
||||
class TestRedisBackend:
|
||||
"""Tests for RedisBackend (requires running Redis)."""
|
||||
|
||||
def test_buffer_append_and_get(self, redis_backend):
|
||||
"""Test appending and retrieving buffer slots."""
|
||||
slot = make_slot("thread-1")
|
||||
slot_bytes = serialize_slot(slot)
|
||||
|
||||
index = redis_backend.buffer_append("thread-1", slot_bytes)
|
||||
assert index == 0
|
||||
|
||||
slots = redis_backend.buffer_get_thread("thread-1")
|
||||
assert len(slots) == 1
|
||||
|
||||
retrieved = deserialize_slot(slots[0])
|
||||
assert retrieved.metadata.thread_id == "thread-1"
|
||||
|
||||
def test_buffer_thread_len(self, redis_backend):
|
||||
"""Test getting thread length."""
|
||||
slot_bytes = serialize_slot(make_slot("thread-1"))
|
||||
|
||||
redis_backend.buffer_append("thread-1", slot_bytes)
|
||||
redis_backend.buffer_append("thread-1", slot_bytes)
|
||||
redis_backend.buffer_append("thread-1", slot_bytes)
|
||||
|
||||
assert redis_backend.buffer_thread_len("thread-1") == 3
|
||||
|
||||
def test_registry_operations(self, redis_backend):
|
||||
"""Test registry operations via Redis."""
|
||||
redis_backend.registry_set("console.router.greeter", "uuid-abc")
|
||||
|
||||
uuid = redis_backend.registry_get_uuid("console.router.greeter")
|
||||
assert uuid == "uuid-abc"
|
||||
|
||||
chain = redis_backend.registry_get_chain("uuid-abc")
|
||||
assert chain == "console.router.greeter"
|
||||
|
||||
def test_registry_delete(self, redis_backend):
|
||||
"""Test registry delete removes both directions."""
|
||||
redis_backend.registry_set("a.b", "uuid-123")
|
||||
|
||||
assert redis_backend.registry_delete("uuid-123")
|
||||
assert redis_backend.registry_get_uuid("a.b") is None
|
||||
assert redis_backend.registry_get_chain("uuid-123") is None
|
||||
|
||||
def test_ping(self, redis_backend):
|
||||
"""Test Redis ping."""
|
||||
assert redis_backend.ping()
|
||||
|
||||
def test_info(self, redis_backend):
|
||||
"""Test backend info."""
|
||||
slot_bytes = serialize_slot(make_slot("thread-1"))
|
||||
redis_backend.buffer_append("thread-1", slot_bytes)
|
||||
redis_backend.registry_set("a.b", "uuid-1")
|
||||
|
||||
info = redis_backend.info()
|
||||
assert "buffer_threads" in info
|
||||
assert "registry_entries" in info
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration: ContextBuffer with SharedBackend
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestContextBufferWithBackend:
|
||||
"""Test ContextBuffer using shared backends."""
|
||||
|
||||
def test_context_buffer_with_memory_backend(self):
|
||||
"""Test ContextBuffer works with in-memory backend."""
|
||||
from xml_pipeline.memory.context_buffer import ContextBuffer
|
||||
|
||||
backend = InMemoryBackend()
|
||||
buffer = ContextBuffer(backend=backend)
|
||||
|
||||
assert buffer.is_shared
|
||||
|
||||
# Append slot
|
||||
slot = buffer.append(
|
||||
thread_id="test-thread",
|
||||
payload=SamplePayload(message="hello", value=42),
|
||||
from_id="sender",
|
||||
to_id="receiver",
|
||||
)
|
||||
|
||||
assert slot.thread_id == "test-thread"
|
||||
assert slot.payload.message == "hello"
|
||||
|
||||
# Get thread
|
||||
thread = buffer.get_thread("test-thread")
|
||||
assert thread is not None
|
||||
assert len(thread) == 1
|
||||
|
||||
# Get stats
|
||||
stats = buffer.get_stats()
|
||||
assert stats["thread_count"] == 1
|
||||
assert stats["backend"] == "shared"
|
||||
|
||||
backend.close()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration: ThreadRegistry with SharedBackend
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestThreadRegistryWithBackend:
|
||||
"""Test ThreadRegistry using shared backends."""
|
||||
|
||||
def test_thread_registry_with_memory_backend(self):
|
||||
"""Test ThreadRegistry works with in-memory backend."""
|
||||
from xml_pipeline.message_bus.thread_registry import ThreadRegistry
|
||||
|
||||
backend = InMemoryBackend()
|
||||
registry = ThreadRegistry(backend=backend)
|
||||
|
||||
assert registry.is_shared
|
||||
|
||||
# Initialize root
|
||||
root_uuid = registry.initialize_root("test-organism")
|
||||
assert root_uuid is not None
|
||||
assert registry.root_chain == "system.test-organism"
|
||||
|
||||
# Get or create chain
|
||||
uuid = registry.get_or_create("a.b.c")
|
||||
assert uuid is not None
|
||||
|
||||
# Lookup
|
||||
chain = registry.lookup(uuid)
|
||||
assert chain == "a.b.c"
|
||||
|
||||
# Extend chain
|
||||
new_uuid = registry.extend_chain(uuid, "d")
|
||||
new_chain = registry.lookup(new_uuid)
|
||||
assert new_chain == "a.b.c.d"
|
||||
|
||||
# Prune for response
|
||||
target, pruned_uuid = registry.prune_for_response(new_uuid)
|
||||
assert target == "c"
|
||||
assert registry.lookup(pruned_uuid) == "a.b.c"
|
||||
|
||||
backend.close()
|
||||
|
|
@ -65,6 +65,9 @@ class ListenerConfig:
|
|||
allowed_tools: list[str] = field(default_factory=list)
|
||||
blocked_tools: list[str] = field(default_factory=list)
|
||||
|
||||
# Dispatch mode
|
||||
cpu_bound: bool = False # If True, dispatch to ProcessPoolExecutor
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerConfig:
|
||||
|
|
@ -83,6 +86,42 @@ class AuthConfig:
|
|||
totp_secret_env: str = "ORGANISM_TOTP_SECRET"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendStorageConfig:
|
||||
"""
|
||||
Shared backend configuration for multi-process deployments.
|
||||
|
||||
Enables ContextBuffer and ThreadRegistry to use shared storage
|
||||
(Redis or multiprocessing.Manager) for cross-process access.
|
||||
"""
|
||||
|
||||
backend_type: str = "memory" # "memory", "manager", "redis"
|
||||
|
||||
# Redis-specific config
|
||||
redis_url: str = "redis://localhost:6379"
|
||||
redis_prefix: str = "xp:"
|
||||
redis_ttl: int = 86400 # 24 hours default TTL
|
||||
|
||||
# Limits
|
||||
max_slots_per_thread: int = 10000
|
||||
max_threads: int = 1000
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessPoolConfig:
|
||||
"""
|
||||
Process pool configuration for CPU-bound handler dispatch.
|
||||
|
||||
When configured, handlers marked with `cpu_bound: true` are
|
||||
dispatched to a ProcessPoolExecutor instead of running in
|
||||
the main event loop.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
workers: int = 4 # Number of worker processes
|
||||
max_tasks_per_child: int = 100 # Restart workers after N tasks
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrganismConfig:
|
||||
"""Complete organism configuration."""
|
||||
|
|
@ -92,6 +131,8 @@ class OrganismConfig:
|
|||
llm_backends: list[LLMBackendConfig] = field(default_factory=list)
|
||||
server: ServerConfig | None = None
|
||||
auth: AuthConfig | None = None
|
||||
backend: BackendStorageConfig | None = None
|
||||
process_pool: ProcessPoolConfig | None = None
|
||||
|
||||
|
||||
def load_config(path: Path) -> OrganismConfig:
|
||||
|
|
@ -152,6 +193,7 @@ def load_config(path: Path) -> OrganismConfig:
|
|||
peers=listener_raw.get("peers", []),
|
||||
allowed_tools=listener_raw.get("allowed_tools", []),
|
||||
blocked_tools=listener_raw.get("blocked_tools", []),
|
||||
cpu_bound=listener_raw.get("cpu_bound", False),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -174,12 +216,37 @@ def load_config(path: Path) -> OrganismConfig:
|
|||
totp_secret_env=auth_raw.get("totp_secret_env", "ORGANISM_TOTP_SECRET"),
|
||||
)
|
||||
|
||||
# Parse optional backend config
|
||||
backend = None
|
||||
if "backend" in raw:
|
||||
backend_raw = raw["backend"]
|
||||
backend = BackendStorageConfig(
|
||||
backend_type=backend_raw.get("type", "memory"),
|
||||
redis_url=backend_raw.get("redis_url", "redis://localhost:6379"),
|
||||
redis_prefix=backend_raw.get("redis_prefix", "xp:"),
|
||||
redis_ttl=backend_raw.get("redis_ttl", 86400),
|
||||
max_slots_per_thread=backend_raw.get("max_slots_per_thread", 10000),
|
||||
max_threads=backend_raw.get("max_threads", 1000),
|
||||
)
|
||||
|
||||
# Parse optional process pool config
|
||||
process_pool = None
|
||||
if "process_pool" in raw:
|
||||
pool_raw = raw["process_pool"]
|
||||
process_pool = ProcessPoolConfig(
|
||||
enabled=pool_raw.get("enabled", True),
|
||||
workers=pool_raw.get("workers", 4),
|
||||
max_tasks_per_child=pool_raw.get("max_tasks_per_child", 100),
|
||||
)
|
||||
|
||||
return OrganismConfig(
|
||||
organism=organism,
|
||||
listeners=listeners,
|
||||
llm_backends=llm_backends,
|
||||
server=server,
|
||||
auth=auth,
|
||||
backend=backend,
|
||||
process_pool=process_pool,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,21 @@ Provides thread-scoped, append-only context buffers with:
|
|||
- Thread isolation (handlers only see their context)
|
||||
- Complete audit trail (all messages preserved)
|
||||
- GC and limits (prevent runaway memory usage)
|
||||
|
||||
For multi-process deployments, supports shared backends:
|
||||
- InMemoryBackend: Default single-process storage
|
||||
- ManagerBackend: multiprocessing.Manager for local multi-process
|
||||
- RedisBackend: Redis for distributed deployments
|
||||
|
||||
Usage:
|
||||
# Default (in-memory, single process)
|
||||
buffer = get_context_buffer()
|
||||
|
||||
# With Redis backend
|
||||
from xml_pipeline.memory.shared_backend import BackendConfig, get_shared_backend
|
||||
config = BackendConfig(backend_type="redis", redis_url="redis://localhost:6379")
|
||||
backend = get_shared_backend(config)
|
||||
buffer = get_context_buffer(backend=backend)
|
||||
"""
|
||||
|
||||
from xml_pipeline.memory.context_buffer import (
|
||||
|
|
@ -14,14 +29,33 @@ from xml_pipeline.memory.context_buffer import (
|
|||
BufferSlot,
|
||||
SlotMetadata,
|
||||
get_context_buffer,
|
||||
reset_context_buffer,
|
||||
slot_to_handler_metadata,
|
||||
)
|
||||
|
||||
from xml_pipeline.memory.shared_backend import (
|
||||
SharedBackend,
|
||||
BackendConfig,
|
||||
get_shared_backend,
|
||||
reset_shared_backend,
|
||||
serialize_slot,
|
||||
deserialize_slot,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Context buffer
|
||||
"ContextBuffer",
|
||||
"ThreadContext",
|
||||
"BufferSlot",
|
||||
"SlotMetadata",
|
||||
"get_context_buffer",
|
||||
"reset_context_buffer",
|
||||
"slot_to_handler_metadata",
|
||||
# Shared backend
|
||||
"SharedBackend",
|
||||
"BackendConfig",
|
||||
"get_shared_backend",
|
||||
"reset_shared_backend",
|
||||
"serialize_slot",
|
||||
"deserialize_slot",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -27,16 +27,26 @@ Usage:
|
|||
|
||||
# Get thread history
|
||||
history = buffer.get_thread(thread_id)
|
||||
|
||||
For multi-process deployments, the buffer can use a shared backend:
|
||||
from xml_pipeline.memory.shared_backend import get_shared_backend, BackendConfig
|
||||
|
||||
config = BackendConfig(backend_type="redis", redis_url="redis://localhost:6379")
|
||||
backend = get_shared_backend(config)
|
||||
buffer = get_context_buffer(backend=backend)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Iterator
|
||||
from typing import Any, Dict, List, Optional, Iterator, TYPE_CHECKING
|
||||
from datetime import datetime, timezone
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from xml_pipeline.memory.shared_backend import SharedBackend
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SlotMetadata:
|
||||
|
|
@ -166,9 +176,26 @@ class ContextBuffer:
|
|||
Global context buffer managing all thread contexts.
|
||||
|
||||
Thread-safe. Singleton pattern via get_context_buffer().
|
||||
|
||||
Supports two storage modes:
|
||||
1. Local mode (default): Uses in-process ThreadContext objects
|
||||
2. Shared mode: Uses SharedBackend (Redis, Manager) for cross-process access
|
||||
|
||||
In shared mode, slots are serialized via pickle and stored in the backend.
|
||||
This enables multi-process handler dispatch (cpu_bound handlers).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, backend: Optional[SharedBackend] = None):
|
||||
"""
|
||||
Initialize context buffer.
|
||||
|
||||
Args:
|
||||
backend: Optional shared backend for cross-process storage.
|
||||
If None, uses in-process storage (original behavior).
|
||||
"""
|
||||
self._backend = backend
|
||||
|
||||
# Local storage (used when no backend)
|
||||
self._threads: Dict[str, ThreadContext] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
|
|
@ -176,8 +203,17 @@ class ContextBuffer:
|
|||
self.max_slots_per_thread: int = 10000
|
||||
self.max_threads: int = 1000
|
||||
|
||||
@property
|
||||
def is_shared(self) -> bool:
|
||||
"""Return True if using shared backend."""
|
||||
return self._backend is not None
|
||||
|
||||
def get_or_create_thread(self, thread_id: str) -> ThreadContext:
|
||||
"""Get existing thread context or create new one."""
|
||||
"""
|
||||
Get existing thread context or create new one.
|
||||
|
||||
Note: In shared mode, this creates a local proxy that syncs with backend.
|
||||
"""
|
||||
with self._lock:
|
||||
if thread_id not in self._threads:
|
||||
if len(self._threads) >= self.max_threads:
|
||||
|
|
@ -205,7 +241,23 @@ class ContextBuffer:
|
|||
|
||||
This is the main entry point for the pipeline.
|
||||
Returns the immutable slot reference.
|
||||
|
||||
In shared mode, the slot is serialized and stored in the backend.
|
||||
"""
|
||||
if self._backend is not None:
|
||||
# Shared mode: serialize and store in backend
|
||||
return self._append_shared(
|
||||
thread_id=thread_id,
|
||||
payload=payload,
|
||||
from_id=from_id,
|
||||
to_id=to_id,
|
||||
own_name=own_name,
|
||||
is_self_call=is_self_call,
|
||||
usage_instructions=usage_instructions,
|
||||
todo_nudge=todo_nudge,
|
||||
)
|
||||
|
||||
# Local mode: use ThreadContext
|
||||
thread = self.get_or_create_thread(thread_id)
|
||||
|
||||
# Enforce slot limit
|
||||
|
|
@ -224,18 +276,116 @@ class ContextBuffer:
|
|||
todo_nudge=todo_nudge,
|
||||
)
|
||||
|
||||
def _append_shared(
|
||||
self,
|
||||
thread_id: str,
|
||||
payload: Any,
|
||||
from_id: str,
|
||||
to_id: str,
|
||||
own_name: Optional[str] = None,
|
||||
is_self_call: bool = False,
|
||||
usage_instructions: str = "",
|
||||
todo_nudge: str = "",
|
||||
) -> BufferSlot:
|
||||
"""Append to shared backend."""
|
||||
from xml_pipeline.memory.shared_backend import serialize_slot
|
||||
|
||||
assert self._backend is not None
|
||||
|
||||
# Get current slot count for index
|
||||
current_len = self._backend.buffer_thread_len(thread_id)
|
||||
|
||||
# Enforce slot limit
|
||||
if current_len >= self.max_slots_per_thread:
|
||||
raise MemoryError(
|
||||
f"Thread {thread_id} exceeded max slots ({self.max_slots_per_thread})"
|
||||
)
|
||||
|
||||
# Create metadata
|
||||
metadata = SlotMetadata(
|
||||
thread_id=thread_id,
|
||||
from_id=from_id,
|
||||
to_id=to_id,
|
||||
slot_index=current_len,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
payload_type=type(payload).__name__,
|
||||
own_name=own_name,
|
||||
is_self_call=is_self_call,
|
||||
usage_instructions=usage_instructions,
|
||||
todo_nudge=todo_nudge,
|
||||
)
|
||||
|
||||
# Create slot
|
||||
slot = BufferSlot(payload=payload, metadata=metadata)
|
||||
|
||||
# Serialize and store
|
||||
slot_data = serialize_slot(slot)
|
||||
self._backend.buffer_append(thread_id, slot_data)
|
||||
|
||||
return slot
|
||||
|
||||
def get_thread(self, thread_id: str) -> Optional[ThreadContext]:
|
||||
"""Get a thread's context (None if not found)."""
|
||||
"""
|
||||
Get a thread's context (None if not found).
|
||||
|
||||
In shared mode, returns a local ThreadContext populated from backend.
|
||||
"""
|
||||
if self._backend is not None:
|
||||
return self._get_thread_shared(thread_id)
|
||||
|
||||
with self._lock:
|
||||
return self._threads.get(thread_id)
|
||||
|
||||
def _get_thread_shared(self, thread_id: str) -> Optional[ThreadContext]:
|
||||
"""Get thread from shared backend."""
|
||||
from xml_pipeline.memory.shared_backend import deserialize_slot
|
||||
|
||||
assert self._backend is not None
|
||||
|
||||
if not self._backend.buffer_thread_exists(thread_id):
|
||||
return None
|
||||
|
||||
# Create local ThreadContext and populate from backend
|
||||
thread = ThreadContext(thread_id)
|
||||
slot_data_list = self._backend.buffer_get_thread(thread_id)
|
||||
|
||||
for slot_data in slot_data_list:
|
||||
slot = deserialize_slot(slot_data)
|
||||
thread._slots.append(slot)
|
||||
|
||||
return thread
|
||||
|
||||
def get_thread_slots(self, thread_id: str) -> List[BufferSlot]:
|
||||
"""
|
||||
Get all slots for a thread as a list.
|
||||
|
||||
More efficient than get_thread() when you just need slots.
|
||||
"""
|
||||
if self._backend is not None:
|
||||
from xml_pipeline.memory.shared_backend import deserialize_slot
|
||||
|
||||
slot_data_list = self._backend.buffer_get_thread(thread_id)
|
||||
return [deserialize_slot(data) for data in slot_data_list]
|
||||
|
||||
with self._lock:
|
||||
thread = self._threads.get(thread_id)
|
||||
if thread:
|
||||
return list(thread._slots)
|
||||
return []
|
||||
|
||||
def thread_exists(self, thread_id: str) -> bool:
|
||||
"""Check if a thread exists."""
|
||||
if self._backend is not None:
|
||||
return self._backend.buffer_thread_exists(thread_id)
|
||||
|
||||
with self._lock:
|
||||
return thread_id in self._threads
|
||||
|
||||
def delete_thread(self, thread_id: str) -> bool:
|
||||
"""Delete a thread's context (GC)."""
|
||||
if self._backend is not None:
|
||||
return self._backend.buffer_delete_thread(thread_id)
|
||||
|
||||
with self._lock:
|
||||
if thread_id in self._threads:
|
||||
del self._threads[thread_id]
|
||||
|
|
@ -244,6 +394,20 @@ class ContextBuffer:
|
|||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get buffer statistics."""
|
||||
if self._backend is not None:
|
||||
threads = self._backend.buffer_list_threads()
|
||||
total_slots = sum(
|
||||
self._backend.buffer_thread_len(t) for t in threads
|
||||
)
|
||||
return {
|
||||
"thread_count": len(threads),
|
||||
"total_slots": total_slots,
|
||||
"max_threads": self.max_threads,
|
||||
"max_slots_per_thread": self.max_slots_per_thread,
|
||||
"threads": threads,
|
||||
"backend": "shared",
|
||||
}
|
||||
|
||||
with self._lock:
|
||||
total_slots = sum(len(t) for t in self._threads.values())
|
||||
return {
|
||||
|
|
@ -252,10 +416,15 @@ class ContextBuffer:
|
|||
"max_threads": self.max_threads,
|
||||
"max_slots_per_thread": self.max_slots_per_thread,
|
||||
"threads": list(self._threads.keys()),
|
||||
"backend": "local",
|
||||
}
|
||||
|
||||
def clear(self):
|
||||
def clear(self) -> None:
|
||||
"""Clear all contexts (for testing)."""
|
||||
if self._backend is not None:
|
||||
self._backend.buffer_clear()
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
self._threads.clear()
|
||||
|
||||
|
|
@ -268,16 +437,35 @@ _buffer: Optional[ContextBuffer] = None
|
|||
_buffer_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_context_buffer() -> ContextBuffer:
|
||||
"""Get the global ContextBuffer singleton."""
|
||||
def get_context_buffer(backend: Optional[SharedBackend] = None) -> ContextBuffer:
|
||||
"""
|
||||
Get the global ContextBuffer singleton.
|
||||
|
||||
Args:
|
||||
backend: Optional shared backend for cross-process storage.
|
||||
Only used on first call (when creating the singleton).
|
||||
Subsequent calls return the existing singleton.
|
||||
|
||||
Returns:
|
||||
Global ContextBuffer instance.
|
||||
"""
|
||||
global _buffer
|
||||
if _buffer is None:
|
||||
with _buffer_lock:
|
||||
if _buffer is None:
|
||||
_buffer = ContextBuffer()
|
||||
_buffer = ContextBuffer(backend=backend)
|
||||
return _buffer
|
||||
|
||||
|
||||
def reset_context_buffer() -> None:
|
||||
"""Reset the global context buffer (for testing)."""
|
||||
global _buffer
|
||||
with _buffer_lock:
|
||||
if _buffer is not None:
|
||||
_buffer.clear()
|
||||
_buffer = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Handler-facing metadata adapter
|
||||
# ============================================================================
|
||||
|
|
|
|||
220
xml_pipeline/memory/manager_backend.py
Normal file
220
xml_pipeline/memory/manager_backend.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
"""
|
||||
manager_backend.py — multiprocessing.Manager implementation of SharedBackend.
|
||||
|
||||
Uses Python's multiprocessing.Manager for cross-process state sharing
|
||||
without requiring Redis. Good for local development and testing.
|
||||
|
||||
Key differences from InMemoryBackend:
|
||||
- Data structures are proxied across processes
|
||||
- Slightly higher overhead than in-memory
|
||||
- No TTL/expiration (use for short-lived test runs)
|
||||
- No persistence (data lost on manager shutdown)
|
||||
|
||||
Usage:
|
||||
# Start manager in main process
|
||||
backend = ManagerBackend()
|
||||
|
||||
# Workers connect via the same address
|
||||
# (Manager handles cross-process communication)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import multiprocessing
|
||||
from multiprocessing.managers import SyncManager
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ManagerBackend:
|
||||
"""
|
||||
multiprocessing.Manager-backed shared state.
|
||||
|
||||
Provides cross-process state without external dependencies.
|
||||
Suitable for local multi-process testing when Redis isn't available.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
address: Optional[Tuple[str, int]] = None,
|
||||
authkey: Optional[bytes] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Manager backend.
|
||||
|
||||
Args:
|
||||
address: (host, port) for remote manager connection.
|
||||
If None, creates a local manager.
|
||||
authkey: Authentication key for manager.
|
||||
If None, uses process authkey.
|
||||
"""
|
||||
self._is_server = address is None
|
||||
self._manager: Optional[SyncManager] = None
|
||||
|
||||
if self._is_server:
|
||||
# Create a local manager (server mode)
|
||||
self._manager = multiprocessing.Manager()
|
||||
self._init_data_structures()
|
||||
logger.info("Manager backend started (server mode)")
|
||||
else:
|
||||
# Connect to remote manager (client mode)
|
||||
# Note: For remote connection, you'd need a custom SyncManager
|
||||
# This is simplified for now - just use local manager
|
||||
self._manager = multiprocessing.Manager()
|
||||
self._init_data_structures()
|
||||
logger.info("Manager backend started (client mode)")
|
||||
|
||||
def _init_data_structures(self) -> None:
|
||||
"""Initialize managed data structures."""
|
||||
if self._manager is None:
|
||||
raise RuntimeError("Manager not initialized")
|
||||
|
||||
# Context buffer: thread_id → list of pickled slots
|
||||
self._buffer: Dict[str, List[bytes]] = self._manager.dict()
|
||||
|
||||
# Thread registry: bidirectional mapping
|
||||
self._chain_to_uuid: Dict[str, str] = self._manager.dict()
|
||||
self._uuid_to_chain: Dict[str, str] = self._manager.dict()
|
||||
|
||||
# Lock for complex operations
|
||||
self._lock = self._manager.Lock()
|
||||
|
||||
# =========================================================================
|
||||
# Context Buffer Operations
|
||||
# =========================================================================
|
||||
|
||||
def buffer_append(self, thread_id: str, slot_data: bytes) -> int:
|
||||
"""Append a slot to a thread's buffer."""
|
||||
with self._lock:
|
||||
if thread_id not in self._buffer:
|
||||
# Manager.dict doesn't support nested assignment directly
|
||||
# We need to get, modify, put back
|
||||
self._buffer[thread_id] = self._manager.list() # type: ignore
|
||||
|
||||
# Get the list, append, and count
|
||||
slots = self._buffer[thread_id]
|
||||
slots.append(slot_data)
|
||||
return len(slots) - 1
|
||||
|
||||
def buffer_get_thread(self, thread_id: str) -> List[bytes]:
|
||||
"""Get all slots for a thread."""
|
||||
with self._lock:
|
||||
if thread_id not in self._buffer:
|
||||
return []
|
||||
# Convert manager.list to regular list
|
||||
return list(self._buffer[thread_id])
|
||||
|
||||
def buffer_get_slot(self, thread_id: str, index: int) -> Optional[bytes]:
|
||||
"""Get a specific slot by index."""
|
||||
with self._lock:
|
||||
if thread_id not in self._buffer:
|
||||
return None
|
||||
slots = self._buffer[thread_id]
|
||||
if 0 <= index < len(slots):
|
||||
return slots[index]
|
||||
return None
|
||||
|
||||
def buffer_thread_len(self, thread_id: str) -> int:
|
||||
"""Get number of slots in a thread."""
|
||||
with self._lock:
|
||||
if thread_id not in self._buffer:
|
||||
return 0
|
||||
return len(self._buffer[thread_id])
|
||||
|
||||
def buffer_thread_exists(self, thread_id: str) -> bool:
|
||||
"""Check if a thread has any slots."""
|
||||
with self._lock:
|
||||
return thread_id in self._buffer
|
||||
|
||||
def buffer_delete_thread(self, thread_id: str) -> bool:
|
||||
"""Delete all slots for a thread."""
|
||||
with self._lock:
|
||||
if thread_id in self._buffer:
|
||||
del self._buffer[thread_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def buffer_list_threads(self) -> List[str]:
|
||||
"""List all thread IDs with slots."""
|
||||
with self._lock:
|
||||
return list(self._buffer.keys())
|
||||
|
||||
def buffer_clear(self) -> None:
|
||||
"""Clear all buffer data."""
|
||||
with self._lock:
|
||||
# Can't call .clear() on manager.dict proxy
|
||||
keys = list(self._buffer.keys())
|
||||
for key in keys:
|
||||
del self._buffer[key]
|
||||
|
||||
# =========================================================================
|
||||
# Thread Registry Operations
|
||||
# =========================================================================
|
||||
|
||||
def registry_set(self, chain: str, uuid: str) -> None:
|
||||
"""Set bidirectional mapping: chain ↔ uuid."""
|
||||
with self._lock:
|
||||
self._chain_to_uuid[chain] = uuid
|
||||
self._uuid_to_chain[uuid] = chain
|
||||
|
||||
def registry_get_uuid(self, chain: str) -> Optional[str]:
|
||||
"""Get UUID for a chain."""
|
||||
with self._lock:
|
||||
return self._chain_to_uuid.get(chain)
|
||||
|
||||
def registry_get_chain(self, uuid: str) -> Optional[str]:
|
||||
"""Get chain for a UUID."""
|
||||
with self._lock:
|
||||
return self._uuid_to_chain.get(uuid)
|
||||
|
||||
def registry_delete(self, uuid: str) -> bool:
|
||||
"""Delete mapping by UUID."""
|
||||
with self._lock:
|
||||
chain = self._uuid_to_chain.get(uuid)
|
||||
if chain:
|
||||
del self._uuid_to_chain[uuid]
|
||||
del self._chain_to_uuid[chain]
|
||||
return True
|
||||
return False
|
||||
|
||||
def registry_list_all(self) -> Dict[str, str]:
|
||||
"""Get all UUID → chain mappings."""
|
||||
with self._lock:
|
||||
return dict(self._uuid_to_chain)
|
||||
|
||||
def registry_clear(self) -> None:
|
||||
"""Clear all registry data."""
|
||||
with self._lock:
|
||||
# Can't call .clear() on manager.dict proxy
|
||||
for key in list(self._chain_to_uuid.keys()):
|
||||
del self._chain_to_uuid[key]
|
||||
for key in list(self._uuid_to_chain.keys()):
|
||||
del self._uuid_to_chain[key]
|
||||
|
||||
# =========================================================================
|
||||
# Lifecycle
|
||||
# =========================================================================
|
||||
|
||||
def close(self) -> None:
|
||||
"""Shutdown manager."""
|
||||
if self._manager and self._is_server:
|
||||
try:
|
||||
self._manager.shutdown()
|
||||
logger.info("Manager backend shutdown")
|
||||
except Exception as e:
|
||||
logger.warning(f"Manager shutdown error: {e}")
|
||||
self._manager = None
|
||||
|
||||
# =========================================================================
|
||||
# Info
|
||||
# =========================================================================
|
||||
|
||||
def info(self) -> Dict[str, int]:
|
||||
"""Get backend statistics."""
|
||||
with self._lock:
|
||||
return {
|
||||
"buffer_threads": len(self._buffer),
|
||||
"registry_entries": len(self._uuid_to_chain),
|
||||
}
|
||||
136
xml_pipeline/memory/memory_backend.py
Normal file
136
xml_pipeline/memory/memory_backend.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
"""
|
||||
memory_backend.py — In-memory implementation of SharedBackend.
|
||||
|
||||
This is the default backend for single-process operation.
|
||||
It provides the same interface as Redis/Manager backends but stores
|
||||
everything in Python data structures.
|
||||
|
||||
Thread-safe via threading.Lock (same as original ContextBuffer/ThreadRegistry).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
class InMemoryBackend:
|
||||
"""
|
||||
In-memory shared backend implementation.
|
||||
|
||||
This is equivalent to the original behavior - all data in-process.
|
||||
Use for development, testing, or single-process deployments.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# Context buffer storage: thread_id → list of pickled slots
|
||||
self._buffer: Dict[str, List[bytes]] = {}
|
||||
self._buffer_lock = threading.Lock()
|
||||
|
||||
# Thread registry storage: bidirectional mapping
|
||||
self._chain_to_uuid: Dict[str, str] = {}
|
||||
self._uuid_to_chain: Dict[str, str] = {}
|
||||
self._registry_lock = threading.Lock()
|
||||
|
||||
# =========================================================================
|
||||
# Context Buffer Operations
|
||||
# =========================================================================
|
||||
|
||||
def buffer_append(self, thread_id: str, slot_data: bytes) -> int:
|
||||
"""Append a slot to a thread's buffer."""
|
||||
with self._buffer_lock:
|
||||
if thread_id not in self._buffer:
|
||||
self._buffer[thread_id] = []
|
||||
|
||||
index = len(self._buffer[thread_id])
|
||||
self._buffer[thread_id].append(slot_data)
|
||||
return index
|
||||
|
||||
def buffer_get_thread(self, thread_id: str) -> List[bytes]:
|
||||
"""Get all slots for a thread."""
|
||||
with self._buffer_lock:
|
||||
return list(self._buffer.get(thread_id, []))
|
||||
|
||||
def buffer_get_slot(self, thread_id: str, index: int) -> Optional[bytes]:
|
||||
"""Get a specific slot by index."""
|
||||
with self._buffer_lock:
|
||||
slots = self._buffer.get(thread_id, [])
|
||||
if 0 <= index < len(slots):
|
||||
return slots[index]
|
||||
return None
|
||||
|
||||
def buffer_thread_len(self, thread_id: str) -> int:
|
||||
"""Get number of slots in a thread."""
|
||||
with self._buffer_lock:
|
||||
return len(self._buffer.get(thread_id, []))
|
||||
|
||||
def buffer_thread_exists(self, thread_id: str) -> bool:
|
||||
"""Check if a thread has any slots."""
|
||||
with self._buffer_lock:
|
||||
return thread_id in self._buffer
|
||||
|
||||
def buffer_delete_thread(self, thread_id: str) -> bool:
|
||||
"""Delete all slots for a thread."""
|
||||
with self._buffer_lock:
|
||||
if thread_id in self._buffer:
|
||||
del self._buffer[thread_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def buffer_list_threads(self) -> List[str]:
|
||||
"""List all thread IDs with slots."""
|
||||
with self._buffer_lock:
|
||||
return list(self._buffer.keys())
|
||||
|
||||
def buffer_clear(self) -> None:
|
||||
"""Clear all buffer data."""
|
||||
with self._buffer_lock:
|
||||
self._buffer.clear()
|
||||
|
||||
# =========================================================================
|
||||
# Thread Registry Operations
|
||||
# =========================================================================
|
||||
|
||||
def registry_set(self, chain: str, uuid: str) -> None:
|
||||
"""Set bidirectional mapping: chain ↔ uuid."""
|
||||
with self._registry_lock:
|
||||
self._chain_to_uuid[chain] = uuid
|
||||
self._uuid_to_chain[uuid] = chain
|
||||
|
||||
def registry_get_uuid(self, chain: str) -> Optional[str]:
|
||||
"""Get UUID for a chain."""
|
||||
with self._registry_lock:
|
||||
return self._chain_to_uuid.get(chain)
|
||||
|
||||
def registry_get_chain(self, uuid: str) -> Optional[str]:
|
||||
"""Get chain for a UUID."""
|
||||
with self._registry_lock:
|
||||
return self._uuid_to_chain.get(uuid)
|
||||
|
||||
def registry_delete(self, uuid: str) -> bool:
|
||||
"""Delete mapping by UUID."""
|
||||
with self._registry_lock:
|
||||
chain = self._uuid_to_chain.pop(uuid, None)
|
||||
if chain:
|
||||
self._chain_to_uuid.pop(chain, None)
|
||||
return True
|
||||
return False
|
||||
|
||||
def registry_list_all(self) -> Dict[str, str]:
|
||||
"""Get all UUID → chain mappings."""
|
||||
with self._registry_lock:
|
||||
return dict(self._uuid_to_chain)
|
||||
|
||||
def registry_clear(self) -> None:
|
||||
"""Clear all registry data."""
|
||||
with self._registry_lock:
|
||||
self._chain_to_uuid.clear()
|
||||
self._uuid_to_chain.clear()
|
||||
|
||||
# =========================================================================
|
||||
# Lifecycle
|
||||
# =========================================================================
|
||||
|
||||
def close(self) -> None:
|
||||
"""No resources to clean up for in-memory backend."""
|
||||
pass
|
||||
262
xml_pipeline/memory/redis_backend.py
Normal file
262
xml_pipeline/memory/redis_backend.py
Normal file
|
|
@ -0,0 +1,262 @@
|
|||
"""
|
||||
redis_backend.py — Redis implementation of SharedBackend.
|
||||
|
||||
Uses Redis for distributed shared state, enabling:
|
||||
- Multi-process handler dispatch (cpu_bound handlers in ProcessPool)
|
||||
- Multi-tenant deployments (key prefix isolation)
|
||||
- Automatic TTL-based garbage collection
|
||||
- Cross-machine state sharing (future: WASM handlers)
|
||||
|
||||
Key schema:
|
||||
- Buffer slots: `{prefix}buffer:{thread_id}` → Redis LIST of pickled slots
|
||||
- Registry chain→uuid: `{prefix}chain:{chain}` → uuid string
|
||||
- Registry uuid→chain: `{prefix}uuid:{uuid}` → chain string
|
||||
|
||||
Dependencies:
|
||||
pip install redis
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
try:
|
||||
import redis
|
||||
except ImportError:
|
||||
redis = None # type: ignore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RedisBackend:
|
||||
"""
|
||||
Redis-backed shared state for multi-process deployments.
|
||||
|
||||
All operations are synchronous (blocking). For async contexts,
|
||||
wrap calls in asyncio.to_thread() or use a connection pool.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str = "redis://localhost:6379",
|
||||
prefix: str = "xp:",
|
||||
ttl: int = 86400,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Redis backend.
|
||||
|
||||
Args:
|
||||
url: Redis connection URL
|
||||
prefix: Key prefix for multi-tenancy isolation
|
||||
ttl: TTL in seconds for automatic cleanup (0 = no TTL)
|
||||
"""
|
||||
if redis is None:
|
||||
raise ImportError(
|
||||
"redis package not installed. Install with: pip install redis"
|
||||
)
|
||||
|
||||
self.url = url
|
||||
self.prefix = prefix
|
||||
self.ttl = ttl
|
||||
|
||||
# Connect to Redis
|
||||
self._client = redis.from_url(url, decode_responses=False)
|
||||
|
||||
# Verify connection
|
||||
try:
|
||||
self._client.ping()
|
||||
logger.info(f"Redis backend connected: {url}")
|
||||
except redis.ConnectionError as e:
|
||||
logger.error(f"Redis connection failed: {e}")
|
||||
raise
|
||||
|
||||
def _buffer_key(self, thread_id: str) -> str:
|
||||
"""Get Redis key for a thread's buffer."""
|
||||
return f"{self.prefix}buffer:{thread_id}"
|
||||
|
||||
def _chain_key(self, chain: str) -> str:
|
||||
"""Get Redis key for chain→uuid mapping."""
|
||||
return f"{self.prefix}chain:{chain}"
|
||||
|
||||
def _uuid_key(self, uuid: str) -> str:
|
||||
"""Get Redis key for uuid→chain mapping."""
|
||||
return f"{self.prefix}uuid:{uuid}"
|
||||
|
||||
# =========================================================================
|
||||
# Context Buffer Operations
|
||||
# =========================================================================
|
||||
|
||||
def buffer_append(self, thread_id: str, slot_data: bytes) -> int:
|
||||
"""Append a slot to a thread's buffer."""
|
||||
key = self._buffer_key(thread_id)
|
||||
|
||||
# RPUSH returns new length, so index = length - 1
|
||||
new_len = self._client.rpush(key, slot_data)
|
||||
|
||||
# Set/refresh TTL
|
||||
if self.ttl > 0:
|
||||
self._client.expire(key, self.ttl)
|
||||
|
||||
return new_len - 1
|
||||
|
||||
def buffer_get_thread(self, thread_id: str) -> List[bytes]:
|
||||
"""Get all slots for a thread."""
|
||||
key = self._buffer_key(thread_id)
|
||||
# LRANGE returns all elements (0, -1 = full list)
|
||||
result = self._client.lrange(key, 0, -1)
|
||||
return result if result else []
|
||||
|
||||
def buffer_get_slot(self, thread_id: str, index: int) -> Optional[bytes]:
|
||||
"""Get a specific slot by index."""
|
||||
key = self._buffer_key(thread_id)
|
||||
# LINDEX returns element at index, or None
|
||||
return self._client.lindex(key, index)
|
||||
|
||||
def buffer_thread_len(self, thread_id: str) -> int:
|
||||
"""Get number of slots in a thread."""
|
||||
key = self._buffer_key(thread_id)
|
||||
return self._client.llen(key)
|
||||
|
||||
def buffer_thread_exists(self, thread_id: str) -> bool:
|
||||
"""Check if a thread has any slots."""
|
||||
key = self._buffer_key(thread_id)
|
||||
return self._client.exists(key) > 0
|
||||
|
||||
def buffer_delete_thread(self, thread_id: str) -> bool:
|
||||
"""Delete all slots for a thread."""
|
||||
key = self._buffer_key(thread_id)
|
||||
deleted = self._client.delete(key)
|
||||
return deleted > 0
|
||||
|
||||
def buffer_list_threads(self) -> List[str]:
|
||||
"""List all thread IDs with slots."""
|
||||
pattern = f"{self.prefix}buffer:*"
|
||||
keys = self._client.keys(pattern)
|
||||
prefix_len = len(f"{self.prefix}buffer:")
|
||||
return [k.decode() if isinstance(k, bytes) else k for k in keys]
|
||||
|
||||
def buffer_clear(self) -> None:
|
||||
"""Clear all buffer data."""
|
||||
pattern = f"{self.prefix}buffer:*"
|
||||
keys = self._client.keys(pattern)
|
||||
if keys:
|
||||
self._client.delete(*keys)
|
||||
|
||||
# =========================================================================
|
||||
# Thread Registry Operations
|
||||
# =========================================================================
|
||||
|
||||
def registry_set(self, chain: str, uuid: str) -> None:
|
||||
"""Set bidirectional mapping: chain ↔ uuid."""
|
||||
chain_key = self._chain_key(chain)
|
||||
uuid_key = self._uuid_key(uuid)
|
||||
|
||||
# Use pipeline for atomicity
|
||||
pipe = self._client.pipeline()
|
||||
pipe.set(chain_key, uuid.encode())
|
||||
pipe.set(uuid_key, chain.encode())
|
||||
|
||||
if self.ttl > 0:
|
||||
pipe.expire(chain_key, self.ttl)
|
||||
pipe.expire(uuid_key, self.ttl)
|
||||
|
||||
pipe.execute()
|
||||
|
||||
def registry_get_uuid(self, chain: str) -> Optional[str]:
|
||||
"""Get UUID for a chain."""
|
||||
key = self._chain_key(chain)
|
||||
value = self._client.get(key)
|
||||
if value:
|
||||
return value.decode() if isinstance(value, bytes) else value
|
||||
return None
|
||||
|
||||
def registry_get_chain(self, uuid: str) -> Optional[str]:
|
||||
"""Get chain for a UUID."""
|
||||
key = self._uuid_key(uuid)
|
||||
value = self._client.get(key)
|
||||
if value:
|
||||
return value.decode() if isinstance(value, bytes) else value
|
||||
return None
|
||||
|
||||
def registry_delete(self, uuid: str) -> bool:
|
||||
"""Delete mapping by UUID."""
|
||||
uuid_key = self._uuid_key(uuid)
|
||||
|
||||
# First get chain to delete both directions
|
||||
chain = self._client.get(uuid_key)
|
||||
if not chain:
|
||||
return False
|
||||
|
||||
if isinstance(chain, bytes):
|
||||
chain = chain.decode()
|
||||
|
||||
chain_key = self._chain_key(chain)
|
||||
|
||||
# Delete both keys
|
||||
deleted = self._client.delete(uuid_key, chain_key)
|
||||
return deleted > 0
|
||||
|
||||
def registry_list_all(self) -> Dict[str, str]:
|
||||
"""Get all UUID → chain mappings."""
|
||||
pattern = f"{self.prefix}uuid:*"
|
||||
keys = self._client.keys(pattern)
|
||||
|
||||
result = {}
|
||||
prefix_len = len(f"{self.prefix}uuid:")
|
||||
|
||||
for key in keys:
|
||||
if isinstance(key, bytes):
|
||||
key = key.decode()
|
||||
uuid = key[prefix_len:]
|
||||
chain = self._client.get(key)
|
||||
if chain:
|
||||
if isinstance(chain, bytes):
|
||||
chain = chain.decode()
|
||||
result[uuid] = chain
|
||||
|
||||
return result
|
||||
|
||||
def registry_clear(self) -> None:
|
||||
"""Clear all registry data."""
|
||||
chain_pattern = f"{self.prefix}chain:*"
|
||||
uuid_pattern = f"{self.prefix}uuid:*"
|
||||
|
||||
chain_keys = self._client.keys(chain_pattern)
|
||||
uuid_keys = self._client.keys(uuid_pattern)
|
||||
|
||||
all_keys = list(chain_keys) + list(uuid_keys)
|
||||
if all_keys:
|
||||
self._client.delete(*all_keys)
|
||||
|
||||
# =========================================================================
|
||||
# Lifecycle
|
||||
# =========================================================================
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close Redis connection."""
|
||||
if self._client:
|
||||
self._client.close()
|
||||
logger.info("Redis backend closed")
|
||||
|
||||
# =========================================================================
|
||||
# Health Check
|
||||
# =========================================================================
|
||||
|
||||
def ping(self) -> bool:
|
||||
"""Check if Redis is reachable."""
|
||||
try:
|
||||
return self._client.ping()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def info(self) -> Dict[str, int]:
|
||||
"""Get backend statistics."""
|
||||
buffer_keys = len(self._client.keys(f"{self.prefix}buffer:*"))
|
||||
uuid_keys = len(self._client.keys(f"{self.prefix}uuid:*"))
|
||||
|
||||
return {
|
||||
"buffer_threads": buffer_keys,
|
||||
"registry_entries": uuid_keys,
|
||||
"ttl_seconds": self.ttl,
|
||||
}
|
||||
275
xml_pipeline/memory/shared_backend.py
Normal file
275
xml_pipeline/memory/shared_backend.py
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
"""
|
||||
shared_backend.py — Abstract backend interface for shared state.
|
||||
|
||||
Provides a Protocol defining operations needed by ContextBuffer and ThreadRegistry
|
||||
to work across processes. Implementations include:
|
||||
- InMemoryBackend: Default, single-process (current behavior)
|
||||
- ManagerBackend: multiprocessing.Manager for local multi-process
|
||||
- RedisBackend: Redis for distributed/multi-tenant scenarios
|
||||
|
||||
All backends support:
|
||||
- Context buffer operations (thread slots)
|
||||
- Thread registry operations (UUID ↔ chain mapping)
|
||||
- TTL for automatic garbage collection
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pickle
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Protocol, Tuple, runtime_checkable
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendConfig:
|
||||
"""Configuration for shared backend selection."""
|
||||
|
||||
backend_type: str = "memory" # "memory", "manager", "redis"
|
||||
|
||||
# Redis-specific config
|
||||
redis_url: str = "redis://localhost:6379"
|
||||
redis_prefix: str = "xp:"
|
||||
redis_ttl: int = 86400 # 24 hours default
|
||||
|
||||
# Manager-specific config (for local multiprocess)
|
||||
manager_address: Optional[Tuple[str, int]] = None
|
||||
manager_authkey: Optional[bytes] = None
|
||||
|
||||
# Common limits
|
||||
max_slots_per_thread: int = 10000
|
||||
max_threads: int = 1000
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SharedBackend(Protocol):
|
||||
"""
|
||||
Protocol for shared state backends.
|
||||
|
||||
All methods should be synchronous (blocking) for simplicity.
|
||||
Async wrappers can be added at the caller level if needed.
|
||||
"""
|
||||
|
||||
# =========================================================================
|
||||
# Context Buffer Operations
|
||||
# =========================================================================
|
||||
|
||||
@abstractmethod
|
||||
def buffer_append(
|
||||
self,
|
||||
thread_id: str,
|
||||
slot_data: bytes, # Pickled BufferSlot
|
||||
) -> int:
|
||||
"""
|
||||
Append a slot to a thread's buffer.
|
||||
|
||||
Args:
|
||||
thread_id: UUID of the thread
|
||||
slot_data: Pickled BufferSlot bytes
|
||||
|
||||
Returns:
|
||||
Index of the appended slot (0-based)
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def buffer_get_thread(self, thread_id: str) -> List[bytes]:
|
||||
"""
|
||||
Get all slots for a thread.
|
||||
|
||||
Args:
|
||||
thread_id: UUID of the thread
|
||||
|
||||
Returns:
|
||||
List of pickled BufferSlot bytes (in order)
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def buffer_get_slot(self, thread_id: str, index: int) -> Optional[bytes]:
|
||||
"""
|
||||
Get a specific slot by index.
|
||||
|
||||
Args:
|
||||
thread_id: UUID of the thread
|
||||
index: Slot index (0-based)
|
||||
|
||||
Returns:
|
||||
Pickled BufferSlot bytes, or None if not found
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def buffer_thread_len(self, thread_id: str) -> int:
|
||||
"""Get number of slots in a thread."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def buffer_thread_exists(self, thread_id: str) -> bool:
|
||||
"""Check if a thread has any slots."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def buffer_delete_thread(self, thread_id: str) -> bool:
|
||||
"""Delete all slots for a thread. Returns True if thread existed."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def buffer_list_threads(self) -> List[str]:
|
||||
"""List all thread IDs with slots."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def buffer_clear(self) -> None:
|
||||
"""Clear all buffer data (for testing)."""
|
||||
...
|
||||
|
||||
# =========================================================================
|
||||
# Thread Registry Operations
|
||||
# =========================================================================
|
||||
|
||||
@abstractmethod
|
||||
def registry_set(self, chain: str, uuid: str) -> None:
|
||||
"""
|
||||
Set bidirectional mapping: chain ↔ uuid.
|
||||
|
||||
Args:
|
||||
chain: Dot-separated call chain (e.g., "console.router.greeter")
|
||||
uuid: UUID string for this chain
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def registry_get_uuid(self, chain: str) -> Optional[str]:
|
||||
"""
|
||||
Get UUID for a chain.
|
||||
|
||||
Args:
|
||||
chain: Call chain to look up
|
||||
|
||||
Returns:
|
||||
UUID string, or None if not found
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def registry_get_chain(self, uuid: str) -> Optional[str]:
|
||||
"""
|
||||
Get chain for a UUID.
|
||||
|
||||
Args:
|
||||
uuid: UUID to look up
|
||||
|
||||
Returns:
|
||||
Chain string, or None if not found
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def registry_delete(self, uuid: str) -> bool:
|
||||
"""
|
||||
Delete mapping by UUID.
|
||||
|
||||
Removes both chain→uuid and uuid→chain mappings.
|
||||
|
||||
Returns:
|
||||
True if mapping existed
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def registry_list_all(self) -> Dict[str, str]:
|
||||
"""
|
||||
Get all UUID → chain mappings.
|
||||
|
||||
Returns:
|
||||
Dict mapping UUID to chain
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def registry_clear(self) -> None:
|
||||
"""Clear all registry data (for testing)."""
|
||||
...
|
||||
|
||||
# =========================================================================
|
||||
# Lifecycle
|
||||
# =========================================================================
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close connections and clean up resources."""
|
||||
...
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Serialization Helpers
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def serialize_slot(slot: Any) -> bytes:
|
||||
"""Serialize a BufferSlot to bytes using pickle."""
|
||||
return pickle.dumps(slot)
|
||||
|
||||
|
||||
def deserialize_slot(data: bytes) -> Any:
|
||||
"""Deserialize bytes back to a BufferSlot."""
|
||||
return pickle.loads(data)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Factory
|
||||
# =============================================================================
|
||||
|
||||
_backend: Optional[SharedBackend] = None
|
||||
|
||||
|
||||
def get_shared_backend(config: Optional[BackendConfig] = None) -> SharedBackend:
|
||||
"""
|
||||
Get or create the global shared backend.
|
||||
|
||||
Backend selection:
|
||||
1. If Redis URL configured and redis available → RedisBackend
|
||||
2. If Manager configured → ManagerBackend
|
||||
3. Otherwise → InMemoryBackend (default)
|
||||
|
||||
Thread-safe singleton pattern.
|
||||
"""
|
||||
global _backend
|
||||
|
||||
if _backend is not None:
|
||||
return _backend
|
||||
|
||||
if config is None:
|
||||
config = BackendConfig()
|
||||
|
||||
if config.backend_type == "redis":
|
||||
from xml_pipeline.memory.redis_backend import RedisBackend
|
||||
|
||||
_backend = RedisBackend(
|
||||
url=config.redis_url,
|
||||
prefix=config.redis_prefix,
|
||||
ttl=config.redis_ttl,
|
||||
)
|
||||
elif config.backend_type == "manager":
|
||||
from xml_pipeline.memory.manager_backend import ManagerBackend
|
||||
|
||||
_backend = ManagerBackend(
|
||||
address=config.manager_address,
|
||||
authkey=config.manager_authkey,
|
||||
)
|
||||
else:
|
||||
# Default: in-memory backend
|
||||
from xml_pipeline.memory.memory_backend import InMemoryBackend
|
||||
|
||||
_backend = InMemoryBackend()
|
||||
|
||||
return _backend
|
||||
|
||||
|
||||
def reset_shared_backend() -> None:
|
||||
"""Reset the global backend (for testing)."""
|
||||
global _backend
|
||||
if _backend is not None:
|
||||
_backend.close()
|
||||
_backend = None
|
||||
|
|
@ -9,12 +9,21 @@ The pipeline is just a composition of stream operators.
|
|||
|
||||
Dependencies:
|
||||
pip install aiostream
|
||||
|
||||
CPU-Bound Handlers:
|
||||
Handlers marked with `cpu_bound: true` are dispatched to a
|
||||
ProcessPoolExecutor instead of running in the main event loop.
|
||||
This prevents long-running handlers from blocking other messages.
|
||||
|
||||
Requires shared backend (Redis or Manager) for cross-process data access.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import logging
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import AsyncIterable, Callable, List, Dict, Any, Optional
|
||||
|
|
@ -34,6 +43,8 @@ from xml_pipeline.message_bus.thread_registry import get_registry
|
|||
from xml_pipeline.message_bus.todo_registry import get_todo_registry
|
||||
from xml_pipeline.memory import get_context_buffer
|
||||
|
||||
pump_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Configuration (same as before)
|
||||
|
|
@ -49,6 +60,7 @@ class ListenerConfig:
|
|||
peers: List[str] = field(default_factory=list)
|
||||
broadcast: bool = False
|
||||
prompt: str = "" # System prompt for LLM agents (loaded into PromptRegistry)
|
||||
cpu_bound: bool = False # Dispatch to ProcessPoolExecutor if True
|
||||
payload_class: type = field(default=None, repr=False)
|
||||
handler: Callable = field(default=None, repr=False)
|
||||
|
||||
|
|
@ -69,6 +81,16 @@ class OrganismConfig:
|
|||
# LLM configuration (optional)
|
||||
llm_config: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Process pool configuration (for cpu_bound handlers)
|
||||
process_pool_workers: int = 4
|
||||
process_pool_max_tasks_per_child: int = 100
|
||||
process_pool_enabled: bool = False
|
||||
|
||||
# Backend configuration (for shared state)
|
||||
backend_type: str = "memory" # "memory", "manager", "redis"
|
||||
backend_redis_url: str = "redis://localhost:6379"
|
||||
backend_redis_prefix: str = "xp:"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Listener:
|
||||
|
|
@ -79,6 +101,8 @@ class Listener:
|
|||
is_agent: bool = False
|
||||
peers: List[str] = field(default_factory=list)
|
||||
broadcast: bool = False
|
||||
cpu_bound: bool = False # Dispatch to ProcessPoolExecutor if True
|
||||
handler_path: str = "" # Import path for worker process
|
||||
schema: etree.XMLSchema = field(default=None, repr=False)
|
||||
root_tag: str = ""
|
||||
usage_instructions: str = "" # Generated at registration for LLM agents
|
||||
|
|
@ -170,6 +194,10 @@ class StreamPump:
|
|||
|
||||
The entire flow is a single composable stream pipeline.
|
||||
Fan-out is natural via flatmap. Concurrency is controlled via task_limit.
|
||||
|
||||
CPU-bound handlers can be dispatched to a ProcessPoolExecutor by
|
||||
marking them with `cpu_bound: true` in config. This requires a
|
||||
shared backend (Redis or Manager) for cross-process data access.
|
||||
"""
|
||||
|
||||
def __init__(self, config: OrganismConfig):
|
||||
|
|
@ -188,6 +216,29 @@ class StreamPump:
|
|||
# Shutdown control
|
||||
self._running = False
|
||||
|
||||
# Process pool for cpu_bound handlers
|
||||
self._process_pool: Optional[ProcessPoolExecutor] = None
|
||||
if config.process_pool_enabled:
|
||||
self._process_pool = ProcessPoolExecutor(
|
||||
max_workers=config.process_pool_workers,
|
||||
max_tasks_per_child=config.process_pool_max_tasks_per_child,
|
||||
)
|
||||
pump_logger.info(
|
||||
f"ProcessPool initialized: {config.process_pool_workers} workers"
|
||||
)
|
||||
|
||||
# Shared backend for cross-process state
|
||||
self._shared_backend = None
|
||||
if config.backend_type != "memory":
|
||||
from xml_pipeline.memory.shared_backend import BackendConfig, get_shared_backend
|
||||
backend_config = BackendConfig(
|
||||
backend_type=config.backend_type,
|
||||
redis_url=config.backend_redis_url,
|
||||
redis_prefix=config.backend_redis_prefix,
|
||||
)
|
||||
self._shared_backend = get_shared_backend(backend_config)
|
||||
pump_logger.info(f"Shared backend: {config.backend_type}")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Registration
|
||||
# ------------------------------------------------------------------
|
||||
|
|
@ -203,6 +254,8 @@ class StreamPump:
|
|||
is_agent=lc.is_agent,
|
||||
peers=lc.peers,
|
||||
broadcast=lc.broadcast,
|
||||
cpu_bound=lc.cpu_bound,
|
||||
handler_path=lc.handler_path, # For worker process import
|
||||
schema=self._generate_schema(lc.payload_class),
|
||||
root_tag=root_tag,
|
||||
)
|
||||
|
|
@ -420,7 +473,15 @@ class StreamPump:
|
|||
)
|
||||
payload_ref = state.payload
|
||||
|
||||
response = await listener.handler(payload_ref, metadata)
|
||||
# Dispatch to handler - either in-process or via ProcessPool
|
||||
if listener.cpu_bound and self._process_pool and self._shared_backend:
|
||||
response = await self._dispatch_to_process_pool(
|
||||
listener=listener,
|
||||
payload=payload_ref,
|
||||
metadata=metadata,
|
||||
)
|
||||
else:
|
||||
response = await listener.handler(payload_ref, metadata)
|
||||
|
||||
# None means "no response needed" - don't re-inject
|
||||
if response is None:
|
||||
|
|
@ -549,6 +610,94 @@ class StreamPump:
|
|||
</message>"""
|
||||
return envelope.encode('utf-8')
|
||||
|
||||
async def _dispatch_to_process_pool(
|
||||
self,
|
||||
listener: Listener,
|
||||
payload: Any,
|
||||
metadata: HandlerMetadata,
|
||||
) -> Optional[HandlerResponse]:
|
||||
"""
|
||||
Dispatch handler to ProcessPoolExecutor for CPU-bound execution.
|
||||
|
||||
This offloads work to a separate process to avoid blocking
|
||||
the main event loop.
|
||||
|
||||
Args:
|
||||
listener: The target listener
|
||||
payload: The @xmlify dataclass payload
|
||||
metadata: Handler metadata
|
||||
|
||||
Returns:
|
||||
HandlerResponse or None (same as direct handler call)
|
||||
"""
|
||||
from xml_pipeline.message_bus.worker import (
|
||||
WorkerTask,
|
||||
store_task_data,
|
||||
fetch_response,
|
||||
cleanup_task_data,
|
||||
execute_handler,
|
||||
)
|
||||
|
||||
assert self._process_pool is not None
|
||||
assert self._shared_backend is not None
|
||||
|
||||
# Store payload and metadata in shared backend
|
||||
payload_uuid, metadata_uuid = store_task_data(
|
||||
self._shared_backend, payload, metadata
|
||||
)
|
||||
|
||||
# Create worker task
|
||||
task = WorkerTask(
|
||||
thread_uuid=metadata.thread_id,
|
||||
payload_uuid=payload_uuid,
|
||||
handler_path=listener.handler_path,
|
||||
metadata_uuid=metadata_uuid,
|
||||
listener_name=listener.name,
|
||||
is_agent=listener.is_agent,
|
||||
peers=list(listener.peers),
|
||||
)
|
||||
|
||||
# Backend config for worker process
|
||||
backend_config = {
|
||||
"backend_type": self.config.backend_type,
|
||||
"redis_url": self.config.backend_redis_url,
|
||||
"redis_prefix": self.config.backend_redis_prefix,
|
||||
}
|
||||
|
||||
try:
|
||||
# Submit to process pool and await result
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
self._process_pool,
|
||||
execute_handler,
|
||||
task,
|
||||
backend_config,
|
||||
)
|
||||
|
||||
if not result.success:
|
||||
pump_logger.error(
|
||||
f"Worker error for {listener.name}: {result.error}"
|
||||
)
|
||||
if result.error_traceback:
|
||||
pump_logger.debug(f"Traceback: {result.error_traceback}")
|
||||
return None
|
||||
|
||||
# Fetch response from shared backend
|
||||
if result.response_uuid:
|
||||
response = fetch_response(self._shared_backend, result.response_uuid)
|
||||
return response
|
||||
|
||||
return None
|
||||
|
||||
finally:
|
||||
# Clean up task data from backend
|
||||
cleanup_task_data(
|
||||
self._shared_backend,
|
||||
payload_uuid,
|
||||
metadata_uuid,
|
||||
result.response_uuid if 'result' in dir() and result.success else None,
|
||||
)
|
||||
|
||||
async def _reinject_responses(self, state: MessageState) -> None:
|
||||
"""Push handler responses back into the queue for next iteration."""
|
||||
await self.queue.put(state)
|
||||
|
|
@ -714,10 +863,16 @@ class StreamPump:
|
|||
await self.queue.put(state)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""Graceful shutdown — wait for queue to drain."""
|
||||
"""Graceful shutdown — wait for queue to drain and close resources."""
|
||||
self._running = False
|
||||
await self.queue.join()
|
||||
|
||||
# Shutdown process pool if active
|
||||
if self._process_pool:
|
||||
self._process_pool.shutdown(wait=True)
|
||||
pump_logger.info("ProcessPool shutdown complete")
|
||||
self._process_pool = None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Config Loader (same as before)
|
||||
|
|
@ -733,6 +888,19 @@ class ConfigLoader:
|
|||
@classmethod
|
||||
def _parse(cls, raw: dict) -> OrganismConfig:
|
||||
org = raw.get("organism", {})
|
||||
|
||||
# Parse process pool config
|
||||
pool = raw.get("process_pool", {})
|
||||
process_pool_enabled = pool.get("enabled", False) if pool else False
|
||||
process_pool_workers = pool.get("workers", 4) if pool else 4
|
||||
process_pool_max_tasks = pool.get("max_tasks_per_child", 100) if pool else 100
|
||||
|
||||
# Parse backend config
|
||||
backend = raw.get("backend", {})
|
||||
backend_type = backend.get("type", "memory") if backend else "memory"
|
||||
backend_redis_url = backend.get("redis_url", "redis://localhost:6379") if backend else "redis://localhost:6379"
|
||||
backend_redis_prefix = backend.get("redis_prefix", "xp:") if backend else "xp:"
|
||||
|
||||
config = OrganismConfig(
|
||||
name=org.get("name", "unnamed"),
|
||||
identity_path=org.get("identity", ""),
|
||||
|
|
@ -742,6 +910,12 @@ class ConfigLoader:
|
|||
max_concurrent_handlers=raw.get("max_concurrent_handlers", 20),
|
||||
max_concurrent_per_agent=raw.get("max_concurrent_per_agent", 5),
|
||||
llm_config=raw.get("llm", {}),
|
||||
process_pool_enabled=process_pool_enabled,
|
||||
process_pool_workers=process_pool_workers,
|
||||
process_pool_max_tasks_per_child=process_pool_max_tasks,
|
||||
backend_type=backend_type,
|
||||
backend_redis_url=backend_redis_url,
|
||||
backend_redis_prefix=backend_redis_prefix,
|
||||
)
|
||||
|
||||
for entry in raw.get("listeners", []):
|
||||
|
|
@ -762,6 +936,7 @@ class ConfigLoader:
|
|||
peers=raw.get("peers", []),
|
||||
broadcast=raw.get("broadcast", False),
|
||||
prompt=raw.get("prompt", ""),
|
||||
cpu_bound=raw.get("cpu_bound", False),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -14,15 +14,26 @@ Response routing:
|
|||
2. Prunes the last segment (the responder)
|
||||
3. Routes to the new last segment (the caller)
|
||||
4. Updates/cleans up the registry
|
||||
|
||||
For multi-process deployments, the registry can use a shared backend:
|
||||
from xml_pipeline.memory.shared_backend import get_shared_backend, BackendConfig
|
||||
|
||||
config = BackendConfig(backend_type="redis", redis_url="redis://localhost:6379")
|
||||
backend = get_shared_backend(config)
|
||||
registry = get_registry(backend=backend)
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid as uuid_module
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Dict, Optional, Tuple, TYPE_CHECKING
|
||||
import threading
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from xml_pipeline.memory.shared_backend import SharedBackend
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThreadRegistry:
|
||||
"""
|
||||
Bidirectional mapping between UUIDs and call chains.
|
||||
|
|
@ -32,12 +43,35 @@ class ThreadRegistry:
|
|||
The registry maintains a root thread established at boot time.
|
||||
All external messages without a known parent are registered as
|
||||
children of the root thread.
|
||||
|
||||
Supports two storage modes:
|
||||
1. Local mode (default): Uses in-process dictionaries
|
||||
2. Shared mode: Uses SharedBackend (Redis, Manager) for cross-process access
|
||||
"""
|
||||
_chain_to_uuid: Dict[str, str] = field(default_factory=dict)
|
||||
_uuid_to_chain: Dict[str, str] = field(default_factory=dict)
|
||||
_lock: threading.Lock = field(default_factory=threading.Lock)
|
||||
_root_uuid: Optional[str] = field(default=None)
|
||||
_root_chain: str = field(default="system")
|
||||
|
||||
def __init__(self, backend: Optional[SharedBackend] = None):
|
||||
"""
|
||||
Initialize thread registry.
|
||||
|
||||
Args:
|
||||
backend: Optional shared backend for cross-process storage.
|
||||
If None, uses in-process storage (original behavior).
|
||||
"""
|
||||
self._backend = backend
|
||||
|
||||
# Local storage (used when no backend)
|
||||
self._chain_to_uuid: Dict[str, str] = {}
|
||||
self._uuid_to_chain: Dict[str, str] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Root thread tracking
|
||||
self._root_uuid: Optional[str] = None
|
||||
self._root_chain: str = "system"
|
||||
|
||||
@property
|
||||
def is_shared(self) -> bool:
|
||||
"""Return True if using shared backend."""
|
||||
return self._backend is not None
|
||||
|
||||
def initialize_root(self, organism_name: str = "organism") -> str:
|
||||
"""
|
||||
|
|
@ -52,16 +86,36 @@ class ThreadRegistry:
|
|||
Returns:
|
||||
UUID for the root thread
|
||||
"""
|
||||
if self._backend is not None:
|
||||
return self._initialize_root_shared(organism_name)
|
||||
|
||||
with self._lock:
|
||||
if self._root_uuid is not None:
|
||||
return self._root_uuid
|
||||
|
||||
self._root_chain = f"system.{organism_name}"
|
||||
self._root_uuid = str(uuid.uuid4())
|
||||
self._root_uuid = str(uuid_module.uuid4())
|
||||
self._chain_to_uuid[self._root_chain] = self._root_uuid
|
||||
self._uuid_to_chain[self._root_uuid] = self._root_chain
|
||||
return self._root_uuid
|
||||
|
||||
def _initialize_root_shared(self, organism_name: str) -> str:
|
||||
"""Initialize root in shared backend."""
|
||||
assert self._backend is not None
|
||||
|
||||
self._root_chain = f"system.{organism_name}"
|
||||
|
||||
# Check if root already exists in backend
|
||||
existing_uuid = self._backend.registry_get_uuid(self._root_chain)
|
||||
if existing_uuid:
|
||||
self._root_uuid = existing_uuid
|
||||
return existing_uuid
|
||||
|
||||
# Create new root
|
||||
self._root_uuid = str(uuid_module.uuid4())
|
||||
self._backend.registry_set(self._root_chain, self._root_uuid)
|
||||
return self._root_uuid
|
||||
|
||||
@property
|
||||
def root_uuid(self) -> Optional[str]:
|
||||
"""Get the root thread UUID (None if not initialized)."""
|
||||
|
|
@ -82,11 +136,19 @@ class ThreadRegistry:
|
|||
Returns:
|
||||
UUID string for this chain
|
||||
"""
|
||||
if self._backend is not None:
|
||||
existing = self._backend.registry_get_uuid(chain)
|
||||
if existing:
|
||||
return existing
|
||||
new_uuid = str(uuid_module.uuid4())
|
||||
self._backend.registry_set(chain, new_uuid)
|
||||
return new_uuid
|
||||
|
||||
with self._lock:
|
||||
if chain in self._chain_to_uuid:
|
||||
return self._chain_to_uuid[chain]
|
||||
|
||||
new_uuid = str(uuid.uuid4())
|
||||
new_uuid = str(uuid_module.uuid4())
|
||||
self._chain_to_uuid[chain] = new_uuid
|
||||
self._uuid_to_chain[new_uuid] = chain
|
||||
return new_uuid
|
||||
|
|
@ -101,6 +163,9 @@ class ThreadRegistry:
|
|||
Returns:
|
||||
Chain string, or None if not found
|
||||
"""
|
||||
if self._backend is not None:
|
||||
return self._backend.registry_get_chain(thread_id)
|
||||
|
||||
with self._lock:
|
||||
return self._uuid_to_chain.get(thread_id)
|
||||
|
||||
|
|
@ -115,6 +180,9 @@ class ThreadRegistry:
|
|||
Returns:
|
||||
UUID for the extended chain
|
||||
"""
|
||||
if self._backend is not None:
|
||||
return self._extend_chain_shared(current_uuid, next_hop)
|
||||
|
||||
with self._lock:
|
||||
current_chain = self._uuid_to_chain.get(current_uuid, "")
|
||||
if current_chain:
|
||||
|
|
@ -127,11 +195,31 @@ class ThreadRegistry:
|
|||
return self._chain_to_uuid[new_chain]
|
||||
|
||||
# Create new UUID for extended chain
|
||||
new_uuid = str(uuid.uuid4())
|
||||
new_uuid = str(uuid_module.uuid4())
|
||||
self._chain_to_uuid[new_chain] = new_uuid
|
||||
self._uuid_to_chain[new_uuid] = new_chain
|
||||
return new_uuid
|
||||
|
||||
def _extend_chain_shared(self, current_uuid: str, next_hop: str) -> str:
|
||||
"""Extend chain in shared backend."""
|
||||
assert self._backend is not None
|
||||
|
||||
current_chain = self._backend.registry_get_chain(current_uuid) or ""
|
||||
if current_chain:
|
||||
new_chain = f"{current_chain}.{next_hop}"
|
||||
else:
|
||||
new_chain = next_hop
|
||||
|
||||
# Check if extended chain already exists
|
||||
existing = self._backend.registry_get_uuid(new_chain)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
# Create new UUID for extended chain
|
||||
new_uuid = str(uuid_module.uuid4())
|
||||
self._backend.registry_set(new_chain, new_uuid)
|
||||
return new_uuid
|
||||
|
||||
def prune_for_response(self, thread_id: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Prune chain for a response and get the target.
|
||||
|
|
@ -147,6 +235,9 @@ class ThreadRegistry:
|
|||
Returns:
|
||||
Tuple of (target_listener, new_thread_uuid) or (None, None) if chain exhausted
|
||||
"""
|
||||
if self._backend is not None:
|
||||
return self._prune_for_response_shared(thread_id)
|
||||
|
||||
with self._lock:
|
||||
chain = self._uuid_to_chain.get(thread_id)
|
||||
if not chain:
|
||||
|
|
@ -168,15 +259,40 @@ class ThreadRegistry:
|
|||
if pruned_chain in self._chain_to_uuid:
|
||||
new_uuid = self._chain_to_uuid[pruned_chain]
|
||||
else:
|
||||
new_uuid = str(uuid.uuid4())
|
||||
new_uuid = str(uuid_module.uuid4())
|
||||
self._chain_to_uuid[pruned_chain] = new_uuid
|
||||
self._uuid_to_chain[new_uuid] = pruned_chain
|
||||
|
||||
# Clean up old UUID (optional - could keep for debugging)
|
||||
# self._cleanup_uuid(thread_id)
|
||||
|
||||
return target, new_uuid
|
||||
|
||||
def _prune_for_response_shared(self, thread_id: str) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Prune chain in shared backend."""
|
||||
assert self._backend is not None
|
||||
|
||||
chain = self._backend.registry_get_chain(thread_id)
|
||||
if not chain:
|
||||
return None, None
|
||||
|
||||
parts = chain.split(".")
|
||||
if len(parts) <= 1:
|
||||
# Chain exhausted
|
||||
self._backend.registry_delete(thread_id)
|
||||
return None, None
|
||||
|
||||
# Prune last segment
|
||||
pruned_parts = parts[:-1]
|
||||
target = pruned_parts[-1]
|
||||
pruned_chain = ".".join(pruned_parts)
|
||||
|
||||
# Get or create UUID for pruned chain
|
||||
existing = self._backend.registry_get_uuid(pruned_chain)
|
||||
if existing:
|
||||
return target, existing
|
||||
|
||||
new_uuid = str(uuid_module.uuid4())
|
||||
self._backend.registry_set(pruned_chain, new_uuid)
|
||||
return target, new_uuid
|
||||
|
||||
def start_chain(self, initiator: str, target: str) -> str:
|
||||
"""
|
||||
Start a new call chain.
|
||||
|
|
@ -208,6 +324,9 @@ class ThreadRegistry:
|
|||
Returns:
|
||||
The same thread_id (now registered)
|
||||
"""
|
||||
if self._backend is not None:
|
||||
return self._register_thread_shared(thread_id, initiator, target)
|
||||
|
||||
with self._lock:
|
||||
# Check if UUID already registered (shouldn't happen, but be safe)
|
||||
if thread_id in self._uuid_to_chain:
|
||||
|
|
@ -230,6 +349,29 @@ class ThreadRegistry:
|
|||
self._uuid_to_chain[thread_id] = chain
|
||||
return thread_id
|
||||
|
||||
def _register_thread_shared(self, thread_id: str, initiator: str, target: str) -> str:
|
||||
"""Register thread in shared backend."""
|
||||
assert self._backend is not None
|
||||
|
||||
# Check if UUID already registered
|
||||
if self._backend.registry_get_chain(thread_id):
|
||||
return thread_id
|
||||
|
||||
# Build chain rooted at system root
|
||||
if self._root_uuid is not None:
|
||||
chain = f"{self._root_chain}.{initiator}.{target}"
|
||||
else:
|
||||
chain = f"{initiator}.{target}"
|
||||
|
||||
# Check if chain already has a different UUID
|
||||
existing = self._backend.registry_get_uuid(chain)
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
# Register the external UUID to this chain
|
||||
self._backend.registry_set(chain, thread_id)
|
||||
return thread_id
|
||||
|
||||
def _cleanup_uuid(self, thread_id: str) -> None:
|
||||
"""Remove a UUID mapping (internal, call with lock held)."""
|
||||
chain = self._uuid_to_chain.pop(thread_id, None)
|
||||
|
|
@ -238,22 +380,65 @@ class ThreadRegistry:
|
|||
|
||||
def cleanup(self, thread_id: str) -> None:
|
||||
"""Explicitly clean up a thread UUID."""
|
||||
if self._backend is not None:
|
||||
self._backend.registry_delete(thread_id)
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
self._cleanup_uuid(thread_id)
|
||||
|
||||
def debug_dump(self) -> Dict[str, str]:
|
||||
"""Return current mappings for debugging."""
|
||||
if self._backend is not None:
|
||||
return self._backend.registry_list_all()
|
||||
|
||||
with self._lock:
|
||||
return dict(self._uuid_to_chain)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all thread mappings (for testing only)."""
|
||||
if self._backend is not None:
|
||||
self._backend.registry_clear()
|
||||
self._root_uuid = None
|
||||
self._root_chain = "system"
|
||||
return
|
||||
|
||||
with self._lock:
|
||||
self._chain_to_uuid.clear()
|
||||
self._uuid_to_chain.clear()
|
||||
self._root_uuid = None
|
||||
self._root_chain = "system"
|
||||
|
||||
|
||||
# Global registry instance
|
||||
_registry: Optional[ThreadRegistry] = None
|
||||
_registry_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_registry() -> ThreadRegistry:
|
||||
"""Get the global thread registry."""
|
||||
def get_registry(backend: Optional[SharedBackend] = None) -> ThreadRegistry:
|
||||
"""
|
||||
Get the global thread registry.
|
||||
|
||||
Args:
|
||||
backend: Optional shared backend for cross-process storage.
|
||||
Only used on first call (when creating the singleton).
|
||||
Subsequent calls return the existing singleton.
|
||||
|
||||
Returns:
|
||||
Global ThreadRegistry instance.
|
||||
"""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
_registry = ThreadRegistry()
|
||||
with _registry_lock:
|
||||
if _registry is None:
|
||||
_registry = ThreadRegistry(backend=backend)
|
||||
return _registry
|
||||
|
||||
|
||||
def reset_registry() -> None:
|
||||
"""Reset the global thread registry (for testing)."""
|
||||
global _registry
|
||||
with _registry_lock:
|
||||
if _registry is not None:
|
||||
_registry.clear()
|
||||
_registry = None
|
||||
|
|
|
|||
339
xml_pipeline/message_bus/worker.py
Normal file
339
xml_pipeline/message_bus/worker.py
Normal file
|
|
@ -0,0 +1,339 @@
|
|||
"""
|
||||
worker.py — Worker process entry point for CPU-bound handler dispatch.
|
||||
|
||||
This module provides the entry point for handler execution in worker processes
|
||||
when using ProcessPoolExecutor. Workers communicate with the main process via
|
||||
the shared backend (Redis or Manager).
|
||||
|
||||
Architecture:
|
||||
Main Process (StreamPump)
|
||||
│
|
||||
│ Submits WorkerTask to ProcessPoolExecutor
|
||||
│
|
||||
▼
|
||||
Worker Process (this module)
|
||||
│
|
||||
├── Imports handler module
|
||||
├── Fetches payload from shared backend
|
||||
├── Executes handler
|
||||
├── Stores response in shared backend
|
||||
└── Returns WorkerResult
|
||||
|
||||
Key design decisions:
|
||||
- Minimal IPC payload: Only UUIDs and module paths cross process boundary
|
||||
- Handler state via shared backend: Workers fetch/store data in Redis
|
||||
- Process reuse: Workers are reused across tasks (max_tasks_per_child for restart)
|
||||
- Error isolation: Handler exceptions don't crash the worker pool
|
||||
|
||||
Usage:
|
||||
# Main process submits task
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from xml_pipeline.message_bus.worker import execute_handler, WorkerTask
|
||||
|
||||
pool = ProcessPoolExecutor(max_workers=4)
|
||||
task = WorkerTask(
|
||||
thread_uuid="550e8400-...",
|
||||
payload_uuid="6ba7b810-...",
|
||||
handler_path="handlers.librarian.handle_query",
|
||||
metadata_uuid="7c9e6679-...",
|
||||
)
|
||||
future = pool.submit(execute_handler, task)
|
||||
result = future.result() # WorkerResult
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import traceback
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from xml_pipeline.memory.shared_backend import SharedBackend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerTask:
|
||||
"""
|
||||
Task submitted to worker process.
|
||||
|
||||
Contains only UUIDs and paths — actual data lives in shared backend.
|
||||
This minimizes IPC overhead and allows large payloads.
|
||||
"""
|
||||
|
||||
thread_uuid: str
|
||||
payload_uuid: str # UUID for looking up serialized payload
|
||||
handler_path: str # e.g., "handlers.librarian.handle_query"
|
||||
metadata_uuid: str # UUID for looking up serialized HandlerMetadata
|
||||
listener_name: str # Name of the listener (for logging/response injection)
|
||||
is_agent: bool = False # Whether this is an agent handler
|
||||
peers: list[str] = field(default_factory=list) # Allowed peers (for agents)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkerResult:
|
||||
"""
|
||||
Result from worker process.
|
||||
|
||||
Contains UUID for response stored in backend, or error information.
|
||||
"""
|
||||
|
||||
success: bool
|
||||
response_uuid: Optional[str] = None # UUID for serialized HandlerResponse
|
||||
error: Optional[str] = None # Error message if failed
|
||||
error_traceback: Optional[str] = None # Full traceback for debugging
|
||||
elapsed_ms: float = 0.0 # Execution time
|
||||
|
||||
|
||||
# Keys for storing task data in shared backend
|
||||
def _payload_key(uuid: str) -> str:
|
||||
return f"worker:payload:{uuid}"
|
||||
|
||||
|
||||
def _metadata_key(uuid: str) -> str:
|
||||
return f"worker:metadata:{uuid}"
|
||||
|
||||
|
||||
def _response_key(uuid: str) -> str:
|
||||
return f"worker:response:{uuid}"
|
||||
|
||||
|
||||
def _import_handler(handler_path: str) -> Callable:
|
||||
"""
|
||||
Import and return handler function from path.
|
||||
|
||||
Args:
|
||||
handler_path: Dotted path like "handlers.librarian.handle_query"
|
||||
|
||||
Returns:
|
||||
The handler function
|
||||
|
||||
Raises:
|
||||
ImportError: If module not found
|
||||
AttributeError: If function not found in module
|
||||
"""
|
||||
module_path, func_name = handler_path.rsplit(".", 1)
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, func_name)
|
||||
|
||||
|
||||
def store_task_data(
|
||||
backend: SharedBackend,
|
||||
payload: Any,
|
||||
metadata: Any,
|
||||
ttl: int = 3600,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Store payload and metadata in shared backend for worker access.
|
||||
|
||||
Args:
|
||||
backend: Shared backend instance
|
||||
payload: The @xmlify dataclass payload
|
||||
metadata: HandlerMetadata instance
|
||||
ttl: Time-to-live in seconds (default 1 hour)
|
||||
|
||||
Returns:
|
||||
Tuple of (payload_uuid, metadata_uuid)
|
||||
"""
|
||||
import pickle
|
||||
|
||||
payload_uuid = str(uuid.uuid4())
|
||||
metadata_uuid = str(uuid.uuid4())
|
||||
|
||||
# Serialize and store
|
||||
payload_bytes = pickle.dumps(payload)
|
||||
metadata_bytes = pickle.dumps(metadata)
|
||||
|
||||
# Store as buffer slots (reusing the backend interface)
|
||||
# We use a special "worker" prefix to distinguish from regular slots
|
||||
backend.buffer_append(f"worker:payload:{payload_uuid}", payload_bytes)
|
||||
backend.buffer_append(f"worker:metadata:{metadata_uuid}", metadata_bytes)
|
||||
|
||||
return payload_uuid, metadata_uuid
|
||||
|
||||
|
||||
def fetch_task_data(
|
||||
backend: SharedBackend,
|
||||
payload_uuid: str,
|
||||
metadata_uuid: str,
|
||||
) -> tuple[Any, Any]:
|
||||
"""
|
||||
Fetch payload and metadata from shared backend.
|
||||
|
||||
Args:
|
||||
backend: Shared backend instance
|
||||
payload_uuid: UUID of stored payload
|
||||
metadata_uuid: UUID of stored metadata
|
||||
|
||||
Returns:
|
||||
Tuple of (payload, metadata)
|
||||
"""
|
||||
import pickle
|
||||
|
||||
payload_data = backend.buffer_get_slot(f"worker:payload:{payload_uuid}", 0)
|
||||
metadata_data = backend.buffer_get_slot(f"worker:metadata:{metadata_uuid}", 0)
|
||||
|
||||
if payload_data is None:
|
||||
raise ValueError(f"Payload not found: {payload_uuid}")
|
||||
if metadata_data is None:
|
||||
raise ValueError(f"Metadata not found: {metadata_uuid}")
|
||||
|
||||
payload = pickle.loads(payload_data)
|
||||
metadata = pickle.loads(metadata_data)
|
||||
|
||||
return payload, metadata
|
||||
|
||||
|
||||
def store_response(
|
||||
backend: SharedBackend,
|
||||
response: Any,
|
||||
) -> str:
|
||||
"""
|
||||
Store handler response in shared backend for main process retrieval.
|
||||
|
||||
Args:
|
||||
backend: Shared backend instance
|
||||
response: HandlerResponse or None
|
||||
|
||||
Returns:
|
||||
Response UUID
|
||||
"""
|
||||
import pickle
|
||||
|
||||
response_uuid = str(uuid.uuid4())
|
||||
response_bytes = pickle.dumps(response)
|
||||
backend.buffer_append(f"worker:response:{response_uuid}", response_bytes)
|
||||
|
||||
return response_uuid
|
||||
|
||||
|
||||
def fetch_response(
|
||||
backend: SharedBackend,
|
||||
response_uuid: str,
|
||||
) -> Any:
|
||||
"""
|
||||
Fetch handler response from shared backend.
|
||||
|
||||
Args:
|
||||
backend: Shared backend instance
|
||||
response_uuid: UUID of stored response
|
||||
|
||||
Returns:
|
||||
HandlerResponse or None
|
||||
"""
|
||||
import pickle
|
||||
|
||||
response_data = backend.buffer_get_slot(f"worker:response:{response_uuid}", 0)
|
||||
if response_data is None:
|
||||
raise ValueError(f"Response not found: {response_uuid}")
|
||||
|
||||
return pickle.loads(response_data)
|
||||
|
||||
|
||||
def cleanup_task_data(
|
||||
backend: SharedBackend,
|
||||
payload_uuid: str,
|
||||
metadata_uuid: str,
|
||||
response_uuid: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Clean up temporary task data from shared backend.
|
||||
|
||||
Call after response has been processed by main process.
|
||||
"""
|
||||
backend.buffer_delete_thread(f"worker:payload:{payload_uuid}")
|
||||
backend.buffer_delete_thread(f"worker:metadata:{metadata_uuid}")
|
||||
if response_uuid:
|
||||
backend.buffer_delete_thread(f"worker:response:{response_uuid}")
|
||||
|
||||
|
||||
def execute_handler(
|
||||
task: WorkerTask,
|
||||
backend_config: Optional[dict] = None,
|
||||
) -> WorkerResult:
|
||||
"""
|
||||
Execute a handler in the worker process.
|
||||
|
||||
This is the entry point called by ProcessPoolExecutor.
|
||||
|
||||
Args:
|
||||
task: WorkerTask containing UUIDs and handler path
|
||||
backend_config: Optional backend configuration dict
|
||||
(if None, uses default shared backend)
|
||||
|
||||
Returns:
|
||||
WorkerResult with response UUID or error
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
start_time = time.monotonic()
|
||||
|
||||
try:
|
||||
# Get shared backend
|
||||
from xml_pipeline.memory.shared_backend import (
|
||||
BackendConfig,
|
||||
get_shared_backend,
|
||||
)
|
||||
|
||||
if backend_config:
|
||||
config = BackendConfig(**backend_config)
|
||||
backend = get_shared_backend(config)
|
||||
else:
|
||||
backend = get_shared_backend()
|
||||
|
||||
# Fetch payload and metadata
|
||||
payload, metadata = fetch_task_data(
|
||||
backend, task.payload_uuid, task.metadata_uuid
|
||||
)
|
||||
|
||||
# Import handler
|
||||
handler = _import_handler(task.handler_path)
|
||||
|
||||
# Execute handler (async)
|
||||
# Workers run their own event loop for async handlers
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
response = loop.run_until_complete(handler(payload, metadata))
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
# Store response
|
||||
response_uuid = store_response(backend, response)
|
||||
|
||||
elapsed = (time.monotonic() - start_time) * 1000
|
||||
|
||||
return WorkerResult(
|
||||
success=True,
|
||||
response_uuid=response_uuid,
|
||||
elapsed_ms=elapsed,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
elapsed = (time.monotonic() - start_time) * 1000
|
||||
logger.exception(f"Handler execution failed: {task.handler_path}")
|
||||
|
||||
return WorkerResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
error_traceback=traceback.format_exc(),
|
||||
elapsed_ms=elapsed,
|
||||
)
|
||||
|
||||
|
||||
def execute_handler_sync(
|
||||
task: WorkerTask,
|
||||
backend_config: Optional[dict] = None,
|
||||
) -> WorkerResult:
|
||||
"""
|
||||
Synchronous wrapper for execute_handler.
|
||||
|
||||
Use this when the handler doesn't need async features
|
||||
(pure computation, no I/O).
|
||||
"""
|
||||
return execute_handler(task, backend_config)
|
||||
Loading…
Reference in a new issue