diff --git a/agentserver/__init__.py b/agentserver/__init__.py index e69de29..7357231 100644 --- a/agentserver/__init__.py +++ b/agentserver/__init__.py @@ -0,0 +1,5 @@ +""" +xml-pipeline: Tamper-proof nervous system for multi-agent organisms. +""" + +__version__ = "0.2.0" diff --git a/agentserver/cli.py b/agentserver/cli.py new file mode 100644 index 0000000..dda2601 --- /dev/null +++ b/agentserver/cli.py @@ -0,0 +1,137 @@ +""" +xml-pipeline CLI entry point. + +Usage: + xml-pipeline run [config.yaml] Run an organism + xml-pipeline init [name] Create new organism config + xml-pipeline check [config.yaml] Validate config without running + xml-pipeline version Show version info +""" + +import argparse +import asyncio +import sys +from pathlib import Path + + +def cmd_run(args: argparse.Namespace) -> int: + """Run an organism from config.""" + from agentserver.config.loader import load_config + from agentserver.message_bus import bootstrap + + config_path = Path(args.config) + if not config_path.exists(): + print(f"Error: Config file not found: {config_path}", file=sys.stderr) + return 1 + + try: + config = load_config(config_path) + asyncio.run(bootstrap(config)) + return 0 + except KeyboardInterrupt: + print("\nShutdown requested.") + return 0 + except Exception as e: + print(f"Error: {e}", file=sys.stderr) + return 1 + + +def cmd_init(args: argparse.Namespace) -> int: + """Initialize a new organism config.""" + from agentserver.config.template import create_organism_template + + name = args.name or "my-organism" + output = Path(args.output or f"{name}.yaml") + + if output.exists() and not args.force: + print(f"Error: {output} already exists. Use --force to overwrite.", file=sys.stderr) + return 1 + + template = create_organism_template(name) + output.write_text(template) + print(f"Created {output}") + print(f"\nNext steps:") + print(f" 1. Edit {output} to configure your agents") + print(f" 2. Set your LLM API key: export XAI_API_KEY=...") + print(f" 3. Run: xml-pipeline run {output}") + return 0 + + +def cmd_check(args: argparse.Namespace) -> int: + """Validate config without running.""" + from agentserver.config.loader import load_config, ConfigError + + config_path = Path(args.config) + if not config_path.exists(): + print(f"Error: Config file not found: {config_path}", file=sys.stderr) + return 1 + + try: + config = load_config(config_path) + print(f"Config valid: {config.organism.name}") + print(f" Listeners: {len(config.listeners)}") + print(f" LLM backends: {len(config.llm_backends)}") + + # Check optional features + from agentserver.config.features import check_features + features = check_features(config) + if features.missing: + print(f"\nOptional features needed:") + for feature, reason in features.missing.items(): + print(f" pip install xml-pipeline[{feature}] # {reason}") + + return 0 + except ConfigError as e: + print(f"Config error: {e}", file=sys.stderr) + return 1 + + +def cmd_version(args: argparse.Namespace) -> int: + """Show version and feature info.""" + from agentserver import __version__ + from agentserver.config.features import get_available_features + + print(f"xml-pipeline {__version__}") + print() + print("Installed features:") + for feature, available in get_available_features().items(): + status = "yes" if available else "no" + print(f" {feature}: {status}") + return 0 + + +def main() -> int: + """CLI entry point.""" + parser = argparse.ArgumentParser( + prog="xml-pipeline", + description="Tamper-proof nervous system for multi-agent organisms", + ) + subparsers = parser.add_subparsers(dest="command", required=True) + + # run + run_parser = subparsers.add_parser("run", help="Run an organism") + run_parser.add_argument("config", nargs="?", default="organism.yaml", help="Config file") + run_parser.set_defaults(func=cmd_run) + + # init + init_parser = subparsers.add_parser("init", help="Create new organism config") + init_parser.add_argument("name", nargs="?", help="Organism name") + init_parser.add_argument("-o", "--output", help="Output file path") + init_parser.add_argument("-f", "--force", action="store_true", help="Overwrite existing") + init_parser.set_defaults(func=cmd_init) + + # check + check_parser = subparsers.add_parser("check", help="Validate config") + check_parser.add_argument("config", nargs="?", default="organism.yaml", help="Config file") + check_parser.set_defaults(func=cmd_check) + + # version + version_parser = subparsers.add_parser("version", help="Show version info") + version_parser.set_defaults(func=cmd_version) + + args = parser.parse_args() + return args.func(args) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/agentserver/config/features.py b/agentserver/config/features.py new file mode 100644 index 0000000..e5f1c30 --- /dev/null +++ b/agentserver/config/features.py @@ -0,0 +1,114 @@ +""" +Optional feature detection for xml-pipeline. + +This module checks which optional dependencies are installed and provides +graceful degradation when features are unavailable. +""" + +from dataclasses import dataclass, field +from importlib.util import find_spec +from typing import Callable + + +def _check_import(module: str) -> bool: + """Check if a module can be imported.""" + return find_spec(module) is not None + + +# Feature registry: feature_name -> (check_function, description) +FEATURES: dict[str, tuple[Callable[[], bool], str]] = { + "anthropic": (lambda: _check_import("anthropic"), "Anthropic Claude SDK"), + "openai": (lambda: _check_import("openai"), "OpenAI SDK"), + "redis": (lambda: _check_import("redis"), "Redis for distributed keyvalue"), + "search": (lambda: _check_import("duckduckgo_search"), "DuckDuckGo search"), + "auth": (lambda: _check_import("pyotp") and _check_import("argon2"), "TOTP auth"), + "server": (lambda: _check_import("websockets"), "WebSocket server"), +} + + +def get_available_features() -> dict[str, bool]: + """Return dict of feature_name -> is_available.""" + return {name: check() for name, (check, _) in FEATURES.items()} + + +def is_feature_available(feature: str) -> bool: + """Check if a specific feature is available.""" + if feature not in FEATURES: + return False + check, _ = FEATURES[feature] + return check() + + +def require_feature(feature: str) -> None: + """Raise ImportError if feature is not available.""" + if not is_feature_available(feature): + _, description = FEATURES.get(feature, (None, feature)) + raise ImportError( + f"Feature '{feature}' is not installed. " + f"Install with: pip install xml-pipeline[{feature}]" + ) + + +@dataclass +class FeatureCheck: + """Result of checking features against a config.""" + + available: dict[str, bool] = field(default_factory=dict) + missing: dict[str, str] = field(default_factory=dict) # feature -> reason needed + + +def check_features(config) -> FeatureCheck: + """ + Check which optional features are needed for a config. + + Returns FeatureCheck with available features and missing ones needed by config. + """ + result = FeatureCheck(available=get_available_features()) + + # Check LLM backends + for backend in getattr(config, "llm_backends", []): + provider = getattr(backend, "provider", "").lower() + if provider == "anthropic" and not result.available.get("anthropic"): + result.missing["anthropic"] = f"LLM backend '{backend.name}' uses Anthropic" + if provider == "openai" and not result.available.get("openai"): + result.missing["openai"] = f"LLM backend '{backend.name}' uses OpenAI" + + # Check tools + for listener in getattr(config, "listeners", []): + # If listener uses keyvalue tool and redis is configured + # This would need more sophisticated detection based on tool config + pass + + # Check if auth is needed (multi-tenant mode) + if getattr(config, "auth", None): + if not result.available.get("auth"): + result.missing["auth"] = "Config has auth enabled" + + # Check if websocket server is needed + if getattr(config, "server", None): + if not result.available.get("server"): + result.missing["server"] = "Config has server enabled" + + return result + + +# Lazy import helpers for optional dependencies +def get_redis_client(): + """Get Redis client, or raise helpful error.""" + require_feature("redis") + import redis + return redis + + +def get_anthropic_client(): + """Get Anthropic client, or raise helpful error.""" + require_feature("anthropic") + import anthropic + return anthropic + + +def get_openai_client(): + """Get OpenAI client, or raise helpful error.""" + require_feature("openai") + import openai + return openai diff --git a/agentserver/config/loader.py b/agentserver/config/loader.py new file mode 100644 index 0000000..93870b9 --- /dev/null +++ b/agentserver/config/loader.py @@ -0,0 +1,216 @@ +""" +Configuration loader for xml-pipeline. + +Loads and validates organism.yaml configuration files. +""" + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import yaml + + +class ConfigError(Exception): + """Configuration validation error.""" + + pass + + +@dataclass +class OrganismMeta: + """Organism metadata.""" + + name: str + version: str = "0.1.0" + description: str = "" + + +@dataclass +class LLMBackendConfig: + """LLM backend configuration.""" + + name: str + provider: str # xai, anthropic, openai, ollama + model: str + api_key_env: str | None = None + base_url: str | None = None + priority: int = 0 + + +@dataclass +class ListenerConfig: + """Listener configuration.""" + + name: str + description: str = "" + + # Type flags + agent: bool = False + tool: bool = False + gateway: bool = False + + # Handler (for tools) + handler: str | None = None + payload_class: str | None = None + + # Agent config + prompt: str | None = None + model: str | None = None + + # Routing + peers: list[str] = field(default_factory=list) + + # Tool permissions (for agents) + allowed_tools: list[str] = field(default_factory=list) + blocked_tools: list[str] = field(default_factory=list) + + +@dataclass +class ServerConfig: + """WebSocket server configuration (optional).""" + + enabled: bool = False + host: str = "127.0.0.1" + port: int = 8765 + + +@dataclass +class AuthConfig: + """Authentication configuration (optional).""" + + enabled: bool = False + totp_secret_env: str = "ORGANISM_TOTP_SECRET" + + +@dataclass +class OrganismConfig: + """Complete organism configuration.""" + + organism: OrganismMeta + listeners: list[ListenerConfig] = field(default_factory=list) + llm_backends: list[LLMBackendConfig] = field(default_factory=list) + server: ServerConfig | None = None + auth: AuthConfig | None = None + + +def load_config(path: Path) -> OrganismConfig: + """Load and validate organism configuration from YAML file.""" + with open(path) as f: + raw = yaml.safe_load(f) + + if not isinstance(raw, dict): + raise ConfigError(f"Config must be a YAML mapping, got {type(raw)}") + + # Parse organism metadata + org_raw = raw.get("organism", {}) + if not org_raw.get("name"): + raise ConfigError("organism.name is required") + + organism = OrganismMeta( + name=org_raw["name"], + version=org_raw.get("version", "0.1.0"), + description=org_raw.get("description", ""), + ) + + # Parse LLM backends + llm_backends = [] + for backend_raw in raw.get("llm_backends", []): + if not backend_raw.get("name"): + raise ConfigError("llm_backends[].name is required") + if not backend_raw.get("provider"): + raise ConfigError(f"llm_backends[{backend_raw['name']}].provider is required") + + llm_backends.append( + LLMBackendConfig( + name=backend_raw["name"], + provider=backend_raw["provider"], + model=backend_raw.get("model", ""), + api_key_env=backend_raw.get("api_key_env"), + base_url=backend_raw.get("base_url"), + priority=backend_raw.get("priority", 0), + ) + ) + + # Parse listeners + listeners = [] + for listener_raw in raw.get("listeners", []): + if not listener_raw.get("name"): + raise ConfigError("listeners[].name is required") + + listeners.append( + ListenerConfig( + name=listener_raw["name"], + description=listener_raw.get("description", ""), + agent=listener_raw.get("agent", False), + tool=listener_raw.get("tool", False), + gateway=listener_raw.get("gateway", False), + handler=listener_raw.get("handler"), + payload_class=listener_raw.get("payload_class"), + prompt=listener_raw.get("prompt"), + model=listener_raw.get("model"), + peers=listener_raw.get("peers", []), + allowed_tools=listener_raw.get("allowed_tools", []), + blocked_tools=listener_raw.get("blocked_tools", []), + ) + ) + + # Parse optional server config + server = None + if "server" in raw: + server_raw = raw["server"] + server = ServerConfig( + enabled=server_raw.get("enabled", True), + host=server_raw.get("host", "127.0.0.1"), + port=server_raw.get("port", 8765), + ) + + # Parse optional auth config + auth = None + if "auth" in raw: + auth_raw = raw["auth"] + auth = AuthConfig( + enabled=auth_raw.get("enabled", True), + totp_secret_env=auth_raw.get("totp_secret_env", "ORGANISM_TOTP_SECRET"), + ) + + return OrganismConfig( + organism=organism, + listeners=listeners, + llm_backends=llm_backends, + server=server, + auth=auth, + ) + + +def validate_config(config: OrganismConfig) -> list[str]: + """ + Validate config for common issues. + + Returns list of warning messages (empty if valid). + """ + warnings = [] + + # Check for at least one listener + if not config.listeners: + warnings.append("No listeners defined") + + # Check for LLM backend if agents exist + agents = [l for l in config.listeners if l.agent] + if agents and not config.llm_backends: + warnings.append( + f"Config has {len(agents)} agent(s) but no llm_backends defined" + ) + + # Check peer references + listener_names = {l.name for l in config.listeners} + for listener in config.listeners: + for peer in listener.peers: + # Peer can be "listener_name" or "listener_name.payload_type" + peer_name = peer.split(".")[0] + if peer_name not in listener_names: + warnings.append( + f"Listener '{listener.name}' references unknown peer '{peer_name}'" + ) + + return warnings diff --git a/agentserver/config/template.py b/agentserver/config/template.py new file mode 100644 index 0000000..336ab68 --- /dev/null +++ b/agentserver/config/template.py @@ -0,0 +1,112 @@ +""" +Configuration templates for xml-pipeline. + +Generates starter organism.yaml files. +""" + + +def create_organism_template(name: str = "my-organism") -> str: + """Create a starter organism.yaml configuration.""" + return f'''# {name} - xml-pipeline organism configuration +# Documentation: https://github.com/yourorg/xml-pipeline + +organism: + name: {name} + version: "0.1.0" + description: "A multi-agent organism" + +# ============================================================================= +# LLM BACKENDS +# Configure which LLM providers to use. Agents will use these for inference. +# API keys are read from environment variables. +# ============================================================================= +llm_backends: + - name: primary + provider: xai # xai, anthropic, openai, ollama + model: grok-2 + api_key_env: XAI_API_KEY + priority: 0 # Lower = preferred + + # Uncomment to add fallback backends: + # - name: fallback + # provider: anthropic + # model: claude-3-sonnet-20240229 + # api_key_env: ANTHROPIC_API_KEY + # priority: 1 + +# ============================================================================= +# LISTENERS +# Define agents, tools, and gateways that make up your organism. +# ============================================================================= +listeners: + # ----------------------------------------------------------------------------- + # Example agent - an LLM-powered assistant + # ----------------------------------------------------------------------------- + - name: assistant + agent: true + description: "A helpful assistant that can use tools" + prompt: | + You are a helpful assistant. You can use tools to help users. + Always be concise and accurate. + model: grok-2 # Override default model (optional) + peers: + - calculator # Can call calculator tool + - fetcher # Can call fetch tool + allowed_tools: + - calculate + - fetch + + # ----------------------------------------------------------------------------- + # Example tool - calculator + # ----------------------------------------------------------------------------- + - name: calculator + tool: true + description: "Evaluates mathematical expressions" + handler: agentserver.tools.calculate:calculate_handler + payload_class: agentserver.tools.calculate:CalculateRequest + + # ----------------------------------------------------------------------------- + # Example tool - HTTP fetcher + # ----------------------------------------------------------------------------- + - name: fetcher + tool: true + description: "Fetches content from URLs" + handler: agentserver.tools.fetch:fetch_handler + payload_class: agentserver.tools.fetch:FetchRequest + +# ============================================================================= +# OPTIONAL: WebSocket server for remote connections +# Uncomment to enable. Requires: pip install xml-pipeline[server] +# ============================================================================= +# server: +# enabled: true +# host: "127.0.0.1" +# port: 8765 + +# ============================================================================= +# OPTIONAL: Authentication for privileged operations +# Uncomment to enable. Requires: pip install xml-pipeline[auth] +# ============================================================================= +# auth: +# enabled: true +# totp_secret_env: ORGANISM_TOTP_SECRET +''' + + +def create_minimal_template(name: str = "simple") -> str: + """Create a minimal organism.yaml with just one agent.""" + return f'''organism: + name: {name} + +llm_backends: + - name: default + provider: xai + model: grok-2 + api_key_env: XAI_API_KEY + +listeners: + - name: assistant + agent: true + description: "A simple assistant" + prompt: "You are a helpful assistant." +''' diff --git a/agentserver/listeners/llm_connection.py b/agentserver/listeners/llm_connection.py index 501c6d4..53e96c5 100644 --- a/agentserver/listeners/llm_connection.py +++ b/agentserver/listeners/llm_connection.py @@ -1,138 +1,108 @@ -# llm_connection.py -import asyncio -import logging -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any, Dict, List, Optional +""" +LLM Connection module - provides llm_pool for backward compatibility. -logger = logging.getLogger("agentserver.llm") +The actual implementation lives in agentserver.llm.router. +This module re-exports the router as llm_pool for listeners. +""" + +from agentserver.llm.router import get_router, configure_router, LLMRouter +from agentserver.llm.backend import ( + LLMRequest, + LLMResponse, + Backend, + BackendError, + RateLimitError, + ProviderError, +) + +__all__ = [ + "llm_pool", + "LLMRequest", + "LLMResponse", + "Backend", + "BackendError", + "RateLimitError", + "ProviderError", + "configure_router", +] -@dataclass -class LLMRequest: - """Standardized request shape passed to all providers.""" - messages: List[Dict[str, str]] - model: Optional[str] = None # provider may ignore if fixed in config - temperature: float = 0.7 - max_tokens: Optional[int] = None - tools: Optional[List[Dict]] = None - stream: bool = False - # extra provider-specific kwargs - extra: Dict[str, Any] = None - - -@dataclass -class LLMResponse: - """Unified response shape.""" - content: str - usage: Dict[str, int] # prompt_tokens, completion_tokens, total_tokens - finish_reason: str - raw: Any = None # provider-specific raw response for debugging - - -class LLMConnection(ABC): - """Abstract base class for all LLM providers.""" - - def __init__(self, name: str, config: dict): - self.name = name - self.config = config - self.rate_limit_tpm: Optional[int] = config.get("rate-limit", {}).get("tokens-per-minute") - self.max_concurrent: Optional[int] = config.get("max-concurrent-requests") - self._semaphore = asyncio.Semaphore(self.max_concurrent or 20) - self._token_bucket = None # optional token bucket impl later - - @abstractmethod - async def chat_completion(self, request: LLMRequest) -> LLMResponse: - """Non-streaming completion.""" - pass - - @abstractmethod - async def stream_completion(self, request: LLMRequest): - """Async generator yielding partial content strings.""" - pass - - async def __aenter__(self): - await self._semaphore.acquire() - return self - - async def __aexit__(self, exc_type, exc, tb): - self._semaphore.release() - - -class LLMConnectionPool: +class LLMPool: """ - Global, owner-controlled pool of LLM connections. - Populated at boot or via signed privileged-command. + Wrapper around the LLM router that provides a simpler interface for listeners. + + Usage: + from agentserver.listeners.llm_connection import llm_pool + + response = await llm_pool.complete( + model="grok-2", + messages=[{"role": "user", "content": "Hello"}], + temperature=0.7, + ) """ def __init__(self): - self._pools: Dict[str, LLMConnection] = {} - self._lock = asyncio.Lock() + self._router: LLMRouter | None = None - async def register(self, name: str, config: dict) -> None: + @property + def router(self) -> LLMRouter: + """Get or create the router instance.""" + if self._router is None: + self._router = get_router() + return self._router + + def configure(self, config: dict) -> None: + """Configure the underlying router.""" + self._router = configure_router(config) + + async def complete( + self, + model: str, + messages: list[dict[str, str]], + *, + temperature: float = 0.7, + max_tokens: int | None = None, + tools: list[dict] | None = None, + agent_id: str | None = None, + ) -> str: """ - Add or replace a pool entry. - Called only from boot config or validated privileged-command handler. + Execute a completion and return just the content string. + + This is the simplified interface for listeners - returns just the + response text, not the full LLMResponse object. """ - async with self._lock: - provider_type = config.get("provider") - if provider_type == "xai": - connection = XAIConnection(name, config) - elif provider_type == "anthropic": - connection = AnthropicConnection(name, config) - elif provider_type == "ollama" or provider_type == "local": - connection = OllamaConnection(name, config) - else: - raise ValueError(f"Unknown LLM provider: {provider_type}") + response = await self.router.complete( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + tools=tools, + agent_id=agent_id, + ) + return response.content - old = self._pools.get(name) - if old: - logger.info(f"Replacing LLM pool '{name}'") - else: - logger.info(f"Adding LLM pool '{name}'") + async def complete_full( + self, + model: str, + messages: list[dict[str, str]], + **kwargs, + ) -> LLMResponse: + """ + Execute a completion and return the full LLMResponse. - self._pools[name] = connection + Use this when you need access to usage stats, finish_reason, etc. + """ + return await self.router.complete(model=model, messages=messages, **kwargs) - async def remove(self, name: str) -> None: - async with self._lock: - if name in self._pools: - del self._pools[name] - logger.info(f"Removed LLM pool '{name}'") + def get_usage(self, agent_id: str): + """Get usage stats for an agent.""" + return self.router.get_agent_usage(agent_id) - def get(self, name: str) -> LLMConnection: - """Synchronous get — safe because pools don't change mid-request.""" - try: - return self._pools[name] - except KeyError: - raise KeyError(f"LLM pool '{name}' not configured") from None - - def list_names(self) -> List[str]: - return list(self._pools.keys()) + async def close(self): + """Clean up resources.""" + if self._router: + await self._router.close() -# Example concrete providers (stubs — flesh out with real HTTP later) - -class XAIConnection(LLMConnection): - async def chat_completion(self, request: LLMRequest) -> LLMResponse: - # TODO: real async httpx to https://api.x.ai/v1/chat/completions - raise NotImplementedError - - async def stream_completion(self, request: LLMRequest): - # yield partial deltas - yield "streaming not yet implemented" - - -class AnthropicConnection(LLMConnection): - async def chat_completion(self, request: LLMRequest) -> LLMResponse: - raise NotImplementedError - - async def stream_completion(self, request: LLMRequest): - raise NotImplementedError - - -class OllamaConnection(LLMConnection): - async def chat_completion(self, request: LLMRequest) -> LLMResponse: - raise NotImplementedError - - async def stream_completion(self, request: LLMRequest): - raise NotImplementedError \ No newline at end of file +# Global instance +llm_pool = LLMPool() diff --git a/agentserver/listeners/wasm_listener.py b/agentserver/listeners/wasm_listener.py new file mode 100644 index 0000000..801b1cc --- /dev/null +++ b/agentserver/listeners/wasm_listener.py @@ -0,0 +1,161 @@ +""" +WASM Listener support (STUB). + +Enables custom listeners implemented in WebAssembly/AssemblyScript. +See docs/wasm-listeners.md for specification. + +Status: NOT IMPLEMENTED - interface documented, awaiting implementation. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +logger = logging.getLogger(__name__) + + +class WasmNotImplementedError(NotImplementedError): + """WASM listener support is not yet implemented.""" + + def __init__(self): + super().__init__( + "WASM listener support is not yet implemented. " + "See docs/wasm-listeners.md for the planned interface. " + "For now, implement listeners in Python." + ) + + +@dataclass +class WasmListenerConfig: + """Configuration for a WASM listener.""" + + name: str + wasm_path: Path + wit_path: Path + memory_limit_mb: int = 64 + timeout_seconds: float = 5.0 + keep_hot: bool = True # Keep instance loaded between calls + + +@dataclass +class WasmInstance: + """A loaded WASM module instance (STUB).""" + + config: WasmListenerConfig + # Future: wasmtime.Instance or wasmer.Instance + _module: Any = field(default=None, repr=False) + _instance: Any = field(default=None, repr=False) + + def call(self, handler: str, input_json: str) -> str: + """Call a handler with JSON input, return JSON output.""" + raise WasmNotImplementedError() + + def close(self) -> None: + """Release WASM instance resources.""" + pass + + +class WasmListenerRegistry: + """ + Registry for WASM listeners (STUB). + + Usage: + from agentserver.listeners.wasm_listener import wasm_registry + + wasm_registry.register( + name="calculator", + wasm_path=Path("/uploads/calculator.wasm"), + wit_path=Path("/uploads/calculator.wit"), + ) + """ + + def __init__(self): + self._listeners: dict[str, WasmListenerConfig] = {} + self._instances: dict[str, WasmInstance] = {} # thread_id -> instance + + def register( + self, + name: str, + wasm_path: Path, + wit_path: Path, + **config, + ) -> None: + """ + Register a WASM listener. + + Args: + name: Listener name (must be unique) + wasm_path: Path to .wasm file + wit_path: Path to .wit interface file + **config: Additional config (memory_limit_mb, timeout_seconds, etc.) + """ + raise WasmNotImplementedError() + + def unregister(self, name: str) -> None: + """Remove a WASM listener.""" + raise WasmNotImplementedError() + + def get_instance(self, name: str, thread_id: str) -> WasmInstance: + """Get or create a WASM instance for a thread.""" + raise WasmNotImplementedError() + + def prune_thread(self, thread_id: str) -> None: + """Release WASM instances for a pruned thread.""" + # This will be called by thread registry on cleanup + instances_to_remove = [ + key for key in self._instances + if key.endswith(f":{thread_id}") + ] + for key in instances_to_remove: + instance = self._instances.pop(key, None) + if instance: + instance.close() + + def list_listeners(self) -> list[str]: + """List registered WASM listener names.""" + return list(self._listeners.keys()) + + +# Global registry instance +wasm_registry = WasmListenerRegistry() + + +def register_wasm_listener( + name: str, + wasm_path: str | Path, + wit_path: str | Path, + **config, +) -> None: + """ + Convenience function to register a WASM listener. + + See docs/wasm-listeners.md for full specification. + + Args: + name: Unique listener name + wasm_path: Path to .wasm module + wit_path: Path to .wit interface definition + **config: Optional config overrides + + Raises: + WasmNotImplementedError: WASM support not yet implemented + """ + wasm_registry.register( + name=name, + wasm_path=Path(wasm_path), + wit_path=Path(wit_path), + **config, + ) + + +async def create_wasm_handler(config: WasmListenerConfig): + """ + Create an async handler function for a WASM listener. + + Returns a handler compatible with the standard listener interface: + async def handler(payload: DataClass, metadata: HandlerMetadata) -> HandlerResponse + """ + raise WasmNotImplementedError() diff --git a/agentserver/tools/__init__.py b/agentserver/tools/__init__.py index d1d1e37..feece43 100644 --- a/agentserver/tools/__init__.py +++ b/agentserver/tools/__init__.py @@ -8,11 +8,12 @@ to interact with the outside world. from .base import Tool, ToolResult, tool, get_tool_registry from .calculate import calculate from .fetch import fetch_url -from .files import read_file, write_file, list_dir -from .shell import run_command -from .search import web_search +from .files import read_file, write_file, list_dir, delete_file, configure_allowed_paths +from .shell import run_command, configure_allowed_commands, configure_blocked_commands +from .search import web_search, configure_search from .keyvalue import key_value_get, key_value_set, key_value_delete -from .librarian import librarian_store, librarian_get, librarian_query, librarian_search +from .librarian import librarian_store, librarian_get, librarian_query, librarian_search, configure_librarian +from .convert import xml_to_json, json_to_xml, xml_extract __all__ = [ # Base @@ -20,12 +21,19 @@ __all__ = [ "ToolResult", "tool", "get_tool_registry", + # Configuration + "configure_allowed_paths", + "configure_allowed_commands", + "configure_blocked_commands", + "configure_search", + "configure_librarian", # Tools "calculate", "fetch_url", "read_file", "write_file", "list_dir", + "delete_file", "run_command", "web_search", "key_value_get", @@ -35,4 +43,8 @@ __all__ = [ "librarian_get", "librarian_query", "librarian_search", + # Conversion + "xml_to_json", + "json_to_xml", + "xml_extract", ] diff --git a/agentserver/tools/base.py b/agentserver/tools/base.py index 0b21829..a048069 100644 --- a/agentserver/tools/base.py +++ b/agentserver/tools/base.py @@ -84,7 +84,14 @@ def tool(func: Callable) -> Callable: for param_name, param in sig.parameters.items(): param_info = {"name": param_name} if param.annotation != inspect.Parameter.empty: - param_info["type"] = param.annotation.__name__ + ann = param.annotation + # Handle both string annotations (from __future__ import annotations) and type objects + if isinstance(ann, str): + param_info["type"] = ann + elif hasattr(ann, "__name__"): + param_info["type"] = ann.__name__ + else: + param_info["type"] = str(ann) if param.default != inspect.Parameter.empty: param_info["default"] = param.default parameters[param_name] = param_info diff --git a/agentserver/tools/calculate.py b/agentserver/tools/calculate.py index ddc98b4..a37ec03 100644 --- a/agentserver/tools/calculate.py +++ b/agentserver/tools/calculate.py @@ -1,34 +1,157 @@ """ -Calculate tool - evaluate mathematical expressions. +Calculate tool - evaluate mathematical expressions safely. -Uses simpleeval for safe expression evaluation with Python syntax. +Uses a restricted AST evaluator for safe expression evaluation. +No external dependencies required. """ +from __future__ import annotations + +import ast +import math +import operator +from typing import Any, Union + from .base import tool, ToolResult -# TODO: pip install simpleeval -# from simpleeval import simple_eval -# import math +# Allowed operations +OPERATORS = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.FloorDiv: operator.floordiv, + ast.Mod: operator.mod, + ast.Pow: operator.pow, + ast.USub: operator.neg, + ast.UAdd: operator.pos, +} + +COMPARISONS = { + ast.Eq: operator.eq, + ast.NotEq: operator.ne, + ast.Lt: operator.lt, + ast.LtE: operator.le, + ast.Gt: operator.gt, + ast.GtE: operator.ge, +} + +# Allowed functions MATH_FUNCTIONS = { - # "abs": abs, - # "round": round, - # "min": min, - # "max": max, - # "sqrt": math.sqrt, - # "sin": math.sin, - # "cos": math.cos, - # "tan": math.tan, - # "log": math.log, - # "log10": math.log10, + "abs": abs, + "round": round, + "min": min, + "max": max, + "sqrt": math.sqrt, + "sin": math.sin, + "cos": math.cos, + "tan": math.tan, + "asin": math.asin, + "acos": math.acos, + "atan": math.atan, + "log": math.log, + "log10": math.log10, + "log2": math.log2, + "exp": math.exp, + "floor": math.floor, + "ceil": math.ceil, + "pow": pow, } +# Allowed constants MATH_CONSTANTS = { - # "pi": math.pi, - # "e": math.e, + "pi": math.pi, + "e": math.e, + "tau": math.tau, + "inf": math.inf, } +class SafeEvaluator(ast.NodeVisitor): + """Safely evaluate mathematical expressions using AST.""" + + def visit(self, node: ast.AST) -> Any: + """Visit a node.""" + method = f"visit_{node.__class__.__name__}" + visitor = getattr(self, method, self.generic_visit) + return visitor(node) + + def generic_visit(self, node: ast.AST) -> None: + """Reject unknown node types.""" + raise ValueError(f"Unsupported operation: {node.__class__.__name__}") + + def visit_Expression(self, node: ast.Expression) -> Any: + return self.visit(node.body) + + def visit_Constant(self, node: ast.Constant) -> Union[int, float]: + if isinstance(node.value, (int, float)): + return node.value + raise ValueError(f"Unsupported constant type: {type(node.value)}") + + def visit_Num(self, node: ast.Num) -> Union[int, float]: + # Python 3.7 compatibility + return node.n + + def visit_Name(self, node: ast.Name) -> Any: + if node.id in MATH_CONSTANTS: + return MATH_CONSTANTS[node.id] + raise ValueError(f"Unknown variable: {node.id}") + + def visit_BinOp(self, node: ast.BinOp) -> Any: + op_type = type(node.op) + if op_type not in OPERATORS: + raise ValueError(f"Unsupported operator: {op_type.__name__}") + left = self.visit(node.left) + right = self.visit(node.right) + return OPERATORS[op_type](left, right) + + def visit_UnaryOp(self, node: ast.UnaryOp) -> Any: + op_type = type(node.op) + if op_type not in OPERATORS: + raise ValueError(f"Unsupported operator: {op_type.__name__}") + operand = self.visit(node.operand) + return OPERATORS[op_type](operand) + + def visit_Compare(self, node: ast.Compare) -> bool: + left = self.visit(node.left) + for op, comparator in zip(node.ops, node.comparators): + op_type = type(op) + if op_type not in COMPARISONS: + raise ValueError(f"Unsupported comparison: {op_type.__name__}") + right = self.visit(comparator) + if not COMPARISONS[op_type](left, right): + return False + left = right + return True + + def visit_Call(self, node: ast.Call) -> Any: + if not isinstance(node.func, ast.Name): + raise ValueError("Only named function calls are allowed") + func_name = node.func.id + if func_name not in MATH_FUNCTIONS: + raise ValueError(f"Unknown function: {func_name}") + args = [self.visit(arg) for arg in node.args] + return MATH_FUNCTIONS[func_name](*args) + + def visit_IfExp(self, node: ast.IfExp) -> Any: + # Support ternary: a if condition else b + test = self.visit(node.test) + if test: + return self.visit(node.body) + return self.visit(node.orelse) + + +def safe_eval(expression: str) -> Any: + """Safely evaluate a mathematical expression.""" + try: + tree = ast.parse(expression, mode="eval") + except SyntaxError as e: + raise ValueError(f"Invalid syntax: {e}") + evaluator = SafeEvaluator() + return evaluator.visit(tree) + + @tool async def calculate(expression: str) -> ToolResult: """ @@ -37,24 +160,19 @@ async def calculate(expression: str) -> ToolResult: Supported: - Basic ops: + - * / // % ** - Comparisons: < > <= >= == != - - Functions: abs, round, min, max, sqrt, sin, cos, tan, log, log10 - - Constants: pi, e + - Functions: abs, round, min, max, sqrt, sin, cos, tan, log, log10, exp, floor, ceil + - Constants: pi, e, tau, inf - Parentheses for grouping + - Ternary expressions: a if condition else b Examples: - "2 + 2" → 4 - "(10 + 5) * 3" → 45 - "sqrt(16) + pi" → 7.141592... + - "max(1, 2, 3)" → 3 """ - # TODO: Implement with simpleeval - # try: - # result = simple_eval( - # expression, - # functions=MATH_FUNCTIONS, - # names=MATH_CONSTANTS, - # ) - # return ToolResult(success=True, data=result) - # except Exception as e: - # return ToolResult(success=False, error=str(e)) - - return ToolResult(success=False, error="Not implemented - install simpleeval") + try: + result = safe_eval(expression) + return ToolResult(success=True, data=result) + except Exception as e: + return ToolResult(success=False, error=str(e)) diff --git a/agentserver/tools/convert.py b/agentserver/tools/convert.py new file mode 100644 index 0000000..2d921fe --- /dev/null +++ b/agentserver/tools/convert.py @@ -0,0 +1,213 @@ +""" +XML/JSON conversion tools. + +Enables agents to interoperate with JSON-based APIs and tools (n8n, webhooks, REST APIs). +""" + +from __future__ import annotations + +import json +import re +import xml.etree.ElementTree as ET +from typing import Any + +from .base import tool, ToolResult + + +def _xml_to_dict(element: ET.Element) -> dict | str | list: + """Recursively convert XML element to dict.""" + # If element has no children, return text content + if len(element) == 0: + text = (element.text or "").strip() + # Try to parse as number/bool + if text.lower() == "true": + return True + if text.lower() == "false": + return False + if text == "": + return None + try: + if "." in text: + return float(text) + return int(text) + except ValueError: + return text + + result = {} + + # Add attributes with @ prefix + for key, value in element.attrib.items(): + result[f"@{key}"] = value + + # Process children + for child in element: + child_data = _xml_to_dict(child) + tag = child.tag + + # Handle multiple children with same tag -> array + if tag in result: + if not isinstance(result[tag], list): + result[tag] = [result[tag]] + result[tag].append(child_data) + else: + result[tag] = child_data + + return result + + +def _dict_to_xml(data: Any, tag: str = "item", parent: ET.Element | None = None) -> ET.Element: + """Recursively convert dict to XML element.""" + if parent is None: + elem = ET.Element(tag) + else: + elem = ET.SubElement(parent, tag) + + if isinstance(data, dict): + for key, value in data.items(): + if key.startswith("@"): + # Attribute + elem.set(key[1:], str(value)) + elif isinstance(value, list): + # Multiple children + for item in value: + _dict_to_xml(item, key, elem) + elif isinstance(value, dict): + # Nested object + _dict_to_xml(value, key, elem) + else: + # Simple value as child element + child = ET.SubElement(elem, key) + if value is not None: + child.text = str(value) + elif isinstance(data, list): + for item in data: + _dict_to_xml(item, "item", elem) + else: + if data is not None: + elem.text = str(data) + + return elem + + +@tool +async def xml_to_json( + xml_string: str, + strip_root: bool = True, +) -> ToolResult: + """ + Convert XML to JSON. + + Use this to prepare data for JSON APIs, webhooks, n8n workflows, etc. + + Args: + xml_string: XML content to convert + strip_root: If True, unwrap single root element (default: True) + + Returns: + json: The JSON string + data: The parsed data as dict + + Example: + Alice30 + → {"name": "Alice", "age": 30} + """ + try: + # Parse XML + root = ET.fromstring(xml_string.strip()) + data = _xml_to_dict(root) + + # Optionally strip the root element wrapper + if strip_root and isinstance(data, dict) and len(data) == 1: + # Check if we should unwrap + pass # Keep as-is, root is already stripped by _xml_to_dict + + # Wrap with root tag name if it's meaningful + result = {root.tag: data} if not strip_root else data + + return ToolResult(success=True, data={ + "json": json.dumps(result, indent=2), + "data": result, + }) + except ET.ParseError as e: + return ToolResult(success=False, error=f"Invalid XML: {e}") + except Exception as e: + return ToolResult(success=False, error=f"Conversion error: {e}") + + +@tool +async def json_to_xml( + json_string: str, + root_tag: str = "data", + pretty: bool = True, +) -> ToolResult: + """ + Convert JSON to XML. + + Use this to convert responses from JSON APIs back to XML format. + + Args: + json_string: JSON content to convert + root_tag: Name for the root XML element (default: "data") + pretty: Pretty-print with indentation (default: True) + + Returns: + xml: The XML string + + Example: + {"name": "Alice", "age": 30} + → Alice30 + """ + try: + data = json.loads(json_string) + root = _dict_to_xml(data, root_tag) + + if pretty: + ET.indent(root) + + xml_str = ET.tostring(root, encoding="unicode") + + return ToolResult(success=True, data={ + "xml": xml_str, + }) + except json.JSONDecodeError as e: + return ToolResult(success=False, error=f"Invalid JSON: {e}") + except Exception as e: + return ToolResult(success=False, error=f"Conversion error: {e}") + + +@tool +async def xml_extract( + xml_string: str, + xpath: str, +) -> ToolResult: + """ + Extract data from XML using XPath. + + Args: + xml_string: XML content + xpath: XPath expression (e.g., ".//item", "./users/user[@id='1']") + + Returns: + matches: List of matching elements as dicts + count: Number of matches + """ + try: + root = ET.fromstring(xml_string.strip()) + elements = root.findall(xpath) + + matches = [] + for elem in elements: + matches.append({ + "tag": elem.tag, + "attributes": dict(elem.attrib), + "data": _xml_to_dict(elem), + }) + + return ToolResult(success=True, data={ + "matches": matches, + "count": len(matches), + }) + except ET.ParseError as e: + return ToolResult(success=False, error=f"Invalid XML: {e}") + except Exception as e: + return ToolResult(success=False, error=f"XPath error: {e}") diff --git a/agentserver/tools/fetch.py b/agentserver/tools/fetch.py index ff12856..afed8cb 100644 --- a/agentserver/tools/fetch.py +++ b/agentserver/tools/fetch.py @@ -1,12 +1,82 @@ """ -Fetch tool - HTTP requests. +Fetch tool - HTTP requests with security controls. Uses aiohttp for async HTTP operations. """ +from __future__ import annotations + +import ipaddress +import socket from typing import Optional, Dict +from urllib.parse import urlparse + from .base import tool, ToolResult +# Try to import aiohttp - optional dependency +try: + import aiohttp + AIOHTTP_AVAILABLE = True +except ImportError: + AIOHTTP_AVAILABLE = False + + +# Security configuration +MAX_RESPONSE_SIZE = 10 * 1024 * 1024 # 10 MB +DEFAULT_TIMEOUT = 30 +ALLOWED_SCHEMES = {"http", "https"} +BLOCKED_HOSTS = { + "localhost", + "127.0.0.1", + "0.0.0.0", + "::1", + "metadata.google.internal", # GCP metadata + "169.254.169.254", # AWS/Azure/GCP metadata +} + + +def _is_private_ip(hostname: str) -> bool: + """Check if hostname resolves to a private/internal IP.""" + try: + # Try to parse as IP address first + try: + ip = ipaddress.ip_address(hostname) + return ip.is_private or ip.is_loopback or ip.is_link_local + except ValueError: + pass + + # Resolve hostname to IP + ip_str = socket.gethostbyname(hostname) + ip = ipaddress.ip_address(ip_str) + return ip.is_private or ip.is_loopback or ip.is_link_local + except (socket.gaierror, socket.herror): + # Can't resolve - block by default for security + return True + + +def _validate_url(url: str, allow_internal: bool = False) -> Optional[str]: + """Validate URL for security. Returns error message or None if OK.""" + try: + parsed = urlparse(url) + except Exception: + return "Invalid URL format" + + if parsed.scheme not in ALLOWED_SCHEMES: + return f"Scheme '{parsed.scheme}' not allowed. Use http or https." + + if not parsed.netloc: + return "URL must have a host" + + hostname = parsed.hostname or "" + + if hostname in BLOCKED_HOSTS: + return f"Host '{hostname}' is blocked" + + if not allow_internal and _is_private_ip(hostname): + return f"Access to internal/private IPs is not allowed" + + return None + @tool async def fetch_url( @@ -14,35 +84,91 @@ async def fetch_url( method: str = "GET", headers: Optional[Dict[str, str]] = None, body: Optional[str] = None, - timeout: int = 30, + timeout: int = DEFAULT_TIMEOUT, + allow_internal: bool = False, ) -> ToolResult: """ Fetch content from a URL. Args: url: The URL to fetch - method: HTTP method (GET, POST, PUT, DELETE) + method: HTTP method (GET, POST, PUT, DELETE, PATCH, HEAD) headers: Optional HTTP headers - body: Optional request body for POST/PUT - timeout: Request timeout in seconds + body: Optional request body for POST/PUT/PATCH + timeout: Request timeout in seconds (default: 30, max: 300) + allow_internal: Allow internal/private IPs (default: false) Returns: - status_code, headers, body + status_code, headers, body, url (final URL after redirects) Security: - - URL allowlist/blocklist configurable + - Only http/https schemes allowed + - No access to localhost, metadata endpoints, or private IPs by default + - Response size limited to 10 MB - Timeout enforced - - Response size limit - - No file:// or internal IPs by default """ - # TODO: Implement with aiohttp - # import aiohttp - # async with aiohttp.ClientSession() as session: - # async with session.request(method, url, headers=headers, data=body, timeout=timeout) as resp: - # return ToolResult(success=True, data={ - # "status_code": resp.status, - # "headers": dict(resp.headers), - # "body": await resp.text(), - # }) + if not AIOHTTP_AVAILABLE: + return ToolResult( + success=False, + error="aiohttp not installed. Install with: pip install xml-pipeline[server]" + ) - return ToolResult(success=False, error="Not implemented") + # Validate URL + if error := _validate_url(url, allow_internal): + return ToolResult(success=False, error=error) + + # Validate method + method = method.upper() + allowed_methods = {"GET", "POST", "PUT", "DELETE", "PATCH", "HEAD", "OPTIONS"} + if method not in allowed_methods: + return ToolResult(success=False, error=f"Method '{method}' not allowed") + + # Clamp timeout + timeout = min(max(1, timeout), 300) + + try: + client_timeout = aiohttp.ClientTimeout(total=timeout) + async with aiohttp.ClientSession(timeout=client_timeout) as session: + async with session.request( + method, + url, + headers=headers, + data=body, + ) as resp: + # Check response size before reading + content_length = resp.headers.get("Content-Length") + if content_length and int(content_length) > MAX_RESPONSE_SIZE: + return ToolResult( + success=False, + error=f"Response too large: {content_length} bytes (max: {MAX_RESPONSE_SIZE})" + ) + + # Read response with size limit + body_bytes = await resp.content.read(MAX_RESPONSE_SIZE + 1) + if len(body_bytes) > MAX_RESPONSE_SIZE: + return ToolResult( + success=False, + error=f"Response exceeded {MAX_RESPONSE_SIZE} bytes" + ) + + # Try to decode as text + try: + body_text = body_bytes.decode("utf-8") + except UnicodeDecodeError: + # Return base64 for binary content + import base64 + body_text = base64.b64encode(body_bytes).decode("ascii") + + return ToolResult(success=True, data={ + "status_code": resp.status, + "headers": dict(resp.headers), + "body": body_text, + "url": str(resp.url), # Final URL after redirects + }) + + except aiohttp.ClientError as e: + return ToolResult(success=False, error=f"HTTP error: {e}") + except TimeoutError: + return ToolResult(success=False, error=f"Request timed out after {timeout}s") + except Exception as e: + return ToolResult(success=False, error=f"Fetch error: {e}") diff --git a/agentserver/tools/files.py b/agentserver/tools/files.py index e003226..3dff5f3 100644 --- a/agentserver/tools/files.py +++ b/agentserver/tools/files.py @@ -1,25 +1,48 @@ """ File tools - sandboxed file system operations. + +All paths are validated against configured allowed directories. """ -from typing import Optional, List +from __future__ import annotations + +import base64 from pathlib import Path +from typing import Optional, List + from .base import tool, ToolResult -# TODO: Configure allowed paths -ALLOWED_PATHS: List[Path] = [] +# Security configuration +_allowed_paths: List[Path] = [] +MAX_FILE_SIZE = 10 * 1024 * 1024 # 10 MB +MAX_LISTING_ENTRIES = 1000 -def _validate_path(path: str) -> Optional[str]: - """Validate path is within allowed directories.""" - # TODO: Implement chroot validation - # resolved = Path(path).resolve() - # for allowed in ALLOWED_PATHS: - # if resolved.is_relative_to(allowed): - # return None - # return f"Path {path} not in allowed directories" - return None # Stub: allow all for now +def configure_allowed_paths(paths: List[str | Path]) -> None: + global _allowed_paths + _allowed_paths = [Path(p).resolve() for p in paths] + + +def _validate_path(path: str) -> tuple[Optional[str], Optional[Path]]: + if not _allowed_paths: + try: + return None, Path(path).resolve() + except Exception as e: + return f"Invalid path: {e}", None + try: + resolved = Path(path).resolve() + except Exception as e: + return f"Invalid path: {e}", None + if ".." in str(path): + return "Path traversal (..) not allowed", None + for allowed in _allowed_paths: + try: + resolved.relative_to(allowed) + return None, resolved + except ValueError: + continue + return "Path not in allowed directories", None @tool @@ -27,36 +50,43 @@ async def read_file( path: str, encoding: str = "utf-8", binary: bool = False, + offset: int = 0, + limit: Optional[int] = None, ) -> ToolResult: - """ - Read contents of a file. - - Args: - path: Path to file - encoding: Text encoding (default: utf-8) - binary: Return base64 if true (default: false) - - Security: - - Chroot to allowed directories - - No path traversal (..) - - Size limit enforced - """ - if error := _validate_path(path): + error, resolved = _validate_path(path) + if error: return ToolResult(success=False, error=error) - - # TODO: Implement - # try: - # p = Path(path) - # if binary: - # import base64 - # content = base64.b64encode(p.read_bytes()).decode() - # else: - # content = p.read_text(encoding=encoding) - # return ToolResult(success=True, data=content) - # except Exception as e: - # return ToolResult(success=False, error=str(e)) - - return ToolResult(success=False, error="Not implemented") + if not resolved.exists(): + return ToolResult(success=False, error=f"File not found: {path}") + if not resolved.is_file(): + return ToolResult(success=False, error=f"Not a file: {path}") + try: + file_size = resolved.stat().st_size + read_size = min(limit or MAX_FILE_SIZE, MAX_FILE_SIZE) + if binary: + with open(resolved, "rb") as f: + if offset: + f.seek(offset) + content = f.read(read_size) + return ToolResult(success=True, data={ + "content": base64.b64encode(content).decode("ascii"), + "size": file_size, + "encoding": "base64", + }) + else: + with open(resolved, "r", encoding=encoding) as f: + if offset: + f.seek(offset) + content = f.read(read_size) + return ToolResult(success=True, data={ + "content": content, + "size": file_size, + "encoding": encoding, + }) + except UnicodeDecodeError: + return ToolResult(success=False, error=f"Cannot decode as {encoding}. Try binary=true.") + except Exception as e: + return ToolResult(success=False, error=f"Read error: {e}") @tool @@ -65,71 +95,98 @@ async def write_file( content: str, mode: str = "overwrite", encoding: str = "utf-8", + binary: bool = False, + create_dirs: bool = False, ) -> ToolResult: - """ - Write content to a file. - - Args: - path: Path to file - content: Content to write - mode: "overwrite" or "append" (default: overwrite) - encoding: Text encoding (default: utf-8) - - Security: - - Chroot to allowed directories - - No path traversal - - Max file size enforced - """ - if error := _validate_path(path): + error, resolved = _validate_path(path) + if error: return ToolResult(success=False, error=error) - - # TODO: Implement - # try: - # p = Path(path) - # if mode == "append": - # with open(p, "a", encoding=encoding) as f: - # f.write(content) - # else: - # p.write_text(content, encoding=encoding) - # return ToolResult(success=True, data={"bytes_written": len(content.encode(encoding))}) - # except Exception as e: - # return ToolResult(success=False, error=str(e)) - - return ToolResult(success=False, error="Not implemented") + if binary: + try: + data = base64.b64decode(content) + except Exception as e: + return ToolResult(success=False, error=f"Invalid base64: {e}") + else: + data = content.encode(encoding) + if len(data) > MAX_FILE_SIZE: + return ToolResult(success=False, error=f"Content too large: {len(data)} bytes") + try: + if create_dirs: + resolved.parent.mkdir(parents=True, exist_ok=True) + if binary: + write_mode = "ab" if mode == "append" else "wb" + with open(resolved, write_mode) as f: + f.write(data) + else: + if mode == "append": + with open(resolved, "a", encoding=encoding) as f: + f.write(content) + else: + resolved.write_text(content, encoding=encoding) + return ToolResult(success=True, data={"bytes_written": len(data), "path": str(resolved)}) + except Exception as e: + return ToolResult(success=False, error=f"Write error: {e}") @tool async def list_dir( path: str, pattern: str = "*", + recursive: bool = False, + include_hidden: bool = False, ) -> ToolResult: - """ - List directory contents. - - Args: - path: Directory path - pattern: Glob pattern filter (default: *) - - Returns: - Array of {name, type, size, modified} - """ - if error := _validate_path(path): + error, resolved = _validate_path(path) + if error: return ToolResult(success=False, error=error) + if not resolved.exists(): + return ToolResult(success=False, error=f"Directory not found: {path}") + if not resolved.is_dir(): + return ToolResult(success=False, error=f"Not a directory: {path}") + try: + entries = [] + glob_method = resolved.rglob if recursive else resolved.glob + for entry in glob_method(pattern): + if not include_hidden and entry.name.startswith("."): + continue + try: + stat = entry.stat() + entries.append({ + "name": str(entry.relative_to(resolved)), + "type": "dir" if entry.is_dir() else "file", + "size": stat.st_size if entry.is_file() else None, + "modified": stat.st_mtime, + }) + except (OSError, PermissionError): + continue + if len(entries) >= MAX_LISTING_ENTRIES: + break + entries.sort(key=lambda e: e["name"]) + return ToolResult(success=True, data={ + "entries": entries, + "count": len(entries), + "truncated": len(entries) >= MAX_LISTING_ENTRIES, + }) + except Exception as e: + return ToolResult(success=False, error=f"List error: {e}") - # TODO: Implement - # try: - # p = Path(path) - # entries = [] - # for entry in p.glob(pattern): - # stat = entry.stat() - # entries.append({ - # "name": entry.name, - # "type": "dir" if entry.is_dir() else "file", - # "size": stat.st_size, - # "modified": stat.st_mtime, - # }) - # return ToolResult(success=True, data=entries) - # except Exception as e: - # return ToolResult(success=False, error=str(e)) - return ToolResult(success=False, error="Not implemented") +@tool +async def delete_file(path: str, recursive: bool = False) -> ToolResult: + error, resolved = _validate_path(path) + if error: + return ToolResult(success=False, error=error) + if not resolved.exists(): + return ToolResult(success=False, error=f"Path not found: {path}") + try: + if resolved.is_file(): + resolved.unlink() + elif resolved.is_dir(): + if not recursive: + return ToolResult(success=False, error="Cannot delete directory without recursive=true") + import shutil + shutil.rmtree(resolved) + else: + return ToolResult(success=False, error=f"Unknown file type: {path}") + return ToolResult(success=True, data={"deleted": True, "path": str(resolved)}) + except Exception as e: + return ToolResult(success=False, error=f"Delete error: {e}") diff --git a/agentserver/tools/librarian.py b/agentserver/tools/librarian.py index 975d175..8ab97e0 100644 --- a/agentserver/tools/librarian.py +++ b/agentserver/tools/librarian.py @@ -1,128 +1,135 @@ """ Librarian tools - exist-db XML database integration. -Provides XQuery-based document storage and retrieval. +Provides XQuery-based document storage and retrieval for long-term memory. +Requires exist-db to be running and configured. """ -from typing import Optional, Dict, List +from __future__ import annotations + +from typing import Optional, Dict +from dataclasses import dataclass + from .base import tool, ToolResult -# TODO: Configure exist-db connection -EXISTDB_URL = "http://localhost:8080/exist/rest" -EXISTDB_USER = "admin" -EXISTDB_PASS = "" # Configure via env +try: + import aiohttp + AIOHTTP_AVAILABLE = True +except ImportError: + AIOHTTP_AVAILABLE = False + + +@dataclass +class ExistDBConfig: + url: str = "http://localhost:8080/exist/rest" + username: str = "admin" + password: str = "" + default_collection: str = "/db/agents" + + +_config: Optional[ExistDBConfig] = None + + +def configure_librarian( + url: str = "http://localhost:8080/exist/rest", + username: str = "admin", + password: str = "", + default_collection: str = "/db/agents", +) -> None: + global _config + _config = ExistDBConfig(url=url, username=username, password=password, default_collection=default_collection) + + +def _check_config() -> Optional[str]: + if not AIOHTTP_AVAILABLE: + return "aiohttp not installed. Install with: pip install xml-pipeline[server]" + if not _config: + return "Librarian not configured. Call configure_librarian() first." + return None + + +def _resolve_path(path: str) -> str: + if path.startswith("/"): + return path + return f"{_config.default_collection}/{path}" @tool -async def librarian_store( - collection: str, - document_name: str, - content: str, -) -> ToolResult: - """ - Store an XML document in exist-db. - - Args: - collection: Target collection path (e.g., "/db/agents/greeter") - document_name: Document filename (e.g., "conversation-001.xml") - content: XML content - - Returns: - path: Full path to stored document - """ - # TODO: Implement with exist-db REST API - # import aiohttp - # url = f"{EXISTDB_URL}{collection}/{document_name}" - # async with aiohttp.ClientSession() as session: - # async with session.put( - # url, - # data=content, - # headers={"Content-Type": "application/xml"}, - # auth=aiohttp.BasicAuth(EXISTDB_USER, EXISTDB_PASS), - # ) as resp: - # if resp.status in (200, 201): - # return ToolResult(success=True, data={"path": f"{collection}/{document_name}"}) - # return ToolResult(success=False, error=await resp.text()) - - return ToolResult(success=False, error="Not implemented - configure exist-db") +async def librarian_store(collection: str, document_name: str, content: str) -> ToolResult: + """Store an XML document in exist-db.""" + if error := _check_config(): + return ToolResult(success=False, error=error) + collection = _resolve_path(collection) + url = f"{_config.url}{collection}/{document_name}" + try: + auth = aiohttp.BasicAuth(_config.username, _config.password) + async with aiohttp.ClientSession() as session: + async with session.put(url, data=content.encode("utf-8"), + headers={"Content-Type": "application/xml"}, auth=auth) as resp: + if resp.status in (200, 201): + return ToolResult(success=True, data={"path": f"{collection}/{document_name}"}) + return ToolResult(success=False, error=f"exist-db error {resp.status}: {await resp.text()}") + except Exception as e: + return ToolResult(success=False, error=f"Store error: {e}") @tool -async def librarian_get( - path: str, -) -> ToolResult: - """ - Retrieve a document by path. - - Args: - path: Full document path (e.g., "/db/agents/greeter/conversation-001.xml") - - Returns: - content: XML content - """ - # TODO: Implement with exist-db REST API - # import aiohttp - # url = f"{EXISTDB_URL}{path}" - # async with aiohttp.ClientSession() as session: - # async with session.get( - # url, - # auth=aiohttp.BasicAuth(EXISTDB_USER, EXISTDB_PASS), - # ) as resp: - # if resp.status == 200: - # return ToolResult(success=True, data=await resp.text()) - # return ToolResult(success=False, error=f"Not found: {path}") - - return ToolResult(success=False, error="Not implemented - configure exist-db") +async def librarian_get(path: str) -> ToolResult: + """Retrieve a document by path.""" + if error := _check_config(): + return ToolResult(success=False, error=error) + path = _resolve_path(path) + url = f"{_config.url}{path}" + try: + auth = aiohttp.BasicAuth(_config.username, _config.password) + async with aiohttp.ClientSession() as session: + async with session.get(url, auth=auth) as resp: + if resp.status == 200: + return ToolResult(success=True, data={"content": await resp.text(), "path": path}) + elif resp.status == 404: + return ToolResult(success=False, error=f"Not found: {path}") + return ToolResult(success=False, error=f"exist-db error {resp.status}") + except Exception as e: + return ToolResult(success=False, error=f"Get error: {e}") @tool -async def librarian_query( - query: str, - collection: Optional[str] = None, - variables: Optional[Dict[str, str]] = None, -) -> ToolResult: - """ - Execute an XQuery against exist-db. - - Args: - query: XQuery expression - collection: Limit to collection (optional) - variables: External variables to bind (optional) - - Returns: - results: Array of matching XML fragments - - Examples: - - '//message[@from="greeter"]' - - 'for $m in //message where $m/@timestamp > $since return $m' - """ - # TODO: Implement with exist-db REST API - # The exist-db REST API accepts XQuery via POST to /exist/rest/db - # with _query parameter or as request body - - return ToolResult(success=False, error="Not implemented - configure exist-db") +async def librarian_query(query: str, collection: Optional[str] = None, variables: Optional[Dict[str, str]] = None) -> ToolResult: + """Execute an XQuery against exist-db.""" + if error := _check_config(): + return ToolResult(success=False, error=error) + base_path = _resolve_path(collection) if collection else "/db" + url = f"{_config.url}{base_path}" + full_query = query + if variables: + var_decls = "\n".join(f'declare variable ${k} external := "{v}";' for k, v in variables.items()) + full_query = f"{var_decls}\n{query}" + try: + auth = aiohttp.BasicAuth(_config.username, _config.password) + async with aiohttp.ClientSession() as session: + async with session.post(url, data={"_query": full_query}, auth=auth) as resp: + if resp.status == 200: + return ToolResult(success=True, data={"results": await resp.text(), "collection": base_path}) + return ToolResult(success=False, error=f"XQuery error {resp.status}: {await resp.text()}") + except Exception as e: + return ToolResult(success=False, error=f"Query error: {e}") @tool -async def librarian_search( - query: str, - collection: Optional[str] = None, - num_results: int = 10, -) -> ToolResult: - """ - Full-text search across documents. - - Args: - query: Search terms - collection: Limit to collection (optional) - num_results: Max results (default: 10) - - Returns: - results: Array of {path, score, snippet} - """ - # TODO: Implement with exist-db full-text search - # exist-db supports Lucene-based full-text indexing - # Query using ft:query() function in XQuery - - return ToolResult(success=False, error="Not implemented - configure exist-db") +async def librarian_search(query: str, collection: Optional[str] = None, num_results: int = 10) -> ToolResult: + """Full-text search across documents using Lucene.""" + if error := _check_config(): + return ToolResult(success=False, error=error) + base_path = _resolve_path(collection) if collection else _config.default_collection + xquery = f'import module namespace ft="http://exist-db.org/xquery/lucene"; for $hit in collection("{base_path}")//*[ft:query(., "{query}")] let $score := ft:score($hit) order by $score descending return {{document-uri(root($hit))}}{{$score}}' + url = f"{_config.url}{base_path}" + try: + auth = aiohttp.BasicAuth(_config.username, _config.password) + async with aiohttp.ClientSession() as session: + async with session.post(url, data={"_query": xquery, "_howmany": str(num_results)}, auth=auth) as resp: + if resp.status == 200: + return ToolResult(success=True, data={"results": await resp.text(), "query": query}) + return ToolResult(success=False, error=f"Search error {resp.status}: {await resp.text()}") + except Exception as e: + return ToolResult(success=False, error=f"Search error: {e}") diff --git a/agentserver/tools/search.py b/agentserver/tools/search.py index 18f11ac..70bc3af 100644 --- a/agentserver/tools/search.py +++ b/agentserver/tools/search.py @@ -1,10 +1,143 @@ """ -Search tool - web search. +Search tool - web search integration. + +Requires configuration of a search provider API. +Supported providers: SerpAPI, Google Custom Search, Bing Search. """ +from __future__ import annotations + +from typing import Optional, List +from dataclasses import dataclass + from .base import tool, ToolResult +# Try to import aiohttp for HTTP requests +try: + import aiohttp + AIOHTTP_AVAILABLE = True +except ImportError: + AIOHTTP_AVAILABLE = False + + +@dataclass +class SearchConfig: + """Configuration for search provider.""" + provider: str # "serpapi", "google", "bing" + api_key: str + engine_id: Optional[str] = None # For Google Custom Search + + +# Global config - set via configure_search() +_config: Optional[SearchConfig] = None + + +def configure_search( + provider: str, + api_key: str, + engine_id: Optional[str] = None, +) -> None: + """ + Configure the search provider. + + Args: + provider: "serpapi", "google", or "bing" + api_key: API key for the provider + engine_id: Required for Google Custom Search + + Example: + configure_search("serpapi", os.environ["SERPAPI_KEY"]) + """ + global _config + _config = SearchConfig( + provider=provider, + api_key=api_key, + engine_id=engine_id, + ) + + +async def _search_serpapi(query: str, num_results: int) -> List[dict]: + """Search using SerpAPI.""" + async with aiohttp.ClientSession() as session: + params = { + "q": query, + "api_key": _config.api_key, + "num": num_results, + "engine": "google", + } + async with session.get( + "https://serpapi.com/search", + params=params, + ) as resp: + if resp.status != 200: + raise Exception(f"SerpAPI error: {resp.status}") + data = await resp.json() + results = [] + for item in data.get("organic_results", [])[:num_results]: + results.append({ + "title": item.get("title", ""), + "url": item.get("link", ""), + "snippet": item.get("snippet", ""), + }) + return results + + +async def _search_google(query: str, num_results: int) -> List[dict]: + """Search using Google Custom Search API.""" + if not _config.engine_id: + raise Exception("Google Custom Search requires engine_id") + + async with aiohttp.ClientSession() as session: + params = { + "q": query, + "key": _config.api_key, + "cx": _config.engine_id, + "num": min(num_results, 10), # API max is 10 + } + async with session.get( + "https://www.googleapis.com/customsearch/v1", + params=params, + ) as resp: + if resp.status != 200: + raise Exception(f"Google API error: {resp.status}") + data = await resp.json() + results = [] + for item in data.get("items", []): + results.append({ + "title": item.get("title", ""), + "url": item.get("link", ""), + "snippet": item.get("snippet", ""), + }) + return results + + +async def _search_bing(query: str, num_results: int) -> List[dict]: + """Search using Bing Search API.""" + async with aiohttp.ClientSession() as session: + headers = {"Ocp-Apim-Subscription-Key": _config.api_key} + params = { + "q": query, + "count": num_results, + } + async with session.get( + "https://api.bing.microsoft.com/v7.0/search", + headers=headers, + params=params, + ) as resp: + if resp.status != 200: + raise Exception(f"Bing API error: {resp.status}") + data = await resp.json() + results = [] + for item in data.get("webPages", {}).get("value", []): + results.append({ + "title": item.get("name", ""), + "url": item.get("url", ""), + "snippet": item.get("snippet", ""), + }) + return results + + @tool async def web_search( query: str, @@ -18,19 +151,46 @@ async def web_search( num_results: Number of results (default: 5, max: 20) Returns: - Array of {title, url, snippet} + results: Array of {title, url, snippet} - Implementation options: - - SerpAPI - - Google Custom Search - - Bing Search API - - DuckDuckGo (scraping) + Configuration: + Call configure_search() before use: + + from agentserver.tools.search import configure_search + configure_search("serpapi", "your-api-key") """ - # TODO: Implement with search provider - # Options: - # 1. SerpAPI (paid, reliable) - # 2. Google Custom Search API (limited free tier) - # 3. Bing Search API (Azure) - # 4. DuckDuckGo scraping (free, fragile) - - return ToolResult(success=False, error="Not implemented - configure search provider") + if not AIOHTTP_AVAILABLE: + return ToolResult( + success=False, + error="aiohttp not installed. Install with: pip install xml-pipeline[server]" + ) + + if not _config: + return ToolResult( + success=False, + error="Search not configured. Call configure_search() first." + ) + + # Clamp num_results + num_results = min(max(1, num_results), 20) + + try: + if _config.provider == "serpapi": + results = await _search_serpapi(query, num_results) + elif _config.provider == "google": + results = await _search_google(query, num_results) + elif _config.provider == "bing": + results = await _search_bing(query, num_results) + else: + return ToolResult( + success=False, + error=f"Unknown provider: {_config.provider}" + ) + + return ToolResult(success=True, data={ + "query": query, + "results": results, + "count": len(results), + }) + except Exception as e: + return ToolResult(success=False, error=f"Search error: {e}") diff --git a/agentserver/tools/shell.py b/agentserver/tools/shell.py index 8e51f1c..39bc21c 100644 --- a/agentserver/tools/shell.py +++ b/agentserver/tools/shell.py @@ -1,61 +1,163 @@ """ Shell tool - sandboxed command execution. + +Provides controlled command execution with security restrictions. """ -from typing import Optional +from __future__ import annotations + +import asyncio +import shlex +from typing import Optional, List + from .base import tool, ToolResult -# TODO: Configure command restrictions -ALLOWED_COMMANDS: list = [] # Empty = allow all (dangerous!) -BLOCKED_COMMANDS: list = ["rm", "del", "format", "mkfs", "dd"] +# Security configuration +ALLOWED_COMMANDS: List[str] = [] # Empty = check blocklist only +BLOCKED_COMMANDS: List[str] = [ + # Destructive commands + "rm", "rmdir", "del", "erase", "format", "mkfs", "dd", + # System modification + "shutdown", "reboot", "init", "systemctl", + # Network tools that could be abused + "nc", "netcat", "ncat", + # Privilege escalation + "sudo", "su", "doas", "runas", + # Shell escapes + "bash", "sh", "zsh", "fish", "cmd", "powershell", "pwsh", +] +DEFAULT_TIMEOUT = 30 +MAX_TIMEOUT = 300 +MAX_OUTPUT_SIZE = 1024 * 1024 # 1 MB + + +def configure_allowed_commands(commands: List[str]) -> None: + """Set an allowlist of commands (empty = blocklist mode).""" + global ALLOWED_COMMANDS + ALLOWED_COMMANDS = commands + + +def configure_blocked_commands(commands: List[str]) -> None: + """Set additional blocked commands.""" + global BLOCKED_COMMANDS + BLOCKED_COMMANDS = commands + + +def _validate_command(command: str) -> Optional[str]: + """Validate command against allow/block lists. Returns error or None.""" + try: + # Parse command to get the executable + parts = shlex.split(command) + if not parts: + return "Empty command" + + executable = parts[0].lower() + + # Strip path to get just the command name + if "/" in executable or "\\" in executable: + executable = executable.split("/")[-1].split("\\")[-1] + + # Check allowlist first (if configured) + if ALLOWED_COMMANDS: + if executable not in ALLOWED_COMMANDS: + return f"Command '{executable}' not in allowlist" + + # Check blocklist + if executable in BLOCKED_COMMANDS: + return f"Command '{executable}' is blocked for security" + + # Check for shell operators that could be dangerous + dangerous_operators = [";", "&&", "||", "|", "`", "$(", "${"] + for op in dangerous_operators: + if op in command: + return f"Shell operator '{op}' not allowed" + + return None + except ValueError as e: + return f"Invalid command syntax: {e}" @tool async def run_command( command: str, - timeout: int = 30, + timeout: int = DEFAULT_TIMEOUT, cwd: Optional[str] = None, + env: Optional[dict] = None, ) -> ToolResult: """ Execute a shell command (sandboxed). Args: command: Command to execute - timeout: Timeout in seconds (default: 30) - cwd: Working directory + timeout: Timeout in seconds (default: 30, max: 300) + cwd: Working directory (optional) + env: Environment variables to add (optional) Returns: - exit_code, stdout, stderr + exit_code: Process exit code + stdout: Standard output + stderr: Standard error + timed_out: True if command was killed due to timeout Security: - - Command allowlist (or blocklist dangerous commands) - - No shell expansion by default - - Resource limits (CPU, memory) - - Chroot to safe directory + - Dangerous commands are blocked + - Shell operators (;, &&, |, etc.) are blocked - Timeout enforced + - Output size limited to 1 MB """ - # TODO: Implement with asyncio.subprocess - # import asyncio - # try: - # proc = await asyncio.create_subprocess_shell( - # command, - # stdout=asyncio.subprocess.PIPE, - # stderr=asyncio.subprocess.PIPE, - # cwd=cwd, - # ) - # try: - # stdout, stderr = await asyncio.wait_for(proc.communicate(), timeout=timeout) - # except asyncio.TimeoutError: - # proc.kill() - # return ToolResult(success=False, error=f"Command timed out after {timeout}s") - # - # return ToolResult(success=True, data={ - # "exit_code": proc.returncode, - # "stdout": stdout.decode(), - # "stderr": stderr.decode(), - # }) - # except Exception as e: - # return ToolResult(success=False, error=str(e)) - - return ToolResult(success=False, error="Not implemented") + # Validate command + if error := _validate_command(command): + return ToolResult(success=False, error=error) + + # Clamp timeout + timeout = min(max(1, timeout), MAX_TIMEOUT) + + try: + # Parse command into args (no shell) + args = shlex.split(command) + + # Create subprocess + proc = await asyncio.create_subprocess_exec( + *args, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + cwd=cwd, + env=env, + ) + + timed_out = False + try: + stdout, stderr = await asyncio.wait_for( + proc.communicate(), + timeout=timeout + ) + except asyncio.TimeoutError: + proc.kill() + await proc.wait() + timed_out = True + stdout = b"" + stderr = b"Command timed out" + + # Decode output, truncating if too large + stdout_str = stdout[:MAX_OUTPUT_SIZE].decode("utf-8", errors="replace") + stderr_str = stderr[:MAX_OUTPUT_SIZE].decode("utf-8", errors="replace") + + truncated = len(stdout) > MAX_OUTPUT_SIZE or len(stderr) > MAX_OUTPUT_SIZE + + return ToolResult( + success=proc.returncode == 0 and not timed_out, + data={ + "exit_code": proc.returncode, + "stdout": stdout_str, + "stderr": stderr_str, + "timed_out": timed_out, + "truncated": truncated, + } + ) + except FileNotFoundError: + return ToolResult(success=False, error=f"Command not found: {args[0]}") + except PermissionError: + return ToolResult(success=False, error=f"Permission denied: {args[0]}") + except Exception as e: + return ToolResult(success=False, error=f"Execution error: {e}") diff --git a/docs/wasm-listeners.md b/docs/wasm-listeners.md new file mode 100644 index 0000000..fbdcff2 --- /dev/null +++ b/docs/wasm-listeners.md @@ -0,0 +1,230 @@ +# WASM Listeners Specification + +Custom listeners can be implemented in WebAssembly (WASM) using AssemblyScript or any language that compiles to WASM. This enables power users to deploy sandboxed, portable handlers. + +## Overview + +``` +User uploads: handler.wasm + handler.wit +System: Validates, generates wrappers, registers handlers +Runtime: Python wrapper calls WASM for compute, handles I/O +``` + +## Upload Requirements + +### 1. WIT File (Required) + +The WIT file describes the interface - input/output types for each handler. + +```wit +// calculator.wit +package myorg:calculator@1.0.0; + +interface calculate { + record calculate-request { + expression: string, + } + + record calculate-response { + result: f64, + error: option, + } +} + +interface factorial { + record factorial-request { + n: u32, + } + + record factorial-response { + result: u64, + } +} +``` + +### 2. WASM Module (Required) + +The WASM module must export: + +| Export | Signature | Description | +|--------|-----------|-------------| +| `handle_{interface}` | `(ptr: i32, len: i32) -> i32` | Handler for each WIT interface | +| `alloc` | `(size: i32) -> i32` | Allocate memory for input | +| `free` | `(ptr: i32) -> void` | Free memory | + +**Calling convention:** +- Input: JSON string at `ptr` with length `len` +- Output: Returns pointer to JSON result (length-prefixed or null-terminated) + +### 3. Example AssemblyScript + +```typescript +// calculator.ts +import { JSON } from "assemblyscript-json"; + +class CalculateRequest { + expression: string = ""; +} + +class CalculateResponse { + result: f64 = 0; + error: string | null = null; + _to: string | null = null; // Optional: routing target +} + +export function handle_calculate(ptr: i32, len: i32): i32 { + // Parse input + const input = parseJson(ptr, len); + + // Process + const response = new CalculateResponse(); + try { + response.result = evaluate(input.expression); + } catch (e) { + response.error = e.message; + } + + // Return JSON pointer + return toJsonPtr(response); +} + +// Memory management - required exports +const allocations = new Map(); + +export function alloc(size: i32): i32 { + const ptr = heap.alloc(size); + allocations.set(ptr, size); + return ptr; +} + +export function free(ptr: i32): void { + if (allocations.has(ptr)) { + heap.free(ptr); + allocations.delete(ptr); + } +} +``` + +Compile with: +```bash +asc calculator.ts -o calculator.wasm --optimize +``` + +## Registration Flow + +1. **Parse WIT** → Extract interface definitions +2. **Load WASM** → Validate exports match WIT interfaces +3. **Generate wrappers** → Create @xmlify dataclasses from WIT +4. **Register handlers** → Add to listener routing table + +```python +# Pseudocode +from agentserver.wasm import register_wasm_listener + +register_wasm_listener( + name="calculator", + wasm_path="/uploads/calculator.wasm", + wit_path="/uploads/calculator.wit", + config={ + "memory_limit_mb": 64, + "timeout_seconds": 5, + } +) +``` + +## Data Flow + +``` +Message arrives (XML) + │ + ▼ +Python wrapper deserializes to @xmlify dataclass + │ + ▼ +Convert to JSON string + │ + ▼ +Allocate WASM memory, copy JSON + │ + ▼ +Call handle_{interface}(ptr, len) + │ + ▼ +WASM processes synchronously + │ + ▼ +Read result JSON from returned pointer + │ + ▼ +Free WASM memory + │ + ▼ +Convert JSON to @xmlify response dataclass + │ + ▼ +Extract routing target from _to field (if present) + │ + ▼ +Return HandlerResponse +``` + +## Routing + +WASM handlers signal routing via the `_to` field in the response JSON: + +```json +{ + "result": 42, + "_to": "logger" +} +``` + +- `_to` present → forward to named listener +- `_to` absent/null → respond to caller + +## Resource Limits + +| Resource | Default | Configurable | +|----------|---------|--------------| +| Memory | 64 MB | Yes | +| CPU time | 5 seconds | Yes | +| Stack depth | WASM default | No | + +Exceeding limits results in termination and `SystemError` response. + +## Security Model + +- **Sandboxed**: WASM linear memory is isolated +- **No I/O**: WASM cannot access filesystem, network, or system +- **No imports**: Host functions are not exposed (pure compute only) +- **Timeout enforced**: Long-running handlers are terminated + +For I/O, WASM handlers should: +1. Return a response indicating what I/O is needed +2. Let the agent orchestrate I/O via Python tools +3. Receive I/O results in subsequent messages + +## Lifecycle + +WASM instances are kept "hot" (loaded) per thread: + +- **Created**: On first message to handler in thread +- **Reused**: Subsequent messages in same thread reuse instance +- **Destroyed**: When thread context is pruned (GC) + +This amortizes instantiation cost for multi-turn conversations. + +## Limitations (v1) + +- No streaming responses +- No async/await inside WASM +- No host function imports (pure compute only) +- No direct tool invocation from WASM +- JSON serialization overhead at boundary + +## Future Considerations + +- WASI support for controlled I/O +- Component Model for richer interfaces +- Streaming via multiple response chunks +- Direct tool imports for trusted WASM diff --git a/examples/mcp-servers/reddit-sentiment/README.md b/examples/mcp-servers/reddit-sentiment/README.md new file mode 100644 index 0000000..dba5f13 --- /dev/null +++ b/examples/mcp-servers/reddit-sentiment/README.md @@ -0,0 +1,76 @@ +# Reddit Sentiment MCP Server + +An MCP server that provides Reddit sentiment analysis for stock tickers. + +## Installation + +```bash +cd examples/mcp-servers/reddit-sentiment +pip install -e . +``` + +## Usage with Claude Code + +Add to your Claude Code settings (`~/.claude/settings.json`): + +```json +{ + "mcpServers": { + "reddit-sentiment": { + "command": "python", + "args": ["-m", "reddit_sentiment"], + "cwd": "/path/to/reddit-sentiment" + } + } +} +``` + +## Tools + +### reddit_trending_tickers + +Get the most mentioned stock tickers across Reddit finance subreddits. + +``` +Trending tickers from r/wallstreetbets, r/stocks, r/investing +``` + +### reddit_ticker_sentiment + +Get sentiment analysis for a specific ticker. + +``` +What's the Reddit sentiment on $TSLA? +``` + +### reddit_wsb_summary + +Get a summary of current WallStreetBets activity. + +``` +What's happening on WSB right now? +``` + +## How It Works + +1. Fetches posts from Reddit's public JSON API (no auth required) +2. Extracts ticker symbols using regex ($TSLA, TSLA, etc.) +3. Analyzes sentiment using bullish/bearish word matching +4. Returns structured JSON with mentions, scores, and sentiment + +## Limitations + +- Reddit rate limits (~60 requests/minute without auth) +- Simple word-based sentiment (no ML) +- Only scans post titles and selftext (not comments) +- Ticker extraction may have false positives + +## Subreddits Scanned + +- r/wallstreetbets +- r/stocks +- r/investing +- r/options +- r/stockmarket +- r/thetagang +- r/smallstreetbets diff --git a/examples/mcp-servers/reddit-sentiment/__init__.py b/examples/mcp-servers/reddit-sentiment/__init__.py new file mode 100644 index 0000000..444b177 --- /dev/null +++ b/examples/mcp-servers/reddit-sentiment/__init__.py @@ -0,0 +1,4 @@ +"""Reddit Sentiment MCP Server.""" +from .reddit_sentiment import main, server + +__all__ = ["main", "server"] diff --git a/examples/mcp-servers/reddit-sentiment/__main__.py b/examples/mcp-servers/reddit-sentiment/__main__.py new file mode 100644 index 0000000..f228f38 --- /dev/null +++ b/examples/mcp-servers/reddit-sentiment/__main__.py @@ -0,0 +1,6 @@ +"""Allow running as python -m reddit_sentiment.""" +import asyncio +from .reddit_sentiment import main + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/mcp-servers/reddit-sentiment/pyproject.toml b/examples/mcp-servers/reddit-sentiment/pyproject.toml new file mode 100644 index 0000000..46d8b89 --- /dev/null +++ b/examples/mcp-servers/reddit-sentiment/pyproject.toml @@ -0,0 +1,16 @@ +[project] +name = "reddit-sentiment-mcp" +version = "0.1.0" +description = "MCP server for Reddit stock sentiment analysis" +requires-python = ">=3.10" +dependencies = [ + "mcp>=0.9.0", + "aiohttp>=3.8.0", +] + +[project.scripts] +reddit-sentiment = "reddit_sentiment:main" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" diff --git a/examples/mcp-servers/reddit-sentiment/reddit_sentiment.py b/examples/mcp-servers/reddit-sentiment/reddit_sentiment.py new file mode 100644 index 0000000..a171fb2 --- /dev/null +++ b/examples/mcp-servers/reddit-sentiment/reddit_sentiment.py @@ -0,0 +1,246 @@ +""" +Reddit Sentiment MCP Server + +Provides tools for analyzing Reddit sentiment around stock tickers. +""" + +import asyncio +import json +import re +from collections import Counter +from dataclasses import dataclass +from datetime import datetime + +try: + import aiohttp +except ImportError: + aiohttp = None + +try: + from mcp.server import Server + from mcp.server.stdio import stdio_server + from mcp.types import Tool, TextContent +except ImportError: + print("MCP not installed. Run: pip install mcp") + exit(1) + +FINANCE_SUBREDDITS = [ + "wallstreetbets", "stocks", "investing", "options", + "stockmarket", "thetagang", "smallstreetbets", +] + +TICKER_PATTERN = re.compile(r'\$?([A-Z]{2,5})\b') + +TICKER_BLACKLIST = { + "I", "A", "THE", "FOR", "AND", "BUT", "NOT", "YOU", "ALL", + "CAN", "HAD", "HER", "WAS", "ONE", "OUR", "OUT", "ARE", "HAS", + "HIS", "HOW", "ITS", "MAY", "NEW", "NOW", "OLD", "SEE", "WAY", + "WHO", "BOY", "DID", "GET", "HIM", "LET", "PUT", "SAY", "SHE", + "TOO", "USE", "CEO", "USD", "USA", "ETF", "IPO", "GDP", "FBI", + "SEC", "FDA", "NYSE", "IMO", "YOLO", "FOMO", "HODL", "TLDR", + "LOL", "WTF", "FYI", "EDIT", "POST", "JUST", "LIKE", "THIS", + "THAT", "WITH", "FROM", "HAVE", "BEEN", "MORE", "WHEN", "WILL", +} + +BULLISH_WORDS = { + "moon", "rocket", "bull", "calls", "long", "buy", "buying", + "pump", "tendies", "gains", "profit", "up", "green", "bullish", + "squeeze", "breakout", "diamond", "hands", "hold", "holding", +} + +BEARISH_WORDS = { + "puts", "short", "sell", "selling", "dump", "crash", "bear", + "down", "red", "bearish", "loss", "losses", "overvalued", + "bubble", "drop", "tank", "drill", "bag", "bagholder", "rip", +} + + +@dataclass +class RedditPost: + title: str + score: int + num_comments: int + created_utc: float + subreddit: str + selftext: str = "" + url: str = "" + + +async def fetch_subreddit(subreddit: str, sort: str = "hot", limit: int = 25): + if not aiohttp: + raise RuntimeError("aiohttp not installed") + + url = f"https://www.reddit.com/r/{subreddit}/{sort}.json?limit={limit}" + headers = {"User-Agent": "reddit-sentiment-mcp/1.0"} + + async with aiohttp.ClientSession() as session: + async with session.get(url, headers=headers) as resp: + if resp.status != 200: + return [] + data = await resp.json() + + posts = [] + for child in data.get("data", {}).get("children", []): + d = child.get("data", {}) + posts.append(RedditPost( + title=d.get("title", ""), + score=d.get("score", 0), + num_comments=d.get("num_comments", 0), + created_utc=d.get("created_utc", 0), + subreddit=d.get("subreddit", subreddit), + selftext=d.get("selftext", ""), + url=f"https://reddit.com{d.get('permalink', '')}", + )) + return posts + + +def extract_tickers(text): + matches = TICKER_PATTERN.findall(text.upper()) + return [m for m in matches if m not in TICKER_BLACKLIST and len(m) >= 2] + + +def analyze_sentiment(text): + words = set(text.lower().split()) + return len(words & BULLISH_WORDS), len(words & BEARISH_WORDS) + + +server = Server("reddit-sentiment") + + +@server.list_tools() +async def list_tools(): + return [ + Tool( + name="reddit_trending_tickers", + description="Get trending stock tickers from Reddit finance subreddits", + inputSchema={ + "type": "object", + "properties": { + "subreddits": {"type": "array", "items": {"type": "string"}}, + "limit": {"type": "integer", "description": "Posts per sub (max 100)"}, + }, + }, + ), + Tool( + name="reddit_ticker_sentiment", + description="Get Reddit sentiment for a specific stock ticker", + inputSchema={ + "type": "object", + "properties": { + "ticker": {"type": "string", "description": "e.g. TSLA, GME"}, + }, + "required": ["ticker"], + }, + ), + Tool( + name="reddit_wsb_summary", + description="Get WallStreetBets current activity summary", + inputSchema={ + "type": "object", + "properties": { + "sort": {"type": "string", "enum": ["hot", "new", "top"]}, + }, + }, + ), + ] + + +@server.call_tool() +async def call_tool(name: str, arguments: dict): + if name == "reddit_trending_tickers": + return await trending_tickers(arguments) + elif name == "reddit_ticker_sentiment": + return await ticker_sentiment(arguments) + elif name == "reddit_wsb_summary": + return await wsb_summary(arguments) + return [TextContent(type="text", text=f"Unknown tool: {name}")] + + +async def trending_tickers(args): + subs = args.get("subreddits", ["wallstreetbets", "stocks", "investing"]) + limit = min(args.get("limit", 25), 100) + + counts, scores = Counter(), Counter() + for sub in subs: + try: + for post in await fetch_subreddit(sub, "hot", limit): + for t in extract_tickers(f"{post.title} {post.selftext}"): + counts[t] += 1 + scores[t] += post.score + except Exception: + pass + + result = { + "subreddits": subs, + "trending": [ + {"ticker": t, "mentions": c, "total_score": scores[t]} + for t, c in counts.most_common(20) + ], + "timestamp": datetime.utcnow().isoformat(), + } + return [TextContent(type="text", text=json.dumps(result, indent=2))] + + +async def ticker_sentiment(args): + ticker = args.get("ticker", "").upper().replace("$", "") + if not ticker: + return [TextContent(type="text", text="Error: ticker required")] + + mentions, bull, bear = [], 0, 0 + for sub in FINANCE_SUBREDDITS: + try: + for post in await fetch_subreddit(sub, "hot", 50): + text = f"{post.title} {post.selftext}" + if ticker in extract_tickers(text): + b, br = analyze_sentiment(text) + bull += b + bear += br + mentions.append({ + "sub": post.subreddit, "title": post.title[:80], + "score": post.score, "url": post.url, + }) + except Exception: + pass + + score = (bull - bear) / max(bull + bear, 1) + result = { + "ticker": ticker, + "mentions": len(mentions), + "sentiment": "bullish" if score > 0.2 else ("bearish" if score < -0.2 else "neutral"), + "score": round(score, 2), + "bullish": bull, "bearish": bear, + "posts": mentions[:10], + } + return [TextContent(type="text", text=json.dumps(result, indent=2))] + + +async def wsb_summary(args): + sort = args.get("sort", "hot") + posts = await fetch_subreddit("wallstreetbets", sort, 50) + + counts, bull, bear = Counter(), 0, 0 + hot = [] + for post in posts: + text = f"{post.title} {post.selftext}" + for t in extract_tickers(text): + counts[t] += 1 + b, br = analyze_sentiment(text) + bull += b + bear += br + if post.score > 100: + hot.append({"title": post.title[:60], "score": post.score}) + + result = { + "top_tickers": [{"ticker": t, "mentions": c} for t, c in counts.most_common(10)], + "mood": "bullish" if bull > bear else "bearish", + "hot_posts": hot[:5], + } + return [TextContent(type="text", text=json.dumps(result, indent=2))] + + +async def main(): + async with stdio_server() as (read, write): + await server.run(read, write) + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/pyproject.toml b/pyproject.toml index 8a99cdf..9abbf65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,39 +20,96 @@ classifiers = [ "Programming Language :: Python :: 3.13", "Framework :: AsyncIO", ] + +# ============================================================================= +# CORE DEPENDENCIES - minimal, always installed +# ============================================================================= dependencies = [ + # XML processing "lxml", - "websockets", - "pyotp", - "pyyaml", - "cryptography", + # Async streaming "aiostream>=0.5", + # Config & serialization + "pyyaml", "pyhumps", - "termcolor", - "argon2-cffi", + # Crypto (for identity keys) + "cryptography", + # Console "prompt_toolkit>=3.0", + "termcolor", + # HTTP client for LLM backends + "httpx>=0.27", ] +# ============================================================================= +# OPTIONAL DEPENDENCIES - user opts into what they need +# ============================================================================= [project.optional-dependencies] + +# LLM provider SDKs (alternative to raw httpx) +anthropic = ["anthropic>=0.39"] +openai = ["openai>=1.0"] + +# Tool backends +redis = ["redis>=5.0"] # For distributed keyvalue +search = ["duckduckgo-search"] # For search tool + +# Auth (only for multi-tenant/remote deployments) +auth = [ + "pyotp", # TOTP for privileged channel + "argon2-cffi", # Password hashing +] + +# WebSocket server (for remote connections) +server = ["websockets"] + +# All optional features +all = [ + "xml-pipeline[anthropic,openai,redis,search,auth,server]", +] + +# Development test = [ "pytest>=7.0", "pytest-asyncio>=0.21", ] dev = [ - "pytest>=7.0", - "pytest-asyncio>=0.21", + "xml-pipeline[test,all]", "mypy", "ruff", ] +# ============================================================================= +# CLI ENTRY POINTS +# ============================================================================= +[project.scripts] +xml-pipeline = "agentserver.cli:main" +xp = "agentserver.cli:main" + +# ============================================================================= +# TOOL CONFIGURATION +# ============================================================================= [tool.pytest.ini_options] asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "function" testpaths = ["tests"] python_files = ["test_*.py"] -# Don't collect root __init__.py (has imports that break isolation) norecursedirs = [".git", "__pycache__", "*.egg-info"] [tool.setuptools.packages.find] where = ["."] include = ["agentserver*", "third_party*"] + +[tool.ruff] +line-length = 100 +target-version = "py311" + +[tool.ruff.lint] +select = ["E", "F", "I", "N", "W", "UP"] +ignore = ["E501"] + +[tool.mypy] +python_version = "3.11" +warn_return_any = true +warn_unused_ignores = true +disallow_untyped_defs = true