diff --git a/tests/test_token_budget.py b/tests/test_token_budget.py
index 1c93438..94b551c 100644
--- a/tests/test_token_budget.py
+++ b/tests/test_token_budget.py
@@ -8,12 +8,14 @@ Tests:
"""
import pytest
-from unittest.mock import Mock, AsyncMock, patch
+from unittest.mock import Mock, AsyncMock, patch, MagicMock
from xml_pipeline.message_bus.budget_registry import (
ThreadBudget,
ThreadBudgetRegistry,
BudgetExhaustedError,
+ BudgetThresholdCrossed,
+ DEFAULT_WARNING_THRESHOLDS,
get_budget_registry,
configure_budget_registry,
reset_budget_registry,
@@ -122,7 +124,7 @@ class TestThreadBudgetRegistry:
def test_check_budget_exhausted(self):
"""check_budget should raise when budget exhausted."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
- registry.consume("thread-1", prompt_tokens=1000)
+ budget, _ = registry.consume("thread-1", prompt_tokens=1000)
with pytest.raises(BudgetExhaustedError) as exc_info:
registry.check_budget("thread-1", estimated_tokens=100)
@@ -135,25 +137,26 @@ class TestThreadBudgetRegistry:
def test_check_budget_would_exceed(self):
"""check_budget should raise when estimate would exceed."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
- registry.consume("thread-1", prompt_tokens=600)
+ budget, _ = registry.consume("thread-1", prompt_tokens=600)
with pytest.raises(BudgetExhaustedError):
registry.check_budget("thread-1", estimated_tokens=500)
- def test_consume_returns_budget(self):
- """consume() should return updated budget."""
+ def test_consume_returns_budget_and_thresholds(self):
+ """consume() should return (budget, crossed_thresholds) tuple."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
- budget = registry.consume("thread-1", prompt_tokens=100, completion_tokens=50)
+ budget, crossed = registry.consume("thread-1", prompt_tokens=100, completion_tokens=50)
assert budget.total_tokens == 150
assert budget.request_count == 1
+ assert crossed == [] # 1.5% - no threshold crossed
def test_get_usage(self):
"""get_usage should return dict with all stats."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
- registry.consume("thread-1", prompt_tokens=500, completion_tokens=200)
- registry.consume("thread-1", prompt_tokens=300, completion_tokens=100)
+ budget, _ = registry.consume("thread-1", prompt_tokens=500, completion_tokens=200)
+ budget, _ = registry.consume("thread-1", prompt_tokens=300, completion_tokens=100)
usage = registry.get_usage("thread-1")
@@ -166,8 +169,8 @@ class TestThreadBudgetRegistry:
def test_get_all_usage(self):
"""get_all_usage should return all threads."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
- registry.consume("thread-1", prompt_tokens=100)
- registry.consume("thread-2", prompt_tokens=200)
+ budget, _ = registry.consume("thread-1", prompt_tokens=100)
+ budget, _ = registry.consume("thread-2", prompt_tokens=200)
all_usage = registry.get_all_usage()
@@ -178,7 +181,7 @@ class TestThreadBudgetRegistry:
def test_reset_thread(self):
"""reset_thread should remove budget for thread."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
- registry.consume("thread-1", prompt_tokens=500)
+ budget, _ = registry.consume("thread-1", prompt_tokens=500)
registry.reset_thread("thread-1")
# Getting budget should create new one with zero usage
@@ -188,7 +191,7 @@ class TestThreadBudgetRegistry:
def test_cleanup_thread(self):
"""cleanup_thread should return and remove budget."""
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
- registry.consume("thread-1", prompt_tokens=500)
+ budget, _ = registry.consume("thread-1", prompt_tokens=500)
final_budget = registry.cleanup_thread("thread-1")
@@ -520,7 +523,7 @@ class TestLLMRouterBudgetIntegration:
# Configure small budget and exhaust it
configure_budget_registry(max_tokens_per_thread=100)
budget_registry = get_budget_registry()
- budget_registry.consume("test-thread", prompt_tokens=100)
+ budget, _ = budget_registry.consume("test-thread", prompt_tokens=100)
mock_backend = Mock()
mock_backend.name = "mock"
@@ -590,7 +593,7 @@ class TestBudgetCleanup:
def test_cleanup_thread_returns_budget(self):
"""cleanup_thread should return the budget before removing it."""
registry = ThreadBudgetRegistry()
- registry.consume("thread-1", prompt_tokens=500, completion_tokens=200)
+ budget, _ = registry.consume("thread-1", prompt_tokens=500, completion_tokens=200)
final = registry.cleanup_thread("thread-1")
@@ -602,7 +605,7 @@ class TestBudgetCleanup:
def test_cleanup_thread_removes_budget(self):
"""cleanup_thread should remove the budget from registry."""
registry = ThreadBudgetRegistry()
- registry.consume("thread-1", prompt_tokens=500, completion_tokens=200)
+ budget, _ = registry.consume("thread-1", prompt_tokens=500, completion_tokens=200)
registry.cleanup_thread("thread-1")
@@ -624,7 +627,7 @@ class TestBudgetCleanup:
registry = get_budget_registry()
# Consume some tokens
- registry.consume("test-thread", prompt_tokens=1000, completion_tokens=500)
+ budget, _ = registry.consume("test-thread", prompt_tokens=1000, completion_tokens=500)
assert registry.has_budget("test-thread")
# Cleanup
@@ -632,3 +635,389 @@ class TestBudgetCleanup:
assert final.total_tokens == 1500
assert not registry.has_budget("test-thread")
+
+
+# ============================================================================
+# Budget Threshold Tests
+# ============================================================================
+
+class TestBudgetThresholds:
+ """Test budget threshold crossing detection."""
+
+ @pytest.fixture(autouse=True)
+ def reset_all(self):
+ """Reset all global registries."""
+ reset_budget_registry()
+ yield
+ reset_budget_registry()
+
+ def test_no_threshold_crossed_below_first(self):
+ """No thresholds crossed when under 75%."""
+ registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
+
+ # 70% usage - below first threshold
+ budget, crossed = registry.consume("thread-1", prompt_tokens=700)
+
+ assert crossed == []
+ assert budget.percent_used == 70.0
+
+ def test_warning_threshold_crossed_at_75(self):
+ """75% threshold should trigger 'warning' severity."""
+ registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
+
+ # 75% usage - exactly at threshold
+ budget, crossed = registry.consume("thread-1", prompt_tokens=750)
+
+ assert len(crossed) == 1
+ assert crossed[0].threshold_percent == 75
+ assert crossed[0].severity == "warning"
+ assert crossed[0].percent_used == 75.0
+ assert crossed[0].tokens_used == 750
+ assert crossed[0].tokens_remaining == 250
+
+ def test_critical_threshold_crossed_at_90(self):
+ """90% threshold should trigger 'critical' severity."""
+ registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
+
+ # First consume to 76% (triggers warning)
+ budget, crossed = registry.consume("thread-1", prompt_tokens=760)
+ assert len(crossed) == 1
+ assert crossed[0].severity == "warning"
+
+ # Now consume to 91% (triggers critical)
+ budget, crossed = registry.consume("thread-1", prompt_tokens=150)
+ assert len(crossed) == 1
+ assert crossed[0].threshold_percent == 90
+ assert crossed[0].severity == "critical"
+
+ def test_final_threshold_crossed_at_95(self):
+ """95% threshold should trigger 'final' severity."""
+ registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
+
+ # Jump directly to 96%
+ budget, crossed = registry.consume("thread-1", prompt_tokens=960)
+
+ # Should have crossed all three thresholds
+ assert len(crossed) == 3
+ severities = {c.severity for c in crossed}
+ assert severities == {"warning", "critical", "final"}
+
+ def test_thresholds_only_triggered_once(self):
+ """Each threshold should only be triggered once per thread."""
+ registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
+
+ # Cross the 75% threshold
+ budget, crossed1 = registry.consume("thread-1", prompt_tokens=760)
+ assert len(crossed1) == 1
+
+ # Stay at 76% with another consume - should not trigger again
+ budget, crossed2 = registry.consume("thread-1", prompt_tokens=0)
+ assert crossed2 == []
+
+ # Consume more to stay between 75-90% - should not trigger
+ budget, crossed3 = registry.consume("thread-1", prompt_tokens=50) # Now at 81%
+ assert crossed3 == []
+
+ def test_multiple_thresholds_in_single_consume(self):
+ """A single large consume can cross multiple thresholds."""
+ registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
+
+ # Single consume from 0% to 92% - crosses 75% and 90%
+ budget, crossed = registry.consume("thread-1", prompt_tokens=920)
+
+ assert len(crossed) == 2
+ thresholds = {c.threshold_percent for c in crossed}
+ assert thresholds == {75, 90}
+
+ def test_threshold_data_accuracy(self):
+ """Threshold crossing data should be accurate."""
+ registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
+
+ # Consume to 75.5%
+ budget, crossed = registry.consume("thread-1", prompt_tokens=7550)
+
+ assert len(crossed) == 1
+ threshold = crossed[0]
+ assert threshold.threshold_percent == 75
+ assert threshold.percent_used == 75.5
+ assert threshold.tokens_used == 7550
+ assert threshold.tokens_remaining == 2450
+ assert threshold.max_tokens == 10000
+
+ def test_percent_used_property(self):
+ """percent_used property should calculate correctly."""
+ budget = ThreadBudget(max_tokens=1000)
+
+ assert budget.percent_used == 0.0
+
+ budget.consume(prompt_tokens=500)
+ assert budget.percent_used == 50.0
+
+ budget.consume(prompt_tokens=250)
+ assert budget.percent_used == 75.0
+
+ def test_percent_used_with_zero_max(self):
+ """percent_used should handle zero max_tokens gracefully."""
+ budget = ThreadBudget(max_tokens=0)
+ assert budget.percent_used == 0.0
+
+ def test_triggered_thresholds_tracked_per_thread(self):
+ """Different threads should track thresholds independently."""
+ registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
+
+ # Thread 1 crosses 75%
+ budget1, crossed1 = registry.consume("thread-1", prompt_tokens=760)
+ assert len(crossed1) == 1
+
+ # Thread 2 crosses 75% independently
+ budget2, crossed2 = registry.consume("thread-2", prompt_tokens=800)
+ assert len(crossed2) == 1
+
+ # Both should have triggered their thresholds
+ assert 75 in budget1.triggered_thresholds
+ assert 75 in budget2.triggered_thresholds
+
+
+# ============================================================================
+# BudgetWarning Injection Tests
+# ============================================================================
+
+class TestBudgetWarningInjection:
+ """Test BudgetWarning messages are injected when thresholds crossed."""
+
+ @pytest.fixture(autouse=True)
+ def reset_all(self):
+ """Reset all global registries."""
+ reset_budget_registry()
+ reset_usage_tracker()
+ yield
+ reset_budget_registry()
+ reset_usage_tracker()
+
+ @pytest.mark.asyncio
+ async def test_warning_injected_on_threshold_crossing(self):
+ """LLM complete should inject BudgetWarning when threshold crossed."""
+ from xml_pipeline.llm.router import LLMRouter
+ from xml_pipeline.llm.backend import LLMResponse
+ from xml_pipeline.primitives.budget_warning import BudgetWarning
+
+ # Configure budget for 1000 tokens
+ configure_budget_registry(max_tokens_per_thread=1000)
+
+ # Create mock backend that returns 760 tokens (crosses 75% threshold)
+ mock_backend = Mock()
+ mock_backend.name = "mock"
+ mock_backend.provider = "test"
+ mock_backend.serves_model = Mock(return_value=True)
+ mock_backend.priority = 1
+ mock_backend.load = 0
+ mock_backend.complete = AsyncMock(return_value=LLMResponse(
+ content="Hello!",
+ model="test-model",
+ usage={"prompt_tokens": 700, "completion_tokens": 60, "total_tokens": 760},
+ finish_reason="stop",
+ ))
+
+ router = LLMRouter()
+ router.backends.append(mock_backend)
+
+ # Mock the pump to capture injected warnings
+ injected_messages = []
+
+ async def mock_inject(raw_bytes, thread_id, from_id):
+ injected_messages.append({
+ "raw_bytes": raw_bytes,
+ "thread_id": thread_id,
+ "from_id": from_id,
+ })
+
+ mock_pump = Mock()
+ mock_pump.inject = AsyncMock(side_effect=mock_inject)
+ mock_pump._wrap_in_envelope = Mock(return_value=b"warning")
+
+ with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
+ await router.complete(
+ model="test-model",
+ messages=[{"role": "user", "content": "Hi"}],
+ thread_id="test-thread",
+ agent_id="greeter",
+ )
+
+ # Should have injected one warning (75% threshold)
+ assert len(injected_messages) == 1
+ assert injected_messages[0]["from_id"] == "system.budget"
+ assert injected_messages[0]["thread_id"] == "test-thread"
+
+ @pytest.mark.asyncio
+ async def test_multiple_warnings_injected(self):
+ """Multiple thresholds crossed should inject multiple warnings."""
+ from xml_pipeline.llm.router import LLMRouter
+ from xml_pipeline.llm.backend import LLMResponse
+
+ configure_budget_registry(max_tokens_per_thread=1000)
+
+ # Backend returns 960 tokens (crosses 75%, 90%, and 95%)
+ mock_backend = Mock()
+ mock_backend.name = "mock"
+ mock_backend.provider = "test"
+ mock_backend.serves_model = Mock(return_value=True)
+ mock_backend.priority = 1
+ mock_backend.load = 0
+ mock_backend.complete = AsyncMock(return_value=LLMResponse(
+ content="Big response",
+ model="test-model",
+ usage={"prompt_tokens": 900, "completion_tokens": 60, "total_tokens": 960},
+ finish_reason="stop",
+ ))
+
+ router = LLMRouter()
+ router.backends.append(mock_backend)
+
+ injected_messages = []
+
+ mock_pump = Mock()
+ mock_pump.inject = AsyncMock(side_effect=lambda *args, **kwargs: injected_messages.append(args))
+ mock_pump._wrap_in_envelope = Mock(return_value=b"warning")
+
+ with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
+ await router.complete(
+ model="test-model",
+ messages=[{"role": "user", "content": "Hi"}],
+ thread_id="test-thread",
+ agent_id="agent",
+ )
+
+ # Should have injected 3 warnings (75%, 90%, 95%)
+ assert len(injected_messages) == 3
+
+ @pytest.mark.asyncio
+ async def test_no_warning_when_below_threshold(self):
+ """No warning should be injected when below all thresholds."""
+ from xml_pipeline.llm.router import LLMRouter
+ from xml_pipeline.llm.backend import LLMResponse
+
+ configure_budget_registry(max_tokens_per_thread=10000)
+
+ # Backend returns 100 tokens (only 1%)
+ mock_backend = Mock()
+ mock_backend.name = "mock"
+ mock_backend.provider = "test"
+ mock_backend.serves_model = Mock(return_value=True)
+ mock_backend.priority = 1
+ mock_backend.load = 0
+ mock_backend.complete = AsyncMock(return_value=LLMResponse(
+ content="Small response",
+ model="test-model",
+ usage={"prompt_tokens": 80, "completion_tokens": 20, "total_tokens": 100},
+ finish_reason="stop",
+ ))
+
+ router = LLMRouter()
+ router.backends.append(mock_backend)
+
+ mock_pump = Mock()
+ mock_pump.inject = AsyncMock()
+
+ with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
+ await router.complete(
+ model="test-model",
+ messages=[{"role": "user", "content": "Hi"}],
+ thread_id="test-thread",
+ )
+
+ # No warnings should be injected
+ mock_pump.inject.assert_not_called()
+
+ @pytest.mark.asyncio
+ async def test_warning_includes_correct_severity(self):
+ """Warning messages should have correct severity based on threshold."""
+ from xml_pipeline.llm.router import LLMRouter
+ from xml_pipeline.llm.backend import LLMResponse
+ from xml_pipeline.primitives.budget_warning import BudgetWarning
+
+ configure_budget_registry(max_tokens_per_thread=1000)
+
+ mock_backend = Mock()
+ mock_backend.name = "mock"
+ mock_backend.provider = "test"
+ mock_backend.serves_model = Mock(return_value=True)
+ mock_backend.priority = 1
+ mock_backend.load = 0
+ mock_backend.complete = AsyncMock(return_value=LLMResponse(
+ content="Response",
+ model="test-model",
+ usage={"prompt_tokens": 910, "completion_tokens": 0, "total_tokens": 910},
+ finish_reason="stop",
+ ))
+
+ router = LLMRouter()
+ router.backends.append(mock_backend)
+
+ captured_payloads = []
+
+ def capture_wrap(payload, from_id, to_id, thread_id):
+ captured_payloads.append({
+ "payload": payload,
+ "from_id": from_id,
+ "to_id": to_id,
+ })
+ return b""
+
+ mock_pump = Mock()
+ mock_pump.inject = AsyncMock()
+ mock_pump._wrap_in_envelope = Mock(side_effect=capture_wrap)
+
+ with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
+ await router.complete(
+ model="test-model",
+ messages=[{"role": "user", "content": "Hi"}],
+ thread_id="test-thread",
+ agent_id="agent",
+ )
+
+ # Should have captured 2 payloads (75% and 90%)
+ assert len(captured_payloads) == 2
+
+ severities = [p["payload"].severity for p in captured_payloads]
+ assert "warning" in severities
+ assert "critical" in severities
+
+ # Check the critical warning has appropriate message
+ critical = next(p for p in captured_payloads if p["payload"].severity == "critical")
+ assert "WARNING" in critical["payload"].message
+ assert "Consider wrapping up" in critical["payload"].message
+
+ @pytest.mark.asyncio
+ async def test_warning_graceful_without_pump(self):
+ """Warning injection should gracefully handle missing pump."""
+ from xml_pipeline.llm.router import LLMRouter
+ from xml_pipeline.llm.backend import LLMResponse
+
+ configure_budget_registry(max_tokens_per_thread=1000)
+
+ mock_backend = Mock()
+ mock_backend.name = "mock"
+ mock_backend.provider = "test"
+ mock_backend.serves_model = Mock(return_value=True)
+ mock_backend.priority = 1
+ mock_backend.load = 0
+ mock_backend.complete = AsyncMock(return_value=LLMResponse(
+ content="Response",
+ model="test-model",
+ usage={"prompt_tokens": 800, "completion_tokens": 0, "total_tokens": 800},
+ finish_reason="stop",
+ ))
+
+ router = LLMRouter()
+ router.backends.append(mock_backend)
+
+ # Pump not initialized - should not raise
+ with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', side_effect=RuntimeError("Not initialized")):
+ response = await router.complete(
+ model="test-model",
+ messages=[{"role": "user", "content": "Hi"}],
+ thread_id="test-thread",
+ )
+
+ # Should still complete successfully
+ assert response.content == "Response"
diff --git a/xml_pipeline/llm/router.py b/xml_pipeline/llm/router.py
index 6c25048..61b60c5 100644
--- a/xml_pipeline/llm/router.py
+++ b/xml_pipeline/llm/router.py
@@ -210,15 +210,24 @@ class LLMRouter:
usage.request_count += 1
# Record to thread budget (enforcement)
+ crossed_thresholds = []
if thread_id:
from xml_pipeline.message_bus.budget_registry import get_budget_registry
budget_registry = get_budget_registry()
- budget_registry.consume(
+ _budget, crossed_thresholds = budget_registry.consume(
thread_id,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
+ # Inject BudgetWarning messages for any newly crossed thresholds
+ if crossed_thresholds:
+ await self._inject_budget_warnings(
+ thread_id=thread_id,
+ agent_id=agent_id,
+ crossed=crossed_thresholds,
+ )
+
# Emit usage event (for billing)
from xml_pipeline.llm.usage_tracker import get_usage_tracker
tracker = get_usage_tracker()
@@ -265,6 +274,80 @@ class LLMRouter:
jitter = delay * 0.25 * (random.random() * 2 - 1)
return delay + jitter
+ async def _inject_budget_warnings(
+ self,
+ thread_id: str,
+ agent_id: Optional[str],
+ crossed: List, # List[BudgetThresholdCrossed]
+ ) -> None:
+ """
+ Inject BudgetWarning messages when thresholds are crossed.
+
+ Each crossed threshold generates a separate warning message
+ sent to the agent via the stream pump.
+ """
+ from xml_pipeline.primitives.budget_warning import BudgetWarning
+
+ # Get the stream pump (may not be initialized in tests)
+ try:
+ from xml_pipeline.message_bus.stream_pump import get_stream_pump
+ pump = get_stream_pump()
+ except RuntimeError:
+ # Pump not initialized - log and skip
+ logger.warning(
+ f"Cannot inject BudgetWarning: StreamPump not initialized. "
+ f"Thread {thread_id[:8]}... crossed thresholds: {[c.threshold_percent for c in crossed]}"
+ )
+ return
+
+ for threshold in crossed:
+ # Build human-readable message based on severity
+ if threshold.severity == "final":
+ message = (
+ f"CRITICAL: {threshold.percent_used:.0f}% of token budget used. "
+ f"Only {threshold.tokens_remaining:,} tokens remaining. "
+ f"Wrap up immediately or your next request may fail."
+ )
+ elif threshold.severity == "critical":
+ message = (
+ f"WARNING: {threshold.percent_used:.0f}% of token budget used. "
+ f"{threshold.tokens_remaining:,} tokens remaining. "
+ f"Consider wrapping up your current task soon."
+ )
+ else: # warning
+ message = (
+ f"Note: {threshold.percent_used:.0f}% of token budget used. "
+ f"{threshold.tokens_remaining:,} tokens remaining."
+ )
+
+ warning_payload = BudgetWarning(
+ percent_used=threshold.percent_used,
+ tokens_used=threshold.tokens_used,
+ tokens_remaining=threshold.tokens_remaining,
+ max_tokens=threshold.max_tokens,
+ severity=threshold.severity,
+ message=message,
+ )
+
+ # Create envelope for the warning
+ # Target is the agent that made the LLM call
+ target = agent_id if agent_id else "system"
+
+ envelope = pump._wrap_in_envelope(
+ payload=warning_payload,
+ from_id="system.budget",
+ to_id=target,
+ thread_id=thread_id,
+ )
+
+ # Inject into the pump's queue
+ await pump.inject(envelope, thread_id=thread_id, from_id="system.budget")
+
+ logger.info(
+ f"BudgetWarning sent to {target}: {threshold.severity} "
+ f"({threshold.percent_used:.0f}% used, {threshold.tokens_remaining:,} remaining)"
+ )
+
def get_agent_usage(self, agent_id: str) -> AgentUsage:
"""Get usage stats for an agent."""
return self._agent_usage.get(agent_id, AgentUsage())
diff --git a/xml_pipeline/message_bus/budget_registry.py b/xml_pipeline/message_bus/budget_registry.py
index a56d5c5..e30843e 100644
--- a/xml_pipeline/message_bus/budget_registry.py
+++ b/xml_pipeline/message_bus/budget_registry.py
@@ -17,7 +17,26 @@ from __future__ import annotations
import threading
from dataclasses import dataclass, field
-from typing import Dict, Optional
+from typing import Dict, List, Optional, Set, Tuple
+
+
+# Default warning thresholds (percent -> severity)
+DEFAULT_WARNING_THRESHOLDS: Dict[int, str] = {
+ 75: "warning", # 75% - early warning
+ 90: "critical", # 90% - wrap up soon
+ 95: "final", # 95% - last chance
+}
+
+
+@dataclass
+class BudgetThresholdCrossed:
+ """Info about a threshold that was just crossed."""
+ threshold_percent: int
+ severity: str
+ percent_used: float
+ tokens_used: int
+ tokens_remaining: int
+ max_tokens: int
@dataclass
@@ -28,6 +47,7 @@ class ThreadBudget:
prompt_tokens: int = 0
completion_tokens: int = 0
request_count: int = 0
+ triggered_thresholds: Set[int] = field(default_factory=set)
@property
def total_tokens(self) -> int:
@@ -44,6 +64,13 @@ class ThreadBudget:
"""True if budget is exhausted."""
return self.total_tokens >= self.max_tokens
+ @property
+ def percent_used(self) -> float:
+ """Percentage of budget consumed (0-100)."""
+ if self.max_tokens <= 0:
+ return 0.0
+ return (self.total_tokens / self.max_tokens) * 100
+
def can_consume(self, estimated_tokens: int) -> bool:
"""Check if we can consume the given tokens without exceeding budget."""
return self.total_tokens + estimated_tokens <= self.max_tokens
@@ -58,6 +85,42 @@ class ThreadBudget:
self.completion_tokens += completion_tokens
self.request_count += 1
+ def check_thresholds(
+ self,
+ thresholds: Dict[int, str] = None,
+ ) -> List[BudgetThresholdCrossed]:
+ """
+ Check if any thresholds were crossed that haven't been triggered yet.
+
+ Args:
+ thresholds: Dict of percent -> severity. Defaults to DEFAULT_WARNING_THRESHOLDS.
+
+ Returns:
+ List of newly crossed thresholds (sorted by percent)
+ """
+ if thresholds is None:
+ thresholds = DEFAULT_WARNING_THRESHOLDS
+
+ crossed = []
+ current_percent = self.percent_used
+
+ for threshold_percent, severity in sorted(thresholds.items()):
+ if (
+ current_percent >= threshold_percent
+ and threshold_percent not in self.triggered_thresholds
+ ):
+ self.triggered_thresholds.add(threshold_percent)
+ crossed.append(BudgetThresholdCrossed(
+ threshold_percent=threshold_percent,
+ severity=severity,
+ percent_used=round(current_percent, 1),
+ tokens_used=self.total_tokens,
+ tokens_remaining=self.remaining,
+ max_tokens=self.max_tokens,
+ ))
+
+ return crossed
+
class BudgetExhaustedError(Exception):
"""Raised when a thread's token budget is exhausted."""
@@ -176,7 +239,7 @@ class ThreadBudgetRegistry:
thread_id: str,
prompt_tokens: int = 0,
completion_tokens: int = 0,
- ) -> ThreadBudget:
+ ) -> Tuple[ThreadBudget, List[BudgetThresholdCrossed]]:
"""
Record token consumption for a thread.
@@ -186,12 +249,13 @@ class ThreadBudgetRegistry:
completion_tokens: Completion tokens used
Returns:
- Updated ThreadBudget
+ Tuple of (Updated ThreadBudget, List of newly crossed thresholds)
"""
budget = self.get_budget(thread_id)
with self._lock:
budget.consume(prompt_tokens, completion_tokens)
- return budget
+ crossed = budget.check_thresholds()
+ return budget, crossed
def has_budget(self, thread_id: str) -> bool:
"""Check if a thread has a budget entry (without creating one)."""
diff --git a/xml_pipeline/primitives/__init__.py b/xml_pipeline/primitives/__init__.py
index a130ce0..15dfd56 100644
--- a/xml_pipeline/primitives/__init__.py
+++ b/xml_pipeline/primitives/__init__.py
@@ -29,6 +29,10 @@ from xml_pipeline.primitives.buffer import (
BufferError,
handle_buffer_start,
)
+from xml_pipeline.primitives.budget_warning import (
+ BudgetWarning,
+ DEFAULT_THRESHOLDS,
+)
__all__ = [
# Boot
@@ -56,4 +60,7 @@ __all__ = [
"BufferDispatched",
"BufferError",
"handle_buffer_start",
+ # Budget warnings
+ "BudgetWarning",
+ "DEFAULT_THRESHOLDS",
]
diff --git a/xml_pipeline/primitives/budget_warning.py b/xml_pipeline/primitives/budget_warning.py
new file mode 100644
index 0000000..54e31b8
--- /dev/null
+++ b/xml_pipeline/primitives/budget_warning.py
@@ -0,0 +1,66 @@
+"""
+BudgetWarning — System alerts for token budget thresholds.
+
+When a thread approaches its token budget limit, the system injects
+BudgetWarning messages to give agents a chance to wrap up gracefully.
+
+Thresholds:
+- 75%: Early warning - consider wrapping up
+- 90%: Critical warning - finish current task
+- 95%: Final warning - immediate action required
+
+Agents can design contingencies:
+- Summarize progress and respond early
+- Escalate to a supervisor agent
+- Save state for continuation
+- Request budget increase (if supported)
+
+Example handler pattern:
+ async def my_agent(payload, metadata):
+ if isinstance(payload, BudgetWarning):
+ if payload.severity == "critical":
+ # Wrap up immediately
+ return HandlerResponse.respond(
+ payload=Summary(progress="Reached 90% budget, stopping here...")
+ )
+ # Otherwise note it and continue
+ ...
+"""
+
+# Note: Do NOT use `from __future__ import annotations` here
+# as it breaks the xmlify decorator which needs concrete types
+
+from dataclasses import dataclass
+from third_party.xmlable import xmlify
+
+
+@xmlify
+@dataclass
+class BudgetWarning:
+ """
+ System warning about token budget consumption.
+
+ Sent to agents when their thread crosses budget thresholds.
+
+ Attributes:
+ percent_used: Current percentage of budget consumed (0-100)
+ tokens_used: Total tokens consumed so far
+ tokens_remaining: Tokens remaining before exhaustion
+ max_tokens: Total budget for this thread
+ severity: Warning level (warning, critical, final)
+ message: Human-readable description
+ """
+ percent_used: float
+ tokens_used: int
+ tokens_remaining: int
+ max_tokens: int
+ severity: str # "warning" (75%), "critical" (90%), "final" (95%)
+ message: str
+
+
+# Default thresholds (can be configured)
+DEFAULT_THRESHOLDS = {
+ 75: "warning", # 75% - early warning
+ 90: "critical", # 90% - wrap up soon
+ 95: "final", # 95% - last chance
+}