Add usage persistence for billing (SQLite)
- UsageStore with async SQLite persistence via aiosqlite - Background batch writer for non-blocking event persistence - Auto-subscribes to UsageTracker for transparent capture - Query methods: query(), get_billing_summary(), get_daily_usage() - REST API endpoints: /usage/history, /usage/billing, /usage/daily - Filtering by org_id, agent_id, model, time range - 18 new tests for persistence layer Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
e6697f0ea2
commit
d0d78a9f70
5 changed files with 1227 additions and 1 deletions
398
tests/test_usage_store.py
Normal file
398
tests/test_usage_store.py
Normal file
|
|
@ -0,0 +1,398 @@
|
|||
"""
|
||||
Tests for usage persistence layer.
|
||||
|
||||
Tests UsageStore's SQLite persistence, querying, and billing summaries.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from xml_pipeline.llm.usage_tracker import UsageEvent, get_usage_tracker, reset_usage_tracker
|
||||
from xml_pipeline.llm.usage_store import UsageStore, reset_usage_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db():
|
||||
"""Create a temporary database file."""
|
||||
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
||||
yield f.name
|
||||
# Cleanup happens automatically
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def store(temp_db):
|
||||
"""Create and initialize a UsageStore with temp database."""
|
||||
reset_usage_tracker()
|
||||
reset_usage_store()
|
||||
|
||||
s = UsageStore(db_path=temp_db)
|
||||
await s.initialize()
|
||||
yield s
|
||||
await s.close()
|
||||
|
||||
|
||||
class TestUsageStoreBasics:
|
||||
"""Test basic store operations."""
|
||||
|
||||
async def test_initialize_creates_table(self, temp_db):
|
||||
"""Store initialization creates the usage_events table."""
|
||||
reset_usage_store()
|
||||
store = UsageStore(db_path=temp_db)
|
||||
await store.initialize()
|
||||
|
||||
# Query should work (table exists)
|
||||
events = await store.query(limit=10)
|
||||
assert events == []
|
||||
|
||||
await store.close()
|
||||
|
||||
async def test_store_initializes_once(self, store):
|
||||
"""Multiple initialize calls are idempotent."""
|
||||
# Should not raise
|
||||
await store.initialize()
|
||||
await store.initialize()
|
||||
|
||||
async def test_empty_query(self, store):
|
||||
"""Query on empty store returns empty list."""
|
||||
events = await store.query()
|
||||
assert events == []
|
||||
|
||||
async def test_empty_count(self, store):
|
||||
"""Count on empty store returns 0."""
|
||||
count = await store.count()
|
||||
assert count == 0
|
||||
|
||||
|
||||
class TestEventPersistence:
|
||||
"""Test event persistence via subscriber pattern."""
|
||||
|
||||
async def test_event_persisted_via_tracker(self, store):
|
||||
"""Events recorded in tracker are persisted to store."""
|
||||
tracker = get_usage_tracker()
|
||||
|
||||
# Record an event
|
||||
tracker.record(
|
||||
thread_id="test-thread-1",
|
||||
agent_id="greeter",
|
||||
model="grok-4.1",
|
||||
provider="xai",
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
latency_ms=250.5,
|
||||
metadata={"org_id": "org-123"},
|
||||
)
|
||||
|
||||
# Give background writer time to flush
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
# Query should find the event
|
||||
events = await store.query()
|
||||
assert len(events) == 1
|
||||
|
||||
event = events[0]
|
||||
assert event["thread_id"] == "test-thread-1"
|
||||
assert event["agent_id"] == "greeter"
|
||||
assert event["model"] == "grok-4.1"
|
||||
assert event["provider"] == "xai"
|
||||
assert event["prompt_tokens"] == 100
|
||||
assert event["completion_tokens"] == 50
|
||||
assert event["total_tokens"] == 150
|
||||
assert event["latency_ms"] == 250.5
|
||||
|
||||
async def test_multiple_events_persisted(self, store):
|
||||
"""Multiple events are all persisted."""
|
||||
tracker = get_usage_tracker()
|
||||
|
||||
for i in range(5):
|
||||
tracker.record(
|
||||
thread_id=f"thread-{i}",
|
||||
agent_id="agent",
|
||||
model="grok-4.1",
|
||||
provider="xai",
|
||||
prompt_tokens=100 * (i + 1),
|
||||
completion_tokens=50 * (i + 1),
|
||||
latency_ms=100.0,
|
||||
)
|
||||
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
events = await store.query()
|
||||
assert len(events) == 5
|
||||
|
||||
async def test_event_with_cost_estimate(self, store):
|
||||
"""Estimated cost is persisted."""
|
||||
tracker = get_usage_tracker()
|
||||
|
||||
# grok-4.1 has pricing in MODEL_COSTS
|
||||
tracker.record(
|
||||
thread_id="cost-thread",
|
||||
agent_id="agent",
|
||||
model="grok-4.1",
|
||||
provider="xai",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=500,
|
||||
latency_ms=100.0,
|
||||
)
|
||||
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
events = await store.query()
|
||||
assert len(events) == 1
|
||||
assert events[0]["estimated_cost"] is not None
|
||||
assert events[0]["estimated_cost"] > 0
|
||||
|
||||
|
||||
class TestQueryFiltering:
|
||||
"""Test query filtering options."""
|
||||
|
||||
async def test_filter_by_agent(self, store):
|
||||
"""Filter events by agent_id."""
|
||||
tracker = get_usage_tracker()
|
||||
|
||||
tracker.record(
|
||||
thread_id="t1", agent_id="greeter", model="grok-4.1",
|
||||
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
||||
)
|
||||
tracker.record(
|
||||
thread_id="t2", agent_id="shouter", model="grok-4.1",
|
||||
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
||||
)
|
||||
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
events = await store.query(agent_id="greeter")
|
||||
assert len(events) == 1
|
||||
assert events[0]["agent_id"] == "greeter"
|
||||
|
||||
async def test_filter_by_model(self, store):
|
||||
"""Filter events by model."""
|
||||
tracker = get_usage_tracker()
|
||||
|
||||
tracker.record(
|
||||
thread_id="t1", agent_id="agent", model="grok-4.1",
|
||||
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
||||
)
|
||||
tracker.record(
|
||||
thread_id="t2", agent_id="agent", model="claude-sonnet-4",
|
||||
provider="anthropic", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
||||
)
|
||||
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
events = await store.query(model="grok-4.1")
|
||||
assert len(events) == 1
|
||||
assert events[0]["model"] == "grok-4.1"
|
||||
|
||||
async def test_filter_by_org_id(self, store):
|
||||
"""Filter events by org_id in metadata."""
|
||||
tracker = get_usage_tracker()
|
||||
|
||||
tracker.record(
|
||||
thread_id="t1", agent_id="agent", model="grok-4.1",
|
||||
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
||||
metadata={"org_id": "org-A"},
|
||||
)
|
||||
tracker.record(
|
||||
thread_id="t2", agent_id="agent", model="grok-4.1",
|
||||
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
||||
metadata={"org_id": "org-B"},
|
||||
)
|
||||
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
events = await store.query(org_id="org-A")
|
||||
assert len(events) == 1
|
||||
assert events[0]["thread_id"] == "t1"
|
||||
|
||||
async def test_pagination(self, store):
|
||||
"""Test limit and offset for pagination."""
|
||||
tracker = get_usage_tracker()
|
||||
|
||||
for i in range(10):
|
||||
tracker.record(
|
||||
thread_id=f"t{i:02d}", agent_id="agent", model="grok-4.1",
|
||||
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
||||
)
|
||||
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
# Get first page
|
||||
page1 = await store.query(limit=3, offset=0)
|
||||
assert len(page1) == 3
|
||||
|
||||
# Get second page
|
||||
page2 = await store.query(limit=3, offset=3)
|
||||
assert len(page2) == 3
|
||||
|
||||
# Different events
|
||||
assert page1[0]["thread_id"] != page2[0]["thread_id"]
|
||||
|
||||
|
||||
class TestBillingSummary:
|
||||
"""Test billing summary aggregation."""
|
||||
|
||||
async def test_billing_totals(self, store):
|
||||
"""Billing summary calculates correct totals."""
|
||||
tracker = get_usage_tracker()
|
||||
|
||||
tracker.record(
|
||||
thread_id="t1", agent_id="agent", model="grok-4.1",
|
||||
provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0,
|
||||
)
|
||||
tracker.record(
|
||||
thread_id="t2", agent_id="agent", model="grok-4.1",
|
||||
provider="xai", prompt_tokens=2000, completion_tokens=1000, latency_ms=100.0,
|
||||
)
|
||||
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
summary = await store.get_billing_summary()
|
||||
|
||||
assert summary.total_tokens == 4500 # 1500 + 3000
|
||||
assert summary.prompt_tokens == 3000
|
||||
assert summary.completion_tokens == 1500
|
||||
assert summary.request_count == 2
|
||||
|
||||
async def test_billing_by_model(self, store):
|
||||
"""Billing summary includes breakdown by model."""
|
||||
tracker = get_usage_tracker()
|
||||
|
||||
tracker.record(
|
||||
thread_id="t1", agent_id="agent", model="grok-4.1",
|
||||
provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0,
|
||||
)
|
||||
tracker.record(
|
||||
thread_id="t2", agent_id="agent", model="claude-sonnet-4",
|
||||
provider="anthropic", prompt_tokens=2000, completion_tokens=1000, latency_ms=100.0,
|
||||
)
|
||||
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
summary = await store.get_billing_summary()
|
||||
|
||||
assert "grok-4.1" in summary.by_model
|
||||
assert "claude-sonnet-4" in summary.by_model
|
||||
assert summary.by_model["grok-4.1"]["total_tokens"] == 1500
|
||||
assert summary.by_model["claude-sonnet-4"]["total_tokens"] == 3000
|
||||
|
||||
async def test_billing_by_agent(self, store):
|
||||
"""Billing summary includes breakdown by agent."""
|
||||
tracker = get_usage_tracker()
|
||||
|
||||
tracker.record(
|
||||
thread_id="t1", agent_id="greeter", model="grok-4.1",
|
||||
provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0,
|
||||
)
|
||||
tracker.record(
|
||||
thread_id="t2", agent_id="shouter", model="grok-4.1",
|
||||
provider="xai", prompt_tokens=500, completion_tokens=250, latency_ms=100.0,
|
||||
)
|
||||
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
summary = await store.get_billing_summary()
|
||||
|
||||
assert "greeter" in summary.by_agent
|
||||
assert "shouter" in summary.by_agent
|
||||
assert summary.by_agent["greeter"]["total_tokens"] == 1500
|
||||
assert summary.by_agent["shouter"]["total_tokens"] == 750
|
||||
|
||||
async def test_billing_filtered_by_org(self, store):
|
||||
"""Billing summary can be filtered by org_id."""
|
||||
tracker = get_usage_tracker()
|
||||
|
||||
tracker.record(
|
||||
thread_id="t1", agent_id="agent", model="grok-4.1",
|
||||
provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0,
|
||||
metadata={"org_id": "org-A"},
|
||||
)
|
||||
tracker.record(
|
||||
thread_id="t2", agent_id="agent", model="grok-4.1",
|
||||
provider="xai", prompt_tokens=2000, completion_tokens=1000, latency_ms=100.0,
|
||||
metadata={"org_id": "org-B"},
|
||||
)
|
||||
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
summary_a = await store.get_billing_summary(org_id="org-A")
|
||||
summary_b = await store.get_billing_summary(org_id="org-B")
|
||||
|
||||
assert summary_a.total_tokens == 1500
|
||||
assert summary_b.total_tokens == 3000
|
||||
|
||||
|
||||
class TestDailyUsage:
|
||||
"""Test daily usage aggregation."""
|
||||
|
||||
async def test_daily_aggregation(self, store):
|
||||
"""Daily usage aggregates by date."""
|
||||
tracker = get_usage_tracker()
|
||||
|
||||
# Record multiple events (all same day since timestamps are auto-generated)
|
||||
for i in range(3):
|
||||
tracker.record(
|
||||
thread_id=f"t{i}", agent_id="agent", model="grok-4.1",
|
||||
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
||||
)
|
||||
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
days = await store.get_daily_usage()
|
||||
|
||||
assert len(days) == 1 # All same day
|
||||
assert days[0]["total_tokens"] == 450
|
||||
assert days[0]["request_count"] == 3
|
||||
|
||||
async def test_daily_returns_date(self, store):
|
||||
"""Daily usage includes date field."""
|
||||
tracker = get_usage_tracker()
|
||||
|
||||
tracker.record(
|
||||
thread_id="t1", agent_id="agent", model="grok-4.1",
|
||||
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
||||
)
|
||||
|
||||
await asyncio.sleep(1.5)
|
||||
|
||||
days = await store.get_daily_usage()
|
||||
|
||||
assert len(days) == 1
|
||||
# Date should be in YYYY-MM-DD format
|
||||
assert len(days[0]["date"]) == 10
|
||||
assert days[0]["date"].count("-") == 2
|
||||
|
||||
|
||||
class TestStoreLifecycle:
|
||||
"""Test store lifecycle management."""
|
||||
|
||||
async def test_close_flushes_pending(self, temp_db):
|
||||
"""Closing store flushes pending writes."""
|
||||
reset_usage_tracker()
|
||||
reset_usage_store()
|
||||
|
||||
store = UsageStore(db_path=temp_db)
|
||||
await store.initialize()
|
||||
|
||||
tracker = get_usage_tracker()
|
||||
tracker.record(
|
||||
thread_id="t1", agent_id="agent", model="grok-4.1",
|
||||
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
||||
)
|
||||
|
||||
# Close immediately - should flush
|
||||
await store.close()
|
||||
|
||||
# Reopen and verify data persisted
|
||||
store2 = UsageStore(db_path=temp_db)
|
||||
await store2.initialize()
|
||||
|
||||
events = await store2.query()
|
||||
assert len(events) == 1
|
||||
|
||||
await store2.close()
|
||||
|
|
@ -30,6 +30,20 @@ Usage Tracking:
|
|||
|
||||
# Query totals
|
||||
totals = tracker.get_totals()
|
||||
|
||||
Usage Persistence (for billing):
|
||||
from xml_pipeline.llm import get_usage_store
|
||||
|
||||
store = await get_usage_store()
|
||||
|
||||
# Query historical usage
|
||||
events = await store.query(
|
||||
start_time="2025-01-01T00:00:00Z",
|
||||
org_id="org-123",
|
||||
)
|
||||
|
||||
# Get billing summary
|
||||
summary = await store.get_billing_summary(org_id="org-123")
|
||||
"""
|
||||
|
||||
from xml_pipeline.llm.router import (
|
||||
|
|
@ -46,6 +60,13 @@ from xml_pipeline.llm.usage_tracker import (
|
|||
get_usage_tracker,
|
||||
reset_usage_tracker,
|
||||
)
|
||||
from xml_pipeline.llm.usage_store import (
|
||||
UsageStore,
|
||||
BillingSummary,
|
||||
get_usage_store,
|
||||
close_usage_store,
|
||||
reset_usage_store,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Router
|
||||
|
|
@ -58,9 +79,15 @@ __all__ = [
|
|||
"LLMRequest",
|
||||
"LLMResponse",
|
||||
"BackendError",
|
||||
# Usage tracking
|
||||
# Usage tracking (in-memory)
|
||||
"UsageTracker",
|
||||
"UsageEvent",
|
||||
"get_usage_tracker",
|
||||
"reset_usage_tracker",
|
||||
# Usage persistence (SQLite)
|
||||
"UsageStore",
|
||||
"BillingSummary",
|
||||
"get_usage_store",
|
||||
"close_usage_store",
|
||||
"reset_usage_store",
|
||||
]
|
||||
|
|
|
|||
599
xml_pipeline/llm/usage_store.py
Normal file
599
xml_pipeline/llm/usage_store.py
Normal file
|
|
@ -0,0 +1,599 @@
|
|||
"""
|
||||
Usage Store — Persistent storage for billing and usage analytics.
|
||||
|
||||
Stores UsageEvents to SQLite (default) or PostgreSQL for:
|
||||
- Historical billing queries
|
||||
- Usage analytics and reporting
|
||||
- Audit trails
|
||||
|
||||
The store auto-subscribes to UsageTracker on initialization,
|
||||
persisting all events transparently.
|
||||
|
||||
Example:
|
||||
from xml_pipeline.llm.usage_store import get_usage_store
|
||||
|
||||
store = await get_usage_store()
|
||||
|
||||
# Query historical usage
|
||||
events = await store.query(
|
||||
start_time="2025-01-01T00:00:00Z",
|
||||
end_time="2025-01-31T23:59:59Z",
|
||||
org_id="org-123",
|
||||
)
|
||||
|
||||
# Get billing summary
|
||||
summary = await store.get_billing_summary(
|
||||
org_id="org-123",
|
||||
start_time="2025-01-01T00:00:00Z",
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from dataclasses import dataclass, asdict
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
import aiosqlite
|
||||
HAS_AIOSQLITE = True
|
||||
except ImportError:
|
||||
HAS_AIOSQLITE = False
|
||||
|
||||
from xml_pipeline.llm.usage_tracker import UsageEvent, get_usage_tracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Default database path
|
||||
DEFAULT_DB_PATH = Path.home() / ".xml-pipeline" / "usage.db"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BillingSummary:
|
||||
"""Aggregated billing data for a time period."""
|
||||
org_id: Optional[str]
|
||||
start_time: str
|
||||
end_time: str
|
||||
total_tokens: int
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
request_count: int
|
||||
total_cost: float
|
||||
by_model: Dict[str, Dict[str, Any]]
|
||||
by_agent: Dict[str, Dict[str, Any]]
|
||||
|
||||
|
||||
class UsageStore:
|
||||
"""
|
||||
Persistent storage for usage events.
|
||||
|
||||
Uses SQLite by default, with async I/O via aiosqlite.
|
||||
Automatically subscribes to UsageTracker to capture all events.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Optional[str] = None):
|
||||
"""
|
||||
Initialize the usage store.
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database. Defaults to ~/.xml-pipeline/usage.db
|
||||
"""
|
||||
if not HAS_AIOSQLITE:
|
||||
raise ImportError(
|
||||
"aiosqlite is required for usage persistence. "
|
||||
"Install with: pip install aiosqlite"
|
||||
)
|
||||
|
||||
self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
|
||||
self._db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._initialized = False
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
# Queue for async persistence (events come from sync callback)
|
||||
self._queue: asyncio.Queue[UsageEvent] = asyncio.Queue()
|
||||
self._writer_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize database schema and start background writer."""
|
||||
async with self._init_lock:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
# Create tables
|
||||
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||
await db.execute("""
|
||||
CREATE TABLE IF NOT EXISTS usage_events (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TEXT NOT NULL,
|
||||
thread_id TEXT NOT NULL,
|
||||
agent_id TEXT,
|
||||
model TEXT NOT NULL,
|
||||
provider TEXT NOT NULL,
|
||||
prompt_tokens INTEGER NOT NULL,
|
||||
completion_tokens INTEGER NOT NULL,
|
||||
total_tokens INTEGER NOT NULL,
|
||||
latency_ms REAL NOT NULL,
|
||||
estimated_cost REAL,
|
||||
metadata TEXT,
|
||||
|
||||
-- Denormalized for billing queries
|
||||
org_id TEXT,
|
||||
user_id TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
# Indexes for common queries
|
||||
await db.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_timestamp
|
||||
ON usage_events(timestamp)
|
||||
""")
|
||||
await db.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_org_id
|
||||
ON usage_events(org_id)
|
||||
""")
|
||||
await db.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_agent_id
|
||||
ON usage_events(agent_id)
|
||||
""")
|
||||
await db.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_usage_model
|
||||
ON usage_events(model)
|
||||
""")
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Start background writer
|
||||
self._running = True
|
||||
self._writer_task = asyncio.create_task(self._writer_loop())
|
||||
|
||||
# Subscribe to usage tracker
|
||||
tracker = get_usage_tracker()
|
||||
tracker.subscribe(self._on_usage_event)
|
||||
|
||||
self._initialized = True
|
||||
logger.info(f"UsageStore initialized: {self._db_path}")
|
||||
|
||||
def _on_usage_event(self, event: UsageEvent) -> None:
|
||||
"""
|
||||
Callback for UsageTracker events.
|
||||
|
||||
This runs synchronously from the tracker, so we queue
|
||||
the event for async persistence.
|
||||
"""
|
||||
try:
|
||||
self._queue.put_nowait(event)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("Usage event queue full, dropping event")
|
||||
|
||||
async def _writer_loop(self) -> None:
|
||||
"""Background task that writes queued events to database."""
|
||||
batch: List[UsageEvent] = []
|
||||
batch_timeout = 1.0 # Flush every second or 100 events
|
||||
|
||||
while self._running or not self._queue.empty():
|
||||
try:
|
||||
# Collect batch
|
||||
try:
|
||||
event = await asyncio.wait_for(
|
||||
self._queue.get(),
|
||||
timeout=batch_timeout
|
||||
)
|
||||
batch.append(event)
|
||||
|
||||
# Drain queue up to batch size
|
||||
while len(batch) < 100:
|
||||
try:
|
||||
event = self._queue.get_nowait()
|
||||
batch.append(event)
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
|
||||
# Write batch
|
||||
if batch:
|
||||
await self._write_batch(batch)
|
||||
batch = []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Usage writer error: {e}")
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
async def _write_batch(self, events: List[UsageEvent]) -> None:
|
||||
"""Write a batch of events to database."""
|
||||
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||
await db.executemany(
|
||||
"""
|
||||
INSERT INTO usage_events (
|
||||
timestamp, thread_id, agent_id, model, provider,
|
||||
prompt_tokens, completion_tokens, total_tokens,
|
||||
latency_ms, estimated_cost, metadata, org_id, user_id
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
[
|
||||
(
|
||||
e.timestamp,
|
||||
e.thread_id,
|
||||
e.agent_id,
|
||||
e.model,
|
||||
e.provider,
|
||||
e.prompt_tokens,
|
||||
e.completion_tokens,
|
||||
e.total_tokens,
|
||||
e.latency_ms,
|
||||
e.estimated_cost,
|
||||
json.dumps(e.metadata) if e.metadata else None,
|
||||
e.metadata.get("org_id") if e.metadata else None,
|
||||
e.metadata.get("user_id") if e.metadata else None,
|
||||
)
|
||||
for e in events
|
||||
]
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
logger.debug(f"Persisted {len(events)} usage events")
|
||||
|
||||
async def query(
|
||||
self,
|
||||
*,
|
||||
start_time: Optional[str] = None,
|
||||
end_time: Optional[str] = None,
|
||||
org_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
limit: int = 1000,
|
||||
offset: int = 0,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query historical usage events.
|
||||
|
||||
Args:
|
||||
start_time: ISO 8601 timestamp (inclusive)
|
||||
end_time: ISO 8601 timestamp (inclusive)
|
||||
org_id: Filter by organization
|
||||
user_id: Filter by user
|
||||
agent_id: Filter by agent
|
||||
model: Filter by model
|
||||
limit: Max results (default 1000)
|
||||
offset: Pagination offset
|
||||
|
||||
Returns:
|
||||
List of usage event dicts
|
||||
"""
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
if start_time:
|
||||
conditions.append("timestamp >= ?")
|
||||
params.append(start_time)
|
||||
if end_time:
|
||||
conditions.append("timestamp <= ?")
|
||||
params.append(end_time)
|
||||
if org_id:
|
||||
conditions.append("org_id = ?")
|
||||
params.append(org_id)
|
||||
if user_id:
|
||||
conditions.append("user_id = ?")
|
||||
params.append(user_id)
|
||||
if agent_id:
|
||||
conditions.append("agent_id = ?")
|
||||
params.append(agent_id)
|
||||
if model:
|
||||
conditions.append("model = ?")
|
||||
params.append(model)
|
||||
|
||||
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||||
|
||||
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
cursor = await db.execute(
|
||||
f"""
|
||||
SELECT * FROM usage_events
|
||||
WHERE {where_clause}
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ? OFFSET ?
|
||||
""",
|
||||
params + [limit, offset]
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"id": row["id"],
|
||||
"timestamp": row["timestamp"],
|
||||
"thread_id": row["thread_id"],
|
||||
"agent_id": row["agent_id"],
|
||||
"model": row["model"],
|
||||
"provider": row["provider"],
|
||||
"prompt_tokens": row["prompt_tokens"],
|
||||
"completion_tokens": row["completion_tokens"],
|
||||
"total_tokens": row["total_tokens"],
|
||||
"latency_ms": row["latency_ms"],
|
||||
"estimated_cost": row["estimated_cost"],
|
||||
"metadata": json.loads(row["metadata"]) if row["metadata"] else {},
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def get_billing_summary(
|
||||
self,
|
||||
*,
|
||||
start_time: Optional[str] = None,
|
||||
end_time: Optional[str] = None,
|
||||
org_id: Optional[str] = None,
|
||||
) -> BillingSummary:
|
||||
"""
|
||||
Get aggregated billing summary for a time period.
|
||||
|
||||
Args:
|
||||
start_time: ISO 8601 timestamp (inclusive)
|
||||
end_time: ISO 8601 timestamp (inclusive)
|
||||
org_id: Filter by organization
|
||||
|
||||
Returns:
|
||||
BillingSummary with totals and breakdowns
|
||||
"""
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
if start_time:
|
||||
conditions.append("timestamp >= ?")
|
||||
params.append(start_time)
|
||||
if end_time:
|
||||
conditions.append("timestamp <= ?")
|
||||
params.append(end_time)
|
||||
if org_id:
|
||||
conditions.append("org_id = ?")
|
||||
params.append(org_id)
|
||||
|
||||
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||||
|
||||
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||
# Overall totals
|
||||
cursor = await db.execute(
|
||||
f"""
|
||||
SELECT
|
||||
COALESCE(SUM(total_tokens), 0) as total_tokens,
|
||||
COALESCE(SUM(prompt_tokens), 0) as prompt_tokens,
|
||||
COALESCE(SUM(completion_tokens), 0) as completion_tokens,
|
||||
COUNT(*) as request_count,
|
||||
COALESCE(SUM(estimated_cost), 0) as total_cost
|
||||
FROM usage_events
|
||||
WHERE {where_clause}
|
||||
""",
|
||||
params
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
|
||||
totals = {
|
||||
"total_tokens": row[0],
|
||||
"prompt_tokens": row[1],
|
||||
"completion_tokens": row[2],
|
||||
"request_count": row[3],
|
||||
"total_cost": round(row[4], 4),
|
||||
}
|
||||
|
||||
# By model
|
||||
cursor = await db.execute(
|
||||
f"""
|
||||
SELECT
|
||||
model,
|
||||
SUM(total_tokens) as total_tokens,
|
||||
SUM(prompt_tokens) as prompt_tokens,
|
||||
SUM(completion_tokens) as completion_tokens,
|
||||
COUNT(*) as request_count,
|
||||
SUM(estimated_cost) as total_cost
|
||||
FROM usage_events
|
||||
WHERE {where_clause}
|
||||
GROUP BY model
|
||||
""",
|
||||
params
|
||||
)
|
||||
by_model = {
|
||||
row[0]: {
|
||||
"total_tokens": row[1],
|
||||
"prompt_tokens": row[2],
|
||||
"completion_tokens": row[3],
|
||||
"request_count": row[4],
|
||||
"total_cost": round(row[5] or 0, 4),
|
||||
}
|
||||
for row in await cursor.fetchall()
|
||||
}
|
||||
|
||||
# By agent
|
||||
cursor = await db.execute(
|
||||
f"""
|
||||
SELECT
|
||||
agent_id,
|
||||
SUM(total_tokens) as total_tokens,
|
||||
SUM(prompt_tokens) as prompt_tokens,
|
||||
SUM(completion_tokens) as completion_tokens,
|
||||
COUNT(*) as request_count,
|
||||
SUM(estimated_cost) as total_cost
|
||||
FROM usage_events
|
||||
WHERE {where_clause} AND agent_id IS NOT NULL
|
||||
GROUP BY agent_id
|
||||
""",
|
||||
params
|
||||
)
|
||||
by_agent = {
|
||||
row[0]: {
|
||||
"total_tokens": row[1],
|
||||
"prompt_tokens": row[2],
|
||||
"completion_tokens": row[3],
|
||||
"request_count": row[4],
|
||||
"total_cost": round(row[5] or 0, 4),
|
||||
}
|
||||
for row in await cursor.fetchall()
|
||||
}
|
||||
|
||||
return BillingSummary(
|
||||
org_id=org_id,
|
||||
start_time=start_time or "",
|
||||
end_time=end_time or datetime.now(timezone.utc).isoformat(),
|
||||
total_tokens=totals["total_tokens"],
|
||||
prompt_tokens=totals["prompt_tokens"],
|
||||
completion_tokens=totals["completion_tokens"],
|
||||
request_count=totals["request_count"],
|
||||
total_cost=totals["total_cost"],
|
||||
by_model=by_model,
|
||||
by_agent=by_agent,
|
||||
)
|
||||
|
||||
async def get_daily_usage(
|
||||
self,
|
||||
*,
|
||||
start_time: Optional[str] = None,
|
||||
end_time: Optional[str] = None,
|
||||
org_id: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get usage aggregated by day for charting.
|
||||
|
||||
Returns:
|
||||
List of {date, total_tokens, request_count, total_cost}
|
||||
"""
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
if start_time:
|
||||
conditions.append("timestamp >= ?")
|
||||
params.append(start_time)
|
||||
if end_time:
|
||||
conditions.append("timestamp <= ?")
|
||||
params.append(end_time)
|
||||
if org_id:
|
||||
conditions.append("org_id = ?")
|
||||
params.append(org_id)
|
||||
|
||||
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||||
|
||||
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||
cursor = await db.execute(
|
||||
f"""
|
||||
SELECT
|
||||
DATE(timestamp) as date,
|
||||
SUM(total_tokens) as total_tokens,
|
||||
COUNT(*) as request_count,
|
||||
SUM(estimated_cost) as total_cost
|
||||
FROM usage_events
|
||||
WHERE {where_clause}
|
||||
GROUP BY DATE(timestamp)
|
||||
ORDER BY date
|
||||
""",
|
||||
params
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
|
||||
return [
|
||||
{
|
||||
"date": row[0],
|
||||
"total_tokens": row[1],
|
||||
"request_count": row[2],
|
||||
"total_cost": round(row[3] or 0, 4),
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
async def count(
|
||||
self,
|
||||
*,
|
||||
start_time: Optional[str] = None,
|
||||
end_time: Optional[str] = None,
|
||||
org_id: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Get total count of events matching criteria."""
|
||||
conditions = []
|
||||
params = []
|
||||
|
||||
if start_time:
|
||||
conditions.append("timestamp >= ?")
|
||||
params.append(start_time)
|
||||
if end_time:
|
||||
conditions.append("timestamp <= ?")
|
||||
params.append(end_time)
|
||||
if org_id:
|
||||
conditions.append("org_id = ?")
|
||||
params.append(org_id)
|
||||
|
||||
where_clause = " AND ".join(conditions) if conditions else "1=1"
|
||||
|
||||
async with aiosqlite.connect(str(self._db_path)) as db:
|
||||
cursor = await db.execute(
|
||||
f"SELECT COUNT(*) FROM usage_events WHERE {where_clause}",
|
||||
params
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
return row[0]
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Shutdown the store, flushing pending writes."""
|
||||
self._running = False
|
||||
|
||||
# Unsubscribe from tracker
|
||||
tracker = get_usage_tracker()
|
||||
tracker.unsubscribe(self._on_usage_event)
|
||||
|
||||
# Wait for writer to finish
|
||||
if self._writer_task:
|
||||
await self._writer_task
|
||||
self._writer_task = None
|
||||
|
||||
logger.info("UsageStore closed")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Global Instance
|
||||
# =============================================================================
|
||||
|
||||
_store: Optional[UsageStore] = None
|
||||
_store_lock = threading.Lock()
|
||||
|
||||
|
||||
async def get_usage_store(db_path: Optional[str] = None) -> UsageStore:
|
||||
"""
|
||||
Get the global usage store, initializing if needed.
|
||||
|
||||
Args:
|
||||
db_path: Optional path to SQLite database
|
||||
|
||||
Returns:
|
||||
Initialized UsageStore
|
||||
"""
|
||||
global _store
|
||||
|
||||
if _store is None:
|
||||
with _store_lock:
|
||||
if _store is None:
|
||||
_store = UsageStore(db_path)
|
||||
|
||||
if not _store._initialized:
|
||||
await _store.initialize()
|
||||
|
||||
return _store
|
||||
|
||||
|
||||
async def close_usage_store() -> None:
|
||||
"""Close the global usage store."""
|
||||
global _store
|
||||
if _store is not None:
|
||||
await _store.close()
|
||||
_store = None
|
||||
|
||||
|
||||
def reset_usage_store() -> None:
|
||||
"""Reset global store (for testing)."""
|
||||
global _store
|
||||
with _store_lock:
|
||||
_store = None
|
||||
|
|
@ -19,9 +19,12 @@ from xml_pipeline.server.models import (
|
|||
AgentInfo,
|
||||
AgentListResponse,
|
||||
AgentUsageInfo,
|
||||
BillingSummaryResponse,
|
||||
CapabilityDetail,
|
||||
CapabilityInfo,
|
||||
CapabilityListResponse,
|
||||
DailyUsagePoint,
|
||||
DailyUsageResponse,
|
||||
ErrorResponse,
|
||||
InjectRequest,
|
||||
InjectResponse,
|
||||
|
|
@ -33,6 +36,8 @@ from xml_pipeline.server.models import (
|
|||
ThreadInfo,
|
||||
ThreadListResponse,
|
||||
ThreadStatus,
|
||||
UsageEventInfo,
|
||||
UsageHistoryResponse,
|
||||
UsageOverview,
|
||||
UsageResponse,
|
||||
UsageTotals,
|
||||
|
|
@ -425,6 +430,140 @@ def create_router(state: "ServerState") -> APIRouter:
|
|||
reset_usage_tracker()
|
||||
return {"success": True, "message": "Usage tracking reset"}
|
||||
|
||||
# =========================================================================
|
||||
# Usage History Endpoints (Persistent)
|
||||
# =========================================================================
|
||||
|
||||
@router.get("/usage/history", response_model=UsageHistoryResponse)
|
||||
async def get_usage_history(
|
||||
start_time: Optional[str] = Query(None, description="ISO 8601 start time"),
|
||||
end_time: Optional[str] = Query(None, description="ISO 8601 end time"),
|
||||
org_id: Optional[str] = Query(None, description="Filter by organization"),
|
||||
agent_id: Optional[str] = Query(None, description="Filter by agent"),
|
||||
model: Optional[str] = Query(None, description="Filter by model"),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
offset: int = Query(0, ge=0),
|
||||
) -> UsageHistoryResponse:
|
||||
"""
|
||||
Query historical usage events from persistent storage.
|
||||
|
||||
Use for billing reconciliation, audit trails, and detailed analytics.
|
||||
Events are stored in SQLite and persist across restarts.
|
||||
"""
|
||||
from xml_pipeline.llm.usage_store import get_usage_store
|
||||
|
||||
store = await get_usage_store()
|
||||
|
||||
events = await store.query(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
org_id=org_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
total = await store.count(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
return UsageHistoryResponse(
|
||||
events=[
|
||||
UsageEventInfo(
|
||||
id=e["id"],
|
||||
timestamp=e["timestamp"],
|
||||
thread_id=e["thread_id"],
|
||||
agent_id=e.get("agent_id"),
|
||||
model=e["model"],
|
||||
provider=e["provider"],
|
||||
prompt_tokens=e["prompt_tokens"],
|
||||
completion_tokens=e["completion_tokens"],
|
||||
total_tokens=e["total_tokens"],
|
||||
latency_ms=e["latency_ms"],
|
||||
estimated_cost=e.get("estimated_cost"),
|
||||
metadata=e.get("metadata", {}),
|
||||
)
|
||||
for e in events
|
||||
],
|
||||
count=len(events),
|
||||
total=total,
|
||||
offset=offset,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
@router.get("/usage/billing", response_model=BillingSummaryResponse)
|
||||
async def get_billing_summary(
|
||||
start_time: Optional[str] = Query(None, description="ISO 8601 start time"),
|
||||
end_time: Optional[str] = Query(None, description="ISO 8601 end time"),
|
||||
org_id: Optional[str] = Query(None, description="Filter by organization"),
|
||||
) -> BillingSummaryResponse:
|
||||
"""
|
||||
Get aggregated billing summary for a time period.
|
||||
|
||||
Returns total tokens, costs, and breakdowns by model and agent.
|
||||
Use for invoicing and cost analysis.
|
||||
"""
|
||||
from xml_pipeline.llm.usage_store import get_usage_store
|
||||
|
||||
store = await get_usage_store()
|
||||
|
||||
summary = await store.get_billing_summary(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
return BillingSummaryResponse(
|
||||
org_id=summary.org_id,
|
||||
start_time=summary.start_time,
|
||||
end_time=summary.end_time,
|
||||
total_tokens=summary.total_tokens,
|
||||
prompt_tokens=summary.prompt_tokens,
|
||||
completion_tokens=summary.completion_tokens,
|
||||
request_count=summary.request_count,
|
||||
total_cost=summary.total_cost,
|
||||
by_model=summary.by_model,
|
||||
by_agent=summary.by_agent,
|
||||
)
|
||||
|
||||
@router.get("/usage/daily", response_model=DailyUsageResponse)
|
||||
async def get_daily_usage(
|
||||
start_time: Optional[str] = Query(None, description="ISO 8601 start time"),
|
||||
end_time: Optional[str] = Query(None, description="ISO 8601 end time"),
|
||||
org_id: Optional[str] = Query(None, description="Filter by organization"),
|
||||
) -> DailyUsageResponse:
|
||||
"""
|
||||
Get usage aggregated by day for charting.
|
||||
|
||||
Returns daily totals for tokens, requests, and costs.
|
||||
Useful for dashboards and trend analysis.
|
||||
"""
|
||||
from xml_pipeline.llm.usage_store import get_usage_store
|
||||
|
||||
store = await get_usage_store()
|
||||
|
||||
days = await store.get_daily_usage(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
org_id=org_id,
|
||||
)
|
||||
|
||||
return DailyUsageResponse(
|
||||
days=[
|
||||
DailyUsagePoint(
|
||||
date=d["date"],
|
||||
total_tokens=d["total_tokens"],
|
||||
request_count=d["request_count"],
|
||||
total_cost=d["total_cost"],
|
||||
)
|
||||
for d in days
|
||||
],
|
||||
count=len(days),
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# Control Endpoints
|
||||
# =========================================================================
|
||||
|
|
|
|||
|
|
@ -368,3 +368,66 @@ class ThreadBudgetListResponse(CamelModel):
|
|||
threads: List[ThreadBudgetInfo]
|
||||
count: int
|
||||
default_max_tokens: int = Field(alias="defaultMaxTokens")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Usage History Models (Persistent)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class UsageEventInfo(CamelModel):
|
||||
"""A single usage event from history."""
|
||||
|
||||
id: int
|
||||
timestamp: str
|
||||
thread_id: str = Field(alias="threadId")
|
||||
agent_id: Optional[str] = Field(None, alias="agentId")
|
||||
model: str
|
||||
provider: str
|
||||
prompt_tokens: int = Field(alias="promptTokens")
|
||||
completion_tokens: int = Field(alias="completionTokens")
|
||||
total_tokens: int = Field(alias="totalTokens")
|
||||
latency_ms: float = Field(alias="latencyMs")
|
||||
estimated_cost: Optional[float] = Field(None, alias="estimatedCost")
|
||||
metadata: dict = Field(default_factory=dict)
|
||||
|
||||
|
||||
class UsageHistoryResponse(CamelModel):
|
||||
"""Response for GET /usage/history."""
|
||||
|
||||
events: List[UsageEventInfo]
|
||||
count: int
|
||||
total: int
|
||||
offset: int
|
||||
limit: int
|
||||
|
||||
|
||||
class BillingSummaryResponse(CamelModel):
|
||||
"""Response for GET /usage/billing."""
|
||||
|
||||
org_id: Optional[str] = Field(None, alias="orgId")
|
||||
start_time: str = Field(alias="startTime")
|
||||
end_time: str = Field(alias="endTime")
|
||||
total_tokens: int = Field(alias="totalTokens")
|
||||
prompt_tokens: int = Field(alias="promptTokens")
|
||||
completion_tokens: int = Field(alias="completionTokens")
|
||||
request_count: int = Field(alias="requestCount")
|
||||
total_cost: float = Field(alias="totalCost")
|
||||
by_model: dict = Field(default_factory=dict, alias="byModel")
|
||||
by_agent: dict = Field(default_factory=dict, alias="byAgent")
|
||||
|
||||
|
||||
class DailyUsagePoint(CamelModel):
|
||||
"""A single day's usage for charting."""
|
||||
|
||||
date: str
|
||||
total_tokens: int = Field(alias="totalTokens")
|
||||
request_count: int = Field(alias="requestCount")
|
||||
total_cost: float = Field(alias="totalCost")
|
||||
|
||||
|
||||
class DailyUsageResponse(CamelModel):
|
||||
"""Response for GET /usage/daily."""
|
||||
|
||||
days: List[DailyUsagePoint]
|
||||
count: int
|
||||
|
|
|
|||
Loading…
Reference in a new issue