diff --git a/AUTHORS.rst b/AUTHORS.rst index 42a456c..38c0b7c 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -32,3 +32,4 @@ Patches and Suggestions - Jonathan Herriott - Job Evers - Cyrus Durgin +- Boden Russell \ No newline at end of file diff --git a/README.rst b/README.rst index ab5d6bd..bdfe4d1 100644 --- a/README.rst +++ b/README.rst @@ -144,6 +144,7 @@ We can also use the result of the function to alter the behavior of retrying. Any combination of stop, wait, etc. is also supported to give you the freedom to mix and match. + Contribute ---------- diff --git a/retrying.py b/retrying.py index bcb7a9d..3302aed 100644 --- a/retrying.py +++ b/retrying.py @@ -12,6 +12,7 @@ ## See the License for the specific language governing permissions and ## limitations under the License. +import inspect import random import six import sys @@ -111,29 +112,29 @@ def __init__(self, else: self.stop = getattr(self, stop) - # TODO add chaining of wait behaviors - # wait behavior - wait_funcs = [lambda *args, **kwargs: 0] + wait_funcs = CallChain(lambda *args, **kwargs: 0) if wait_fixed is not None: - wait_funcs.append(self.fixed_sleep) + wait_funcs += self.fixed_sleep if wait_random_min is not None or wait_random_max is not None: - wait_funcs.append(self.random_sleep) + wait_funcs += self.random_sleep if wait_incrementing_start is not None or wait_incrementing_increment is not None: - wait_funcs.append(self.incrementing_sleep) + wait_funcs += self.incrementing_sleep if wait_exponential_multiplier is not None or wait_exponential_max is not None: - wait_funcs.append(self.exponential_sleep) + wait_funcs += self.exponential_sleep if wait_func is not None: - self.wait = wait_func + wait_funcs += wait_func elif wait is None: - self.wait = lambda attempts, delay: max(f(attempts, delay) for f in wait_funcs) + wait_funcs += lambda attempts, delay, chain_results=None: max(chain_results) else: - self.wait = getattr(self, wait) + wait_funcs += getattr(self, wait) + + self.wait = wait_funcs # retry on exception filter if retry_on_exception is None: @@ -283,6 +284,35 @@ def __repr__(self): return "Attempts: {0}, Value: {1}".format(self.attempt_number, self.value) +class CallChain(object): + + def __init__(self, *fns): + self._chain = [] + for fn in fns: + self._assert_callable(fn) + self._chain.append(fn) + + def __add__(self, other): + self._assert_callable(object) + self._chain.append(other) + return self + + def __call__(self, *args, **kwargs): + results = [] + for fn in self._chain: + if 'chain_results' in inspect.getargspec(fn).args: + fn_kwargs = dict(kwargs) + fn_kwargs['chain_results'] = results + results.append(fn(*args, **fn_kwargs)) + else: + results.append(fn(*args, **kwargs)) + return results[-1] if results else None + + def _assert_callable(self, fn): + if not hasattr(fn, '__call__'): + raise TypeError("'%s' is not callable" % fn) + + class RetryError(Exception): """ A RetryError encapsulates the last Attempt instance right before giving up. diff --git a/test_retrying.py b/test_retrying.py index 8ce4ac3..4ac634c 100644 --- a/test_retrying.py +++ b/test_retrying.py @@ -434,6 +434,7 @@ def test_defaults(self): self.assertTrue(_retryable_default(NoCustomErrorAfterCount(5))) self.assertTrue(_retryable_default_f(NoCustomErrorAfterCount(5))) + class TestBeforeAfterAttempts(unittest.TestCase): _attempt_number = 0 @@ -468,5 +469,6 @@ def _test_after(): self.assertTrue(TestBeforeAfterAttempts._attempt_number is 2) + if __name__ == '__main__': unittest.main()