Add context buffer - virtual memory manager for AI agents

Implements thread-scoped, append-only storage for validated payloads.
Handlers receive immutable references; messages cannot be modified
after insertion.

Core components:
- BufferSlot: frozen dataclass holding payload + metadata
- ThreadContext: append-only buffer per thread
- ContextBuffer: global manager with GC and limits

Design parallels OS virtual memory:
- Thread ID = virtual address space
- Buffer slot = memory page
- Immutable reference = read-only mapping
- Thread isolation = process isolation

Integration:
- Incoming messages appended after pipeline validation
- Outgoing responses appended before serialization
- Full audit trail preserved

This is incremental - handlers still receive copies for backward
compatibility. Next step: skip serialization for internal routing.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
dullfig 2026-01-10 17:20:43 -08:00
parent a5e2ab22da
commit fc7170a02e
4 changed files with 784 additions and 0 deletions

View file

@ -0,0 +1,27 @@
"""
memory Virtual memory management for AI agents.
Provides thread-scoped, append-only context buffers with:
- Immutable slots (handlers can't modify messages)
- Thread isolation (handlers only see their context)
- Complete audit trail (all messages preserved)
- GC and limits (prevent runaway memory usage)
"""
from agentserver.memory.context_buffer import (
ContextBuffer,
ThreadContext,
BufferSlot,
SlotMetadata,
get_context_buffer,
slot_to_handler_metadata,
)
__all__ = [
"ContextBuffer",
"ThreadContext",
"BufferSlot",
"SlotMetadata",
"get_context_buffer",
"slot_to_handler_metadata",
]

View file

@ -0,0 +1,299 @@
"""
context_buffer.py Virtual memory manager for AI agents.
Provides thread-scoped, append-only storage for validated message payloads.
Handlers receive immutable references to buffer slots, never copies.
Design principles:
- Append-only: Messages cannot be modified after insertion
- Thread isolation: Handlers only see their thread's context
- Immutable references: Handlers get read-only views
- Complete audit trail: All messages preserved in order
Analogous to OS virtual memory:
- Thread ID = virtual address space
- Buffer slot = memory page
- Thread registry = page table
- Immutable reference = read-only mapping
Usage:
buffer = get_context_buffer()
# Append validated payload (returns slot reference)
ref = buffer.append(thread_id, payload, metadata)
# Handler receives reference
handler(ref.payload, ref.metadata)
# Get thread history
history = buffer.get_thread(thread_id)
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Iterator
from datetime import datetime, timezone
import threading
import uuid
@dataclass(frozen=True)
class SlotMetadata:
"""Immutable metadata for a buffer slot."""
thread_id: str
from_id: str
to_id: str
slot_index: int
timestamp: str
payload_type: str
# Handler-facing metadata (subset exposed to handlers)
own_name: Optional[str] = None
is_self_call: bool = False
usage_instructions: str = ""
todo_nudge: str = ""
@dataclass(frozen=True)
class BufferSlot:
"""
Immutable slot in the context buffer.
frozen=True ensures the slot cannot be modified after creation.
Handlers receive this directly - they cannot mutate it.
"""
payload: Any # The validated @xmlify dataclass (immutable reference)
metadata: SlotMetadata # Immutable slot metadata
@property
def thread_id(self) -> str:
return self.metadata.thread_id
@property
def from_id(self) -> str:
return self.metadata.from_id
@property
def to_id(self) -> str:
return self.metadata.to_id
@property
def index(self) -> int:
return self.metadata.slot_index
class ThreadContext:
"""
Append-only context buffer for a single thread.
All slots are immutable once appended.
"""
def __init__(self, thread_id: str):
self.thread_id = thread_id
self._slots: List[BufferSlot] = []
self._lock = threading.Lock()
self._created_at = datetime.now(timezone.utc)
def append(
self,
payload: Any,
from_id: str,
to_id: str,
own_name: Optional[str] = None,
is_self_call: bool = False,
usage_instructions: str = "",
todo_nudge: str = "",
) -> BufferSlot:
"""
Append a validated payload to this thread's context.
Returns the immutable slot reference.
"""
with self._lock:
slot_index = len(self._slots)
metadata = SlotMetadata(
thread_id=self.thread_id,
from_id=from_id,
to_id=to_id,
slot_index=slot_index,
timestamp=datetime.now(timezone.utc).isoformat(),
payload_type=type(payload).__name__,
own_name=own_name,
is_self_call=is_self_call,
usage_instructions=usage_instructions,
todo_nudge=todo_nudge,
)
slot = BufferSlot(payload=payload, metadata=metadata)
self._slots.append(slot)
return slot
def __len__(self) -> int:
with self._lock:
return len(self._slots)
def __getitem__(self, index: int) -> BufferSlot:
with self._lock:
return self._slots[index]
def __iter__(self) -> Iterator[BufferSlot]:
# Return a copy of the list to avoid mutation during iteration
with self._lock:
return iter(list(self._slots))
def get_slice(self, start: int = 0, end: Optional[int] = None) -> List[BufferSlot]:
"""Get a slice of the context (for paging/windowing)."""
with self._lock:
return list(self._slots[start:end])
def get_by_type(self, payload_type: str) -> List[BufferSlot]:
"""Get all slots with a specific payload type."""
with self._lock:
return [s for s in self._slots if s.metadata.payload_type == payload_type]
def get_from(self, from_id: str) -> List[BufferSlot]:
"""Get all slots from a specific sender."""
with self._lock:
return [s for s in self._slots if s.from_id == from_id]
class ContextBuffer:
"""
Global context buffer managing all thread contexts.
Thread-safe. Singleton pattern via get_context_buffer().
"""
def __init__(self):
self._threads: Dict[str, ThreadContext] = {}
self._lock = threading.Lock()
# Limits (can be configured)
self.max_slots_per_thread: int = 10000
self.max_threads: int = 1000
def get_or_create_thread(self, thread_id: str) -> ThreadContext:
"""Get existing thread context or create new one."""
with self._lock:
if thread_id not in self._threads:
if len(self._threads) >= self.max_threads:
# GC: remove oldest thread (simple strategy)
oldest = min(self._threads.values(), key=lambda t: t._created_at)
del self._threads[oldest.thread_id]
self._threads[thread_id] = ThreadContext(thread_id)
return self._threads[thread_id]
def append(
self,
thread_id: str,
payload: Any,
from_id: str,
to_id: str,
own_name: Optional[str] = None,
is_self_call: bool = False,
usage_instructions: str = "",
todo_nudge: str = "",
) -> BufferSlot:
"""
Append a validated payload to a thread's context.
This is the main entry point for the pipeline.
Returns the immutable slot reference.
"""
thread = self.get_or_create_thread(thread_id)
# Enforce slot limit
if len(thread) >= self.max_slots_per_thread:
raise MemoryError(
f"Thread {thread_id} exceeded max slots ({self.max_slots_per_thread})"
)
return thread.append(
payload=payload,
from_id=from_id,
to_id=to_id,
own_name=own_name,
is_self_call=is_self_call,
usage_instructions=usage_instructions,
todo_nudge=todo_nudge,
)
def get_thread(self, thread_id: str) -> Optional[ThreadContext]:
"""Get a thread's context (None if not found)."""
with self._lock:
return self._threads.get(thread_id)
def thread_exists(self, thread_id: str) -> bool:
"""Check if a thread exists."""
with self._lock:
return thread_id in self._threads
def delete_thread(self, thread_id: str) -> bool:
"""Delete a thread's context (GC)."""
with self._lock:
if thread_id in self._threads:
del self._threads[thread_id]
return True
return False
def get_stats(self) -> Dict[str, Any]:
"""Get buffer statistics."""
with self._lock:
total_slots = sum(len(t) for t in self._threads.values())
return {
"thread_count": len(self._threads),
"total_slots": total_slots,
"max_threads": self.max_threads,
"max_slots_per_thread": self.max_slots_per_thread,
}
def clear(self):
"""Clear all contexts (for testing)."""
with self._lock:
self._threads.clear()
# ============================================================================
# Singleton
# ============================================================================
_buffer: Optional[ContextBuffer] = None
_buffer_lock = threading.Lock()
def get_context_buffer() -> ContextBuffer:
"""Get the global ContextBuffer singleton."""
global _buffer
if _buffer is None:
with _buffer_lock:
if _buffer is None:
_buffer = ContextBuffer()
return _buffer
# ============================================================================
# Handler-facing metadata adapter
# ============================================================================
def slot_to_handler_metadata(slot: BufferSlot) -> 'HandlerMetadata':
"""
Convert SlotMetadata to HandlerMetadata for backward compatibility.
Handlers still receive HandlerMetadata, but it's derived from the slot.
"""
from agentserver.message_bus.message_state import HandlerMetadata
return HandlerMetadata(
thread_id=slot.metadata.thread_id,
from_id=slot.metadata.from_id,
own_name=slot.metadata.own_name,
is_self_call=slot.metadata.is_self_call,
usage_instructions=slot.metadata.usage_instructions,
todo_nudge=slot.metadata.todo_nudge,
)

View file

@ -32,6 +32,7 @@ from agentserver.message_bus.steps.thread_assignment import thread_assignment_st
from agentserver.message_bus.message_state import MessageState, HandlerMetadata, HandlerResponse, SystemError, ROUTING_ERROR from agentserver.message_bus.message_state import MessageState, HandlerMetadata, HandlerResponse, SystemError, ROUTING_ERROR
from agentserver.message_bus.thread_registry import get_registry from agentserver.message_bus.thread_registry import get_registry
from agentserver.message_bus.todo_registry import get_todo_registry from agentserver.message_bus.todo_registry import get_todo_registry
from agentserver.memory import get_context_buffer
# ============================================================================ # ============================================================================
@ -348,6 +349,7 @@ class StreamPump:
# Ensure we have a valid thread chain # Ensure we have a valid thread chain
registry = get_registry() registry = get_registry()
todo_registry = get_todo_registry() todo_registry = get_todo_registry()
context_buffer = get_context_buffer()
current_thread = state.thread_id or "" current_thread = state.thread_id or ""
# Check if thread exists in registry; if not, register it # Check if thread exists in registry; if not, register it
@ -377,6 +379,27 @@ class StreamPump:
raised = todo_registry.get_raised_for(current_thread, listener.name) raised = todo_registry.get_raised_for(current_thread, listener.name)
todo_nudge = todo_registry.format_nudge(raised) todo_nudge = todo_registry.format_nudge(raised)
# === CONTEXT BUFFER: Record incoming message ===
# Append validated payload to thread's context buffer
if current_thread and state.payload:
try:
context_buffer.append(
thread_id=current_thread,
payload=state.payload,
from_id=state.from_id or "unknown",
to_id=listener.name,
own_name=listener.name if listener.is_agent else None,
is_self_call=is_self_call,
usage_instructions=listener.usage_instructions,
todo_nudge=todo_nudge,
)
except MemoryError:
# Thread exceeded max slots - log and continue
import logging
logging.getLogger(__name__).warning(
f"Thread {current_thread[:8]}... exceeded context buffer limit"
)
metadata = HandlerMetadata( metadata = HandlerMetadata(
thread_id=current_thread, thread_id=current_thread,
from_id=state.from_id or "", from_id=state.from_id or "",
@ -435,6 +458,22 @@ class StreamPump:
to_id = requested_to to_id = requested_to
thread_id = registry.extend_chain(current_thread, to_id) thread_id = registry.extend_chain(current_thread, to_id)
# === CONTEXT BUFFER: Record outgoing response ===
# Append handler's response to the target thread's buffer
# This happens BEFORE serialization - the buffer holds the clean payload
try:
context_buffer.append(
thread_id=thread_id,
payload=response.payload,
from_id=listener.name,
to_id=to_id,
)
except MemoryError:
import logging
logging.getLogger(__name__).warning(
f"Thread {thread_id[:8]}... exceeded context buffer limit"
)
response_bytes = self._wrap_in_envelope( response_bytes = self._wrap_in_envelope(
payload=response.payload, payload=response.payload,
from_id=listener.name, from_id=listener.name,

View file

@ -0,0 +1,419 @@
"""
test_context_buffer.py Tests for the AI agent virtual memory manager.
Tests:
1. Append-only semantics
2. Immutability guarantees
3. Thread isolation
4. Slot references
5. GC and limits
"""
import pytest
import uuid
from dataclasses import dataclass, FrozenInstanceError
from agentserver.memory.context_buffer import (
ContextBuffer,
ThreadContext,
BufferSlot,
SlotMetadata,
get_context_buffer,
slot_to_handler_metadata,
)
# Test payload classes
@dataclass
class TestPayload:
message: str
value: int = 0
@dataclass(frozen=True)
class FrozenPayload:
message: str
class TestBufferSlotImmutability:
"""Test that buffer slots cannot be modified."""
def test_slot_is_frozen(self):
"""BufferSlot should be frozen (immutable)."""
metadata = SlotMetadata(
thread_id="t1",
from_id="sender",
to_id="receiver",
slot_index=0,
timestamp="2024-01-01T00:00:00Z",
payload_type="TestPayload",
)
slot = BufferSlot(payload=TestPayload(message="hello"), metadata=metadata)
# Cannot modify slot attributes
with pytest.raises(FrozenInstanceError):
slot.metadata = None
def test_slot_metadata_is_frozen(self):
"""SlotMetadata should be frozen (immutable)."""
metadata = SlotMetadata(
thread_id="t1",
from_id="sender",
to_id="receiver",
slot_index=0,
timestamp="2024-01-01T00:00:00Z",
payload_type="TestPayload",
)
with pytest.raises(FrozenInstanceError):
metadata.thread_id = "modified"
def test_payload_reference_preserved(self):
"""Slot should preserve reference to original payload."""
payload = TestPayload(message="original")
metadata = SlotMetadata(
thread_id="t1",
from_id="sender",
to_id="receiver",
slot_index=0,
timestamp="2024-01-01T00:00:00Z",
payload_type="TestPayload",
)
slot = BufferSlot(payload=payload, metadata=metadata)
# Same reference
assert slot.payload is payload
class TestThreadContext:
"""Test single-thread context buffer."""
def test_append_creates_slot(self):
"""Appending returns a BufferSlot."""
ctx = ThreadContext("thread-1")
slot = ctx.append(
payload=TestPayload(message="test"),
from_id="sender",
to_id="receiver",
)
assert isinstance(slot, BufferSlot)
assert slot.payload.message == "test"
assert slot.from_id == "sender"
assert slot.to_id == "receiver"
assert slot.index == 0
def test_append_increments_index(self):
"""Each append gets a new index."""
ctx = ThreadContext("thread-1")
slot1 = ctx.append(TestPayload("a"), "s", "r")
slot2 = ctx.append(TestPayload("b"), "s", "r")
slot3 = ctx.append(TestPayload("c"), "s", "r")
assert slot1.index == 0
assert slot2.index == 1
assert slot3.index == 2
assert len(ctx) == 3
def test_getitem_returns_slot(self):
"""Can access slots by index."""
ctx = ThreadContext("thread-1")
ctx.append(TestPayload("first"), "s", "r")
ctx.append(TestPayload("second"), "s", "r")
assert ctx[0].payload.message == "first"
assert ctx[1].payload.message == "second"
def test_iteration(self):
"""Can iterate over all slots."""
ctx = ThreadContext("thread-1")
ctx.append(TestPayload("a"), "s", "r")
ctx.append(TestPayload("b"), "s", "r")
ctx.append(TestPayload("c"), "s", "r")
messages = [slot.payload.message for slot in ctx]
assert messages == ["a", "b", "c"]
def test_get_by_type(self):
"""Can filter slots by payload type."""
ctx = ThreadContext("thread-1")
ctx.append(TestPayload("test"), "s", "r")
ctx.append(FrozenPayload("frozen"), "s", "r")
ctx.append(TestPayload("test2"), "s", "r")
test_slots = ctx.get_by_type("TestPayload")
assert len(test_slots) == 2
frozen_slots = ctx.get_by_type("FrozenPayload")
assert len(frozen_slots) == 1
def test_get_from(self):
"""Can filter slots by sender."""
ctx = ThreadContext("thread-1")
ctx.append(TestPayload("a"), "alice", "r")
ctx.append(TestPayload("b"), "bob", "r")
ctx.append(TestPayload("c"), "alice", "r")
alice_slots = ctx.get_from("alice")
assert len(alice_slots) == 2
bob_slots = ctx.get_from("bob")
assert len(bob_slots) == 1
class TestContextBuffer:
"""Test global context buffer."""
def test_get_or_create_thread(self):
"""get_or_create_thread creates new thread if needed."""
buffer = ContextBuffer()
ctx1 = buffer.get_or_create_thread("thread-1")
ctx2 = buffer.get_or_create_thread("thread-1")
ctx3 = buffer.get_or_create_thread("thread-2")
assert ctx1 is ctx2 # Same thread
assert ctx1 is not ctx3 # Different thread
def test_append_to_thread(self):
"""append() adds to correct thread."""
buffer = ContextBuffer()
slot1 = buffer.append("t1", TestPayload("a"), "s", "r")
slot2 = buffer.append("t2", TestPayload("b"), "s", "r")
slot3 = buffer.append("t1", TestPayload("c"), "s", "r")
t1 = buffer.get_thread("t1")
t2 = buffer.get_thread("t2")
assert len(t1) == 2
assert len(t2) == 1
assert t1[0].payload.message == "a"
assert t1[1].payload.message == "c"
assert t2[0].payload.message == "b"
def test_thread_isolation(self):
"""Threads cannot see each other's slots."""
buffer = ContextBuffer()
buffer.append("thread-a", TestPayload("secret-a"), "s", "r")
buffer.append("thread-b", TestPayload("secret-b"), "s", "r")
thread_a = buffer.get_thread("thread-a")
thread_b = buffer.get_thread("thread-b")
# Each thread only sees its own messages
a_messages = [s.payload.message for s in thread_a]
b_messages = [s.payload.message for s in thread_b]
assert a_messages == ["secret-a"]
assert b_messages == ["secret-b"]
def test_delete_thread(self):
"""delete_thread removes thread context."""
buffer = ContextBuffer()
buffer.append("t1", TestPayload("test"), "s", "r")
assert buffer.thread_exists("t1")
result = buffer.delete_thread("t1")
assert result is True
assert not buffer.thread_exists("t1")
def test_max_slots_limit(self):
"""Exceeding max_slots_per_thread raises MemoryError."""
buffer = ContextBuffer()
buffer.max_slots_per_thread = 3
buffer.append("t1", TestPayload("1"), "s", "r")
buffer.append("t1", TestPayload("2"), "s", "r")
buffer.append("t1", TestPayload("3"), "s", "r")
with pytest.raises(MemoryError) as exc_info:
buffer.append("t1", TestPayload("4"), "s", "r")
assert "exceeded max slots" in str(exc_info.value)
def test_max_threads_gc(self):
"""Exceeding max_threads triggers GC of oldest thread."""
buffer = ContextBuffer()
buffer.max_threads = 2
buffer.append("t1", TestPayload("first"), "s", "r")
buffer.append("t2", TestPayload("second"), "s", "r")
# Adding third thread should GC the oldest
buffer.append("t3", TestPayload("third"), "s", "r")
stats = buffer.get_stats()
assert stats["thread_count"] == 2
# t1 should be gone (oldest)
assert not buffer.thread_exists("t1")
assert buffer.thread_exists("t2")
assert buffer.thread_exists("t3")
def test_get_stats(self):
"""get_stats returns buffer statistics."""
buffer = ContextBuffer()
buffer.append("t1", TestPayload("a"), "s", "r")
buffer.append("t1", TestPayload("b"), "s", "r")
buffer.append("t2", TestPayload("c"), "s", "r")
stats = buffer.get_stats()
assert stats["thread_count"] == 2
assert stats["total_slots"] == 3
class TestHandlerMetadataAdapter:
"""Test conversion from SlotMetadata to HandlerMetadata."""
def test_slot_to_handler_metadata(self):
"""slot_to_handler_metadata converts correctly."""
buffer = ContextBuffer()
slot = buffer.append(
thread_id="t1",
payload=TestPayload("test"),
from_id="sender",
to_id="receiver",
own_name="test-agent",
is_self_call=True,
usage_instructions="instructions here",
todo_nudge="nudge here",
)
metadata = slot_to_handler_metadata(slot)
assert metadata.thread_id == "t1"
assert metadata.from_id == "sender"
assert metadata.own_name == "test-agent"
assert metadata.is_self_call is True
assert metadata.usage_instructions == "instructions here"
assert metadata.todo_nudge == "nudge here"
class TestSingleton:
"""Test singleton behavior."""
def test_get_context_buffer_singleton(self):
"""get_context_buffer returns same instance."""
buf1 = get_context_buffer()
buf2 = get_context_buffer()
assert buf1 is buf2
def test_clear_resets_state(self):
"""clear() removes all threads."""
buffer = get_context_buffer()
buffer.append("test-thread", TestPayload("test"), "s", "r")
buffer.clear()
assert not buffer.thread_exists("test-thread")
assert buffer.get_stats()["thread_count"] == 0
class TestPumpIntegration:
"""Test context buffer integration with StreamPump."""
@pytest.mark.asyncio
async def test_buffer_records_messages_during_flow(self):
"""Context buffer should record messages as they flow through pump."""
from unittest.mock import AsyncMock, patch
from agentserver.message_bus.stream_pump import StreamPump, ListenerConfig, OrganismConfig
from agentserver.message_bus.message_state import HandlerResponse
from agentserver.llm.backend import LLMResponse
# Import handlers
from handlers.hello import Greeting, GreetingResponse, handle_greeting, handle_shout
from handlers.console import ShoutedResponse
# Clear buffer
buffer = get_context_buffer()
buffer.clear()
# Create pump with greeter and shouter
config = OrganismConfig(name="buffer-test")
pump = StreamPump(config)
pump.register_listener(ListenerConfig(
name="greeter",
payload_class_path="handlers.hello.Greeting",
handler_path="handlers.hello.handle_greeting",
description="Greeting agent",
is_agent=True,
peers=["shouter"],
payload_class=Greeting,
handler=handle_greeting,
))
pump.register_listener(ListenerConfig(
name="shouter",
payload_class_path="handlers.hello.GreetingResponse",
handler_path="handlers.hello.handle_shout",
description="Shouts",
payload_class=GreetingResponse,
handler=handle_shout,
))
# Mock LLM
mock_llm = LLMResponse(
content="Hello there!",
model="mock",
usage={"total_tokens": 5},
finish_reason="stop",
)
# Prevent re-injection loops
async def noop_reinject(state):
pass
pump._reinject_responses = noop_reinject
with patch('agentserver.llm.complete', new=AsyncMock(return_value=mock_llm)):
# Create envelope for Greeting
thread_id = str(uuid.uuid4())
envelope = f"""<message xmlns="https://xml-pipeline.org/ns/envelope/v1">
<meta><from>user</from><to>greeter</to><thread>{thread_id}</thread></meta>
<Greeting xmlns=""><Name>Alice</Name></Greeting>
</message>""".encode()
await pump.inject(envelope, thread_id, from_id="user")
# Run pump to process
pump._running = True
pipeline = pump.build_pipeline(pump._queue_source())
async def run_chain():
async with pipeline.stream() as streamer:
count = 0
async for _ in streamer:
count += 1
if count >= 3:
break
import asyncio
try:
await asyncio.wait_for(run_chain(), timeout=3.0)
except asyncio.TimeoutError:
pass
finally:
pump._running = False
# Verify buffer recorded the messages
thread_ctx = buffer.get_thread(thread_id)
assert thread_ctx is not None, "Thread should exist in buffer"
assert len(thread_ctx) >= 1, "Buffer should have at least one message"
# Check that we recorded a Greeting
greeting_slots = thread_ctx.get_by_type("Greeting")
assert len(greeting_slots) >= 1, "Should have recorded Greeting"
assert greeting_slots[0].payload.name == "Alice"