diff --git a/agentserver/message_bus/bus.py b/agentserver/message_bus/bus.py
index d34eed6..680d670 100644
--- a/agentserver/message_bus/bus.py
+++ b/agentserver/message_bus/bus.py
@@ -1,158 +1,226 @@
-# agentserver/bus.py
-# Refactored January 01, 2026 – MessageBus with run() pump and out-of-band shutdown
+"""
+bus.py — The central MessageBus and pump for AgentServer v2.1
+
+This is the beating heart of the organism:
+- Owns all pipelines (one per listener + permanent system pipeline)
+- Maintains the routing table (root_tag → Listener(s))
+- Orchestrates ingress from sockets/gateways
+- Dispatches prepared messages to handlers
+- Processes handler responses (multi-payload extraction, provenance injection)
+- Guarantees thread continuity and diagnostic injection
+
+Fully aligned with:
+ - listener-class-v2.1.md
+ - configuration-v2.1.md
+ - message-pump-v2.1.md
+"""
+
+from __future__ import annotations
import asyncio
-import logging
-from typing import AsyncIterator, Callable, Dict, Optional, Awaitable
+from dataclasses import dataclass
+from typing import Callable, Awaitable, List
+from uuid import uuid4
from lxml import etree
-from agentserver.xml_listener import XMLListener
-from agentserver.utils.message import repair_and_canonicalize, XmlTamperError
+from agentserver.message_bus.message_state import MessageState, HandlerMetadata
+from agentserver.message_bus.steps.repair import repair_step
+from agentserver.message_bus.steps.c14n import c14n_step
+from agentserver.message_bus.steps.envelope_validation import envelope_validation_step
+from agentserver.message_bus.steps.payload_extraction import payload_extraction_step
+from agentserver.message_bus.steps.thread_assignment import thread_assignment_step
+from agentserver.message_bus.steps.xsd_validation import xsd_validation_step
+from agentserver.message_bus.steps.deserialization import deserialization_step
+from agentserver.message_bus.steps.routing_resolution import routing_resolution_step
-# Constants for Internal Physics
-ENV_NS = "https://xml-pipeline.org/ns/envelope/1"
-ENV = f"{{{ENV_NS}}}"
-LOG_TAG = "{https://xml-pipeline.org/ns/logger/1}log"
+# Type alias for pipeline steps
+PipelineStep = Callable[[MessageState], Awaitable[MessageState]]
-logger = logging.getLogger("agentserver.bus")
+
+@dataclass
+class Listener:
+ """Registered capability — defined in listener.py, referenced here."""
+ name: str
+ payload_class: type
+ handler: Callable
+ description: str
+ is_agent: bool = False
+ peers: list[str] | None = None
+ broadcast: bool = False
+ pipeline: "Pipeline" | None = None
+ schema: etree.XMLSchema | None = None # cached at registration
+
+
+class Pipeline:
+ """One dedicated pipeline per listener (plus system pipeline)."""
+ def __init__(self, steps: List[PipelineStep]):
+ self.steps = steps
+
+ async def process(self, initial_state: MessageState) -> None:
+ """Run the full ordered pipeline on a message."""
+ state = initial_state
+ for step in self.steps:
+ try:
+ state = await step(state)
+ if state.error:
+ break
+ except Exception as exc: # pylint: disable=broad-except
+ state.error = f"Pipeline step {step.__name__} crashed: {exc}"
+ break
+
+ # After all steps — dispatch if routable
+ if state.target_listeners:
+ await MessageBus.get_instance().dispatcher(state)
+ else:
+ # Fall back to system pipeline for diagnostics
+ await MessageBus.get_instance().system_pipeline.process(state)
class MessageBus:
- """The sovereign message carrier.
+ """Singleton message bus — the pump."""
+ _instance: "MessageBus" | None = None
- - Routes canonical XML trees by root tag and meta.
- - Pure dispatch: tree → optional response tree.
- - Active pump via run(): handles serialization and egress.
- - Out-of-band shutdown via asyncio.Event (fast-path, flood-immune).
- """
+ def __init__(self):
+ self.routing_table: dict[str, List[Listener]] = {} # root_tag → listener(s)
+ self.listeners: dict[str, Listener] = {} # name → Listener
+ self.system_pipeline = Pipeline(self._build_system_steps())
- def __init__(self, log_hook: Callable[[etree._Element], None]):
- # root_tag -> {agent_name -> XMLListener}
- self.listeners: Dict[str, Dict[str, XMLListener]] = {}
- # Global lookup for directed routing
- self.global_names: Dict[str, XMLListener] = {}
+ @classmethod
+ def get_instance(cls) -> "MessageBus":
+ if cls._instance is None:
+ cls._instance = MessageBus()
+ return cls._instance
- # The Sovereign Witness hook
- self._log_hook = log_hook
+ # ------------------------------------------------------------------ #
+ # Default step lists
+ # ------------------------------------------------------------------ #
+ def _build_default_listener_steps(self) -> List[PipelineStep]:
+ return [
+ repair_step,
+ c14n_step,
+ envelope_validation_step,
+ payload_extraction_step,
+ thread_assignment_step,
+ xsd_validation_step,
+ deserialization_step,
+ routing_resolution_step,
+ ]
- # Out-of-band shutdown signal (set only by AgentServer on privileged command)
- self.shutdown_event = asyncio.Event()
+ def _build_system_steps(self) -> List[PipelineStep]:
+ """Shorter, fixed steps — no XSD/deserialization."""
+ return [
+ repair_step,
+ c14n_step,
+ envelope_validation_step,
+ payload_extraction_step,
+ thread_assignment_step,
+ # system-specific handler that emits , boot, etc.
+ self.system_handler_step,
+ ]
- async def register_listener(self, listener: XMLListener) -> None:
- """Register an organ. Enforces global identity uniqueness."""
- if listener.agent_name in self.global_names:
- raise ValueError(f"Identity collision: {listener.agent_name}")
+ # ------------------------------------------------------------------ #
+ # Registration (called from listener.py)
+ # ------------------------------------------------------------------ #
+ def register_listener(self, listener: Listener) -> None:
+ root_tag = f"{listener.name.lower()}.{listener.payload_class.__name__.lower()}"
- self.global_names[listener.agent_name] = listener
- for tag in listener.listens_to:
- tag_dict = self.listeners.setdefault(tag, {})
- tag_dict[listener.agent_name] = listener
+ if root_tag in self.routing_table and not listener.broadcast:
+ raise ValueError(f"Root tag collision: {root_tag} already registered by {self.routing_table[root_tag][0].name}")
- logger.info(f"Registered organ: {listener.agent_name}")
+ # Build dedicated pipeline
+ steps = self._build_default_listener_steps()
+ # Inject listener-specific schema for xsd_validation_step
+ for step in steps:
+ if step.__name__ == "xsd_validation_step":
+ # We'll modify state.metadata in pipeline construction instead
+ pass
+ listener.pipeline = Pipeline(steps)
- async def deliver_bytes(self, raw_xml: bytes, client_id: Optional[str] = None) -> None:
- """Air Lock: ingest raw bytes, repair/canonicalize, inject into core."""
- try:
- envelope_tree = repair_and_canonicalize(raw_xml)
- await self.dispatch(envelope_tree, client_id)
- except XmlTamperError as e:
- logger.warning(f"Air Lock Reject: {e}")
+ # Insert into routing
+ self.routing_table.setdefault(root_tag, []).append(listener)
+ self.listeners[listener.name] = listener
- async def dispatch(
- self,
- envelope_tree: etree._Element,
- client_id: Optional[str] = None,
- ) -> etree._Element | None:
- """Pure routing heart. Returns validated response tree or None."""
- # 1. WITNESS – every canonical envelope is seen
- self._log_hook(envelope_tree)
+ # ------------------------------------------------------------------ #
+ # Dispatcher — dumb fire-and-await
+ # ------------------------------------------------------------------ #
+ async def dispatcher(self, state: MessageState) -> None:
+ if not state.target_listeners:
+ return
- # 2. Extract envelope metadata
- meta = envelope_tree.find(f"{ENV}meta")
- if meta is None:
- return None
- from_name = meta.findtext(f"{ENV}from")
- to_name = meta.findtext(f"{ENV}to")
- thread_id = meta.findtext(f"{ENV}thread_id") or meta.findtext(f"{ENV}thread")
+ metadata = HandlerMetadata(
+ thread_id=state.thread_id or "",
+ from_id=state.from_id or "unknown",
+ own_name=state.target_listeners[0].name if state.target_listeners[0].is_agent else None,
+ is_self_call=(state.from_id == state.target_listeners[0].name) if state.from_id else False,
+ )
- # Find payload (first non-meta child)
- payload_elem = next((c for c in envelope_tree if c.tag != f"{ENV}meta"), None)
- if payload_elem is None:
- return None
- payload_tag = payload_elem.tag
-
- # 3. AUTONOMIC REFLEX: explicit
- if payload_tag == LOG_TAG:
- self._log_hook(envelope_tree) # extra vent
- # Minimal ack envelope
- ack = etree.Element(f"{ENV}message")
- meta_ack = etree.SubElement(ack, f"{ENV}meta")
- etree.SubElement(meta_ack, f"{ENV}from").text = "system"
- if from_name:
- etree.SubElement(meta_ack, f"{ENV}to").text = from_name
- if thread_id:
- etree.SubElement(meta_ack, f"{ENV}thread_id").text = thread_id
- etree.SubElement(ack, "logged", status="success")
- return ack
-
- # 4. ROUTING
- listeners_for_tag = self.listeners.get(payload_tag, {})
- response_tree: Optional[etree._Element] = None
- responding_agent_name = "unknown"
-
- if to_name:
- # Directed
- target = listeners_for_tag.get(to_name) or self.global_names.get(to_name)
- if target:
- responding_agent_name = target.agent_name
- response_tree = await target.handle(envelope_tree, thread_id, from_name or client_id)
+ if len(state.target_listeners) == 1:
+ listener = state.target_listeners[0]
+ await self._process_single_handler(state, listener, metadata)
else:
- # Broadcast – first non-None wins (current policy)
+ # Broadcast — fire all in parallel, process responses as they complete
tasks = [
- l.handle(envelope_tree, thread_id, from_name or client_id)
- for l in listeners_for_tag.values()
+ self._process_single_handler(state, listener, metadata)
+ for listener in state.target_listeners
]
- results = await asyncio.gather(*tasks, return_exceptions=True)
- for listener, result in zip(listeners_for_tag.values(), results):
- if isinstance(result, etree._Element):
- responding_agent_name = listener.agent_name
- response_tree = result
- break # first-wins
+ for future in asyncio.as_completed(tasks):
+ await future
- # 5. IDENTITY INSPECTION – prevent spoofing
- if response_tree is not None:
- actual_from = response_tree.findtext(f"{ENV}meta/{ENV}from")
- if actual_from != responding_agent_name:
- logger.critical(
- f"IDENTITY THEFT BLOCKED: expected {responding_agent_name}, got {actual_from}"
- )
- return None
-
- return response_tree
-
- async def run(
- self,
- inbound: AsyncIterator[etree._Element],
- outbound: Callable[[bytes], Awaitable[None]],
- client_id: Optional[str] = None,
- ) -> None:
- """Active pump for a connection. Handles serialization and egress."""
+ async def _process_single_handler(self, state: MessageState, listener: Listener, metadata: HandlerMetadata) -> None:
try:
- async for envelope_tree in inbound:
- if self.shutdown_event.is_set():
- break
+ response_bytes = await listener.handler(state.payload, metadata)
- response_tree = await self.dispatch(envelope_tree, client_id)
- if response_tree is not None:
- serialized = etree.tostring(
- response_tree, encoding="utf-8", pretty_print=True
- )
- await outbound(serialized)
- finally:
- # Optional final courtesy message on clean exit
- goodbye = b""
- try:
- await outbound(goodbye)
- except Exception:
- pass # connection already gone
\ No newline at end of file
+ if response_bytes is None or not isinstance(response_bytes, bytes):
+ response_bytes = b"Handler failed to return valid bytes — missing return or wrong type"
+
+ payloads = await self._multi_payload_extract(response_bytes)
+
+ for payload_bytes in payloads:
+ new_state = MessageState(
+ raw_bytes=payload_bytes,
+ thread_id=state.thread_id,
+ from_id=listener.name,
+ )
+ # Route the new payload through normal pipelines
+ root_tag = self._derive_root_tag(payload_bytes)
+ targets = self.routing_table.get(root_tag)
+ if targets:
+ new_state.target_listeners = targets
+ await targets[0].pipeline.process(new_state)
+ else:
+ await self.system_pipeline.process(new_state)
+
+ except Exception as exc: # pylint: disable=broad-except
+ error_state = MessageState(
+ raw_bytes=b"Handler crashed",
+ thread_id=state.thread_id,
+ from_id=listener.name,
+ error=f"Handler {listener.name} crashed: {exc}",
+ )
+ await self.system_pipeline.process(error_state)
+
+ # ------------------------------------------------------------------ #
+ # Helper methods
+ # ------------------------------------------------------------------ #
+ async def _multi_payload_extract(self, raw_bytes: bytes) -> List[bytes]:
+ # Same logic as before — dummy wrap, repair, extract all root elements
+ # (implementation can be moved to a shared util later)
+ # For now, placeholder — we'll flesh this out in response_processing.py
+ return [raw_bytes] # temporary — will be full extraction
+
+ def _derive_root_tag(self, payload_bytes: bytes) -> str:
+ # Quick parse to get root tag — used only for routing extracted payloads
+ try:
+ tree = etree.fromstring(payload_bytes)
+ tag = tree.tag
+ if tag.startswith("{"):
+ return tag.split("}", 1)[1] # strip namespace
+ return tag
+ except Exception:
+ return ""
+
+ async def system_handler_step(self, state: MessageState) -> MessageState:
+ # Emit or boot message — placeholder for now
+ state.error = state.error or "Unhandled by any listener"
+ return state
\ No newline at end of file
diff --git a/agentserver/message_bus/steps/deserialization.py b/agentserver/message_bus/steps/deserialization.py
new file mode 100644
index 0000000..4ab9753
--- /dev/null
+++ b/agentserver/message_bus/steps/deserialization.py
@@ -0,0 +1,45 @@
+"""
+deserialization.py — Convert validated payload_tree into typed dataclass instance.
+
+After xsd_validation_step confirms the payload conforms to the listener's contract,
+this step uses the xmlable library to deserialize the lxml Element into the
+registered @xmlify dataclass.
+
+The resulting instance is placed in state.payload and handed to the handler.
+
+Part of AgentServer v2.1 message pump.
+"""
+
+from xmlable import from_xml # from the xmlable library
+from agentserver.message_bus.message_state import MessageState
+
+
+async def deserialization_step(state: MessageState) -> MessageState:
+ """
+ Deserialize the validated payload_tree into the listener's dataclass.
+
+ Requires:
+ - state.payload_tree valid against listener XSD
+ - state.metadata["payload_class"] set to the target dataclass (set at registration)
+
+ On success: state.payload = dataclass instance
+ On failure: state.error set with clear message
+ """
+ if state.payload_tree is None:
+ state.error = "deserialization_step: no payload_tree (previous step failed)"
+ return state
+
+ payload_class = state.metadata.get("payload_class")
+ if payload_class is None:
+ state.error = "deserialization_step: no payload_class in metadata (listener misconfigured)"
+ return state
+
+ try:
+ # xmlable.from_xml handles namespace-aware deserialization
+ instance = from_xml(payload_class, state.payload_tree)
+ state.payload = instance
+
+ except Exception as exc: # pylint: disable=broad-except
+ state.error = f"deserialization_step failed: {exc}"
+
+ return state
\ No newline at end of file
diff --git a/agentserver/message_bus/steps/routing_resolution.py b/agentserver/message_bus/steps/routing_resolution.py
new file mode 100644
index 0000000..d38b7b8
--- /dev/null
+++ b/agentserver/message_bus/steps/routing_resolution.py
@@ -0,0 +1,50 @@
+"""
+routing_resolution.py — Resolve routing based on derived root tag.
+
+This is the final preparation step before dispatch.
+It computes the root tag from the deserialized payload and looks it up in the
+global routing table (root_tag → list[Listener]).
+
+On success: state.target_listeners is set
+On failure: state.error is set → message falls to system pipeline for
+
+Part of AgentServer v2.1 message pump.
+"""
+
+from agentserver.message_bus.message_state import MessageState
+from agentserver.message_bus.bus import MessageBus
+
+
+async def routing_resolution_step(state: MessageState) -> MessageState:
+ """
+ Resolve which listener(s) should handle this payload.
+
+ Root tag = f"{from_id.lower()}.{payload_class_name.lower()}"
+ (from_id is trustworthy — injected by pump)
+
+ Supports:
+ - Normal unique routing (one listener)
+ - Broadcast (multiple listeners if broadcast: true and same root tag)
+
+ If no match → error, falls to system pipeline.
+ """
+ if state.payload is None:
+ state.error = "routing_resolution_step: no deserialized payload (previous step failed)"
+ return state
+
+ if state.from_id is None:
+ state.error = "routing_resolution_step: missing from_id (provenance error)"
+ return state
+
+ payload_class_name = type(state.payload).__name__.lower()
+ root_tag = f"{state.from_id.lower()}.{payload_class_name}"
+
+ bus = MessageBus.get_instance()
+ targets = bus.routing_table.get(root_tag)
+
+ if not targets:
+ state.error = f"routing_resolution_step: unknown capability root tag '{root_tag}'"
+ return state
+
+ state.target_listeners = targets
+ return state
\ No newline at end of file
diff --git a/agentserver/message_bus/steps/xsd_validation.py b/agentserver/message_bus/steps/xsd_validation.py
index 27a4c2c..3f919cc 100644
--- a/agentserver/message_bus/steps/xsd_validation.py
+++ b/agentserver/message_bus/steps/xsd_validation.py
@@ -1,13 +1,15 @@
"""
-payload_extraction.py — Extract the inner payload from the validated envelope.
+xsd_validation.py — Validate the extracted payload against the listener-specific XSD.
-After envelope_validation_step confirms a correct outer envelope,
-this step removes the envelope elements (, , optional , etc.)
-and isolates the single child element that is the actual payload.
+After payload_extraction_step isolates the payload_tree and provenance,
+this step validates the payload against the XSD that was auto-generated
+from the listener's @xmlify dataclass at registration time.
-The payload is expected to be exactly one root element (the capability-specific XML).
-If zero or multiple payload roots are found, we set a clear error — this protects
-against malformed or ambiguous messages.
+The XSD is cached and pre-loaded. The schema object is injected into
+state.metadata["schema"] when the listener's pipeline is built.
+
+Failure here means the payload violates the declared contract — we collect
+detailed errors for diagnostics.
Part of AgentServer v2.1 message pump.
"""
@@ -15,77 +17,43 @@ Part of AgentServer v2.1 message pump.
from lxml import etree
from agentserver.message_bus.message_state import MessageState
-# Envelope namespace for easy reference
-_ENVELOPE_NS = "https://xml-pipeline.org/ns/envelope/v1"
-_MESSAGE_TAG = f"{{{ _ENVELOPE_NS }}}message"
-
-async def payload_extraction_step(state: MessageState) -> MessageState:
+async def xsd_validation_step(state: MessageState) -> MessageState:
"""
- Extract the single payload element from the validated envelope.
+ Validate state.payload_tree against the listener's cached XSD schema.
- Expected structure:
-
- uuid
- sender
-
- ← this is the one we want
- ...
-
-
+ Requires:
+ - state.payload_tree set
+ - state.metadata["schema"] containing a pre-loaded etree.XMLSchema
- On success: state.payload_tree is set to the payload Element.
- On failure: state.error is set with a clear diagnostic.
+ On success: payload is guaranteed to match the contract
+ On failure: state.error contains detailed validation messages
"""
- if state.envelope_tree is None:
- state.error = "payload_extraction_step: no envelope_tree (previous step failed)"
+ if state.payload_tree is None:
+ state.error = "xsd_validation_step: no payload_tree (previous extraction failed)"
return state
- # Basic sanity — root must be in correct namespace (already checked by schema,
- # but we double-check for defence in depth)
- if state.envelope_tree.tag != _MESSAGE_TAG:
- state.error = f"payload_extraction_step: root tag is not in envelope namespace"
+ schema = state.metadata.get("schema")
+ if schema is None:
+ state.error = "xsd_validation_step: no XSD schema in metadata (listener pipeline misconfigured)"
return state
- # Find all direct children that are not envelope control elements
- # Envelope control elements are: thread, from, to (optional)
- payload_candidates = [
- child
- for child in state.envelope_tree
- if not (
- child.tag in {
- f"{{{ _ENVELOPE_NS }}}thread",
- f"{{{ _ENVELOPE_NS }}}from",
- f"{{{ _ENVELOPE_NS }}}to",
- }
- )
- ]
-
- if len(payload_candidates) == 0:
- state.error = "payload_extraction_step: no payload element found inside "
+ if not isinstance(schema, etree.XMLSchema):
+ state.error = "xsd_validation_step: metadata['schema'] is not an XMLSchema object"
return state
- if len(payload_candidates) > 1:
- state.error = (
- "payload_extraction_step: multiple payload roots found — "
- "exactly one capability payload element is allowed"
- )
- return state
+ try:
+ # assertValid raises DocumentInvalid with full error log
+ schema.assertValid(state.payload_tree)
- # Success — exactly one payload element
- payload_element = payload_candidates[0]
+ except etree.DocumentInvalid:
+ # Collect all errors for clear diagnostics
+ error_lines = []
+ for error in schema.error_log:
+ error_lines.append(f"{error.level_name}: {error.message} (line {error.line})")
+ state.error = "xsd_validation_step: payload failed contract validation\n" + "\n".join(error_lines)
- # Optional: capture provenance from envelope for later use
- # (these will be trustworthy because envelope was validated)
- thread_elem = state.envelope_tree.find(f"{{{ _ENVELOPE_NS }}}thread")
- from_elem = state.envelope_tree.find(f"{{{ _ENVELOPE_NS }}}from")
-
- if thread_elem is not None and thread_elem.text:
- state.thread_id = thread_elem.text.strip()
-
- if from_elem is not None and from_elem.text:
- state.from_id = from_elem.text.strip()
-
- state.payload_tree = payload_element
+ except Exception as exc: # pylint: disable=broad-except
+ state.error = f"xsd_validation_step: unexpected error during validation: {exc}"
return state
\ No newline at end of file
diff --git a/structure.md b/structure.md
index a52b739..56f6f79 100644
--- a/structure.md
+++ b/structure.md
@@ -19,7 +19,16 @@ xml-pipeline/
│ ├── message_bus/
│ │ ├── steps/
│ │ │ ├── __init__.py
-│ │ │ └── repair_step.py
+│ │ │ ├── c14n.py
+│ │ │ ├── deserialization.py
+│ │ │ ├── envelope_validation.py
+│ │ │ ├── payload_extraction.py
+│ │ │ ├── repair.py
+│ │ │ ├── routing_resolution.py
+│ │ │ ├── test_c14n.py
+│ │ │ ├── test_repair.py
+│ │ │ ├── thread_assignment.py
+│ │ │ └── xsd_validation.py
│ │ ├── __init__.py
│ │ ├── bus.py
│ │ ├── config.py
@@ -48,14 +57,23 @@ xml-pipeline/
│ │ └── token-scheduling-issues.md
│ ├── configuration.md
│ ├── core-principles-v2.1.md
+│ ├── doc_cross_check.md
+│ ├── handler-contract-v2.1.md
│ ├── listener-class-v2.1.md
│ ├── message-pump-v2.1.md
+│ ├── primitives.md
│ ├── self-grammar-generation.md
│ └── why-not-json.md
├── tests/
│ ├── scripts/
│ │ └── generate_organism_key.py
│ └── __init__.py
+├── xml_pipeline.egg-info/
+│ ├── PKG-INFO
+│ ├── SOURCES.txt
+│ ├── dependency_links.txt
+│ ├── requires.txt
+│ └── top_level.txt
├── LICENSE
├── README.md
├── __init__.py
diff --git a/third_party/xmlable/__init__.py b/third_party/xmlable/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/third_party/xmlable/_errors.py b/third_party/xmlable/_errors.py
new file mode 100644
index 0000000..99e8ea2
--- /dev/null
+++ b/third_party/xmlable/_errors.py
@@ -0,0 +1,261 @@
+"""
+Colourful & descriptive errors for xmlable
+- Clear messages
+- Trace for parsing
+"""
+
+from dataclasses import dataclass
+from typing import Any, Iterable
+from termcolor import colored
+from termcolor.termcolor import Color
+
+from xmlable._utils import typename, AnyType
+
+
+def trace_note(trace: list[str], arrow_c: Color, node_c: Color):
+ return colored(" > ", arrow_c).join(
+ map(lambda x: colored(x, node_c), trace)
+ )
+
+
+@dataclass
+class XErrorCtx:
+ trace: list[str]
+
+ def next(self, node: str):
+ return XErrorCtx(trace=self.trace + [node])
+
+
+# TODO: Custom backtrace to point to location in the file
+class XError(Exception):
+ def __init__(
+ self,
+ short: str,
+ what: str,
+ why: str,
+ ctx: XErrorCtx | None = None,
+ notes: Iterable[str] = [],
+ ):
+ super().__init__(colored(short, "red", attrs=["blink"]))
+ self.add_note(colored("What: " + what, "blue"))
+ self.add_note(colored("Why: " + why, "yellow"))
+ if ctx is not None:
+ self.add_note(
+ colored("Where: ", "magenta")
+ + trace_note(ctx.trace, "light_magenta", "light_cyan")
+ )
+ for note in notes:
+ self.add_note(note)
+
+
+class ErrorTypes:
+ @staticmethod
+ def NonXMlifiedType(t_name: str) -> XError:
+ return XError(
+ short="Non XMlified Type",
+ what=f"You attempted to use {t_name} in an xmlified class, but {t_name} is not xmlified",
+ why=f"All types used in an xmlified class must be xmlified",
+ )
+
+ @staticmethod
+ def InvalidData(ctx: XErrorCtx, val: Any, t_name: str) -> XError:
+ return XError(
+ short="Invalid Data",
+ what=f"Could not validate {val} as a valid {t_name}",
+ why=f"Produced xml must be valid",
+ ctx=ctx,
+ )
+
+ @staticmethod
+ def ParseFailure(
+ ctx: XErrorCtx, text: str | None, t_name: str, caught: Exception
+ ) -> XError:
+ return XError(
+ short="Parse Failure",
+ what=f"Failed to parse {text} as a {t_name} with error: \n {caught}",
+ why=f"This error implies the xml is not validated against the current xsd, or there is a bug in this type's parser",
+ ctx=ctx,
+ )
+
+ @staticmethod
+ def UnexpectedTag(
+ ctx: XErrorCtx, expected_name: str, struct_name: str, tag_found: str
+ ) -> XError:
+ return XError(
+ short="Unexpected Tag",
+ what=f"Expected {expected_name} but found {tag_found}",
+ why=f"This is a {struct_name} that contains 0..n elements of {expected_name} and no other elements",
+ ctx=ctx,
+ )
+
+ @staticmethod
+ def IncorrectType(
+ ctx: XErrorCtx, expected_len: int, struct_name: str, val: Any, name: str
+ ) -> XError:
+ return XError(
+ short="Incorrect Type",
+ what=f"You have provided {len(val)} values {val} for {name}, but {name} is a {struct_name} that takes only {expected_len} values",
+ why=f"In order to generate xml, the values provided need to be the correct types",
+ ctx=ctx,
+ )
+
+ @staticmethod
+ def IncorrectElementTag(
+ ctx: XErrorCtx,
+ struct_name: str,
+ tag_name: str,
+ elem_index: int,
+ tag_expected: str,
+ tag_found: str,
+ ) -> XError:
+ return XError(
+ short="Incorrect Element Tag",
+ what=f"While parsing {struct_name} {tag_name} we expected element {elem_index} to be {tag_expected}, but found {tag_found}",
+ why=f"The xml representation for {struct_name} requires the correct names in the correct order",
+ ctx=ctx,
+ )
+
+ @staticmethod
+ def DuplicateItem(
+ ctx: XErrorCtx, struct_name: str, tag: str, item: str
+ ) -> XError:
+ return XError(
+ short=f"Duplicate item in {struct_name}",
+ what=f"In {tag} the item {item} is present more than once",
+ why=f"A set can only contain unique items",
+ ctx=ctx,
+ )
+
+ @staticmethod
+ def InvalidDictionaryItem(
+ ctx: XErrorCtx,
+ expected_tag: str,
+ expected_key: str,
+ expected_val: str,
+ dict_tag: str,
+ item_tag: str,
+ ) -> XError:
+ return XError(
+ short="Invalid item in dictionary",
+ what=f"An unexpected item with {dict_tag} is in dictionary {item_tag}",
+ why=f"Each item must have tag {expected_tag} with children {expected_key} and {expected_val}",
+ ctx=ctx,
+ )
+
+ @staticmethod
+ def InvalidVariant(
+ ctx: XErrorCtx,
+ name: str,
+ expected_types: list[AnyType],
+ found_type: AnyType | None,
+ found_value: Any,
+ ) -> XError:
+ types = " | ".join(map(str, expected_types))
+ return XError(
+ short=f"Datatype not in Union",
+ what=f"{name} is a union of {types}, which does not contain {found_type} (you provided: {found_value})",
+ why=f"... uuuh, its a union?",
+ ctx=ctx,
+ )
+
+ @staticmethod
+ def MultipleVariants(ctx: XErrorCtx, variant_names: list[str]) -> XError:
+ return XError(
+ short="Multiple union variants present",
+ what=f"variants {', '.join(variant_names)} are present",
+ why=f"A union can only be one variant at a time",
+ ctx=ctx,
+ )
+
+ @staticmethod
+ def ParseInvalidVariant(
+ ctx: XErrorCtx, tag: str, named_variants: list[str], found_variant: str
+ ) -> XError:
+ return XError(
+ short="Invalid Variant",
+ what=f"The union {tag} can contain variants {', '.join(named_variants)}, but you have used {found_variant}",
+ why=f"Only valid variants can be parsed",
+ ctx=ctx,
+ )
+
+ @staticmethod
+ def NoneIsSome(ctx: XErrorCtx, name: str, val: Any) -> XError:
+ return XError(
+ short="None object is not None",
+ what=f"{name} contains value {val} which is not None",
+ why="A None type object can only contain none",
+ ctx=ctx,
+ )
+
+ @staticmethod
+ def NotADataclass(cls: AnyType) -> XError:
+ cls_name: str = typename(cls)
+ return XError(
+ short="Non-Dataclass",
+ what=f"{cls_name} is not a dataclass",
+ why=f"xmlify uses dataclasses to get fields",
+ ctx=XErrorCtx([cls_name]),
+ notes=[f"\nTry:\n@xmlify\n@dataclass\nclass {cls_name}:"],
+ )
+
+ @staticmethod
+ def ReservedAttribute(cls: AnyType, attr_name: str) -> XError:
+ cls_name: str = typename(cls)
+ return XError(
+ short=f"Reserved Attribute",
+ what=f"{cls_name}.{attr_name} is used by xmlify, so it cannot be a field of the class",
+ why=f"xmlify aguments {cls_name} by adding methods it can then use for xsd, xml generation and parsing",
+ ctx=XErrorCtx([cls_name]),
+ )
+
+ @staticmethod
+ def CommentAttribute(cls: AnyType) -> XError:
+ cls_name: str = typename(cls)
+ return XError(
+ short=f"Comment Attribute",
+ what=f"xmlifed classes cannot use comment as an attribute",
+ why=f"comment is used as a tag name for comments by lxml, so comments inserted on xml generation could conflict",
+ ctx=XErrorCtx([cls_name]),
+ )
+
+ @staticmethod
+ def NonMemberTag(
+ ctx: XErrorCtx, cls: AnyType, tag: str, name: str
+ ) -> XError:
+ cls_name: str = typename(cls)
+ return XError(
+ short="Non member tag",
+ what=f"In {tag} {cls_name}.{name} could not be found.",
+ why=f"All members, including {cls_name}.{name} must be present",
+ ctx=ctx,
+ )
+
+ @staticmethod
+ def MissingAttribute(
+ cls: AnyType, required_attrs: set[str], missing_attr: str
+ ) -> XError:
+ cls_name: str = typename(cls)
+ return XError(
+ short="Missing Attribute",
+ what=f"The attribute {missing_attr} is missing from {cls_name}",
+ why=f"To be manual_xmlified the attributes: {', '.join(required_attrs)} are required. Try using help(IXmlify)",
+ ctx=XErrorCtx([cls_name]),
+ )
+
+ @staticmethod
+ def DependencyCycle(cycle: list[AnyType]) -> XError:
+ return XError(
+ short="Dependency Cycle in XSD",
+ what=f"There is a cycle: {'<-'.join(map(str, cycle))}",
+ why="The XSDs for classes are written to the .xsd file in dependency order",
+ )
+
+ @staticmethod
+ def NotXmlified(cls: AnyType) -> XError:
+ cls_name: str = typename(cls)
+ return XError(
+ short="Not Xmlified",
+ what=f"{cls_name} is not xmlified, and hence cannot have an associated parser",
+ why=f"the .xsd(...) method is required to write_xsd",
+ notes=[f"To fix, try:\n@xmlify\n@dataclass\nclass {cls_name}: ..."],
+ )
diff --git a/third_party/xmlable/_io.py b/third_party/xmlable/_io.py
new file mode 100644
index 0000000..475e205
--- /dev/null
+++ b/third_party/xmlable/_io.py
@@ -0,0 +1,67 @@
+"""
+Easy file IO for users
+- Need to make it obvious when an xml has been overwritten
+- Easy parsing from a file
+"""
+
+from pathlib import Path
+from typing import Any, TypeVar
+from termcolor import colored
+from lxml.objectify import parse as objectify_parse
+from lxml.etree import _ElementTree
+
+from xmlable._utils import typename
+from xmlable._xobject import is_xmlified
+from xmlable._errors import ErrorTypes
+
+
+def write_file(file_path: str | Path, tree: _ElementTree):
+ print(
+ colored(f"Overwriting {file_path}", "red", attrs=["blink"]), end="..."
+ )
+ with open(file=file_path, mode="wb") as f:
+ tree.write(f, xml_declaration=True, encoding="utf-8", pretty_print=True)
+ print(colored(f"Complete!", "green", attrs=["blink"]))
+
+
+def parse_file(cls: type, file_path: str | Path) -> Any:
+ """
+ Parse a file, validate and produce instance of cls
+ INV: cls must be an xmlified class
+ """
+ if not is_xmlified(cls):
+ raise ErrorTypes.NotXmlified(cls)
+ with open(file=file_path, mode="r") as f:
+ return cls.parse(objectify_parse(f).getroot()) # type: ignore[attr-defined]
+
+
+def write_xsd(
+ file_path: str | Path,
+ cls: type,
+ namespaces: dict[str, str] = {},
+ imports: dict[str, str] = {},
+):
+ if not is_xmlified(cls):
+ raise ErrorTypes.NonXMlifiedType(typename(cls))
+ else:
+ write_file(file_path, cls.xsd(namespaces=namespaces, imports=imports)) # type: ignore[attr-defined]
+
+
+def write_xml_template(
+ file_path: str | Path, cls: type, schema_name: str | None = None
+):
+ if not is_xmlified(cls):
+ raise ErrorTypes.NonXMlifiedType(typename(cls))
+ else:
+ schema_id: str = (
+ schema_name if schema_name is not None else typename(cls)
+ )
+ write_file(file_path, cls.xml(schema_id)) # type: ignore[attr-defined]
+
+
+def write_xml_value(file_path: str | Path, val: Any):
+ cls = type(val)
+ if not is_xmlified(cls):
+ raise ErrorTypes.NonXMlifiedType(typename(cls))
+ else:
+ write_file(file_path, val.xml_value()) # type: ignore[attr-defined]
diff --git a/third_party/xmlable/_lxml_helpers.py b/third_party/xmlable/_lxml_helpers.py
new file mode 100644
index 0000000..1f1f288
--- /dev/null
+++ b/third_party/xmlable/_lxml_helpers.py
@@ -0,0 +1,33 @@
+"""
+Helper functions for wrangling to lxml library
+- Includes the XMLSchema used
+"""
+
+from lxml.objectify import ObjectifiedElement
+from lxml.etree import _Element
+from typing import Iterable
+
+XMLURL = r"http://www.w3.org/2001/XMLSchema"
+XMLSchema = r"{http://www.w3.org/2001/XMLSchema}"
+
+
+def with_text(e: _Element, text: str) -> _Element:
+ e.text = text
+ return e
+
+
+def with_children(parent: _Element, children: Iterable[_Element]) -> _Element:
+ for child in children:
+ parent.append(child)
+ return parent
+
+
+def with_child(parent: _Element, child: _Element) -> _Element:
+ return with_children(parent, [child])
+
+
+def children(obj: ObjectifiedElement) -> Iterable[ObjectifiedElement]:
+ def not_comment(child_obj: ObjectifiedElement):
+ return child_obj.tag != "comment"
+
+ return filter(not_comment, obj.getchildren()) # type: ignore[arg-type, operator]
diff --git a/third_party/xmlable/_manual.py b/third_party/xmlable/_manual.py
new file mode 100644
index 0000000..9e73c31
--- /dev/null
+++ b/third_party/xmlable/_manual.py
@@ -0,0 +1,137 @@
+"""
+The @manual_xmlify decorator used to add the .xsd, .xml, .xml_value and .parse
+methods to a class that already has .xsd_dependencies, .xsd_forward and
+.get_xobject
+"""
+
+from typing import Any
+from lxml.etree import _Element, Element, _ElementTree, ElementTree
+from lxml.objectify import ObjectifiedElement
+
+from xmlable._utils import typename, AnyType, ordered_iter
+from xmlable._lxml_helpers import with_children, XMLSchema
+from xmlable._errors import XError, XErrorCtx, ErrorTypes
+
+
+def validate_manual_class(cls: AnyType):
+ attrs = {"get_xobject", "xsd_forward", "xsd_dependencies"}
+ for attr in attrs:
+ if not hasattr(cls, attr):
+ raise ErrorTypes.MissingAttribute(cls, attrs, attr)
+
+
+def type_cycle(from_type: AnyType) -> list[AnyType]:
+ # INV: it is an xmlified type for a user define structure
+ cycle: list[AnyType] = []
+
+ def visit_dep(curr: AnyType) -> bool:
+ if curr == from_type or any(
+ visit_dep(dep) for dep in ordered_iter(curr.xsd_dependencies()) # type: ignore[attr-defined]
+ ):
+ cycle.append(curr)
+ return True
+ else:
+ return False
+
+ assert visit_dep(from_type)
+ cycle.append(from_type)
+ return cycle
+
+
+def manual_xmlify(cls: type) -> type:
+ """
+ Generate the following methods:
+ ```
+ def xsd(
+ id: str = cls_name,
+ namespaces: dict[str, str] = {},
+ imports: dict[str, str] = {},
+ ) -> _ElementTree:
+ # ...
+
+ def xml(schema_name: str = cls_name) -> _ElementTree:
+ # ...
+
+ def xml_value(self, id: str = cls_name) -> _ElementTree:
+ # ...
+
+ def parse(obj: ObjectifiedElement) -> Any:
+ # ...
+ ```
+ """
+ try:
+ validate_manual_class(cls)
+ cls_name = typename(cls)
+
+ cls_xobject = cls.get_xobject() # type: ignore[attr-defined]
+
+ def xsd(
+ id: str = cls_name,
+ namespaces: dict[str, str] = {},
+ imports: dict[str, str] = {},
+ ) -> _ElementTree:
+ # Get dependencies (user classes that need to be declared before)
+ visited: set[AnyType] = set()
+ dec_order: list[AnyType] = []
+
+ def toposort(
+ curr: AnyType, visited: set[AnyType], dec_order: list[AnyType]
+ ):
+ if curr in visited:
+ raise ErrorTypes.DependencyCycle(type_cycle(curr))
+ visited.add(curr)
+ deps = curr.xsd_dependencies() # type: ignore[attr-defined]
+ for d in ordered_iter(deps):
+ if d not in visited:
+ toposort(d, visited, dec_order)
+ dec_order.append(curr)
+
+ toposort(cls, visited, dec_order)
+
+ # Create forward declarations, potentially adding to namespaces
+ decs: list[_Element] = [dec.xsd_forward(namespaces) for dec in dec_order] # type: ignore[attr-defined]
+
+ # generate main element (can add to namespaces)
+ main_element = cls_xobject.xsd_out(id, add_ns=namespaces)
+
+ return ElementTree(
+ with_children(
+ Element(
+ f"{XMLSchema}schema",
+ id=id,
+ elementFormDefault="qualified",
+ nsmap=namespaces,
+ ),
+ [
+ Element(
+ f"{XMLSchema}import",
+ namespace=ns,
+ schemaLocation=sloc,
+ )
+ for ns, sloc in imports.items()
+ ]
+ + decs
+ + [main_element],
+ )
+ )
+
+ def xml(schema_name: str = cls_name) -> _ElementTree:
+ return ElementTree(cls_xobject.xml_temp(schema_name))
+
+ def xml_value(self, id: str = cls_name) -> _ElementTree:
+ return ElementTree(cls_xobject.xml_out(id, self, XErrorCtx([id])))
+
+ def parse(obj: ObjectifiedElement) -> Any:
+ return cls_xobject.xml_in(obj, XErrorCtx([obj.tag]))
+
+ cls.xsd = xsd # type: ignore[attr-defined]
+ cls.xml = xml # type: ignore[attr-defined]
+ setattr(cls, "xml_value", xml_value) # needs to use self to get values
+ cls.parse = parse # type: ignore[attr-defined]
+
+ return cls
+ except XError as e:
+ # NOTE: Trick to remove dirty 'internal' traceback, and raise from
+ # xmlify (makes more sense to user than seeing internals)
+ e.__traceback__ = None
+ raise e
diff --git a/third_party/xmlable/_user.py b/third_party/xmlable/_user.py
new file mode 100644
index 0000000..d94db60
--- /dev/null
+++ b/third_party/xmlable/_user.py
@@ -0,0 +1,71 @@
+"""
+The IXmlify interface
+- Contains the methods needed to make get_xobject work
+- Allows type checking of user's implementations
+"""
+
+from abc import ABC, abstractmethod
+from lxml.etree import _Element
+from xmlable._xobject import XObject
+from xmlable._utils import AnyType
+
+
+class IXmlify(ABC):
+ """
+ A useful interface for ensuring the attributes required for
+ @manual_xmlify are present
+ """
+
+ @staticmethod
+ @abstractmethod
+ def get_xobject() -> XObject:
+ """
+ produces an xobject encapsulates the:
+ - xsd usage (e.g )
+ - xml template
+ - xml value
+ - parsing
+
+ ```
+ @manual_xmlify
+ class Foo(IXmlify):
+ def get_xobject() -> XObject:
+ class MyObj(XObject):
+ # ... definitions
+
+ return MyObj
+ ```
+ """
+ pass
+
+ @staticmethod
+ @abstractmethod
+ def xsd_forward(add_ns: dict[str, str]) -> _Element:
+ """
+ Produces the forward declaration
+ - xsd definition of the class's type
+ ```
+ @manual_xmlify
+ class Foo(IXmlify):
+ def xsd_forward(add_ns: dict[str, str]) -> _Element:
+ return Element('{XMLSchema}complexType', name="Foo", ...)
+ ```
+ """
+ pass
+
+ @staticmethod
+ @abstractmethod
+ def xsd_dependencies() -> set[AnyType]:
+ """
+ The user classes that need to be before this first
+
+ For example:
+ ```
+ @manual_xmlify
+ class A(IXMlify):
+ # xobject uses Foo and Bar
+ def xsd_depedencies() -> set[type]:
+ return {Foo, Bar}
+ ```
+ """
+ pass
diff --git a/third_party/xmlable/_utils.py b/third_party/xmlable/_utils.py
new file mode 100644
index 0000000..c112bb0
--- /dev/null
+++ b/third_party/xmlable/_utils.py
@@ -0,0 +1,63 @@
+"""
+Basic Utilities
+Includes common helper functions for this project
+- Handling optionals
+- getting members by string name
+- typenames
+"""
+
+from typing import Any, Callable, TypeVar, TypeAlias, Type, Iterable
+from types import GenericAlias
+
+AnyType: TypeAlias = Type | GenericAlias
+
+T = TypeVar("T")
+
+
+def some_or(data: T | None, alt: T):
+ return data if data is not None else alt
+
+
+N = TypeVar("N")
+M = TypeVar("M")
+
+
+def some_or_apply(data: N, fn: Callable[[N], M], alt: M):
+ return fn(data) if data is not None else alt
+
+
+def get(obj: Any, attr: str) -> Any:
+ return obj.__getattribute__(attr)
+
+
+def opt_get(obj: Any, attr: str) -> Any | None:
+ try:
+ return obj.__getattribute__(attr)
+ except AttributeError:
+ return None
+
+
+X = TypeVar("X")
+Y = TypeVar("Y")
+
+
+def firstkey(d: dict[X, Y], val: Y) -> X | None:
+ for k, v in d.items():
+ if v == val:
+ return k
+ else:
+ return None
+
+
+def typename(t: AnyType) -> str:
+ if t is None:
+ return "None"
+ else:
+ return t.__name__
+
+
+Z = TypeVar("Z")
+
+
+def ordered_iter(types: Iterable[Z]) -> list[Z]:
+ return sorted(list(types), key=str)
diff --git a/third_party/xmlable/_xmlify.py b/third_party/xmlable/_xmlify.py
new file mode 100644
index 0000000..050b069
--- /dev/null
+++ b/third_party/xmlable/_xmlify.py
@@ -0,0 +1,156 @@
+"""XMLable
+A decorator to allow creation of xml config based on python dataclasses
+
+Given a dataclass:
+- Produce an xsd schema based on the class
+- Produce an xml template based on the class
+- Given any instance of the class, make a best-effort attempt at turning it into
+ a filled xml
+- Create a parser for parsing the xml
+"""
+
+from humps import pascalize
+from dataclasses import fields, is_dataclass
+from typing import Any, dataclass_transform, cast
+from lxml.objectify import ObjectifiedElement
+from lxml.etree import Element, _Element
+
+from xmlable._utils import get, typename, AnyType
+from xmlable._errors import XError, XErrorCtx, ErrorTypes
+from xmlable._manual import manual_xmlify
+from xmlable._lxml_helpers import with_children, with_child, XMLSchema
+from xmlable._xobject import XObject, gen_xobject
+
+
+def validate_class(cls: AnyType):
+ """
+ Validate tha the class can be xmlified
+ - Must be a dataclass
+ - Cannot have any members called 'comment' (lxml parses comments as this tag)
+ - Cannot have
+ """
+ if not is_dataclass(cls):
+ raise ErrorTypes.NotADataclass(cls)
+
+ reserved_attrs = ["get_xobject", "xsd_forward", "xsd_dependencies"]
+
+ # TODO: cleanup repetition
+ for f in fields(cls):
+ if f.name in reserved_attrs:
+ raise ErrorTypes.ReservedAttribute(cls, f.name)
+ elif f.name == "comment":
+ raise ErrorTypes.CommentAttribute(cls)
+
+ # JUSTIFY: Could potentially have added other attributes (of the class,
+ # rather than a field of an instance as provided by dataclass
+ # fields)
+ for reserved in reserved_attrs:
+ if hasattr(cls, reserved):
+ raise ErrorTypes.ReservedAttribute(cls, reserved)
+ if hasattr(cls, "comment"):
+ raise ErrorTypes.CommentAttribute(cls)
+
+
+@dataclass_transform()
+def xmlify(cls: type) -> AnyType:
+ try:
+ validate_class(cls)
+
+ cls_name = typename(cls)
+ forward_decs = cast(set[AnyType], {cls})
+ meta_xobjects = [
+ (
+ pascalize(f.name),
+ f,
+ gen_xobject(cast(AnyType, f.type), forward_decs),
+ )
+ for f in fields(cls)
+ ]
+
+ class UserXObject(XObject):
+ def xsd_out(
+ self,
+ name: str,
+ attribs: dict[str, str] = {},
+ add_ns: dict[str, str] = {},
+ ) -> _Element:
+ return Element(
+ f"{XMLSchema}element",
+ name=name,
+ type=cls_name,
+ attrib=attribs,
+ )
+
+ def xml_temp(self, name: str) -> _Element:
+ return with_children(
+ Element(name),
+ [
+ xobj.xml_temp(pascal_name)
+ for pascal_name, _, xobj in meta_xobjects
+ ],
+ )
+
+ def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
+ return with_children(
+ Element(name),
+ [
+ xobj.xml_out(
+ pascal_name,
+ get(val, m.name),
+ ctx.next(pascal_name),
+ )
+ for pascal_name, m, xobj in meta_xobjects
+ ],
+ )
+
+ def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> Any:
+ parsed: dict[str, Any] = {}
+ for pascal_name, m, xobj in meta_xobjects:
+ if (m_obj := get(obj, pascal_name)) is not None:
+ parsed[m.name] = xobj.xml_in(
+ m_obj, ctx.next(pascal_name)
+ )
+ else:
+ raise ErrorTypes.NonMemberTag(ctx, cls, obj.tag, m.name)
+ return cls(**parsed)
+
+ cls_xobject = UserXObject()
+
+ # JUSTIFY: Why are xsd forward & dependencies not part of xobject?
+ # - xobject covers the use (not forward decs)
+ # - we want to present error messages to the user containing
+ # their types, so xsd dependencies are in terms of python
+ # types, rather than xobjects
+ # - forward and dependencies do not apply to the basic types,
+ # only user types
+
+ def xsd_forward(add_ns: dict[str, str]) -> _Element:
+ return with_child(
+ Element(f"{XMLSchema}complexType", name=cls_name),
+ with_children(
+ Element(f"{XMLSchema}sequence"),
+ [
+ xobj.xsd_out(pascal_name, attribs={}, add_ns=add_ns)
+ for pascal_name, m, xobj in meta_xobjects
+ ],
+ ),
+ )
+
+ def xsd_dependencies() -> set[AnyType]:
+ return forward_decs
+
+ def get_xobject():
+ return cls_xobject
+
+ # helper methods for gen_xobject, and other dataclasses to generate their
+ # x methods
+ cls.xsd_forward = xsd_forward # type: ignore[attr-defined]
+ cls.xsd_dependencies = xsd_dependencies # type: ignore[attr-defined]
+ cls.get_xobject = get_xobject # type: ignore[attr-defined]
+
+ return manual_xmlify(cls)
+ except XError as e:
+ # NOTE: Trick to remove dirty 'internal' traceback, and raise from
+ # xmlify (makes more sense to user than seeing internals)
+ e.__traceback__ = None
+ raise e
diff --git a/third_party/xmlable/_xobject.py b/third_party/xmlable/_xobject.py
new file mode 100644
index 0000000..766695e
--- /dev/null
+++ b/third_party/xmlable/_xobject.py
@@ -0,0 +1,640 @@
+"""XObjects
+XObjects are an intermediate representation for python types -> xsd/xml
+- Produced by @xmlify decorated classes, and by gen_xobject
+- Associated xsd, xml and parsing
+"""
+
+from humps import pascalize
+from dataclasses import dataclass
+from types import NoneType, UnionType
+from lxml.objectify import ObjectifiedElement
+from lxml.etree import Element, Comment, _Element
+from abc import ABC, abstractmethod
+from typing import Any, Callable, Type, get_args, TypeAlias, cast
+from types import GenericAlias
+
+from xmlable._utils import get, typename, firstkey, AnyType
+from xmlable._errors import XErrorCtx, ErrorTypes
+from xmlable._lxml_helpers import (
+ with_text,
+ with_child,
+ with_children,
+ XMLSchema,
+ XMLURL,
+ children,
+)
+
+
+class XObject(ABC):
+ """Any XObject wraps the xsd generation,
+ We can map types to XObjects to get the xsd, template xml, etc
+ """
+
+ @abstractmethod
+ def xsd_out(
+ self,
+ name: str,
+ attribs: dict[str, str] = {},
+ add_ns: dict[str, str] = {},
+ ) -> _Element:
+ """Generate the xsd schema for the object"""
+ pass
+
+ @abstractmethod
+ def xml_temp(self, name: str) -> _Element:
+ """
+ Generate commented output for the xml representation
+ - Contains no values, only comments
+ - Not valid xml (can contain nested comments, comments instead of values)
+ """
+ pass
+
+ @abstractmethod
+ def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
+ pass
+
+ @abstractmethod
+ def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> Any:
+ pass
+
+
+@dataclass
+class BasicObj(XObject):
+ """
+ An xobject for a simple type (e.g string, int)
+ """
+
+ type_str: str
+ convert_fn: Callable[[Any], str]
+ validate_fn: Callable[[Any], bool]
+ parse_fn: Callable[[ObjectifiedElement], Any]
+
+ def xsd_out(
+ self,
+ name: str,
+ attribs: dict[str, Any] = {},
+ add_ns: dict[str, str] = {},
+ ) -> _Element:
+ # NOTE: namespace cringe:
+ # - lxml will deal with qualifying namespaces for the name of the
+ # element, but not for attributes
+ # - XMLSchema type attributes must be qualified
+ if (prefix := firstkey(add_ns, XMLURL)) is not None:
+ return Element(
+ f"{XMLSchema}element",
+ name=name,
+ type=f"{prefix}:{self.type_str}",
+ attrib=attribs,
+ )
+ else:
+ # add new namespace, resolve conflicts with extra 's'
+ new_ns = "xs"
+ while new_ns in add_ns:
+ new_ns += "s"
+ add_ns[new_ns] = XMLURL
+ return Element(
+ f"{XMLSchema}element",
+ name=name,
+ type=f"{new_ns}:{self.type_str}",
+ attrib=attribs,
+ nsmap={new_ns: XMLURL},
+ )
+
+ def xml_temp(self, name: str) -> _Element:
+ return with_text(Element(name), f"Fill me with an {self.type_str}")
+
+ def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
+ if not self.validate_fn(val):
+ raise ErrorTypes.InvalidData(ctx, val, self.type_str)
+ return with_text(Element(name), self.convert_fn(val))
+
+ def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> Any:
+ try:
+ return self.parse_fn(obj)
+ except Exception as e:
+ raise ErrorTypes.ParseFailure(ctx, obj.text, self.type_str, e)
+
+
+@dataclass
+class ListObj(XObject):
+ """
+ An ordered list of objects
+ """
+
+ item_xobject: XObject
+ list_elem_name: str
+ struct_name: str = "List"
+
+ def xsd_out(
+ self,
+ name: str,
+ attribs: dict[str, str] = {},
+ add_ns: dict[str, str] = {},
+ ) -> _Element:
+ return with_child(
+ Element(f"{XMLSchema}element", name=name, attrib=attribs),
+ with_children(
+ Element(f"{XMLSchema}complexType"),
+ [
+ Comment(f"This is a {self.struct_name}"),
+ with_child(
+ Element(f"{XMLSchema}sequence"),
+ self.item_xobject.xsd_out(
+ self.list_elem_name,
+ {"minOccurs": "0", "maxOccurs": "unbounded"},
+ add_ns,
+ ),
+ ),
+ ],
+ ),
+ )
+
+ def xml_temp(self, name: str) -> _Element:
+ return with_children(
+ Element(name),
+ [
+ Comment(f"This is a {self.struct_name}"),
+ self.item_xobject.xml_temp(self.list_elem_name),
+ ],
+ )
+
+ def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
+ if len(val) > 0:
+ return with_children(
+ Element(name),
+ [
+ self.item_xobject.xml_out(
+ self.list_elem_name,
+ item_val,
+ ctx.next(f"{self.list_elem_name}[{i}]"),
+ )
+ for i, item_val in enumerate(val)
+ ],
+ )
+ else:
+ return with_child(
+ Element(name), Comment(f"Empty {self.struct_name}!")
+ )
+
+ def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> list[Any]:
+ parsed = []
+ for i, child in enumerate(children(obj)):
+ if child.tag != self.list_elem_name:
+ raise ErrorTypes.UnexpectedTag(
+ ctx, self.list_elem_name, self.struct_name, child.tag
+ )
+ else:
+ parsed.append(
+ self.item_xobject.xml_in(
+ child, ctx.next(f"{self.list_elem_name}[{i}]")
+ )
+ )
+ return parsed
+
+
+@dataclass
+class StructObj(XObject):
+ """An order list of key-value pairs""" # TODO: make objects variable length tuple
+
+ objects: list[tuple[str, XObject]]
+ struct_name: str = "Struct"
+
+ def xsd_out(
+ self,
+ name: str,
+ attribs: dict[str, str] = {},
+ add_ns: dict[str, str] = {},
+ ) -> _Element:
+ return with_child(
+ Element(f"{XMLSchema}element", name=name, attrib=attribs),
+ with_child(
+ Element(f"{XMLSchema}complexType"),
+ with_children(
+ Element(f"{XMLSchema}sequence"),
+ [Comment(f"This is a {self.struct_name}")]
+ + [
+ xobj.xsd_out(member, {}, add_ns)
+ for member, xobj in self.objects
+ ],
+ ),
+ ),
+ )
+
+ def xml_temp(self, name: str) -> _Element:
+ return with_children(
+ Element(name),
+ [Comment(f"This is a {self.struct_name}")]
+ + [xobj.xml_temp(member) for member, xobj in self.objects],
+ )
+
+ def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
+ if len(val) != len(self.objects):
+ raise ErrorTypes.IncorrectType(
+ ctx, len(self.objects), self.struct_name, val, name
+ )
+
+ return with_children(
+ Element(name),
+ [
+ xobj.xml_out(member, v, ctx.next(member))
+ for (member, xobj), v in zip(self.objects, val)
+ ],
+ )
+
+ def xml_in(
+ self, obj: ObjectifiedElement, ctx: XErrorCtx
+ ) -> list[tuple[str, Any]]:
+ parsed = []
+ for i, (child, (name, xobj)) in enumerate(
+ zip(children(obj), self.objects)
+ ):
+ if child.tag != name:
+ raise ErrorTypes.IncorrectElementTag(
+ ctx, self.struct_name, obj.tag, i, name, child.tag
+ )
+ parsed.append((name, xobj.xml_in(child, ctx.next(name))))
+ return parsed
+
+
+class TupleObj(XObject):
+ """An anonymous struct"""
+
+ def __init__(
+ self,
+ objects: tuple[XObject, ...],
+ elem_gen: Callable[[int], str] = lambda i: f"Item-{i+1}",
+ ):
+ self.elem_gen = elem_gen
+ self.struct: StructObj = StructObj(
+ [(self.elem_gen(i), xobj) for i, xobj in enumerate(objects)],
+ struct_name="Tuple",
+ )
+
+ def xsd_out(
+ self,
+ name: str,
+ attribs: dict[str, str] = {},
+ add_ns: dict[str, str] = {},
+ ) -> _Element:
+ return self.struct.xsd_out(name, attribs, add_ns)
+
+ def xml_temp(self, name: str) -> _Element:
+ return self.struct.xml_temp(name)
+
+ def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
+ return self.struct.xml_out(name, val, ctx)
+
+ def xml_in(
+ self, obj: ObjectifiedElement, ctx: XErrorCtx
+ ) -> tuple[Any, ...]:
+ # Assumes the objects are in the correct order
+ return tuple(zip(*self.struct.xml_in(obj, ctx)))[1] # type: ignore[no-any-return]
+
+
+class SetOBj(XObject):
+ """An unordered collection of unique elements"""
+
+ def __init__(self, inner: XObject, elem_name: str = "setitem"):
+ self.list = ListObj(inner, elem_name, struct_name="set")
+
+ def xsd_out(
+ self,
+ name: str,
+ attribs: dict[str, str] = {},
+ add_ns: dict[str, str] = {},
+ ) -> _Element:
+ return self.list.xsd_out(name, attribs, add_ns)
+
+ def xml_temp(self, name: str) -> _Element:
+ return self.list.xml_temp(name)
+
+ def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
+ return self.list.xml_out(name, list(val), ctx)
+
+ def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> set[Any]:
+ parsed: set[Any] = set()
+ for item in self.list.xml_in(obj, ctx):
+ if item in parsed:
+ raise ErrorTypes.DuplicateItem(ctx, "set", obj.tag, item)
+ parsed.add(item)
+ return parsed
+
+
+@dataclass
+class DictObj(XObject):
+ """An unordered collection of key-value pair elements"""
+
+ key_xobject: XObject
+ val_xobject: XObject
+ key_name: str = "Key"
+ val_name: str = "Val"
+ item_name: str = "Item"
+
+ def xsd_out(
+ self,
+ name: str,
+ attribs: dict[str, str] = {},
+ add_ns: dict[str, str] = {},
+ ) -> _Element:
+ return with_child(
+ Element(f"{XMLSchema}element", name=name, attrib=attribs),
+ with_children(
+ Element(f"{XMLSchema}complexType"),
+ [
+ Comment("this is a dictionary!"),
+ with_child(
+ Element(f"{XMLSchema}sequence"),
+ with_child(
+ Element(
+ f"{XMLSchema}element",
+ name=self.item_name,
+ minOccurs="0",
+ maxOccurs="unbounded",
+ ),
+ with_child(
+ Element(f"{XMLSchema}complexType"),
+ with_children(
+ Element(f"{XMLSchema}sequence"),
+ [
+ self.key_xobject.xsd_out(
+ self.key_name, {}, add_ns
+ ),
+ self.val_xobject.xsd_out(
+ self.val_name, {}, add_ns
+ ),
+ ],
+ ),
+ ),
+ ),
+ ),
+ ],
+ ),
+ )
+
+ def xml_temp(self, name: str) -> _Element:
+ return with_children(
+ Element(name),
+ [
+ Comment("This is a dictionary"),
+ with_children(
+ Element(self.item_name),
+ [
+ self.key_xobject.xml_temp(self.key_name),
+ self.val_xobject.xml_temp(self.val_name),
+ ],
+ ),
+ ],
+ )
+
+ def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
+ item_ctx = ctx.next(self.item_name)
+
+ return with_children(
+ Element(name),
+ [
+ with_children(
+ Element(self.item_name),
+ [
+ self.key_xobject.xml_out(
+ self.key_name, k, item_ctx.next(name)
+ ),
+ self.val_xobject.xml_out(
+ self.val_name, v, item_ctx.next(name)
+ ),
+ ],
+ )
+ for k, v in val.items()
+ ],
+ )
+
+ def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> dict[Any, Any]:
+ parsed = {}
+ for child in children(obj):
+ if child.tag != self.item_name:
+ raise ErrorTypes.InvalidDictionaryItem(
+ ctx,
+ self.item_name,
+ self.key_name,
+ self.val_name,
+ child.tag,
+ obj.tag,
+ )
+ else:
+ child_ctx = ctx.next(self.item_name)
+ k = self.key_xobject.xml_in(
+ get(child, self.key_name), child_ctx.next(self.key_name)
+ )
+ v = self.val_xobject.xml_in(
+ get(child, self.val_name), child_ctx.next(self.val_name)
+ )
+
+ if k in parsed:
+ raise ErrorTypes.DuplicateItem(
+ ctx, "dictionary", obj.tag, k
+ )
+
+ parsed[k] = v
+ # TODO: Check for other tags? Fail better?
+ return parsed
+
+
+def resolve_type(v: Any) -> AnyType:
+ """Determine the type of some value, using primitive types
+ - If empty container, only provide top container type
+ INV: only generic types for v are {tuple, list, dict, set}
+ """
+ t = type(v)
+ if t in {int, float, str, bool, NoneType}:
+ return t
+ elif t == dict and len(v) > 0:
+ t0, t1 = next(iter(v.items()))
+ return dict[resolve_type(t0), resolve_type(t1)] # type: ignore[misc, index, no-any-return]
+ elif t == list and len(v) > 0:
+ return list[resolve_type(v[0])] # type: ignore[misc, index, no-any-return]
+ elif t == set and len(v) > 0:
+ return set[resolve_type(next(iter(v)))] # type: ignore[misc, index, no-any-return]
+ elif t == tuple and len(v) > 0:
+ return tuple[*(resolve_type(vi) for vi in v)] # type: ignore[misc, no-any-return]
+ else:
+ # INV: non-generic type
+ return t
+
+
+@dataclass
+class UnionObj(XObject):
+ """A variant, can be one of several different types"""
+
+ xobjects: dict[AnyType, XObject]
+ elem_gen: Callable[[AnyType], str] = lambda t: pascalize(typename(t))
+
+ def xsd_out(
+ self,
+ name: str,
+ attribs: dict[str, str] = {},
+ add_ns: dict[str, str] = {},
+ ) -> _Element:
+ return with_child(
+ Element(f"{XMLSchema}element", name=name, attrib=attribs),
+ with_children(
+ Element(f"{XMLSchema}complexType"),
+ [
+ Comment("this is a union!"),
+ with_children(
+ Element(f"{XMLSchema}sequence"),
+ [
+ xobj.xsd_out(
+ self.elem_gen(t), {"minOccurs": "0"}, add_ns
+ )
+ for t, xobj in self.xobjects.items()
+ ],
+ ),
+ ],
+ ),
+ )
+
+ def xml_temp(self, name: str) -> _Element:
+ return with_children(
+ Element(name),
+ [
+ Comment(
+ "This is a union, the following variants are possible, only one can be present"
+ )
+ ]
+ + [
+ xobj.xml_temp(self.elem_gen(t))
+ for t, xobj in self.xobjects.items()
+ ],
+ )
+
+ def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
+ t = resolve_type(val)
+
+ if (val_xobj := self.xobjects.get(t)) is not None:
+ variant_name = self.elem_gen(t)
+ return with_child(
+ Element(name),
+ val_xobj.xml_out(variant_name, val, ctx.next(variant_name)),
+ )
+ else:
+ raise ErrorTypes.InvalidVariant(
+ ctx, name, list(self.xobjects.keys()), t, val
+ )
+
+ def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> Any:
+ named = {self.elem_gen(t): xobj for t, xobj in self.xobjects.items()}
+ variants = list(children(obj))
+
+ if len(variants) != 1:
+ raise ErrorTypes.MultipleVariants(ctx, [v.tag for v in variants])
+
+ variant = variants[0]
+ if (xobj := named.get(variant.tag)) is not None:
+ return xobj.xml_in(variant, ctx.next(variant.tag))
+ else:
+ raise ErrorTypes.ParseInvalidVariant(
+ ctx, str(obj.tag), list(named.keys()), str(variant)
+ )
+
+
+class NoneObj(XObject):
+ """
+ An object representing the python 'None' type
+ - Unions of form `int | None` are used for optionals
+ """
+
+ def xsd_out(
+ self,
+ name: str,
+ attribs: dict[str, str] = {},
+ add_ns: dict[str, str] = {},
+ ) -> _Element:
+ return with_child(
+ Element(f"{XMLSchema}element", name=name, attrib=attribs),
+ Comment("This is a None type"),
+ )
+
+ def xml_temp(self, name: str) -> _Element:
+ return with_child(Element(name), Comment("This is None"))
+
+ def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
+ if val != None:
+ raise ErrorTypes.NoneIsSome(ctx, name, val)
+
+ return with_child(Element(name), Comment("This is None"))
+
+ def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> Any:
+ return None
+
+
+def is_xmlified(cls):
+ return (
+ hasattr(cls, "xsd_forward")
+ and hasattr(cls, "xsd_dependencies")
+ and hasattr(cls, "get_xobject")
+ and hasattr(cls, "xsd")
+ and hasattr(cls, "xml")
+ and hasattr(cls, "xml_value")
+ and hasattr(cls, "parse")
+ )
+
+
+def gen_xobject(data_type: AnyType, forward_dec: set[AnyType]) -> XObject:
+ basic_types: dict[
+ AnyType, tuple[str, Callable[[Any], str], Callable[[Any], bool]]
+ ] = {
+ int: ("integer", str, lambda d: type(d) == int),
+ str: ("string", str, lambda d: type(d) == str),
+ float: ("decimal", str, lambda d: type(d) == float),
+ bool: (
+ "boolean",
+ lambda b: "true" if b else "false",
+ lambda d: type(d) == bool,
+ ),
+ }
+
+ if (basic_entry := basic_types.get(data_type)) is not None:
+ type_str, convert_fn, validate_fn = basic_entry
+ # NOTE: here was can pass the parse_fn as the data type, as the name is
+ # also a constructor. (e.g. `int` -> `int("23") == 32`)
+ parse_fn = cast(Callable[[ObjectifiedElement], Any], data_type)
+ return BasicObj(type_str, convert_fn, validate_fn, parse_fn)
+ elif isinstance(data_type, NoneType) or data_type == NoneType:
+ # NOTE: Python typing cringe: None can be both a type and a value
+ # (even when within a type hint!)
+ # a: list[None] -> None is an instance of NoneType
+ # a: int | None -> Union of int and NoneType
+ return NoneObj()
+ elif isinstance(data_type, UnionType):
+ return UnionObj(
+ {t: gen_xobject(t, forward_dec) for t in get_args(data_type)}
+ )
+ else:
+ t_name = typename(data_type)
+ if t_name == "list":
+ (item_type,) = get_args(data_type)
+ return ListObj(
+ gen_xobject(item_type, forward_dec),
+ pascalize(typename(item_type)),
+ )
+ elif t_name == "dict":
+ key_type, val_type = get_args(data_type)
+ return DictObj(
+ gen_xobject(key_type, forward_dec),
+ gen_xobject(val_type, forward_dec),
+ )
+ elif t_name == "tuple":
+ return TupleObj(
+ tuple(gen_xobject(t, forward_dec) for t in get_args(data_type))
+ )
+ elif t_name == "set":
+ (item_type,) = get_args(data_type)
+ return SetOBj(
+ gen_xobject(item_type, forward_dec),
+ pascalize(typename(item_type)),
+ )
+ else:
+ if is_xmlified(data_type):
+ forward_dec.add(data_type)
+ return data_type.get_xobject() # type: ignore[attr-defined, no-any-return]
+ else:
+ raise ErrorTypes.NonXMlifiedType(t_name)
diff --git a/third_party/xmlable/py.typed b/third_party/xmlable/py.typed
new file mode 100644
index 0000000..e69de29