Add shared backend for multiprocess pipeline support
Introduces SharedBackend Protocol for cross-process state sharing: - InMemoryBackend: default single-process storage - ManagerBackend: multiprocessing.Manager for local multi-process - RedisBackend: distributed deployments with TTL auto-GC Adds ProcessPoolExecutor support for CPU-bound handlers: - worker.py: worker process entry point - stream_pump.py: cpu_bound handler dispatch - Config: backend and process_pool sections in organism.yaml ContextBuffer and ThreadRegistry now accept optional backend parameter while maintaining full backward compatibility. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
f87d9f80e9
commit
6790c7a46c
12 changed files with 2346 additions and 28 deletions
|
|
@ -36,6 +36,50 @@ def pytest_configure(config):
|
||||||
# Fixtures available to all tests
|
# Fixtures available to all tests
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_singletons():
|
||||||
|
"""Reset global singletons before each test to ensure isolation."""
|
||||||
|
# Clear registries before test
|
||||||
|
try:
|
||||||
|
from xml_pipeline.platform.prompt_registry import get_prompt_registry
|
||||||
|
get_prompt_registry().clear()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xml_pipeline.memory.context_buffer import get_context_buffer
|
||||||
|
get_context_buffer().clear()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xml_pipeline.message_bus.thread_registry import get_registry
|
||||||
|
get_registry().clear()
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
yield # Run the test
|
||||||
|
|
||||||
|
# Clear after test too for good measure
|
||||||
|
try:
|
||||||
|
from xml_pipeline.platform.prompt_registry import get_prompt_registry
|
||||||
|
get_prompt_registry().clear()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xml_pipeline.memory.context_buffer import get_context_buffer
|
||||||
|
get_context_buffer().clear()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
from xml_pipeline.message_bus.thread_registry import get_registry
|
||||||
|
get_registry().clear()
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_thread_id():
|
def sample_thread_id():
|
||||||
"""A valid UUID for testing."""
|
"""A valid UUID for testing."""
|
||||||
|
|
|
||||||
393
tests/test_shared_backend.py
Normal file
393
tests/test_shared_backend.py
Normal file
|
|
@ -0,0 +1,393 @@
|
||||||
|
"""
|
||||||
|
Tests for shared backend implementations.
|
||||||
|
|
||||||
|
Tests InMemoryBackend and ManagerBackend.
|
||||||
|
Redis tests require a running Redis server (skipped if not available).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from xml_pipeline.memory.memory_backend import InMemoryBackend
|
||||||
|
from xml_pipeline.memory.shared_backend import (
|
||||||
|
BackendConfig,
|
||||||
|
serialize_slot,
|
||||||
|
deserialize_slot,
|
||||||
|
)
|
||||||
|
from xml_pipeline.memory.context_buffer import BufferSlot, SlotMetadata
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Fixtures
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def memory_backend():
|
||||||
|
"""Create fresh in-memory backend."""
|
||||||
|
backend = InMemoryBackend()
|
||||||
|
yield backend
|
||||||
|
backend.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def manager_backend():
|
||||||
|
"""Create fresh manager backend."""
|
||||||
|
from xml_pipeline.memory.manager_backend import ManagerBackend
|
||||||
|
|
||||||
|
backend = ManagerBackend()
|
||||||
|
yield backend
|
||||||
|
backend.close()
|
||||||
|
|
||||||
|
|
||||||
|
def redis_available() -> bool:
|
||||||
|
"""Check if Redis is available."""
|
||||||
|
try:
|
||||||
|
import redis
|
||||||
|
|
||||||
|
client = redis.from_url("redis://localhost:6379")
|
||||||
|
client.ping()
|
||||||
|
client.close()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def redis_backend():
|
||||||
|
"""Create fresh Redis backend (skipped if Redis not available)."""
|
||||||
|
if not redis_available():
|
||||||
|
pytest.skip("Redis not available")
|
||||||
|
|
||||||
|
from xml_pipeline.memory.redis_backend import RedisBackend
|
||||||
|
|
||||||
|
backend = RedisBackend(
|
||||||
|
url="redis://localhost:6379",
|
||||||
|
prefix="xp_test:",
|
||||||
|
ttl=300,
|
||||||
|
)
|
||||||
|
# Clear test keys
|
||||||
|
backend.buffer_clear()
|
||||||
|
backend.registry_clear()
|
||||||
|
yield backend
|
||||||
|
# Cleanup
|
||||||
|
backend.buffer_clear()
|
||||||
|
backend.registry_clear()
|
||||||
|
backend.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Sample data
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SamplePayload:
|
||||||
|
"""Sample payload for testing."""
|
||||||
|
|
||||||
|
message: str
|
||||||
|
value: int
|
||||||
|
|
||||||
|
|
||||||
|
def make_slot(thread_id: str, index: int = 0) -> BufferSlot:
|
||||||
|
"""Create a sample buffer slot."""
|
||||||
|
payload = SamplePayload(message="hello", value=42)
|
||||||
|
metadata = SlotMetadata(
|
||||||
|
thread_id=thread_id,
|
||||||
|
from_id="sender",
|
||||||
|
to_id="receiver",
|
||||||
|
slot_index=index,
|
||||||
|
timestamp="2024-01-15T00:00:00Z",
|
||||||
|
payload_type="SamplePayload",
|
||||||
|
)
|
||||||
|
return BufferSlot(payload=payload, metadata=metadata)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# InMemoryBackend Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestInMemoryBackend:
|
||||||
|
"""Tests for InMemoryBackend."""
|
||||||
|
|
||||||
|
def test_buffer_append_and_get(self, memory_backend):
|
||||||
|
"""Test appending and retrieving buffer slots."""
|
||||||
|
slot = make_slot("thread-1")
|
||||||
|
slot_bytes = serialize_slot(slot)
|
||||||
|
|
||||||
|
# Append
|
||||||
|
index = memory_backend.buffer_append("thread-1", slot_bytes)
|
||||||
|
assert index == 0
|
||||||
|
|
||||||
|
# Append another
|
||||||
|
index2 = memory_backend.buffer_append("thread-1", slot_bytes)
|
||||||
|
assert index2 == 1
|
||||||
|
|
||||||
|
# Get all
|
||||||
|
slots = memory_backend.buffer_get_thread("thread-1")
|
||||||
|
assert len(slots) == 2
|
||||||
|
|
||||||
|
# Deserialize and verify
|
||||||
|
retrieved = deserialize_slot(slots[0])
|
||||||
|
assert retrieved.metadata.thread_id == "thread-1"
|
||||||
|
assert retrieved.payload.message == "hello"
|
||||||
|
|
||||||
|
def test_buffer_get_slot(self, memory_backend):
|
||||||
|
"""Test getting specific slot by index."""
|
||||||
|
slot = make_slot("thread-1")
|
||||||
|
slot_bytes = serialize_slot(slot)
|
||||||
|
|
||||||
|
memory_backend.buffer_append("thread-1", slot_bytes)
|
||||||
|
memory_backend.buffer_append("thread-1", slot_bytes)
|
||||||
|
|
||||||
|
# Get specific slot
|
||||||
|
data = memory_backend.buffer_get_slot("thread-1", 1)
|
||||||
|
assert data is not None
|
||||||
|
|
||||||
|
# Non-existent index
|
||||||
|
data = memory_backend.buffer_get_slot("thread-1", 999)
|
||||||
|
assert data is None
|
||||||
|
|
||||||
|
def test_buffer_thread_exists(self, memory_backend):
|
||||||
|
"""Test thread existence check."""
|
||||||
|
assert not memory_backend.buffer_thread_exists("thread-1")
|
||||||
|
|
||||||
|
slot_bytes = serialize_slot(make_slot("thread-1"))
|
||||||
|
memory_backend.buffer_append("thread-1", slot_bytes)
|
||||||
|
|
||||||
|
assert memory_backend.buffer_thread_exists("thread-1")
|
||||||
|
|
||||||
|
def test_buffer_delete_thread(self, memory_backend):
|
||||||
|
"""Test deleting thread buffer."""
|
||||||
|
slot_bytes = serialize_slot(make_slot("thread-1"))
|
||||||
|
memory_backend.buffer_append("thread-1", slot_bytes)
|
||||||
|
|
||||||
|
assert memory_backend.buffer_delete_thread("thread-1")
|
||||||
|
assert not memory_backend.buffer_thread_exists("thread-1")
|
||||||
|
assert not memory_backend.buffer_delete_thread("thread-1") # Already deleted
|
||||||
|
|
||||||
|
def test_buffer_list_threads(self, memory_backend):
|
||||||
|
"""Test listing all threads."""
|
||||||
|
slot_bytes = serialize_slot(make_slot("thread-1"))
|
||||||
|
memory_backend.buffer_append("thread-1", slot_bytes)
|
||||||
|
memory_backend.buffer_append("thread-2", slot_bytes)
|
||||||
|
|
||||||
|
threads = memory_backend.buffer_list_threads()
|
||||||
|
assert set(threads) == {"thread-1", "thread-2"}
|
||||||
|
|
||||||
|
def test_registry_set_and_get(self, memory_backend):
|
||||||
|
"""Test registry set and get operations."""
|
||||||
|
memory_backend.registry_set("a.b.c", "uuid-123")
|
||||||
|
|
||||||
|
# Get UUID from chain
|
||||||
|
uuid = memory_backend.registry_get_uuid("a.b.c")
|
||||||
|
assert uuid == "uuid-123"
|
||||||
|
|
||||||
|
# Get chain from UUID
|
||||||
|
chain = memory_backend.registry_get_chain("uuid-123")
|
||||||
|
assert chain == "a.b.c"
|
||||||
|
|
||||||
|
def test_registry_delete(self, memory_backend):
|
||||||
|
"""Test registry delete."""
|
||||||
|
memory_backend.registry_set("a.b.c", "uuid-123")
|
||||||
|
|
||||||
|
assert memory_backend.registry_delete("uuid-123")
|
||||||
|
assert memory_backend.registry_get_uuid("a.b.c") is None
|
||||||
|
assert memory_backend.registry_get_chain("uuid-123") is None
|
||||||
|
assert not memory_backend.registry_delete("uuid-123") # Already deleted
|
||||||
|
|
||||||
|
def test_registry_list_all(self, memory_backend):
|
||||||
|
"""Test listing all registry entries."""
|
||||||
|
memory_backend.registry_set("a.b", "uuid-1")
|
||||||
|
memory_backend.registry_set("x.y.z", "uuid-2")
|
||||||
|
|
||||||
|
all_entries = memory_backend.registry_list_all()
|
||||||
|
assert all_entries == {"uuid-1": "a.b", "uuid-2": "x.y.z"}
|
||||||
|
|
||||||
|
def test_registry_clear(self, memory_backend):
|
||||||
|
"""Test clearing registry."""
|
||||||
|
memory_backend.registry_set("a.b", "uuid-1")
|
||||||
|
memory_backend.registry_clear()
|
||||||
|
|
||||||
|
assert memory_backend.registry_list_all() == {}
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# ManagerBackend Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestManagerBackend:
|
||||||
|
"""Tests for ManagerBackend (multiprocessing.Manager)."""
|
||||||
|
|
||||||
|
def test_buffer_append_and_get(self, manager_backend):
|
||||||
|
"""Test appending and retrieving buffer slots."""
|
||||||
|
slot = make_slot("thread-1")
|
||||||
|
slot_bytes = serialize_slot(slot)
|
||||||
|
|
||||||
|
index = manager_backend.buffer_append("thread-1", slot_bytes)
|
||||||
|
assert index == 0
|
||||||
|
|
||||||
|
slots = manager_backend.buffer_get_thread("thread-1")
|
||||||
|
assert len(slots) == 1
|
||||||
|
|
||||||
|
def test_registry_operations(self, manager_backend):
|
||||||
|
"""Test registry operations via Manager."""
|
||||||
|
manager_backend.registry_set("a.b.c", "uuid-123")
|
||||||
|
|
||||||
|
uuid = manager_backend.registry_get_uuid("a.b.c")
|
||||||
|
assert uuid == "uuid-123"
|
||||||
|
|
||||||
|
chain = manager_backend.registry_get_chain("uuid-123")
|
||||||
|
assert chain == "a.b.c"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# RedisBackend Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not redis_available(), reason="Redis not available")
|
||||||
|
class TestRedisBackend:
|
||||||
|
"""Tests for RedisBackend (requires running Redis)."""
|
||||||
|
|
||||||
|
def test_buffer_append_and_get(self, redis_backend):
|
||||||
|
"""Test appending and retrieving buffer slots."""
|
||||||
|
slot = make_slot("thread-1")
|
||||||
|
slot_bytes = serialize_slot(slot)
|
||||||
|
|
||||||
|
index = redis_backend.buffer_append("thread-1", slot_bytes)
|
||||||
|
assert index == 0
|
||||||
|
|
||||||
|
slots = redis_backend.buffer_get_thread("thread-1")
|
||||||
|
assert len(slots) == 1
|
||||||
|
|
||||||
|
retrieved = deserialize_slot(slots[0])
|
||||||
|
assert retrieved.metadata.thread_id == "thread-1"
|
||||||
|
|
||||||
|
def test_buffer_thread_len(self, redis_backend):
|
||||||
|
"""Test getting thread length."""
|
||||||
|
slot_bytes = serialize_slot(make_slot("thread-1"))
|
||||||
|
|
||||||
|
redis_backend.buffer_append("thread-1", slot_bytes)
|
||||||
|
redis_backend.buffer_append("thread-1", slot_bytes)
|
||||||
|
redis_backend.buffer_append("thread-1", slot_bytes)
|
||||||
|
|
||||||
|
assert redis_backend.buffer_thread_len("thread-1") == 3
|
||||||
|
|
||||||
|
def test_registry_operations(self, redis_backend):
|
||||||
|
"""Test registry operations via Redis."""
|
||||||
|
redis_backend.registry_set("console.router.greeter", "uuid-abc")
|
||||||
|
|
||||||
|
uuid = redis_backend.registry_get_uuid("console.router.greeter")
|
||||||
|
assert uuid == "uuid-abc"
|
||||||
|
|
||||||
|
chain = redis_backend.registry_get_chain("uuid-abc")
|
||||||
|
assert chain == "console.router.greeter"
|
||||||
|
|
||||||
|
def test_registry_delete(self, redis_backend):
|
||||||
|
"""Test registry delete removes both directions."""
|
||||||
|
redis_backend.registry_set("a.b", "uuid-123")
|
||||||
|
|
||||||
|
assert redis_backend.registry_delete("uuid-123")
|
||||||
|
assert redis_backend.registry_get_uuid("a.b") is None
|
||||||
|
assert redis_backend.registry_get_chain("uuid-123") is None
|
||||||
|
|
||||||
|
def test_ping(self, redis_backend):
|
||||||
|
"""Test Redis ping."""
|
||||||
|
assert redis_backend.ping()
|
||||||
|
|
||||||
|
def test_info(self, redis_backend):
|
||||||
|
"""Test backend info."""
|
||||||
|
slot_bytes = serialize_slot(make_slot("thread-1"))
|
||||||
|
redis_backend.buffer_append("thread-1", slot_bytes)
|
||||||
|
redis_backend.registry_set("a.b", "uuid-1")
|
||||||
|
|
||||||
|
info = redis_backend.info()
|
||||||
|
assert "buffer_threads" in info
|
||||||
|
assert "registry_entries" in info
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Integration: ContextBuffer with SharedBackend
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextBufferWithBackend:
|
||||||
|
"""Test ContextBuffer using shared backends."""
|
||||||
|
|
||||||
|
def test_context_buffer_with_memory_backend(self):
|
||||||
|
"""Test ContextBuffer works with in-memory backend."""
|
||||||
|
from xml_pipeline.memory.context_buffer import ContextBuffer
|
||||||
|
|
||||||
|
backend = InMemoryBackend()
|
||||||
|
buffer = ContextBuffer(backend=backend)
|
||||||
|
|
||||||
|
assert buffer.is_shared
|
||||||
|
|
||||||
|
# Append slot
|
||||||
|
slot = buffer.append(
|
||||||
|
thread_id="test-thread",
|
||||||
|
payload=SamplePayload(message="hello", value=42),
|
||||||
|
from_id="sender",
|
||||||
|
to_id="receiver",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert slot.thread_id == "test-thread"
|
||||||
|
assert slot.payload.message == "hello"
|
||||||
|
|
||||||
|
# Get thread
|
||||||
|
thread = buffer.get_thread("test-thread")
|
||||||
|
assert thread is not None
|
||||||
|
assert len(thread) == 1
|
||||||
|
|
||||||
|
# Get stats
|
||||||
|
stats = buffer.get_stats()
|
||||||
|
assert stats["thread_count"] == 1
|
||||||
|
assert stats["backend"] == "shared"
|
||||||
|
|
||||||
|
backend.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Integration: ThreadRegistry with SharedBackend
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestThreadRegistryWithBackend:
|
||||||
|
"""Test ThreadRegistry using shared backends."""
|
||||||
|
|
||||||
|
def test_thread_registry_with_memory_backend(self):
|
||||||
|
"""Test ThreadRegistry works with in-memory backend."""
|
||||||
|
from xml_pipeline.message_bus.thread_registry import ThreadRegistry
|
||||||
|
|
||||||
|
backend = InMemoryBackend()
|
||||||
|
registry = ThreadRegistry(backend=backend)
|
||||||
|
|
||||||
|
assert registry.is_shared
|
||||||
|
|
||||||
|
# Initialize root
|
||||||
|
root_uuid = registry.initialize_root("test-organism")
|
||||||
|
assert root_uuid is not None
|
||||||
|
assert registry.root_chain == "system.test-organism"
|
||||||
|
|
||||||
|
# Get or create chain
|
||||||
|
uuid = registry.get_or_create("a.b.c")
|
||||||
|
assert uuid is not None
|
||||||
|
|
||||||
|
# Lookup
|
||||||
|
chain = registry.lookup(uuid)
|
||||||
|
assert chain == "a.b.c"
|
||||||
|
|
||||||
|
# Extend chain
|
||||||
|
new_uuid = registry.extend_chain(uuid, "d")
|
||||||
|
new_chain = registry.lookup(new_uuid)
|
||||||
|
assert new_chain == "a.b.c.d"
|
||||||
|
|
||||||
|
# Prune for response
|
||||||
|
target, pruned_uuid = registry.prune_for_response(new_uuid)
|
||||||
|
assert target == "c"
|
||||||
|
assert registry.lookup(pruned_uuid) == "a.b.c"
|
||||||
|
|
||||||
|
backend.close()
|
||||||
|
|
@ -65,6 +65,9 @@ class ListenerConfig:
|
||||||
allowed_tools: list[str] = field(default_factory=list)
|
allowed_tools: list[str] = field(default_factory=list)
|
||||||
blocked_tools: list[str] = field(default_factory=list)
|
blocked_tools: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
# Dispatch mode
|
||||||
|
cpu_bound: bool = False # If True, dispatch to ProcessPoolExecutor
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ServerConfig:
|
class ServerConfig:
|
||||||
|
|
@ -83,6 +86,42 @@ class AuthConfig:
|
||||||
totp_secret_env: str = "ORGANISM_TOTP_SECRET"
|
totp_secret_env: str = "ORGANISM_TOTP_SECRET"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BackendStorageConfig:
|
||||||
|
"""
|
||||||
|
Shared backend configuration for multi-process deployments.
|
||||||
|
|
||||||
|
Enables ContextBuffer and ThreadRegistry to use shared storage
|
||||||
|
(Redis or multiprocessing.Manager) for cross-process access.
|
||||||
|
"""
|
||||||
|
|
||||||
|
backend_type: str = "memory" # "memory", "manager", "redis"
|
||||||
|
|
||||||
|
# Redis-specific config
|
||||||
|
redis_url: str = "redis://localhost:6379"
|
||||||
|
redis_prefix: str = "xp:"
|
||||||
|
redis_ttl: int = 86400 # 24 hours default TTL
|
||||||
|
|
||||||
|
# Limits
|
||||||
|
max_slots_per_thread: int = 10000
|
||||||
|
max_threads: int = 1000
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProcessPoolConfig:
|
||||||
|
"""
|
||||||
|
Process pool configuration for CPU-bound handler dispatch.
|
||||||
|
|
||||||
|
When configured, handlers marked with `cpu_bound: true` are
|
||||||
|
dispatched to a ProcessPoolExecutor instead of running in
|
||||||
|
the main event loop.
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
workers: int = 4 # Number of worker processes
|
||||||
|
max_tasks_per_child: int = 100 # Restart workers after N tasks
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class OrganismConfig:
|
class OrganismConfig:
|
||||||
"""Complete organism configuration."""
|
"""Complete organism configuration."""
|
||||||
|
|
@ -92,6 +131,8 @@ class OrganismConfig:
|
||||||
llm_backends: list[LLMBackendConfig] = field(default_factory=list)
|
llm_backends: list[LLMBackendConfig] = field(default_factory=list)
|
||||||
server: ServerConfig | None = None
|
server: ServerConfig | None = None
|
||||||
auth: AuthConfig | None = None
|
auth: AuthConfig | None = None
|
||||||
|
backend: BackendStorageConfig | None = None
|
||||||
|
process_pool: ProcessPoolConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
def load_config(path: Path) -> OrganismConfig:
|
def load_config(path: Path) -> OrganismConfig:
|
||||||
|
|
@ -152,6 +193,7 @@ def load_config(path: Path) -> OrganismConfig:
|
||||||
peers=listener_raw.get("peers", []),
|
peers=listener_raw.get("peers", []),
|
||||||
allowed_tools=listener_raw.get("allowed_tools", []),
|
allowed_tools=listener_raw.get("allowed_tools", []),
|
||||||
blocked_tools=listener_raw.get("blocked_tools", []),
|
blocked_tools=listener_raw.get("blocked_tools", []),
|
||||||
|
cpu_bound=listener_raw.get("cpu_bound", False),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -174,12 +216,37 @@ def load_config(path: Path) -> OrganismConfig:
|
||||||
totp_secret_env=auth_raw.get("totp_secret_env", "ORGANISM_TOTP_SECRET"),
|
totp_secret_env=auth_raw.get("totp_secret_env", "ORGANISM_TOTP_SECRET"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Parse optional backend config
|
||||||
|
backend = None
|
||||||
|
if "backend" in raw:
|
||||||
|
backend_raw = raw["backend"]
|
||||||
|
backend = BackendStorageConfig(
|
||||||
|
backend_type=backend_raw.get("type", "memory"),
|
||||||
|
redis_url=backend_raw.get("redis_url", "redis://localhost:6379"),
|
||||||
|
redis_prefix=backend_raw.get("redis_prefix", "xp:"),
|
||||||
|
redis_ttl=backend_raw.get("redis_ttl", 86400),
|
||||||
|
max_slots_per_thread=backend_raw.get("max_slots_per_thread", 10000),
|
||||||
|
max_threads=backend_raw.get("max_threads", 1000),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse optional process pool config
|
||||||
|
process_pool = None
|
||||||
|
if "process_pool" in raw:
|
||||||
|
pool_raw = raw["process_pool"]
|
||||||
|
process_pool = ProcessPoolConfig(
|
||||||
|
enabled=pool_raw.get("enabled", True),
|
||||||
|
workers=pool_raw.get("workers", 4),
|
||||||
|
max_tasks_per_child=pool_raw.get("max_tasks_per_child", 100),
|
||||||
|
)
|
||||||
|
|
||||||
return OrganismConfig(
|
return OrganismConfig(
|
||||||
organism=organism,
|
organism=organism,
|
||||||
listeners=listeners,
|
listeners=listeners,
|
||||||
llm_backends=llm_backends,
|
llm_backends=llm_backends,
|
||||||
server=server,
|
server=server,
|
||||||
auth=auth,
|
auth=auth,
|
||||||
|
backend=backend,
|
||||||
|
process_pool=process_pool,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,21 @@ Provides thread-scoped, append-only context buffers with:
|
||||||
- Thread isolation (handlers only see their context)
|
- Thread isolation (handlers only see their context)
|
||||||
- Complete audit trail (all messages preserved)
|
- Complete audit trail (all messages preserved)
|
||||||
- GC and limits (prevent runaway memory usage)
|
- GC and limits (prevent runaway memory usage)
|
||||||
|
|
||||||
|
For multi-process deployments, supports shared backends:
|
||||||
|
- InMemoryBackend: Default single-process storage
|
||||||
|
- ManagerBackend: multiprocessing.Manager for local multi-process
|
||||||
|
- RedisBackend: Redis for distributed deployments
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Default (in-memory, single process)
|
||||||
|
buffer = get_context_buffer()
|
||||||
|
|
||||||
|
# With Redis backend
|
||||||
|
from xml_pipeline.memory.shared_backend import BackendConfig, get_shared_backend
|
||||||
|
config = BackendConfig(backend_type="redis", redis_url="redis://localhost:6379")
|
||||||
|
backend = get_shared_backend(config)
|
||||||
|
buffer = get_context_buffer(backend=backend)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from xml_pipeline.memory.context_buffer import (
|
from xml_pipeline.memory.context_buffer import (
|
||||||
|
|
@ -14,14 +29,33 @@ from xml_pipeline.memory.context_buffer import (
|
||||||
BufferSlot,
|
BufferSlot,
|
||||||
SlotMetadata,
|
SlotMetadata,
|
||||||
get_context_buffer,
|
get_context_buffer,
|
||||||
|
reset_context_buffer,
|
||||||
slot_to_handler_metadata,
|
slot_to_handler_metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from xml_pipeline.memory.shared_backend import (
|
||||||
|
SharedBackend,
|
||||||
|
BackendConfig,
|
||||||
|
get_shared_backend,
|
||||||
|
reset_shared_backend,
|
||||||
|
serialize_slot,
|
||||||
|
deserialize_slot,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Context buffer
|
||||||
"ContextBuffer",
|
"ContextBuffer",
|
||||||
"ThreadContext",
|
"ThreadContext",
|
||||||
"BufferSlot",
|
"BufferSlot",
|
||||||
"SlotMetadata",
|
"SlotMetadata",
|
||||||
"get_context_buffer",
|
"get_context_buffer",
|
||||||
|
"reset_context_buffer",
|
||||||
"slot_to_handler_metadata",
|
"slot_to_handler_metadata",
|
||||||
|
# Shared backend
|
||||||
|
"SharedBackend",
|
||||||
|
"BackendConfig",
|
||||||
|
"get_shared_backend",
|
||||||
|
"reset_shared_backend",
|
||||||
|
"serialize_slot",
|
||||||
|
"deserialize_slot",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -27,16 +27,26 @@ Usage:
|
||||||
|
|
||||||
# Get thread history
|
# Get thread history
|
||||||
history = buffer.get_thread(thread_id)
|
history = buffer.get_thread(thread_id)
|
||||||
|
|
||||||
|
For multi-process deployments, the buffer can use a shared backend:
|
||||||
|
from xml_pipeline.memory.shared_backend import get_shared_backend, BackendConfig
|
||||||
|
|
||||||
|
config = BackendConfig(backend_type="redis", redis_url="redis://localhost:6379")
|
||||||
|
backend = get_shared_backend(config)
|
||||||
|
buffer = get_context_buffer(backend=backend)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional, Iterator
|
from typing import Any, Dict, List, Optional, Iterator, TYPE_CHECKING
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from xml_pipeline.memory.shared_backend import SharedBackend
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class SlotMetadata:
|
class SlotMetadata:
|
||||||
|
|
@ -166,9 +176,26 @@ class ContextBuffer:
|
||||||
Global context buffer managing all thread contexts.
|
Global context buffer managing all thread contexts.
|
||||||
|
|
||||||
Thread-safe. Singleton pattern via get_context_buffer().
|
Thread-safe. Singleton pattern via get_context_buffer().
|
||||||
|
|
||||||
|
Supports two storage modes:
|
||||||
|
1. Local mode (default): Uses in-process ThreadContext objects
|
||||||
|
2. Shared mode: Uses SharedBackend (Redis, Manager) for cross-process access
|
||||||
|
|
||||||
|
In shared mode, slots are serialized via pickle and stored in the backend.
|
||||||
|
This enables multi-process handler dispatch (cpu_bound handlers).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, backend: Optional[SharedBackend] = None):
|
||||||
|
"""
|
||||||
|
Initialize context buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend: Optional shared backend for cross-process storage.
|
||||||
|
If None, uses in-process storage (original behavior).
|
||||||
|
"""
|
||||||
|
self._backend = backend
|
||||||
|
|
||||||
|
# Local storage (used when no backend)
|
||||||
self._threads: Dict[str, ThreadContext] = {}
|
self._threads: Dict[str, ThreadContext] = {}
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
|
@ -176,8 +203,17 @@ class ContextBuffer:
|
||||||
self.max_slots_per_thread: int = 10000
|
self.max_slots_per_thread: int = 10000
|
||||||
self.max_threads: int = 1000
|
self.max_threads: int = 1000
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_shared(self) -> bool:
|
||||||
|
"""Return True if using shared backend."""
|
||||||
|
return self._backend is not None
|
||||||
|
|
||||||
def get_or_create_thread(self, thread_id: str) -> ThreadContext:
|
def get_or_create_thread(self, thread_id: str) -> ThreadContext:
|
||||||
"""Get existing thread context or create new one."""
|
"""
|
||||||
|
Get existing thread context or create new one.
|
||||||
|
|
||||||
|
Note: In shared mode, this creates a local proxy that syncs with backend.
|
||||||
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if thread_id not in self._threads:
|
if thread_id not in self._threads:
|
||||||
if len(self._threads) >= self.max_threads:
|
if len(self._threads) >= self.max_threads:
|
||||||
|
|
@ -205,7 +241,23 @@ class ContextBuffer:
|
||||||
|
|
||||||
This is the main entry point for the pipeline.
|
This is the main entry point for the pipeline.
|
||||||
Returns the immutable slot reference.
|
Returns the immutable slot reference.
|
||||||
|
|
||||||
|
In shared mode, the slot is serialized and stored in the backend.
|
||||||
"""
|
"""
|
||||||
|
if self._backend is not None:
|
||||||
|
# Shared mode: serialize and store in backend
|
||||||
|
return self._append_shared(
|
||||||
|
thread_id=thread_id,
|
||||||
|
payload=payload,
|
||||||
|
from_id=from_id,
|
||||||
|
to_id=to_id,
|
||||||
|
own_name=own_name,
|
||||||
|
is_self_call=is_self_call,
|
||||||
|
usage_instructions=usage_instructions,
|
||||||
|
todo_nudge=todo_nudge,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Local mode: use ThreadContext
|
||||||
thread = self.get_or_create_thread(thread_id)
|
thread = self.get_or_create_thread(thread_id)
|
||||||
|
|
||||||
# Enforce slot limit
|
# Enforce slot limit
|
||||||
|
|
@ -224,18 +276,116 @@ class ContextBuffer:
|
||||||
todo_nudge=todo_nudge,
|
todo_nudge=todo_nudge,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _append_shared(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
payload: Any,
|
||||||
|
from_id: str,
|
||||||
|
to_id: str,
|
||||||
|
own_name: Optional[str] = None,
|
||||||
|
is_self_call: bool = False,
|
||||||
|
usage_instructions: str = "",
|
||||||
|
todo_nudge: str = "",
|
||||||
|
) -> BufferSlot:
|
||||||
|
"""Append to shared backend."""
|
||||||
|
from xml_pipeline.memory.shared_backend import serialize_slot
|
||||||
|
|
||||||
|
assert self._backend is not None
|
||||||
|
|
||||||
|
# Get current slot count for index
|
||||||
|
current_len = self._backend.buffer_thread_len(thread_id)
|
||||||
|
|
||||||
|
# Enforce slot limit
|
||||||
|
if current_len >= self.max_slots_per_thread:
|
||||||
|
raise MemoryError(
|
||||||
|
f"Thread {thread_id} exceeded max slots ({self.max_slots_per_thread})"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create metadata
|
||||||
|
metadata = SlotMetadata(
|
||||||
|
thread_id=thread_id,
|
||||||
|
from_id=from_id,
|
||||||
|
to_id=to_id,
|
||||||
|
slot_index=current_len,
|
||||||
|
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||||
|
payload_type=type(payload).__name__,
|
||||||
|
own_name=own_name,
|
||||||
|
is_self_call=is_self_call,
|
||||||
|
usage_instructions=usage_instructions,
|
||||||
|
todo_nudge=todo_nudge,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create slot
|
||||||
|
slot = BufferSlot(payload=payload, metadata=metadata)
|
||||||
|
|
||||||
|
# Serialize and store
|
||||||
|
slot_data = serialize_slot(slot)
|
||||||
|
self._backend.buffer_append(thread_id, slot_data)
|
||||||
|
|
||||||
|
return slot
|
||||||
|
|
||||||
def get_thread(self, thread_id: str) -> Optional[ThreadContext]:
|
def get_thread(self, thread_id: str) -> Optional[ThreadContext]:
|
||||||
"""Get a thread's context (None if not found)."""
|
"""
|
||||||
|
Get a thread's context (None if not found).
|
||||||
|
|
||||||
|
In shared mode, returns a local ThreadContext populated from backend.
|
||||||
|
"""
|
||||||
|
if self._backend is not None:
|
||||||
|
return self._get_thread_shared(thread_id)
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return self._threads.get(thread_id)
|
return self._threads.get(thread_id)
|
||||||
|
|
||||||
|
def _get_thread_shared(self, thread_id: str) -> Optional[ThreadContext]:
|
||||||
|
"""Get thread from shared backend."""
|
||||||
|
from xml_pipeline.memory.shared_backend import deserialize_slot
|
||||||
|
|
||||||
|
assert self._backend is not None
|
||||||
|
|
||||||
|
if not self._backend.buffer_thread_exists(thread_id):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create local ThreadContext and populate from backend
|
||||||
|
thread = ThreadContext(thread_id)
|
||||||
|
slot_data_list = self._backend.buffer_get_thread(thread_id)
|
||||||
|
|
||||||
|
for slot_data in slot_data_list:
|
||||||
|
slot = deserialize_slot(slot_data)
|
||||||
|
thread._slots.append(slot)
|
||||||
|
|
||||||
|
return thread
|
||||||
|
|
||||||
|
def get_thread_slots(self, thread_id: str) -> List[BufferSlot]:
|
||||||
|
"""
|
||||||
|
Get all slots for a thread as a list.
|
||||||
|
|
||||||
|
More efficient than get_thread() when you just need slots.
|
||||||
|
"""
|
||||||
|
if self._backend is not None:
|
||||||
|
from xml_pipeline.memory.shared_backend import deserialize_slot
|
||||||
|
|
||||||
|
slot_data_list = self._backend.buffer_get_thread(thread_id)
|
||||||
|
return [deserialize_slot(data) for data in slot_data_list]
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
thread = self._threads.get(thread_id)
|
||||||
|
if thread:
|
||||||
|
return list(thread._slots)
|
||||||
|
return []
|
||||||
|
|
||||||
def thread_exists(self, thread_id: str) -> bool:
|
def thread_exists(self, thread_id: str) -> bool:
|
||||||
"""Check if a thread exists."""
|
"""Check if a thread exists."""
|
||||||
|
if self._backend is not None:
|
||||||
|
return self._backend.buffer_thread_exists(thread_id)
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return thread_id in self._threads
|
return thread_id in self._threads
|
||||||
|
|
||||||
def delete_thread(self, thread_id: str) -> bool:
|
def delete_thread(self, thread_id: str) -> bool:
|
||||||
"""Delete a thread's context (GC)."""
|
"""Delete a thread's context (GC)."""
|
||||||
|
if self._backend is not None:
|
||||||
|
return self._backend.buffer_delete_thread(thread_id)
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if thread_id in self._threads:
|
if thread_id in self._threads:
|
||||||
del self._threads[thread_id]
|
del self._threads[thread_id]
|
||||||
|
|
@ -244,6 +394,20 @@ class ContextBuffer:
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
"""Get buffer statistics."""
|
"""Get buffer statistics."""
|
||||||
|
if self._backend is not None:
|
||||||
|
threads = self._backend.buffer_list_threads()
|
||||||
|
total_slots = sum(
|
||||||
|
self._backend.buffer_thread_len(t) for t in threads
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"thread_count": len(threads),
|
||||||
|
"total_slots": total_slots,
|
||||||
|
"max_threads": self.max_threads,
|
||||||
|
"max_slots_per_thread": self.max_slots_per_thread,
|
||||||
|
"threads": threads,
|
||||||
|
"backend": "shared",
|
||||||
|
}
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
total_slots = sum(len(t) for t in self._threads.values())
|
total_slots = sum(len(t) for t in self._threads.values())
|
||||||
return {
|
return {
|
||||||
|
|
@ -252,10 +416,15 @@ class ContextBuffer:
|
||||||
"max_threads": self.max_threads,
|
"max_threads": self.max_threads,
|
||||||
"max_slots_per_thread": self.max_slots_per_thread,
|
"max_slots_per_thread": self.max_slots_per_thread,
|
||||||
"threads": list(self._threads.keys()),
|
"threads": list(self._threads.keys()),
|
||||||
|
"backend": "local",
|
||||||
}
|
}
|
||||||
|
|
||||||
def clear(self):
|
def clear(self) -> None:
|
||||||
"""Clear all contexts (for testing)."""
|
"""Clear all contexts (for testing)."""
|
||||||
|
if self._backend is not None:
|
||||||
|
self._backend.buffer_clear()
|
||||||
|
return
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._threads.clear()
|
self._threads.clear()
|
||||||
|
|
||||||
|
|
@ -268,16 +437,35 @@ _buffer: Optional[ContextBuffer] = None
|
||||||
_buffer_lock = threading.Lock()
|
_buffer_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def get_context_buffer() -> ContextBuffer:
|
def get_context_buffer(backend: Optional[SharedBackend] = None) -> ContextBuffer:
|
||||||
"""Get the global ContextBuffer singleton."""
|
"""
|
||||||
|
Get the global ContextBuffer singleton.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend: Optional shared backend for cross-process storage.
|
||||||
|
Only used on first call (when creating the singleton).
|
||||||
|
Subsequent calls return the existing singleton.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Global ContextBuffer instance.
|
||||||
|
"""
|
||||||
global _buffer
|
global _buffer
|
||||||
if _buffer is None:
|
if _buffer is None:
|
||||||
with _buffer_lock:
|
with _buffer_lock:
|
||||||
if _buffer is None:
|
if _buffer is None:
|
||||||
_buffer = ContextBuffer()
|
_buffer = ContextBuffer(backend=backend)
|
||||||
return _buffer
|
return _buffer
|
||||||
|
|
||||||
|
|
||||||
|
def reset_context_buffer() -> None:
|
||||||
|
"""Reset the global context buffer (for testing)."""
|
||||||
|
global _buffer
|
||||||
|
with _buffer_lock:
|
||||||
|
if _buffer is not None:
|
||||||
|
_buffer.clear()
|
||||||
|
_buffer = None
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Handler-facing metadata adapter
|
# Handler-facing metadata adapter
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
|
||||||
220
xml_pipeline/memory/manager_backend.py
Normal file
220
xml_pipeline/memory/manager_backend.py
Normal file
|
|
@ -0,0 +1,220 @@
|
||||||
|
"""
|
||||||
|
manager_backend.py — multiprocessing.Manager implementation of SharedBackend.
|
||||||
|
|
||||||
|
Uses Python's multiprocessing.Manager for cross-process state sharing
|
||||||
|
without requiring Redis. Good for local development and testing.
|
||||||
|
|
||||||
|
Key differences from InMemoryBackend:
|
||||||
|
- Data structures are proxied across processes
|
||||||
|
- Slightly higher overhead than in-memory
|
||||||
|
- No TTL/expiration (use for short-lived test runs)
|
||||||
|
- No persistence (data lost on manager shutdown)
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Start manager in main process
|
||||||
|
backend = ManagerBackend()
|
||||||
|
|
||||||
|
# Workers connect via the same address
|
||||||
|
# (Manager handles cross-process communication)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import multiprocessing
|
||||||
|
from multiprocessing.managers import SyncManager
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ManagerBackend:
|
||||||
|
"""
|
||||||
|
multiprocessing.Manager-backed shared state.
|
||||||
|
|
||||||
|
Provides cross-process state without external dependencies.
|
||||||
|
Suitable for local multi-process testing when Redis isn't available.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
address: Optional[Tuple[str, int]] = None,
|
||||||
|
authkey: Optional[bytes] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize Manager backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
address: (host, port) for remote manager connection.
|
||||||
|
If None, creates a local manager.
|
||||||
|
authkey: Authentication key for manager.
|
||||||
|
If None, uses process authkey.
|
||||||
|
"""
|
||||||
|
self._is_server = address is None
|
||||||
|
self._manager: Optional[SyncManager] = None
|
||||||
|
|
||||||
|
if self._is_server:
|
||||||
|
# Create a local manager (server mode)
|
||||||
|
self._manager = multiprocessing.Manager()
|
||||||
|
self._init_data_structures()
|
||||||
|
logger.info("Manager backend started (server mode)")
|
||||||
|
else:
|
||||||
|
# Connect to remote manager (client mode)
|
||||||
|
# Note: For remote connection, you'd need a custom SyncManager
|
||||||
|
# This is simplified for now - just use local manager
|
||||||
|
self._manager = multiprocessing.Manager()
|
||||||
|
self._init_data_structures()
|
||||||
|
logger.info("Manager backend started (client mode)")
|
||||||
|
|
||||||
|
def _init_data_structures(self) -> None:
|
||||||
|
"""Initialize managed data structures."""
|
||||||
|
if self._manager is None:
|
||||||
|
raise RuntimeError("Manager not initialized")
|
||||||
|
|
||||||
|
# Context buffer: thread_id → list of pickled slots
|
||||||
|
self._buffer: Dict[str, List[bytes]] = self._manager.dict()
|
||||||
|
|
||||||
|
# Thread registry: bidirectional mapping
|
||||||
|
self._chain_to_uuid: Dict[str, str] = self._manager.dict()
|
||||||
|
self._uuid_to_chain: Dict[str, str] = self._manager.dict()
|
||||||
|
|
||||||
|
# Lock for complex operations
|
||||||
|
self._lock = self._manager.Lock()
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Context Buffer Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def buffer_append(self, thread_id: str, slot_data: bytes) -> int:
|
||||||
|
"""Append a slot to a thread's buffer."""
|
||||||
|
with self._lock:
|
||||||
|
if thread_id not in self._buffer:
|
||||||
|
# Manager.dict doesn't support nested assignment directly
|
||||||
|
# We need to get, modify, put back
|
||||||
|
self._buffer[thread_id] = self._manager.list() # type: ignore
|
||||||
|
|
||||||
|
# Get the list, append, and count
|
||||||
|
slots = self._buffer[thread_id]
|
||||||
|
slots.append(slot_data)
|
||||||
|
return len(slots) - 1
|
||||||
|
|
||||||
|
def buffer_get_thread(self, thread_id: str) -> List[bytes]:
|
||||||
|
"""Get all slots for a thread."""
|
||||||
|
with self._lock:
|
||||||
|
if thread_id not in self._buffer:
|
||||||
|
return []
|
||||||
|
# Convert manager.list to regular list
|
||||||
|
return list(self._buffer[thread_id])
|
||||||
|
|
||||||
|
def buffer_get_slot(self, thread_id: str, index: int) -> Optional[bytes]:
|
||||||
|
"""Get a specific slot by index."""
|
||||||
|
with self._lock:
|
||||||
|
if thread_id not in self._buffer:
|
||||||
|
return None
|
||||||
|
slots = self._buffer[thread_id]
|
||||||
|
if 0 <= index < len(slots):
|
||||||
|
return slots[index]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def buffer_thread_len(self, thread_id: str) -> int:
|
||||||
|
"""Get number of slots in a thread."""
|
||||||
|
with self._lock:
|
||||||
|
if thread_id not in self._buffer:
|
||||||
|
return 0
|
||||||
|
return len(self._buffer[thread_id])
|
||||||
|
|
||||||
|
def buffer_thread_exists(self, thread_id: str) -> bool:
|
||||||
|
"""Check if a thread has any slots."""
|
||||||
|
with self._lock:
|
||||||
|
return thread_id in self._buffer
|
||||||
|
|
||||||
|
def buffer_delete_thread(self, thread_id: str) -> bool:
|
||||||
|
"""Delete all slots for a thread."""
|
||||||
|
with self._lock:
|
||||||
|
if thread_id in self._buffer:
|
||||||
|
del self._buffer[thread_id]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def buffer_list_threads(self) -> List[str]:
|
||||||
|
"""List all thread IDs with slots."""
|
||||||
|
with self._lock:
|
||||||
|
return list(self._buffer.keys())
|
||||||
|
|
||||||
|
def buffer_clear(self) -> None:
|
||||||
|
"""Clear all buffer data."""
|
||||||
|
with self._lock:
|
||||||
|
# Can't call .clear() on manager.dict proxy
|
||||||
|
keys = list(self._buffer.keys())
|
||||||
|
for key in keys:
|
||||||
|
del self._buffer[key]
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Thread Registry Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def registry_set(self, chain: str, uuid: str) -> None:
|
||||||
|
"""Set bidirectional mapping: chain ↔ uuid."""
|
||||||
|
with self._lock:
|
||||||
|
self._chain_to_uuid[chain] = uuid
|
||||||
|
self._uuid_to_chain[uuid] = chain
|
||||||
|
|
||||||
|
def registry_get_uuid(self, chain: str) -> Optional[str]:
|
||||||
|
"""Get UUID for a chain."""
|
||||||
|
with self._lock:
|
||||||
|
return self._chain_to_uuid.get(chain)
|
||||||
|
|
||||||
|
def registry_get_chain(self, uuid: str) -> Optional[str]:
|
||||||
|
"""Get chain for a UUID."""
|
||||||
|
with self._lock:
|
||||||
|
return self._uuid_to_chain.get(uuid)
|
||||||
|
|
||||||
|
def registry_delete(self, uuid: str) -> bool:
|
||||||
|
"""Delete mapping by UUID."""
|
||||||
|
with self._lock:
|
||||||
|
chain = self._uuid_to_chain.get(uuid)
|
||||||
|
if chain:
|
||||||
|
del self._uuid_to_chain[uuid]
|
||||||
|
del self._chain_to_uuid[chain]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def registry_list_all(self) -> Dict[str, str]:
|
||||||
|
"""Get all UUID → chain mappings."""
|
||||||
|
with self._lock:
|
||||||
|
return dict(self._uuid_to_chain)
|
||||||
|
|
||||||
|
def registry_clear(self) -> None:
|
||||||
|
"""Clear all registry data."""
|
||||||
|
with self._lock:
|
||||||
|
# Can't call .clear() on manager.dict proxy
|
||||||
|
for key in list(self._chain_to_uuid.keys()):
|
||||||
|
del self._chain_to_uuid[key]
|
||||||
|
for key in list(self._uuid_to_chain.keys()):
|
||||||
|
del self._uuid_to_chain[key]
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Lifecycle
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Shutdown manager."""
|
||||||
|
if self._manager and self._is_server:
|
||||||
|
try:
|
||||||
|
self._manager.shutdown()
|
||||||
|
logger.info("Manager backend shutdown")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Manager shutdown error: {e}")
|
||||||
|
self._manager = None
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Info
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def info(self) -> Dict[str, int]:
|
||||||
|
"""Get backend statistics."""
|
||||||
|
with self._lock:
|
||||||
|
return {
|
||||||
|
"buffer_threads": len(self._buffer),
|
||||||
|
"registry_entries": len(self._uuid_to_chain),
|
||||||
|
}
|
||||||
136
xml_pipeline/memory/memory_backend.py
Normal file
136
xml_pipeline/memory/memory_backend.py
Normal file
|
|
@ -0,0 +1,136 @@
|
||||||
|
"""
|
||||||
|
memory_backend.py — In-memory implementation of SharedBackend.
|
||||||
|
|
||||||
|
This is the default backend for single-process operation.
|
||||||
|
It provides the same interface as Redis/Manager backends but stores
|
||||||
|
everything in Python data structures.
|
||||||
|
|
||||||
|
Thread-safe via threading.Lock (same as original ContextBuffer/ThreadRegistry).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryBackend:
|
||||||
|
"""
|
||||||
|
In-memory shared backend implementation.
|
||||||
|
|
||||||
|
This is equivalent to the original behavior - all data in-process.
|
||||||
|
Use for development, testing, or single-process deployments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# Context buffer storage: thread_id → list of pickled slots
|
||||||
|
self._buffer: Dict[str, List[bytes]] = {}
|
||||||
|
self._buffer_lock = threading.Lock()
|
||||||
|
|
||||||
|
# Thread registry storage: bidirectional mapping
|
||||||
|
self._chain_to_uuid: Dict[str, str] = {}
|
||||||
|
self._uuid_to_chain: Dict[str, str] = {}
|
||||||
|
self._registry_lock = threading.Lock()
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Context Buffer Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def buffer_append(self, thread_id: str, slot_data: bytes) -> int:
|
||||||
|
"""Append a slot to a thread's buffer."""
|
||||||
|
with self._buffer_lock:
|
||||||
|
if thread_id not in self._buffer:
|
||||||
|
self._buffer[thread_id] = []
|
||||||
|
|
||||||
|
index = len(self._buffer[thread_id])
|
||||||
|
self._buffer[thread_id].append(slot_data)
|
||||||
|
return index
|
||||||
|
|
||||||
|
def buffer_get_thread(self, thread_id: str) -> List[bytes]:
|
||||||
|
"""Get all slots for a thread."""
|
||||||
|
with self._buffer_lock:
|
||||||
|
return list(self._buffer.get(thread_id, []))
|
||||||
|
|
||||||
|
def buffer_get_slot(self, thread_id: str, index: int) -> Optional[bytes]:
|
||||||
|
"""Get a specific slot by index."""
|
||||||
|
with self._buffer_lock:
|
||||||
|
slots = self._buffer.get(thread_id, [])
|
||||||
|
if 0 <= index < len(slots):
|
||||||
|
return slots[index]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def buffer_thread_len(self, thread_id: str) -> int:
|
||||||
|
"""Get number of slots in a thread."""
|
||||||
|
with self._buffer_lock:
|
||||||
|
return len(self._buffer.get(thread_id, []))
|
||||||
|
|
||||||
|
def buffer_thread_exists(self, thread_id: str) -> bool:
|
||||||
|
"""Check if a thread has any slots."""
|
||||||
|
with self._buffer_lock:
|
||||||
|
return thread_id in self._buffer
|
||||||
|
|
||||||
|
def buffer_delete_thread(self, thread_id: str) -> bool:
|
||||||
|
"""Delete all slots for a thread."""
|
||||||
|
with self._buffer_lock:
|
||||||
|
if thread_id in self._buffer:
|
||||||
|
del self._buffer[thread_id]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def buffer_list_threads(self) -> List[str]:
|
||||||
|
"""List all thread IDs with slots."""
|
||||||
|
with self._buffer_lock:
|
||||||
|
return list(self._buffer.keys())
|
||||||
|
|
||||||
|
def buffer_clear(self) -> None:
|
||||||
|
"""Clear all buffer data."""
|
||||||
|
with self._buffer_lock:
|
||||||
|
self._buffer.clear()
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Thread Registry Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def registry_set(self, chain: str, uuid: str) -> None:
|
||||||
|
"""Set bidirectional mapping: chain ↔ uuid."""
|
||||||
|
with self._registry_lock:
|
||||||
|
self._chain_to_uuid[chain] = uuid
|
||||||
|
self._uuid_to_chain[uuid] = chain
|
||||||
|
|
||||||
|
def registry_get_uuid(self, chain: str) -> Optional[str]:
|
||||||
|
"""Get UUID for a chain."""
|
||||||
|
with self._registry_lock:
|
||||||
|
return self._chain_to_uuid.get(chain)
|
||||||
|
|
||||||
|
def registry_get_chain(self, uuid: str) -> Optional[str]:
|
||||||
|
"""Get chain for a UUID."""
|
||||||
|
with self._registry_lock:
|
||||||
|
return self._uuid_to_chain.get(uuid)
|
||||||
|
|
||||||
|
def registry_delete(self, uuid: str) -> bool:
|
||||||
|
"""Delete mapping by UUID."""
|
||||||
|
with self._registry_lock:
|
||||||
|
chain = self._uuid_to_chain.pop(uuid, None)
|
||||||
|
if chain:
|
||||||
|
self._chain_to_uuid.pop(chain, None)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def registry_list_all(self) -> Dict[str, str]:
|
||||||
|
"""Get all UUID → chain mappings."""
|
||||||
|
with self._registry_lock:
|
||||||
|
return dict(self._uuid_to_chain)
|
||||||
|
|
||||||
|
def registry_clear(self) -> None:
|
||||||
|
"""Clear all registry data."""
|
||||||
|
with self._registry_lock:
|
||||||
|
self._chain_to_uuid.clear()
|
||||||
|
self._uuid_to_chain.clear()
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Lifecycle
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""No resources to clean up for in-memory backend."""
|
||||||
|
pass
|
||||||
262
xml_pipeline/memory/redis_backend.py
Normal file
262
xml_pipeline/memory/redis_backend.py
Normal file
|
|
@ -0,0 +1,262 @@
|
||||||
|
"""
|
||||||
|
redis_backend.py — Redis implementation of SharedBackend.
|
||||||
|
|
||||||
|
Uses Redis for distributed shared state, enabling:
|
||||||
|
- Multi-process handler dispatch (cpu_bound handlers in ProcessPool)
|
||||||
|
- Multi-tenant deployments (key prefix isolation)
|
||||||
|
- Automatic TTL-based garbage collection
|
||||||
|
- Cross-machine state sharing (future: WASM handlers)
|
||||||
|
|
||||||
|
Key schema:
|
||||||
|
- Buffer slots: `{prefix}buffer:{thread_id}` → Redis LIST of pickled slots
|
||||||
|
- Registry chain→uuid: `{prefix}chain:{chain}` → uuid string
|
||||||
|
- Registry uuid→chain: `{prefix}uuid:{uuid}` → chain string
|
||||||
|
|
||||||
|
Dependencies:
|
||||||
|
pip install redis
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
try:
|
||||||
|
import redis
|
||||||
|
except ImportError:
|
||||||
|
redis = None # type: ignore
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisBackend:
|
||||||
|
"""
|
||||||
|
Redis-backed shared state for multi-process deployments.
|
||||||
|
|
||||||
|
All operations are synchronous (blocking). For async contexts,
|
||||||
|
wrap calls in asyncio.to_thread() or use a connection pool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
url: str = "redis://localhost:6379",
|
||||||
|
prefix: str = "xp:",
|
||||||
|
ttl: int = 86400,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize Redis backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: Redis connection URL
|
||||||
|
prefix: Key prefix for multi-tenancy isolation
|
||||||
|
ttl: TTL in seconds for automatic cleanup (0 = no TTL)
|
||||||
|
"""
|
||||||
|
if redis is None:
|
||||||
|
raise ImportError(
|
||||||
|
"redis package not installed. Install with: pip install redis"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.url = url
|
||||||
|
self.prefix = prefix
|
||||||
|
self.ttl = ttl
|
||||||
|
|
||||||
|
# Connect to Redis
|
||||||
|
self._client = redis.from_url(url, decode_responses=False)
|
||||||
|
|
||||||
|
# Verify connection
|
||||||
|
try:
|
||||||
|
self._client.ping()
|
||||||
|
logger.info(f"Redis backend connected: {url}")
|
||||||
|
except redis.ConnectionError as e:
|
||||||
|
logger.error(f"Redis connection failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _buffer_key(self, thread_id: str) -> str:
|
||||||
|
"""Get Redis key for a thread's buffer."""
|
||||||
|
return f"{self.prefix}buffer:{thread_id}"
|
||||||
|
|
||||||
|
def _chain_key(self, chain: str) -> str:
|
||||||
|
"""Get Redis key for chain→uuid mapping."""
|
||||||
|
return f"{self.prefix}chain:{chain}"
|
||||||
|
|
||||||
|
def _uuid_key(self, uuid: str) -> str:
|
||||||
|
"""Get Redis key for uuid→chain mapping."""
|
||||||
|
return f"{self.prefix}uuid:{uuid}"
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Context Buffer Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def buffer_append(self, thread_id: str, slot_data: bytes) -> int:
|
||||||
|
"""Append a slot to a thread's buffer."""
|
||||||
|
key = self._buffer_key(thread_id)
|
||||||
|
|
||||||
|
# RPUSH returns new length, so index = length - 1
|
||||||
|
new_len = self._client.rpush(key, slot_data)
|
||||||
|
|
||||||
|
# Set/refresh TTL
|
||||||
|
if self.ttl > 0:
|
||||||
|
self._client.expire(key, self.ttl)
|
||||||
|
|
||||||
|
return new_len - 1
|
||||||
|
|
||||||
|
def buffer_get_thread(self, thread_id: str) -> List[bytes]:
|
||||||
|
"""Get all slots for a thread."""
|
||||||
|
key = self._buffer_key(thread_id)
|
||||||
|
# LRANGE returns all elements (0, -1 = full list)
|
||||||
|
result = self._client.lrange(key, 0, -1)
|
||||||
|
return result if result else []
|
||||||
|
|
||||||
|
def buffer_get_slot(self, thread_id: str, index: int) -> Optional[bytes]:
|
||||||
|
"""Get a specific slot by index."""
|
||||||
|
key = self._buffer_key(thread_id)
|
||||||
|
# LINDEX returns element at index, or None
|
||||||
|
return self._client.lindex(key, index)
|
||||||
|
|
||||||
|
def buffer_thread_len(self, thread_id: str) -> int:
|
||||||
|
"""Get number of slots in a thread."""
|
||||||
|
key = self._buffer_key(thread_id)
|
||||||
|
return self._client.llen(key)
|
||||||
|
|
||||||
|
def buffer_thread_exists(self, thread_id: str) -> bool:
|
||||||
|
"""Check if a thread has any slots."""
|
||||||
|
key = self._buffer_key(thread_id)
|
||||||
|
return self._client.exists(key) > 0
|
||||||
|
|
||||||
|
def buffer_delete_thread(self, thread_id: str) -> bool:
|
||||||
|
"""Delete all slots for a thread."""
|
||||||
|
key = self._buffer_key(thread_id)
|
||||||
|
deleted = self._client.delete(key)
|
||||||
|
return deleted > 0
|
||||||
|
|
||||||
|
def buffer_list_threads(self) -> List[str]:
|
||||||
|
"""List all thread IDs with slots."""
|
||||||
|
pattern = f"{self.prefix}buffer:*"
|
||||||
|
keys = self._client.keys(pattern)
|
||||||
|
prefix_len = len(f"{self.prefix}buffer:")
|
||||||
|
return [k.decode() if isinstance(k, bytes) else k for k in keys]
|
||||||
|
|
||||||
|
def buffer_clear(self) -> None:
|
||||||
|
"""Clear all buffer data."""
|
||||||
|
pattern = f"{self.prefix}buffer:*"
|
||||||
|
keys = self._client.keys(pattern)
|
||||||
|
if keys:
|
||||||
|
self._client.delete(*keys)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Thread Registry Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def registry_set(self, chain: str, uuid: str) -> None:
|
||||||
|
"""Set bidirectional mapping: chain ↔ uuid."""
|
||||||
|
chain_key = self._chain_key(chain)
|
||||||
|
uuid_key = self._uuid_key(uuid)
|
||||||
|
|
||||||
|
# Use pipeline for atomicity
|
||||||
|
pipe = self._client.pipeline()
|
||||||
|
pipe.set(chain_key, uuid.encode())
|
||||||
|
pipe.set(uuid_key, chain.encode())
|
||||||
|
|
||||||
|
if self.ttl > 0:
|
||||||
|
pipe.expire(chain_key, self.ttl)
|
||||||
|
pipe.expire(uuid_key, self.ttl)
|
||||||
|
|
||||||
|
pipe.execute()
|
||||||
|
|
||||||
|
def registry_get_uuid(self, chain: str) -> Optional[str]:
|
||||||
|
"""Get UUID for a chain."""
|
||||||
|
key = self._chain_key(chain)
|
||||||
|
value = self._client.get(key)
|
||||||
|
if value:
|
||||||
|
return value.decode() if isinstance(value, bytes) else value
|
||||||
|
return None
|
||||||
|
|
||||||
|
def registry_get_chain(self, uuid: str) -> Optional[str]:
|
||||||
|
"""Get chain for a UUID."""
|
||||||
|
key = self._uuid_key(uuid)
|
||||||
|
value = self._client.get(key)
|
||||||
|
if value:
|
||||||
|
return value.decode() if isinstance(value, bytes) else value
|
||||||
|
return None
|
||||||
|
|
||||||
|
def registry_delete(self, uuid: str) -> bool:
|
||||||
|
"""Delete mapping by UUID."""
|
||||||
|
uuid_key = self._uuid_key(uuid)
|
||||||
|
|
||||||
|
# First get chain to delete both directions
|
||||||
|
chain = self._client.get(uuid_key)
|
||||||
|
if not chain:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if isinstance(chain, bytes):
|
||||||
|
chain = chain.decode()
|
||||||
|
|
||||||
|
chain_key = self._chain_key(chain)
|
||||||
|
|
||||||
|
# Delete both keys
|
||||||
|
deleted = self._client.delete(uuid_key, chain_key)
|
||||||
|
return deleted > 0
|
||||||
|
|
||||||
|
def registry_list_all(self) -> Dict[str, str]:
|
||||||
|
"""Get all UUID → chain mappings."""
|
||||||
|
pattern = f"{self.prefix}uuid:*"
|
||||||
|
keys = self._client.keys(pattern)
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
prefix_len = len(f"{self.prefix}uuid:")
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
if isinstance(key, bytes):
|
||||||
|
key = key.decode()
|
||||||
|
uuid = key[prefix_len:]
|
||||||
|
chain = self._client.get(key)
|
||||||
|
if chain:
|
||||||
|
if isinstance(chain, bytes):
|
||||||
|
chain = chain.decode()
|
||||||
|
result[uuid] = chain
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def registry_clear(self) -> None:
|
||||||
|
"""Clear all registry data."""
|
||||||
|
chain_pattern = f"{self.prefix}chain:*"
|
||||||
|
uuid_pattern = f"{self.prefix}uuid:*"
|
||||||
|
|
||||||
|
chain_keys = self._client.keys(chain_pattern)
|
||||||
|
uuid_keys = self._client.keys(uuid_pattern)
|
||||||
|
|
||||||
|
all_keys = list(chain_keys) + list(uuid_keys)
|
||||||
|
if all_keys:
|
||||||
|
self._client.delete(*all_keys)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Lifecycle
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close Redis connection."""
|
||||||
|
if self._client:
|
||||||
|
self._client.close()
|
||||||
|
logger.info("Redis backend closed")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Health Check
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def ping(self) -> bool:
|
||||||
|
"""Check if Redis is reachable."""
|
||||||
|
try:
|
||||||
|
return self._client.ping()
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def info(self) -> Dict[str, int]:
|
||||||
|
"""Get backend statistics."""
|
||||||
|
buffer_keys = len(self._client.keys(f"{self.prefix}buffer:*"))
|
||||||
|
uuid_keys = len(self._client.keys(f"{self.prefix}uuid:*"))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"buffer_threads": buffer_keys,
|
||||||
|
"registry_entries": uuid_keys,
|
||||||
|
"ttl_seconds": self.ttl,
|
||||||
|
}
|
||||||
275
xml_pipeline/memory/shared_backend.py
Normal file
275
xml_pipeline/memory/shared_backend.py
Normal file
|
|
@ -0,0 +1,275 @@
|
||||||
|
"""
|
||||||
|
shared_backend.py — Abstract backend interface for shared state.
|
||||||
|
|
||||||
|
Provides a Protocol defining operations needed by ContextBuffer and ThreadRegistry
|
||||||
|
to work across processes. Implementations include:
|
||||||
|
- InMemoryBackend: Default, single-process (current behavior)
|
||||||
|
- ManagerBackend: multiprocessing.Manager for local multi-process
|
||||||
|
- RedisBackend: Redis for distributed/multi-tenant scenarios
|
||||||
|
|
||||||
|
All backends support:
|
||||||
|
- Context buffer operations (thread slots)
|
||||||
|
- Thread registry operations (UUID ↔ chain mapping)
|
||||||
|
- TTL for automatic garbage collection
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
from abc import abstractmethod
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Protocol, Tuple, runtime_checkable
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BackendConfig:
|
||||||
|
"""Configuration for shared backend selection."""
|
||||||
|
|
||||||
|
backend_type: str = "memory" # "memory", "manager", "redis"
|
||||||
|
|
||||||
|
# Redis-specific config
|
||||||
|
redis_url: str = "redis://localhost:6379"
|
||||||
|
redis_prefix: str = "xp:"
|
||||||
|
redis_ttl: int = 86400 # 24 hours default
|
||||||
|
|
||||||
|
# Manager-specific config (for local multiprocess)
|
||||||
|
manager_address: Optional[Tuple[str, int]] = None
|
||||||
|
manager_authkey: Optional[bytes] = None
|
||||||
|
|
||||||
|
# Common limits
|
||||||
|
max_slots_per_thread: int = 10000
|
||||||
|
max_threads: int = 1000
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class SharedBackend(Protocol):
|
||||||
|
"""
|
||||||
|
Protocol for shared state backends.
|
||||||
|
|
||||||
|
All methods should be synchronous (blocking) for simplicity.
|
||||||
|
Async wrappers can be added at the caller level if needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Context Buffer Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def buffer_append(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
slot_data: bytes, # Pickled BufferSlot
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Append a slot to a thread's buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: UUID of the thread
|
||||||
|
slot_data: Pickled BufferSlot bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Index of the appended slot (0-based)
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def buffer_get_thread(self, thread_id: str) -> List[bytes]:
|
||||||
|
"""
|
||||||
|
Get all slots for a thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: UUID of the thread
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of pickled BufferSlot bytes (in order)
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def buffer_get_slot(self, thread_id: str, index: int) -> Optional[bytes]:
|
||||||
|
"""
|
||||||
|
Get a specific slot by index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: UUID of the thread
|
||||||
|
index: Slot index (0-based)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Pickled BufferSlot bytes, or None if not found
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def buffer_thread_len(self, thread_id: str) -> int:
|
||||||
|
"""Get number of slots in a thread."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def buffer_thread_exists(self, thread_id: str) -> bool:
|
||||||
|
"""Check if a thread has any slots."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def buffer_delete_thread(self, thread_id: str) -> bool:
|
||||||
|
"""Delete all slots for a thread. Returns True if thread existed."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def buffer_list_threads(self) -> List[str]:
|
||||||
|
"""List all thread IDs with slots."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def buffer_clear(self) -> None:
|
||||||
|
"""Clear all buffer data (for testing)."""
|
||||||
|
...
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Thread Registry Operations
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def registry_set(self, chain: str, uuid: str) -> None:
|
||||||
|
"""
|
||||||
|
Set bidirectional mapping: chain ↔ uuid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chain: Dot-separated call chain (e.g., "console.router.greeter")
|
||||||
|
uuid: UUID string for this chain
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def registry_get_uuid(self, chain: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get UUID for a chain.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chain: Call chain to look up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UUID string, or None if not found
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def registry_get_chain(self, uuid: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get chain for a UUID.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
uuid: UUID to look up
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Chain string, or None if not found
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def registry_delete(self, uuid: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete mapping by UUID.
|
||||||
|
|
||||||
|
Removes both chain→uuid and uuid→chain mappings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if mapping existed
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def registry_list_all(self) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Get all UUID → chain mappings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping UUID to chain
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def registry_clear(self) -> None:
|
||||||
|
"""Clear all registry data (for testing)."""
|
||||||
|
...
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Lifecycle
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def close(self) -> None:
|
||||||
|
"""Close connections and clean up resources."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Serialization Helpers
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_slot(slot: Any) -> bytes:
|
||||||
|
"""Serialize a BufferSlot to bytes using pickle."""
|
||||||
|
return pickle.dumps(slot)
|
||||||
|
|
||||||
|
|
||||||
|
def deserialize_slot(data: bytes) -> Any:
|
||||||
|
"""Deserialize bytes back to a BufferSlot."""
|
||||||
|
return pickle.loads(data)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Factory
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
_backend: Optional[SharedBackend] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_shared_backend(config: Optional[BackendConfig] = None) -> SharedBackend:
|
||||||
|
"""
|
||||||
|
Get or create the global shared backend.
|
||||||
|
|
||||||
|
Backend selection:
|
||||||
|
1. If Redis URL configured and redis available → RedisBackend
|
||||||
|
2. If Manager configured → ManagerBackend
|
||||||
|
3. Otherwise → InMemoryBackend (default)
|
||||||
|
|
||||||
|
Thread-safe singleton pattern.
|
||||||
|
"""
|
||||||
|
global _backend
|
||||||
|
|
||||||
|
if _backend is not None:
|
||||||
|
return _backend
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = BackendConfig()
|
||||||
|
|
||||||
|
if config.backend_type == "redis":
|
||||||
|
from xml_pipeline.memory.redis_backend import RedisBackend
|
||||||
|
|
||||||
|
_backend = RedisBackend(
|
||||||
|
url=config.redis_url,
|
||||||
|
prefix=config.redis_prefix,
|
||||||
|
ttl=config.redis_ttl,
|
||||||
|
)
|
||||||
|
elif config.backend_type == "manager":
|
||||||
|
from xml_pipeline.memory.manager_backend import ManagerBackend
|
||||||
|
|
||||||
|
_backend = ManagerBackend(
|
||||||
|
address=config.manager_address,
|
||||||
|
authkey=config.manager_authkey,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Default: in-memory backend
|
||||||
|
from xml_pipeline.memory.memory_backend import InMemoryBackend
|
||||||
|
|
||||||
|
_backend = InMemoryBackend()
|
||||||
|
|
||||||
|
return _backend
|
||||||
|
|
||||||
|
|
||||||
|
def reset_shared_backend() -> None:
|
||||||
|
"""Reset the global backend (for testing)."""
|
||||||
|
global _backend
|
||||||
|
if _backend is not None:
|
||||||
|
_backend.close()
|
||||||
|
_backend = None
|
||||||
|
|
@ -9,12 +9,21 @@ The pipeline is just a composition of stream operators.
|
||||||
|
|
||||||
Dependencies:
|
Dependencies:
|
||||||
pip install aiostream
|
pip install aiostream
|
||||||
|
|
||||||
|
CPU-Bound Handlers:
|
||||||
|
Handlers marked with `cpu_bound: true` are dispatched to a
|
||||||
|
ProcessPoolExecutor instead of running in the main event loop.
|
||||||
|
This prevents long-running handlers from blocking other messages.
|
||||||
|
|
||||||
|
Requires shared backend (Redis or Manager) for cross-process data access.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import importlib
|
import importlib
|
||||||
|
import logging
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import AsyncIterable, Callable, List, Dict, Any, Optional
|
from typing import AsyncIterable, Callable, List, Dict, Any, Optional
|
||||||
|
|
@ -34,6 +43,8 @@ from xml_pipeline.message_bus.thread_registry import get_registry
|
||||||
from xml_pipeline.message_bus.todo_registry import get_todo_registry
|
from xml_pipeline.message_bus.todo_registry import get_todo_registry
|
||||||
from xml_pipeline.memory import get_context_buffer
|
from xml_pipeline.memory import get_context_buffer
|
||||||
|
|
||||||
|
pump_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Configuration (same as before)
|
# Configuration (same as before)
|
||||||
|
|
@ -49,6 +60,7 @@ class ListenerConfig:
|
||||||
peers: List[str] = field(default_factory=list)
|
peers: List[str] = field(default_factory=list)
|
||||||
broadcast: bool = False
|
broadcast: bool = False
|
||||||
prompt: str = "" # System prompt for LLM agents (loaded into PromptRegistry)
|
prompt: str = "" # System prompt for LLM agents (loaded into PromptRegistry)
|
||||||
|
cpu_bound: bool = False # Dispatch to ProcessPoolExecutor if True
|
||||||
payload_class: type = field(default=None, repr=False)
|
payload_class: type = field(default=None, repr=False)
|
||||||
handler: Callable = field(default=None, repr=False)
|
handler: Callable = field(default=None, repr=False)
|
||||||
|
|
||||||
|
|
@ -69,6 +81,16 @@ class OrganismConfig:
|
||||||
# LLM configuration (optional)
|
# LLM configuration (optional)
|
||||||
llm_config: Dict[str, Any] = field(default_factory=dict)
|
llm_config: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Process pool configuration (for cpu_bound handlers)
|
||||||
|
process_pool_workers: int = 4
|
||||||
|
process_pool_max_tasks_per_child: int = 100
|
||||||
|
process_pool_enabled: bool = False
|
||||||
|
|
||||||
|
# Backend configuration (for shared state)
|
||||||
|
backend_type: str = "memory" # "memory", "manager", "redis"
|
||||||
|
backend_redis_url: str = "redis://localhost:6379"
|
||||||
|
backend_redis_prefix: str = "xp:"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Listener:
|
class Listener:
|
||||||
|
|
@ -79,6 +101,8 @@ class Listener:
|
||||||
is_agent: bool = False
|
is_agent: bool = False
|
||||||
peers: List[str] = field(default_factory=list)
|
peers: List[str] = field(default_factory=list)
|
||||||
broadcast: bool = False
|
broadcast: bool = False
|
||||||
|
cpu_bound: bool = False # Dispatch to ProcessPoolExecutor if True
|
||||||
|
handler_path: str = "" # Import path for worker process
|
||||||
schema: etree.XMLSchema = field(default=None, repr=False)
|
schema: etree.XMLSchema = field(default=None, repr=False)
|
||||||
root_tag: str = ""
|
root_tag: str = ""
|
||||||
usage_instructions: str = "" # Generated at registration for LLM agents
|
usage_instructions: str = "" # Generated at registration for LLM agents
|
||||||
|
|
@ -170,6 +194,10 @@ class StreamPump:
|
||||||
|
|
||||||
The entire flow is a single composable stream pipeline.
|
The entire flow is a single composable stream pipeline.
|
||||||
Fan-out is natural via flatmap. Concurrency is controlled via task_limit.
|
Fan-out is natural via flatmap. Concurrency is controlled via task_limit.
|
||||||
|
|
||||||
|
CPU-bound handlers can be dispatched to a ProcessPoolExecutor by
|
||||||
|
marking them with `cpu_bound: true` in config. This requires a
|
||||||
|
shared backend (Redis or Manager) for cross-process data access.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config: OrganismConfig):
|
def __init__(self, config: OrganismConfig):
|
||||||
|
|
@ -188,6 +216,29 @@ class StreamPump:
|
||||||
# Shutdown control
|
# Shutdown control
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
|
# Process pool for cpu_bound handlers
|
||||||
|
self._process_pool: Optional[ProcessPoolExecutor] = None
|
||||||
|
if config.process_pool_enabled:
|
||||||
|
self._process_pool = ProcessPoolExecutor(
|
||||||
|
max_workers=config.process_pool_workers,
|
||||||
|
max_tasks_per_child=config.process_pool_max_tasks_per_child,
|
||||||
|
)
|
||||||
|
pump_logger.info(
|
||||||
|
f"ProcessPool initialized: {config.process_pool_workers} workers"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shared backend for cross-process state
|
||||||
|
self._shared_backend = None
|
||||||
|
if config.backend_type != "memory":
|
||||||
|
from xml_pipeline.memory.shared_backend import BackendConfig, get_shared_backend
|
||||||
|
backend_config = BackendConfig(
|
||||||
|
backend_type=config.backend_type,
|
||||||
|
redis_url=config.backend_redis_url,
|
||||||
|
redis_prefix=config.backend_redis_prefix,
|
||||||
|
)
|
||||||
|
self._shared_backend = get_shared_backend(backend_config)
|
||||||
|
pump_logger.info(f"Shared backend: {config.backend_type}")
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Registration
|
# Registration
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|
@ -203,6 +254,8 @@ class StreamPump:
|
||||||
is_agent=lc.is_agent,
|
is_agent=lc.is_agent,
|
||||||
peers=lc.peers,
|
peers=lc.peers,
|
||||||
broadcast=lc.broadcast,
|
broadcast=lc.broadcast,
|
||||||
|
cpu_bound=lc.cpu_bound,
|
||||||
|
handler_path=lc.handler_path, # For worker process import
|
||||||
schema=self._generate_schema(lc.payload_class),
|
schema=self._generate_schema(lc.payload_class),
|
||||||
root_tag=root_tag,
|
root_tag=root_tag,
|
||||||
)
|
)
|
||||||
|
|
@ -420,7 +473,15 @@ class StreamPump:
|
||||||
)
|
)
|
||||||
payload_ref = state.payload
|
payload_ref = state.payload
|
||||||
|
|
||||||
response = await listener.handler(payload_ref, metadata)
|
# Dispatch to handler - either in-process or via ProcessPool
|
||||||
|
if listener.cpu_bound and self._process_pool and self._shared_backend:
|
||||||
|
response = await self._dispatch_to_process_pool(
|
||||||
|
listener=listener,
|
||||||
|
payload=payload_ref,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = await listener.handler(payload_ref, metadata)
|
||||||
|
|
||||||
# None means "no response needed" - don't re-inject
|
# None means "no response needed" - don't re-inject
|
||||||
if response is None:
|
if response is None:
|
||||||
|
|
@ -549,6 +610,94 @@ class StreamPump:
|
||||||
</message>"""
|
</message>"""
|
||||||
return envelope.encode('utf-8')
|
return envelope.encode('utf-8')
|
||||||
|
|
||||||
|
async def _dispatch_to_process_pool(
|
||||||
|
self,
|
||||||
|
listener: Listener,
|
||||||
|
payload: Any,
|
||||||
|
metadata: HandlerMetadata,
|
||||||
|
) -> Optional[HandlerResponse]:
|
||||||
|
"""
|
||||||
|
Dispatch handler to ProcessPoolExecutor for CPU-bound execution.
|
||||||
|
|
||||||
|
This offloads work to a separate process to avoid blocking
|
||||||
|
the main event loop.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
listener: The target listener
|
||||||
|
payload: The @xmlify dataclass payload
|
||||||
|
metadata: Handler metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HandlerResponse or None (same as direct handler call)
|
||||||
|
"""
|
||||||
|
from xml_pipeline.message_bus.worker import (
|
||||||
|
WorkerTask,
|
||||||
|
store_task_data,
|
||||||
|
fetch_response,
|
||||||
|
cleanup_task_data,
|
||||||
|
execute_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert self._process_pool is not None
|
||||||
|
assert self._shared_backend is not None
|
||||||
|
|
||||||
|
# Store payload and metadata in shared backend
|
||||||
|
payload_uuid, metadata_uuid = store_task_data(
|
||||||
|
self._shared_backend, payload, metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create worker task
|
||||||
|
task = WorkerTask(
|
||||||
|
thread_uuid=metadata.thread_id,
|
||||||
|
payload_uuid=payload_uuid,
|
||||||
|
handler_path=listener.handler_path,
|
||||||
|
metadata_uuid=metadata_uuid,
|
||||||
|
listener_name=listener.name,
|
||||||
|
is_agent=listener.is_agent,
|
||||||
|
peers=list(listener.peers),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Backend config for worker process
|
||||||
|
backend_config = {
|
||||||
|
"backend_type": self.config.backend_type,
|
||||||
|
"redis_url": self.config.backend_redis_url,
|
||||||
|
"redis_prefix": self.config.backend_redis_prefix,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Submit to process pool and await result
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
result = await loop.run_in_executor(
|
||||||
|
self._process_pool,
|
||||||
|
execute_handler,
|
||||||
|
task,
|
||||||
|
backend_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not result.success:
|
||||||
|
pump_logger.error(
|
||||||
|
f"Worker error for {listener.name}: {result.error}"
|
||||||
|
)
|
||||||
|
if result.error_traceback:
|
||||||
|
pump_logger.debug(f"Traceback: {result.error_traceback}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Fetch response from shared backend
|
||||||
|
if result.response_uuid:
|
||||||
|
response = fetch_response(self._shared_backend, result.response_uuid)
|
||||||
|
return response
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# Clean up task data from backend
|
||||||
|
cleanup_task_data(
|
||||||
|
self._shared_backend,
|
||||||
|
payload_uuid,
|
||||||
|
metadata_uuid,
|
||||||
|
result.response_uuid if 'result' in dir() and result.success else None,
|
||||||
|
)
|
||||||
|
|
||||||
async def _reinject_responses(self, state: MessageState) -> None:
|
async def _reinject_responses(self, state: MessageState) -> None:
|
||||||
"""Push handler responses back into the queue for next iteration."""
|
"""Push handler responses back into the queue for next iteration."""
|
||||||
await self.queue.put(state)
|
await self.queue.put(state)
|
||||||
|
|
@ -714,10 +863,16 @@ class StreamPump:
|
||||||
await self.queue.put(state)
|
await self.queue.put(state)
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
"""Graceful shutdown — wait for queue to drain."""
|
"""Graceful shutdown — wait for queue to drain and close resources."""
|
||||||
self._running = False
|
self._running = False
|
||||||
await self.queue.join()
|
await self.queue.join()
|
||||||
|
|
||||||
|
# Shutdown process pool if active
|
||||||
|
if self._process_pool:
|
||||||
|
self._process_pool.shutdown(wait=True)
|
||||||
|
pump_logger.info("ProcessPool shutdown complete")
|
||||||
|
self._process_pool = None
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Config Loader (same as before)
|
# Config Loader (same as before)
|
||||||
|
|
@ -733,6 +888,19 @@ class ConfigLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def _parse(cls, raw: dict) -> OrganismConfig:
|
def _parse(cls, raw: dict) -> OrganismConfig:
|
||||||
org = raw.get("organism", {})
|
org = raw.get("organism", {})
|
||||||
|
|
||||||
|
# Parse process pool config
|
||||||
|
pool = raw.get("process_pool", {})
|
||||||
|
process_pool_enabled = pool.get("enabled", False) if pool else False
|
||||||
|
process_pool_workers = pool.get("workers", 4) if pool else 4
|
||||||
|
process_pool_max_tasks = pool.get("max_tasks_per_child", 100) if pool else 100
|
||||||
|
|
||||||
|
# Parse backend config
|
||||||
|
backend = raw.get("backend", {})
|
||||||
|
backend_type = backend.get("type", "memory") if backend else "memory"
|
||||||
|
backend_redis_url = backend.get("redis_url", "redis://localhost:6379") if backend else "redis://localhost:6379"
|
||||||
|
backend_redis_prefix = backend.get("redis_prefix", "xp:") if backend else "xp:"
|
||||||
|
|
||||||
config = OrganismConfig(
|
config = OrganismConfig(
|
||||||
name=org.get("name", "unnamed"),
|
name=org.get("name", "unnamed"),
|
||||||
identity_path=org.get("identity", ""),
|
identity_path=org.get("identity", ""),
|
||||||
|
|
@ -742,6 +910,12 @@ class ConfigLoader:
|
||||||
max_concurrent_handlers=raw.get("max_concurrent_handlers", 20),
|
max_concurrent_handlers=raw.get("max_concurrent_handlers", 20),
|
||||||
max_concurrent_per_agent=raw.get("max_concurrent_per_agent", 5),
|
max_concurrent_per_agent=raw.get("max_concurrent_per_agent", 5),
|
||||||
llm_config=raw.get("llm", {}),
|
llm_config=raw.get("llm", {}),
|
||||||
|
process_pool_enabled=process_pool_enabled,
|
||||||
|
process_pool_workers=process_pool_workers,
|
||||||
|
process_pool_max_tasks_per_child=process_pool_max_tasks,
|
||||||
|
backend_type=backend_type,
|
||||||
|
backend_redis_url=backend_redis_url,
|
||||||
|
backend_redis_prefix=backend_redis_prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
for entry in raw.get("listeners", []):
|
for entry in raw.get("listeners", []):
|
||||||
|
|
@ -762,6 +936,7 @@ class ConfigLoader:
|
||||||
peers=raw.get("peers", []),
|
peers=raw.get("peers", []),
|
||||||
broadcast=raw.get("broadcast", False),
|
broadcast=raw.get("broadcast", False),
|
||||||
prompt=raw.get("prompt", ""),
|
prompt=raw.get("prompt", ""),
|
||||||
|
cpu_bound=raw.get("cpu_bound", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -14,15 +14,26 @@ Response routing:
|
||||||
2. Prunes the last segment (the responder)
|
2. Prunes the last segment (the responder)
|
||||||
3. Routes to the new last segment (the caller)
|
3. Routes to the new last segment (the caller)
|
||||||
4. Updates/cleans up the registry
|
4. Updates/cleans up the registry
|
||||||
|
|
||||||
|
For multi-process deployments, the registry can use a shared backend:
|
||||||
|
from xml_pipeline.memory.shared_backend import get_shared_backend, BackendConfig
|
||||||
|
|
||||||
|
config = BackendConfig(backend_type="redis", redis_url="redis://localhost:6379")
|
||||||
|
backend = get_shared_backend(config)
|
||||||
|
registry = get_registry(backend=backend)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid as uuid_module
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, Optional, Tuple
|
from typing import Dict, Optional, Tuple, TYPE_CHECKING
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from xml_pipeline.memory.shared_backend import SharedBackend
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ThreadRegistry:
|
class ThreadRegistry:
|
||||||
"""
|
"""
|
||||||
Bidirectional mapping between UUIDs and call chains.
|
Bidirectional mapping between UUIDs and call chains.
|
||||||
|
|
@ -32,12 +43,35 @@ class ThreadRegistry:
|
||||||
The registry maintains a root thread established at boot time.
|
The registry maintains a root thread established at boot time.
|
||||||
All external messages without a known parent are registered as
|
All external messages without a known parent are registered as
|
||||||
children of the root thread.
|
children of the root thread.
|
||||||
|
|
||||||
|
Supports two storage modes:
|
||||||
|
1. Local mode (default): Uses in-process dictionaries
|
||||||
|
2. Shared mode: Uses SharedBackend (Redis, Manager) for cross-process access
|
||||||
"""
|
"""
|
||||||
_chain_to_uuid: Dict[str, str] = field(default_factory=dict)
|
|
||||||
_uuid_to_chain: Dict[str, str] = field(default_factory=dict)
|
def __init__(self, backend: Optional[SharedBackend] = None):
|
||||||
_lock: threading.Lock = field(default_factory=threading.Lock)
|
"""
|
||||||
_root_uuid: Optional[str] = field(default=None)
|
Initialize thread registry.
|
||||||
_root_chain: str = field(default="system")
|
|
||||||
|
Args:
|
||||||
|
backend: Optional shared backend for cross-process storage.
|
||||||
|
If None, uses in-process storage (original behavior).
|
||||||
|
"""
|
||||||
|
self._backend = backend
|
||||||
|
|
||||||
|
# Local storage (used when no backend)
|
||||||
|
self._chain_to_uuid: Dict[str, str] = {}
|
||||||
|
self._uuid_to_chain: Dict[str, str] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
# Root thread tracking
|
||||||
|
self._root_uuid: Optional[str] = None
|
||||||
|
self._root_chain: str = "system"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_shared(self) -> bool:
|
||||||
|
"""Return True if using shared backend."""
|
||||||
|
return self._backend is not None
|
||||||
|
|
||||||
def initialize_root(self, organism_name: str = "organism") -> str:
|
def initialize_root(self, organism_name: str = "organism") -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -52,16 +86,36 @@ class ThreadRegistry:
|
||||||
Returns:
|
Returns:
|
||||||
UUID for the root thread
|
UUID for the root thread
|
||||||
"""
|
"""
|
||||||
|
if self._backend is not None:
|
||||||
|
return self._initialize_root_shared(organism_name)
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._root_uuid is not None:
|
if self._root_uuid is not None:
|
||||||
return self._root_uuid
|
return self._root_uuid
|
||||||
|
|
||||||
self._root_chain = f"system.{organism_name}"
|
self._root_chain = f"system.{organism_name}"
|
||||||
self._root_uuid = str(uuid.uuid4())
|
self._root_uuid = str(uuid_module.uuid4())
|
||||||
self._chain_to_uuid[self._root_chain] = self._root_uuid
|
self._chain_to_uuid[self._root_chain] = self._root_uuid
|
||||||
self._uuid_to_chain[self._root_uuid] = self._root_chain
|
self._uuid_to_chain[self._root_uuid] = self._root_chain
|
||||||
return self._root_uuid
|
return self._root_uuid
|
||||||
|
|
||||||
|
def _initialize_root_shared(self, organism_name: str) -> str:
|
||||||
|
"""Initialize root in shared backend."""
|
||||||
|
assert self._backend is not None
|
||||||
|
|
||||||
|
self._root_chain = f"system.{organism_name}"
|
||||||
|
|
||||||
|
# Check if root already exists in backend
|
||||||
|
existing_uuid = self._backend.registry_get_uuid(self._root_chain)
|
||||||
|
if existing_uuid:
|
||||||
|
self._root_uuid = existing_uuid
|
||||||
|
return existing_uuid
|
||||||
|
|
||||||
|
# Create new root
|
||||||
|
self._root_uuid = str(uuid_module.uuid4())
|
||||||
|
self._backend.registry_set(self._root_chain, self._root_uuid)
|
||||||
|
return self._root_uuid
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def root_uuid(self) -> Optional[str]:
|
def root_uuid(self) -> Optional[str]:
|
||||||
"""Get the root thread UUID (None if not initialized)."""
|
"""Get the root thread UUID (None if not initialized)."""
|
||||||
|
|
@ -82,11 +136,19 @@ class ThreadRegistry:
|
||||||
Returns:
|
Returns:
|
||||||
UUID string for this chain
|
UUID string for this chain
|
||||||
"""
|
"""
|
||||||
|
if self._backend is not None:
|
||||||
|
existing = self._backend.registry_get_uuid(chain)
|
||||||
|
if existing:
|
||||||
|
return existing
|
||||||
|
new_uuid = str(uuid_module.uuid4())
|
||||||
|
self._backend.registry_set(chain, new_uuid)
|
||||||
|
return new_uuid
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if chain in self._chain_to_uuid:
|
if chain in self._chain_to_uuid:
|
||||||
return self._chain_to_uuid[chain]
|
return self._chain_to_uuid[chain]
|
||||||
|
|
||||||
new_uuid = str(uuid.uuid4())
|
new_uuid = str(uuid_module.uuid4())
|
||||||
self._chain_to_uuid[chain] = new_uuid
|
self._chain_to_uuid[chain] = new_uuid
|
||||||
self._uuid_to_chain[new_uuid] = chain
|
self._uuid_to_chain[new_uuid] = chain
|
||||||
return new_uuid
|
return new_uuid
|
||||||
|
|
@ -101,6 +163,9 @@ class ThreadRegistry:
|
||||||
Returns:
|
Returns:
|
||||||
Chain string, or None if not found
|
Chain string, or None if not found
|
||||||
"""
|
"""
|
||||||
|
if self._backend is not None:
|
||||||
|
return self._backend.registry_get_chain(thread_id)
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return self._uuid_to_chain.get(thread_id)
|
return self._uuid_to_chain.get(thread_id)
|
||||||
|
|
||||||
|
|
@ -115,6 +180,9 @@ class ThreadRegistry:
|
||||||
Returns:
|
Returns:
|
||||||
UUID for the extended chain
|
UUID for the extended chain
|
||||||
"""
|
"""
|
||||||
|
if self._backend is not None:
|
||||||
|
return self._extend_chain_shared(current_uuid, next_hop)
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
current_chain = self._uuid_to_chain.get(current_uuid, "")
|
current_chain = self._uuid_to_chain.get(current_uuid, "")
|
||||||
if current_chain:
|
if current_chain:
|
||||||
|
|
@ -127,11 +195,31 @@ class ThreadRegistry:
|
||||||
return self._chain_to_uuid[new_chain]
|
return self._chain_to_uuid[new_chain]
|
||||||
|
|
||||||
# Create new UUID for extended chain
|
# Create new UUID for extended chain
|
||||||
new_uuid = str(uuid.uuid4())
|
new_uuid = str(uuid_module.uuid4())
|
||||||
self._chain_to_uuid[new_chain] = new_uuid
|
self._chain_to_uuid[new_chain] = new_uuid
|
||||||
self._uuid_to_chain[new_uuid] = new_chain
|
self._uuid_to_chain[new_uuid] = new_chain
|
||||||
return new_uuid
|
return new_uuid
|
||||||
|
|
||||||
|
def _extend_chain_shared(self, current_uuid: str, next_hop: str) -> str:
|
||||||
|
"""Extend chain in shared backend."""
|
||||||
|
assert self._backend is not None
|
||||||
|
|
||||||
|
current_chain = self._backend.registry_get_chain(current_uuid) or ""
|
||||||
|
if current_chain:
|
||||||
|
new_chain = f"{current_chain}.{next_hop}"
|
||||||
|
else:
|
||||||
|
new_chain = next_hop
|
||||||
|
|
||||||
|
# Check if extended chain already exists
|
||||||
|
existing = self._backend.registry_get_uuid(new_chain)
|
||||||
|
if existing:
|
||||||
|
return existing
|
||||||
|
|
||||||
|
# Create new UUID for extended chain
|
||||||
|
new_uuid = str(uuid_module.uuid4())
|
||||||
|
self._backend.registry_set(new_chain, new_uuid)
|
||||||
|
return new_uuid
|
||||||
|
|
||||||
def prune_for_response(self, thread_id: str) -> Tuple[Optional[str], Optional[str]]:
|
def prune_for_response(self, thread_id: str) -> Tuple[Optional[str], Optional[str]]:
|
||||||
"""
|
"""
|
||||||
Prune chain for a response and get the target.
|
Prune chain for a response and get the target.
|
||||||
|
|
@ -147,6 +235,9 @@ class ThreadRegistry:
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (target_listener, new_thread_uuid) or (None, None) if chain exhausted
|
Tuple of (target_listener, new_thread_uuid) or (None, None) if chain exhausted
|
||||||
"""
|
"""
|
||||||
|
if self._backend is not None:
|
||||||
|
return self._prune_for_response_shared(thread_id)
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
chain = self._uuid_to_chain.get(thread_id)
|
chain = self._uuid_to_chain.get(thread_id)
|
||||||
if not chain:
|
if not chain:
|
||||||
|
|
@ -168,15 +259,40 @@ class ThreadRegistry:
|
||||||
if pruned_chain in self._chain_to_uuid:
|
if pruned_chain in self._chain_to_uuid:
|
||||||
new_uuid = self._chain_to_uuid[pruned_chain]
|
new_uuid = self._chain_to_uuid[pruned_chain]
|
||||||
else:
|
else:
|
||||||
new_uuid = str(uuid.uuid4())
|
new_uuid = str(uuid_module.uuid4())
|
||||||
self._chain_to_uuid[pruned_chain] = new_uuid
|
self._chain_to_uuid[pruned_chain] = new_uuid
|
||||||
self._uuid_to_chain[new_uuid] = pruned_chain
|
self._uuid_to_chain[new_uuid] = pruned_chain
|
||||||
|
|
||||||
# Clean up old UUID (optional - could keep for debugging)
|
|
||||||
# self._cleanup_uuid(thread_id)
|
|
||||||
|
|
||||||
return target, new_uuid
|
return target, new_uuid
|
||||||
|
|
||||||
|
def _prune_for_response_shared(self, thread_id: str) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
"""Prune chain in shared backend."""
|
||||||
|
assert self._backend is not None
|
||||||
|
|
||||||
|
chain = self._backend.registry_get_chain(thread_id)
|
||||||
|
if not chain:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
parts = chain.split(".")
|
||||||
|
if len(parts) <= 1:
|
||||||
|
# Chain exhausted
|
||||||
|
self._backend.registry_delete(thread_id)
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Prune last segment
|
||||||
|
pruned_parts = parts[:-1]
|
||||||
|
target = pruned_parts[-1]
|
||||||
|
pruned_chain = ".".join(pruned_parts)
|
||||||
|
|
||||||
|
# Get or create UUID for pruned chain
|
||||||
|
existing = self._backend.registry_get_uuid(pruned_chain)
|
||||||
|
if existing:
|
||||||
|
return target, existing
|
||||||
|
|
||||||
|
new_uuid = str(uuid_module.uuid4())
|
||||||
|
self._backend.registry_set(pruned_chain, new_uuid)
|
||||||
|
return target, new_uuid
|
||||||
|
|
||||||
def start_chain(self, initiator: str, target: str) -> str:
|
def start_chain(self, initiator: str, target: str) -> str:
|
||||||
"""
|
"""
|
||||||
Start a new call chain.
|
Start a new call chain.
|
||||||
|
|
@ -208,6 +324,9 @@ class ThreadRegistry:
|
||||||
Returns:
|
Returns:
|
||||||
The same thread_id (now registered)
|
The same thread_id (now registered)
|
||||||
"""
|
"""
|
||||||
|
if self._backend is not None:
|
||||||
|
return self._register_thread_shared(thread_id, initiator, target)
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
# Check if UUID already registered (shouldn't happen, but be safe)
|
# Check if UUID already registered (shouldn't happen, but be safe)
|
||||||
if thread_id in self._uuid_to_chain:
|
if thread_id in self._uuid_to_chain:
|
||||||
|
|
@ -230,6 +349,29 @@ class ThreadRegistry:
|
||||||
self._uuid_to_chain[thread_id] = chain
|
self._uuid_to_chain[thread_id] = chain
|
||||||
return thread_id
|
return thread_id
|
||||||
|
|
||||||
|
def _register_thread_shared(self, thread_id: str, initiator: str, target: str) -> str:
|
||||||
|
"""Register thread in shared backend."""
|
||||||
|
assert self._backend is not None
|
||||||
|
|
||||||
|
# Check if UUID already registered
|
||||||
|
if self._backend.registry_get_chain(thread_id):
|
||||||
|
return thread_id
|
||||||
|
|
||||||
|
# Build chain rooted at system root
|
||||||
|
if self._root_uuid is not None:
|
||||||
|
chain = f"{self._root_chain}.{initiator}.{target}"
|
||||||
|
else:
|
||||||
|
chain = f"{initiator}.{target}"
|
||||||
|
|
||||||
|
# Check if chain already has a different UUID
|
||||||
|
existing = self._backend.registry_get_uuid(chain)
|
||||||
|
if existing:
|
||||||
|
return existing
|
||||||
|
|
||||||
|
# Register the external UUID to this chain
|
||||||
|
self._backend.registry_set(chain, thread_id)
|
||||||
|
return thread_id
|
||||||
|
|
||||||
def _cleanup_uuid(self, thread_id: str) -> None:
|
def _cleanup_uuid(self, thread_id: str) -> None:
|
||||||
"""Remove a UUID mapping (internal, call with lock held)."""
|
"""Remove a UUID mapping (internal, call with lock held)."""
|
||||||
chain = self._uuid_to_chain.pop(thread_id, None)
|
chain = self._uuid_to_chain.pop(thread_id, None)
|
||||||
|
|
@ -238,22 +380,65 @@ class ThreadRegistry:
|
||||||
|
|
||||||
def cleanup(self, thread_id: str) -> None:
|
def cleanup(self, thread_id: str) -> None:
|
||||||
"""Explicitly clean up a thread UUID."""
|
"""Explicitly clean up a thread UUID."""
|
||||||
|
if self._backend is not None:
|
||||||
|
self._backend.registry_delete(thread_id)
|
||||||
|
return
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._cleanup_uuid(thread_id)
|
self._cleanup_uuid(thread_id)
|
||||||
|
|
||||||
def debug_dump(self) -> Dict[str, str]:
|
def debug_dump(self) -> Dict[str, str]:
|
||||||
"""Return current mappings for debugging."""
|
"""Return current mappings for debugging."""
|
||||||
|
if self._backend is not None:
|
||||||
|
return self._backend.registry_list_all()
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return dict(self._uuid_to_chain)
|
return dict(self._uuid_to_chain)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear all thread mappings (for testing only)."""
|
||||||
|
if self._backend is not None:
|
||||||
|
self._backend.registry_clear()
|
||||||
|
self._root_uuid = None
|
||||||
|
self._root_chain = "system"
|
||||||
|
return
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
self._chain_to_uuid.clear()
|
||||||
|
self._uuid_to_chain.clear()
|
||||||
|
self._root_uuid = None
|
||||||
|
self._root_chain = "system"
|
||||||
|
|
||||||
|
|
||||||
# Global registry instance
|
# Global registry instance
|
||||||
_registry: Optional[ThreadRegistry] = None
|
_registry: Optional[ThreadRegistry] = None
|
||||||
|
_registry_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def get_registry() -> ThreadRegistry:
|
def get_registry(backend: Optional[SharedBackend] = None) -> ThreadRegistry:
|
||||||
"""Get the global thread registry."""
|
"""
|
||||||
|
Get the global thread registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend: Optional shared backend for cross-process storage.
|
||||||
|
Only used on first call (when creating the singleton).
|
||||||
|
Subsequent calls return the existing singleton.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Global ThreadRegistry instance.
|
||||||
|
"""
|
||||||
global _registry
|
global _registry
|
||||||
if _registry is None:
|
if _registry is None:
|
||||||
_registry = ThreadRegistry()
|
with _registry_lock:
|
||||||
|
if _registry is None:
|
||||||
|
_registry = ThreadRegistry(backend=backend)
|
||||||
return _registry
|
return _registry
|
||||||
|
|
||||||
|
|
||||||
|
def reset_registry() -> None:
|
||||||
|
"""Reset the global thread registry (for testing)."""
|
||||||
|
global _registry
|
||||||
|
with _registry_lock:
|
||||||
|
if _registry is not None:
|
||||||
|
_registry.clear()
|
||||||
|
_registry = None
|
||||||
|
|
|
||||||
339
xml_pipeline/message_bus/worker.py
Normal file
339
xml_pipeline/message_bus/worker.py
Normal file
|
|
@ -0,0 +1,339 @@
|
||||||
|
"""
|
||||||
|
worker.py — Worker process entry point for CPU-bound handler dispatch.
|
||||||
|
|
||||||
|
This module provides the entry point for handler execution in worker processes
|
||||||
|
when using ProcessPoolExecutor. Workers communicate with the main process via
|
||||||
|
the shared backend (Redis or Manager).
|
||||||
|
|
||||||
|
Architecture:
|
||||||
|
Main Process (StreamPump)
|
||||||
|
│
|
||||||
|
│ Submits WorkerTask to ProcessPoolExecutor
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
Worker Process (this module)
|
||||||
|
│
|
||||||
|
├── Imports handler module
|
||||||
|
├── Fetches payload from shared backend
|
||||||
|
├── Executes handler
|
||||||
|
├── Stores response in shared backend
|
||||||
|
└── Returns WorkerResult
|
||||||
|
|
||||||
|
Key design decisions:
|
||||||
|
- Minimal IPC payload: Only UUIDs and module paths cross process boundary
|
||||||
|
- Handler state via shared backend: Workers fetch/store data in Redis
|
||||||
|
- Process reuse: Workers are reused across tasks (max_tasks_per_child for restart)
|
||||||
|
- Error isolation: Handler exceptions don't crash the worker pool
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Main process submits task
|
||||||
|
from concurrent.futures import ProcessPoolExecutor
|
||||||
|
from xml_pipeline.message_bus.worker import execute_handler, WorkerTask
|
||||||
|
|
||||||
|
pool = ProcessPoolExecutor(max_workers=4)
|
||||||
|
task = WorkerTask(
|
||||||
|
thread_uuid="550e8400-...",
|
||||||
|
payload_uuid="6ba7b810-...",
|
||||||
|
handler_path="handlers.librarian.handle_query",
|
||||||
|
metadata_uuid="7c9e6679-...",
|
||||||
|
)
|
||||||
|
future = pool.submit(execute_handler, task)
|
||||||
|
result = future.result() # WorkerResult
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from xml_pipeline.memory.shared_backend import SharedBackend
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WorkerTask:
|
||||||
|
"""
|
||||||
|
Task submitted to worker process.
|
||||||
|
|
||||||
|
Contains only UUIDs and paths — actual data lives in shared backend.
|
||||||
|
This minimizes IPC overhead and allows large payloads.
|
||||||
|
"""
|
||||||
|
|
||||||
|
thread_uuid: str
|
||||||
|
payload_uuid: str # UUID for looking up serialized payload
|
||||||
|
handler_path: str # e.g., "handlers.librarian.handle_query"
|
||||||
|
metadata_uuid: str # UUID for looking up serialized HandlerMetadata
|
||||||
|
listener_name: str # Name of the listener (for logging/response injection)
|
||||||
|
is_agent: bool = False # Whether this is an agent handler
|
||||||
|
peers: list[str] = field(default_factory=list) # Allowed peers (for agents)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WorkerResult:
|
||||||
|
"""
|
||||||
|
Result from worker process.
|
||||||
|
|
||||||
|
Contains UUID for response stored in backend, or error information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
success: bool
|
||||||
|
response_uuid: Optional[str] = None # UUID for serialized HandlerResponse
|
||||||
|
error: Optional[str] = None # Error message if failed
|
||||||
|
error_traceback: Optional[str] = None # Full traceback for debugging
|
||||||
|
elapsed_ms: float = 0.0 # Execution time
|
||||||
|
|
||||||
|
|
||||||
|
# Keys for storing task data in shared backend
|
||||||
|
def _payload_key(uuid: str) -> str:
|
||||||
|
return f"worker:payload:{uuid}"
|
||||||
|
|
||||||
|
|
||||||
|
def _metadata_key(uuid: str) -> str:
|
||||||
|
return f"worker:metadata:{uuid}"
|
||||||
|
|
||||||
|
|
||||||
|
def _response_key(uuid: str) -> str:
|
||||||
|
return f"worker:response:{uuid}"
|
||||||
|
|
||||||
|
|
||||||
|
def _import_handler(handler_path: str) -> Callable:
|
||||||
|
"""
|
||||||
|
Import and return handler function from path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
handler_path: Dotted path like "handlers.librarian.handle_query"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The handler function
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ImportError: If module not found
|
||||||
|
AttributeError: If function not found in module
|
||||||
|
"""
|
||||||
|
module_path, func_name = handler_path.rsplit(".", 1)
|
||||||
|
module = importlib.import_module(module_path)
|
||||||
|
return getattr(module, func_name)
|
||||||
|
|
||||||
|
|
||||||
|
def store_task_data(
|
||||||
|
backend: SharedBackend,
|
||||||
|
payload: Any,
|
||||||
|
metadata: Any,
|
||||||
|
ttl: int = 3600,
|
||||||
|
) -> tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Store payload and metadata in shared backend for worker access.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend: Shared backend instance
|
||||||
|
payload: The @xmlify dataclass payload
|
||||||
|
metadata: HandlerMetadata instance
|
||||||
|
ttl: Time-to-live in seconds (default 1 hour)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (payload_uuid, metadata_uuid)
|
||||||
|
"""
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
payload_uuid = str(uuid.uuid4())
|
||||||
|
metadata_uuid = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Serialize and store
|
||||||
|
payload_bytes = pickle.dumps(payload)
|
||||||
|
metadata_bytes = pickle.dumps(metadata)
|
||||||
|
|
||||||
|
# Store as buffer slots (reusing the backend interface)
|
||||||
|
# We use a special "worker" prefix to distinguish from regular slots
|
||||||
|
backend.buffer_append(f"worker:payload:{payload_uuid}", payload_bytes)
|
||||||
|
backend.buffer_append(f"worker:metadata:{metadata_uuid}", metadata_bytes)
|
||||||
|
|
||||||
|
return payload_uuid, metadata_uuid
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_task_data(
|
||||||
|
backend: SharedBackend,
|
||||||
|
payload_uuid: str,
|
||||||
|
metadata_uuid: str,
|
||||||
|
) -> tuple[Any, Any]:
|
||||||
|
"""
|
||||||
|
Fetch payload and metadata from shared backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend: Shared backend instance
|
||||||
|
payload_uuid: UUID of stored payload
|
||||||
|
metadata_uuid: UUID of stored metadata
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (payload, metadata)
|
||||||
|
"""
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
payload_data = backend.buffer_get_slot(f"worker:payload:{payload_uuid}", 0)
|
||||||
|
metadata_data = backend.buffer_get_slot(f"worker:metadata:{metadata_uuid}", 0)
|
||||||
|
|
||||||
|
if payload_data is None:
|
||||||
|
raise ValueError(f"Payload not found: {payload_uuid}")
|
||||||
|
if metadata_data is None:
|
||||||
|
raise ValueError(f"Metadata not found: {metadata_uuid}")
|
||||||
|
|
||||||
|
payload = pickle.loads(payload_data)
|
||||||
|
metadata = pickle.loads(metadata_data)
|
||||||
|
|
||||||
|
return payload, metadata
|
||||||
|
|
||||||
|
|
||||||
|
def store_response(
|
||||||
|
backend: SharedBackend,
|
||||||
|
response: Any,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Store handler response in shared backend for main process retrieval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend: Shared backend instance
|
||||||
|
response: HandlerResponse or None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response UUID
|
||||||
|
"""
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
response_uuid = str(uuid.uuid4())
|
||||||
|
response_bytes = pickle.dumps(response)
|
||||||
|
backend.buffer_append(f"worker:response:{response_uuid}", response_bytes)
|
||||||
|
|
||||||
|
return response_uuid
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_response(
|
||||||
|
backend: SharedBackend,
|
||||||
|
response_uuid: str,
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Fetch handler response from shared backend.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
backend: Shared backend instance
|
||||||
|
response_uuid: UUID of stored response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HandlerResponse or None
|
||||||
|
"""
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
response_data = backend.buffer_get_slot(f"worker:response:{response_uuid}", 0)
|
||||||
|
if response_data is None:
|
||||||
|
raise ValueError(f"Response not found: {response_uuid}")
|
||||||
|
|
||||||
|
return pickle.loads(response_data)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_task_data(
|
||||||
|
backend: SharedBackend,
|
||||||
|
payload_uuid: str,
|
||||||
|
metadata_uuid: str,
|
||||||
|
response_uuid: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Clean up temporary task data from shared backend.
|
||||||
|
|
||||||
|
Call after response has been processed by main process.
|
||||||
|
"""
|
||||||
|
backend.buffer_delete_thread(f"worker:payload:{payload_uuid}")
|
||||||
|
backend.buffer_delete_thread(f"worker:metadata:{metadata_uuid}")
|
||||||
|
if response_uuid:
|
||||||
|
backend.buffer_delete_thread(f"worker:response:{response_uuid}")
|
||||||
|
|
||||||
|
|
||||||
|
def execute_handler(
|
||||||
|
task: WorkerTask,
|
||||||
|
backend_config: Optional[dict] = None,
|
||||||
|
) -> WorkerResult:
|
||||||
|
"""
|
||||||
|
Execute a handler in the worker process.
|
||||||
|
|
||||||
|
This is the entry point called by ProcessPoolExecutor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
task: WorkerTask containing UUIDs and handler path
|
||||||
|
backend_config: Optional backend configuration dict
|
||||||
|
(if None, uses default shared backend)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WorkerResult with response UUID or error
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
start_time = time.monotonic()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get shared backend
|
||||||
|
from xml_pipeline.memory.shared_backend import (
|
||||||
|
BackendConfig,
|
||||||
|
get_shared_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
if backend_config:
|
||||||
|
config = BackendConfig(**backend_config)
|
||||||
|
backend = get_shared_backend(config)
|
||||||
|
else:
|
||||||
|
backend = get_shared_backend()
|
||||||
|
|
||||||
|
# Fetch payload and metadata
|
||||||
|
payload, metadata = fetch_task_data(
|
||||||
|
backend, task.payload_uuid, task.metadata_uuid
|
||||||
|
)
|
||||||
|
|
||||||
|
# Import handler
|
||||||
|
handler = _import_handler(task.handler_path)
|
||||||
|
|
||||||
|
# Execute handler (async)
|
||||||
|
# Workers run their own event loop for async handlers
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
try:
|
||||||
|
response = loop.run_until_complete(handler(payload, metadata))
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
# Store response
|
||||||
|
response_uuid = store_response(backend, response)
|
||||||
|
|
||||||
|
elapsed = (time.monotonic() - start_time) * 1000
|
||||||
|
|
||||||
|
return WorkerResult(
|
||||||
|
success=True,
|
||||||
|
response_uuid=response_uuid,
|
||||||
|
elapsed_ms=elapsed,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
elapsed = (time.monotonic() - start_time) * 1000
|
||||||
|
logger.exception(f"Handler execution failed: {task.handler_path}")
|
||||||
|
|
||||||
|
return WorkerResult(
|
||||||
|
success=False,
|
||||||
|
error=str(e),
|
||||||
|
error_traceback=traceback.format_exc(),
|
||||||
|
elapsed_ms=elapsed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def execute_handler_sync(
|
||||||
|
task: WorkerTask,
|
||||||
|
backend_config: Optional[dict] = None,
|
||||||
|
) -> WorkerResult:
|
||||||
|
"""
|
||||||
|
Synchronous wrapper for execute_handler.
|
||||||
|
|
||||||
|
Use this when the handler doesn't need async features
|
||||||
|
(pure computation, no I/O).
|
||||||
|
"""
|
||||||
|
return execute_handler(task, backend_config)
|
||||||
Loading…
Reference in a new issue