Add token budget enforcement and usage tracking

Token Budget System:
- ThreadBudgetRegistry tracks per-thread token usage with configurable limits
- BudgetExhaustedError raised when thread exceeds max_tokens_per_thread
- Integrates with LLMRouter to block LLM calls when budget exhausted
- Automatic cleanup when threads are pruned

Usage Tracking (for production billing):
- UsageTracker emits events after each LLM completion
- Subscribers receive UsageEvent with tokens, latency, estimated cost
- Cost estimation for common models (Grok, Claude, GPT, etc.)
- Aggregate stats by agent, model, and totals

Configuration:
- max_tokens_per_thread in organism.yaml (default 100k)
- LLMRouter.complete() accepts thread_id and metadata parameters

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
dullfig 2026-01-27 21:07:43 -08:00
parent 4530c06835
commit 8b11323a8b
7 changed files with 1341 additions and 6 deletions

573
tests/test_token_budget.py Normal file
View file

@ -0,0 +1,573 @@
"""
test_token_budget.py Tests for token budget and usage tracking.
Tests:
1. ThreadBudgetRegistry - per-thread token limits
2. UsageTracker - billing/gas usage events
3. LLMRouter integration - budget enforcement
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch
from xml_pipeline.message_bus.budget_registry import (
ThreadBudget,
ThreadBudgetRegistry,
BudgetExhaustedError,
get_budget_registry,
configure_budget_registry,
reset_budget_registry,
)
from xml_pipeline.llm.usage_tracker import (
UsageEvent,
UsageTracker,
UsageTotals,
estimate_cost,
get_usage_tracker,
reset_usage_tracker,
)
# ============================================================================
# ThreadBudget Tests
# ============================================================================
class TestThreadBudget:
"""Test ThreadBudget dataclass."""
def test_initial_state(self):
"""New budget should have zero usage."""
budget = ThreadBudget(max_tokens=10000)
assert budget.total_tokens == 0
assert budget.remaining == 10000
assert budget.is_exhausted is False
def test_consume_tokens(self):
"""Consuming tokens should update totals."""
budget = ThreadBudget(max_tokens=10000)
budget.consume(prompt_tokens=500, completion_tokens=300)
assert budget.prompt_tokens == 500
assert budget.completion_tokens == 300
assert budget.total_tokens == 800
assert budget.remaining == 9200
assert budget.request_count == 1
def test_can_consume_within_budget(self):
"""can_consume should return True if within budget."""
budget = ThreadBudget(max_tokens=1000)
budget.consume(prompt_tokens=400)
assert budget.can_consume(500) is True
assert budget.can_consume(600) is True
assert budget.can_consume(601) is False
def test_is_exhausted(self):
"""is_exhausted should return True when budget exceeded."""
budget = ThreadBudget(max_tokens=1000)
budget.consume(prompt_tokens=1000)
assert budget.is_exhausted is True
assert budget.remaining == 0
def test_remaining_never_negative(self):
"""remaining should never go negative."""
budget = ThreadBudget(max_tokens=100)
budget.consume(prompt_tokens=200)
assert budget.remaining == 0
assert budget.total_tokens == 200
# ============================================================================
# ThreadBudgetRegistry Tests
# ============================================================================
class TestThreadBudgetRegistry:
"""Test ThreadBudgetRegistry."""
@pytest.fixture(autouse=True)
def reset(self):
"""Reset global registry before each test."""
reset_budget_registry()
yield
reset_budget_registry()
def test_default_budget_creation(self):
"""Getting budget for new thread should create one."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=50000)
budget = registry.get_budget("thread-1")
assert budget.max_tokens == 50000
assert budget.total_tokens == 0
def test_configure_max_tokens(self):
"""configure() should update default for new threads."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
budget1 = registry.get_budget("thread-1")
registry.configure(max_tokens_per_thread=20000)
budget2 = registry.get_budget("thread-2")
assert budget1.max_tokens == 10000 # Original unchanged
assert budget2.max_tokens == 20000 # New default
def test_check_budget_success(self):
"""check_budget should pass when within budget."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
result = registry.check_budget("thread-1", estimated_tokens=5000)
assert result is True
def test_check_budget_exhausted(self):
"""check_budget should raise when budget exhausted."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
registry.consume("thread-1", prompt_tokens=1000)
with pytest.raises(BudgetExhaustedError) as exc_info:
registry.check_budget("thread-1", estimated_tokens=100)
assert "budget exhausted" in str(exc_info.value)
assert exc_info.value.thread_id == "thread-1"
assert exc_info.value.used == 1000
assert exc_info.value.max_tokens == 1000
def test_check_budget_would_exceed(self):
"""check_budget should raise when estimate would exceed."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
registry.consume("thread-1", prompt_tokens=600)
with pytest.raises(BudgetExhaustedError):
registry.check_budget("thread-1", estimated_tokens=500)
def test_consume_returns_budget(self):
"""consume() should return updated budget."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
budget = registry.consume("thread-1", prompt_tokens=100, completion_tokens=50)
assert budget.total_tokens == 150
assert budget.request_count == 1
def test_get_usage(self):
"""get_usage should return dict with all stats."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
registry.consume("thread-1", prompt_tokens=500, completion_tokens=200)
registry.consume("thread-1", prompt_tokens=300, completion_tokens=100)
usage = registry.get_usage("thread-1")
assert usage["prompt_tokens"] == 800
assert usage["completion_tokens"] == 300
assert usage["total_tokens"] == 1100
assert usage["remaining"] == 8900
assert usage["request_count"] == 2
def test_get_all_usage(self):
"""get_all_usage should return all threads."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
registry.consume("thread-1", prompt_tokens=100)
registry.consume("thread-2", prompt_tokens=200)
all_usage = registry.get_all_usage()
assert len(all_usage) == 2
assert "thread-1" in all_usage
assert "thread-2" in all_usage
def test_reset_thread(self):
"""reset_thread should remove budget for thread."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
registry.consume("thread-1", prompt_tokens=500)
registry.reset_thread("thread-1")
# Getting budget should create new one with zero usage
budget = registry.get_budget("thread-1")
assert budget.total_tokens == 0
def test_cleanup_thread(self):
"""cleanup_thread should return and remove budget."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
registry.consume("thread-1", prompt_tokens=500)
final_budget = registry.cleanup_thread("thread-1")
assert final_budget.total_tokens == 500
assert registry.cleanup_thread("thread-1") is None # Already cleaned
def test_global_registry(self):
"""Global registry should be singleton."""
registry1 = get_budget_registry()
registry2 = get_budget_registry()
assert registry1 is registry2
def test_global_configure(self):
"""configure_budget_registry should update global."""
configure_budget_registry(max_tokens_per_thread=75000)
registry = get_budget_registry()
budget = registry.get_budget("new-thread")
assert budget.max_tokens == 75000
# ============================================================================
# UsageTracker Tests
# ============================================================================
class TestUsageTracker:
"""Test UsageTracker for billing/metering."""
@pytest.fixture(autouse=True)
def reset(self):
"""Reset global tracker before each test."""
reset_usage_tracker()
yield
reset_usage_tracker()
def test_record_creates_event(self):
"""record() should create and return UsageEvent."""
tracker = UsageTracker()
event = tracker.record(
thread_id="thread-1",
agent_id="greeter",
model="grok-4.1",
provider="xai",
prompt_tokens=500,
completion_tokens=200,
latency_ms=150.5,
)
assert event.thread_id == "thread-1"
assert event.agent_id == "greeter"
assert event.model == "grok-4.1"
assert event.total_tokens == 700
assert event.timestamp is not None
def test_record_estimates_cost(self):
"""record() should estimate cost for known models."""
tracker = UsageTracker()
event = tracker.record(
thread_id="thread-1",
agent_id="agent",
model="grok-4.1",
provider="xai",
prompt_tokens=1_000_000, # 1M prompt
completion_tokens=1_000_000, # 1M completion
latency_ms=1000,
)
# grok-4.1: $3/1M prompt + $15/1M completion = $18
assert event.estimated_cost == 18.0
def test_subscriber_receives_events(self):
"""Subscribers should receive events on record."""
tracker = UsageTracker()
received = []
tracker.subscribe(lambda e: received.append(e))
tracker.record(
thread_id="t1",
agent_id="agent",
model="gpt-4o",
provider="openai",
prompt_tokens=100,
completion_tokens=50,
latency_ms=50,
)
assert len(received) == 1
assert received[0].thread_id == "t1"
def test_unsubscribe(self):
"""unsubscribe should stop receiving events."""
tracker = UsageTracker()
received = []
callback = lambda e: received.append(e)
tracker.subscribe(callback)
tracker.record(thread_id="t1", agent_id=None, model="m", provider="p",
prompt_tokens=10, completion_tokens=10, latency_ms=10)
tracker.unsubscribe(callback)
tracker.record(thread_id="t2", agent_id=None, model="m", provider="p",
prompt_tokens=10, completion_tokens=10, latency_ms=10)
assert len(received) == 1
def test_get_totals(self):
"""get_totals should return aggregate stats."""
tracker = UsageTracker()
tracker.record(thread_id="t1", agent_id="a1", model="m1", provider="p",
prompt_tokens=100, completion_tokens=50, latency_ms=100)
tracker.record(thread_id="t2", agent_id="a2", model="m2", provider="p",
prompt_tokens=200, completion_tokens=100, latency_ms=200)
totals = tracker.get_totals()
assert totals["prompt_tokens"] == 300
assert totals["completion_tokens"] == 150
assert totals["total_tokens"] == 450
assert totals["request_count"] == 2
assert totals["avg_latency_ms"] == 150.0
def test_get_agent_totals(self):
"""get_agent_totals should return per-agent stats."""
tracker = UsageTracker()
tracker.record(thread_id="t1", agent_id="greeter", model="m", provider="p",
prompt_tokens=100, completion_tokens=50, latency_ms=100)
tracker.record(thread_id="t2", agent_id="greeter", model="m", provider="p",
prompt_tokens=100, completion_tokens=50, latency_ms=100)
tracker.record(thread_id="t3", agent_id="shouter", model="m", provider="p",
prompt_tokens=200, completion_tokens=100, latency_ms=200)
greeter = tracker.get_agent_totals("greeter")
shouter = tracker.get_agent_totals("shouter")
assert greeter["total_tokens"] == 300
assert greeter["request_count"] == 2
assert shouter["total_tokens"] == 300
assert shouter["request_count"] == 1
def test_get_model_totals(self):
"""get_model_totals should return per-model stats."""
tracker = UsageTracker()
tracker.record(thread_id="t1", agent_id="a", model="grok-4.1", provider="xai",
prompt_tokens=1000, completion_tokens=500, latency_ms=100)
tracker.record(thread_id="t2", agent_id="a", model="claude-sonnet-4", provider="anthropic",
prompt_tokens=500, completion_tokens=250, latency_ms=100)
grok = tracker.get_model_totals("grok-4.1")
claude = tracker.get_model_totals("claude-sonnet-4")
assert grok["total_tokens"] == 1500
assert claude["total_tokens"] == 750
def test_metadata_passed_through(self):
"""Metadata should be included in events."""
tracker = UsageTracker()
received = []
tracker.subscribe(lambda e: received.append(e))
tracker.record(
thread_id="t1",
agent_id="a",
model="m",
provider="p",
prompt_tokens=10,
completion_tokens=10,
latency_ms=10,
metadata={"org_id": "org-123", "user_id": "user-456"},
)
assert received[0].metadata["org_id"] == "org-123"
assert received[0].metadata["user_id"] == "user-456"
# ============================================================================
# Cost Estimation Tests
# ============================================================================
class TestCostEstimation:
"""Test cost estimation for various models."""
def test_grok_cost(self):
"""Grok models should use correct pricing."""
cost = estimate_cost("grok-4.1", prompt_tokens=1_000_000, completion_tokens=1_000_000)
# $3/1M prompt + $15/1M completion = $18
assert cost == 18.0
def test_claude_opus_cost(self):
"""Claude Opus should use correct pricing."""
cost = estimate_cost("claude-opus-4", prompt_tokens=1_000_000, completion_tokens=1_000_000)
# $15/1M prompt + $75/1M completion = $90
assert cost == 90.0
def test_gpt4o_cost(self):
"""GPT-4o should use correct pricing."""
cost = estimate_cost("gpt-4o", prompt_tokens=1_000_000, completion_tokens=1_000_000)
# $2.5/1M prompt + $10/1M completion = $12.5
assert cost == 12.5
def test_unknown_model_returns_none(self):
"""Unknown model should return None."""
cost = estimate_cost("unknown-model", prompt_tokens=1000, completion_tokens=500)
assert cost is None
def test_small_usage_cost(self):
"""Small token counts should produce fractional costs."""
cost = estimate_cost("gpt-4o-mini", prompt_tokens=1000, completion_tokens=500)
# 1000 tokens * $0.15/1M = $0.00015
# 500 tokens * $0.6/1M = $0.0003
# Total = $0.00045
assert cost == pytest.approx(0.00045, rel=1e-4)
# ============================================================================
# LLMRouter Integration Tests (Mocked)
# ============================================================================
class TestLLMRouterBudgetIntegration:
"""Test LLMRouter budget enforcement."""
@pytest.fixture(autouse=True)
def reset_all(self):
"""Reset all global registries."""
reset_budget_registry()
reset_usage_tracker()
yield
reset_budget_registry()
reset_usage_tracker()
@pytest.mark.asyncio
async def test_complete_consumes_budget(self):
"""LLM complete should consume from thread budget."""
from xml_pipeline.llm.router import LLMRouter
from xml_pipeline.llm.backend import LLMResponse
# Create mock backend
mock_backend = Mock()
mock_backend.name = "mock"
mock_backend.provider = "test"
mock_backend.serves_model = Mock(return_value=True)
mock_backend.priority = 1
mock_backend.load = 0
mock_backend.complete = AsyncMock(return_value=LLMResponse(
content="Hello!",
model="test-model",
usage={"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150},
finish_reason="stop",
))
# Configure budget
configure_budget_registry(max_tokens_per_thread=10000)
budget_registry = get_budget_registry()
# Create router with mock backend
router = LLMRouter()
router.backends.append(mock_backend)
# Make request
response = await router.complete(
model="test-model",
messages=[{"role": "user", "content": "Hi"}],
thread_id="test-thread-123",
)
assert response.content == "Hello!"
# Verify budget consumed
usage = budget_registry.get_usage("test-thread-123")
assert usage["prompt_tokens"] == 100
assert usage["completion_tokens"] == 50
assert usage["total_tokens"] == 150
@pytest.mark.asyncio
async def test_complete_emits_usage_event(self):
"""LLM complete should emit usage event."""
from xml_pipeline.llm.router import LLMRouter
from xml_pipeline.llm.backend import LLMResponse
mock_backend = Mock()
mock_backend.name = "mock"
mock_backend.provider = "test"
mock_backend.serves_model = Mock(return_value=True)
mock_backend.priority = 1
mock_backend.load = 0
mock_backend.complete = AsyncMock(return_value=LLMResponse(
content="Hello!",
model="test-model",
usage={"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150},
finish_reason="stop",
))
# Subscribe to usage events
tracker = get_usage_tracker()
received_events = []
tracker.subscribe(lambda e: received_events.append(e))
# Create router and make request
router = LLMRouter()
router.backends.append(mock_backend)
await router.complete(
model="test-model",
messages=[{"role": "user", "content": "Hi"}],
thread_id="test-thread",
agent_id="greeter",
metadata={"org_id": "test-org"},
)
# Verify event emitted
assert len(received_events) == 1
event = received_events[0]
assert event.thread_id == "test-thread"
assert event.agent_id == "greeter"
assert event.total_tokens == 150
assert event.metadata["org_id"] == "test-org"
@pytest.mark.asyncio
async def test_complete_raises_when_budget_exhausted(self):
"""LLM complete should raise when budget exhausted."""
from xml_pipeline.llm.router import LLMRouter
# Configure small budget and exhaust it
configure_budget_registry(max_tokens_per_thread=100)
budget_registry = get_budget_registry()
budget_registry.consume("test-thread", prompt_tokens=100)
mock_backend = Mock()
mock_backend.name = "mock"
mock_backend.serves_model = Mock(return_value=True)
mock_backend.priority = 1
router = LLMRouter()
router.backends.append(mock_backend)
with pytest.raises(BudgetExhaustedError) as exc_info:
await router.complete(
model="test-model",
messages=[{"role": "user", "content": "Hi"}],
thread_id="test-thread",
)
assert "budget exhausted" in str(exc_info.value)
# Backend should NOT have been called
mock_backend.complete.assert_not_called()
@pytest.mark.asyncio
async def test_complete_without_thread_id_skips_budget(self):
"""LLM complete without thread_id should skip budget check."""
from xml_pipeline.llm.router import LLMRouter
from xml_pipeline.llm.backend import LLMResponse
mock_backend = Mock()
mock_backend.name = "mock"
mock_backend.provider = "test"
mock_backend.serves_model = Mock(return_value=True)
mock_backend.priority = 1
mock_backend.load = 0
mock_backend.complete = AsyncMock(return_value=LLMResponse(
content="Hello!",
model="test-model",
usage={"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150},
finish_reason="stop",
))
router = LLMRouter()
router.backends.append(mock_backend)
# Should not raise - no budget checking
response = await router.complete(
model="test-model",
messages=[{"role": "user", "content": "Hi"}],
# No thread_id
)
assert response.content == "Hello!"

View file

@ -16,7 +16,20 @@ Usage:
response = await router.complete( response = await router.complete(
model="grok-4.1", model="grok-4.1",
messages=[{"role": "user", "content": "Hello"}], messages=[{"role": "user", "content": "Hello"}],
thread_id=metadata.thread_id, # For budget enforcement
agent_id=metadata.own_name, # For usage tracking
) )
Usage Tracking:
from xml_pipeline.llm import get_usage_tracker
tracker = get_usage_tracker()
# Subscribe to events for billing
tracker.subscribe(lambda event: billing_api.record(event))
# Query totals
totals = tracker.get_totals()
""" """
from xml_pipeline.llm.router import ( from xml_pipeline.llm.router import (
@ -27,14 +40,27 @@ from xml_pipeline.llm.router import (
Strategy, Strategy,
) )
from xml_pipeline.llm.backend import LLMRequest, LLMResponse, BackendError from xml_pipeline.llm.backend import LLMRequest, LLMResponse, BackendError
from xml_pipeline.llm.usage_tracker import (
UsageTracker,
UsageEvent,
get_usage_tracker,
reset_usage_tracker,
)
__all__ = [ __all__ = [
# Router
"LLMRouter", "LLMRouter",
"get_router", "get_router",
"configure_router", "configure_router",
"complete", "complete",
"Strategy", "Strategy",
# Backend
"LLMRequest", "LLMRequest",
"LLMResponse", "LLMResponse",
"BackendError", "BackendError",
# Usage tracking
"UsageTracker",
"UsageEvent",
"get_usage_tracker",
"reset_usage_tracker",
] ]

View file

@ -9,6 +9,8 @@ The router handles:
- Load balancing (failover, round-robin, least-loaded) - Load balancing (failover, round-robin, least-loaded)
- Retries with exponential backoff - Retries with exponential backoff
- Token tracking per agent - Token tracking per agent
- Thread budget enforcement
- Usage event emission for billing
""" """
from __future__ import annotations from __future__ import annotations
@ -16,6 +18,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import random import random
import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
@ -125,6 +128,8 @@ class LLMRouter:
max_tokens: int = None, max_tokens: int = None,
tools: List[Dict] = None, tools: List[Dict] = None,
agent_id: str = None, agent_id: str = None,
thread_id: str = None,
metadata: Dict[str, Any] = None,
) -> LLMResponse: ) -> LLMResponse:
""" """
Execute a completion request. Execute a completion request.
@ -136,10 +141,27 @@ class LLMRouter:
max_tokens: Max tokens in response max_tokens: Max tokens in response
tools: Tool definitions for function calling tools: Tool definitions for function calling
agent_id: Optional agent ID for usage tracking agent_id: Optional agent ID for usage tracking
thread_id: Optional thread ID for budget enforcement
metadata: Optional metadata for usage events (org_id, user_id, etc.)
Returns: Returns:
LLMResponse with content and usage stats LLMResponse with content and usage stats
Raises:
BudgetExhaustedError: If thread has no remaining budget
BackendError: If all backends fail
""" """
# Estimate tokens for budget check (rough: 4 chars per token)
estimated_tokens = sum(len(m.get("content", "")) for m in messages) // 4
estimated_tokens = max(estimated_tokens, 100) # minimum estimate
# Check thread budget before proceeding
if thread_id:
from xml_pipeline.message_bus.budget_registry import get_budget_registry
budget_registry = get_budget_registry()
# This raises BudgetExhaustedError if over budget
budget_registry.check_budget(thread_id, estimated_tokens)
candidates = self._find_backends(model) candidates = self._find_backends(model)
request = LLMRequest( request = LLMRequest(
model=model, model=model,
@ -151,6 +173,7 @@ class LLMRouter:
last_error = None last_error = None
tried_backends = set() tried_backends = set()
start_time = time.monotonic()
for attempt in range(self.retries + 1): for attempt in range(self.retries + 1):
# Select backend (different selection on retry for failover) # Select backend (different selection on retry for failover)
@ -170,14 +193,46 @@ class LLMRouter:
logger.debug(f"Attempting {model} on {backend.name} (attempt {attempt + 1})") logger.debug(f"Attempting {model} on {backend.name} (attempt {attempt + 1})")
response = await backend.complete(request) response = await backend.complete(request)
# Track usage # Calculate latency
latency_ms = (time.monotonic() - start_time) * 1000
# Extract usage
prompt_tokens = response.usage.get("prompt_tokens", 0)
completion_tokens = response.usage.get("completion_tokens", 0)
total_tokens = response.usage.get("total_tokens", 0)
# Track per-agent usage (internal)
if agent_id: if agent_id:
usage = self._agent_usage.setdefault(agent_id, AgentUsage()) usage = self._agent_usage.setdefault(agent_id, AgentUsage())
usage.total_tokens += response.usage.get("total_tokens", 0) usage.total_tokens += total_tokens
usage.prompt_tokens += response.usage.get("prompt_tokens", 0) usage.prompt_tokens += prompt_tokens
usage.completion_tokens += response.usage.get("completion_tokens", 0) usage.completion_tokens += completion_tokens
usage.request_count += 1 usage.request_count += 1
# Record to thread budget (enforcement)
if thread_id:
from xml_pipeline.message_bus.budget_registry import get_budget_registry
budget_registry = get_budget_registry()
budget_registry.consume(
thread_id,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
# Emit usage event (for billing)
from xml_pipeline.llm.usage_tracker import get_usage_tracker
tracker = get_usage_tracker()
tracker.record(
thread_id=thread_id or "",
agent_id=agent_id,
model=response.model,
provider=backend.provider,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
latency_ms=latency_ms,
metadata=metadata,
)
return response return response
except RateLimitError as e: except RateLimitError as e:
@ -286,6 +341,10 @@ def configure_router(config: Dict[str, Any]) -> LLMRouter:
async def complete( async def complete(
model: str, model: str,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
*,
thread_id: str = None,
agent_id: str = None,
metadata: Dict[str, Any] = None,
**kwargs, **kwargs,
) -> LLMResponse: ) -> LLMResponse:
""" """
@ -293,6 +352,32 @@ async def complete(
Usage: Usage:
from xml_pipeline.llm import router from xml_pipeline.llm import router
response = await router.complete("grok-4.1", messages) response = await router.complete(
"grok-4.1",
messages,
thread_id=metadata.thread_id,
agent_id=metadata.own_name,
)
Args:
model: Model name
messages: Chat messages
thread_id: Thread UUID for budget enforcement
agent_id: Agent name for usage tracking
metadata: Extra metadata for billing events
**kwargs: Additional arguments (temperature, max_tokens, tools)
Returns:
LLMResponse with content and usage stats
Raises:
BudgetExhaustedError: If thread budget exhausted
""" """
return await get_router().complete(model, messages, **kwargs) return await get_router().complete(
model,
messages,
thread_id=thread_id,
agent_id=agent_id,
metadata=metadata,
**kwargs,
)

View file

@ -0,0 +1,346 @@
"""
Usage Tracker Production billing and gas usage metering.
This module provides hooks for tracking LLM usage at the platform level.
External billing systems can subscribe to usage events for metering.
Usage Tracking Layers:
1. Per-agent (LLMRouter._agent_usage) Internal token tracking
2. Per-thread (ThreadBudgetRegistry) Enforcement limits
3. Platform (UsageTracker) Production billing/metering
Example:
from xml_pipeline.llm.usage_tracker import get_usage_tracker
tracker = get_usage_tracker()
# Subscribe to usage events (for billing webhook, database, etc.)
def record_usage(event: UsageEvent):
billing_db.record(
org_id=event.metadata.get("org_id"),
tokens=event.total_tokens,
cost=event.estimated_cost,
)
tracker.subscribe(record_usage)
# Query aggregate usage
totals = tracker.get_totals()
print(f"Total tokens: {totals['total_tokens']}")
"""
from __future__ import annotations
import threading
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Callable, Dict, List, Optional
@dataclass
class UsageEvent:
"""
Usage event emitted after each LLM completion.
This is the main interface for billing systems.
"""
# Request identification
thread_id: str
agent_id: Optional[str]
model: str
provider: str
# Token usage
prompt_tokens: int
completion_tokens: int
total_tokens: int
# Timing
timestamp: str # ISO 8601
latency_ms: float # Request duration
# Cost estimation (if available)
estimated_cost: Optional[float] = None
# Extensible metadata (org_id, user_id, etc.)
metadata: Dict[str, Any] = field(default_factory=dict)
# Cost per 1M tokens for common models (approximate, update as needed)
MODEL_COSTS: Dict[str, Dict[str, float]] = {
# xAI Grok
"grok-4.1": {"prompt": 3.0, "completion": 15.0},
"grok-3": {"prompt": 3.0, "completion": 15.0},
# Anthropic Claude
"claude-opus-4": {"prompt": 15.0, "completion": 75.0},
"claude-sonnet-4": {"prompt": 3.0, "completion": 15.0},
"claude-sonnet-3-5": {"prompt": 3.0, "completion": 15.0},
# OpenAI
"gpt-4o": {"prompt": 2.5, "completion": 10.0},
"gpt-4o-mini": {"prompt": 0.15, "completion": 0.6},
"o1": {"prompt": 15.0, "completion": 60.0},
"o3-mini": {"prompt": 1.1, "completion": 4.4},
}
def estimate_cost(
model: str,
prompt_tokens: int,
completion_tokens: int,
) -> Optional[float]:
"""
Estimate cost in USD for a completion.
Returns None if model pricing is unknown.
"""
# Normalize model name for lookup
model_lower = model.lower()
# Find matching pricing (prefer longest prefix match)
pricing = None
best_match_len = 0
for model_prefix, costs in MODEL_COSTS.items():
prefix_lower = model_prefix.lower()
if model_lower.startswith(prefix_lower):
if len(prefix_lower) > best_match_len:
pricing = costs
best_match_len = len(prefix_lower)
if pricing is None:
return None
# Cost = (tokens / 1M) * cost_per_million
prompt_cost = (prompt_tokens / 1_000_000) * pricing["prompt"]
completion_cost = (completion_tokens / 1_000_000) * pricing["completion"]
return round(prompt_cost + completion_cost, 6)
UsageCallback = Callable[[UsageEvent], None]
@dataclass
class UsageTotals:
"""Aggregate usage statistics."""
total_tokens: int = 0
prompt_tokens: int = 0
completion_tokens: int = 0
request_count: int = 0
total_cost: float = 0.0
total_latency_ms: float = 0.0
class UsageTracker:
"""
Platform-level usage tracking for billing and metering.
Thread-safe. Supports multiple subscribers for real-time event streaming.
Integration points:
- Webhook to billing API
- Database for usage records
- Metrics/observability (Prometheus, DataDog)
- Real-time dashboard (WebSocket)
"""
def __init__(self):
self._callbacks: List[UsageCallback] = []
self._lock = threading.Lock()
# Aggregate tracking
self._totals = UsageTotals()
self._per_agent: Dict[str, UsageTotals] = {}
self._per_model: Dict[str, UsageTotals] = {}
def subscribe(self, callback: UsageCallback) -> None:
"""
Subscribe to usage events.
Callbacks are invoked synchronously after each LLM completion.
For async processing, use a queue in your callback.
"""
with self._lock:
self._callbacks.append(callback)
def unsubscribe(self, callback: UsageCallback) -> None:
"""Unsubscribe from usage events."""
with self._lock:
if callback in self._callbacks:
self._callbacks.remove(callback)
def record(
self,
thread_id: str,
agent_id: Optional[str],
model: str,
provider: str,
prompt_tokens: int,
completion_tokens: int,
latency_ms: float,
metadata: Optional[Dict[str, Any]] = None,
) -> UsageEvent:
"""
Record a usage event and notify subscribers.
Called by LLMRouter after each completion.
Returns:
The created UsageEvent (for chaining/logging)
"""
total_tokens = prompt_tokens + completion_tokens
event = UsageEvent(
thread_id=thread_id,
agent_id=agent_id,
model=model,
provider=provider,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
timestamp=datetime.now(timezone.utc).isoformat(),
latency_ms=latency_ms,
estimated_cost=estimate_cost(model, prompt_tokens, completion_tokens),
metadata=metadata or {},
)
# Update aggregates
with self._lock:
self._update_totals(self._totals, event)
if agent_id:
if agent_id not in self._per_agent:
self._per_agent[agent_id] = UsageTotals()
self._update_totals(self._per_agent[agent_id], event)
if model not in self._per_model:
self._per_model[model] = UsageTotals()
self._update_totals(self._per_model[model], event)
# Copy callbacks to avoid holding lock during invocation
callbacks = list(self._callbacks)
# Notify subscribers (outside lock)
for callback in callbacks:
try:
callback(event)
except Exception:
# Don't let subscriber errors break tracking
pass
return event
def _update_totals(self, totals: UsageTotals, event: UsageEvent) -> None:
"""Update aggregate totals from an event."""
totals.total_tokens += event.total_tokens
totals.prompt_tokens += event.prompt_tokens
totals.completion_tokens += event.completion_tokens
totals.request_count += 1
totals.total_latency_ms += event.latency_ms
if event.estimated_cost:
totals.total_cost += event.estimated_cost
def get_totals(self) -> Dict[str, Any]:
"""Get aggregate usage totals."""
with self._lock:
return {
"total_tokens": self._totals.total_tokens,
"prompt_tokens": self._totals.prompt_tokens,
"completion_tokens": self._totals.completion_tokens,
"request_count": self._totals.request_count,
"total_cost": round(self._totals.total_cost, 4),
"avg_latency_ms": (
self._totals.total_latency_ms / self._totals.request_count
if self._totals.request_count > 0
else 0
),
}
def get_agent_totals(self, agent_id: str) -> Dict[str, Any]:
"""Get usage totals for a specific agent."""
with self._lock:
totals = self._per_agent.get(agent_id, UsageTotals())
return {
"total_tokens": totals.total_tokens,
"prompt_tokens": totals.prompt_tokens,
"completion_tokens": totals.completion_tokens,
"request_count": totals.request_count,
"total_cost": round(totals.total_cost, 4),
}
def get_model_totals(self, model: str) -> Dict[str, Any]:
"""Get usage totals for a specific model."""
with self._lock:
totals = self._per_model.get(model, UsageTotals())
return {
"total_tokens": totals.total_tokens,
"prompt_tokens": totals.prompt_tokens,
"completion_tokens": totals.completion_tokens,
"request_count": totals.request_count,
"total_cost": round(totals.total_cost, 4),
}
def get_all_agent_totals(self) -> Dict[str, Dict[str, Any]]:
"""Get usage totals for all agents."""
with self._lock:
return {
agent_id: {
"total_tokens": t.total_tokens,
"prompt_tokens": t.prompt_tokens,
"completion_tokens": t.completion_tokens,
"request_count": t.request_count,
"total_cost": round(t.total_cost, 4),
}
for agent_id, t in self._per_agent.items()
}
def get_all_model_totals(self) -> Dict[str, Dict[str, Any]]:
"""Get usage totals for all models."""
with self._lock:
return {
model: {
"total_tokens": t.total_tokens,
"prompt_tokens": t.prompt_tokens,
"completion_tokens": t.completion_tokens,
"request_count": t.request_count,
"total_cost": round(t.total_cost, 4),
}
for model, t in self._per_model.items()
}
def reset(self) -> None:
"""Reset all tracking (for testing)."""
with self._lock:
self._totals = UsageTotals()
self._per_agent.clear()
self._per_model.clear()
# =============================================================================
# Global Instance
# =============================================================================
_tracker: Optional[UsageTracker] = None
_tracker_lock = threading.Lock()
def get_usage_tracker() -> UsageTracker:
"""Get the global usage tracker."""
global _tracker
if _tracker is None:
with _tracker_lock:
if _tracker is None:
_tracker = UsageTracker()
return _tracker
def reset_usage_tracker() -> None:
"""Reset the global tracker (for testing)."""
global _tracker
with _tracker_lock:
if _tracker is not None:
_tracker.reset()
_tracker = None

View file

@ -67,6 +67,15 @@ from xml_pipeline.message_bus.buffer_registry import (
reset_buffer_registry, reset_buffer_registry,
) )
from xml_pipeline.message_bus.budget_registry import (
ThreadBudget,
ThreadBudgetRegistry,
BudgetExhaustedError,
get_budget_registry,
configure_budget_registry,
reset_budget_registry,
)
__all__ = [ __all__ = [
# Pump # Pump
"StreamPump", "StreamPump",
@ -102,4 +111,11 @@ __all__ = [
"BufferRegistry", "BufferRegistry",
"get_buffer_registry", "get_buffer_registry",
"reset_buffer_registry", "reset_buffer_registry",
# Budget registry
"ThreadBudget",
"ThreadBudgetRegistry",
"BudgetExhaustedError",
"get_budget_registry",
"configure_budget_registry",
"reset_budget_registry",
] ]

View file

@ -0,0 +1,280 @@
"""
Thread Budget Registry Enforces per-thread token limits.
Each thread has a token budget that tracks:
- Total tokens consumed (prompt + completion)
- Requests made
- Remaining budget
When a thread exhausts its budget, LLM calls are blocked.
Example config:
organism:
max_tokens_per_thread: 100000 # 100k tokens per thread
"""
from __future__ import annotations
import threading
from dataclasses import dataclass, field
from typing import Dict, Optional
@dataclass
class ThreadBudget:
"""Track token usage for a single thread."""
max_tokens: int
prompt_tokens: int = 0
completion_tokens: int = 0
request_count: int = 0
@property
def total_tokens(self) -> int:
"""Total tokens consumed."""
return self.prompt_tokens + self.completion_tokens
@property
def remaining(self) -> int:
"""Remaining token budget."""
return max(0, self.max_tokens - self.total_tokens)
@property
def is_exhausted(self) -> bool:
"""True if budget is exhausted."""
return self.total_tokens >= self.max_tokens
def can_consume(self, estimated_tokens: int) -> bool:
"""Check if we can consume the given tokens without exceeding budget."""
return self.total_tokens + estimated_tokens <= self.max_tokens
def consume(
self,
prompt_tokens: int = 0,
completion_tokens: int = 0,
) -> None:
"""Record token consumption."""
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
self.request_count += 1
class BudgetExhaustedError(Exception):
"""Raised when a thread's token budget is exhausted."""
def __init__(self, thread_id: str, used: int, max_tokens: int):
self.thread_id = thread_id
self.used = used
self.max_tokens = max_tokens
super().__init__(
f"Thread {thread_id[:8]}... budget exhausted: "
f"{used}/{max_tokens} tokens used"
)
class ThreadBudgetRegistry:
"""
Manages token budgets per thread.
Thread-safe for concurrent access.
Usage:
registry = get_budget_registry()
registry.configure(max_tokens_per_thread=100000)
# Before LLM call
registry.check_budget(thread_id, estimated_tokens=1000)
# After LLM call
registry.consume(thread_id, prompt=500, completion=300)
# Get usage
budget = registry.get_budget(thread_id)
print(f"Used: {budget.total_tokens}, Remaining: {budget.remaining}")
"""
def __init__(self, max_tokens_per_thread: int = 100_000):
"""
Initialize budget registry.
Args:
max_tokens_per_thread: Default budget for new threads.
"""
self._max_tokens_per_thread = max_tokens_per_thread
self._budgets: Dict[str, ThreadBudget] = {}
self._lock = threading.Lock()
def configure(self, max_tokens_per_thread: int) -> None:
"""
Update default max tokens for new threads.
Existing threads keep their current budgets.
"""
with self._lock:
self._max_tokens_per_thread = max_tokens_per_thread
@property
def max_tokens_per_thread(self) -> int:
"""Get the default max tokens per thread."""
return self._max_tokens_per_thread
def get_budget(self, thread_id: str) -> ThreadBudget:
"""
Get or create budget for a thread.
Args:
thread_id: Thread UUID
Returns:
ThreadBudget instance
"""
with self._lock:
if thread_id not in self._budgets:
self._budgets[thread_id] = ThreadBudget(
max_tokens=self._max_tokens_per_thread
)
return self._budgets[thread_id]
def check_budget(
self,
thread_id: str,
estimated_tokens: int = 0,
) -> bool:
"""
Check if thread has budget for the estimated tokens.
Args:
thread_id: Thread UUID
estimated_tokens: Estimated tokens for the request
Returns:
True if budget available
Raises:
BudgetExhaustedError if budget is exhausted
"""
budget = self.get_budget(thread_id)
if budget.is_exhausted:
raise BudgetExhaustedError(
thread_id=thread_id,
used=budget.total_tokens,
max_tokens=budget.max_tokens,
)
if not budget.can_consume(estimated_tokens):
raise BudgetExhaustedError(
thread_id=thread_id,
used=budget.total_tokens,
max_tokens=budget.max_tokens,
)
return True
def consume(
self,
thread_id: str,
prompt_tokens: int = 0,
completion_tokens: int = 0,
) -> ThreadBudget:
"""
Record token consumption for a thread.
Args:
thread_id: Thread UUID
prompt_tokens: Prompt tokens used
completion_tokens: Completion tokens used
Returns:
Updated ThreadBudget
"""
budget = self.get_budget(thread_id)
with self._lock:
budget.consume(prompt_tokens, completion_tokens)
return budget
def get_usage(self, thread_id: str) -> Dict[str, int]:
"""
Get usage stats for a thread.
Returns:
Dict with prompt_tokens, completion_tokens, total_tokens,
remaining, max_tokens, request_count
"""
budget = self.get_budget(thread_id)
return {
"prompt_tokens": budget.prompt_tokens,
"completion_tokens": budget.completion_tokens,
"total_tokens": budget.total_tokens,
"remaining": budget.remaining,
"max_tokens": budget.max_tokens,
"request_count": budget.request_count,
}
def get_all_usage(self) -> Dict[str, Dict[str, int]]:
"""Get usage stats for all threads."""
with self._lock:
return {
thread_id: {
"prompt_tokens": b.prompt_tokens,
"completion_tokens": b.completion_tokens,
"total_tokens": b.total_tokens,
"remaining": b.remaining,
"max_tokens": b.max_tokens,
"request_count": b.request_count,
}
for thread_id, b in self._budgets.items()
}
def reset_thread(self, thread_id: str) -> None:
"""Reset budget for a specific thread."""
with self._lock:
self._budgets.pop(thread_id, None)
def cleanup_thread(self, thread_id: str) -> Optional[ThreadBudget]:
"""
Remove budget when thread is pruned/completed.
Returns the final budget for logging/billing, or None if not found.
"""
with self._lock:
return self._budgets.pop(thread_id, None)
def clear(self) -> None:
"""Clear all budgets (for testing)."""
with self._lock:
self._budgets.clear()
# =============================================================================
# Global Instance
# =============================================================================
_registry: Optional[ThreadBudgetRegistry] = None
_registry_lock = threading.Lock()
def get_budget_registry() -> ThreadBudgetRegistry:
"""Get the global budget registry."""
global _registry
if _registry is None:
with _registry_lock:
if _registry is None:
_registry = ThreadBudgetRegistry()
return _registry
def configure_budget_registry(max_tokens_per_thread: int) -> ThreadBudgetRegistry:
"""Configure the global budget registry."""
registry = get_budget_registry()
registry.configure(max_tokens_per_thread)
return registry
def reset_budget_registry() -> None:
"""Reset the global registry (for testing)."""
global _registry
with _registry_lock:
if _registry is not None:
_registry.clear()
_registry = None

View file

@ -141,6 +141,9 @@ class OrganismConfig:
max_concurrent_handlers: int = 20 # Concurrent handler invocations max_concurrent_handlers: int = 20 # Concurrent handler invocations
max_concurrent_per_agent: int = 5 # Per-agent rate limit max_concurrent_per_agent: int = 5 # Per-agent rate limit
# Token budget enforcement
max_tokens_per_thread: int = 100_000 # Max tokens per conversation thread
# LLM configuration (optional) # LLM configuration (optional)
llm_config: Dict[str, Any] = field(default_factory=dict) llm_config: Dict[str, Any] = field(default_factory=dict)
@ -1271,6 +1274,7 @@ class ConfigLoader:
max_concurrent_pipelines=raw.get("max_concurrent_pipelines", 50), max_concurrent_pipelines=raw.get("max_concurrent_pipelines", 50),
max_concurrent_handlers=raw.get("max_concurrent_handlers", 20), max_concurrent_handlers=raw.get("max_concurrent_handlers", 20),
max_concurrent_per_agent=raw.get("max_concurrent_per_agent", 5), max_concurrent_per_agent=raw.get("max_concurrent_per_agent", 5),
max_tokens_per_thread=raw.get("max_tokens_per_thread", 100_000),
llm_config=raw.get("llm", {}), llm_config=raw.get("llm", {}),
process_pool_enabled=process_pool_enabled, process_pool_enabled=process_pool_enabled,
process_pool_workers=process_pool_workers, process_pool_workers=process_pool_workers,
@ -1430,6 +1434,11 @@ async def bootstrap(config_path: str = "config/organism.yaml") -> StreamPump:
configure_router(config.llm_config) configure_router(config.llm_config)
print(f"LLM backends: {len(config.llm_config.get('backends', []))}") print(f"LLM backends: {len(config.llm_config.get('backends', []))}")
# Configure thread budget registry
from xml_pipeline.message_bus.budget_registry import configure_budget_registry
configure_budget_registry(config.max_tokens_per_thread)
print(f"Token budget: {config.max_tokens_per_thread:,} per thread")
# Initialize root thread in registry # Initialize root thread in registry
registry = get_registry() registry = get_registry()
root_uuid = registry.initialize_root(config.name) root_uuid = registry.initialize_root(config.name)