Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 48 additions & 25 deletions lib/yaml/constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class BaseConstructor:

yaml_constructors = {}
yaml_multi_constructors = {}
yaml_flatteners = {}

def __init__(self):
self.constructed_objects = {}
Expand Down Expand Up @@ -168,6 +169,12 @@ def add_multi_constructor(cls, tag_prefix, multi_constructor):
cls.yaml_multi_constructors = cls.yaml_multi_constructors.copy()
cls.yaml_multi_constructors[tag_prefix] = multi_constructor

@classmethod
def add_flattener(cls, tag, flattener):
if not 'yaml_flatteners' in cls.__dict__:
cls.yaml_flatteners = cls.yaml_flatteners.copy()
cls.yaml_flatteners[tag] = flattener

class SafeConstructor(BaseConstructor):

def construct_scalar(self, node):
Expand All @@ -179,38 +186,22 @@ def construct_scalar(self, node):

def flatten_mapping(self, node):
merge = []
index = 0
while index < len(node.value):
key_node, value_node = node.value[index]
for key_node, value_node in node.value:
if key_node.tag == 'tag:yaml.org,2002:merge':
del node.value[index]
if isinstance(value_node, MappingNode):
self.flatten_mapping(value_node)
merge.extend(value_node.value)
elif isinstance(value_node, SequenceNode):
submerge = []
for subnode in value_node.value:
if not isinstance(subnode, MappingNode):
raise ConstructorError("while constructing a mapping",
node.start_mark,
"expected a mapping for merging, but found %s"
% subnode.id, subnode.start_mark)
self.flatten_mapping(subnode)
submerge.append(subnode.value)
submerge.reverse()
for value in submerge:
flattener = self.yaml_flatteners.get(value_node.tag)
if flattener:
for value in flattener(self, value_node):
merge.extend(value)
else:
raise ConstructorError("while constructing a mapping", node.start_mark,
"expected a mapping or list of mappings for merging, but found %s"
% value_node.id, value_node.start_mark)
elif key_node.tag == 'tag:yaml.org,2002:value':
key_node.tag = 'tag:yaml.org,2002:str'
index += 1
else:
index += 1
if merge:
node.value = merge + node.value
if key_node.tag == 'tag:yaml.org,2002:value':
key_node.tag = 'tag:yaml.org,2002:str'
merge.append((key_node, value_node))

node.value = merge

def construct_mapping(self, node, deep=False):
if isinstance(node, MappingNode):
Expand Down Expand Up @@ -428,6 +419,30 @@ def construct_undefined(self, node):
"could not determine a constructor for the tag %r" % node.tag,
node.start_mark)

def flatten_yaml_seq(self, node):
submerge = []
for subnode in node.value:
# we need to flatten each item in the seq, most likely they'll be mappings,
# but we need to allow for custom flatteners as well.
flattener = self.yaml_flatteners.get(subnode.tag)
if flattener:
for value in flattener(self, subnode):
submerge.append(value)
else:
raise ConstructorError("while constructing a mapping",
node.start_mark,
"expected a mapping for merging, but found %s"
% subnode.id, subnode.start_mark)

submerge.reverse()
for value in submerge:
yield value

def flatten_yaml_map(self, node):
self.flatten_mapping(node)
yield node.value


SafeConstructor.add_constructor(
'tag:yaml.org,2002:null',
SafeConstructor.construct_yaml_null)
Expand Down Expand Up @@ -479,6 +494,14 @@ def construct_undefined(self, node):
SafeConstructor.add_constructor(None,
SafeConstructor.construct_undefined)

SafeConstructor.add_flattener(
'tag:yaml.org,2002:seq',
SafeConstructor.flatten_yaml_seq)

SafeConstructor.add_flattener(
'tag:yaml.org,2002:map',
SafeConstructor.flatten_yaml_map)

class FullConstructor(SafeConstructor):
# 'extend' is blacklisted because it is used by
# construct_python_object_apply to add `listitems` to a newly generate
Expand Down
138 changes: 138 additions & 0 deletions tests/test_constructor_flattener.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import pytest
import yaml
from io import StringIO


def test_custom_flattener_with_compose():
"""
Test that a custom flattener can use compose to parse embedded YAML.
"""

def flatten_yaml_string(constructor, node):
"""
Flattener that parses YAML from a string scalar.
Used to test that custom flatteners can call compose.
"""
if isinstance(node, yaml.ScalarNode):
yaml_string = constructor.construct_scalar(node)
parsed_node = yaml.compose(StringIO(yaml_string), yaml.SafeLoader)

if isinstance(parsed_node, yaml.MappingNode):
constructor.flatten_mapping(parsed_node)
yield parsed_node.value
else:
raise yaml.ConstructorError(
"while constructing a mapping",
node.start_mark,
"expected a string containing YAML for !yaml_string tag",
node.start_mark
)

class CustomLoader(yaml.SafeLoader):
pass

CustomLoader.add_flattener('!yaml_string', flatten_yaml_string)

test_yaml = """
base:
x: 1
y: 2

merged:
<<: !yaml_string "x: 10\\ny: 20\\nz: 30"
label: test
"""

result = yaml.load(test_yaml, CustomLoader)

assert result['merged']['x'] == 10
assert result['merged']['y'] == 20
assert result['merged']['z'] == 30
assert result['merged']['label'] == 'test'


def test_custom_flattener_with_sequence():
"""
Test custom flattener with sequence of YAML strings (multiple merges).
"""

def flatten_yaml_string(constructor, node):
if isinstance(node, yaml.ScalarNode):
yaml_string = constructor.construct_scalar(node)
parsed_node = yaml.compose(StringIO(yaml_string), yaml.SafeLoader)

if isinstance(parsed_node, yaml.MappingNode):
constructor.flatten_mapping(parsed_node)
yield parsed_node.value
else:
raise yaml.ConstructorError(
"while constructing a mapping",
node.start_mark,
"expected a string containing YAML for !yaml_string tag",
node.start_mark
)

class CustomLoader(yaml.SafeLoader):
pass

CustomLoader.add_flattener('!yaml_string', flatten_yaml_string)

test_yaml = """
merged:
<<: [!yaml_string "x: 1", !yaml_string "y: 2", !yaml_string "z: 3"]
label: multi
"""

result = yaml.load(test_yaml, CustomLoader)

assert result['merged']['x'] == 1
assert result['merged']['y'] == 2
assert result['merged']['z'] == 3
assert result['merged']['label'] == 'multi'


def test_custom_flattener_override():
"""
Test that later values override earlier ones in merge sequence.

It also lightly verifies that the original/standard MappingNode
inside of a sequence functions as expected.
"""

def flatten_yaml_string(constructor, node):
if isinstance(node, yaml.ScalarNode):
yaml_string = constructor.construct_scalar(node)
parsed_node = yaml.compose(StringIO(yaml_string), yaml.SafeLoader)

if isinstance(parsed_node, yaml.MappingNode):
constructor.flatten_mapping(parsed_node)
yield parsed_node.value
else:
raise yaml.ConstructorError(
"while constructing a mapping",
node.start_mark,
"expected a string containing YAML for !yaml_string tag",
node.start_mark
)

class CustomLoader(yaml.SafeLoader):
pass

CustomLoader.add_flattener('!yaml_string', flatten_yaml_string)

test_yaml = """
merged:
v: 0
w: overwritten
<<: [{w: 1, x: overwritten, y: 3}, !yaml_string "x: 10"]
x: 2
z: 4
"""

result = yaml.load(test_yaml, CustomLoader)

assert result['merged']['v'] == 0
assert result['merged']['w'] == 1
assert result['merged']['x'] == 2
assert result['merged']['y'] == 3
assert result['merged']['z'] == 4