""" audit.py — SQLite-backed audit log for security events. Records: - Tool invocations (who called what tool with what params) - Peer constraint violations (blocked routing attempts) - Security events (unauthorized access, egress blocks, etc.) - Config changes (hot-reload events) """ from __future__ import annotations import json import logging import time from dataclasses import dataclass from pathlib import Path from typing import Any, Optional logger = logging.getLogger(__name__) # In-memory audit log (SQLite backing added when aiosqlite is available) _audit_entries: list[dict[str, Any]] = [] _max_memory_entries: int = 10000 @dataclass class AuditEntry: """A single audit log entry.""" timestamp: float event_type: str # "tool_invocation", "peer_violation", "security_event", "config_change" listener_name: str thread_id: Optional[str] details: dict[str, Any] severity: str = "info" # "info", "warning", "error", "critical" def record_event( event_type: str, listener_name: str, details: dict[str, Any], *, thread_id: Optional[str] = None, severity: str = "info", ) -> None: """ Record an audit event. Args: event_type: Category of event listener_name: Which listener triggered this details: Event-specific data thread_id: Associated thread UUID (if any) severity: Event severity level """ entry = { "timestamp": time.time(), "event_type": event_type, "listener_name": listener_name, "thread_id": thread_id, "details": details, "severity": severity, } _audit_entries.append(entry) # Trim old entries if over limit if len(_audit_entries) > _max_memory_entries: _audit_entries[:] = _audit_entries[-_max_memory_entries:] # Log security events at appropriate level if severity == "critical": logger.critical(f"AUDIT [{event_type}] {listener_name}: {details}") elif severity == "error": logger.error(f"AUDIT [{event_type}] {listener_name}: {details}") elif severity == "warning": logger.warning(f"AUDIT [{event_type}] {listener_name}: {details}") else: logger.debug(f"AUDIT [{event_type}] {listener_name}: {details}") def record_tool_invocation( listener_name: str, tool_name: str, params: dict[str, Any], success: bool, *, thread_id: Optional[str] = None, error: Optional[str] = None, ) -> None: """Record a tool invocation.""" record_event( "tool_invocation", listener_name, { "tool": tool_name, "params": _sanitize_params(params), "success": success, "error": error, }, thread_id=thread_id, ) def record_peer_violation( listener_name: str, target: str, *, thread_id: Optional[str] = None, ) -> None: """Record a peer constraint violation.""" record_event( "peer_violation", listener_name, {"attempted_target": target}, thread_id=thread_id, severity="warning", ) def record_security_event( listener_name: str, description: str, details: Optional[dict[str, Any]] = None, *, thread_id: Optional[str] = None, severity: str = "warning", ) -> None: """Record a security event.""" record_event( "security_event", listener_name, {"description": description, **(details or {})}, thread_id=thread_id, severity=severity, ) def get_entries( *, event_type: Optional[str] = None, listener_name: Optional[str] = None, severity: Optional[str] = None, since: Optional[float] = None, limit: int = 100, offset: int = 0, ) -> list[dict[str, Any]]: """ Query audit log entries with optional filtering. Returns entries in reverse chronological order (newest first). """ filtered = _audit_entries if event_type: filtered = [e for e in filtered if e["event_type"] == event_type] if listener_name: filtered = [e for e in filtered if e["listener_name"] == listener_name] if severity: filtered = [e for e in filtered if e["severity"] == severity] if since: filtered = [e for e in filtered if e["timestamp"] >= since] # Reverse chronological filtered = list(reversed(filtered)) return filtered[offset : offset + limit] def get_stats() -> dict[str, Any]: """Get audit log statistics.""" total = len(_audit_entries) by_type: dict[str, int] = {} by_severity: dict[str, int] = {} for entry in _audit_entries: by_type[entry["event_type"]] = by_type.get(entry["event_type"], 0) + 1 by_severity[entry["severity"]] = by_severity.get(entry["severity"], 0) + 1 return { "total_entries": total, "by_type": by_type, "by_severity": by_severity, "oldest": _audit_entries[0]["timestamp"] if _audit_entries else None, "newest": _audit_entries[-1]["timestamp"] if _audit_entries else None, } def clear() -> None: """Clear the audit log (for testing).""" _audit_entries.clear() def _sanitize_params(params: dict[str, Any]) -> dict[str, Any]: """Remove sensitive values from tool parameters before logging.""" sanitized = {} sensitive_keys = {"api_key", "password", "secret", "token", "credential"} for key, value in params.items(): if any(s in key.lower() for s in sensitive_keys): sanitized[key] = "***" elif isinstance(value, str) and len(value) > 500: sanitized[key] = value[:500] + "...(truncated)" else: sanitized[key] = value return sanitized