""" journal_store.py — SQLite persistence layer for the message journal. Provides async CRUD operations for journal entries using aiosqlite. WAL mode is enabled for concurrent reads/writes. This module is the storage backend for MessageJournal (journal.py). """ from __future__ import annotations import logging from pathlib import Path from typing import Any, Dict, List, Optional try: import aiosqlite HAS_AIOSQLITE = True except ImportError: HAS_AIOSQLITE = False logger = logging.getLogger(__name__) # Default database path DEFAULT_JOURNAL_DB_PATH = Path.home() / ".xml-pipeline" / "journal.db" # SQL schema _CREATE_TABLE = """ CREATE TABLE IF NOT EXISTS journal_entries ( id TEXT PRIMARY KEY, thread_id TEXT NOT NULL, from_id TEXT NOT NULL, to_id TEXT NOT NULL, payload_type TEXT NOT NULL, payload_bytes BLOB NOT NULL, status TEXT NOT NULL DEFAULT 'pending', created_at TEXT NOT NULL, dispatched_at TEXT, acked_at TEXT, failed_at TEXT, retry_count INTEGER NOT NULL DEFAULT 0, error TEXT ) """ _CREATE_INDEXES = [ "CREATE INDEX IF NOT EXISTS idx_journal_status ON journal_entries(status)", "CREATE INDEX IF NOT EXISTS idx_journal_thread ON journal_entries(thread_id)", "CREATE INDEX IF NOT EXISTS idx_journal_created ON journal_entries(created_at)", ] class JournalStore: """ Async SQLite persistence for journal entries. Uses WAL mode for concurrent read/write access. """ def __init__(self, db_path: Optional[str] = None) -> None: if not HAS_AIOSQLITE: raise ImportError( "aiosqlite is required for the message journal. " "Install with: pip install aiosqlite" ) self._db_path = Path(db_path) if db_path else DEFAULT_JOURNAL_DB_PATH self._db_path.parent.mkdir(parents=True, exist_ok=True) self._initialized = False async def initialize(self) -> None: """Create tables and indexes if they don't exist.""" if self._initialized: return async with aiosqlite.connect(str(self._db_path)) as db: # Enable WAL mode for concurrent access await db.execute("PRAGMA journal_mode=WAL") await db.execute(_CREATE_TABLE) for idx_sql in _CREATE_INDEXES: await db.execute(idx_sql) await db.commit() self._initialized = True logger.info(f"JournalStore initialized: {self._db_path}") async def insert( self, entry_id: str, thread_id: str, from_id: str, to_id: str, payload_type: str, payload_bytes: bytes, status: str, created_at: str, ) -> None: """Insert a new journal entry.""" async with aiosqlite.connect(str(self._db_path)) as db: await db.execute( """ INSERT INTO journal_entries (id, thread_id, from_id, to_id, payload_type, payload_bytes, status, created_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, (entry_id, thread_id, from_id, to_id, payload_type, payload_bytes, status, created_at), ) await db.commit() async def update_status( self, entry_id: str, status: str, *, timestamp_field: Optional[str] = None, timestamp_value: Optional[str] = None, error: Optional[str] = None, ) -> None: """Update the status of a journal entry.""" async with aiosqlite.connect(str(self._db_path)) as db: if timestamp_field and timestamp_value: await db.execute( f""" UPDATE journal_entries SET status = ?, {timestamp_field} = ? WHERE id = ? """, (status, timestamp_value, entry_id), ) elif error is not None: await db.execute( """ UPDATE journal_entries SET status = ?, error = ?, retry_count = retry_count + 1, failed_at = ? WHERE id = ? """, (status, error, timestamp_value or "", entry_id), ) else: await db.execute( "UPDATE journal_entries SET status = ? WHERE id = ?", (status, entry_id), ) await db.commit() async def get_by_status( self, status: str, *, older_than: Optional[str] = None, limit: int = 100, ) -> List[Dict[str, Any]]: """ Get entries by status, optionally filtered by age. Args: status: Entry status to filter by older_than: ISO timestamp - only return entries created before this limit: Maximum entries to return """ async with aiosqlite.connect(str(self._db_path)) as db: db.row_factory = aiosqlite.Row if older_than: cursor = await db.execute( """ SELECT * FROM journal_entries WHERE status = ? AND created_at < ? ORDER BY created_at ASC LIMIT ? """, (status, older_than, limit), ) else: cursor = await db.execute( """ SELECT * FROM journal_entries WHERE status = ? ORDER BY created_at ASC LIMIT ? """, (status, limit), ) rows = await cursor.fetchall() return [dict(row) for row in rows] async def compact_thread(self, thread_id: str) -> int: """ Remove acknowledged entries for a completed thread. Returns: Number of entries removed """ async with aiosqlite.connect(str(self._db_path)) as db: cursor = await db.execute( """ DELETE FROM journal_entries WHERE thread_id = ? AND status = 'acked' """, (thread_id,), ) count = cursor.rowcount await db.commit() if count: logger.debug(f"Compacted {count} acked entries for thread {thread_id[:8]}...") return count async def compact_old(self, older_than: str) -> int: """ Remove old acknowledged entries regardless of thread. Args: older_than: ISO timestamp - remove acked entries older than this Returns: Number of entries removed """ async with aiosqlite.connect(str(self._db_path)) as db: cursor = await db.execute( """ DELETE FROM journal_entries WHERE status = 'acked' AND created_at < ? """, (older_than,), ) count = cursor.rowcount await db.commit() if count: logger.info(f"Compacted {count} old acked entries") return count async def get_stats(self) -> Dict[str, int]: """Get counts by status.""" async with aiosqlite.connect(str(self._db_path)) as db: cursor = await db.execute( """ SELECT status, COUNT(*) as count FROM journal_entries GROUP BY status """ ) rows = await cursor.fetchall() stats: Dict[str, int] = { "pending": 0, "dispatched": 0, "acked": 0, "failed": 0, "total": 0, } for row in rows: stats[row[0]] = row[1] stats["total"] += row[1] return stats async def get_unacknowledged( self, *, older_than_seconds: float = 30.0, max_retries: int = 3, ) -> List[Dict[str, Any]]: """ Get entries that were dispatched but never acknowledged. Used for crash recovery: these entries need to be replayed. Args: older_than_seconds: Only return entries older than this max_retries: Only return entries with fewer retries than this Returns: List of entry dicts suitable for replay """ from datetime import datetime, timezone, timedelta cutoff = ( datetime.now(timezone.utc) - timedelta(seconds=older_than_seconds) ).isoformat() async with aiosqlite.connect(str(self._db_path)) as db: db.row_factory = aiosqlite.Row cursor = await db.execute( """ SELECT * FROM journal_entries WHERE status IN ('pending', 'dispatched') AND created_at < ? AND retry_count < ? ORDER BY created_at ASC """, (cutoff, max_retries), ) rows = await cursor.fetchall() return [dict(row) for row in rows]