diff --git a/src/packaging/errors.py b/src/packaging/errors.py new file mode 100644 index 00000000..d1d47cf6 --- /dev/null +++ b/src/packaging/errors.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import contextlib +import dataclasses +import sys +import typing + +__all__ = ["ExceptionGroup"] + + +def __dir__() -> list[str]: + return __all__ + + +if sys.version_info >= (3, 11): # pragma: no cover + from builtins import ExceptionGroup +else: # pragma: no cover + + class ExceptionGroup(Exception): + """A minimal implementation of :external:exc:`ExceptionGroup` from Python 3.11. + + If :external:exc:`ExceptionGroup` is already defined by Python itself, + that version is used instead. + """ + + message: str + exceptions: list[Exception] + + def __init__(self, message: str, exceptions: list[Exception]) -> None: + self.message = message + self.exceptions = exceptions + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.message!r}, {self.exceptions!r})" + + +@dataclasses.dataclass +class _ErrorCollector: + """ + Collect errors into ExceptionGroups. + + Used like this: + + collector = _ErrorCollector() + # Add a single exception + collector.error(ValueError("one")) + + # Supports nesting, including combining ExceptionGroups + with collector.collect(): + raise ValueError("two") + collector.finalize("Found some errors") + + Since making a collector and then calling finalize later is a common pattern, + a convenience method ``on_exit`` is provided. + """ + + errors: list[Exception] = dataclasses.field(default_factory=list, init=False) + + def finalize(self, msg: str) -> None: + """Raise a group exception if there are any errors.""" + if self.errors: + raise ExceptionGroup(msg, self.errors) + + @contextlib.contextmanager + def on_exit(self, msg: str) -> typing.Generator[_ErrorCollector, None, None]: + """ + Calls finalize if no uncollected errors were present. + + Uncollected errors are raised normally. + """ + yield self + self.finalize(msg) + + @contextlib.contextmanager + def collect(self, *err_cls: type[Exception]) -> typing.Generator[None, None, None]: + """ + Context manager to collect errors into the error list. + + Must be inside loops, as only one error can be collected at a time. + """ + error_classes = err_cls or (Exception,) + try: + yield + except ExceptionGroup as error: + self.errors.extend(error.exceptions) + except error_classes as error: + self.errors.append(error) + + def error( + self, + error: Exception, + ) -> None: + """Add an error to the list.""" + self.errors.append(error) diff --git a/src/packaging/metadata.py b/src/packaging/metadata.py index 4dd08f42..5eb488bb 100644 --- a/src/packaging/metadata.py +++ b/src/packaging/metadata.py @@ -6,7 +6,6 @@ import email.policy import keyword import pathlib -import sys import typing from typing import ( Any, @@ -19,6 +18,7 @@ from . import licenses, requirements, specifiers, utils from . import version as version_module +from .errors import ExceptionGroup, _ErrorCollector if typing.TYPE_CHECKING: from .licenses import NormalizedLicenseExpression @@ -26,28 +26,6 @@ T = typing.TypeVar("T") -if sys.version_info >= (3, 11): # pragma: no cover - ExceptionGroup = ExceptionGroup # noqa: F821 -else: # pragma: no cover - - class ExceptionGroup(Exception): - """A minimal implementation of :external:exc:`ExceptionGroup` from Python 3.11. - - If :external:exc:`ExceptionGroup` is already defined by Python itself, - that version is used instead. - """ - - message: str - exceptions: list[Exception] - - def __init__(self, message: str, exceptions: list[Exception]) -> None: - self.message = message - self.exceptions = exceptions - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.message!r}, {self.exceptions!r})" - - __all__ = [ "InvalidMetadata", "Metadata", @@ -797,13 +775,11 @@ def from_raw(cls, data: RawMetadata, *, validate: bool = True) -> Metadata: ins._raw = data.copy() # Mutations occur due to caching enriched values. if validate: - exceptions: list[Exception] = [] - try: + collector = _ErrorCollector() + metadata_version = None + with collector.collect(InvalidMetadata): metadata_version = ins.metadata_version metadata_age = _VALID_METADATA_VERSIONS.index(metadata_version) - except InvalidMetadata as metadata_version_exc: - exceptions.append(metadata_version_exc) - metadata_version = None # Make sure to check for the fields that are present, the required # fields (so their absence can be reported). @@ -820,7 +796,7 @@ def from_raw(cls, data: RawMetadata, *, validate: bool = True) -> Metadata: field_metadata_version = cls.__dict__[key].added except KeyError: exc = InvalidMetadata(key, f"unrecognized field: {key!r}") - exceptions.append(exc) + collector.error(exc) continue field_age = _VALID_METADATA_VERSIONS.index( field_metadata_version @@ -832,14 +808,13 @@ def from_raw(cls, data: RawMetadata, *, validate: bool = True) -> Metadata: f"{field} introduced in metadata version " f"{field_metadata_version}, not {metadata_version}", ) - exceptions.append(exc) + collector.error(exc) continue getattr(ins, key) except InvalidMetadata as exc: - exceptions.append(exc) + collector.error(exc) - if exceptions: - raise ExceptionGroup("invalid metadata", exceptions) + collector.finalize("invalid metadata") return ins @@ -853,16 +828,13 @@ def from_email(cls, data: bytes | str, *, validate: bool = True) -> Metadata: raw, unparsed = parse_email(data) if validate: - exceptions: list[Exception] = [] - for unparsed_key in unparsed: - if unparsed_key in _EMAIL_TO_RAW_MAPPING: - message = f"{unparsed_key!r} has invalid data" - else: - message = f"unrecognized field: {unparsed_key!r}" - exceptions.append(InvalidMetadata(unparsed_key, message)) - - if exceptions: - raise ExceptionGroup("unparsed", exceptions) + with _ErrorCollector().on_exit("unparsed") as collector: + for unparsed_key in unparsed: + if unparsed_key in _EMAIL_TO_RAW_MAPPING: + message = f"{unparsed_key!r} has invalid data" + else: + message = f"unrecognized field: {unparsed_key!r}" + collector.error(InvalidMetadata(unparsed_key, message)) try: return cls.from_raw(raw, validate=validate) diff --git a/tests/test_errors.py b/tests/test_errors.py new file mode 100644 index 00000000..d138f65d --- /dev/null +++ b/tests/test_errors.py @@ -0,0 +1,108 @@ +import pytest + +import packaging.errors + + +def test_error_collector_collect() -> None: + collector = packaging.errors._ErrorCollector() + + with collector.collect(): + raise ValueError("first error") + + with collector.collect(): + raise KeyError("second error") + + collector.error(TypeError("third error")) + + with pytest.raises(packaging.errors.ExceptionGroup) as exc_info: + collector.finalize("collected errors") + + exception_group = exc_info.value + assert exception_group.message == "collected errors" + assert len(exception_group.exceptions) == 3 + assert isinstance(exception_group.exceptions[0], ValueError) + assert str(exception_group.exceptions[0]) == "first error" + assert isinstance(exception_group.exceptions[1], KeyError) + assert str(exception_group.exceptions[1]) == "'second error'" + assert isinstance(exception_group.exceptions[2], TypeError) + assert str(exception_group.exceptions[2]) == "third error" + + +def test_error_collector_no_errors() -> None: + collector = packaging.errors._ErrorCollector() + + with collector.collect(): + pass # No error + + collector.finalize("no errors") # Should not raise + + +def test_error_collector_exception_group() -> None: + collector = packaging.errors._ErrorCollector() + + with collector.collect(): + raise packaging.errors.ExceptionGroup( + "inner group", + [ValueError("inner error 1"), KeyError("inner error 2")], + ) + + with pytest.raises(packaging.errors.ExceptionGroup) as exc_info: + collector.finalize("outer group") + + exception_group = exc_info.value + assert exception_group.message == "outer group" + assert len(exception_group.exceptions) == 2 + assert isinstance(exception_group.exceptions[0], ValueError) + assert str(exception_group.exceptions[0]) == "inner error 1" + assert isinstance(exception_group.exceptions[1], KeyError) + assert str(exception_group.exceptions[1]) == "'inner error 2'" + + +def test_error_collector_on_exit() -> None: + collector = packaging.errors._ErrorCollector() + + with pytest.raises(packaging.errors.ExceptionGroup) as exc_info, collector.on_exit( + "exiting" + ): + collector.error(ValueError("an error")) + + exception_group = exc_info.value + assert exception_group.message == "exiting" + assert len(exception_group.exceptions) == 1 + assert isinstance(exception_group.exceptions[0], ValueError) + assert str(exception_group.exceptions[0]) == "an error" + + +def test_error_collector_on_exit_no_errors() -> None: + collector = packaging.errors._ErrorCollector() + + with collector.on_exit("exiting"): + pass # No errors added + + +def test_error_collector_collect_specific_exception() -> None: + collector = packaging.errors._ErrorCollector() + + with collector.collect(KeyError): + raise KeyError("a key error") + + with pytest.raises(packaging.errors.ExceptionGroup) as exc_info: + collector.finalize("collected errors") + + exception_group = exc_info.value + assert exception_group.message == "collected errors" + assert len(exception_group.exceptions) == 1 + assert isinstance(exception_group.exceptions[0], KeyError) + assert str(exception_group.exceptions[0]) == "'a key error'" + + +def test_error_collector_collect_unmatched_exception() -> None: + collector = packaging.errors._ErrorCollector() + + # Now test that other exceptions are not collected + with pytest.raises( + ValueError, match="a value error" + ) as exc_info, collector.collect(KeyError): + raise ValueError("a value error") + + assert str(exc_info.value) == "a value error" diff --git a/tests/test_metadata.py b/tests/test_metadata.py index db7e46ea..0af214d2 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -1,14 +1,18 @@ from __future__ import annotations -import email.message +import email import inspect import pathlib import textwrap +import typing import pytest from packaging import metadata, requirements, specifiers, utils, version -from packaging.metadata import ExceptionGroup, RawMetadata +from packaging.errors import ExceptionGroup + +if typing.TYPE_CHECKING: + from packaging.metadata import RawMetadata class TestRawMetadata: @@ -259,13 +263,13 @@ def test_complete(self) -> None: class TestExceptionGroup: def test_attributes(self) -> None: individual_exception = Exception("not important") - exc = metadata.ExceptionGroup("message", [individual_exception]) + exc = ExceptionGroup("message", [individual_exception]) assert exc.message == "message" assert list(exc.exceptions) == [individual_exception] def test_repr(self) -> None: individual_exception = RuntimeError("not important") - exc = metadata.ExceptionGroup("message", [individual_exception]) + exc = ExceptionGroup("message", [individual_exception]) assert individual_exception.__class__.__name__ in repr(exc)