xml-pipeline/tests/test_usage_store.py
dullfig d0d78a9f70 Add usage persistence for billing (SQLite)
- UsageStore with async SQLite persistence via aiosqlite
- Background batch writer for non-blocking event persistence
- Auto-subscribes to UsageTracker for transparent capture
- Query methods: query(), get_billing_summary(), get_daily_usage()
- REST API endpoints: /usage/history, /usage/billing, /usage/daily
- Filtering by org_id, agent_id, model, time range
- 18 new tests for persistence layer

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-27 21:58:22 -08:00

398 lines
12 KiB
Python

"""
Tests for usage persistence layer.
Tests UsageStore's SQLite persistence, querying, and billing summaries.
"""
from __future__ import annotations
import asyncio
import tempfile
from datetime import datetime, timezone
from pathlib import Path
import pytest
from xml_pipeline.llm.usage_tracker import UsageEvent, get_usage_tracker, reset_usage_tracker
from xml_pipeline.llm.usage_store import UsageStore, reset_usage_store
@pytest.fixture
def temp_db():
"""Create a temporary database file."""
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
yield f.name
# Cleanup happens automatically
@pytest.fixture
async def store(temp_db):
"""Create and initialize a UsageStore with temp database."""
reset_usage_tracker()
reset_usage_store()
s = UsageStore(db_path=temp_db)
await s.initialize()
yield s
await s.close()
class TestUsageStoreBasics:
"""Test basic store operations."""
async def test_initialize_creates_table(self, temp_db):
"""Store initialization creates the usage_events table."""
reset_usage_store()
store = UsageStore(db_path=temp_db)
await store.initialize()
# Query should work (table exists)
events = await store.query(limit=10)
assert events == []
await store.close()
async def test_store_initializes_once(self, store):
"""Multiple initialize calls are idempotent."""
# Should not raise
await store.initialize()
await store.initialize()
async def test_empty_query(self, store):
"""Query on empty store returns empty list."""
events = await store.query()
assert events == []
async def test_empty_count(self, store):
"""Count on empty store returns 0."""
count = await store.count()
assert count == 0
class TestEventPersistence:
"""Test event persistence via subscriber pattern."""
async def test_event_persisted_via_tracker(self, store):
"""Events recorded in tracker are persisted to store."""
tracker = get_usage_tracker()
# Record an event
tracker.record(
thread_id="test-thread-1",
agent_id="greeter",
model="grok-4.1",
provider="xai",
prompt_tokens=100,
completion_tokens=50,
latency_ms=250.5,
metadata={"org_id": "org-123"},
)
# Give background writer time to flush
await asyncio.sleep(1.5)
# Query should find the event
events = await store.query()
assert len(events) == 1
event = events[0]
assert event["thread_id"] == "test-thread-1"
assert event["agent_id"] == "greeter"
assert event["model"] == "grok-4.1"
assert event["provider"] == "xai"
assert event["prompt_tokens"] == 100
assert event["completion_tokens"] == 50
assert event["total_tokens"] == 150
assert event["latency_ms"] == 250.5
async def test_multiple_events_persisted(self, store):
"""Multiple events are all persisted."""
tracker = get_usage_tracker()
for i in range(5):
tracker.record(
thread_id=f"thread-{i}",
agent_id="agent",
model="grok-4.1",
provider="xai",
prompt_tokens=100 * (i + 1),
completion_tokens=50 * (i + 1),
latency_ms=100.0,
)
await asyncio.sleep(1.5)
events = await store.query()
assert len(events) == 5
async def test_event_with_cost_estimate(self, store):
"""Estimated cost is persisted."""
tracker = get_usage_tracker()
# grok-4.1 has pricing in MODEL_COSTS
tracker.record(
thread_id="cost-thread",
agent_id="agent",
model="grok-4.1",
provider="xai",
prompt_tokens=1000,
completion_tokens=500,
latency_ms=100.0,
)
await asyncio.sleep(1.5)
events = await store.query()
assert len(events) == 1
assert events[0]["estimated_cost"] is not None
assert events[0]["estimated_cost"] > 0
class TestQueryFiltering:
"""Test query filtering options."""
async def test_filter_by_agent(self, store):
"""Filter events by agent_id."""
tracker = get_usage_tracker()
tracker.record(
thread_id="t1", agent_id="greeter", model="grok-4.1",
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
)
tracker.record(
thread_id="t2", agent_id="shouter", model="grok-4.1",
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
)
await asyncio.sleep(1.5)
events = await store.query(agent_id="greeter")
assert len(events) == 1
assert events[0]["agent_id"] == "greeter"
async def test_filter_by_model(self, store):
"""Filter events by model."""
tracker = get_usage_tracker()
tracker.record(
thread_id="t1", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
)
tracker.record(
thread_id="t2", agent_id="agent", model="claude-sonnet-4",
provider="anthropic", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
)
await asyncio.sleep(1.5)
events = await store.query(model="grok-4.1")
assert len(events) == 1
assert events[0]["model"] == "grok-4.1"
async def test_filter_by_org_id(self, store):
"""Filter events by org_id in metadata."""
tracker = get_usage_tracker()
tracker.record(
thread_id="t1", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
metadata={"org_id": "org-A"},
)
tracker.record(
thread_id="t2", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
metadata={"org_id": "org-B"},
)
await asyncio.sleep(1.5)
events = await store.query(org_id="org-A")
assert len(events) == 1
assert events[0]["thread_id"] == "t1"
async def test_pagination(self, store):
"""Test limit and offset for pagination."""
tracker = get_usage_tracker()
for i in range(10):
tracker.record(
thread_id=f"t{i:02d}", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
)
await asyncio.sleep(1.5)
# Get first page
page1 = await store.query(limit=3, offset=0)
assert len(page1) == 3
# Get second page
page2 = await store.query(limit=3, offset=3)
assert len(page2) == 3
# Different events
assert page1[0]["thread_id"] != page2[0]["thread_id"]
class TestBillingSummary:
"""Test billing summary aggregation."""
async def test_billing_totals(self, store):
"""Billing summary calculates correct totals."""
tracker = get_usage_tracker()
tracker.record(
thread_id="t1", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0,
)
tracker.record(
thread_id="t2", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=2000, completion_tokens=1000, latency_ms=100.0,
)
await asyncio.sleep(1.5)
summary = await store.get_billing_summary()
assert summary.total_tokens == 4500 # 1500 + 3000
assert summary.prompt_tokens == 3000
assert summary.completion_tokens == 1500
assert summary.request_count == 2
async def test_billing_by_model(self, store):
"""Billing summary includes breakdown by model."""
tracker = get_usage_tracker()
tracker.record(
thread_id="t1", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0,
)
tracker.record(
thread_id="t2", agent_id="agent", model="claude-sonnet-4",
provider="anthropic", prompt_tokens=2000, completion_tokens=1000, latency_ms=100.0,
)
await asyncio.sleep(1.5)
summary = await store.get_billing_summary()
assert "grok-4.1" in summary.by_model
assert "claude-sonnet-4" in summary.by_model
assert summary.by_model["grok-4.1"]["total_tokens"] == 1500
assert summary.by_model["claude-sonnet-4"]["total_tokens"] == 3000
async def test_billing_by_agent(self, store):
"""Billing summary includes breakdown by agent."""
tracker = get_usage_tracker()
tracker.record(
thread_id="t1", agent_id="greeter", model="grok-4.1",
provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0,
)
tracker.record(
thread_id="t2", agent_id="shouter", model="grok-4.1",
provider="xai", prompt_tokens=500, completion_tokens=250, latency_ms=100.0,
)
await asyncio.sleep(1.5)
summary = await store.get_billing_summary()
assert "greeter" in summary.by_agent
assert "shouter" in summary.by_agent
assert summary.by_agent["greeter"]["total_tokens"] == 1500
assert summary.by_agent["shouter"]["total_tokens"] == 750
async def test_billing_filtered_by_org(self, store):
"""Billing summary can be filtered by org_id."""
tracker = get_usage_tracker()
tracker.record(
thread_id="t1", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0,
metadata={"org_id": "org-A"},
)
tracker.record(
thread_id="t2", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=2000, completion_tokens=1000, latency_ms=100.0,
metadata={"org_id": "org-B"},
)
await asyncio.sleep(1.5)
summary_a = await store.get_billing_summary(org_id="org-A")
summary_b = await store.get_billing_summary(org_id="org-B")
assert summary_a.total_tokens == 1500
assert summary_b.total_tokens == 3000
class TestDailyUsage:
"""Test daily usage aggregation."""
async def test_daily_aggregation(self, store):
"""Daily usage aggregates by date."""
tracker = get_usage_tracker()
# Record multiple events (all same day since timestamps are auto-generated)
for i in range(3):
tracker.record(
thread_id=f"t{i}", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
)
await asyncio.sleep(1.5)
days = await store.get_daily_usage()
assert len(days) == 1 # All same day
assert days[0]["total_tokens"] == 450
assert days[0]["request_count"] == 3
async def test_daily_returns_date(self, store):
"""Daily usage includes date field."""
tracker = get_usage_tracker()
tracker.record(
thread_id="t1", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
)
await asyncio.sleep(1.5)
days = await store.get_daily_usage()
assert len(days) == 1
# Date should be in YYYY-MM-DD format
assert len(days[0]["date"]) == 10
assert days[0]["date"].count("-") == 2
class TestStoreLifecycle:
"""Test store lifecycle management."""
async def test_close_flushes_pending(self, temp_db):
"""Closing store flushes pending writes."""
reset_usage_tracker()
reset_usage_store()
store = UsageStore(db_path=temp_db)
await store.initialize()
tracker = get_usage_tracker()
tracker.record(
thread_id="t1", agent_id="agent", model="grok-4.1",
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
)
# Close immediately - should flush
await store.close()
# Reopen and verify data persisted
store2 = UsageStore(db_path=temp_db)
await store2.initialize()
events = await store2.query()
assert len(events) == 1
await store2.close()