diff --git a/pybond/checks.py b/pybond/checks.py new file mode 100644 index 0000000..96edca1 --- /dev/null +++ b/pybond/checks.py @@ -0,0 +1,108 @@ +import warnings + +from pybond.types import SpyableClass, SpyableFunction +from pybond.util import ( + function_signatures_match, + is_wrapped_function, + list_class_attributes, + list_class_methods, +) + + +def _function_signatures_match( + originalf: SpyableFunction, + stubf: SpyableFunction, +) -> bool: + """ + Supports both regular functions and decorated functions using + functools.wraps + """ + return ( + ( + is_wrapped_function(originalf) + and function_signatures_match(originalf.__wrapped__, stubf) + ) or ( + not is_wrapped_function(originalf) + and function_signatures_match(originalf, stubf) + ) + ) + + +def check_if_function_is_instrumentable( + original_obj: SpyableFunction, + stub_obj: SpyableFunction, + strict: bool = True, +) -> None: + if strict and not _function_signatures_match(original_obj, stub_obj): + raise ValueError( + f"Stub does not match the signature of {original_obj.__name__}." + ) + + +def _check_if_class_methods_are_instrumentable( + method_names: list[str], + original_obj: SpyableClass, + stub_obj: SpyableClass, +) -> None: + unsupported_callables = [] + for method_name in method_names: + try: + if not _function_signatures_match( + getattr(original_obj, method_name), + getattr(stub_obj, method_name), + ): + raise ValueError( + f"Stub method {stub_obj.__name__}.{method_name} does not " + "match the signature of the original " + f"{original_obj.__name__}.{method_name} class method. " + "Please ensure the implementation of the provided stub " + "matches that of the original class, or set the 'strict' " + "option to False." + ) + except TypeError as e: + if str(e) == "unsupported callable": + unsupported_callables.append(method_name) + + if len(unsupported_callables) > 0: + PYBOND_WARNING__unsupported_callables = ( + "The following methods' signatures cannot be checked: " + f"{unsupported_callables}." + ) + warnings.warn(PYBOND_WARNING__unsupported_callables) + + +def check_if_class_is_instrumentable( + original_obj: SpyableClass, + stub_obj: SpyableClass, + strict: bool = True, +) -> None: + original_obj_attributes = list_class_attributes(original_obj) + original_obj_methods = list_class_methods(original_obj) + stub_obj_attributes = list_class_attributes(stub_obj) + stub_obj_methods = list_class_methods(stub_obj) + if strict: + if set(original_obj_attributes) != set(stub_obj_attributes): + raise ValueError( + f"Stub object '{stub_obj.__name__}' does not have the same set " + f"of attributes as the original '{original_obj.__name__}' " + "class. Please ensure the implementation of the provided stub " + "matches that of the original class, or set the 'strict' " + "option to False.\n" + f"Original: {original_obj_attributes}\n" + f"Provided: {stub_obj_attributes}" + ) + if set(original_obj_methods) != set(stub_obj_methods): + raise ValueError( + f"Stub object '{stub_obj.__name__}' does not have the same set " + f"of methods as the original '{original_obj.__name__}' class. " + "Please ensure the implementation of the provided stub matches " + "that of the original class, or set the 'strict' option to " + "False.\n" + f"Original: {original_obj_methods}\n" + f"Provided: {stub_obj_methods}" + ) + _check_if_class_methods_are_instrumentable( + method_names=original_obj_methods, + original_obj=original_obj, + stub_obj=stub_obj, + ) diff --git a/pybond/james.py b/pybond/james.py index ba31aed..bf3c4b7 100644 --- a/pybond/james.py +++ b/pybond/james.py @@ -1,17 +1,27 @@ """This module is inspired by clojure's bond library.""" -import sys from contextlib import contextmanager from copy import deepcopy from functools import wraps from inspect import isclass -from typing import Callable +from sys import exc_info from pytest import MonkeyPatch +from pybond.checks import ( + check_if_class_is_instrumentable, + check_if_function_is_instrumentable, +) from pybond.memory import replace_bound_references_in_memory -from pybond.util import function_signatures_match, is_wrapped_function -from pybond.types import FunctionCall, Spyable, SpyTarget, StubTarget +from pybond.types import ( + FunctionCall, + Spyable, + SpyableClass, + SpyableFunction, + SpyTarget, + StubTarget, +) +from pybond.util import list_class_methods def _function_call(args, kwargs, error, return_value) -> FunctionCall: @@ -23,7 +33,7 @@ def _function_call(args, kwargs, error, return_value) -> FunctionCall: } -def _spy_function(f: Callable) -> Spyable: +def _spy_function(f: SpyableFunction) -> Spyable: """ Wrap f, returning a new function that keeps track of its call count and arguments. @@ -54,7 +64,7 @@ def handle_function_call(*args, **kwargs): _function_call( args=non_mutated_args, kwargs=non_mutated_kwargs, - error=sys.exc_info(), + error=exc_info(), return_value=None, ) ) @@ -82,40 +92,11 @@ def calls(f: Spyable) -> list[FunctionCall]: ) -def _function_signatures_match(originalf: Callable, stubf: Callable) -> bool: - """ - Supports both regular functions and decorated functions using - functools.wraps - """ - return ( - ( - is_wrapped_function(originalf) - and function_signatures_match(originalf.__wrapped__, stubf) - ) or ( - not is_wrapped_function(originalf) - and function_signatures_match(originalf, stubf) - ) - ) - - -def _check_if_class_is_instrumentable( - original_obj: Spyable, - stub_obj: Spyable, - strict: bool = True, -) -> None: - # TODO: implement spying on classes and class methods - return - - -def _check_if_function_is_instrumentable( - original_obj: Callable, - stub_obj: Callable, - strict: bool = True, -) -> None: - if strict and not _function_signatures_match(original_obj, stub_obj): - raise ValueError( - f"Stub does not match the signature of {original_obj.__name__}." - ) +def _spy_all_methods(obj: SpyableClass) -> SpyableClass: + obj_methods = list_class_methods(obj) + for method_name in obj_methods: + setattr(obj, method_name, _spy_function(getattr(obj, method_name))) + return obj def _instrumented_obj( @@ -123,13 +104,17 @@ def _instrumented_obj( stub_obj: Spyable, strict: bool = True, ) -> Spyable: - if isclass(original_obj): - # TODO: implement spying on classes and class methods - _check_if_class_is_instrumentable(original_obj, stub_obj, strict) - return stub_obj + if isclass(original_obj) and isclass(stub_obj): + check_if_class_is_instrumentable(original_obj, stub_obj, strict) + return _spy_all_methods(stub_obj) elif callable(original_obj) and callable(stub_obj): - _check_if_function_is_instrumentable(original_obj, stub_obj, strict) + check_if_function_is_instrumentable(original_obj, stub_obj, strict) return _spy_function(stub_obj) + elif isclass(original_obj) and not isclass(stub_obj): + raise ValueError( + f"Provided stub for class {original_obj.__name__} of type " + f"{type(stub_obj)} is invalid: pybond expected a class instance." + ) elif callable(original_obj) and not callable(stub_obj): raise ValueError( f"Provided stub for Callable {original_obj.__name__} of type " diff --git a/pybond/types.py b/pybond/types.py index 7553ebf..abb49b0 100644 --- a/pybond/types.py +++ b/pybond/types.py @@ -11,6 +11,8 @@ } ) -Spyable: TypeAlias = Callable | object +SpyableClass: TypeAlias = Any +SpyableFunction: TypeAlias = Callable +Spyable: TypeAlias = SpyableFunction | SpyableClass SpyTarget: TypeAlias = Tuple[ModuleType, str] StubTarget: TypeAlias = Tuple[ModuleType, str, Spyable] diff --git a/pybond/util.py b/pybond/util.py index 3b223d1..ef573e3 100644 --- a/pybond/util.py +++ b/pybond/util.py @@ -54,7 +54,9 @@ def function_signatures_match(f, g): # Python. For example, in CPython, some built-in functions defined in C # provide no metadata about their arguments. if str(e) == "unsupported callable": - if [f.__module__, f.__name__] == ["time", "time"]: + fmodule = getattr(f, "__module__", None) + fname = getattr(f, "__name__", None) + if [fmodule, fname] == ["time", "time"]: return function_signatures_match(_fn_with_zero_arguments, g) # Add other specific cases here else: @@ -65,3 +67,19 @@ def function_signatures_match(f, g): def is_wrapped_function(f: Callable) -> bool: return hasattr(f, "__wrapped__") + + +def list_class_attributes(obj: object) -> list[str]: + return [ + attr for attr in dir(obj) + if not callable(getattr(obj, attr)) + and not (attr.startswith("__") and attr.endswith("__")) + ] + + +def list_class_methods(obj: object) -> list[str]: + return [ + attr for attr in dir(obj) + if callable(getattr(obj, attr)) + and not (attr.startswith("__") and attr.endswith("__")) + ] diff --git a/tests/sample_code/mocks.py b/tests/sample_code/mocks.py index abdca46..61ca802 100644 --- a/tests/sample_code/mocks.py +++ b/tests/sample_code/mocks.py @@ -1,3 +1,6 @@ +from datetime import datetime + + def mock_write_to_disk(x): return "Wrote to disk!" @@ -14,9 +17,9 @@ def mock_make_a_network_request( def create_mock_datetime(mock_now): - class MockDatetime(): - @staticmethod - def now(tz=None): + class MockDatetime(datetime): + @classmethod + def now(cls, tz=None): return mock_now return MockDatetime