xml-pipeline/xml_pipeline/llm/usage_tracker.py
dullfig 8b11323a8b Add token budget enforcement and usage tracking
Token Budget System:
- ThreadBudgetRegistry tracks per-thread token usage with configurable limits
- BudgetExhaustedError raised when thread exceeds max_tokens_per_thread
- Integrates with LLMRouter to block LLM calls when budget exhausted
- Automatic cleanup when threads are pruned

Usage Tracking (for production billing):
- UsageTracker emits events after each LLM completion
- Subscribers receive UsageEvent with tokens, latency, estimated cost
- Cost estimation for common models (Grok, Claude, GPT, etc.)
- Aggregate stats by agent, model, and totals

Configuration:
- max_tokens_per_thread in organism.yaml (default 100k)
- LLMRouter.complete() accepts thread_id and metadata parameters

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 21:07:43 -08:00

346 lines
11 KiB
Python

"""
Usage Tracker — Production billing and gas usage metering.
This module provides hooks for tracking LLM usage at the platform level.
External billing systems can subscribe to usage events for metering.
Usage Tracking Layers:
1. Per-agent (LLMRouter._agent_usage) — Internal token tracking
2. Per-thread (ThreadBudgetRegistry) — Enforcement limits
3. Platform (UsageTracker) — Production billing/metering
Example:
from xml_pipeline.llm.usage_tracker import get_usage_tracker
tracker = get_usage_tracker()
# Subscribe to usage events (for billing webhook, database, etc.)
def record_usage(event: UsageEvent):
billing_db.record(
org_id=event.metadata.get("org_id"),
tokens=event.total_tokens,
cost=event.estimated_cost,
)
tracker.subscribe(record_usage)
# Query aggregate usage
totals = tracker.get_totals()
print(f"Total tokens: {totals['total_tokens']}")
"""
from __future__ import annotations
import threading
import time
from dataclasses import dataclass, field
from datetime import datetime, timezone
from typing import Any, Callable, Dict, List, Optional
@dataclass
class UsageEvent:
"""
Usage event emitted after each LLM completion.
This is the main interface for billing systems.
"""
# Request identification
thread_id: str
agent_id: Optional[str]
model: str
provider: str
# Token usage
prompt_tokens: int
completion_tokens: int
total_tokens: int
# Timing
timestamp: str # ISO 8601
latency_ms: float # Request duration
# Cost estimation (if available)
estimated_cost: Optional[float] = None
# Extensible metadata (org_id, user_id, etc.)
metadata: Dict[str, Any] = field(default_factory=dict)
# Cost per 1M tokens for common models (approximate, update as needed)
MODEL_COSTS: Dict[str, Dict[str, float]] = {
# xAI Grok
"grok-4.1": {"prompt": 3.0, "completion": 15.0},
"grok-3": {"prompt": 3.0, "completion": 15.0},
# Anthropic Claude
"claude-opus-4": {"prompt": 15.0, "completion": 75.0},
"claude-sonnet-4": {"prompt": 3.0, "completion": 15.0},
"claude-sonnet-3-5": {"prompt": 3.0, "completion": 15.0},
# OpenAI
"gpt-4o": {"prompt": 2.5, "completion": 10.0},
"gpt-4o-mini": {"prompt": 0.15, "completion": 0.6},
"o1": {"prompt": 15.0, "completion": 60.0},
"o3-mini": {"prompt": 1.1, "completion": 4.4},
}
def estimate_cost(
model: str,
prompt_tokens: int,
completion_tokens: int,
) -> Optional[float]:
"""
Estimate cost in USD for a completion.
Returns None if model pricing is unknown.
"""
# Normalize model name for lookup
model_lower = model.lower()
# Find matching pricing (prefer longest prefix match)
pricing = None
best_match_len = 0
for model_prefix, costs in MODEL_COSTS.items():
prefix_lower = model_prefix.lower()
if model_lower.startswith(prefix_lower):
if len(prefix_lower) > best_match_len:
pricing = costs
best_match_len = len(prefix_lower)
if pricing is None:
return None
# Cost = (tokens / 1M) * cost_per_million
prompt_cost = (prompt_tokens / 1_000_000) * pricing["prompt"]
completion_cost = (completion_tokens / 1_000_000) * pricing["completion"]
return round(prompt_cost + completion_cost, 6)
UsageCallback = Callable[[UsageEvent], None]
@dataclass
class UsageTotals:
"""Aggregate usage statistics."""
total_tokens: int = 0
prompt_tokens: int = 0
completion_tokens: int = 0
request_count: int = 0
total_cost: float = 0.0
total_latency_ms: float = 0.0
class UsageTracker:
"""
Platform-level usage tracking for billing and metering.
Thread-safe. Supports multiple subscribers for real-time event streaming.
Integration points:
- Webhook to billing API
- Database for usage records
- Metrics/observability (Prometheus, DataDog)
- Real-time dashboard (WebSocket)
"""
def __init__(self):
self._callbacks: List[UsageCallback] = []
self._lock = threading.Lock()
# Aggregate tracking
self._totals = UsageTotals()
self._per_agent: Dict[str, UsageTotals] = {}
self._per_model: Dict[str, UsageTotals] = {}
def subscribe(self, callback: UsageCallback) -> None:
"""
Subscribe to usage events.
Callbacks are invoked synchronously after each LLM completion.
For async processing, use a queue in your callback.
"""
with self._lock:
self._callbacks.append(callback)
def unsubscribe(self, callback: UsageCallback) -> None:
"""Unsubscribe from usage events."""
with self._lock:
if callback in self._callbacks:
self._callbacks.remove(callback)
def record(
self,
thread_id: str,
agent_id: Optional[str],
model: str,
provider: str,
prompt_tokens: int,
completion_tokens: int,
latency_ms: float,
metadata: Optional[Dict[str, Any]] = None,
) -> UsageEvent:
"""
Record a usage event and notify subscribers.
Called by LLMRouter after each completion.
Returns:
The created UsageEvent (for chaining/logging)
"""
total_tokens = prompt_tokens + completion_tokens
event = UsageEvent(
thread_id=thread_id,
agent_id=agent_id,
model=model,
provider=provider,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
timestamp=datetime.now(timezone.utc).isoformat(),
latency_ms=latency_ms,
estimated_cost=estimate_cost(model, prompt_tokens, completion_tokens),
metadata=metadata or {},
)
# Update aggregates
with self._lock:
self._update_totals(self._totals, event)
if agent_id:
if agent_id not in self._per_agent:
self._per_agent[agent_id] = UsageTotals()
self._update_totals(self._per_agent[agent_id], event)
if model not in self._per_model:
self._per_model[model] = UsageTotals()
self._update_totals(self._per_model[model], event)
# Copy callbacks to avoid holding lock during invocation
callbacks = list(self._callbacks)
# Notify subscribers (outside lock)
for callback in callbacks:
try:
callback(event)
except Exception:
# Don't let subscriber errors break tracking
pass
return event
def _update_totals(self, totals: UsageTotals, event: UsageEvent) -> None:
"""Update aggregate totals from an event."""
totals.total_tokens += event.total_tokens
totals.prompt_tokens += event.prompt_tokens
totals.completion_tokens += event.completion_tokens
totals.request_count += 1
totals.total_latency_ms += event.latency_ms
if event.estimated_cost:
totals.total_cost += event.estimated_cost
def get_totals(self) -> Dict[str, Any]:
"""Get aggregate usage totals."""
with self._lock:
return {
"total_tokens": self._totals.total_tokens,
"prompt_tokens": self._totals.prompt_tokens,
"completion_tokens": self._totals.completion_tokens,
"request_count": self._totals.request_count,
"total_cost": round(self._totals.total_cost, 4),
"avg_latency_ms": (
self._totals.total_latency_ms / self._totals.request_count
if self._totals.request_count > 0
else 0
),
}
def get_agent_totals(self, agent_id: str) -> Dict[str, Any]:
"""Get usage totals for a specific agent."""
with self._lock:
totals = self._per_agent.get(agent_id, UsageTotals())
return {
"total_tokens": totals.total_tokens,
"prompt_tokens": totals.prompt_tokens,
"completion_tokens": totals.completion_tokens,
"request_count": totals.request_count,
"total_cost": round(totals.total_cost, 4),
}
def get_model_totals(self, model: str) -> Dict[str, Any]:
"""Get usage totals for a specific model."""
with self._lock:
totals = self._per_model.get(model, UsageTotals())
return {
"total_tokens": totals.total_tokens,
"prompt_tokens": totals.prompt_tokens,
"completion_tokens": totals.completion_tokens,
"request_count": totals.request_count,
"total_cost": round(totals.total_cost, 4),
}
def get_all_agent_totals(self) -> Dict[str, Dict[str, Any]]:
"""Get usage totals for all agents."""
with self._lock:
return {
agent_id: {
"total_tokens": t.total_tokens,
"prompt_tokens": t.prompt_tokens,
"completion_tokens": t.completion_tokens,
"request_count": t.request_count,
"total_cost": round(t.total_cost, 4),
}
for agent_id, t in self._per_agent.items()
}
def get_all_model_totals(self) -> Dict[str, Dict[str, Any]]:
"""Get usage totals for all models."""
with self._lock:
return {
model: {
"total_tokens": t.total_tokens,
"prompt_tokens": t.prompt_tokens,
"completion_tokens": t.completion_tokens,
"request_count": t.request_count,
"total_cost": round(t.total_cost, 4),
}
for model, t in self._per_model.items()
}
def reset(self) -> None:
"""Reset all tracking (for testing)."""
with self._lock:
self._totals = UsageTotals()
self._per_agent.clear()
self._per_model.clear()
# =============================================================================
# Global Instance
# =============================================================================
_tracker: Optional[UsageTracker] = None
_tracker_lock = threading.Lock()
def get_usage_tracker() -> UsageTracker:
"""Get the global usage tracker."""
global _tracker
if _tracker is None:
with _tracker_lock:
if _tracker is None:
_tracker = UsageTracker()
return _tracker
def reset_usage_tracker() -> None:
"""Reset the global tracker (for testing)."""
global _tracker
with _tracker_lock:
if _tracker is not None:
_tracker.reset()
_tracker = None