- Create BudgetWarning primitive payload (75%, 90%, 95% thresholds) - Add threshold tracking to ThreadBudget with triggered_thresholds set - Change consume() to return (budget, crossed_thresholds) tuple - Wire warning injection in LLM router when thresholds crossed - Add 15 new tests for threshold detection and warning injection Agents now receive BudgetWarning messages when approaching their token limit, allowing them to design contingencies (summarize, escalate, save state). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
353 lines
10 KiB
Python
353 lines
10 KiB
Python
"""
|
|
Thread Budget Registry — Enforces per-thread token limits.
|
|
|
|
Each thread has a token budget that tracks:
|
|
- Total tokens consumed (prompt + completion)
|
|
- Requests made
|
|
- Remaining budget
|
|
|
|
When a thread exhausts its budget, LLM calls are blocked.
|
|
|
|
Example config:
|
|
organism:
|
|
max_tokens_per_thread: 100000 # 100k tokens per thread
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import threading
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Optional, Set, Tuple
|
|
|
|
|
|
# Default warning thresholds (percent -> severity)
|
|
DEFAULT_WARNING_THRESHOLDS: Dict[int, str] = {
|
|
75: "warning", # 75% - early warning
|
|
90: "critical", # 90% - wrap up soon
|
|
95: "final", # 95% - last chance
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class BudgetThresholdCrossed:
|
|
"""Info about a threshold that was just crossed."""
|
|
threshold_percent: int
|
|
severity: str
|
|
percent_used: float
|
|
tokens_used: int
|
|
tokens_remaining: int
|
|
max_tokens: int
|
|
|
|
|
|
@dataclass
|
|
class ThreadBudget:
|
|
"""Track token usage for a single thread."""
|
|
|
|
max_tokens: int
|
|
prompt_tokens: int = 0
|
|
completion_tokens: int = 0
|
|
request_count: int = 0
|
|
triggered_thresholds: Set[int] = field(default_factory=set)
|
|
|
|
@property
|
|
def total_tokens(self) -> int:
|
|
"""Total tokens consumed."""
|
|
return self.prompt_tokens + self.completion_tokens
|
|
|
|
@property
|
|
def remaining(self) -> int:
|
|
"""Remaining token budget."""
|
|
return max(0, self.max_tokens - self.total_tokens)
|
|
|
|
@property
|
|
def is_exhausted(self) -> bool:
|
|
"""True if budget is exhausted."""
|
|
return self.total_tokens >= self.max_tokens
|
|
|
|
@property
|
|
def percent_used(self) -> float:
|
|
"""Percentage of budget consumed (0-100)."""
|
|
if self.max_tokens <= 0:
|
|
return 0.0
|
|
return (self.total_tokens / self.max_tokens) * 100
|
|
|
|
def can_consume(self, estimated_tokens: int) -> bool:
|
|
"""Check if we can consume the given tokens without exceeding budget."""
|
|
return self.total_tokens + estimated_tokens <= self.max_tokens
|
|
|
|
def consume(
|
|
self,
|
|
prompt_tokens: int = 0,
|
|
completion_tokens: int = 0,
|
|
) -> None:
|
|
"""Record token consumption."""
|
|
self.prompt_tokens += prompt_tokens
|
|
self.completion_tokens += completion_tokens
|
|
self.request_count += 1
|
|
|
|
def check_thresholds(
|
|
self,
|
|
thresholds: Dict[int, str] = None,
|
|
) -> List[BudgetThresholdCrossed]:
|
|
"""
|
|
Check if any thresholds were crossed that haven't been triggered yet.
|
|
|
|
Args:
|
|
thresholds: Dict of percent -> severity. Defaults to DEFAULT_WARNING_THRESHOLDS.
|
|
|
|
Returns:
|
|
List of newly crossed thresholds (sorted by percent)
|
|
"""
|
|
if thresholds is None:
|
|
thresholds = DEFAULT_WARNING_THRESHOLDS
|
|
|
|
crossed = []
|
|
current_percent = self.percent_used
|
|
|
|
for threshold_percent, severity in sorted(thresholds.items()):
|
|
if (
|
|
current_percent >= threshold_percent
|
|
and threshold_percent not in self.triggered_thresholds
|
|
):
|
|
self.triggered_thresholds.add(threshold_percent)
|
|
crossed.append(BudgetThresholdCrossed(
|
|
threshold_percent=threshold_percent,
|
|
severity=severity,
|
|
percent_used=round(current_percent, 1),
|
|
tokens_used=self.total_tokens,
|
|
tokens_remaining=self.remaining,
|
|
max_tokens=self.max_tokens,
|
|
))
|
|
|
|
return crossed
|
|
|
|
|
|
class BudgetExhaustedError(Exception):
|
|
"""Raised when a thread's token budget is exhausted."""
|
|
|
|
def __init__(self, thread_id: str, used: int, max_tokens: int):
|
|
self.thread_id = thread_id
|
|
self.used = used
|
|
self.max_tokens = max_tokens
|
|
super().__init__(
|
|
f"Thread {thread_id[:8]}... budget exhausted: "
|
|
f"{used}/{max_tokens} tokens used"
|
|
)
|
|
|
|
|
|
class ThreadBudgetRegistry:
|
|
"""
|
|
Manages token budgets per thread.
|
|
|
|
Thread-safe for concurrent access.
|
|
|
|
Usage:
|
|
registry = get_budget_registry()
|
|
registry.configure(max_tokens_per_thread=100000)
|
|
|
|
# Before LLM call
|
|
registry.check_budget(thread_id, estimated_tokens=1000)
|
|
|
|
# After LLM call
|
|
registry.consume(thread_id, prompt=500, completion=300)
|
|
|
|
# Get usage
|
|
budget = registry.get_budget(thread_id)
|
|
print(f"Used: {budget.total_tokens}, Remaining: {budget.remaining}")
|
|
"""
|
|
|
|
def __init__(self, max_tokens_per_thread: int = 100_000):
|
|
"""
|
|
Initialize budget registry.
|
|
|
|
Args:
|
|
max_tokens_per_thread: Default budget for new threads.
|
|
"""
|
|
self._max_tokens_per_thread = max_tokens_per_thread
|
|
self._budgets: Dict[str, ThreadBudget] = {}
|
|
self._lock = threading.Lock()
|
|
|
|
def configure(self, max_tokens_per_thread: int) -> None:
|
|
"""
|
|
Update default max tokens for new threads.
|
|
|
|
Existing threads keep their current budgets.
|
|
"""
|
|
with self._lock:
|
|
self._max_tokens_per_thread = max_tokens_per_thread
|
|
|
|
@property
|
|
def max_tokens_per_thread(self) -> int:
|
|
"""Get the default max tokens per thread."""
|
|
return self._max_tokens_per_thread
|
|
|
|
def get_budget(self, thread_id: str) -> ThreadBudget:
|
|
"""
|
|
Get or create budget for a thread.
|
|
|
|
Args:
|
|
thread_id: Thread UUID
|
|
|
|
Returns:
|
|
ThreadBudget instance
|
|
"""
|
|
with self._lock:
|
|
if thread_id not in self._budgets:
|
|
self._budgets[thread_id] = ThreadBudget(
|
|
max_tokens=self._max_tokens_per_thread
|
|
)
|
|
return self._budgets[thread_id]
|
|
|
|
def check_budget(
|
|
self,
|
|
thread_id: str,
|
|
estimated_tokens: int = 0,
|
|
) -> bool:
|
|
"""
|
|
Check if thread has budget for the estimated tokens.
|
|
|
|
Args:
|
|
thread_id: Thread UUID
|
|
estimated_tokens: Estimated tokens for the request
|
|
|
|
Returns:
|
|
True if budget available
|
|
|
|
Raises:
|
|
BudgetExhaustedError if budget is exhausted
|
|
"""
|
|
budget = self.get_budget(thread_id)
|
|
|
|
if budget.is_exhausted:
|
|
raise BudgetExhaustedError(
|
|
thread_id=thread_id,
|
|
used=budget.total_tokens,
|
|
max_tokens=budget.max_tokens,
|
|
)
|
|
|
|
if not budget.can_consume(estimated_tokens):
|
|
raise BudgetExhaustedError(
|
|
thread_id=thread_id,
|
|
used=budget.total_tokens,
|
|
max_tokens=budget.max_tokens,
|
|
)
|
|
|
|
return True
|
|
|
|
def consume(
|
|
self,
|
|
thread_id: str,
|
|
prompt_tokens: int = 0,
|
|
completion_tokens: int = 0,
|
|
) -> Tuple[ThreadBudget, List[BudgetThresholdCrossed]]:
|
|
"""
|
|
Record token consumption for a thread.
|
|
|
|
Args:
|
|
thread_id: Thread UUID
|
|
prompt_tokens: Prompt tokens used
|
|
completion_tokens: Completion tokens used
|
|
|
|
Returns:
|
|
Tuple of (Updated ThreadBudget, List of newly crossed thresholds)
|
|
"""
|
|
budget = self.get_budget(thread_id)
|
|
with self._lock:
|
|
budget.consume(prompt_tokens, completion_tokens)
|
|
crossed = budget.check_thresholds()
|
|
return budget, crossed
|
|
|
|
def has_budget(self, thread_id: str) -> bool:
|
|
"""Check if a thread has a budget entry (without creating one)."""
|
|
with self._lock:
|
|
return thread_id in self._budgets
|
|
|
|
def get_usage(self, thread_id: str) -> Optional[Dict[str, int]]:
|
|
"""
|
|
Get usage stats for a thread.
|
|
|
|
Returns:
|
|
Dict with prompt_tokens, completion_tokens, total_tokens,
|
|
remaining, max_tokens, request_count.
|
|
Returns None if thread has no budget.
|
|
"""
|
|
with self._lock:
|
|
if thread_id not in self._budgets:
|
|
return None
|
|
budget = self._budgets[thread_id]
|
|
return {
|
|
"prompt_tokens": budget.prompt_tokens,
|
|
"completion_tokens": budget.completion_tokens,
|
|
"total_tokens": budget.total_tokens,
|
|
"remaining": budget.remaining,
|
|
"max_tokens": budget.max_tokens,
|
|
"request_count": budget.request_count,
|
|
}
|
|
|
|
def get_all_usage(self) -> Dict[str, Dict[str, int]]:
|
|
"""Get usage stats for all threads."""
|
|
with self._lock:
|
|
return {
|
|
thread_id: {
|
|
"prompt_tokens": b.prompt_tokens,
|
|
"completion_tokens": b.completion_tokens,
|
|
"total_tokens": b.total_tokens,
|
|
"remaining": b.remaining,
|
|
"max_tokens": b.max_tokens,
|
|
"request_count": b.request_count,
|
|
}
|
|
for thread_id, b in self._budgets.items()
|
|
}
|
|
|
|
def reset_thread(self, thread_id: str) -> None:
|
|
"""Reset budget for a specific thread."""
|
|
with self._lock:
|
|
self._budgets.pop(thread_id, None)
|
|
|
|
def cleanup_thread(self, thread_id: str) -> Optional[ThreadBudget]:
|
|
"""
|
|
Remove budget when thread is pruned/completed.
|
|
|
|
Returns the final budget for logging/billing, or None if not found.
|
|
"""
|
|
with self._lock:
|
|
return self._budgets.pop(thread_id, None)
|
|
|
|
def clear(self) -> None:
|
|
"""Clear all budgets (for testing)."""
|
|
with self._lock:
|
|
self._budgets.clear()
|
|
|
|
|
|
# =============================================================================
|
|
# Global Instance
|
|
# =============================================================================
|
|
|
|
_registry: Optional[ThreadBudgetRegistry] = None
|
|
_registry_lock = threading.Lock()
|
|
|
|
|
|
def get_budget_registry() -> ThreadBudgetRegistry:
|
|
"""Get the global budget registry."""
|
|
global _registry
|
|
if _registry is None:
|
|
with _registry_lock:
|
|
if _registry is None:
|
|
_registry = ThreadBudgetRegistry()
|
|
return _registry
|
|
|
|
|
|
def configure_budget_registry(max_tokens_per_thread: int) -> ThreadBudgetRegistry:
|
|
"""Configure the global budget registry."""
|
|
registry = get_budget_registry()
|
|
registry.configure(max_tokens_per_thread)
|
|
return registry
|
|
|
|
|
|
def reset_budget_registry() -> None:
|
|
"""Reset the global registry (for testing)."""
|
|
global _registry
|
|
with _registry_lock:
|
|
if _registry is not None:
|
|
_registry.clear()
|
|
_registry = None
|