From 4e7693991b249dd14e4f77a74f3fd700375eb6cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pressl=2C=20=C5=A0t=C4=9Bp=C3=A1n?= Date: Sat, 27 Apr 2024 02:00:51 +0200 Subject: [PATCH] src/svgutils/compose.py: SVG.__init__ now accepts SVGFigure Two classmethods have been added - fromfile and fromstring. The fromfile method provides the same functionality as the previous constructor. The fromstring method allows for SVG construction from string. --- src/svgutils/compose.py | 46 +++++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/src/svgutils/compose.py b/src/svgutils/compose.py index 0932fe5..5a085bf 100644 --- a/src/svgutils/compose.py +++ b/src/svgutils/compose.py @@ -18,6 +18,7 @@ from svgutils import transform as _transform from svgutils.common import Unit +from svgutils.templates import SVGFigure CONFIG = { "svg.file_path": ".", @@ -109,24 +110,37 @@ class SVG(Element): replace pt units with px units to fix files created with matplotlib """ - def __init__(self, fname=None, fix_mpl=False): + def __init__(self, svg: SVGFigure, fix_mpl=False): + if fix_mpl: + w, h = svg.get_size() + svg.set_size((w.replace("pt", ""), h.replace("pt", ""))) + super(SVG, self).__init__(svg.getroot().root) + + # if height/width is in % units, we can't store the absolute values + if svg.width.endswith("%"): + self._width = None + else: + self._width = Unit(svg.width).to("px") + if svg.height.endswith("%"): + self._height = None + else: + self._height = Unit(svg.height).to("px") + + @classmethod + def fromfile(cls, fname=None, fix_mpl=False): if fname: - fname = os.path.join(CONFIG["svg.file_path"], fname) svg = _transform.fromfile(fname) - if fix_mpl: - w, h = svg.get_size() - svg.set_size((w.replace("pt", ""), h.replace("pt", ""))) - super(SVG, self).__init__(svg.getroot().root) - - # if height/width is in % units, we can't store the absolute values - if svg.width.endswith("%"): - self._width = None - else: - self._width = Unit(svg.width).to("px") - if svg.height.endswith("%"): - self._height = None - else: - self._height = Unit(svg.height).to("px") + return SVG(svg, fix_mpl) + else: + raise TypeError('fname is None!') + + @classmethod + def fromstring(cls, xml_string: str=None, fix_mpl=False): + if xml_string: + svg = _transform.fromstring(xml_string) + return SVG(svg, fix_mpl) + else: + raise TypeError('xml_string is None!') @property def width(self):