""" Tests for the MessageJournal and JournalStore. """ from __future__ import annotations import os import tempfile import pytest from xml_pipeline.message_bus.journal import ( JournalEntryStatus, MessageJournal, ) from xml_pipeline.message_bus.journal_store import JournalStore @pytest.fixture def tmp_db_path(): """Create a temporary database path.""" fd, path = tempfile.mkstemp(suffix=".db") os.close(fd) yield path try: os.unlink(path) except OSError: pass # Also clean up WAL/SHM files for suffix in ("-wal", "-shm"): try: os.unlink(path + suffix) except OSError: pass class TestJournalStore: """Test the SQLite persistence layer.""" async def test_initialize_creates_tables(self, tmp_db_path): store = JournalStore(tmp_db_path) await store.initialize() # Should not raise on second init await store.initialize() async def test_insert_and_get_by_status(self, tmp_db_path): store = JournalStore(tmp_db_path) await store.initialize() await store.insert( entry_id="e1", thread_id="t1", from_id="sender", to_id="receiver", payload_type="Greeting", payload_bytes=b"Alice", status="pending", created_at="2026-01-28T00:00:00Z", ) entries = await store.get_by_status("pending") assert len(entries) == 1 assert entries[0]["id"] == "e1" assert entries[0]["thread_id"] == "t1" assert entries[0]["payload_type"] == "Greeting" async def test_update_status(self, tmp_db_path): store = JournalStore(tmp_db_path) await store.initialize() await store.insert( entry_id="e1", thread_id="t1", from_id="sender", to_id="receiver", payload_type="Greeting", payload_bytes=b"", status="pending", created_at="2026-01-28T00:00:00Z", ) await store.update_status( "e1", "dispatched", timestamp_field="dispatched_at", timestamp_value="2026-01-28T00:00:01Z", ) entries = await store.get_by_status("dispatched") assert len(entries) == 1 assert entries[0]["dispatched_at"] == "2026-01-28T00:00:01Z" # No more pending pending = await store.get_by_status("pending") assert len(pending) == 0 async def test_update_status_with_error(self, tmp_db_path): store = JournalStore(tmp_db_path) await store.initialize() await store.insert( entry_id="e1", thread_id="t1", from_id="sender", to_id="receiver", payload_type="Greeting", payload_bytes=b"", status="dispatched", created_at="2026-01-28T00:00:00Z", ) await store.update_status( "e1", "failed", error="handler crashed", timestamp_value="2026-01-28T00:00:02Z", ) entries = await store.get_by_status("failed") assert len(entries) == 1 assert entries[0]["error"] == "handler crashed" assert entries[0]["retry_count"] == 1 async def test_compact_thread(self, tmp_db_path): store = JournalStore(tmp_db_path) await store.initialize() # Insert two entries: one acked, one pending for eid, status in [("e1", "acked"), ("e2", "pending")]: await store.insert( entry_id=eid, thread_id="t1", from_id="sender", to_id="receiver", payload_type="Greeting", payload_bytes=b"", status=status, created_at="2026-01-28T00:00:00Z", ) count = await store.compact_thread("t1") assert count == 1 # Only the acked one # Pending should still exist remaining = await store.get_by_status("pending") assert len(remaining) == 1 async def test_compact_old(self, tmp_db_path): store = JournalStore(tmp_db_path) await store.initialize() # Old acked entry await store.insert( entry_id="e1", thread_id="t1", from_id="sender", to_id="receiver", payload_type="Greeting", payload_bytes=b"", status="acked", created_at="2020-01-01T00:00:00Z", ) # Recent acked entry await store.insert( entry_id="e2", thread_id="t2", from_id="sender", to_id="receiver", payload_type="Greeting", payload_bytes=b"", status="acked", created_at="2099-01-01T00:00:00Z", ) count = await store.compact_old("2026-01-28T00:00:00Z") assert count == 1 # Only the old one async def test_get_stats(self, tmp_db_path): store = JournalStore(tmp_db_path) await store.initialize() for eid, status in [("e1", "pending"), ("e2", "dispatched"), ("e3", "acked"), ("e4", "acked")]: await store.insert( entry_id=eid, thread_id="t1", from_id="sender", to_id="receiver", payload_type="Greeting", payload_bytes=b"", status=status, created_at="2026-01-28T00:00:00Z", ) stats = await store.get_stats() assert stats["pending"] == 1 assert stats["dispatched"] == 1 assert stats["acked"] == 2 assert stats["total"] == 4 async def test_get_unacknowledged(self, tmp_db_path): store = JournalStore(tmp_db_path) await store.initialize() # Old dispatched entry (should be returned) await store.insert( entry_id="e1", thread_id="t1", from_id="sender", to_id="receiver", payload_type="Greeting", payload_bytes=b"", status="dispatched", created_at="2020-01-01T00:00:00Z", ) # Old pending entry (should be returned) await store.insert( entry_id="e2", thread_id="t1", from_id="sender", to_id="receiver", payload_type="Greeting", payload_bytes=b"", status="pending", created_at="2020-01-01T00:00:00Z", ) # Acked entry (should NOT be returned) await store.insert( entry_id="e3", thread_id="t1", from_id="sender", to_id="receiver", payload_type="Greeting", payload_bytes=b"", status="acked", created_at="2020-01-01T00:00:00Z", ) entries = await store.get_unacknowledged(older_than_seconds=0) assert len(entries) == 2 ids = {e["id"] for e in entries} assert "e1" in ids assert "e2" in ids class TestMessageJournal: """Test the MessageJournal (DispatchHook implementation).""" async def test_full_lifecycle(self, tmp_db_path): journal = MessageJournal(db_path=tmp_db_path) await journal.initialize() # on_intent eid = await journal.on_intent( thread_id="t1", from_id="sender", to_id="receiver", payload_type="Greeting", payload_bytes=b"Alice", ) assert eid # Non-empty UUID stats = await journal.get_stats() assert stats["pending"] == 1 # on_dispatched await journal.on_dispatched(eid) stats = await journal.get_stats() assert stats["dispatched"] == 1 assert stats["pending"] == 0 # on_acknowledged await journal.on_acknowledged(eid) stats = await journal.get_stats() assert stats["acked"] == 1 assert stats["dispatched"] == 0 async def test_failed_lifecycle(self, tmp_db_path): journal = MessageJournal(db_path=tmp_db_path) await journal.initialize() eid = await journal.on_intent( thread_id="t1", from_id="sender", to_id="receiver", payload_type="Greeting", payload_bytes=b"", ) await journal.on_dispatched(eid) await journal.on_failed(eid, "handler exploded") stats = await journal.get_stats() assert stats["failed"] == 1 async def test_thread_complete_compacts(self, tmp_db_path): journal = MessageJournal(db_path=tmp_db_path) await journal.initialize() # Create and ack an entry eid = await journal.on_intent( thread_id="t1", from_id="sender", to_id="receiver", payload_type="Greeting", payload_bytes=b"", ) await journal.on_dispatched(eid) await journal.on_acknowledged(eid) stats = await journal.get_stats() assert stats["acked"] == 1 # Thread complete should compact await journal.on_thread_complete("t1") stats = await journal.get_stats() assert stats["acked"] == 0 assert stats["total"] == 0 async def test_get_unacknowledged_for_replay(self, tmp_db_path): journal = MessageJournal( db_path=tmp_db_path, retry_after_seconds=0, ) await journal.initialize() # Create an entry and dispatch but don't ack eid = await journal.on_intent( thread_id="t1", from_id="sender", to_id="receiver", payload_type="Greeting", payload_bytes=b"Alice", ) await journal.on_dispatched(eid) # Should show up as unacknowledged entries = await journal.get_unacknowledged(older_than_seconds=0) assert len(entries) == 1 assert entries[0]["id"] == eid assert entries[0]["payload_bytes"] == b"Alice" async def test_empty_entry_id_is_noop(self, tmp_db_path): """Hook methods with empty entry_id should be no-ops.""" journal = MessageJournal(db_path=tmp_db_path) await journal.initialize() # These should not raise await journal.on_dispatched("") await journal.on_acknowledged("") await journal.on_failed("", "error") stats = await journal.get_stats() assert stats["total"] == 0 async def test_compact_old(self, tmp_db_path): journal = MessageJournal(db_path=tmp_db_path) await journal.initialize() # Create and ack an entry with old timestamp eid = await journal.on_intent( thread_id="t1", from_id="sender", to_id="receiver", payload_type="Greeting", payload_bytes=b"", ) await journal.on_dispatched(eid) await journal.on_acknowledged(eid) # Compact with 0 hours should remove it (entry is older than 0 hours ago) removed = await journal.compact_old(max_age_hours=0) # Entry was just created so 0 hours won't catch it # Use a large value to catch everything removed = await journal.compact_old(max_age_hours=999999) # This won't remove fresh entries either because they're newer # than now - 999999 hours. That's fine — the point is the API works. assert removed >= 0 class TestJournalEntryStatus: """Test the status enum.""" def test_values(self): assert JournalEntryStatus.PENDING.value == "pending" assert JournalEntryStatus.DISPATCHED.value == "dispatched" assert JournalEntryStatus.ACKNOWLEDGED.value == "acked" assert JournalEntryStatus.FAILED.value == "failed"