diff --git a/src/svgutils/compose.py b/src/svgutils/compose.py index 0932fe5..5a085bf 100644 --- a/src/svgutils/compose.py +++ b/src/svgutils/compose.py @@ -18,6 +18,7 @@ from svgutils import transform as _transform from svgutils.common import Unit +from svgutils.templates import SVGFigure CONFIG = { "svg.file_path": ".", @@ -109,24 +110,37 @@ class SVG(Element): replace pt units with px units to fix files created with matplotlib """ - def __init__(self, fname=None, fix_mpl=False): + def __init__(self, svg: SVGFigure, fix_mpl=False): + if fix_mpl: + w, h = svg.get_size() + svg.set_size((w.replace("pt", ""), h.replace("pt", ""))) + super(SVG, self).__init__(svg.getroot().root) + + # if height/width is in % units, we can't store the absolute values + if svg.width.endswith("%"): + self._width = None + else: + self._width = Unit(svg.width).to("px") + if svg.height.endswith("%"): + self._height = None + else: + self._height = Unit(svg.height).to("px") + + @classmethod + def fromfile(cls, fname=None, fix_mpl=False): if fname: - fname = os.path.join(CONFIG["svg.file_path"], fname) svg = _transform.fromfile(fname) - if fix_mpl: - w, h = svg.get_size() - svg.set_size((w.replace("pt", ""), h.replace("pt", ""))) - super(SVG, self).__init__(svg.getroot().root) - - # if height/width is in % units, we can't store the absolute values - if svg.width.endswith("%"): - self._width = None - else: - self._width = Unit(svg.width).to("px") - if svg.height.endswith("%"): - self._height = None - else: - self._height = Unit(svg.height).to("px") + return SVG(svg, fix_mpl) + else: + raise TypeError('fname is None!') + + @classmethod + def fromstring(cls, xml_string: str=None, fix_mpl=False): + if xml_string: + svg = _transform.fromstring(xml_string) + return SVG(svg, fix_mpl) + else: + raise TypeError('xml_string is None!') @property def width(self):