Introduces SharedBackend Protocol for cross-process state sharing: - InMemoryBackend: default single-process storage - ManagerBackend: multiprocessing.Manager for local multi-process - RedisBackend: distributed deployments with TTL auto-GC Adds ProcessPoolExecutor support for CPU-bound handlers: - worker.py: worker process entry point - stream_pump.py: cpu_bound handler dispatch - Config: backend and process_pool sections in organism.yaml ContextBuffer and ThreadRegistry now accept optional backend parameter while maintaining full backward compatibility. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
262 lines
8.3 KiB
Python
262 lines
8.3 KiB
Python
"""
|
|
redis_backend.py — Redis implementation of SharedBackend.
|
|
|
|
Uses Redis for distributed shared state, enabling:
|
|
- Multi-process handler dispatch (cpu_bound handlers in ProcessPool)
|
|
- Multi-tenant deployments (key prefix isolation)
|
|
- Automatic TTL-based garbage collection
|
|
- Cross-machine state sharing (future: WASM handlers)
|
|
|
|
Key schema:
|
|
- Buffer slots: `{prefix}buffer:{thread_id}` → Redis LIST of pickled slots
|
|
- Registry chain→uuid: `{prefix}chain:{chain}` → uuid string
|
|
- Registry uuid→chain: `{prefix}uuid:{uuid}` → chain string
|
|
|
|
Dependencies:
|
|
pip install redis
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import Dict, List, Optional
|
|
|
|
try:
|
|
import redis
|
|
except ImportError:
|
|
redis = None # type: ignore
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class RedisBackend:
|
|
"""
|
|
Redis-backed shared state for multi-process deployments.
|
|
|
|
All operations are synchronous (blocking). For async contexts,
|
|
wrap calls in asyncio.to_thread() or use a connection pool.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
url: str = "redis://localhost:6379",
|
|
prefix: str = "xp:",
|
|
ttl: int = 86400,
|
|
) -> None:
|
|
"""
|
|
Initialize Redis backend.
|
|
|
|
Args:
|
|
url: Redis connection URL
|
|
prefix: Key prefix for multi-tenancy isolation
|
|
ttl: TTL in seconds for automatic cleanup (0 = no TTL)
|
|
"""
|
|
if redis is None:
|
|
raise ImportError(
|
|
"redis package not installed. Install with: pip install redis"
|
|
)
|
|
|
|
self.url = url
|
|
self.prefix = prefix
|
|
self.ttl = ttl
|
|
|
|
# Connect to Redis
|
|
self._client = redis.from_url(url, decode_responses=False)
|
|
|
|
# Verify connection
|
|
try:
|
|
self._client.ping()
|
|
logger.info(f"Redis backend connected: {url}")
|
|
except redis.ConnectionError as e:
|
|
logger.error(f"Redis connection failed: {e}")
|
|
raise
|
|
|
|
def _buffer_key(self, thread_id: str) -> str:
|
|
"""Get Redis key for a thread's buffer."""
|
|
return f"{self.prefix}buffer:{thread_id}"
|
|
|
|
def _chain_key(self, chain: str) -> str:
|
|
"""Get Redis key for chain→uuid mapping."""
|
|
return f"{self.prefix}chain:{chain}"
|
|
|
|
def _uuid_key(self, uuid: str) -> str:
|
|
"""Get Redis key for uuid→chain mapping."""
|
|
return f"{self.prefix}uuid:{uuid}"
|
|
|
|
# =========================================================================
|
|
# Context Buffer Operations
|
|
# =========================================================================
|
|
|
|
def buffer_append(self, thread_id: str, slot_data: bytes) -> int:
|
|
"""Append a slot to a thread's buffer."""
|
|
key = self._buffer_key(thread_id)
|
|
|
|
# RPUSH returns new length, so index = length - 1
|
|
new_len = self._client.rpush(key, slot_data)
|
|
|
|
# Set/refresh TTL
|
|
if self.ttl > 0:
|
|
self._client.expire(key, self.ttl)
|
|
|
|
return new_len - 1
|
|
|
|
def buffer_get_thread(self, thread_id: str) -> List[bytes]:
|
|
"""Get all slots for a thread."""
|
|
key = self._buffer_key(thread_id)
|
|
# LRANGE returns all elements (0, -1 = full list)
|
|
result = self._client.lrange(key, 0, -1)
|
|
return result if result else []
|
|
|
|
def buffer_get_slot(self, thread_id: str, index: int) -> Optional[bytes]:
|
|
"""Get a specific slot by index."""
|
|
key = self._buffer_key(thread_id)
|
|
# LINDEX returns element at index, or None
|
|
return self._client.lindex(key, index)
|
|
|
|
def buffer_thread_len(self, thread_id: str) -> int:
|
|
"""Get number of slots in a thread."""
|
|
key = self._buffer_key(thread_id)
|
|
return self._client.llen(key)
|
|
|
|
def buffer_thread_exists(self, thread_id: str) -> bool:
|
|
"""Check if a thread has any slots."""
|
|
key = self._buffer_key(thread_id)
|
|
return self._client.exists(key) > 0
|
|
|
|
def buffer_delete_thread(self, thread_id: str) -> bool:
|
|
"""Delete all slots for a thread."""
|
|
key = self._buffer_key(thread_id)
|
|
deleted = self._client.delete(key)
|
|
return deleted > 0
|
|
|
|
def buffer_list_threads(self) -> List[str]:
|
|
"""List all thread IDs with slots."""
|
|
pattern = f"{self.prefix}buffer:*"
|
|
keys = self._client.keys(pattern)
|
|
prefix_len = len(f"{self.prefix}buffer:")
|
|
return [k.decode() if isinstance(k, bytes) else k for k in keys]
|
|
|
|
def buffer_clear(self) -> None:
|
|
"""Clear all buffer data."""
|
|
pattern = f"{self.prefix}buffer:*"
|
|
keys = self._client.keys(pattern)
|
|
if keys:
|
|
self._client.delete(*keys)
|
|
|
|
# =========================================================================
|
|
# Thread Registry Operations
|
|
# =========================================================================
|
|
|
|
def registry_set(self, chain: str, uuid: str) -> None:
|
|
"""Set bidirectional mapping: chain ↔ uuid."""
|
|
chain_key = self._chain_key(chain)
|
|
uuid_key = self._uuid_key(uuid)
|
|
|
|
# Use pipeline for atomicity
|
|
pipe = self._client.pipeline()
|
|
pipe.set(chain_key, uuid.encode())
|
|
pipe.set(uuid_key, chain.encode())
|
|
|
|
if self.ttl > 0:
|
|
pipe.expire(chain_key, self.ttl)
|
|
pipe.expire(uuid_key, self.ttl)
|
|
|
|
pipe.execute()
|
|
|
|
def registry_get_uuid(self, chain: str) -> Optional[str]:
|
|
"""Get UUID for a chain."""
|
|
key = self._chain_key(chain)
|
|
value = self._client.get(key)
|
|
if value:
|
|
return value.decode() if isinstance(value, bytes) else value
|
|
return None
|
|
|
|
def registry_get_chain(self, uuid: str) -> Optional[str]:
|
|
"""Get chain for a UUID."""
|
|
key = self._uuid_key(uuid)
|
|
value = self._client.get(key)
|
|
if value:
|
|
return value.decode() if isinstance(value, bytes) else value
|
|
return None
|
|
|
|
def registry_delete(self, uuid: str) -> bool:
|
|
"""Delete mapping by UUID."""
|
|
uuid_key = self._uuid_key(uuid)
|
|
|
|
# First get chain to delete both directions
|
|
chain = self._client.get(uuid_key)
|
|
if not chain:
|
|
return False
|
|
|
|
if isinstance(chain, bytes):
|
|
chain = chain.decode()
|
|
|
|
chain_key = self._chain_key(chain)
|
|
|
|
# Delete both keys
|
|
deleted = self._client.delete(uuid_key, chain_key)
|
|
return deleted > 0
|
|
|
|
def registry_list_all(self) -> Dict[str, str]:
|
|
"""Get all UUID → chain mappings."""
|
|
pattern = f"{self.prefix}uuid:*"
|
|
keys = self._client.keys(pattern)
|
|
|
|
result = {}
|
|
prefix_len = len(f"{self.prefix}uuid:")
|
|
|
|
for key in keys:
|
|
if isinstance(key, bytes):
|
|
key = key.decode()
|
|
uuid = key[prefix_len:]
|
|
chain = self._client.get(key)
|
|
if chain:
|
|
if isinstance(chain, bytes):
|
|
chain = chain.decode()
|
|
result[uuid] = chain
|
|
|
|
return result
|
|
|
|
def registry_clear(self) -> None:
|
|
"""Clear all registry data."""
|
|
chain_pattern = f"{self.prefix}chain:*"
|
|
uuid_pattern = f"{self.prefix}uuid:*"
|
|
|
|
chain_keys = self._client.keys(chain_pattern)
|
|
uuid_keys = self._client.keys(uuid_pattern)
|
|
|
|
all_keys = list(chain_keys) + list(uuid_keys)
|
|
if all_keys:
|
|
self._client.delete(*all_keys)
|
|
|
|
# =========================================================================
|
|
# Lifecycle
|
|
# =========================================================================
|
|
|
|
def close(self) -> None:
|
|
"""Close Redis connection."""
|
|
if self._client:
|
|
self._client.close()
|
|
logger.info("Redis backend closed")
|
|
|
|
# =========================================================================
|
|
# Health Check
|
|
# =========================================================================
|
|
|
|
def ping(self) -> bool:
|
|
"""Check if Redis is reachable."""
|
|
try:
|
|
return self._client.ping()
|
|
except Exception:
|
|
return False
|
|
|
|
def info(self) -> Dict[str, int]:
|
|
"""Get backend statistics."""
|
|
buffer_keys = len(self._client.keys(f"{self.prefix}buffer:*"))
|
|
uuid_keys = len(self._client.keys(f"{self.prefix}uuid:*"))
|
|
|
|
return {
|
|
"buffer_threads": buffer_keys,
|
|
"registry_entries": uuid_keys,
|
|
"ttl_seconds": self.ttl,
|
|
}
|