diff --git a/.gitignore b/.gitignore index a8560ec..a74e04a 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,7 @@ xml_pipeline/config/*.signed.xml # OS Thumbs.db .DS_Store + +# BloxServer local dev +bloxserver.db +bloxserver/.env diff --git a/bloxserver/.env.example b/bloxserver/.env.example new file mode 100644 index 0000000..ab5bcda --- /dev/null +++ b/bloxserver/.env.example @@ -0,0 +1,54 @@ +# BloxServer API Environment Variables +# Copy this file to .env and fill in the values + +# ============================================================================= +# Environment +# ============================================================================= +ENV=development +# ENV=production + +# ============================================================================= +# Database (PostgreSQL) +# ============================================================================= +DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/bloxserver + +# Set to true to auto-create tables on startup (disable in production) +AUTO_CREATE_TABLES=true + +# ============================================================================= +# Clerk Authentication +# ============================================================================= +CLERK_ISSUER=https://your-clerk-instance.clerk.accounts.dev +CLERK_AUDIENCE=your-clerk-audience + +# ============================================================================= +# Stripe Billing +# ============================================================================= +STRIPE_SECRET_KEY=sk_test_... +STRIPE_WEBHOOK_SECRET=whsec_... + +# ============================================================================= +# API Key Encryption +# ============================================================================= +# Generate with: python -c "from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())" +API_KEY_ENCRYPTION_KEY=your-fernet-key-here + +# ============================================================================= +# CORS +# ============================================================================= +CORS_ORIGINS=http://localhost:3000,https://app.openblox.ai + +# ============================================================================= +# Webhooks +# ============================================================================= +WEBHOOK_BASE_URL=https://api.openblox.ai/webhooks + +# ============================================================================= +# Redis (optional, for caching/rate limiting) +# ============================================================================= +# REDIS_URL=redis://localhost:6379 + +# ============================================================================= +# Docs +# ============================================================================= +ENABLE_DOCS=true diff --git a/bloxserver/Dockerfile b/bloxserver/Dockerfile new file mode 100644 index 0000000..57f0393 --- /dev/null +++ b/bloxserver/Dockerfile @@ -0,0 +1,58 @@ +# BloxServer API Dockerfile +# Multi-stage build for smaller production image + +# ============================================================================= +# Build stage +# ============================================================================= +FROM python:3.12-slim as builder + +WORKDIR /app + +# Install build dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +# Copy requirements first for layer caching +COPY requirements.txt . +RUN pip wheel --no-cache-dir --wheel-dir /app/wheels -r requirements.txt + +# ============================================================================= +# Production stage +# ============================================================================= +FROM python:3.12-slim as production + +WORKDIR /app + +# Create non-root user +RUN groupadd --gid 1000 bloxserver \ + && useradd --uid 1000 --gid bloxserver --shell /bin/bash --create-home bloxserver + +# Install runtime dependencies +RUN apt-get update && apt-get install -y --no-install-recommends \ + curl \ + && rm -rf /var/lib/apt/lists/* + +# Copy wheels from builder and install +COPY --from=builder /app/wheels /wheels +RUN pip install --no-cache-dir /wheels/* && rm -rf /wheels + +# Copy application code +COPY --chown=bloxserver:bloxserver . /app/bloxserver + +# Set Python path +ENV PYTHONPATH=/app +ENV PYTHONUNBUFFERED=1 + +# Switch to non-root user +USER bloxserver + +# Health check +HEALTHCHECK --interval=30s --timeout=5s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:8000/health/live || exit 1 + +# Expose port +EXPOSE 8000 + +# Run with uvicorn +CMD ["uvicorn", "bloxserver.api.main:app", "--host", "0.0.0.0", "--port", "8000"] diff --git a/bloxserver/README.md b/bloxserver/README.md new file mode 100644 index 0000000..7321edd --- /dev/null +++ b/bloxserver/README.md @@ -0,0 +1,203 @@ +# BloxServer API + +Backend API for BloxServer (OpenBlox.ai) - Visual AI Agent Workflow Builder. + +## Quick Start + +### With Docker Compose (Recommended) + +```bash +cd bloxserver + +# Start PostgreSQL, Redis, and API +docker-compose up -d + +# Check logs +docker-compose logs -f api + +# API available at http://localhost:8000 +# Docs at http://localhost:8000/docs +``` + +### Local Development + +```bash +cd bloxserver + +# Create virtual environment +python -m venv .venv +source .venv/bin/activate # Linux/macOS +# .venv\Scripts\activate # Windows + +# Install dependencies +pip install -r requirements.txt + +# Copy environment variables +cp .env.example .env +# Edit .env with your settings + +# Start PostgreSQL and Redis (or use Docker) +docker-compose up -d postgres redis + +# Run the API +python -m bloxserver.api.main +# Or with uvicorn directly: +uvicorn bloxserver.api.main:app --reload +``` + +## API Endpoints + +### Health + +- `GET /health` - Basic health check +- `GET /health/ready` - Readiness check (includes DB) +- `GET /health/live` - Liveness check + +### Flows + +- `GET /api/v1/flows` - List flows +- `POST /api/v1/flows` - Create flow +- `GET /api/v1/flows/{id}` - Get flow +- `PATCH /api/v1/flows/{id}` - Update flow +- `DELETE /api/v1/flows/{id}` - Delete flow +- `POST /api/v1/flows/{id}/start` - Start flow +- `POST /api/v1/flows/{id}/stop` - Stop flow + +### Triggers + +- `GET /api/v1/flows/{flow_id}/triggers` - List triggers +- `POST /api/v1/flows/{flow_id}/triggers` - Create trigger +- `GET /api/v1/flows/{flow_id}/triggers/{id}` - Get trigger +- `DELETE /api/v1/flows/{flow_id}/triggers/{id}` - Delete trigger +- `POST /api/v1/flows/{flow_id}/triggers/{id}/regenerate-token` - Regenerate webhook token + +### Executions + +- `GET /api/v1/flows/{flow_id}/executions` - List executions +- `GET /api/v1/flows/{flow_id}/executions/{id}` - Get execution +- `POST /api/v1/flows/{flow_id}/executions/run` - Manual trigger +- `GET /api/v1/flows/{flow_id}/executions/stats` - Get stats + +### Webhooks + +- `POST /webhooks/{token}` - Trigger flow via webhook +- `GET /webhooks/{token}/test` - Test webhook token + +## Project Structure + +``` +bloxserver/ +├── api/ +│ ├── __init__.py +│ ├── main.py # FastAPI app entry point +│ ├── dependencies.py # Auth, DB session dependencies +│ ├── schemas.py # Pydantic request/response models +│ ├── models/ +│ │ ├── __init__.py +│ │ ├── database.py # SQLAlchemy engine/session +│ │ └── tables.py # ORM table definitions +│ └── routes/ +│ ├── __init__.py +│ ├── flows.py # Flow CRUD +│ ├── triggers.py # Trigger CRUD +│ ├── executions.py # Execution history +│ ├── webhooks.py # Webhook handler +│ └── health.py # Health checks +├── requirements.txt +├── Dockerfile +├── docker-compose.yml +├── .env.example +└── README.md +``` + +## Authentication + +Uses Clerk for JWT authentication. All `/api/v1/*` endpoints require a valid JWT. + +```bash +curl -H "Authorization: Bearer " \ + http://localhost:8000/api/v1/flows +``` + +## Environment Variables + +See `.env.example` for all configuration options. + +Key variables: +- `DATABASE_URL` - PostgreSQL connection string +- `CLERK_ISSUER` - Clerk JWT issuer URL +- `STRIPE_SECRET_KEY` - Stripe API key +- `API_KEY_ENCRYPTION_KEY` - Fernet key for encrypting user API keys + +## Database Migrations + +Using Alembic for migrations (not yet set up): + +```bash +# Initialize (first time) +alembic init alembic + +# Create migration +alembic revision --autogenerate -m "description" + +# Apply migrations +alembic upgrade head +``` + +## Testing + +```bash +# Install test dependencies +pip install pytest pytest-asyncio httpx + +# Run tests +pytest tests/ -v +``` + +## Deployment + +### Railway / Render / Fly.io + +1. Connect your repo +2. Set environment variables +3. Deploy + +### Kubernetes + +```yaml +apiVersion: apps/v1 +kind: Deployment +metadata: + name: bloxserver-api +spec: + replicas: 3 + template: + spec: + containers: + - name: api + image: your-registry/bloxserver-api:latest + ports: + - containerPort: 8000 + env: + - name: DATABASE_URL + valueFrom: + secretKeyRef: + name: bloxserver-secrets + key: database-url + livenessProbe: + httpGet: + path: /health/live + port: 8000 + readinessProbe: + httpGet: + path: /health/ready + port: 8000 +``` + +## Next Steps + +- [ ] Alembic migrations setup +- [ ] Stripe webhook handlers +- [ ] Redis rate limiting +- [ ] Container orchestration integration +- [ ] WebSocket for real-time logs diff --git a/bloxserver/__init__.py b/bloxserver/__init__.py new file mode 100644 index 0000000..e0273b8 --- /dev/null +++ b/bloxserver/__init__.py @@ -0,0 +1,7 @@ +""" +BloxServer - Visual AI Agent Workflow Builder + +SaaS backend for OpenBlox.ai +""" + +__version__ = "0.1.0" diff --git a/bloxserver/api/__init__.py b/bloxserver/api/__init__.py new file mode 100644 index 0000000..eaf5279 --- /dev/null +++ b/bloxserver/api/__init__.py @@ -0,0 +1 @@ +"""BloxServer API package.""" diff --git a/bloxserver/api/dependencies.py b/bloxserver/api/dependencies.py new file mode 100644 index 0000000..f1edca8 --- /dev/null +++ b/bloxserver/api/dependencies.py @@ -0,0 +1,236 @@ +""" +FastAPI dependencies for authentication and database access. + +Uses Clerk for JWT validation. +""" + +from __future__ import annotations + +import os +from typing import Annotated +from uuid import UUID + +import httpx +from fastapi import Depends, HTTPException, Request, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from bloxserver.api.models.database import get_db +from bloxserver.api.models.tables import UserRecord + +# Dev mode - skip auth for local testing +DEV_MODE = os.getenv("ENV", "development") == "development" and not os.getenv("CLERK_ISSUER") + +# Clerk configuration +CLERK_ISSUER = os.getenv("CLERK_ISSUER", "") +CLERK_JWKS_URL = f"{CLERK_ISSUER}/.well-known/jwks.json" if CLERK_ISSUER else "" + +# Security scheme +security = HTTPBearer(auto_error=False) + + +# ============================================================================= +# JWT Validation (Clerk) +# ============================================================================= + + +async def get_clerk_jwks() -> dict: + """Fetch Clerk's JWKS for JWT validation.""" + async with httpx.AsyncClient() as client: + response = await client.get(CLERK_JWKS_URL) + response.raise_for_status() + return response.json() + + +async def validate_clerk_token(token: str) -> dict: + """ + Validate a Clerk JWT token and return the payload. + + In production, use a proper JWT library with caching. + This is a simplified version for the scaffold. + """ + import jwt + from jwt import PyJWKClient + + try: + # Get signing key from Clerk's JWKS + jwks_client = PyJWKClient(CLERK_JWKS_URL) + signing_key = jwks_client.get_signing_key_from_jwt(token) + + # Decode and validate + payload = jwt.decode( + token, + signing_key.key, + algorithms=["RS256"], + audience=os.getenv("CLERK_AUDIENCE"), + issuer=CLERK_ISSUER, + ) + + return payload + + except jwt.ExpiredSignatureError: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Token has expired", + ) + except jwt.InvalidTokenError as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"Invalid token: {e}", + ) + + +# ============================================================================= +# Current User Dependency +# ============================================================================= + + +class CurrentUser: + """Authenticated user context.""" + + def __init__(self, user: UserRecord, clerk_payload: dict): + self.user = user + self.clerk_payload = clerk_payload + + @property + def id(self) -> UUID: + return self.user.id + + @property + def clerk_id(self) -> str: + return self.user.clerk_id + + @property + def email(self) -> str: + return self.user.email + + @property + def tier(self) -> str: + return self.user.tier.value + + +async def get_current_user( + request: Request, + credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)], + db: Annotated[AsyncSession, Depends(get_db)], +) -> CurrentUser: + """ + Dependency that validates the JWT and returns the current user. + + Creates the user record if this is their first request (synced from Clerk). + In DEV_MODE without Clerk configured, returns a test user. + """ + # Dev mode - create/return a test user without auth + if DEV_MODE: + dev_clerk_id = "dev_user_001" + result = await db.execute( + select(UserRecord).where(UserRecord.clerk_id == dev_clerk_id) + ) + user = result.scalar_one_or_none() + + if not user: + from bloxserver.api.models.tables import Tier + user = UserRecord( + clerk_id=dev_clerk_id, + email="dev@localhost", + name="Dev User", + tier=Tier.PRO, # Give dev user Pro access + ) + db.add(user) + await db.flush() + + return CurrentUser(user=user, clerk_payload={"sub": dev_clerk_id, "dev": True}) + + # Production mode - require Clerk auth + if not credentials: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Missing authentication token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # Validate JWT + payload = await validate_clerk_token(credentials.credentials) + clerk_id = payload.get("sub") + + if not clerk_id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid token: missing subject", + ) + + # Look up or create user + result = await db.execute( + select(UserRecord).where(UserRecord.clerk_id == clerk_id) + ) + user = result.scalar_one_or_none() + + if not user: + # First login - create user record from Clerk data + user = UserRecord( + clerk_id=clerk_id, + email=payload.get("email", f"{clerk_id}@unknown"), + name=payload.get("name"), + avatar_url=payload.get("image_url"), + ) + db.add(user) + await db.flush() # Get the ID without committing + + return CurrentUser(user=user, clerk_payload=payload) + + +# Type alias for cleaner route signatures +AuthenticatedUser = Annotated[CurrentUser, Depends(get_current_user)] +DbSession = Annotated[AsyncSession, Depends(get_db)] + + +# ============================================================================= +# Optional Auth (for public endpoints) +# ============================================================================= + + +async def get_optional_user( + request: Request, + credentials: Annotated[HTTPAuthorizationCredentials | None, Depends(security)], + db: Annotated[AsyncSession, Depends(get_db)], +) -> CurrentUser | None: + """ + Like get_current_user, but returns None instead of raising if not authenticated. + """ + if not credentials: + return None + + try: + return await get_current_user(request, credentials, db) + except HTTPException: + return None + + +OptionalUser = Annotated[CurrentUser | None, Depends(get_optional_user)] + + +# ============================================================================= +# Tier Checks +# ============================================================================= + + +def require_tier(*allowed_tiers: str): + """ + Dependency factory that requires the user to be on one of the allowed tiers. + + Usage: + @router.post("/wasm", dependencies=[Depends(require_tier("pro", "enterprise"))]) + """ + async def check_tier(user: AuthenticatedUser) -> None: + if user.tier not in allowed_tiers: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"This feature requires one of: {', '.join(allowed_tiers)}", + ) + + return check_tier + + +RequirePro = Depends(require_tier("pro", "enterprise", "high_frequency")) +RequireEnterprise = Depends(require_tier("enterprise", "high_frequency")) diff --git a/bloxserver/api/main.py b/bloxserver/api/main.py new file mode 100644 index 0000000..b8912ed --- /dev/null +++ b/bloxserver/api/main.py @@ -0,0 +1,166 @@ +""" +BloxServer API - FastAPI Application + +Main entry point for the BloxServer backend API. +""" + +from __future__ import annotations + +import os +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +from fastapi import FastAPI, Request, status +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse + +from bloxserver.api.models.database import init_db +from bloxserver.api.routes import executions, flows, health, triggers, webhooks +from bloxserver.api.schemas import ApiError + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: + """Application lifespan - startup and shutdown events.""" + # Startup + print("Starting BloxServer API...") + + # Initialize database tables + if os.getenv("AUTO_CREATE_TABLES", "true").lower() == "true": + await init_db() + print("Database tables initialized") + + yield + + # Shutdown + print("Shutting down BloxServer API...") + + +# Create FastAPI app +app = FastAPI( + title="BloxServer API", + description="Backend API for BloxServer - Visual AI Agent Workflow Builder", + version="0.1.0", + lifespan=lifespan, + docs_url="/docs" if os.getenv("ENABLE_DOCS", "true").lower() == "true" else None, + redoc_url="/redoc" if os.getenv("ENABLE_DOCS", "true").lower() == "true" else None, +) + + +# ============================================================================= +# CORS Middleware +# ============================================================================= + +# Allowed origins (configure via environment) +CORS_ORIGINS = os.getenv( + "CORS_ORIGINS", + "http://localhost:3000,https://app.openblox.ai", +).split(",") + +app.add_middleware( + CORSMiddleware, + allow_origins=CORS_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# ============================================================================= +# Exception Handlers +# ============================================================================= + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler( + request: Request, exc: RequestValidationError +) -> JSONResponse: + """Convert validation errors to standard API error format.""" + errors = exc.errors() + details = { + ".".join(str(loc) for loc in err["loc"]): err["msg"] + for err in errors + } + + return JSONResponse( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + content=ApiError( + code="validation_error", + message="Request validation failed", + details=details, + ).model_dump(by_alias=True), + ) + + +@app.exception_handler(Exception) +async def general_exception_handler( + request: Request, exc: Exception +) -> JSONResponse: + """Catch-all exception handler.""" + # In production, don't expose internal errors + if os.getenv("ENV", "development") == "production": + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=ApiError( + code="internal_error", + message="An unexpected error occurred", + ).model_dump(by_alias=True), + ) + + # In development, include error details + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content=ApiError( + code="internal_error", + message=str(exc), + details={"type": type(exc).__name__}, + ).model_dump(by_alias=True), + ) + + +# ============================================================================= +# Routes +# ============================================================================= + +# Health checks (no auth) +app.include_router(health.router) + +# Webhook endpoint (token-based auth) +app.include_router(webhooks.router) + +# Protected API routes +app.include_router(flows.router, prefix="/api/v1") +app.include_router(triggers.router, prefix="/api/v1") +app.include_router(executions.router, prefix="/api/v1") + + +# ============================================================================= +# Root endpoint +# ============================================================================= + + +@app.get("/") +async def root() -> dict: + """Root endpoint - API info.""" + return { + "name": "BloxServer API", + "version": "0.1.0", + "docs": "/docs", + "health": "/health", + } + + +# ============================================================================= +# Run with uvicorn +# ============================================================================= + +if __name__ == "__main__": + import uvicorn + + uvicorn.run( + "bloxserver.api.main:app", + host=os.getenv("HOST", "0.0.0.0"), + port=int(os.getenv("PORT", "8000")), + reload=os.getenv("ENV", "development") == "development", + ) diff --git a/bloxserver/api/models/__init__.py b/bloxserver/api/models/__init__.py new file mode 100644 index 0000000..f8a248d --- /dev/null +++ b/bloxserver/api/models/__init__.py @@ -0,0 +1,23 @@ +"""Database and Pydantic models.""" + +from bloxserver.api.models.database import Base, get_db, init_db +from bloxserver.api.models.tables import ( + ExecutionRecord, + FlowRecord, + TriggerRecord, + UserApiKeyRecord, + UserRecord, + UsageRecord, +) + +__all__ = [ + "Base", + "get_db", + "init_db", + "UserRecord", + "FlowRecord", + "TriggerRecord", + "ExecutionRecord", + "UserApiKeyRecord", + "UsageRecord", +] diff --git a/bloxserver/api/models/database.py b/bloxserver/api/models/database.py new file mode 100644 index 0000000..15d4430 --- /dev/null +++ b/bloxserver/api/models/database.py @@ -0,0 +1,84 @@ +""" +Database connection and session management. + +Uses SQLAlchemy async with PostgreSQL. +""" + +from __future__ import annotations + +import os +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.orm import DeclarativeBase + + +class Base(DeclarativeBase): + """Base class for all ORM models.""" + + pass + + +# Database URL from environment +# Supports both PostgreSQL and SQLite (for local testing) +DATABASE_URL = os.getenv( + "DATABASE_URL", + "sqlite+aiosqlite:///./bloxserver.db", # SQLite default for easy local testing +) + +# Create async engine with appropriate settings +_is_sqlite = DATABASE_URL.startswith("sqlite") + +if _is_sqlite: + # SQLite doesn't support pool settings + engine = create_async_engine( + DATABASE_URL, + echo=os.getenv("SQL_ECHO", "false").lower() == "true", + connect_args={"check_same_thread": False}, + ) +else: + # PostgreSQL with connection pooling + engine = create_async_engine( + DATABASE_URL, + echo=os.getenv("SQL_ECHO", "false").lower() == "true", + pool_pre_ping=True, + pool_size=10, + max_overflow=20, + ) + +# Session factory +async_session_maker = async_sessionmaker( + engine, + class_=AsyncSession, + expire_on_commit=False, +) + + +async def init_db() -> None: + """Create all tables. Call once at startup.""" + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) + + +async def get_db() -> AsyncGenerator[AsyncSession, None]: + """Dependency for FastAPI routes. Yields a database session.""" + async with async_session_maker() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + + +@asynccontextmanager +async def get_db_context() -> AsyncGenerator[AsyncSession, None]: + """Context manager for use outside of FastAPI routes.""" + async with async_session_maker() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise diff --git a/bloxserver/api/models/tables.py b/bloxserver/api/models/tables.py new file mode 100644 index 0000000..a3c33f6 --- /dev/null +++ b/bloxserver/api/models/tables.py @@ -0,0 +1,381 @@ +""" +SQLAlchemy ORM models for BloxServer. + +These map to the Pydantic models in schemas.py and TypeScript types in types.ts. +""" + +from __future__ import annotations + +import enum +from datetime import datetime +from typing import Any +from uuid import uuid4 + +from sqlalchemy import ( + JSON, + Boolean, + DateTime, + Enum, + ForeignKey, + Index, + Integer, + LargeBinary, + Numeric, + String, + Text, + func, +) +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from bloxserver.api.models.database import Base + + +# ============================================================================= +# Enums +# ============================================================================= + + +class Tier(str, enum.Enum): + """User subscription tier.""" + + FREE = "free" + PRO = "pro" + ENTERPRISE = "enterprise" + HIGH_FREQUENCY = "high_frequency" + + +class BillingStatus(str, enum.Enum): + """Subscription billing status.""" + + ACTIVE = "active" + TRIALING = "trialing" + PAST_DUE = "past_due" + CANCELED = "canceled" + CANCELING = "canceling" + + +class FlowStatus(str, enum.Enum): + """Flow runtime status.""" + + STOPPED = "stopped" + STARTING = "starting" + RUNNING = "running" + STOPPING = "stopping" + ERROR = "error" + + +class TriggerType(str, enum.Enum): + """How a flow can be triggered.""" + + WEBHOOK = "webhook" + SCHEDULE = "schedule" + MANUAL = "manual" + + +class ExecutionStatus(str, enum.Enum): + """Status of a flow execution.""" + + RUNNING = "running" + SUCCESS = "success" + ERROR = "error" + TIMEOUT = "timeout" + + +# ============================================================================= +# Users (synced from Clerk) +# ============================================================================= + + +class UserRecord(Base): + """User account, synced from Clerk.""" + + __tablename__ = "users" + + id: Mapped[UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid4 + ) + clerk_id: Mapped[str] = mapped_column(String(255), unique=True, nullable=False) + email: Mapped[str] = mapped_column(String(255), nullable=False) + name: Mapped[str | None] = mapped_column(String(255)) + avatar_url: Mapped[str | None] = mapped_column(Text) + + # Stripe integration + stripe_customer_id: Mapped[str | None] = mapped_column(String(255), unique=True) + stripe_subscription_id: Mapped[str | None] = mapped_column(String(255)) + stripe_subscription_item_id: Mapped[str | None] = mapped_column(String(255)) + + # Billing state (cached from Stripe) + tier: Mapped[Tier] = mapped_column(Enum(Tier), default=Tier.FREE) + billing_status: Mapped[BillingStatus] = mapped_column( + Enum(BillingStatus), default=BillingStatus.ACTIVE + ) + trial_ends_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + current_period_start: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + current_period_end: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + + # Timestamps + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + # Relationships + flows: Mapped[list[FlowRecord]] = relationship(back_populates="user", cascade="all, delete-orphan") + api_keys: Mapped[list[UserApiKeyRecord]] = relationship(back_populates="user", cascade="all, delete-orphan") + usage_records: Mapped[list[UsageRecord]] = relationship(back_populates="user", cascade="all, delete-orphan") + + __table_args__ = ( + Index("idx_users_clerk_id", "clerk_id"), + Index("idx_users_stripe_customer", "stripe_customer_id"), + ) + + +# ============================================================================= +# Flows +# ============================================================================= + + +class FlowRecord(Base): + """A user's workflow/flow.""" + + __tablename__ = "flows" + + id: Mapped[UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid4 + ) + user_id: Mapped[UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False + ) + name: Mapped[str] = mapped_column(String(100), nullable=False) + description: Mapped[str | None] = mapped_column(String(500)) + + # The actual workflow definition + organism_yaml: Mapped[str] = mapped_column(Text, nullable=False, default="") + + # React Flow canvas state (JSON) + canvas_state: Mapped[dict[str, Any] | None] = mapped_column(JSON) + + # Runtime state + status: Mapped[FlowStatus] = mapped_column(Enum(FlowStatus), default=FlowStatus.STOPPED) + container_id: Mapped[str | None] = mapped_column(String(255)) + error_message: Mapped[str | None] = mapped_column(Text) + + # Timestamps + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + # Relationships + user: Mapped[UserRecord] = relationship(back_populates="flows") + triggers: Mapped[list[TriggerRecord]] = relationship(back_populates="flow", cascade="all, delete-orphan") + executions: Mapped[list[ExecutionRecord]] = relationship(back_populates="flow", cascade="all, delete-orphan") + + __table_args__ = ( + Index("idx_flows_user_id", "user_id"), + Index("idx_flows_status", "status"), + ) + + +# ============================================================================= +# Triggers +# ============================================================================= + + +class TriggerRecord(Base): + """A trigger that can start a flow.""" + + __tablename__ = "triggers" + + id: Mapped[UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid4 + ) + flow_id: Mapped[UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("flows.id", ondelete="CASCADE"), nullable=False + ) + type: Mapped[TriggerType] = mapped_column(Enum(TriggerType), nullable=False) + name: Mapped[str] = mapped_column(String(100), nullable=False) + + # Trigger configuration (JSON) + config: Mapped[dict[str, Any]] = mapped_column(JSON, nullable=False, default=dict) + + # Webhook-specific fields + webhook_token: Mapped[str | None] = mapped_column(String(64), unique=True) + webhook_url: Mapped[str | None] = mapped_column(Text) + + # Timestamps + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + + # Relationships + flow: Mapped[FlowRecord] = relationship(back_populates="triggers") + executions: Mapped[list[ExecutionRecord]] = relationship(back_populates="trigger") + + __table_args__ = ( + Index("idx_triggers_flow_id", "flow_id"), + Index("idx_triggers_webhook_token", "webhook_token"), + ) + + +# ============================================================================= +# Executions +# ============================================================================= + + +class ExecutionRecord(Base): + """A single execution/run of a flow.""" + + __tablename__ = "executions" + + id: Mapped[UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid4 + ) + flow_id: Mapped[UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("flows.id", ondelete="CASCADE"), nullable=False + ) + trigger_id: Mapped[UUID | None] = mapped_column( + UUID(as_uuid=True), ForeignKey("triggers.id", ondelete="SET NULL") + ) + trigger_type: Mapped[TriggerType] = mapped_column(Enum(TriggerType), nullable=False) + + # Execution state + status: Mapped[ExecutionStatus] = mapped_column( + Enum(ExecutionStatus), default=ExecutionStatus.RUNNING + ) + error_message: Mapped[str | None] = mapped_column(Text) + + # Payloads (JSON strings for flexibility) + input_payload: Mapped[str | None] = mapped_column(Text) + output_payload: Mapped[str | None] = mapped_column(Text) + + # Timing + started_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + duration_ms: Mapped[int | None] = mapped_column(Integer) + + # Relationships + flow: Mapped[FlowRecord] = relationship(back_populates="executions") + trigger: Mapped[TriggerRecord | None] = relationship(back_populates="executions") + + __table_args__ = ( + Index("idx_executions_flow_id", "flow_id"), + Index("idx_executions_started_at", "started_at"), + Index("idx_executions_status", "status"), + ) + + +# ============================================================================= +# User API Keys (BYOK) +# ============================================================================= + + +class UserApiKeyRecord(Base): + """User's own API keys for BYOK (Bring Your Own Key).""" + + __tablename__ = "user_api_keys" + + id: Mapped[UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid4 + ) + user_id: Mapped[UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False + ) + provider: Mapped[str] = mapped_column(String(50), nullable=False) + + # Encrypted API key + encrypted_key: Mapped[bytes] = mapped_column(LargeBinary, nullable=False) + key_hint: Mapped[str | None] = mapped_column(String(20)) # Last few chars for display + + # Validation state + is_valid: Mapped[bool] = mapped_column(Boolean, default=True) + last_error: Mapped[str | None] = mapped_column(String(255)) + last_used_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + + # Timestamps + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + + # Relationships + user: Mapped[UserRecord] = relationship(back_populates="api_keys") + + __table_args__ = ( + Index("idx_user_api_keys_user_provider", "user_id", "provider", unique=True), + ) + + +# ============================================================================= +# Usage Tracking +# ============================================================================= + + +class UsageRecord(Base): + """Usage tracking for billing.""" + + __tablename__ = "usage_records" + + id: Mapped[UUID] = mapped_column( + UUID(as_uuid=True), primary_key=True, default=uuid4 + ) + user_id: Mapped[UUID] = mapped_column( + UUID(as_uuid=True), ForeignKey("users.id", ondelete="CASCADE"), nullable=False + ) + period_start: Mapped[datetime] = mapped_column( + DateTime(timezone=True), nullable=False + ) + + # Metrics + workflow_runs: Mapped[int] = mapped_column(Integer, default=0) + llm_tokens_in: Mapped[int] = mapped_column(Integer, default=0) + llm_tokens_out: Mapped[int] = mapped_column(Integer, default=0) + wasm_cpu_seconds: Mapped[float] = mapped_column(Numeric(10, 2), default=0) + storage_gb_hours: Mapped[float] = mapped_column(Numeric(10, 2), default=0) + + # Stripe sync state + last_synced_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True)) + last_synced_runs: Mapped[int] = mapped_column(Integer, default=0) + + # Timestamps + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now(), onupdate=func.now() + ) + + # Relationships + user: Mapped[UserRecord] = relationship(back_populates="usage_records") + + __table_args__ = ( + Index("idx_usage_user_period", "user_id", "period_start", unique=True), + ) + + +# ============================================================================= +# Stripe Events (Idempotency) +# ============================================================================= + + +class StripeEventRecord(Base): + """Processed Stripe webhook events for idempotency.""" + + __tablename__ = "stripe_events" + + event_id: Mapped[str] = mapped_column(String(255), primary_key=True) + event_type: Mapped[str] = mapped_column(String(100), nullable=False) + processed_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), server_default=func.now() + ) + payload: Mapped[dict[str, Any] | None] = mapped_column(JSON) + + __table_args__ = ( + Index("idx_stripe_events_processed", "processed_at"), + ) diff --git a/bloxserver/api/routes/__init__.py b/bloxserver/api/routes/__init__.py new file mode 100644 index 0000000..fb0a2f8 --- /dev/null +++ b/bloxserver/api/routes/__init__.py @@ -0,0 +1 @@ +"""API route modules.""" diff --git a/bloxserver/api/routes/executions.py b/bloxserver/api/routes/executions.py new file mode 100644 index 0000000..5c72c0c --- /dev/null +++ b/bloxserver/api/routes/executions.py @@ -0,0 +1,204 @@ +""" +Execution history and manual trigger endpoints. + +Executions are immutable records of flow runs. +""" + +from __future__ import annotations + +from datetime import datetime +from uuid import UUID + +from fastapi import APIRouter, HTTPException, status +from sqlalchemy import func, select + +from bloxserver.api.dependencies import AuthenticatedUser, DbSession +from bloxserver.api.models.tables import ( + ExecutionRecord, + ExecutionStatus, + FlowRecord, + TriggerType, +) +from bloxserver.api.schemas import Execution, ExecutionSummary, PaginatedResponse + +router = APIRouter(prefix="/flows/{flow_id}/executions", tags=["executions"]) + + +@router.get("", response_model=PaginatedResponse[ExecutionSummary]) +async def list_executions( + flow_id: UUID, + user: AuthenticatedUser, + db: DbSession, + page: int = 1, + page_size: int = 50, + status_filter: ExecutionStatus | None = None, +) -> PaginatedResponse[ExecutionSummary]: + """List execution history for a flow.""" + # Verify flow ownership + flow_query = select(FlowRecord).where( + FlowRecord.id == flow_id, + FlowRecord.user_id == user.id, + ) + flow = (await db.execute(flow_query)).scalar_one_or_none() + + if not flow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Flow not found", + ) + + offset = (page - 1) * page_size + + # Build query + base_query = select(ExecutionRecord).where(ExecutionRecord.flow_id == flow_id) + if status_filter: + base_query = base_query.where(ExecutionRecord.status == status_filter) + + # Get total count + count_query = select(func.count()).select_from(base_query.subquery()) + total = (await db.execute(count_query)).scalar() or 0 + + # Get page + query = base_query.order_by(ExecutionRecord.started_at.desc()).offset(offset).limit(page_size) + result = await db.execute(query) + executions = result.scalars().all() + + return PaginatedResponse( + items=[ExecutionSummary.model_validate(e) for e in executions], + total=total, + page=page, + page_size=page_size, + has_more=offset + len(executions) < total, + ) + + +@router.get("/{execution_id}", response_model=Execution) +async def get_execution( + flow_id: UUID, + execution_id: UUID, + user: AuthenticatedUser, + db: DbSession, +) -> Execution: + """Get details of a single execution.""" + # Verify flow ownership + flow_query = select(FlowRecord).where( + FlowRecord.id == flow_id, + FlowRecord.user_id == user.id, + ) + flow = (await db.execute(flow_query)).scalar_one_or_none() + + if not flow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Flow not found", + ) + + # Get execution + query = select(ExecutionRecord).where( + ExecutionRecord.id == execution_id, + ExecutionRecord.flow_id == flow_id, + ) + result = await db.execute(query) + execution = result.scalar_one_or_none() + + if not execution: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Execution not found", + ) + + return Execution.model_validate(execution) + + +@router.post("/run", response_model=Execution, status_code=status.HTTP_201_CREATED) +async def run_flow_manually( + flow_id: UUID, + user: AuthenticatedUser, + db: DbSession, + input_payload: str | None = None, +) -> Execution: + """ + Manually trigger a flow execution. + + The flow must be in 'running' state. + """ + # Verify flow ownership + flow_query = select(FlowRecord).where( + FlowRecord.id == flow_id, + FlowRecord.user_id == user.id, + ) + flow = (await db.execute(flow_query)).scalar_one_or_none() + + if not flow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Flow not found", + ) + + if flow.status != "running": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Flow must be running to execute (current: {flow.status})", + ) + + # Create execution record + execution = ExecutionRecord( + flow_id=flow_id, + trigger_type=TriggerType.MANUAL, + status=ExecutionStatus.RUNNING, + input_payload=input_payload, + ) + db.add(execution) + await db.flush() + + # TODO: Actually dispatch to the running container + # For now, just return the execution record + + return Execution.model_validate(execution) + + +# ============================================================================= +# Stats endpoint +# ============================================================================= + + +@router.get("/stats", response_model=dict) +async def get_execution_stats( + flow_id: UUID, + user: AuthenticatedUser, + db: DbSession, +) -> dict: + """Get execution statistics for a flow.""" + # Verify flow ownership + flow_query = select(FlowRecord).where( + FlowRecord.id == flow_id, + FlowRecord.user_id == user.id, + ) + flow = (await db.execute(flow_query)).scalar_one_or_none() + + if not flow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Flow not found", + ) + + # Calculate stats + stats_query = select( + func.count().label("total"), + func.count().filter(ExecutionRecord.status == ExecutionStatus.SUCCESS).label("success"), + func.count().filter(ExecutionRecord.status == ExecutionStatus.ERROR).label("error"), + func.avg(ExecutionRecord.duration_ms).label("avg_duration_ms"), + func.max(ExecutionRecord.started_at).label("last_executed_at"), + ).where(ExecutionRecord.flow_id == flow_id) + + result = await db.execute(stats_query) + row = result.one() + + return { + "flowId": str(flow_id), + "executionsTotal": row.total or 0, + "executionsSuccess": row.success or 0, + "executionsError": row.error or 0, + "avgDurationMs": float(row.avg_duration_ms) if row.avg_duration_ms else 0, + "lastExecutedAt": row.last_executed_at.isoformat() if row.last_executed_at else None, + } diff --git a/bloxserver/api/routes/flows.py b/bloxserver/api/routes/flows.py new file mode 100644 index 0000000..89cf109 --- /dev/null +++ b/bloxserver/api/routes/flows.py @@ -0,0 +1,269 @@ +""" +Flow CRUD endpoints. + +Flows are the core entity - a user's workflow definition. +""" + +from __future__ import annotations + +from uuid import UUID + +from fastapi import APIRouter, HTTPException, status +from sqlalchemy import func, select + +from bloxserver.api.dependencies import AuthenticatedUser, DbSession +from bloxserver.api.models.tables import FlowRecord, Tier +from bloxserver.api.schemas import ( + CreateFlowRequest, + Flow, + FlowSummary, + PaginatedResponse, + UpdateFlowRequest, +) + +router = APIRouter(prefix="/flows", tags=["flows"]) + +# Default organism.yaml template for new flows +DEFAULT_ORGANISM_YAML = """organism: + name: my-flow + +listeners: + - name: greeter + payload_class: handlers.hello.Greeting + handler: handlers.hello.handle_greeting + description: A friendly greeter agent + agent: true + peers: [] +""" + +# Tier limits +TIER_FLOW_LIMITS = { + Tier.FREE: 1, + Tier.PRO: 100, # Effectively unlimited for most users + Tier.ENTERPRISE: 1000, + Tier.HIGH_FREQUENCY: 1000, +} + + +@router.get("", response_model=PaginatedResponse[FlowSummary]) +async def list_flows( + user: AuthenticatedUser, + db: DbSession, + page: int = 1, + page_size: int = 20, +) -> PaginatedResponse[FlowSummary]: + """List all flows for the current user.""" + offset = (page - 1) * page_size + + # Get total count + count_query = select(func.count()).select_from(FlowRecord).where( + FlowRecord.user_id == user.id + ) + total = (await db.execute(count_query)).scalar() or 0 + + # Get page of flows + query = ( + select(FlowRecord) + .where(FlowRecord.user_id == user.id) + .order_by(FlowRecord.updated_at.desc()) + .offset(offset) + .limit(page_size) + ) + result = await db.execute(query) + flows = result.scalars().all() + + return PaginatedResponse( + items=[FlowSummary.model_validate(f) for f in flows], + total=total, + page=page, + page_size=page_size, + has_more=offset + len(flows) < total, + ) + + +@router.post("", response_model=Flow, status_code=status.HTTP_201_CREATED) +async def create_flow( + user: AuthenticatedUser, + db: DbSession, + request: CreateFlowRequest, +) -> Flow: + """Create a new flow.""" + # Check tier limits + count_query = select(func.count()).select_from(FlowRecord).where( + FlowRecord.user_id == user.id + ) + current_count = (await db.execute(count_query)).scalar() or 0 + limit = TIER_FLOW_LIMITS.get(user.user.tier, 1) + + if current_count >= limit: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Flow limit reached ({limit}). Upgrade to create more flows.", + ) + + # Create flow + flow = FlowRecord( + user_id=user.id, + name=request.name, + description=request.description, + organism_yaml=request.organism_yaml or DEFAULT_ORGANISM_YAML, + ) + db.add(flow) + await db.flush() + + return Flow.model_validate(flow) + + +@router.get("/{flow_id}", response_model=Flow) +async def get_flow( + flow_id: UUID, + user: AuthenticatedUser, + db: DbSession, +) -> Flow: + """Get a single flow by ID.""" + query = select(FlowRecord).where( + FlowRecord.id == flow_id, + FlowRecord.user_id == user.id, + ) + result = await db.execute(query) + flow = result.scalar_one_or_none() + + if not flow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Flow not found", + ) + + return Flow.model_validate(flow) + + +@router.patch("/{flow_id}", response_model=Flow) +async def update_flow( + flow_id: UUID, + user: AuthenticatedUser, + db: DbSession, + request: UpdateFlowRequest, +) -> Flow: + """Update a flow.""" + query = select(FlowRecord).where( + FlowRecord.id == flow_id, + FlowRecord.user_id == user.id, + ) + result = await db.execute(query) + flow = result.scalar_one_or_none() + + if not flow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Flow not found", + ) + + # Update fields that were provided + if request.name is not None: + flow.name = request.name + if request.description is not None: + flow.description = request.description + if request.organism_yaml is not None: + flow.organism_yaml = request.organism_yaml + if request.canvas_state is not None: + flow.canvas_state = request.canvas_state.model_dump() + + await db.flush() + return Flow.model_validate(flow) + + +@router.delete("/{flow_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_flow( + flow_id: UUID, + user: AuthenticatedUser, + db: DbSession, +) -> None: + """Delete a flow.""" + query = select(FlowRecord).where( + FlowRecord.id == flow_id, + FlowRecord.user_id == user.id, + ) + result = await db.execute(query) + flow = result.scalar_one_or_none() + + if not flow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Flow not found", + ) + + await db.delete(flow) + + +# ============================================================================= +# Flow Actions (Start/Stop) +# ============================================================================= + + +@router.post("/{flow_id}/start", response_model=Flow) +async def start_flow( + flow_id: UUID, + user: AuthenticatedUser, + db: DbSession, +) -> Flow: + """Start a flow (deploy container).""" + query = select(FlowRecord).where( + FlowRecord.id == flow_id, + FlowRecord.user_id == user.id, + ) + result = await db.execute(query) + flow = result.scalar_one_or_none() + + if not flow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Flow not found", + ) + + if flow.status not in ("stopped", "error"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Cannot start flow in {flow.status} state", + ) + + # TODO: Actually start the container + # This is where we'd call the container orchestration layer + # For now, just update the status + flow.status = "starting" + flow.error_message = None + + await db.flush() + return Flow.model_validate(flow) + + +@router.post("/{flow_id}/stop", response_model=Flow) +async def stop_flow( + flow_id: UUID, + user: AuthenticatedUser, + db: DbSession, +) -> Flow: + """Stop a running flow.""" + query = select(FlowRecord).where( + FlowRecord.id == flow_id, + FlowRecord.user_id == user.id, + ) + result = await db.execute(query) + flow = result.scalar_one_or_none() + + if not flow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Flow not found", + ) + + if flow.status not in ("running", "starting", "error"): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Cannot stop flow in {flow.status} state", + ) + + # TODO: Actually stop the container + flow.status = "stopping" + + await db.flush() + return Flow.model_validate(flow) diff --git a/bloxserver/api/routes/health.py b/bloxserver/api/routes/health.py new file mode 100644 index 0000000..dd587ff --- /dev/null +++ b/bloxserver/api/routes/health.py @@ -0,0 +1,77 @@ +""" +Health check and status endpoints. +""" + +from __future__ import annotations + +from datetime import datetime + +from fastapi import APIRouter +from sqlalchemy import text + +from bloxserver.api.models.database import async_session_maker + +router = APIRouter(tags=["health"]) + + +@router.get("/health") +async def health_check() -> dict: + """ + Basic health check. + + Returns 200 if the service is running. + """ + return { + "status": "healthy", + "timestamp": datetime.utcnow().isoformat(), + "service": "bloxserver-api", + } + + +@router.get("/health/ready") +async def readiness_check() -> dict: + """ + Readiness check - verifies database connectivity. + + Used by Kubernetes/load balancers to determine if the service + is ready to receive traffic. + """ + errors = [] + + # Check database + try: + async with async_session_maker() as session: + await session.execute(text("SELECT 1")) + except Exception as e: + errors.append(f"database: {e}") + + # TODO: Check Redis + # TODO: Check other dependencies + + if errors: + return { + "status": "unhealthy", + "timestamp": datetime.utcnow().isoformat(), + "errors": errors, + } + + return { + "status": "ready", + "timestamp": datetime.utcnow().isoformat(), + "checks": { + "database": "ok", + }, + } + + +@router.get("/health/live") +async def liveness_check() -> dict: + """ + Liveness check - just confirms the process is running. + + If this fails, Kubernetes should restart the pod. + """ + return { + "status": "alive", + "timestamp": datetime.utcnow().isoformat(), + } diff --git a/bloxserver/api/routes/triggers.py b/bloxserver/api/routes/triggers.py new file mode 100644 index 0000000..ae8e435 --- /dev/null +++ b/bloxserver/api/routes/triggers.py @@ -0,0 +1,221 @@ +""" +Trigger CRUD endpoints. + +Triggers define how flows are started: webhook, schedule, or manual. +""" + +from __future__ import annotations + +import secrets +from uuid import UUID + +from fastapi import APIRouter, HTTPException, status +from sqlalchemy import select + +from bloxserver.api.dependencies import AuthenticatedUser, DbSession +from bloxserver.api.models.tables import FlowRecord, TriggerRecord, TriggerType +from bloxserver.api.schemas import CreateTriggerRequest, Trigger + +router = APIRouter(prefix="/flows/{flow_id}/triggers", tags=["triggers"]) + +# Base URL for webhooks (configured via environment) +import os +WEBHOOK_BASE_URL = os.getenv("WEBHOOK_BASE_URL", "https://api.openblox.ai/webhooks") + + +def generate_webhook_token() -> str: + """Generate a secure random token for webhook URLs.""" + return secrets.token_urlsafe(32) + + +@router.get("", response_model=list[Trigger]) +async def list_triggers( + flow_id: UUID, + user: AuthenticatedUser, + db: DbSession, +) -> list[Trigger]: + """List all triggers for a flow.""" + # Verify flow ownership + flow_query = select(FlowRecord).where( + FlowRecord.id == flow_id, + FlowRecord.user_id == user.id, + ) + flow = (await db.execute(flow_query)).scalar_one_or_none() + + if not flow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Flow not found", + ) + + # Get triggers + query = select(TriggerRecord).where(TriggerRecord.flow_id == flow_id) + result = await db.execute(query) + triggers = result.scalars().all() + + return [Trigger.model_validate(t) for t in triggers] + + +@router.post("", response_model=Trigger, status_code=status.HTTP_201_CREATED) +async def create_trigger( + flow_id: UUID, + user: AuthenticatedUser, + db: DbSession, + request: CreateTriggerRequest, +) -> Trigger: + """Create a new trigger for a flow.""" + # Verify flow ownership + flow_query = select(FlowRecord).where( + FlowRecord.id == flow_id, + FlowRecord.user_id == user.id, + ) + flow = (await db.execute(flow_query)).scalar_one_or_none() + + if not flow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Flow not found", + ) + + # Create trigger + trigger = TriggerRecord( + flow_id=flow_id, + type=TriggerType(request.type.value), + name=request.name, + config=request.config, + ) + + # Generate webhook URL for webhook triggers + if request.type == TriggerType.WEBHOOK: + trigger.webhook_token = generate_webhook_token() + trigger.webhook_url = f"{WEBHOOK_BASE_URL}/{trigger.webhook_token}" + + db.add(trigger) + await db.flush() + + return Trigger.model_validate(trigger) + + +@router.get("/{trigger_id}", response_model=Trigger) +async def get_trigger( + flow_id: UUID, + trigger_id: UUID, + user: AuthenticatedUser, + db: DbSession, +) -> Trigger: + """Get a single trigger by ID.""" + # Verify flow ownership + flow_query = select(FlowRecord).where( + FlowRecord.id == flow_id, + FlowRecord.user_id == user.id, + ) + flow = (await db.execute(flow_query)).scalar_one_or_none() + + if not flow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Flow not found", + ) + + # Get trigger + query = select(TriggerRecord).where( + TriggerRecord.id == trigger_id, + TriggerRecord.flow_id == flow_id, + ) + result = await db.execute(query) + trigger = result.scalar_one_or_none() + + if not trigger: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Trigger not found", + ) + + return Trigger.model_validate(trigger) + + +@router.delete("/{trigger_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_trigger( + flow_id: UUID, + trigger_id: UUID, + user: AuthenticatedUser, + db: DbSession, +) -> None: + """Delete a trigger.""" + # Verify flow ownership + flow_query = select(FlowRecord).where( + FlowRecord.id == flow_id, + FlowRecord.user_id == user.id, + ) + flow = (await db.execute(flow_query)).scalar_one_or_none() + + if not flow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Flow not found", + ) + + # Get and delete trigger + query = select(TriggerRecord).where( + TriggerRecord.id == trigger_id, + TriggerRecord.flow_id == flow_id, + ) + result = await db.execute(query) + trigger = result.scalar_one_or_none() + + if not trigger: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Trigger not found", + ) + + await db.delete(trigger) + + +@router.post("/{trigger_id}/regenerate-token", response_model=Trigger) +async def regenerate_webhook_token( + flow_id: UUID, + trigger_id: UUID, + user: AuthenticatedUser, + db: DbSession, +) -> Trigger: + """Regenerate the webhook token for a webhook trigger.""" + # Verify flow ownership + flow_query = select(FlowRecord).where( + FlowRecord.id == flow_id, + FlowRecord.user_id == user.id, + ) + flow = (await db.execute(flow_query)).scalar_one_or_none() + + if not flow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Flow not found", + ) + + # Get trigger + query = select(TriggerRecord).where( + TriggerRecord.id == trigger_id, + TriggerRecord.flow_id == flow_id, + ) + result = await db.execute(query) + trigger = result.scalar_one_or_none() + + if not trigger: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Trigger not found", + ) + + if trigger.type != TriggerType.WEBHOOK: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Can only regenerate token for webhook triggers", + ) + + # Regenerate + trigger.webhook_token = generate_webhook_token() + trigger.webhook_url = f"{WEBHOOK_BASE_URL}/{trigger.webhook_token}" + + await db.flush() + return Trigger.model_validate(trigger) diff --git a/bloxserver/api/routes/webhooks.py b/bloxserver/api/routes/webhooks.py new file mode 100644 index 0000000..e22bc23 --- /dev/null +++ b/bloxserver/api/routes/webhooks.py @@ -0,0 +1,125 @@ +""" +Webhook trigger endpoint. + +This handles incoming webhook requests that trigger flows. +""" + +from __future__ import annotations + +from datetime import datetime + +from fastapi import APIRouter, HTTPException, Request, status +from sqlalchemy import select + +from bloxserver.api.models.database import get_db_context +from bloxserver.api.models.tables import ( + ExecutionRecord, + ExecutionStatus, + FlowRecord, + TriggerRecord, + TriggerType, +) + +router = APIRouter(prefix="/webhooks", tags=["webhooks"]) + + +@router.post("/{webhook_token}") +async def handle_webhook( + webhook_token: str, + request: Request, +) -> dict: + """ + Handle incoming webhook request. + + This endpoint is public (no auth) - the token IS the authentication. + """ + async with get_db_context() as db: + # Look up trigger by token + query = select(TriggerRecord).where( + TriggerRecord.webhook_token == webhook_token, + TriggerRecord.type == TriggerType.WEBHOOK, + ) + result = await db.execute(query) + trigger = result.scalar_one_or_none() + + if not trigger: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Webhook not found", + ) + + # Get the flow + flow_query = select(FlowRecord).where(FlowRecord.id == trigger.flow_id) + flow = (await db.execute(flow_query)).scalar_one_or_none() + + if not flow: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Flow not found", + ) + + if flow.status != "running": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Flow is not running (status: {flow.status})", + ) + + # Get request body + try: + body = await request.body() + input_payload = body.decode("utf-8") if body else None + except Exception: + input_payload = None + + # Create execution record + execution = ExecutionRecord( + flow_id=flow.id, + trigger_id=trigger.id, + trigger_type=TriggerType.WEBHOOK, + status=ExecutionStatus.RUNNING, + input_payload=input_payload, + ) + db.add(execution) + await db.commit() + + # TODO: Actually dispatch to the running container + # This would send the payload to the flow's container + + return { + "status": "accepted", + "executionId": str(execution.id), + "message": "Webhook received and execution started", + } + + +@router.get("/{webhook_token}/test") +async def test_webhook(webhook_token: str) -> dict: + """ + Test that a webhook token is valid. + + Returns info about the trigger without actually executing. + """ + async with get_db_context() as db: + query = select(TriggerRecord).where( + TriggerRecord.webhook_token == webhook_token, + TriggerRecord.type == TriggerType.WEBHOOK, + ) + result = await db.execute(query) + trigger = result.scalar_one_or_none() + + if not trigger: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Webhook not found", + ) + + # Get the flow + flow_query = select(FlowRecord).where(FlowRecord.id == trigger.flow_id) + flow = (await db.execute(flow_query)).scalar_one_or_none() + + return { + "valid": True, + "triggerName": trigger.name, + "flowName": flow.name if flow else None, + "flowStatus": flow.status.value if flow else None, + } diff --git a/bloxserver/api/schemas.py b/bloxserver/api/schemas.py new file mode 100644 index 0000000..acf7f1b --- /dev/null +++ b/bloxserver/api/schemas.py @@ -0,0 +1,322 @@ +""" +Pydantic schemas for API request/response validation. + +These match the TypeScript types in types.ts for frontend compatibility. +Uses camelCase aliases for JSON serialization. +""" + +from __future__ import annotations + +from datetime import datetime +from enum import Enum +from typing import Any, Generic, Literal, TypeVar +from uuid import UUID + +from pydantic import BaseModel, ConfigDict, Field + + +# ============================================================================= +# Config for camelCase serialization +# ============================================================================= + + +def to_camel(string: str) -> str: + """Convert snake_case to camelCase.""" + components = string.split("_") + return components[0] + "".join(x.title() for x in components[1:]) + + +class CamelModel(BaseModel): + """Base model with camelCase JSON serialization.""" + + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + from_attributes=True, + ) + + +# ============================================================================= +# Common Types +# ============================================================================= + +T = TypeVar("T") + + +class PaginatedResponse(CamelModel, Generic[T]): + """Paginated list response.""" + + items: list[T] + total: int + page: int + page_size: int + has_more: bool + + +class ApiError(CamelModel): + """API error response.""" + + code: str + message: str + details: dict[str, Any] | None = None + + +# ============================================================================= +# Enums +# ============================================================================= + + +class Tier(str, Enum): + FREE = "free" + PRO = "pro" + ENTERPRISE = "enterprise" + HIGH_FREQUENCY = "high_frequency" + + +class FlowStatus(str, Enum): + STOPPED = "stopped" + STARTING = "starting" + RUNNING = "running" + STOPPING = "stopping" + ERROR = "error" + + +class TriggerType(str, Enum): + WEBHOOK = "webhook" + SCHEDULE = "schedule" + MANUAL = "manual" + + +class ExecutionStatus(str, Enum): + RUNNING = "running" + SUCCESS = "success" + ERROR = "error" + TIMEOUT = "timeout" + + +# ============================================================================= +# User +# ============================================================================= + + +class User(CamelModel): + """User account (synced from Clerk).""" + + id: UUID + clerk_id: str + email: str + name: str | None = None + avatar_url: str | None = None + tier: Tier = Tier.FREE + created_at: datetime + + +# ============================================================================= +# Canvas State (React Flow) +# ============================================================================= + + +class CanvasNode(CamelModel): + """A node in the React Flow canvas.""" + + id: str + type: str + position: dict[str, float] + data: dict[str, Any] + + +class CanvasEdge(CamelModel): + """An edge connecting nodes in the canvas.""" + + id: str + source: str + target: str + source_handle: str | None = None + target_handle: str | None = None + + +class CanvasState(CamelModel): + """React Flow canvas state.""" + + nodes: list[CanvasNode] + edges: list[CanvasEdge] + viewport: dict[str, float] + + +# ============================================================================= +# Flows +# ============================================================================= + + +class Flow(CamelModel): + """A user's workflow/flow.""" + + id: UUID + user_id: UUID + name: str + description: str | None = None + organism_yaml: str + canvas_state: CanvasState | None = None + status: FlowStatus = FlowStatus.STOPPED + container_id: str | None = None + error_message: str | None = None + created_at: datetime + updated_at: datetime + + +class FlowSummary(CamelModel): + """Abbreviated flow for list views.""" + + id: UUID + name: str + description: str | None = None + status: FlowStatus + updated_at: datetime + + +class CreateFlowRequest(CamelModel): + """Request to create a new flow.""" + + name: str = Field(min_length=1, max_length=100) + description: str | None = Field(default=None, max_length=500) + organism_yaml: str | None = None + + +class UpdateFlowRequest(CamelModel): + """Request to update a flow.""" + + name: str | None = Field(default=None, min_length=1, max_length=100) + description: str | None = Field(default=None, max_length=500) + organism_yaml: str | None = None + canvas_state: CanvasState | None = None + + +# ============================================================================= +# Triggers +# ============================================================================= + + +class WebhookTriggerConfig(CamelModel): + """Config for webhook triggers.""" + + type: Literal["webhook"] = "webhook" + + +class ScheduleTriggerConfig(CamelModel): + """Config for scheduled triggers.""" + + type: Literal["schedule"] = "schedule" + cron: str = Field(description="Cron expression") + timezone: str = "UTC" + + +class ManualTriggerConfig(CamelModel): + """Config for manual triggers.""" + + type: Literal["manual"] = "manual" + + +TriggerConfig = WebhookTriggerConfig | ScheduleTriggerConfig | ManualTriggerConfig + + +class Trigger(CamelModel): + """A trigger that can start a flow.""" + + id: UUID + flow_id: UUID + type: TriggerType + name: str + config: dict[str, Any] + webhook_token: str | None = None + webhook_url: str | None = None + created_at: datetime + + +class CreateTriggerRequest(CamelModel): + """Request to create a trigger.""" + + type: TriggerType + name: str = Field(min_length=1, max_length=100) + config: dict[str, Any] + + +# ============================================================================= +# Executions +# ============================================================================= + + +class Execution(CamelModel): + """A single execution/run of a flow.""" + + id: UUID + flow_id: UUID + trigger_id: UUID | None = None + trigger_type: TriggerType + status: ExecutionStatus + started_at: datetime + completed_at: datetime | None = None + duration_ms: int | None = None + error_message: str | None = None + input_payload: str | None = None + output_payload: str | None = None + + +class ExecutionSummary(CamelModel): + """Abbreviated execution for list views.""" + + id: UUID + status: ExecutionStatus + trigger_type: TriggerType + started_at: datetime + duration_ms: int | None = None + + +# ============================================================================= +# Usage & Stats +# ============================================================================= + + +class UsageDashboard(CamelModel): + """Current usage for user dashboard.""" + + period_start: datetime + period_end: datetime | None + runs_used: int + runs_limit: int + runs_percentage: float + tokens_used: int + estimated_overage: float + days_remaining: int + + +class FlowStats(CamelModel): + """Statistics for a single flow.""" + + flow_id: UUID + executions_total: int + executions_success: int + executions_error: int + avg_duration_ms: float + last_executed_at: datetime | None = None + + +# ============================================================================= +# API Keys (BYOK) +# ============================================================================= + + +class ApiKeyInfo(CamelModel): + """Info about a stored API key (never exposes the key itself).""" + + provider: str + key_hint: str | None # Last few chars: "...abc123" + is_valid: bool + last_used_at: datetime | None + created_at: datetime + + +class AddApiKeyRequest(CamelModel): + """Request to add a user's API key.""" + + provider: str = Field(description="Provider name: openai, anthropic, xai") + api_key: str = Field(min_length=10, description="The API key") diff --git a/bloxserver/docker-compose.yml b/bloxserver/docker-compose.yml new file mode 100644 index 0000000..c3f39a0 --- /dev/null +++ b/bloxserver/docker-compose.yml @@ -0,0 +1,72 @@ +# BloxServer Development Docker Compose +# Run with: docker-compose up -d + +version: '3.8' + +services: + # ========================================================================== + # PostgreSQL Database + # ========================================================================== + postgres: + image: postgres:16-alpine + container_name: bloxserver-postgres + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: bloxserver + ports: + - "5432:5432" + volumes: + - postgres_data:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U postgres"] + interval: 10s + timeout: 5s + retries: 5 + + # ========================================================================== + # Redis (for caching, rate limiting, queues) + # ========================================================================== + redis: + image: redis:7-alpine + container_name: bloxserver-redis + ports: + - "6379:6379" + volumes: + - redis_data:/data + healthcheck: + test: ["CMD", "redis-cli", "ping"] + interval: 10s + timeout: 5s + retries: 5 + + # ========================================================================== + # BloxServer API + # ========================================================================== + api: + build: + context: . + dockerfile: Dockerfile + container_name: bloxserver-api + ports: + - "8000:8000" + environment: + - ENV=development + - DATABASE_URL=postgresql+asyncpg://postgres:postgres@postgres:5432/bloxserver + - REDIS_URL=redis://redis:6379 + - AUTO_CREATE_TABLES=true + - ENABLE_DOCS=true + - CORS_ORIGINS=http://localhost:3000 + depends_on: + postgres: + condition: service_healthy + redis: + condition: service_healthy + volumes: + # Mount source for hot reload in development + - .:/app/bloxserver:ro + command: uvicorn bloxserver.api.main:app --host 0.0.0.0 --port 8000 --reload + +volumes: + postgres_data: + redis_data: diff --git a/bloxserver/requirements.txt b/bloxserver/requirements.txt new file mode 100644 index 0000000..8ec147a --- /dev/null +++ b/bloxserver/requirements.txt @@ -0,0 +1,31 @@ +# BloxServer API Dependencies + +# Web framework +fastapi>=0.109.0 +uvicorn[standard]>=0.27.0 + +# Database +sqlalchemy[asyncio]>=2.0.0 +asyncpg>=0.29.0 +alembic>=1.13.0 + +# Authentication (Clerk JWT validation) +pyjwt[crypto]>=2.8.0 +httpx>=0.27.0 + +# Validation & serialization +pydantic>=2.5.0 +pydantic-settings>=2.1.0 + +# Utilities +python-dotenv>=1.0.0 +humps>=0.2.2 + +# Stripe billing +stripe>=8.0.0 + +# Redis (for caching/rate limiting) +redis>=5.0.0 + +# Cryptography (for API key encryption) +cryptography>=42.0.0 diff --git a/docs/bloxserver-billing.md b/docs/bloxserver-billing.md new file mode 100644 index 0000000..b375e2e --- /dev/null +++ b/docs/bloxserver-billing.md @@ -0,0 +1,668 @@ +# BloxServer Billing Integration — Stripe + +**Status:** Design +**Date:** January 2026 + +## Overview + +BloxServer uses Stripe for subscription management, usage-based billing, and payment processing. This document specifies the integration architecture, webhook handlers, and usage tracking system. + +## Pricing Tiers + +| Tier | Price | Runs/Month | Features | +|------|-------|------------|----------| +| **Free** | $0 | 1,000 | 1 workflow, built-in tools, community support | +| **Pro** | $29 | 100,000 | Unlimited workflows, marketplace, WASM, project memory, priority support | +| **Enterprise** | Custom | Unlimited | SSO/SAML, SLA, dedicated support, private marketplace | + +### Overage Pricing (Pro) + +| Metric | Included | Overage Rate | +|--------|----------|--------------| +| Workflow runs | 100K/mo | $0.50 per 1K | +| Storage | 10 GB | $0.10 per GB | +| WASM execution | 1000 CPU-sec | $0.01 per CPU-sec | + +## Stripe Product Structure + +``` +Products: +├── bloxserver_free +│ └── price_free_monthly ($0/month, metered runs) +├── bloxserver_pro +│ ├── price_pro_monthly ($29/month base) +│ ├── price_pro_runs_overage (metered, $0.50/1K) +│ └── price_pro_storage_overage (metered, $0.10/GB) +└── bloxserver_enterprise + └── price_enterprise_custom (quoted per customer) +``` + +### Stripe Configuration + +```python +# One-time setup (or via Stripe Dashboard) + +# Free tier product +free_product = stripe.Product.create( + name="BloxServer Free", + description="Build AI agent swarms, visually", +) + +free_price = stripe.Price.create( + product=free_product.id, + unit_amount=0, + currency="usd", + recurring={"interval": "month"}, + metadata={"tier": "free", "runs_included": "1000"} +) + +# Pro tier product +pro_product = stripe.Product.create( + name="BloxServer Pro", + description="Unlimited workflows, marketplace access, custom WASM", +) + +pro_base_price = stripe.Price.create( + product=pro_product.id, + unit_amount=2900, # $29.00 + currency="usd", + recurring={"interval": "month"}, + metadata={"tier": "pro", "runs_included": "100000"} +) + +pro_runs_overage = stripe.Price.create( + product=pro_product.id, + currency="usd", + recurring={ + "interval": "month", + "usage_type": "metered", + "aggregate_usage": "sum", + }, + unit_amount_decimal="0.05", # $0.0005 per run = $0.50 per 1K + metadata={"type": "runs_overage"} +) +``` + +## Database Schema + +```sql +-- Users table (synced from Clerk + Stripe) +CREATE TABLE users ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + clerk_id VARCHAR(255) UNIQUE NOT NULL, + email VARCHAR(255) NOT NULL, + name VARCHAR(255), + + -- Stripe fields + stripe_customer_id VARCHAR(255) UNIQUE, + stripe_subscription_id VARCHAR(255), + stripe_subscription_item_id VARCHAR(255), -- For usage reporting + + -- Billing state (cached from Stripe) + tier VARCHAR(50) DEFAULT 'free', -- free, pro, enterprise + billing_status VARCHAR(50) DEFAULT 'active', -- active, past_due, canceled + trial_ends_at TIMESTAMPTZ, + current_period_start TIMESTAMPTZ, + current_period_end TIMESTAMPTZ, + + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW() +); + +-- Usage tracking (local, for dashboard + Stripe sync) +CREATE TABLE usage_records ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID REFERENCES users(id), + period_start DATE NOT NULL, -- Billing period start + + -- Metrics + workflow_runs INT DEFAULT 0, + llm_tokens_in INT DEFAULT 0, + llm_tokens_out INT DEFAULT 0, + wasm_cpu_seconds DECIMAL(10,2) DEFAULT 0, + storage_gb_hours DECIMAL(10,2) DEFAULT 0, + + -- Stripe sync state + last_synced_at TIMESTAMPTZ, + last_synced_runs INT DEFAULT 0, + + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW(), + + UNIQUE(user_id, period_start) +); + +-- Stripe webhook events (idempotency) +CREATE TABLE stripe_events ( + event_id VARCHAR(255) PRIMARY KEY, + event_type VARCHAR(100) NOT NULL, + processed_at TIMESTAMPTZ DEFAULT NOW(), + payload JSONB +); + +-- Index for cleanup +CREATE INDEX idx_stripe_events_processed ON stripe_events(processed_at); +``` + +## Usage Tracking + +### Real-Time Counting (Redis) + +```python +# On every workflow execution +async def record_workflow_run(user_id: str): + """Increment run counter in Redis.""" + key = f"usage:{user_id}:runs:{get_current_period()}" + await redis.incr(key) + await redis.expire(key, 86400 * 35) # 35 days TTL + + # Track users with usage for batch sync + await redis.sadd("users:with_usage", user_id) + +async def record_llm_tokens(user_id: str, tokens_in: int, tokens_out: int): + """Track LLM token usage.""" + period = get_current_period() + await redis.incrby(f"usage:{user_id}:tokens_in:{period}", tokens_in) + await redis.incrby(f"usage:{user_id}:tokens_out:{period}", tokens_out) +``` + +### Periodic Sync to Stripe (Hourly) + +```python +async def sync_usage_to_stripe(): + """Hourly job: push usage increments to Stripe.""" + + user_ids = await redis.smembers("users:with_usage") + + for user_id in user_ids: + user = await get_user(user_id) + if not user.stripe_subscription_item_id: + continue # Free tier without Stripe subscription + + # Get usage since last sync + period = get_current_period() + runs_key = f"usage:{user_id}:runs:{period}" + + current_runs = int(await redis.get(runs_key) or 0) + last_synced = await get_last_synced_runs(user_id, period) + + delta = current_runs - last_synced + if delta <= 0: + continue + + # Check if over included limit + tier_limit = get_tier_runs_limit(user.tier) # 1000 or 100000 + if current_runs <= tier_limit: + # Still within included runs, just track locally + await update_last_synced(user_id, period, current_runs) + continue + + # Calculate overage to report + overage_start = max(last_synced, tier_limit) + overage_runs = current_runs - overage_start + + if overage_runs > 0: + # Report to Stripe + await stripe.subscription_items.create_usage_record( + user.stripe_subscription_item_id, + quantity=overage_runs, + timestamp=int(time.time()), + action='increment' + ) + + await update_last_synced(user_id, period, current_runs) + + # Clear the tracking set (will rebuild next hour) + await redis.delete("users:with_usage") +``` + +### Dashboard Query + +```python +async def get_usage_dashboard(user_id: str) -> UsageDashboard: + """Get current usage for user dashboard.""" + user = await get_user(user_id) + period = get_current_period() + + # Get real-time counts from Redis + runs = int(await redis.get(f"usage:{user_id}:runs:{period}") or 0) + tokens_in = int(await redis.get(f"usage:{user_id}:tokens_in:{period}") or 0) + tokens_out = int(await redis.get(f"usage:{user_id}:tokens_out:{period}") or 0) + + tier_limit = get_tier_runs_limit(user.tier) + + return UsageDashboard( + period_start=period, + period_end=user.current_period_end, + + runs_used=runs, + runs_limit=tier_limit, + runs_percentage=min(100, (runs / tier_limit) * 100), + + tokens_used=tokens_in + tokens_out, + + estimated_overage=calculate_overage_cost(runs, tier_limit), + + days_remaining=(user.current_period_end - datetime.now()).days, + ) +``` + +## Subscription Lifecycle + +### Signup Flow + +``` +User clicks "Start Free Trial" + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ 1. Create Stripe Customer │ +│ │ +│ customer = stripe.Customer.create( │ +│ email=user.email, │ +│ metadata={"clerk_id": user.clerk_id} │ +│ ) │ +└───────────────────────────────────────────────────────────┘ + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ 2. Create Checkout Session (hosted payment page) │ +│ │ +│ session = stripe.checkout.Session.create( │ +│ customer=customer.id, │ +│ mode='subscription', │ +│ line_items=[{ │ +│ 'price': 'price_pro_monthly', │ +│ 'quantity': 1 │ +│ }, { │ +│ 'price': 'price_pro_runs_overage', # metered │ +│ }], │ +│ subscription_data={ │ +│ 'trial_period_days': 14, │ +│ }, │ +│ success_url='https://app.openblox.ai/welcome', │ +│ cancel_url='https://app.openblox.ai/pricing', │ +│ ) │ +│ │ +│ → Redirect user to session.url │ +└───────────────────────────────────────────────────────────┘ + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ 3. User enters payment details on Stripe Checkout │ +│ │ +│ Card validated but NOT charged (trial) │ +└───────────────────────────────────────────────────────────┘ + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ 4. Webhook: checkout.session.completed │ +│ │ +│ → Update user with stripe_customer_id │ +│ → Update user with stripe_subscription_id │ +│ → Set tier = 'pro' │ +│ → Set trial_ends_at │ +└───────────────────────────────────────────────────────────┘ +``` + +### Trial End + +``` +Day 11 of 14-day trial + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ Scheduled job: Trial ending soon emails │ +│ │ +│ SELECT * FROM users │ +│ WHERE trial_ends_at BETWEEN NOW() AND NOW() + INTERVAL '3d'│ +│ AND billing_status = 'trialing' │ +│ │ +│ → Send "Your trial ends in 3 days" email │ +└───────────────────────────────────────────────────────────┘ + │ + ▼ +Day 14: Trial ends + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ Stripe automatically: │ +│ 1. Charges the card on file │ +│ 2. Sends invoice.payment_succeeded webhook │ +│ │ +│ Our webhook handler: │ +│ → Update billing_status = 'active' │ +│ → Send "Welcome to Pro!" email │ +└───────────────────────────────────────────────────────────┘ +``` + +### Cancellation + +```python +# User clicks "Cancel subscription" in Customer Portal +# Stripe sends webhook + +@webhook("customer.subscription.updated") +async def handle_subscription_updated(event): + subscription = event.data.object + user = await get_user_by_stripe_subscription(subscription.id) + + if subscription.cancel_at_period_end: + # User requested cancellation (takes effect at period end) + await send_email(user, "subscription_canceled", { + "effective_date": subscription.current_period_end + }) + await db.execute(""" + UPDATE users + SET billing_status = 'canceling', + updated_at = NOW() + WHERE id = $1 + """, user.id) + +@webhook("customer.subscription.deleted") +async def handle_subscription_deleted(event): + subscription = event.data.object + user = await get_user_by_stripe_subscription(subscription.id) + + # Subscription actually ended + await db.execute(""" + UPDATE users + SET tier = 'free', + billing_status = 'canceled', + stripe_subscription_id = NULL, + stripe_subscription_item_id = NULL, + updated_at = NOW() + WHERE id = $1 + """, user.id) + + await send_email(user, "downgraded_to_free") +``` + +## Webhook Handlers + +### Endpoint Setup + +```python +from fastapi import FastAPI, Request, HTTPException +import stripe + +app = FastAPI() + +@app.post("/webhooks/stripe") +async def stripe_webhook(request: Request): + payload = await request.body() + sig_header = request.headers.get("stripe-signature") + + try: + event = stripe.Webhook.construct_event( + payload, sig_header, settings.STRIPE_WEBHOOK_SECRET + ) + except ValueError: + raise HTTPException(400, "Invalid payload") + except stripe.error.SignatureVerificationError: + raise HTTPException(400, "Invalid signature") + + # Idempotency check + if await is_event_processed(event.id): + return {"status": "already_processed"} + + # Route to handler + handler = WEBHOOK_HANDLERS.get(event.type) + if handler: + await handler(event) + else: + logger.info(f"Unhandled webhook: {event.type}") + + # Mark processed + await mark_event_processed(event) + + return {"status": "success"} +``` + +### Handler Registry + +```python +WEBHOOK_HANDLERS = { + # Checkout + "checkout.session.completed": handle_checkout_completed, + + # Subscriptions + "customer.subscription.created": handle_subscription_created, + "customer.subscription.updated": handle_subscription_updated, + "customer.subscription.deleted": handle_subscription_deleted, + "customer.subscription.trial_will_end": handle_trial_ending, + + # Payments + "invoice.payment_succeeded": handle_payment_succeeded, + "invoice.payment_failed": handle_payment_failed, + "invoice.upcoming": handle_invoice_upcoming, + + # Customer + "customer.updated": handle_customer_updated, +} +``` + +### Key Handlers + +```python +@webhook("checkout.session.completed") +async def handle_checkout_completed(event): + """User completed checkout - provision their account.""" + session = event.data.object + + # Get or create user + user = await get_user_by_clerk_id(session.client_reference_id) + + # Update with Stripe IDs + subscription = await stripe.Subscription.retrieve(session.subscription) + + await db.execute(""" + UPDATE users SET + stripe_customer_id = $1, + stripe_subscription_id = $2, + stripe_subscription_item_id = $3, + tier = $4, + billing_status = $5, + trial_ends_at = $6, + current_period_start = $7, + current_period_end = $8, + updated_at = NOW() + WHERE id = $9 + """, + session.customer, + subscription.id, + subscription['items'].data[0].id, # First item for usage reporting + 'pro', + subscription.status, # 'trialing' or 'active' + datetime.fromtimestamp(subscription.trial_end) if subscription.trial_end else None, + datetime.fromtimestamp(subscription.current_period_start), + datetime.fromtimestamp(subscription.current_period_end), + user.id + ) + + +@webhook("invoice.payment_failed") +async def handle_payment_failed(event): + """Payment failed - notify user, potentially downgrade.""" + invoice = event.data.object + user = await get_user_by_stripe_customer(invoice.customer) + + attempt_count = invoice.attempt_count + + if attempt_count == 1: + # First failure - soft warning + await send_email(user, "payment_failed_soft", { + "amount": invoice.amount_due / 100, + "update_url": await get_customer_portal_url(user) + }) + + elif attempt_count == 2: + # Second failure - stronger warning + await send_email(user, "payment_failed_warning", { + "amount": invoice.amount_due / 100, + "days_until_downgrade": 3 + }) + + else: + # Final failure - downgrade + await db.execute(""" + UPDATE users SET + tier = 'free', + billing_status = 'past_due', + updated_at = NOW() + WHERE id = $1 + """, user.id) + + await send_email(user, "downgraded_payment_failed") + + +@webhook("customer.subscription.trial_will_end") +async def handle_trial_ending(event): + """Trial ending in 3 days - Stripe sends this automatically.""" + subscription = event.data.object + user = await get_user_by_stripe_subscription(subscription.id) + + await send_email(user, "trial_ending", { + "trial_end_date": datetime.fromtimestamp(subscription.trial_end), + "amount": 29.00, # Pro price + "manage_url": await get_customer_portal_url(user) + }) +``` + +## Customer Portal + +Stripe's hosted portal for self-service billing management. + +```python +async def get_customer_portal_url(user: User) -> str: + """Generate a portal session URL for the user.""" + session = await stripe.billing_portal.Session.create( + customer=user.stripe_customer_id, + return_url="https://app.openblox.ai/settings/billing" + ) + return session.url +``` + +**Portal capabilities:** +- Update payment method +- View invoices and receipts +- Cancel subscription +- Upgrade/downgrade plan (if configured) + +## Email Templates + +| Trigger | Template | Content | +|---------|----------|---------| +| Trial started | `trial_started` | Welcome, trial ends on X | +| Trial ending (3 days) | `trial_ending` | Your trial ends soon, card will be charged | +| Trial converted | `trial_converted` | Welcome to Pro! | +| Payment succeeded | `payment_succeeded` | Receipt attached | +| Payment failed (1st) | `payment_failed_soft` | Please update your card | +| Payment failed (2nd) | `payment_failed_warning` | Service will be interrupted | +| Payment failed (final) | `downgraded_payment_failed` | You've been downgraded | +| Subscription canceled | `subscription_canceled` | Access until period end | +| Downgraded | `downgraded_to_free` | You're now on Free | + +## Rate Limiting & Abuse Prevention + +### Soft Limits (Warning) + +```python +async def check_usage_limits(user_id: str) -> UsageLimitResult: + """Check if user is approaching limits.""" + usage = await get_current_usage(user_id) + user = await get_user(user_id) + tier_limit = get_tier_runs_limit(user.tier) + + percentage = (usage.runs / tier_limit) * 100 + + if percentage >= 100: + return UsageLimitResult( + allowed=True, # Still allow, but warn + warning="You've exceeded your included runs. Overage charges apply.", + overage_rate="$0.50 per 1,000 runs" + ) + elif percentage >= 80: + return UsageLimitResult( + allowed=True, + warning=f"You've used {percentage:.0f}% of your monthly runs." + ) + + return UsageLimitResult(allowed=True) +``` + +### Hard Limits (Free Tier) + +```python +async def enforce_free_tier_limits(user_id: str) -> bool: + """Free tier has hard limits - no overage allowed.""" + user = await get_user(user_id) + if user.tier != "free": + return True # Paid tiers have soft limits + + usage = await get_current_usage(user_id) + if usage.runs >= 1000: + raise UsageLimitExceeded( + "You've reached the Free tier limit of 1,000 runs/month. " + "Upgrade to Pro for unlimited workflows." + ) + + return True +``` + +## Testing + +### Test Mode + +Stripe provides test mode with test API keys and test card numbers. + +```python +# .env +STRIPE_SECRET_KEY=sk_test_... # Test mode +STRIPE_WEBHOOK_SECRET=whsec_... + +# Test cards +# 4242424242424242 - Succeeds +# 4000000000000002 - Declined +# 4000002500003155 - Requires 3D Secure +``` + +### Webhook Testing + +```bash +# Use Stripe CLI to forward webhooks locally +stripe listen --forward-to localhost:8000/webhooks/stripe + +# Trigger test events +stripe trigger invoice.payment_succeeded +stripe trigger customer.subscription.trial_will_end +``` + +## Monitoring & Alerts + +| Metric | Alert Threshold | +|--------|-----------------| +| Webhook processing time | > 5 seconds | +| Webhook failure rate | > 1% | +| Payment failure rate | > 5% | +| Usage sync lag | > 2 hours | +| Stripe API errors | Any 5xx | + +## Security Checklist + +- [ ] Webhook signature verification +- [ ] Idempotent event processing +- [ ] API keys in environment variables (never in code) +- [ ] Customer portal for sensitive operations (not custom UI) +- [ ] PCI compliance via Stripe Checkout (no card data touches our servers) +- [ ] Audit log for billing events + +--- + +## References + +- [Stripe Billing](https://stripe.com/docs/billing) +- [Stripe Webhooks](https://stripe.com/docs/webhooks) +- [Stripe Checkout](https://stripe.com/docs/payments/checkout) +- [Stripe Customer Portal](https://stripe.com/docs/billing/subscriptions/customer-portal) +- [Metered Billing](https://stripe.com/docs/billing/subscriptions/metered-billing) diff --git a/docs/bloxserver-llm-layer.md b/docs/bloxserver-llm-layer.md new file mode 100644 index 0000000..5361b3f --- /dev/null +++ b/docs/bloxserver-llm-layer.md @@ -0,0 +1,961 @@ +# BloxServer LLM Abstraction Layer — Resilient Multi-Provider Architecture + +**Status:** Design +**Date:** January 2026 + +## Overview + +The LLM abstraction layer is the critical path for all AI operations in BloxServer. It must handle: + +- **Viral growth**: 100 → 10,000 users overnight +- **Provider outages**: Single provider down ≠ platform down +- **Fair access**: Paid users prioritized, free users served fairly +- **Cost control**: Platform keys vs BYOK (Bring Your Own Key) +- **Low latency**: Sub-second for simple calls, reasonable for complex + +This document specifies the defense-in-depth architecture that survives success. + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ LLM Abstraction Layer │ +│ │ +│ Request → [Rate Limit] → [Cache Check] → [Queue] → [Dispatch] │ +│ │ │ │ │ │ +│ ▼ ▼ ▼ ▼ │ +│ Per-user Semantic Priority Provider │ +│ per-tier cache queues pool + │ +│ limits (30%+ hits) (by tier) failover │ +│ │ +│ ┌─────────────────────────────────────────────────────────────┐│ +│ │ BYOK (Bring Your Own Key) ││ +│ │ Pro+ users with own API keys bypass platform limits ││ +│ └─────────────────────────────────────────────────────────────┘│ +│ │ +│ ┌─────────────────────────────────────────────────────────────┐│ +│ │ High Frequency Tier ││ +│ │ Dedicated capacity, custom SLA — contact sales ││ +│ └─────────────────────────────────────────────────────────────┘│ +└─────────────────────────────────────────────────────────────────┘ +``` + +## Tier Limits + +| Tier | Price | Requests/min | Tokens/min | Concurrent | Latency SLA | +|------|-------|--------------|------------|------------|-------------| +| **Free** | $0 | 10 | 10,000 | 2 | Best effort | +| **Pro** | $29/mo | 60 | 100,000 | 10 | < 30s P95 | +| **Enterprise** | Custom | 300 | 500,000 | 50 | < 10s P95 | +| **High Frequency** | Custom | Custom | Custom | Dedicated | Custom SLA | +| **BYOK** (any tier) | — | Unlimited* | Unlimited* | 20 | User's provider | + +*BYOK users are limited only by their own provider's rate limits. + +### High Frequency Tier + +For users requiring: +- **Low latency**: Sub-second response times +- **High throughput**: Thousands of requests per minute +- **Guaranteed capacity**: Dedicated provider allocations +- **Custom models**: Fine-tuned or private deployments + +**Use cases:** +- Real-time trading signals +- Live customer support at scale +- High-volume content generation +- Latency-sensitive applications + +**Pricing:** Custom — based on capacity reservation, SLA requirements, and volume. + +**Landing page CTA:** +``` +┌─────────────────────────────────────────────────────────────┐ +│ │ +│ Need High Frequency? │ +│ │ +│ Building something that needs thousands of requests per │ +│ minute with sub-second latency? Let's talk dedicated │ +│ capacity and custom SLAs. │ +│ │ +│ [Contact Sales →] │ +│ │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Layer 1: Intake Rate Limiting + +First line of defense. Rejects requests before they consume resources. + +### Implementation + +```python +from dataclasses import dataclass +from enum import Enum +import time + +class Tier(Enum): + FREE = "free" + PRO = "pro" + ENTERPRISE = "enterprise" + HIGH_FREQUENCY = "high_frequency" + +@dataclass +class TierLimits: + requests_per_minute: int + tokens_per_minute: int + max_concurrent: int + +TIER_LIMITS = { + Tier.FREE: TierLimits(10, 10_000, 2), + Tier.PRO: TierLimits(60, 100_000, 10), + Tier.ENTERPRISE: TierLimits(300, 500_000, 50), + Tier.HIGH_FREQUENCY: TierLimits(10_000, 10_000_000, 500), # Custom per customer +} + +@dataclass +class RateLimitResult: + allowed: bool + use_user_key: bool = False + retry_after: int | None = None + reason: str | None = None + concurrent_key: str | None = None + +async def rate_limit_check(user: User, request: LLMRequest) -> RateLimitResult: + """Check if user can make this request.""" + + # BYOK users bypass platform limits + if user.has_own_api_key(request.provider): + return RateLimitResult(allowed=True, use_user_key=True) + + limits = TIER_LIMITS[user.tier] + + # Check requests per minute (sliding window) + rpm_key = f"ratelimit:{user.id}:rpm" + now = time.time() + window_start = now - 60 + + # Remove old entries, add new one, count + pipe = redis.pipeline() + pipe.zremrangebyscore(rpm_key, 0, window_start) + pipe.zadd(rpm_key, {str(now): now}) + pipe.zcard(rpm_key) + pipe.expire(rpm_key, 120) + _, _, current_rpm, _ = await pipe.execute() + + if current_rpm > limits.requests_per_minute: + return RateLimitResult( + allowed=False, + retry_after=int(60 - (now - window_start)), + reason=f"Rate limit: {limits.requests_per_minute} requests/minute" + ) + + # Check concurrent requests + concurrent_key = f"ratelimit:{user.id}:concurrent" + current_concurrent = await redis.incr(concurrent_key) + await redis.expire(concurrent_key, 300) # 5 min TTL as safety + + if current_concurrent > limits.max_concurrent: + await redis.decr(concurrent_key) + return RateLimitResult( + allowed=False, + retry_after=1, + reason=f"Max concurrent: {limits.max_concurrent} requests" + ) + + return RateLimitResult(allowed=True, concurrent_key=concurrent_key) + +async def release_concurrent(concurrent_key: str): + """Release concurrent slot after request completes.""" + if concurrent_key: + await redis.decr(concurrent_key) +``` + +### Rate Limit Headers + +Return standard headers so clients can self-regulate: + +```python +def rate_limit_headers(user: User) -> dict: + limits = TIER_LIMITS[user.tier] + current = await get_current_usage(user.id) + + return { + "X-RateLimit-Limit": str(limits.requests_per_minute), + "X-RateLimit-Remaining": str(max(0, limits.requests_per_minute - current.rpm)), + "X-RateLimit-Reset": str(int(time.time()) + 60), + } +``` + +## Layer 2: Semantic Cache + +Identical requests return cached responses. Reduces load and cost. + +### Cache Key Generation + +```python +import hashlib +import json + +def hash_request(request: LLMRequest) -> str: + """Generate deterministic cache key for request.""" + + # Include all parameters that affect output + cache_input = { + "model": request.model, + "messages": [ + {"role": m.role, "content": m.content} + for m in request.messages + ], + "temperature": request.temperature, + "max_tokens": request.max_tokens, + "tools": request.tools, # Tool definitions matter + # Exclude: user_id, timestamps, request_id + } + + serialized = json.dumps(cache_input, sort_keys=True) + return hashlib.sha256(serialized.encode()).hexdigest()[:32] +``` + +### Cache Logic + +```python +@dataclass +class CachedResponse: + response: LLMResponse + cached_at: float + hit_count: int + +async def check_semantic_cache(request: LLMRequest) -> LLMResponse | None: + """Check if we've seen this exact request before.""" + + cache_key = f"llmcache:{hash_request(request)}" + cached = await redis.get(cache_key) + + if cached: + data = json.loads(cached) + + # Update hit count for analytics + await redis.hincrby(f"llmcache:stats", "hits", 1) + + return LLMResponse( + content=data["content"], + model=data["model"], + usage=data["usage"], + cached=True, + ) + + await redis.hincrby(f"llmcache:stats", "misses", 1) + return None + +async def cache_response(request: LLMRequest, response: LLMResponse): + """Cache response with TTL based on determinism.""" + + # Don't cache errors or empty responses + if response.error or not response.content: + return + + cache_key = f"llmcache:{hash_request(request)}" + + # TTL based on temperature (determinism) + if request.temperature == 0: + ttl = 86400 # 24 hours for deterministic + elif request.temperature < 0.3: + ttl = 3600 # 1 hour + elif request.temperature < 0.7: + ttl = 300 # 5 minutes + else: + return # Don't cache high-temperature responses + + cache_data = { + "content": response.content, + "model": response.model, + "usage": response.usage, + "cached_at": time.time(), + } + + await redis.setex(cache_key, ttl, json.dumps(cache_data)) +``` + +### Expected Cache Performance + +| Use Case | Temperature | Expected Hit Rate | +|----------|-------------|-------------------| +| Tool calls (same inputs) | 0 | 70-90% | +| Structured extraction | 0-0.3 | 50-70% | +| Agent reasoning | 0.5-0.7 | 20-40% | +| Creative content | 0.8-1.0 | ~0% | + +**Aggregate impact:** 30-40% reduction in API calls for typical workloads. + +## Layer 3: Priority Queues + +Paid users get priority. Free users are served fairly but can be shed under load. + +### Queue Structure + +```python +# Redis sorted set with composite score +# Score = (priority * 1B) + timestamp +# Lower score = higher priority + earlier arrival + +QUEUE_PRIORITIES = { + Tier.HIGH_FREQUENCY: 0, # Highest priority (dedicated customers) + Tier.ENTERPRISE: 1, + Tier.PRO: 2, + "trial": 2, # Trials get Pro priority (first impression) + Tier.FREE: 3, # Lowest priority +} + +@dataclass +class QueuedRequest: + ticket_id: str + user_id: str + tier: str + request: LLMRequest + enqueued_at: float + use_user_key: bool = False + +async def enqueue_request(user: User, request: LLMRequest, use_user_key: bool) -> str: + """Add request to priority queue, return ticket ID.""" + + ticket_id = f"ticket:{uuid.uuid4().hex}" + priority = QUEUE_PRIORITIES.get(user.tier, 3) + + # Composite score: priority (billions) + timestamp (seconds) + score = priority * 1_000_000_000 + time.time() + + queued = QueuedRequest( + ticket_id=ticket_id, + user_id=str(user.id), + tier=user.tier, + request=request, + enqueued_at=time.time(), + use_user_key=use_user_key, + ) + + await redis.zadd("llm:queue", {json.dumps(asdict(queued)): score}) + + # Set a result placeholder + await redis.setex(f"llm:result:{ticket_id}", 300, "pending") + + return ticket_id +``` + +### Queue Workers + +```python +async def queue_worker(): + """Process requests from the queue.""" + + while True: + # Get highest priority item (lowest score) + items = await redis.zpopmin("llm:queue", count=1) + + if not items: + await asyncio.sleep(0.1) # Brief pause if queue empty + continue + + data, score = items[0] + queued = QueuedRequest(**json.loads(data)) + + try: + # Select provider and execute + response = await execute_llm_request(queued) + + # Store result + await redis.setex( + f"llm:result:{queued.ticket_id}", + 300, + json.dumps({"status": "success", "response": asdict(response)}) + ) + + except Exception as e: + await redis.setex( + f"llm:result:{queued.ticket_id}", + 300, + json.dumps({"status": "error", "error": str(e)}) + ) + +async def wait_for_result(ticket_id: str, timeout: float = 120) -> LLMResponse: + """Wait for queued request to complete.""" + + deadline = time.time() + timeout + + while time.time() < deadline: + result = await redis.get(f"llm:result:{ticket_id}") + + if result and result != "pending": + data = json.loads(result) + if data["status"] == "success": + return LLMResponse(**data["response"]) + else: + raise LLMError(data["error"]) + + await asyncio.sleep(0.1) + + raise RequestTimeout("Request timed out") +``` + +### Queue Health Monitoring + +```python +@dataclass +class QueueHealth: + size: int + oldest_wait_seconds: float + by_tier: dict[str, int] + status: str # healthy, degraded, critical + +async def get_queue_health() -> QueueHealth: + """Get queue metrics for monitoring and load shedding.""" + + queue_size = await redis.zcard("llm:queue") + + # Get oldest item + oldest = await redis.zrange("llm:queue", 0, 0, withscores=True) + if oldest: + oldest_score = oldest[0][1] + oldest_time = oldest_score % 1_000_000_000 + wait_time = time.time() - oldest_time + else: + wait_time = 0 + + # Count by tier + all_items = await redis.zrange("llm:queue", 0, -1) + by_tier = {} + for item in all_items: + data = json.loads(item) + tier = data.get("tier", "unknown") + by_tier[tier] = by_tier.get(tier, 0) + 1 + + # Determine status + if queue_size < 500: + status = "healthy" + elif queue_size < 2000: + status = "degraded" + else: + status = "critical" + + return QueueHealth( + size=queue_size, + oldest_wait_seconds=wait_time, + by_tier=by_tier, + status=status, + ) +``` + +## Layer 4: Multi-Provider Pool with Circuit Breakers + +Never depend on a single provider. + +### Provider Configuration + +```python +@dataclass +class ProviderConfig: + name: str + base_url: str + api_key_env: str + models: list[str] + max_concurrent: int + priority: int # Lower = preferred + timeout: float = 60.0 + +PROVIDERS = { + "anthropic": ProviderConfig( + name="anthropic", + base_url="https://api.anthropic.com/v1", + api_key_env="ANTHROPIC_API_KEY", + models=["claude-sonnet-4-20250514", "claude-opus-4-20250514", "claude-haiku-3"], + max_concurrent=100, + priority=1, + ), + "openai": ProviderConfig( + name="openai", + base_url="https://api.openai.com/v1", + api_key_env="OPENAI_API_KEY", + models=["gpt-4o", "gpt-4o-mini", "o1", "o3-mini"], + max_concurrent=50, + priority=2, + ), + "xai": ProviderConfig( + name="xai", + base_url="https://api.x.ai/v1", + api_key_env="XAI_API_KEY", + models=["grok-3", "grok-3-mini"], + max_concurrent=50, + priority=1, + ), + "together": ProviderConfig( + name="together", + base_url="https://api.together.xyz/v1", + api_key_env="TOGETHER_API_KEY", + models=["llama-3-70b", "mixtral-8x7b"], + max_concurrent=100, + priority=3, # Fallback + ), +} +``` + +### Circuit Breaker State + +```python +@dataclass +class CircuitState: + provider: str + healthy: bool = True + failures: int = 0 + successes: int = 0 + last_failure: float = 0 + circuit_open_until: float = 0 + current_load: int = 0 + +# In-memory state (could be Redis for distributed) +CIRCUIT_STATES: dict[str, CircuitState] = { + name: CircuitState(provider=name) + for name in PROVIDERS +} + +CIRCUIT_CONFIG = { + "failure_threshold": 5, # Failures before opening + "success_threshold": 3, # Successes before closing + "open_duration": 30, # Seconds circuit stays open + "half_open_requests": 1, # Requests allowed in half-open state +} + +async def record_success(provider: str): + """Record successful request.""" + state = CIRCUIT_STATES[provider] + state.successes += 1 + state.failures = 0 + + if not state.healthy and state.successes >= CIRCUIT_CONFIG["success_threshold"]: + state.healthy = True + logger.info(f"Circuit closed for {provider}") + +async def record_failure(provider: str, error: Exception): + """Record failed request, potentially open circuit.""" + state = CIRCUIT_STATES[provider] + state.failures += 1 + state.successes = 0 + state.last_failure = time.time() + + if state.failures >= CIRCUIT_CONFIG["failure_threshold"]: + state.healthy = False + state.circuit_open_until = time.time() + CIRCUIT_CONFIG["open_duration"] + logger.error(f"Circuit opened for {provider}: {error}") + await alert_ops(f"LLM provider {provider} circuit opened") + +def is_provider_available(provider: str) -> bool: + """Check if provider can accept requests.""" + state = CIRCUIT_STATES[provider] + config = PROVIDERS[provider] + + # Circuit open? + if not state.healthy: + if time.time() < state.circuit_open_until: + return False + # Half-open: allow limited requests to probe + + # At capacity? + if state.current_load >= config.max_concurrent: + return False + + return True +``` + +### Provider Selection + +```python +def get_providers_for_model(model: str) -> list[str]: + """Get providers that support this model.""" + return [ + name for name, config in PROVIDERS.items() + if model in config.models or any(model.startswith(m.split("-")[0]) for m in config.models) + ] + +async def select_provider(request: LLMRequest, user_key: str | None = None) -> tuple[str, str]: + """Select best available provider, return (provider_name, api_key).""" + + candidates = get_providers_for_model(request.model) + + if not candidates: + raise UnsupportedModel(f"No provider supports model: {request.model}") + + # Filter to available providers + available = [p for p in candidates if is_provider_available(p)] + + if not available: + raise NoProvidersAvailable( + "All providers for this model are currently unavailable. " + "Please try again in a few seconds." + ) + + # Sort by priority, then by current load + available.sort(key=lambda p: ( + PROVIDERS[p].priority, + CIRCUIT_STATES[p].current_load / PROVIDERS[p].max_concurrent + )) + + selected = available[0] + + # Determine API key + if user_key: + api_key = user_key + else: + api_key = os.environ[PROVIDERS[selected].api_key_env] + + return selected, api_key +``` + +## Layer 5: BYOK (Bring Your Own Key) + +Pro+ users can add their own API keys to bypass platform limits. + +### Database Schema + +```sql +CREATE TABLE user_api_keys ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID REFERENCES users(id) ON DELETE CASCADE, + provider VARCHAR(50) NOT NULL, + encrypted_key BYTEA NOT NULL, + key_hint VARCHAR(20), -- Last 4 chars for display: "...abc123" + is_valid BOOLEAN DEFAULT true, + last_used_at TIMESTAMPTZ, + last_error VARCHAR(255), + created_at TIMESTAMPTZ DEFAULT NOW(), + + UNIQUE(user_id, provider) +); + +CREATE INDEX idx_user_api_keys_user ON user_api_keys(user_id); +``` + +### Key Encryption + +```python +from cryptography.fernet import Fernet + +# Platform encryption key (from environment, rotated periodically) +ENCRYPTION_KEY = Fernet(os.environ["API_KEY_ENCRYPTION_KEY"]) + +def encrypt_api_key(key: str) -> bytes: + """Encrypt user's API key for storage.""" + return ENCRYPTION_KEY.encrypt(key.encode()) + +def decrypt_api_key(encrypted: bytes) -> str: + """Decrypt user's API key for use.""" + return ENCRYPTION_KEY.decrypt(encrypted).decode() + +async def store_user_api_key(user_id: str, provider: str, api_key: str): + """Store encrypted API key for user.""" + + # Validate key format + if not validate_key_format(provider, api_key): + raise InvalidAPIKey(f"Invalid {provider} API key format") + + # Test the key + if not await test_api_key(provider, api_key): + raise InvalidAPIKey(f"API key validation failed for {provider}") + + encrypted = encrypt_api_key(api_key) + key_hint = f"...{api_key[-6:]}" + + await db.execute(""" + INSERT INTO user_api_keys (user_id, provider, encrypted_key, key_hint) + VALUES ($1, $2, $3, $4) + ON CONFLICT (user_id, provider) + DO UPDATE SET encrypted_key = $3, key_hint = $4, is_valid = true, last_error = NULL + """, user_id, provider, encrypted, key_hint) + +async def get_user_api_key(user_id: str, provider: str) -> str | None: + """Get decrypted API key for user, if they have one.""" + + row = await db.fetchrow(""" + SELECT encrypted_key, is_valid + FROM user_api_keys + WHERE user_id = $1 AND provider = $2 + """, user_id, provider) + + if not row or not row["is_valid"]: + return None + + return decrypt_api_key(row["encrypted_key"]) +``` + +### BYOK Request Flow + +```python +async def execute_with_byok(user: User, request: LLMRequest) -> LLMResponse: + """Execute request, preferring user's own key if available.""" + + # Check for user's key + user_key = await get_user_api_key(user.id, get_provider_for_model(request.model)) + + if user_key: + # Use user's key - bypass platform rate limits + try: + response = await call_provider_direct(request, user_key) + + # Update last used + await db.execute(""" + UPDATE user_api_keys + SET last_used_at = NOW(), last_error = NULL + WHERE user_id = $1 AND provider = $2 + """, user.id, request.provider) + + return response + + except AuthenticationError: + # Key is invalid - mark it and fall back to platform + await db.execute(""" + UPDATE user_api_keys + SET is_valid = false, last_error = 'Authentication failed' + WHERE user_id = $1 AND provider = $2 + """, user.id, request.provider) + + # Notify user + await send_notification(user, "api_key_invalid", { + "provider": request.provider + }) + + # Fall through to platform key + + # Use platform key (with rate limiting) + return await execute_with_platform_key(user, request) +``` + +## Layer 6: Backpressure & Graceful Degradation + +When overwhelmed, fail gracefully and prioritize paid users. + +### Load Shedding + +```python +async def should_shed_load(user: User, queue_health: QueueHealth) -> bool: + """Determine if this request should be rejected to protect the system.""" + + # High Frequency and Enterprise never shed + if user.tier in [Tier.HIGH_FREQUENCY, Tier.ENTERPRISE]: + return False + + # Pro shed only in critical + if user.tier == Tier.PRO and queue_health.status != "critical": + return False + + # Free tier shed in degraded or critical + if user.tier == Tier.FREE and queue_health.status in ["degraded", "critical"]: + # Probabilistic shedding based on queue size + shed_probability = min(0.9, (queue_health.size - 500) / 2000) + return random.random() < shed_probability + + return False +``` + +### Graceful Error Messages + +```python +class ServiceDegraded(Exception): + """Raised when load shedding rejects a request.""" + + def __init__(self, tier: str, queue_health: QueueHealth): + if tier == Tier.FREE: + message = ( + "We're experiencing high demand. Free tier requests are " + "temporarily paused. Upgrade to Pro for priority access, " + "or try again in a few minutes." + ) + retry_after = 60 + else: + message = ( + "High demand is causing delays. Your request has been queued. " + "Expected wait time: ~{} seconds." + ).format(int(queue_health.oldest_wait_seconds * 1.5)) + retry_after = 30 + + self.message = message + self.retry_after = retry_after + super().__init__(message) +``` + +### Timeout Handling + +```python +async def execute_with_timeout(request: LLMRequest, provider: str, api_key: str) -> LLMResponse: + """Execute request with appropriate timeout.""" + + # Timeout based on expected response size + if request.max_tokens and request.max_tokens > 2000: + timeout = 120 # Long responses need more time + else: + timeout = 60 + + try: + async with asyncio.timeout(timeout): + return await call_provider(request, provider, api_key) + except asyncio.TimeoutError: + await record_failure(provider, TimeoutError("Request timed out")) + raise RequestTimeout( + f"Request timed out after {timeout}s. " + "Try reducing max_tokens or simplifying the prompt." + ) +``` + +## Main Entry Point + +```python +async def handle_llm_request(user: User, request: LLMRequest) -> LLMResponse: + """ + Main entry point for all LLM requests. + Implements full defense-in-depth stack. + """ + + concurrent_key = None + + try: + # Layer 1: Rate limiting + rate_result = await rate_limit_check(user, request) + if not rate_result.allowed: + raise RateLimitExceeded( + message=rate_result.reason, + retry_after=rate_result.retry_after + ) + concurrent_key = rate_result.concurrent_key + + # Layer 2: Semantic cache + cached = await check_semantic_cache(request) + if cached: + return cached + + # Layer 3: Check queue health for load shedding + queue_health = await get_queue_health() + if await should_shed_load(user, queue_health): + raise ServiceDegraded(user.tier, queue_health) + + # Layer 4: Enqueue with priority + ticket_id = await enqueue_request(user, request, rate_result.use_user_key) + + # Layer 5: Wait for result + response = await wait_for_result(ticket_id, timeout=120) + + # Layer 6: Cache successful response + await cache_response(request, response) + + return response + + finally: + # Always release concurrent slot + if concurrent_key: + await release_concurrent(concurrent_key) +``` + +## Monitoring & Alerts + +### Key Metrics + +| Metric | Source | Warning | Critical | +|--------|--------|---------|----------| +| Queue depth | Redis ZCARD | > 500 | > 2000 | +| P50 latency | Request timing | > 10s | > 30s | +| P99 latency | Request timing | > 60s | > 120s | +| Cache hit rate | Redis stats | < 25% | < 10% | +| Provider error rate | Circuit state | > 5% | > 20% | +| Circuit breaker open | Circuit state | Any | Multiple | +| Free tier rejection rate | Load shedding | > 20% | > 50% | + +### Alerting + +```python +# PagerDuty / Slack alerts +ALERTS = { + "queue_critical": { + "condition": lambda h: h.size > 2000, + "severity": "critical", + "message": "LLM queue depth critical: {size} requests backed up" + }, + "provider_down": { + "condition": lambda p: not p.healthy, + "severity": "warning", + "message": "Provider {name} circuit breaker open" + }, + "all_providers_down": { + "condition": lambda: all(not s.healthy for s in CIRCUIT_STATES.values()), + "severity": "critical", + "message": "ALL LLM providers are down!" + }, +} +``` + +### Dashboard Queries + +```sql +-- Requests per minute by tier +SELECT + date_trunc('minute', created_at) as minute, + tier, + COUNT(*) as requests +FROM llm_requests +WHERE created_at > NOW() - INTERVAL '1 hour' +GROUP BY 1, 2 +ORDER BY 1 DESC; + +-- Error rate by provider +SELECT + provider, + COUNT(*) FILTER (WHERE status = 'error') * 100.0 / COUNT(*) as error_rate +FROM llm_requests +WHERE created_at > NOW() - INTERVAL '1 hour' +GROUP BY provider; + +-- BYOK adoption +SELECT + tier, + COUNT(*) FILTER (WHERE used_user_key) * 100.0 / COUNT(*) as byok_percentage +FROM llm_requests +WHERE created_at > NOW() - INTERVAL '24 hours' +GROUP BY tier; +``` + +## Viral Day Playbook + +What to do when that tweet hits: + +### Hour 0-1: Detection +- Alert: Queue depth > 500 +- Action: Monitor, no intervention needed + +### Hour 1-2: Escalation +- Alert: Queue depth > 1000, latency spiking +- Action: + - Verify all provider circuits are healthy + - Check cache hit rate (should be climbing) + - Prepare to enable aggressive load shedding + +### Hour 2-4: Peak +- Alert: Queue depth > 2000, free tier rejections > 30% +- Action: + - Enable aggressive load shedding for free tier + - Send "high demand" email to free users with upgrade CTA + - Monitor Pro/Enterprise latency (must stay < 30s) + - Tweet acknowledgment: "We're experiencing high demand due to [reason]. Pro users unaffected." + +### Hour 4-8: Stabilization +- Queue draining as cache warms and load shedding works +- Many users convert to Pro or add BYOK keys +- Circuits recovering as providers stabilize + +### Post-Mortem +- Review metrics: peak queue, rejection rate, conversion rate +- Adjust tier limits if needed +- Consider adding provider capacity for sustained growth + +--- + +## References + +- [Stripe-style rate limiting](https://stripe.com/docs/rate-limits) +- [Circuit breaker pattern](https://martinfowler.com/bliki/CircuitBreaker.html) +- [Token bucket algorithm](https://en.wikipedia.org/wiki/Token_bucket) +- [BloxServer Billing](bloxserver-billing.md) — Tier definitions and pricing diff --git a/docs/librarian-architecture.md b/docs/librarian-architecture.md new file mode 100644 index 0000000..6d2cf53 --- /dev/null +++ b/docs/librarian-architecture.md @@ -0,0 +1,513 @@ +# Librarian Architecture — RLM-Powered Document Intelligence + +**Status:** Design +**Date:** January 2026 + +## Overview + +The Librarian is an agent that ingests, indexes, and queries large document collections using the **Recursive Language Model (RLM)** pattern. It can handle codebases, documentation, and structured data at scales far beyond LLM context windows (10M+ tokens). + +Key insight from [MIT RLM research](https://arxiv.org/abs/...): Long contexts should be loaded as **variables in a REPL environment**, not fed directly to the neural network. The LLM writes code to examine, decompose, and recursively query chunks. + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────────┐ +│ RLM-Powered Librarian │ +│ │ +│ ┌───────────────────────────────────────────────────────────┐ │ +│ │ Ingestion Pipeline │ │ +│ │ │ │ +│ │ Source → Detect Type → Select Chunker → Index → Store │ │ +│ └───────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌───────────────────────────────────────────────────────────┐ │ +│ │ Query Engine (RLM Pattern) │ │ +│ │ │ │ +│ │ Query → Search → Filter → Recursive Sub-Query → Answer │ │ +│ └───────────────────────────────────────────────────────────┘ │ +│ │ +│ ┌───────────────────────────────────────────────────────────┐ │ +│ │ Storage Layer │ │ +│ │ │ │ +│ │ eXist-db (XML) + Vector Embeddings + Dependency Graph │ │ +│ └───────────────────────────────────────────────────────────┘ │ +└─────────────────────────────────────────────────────────────────┘ +``` + +## The RLM Pattern + +Traditional LLM usage stuffs entire documents into the prompt. This fails at scale: +- Context windows have hard limits (128K-1M tokens) +- Performance degrades with context length ("context rot") +- Cost scales linearly with input size + +**RLM approach:** + +1. **Load as Variable**: Documents become references, not inline content +2. **Programmatic Access**: LLM writes code to peek into chunks +3. **Recursive Sub-Queries**: `llm_query(chunk, question)` for focused analysis +4. **Aggregation**: Combine sub-query results into final answer + +```python +# RLM-style pseudocode +async def handle_query(query: str, codebase: CodebaseRef): + # 1. Search index for relevant chunks (not full content) + hits = await search_index(codebase, query) + + # 2. Filter if too many results + if len(hits) > 10: + hits = await llm_filter(hits, query) # LLM picks most relevant + + # 3. Recursive sub-queries on each chunk + findings = [] + for hit in hits: + chunk = await load_chunk(hit) + result = await llm_query( + f"Analyze this for: {query}\n\n{chunk}" + ) + findings.append(result) + + # 4. Aggregate into final answer + return await llm_synthesize(findings, query) +``` + +## Hybrid Chunking Architecture + +Chunking is domain-specific. A C++ class should stay together; a legal clause shouldn't be split mid-sentence. We use a hybrid approach: + +### Built-in Chunkers (Fast Path) + +| Chunker | File Types | Strategy | Implementation | +|---------|------------|----------|----------------| +| **Code** | .c, .cpp, .py, .js, .rs, ... | AST-aware splitting | tree-sitter | +| **Markdown/Docs** | .md, .rst, .txt | Heading hierarchy | Custom parser | +| **Structured Data** | .json, .xml, .yaml | Schema-aware | lxml + json | +| **Plain Text** | emails, logs, notes | Semantic paragraphs | Sentence boundaries | + +These cover ~90% of use cases with optimized, predictable behavior. + +### WASM Factory (Fallback for Unknown Types) + +For novel formats, the AI generates a custom chunker: + +``` +User uploads proprietary format + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ Step 1: Sample Analysis │ +│ │ +│ AI examines sample files: │ +│ - Structure patterns │ +│ - Record boundaries │ +│ - Semantic units │ +└───────────────────────────────────────────────────────────┘ + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ Step 2: Generate Chunker (Rust → WASM) │ +│ │ +│ AI writes Rust code implementing the chunker interface │ +└───────────────────────────────────────────────────────────┘ + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ Step 3: Compile & Validate │ +│ │ +│ cargo build --target wasm32-wasi │ +│ Test on sample files │ +│ AI reviews output quality │ +└───────────────────────────────────────────────────────────┘ + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ Step 4: Deploy │ +│ │ +│ Store in user's WASM modules │ +│ Optional: publish to marketplace │ +└───────────────────────────────────────────────────────────┘ +``` + +### WASM Chunker Interface (WIT) + +```wit +// chunker.wit +interface chunker { + record chunk { + id: string, + content: string, + metadata: list>, + parent-id: option, + children: list, + } + + record chunker-config { + file-type: string, + max-chunk-size: u32, + preserve-context: bool, + custom-params: list>, + } + + // Analyze sample data, return chunking config + analyze: func(sample: string, file-type: string) -> chunker-config + + // Chunk a file using the config + chunk-file: func(content: string, config: chunker-config) -> list +} +``` + +## Ingestion Pipeline + +### Step 1: Source Acquisition + +```python +@dataclass +class IngestionSource: + type: Literal["git", "upload", "url", "s3"] + location: str + filter: str | None = None # e.g., "*.cpp", "docs/**/*.md" +``` + +Supported sources: +- **Git repository**: Clone and track branches +- **File upload**: Direct upload via UI +- **URL**: Fetch remote documents +- **S3/Cloud storage**: Enterprise integrations + +### Step 2: Type Detection + +```python +def detect_type(file_path: str, content: bytes) -> FileType: + # 1. Check extension + ext = Path(file_path).suffix.lower() + if ext in CODE_EXTENSIONS: + return FileType.CODE + + # 2. Check magic bytes + if content.startswith(b'%PDF'): + return FileType.PDF + + # 3. Content analysis + if looks_like_markdown(content): + return FileType.MARKDOWN + + return FileType.PLAIN_TEXT +``` + +### Step 3: Chunking + +```python +def select_chunker(file_type: FileType, user_config: ChunkerConfig) -> Chunker: + # User override + if user_config.custom_wasm: + return WasmChunker(user_config.custom_wasm) + + # Built-in chunkers + match file_type: + case FileType.CODE: + return TreeSitterChunker(language=detect_language(file_type)) + case FileType.MARKDOWN: + return MarkdownChunker() + case FileType.JSON | FileType.XML | FileType.YAML: + return StructuredDataChunker() + case _: + return PlainTextChunker() +``` + +### Step 4: Indexing + +Each chunk is indexed in multiple ways: + +| Index Type | Purpose | Implementation | +|------------|---------|----------------| +| **Full-text** | Keyword search | eXist-db Lucene | +| **Vector** | Semantic similarity | Embeddings (OpenAI/local) | +| **Graph** | Relationships | Class hierarchy, imports, references | +| **Metadata** | Filtering | File path, type, timestamp | + +### Step 5: Storage + +```xml + + + opencascade:BRepBuilderAPI_MakeEdge:constructor_1 + + opencascade + src/BRepBuilderAPI/BRepBuilderAPI_MakeEdge.cxx + + + function + + BRepBuilderAPI_MakeEdge + public + const TopoDS_Vertex&, const TopoDS_Vertex& + + + [0.023, -0.041, 0.089, ...] + +``` + +## Query Engine + +### Query Flow + +``` +User: "How does BRepBuilderAPI_MakeEdge handle degenerate curves?" + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ Step 1: Search │ +│ │ +│ - Vector search: find semantically similar chunks │ +│ - Keyword search: "BRepBuilderAPI_MakeEdge" + "degenerate"│ +│ - Graph traversal: class hierarchy, method calls │ +│ │ +│ Result: 47 potentially relevant chunks │ +└───────────────────────────────────────────────────────────┘ + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ Step 2: Filter (LLM-assisted) │ +│ │ +│ Too many chunks for direct analysis. │ +│ LLM reviews summaries, picks top 8 most relevant. │ +│ │ +│ Selected: │ +│ - BRepBuilderAPI_MakeEdge constructors (3 chunks) │ +│ - Edge validation methods (2 chunks) │ +│ - Degenerate curve handling (2 chunks) │ +│ - Error reporting (1 chunk) │ +└───────────────────────────────────────────────────────────┘ + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ Step 3: Recursive Sub-Queries │ +│ │ +│ For each chunk, focused LLM query: │ +│ │ +│ llm_query(chunk_1, "How does this handle degenerate...") │ +│ llm_query(chunk_2, "What validation happens here...") │ +│ llm_query(chunk_3, "What errors are raised for...") │ +│ ... │ +│ │ +│ 8 parallel sub-queries → 8 focused findings │ +└───────────────────────────────────────────────────────────┘ + │ + ▼ +┌───────────────────────────────────────────────────────────┐ +│ Step 4: Synthesize │ +│ │ +│ LLM combines findings into coherent answer: │ +│ │ +│ "BRepBuilderAPI_MakeEdge handles degenerate curves by: │ +│ 1. Checking curve bounds in the constructor... │ +│ 2. Calling BRepCheck_Edge for validation... │ +│ 3. Setting myError to BRepBuilderAPI_CurveTooSmall..." │ +└───────────────────────────────────────────────────────────┘ +``` + +### Handler Implementation + +```python +@xmlify +@dataclass +class LibrarianQuery: + """Query the librarian for information.""" + collection: str # Which indexed collection + question: str # Natural language question + max_chunks: int = 10 # Limit for recursive queries + include_sources: bool = True + +@xmlify +@dataclass +class LibrarianResponse: + """Response from librarian with sources.""" + answer: str + sources: list[SourceReference] + confidence: float + +async def handle_librarian_query( + payload: LibrarianQuery, + metadata: HandlerMetadata +) -> HandlerResponse: + """RLM-style query handler.""" + + # 1. Search for relevant chunks + hits = await search_collection( + payload.collection, + payload.question, + limit=50 # Cast wide net + ) + + # 2. Filter if needed + if len(hits) > payload.max_chunks: + hits = await llm_filter_chunks( + hits, + payload.question, + limit=payload.max_chunks + ) + + # 3. Recursive sub-queries + findings = await asyncio.gather(*[ + llm_analyze_chunk(chunk, payload.question) + for chunk in hits + ]) + + # 4. Synthesize answer + answer = await llm_synthesize(findings, payload.question) + + # 5. Build response + sources = [ + SourceReference( + path=hit.source_path, + lines=(hit.start_line, hit.end_line), + relevance=hit.score + ) + for hit in hits + ] + + return HandlerResponse.respond( + payload=LibrarianResponse( + answer=answer, + sources=sources if payload.include_sources else [], + confidence=calculate_confidence(findings) + ) + ) +``` + +## Storage Layer + +### eXist-db (Primary Store) + +XML-native database for chunk storage and XQuery retrieval. + +**Why eXist-db:** +- Native XQuery for complex queries +- Full-text search with Lucene +- XML validation against schemas +- Transactional updates + +**Collections structure:** +``` +/db/librarian/ +├── collections/ +│ ├── {user_id}/ +│ │ ├── {collection_id}/ +│ │ │ ├── metadata.xml +│ │ │ ├── chunks/ +│ │ │ │ ├── chunk_001.xml +│ │ │ │ ├── chunk_002.xml +│ │ │ │ └── ... +│ │ │ └── index/ +│ │ │ └── embeddings.bin +``` + +### Vector Embeddings + +For semantic search, chunks are embedded using: +- OpenAI `text-embedding-3-small` (cloud) +- Sentence Transformers (local/self-hosted) + +Embeddings stored alongside chunks or in dedicated vector DB (Qdrant/Pinecone for scale). + +### Dependency Graph + +For code collections, track relationships: +- **Class hierarchy**: inheritance, interfaces +- **Imports**: file dependencies +- **Call graph**: function → function references + +Stored in eXist-db as XML or external graph DB for complex traversals. + +## Configuration + +### organism.yaml + +```yaml +listeners: + - name: librarian + handler: xml_pipeline.tools.librarian.handle_librarian_query + payload_class: xml_pipeline.tools.librarian.LibrarianQuery + description: Query indexed document collections + agent: true + peers: [] # Terminal handler + config: + exist_db: + url: "http://localhost:8080/exist" + user_env: EXIST_USER + password_env: EXIST_PASSWORD + embeddings: + provider: openai # or "local" + model: text-embedding-3-small + chunkers: + code: + max_chunk_size: 2000 + overlap: 200 + markdown: + split_on_headings: true + min_heading_level: 2 +``` + +### Ingestion API + +```python +# Ingest a git repository +await librarian.ingest( + source=GitSource( + url="https://github.com/Open-Cascade-SAS/OCCT", + branch="master", + filter="src/**/*.cxx" + ), + collection="opencascade", + chunker_config=CodeChunkerConfig( + language="cpp", + max_chunk_size=2000 + ) +) + +# Query the collection +response = await librarian.query( + collection="opencascade", + question="How does BRepBuilderAPI_MakeEdge handle curves?" +) +``` + +## Scaling Considerations + +| Scale | Storage | Search | Compute | +|-------|---------|--------|---------| +| Small (<10K chunks) | eXist-db local | In-DB Lucene | Single node | +| Medium (10K-1M) | eXist-db cluster | + Vector DB | Multi-worker | +| Large (1M+) | Sharded storage | Distributed search | GPU embeddings | + +## Security + +- **Collection isolation**: Users can only query their own collections +- **WASM sandbox**: Custom chunkers run in isolated WASM runtime +- **Rate limiting**: Prevent abuse of recursive queries +- **Audit logging**: Track all queries for compliance + +## Future Enhancements + +1. **Incremental updates**: Re-index only changed files +2. **Cross-collection queries**: Search across multiple codebases +3. **Collaborative collections**: Shared team libraries +4. **Query caching**: Cache common sub-queries +5. **Streaming ingestion**: Real-time updates from git webhooks + +--- + +## References + +- [Recursive Language Models (MIT)](docs/mit-paper.pdf) — Foundational research on RLM pattern +- [tree-sitter](https://tree-sitter.github.io/) — AST-aware code parsing +- [eXist-db](http://exist-db.org/) — XML-native database +- [BloxServer Architecture](bloxserver-architecture.md) — Platform overview