Add usage persistence for billing (SQLite)

- UsageStore with async SQLite persistence via aiosqlite
- Background batch writer for non-blocking event persistence
- Auto-subscribes to UsageTracker for transparent capture
- Query methods: query(), get_billing_summary(), get_daily_usage()
- REST API endpoints: /usage/history, /usage/billing, /usage/daily
- Filtering by org_id, agent_id, model, time range
- 18 new tests for persistence layer

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
dullfig 2026-01-27 21:58:22 -08:00
parent e6697f0ea2
commit d0d78a9f70
5 changed files with 1227 additions and 1 deletions

398
tests/test_usage_store.py Normal file
View file

@ -0,0 +1,398 @@
"""
Tests for usage persistence layer.
Tests UsageStore's SQLite persistence, querying, and billing summaries.
"""
from __future__ import annotations
import asyncio
import tempfile
from datetime import datetime, timezone
from pathlib import Path
import pytest
from xml_pipeline.llm.usage_tracker import UsageEvent, get_usage_tracker, reset_usage_tracker
from xml_pipeline.llm.usage_store import UsageStore, reset_usage_store
@pytest.fixture
def temp_db():
"""Create a temporary database file."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
yield f.name
# Cleanup happens automatically
@pytest.fixture
async def store(temp_db):
"""Create and initialize a UsageStore with temp database."""
reset_usage_tracker()
reset_usage_store()
s = UsageStore(db_path=temp_db)
await s.initialize()
yield s
await s.close()
class TestUsageStoreBasics:
"""Test basic store operations."""
async def test_initialize_creates_table(self, temp_db):
"""Store initialization creates the usage_events table."""
reset_usage_store()
store = UsageStore(db_path=temp_db)
await store.initialize()
# Query should work (table exists)
events = await store.query(limit=10)
assert events == []
await store.close()
async def test_store_initializes_once(self, store):
"""Multiple initialize calls are idempotent."""
# Should not raise
await store.initialize()
await store.initialize()
async def test_empty_query(self, store):
"""Query on empty store returns empty list."""
events = await store.query()
assert events == []
async def test_empty_count(self, store):
"""Count on empty store returns 0."""
count = await store.count()
assert count == 0
class TestEventPersistence:
"""Test event persistence via subscriber pattern."""
async def test_event_persisted_via_tracker(self, store):
"""Events recorded in tracker are persisted to store."""
tracker = get_usage_tracker()
# Record an event
tracker.record(
thread_id="test-thread-1",
agent_id="greeter",
model="grok-4.1",
provider="xai",
prompt_tokens=100,
completion_tokens=50,
latency_ms=250.5,
metadata={"org_id": "org-123"},
)
# Give background writer time to flush
await asyncio.sleep(1.5)
# Query should find the event
events = await store.query()
assert len(events) == 1
event = events[0]
assert event["thread_id"] == "test-thread-1"
assert event["agent_id"] == "greeter"
assert event["model"] == "grok-4.1"
assert event["provider"] == "xai"
assert event["prompt_tokens"] == 100
assert event["completion_tokens"] == 50
assert event["total_tokens"] == 150
assert event["latency_ms"] == 250.5
async def test_multiple_events_persisted(self, store):
"""Multiple events are all persisted."""
tracker = get_usage_tracker()
for i in range(5):
tracker.record(
thread_id=f"thread-{i}",
agent_id="agent",
model="grok-4.1",
provider="xai",
prompt_tokens=100 * (i + 1),
completion_tokens=50 * (i + 1),
latency_ms=100.0,
)
await asyncio.sleep(1.5)
events = await store.query()
assert len(events) == 5
async def test_event_with_cost_estimate(self, store):
"""Estimated cost is persisted."""
tracker = get_usage_tracker()
# grok-4.1 has pricing in MODEL_COSTS
tracker.record(
thread_id="cost-thread",
agent_id="agent",
model="grok-4.1",
provider="xai",
prompt_tokens=1000,
completion_tokens=500,
latency_ms=100.0,
)
await asyncio.sleep(1.5)
events = await store.query()
assert len(events) == 1
assert events[0]["estimated_cost"] is not None
assert events[0]["estimated_cost"] > 0
class TestQueryFiltering:
"""Test query filtering options."""
async def test_filter_by_agent(self, store):
"""Filter events by agent_id."""
tracker = get_usage_tracker()
tracker.record(
thread_id="t1", agent_id="greeter", model="grok-4.1",
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
)
tracker.record(
thread_id="t2", agent_id="shouter", model="grok-4.1",
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
)
await asyncio.sleep(1.5)
events = await store.query(agent_id="greeter")
assert len(events) == 1
assert events[0]["agent_id"] == "greeter"
async def test_filter_by_model(self, store):
"""Filter events by model."""
tracker = get_usage_tracker()
tracker.record(
thread_id="t1", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
)
tracker.record(
thread_id="t2", agent_id="agent", model="claude-sonnet-4",
provider="anthropic", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
)
await asyncio.sleep(1.5)
events = await store.query(model="grok-4.1")
assert len(events) == 1
assert events[0]["model"] == "grok-4.1"
async def test_filter_by_org_id(self, store):
"""Filter events by org_id in metadata."""
tracker = get_usage_tracker()
tracker.record(
thread_id="t1", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
metadata={"org_id": "org-A"},
)
tracker.record(
thread_id="t2", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
metadata={"org_id": "org-B"},
)
await asyncio.sleep(1.5)
events = await store.query(org_id="org-A")
assert len(events) == 1
assert events[0]["thread_id"] == "t1"
async def test_pagination(self, store):
"""Test limit and offset for pagination."""
tracker = get_usage_tracker()
for i in range(10):
tracker.record(
thread_id=f"t{i:02d}", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
)
await asyncio.sleep(1.5)
# Get first page
page1 = await store.query(limit=3, offset=0)
assert len(page1) == 3
# Get second page
page2 = await store.query(limit=3, offset=3)
assert len(page2) == 3
# Different events
assert page1[0]["thread_id"] != page2[0]["thread_id"]
class TestBillingSummary:
"""Test billing summary aggregation."""
async def test_billing_totals(self, store):
"""Billing summary calculates correct totals."""
tracker = get_usage_tracker()
tracker.record(
thread_id="t1", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0,
)
tracker.record(
thread_id="t2", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=2000, completion_tokens=1000, latency_ms=100.0,
)
await asyncio.sleep(1.5)
summary = await store.get_billing_summary()
assert summary.total_tokens == 4500 # 1500 + 3000
assert summary.prompt_tokens == 3000
assert summary.completion_tokens == 1500
assert summary.request_count == 2
async def test_billing_by_model(self, store):
"""Billing summary includes breakdown by model."""
tracker = get_usage_tracker()
tracker.record(
thread_id="t1", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0,
)
tracker.record(
thread_id="t2", agent_id="agent", model="claude-sonnet-4",
provider="anthropic", prompt_tokens=2000, completion_tokens=1000, latency_ms=100.0,
)
await asyncio.sleep(1.5)
summary = await store.get_billing_summary()
assert "grok-4.1" in summary.by_model
assert "claude-sonnet-4" in summary.by_model
assert summary.by_model["grok-4.1"]["total_tokens"] == 1500
assert summary.by_model["claude-sonnet-4"]["total_tokens"] == 3000
async def test_billing_by_agent(self, store):
"""Billing summary includes breakdown by agent."""
tracker = get_usage_tracker()
tracker.record(
thread_id="t1", agent_id="greeter", model="grok-4.1",
provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0,
)
tracker.record(
thread_id="t2", agent_id="shouter", model="grok-4.1",
provider="xai", prompt_tokens=500, completion_tokens=250, latency_ms=100.0,
)
await asyncio.sleep(1.5)
summary = await store.get_billing_summary()
assert "greeter" in summary.by_agent
assert "shouter" in summary.by_agent
assert summary.by_agent["greeter"]["total_tokens"] == 1500
assert summary.by_agent["shouter"]["total_tokens"] == 750
async def test_billing_filtered_by_org(self, store):
"""Billing summary can be filtered by org_id."""
tracker = get_usage_tracker()
tracker.record(
thread_id="t1", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0,
metadata={"org_id": "org-A"},
)
tracker.record(
thread_id="t2", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=2000, completion_tokens=1000, latency_ms=100.0,
metadata={"org_id": "org-B"},
)
await asyncio.sleep(1.5)
summary_a = await store.get_billing_summary(org_id="org-A")
summary_b = await store.get_billing_summary(org_id="org-B")
assert summary_a.total_tokens == 1500
assert summary_b.total_tokens == 3000
class TestDailyUsage:
"""Test daily usage aggregation."""
async def test_daily_aggregation(self, store):
"""Daily usage aggregates by date."""
tracker = get_usage_tracker()
# Record multiple events (all same day since timestamps are auto-generated)
for i in range(3):
tracker.record(
thread_id=f"t{i}", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
)
await asyncio.sleep(1.5)
days = await store.get_daily_usage()
assert len(days) == 1 # All same day
assert days[0]["total_tokens"] == 450
assert days[0]["request_count"] == 3
async def test_daily_returns_date(self, store):
"""Daily usage includes date field."""
tracker = get_usage_tracker()
tracker.record(
thread_id="t1", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
)
await asyncio.sleep(1.5)
days = await store.get_daily_usage()
assert len(days) == 1
# Date should be in YYYY-MM-DD format
assert len(days[0]["date"]) == 10
assert days[0]["date"].count("-") == 2
class TestStoreLifecycle:
"""Test store lifecycle management."""
async def test_close_flushes_pending(self, temp_db):
"""Closing store flushes pending writes."""
reset_usage_tracker()
reset_usage_store()
store = UsageStore(db_path=temp_db)
await store.initialize()
tracker = get_usage_tracker()
tracker.record(
thread_id="t1", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
)
# Close immediately - should flush
await store.close()
# Reopen and verify data persisted
store2 = UsageStore(db_path=temp_db)
await store2.initialize()
events = await store2.query()
assert len(events) == 1
await store2.close()

View file

@ -30,6 +30,20 @@ Usage Tracking:
# Query totals
totals = tracker.get_totals()
Usage Persistence (for billing):
from xml_pipeline.llm import get_usage_store
store = await get_usage_store()
# Query historical usage
events = await store.query(
start_time="2025-01-01T00:00:00Z",
org_id="org-123",
)
# Get billing summary
summary = await store.get_billing_summary(org_id="org-123")
"""
from xml_pipeline.llm.router import (
@ -46,6 +60,13 @@ from xml_pipeline.llm.usage_tracker import (
get_usage_tracker,
reset_usage_tracker,
)
from xml_pipeline.llm.usage_store import (
UsageStore,
BillingSummary,
get_usage_store,
close_usage_store,
reset_usage_store,
)
__all__ = [
# Router
@ -58,9 +79,15 @@ __all__ = [
"LLMRequest",
"LLMResponse",
"BackendError",
# Usage tracking
# Usage tracking (in-memory)
"UsageTracker",
"UsageEvent",
"get_usage_tracker",
"reset_usage_tracker",
# Usage persistence (SQLite)
"UsageStore",
"BillingSummary",
"get_usage_store",
"close_usage_store",
"reset_usage_store",
]

View file

@ -0,0 +1,599 @@
"""
Usage Store Persistent storage for billing and usage analytics.
Stores UsageEvents to SQLite (default) or PostgreSQL for:
- Historical billing queries
- Usage analytics and reporting
- Audit trails
The store auto-subscribes to UsageTracker on initialization,
persisting all events transparently.
Example:
from xml_pipeline.llm.usage_store import get_usage_store
store = await get_usage_store()
# Query historical usage
events = await store.query(
start_time="2025-01-01T00:00:00Z",
end_time="2025-01-31T23:59:59Z",
org_id="org-123",
)
# Get billing summary
summary = await store.get_billing_summary(
org_id="org-123",
start_time="2025-01-01T00:00:00Z",
)
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import threading
from dataclasses import dataclass, asdict
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
try:
import aiosqlite
HAS_AIOSQLITE = True
except ImportError:
HAS_AIOSQLITE = False
from xml_pipeline.llm.usage_tracker import UsageEvent, get_usage_tracker
logger = logging.getLogger(__name__)
# Default database path
DEFAULT_DB_PATH = Path.home() / ".xml-pipeline" / "usage.db"
@dataclass
class BillingSummary:
"""Aggregated billing data for a time period."""
org_id: Optional[str]
start_time: str
end_time: str
total_tokens: int
prompt_tokens: int
completion_tokens: int
request_count: int
total_cost: float
by_model: Dict[str, Dict[str, Any]]
by_agent: Dict[str, Dict[str, Any]]
class UsageStore:
"""
Persistent storage for usage events.
Uses SQLite by default, with async I/O via aiosqlite.
Automatically subscribes to UsageTracker to capture all events.
"""
def __init__(self, db_path: Optional[str] = None):
"""
Initialize the usage store.
Args:
db_path: Path to SQLite database. Defaults to ~/.xml-pipeline/usage.db
"""
if not HAS_AIOSQLITE:
raise ImportError(
"aiosqlite is required for usage persistence. "
"Install with: pip install aiosqlite"
)
self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
self._db_path.parent.mkdir(parents=True, exist_ok=True)
self._initialized = False
self._init_lock = asyncio.Lock()
# Queue for async persistence (events come from sync callback)
self._queue: asyncio.Queue[UsageEvent] = asyncio.Queue()
self._writer_task: Optional[asyncio.Task] = None
self._running = False
async def initialize(self) -> None:
"""Initialize database schema and start background writer."""
async with self._init_lock:
if self._initialized:
return
# Create tables
async with aiosqlite.connect(str(self._db_path)) as db:
await db.execute("""
CREATE TABLE IF NOT EXISTS usage_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
thread_id TEXT NOT NULL,
agent_id TEXT,
model TEXT NOT NULL,
provider TEXT NOT NULL,
prompt_tokens INTEGER NOT NULL,
completion_tokens INTEGER NOT NULL,
total_tokens INTEGER NOT NULL,
latency_ms REAL NOT NULL,
estimated_cost REAL,
metadata TEXT,
-- Denormalized for billing queries
org_id TEXT,
user_id TEXT
)
""")
# Indexes for common queries
await db.execute("""
CREATE INDEX IF NOT EXISTS idx_usage_timestamp
ON usage_events(timestamp)
""")
await db.execute("""
CREATE INDEX IF NOT EXISTS idx_usage_org_id
ON usage_events(org_id)
""")
await db.execute("""
CREATE INDEX IF NOT EXISTS idx_usage_agent_id
ON usage_events(agent_id)
""")
await db.execute("""
CREATE INDEX IF NOT EXISTS idx_usage_model
ON usage_events(model)
""")
await db.commit()
# Start background writer
self._running = True
self._writer_task = asyncio.create_task(self._writer_loop())
# Subscribe to usage tracker
tracker = get_usage_tracker()
tracker.subscribe(self._on_usage_event)
self._initialized = True
logger.info(f"UsageStore initialized: {self._db_path}")
def _on_usage_event(self, event: UsageEvent) -> None:
"""
Callback for UsageTracker events.
This runs synchronously from the tracker, so we queue
the event for async persistence.
"""
try:
self._queue.put_nowait(event)
except asyncio.QueueFull:
logger.warning("Usage event queue full, dropping event")
async def _writer_loop(self) -> None:
"""Background task that writes queued events to database."""
batch: List[UsageEvent] = []
batch_timeout = 1.0 # Flush every second or 100 events
while self._running or not self._queue.empty():
try:
# Collect batch
try:
event = await asyncio.wait_for(
self._queue.get(),
timeout=batch_timeout
)
batch.append(event)
# Drain queue up to batch size
while len(batch) < 100:
try:
event = self._queue.get_nowait()
batch.append(event)
except asyncio.QueueEmpty:
break
except asyncio.TimeoutError:
pass
# Write batch
if batch:
await self._write_batch(batch)
batch = []
except Exception as e:
logger.error(f"Usage writer error: {e}")
await asyncio.sleep(1.0)
async def _write_batch(self, events: List[UsageEvent]) -> None:
"""Write a batch of events to database."""
async with aiosqlite.connect(str(self._db_path)) as db:
await db.executemany(
"""
INSERT INTO usage_events (
timestamp, thread_id, agent_id, model, provider,
prompt_tokens, completion_tokens, total_tokens,
latency_ms, estimated_cost, metadata, org_id, user_id
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
[
(
e.timestamp,
e.thread_id,
e.agent_id,
e.model,
e.provider,
e.prompt_tokens,
e.completion_tokens,
e.total_tokens,
e.latency_ms,
e.estimated_cost,
json.dumps(e.metadata) if e.metadata else None,
e.metadata.get("org_id") if e.metadata else None,
e.metadata.get("user_id") if e.metadata else None,
)
for e in events
]
)
await db.commit()
logger.debug(f"Persisted {len(events)} usage events")
async def query(
self,
*,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
org_id: Optional[str] = None,
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
model: Optional[str] = None,
limit: int = 1000,
offset: int = 0,
) -> List[Dict[str, Any]]:
"""
Query historical usage events.
Args:
start_time: ISO 8601 timestamp (inclusive)
end_time: ISO 8601 timestamp (inclusive)
org_id: Filter by organization
user_id: Filter by user
agent_id: Filter by agent
model: Filter by model
limit: Max results (default 1000)
offset: Pagination offset
Returns:
List of usage event dicts
"""
conditions = []
params = []
if start_time:
conditions.append("timestamp >= ?")
params.append(start_time)
if end_time:
conditions.append("timestamp <= ?")
params.append(end_time)
if org_id:
conditions.append("org_id = ?")
params.append(org_id)
if user_id:
conditions.append("user_id = ?")
params.append(user_id)
if agent_id:
conditions.append("agent_id = ?")
params.append(agent_id)
if model:
conditions.append("model = ?")
params.append(model)
where_clause = " AND ".join(conditions) if conditions else "1=1"
async with aiosqlite.connect(str(self._db_path)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
f"""
SELECT * FROM usage_events
WHERE {where_clause}
ORDER BY timestamp DESC
LIMIT ? OFFSET ?
""",
params + [limit, offset]
)
rows = await cursor.fetchall()
return [
{
"id": row["id"],
"timestamp": row["timestamp"],
"thread_id": row["thread_id"],
"agent_id": row["agent_id"],
"model": row["model"],
"provider": row["provider"],
"prompt_tokens": row["prompt_tokens"],
"completion_tokens": row["completion_tokens"],
"total_tokens": row["total_tokens"],
"latency_ms": row["latency_ms"],
"estimated_cost": row["estimated_cost"],
"metadata": json.loads(row["metadata"]) if row["metadata"] else {},
}
for row in rows
]
async def get_billing_summary(
self,
*,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
org_id: Optional[str] = None,
) -> BillingSummary:
"""
Get aggregated billing summary for a time period.
Args:
start_time: ISO 8601 timestamp (inclusive)
end_time: ISO 8601 timestamp (inclusive)
org_id: Filter by organization
Returns:
BillingSummary with totals and breakdowns
"""
conditions = []
params = []
if start_time:
conditions.append("timestamp >= ?")
params.append(start_time)
if end_time:
conditions.append("timestamp <= ?")
params.append(end_time)
if org_id:
conditions.append("org_id = ?")
params.append(org_id)
where_clause = " AND ".join(conditions) if conditions else "1=1"
async with aiosqlite.connect(str(self._db_path)) as db:
# Overall totals
cursor = await db.execute(
f"""
SELECT
COALESCE(SUM(total_tokens), 0) as total_tokens,
COALESCE(SUM(prompt_tokens), 0) as prompt_tokens,
COALESCE(SUM(completion_tokens), 0) as completion_tokens,
COUNT(*) as request_count,
COALESCE(SUM(estimated_cost), 0) as total_cost
FROM usage_events
WHERE {where_clause}
""",
params
)
row = await cursor.fetchone()
totals = {
"total_tokens": row[0],
"prompt_tokens": row[1],
"completion_tokens": row[2],
"request_count": row[3],
"total_cost": round(row[4], 4),
}
# By model
cursor = await db.execute(
f"""
SELECT
model,
SUM(total_tokens) as total_tokens,
SUM(prompt_tokens) as prompt_tokens,
SUM(completion_tokens) as completion_tokens,
COUNT(*) as request_count,
SUM(estimated_cost) as total_cost
FROM usage_events
WHERE {where_clause}
GROUP BY model
""",
params
)
by_model = {
row[0]: {
"total_tokens": row[1],
"prompt_tokens": row[2],
"completion_tokens": row[3],
"request_count": row[4],
"total_cost": round(row[5] or 0, 4),
}
for row in await cursor.fetchall()
}
# By agent
cursor = await db.execute(
f"""
SELECT
agent_id,
SUM(total_tokens) as total_tokens,
SUM(prompt_tokens) as prompt_tokens,
SUM(completion_tokens) as completion_tokens,
COUNT(*) as request_count,
SUM(estimated_cost) as total_cost
FROM usage_events
WHERE {where_clause} AND agent_id IS NOT NULL
GROUP BY agent_id
""",
params
)
by_agent = {
row[0]: {
"total_tokens": row[1],
"prompt_tokens": row[2],
"completion_tokens": row[3],
"request_count": row[4],
"total_cost": round(row[5] or 0, 4),
}
for row in await cursor.fetchall()
}
return BillingSummary(
org_id=org_id,
start_time=start_time or "",
end_time=end_time or datetime.now(timezone.utc).isoformat(),
total_tokens=totals["total_tokens"],
prompt_tokens=totals["prompt_tokens"],
completion_tokens=totals["completion_tokens"],
request_count=totals["request_count"],
total_cost=totals["total_cost"],
by_model=by_model,
by_agent=by_agent,
)
async def get_daily_usage(
self,
*,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
org_id: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
Get usage aggregated by day for charting.
Returns:
List of {date, total_tokens, request_count, total_cost}
"""
conditions = []
params = []
if start_time:
conditions.append("timestamp >= ?")
params.append(start_time)
if end_time:
conditions.append("timestamp <= ?")
params.append(end_time)
if org_id:
conditions.append("org_id = ?")
params.append(org_id)
where_clause = " AND ".join(conditions) if conditions else "1=1"
async with aiosqlite.connect(str(self._db_path)) as db:
cursor = await db.execute(
f"""
SELECT
DATE(timestamp) as date,
SUM(total_tokens) as total_tokens,
COUNT(*) as request_count,
SUM(estimated_cost) as total_cost
FROM usage_events
WHERE {where_clause}
GROUP BY DATE(timestamp)
ORDER BY date
""",
params
)
rows = await cursor.fetchall()
return [
{
"date": row[0],
"total_tokens": row[1],
"request_count": row[2],
"total_cost": round(row[3] or 0, 4),
}
for row in rows
]
async def count(
self,
*,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
org_id: Optional[str] = None,
) -> int:
"""Get total count of events matching criteria."""
conditions = []
params = []
if start_time:
conditions.append("timestamp >= ?")
params.append(start_time)
if end_time:
conditions.append("timestamp <= ?")
params.append(end_time)
if org_id:
conditions.append("org_id = ?")
params.append(org_id)
where_clause = " AND ".join(conditions) if conditions else "1=1"
async with aiosqlite.connect(str(self._db_path)) as db:
cursor = await db.execute(
f"SELECT COUNT(*) FROM usage_events WHERE {where_clause}",
params
)
row = await cursor.fetchone()
return row[0]
async def close(self) -> None:
"""Shutdown the store, flushing pending writes."""
self._running = False
# Unsubscribe from tracker
tracker = get_usage_tracker()
tracker.unsubscribe(self._on_usage_event)
# Wait for writer to finish
if self._writer_task:
await self._writer_task
self._writer_task = None
logger.info("UsageStore closed")
# =============================================================================
# Global Instance
# =============================================================================
_store: Optional[UsageStore] = None
_store_lock = threading.Lock()
async def get_usage_store(db_path: Optional[str] = None) -> UsageStore:
"""
Get the global usage store, initializing if needed.
Args:
db_path: Optional path to SQLite database
Returns:
Initialized UsageStore
"""
global _store
if _store is None:
with _store_lock:
if _store is None:
_store = UsageStore(db_path)
if not _store._initialized:
await _store.initialize()
return _store
async def close_usage_store() -> None:
"""Close the global usage store."""
global _store
if _store is not None:
await _store.close()
_store = None
def reset_usage_store() -> None:
"""Reset global store (for testing)."""
global _store
with _store_lock:
_store = None

View file

@ -19,9 +19,12 @@ from xml_pipeline.server.models import (
AgentInfo,
AgentListResponse,
AgentUsageInfo,
BillingSummaryResponse,
CapabilityDetail,
CapabilityInfo,
CapabilityListResponse,
DailyUsagePoint,
DailyUsageResponse,
ErrorResponse,
InjectRequest,
InjectResponse,
@ -33,6 +36,8 @@ from xml_pipeline.server.models import (
ThreadInfo,
ThreadListResponse,
ThreadStatus,
UsageEventInfo,
UsageHistoryResponse,
UsageOverview,
UsageResponse,
UsageTotals,
@ -425,6 +430,140 @@ def create_router(state: "ServerState") -> APIRouter:
reset_usage_tracker()
return {"success": True, "message": "Usage tracking reset"}
# =========================================================================
# Usage History Endpoints (Persistent)
# =========================================================================
@router.get("/usage/history", response_model=UsageHistoryResponse)
async def get_usage_history(
start_time: Optional[str] = Query(None, description="ISO 8601 start time"),
end_time: Optional[str] = Query(None, description="ISO 8601 end time"),
org_id: Optional[str] = Query(None, description="Filter by organization"),
agent_id: Optional[str] = Query(None, description="Filter by agent"),
model: Optional[str] = Query(None, description="Filter by model"),
limit: int = Query(100, ge=1, le=1000),
offset: int = Query(0, ge=0),
) -> UsageHistoryResponse:
"""
Query historical usage events from persistent storage.
Use for billing reconciliation, audit trails, and detailed analytics.
Events are stored in SQLite and persist across restarts.
"""
from xml_pipeline.llm.usage_store import get_usage_store
store = await get_usage_store()
events = await store.query(
start_time=start_time,
end_time=end_time,
org_id=org_id,
agent_id=agent_id,
model=model,
limit=limit,
offset=offset,
)
total = await store.count(
start_time=start_time,
end_time=end_time,
org_id=org_id,
)
return UsageHistoryResponse(
events=[
UsageEventInfo(
id=e["id"],
timestamp=e["timestamp"],
thread_id=e["thread_id"],
agent_id=e.get("agent_id"),
model=e["model"],
provider=e["provider"],
prompt_tokens=e["prompt_tokens"],
completion_tokens=e["completion_tokens"],
total_tokens=e["total_tokens"],
latency_ms=e["latency_ms"],
estimated_cost=e.get("estimated_cost"),
metadata=e.get("metadata", {}),
)
for e in events
],
count=len(events),
total=total,
offset=offset,
limit=limit,
)
@router.get("/usage/billing", response_model=BillingSummaryResponse)
async def get_billing_summary(
start_time: Optional[str] = Query(None, description="ISO 8601 start time"),
end_time: Optional[str] = Query(None, description="ISO 8601 end time"),
org_id: Optional[str] = Query(None, description="Filter by organization"),
) -> BillingSummaryResponse:
"""
Get aggregated billing summary for a time period.
Returns total tokens, costs, and breakdowns by model and agent.
Use for invoicing and cost analysis.
"""
from xml_pipeline.llm.usage_store import get_usage_store
store = await get_usage_store()
summary = await store.get_billing_summary(
start_time=start_time,
end_time=end_time,
org_id=org_id,
)
return BillingSummaryResponse(
org_id=summary.org_id,
start_time=summary.start_time,
end_time=summary.end_time,
total_tokens=summary.total_tokens,
prompt_tokens=summary.prompt_tokens,
completion_tokens=summary.completion_tokens,
request_count=summary.request_count,
total_cost=summary.total_cost,
by_model=summary.by_model,
by_agent=summary.by_agent,
)
@router.get("/usage/daily", response_model=DailyUsageResponse)
async def get_daily_usage(
start_time: Optional[str] = Query(None, description="ISO 8601 start time"),
end_time: Optional[str] = Query(None, description="ISO 8601 end time"),
org_id: Optional[str] = Query(None, description="Filter by organization"),
) -> DailyUsageResponse:
"""
Get usage aggregated by day for charting.
Returns daily totals for tokens, requests, and costs.
Useful for dashboards and trend analysis.
"""
from xml_pipeline.llm.usage_store import get_usage_store
store = await get_usage_store()
days = await store.get_daily_usage(
start_time=start_time,
end_time=end_time,
org_id=org_id,
)
return DailyUsageResponse(
days=[
DailyUsagePoint(
date=d["date"],
total_tokens=d["total_tokens"],
request_count=d["request_count"],
total_cost=d["total_cost"],
)
for d in days
],
count=len(days),
)
# =========================================================================
# Control Endpoints
# =========================================================================

View file

@ -368,3 +368,66 @@ class ThreadBudgetListResponse(CamelModel):
threads: List[ThreadBudgetInfo]
count: int
default_max_tokens: int = Field(alias="defaultMaxTokens")
# =============================================================================
# Usage History Models (Persistent)
# =============================================================================
class UsageEventInfo(CamelModel):
"""A single usage event from history."""
id: int
timestamp: str
thread_id: str = Field(alias="threadId")
agent_id: Optional[str] = Field(None, alias="agentId")
model: str
provider: str
prompt_tokens: int = Field(alias="promptTokens")
completion_tokens: int = Field(alias="completionTokens")
total_tokens: int = Field(alias="totalTokens")
latency_ms: float = Field(alias="latencyMs")
estimated_cost: Optional[float] = Field(None, alias="estimatedCost")
metadata: dict = Field(default_factory=dict)
class UsageHistoryResponse(CamelModel):
"""Response for GET /usage/history."""
events: List[UsageEventInfo]
count: int
total: int
offset: int
limit: int
class BillingSummaryResponse(CamelModel):
"""Response for GET /usage/billing."""
org_id: Optional[str] = Field(None, alias="orgId")
start_time: str = Field(alias="startTime")
end_time: str = Field(alias="endTime")
total_tokens: int = Field(alias="totalTokens")
prompt_tokens: int = Field(alias="promptTokens")
completion_tokens: int = Field(alias="completionTokens")
request_count: int = Field(alias="requestCount")
total_cost: float = Field(alias="totalCost")
by_model: dict = Field(default_factory=dict, alias="byModel")
by_agent: dict = Field(default_factory=dict, alias="byAgent")
class DailyUsagePoint(CamelModel):
"""A single day's usage for charting."""
date: str
total_tokens: int = Field(alias="totalTokens")
request_count: int = Field(alias="requestCount")
total_cost: float = Field(alias="totalCost")
class DailyUsageResponse(CamelModel):
"""Response for GET /usage/daily."""
days: List[DailyUsagePoint]
count: int