diff --git a/lib/yaml/constructor.py b/lib/yaml/constructor.py index 619acd307..1ffd3d5f7 100644 --- a/lib/yaml/constructor.py +++ b/lib/yaml/constructor.py @@ -20,6 +20,7 @@ class BaseConstructor: yaml_constructors = {} yaml_multi_constructors = {} + yaml_flatteners = {} def __init__(self): self.constructed_objects = {} @@ -168,6 +169,12 @@ def add_multi_constructor(cls, tag_prefix, multi_constructor): cls.yaml_multi_constructors = cls.yaml_multi_constructors.copy() cls.yaml_multi_constructors[tag_prefix] = multi_constructor + @classmethod + def add_flattener(cls, tag, flattener): + if not 'yaml_flatteners' in cls.__dict__: + cls.yaml_flatteners = cls.yaml_flatteners.copy() + cls.yaml_flatteners[tag] = flattener + class SafeConstructor(BaseConstructor): def construct_scalar(self, node): @@ -179,38 +186,22 @@ def construct_scalar(self, node): def flatten_mapping(self, node): merge = [] - index = 0 - while index < len(node.value): - key_node, value_node = node.value[index] + for key_node, value_node in node.value: if key_node.tag == 'tag:yaml.org,2002:merge': - del node.value[index] - if isinstance(value_node, MappingNode): - self.flatten_mapping(value_node) - merge.extend(value_node.value) - elif isinstance(value_node, SequenceNode): - submerge = [] - for subnode in value_node.value: - if not isinstance(subnode, MappingNode): - raise ConstructorError("while constructing a mapping", - node.start_mark, - "expected a mapping for merging, but found %s" - % subnode.id, subnode.start_mark) - self.flatten_mapping(subnode) - submerge.append(subnode.value) - submerge.reverse() - for value in submerge: + flattener = self.yaml_flatteners.get(value_node.tag) + if flattener: + for value in flattener(self, value_node): merge.extend(value) else: raise ConstructorError("while constructing a mapping", node.start_mark, "expected a mapping or list of mappings for merging, but found %s" % value_node.id, value_node.start_mark) - elif key_node.tag == 'tag:yaml.org,2002:value': - key_node.tag = 'tag:yaml.org,2002:str' - index += 1 else: - index += 1 - if merge: - node.value = merge + node.value + if key_node.tag == 'tag:yaml.org,2002:value': + key_node.tag = 'tag:yaml.org,2002:str' + merge.append((key_node, value_node)) + + node.value = merge def construct_mapping(self, node, deep=False): if isinstance(node, MappingNode): @@ -428,6 +419,30 @@ def construct_undefined(self, node): "could not determine a constructor for the tag %r" % node.tag, node.start_mark) + def flatten_yaml_seq(self, node): + submerge = [] + for subnode in node.value: + # we need to flatten each item in the seq, most likely they'll be mappings, + # but we need to allow for custom flatteners as well. + flattener = self.yaml_flatteners.get(subnode.tag) + if flattener: + for value in flattener(self, subnode): + submerge.append(value) + else: + raise ConstructorError("while constructing a mapping", + node.start_mark, + "expected a mapping for merging, but found %s" + % subnode.id, subnode.start_mark) + + submerge.reverse() + for value in submerge: + yield value + + def flatten_yaml_map(self, node): + self.flatten_mapping(node) + yield node.value + + SafeConstructor.add_constructor( 'tag:yaml.org,2002:null', SafeConstructor.construct_yaml_null) @@ -479,6 +494,14 @@ def construct_undefined(self, node): SafeConstructor.add_constructor(None, SafeConstructor.construct_undefined) +SafeConstructor.add_flattener( + 'tag:yaml.org,2002:seq', + SafeConstructor.flatten_yaml_seq) + +SafeConstructor.add_flattener( + 'tag:yaml.org,2002:map', + SafeConstructor.flatten_yaml_map) + class FullConstructor(SafeConstructor): # 'extend' is blacklisted because it is used by # construct_python_object_apply to add `listitems` to a newly generate diff --git a/tests/test_constructor_flattener.py b/tests/test_constructor_flattener.py new file mode 100644 index 000000000..24d3b6f2a --- /dev/null +++ b/tests/test_constructor_flattener.py @@ -0,0 +1,138 @@ +import pytest +import yaml +from io import StringIO + + +def test_custom_flattener_with_compose(): + """ + Test that a custom flattener can use compose to parse embedded YAML. + """ + + def flatten_yaml_string(constructor, node): + """ + Flattener that parses YAML from a string scalar. + Used to test that custom flatteners can call compose. + """ + if isinstance(node, yaml.ScalarNode): + yaml_string = constructor.construct_scalar(node) + parsed_node = yaml.compose(StringIO(yaml_string), yaml.SafeLoader) + + if isinstance(parsed_node, yaml.MappingNode): + constructor.flatten_mapping(parsed_node) + yield parsed_node.value + else: + raise yaml.ConstructorError( + "while constructing a mapping", + node.start_mark, + "expected a string containing YAML for !yaml_string tag", + node.start_mark + ) + + class CustomLoader(yaml.SafeLoader): + pass + + CustomLoader.add_flattener('!yaml_string', flatten_yaml_string) + + test_yaml = """ +base: + x: 1 + y: 2 + +merged: + <<: !yaml_string "x: 10\\ny: 20\\nz: 30" + label: test +""" + + result = yaml.load(test_yaml, CustomLoader) + + assert result['merged']['x'] == 10 + assert result['merged']['y'] == 20 + assert result['merged']['z'] == 30 + assert result['merged']['label'] == 'test' + + +def test_custom_flattener_with_sequence(): + """ + Test custom flattener with sequence of YAML strings (multiple merges). + """ + + def flatten_yaml_string(constructor, node): + if isinstance(node, yaml.ScalarNode): + yaml_string = constructor.construct_scalar(node) + parsed_node = yaml.compose(StringIO(yaml_string), yaml.SafeLoader) + + if isinstance(parsed_node, yaml.MappingNode): + constructor.flatten_mapping(parsed_node) + yield parsed_node.value + else: + raise yaml.ConstructorError( + "while constructing a mapping", + node.start_mark, + "expected a string containing YAML for !yaml_string tag", + node.start_mark + ) + + class CustomLoader(yaml.SafeLoader): + pass + + CustomLoader.add_flattener('!yaml_string', flatten_yaml_string) + + test_yaml = """ +merged: + <<: [!yaml_string "x: 1", !yaml_string "y: 2", !yaml_string "z: 3"] + label: multi +""" + + result = yaml.load(test_yaml, CustomLoader) + + assert result['merged']['x'] == 1 + assert result['merged']['y'] == 2 + assert result['merged']['z'] == 3 + assert result['merged']['label'] == 'multi' + + +def test_custom_flattener_override(): + """ + Test that later values override earlier ones in merge sequence. + + It also lightly verifies that the original/standard MappingNode + inside of a sequence functions as expected. + """ + + def flatten_yaml_string(constructor, node): + if isinstance(node, yaml.ScalarNode): + yaml_string = constructor.construct_scalar(node) + parsed_node = yaml.compose(StringIO(yaml_string), yaml.SafeLoader) + + if isinstance(parsed_node, yaml.MappingNode): + constructor.flatten_mapping(parsed_node) + yield parsed_node.value + else: + raise yaml.ConstructorError( + "while constructing a mapping", + node.start_mark, + "expected a string containing YAML for !yaml_string tag", + node.start_mark + ) + + class CustomLoader(yaml.SafeLoader): + pass + + CustomLoader.add_flattener('!yaml_string', flatten_yaml_string) + + test_yaml = """ +merged: + v: 0 + w: overwritten + <<: [{w: 1, x: overwritten, y: 3}, !yaml_string "x: 10"] + x: 2 + z: 4 +""" + + result = yaml.load(test_yaml, CustomLoader) + + assert result['merged']['v'] == 0 + assert result['merged']['w'] == 1 + assert result['merged']['x'] == 2 + assert result['merged']['y'] == 3 + assert result['merged']['z'] == 4 \ No newline at end of file