From 95e6f662e293c1320942b93ac8b6edae8476e06e Mon Sep 17 00:00:00 2001 From: Matt Hoffman Date: Mon, 8 Jun 2020 14:30:42 -0400 Subject: [PATCH 1/2] uses wrapt to support classmethods --- before_after/__init__.py | 43 ++++++++++++------------- before_after/tests/test_before_after.py | 10 ++++++ before_after/tests/test_functions.py | 11 +++++-- setup.py | 2 +- 4 files changed, 41 insertions(+), 25 deletions(-) diff --git a/before_after/__init__.py b/before_after/__init__.py index 1081684..ecdc31d 100644 --- a/before_after/__init__.py +++ b/before_after/__init__.py @@ -14,6 +14,7 @@ from contextlib import contextmanager from functools import wraps +import wrapt def before(target, fn, **kwargs): @@ -27,28 +28,26 @@ def after(target, fn, **kwargs): @contextmanager def before_after( target, before_fn=None, after_fn=None, once=True, **kwargs): - def before_after_wrap(fn): - called = [] - - @wraps(fn) - def inner(*a, **k): - # If once is True, then don't call if this function has already - # been called - if once: - if called: - return fn(*a, **k) - else: - # Hack for lack of nonlocal keyword in Python 2: append to - # list to maked called truthy - called.append(True) - - if before_fn: - before_fn(*a, **k) - ret = fn(*a, **k) - if after_fn: - after_fn(*a, **k) - return ret - return inner + called = [] + @wrapt.decorator + def before_after_wrap(fn, instance, a, k): + + # If once is True, then don't call if this function has already + # been called + if once: + if called: + return fn(*a, **k) + else: + # Hack for lack of nonlocal keyword in Python 2: append to + # list to maked called truthy + called.append(True) + + if before_fn: + before_fn(*a, **k) + ret = fn(*a, **k) + if after_fn: + after_fn(*a, **k) + return ret from mock import patch diff --git a/before_after/tests/test_before_after.py b/before_after/tests/test_before_after.py index 017df1e..cd65562 100644 --- a/before_after/tests/test_before_after.py +++ b/before_after/tests/test_before_after.py @@ -9,6 +9,7 @@ class TestBeforeAfter(TestCase): def setUp(self): test_functions.reset_test_list() + test_functions.CLASS_LIST = [] super(TestBeforeAfter, self).setUp() def test_before(self): @@ -92,3 +93,12 @@ def before_fn(self, *a): sample_instance.method(2) self.assertEqual(sample_instance.instance_list, [1, 2]) + + def test_before_classmethod(self): + def before_fn(*a): + test_functions.Sample.CLASS_LIST.append(1) + + with before('before_after.tests.test_functions.Sample.class_method', before_fn): + test_functions.Sample.class_method(2) + + self.assertEqual(test_functions.Sample.CLASS_LIST, [1, 2]) diff --git a/before_after/tests/test_functions.py b/before_after/tests/test_functions.py index b4698cd..c329fcc 100644 --- a/before_after/tests/test_functions.py +++ b/before_after/tests/test_functions.py @@ -4,14 +4,21 @@ def reset_test_list(): del test_list[:] def sample_fn(arg): - print 'sample', arg + print('sample', arg) test_list.append(arg) class Sample(object): + CLASS_LIST = [] + def __init__(self): self.instance_list = [] def method(self, arg): - print 'Sample.method', arg + print('Sample.method', arg) self.instance_list.append(arg) + + @classmethod + def class_method(cls, arg): + print('Sample.class_method', arg) + cls.CLASS_LIST.append(arg) diff --git a/setup.py b/setup.py index 69e3ad9..00c80b0 100644 --- a/setup.py +++ b/setup.py @@ -34,5 +34,5 @@ ], keywords = ['testing', 'race conditions'], - install_requires=['mock==1.0.1'], + install_requires=['mock==1.0.1', 'wrapt'], ) From cbf3fff6b13ff8e3fb807d685517f1a07dbcc47e Mon Sep 17 00:00:00 2001 From: Matt Hoffman Date: Mon, 8 Jun 2020 14:41:36 -0400 Subject: [PATCH 2/2] be sure to call after function even if an exception is raised --- before_after/__init__.py | 9 ++++++--- before_after/tests/test_before_after.py | 15 +++++++++++++++ before_after/tests/test_functions.py | 4 ++++ 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/before_after/__init__.py b/before_after/__init__.py index ecdc31d..5b25d73 100644 --- a/before_after/__init__.py +++ b/before_after/__init__.py @@ -44,9 +44,12 @@ def before_after_wrap(fn, instance, a, k): if before_fn: before_fn(*a, **k) - ret = fn(*a, **k) - if after_fn: - after_fn(*a, **k) + try: + ret = fn(*a, **k) + finally: + if after_fn: + after_fn(*a, **k) + return ret from mock import patch diff --git a/before_after/tests/test_before_after.py b/before_after/tests/test_before_after.py index cd65562..4025be8 100644 --- a/before_after/tests/test_before_after.py +++ b/before_after/tests/test_before_after.py @@ -102,3 +102,18 @@ def before_fn(*a): test_functions.Sample.class_method(2) self.assertEqual(test_functions.Sample.CLASS_LIST, [1, 2]) + + def test_after_called_exception(self): + sample_instance = test_functions.Sample() + + def after_fn(self, *a): + sample_instance.instance_list.append(2) + + with after('before_after.tests.test_functions.Sample.method_with_exception', after_fn): + try: + sample_instance.method_with_exception(1) + self.fail("Expected exception to be raised!") + except: + pass + + self.assertEqual(sample_instance.instance_list, [1, 2]) diff --git a/before_after/tests/test_functions.py b/before_after/tests/test_functions.py index c329fcc..e405ec9 100644 --- a/before_after/tests/test_functions.py +++ b/before_after/tests/test_functions.py @@ -18,6 +18,10 @@ def method(self, arg): print('Sample.method', arg) self.instance_list.append(arg) + def method_with_exception(self, arg): + self.method(arg) + raise Exception + @classmethod def class_method(cls, arg): print('Sample.class_method', arg)