xml-pipeline/xml_pipeline/message_bus/budget_registry.py
dullfig e6697f0ea2 Add BudgetWarning system alerts for token budget thresholds
- Create BudgetWarning primitive payload (75%, 90%, 95% thresholds)
- Add threshold tracking to ThreadBudget with triggered_thresholds set
- Change consume() to return (budget, crossed_thresholds) tuple
- Wire warning injection in LLM router when thresholds crossed
- Add 15 new tests for threshold detection and warning injection

Agents now receive BudgetWarning messages when approaching their token limit,
allowing them to design contingencies (summarize, escalate, save state).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 21:41:34 -08:00

353 lines
10 KiB
Python

"""
Thread Budget Registry — Enforces per-thread token limits.
Each thread has a token budget that tracks:
- Total tokens consumed (prompt + completion)
- Requests made
- Remaining budget
When a thread exhausts its budget, LLM calls are blocked.
Example config:
organism:
max_tokens_per_thread: 100000 # 100k tokens per thread
"""
from __future__ import annotations
import threading
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Set, Tuple
# Default warning thresholds (percent -> severity)
DEFAULT_WARNING_THRESHOLDS: Dict[int, str] = {
75: "warning", # 75% - early warning
90: "critical", # 90% - wrap up soon
95: "final", # 95% - last chance
}
@dataclass
class BudgetThresholdCrossed:
"""Info about a threshold that was just crossed."""
threshold_percent: int
severity: str
percent_used: float
tokens_used: int
tokens_remaining: int
max_tokens: int
@dataclass
class ThreadBudget:
"""Track token usage for a single thread."""
max_tokens: int
prompt_tokens: int = 0
completion_tokens: int = 0
request_count: int = 0
triggered_thresholds: Set[int] = field(default_factory=set)
@property
def total_tokens(self) -> int:
"""Total tokens consumed."""
return self.prompt_tokens + self.completion_tokens
@property
def remaining(self) -> int:
"""Remaining token budget."""
return max(0, self.max_tokens - self.total_tokens)
@property
def is_exhausted(self) -> bool:
"""True if budget is exhausted."""
return self.total_tokens >= self.max_tokens
@property
def percent_used(self) -> float:
"""Percentage of budget consumed (0-100)."""
if self.max_tokens <= 0:
return 0.0
return (self.total_tokens / self.max_tokens) * 100
def can_consume(self, estimated_tokens: int) -> bool:
"""Check if we can consume the given tokens without exceeding budget."""
return self.total_tokens + estimated_tokens <= self.max_tokens
def consume(
self,
prompt_tokens: int = 0,
completion_tokens: int = 0,
) -> None:
"""Record token consumption."""
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
self.request_count += 1
def check_thresholds(
self,
thresholds: Dict[int, str] = None,
) -> List[BudgetThresholdCrossed]:
"""
Check if any thresholds were crossed that haven't been triggered yet.
Args:
thresholds: Dict of percent -> severity. Defaults to DEFAULT_WARNING_THRESHOLDS.
Returns:
List of newly crossed thresholds (sorted by percent)
"""
if thresholds is None:
thresholds = DEFAULT_WARNING_THRESHOLDS
crossed = []
current_percent = self.percent_used
for threshold_percent, severity in sorted(thresholds.items()):
if (
current_percent >= threshold_percent
and threshold_percent not in self.triggered_thresholds
):
self.triggered_thresholds.add(threshold_percent)
crossed.append(BudgetThresholdCrossed(
threshold_percent=threshold_percent,
severity=severity,
percent_used=round(current_percent, 1),
tokens_used=self.total_tokens,
tokens_remaining=self.remaining,
max_tokens=self.max_tokens,
))
return crossed
class BudgetExhaustedError(Exception):
"""Raised when a thread's token budget is exhausted."""
def __init__(self, thread_id: str, used: int, max_tokens: int):
self.thread_id = thread_id
self.used = used
self.max_tokens = max_tokens
super().__init__(
f"Thread {thread_id[:8]}... budget exhausted: "
f"{used}/{max_tokens} tokens used"
)
class ThreadBudgetRegistry:
"""
Manages token budgets per thread.
Thread-safe for concurrent access.
Usage:
registry = get_budget_registry()
registry.configure(max_tokens_per_thread=100000)
# Before LLM call
registry.check_budget(thread_id, estimated_tokens=1000)
# After LLM call
registry.consume(thread_id, prompt=500, completion=300)
# Get usage
budget = registry.get_budget(thread_id)
print(f"Used: {budget.total_tokens}, Remaining: {budget.remaining}")
"""
def __init__(self, max_tokens_per_thread: int = 100_000):
"""
Initialize budget registry.
Args:
max_tokens_per_thread: Default budget for new threads.
"""
self._max_tokens_per_thread = max_tokens_per_thread
self._budgets: Dict[str, ThreadBudget] = {}
self._lock = threading.Lock()
def configure(self, max_tokens_per_thread: int) -> None:
"""
Update default max tokens for new threads.
Existing threads keep their current budgets.
"""
with self._lock:
self._max_tokens_per_thread = max_tokens_per_thread
@property
def max_tokens_per_thread(self) -> int:
"""Get the default max tokens per thread."""
return self._max_tokens_per_thread
def get_budget(self, thread_id: str) -> ThreadBudget:
"""
Get or create budget for a thread.
Args:
thread_id: Thread UUID
Returns:
ThreadBudget instance
"""
with self._lock:
if thread_id not in self._budgets:
self._budgets[thread_id] = ThreadBudget(
max_tokens=self._max_tokens_per_thread
)
return self._budgets[thread_id]
def check_budget(
self,
thread_id: str,
estimated_tokens: int = 0,
) -> bool:
"""
Check if thread has budget for the estimated tokens.
Args:
thread_id: Thread UUID
estimated_tokens: Estimated tokens for the request
Returns:
True if budget available
Raises:
BudgetExhaustedError if budget is exhausted
"""
budget = self.get_budget(thread_id)
if budget.is_exhausted:
raise BudgetExhaustedError(
thread_id=thread_id,
used=budget.total_tokens,
max_tokens=budget.max_tokens,
)
if not budget.can_consume(estimated_tokens):
raise BudgetExhaustedError(
thread_id=thread_id,
used=budget.total_tokens,
max_tokens=budget.max_tokens,
)
return True
def consume(
self,
thread_id: str,
prompt_tokens: int = 0,
completion_tokens: int = 0,
) -> Tuple[ThreadBudget, List[BudgetThresholdCrossed]]:
"""
Record token consumption for a thread.
Args:
thread_id: Thread UUID
prompt_tokens: Prompt tokens used
completion_tokens: Completion tokens used
Returns:
Tuple of (Updated ThreadBudget, List of newly crossed thresholds)
"""
budget = self.get_budget(thread_id)
with self._lock:
budget.consume(prompt_tokens, completion_tokens)
crossed = budget.check_thresholds()
return budget, crossed
def has_budget(self, thread_id: str) -> bool:
"""Check if a thread has a budget entry (without creating one)."""
with self._lock:
return thread_id in self._budgets
def get_usage(self, thread_id: str) -> Optional[Dict[str, int]]:
"""
Get usage stats for a thread.
Returns:
Dict with prompt_tokens, completion_tokens, total_tokens,
remaining, max_tokens, request_count.
Returns None if thread has no budget.
"""
with self._lock:
if thread_id not in self._budgets:
return None
budget = self._budgets[thread_id]
return {
"prompt_tokens": budget.prompt_tokens,
"completion_tokens": budget.completion_tokens,
"total_tokens": budget.total_tokens,
"remaining": budget.remaining,
"max_tokens": budget.max_tokens,
"request_count": budget.request_count,
}
def get_all_usage(self) -> Dict[str, Dict[str, int]]:
"""Get usage stats for all threads."""
with self._lock:
return {
thread_id: {
"prompt_tokens": b.prompt_tokens,
"completion_tokens": b.completion_tokens,
"total_tokens": b.total_tokens,
"remaining": b.remaining,
"max_tokens": b.max_tokens,
"request_count": b.request_count,
}
for thread_id, b in self._budgets.items()
}
def reset_thread(self, thread_id: str) -> None:
"""Reset budget for a specific thread."""
with self._lock:
self._budgets.pop(thread_id, None)
def cleanup_thread(self, thread_id: str) -> Optional[ThreadBudget]:
"""
Remove budget when thread is pruned/completed.
Returns the final budget for logging/billing, or None if not found.
"""
with self._lock:
return self._budgets.pop(thread_id, None)
def clear(self) -> None:
"""Clear all budgets (for testing)."""
with self._lock:
self._budgets.clear()
# =============================================================================
# Global Instance
# =============================================================================
_registry: Optional[ThreadBudgetRegistry] = None
_registry_lock = threading.Lock()
def get_budget_registry() -> ThreadBudgetRegistry:
"""Get the global budget registry."""
global _registry
if _registry is None:
with _registry_lock:
if _registry is None:
_registry = ThreadBudgetRegistry()
return _registry
def configure_budget_registry(max_tokens_per_thread: int) -> ThreadBudgetRegistry:
"""Configure the global budget registry."""
registry = get_budget_registry()
registry.configure(max_tokens_per_thread)
return registry
def reset_budget_registry() -> None:
"""Reset the global registry (for testing)."""
global _registry
with _registry_lock:
if _registry is not None:
_registry.clear()
_registry = None