Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions pybond/checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import warnings

from pybond.types import SpyableClass, SpyableFunction
from pybond.util import (
function_signatures_match,
is_wrapped_function,
list_class_attributes,
list_class_methods,
)


def _function_signatures_match(
originalf: SpyableFunction,
stubf: SpyableFunction,
) -> bool:
"""
Supports both regular functions and decorated functions using
functools.wraps
"""
return (
(
is_wrapped_function(originalf)
and function_signatures_match(originalf.__wrapped__, stubf)
) or (
not is_wrapped_function(originalf)
and function_signatures_match(originalf, stubf)
)
)


def check_if_function_is_instrumentable(
original_obj: SpyableFunction,
stub_obj: SpyableFunction,
strict: bool = True,
) -> None:
if strict and not _function_signatures_match(original_obj, stub_obj):
raise ValueError(
f"Stub does not match the signature of {original_obj.__name__}."
)


def _check_if_class_methods_are_instrumentable(
method_names: list[str],
original_obj: SpyableClass,
stub_obj: SpyableClass,
) -> None:
unsupported_callables = []
for method_name in method_names:
try:
if not _function_signatures_match(
getattr(original_obj, method_name),
getattr(stub_obj, method_name),
):
raise ValueError(
f"Stub method {stub_obj.__name__}.{method_name} does not "
"match the signature of the original "
f"{original_obj.__name__}.{method_name} class method. "
"Please ensure the implementation of the provided stub "
"matches that of the original class, or set the 'strict' "
"option to False."
)
except TypeError as e:
if str(e) == "unsupported callable":
unsupported_callables.append(method_name)

if len(unsupported_callables) > 0:
PYBOND_WARNING__unsupported_callables = (
"The following methods' signatures cannot be checked: "
f"{unsupported_callables}."
)
warnings.warn(PYBOND_WARNING__unsupported_callables)


def check_if_class_is_instrumentable(
original_obj: SpyableClass,
stub_obj: SpyableClass,
strict: bool = True,
) -> None:
original_obj_attributes = list_class_attributes(original_obj)
original_obj_methods = list_class_methods(original_obj)
stub_obj_attributes = list_class_attributes(stub_obj)
stub_obj_methods = list_class_methods(stub_obj)
if strict:
if set(original_obj_attributes) != set(stub_obj_attributes):
raise ValueError(
f"Stub object '{stub_obj.__name__}' does not have the same set "
f"of attributes as the original '{original_obj.__name__}' "
"class. Please ensure the implementation of the provided stub "
"matches that of the original class, or set the 'strict' "
"option to False.\n"
f"Original: {original_obj_attributes}\n"
f"Provided: {stub_obj_attributes}"
)
if set(original_obj_methods) != set(stub_obj_methods):
raise ValueError(
f"Stub object '{stub_obj.__name__}' does not have the same set "
f"of methods as the original '{original_obj.__name__}' class. "
"Please ensure the implementation of the provided stub matches "
"that of the original class, or set the 'strict' option to "
"False.\n"
f"Original: {original_obj_methods}\n"
f"Provided: {stub_obj_methods}"
)
_check_if_class_methods_are_instrumentable(
method_names=original_obj_methods,
original_obj=original_obj,
stub_obj=stub_obj,
)
75 changes: 30 additions & 45 deletions pybond/james.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
"""This module is inspired by clojure's bond library."""

import sys
from contextlib import contextmanager
from copy import deepcopy
from functools import wraps
from inspect import isclass
from typing import Callable
from sys import exc_info

from pytest import MonkeyPatch

from pybond.checks import (
check_if_class_is_instrumentable,
check_if_function_is_instrumentable,
)
from pybond.memory import replace_bound_references_in_memory
from pybond.util import function_signatures_match, is_wrapped_function
from pybond.types import FunctionCall, Spyable, SpyTarget, StubTarget
from pybond.types import (
FunctionCall,
Spyable,
SpyableClass,
SpyableFunction,
SpyTarget,
StubTarget,
)
from pybond.util import list_class_methods


def _function_call(args, kwargs, error, return_value) -> FunctionCall:
Expand All @@ -23,7 +33,7 @@ def _function_call(args, kwargs, error, return_value) -> FunctionCall:
}


def _spy_function(f: Callable) -> Spyable:
def _spy_function(f: SpyableFunction) -> Spyable:
"""
Wrap f, returning a new function that keeps track of its call count and
arguments.
Expand Down Expand Up @@ -54,7 +64,7 @@ def handle_function_call(*args, **kwargs):
_function_call(
args=non_mutated_args,
kwargs=non_mutated_kwargs,
error=sys.exc_info(),
error=exc_info(),
return_value=None,
)
)
Expand Down Expand Up @@ -82,54 +92,29 @@ def calls(f: Spyable) -> list[FunctionCall]:
)


def _function_signatures_match(originalf: Callable, stubf: Callable) -> bool:
"""
Supports both regular functions and decorated functions using
functools.wraps
"""
return (
(
is_wrapped_function(originalf)
and function_signatures_match(originalf.__wrapped__, stubf)
) or (
not is_wrapped_function(originalf)
and function_signatures_match(originalf, stubf)
)
)


def _check_if_class_is_instrumentable(
original_obj: Spyable,
stub_obj: Spyable,
strict: bool = True,
) -> None:
# TODO: implement spying on classes and class methods
return


def _check_if_function_is_instrumentable(
original_obj: Callable,
stub_obj: Callable,
strict: bool = True,
) -> None:
if strict and not _function_signatures_match(original_obj, stub_obj):
raise ValueError(
f"Stub does not match the signature of {original_obj.__name__}."
)
def _spy_all_methods(obj: SpyableClass) -> SpyableClass:
obj_methods = list_class_methods(obj)
for method_name in obj_methods:
setattr(obj, method_name, _spy_function(getattr(obj, method_name)))
return obj


def _instrumented_obj(
original_obj: Spyable,
stub_obj: Spyable,
strict: bool = True,
) -> Spyable:
if isclass(original_obj):
# TODO: implement spying on classes and class methods
_check_if_class_is_instrumentable(original_obj, stub_obj, strict)
return stub_obj
if isclass(original_obj) and isclass(stub_obj):
check_if_class_is_instrumentable(original_obj, stub_obj, strict)
return _spy_all_methods(stub_obj)
elif callable(original_obj) and callable(stub_obj):
_check_if_function_is_instrumentable(original_obj, stub_obj, strict)
check_if_function_is_instrumentable(original_obj, stub_obj, strict)
return _spy_function(stub_obj)
elif isclass(original_obj) and not isclass(stub_obj):
raise ValueError(
f"Provided stub for class {original_obj.__name__} of type "
f"{type(stub_obj)} is invalid: pybond expected a class instance."
)
elif callable(original_obj) and not callable(stub_obj):
raise ValueError(
f"Provided stub for Callable {original_obj.__name__} of type "
Expand Down
4 changes: 3 additions & 1 deletion pybond/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
}
)

Spyable: TypeAlias = Callable | object
SpyableClass: TypeAlias = Any
SpyableFunction: TypeAlias = Callable
Spyable: TypeAlias = SpyableFunction | SpyableClass
SpyTarget: TypeAlias = Tuple[ModuleType, str]
StubTarget: TypeAlias = Tuple[ModuleType, str, Spyable]
20 changes: 19 additions & 1 deletion pybond/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def function_signatures_match(f, g):
# Python. For example, in CPython, some built-in functions defined in C
# provide no metadata about their arguments.
if str(e) == "unsupported callable":
if [f.__module__, f.__name__] == ["time", "time"]:
fmodule = getattr(f, "__module__", None)
fname = getattr(f, "__name__", None)
if [fmodule, fname] == ["time", "time"]:
return function_signatures_match(_fn_with_zero_arguments, g)
# Add other specific cases here
else:
Expand All @@ -65,3 +67,19 @@ def function_signatures_match(f, g):

def is_wrapped_function(f: Callable) -> bool:
return hasattr(f, "__wrapped__")


def list_class_attributes(obj: object) -> list[str]:
return [
attr for attr in dir(obj)
if not callable(getattr(obj, attr))
and not (attr.startswith("__") and attr.endswith("__"))
]


def list_class_methods(obj: object) -> list[str]:
return [
attr for attr in dir(obj)
if callable(getattr(obj, attr))
and not (attr.startswith("__") and attr.endswith("__"))
]
9 changes: 6 additions & 3 deletions tests/sample_code/mocks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from datetime import datetime


def mock_write_to_disk(x):
return "Wrote to disk!"

Expand All @@ -14,9 +17,9 @@ def mock_make_a_network_request(


def create_mock_datetime(mock_now):
class MockDatetime():
@staticmethod
def now(tz=None):
class MockDatetime(datetime):
@classmethod
def now(cls, tz=None):
return mock_now

return MockDatetime