diff --git a/tdom/__init__.py b/tdom/__init__.py index 4503582..cd5a291 100644 --- a/tdom/__init__.py +++ b/tdom/__init__.py @@ -1,7 +1,7 @@ from markupsafe import Markup, escape from .nodes import Comment, DocumentType, Element, Fragment, Node, Text -from .processor import html +from .processor import html, svg # We consider `Markup` and `escape` to be part of this module's public API @@ -13,6 +13,7 @@ "Fragment", "html", "Markup", + "svg", "Node", "Text", ] diff --git a/tdom/parser.py b/tdom/parser.py index 82de290..76f96c8 100644 --- a/tdom/parser.py +++ b/tdom/parser.py @@ -25,6 +25,106 @@ type HTMLAttribute = tuple[str, str | None] type HTMLAttributesDict = dict[str, str | None] +SVG_TAG_FIX = { + "altglyph": "altGlyph", + "altglyphdef": "altGlyphDef", + "altglyphitem": "altGlyphItem", + "animatecolor": "animateColor", + "animatemotion": "animateMotion", + "animatetransform": "animateTransform", + "clippath": "clipPath", + "feblend": "feBlend", + "fecolormatrix": "feColorMatrix", + "fecomponenttransfer": "feComponentTransfer", + "fecomposite": "feComposite", + "feconvolvematrix": "feConvolveMatrix", + "fediffuselighting": "feDiffuseLighting", + "fedisplacementmap": "feDisplacementMap", + "fedistantlight": "feDistantLight", + "fedropshadow": "feDropShadow", + "feflood": "feFlood", + "fefunca": "feFuncA", + "fefuncb": "feFuncB", + "fefuncg": "feFuncG", + "fefuncr": "feFuncR", + "fegaussianblur": "feGaussianBlur", + "feimage": "feImage", + "femerge": "feMerge", + "femergenode": "feMergeNode", + "femorphology": "feMorphology", + "feoffset": "feOffset", + "fepointlight": "fePointLight", + "fespecularlighting": "feSpecularLighting", + "fespotlight": "feSpotLight", + "fetile": "feTile", + "feturbulence": "feTurbulence", + "foreignobject": "foreignObject", + "glyphref": "glyphRef", + "lineargradient": "linearGradient", + "radialgradient": "radialGradient", + "textpath": "textPath", +} + +SVG_CASE_FIX = { + "attributename": "attributeName", + "attributetype": "attributeType", + "basefrequency": "baseFrequency", + "baseprofile": "baseProfile", + "calcmode": "calcMode", + "clippathunits": "clipPathUnits", + "diffuseconstant": "diffuseConstant", + "edgemode": "edgeMode", + "filterunits": "filterUnits", + "glyphref": "glyphRef", + "gradienttransform": "gradientTransform", + "gradientunits": "gradientUnits", + "kernelmatrix": "kernelMatrix", + "kernelunitlength": "kernelUnitLength", + "keypoints": "keyPoints", + "keysplines": "keySplines", + "keytimes": "keyTimes", + "lengthadjust": "lengthAdjust", + "limitingconeangle": "limitingConeAngle", + "markerheight": "markerHeight", + "markerunits": "markerUnits", + "markerwidth": "markerWidth", + "maskcontentunits": "maskContentUnits", + "maskunits": "maskUnits", + "numoctaves": "numOctaves", + "pathlength": "pathLength", + "patterncontentunits": "patternContentUnits", + "patterntransform": "patternTransform", + "patternunits": "patternUnits", + "pointsatx": "pointsAtX", + "pointsaty": "pointsAtY", + "pointsatz": "pointsAtZ", + "preservealpha": "preserveAlpha", + "preserveaspectratio": "preserveAspectRatio", + "primitiveunits": "primitiveUnits", + "refx": "refX", + "refy": "refY", + "repeatcount": "repeatCount", + "repeatdur": "repeatDur", + "requiredextensions": "requiredExtensions", + "requiredfeatures": "requiredFeatures", + "specularconstant": "specularConstant", + "specularexponent": "specularExponent", + "spreadmethod": "spreadMethod", + "startoffset": "startOffset", + "stddeviation": "stdDeviation", + "stitchtiles": "stitchTiles", + "surfacescale": "surfaceScale", + "systemlanguage": "systemLanguage", + "tablevalues": "tableValues", + "targetx": "targetX", + "targety": "targetY", + "textlength": "textLength", + "viewbox": "viewBox", + "viewtarget": "viewTarget", + "xchannelselector": "xChannelSelector", + "ychannelselector": "yChannelSelector", + "zoomandpan": "zoomAndPan", +} @dataclass class OpenTElement: @@ -87,8 +187,10 @@ class TemplateParser(HTMLParser): stack: list[OpenTag] placeholders: PlaceholderState source: SourceTracker | None + _svg_depth: int = 0 - def __init__(self, *, convert_charrefs: bool = True): + def __init__(self, *, convert_charrefs: bool = True, svg_context: bool = False): + self._initial_svg_depth = 1 if svg_context else 0 # This calls HTMLParser.reset() which we override to set up our state. super().__init__(convert_charrefs=convert_charrefs) @@ -108,10 +210,15 @@ def append_child(self, child: TNode) -> None: # Attribute Helpers # ------------------------------------------ - def make_tattr(self, attr: HTMLAttribute) -> TAttribute: + def make_tattr( + self, attr: HTMLAttribute, svg_context: bool = False + ) -> TAttribute: """Build a TAttribute from a raw attribute tuple.""" name, value = attr + if svg_context: + name = SVG_CASE_FIX.get(name, name) + name_ref = self.placeholders.remove_placeholders(name) value_ref = ( self.placeholders.remove_placeholders(value) if value is not None else None @@ -136,20 +243,26 @@ def make_tattr(self, attr: HTMLAttribute) -> TAttribute: ) return TSpreadAttribute(i_index=name_ref.i_indexes[0]) - def make_tattrs(self, attrs: t.Sequence[HTMLAttribute]) -> tuple[TAttribute, ...]: + def make_tattrs( + self, attrs: t.Sequence[HTMLAttribute], svg_context: bool = False + ) -> tuple[TAttribute, ...]: """Build TAttributes from raw attribute tuples.""" - return tuple(self.make_tattr(attr) for attr in attrs) + return tuple(self.make_tattr(attr, svg_context) for attr in attrs) # ------------------------------------------ # Tag Helpers # ------------------------------------------ - def make_open_tag(self, tag: str, attrs: t.Sequence[HTMLAttribute]) -> OpenTag: + def make_open_tag( + self, tag: str, attrs: t.Sequence[HTMLAttribute], svg_context: bool = False + ) -> OpenTag: """Build an OpenTag from a raw tag and attribute tuples.""" tag_ref = self.placeholders.remove_placeholders(tag) if tag_ref.is_literal: - return OpenTElement(tag=tag, attrs=self.make_tattrs(attrs)) + return OpenTElement( + tag=tag, attrs=self.make_tattrs(attrs, svg_context) + ) if not tag_ref.is_singleton: raise ValueError( @@ -162,7 +275,7 @@ def make_open_tag(self, tag: str, attrs: t.Sequence[HTMLAttribute]) -> OpenTag: i_index = tag_ref.i_indexes[0] return OpenTComponent( start_i_index=i_index, - attrs=self.make_tattrs(attrs), + attrs=self.make_tattrs(attrs, svg_context), ) def finalize_tag( @@ -225,7 +338,13 @@ def validate_end_tag(self, tag: str, open_tag: OpenTag) -> int | None: # ------------------------------------------ def handle_starttag(self, tag: str, attrs: t.Sequence[HTMLAttribute]) -> None: - open_tag = self.make_open_tag(tag, attrs) + if tag == "svg": + self._svg_depth += 1 + + if self._svg_depth > 0: + tag = SVG_TAG_FIX.get(tag, tag) + + open_tag = self.make_open_tag(tag, attrs, svg_context=(self._svg_depth > 0)) if isinstance(open_tag, OpenTElement) and open_tag.tag in VOID_ELEMENTS: final_tag = self.finalize_tag(open_tag) self.append_child(final_tag) @@ -234,7 +353,13 @@ def handle_starttag(self, tag: str, attrs: t.Sequence[HTMLAttribute]) -> None: def handle_startendtag(self, tag: str, attrs: t.Sequence[HTMLAttribute]) -> None: """Dispatch a self-closing tag, `` to specialized handlers.""" - open_tag = self.make_open_tag(tag, attrs) + is_svg_tag = tag == "svg" + effective_svg_context = (self._svg_depth > 0) or is_svg_tag + + if effective_svg_context: + tag = SVG_TAG_FIX.get(tag, tag) + + open_tag = self.make_open_tag(tag, attrs, svg_context=effective_svg_context) final_tag = self.finalize_tag(open_tag) self.append_child(final_tag) @@ -242,6 +367,12 @@ def handle_endtag(self, tag: str) -> None: if not self.stack: raise ValueError(f"Unexpected closing tag with no open tag.") + if self._svg_depth > 0: + tag = SVG_TAG_FIX.get(tag, tag) + + if tag == "svg": + self._svg_depth -= 1 + open_tag = self.stack.pop() endtag_i_index = self.validate_end_tag(tag, open_tag) final_tag = self.finalize_tag(open_tag, endtag_i_index) @@ -285,6 +416,7 @@ def reset(self): self.stack = [] self.placeholders = PlaceholderState() self.source = None + self._svg_depth = getattr(self, "_initial_svg_depth", 0) def close(self) -> None: if self.stack: @@ -337,13 +469,16 @@ def feed_template(self, template: Template) -> None: self.feed_str(template.strings[-1]) @staticmethod - def parse(t: Template) -> TNode: + def parse(t: Template, *, svg_context: bool = False) -> TNode: """ Parse a Template containing valid HTML and substitutions and return a TNode tree representing its structure. This cachable structure can later be resolved against actual interpolation values to produce a Node tree. + + Pass ``svg_context=True`` for SVG fragments that have no ```` + wrapper, so that tag and attribute case-fixing applies from the root. """ - parser = TemplateParser() + parser = TemplateParser(svg_context=svg_context) parser.feed_template(t) parser.close() return parser.get_tnode() diff --git a/tdom/processor.py b/tdom/processor.py index 1e2528d..2da8286 100644 --- a/tdom/processor.py +++ b/tdom/processor.py @@ -45,7 +45,7 @@ def __html__(self) -> str: ... # pragma: no cover @lru_cache(maxsize=0 if "pytest" in sys.modules else 512) def _parse_and_cache(cachable: CachableTemplate) -> TNode: - return TemplateParser.parse(cachable.template) + return TemplateParser.parse(cachable.template, svg_context=cachable.svg_context) type Attribute = tuple[str, object] @@ -470,6 +470,7 @@ def _invoke_component( and passed as keyword arguments if the callable accepts them (or has **kwargs). Attributes that don't match parameters are silently ignored. """ + component_name = interpolation.expression or "unknown component" value = format_interpolation(interpolation) if not callable(value): raise TypeError( @@ -478,9 +479,11 @@ def _invoke_component( callable_info = get_callable_info(value) if callable_info.requires_positional: - raise TypeError( + err = TypeError( "Component callables cannot have required positional arguments." ) + err.add_note(f"While invoking component: {component_name}") + raise err kwargs: AttributesDict = {} @@ -497,11 +500,17 @@ def _invoke_component( # Check to make sure we've fully satisfied the callable's requirements missing = callable_info.required_named_params - kwargs.keys() if missing: - raise TypeError( + err = TypeError( f"Missing required parameters for component: {', '.join(missing)}" ) - - result = value(**kwargs) + err.add_note(f"While invoking component: {component_name}") + raise err + + try: + result = value(**kwargs) + except TypeError as e: + e.add_note(f"While invoking component: {component_name}") + raise return _node_from_value(result) @@ -591,7 +600,22 @@ def _resolve_t_node(t_node: TNode, interpolations: tuple[Interpolation, ...]) -> def html(template: Template) -> Node: - """Parse an HTML t-string, substitue values, and return a tree of Nodes.""" + """Parse an HTML t-string, substitute values, and return a tree of Nodes.""" cachable = CachableTemplate(template) t_node = _parse_and_cache(cachable) return _resolve_t_node(t_node, template.interpolations) + + +def svg(template: Template) -> Node: + """Parse a standalone SVG fragment and return a tree of Nodes. + + Use when the template does not contain an ```` wrapper element. + Tag and attribute case-fixing (e.g. ``clipPath``, ``viewBox``) are applied + from the root, exactly as they would be inside ``html(t"...")``. + + When the template does contain ````, use ``html()`` — the SVG context + is detected automatically. + """ + cachable = CachableTemplate(template, svg_context=True) + t_node = _parse_and_cache(cachable) + return _resolve_t_node(t_node, template.interpolations) diff --git a/tdom/svg_test.py b/tdom/svg_test.py new file mode 100644 index 0000000..e5810f5 --- /dev/null +++ b/tdom/svg_test.py @@ -0,0 +1,104 @@ +import pytest + +from tdom import html, svg + + +# svg() — tag case-fixing + +def test_svg_clippath_case_fixed(): + node = svg(t"") + assert str(node) == '' + + +def test_svg_lineargradient_case_fixed(): + node = svg(t"") + assert str(node) == '' + + +def test_svg_femergenode_self_closing_case_fixed(): + node = svg(t"") + assert str(node) == "" + + +def test_svg_nested_tags_case_fixed(): + node = svg(t"") + assert str(node) == '' + + +# ------------------------------ +# svg() — attribute case-fixing +# ------------------------------ + + +def test_svg_viewbox_attr_case_fixed(): + node = svg(t'') + assert str(node) == '' + +def test_svg_case_sensitivity(): + # SVG attributes like viewBox are case-sensitive + node = html(t'') + # We expect viewBox, not viewbox + assert 'viewBox' in str(node) + +def test_svg_tag_case_sensitivity(): + # SVG tags like linearGradient are case-sensitive + node = html(t'') + assert 'linearGradient' in str(node) + +def test_svg_tag_case_sensitivity_outside_svg(): + # Outside SVG, tags should be lowercased + node = html(t'') + assert 'lineargradient' in str(node) + +def test_svg_attr_case_sensitivity_outside_svg(): + # Outside SVG, attributes should be lowercased + node = html(t'
') + assert 'viewbox' in str(node) + +def test_svg_interpolated_attr(): + cx, cy, r = 50, 50, 40 + node = svg(t'') + assert str(node) == '' + + +def test_svg_interpolated_child(): + label = "hello" + node = svg(t"{label}") + assert str(node) == "hello" + + +def test_svg_fragment_multiple_roots(): + node = svg(t"") + assert str(node) == "" + + +# --------------------------------------------------------- +# svg() vs html() — same strings, distinct parse results +# --------------------------------------------------------- + + +def test_svg_and_html_produce_different_results_for_same_strings(): + # html() lowercases clipPath (no SVG context); svg() preserves it. + html_node = html(t"") + svg_node = svg(t"") + assert str(html_node) == "" + assert str(svg_node) == "" + + +def test_html_full_svg_document_still_works(): + # html() auto-detects SVG context when is present — no regression. + node = html(t"") + assert str(node) == '' + + +# ------------------------------- +# svg() composable inside html() +# ------------------------------- + + +def test_svg_fragment_embedded_in_html(): + def icon(): + return svg(t'') + + node = html(t'
{icon()}
') + assert str(node) == '
' diff --git a/tdom/utils.py b/tdom/utils.py index 5f69c83..a2e298b 100644 --- a/tdom/utils.py +++ b/tdom/utils.py @@ -15,13 +15,14 @@ class CachableTemplate: # CONSIDER: what about interpolation format specs, convsersions, etc.? - def __init__(self, template: Template) -> None: + def __init__(self, template: Template, svg_context: bool = False) -> None: self.template = template + self.svg_context = svg_context def __eq__(self, other: object) -> bool: if not isinstance(other, CachableTemplate): return NotImplemented - return self.template.strings == other.template.strings + return self.template.strings == other.template.strings and self.svg_context == other.svg_context def __hash__(self) -> int: - return hash(self.template.strings) + return hash((self.template.strings, self.svg_context))