Add authentication system and HTTP/WebSocket server
- auth/users.py: User store with Argon2id password hashing - auth/sessions.py: Token-based session management with expiry - server/app.py: aiohttp server with auth middleware and WebSocket - console/client.py: SSH-style login console client Server endpoints: /auth/login, /auth/logout, /auth/me, /health, /ws Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
986db2e79b
commit
ebf72c1f8c
7 changed files with 888 additions and 4 deletions
|
|
@ -0,0 +1,19 @@
|
|||
"""
|
||||
Authentication and authorization for xml-pipeline.
|
||||
|
||||
Provides:
|
||||
- UserStore: User management with Argon2id password hashing
|
||||
- SessionManager: Token-based session management
|
||||
"""
|
||||
|
||||
from .users import User, UserStore, get_user_store
|
||||
from .sessions import Session, SessionManager, get_session_manager
|
||||
|
||||
__all__ = [
|
||||
"User",
|
||||
"UserStore",
|
||||
"get_user_store",
|
||||
"Session",
|
||||
"SessionManager",
|
||||
"get_session_manager",
|
||||
]
|
||||
197
agentserver/auth/sessions.py
Normal file
197
agentserver/auth/sessions.py
Normal file
|
|
@ -0,0 +1,197 @@
|
|||
"""
|
||||
Session management with token-based authentication.
|
||||
|
||||
Tokens are random hex strings stored in memory with expiry.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Default session lifetime
|
||||
DEFAULT_SESSION_LIFETIME = timedelta(hours=8)
|
||||
|
||||
# Token length in bytes (32 bytes = 64 hex chars)
|
||||
TOKEN_BYTES = 32
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
"""An authenticated session."""
|
||||
token: str
|
||||
username: str
|
||||
role: str
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
last_activity: datetime
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if session has expired."""
|
||||
return datetime.now(timezone.utc) > self.expires_at
|
||||
|
||||
def touch(self) -> None:
|
||||
"""Update last activity time."""
|
||||
self.last_activity = datetime.now(timezone.utc)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dict for API responses."""
|
||||
return {
|
||||
"token": self.token,
|
||||
"username": self.username,
|
||||
"role": self.role,
|
||||
"expires_at": self.expires_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""
|
||||
Manages authenticated sessions.
|
||||
|
||||
Thread-safe for concurrent access.
|
||||
|
||||
Usage:
|
||||
manager = SessionManager()
|
||||
|
||||
# Create session after successful login
|
||||
session = manager.create("admin", "admin")
|
||||
|
||||
# Validate token on subsequent requests
|
||||
session = manager.validate(token)
|
||||
if session:
|
||||
print(f"Welcome back {session.username}")
|
||||
|
||||
# Logout
|
||||
manager.revoke(token)
|
||||
"""
|
||||
|
||||
def __init__(self, lifetime: timedelta = DEFAULT_SESSION_LIFETIME):
|
||||
self.lifetime = lifetime
|
||||
self._sessions: dict[str, Session] = {}
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def create(
|
||||
self,
|
||||
username: str,
|
||||
role: str,
|
||||
lifetime: Optional[timedelta] = None,
|
||||
) -> Session:
|
||||
"""
|
||||
Create a new session.
|
||||
|
||||
Args:
|
||||
username: Authenticated username
|
||||
role: User's role
|
||||
lifetime: Optional custom lifetime
|
||||
|
||||
Returns:
|
||||
New Session with token
|
||||
"""
|
||||
token = secrets.token_hex(TOKEN_BYTES)
|
||||
now = datetime.now(timezone.utc)
|
||||
expires = now + (lifetime or self.lifetime)
|
||||
|
||||
session = Session(
|
||||
token=token,
|
||||
username=username,
|
||||
role=role,
|
||||
created_at=now,
|
||||
expires_at=expires,
|
||||
last_activity=now,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._sessions[token] = session
|
||||
self._cleanup_expired()
|
||||
|
||||
return session
|
||||
|
||||
def validate(self, token: str) -> Optional[Session]:
|
||||
"""
|
||||
Validate a session token.
|
||||
|
||||
Args:
|
||||
token: Session token from client
|
||||
|
||||
Returns:
|
||||
Session if valid, None if invalid/expired
|
||||
"""
|
||||
with self._lock:
|
||||
session = self._sessions.get(token)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
if session.is_expired():
|
||||
del self._sessions[token]
|
||||
return None
|
||||
|
||||
session.touch()
|
||||
return session
|
||||
|
||||
def revoke(self, token: str) -> bool:
|
||||
"""
|
||||
Revoke a session (logout).
|
||||
|
||||
Returns:
|
||||
True if session was revoked, False if not found
|
||||
"""
|
||||
with self._lock:
|
||||
if token in self._sessions:
|
||||
del self._sessions[token]
|
||||
return True
|
||||
return False
|
||||
|
||||
def revoke_user(self, username: str) -> int:
|
||||
"""
|
||||
Revoke all sessions for a user.
|
||||
|
||||
Returns:
|
||||
Number of sessions revoked
|
||||
"""
|
||||
with self._lock:
|
||||
to_revoke = [
|
||||
token for token, session in self._sessions.items()
|
||||
if session.username == username
|
||||
]
|
||||
for token in to_revoke:
|
||||
del self._sessions[token]
|
||||
return len(to_revoke)
|
||||
|
||||
def get_user_sessions(self, username: str) -> list[Session]:
|
||||
"""Get all active sessions for a user."""
|
||||
with self._lock:
|
||||
return [
|
||||
s for s in self._sessions.values()
|
||||
if s.username == username and not s.is_expired()
|
||||
]
|
||||
|
||||
def _cleanup_expired(self) -> None:
|
||||
"""Remove expired sessions. Must hold lock."""
|
||||
expired = [
|
||||
token for token, session in self._sessions.items()
|
||||
if session.is_expired()
|
||||
]
|
||||
for token in expired:
|
||||
del self._sessions[token]
|
||||
|
||||
def active_count(self) -> int:
|
||||
"""Count active sessions."""
|
||||
with self._lock:
|
||||
self._cleanup_expired()
|
||||
return len(self._sessions)
|
||||
|
||||
|
||||
# Global instance
|
||||
_manager: Optional[SessionManager] = None
|
||||
|
||||
|
||||
def get_session_manager() -> SessionManager:
|
||||
"""Get the global session manager."""
|
||||
global _manager
|
||||
if _manager is None:
|
||||
_manager = SessionManager()
|
||||
return _manager
|
||||
227
agentserver/auth/users.py
Normal file
227
agentserver/auth/users.py
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
"""
|
||||
User store with Argon2id password hashing.
|
||||
|
||||
Users are stored in ~/.xml-pipeline/users.yaml with hashed passwords.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import stat
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import yaml
|
||||
from argon2 import PasswordHasher
|
||||
from argon2.exceptions import VerifyMismatchError
|
||||
|
||||
|
||||
CONFIG_DIR = Path.home() / ".xml-pipeline"
|
||||
USERS_FILE = CONFIG_DIR / "users.yaml"
|
||||
|
||||
|
||||
@dataclass
|
||||
class User:
|
||||
"""A user account."""
|
||||
username: str
|
||||
password_hash: str
|
||||
role: str = "operator" # admin, operator, viewer
|
||||
created_at: str = ""
|
||||
last_login: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"username": self.username,
|
||||
"password_hash": self.password_hash,
|
||||
"role": self.role,
|
||||
"created_at": self.created_at,
|
||||
"last_login": self.last_login,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> User:
|
||||
return cls(
|
||||
username=data["username"],
|
||||
password_hash=data["password_hash"],
|
||||
role=data.get("role", "operator"),
|
||||
created_at=data.get("created_at", ""),
|
||||
last_login=data.get("last_login"),
|
||||
)
|
||||
|
||||
|
||||
class UserStore:
|
||||
"""
|
||||
Manages user accounts with secure password storage.
|
||||
|
||||
Usage:
|
||||
store = UserStore()
|
||||
store.create_user("admin", "secretpass", role="admin")
|
||||
|
||||
user = store.authenticate("admin", "secretpass")
|
||||
if user:
|
||||
print(f"Welcome {user.username}!")
|
||||
"""
|
||||
|
||||
def __init__(self, users_file: Path = USERS_FILE):
|
||||
self.users_file = users_file
|
||||
self.hasher = PasswordHasher()
|
||||
self._users: dict[str, User] = {}
|
||||
self._load()
|
||||
|
||||
def _ensure_dir(self) -> None:
|
||||
"""Create config directory if needed."""
|
||||
self.users_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _load(self) -> None:
|
||||
"""Load users from file."""
|
||||
if not self.users_file.exists():
|
||||
return
|
||||
try:
|
||||
with open(self.users_file) as f:
|
||||
data = yaml.safe_load(f) or {}
|
||||
for username, user_data in data.get("users", {}).items():
|
||||
user_data["username"] = username
|
||||
self._users[username] = User.from_dict(user_data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _save(self) -> None:
|
||||
"""Save users to file."""
|
||||
self._ensure_dir()
|
||||
|
||||
data = {
|
||||
"users": {
|
||||
username: {
|
||||
"password_hash": user.password_hash,
|
||||
"role": user.role,
|
||||
"created_at": user.created_at,
|
||||
"last_login": user.last_login,
|
||||
}
|
||||
for username, user in self._users.items()
|
||||
}
|
||||
}
|
||||
|
||||
with open(self.users_file, "w") as f:
|
||||
yaml.dump(data, f, default_flow_style=False)
|
||||
|
||||
# Set file permissions to 600
|
||||
if sys.platform != "win32":
|
||||
os.chmod(self.users_file, stat.S_IRUSR | stat.S_IWUSR)
|
||||
|
||||
def has_users(self) -> bool:
|
||||
"""Check if any users exist."""
|
||||
return len(self._users) > 0
|
||||
|
||||
def get_user(self, username: str) -> Optional[User]:
|
||||
"""Get user by username."""
|
||||
return self._users.get(username)
|
||||
|
||||
def list_users(self) -> list[str]:
|
||||
"""List all usernames."""
|
||||
return list(self._users.keys())
|
||||
|
||||
def create_user(
|
||||
self,
|
||||
username: str,
|
||||
password: str,
|
||||
role: str = "operator",
|
||||
) -> User:
|
||||
"""
|
||||
Create a new user.
|
||||
|
||||
Args:
|
||||
username: Unique username
|
||||
password: Plain text password (will be hashed)
|
||||
role: User role (admin, operator, viewer)
|
||||
|
||||
Returns:
|
||||
The created User
|
||||
|
||||
Raises:
|
||||
ValueError: If username already exists
|
||||
"""
|
||||
if username in self._users:
|
||||
raise ValueError(f"User already exists: {username}")
|
||||
|
||||
if len(password) < 4:
|
||||
raise ValueError("Password must be at least 4 characters")
|
||||
|
||||
user = User(
|
||||
username=username,
|
||||
password_hash=self.hasher.hash(password),
|
||||
role=role,
|
||||
created_at=datetime.now(timezone.utc).isoformat(),
|
||||
)
|
||||
|
||||
self._users[username] = user
|
||||
self._save()
|
||||
return user
|
||||
|
||||
def authenticate(self, username: str, password: str) -> Optional[User]:
|
||||
"""
|
||||
Authenticate user with password.
|
||||
|
||||
Returns:
|
||||
User if authentication successful, None otherwise
|
||||
"""
|
||||
user = self._users.get(username)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
try:
|
||||
self.hasher.verify(user.password_hash, password)
|
||||
|
||||
# Update last login
|
||||
user.last_login = datetime.now(timezone.utc).isoformat()
|
||||
self._save()
|
||||
|
||||
return user
|
||||
except VerifyMismatchError:
|
||||
return None
|
||||
|
||||
def change_password(self, username: str, new_password: str) -> bool:
|
||||
"""Change user's password."""
|
||||
user = self._users.get(username)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
if len(new_password) < 4:
|
||||
raise ValueError("Password must be at least 4 characters")
|
||||
|
||||
user.password_hash = self.hasher.hash(new_password)
|
||||
self._save()
|
||||
return True
|
||||
|
||||
def delete_user(self, username: str) -> bool:
|
||||
"""Delete a user."""
|
||||
if username not in self._users:
|
||||
return False
|
||||
|
||||
del self._users[username]
|
||||
self._save()
|
||||
return True
|
||||
|
||||
def set_role(self, username: str, role: str) -> bool:
|
||||
"""Change user's role."""
|
||||
user = self._users.get(username)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
user.role = role
|
||||
self._save()
|
||||
return True
|
||||
|
||||
|
||||
# Global instance
|
||||
_store: Optional[UserStore] = None
|
||||
|
||||
|
||||
def get_user_store() -> UserStore:
|
||||
"""Get the global user store."""
|
||||
global _store
|
||||
if _store is None:
|
||||
_store = UserStore()
|
||||
return _store
|
||||
|
|
@ -1,10 +1,12 @@
|
|||
"""
|
||||
console — Secure console interface for organism operators.
|
||||
console — Console interfaces for xml-pipeline.
|
||||
|
||||
Provides password-protected access to privileged operations
|
||||
via local keyboard input only (no network exposure).
|
||||
Provides:
|
||||
- SecureConsole: Local keyboard-only console (no network)
|
||||
- ConsoleClient: Network client connecting to server with auth
|
||||
"""
|
||||
|
||||
from agentserver.console.secure_console import SecureConsole, PasswordManager
|
||||
from agentserver.console.client import ConsoleClient
|
||||
|
||||
__all__ = ["SecureConsole", "PasswordManager"]
|
||||
__all__ = ["SecureConsole", "PasswordManager", "ConsoleClient"]
|
||||
|
|
|
|||
266
agentserver/console/client.py
Normal file
266
agentserver/console/client.py
Normal file
|
|
@ -0,0 +1,266 @@
|
|||
"""
|
||||
Console client that connects to the agent server.
|
||||
|
||||
Provides SSH-style login with username/password authentication.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import getpass
|
||||
import json
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
try:
|
||||
import aiohttp
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOHTTP_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.history import InMemoryHistory
|
||||
from prompt_toolkit.styles import Style
|
||||
PROMPT_TOOLKIT_AVAILABLE = True
|
||||
except ImportError:
|
||||
PROMPT_TOOLKIT_AVAILABLE = False
|
||||
|
||||
|
||||
DEFAULT_HOST = "127.0.0.1"
|
||||
DEFAULT_PORT = 8765
|
||||
MAX_LOGIN_ATTEMPTS = 3
|
||||
|
||||
|
||||
class ConsoleClient:
|
||||
"""
|
||||
Text-based console client for the agent server.
|
||||
|
||||
Usage:
|
||||
client = ConsoleClient()
|
||||
asyncio.run(client.run())
|
||||
"""
|
||||
|
||||
def __init__(self, host: str = DEFAULT_HOST, port: int = DEFAULT_PORT):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.base_url = f"http://{host}:{port}"
|
||||
self.ws_url = f"ws://{host}:{port}/ws"
|
||||
self.token: Optional[str] = None
|
||||
self.username: Optional[str] = None
|
||||
self.session: Optional[aiohttp.ClientSession] = None
|
||||
self.ws: Optional[aiohttp.ClientWebSocketResponse] = None
|
||||
self.running = False
|
||||
|
||||
async def login(self) -> bool:
|
||||
"""
|
||||
Perform SSH-style login.
|
||||
|
||||
Returns:
|
||||
True if login successful, False otherwise
|
||||
"""
|
||||
print(f"Connecting to {self.host}:{self.port}...")
|
||||
|
||||
for attempt in range(1, MAX_LOGIN_ATTEMPTS + 1):
|
||||
try:
|
||||
username = input("Username: ")
|
||||
password = getpass.getpass("Password: ")
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print("\nLogin cancelled.")
|
||||
return False
|
||||
|
||||
if not username or not password:
|
||||
print("Username and password required.")
|
||||
continue
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/auth/login",
|
||||
json={"username": username, "password": password},
|
||||
) as resp:
|
||||
data = await resp.json()
|
||||
|
||||
if resp.status == 200:
|
||||
self.token = data["token"]
|
||||
self.username = username
|
||||
print(f"Welcome, {username}!")
|
||||
return True
|
||||
else:
|
||||
error = data.get("error", "Authentication failed")
|
||||
remaining = MAX_LOGIN_ATTEMPTS - attempt
|
||||
if remaining > 0:
|
||||
print(f"{error}. {remaining} attempt(s) remaining.")
|
||||
else:
|
||||
print(f"{error}. No attempts remaining.")
|
||||
except aiohttp.ClientError as e:
|
||||
print(f"Connection error: {e}")
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
async def connect_ws(self) -> bool:
|
||||
"""Connect to WebSocket after authentication."""
|
||||
if not self.token:
|
||||
return False
|
||||
|
||||
try:
|
||||
self.session = aiohttp.ClientSession(
|
||||
headers={"Authorization": f"Bearer {self.token}"}
|
||||
)
|
||||
self.ws = await self.session.ws_connect(self.ws_url)
|
||||
|
||||
# Wait for connected message
|
||||
msg = await self.ws.receive_json()
|
||||
if msg.get("type") == "connected":
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"WebSocket connection failed: {e}")
|
||||
return False
|
||||
|
||||
async def send_command(self, cmd: str) -> Optional[dict]:
|
||||
"""Send a command via WebSocket and get response."""
|
||||
if not self.ws:
|
||||
return None
|
||||
|
||||
await self.ws.send_json(cmd)
|
||||
return await self.ws.receive_json()
|
||||
|
||||
def print_help(self):
|
||||
"""Print available commands."""
|
||||
print("""
|
||||
Available commands:
|
||||
/help - Show this help
|
||||
/status - Show server status
|
||||
/listeners - List active listeners
|
||||
/quit - Disconnect and exit
|
||||
|
||||
Send messages:
|
||||
@listener message - Send message to a listener
|
||||
message - Send to default listener
|
||||
""")
|
||||
|
||||
async def handle_command(self, line: str) -> bool:
|
||||
"""
|
||||
Handle a command line.
|
||||
|
||||
Returns:
|
||||
False if should quit, True otherwise
|
||||
"""
|
||||
line = line.strip()
|
||||
if not line:
|
||||
return True
|
||||
|
||||
if line == "/help":
|
||||
self.print_help()
|
||||
elif line == "/quit" or line == "/exit":
|
||||
return False
|
||||
elif line == "/status":
|
||||
resp = await self.send_command({"type": "status"})
|
||||
if resp:
|
||||
threads = resp.get("threads", 0)
|
||||
print(f"Active threads: {threads}")
|
||||
elif line == "/listeners":
|
||||
resp = await self.send_command({"type": "listeners"})
|
||||
if resp:
|
||||
listeners = resp.get("listeners", [])
|
||||
if listeners:
|
||||
print("Active listeners:")
|
||||
for name in listeners:
|
||||
print(f" - {name}")
|
||||
else:
|
||||
print("No active listeners")
|
||||
elif line.startswith("/"):
|
||||
print(f"Unknown command: {line}")
|
||||
else:
|
||||
# Send as message
|
||||
# TODO: Implement message sending when pump is connected
|
||||
print(f"Message sending not yet implemented: {line}")
|
||||
|
||||
return True
|
||||
|
||||
async def run(self):
|
||||
"""Main client loop."""
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
print("Error: aiohttp not installed")
|
||||
sys.exit(1)
|
||||
|
||||
# Login
|
||||
if not await self.login():
|
||||
print("Authentication failed.")
|
||||
sys.exit(1)
|
||||
|
||||
# Connect WebSocket
|
||||
if not await self.connect_ws():
|
||||
print("Failed to connect to server.")
|
||||
sys.exit(1)
|
||||
|
||||
print("Connected. Type /help for commands, /quit to exit.")
|
||||
|
||||
self.running = True
|
||||
|
||||
try:
|
||||
if PROMPT_TOOLKIT_AVAILABLE:
|
||||
await self._run_prompt_toolkit()
|
||||
else:
|
||||
await self._run_simple()
|
||||
finally:
|
||||
await self.cleanup()
|
||||
|
||||
async def _run_prompt_toolkit(self):
|
||||
"""Run with prompt_toolkit for better UX."""
|
||||
style = Style.from_dict({
|
||||
"prompt": "ansicyan bold",
|
||||
})
|
||||
|
||||
session = PromptSession(
|
||||
history=InMemoryHistory(),
|
||||
style=style,
|
||||
)
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
line = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: session.prompt(f"{self.username}> ")
|
||||
)
|
||||
if not await self.handle_command(line):
|
||||
break
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
break
|
||||
|
||||
async def _run_simple(self):
|
||||
"""Run with simple input (fallback)."""
|
||||
while self.running:
|
||||
try:
|
||||
line = input(f"{self.username}> ")
|
||||
if not await self.handle_command(line):
|
||||
break
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
break
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up connections."""
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
print("Disconnected.")
|
||||
|
||||
|
||||
def main():
|
||||
"""Entry point."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="XML Pipeline Console")
|
||||
parser.add_argument("--host", default=DEFAULT_HOST, help="Server host")
|
||||
parser.add_argument("--port", type=int, default=DEFAULT_PORT, help="Server port")
|
||||
args = parser.parse_args()
|
||||
|
||||
client = ConsoleClient(host=args.host, port=args.port)
|
||||
asyncio.run(client.run())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
11
agentserver/server/__init__.py
Normal file
11
agentserver/server/__init__.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""
|
||||
HTTP/WebSocket server for xml-pipeline.
|
||||
|
||||
Provides:
|
||||
- REST API for auth and management
|
||||
- WebSocket for console and GUI clients
|
||||
"""
|
||||
|
||||
from .app import create_app, run_server
|
||||
|
||||
__all__ = ["create_app", "run_server"]
|
||||
162
agentserver/server/app.py
Normal file
162
agentserver/server/app.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
"""
|
||||
aiohttp-based HTTP/WebSocket server.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional, Callable
|
||||
|
||||
try:
|
||||
from aiohttp import web, WSMsgType
|
||||
AIOHTTP_AVAILABLE = True
|
||||
except ImportError:
|
||||
AIOHTTP_AVAILABLE = False
|
||||
web = None
|
||||
WSMsgType = None
|
||||
|
||||
from ..auth.users import get_user_store, UserStore
|
||||
from ..auth.sessions import get_session_manager, SessionManager, Session
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..message_bus.stream_pump import StreamPump
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def auth_middleware():
|
||||
@web.middleware
|
||||
async def middleware(request, handler):
|
||||
if request.path in ("/auth/login", "/health"):
|
||||
return await handler(request)
|
||||
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return web.json_response({"error": "Missing Authorization"}, status=401)
|
||||
|
||||
token = auth_header[7:]
|
||||
session = request.app["session_manager"].validate(token)
|
||||
|
||||
if not session:
|
||||
return web.json_response({"error": "Invalid token"}, status=401)
|
||||
|
||||
request["session"] = session
|
||||
return await handler(request)
|
||||
|
||||
return middleware
|
||||
|
||||
|
||||
async def handle_login(request):
|
||||
try:
|
||||
data = await request.json()
|
||||
except:
|
||||
return web.json_response({"error": "Invalid JSON"}, status=400)
|
||||
|
||||
username = data.get("username", "")
|
||||
password = data.get("password", "")
|
||||
|
||||
if not username or not password:
|
||||
return web.json_response({"error": "Credentials required"}, status=400)
|
||||
|
||||
user = request.app["user_store"].authenticate(username, password)
|
||||
if not user:
|
||||
return web.json_response({"error": "Invalid credentials"}, status=401)
|
||||
|
||||
session = request.app["session_manager"].create(user.username, user.role)
|
||||
return web.json_response(session.to_dict())
|
||||
|
||||
|
||||
async def handle_logout(request):
|
||||
session = request["session"]
|
||||
request.app["session_manager"].revoke(session.token)
|
||||
return web.json_response({"message": "Logged out"})
|
||||
|
||||
|
||||
async def handle_me(request):
|
||||
session = request["session"]
|
||||
return web.json_response({
|
||||
"username": session.username,
|
||||
"role": session.role,
|
||||
"expires_at": session.expires_at.isoformat(),
|
||||
})
|
||||
|
||||
|
||||
async def handle_health(request):
|
||||
return web.json_response({"status": "ok"})
|
||||
|
||||
|
||||
async def handle_websocket(request):
|
||||
session = request["session"]
|
||||
pump = request.app.get("pump")
|
||||
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
|
||||
await ws.send_json({"type": "connected", "username": session.username})
|
||||
|
||||
async for msg in ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
try:
|
||||
data = json.loads(msg.data)
|
||||
resp = await handle_ws_msg(data, session, pump)
|
||||
await ws.send_json(resp)
|
||||
except Exception as e:
|
||||
await ws.send_json({"type": "error", "error": str(e)})
|
||||
|
||||
return ws
|
||||
|
||||
|
||||
async def handle_ws_msg(data, session, pump):
|
||||
t = data.get("type", "")
|
||||
|
||||
if t == "ping":
|
||||
return {"type": "pong"}
|
||||
elif t == "status":
|
||||
from ..memory import get_context_buffer
|
||||
stats = get_context_buffer().get_stats()
|
||||
return {"type": "status", "threads": stats["thread_count"]}
|
||||
elif t == "listeners":
|
||||
if not pump:
|
||||
return {"type": "listeners", "listeners": []}
|
||||
return {"type": "listeners", "listeners": list(pump.listeners.keys())}
|
||||
|
||||
return {"type": "error", "error": f"Unknown: {t}"}
|
||||
|
||||
|
||||
def create_app(pump=None):
|
||||
if not AIOHTTP_AVAILABLE:
|
||||
raise RuntimeError("aiohttp not installed")
|
||||
|
||||
app = web.Application(middlewares=[auth_middleware()])
|
||||
app["user_store"] = get_user_store()
|
||||
app["session_manager"] = get_session_manager()
|
||||
app["pump"] = pump
|
||||
|
||||
app.router.add_post("/auth/login", handle_login)
|
||||
app.router.add_post("/auth/logout", handle_logout)
|
||||
app.router.add_get("/auth/me", handle_me)
|
||||
app.router.add_get("/health", handle_health)
|
||||
app.router.add_get("/ws", handle_websocket)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
async def run_server(pump=None, host="127.0.0.1", port=8765):
|
||||
app = create_app(pump)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
|
||||
site = web.TCPSite(runner, host, port)
|
||||
await site.start()
|
||||
|
||||
print(f"Server on http://{host}:{port}")
|
||||
|
||||
try:
|
||||
while True:
|
||||
await asyncio.sleep(3600)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
Loading…
Reference in a new issue