- UsageStore with async SQLite persistence via aiosqlite - Background batch writer for non-blocking event persistence - Auto-subscribes to UsageTracker for transparent capture - Query methods: query(), get_billing_summary(), get_daily_usage() - REST API endpoints: /usage/history, /usage/billing, /usage/daily - Filtering by org_id, agent_id, model, time range - 18 new tests for persistence layer Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
398 lines
12 KiB
Python
398 lines
12 KiB
Python
"""
|
|
Tests for usage persistence layer.
|
|
|
|
Tests UsageStore's SQLite persistence, querying, and billing summaries.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import tempfile
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from xml_pipeline.llm.usage_tracker import UsageEvent, get_usage_tracker, reset_usage_tracker
|
|
from xml_pipeline.llm.usage_store import UsageStore, reset_usage_store
|
|
|
|
|
|
@pytest.fixture
|
|
def temp_db():
|
|
"""Create a temporary database file."""
|
|
with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as f:
|
|
yield f.name
|
|
# Cleanup happens automatically
|
|
|
|
|
|
@pytest.fixture
|
|
async def store(temp_db):
|
|
"""Create and initialize a UsageStore with temp database."""
|
|
reset_usage_tracker()
|
|
reset_usage_store()
|
|
|
|
s = UsageStore(db_path=temp_db)
|
|
await s.initialize()
|
|
yield s
|
|
await s.close()
|
|
|
|
|
|
class TestUsageStoreBasics:
|
|
"""Test basic store operations."""
|
|
|
|
async def test_initialize_creates_table(self, temp_db):
|
|
"""Store initialization creates the usage_events table."""
|
|
reset_usage_store()
|
|
store = UsageStore(db_path=temp_db)
|
|
await store.initialize()
|
|
|
|
# Query should work (table exists)
|
|
events = await store.query(limit=10)
|
|
assert events == []
|
|
|
|
await store.close()
|
|
|
|
async def test_store_initializes_once(self, store):
|
|
"""Multiple initialize calls are idempotent."""
|
|
# Should not raise
|
|
await store.initialize()
|
|
await store.initialize()
|
|
|
|
async def test_empty_query(self, store):
|
|
"""Query on empty store returns empty list."""
|
|
events = await store.query()
|
|
assert events == []
|
|
|
|
async def test_empty_count(self, store):
|
|
"""Count on empty store returns 0."""
|
|
count = await store.count()
|
|
assert count == 0
|
|
|
|
|
|
class TestEventPersistence:
|
|
"""Test event persistence via subscriber pattern."""
|
|
|
|
async def test_event_persisted_via_tracker(self, store):
|
|
"""Events recorded in tracker are persisted to store."""
|
|
tracker = get_usage_tracker()
|
|
|
|
# Record an event
|
|
tracker.record(
|
|
thread_id="test-thread-1",
|
|
agent_id="greeter",
|
|
model="grok-4.1",
|
|
provider="xai",
|
|
prompt_tokens=100,
|
|
completion_tokens=50,
|
|
latency_ms=250.5,
|
|
metadata={"org_id": "org-123"},
|
|
)
|
|
|
|
# Give background writer time to flush
|
|
await asyncio.sleep(1.5)
|
|
|
|
# Query should find the event
|
|
events = await store.query()
|
|
assert len(events) == 1
|
|
|
|
event = events[0]
|
|
assert event["thread_id"] == "test-thread-1"
|
|
assert event["agent_id"] == "greeter"
|
|
assert event["model"] == "grok-4.1"
|
|
assert event["provider"] == "xai"
|
|
assert event["prompt_tokens"] == 100
|
|
assert event["completion_tokens"] == 50
|
|
assert event["total_tokens"] == 150
|
|
assert event["latency_ms"] == 250.5
|
|
|
|
async def test_multiple_events_persisted(self, store):
|
|
"""Multiple events are all persisted."""
|
|
tracker = get_usage_tracker()
|
|
|
|
for i in range(5):
|
|
tracker.record(
|
|
thread_id=f"thread-{i}",
|
|
agent_id="agent",
|
|
model="grok-4.1",
|
|
provider="xai",
|
|
prompt_tokens=100 * (i + 1),
|
|
completion_tokens=50 * (i + 1),
|
|
latency_ms=100.0,
|
|
)
|
|
|
|
await asyncio.sleep(1.5)
|
|
|
|
events = await store.query()
|
|
assert len(events) == 5
|
|
|
|
async def test_event_with_cost_estimate(self, store):
|
|
"""Estimated cost is persisted."""
|
|
tracker = get_usage_tracker()
|
|
|
|
# grok-4.1 has pricing in MODEL_COSTS
|
|
tracker.record(
|
|
thread_id="cost-thread",
|
|
agent_id="agent",
|
|
model="grok-4.1",
|
|
provider="xai",
|
|
prompt_tokens=1000,
|
|
completion_tokens=500,
|
|
latency_ms=100.0,
|
|
)
|
|
|
|
await asyncio.sleep(1.5)
|
|
|
|
events = await store.query()
|
|
assert len(events) == 1
|
|
assert events[0]["estimated_cost"] is not None
|
|
assert events[0]["estimated_cost"] > 0
|
|
|
|
|
|
class TestQueryFiltering:
|
|
"""Test query filtering options."""
|
|
|
|
async def test_filter_by_agent(self, store):
|
|
"""Filter events by agent_id."""
|
|
tracker = get_usage_tracker()
|
|
|
|
tracker.record(
|
|
thread_id="t1", agent_id="greeter", model="grok-4.1",
|
|
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
|
)
|
|
tracker.record(
|
|
thread_id="t2", agent_id="shouter", model="grok-4.1",
|
|
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
|
)
|
|
|
|
await asyncio.sleep(1.5)
|
|
|
|
events = await store.query(agent_id="greeter")
|
|
assert len(events) == 1
|
|
assert events[0]["agent_id"] == "greeter"
|
|
|
|
async def test_filter_by_model(self, store):
|
|
"""Filter events by model."""
|
|
tracker = get_usage_tracker()
|
|
|
|
tracker.record(
|
|
thread_id="t1", agent_id="agent", model="grok-4.1",
|
|
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
|
)
|
|
tracker.record(
|
|
thread_id="t2", agent_id="agent", model="claude-sonnet-4",
|
|
provider="anthropic", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
|
)
|
|
|
|
await asyncio.sleep(1.5)
|
|
|
|
events = await store.query(model="grok-4.1")
|
|
assert len(events) == 1
|
|
assert events[0]["model"] == "grok-4.1"
|
|
|
|
async def test_filter_by_org_id(self, store):
|
|
"""Filter events by org_id in metadata."""
|
|
tracker = get_usage_tracker()
|
|
|
|
tracker.record(
|
|
thread_id="t1", agent_id="agent", model="grok-4.1",
|
|
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
|
metadata={"org_id": "org-A"},
|
|
)
|
|
tracker.record(
|
|
thread_id="t2", agent_id="agent", model="grok-4.1",
|
|
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
|
metadata={"org_id": "org-B"},
|
|
)
|
|
|
|
await asyncio.sleep(1.5)
|
|
|
|
events = await store.query(org_id="org-A")
|
|
assert len(events) == 1
|
|
assert events[0]["thread_id"] == "t1"
|
|
|
|
async def test_pagination(self, store):
|
|
"""Test limit and offset for pagination."""
|
|
tracker = get_usage_tracker()
|
|
|
|
for i in range(10):
|
|
tracker.record(
|
|
thread_id=f"t{i:02d}", agent_id="agent", model="grok-4.1",
|
|
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
|
)
|
|
|
|
await asyncio.sleep(1.5)
|
|
|
|
# Get first page
|
|
page1 = await store.query(limit=3, offset=0)
|
|
assert len(page1) == 3
|
|
|
|
# Get second page
|
|
page2 = await store.query(limit=3, offset=3)
|
|
assert len(page2) == 3
|
|
|
|
# Different events
|
|
assert page1[0]["thread_id"] != page2[0]["thread_id"]
|
|
|
|
|
|
class TestBillingSummary:
|
|
"""Test billing summary aggregation."""
|
|
|
|
async def test_billing_totals(self, store):
|
|
"""Billing summary calculates correct totals."""
|
|
tracker = get_usage_tracker()
|
|
|
|
tracker.record(
|
|
thread_id="t1", agent_id="agent", model="grok-4.1",
|
|
provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0,
|
|
)
|
|
tracker.record(
|
|
thread_id="t2", agent_id="agent", model="grok-4.1",
|
|
provider="xai", prompt_tokens=2000, completion_tokens=1000, latency_ms=100.0,
|
|
)
|
|
|
|
await asyncio.sleep(1.5)
|
|
|
|
summary = await store.get_billing_summary()
|
|
|
|
assert summary.total_tokens == 4500 # 1500 + 3000
|
|
assert summary.prompt_tokens == 3000
|
|
assert summary.completion_tokens == 1500
|
|
assert summary.request_count == 2
|
|
|
|
async def test_billing_by_model(self, store):
|
|
"""Billing summary includes breakdown by model."""
|
|
tracker = get_usage_tracker()
|
|
|
|
tracker.record(
|
|
thread_id="t1", agent_id="agent", model="grok-4.1",
|
|
provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0,
|
|
)
|
|
tracker.record(
|
|
thread_id="t2", agent_id="agent", model="claude-sonnet-4",
|
|
provider="anthropic", prompt_tokens=2000, completion_tokens=1000, latency_ms=100.0,
|
|
)
|
|
|
|
await asyncio.sleep(1.5)
|
|
|
|
summary = await store.get_billing_summary()
|
|
|
|
assert "grok-4.1" in summary.by_model
|
|
assert "claude-sonnet-4" in summary.by_model
|
|
assert summary.by_model["grok-4.1"]["total_tokens"] == 1500
|
|
assert summary.by_model["claude-sonnet-4"]["total_tokens"] == 3000
|
|
|
|
async def test_billing_by_agent(self, store):
|
|
"""Billing summary includes breakdown by agent."""
|
|
tracker = get_usage_tracker()
|
|
|
|
tracker.record(
|
|
thread_id="t1", agent_id="greeter", model="grok-4.1",
|
|
provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0,
|
|
)
|
|
tracker.record(
|
|
thread_id="t2", agent_id="shouter", model="grok-4.1",
|
|
provider="xai", prompt_tokens=500, completion_tokens=250, latency_ms=100.0,
|
|
)
|
|
|
|
await asyncio.sleep(1.5)
|
|
|
|
summary = await store.get_billing_summary()
|
|
|
|
assert "greeter" in summary.by_agent
|
|
assert "shouter" in summary.by_agent
|
|
assert summary.by_agent["greeter"]["total_tokens"] == 1500
|
|
assert summary.by_agent["shouter"]["total_tokens"] == 750
|
|
|
|
async def test_billing_filtered_by_org(self, store):
|
|
"""Billing summary can be filtered by org_id."""
|
|
tracker = get_usage_tracker()
|
|
|
|
tracker.record(
|
|
thread_id="t1", agent_id="agent", model="grok-4.1",
|
|
provider="xai", prompt_tokens=1000, completion_tokens=500, latency_ms=100.0,
|
|
metadata={"org_id": "org-A"},
|
|
)
|
|
tracker.record(
|
|
thread_id="t2", agent_id="agent", model="grok-4.1",
|
|
provider="xai", prompt_tokens=2000, completion_tokens=1000, latency_ms=100.0,
|
|
metadata={"org_id": "org-B"},
|
|
)
|
|
|
|
await asyncio.sleep(1.5)
|
|
|
|
summary_a = await store.get_billing_summary(org_id="org-A")
|
|
summary_b = await store.get_billing_summary(org_id="org-B")
|
|
|
|
assert summary_a.total_tokens == 1500
|
|
assert summary_b.total_tokens == 3000
|
|
|
|
|
|
class TestDailyUsage:
|
|
"""Test daily usage aggregation."""
|
|
|
|
async def test_daily_aggregation(self, store):
|
|
"""Daily usage aggregates by date."""
|
|
tracker = get_usage_tracker()
|
|
|
|
# Record multiple events (all same day since timestamps are auto-generated)
|
|
for i in range(3):
|
|
tracker.record(
|
|
thread_id=f"t{i}", agent_id="agent", model="grok-4.1",
|
|
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
|
)
|
|
|
|
await asyncio.sleep(1.5)
|
|
|
|
days = await store.get_daily_usage()
|
|
|
|
assert len(days) == 1 # All same day
|
|
assert days[0]["total_tokens"] == 450
|
|
assert days[0]["request_count"] == 3
|
|
|
|
async def test_daily_returns_date(self, store):
|
|
"""Daily usage includes date field."""
|
|
tracker = get_usage_tracker()
|
|
|
|
tracker.record(
|
|
thread_id="t1", agent_id="agent", model="grok-4.1",
|
|
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
|
)
|
|
|
|
await asyncio.sleep(1.5)
|
|
|
|
days = await store.get_daily_usage()
|
|
|
|
assert len(days) == 1
|
|
# Date should be in YYYY-MM-DD format
|
|
assert len(days[0]["date"]) == 10
|
|
assert days[0]["date"].count("-") == 2
|
|
|
|
|
|
class TestStoreLifecycle:
|
|
"""Test store lifecycle management."""
|
|
|
|
async def test_close_flushes_pending(self, temp_db):
|
|
"""Closing store flushes pending writes."""
|
|
reset_usage_tracker()
|
|
reset_usage_store()
|
|
|
|
store = UsageStore(db_path=temp_db)
|
|
await store.initialize()
|
|
|
|
tracker = get_usage_tracker()
|
|
tracker.record(
|
|
thread_id="t1", agent_id="agent", model="grok-4.1",
|
|
provider="xai", prompt_tokens=100, completion_tokens=50, latency_ms=100.0,
|
|
)
|
|
|
|
# Close immediately - should flush
|
|
await store.close()
|
|
|
|
# Reopen and verify data persisted
|
|
store2 = UsageStore(db_path=temp_db)
|
|
await store2.initialize()
|
|
|
|
events = await store2.query()
|
|
assert len(events) == 1
|
|
|
|
await store2.close()
|