Endpoints:
- GET /api/v1/usage - Overview with totals, per-agent, per-model breakdown
- GET /api/v1/usage/threads - List all thread budgets sorted by usage
- GET /api/v1/usage/threads/{id} - Single thread budget details
- GET /api/v1/usage/agents/{id} - Usage totals for specific agent
- GET /api/v1/usage/models/{model} - Usage totals for specific model
- POST /api/v1/usage/reset - Reset all usage tracking
Models:
- UsageTotals, UsageOverview, UsageResponse
- ThreadBudgetInfo, ThreadBudgetListResponse
- AgentUsageInfo, ModelUsageInfo
Also adds has_budget() method to ThreadBudgetRegistry for checking
if a thread exists without auto-creating it.
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
893 lines
30 KiB
Python
893 lines
30 KiB
Python
"""
|
|
Tests for the AgentServer API.
|
|
|
|
Tests the REST API endpoints and WebSocket connections.
|
|
"""
|
|
|
|
import asyncio
|
|
import pytest
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, List, Optional
|
|
from unittest.mock import MagicMock, AsyncMock, patch
|
|
|
|
# Skip all tests if FastAPI not available
|
|
pytest.importorskip("fastapi")
|
|
|
|
from fastapi.testclient import TestClient
|
|
|
|
from xml_pipeline.server.models import (
|
|
AgentState,
|
|
AgentInfo,
|
|
ThreadStatus,
|
|
ThreadInfo,
|
|
MessageInfo,
|
|
OrganismInfo,
|
|
OrganismStatus,
|
|
)
|
|
from xml_pipeline.server.state import ServerState, AgentRuntimeState, ThreadRuntimeState
|
|
from xml_pipeline.server.api import create_router
|
|
from xml_pipeline.server.app import create_app
|
|
|
|
|
|
# ============================================================================
|
|
# Mock StreamPump
|
|
# ============================================================================
|
|
|
|
@dataclass
|
|
class MockListener:
|
|
"""Mock listener for testing."""
|
|
name: str
|
|
description: str
|
|
is_agent: bool = False
|
|
peers: List[str] = None
|
|
payload_class: type = None
|
|
schema: Any = None
|
|
root_tag: str = ""
|
|
|
|
def __post_init__(self):
|
|
if self.peers is None:
|
|
self.peers = []
|
|
if self.payload_class is None:
|
|
self.payload_class = type("MockPayload", (), {"__module__": "test", "__name__": "MockPayload"})
|
|
if not self.root_tag:
|
|
self.root_tag = f"{self.name.lower()}.mockpayload"
|
|
|
|
|
|
@dataclass
|
|
class MockConfig:
|
|
"""Mock config for testing."""
|
|
name: str = "test-organism"
|
|
port: int = 8765
|
|
thread_scheduling: str = "breadth-first"
|
|
max_concurrent_pipelines: int = 50
|
|
max_concurrent_handlers: int = 20
|
|
|
|
|
|
class MockStreamPump:
|
|
"""Mock StreamPump for testing."""
|
|
|
|
def __init__(self):
|
|
self.config = MockConfig()
|
|
self.identity = None
|
|
self.listeners: Dict[str, MockListener] = {
|
|
"greeter": MockListener(
|
|
name="greeter",
|
|
description="Greeting agent",
|
|
is_agent=True,
|
|
peers=["shouter"],
|
|
),
|
|
"shouter": MockListener(
|
|
name="shouter",
|
|
description="Shouting handler",
|
|
is_agent=False,
|
|
),
|
|
}
|
|
self._event_callbacks = []
|
|
|
|
def subscribe_events(self, callback):
|
|
self._event_callbacks.append(callback)
|
|
|
|
def unsubscribe_events(self, callback):
|
|
if callback in self._event_callbacks:
|
|
self._event_callbacks.remove(callback)
|
|
|
|
|
|
# ============================================================================
|
|
# Test Fixtures
|
|
# ============================================================================
|
|
|
|
@pytest.fixture
|
|
def mock_pump():
|
|
"""Create a mock StreamPump."""
|
|
return MockStreamPump()
|
|
|
|
|
|
@pytest.fixture
|
|
def server_state(mock_pump):
|
|
"""Create ServerState with mock pump."""
|
|
return ServerState(mock_pump)
|
|
|
|
|
|
@pytest.fixture
|
|
def test_client(mock_pump):
|
|
"""Create FastAPI test client."""
|
|
app = create_app(mock_pump)
|
|
return TestClient(app)
|
|
|
|
|
|
# ============================================================================
|
|
# Test Models
|
|
# ============================================================================
|
|
|
|
class TestModels:
|
|
"""Test Pydantic model serialization."""
|
|
|
|
def test_agent_info_camel_case(self):
|
|
"""Test AgentInfo serializes to camelCase."""
|
|
agent = AgentInfo(
|
|
name="greeter",
|
|
description="Test agent",
|
|
is_agent=True,
|
|
peers=["shouter"],
|
|
payload_class="test.Greeting",
|
|
state=AgentState.IDLE,
|
|
current_thread=None,
|
|
queue_depth=0,
|
|
message_count=5,
|
|
)
|
|
data = agent.model_dump(by_alias=True)
|
|
assert "isAgent" in data
|
|
assert "payloadClass" in data
|
|
assert "currentThread" in data
|
|
assert "queueDepth" in data
|
|
assert "messageCount" in data
|
|
|
|
def test_thread_info_camel_case(self):
|
|
"""Test ThreadInfo serializes to camelCase."""
|
|
from datetime import datetime, timezone
|
|
thread = ThreadInfo(
|
|
id="test-uuid",
|
|
status=ThreadStatus.ACTIVE,
|
|
participants=["greeter", "shouter"],
|
|
message_count=3,
|
|
created_at=datetime.now(timezone.utc),
|
|
)
|
|
data = thread.model_dump(by_alias=True)
|
|
assert "messageCount" in data
|
|
assert "createdAt" in data
|
|
assert "lastActivity" in data
|
|
|
|
def test_organism_info_camel_case(self):
|
|
"""Test OrganismInfo serializes to camelCase."""
|
|
info = OrganismInfo(
|
|
name="test-organism",
|
|
status=OrganismStatus.RUNNING,
|
|
uptime_seconds=3600.0,
|
|
agent_count=2,
|
|
active_threads=1,
|
|
total_messages=10,
|
|
identity_configured=False,
|
|
)
|
|
data = info.model_dump(by_alias=True)
|
|
assert "uptimeSeconds" in data
|
|
assert "agentCount" in data
|
|
assert "activeThreads" in data
|
|
assert "totalMessages" in data
|
|
assert "identityConfigured" in data
|
|
|
|
|
|
# ============================================================================
|
|
# Test ServerState
|
|
# ============================================================================
|
|
|
|
class TestServerState:
|
|
"""Test ServerState functionality."""
|
|
|
|
def test_init_agents_from_pump(self, server_state):
|
|
"""Test agents are initialized from pump listeners."""
|
|
assert "greeter" in server_state._agents
|
|
assert "shouter" in server_state._agents
|
|
assert server_state._agents["greeter"].is_agent is True
|
|
assert server_state._agents["shouter"].is_agent is False
|
|
|
|
def test_get_organism_info(self, server_state):
|
|
"""Test getting organism info."""
|
|
info = server_state.get_organism_info()
|
|
assert info.name == "test-organism"
|
|
assert info.agent_count == 2
|
|
assert info.status == OrganismStatus.STARTING
|
|
|
|
def test_set_running(self, server_state):
|
|
"""Test setting running status."""
|
|
server_state.set_running()
|
|
info = server_state.get_organism_info()
|
|
assert info.status == OrganismStatus.RUNNING
|
|
|
|
def test_get_agents(self, server_state):
|
|
"""Test getting all agents."""
|
|
agents = server_state.get_agents()
|
|
assert len(agents) == 2
|
|
names = [a.name for a in agents]
|
|
assert "greeter" in names
|
|
assert "shouter" in names
|
|
|
|
def test_get_agent(self, server_state):
|
|
"""Test getting single agent."""
|
|
agent = server_state.get_agent("greeter")
|
|
assert agent is not None
|
|
assert agent.name == "greeter"
|
|
assert agent.is_agent is True
|
|
|
|
def test_get_agent_not_found(self, server_state):
|
|
"""Test getting non-existent agent."""
|
|
agent = server_state.get_agent("nonexistent")
|
|
assert agent is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_record_message(self, server_state):
|
|
"""Test recording a message."""
|
|
msg_id = await server_state.record_message(
|
|
thread_id="test-thread",
|
|
from_id="greeter",
|
|
to_id="shouter",
|
|
payload_type="GreetingResponse",
|
|
payload={"message": "Hello!"},
|
|
)
|
|
assert msg_id is not None
|
|
|
|
# Check message was recorded
|
|
messages, total = server_state.get_messages(thread_id="test-thread")
|
|
assert total == 1
|
|
assert messages[0].from_id == "greeter"
|
|
assert messages[0].to_id == "shouter"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_agent_state(self, server_state):
|
|
"""Test updating agent state."""
|
|
await server_state.update_agent_state("greeter", AgentState.PROCESSING, "thread-1")
|
|
|
|
agent = server_state.get_agent("greeter")
|
|
assert agent.state == AgentState.PROCESSING
|
|
assert agent.current_thread == "thread-1"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_thread(self, server_state):
|
|
"""Test completing a thread."""
|
|
# First record a message to create the thread
|
|
await server_state.record_message(
|
|
thread_id="test-thread",
|
|
from_id="greeter",
|
|
to_id="shouter",
|
|
payload_type="Test",
|
|
payload={},
|
|
)
|
|
|
|
# Complete the thread
|
|
await server_state.complete_thread("test-thread", ThreadStatus.COMPLETED)
|
|
|
|
thread = server_state.get_thread("test-thread")
|
|
assert thread.status == ThreadStatus.COMPLETED
|
|
|
|
def test_get_threads_with_filter(self, server_state):
|
|
"""Test filtering threads by status."""
|
|
# No threads initially
|
|
threads, total = server_state.get_threads(status=ThreadStatus.ACTIVE)
|
|
assert total == 0
|
|
|
|
def test_get_organism_config(self, server_state):
|
|
"""Test getting organism config."""
|
|
config = server_state.get_organism_config()
|
|
assert config["name"] == "test-organism"
|
|
assert config["port"] == 8765
|
|
|
|
|
|
# ============================================================================
|
|
# Test REST API
|
|
# ============================================================================
|
|
|
|
class TestRestAPI:
|
|
"""Test REST API endpoints."""
|
|
|
|
def test_health_check(self, test_client):
|
|
"""Test health check endpoint."""
|
|
response = test_client.get("/health")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "healthy"
|
|
assert "organism" in data
|
|
|
|
def test_get_organism(self, test_client):
|
|
"""Test GET /api/v1/organism."""
|
|
response = test_client.get("/api/v1/organism")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["name"] == "test-organism"
|
|
assert "agentCount" in data
|
|
|
|
def test_get_organism_config(self, test_client):
|
|
"""Test GET /api/v1/organism/config."""
|
|
response = test_client.get("/api/v1/organism/config")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["name"] == "test-organism"
|
|
|
|
def test_list_agents(self, test_client):
|
|
"""Test GET /api/v1/agents."""
|
|
response = test_client.get("/api/v1/agents")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "agents" in data
|
|
assert data["count"] == 2
|
|
|
|
def test_get_agent(self, test_client):
|
|
"""Test GET /api/v1/agents/{name}."""
|
|
response = test_client.get("/api/v1/agents/greeter")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["name"] == "greeter"
|
|
assert data["isAgent"] is True
|
|
|
|
def test_get_agent_not_found(self, test_client):
|
|
"""Test GET /api/v1/agents/{name} with non-existent agent."""
|
|
response = test_client.get("/api/v1/agents/nonexistent")
|
|
assert response.status_code == 404
|
|
|
|
def test_get_agent_config(self, test_client):
|
|
"""Test GET /api/v1/agents/{name}/config."""
|
|
response = test_client.get("/api/v1/agents/greeter/config")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["name"] == "greeter"
|
|
assert "isAgent" in data
|
|
|
|
def test_list_threads(self, test_client):
|
|
"""Test GET /api/v1/threads."""
|
|
response = test_client.get("/api/v1/threads")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "threads" in data
|
|
assert "count" in data
|
|
assert "total" in data
|
|
|
|
def test_list_threads_with_invalid_status(self, test_client):
|
|
"""Test GET /api/v1/threads with invalid status filter."""
|
|
response = test_client.get("/api/v1/threads?status=invalid")
|
|
assert response.status_code == 400
|
|
|
|
def test_list_messages(self, test_client):
|
|
"""Test GET /api/v1/messages."""
|
|
response = test_client.get("/api/v1/messages")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "messages" in data
|
|
assert "count" in data
|
|
|
|
def test_inject_message(self, test_client):
|
|
"""Test POST /api/v1/inject."""
|
|
response = test_client.post(
|
|
"/api/v1/inject",
|
|
json={
|
|
"to": "greeter",
|
|
"payload": {"name": "Dan"},
|
|
},
|
|
)
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "threadId" in data
|
|
assert "messageId" in data
|
|
|
|
def test_inject_message_unknown_agent(self, test_client):
|
|
"""Test POST /api/v1/inject with unknown agent."""
|
|
response = test_client.post(
|
|
"/api/v1/inject",
|
|
json={
|
|
"to": "nonexistent",
|
|
"payload": {"name": "Dan"},
|
|
},
|
|
)
|
|
assert response.status_code == 400
|
|
|
|
def test_pause_agent(self, test_client):
|
|
"""Test POST /api/v1/agents/{name}/pause."""
|
|
response = test_client.post("/api/v1/agents/greeter/pause")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["state"] == "paused"
|
|
|
|
def test_resume_agent(self, test_client):
|
|
"""Test POST /api/v1/agents/{name}/resume."""
|
|
response = test_client.post("/api/v1/agents/greeter/resume")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
assert data["state"] == "idle"
|
|
|
|
def test_stop_organism(self, test_client):
|
|
"""Test POST /api/v1/organism/stop."""
|
|
response = test_client.post("/api/v1/organism/stop")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
|
|
|
|
# ============================================================================
|
|
# Test WebSocket
|
|
# ============================================================================
|
|
|
|
class TestWebSocket:
|
|
"""Test WebSocket endpoints."""
|
|
|
|
def test_websocket_connect(self, test_client):
|
|
"""Test WebSocket connection."""
|
|
with test_client.websocket_connect("/ws") as websocket:
|
|
# Should receive connected event with state snapshot
|
|
data = websocket.receive_json()
|
|
assert data["event"] == "connected"
|
|
assert "organism" in data
|
|
assert "agents" in data
|
|
assert "threads" in data
|
|
|
|
def test_websocket_subscribe(self, test_client):
|
|
"""Test WebSocket subscribe command."""
|
|
with test_client.websocket_connect("/ws") as websocket:
|
|
# Receive initial connected event
|
|
websocket.receive_json()
|
|
|
|
# Send subscribe command
|
|
websocket.send_json({
|
|
"cmd": "subscribe",
|
|
"agents": ["greeter"],
|
|
"events": ["message"],
|
|
})
|
|
|
|
# Should receive subscribed confirmation
|
|
data = websocket.receive_json()
|
|
assert data["event"] == "subscribed"
|
|
|
|
def test_websocket_inject(self, test_client):
|
|
"""Test WebSocket inject command."""
|
|
with test_client.websocket_connect("/ws") as websocket:
|
|
# Receive initial connected event
|
|
websocket.receive_json()
|
|
|
|
# Send inject command
|
|
websocket.send_json({
|
|
"cmd": "inject",
|
|
"to": "greeter",
|
|
"payload": {"name": "Dan"},
|
|
})
|
|
|
|
# Should receive injected confirmation
|
|
data = websocket.receive_json()
|
|
assert data["event"] == "injected"
|
|
assert "thread_id" in data
|
|
assert "message_id" in data
|
|
|
|
def test_websocket_inject_unknown_agent(self, test_client):
|
|
"""Test WebSocket inject with unknown agent."""
|
|
with test_client.websocket_connect("/ws") as websocket:
|
|
# Receive initial connected event
|
|
websocket.receive_json()
|
|
|
|
# Send inject command to unknown agent
|
|
websocket.send_json({
|
|
"cmd": "inject",
|
|
"to": "nonexistent",
|
|
"payload": {},
|
|
})
|
|
|
|
# Should receive error
|
|
data = websocket.receive_json()
|
|
assert data["event"] == "error"
|
|
|
|
def test_websocket_unknown_command(self, test_client):
|
|
"""Test WebSocket with unknown command."""
|
|
with test_client.websocket_connect("/ws") as websocket:
|
|
# Receive initial connected event
|
|
websocket.receive_json()
|
|
|
|
# Send unknown command
|
|
websocket.send_json({
|
|
"cmd": "unknown_command",
|
|
})
|
|
|
|
# Should receive error
|
|
data = websocket.receive_json()
|
|
assert data["event"] == "error"
|
|
|
|
def test_websocket_messages_stream(self, test_client):
|
|
"""Test WebSocket messages stream endpoint."""
|
|
with test_client.websocket_connect("/ws/messages") as websocket:
|
|
# Send subscribe command
|
|
websocket.send_json({
|
|
"cmd": "subscribe",
|
|
"filter": {
|
|
"agents": ["greeter"],
|
|
},
|
|
})
|
|
|
|
# Should receive subscribed confirmation
|
|
data = websocket.receive_json()
|
|
assert data["event"] == "subscribed"
|
|
assert "filter" in data
|
|
|
|
|
|
# ============================================================================
|
|
# Test Capability Introspection
|
|
# ============================================================================
|
|
|
|
class TestCapabilityIntrospection:
|
|
"""Test capability introspection endpoints."""
|
|
|
|
def test_list_capabilities(self, test_client):
|
|
"""Test GET /capabilities lists all registered listeners."""
|
|
response = test_client.get("/api/v1/capabilities")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert "capabilities" in data
|
|
assert "count" in data
|
|
assert data["count"] == 2 # greeter and shouter
|
|
|
|
names = [c["name"] for c in data["capabilities"]]
|
|
assert "greeter" in names
|
|
assert "shouter" in names
|
|
|
|
def test_list_capabilities_includes_details(self, test_client):
|
|
"""Test capability listing includes required fields."""
|
|
response = test_client.get("/api/v1/capabilities")
|
|
data = response.json()
|
|
|
|
# Find greeter
|
|
greeter = next(c for c in data["capabilities"] if c["name"] == "greeter")
|
|
|
|
assert "description" in greeter
|
|
assert "isAgent" in greeter
|
|
assert greeter["isAgent"] is True
|
|
assert "peers" in greeter
|
|
assert "shouter" in greeter["peers"]
|
|
assert "rootTag" in greeter
|
|
|
|
def test_list_capabilities_sorted(self, test_client):
|
|
"""Test capabilities are sorted by name."""
|
|
response = test_client.get("/api/v1/capabilities")
|
|
data = response.json()
|
|
|
|
names = [c["name"] for c in data["capabilities"]]
|
|
assert names == sorted(names)
|
|
|
|
def test_get_capability_detail(self, test_client):
|
|
"""Test GET /capabilities/{name} returns detailed info."""
|
|
response = test_client.get("/api/v1/capabilities/greeter")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["name"] == "greeter"
|
|
assert data["description"] == "Greeting agent"
|
|
assert data["isAgent"] is True
|
|
assert "shouter" in data["peers"]
|
|
assert "payloadClass" in data
|
|
assert "rootTag" in data
|
|
|
|
def test_get_capability_includes_example(self, test_client):
|
|
"""Test capability detail includes example XML."""
|
|
response = test_client.get("/api/v1/capabilities/greeter")
|
|
data = response.json()
|
|
|
|
# Example XML may or may not be present depending on payload class
|
|
assert "exampleXml" in data
|
|
|
|
def test_get_capability_not_found(self, test_client):
|
|
"""Test GET /capabilities/{name} returns 404 for unknown capability."""
|
|
response = test_client.get("/api/v1/capabilities/nonexistent")
|
|
assert response.status_code == 404
|
|
|
|
def test_get_capability_shouter(self, test_client):
|
|
"""Test non-agent capability details."""
|
|
response = test_client.get("/api/v1/capabilities/shouter")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["name"] == "shouter"
|
|
assert data["isAgent"] is False
|
|
assert data["peers"] == []
|
|
|
|
|
|
class TestCapabilityIntrospectionState:
|
|
"""Test capability introspection via ServerState."""
|
|
|
|
def test_get_capabilities_returns_list(self, server_state):
|
|
"""Test get_capabilities returns CapabilityInfo list."""
|
|
capabilities = server_state.get_capabilities()
|
|
assert len(capabilities) == 2
|
|
assert all(hasattr(c, "name") for c in capabilities)
|
|
assert all(hasattr(c, "root_tag") for c in capabilities)
|
|
|
|
def test_get_capability_returns_detail(self, server_state):
|
|
"""Test get_capability returns CapabilityDetail."""
|
|
detail = server_state.get_capability("greeter")
|
|
assert detail is not None
|
|
assert detail.name == "greeter"
|
|
assert detail.is_agent is True
|
|
assert "shouter" in detail.peers
|
|
|
|
def test_get_capability_not_found_returns_none(self, server_state):
|
|
"""Test get_capability returns None for unknown."""
|
|
detail = server_state.get_capability("nonexistent")
|
|
assert detail is None
|
|
|
|
|
|
# ============================================================================
|
|
# Test Usage/Gas Tracking API
|
|
# ============================================================================
|
|
|
|
class TestUsageAPI:
|
|
"""Test usage/gas tracking endpoints."""
|
|
|
|
def test_get_usage_overview(self, test_client):
|
|
"""Test GET /api/v1/usage returns overview."""
|
|
# Reset trackers for clean state
|
|
from xml_pipeline.llm import reset_usage_tracker
|
|
from xml_pipeline.message_bus import reset_budget_registry
|
|
|
|
reset_usage_tracker()
|
|
reset_budget_registry()
|
|
|
|
response = test_client.get("/api/v1/usage")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert "usage" in data
|
|
|
|
usage = data["usage"]
|
|
assert "totals" in usage
|
|
assert "byAgent" in usage
|
|
assert "byModel" in usage
|
|
assert "activeThreads" in usage
|
|
|
|
totals = usage["totals"]
|
|
assert "totalTokens" in totals
|
|
assert "promptTokens" in totals
|
|
assert "completionTokens" in totals
|
|
assert "requestCount" in totals
|
|
assert "totalCost" in totals
|
|
assert "avgLatencyMs" in totals
|
|
|
|
def test_get_usage_with_data(self, test_client):
|
|
"""Test usage reflects recorded data."""
|
|
from xml_pipeline.llm import get_usage_tracker, reset_usage_tracker
|
|
|
|
reset_usage_tracker()
|
|
tracker = get_usage_tracker()
|
|
|
|
# Record some usage
|
|
tracker.record(
|
|
thread_id="test-thread",
|
|
agent_id="greeter",
|
|
model="grok-4.1",
|
|
provider="xai",
|
|
prompt_tokens=100,
|
|
completion_tokens=50,
|
|
latency_ms=250.0,
|
|
)
|
|
|
|
response = test_client.get("/api/v1/usage")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
totals = data["usage"]["totals"]
|
|
assert totals["totalTokens"] == 150
|
|
assert totals["promptTokens"] == 100
|
|
assert totals["completionTokens"] == 50
|
|
assert totals["requestCount"] == 1
|
|
|
|
# Check by-agent breakdown
|
|
by_agent = data["usage"]["byAgent"]
|
|
assert len(by_agent) == 1
|
|
assert by_agent[0]["agentId"] == "greeter"
|
|
assert by_agent[0]["totalTokens"] == 150
|
|
|
|
# Check by-model breakdown
|
|
by_model = data["usage"]["byModel"]
|
|
assert len(by_model) == 1
|
|
assert by_model[0]["model"] == "grok-4.1"
|
|
assert by_model[0]["totalTokens"] == 150
|
|
|
|
def test_get_thread_budgets_empty(self, test_client):
|
|
"""Test GET /api/v1/usage/threads with no threads."""
|
|
from xml_pipeline.message_bus import reset_budget_registry
|
|
|
|
reset_budget_registry()
|
|
|
|
response = test_client.get("/api/v1/usage/threads")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert "threads" in data
|
|
assert "count" in data
|
|
assert "defaultMaxTokens" in data
|
|
assert data["count"] == 0
|
|
|
|
def test_get_thread_budgets_with_data(self, test_client):
|
|
"""Test thread budgets reflect consumption."""
|
|
from xml_pipeline.message_bus import get_budget_registry, reset_budget_registry
|
|
|
|
reset_budget_registry()
|
|
registry = get_budget_registry()
|
|
|
|
# Consume some tokens
|
|
registry.consume("thread-1", 5000, 2000)
|
|
registry.consume("thread-2", 10000, 5000)
|
|
|
|
response = test_client.get("/api/v1/usage/threads")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["count"] == 2
|
|
|
|
# Threads sorted by percent used (descending)
|
|
threads = data["threads"]
|
|
assert threads[0]["percentUsed"] >= threads[1]["percentUsed"]
|
|
|
|
# Find thread-2 (should have higher usage)
|
|
thread2 = next(t for t in threads if t["threadId"] == "thread-2")
|
|
assert thread2["totalTokens"] == 15000
|
|
assert thread2["promptTokens"] == 10000
|
|
assert thread2["completionTokens"] == 5000
|
|
|
|
def test_get_single_thread_budget(self, test_client):
|
|
"""Test GET /api/v1/usage/threads/{thread_id}."""
|
|
from xml_pipeline.message_bus import get_budget_registry, reset_budget_registry
|
|
|
|
reset_budget_registry()
|
|
registry = get_budget_registry()
|
|
registry.consume("my-thread", 3000, 1500)
|
|
|
|
response = test_client.get("/api/v1/usage/threads/my-thread")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["threadId"] == "my-thread"
|
|
assert data["totalTokens"] == 4500
|
|
assert data["promptTokens"] == 3000
|
|
assert data["completionTokens"] == 1500
|
|
assert data["maxTokens"] == 100000 # default
|
|
assert data["remaining"] == 95500
|
|
assert data["percentUsed"] == 4.5
|
|
assert data["isExhausted"] is False
|
|
|
|
def test_get_single_thread_budget_not_found(self, test_client):
|
|
"""Test GET /usage/threads/{id} returns 404 for unknown thread."""
|
|
from xml_pipeline.message_bus import reset_budget_registry
|
|
|
|
reset_budget_registry()
|
|
|
|
response = test_client.get("/api/v1/usage/threads/nonexistent")
|
|
assert response.status_code == 404
|
|
|
|
def test_get_agent_usage(self, test_client):
|
|
"""Test GET /api/v1/usage/agents/{agent_id}."""
|
|
from xml_pipeline.llm import get_usage_tracker, reset_usage_tracker
|
|
|
|
reset_usage_tracker()
|
|
tracker = get_usage_tracker()
|
|
|
|
tracker.record(
|
|
thread_id="t1",
|
|
agent_id="researcher",
|
|
model="grok-4.1",
|
|
provider="xai",
|
|
prompt_tokens=1000,
|
|
completion_tokens=500,
|
|
latency_ms=100.0,
|
|
)
|
|
tracker.record(
|
|
thread_id="t2",
|
|
agent_id="researcher",
|
|
model="grok-4.1",
|
|
provider="xai",
|
|
prompt_tokens=2000,
|
|
completion_tokens=1000,
|
|
latency_ms=150.0,
|
|
)
|
|
|
|
response = test_client.get("/api/v1/usage/agents/researcher")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["agentId"] == "researcher"
|
|
assert data["totalTokens"] == 4500
|
|
assert data["promptTokens"] == 3000
|
|
assert data["completionTokens"] == 1500
|
|
assert data["requestCount"] == 2
|
|
|
|
def test_get_agent_usage_empty(self, test_client):
|
|
"""Test GET /usage/agents/{id} for agent with no usage."""
|
|
from xml_pipeline.llm import reset_usage_tracker
|
|
|
|
reset_usage_tracker()
|
|
|
|
response = test_client.get("/api/v1/usage/agents/unknown")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["agentId"] == "unknown"
|
|
assert data["totalTokens"] == 0
|
|
assert data["requestCount"] == 0
|
|
|
|
def test_get_model_usage(self, test_client):
|
|
"""Test GET /api/v1/usage/models/{model}."""
|
|
from xml_pipeline.llm import get_usage_tracker, reset_usage_tracker
|
|
|
|
reset_usage_tracker()
|
|
tracker = get_usage_tracker()
|
|
|
|
tracker.record(
|
|
thread_id="t1",
|
|
agent_id="a1",
|
|
model="claude-sonnet-4",
|
|
provider="anthropic",
|
|
prompt_tokens=500,
|
|
completion_tokens=200,
|
|
latency_ms=80.0,
|
|
)
|
|
|
|
response = test_client.get("/api/v1/usage/models/claude-sonnet-4")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["model"] == "claude-sonnet-4"
|
|
assert data["totalTokens"] == 700
|
|
assert data["requestCount"] == 1
|
|
|
|
def test_reset_usage(self, test_client):
|
|
"""Test POST /api/v1/usage/reset."""
|
|
from xml_pipeline.llm import get_usage_tracker
|
|
|
|
tracker = get_usage_tracker()
|
|
tracker.record(
|
|
thread_id="t1",
|
|
agent_id="a1",
|
|
model="test",
|
|
provider="test",
|
|
prompt_tokens=1000,
|
|
completion_tokens=500,
|
|
latency_ms=100.0,
|
|
)
|
|
|
|
response = test_client.post("/api/v1/usage/reset")
|
|
assert response.status_code == 200
|
|
|
|
data = response.json()
|
|
assert data["success"] is True
|
|
|
|
# Verify usage was reset
|
|
response = test_client.get("/api/v1/usage")
|
|
data = response.json()
|
|
assert data["usage"]["totals"]["totalTokens"] == 0
|
|
assert data["usage"]["totals"]["requestCount"] == 0
|
|
|
|
def test_usage_cost_estimation(self, test_client):
|
|
"""Test that usage includes cost estimates."""
|
|
from xml_pipeline.llm import get_usage_tracker, reset_usage_tracker
|
|
|
|
reset_usage_tracker()
|
|
tracker = get_usage_tracker()
|
|
|
|
# Use known model with pricing
|
|
tracker.record(
|
|
thread_id="t1",
|
|
agent_id="a1",
|
|
model="grok-4.1", # $3/M prompt, $15/M completion
|
|
provider="xai",
|
|
prompt_tokens=1_000_000, # $3
|
|
completion_tokens=1_000_000, # $15
|
|
latency_ms=100.0,
|
|
)
|
|
|
|
response = test_client.get("/api/v1/usage")
|
|
data = response.json()
|
|
|
|
# Cost should be approximately $18
|
|
assert data["usage"]["totals"]["totalCost"] == 18.0
|