diff --git a/tests/test_usage_store.py b/tests/test_usage_store.py new file mode 100644 index 0000000..3e780e0 --- /dev/null +++ b/tests/test_usage_store.py @@ -0,0 +1,398 @@ +""" +Tests for usage persistence layer. + +Tests UsageStore's SQLite persistence, querying, and billing summaries. +""" + +from __future__ import annotations + +import asyncio +import tempfile +from datetime import datetime, timezone +from pathlib import Path + +import pytest + +from xml_pipeline.llm.usage_tracker import UsageEvent, get_usage_tracker, reset_usage_tracker +from xml_pipeline.llm.usage_store import UsageStore, reset_usage_store + + +@pytest.fixture +def temp_db(): + """Create a temporary database file.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f: + yield f.name + # Cleanup happens automatically + + +@pytest.fixture +async def store(temp_db): + """Create and initialize a UsageStore with temp database.""" + reset_usage_tracker() + reset_usage_store() + + s = UsageStore(db_path=temp_db) + await s.initialize() + yield s + await s.close() + + +class TestUsageStoreBasics: + """Test basic store operations.""" + + async def test_initialize_creates_table(self, temp_db): + """Store initialization creates the usage_events table.""" + reset_usage_store() + store = UsageStore(db_path=temp_db) + await store.initialize() + + # Query should work (table exists) + events = await store.query(limit=10) + assert events == [] + + await store.close() + + async def test_store_initializes_once(self, store): + """Multiple initialize calls are idempotent.""" + # Should not raise + await store.initialize() + await store.initialize() + + async def test_empty_query(self, store): + """Query on empty store returns empty list.""" + events = await store.query() + assert events == [] + + async def test_empty_count(self, store): + """Count on empty store returns 0.""" + count = await store.count() + assert count == 0 + + +class TestEventPersistence: + """Test event persistence via subscriber pattern.""" + + async def test_event_persisted_via_tracker(self, store): + """Events recorded in tracker are persisted to store.""" + tracker = get_usage_tracker() + + # Record an event + tracker.record( + thread_id="test-thread-1", + agent_id="greeter", + model="grok-4.1", + provider="xai", + prompt_tokens=100, + completion_tokens=50, + latency_ms=250.5, + metadata={"org_id": "org-123"}, + ) + + # Give background writer time to flush + await asyncio.sleep(1.5) + + # Query should find the event + events = await store.query() + assert len(events) == 1 + + event = events[0] + assert event["thread_id"] == "test-thread-1" + assert event["agent_id"] == "greeter" + assert event["model"] == "grok-4.1" + assert event["provider"] == "xai" + assert event["prompt_tokens"] == 100 + assert event["completion_tokens"] == 50 + assert event["total_tokens"] == 150 + assert event["latency_ms"] == 250.5 + + async def test_multiple_events_persisted(self, store): + """Multiple events are all persisted.""" + tracker = get_usage_tracker() + + for i in range(5): + tracker.record( + thread_id=f"thread-{i}", + agent_id="agent", + model="grok-4.1", + provider="xai", + prompt_tokens=100 * (i + 1), + completion_tokens=50 * (i + 1), + latency_ms=100.0, + ) + + await asyncio.sleep(1.5) + + events = await store.query() + assert len(events) == 5 + + async def test_event_with_cost_estimate(self, store): + """Estimated cost is persisted.""" + tracker = get_usage_tracker() + + # grok-4.1 has pricing in MODEL_COSTS + tracker.record( + thread_id="cost-thread", + agent_id="agent", + model="grok-4.1", + provider="xai", + prompt_tokens=1000, + completion_tokens=500, + latency_ms=100.0, + ) + + await asyncio.sleep(1.5) + + events = await store.query() + assert len(events) == 1 + assert events[0]["estimated_cost"] is not None + assert events[0]["estimated_cost"] > 0 + + +class TestQueryFiltering: + """Test query filtering options.""" + + async def test_filter_by_agent(self, store): + """Filter events by agent_id.""" + tracker = get_usage_tracker() + + tracker.record( + thread_id="t1", agent_id="greeter", model="grok-4.1", + provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0, + ) + tracker.record( + thread_id="t2", agent_id="shouter", model="grok-4.1", + provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0, + ) + + await asyncio.sleep(1.5) + + events = await store.query(agent_id="greeter") + assert len(events) == 1 + assert events[0]["agent_id"] == "greeter" + + async def test_filter_by_model(self, store): + """Filter events by model.""" + tracker = get_usage_tracker() + + tracker.record( + thread_id="t1", agent_id="agent", model="grok-4.1", + provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0, + ) + tracker.record( + thread_id="t2", agent_id="agent", model="claude-sonnet-4", + provider="anthropic", prompt_tokens=100, completion_tokens=50, latency_ms=100.0, + ) + + await asyncio.sleep(1.5) + + events = await store.query(model="grok-4.1") + assert len(events) == 1 + assert events[0]["model"] == "grok-4.1" + + async def test_filter_by_org_id(self, store): + """Filter events by org_id in metadata.""" + tracker = get_usage_tracker() + + tracker.record( + thread_id="t1", agent_id="agent", model="grok-4.1", + provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0, + metadata={"org_id": "org-A"}, + ) + tracker.record( + thread_id="t2", agent_id="agent", model="grok-4.1", + provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0, + metadata={"org_id": "org-B"}, + ) + + await asyncio.sleep(1.5) + + events = await store.query(org_id="org-A") + assert len(events) == 1 + assert events[0]["thread_id"] == "t1" + + async def test_pagination(self, store): + """Test limit and offset for pagination.""" + tracker = get_usage_tracker() + + for i in range(10): + tracker.record( + thread_id=f"t{i:02d}", agent_id="agent", model="grok-4.1", + provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0, + ) + + await asyncio.sleep(1.5) + + # Get first page + page1 = await store.query(limit=3, offset=0) + assert len(page1) == 3 + + # Get second page + page2 = await store.query(limit=3, offset=3) + assert len(page2) == 3 + + # Different events + assert page1[0]["thread_id"] != page2[0]["thread_id"] + + +class TestBillingSummary: + """Test billing summary aggregation.""" + + async def test_billing_totals(self, store): + """Billing summary calculates correct totals.""" + tracker = get_usage_tracker() + + tracker.record( + thread_id="t1", agent_id="agent", model="grok-4.1", + provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0, + ) + tracker.record( + thread_id="t2", agent_id="agent", model="grok-4.1", + provider="xai", prompt_tokens=2000, completion_tokens=1000, latency_ms=100.0, + ) + + await asyncio.sleep(1.5) + + summary = await store.get_billing_summary() + + assert summary.total_tokens == 4500 # 1500 + 3000 + assert summary.prompt_tokens == 3000 + assert summary.completion_tokens == 1500 + assert summary.request_count == 2 + + async def test_billing_by_model(self, store): + """Billing summary includes breakdown by model.""" + tracker = get_usage_tracker() + + tracker.record( + thread_id="t1", agent_id="agent", model="grok-4.1", + provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0, + ) + tracker.record( + thread_id="t2", agent_id="agent", model="claude-sonnet-4", + provider="anthropic", prompt_tokens=2000, completion_tokens=1000, latency_ms=100.0, + ) + + await asyncio.sleep(1.5) + + summary = await store.get_billing_summary() + + assert "grok-4.1" in summary.by_model + assert "claude-sonnet-4" in summary.by_model + assert summary.by_model["grok-4.1"]["total_tokens"] == 1500 + assert summary.by_model["claude-sonnet-4"]["total_tokens"] == 3000 + + async def test_billing_by_agent(self, store): + """Billing summary includes breakdown by agent.""" + tracker = get_usage_tracker() + + tracker.record( + thread_id="t1", agent_id="greeter", model="grok-4.1", + provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0, + ) + tracker.record( + thread_id="t2", agent_id="shouter", model="grok-4.1", + provider="xai", prompt_tokens=500, completion_tokens=250, latency_ms=100.0, + ) + + await asyncio.sleep(1.5) + + summary = await store.get_billing_summary() + + assert "greeter" in summary.by_agent + assert "shouter" in summary.by_agent + assert summary.by_agent["greeter"]["total_tokens"] == 1500 + assert summary.by_agent["shouter"]["total_tokens"] == 750 + + async def test_billing_filtered_by_org(self, store): + """Billing summary can be filtered by org_id.""" + tracker = get_usage_tracker() + + tracker.record( + thread_id="t1", agent_id="agent", model="grok-4.1", + provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0, + metadata={"org_id": "org-A"}, + ) + tracker.record( + thread_id="t2", agent_id="agent", model="grok-4.1", + provider="xai", prompt_tokens=2000, completion_tokens=1000, latency_ms=100.0, + metadata={"org_id": "org-B"}, + ) + + await asyncio.sleep(1.5) + + summary_a = await store.get_billing_summary(org_id="org-A") + summary_b = await store.get_billing_summary(org_id="org-B") + + assert summary_a.total_tokens == 1500 + assert summary_b.total_tokens == 3000 + + +class TestDailyUsage: + """Test daily usage aggregation.""" + + async def test_daily_aggregation(self, store): + """Daily usage aggregates by date.""" + tracker = get_usage_tracker() + + # Record multiple events (all same day since timestamps are auto-generated) + for i in range(3): + tracker.record( + thread_id=f"t{i}", agent_id="agent", model="grok-4.1", + provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0, + ) + + await asyncio.sleep(1.5) + + days = await store.get_daily_usage() + + assert len(days) == 1 # All same day + assert days[0]["total_tokens"] == 450 + assert days[0]["request_count"] == 3 + + async def test_daily_returns_date(self, store): + """Daily usage includes date field.""" + tracker = get_usage_tracker() + + tracker.record( + thread_id="t1", agent_id="agent", model="grok-4.1", + provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0, + ) + + await asyncio.sleep(1.5) + + days = await store.get_daily_usage() + + assert len(days) == 1 + # Date should be in YYYY-MM-DD format + assert len(days[0]["date"]) == 10 + assert days[0]["date"].count("-") == 2 + + +class TestStoreLifecycle: + """Test store lifecycle management.""" + + async def test_close_flushes_pending(self, temp_db): + """Closing store flushes pending writes.""" + reset_usage_tracker() + reset_usage_store() + + store = UsageStore(db_path=temp_db) + await store.initialize() + + tracker = get_usage_tracker() + tracker.record( + thread_id="t1", agent_id="agent", model="grok-4.1", + provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0, + ) + + # Close immediately - should flush + await store.close() + + # Reopen and verify data persisted + store2 = UsageStore(db_path=temp_db) + await store2.initialize() + + events = await store2.query() + assert len(events) == 1 + + await store2.close() diff --git a/xml_pipeline/llm/__init__.py b/xml_pipeline/llm/__init__.py index 01aeb02..8bf5539 100644 --- a/xml_pipeline/llm/__init__.py +++ b/xml_pipeline/llm/__init__.py @@ -30,6 +30,20 @@ Usage Tracking: # Query totals totals = tracker.get_totals() + +Usage Persistence (for billing): + from xml_pipeline.llm import get_usage_store + + store = await get_usage_store() + + # Query historical usage + events = await store.query( + start_time="2025-01-01T00:00:00Z", + org_id="org-123", + ) + + # Get billing summary + summary = await store.get_billing_summary(org_id="org-123") """ from xml_pipeline.llm.router import ( @@ -46,6 +60,13 @@ from xml_pipeline.llm.usage_tracker import ( get_usage_tracker, reset_usage_tracker, ) +from xml_pipeline.llm.usage_store import ( + UsageStore, + BillingSummary, + get_usage_store, + close_usage_store, + reset_usage_store, +) __all__ = [ # Router @@ -58,9 +79,15 @@ __all__ = [ "LLMRequest", "LLMResponse", "BackendError", - # Usage tracking + # Usage tracking (in-memory) "UsageTracker", "UsageEvent", "get_usage_tracker", "reset_usage_tracker", + # Usage persistence (SQLite) + "UsageStore", + "BillingSummary", + "get_usage_store", + "close_usage_store", + "reset_usage_store", ] diff --git a/xml_pipeline/llm/usage_store.py b/xml_pipeline/llm/usage_store.py new file mode 100644 index 0000000..f33a19d --- /dev/null +++ b/xml_pipeline/llm/usage_store.py @@ -0,0 +1,599 @@ +""" +Usage Store — Persistent storage for billing and usage analytics. + +Stores UsageEvents to SQLite (default) or PostgreSQL for: +- Historical billing queries +- Usage analytics and reporting +- Audit trails + +The store auto-subscribes to UsageTracker on initialization, +persisting all events transparently. + +Example: + from xml_pipeline.llm.usage_store import get_usage_store + + store = await get_usage_store() + + # Query historical usage + events = await store.query( + start_time="2025-01-01T00:00:00Z", + end_time="2025-01-31T23:59:59Z", + org_id="org-123", + ) + + # Get billing summary + summary = await store.get_billing_summary( + org_id="org-123", + start_time="2025-01-01T00:00:00Z", + ) +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import threading +from dataclasses import dataclass, asdict +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, List, Optional + +try: + import aiosqlite + HAS_AIOSQLITE = True +except ImportError: + HAS_AIOSQLITE = False + +from xml_pipeline.llm.usage_tracker import UsageEvent, get_usage_tracker + +logger = logging.getLogger(__name__) + + +# Default database path +DEFAULT_DB_PATH = Path.home() / ".xml-pipeline" / "usage.db" + + +@dataclass +class BillingSummary: + """Aggregated billing data for a time period.""" + org_id: Optional[str] + start_time: str + end_time: str + total_tokens: int + prompt_tokens: int + completion_tokens: int + request_count: int + total_cost: float + by_model: Dict[str, Dict[str, Any]] + by_agent: Dict[str, Dict[str, Any]] + + +class UsageStore: + """ + Persistent storage for usage events. + + Uses SQLite by default, with async I/O via aiosqlite. + Automatically subscribes to UsageTracker to capture all events. + """ + + def __init__(self, db_path: Optional[str] = None): + """ + Initialize the usage store. + + Args: + db_path: Path to SQLite database. Defaults to ~/.xml-pipeline/usage.db + """ + if not HAS_AIOSQLITE: + raise ImportError( + "aiosqlite is required for usage persistence. " + "Install with: pip install aiosqlite" + ) + + self._db_path = Path(db_path) if db_path else DEFAULT_DB_PATH + self._db_path.parent.mkdir(parents=True, exist_ok=True) + + self._initialized = False + self._init_lock = asyncio.Lock() + + # Queue for async persistence (events come from sync callback) + self._queue: asyncio.Queue[UsageEvent] = asyncio.Queue() + self._writer_task: Optional[asyncio.Task] = None + self._running = False + + async def initialize(self) -> None: + """Initialize database schema and start background writer.""" + async with self._init_lock: + if self._initialized: + return + + # Create tables + async with aiosqlite.connect(str(self._db_path)) as db: + await db.execute(""" + CREATE TABLE IF NOT EXISTS usage_events ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL, + thread_id TEXT NOT NULL, + agent_id TEXT, + model TEXT NOT NULL, + provider TEXT NOT NULL, + prompt_tokens INTEGER NOT NULL, + completion_tokens INTEGER NOT NULL, + total_tokens INTEGER NOT NULL, + latency_ms REAL NOT NULL, + estimated_cost REAL, + metadata TEXT, + + -- Denormalized for billing queries + org_id TEXT, + user_id TEXT + ) + """) + + # Indexes for common queries + await db.execute(""" + CREATE INDEX IF NOT EXISTS idx_usage_timestamp + ON usage_events(timestamp) + """) + await db.execute(""" + CREATE INDEX IF NOT EXISTS idx_usage_org_id + ON usage_events(org_id) + """) + await db.execute(""" + CREATE INDEX IF NOT EXISTS idx_usage_agent_id + ON usage_events(agent_id) + """) + await db.execute(""" + CREATE INDEX IF NOT EXISTS idx_usage_model + ON usage_events(model) + """) + + await db.commit() + + # Start background writer + self._running = True + self._writer_task = asyncio.create_task(self._writer_loop()) + + # Subscribe to usage tracker + tracker = get_usage_tracker() + tracker.subscribe(self._on_usage_event) + + self._initialized = True + logger.info(f"UsageStore initialized: {self._db_path}") + + def _on_usage_event(self, event: UsageEvent) -> None: + """ + Callback for UsageTracker events. + + This runs synchronously from the tracker, so we queue + the event for async persistence. + """ + try: + self._queue.put_nowait(event) + except asyncio.QueueFull: + logger.warning("Usage event queue full, dropping event") + + async def _writer_loop(self) -> None: + """Background task that writes queued events to database.""" + batch: List[UsageEvent] = [] + batch_timeout = 1.0 # Flush every second or 100 events + + while self._running or not self._queue.empty(): + try: + # Collect batch + try: + event = await asyncio.wait_for( + self._queue.get(), + timeout=batch_timeout + ) + batch.append(event) + + # Drain queue up to batch size + while len(batch) < 100: + try: + event = self._queue.get_nowait() + batch.append(event) + except asyncio.QueueEmpty: + break + + except asyncio.TimeoutError: + pass + + # Write batch + if batch: + await self._write_batch(batch) + batch = [] + + except Exception as e: + logger.error(f"Usage writer error: {e}") + await asyncio.sleep(1.0) + + async def _write_batch(self, events: List[UsageEvent]) -> None: + """Write a batch of events to database.""" + async with aiosqlite.connect(str(self._db_path)) as db: + await db.executemany( + """ + INSERT INTO usage_events ( + timestamp, thread_id, agent_id, model, provider, + prompt_tokens, completion_tokens, total_tokens, + latency_ms, estimated_cost, metadata, org_id, user_id + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + [ + ( + e.timestamp, + e.thread_id, + e.agent_id, + e.model, + e.provider, + e.prompt_tokens, + e.completion_tokens, + e.total_tokens, + e.latency_ms, + e.estimated_cost, + json.dumps(e.metadata) if e.metadata else None, + e.metadata.get("org_id") if e.metadata else None, + e.metadata.get("user_id") if e.metadata else None, + ) + for e in events + ] + ) + await db.commit() + + logger.debug(f"Persisted {len(events)} usage events") + + async def query( + self, + *, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + org_id: Optional[str] = None, + user_id: Optional[str] = None, + agent_id: Optional[str] = None, + model: Optional[str] = None, + limit: int = 1000, + offset: int = 0, + ) -> List[Dict[str, Any]]: + """ + Query historical usage events. + + Args: + start_time: ISO 8601 timestamp (inclusive) + end_time: ISO 8601 timestamp (inclusive) + org_id: Filter by organization + user_id: Filter by user + agent_id: Filter by agent + model: Filter by model + limit: Max results (default 1000) + offset: Pagination offset + + Returns: + List of usage event dicts + """ + conditions = [] + params = [] + + if start_time: + conditions.append("timestamp >= ?") + params.append(start_time) + if end_time: + conditions.append("timestamp <= ?") + params.append(end_time) + if org_id: + conditions.append("org_id = ?") + params.append(org_id) + if user_id: + conditions.append("user_id = ?") + params.append(user_id) + if agent_id: + conditions.append("agent_id = ?") + params.append(agent_id) + if model: + conditions.append("model = ?") + params.append(model) + + where_clause = " AND ".join(conditions) if conditions else "1=1" + + async with aiosqlite.connect(str(self._db_path)) as db: + db.row_factory = aiosqlite.Row + cursor = await db.execute( + f""" + SELECT * FROM usage_events + WHERE {where_clause} + ORDER BY timestamp DESC + LIMIT ? OFFSET ? + """, + params + [limit, offset] + ) + rows = await cursor.fetchall() + + return [ + { + "id": row["id"], + "timestamp": row["timestamp"], + "thread_id": row["thread_id"], + "agent_id": row["agent_id"], + "model": row["model"], + "provider": row["provider"], + "prompt_tokens": row["prompt_tokens"], + "completion_tokens": row["completion_tokens"], + "total_tokens": row["total_tokens"], + "latency_ms": row["latency_ms"], + "estimated_cost": row["estimated_cost"], + "metadata": json.loads(row["metadata"]) if row["metadata"] else {}, + } + for row in rows + ] + + async def get_billing_summary( + self, + *, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + org_id: Optional[str] = None, + ) -> BillingSummary: + """ + Get aggregated billing summary for a time period. + + Args: + start_time: ISO 8601 timestamp (inclusive) + end_time: ISO 8601 timestamp (inclusive) + org_id: Filter by organization + + Returns: + BillingSummary with totals and breakdowns + """ + conditions = [] + params = [] + + if start_time: + conditions.append("timestamp >= ?") + params.append(start_time) + if end_time: + conditions.append("timestamp <= ?") + params.append(end_time) + if org_id: + conditions.append("org_id = ?") + params.append(org_id) + + where_clause = " AND ".join(conditions) if conditions else "1=1" + + async with aiosqlite.connect(str(self._db_path)) as db: + # Overall totals + cursor = await db.execute( + f""" + SELECT + COALESCE(SUM(total_tokens), 0) as total_tokens, + COALESCE(SUM(prompt_tokens), 0) as prompt_tokens, + COALESCE(SUM(completion_tokens), 0) as completion_tokens, + COUNT(*) as request_count, + COALESCE(SUM(estimated_cost), 0) as total_cost + FROM usage_events + WHERE {where_clause} + """, + params + ) + row = await cursor.fetchone() + + totals = { + "total_tokens": row[0], + "prompt_tokens": row[1], + "completion_tokens": row[2], + "request_count": row[3], + "total_cost": round(row[4], 4), + } + + # By model + cursor = await db.execute( + f""" + SELECT + model, + SUM(total_tokens) as total_tokens, + SUM(prompt_tokens) as prompt_tokens, + SUM(completion_tokens) as completion_tokens, + COUNT(*) as request_count, + SUM(estimated_cost) as total_cost + FROM usage_events + WHERE {where_clause} + GROUP BY model + """, + params + ) + by_model = { + row[0]: { + "total_tokens": row[1], + "prompt_tokens": row[2], + "completion_tokens": row[3], + "request_count": row[4], + "total_cost": round(row[5] or 0, 4), + } + for row in await cursor.fetchall() + } + + # By agent + cursor = await db.execute( + f""" + SELECT + agent_id, + SUM(total_tokens) as total_tokens, + SUM(prompt_tokens) as prompt_tokens, + SUM(completion_tokens) as completion_tokens, + COUNT(*) as request_count, + SUM(estimated_cost) as total_cost + FROM usage_events + WHERE {where_clause} AND agent_id IS NOT NULL + GROUP BY agent_id + """, + params + ) + by_agent = { + row[0]: { + "total_tokens": row[1], + "prompt_tokens": row[2], + "completion_tokens": row[3], + "request_count": row[4], + "total_cost": round(row[5] or 0, 4), + } + for row in await cursor.fetchall() + } + + return BillingSummary( + org_id=org_id, + start_time=start_time or "", + end_time=end_time or datetime.now(timezone.utc).isoformat(), + total_tokens=totals["total_tokens"], + prompt_tokens=totals["prompt_tokens"], + completion_tokens=totals["completion_tokens"], + request_count=totals["request_count"], + total_cost=totals["total_cost"], + by_model=by_model, + by_agent=by_agent, + ) + + async def get_daily_usage( + self, + *, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + org_id: Optional[str] = None, + ) -> List[Dict[str, Any]]: + """ + Get usage aggregated by day for charting. + + Returns: + List of {date, total_tokens, request_count, total_cost} + """ + conditions = [] + params = [] + + if start_time: + conditions.append("timestamp >= ?") + params.append(start_time) + if end_time: + conditions.append("timestamp <= ?") + params.append(end_time) + if org_id: + conditions.append("org_id = ?") + params.append(org_id) + + where_clause = " AND ".join(conditions) if conditions else "1=1" + + async with aiosqlite.connect(str(self._db_path)) as db: + cursor = await db.execute( + f""" + SELECT + DATE(timestamp) as date, + SUM(total_tokens) as total_tokens, + COUNT(*) as request_count, + SUM(estimated_cost) as total_cost + FROM usage_events + WHERE {where_clause} + GROUP BY DATE(timestamp) + ORDER BY date + """, + params + ) + rows = await cursor.fetchall() + + return [ + { + "date": row[0], + "total_tokens": row[1], + "request_count": row[2], + "total_cost": round(row[3] or 0, 4), + } + for row in rows + ] + + async def count( + self, + *, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + org_id: Optional[str] = None, + ) -> int: + """Get total count of events matching criteria.""" + conditions = [] + params = [] + + if start_time: + conditions.append("timestamp >= ?") + params.append(start_time) + if end_time: + conditions.append("timestamp <= ?") + params.append(end_time) + if org_id: + conditions.append("org_id = ?") + params.append(org_id) + + where_clause = " AND ".join(conditions) if conditions else "1=1" + + async with aiosqlite.connect(str(self._db_path)) as db: + cursor = await db.execute( + f"SELECT COUNT(*) FROM usage_events WHERE {where_clause}", + params + ) + row = await cursor.fetchone() + return row[0] + + async def close(self) -> None: + """Shutdown the store, flushing pending writes.""" + self._running = False + + # Unsubscribe from tracker + tracker = get_usage_tracker() + tracker.unsubscribe(self._on_usage_event) + + # Wait for writer to finish + if self._writer_task: + await self._writer_task + self._writer_task = None + + logger.info("UsageStore closed") + + +# ============================================================================= +# Global Instance +# ============================================================================= + +_store: Optional[UsageStore] = None +_store_lock = threading.Lock() + + +async def get_usage_store(db_path: Optional[str] = None) -> UsageStore: + """ + Get the global usage store, initializing if needed. + + Args: + db_path: Optional path to SQLite database + + Returns: + Initialized UsageStore + """ + global _store + + if _store is None: + with _store_lock: + if _store is None: + _store = UsageStore(db_path) + + if not _store._initialized: + await _store.initialize() + + return _store + + +async def close_usage_store() -> None: + """Close the global usage store.""" + global _store + if _store is not None: + await _store.close() + _store = None + + +def reset_usage_store() -> None: + """Reset global store (for testing).""" + global _store + with _store_lock: + _store = None diff --git a/xml_pipeline/server/api.py b/xml_pipeline/server/api.py index 45b2349..f265741 100644 --- a/xml_pipeline/server/api.py +++ b/xml_pipeline/server/api.py @@ -19,9 +19,12 @@ from xml_pipeline.server.models import ( AgentInfo, AgentListResponse, AgentUsageInfo, + BillingSummaryResponse, CapabilityDetail, CapabilityInfo, CapabilityListResponse, + DailyUsagePoint, + DailyUsageResponse, ErrorResponse, InjectRequest, InjectResponse, @@ -33,6 +36,8 @@ from xml_pipeline.server.models import ( ThreadInfo, ThreadListResponse, ThreadStatus, + UsageEventInfo, + UsageHistoryResponse, UsageOverview, UsageResponse, UsageTotals, @@ -425,6 +430,140 @@ def create_router(state: "ServerState") -> APIRouter: reset_usage_tracker() return {"success": True, "message": "Usage tracking reset"} + # ========================================================================= + # Usage History Endpoints (Persistent) + # ========================================================================= + + @router.get("/usage/history", response_model=UsageHistoryResponse) + async def get_usage_history( + start_time: Optional[str] = Query(None, description="ISO 8601 start time"), + end_time: Optional[str] = Query(None, description="ISO 8601 end time"), + org_id: Optional[str] = Query(None, description="Filter by organization"), + agent_id: Optional[str] = Query(None, description="Filter by agent"), + model: Optional[str] = Query(None, description="Filter by model"), + limit: int = Query(100, ge=1, le=1000), + offset: int = Query(0, ge=0), + ) -> UsageHistoryResponse: + """ + Query historical usage events from persistent storage. + + Use for billing reconciliation, audit trails, and detailed analytics. + Events are stored in SQLite and persist across restarts. + """ + from xml_pipeline.llm.usage_store import get_usage_store + + store = await get_usage_store() + + events = await store.query( + start_time=start_time, + end_time=end_time, + org_id=org_id, + agent_id=agent_id, + model=model, + limit=limit, + offset=offset, + ) + + total = await store.count( + start_time=start_time, + end_time=end_time, + org_id=org_id, + ) + + return UsageHistoryResponse( + events=[ + UsageEventInfo( + id=e["id"], + timestamp=e["timestamp"], + thread_id=e["thread_id"], + agent_id=e.get("agent_id"), + model=e["model"], + provider=e["provider"], + prompt_tokens=e["prompt_tokens"], + completion_tokens=e["completion_tokens"], + total_tokens=e["total_tokens"], + latency_ms=e["latency_ms"], + estimated_cost=e.get("estimated_cost"), + metadata=e.get("metadata", {}), + ) + for e in events + ], + count=len(events), + total=total, + offset=offset, + limit=limit, + ) + + @router.get("/usage/billing", response_model=BillingSummaryResponse) + async def get_billing_summary( + start_time: Optional[str] = Query(None, description="ISO 8601 start time"), + end_time: Optional[str] = Query(None, description="ISO 8601 end time"), + org_id: Optional[str] = Query(None, description="Filter by organization"), + ) -> BillingSummaryResponse: + """ + Get aggregated billing summary for a time period. + + Returns total tokens, costs, and breakdowns by model and agent. + Use for invoicing and cost analysis. + """ + from xml_pipeline.llm.usage_store import get_usage_store + + store = await get_usage_store() + + summary = await store.get_billing_summary( + start_time=start_time, + end_time=end_time, + org_id=org_id, + ) + + return BillingSummaryResponse( + org_id=summary.org_id, + start_time=summary.start_time, + end_time=summary.end_time, + total_tokens=summary.total_tokens, + prompt_tokens=summary.prompt_tokens, + completion_tokens=summary.completion_tokens, + request_count=summary.request_count, + total_cost=summary.total_cost, + by_model=summary.by_model, + by_agent=summary.by_agent, + ) + + @router.get("/usage/daily", response_model=DailyUsageResponse) + async def get_daily_usage( + start_time: Optional[str] = Query(None, description="ISO 8601 start time"), + end_time: Optional[str] = Query(None, description="ISO 8601 end time"), + org_id: Optional[str] = Query(None, description="Filter by organization"), + ) -> DailyUsageResponse: + """ + Get usage aggregated by day for charting. + + Returns daily totals for tokens, requests, and costs. + Useful for dashboards and trend analysis. + """ + from xml_pipeline.llm.usage_store import get_usage_store + + store = await get_usage_store() + + days = await store.get_daily_usage( + start_time=start_time, + end_time=end_time, + org_id=org_id, + ) + + return DailyUsageResponse( + days=[ + DailyUsagePoint( + date=d["date"], + total_tokens=d["total_tokens"], + request_count=d["request_count"], + total_cost=d["total_cost"], + ) + for d in days + ], + count=len(days), + ) + # ========================================================================= # Control Endpoints # ========================================================================= diff --git a/xml_pipeline/server/models.py b/xml_pipeline/server/models.py index 7aa6107..232e923 100644 --- a/xml_pipeline/server/models.py +++ b/xml_pipeline/server/models.py @@ -368,3 +368,66 @@ class ThreadBudgetListResponse(CamelModel): threads: List[ThreadBudgetInfo] count: int default_max_tokens: int = Field(alias="defaultMaxTokens") + + +# ============================================================================= +# Usage History Models (Persistent) +# ============================================================================= + + +class UsageEventInfo(CamelModel): + """A single usage event from history.""" + + id: int + timestamp: str + thread_id: str = Field(alias="threadId") + agent_id: Optional[str] = Field(None, alias="agentId") + model: str + provider: str + prompt_tokens: int = Field(alias="promptTokens") + completion_tokens: int = Field(alias="completionTokens") + total_tokens: int = Field(alias="totalTokens") + latency_ms: float = Field(alias="latencyMs") + estimated_cost: Optional[float] = Field(None, alias="estimatedCost") + metadata: dict = Field(default_factory=dict) + + +class UsageHistoryResponse(CamelModel): + """Response for GET /usage/history.""" + + events: List[UsageEventInfo] + count: int + total: int + offset: int + limit: int + + +class BillingSummaryResponse(CamelModel): + """Response for GET /usage/billing.""" + + org_id: Optional[str] = Field(None, alias="orgId") + start_time: str = Field(alias="startTime") + end_time: str = Field(alias="endTime") + total_tokens: int = Field(alias="totalTokens") + prompt_tokens: int = Field(alias="promptTokens") + completion_tokens: int = Field(alias="completionTokens") + request_count: int = Field(alias="requestCount") + total_cost: float = Field(alias="totalCost") + by_model: dict = Field(default_factory=dict, alias="byModel") + by_agent: dict = Field(default_factory=dict, alias="byAgent") + + +class DailyUsagePoint(CamelModel): + """A single day's usage for charting.""" + + date: str + total_tokens: int = Field(alias="totalTokens") + request_count: int = Field(alias="requestCount") + total_cost: float = Field(alias="totalCost") + + +class DailyUsageResponse(CamelModel): + """Response for GET /usage/daily.""" + + days: List[DailyUsagePoint] + count: int