Add token budget enforcement and usage tracking
Token Budget System: - ThreadBudgetRegistry tracks per-thread token usage with configurable limits - BudgetExhaustedError raised when thread exceeds max_tokens_per_thread - Integrates with LLMRouter to block LLM calls when budget exhausted - Automatic cleanup when threads are pruned Usage Tracking (for production billing): - UsageTracker emits events after each LLM completion - Subscribers receive UsageEvent with tokens, latency, estimated cost - Cost estimation for common models (Grok, Claude, GPT, etc.) - Aggregate stats by agent, model, and totals Configuration: - max_tokens_per_thread in organism.yaml (default 100k) - LLMRouter.complete() accepts thread_id and metadata parameters Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
4530c06835
commit
8b11323a8b
7 changed files with 1341 additions and 6 deletions
573
tests/test_token_budget.py
Normal file
573
tests/test_token_budget.py
Normal file
|
|
@ -0,0 +1,573 @@
|
||||||
|
"""
|
||||||
|
test_token_budget.py — Tests for token budget and usage tracking.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
1. ThreadBudgetRegistry - per-thread token limits
|
||||||
|
2. UsageTracker - billing/gas usage events
|
||||||
|
3. LLMRouter integration - budget enforcement
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, AsyncMock, patch
|
||||||
|
|
||||||
|
from xml_pipeline.message_bus.budget_registry import (
|
||||||
|
ThreadBudget,
|
||||||
|
ThreadBudgetRegistry,
|
||||||
|
BudgetExhaustedError,
|
||||||
|
get_budget_registry,
|
||||||
|
configure_budget_registry,
|
||||||
|
reset_budget_registry,
|
||||||
|
)
|
||||||
|
from xml_pipeline.llm.usage_tracker import (
|
||||||
|
UsageEvent,
|
||||||
|
UsageTracker,
|
||||||
|
UsageTotals,
|
||||||
|
estimate_cost,
|
||||||
|
get_usage_tracker,
|
||||||
|
reset_usage_tracker,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# ThreadBudget Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestThreadBudget:
|
||||||
|
"""Test ThreadBudget dataclass."""
|
||||||
|
|
||||||
|
def test_initial_state(self):
|
||||||
|
"""New budget should have zero usage."""
|
||||||
|
budget = ThreadBudget(max_tokens=10000)
|
||||||
|
assert budget.total_tokens == 0
|
||||||
|
assert budget.remaining == 10000
|
||||||
|
assert budget.is_exhausted is False
|
||||||
|
|
||||||
|
def test_consume_tokens(self):
|
||||||
|
"""Consuming tokens should update totals."""
|
||||||
|
budget = ThreadBudget(max_tokens=10000)
|
||||||
|
budget.consume(prompt_tokens=500, completion_tokens=300)
|
||||||
|
|
||||||
|
assert budget.prompt_tokens == 500
|
||||||
|
assert budget.completion_tokens == 300
|
||||||
|
assert budget.total_tokens == 800
|
||||||
|
assert budget.remaining == 9200
|
||||||
|
assert budget.request_count == 1
|
||||||
|
|
||||||
|
def test_can_consume_within_budget(self):
|
||||||
|
"""can_consume should return True if within budget."""
|
||||||
|
budget = ThreadBudget(max_tokens=1000)
|
||||||
|
budget.consume(prompt_tokens=400)
|
||||||
|
|
||||||
|
assert budget.can_consume(500) is True
|
||||||
|
assert budget.can_consume(600) is True
|
||||||
|
assert budget.can_consume(601) is False
|
||||||
|
|
||||||
|
def test_is_exhausted(self):
|
||||||
|
"""is_exhausted should return True when budget exceeded."""
|
||||||
|
budget = ThreadBudget(max_tokens=1000)
|
||||||
|
budget.consume(prompt_tokens=1000)
|
||||||
|
|
||||||
|
assert budget.is_exhausted is True
|
||||||
|
assert budget.remaining == 0
|
||||||
|
|
||||||
|
def test_remaining_never_negative(self):
|
||||||
|
"""remaining should never go negative."""
|
||||||
|
budget = ThreadBudget(max_tokens=100)
|
||||||
|
budget.consume(prompt_tokens=200)
|
||||||
|
|
||||||
|
assert budget.remaining == 0
|
||||||
|
assert budget.total_tokens == 200
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# ThreadBudgetRegistry Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestThreadBudgetRegistry:
|
||||||
|
"""Test ThreadBudgetRegistry."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset(self):
|
||||||
|
"""Reset global registry before each test."""
|
||||||
|
reset_budget_registry()
|
||||||
|
yield
|
||||||
|
reset_budget_registry()
|
||||||
|
|
||||||
|
def test_default_budget_creation(self):
|
||||||
|
"""Getting budget for new thread should create one."""
|
||||||
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=50000)
|
||||||
|
budget = registry.get_budget("thread-1")
|
||||||
|
|
||||||
|
assert budget.max_tokens == 50000
|
||||||
|
assert budget.total_tokens == 0
|
||||||
|
|
||||||
|
def test_configure_max_tokens(self):
|
||||||
|
"""configure() should update default for new threads."""
|
||||||
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||||
|
budget1 = registry.get_budget("thread-1")
|
||||||
|
|
||||||
|
registry.configure(max_tokens_per_thread=20000)
|
||||||
|
budget2 = registry.get_budget("thread-2")
|
||||||
|
|
||||||
|
assert budget1.max_tokens == 10000 # Original unchanged
|
||||||
|
assert budget2.max_tokens == 20000 # New default
|
||||||
|
|
||||||
|
def test_check_budget_success(self):
|
||||||
|
"""check_budget should pass when within budget."""
|
||||||
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||||
|
|
||||||
|
result = registry.check_budget("thread-1", estimated_tokens=5000)
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
def test_check_budget_exhausted(self):
|
||||||
|
"""check_budget should raise when budget exhausted."""
|
||||||
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
|
||||||
|
registry.consume("thread-1", prompt_tokens=1000)
|
||||||
|
|
||||||
|
with pytest.raises(BudgetExhaustedError) as exc_info:
|
||||||
|
registry.check_budget("thread-1", estimated_tokens=100)
|
||||||
|
|
||||||
|
assert "budget exhausted" in str(exc_info.value)
|
||||||
|
assert exc_info.value.thread_id == "thread-1"
|
||||||
|
assert exc_info.value.used == 1000
|
||||||
|
assert exc_info.value.max_tokens == 1000
|
||||||
|
|
||||||
|
def test_check_budget_would_exceed(self):
|
||||||
|
"""check_budget should raise when estimate would exceed."""
|
||||||
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
|
||||||
|
registry.consume("thread-1", prompt_tokens=600)
|
||||||
|
|
||||||
|
with pytest.raises(BudgetExhaustedError):
|
||||||
|
registry.check_budget("thread-1", estimated_tokens=500)
|
||||||
|
|
||||||
|
def test_consume_returns_budget(self):
|
||||||
|
"""consume() should return updated budget."""
|
||||||
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||||
|
|
||||||
|
budget = registry.consume("thread-1", prompt_tokens=100, completion_tokens=50)
|
||||||
|
|
||||||
|
assert budget.total_tokens == 150
|
||||||
|
assert budget.request_count == 1
|
||||||
|
|
||||||
|
def test_get_usage(self):
|
||||||
|
"""get_usage should return dict with all stats."""
|
||||||
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||||
|
registry.consume("thread-1", prompt_tokens=500, completion_tokens=200)
|
||||||
|
registry.consume("thread-1", prompt_tokens=300, completion_tokens=100)
|
||||||
|
|
||||||
|
usage = registry.get_usage("thread-1")
|
||||||
|
|
||||||
|
assert usage["prompt_tokens"] == 800
|
||||||
|
assert usage["completion_tokens"] == 300
|
||||||
|
assert usage["total_tokens"] == 1100
|
||||||
|
assert usage["remaining"] == 8900
|
||||||
|
assert usage["request_count"] == 2
|
||||||
|
|
||||||
|
def test_get_all_usage(self):
|
||||||
|
"""get_all_usage should return all threads."""
|
||||||
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||||
|
registry.consume("thread-1", prompt_tokens=100)
|
||||||
|
registry.consume("thread-2", prompt_tokens=200)
|
||||||
|
|
||||||
|
all_usage = registry.get_all_usage()
|
||||||
|
|
||||||
|
assert len(all_usage) == 2
|
||||||
|
assert "thread-1" in all_usage
|
||||||
|
assert "thread-2" in all_usage
|
||||||
|
|
||||||
|
def test_reset_thread(self):
|
||||||
|
"""reset_thread should remove budget for thread."""
|
||||||
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||||
|
registry.consume("thread-1", prompt_tokens=500)
|
||||||
|
registry.reset_thread("thread-1")
|
||||||
|
|
||||||
|
# Getting budget should create new one with zero usage
|
||||||
|
budget = registry.get_budget("thread-1")
|
||||||
|
assert budget.total_tokens == 0
|
||||||
|
|
||||||
|
def test_cleanup_thread(self):
|
||||||
|
"""cleanup_thread should return and remove budget."""
|
||||||
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||||
|
registry.consume("thread-1", prompt_tokens=500)
|
||||||
|
|
||||||
|
final_budget = registry.cleanup_thread("thread-1")
|
||||||
|
|
||||||
|
assert final_budget.total_tokens == 500
|
||||||
|
assert registry.cleanup_thread("thread-1") is None # Already cleaned
|
||||||
|
|
||||||
|
def test_global_registry(self):
|
||||||
|
"""Global registry should be singleton."""
|
||||||
|
registry1 = get_budget_registry()
|
||||||
|
registry2 = get_budget_registry()
|
||||||
|
|
||||||
|
assert registry1 is registry2
|
||||||
|
|
||||||
|
def test_global_configure(self):
|
||||||
|
"""configure_budget_registry should update global."""
|
||||||
|
configure_budget_registry(max_tokens_per_thread=75000)
|
||||||
|
registry = get_budget_registry()
|
||||||
|
|
||||||
|
budget = registry.get_budget("new-thread")
|
||||||
|
assert budget.max_tokens == 75000
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# UsageTracker Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestUsageTracker:
|
||||||
|
"""Test UsageTracker for billing/metering."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset(self):
|
||||||
|
"""Reset global tracker before each test."""
|
||||||
|
reset_usage_tracker()
|
||||||
|
yield
|
||||||
|
reset_usage_tracker()
|
||||||
|
|
||||||
|
def test_record_creates_event(self):
|
||||||
|
"""record() should create and return UsageEvent."""
|
||||||
|
tracker = UsageTracker()
|
||||||
|
|
||||||
|
event = tracker.record(
|
||||||
|
thread_id="thread-1",
|
||||||
|
agent_id="greeter",
|
||||||
|
model="grok-4.1",
|
||||||
|
provider="xai",
|
||||||
|
prompt_tokens=500,
|
||||||
|
completion_tokens=200,
|
||||||
|
latency_ms=150.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert event.thread_id == "thread-1"
|
||||||
|
assert event.agent_id == "greeter"
|
||||||
|
assert event.model == "grok-4.1"
|
||||||
|
assert event.total_tokens == 700
|
||||||
|
assert event.timestamp is not None
|
||||||
|
|
||||||
|
def test_record_estimates_cost(self):
|
||||||
|
"""record() should estimate cost for known models."""
|
||||||
|
tracker = UsageTracker()
|
||||||
|
|
||||||
|
event = tracker.record(
|
||||||
|
thread_id="thread-1",
|
||||||
|
agent_id="agent",
|
||||||
|
model="grok-4.1",
|
||||||
|
provider="xai",
|
||||||
|
prompt_tokens=1_000_000, # 1M prompt
|
||||||
|
completion_tokens=1_000_000, # 1M completion
|
||||||
|
latency_ms=1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# grok-4.1: $3/1M prompt + $15/1M completion = $18
|
||||||
|
assert event.estimated_cost == 18.0
|
||||||
|
|
||||||
|
def test_subscriber_receives_events(self):
|
||||||
|
"""Subscribers should receive events on record."""
|
||||||
|
tracker = UsageTracker()
|
||||||
|
received = []
|
||||||
|
|
||||||
|
tracker.subscribe(lambda e: received.append(e))
|
||||||
|
|
||||||
|
tracker.record(
|
||||||
|
thread_id="t1",
|
||||||
|
agent_id="agent",
|
||||||
|
model="gpt-4o",
|
||||||
|
provider="openai",
|
||||||
|
prompt_tokens=100,
|
||||||
|
completion_tokens=50,
|
||||||
|
latency_ms=50,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(received) == 1
|
||||||
|
assert received[0].thread_id == "t1"
|
||||||
|
|
||||||
|
def test_unsubscribe(self):
|
||||||
|
"""unsubscribe should stop receiving events."""
|
||||||
|
tracker = UsageTracker()
|
||||||
|
received = []
|
||||||
|
callback = lambda e: received.append(e)
|
||||||
|
|
||||||
|
tracker.subscribe(callback)
|
||||||
|
tracker.record(thread_id="t1", agent_id=None, model="m", provider="p",
|
||||||
|
prompt_tokens=10, completion_tokens=10, latency_ms=10)
|
||||||
|
|
||||||
|
tracker.unsubscribe(callback)
|
||||||
|
tracker.record(thread_id="t2", agent_id=None, model="m", provider="p",
|
||||||
|
prompt_tokens=10, completion_tokens=10, latency_ms=10)
|
||||||
|
|
||||||
|
assert len(received) == 1
|
||||||
|
|
||||||
|
def test_get_totals(self):
|
||||||
|
"""get_totals should return aggregate stats."""
|
||||||
|
tracker = UsageTracker()
|
||||||
|
|
||||||
|
tracker.record(thread_id="t1", agent_id="a1", model="m1", provider="p",
|
||||||
|
prompt_tokens=100, completion_tokens=50, latency_ms=100)
|
||||||
|
tracker.record(thread_id="t2", agent_id="a2", model="m2", provider="p",
|
||||||
|
prompt_tokens=200, completion_tokens=100, latency_ms=200)
|
||||||
|
|
||||||
|
totals = tracker.get_totals()
|
||||||
|
|
||||||
|
assert totals["prompt_tokens"] == 300
|
||||||
|
assert totals["completion_tokens"] == 150
|
||||||
|
assert totals["total_tokens"] == 450
|
||||||
|
assert totals["request_count"] == 2
|
||||||
|
assert totals["avg_latency_ms"] == 150.0
|
||||||
|
|
||||||
|
def test_get_agent_totals(self):
|
||||||
|
"""get_agent_totals should return per-agent stats."""
|
||||||
|
tracker = UsageTracker()
|
||||||
|
|
||||||
|
tracker.record(thread_id="t1", agent_id="greeter", model="m", provider="p",
|
||||||
|
prompt_tokens=100, completion_tokens=50, latency_ms=100)
|
||||||
|
tracker.record(thread_id="t2", agent_id="greeter", model="m", provider="p",
|
||||||
|
prompt_tokens=100, completion_tokens=50, latency_ms=100)
|
||||||
|
tracker.record(thread_id="t3", agent_id="shouter", model="m", provider="p",
|
||||||
|
prompt_tokens=200, completion_tokens=100, latency_ms=200)
|
||||||
|
|
||||||
|
greeter = tracker.get_agent_totals("greeter")
|
||||||
|
shouter = tracker.get_agent_totals("shouter")
|
||||||
|
|
||||||
|
assert greeter["total_tokens"] == 300
|
||||||
|
assert greeter["request_count"] == 2
|
||||||
|
assert shouter["total_tokens"] == 300
|
||||||
|
assert shouter["request_count"] == 1
|
||||||
|
|
||||||
|
def test_get_model_totals(self):
|
||||||
|
"""get_model_totals should return per-model stats."""
|
||||||
|
tracker = UsageTracker()
|
||||||
|
|
||||||
|
tracker.record(thread_id="t1", agent_id="a", model="grok-4.1", provider="xai",
|
||||||
|
prompt_tokens=1000, completion_tokens=500, latency_ms=100)
|
||||||
|
tracker.record(thread_id="t2", agent_id="a", model="claude-sonnet-4", provider="anthropic",
|
||||||
|
prompt_tokens=500, completion_tokens=250, latency_ms=100)
|
||||||
|
|
||||||
|
grok = tracker.get_model_totals("grok-4.1")
|
||||||
|
claude = tracker.get_model_totals("claude-sonnet-4")
|
||||||
|
|
||||||
|
assert grok["total_tokens"] == 1500
|
||||||
|
assert claude["total_tokens"] == 750
|
||||||
|
|
||||||
|
def test_metadata_passed_through(self):
|
||||||
|
"""Metadata should be included in events."""
|
||||||
|
tracker = UsageTracker()
|
||||||
|
received = []
|
||||||
|
tracker.subscribe(lambda e: received.append(e))
|
||||||
|
|
||||||
|
tracker.record(
|
||||||
|
thread_id="t1",
|
||||||
|
agent_id="a",
|
||||||
|
model="m",
|
||||||
|
provider="p",
|
||||||
|
prompt_tokens=10,
|
||||||
|
completion_tokens=10,
|
||||||
|
latency_ms=10,
|
||||||
|
metadata={"org_id": "org-123", "user_id": "user-456"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert received[0].metadata["org_id"] == "org-123"
|
||||||
|
assert received[0].metadata["user_id"] == "user-456"
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Cost Estimation Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestCostEstimation:
|
||||||
|
"""Test cost estimation for various models."""
|
||||||
|
|
||||||
|
def test_grok_cost(self):
|
||||||
|
"""Grok models should use correct pricing."""
|
||||||
|
cost = estimate_cost("grok-4.1", prompt_tokens=1_000_000, completion_tokens=1_000_000)
|
||||||
|
# $3/1M prompt + $15/1M completion = $18
|
||||||
|
assert cost == 18.0
|
||||||
|
|
||||||
|
def test_claude_opus_cost(self):
|
||||||
|
"""Claude Opus should use correct pricing."""
|
||||||
|
cost = estimate_cost("claude-opus-4", prompt_tokens=1_000_000, completion_tokens=1_000_000)
|
||||||
|
# $15/1M prompt + $75/1M completion = $90
|
||||||
|
assert cost == 90.0
|
||||||
|
|
||||||
|
def test_gpt4o_cost(self):
|
||||||
|
"""GPT-4o should use correct pricing."""
|
||||||
|
cost = estimate_cost("gpt-4o", prompt_tokens=1_000_000, completion_tokens=1_000_000)
|
||||||
|
# $2.5/1M prompt + $10/1M completion = $12.5
|
||||||
|
assert cost == 12.5
|
||||||
|
|
||||||
|
def test_unknown_model_returns_none(self):
|
||||||
|
"""Unknown model should return None."""
|
||||||
|
cost = estimate_cost("unknown-model", prompt_tokens=1000, completion_tokens=500)
|
||||||
|
assert cost is None
|
||||||
|
|
||||||
|
def test_small_usage_cost(self):
|
||||||
|
"""Small token counts should produce fractional costs."""
|
||||||
|
cost = estimate_cost("gpt-4o-mini", prompt_tokens=1000, completion_tokens=500)
|
||||||
|
# 1000 tokens * $0.15/1M = $0.00015
|
||||||
|
# 500 tokens * $0.6/1M = $0.0003
|
||||||
|
# Total = $0.00045
|
||||||
|
assert cost == pytest.approx(0.00045, rel=1e-4)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# LLMRouter Integration Tests (Mocked)
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestLLMRouterBudgetIntegration:
|
||||||
|
"""Test LLMRouter budget enforcement."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_all(self):
|
||||||
|
"""Reset all global registries."""
|
||||||
|
reset_budget_registry()
|
||||||
|
reset_usage_tracker()
|
||||||
|
yield
|
||||||
|
reset_budget_registry()
|
||||||
|
reset_usage_tracker()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_consumes_budget(self):
|
||||||
|
"""LLM complete should consume from thread budget."""
|
||||||
|
from xml_pipeline.llm.router import LLMRouter
|
||||||
|
from xml_pipeline.llm.backend import LLMResponse
|
||||||
|
|
||||||
|
# Create mock backend
|
||||||
|
mock_backend = Mock()
|
||||||
|
mock_backend.name = "mock"
|
||||||
|
mock_backend.provider = "test"
|
||||||
|
mock_backend.serves_model = Mock(return_value=True)
|
||||||
|
mock_backend.priority = 1
|
||||||
|
mock_backend.load = 0
|
||||||
|
mock_backend.complete = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="Hello!",
|
||||||
|
model="test-model",
|
||||||
|
usage={"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150},
|
||||||
|
finish_reason="stop",
|
||||||
|
))
|
||||||
|
|
||||||
|
# Configure budget
|
||||||
|
configure_budget_registry(max_tokens_per_thread=10000)
|
||||||
|
budget_registry = get_budget_registry()
|
||||||
|
|
||||||
|
# Create router with mock backend
|
||||||
|
router = LLMRouter()
|
||||||
|
router.backends.append(mock_backend)
|
||||||
|
|
||||||
|
# Make request
|
||||||
|
response = await router.complete(
|
||||||
|
model="test-model",
|
||||||
|
messages=[{"role": "user", "content": "Hi"}],
|
||||||
|
thread_id="test-thread-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.content == "Hello!"
|
||||||
|
|
||||||
|
# Verify budget consumed
|
||||||
|
usage = budget_registry.get_usage("test-thread-123")
|
||||||
|
assert usage["prompt_tokens"] == 100
|
||||||
|
assert usage["completion_tokens"] == 50
|
||||||
|
assert usage["total_tokens"] == 150
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_emits_usage_event(self):
|
||||||
|
"""LLM complete should emit usage event."""
|
||||||
|
from xml_pipeline.llm.router import LLMRouter
|
||||||
|
from xml_pipeline.llm.backend import LLMResponse
|
||||||
|
|
||||||
|
mock_backend = Mock()
|
||||||
|
mock_backend.name = "mock"
|
||||||
|
mock_backend.provider = "test"
|
||||||
|
mock_backend.serves_model = Mock(return_value=True)
|
||||||
|
mock_backend.priority = 1
|
||||||
|
mock_backend.load = 0
|
||||||
|
mock_backend.complete = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="Hello!",
|
||||||
|
model="test-model",
|
||||||
|
usage={"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150},
|
||||||
|
finish_reason="stop",
|
||||||
|
))
|
||||||
|
|
||||||
|
# Subscribe to usage events
|
||||||
|
tracker = get_usage_tracker()
|
||||||
|
received_events = []
|
||||||
|
tracker.subscribe(lambda e: received_events.append(e))
|
||||||
|
|
||||||
|
# Create router and make request
|
||||||
|
router = LLMRouter()
|
||||||
|
router.backends.append(mock_backend)
|
||||||
|
|
||||||
|
await router.complete(
|
||||||
|
model="test-model",
|
||||||
|
messages=[{"role": "user", "content": "Hi"}],
|
||||||
|
thread_id="test-thread",
|
||||||
|
agent_id="greeter",
|
||||||
|
metadata={"org_id": "test-org"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify event emitted
|
||||||
|
assert len(received_events) == 1
|
||||||
|
event = received_events[0]
|
||||||
|
assert event.thread_id == "test-thread"
|
||||||
|
assert event.agent_id == "greeter"
|
||||||
|
assert event.total_tokens == 150
|
||||||
|
assert event.metadata["org_id"] == "test-org"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_raises_when_budget_exhausted(self):
|
||||||
|
"""LLM complete should raise when budget exhausted."""
|
||||||
|
from xml_pipeline.llm.router import LLMRouter
|
||||||
|
|
||||||
|
# Configure small budget and exhaust it
|
||||||
|
configure_budget_registry(max_tokens_per_thread=100)
|
||||||
|
budget_registry = get_budget_registry()
|
||||||
|
budget_registry.consume("test-thread", prompt_tokens=100)
|
||||||
|
|
||||||
|
mock_backend = Mock()
|
||||||
|
mock_backend.name = "mock"
|
||||||
|
mock_backend.serves_model = Mock(return_value=True)
|
||||||
|
mock_backend.priority = 1
|
||||||
|
|
||||||
|
router = LLMRouter()
|
||||||
|
router.backends.append(mock_backend)
|
||||||
|
|
||||||
|
with pytest.raises(BudgetExhaustedError) as exc_info:
|
||||||
|
await router.complete(
|
||||||
|
model="test-model",
|
||||||
|
messages=[{"role": "user", "content": "Hi"}],
|
||||||
|
thread_id="test-thread",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "budget exhausted" in str(exc_info.value)
|
||||||
|
# Backend should NOT have been called
|
||||||
|
mock_backend.complete.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_without_thread_id_skips_budget(self):
|
||||||
|
"""LLM complete without thread_id should skip budget check."""
|
||||||
|
from xml_pipeline.llm.router import LLMRouter
|
||||||
|
from xml_pipeline.llm.backend import LLMResponse
|
||||||
|
|
||||||
|
mock_backend = Mock()
|
||||||
|
mock_backend.name = "mock"
|
||||||
|
mock_backend.provider = "test"
|
||||||
|
mock_backend.serves_model = Mock(return_value=True)
|
||||||
|
mock_backend.priority = 1
|
||||||
|
mock_backend.load = 0
|
||||||
|
mock_backend.complete = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="Hello!",
|
||||||
|
model="test-model",
|
||||||
|
usage={"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150},
|
||||||
|
finish_reason="stop",
|
||||||
|
))
|
||||||
|
|
||||||
|
router = LLMRouter()
|
||||||
|
router.backends.append(mock_backend)
|
||||||
|
|
||||||
|
# Should not raise - no budget checking
|
||||||
|
response = await router.complete(
|
||||||
|
model="test-model",
|
||||||
|
messages=[{"role": "user", "content": "Hi"}],
|
||||||
|
# No thread_id
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.content == "Hello!"
|
||||||
|
|
@ -16,7 +16,20 @@ Usage:
|
||||||
response = await router.complete(
|
response = await router.complete(
|
||||||
model="grok-4.1",
|
model="grok-4.1",
|
||||||
messages=[{"role": "user", "content": "Hello"}],
|
messages=[{"role": "user", "content": "Hello"}],
|
||||||
|
thread_id=metadata.thread_id, # For budget enforcement
|
||||||
|
agent_id=metadata.own_name, # For usage tracking
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Usage Tracking:
|
||||||
|
from xml_pipeline.llm import get_usage_tracker
|
||||||
|
|
||||||
|
tracker = get_usage_tracker()
|
||||||
|
|
||||||
|
# Subscribe to events for billing
|
||||||
|
tracker.subscribe(lambda event: billing_api.record(event))
|
||||||
|
|
||||||
|
# Query totals
|
||||||
|
totals = tracker.get_totals()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from xml_pipeline.llm.router import (
|
from xml_pipeline.llm.router import (
|
||||||
|
|
@ -27,14 +40,27 @@ from xml_pipeline.llm.router import (
|
||||||
Strategy,
|
Strategy,
|
||||||
)
|
)
|
||||||
from xml_pipeline.llm.backend import LLMRequest, LLMResponse, BackendError
|
from xml_pipeline.llm.backend import LLMRequest, LLMResponse, BackendError
|
||||||
|
from xml_pipeline.llm.usage_tracker import (
|
||||||
|
UsageTracker,
|
||||||
|
UsageEvent,
|
||||||
|
get_usage_tracker,
|
||||||
|
reset_usage_tracker,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Router
|
||||||
"LLMRouter",
|
"LLMRouter",
|
||||||
"get_router",
|
"get_router",
|
||||||
"configure_router",
|
"configure_router",
|
||||||
"complete",
|
"complete",
|
||||||
"Strategy",
|
"Strategy",
|
||||||
|
# Backend
|
||||||
"LLMRequest",
|
"LLMRequest",
|
||||||
"LLMResponse",
|
"LLMResponse",
|
||||||
"BackendError",
|
"BackendError",
|
||||||
|
# Usage tracking
|
||||||
|
"UsageTracker",
|
||||||
|
"UsageEvent",
|
||||||
|
"get_usage_tracker",
|
||||||
|
"reset_usage_tracker",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,8 @@ The router handles:
|
||||||
- Load balancing (failover, round-robin, least-loaded)
|
- Load balancing (failover, round-robin, least-loaded)
|
||||||
- Retries with exponential backoff
|
- Retries with exponential backoff
|
||||||
- Token tracking per agent
|
- Token tracking per agent
|
||||||
|
- Thread budget enforcement
|
||||||
|
- Usage event emission for billing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -16,6 +18,7 @@ from __future__ import annotations
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
@ -125,6 +128,8 @@ class LLMRouter:
|
||||||
max_tokens: int = None,
|
max_tokens: int = None,
|
||||||
tools: List[Dict] = None,
|
tools: List[Dict] = None,
|
||||||
agent_id: str = None,
|
agent_id: str = None,
|
||||||
|
thread_id: str = None,
|
||||||
|
metadata: Dict[str, Any] = None,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
Execute a completion request.
|
Execute a completion request.
|
||||||
|
|
@ -136,10 +141,27 @@ class LLMRouter:
|
||||||
max_tokens: Max tokens in response
|
max_tokens: Max tokens in response
|
||||||
tools: Tool definitions for function calling
|
tools: Tool definitions for function calling
|
||||||
agent_id: Optional agent ID for usage tracking
|
agent_id: Optional agent ID for usage tracking
|
||||||
|
thread_id: Optional thread ID for budget enforcement
|
||||||
|
metadata: Optional metadata for usage events (org_id, user_id, etc.)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LLMResponse with content and usage stats
|
LLMResponse with content and usage stats
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BudgetExhaustedError: If thread has no remaining budget
|
||||||
|
BackendError: If all backends fail
|
||||||
"""
|
"""
|
||||||
|
# Estimate tokens for budget check (rough: 4 chars per token)
|
||||||
|
estimated_tokens = sum(len(m.get("content", "")) for m in messages) // 4
|
||||||
|
estimated_tokens = max(estimated_tokens, 100) # minimum estimate
|
||||||
|
|
||||||
|
# Check thread budget before proceeding
|
||||||
|
if thread_id:
|
||||||
|
from xml_pipeline.message_bus.budget_registry import get_budget_registry
|
||||||
|
budget_registry = get_budget_registry()
|
||||||
|
# This raises BudgetExhaustedError if over budget
|
||||||
|
budget_registry.check_budget(thread_id, estimated_tokens)
|
||||||
|
|
||||||
candidates = self._find_backends(model)
|
candidates = self._find_backends(model)
|
||||||
request = LLMRequest(
|
request = LLMRequest(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
@ -151,6 +173,7 @@ class LLMRouter:
|
||||||
|
|
||||||
last_error = None
|
last_error = None
|
||||||
tried_backends = set()
|
tried_backends = set()
|
||||||
|
start_time = time.monotonic()
|
||||||
|
|
||||||
for attempt in range(self.retries + 1):
|
for attempt in range(self.retries + 1):
|
||||||
# Select backend (different selection on retry for failover)
|
# Select backend (different selection on retry for failover)
|
||||||
|
|
@ -170,14 +193,46 @@ class LLMRouter:
|
||||||
logger.debug(f"Attempting {model} on {backend.name} (attempt {attempt + 1})")
|
logger.debug(f"Attempting {model} on {backend.name} (attempt {attempt + 1})")
|
||||||
response = await backend.complete(request)
|
response = await backend.complete(request)
|
||||||
|
|
||||||
# Track usage
|
# Calculate latency
|
||||||
|
latency_ms = (time.monotonic() - start_time) * 1000
|
||||||
|
|
||||||
|
# Extract usage
|
||||||
|
prompt_tokens = response.usage.get("prompt_tokens", 0)
|
||||||
|
completion_tokens = response.usage.get("completion_tokens", 0)
|
||||||
|
total_tokens = response.usage.get("total_tokens", 0)
|
||||||
|
|
||||||
|
# Track per-agent usage (internal)
|
||||||
if agent_id:
|
if agent_id:
|
||||||
usage = self._agent_usage.setdefault(agent_id, AgentUsage())
|
usage = self._agent_usage.setdefault(agent_id, AgentUsage())
|
||||||
usage.total_tokens += response.usage.get("total_tokens", 0)
|
usage.total_tokens += total_tokens
|
||||||
usage.prompt_tokens += response.usage.get("prompt_tokens", 0)
|
usage.prompt_tokens += prompt_tokens
|
||||||
usage.completion_tokens += response.usage.get("completion_tokens", 0)
|
usage.completion_tokens += completion_tokens
|
||||||
usage.request_count += 1
|
usage.request_count += 1
|
||||||
|
|
||||||
|
# Record to thread budget (enforcement)
|
||||||
|
if thread_id:
|
||||||
|
from xml_pipeline.message_bus.budget_registry import get_budget_registry
|
||||||
|
budget_registry = get_budget_registry()
|
||||||
|
budget_registry.consume(
|
||||||
|
thread_id,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Emit usage event (for billing)
|
||||||
|
from xml_pipeline.llm.usage_tracker import get_usage_tracker
|
||||||
|
tracker = get_usage_tracker()
|
||||||
|
tracker.record(
|
||||||
|
thread_id=thread_id or "",
|
||||||
|
agent_id=agent_id,
|
||||||
|
model=response.model,
|
||||||
|
provider=backend.provider,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except RateLimitError as e:
|
except RateLimitError as e:
|
||||||
|
|
@ -286,6 +341,10 @@ def configure_router(config: Dict[str, Any]) -> LLMRouter:
|
||||||
async def complete(
|
async def complete(
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
|
*,
|
||||||
|
thread_id: str = None,
|
||||||
|
agent_id: str = None,
|
||||||
|
metadata: Dict[str, Any] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
"""
|
"""
|
||||||
|
|
@ -293,6 +352,32 @@ async def complete(
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
from xml_pipeline.llm import router
|
from xml_pipeline.llm import router
|
||||||
response = await router.complete("grok-4.1", messages)
|
response = await router.complete(
|
||||||
|
"grok-4.1",
|
||||||
|
messages,
|
||||||
|
thread_id=metadata.thread_id,
|
||||||
|
agent_id=metadata.own_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model name
|
||||||
|
messages: Chat messages
|
||||||
|
thread_id: Thread UUID for budget enforcement
|
||||||
|
agent_id: Agent name for usage tracking
|
||||||
|
metadata: Extra metadata for billing events
|
||||||
|
**kwargs: Additional arguments (temperature, max_tokens, tools)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
LLMResponse with content and usage stats
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BudgetExhaustedError: If thread budget exhausted
|
||||||
"""
|
"""
|
||||||
return await get_router().complete(model, messages, **kwargs)
|
return await get_router().complete(
|
||||||
|
model,
|
||||||
|
messages,
|
||||||
|
thread_id=thread_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
metadata=metadata,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
|
||||||
346
xml_pipeline/llm/usage_tracker.py
Normal file
346
xml_pipeline/llm/usage_tracker.py
Normal file
|
|
@ -0,0 +1,346 @@
|
||||||
|
"""
|
||||||
|
Usage Tracker — Production billing and gas usage metering.
|
||||||
|
|
||||||
|
This module provides hooks for tracking LLM usage at the platform level.
|
||||||
|
External billing systems can subscribe to usage events for metering.
|
||||||
|
|
||||||
|
Usage Tracking Layers:
|
||||||
|
1. Per-agent (LLMRouter._agent_usage) — Internal token tracking
|
||||||
|
2. Per-thread (ThreadBudgetRegistry) — Enforcement limits
|
||||||
|
3. Platform (UsageTracker) — Production billing/metering
|
||||||
|
|
||||||
|
Example:
|
||||||
|
from xml_pipeline.llm.usage_tracker import get_usage_tracker
|
||||||
|
|
||||||
|
tracker = get_usage_tracker()
|
||||||
|
|
||||||
|
# Subscribe to usage events (for billing webhook, database, etc.)
|
||||||
|
def record_usage(event: UsageEvent):
|
||||||
|
billing_db.record(
|
||||||
|
org_id=event.metadata.get("org_id"),
|
||||||
|
tokens=event.total_tokens,
|
||||||
|
cost=event.estimated_cost,
|
||||||
|
)
|
||||||
|
|
||||||
|
tracker.subscribe(record_usage)
|
||||||
|
|
||||||
|
# Query aggregate usage
|
||||||
|
totals = tracker.get_totals()
|
||||||
|
print(f"Total tokens: {totals['total_tokens']}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any, Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UsageEvent:
|
||||||
|
"""
|
||||||
|
Usage event emitted after each LLM completion.
|
||||||
|
|
||||||
|
This is the main interface for billing systems.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Request identification
|
||||||
|
thread_id: str
|
||||||
|
agent_id: Optional[str]
|
||||||
|
model: str
|
||||||
|
provider: str
|
||||||
|
|
||||||
|
# Token usage
|
||||||
|
prompt_tokens: int
|
||||||
|
completion_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
# Timing
|
||||||
|
timestamp: str # ISO 8601
|
||||||
|
latency_ms: float # Request duration
|
||||||
|
|
||||||
|
# Cost estimation (if available)
|
||||||
|
estimated_cost: Optional[float] = None
|
||||||
|
|
||||||
|
# Extensible metadata (org_id, user_id, etc.)
|
||||||
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
# Cost per 1M tokens for common models (approximate, update as needed)
|
||||||
|
MODEL_COSTS: Dict[str, Dict[str, float]] = {
|
||||||
|
# xAI Grok
|
||||||
|
"grok-4.1": {"prompt": 3.0, "completion": 15.0},
|
||||||
|
"grok-3": {"prompt": 3.0, "completion": 15.0},
|
||||||
|
# Anthropic Claude
|
||||||
|
"claude-opus-4": {"prompt": 15.0, "completion": 75.0},
|
||||||
|
"claude-sonnet-4": {"prompt": 3.0, "completion": 15.0},
|
||||||
|
"claude-sonnet-3-5": {"prompt": 3.0, "completion": 15.0},
|
||||||
|
# OpenAI
|
||||||
|
"gpt-4o": {"prompt": 2.5, "completion": 10.0},
|
||||||
|
"gpt-4o-mini": {"prompt": 0.15, "completion": 0.6},
|
||||||
|
"o1": {"prompt": 15.0, "completion": 60.0},
|
||||||
|
"o3-mini": {"prompt": 1.1, "completion": 4.4},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_cost(
|
||||||
|
model: str,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
) -> Optional[float]:
|
||||||
|
"""
|
||||||
|
Estimate cost in USD for a completion.
|
||||||
|
|
||||||
|
Returns None if model pricing is unknown.
|
||||||
|
"""
|
||||||
|
# Normalize model name for lookup
|
||||||
|
model_lower = model.lower()
|
||||||
|
|
||||||
|
# Find matching pricing (prefer longest prefix match)
|
||||||
|
pricing = None
|
||||||
|
best_match_len = 0
|
||||||
|
|
||||||
|
for model_prefix, costs in MODEL_COSTS.items():
|
||||||
|
prefix_lower = model_prefix.lower()
|
||||||
|
if model_lower.startswith(prefix_lower):
|
||||||
|
if len(prefix_lower) > best_match_len:
|
||||||
|
pricing = costs
|
||||||
|
best_match_len = len(prefix_lower)
|
||||||
|
|
||||||
|
if pricing is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Cost = (tokens / 1M) * cost_per_million
|
||||||
|
prompt_cost = (prompt_tokens / 1_000_000) * pricing["prompt"]
|
||||||
|
completion_cost = (completion_tokens / 1_000_000) * pricing["completion"]
|
||||||
|
|
||||||
|
return round(prompt_cost + completion_cost, 6)
|
||||||
|
|
||||||
|
|
||||||
|
UsageCallback = Callable[[UsageEvent], None]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UsageTotals:
|
||||||
|
"""Aggregate usage statistics."""
|
||||||
|
|
||||||
|
total_tokens: int = 0
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
request_count: int = 0
|
||||||
|
total_cost: float = 0.0
|
||||||
|
total_latency_ms: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
class UsageTracker:
|
||||||
|
"""
|
||||||
|
Platform-level usage tracking for billing and metering.
|
||||||
|
|
||||||
|
Thread-safe. Supports multiple subscribers for real-time event streaming.
|
||||||
|
|
||||||
|
Integration points:
|
||||||
|
- Webhook to billing API
|
||||||
|
- Database for usage records
|
||||||
|
- Metrics/observability (Prometheus, DataDog)
|
||||||
|
- Real-time dashboard (WebSocket)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._callbacks: List[UsageCallback] = []
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
# Aggregate tracking
|
||||||
|
self._totals = UsageTotals()
|
||||||
|
self._per_agent: Dict[str, UsageTotals] = {}
|
||||||
|
self._per_model: Dict[str, UsageTotals] = {}
|
||||||
|
|
||||||
|
def subscribe(self, callback: UsageCallback) -> None:
|
||||||
|
"""
|
||||||
|
Subscribe to usage events.
|
||||||
|
|
||||||
|
Callbacks are invoked synchronously after each LLM completion.
|
||||||
|
For async processing, use a queue in your callback.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
self._callbacks.append(callback)
|
||||||
|
|
||||||
|
def unsubscribe(self, callback: UsageCallback) -> None:
|
||||||
|
"""Unsubscribe from usage events."""
|
||||||
|
with self._lock:
|
||||||
|
if callback in self._callbacks:
|
||||||
|
self._callbacks.remove(callback)
|
||||||
|
|
||||||
|
def record(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
agent_id: Optional[str],
|
||||||
|
model: str,
|
||||||
|
provider: str,
|
||||||
|
prompt_tokens: int,
|
||||||
|
completion_tokens: int,
|
||||||
|
latency_ms: float,
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> UsageEvent:
|
||||||
|
"""
|
||||||
|
Record a usage event and notify subscribers.
|
||||||
|
|
||||||
|
Called by LLMRouter after each completion.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The created UsageEvent (for chaining/logging)
|
||||||
|
"""
|
||||||
|
total_tokens = prompt_tokens + completion_tokens
|
||||||
|
|
||||||
|
event = UsageEvent(
|
||||||
|
thread_id=thread_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
model=model,
|
||||||
|
provider=provider,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
estimated_cost=estimate_cost(model, prompt_tokens, completion_tokens),
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update aggregates
|
||||||
|
with self._lock:
|
||||||
|
self._update_totals(self._totals, event)
|
||||||
|
|
||||||
|
if agent_id:
|
||||||
|
if agent_id not in self._per_agent:
|
||||||
|
self._per_agent[agent_id] = UsageTotals()
|
||||||
|
self._update_totals(self._per_agent[agent_id], event)
|
||||||
|
|
||||||
|
if model not in self._per_model:
|
||||||
|
self._per_model[model] = UsageTotals()
|
||||||
|
self._update_totals(self._per_model[model], event)
|
||||||
|
|
||||||
|
# Copy callbacks to avoid holding lock during invocation
|
||||||
|
callbacks = list(self._callbacks)
|
||||||
|
|
||||||
|
# Notify subscribers (outside lock)
|
||||||
|
for callback in callbacks:
|
||||||
|
try:
|
||||||
|
callback(event)
|
||||||
|
except Exception:
|
||||||
|
# Don't let subscriber errors break tracking
|
||||||
|
pass
|
||||||
|
|
||||||
|
return event
|
||||||
|
|
||||||
|
def _update_totals(self, totals: UsageTotals, event: UsageEvent) -> None:
|
||||||
|
"""Update aggregate totals from an event."""
|
||||||
|
totals.total_tokens += event.total_tokens
|
||||||
|
totals.prompt_tokens += event.prompt_tokens
|
||||||
|
totals.completion_tokens += event.completion_tokens
|
||||||
|
totals.request_count += 1
|
||||||
|
totals.total_latency_ms += event.latency_ms
|
||||||
|
if event.estimated_cost:
|
||||||
|
totals.total_cost += event.estimated_cost
|
||||||
|
|
||||||
|
def get_totals(self) -> Dict[str, Any]:
|
||||||
|
"""Get aggregate usage totals."""
|
||||||
|
with self._lock:
|
||||||
|
return {
|
||||||
|
"total_tokens": self._totals.total_tokens,
|
||||||
|
"prompt_tokens": self._totals.prompt_tokens,
|
||||||
|
"completion_tokens": self._totals.completion_tokens,
|
||||||
|
"request_count": self._totals.request_count,
|
||||||
|
"total_cost": round(self._totals.total_cost, 4),
|
||||||
|
"avg_latency_ms": (
|
||||||
|
self._totals.total_latency_ms / self._totals.request_count
|
||||||
|
if self._totals.request_count > 0
|
||||||
|
else 0
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_agent_totals(self, agent_id: str) -> Dict[str, Any]:
|
||||||
|
"""Get usage totals for a specific agent."""
|
||||||
|
with self._lock:
|
||||||
|
totals = self._per_agent.get(agent_id, UsageTotals())
|
||||||
|
return {
|
||||||
|
"total_tokens": totals.total_tokens,
|
||||||
|
"prompt_tokens": totals.prompt_tokens,
|
||||||
|
"completion_tokens": totals.completion_tokens,
|
||||||
|
"request_count": totals.request_count,
|
||||||
|
"total_cost": round(totals.total_cost, 4),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_model_totals(self, model: str) -> Dict[str, Any]:
|
||||||
|
"""Get usage totals for a specific model."""
|
||||||
|
with self._lock:
|
||||||
|
totals = self._per_model.get(model, UsageTotals())
|
||||||
|
return {
|
||||||
|
"total_tokens": totals.total_tokens,
|
||||||
|
"prompt_tokens": totals.prompt_tokens,
|
||||||
|
"completion_tokens": totals.completion_tokens,
|
||||||
|
"request_count": totals.request_count,
|
||||||
|
"total_cost": round(totals.total_cost, 4),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_all_agent_totals(self) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""Get usage totals for all agents."""
|
||||||
|
with self._lock:
|
||||||
|
return {
|
||||||
|
agent_id: {
|
||||||
|
"total_tokens": t.total_tokens,
|
||||||
|
"prompt_tokens": t.prompt_tokens,
|
||||||
|
"completion_tokens": t.completion_tokens,
|
||||||
|
"request_count": t.request_count,
|
||||||
|
"total_cost": round(t.total_cost, 4),
|
||||||
|
}
|
||||||
|
for agent_id, t in self._per_agent.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_all_model_totals(self) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""Get usage totals for all models."""
|
||||||
|
with self._lock:
|
||||||
|
return {
|
||||||
|
model: {
|
||||||
|
"total_tokens": t.total_tokens,
|
||||||
|
"prompt_tokens": t.prompt_tokens,
|
||||||
|
"completion_tokens": t.completion_tokens,
|
||||||
|
"request_count": t.request_count,
|
||||||
|
"total_cost": round(t.total_cost, 4),
|
||||||
|
}
|
||||||
|
for model, t in self._per_model.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset all tracking (for testing)."""
|
||||||
|
with self._lock:
|
||||||
|
self._totals = UsageTotals()
|
||||||
|
self._per_agent.clear()
|
||||||
|
self._per_model.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Global Instance
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
_tracker: Optional[UsageTracker] = None
|
||||||
|
_tracker_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_usage_tracker() -> UsageTracker:
|
||||||
|
"""Get the global usage tracker."""
|
||||||
|
global _tracker
|
||||||
|
if _tracker is None:
|
||||||
|
with _tracker_lock:
|
||||||
|
if _tracker is None:
|
||||||
|
_tracker = UsageTracker()
|
||||||
|
return _tracker
|
||||||
|
|
||||||
|
|
||||||
|
def reset_usage_tracker() -> None:
|
||||||
|
"""Reset the global tracker (for testing)."""
|
||||||
|
global _tracker
|
||||||
|
with _tracker_lock:
|
||||||
|
if _tracker is not None:
|
||||||
|
_tracker.reset()
|
||||||
|
_tracker = None
|
||||||
|
|
@ -67,6 +67,15 @@ from xml_pipeline.message_bus.buffer_registry import (
|
||||||
reset_buffer_registry,
|
reset_buffer_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from xml_pipeline.message_bus.budget_registry import (
|
||||||
|
ThreadBudget,
|
||||||
|
ThreadBudgetRegistry,
|
||||||
|
BudgetExhaustedError,
|
||||||
|
get_budget_registry,
|
||||||
|
configure_budget_registry,
|
||||||
|
reset_budget_registry,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Pump
|
# Pump
|
||||||
"StreamPump",
|
"StreamPump",
|
||||||
|
|
@ -102,4 +111,11 @@ __all__ = [
|
||||||
"BufferRegistry",
|
"BufferRegistry",
|
||||||
"get_buffer_registry",
|
"get_buffer_registry",
|
||||||
"reset_buffer_registry",
|
"reset_buffer_registry",
|
||||||
|
# Budget registry
|
||||||
|
"ThreadBudget",
|
||||||
|
"ThreadBudgetRegistry",
|
||||||
|
"BudgetExhaustedError",
|
||||||
|
"get_budget_registry",
|
||||||
|
"configure_budget_registry",
|
||||||
|
"reset_budget_registry",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
280
xml_pipeline/message_bus/budget_registry.py
Normal file
280
xml_pipeline/message_bus/budget_registry.py
Normal file
|
|
@ -0,0 +1,280 @@
|
||||||
|
"""
|
||||||
|
Thread Budget Registry — Enforces per-thread token limits.
|
||||||
|
|
||||||
|
Each thread has a token budget that tracks:
|
||||||
|
- Total tokens consumed (prompt + completion)
|
||||||
|
- Requests made
|
||||||
|
- Remaining budget
|
||||||
|
|
||||||
|
When a thread exhausts its budget, LLM calls are blocked.
|
||||||
|
|
||||||
|
Example config:
|
||||||
|
organism:
|
||||||
|
max_tokens_per_thread: 100000 # 100k tokens per thread
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ThreadBudget:
|
||||||
|
"""Track token usage for a single thread."""
|
||||||
|
|
||||||
|
max_tokens: int
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
request_count: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_tokens(self) -> int:
|
||||||
|
"""Total tokens consumed."""
|
||||||
|
return self.prompt_tokens + self.completion_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def remaining(self) -> int:
|
||||||
|
"""Remaining token budget."""
|
||||||
|
return max(0, self.max_tokens - self.total_tokens)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_exhausted(self) -> bool:
|
||||||
|
"""True if budget is exhausted."""
|
||||||
|
return self.total_tokens >= self.max_tokens
|
||||||
|
|
||||||
|
def can_consume(self, estimated_tokens: int) -> bool:
|
||||||
|
"""Check if we can consume the given tokens without exceeding budget."""
|
||||||
|
return self.total_tokens + estimated_tokens <= self.max_tokens
|
||||||
|
|
||||||
|
def consume(
|
||||||
|
self,
|
||||||
|
prompt_tokens: int = 0,
|
||||||
|
completion_tokens: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Record token consumption."""
|
||||||
|
self.prompt_tokens += prompt_tokens
|
||||||
|
self.completion_tokens += completion_tokens
|
||||||
|
self.request_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
class BudgetExhaustedError(Exception):
|
||||||
|
"""Raised when a thread's token budget is exhausted."""
|
||||||
|
|
||||||
|
def __init__(self, thread_id: str, used: int, max_tokens: int):
|
||||||
|
self.thread_id = thread_id
|
||||||
|
self.used = used
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
super().__init__(
|
||||||
|
f"Thread {thread_id[:8]}... budget exhausted: "
|
||||||
|
f"{used}/{max_tokens} tokens used"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadBudgetRegistry:
|
||||||
|
"""
|
||||||
|
Manages token budgets per thread.
|
||||||
|
|
||||||
|
Thread-safe for concurrent access.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
registry = get_budget_registry()
|
||||||
|
registry.configure(max_tokens_per_thread=100000)
|
||||||
|
|
||||||
|
# Before LLM call
|
||||||
|
registry.check_budget(thread_id, estimated_tokens=1000)
|
||||||
|
|
||||||
|
# After LLM call
|
||||||
|
registry.consume(thread_id, prompt=500, completion=300)
|
||||||
|
|
||||||
|
# Get usage
|
||||||
|
budget = registry.get_budget(thread_id)
|
||||||
|
print(f"Used: {budget.total_tokens}, Remaining: {budget.remaining}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_tokens_per_thread: int = 100_000):
|
||||||
|
"""
|
||||||
|
Initialize budget registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_tokens_per_thread: Default budget for new threads.
|
||||||
|
"""
|
||||||
|
self._max_tokens_per_thread = max_tokens_per_thread
|
||||||
|
self._budgets: Dict[str, ThreadBudget] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def configure(self, max_tokens_per_thread: int) -> None:
|
||||||
|
"""
|
||||||
|
Update default max tokens for new threads.
|
||||||
|
|
||||||
|
Existing threads keep their current budgets.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
self._max_tokens_per_thread = max_tokens_per_thread
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_tokens_per_thread(self) -> int:
|
||||||
|
"""Get the default max tokens per thread."""
|
||||||
|
return self._max_tokens_per_thread
|
||||||
|
|
||||||
|
def get_budget(self, thread_id: str) -> ThreadBudget:
|
||||||
|
"""
|
||||||
|
Get or create budget for a thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: Thread UUID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ThreadBudget instance
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if thread_id not in self._budgets:
|
||||||
|
self._budgets[thread_id] = ThreadBudget(
|
||||||
|
max_tokens=self._max_tokens_per_thread
|
||||||
|
)
|
||||||
|
return self._budgets[thread_id]
|
||||||
|
|
||||||
|
def check_budget(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
estimated_tokens: int = 0,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Check if thread has budget for the estimated tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: Thread UUID
|
||||||
|
estimated_tokens: Estimated tokens for the request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if budget available
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BudgetExhaustedError if budget is exhausted
|
||||||
|
"""
|
||||||
|
budget = self.get_budget(thread_id)
|
||||||
|
|
||||||
|
if budget.is_exhausted:
|
||||||
|
raise BudgetExhaustedError(
|
||||||
|
thread_id=thread_id,
|
||||||
|
used=budget.total_tokens,
|
||||||
|
max_tokens=budget.max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not budget.can_consume(estimated_tokens):
|
||||||
|
raise BudgetExhaustedError(
|
||||||
|
thread_id=thread_id,
|
||||||
|
used=budget.total_tokens,
|
||||||
|
max_tokens=budget.max_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def consume(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
prompt_tokens: int = 0,
|
||||||
|
completion_tokens: int = 0,
|
||||||
|
) -> ThreadBudget:
|
||||||
|
"""
|
||||||
|
Record token consumption for a thread.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thread_id: Thread UUID
|
||||||
|
prompt_tokens: Prompt tokens used
|
||||||
|
completion_tokens: Completion tokens used
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Updated ThreadBudget
|
||||||
|
"""
|
||||||
|
budget = self.get_budget(thread_id)
|
||||||
|
with self._lock:
|
||||||
|
budget.consume(prompt_tokens, completion_tokens)
|
||||||
|
return budget
|
||||||
|
|
||||||
|
def get_usage(self, thread_id: str) -> Dict[str, int]:
|
||||||
|
"""
|
||||||
|
Get usage stats for a thread.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with prompt_tokens, completion_tokens, total_tokens,
|
||||||
|
remaining, max_tokens, request_count
|
||||||
|
"""
|
||||||
|
budget = self.get_budget(thread_id)
|
||||||
|
return {
|
||||||
|
"prompt_tokens": budget.prompt_tokens,
|
||||||
|
"completion_tokens": budget.completion_tokens,
|
||||||
|
"total_tokens": budget.total_tokens,
|
||||||
|
"remaining": budget.remaining,
|
||||||
|
"max_tokens": budget.max_tokens,
|
||||||
|
"request_count": budget.request_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_all_usage(self) -> Dict[str, Dict[str, int]]:
|
||||||
|
"""Get usage stats for all threads."""
|
||||||
|
with self._lock:
|
||||||
|
return {
|
||||||
|
thread_id: {
|
||||||
|
"prompt_tokens": b.prompt_tokens,
|
||||||
|
"completion_tokens": b.completion_tokens,
|
||||||
|
"total_tokens": b.total_tokens,
|
||||||
|
"remaining": b.remaining,
|
||||||
|
"max_tokens": b.max_tokens,
|
||||||
|
"request_count": b.request_count,
|
||||||
|
}
|
||||||
|
for thread_id, b in self._budgets.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def reset_thread(self, thread_id: str) -> None:
|
||||||
|
"""Reset budget for a specific thread."""
|
||||||
|
with self._lock:
|
||||||
|
self._budgets.pop(thread_id, None)
|
||||||
|
|
||||||
|
def cleanup_thread(self, thread_id: str) -> Optional[ThreadBudget]:
|
||||||
|
"""
|
||||||
|
Remove budget when thread is pruned/completed.
|
||||||
|
|
||||||
|
Returns the final budget for logging/billing, or None if not found.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
return self._budgets.pop(thread_id, None)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear all budgets (for testing)."""
|
||||||
|
with self._lock:
|
||||||
|
self._budgets.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Global Instance
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
_registry: Optional[ThreadBudgetRegistry] = None
|
||||||
|
_registry_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_budget_registry() -> ThreadBudgetRegistry:
|
||||||
|
"""Get the global budget registry."""
|
||||||
|
global _registry
|
||||||
|
if _registry is None:
|
||||||
|
with _registry_lock:
|
||||||
|
if _registry is None:
|
||||||
|
_registry = ThreadBudgetRegistry()
|
||||||
|
return _registry
|
||||||
|
|
||||||
|
|
||||||
|
def configure_budget_registry(max_tokens_per_thread: int) -> ThreadBudgetRegistry:
|
||||||
|
"""Configure the global budget registry."""
|
||||||
|
registry = get_budget_registry()
|
||||||
|
registry.configure(max_tokens_per_thread)
|
||||||
|
return registry
|
||||||
|
|
||||||
|
|
||||||
|
def reset_budget_registry() -> None:
|
||||||
|
"""Reset the global registry (for testing)."""
|
||||||
|
global _registry
|
||||||
|
with _registry_lock:
|
||||||
|
if _registry is not None:
|
||||||
|
_registry.clear()
|
||||||
|
_registry = None
|
||||||
|
|
@ -141,6 +141,9 @@ class OrganismConfig:
|
||||||
max_concurrent_handlers: int = 20 # Concurrent handler invocations
|
max_concurrent_handlers: int = 20 # Concurrent handler invocations
|
||||||
max_concurrent_per_agent: int = 5 # Per-agent rate limit
|
max_concurrent_per_agent: int = 5 # Per-agent rate limit
|
||||||
|
|
||||||
|
# Token budget enforcement
|
||||||
|
max_tokens_per_thread: int = 100_000 # Max tokens per conversation thread
|
||||||
|
|
||||||
# LLM configuration (optional)
|
# LLM configuration (optional)
|
||||||
llm_config: Dict[str, Any] = field(default_factory=dict)
|
llm_config: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
@ -1271,6 +1274,7 @@ class ConfigLoader:
|
||||||
max_concurrent_pipelines=raw.get("max_concurrent_pipelines", 50),
|
max_concurrent_pipelines=raw.get("max_concurrent_pipelines", 50),
|
||||||
max_concurrent_handlers=raw.get("max_concurrent_handlers", 20),
|
max_concurrent_handlers=raw.get("max_concurrent_handlers", 20),
|
||||||
max_concurrent_per_agent=raw.get("max_concurrent_per_agent", 5),
|
max_concurrent_per_agent=raw.get("max_concurrent_per_agent", 5),
|
||||||
|
max_tokens_per_thread=raw.get("max_tokens_per_thread", 100_000),
|
||||||
llm_config=raw.get("llm", {}),
|
llm_config=raw.get("llm", {}),
|
||||||
process_pool_enabled=process_pool_enabled,
|
process_pool_enabled=process_pool_enabled,
|
||||||
process_pool_workers=process_pool_workers,
|
process_pool_workers=process_pool_workers,
|
||||||
|
|
@ -1430,6 +1434,11 @@ async def bootstrap(config_path: str = "config/organism.yaml") -> StreamPump:
|
||||||
configure_router(config.llm_config)
|
configure_router(config.llm_config)
|
||||||
print(f"LLM backends: {len(config.llm_config.get('backends', []))}")
|
print(f"LLM backends: {len(config.llm_config.get('backends', []))}")
|
||||||
|
|
||||||
|
# Configure thread budget registry
|
||||||
|
from xml_pipeline.message_bus.budget_registry import configure_budget_registry
|
||||||
|
configure_budget_registry(config.max_tokens_per_thread)
|
||||||
|
print(f"Token budget: {config.max_tokens_per_thread:,} per thread")
|
||||||
|
|
||||||
# Initialize root thread in registry
|
# Initialize root thread in registry
|
||||||
registry = get_registry()
|
registry = get_registry()
|
||||||
root_uuid = registry.initialize_root(config.name)
|
root_uuid = registry.initialize_root(config.name)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue