diff --git a/before_after/__init__.py b/before_after/__init__.py index 1081684..5b25d73 100644 --- a/before_after/__init__.py +++ b/before_after/__init__.py @@ -14,6 +14,7 @@ from contextlib import contextmanager from functools import wraps +import wrapt def before(target, fn, **kwargs): @@ -27,28 +28,29 @@ def after(target, fn, **kwargs): @contextmanager def before_after( target, before_fn=None, after_fn=None, once=True, **kwargs): - def before_after_wrap(fn): - called = [] - - @wraps(fn) - def inner(*a, **k): - # If once is True, then don't call if this function has already - # been called - if once: - if called: - return fn(*a, **k) - else: - # Hack for lack of nonlocal keyword in Python 2: append to - # list to maked called truthy - called.append(True) - - if before_fn: - before_fn(*a, **k) + called = [] + @wrapt.decorator + def before_after_wrap(fn, instance, a, k): + + # If once is True, then don't call if this function has already + # been called + if once: + if called: + return fn(*a, **k) + else: + # Hack for lack of nonlocal keyword in Python 2: append to + # list to maked called truthy + called.append(True) + + if before_fn: + before_fn(*a, **k) + try: ret = fn(*a, **k) + finally: if after_fn: after_fn(*a, **k) - return ret - return inner + + return ret from mock import patch diff --git a/before_after/tests/test_before_after.py b/before_after/tests/test_before_after.py index 017df1e..4025be8 100644 --- a/before_after/tests/test_before_after.py +++ b/before_after/tests/test_before_after.py @@ -9,6 +9,7 @@ class TestBeforeAfter(TestCase): def setUp(self): test_functions.reset_test_list() + test_functions.CLASS_LIST = [] super(TestBeforeAfter, self).setUp() def test_before(self): @@ -92,3 +93,27 @@ def before_fn(self, *a): sample_instance.method(2) self.assertEqual(sample_instance.instance_list, [1, 2]) + + def test_before_classmethod(self): + def before_fn(*a): + test_functions.Sample.CLASS_LIST.append(1) + + with before('before_after.tests.test_functions.Sample.class_method', before_fn): + test_functions.Sample.class_method(2) + + self.assertEqual(test_functions.Sample.CLASS_LIST, [1, 2]) + + def test_after_called_exception(self): + sample_instance = test_functions.Sample() + + def after_fn(self, *a): + sample_instance.instance_list.append(2) + + with after('before_after.tests.test_functions.Sample.method_with_exception', after_fn): + try: + sample_instance.method_with_exception(1) + self.fail("Expected exception to be raised!") + except: + pass + + self.assertEqual(sample_instance.instance_list, [1, 2]) diff --git a/before_after/tests/test_functions.py b/before_after/tests/test_functions.py index b4698cd..e405ec9 100644 --- a/before_after/tests/test_functions.py +++ b/before_after/tests/test_functions.py @@ -4,14 +4,25 @@ def reset_test_list(): del test_list[:] def sample_fn(arg): - print 'sample', arg + print('sample', arg) test_list.append(arg) class Sample(object): + CLASS_LIST = [] + def __init__(self): self.instance_list = [] def method(self, arg): - print 'Sample.method', arg + print('Sample.method', arg) self.instance_list.append(arg) + + def method_with_exception(self, arg): + self.method(arg) + raise Exception + + @classmethod + def class_method(cls, arg): + print('Sample.class_method', arg) + cls.CLASS_LIST.append(arg) diff --git a/setup.py b/setup.py index 69e3ad9..00c80b0 100644 --- a/setup.py +++ b/setup.py @@ -34,5 +34,5 @@ ], keywords = ['testing', 'race conditions'], - install_requires=['mock==1.0.1'], + install_requires=['mock==1.0.1', 'wrapt'], )