diff --git a/tdom/__init__.py b/tdom/__init__.py
index 4503582..cd5a291 100644
--- a/tdom/__init__.py
+++ b/tdom/__init__.py
@@ -1,7 +1,7 @@
from markupsafe import Markup, escape
from .nodes import Comment, DocumentType, Element, Fragment, Node, Text
-from .processor import html
+from .processor import html, svg
# We consider `Markup` and `escape` to be part of this module's public API
@@ -13,6 +13,7 @@
"Fragment",
"html",
"Markup",
+ "svg",
"Node",
"Text",
]
diff --git a/tdom/parser.py b/tdom/parser.py
index 82de290..76f96c8 100644
--- a/tdom/parser.py
+++ b/tdom/parser.py
@@ -25,6 +25,106 @@
type HTMLAttribute = tuple[str, str | None]
type HTMLAttributesDict = dict[str, str | None]
+SVG_TAG_FIX = {
+ "altglyph": "altGlyph",
+ "altglyphdef": "altGlyphDef",
+ "altglyphitem": "altGlyphItem",
+ "animatecolor": "animateColor",
+ "animatemotion": "animateMotion",
+ "animatetransform": "animateTransform",
+ "clippath": "clipPath",
+ "feblend": "feBlend",
+ "fecolormatrix": "feColorMatrix",
+ "fecomponenttransfer": "feComponentTransfer",
+ "fecomposite": "feComposite",
+ "feconvolvematrix": "feConvolveMatrix",
+ "fediffuselighting": "feDiffuseLighting",
+ "fedisplacementmap": "feDisplacementMap",
+ "fedistantlight": "feDistantLight",
+ "fedropshadow": "feDropShadow",
+ "feflood": "feFlood",
+ "fefunca": "feFuncA",
+ "fefuncb": "feFuncB",
+ "fefuncg": "feFuncG",
+ "fefuncr": "feFuncR",
+ "fegaussianblur": "feGaussianBlur",
+ "feimage": "feImage",
+ "femerge": "feMerge",
+ "femergenode": "feMergeNode",
+ "femorphology": "feMorphology",
+ "feoffset": "feOffset",
+ "fepointlight": "fePointLight",
+ "fespecularlighting": "feSpecularLighting",
+ "fespotlight": "feSpotLight",
+ "fetile": "feTile",
+ "feturbulence": "feTurbulence",
+ "foreignobject": "foreignObject",
+ "glyphref": "glyphRef",
+ "lineargradient": "linearGradient",
+ "radialgradient": "radialGradient",
+ "textpath": "textPath",
+}
+
+SVG_CASE_FIX = {
+ "attributename": "attributeName",
+ "attributetype": "attributeType",
+ "basefrequency": "baseFrequency",
+ "baseprofile": "baseProfile",
+ "calcmode": "calcMode",
+ "clippathunits": "clipPathUnits",
+ "diffuseconstant": "diffuseConstant",
+ "edgemode": "edgeMode",
+ "filterunits": "filterUnits",
+ "glyphref": "glyphRef",
+ "gradienttransform": "gradientTransform",
+ "gradientunits": "gradientUnits",
+ "kernelmatrix": "kernelMatrix",
+ "kernelunitlength": "kernelUnitLength",
+ "keypoints": "keyPoints",
+ "keysplines": "keySplines",
+ "keytimes": "keyTimes",
+ "lengthadjust": "lengthAdjust",
+ "limitingconeangle": "limitingConeAngle",
+ "markerheight": "markerHeight",
+ "markerunits": "markerUnits",
+ "markerwidth": "markerWidth",
+ "maskcontentunits": "maskContentUnits",
+ "maskunits": "maskUnits",
+ "numoctaves": "numOctaves",
+ "pathlength": "pathLength",
+ "patterncontentunits": "patternContentUnits",
+ "patterntransform": "patternTransform",
+ "patternunits": "patternUnits",
+ "pointsatx": "pointsAtX",
+ "pointsaty": "pointsAtY",
+ "pointsatz": "pointsAtZ",
+ "preservealpha": "preserveAlpha",
+ "preserveaspectratio": "preserveAspectRatio",
+ "primitiveunits": "primitiveUnits",
+ "refx": "refX",
+ "refy": "refY",
+ "repeatcount": "repeatCount",
+ "repeatdur": "repeatDur",
+ "requiredextensions": "requiredExtensions",
+ "requiredfeatures": "requiredFeatures",
+ "specularconstant": "specularConstant",
+ "specularexponent": "specularExponent",
+ "spreadmethod": "spreadMethod",
+ "startoffset": "startOffset",
+ "stddeviation": "stdDeviation",
+ "stitchtiles": "stitchTiles",
+ "surfacescale": "surfaceScale",
+ "systemlanguage": "systemLanguage",
+ "tablevalues": "tableValues",
+ "targetx": "targetX",
+ "targety": "targetY",
+ "textlength": "textLength",
+ "viewbox": "viewBox",
+ "viewtarget": "viewTarget",
+ "xchannelselector": "xChannelSelector",
+ "ychannelselector": "yChannelSelector",
+ "zoomandpan": "zoomAndPan",
+}
@dataclass
class OpenTElement:
@@ -87,8 +187,10 @@ class TemplateParser(HTMLParser):
stack: list[OpenTag]
placeholders: PlaceholderState
source: SourceTracker | None
+ _svg_depth: int = 0
- def __init__(self, *, convert_charrefs: bool = True):
+ def __init__(self, *, convert_charrefs: bool = True, svg_context: bool = False):
+ self._initial_svg_depth = 1 if svg_context else 0
# This calls HTMLParser.reset() which we override to set up our state.
super().__init__(convert_charrefs=convert_charrefs)
@@ -108,10 +210,15 @@ def append_child(self, child: TNode) -> None:
# Attribute Helpers
# ------------------------------------------
- def make_tattr(self, attr: HTMLAttribute) -> TAttribute:
+ def make_tattr(
+ self, attr: HTMLAttribute, svg_context: bool = False
+ ) -> TAttribute:
"""Build a TAttribute from a raw attribute tuple."""
name, value = attr
+ if svg_context:
+ name = SVG_CASE_FIX.get(name, name)
+
name_ref = self.placeholders.remove_placeholders(name)
value_ref = (
self.placeholders.remove_placeholders(value) if value is not None else None
@@ -136,20 +243,26 @@ def make_tattr(self, attr: HTMLAttribute) -> TAttribute:
)
return TSpreadAttribute(i_index=name_ref.i_indexes[0])
- def make_tattrs(self, attrs: t.Sequence[HTMLAttribute]) -> tuple[TAttribute, ...]:
+ def make_tattrs(
+ self, attrs: t.Sequence[HTMLAttribute], svg_context: bool = False
+ ) -> tuple[TAttribute, ...]:
"""Build TAttributes from raw attribute tuples."""
- return tuple(self.make_tattr(attr) for attr in attrs)
+ return tuple(self.make_tattr(attr, svg_context) for attr in attrs)
# ------------------------------------------
# Tag Helpers
# ------------------------------------------
- def make_open_tag(self, tag: str, attrs: t.Sequence[HTMLAttribute]) -> OpenTag:
+ def make_open_tag(
+ self, tag: str, attrs: t.Sequence[HTMLAttribute], svg_context: bool = False
+ ) -> OpenTag:
"""Build an OpenTag from a raw tag and attribute tuples."""
tag_ref = self.placeholders.remove_placeholders(tag)
if tag_ref.is_literal:
- return OpenTElement(tag=tag, attrs=self.make_tattrs(attrs))
+ return OpenTElement(
+ tag=tag, attrs=self.make_tattrs(attrs, svg_context)
+ )
if not tag_ref.is_singleton:
raise ValueError(
@@ -162,7 +275,7 @@ def make_open_tag(self, tag: str, attrs: t.Sequence[HTMLAttribute]) -> OpenTag:
i_index = tag_ref.i_indexes[0]
return OpenTComponent(
start_i_index=i_index,
- attrs=self.make_tattrs(attrs),
+ attrs=self.make_tattrs(attrs, svg_context),
)
def finalize_tag(
@@ -225,7 +338,13 @@ def validate_end_tag(self, tag: str, open_tag: OpenTag) -> int | None:
# ------------------------------------------
def handle_starttag(self, tag: str, attrs: t.Sequence[HTMLAttribute]) -> None:
- open_tag = self.make_open_tag(tag, attrs)
+ if tag == "svg":
+ self._svg_depth += 1
+
+ if self._svg_depth > 0:
+ tag = SVG_TAG_FIX.get(tag, tag)
+
+ open_tag = self.make_open_tag(tag, attrs, svg_context=(self._svg_depth > 0))
if isinstance(open_tag, OpenTElement) and open_tag.tag in VOID_ELEMENTS:
final_tag = self.finalize_tag(open_tag)
self.append_child(final_tag)
@@ -234,7 +353,13 @@ def handle_starttag(self, tag: str, attrs: t.Sequence[HTMLAttribute]) -> None:
def handle_startendtag(self, tag: str, attrs: t.Sequence[HTMLAttribute]) -> None:
"""Dispatch a self-closing tag, `` to specialized handlers."""
- open_tag = self.make_open_tag(tag, attrs)
+ is_svg_tag = tag == "svg"
+ effective_svg_context = (self._svg_depth > 0) or is_svg_tag
+
+ if effective_svg_context:
+ tag = SVG_TAG_FIX.get(tag, tag)
+
+ open_tag = self.make_open_tag(tag, attrs, svg_context=effective_svg_context)
final_tag = self.finalize_tag(open_tag)
self.append_child(final_tag)
@@ -242,6 +367,12 @@ def handle_endtag(self, tag: str) -> None:
if not self.stack:
raise ValueError(f"Unexpected closing tag {tag}> with no open tag.")
+ if self._svg_depth > 0:
+ tag = SVG_TAG_FIX.get(tag, tag)
+
+ if tag == "svg":
+ self._svg_depth -= 1
+
open_tag = self.stack.pop()
endtag_i_index = self.validate_end_tag(tag, open_tag)
final_tag = self.finalize_tag(open_tag, endtag_i_index)
@@ -285,6 +416,7 @@ def reset(self):
self.stack = []
self.placeholders = PlaceholderState()
self.source = None
+ self._svg_depth = getattr(self, "_initial_svg_depth", 0)
def close(self) -> None:
if self.stack:
@@ -337,13 +469,16 @@ def feed_template(self, template: Template) -> None:
self.feed_str(template.strings[-1])
@staticmethod
- def parse(t: Template) -> TNode:
+ def parse(t: Template, *, svg_context: bool = False) -> TNode:
"""
Parse a Template containing valid HTML and substitutions and return
a TNode tree representing its structure. This cachable structure can later
be resolved against actual interpolation values to produce a Node tree.
+
+ Pass ``svg_context=True`` for SVG fragments that have no ``