Add context buffer - virtual memory manager for AI agents
Implements thread-scoped, append-only storage for validated payloads. Handlers receive immutable references; messages cannot be modified after insertion. Core components: - BufferSlot: frozen dataclass holding payload + metadata - ThreadContext: append-only buffer per thread - ContextBuffer: global manager with GC and limits Design parallels OS virtual memory: - Thread ID = virtual address space - Buffer slot = memory page - Immutable reference = read-only mapping - Thread isolation = process isolation Integration: - Incoming messages appended after pipeline validation - Outgoing responses appended before serialization - Full audit trail preserved This is incremental - handlers still receive copies for backward compatibility. Next step: skip serialization for internal routing. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
a5e2ab22da
commit
fc7170a02e
4 changed files with 784 additions and 0 deletions
27
agentserver/memory/__init__.py
Normal file
27
agentserver/memory/__init__.py
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
"""
|
||||||
|
memory — Virtual memory management for AI agents.
|
||||||
|
|
||||||
|
Provides thread-scoped, append-only context buffers with:
|
||||||
|
- Immutable slots (handlers can't modify messages)
|
||||||
|
- Thread isolation (handlers only see their context)
|
||||||
|
- Complete audit trail (all messages preserved)
|
||||||
|
- GC and limits (prevent runaway memory usage)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from agentserver.memory.context_buffer import (
|
||||||
|
ContextBuffer,
|
||||||
|
ThreadContext,
|
||||||
|
BufferSlot,
|
||||||
|
SlotMetadata,
|
||||||
|
get_context_buffer,
|
||||||
|
slot_to_handler_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ContextBuffer",
|
||||||
|
"ThreadContext",
|
||||||
|
"BufferSlot",
|
||||||
|
"SlotMetadata",
|
||||||
|
"get_context_buffer",
|
||||||
|
"slot_to_handler_metadata",
|
||||||
|
]
|
||||||
299
agentserver/memory/context_buffer.py
Normal file
299
agentserver/memory/context_buffer.py
Normal file
|
|
@ -0,0 +1,299 @@
|
||||||
|
"""
|
||||||
|
context_buffer.py — Virtual memory manager for AI agents.
|
||||||
|
|
||||||
|
Provides thread-scoped, append-only storage for validated message payloads.
|
||||||
|
Handlers receive immutable references to buffer slots, never copies.
|
||||||
|
|
||||||
|
Design principles:
|
||||||
|
- Append-only: Messages cannot be modified after insertion
|
||||||
|
- Thread isolation: Handlers only see their thread's context
|
||||||
|
- Immutable references: Handlers get read-only views
|
||||||
|
- Complete audit trail: All messages preserved in order
|
||||||
|
|
||||||
|
Analogous to OS virtual memory:
|
||||||
|
- Thread ID = virtual address space
|
||||||
|
- Buffer slot = memory page
|
||||||
|
- Thread registry = page table
|
||||||
|
- Immutable reference = read-only mapping
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
buffer = get_context_buffer()
|
||||||
|
|
||||||
|
# Append validated payload (returns slot reference)
|
||||||
|
ref = buffer.append(thread_id, payload, metadata)
|
||||||
|
|
||||||
|
# Handler receives reference
|
||||||
|
handler(ref.payload, ref.metadata)
|
||||||
|
|
||||||
|
# Get thread history
|
||||||
|
history = buffer.get_thread(thread_id)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Dict, List, Optional, Iterator
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SlotMetadata:
|
||||||
|
"""Immutable metadata for a buffer slot."""
|
||||||
|
thread_id: str
|
||||||
|
from_id: str
|
||||||
|
to_id: str
|
||||||
|
slot_index: int
|
||||||
|
timestamp: str
|
||||||
|
payload_type: str
|
||||||
|
|
||||||
|
# Handler-facing metadata (subset exposed to handlers)
|
||||||
|
own_name: Optional[str] = None
|
||||||
|
is_self_call: bool = False
|
||||||
|
usage_instructions: str = ""
|
||||||
|
todo_nudge: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class BufferSlot:
|
||||||
|
"""
|
||||||
|
Immutable slot in the context buffer.
|
||||||
|
|
||||||
|
frozen=True ensures the slot cannot be modified after creation.
|
||||||
|
Handlers receive this directly - they cannot mutate it.
|
||||||
|
"""
|
||||||
|
payload: Any # The validated @xmlify dataclass (immutable reference)
|
||||||
|
metadata: SlotMetadata # Immutable slot metadata
|
||||||
|
|
||||||
|
@property
|
||||||
|
def thread_id(self) -> str:
|
||||||
|
return self.metadata.thread_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def from_id(self) -> str:
|
||||||
|
return self.metadata.from_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def to_id(self) -> str:
|
||||||
|
return self.metadata.to_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def index(self) -> int:
|
||||||
|
return self.metadata.slot_index
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadContext:
|
||||||
|
"""
|
||||||
|
Append-only context buffer for a single thread.
|
||||||
|
|
||||||
|
All slots are immutable once appended.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, thread_id: str):
|
||||||
|
self.thread_id = thread_id
|
||||||
|
self._slots: List[BufferSlot] = []
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._created_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
def append(
|
||||||
|
self,
|
||||||
|
payload: Any,
|
||||||
|
from_id: str,
|
||||||
|
to_id: str,
|
||||||
|
own_name: Optional[str] = None,
|
||||||
|
is_self_call: bool = False,
|
||||||
|
usage_instructions: str = "",
|
||||||
|
todo_nudge: str = "",
|
||||||
|
) -> BufferSlot:
|
||||||
|
"""
|
||||||
|
Append a validated payload to this thread's context.
|
||||||
|
|
||||||
|
Returns the immutable slot reference.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
slot_index = len(self._slots)
|
||||||
|
|
||||||
|
metadata = SlotMetadata(
|
||||||
|
thread_id=self.thread_id,
|
||||||
|
from_id=from_id,
|
||||||
|
to_id=to_id,
|
||||||
|
slot_index=slot_index,
|
||||||
|
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||||
|
payload_type=type(payload).__name__,
|
||||||
|
own_name=own_name,
|
||||||
|
is_self_call=is_self_call,
|
||||||
|
usage_instructions=usage_instructions,
|
||||||
|
todo_nudge=todo_nudge,
|
||||||
|
)
|
||||||
|
|
||||||
|
slot = BufferSlot(payload=payload, metadata=metadata)
|
||||||
|
self._slots.append(slot)
|
||||||
|
|
||||||
|
return slot
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
with self._lock:
|
||||||
|
return len(self._slots)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> BufferSlot:
|
||||||
|
with self._lock:
|
||||||
|
return self._slots[index]
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[BufferSlot]:
|
||||||
|
# Return a copy of the list to avoid mutation during iteration
|
||||||
|
with self._lock:
|
||||||
|
return iter(list(self._slots))
|
||||||
|
|
||||||
|
def get_slice(self, start: int = 0, end: Optional[int] = None) -> List[BufferSlot]:
|
||||||
|
"""Get a slice of the context (for paging/windowing)."""
|
||||||
|
with self._lock:
|
||||||
|
return list(self._slots[start:end])
|
||||||
|
|
||||||
|
def get_by_type(self, payload_type: str) -> List[BufferSlot]:
|
||||||
|
"""Get all slots with a specific payload type."""
|
||||||
|
with self._lock:
|
||||||
|
return [s for s in self._slots if s.metadata.payload_type == payload_type]
|
||||||
|
|
||||||
|
def get_from(self, from_id: str) -> List[BufferSlot]:
|
||||||
|
"""Get all slots from a specific sender."""
|
||||||
|
with self._lock:
|
||||||
|
return [s for s in self._slots if s.from_id == from_id]
|
||||||
|
|
||||||
|
|
||||||
|
class ContextBuffer:
|
||||||
|
"""
|
||||||
|
Global context buffer managing all thread contexts.
|
||||||
|
|
||||||
|
Thread-safe. Singleton pattern via get_context_buffer().
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._threads: Dict[str, ThreadContext] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
# Limits (can be configured)
|
||||||
|
self.max_slots_per_thread: int = 10000
|
||||||
|
self.max_threads: int = 1000
|
||||||
|
|
||||||
|
def get_or_create_thread(self, thread_id: str) -> ThreadContext:
|
||||||
|
"""Get existing thread context or create new one."""
|
||||||
|
with self._lock:
|
||||||
|
if thread_id not in self._threads:
|
||||||
|
if len(self._threads) >= self.max_threads:
|
||||||
|
# GC: remove oldest thread (simple strategy)
|
||||||
|
oldest = min(self._threads.values(), key=lambda t: t._created_at)
|
||||||
|
del self._threads[oldest.thread_id]
|
||||||
|
|
||||||
|
self._threads[thread_id] = ThreadContext(thread_id)
|
||||||
|
|
||||||
|
return self._threads[thread_id]
|
||||||
|
|
||||||
|
def append(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
payload: Any,
|
||||||
|
from_id: str,
|
||||||
|
to_id: str,
|
||||||
|
own_name: Optional[str] = None,
|
||||||
|
is_self_call: bool = False,
|
||||||
|
usage_instructions: str = "",
|
||||||
|
todo_nudge: str = "",
|
||||||
|
) -> BufferSlot:
|
||||||
|
"""
|
||||||
|
Append a validated payload to a thread's context.
|
||||||
|
|
||||||
|
This is the main entry point for the pipeline.
|
||||||
|
Returns the immutable slot reference.
|
||||||
|
"""
|
||||||
|
thread = self.get_or_create_thread(thread_id)
|
||||||
|
|
||||||
|
# Enforce slot limit
|
||||||
|
if len(thread) >= self.max_slots_per_thread:
|
||||||
|
raise MemoryError(
|
||||||
|
f"Thread {thread_id} exceeded max slots ({self.max_slots_per_thread})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return thread.append(
|
||||||
|
payload=payload,
|
||||||
|
from_id=from_id,
|
||||||
|
to_id=to_id,
|
||||||
|
own_name=own_name,
|
||||||
|
is_self_call=is_self_call,
|
||||||
|
usage_instructions=usage_instructions,
|
||||||
|
todo_nudge=todo_nudge,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_thread(self, thread_id: str) -> Optional[ThreadContext]:
|
||||||
|
"""Get a thread's context (None if not found)."""
|
||||||
|
with self._lock:
|
||||||
|
return self._threads.get(thread_id)
|
||||||
|
|
||||||
|
def thread_exists(self, thread_id: str) -> bool:
|
||||||
|
"""Check if a thread exists."""
|
||||||
|
with self._lock:
|
||||||
|
return thread_id in self._threads
|
||||||
|
|
||||||
|
def delete_thread(self, thread_id: str) -> bool:
|
||||||
|
"""Delete a thread's context (GC)."""
|
||||||
|
with self._lock:
|
||||||
|
if thread_id in self._threads:
|
||||||
|
del self._threads[thread_id]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""Get buffer statistics."""
|
||||||
|
with self._lock:
|
||||||
|
total_slots = sum(len(t) for t in self._threads.values())
|
||||||
|
return {
|
||||||
|
"thread_count": len(self._threads),
|
||||||
|
"total_slots": total_slots,
|
||||||
|
"max_threads": self.max_threads,
|
||||||
|
"max_slots_per_thread": self.max_slots_per_thread,
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clear all contexts (for testing)."""
|
||||||
|
with self._lock:
|
||||||
|
self._threads.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Singleton
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
_buffer: Optional[ContextBuffer] = None
|
||||||
|
_buffer_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_context_buffer() -> ContextBuffer:
|
||||||
|
"""Get the global ContextBuffer singleton."""
|
||||||
|
global _buffer
|
||||||
|
if _buffer is None:
|
||||||
|
with _buffer_lock:
|
||||||
|
if _buffer is None:
|
||||||
|
_buffer = ContextBuffer()
|
||||||
|
return _buffer
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Handler-facing metadata adapter
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
def slot_to_handler_metadata(slot: BufferSlot) -> 'HandlerMetadata':
|
||||||
|
"""
|
||||||
|
Convert SlotMetadata to HandlerMetadata for backward compatibility.
|
||||||
|
|
||||||
|
Handlers still receive HandlerMetadata, but it's derived from the slot.
|
||||||
|
"""
|
||||||
|
from agentserver.message_bus.message_state import HandlerMetadata
|
||||||
|
|
||||||
|
return HandlerMetadata(
|
||||||
|
thread_id=slot.metadata.thread_id,
|
||||||
|
from_id=slot.metadata.from_id,
|
||||||
|
own_name=slot.metadata.own_name,
|
||||||
|
is_self_call=slot.metadata.is_self_call,
|
||||||
|
usage_instructions=slot.metadata.usage_instructions,
|
||||||
|
todo_nudge=slot.metadata.todo_nudge,
|
||||||
|
)
|
||||||
|
|
@ -32,6 +32,7 @@ from agentserver.message_bus.steps.thread_assignment import thread_assignment_st
|
||||||
from agentserver.message_bus.message_state import MessageState, HandlerMetadata, HandlerResponse, SystemError, ROUTING_ERROR
|
from agentserver.message_bus.message_state import MessageState, HandlerMetadata, HandlerResponse, SystemError, ROUTING_ERROR
|
||||||
from agentserver.message_bus.thread_registry import get_registry
|
from agentserver.message_bus.thread_registry import get_registry
|
||||||
from agentserver.message_bus.todo_registry import get_todo_registry
|
from agentserver.message_bus.todo_registry import get_todo_registry
|
||||||
|
from agentserver.memory import get_context_buffer
|
||||||
|
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|
@ -348,6 +349,7 @@ class StreamPump:
|
||||||
# Ensure we have a valid thread chain
|
# Ensure we have a valid thread chain
|
||||||
registry = get_registry()
|
registry = get_registry()
|
||||||
todo_registry = get_todo_registry()
|
todo_registry = get_todo_registry()
|
||||||
|
context_buffer = get_context_buffer()
|
||||||
current_thread = state.thread_id or ""
|
current_thread = state.thread_id or ""
|
||||||
|
|
||||||
# Check if thread exists in registry; if not, register it
|
# Check if thread exists in registry; if not, register it
|
||||||
|
|
@ -377,6 +379,27 @@ class StreamPump:
|
||||||
raised = todo_registry.get_raised_for(current_thread, listener.name)
|
raised = todo_registry.get_raised_for(current_thread, listener.name)
|
||||||
todo_nudge = todo_registry.format_nudge(raised)
|
todo_nudge = todo_registry.format_nudge(raised)
|
||||||
|
|
||||||
|
# === CONTEXT BUFFER: Record incoming message ===
|
||||||
|
# Append validated payload to thread's context buffer
|
||||||
|
if current_thread and state.payload:
|
||||||
|
try:
|
||||||
|
context_buffer.append(
|
||||||
|
thread_id=current_thread,
|
||||||
|
payload=state.payload,
|
||||||
|
from_id=state.from_id or "unknown",
|
||||||
|
to_id=listener.name,
|
||||||
|
own_name=listener.name if listener.is_agent else None,
|
||||||
|
is_self_call=is_self_call,
|
||||||
|
usage_instructions=listener.usage_instructions,
|
||||||
|
todo_nudge=todo_nudge,
|
||||||
|
)
|
||||||
|
except MemoryError:
|
||||||
|
# Thread exceeded max slots - log and continue
|
||||||
|
import logging
|
||||||
|
logging.getLogger(__name__).warning(
|
||||||
|
f"Thread {current_thread[:8]}... exceeded context buffer limit"
|
||||||
|
)
|
||||||
|
|
||||||
metadata = HandlerMetadata(
|
metadata = HandlerMetadata(
|
||||||
thread_id=current_thread,
|
thread_id=current_thread,
|
||||||
from_id=state.from_id or "",
|
from_id=state.from_id or "",
|
||||||
|
|
@ -435,6 +458,22 @@ class StreamPump:
|
||||||
to_id = requested_to
|
to_id = requested_to
|
||||||
thread_id = registry.extend_chain(current_thread, to_id)
|
thread_id = registry.extend_chain(current_thread, to_id)
|
||||||
|
|
||||||
|
# === CONTEXT BUFFER: Record outgoing response ===
|
||||||
|
# Append handler's response to the target thread's buffer
|
||||||
|
# This happens BEFORE serialization - the buffer holds the clean payload
|
||||||
|
try:
|
||||||
|
context_buffer.append(
|
||||||
|
thread_id=thread_id,
|
||||||
|
payload=response.payload,
|
||||||
|
from_id=listener.name,
|
||||||
|
to_id=to_id,
|
||||||
|
)
|
||||||
|
except MemoryError:
|
||||||
|
import logging
|
||||||
|
logging.getLogger(__name__).warning(
|
||||||
|
f"Thread {thread_id[:8]}... exceeded context buffer limit"
|
||||||
|
)
|
||||||
|
|
||||||
response_bytes = self._wrap_in_envelope(
|
response_bytes = self._wrap_in_envelope(
|
||||||
payload=response.payload,
|
payload=response.payload,
|
||||||
from_id=listener.name,
|
from_id=listener.name,
|
||||||
|
|
|
||||||
419
tests/test_context_buffer.py
Normal file
419
tests/test_context_buffer.py
Normal file
|
|
@ -0,0 +1,419 @@
|
||||||
|
"""
|
||||||
|
test_context_buffer.py — Tests for the AI agent virtual memory manager.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
1. Append-only semantics
|
||||||
|
2. Immutability guarantees
|
||||||
|
3. Thread isolation
|
||||||
|
4. Slot references
|
||||||
|
5. GC and limits
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, FrozenInstanceError
|
||||||
|
|
||||||
|
from agentserver.memory.context_buffer import (
|
||||||
|
ContextBuffer,
|
||||||
|
ThreadContext,
|
||||||
|
BufferSlot,
|
||||||
|
SlotMetadata,
|
||||||
|
get_context_buffer,
|
||||||
|
slot_to_handler_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test payload classes
|
||||||
|
@dataclass
|
||||||
|
class TestPayload:
|
||||||
|
message: str
|
||||||
|
value: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FrozenPayload:
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class TestBufferSlotImmutability:
|
||||||
|
"""Test that buffer slots cannot be modified."""
|
||||||
|
|
||||||
|
def test_slot_is_frozen(self):
|
||||||
|
"""BufferSlot should be frozen (immutable)."""
|
||||||
|
metadata = SlotMetadata(
|
||||||
|
thread_id="t1",
|
||||||
|
from_id="sender",
|
||||||
|
to_id="receiver",
|
||||||
|
slot_index=0,
|
||||||
|
timestamp="2024-01-01T00:00:00Z",
|
||||||
|
payload_type="TestPayload",
|
||||||
|
)
|
||||||
|
slot = BufferSlot(payload=TestPayload(message="hello"), metadata=metadata)
|
||||||
|
|
||||||
|
# Cannot modify slot attributes
|
||||||
|
with pytest.raises(FrozenInstanceError):
|
||||||
|
slot.metadata = None
|
||||||
|
|
||||||
|
def test_slot_metadata_is_frozen(self):
|
||||||
|
"""SlotMetadata should be frozen (immutable)."""
|
||||||
|
metadata = SlotMetadata(
|
||||||
|
thread_id="t1",
|
||||||
|
from_id="sender",
|
||||||
|
to_id="receiver",
|
||||||
|
slot_index=0,
|
||||||
|
timestamp="2024-01-01T00:00:00Z",
|
||||||
|
payload_type="TestPayload",
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(FrozenInstanceError):
|
||||||
|
metadata.thread_id = "modified"
|
||||||
|
|
||||||
|
def test_payload_reference_preserved(self):
|
||||||
|
"""Slot should preserve reference to original payload."""
|
||||||
|
payload = TestPayload(message="original")
|
||||||
|
metadata = SlotMetadata(
|
||||||
|
thread_id="t1",
|
||||||
|
from_id="sender",
|
||||||
|
to_id="receiver",
|
||||||
|
slot_index=0,
|
||||||
|
timestamp="2024-01-01T00:00:00Z",
|
||||||
|
payload_type="TestPayload",
|
||||||
|
)
|
||||||
|
slot = BufferSlot(payload=payload, metadata=metadata)
|
||||||
|
|
||||||
|
# Same reference
|
||||||
|
assert slot.payload is payload
|
||||||
|
|
||||||
|
|
||||||
|
class TestThreadContext:
|
||||||
|
"""Test single-thread context buffer."""
|
||||||
|
|
||||||
|
def test_append_creates_slot(self):
|
||||||
|
"""Appending returns a BufferSlot."""
|
||||||
|
ctx = ThreadContext("thread-1")
|
||||||
|
|
||||||
|
slot = ctx.append(
|
||||||
|
payload=TestPayload(message="test"),
|
||||||
|
from_id="sender",
|
||||||
|
to_id="receiver",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(slot, BufferSlot)
|
||||||
|
assert slot.payload.message == "test"
|
||||||
|
assert slot.from_id == "sender"
|
||||||
|
assert slot.to_id == "receiver"
|
||||||
|
assert slot.index == 0
|
||||||
|
|
||||||
|
def test_append_increments_index(self):
|
||||||
|
"""Each append gets a new index."""
|
||||||
|
ctx = ThreadContext("thread-1")
|
||||||
|
|
||||||
|
slot1 = ctx.append(TestPayload("a"), "s", "r")
|
||||||
|
slot2 = ctx.append(TestPayload("b"), "s", "r")
|
||||||
|
slot3 = ctx.append(TestPayload("c"), "s", "r")
|
||||||
|
|
||||||
|
assert slot1.index == 0
|
||||||
|
assert slot2.index == 1
|
||||||
|
assert slot3.index == 2
|
||||||
|
assert len(ctx) == 3
|
||||||
|
|
||||||
|
def test_getitem_returns_slot(self):
|
||||||
|
"""Can access slots by index."""
|
||||||
|
ctx = ThreadContext("thread-1")
|
||||||
|
|
||||||
|
ctx.append(TestPayload("first"), "s", "r")
|
||||||
|
ctx.append(TestPayload("second"), "s", "r")
|
||||||
|
|
||||||
|
assert ctx[0].payload.message == "first"
|
||||||
|
assert ctx[1].payload.message == "second"
|
||||||
|
|
||||||
|
def test_iteration(self):
|
||||||
|
"""Can iterate over all slots."""
|
||||||
|
ctx = ThreadContext("thread-1")
|
||||||
|
|
||||||
|
ctx.append(TestPayload("a"), "s", "r")
|
||||||
|
ctx.append(TestPayload("b"), "s", "r")
|
||||||
|
ctx.append(TestPayload("c"), "s", "r")
|
||||||
|
|
||||||
|
messages = [slot.payload.message for slot in ctx]
|
||||||
|
assert messages == ["a", "b", "c"]
|
||||||
|
|
||||||
|
def test_get_by_type(self):
|
||||||
|
"""Can filter slots by payload type."""
|
||||||
|
ctx = ThreadContext("thread-1")
|
||||||
|
|
||||||
|
ctx.append(TestPayload("test"), "s", "r")
|
||||||
|
ctx.append(FrozenPayload("frozen"), "s", "r")
|
||||||
|
ctx.append(TestPayload("test2"), "s", "r")
|
||||||
|
|
||||||
|
test_slots = ctx.get_by_type("TestPayload")
|
||||||
|
assert len(test_slots) == 2
|
||||||
|
|
||||||
|
frozen_slots = ctx.get_by_type("FrozenPayload")
|
||||||
|
assert len(frozen_slots) == 1
|
||||||
|
|
||||||
|
def test_get_from(self):
|
||||||
|
"""Can filter slots by sender."""
|
||||||
|
ctx = ThreadContext("thread-1")
|
||||||
|
|
||||||
|
ctx.append(TestPayload("a"), "alice", "r")
|
||||||
|
ctx.append(TestPayload("b"), "bob", "r")
|
||||||
|
ctx.append(TestPayload("c"), "alice", "r")
|
||||||
|
|
||||||
|
alice_slots = ctx.get_from("alice")
|
||||||
|
assert len(alice_slots) == 2
|
||||||
|
|
||||||
|
bob_slots = ctx.get_from("bob")
|
||||||
|
assert len(bob_slots) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextBuffer:
|
||||||
|
"""Test global context buffer."""
|
||||||
|
|
||||||
|
def test_get_or_create_thread(self):
|
||||||
|
"""get_or_create_thread creates new thread if needed."""
|
||||||
|
buffer = ContextBuffer()
|
||||||
|
|
||||||
|
ctx1 = buffer.get_or_create_thread("thread-1")
|
||||||
|
ctx2 = buffer.get_or_create_thread("thread-1")
|
||||||
|
ctx3 = buffer.get_or_create_thread("thread-2")
|
||||||
|
|
||||||
|
assert ctx1 is ctx2 # Same thread
|
||||||
|
assert ctx1 is not ctx3 # Different thread
|
||||||
|
|
||||||
|
def test_append_to_thread(self):
|
||||||
|
"""append() adds to correct thread."""
|
||||||
|
buffer = ContextBuffer()
|
||||||
|
|
||||||
|
slot1 = buffer.append("t1", TestPayload("a"), "s", "r")
|
||||||
|
slot2 = buffer.append("t2", TestPayload("b"), "s", "r")
|
||||||
|
slot3 = buffer.append("t1", TestPayload("c"), "s", "r")
|
||||||
|
|
||||||
|
t1 = buffer.get_thread("t1")
|
||||||
|
t2 = buffer.get_thread("t2")
|
||||||
|
|
||||||
|
assert len(t1) == 2
|
||||||
|
assert len(t2) == 1
|
||||||
|
assert t1[0].payload.message == "a"
|
||||||
|
assert t1[1].payload.message == "c"
|
||||||
|
assert t2[0].payload.message == "b"
|
||||||
|
|
||||||
|
def test_thread_isolation(self):
|
||||||
|
"""Threads cannot see each other's slots."""
|
||||||
|
buffer = ContextBuffer()
|
||||||
|
|
||||||
|
buffer.append("thread-a", TestPayload("secret-a"), "s", "r")
|
||||||
|
buffer.append("thread-b", TestPayload("secret-b"), "s", "r")
|
||||||
|
|
||||||
|
thread_a = buffer.get_thread("thread-a")
|
||||||
|
thread_b = buffer.get_thread("thread-b")
|
||||||
|
|
||||||
|
# Each thread only sees its own messages
|
||||||
|
a_messages = [s.payload.message for s in thread_a]
|
||||||
|
b_messages = [s.payload.message for s in thread_b]
|
||||||
|
|
||||||
|
assert a_messages == ["secret-a"]
|
||||||
|
assert b_messages == ["secret-b"]
|
||||||
|
|
||||||
|
def test_delete_thread(self):
|
||||||
|
"""delete_thread removes thread context."""
|
||||||
|
buffer = ContextBuffer()
|
||||||
|
|
||||||
|
buffer.append("t1", TestPayload("test"), "s", "r")
|
||||||
|
assert buffer.thread_exists("t1")
|
||||||
|
|
||||||
|
result = buffer.delete_thread("t1")
|
||||||
|
assert result is True
|
||||||
|
assert not buffer.thread_exists("t1")
|
||||||
|
|
||||||
|
def test_max_slots_limit(self):
|
||||||
|
"""Exceeding max_slots_per_thread raises MemoryError."""
|
||||||
|
buffer = ContextBuffer()
|
||||||
|
buffer.max_slots_per_thread = 3
|
||||||
|
|
||||||
|
buffer.append("t1", TestPayload("1"), "s", "r")
|
||||||
|
buffer.append("t1", TestPayload("2"), "s", "r")
|
||||||
|
buffer.append("t1", TestPayload("3"), "s", "r")
|
||||||
|
|
||||||
|
with pytest.raises(MemoryError) as exc_info:
|
||||||
|
buffer.append("t1", TestPayload("4"), "s", "r")
|
||||||
|
|
||||||
|
assert "exceeded max slots" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_max_threads_gc(self):
|
||||||
|
"""Exceeding max_threads triggers GC of oldest thread."""
|
||||||
|
buffer = ContextBuffer()
|
||||||
|
buffer.max_threads = 2
|
||||||
|
|
||||||
|
buffer.append("t1", TestPayload("first"), "s", "r")
|
||||||
|
buffer.append("t2", TestPayload("second"), "s", "r")
|
||||||
|
|
||||||
|
# Adding third thread should GC the oldest
|
||||||
|
buffer.append("t3", TestPayload("third"), "s", "r")
|
||||||
|
|
||||||
|
stats = buffer.get_stats()
|
||||||
|
assert stats["thread_count"] == 2
|
||||||
|
|
||||||
|
# t1 should be gone (oldest)
|
||||||
|
assert not buffer.thread_exists("t1")
|
||||||
|
assert buffer.thread_exists("t2")
|
||||||
|
assert buffer.thread_exists("t3")
|
||||||
|
|
||||||
|
def test_get_stats(self):
|
||||||
|
"""get_stats returns buffer statistics."""
|
||||||
|
buffer = ContextBuffer()
|
||||||
|
|
||||||
|
buffer.append("t1", TestPayload("a"), "s", "r")
|
||||||
|
buffer.append("t1", TestPayload("b"), "s", "r")
|
||||||
|
buffer.append("t2", TestPayload("c"), "s", "r")
|
||||||
|
|
||||||
|
stats = buffer.get_stats()
|
||||||
|
|
||||||
|
assert stats["thread_count"] == 2
|
||||||
|
assert stats["total_slots"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestHandlerMetadataAdapter:
|
||||||
|
"""Test conversion from SlotMetadata to HandlerMetadata."""
|
||||||
|
|
||||||
|
def test_slot_to_handler_metadata(self):
|
||||||
|
"""slot_to_handler_metadata converts correctly."""
|
||||||
|
buffer = ContextBuffer()
|
||||||
|
|
||||||
|
slot = buffer.append(
|
||||||
|
thread_id="t1",
|
||||||
|
payload=TestPayload("test"),
|
||||||
|
from_id="sender",
|
||||||
|
to_id="receiver",
|
||||||
|
own_name="test-agent",
|
||||||
|
is_self_call=True,
|
||||||
|
usage_instructions="instructions here",
|
||||||
|
todo_nudge="nudge here",
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = slot_to_handler_metadata(slot)
|
||||||
|
|
||||||
|
assert metadata.thread_id == "t1"
|
||||||
|
assert metadata.from_id == "sender"
|
||||||
|
assert metadata.own_name == "test-agent"
|
||||||
|
assert metadata.is_self_call is True
|
||||||
|
assert metadata.usage_instructions == "instructions here"
|
||||||
|
assert metadata.todo_nudge == "nudge here"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSingleton:
|
||||||
|
"""Test singleton behavior."""
|
||||||
|
|
||||||
|
def test_get_context_buffer_singleton(self):
|
||||||
|
"""get_context_buffer returns same instance."""
|
||||||
|
buf1 = get_context_buffer()
|
||||||
|
buf2 = get_context_buffer()
|
||||||
|
|
||||||
|
assert buf1 is buf2
|
||||||
|
|
||||||
|
def test_clear_resets_state(self):
|
||||||
|
"""clear() removes all threads."""
|
||||||
|
buffer = get_context_buffer()
|
||||||
|
buffer.append("test-thread", TestPayload("test"), "s", "r")
|
||||||
|
|
||||||
|
buffer.clear()
|
||||||
|
|
||||||
|
assert not buffer.thread_exists("test-thread")
|
||||||
|
assert buffer.get_stats()["thread_count"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestPumpIntegration:
|
||||||
|
"""Test context buffer integration with StreamPump."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_buffer_records_messages_during_flow(self):
|
||||||
|
"""Context buffer should record messages as they flow through pump."""
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from agentserver.message_bus.stream_pump import StreamPump, ListenerConfig, OrganismConfig
|
||||||
|
from agentserver.message_bus.message_state import HandlerResponse
|
||||||
|
from agentserver.llm.backend import LLMResponse
|
||||||
|
|
||||||
|
# Import handlers
|
||||||
|
from handlers.hello import Greeting, GreetingResponse, handle_greeting, handle_shout
|
||||||
|
from handlers.console import ShoutedResponse
|
||||||
|
|
||||||
|
# Clear buffer
|
||||||
|
buffer = get_context_buffer()
|
||||||
|
buffer.clear()
|
||||||
|
|
||||||
|
# Create pump with greeter and shouter
|
||||||
|
config = OrganismConfig(name="buffer-test")
|
||||||
|
pump = StreamPump(config)
|
||||||
|
|
||||||
|
pump.register_listener(ListenerConfig(
|
||||||
|
name="greeter",
|
||||||
|
payload_class_path="handlers.hello.Greeting",
|
||||||
|
handler_path="handlers.hello.handle_greeting",
|
||||||
|
description="Greeting agent",
|
||||||
|
is_agent=True,
|
||||||
|
peers=["shouter"],
|
||||||
|
payload_class=Greeting,
|
||||||
|
handler=handle_greeting,
|
||||||
|
))
|
||||||
|
|
||||||
|
pump.register_listener(ListenerConfig(
|
||||||
|
name="shouter",
|
||||||
|
payload_class_path="handlers.hello.GreetingResponse",
|
||||||
|
handler_path="handlers.hello.handle_shout",
|
||||||
|
description="Shouts",
|
||||||
|
payload_class=GreetingResponse,
|
||||||
|
handler=handle_shout,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Mock LLM
|
||||||
|
mock_llm = LLMResponse(
|
||||||
|
content="Hello there!",
|
||||||
|
model="mock",
|
||||||
|
usage={"total_tokens": 5},
|
||||||
|
finish_reason="stop",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prevent re-injection loops
|
||||||
|
async def noop_reinject(state):
|
||||||
|
pass
|
||||||
|
pump._reinject_responses = noop_reinject
|
||||||
|
|
||||||
|
with patch('agentserver.llm.complete', new=AsyncMock(return_value=mock_llm)):
|
||||||
|
# Create envelope for Greeting
|
||||||
|
thread_id = str(uuid.uuid4())
|
||||||
|
envelope = f"""<message xmlns="https://xml-pipeline.org/ns/envelope/v1">
|
||||||
|
<meta><from>user</from><to>greeter</to><thread>{thread_id}</thread></meta>
|
||||||
|
<Greeting xmlns=""><Name>Alice</Name></Greeting>
|
||||||
|
</message>""".encode()
|
||||||
|
|
||||||
|
await pump.inject(envelope, thread_id, from_id="user")
|
||||||
|
|
||||||
|
# Run pump to process
|
||||||
|
pump._running = True
|
||||||
|
pipeline = pump.build_pipeline(pump._queue_source())
|
||||||
|
|
||||||
|
async def run_chain():
|
||||||
|
async with pipeline.stream() as streamer:
|
||||||
|
count = 0
|
||||||
|
async for _ in streamer:
|
||||||
|
count += 1
|
||||||
|
if count >= 3:
|
||||||
|
break
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(run_chain(), timeout=3.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
pump._running = False
|
||||||
|
|
||||||
|
# Verify buffer recorded the messages
|
||||||
|
thread_ctx = buffer.get_thread(thread_id)
|
||||||
|
assert thread_ctx is not None, "Thread should exist in buffer"
|
||||||
|
assert len(thread_ctx) >= 1, "Buffer should have at least one message"
|
||||||
|
|
||||||
|
# Check that we recorded a Greeting
|
||||||
|
greeting_slots = thread_ctx.get_by_type("Greeting")
|
||||||
|
assert len(greeting_slots) >= 1, "Should have recorded Greeting"
|
||||||
|
assert greeting_slots[0].payload.name == "Alice"
|
||||||
Loading…
Reference in a new issue