Add BudgetWarning system alerts for token budget thresholds

- Create BudgetWarning primitive payload (75%, 90%, 95% thresholds)
- Add threshold tracking to ThreadBudget with triggered_thresholds set
- Change consume() to return (budget, crossed_thresholds) tuple
- Wire warning injection in LLM router when thresholds crossed
- Add 15 new tests for threshold detection and warning injection

Agents now receive BudgetWarning messages when approaching their token limit,
allowing them to design contingencies (summarize, escalate, save state).

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
dullfig 2026-01-27 21:41:34 -08:00
parent f98a21f96b
commit e6697f0ea2
5 changed files with 630 additions and 21 deletions

View file

@ -8,12 +8,14 @@ Tests:
""" """
import pytest import pytest
from unittest.mock import Mock, AsyncMock, patch from unittest.mock import Mock, AsyncMock, patch, MagicMock
from xml_pipeline.message_bus.budget_registry import ( from xml_pipeline.message_bus.budget_registry import (
ThreadBudget, ThreadBudget,
ThreadBudgetRegistry, ThreadBudgetRegistry,
BudgetExhaustedError, BudgetExhaustedError,
BudgetThresholdCrossed,
DEFAULT_WARNING_THRESHOLDS,
get_budget_registry, get_budget_registry,
configure_budget_registry, configure_budget_registry,
reset_budget_registry, reset_budget_registry,
@ -122,7 +124,7 @@ class TestThreadBudgetRegistry:
def test_check_budget_exhausted(self): def test_check_budget_exhausted(self):
"""check_budget should raise when budget exhausted.""" """check_budget should raise when budget exhausted."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000) registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
registry.consume("thread-1", prompt_tokens=1000) budget, _ = registry.consume("thread-1", prompt_tokens=1000)
with pytest.raises(BudgetExhaustedError) as exc_info: with pytest.raises(BudgetExhaustedError) as exc_info:
registry.check_budget("thread-1", estimated_tokens=100) registry.check_budget("thread-1", estimated_tokens=100)
@ -135,25 +137,26 @@ class TestThreadBudgetRegistry:
def test_check_budget_would_exceed(self): def test_check_budget_would_exceed(self):
"""check_budget should raise when estimate would exceed.""" """check_budget should raise when estimate would exceed."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000) registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
registry.consume("thread-1", prompt_tokens=600) budget, _ = registry.consume("thread-1", prompt_tokens=600)
with pytest.raises(BudgetExhaustedError): with pytest.raises(BudgetExhaustedError):
registry.check_budget("thread-1", estimated_tokens=500) registry.check_budget("thread-1", estimated_tokens=500)
def test_consume_returns_budget(self): def test_consume_returns_budget_and_thresholds(self):
"""consume() should return updated budget.""" """consume() should return (budget, crossed_thresholds) tuple."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000) registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
budget = registry.consume("thread-1", prompt_tokens=100, completion_tokens=50) budget, crossed = registry.consume("thread-1", prompt_tokens=100, completion_tokens=50)
assert budget.total_tokens == 150 assert budget.total_tokens == 150
assert budget.request_count == 1 assert budget.request_count == 1
assert crossed == [] # 1.5% - no threshold crossed
def test_get_usage(self): def test_get_usage(self):
"""get_usage should return dict with all stats.""" """get_usage should return dict with all stats."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000) registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
registry.consume("thread-1", prompt_tokens=500, completion_tokens=200) budget, _ = registry.consume("thread-1", prompt_tokens=500, completion_tokens=200)
registry.consume("thread-1", prompt_tokens=300, completion_tokens=100) budget, _ = registry.consume("thread-1", prompt_tokens=300, completion_tokens=100)
usage = registry.get_usage("thread-1") usage = registry.get_usage("thread-1")
@ -166,8 +169,8 @@ class TestThreadBudgetRegistry:
def test_get_all_usage(self): def test_get_all_usage(self):
"""get_all_usage should return all threads.""" """get_all_usage should return all threads."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000) registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
registry.consume("thread-1", prompt_tokens=100) budget, _ = registry.consume("thread-1", prompt_tokens=100)
registry.consume("thread-2", prompt_tokens=200) budget, _ = registry.consume("thread-2", prompt_tokens=200)
all_usage = registry.get_all_usage() all_usage = registry.get_all_usage()
@ -178,7 +181,7 @@ class TestThreadBudgetRegistry:
def test_reset_thread(self): def test_reset_thread(self):
"""reset_thread should remove budget for thread.""" """reset_thread should remove budget for thread."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000) registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
registry.consume("thread-1", prompt_tokens=500) budget, _ = registry.consume("thread-1", prompt_tokens=500)
registry.reset_thread("thread-1") registry.reset_thread("thread-1")
# Getting budget should create new one with zero usage # Getting budget should create new one with zero usage
@ -188,7 +191,7 @@ class TestThreadBudgetRegistry:
def test_cleanup_thread(self): def test_cleanup_thread(self):
"""cleanup_thread should return and remove budget.""" """cleanup_thread should return and remove budget."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000) registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
registry.consume("thread-1", prompt_tokens=500) budget, _ = registry.consume("thread-1", prompt_tokens=500)
final_budget = registry.cleanup_thread("thread-1") final_budget = registry.cleanup_thread("thread-1")
@ -520,7 +523,7 @@ class TestLLMRouterBudgetIntegration:
# Configure small budget and exhaust it # Configure small budget and exhaust it
configure_budget_registry(max_tokens_per_thread=100) configure_budget_registry(max_tokens_per_thread=100)
budget_registry = get_budget_registry() budget_registry = get_budget_registry()
budget_registry.consume("test-thread", prompt_tokens=100) budget, _ = budget_registry.consume("test-thread", prompt_tokens=100)
mock_backend = Mock() mock_backend = Mock()
mock_backend.name = "mock" mock_backend.name = "mock"
@ -590,7 +593,7 @@ class TestBudgetCleanup:
def test_cleanup_thread_returns_budget(self): def test_cleanup_thread_returns_budget(self):
"""cleanup_thread should return the budget before removing it.""" """cleanup_thread should return the budget before removing it."""
registry = ThreadBudgetRegistry() registry = ThreadBudgetRegistry()
registry.consume("thread-1", prompt_tokens=500, completion_tokens=200) budget, _ = registry.consume("thread-1", prompt_tokens=500, completion_tokens=200)
final = registry.cleanup_thread("thread-1") final = registry.cleanup_thread("thread-1")
@ -602,7 +605,7 @@ class TestBudgetCleanup:
def test_cleanup_thread_removes_budget(self): def test_cleanup_thread_removes_budget(self):
"""cleanup_thread should remove the budget from registry.""" """cleanup_thread should remove the budget from registry."""
registry = ThreadBudgetRegistry() registry = ThreadBudgetRegistry()
registry.consume("thread-1", prompt_tokens=500, completion_tokens=200) budget, _ = registry.consume("thread-1", prompt_tokens=500, completion_tokens=200)
registry.cleanup_thread("thread-1") registry.cleanup_thread("thread-1")
@ -624,7 +627,7 @@ class TestBudgetCleanup:
registry = get_budget_registry() registry = get_budget_registry()
# Consume some tokens # Consume some tokens
registry.consume("test-thread", prompt_tokens=1000, completion_tokens=500) budget, _ = registry.consume("test-thread", prompt_tokens=1000, completion_tokens=500)
assert registry.has_budget("test-thread") assert registry.has_budget("test-thread")
# Cleanup # Cleanup
@ -632,3 +635,389 @@ class TestBudgetCleanup:
assert final.total_tokens == 1500 assert final.total_tokens == 1500
assert not registry.has_budget("test-thread") assert not registry.has_budget("test-thread")
# ============================================================================
# Budget Threshold Tests
# ============================================================================
class TestBudgetThresholds:
"""Test budget threshold crossing detection."""
@pytest.fixture(autouse=True)
def reset_all(self):
"""Reset all global registries."""
reset_budget_registry()
yield
reset_budget_registry()
def test_no_threshold_crossed_below_first(self):
"""No thresholds crossed when under 75%."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
# 70% usage - below first threshold
budget, crossed = registry.consume("thread-1", prompt_tokens=700)
assert crossed == []
assert budget.percent_used == 70.0
def test_warning_threshold_crossed_at_75(self):
"""75% threshold should trigger 'warning' severity."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
# 75% usage - exactly at threshold
budget, crossed = registry.consume("thread-1", prompt_tokens=750)
assert len(crossed) == 1
assert crossed[0].threshold_percent == 75
assert crossed[0].severity == "warning"
assert crossed[0].percent_used == 75.0
assert crossed[0].tokens_used == 750
assert crossed[0].tokens_remaining == 250
def test_critical_threshold_crossed_at_90(self):
"""90% threshold should trigger 'critical' severity."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
# First consume to 76% (triggers warning)
budget, crossed = registry.consume("thread-1", prompt_tokens=760)
assert len(crossed) == 1
assert crossed[0].severity == "warning"
# Now consume to 91% (triggers critical)
budget, crossed = registry.consume("thread-1", prompt_tokens=150)
assert len(crossed) == 1
assert crossed[0].threshold_percent == 90
assert crossed[0].severity == "critical"
def test_final_threshold_crossed_at_95(self):
"""95% threshold should trigger 'final' severity."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
# Jump directly to 96%
budget, crossed = registry.consume("thread-1", prompt_tokens=960)
# Should have crossed all three thresholds
assert len(crossed) == 3
severities = {c.severity for c in crossed}
assert severities == {"warning", "critical", "final"}
def test_thresholds_only_triggered_once(self):
"""Each threshold should only be triggered once per thread."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
# Cross the 75% threshold
budget, crossed1 = registry.consume("thread-1", prompt_tokens=760)
assert len(crossed1) == 1
# Stay at 76% with another consume - should not trigger again
budget, crossed2 = registry.consume("thread-1", prompt_tokens=0)
assert crossed2 == []
# Consume more to stay between 75-90% - should not trigger
budget, crossed3 = registry.consume("thread-1", prompt_tokens=50) # Now at 81%
assert crossed3 == []
def test_multiple_thresholds_in_single_consume(self):
"""A single large consume can cross multiple thresholds."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
# Single consume from 0% to 92% - crosses 75% and 90%
budget, crossed = registry.consume("thread-1", prompt_tokens=920)
assert len(crossed) == 2
thresholds = {c.threshold_percent for c in crossed}
assert thresholds == {75, 90}
def test_threshold_data_accuracy(self):
"""Threshold crossing data should be accurate."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
# Consume to 75.5%
budget, crossed = registry.consume("thread-1", prompt_tokens=7550)
assert len(crossed) == 1
threshold = crossed[0]
assert threshold.threshold_percent == 75
assert threshold.percent_used == 75.5
assert threshold.tokens_used == 7550
assert threshold.tokens_remaining == 2450
assert threshold.max_tokens == 10000
def test_percent_used_property(self):
"""percent_used property should calculate correctly."""
budget = ThreadBudget(max_tokens=1000)
assert budget.percent_used == 0.0
budget.consume(prompt_tokens=500)
assert budget.percent_used == 50.0
budget.consume(prompt_tokens=250)
assert budget.percent_used == 75.0
def test_percent_used_with_zero_max(self):
"""percent_used should handle zero max_tokens gracefully."""
budget = ThreadBudget(max_tokens=0)
assert budget.percent_used == 0.0
def test_triggered_thresholds_tracked_per_thread(self):
"""Different threads should track thresholds independently."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
# Thread 1 crosses 75%
budget1, crossed1 = registry.consume("thread-1", prompt_tokens=760)
assert len(crossed1) == 1
# Thread 2 crosses 75% independently
budget2, crossed2 = registry.consume("thread-2", prompt_tokens=800)
assert len(crossed2) == 1
# Both should have triggered their thresholds
assert 75 in budget1.triggered_thresholds
assert 75 in budget2.triggered_thresholds
# ============================================================================
# BudgetWarning Injection Tests
# ============================================================================
class TestBudgetWarningInjection:
"""Test BudgetWarning messages are injected when thresholds crossed."""
@pytest.fixture(autouse=True)
def reset_all(self):
"""Reset all global registries."""
reset_budget_registry()
reset_usage_tracker()
yield
reset_budget_registry()
reset_usage_tracker()
@pytest.mark.asyncio
async def test_warning_injected_on_threshold_crossing(self):
"""LLM complete should inject BudgetWarning when threshold crossed."""
from xml_pipeline.llm.router import LLMRouter
from xml_pipeline.llm.backend import LLMResponse
from xml_pipeline.primitives.budget_warning import BudgetWarning
# Configure budget for 1000 tokens
configure_budget_registry(max_tokens_per_thread=1000)
# Create mock backend that returns 760 tokens (crosses 75% threshold)
mock_backend = Mock()
mock_backend.name = "mock"
mock_backend.provider = "test"
mock_backend.serves_model = Mock(return_value=True)
mock_backend.priority = 1
mock_backend.load = 0
mock_backend.complete = AsyncMock(return_value=LLMResponse(
content="Hello!",
model="test-model",
usage={"prompt_tokens": 700, "completion_tokens": 60, "total_tokens": 760},
finish_reason="stop",
))
router = LLMRouter()
router.backends.append(mock_backend)
# Mock the pump to capture injected warnings
injected_messages = []
async def mock_inject(raw_bytes, thread_id, from_id):
injected_messages.append({
"raw_bytes": raw_bytes,
"thread_id": thread_id,
"from_id": from_id,
})
mock_pump = Mock()
mock_pump.inject = AsyncMock(side_effect=mock_inject)
mock_pump._wrap_in_envelope = Mock(return_value=b"<envelope>warning</envelope>")
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
await router.complete(
model="test-model",
messages=[{"role": "user", "content": "Hi"}],
thread_id="test-thread",
agent_id="greeter",
)
# Should have injected one warning (75% threshold)
assert len(injected_messages) == 1
assert injected_messages[0]["from_id"] == "system.budget"
assert injected_messages[0]["thread_id"] == "test-thread"
@pytest.mark.asyncio
async def test_multiple_warnings_injected(self):
"""Multiple thresholds crossed should inject multiple warnings."""
from xml_pipeline.llm.router import LLMRouter
from xml_pipeline.llm.backend import LLMResponse
configure_budget_registry(max_tokens_per_thread=1000)
# Backend returns 960 tokens (crosses 75%, 90%, and 95%)
mock_backend = Mock()
mock_backend.name = "mock"
mock_backend.provider = "test"
mock_backend.serves_model = Mock(return_value=True)
mock_backend.priority = 1
mock_backend.load = 0
mock_backend.complete = AsyncMock(return_value=LLMResponse(
content="Big response",
model="test-model",
usage={"prompt_tokens": 900, "completion_tokens": 60, "total_tokens": 960},
finish_reason="stop",
))
router = LLMRouter()
router.backends.append(mock_backend)
injected_messages = []
mock_pump = Mock()
mock_pump.inject = AsyncMock(side_effect=lambda *args, **kwargs: injected_messages.append(args))
mock_pump._wrap_in_envelope = Mock(return_value=b"<envelope>warning</envelope>")
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
await router.complete(
model="test-model",
messages=[{"role": "user", "content": "Hi"}],
thread_id="test-thread",
agent_id="agent",
)
# Should have injected 3 warnings (75%, 90%, 95%)
assert len(injected_messages) == 3
@pytest.mark.asyncio
async def test_no_warning_when_below_threshold(self):
"""No warning should be injected when below all thresholds."""
from xml_pipeline.llm.router import LLMRouter
from xml_pipeline.llm.backend import LLMResponse
configure_budget_registry(max_tokens_per_thread=10000)
# Backend returns 100 tokens (only 1%)
mock_backend = Mock()
mock_backend.name = "mock"
mock_backend.provider = "test"
mock_backend.serves_model = Mock(return_value=True)
mock_backend.priority = 1
mock_backend.load = 0
mock_backend.complete = AsyncMock(return_value=LLMResponse(
content="Small response",
model="test-model",
usage={"prompt_tokens": 80, "completion_tokens": 20, "total_tokens": 100},
finish_reason="stop",
))
router = LLMRouter()
router.backends.append(mock_backend)
mock_pump = Mock()
mock_pump.inject = AsyncMock()
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
await router.complete(
model="test-model",
messages=[{"role": "user", "content": "Hi"}],
thread_id="test-thread",
)
# No warnings should be injected
mock_pump.inject.assert_not_called()
@pytest.mark.asyncio
async def test_warning_includes_correct_severity(self):
"""Warning messages should have correct severity based on threshold."""
from xml_pipeline.llm.router import LLMRouter
from xml_pipeline.llm.backend import LLMResponse
from xml_pipeline.primitives.budget_warning import BudgetWarning
configure_budget_registry(max_tokens_per_thread=1000)
mock_backend = Mock()
mock_backend.name = "mock"
mock_backend.provider = "test"
mock_backend.serves_model = Mock(return_value=True)
mock_backend.priority = 1
mock_backend.load = 0
mock_backend.complete = AsyncMock(return_value=LLMResponse(
content="Response",
model="test-model",
usage={"prompt_tokens": 910, "completion_tokens": 0, "total_tokens": 910},
finish_reason="stop",
))
router = LLMRouter()
router.backends.append(mock_backend)
captured_payloads = []
def capture_wrap(payload, from_id, to_id, thread_id):
captured_payloads.append({
"payload": payload,
"from_id": from_id,
"to_id": to_id,
})
return b"<envelope/>"
mock_pump = Mock()
mock_pump.inject = AsyncMock()
mock_pump._wrap_in_envelope = Mock(side_effect=capture_wrap)
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
await router.complete(
model="test-model",
messages=[{"role": "user", "content": "Hi"}],
thread_id="test-thread",
agent_id="agent",
)
# Should have captured 2 payloads (75% and 90%)
assert len(captured_payloads) == 2
severities = [p["payload"].severity for p in captured_payloads]
assert "warning" in severities
assert "critical" in severities
# Check the critical warning has appropriate message
critical = next(p for p in captured_payloads if p["payload"].severity == "critical")
assert "WARNING" in critical["payload"].message
assert "Consider wrapping up" in critical["payload"].message
@pytest.mark.asyncio
async def test_warning_graceful_without_pump(self):
"""Warning injection should gracefully handle missing pump."""
from xml_pipeline.llm.router import LLMRouter
from xml_pipeline.llm.backend import LLMResponse
configure_budget_registry(max_tokens_per_thread=1000)
mock_backend = Mock()
mock_backend.name = "mock"
mock_backend.provider = "test"
mock_backend.serves_model = Mock(return_value=True)
mock_backend.priority = 1
mock_backend.load = 0
mock_backend.complete = AsyncMock(return_value=LLMResponse(
content="Response",
model="test-model",
usage={"prompt_tokens": 800, "completion_tokens": 0, "total_tokens": 800},
finish_reason="stop",
))
router = LLMRouter()
router.backends.append(mock_backend)
# Pump not initialized - should not raise
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', side_effect=RuntimeError("Not initialized")):
response = await router.complete(
model="test-model",
messages=[{"role": "user", "content": "Hi"}],
thread_id="test-thread",
)
# Should still complete successfully
assert response.content == "Response"

View file

@ -210,15 +210,24 @@ class LLMRouter:
usage.request_count += 1 usage.request_count += 1
# Record to thread budget (enforcement) # Record to thread budget (enforcement)
crossed_thresholds = []
if thread_id: if thread_id:
from xml_pipeline.message_bus.budget_registry import get_budget_registry from xml_pipeline.message_bus.budget_registry import get_budget_registry
budget_registry = get_budget_registry() budget_registry = get_budget_registry()
budget_registry.consume( _budget, crossed_thresholds = budget_registry.consume(
thread_id, thread_id,
prompt_tokens=prompt_tokens, prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
) )
# Inject BudgetWarning messages for any newly crossed thresholds
if crossed_thresholds:
await self._inject_budget_warnings(
thread_id=thread_id,
agent_id=agent_id,
crossed=crossed_thresholds,
)
# Emit usage event (for billing) # Emit usage event (for billing)
from xml_pipeline.llm.usage_tracker import get_usage_tracker from xml_pipeline.llm.usage_tracker import get_usage_tracker
tracker = get_usage_tracker() tracker = get_usage_tracker()
@ -265,6 +274,80 @@ class LLMRouter:
jitter = delay * 0.25 * (random.random() * 2 - 1) jitter = delay * 0.25 * (random.random() * 2 - 1)
return delay + jitter return delay + jitter
async def _inject_budget_warnings(
self,
thread_id: str,
agent_id: Optional[str],
crossed: List, # List[BudgetThresholdCrossed]
) -> None:
"""
Inject BudgetWarning messages when thresholds are crossed.
Each crossed threshold generates a separate warning message
sent to the agent via the stream pump.
"""
from xml_pipeline.primitives.budget_warning import BudgetWarning
# Get the stream pump (may not be initialized in tests)
try:
from xml_pipeline.message_bus.stream_pump import get_stream_pump
pump = get_stream_pump()
except RuntimeError:
# Pump not initialized - log and skip
logger.warning(
f"Cannot inject BudgetWarning: StreamPump not initialized. "
f"Thread {thread_id[:8]}... crossed thresholds: {[c.threshold_percent for c in crossed]}"
)
return
for threshold in crossed:
# Build human-readable message based on severity
if threshold.severity == "final":
message = (
f"CRITICAL: {threshold.percent_used:.0f}% of token budget used. "
f"Only {threshold.tokens_remaining:,} tokens remaining. "
f"Wrap up immediately or your next request may fail."
)
elif threshold.severity == "critical":
message = (
f"WARNING: {threshold.percent_used:.0f}% of token budget used. "
f"{threshold.tokens_remaining:,} tokens remaining. "
f"Consider wrapping up your current task soon."
)
else: # warning
message = (
f"Note: {threshold.percent_used:.0f}% of token budget used. "
f"{threshold.tokens_remaining:,} tokens remaining."
)
warning_payload = BudgetWarning(
percent_used=threshold.percent_used,
tokens_used=threshold.tokens_used,
tokens_remaining=threshold.tokens_remaining,
max_tokens=threshold.max_tokens,
severity=threshold.severity,
message=message,
)
# Create envelope for the warning
# Target is the agent that made the LLM call
target = agent_id if agent_id else "system"
envelope = pump._wrap_in_envelope(
payload=warning_payload,
from_id="system.budget",
to_id=target,
thread_id=thread_id,
)
# Inject into the pump's queue
await pump.inject(envelope, thread_id=thread_id, from_id="system.budget")
logger.info(
f"BudgetWarning sent to {target}: {threshold.severity} "
f"({threshold.percent_used:.0f}% used, {threshold.tokens_remaining:,} remaining)"
)
def get_agent_usage(self, agent_id: str) -> AgentUsage: def get_agent_usage(self, agent_id: str) -> AgentUsage:
"""Get usage stats for an agent.""" """Get usage stats for an agent."""
return self._agent_usage.get(agent_id, AgentUsage()) return self._agent_usage.get(agent_id, AgentUsage())

View file

@ -17,7 +17,26 @@ from __future__ import annotations
import threading import threading
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, Optional from typing import Dict, List, Optional, Set, Tuple
# Default warning thresholds (percent -> severity)
DEFAULT_WARNING_THRESHOLDS: Dict[int, str] = {
75: "warning", # 75% - early warning
90: "critical", # 90% - wrap up soon
95: "final", # 95% - last chance
}
@dataclass
class BudgetThresholdCrossed:
"""Info about a threshold that was just crossed."""
threshold_percent: int
severity: str
percent_used: float
tokens_used: int
tokens_remaining: int
max_tokens: int
@dataclass @dataclass
@ -28,6 +47,7 @@ class ThreadBudget:
prompt_tokens: int = 0 prompt_tokens: int = 0
completion_tokens: int = 0 completion_tokens: int = 0
request_count: int = 0 request_count: int = 0
triggered_thresholds: Set[int] = field(default_factory=set)
@property @property
def total_tokens(self) -> int: def total_tokens(self) -> int:
@ -44,6 +64,13 @@ class ThreadBudget:
"""True if budget is exhausted.""" """True if budget is exhausted."""
return self.total_tokens >= self.max_tokens return self.total_tokens >= self.max_tokens
@property
def percent_used(self) -> float:
"""Percentage of budget consumed (0-100)."""
if self.max_tokens <= 0:
return 0.0
return (self.total_tokens / self.max_tokens) * 100
def can_consume(self, estimated_tokens: int) -> bool: def can_consume(self, estimated_tokens: int) -> bool:
"""Check if we can consume the given tokens without exceeding budget.""" """Check if we can consume the given tokens without exceeding budget."""
return self.total_tokens + estimated_tokens <= self.max_tokens return self.total_tokens + estimated_tokens <= self.max_tokens
@ -58,6 +85,42 @@ class ThreadBudget:
self.completion_tokens += completion_tokens self.completion_tokens += completion_tokens
self.request_count += 1 self.request_count += 1
def check_thresholds(
self,
thresholds: Dict[int, str] = None,
) -> List[BudgetThresholdCrossed]:
"""
Check if any thresholds were crossed that haven't been triggered yet.
Args:
thresholds: Dict of percent -> severity. Defaults to DEFAULT_WARNING_THRESHOLDS.
Returns:
List of newly crossed thresholds (sorted by percent)
"""
if thresholds is None:
thresholds = DEFAULT_WARNING_THRESHOLDS
crossed = []
current_percent = self.percent_used
for threshold_percent, severity in sorted(thresholds.items()):
if (
current_percent >= threshold_percent
and threshold_percent not in self.triggered_thresholds
):
self.triggered_thresholds.add(threshold_percent)
crossed.append(BudgetThresholdCrossed(
threshold_percent=threshold_percent,
severity=severity,
percent_used=round(current_percent, 1),
tokens_used=self.total_tokens,
tokens_remaining=self.remaining,
max_tokens=self.max_tokens,
))
return crossed
class BudgetExhaustedError(Exception): class BudgetExhaustedError(Exception):
"""Raised when a thread's token budget is exhausted.""" """Raised when a thread's token budget is exhausted."""
@ -176,7 +239,7 @@ class ThreadBudgetRegistry:
thread_id: str, thread_id: str,
prompt_tokens: int = 0, prompt_tokens: int = 0,
completion_tokens: int = 0, completion_tokens: int = 0,
) -> ThreadBudget: ) -> Tuple[ThreadBudget, List[BudgetThresholdCrossed]]:
""" """
Record token consumption for a thread. Record token consumption for a thread.
@ -186,12 +249,13 @@ class ThreadBudgetRegistry:
completion_tokens: Completion tokens used completion_tokens: Completion tokens used
Returns: Returns:
Updated ThreadBudget Tuple of (Updated ThreadBudget, List of newly crossed thresholds)
""" """
budget = self.get_budget(thread_id) budget = self.get_budget(thread_id)
with self._lock: with self._lock:
budget.consume(prompt_tokens, completion_tokens) budget.consume(prompt_tokens, completion_tokens)
return budget crossed = budget.check_thresholds()
return budget, crossed
def has_budget(self, thread_id: str) -> bool: def has_budget(self, thread_id: str) -> bool:
"""Check if a thread has a budget entry (without creating one).""" """Check if a thread has a budget entry (without creating one)."""

View file

@ -29,6 +29,10 @@ from xml_pipeline.primitives.buffer import (
BufferError, BufferError,
handle_buffer_start, handle_buffer_start,
) )
from xml_pipeline.primitives.budget_warning import (
BudgetWarning,
DEFAULT_THRESHOLDS,
)
__all__ = [ __all__ = [
# Boot # Boot
@ -56,4 +60,7 @@ __all__ = [
"BufferDispatched", "BufferDispatched",
"BufferError", "BufferError",
"handle_buffer_start", "handle_buffer_start",
# Budget warnings
"BudgetWarning",
"DEFAULT_THRESHOLDS",
] ]

View file

@ -0,0 +1,66 @@
"""
BudgetWarning System alerts for token budget thresholds.
When a thread approaches its token budget limit, the system injects
BudgetWarning messages to give agents a chance to wrap up gracefully.
Thresholds:
- 75%: Early warning - consider wrapping up
- 90%: Critical warning - finish current task
- 95%: Final warning - immediate action required
Agents can design contingencies:
- Summarize progress and respond early
- Escalate to a supervisor agent
- Save state for continuation
- Request budget increase (if supported)
Example handler pattern:
async def my_agent(payload, metadata):
if isinstance(payload, BudgetWarning):
if payload.severity == "critical":
# Wrap up immediately
return HandlerResponse.respond(
payload=Summary(progress="Reached 90% budget, stopping here...")
)
# Otherwise note it and continue
...
"""
# Note: Do NOT use `from __future__ import annotations` here
# as it breaks the xmlify decorator which needs concrete types
from dataclasses import dataclass
from third_party.xmlable import xmlify
@xmlify
@dataclass
class BudgetWarning:
"""
System warning about token budget consumption.
Sent to agents when their thread crosses budget thresholds.
Attributes:
percent_used: Current percentage of budget consumed (0-100)
tokens_used: Total tokens consumed so far
tokens_remaining: Tokens remaining before exhaustion
max_tokens: Total budget for this thread
severity: Warning level (warning, critical, final)
message: Human-readable description
"""
percent_used: float
tokens_used: int
tokens_remaining: int
max_tokens: int
severity: str # "warning" (75%), "critical" (90%), "final" (95%)
message: str
# Default thresholds (can be configured)
DEFAULT_THRESHOLDS = {
75: "warning", # 75% - early warning
90: "critical", # 90% - wrap up soon
95: "final", # 95% - last chance
}