diff --git a/lib/yaml/representer.py b/lib/yaml/representer.py index 808ca06df..601ed6304 100644 --- a/lib/yaml/representer.py +++ b/lib/yaml/representer.py @@ -5,7 +5,11 @@ from .error import * from .nodes import * -import datetime, copyreg, types, base64, collections +import datetime, copyreg, types, base64, collections, sys + +if sys.version_info >= (3, 11): + from enum import IntEnum, StrEnum + class RepresenterError(YAMLError): pass @@ -33,6 +37,9 @@ def represent(self, data): def represent_data(self, data): if self.ignore_aliases(data): self.alias_key = None + if sys.version_info >= (3, 11): + if isinstance(data, (IntEnum, StrEnum)): + data = data.value else: self.alias_key = id(data) if self.alias_key is not None: diff --git a/tests/test_dump_load.py b/tests/test_dump_load.py index 8c4352bd1..1cbcc9d6f 100644 --- a/tests/test_dump_load.py +++ b/tests/test_dump_load.py @@ -1,3 +1,5 @@ +import sys + import pytest import yaml @@ -13,3 +15,13 @@ def test_load_no_loader(): def test_load_safeloader(): assert yaml.load("- foo\n", Loader=yaml.SafeLoader) + + +@pytest.mark.skipif(sys.version_info < (3, 11), reason="Requires Python 3.11 or higher") +def test_dump_str_enum(): + from enum import StrEnum + + class ContentType(StrEnum): + YAML = "YAML" + + assert yaml.safe_load(yaml.safe_dump(ContentType.YAML)) == ContentType.YAML