diff --git a/tests/test_token_budget.py b/tests/test_token_budget.py new file mode 100644 index 0000000..1b2b712 --- /dev/null +++ b/tests/test_token_budget.py @@ -0,0 +1,573 @@ +""" +test_token_budget.py — Tests for token budget and usage tracking. + +Tests: +1. ThreadBudgetRegistry - per-thread token limits +2. UsageTracker - billing/gas usage events +3. LLMRouter integration - budget enforcement +""" + +import pytest +from unittest.mock import Mock, AsyncMock, patch + +from xml_pipeline.message_bus.budget_registry import ( + ThreadBudget, + ThreadBudgetRegistry, + BudgetExhaustedError, + get_budget_registry, + configure_budget_registry, + reset_budget_registry, +) +from xml_pipeline.llm.usage_tracker import ( + UsageEvent, + UsageTracker, + UsageTotals, + estimate_cost, + get_usage_tracker, + reset_usage_tracker, +) + + +# ============================================================================ +# ThreadBudget Tests +# ============================================================================ + +class TestThreadBudget: + """Test ThreadBudget dataclass.""" + + def test_initial_state(self): + """New budget should have zero usage.""" + budget = ThreadBudget(max_tokens=10000) + assert budget.total_tokens == 0 + assert budget.remaining == 10000 + assert budget.is_exhausted is False + + def test_consume_tokens(self): + """Consuming tokens should update totals.""" + budget = ThreadBudget(max_tokens=10000) + budget.consume(prompt_tokens=500, completion_tokens=300) + + assert budget.prompt_tokens == 500 + assert budget.completion_tokens == 300 + assert budget.total_tokens == 800 + assert budget.remaining == 9200 + assert budget.request_count == 1 + + def test_can_consume_within_budget(self): + """can_consume should return True if within budget.""" + budget = ThreadBudget(max_tokens=1000) + budget.consume(prompt_tokens=400) + + assert budget.can_consume(500) is True + assert budget.can_consume(600) is True + assert budget.can_consume(601) is False + + def test_is_exhausted(self): + """is_exhausted should return True when budget exceeded.""" + budget = ThreadBudget(max_tokens=1000) + budget.consume(prompt_tokens=1000) + + assert budget.is_exhausted is True + assert budget.remaining == 0 + + def test_remaining_never_negative(self): + """remaining should never go negative.""" + budget = ThreadBudget(max_tokens=100) + budget.consume(prompt_tokens=200) + + assert budget.remaining == 0 + assert budget.total_tokens == 200 + + +# ============================================================================ +# ThreadBudgetRegistry Tests +# ============================================================================ + +class TestThreadBudgetRegistry: + """Test ThreadBudgetRegistry.""" + + @pytest.fixture(autouse=True) + def reset(self): + """Reset global registry before each test.""" + reset_budget_registry() + yield + reset_budget_registry() + + def test_default_budget_creation(self): + """Getting budget for new thread should create one.""" + registry = ThreadBudgetRegistry(max_tokens_per_thread=50000) + budget = registry.get_budget("thread-1") + + assert budget.max_tokens == 50000 + assert budget.total_tokens == 0 + + def test_configure_max_tokens(self): + """configure() should update default for new threads.""" + registry = ThreadBudgetRegistry(max_tokens_per_thread=10000) + budget1 = registry.get_budget("thread-1") + + registry.configure(max_tokens_per_thread=20000) + budget2 = registry.get_budget("thread-2") + + assert budget1.max_tokens == 10000 # Original unchanged + assert budget2.max_tokens == 20000 # New default + + def test_check_budget_success(self): + """check_budget should pass when within budget.""" + registry = ThreadBudgetRegistry(max_tokens_per_thread=10000) + + result = registry.check_budget("thread-1", estimated_tokens=5000) + assert result is True + + def test_check_budget_exhausted(self): + """check_budget should raise when budget exhausted.""" + registry = ThreadBudgetRegistry(max_tokens_per_thread=1000) + registry.consume("thread-1", prompt_tokens=1000) + + with pytest.raises(BudgetExhaustedError) as exc_info: + registry.check_budget("thread-1", estimated_tokens=100) + + assert "budget exhausted" in str(exc_info.value) + assert exc_info.value.thread_id == "thread-1" + assert exc_info.value.used == 1000 + assert exc_info.value.max_tokens == 1000 + + def test_check_budget_would_exceed(self): + """check_budget should raise when estimate would exceed.""" + registry = ThreadBudgetRegistry(max_tokens_per_thread=1000) + registry.consume("thread-1", prompt_tokens=600) + + with pytest.raises(BudgetExhaustedError): + registry.check_budget("thread-1", estimated_tokens=500) + + def test_consume_returns_budget(self): + """consume() should return updated budget.""" + registry = ThreadBudgetRegistry(max_tokens_per_thread=10000) + + budget = registry.consume("thread-1", prompt_tokens=100, completion_tokens=50) + + assert budget.total_tokens == 150 + assert budget.request_count == 1 + + def test_get_usage(self): + """get_usage should return dict with all stats.""" + registry = ThreadBudgetRegistry(max_tokens_per_thread=10000) + registry.consume("thread-1", prompt_tokens=500, completion_tokens=200) + registry.consume("thread-1", prompt_tokens=300, completion_tokens=100) + + usage = registry.get_usage("thread-1") + + assert usage["prompt_tokens"] == 800 + assert usage["completion_tokens"] == 300 + assert usage["total_tokens"] == 1100 + assert usage["remaining"] == 8900 + assert usage["request_count"] == 2 + + def test_get_all_usage(self): + """get_all_usage should return all threads.""" + registry = ThreadBudgetRegistry(max_tokens_per_thread=10000) + registry.consume("thread-1", prompt_tokens=100) + registry.consume("thread-2", prompt_tokens=200) + + all_usage = registry.get_all_usage() + + assert len(all_usage) == 2 + assert "thread-1" in all_usage + assert "thread-2" in all_usage + + def test_reset_thread(self): + """reset_thread should remove budget for thread.""" + registry = ThreadBudgetRegistry(max_tokens_per_thread=10000) + registry.consume("thread-1", prompt_tokens=500) + registry.reset_thread("thread-1") + + # Getting budget should create new one with zero usage + budget = registry.get_budget("thread-1") + assert budget.total_tokens == 0 + + def test_cleanup_thread(self): + """cleanup_thread should return and remove budget.""" + registry = ThreadBudgetRegistry(max_tokens_per_thread=10000) + registry.consume("thread-1", prompt_tokens=500) + + final_budget = registry.cleanup_thread("thread-1") + + assert final_budget.total_tokens == 500 + assert registry.cleanup_thread("thread-1") is None # Already cleaned + + def test_global_registry(self): + """Global registry should be singleton.""" + registry1 = get_budget_registry() + registry2 = get_budget_registry() + + assert registry1 is registry2 + + def test_global_configure(self): + """configure_budget_registry should update global.""" + configure_budget_registry(max_tokens_per_thread=75000) + registry = get_budget_registry() + + budget = registry.get_budget("new-thread") + assert budget.max_tokens == 75000 + + +# ============================================================================ +# UsageTracker Tests +# ============================================================================ + +class TestUsageTracker: + """Test UsageTracker for billing/metering.""" + + @pytest.fixture(autouse=True) + def reset(self): + """Reset global tracker before each test.""" + reset_usage_tracker() + yield + reset_usage_tracker() + + def test_record_creates_event(self): + """record() should create and return UsageEvent.""" + tracker = UsageTracker() + + event = tracker.record( + thread_id="thread-1", + agent_id="greeter", + model="grok-4.1", + provider="xai", + prompt_tokens=500, + completion_tokens=200, + latency_ms=150.5, + ) + + assert event.thread_id == "thread-1" + assert event.agent_id == "greeter" + assert event.model == "grok-4.1" + assert event.total_tokens == 700 + assert event.timestamp is not None + + def test_record_estimates_cost(self): + """record() should estimate cost for known models.""" + tracker = UsageTracker() + + event = tracker.record( + thread_id="thread-1", + agent_id="agent", + model="grok-4.1", + provider="xai", + prompt_tokens=1_000_000, # 1M prompt + completion_tokens=1_000_000, # 1M completion + latency_ms=1000, + ) + + # grok-4.1: $3/1M prompt + $15/1M completion = $18 + assert event.estimated_cost == 18.0 + + def test_subscriber_receives_events(self): + """Subscribers should receive events on record.""" + tracker = UsageTracker() + received = [] + + tracker.subscribe(lambda e: received.append(e)) + + tracker.record( + thread_id="t1", + agent_id="agent", + model="gpt-4o", + provider="openai", + prompt_tokens=100, + completion_tokens=50, + latency_ms=50, + ) + + assert len(received) == 1 + assert received[0].thread_id == "t1" + + def test_unsubscribe(self): + """unsubscribe should stop receiving events.""" + tracker = UsageTracker() + received = [] + callback = lambda e: received.append(e) + + tracker.subscribe(callback) + tracker.record(thread_id="t1", agent_id=None, model="m", provider="p", + prompt_tokens=10, completion_tokens=10, latency_ms=10) + + tracker.unsubscribe(callback) + tracker.record(thread_id="t2", agent_id=None, model="m", provider="p", + prompt_tokens=10, completion_tokens=10, latency_ms=10) + + assert len(received) == 1 + + def test_get_totals(self): + """get_totals should return aggregate stats.""" + tracker = UsageTracker() + + tracker.record(thread_id="t1", agent_id="a1", model="m1", provider="p", + prompt_tokens=100, completion_tokens=50, latency_ms=100) + tracker.record(thread_id="t2", agent_id="a2", model="m2", provider="p", + prompt_tokens=200, completion_tokens=100, latency_ms=200) + + totals = tracker.get_totals() + + assert totals["prompt_tokens"] == 300 + assert totals["completion_tokens"] == 150 + assert totals["total_tokens"] == 450 + assert totals["request_count"] == 2 + assert totals["avg_latency_ms"] == 150.0 + + def test_get_agent_totals(self): + """get_agent_totals should return per-agent stats.""" + tracker = UsageTracker() + + tracker.record(thread_id="t1", agent_id="greeter", model="m", provider="p", + prompt_tokens=100, completion_tokens=50, latency_ms=100) + tracker.record(thread_id="t2", agent_id="greeter", model="m", provider="p", + prompt_tokens=100, completion_tokens=50, latency_ms=100) + tracker.record(thread_id="t3", agent_id="shouter", model="m", provider="p", + prompt_tokens=200, completion_tokens=100, latency_ms=200) + + greeter = tracker.get_agent_totals("greeter") + shouter = tracker.get_agent_totals("shouter") + + assert greeter["total_tokens"] == 300 + assert greeter["request_count"] == 2 + assert shouter["total_tokens"] == 300 + assert shouter["request_count"] == 1 + + def test_get_model_totals(self): + """get_model_totals should return per-model stats.""" + tracker = UsageTracker() + + tracker.record(thread_id="t1", agent_id="a", model="grok-4.1", provider="xai", + prompt_tokens=1000, completion_tokens=500, latency_ms=100) + tracker.record(thread_id="t2", agent_id="a", model="claude-sonnet-4", provider="anthropic", + prompt_tokens=500, completion_tokens=250, latency_ms=100) + + grok = tracker.get_model_totals("grok-4.1") + claude = tracker.get_model_totals("claude-sonnet-4") + + assert grok["total_tokens"] == 1500 + assert claude["total_tokens"] == 750 + + def test_metadata_passed_through(self): + """Metadata should be included in events.""" + tracker = UsageTracker() + received = [] + tracker.subscribe(lambda e: received.append(e)) + + tracker.record( + thread_id="t1", + agent_id="a", + model="m", + provider="p", + prompt_tokens=10, + completion_tokens=10, + latency_ms=10, + metadata={"org_id": "org-123", "user_id": "user-456"}, + ) + + assert received[0].metadata["org_id"] == "org-123" + assert received[0].metadata["user_id"] == "user-456" + + +# ============================================================================ +# Cost Estimation Tests +# ============================================================================ + +class TestCostEstimation: + """Test cost estimation for various models.""" + + def test_grok_cost(self): + """Grok models should use correct pricing.""" + cost = estimate_cost("grok-4.1", prompt_tokens=1_000_000, completion_tokens=1_000_000) + # $3/1M prompt + $15/1M completion = $18 + assert cost == 18.0 + + def test_claude_opus_cost(self): + """Claude Opus should use correct pricing.""" + cost = estimate_cost("claude-opus-4", prompt_tokens=1_000_000, completion_tokens=1_000_000) + # $15/1M prompt + $75/1M completion = $90 + assert cost == 90.0 + + def test_gpt4o_cost(self): + """GPT-4o should use correct pricing.""" + cost = estimate_cost("gpt-4o", prompt_tokens=1_000_000, completion_tokens=1_000_000) + # $2.5/1M prompt + $10/1M completion = $12.5 + assert cost == 12.5 + + def test_unknown_model_returns_none(self): + """Unknown model should return None.""" + cost = estimate_cost("unknown-model", prompt_tokens=1000, completion_tokens=500) + assert cost is None + + def test_small_usage_cost(self): + """Small token counts should produce fractional costs.""" + cost = estimate_cost("gpt-4o-mini", prompt_tokens=1000, completion_tokens=500) + # 1000 tokens * $0.15/1M = $0.00015 + # 500 tokens * $0.6/1M = $0.0003 + # Total = $0.00045 + assert cost == pytest.approx(0.00045, rel=1e-4) + + +# ============================================================================ +# LLMRouter Integration Tests (Mocked) +# ============================================================================ + +class TestLLMRouterBudgetIntegration: + """Test LLMRouter budget enforcement.""" + + @pytest.fixture(autouse=True) + def reset_all(self): + """Reset all global registries.""" + reset_budget_registry() + reset_usage_tracker() + yield + reset_budget_registry() + reset_usage_tracker() + + @pytest.mark.asyncio + async def test_complete_consumes_budget(self): + """LLM complete should consume from thread budget.""" + from xml_pipeline.llm.router import LLMRouter + from xml_pipeline.llm.backend import LLMResponse + + # Create mock backend + mock_backend = Mock() + mock_backend.name = "mock" + mock_backend.provider = "test" + mock_backend.serves_model = Mock(return_value=True) + mock_backend.priority = 1 + mock_backend.load = 0 + mock_backend.complete = AsyncMock(return_value=LLMResponse( + content="Hello!", + model="test-model", + usage={"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150}, + finish_reason="stop", + )) + + # Configure budget + configure_budget_registry(max_tokens_per_thread=10000) + budget_registry = get_budget_registry() + + # Create router with mock backend + router = LLMRouter() + router.backends.append(mock_backend) + + # Make request + response = await router.complete( + model="test-model", + messages=[{"role": "user", "content": "Hi"}], + thread_id="test-thread-123", + ) + + assert response.content == "Hello!" + + # Verify budget consumed + usage = budget_registry.get_usage("test-thread-123") + assert usage["prompt_tokens"] == 100 + assert usage["completion_tokens"] == 50 + assert usage["total_tokens"] == 150 + + @pytest.mark.asyncio + async def test_complete_emits_usage_event(self): + """LLM complete should emit usage event.""" + from xml_pipeline.llm.router import LLMRouter + from xml_pipeline.llm.backend import LLMResponse + + mock_backend = Mock() + mock_backend.name = "mock" + mock_backend.provider = "test" + mock_backend.serves_model = Mock(return_value=True) + mock_backend.priority = 1 + mock_backend.load = 0 + mock_backend.complete = AsyncMock(return_value=LLMResponse( + content="Hello!", + model="test-model", + usage={"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150}, + finish_reason="stop", + )) + + # Subscribe to usage events + tracker = get_usage_tracker() + received_events = [] + tracker.subscribe(lambda e: received_events.append(e)) + + # Create router and make request + router = LLMRouter() + router.backends.append(mock_backend) + + await router.complete( + model="test-model", + messages=[{"role": "user", "content": "Hi"}], + thread_id="test-thread", + agent_id="greeter", + metadata={"org_id": "test-org"}, + ) + + # Verify event emitted + assert len(received_events) == 1 + event = received_events[0] + assert event.thread_id == "test-thread" + assert event.agent_id == "greeter" + assert event.total_tokens == 150 + assert event.metadata["org_id"] == "test-org" + + @pytest.mark.asyncio + async def test_complete_raises_when_budget_exhausted(self): + """LLM complete should raise when budget exhausted.""" + from xml_pipeline.llm.router import LLMRouter + + # Configure small budget and exhaust it + configure_budget_registry(max_tokens_per_thread=100) + budget_registry = get_budget_registry() + budget_registry.consume("test-thread", prompt_tokens=100) + + mock_backend = Mock() + mock_backend.name = "mock" + mock_backend.serves_model = Mock(return_value=True) + mock_backend.priority = 1 + + router = LLMRouter() + router.backends.append(mock_backend) + + with pytest.raises(BudgetExhaustedError) as exc_info: + await router.complete( + model="test-model", + messages=[{"role": "user", "content": "Hi"}], + thread_id="test-thread", + ) + + assert "budget exhausted" in str(exc_info.value) + # Backend should NOT have been called + mock_backend.complete.assert_not_called() + + @pytest.mark.asyncio + async def test_complete_without_thread_id_skips_budget(self): + """LLM complete without thread_id should skip budget check.""" + from xml_pipeline.llm.router import LLMRouter + from xml_pipeline.llm.backend import LLMResponse + + mock_backend = Mock() + mock_backend.name = "mock" + mock_backend.provider = "test" + mock_backend.serves_model = Mock(return_value=True) + mock_backend.priority = 1 + mock_backend.load = 0 + mock_backend.complete = AsyncMock(return_value=LLMResponse( + content="Hello!", + model="test-model", + usage={"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150}, + finish_reason="stop", + )) + + router = LLMRouter() + router.backends.append(mock_backend) + + # Should not raise - no budget checking + response = await router.complete( + model="test-model", + messages=[{"role": "user", "content": "Hi"}], + # No thread_id + ) + + assert response.content == "Hello!" diff --git a/xml_pipeline/llm/__init__.py b/xml_pipeline/llm/__init__.py index c4429d3..01aeb02 100644 --- a/xml_pipeline/llm/__init__.py +++ b/xml_pipeline/llm/__init__.py @@ -16,7 +16,20 @@ Usage: response = await router.complete( model="grok-4.1", messages=[{"role": "user", "content": "Hello"}], + thread_id=metadata.thread_id, # For budget enforcement + agent_id=metadata.own_name, # For usage tracking ) + +Usage Tracking: + from xml_pipeline.llm import get_usage_tracker + + tracker = get_usage_tracker() + + # Subscribe to events for billing + tracker.subscribe(lambda event: billing_api.record(event)) + + # Query totals + totals = tracker.get_totals() """ from xml_pipeline.llm.router import ( @@ -27,14 +40,27 @@ from xml_pipeline.llm.router import ( Strategy, ) from xml_pipeline.llm.backend import LLMRequest, LLMResponse, BackendError +from xml_pipeline.llm.usage_tracker import ( + UsageTracker, + UsageEvent, + get_usage_tracker, + reset_usage_tracker, +) __all__ = [ + # Router "LLMRouter", "get_router", "configure_router", "complete", "Strategy", + # Backend "LLMRequest", "LLMResponse", "BackendError", + # Usage tracking + "UsageTracker", + "UsageEvent", + "get_usage_tracker", + "reset_usage_tracker", ] diff --git a/xml_pipeline/llm/router.py b/xml_pipeline/llm/router.py index b58d474..6c25048 100644 --- a/xml_pipeline/llm/router.py +++ b/xml_pipeline/llm/router.py @@ -9,6 +9,8 @@ The router handles: - Load balancing (failover, round-robin, least-loaded) - Retries with exponential backoff - Token tracking per agent +- Thread budget enforcement +- Usage event emission for billing """ from __future__ import annotations @@ -16,6 +18,7 @@ from __future__ import annotations import asyncio import logging import random +import time from dataclasses import dataclass, field from enum import Enum from typing import List, Dict, Any, Optional @@ -125,6 +128,8 @@ class LLMRouter: max_tokens: int = None, tools: List[Dict] = None, agent_id: str = None, + thread_id: str = None, + metadata: Dict[str, Any] = None, ) -> LLMResponse: """ Execute a completion request. @@ -136,10 +141,27 @@ class LLMRouter: max_tokens: Max tokens in response tools: Tool definitions for function calling agent_id: Optional agent ID for usage tracking + thread_id: Optional thread ID for budget enforcement + metadata: Optional metadata for usage events (org_id, user_id, etc.) Returns: LLMResponse with content and usage stats + + Raises: + BudgetExhaustedError: If thread has no remaining budget + BackendError: If all backends fail """ + # Estimate tokens for budget check (rough: 4 chars per token) + estimated_tokens = sum(len(m.get("content", "")) for m in messages) // 4 + estimated_tokens = max(estimated_tokens, 100) # minimum estimate + + # Check thread budget before proceeding + if thread_id: + from xml_pipeline.message_bus.budget_registry import get_budget_registry + budget_registry = get_budget_registry() + # This raises BudgetExhaustedError if over budget + budget_registry.check_budget(thread_id, estimated_tokens) + candidates = self._find_backends(model) request = LLMRequest( model=model, @@ -151,6 +173,7 @@ class LLMRouter: last_error = None tried_backends = set() + start_time = time.monotonic() for attempt in range(self.retries + 1): # Select backend (different selection on retry for failover) @@ -170,14 +193,46 @@ class LLMRouter: logger.debug(f"Attempting {model} on {backend.name} (attempt {attempt + 1})") response = await backend.complete(request) - # Track usage + # Calculate latency + latency_ms = (time.monotonic() - start_time) * 1000 + + # Extract usage + prompt_tokens = response.usage.get("prompt_tokens", 0) + completion_tokens = response.usage.get("completion_tokens", 0) + total_tokens = response.usage.get("total_tokens", 0) + + # Track per-agent usage (internal) if agent_id: usage = self._agent_usage.setdefault(agent_id, AgentUsage()) - usage.total_tokens += response.usage.get("total_tokens", 0) - usage.prompt_tokens += response.usage.get("prompt_tokens", 0) - usage.completion_tokens += response.usage.get("completion_tokens", 0) + usage.total_tokens += total_tokens + usage.prompt_tokens += prompt_tokens + usage.completion_tokens += completion_tokens usage.request_count += 1 + # Record to thread budget (enforcement) + if thread_id: + from xml_pipeline.message_bus.budget_registry import get_budget_registry + budget_registry = get_budget_registry() + budget_registry.consume( + thread_id, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + + # Emit usage event (for billing) + from xml_pipeline.llm.usage_tracker import get_usage_tracker + tracker = get_usage_tracker() + tracker.record( + thread_id=thread_id or "", + agent_id=agent_id, + model=response.model, + provider=backend.provider, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + latency_ms=latency_ms, + metadata=metadata, + ) + return response except RateLimitError as e: @@ -286,6 +341,10 @@ def configure_router(config: Dict[str, Any]) -> LLMRouter: async def complete( model: str, messages: List[Dict[str, str]], + *, + thread_id: str = None, + agent_id: str = None, + metadata: Dict[str, Any] = None, **kwargs, ) -> LLMResponse: """ @@ -293,6 +352,32 @@ async def complete( Usage: from xml_pipeline.llm import router - response = await router.complete("grok-4.1", messages) + response = await router.complete( + "grok-4.1", + messages, + thread_id=metadata.thread_id, + agent_id=metadata.own_name, + ) + + Args: + model: Model name + messages: Chat messages + thread_id: Thread UUID for budget enforcement + agent_id: Agent name for usage tracking + metadata: Extra metadata for billing events + **kwargs: Additional arguments (temperature, max_tokens, tools) + + Returns: + LLMResponse with content and usage stats + + Raises: + BudgetExhaustedError: If thread budget exhausted """ - return await get_router().complete(model, messages, **kwargs) + return await get_router().complete( + model, + messages, + thread_id=thread_id, + agent_id=agent_id, + metadata=metadata, + **kwargs, + ) diff --git a/xml_pipeline/llm/usage_tracker.py b/xml_pipeline/llm/usage_tracker.py new file mode 100644 index 0000000..3360cef --- /dev/null +++ b/xml_pipeline/llm/usage_tracker.py @@ -0,0 +1,346 @@ +""" +Usage Tracker — Production billing and gas usage metering. + +This module provides hooks for tracking LLM usage at the platform level. +External billing systems can subscribe to usage events for metering. + +Usage Tracking Layers: +1. Per-agent (LLMRouter._agent_usage) — Internal token tracking +2. Per-thread (ThreadBudgetRegistry) — Enforcement limits +3. Platform (UsageTracker) — Production billing/metering + +Example: + from xml_pipeline.llm.usage_tracker import get_usage_tracker + + tracker = get_usage_tracker() + + # Subscribe to usage events (for billing webhook, database, etc.) + def record_usage(event: UsageEvent): + billing_db.record( + org_id=event.metadata.get("org_id"), + tokens=event.total_tokens, + cost=event.estimated_cost, + ) + + tracker.subscribe(record_usage) + + # Query aggregate usage + totals = tracker.get_totals() + print(f"Total tokens: {totals['total_tokens']}") +""" + +from __future__ import annotations + +import threading +import time +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Callable, Dict, List, Optional + + +@dataclass +class UsageEvent: + """ + Usage event emitted after each LLM completion. + + This is the main interface for billing systems. + """ + + # Request identification + thread_id: str + agent_id: Optional[str] + model: str + provider: str + + # Token usage + prompt_tokens: int + completion_tokens: int + total_tokens: int + + # Timing + timestamp: str # ISO 8601 + latency_ms: float # Request duration + + # Cost estimation (if available) + estimated_cost: Optional[float] = None + + # Extensible metadata (org_id, user_id, etc.) + metadata: Dict[str, Any] = field(default_factory=dict) + + +# Cost per 1M tokens for common models (approximate, update as needed) +MODEL_COSTS: Dict[str, Dict[str, float]] = { + # xAI Grok + "grok-4.1": {"prompt": 3.0, "completion": 15.0}, + "grok-3": {"prompt": 3.0, "completion": 15.0}, + # Anthropic Claude + "claude-opus-4": {"prompt": 15.0, "completion": 75.0}, + "claude-sonnet-4": {"prompt": 3.0, "completion": 15.0}, + "claude-sonnet-3-5": {"prompt": 3.0, "completion": 15.0}, + # OpenAI + "gpt-4o": {"prompt": 2.5, "completion": 10.0}, + "gpt-4o-mini": {"prompt": 0.15, "completion": 0.6}, + "o1": {"prompt": 15.0, "completion": 60.0}, + "o3-mini": {"prompt": 1.1, "completion": 4.4}, +} + + +def estimate_cost( + model: str, + prompt_tokens: int, + completion_tokens: int, +) -> Optional[float]: + """ + Estimate cost in USD for a completion. + + Returns None if model pricing is unknown. + """ + # Normalize model name for lookup + model_lower = model.lower() + + # Find matching pricing (prefer longest prefix match) + pricing = None + best_match_len = 0 + + for model_prefix, costs in MODEL_COSTS.items(): + prefix_lower = model_prefix.lower() + if model_lower.startswith(prefix_lower): + if len(prefix_lower) > best_match_len: + pricing = costs + best_match_len = len(prefix_lower) + + if pricing is None: + return None + + # Cost = (tokens / 1M) * cost_per_million + prompt_cost = (prompt_tokens / 1_000_000) * pricing["prompt"] + completion_cost = (completion_tokens / 1_000_000) * pricing["completion"] + + return round(prompt_cost + completion_cost, 6) + + +UsageCallback = Callable[[UsageEvent], None] + + +@dataclass +class UsageTotals: + """Aggregate usage statistics.""" + + total_tokens: int = 0 + prompt_tokens: int = 0 + completion_tokens: int = 0 + request_count: int = 0 + total_cost: float = 0.0 + total_latency_ms: float = 0.0 + + +class UsageTracker: + """ + Platform-level usage tracking for billing and metering. + + Thread-safe. Supports multiple subscribers for real-time event streaming. + + Integration points: + - Webhook to billing API + - Database for usage records + - Metrics/observability (Prometheus, DataDog) + - Real-time dashboard (WebSocket) + """ + + def __init__(self): + self._callbacks: List[UsageCallback] = [] + self._lock = threading.Lock() + + # Aggregate tracking + self._totals = UsageTotals() + self._per_agent: Dict[str, UsageTotals] = {} + self._per_model: Dict[str, UsageTotals] = {} + + def subscribe(self, callback: UsageCallback) -> None: + """ + Subscribe to usage events. + + Callbacks are invoked synchronously after each LLM completion. + For async processing, use a queue in your callback. + """ + with self._lock: + self._callbacks.append(callback) + + def unsubscribe(self, callback: UsageCallback) -> None: + """Unsubscribe from usage events.""" + with self._lock: + if callback in self._callbacks: + self._callbacks.remove(callback) + + def record( + self, + thread_id: str, + agent_id: Optional[str], + model: str, + provider: str, + prompt_tokens: int, + completion_tokens: int, + latency_ms: float, + metadata: Optional[Dict[str, Any]] = None, + ) -> UsageEvent: + """ + Record a usage event and notify subscribers. + + Called by LLMRouter after each completion. + + Returns: + The created UsageEvent (for chaining/logging) + """ + total_tokens = prompt_tokens + completion_tokens + + event = UsageEvent( + thread_id=thread_id, + agent_id=agent_id, + model=model, + provider=provider, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + timestamp=datetime.now(timezone.utc).isoformat(), + latency_ms=latency_ms, + estimated_cost=estimate_cost(model, prompt_tokens, completion_tokens), + metadata=metadata or {}, + ) + + # Update aggregates + with self._lock: + self._update_totals(self._totals, event) + + if agent_id: + if agent_id not in self._per_agent: + self._per_agent[agent_id] = UsageTotals() + self._update_totals(self._per_agent[agent_id], event) + + if model not in self._per_model: + self._per_model[model] = UsageTotals() + self._update_totals(self._per_model[model], event) + + # Copy callbacks to avoid holding lock during invocation + callbacks = list(self._callbacks) + + # Notify subscribers (outside lock) + for callback in callbacks: + try: + callback(event) + except Exception: + # Don't let subscriber errors break tracking + pass + + return event + + def _update_totals(self, totals: UsageTotals, event: UsageEvent) -> None: + """Update aggregate totals from an event.""" + totals.total_tokens += event.total_tokens + totals.prompt_tokens += event.prompt_tokens + totals.completion_tokens += event.completion_tokens + totals.request_count += 1 + totals.total_latency_ms += event.latency_ms + if event.estimated_cost: + totals.total_cost += event.estimated_cost + + def get_totals(self) -> Dict[str, Any]: + """Get aggregate usage totals.""" + with self._lock: + return { + "total_tokens": self._totals.total_tokens, + "prompt_tokens": self._totals.prompt_tokens, + "completion_tokens": self._totals.completion_tokens, + "request_count": self._totals.request_count, + "total_cost": round(self._totals.total_cost, 4), + "avg_latency_ms": ( + self._totals.total_latency_ms / self._totals.request_count + if self._totals.request_count > 0 + else 0 + ), + } + + def get_agent_totals(self, agent_id: str) -> Dict[str, Any]: + """Get usage totals for a specific agent.""" + with self._lock: + totals = self._per_agent.get(agent_id, UsageTotals()) + return { + "total_tokens": totals.total_tokens, + "prompt_tokens": totals.prompt_tokens, + "completion_tokens": totals.completion_tokens, + "request_count": totals.request_count, + "total_cost": round(totals.total_cost, 4), + } + + def get_model_totals(self, model: str) -> Dict[str, Any]: + """Get usage totals for a specific model.""" + with self._lock: + totals = self._per_model.get(model, UsageTotals()) + return { + "total_tokens": totals.total_tokens, + "prompt_tokens": totals.prompt_tokens, + "completion_tokens": totals.completion_tokens, + "request_count": totals.request_count, + "total_cost": round(totals.total_cost, 4), + } + + def get_all_agent_totals(self) -> Dict[str, Dict[str, Any]]: + """Get usage totals for all agents.""" + with self._lock: + return { + agent_id: { + "total_tokens": t.total_tokens, + "prompt_tokens": t.prompt_tokens, + "completion_tokens": t.completion_tokens, + "request_count": t.request_count, + "total_cost": round(t.total_cost, 4), + } + for agent_id, t in self._per_agent.items() + } + + def get_all_model_totals(self) -> Dict[str, Dict[str, Any]]: + """Get usage totals for all models.""" + with self._lock: + return { + model: { + "total_tokens": t.total_tokens, + "prompt_tokens": t.prompt_tokens, + "completion_tokens": t.completion_tokens, + "request_count": t.request_count, + "total_cost": round(t.total_cost, 4), + } + for model, t in self._per_model.items() + } + + def reset(self) -> None: + """Reset all tracking (for testing).""" + with self._lock: + self._totals = UsageTotals() + self._per_agent.clear() + self._per_model.clear() + + +# ============================================================================= +# Global Instance +# ============================================================================= + +_tracker: Optional[UsageTracker] = None +_tracker_lock = threading.Lock() + + +def get_usage_tracker() -> UsageTracker: + """Get the global usage tracker.""" + global _tracker + if _tracker is None: + with _tracker_lock: + if _tracker is None: + _tracker = UsageTracker() + return _tracker + + +def reset_usage_tracker() -> None: + """Reset the global tracker (for testing).""" + global _tracker + with _tracker_lock: + if _tracker is not None: + _tracker.reset() + _tracker = None diff --git a/xml_pipeline/message_bus/__init__.py b/xml_pipeline/message_bus/__init__.py index a5e18af..df24088 100644 --- a/xml_pipeline/message_bus/__init__.py +++ b/xml_pipeline/message_bus/__init__.py @@ -67,6 +67,15 @@ from xml_pipeline.message_bus.buffer_registry import ( reset_buffer_registry, ) +from xml_pipeline.message_bus.budget_registry import ( + ThreadBudget, + ThreadBudgetRegistry, + BudgetExhaustedError, + get_budget_registry, + configure_budget_registry, + reset_budget_registry, +) + __all__ = [ # Pump "StreamPump", @@ -102,4 +111,11 @@ __all__ = [ "BufferRegistry", "get_buffer_registry", "reset_buffer_registry", + # Budget registry + "ThreadBudget", + "ThreadBudgetRegistry", + "BudgetExhaustedError", + "get_budget_registry", + "configure_budget_registry", + "reset_budget_registry", ] diff --git a/xml_pipeline/message_bus/budget_registry.py b/xml_pipeline/message_bus/budget_registry.py new file mode 100644 index 0000000..3b4f323 --- /dev/null +++ b/xml_pipeline/message_bus/budget_registry.py @@ -0,0 +1,280 @@ +""" +Thread Budget Registry — Enforces per-thread token limits. + +Each thread has a token budget that tracks: +- Total tokens consumed (prompt + completion) +- Requests made +- Remaining budget + +When a thread exhausts its budget, LLM calls are blocked. + +Example config: + organism: + max_tokens_per_thread: 100000 # 100k tokens per thread +""" + +from __future__ import annotations + +import threading +from dataclasses import dataclass, field +from typing import Dict, Optional + + +@dataclass +class ThreadBudget: + """Track token usage for a single thread.""" + + max_tokens: int + prompt_tokens: int = 0 + completion_tokens: int = 0 + request_count: int = 0 + + @property + def total_tokens(self) -> int: + """Total tokens consumed.""" + return self.prompt_tokens + self.completion_tokens + + @property + def remaining(self) -> int: + """Remaining token budget.""" + return max(0, self.max_tokens - self.total_tokens) + + @property + def is_exhausted(self) -> bool: + """True if budget is exhausted.""" + return self.total_tokens >= self.max_tokens + + def can_consume(self, estimated_tokens: int) -> bool: + """Check if we can consume the given tokens without exceeding budget.""" + return self.total_tokens + estimated_tokens <= self.max_tokens + + def consume( + self, + prompt_tokens: int = 0, + completion_tokens: int = 0, + ) -> None: + """Record token consumption.""" + self.prompt_tokens += prompt_tokens + self.completion_tokens += completion_tokens + self.request_count += 1 + + +class BudgetExhaustedError(Exception): + """Raised when a thread's token budget is exhausted.""" + + def __init__(self, thread_id: str, used: int, max_tokens: int): + self.thread_id = thread_id + self.used = used + self.max_tokens = max_tokens + super().__init__( + f"Thread {thread_id[:8]}... budget exhausted: " + f"{used}/{max_tokens} tokens used" + ) + + +class ThreadBudgetRegistry: + """ + Manages token budgets per thread. + + Thread-safe for concurrent access. + + Usage: + registry = get_budget_registry() + registry.configure(max_tokens_per_thread=100000) + + # Before LLM call + registry.check_budget(thread_id, estimated_tokens=1000) + + # After LLM call + registry.consume(thread_id, prompt=500, completion=300) + + # Get usage + budget = registry.get_budget(thread_id) + print(f"Used: {budget.total_tokens}, Remaining: {budget.remaining}") + """ + + def __init__(self, max_tokens_per_thread: int = 100_000): + """ + Initialize budget registry. + + Args: + max_tokens_per_thread: Default budget for new threads. + """ + self._max_tokens_per_thread = max_tokens_per_thread + self._budgets: Dict[str, ThreadBudget] = {} + self._lock = threading.Lock() + + def configure(self, max_tokens_per_thread: int) -> None: + """ + Update default max tokens for new threads. + + Existing threads keep their current budgets. + """ + with self._lock: + self._max_tokens_per_thread = max_tokens_per_thread + + @property + def max_tokens_per_thread(self) -> int: + """Get the default max tokens per thread.""" + return self._max_tokens_per_thread + + def get_budget(self, thread_id: str) -> ThreadBudget: + """ + Get or create budget for a thread. + + Args: + thread_id: Thread UUID + + Returns: + ThreadBudget instance + """ + with self._lock: + if thread_id not in self._budgets: + self._budgets[thread_id] = ThreadBudget( + max_tokens=self._max_tokens_per_thread + ) + return self._budgets[thread_id] + + def check_budget( + self, + thread_id: str, + estimated_tokens: int = 0, + ) -> bool: + """ + Check if thread has budget for the estimated tokens. + + Args: + thread_id: Thread UUID + estimated_tokens: Estimated tokens for the request + + Returns: + True if budget available + + Raises: + BudgetExhaustedError if budget is exhausted + """ + budget = self.get_budget(thread_id) + + if budget.is_exhausted: + raise BudgetExhaustedError( + thread_id=thread_id, + used=budget.total_tokens, + max_tokens=budget.max_tokens, + ) + + if not budget.can_consume(estimated_tokens): + raise BudgetExhaustedError( + thread_id=thread_id, + used=budget.total_tokens, + max_tokens=budget.max_tokens, + ) + + return True + + def consume( + self, + thread_id: str, + prompt_tokens: int = 0, + completion_tokens: int = 0, + ) -> ThreadBudget: + """ + Record token consumption for a thread. + + Args: + thread_id: Thread UUID + prompt_tokens: Prompt tokens used + completion_tokens: Completion tokens used + + Returns: + Updated ThreadBudget + """ + budget = self.get_budget(thread_id) + with self._lock: + budget.consume(prompt_tokens, completion_tokens) + return budget + + def get_usage(self, thread_id: str) -> Dict[str, int]: + """ + Get usage stats for a thread. + + Returns: + Dict with prompt_tokens, completion_tokens, total_tokens, + remaining, max_tokens, request_count + """ + budget = self.get_budget(thread_id) + return { + "prompt_tokens": budget.prompt_tokens, + "completion_tokens": budget.completion_tokens, + "total_tokens": budget.total_tokens, + "remaining": budget.remaining, + "max_tokens": budget.max_tokens, + "request_count": budget.request_count, + } + + def get_all_usage(self) -> Dict[str, Dict[str, int]]: + """Get usage stats for all threads.""" + with self._lock: + return { + thread_id: { + "prompt_tokens": b.prompt_tokens, + "completion_tokens": b.completion_tokens, + "total_tokens": b.total_tokens, + "remaining": b.remaining, + "max_tokens": b.max_tokens, + "request_count": b.request_count, + } + for thread_id, b in self._budgets.items() + } + + def reset_thread(self, thread_id: str) -> None: + """Reset budget for a specific thread.""" + with self._lock: + self._budgets.pop(thread_id, None) + + def cleanup_thread(self, thread_id: str) -> Optional[ThreadBudget]: + """ + Remove budget when thread is pruned/completed. + + Returns the final budget for logging/billing, or None if not found. + """ + with self._lock: + return self._budgets.pop(thread_id, None) + + def clear(self) -> None: + """Clear all budgets (for testing).""" + with self._lock: + self._budgets.clear() + + +# ============================================================================= +# Global Instance +# ============================================================================= + +_registry: Optional[ThreadBudgetRegistry] = None +_registry_lock = threading.Lock() + + +def get_budget_registry() -> ThreadBudgetRegistry: + """Get the global budget registry.""" + global _registry + if _registry is None: + with _registry_lock: + if _registry is None: + _registry = ThreadBudgetRegistry() + return _registry + + +def configure_budget_registry(max_tokens_per_thread: int) -> ThreadBudgetRegistry: + """Configure the global budget registry.""" + registry = get_budget_registry() + registry.configure(max_tokens_per_thread) + return registry + + +def reset_budget_registry() -> None: + """Reset the global registry (for testing).""" + global _registry + with _registry_lock: + if _registry is not None: + _registry.clear() + _registry = None diff --git a/xml_pipeline/message_bus/stream_pump.py b/xml_pipeline/message_bus/stream_pump.py index f86aa54..50dea7f 100644 --- a/xml_pipeline/message_bus/stream_pump.py +++ b/xml_pipeline/message_bus/stream_pump.py @@ -141,6 +141,9 @@ class OrganismConfig: max_concurrent_handlers: int = 20 # Concurrent handler invocations max_concurrent_per_agent: int = 5 # Per-agent rate limit + # Token budget enforcement + max_tokens_per_thread: int = 100_000 # Max tokens per conversation thread + # LLM configuration (optional) llm_config: Dict[str, Any] = field(default_factory=dict) @@ -1271,6 +1274,7 @@ class ConfigLoader: max_concurrent_pipelines=raw.get("max_concurrent_pipelines", 50), max_concurrent_handlers=raw.get("max_concurrent_handlers", 20), max_concurrent_per_agent=raw.get("max_concurrent_per_agent", 5), + max_tokens_per_thread=raw.get("max_tokens_per_thread", 100_000), llm_config=raw.get("llm", {}), process_pool_enabled=process_pool_enabled, process_pool_workers=process_pool_workers, @@ -1430,6 +1434,11 @@ async def bootstrap(config_path: str = "config/organism.yaml") -> StreamPump: configure_router(config.llm_config) print(f"LLM backends: {len(config.llm_config.get('backends', []))}") + # Configure thread budget registry + from xml_pipeline.message_bus.budget_registry import configure_budget_registry + configure_budget_registry(config.max_tokens_per_thread) + print(f"Token budget: {config.max_tokens_per_thread:,} per thread") + # Initialize root thread in registry registry = get_registry() root_uuid = registry.initialize_root(config.name)