Token Budget System: - ThreadBudgetRegistry tracks per-thread token usage with configurable limits - BudgetExhaustedError raised when thread exceeds max_tokens_per_thread - Integrates with LLMRouter to block LLM calls when budget exhausted - Automatic cleanup when threads are pruned Usage Tracking (for production billing): - UsageTracker emits events after each LLM completion - Subscribers receive UsageEvent with tokens, latency, estimated cost - Cost estimation for common models (Grok, Claude, GPT, etc.) - Aggregate stats by agent, model, and totals Configuration: - max_tokens_per_thread in organism.yaml (default 100k) - LLMRouter.complete() accepts thread_id and metadata parameters Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
383 lines
12 KiB
Python
383 lines
12 KiB
Python
"""
|
|
LLM Router - the main entry point for LLM calls.
|
|
|
|
Agents just call:
|
|
response = await router.complete(model="grok-4.1", messages=[...])
|
|
|
|
The router handles:
|
|
- Finding backends that serve the model
|
|
- Load balancing (failover, round-robin, least-loaded)
|
|
- Retries with exponential backoff
|
|
- Token tracking per agent
|
|
- Thread budget enforcement
|
|
- Usage event emission for billing
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import random
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import List, Dict, Any, Optional
|
|
|
|
from xml_pipeline.llm.backend import (
|
|
Backend,
|
|
LLMRequest,
|
|
LLMResponse,
|
|
BackendError,
|
|
RateLimitError,
|
|
ProviderError,
|
|
create_backend,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Strategy(Enum):
|
|
FAILOVER = "failover" # Try in priority order, fail over on error
|
|
ROUND_ROBIN = "round-robin" # Rotate through backends
|
|
LEAST_LOADED = "least-loaded" # Pick backend with lowest current load
|
|
|
|
|
|
@dataclass
|
|
class AgentUsage:
|
|
"""Track token usage per agent."""
|
|
total_tokens: int = 0
|
|
prompt_tokens: int = 0
|
|
completion_tokens: int = 0
|
|
request_count: int = 0
|
|
|
|
|
|
@dataclass
|
|
class LLMRouter:
|
|
"""
|
|
Routes LLM requests to appropriate backends.
|
|
|
|
Config example:
|
|
llm:
|
|
strategy: failover
|
|
retries: 3
|
|
backends:
|
|
- provider: xai
|
|
api_key_env: XAI_API_KEY
|
|
- provider: anthropic
|
|
api_key_env: ANTHROPIC_API_KEY
|
|
"""
|
|
backends: List[Backend] = field(default_factory=list)
|
|
strategy: Strategy = Strategy.FAILOVER
|
|
retries: int = 3
|
|
retry_base_delay: float = 1.0
|
|
retry_max_delay: float = 60.0
|
|
|
|
# Per-agent token tracking
|
|
_agent_usage: Dict[str, AgentUsage] = field(default_factory=dict, repr=False)
|
|
|
|
# Round-robin state
|
|
_rr_index: int = field(default=0, repr=False)
|
|
_rr_lock: asyncio.Lock = field(default=None, repr=False)
|
|
|
|
def __post_init__(self):
|
|
self._rr_lock = asyncio.Lock()
|
|
|
|
def add_backend(self, backend: Backend) -> None:
|
|
"""Add a backend to the router."""
|
|
self.backends.append(backend)
|
|
logger.info(f"Added backend: {backend.name} ({backend.provider})")
|
|
|
|
def _find_backends(self, model: str) -> List[Backend]:
|
|
"""Find all backends that can serve this model."""
|
|
candidates = [b for b in self.backends if b.serves_model(model)]
|
|
|
|
if not candidates:
|
|
raise ValueError(
|
|
f"No backend available for model '{model}'. "
|
|
f"Available backends: {[b.name for b in self.backends]}"
|
|
)
|
|
|
|
return candidates
|
|
|
|
async def _select_backend(self, candidates: List[Backend]) -> Backend:
|
|
"""Select a backend based on strategy."""
|
|
if self.strategy == Strategy.FAILOVER:
|
|
# Sort by priority (lower = preferred)
|
|
return sorted(candidates, key=lambda b: b.priority)[0]
|
|
|
|
elif self.strategy == Strategy.ROUND_ROBIN:
|
|
async with self._rr_lock:
|
|
# Filter to just candidates, round-robin among them
|
|
idx = self._rr_index % len(candidates)
|
|
self._rr_index += 1
|
|
return candidates[idx]
|
|
|
|
elif self.strategy == Strategy.LEAST_LOADED:
|
|
# Pick backend with lowest current load
|
|
return min(candidates, key=lambda b: b.load)
|
|
|
|
else:
|
|
return candidates[0]
|
|
|
|
async def complete(
|
|
self,
|
|
model: str,
|
|
messages: List[Dict[str, str]],
|
|
*,
|
|
temperature: float = 0.7,
|
|
max_tokens: int = None,
|
|
tools: List[Dict] = None,
|
|
agent_id: str = None,
|
|
thread_id: str = None,
|
|
metadata: Dict[str, Any] = None,
|
|
) -> LLMResponse:
|
|
"""
|
|
Execute a completion request.
|
|
|
|
Args:
|
|
model: Model name (e.g., "grok-4.1", "claude-sonnet-4")
|
|
messages: Chat messages
|
|
temperature: Sampling temperature
|
|
max_tokens: Max tokens in response
|
|
tools: Tool definitions for function calling
|
|
agent_id: Optional agent ID for usage tracking
|
|
thread_id: Optional thread ID for budget enforcement
|
|
metadata: Optional metadata for usage events (org_id, user_id, etc.)
|
|
|
|
Returns:
|
|
LLMResponse with content and usage stats
|
|
|
|
Raises:
|
|
BudgetExhaustedError: If thread has no remaining budget
|
|
BackendError: If all backends fail
|
|
"""
|
|
# Estimate tokens for budget check (rough: 4 chars per token)
|
|
estimated_tokens = sum(len(m.get("content", "")) for m in messages) // 4
|
|
estimated_tokens = max(estimated_tokens, 100) # minimum estimate
|
|
|
|
# Check thread budget before proceeding
|
|
if thread_id:
|
|
from xml_pipeline.message_bus.budget_registry import get_budget_registry
|
|
budget_registry = get_budget_registry()
|
|
# This raises BudgetExhaustedError if over budget
|
|
budget_registry.check_budget(thread_id, estimated_tokens)
|
|
|
|
candidates = self._find_backends(model)
|
|
request = LLMRequest(
|
|
model=model,
|
|
messages=messages,
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
tools=tools,
|
|
)
|
|
|
|
last_error = None
|
|
tried_backends = set()
|
|
start_time = time.monotonic()
|
|
|
|
for attempt in range(self.retries + 1):
|
|
# Select backend (different selection on retry for failover)
|
|
if self.strategy == Strategy.FAILOVER and tried_backends:
|
|
# Filter out already-tried backends
|
|
remaining = [b for b in candidates if b.name not in tried_backends]
|
|
if not remaining:
|
|
# All backends tried, start over with delay
|
|
remaining = candidates
|
|
backend = sorted(remaining, key=lambda b: b.priority)[0]
|
|
else:
|
|
backend = await self._select_backend(candidates)
|
|
|
|
tried_backends.add(backend.name)
|
|
|
|
try:
|
|
logger.debug(f"Attempting {model} on {backend.name} (attempt {attempt + 1})")
|
|
response = await backend.complete(request)
|
|
|
|
# Calculate latency
|
|
latency_ms = (time.monotonic() - start_time) * 1000
|
|
|
|
# Extract usage
|
|
prompt_tokens = response.usage.get("prompt_tokens", 0)
|
|
completion_tokens = response.usage.get("completion_tokens", 0)
|
|
total_tokens = response.usage.get("total_tokens", 0)
|
|
|
|
# Track per-agent usage (internal)
|
|
if agent_id:
|
|
usage = self._agent_usage.setdefault(agent_id, AgentUsage())
|
|
usage.total_tokens += total_tokens
|
|
usage.prompt_tokens += prompt_tokens
|
|
usage.completion_tokens += completion_tokens
|
|
usage.request_count += 1
|
|
|
|
# Record to thread budget (enforcement)
|
|
if thread_id:
|
|
from xml_pipeline.message_bus.budget_registry import get_budget_registry
|
|
budget_registry = get_budget_registry()
|
|
budget_registry.consume(
|
|
thread_id,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
)
|
|
|
|
# Emit usage event (for billing)
|
|
from xml_pipeline.llm.usage_tracker import get_usage_tracker
|
|
tracker = get_usage_tracker()
|
|
tracker.record(
|
|
thread_id=thread_id or "",
|
|
agent_id=agent_id,
|
|
model=response.model,
|
|
provider=backend.provider,
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
latency_ms=latency_ms,
|
|
metadata=metadata,
|
|
)
|
|
|
|
return response
|
|
|
|
except RateLimitError as e:
|
|
last_error = e
|
|
delay = e.retry_after or self._backoff_delay(attempt)
|
|
logger.warning(f"Rate limited on {backend.name}, waiting {delay:.1f}s")
|
|
await asyncio.sleep(delay)
|
|
|
|
except ProviderError as e:
|
|
last_error = e
|
|
delay = self._backoff_delay(attempt)
|
|
logger.warning(f"Provider error on {backend.name}: {e}, retrying in {delay:.1f}s")
|
|
await asyncio.sleep(delay)
|
|
|
|
except Exception as e:
|
|
last_error = e
|
|
logger.error(f"Unexpected error on {backend.name}: {e}")
|
|
if attempt < self.retries:
|
|
delay = self._backoff_delay(attempt)
|
|
await asyncio.sleep(delay)
|
|
|
|
# All retries exhausted
|
|
raise BackendError(f"All backends failed for {model}: {last_error}") from last_error
|
|
|
|
def _backoff_delay(self, attempt: int) -> float:
|
|
"""Calculate exponential backoff with jitter."""
|
|
delay = self.retry_base_delay * (2 ** attempt)
|
|
delay = min(delay, self.retry_max_delay)
|
|
# Add jitter (±25%)
|
|
jitter = delay * 0.25 * (random.random() * 2 - 1)
|
|
return delay + jitter
|
|
|
|
def get_agent_usage(self, agent_id: str) -> AgentUsage:
|
|
"""Get usage stats for an agent."""
|
|
return self._agent_usage.get(agent_id, AgentUsage())
|
|
|
|
def get_all_usage(self) -> Dict[str, AgentUsage]:
|
|
"""Get usage stats for all agents."""
|
|
return dict(self._agent_usage)
|
|
|
|
def reset_agent_usage(self, agent_id: str = None) -> None:
|
|
"""Reset usage stats for one or all agents."""
|
|
if agent_id:
|
|
self._agent_usage.pop(agent_id, None)
|
|
else:
|
|
self._agent_usage.clear()
|
|
|
|
async def close(self) -> None:
|
|
"""Clean up all backends."""
|
|
for backend in self.backends:
|
|
await backend.close()
|
|
|
|
|
|
# =============================================================================
|
|
# Global Router Instance
|
|
# =============================================================================
|
|
|
|
_router: Optional[LLMRouter] = None
|
|
|
|
|
|
def get_router() -> LLMRouter:
|
|
"""Get the global router instance."""
|
|
global _router
|
|
if _router is None:
|
|
_router = LLMRouter()
|
|
return _router
|
|
|
|
|
|
def configure_router(config: Dict[str, Any]) -> LLMRouter:
|
|
"""
|
|
Configure the global router from config dict.
|
|
|
|
Config format:
|
|
llm:
|
|
strategy: failover
|
|
retries: 3
|
|
backends:
|
|
- provider: xai
|
|
api_key_env: XAI_API_KEY
|
|
rate_limit_tpm: 100000
|
|
- provider: anthropic
|
|
api_key_env: ANTHROPIC_API_KEY
|
|
"""
|
|
global _router
|
|
|
|
strategy_str = config.get("strategy", "failover").lower().replace("-", "_")
|
|
try:
|
|
strategy = Strategy(strategy_str.replace("_", "-"))
|
|
except ValueError:
|
|
strategy = Strategy.FAILOVER
|
|
|
|
_router = LLMRouter(
|
|
strategy=strategy,
|
|
retries=config.get("retries", 3),
|
|
retry_base_delay=config.get("retry_base_delay", 1.0),
|
|
retry_max_delay=config.get("retry_max_delay", 60.0),
|
|
)
|
|
|
|
for backend_config in config.get("backends", []):
|
|
backend = create_backend(backend_config)
|
|
_router.add_backend(backend)
|
|
|
|
return _router
|
|
|
|
|
|
async def complete(
|
|
model: str,
|
|
messages: List[Dict[str, str]],
|
|
*,
|
|
thread_id: str = None,
|
|
agent_id: str = None,
|
|
metadata: Dict[str, Any] = None,
|
|
**kwargs,
|
|
) -> LLMResponse:
|
|
"""
|
|
Convenience function - calls get_router().complete().
|
|
|
|
Usage:
|
|
from xml_pipeline.llm import router
|
|
response = await router.complete(
|
|
"grok-4.1",
|
|
messages,
|
|
thread_id=metadata.thread_id,
|
|
agent_id=metadata.own_name,
|
|
)
|
|
|
|
Args:
|
|
model: Model name
|
|
messages: Chat messages
|
|
thread_id: Thread UUID for budget enforcement
|
|
agent_id: Agent name for usage tracking
|
|
metadata: Extra metadata for billing events
|
|
**kwargs: Additional arguments (temperature, max_tokens, tools)
|
|
|
|
Returns:
|
|
LLMResponse with content and usage stats
|
|
|
|
Raises:
|
|
BudgetExhaustedError: If thread budget exhausted
|
|
"""
|
|
return await get_router().complete(
|
|
model,
|
|
messages,
|
|
thread_id=thread_id,
|
|
agent_id=agent_id,
|
|
metadata=metadata,
|
|
**kwargs,
|
|
)
|