fixing docs

This commit is contained in:
dullfig 2026-01-08 15:35:36 -08:00
parent ab207d8f0b
commit a1e1b9a1c0
15 changed files with 1775 additions and 198 deletions

View file

@ -1,158 +1,226 @@
# agentserver/bus.py
# Refactored January 01, 2026 MessageBus with run() pump and out-of-band shutdown
"""
bus.py The central MessageBus and pump for AgentServer v2.1
This is the beating heart of the organism:
- Owns all pipelines (one per listener + permanent system pipeline)
- Maintains the routing table (root_tag Listener(s))
- Orchestrates ingress from sockets/gateways
- Dispatches prepared messages to handlers
- Processes handler responses (multi-payload extraction, provenance injection)
- Guarantees thread continuity and diagnostic injection
Fully aligned with:
- listener-class-v2.1.md
- configuration-v2.1.md
- message-pump-v2.1.md
"""
from __future__ import annotations
import asyncio
import logging
from typing import AsyncIterator, Callable, Dict, Optional, Awaitable
from dataclasses import dataclass
from typing import Callable, Awaitable, List
from uuid import uuid4
from lxml import etree
from agentserver.xml_listener import XMLListener
from agentserver.utils.message import repair_and_canonicalize, XmlTamperError
from agentserver.message_bus.message_state import MessageState, HandlerMetadata
from agentserver.message_bus.steps.repair import repair_step
from agentserver.message_bus.steps.c14n import c14n_step
from agentserver.message_bus.steps.envelope_validation import envelope_validation_step
from agentserver.message_bus.steps.payload_extraction import payload_extraction_step
from agentserver.message_bus.steps.thread_assignment import thread_assignment_step
from agentserver.message_bus.steps.xsd_validation import xsd_validation_step
from agentserver.message_bus.steps.deserialization import deserialization_step
from agentserver.message_bus.steps.routing_resolution import routing_resolution_step
# Constants for Internal Physics
ENV_NS = "https://xml-pipeline.org/ns/envelope/1"
ENV = f"{{{ENV_NS}}}"
LOG_TAG = "{https://xml-pipeline.org/ns/logger/1}log"
# Type alias for pipeline steps
PipelineStep = Callable[[MessageState], Awaitable[MessageState]]
logger = logging.getLogger("agentserver.bus")
@dataclass
class Listener:
"""Registered capability — defined in listener.py, referenced here."""
name: str
payload_class: type
handler: Callable
description: str
is_agent: bool = False
peers: list[str] | None = None
broadcast: bool = False
pipeline: "Pipeline" | None = None
schema: etree.XMLSchema | None = None # cached at registration
class Pipeline:
"""One dedicated pipeline per listener (plus system pipeline)."""
def __init__(self, steps: List[PipelineStep]):
self.steps = steps
async def process(self, initial_state: MessageState) -> None:
"""Run the full ordered pipeline on a message."""
state = initial_state
for step in self.steps:
try:
state = await step(state)
if state.error:
break
except Exception as exc: # pylint: disable=broad-except
state.error = f"Pipeline step {step.__name__} crashed: {exc}"
break
# After all steps — dispatch if routable
if state.target_listeners:
await MessageBus.get_instance().dispatcher(state)
else:
# Fall back to system pipeline for diagnostics
await MessageBus.get_instance().system_pipeline.process(state)
class MessageBus:
"""The sovereign message carrier.
"""Singleton message bus — the pump."""
_instance: "MessageBus" | None = None
- Routes canonical XML trees by root tag and <to/> meta.
- Pure dispatch: tree optional response tree.
- Active pump via run(): handles serialization and egress.
- Out-of-band shutdown via asyncio.Event (fast-path, flood-immune).
"""
def __init__(self):
self.routing_table: dict[str, List[Listener]] = {} # root_tag → listener(s)
self.listeners: dict[str, Listener] = {} # name → Listener
self.system_pipeline = Pipeline(self._build_system_steps())
def __init__(self, log_hook: Callable[[etree._Element], None]):
# root_tag -> {agent_name -> XMLListener}
self.listeners: Dict[str, Dict[str, XMLListener]] = {}
# Global lookup for directed <to/> routing
self.global_names: Dict[str, XMLListener] = {}
@classmethod
def get_instance(cls) -> "MessageBus":
if cls._instance is None:
cls._instance = MessageBus()
return cls._instance
# The Sovereign Witness hook
self._log_hook = log_hook
# Out-of-band shutdown signal (set only by AgentServer on privileged command)
self.shutdown_event = asyncio.Event()
async def register_listener(self, listener: XMLListener) -> None:
"""Register an organ. Enforces global identity uniqueness."""
if listener.agent_name in self.global_names:
raise ValueError(f"Identity collision: {listener.agent_name}")
self.global_names[listener.agent_name] = listener
for tag in listener.listens_to:
tag_dict = self.listeners.setdefault(tag, {})
tag_dict[listener.agent_name] = listener
logger.info(f"Registered organ: {listener.agent_name}")
async def deliver_bytes(self, raw_xml: bytes, client_id: Optional[str] = None) -> None:
"""Air Lock: ingest raw bytes, repair/canonicalize, inject into core."""
try:
envelope_tree = repair_and_canonicalize(raw_xml)
await self.dispatch(envelope_tree, client_id)
except XmlTamperError as e:
logger.warning(f"Air Lock Reject: {e}")
async def dispatch(
self,
envelope_tree: etree._Element,
client_id: Optional[str] = None,
) -> etree._Element | None:
"""Pure routing heart. Returns validated response tree or None."""
# 1. WITNESS every canonical envelope is seen
self._log_hook(envelope_tree)
# 2. Extract envelope metadata
meta = envelope_tree.find(f"{ENV}meta")
if meta is None:
return None
from_name = meta.findtext(f"{ENV}from")
to_name = meta.findtext(f"{ENV}to")
thread_id = meta.findtext(f"{ENV}thread_id") or meta.findtext(f"{ENV}thread")
# Find payload (first non-meta child)
payload_elem = next((c for c in envelope_tree if c.tag != f"{ENV}meta"), None)
if payload_elem is None:
return None
payload_tag = payload_elem.tag
# 3. AUTONOMIC REFLEX: explicit <log/>
if payload_tag == LOG_TAG:
self._log_hook(envelope_tree) # extra vent
# Minimal ack envelope
ack = etree.Element(f"{ENV}message")
meta_ack = etree.SubElement(ack, f"{ENV}meta")
etree.SubElement(meta_ack, f"{ENV}from").text = "system"
if from_name:
etree.SubElement(meta_ack, f"{ENV}to").text = from_name
if thread_id:
etree.SubElement(meta_ack, f"{ENV}thread_id").text = thread_id
etree.SubElement(ack, "logged", status="success")
return ack
# 4. ROUTING
listeners_for_tag = self.listeners.get(payload_tag, {})
response_tree: Optional[etree._Element] = None
responding_agent_name = "unknown"
if to_name:
# Directed
target = listeners_for_tag.get(to_name) or self.global_names.get(to_name)
if target:
responding_agent_name = target.agent_name
response_tree = await target.handle(envelope_tree, thread_id, from_name or client_id)
else:
# Broadcast first non-None wins (current policy)
tasks = [
l.handle(envelope_tree, thread_id, from_name or client_id)
for l in listeners_for_tag.values()
# ------------------------------------------------------------------ #
# Default step lists
# ------------------------------------------------------------------ #
def _build_default_listener_steps(self) -> List[PipelineStep]:
return [
repair_step,
c14n_step,
envelope_validation_step,
payload_extraction_step,
thread_assignment_step,
xsd_validation_step,
deserialization_step,
routing_resolution_step,
]
results = await asyncio.gather(*tasks, return_exceptions=True)
for listener, result in zip(listeners_for_tag.values(), results):
if isinstance(result, etree._Element):
responding_agent_name = listener.agent_name
response_tree = result
break # first-wins
# 5. IDENTITY INSPECTION prevent spoofing
if response_tree is not None:
actual_from = response_tree.findtext(f"{ENV}meta/{ENV}from")
if actual_from != responding_agent_name:
logger.critical(
f"IDENTITY THEFT BLOCKED: expected {responding_agent_name}, got {actual_from}"
def _build_system_steps(self) -> List[PipelineStep]:
"""Shorter, fixed steps — no XSD/deserialization."""
return [
repair_step,
c14n_step,
envelope_validation_step,
payload_extraction_step,
thread_assignment_step,
# system-specific handler that emits <huh>, boot, etc.
self.system_handler_step,
]
# ------------------------------------------------------------------ #
# Registration (called from listener.py)
# ------------------------------------------------------------------ #
def register_listener(self, listener: Listener) -> None:
root_tag = f"{listener.name.lower()}.{listener.payload_class.__name__.lower()}"
if root_tag in self.routing_table and not listener.broadcast:
raise ValueError(f"Root tag collision: {root_tag} already registered by {self.routing_table[root_tag][0].name}")
# Build dedicated pipeline
steps = self._build_default_listener_steps()
# Inject listener-specific schema for xsd_validation_step
for step in steps:
if step.__name__ == "xsd_validation_step":
# We'll modify state.metadata in pipeline construction instead
pass
listener.pipeline = Pipeline(steps)
# Insert into routing
self.routing_table.setdefault(root_tag, []).append(listener)
self.listeners[listener.name] = listener
# ------------------------------------------------------------------ #
# Dispatcher — dumb fire-and-await
# ------------------------------------------------------------------ #
async def dispatcher(self, state: MessageState) -> None:
if not state.target_listeners:
return
metadata = HandlerMetadata(
thread_id=state.thread_id or "",
from_id=state.from_id or "unknown",
own_name=state.target_listeners[0].name if state.target_listeners[0].is_agent else None,
is_self_call=(state.from_id == state.target_listeners[0].name) if state.from_id else False,
)
return None
return response_tree
if len(state.target_listeners) == 1:
listener = state.target_listeners[0]
await self._process_single_handler(state, listener, metadata)
else:
# Broadcast — fire all in parallel, process responses as they complete
tasks = [
self._process_single_handler(state, listener, metadata)
for listener in state.target_listeners
]
for future in asyncio.as_completed(tasks):
await future
async def run(
self,
inbound: AsyncIterator[etree._Element],
outbound: Callable[[bytes], Awaitable[None]],
client_id: Optional[str] = None,
) -> None:
"""Active pump for a connection. Handles serialization and egress."""
async def _process_single_handler(self, state: MessageState, listener: Listener, metadata: HandlerMetadata) -> None:
try:
async for envelope_tree in inbound:
if self.shutdown_event.is_set():
break
response_bytes = await listener.handler(state.payload, metadata)
response_tree = await self.dispatch(envelope_tree, client_id)
if response_tree is not None:
serialized = etree.tostring(
response_tree, encoding="utf-8", pretty_print=True
if response_bytes is None or not isinstance(response_bytes, bytes):
response_bytes = b"<huh>Handler failed to return valid bytes — missing return or wrong type</huh>"
payloads = await self._multi_payload_extract(response_bytes)
for payload_bytes in payloads:
new_state = MessageState(
raw_bytes=payload_bytes,
thread_id=state.thread_id,
from_id=listener.name,
)
await outbound(serialized)
finally:
# Optional final courtesy message on clean exit
goodbye = b"<message xmlns='https://xml-pipeline.org/ns/envelope/1'><goodbye reason='connection-closed'/></message>"
# Route the new payload through normal pipelines
root_tag = self._derive_root_tag(payload_bytes)
targets = self.routing_table.get(root_tag)
if targets:
new_state.target_listeners = targets
await targets[0].pipeline.process(new_state)
else:
await self.system_pipeline.process(new_state)
except Exception as exc: # pylint: disable=broad-except
error_state = MessageState(
raw_bytes=b"<huh>Handler crashed</huh>",
thread_id=state.thread_id,
from_id=listener.name,
error=f"Handler {listener.name} crashed: {exc}",
)
await self.system_pipeline.process(error_state)
# ------------------------------------------------------------------ #
# Helper methods
# ------------------------------------------------------------------ #
async def _multi_payload_extract(self, raw_bytes: bytes) -> List[bytes]:
# Same logic as before — dummy wrap, repair, extract all root elements
# (implementation can be moved to a shared util later)
# For now, placeholder — we'll flesh this out in response_processing.py
return [raw_bytes] # temporary — will be full extraction
def _derive_root_tag(self, payload_bytes: bytes) -> str:
# Quick parse to get root tag — used only for routing extracted payloads
try:
await outbound(goodbye)
tree = etree.fromstring(payload_bytes)
tag = tree.tag
if tag.startswith("{"):
return tag.split("}", 1)[1] # strip namespace
return tag
except Exception:
pass # connection already gone
return ""
async def system_handler_step(self, state: MessageState) -> MessageState:
# Emit <huh> or boot message — placeholder for now
state.error = state.error or "Unhandled by any listener"
return state

View file

@ -0,0 +1,45 @@
"""
deserialization.py Convert validated payload_tree into typed dataclass instance.
After xsd_validation_step confirms the payload conforms to the listener's contract,
this step uses the xmlable library to deserialize the lxml Element into the
registered @xmlify dataclass.
The resulting instance is placed in state.payload and handed to the handler.
Part of AgentServer v2.1 message pump.
"""
from xmlable import from_xml # from the xmlable library
from agentserver.message_bus.message_state import MessageState
async def deserialization_step(state: MessageState) -> MessageState:
"""
Deserialize the validated payload_tree into the listener's dataclass.
Requires:
- state.payload_tree valid against listener XSD
- state.metadata["payload_class"] set to the target dataclass (set at registration)
On success: state.payload = dataclass instance
On failure: state.error set with clear message
"""
if state.payload_tree is None:
state.error = "deserialization_step: no payload_tree (previous step failed)"
return state
payload_class = state.metadata.get("payload_class")
if payload_class is None:
state.error = "deserialization_step: no payload_class in metadata (listener misconfigured)"
return state
try:
# xmlable.from_xml handles namespace-aware deserialization
instance = from_xml(payload_class, state.payload_tree)
state.payload = instance
except Exception as exc: # pylint: disable=broad-except
state.error = f"deserialization_step failed: {exc}"
return state

View file

@ -0,0 +1,50 @@
"""
routing_resolution.py Resolve routing based on derived root tag.
This is the final preparation step before dispatch.
It computes the root tag from the deserialized payload and looks it up in the
global routing table (root_tag list[Listener]).
On success: state.target_listeners is set
On failure: state.error is set message falls to system pipeline for <huh>
Part of AgentServer v2.1 message pump.
"""
from agentserver.message_bus.message_state import MessageState
from agentserver.message_bus.bus import MessageBus
async def routing_resolution_step(state: MessageState) -> MessageState:
"""
Resolve which listener(s) should handle this payload.
Root tag = f"{from_id.lower()}.{payload_class_name.lower()}"
(from_id is trustworthy injected by pump)
Supports:
- Normal unique routing (one listener)
- Broadcast (multiple listeners if broadcast: true and same root tag)
If no match error, falls to system pipeline.
"""
if state.payload is None:
state.error = "routing_resolution_step: no deserialized payload (previous step failed)"
return state
if state.from_id is None:
state.error = "routing_resolution_step: missing from_id (provenance error)"
return state
payload_class_name = type(state.payload).__name__.lower()
root_tag = f"{state.from_id.lower()}.{payload_class_name}"
bus = MessageBus.get_instance()
targets = bus.routing_table.get(root_tag)
if not targets:
state.error = f"routing_resolution_step: unknown capability root tag '{root_tag}'"
return state
state.target_listeners = targets
return state

View file

@ -1,13 +1,15 @@
"""
payload_extraction.py Extract the inner payload from the validated <message> envelope.
xsd_validation.py Validate the extracted payload against the listener-specific XSD.
After envelope_validation_step confirms a correct outer <message> envelope,
this step removes the envelope elements (<thread>, <from>, optional <to>, etc.)
and isolates the single child element that is the actual payload.
After payload_extraction_step isolates the payload_tree and provenance,
this step validates the payload against the XSD that was auto-generated
from the listener's @xmlify dataclass at registration time.
The payload is expected to be exactly one root element (the capability-specific XML).
If zero or multiple payload roots are found, we set a clear error this protects
against malformed or ambiguous messages.
The XSD is cached and pre-loaded. The schema object is injected into
state.metadata["schema"] when the listener's pipeline is built.
Failure here means the payload violates the declared contract we collect
detailed errors for diagnostics.
Part of AgentServer v2.1 message pump.
"""
@ -15,77 +17,43 @@ Part of AgentServer v2.1 message pump.
from lxml import etree
from agentserver.message_bus.message_state import MessageState
# Envelope namespace for easy reference
_ENVELOPE_NS = "https://xml-pipeline.org/ns/envelope/v1"
_MESSAGE_TAG = f"{{{ _ENVELOPE_NS }}}message"
async def payload_extraction_step(state: MessageState) -> MessageState:
async def xsd_validation_step(state: MessageState) -> MessageState:
"""
Extract the single payload element from the validated envelope.
Validate state.payload_tree against the listener's cached XSD schema.
Expected structure:
<message xmlns="https://xml-pipeline.org/ns/envelope/v1">
<thread>uuid</thread>
<from>sender</from>
<!-- optional <to>receiver</to> -->
<payload_root> this is the one we want
...
</payload_root>
</message>
Requires:
- state.payload_tree set
- state.metadata["schema"] containing a pre-loaded etree.XMLSchema
On success: state.payload_tree is set to the payload Element.
On failure: state.error is set with a clear diagnostic.
On success: payload is guaranteed to match the contract
On failure: state.error contains detailed validation messages
"""
if state.envelope_tree is None:
state.error = "payload_extraction_step: no envelope_tree (previous step failed)"
if state.payload_tree is None:
state.error = "xsd_validation_step: no payload_tree (previous extraction failed)"
return state
# Basic sanity — root must be <message> in correct namespace (already checked by schema,
# but we double-check for defence in depth)
if state.envelope_tree.tag != _MESSAGE_TAG:
state.error = f"payload_extraction_step: root tag is not <message> in envelope namespace"
schema = state.metadata.get("schema")
if schema is None:
state.error = "xsd_validation_step: no XSD schema in metadata (listener pipeline misconfigured)"
return state
# Find all direct children that are not envelope control elements
# Envelope control elements are: thread, from, to (optional)
payload_candidates = [
child
for child in state.envelope_tree
if not (
child.tag in {
f"{{{ _ENVELOPE_NS }}}thread",
f"{{{ _ENVELOPE_NS }}}from",
f"{{{ _ENVELOPE_NS }}}to",
}
)
]
if len(payload_candidates) == 0:
state.error = "payload_extraction_step: no payload element found inside <message>"
if not isinstance(schema, etree.XMLSchema):
state.error = "xsd_validation_step: metadata['schema'] is not an XMLSchema object"
return state
if len(payload_candidates) > 1:
state.error = (
"payload_extraction_step: multiple payload roots found — "
"exactly one capability payload element is allowed"
)
return state
try:
# assertValid raises DocumentInvalid with full error log
schema.assertValid(state.payload_tree)
# Success — exactly one payload element
payload_element = payload_candidates[0]
except etree.DocumentInvalid:
# Collect all errors for clear diagnostics
error_lines = []
for error in schema.error_log:
error_lines.append(f"{error.level_name}: {error.message} (line {error.line})")
state.error = "xsd_validation_step: payload failed contract validation\n" + "\n".join(error_lines)
# Optional: capture provenance from envelope for later use
# (these will be trustworthy because envelope was validated)
thread_elem = state.envelope_tree.find(f"{{{ _ENVELOPE_NS }}}thread")
from_elem = state.envelope_tree.find(f"{{{ _ENVELOPE_NS }}}from")
if thread_elem is not None and thread_elem.text:
state.thread_id = thread_elem.text.strip()
if from_elem is not None and from_elem.text:
state.from_id = from_elem.text.strip()
state.payload_tree = payload_element
except Exception as exc: # pylint: disable=broad-except
state.error = f"xsd_validation_step: unexpected error during validation: {exc}"
return state

View file

@ -19,7 +19,16 @@ xml-pipeline/
│ ├── message_bus/
│ │ ├── steps/
│ │ │ ├── __init__.py
│ │ │ └── repair_step.py
│ │ │ ├── c14n.py
│ │ │ ├── deserialization.py
│ │ │ ├── envelope_validation.py
│ │ │ ├── payload_extraction.py
│ │ │ ├── repair.py
│ │ │ ├── routing_resolution.py
│ │ │ ├── test_c14n.py
│ │ │ ├── test_repair.py
│ │ │ ├── thread_assignment.py
│ │ │ └── xsd_validation.py
│ │ ├── __init__.py
│ │ ├── bus.py
│ │ ├── config.py
@ -48,14 +57,23 @@ xml-pipeline/
│ │ └── token-scheduling-issues.md
│ ├── configuration.md
│ ├── core-principles-v2.1.md
│ ├── doc_cross_check.md
│ ├── handler-contract-v2.1.md
│ ├── listener-class-v2.1.md
│ ├── message-pump-v2.1.md
│ ├── primitives.md
│ ├── self-grammar-generation.md
│ └── why-not-json.md
├── tests/
│ ├── scripts/
│ │ └── generate_organism_key.py
│ └── __init__.py
├── xml_pipeline.egg-info/
│ ├── PKG-INFO
│ ├── SOURCES.txt
│ ├── dependency_links.txt
│ ├── requires.txt
│ └── top_level.txt
├── LICENSE
├── README.md
├── __init__.py

0
third_party/xmlable/__init__.py vendored Normal file
View file

261
third_party/xmlable/_errors.py vendored Normal file
View file

@ -0,0 +1,261 @@
"""
Colourful & descriptive errors for xmlable
- Clear messages
- Trace for parsing
"""
from dataclasses import dataclass
from typing import Any, Iterable
from termcolor import colored
from termcolor.termcolor import Color
from xmlable._utils import typename, AnyType
def trace_note(trace: list[str], arrow_c: Color, node_c: Color):
return colored(" > ", arrow_c).join(
map(lambda x: colored(x, node_c), trace)
)
@dataclass
class XErrorCtx:
trace: list[str]
def next(self, node: str):
return XErrorCtx(trace=self.trace + [node])
# TODO: Custom backtrace to point to location in the file
class XError(Exception):
def __init__(
self,
short: str,
what: str,
why: str,
ctx: XErrorCtx | None = None,
notes: Iterable[str] = [],
):
super().__init__(colored(short, "red", attrs=["blink"]))
self.add_note(colored("What: " + what, "blue"))
self.add_note(colored("Why: " + why, "yellow"))
if ctx is not None:
self.add_note(
colored("Where: ", "magenta")
+ trace_note(ctx.trace, "light_magenta", "light_cyan")
)
for note in notes:
self.add_note(note)
class ErrorTypes:
@staticmethod
def NonXMlifiedType(t_name: str) -> XError:
return XError(
short="Non XMlified Type",
what=f"You attempted to use {t_name} in an xmlified class, but {t_name} is not xmlified",
why=f"All types used in an xmlified class must be xmlified",
)
@staticmethod
def InvalidData(ctx: XErrorCtx, val: Any, t_name: str) -> XError:
return XError(
short="Invalid Data",
what=f"Could not validate {val} as a valid {t_name}",
why=f"Produced xml must be valid",
ctx=ctx,
)
@staticmethod
def ParseFailure(
ctx: XErrorCtx, text: str | None, t_name: str, caught: Exception
) -> XError:
return XError(
short="Parse Failure",
what=f"Failed to parse {text} as a {t_name} with error: \n {caught}",
why=f"This error implies the xml is not validated against the current xsd, or there is a bug in this type's parser",
ctx=ctx,
)
@staticmethod
def UnexpectedTag(
ctx: XErrorCtx, expected_name: str, struct_name: str, tag_found: str
) -> XError:
return XError(
short="Unexpected Tag",
what=f"Expected {expected_name} but found {tag_found}",
why=f"This is a {struct_name} that contains 0..n elements of {expected_name} and no other elements",
ctx=ctx,
)
@staticmethod
def IncorrectType(
ctx: XErrorCtx, expected_len: int, struct_name: str, val: Any, name: str
) -> XError:
return XError(
short="Incorrect Type",
what=f"You have provided {len(val)} values {val} for {name}, but {name} is a {struct_name} that takes only {expected_len} values",
why=f"In order to generate xml, the values provided need to be the correct types",
ctx=ctx,
)
@staticmethod
def IncorrectElementTag(
ctx: XErrorCtx,
struct_name: str,
tag_name: str,
elem_index: int,
tag_expected: str,
tag_found: str,
) -> XError:
return XError(
short="Incorrect Element Tag",
what=f"While parsing {struct_name} {tag_name} we expected element {elem_index} to be {tag_expected}, but found {tag_found}",
why=f"The xml representation for {struct_name} requires the correct names in the correct order",
ctx=ctx,
)
@staticmethod
def DuplicateItem(
ctx: XErrorCtx, struct_name: str, tag: str, item: str
) -> XError:
return XError(
short=f"Duplicate item in {struct_name}",
what=f"In {tag} the item {item} is present more than once",
why=f"A set can only contain unique items",
ctx=ctx,
)
@staticmethod
def InvalidDictionaryItem(
ctx: XErrorCtx,
expected_tag: str,
expected_key: str,
expected_val: str,
dict_tag: str,
item_tag: str,
) -> XError:
return XError(
short="Invalid item in dictionary",
what=f"An unexpected item with {dict_tag} is in dictionary {item_tag}",
why=f"Each item must have tag {expected_tag} with children {expected_key} and {expected_val}",
ctx=ctx,
)
@staticmethod
def InvalidVariant(
ctx: XErrorCtx,
name: str,
expected_types: list[AnyType],
found_type: AnyType | None,
found_value: Any,
) -> XError:
types = " | ".join(map(str, expected_types))
return XError(
short=f"Datatype not in Union",
what=f"{name} is a union of {types}, which does not contain {found_type} (you provided: {found_value})",
why=f"... uuuh, its a union?",
ctx=ctx,
)
@staticmethod
def MultipleVariants(ctx: XErrorCtx, variant_names: list[str]) -> XError:
return XError(
short="Multiple union variants present",
what=f"variants {', '.join(variant_names)} are present",
why=f"A union can only be one variant at a time",
ctx=ctx,
)
@staticmethod
def ParseInvalidVariant(
ctx: XErrorCtx, tag: str, named_variants: list[str], found_variant: str
) -> XError:
return XError(
short="Invalid Variant",
what=f"The union {tag} can contain variants {', '.join(named_variants)}, but you have used {found_variant}",
why=f"Only valid variants can be parsed",
ctx=ctx,
)
@staticmethod
def NoneIsSome(ctx: XErrorCtx, name: str, val: Any) -> XError:
return XError(
short="None object is not None",
what=f"{name} contains value {val} which is not None",
why="A None type object can only contain none",
ctx=ctx,
)
@staticmethod
def NotADataclass(cls: AnyType) -> XError:
cls_name: str = typename(cls)
return XError(
short="Non-Dataclass",
what=f"{cls_name} is not a dataclass",
why=f"xmlify uses dataclasses to get fields",
ctx=XErrorCtx([cls_name]),
notes=[f"\nTry:\n@xmlify\n@dataclass\nclass {cls_name}:"],
)
@staticmethod
def ReservedAttribute(cls: AnyType, attr_name: str) -> XError:
cls_name: str = typename(cls)
return XError(
short=f"Reserved Attribute",
what=f"{cls_name}.{attr_name} is used by xmlify, so it cannot be a field of the class",
why=f"xmlify aguments {cls_name} by adding methods it can then use for xsd, xml generation and parsing",
ctx=XErrorCtx([cls_name]),
)
@staticmethod
def CommentAttribute(cls: AnyType) -> XError:
cls_name: str = typename(cls)
return XError(
short=f"Comment Attribute",
what=f"xmlifed classes cannot use comment as an attribute",
why=f"comment is used as a tag name for comments by lxml, so comments inserted on xml generation could conflict",
ctx=XErrorCtx([cls_name]),
)
@staticmethod
def NonMemberTag(
ctx: XErrorCtx, cls: AnyType, tag: str, name: str
) -> XError:
cls_name: str = typename(cls)
return XError(
short="Non member tag",
what=f"In {tag} {cls_name}.{name} could not be found.",
why=f"All members, including {cls_name}.{name} must be present",
ctx=ctx,
)
@staticmethod
def MissingAttribute(
cls: AnyType, required_attrs: set[str], missing_attr: str
) -> XError:
cls_name: str = typename(cls)
return XError(
short="Missing Attribute",
what=f"The attribute {missing_attr} is missing from {cls_name}",
why=f"To be manual_xmlified the attributes: {', '.join(required_attrs)} are required. Try using help(IXmlify)",
ctx=XErrorCtx([cls_name]),
)
@staticmethod
def DependencyCycle(cycle: list[AnyType]) -> XError:
return XError(
short="Dependency Cycle in XSD",
what=f"There is a cycle: {'<-'.join(map(str, cycle))}",
why="The XSDs for classes are written to the .xsd file in dependency order",
)
@staticmethod
def NotXmlified(cls: AnyType) -> XError:
cls_name: str = typename(cls)
return XError(
short="Not Xmlified",
what=f"{cls_name} is not xmlified, and hence cannot have an associated parser",
why=f"the .xsd(...) method is required to write_xsd",
notes=[f"To fix, try:\n@xmlify\n@dataclass\nclass {cls_name}: ..."],
)

67
third_party/xmlable/_io.py vendored Normal file
View file

@ -0,0 +1,67 @@
"""
Easy file IO for users
- Need to make it obvious when an xml has been overwritten
- Easy parsing from a file
"""
from pathlib import Path
from typing import Any, TypeVar
from termcolor import colored
from lxml.objectify import parse as objectify_parse
from lxml.etree import _ElementTree
from xmlable._utils import typename
from xmlable._xobject import is_xmlified
from xmlable._errors import ErrorTypes
def write_file(file_path: str | Path, tree: _ElementTree):
print(
colored(f"Overwriting {file_path}", "red", attrs=["blink"]), end="..."
)
with open(file=file_path, mode="wb") as f:
tree.write(f, xml_declaration=True, encoding="utf-8", pretty_print=True)
print(colored(f"Complete!", "green", attrs=["blink"]))
def parse_file(cls: type, file_path: str | Path) -> Any:
"""
Parse a file, validate and produce instance of cls
INV: cls must be an xmlified class
"""
if not is_xmlified(cls):
raise ErrorTypes.NotXmlified(cls)
with open(file=file_path, mode="r") as f:
return cls.parse(objectify_parse(f).getroot()) # type: ignore[attr-defined]
def write_xsd(
file_path: str | Path,
cls: type,
namespaces: dict[str, str] = {},
imports: dict[str, str] = {},
):
if not is_xmlified(cls):
raise ErrorTypes.NonXMlifiedType(typename(cls))
else:
write_file(file_path, cls.xsd(namespaces=namespaces, imports=imports)) # type: ignore[attr-defined]
def write_xml_template(
file_path: str | Path, cls: type, schema_name: str | None = None
):
if not is_xmlified(cls):
raise ErrorTypes.NonXMlifiedType(typename(cls))
else:
schema_id: str = (
schema_name if schema_name is not None else typename(cls)
)
write_file(file_path, cls.xml(schema_id)) # type: ignore[attr-defined]
def write_xml_value(file_path: str | Path, val: Any):
cls = type(val)
if not is_xmlified(cls):
raise ErrorTypes.NonXMlifiedType(typename(cls))
else:
write_file(file_path, val.xml_value()) # type: ignore[attr-defined]

33
third_party/xmlable/_lxml_helpers.py vendored Normal file
View file

@ -0,0 +1,33 @@
"""
Helper functions for wrangling to lxml library
- Includes the XMLSchema used
"""
from lxml.objectify import ObjectifiedElement
from lxml.etree import _Element
from typing import Iterable
XMLURL = r"http://www.w3.org/2001/XMLSchema"
XMLSchema = r"{http://www.w3.org/2001/XMLSchema}"
def with_text(e: _Element, text: str) -> _Element:
e.text = text
return e
def with_children(parent: _Element, children: Iterable[_Element]) -> _Element:
for child in children:
parent.append(child)
return parent
def with_child(parent: _Element, child: _Element) -> _Element:
return with_children(parent, [child])
def children(obj: ObjectifiedElement) -> Iterable[ObjectifiedElement]:
def not_comment(child_obj: ObjectifiedElement):
return child_obj.tag != "comment"
return filter(not_comment, obj.getchildren()) # type: ignore[arg-type, operator]

137
third_party/xmlable/_manual.py vendored Normal file
View file

@ -0,0 +1,137 @@
"""
The @manual_xmlify decorator used to add the .xsd, .xml, .xml_value and .parse
methods to a class that already has .xsd_dependencies, .xsd_forward and
.get_xobject
"""
from typing import Any
from lxml.etree import _Element, Element, _ElementTree, ElementTree
from lxml.objectify import ObjectifiedElement
from xmlable._utils import typename, AnyType, ordered_iter
from xmlable._lxml_helpers import with_children, XMLSchema
from xmlable._errors import XError, XErrorCtx, ErrorTypes
def validate_manual_class(cls: AnyType):
attrs = {"get_xobject", "xsd_forward", "xsd_dependencies"}
for attr in attrs:
if not hasattr(cls, attr):
raise ErrorTypes.MissingAttribute(cls, attrs, attr)
def type_cycle(from_type: AnyType) -> list[AnyType]:
# INV: it is an xmlified type for a user define structure
cycle: list[AnyType] = []
def visit_dep(curr: AnyType) -> bool:
if curr == from_type or any(
visit_dep(dep) for dep in ordered_iter(curr.xsd_dependencies()) # type: ignore[attr-defined]
):
cycle.append(curr)
return True
else:
return False
assert visit_dep(from_type)
cycle.append(from_type)
return cycle
def manual_xmlify(cls: type) -> type:
"""
Generate the following methods:
```
def xsd(
id: str = cls_name,
namespaces: dict[str, str] = {},
imports: dict[str, str] = {},
) -> _ElementTree:
# ...
def xml(schema_name: str = cls_name) -> _ElementTree:
# ...
def xml_value(self, id: str = cls_name) -> _ElementTree:
# ...
def parse(obj: ObjectifiedElement) -> Any:
# ...
```
"""
try:
validate_manual_class(cls)
cls_name = typename(cls)
cls_xobject = cls.get_xobject() # type: ignore[attr-defined]
def xsd(
id: str = cls_name,
namespaces: dict[str, str] = {},
imports: dict[str, str] = {},
) -> _ElementTree:
# Get dependencies (user classes that need to be declared before)
visited: set[AnyType] = set()
dec_order: list[AnyType] = []
def toposort(
curr: AnyType, visited: set[AnyType], dec_order: list[AnyType]
):
if curr in visited:
raise ErrorTypes.DependencyCycle(type_cycle(curr))
visited.add(curr)
deps = curr.xsd_dependencies() # type: ignore[attr-defined]
for d in ordered_iter(deps):
if d not in visited:
toposort(d, visited, dec_order)
dec_order.append(curr)
toposort(cls, visited, dec_order)
# Create forward declarations, potentially adding to namespaces
decs: list[_Element] = [dec.xsd_forward(namespaces) for dec in dec_order] # type: ignore[attr-defined]
# generate main element (can add to namespaces)
main_element = cls_xobject.xsd_out(id, add_ns=namespaces)
return ElementTree(
with_children(
Element(
f"{XMLSchema}schema",
id=id,
elementFormDefault="qualified",
nsmap=namespaces,
),
[
Element(
f"{XMLSchema}import",
namespace=ns,
schemaLocation=sloc,
)
for ns, sloc in imports.items()
]
+ decs
+ [main_element],
)
)
def xml(schema_name: str = cls_name) -> _ElementTree:
return ElementTree(cls_xobject.xml_temp(schema_name))
def xml_value(self, id: str = cls_name) -> _ElementTree:
return ElementTree(cls_xobject.xml_out(id, self, XErrorCtx([id])))
def parse(obj: ObjectifiedElement) -> Any:
return cls_xobject.xml_in(obj, XErrorCtx([obj.tag]))
cls.xsd = xsd # type: ignore[attr-defined]
cls.xml = xml # type: ignore[attr-defined]
setattr(cls, "xml_value", xml_value) # needs to use self to get values
cls.parse = parse # type: ignore[attr-defined]
return cls
except XError as e:
# NOTE: Trick to remove dirty 'internal' traceback, and raise from
# xmlify (makes more sense to user than seeing internals)
e.__traceback__ = None
raise e

71
third_party/xmlable/_user.py vendored Normal file
View file

@ -0,0 +1,71 @@
"""
The IXmlify interface
- Contains the methods needed to make get_xobject work
- Allows type checking of user's implementations
"""
from abc import ABC, abstractmethod
from lxml.etree import _Element
from xmlable._xobject import XObject
from xmlable._utils import AnyType
class IXmlify(ABC):
"""
A useful interface for ensuring the attributes required for
@manual_xmlify are present
"""
@staticmethod
@abstractmethod
def get_xobject() -> XObject:
"""
produces an xobject encapsulates the:
- xsd usage (e.g <xs:element name="..." type="thisclass!"/>)
- xml template
- xml value
- parsing
```
@manual_xmlify
class Foo(IXmlify):
def get_xobject() -> XObject:
class MyObj(XObject):
# ... definitions
return MyObj
```
"""
pass
@staticmethod
@abstractmethod
def xsd_forward(add_ns: dict[str, str]) -> _Element:
"""
Produces the forward declaration
- xsd definition of the class's type
```
@manual_xmlify
class Foo(IXmlify):
def xsd_forward(add_ns: dict[str, str]) -> _Element:
return Element('{XMLSchema}complexType', name="Foo", ...)
```
"""
pass
@staticmethod
@abstractmethod
def xsd_dependencies() -> set[AnyType]:
"""
The user classes that need to be before this first
For example:
```
@manual_xmlify
class A(IXMlify):
# xobject uses Foo and Bar
def xsd_depedencies() -> set[type]:
return {Foo, Bar}
```
"""
pass

63
third_party/xmlable/_utils.py vendored Normal file
View file

@ -0,0 +1,63 @@
"""
Basic Utilities
Includes common helper functions for this project
- Handling optionals
- getting members by string name
- typenames
"""
from typing import Any, Callable, TypeVar, TypeAlias, Type, Iterable
from types import GenericAlias
AnyType: TypeAlias = Type | GenericAlias
T = TypeVar("T")
def some_or(data: T | None, alt: T):
return data if data is not None else alt
N = TypeVar("N")
M = TypeVar("M")
def some_or_apply(data: N, fn: Callable[[N], M], alt: M):
return fn(data) if data is not None else alt
def get(obj: Any, attr: str) -> Any:
return obj.__getattribute__(attr)
def opt_get(obj: Any, attr: str) -> Any | None:
try:
return obj.__getattribute__(attr)
except AttributeError:
return None
X = TypeVar("X")
Y = TypeVar("Y")
def firstkey(d: dict[X, Y], val: Y) -> X | None:
for k, v in d.items():
if v == val:
return k
else:
return None
def typename(t: AnyType) -> str:
if t is None:
return "None"
else:
return t.__name__
Z = TypeVar("Z")
def ordered_iter(types: Iterable[Z]) -> list[Z]:
return sorted(list(types), key=str)

156
third_party/xmlable/_xmlify.py vendored Normal file
View file

@ -0,0 +1,156 @@
"""XMLable
A decorator to allow creation of xml config based on python dataclasses
Given a dataclass:
- Produce an xsd schema based on the class
- Produce an xml template based on the class
- Given any instance of the class, make a best-effort attempt at turning it into
a filled xml
- Create a parser for parsing the xml
"""
from humps import pascalize
from dataclasses import fields, is_dataclass
from typing import Any, dataclass_transform, cast
from lxml.objectify import ObjectifiedElement
from lxml.etree import Element, _Element
from xmlable._utils import get, typename, AnyType
from xmlable._errors import XError, XErrorCtx, ErrorTypes
from xmlable._manual import manual_xmlify
from xmlable._lxml_helpers import with_children, with_child, XMLSchema
from xmlable._xobject import XObject, gen_xobject
def validate_class(cls: AnyType):
"""
Validate tha the class can be xmlified
- Must be a dataclass
- Cannot have any members called 'comment' (lxml parses comments as this tag)
- Cannot have
"""
if not is_dataclass(cls):
raise ErrorTypes.NotADataclass(cls)
reserved_attrs = ["get_xobject", "xsd_forward", "xsd_dependencies"]
# TODO: cleanup repetition
for f in fields(cls):
if f.name in reserved_attrs:
raise ErrorTypes.ReservedAttribute(cls, f.name)
elif f.name == "comment":
raise ErrorTypes.CommentAttribute(cls)
# JUSTIFY: Could potentially have added other attributes (of the class,
# rather than a field of an instance as provided by dataclass
# fields)
for reserved in reserved_attrs:
if hasattr(cls, reserved):
raise ErrorTypes.ReservedAttribute(cls, reserved)
if hasattr(cls, "comment"):
raise ErrorTypes.CommentAttribute(cls)
@dataclass_transform()
def xmlify(cls: type) -> AnyType:
try:
validate_class(cls)
cls_name = typename(cls)
forward_decs = cast(set[AnyType], {cls})
meta_xobjects = [
(
pascalize(f.name),
f,
gen_xobject(cast(AnyType, f.type), forward_decs),
)
for f in fields(cls)
]
class UserXObject(XObject):
def xsd_out(
self,
name: str,
attribs: dict[str, str] = {},
add_ns: dict[str, str] = {},
) -> _Element:
return Element(
f"{XMLSchema}element",
name=name,
type=cls_name,
attrib=attribs,
)
def xml_temp(self, name: str) -> _Element:
return with_children(
Element(name),
[
xobj.xml_temp(pascal_name)
for pascal_name, _, xobj in meta_xobjects
],
)
def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
return with_children(
Element(name),
[
xobj.xml_out(
pascal_name,
get(val, m.name),
ctx.next(pascal_name),
)
for pascal_name, m, xobj in meta_xobjects
],
)
def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> Any:
parsed: dict[str, Any] = {}
for pascal_name, m, xobj in meta_xobjects:
if (m_obj := get(obj, pascal_name)) is not None:
parsed[m.name] = xobj.xml_in(
m_obj, ctx.next(pascal_name)
)
else:
raise ErrorTypes.NonMemberTag(ctx, cls, obj.tag, m.name)
return cls(**parsed)
cls_xobject = UserXObject()
# JUSTIFY: Why are xsd forward & dependencies not part of xobject?
# - xobject covers the use (not forward decs)
# - we want to present error messages to the user containing
# their types, so xsd dependencies are in terms of python
# types, rather than xobjects
# - forward and dependencies do not apply to the basic types,
# only user types
def xsd_forward(add_ns: dict[str, str]) -> _Element:
return with_child(
Element(f"{XMLSchema}complexType", name=cls_name),
with_children(
Element(f"{XMLSchema}sequence"),
[
xobj.xsd_out(pascal_name, attribs={}, add_ns=add_ns)
for pascal_name, m, xobj in meta_xobjects
],
),
)
def xsd_dependencies() -> set[AnyType]:
return forward_decs
def get_xobject():
return cls_xobject
# helper methods for gen_xobject, and other dataclasses to generate their
# x methods
cls.xsd_forward = xsd_forward # type: ignore[attr-defined]
cls.xsd_dependencies = xsd_dependencies # type: ignore[attr-defined]
cls.get_xobject = get_xobject # type: ignore[attr-defined]
return manual_xmlify(cls)
except XError as e:
# NOTE: Trick to remove dirty 'internal' traceback, and raise from
# xmlify (makes more sense to user than seeing internals)
e.__traceback__ = None
raise e

640
third_party/xmlable/_xobject.py vendored Normal file
View file

@ -0,0 +1,640 @@
"""XObjects
XObjects are an intermediate representation for python types -> xsd/xml
- Produced by @xmlify decorated classes, and by gen_xobject
- Associated xsd, xml and parsing
"""
from humps import pascalize
from dataclasses import dataclass
from types import NoneType, UnionType
from lxml.objectify import ObjectifiedElement
from lxml.etree import Element, Comment, _Element
from abc import ABC, abstractmethod
from typing import Any, Callable, Type, get_args, TypeAlias, cast
from types import GenericAlias
from xmlable._utils import get, typename, firstkey, AnyType
from xmlable._errors import XErrorCtx, ErrorTypes
from xmlable._lxml_helpers import (
with_text,
with_child,
with_children,
XMLSchema,
XMLURL,
children,
)
class XObject(ABC):
"""Any XObject wraps the xsd generation,
We can map types to XObjects to get the xsd, template xml, etc
"""
@abstractmethod
def xsd_out(
self,
name: str,
attribs: dict[str, str] = {},
add_ns: dict[str, str] = {},
) -> _Element:
"""Generate the xsd schema for the object"""
pass
@abstractmethod
def xml_temp(self, name: str) -> _Element:
"""
Generate commented output for the xml representation
- Contains no values, only comments
- Not valid xml (can contain nested comments, comments instead of values)
"""
pass
@abstractmethod
def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
pass
@abstractmethod
def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> Any:
pass
@dataclass
class BasicObj(XObject):
"""
An xobject for a simple type (e.g string, int)
"""
type_str: str
convert_fn: Callable[[Any], str]
validate_fn: Callable[[Any], bool]
parse_fn: Callable[[ObjectifiedElement], Any]
def xsd_out(
self,
name: str,
attribs: dict[str, Any] = {},
add_ns: dict[str, str] = {},
) -> _Element:
# NOTE: namespace cringe:
# - lxml will deal with qualifying namespaces for the name of the
# element, but not for attributes
# - XMLSchema type attributes must be qualified
if (prefix := firstkey(add_ns, XMLURL)) is not None:
return Element(
f"{XMLSchema}element",
name=name,
type=f"{prefix}:{self.type_str}",
attrib=attribs,
)
else:
# add new namespace, resolve conflicts with extra 's'
new_ns = "xs"
while new_ns in add_ns:
new_ns += "s"
add_ns[new_ns] = XMLURL
return Element(
f"{XMLSchema}element",
name=name,
type=f"{new_ns}:{self.type_str}",
attrib=attribs,
nsmap={new_ns: XMLURL},
)
def xml_temp(self, name: str) -> _Element:
return with_text(Element(name), f"Fill me with an {self.type_str}")
def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
if not self.validate_fn(val):
raise ErrorTypes.InvalidData(ctx, val, self.type_str)
return with_text(Element(name), self.convert_fn(val))
def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> Any:
try:
return self.parse_fn(obj)
except Exception as e:
raise ErrorTypes.ParseFailure(ctx, obj.text, self.type_str, e)
@dataclass
class ListObj(XObject):
"""
An ordered list of objects
"""
item_xobject: XObject
list_elem_name: str
struct_name: str = "List"
def xsd_out(
self,
name: str,
attribs: dict[str, str] = {},
add_ns: dict[str, str] = {},
) -> _Element:
return with_child(
Element(f"{XMLSchema}element", name=name, attrib=attribs),
with_children(
Element(f"{XMLSchema}complexType"),
[
Comment(f"This is a {self.struct_name}"),
with_child(
Element(f"{XMLSchema}sequence"),
self.item_xobject.xsd_out(
self.list_elem_name,
{"minOccurs": "0", "maxOccurs": "unbounded"},
add_ns,
),
),
],
),
)
def xml_temp(self, name: str) -> _Element:
return with_children(
Element(name),
[
Comment(f"This is a {self.struct_name}"),
self.item_xobject.xml_temp(self.list_elem_name),
],
)
def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
if len(val) > 0:
return with_children(
Element(name),
[
self.item_xobject.xml_out(
self.list_elem_name,
item_val,
ctx.next(f"{self.list_elem_name}[{i}]"),
)
for i, item_val in enumerate(val)
],
)
else:
return with_child(
Element(name), Comment(f"Empty {self.struct_name}!")
)
def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> list[Any]:
parsed = []
for i, child in enumerate(children(obj)):
if child.tag != self.list_elem_name:
raise ErrorTypes.UnexpectedTag(
ctx, self.list_elem_name, self.struct_name, child.tag
)
else:
parsed.append(
self.item_xobject.xml_in(
child, ctx.next(f"{self.list_elem_name}[{i}]")
)
)
return parsed
@dataclass
class StructObj(XObject):
"""An order list of key-value pairs""" # TODO: make objects variable length tuple
objects: list[tuple[str, XObject]]
struct_name: str = "Struct"
def xsd_out(
self,
name: str,
attribs: dict[str, str] = {},
add_ns: dict[str, str] = {},
) -> _Element:
return with_child(
Element(f"{XMLSchema}element", name=name, attrib=attribs),
with_child(
Element(f"{XMLSchema}complexType"),
with_children(
Element(f"{XMLSchema}sequence"),
[Comment(f"This is a {self.struct_name}")]
+ [
xobj.xsd_out(member, {}, add_ns)
for member, xobj in self.objects
],
),
),
)
def xml_temp(self, name: str) -> _Element:
return with_children(
Element(name),
[Comment(f"This is a {self.struct_name}")]
+ [xobj.xml_temp(member) for member, xobj in self.objects],
)
def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
if len(val) != len(self.objects):
raise ErrorTypes.IncorrectType(
ctx, len(self.objects), self.struct_name, val, name
)
return with_children(
Element(name),
[
xobj.xml_out(member, v, ctx.next(member))
for (member, xobj), v in zip(self.objects, val)
],
)
def xml_in(
self, obj: ObjectifiedElement, ctx: XErrorCtx
) -> list[tuple[str, Any]]:
parsed = []
for i, (child, (name, xobj)) in enumerate(
zip(children(obj), self.objects)
):
if child.tag != name:
raise ErrorTypes.IncorrectElementTag(
ctx, self.struct_name, obj.tag, i, name, child.tag
)
parsed.append((name, xobj.xml_in(child, ctx.next(name))))
return parsed
class TupleObj(XObject):
"""An anonymous struct"""
def __init__(
self,
objects: tuple[XObject, ...],
elem_gen: Callable[[int], str] = lambda i: f"Item-{i+1}",
):
self.elem_gen = elem_gen
self.struct: StructObj = StructObj(
[(self.elem_gen(i), xobj) for i, xobj in enumerate(objects)],
struct_name="Tuple",
)
def xsd_out(
self,
name: str,
attribs: dict[str, str] = {},
add_ns: dict[str, str] = {},
) -> _Element:
return self.struct.xsd_out(name, attribs, add_ns)
def xml_temp(self, name: str) -> _Element:
return self.struct.xml_temp(name)
def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
return self.struct.xml_out(name, val, ctx)
def xml_in(
self, obj: ObjectifiedElement, ctx: XErrorCtx
) -> tuple[Any, ...]:
# Assumes the objects are in the correct order
return tuple(zip(*self.struct.xml_in(obj, ctx)))[1] # type: ignore[no-any-return]
class SetOBj(XObject):
"""An unordered collection of unique elements"""
def __init__(self, inner: XObject, elem_name: str = "setitem"):
self.list = ListObj(inner, elem_name, struct_name="set")
def xsd_out(
self,
name: str,
attribs: dict[str, str] = {},
add_ns: dict[str, str] = {},
) -> _Element:
return self.list.xsd_out(name, attribs, add_ns)
def xml_temp(self, name: str) -> _Element:
return self.list.xml_temp(name)
def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
return self.list.xml_out(name, list(val), ctx)
def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> set[Any]:
parsed: set[Any] = set()
for item in self.list.xml_in(obj, ctx):
if item in parsed:
raise ErrorTypes.DuplicateItem(ctx, "set", obj.tag, item)
parsed.add(item)
return parsed
@dataclass
class DictObj(XObject):
"""An unordered collection of key-value pair elements"""
key_xobject: XObject
val_xobject: XObject
key_name: str = "Key"
val_name: str = "Val"
item_name: str = "Item"
def xsd_out(
self,
name: str,
attribs: dict[str, str] = {},
add_ns: dict[str, str] = {},
) -> _Element:
return with_child(
Element(f"{XMLSchema}element", name=name, attrib=attribs),
with_children(
Element(f"{XMLSchema}complexType"),
[
Comment("this is a dictionary!"),
with_child(
Element(f"{XMLSchema}sequence"),
with_child(
Element(
f"{XMLSchema}element",
name=self.item_name,
minOccurs="0",
maxOccurs="unbounded",
),
with_child(
Element(f"{XMLSchema}complexType"),
with_children(
Element(f"{XMLSchema}sequence"),
[
self.key_xobject.xsd_out(
self.key_name, {}, add_ns
),
self.val_xobject.xsd_out(
self.val_name, {}, add_ns
),
],
),
),
),
),
],
),
)
def xml_temp(self, name: str) -> _Element:
return with_children(
Element(name),
[
Comment("This is a dictionary"),
with_children(
Element(self.item_name),
[
self.key_xobject.xml_temp(self.key_name),
self.val_xobject.xml_temp(self.val_name),
],
),
],
)
def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
item_ctx = ctx.next(self.item_name)
return with_children(
Element(name),
[
with_children(
Element(self.item_name),
[
self.key_xobject.xml_out(
self.key_name, k, item_ctx.next(name)
),
self.val_xobject.xml_out(
self.val_name, v, item_ctx.next(name)
),
],
)
for k, v in val.items()
],
)
def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> dict[Any, Any]:
parsed = {}
for child in children(obj):
if child.tag != self.item_name:
raise ErrorTypes.InvalidDictionaryItem(
ctx,
self.item_name,
self.key_name,
self.val_name,
child.tag,
obj.tag,
)
else:
child_ctx = ctx.next(self.item_name)
k = self.key_xobject.xml_in(
get(child, self.key_name), child_ctx.next(self.key_name)
)
v = self.val_xobject.xml_in(
get(child, self.val_name), child_ctx.next(self.val_name)
)
if k in parsed:
raise ErrorTypes.DuplicateItem(
ctx, "dictionary", obj.tag, k
)
parsed[k] = v
# TODO: Check for other tags? Fail better?
return parsed
def resolve_type(v: Any) -> AnyType:
"""Determine the type of some value, using primitive types
- If empty container, only provide top container type
INV: only generic types for v are {tuple, list, dict, set}
"""
t = type(v)
if t in {int, float, str, bool, NoneType}:
return t
elif t == dict and len(v) > 0:
t0, t1 = next(iter(v.items()))
return dict[resolve_type(t0), resolve_type(t1)] # type: ignore[misc, index, no-any-return]
elif t == list and len(v) > 0:
return list[resolve_type(v[0])] # type: ignore[misc, index, no-any-return]
elif t == set and len(v) > 0:
return set[resolve_type(next(iter(v)))] # type: ignore[misc, index, no-any-return]
elif t == tuple and len(v) > 0:
return tuple[*(resolve_type(vi) for vi in v)] # type: ignore[misc, no-any-return]
else:
# INV: non-generic type
return t
@dataclass
class UnionObj(XObject):
"""A variant, can be one of several different types"""
xobjects: dict[AnyType, XObject]
elem_gen: Callable[[AnyType], str] = lambda t: pascalize(typename(t))
def xsd_out(
self,
name: str,
attribs: dict[str, str] = {},
add_ns: dict[str, str] = {},
) -> _Element:
return with_child(
Element(f"{XMLSchema}element", name=name, attrib=attribs),
with_children(
Element(f"{XMLSchema}complexType"),
[
Comment("this is a union!"),
with_children(
Element(f"{XMLSchema}sequence"),
[
xobj.xsd_out(
self.elem_gen(t), {"minOccurs": "0"}, add_ns
)
for t, xobj in self.xobjects.items()
],
),
],
),
)
def xml_temp(self, name: str) -> _Element:
return with_children(
Element(name),
[
Comment(
"This is a union, the following variants are possible, only one can be present"
)
]
+ [
xobj.xml_temp(self.elem_gen(t))
for t, xobj in self.xobjects.items()
],
)
def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
t = resolve_type(val)
if (val_xobj := self.xobjects.get(t)) is not None:
variant_name = self.elem_gen(t)
return with_child(
Element(name),
val_xobj.xml_out(variant_name, val, ctx.next(variant_name)),
)
else:
raise ErrorTypes.InvalidVariant(
ctx, name, list(self.xobjects.keys()), t, val
)
def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> Any:
named = {self.elem_gen(t): xobj for t, xobj in self.xobjects.items()}
variants = list(children(obj))
if len(variants) != 1:
raise ErrorTypes.MultipleVariants(ctx, [v.tag for v in variants])
variant = variants[0]
if (xobj := named.get(variant.tag)) is not None:
return xobj.xml_in(variant, ctx.next(variant.tag))
else:
raise ErrorTypes.ParseInvalidVariant(
ctx, str(obj.tag), list(named.keys()), str(variant)
)
class NoneObj(XObject):
"""
An object representing the python 'None' type
- Unions of form `int | None` are used for optionals
"""
def xsd_out(
self,
name: str,
attribs: dict[str, str] = {},
add_ns: dict[str, str] = {},
) -> _Element:
return with_child(
Element(f"{XMLSchema}element", name=name, attrib=attribs),
Comment("This is a None type"),
)
def xml_temp(self, name: str) -> _Element:
return with_child(Element(name), Comment("This is None"))
def xml_out(self, name: str, val: Any, ctx: XErrorCtx) -> _Element:
if val != None:
raise ErrorTypes.NoneIsSome(ctx, name, val)
return with_child(Element(name), Comment("This is None"))
def xml_in(self, obj: ObjectifiedElement, ctx: XErrorCtx) -> Any:
return None
def is_xmlified(cls):
return (
hasattr(cls, "xsd_forward")
and hasattr(cls, "xsd_dependencies")
and hasattr(cls, "get_xobject")
and hasattr(cls, "xsd")
and hasattr(cls, "xml")
and hasattr(cls, "xml_value")
and hasattr(cls, "parse")
)
def gen_xobject(data_type: AnyType, forward_dec: set[AnyType]) -> XObject:
basic_types: dict[
AnyType, tuple[str, Callable[[Any], str], Callable[[Any], bool]]
] = {
int: ("integer", str, lambda d: type(d) == int),
str: ("string", str, lambda d: type(d) == str),
float: ("decimal", str, lambda d: type(d) == float),
bool: (
"boolean",
lambda b: "true" if b else "false",
lambda d: type(d) == bool,
),
}
if (basic_entry := basic_types.get(data_type)) is not None:
type_str, convert_fn, validate_fn = basic_entry
# NOTE: here was can pass the parse_fn as the data type, as the name is
# also a constructor. (e.g. `int` -> `int("23") == 32`)
parse_fn = cast(Callable[[ObjectifiedElement], Any], data_type)
return BasicObj(type_str, convert_fn, validate_fn, parse_fn)
elif isinstance(data_type, NoneType) or data_type == NoneType:
# NOTE: Python typing cringe: None can be both a type and a value
# (even when within a type hint!)
# a: list[None] -> None is an instance of NoneType
# a: int | None -> Union of int and NoneType
return NoneObj()
elif isinstance(data_type, UnionType):
return UnionObj(
{t: gen_xobject(t, forward_dec) for t in get_args(data_type)}
)
else:
t_name = typename(data_type)
if t_name == "list":
(item_type,) = get_args(data_type)
return ListObj(
gen_xobject(item_type, forward_dec),
pascalize(typename(item_type)),
)
elif t_name == "dict":
key_type, val_type = get_args(data_type)
return DictObj(
gen_xobject(key_type, forward_dec),
gen_xobject(val_type, forward_dec),
)
elif t_name == "tuple":
return TupleObj(
tuple(gen_xobject(t, forward_dec) for t in get_args(data_type))
)
elif t_name == "set":
(item_type,) = get_args(data_type)
return SetOBj(
gen_xobject(item_type, forward_dec),
pascalize(typename(item_type)),
)
else:
if is_xmlified(data_type):
forward_dec.add(data_type)
return data_type.get_xobject() # type: ignore[attr-defined, no-any-return]
else:
raise ErrorTypes.NonXMlifiedType(t_name)

0
third_party/xmlable/py.typed vendored Normal file
View file