Add token budget enforcement and usage tracking
Token Budget System: - ThreadBudgetRegistry tracks per-thread token usage with configurable limits - BudgetExhaustedError raised when thread exceeds max_tokens_per_thread - Integrates with LLMRouter to block LLM calls when budget exhausted - Automatic cleanup when threads are pruned Usage Tracking (for production billing): - UsageTracker emits events after each LLM completion - Subscribers receive UsageEvent with tokens, latency, estimated cost - Cost estimation for common models (Grok, Claude, GPT, etc.) - Aggregate stats by agent, model, and totals Configuration: - max_tokens_per_thread in organism.yaml (default 100k) - LLMRouter.complete() accepts thread_id and metadata parameters Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
4530c06835
commit
8b11323a8b
7 changed files with 1341 additions and 6 deletions
573
tests/test_token_budget.py
Normal file
573
tests/test_token_budget.py
Normal file
|
|
@ -0,0 +1,573 @@
|
|||
"""
|
||||
test_token_budget.py — Tests for token budget and usage tracking.
|
||||
|
||||
Tests:
|
||||
1. ThreadBudgetRegistry - per-thread token limits
|
||||
2. UsageTracker - billing/gas usage events
|
||||
3. LLMRouter integration - budget enforcement
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
|
||||
from xml_pipeline.message_bus.budget_registry import (
|
||||
ThreadBudget,
|
||||
ThreadBudgetRegistry,
|
||||
BudgetExhaustedError,
|
||||
get_budget_registry,
|
||||
configure_budget_registry,
|
||||
reset_budget_registry,
|
||||
)
|
||||
from xml_pipeline.llm.usage_tracker import (
|
||||
UsageEvent,
|
||||
UsageTracker,
|
||||
UsageTotals,
|
||||
estimate_cost,
|
||||
get_usage_tracker,
|
||||
reset_usage_tracker,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ThreadBudget Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestThreadBudget:
|
||||
"""Test ThreadBudget dataclass."""
|
||||
|
||||
def test_initial_state(self):
|
||||
"""New budget should have zero usage."""
|
||||
budget = ThreadBudget(max_tokens=10000)
|
||||
assert budget.total_tokens == 0
|
||||
assert budget.remaining == 10000
|
||||
assert budget.is_exhausted is False
|
||||
|
||||
def test_consume_tokens(self):
|
||||
"""Consuming tokens should update totals."""
|
||||
budget = ThreadBudget(max_tokens=10000)
|
||||
budget.consume(prompt_tokens=500, completion_tokens=300)
|
||||
|
||||
assert budget.prompt_tokens == 500
|
||||
assert budget.completion_tokens == 300
|
||||
assert budget.total_tokens == 800
|
||||
assert budget.remaining == 9200
|
||||
assert budget.request_count == 1
|
||||
|
||||
def test_can_consume_within_budget(self):
|
||||
"""can_consume should return True if within budget."""
|
||||
budget = ThreadBudget(max_tokens=1000)
|
||||
budget.consume(prompt_tokens=400)
|
||||
|
||||
assert budget.can_consume(500) is True
|
||||
assert budget.can_consume(600) is True
|
||||
assert budget.can_consume(601) is False
|
||||
|
||||
def test_is_exhausted(self):
|
||||
"""is_exhausted should return True when budget exceeded."""
|
||||
budget = ThreadBudget(max_tokens=1000)
|
||||
budget.consume(prompt_tokens=1000)
|
||||
|
||||
assert budget.is_exhausted is True
|
||||
assert budget.remaining == 0
|
||||
|
||||
def test_remaining_never_negative(self):
|
||||
"""remaining should never go negative."""
|
||||
budget = ThreadBudget(max_tokens=100)
|
||||
budget.consume(prompt_tokens=200)
|
||||
|
||||
assert budget.remaining == 0
|
||||
assert budget.total_tokens == 200
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ThreadBudgetRegistry Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestThreadBudgetRegistry:
|
||||
"""Test ThreadBudgetRegistry."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset(self):
|
||||
"""Reset global registry before each test."""
|
||||
reset_budget_registry()
|
||||
yield
|
||||
reset_budget_registry()
|
||||
|
||||
def test_default_budget_creation(self):
|
||||
"""Getting budget for new thread should create one."""
|
||||
registry = ThreadBudgetRegistry(max_tokens_per_thread=50000)
|
||||
budget = registry.get_budget("thread-1")
|
||||
|
||||
assert budget.max_tokens == 50000
|
||||
assert budget.total_tokens == 0
|
||||
|
||||
def test_configure_max_tokens(self):
|
||||
"""configure() should update default for new threads."""
|
||||
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||
budget1 = registry.get_budget("thread-1")
|
||||
|
||||
registry.configure(max_tokens_per_thread=20000)
|
||||
budget2 = registry.get_budget("thread-2")
|
||||
|
||||
assert budget1.max_tokens == 10000 # Original unchanged
|
||||
assert budget2.max_tokens == 20000 # New default
|
||||
|
||||
def test_check_budget_success(self):
|
||||
"""check_budget should pass when within budget."""
|
||||
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||
|
||||
result = registry.check_budget("thread-1", estimated_tokens=5000)
|
||||
assert result is True
|
||||
|
||||
def test_check_budget_exhausted(self):
|
||||
"""check_budget should raise when budget exhausted."""
|
||||
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
|
||||
registry.consume("thread-1", prompt_tokens=1000)
|
||||
|
||||
with pytest.raises(BudgetExhaustedError) as exc_info:
|
||||
registry.check_budget("thread-1", estimated_tokens=100)
|
||||
|
||||
assert "budget exhausted" in str(exc_info.value)
|
||||
assert exc_info.value.thread_id == "thread-1"
|
||||
assert exc_info.value.used == 1000
|
||||
assert exc_info.value.max_tokens == 1000
|
||||
|
||||
def test_check_budget_would_exceed(self):
|
||||
"""check_budget should raise when estimate would exceed."""
|
||||
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
|
||||
registry.consume("thread-1", prompt_tokens=600)
|
||||
|
||||
with pytest.raises(BudgetExhaustedError):
|
||||
registry.check_budget("thread-1", estimated_tokens=500)
|
||||
|
||||
def test_consume_returns_budget(self):
|
||||
"""consume() should return updated budget."""
|
||||
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||
|
||||
budget = registry.consume("thread-1", prompt_tokens=100, completion_tokens=50)
|
||||
|
||||
assert budget.total_tokens == 150
|
||||
assert budget.request_count == 1
|
||||
|
||||
def test_get_usage(self):
|
||||
"""get_usage should return dict with all stats."""
|
||||
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||
registry.consume("thread-1", prompt_tokens=500, completion_tokens=200)
|
||||
registry.consume("thread-1", prompt_tokens=300, completion_tokens=100)
|
||||
|
||||
usage = registry.get_usage("thread-1")
|
||||
|
||||
assert usage["prompt_tokens"] == 800
|
||||
assert usage["completion_tokens"] == 300
|
||||
assert usage["total_tokens"] == 1100
|
||||
assert usage["remaining"] == 8900
|
||||
assert usage["request_count"] == 2
|
||||
|
||||
def test_get_all_usage(self):
|
||||
"""get_all_usage should return all threads."""
|
||||
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||
registry.consume("thread-1", prompt_tokens=100)
|
||||
registry.consume("thread-2", prompt_tokens=200)
|
||||
|
||||
all_usage = registry.get_all_usage()
|
||||
|
||||
assert len(all_usage) == 2
|
||||
assert "thread-1" in all_usage
|
||||
assert "thread-2" in all_usage
|
||||
|
||||
def test_reset_thread(self):
|
||||
"""reset_thread should remove budget for thread."""
|
||||
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||
registry.consume("thread-1", prompt_tokens=500)
|
||||
registry.reset_thread("thread-1")
|
||||
|
||||
# Getting budget should create new one with zero usage
|
||||
budget = registry.get_budget("thread-1")
|
||||
assert budget.total_tokens == 0
|
||||
|
||||
def test_cleanup_thread(self):
|
||||
"""cleanup_thread should return and remove budget."""
|
||||
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||
registry.consume("thread-1", prompt_tokens=500)
|
||||
|
||||
final_budget = registry.cleanup_thread("thread-1")
|
||||
|
||||
assert final_budget.total_tokens == 500
|
||||
assert registry.cleanup_thread("thread-1") is None # Already cleaned
|
||||
|
||||
def test_global_registry(self):
|
||||
"""Global registry should be singleton."""
|
||||
registry1 = get_budget_registry()
|
||||
registry2 = get_budget_registry()
|
||||
|
||||
assert registry1 is registry2
|
||||
|
||||
def test_global_configure(self):
|
||||
"""configure_budget_registry should update global."""
|
||||
configure_budget_registry(max_tokens_per_thread=75000)
|
||||
registry = get_budget_registry()
|
||||
|
||||
budget = registry.get_budget("new-thread")
|
||||
assert budget.max_tokens == 75000
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# UsageTracker Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestUsageTracker:
|
||||
"""Test UsageTracker for billing/metering."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset(self):
|
||||
"""Reset global tracker before each test."""
|
||||
reset_usage_tracker()
|
||||
yield
|
||||
reset_usage_tracker()
|
||||
|
||||
def test_record_creates_event(self):
|
||||
"""record() should create and return UsageEvent."""
|
||||
tracker = UsageTracker()
|
||||
|
||||
event = tracker.record(
|
||||
thread_id="thread-1",
|
||||
agent_id="greeter",
|
||||
model="grok-4.1",
|
||||
provider="xai",
|
||||
prompt_tokens=500,
|
||||
completion_tokens=200,
|
||||
latency_ms=150.5,
|
||||
)
|
||||
|
||||
assert event.thread_id == "thread-1"
|
||||
assert event.agent_id == "greeter"
|
||||
assert event.model == "grok-4.1"
|
||||
assert event.total_tokens == 700
|
||||
assert event.timestamp is not None
|
||||
|
||||
def test_record_estimates_cost(self):
|
||||
"""record() should estimate cost for known models."""
|
||||
tracker = UsageTracker()
|
||||
|
||||
event = tracker.record(
|
||||
thread_id="thread-1",
|
||||
agent_id="agent",
|
||||
model="grok-4.1",
|
||||
provider="xai",
|
||||
prompt_tokens=1_000_000, # 1M prompt
|
||||
completion_tokens=1_000_000, # 1M completion
|
||||
latency_ms=1000,
|
||||
)
|
||||
|
||||
# grok-4.1: $3/1M prompt + $15/1M completion = $18
|
||||
assert event.estimated_cost == 18.0
|
||||
|
||||
def test_subscriber_receives_events(self):
|
||||
"""Subscribers should receive events on record."""
|
||||
tracker = UsageTracker()
|
||||
received = []
|
||||
|
||||
tracker.subscribe(lambda e: received.append(e))
|
||||
|
||||
tracker.record(
|
||||
thread_id="t1",
|
||||
agent_id="agent",
|
||||
model="gpt-4o",
|
||||
provider="openai",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
latency_ms=50,
|
||||
)
|
||||
|
||||
assert len(received) == 1
|
||||
assert received[0].thread_id == "t1"
|
||||
|
||||
def test_unsubscribe(self):
|
||||
"""unsubscribe should stop receiving events."""
|
||||
tracker = UsageTracker()
|
||||
received = []
|
||||
callback = lambda e: received.append(e)
|
||||
|
||||
tracker.subscribe(callback)
|
||||
tracker.record(thread_id="t1", agent_id=None, model="m", provider="p",
|
||||
prompt_tokens=10, completion_tokens=10, latency_ms=10)
|
||||
|
||||
tracker.unsubscribe(callback)
|
||||
tracker.record(thread_id="t2", agent_id=None, model="m", provider="p",
|
||||
prompt_tokens=10, completion_tokens=10, latency_ms=10)
|
||||
|
||||
assert len(received) == 1
|
||||
|
||||
def test_get_totals(self):
|
||||
"""get_totals should return aggregate stats."""
|
||||
tracker = UsageTracker()
|
||||
|
||||
tracker.record(thread_id="t1", agent_id="a1", model="m1", provider="p",
|
||||
prompt_tokens=100, completion_tokens=50, latency_ms=100)
|
||||
tracker.record(thread_id="t2", agent_id="a2", model="m2", provider="p",
|
||||
prompt_tokens=200, completion_tokens=100, latency_ms=200)
|
||||
|
||||
totals = tracker.get_totals()
|
||||
|
||||
assert totals["prompt_tokens"] == 300
|
||||
assert totals["completion_tokens"] == 150
|
||||
assert totals["total_tokens"] == 450
|
||||
assert totals["request_count"] == 2
|
||||
assert totals["avg_latency_ms"] == 150.0
|
||||
|
||||
def test_get_agent_totals(self):
|
||||
"""get_agent_totals should return per-agent stats."""
|
||||
tracker = UsageTracker()
|
||||
|
||||
tracker.record(thread_id="t1", agent_id="greeter", model="m", provider="p",
|
||||
prompt_tokens=100, completion_tokens=50, latency_ms=100)
|
||||
tracker.record(thread_id="t2", agent_id="greeter", model="m", provider="p",
|
||||
prompt_tokens=100, completion_tokens=50, latency_ms=100)
|
||||
tracker.record(thread_id="t3", agent_id="shouter", model="m", provider="p",
|
||||
prompt_tokens=200, completion_tokens=100, latency_ms=200)
|
||||
|
||||
greeter = tracker.get_agent_totals("greeter")
|
||||
shouter = tracker.get_agent_totals("shouter")
|
||||
|
||||
assert greeter["total_tokens"] == 300
|
||||
assert greeter["request_count"] == 2
|
||||
assert shouter["total_tokens"] == 300
|
||||
assert shouter["request_count"] == 1
|
||||
|
||||
def test_get_model_totals(self):
|
||||
"""get_model_totals should return per-model stats."""
|
||||
tracker = UsageTracker()
|
||||
|
||||
tracker.record(thread_id="t1", agent_id="a", model="grok-4.1", provider="xai",
|
||||
prompt_tokens=1000, completion_tokens=500, latency_ms=100)
|
||||
tracker.record(thread_id="t2", agent_id="a", model="claude-sonnet-4", provider="anthropic",
|
||||
prompt_tokens=500, completion_tokens=250, latency_ms=100)
|
||||
|
||||
grok = tracker.get_model_totals("grok-4.1")
|
||||
claude = tracker.get_model_totals("claude-sonnet-4")
|
||||
|
||||
assert grok["total_tokens"] == 1500
|
||||
assert claude["total_tokens"] == 750
|
||||
|
||||
def test_metadata_passed_through(self):
|
||||
"""Metadata should be included in events."""
|
||||
tracker = UsageTracker()
|
||||
received = []
|
||||
tracker.subscribe(lambda e: received.append(e))
|
||||
|
||||
tracker.record(
|
||||
thread_id="t1",
|
||||
agent_id="a",
|
||||
model="m",
|
||||
provider="p",
|
||||
prompt_tokens=10,
|
||||
completion_tokens=10,
|
||||
latency_ms=10,
|
||||
metadata={"org_id": "org-123", "user_id": "user-456"},
|
||||
)
|
||||
|
||||
assert received[0].metadata["org_id"] == "org-123"
|
||||
assert received[0].metadata["user_id"] == "user-456"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Cost Estimation Tests
|
||||
# ============================================================================
|
||||
|
||||
class TestCostEstimation:
|
||||
"""Test cost estimation for various models."""
|
||||
|
||||
def test_grok_cost(self):
|
||||
"""Grok models should use correct pricing."""
|
||||
cost = estimate_cost("grok-4.1", prompt_tokens=1_000_000, completion_tokens=1_000_000)
|
||||
# $3/1M prompt + $15/1M completion = $18
|
||||
assert cost == 18.0
|
||||
|
||||
def test_claude_opus_cost(self):
|
||||
"""Claude Opus should use correct pricing."""
|
||||
cost = estimate_cost("claude-opus-4", prompt_tokens=1_000_000, completion_tokens=1_000_000)
|
||||
# $15/1M prompt + $75/1M completion = $90
|
||||
assert cost == 90.0
|
||||
|
||||
def test_gpt4o_cost(self):
|
||||
"""GPT-4o should use correct pricing."""
|
||||
cost = estimate_cost("gpt-4o", prompt_tokens=1_000_000, completion_tokens=1_000_000)
|
||||
# $2.5/1M prompt + $10/1M completion = $12.5
|
||||
assert cost == 12.5
|
||||
|
||||
def test_unknown_model_returns_none(self):
|
||||
"""Unknown model should return None."""
|
||||
cost = estimate_cost("unknown-model", prompt_tokens=1000, completion_tokens=500)
|
||||
assert cost is None
|
||||
|
||||
def test_small_usage_cost(self):
|
||||
"""Small token counts should produce fractional costs."""
|
||||
cost = estimate_cost("gpt-4o-mini", prompt_tokens=1000, completion_tokens=500)
|
||||
# 1000 tokens * $0.15/1M = $0.00015
|
||||
# 500 tokens * $0.6/1M = $0.0003
|
||||
# Total = $0.00045
|
||||
assert cost == pytest.approx(0.00045, rel=1e-4)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# LLMRouter Integration Tests (Mocked)
|
||||
# ============================================================================
|
||||
|
||||
class TestLLMRouterBudgetIntegration:
|
||||
"""Test LLMRouter budget enforcement."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_all(self):
|
||||
"""Reset all global registries."""
|
||||
reset_budget_registry()
|
||||
reset_usage_tracker()
|
||||
yield
|
||||
reset_budget_registry()
|
||||
reset_usage_tracker()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_consumes_budget(self):
|
||||
"""LLM complete should consume from thread budget."""
|
||||
from xml_pipeline.llm.router import LLMRouter
|
||||
from xml_pipeline.llm.backend import LLMResponse
|
||||
|
||||
# Create mock backend
|
||||
mock_backend = Mock()
|
||||
mock_backend.name = "mock"
|
||||
mock_backend.provider = "test"
|
||||
mock_backend.serves_model = Mock(return_value=True)
|
||||
mock_backend.priority = 1
|
||||
mock_backend.load = 0
|
||||
mock_backend.complete = AsyncMock(return_value=LLMResponse(
|
||||
content="Hello!",
|
||||
model="test-model",
|
||||
usage={"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150},
|
||||
finish_reason="stop",
|
||||
))
|
||||
|
||||
# Configure budget
|
||||
configure_budget_registry(max_tokens_per_thread=10000)
|
||||
budget_registry = get_budget_registry()
|
||||
|
||||
# Create router with mock backend
|
||||
router = LLMRouter()
|
||||
router.backends.append(mock_backend)
|
||||
|
||||
# Make request
|
||||
response = await router.complete(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
thread_id="test-thread-123",
|
||||
)
|
||||
|
||||
assert response.content == "Hello!"
|
||||
|
||||
# Verify budget consumed
|
||||
usage = budget_registry.get_usage("test-thread-123")
|
||||
assert usage["prompt_tokens"] == 100
|
||||
assert usage["completion_tokens"] == 50
|
||||
assert usage["total_tokens"] == 150
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_emits_usage_event(self):
|
||||
"""LLM complete should emit usage event."""
|
||||
from xml_pipeline.llm.router import LLMRouter
|
||||
from xml_pipeline.llm.backend import LLMResponse
|
||||
|
||||
mock_backend = Mock()
|
||||
mock_backend.name = "mock"
|
||||
mock_backend.provider = "test"
|
||||
mock_backend.serves_model = Mock(return_value=True)
|
||||
mock_backend.priority = 1
|
||||
mock_backend.load = 0
|
||||
mock_backend.complete = AsyncMock(return_value=LLMResponse(
|
||||
content="Hello!",
|
||||
model="test-model",
|
||||
usage={"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150},
|
||||
finish_reason="stop",
|
||||
))
|
||||
|
||||
# Subscribe to usage events
|
||||
tracker = get_usage_tracker()
|
||||
received_events = []
|
||||
tracker.subscribe(lambda e: received_events.append(e))
|
||||
|
||||
# Create router and make request
|
||||
router = LLMRouter()
|
||||
router.backends.append(mock_backend)
|
||||
|
||||
await router.complete(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
thread_id="test-thread",
|
||||
agent_id="greeter",
|
||||
metadata={"org_id": "test-org"},
|
||||
)
|
||||
|
||||
# Verify event emitted
|
||||
assert len(received_events) == 1
|
||||
event = received_events[0]
|
||||
assert event.thread_id == "test-thread"
|
||||
assert event.agent_id == "greeter"
|
||||
assert event.total_tokens == 150
|
||||
assert event.metadata["org_id"] == "test-org"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_raises_when_budget_exhausted(self):
|
||||
"""LLM complete should raise when budget exhausted."""
|
||||
from xml_pipeline.llm.router import LLMRouter
|
||||
|
||||
# Configure small budget and exhaust it
|
||||
configure_budget_registry(max_tokens_per_thread=100)
|
||||
budget_registry = get_budget_registry()
|
||||
budget_registry.consume("test-thread", prompt_tokens=100)
|
||||
|
||||
mock_backend = Mock()
|
||||
mock_backend.name = "mock"
|
||||
mock_backend.serves_model = Mock(return_value=True)
|
||||
mock_backend.priority = 1
|
||||
|
||||
router = LLMRouter()
|
||||
router.backends.append(mock_backend)
|
||||
|
||||
with pytest.raises(BudgetExhaustedError) as exc_info:
|
||||
await router.complete(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
thread_id="test-thread",
|
||||
)
|
||||
|
||||
assert "budget exhausted" in str(exc_info.value)
|
||||
# Backend should NOT have been called
|
||||
mock_backend.complete.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_without_thread_id_skips_budget(self):
|
||||
"""LLM complete without thread_id should skip budget check."""
|
||||
from xml_pipeline.llm.router import LLMRouter
|
||||
from xml_pipeline.llm.backend import LLMResponse
|
||||
|
||||
mock_backend = Mock()
|
||||
mock_backend.name = "mock"
|
||||
mock_backend.provider = "test"
|
||||
mock_backend.serves_model = Mock(return_value=True)
|
||||
mock_backend.priority = 1
|
||||
mock_backend.load = 0
|
||||
mock_backend.complete = AsyncMock(return_value=LLMResponse(
|
||||
content="Hello!",
|
||||
model="test-model",
|
||||
usage={"prompt_tokens": 100, "completion_tokens": 50, "total_tokens": 150},
|
||||
finish_reason="stop",
|
||||
))
|
||||
|
||||
router = LLMRouter()
|
||||
router.backends.append(mock_backend)
|
||||
|
||||
# Should not raise - no budget checking
|
||||
response = await router.complete(
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "Hi"}],
|
||||
# No thread_id
|
||||
)
|
||||
|
||||
assert response.content == "Hello!"
|
||||
|
|
@ -16,7 +16,20 @@ Usage:
|
|||
response = await router.complete(
|
||||
model="grok-4.1",
|
||||
messages=[{"role": "user", "content": "Hello"}],
|
||||
thread_id=metadata.thread_id, # For budget enforcement
|
||||
agent_id=metadata.own_name, # For usage tracking
|
||||
)
|
||||
|
||||
Usage Tracking:
|
||||
from xml_pipeline.llm import get_usage_tracker
|
||||
|
||||
tracker = get_usage_tracker()
|
||||
|
||||
# Subscribe to events for billing
|
||||
tracker.subscribe(lambda event: billing_api.record(event))
|
||||
|
||||
# Query totals
|
||||
totals = tracker.get_totals()
|
||||
"""
|
||||
|
||||
from xml_pipeline.llm.router import (
|
||||
|
|
@ -27,14 +40,27 @@ from xml_pipeline.llm.router import (
|
|||
Strategy,
|
||||
)
|
||||
from xml_pipeline.llm.backend import LLMRequest, LLMResponse, BackendError
|
||||
from xml_pipeline.llm.usage_tracker import (
|
||||
UsageTracker,
|
||||
UsageEvent,
|
||||
get_usage_tracker,
|
||||
reset_usage_tracker,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Router
|
||||
"LLMRouter",
|
||||
"get_router",
|
||||
"configure_router",
|
||||
"complete",
|
||||
"Strategy",
|
||||
# Backend
|
||||
"LLMRequest",
|
||||
"LLMResponse",
|
||||
"BackendError",
|
||||
# Usage tracking
|
||||
"UsageTracker",
|
||||
"UsageEvent",
|
||||
"get_usage_tracker",
|
||||
"reset_usage_tracker",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ The router handles:
|
|||
- Load balancing (failover, round-robin, least-loaded)
|
||||
- Retries with exponential backoff
|
||||
- Token tracking per agent
|
||||
- Thread budget enforcement
|
||||
- Usage event emission for billing
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -16,6 +18,7 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
|
@ -125,6 +128,8 @@ class LLMRouter:
|
|||
max_tokens: int = None,
|
||||
tools: List[Dict] = None,
|
||||
agent_id: str = None,
|
||||
thread_id: str = None,
|
||||
metadata: Dict[str, Any] = None,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
Execute a completion request.
|
||||
|
|
@ -136,10 +141,27 @@ class LLMRouter:
|
|||
max_tokens: Max tokens in response
|
||||
tools: Tool definitions for function calling
|
||||
agent_id: Optional agent ID for usage tracking
|
||||
thread_id: Optional thread ID for budget enforcement
|
||||
metadata: Optional metadata for usage events (org_id, user_id, etc.)
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and usage stats
|
||||
|
||||
Raises:
|
||||
BudgetExhaustedError: If thread has no remaining budget
|
||||
BackendError: If all backends fail
|
||||
"""
|
||||
# Estimate tokens for budget check (rough: 4 chars per token)
|
||||
estimated_tokens = sum(len(m.get("content", "")) for m in messages) // 4
|
||||
estimated_tokens = max(estimated_tokens, 100) # minimum estimate
|
||||
|
||||
# Check thread budget before proceeding
|
||||
if thread_id:
|
||||
from xml_pipeline.message_bus.budget_registry import get_budget_registry
|
||||
budget_registry = get_budget_registry()
|
||||
# This raises BudgetExhaustedError if over budget
|
||||
budget_registry.check_budget(thread_id, estimated_tokens)
|
||||
|
||||
candidates = self._find_backends(model)
|
||||
request = LLMRequest(
|
||||
model=model,
|
||||
|
|
@ -151,6 +173,7 @@ class LLMRouter:
|
|||
|
||||
last_error = None
|
||||
tried_backends = set()
|
||||
start_time = time.monotonic()
|
||||
|
||||
for attempt in range(self.retries + 1):
|
||||
# Select backend (different selection on retry for failover)
|
||||
|
|
@ -170,14 +193,46 @@ class LLMRouter:
|
|||
logger.debug(f"Attempting {model} on {backend.name} (attempt {attempt + 1})")
|
||||
response = await backend.complete(request)
|
||||
|
||||
# Track usage
|
||||
# Calculate latency
|
||||
latency_ms = (time.monotonic() - start_time) * 1000
|
||||
|
||||
# Extract usage
|
||||
prompt_tokens = response.usage.get("prompt_tokens", 0)
|
||||
completion_tokens = response.usage.get("completion_tokens", 0)
|
||||
total_tokens = response.usage.get("total_tokens", 0)
|
||||
|
||||
# Track per-agent usage (internal)
|
||||
if agent_id:
|
||||
usage = self._agent_usage.setdefault(agent_id, AgentUsage())
|
||||
usage.total_tokens += response.usage.get("total_tokens", 0)
|
||||
usage.prompt_tokens += response.usage.get("prompt_tokens", 0)
|
||||
usage.completion_tokens += response.usage.get("completion_tokens", 0)
|
||||
usage.total_tokens += total_tokens
|
||||
usage.prompt_tokens += prompt_tokens
|
||||
usage.completion_tokens += completion_tokens
|
||||
usage.request_count += 1
|
||||
|
||||
# Record to thread budget (enforcement)
|
||||
if thread_id:
|
||||
from xml_pipeline.message_bus.budget_registry import get_budget_registry
|
||||
budget_registry = get_budget_registry()
|
||||
budget_registry.consume(
|
||||
thread_id,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
# Emit usage event (for billing)
|
||||
from xml_pipeline.llm.usage_tracker import get_usage_tracker
|
||||
tracker = get_usage_tracker()
|
||||
tracker.record(
|
||||
thread_id=thread_id or "",
|
||||
agent_id=agent_id,
|
||||
model=response.model,
|
||||
provider=backend.provider,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
latency_ms=latency_ms,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except RateLimitError as e:
|
||||
|
|
@ -286,6 +341,10 @@ def configure_router(config: Dict[str, Any]) -> LLMRouter:
|
|||
async def complete(
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
*,
|
||||
thread_id: str = None,
|
||||
agent_id: str = None,
|
||||
metadata: Dict[str, Any] = None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
"""
|
||||
|
|
@ -293,6 +352,32 @@ async def complete(
|
|||
|
||||
Usage:
|
||||
from xml_pipeline.llm import router
|
||||
response = await router.complete("grok-4.1", messages)
|
||||
response = await router.complete(
|
||||
"grok-4.1",
|
||||
messages,
|
||||
thread_id=metadata.thread_id,
|
||||
agent_id=metadata.own_name,
|
||||
)
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
messages: Chat messages
|
||||
thread_id: Thread UUID for budget enforcement
|
||||
agent_id: Agent name for usage tracking
|
||||
metadata: Extra metadata for billing events
|
||||
**kwargs: Additional arguments (temperature, max_tokens, tools)
|
||||
|
||||
Returns:
|
||||
LLMResponse with content and usage stats
|
||||
|
||||
Raises:
|
||||
BudgetExhaustedError: If thread budget exhausted
|
||||
"""
|
||||
return await get_router().complete(model, messages, **kwargs)
|
||||
return await get_router().complete(
|
||||
model,
|
||||
messages,
|
||||
thread_id=thread_id,
|
||||
agent_id=agent_id,
|
||||
metadata=metadata,
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
|||
346
xml_pipeline/llm/usage_tracker.py
Normal file
346
xml_pipeline/llm/usage_tracker.py
Normal file
|
|
@ -0,0 +1,346 @@
|
|||
"""
|
||||
Usage Tracker — Production billing and gas usage metering.
|
||||
|
||||
This module provides hooks for tracking LLM usage at the platform level.
|
||||
External billing systems can subscribe to usage events for metering.
|
||||
|
||||
Usage Tracking Layers:
|
||||
1. Per-agent (LLMRouter._agent_usage) — Internal token tracking
|
||||
2. Per-thread (ThreadBudgetRegistry) — Enforcement limits
|
||||
3. Platform (UsageTracker) — Production billing/metering
|
||||
|
||||
Example:
|
||||
from xml_pipeline.llm.usage_tracker import get_usage_tracker
|
||||
|
||||
tracker = get_usage_tracker()
|
||||
|
||||
# Subscribe to usage events (for billing webhook, database, etc.)
|
||||
def record_usage(event: UsageEvent):
|
||||
billing_db.record(
|
||||
org_id=event.metadata.get("org_id"),
|
||||
tokens=event.total_tokens,
|
||||
cost=event.estimated_cost,
|
||||
)
|
||||
|
||||
tracker.subscribe(record_usage)
|
||||
|
||||
# Query aggregate usage
|
||||
totals = tracker.get_totals()
|
||||
print(f"Total tokens: {totals['total_tokens']}")
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageEvent:
|
||||
"""
|
||||
Usage event emitted after each LLM completion.
|
||||
|
||||
This is the main interface for billing systems.
|
||||
"""
|
||||
|
||||
# Request identification
|
||||
thread_id: str
|
||||
agent_id: Optional[str]
|
||||
model: str
|
||||
provider: str
|
||||
|
||||
# Token usage
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
# Timing
|
||||
timestamp: str # ISO 8601
|
||||
latency_ms: float # Request duration
|
||||
|
||||
# Cost estimation (if available)
|
||||
estimated_cost: Optional[float] = None
|
||||
|
||||
# Extensible metadata (org_id, user_id, etc.)
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# Cost per 1M tokens for common models (approximate, update as needed)
|
||||
MODEL_COSTS: Dict[str, Dict[str, float]] = {
|
||||
# xAI Grok
|
||||
"grok-4.1": {"prompt": 3.0, "completion": 15.0},
|
||||
"grok-3": {"prompt": 3.0, "completion": 15.0},
|
||||
# Anthropic Claude
|
||||
"claude-opus-4": {"prompt": 15.0, "completion": 75.0},
|
||||
"claude-sonnet-4": {"prompt": 3.0, "completion": 15.0},
|
||||
"claude-sonnet-3-5": {"prompt": 3.0, "completion": 15.0},
|
||||
# OpenAI
|
||||
"gpt-4o": {"prompt": 2.5, "completion": 10.0},
|
||||
"gpt-4o-mini": {"prompt": 0.15, "completion": 0.6},
|
||||
"o1": {"prompt": 15.0, "completion": 60.0},
|
||||
"o3-mini": {"prompt": 1.1, "completion": 4.4},
|
||||
}
|
||||
|
||||
|
||||
def estimate_cost(
|
||||
model: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Estimate cost in USD for a completion.
|
||||
|
||||
Returns None if model pricing is unknown.
|
||||
"""
|
||||
# Normalize model name for lookup
|
||||
model_lower = model.lower()
|
||||
|
||||
# Find matching pricing (prefer longest prefix match)
|
||||
pricing = None
|
||||
best_match_len = 0
|
||||
|
||||
for model_prefix, costs in MODEL_COSTS.items():
|
||||
prefix_lower = model_prefix.lower()
|
||||
if model_lower.startswith(prefix_lower):
|
||||
if len(prefix_lower) > best_match_len:
|
||||
pricing = costs
|
||||
best_match_len = len(prefix_lower)
|
||||
|
||||
if pricing is None:
|
||||
return None
|
||||
|
||||
# Cost = (tokens / 1M) * cost_per_million
|
||||
prompt_cost = (prompt_tokens / 1_000_000) * pricing["prompt"]
|
||||
completion_cost = (completion_tokens / 1_000_000) * pricing["completion"]
|
||||
|
||||
return round(prompt_cost + completion_cost, 6)
|
||||
|
||||
|
||||
UsageCallback = Callable[[UsageEvent], None]
|
||||
|
||||
|
||||
@dataclass
|
||||
class UsageTotals:
|
||||
"""Aggregate usage statistics."""
|
||||
|
||||
total_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
request_count: int = 0
|
||||
total_cost: float = 0.0
|
||||
total_latency_ms: float = 0.0
|
||||
|
||||
|
||||
class UsageTracker:
|
||||
"""
|
||||
Platform-level usage tracking for billing and metering.
|
||||
|
||||
Thread-safe. Supports multiple subscribers for real-time event streaming.
|
||||
|
||||
Integration points:
|
||||
- Webhook to billing API
|
||||
- Database for usage records
|
||||
- Metrics/observability (Prometheus, DataDog)
|
||||
- Real-time dashboard (WebSocket)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._callbacks: List[UsageCallback] = []
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Aggregate tracking
|
||||
self._totals = UsageTotals()
|
||||
self._per_agent: Dict[str, UsageTotals] = {}
|
||||
self._per_model: Dict[str, UsageTotals] = {}
|
||||
|
||||
def subscribe(self, callback: UsageCallback) -> None:
|
||||
"""
|
||||
Subscribe to usage events.
|
||||
|
||||
Callbacks are invoked synchronously after each LLM completion.
|
||||
For async processing, use a queue in your callback.
|
||||
"""
|
||||
with self._lock:
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def unsubscribe(self, callback: UsageCallback) -> None:
|
||||
"""Unsubscribe from usage events."""
|
||||
with self._lock:
|
||||
if callback in self._callbacks:
|
||||
self._callbacks.remove(callback)
|
||||
|
||||
def record(
|
||||
self,
|
||||
thread_id: str,
|
||||
agent_id: Optional[str],
|
||||
model: str,
|
||||
provider: str,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
latency_ms: float,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> UsageEvent:
|
||||
"""
|
||||
Record a usage event and notify subscribers.
|
||||
|
||||
Called by LLMRouter after each completion.
|
||||
|
||||
Returns:
|
||||
The created UsageEvent (for chaining/logging)
|
||||
"""
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
event = UsageEvent(
|
||||
thread_id=thread_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
provider=provider,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
timestamp=datetime.now(timezone.utc).isoformat(),
|
||||
latency_ms=latency_ms,
|
||||
estimated_cost=estimate_cost(model, prompt_tokens, completion_tokens),
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
# Update aggregates
|
||||
with self._lock:
|
||||
self._update_totals(self._totals, event)
|
||||
|
||||
if agent_id:
|
||||
if agent_id not in self._per_agent:
|
||||
self._per_agent[agent_id] = UsageTotals()
|
||||
self._update_totals(self._per_agent[agent_id], event)
|
||||
|
||||
if model not in self._per_model:
|
||||
self._per_model[model] = UsageTotals()
|
||||
self._update_totals(self._per_model[model], event)
|
||||
|
||||
# Copy callbacks to avoid holding lock during invocation
|
||||
callbacks = list(self._callbacks)
|
||||
|
||||
# Notify subscribers (outside lock)
|
||||
for callback in callbacks:
|
||||
try:
|
||||
callback(event)
|
||||
except Exception:
|
||||
# Don't let subscriber errors break tracking
|
||||
pass
|
||||
|
||||
return event
|
||||
|
||||
def _update_totals(self, totals: UsageTotals, event: UsageEvent) -> None:
|
||||
"""Update aggregate totals from an event."""
|
||||
totals.total_tokens += event.total_tokens
|
||||
totals.prompt_tokens += event.prompt_tokens
|
||||
totals.completion_tokens += event.completion_tokens
|
||||
totals.request_count += 1
|
||||
totals.total_latency_ms += event.latency_ms
|
||||
if event.estimated_cost:
|
||||
totals.total_cost += event.estimated_cost
|
||||
|
||||
def get_totals(self) -> Dict[str, Any]:
|
||||
"""Get aggregate usage totals."""
|
||||
with self._lock:
|
||||
return {
|
||||
"total_tokens": self._totals.total_tokens,
|
||||
"prompt_tokens": self._totals.prompt_tokens,
|
||||
"completion_tokens": self._totals.completion_tokens,
|
||||
"request_count": self._totals.request_count,
|
||||
"total_cost": round(self._totals.total_cost, 4),
|
||||
"avg_latency_ms": (
|
||||
self._totals.total_latency_ms / self._totals.request_count
|
||||
if self._totals.request_count > 0
|
||||
else 0
|
||||
),
|
||||
}
|
||||
|
||||
def get_agent_totals(self, agent_id: str) -> Dict[str, Any]:
|
||||
"""Get usage totals for a specific agent."""
|
||||
with self._lock:
|
||||
totals = self._per_agent.get(agent_id, UsageTotals())
|
||||
return {
|
||||
"total_tokens": totals.total_tokens,
|
||||
"prompt_tokens": totals.prompt_tokens,
|
||||
"completion_tokens": totals.completion_tokens,
|
||||
"request_count": totals.request_count,
|
||||
"total_cost": round(totals.total_cost, 4),
|
||||
}
|
||||
|
||||
def get_model_totals(self, model: str) -> Dict[str, Any]:
|
||||
"""Get usage totals for a specific model."""
|
||||
with self._lock:
|
||||
totals = self._per_model.get(model, UsageTotals())
|
||||
return {
|
||||
"total_tokens": totals.total_tokens,
|
||||
"prompt_tokens": totals.prompt_tokens,
|
||||
"completion_tokens": totals.completion_tokens,
|
||||
"request_count": totals.request_count,
|
||||
"total_cost": round(totals.total_cost, 4),
|
||||
}
|
||||
|
||||
def get_all_agent_totals(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get usage totals for all agents."""
|
||||
with self._lock:
|
||||
return {
|
||||
agent_id: {
|
||||
"total_tokens": t.total_tokens,
|
||||
"prompt_tokens": t.prompt_tokens,
|
||||
"completion_tokens": t.completion_tokens,
|
||||
"request_count": t.request_count,
|
||||
"total_cost": round(t.total_cost, 4),
|
||||
}
|
||||
for agent_id, t in self._per_agent.items()
|
||||
}
|
||||
|
||||
def get_all_model_totals(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get usage totals for all models."""
|
||||
with self._lock:
|
||||
return {
|
||||
model: {
|
||||
"total_tokens": t.total_tokens,
|
||||
"prompt_tokens": t.prompt_tokens,
|
||||
"completion_tokens": t.completion_tokens,
|
||||
"request_count": t.request_count,
|
||||
"total_cost": round(t.total_cost, 4),
|
||||
}
|
||||
for model, t in self._per_model.items()
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all tracking (for testing)."""
|
||||
with self._lock:
|
||||
self._totals = UsageTotals()
|
||||
self._per_agent.clear()
|
||||
self._per_model.clear()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Instance
|
||||
# =============================================================================
|
||||
|
||||
_tracker: Optional[UsageTracker] = None
|
||||
_tracker_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_usage_tracker() -> UsageTracker:
|
||||
"""Get the global usage tracker."""
|
||||
global _tracker
|
||||
if _tracker is None:
|
||||
with _tracker_lock:
|
||||
if _tracker is None:
|
||||
_tracker = UsageTracker()
|
||||
return _tracker
|
||||
|
||||
|
||||
def reset_usage_tracker() -> None:
|
||||
"""Reset the global tracker (for testing)."""
|
||||
global _tracker
|
||||
with _tracker_lock:
|
||||
if _tracker is not None:
|
||||
_tracker.reset()
|
||||
_tracker = None
|
||||
|
|
@ -67,6 +67,15 @@ from xml_pipeline.message_bus.buffer_registry import (
|
|||
reset_buffer_registry,
|
||||
)
|
||||
|
||||
from xml_pipeline.message_bus.budget_registry import (
|
||||
ThreadBudget,
|
||||
ThreadBudgetRegistry,
|
||||
BudgetExhaustedError,
|
||||
get_budget_registry,
|
||||
configure_budget_registry,
|
||||
reset_budget_registry,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Pump
|
||||
"StreamPump",
|
||||
|
|
@ -102,4 +111,11 @@ __all__ = [
|
|||
"BufferRegistry",
|
||||
"get_buffer_registry",
|
||||
"reset_buffer_registry",
|
||||
# Budget registry
|
||||
"ThreadBudget",
|
||||
"ThreadBudgetRegistry",
|
||||
"BudgetExhaustedError",
|
||||
"get_budget_registry",
|
||||
"configure_budget_registry",
|
||||
"reset_budget_registry",
|
||||
]
|
||||
|
|
|
|||
280
xml_pipeline/message_bus/budget_registry.py
Normal file
280
xml_pipeline/message_bus/budget_registry.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
"""
|
||||
Thread Budget Registry — Enforces per-thread token limits.
|
||||
|
||||
Each thread has a token budget that tracks:
|
||||
- Total tokens consumed (prompt + completion)
|
||||
- Requests made
|
||||
- Remaining budget
|
||||
|
||||
When a thread exhausts its budget, LLM calls are blocked.
|
||||
|
||||
Example config:
|
||||
organism:
|
||||
max_tokens_per_thread: 100000 # 100k tokens per thread
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ThreadBudget:
|
||||
"""Track token usage for a single thread."""
|
||||
|
||||
max_tokens: int
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
request_count: int = 0
|
||||
|
||||
@property
|
||||
def total_tokens(self) -> int:
|
||||
"""Total tokens consumed."""
|
||||
return self.prompt_tokens + self.completion_tokens
|
||||
|
||||
@property
|
||||
def remaining(self) -> int:
|
||||
"""Remaining token budget."""
|
||||
return max(0, self.max_tokens - self.total_tokens)
|
||||
|
||||
@property
|
||||
def is_exhausted(self) -> bool:
|
||||
"""True if budget is exhausted."""
|
||||
return self.total_tokens >= self.max_tokens
|
||||
|
||||
def can_consume(self, estimated_tokens: int) -> bool:
|
||||
"""Check if we can consume the given tokens without exceeding budget."""
|
||||
return self.total_tokens + estimated_tokens <= self.max_tokens
|
||||
|
||||
def consume(
|
||||
self,
|
||||
prompt_tokens: int = 0,
|
||||
completion_tokens: int = 0,
|
||||
) -> None:
|
||||
"""Record token consumption."""
|
||||
self.prompt_tokens += prompt_tokens
|
||||
self.completion_tokens += completion_tokens
|
||||
self.request_count += 1
|
||||
|
||||
|
||||
class BudgetExhaustedError(Exception):
|
||||
"""Raised when a thread's token budget is exhausted."""
|
||||
|
||||
def __init__(self, thread_id: str, used: int, max_tokens: int):
|
||||
self.thread_id = thread_id
|
||||
self.used = used
|
||||
self.max_tokens = max_tokens
|
||||
super().__init__(
|
||||
f"Thread {thread_id[:8]}... budget exhausted: "
|
||||
f"{used}/{max_tokens} tokens used"
|
||||
)
|
||||
|
||||
|
||||
class ThreadBudgetRegistry:
|
||||
"""
|
||||
Manages token budgets per thread.
|
||||
|
||||
Thread-safe for concurrent access.
|
||||
|
||||
Usage:
|
||||
registry = get_budget_registry()
|
||||
registry.configure(max_tokens_per_thread=100000)
|
||||
|
||||
# Before LLM call
|
||||
registry.check_budget(thread_id, estimated_tokens=1000)
|
||||
|
||||
# After LLM call
|
||||
registry.consume(thread_id, prompt=500, completion=300)
|
||||
|
||||
# Get usage
|
||||
budget = registry.get_budget(thread_id)
|
||||
print(f"Used: {budget.total_tokens}, Remaining: {budget.remaining}")
|
||||
"""
|
||||
|
||||
def __init__(self, max_tokens_per_thread: int = 100_000):
|
||||
"""
|
||||
Initialize budget registry.
|
||||
|
||||
Args:
|
||||
max_tokens_per_thread: Default budget for new threads.
|
||||
"""
|
||||
self._max_tokens_per_thread = max_tokens_per_thread
|
||||
self._budgets: Dict[str, ThreadBudget] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def configure(self, max_tokens_per_thread: int) -> None:
|
||||
"""
|
||||
Update default max tokens for new threads.
|
||||
|
||||
Existing threads keep their current budgets.
|
||||
"""
|
||||
with self._lock:
|
||||
self._max_tokens_per_thread = max_tokens_per_thread
|
||||
|
||||
@property
|
||||
def max_tokens_per_thread(self) -> int:
|
||||
"""Get the default max tokens per thread."""
|
||||
return self._max_tokens_per_thread
|
||||
|
||||
def get_budget(self, thread_id: str) -> ThreadBudget:
|
||||
"""
|
||||
Get or create budget for a thread.
|
||||
|
||||
Args:
|
||||
thread_id: Thread UUID
|
||||
|
||||
Returns:
|
||||
ThreadBudget instance
|
||||
"""
|
||||
with self._lock:
|
||||
if thread_id not in self._budgets:
|
||||
self._budgets[thread_id] = ThreadBudget(
|
||||
max_tokens=self._max_tokens_per_thread
|
||||
)
|
||||
return self._budgets[thread_id]
|
||||
|
||||
def check_budget(
|
||||
self,
|
||||
thread_id: str,
|
||||
estimated_tokens: int = 0,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if thread has budget for the estimated tokens.
|
||||
|
||||
Args:
|
||||
thread_id: Thread UUID
|
||||
estimated_tokens: Estimated tokens for the request
|
||||
|
||||
Returns:
|
||||
True if budget available
|
||||
|
||||
Raises:
|
||||
BudgetExhaustedError if budget is exhausted
|
||||
"""
|
||||
budget = self.get_budget(thread_id)
|
||||
|
||||
if budget.is_exhausted:
|
||||
raise BudgetExhaustedError(
|
||||
thread_id=thread_id,
|
||||
used=budget.total_tokens,
|
||||
max_tokens=budget.max_tokens,
|
||||
)
|
||||
|
||||
if not budget.can_consume(estimated_tokens):
|
||||
raise BudgetExhaustedError(
|
||||
thread_id=thread_id,
|
||||
used=budget.total_tokens,
|
||||
max_tokens=budget.max_tokens,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def consume(
|
||||
self,
|
||||
thread_id: str,
|
||||
prompt_tokens: int = 0,
|
||||
completion_tokens: int = 0,
|
||||
) -> ThreadBudget:
|
||||
"""
|
||||
Record token consumption for a thread.
|
||||
|
||||
Args:
|
||||
thread_id: Thread UUID
|
||||
prompt_tokens: Prompt tokens used
|
||||
completion_tokens: Completion tokens used
|
||||
|
||||
Returns:
|
||||
Updated ThreadBudget
|
||||
"""
|
||||
budget = self.get_budget(thread_id)
|
||||
with self._lock:
|
||||
budget.consume(prompt_tokens, completion_tokens)
|
||||
return budget
|
||||
|
||||
def get_usage(self, thread_id: str) -> Dict[str, int]:
|
||||
"""
|
||||
Get usage stats for a thread.
|
||||
|
||||
Returns:
|
||||
Dict with prompt_tokens, completion_tokens, total_tokens,
|
||||
remaining, max_tokens, request_count
|
||||
"""
|
||||
budget = self.get_budget(thread_id)
|
||||
return {
|
||||
"prompt_tokens": budget.prompt_tokens,
|
||||
"completion_tokens": budget.completion_tokens,
|
||||
"total_tokens": budget.total_tokens,
|
||||
"remaining": budget.remaining,
|
||||
"max_tokens": budget.max_tokens,
|
||||
"request_count": budget.request_count,
|
||||
}
|
||||
|
||||
def get_all_usage(self) -> Dict[str, Dict[str, int]]:
|
||||
"""Get usage stats for all threads."""
|
||||
with self._lock:
|
||||
return {
|
||||
thread_id: {
|
||||
"prompt_tokens": b.prompt_tokens,
|
||||
"completion_tokens": b.completion_tokens,
|
||||
"total_tokens": b.total_tokens,
|
||||
"remaining": b.remaining,
|
||||
"max_tokens": b.max_tokens,
|
||||
"request_count": b.request_count,
|
||||
}
|
||||
for thread_id, b in self._budgets.items()
|
||||
}
|
||||
|
||||
def reset_thread(self, thread_id: str) -> None:
|
||||
"""Reset budget for a specific thread."""
|
||||
with self._lock:
|
||||
self._budgets.pop(thread_id, None)
|
||||
|
||||
def cleanup_thread(self, thread_id: str) -> Optional[ThreadBudget]:
|
||||
"""
|
||||
Remove budget when thread is pruned/completed.
|
||||
|
||||
Returns the final budget for logging/billing, or None if not found.
|
||||
"""
|
||||
with self._lock:
|
||||
return self._budgets.pop(thread_id, None)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all budgets (for testing)."""
|
||||
with self._lock:
|
||||
self._budgets.clear()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Instance
|
||||
# =============================================================================
|
||||
|
||||
_registry: Optional[ThreadBudgetRegistry] = None
|
||||
_registry_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_budget_registry() -> ThreadBudgetRegistry:
|
||||
"""Get the global budget registry."""
|
||||
global _registry
|
||||
if _registry is None:
|
||||
with _registry_lock:
|
||||
if _registry is None:
|
||||
_registry = ThreadBudgetRegistry()
|
||||
return _registry
|
||||
|
||||
|
||||
def configure_budget_registry(max_tokens_per_thread: int) -> ThreadBudgetRegistry:
|
||||
"""Configure the global budget registry."""
|
||||
registry = get_budget_registry()
|
||||
registry.configure(max_tokens_per_thread)
|
||||
return registry
|
||||
|
||||
|
||||
def reset_budget_registry() -> None:
|
||||
"""Reset the global registry (for testing)."""
|
||||
global _registry
|
||||
with _registry_lock:
|
||||
if _registry is not None:
|
||||
_registry.clear()
|
||||
_registry = None
|
||||
|
|
@ -141,6 +141,9 @@ class OrganismConfig:
|
|||
max_concurrent_handlers: int = 20 # Concurrent handler invocations
|
||||
max_concurrent_per_agent: int = 5 # Per-agent rate limit
|
||||
|
||||
# Token budget enforcement
|
||||
max_tokens_per_thread: int = 100_000 # Max tokens per conversation thread
|
||||
|
||||
# LLM configuration (optional)
|
||||
llm_config: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
|
@ -1271,6 +1274,7 @@ class ConfigLoader:
|
|||
max_concurrent_pipelines=raw.get("max_concurrent_pipelines", 50),
|
||||
max_concurrent_handlers=raw.get("max_concurrent_handlers", 20),
|
||||
max_concurrent_per_agent=raw.get("max_concurrent_per_agent", 5),
|
||||
max_tokens_per_thread=raw.get("max_tokens_per_thread", 100_000),
|
||||
llm_config=raw.get("llm", {}),
|
||||
process_pool_enabled=process_pool_enabled,
|
||||
process_pool_workers=process_pool_workers,
|
||||
|
|
@ -1430,6 +1434,11 @@ async def bootstrap(config_path: str = "config/organism.yaml") -> StreamPump:
|
|||
configure_router(config.llm_config)
|
||||
print(f"LLM backends: {len(config.llm_config.get('backends', []))}")
|
||||
|
||||
# Configure thread budget registry
|
||||
from xml_pipeline.message_bus.budget_registry import configure_budget_registry
|
||||
configure_budget_registry(config.max_tokens_per_thread)
|
||||
print(f"Token budget: {config.max_tokens_per_thread:,} per thread")
|
||||
|
||||
# Initialize root thread in registry
|
||||
registry = get_registry()
|
||||
root_uuid = registry.initialize_root(config.name)
|
||||
|
|
|
|||
Loading…
Reference in a new issue