diff --git a/tests/test_token_budget.py b/tests/test_token_budget.py index 1c93438..94b551c 100644 --- a/tests/test_token_budget.py +++ b/tests/test_token_budget.py @@ -8,12 +8,14 @@ Tests: """ import pytest -from unittest.mock import Mock, AsyncMock, patch +from unittest.mock import Mock, AsyncMock, patch, MagicMock from xml_pipeline.message_bus.budget_registry import ( ThreadBudget, ThreadBudgetRegistry, BudgetExhaustedError, + BudgetThresholdCrossed, + DEFAULT_WARNING_THRESHOLDS, get_budget_registry, configure_budget_registry, reset_budget_registry, @@ -122,7 +124,7 @@ class TestThreadBudgetRegistry: def test_check_budget_exhausted(self): """check_budget should raise when budget exhausted.""" registry = ThreadBudgetRegistry(max_tokens_per_thread=1000) - registry.consume("thread-1", prompt_tokens=1000) + budget, _ = registry.consume("thread-1", prompt_tokens=1000) with pytest.raises(BudgetExhaustedError) as exc_info: registry.check_budget("thread-1", estimated_tokens=100) @@ -135,25 +137,26 @@ class TestThreadBudgetRegistry: def test_check_budget_would_exceed(self): """check_budget should raise when estimate would exceed.""" registry = ThreadBudgetRegistry(max_tokens_per_thread=1000) - registry.consume("thread-1", prompt_tokens=600) + budget, _ = registry.consume("thread-1", prompt_tokens=600) with pytest.raises(BudgetExhaustedError): registry.check_budget("thread-1", estimated_tokens=500) - def test_consume_returns_budget(self): - """consume() should return updated budget.""" + def test_consume_returns_budget_and_thresholds(self): + """consume() should return (budget, crossed_thresholds) tuple.""" registry = ThreadBudgetRegistry(max_tokens_per_thread=10000) - budget = registry.consume("thread-1", prompt_tokens=100, completion_tokens=50) + budget, crossed = registry.consume("thread-1", prompt_tokens=100, completion_tokens=50) assert budget.total_tokens == 150 assert budget.request_count == 1 + assert crossed == [] # 1.5% - no threshold crossed def test_get_usage(self): """get_usage should return dict with all stats.""" registry = ThreadBudgetRegistry(max_tokens_per_thread=10000) - registry.consume("thread-1", prompt_tokens=500, completion_tokens=200) - registry.consume("thread-1", prompt_tokens=300, completion_tokens=100) + budget, _ = registry.consume("thread-1", prompt_tokens=500, completion_tokens=200) + budget, _ = registry.consume("thread-1", prompt_tokens=300, completion_tokens=100) usage = registry.get_usage("thread-1") @@ -166,8 +169,8 @@ class TestThreadBudgetRegistry: def test_get_all_usage(self): """get_all_usage should return all threads.""" registry = ThreadBudgetRegistry(max_tokens_per_thread=10000) - registry.consume("thread-1", prompt_tokens=100) - registry.consume("thread-2", prompt_tokens=200) + budget, _ = registry.consume("thread-1", prompt_tokens=100) + budget, _ = registry.consume("thread-2", prompt_tokens=200) all_usage = registry.get_all_usage() @@ -178,7 +181,7 @@ class TestThreadBudgetRegistry: def test_reset_thread(self): """reset_thread should remove budget for thread.""" registry = ThreadBudgetRegistry(max_tokens_per_thread=10000) - registry.consume("thread-1", prompt_tokens=500) + budget, _ = registry.consume("thread-1", prompt_tokens=500) registry.reset_thread("thread-1") # Getting budget should create new one with zero usage @@ -188,7 +191,7 @@ class TestThreadBudgetRegistry: def test_cleanup_thread(self): """cleanup_thread should return and remove budget.""" registry = ThreadBudgetRegistry(max_tokens_per_thread=10000) - registry.consume("thread-1", prompt_tokens=500) + budget, _ = registry.consume("thread-1", prompt_tokens=500) final_budget = registry.cleanup_thread("thread-1") @@ -520,7 +523,7 @@ class TestLLMRouterBudgetIntegration: # Configure small budget and exhaust it configure_budget_registry(max_tokens_per_thread=100) budget_registry = get_budget_registry() - budget_registry.consume("test-thread", prompt_tokens=100) + budget, _ = budget_registry.consume("test-thread", prompt_tokens=100) mock_backend = Mock() mock_backend.name = "mock" @@ -590,7 +593,7 @@ class TestBudgetCleanup: def test_cleanup_thread_returns_budget(self): """cleanup_thread should return the budget before removing it.""" registry = ThreadBudgetRegistry() - registry.consume("thread-1", prompt_tokens=500, completion_tokens=200) + budget, _ = registry.consume("thread-1", prompt_tokens=500, completion_tokens=200) final = registry.cleanup_thread("thread-1") @@ -602,7 +605,7 @@ class TestBudgetCleanup: def test_cleanup_thread_removes_budget(self): """cleanup_thread should remove the budget from registry.""" registry = ThreadBudgetRegistry() - registry.consume("thread-1", prompt_tokens=500, completion_tokens=200) + budget, _ = registry.consume("thread-1", prompt_tokens=500, completion_tokens=200) registry.cleanup_thread("thread-1") @@ -624,7 +627,7 @@ class TestBudgetCleanup: registry = get_budget_registry() # Consume some tokens - registry.consume("test-thread", prompt_tokens=1000, completion_tokens=500) + budget, _ = registry.consume("test-thread", prompt_tokens=1000, completion_tokens=500) assert registry.has_budget("test-thread") # Cleanup @@ -632,3 +635,389 @@ class TestBudgetCleanup: assert final.total_tokens == 1500 assert not registry.has_budget("test-thread") + + +# ============================================================================ +# Budget Threshold Tests +# ============================================================================ + +class TestBudgetThresholds: + """Test budget threshold crossing detection.""" + + @pytest.fixture(autouse=True) + def reset_all(self): + """Reset all global registries.""" + reset_budget_registry() + yield + reset_budget_registry() + + def test_no_threshold_crossed_below_first(self): + """No thresholds crossed when under 75%.""" + registry = ThreadBudgetRegistry(max_tokens_per_thread=1000) + + # 70% usage - below first threshold + budget, crossed = registry.consume("thread-1", prompt_tokens=700) + + assert crossed == [] + assert budget.percent_used == 70.0 + + def test_warning_threshold_crossed_at_75(self): + """75% threshold should trigger 'warning' severity.""" + registry = ThreadBudgetRegistry(max_tokens_per_thread=1000) + + # 75% usage - exactly at threshold + budget, crossed = registry.consume("thread-1", prompt_tokens=750) + + assert len(crossed) == 1 + assert crossed[0].threshold_percent == 75 + assert crossed[0].severity == "warning" + assert crossed[0].percent_used == 75.0 + assert crossed[0].tokens_used == 750 + assert crossed[0].tokens_remaining == 250 + + def test_critical_threshold_crossed_at_90(self): + """90% threshold should trigger 'critical' severity.""" + registry = ThreadBudgetRegistry(max_tokens_per_thread=1000) + + # First consume to 76% (triggers warning) + budget, crossed = registry.consume("thread-1", prompt_tokens=760) + assert len(crossed) == 1 + assert crossed[0].severity == "warning" + + # Now consume to 91% (triggers critical) + budget, crossed = registry.consume("thread-1", prompt_tokens=150) + assert len(crossed) == 1 + assert crossed[0].threshold_percent == 90 + assert crossed[0].severity == "critical" + + def test_final_threshold_crossed_at_95(self): + """95% threshold should trigger 'final' severity.""" + registry = ThreadBudgetRegistry(max_tokens_per_thread=1000) + + # Jump directly to 96% + budget, crossed = registry.consume("thread-1", prompt_tokens=960) + + # Should have crossed all three thresholds + assert len(crossed) == 3 + severities = {c.severity for c in crossed} + assert severities == {"warning", "critical", "final"} + + def test_thresholds_only_triggered_once(self): + """Each threshold should only be triggered once per thread.""" + registry = ThreadBudgetRegistry(max_tokens_per_thread=1000) + + # Cross the 75% threshold + budget, crossed1 = registry.consume("thread-1", prompt_tokens=760) + assert len(crossed1) == 1 + + # Stay at 76% with another consume - should not trigger again + budget, crossed2 = registry.consume("thread-1", prompt_tokens=0) + assert crossed2 == [] + + # Consume more to stay between 75-90% - should not trigger + budget, crossed3 = registry.consume("thread-1", prompt_tokens=50) # Now at 81% + assert crossed3 == [] + + def test_multiple_thresholds_in_single_consume(self): + """A single large consume can cross multiple thresholds.""" + registry = ThreadBudgetRegistry(max_tokens_per_thread=1000) + + # Single consume from 0% to 92% - crosses 75% and 90% + budget, crossed = registry.consume("thread-1", prompt_tokens=920) + + assert len(crossed) == 2 + thresholds = {c.threshold_percent for c in crossed} + assert thresholds == {75, 90} + + def test_threshold_data_accuracy(self): + """Threshold crossing data should be accurate.""" + registry = ThreadBudgetRegistry(max_tokens_per_thread=10000) + + # Consume to 75.5% + budget, crossed = registry.consume("thread-1", prompt_tokens=7550) + + assert len(crossed) == 1 + threshold = crossed[0] + assert threshold.threshold_percent == 75 + assert threshold.percent_used == 75.5 + assert threshold.tokens_used == 7550 + assert threshold.tokens_remaining == 2450 + assert threshold.max_tokens == 10000 + + def test_percent_used_property(self): + """percent_used property should calculate correctly.""" + budget = ThreadBudget(max_tokens=1000) + + assert budget.percent_used == 0.0 + + budget.consume(prompt_tokens=500) + assert budget.percent_used == 50.0 + + budget.consume(prompt_tokens=250) + assert budget.percent_used == 75.0 + + def test_percent_used_with_zero_max(self): + """percent_used should handle zero max_tokens gracefully.""" + budget = ThreadBudget(max_tokens=0) + assert budget.percent_used == 0.0 + + def test_triggered_thresholds_tracked_per_thread(self): + """Different threads should track thresholds independently.""" + registry = ThreadBudgetRegistry(max_tokens_per_thread=1000) + + # Thread 1 crosses 75% + budget1, crossed1 = registry.consume("thread-1", prompt_tokens=760) + assert len(crossed1) == 1 + + # Thread 2 crosses 75% independently + budget2, crossed2 = registry.consume("thread-2", prompt_tokens=800) + assert len(crossed2) == 1 + + # Both should have triggered their thresholds + assert 75 in budget1.triggered_thresholds + assert 75 in budget2.triggered_thresholds + + +# ============================================================================ +# BudgetWarning Injection Tests +# ============================================================================ + +class TestBudgetWarningInjection: + """Test BudgetWarning messages are injected when thresholds crossed.""" + + @pytest.fixture(autouse=True) + def reset_all(self): + """Reset all global registries.""" + reset_budget_registry() + reset_usage_tracker() + yield + reset_budget_registry() + reset_usage_tracker() + + @pytest.mark.asyncio + async def test_warning_injected_on_threshold_crossing(self): + """LLM complete should inject BudgetWarning when threshold crossed.""" + from xml_pipeline.llm.router import LLMRouter + from xml_pipeline.llm.backend import LLMResponse + from xml_pipeline.primitives.budget_warning import BudgetWarning + + # Configure budget for 1000 tokens + configure_budget_registry(max_tokens_per_thread=1000) + + # Create mock backend that returns 760 tokens (crosses 75% threshold) + mock_backend = Mock() + mock_backend.name = "mock" + mock_backend.provider = "test" + mock_backend.serves_model = Mock(return_value=True) + mock_backend.priority = 1 + mock_backend.load = 0 + mock_backend.complete = AsyncMock(return_value=LLMResponse( + content="Hello!", + model="test-model", + usage={"prompt_tokens": 700, "completion_tokens": 60, "total_tokens": 760}, + finish_reason="stop", + )) + + router = LLMRouter() + router.backends.append(mock_backend) + + # Mock the pump to capture injected warnings + injected_messages = [] + + async def mock_inject(raw_bytes, thread_id, from_id): + injected_messages.append({ + "raw_bytes": raw_bytes, + "thread_id": thread_id, + "from_id": from_id, + }) + + mock_pump = Mock() + mock_pump.inject = AsyncMock(side_effect=mock_inject) + mock_pump._wrap_in_envelope = Mock(return_value=b"warning") + + with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump): + await router.complete( + model="test-model", + messages=[{"role": "user", "content": "Hi"}], + thread_id="test-thread", + agent_id="greeter", + ) + + # Should have injected one warning (75% threshold) + assert len(injected_messages) == 1 + assert injected_messages[0]["from_id"] == "system.budget" + assert injected_messages[0]["thread_id"] == "test-thread" + + @pytest.mark.asyncio + async def test_multiple_warnings_injected(self): + """Multiple thresholds crossed should inject multiple warnings.""" + from xml_pipeline.llm.router import LLMRouter + from xml_pipeline.llm.backend import LLMResponse + + configure_budget_registry(max_tokens_per_thread=1000) + + # Backend returns 960 tokens (crosses 75%, 90%, and 95%) + mock_backend = Mock() + mock_backend.name = "mock" + mock_backend.provider = "test" + mock_backend.serves_model = Mock(return_value=True) + mock_backend.priority = 1 + mock_backend.load = 0 + mock_backend.complete = AsyncMock(return_value=LLMResponse( + content="Big response", + model="test-model", + usage={"prompt_tokens": 900, "completion_tokens": 60, "total_tokens": 960}, + finish_reason="stop", + )) + + router = LLMRouter() + router.backends.append(mock_backend) + + injected_messages = [] + + mock_pump = Mock() + mock_pump.inject = AsyncMock(side_effect=lambda *args, **kwargs: injected_messages.append(args)) + mock_pump._wrap_in_envelope = Mock(return_value=b"warning") + + with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump): + await router.complete( + model="test-model", + messages=[{"role": "user", "content": "Hi"}], + thread_id="test-thread", + agent_id="agent", + ) + + # Should have injected 3 warnings (75%, 90%, 95%) + assert len(injected_messages) == 3 + + @pytest.mark.asyncio + async def test_no_warning_when_below_threshold(self): + """No warning should be injected when below all thresholds.""" + from xml_pipeline.llm.router import LLMRouter + from xml_pipeline.llm.backend import LLMResponse + + configure_budget_registry(max_tokens_per_thread=10000) + + # Backend returns 100 tokens (only 1%) + mock_backend = Mock() + mock_backend.name = "mock" + mock_backend.provider = "test" + mock_backend.serves_model = Mock(return_value=True) + mock_backend.priority = 1 + mock_backend.load = 0 + mock_backend.complete = AsyncMock(return_value=LLMResponse( + content="Small response", + model="test-model", + usage={"prompt_tokens": 80, "completion_tokens": 20, "total_tokens": 100}, + finish_reason="stop", + )) + + router = LLMRouter() + router.backends.append(mock_backend) + + mock_pump = Mock() + mock_pump.inject = AsyncMock() + + with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump): + await router.complete( + model="test-model", + messages=[{"role": "user", "content": "Hi"}], + thread_id="test-thread", + ) + + # No warnings should be injected + mock_pump.inject.assert_not_called() + + @pytest.mark.asyncio + async def test_warning_includes_correct_severity(self): + """Warning messages should have correct severity based on threshold.""" + from xml_pipeline.llm.router import LLMRouter + from xml_pipeline.llm.backend import LLMResponse + from xml_pipeline.primitives.budget_warning import BudgetWarning + + configure_budget_registry(max_tokens_per_thread=1000) + + mock_backend = Mock() + mock_backend.name = "mock" + mock_backend.provider = "test" + mock_backend.serves_model = Mock(return_value=True) + mock_backend.priority = 1 + mock_backend.load = 0 + mock_backend.complete = AsyncMock(return_value=LLMResponse( + content="Response", + model="test-model", + usage={"prompt_tokens": 910, "completion_tokens": 0, "total_tokens": 910}, + finish_reason="stop", + )) + + router = LLMRouter() + router.backends.append(mock_backend) + + captured_payloads = [] + + def capture_wrap(payload, from_id, to_id, thread_id): + captured_payloads.append({ + "payload": payload, + "from_id": from_id, + "to_id": to_id, + }) + return b"" + + mock_pump = Mock() + mock_pump.inject = AsyncMock() + mock_pump._wrap_in_envelope = Mock(side_effect=capture_wrap) + + with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump): + await router.complete( + model="test-model", + messages=[{"role": "user", "content": "Hi"}], + thread_id="test-thread", + agent_id="agent", + ) + + # Should have captured 2 payloads (75% and 90%) + assert len(captured_payloads) == 2 + + severities = [p["payload"].severity for p in captured_payloads] + assert "warning" in severities + assert "critical" in severities + + # Check the critical warning has appropriate message + critical = next(p for p in captured_payloads if p["payload"].severity == "critical") + assert "WARNING" in critical["payload"].message + assert "Consider wrapping up" in critical["payload"].message + + @pytest.mark.asyncio + async def test_warning_graceful_without_pump(self): + """Warning injection should gracefully handle missing pump.""" + from xml_pipeline.llm.router import LLMRouter + from xml_pipeline.llm.backend import LLMResponse + + configure_budget_registry(max_tokens_per_thread=1000) + + mock_backend = Mock() + mock_backend.name = "mock" + mock_backend.provider = "test" + mock_backend.serves_model = Mock(return_value=True) + mock_backend.priority = 1 + mock_backend.load = 0 + mock_backend.complete = AsyncMock(return_value=LLMResponse( + content="Response", + model="test-model", + usage={"prompt_tokens": 800, "completion_tokens": 0, "total_tokens": 800}, + finish_reason="stop", + )) + + router = LLMRouter() + router.backends.append(mock_backend) + + # Pump not initialized - should not raise + with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', side_effect=RuntimeError("Not initialized")): + response = await router.complete( + model="test-model", + messages=[{"role": "user", "content": "Hi"}], + thread_id="test-thread", + ) + + # Should still complete successfully + assert response.content == "Response" diff --git a/xml_pipeline/llm/router.py b/xml_pipeline/llm/router.py index 6c25048..61b60c5 100644 --- a/xml_pipeline/llm/router.py +++ b/xml_pipeline/llm/router.py @@ -210,15 +210,24 @@ class LLMRouter: usage.request_count += 1 # Record to thread budget (enforcement) + crossed_thresholds = [] if thread_id: from xml_pipeline.message_bus.budget_registry import get_budget_registry budget_registry = get_budget_registry() - budget_registry.consume( + _budget, crossed_thresholds = budget_registry.consume( thread_id, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ) + # Inject BudgetWarning messages for any newly crossed thresholds + if crossed_thresholds: + await self._inject_budget_warnings( + thread_id=thread_id, + agent_id=agent_id, + crossed=crossed_thresholds, + ) + # Emit usage event (for billing) from xml_pipeline.llm.usage_tracker import get_usage_tracker tracker = get_usage_tracker() @@ -265,6 +274,80 @@ class LLMRouter: jitter = delay * 0.25 * (random.random() * 2 - 1) return delay + jitter + async def _inject_budget_warnings( + self, + thread_id: str, + agent_id: Optional[str], + crossed: List, # List[BudgetThresholdCrossed] + ) -> None: + """ + Inject BudgetWarning messages when thresholds are crossed. + + Each crossed threshold generates a separate warning message + sent to the agent via the stream pump. + """ + from xml_pipeline.primitives.budget_warning import BudgetWarning + + # Get the stream pump (may not be initialized in tests) + try: + from xml_pipeline.message_bus.stream_pump import get_stream_pump + pump = get_stream_pump() + except RuntimeError: + # Pump not initialized - log and skip + logger.warning( + f"Cannot inject BudgetWarning: StreamPump not initialized. " + f"Thread {thread_id[:8]}... crossed thresholds: {[c.threshold_percent for c in crossed]}" + ) + return + + for threshold in crossed: + # Build human-readable message based on severity + if threshold.severity == "final": + message = ( + f"CRITICAL: {threshold.percent_used:.0f}% of token budget used. " + f"Only {threshold.tokens_remaining:,} tokens remaining. " + f"Wrap up immediately or your next request may fail." + ) + elif threshold.severity == "critical": + message = ( + f"WARNING: {threshold.percent_used:.0f}% of token budget used. " + f"{threshold.tokens_remaining:,} tokens remaining. " + f"Consider wrapping up your current task soon." + ) + else: # warning + message = ( + f"Note: {threshold.percent_used:.0f}% of token budget used. " + f"{threshold.tokens_remaining:,} tokens remaining." + ) + + warning_payload = BudgetWarning( + percent_used=threshold.percent_used, + tokens_used=threshold.tokens_used, + tokens_remaining=threshold.tokens_remaining, + max_tokens=threshold.max_tokens, + severity=threshold.severity, + message=message, + ) + + # Create envelope for the warning + # Target is the agent that made the LLM call + target = agent_id if agent_id else "system" + + envelope = pump._wrap_in_envelope( + payload=warning_payload, + from_id="system.budget", + to_id=target, + thread_id=thread_id, + ) + + # Inject into the pump's queue + await pump.inject(envelope, thread_id=thread_id, from_id="system.budget") + + logger.info( + f"BudgetWarning sent to {target}: {threshold.severity} " + f"({threshold.percent_used:.0f}% used, {threshold.tokens_remaining:,} remaining)" + ) + def get_agent_usage(self, agent_id: str) -> AgentUsage: """Get usage stats for an agent.""" return self._agent_usage.get(agent_id, AgentUsage()) diff --git a/xml_pipeline/message_bus/budget_registry.py b/xml_pipeline/message_bus/budget_registry.py index a56d5c5..e30843e 100644 --- a/xml_pipeline/message_bus/budget_registry.py +++ b/xml_pipeline/message_bus/budget_registry.py @@ -17,7 +17,26 @@ from __future__ import annotations import threading from dataclasses import dataclass, field -from typing import Dict, Optional +from typing import Dict, List, Optional, Set, Tuple + + +# Default warning thresholds (percent -> severity) +DEFAULT_WARNING_THRESHOLDS: Dict[int, str] = { + 75: "warning", # 75% - early warning + 90: "critical", # 90% - wrap up soon + 95: "final", # 95% - last chance +} + + +@dataclass +class BudgetThresholdCrossed: + """Info about a threshold that was just crossed.""" + threshold_percent: int + severity: str + percent_used: float + tokens_used: int + tokens_remaining: int + max_tokens: int @dataclass @@ -28,6 +47,7 @@ class ThreadBudget: prompt_tokens: int = 0 completion_tokens: int = 0 request_count: int = 0 + triggered_thresholds: Set[int] = field(default_factory=set) @property def total_tokens(self) -> int: @@ -44,6 +64,13 @@ class ThreadBudget: """True if budget is exhausted.""" return self.total_tokens >= self.max_tokens + @property + def percent_used(self) -> float: + """Percentage of budget consumed (0-100).""" + if self.max_tokens <= 0: + return 0.0 + return (self.total_tokens / self.max_tokens) * 100 + def can_consume(self, estimated_tokens: int) -> bool: """Check if we can consume the given tokens without exceeding budget.""" return self.total_tokens + estimated_tokens <= self.max_tokens @@ -58,6 +85,42 @@ class ThreadBudget: self.completion_tokens += completion_tokens self.request_count += 1 + def check_thresholds( + self, + thresholds: Dict[int, str] = None, + ) -> List[BudgetThresholdCrossed]: + """ + Check if any thresholds were crossed that haven't been triggered yet. + + Args: + thresholds: Dict of percent -> severity. Defaults to DEFAULT_WARNING_THRESHOLDS. + + Returns: + List of newly crossed thresholds (sorted by percent) + """ + if thresholds is None: + thresholds = DEFAULT_WARNING_THRESHOLDS + + crossed = [] + current_percent = self.percent_used + + for threshold_percent, severity in sorted(thresholds.items()): + if ( + current_percent >= threshold_percent + and threshold_percent not in self.triggered_thresholds + ): + self.triggered_thresholds.add(threshold_percent) + crossed.append(BudgetThresholdCrossed( + threshold_percent=threshold_percent, + severity=severity, + percent_used=round(current_percent, 1), + tokens_used=self.total_tokens, + tokens_remaining=self.remaining, + max_tokens=self.max_tokens, + )) + + return crossed + class BudgetExhaustedError(Exception): """Raised when a thread's token budget is exhausted.""" @@ -176,7 +239,7 @@ class ThreadBudgetRegistry: thread_id: str, prompt_tokens: int = 0, completion_tokens: int = 0, - ) -> ThreadBudget: + ) -> Tuple[ThreadBudget, List[BudgetThresholdCrossed]]: """ Record token consumption for a thread. @@ -186,12 +249,13 @@ class ThreadBudgetRegistry: completion_tokens: Completion tokens used Returns: - Updated ThreadBudget + Tuple of (Updated ThreadBudget, List of newly crossed thresholds) """ budget = self.get_budget(thread_id) with self._lock: budget.consume(prompt_tokens, completion_tokens) - return budget + crossed = budget.check_thresholds() + return budget, crossed def has_budget(self, thread_id: str) -> bool: """Check if a thread has a budget entry (without creating one).""" diff --git a/xml_pipeline/primitives/__init__.py b/xml_pipeline/primitives/__init__.py index a130ce0..15dfd56 100644 --- a/xml_pipeline/primitives/__init__.py +++ b/xml_pipeline/primitives/__init__.py @@ -29,6 +29,10 @@ from xml_pipeline.primitives.buffer import ( BufferError, handle_buffer_start, ) +from xml_pipeline.primitives.budget_warning import ( + BudgetWarning, + DEFAULT_THRESHOLDS, +) __all__ = [ # Boot @@ -56,4 +60,7 @@ __all__ = [ "BufferDispatched", "BufferError", "handle_buffer_start", + # Budget warnings + "BudgetWarning", + "DEFAULT_THRESHOLDS", ] diff --git a/xml_pipeline/primitives/budget_warning.py b/xml_pipeline/primitives/budget_warning.py new file mode 100644 index 0000000..54e31b8 --- /dev/null +++ b/xml_pipeline/primitives/budget_warning.py @@ -0,0 +1,66 @@ +""" +BudgetWarning — System alerts for token budget thresholds. + +When a thread approaches its token budget limit, the system injects +BudgetWarning messages to give agents a chance to wrap up gracefully. + +Thresholds: +- 75%: Early warning - consider wrapping up +- 90%: Critical warning - finish current task +- 95%: Final warning - immediate action required + +Agents can design contingencies: +- Summarize progress and respond early +- Escalate to a supervisor agent +- Save state for continuation +- Request budget increase (if supported) + +Example handler pattern: + async def my_agent(payload, metadata): + if isinstance(payload, BudgetWarning): + if payload.severity == "critical": + # Wrap up immediately + return HandlerResponse.respond( + payload=Summary(progress="Reached 90% budget, stopping here...") + ) + # Otherwise note it and continue + ... +""" + +# Note: Do NOT use `from __future__ import annotations` here +# as it breaks the xmlify decorator which needs concrete types + +from dataclasses import dataclass +from third_party.xmlable import xmlify + + +@xmlify +@dataclass +class BudgetWarning: + """ + System warning about token budget consumption. + + Sent to agents when their thread crosses budget thresholds. + + Attributes: + percent_used: Current percentage of budget consumed (0-100) + tokens_used: Total tokens consumed so far + tokens_remaining: Tokens remaining before exhaustion + max_tokens: Total budget for this thread + severity: Warning level (warning, critical, final) + message: Human-readable description + """ + percent_used: float + tokens_used: int + tokens_remaining: int + max_tokens: int + severity: str # "warning" (75%), "critical" (90%), "final" (95%) + message: str + + +# Default thresholds (can be configured) +DEFAULT_THRESHOLDS = { + 75: "warning", # 75% - early warning + 90: "critical", # 90% - wrap up soon + 95: "final", # 95% - last chance +}