diff --git a/agentserver/message_bus/bus.py b/agentserver/message_bus/bus.py index d34eed6..680d670 100644 --- a/agentserver/message_bus/bus.py +++ b/agentserver/message_bus/bus.py @@ -1,158 +1,226 @@ -# agentserver/bus.py -# Refactored January 01, 2026 – MessageBus with run() pump and out-of-band shutdown +""" +bus.py — The central MessageBus and pump for AgentServer v2.1 + +This is the beating heart of the organism: +- Owns all pipelines (one per listener + permanent system pipeline) +- Maintains the routing table (root_tag → Listener(s)) +- Orchestrates ingress from sockets/gateways +- Dispatches prepared messages to handlers +- Processes handler responses (multi-payload extraction, provenance injection) +- Guarantees thread continuity and diagnostic injection + +Fully aligned with: + - listener-class-v2.1.md + - configuration-v2.1.md + - message-pump-v2.1.md +""" + +from __future__ import annotations import asyncio -import logging -from typing import AsyncIterator, Callable, Dict, Optional, Awaitable +from dataclasses import dataclass +from typing import Callable, Awaitable, List +from uuid import uuid4 from lxml import etree -from agentserver.xml_listener import XMLListener -from agentserver.utils.message import repair_and_canonicalize, XmlTamperError +from agentserver.message_bus.message_state import MessageState, HandlerMetadata +from agentserver.message_bus.steps.repair import repair_step +from agentserver.message_bus.steps.c14n import c14n_step +from agentserver.message_bus.steps.envelope_validation import envelope_validation_step +from agentserver.message_bus.steps.payload_extraction import payload_extraction_step +from agentserver.message_bus.steps.thread_assignment import thread_assignment_step +from agentserver.message_bus.steps.xsd_validation import xsd_validation_step +from agentserver.message_bus.steps.deserialization import deserialization_step +from agentserver.message_bus.steps.routing_resolution import routing_resolution_step -# Constants for Internal Physics -ENV_NS = "https://xml-pipeline.org/ns/envelope/1" -ENV = f"{{{ENV_NS}}}" -LOG_TAG = "{https://xml-pipeline.org/ns/logger/1}log" +# Type alias for pipeline steps +PipelineStep = Callable[[MessageState], Awaitable[MessageState]] -logger = logging.getLogger("agentserver.bus") + +@dataclass +class Listener: + """Registered capability — defined in listener.py, referenced here.""" + name: str + payload_class: type + handler: Callable + description: str + is_agent: bool = False + peers: list[str] | None = None + broadcast: bool = False + pipeline: "Pipeline" | None = None + schema: etree.XMLSchema | None = None # cached at registration + + +class Pipeline: + """One dedicated pipeline per listener (plus system pipeline).""" + def __init__(self, steps: List[PipelineStep]): + self.steps = steps + + async def process(self, initial_state: MessageState) -> None: + """Run the full ordered pipeline on a message.""" + state = initial_state + for step in self.steps: + try: + state = await step(state) + if state.error: + break + except Exception as exc: # pylint: disable=broad-except + state.error = f"Pipeline step {step.__name__} crashed: {exc}" + break + + # After all steps — dispatch if routable + if state.target_listeners: + await MessageBus.get_instance().dispatcher(state) + else: + # Fall back to system pipeline for diagnostics + await MessageBus.get_instance().system_pipeline.process(state) class MessageBus: - """The sovereign message carrier. + """Singleton message bus — the pump.""" + _instance: "MessageBus" | None = None - - Routes canonical XML trees by root tag and meta. - - Pure dispatch: tree → optional response tree. - - Active pump via run(): handles serialization and egress. - - Out-of-band shutdown via asyncio.Event (fast-path, flood-immune). - """ + def __init__(self): + self.routing_table: dict[str, List[Listener]] = {} # root_tag → listener(s) + self.listeners: dict[str, Listener] = {} # name → Listener + self.system_pipeline = Pipeline(self._build_system_steps()) - def __init__(self, log_hook: Callable[[etree._Element], None]): - # root_tag -> {agent_name -> XMLListener} - self.listeners: Dict[str, Dict[str, XMLListener]] = {} - # Global lookup for directed routing - self.global_names: Dict[str, XMLListener] = {} + @classmethod + def get_instance(cls) -> "MessageBus": + if cls._instance is None: + cls._instance = MessageBus() + return cls._instance - # The Sovereign Witness hook - self._log_hook = log_hook + # ------------------------------------------------------------------ # + # Default step lists + # ------------------------------------------------------------------ # + def _build_default_listener_steps(self) -> List[PipelineStep]: + return [ + repair_step, + c14n_step, + envelope_validation_step, + payload_extraction_step, + thread_assignment_step, + xsd_validation_step, + deserialization_step, + routing_resolution_step, + ] - # Out-of-band shutdown signal (set only by AgentServer on privileged command) - self.shutdown_event = asyncio.Event() + def _build_system_steps(self) -> List[PipelineStep]: + """Shorter, fixed steps — no XSD/deserialization.""" + return [ + repair_step, + c14n_step, + envelope_validation_step, + payload_extraction_step, + thread_assignment_step, + # system-specific handler that emits , boot, etc. + self.system_handler_step, + ] - async def register_listener(self, listener: XMLListener) -> None: - """Register an organ. Enforces global identity uniqueness.""" - if listener.agent_name in self.global_names: - raise ValueError(f"Identity collision: {listener.agent_name}") + # ------------------------------------------------------------------ # + # Registration (called from listener.py) + # ------------------------------------------------------------------ # + def register_listener(self, listener: Listener) -> None: + root_tag = f"{listener.name.lower()}.{listener.payload_class.__name__.lower()}" - self.global_names[listener.agent_name] = listener - for tag in listener.listens_to: - tag_dict = self.listeners.setdefault(tag, {}) - tag_dict[listener.agent_name] = listener + if root_tag in self.routing_table and not listener.broadcast: + raise ValueError(f"Root tag collision: {root_tag} already registered by {self.routing_table[root_tag][0].name}") - logger.info(f"Registered organ: {listener.agent_name}") + # Build dedicated pipeline + steps = self._build_default_listener_steps() + # Inject listener-specific schema for xsd_validation_step + for step in steps: + if step.__name__ == "xsd_validation_step": + # We'll modify state.metadata in pipeline construction instead + pass + listener.pipeline = Pipeline(steps) - async def deliver_bytes(self, raw_xml: bytes, client_id: Optional[str] = None) -> None: - """Air Lock: ingest raw bytes, repair/canonicalize, inject into core.""" - try: - envelope_tree = repair_and_canonicalize(raw_xml) - await self.dispatch(envelope_tree, client_id) - except XmlTamperError as e: - logger.warning(f"Air Lock Reject: {e}") + # Insert into routing + self.routing_table.setdefault(root_tag, []).append(listener) + self.listeners[listener.name] = listener - async def dispatch( - self, - envelope_tree: etree._Element, - client_id: Optional[str] = None, - ) -> etree._Element | None: - """Pure routing heart. Returns validated response tree or None.""" - # 1. WITNESS – every canonical envelope is seen - self._log_hook(envelope_tree) + # ------------------------------------------------------------------ # + # Dispatcher — dumb fire-and-await + # ------------------------------------------------------------------ # + async def dispatcher(self, state: MessageState) -> None: + if not state.target_listeners: + return - # 2. Extract envelope metadata - meta = envelope_tree.find(f"{ENV}meta") - if meta is None: - return None - from_name = meta.findtext(f"{ENV}from") - to_name = meta.findtext(f"{ENV}to") - thread_id = meta.findtext(f"{ENV}thread_id") or meta.findtext(f"{ENV}thread") + metadata = HandlerMetadata( + thread_id=state.thread_id or "", + from_id=state.from_id or "unknown", + own_name=state.target_listeners[0].name if state.target_listeners[0].is_agent else None, + is_self_call=(state.from_id == state.target_listeners[0].name) if state.from_id else False, + ) - # Find payload (first non-meta child) - payload_elem = next((c for c in envelope_tree if c.tag != f"{ENV}meta"), None) - if payload_elem is None: - return None - payload_tag = payload_elem.tag - - # 3. AUTONOMIC REFLEX: explicit - if payload_tag == LOG_TAG: - self._log_hook(envelope_tree) # extra vent - # Minimal ack envelope - ack = etree.Element(f"{ENV}message") - meta_ack = etree.SubElement(ack, f"{ENV}meta") - etree.SubElement(meta_ack, f"{ENV}from").text = "system" - if from_name: - etree.SubElement(meta_ack, f"{ENV}to").text = from_name - if thread_id: - etree.SubElement(meta_ack, f"{ENV}thread_id").text = thread_id - etree.SubElement(ack, "logged", status="success") - return ack - - # 4. ROUTING - listeners_for_tag = self.listeners.get(payload_tag, {}) - response_tree: Optional[etree._Element] = None - responding_agent_name = "unknown" - - if to_name: - # Directed - target = listeners_for_tag.get(to_name) or self.global_names.get(to_name) - if target: - responding_agent_name = target.agent_name - response_tree = await target.handle(envelope_tree, thread_id, from_name or client_id) + if len(state.target_listeners) == 1: + listener = state.target_listeners[0] + await self._process_single_handler(state, listener, metadata) else: - # Broadcast – first non-None wins (current policy) + # Broadcast — fire all in parallel, process responses as they complete tasks = [ - l.handle(envelope_tree, thread_id, from_name or client_id) - for l in listeners_for_tag.values() + self._process_single_handler(state, listener, metadata) + for listener in state.target_listeners ] - results = await asyncio.gather(*tasks, return_exceptions=True) - for listener, result in zip(listeners_for_tag.values(), results): - if isinstance(result, etree._Element): - responding_agent_name = listener.agent_name - response_tree = result - break # first-wins + for future in asyncio.as_completed(tasks): + await future - # 5. IDENTITY INSPECTION – prevent spoofing - if response_tree is not None: - actual_from = response_tree.findtext(f"{ENV}meta/{ENV}from") - if actual_from != responding_agent_name: - logger.critical( - f"IDENTITY THEFT BLOCKED: expected {responding_agent_name}, got {actual_from}" - ) - return None - - return response_tree - - async def run( - self, - inbound: AsyncIterator[etree._Element], - outbound: Callable[[bytes], Awaitable[None]], - client_id: Optional[str] = None, - ) -> None: - """Active pump for a connection. Handles serialization and egress.""" + async def _process_single_handler(self, state: MessageState, listener: Listener, metadata: HandlerMetadata) -> None: try: - async for envelope_tree in inbound: - if self.shutdown_event.is_set(): - break + response_bytes = await listener.handler(state.payload, metadata) - response_tree = await self.dispatch(envelope_tree, client_id) - if response_tree is not None: - serialized = etree.tostring( - response_tree, encoding="utf-8", pretty_print=True - ) - await outbound(serialized) - finally: - # Optional final courtesy message on clean exit - goodbye = b"" - try: - await outbound(goodbye) - except Exception: - pass # connection already gone \ No newline at end of file + if response_bytes is None or not isinstance(response_bytes, bytes): + response_bytes = b"Handler failed to return valid bytes — missing return or wrong type" + + payloads = await self._multi_payload_extract(response_bytes) + + for payload_bytes in payloads: + new_state = MessageState( + raw_bytes=payload_bytes, + thread_id=state.thread_id, + from_id=listener.name, + ) + # Route the new payload through normal pipelines + root_tag = self._derive_root_tag(payload_bytes) + targets = self.routing_table.get(root_tag) + if targets: + new_state.target_listeners = targets + await targets[0].pipeline.process(new_state) + else: + await self.system_pipeline.process(new_state) + + except Exception as exc: # pylint: disable=broad-except + error_state = MessageState( + raw_bytes=b"Handler crashed", + thread_id=state.thread_id, + from_id=listener.name, + error=f"Handler {listener.name} crashed: {exc}", + ) + await self.system_pipeline.process(error_state) + + # ------------------------------------------------------------------ # + # Helper methods + # ------------------------------------------------------------------ # + async def _multi_payload_extract(self, raw_bytes: bytes) -> List[bytes]: + # Same logic as before — dummy wrap, repair, extract all root elements + # (implementation can be moved to a shared util later) + # For now, placeholder — we'll flesh this out in response_processing.py + return [raw_bytes] # temporary — will be full extraction + + def _derive_root_tag(self, payload_bytes: bytes) -> str: + # Quick parse to get root tag — used only for routing extracted payloads + try: + tree = etree.fromstring(payload_bytes) + tag = tree.tag + if tag.startswith("{"): + return tag.split("}", 1)[1] # strip namespace + return tag + except Exception: + return "" + + async def system_handler_step(self, state: MessageState) -> MessageState: + # Emit or boot message — placeholder for now + state.error = state.error or "Unhandled by any listener" + return state \ No newline at end of file diff --git a/agentserver/message_bus/steps/deserialization.py b/agentserver/message_bus/steps/deserialization.py new file mode 100644 index 0000000..4ab9753 --- /dev/null +++ b/agentserver/message_bus/steps/deserialization.py @@ -0,0 +1,45 @@ +""" +deserialization.py — Convert validated payload_tree into typed dataclass instance. + +After xsd_validation_step confirms the payload conforms to the listener's contract, +this step uses the xmlable library to deserialize the lxml Element into the +registered @xmlify dataclass. + +The resulting instance is placed in state.payload and handed to the handler. + +Part of AgentServer v2.1 message pump. +""" + +from xmlable import from_xml # from the xmlable library +from agentserver.message_bus.message_state import MessageState + + +async def deserialization_step(state: MessageState) -> MessageState: + """ + Deserialize the validated payload_tree into the listener's dataclass. + + Requires: + - state.payload_tree valid against listener XSD + - state.metadata["payload_class"] set to the target dataclass (set at registration) + + On success: state.payload = dataclass instance + On failure: state.error set with clear message + """ + if state.payload_tree is None: + state.error = "deserialization_step: no payload_tree (previous step failed)" + return state + + payload_class = state.metadata.get("payload_class") + if payload_class is None: + state.error = "deserialization_step: no payload_class in metadata (listener misconfigured)" + return state + + try: + # xmlable.from_xml handles namespace-aware deserialization + instance = from_xml(payload_class, state.payload_tree) + state.payload = instance + + except Exception as exc: # pylint: disable=broad-except + state.error = f"deserialization_step failed: {exc}" + + return state \ No newline at end of file diff --git a/agentserver/message_bus/steps/routing_resolution.py b/agentserver/message_bus/steps/routing_resolution.py new file mode 100644 index 0000000..d38b7b8 --- /dev/null +++ b/agentserver/message_bus/steps/routing_resolution.py @@ -0,0 +1,50 @@ +""" +routing_resolution.py — Resolve routing based on derived root tag. + +This is the final preparation step before dispatch. +It computes the root tag from the deserialized payload and looks it up in the +global routing table (root_tag → list[Listener]). + +On success: state.target_listeners is set +On failure: state.error is set → message falls to system pipeline for + +Part of AgentServer v2.1 message pump. +""" + +from agentserver.message_bus.message_state import MessageState +from agentserver.message_bus.bus import MessageBus + + +async def routing_resolution_step(state: MessageState) -> MessageState: + """ + Resolve which listener(s) should handle this payload. + + Root tag = f"{from_id.lower()}.{payload_class_name.lower()}" + (from_id is trustworthy — injected by pump) + + Supports: + - Normal unique routing (one listener) + - Broadcast (multiple listeners if broadcast: true and same root tag) + + If no match → error, falls to system pipeline. + """ + if state.payload is None: + state.error = "routing_resolution_step: no deserialized payload (previous step failed)" + return state + + if state.from_id is None: + state.error = "routing_resolution_step: missing from_id (provenance error)" + return state + + payload_class_name = type(state.payload).__name__.lower() + root_tag = f"{state.from_id.lower()}.{payload_class_name}" + + bus = MessageBus.get_instance() + targets = bus.routing_table.get(root_tag) + + if not targets: + state.error = f"routing_resolution_step: unknown capability root tag '{root_tag}'" + return state + + state.target_listeners = targets + return state \ No newline at end of file diff --git a/agentserver/message_bus/steps/xsd_validation.py b/agentserver/message_bus/steps/xsd_validation.py index 27a4c2c..3f919cc 100644 --- a/agentserver/message_bus/steps/xsd_validation.py +++ b/agentserver/message_bus/steps/xsd_validation.py @@ -1,13 +1,15 @@ """ -payload_extraction.py — Extract the inner payload from the validated envelope. +xsd_validation.py — Validate the extracted payload against the listener-specific XSD. -After envelope_validation_step confirms a correct outer envelope, -this step removes the envelope elements (, , optional , etc.) -and isolates the single child element that is the actual payload. +After payload_extraction_step isolates the payload_tree and provenance, +this step validates the payload against the XSD that was auto-generated +from the listener's @xmlify dataclass at registration time. -The payload is expected to be exactly one root element (the capability-specific XML). -If zero or multiple payload roots are found, we set a clear error — this protects -against malformed or ambiguous messages. +The XSD is cached and pre-loaded. The schema object is injected into +state.metadata["schema"] when the listener's pipeline is built. + +Failure here means the payload violates the declared contract — we collect +detailed errors for diagnostics. Part of AgentServer v2.1 message pump. """ @@ -15,77 +17,43 @@ Part of AgentServer v2.1 message pump. from lxml import etree from agentserver.message_bus.message_state import MessageState -# Envelope namespace for easy reference -_ENVELOPE_NS = "https://xml-pipeline.org/ns/envelope/v1" -_MESSAGE_TAG = f"{{{ _ENVELOPE_NS }}}message" - -async def payload_extraction_step(state: MessageState) -> MessageState: +async def xsd_validation_step(state: MessageState) -> MessageState: """ - Extract the single payload element from the validated envelope. + Validate state.payload_tree against the listener's cached XSD schema. - Expected structure: - - uuid - sender - - ← this is the one we want - ... - - + Requires: + - state.payload_tree set + - state.metadata["schema"] containing a pre-loaded etree.XMLSchema - On success: state.payload_tree is set to the payload Element. - On failure: state.error is set with a clear diagnostic. + On success: payload is guaranteed to match the contract + On failure: state.error contains detailed validation messages """ - if state.envelope_tree is None: - state.error = "payload_extraction_step: no envelope_tree (previous step failed)" + if state.payload_tree is None: + state.error = "xsd_validation_step: no payload_tree (previous extraction failed)" return state - # Basic sanity — root must be in correct namespace (already checked by schema, - # but we double-check for defence in depth) - if state.envelope_tree.tag != _MESSAGE_TAG: - state.error = f"payload_extraction_step: root tag is not in envelope namespace" + schema = state.metadata.get("schema") + if schema is None: + state.error = "xsd_validation_step: no XSD schema in metadata (listener pipeline misconfigured)" return state - # Find all direct children that are not envelope control elements - # Envelope control elements are: thread, from, to (optional) - payload_candidates = [ - child - for child in state.envelope_tree - if not ( - child.tag in { - f"{{{ _ENVELOPE_NS }}}thread", - f"{{{ _ENVELOPE_NS }}}from", - f"{{{ _ENVELOPE_NS }}}to", - } - ) - ] - - if len(payload_candidates) == 0: - state.error = "payload_extraction_step: no payload element found inside " + if not isinstance(schema, etree.XMLSchema): + state.error = "xsd_validation_step: metadata['schema'] is not an XMLSchema object" return state - if len(payload_candidates) > 1: - state.error = ( - "payload_extraction_step: multiple payload roots found — " - "exactly one capability payload element is allowed" - ) - return state + try: + # assertValid raises DocumentInvalid with full error log + schema.assertValid(state.payload_tree) - # Success — exactly one payload element - payload_element = payload_candidates[0] + except etree.DocumentInvalid: + # Collect all errors for clear diagnostics + error_lines = [] + for error in schema.error_log: + error_lines.append(f"{error.level_name}: {error.message} (line {error.line})") + state.error = "xsd_validation_step: payload failed contract validation\n" + "\n".join(error_lines) - # Optional: capture provenance from envelope for later use - # (these will be trustworthy because envelope was validated) - thread_elem = state.envelope_tree.find(f"{{{ _ENVELOPE_NS }}}thread") - from_elem = state.envelope_tree.find(f"{{{ _ENVELOPE_NS }}}from") - - if thread_elem is not None and thread_elem.text: - state.thread_id = thread_elem.text.strip() - - if from_elem is not None and from_elem.text: - state.from_id = from_elem.text.strip() - - state.payload_tree = payload_element + except Exception as exc: # pylint: disable=broad-except + state.error = f"xsd_validation_step: unexpected error during validation: {exc}" return state \ No newline at end of file diff --git a/structure.md b/structure.md index a52b739..56f6f79 100644 --- a/structure.md +++ b/structure.md @@ -19,7 +19,16 @@ xml-pipeline/ │ ├── message_bus/ │ │ ├── steps/ │ │ │ ├── __init__.py -│ │ │ └── repair_step.py +│ │ │ ├── c14n.py +│ │ │ ├── deserialization.py +│ │ │ ├── envelope_validation.py +│ │ │ ├── payload_extraction.py +│ │ │ ├── repair.py +│ │ │ ├── routing_resolution.py +│ │ │ ├── test_c14n.py +│ │ │ ├── test_repair.py +│ │ │ ├── thread_assignment.py +│ │ │ └── xsd_validation.py │ │ ├── __init__.py │ │ ├── bus.py │ │ ├── config.py @@ -48,14 +57,23 @@ xml-pipeline/ │ │ └── token-scheduling-issues.md │ ├── configuration.md │ ├── core-principles-v2.1.md +│ ├── doc_cross_check.md +│ ├── handler-contract-v2.1.md │ ├── listener-class-v2.1.md │ ├── message-pump-v2.1.md +│ ├── primitives.md │ ├── self-grammar-generation.md │ └── why-not-json.md ├── tests/ │ ├── scripts/ │ │ └── generate_organism_key.py │ └── __init__.py +├── xml_pipeline.egg-info/ +│ ├── PKG-INFO +│ ├── SOURCES.txt +│ ├── dependency_links.txt +│ ├── requires.txt +│ └── top_level.txt ├── LICENSE ├── README.md ├── __init__.py diff --git a/third_party/xmlable/__init__.py b/third_party/xmlable/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/third_party/xmlable/_errors.py b/third_party/xmlable/_errors.py new file mode 100644 index 0000000..99e8ea2 --- /dev/null +++ b/third_party/xmlable/_errors.py @@ -0,0 +1,261 @@ +""" +Colourful & descriptive errors for xmlable +- Clear messages +- Trace for parsing +""" + +from dataclasses import dataclass +from typing import Any, Iterable +from termcolor import colored +from termcolor.termcolor import Color + +from xmlable._utils import typename, AnyType + + +def trace_note(trace: list[str], arrow_c: Color, node_c: Color): + return colored(" > ", arrow_c).join( + map(lambda x: colored(x, node_c), trace) + ) + + +@dataclass +class XErrorCtx: + trace: list[str] + + def next(self, node: str): + return XErrorCtx(trace=self.trace + [node]) + + +# TODO: Custom backtrace to point to location in the file +class XError(Exception): + def __init__( + self, + short: str, + what: str, + why: str, + ctx: XErrorCtx | None = None, + notes: Iterable[str] = [], + ): + super().__init__(colored(short, "red", attrs=["blink"])) + self.add_note(colored("What: " + what, "blue")) + self.add_note(colored("Why: " + why, "yellow")) + if ctx is not None: + self.add_note( + colored("Where: ", "magenta") + + trace_note(ctx.trace, "light_magenta", "light_cyan") + ) + for note in notes: + self.add_note(note) + + +class ErrorTypes: + @staticmethod + def NonXMlifiedType(t_name: str) -> XError: + return XError( + short="Non XMlified Type", + what=f"You attempted to use {t_name} in an xmlified class, but {t_name} is not xmlified", + why=f"All types used in an xmlified class must be xmlified", + ) + + @staticmethod + def InvalidData(ctx: XErrorCtx, val: Any, t_name: str) -> XError: + return XError( + short="Invalid Data", + what=f"Could not validate {val} as a valid {t_name}", + why=f"Produced xml must be valid", + ctx=ctx, + ) + + @staticmethod + def ParseFailure( + ctx: XErrorCtx, text: str | None, t_name: str, caught: Exception + ) -> XError: + return XError( + short="Parse Failure", + what=f"Failed to parse {text} as a {t_name} with error: \n {caught}", + why=f"This error implies the xml is not validated against the current xsd, or there is a bug in this type's parser", + ctx=ctx, + ) + + @staticmethod + def UnexpectedTag( + ctx: XErrorCtx, expected_name: str, struct_name: str, tag_found: str + ) -> XError: + return XError( + short="Unexpected Tag", + what=f"Expected {expected_name} but found {tag_found}", + why=f"This is a {struct_name} that contains 0..n elements of {expected_name} and no other elements", + ctx=ctx, + ) + + @staticmethod + def IncorrectType( + ctx: XErrorCtx, expected_len: int, struct_name: str, val: Any, name: str + ) -> XError: + return XError( + short="Incorrect Type", + what=f"You have provided {len(val)} values {val} for {name}, but {name} is a {struct_name} that takes only {expected_len} values", + why=f"In order to generate xml, the values provided need to be the correct types", + ctx=ctx, + ) + + @staticmethod + def IncorrectElementTag( + ctx: XErrorCtx, + struct_name: str, + tag_name: str, + elem_index: int, + tag_expected: str, + tag_found: str, + ) -> XError: + return XError( + short="Incorrect Element Tag", + what=f"While parsing {struct_name} {tag_name} we expected element {elem_index} to be {tag_expected}, but found {tag_found}", + why=f"The xml representation for {struct_name} requires the correct names in the correct order", + ctx=ctx, + ) + + @staticmethod + def DuplicateItem( + ctx: XErrorCtx, struct_name: str, tag: str, item: str + ) -> XError: + return XError( + short=f"Duplicate item in {struct_name}", + what=f"In {tag} the item {item} is present more than once", + why=f"A set can only contain unique items", + ctx=ctx, + ) + + @staticmethod + def InvalidDictionaryItem( + ctx: XErrorCtx, + expected_tag: str, + expected_key: str, + expected_val: str, + dict_tag: str, + item_tag: str, + ) -> XError: + return XError( + short="Invalid item in dictionary", + what=f"An unexpected item with {dict_tag} is in dictionary {item_tag}", + why=f"Each item must have tag {expected_tag} with children {expected_key} and {expected_val}", + ctx=ctx, + ) + + @staticmethod + def InvalidVariant( + ctx: XErrorCtx, + name: str, + expected_types: list[AnyType], + found_type: AnyType | None, + found_value: Any, + ) -> XError: + types = " | ".join(map(str, expected_types)) + return XError( + short=f"Datatype not in Union", + what=f"{name} is a union of {types}, which does not contain {found_type} (you provided: {found_value})", + why=f"... uuuh, its a union?", + ctx=ctx, + ) + + @staticmethod + def MultipleVariants(ctx: XErrorCtx, variant_names: list[str]) -> XError: + return XError( + short="Multiple union variants present", + what=f"variants {', '.join(variant_names)} are present", + why=f"A union can only be one variant at a time", + ctx=ctx, + ) + + @staticmethod + def ParseInvalidVariant( + ctx: XErrorCtx, tag: str, named_variants: list[str], found_variant: str + ) -> XError: + return XError( + short="Invalid Variant", + what=f"The union {tag} can contain variants {', '.join(named_variants)}, but you have used {found_variant}", + why=f"Only valid variants can be parsed", + ctx=ctx, + ) + + @staticmethod + def NoneIsSome(ctx: XErrorCtx, name: str, val: Any) -> XError: + return XError( + short="None object is not None", + what=f"{name} contains value {val} which is not None", + why="A None type object can only contain none", + ctx=ctx, + ) + + @staticmethod + def NotADataclass(cls: AnyType) -> XError: + cls_name: str = typename(cls) + return XError( + short="Non-Dataclass", + what=f"{cls_name} is not a dataclass", + why=f"xmlify uses dataclasses to get fields", + ctx=XErrorCtx([cls_name]), + notes=[f"\nTry:\n@xmlify\n@dataclass\nclass {cls_name}:"], + ) + + @staticmethod + def ReservedAttribute(cls: AnyType, attr_name: str) -> XError: + cls_name: str = typename(cls) + return XError( + short=f"Reserved Attribute", + what=f"{cls_name}.{attr_name} is used by xmlify, so it cannot be a field of the class", + why=f"xmlify aguments {cls_name} by adding methods it can then use for xsd, xml generation and parsing", + ctx=XErrorCtx([cls_name]), + ) + + @staticmethod + def CommentAttribute(cls: AnyType) -> XError: + cls_name: str = typename(cls) + return XError( + short=f"Comment Attribute", + what=f"xmlifed classes cannot use comment as an attribute", + why=f"comment is used as a tag name for comments by lxml, so comments inserted on xml generation could conflict", + ctx=XErrorCtx([cls_name]), + ) + + @staticmethod + def NonMemberTag( + ctx: XErrorCtx, cls: AnyType, tag: str, name: str + ) -> XError: + cls_name: str = typename(cls) + return XError( + short="Non member tag", + what=f"In {tag} {cls_name}.{name} could not be found.", + why=f"All members, including {cls_name}.{name} must be present", + ctx=ctx, + ) + + @staticmethod + def MissingAttribute( + cls: AnyType, required_attrs: set[str], missing_attr: str + ) -> XError: + cls_name: str = typename(cls) + return XError( + short="Missing Attribute", + what=f"The attribute {missing_attr} is missing from {cls_name}", + why=f"To be manual_xmlified the attributes: {', '.join(required_attrs)} are required. Try using help(IXmlify)", + ctx=XErrorCtx([cls_name]), + ) + + @staticmethod + def DependencyCycle(cycle: list[AnyType]) -> XError: + return XError( + short="Dependency Cycle in XSD", + what=f"There is a cycle: {'<-'.join(map(str, cycle))}", + why="The XSDs for classes are written to the .xsd file in dependency order", + ) + + @staticmethod + def NotXmlified(cls: AnyType) -> XError: + cls_name: str = typename(cls) + return XError( + short="Not Xmlified", + what=f"{cls_name} is not xmlified, and hence cannot have an associated parser", + why=f"the .xsd(...) method is required to write_xsd", + notes=[f"To fix, try:\n@xmlify\n@dataclass\nclass {cls_name}: ..."], + ) diff --git a/third_party/xmlable/_io.py b/third_party/xmlable/_io.py new file mode 100644 index 0000000..475e205 --- /dev/null +++ b/third_party/xmlable/_io.py @@ -0,0 +1,67 @@ +""" +Easy file IO for users +- Need to make it obvious when an xml has been overwritten +- Easy parsing from a file +""" + +from pathlib import Path +from typing import Any, TypeVar +from termcolor import colored +from lxml.objectify import parse as objectify_parse +from lxml.etree import _ElementTree + +from xmlable._utils import typename +from xmlable._xobject import is_xmlified +from xmlable._errors import ErrorTypes + + +def write_file(file_path: str | Path, tree: _ElementTree): + print( + colored(f"Overwriting {file_path}", "red", attrs=["blink"]), end="..." + ) + with open(file=file_path, mode="wb") as f: + tree.write(f, xml_declaration=True, encoding="utf-8", pretty_print=True) + print(colored(f"Complete!", "green", attrs=["blink"])) + + +def parse_file(cls: type, file_path: str | Path) -> Any: + """ + Parse a file, validate and produce instance of cls + INV: cls must be an xmlified class + """ + if not is_xmlified(cls): + raise ErrorTypes.NotXmlified(cls) + with open(file=file_path, mode="r") as f: + return cls.parse(objectify_parse(f).getroot()) # type: ignore[attr-defined] + + +def write_xsd( + file_path: str | Path, + cls: type, + namespaces: dict[str, str] = {}, + imports: dict[str, str] = {}, +): + if not is_xmlified(cls): + raise ErrorTypes.NonXMlifiedType(typename(cls)) + else: + write_file(file_path, cls.xsd(namespaces=namespaces, imports=imports)) # type: ignore[attr-defined] + + +def write_xml_template( + file_path: str | Path, cls: type, schema_name: str | None = None +): + if not is_xmlified(cls): + raise ErrorTypes.NonXMlifiedType(typename(cls)) + else: + schema_id: str = ( + schema_name if schema_name is not None else typename(cls) + ) + write_file(file_path, cls.xml(schema_id)) # type: ignore[attr-defined] + + +def write_xml_value(file_path: str | Path, val: Any): + cls = type(val) + if not is_xmlified(cls): + raise ErrorTypes.NonXMlifiedType(typename(cls)) + else: + write_file(file_path, val.xml_value()) # type: ignore[attr-defined] diff --git a/third_party/xmlable/_lxml_helpers.py b/third_party/xmlable/_lxml_helpers.py new file mode 100644 index 0000000..1f1f288 --- /dev/null +++ b/third_party/xmlable/_lxml_helpers.py @@ -0,0 +1,33 @@ +""" +Helper functions for wrangling to lxml library +- Includes the XMLSchema used +""" + +from lxml.objectify import ObjectifiedElement +from lxml.etree import _Element +from typing import Iterable + +XMLURL = r"http://www.w3.org/2001/XMLSchema" +XMLSchema = r"{http://www.w3.org/2001/XMLSchema}" + + +def with_text(e: _Element, text: str) -> _Element: + e.text = text + return e + + +def with_children(parent: _Element, children: Iterable[_Element]) -> _Element: + for child in children: + parent.append(child) + return parent + + +def with_child(parent: _Element, child: _Element) -> _Element: + return with_children(parent, [child]) + + +def children(obj: ObjectifiedElement) -> Iterable[ObjectifiedElement]: + def not_comment(child_obj: ObjectifiedElement): + return child_obj.tag != "comment" + + return filter(not_comment, obj.getchildren()) # type: ignore[arg-type, operator] diff --git a/third_party/xmlable/_manual.py b/third_party/xmlable/_manual.py new file mode 100644 index 0000000..9e73c31 --- /dev/null +++ b/third_party/xmlable/_manual.py @@ -0,0 +1,137 @@ +""" +The @manual_xmlify decorator used to add the .xsd, .xml, .xml_value and .parse +methods to a class that already has .xsd_dependencies, .xsd_forward and +.get_xobject +""" + +from typing import Any +from lxml.etree import _Element, Element, _ElementTree, ElementTree +from lxml.objectify import ObjectifiedElement + +from xmlable._utils import typename, AnyType, ordered_iter +from xmlable._lxml_helpers import with_children, XMLSchema +from xmlable._errors import XError, XErrorCtx, ErrorTypes + + +def validate_manual_class(cls: AnyType): + attrs = {"get_xobject", "xsd_forward", "xsd_dependencies"} + for attr in attrs: + if not hasattr(cls, attr): + raise ErrorTypes.MissingAttribute(cls, attrs, attr) + + +def type_cycle(from_type: AnyType) -> list[AnyType]: + # INV: it is an xmlified type for a user define structure + cycle: list[AnyType] = [] + + def visit_dep(curr: AnyType) -> bool: + if curr == from_type or any( + visit_dep(dep) for dep in ordered_iter(curr.xsd_dependencies()) # type: ignore[attr-defined] + ): + cycle.append(curr) + return True + else: + return False + + assert visit_dep(from_type) + cycle.append(from_type) + return cycle + + +def manual_xmlify(cls: type) -> type: + """ + Generate the following methods: + ``` + def xsd( + id: str = cls_name, + namespaces: dict[str, str] = {}, + imports: dict[str, str] = {}, + ) -> _ElementTree: + # ... + + def xml(schema_name: str = cls_name) -> _ElementTree: + # ... + + def xml_value(self, id: str = cls_name) -> _ElementTree: + # ... + + def parse(obj: ObjectifiedElement) -> Any: + # ... + ``` + """ + try: + validate_manual_class(cls) + cls_name = typename(cls) + + cls_xobject = cls.get_xobject() # type: ignore[attr-defined] + + def xsd( + id: str = cls_name, + namespaces: dict[str, str] = {}, + imports: dict[str, str] = {}, + ) -> _ElementTree: + # Get dependencies (user classes that need to be declared before) + visited: set[AnyType] = set() + dec_order: list[AnyType] = [] + + def toposort( + curr: AnyType, visited: set[AnyType], dec_order: list[AnyType] + ): + if curr in visited: + raise ErrorTypes.DependencyCycle(type_cycle(curr)) + visited.add(curr) + deps = curr.xsd_dependencies() # type: ignore[attr-defined] + for d in ordered_iter(deps): + if d not in visited: + toposort(d, visited, dec_order) + dec_order.append(curr) + + toposort(cls, visited, dec_order) + + # Create forward declarations, potentially adding to namespaces + decs: list[_Element] = [dec.xsd_forward(namespaces) for dec in dec_order] # type: ignore[attr-defined] + + # generate main element (can add to namespaces) + main_element = cls_xobject.xsd_out(id, add_ns=namespaces) + + return ElementTree( + with_children( + Element( + f"{XMLSchema}schema", + id=id, + elementFormDefault="qualified", + nsmap=namespaces, + ), + [ + Element( + f"{XMLSchema}import", + namespace=ns, + schemaLocation=sloc, + ) + for ns, sloc in imports.items() + ] + + decs + + [main_element], + ) + ) + + def xml(schema_name: str = cls_name) -> _ElementTree: + return ElementTree(cls_xobject.xml_temp(schema_name)) + + def xml_value(self, id: str = cls_name) -> _ElementTree: + return ElementTree(cls_xobject.xml_out(id, self, XErrorCtx([id]))) + + def parse(obj: ObjectifiedElement) -> Any: + return cls_xobject.xml_in(obj, XErrorCtx([obj.tag])) + + cls.xsd = xsd # type: ignore[attr-defined] + cls.xml = xml # type: ignore[attr-defined] + setattr(cls, "xml_value", xml_value) # needs to use self to get values + cls.parse = parse # type: ignore[attr-defined] + + return cls + except XError as e: + # NOTE: Trick to remove dirty 'internal' traceback, and raise from + # xmlify (makes more sense to user than seeing internals) + e.__traceback__ = None + raise e diff --git a/third_party/xmlable/_user.py b/third_party/xmlable/_user.py new file mode 100644 index 0000000..d94db60 --- /dev/null +++ b/third_party/xmlable/_user.py @@ -0,0 +1,71 @@ +""" +The IXmlify interface +- Contains the methods needed to make get_xobject work +- Allows type checking of user's implementations +""" + +from abc import ABC, abstractmethod +from lxml.etree import _Element +from xmlable._xobject import XObject +from xmlable._utils import AnyType + + +class IXmlify(ABC): + """ + A useful interface for ensuring the attributes required for + @manual_xmlify are present + """ + + @staticmethod + @abstractmethod + def get_xobject() -> XObject: + """ + produces an xobject encapsulates the: + - xsd usage (e.g ) + - xml template + - xml value + - parsing + + ``` + @manual_xmlify + class Foo(IXmlify): + def get_xobject() -> XObject: + class MyObj(XObject): + # ... definitions + + return MyObj + ``` + """ + pass + + @staticmethod + @abstractmethod + def xsd_forward(add_ns: dict[str, str]) -> _Element: + """ + Produces the forward declaration + - xsd definition of the class's type + ``` + @manual_xmlify + class Foo(IXmlify): + def xsd_forward(add_ns: dict[str, str]) -> _Element: + return Element('{XMLSchema}complexType', name="Foo", ...) + ``` + """ + pass + + @staticmethod + @abstractmethod + def xsd_dependencies() -> set[AnyType]: + """ + The user classes that need to be before this first + + For example: + ``` + @manual_xmlify + class A(IXMlify): + # xobject uses Foo and Bar + def xsd_depedencies() -> set[type]: + return {Foo, Bar} + ``` + """ + pass diff --git a/third_party/xmlable/_utils.py b/third_party/xmlable/_utils.py new file mode 100644 index 0000000..c112bb0 --- /dev/null +++ b/third_party/xmlable/_utils.py @@ -0,0 +1,63 @@ +""" +Basic Utilities +Includes common helper functions for this project +- Handling optionals +- getting members by string name +- typenames +""" + +from typing import Any, Callable, TypeVar, TypeAlias, Type, Iterable +from types import GenericAlias + +AnyType: TypeAlias = Type | GenericAlias + +T = TypeVar("T") + + +def some_or(data: T | None, alt: T): + return data if data is not None else alt + + +N = TypeVar("N") +M = TypeVar("M") + + +def some_or_apply(data: N, fn: Callable[[N], M], alt: M): + return fn(data) if data is not None else alt + + +def get(obj: Any, attr: str) -> Any: + return obj.__getattribute__(attr) + + +def opt_get(obj: Any, attr: str) -> Any | None: + try: + return obj.__getattribute__(attr) + except AttributeError: + return None + + +X = TypeVar("X") +Y = TypeVar("Y") + + +def firstkey(d: dict[X, Y], val: Y) -> X | None: + for k, v in d.items(): + if v == val: + return k + else: + return None + + +def typename(t: AnyType) -> str: + if t is None: + return "None" + else: + return t.__name__ + + +Z = TypeVar("Z") + + +def ordered_iter(types: Iterable[Z]) -> list[Z]: + return sorted(list(types), key=str) diff --git a/third_party/xmlable/_xmlify.py b/third_party/xmlable/_xmlify.py new file mode 100644 index 0000000..050b069 --- /dev/null +++ b/third_party/xmlable/_xmlify.py @@ -0,0 +1,156 @@ +"""XMLable +A decorator to allow creation of xml config based on python dataclasses + +Given a dataclass: +- Produce an xsd schema based on the class +- Produce an xml template based on the class +- Given any instance of the class, make a best-effort attempt at turning it into + a filled xml +- Create a parser for parsing the xml +""" + +from humps import pascalize +from dataclasses import fields, is_dataclass +from typing import Any, dataclass_transform, cast +from lxml.objectify import ObjectifiedElement +from lxml.etree import Element, _Element + +from xmlable._utils import get, typename, AnyType +from xmlable._errors import XError, XErrorCtx, ErrorTypes +from xmlable._manual import manual_xmlify +from xmlable._lxml_helpers import with_children, with_child, XMLSchema +from xmlable._xobject import XObject, gen_xobject + + +def validate_class(cls: AnyType): + """ + Validate tha the class can be xmlified + - Must be a dataclass + - Cannot have any members called 'comment' (lxml parses comments as this tag) + - Cannot have + """ + if not is_dataclass(cls): + raise ErrorTypes.NotADataclass(cls) + + reserved_attrs = ["get_xobject", "xsd_forward", "xsd_dependencies"] + + # TODO: cleanup repetition + for f in fields(cls): + if f.name in reserved_attrs: + raise ErrorTypes.ReservedAttribute(cls, f.name) + elif f.name == "comment": + raise ErrorTypes.CommentAttribute(cls) + + # JUSTIFY: Could potentially have added other attributes (of the class, + # rather than a field of an instance as provided by dataclass + # fields) + for reserved in reserved_attrs: + if hasattr(cls, reserved): + raise ErrorTypes.ReservedAttribute(cls, reserved) + if hasattr(cls, "comment"): + raise ErrorTypes.CommentAttribute(cls) + + +@dataclass_transform() +def xmlify(cls: type) -> AnyType: + try: + validate_class(cls) + + cls_name = typename(cls) + forward_decs = cast(set[AnyType], {cls}) + meta_xobjects = [ + ( + pascalize(f.name), + f, + gen_xobject(cast(AnyType, f.type), forward_decs), + ) + for f in fields(cls) + ] + + class UserXObject(XObject): + def xsd_out( + self, + name: str, + attribs: dict[str, str] = {}, + add_ns: dict[str, str] = {}, + ) -> _Element: + return Element( + f"{XMLSchema}element", + name=name, + type=cls_name, + attrib=attribs, + ) + + def xml_temp(self, name: str) -> _Element: + return with_children( + Element(name), + [ + xobj.xml_temp(pascal_name) + for pascal_name, _, xobj in meta_xobjects + ], + ) + + def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element: + return with_children( + Element(name), + [ + xobj.xml_out( + pascal_name, + get(val, m.name), + ctx.next(pascal_name), + ) + for pascal_name, m, xobj in meta_xobjects + ], + ) + + def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> Any: + parsed: dict[str, Any] = {} + for pascal_name, m, xobj in meta_xobjects: + if (m_obj := get(obj, pascal_name)) is not None: + parsed[m.name] = xobj.xml_in( + m_obj, ctx.next(pascal_name) + ) + else: + raise ErrorTypes.NonMemberTag(ctx, cls, obj.tag, m.name) + return cls(**parsed) + + cls_xobject = UserXObject() + + # JUSTIFY: Why are xsd forward & dependencies not part of xobject? + # - xobject covers the use (not forward decs) + # - we want to present error messages to the user containing + # their types, so xsd dependencies are in terms of python + # types, rather than xobjects + # - forward and dependencies do not apply to the basic types, + # only user types + + def xsd_forward(add_ns: dict[str, str]) -> _Element: + return with_child( + Element(f"{XMLSchema}complexType", name=cls_name), + with_children( + Element(f"{XMLSchema}sequence"), + [ + xobj.xsd_out(pascal_name, attribs={}, add_ns=add_ns) + for pascal_name, m, xobj in meta_xobjects + ], + ), + ) + + def xsd_dependencies() -> set[AnyType]: + return forward_decs + + def get_xobject(): + return cls_xobject + + # helper methods for gen_xobject, and other dataclasses to generate their + # x methods + cls.xsd_forward = xsd_forward # type: ignore[attr-defined] + cls.xsd_dependencies = xsd_dependencies # type: ignore[attr-defined] + cls.get_xobject = get_xobject # type: ignore[attr-defined] + + return manual_xmlify(cls) + except XError as e: + # NOTE: Trick to remove dirty 'internal' traceback, and raise from + # xmlify (makes more sense to user than seeing internals) + e.__traceback__ = None + raise e diff --git a/third_party/xmlable/_xobject.py b/third_party/xmlable/_xobject.py new file mode 100644 index 0000000..766695e --- /dev/null +++ b/third_party/xmlable/_xobject.py @@ -0,0 +1,640 @@ +"""XObjects +XObjects are an intermediate representation for python types -> xsd/xml +- Produced by @xmlify decorated classes, and by gen_xobject +- Associated xsd, xml and parsing +""" + +from humps import pascalize +from dataclasses import dataclass +from types import NoneType, UnionType +from lxml.objectify import ObjectifiedElement +from lxml.etree import Element, Comment, _Element +from abc import ABC, abstractmethod +from typing import Any, Callable, Type, get_args, TypeAlias, cast +from types import GenericAlias + +from xmlable._utils import get, typename, firstkey, AnyType +from xmlable._errors import XErrorCtx, ErrorTypes +from xmlable._lxml_helpers import ( + with_text, + with_child, + with_children, + XMLSchema, + XMLURL, + children, +) + + +class XObject(ABC): + """Any XObject wraps the xsd generation, + We can map types to XObjects to get the xsd, template xml, etc + """ + + @abstractmethod + def xsd_out( + self, + name: str, + attribs: dict[str, str] = {}, + add_ns: dict[str, str] = {}, + ) -> _Element: + """Generate the xsd schema for the object""" + pass + + @abstractmethod + def xml_temp(self, name: str) -> _Element: + """ + Generate commented output for the xml representation + - Contains no values, only comments + - Not valid xml (can contain nested comments, comments instead of values) + """ + pass + + @abstractmethod + def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element: + pass + + @abstractmethod + def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> Any: + pass + + +@dataclass +class BasicObj(XObject): + """ + An xobject for a simple type (e.g string, int) + """ + + type_str: str + convert_fn: Callable[[Any], str] + validate_fn: Callable[[Any], bool] + parse_fn: Callable[[ObjectifiedElement], Any] + + def xsd_out( + self, + name: str, + attribs: dict[str, Any] = {}, + add_ns: dict[str, str] = {}, + ) -> _Element: + # NOTE: namespace cringe: + # - lxml will deal with qualifying namespaces for the name of the + # element, but not for attributes + # - XMLSchema type attributes must be qualified + if (prefix := firstkey(add_ns, XMLURL)) is not None: + return Element( + f"{XMLSchema}element", + name=name, + type=f"{prefix}:{self.type_str}", + attrib=attribs, + ) + else: + # add new namespace, resolve conflicts with extra 's' + new_ns = "xs" + while new_ns in add_ns: + new_ns += "s" + add_ns[new_ns] = XMLURL + return Element( + f"{XMLSchema}element", + name=name, + type=f"{new_ns}:{self.type_str}", + attrib=attribs, + nsmap={new_ns: XMLURL}, + ) + + def xml_temp(self, name: str) -> _Element: + return with_text(Element(name), f"Fill me with an {self.type_str}") + + def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element: + if not self.validate_fn(val): + raise ErrorTypes.InvalidData(ctx, val, self.type_str) + return with_text(Element(name), self.convert_fn(val)) + + def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> Any: + try: + return self.parse_fn(obj) + except Exception as e: + raise ErrorTypes.ParseFailure(ctx, obj.text, self.type_str, e) + + +@dataclass +class ListObj(XObject): + """ + An ordered list of objects + """ + + item_xobject: XObject + list_elem_name: str + struct_name: str = "List" + + def xsd_out( + self, + name: str, + attribs: dict[str, str] = {}, + add_ns: dict[str, str] = {}, + ) -> _Element: + return with_child( + Element(f"{XMLSchema}element", name=name, attrib=attribs), + with_children( + Element(f"{XMLSchema}complexType"), + [ + Comment(f"This is a {self.struct_name}"), + with_child( + Element(f"{XMLSchema}sequence"), + self.item_xobject.xsd_out( + self.list_elem_name, + {"minOccurs": "0", "maxOccurs": "unbounded"}, + add_ns, + ), + ), + ], + ), + ) + + def xml_temp(self, name: str) -> _Element: + return with_children( + Element(name), + [ + Comment(f"This is a {self.struct_name}"), + self.item_xobject.xml_temp(self.list_elem_name), + ], + ) + + def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element: + if len(val) > 0: + return with_children( + Element(name), + [ + self.item_xobject.xml_out( + self.list_elem_name, + item_val, + ctx.next(f"{self.list_elem_name}[{i}]"), + ) + for i, item_val in enumerate(val) + ], + ) + else: + return with_child( + Element(name), Comment(f"Empty {self.struct_name}!") + ) + + def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> list[Any]: + parsed = [] + for i, child in enumerate(children(obj)): + if child.tag != self.list_elem_name: + raise ErrorTypes.UnexpectedTag( + ctx, self.list_elem_name, self.struct_name, child.tag + ) + else: + parsed.append( + self.item_xobject.xml_in( + child, ctx.next(f"{self.list_elem_name}[{i}]") + ) + ) + return parsed + + +@dataclass +class StructObj(XObject): + """An order list of key-value pairs""" # TODO: make objects variable length tuple + + objects: list[tuple[str, XObject]] + struct_name: str = "Struct" + + def xsd_out( + self, + name: str, + attribs: dict[str, str] = {}, + add_ns: dict[str, str] = {}, + ) -> _Element: + return with_child( + Element(f"{XMLSchema}element", name=name, attrib=attribs), + with_child( + Element(f"{XMLSchema}complexType"), + with_children( + Element(f"{XMLSchema}sequence"), + [Comment(f"This is a {self.struct_name}")] + + [ + xobj.xsd_out(member, {}, add_ns) + for member, xobj in self.objects + ], + ), + ), + ) + + def xml_temp(self, name: str) -> _Element: + return with_children( + Element(name), + [Comment(f"This is a {self.struct_name}")] + + [xobj.xml_temp(member) for member, xobj in self.objects], + ) + + def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element: + if len(val) != len(self.objects): + raise ErrorTypes.IncorrectType( + ctx, len(self.objects), self.struct_name, val, name + ) + + return with_children( + Element(name), + [ + xobj.xml_out(member, v, ctx.next(member)) + for (member, xobj), v in zip(self.objects, val) + ], + ) + + def xml_in( + self, obj: ObjectifiedElement, ctx: XErrorCtx + ) -> list[tuple[str, Any]]: + parsed = [] + for i, (child, (name, xobj)) in enumerate( + zip(children(obj), self.objects) + ): + if child.tag != name: + raise ErrorTypes.IncorrectElementTag( + ctx, self.struct_name, obj.tag, i, name, child.tag + ) + parsed.append((name, xobj.xml_in(child, ctx.next(name)))) + return parsed + + +class TupleObj(XObject): + """An anonymous struct""" + + def __init__( + self, + objects: tuple[XObject, ...], + elem_gen: Callable[[int], str] = lambda i: f"Item-{i+1}", + ): + self.elem_gen = elem_gen + self.struct: StructObj = StructObj( + [(self.elem_gen(i), xobj) for i, xobj in enumerate(objects)], + struct_name="Tuple", + ) + + def xsd_out( + self, + name: str, + attribs: dict[str, str] = {}, + add_ns: dict[str, str] = {}, + ) -> _Element: + return self.struct.xsd_out(name, attribs, add_ns) + + def xml_temp(self, name: str) -> _Element: + return self.struct.xml_temp(name) + + def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element: + return self.struct.xml_out(name, val, ctx) + + def xml_in( + self, obj: ObjectifiedElement, ctx: XErrorCtx + ) -> tuple[Any, ...]: + # Assumes the objects are in the correct order + return tuple(zip(*self.struct.xml_in(obj, ctx)))[1] # type: ignore[no-any-return] + + +class SetOBj(XObject): + """An unordered collection of unique elements""" + + def __init__(self, inner: XObject, elem_name: str = "setitem"): + self.list = ListObj(inner, elem_name, struct_name="set") + + def xsd_out( + self, + name: str, + attribs: dict[str, str] = {}, + add_ns: dict[str, str] = {}, + ) -> _Element: + return self.list.xsd_out(name, attribs, add_ns) + + def xml_temp(self, name: str) -> _Element: + return self.list.xml_temp(name) + + def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element: + return self.list.xml_out(name, list(val), ctx) + + def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> set[Any]: + parsed: set[Any] = set() + for item in self.list.xml_in(obj, ctx): + if item in parsed: + raise ErrorTypes.DuplicateItem(ctx, "set", obj.tag, item) + parsed.add(item) + return parsed + + +@dataclass +class DictObj(XObject): + """An unordered collection of key-value pair elements""" + + key_xobject: XObject + val_xobject: XObject + key_name: str = "Key" + val_name: str = "Val" + item_name: str = "Item" + + def xsd_out( + self, + name: str, + attribs: dict[str, str] = {}, + add_ns: dict[str, str] = {}, + ) -> _Element: + return with_child( + Element(f"{XMLSchema}element", name=name, attrib=attribs), + with_children( + Element(f"{XMLSchema}complexType"), + [ + Comment("this is a dictionary!"), + with_child( + Element(f"{XMLSchema}sequence"), + with_child( + Element( + f"{XMLSchema}element", + name=self.item_name, + minOccurs="0", + maxOccurs="unbounded", + ), + with_child( + Element(f"{XMLSchema}complexType"), + with_children( + Element(f"{XMLSchema}sequence"), + [ + self.key_xobject.xsd_out( + self.key_name, {}, add_ns + ), + self.val_xobject.xsd_out( + self.val_name, {}, add_ns + ), + ], + ), + ), + ), + ), + ], + ), + ) + + def xml_temp(self, name: str) -> _Element: + return with_children( + Element(name), + [ + Comment("This is a dictionary"), + with_children( + Element(self.item_name), + [ + self.key_xobject.xml_temp(self.key_name), + self.val_xobject.xml_temp(self.val_name), + ], + ), + ], + ) + + def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element: + item_ctx = ctx.next(self.item_name) + + return with_children( + Element(name), + [ + with_children( + Element(self.item_name), + [ + self.key_xobject.xml_out( + self.key_name, k, item_ctx.next(name) + ), + self.val_xobject.xml_out( + self.val_name, v, item_ctx.next(name) + ), + ], + ) + for k, v in val.items() + ], + ) + + def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> dict[Any, Any]: + parsed = {} + for child in children(obj): + if child.tag != self.item_name: + raise ErrorTypes.InvalidDictionaryItem( + ctx, + self.item_name, + self.key_name, + self.val_name, + child.tag, + obj.tag, + ) + else: + child_ctx = ctx.next(self.item_name) + k = self.key_xobject.xml_in( + get(child, self.key_name), child_ctx.next(self.key_name) + ) + v = self.val_xobject.xml_in( + get(child, self.val_name), child_ctx.next(self.val_name) + ) + + if k in parsed: + raise ErrorTypes.DuplicateItem( + ctx, "dictionary", obj.tag, k + ) + + parsed[k] = v + # TODO: Check for other tags? Fail better? + return parsed + + +def resolve_type(v: Any) -> AnyType: + """Determine the type of some value, using primitive types + - If empty container, only provide top container type + INV: only generic types for v are {tuple, list, dict, set} + """ + t = type(v) + if t in {int, float, str, bool, NoneType}: + return t + elif t == dict and len(v) > 0: + t0, t1 = next(iter(v.items())) + return dict[resolve_type(t0), resolve_type(t1)] # type: ignore[misc, index, no-any-return] + elif t == list and len(v) > 0: + return list[resolve_type(v[0])] # type: ignore[misc, index, no-any-return] + elif t == set and len(v) > 0: + return set[resolve_type(next(iter(v)))] # type: ignore[misc, index, no-any-return] + elif t == tuple and len(v) > 0: + return tuple[*(resolve_type(vi) for vi in v)] # type: ignore[misc, no-any-return] + else: + # INV: non-generic type + return t + + +@dataclass +class UnionObj(XObject): + """A variant, can be one of several different types""" + + xobjects: dict[AnyType, XObject] + elem_gen: Callable[[AnyType], str] = lambda t: pascalize(typename(t)) + + def xsd_out( + self, + name: str, + attribs: dict[str, str] = {}, + add_ns: dict[str, str] = {}, + ) -> _Element: + return with_child( + Element(f"{XMLSchema}element", name=name, attrib=attribs), + with_children( + Element(f"{XMLSchema}complexType"), + [ + Comment("this is a union!"), + with_children( + Element(f"{XMLSchema}sequence"), + [ + xobj.xsd_out( + self.elem_gen(t), {"minOccurs": "0"}, add_ns + ) + for t, xobj in self.xobjects.items() + ], + ), + ], + ), + ) + + def xml_temp(self, name: str) -> _Element: + return with_children( + Element(name), + [ + Comment( + "This is a union, the following variants are possible, only one can be present" + ) + ] + + [ + xobj.xml_temp(self.elem_gen(t)) + for t, xobj in self.xobjects.items() + ], + ) + + def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element: + t = resolve_type(val) + + if (val_xobj := self.xobjects.get(t)) is not None: + variant_name = self.elem_gen(t) + return with_child( + Element(name), + val_xobj.xml_out(variant_name, val, ctx.next(variant_name)), + ) + else: + raise ErrorTypes.InvalidVariant( + ctx, name, list(self.xobjects.keys()), t, val + ) + + def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> Any: + named = {self.elem_gen(t): xobj for t, xobj in self.xobjects.items()} + variants = list(children(obj)) + + if len(variants) != 1: + raise ErrorTypes.MultipleVariants(ctx, [v.tag for v in variants]) + + variant = variants[0] + if (xobj := named.get(variant.tag)) is not None: + return xobj.xml_in(variant, ctx.next(variant.tag)) + else: + raise ErrorTypes.ParseInvalidVariant( + ctx, str(obj.tag), list(named.keys()), str(variant) + ) + + +class NoneObj(XObject): + """ + An object representing the python 'None' type + - Unions of form `int | None` are used for optionals + """ + + def xsd_out( + self, + name: str, + attribs: dict[str, str] = {}, + add_ns: dict[str, str] = {}, + ) -> _Element: + return with_child( + Element(f"{XMLSchema}element", name=name, attrib=attribs), + Comment("This is a None type"), + ) + + def xml_temp(self, name: str) -> _Element: + return with_child(Element(name), Comment("This is None")) + + def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element: + if val != None: + raise ErrorTypes.NoneIsSome(ctx, name, val) + + return with_child(Element(name), Comment("This is None")) + + def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> Any: + return None + + +def is_xmlified(cls): + return ( + hasattr(cls, "xsd_forward") + and hasattr(cls, "xsd_dependencies") + and hasattr(cls, "get_xobject") + and hasattr(cls, "xsd") + and hasattr(cls, "xml") + and hasattr(cls, "xml_value") + and hasattr(cls, "parse") + ) + + +def gen_xobject(data_type: AnyType, forward_dec: set[AnyType]) -> XObject: + basic_types: dict[ + AnyType, tuple[str, Callable[[Any], str], Callable[[Any], bool]] + ] = { + int: ("integer", str, lambda d: type(d) == int), + str: ("string", str, lambda d: type(d) == str), + float: ("decimal", str, lambda d: type(d) == float), + bool: ( + "boolean", + lambda b: "true" if b else "false", + lambda d: type(d) == bool, + ), + } + + if (basic_entry := basic_types.get(data_type)) is not None: + type_str, convert_fn, validate_fn = basic_entry + # NOTE: here was can pass the parse_fn as the data type, as the name is + # also a constructor. (e.g. `int` -> `int("23") == 32`) + parse_fn = cast(Callable[[ObjectifiedElement], Any], data_type) + return BasicObj(type_str, convert_fn, validate_fn, parse_fn) + elif isinstance(data_type, NoneType) or data_type == NoneType: + # NOTE: Python typing cringe: None can be both a type and a value + # (even when within a type hint!) + # a: list[None] -> None is an instance of NoneType + # a: int | None -> Union of int and NoneType + return NoneObj() + elif isinstance(data_type, UnionType): + return UnionObj( + {t: gen_xobject(t, forward_dec) for t in get_args(data_type)} + ) + else: + t_name = typename(data_type) + if t_name == "list": + (item_type,) = get_args(data_type) + return ListObj( + gen_xobject(item_type, forward_dec), + pascalize(typename(item_type)), + ) + elif t_name == "dict": + key_type, val_type = get_args(data_type) + return DictObj( + gen_xobject(key_type, forward_dec), + gen_xobject(val_type, forward_dec), + ) + elif t_name == "tuple": + return TupleObj( + tuple(gen_xobject(t, forward_dec) for t in get_args(data_type)) + ) + elif t_name == "set": + (item_type,) = get_args(data_type) + return SetOBj( + gen_xobject(item_type, forward_dec), + pascalize(typename(item_type)), + ) + else: + if is_xmlified(data_type): + forward_dec.add(data_type) + return data_type.get_xobject() # type: ignore[attr-defined, no-any-return] + else: + raise ErrorTypes.NonXMlifiedType(t_name) diff --git a/third_party/xmlable/py.typed b/third_party/xmlable/py.typed new file mode 100644 index 0000000..e69de29