Add BudgetWarning system alerts for token budget thresholds
- Create BudgetWarning primitive payload (75%, 90%, 95% thresholds) - Add threshold tracking to ThreadBudget with triggered_thresholds set - Change consume() to return (budget, crossed_thresholds) tuple - Wire warning injection in LLM router when thresholds crossed - Add 15 new tests for threshold detection and warning injection Agents now receive BudgetWarning messages when approaching their token limit, allowing them to design contingencies (summarize, escalate, save state). Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
f98a21f96b
commit
e6697f0ea2
5 changed files with 630 additions and 21 deletions
|
|
@ -8,12 +8,14 @@ Tests:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from unittest.mock import Mock, AsyncMock, patch
|
from unittest.mock import Mock, AsyncMock, patch, MagicMock
|
||||||
|
|
||||||
from xml_pipeline.message_bus.budget_registry import (
|
from xml_pipeline.message_bus.budget_registry import (
|
||||||
ThreadBudget,
|
ThreadBudget,
|
||||||
ThreadBudgetRegistry,
|
ThreadBudgetRegistry,
|
||||||
BudgetExhaustedError,
|
BudgetExhaustedError,
|
||||||
|
BudgetThresholdCrossed,
|
||||||
|
DEFAULT_WARNING_THRESHOLDS,
|
||||||
get_budget_registry,
|
get_budget_registry,
|
||||||
configure_budget_registry,
|
configure_budget_registry,
|
||||||
reset_budget_registry,
|
reset_budget_registry,
|
||||||
|
|
@ -122,7 +124,7 @@ class TestThreadBudgetRegistry:
|
||||||
def test_check_budget_exhausted(self):
|
def test_check_budget_exhausted(self):
|
||||||
"""check_budget should raise when budget exhausted."""
|
"""check_budget should raise when budget exhausted."""
|
||||||
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
|
||||||
registry.consume("thread-1", prompt_tokens=1000)
|
budget, _ = registry.consume("thread-1", prompt_tokens=1000)
|
||||||
|
|
||||||
with pytest.raises(BudgetExhaustedError) as exc_info:
|
with pytest.raises(BudgetExhaustedError) as exc_info:
|
||||||
registry.check_budget("thread-1", estimated_tokens=100)
|
registry.check_budget("thread-1", estimated_tokens=100)
|
||||||
|
|
@ -135,25 +137,26 @@ class TestThreadBudgetRegistry:
|
||||||
def test_check_budget_would_exceed(self):
|
def test_check_budget_would_exceed(self):
|
||||||
"""check_budget should raise when estimate would exceed."""
|
"""check_budget should raise when estimate would exceed."""
|
||||||
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
|
||||||
registry.consume("thread-1", prompt_tokens=600)
|
budget, _ = registry.consume("thread-1", prompt_tokens=600)
|
||||||
|
|
||||||
with pytest.raises(BudgetExhaustedError):
|
with pytest.raises(BudgetExhaustedError):
|
||||||
registry.check_budget("thread-1", estimated_tokens=500)
|
registry.check_budget("thread-1", estimated_tokens=500)
|
||||||
|
|
||||||
def test_consume_returns_budget(self):
|
def test_consume_returns_budget_and_thresholds(self):
|
||||||
"""consume() should return updated budget."""
|
"""consume() should return (budget, crossed_thresholds) tuple."""
|
||||||
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||||
|
|
||||||
budget = registry.consume("thread-1", prompt_tokens=100, completion_tokens=50)
|
budget, crossed = registry.consume("thread-1", prompt_tokens=100, completion_tokens=50)
|
||||||
|
|
||||||
assert budget.total_tokens == 150
|
assert budget.total_tokens == 150
|
||||||
assert budget.request_count == 1
|
assert budget.request_count == 1
|
||||||
|
assert crossed == [] # 1.5% - no threshold crossed
|
||||||
|
|
||||||
def test_get_usage(self):
|
def test_get_usage(self):
|
||||||
"""get_usage should return dict with all stats."""
|
"""get_usage should return dict with all stats."""
|
||||||
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||||
registry.consume("thread-1", prompt_tokens=500, completion_tokens=200)
|
budget, _ = registry.consume("thread-1", prompt_tokens=500, completion_tokens=200)
|
||||||
registry.consume("thread-1", prompt_tokens=300, completion_tokens=100)
|
budget, _ = registry.consume("thread-1", prompt_tokens=300, completion_tokens=100)
|
||||||
|
|
||||||
usage = registry.get_usage("thread-1")
|
usage = registry.get_usage("thread-1")
|
||||||
|
|
||||||
|
|
@ -166,8 +169,8 @@ class TestThreadBudgetRegistry:
|
||||||
def test_get_all_usage(self):
|
def test_get_all_usage(self):
|
||||||
"""get_all_usage should return all threads."""
|
"""get_all_usage should return all threads."""
|
||||||
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||||
registry.consume("thread-1", prompt_tokens=100)
|
budget, _ = registry.consume("thread-1", prompt_tokens=100)
|
||||||
registry.consume("thread-2", prompt_tokens=200)
|
budget, _ = registry.consume("thread-2", prompt_tokens=200)
|
||||||
|
|
||||||
all_usage = registry.get_all_usage()
|
all_usage = registry.get_all_usage()
|
||||||
|
|
||||||
|
|
@ -178,7 +181,7 @@ class TestThreadBudgetRegistry:
|
||||||
def test_reset_thread(self):
|
def test_reset_thread(self):
|
||||||
"""reset_thread should remove budget for thread."""
|
"""reset_thread should remove budget for thread."""
|
||||||
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||||
registry.consume("thread-1", prompt_tokens=500)
|
budget, _ = registry.consume("thread-1", prompt_tokens=500)
|
||||||
registry.reset_thread("thread-1")
|
registry.reset_thread("thread-1")
|
||||||
|
|
||||||
# Getting budget should create new one with zero usage
|
# Getting budget should create new one with zero usage
|
||||||
|
|
@ -188,7 +191,7 @@ class TestThreadBudgetRegistry:
|
||||||
def test_cleanup_thread(self):
|
def test_cleanup_thread(self):
|
||||||
"""cleanup_thread should return and remove budget."""
|
"""cleanup_thread should return and remove budget."""
|
||||||
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||||
registry.consume("thread-1", prompt_tokens=500)
|
budget, _ = registry.consume("thread-1", prompt_tokens=500)
|
||||||
|
|
||||||
final_budget = registry.cleanup_thread("thread-1")
|
final_budget = registry.cleanup_thread("thread-1")
|
||||||
|
|
||||||
|
|
@ -520,7 +523,7 @@ class TestLLMRouterBudgetIntegration:
|
||||||
# Configure small budget and exhaust it
|
# Configure small budget and exhaust it
|
||||||
configure_budget_registry(max_tokens_per_thread=100)
|
configure_budget_registry(max_tokens_per_thread=100)
|
||||||
budget_registry = get_budget_registry()
|
budget_registry = get_budget_registry()
|
||||||
budget_registry.consume("test-thread", prompt_tokens=100)
|
budget, _ = budget_registry.consume("test-thread", prompt_tokens=100)
|
||||||
|
|
||||||
mock_backend = Mock()
|
mock_backend = Mock()
|
||||||
mock_backend.name = "mock"
|
mock_backend.name = "mock"
|
||||||
|
|
@ -590,7 +593,7 @@ class TestBudgetCleanup:
|
||||||
def test_cleanup_thread_returns_budget(self):
|
def test_cleanup_thread_returns_budget(self):
|
||||||
"""cleanup_thread should return the budget before removing it."""
|
"""cleanup_thread should return the budget before removing it."""
|
||||||
registry = ThreadBudgetRegistry()
|
registry = ThreadBudgetRegistry()
|
||||||
registry.consume("thread-1", prompt_tokens=500, completion_tokens=200)
|
budget, _ = registry.consume("thread-1", prompt_tokens=500, completion_tokens=200)
|
||||||
|
|
||||||
final = registry.cleanup_thread("thread-1")
|
final = registry.cleanup_thread("thread-1")
|
||||||
|
|
||||||
|
|
@ -602,7 +605,7 @@ class TestBudgetCleanup:
|
||||||
def test_cleanup_thread_removes_budget(self):
|
def test_cleanup_thread_removes_budget(self):
|
||||||
"""cleanup_thread should remove the budget from registry."""
|
"""cleanup_thread should remove the budget from registry."""
|
||||||
registry = ThreadBudgetRegistry()
|
registry = ThreadBudgetRegistry()
|
||||||
registry.consume("thread-1", prompt_tokens=500, completion_tokens=200)
|
budget, _ = registry.consume("thread-1", prompt_tokens=500, completion_tokens=200)
|
||||||
|
|
||||||
registry.cleanup_thread("thread-1")
|
registry.cleanup_thread("thread-1")
|
||||||
|
|
||||||
|
|
@ -624,7 +627,7 @@ class TestBudgetCleanup:
|
||||||
registry = get_budget_registry()
|
registry = get_budget_registry()
|
||||||
|
|
||||||
# Consume some tokens
|
# Consume some tokens
|
||||||
registry.consume("test-thread", prompt_tokens=1000, completion_tokens=500)
|
budget, _ = registry.consume("test-thread", prompt_tokens=1000, completion_tokens=500)
|
||||||
assert registry.has_budget("test-thread")
|
assert registry.has_budget("test-thread")
|
||||||
|
|
||||||
# Cleanup
|
# Cleanup
|
||||||
|
|
@ -632,3 +635,389 @@ class TestBudgetCleanup:
|
||||||
|
|
||||||
assert final.total_tokens == 1500
|
assert final.total_tokens == 1500
|
||||||
assert not registry.has_budget("test-thread")
|
assert not registry.has_budget("test-thread")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Budget Threshold Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestBudgetThresholds:
|
||||||
|
"""Test budget threshold crossing detection."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_all(self):
|
||||||
|
"""Reset all global registries."""
|
||||||
|
reset_budget_registry()
|
||||||
|
yield
|
||||||
|
reset_budget_registry()
|
||||||
|
|
||||||
|
def test_no_threshold_crossed_below_first(self):
|
||||||
|
"""No thresholds crossed when under 75%."""
|
||||||
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
|
||||||
|
|
||||||
|
# 70% usage - below first threshold
|
||||||
|
budget, crossed = registry.consume("thread-1", prompt_tokens=700)
|
||||||
|
|
||||||
|
assert crossed == []
|
||||||
|
assert budget.percent_used == 70.0
|
||||||
|
|
||||||
|
def test_warning_threshold_crossed_at_75(self):
|
||||||
|
"""75% threshold should trigger 'warning' severity."""
|
||||||
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
|
||||||
|
|
||||||
|
# 75% usage - exactly at threshold
|
||||||
|
budget, crossed = registry.consume("thread-1", prompt_tokens=750)
|
||||||
|
|
||||||
|
assert len(crossed) == 1
|
||||||
|
assert crossed[0].threshold_percent == 75
|
||||||
|
assert crossed[0].severity == "warning"
|
||||||
|
assert crossed[0].percent_used == 75.0
|
||||||
|
assert crossed[0].tokens_used == 750
|
||||||
|
assert crossed[0].tokens_remaining == 250
|
||||||
|
|
||||||
|
def test_critical_threshold_crossed_at_90(self):
|
||||||
|
"""90% threshold should trigger 'critical' severity."""
|
||||||
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
|
||||||
|
|
||||||
|
# First consume to 76% (triggers warning)
|
||||||
|
budget, crossed = registry.consume("thread-1", prompt_tokens=760)
|
||||||
|
assert len(crossed) == 1
|
||||||
|
assert crossed[0].severity == "warning"
|
||||||
|
|
||||||
|
# Now consume to 91% (triggers critical)
|
||||||
|
budget, crossed = registry.consume("thread-1", prompt_tokens=150)
|
||||||
|
assert len(crossed) == 1
|
||||||
|
assert crossed[0].threshold_percent == 90
|
||||||
|
assert crossed[0].severity == "critical"
|
||||||
|
|
||||||
|
def test_final_threshold_crossed_at_95(self):
|
||||||
|
"""95% threshold should trigger 'final' severity."""
|
||||||
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
|
||||||
|
|
||||||
|
# Jump directly to 96%
|
||||||
|
budget, crossed = registry.consume("thread-1", prompt_tokens=960)
|
||||||
|
|
||||||
|
# Should have crossed all three thresholds
|
||||||
|
assert len(crossed) == 3
|
||||||
|
severities = {c.severity for c in crossed}
|
||||||
|
assert severities == {"warning", "critical", "final"}
|
||||||
|
|
||||||
|
def test_thresholds_only_triggered_once(self):
|
||||||
|
"""Each threshold should only be triggered once per thread."""
|
||||||
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
|
||||||
|
|
||||||
|
# Cross the 75% threshold
|
||||||
|
budget, crossed1 = registry.consume("thread-1", prompt_tokens=760)
|
||||||
|
assert len(crossed1) == 1
|
||||||
|
|
||||||
|
# Stay at 76% with another consume - should not trigger again
|
||||||
|
budget, crossed2 = registry.consume("thread-1", prompt_tokens=0)
|
||||||
|
assert crossed2 == []
|
||||||
|
|
||||||
|
# Consume more to stay between 75-90% - should not trigger
|
||||||
|
budget, crossed3 = registry.consume("thread-1", prompt_tokens=50) # Now at 81%
|
||||||
|
assert crossed3 == []
|
||||||
|
|
||||||
|
def test_multiple_thresholds_in_single_consume(self):
|
||||||
|
"""A single large consume can cross multiple thresholds."""
|
||||||
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
|
||||||
|
|
||||||
|
# Single consume from 0% to 92% - crosses 75% and 90%
|
||||||
|
budget, crossed = registry.consume("thread-1", prompt_tokens=920)
|
||||||
|
|
||||||
|
assert len(crossed) == 2
|
||||||
|
thresholds = {c.threshold_percent for c in crossed}
|
||||||
|
assert thresholds == {75, 90}
|
||||||
|
|
||||||
|
def test_threshold_data_accuracy(self):
|
||||||
|
"""Threshold crossing data should be accurate."""
|
||||||
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=10000)
|
||||||
|
|
||||||
|
# Consume to 75.5%
|
||||||
|
budget, crossed = registry.consume("thread-1", prompt_tokens=7550)
|
||||||
|
|
||||||
|
assert len(crossed) == 1
|
||||||
|
threshold = crossed[0]
|
||||||
|
assert threshold.threshold_percent == 75
|
||||||
|
assert threshold.percent_used == 75.5
|
||||||
|
assert threshold.tokens_used == 7550
|
||||||
|
assert threshold.tokens_remaining == 2450
|
||||||
|
assert threshold.max_tokens == 10000
|
||||||
|
|
||||||
|
def test_percent_used_property(self):
|
||||||
|
"""percent_used property should calculate correctly."""
|
||||||
|
budget = ThreadBudget(max_tokens=1000)
|
||||||
|
|
||||||
|
assert budget.percent_used == 0.0
|
||||||
|
|
||||||
|
budget.consume(prompt_tokens=500)
|
||||||
|
assert budget.percent_used == 50.0
|
||||||
|
|
||||||
|
budget.consume(prompt_tokens=250)
|
||||||
|
assert budget.percent_used == 75.0
|
||||||
|
|
||||||
|
def test_percent_used_with_zero_max(self):
|
||||||
|
"""percent_used should handle zero max_tokens gracefully."""
|
||||||
|
budget = ThreadBudget(max_tokens=0)
|
||||||
|
assert budget.percent_used == 0.0
|
||||||
|
|
||||||
|
def test_triggered_thresholds_tracked_per_thread(self):
|
||||||
|
"""Different threads should track thresholds independently."""
|
||||||
|
registry = ThreadBudgetRegistry(max_tokens_per_thread=1000)
|
||||||
|
|
||||||
|
# Thread 1 crosses 75%
|
||||||
|
budget1, crossed1 = registry.consume("thread-1", prompt_tokens=760)
|
||||||
|
assert len(crossed1) == 1
|
||||||
|
|
||||||
|
# Thread 2 crosses 75% independently
|
||||||
|
budget2, crossed2 = registry.consume("thread-2", prompt_tokens=800)
|
||||||
|
assert len(crossed2) == 1
|
||||||
|
|
||||||
|
# Both should have triggered their thresholds
|
||||||
|
assert 75 in budget1.triggered_thresholds
|
||||||
|
assert 75 in budget2.triggered_thresholds
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# BudgetWarning Injection Tests
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
class TestBudgetWarningInjection:
|
||||||
|
"""Test BudgetWarning messages are injected when thresholds crossed."""
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_all(self):
|
||||||
|
"""Reset all global registries."""
|
||||||
|
reset_budget_registry()
|
||||||
|
reset_usage_tracker()
|
||||||
|
yield
|
||||||
|
reset_budget_registry()
|
||||||
|
reset_usage_tracker()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_warning_injected_on_threshold_crossing(self):
|
||||||
|
"""LLM complete should inject BudgetWarning when threshold crossed."""
|
||||||
|
from xml_pipeline.llm.router import LLMRouter
|
||||||
|
from xml_pipeline.llm.backend import LLMResponse
|
||||||
|
from xml_pipeline.primitives.budget_warning import BudgetWarning
|
||||||
|
|
||||||
|
# Configure budget for 1000 tokens
|
||||||
|
configure_budget_registry(max_tokens_per_thread=1000)
|
||||||
|
|
||||||
|
# Create mock backend that returns 760 tokens (crosses 75% threshold)
|
||||||
|
mock_backend = Mock()
|
||||||
|
mock_backend.name = "mock"
|
||||||
|
mock_backend.provider = "test"
|
||||||
|
mock_backend.serves_model = Mock(return_value=True)
|
||||||
|
mock_backend.priority = 1
|
||||||
|
mock_backend.load = 0
|
||||||
|
mock_backend.complete = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="Hello!",
|
||||||
|
model="test-model",
|
||||||
|
usage={"prompt_tokens": 700, "completion_tokens": 60, "total_tokens": 760},
|
||||||
|
finish_reason="stop",
|
||||||
|
))
|
||||||
|
|
||||||
|
router = LLMRouter()
|
||||||
|
router.backends.append(mock_backend)
|
||||||
|
|
||||||
|
# Mock the pump to capture injected warnings
|
||||||
|
injected_messages = []
|
||||||
|
|
||||||
|
async def mock_inject(raw_bytes, thread_id, from_id):
|
||||||
|
injected_messages.append({
|
||||||
|
"raw_bytes": raw_bytes,
|
||||||
|
"thread_id": thread_id,
|
||||||
|
"from_id": from_id,
|
||||||
|
})
|
||||||
|
|
||||||
|
mock_pump = Mock()
|
||||||
|
mock_pump.inject = AsyncMock(side_effect=mock_inject)
|
||||||
|
mock_pump._wrap_in_envelope = Mock(return_value=b"<envelope>warning</envelope>")
|
||||||
|
|
||||||
|
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
||||||
|
await router.complete(
|
||||||
|
model="test-model",
|
||||||
|
messages=[{"role": "user", "content": "Hi"}],
|
||||||
|
thread_id="test-thread",
|
||||||
|
agent_id="greeter",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have injected one warning (75% threshold)
|
||||||
|
assert len(injected_messages) == 1
|
||||||
|
assert injected_messages[0]["from_id"] == "system.budget"
|
||||||
|
assert injected_messages[0]["thread_id"] == "test-thread"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_warnings_injected(self):
|
||||||
|
"""Multiple thresholds crossed should inject multiple warnings."""
|
||||||
|
from xml_pipeline.llm.router import LLMRouter
|
||||||
|
from xml_pipeline.llm.backend import LLMResponse
|
||||||
|
|
||||||
|
configure_budget_registry(max_tokens_per_thread=1000)
|
||||||
|
|
||||||
|
# Backend returns 960 tokens (crosses 75%, 90%, and 95%)
|
||||||
|
mock_backend = Mock()
|
||||||
|
mock_backend.name = "mock"
|
||||||
|
mock_backend.provider = "test"
|
||||||
|
mock_backend.serves_model = Mock(return_value=True)
|
||||||
|
mock_backend.priority = 1
|
||||||
|
mock_backend.load = 0
|
||||||
|
mock_backend.complete = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="Big response",
|
||||||
|
model="test-model",
|
||||||
|
usage={"prompt_tokens": 900, "completion_tokens": 60, "total_tokens": 960},
|
||||||
|
finish_reason="stop",
|
||||||
|
))
|
||||||
|
|
||||||
|
router = LLMRouter()
|
||||||
|
router.backends.append(mock_backend)
|
||||||
|
|
||||||
|
injected_messages = []
|
||||||
|
|
||||||
|
mock_pump = Mock()
|
||||||
|
mock_pump.inject = AsyncMock(side_effect=lambda *args, **kwargs: injected_messages.append(args))
|
||||||
|
mock_pump._wrap_in_envelope = Mock(return_value=b"<envelope>warning</envelope>")
|
||||||
|
|
||||||
|
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
||||||
|
await router.complete(
|
||||||
|
model="test-model",
|
||||||
|
messages=[{"role": "user", "content": "Hi"}],
|
||||||
|
thread_id="test-thread",
|
||||||
|
agent_id="agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have injected 3 warnings (75%, 90%, 95%)
|
||||||
|
assert len(injected_messages) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_warning_when_below_threshold(self):
|
||||||
|
"""No warning should be injected when below all thresholds."""
|
||||||
|
from xml_pipeline.llm.router import LLMRouter
|
||||||
|
from xml_pipeline.llm.backend import LLMResponse
|
||||||
|
|
||||||
|
configure_budget_registry(max_tokens_per_thread=10000)
|
||||||
|
|
||||||
|
# Backend returns 100 tokens (only 1%)
|
||||||
|
mock_backend = Mock()
|
||||||
|
mock_backend.name = "mock"
|
||||||
|
mock_backend.provider = "test"
|
||||||
|
mock_backend.serves_model = Mock(return_value=True)
|
||||||
|
mock_backend.priority = 1
|
||||||
|
mock_backend.load = 0
|
||||||
|
mock_backend.complete = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="Small response",
|
||||||
|
model="test-model",
|
||||||
|
usage={"prompt_tokens": 80, "completion_tokens": 20, "total_tokens": 100},
|
||||||
|
finish_reason="stop",
|
||||||
|
))
|
||||||
|
|
||||||
|
router = LLMRouter()
|
||||||
|
router.backends.append(mock_backend)
|
||||||
|
|
||||||
|
mock_pump = Mock()
|
||||||
|
mock_pump.inject = AsyncMock()
|
||||||
|
|
||||||
|
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
||||||
|
await router.complete(
|
||||||
|
model="test-model",
|
||||||
|
messages=[{"role": "user", "content": "Hi"}],
|
||||||
|
thread_id="test-thread",
|
||||||
|
)
|
||||||
|
|
||||||
|
# No warnings should be injected
|
||||||
|
mock_pump.inject.assert_not_called()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_warning_includes_correct_severity(self):
|
||||||
|
"""Warning messages should have correct severity based on threshold."""
|
||||||
|
from xml_pipeline.llm.router import LLMRouter
|
||||||
|
from xml_pipeline.llm.backend import LLMResponse
|
||||||
|
from xml_pipeline.primitives.budget_warning import BudgetWarning
|
||||||
|
|
||||||
|
configure_budget_registry(max_tokens_per_thread=1000)
|
||||||
|
|
||||||
|
mock_backend = Mock()
|
||||||
|
mock_backend.name = "mock"
|
||||||
|
mock_backend.provider = "test"
|
||||||
|
mock_backend.serves_model = Mock(return_value=True)
|
||||||
|
mock_backend.priority = 1
|
||||||
|
mock_backend.load = 0
|
||||||
|
mock_backend.complete = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="Response",
|
||||||
|
model="test-model",
|
||||||
|
usage={"prompt_tokens": 910, "completion_tokens": 0, "total_tokens": 910},
|
||||||
|
finish_reason="stop",
|
||||||
|
))
|
||||||
|
|
||||||
|
router = LLMRouter()
|
||||||
|
router.backends.append(mock_backend)
|
||||||
|
|
||||||
|
captured_payloads = []
|
||||||
|
|
||||||
|
def capture_wrap(payload, from_id, to_id, thread_id):
|
||||||
|
captured_payloads.append({
|
||||||
|
"payload": payload,
|
||||||
|
"from_id": from_id,
|
||||||
|
"to_id": to_id,
|
||||||
|
})
|
||||||
|
return b"<envelope/>"
|
||||||
|
|
||||||
|
mock_pump = Mock()
|
||||||
|
mock_pump.inject = AsyncMock()
|
||||||
|
mock_pump._wrap_in_envelope = Mock(side_effect=capture_wrap)
|
||||||
|
|
||||||
|
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', return_value=mock_pump):
|
||||||
|
await router.complete(
|
||||||
|
model="test-model",
|
||||||
|
messages=[{"role": "user", "content": "Hi"}],
|
||||||
|
thread_id="test-thread",
|
||||||
|
agent_id="agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have captured 2 payloads (75% and 90%)
|
||||||
|
assert len(captured_payloads) == 2
|
||||||
|
|
||||||
|
severities = [p["payload"].severity for p in captured_payloads]
|
||||||
|
assert "warning" in severities
|
||||||
|
assert "critical" in severities
|
||||||
|
|
||||||
|
# Check the critical warning has appropriate message
|
||||||
|
critical = next(p for p in captured_payloads if p["payload"].severity == "critical")
|
||||||
|
assert "WARNING" in critical["payload"].message
|
||||||
|
assert "Consider wrapping up" in critical["payload"].message
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_warning_graceful_without_pump(self):
|
||||||
|
"""Warning injection should gracefully handle missing pump."""
|
||||||
|
from xml_pipeline.llm.router import LLMRouter
|
||||||
|
from xml_pipeline.llm.backend import LLMResponse
|
||||||
|
|
||||||
|
configure_budget_registry(max_tokens_per_thread=1000)
|
||||||
|
|
||||||
|
mock_backend = Mock()
|
||||||
|
mock_backend.name = "mock"
|
||||||
|
mock_backend.provider = "test"
|
||||||
|
mock_backend.serves_model = Mock(return_value=True)
|
||||||
|
mock_backend.priority = 1
|
||||||
|
mock_backend.load = 0
|
||||||
|
mock_backend.complete = AsyncMock(return_value=LLMResponse(
|
||||||
|
content="Response",
|
||||||
|
model="test-model",
|
||||||
|
usage={"prompt_tokens": 800, "completion_tokens": 0, "total_tokens": 800},
|
||||||
|
finish_reason="stop",
|
||||||
|
))
|
||||||
|
|
||||||
|
router = LLMRouter()
|
||||||
|
router.backends.append(mock_backend)
|
||||||
|
|
||||||
|
# Pump not initialized - should not raise
|
||||||
|
with patch('xml_pipeline.message_bus.stream_pump.get_stream_pump', side_effect=RuntimeError("Not initialized")):
|
||||||
|
response = await router.complete(
|
||||||
|
model="test-model",
|
||||||
|
messages=[{"role": "user", "content": "Hi"}],
|
||||||
|
thread_id="test-thread",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should still complete successfully
|
||||||
|
assert response.content == "Response"
|
||||||
|
|
|
||||||
|
|
@ -210,15 +210,24 @@ class LLMRouter:
|
||||||
usage.request_count += 1
|
usage.request_count += 1
|
||||||
|
|
||||||
# Record to thread budget (enforcement)
|
# Record to thread budget (enforcement)
|
||||||
|
crossed_thresholds = []
|
||||||
if thread_id:
|
if thread_id:
|
||||||
from xml_pipeline.message_bus.budget_registry import get_budget_registry
|
from xml_pipeline.message_bus.budget_registry import get_budget_registry
|
||||||
budget_registry = get_budget_registry()
|
budget_registry = get_budget_registry()
|
||||||
budget_registry.consume(
|
_budget, crossed_thresholds = budget_registry.consume(
|
||||||
thread_id,
|
thread_id,
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Inject BudgetWarning messages for any newly crossed thresholds
|
||||||
|
if crossed_thresholds:
|
||||||
|
await self._inject_budget_warnings(
|
||||||
|
thread_id=thread_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
crossed=crossed_thresholds,
|
||||||
|
)
|
||||||
|
|
||||||
# Emit usage event (for billing)
|
# Emit usage event (for billing)
|
||||||
from xml_pipeline.llm.usage_tracker import get_usage_tracker
|
from xml_pipeline.llm.usage_tracker import get_usage_tracker
|
||||||
tracker = get_usage_tracker()
|
tracker = get_usage_tracker()
|
||||||
|
|
@ -265,6 +274,80 @@ class LLMRouter:
|
||||||
jitter = delay * 0.25 * (random.random() * 2 - 1)
|
jitter = delay * 0.25 * (random.random() * 2 - 1)
|
||||||
return delay + jitter
|
return delay + jitter
|
||||||
|
|
||||||
|
async def _inject_budget_warnings(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
agent_id: Optional[str],
|
||||||
|
crossed: List, # List[BudgetThresholdCrossed]
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Inject BudgetWarning messages when thresholds are crossed.
|
||||||
|
|
||||||
|
Each crossed threshold generates a separate warning message
|
||||||
|
sent to the agent via the stream pump.
|
||||||
|
"""
|
||||||
|
from xml_pipeline.primitives.budget_warning import BudgetWarning
|
||||||
|
|
||||||
|
# Get the stream pump (may not be initialized in tests)
|
||||||
|
try:
|
||||||
|
from xml_pipeline.message_bus.stream_pump import get_stream_pump
|
||||||
|
pump = get_stream_pump()
|
||||||
|
except RuntimeError:
|
||||||
|
# Pump not initialized - log and skip
|
||||||
|
logger.warning(
|
||||||
|
f"Cannot inject BudgetWarning: StreamPump not initialized. "
|
||||||
|
f"Thread {thread_id[:8]}... crossed thresholds: {[c.threshold_percent for c in crossed]}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
for threshold in crossed:
|
||||||
|
# Build human-readable message based on severity
|
||||||
|
if threshold.severity == "final":
|
||||||
|
message = (
|
||||||
|
f"CRITICAL: {threshold.percent_used:.0f}% of token budget used. "
|
||||||
|
f"Only {threshold.tokens_remaining:,} tokens remaining. "
|
||||||
|
f"Wrap up immediately or your next request may fail."
|
||||||
|
)
|
||||||
|
elif threshold.severity == "critical":
|
||||||
|
message = (
|
||||||
|
f"WARNING: {threshold.percent_used:.0f}% of token budget used. "
|
||||||
|
f"{threshold.tokens_remaining:,} tokens remaining. "
|
||||||
|
f"Consider wrapping up your current task soon."
|
||||||
|
)
|
||||||
|
else: # warning
|
||||||
|
message = (
|
||||||
|
f"Note: {threshold.percent_used:.0f}% of token budget used. "
|
||||||
|
f"{threshold.tokens_remaining:,} tokens remaining."
|
||||||
|
)
|
||||||
|
|
||||||
|
warning_payload = BudgetWarning(
|
||||||
|
percent_used=threshold.percent_used,
|
||||||
|
tokens_used=threshold.tokens_used,
|
||||||
|
tokens_remaining=threshold.tokens_remaining,
|
||||||
|
max_tokens=threshold.max_tokens,
|
||||||
|
severity=threshold.severity,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create envelope for the warning
|
||||||
|
# Target is the agent that made the LLM call
|
||||||
|
target = agent_id if agent_id else "system"
|
||||||
|
|
||||||
|
envelope = pump._wrap_in_envelope(
|
||||||
|
payload=warning_payload,
|
||||||
|
from_id="system.budget",
|
||||||
|
to_id=target,
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inject into the pump's queue
|
||||||
|
await pump.inject(envelope, thread_id=thread_id, from_id="system.budget")
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"BudgetWarning sent to {target}: {threshold.severity} "
|
||||||
|
f"({threshold.percent_used:.0f}% used, {threshold.tokens_remaining:,} remaining)"
|
||||||
|
)
|
||||||
|
|
||||||
def get_agent_usage(self, agent_id: str) -> AgentUsage:
|
def get_agent_usage(self, agent_id: str) -> AgentUsage:
|
||||||
"""Get usage stats for an agent."""
|
"""Get usage stats for an agent."""
|
||||||
return self._agent_usage.get(agent_id, AgentUsage())
|
return self._agent_usage.get(agent_id, AgentUsage())
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,26 @@ from __future__ import annotations
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Dict, Optional
|
from typing import Dict, List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
# Default warning thresholds (percent -> severity)
|
||||||
|
DEFAULT_WARNING_THRESHOLDS: Dict[int, str] = {
|
||||||
|
75: "warning", # 75% - early warning
|
||||||
|
90: "critical", # 90% - wrap up soon
|
||||||
|
95: "final", # 95% - last chance
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BudgetThresholdCrossed:
|
||||||
|
"""Info about a threshold that was just crossed."""
|
||||||
|
threshold_percent: int
|
||||||
|
severity: str
|
||||||
|
percent_used: float
|
||||||
|
tokens_used: int
|
||||||
|
tokens_remaining: int
|
||||||
|
max_tokens: int
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -28,6 +47,7 @@ class ThreadBudget:
|
||||||
prompt_tokens: int = 0
|
prompt_tokens: int = 0
|
||||||
completion_tokens: int = 0
|
completion_tokens: int = 0
|
||||||
request_count: int = 0
|
request_count: int = 0
|
||||||
|
triggered_thresholds: Set[int] = field(default_factory=set)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def total_tokens(self) -> int:
|
def total_tokens(self) -> int:
|
||||||
|
|
@ -44,6 +64,13 @@ class ThreadBudget:
|
||||||
"""True if budget is exhausted."""
|
"""True if budget is exhausted."""
|
||||||
return self.total_tokens >= self.max_tokens
|
return self.total_tokens >= self.max_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def percent_used(self) -> float:
|
||||||
|
"""Percentage of budget consumed (0-100)."""
|
||||||
|
if self.max_tokens <= 0:
|
||||||
|
return 0.0
|
||||||
|
return (self.total_tokens / self.max_tokens) * 100
|
||||||
|
|
||||||
def can_consume(self, estimated_tokens: int) -> bool:
|
def can_consume(self, estimated_tokens: int) -> bool:
|
||||||
"""Check if we can consume the given tokens without exceeding budget."""
|
"""Check if we can consume the given tokens without exceeding budget."""
|
||||||
return self.total_tokens + estimated_tokens <= self.max_tokens
|
return self.total_tokens + estimated_tokens <= self.max_tokens
|
||||||
|
|
@ -58,6 +85,42 @@ class ThreadBudget:
|
||||||
self.completion_tokens += completion_tokens
|
self.completion_tokens += completion_tokens
|
||||||
self.request_count += 1
|
self.request_count += 1
|
||||||
|
|
||||||
|
def check_thresholds(
|
||||||
|
self,
|
||||||
|
thresholds: Dict[int, str] = None,
|
||||||
|
) -> List[BudgetThresholdCrossed]:
|
||||||
|
"""
|
||||||
|
Check if any thresholds were crossed that haven't been triggered yet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
thresholds: Dict of percent -> severity. Defaults to DEFAULT_WARNING_THRESHOLDS.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of newly crossed thresholds (sorted by percent)
|
||||||
|
"""
|
||||||
|
if thresholds is None:
|
||||||
|
thresholds = DEFAULT_WARNING_THRESHOLDS
|
||||||
|
|
||||||
|
crossed = []
|
||||||
|
current_percent = self.percent_used
|
||||||
|
|
||||||
|
for threshold_percent, severity in sorted(thresholds.items()):
|
||||||
|
if (
|
||||||
|
current_percent >= threshold_percent
|
||||||
|
and threshold_percent not in self.triggered_thresholds
|
||||||
|
):
|
||||||
|
self.triggered_thresholds.add(threshold_percent)
|
||||||
|
crossed.append(BudgetThresholdCrossed(
|
||||||
|
threshold_percent=threshold_percent,
|
||||||
|
severity=severity,
|
||||||
|
percent_used=round(current_percent, 1),
|
||||||
|
tokens_used=self.total_tokens,
|
||||||
|
tokens_remaining=self.remaining,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
|
))
|
||||||
|
|
||||||
|
return crossed
|
||||||
|
|
||||||
|
|
||||||
class BudgetExhaustedError(Exception):
|
class BudgetExhaustedError(Exception):
|
||||||
"""Raised when a thread's token budget is exhausted."""
|
"""Raised when a thread's token budget is exhausted."""
|
||||||
|
|
@ -176,7 +239,7 @@ class ThreadBudgetRegistry:
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
prompt_tokens: int = 0,
|
prompt_tokens: int = 0,
|
||||||
completion_tokens: int = 0,
|
completion_tokens: int = 0,
|
||||||
) -> ThreadBudget:
|
) -> Tuple[ThreadBudget, List[BudgetThresholdCrossed]]:
|
||||||
"""
|
"""
|
||||||
Record token consumption for a thread.
|
Record token consumption for a thread.
|
||||||
|
|
||||||
|
|
@ -186,12 +249,13 @@ class ThreadBudgetRegistry:
|
||||||
completion_tokens: Completion tokens used
|
completion_tokens: Completion tokens used
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Updated ThreadBudget
|
Tuple of (Updated ThreadBudget, List of newly crossed thresholds)
|
||||||
"""
|
"""
|
||||||
budget = self.get_budget(thread_id)
|
budget = self.get_budget(thread_id)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
budget.consume(prompt_tokens, completion_tokens)
|
budget.consume(prompt_tokens, completion_tokens)
|
||||||
return budget
|
crossed = budget.check_thresholds()
|
||||||
|
return budget, crossed
|
||||||
|
|
||||||
def has_budget(self, thread_id: str) -> bool:
|
def has_budget(self, thread_id: str) -> bool:
|
||||||
"""Check if a thread has a budget entry (without creating one)."""
|
"""Check if a thread has a budget entry (without creating one)."""
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,10 @@ from xml_pipeline.primitives.buffer import (
|
||||||
BufferError,
|
BufferError,
|
||||||
handle_buffer_start,
|
handle_buffer_start,
|
||||||
)
|
)
|
||||||
|
from xml_pipeline.primitives.budget_warning import (
|
||||||
|
BudgetWarning,
|
||||||
|
DEFAULT_THRESHOLDS,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Boot
|
# Boot
|
||||||
|
|
@ -56,4 +60,7 @@ __all__ = [
|
||||||
"BufferDispatched",
|
"BufferDispatched",
|
||||||
"BufferError",
|
"BufferError",
|
||||||
"handle_buffer_start",
|
"handle_buffer_start",
|
||||||
|
# Budget warnings
|
||||||
|
"BudgetWarning",
|
||||||
|
"DEFAULT_THRESHOLDS",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
66
xml_pipeline/primitives/budget_warning.py
Normal file
66
xml_pipeline/primitives/budget_warning.py
Normal file
|
|
@ -0,0 +1,66 @@
|
||||||
|
"""
|
||||||
|
BudgetWarning — System alerts for token budget thresholds.
|
||||||
|
|
||||||
|
When a thread approaches its token budget limit, the system injects
|
||||||
|
BudgetWarning messages to give agents a chance to wrap up gracefully.
|
||||||
|
|
||||||
|
Thresholds:
|
||||||
|
- 75%: Early warning - consider wrapping up
|
||||||
|
- 90%: Critical warning - finish current task
|
||||||
|
- 95%: Final warning - immediate action required
|
||||||
|
|
||||||
|
Agents can design contingencies:
|
||||||
|
- Summarize progress and respond early
|
||||||
|
- Escalate to a supervisor agent
|
||||||
|
- Save state for continuation
|
||||||
|
- Request budget increase (if supported)
|
||||||
|
|
||||||
|
Example handler pattern:
|
||||||
|
async def my_agent(payload, metadata):
|
||||||
|
if isinstance(payload, BudgetWarning):
|
||||||
|
if payload.severity == "critical":
|
||||||
|
# Wrap up immediately
|
||||||
|
return HandlerResponse.respond(
|
||||||
|
payload=Summary(progress="Reached 90% budget, stopping here...")
|
||||||
|
)
|
||||||
|
# Otherwise note it and continue
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Note: Do NOT use `from __future__ import annotations` here
|
||||||
|
# as it breaks the xmlify decorator which needs concrete types
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from third_party.xmlable import xmlify
|
||||||
|
|
||||||
|
|
||||||
|
@xmlify
|
||||||
|
@dataclass
|
||||||
|
class BudgetWarning:
|
||||||
|
"""
|
||||||
|
System warning about token budget consumption.
|
||||||
|
|
||||||
|
Sent to agents when their thread crosses budget thresholds.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
percent_used: Current percentage of budget consumed (0-100)
|
||||||
|
tokens_used: Total tokens consumed so far
|
||||||
|
tokens_remaining: Tokens remaining before exhaustion
|
||||||
|
max_tokens: Total budget for this thread
|
||||||
|
severity: Warning level (warning, critical, final)
|
||||||
|
message: Human-readable description
|
||||||
|
"""
|
||||||
|
percent_used: float
|
||||||
|
tokens_used: int
|
||||||
|
tokens_remaining: int
|
||||||
|
max_tokens: int
|
||||||
|
severity: str # "warning" (75%), "critical" (90%), "final" (95%)
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
# Default thresholds (can be configured)
|
||||||
|
DEFAULT_THRESHOLDS = {
|
||||||
|
75: "warning", # 75% - early warning
|
||||||
|
90: "critical", # 90% - wrap up soon
|
||||||
|
95: "final", # 95% - last chance
|
||||||
|
}
|
||||||
Loading…
Reference in a new issue