xml-pipeline/xml_pipeline/memory/redis_backend.py
dullfig 6790c7a46c Add shared backend for multiprocess pipeline support
Introduces SharedBackend Protocol for cross-process state sharing:
- InMemoryBackend: default single-process storage
- ManagerBackend: multiprocessing.Manager for local multi-process
- RedisBackend: distributed deployments with TTL auto-GC

Adds ProcessPoolExecutor support for CPU-bound handlers:
- worker.py: worker process entry point
- stream_pump.py: cpu_bound handler dispatch
- Config: backend and process_pool sections in organism.yaml

ContextBuffer and ThreadRegistry now accept optional backend
parameter while maintaining full backward compatibility.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 20:18:22 -08:00

262 lines
8.3 KiB
Python

"""
redis_backend.py — Redis implementation of SharedBackend.
Uses Redis for distributed shared state, enabling:
- Multi-process handler dispatch (cpu_bound handlers in ProcessPool)
- Multi-tenant deployments (key prefix isolation)
- Automatic TTL-based garbage collection
- Cross-machine state sharing (future: WASM handlers)
Key schema:
- Buffer slots: `{prefix}buffer:{thread_id}` → Redis LIST of pickled slots
- Registry chain→uuid: `{prefix}chain:{chain}` → uuid string
- Registry uuid→chain: `{prefix}uuid:{uuid}` → chain string
Dependencies:
pip install redis
"""
from __future__ import annotations
import logging
from typing import Dict, List, Optional
try:
import redis
except ImportError:
redis = None # type: ignore
logger = logging.getLogger(__name__)
class RedisBackend:
"""
Redis-backed shared state for multi-process deployments.
All operations are synchronous (blocking). For async contexts,
wrap calls in asyncio.to_thread() or use a connection pool.
"""
def __init__(
self,
url: str = "redis://localhost:6379",
prefix: str = "xp:",
ttl: int = 86400,
) -> None:
"""
Initialize Redis backend.
Args:
url: Redis connection URL
prefix: Key prefix for multi-tenancy isolation
ttl: TTL in seconds for automatic cleanup (0 = no TTL)
"""
if redis is None:
raise ImportError(
"redis package not installed. Install with: pip install redis"
)
self.url = url
self.prefix = prefix
self.ttl = ttl
# Connect to Redis
self._client = redis.from_url(url, decode_responses=False)
# Verify connection
try:
self._client.ping()
logger.info(f"Redis backend connected: {url}")
except redis.ConnectionError as e:
logger.error(f"Redis connection failed: {e}")
raise
def _buffer_key(self, thread_id: str) -> str:
"""Get Redis key for a thread's buffer."""
return f"{self.prefix}buffer:{thread_id}"
def _chain_key(self, chain: str) -> str:
"""Get Redis key for chain→uuid mapping."""
return f"{self.prefix}chain:{chain}"
def _uuid_key(self, uuid: str) -> str:
"""Get Redis key for uuid→chain mapping."""
return f"{self.prefix}uuid:{uuid}"
# =========================================================================
# Context Buffer Operations
# =========================================================================
def buffer_append(self, thread_id: str, slot_data: bytes) -> int:
"""Append a slot to a thread's buffer."""
key = self._buffer_key(thread_id)
# RPUSH returns new length, so index = length - 1
new_len = self._client.rpush(key, slot_data)
# Set/refresh TTL
if self.ttl > 0:
self._client.expire(key, self.ttl)
return new_len - 1
def buffer_get_thread(self, thread_id: str) -> List[bytes]:
"""Get all slots for a thread."""
key = self._buffer_key(thread_id)
# LRANGE returns all elements (0, -1 = full list)
result = self._client.lrange(key, 0, -1)
return result if result else []
def buffer_get_slot(self, thread_id: str, index: int) -> Optional[bytes]:
"""Get a specific slot by index."""
key = self._buffer_key(thread_id)
# LINDEX returns element at index, or None
return self._client.lindex(key, index)
def buffer_thread_len(self, thread_id: str) -> int:
"""Get number of slots in a thread."""
key = self._buffer_key(thread_id)
return self._client.llen(key)
def buffer_thread_exists(self, thread_id: str) -> bool:
"""Check if a thread has any slots."""
key = self._buffer_key(thread_id)
return self._client.exists(key) > 0
def buffer_delete_thread(self, thread_id: str) -> bool:
"""Delete all slots for a thread."""
key = self._buffer_key(thread_id)
deleted = self._client.delete(key)
return deleted > 0
def buffer_list_threads(self) -> List[str]:
"""List all thread IDs with slots."""
pattern = f"{self.prefix}buffer:*"
keys = self._client.keys(pattern)
prefix_len = len(f"{self.prefix}buffer:")
return [k.decode() if isinstance(k, bytes) else k for k in keys]
def buffer_clear(self) -> None:
"""Clear all buffer data."""
pattern = f"{self.prefix}buffer:*"
keys = self._client.keys(pattern)
if keys:
self._client.delete(*keys)
# =========================================================================
# Thread Registry Operations
# =========================================================================
def registry_set(self, chain: str, uuid: str) -> None:
"""Set bidirectional mapping: chain ↔ uuid."""
chain_key = self._chain_key(chain)
uuid_key = self._uuid_key(uuid)
# Use pipeline for atomicity
pipe = self._client.pipeline()
pipe.set(chain_key, uuid.encode())
pipe.set(uuid_key, chain.encode())
if self.ttl > 0:
pipe.expire(chain_key, self.ttl)
pipe.expire(uuid_key, self.ttl)
pipe.execute()
def registry_get_uuid(self, chain: str) -> Optional[str]:
"""Get UUID for a chain."""
key = self._chain_key(chain)
value = self._client.get(key)
if value:
return value.decode() if isinstance(value, bytes) else value
return None
def registry_get_chain(self, uuid: str) -> Optional[str]:
"""Get chain for a UUID."""
key = self._uuid_key(uuid)
value = self._client.get(key)
if value:
return value.decode() if isinstance(value, bytes) else value
return None
def registry_delete(self, uuid: str) -> bool:
"""Delete mapping by UUID."""
uuid_key = self._uuid_key(uuid)
# First get chain to delete both directions
chain = self._client.get(uuid_key)
if not chain:
return False
if isinstance(chain, bytes):
chain = chain.decode()
chain_key = self._chain_key(chain)
# Delete both keys
deleted = self._client.delete(uuid_key, chain_key)
return deleted > 0
def registry_list_all(self) -> Dict[str, str]:
"""Get all UUID → chain mappings."""
pattern = f"{self.prefix}uuid:*"
keys = self._client.keys(pattern)
result = {}
prefix_len = len(f"{self.prefix}uuid:")
for key in keys:
if isinstance(key, bytes):
key = key.decode()
uuid = key[prefix_len:]
chain = self._client.get(key)
if chain:
if isinstance(chain, bytes):
chain = chain.decode()
result[uuid] = chain
return result
def registry_clear(self) -> None:
"""Clear all registry data."""
chain_pattern = f"{self.prefix}chain:*"
uuid_pattern = f"{self.prefix}uuid:*"
chain_keys = self._client.keys(chain_pattern)
uuid_keys = self._client.keys(uuid_pattern)
all_keys = list(chain_keys) + list(uuid_keys)
if all_keys:
self._client.delete(*all_keys)
# =========================================================================
# Lifecycle
# =========================================================================
def close(self) -> None:
"""Close Redis connection."""
if self._client:
self._client.close()
logger.info("Redis backend closed")
# =========================================================================
# Health Check
# =========================================================================
def ping(self) -> bool:
"""Check if Redis is reachable."""
try:
return self._client.ping()
except Exception:
return False
def info(self) -> Dict[str, int]:
"""Get backend statistics."""
buffer_keys = len(self._client.keys(f"{self.prefix}buffer:*"))
uuid_keys = len(self._client.keys(f"{self.prefix}uuid:*"))
return {
"buffer_threads": buffer_keys,
"registry_entries": uuid_keys,
"ttl_seconds": self.ttl,
}