From fc7170a02e4cd46672a8b40a027c707921be8552 Mon Sep 17 00:00:00 2001 From: dullfig Date: Sat, 10 Jan 2026 17:20:43 -0800 Subject: [PATCH] Add context buffer - virtual memory manager for AI agents Implements thread-scoped, append-only storage for validated payloads. Handlers receive immutable references; messages cannot be modified after insertion. Core components: - BufferSlot: frozen dataclass holding payload + metadata - ThreadContext: append-only buffer per thread - ContextBuffer: global manager with GC and limits Design parallels OS virtual memory: - Thread ID = virtual address space - Buffer slot = memory page - Immutable reference = read-only mapping - Thread isolation = process isolation Integration: - Incoming messages appended after pipeline validation - Outgoing responses appended before serialization - Full audit trail preserved This is incremental - handlers still receive copies for backward compatibility. Next step: skip serialization for internal routing. Co-Authored-By: Claude Opus 4.5 --- agentserver/memory/__init__.py | 27 ++ agentserver/memory/context_buffer.py | 299 ++++++++++++++++++ agentserver/message_bus/stream_pump.py | 39 +++ tests/test_context_buffer.py | 419 +++++++++++++++++++++++++ 4 files changed, 784 insertions(+) create mode 100644 agentserver/memory/__init__.py create mode 100644 agentserver/memory/context_buffer.py create mode 100644 tests/test_context_buffer.py diff --git a/agentserver/memory/__init__.py b/agentserver/memory/__init__.py new file mode 100644 index 0000000..171f5a1 --- /dev/null +++ b/agentserver/memory/__init__.py @@ -0,0 +1,27 @@ +""" +memory — Virtual memory management for AI agents. + +Provides thread-scoped, append-only context buffers with: +- Immutable slots (handlers can't modify messages) +- Thread isolation (handlers only see their context) +- Complete audit trail (all messages preserved) +- GC and limits (prevent runaway memory usage) +""" + +from agentserver.memory.context_buffer import ( + ContextBuffer, + ThreadContext, + BufferSlot, + SlotMetadata, + get_context_buffer, + slot_to_handler_metadata, +) + +__all__ = [ + "ContextBuffer", + "ThreadContext", + "BufferSlot", + "SlotMetadata", + "get_context_buffer", + "slot_to_handler_metadata", +] diff --git a/agentserver/memory/context_buffer.py b/agentserver/memory/context_buffer.py new file mode 100644 index 0000000..6982e0c --- /dev/null +++ b/agentserver/memory/context_buffer.py @@ -0,0 +1,299 @@ +""" +context_buffer.py — Virtual memory manager for AI agents. + +Provides thread-scoped, append-only storage for validated message payloads. +Handlers receive immutable references to buffer slots, never copies. + +Design principles: +- Append-only: Messages cannot be modified after insertion +- Thread isolation: Handlers only see their thread's context +- Immutable references: Handlers get read-only views +- Complete audit trail: All messages preserved in order + +Analogous to OS virtual memory: +- Thread ID = virtual address space +- Buffer slot = memory page +- Thread registry = page table +- Immutable reference = read-only mapping + +Usage: + buffer = get_context_buffer() + + # Append validated payload (returns slot reference) + ref = buffer.append(thread_id, payload, metadata) + + # Handler receives reference + handler(ref.payload, ref.metadata) + + # Get thread history + history = buffer.get_thread(thread_id) +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Iterator +from datetime import datetime, timezone +import threading +import uuid + + +@dataclass(frozen=True) +class SlotMetadata: + """Immutable metadata for a buffer slot.""" + thread_id: str + from_id: str + to_id: str + slot_index: int + timestamp: str + payload_type: str + + # Handler-facing metadata (subset exposed to handlers) + own_name: Optional[str] = None + is_self_call: bool = False + usage_instructions: str = "" + todo_nudge: str = "" + + +@dataclass(frozen=True) +class BufferSlot: + """ + Immutable slot in the context buffer. + + frozen=True ensures the slot cannot be modified after creation. + Handlers receive this directly - they cannot mutate it. + """ + payload: Any # The validated @xmlify dataclass (immutable reference) + metadata: SlotMetadata # Immutable slot metadata + + @property + def thread_id(self) -> str: + return self.metadata.thread_id + + @property + def from_id(self) -> str: + return self.metadata.from_id + + @property + def to_id(self) -> str: + return self.metadata.to_id + + @property + def index(self) -> int: + return self.metadata.slot_index + + +class ThreadContext: + """ + Append-only context buffer for a single thread. + + All slots are immutable once appended. + """ + + def __init__(self, thread_id: str): + self.thread_id = thread_id + self._slots: List[BufferSlot] = [] + self._lock = threading.Lock() + self._created_at = datetime.now(timezone.utc) + + def append( + self, + payload: Any, + from_id: str, + to_id: str, + own_name: Optional[str] = None, + is_self_call: bool = False, + usage_instructions: str = "", + todo_nudge: str = "", + ) -> BufferSlot: + """ + Append a validated payload to this thread's context. + + Returns the immutable slot reference. + """ + with self._lock: + slot_index = len(self._slots) + + metadata = SlotMetadata( + thread_id=self.thread_id, + from_id=from_id, + to_id=to_id, + slot_index=slot_index, + timestamp=datetime.now(timezone.utc).isoformat(), + payload_type=type(payload).__name__, + own_name=own_name, + is_self_call=is_self_call, + usage_instructions=usage_instructions, + todo_nudge=todo_nudge, + ) + + slot = BufferSlot(payload=payload, metadata=metadata) + self._slots.append(slot) + + return slot + + def __len__(self) -> int: + with self._lock: + return len(self._slots) + + def __getitem__(self, index: int) -> BufferSlot: + with self._lock: + return self._slots[index] + + def __iter__(self) -> Iterator[BufferSlot]: + # Return a copy of the list to avoid mutation during iteration + with self._lock: + return iter(list(self._slots)) + + def get_slice(self, start: int = 0, end: Optional[int] = None) -> List[BufferSlot]: + """Get a slice of the context (for paging/windowing).""" + with self._lock: + return list(self._slots[start:end]) + + def get_by_type(self, payload_type: str) -> List[BufferSlot]: + """Get all slots with a specific payload type.""" + with self._lock: + return [s for s in self._slots if s.metadata.payload_type == payload_type] + + def get_from(self, from_id: str) -> List[BufferSlot]: + """Get all slots from a specific sender.""" + with self._lock: + return [s for s in self._slots if s.from_id == from_id] + + +class ContextBuffer: + """ + Global context buffer managing all thread contexts. + + Thread-safe. Singleton pattern via get_context_buffer(). + """ + + def __init__(self): + self._threads: Dict[str, ThreadContext] = {} + self._lock = threading.Lock() + + # Limits (can be configured) + self.max_slots_per_thread: int = 10000 + self.max_threads: int = 1000 + + def get_or_create_thread(self, thread_id: str) -> ThreadContext: + """Get existing thread context or create new one.""" + with self._lock: + if thread_id not in self._threads: + if len(self._threads) >= self.max_threads: + # GC: remove oldest thread (simple strategy) + oldest = min(self._threads.values(), key=lambda t: t._created_at) + del self._threads[oldest.thread_id] + + self._threads[thread_id] = ThreadContext(thread_id) + + return self._threads[thread_id] + + def append( + self, + thread_id: str, + payload: Any, + from_id: str, + to_id: str, + own_name: Optional[str] = None, + is_self_call: bool = False, + usage_instructions: str = "", + todo_nudge: str = "", + ) -> BufferSlot: + """ + Append a validated payload to a thread's context. + + This is the main entry point for the pipeline. + Returns the immutable slot reference. + """ + thread = self.get_or_create_thread(thread_id) + + # Enforce slot limit + if len(thread) >= self.max_slots_per_thread: + raise MemoryError( + f"Thread {thread_id} exceeded max slots ({self.max_slots_per_thread})" + ) + + return thread.append( + payload=payload, + from_id=from_id, + to_id=to_id, + own_name=own_name, + is_self_call=is_self_call, + usage_instructions=usage_instructions, + todo_nudge=todo_nudge, + ) + + def get_thread(self, thread_id: str) -> Optional[ThreadContext]: + """Get a thread's context (None if not found).""" + with self._lock: + return self._threads.get(thread_id) + + def thread_exists(self, thread_id: str) -> bool: + """Check if a thread exists.""" + with self._lock: + return thread_id in self._threads + + def delete_thread(self, thread_id: str) -> bool: + """Delete a thread's context (GC).""" + with self._lock: + if thread_id in self._threads: + del self._threads[thread_id] + return True + return False + + def get_stats(self) -> Dict[str, Any]: + """Get buffer statistics.""" + with self._lock: + total_slots = sum(len(t) for t in self._threads.values()) + return { + "thread_count": len(self._threads), + "total_slots": total_slots, + "max_threads": self.max_threads, + "max_slots_per_thread": self.max_slots_per_thread, + } + + def clear(self): + """Clear all contexts (for testing).""" + with self._lock: + self._threads.clear() + + +# ============================================================================ +# Singleton +# ============================================================================ + +_buffer: Optional[ContextBuffer] = None +_buffer_lock = threading.Lock() + + +def get_context_buffer() -> ContextBuffer: + """Get the global ContextBuffer singleton.""" + global _buffer + if _buffer is None: + with _buffer_lock: + if _buffer is None: + _buffer = ContextBuffer() + return _buffer + + +# ============================================================================ +# Handler-facing metadata adapter +# ============================================================================ + +def slot_to_handler_metadata(slot: BufferSlot) -> 'HandlerMetadata': + """ + Convert SlotMetadata to HandlerMetadata for backward compatibility. + + Handlers still receive HandlerMetadata, but it's derived from the slot. + """ + from agentserver.message_bus.message_state import HandlerMetadata + + return HandlerMetadata( + thread_id=slot.metadata.thread_id, + from_id=slot.metadata.from_id, + own_name=slot.metadata.own_name, + is_self_call=slot.metadata.is_self_call, + usage_instructions=slot.metadata.usage_instructions, + todo_nudge=slot.metadata.todo_nudge, + ) diff --git a/agentserver/message_bus/stream_pump.py b/agentserver/message_bus/stream_pump.py index f076af4..583c379 100644 --- a/agentserver/message_bus/stream_pump.py +++ b/agentserver/message_bus/stream_pump.py @@ -32,6 +32,7 @@ from agentserver.message_bus.steps.thread_assignment import thread_assignment_st from agentserver.message_bus.message_state import MessageState, HandlerMetadata, HandlerResponse, SystemError, ROUTING_ERROR from agentserver.message_bus.thread_registry import get_registry from agentserver.message_bus.todo_registry import get_todo_registry +from agentserver.memory import get_context_buffer # ============================================================================ @@ -348,6 +349,7 @@ class StreamPump: # Ensure we have a valid thread chain registry = get_registry() todo_registry = get_todo_registry() + context_buffer = get_context_buffer() current_thread = state.thread_id or "" # Check if thread exists in registry; if not, register it @@ -377,6 +379,27 @@ class StreamPump: raised = todo_registry.get_raised_for(current_thread, listener.name) todo_nudge = todo_registry.format_nudge(raised) + # === CONTEXT BUFFER: Record incoming message === + # Append validated payload to thread's context buffer + if current_thread and state.payload: + try: + context_buffer.append( + thread_id=current_thread, + payload=state.payload, + from_id=state.from_id or "unknown", + to_id=listener.name, + own_name=listener.name if listener.is_agent else None, + is_self_call=is_self_call, + usage_instructions=listener.usage_instructions, + todo_nudge=todo_nudge, + ) + except MemoryError: + # Thread exceeded max slots - log and continue + import logging + logging.getLogger(__name__).warning( + f"Thread {current_thread[:8]}... exceeded context buffer limit" + ) + metadata = HandlerMetadata( thread_id=current_thread, from_id=state.from_id or "", @@ -435,6 +458,22 @@ class StreamPump: to_id = requested_to thread_id = registry.extend_chain(current_thread, to_id) + # === CONTEXT BUFFER: Record outgoing response === + # Append handler's response to the target thread's buffer + # This happens BEFORE serialization - the buffer holds the clean payload + try: + context_buffer.append( + thread_id=thread_id, + payload=response.payload, + from_id=listener.name, + to_id=to_id, + ) + except MemoryError: + import logging + logging.getLogger(__name__).warning( + f"Thread {thread_id[:8]}... exceeded context buffer limit" + ) + response_bytes = self._wrap_in_envelope( payload=response.payload, from_id=listener.name, diff --git a/tests/test_context_buffer.py b/tests/test_context_buffer.py new file mode 100644 index 0000000..9ca6704 --- /dev/null +++ b/tests/test_context_buffer.py @@ -0,0 +1,419 @@ +""" +test_context_buffer.py — Tests for the AI agent virtual memory manager. + +Tests: +1. Append-only semantics +2. Immutability guarantees +3. Thread isolation +4. Slot references +5. GC and limits +""" + +import pytest +import uuid +from dataclasses import dataclass, FrozenInstanceError + +from agentserver.memory.context_buffer import ( + ContextBuffer, + ThreadContext, + BufferSlot, + SlotMetadata, + get_context_buffer, + slot_to_handler_metadata, +) + + +# Test payload classes +@dataclass +class TestPayload: + message: str + value: int = 0 + + +@dataclass(frozen=True) +class FrozenPayload: + message: str + + +class TestBufferSlotImmutability: + """Test that buffer slots cannot be modified.""" + + def test_slot_is_frozen(self): + """BufferSlot should be frozen (immutable).""" + metadata = SlotMetadata( + thread_id="t1", + from_id="sender", + to_id="receiver", + slot_index=0, + timestamp="2024-01-01T00:00:00Z", + payload_type="TestPayload", + ) + slot = BufferSlot(payload=TestPayload(message="hello"), metadata=metadata) + + # Cannot modify slot attributes + with pytest.raises(FrozenInstanceError): + slot.metadata = None + + def test_slot_metadata_is_frozen(self): + """SlotMetadata should be frozen (immutable).""" + metadata = SlotMetadata( + thread_id="t1", + from_id="sender", + to_id="receiver", + slot_index=0, + timestamp="2024-01-01T00:00:00Z", + payload_type="TestPayload", + ) + + with pytest.raises(FrozenInstanceError): + metadata.thread_id = "modified" + + def test_payload_reference_preserved(self): + """Slot should preserve reference to original payload.""" + payload = TestPayload(message="original") + metadata = SlotMetadata( + thread_id="t1", + from_id="sender", + to_id="receiver", + slot_index=0, + timestamp="2024-01-01T00:00:00Z", + payload_type="TestPayload", + ) + slot = BufferSlot(payload=payload, metadata=metadata) + + # Same reference + assert slot.payload is payload + + +class TestThreadContext: + """Test single-thread context buffer.""" + + def test_append_creates_slot(self): + """Appending returns a BufferSlot.""" + ctx = ThreadContext("thread-1") + + slot = ctx.append( + payload=TestPayload(message="test"), + from_id="sender", + to_id="receiver", + ) + + assert isinstance(slot, BufferSlot) + assert slot.payload.message == "test" + assert slot.from_id == "sender" + assert slot.to_id == "receiver" + assert slot.index == 0 + + def test_append_increments_index(self): + """Each append gets a new index.""" + ctx = ThreadContext("thread-1") + + slot1 = ctx.append(TestPayload("a"), "s", "r") + slot2 = ctx.append(TestPayload("b"), "s", "r") + slot3 = ctx.append(TestPayload("c"), "s", "r") + + assert slot1.index == 0 + assert slot2.index == 1 + assert slot3.index == 2 + assert len(ctx) == 3 + + def test_getitem_returns_slot(self): + """Can access slots by index.""" + ctx = ThreadContext("thread-1") + + ctx.append(TestPayload("first"), "s", "r") + ctx.append(TestPayload("second"), "s", "r") + + assert ctx[0].payload.message == "first" + assert ctx[1].payload.message == "second" + + def test_iteration(self): + """Can iterate over all slots.""" + ctx = ThreadContext("thread-1") + + ctx.append(TestPayload("a"), "s", "r") + ctx.append(TestPayload("b"), "s", "r") + ctx.append(TestPayload("c"), "s", "r") + + messages = [slot.payload.message for slot in ctx] + assert messages == ["a", "b", "c"] + + def test_get_by_type(self): + """Can filter slots by payload type.""" + ctx = ThreadContext("thread-1") + + ctx.append(TestPayload("test"), "s", "r") + ctx.append(FrozenPayload("frozen"), "s", "r") + ctx.append(TestPayload("test2"), "s", "r") + + test_slots = ctx.get_by_type("TestPayload") + assert len(test_slots) == 2 + + frozen_slots = ctx.get_by_type("FrozenPayload") + assert len(frozen_slots) == 1 + + def test_get_from(self): + """Can filter slots by sender.""" + ctx = ThreadContext("thread-1") + + ctx.append(TestPayload("a"), "alice", "r") + ctx.append(TestPayload("b"), "bob", "r") + ctx.append(TestPayload("c"), "alice", "r") + + alice_slots = ctx.get_from("alice") + assert len(alice_slots) == 2 + + bob_slots = ctx.get_from("bob") + assert len(bob_slots) == 1 + + +class TestContextBuffer: + """Test global context buffer.""" + + def test_get_or_create_thread(self): + """get_or_create_thread creates new thread if needed.""" + buffer = ContextBuffer() + + ctx1 = buffer.get_or_create_thread("thread-1") + ctx2 = buffer.get_or_create_thread("thread-1") + ctx3 = buffer.get_or_create_thread("thread-2") + + assert ctx1 is ctx2 # Same thread + assert ctx1 is not ctx3 # Different thread + + def test_append_to_thread(self): + """append() adds to correct thread.""" + buffer = ContextBuffer() + + slot1 = buffer.append("t1", TestPayload("a"), "s", "r") + slot2 = buffer.append("t2", TestPayload("b"), "s", "r") + slot3 = buffer.append("t1", TestPayload("c"), "s", "r") + + t1 = buffer.get_thread("t1") + t2 = buffer.get_thread("t2") + + assert len(t1) == 2 + assert len(t2) == 1 + assert t1[0].payload.message == "a" + assert t1[1].payload.message == "c" + assert t2[0].payload.message == "b" + + def test_thread_isolation(self): + """Threads cannot see each other's slots.""" + buffer = ContextBuffer() + + buffer.append("thread-a", TestPayload("secret-a"), "s", "r") + buffer.append("thread-b", TestPayload("secret-b"), "s", "r") + + thread_a = buffer.get_thread("thread-a") + thread_b = buffer.get_thread("thread-b") + + # Each thread only sees its own messages + a_messages = [s.payload.message for s in thread_a] + b_messages = [s.payload.message for s in thread_b] + + assert a_messages == ["secret-a"] + assert b_messages == ["secret-b"] + + def test_delete_thread(self): + """delete_thread removes thread context.""" + buffer = ContextBuffer() + + buffer.append("t1", TestPayload("test"), "s", "r") + assert buffer.thread_exists("t1") + + result = buffer.delete_thread("t1") + assert result is True + assert not buffer.thread_exists("t1") + + def test_max_slots_limit(self): + """Exceeding max_slots_per_thread raises MemoryError.""" + buffer = ContextBuffer() + buffer.max_slots_per_thread = 3 + + buffer.append("t1", TestPayload("1"), "s", "r") + buffer.append("t1", TestPayload("2"), "s", "r") + buffer.append("t1", TestPayload("3"), "s", "r") + + with pytest.raises(MemoryError) as exc_info: + buffer.append("t1", TestPayload("4"), "s", "r") + + assert "exceeded max slots" in str(exc_info.value) + + def test_max_threads_gc(self): + """Exceeding max_threads triggers GC of oldest thread.""" + buffer = ContextBuffer() + buffer.max_threads = 2 + + buffer.append("t1", TestPayload("first"), "s", "r") + buffer.append("t2", TestPayload("second"), "s", "r") + + # Adding third thread should GC the oldest + buffer.append("t3", TestPayload("third"), "s", "r") + + stats = buffer.get_stats() + assert stats["thread_count"] == 2 + + # t1 should be gone (oldest) + assert not buffer.thread_exists("t1") + assert buffer.thread_exists("t2") + assert buffer.thread_exists("t3") + + def test_get_stats(self): + """get_stats returns buffer statistics.""" + buffer = ContextBuffer() + + buffer.append("t1", TestPayload("a"), "s", "r") + buffer.append("t1", TestPayload("b"), "s", "r") + buffer.append("t2", TestPayload("c"), "s", "r") + + stats = buffer.get_stats() + + assert stats["thread_count"] == 2 + assert stats["total_slots"] == 3 + + +class TestHandlerMetadataAdapter: + """Test conversion from SlotMetadata to HandlerMetadata.""" + + def test_slot_to_handler_metadata(self): + """slot_to_handler_metadata converts correctly.""" + buffer = ContextBuffer() + + slot = buffer.append( + thread_id="t1", + payload=TestPayload("test"), + from_id="sender", + to_id="receiver", + own_name="test-agent", + is_self_call=True, + usage_instructions="instructions here", + todo_nudge="nudge here", + ) + + metadata = slot_to_handler_metadata(slot) + + assert metadata.thread_id == "t1" + assert metadata.from_id == "sender" + assert metadata.own_name == "test-agent" + assert metadata.is_self_call is True + assert metadata.usage_instructions == "instructions here" + assert metadata.todo_nudge == "nudge here" + + +class TestSingleton: + """Test singleton behavior.""" + + def test_get_context_buffer_singleton(self): + """get_context_buffer returns same instance.""" + buf1 = get_context_buffer() + buf2 = get_context_buffer() + + assert buf1 is buf2 + + def test_clear_resets_state(self): + """clear() removes all threads.""" + buffer = get_context_buffer() + buffer.append("test-thread", TestPayload("test"), "s", "r") + + buffer.clear() + + assert not buffer.thread_exists("test-thread") + assert buffer.get_stats()["thread_count"] == 0 + + +class TestPumpIntegration: + """Test context buffer integration with StreamPump.""" + + @pytest.mark.asyncio + async def test_buffer_records_messages_during_flow(self): + """Context buffer should record messages as they flow through pump.""" + from unittest.mock import AsyncMock, patch + from agentserver.message_bus.stream_pump import StreamPump, ListenerConfig, OrganismConfig + from agentserver.message_bus.message_state import HandlerResponse + from agentserver.llm.backend import LLMResponse + + # Import handlers + from handlers.hello import Greeting, GreetingResponse, handle_greeting, handle_shout + from handlers.console import ShoutedResponse + + # Clear buffer + buffer = get_context_buffer() + buffer.clear() + + # Create pump with greeter and shouter + config = OrganismConfig(name="buffer-test") + pump = StreamPump(config) + + pump.register_listener(ListenerConfig( + name="greeter", + payload_class_path="handlers.hello.Greeting", + handler_path="handlers.hello.handle_greeting", + description="Greeting agent", + is_agent=True, + peers=["shouter"], + payload_class=Greeting, + handler=handle_greeting, + )) + + pump.register_listener(ListenerConfig( + name="shouter", + payload_class_path="handlers.hello.GreetingResponse", + handler_path="handlers.hello.handle_shout", + description="Shouts", + payload_class=GreetingResponse, + handler=handle_shout, + )) + + # Mock LLM + mock_llm = LLMResponse( + content="Hello there!", + model="mock", + usage={"total_tokens": 5}, + finish_reason="stop", + ) + + # Prevent re-injection loops + async def noop_reinject(state): + pass + pump._reinject_responses = noop_reinject + + with patch('agentserver.llm.complete', new=AsyncMock(return_value=mock_llm)): + # Create envelope for Greeting + thread_id = str(uuid.uuid4()) + envelope = f""" + usergreeter{thread_id} + Alice + """.encode() + + await pump.inject(envelope, thread_id, from_id="user") + + # Run pump to process + pump._running = True + pipeline = pump.build_pipeline(pump._queue_source()) + + async def run_chain(): + async with pipeline.stream() as streamer: + count = 0 + async for _ in streamer: + count += 1 + if count >= 3: + break + + import asyncio + try: + await asyncio.wait_for(run_chain(), timeout=3.0) + except asyncio.TimeoutError: + pass + finally: + pump._running = False + + # Verify buffer recorded the messages + thread_ctx = buffer.get_thread(thread_id) + assert thread_ctx is not None, "Thread should exist in buffer" + assert len(thread_ctx) >= 1, "Buffer should have at least one message" + + # Check that we recorded a Greeting + greeting_slots = thread_ctx.get_by_type("Greeting") + assert len(greeting_slots) >= 1, "Should have recorded Greeting" + assert greeting_slots[0].payload.name == "Alice"