From df646dc9281ba23f50e2f460fc49af5fea09d96d Mon Sep 17 00:00:00 2001 From: Marko Ristin-Kaufmann Date: Sat, 25 Oct 2025 15:54:26 +0200 Subject: [PATCH] Fix child invariants checked in `super().__init__` Previously, we determined the invariants based on the `self` passed to the function. However, in case of `super().__init__`, the invariants that need to be checked after the call are those belonging to to the super class, not the current (child) class. This change fixes the issue by passing in the class corresponding to the invariants alongside function and parameters, so that we can properly resolve which invariants need to be checked. Fixes #300. --- icontract/_checkers.py | 113 ++++++-- tests/test_inheritance_invariant.py | 397 +++++++++++++++++++++++++++- 2 files changed, 491 insertions(+), 19 deletions(-) diff --git a/icontract/_checkers.py b/icontract/_checkers.py index 0503317..24fa3d4 100644 --- a/icontract/_checkers.py +++ b/icontract/_checkers.py @@ -979,13 +979,14 @@ def wrapper(*args, **kwargs): # type: ignore return wrapper # type: ignore -def _decorate_with_invariants(func: CallableT, is_init: bool) -> CallableT: +def _decorate_with_invariants(func: CallableT, cls: ClassT, is_init: bool) -> CallableT: """ Decorate the method ``func`` with invariant checks. If the function has been already decorated with invariant checks, the function returns immediately. :param func: function to be wrapped + :param cls: class corresponding to the invariant and ``func`` :param is_init: True if the ``func`` is __init__ :return: function wrapped with invariant checks """ @@ -1027,7 +1028,25 @@ def wrapper(*args, **kwargs): # type: ignore try: result = func(*args, **kwargs) - for invariant in instance.__class__.__invariants__: + # NOTE (mristin): + # We go to the invariants corresponding to the class, not the instance, as we need to + # account also for a situation where super().__init__ is called. Here is an example: + # + # @invariant(lambda self: ...) + # class A(DBC): + # pass + # + # @invariant(lambda self: ...) + # class B(A): + # def __init__(self) -> None: + # super().__init__() + # # ↖ After this call, only the invariants of A, but not B, have to be checked. + # # + # # However, the ``instance`` (i.e., resolved ``self``) in super().__init__ call points to + # # an instance of B, so instance.__cls__.__invariants__ refer to invariants of B, not A. + + # noinspection PyUnresolvedReferences + for invariant in cls.__invariants__: # type: ignore _assert_invariant(contract=invariant, instance=instance) return result @@ -1273,35 +1292,97 @@ def add_invariant_checks(cls: ClassT) -> None: ) ) - if init_func: - # We have to distinguish this special case which is used by named - # tuples and possibly other optimized data structures. - # In those cases, we have to wrap __new__ instead of __init__. - if init_func == object.__init__ and hasattr(cls, "__new__"): - new_func = getattr(cls, "__new__") - setattr(cls, "__new__", _decorate_new_with_invariants(new_func)) - else: - wrapper = _decorate_with_invariants(func=init_func, is_init=True) - setattr(cls, init_func.__name__, wrapper) + assert init_func is not None, "Every class in Python must have a constructor." + + # We must handle this special case which is used by named + # tuples and possibly other optimized data structures. + # In those cases, we have to wrap __new__ instead of __init__. + if init_func == object.__init__ and hasattr(cls, "__new__"): + new_func = getattr(cls, "__new__") + setattr(cls, "__new__", _decorate_new_with_invariants(new_func)) + else: + # NOTE (mristin): + # We have to create a new __init__ function so that the invariants of *this* class are checked. + # The problem arises due to two different cases related to inheritance which we can not distinguish in Python. + # Namely, we can not know whether we are dealing with invariants coming from ``super().__init__`` or + # an implicit call to ``__init__``. + # + # In both of these edge cases, the instance is of the child class, but the constructors are + # of the parent class. Checking the invariants attached to the class would break the second case, while checking + # the invariants attached to the instance (through ``self.__class__.__invariants__``) would break the first + # case. + # + # The following snippets depict the two cases. + # + # Case 1: ``super().__init__`` + # @invariant(lambda self: ...) + # class A(DBC): + # pass + # + # @invariant(lambda self: ...) + # class B(A): + # def __init__(self) -> None: + # super().__init__() + # # ↖ After this call, only the invariants of A, but not B, have to be checked. + # # More code follows, and after this ``__init__``, invariants of B have to be checked. + # + # Case 2: Implicit ``__init__`` call + # @invariant(lambda self: ...) + # class A(DBC): + # pass + # + # @invariant(lambda self: ...) + # class B(A): + # pass + # + # b = B() + # # ↖ After this call, the invariants of B have to be checked. + # # However, we only see the call to A.__init__, since there is no B.__init__. + # + # Therefore, to avert this problem, we have to create an ``__init__`` in the child class for the second + # case. This allows us to always check for invariants attached to the class in the case of constructors, so both + # cases can be successfully handled. + + if "__init__" not in cls.__dict__: + init_after_mro = ( + # NOTE (mristin): + # mypy gives us the following warning: + # Accessing "__init__" on an instance is unsound, since instance.__init__ could be from an incompatible + # subclass + # + # ... but this is exactly what we want here -- we want to look up the __init__ of the class at runtime. + cls.__init__ # type: ignore + ) # This is the constructor after MRO, pointing to one of the parent classes. + + def __init__(self: Any, *args: Any, **kwargs: Any) -> None: + init_after_mro(self, *args, **kwargs) + + # NOTE (mristin): + # See the comment above corresponding to this mypy warning. + cls.__init__ = __init__ # type: ignore + init_func = __init__ + + wrapper = _decorate_with_invariants(func=init_func, cls=cls, is_init=True) + setattr(cls, init_func.__name__, wrapper) for name, func in names_funcs: - wrapper = _decorate_with_invariants(func=func, is_init=False) + wrapper = _decorate_with_invariants(func=func, cls=cls, is_init=False) setattr(cls, name, wrapper) for name, prop in names_properties: new_prop = property( fget=( - _decorate_with_invariants(func=prop.fget, is_init=False) + _decorate_with_invariants(func=prop.fget, cls=cls, is_init=False) if prop.fget else None ), fset=( - _decorate_with_invariants(func=prop.fset, is_init=False) + _decorate_with_invariants(func=prop.fset, cls=cls, is_init=False) if prop.fset else None ), fdel=( - _decorate_with_invariants(func=prop.fdel, is_init=False) + _decorate_with_invariants(func=prop.fdel, cls=cls, is_init=False) if prop.fdel else None ), diff --git a/tests/test_inheritance_invariant.py b/tests/test_inheritance_invariant.py index 4c9b48c..06ec522 100644 --- a/tests/test_inheritance_invariant.py +++ b/tests/test_inheritance_invariant.py @@ -162,6 +162,49 @@ class CTrue(B): # This produced an unexpected violation error. CTrue().do_something() + def test_calling_super_init_checks_invariants_of_parent_not_child(self) -> None: + # NOTE (mristin): + # This test is a regression test from: + # https://github.com/Parquery/icontract/issues/300 + + @icontract.invariant(lambda self: True) + class A(icontract.DBC): + def __init__(self) -> None: + pass + + @icontract.invariant(lambda self: self.x > 0) + class B(A): + def __init__(self, x: int) -> None: + super().__init__() + self.x = x + + _ = B(2) + + def test_calling_multi_super_init_checks_invariants_of_parent_not_child( + self, + ) -> None: + # NOTE (mristin): + # This test is a regression test from: + # https://github.com/Parquery/icontract/issues/300 + + @icontract.invariant(lambda self: True) + class A1(icontract.DBC): + def __init__(self) -> None: + pass + + @icontract.invariant(lambda self: True) + class A2(icontract.DBC): + def __init__(self) -> None: + pass + + @icontract.invariant(lambda self: self.x > 0) + class B(A1, A2): + def __init__(self, x: int) -> None: + super().__init__() + self.x = x + + _ = B(2) + class TestViolation(unittest.TestCase): def test_inherited(self) -> None: @@ -242,12 +285,11 @@ def __init__(self) -> None: self.x = 10 def __repr__(self) -> str: - return "an instance of A" + return "an instance of {}".format(self.__class__.__name__) @icontract.invariant(lambda self: self.x > 100) class B(A): - def __repr__(self) -> str: - return "an instance of B" + pass violation_error = None # type: Optional[icontract.ViolationError] try: @@ -377,6 +419,355 @@ def __repr__(self) -> str: tests.error.wo_mandatory_location(str(violation_error)), ) + def test_violated_in_super_init(self) -> None: + # NOTE (mristin): + # This test is a regression test from: + # https://github.com/Parquery/icontract/issues/300 + + @icontract.invariant(lambda self: self.x > 0) + class A(icontract.DBC): + def __init__(self, x: int): + self.x = x + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda self: self.x > 10) + class B(A): + def __init__(self, x: int): # pylint: disable=useless-parent-delegation + super().__init__(x) + + violation_error = None # type: Optional[icontract.ViolationError] + try: + _ = B(-1) + except icontract.ViolationError as err: + violation_error = err + + self.assertIsNotNone(violation_error) + self.assertEqual( + # NOTE (mristin): + # This is an exception expected to be raised after super().__init__ call, not after B.__init__. + textwrap.dedent( + """\ + self.x > 0: + self was an instance of B + self.x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) + + def test_violated_in_one_of_multi_super_init(self) -> None: + # NOTE (mristin): + # This test is a regression test from: + # https://github.com/Parquery/icontract/issues/300 + + @icontract.invariant(lambda self: self.x > 0) + class A1(icontract.DBC): + def __init__(self, x: int): + self.x = x + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda self: self.x > 2) + class A2(icontract.DBC): + def __init__(self, x: int): + self.x = x + + @icontract.invariant(lambda self: self.x > 10) + class B(A1, A2): + def __init__(self, x: int): + A1.__init__(self, x) + A2.__init__(self, x) + + violation_error = None # type: Optional[icontract.ViolationError] + try: + # NOTE (mristin): + # This is expected to violate __init__ in A2 but not A1. + _ = B(1) + except icontract.ViolationError as err: + violation_error = err + + self.assertIsNotNone(violation_error) + self.assertEqual( + # NOTE (mristin): + # ``self.x > 2`` indiactes a violation in __init__ of A2 but not A1 and not B. + textwrap.dedent( + """\ + self.x > 2: + self was an instance of B + self.x was 1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) + + def test_super_init_passes_but_init_violated(self) -> None: + # NOTE (mristin): + # This test is a regression test from: + # https://github.com/Parquery/icontract/issues/300 + + @icontract.invariant(lambda self: self.x > 0) + class A(icontract.DBC): + def __init__(self, x: int): + self.x = x + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda self: self.x > 10) + class B(A): + def __init__(self, x: int): # pylint: disable=useless-parent-delegation + super().__init__(x) + + violation_error = None # type: Optional[icontract.ViolationError] + try: + _ = B(2) + except icontract.ViolationError as err: + violation_error = err + + self.assertIsNotNone(violation_error) + self.assertEqual( + # NOTE (mristin): + # This is an exception expected to be raised after B.__init__, but not earlier after super().__init__ call. + textwrap.dedent( + """\ + self.x > 10: + self was an instance of B + self.x was 2""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) + + def test_multi_super_inits_pass_but_init_violated(self) -> None: + # NOTE (mristin): + # This test is a regression test from: + # https://github.com/Parquery/icontract/issues/300 + + @icontract.invariant(lambda self: self.x > 0) + class A1(icontract.DBC): + def __init__(self, x: int): + self.x = x + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda self: self.x > 2) + class A2(icontract.DBC): + def __init__(self, x: int): + self.x = x + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda self: self.x > 10) + class B(A1, A2): + def __init__(self, x: int): + A1.__init__(self, x) + A2.__init__(self, x) + + violation_error = None # type: Optional[icontract.ViolationError] + try: + _ = B(3) + except icontract.ViolationError as err: + violation_error = err + + self.assertIsNotNone(violation_error) + self.assertEqual( + # NOTE (mristin): + # This is an exception expected to be raised after B.__init__, but not in any super __init__ calls. + textwrap.dedent( + """\ + self.x > 10: + self was an instance of B + self.x was 3""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) + + def test_invariant_violated_in_parent_in_implicit_init(self) -> None: + @icontract.invariant(lambda self: self.x > 0) + class A(icontract.DBC): + def __init__(self, x: int) -> None: + self.x = x + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda self: self.x > 10) + class B(A): + # NOTE (mristin): + # We explicitly define no constructor (``__init__``) here. The constructor is implicitly called + # from A (``A.__init__``). However, after the call to the implicit constructor, the invariants of B, + # and not only A, should apply. + pass + + # NOTE (mristin): + # We check that the invariants of the parent are violated here. + violation_error = None # type: Optional[icontract.ViolationError] + try: + _ = B(-1) + except icontract.ViolationError as err: + violation_error = err + + self.assertIsNotNone(violation_error) + self.assertEqual( + # NOTE (mristin): + # This is a violation expected to be raised after A.__init__ which is delegated to from B.__init__. + textwrap.dedent( + """\ + self.x > 0: + self was an instance of B + self.x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) + + def test_invariant_violated_in_one_of_parents_in_implicit_init(self) -> None: + @icontract.invariant(lambda self: self.x > 0) + class A1(icontract.DBC): + def __init__(self, x: int) -> None: + self.x = x + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda self: self.x > 5) + class A2(icontract.DBC): + def __init__(self, x: int) -> None: + self.x = x + + class B(A1, A2): + # NOTE (mristin): + # We explicitly define no constructor (``__init__``) here. The constructor is implicitly called + # from one of the parents (``A1.__init__`` or ``A2.__init__``). However, after the call to the implicit + # constructor, the invariants of B, and not only of A1 and A2, should apply. + pass + + # NOTE (mristin): + # We check that the invariants of one of the parents are violated here. + violation_error = None # type: Optional[icontract.ViolationError] + try: + _ = B(3) + except icontract.ViolationError as err: + violation_error = err + + self.assertIsNotNone(violation_error) + self.assertEqual( + # NOTE (mristin): + # This is a violation expected to be raised after A2.__init__ which is delegated to from B.__init__. + textwrap.dedent( + """\ + self.x > 5: + self was an instance of B + self.x was 3""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) + + def test_invariant_violated_in_child_in_implicit_init(self) -> None: + @icontract.invariant(lambda self: self.x > 0) + class A(icontract.DBC): + def __init__(self, x: int) -> None: + self.x = x + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda self: self.x > 10) + class B(A): + # NOTE (mristin): + # We explicitly define no constructor (``__init__``) here. The constructor is implicitly called + # from A (``A.__init__``). However, after the call to the implicit constructor, the invariants of B, + # and not only A, should apply. + pass + + # NOTE (mristin): + # We make sure explicitly that the invariants did not leak to the parent -- the invariants of B should not + # apply to an instance of A. + _ = A(3) + + # NOTE (mristin): + # We also double-check that the invariants pass correctly. + _ = B(11) + + # NOTE (mristin): + # Now we check that the invariants are violated which we exactly expect. + violation_error = None # type: Optional[icontract.ViolationError] + try: + _ = B(3) + except icontract.ViolationError as err: + violation_error = err + + self.assertIsNotNone(violation_error) + self.assertEqual( + # NOTE (mristin): + # This is a violation expected to be raised after B.__init__ even though B.__init__ is simply delegated to + # A.__init__. + textwrap.dedent( + """\ + self.x > 10: + self was an instance of B + self.x was 3""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) + + def test_invariant_violated_in_child_in_implicit_init_with_multiple_inheritance( + self, + ) -> None: + @icontract.invariant(lambda self: self.x > 0) + class A1(icontract.DBC): + def __init__(self, x: int) -> None: + self.x = x + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda self: self.x > 5) + class A2(icontract.DBC): + def __init__(self, x: int) -> None: + self.x = x + + @icontract.invariant(lambda self: self.x > 10) + class B(A1, A2): + # NOTE (mristin): + # We explicitly define no constructor (``__init__``) here. The constructor is implicitly called + # from one of the super classes. However, after the call to the implicit constructor, the invariants of B, + # and not only of A1 and A2, should apply. + pass + + # NOTE (mristin): + # We make sure explicitly that the invariants did not leak to the parent -- the invariants of B should not + # apply to an instance of A1 or of A2. + _ = A1(3) + _ = A2(7) + + # NOTE (mristin): + # We also double-check that the invariant checks can pass correctly. + _ = B(11) + + # NOTE (mristin): + # Now we check that the invariants are violated which we exactly expect. + violation_error = None # type: Optional[icontract.ViolationError] + try: + _ = B(8) + except icontract.ViolationError as err: + violation_error = err + + self.assertIsNotNone(violation_error) + self.assertEqual( + # NOTE (mristin): + # This is a violation expected to be raised after B.__init__ even though B.__init__ is simply delegated to + # either A1.__init__ or A2.__init__. + textwrap.dedent( + """\ + self.x > 10: + self was an instance of B + self.x was 8""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) + class TestProperty(unittest.TestCase): def test_inherited_getter(self) -> None: