xml-pipeline/xml_pipeline/llm/usage_store.py
dullfig d0d78a9f70 Add usage persistence for billing (SQLite)
- UsageStore with async SQLite persistence via aiosqlite
- Background batch writer for non-blocking event persistence
- Auto-subscribes to UsageTracker for transparent capture
- Query methods: query(), get_billing_summary(), get_daily_usage()
- REST API endpoints: /usage/history, /usage/billing, /usage/daily
- Filtering by org_id, agent_id, model, time range
- 18 new tests for persistence layer

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 21:58:22 -08:00

599 lines
19 KiB
Python

"""
Usage Store — Persistent storage for billing and usage analytics.
Stores UsageEvents to SQLite (default) or PostgreSQL for:
- Historical billing queries
- Usage analytics and reporting
- Audit trails
The store auto-subscribes to UsageTracker on initialization,
persisting all events transparently.
Example:
from xml_pipeline.llm.usage_store import get_usage_store
store = await get_usage_store()
# Query historical usage
events = await store.query(
start_time="2025-01-01T00:00:00Z",
end_time="2025-01-31T23:59:59Z",
org_id="org-123",
)
# Get billing summary
summary = await store.get_billing_summary(
org_id="org-123",
start_time="2025-01-01T00:00:00Z",
)
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import threading
from dataclasses import dataclass, asdict
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
try:
import aiosqlite
HAS_AIOSQLITE = True
except ImportError:
HAS_AIOSQLITE = False
from xml_pipeline.llm.usage_tracker import UsageEvent, get_usage_tracker
logger = logging.getLogger(__name__)
# Default database path
DEFAULT_DB_PATH = Path.home() / ".xml-pipeline" / "usage.db"
@dataclass
class BillingSummary:
"""Aggregated billing data for a time period."""
org_id: Optional[str]
start_time: str
end_time: str
total_tokens: int
prompt_tokens: int
completion_tokens: int
request_count: int
total_cost: float
by_model: Dict[str, Dict[str, Any]]
by_agent: Dict[str, Dict[str, Any]]
class UsageStore:
"""
Persistent storage for usage events.
Uses SQLite by default, with async I/O via aiosqlite.
Automatically subscribes to UsageTracker to capture all events.
"""
def __init__(self, db_path: Optional[str] = None):
"""
Initialize the usage store.
Args:
db_path: Path to SQLite database. Defaults to ~/.xml-pipeline/usage.db
"""
if not HAS_AIOSQLITE:
raise ImportError(
"aiosqlite is required for usage persistence. "
"Install with: pip install aiosqlite"
)
self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH
self._db_path.parent.mkdir(parents=True, exist_ok=True)
self._initialized = False
self._init_lock = asyncio.Lock()
# Queue for async persistence (events come from sync callback)
self._queue: asyncio.Queue[UsageEvent] = asyncio.Queue()
self._writer_task: Optional[asyncio.Task] = None
self._running = False
async def initialize(self) -> None:
"""Initialize database schema and start background writer."""
async with self._init_lock:
if self._initialized:
return
# Create tables
async with aiosqlite.connect(str(self._db_path)) as db:
await db.execute("""
CREATE TABLE IF NOT EXISTS usage_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
thread_id TEXT NOT NULL,
agent_id TEXT,
model TEXT NOT NULL,
provider TEXT NOT NULL,
prompt_tokens INTEGER NOT NULL,
completion_tokens INTEGER NOT NULL,
total_tokens INTEGER NOT NULL,
latency_ms REAL NOT NULL,
estimated_cost REAL,
metadata TEXT,
-- Denormalized for billing queries
org_id TEXT,
user_id TEXT
)
""")
# Indexes for common queries
await db.execute("""
CREATE INDEX IF NOT EXISTS idx_usage_timestamp
ON usage_events(timestamp)
""")
await db.execute("""
CREATE INDEX IF NOT EXISTS idx_usage_org_id
ON usage_events(org_id)
""")
await db.execute("""
CREATE INDEX IF NOT EXISTS idx_usage_agent_id
ON usage_events(agent_id)
""")
await db.execute("""
CREATE INDEX IF NOT EXISTS idx_usage_model
ON usage_events(model)
""")
await db.commit()
# Start background writer
self._running = True
self._writer_task = asyncio.create_task(self._writer_loop())
# Subscribe to usage tracker
tracker = get_usage_tracker()
tracker.subscribe(self._on_usage_event)
self._initialized = True
logger.info(f"UsageStore initialized: {self._db_path}")
def _on_usage_event(self, event: UsageEvent) -> None:
"""
Callback for UsageTracker events.
This runs synchronously from the tracker, so we queue
the event for async persistence.
"""
try:
self._queue.put_nowait(event)
except asyncio.QueueFull:
logger.warning("Usage event queue full, dropping event")
async def _writer_loop(self) -> None:
"""Background task that writes queued events to database."""
batch: List[UsageEvent] = []
batch_timeout = 1.0 # Flush every second or 100 events
while self._running or not self._queue.empty():
try:
# Collect batch
try:
event = await asyncio.wait_for(
self._queue.get(),
timeout=batch_timeout
)
batch.append(event)
# Drain queue up to batch size
while len(batch) < 100:
try:
event = self._queue.get_nowait()
batch.append(event)
except asyncio.QueueEmpty:
break
except asyncio.TimeoutError:
pass
# Write batch
if batch:
await self._write_batch(batch)
batch = []
except Exception as e:
logger.error(f"Usage writer error: {e}")
await asyncio.sleep(1.0)
async def _write_batch(self, events: List[UsageEvent]) -> None:
"""Write a batch of events to database."""
async with aiosqlite.connect(str(self._db_path)) as db:
await db.executemany(
"""
INSERT INTO usage_events (
timestamp, thread_id, agent_id, model, provider,
prompt_tokens, completion_tokens, total_tokens,
latency_ms, estimated_cost, metadata, org_id, user_id
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
[
(
e.timestamp,
e.thread_id,
e.agent_id,
e.model,
e.provider,
e.prompt_tokens,
e.completion_tokens,
e.total_tokens,
e.latency_ms,
e.estimated_cost,
json.dumps(e.metadata) if e.metadata else None,
e.metadata.get("org_id") if e.metadata else None,
e.metadata.get("user_id") if e.metadata else None,
)
for e in events
]
)
await db.commit()
logger.debug(f"Persisted {len(events)} usage events")
async def query(
self,
*,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
org_id: Optional[str] = None,
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
model: Optional[str] = None,
limit: int = 1000,
offset: int = 0,
) -> List[Dict[str, Any]]:
"""
Query historical usage events.
Args:
start_time: ISO 8601 timestamp (inclusive)
end_time: ISO 8601 timestamp (inclusive)
org_id: Filter by organization
user_id: Filter by user
agent_id: Filter by agent
model: Filter by model
limit: Max results (default 1000)
offset: Pagination offset
Returns:
List of usage event dicts
"""
conditions = []
params = []
if start_time:
conditions.append("timestamp >= ?")
params.append(start_time)
if end_time:
conditions.append("timestamp <= ?")
params.append(end_time)
if org_id:
conditions.append("org_id = ?")
params.append(org_id)
if user_id:
conditions.append("user_id = ?")
params.append(user_id)
if agent_id:
conditions.append("agent_id = ?")
params.append(agent_id)
if model:
conditions.append("model = ?")
params.append(model)
where_clause = " AND ".join(conditions) if conditions else "1=1"
async with aiosqlite.connect(str(self._db_path)) as db:
db.row_factory = aiosqlite.Row
cursor = await db.execute(
f"""
SELECT * FROM usage_events
WHERE {where_clause}
ORDER BY timestamp DESC
LIMIT ? OFFSET ?
""",
params + [limit, offset]
)
rows = await cursor.fetchall()
return [
{
"id": row["id"],
"timestamp": row["timestamp"],
"thread_id": row["thread_id"],
"agent_id": row["agent_id"],
"model": row["model"],
"provider": row["provider"],
"prompt_tokens": row["prompt_tokens"],
"completion_tokens": row["completion_tokens"],
"total_tokens": row["total_tokens"],
"latency_ms": row["latency_ms"],
"estimated_cost": row["estimated_cost"],
"metadata": json.loads(row["metadata"]) if row["metadata"] else {},
}
for row in rows
]
async def get_billing_summary(
self,
*,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
org_id: Optional[str] = None,
) -> BillingSummary:
"""
Get aggregated billing summary for a time period.
Args:
start_time: ISO 8601 timestamp (inclusive)
end_time: ISO 8601 timestamp (inclusive)
org_id: Filter by organization
Returns:
BillingSummary with totals and breakdowns
"""
conditions = []
params = []
if start_time:
conditions.append("timestamp >= ?")
params.append(start_time)
if end_time:
conditions.append("timestamp <= ?")
params.append(end_time)
if org_id:
conditions.append("org_id = ?")
params.append(org_id)
where_clause = " AND ".join(conditions) if conditions else "1=1"
async with aiosqlite.connect(str(self._db_path)) as db:
# Overall totals
cursor = await db.execute(
f"""
SELECT
COALESCE(SUM(total_tokens), 0) as total_tokens,
COALESCE(SUM(prompt_tokens), 0) as prompt_tokens,
COALESCE(SUM(completion_tokens), 0) as completion_tokens,
COUNT(*) as request_count,
COALESCE(SUM(estimated_cost), 0) as total_cost
FROM usage_events
WHERE {where_clause}
""",
params
)
row = await cursor.fetchone()
totals = {
"total_tokens": row[0],
"prompt_tokens": row[1],
"completion_tokens": row[2],
"request_count": row[3],
"total_cost": round(row[4], 4),
}
# By model
cursor = await db.execute(
f"""
SELECT
model,
SUM(total_tokens) as total_tokens,
SUM(prompt_tokens) as prompt_tokens,
SUM(completion_tokens) as completion_tokens,
COUNT(*) as request_count,
SUM(estimated_cost) as total_cost
FROM usage_events
WHERE {where_clause}
GROUP BY model
""",
params
)
by_model = {
row[0]: {
"total_tokens": row[1],
"prompt_tokens": row[2],
"completion_tokens": row[3],
"request_count": row[4],
"total_cost": round(row[5] or 0, 4),
}
for row in await cursor.fetchall()
}
# By agent
cursor = await db.execute(
f"""
SELECT
agent_id,
SUM(total_tokens) as total_tokens,
SUM(prompt_tokens) as prompt_tokens,
SUM(completion_tokens) as completion_tokens,
COUNT(*) as request_count,
SUM(estimated_cost) as total_cost
FROM usage_events
WHERE {where_clause} AND agent_id IS NOT NULL
GROUP BY agent_id
""",
params
)
by_agent = {
row[0]: {
"total_tokens": row[1],
"prompt_tokens": row[2],
"completion_tokens": row[3],
"request_count": row[4],
"total_cost": round(row[5] or 0, 4),
}
for row in await cursor.fetchall()
}
return BillingSummary(
org_id=org_id,
start_time=start_time or "",
end_time=end_time or datetime.now(timezone.utc).isoformat(),
total_tokens=totals["total_tokens"],
prompt_tokens=totals["prompt_tokens"],
completion_tokens=totals["completion_tokens"],
request_count=totals["request_count"],
total_cost=totals["total_cost"],
by_model=by_model,
by_agent=by_agent,
)
async def get_daily_usage(
self,
*,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
org_id: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
Get usage aggregated by day for charting.
Returns:
List of {date, total_tokens, request_count, total_cost}
"""
conditions = []
params = []
if start_time:
conditions.append("timestamp >= ?")
params.append(start_time)
if end_time:
conditions.append("timestamp <= ?")
params.append(end_time)
if org_id:
conditions.append("org_id = ?")
params.append(org_id)
where_clause = " AND ".join(conditions) if conditions else "1=1"
async with aiosqlite.connect(str(self._db_path)) as db:
cursor = await db.execute(
f"""
SELECT
DATE(timestamp) as date,
SUM(total_tokens) as total_tokens,
COUNT(*) as request_count,
SUM(estimated_cost) as total_cost
FROM usage_events
WHERE {where_clause}
GROUP BY DATE(timestamp)
ORDER BY date
""",
params
)
rows = await cursor.fetchall()
return [
{
"date": row[0],
"total_tokens": row[1],
"request_count": row[2],
"total_cost": round(row[3] or 0, 4),
}
for row in rows
]
async def count(
self,
*,
start_time: Optional[str] = None,
end_time: Optional[str] = None,
org_id: Optional[str] = None,
) -> int:
"""Get total count of events matching criteria."""
conditions = []
params = []
if start_time:
conditions.append("timestamp >= ?")
params.append(start_time)
if end_time:
conditions.append("timestamp <= ?")
params.append(end_time)
if org_id:
conditions.append("org_id = ?")
params.append(org_id)
where_clause = " AND ".join(conditions) if conditions else "1=1"
async with aiosqlite.connect(str(self._db_path)) as db:
cursor = await db.execute(
f"SELECT COUNT(*) FROM usage_events WHERE {where_clause}",
params
)
row = await cursor.fetchone()
return row[0]
async def close(self) -> None:
"""Shutdown the store, flushing pending writes."""
self._running = False
# Unsubscribe from tracker
tracker = get_usage_tracker()
tracker.unsubscribe(self._on_usage_event)
# Wait for writer to finish
if self._writer_task:
await self._writer_task
self._writer_task = None
logger.info("UsageStore closed")
# =============================================================================
# Global Instance
# =============================================================================
_store: Optional[UsageStore] = None
_store_lock = threading.Lock()
async def get_usage_store(db_path: Optional[str] = None) -> UsageStore:
"""
Get the global usage store, initializing if needed.
Args:
db_path: Optional path to SQLite database
Returns:
Initialized UsageStore
"""
global _store
if _store is None:
with _store_lock:
if _store is None:
_store = UsageStore(db_path)
if not _store._initialized:
await _store.initialize()
return _store
async def close_usage_store() -> None:
"""Close the global usage store."""
global _store
if _store is not None:
await _store.close()
_store = None
def reset_usage_store() -> None:
"""Reset global store (for testing)."""
global _store
with _store_lock:
_store = None