diff --git a/icontract/_checkers.py b/icontract/_checkers.py index 0503317..24fa3d4 100644 --- a/icontract/_checkers.py +++ b/icontract/_checkers.py @@ -979,13 +979,14 @@ def wrapper(*args, **kwargs): # type: ignore return wrapper # type: ignore -def _decorate_with_invariants(func: CallableT, is_init: bool) -> CallableT: +def _decorate_with_invariants(func: CallableT, cls: ClassT, is_init: bool) -> CallableT: """ Decorate the method ``func`` with invariant checks. If the function has been already decorated with invariant checks, the function returns immediately. :param func: function to be wrapped + :param cls: class corresponding to the invariant and ``func`` :param is_init: True if the ``func`` is __init__ :return: function wrapped with invariant checks """ @@ -1027,7 +1028,25 @@ def wrapper(*args, **kwargs): # type: ignore try: result = func(*args, **kwargs) - for invariant in instance.__class__.__invariants__: + # NOTE (mristin): + # We go to the invariants corresponding to the class, not the instance, as we need to + # account also for a situation where super().__init__ is called. Here is an example: + # + # @invariant(lambda self: ...) + # class A(DBC): + # pass + # + # @invariant(lambda self: ...) + # class B(A): + # def __init__(self) -> None: + # super().__init__() + # # ↖ After this call, only the invariants of A, but not B, have to be checked. + # # + # # However, the ``instance`` (i.e., resolved ``self``) in super().__init__ call points to + # # an instance of B, so instance.__cls__.__invariants__ refer to invariants of B, not A. + + # noinspection PyUnresolvedReferences + for invariant in cls.__invariants__: # type: ignore _assert_invariant(contract=invariant, instance=instance) return result @@ -1273,35 +1292,97 @@ def add_invariant_checks(cls: ClassT) -> None: ) ) - if init_func: - # We have to distinguish this special case which is used by named - # tuples and possibly other optimized data structures. - # In those cases, we have to wrap __new__ instead of __init__. - if init_func == object.__init__ and hasattr(cls, "__new__"): - new_func = getattr(cls, "__new__") - setattr(cls, "__new__", _decorate_new_with_invariants(new_func)) - else: - wrapper = _decorate_with_invariants(func=init_func, is_init=True) - setattr(cls, init_func.__name__, wrapper) + assert init_func is not None, "Every class in Python must have a constructor." + + # We must handle this special case which is used by named + # tuples and possibly other optimized data structures. + # In those cases, we have to wrap __new__ instead of __init__. + if init_func == object.__init__ and hasattr(cls, "__new__"): + new_func = getattr(cls, "__new__") + setattr(cls, "__new__", _decorate_new_with_invariants(new_func)) + else: + # NOTE (mristin): + # We have to create a new __init__ function so that the invariants of *this* class are checked. + # The problem arises due to two different cases related to inheritance which we can not distinguish in Python. + # Namely, we can not know whether we are dealing with invariants coming from ``super().__init__`` or + # an implicit call to ``__init__``. + # + # In both of these edge cases, the instance is of the child class, but the constructors are + # of the parent class. Checking the invariants attached to the class would break the second case, while checking + # the invariants attached to the instance (through ``self.__class__.__invariants__``) would break the first + # case. + # + # The following snippets depict the two cases. + # + # Case 1: ``super().__init__`` + # @invariant(lambda self: ...) + # class A(DBC): + # pass + # + # @invariant(lambda self: ...) + # class B(A): + # def __init__(self) -> None: + # super().__init__() + # # ↖ After this call, only the invariants of A, but not B, have to be checked. + # # More code follows, and after this ``__init__``, invariants of B have to be checked. + # + # Case 2: Implicit ``__init__`` call + # @invariant(lambda self: ...) + # class A(DBC): + # pass + # + # @invariant(lambda self: ...) + # class B(A): + # pass + # + # b = B() + # # ↖ After this call, the invariants of B have to be checked. + # # However, we only see the call to A.__init__, since there is no B.__init__. + # + # Therefore, to avert this problem, we have to create an ``__init__`` in the child class for the second + # case. This allows us to always check for invariants attached to the class in the case of constructors, so both + # cases can be successfully handled. + + if "__init__" not in cls.__dict__: + init_after_mro = ( + # NOTE (mristin): + # mypy gives us the following warning: + # Accessing "__init__" on an instance is unsound, since instance.__init__ could be from an incompatible + # subclass + # + # ... but this is exactly what we want here -- we want to look up the __init__ of the class at runtime. + cls.__init__ # type: ignore + ) # This is the constructor after MRO, pointing to one of the parent classes. + + def __init__(self: Any, *args: Any, **kwargs: Any) -> None: + init_after_mro(self, *args, **kwargs) + + # NOTE (mristin): + # See the comment above corresponding to this mypy warning. + cls.__init__ = __init__ # type: ignore + init_func = __init__ + + wrapper = _decorate_with_invariants(func=init_func, cls=cls, is_init=True) + setattr(cls, init_func.__name__, wrapper) for name, func in names_funcs: - wrapper = _decorate_with_invariants(func=func, is_init=False) + wrapper = _decorate_with_invariants(func=func, cls=cls, is_init=False) setattr(cls, name, wrapper) for name, prop in names_properties: new_prop = property( fget=( - _decorate_with_invariants(func=prop.fget, is_init=False) + _decorate_with_invariants(func=prop.fget, cls=cls, is_init=False) if prop.fget else None ), fset=( - _decorate_with_invariants(func=prop.fset, is_init=False) + _decorate_with_invariants(func=prop.fset, cls=cls, is_init=False) if prop.fset else None ), fdel=( - _decorate_with_invariants(func=prop.fdel, is_init=False) + _decorate_with_invariants(func=prop.fdel, cls=cls, is_init=False) if prop.fdel else None ), diff --git a/tests/test_inheritance_invariant.py b/tests/test_inheritance_invariant.py index 4c9b48c..06ec522 100644 --- a/tests/test_inheritance_invariant.py +++ b/tests/test_inheritance_invariant.py @@ -162,6 +162,49 @@ class CTrue(B): # This produced an unexpected violation error. CTrue().do_something() + def test_calling_super_init_checks_invariants_of_parent_not_child(self) -> None: + # NOTE (mristin): + # This test is a regression test from: + # https://github.com/Parquery/icontract/issues/300 + + @icontract.invariant(lambda self: True) + class A(icontract.DBC): + def __init__(self) -> None: + pass + + @icontract.invariant(lambda self: self.x > 0) + class B(A): + def __init__(self, x: int) -> None: + super().__init__() + self.x = x + + _ = B(2) + + def test_calling_multi_super_init_checks_invariants_of_parent_not_child( + self, + ) -> None: + # NOTE (mristin): + # This test is a regression test from: + # https://github.com/Parquery/icontract/issues/300 + + @icontract.invariant(lambda self: True) + class A1(icontract.DBC): + def __init__(self) -> None: + pass + + @icontract.invariant(lambda self: True) + class A2(icontract.DBC): + def __init__(self) -> None: + pass + + @icontract.invariant(lambda self: self.x > 0) + class B(A1, A2): + def __init__(self, x: int) -> None: + super().__init__() + self.x = x + + _ = B(2) + class TestViolation(unittest.TestCase): def test_inherited(self) -> None: @@ -242,12 +285,11 @@ def __init__(self) -> None: self.x = 10 def __repr__(self) -> str: - return "an instance of A" + return "an instance of {}".format(self.__class__.__name__) @icontract.invariant(lambda self: self.x > 100) class B(A): - def __repr__(self) -> str: - return "an instance of B" + pass violation_error = None # type: Optional[icontract.ViolationError] try: @@ -377,6 +419,355 @@ def __repr__(self) -> str: tests.error.wo_mandatory_location(str(violation_error)), ) + def test_violated_in_super_init(self) -> None: + # NOTE (mristin): + # This test is a regression test from: + # https://github.com/Parquery/icontract/issues/300 + + @icontract.invariant(lambda self: self.x > 0) + class A(icontract.DBC): + def __init__(self, x: int): + self.x = x + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda self: self.x > 10) + class B(A): + def __init__(self, x: int): # pylint: disable=useless-parent-delegation + super().__init__(x) + + violation_error = None # type: Optional[icontract.ViolationError] + try: + _ = B(-1) + except icontract.ViolationError as err: + violation_error = err + + self.assertIsNotNone(violation_error) + self.assertEqual( + # NOTE (mristin): + # This is an exception expected to be raised after super().__init__ call, not after B.__init__. + textwrap.dedent( + """\ + self.x > 0: + self was an instance of B + self.x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) + + def test_violated_in_one_of_multi_super_init(self) -> None: + # NOTE (mristin): + # This test is a regression test from: + # https://github.com/Parquery/icontract/issues/300 + + @icontract.invariant(lambda self: self.x > 0) + class A1(icontract.DBC): + def __init__(self, x: int): + self.x = x + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda self: self.x > 2) + class A2(icontract.DBC): + def __init__(self, x: int): + self.x = x + + @icontract.invariant(lambda self: self.x > 10) + class B(A1, A2): + def __init__(self, x: int): + A1.__init__(self, x) + A2.__init__(self, x) + + violation_error = None # type: Optional[icontract.ViolationError] + try: + # NOTE (mristin): + # This is expected to violate __init__ in A2 but not A1. + _ = B(1) + except icontract.ViolationError as err: + violation_error = err + + self.assertIsNotNone(violation_error) + self.assertEqual( + # NOTE (mristin): + # ``self.x > 2`` indiactes a violation in __init__ of A2 but not A1 and not B. + textwrap.dedent( + """\ + self.x > 2: + self was an instance of B + self.x was 1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) + + def test_super_init_passes_but_init_violated(self) -> None: + # NOTE (mristin): + # This test is a regression test from: + # https://github.com/Parquery/icontract/issues/300 + + @icontract.invariant(lambda self: self.x > 0) + class A(icontract.DBC): + def __init__(self, x: int): + self.x = x + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda self: self.x > 10) + class B(A): + def __init__(self, x: int): # pylint: disable=useless-parent-delegation + super().__init__(x) + + violation_error = None # type: Optional[icontract.ViolationError] + try: + _ = B(2) + except icontract.ViolationError as err: + violation_error = err + + self.assertIsNotNone(violation_error) + self.assertEqual( + # NOTE (mristin): + # This is an exception expected to be raised after B.__init__, but not earlier after super().__init__ call. + textwrap.dedent( + """\ + self.x > 10: + self was an instance of B + self.x was 2""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) + + def test_multi_super_inits_pass_but_init_violated(self) -> None: + # NOTE (mristin): + # This test is a regression test from: + # https://github.com/Parquery/icontract/issues/300 + + @icontract.invariant(lambda self: self.x > 0) + class A1(icontract.DBC): + def __init__(self, x: int): + self.x = x + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda self: self.x > 2) + class A2(icontract.DBC): + def __init__(self, x: int): + self.x = x + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda self: self.x > 10) + class B(A1, A2): + def __init__(self, x: int): + A1.__init__(self, x) + A2.__init__(self, x) + + violation_error = None # type: Optional[icontract.ViolationError] + try: + _ = B(3) + except icontract.ViolationError as err: + violation_error = err + + self.assertIsNotNone(violation_error) + self.assertEqual( + # NOTE (mristin): + # This is an exception expected to be raised after B.__init__, but not in any super __init__ calls. + textwrap.dedent( + """\ + self.x > 10: + self was an instance of B + self.x was 3""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) + + def test_invariant_violated_in_parent_in_implicit_init(self) -> None: + @icontract.invariant(lambda self: self.x > 0) + class A(icontract.DBC): + def __init__(self, x: int) -> None: + self.x = x + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda self: self.x > 10) + class B(A): + # NOTE (mristin): + # We explicitly define no constructor (``__init__``) here. The constructor is implicitly called + # from A (``A.__init__``). However, after the call to the implicit constructor, the invariants of B, + # and not only A, should apply. + pass + + # NOTE (mristin): + # We check that the invariants of the parent are violated here. + violation_error = None # type: Optional[icontract.ViolationError] + try: + _ = B(-1) + except icontract.ViolationError as err: + violation_error = err + + self.assertIsNotNone(violation_error) + self.assertEqual( + # NOTE (mristin): + # This is a violation expected to be raised after A.__init__ which is delegated to from B.__init__. + textwrap.dedent( + """\ + self.x > 0: + self was an instance of B + self.x was -1""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) + + def test_invariant_violated_in_one_of_parents_in_implicit_init(self) -> None: + @icontract.invariant(lambda self: self.x > 0) + class A1(icontract.DBC): + def __init__(self, x: int) -> None: + self.x = x + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda self: self.x > 5) + class A2(icontract.DBC): + def __init__(self, x: int) -> None: + self.x = x + + class B(A1, A2): + # NOTE (mristin): + # We explicitly define no constructor (``__init__``) here. The constructor is implicitly called + # from one of the parents (``A1.__init__`` or ``A2.__init__``). However, after the call to the implicit + # constructor, the invariants of B, and not only of A1 and A2, should apply. + pass + + # NOTE (mristin): + # We check that the invariants of one of the parents are violated here. + violation_error = None # type: Optional[icontract.ViolationError] + try: + _ = B(3) + except icontract.ViolationError as err: + violation_error = err + + self.assertIsNotNone(violation_error) + self.assertEqual( + # NOTE (mristin): + # This is a violation expected to be raised after A2.__init__ which is delegated to from B.__init__. + textwrap.dedent( + """\ + self.x > 5: + self was an instance of B + self.x was 3""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) + + def test_invariant_violated_in_child_in_implicit_init(self) -> None: + @icontract.invariant(lambda self: self.x > 0) + class A(icontract.DBC): + def __init__(self, x: int) -> None: + self.x = x + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda self: self.x > 10) + class B(A): + # NOTE (mristin): + # We explicitly define no constructor (``__init__``) here. The constructor is implicitly called + # from A (``A.__init__``). However, after the call to the implicit constructor, the invariants of B, + # and not only A, should apply. + pass + + # NOTE (mristin): + # We make sure explicitly that the invariants did not leak to the parent -- the invariants of B should not + # apply to an instance of A. + _ = A(3) + + # NOTE (mristin): + # We also double-check that the invariants pass correctly. + _ = B(11) + + # NOTE (mristin): + # Now we check that the invariants are violated which we exactly expect. + violation_error = None # type: Optional[icontract.ViolationError] + try: + _ = B(3) + except icontract.ViolationError as err: + violation_error = err + + self.assertIsNotNone(violation_error) + self.assertEqual( + # NOTE (mristin): + # This is a violation expected to be raised after B.__init__ even though B.__init__ is simply delegated to + # A.__init__. + textwrap.dedent( + """\ + self.x > 10: + self was an instance of B + self.x was 3""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) + + def test_invariant_violated_in_child_in_implicit_init_with_multiple_inheritance( + self, + ) -> None: + @icontract.invariant(lambda self: self.x > 0) + class A1(icontract.DBC): + def __init__(self, x: int) -> None: + self.x = x + + def __repr__(self) -> str: + return "an instance of {}".format(self.__class__.__name__) + + @icontract.invariant(lambda self: self.x > 5) + class A2(icontract.DBC): + def __init__(self, x: int) -> None: + self.x = x + + @icontract.invariant(lambda self: self.x > 10) + class B(A1, A2): + # NOTE (mristin): + # We explicitly define no constructor (``__init__``) here. The constructor is implicitly called + # from one of the super classes. However, after the call to the implicit constructor, the invariants of B, + # and not only of A1 and A2, should apply. + pass + + # NOTE (mristin): + # We make sure explicitly that the invariants did not leak to the parent -- the invariants of B should not + # apply to an instance of A1 or of A2. + _ = A1(3) + _ = A2(7) + + # NOTE (mristin): + # We also double-check that the invariant checks can pass correctly. + _ = B(11) + + # NOTE (mristin): + # Now we check that the invariants are violated which we exactly expect. + violation_error = None # type: Optional[icontract.ViolationError] + try: + _ = B(8) + except icontract.ViolationError as err: + violation_error = err + + self.assertIsNotNone(violation_error) + self.assertEqual( + # NOTE (mristin): + # This is a violation expected to be raised after B.__init__ even though B.__init__ is simply delegated to + # either A1.__init__ or A2.__init__. + textwrap.dedent( + """\ + self.x > 10: + self was an instance of B + self.x was 8""" + ), + tests.error.wo_mandatory_location(str(violation_error)), + ) + class TestProperty(unittest.TestCase): def test_inherited_getter(self) -> None: