Add context buffer - virtual memory manager for AI agents
Implements thread-scoped, append-only storage for validated payloads. Handlers receive immutable references; messages cannot be modified after insertion. Core components: - BufferSlot: frozen dataclass holding payload + metadata - ThreadContext: append-only buffer per thread - ContextBuffer: global manager with GC and limits Design parallels OS virtual memory: - Thread ID = virtual address space - Buffer slot = memory page - Immutable reference = read-only mapping - Thread isolation = process isolation Integration: - Incoming messages appended after pipeline validation - Outgoing responses appended before serialization - Full audit trail preserved This is incremental - handlers still receive copies for backward compatibility. Next step: skip serialization for internal routing. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
a5e2ab22da
commit
fc7170a02e
4 changed files with 784 additions and 0 deletions
27
agentserver/memory/__init__.py
Normal file
27
agentserver/memory/__init__.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""
|
||||
memory — Virtual memory management for AI agents.
|
||||
|
||||
Provides thread-scoped, append-only context buffers with:
|
||||
- Immutable slots (handlers can't modify messages)
|
||||
- Thread isolation (handlers only see their context)
|
||||
- Complete audit trail (all messages preserved)
|
||||
- GC and limits (prevent runaway memory usage)
|
||||
"""
|
||||
|
||||
from agentserver.memory.context_buffer import (
|
||||
ContextBuffer,
|
||||
ThreadContext,
|
||||
BufferSlot,
|
||||
SlotMetadata,
|
||||
get_context_buffer,
|
||||
slot_to_handler_metadata,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ContextBuffer",
|
||||
"ThreadContext",
|
||||
"BufferSlot",
|
||||
"SlotMetadata",
|
||||
"get_context_buffer",
|
||||
"slot_to_handler_metadata",
|
||||
]
|
||||
299
agentserver/memory/context_buffer.py
Normal file
299
agentserver/memory/context_buffer.py
Normal file
|
|
@ -0,0 +1,299 @@
|
|||
"""
|
||||
context_buffer.py — Virtual memory manager for AI agents.
|
||||
|
||||
Provides thread-scoped, append-only storage for validated message payloads.
|
||||
Handlers receive immutable references to buffer slots, never copies.
|
||||
|
||||
Design principles:
|
||||
- Append-only: Messages cannot be modified after insertion
|
||||
- Thread isolation: Handlers only see their thread's context
|
||||
- Immutable references: Handlers get read-only views
|
||||
- Complete audit trail: All messages preserved in order
|
||||
|
||||
Analogous to OS virtual memory:
|
||||
- Thread ID = virtual address space
|
||||
- Buffer slot = memory page
|
||||
- Thread registry = page table
|
||||
- Immutable reference = read-only mapping
|
||||
|
||||
Usage:
|
||||
buffer = get_context_buffer()
|
||||
|
||||
# Append validated payload (returns slot reference)
|
||||
ref = buffer.append(thread_id, payload, metadata)
|
||||
|
||||
# Handler receives reference
|
||||
handler(ref.payload, ref.metadata)
|
||||
|
||||
# Get thread history
|
||||
history = buffer.get_thread(thread_id)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Iterator
|
||||
from datetime import datetime, timezone
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SlotMetadata:
|
||||
"""Immutable metadata for a buffer slot."""
|
||||
thread_id: str
|
||||
from_id: str
|
||||
to_id: str
|
||||
slot_index: int
|
||||
timestamp: str
|
||||
payload_type: str
|
||||
|
||||
# Handler-facing metadata (subset exposed to handlers)
|
||||
own_name: Optional[str] = None
|
||||
is_self_call: bool = False
|
||||
usage_instructions: str = ""
|
||||
todo_nudge: str = ""
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BufferSlot:
|
||||
"""
|
||||
Immutable slot in the context buffer.
|
||||
|
||||
frozen=True ensures the slot cannot be modified after creation.
|
||||
Handlers receive this directly - they cannot mutate it.
|
||||
"""
|
||||
payload: Any # The validated @xmlify dataclass (immutable reference)
|
||||
metadata: SlotMetadata # Immutable slot metadata
|
||||
|
||||
@property
|
||||
def thread_id(self) -> str:
|
||||
return self.metadata.thread_id
|
||||
|
||||
@property
|
||||
def from_id(self) -> str:
|
||||
return self.metadata.from_id
|
||||
|
||||
@property
|
||||
def to_id(self) -> str:
|
||||
return self.metadata.to_id
|
||||
|
||||
@property
|
||||
def index(self) -> int:
|
||||
return self.metadata.slot_index
|
||||
|
||||
|
||||
class ThreadContext:
|
||||
"""
|
||||
Append-only context buffer for a single thread.
|
||||
|
||||
All slots are immutable once appended.
|
||||
"""
|
||||
|
||||
def __init__(self, thread_id: str):
|
||||
self.thread_id = thread_id
|
||||
self._slots: List[BufferSlot] = []
|
||||
self._lock = threading.Lock()
|
||||
self._created_at = datetime.now(timezone.utc)
|
||||
|
||||
def append(
|
||||
self,
|
||||
payload: Any,
|
||||
from_id: str,
|
||||
to_id: str,
|
||||
own_name: Optional[str] = None,
|
||||
is_self_call: bool = False,
|
||||
usage_instructions: str = "",
|
||||
todo_nudge: str = "",
|
||||
) -> BufferSlot:
|
||||
"""
|
||||
Append a validated payload to this thread's context.
|
||||
|
||||
Returns the immutable slot reference.
|
||||
"""
|
||||
with self._lock:
|
||||
slot_index = len(self._slots)
|
||||
|
||||
metadata = SlotMetadata(
|
||||
thread_id=self.thread_id,
|
||||
from_id=from_id,
|
||||
to_id=to_id,
|
||||
slot_index=slot_index,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
payload_type=type(payload).__name__,
|
||||
own_name=own_name,
|
||||
is_self_call=is_self_call,
|
||||
usage_instructions=usage_instructions,
|
||||
todo_nudge=todo_nudge,
|
||||
)
|
||||
|
||||
slot = BufferSlot(payload=payload, metadata=metadata)
|
||||
self._slots.append(slot)
|
||||
|
||||
return slot
|
||||
|
||||
def __len__(self) -> int:
|
||||
with self._lock:
|
||||
return len(self._slots)
|
||||
|
||||
def __getitem__(self, index: int) -> BufferSlot:
|
||||
with self._lock:
|
||||
return self._slots[index]
|
||||
|
||||
def __iter__(self) -> Iterator[BufferSlot]:
|
||||
# Return a copy of the list to avoid mutation during iteration
|
||||
with self._lock:
|
||||
return iter(list(self._slots))
|
||||
|
||||
def get_slice(self, start: int = 0, end: Optional[int] = None) -> List[BufferSlot]:
|
||||
"""Get a slice of the context (for paging/windowing)."""
|
||||
with self._lock:
|
||||
return list(self._slots[start:end])
|
||||
|
||||
def get_by_type(self, payload_type: str) -> List[BufferSlot]:
|
||||
"""Get all slots with a specific payload type."""
|
||||
with self._lock:
|
||||
return [s for s in self._slots if s.metadata.payload_type == payload_type]
|
||||
|
||||
def get_from(self, from_id: str) -> List[BufferSlot]:
|
||||
"""Get all slots from a specific sender."""
|
||||
with self._lock:
|
||||
return [s for s in self._slots if s.from_id == from_id]
|
||||
|
||||
|
||||
class ContextBuffer:
|
||||
"""
|
||||
Global context buffer managing all thread contexts.
|
||||
|
||||
Thread-safe. Singleton pattern via get_context_buffer().
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._threads: Dict[str, ThreadContext] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Limits (can be configured)
|
||||
self.max_slots_per_thread: int = 10000
|
||||
self.max_threads: int = 1000
|
||||
|
||||
def get_or_create_thread(self, thread_id: str) -> ThreadContext:
|
||||
"""Get existing thread context or create new one."""
|
||||
with self._lock:
|
||||
if thread_id not in self._threads:
|
||||
if len(self._threads) >= self.max_threads:
|
||||
# GC: remove oldest thread (simple strategy)
|
||||
oldest = min(self._threads.values(), key=lambda t: t._created_at)
|
||||
del self._threads[oldest.thread_id]
|
||||
|
||||
self._threads[thread_id] = ThreadContext(thread_id)
|
||||
|
||||
return self._threads[thread_id]
|
||||
|
||||
def append(
|
||||
self,
|
||||
thread_id: str,
|
||||
payload: Any,
|
||||
from_id: str,
|
||||
to_id: str,
|
||||
own_name: Optional[str] = None,
|
||||
is_self_call: bool = False,
|
||||
usage_instructions: str = "",
|
||||
todo_nudge: str = "",
|
||||
) -> BufferSlot:
|
||||
"""
|
||||
Append a validated payload to a thread's context.
|
||||
|
||||
This is the main entry point for the pipeline.
|
||||
Returns the immutable slot reference.
|
||||
"""
|
||||
thread = self.get_or_create_thread(thread_id)
|
||||
|
||||
# Enforce slot limit
|
||||
if len(thread) >= self.max_slots_per_thread:
|
||||
raise MemoryError(
|
||||
f"Thread {thread_id} exceeded max slots ({self.max_slots_per_thread})"
|
||||
)
|
||||
|
||||
return thread.append(
|
||||
payload=payload,
|
||||
from_id=from_id,
|
||||
to_id=to_id,
|
||||
own_name=own_name,
|
||||
is_self_call=is_self_call,
|
||||
usage_instructions=usage_instructions,
|
||||
todo_nudge=todo_nudge,
|
||||
)
|
||||
|
||||
def get_thread(self, thread_id: str) -> Optional[ThreadContext]:
|
||||
"""Get a thread's context (None if not found)."""
|
||||
with self._lock:
|
||||
return self._threads.get(thread_id)
|
||||
|
||||
def thread_exists(self, thread_id: str) -> bool:
|
||||
"""Check if a thread exists."""
|
||||
with self._lock:
|
||||
return thread_id in self._threads
|
||||
|
||||
def delete_thread(self, thread_id: str) -> bool:
|
||||
"""Delete a thread's context (GC)."""
|
||||
with self._lock:
|
||||
if thread_id in self._threads:
|
||||
del self._threads[thread_id]
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get buffer statistics."""
|
||||
with self._lock:
|
||||
total_slots = sum(len(t) for t in self._threads.values())
|
||||
return {
|
||||
"thread_count": len(self._threads),
|
||||
"total_slots": total_slots,
|
||||
"max_threads": self.max_threads,
|
||||
"max_slots_per_thread": self.max_slots_per_thread,
|
||||
}
|
||||
|
||||
def clear(self):
|
||||
"""Clear all contexts (for testing)."""
|
||||
with self._lock:
|
||||
self._threads.clear()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Singleton
|
||||
# ============================================================================
|
||||
|
||||
_buffer: Optional[ContextBuffer] = None
|
||||
_buffer_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_context_buffer() -> ContextBuffer:
|
||||
"""Get the global ContextBuffer singleton."""
|
||||
global _buffer
|
||||
if _buffer is None:
|
||||
with _buffer_lock:
|
||||
if _buffer is None:
|
||||
_buffer = ContextBuffer()
|
||||
return _buffer
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Handler-facing metadata adapter
|
||||
# ============================================================================
|
||||
|
||||
def slot_to_handler_metadata(slot: BufferSlot) -> 'HandlerMetadata':
|
||||
"""
|
||||
Convert SlotMetadata to HandlerMetadata for backward compatibility.
|
||||
|
||||
Handlers still receive HandlerMetadata, but it's derived from the slot.
|
||||
"""
|
||||
from agentserver.message_bus.message_state import HandlerMetadata
|
||||
|
||||
return HandlerMetadata(
|
||||
thread_id=slot.metadata.thread_id,
|
||||
from_id=slot.metadata.from_id,
|
||||
own_name=slot.metadata.own_name,
|
||||
is_self_call=slot.metadata.is_self_call,
|
||||
usage_instructions=slot.metadata.usage_instructions,
|
||||
todo_nudge=slot.metadata.todo_nudge,
|
||||
)
|
||||
|
|
@ -32,6 +32,7 @@ from agentserver.message_bus.steps.thread_assignment import thread_assignment_st
|
|||
from agentserver.message_bus.message_state import MessageState, HandlerMetadata, HandlerResponse, SystemError, ROUTING_ERROR
|
||||
from agentserver.message_bus.thread_registry import get_registry
|
||||
from agentserver.message_bus.todo_registry import get_todo_registry
|
||||
from agentserver.memory import get_context_buffer
|
||||
|
||||
|
||||
# ============================================================================
|
||||
|
|
@ -348,6 +349,7 @@ class StreamPump:
|
|||
# Ensure we have a valid thread chain
|
||||
registry = get_registry()
|
||||
todo_registry = get_todo_registry()
|
||||
context_buffer = get_context_buffer()
|
||||
current_thread = state.thread_id or ""
|
||||
|
||||
# Check if thread exists in registry; if not, register it
|
||||
|
|
@ -377,6 +379,27 @@ class StreamPump:
|
|||
raised = todo_registry.get_raised_for(current_thread, listener.name)
|
||||
todo_nudge = todo_registry.format_nudge(raised)
|
||||
|
||||
# === CONTEXT BUFFER: Record incoming message ===
|
||||
# Append validated payload to thread's context buffer
|
||||
if current_thread and state.payload:
|
||||
try:
|
||||
context_buffer.append(
|
||||
thread_id=current_thread,
|
||||
payload=state.payload,
|
||||
from_id=state.from_id or "unknown",
|
||||
to_id=listener.name,
|
||||
own_name=listener.name if listener.is_agent else None,
|
||||
is_self_call=is_self_call,
|
||||
usage_instructions=listener.usage_instructions,
|
||||
todo_nudge=todo_nudge,
|
||||
)
|
||||
except MemoryError:
|
||||
# Thread exceeded max slots - log and continue
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
f"Thread {current_thread[:8]}... exceeded context buffer limit"
|
||||
)
|
||||
|
||||
metadata = HandlerMetadata(
|
||||
thread_id=current_thread,
|
||||
from_id=state.from_id or "",
|
||||
|
|
@ -435,6 +458,22 @@ class StreamPump:
|
|||
to_id = requested_to
|
||||
thread_id = registry.extend_chain(current_thread, to_id)
|
||||
|
||||
# === CONTEXT BUFFER: Record outgoing response ===
|
||||
# Append handler's response to the target thread's buffer
|
||||
# This happens BEFORE serialization - the buffer holds the clean payload
|
||||
try:
|
||||
context_buffer.append(
|
||||
thread_id=thread_id,
|
||||
payload=response.payload,
|
||||
from_id=listener.name,
|
||||
to_id=to_id,
|
||||
)
|
||||
except MemoryError:
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
f"Thread {thread_id[:8]}... exceeded context buffer limit"
|
||||
)
|
||||
|
||||
response_bytes = self._wrap_in_envelope(
|
||||
payload=response.payload,
|
||||
from_id=listener.name,
|
||||
|
|
|
|||
419
tests/test_context_buffer.py
Normal file
419
tests/test_context_buffer.py
Normal file
|
|
@ -0,0 +1,419 @@
|
|||
"""
|
||||
test_context_buffer.py — Tests for the AI agent virtual memory manager.
|
||||
|
||||
Tests:
|
||||
1. Append-only semantics
|
||||
2. Immutability guarantees
|
||||
3. Thread isolation
|
||||
4. Slot references
|
||||
5. GC and limits
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import uuid
|
||||
from dataclasses import dataclass, FrozenInstanceError
|
||||
|
||||
from agentserver.memory.context_buffer import (
|
||||
ContextBuffer,
|
||||
ThreadContext,
|
||||
BufferSlot,
|
||||
SlotMetadata,
|
||||
get_context_buffer,
|
||||
slot_to_handler_metadata,
|
||||
)
|
||||
|
||||
|
||||
# Test payload classes
|
||||
@dataclass
|
||||
class TestPayload:
|
||||
message: str
|
||||
value: int = 0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FrozenPayload:
|
||||
message: str
|
||||
|
||||
|
||||
class TestBufferSlotImmutability:
|
||||
"""Test that buffer slots cannot be modified."""
|
||||
|
||||
def test_slot_is_frozen(self):
|
||||
"""BufferSlot should be frozen (immutable)."""
|
||||
metadata = SlotMetadata(
|
||||
thread_id="t1",
|
||||
from_id="sender",
|
||||
to_id="receiver",
|
||||
slot_index=0,
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
payload_type="TestPayload",
|
||||
)
|
||||
slot = BufferSlot(payload=TestPayload(message="hello"), metadata=metadata)
|
||||
|
||||
# Cannot modify slot attributes
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
slot.metadata = None
|
||||
|
||||
def test_slot_metadata_is_frozen(self):
|
||||
"""SlotMetadata should be frozen (immutable)."""
|
||||
metadata = SlotMetadata(
|
||||
thread_id="t1",
|
||||
from_id="sender",
|
||||
to_id="receiver",
|
||||
slot_index=0,
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
payload_type="TestPayload",
|
||||
)
|
||||
|
||||
with pytest.raises(FrozenInstanceError):
|
||||
metadata.thread_id = "modified"
|
||||
|
||||
def test_payload_reference_preserved(self):
|
||||
"""Slot should preserve reference to original payload."""
|
||||
payload = TestPayload(message="original")
|
||||
metadata = SlotMetadata(
|
||||
thread_id="t1",
|
||||
from_id="sender",
|
||||
to_id="receiver",
|
||||
slot_index=0,
|
||||
timestamp="2024-01-01T00:00:00Z",
|
||||
payload_type="TestPayload",
|
||||
)
|
||||
slot = BufferSlot(payload=payload, metadata=metadata)
|
||||
|
||||
# Same reference
|
||||
assert slot.payload is payload
|
||||
|
||||
|
||||
class TestThreadContext:
|
||||
"""Test single-thread context buffer."""
|
||||
|
||||
def test_append_creates_slot(self):
|
||||
"""Appending returns a BufferSlot."""
|
||||
ctx = ThreadContext("thread-1")
|
||||
|
||||
slot = ctx.append(
|
||||
payload=TestPayload(message="test"),
|
||||
from_id="sender",
|
||||
to_id="receiver",
|
||||
)
|
||||
|
||||
assert isinstance(slot, BufferSlot)
|
||||
assert slot.payload.message == "test"
|
||||
assert slot.from_id == "sender"
|
||||
assert slot.to_id == "receiver"
|
||||
assert slot.index == 0
|
||||
|
||||
def test_append_increments_index(self):
|
||||
"""Each append gets a new index."""
|
||||
ctx = ThreadContext("thread-1")
|
||||
|
||||
slot1 = ctx.append(TestPayload("a"), "s", "r")
|
||||
slot2 = ctx.append(TestPayload("b"), "s", "r")
|
||||
slot3 = ctx.append(TestPayload("c"), "s", "r")
|
||||
|
||||
assert slot1.index == 0
|
||||
assert slot2.index == 1
|
||||
assert slot3.index == 2
|
||||
assert len(ctx) == 3
|
||||
|
||||
def test_getitem_returns_slot(self):
|
||||
"""Can access slots by index."""
|
||||
ctx = ThreadContext("thread-1")
|
||||
|
||||
ctx.append(TestPayload("first"), "s", "r")
|
||||
ctx.append(TestPayload("second"), "s", "r")
|
||||
|
||||
assert ctx[0].payload.message == "first"
|
||||
assert ctx[1].payload.message == "second"
|
||||
|
||||
def test_iteration(self):
|
||||
"""Can iterate over all slots."""
|
||||
ctx = ThreadContext("thread-1")
|
||||
|
||||
ctx.append(TestPayload("a"), "s", "r")
|
||||
ctx.append(TestPayload("b"), "s", "r")
|
||||
ctx.append(TestPayload("c"), "s", "r")
|
||||
|
||||
messages = [slot.payload.message for slot in ctx]
|
||||
assert messages == ["a", "b", "c"]
|
||||
|
||||
def test_get_by_type(self):
|
||||
"""Can filter slots by payload type."""
|
||||
ctx = ThreadContext("thread-1")
|
||||
|
||||
ctx.append(TestPayload("test"), "s", "r")
|
||||
ctx.append(FrozenPayload("frozen"), "s", "r")
|
||||
ctx.append(TestPayload("test2"), "s", "r")
|
||||
|
||||
test_slots = ctx.get_by_type("TestPayload")
|
||||
assert len(test_slots) == 2
|
||||
|
||||
frozen_slots = ctx.get_by_type("FrozenPayload")
|
||||
assert len(frozen_slots) == 1
|
||||
|
||||
def test_get_from(self):
|
||||
"""Can filter slots by sender."""
|
||||
ctx = ThreadContext("thread-1")
|
||||
|
||||
ctx.append(TestPayload("a"), "alice", "r")
|
||||
ctx.append(TestPayload("b"), "bob", "r")
|
||||
ctx.append(TestPayload("c"), "alice", "r")
|
||||
|
||||
alice_slots = ctx.get_from("alice")
|
||||
assert len(alice_slots) == 2
|
||||
|
||||
bob_slots = ctx.get_from("bob")
|
||||
assert len(bob_slots) == 1
|
||||
|
||||
|
||||
class TestContextBuffer:
|
||||
"""Test global context buffer."""
|
||||
|
||||
def test_get_or_create_thread(self):
|
||||
"""get_or_create_thread creates new thread if needed."""
|
||||
buffer = ContextBuffer()
|
||||
|
||||
ctx1 = buffer.get_or_create_thread("thread-1")
|
||||
ctx2 = buffer.get_or_create_thread("thread-1")
|
||||
ctx3 = buffer.get_or_create_thread("thread-2")
|
||||
|
||||
assert ctx1 is ctx2 # Same thread
|
||||
assert ctx1 is not ctx3 # Different thread
|
||||
|
||||
def test_append_to_thread(self):
|
||||
"""append() adds to correct thread."""
|
||||
buffer = ContextBuffer()
|
||||
|
||||
slot1 = buffer.append("t1", TestPayload("a"), "s", "r")
|
||||
slot2 = buffer.append("t2", TestPayload("b"), "s", "r")
|
||||
slot3 = buffer.append("t1", TestPayload("c"), "s", "r")
|
||||
|
||||
t1 = buffer.get_thread("t1")
|
||||
t2 = buffer.get_thread("t2")
|
||||
|
||||
assert len(t1) == 2
|
||||
assert len(t2) == 1
|
||||
assert t1[0].payload.message == "a"
|
||||
assert t1[1].payload.message == "c"
|
||||
assert t2[0].payload.message == "b"
|
||||
|
||||
def test_thread_isolation(self):
|
||||
"""Threads cannot see each other's slots."""
|
||||
buffer = ContextBuffer()
|
||||
|
||||
buffer.append("thread-a", TestPayload("secret-a"), "s", "r")
|
||||
buffer.append("thread-b", TestPayload("secret-b"), "s", "r")
|
||||
|
||||
thread_a = buffer.get_thread("thread-a")
|
||||
thread_b = buffer.get_thread("thread-b")
|
||||
|
||||
# Each thread only sees its own messages
|
||||
a_messages = [s.payload.message for s in thread_a]
|
||||
b_messages = [s.payload.message for s in thread_b]
|
||||
|
||||
assert a_messages == ["secret-a"]
|
||||
assert b_messages == ["secret-b"]
|
||||
|
||||
def test_delete_thread(self):
|
||||
"""delete_thread removes thread context."""
|
||||
buffer = ContextBuffer()
|
||||
|
||||
buffer.append("t1", TestPayload("test"), "s", "r")
|
||||
assert buffer.thread_exists("t1")
|
||||
|
||||
result = buffer.delete_thread("t1")
|
||||
assert result is True
|
||||
assert not buffer.thread_exists("t1")
|
||||
|
||||
def test_max_slots_limit(self):
|
||||
"""Exceeding max_slots_per_thread raises MemoryError."""
|
||||
buffer = ContextBuffer()
|
||||
buffer.max_slots_per_thread = 3
|
||||
|
||||
buffer.append("t1", TestPayload("1"), "s", "r")
|
||||
buffer.append("t1", TestPayload("2"), "s", "r")
|
||||
buffer.append("t1", TestPayload("3"), "s", "r")
|
||||
|
||||
with pytest.raises(MemoryError) as exc_info:
|
||||
buffer.append("t1", TestPayload("4"), "s", "r")
|
||||
|
||||
assert "exceeded max slots" in str(exc_info.value)
|
||||
|
||||
def test_max_threads_gc(self):
|
||||
"""Exceeding max_threads triggers GC of oldest thread."""
|
||||
buffer = ContextBuffer()
|
||||
buffer.max_threads = 2
|
||||
|
||||
buffer.append("t1", TestPayload("first"), "s", "r")
|
||||
buffer.append("t2", TestPayload("second"), "s", "r")
|
||||
|
||||
# Adding third thread should GC the oldest
|
||||
buffer.append("t3", TestPayload("third"), "s", "r")
|
||||
|
||||
stats = buffer.get_stats()
|
||||
assert stats["thread_count"] == 2
|
||||
|
||||
# t1 should be gone (oldest)
|
||||
assert not buffer.thread_exists("t1")
|
||||
assert buffer.thread_exists("t2")
|
||||
assert buffer.thread_exists("t3")
|
||||
|
||||
def test_get_stats(self):
|
||||
"""get_stats returns buffer statistics."""
|
||||
buffer = ContextBuffer()
|
||||
|
||||
buffer.append("t1", TestPayload("a"), "s", "r")
|
||||
buffer.append("t1", TestPayload("b"), "s", "r")
|
||||
buffer.append("t2", TestPayload("c"), "s", "r")
|
||||
|
||||
stats = buffer.get_stats()
|
||||
|
||||
assert stats["thread_count"] == 2
|
||||
assert stats["total_slots"] == 3
|
||||
|
||||
|
||||
class TestHandlerMetadataAdapter:
|
||||
"""Test conversion from SlotMetadata to HandlerMetadata."""
|
||||
|
||||
def test_slot_to_handler_metadata(self):
|
||||
"""slot_to_handler_metadata converts correctly."""
|
||||
buffer = ContextBuffer()
|
||||
|
||||
slot = buffer.append(
|
||||
thread_id="t1",
|
||||
payload=TestPayload("test"),
|
||||
from_id="sender",
|
||||
to_id="receiver",
|
||||
own_name="test-agent",
|
||||
is_self_call=True,
|
||||
usage_instructions="instructions here",
|
||||
todo_nudge="nudge here",
|
||||
)
|
||||
|
||||
metadata = slot_to_handler_metadata(slot)
|
||||
|
||||
assert metadata.thread_id == "t1"
|
||||
assert metadata.from_id == "sender"
|
||||
assert metadata.own_name == "test-agent"
|
||||
assert metadata.is_self_call is True
|
||||
assert metadata.usage_instructions == "instructions here"
|
||||
assert metadata.todo_nudge == "nudge here"
|
||||
|
||||
|
||||
class TestSingleton:
|
||||
"""Test singleton behavior."""
|
||||
|
||||
def test_get_context_buffer_singleton(self):
|
||||
"""get_context_buffer returns same instance."""
|
||||
buf1 = get_context_buffer()
|
||||
buf2 = get_context_buffer()
|
||||
|
||||
assert buf1 is buf2
|
||||
|
||||
def test_clear_resets_state(self):
|
||||
"""clear() removes all threads."""
|
||||
buffer = get_context_buffer()
|
||||
buffer.append("test-thread", TestPayload("test"), "s", "r")
|
||||
|
||||
buffer.clear()
|
||||
|
||||
assert not buffer.thread_exists("test-thread")
|
||||
assert buffer.get_stats()["thread_count"] == 0
|
||||
|
||||
|
||||
class TestPumpIntegration:
|
||||
"""Test context buffer integration with StreamPump."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_buffer_records_messages_during_flow(self):
|
||||
"""Context buffer should record messages as they flow through pump."""
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from agentserver.message_bus.stream_pump import StreamPump, ListenerConfig, OrganismConfig
|
||||
from agentserver.message_bus.message_state import HandlerResponse
|
||||
from agentserver.llm.backend import LLMResponse
|
||||
|
||||
# Import handlers
|
||||
from handlers.hello import Greeting, GreetingResponse, handle_greeting, handle_shout
|
||||
from handlers.console import ShoutedResponse
|
||||
|
||||
# Clear buffer
|
||||
buffer = get_context_buffer()
|
||||
buffer.clear()
|
||||
|
||||
# Create pump with greeter and shouter
|
||||
config = OrganismConfig(name="buffer-test")
|
||||
pump = StreamPump(config)
|
||||
|
||||
pump.register_listener(ListenerConfig(
|
||||
name="greeter",
|
||||
payload_class_path="handlers.hello.Greeting",
|
||||
handler_path="handlers.hello.handle_greeting",
|
||||
description="Greeting agent",
|
||||
is_agent=True,
|
||||
peers=["shouter"],
|
||||
payload_class=Greeting,
|
||||
handler=handle_greeting,
|
||||
))
|
||||
|
||||
pump.register_listener(ListenerConfig(
|
||||
name="shouter",
|
||||
payload_class_path="handlers.hello.GreetingResponse",
|
||||
handler_path="handlers.hello.handle_shout",
|
||||
description="Shouts",
|
||||
payload_class=GreetingResponse,
|
||||
handler=handle_shout,
|
||||
))
|
||||
|
||||
# Mock LLM
|
||||
mock_llm = LLMResponse(
|
||||
content="Hello there!",
|
||||
model="mock",
|
||||
usage={"total_tokens": 5},
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
# Prevent re-injection loops
|
||||
async def noop_reinject(state):
|
||||
pass
|
||||
pump._reinject_responses = noop_reinject
|
||||
|
||||
with patch('agentserver.llm.complete', new=AsyncMock(return_value=mock_llm)):
|
||||
# Create envelope for Greeting
|
||||
thread_id = str(uuid.uuid4())
|
||||
envelope = f"""<message xmlns="https://xml-pipeline.org/ns/envelope/v1">
|
||||
<meta><from>user</from><to>greeter</to><thread>{thread_id}</thread></meta>
|
||||
<Greeting xmlns=""><Name>Alice</Name></Greeting>
|
||||
</message>""".encode()
|
||||
|
||||
await pump.inject(envelope, thread_id, from_id="user")
|
||||
|
||||
# Run pump to process
|
||||
pump._running = True
|
||||
pipeline = pump.build_pipeline(pump._queue_source())
|
||||
|
||||
async def run_chain():
|
||||
async with pipeline.stream() as streamer:
|
||||
count = 0
|
||||
async for _ in streamer:
|
||||
count += 1
|
||||
if count >= 3:
|
||||
break
|
||||
|
||||
import asyncio
|
||||
try:
|
||||
await asyncio.wait_for(run_chain(), timeout=3.0)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
pump._running = False
|
||||
|
||||
# Verify buffer recorded the messages
|
||||
thread_ctx = buffer.get_thread(thread_id)
|
||||
assert thread_ctx is not None, "Thread should exist in buffer"
|
||||
assert len(thread_ctx) >= 1, "Buffer should have at least one message"
|
||||
|
||||
# Check that we recorded a Greeting
|
||||
greeting_slots = thread_ctx.get_by_type("Greeting")
|
||||
assert len(greeting_slots) >= 1, "Should have recorded Greeting"
|
||||
assert greeting_slots[0].payload.name == "Alice"
|
||||
Loading…
Reference in a new issue