diff --git a/tests/test_server.py b/tests/test_server.py index 63ed6a9..3f90e1d 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -616,3 +616,278 @@ class TestCapabilityIntrospectionState: """Test get_capability returns None for unknown.""" detail = server_state.get_capability("nonexistent") assert detail is None + + +# ============================================================================ +# Test Usage/Gas Tracking API +# ============================================================================ + +class TestUsageAPI: + """Test usage/gas tracking endpoints.""" + + def test_get_usage_overview(self, test_client): + """Test GET /api/v1/usage returns overview.""" + # Reset trackers for clean state + from xml_pipeline.llm import reset_usage_tracker + from xml_pipeline.message_bus import reset_budget_registry + + reset_usage_tracker() + reset_budget_registry() + + response = test_client.get("/api/v1/usage") + assert response.status_code == 200 + + data = response.json() + assert "usage" in data + + usage = data["usage"] + assert "totals" in usage + assert "byAgent" in usage + assert "byModel" in usage + assert "activeThreads" in usage + + totals = usage["totals"] + assert "totalTokens" in totals + assert "promptTokens" in totals + assert "completionTokens" in totals + assert "requestCount" in totals + assert "totalCost" in totals + assert "avgLatencyMs" in totals + + def test_get_usage_with_data(self, test_client): + """Test usage reflects recorded data.""" + from xml_pipeline.llm import get_usage_tracker, reset_usage_tracker + + reset_usage_tracker() + tracker = get_usage_tracker() + + # Record some usage + tracker.record( + thread_id="test-thread", + agent_id="greeter", + model="grok-4.1", + provider="xai", + prompt_tokens=100, + completion_tokens=50, + latency_ms=250.0, + ) + + response = test_client.get("/api/v1/usage") + assert response.status_code == 200 + + data = response.json() + totals = data["usage"]["totals"] + assert totals["totalTokens"] == 150 + assert totals["promptTokens"] == 100 + assert totals["completionTokens"] == 50 + assert totals["requestCount"] == 1 + + # Check by-agent breakdown + by_agent = data["usage"]["byAgent"] + assert len(by_agent) == 1 + assert by_agent[0]["agentId"] == "greeter" + assert by_agent[0]["totalTokens"] == 150 + + # Check by-model breakdown + by_model = data["usage"]["byModel"] + assert len(by_model) == 1 + assert by_model[0]["model"] == "grok-4.1" + assert by_model[0]["totalTokens"] == 150 + + def test_get_thread_budgets_empty(self, test_client): + """Test GET /api/v1/usage/threads with no threads.""" + from xml_pipeline.message_bus import reset_budget_registry + + reset_budget_registry() + + response = test_client.get("/api/v1/usage/threads") + assert response.status_code == 200 + + data = response.json() + assert "threads" in data + assert "count" in data + assert "defaultMaxTokens" in data + assert data["count"] == 0 + + def test_get_thread_budgets_with_data(self, test_client): + """Test thread budgets reflect consumption.""" + from xml_pipeline.message_bus import get_budget_registry, reset_budget_registry + + reset_budget_registry() + registry = get_budget_registry() + + # Consume some tokens + registry.consume("thread-1", 5000, 2000) + registry.consume("thread-2", 10000, 5000) + + response = test_client.get("/api/v1/usage/threads") + assert response.status_code == 200 + + data = response.json() + assert data["count"] == 2 + + # Threads sorted by percent used (descending) + threads = data["threads"] + assert threads[0]["percentUsed"] >= threads[1]["percentUsed"] + + # Find thread-2 (should have higher usage) + thread2 = next(t for t in threads if t["threadId"] == "thread-2") + assert thread2["totalTokens"] == 15000 + assert thread2["promptTokens"] == 10000 + assert thread2["completionTokens"] == 5000 + + def test_get_single_thread_budget(self, test_client): + """Test GET /api/v1/usage/threads/{thread_id}.""" + from xml_pipeline.message_bus import get_budget_registry, reset_budget_registry + + reset_budget_registry() + registry = get_budget_registry() + registry.consume("my-thread", 3000, 1500) + + response = test_client.get("/api/v1/usage/threads/my-thread") + assert response.status_code == 200 + + data = response.json() + assert data["threadId"] == "my-thread" + assert data["totalTokens"] == 4500 + assert data["promptTokens"] == 3000 + assert data["completionTokens"] == 1500 + assert data["maxTokens"] == 100000 # default + assert data["remaining"] == 95500 + assert data["percentUsed"] == 4.5 + assert data["isExhausted"] is False + + def test_get_single_thread_budget_not_found(self, test_client): + """Test GET /usage/threads/{id} returns 404 for unknown thread.""" + from xml_pipeline.message_bus import reset_budget_registry + + reset_budget_registry() + + response = test_client.get("/api/v1/usage/threads/nonexistent") + assert response.status_code == 404 + + def test_get_agent_usage(self, test_client): + """Test GET /api/v1/usage/agents/{agent_id}.""" + from xml_pipeline.llm import get_usage_tracker, reset_usage_tracker + + reset_usage_tracker() + tracker = get_usage_tracker() + + tracker.record( + thread_id="t1", + agent_id="researcher", + model="grok-4.1", + provider="xai", + prompt_tokens=1000, + completion_tokens=500, + latency_ms=100.0, + ) + tracker.record( + thread_id="t2", + agent_id="researcher", + model="grok-4.1", + provider="xai", + prompt_tokens=2000, + completion_tokens=1000, + latency_ms=150.0, + ) + + response = test_client.get("/api/v1/usage/agents/researcher") + assert response.status_code == 200 + + data = response.json() + assert data["agentId"] == "researcher" + assert data["totalTokens"] == 4500 + assert data["promptTokens"] == 3000 + assert data["completionTokens"] == 1500 + assert data["requestCount"] == 2 + + def test_get_agent_usage_empty(self, test_client): + """Test GET /usage/agents/{id} for agent with no usage.""" + from xml_pipeline.llm import reset_usage_tracker + + reset_usage_tracker() + + response = test_client.get("/api/v1/usage/agents/unknown") + assert response.status_code == 200 + + data = response.json() + assert data["agentId"] == "unknown" + assert data["totalTokens"] == 0 + assert data["requestCount"] == 0 + + def test_get_model_usage(self, test_client): + """Test GET /api/v1/usage/models/{model}.""" + from xml_pipeline.llm import get_usage_tracker, reset_usage_tracker + + reset_usage_tracker() + tracker = get_usage_tracker() + + tracker.record( + thread_id="t1", + agent_id="a1", + model="claude-sonnet-4", + provider="anthropic", + prompt_tokens=500, + completion_tokens=200, + latency_ms=80.0, + ) + + response = test_client.get("/api/v1/usage/models/claude-sonnet-4") + assert response.status_code == 200 + + data = response.json() + assert data["model"] == "claude-sonnet-4" + assert data["totalTokens"] == 700 + assert data["requestCount"] == 1 + + def test_reset_usage(self, test_client): + """Test POST /api/v1/usage/reset.""" + from xml_pipeline.llm import get_usage_tracker + + tracker = get_usage_tracker() + tracker.record( + thread_id="t1", + agent_id="a1", + model="test", + provider="test", + prompt_tokens=1000, + completion_tokens=500, + latency_ms=100.0, + ) + + response = test_client.post("/api/v1/usage/reset") + assert response.status_code == 200 + + data = response.json() + assert data["success"] is True + + # Verify usage was reset + response = test_client.get("/api/v1/usage") + data = response.json() + assert data["usage"]["totals"]["totalTokens"] == 0 + assert data["usage"]["totals"]["requestCount"] == 0 + + def test_usage_cost_estimation(self, test_client): + """Test that usage includes cost estimates.""" + from xml_pipeline.llm import get_usage_tracker, reset_usage_tracker + + reset_usage_tracker() + tracker = get_usage_tracker() + + # Use known model with pricing + tracker.record( + thread_id="t1", + agent_id="a1", + model="grok-4.1", # $3/M prompt, $15/M completion + provider="xai", + prompt_tokens=1_000_000, # $3 + completion_tokens=1_000_000, # $15 + latency_ms=100.0, + ) + + response = test_client.get("/api/v1/usage") + data = response.json() + + # Cost should be approximately $18 + assert data["usage"]["totals"]["totalCost"] == 18.0 diff --git a/xml_pipeline/message_bus/budget_registry.py b/xml_pipeline/message_bus/budget_registry.py index 3b4f323..a56d5c5 100644 --- a/xml_pipeline/message_bus/budget_registry.py +++ b/xml_pipeline/message_bus/budget_registry.py @@ -193,23 +193,32 @@ class ThreadBudgetRegistry: budget.consume(prompt_tokens, completion_tokens) return budget - def get_usage(self, thread_id: str) -> Dict[str, int]: + def has_budget(self, thread_id: str) -> bool: + """Check if a thread has a budget entry (without creating one).""" + with self._lock: + return thread_id in self._budgets + + def get_usage(self, thread_id: str) -> Optional[Dict[str, int]]: """ Get usage stats for a thread. Returns: Dict with prompt_tokens, completion_tokens, total_tokens, - remaining, max_tokens, request_count + remaining, max_tokens, request_count. + Returns None if thread has no budget. """ - budget = self.get_budget(thread_id) - return { - "prompt_tokens": budget.prompt_tokens, - "completion_tokens": budget.completion_tokens, - "total_tokens": budget.total_tokens, - "remaining": budget.remaining, - "max_tokens": budget.max_tokens, - "request_count": budget.request_count, - } + with self._lock: + if thread_id not in self._budgets: + return None + budget = self._budgets[thread_id] + return { + "prompt_tokens": budget.prompt_tokens, + "completion_tokens": budget.completion_tokens, + "total_tokens": budget.total_tokens, + "remaining": budget.remaining, + "max_tokens": budget.max_tokens, + "request_count": budget.request_count, + } def get_all_usage(self) -> Dict[str, Dict[str, int]]: """Get usage stats for all threads.""" diff --git a/xml_pipeline/server/api.py b/xml_pipeline/server/api.py index f619183..45b2349 100644 --- a/xml_pipeline/server/api.py +++ b/xml_pipeline/server/api.py @@ -18,6 +18,7 @@ from fastapi import APIRouter, HTTPException, Query from xml_pipeline.server.models import ( AgentInfo, AgentListResponse, + AgentUsageInfo, CapabilityDetail, CapabilityInfo, CapabilityListResponse, @@ -25,10 +26,16 @@ from xml_pipeline.server.models import ( InjectRequest, InjectResponse, MessageListResponse, + ModelUsageInfo, OrganismInfo, + ThreadBudgetInfo, + ThreadBudgetListResponse, ThreadInfo, ThreadListResponse, ThreadStatus, + UsageOverview, + UsageResponse, + UsageTotals, ) if TYPE_CHECKING: @@ -233,6 +240,191 @@ def create_router(state: "ServerState") -> APIRouter: limit=limit, ) + # ========================================================================= + # Usage/Gas Tracking Endpoints + # ========================================================================= + + @router.get("/usage", response_model=UsageResponse) + async def get_usage() -> UsageResponse: + """ + Get usage overview (gas gauge). + + Returns aggregate token usage, costs, and per-agent/model breakdowns. + This is the main endpoint for monitoring LLM consumption. + """ + from xml_pipeline.llm import get_usage_tracker + from xml_pipeline.message_bus import get_budget_registry + + tracker = get_usage_tracker() + budget_registry = get_budget_registry() + + # Get aggregate totals + totals_dict = tracker.get_totals() + totals = UsageTotals( + total_tokens=totals_dict["total_tokens"], + prompt_tokens=totals_dict["prompt_tokens"], + completion_tokens=totals_dict["completion_tokens"], + request_count=totals_dict["request_count"], + total_cost=totals_dict["total_cost"], + avg_latency_ms=totals_dict["avg_latency_ms"], + ) + + # Get per-agent breakdown + agent_totals = tracker.get_all_agent_totals() + by_agent = [ + AgentUsageInfo( + agent_id=agent_id, + total_tokens=data["total_tokens"], + prompt_tokens=data["prompt_tokens"], + completion_tokens=data["completion_tokens"], + request_count=data["request_count"], + total_cost=data["total_cost"], + ) + for agent_id, data in agent_totals.items() + ] + + # Get per-model breakdown + model_totals = tracker.get_all_model_totals() + by_model = [ + ModelUsageInfo( + model=model, + total_tokens=data["total_tokens"], + prompt_tokens=data["prompt_tokens"], + completion_tokens=data["completion_tokens"], + request_count=data["request_count"], + total_cost=data["total_cost"], + ) + for model, data in model_totals.items() + ] + + # Count active threads with budgets + all_budgets = budget_registry.get_all_usage() + active_threads = len(all_budgets) + + overview = UsageOverview( + totals=totals, + by_agent=by_agent, + by_model=by_model, + active_threads=active_threads, + ) + + return UsageResponse(usage=overview) + + @router.get("/usage/threads", response_model=ThreadBudgetListResponse) + async def get_thread_budgets() -> ThreadBudgetListResponse: + """ + Get token budgets for all active threads. + + Shows remaining budget per thread for monitoring runaway agents. + """ + from xml_pipeline.message_bus import get_budget_registry + + registry = get_budget_registry() + all_budgets = registry.get_all_usage() + + threads = [] + for thread_id, budget_dict in all_budgets.items(): + max_tokens = budget_dict["max_tokens"] + total = budget_dict["total_tokens"] + percent = (total / max_tokens * 100) if max_tokens > 0 else 0 + + threads.append( + ThreadBudgetInfo( + thread_id=thread_id, + max_tokens=max_tokens, + prompt_tokens=budget_dict["prompt_tokens"], + completion_tokens=budget_dict["completion_tokens"], + total_tokens=total, + remaining=budget_dict["remaining"], + percent_used=round(percent, 1), + is_exhausted=budget_dict["remaining"] <= 0, + ) + ) + + # Sort by percent used (descending) - hottest threads first + threads.sort(key=lambda t: t.percent_used, reverse=True) + + return ThreadBudgetListResponse( + threads=threads, + count=len(threads), + default_max_tokens=registry._max_tokens_per_thread, + ) + + @router.get("/usage/threads/{thread_id}", response_model=ThreadBudgetInfo) + async def get_thread_budget(thread_id: str) -> ThreadBudgetInfo: + """Get token budget for a specific thread.""" + from xml_pipeline.message_bus import get_budget_registry + + registry = get_budget_registry() + budget_dict = registry.get_usage(thread_id) + + if budget_dict is None: + raise HTTPException( + status_code=404, + detail=f"No budget found for thread: {thread_id}", + ) + + max_tokens = budget_dict["max_tokens"] + total = budget_dict["total_tokens"] + percent = (total / max_tokens * 100) if max_tokens > 0 else 0 + + return ThreadBudgetInfo( + thread_id=thread_id, + max_tokens=max_tokens, + prompt_tokens=budget_dict["prompt_tokens"], + completion_tokens=budget_dict["completion_tokens"], + total_tokens=total, + remaining=budget_dict["remaining"], + percent_used=round(percent, 1), + is_exhausted=budget_dict["remaining"] <= 0, + ) + + @router.get("/usage/agents/{agent_id}") + async def get_agent_usage(agent_id: str) -> AgentUsageInfo: + """Get usage totals for a specific agent.""" + from xml_pipeline.llm import get_usage_tracker + + tracker = get_usage_tracker() + data = tracker.get_agent_totals(agent_id) + + return AgentUsageInfo( + agent_id=agent_id, + total_tokens=data["total_tokens"], + prompt_tokens=data["prompt_tokens"], + completion_tokens=data["completion_tokens"], + request_count=data["request_count"], + total_cost=data["total_cost"], + ) + + @router.get("/usage/models/{model}") + async def get_model_usage(model: str) -> ModelUsageInfo: + """Get usage totals for a specific model.""" + from xml_pipeline.llm import get_usage_tracker + + tracker = get_usage_tracker() + data = tracker.get_model_totals(model) + + return ModelUsageInfo( + model=model, + total_tokens=data["total_tokens"], + prompt_tokens=data["prompt_tokens"], + completion_tokens=data["completion_tokens"], + request_count=data["request_count"], + total_cost=data["total_cost"], + ) + + @router.post("/usage/reset") + async def reset_usage() -> dict: + """ + Reset all usage tracking (for testing/development). + + WARNING: This clears all usage history. Use with caution. + """ + from xml_pipeline.llm import reset_usage_tracker + + reset_usage_tracker() + return {"success": True, "message": "Usage tracking reset"} + # ========================================================================= # Control Endpoints # ========================================================================= diff --git a/xml_pipeline/server/models.py b/xml_pipeline/server/models.py index 2a98611..7aa6107 100644 --- a/xml_pipeline/server/models.py +++ b/xml_pipeline/server/models.py @@ -294,3 +294,77 @@ class CapabilityListResponse(CamelModel): capabilities: List[CapabilityInfo] count: int + + +# ============================================================================= +# Usage/Gas Tracking Models +# ============================================================================= + + +class UsageTotals(CamelModel): + """Aggregate usage statistics.""" + + total_tokens: int = Field(0, alias="totalTokens") + prompt_tokens: int = Field(0, alias="promptTokens") + completion_tokens: int = Field(0, alias="completionTokens") + request_count: int = Field(0, alias="requestCount") + total_cost: float = Field(0.0, alias="totalCost") + avg_latency_ms: float = Field(0.0, alias="avgLatencyMs") + + +class ThreadBudgetInfo(CamelModel): + """Token budget info for a thread.""" + + thread_id: str = Field(alias="threadId") + max_tokens: int = Field(alias="maxTokens") + prompt_tokens: int = Field(alias="promptTokens") + completion_tokens: int = Field(alias="completionTokens") + total_tokens: int = Field(alias="totalTokens") + remaining: int + percent_used: float = Field(alias="percentUsed") + is_exhausted: bool = Field(alias="isExhausted") + + +class AgentUsageInfo(CamelModel): + """Usage info for a specific agent.""" + + agent_id: str = Field(alias="agentId") + total_tokens: int = Field(0, alias="totalTokens") + prompt_tokens: int = Field(0, alias="promptTokens") + completion_tokens: int = Field(0, alias="completionTokens") + request_count: int = Field(0, alias="requestCount") + total_cost: float = Field(0.0, alias="totalCost") + + +class ModelUsageInfo(CamelModel): + """Usage info for a specific model.""" + + model: str + total_tokens: int = Field(0, alias="totalTokens") + prompt_tokens: int = Field(0, alias="promptTokens") + completion_tokens: int = Field(0, alias="completionTokens") + request_count: int = Field(0, alias="requestCount") + total_cost: float = Field(0.0, alias="totalCost") + + +class UsageOverview(CamelModel): + """Complete usage overview (gas gauge).""" + + totals: UsageTotals + by_agent: List[AgentUsageInfo] = Field(default_factory=list, alias="byAgent") + by_model: List[ModelUsageInfo] = Field(default_factory=list, alias="byModel") + active_threads: int = Field(0, alias="activeThreads") + + +class UsageResponse(CamelModel): + """Response for GET /usage.""" + + usage: UsageOverview + + +class ThreadBudgetListResponse(CamelModel): + """Response for GET /usage/threads.""" + + threads: List[ThreadBudgetInfo] + count: int + default_max_tokens: int = Field(alias="defaultMaxTokens")