diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml new file mode 100644 index 0000000..238bab6 --- /dev/null +++ b/.github/workflows/pythonapp.yml @@ -0,0 +1,37 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Python application + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + build: + + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.8 + uses: actions/setup-python@v1 + with: + python-version: 3.8 + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 pytest + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + - name: Test with pytest + run: | + pip install git+https://github.com/brookman1/retry.git + pytest diff --git a/README.rst b/README.rst index e1d0e9c..0e82bfe 100644 --- a/README.rst +++ b/README.rst @@ -1,160 +1,149 @@ -retry -===== +# reretry -.. image:: https://img.shields.io/pypi/dm/retry.svg?maxAge=2592000 - :target: https://pypi.python.org/pypi/retry/ +![](https://img.shields.io/pypi/dm/reretry.svg?maxAge=2592000) +![](https://img.shields.io/pypi/v/reretry.svg?maxAge=2592000) +![](https://img.shields.io/pypi/l/reretry.svg?maxAge=2592000) -.. image:: https://img.shields.io/pypi/v/retry.svg?maxAge=2592000 - :target: https://pypi.python.org/pypi/retry/ +An easy to use retry decorator. -.. image:: https://img.shields.io/pypi/l/retry.svg?maxAge=2592000 - :target: https://pypi.python.org/pypi/retry/ +This package is a fork from the [`retry`](https://github.com/invl/retry) package, but with some of added community-sourced features. -Easy to use retry decorator. +## Features - -Features --------- - -- No external dependency (stdlib only). +From original `retry`: +- Retry on specific exceptions. +- Set a maximum number of retries. +- Set a delay between retries. +- Set a maximum delay between retries. +- Set backoff and jitter parameters. +- Use a custom logger. +- No external dependencies (stdlib only). - (Optionally) Preserve function signatures (`pip install decorator`). -- Original traceback, easy to debug. - - -Installation ------------- -.. code-block:: bash +New features in `reretry`: +- Log traceback of an error that lead to a failed attempt. +- Call a custom callback after each failed attempt. +- Can be used with async functions. - $ pip install retry +## Installation -API ---- +```bash +$ pip install reretry +``` -retry decorator -^^^^^^^^^^^^^^^ +## API +### The @retry decorator -.. code:: python +#### Usage +`@retry(exceptions=Exception, tries=-1, delay=0, max_delay=None, backoff=1, jitter=0, show_traceback=False, logger=logging_logger, fail_callback=None, condition=threading.Condition())` - def retry(exceptions=Exception, tries=-1, delay=0, max_delay=None, backoff=1, jitter=0, logger=logging_logger): - """Return a retry decorator. +#### Arguments +- `exceptions`: An exception or a tuple of exceptions to catch. Default: Exception. - :param exceptions: an exception or a tuple of exceptions to catch. default: Exception. - :param tries: the maximum number of attempts. default: -1 (infinite). - :param delay: initial delay between attempts. default: 0. - :param max_delay: the maximum value of delay. default: None (no limit). - :param backoff: multiplier applied to delay between attempts. default: 1 (no backoff). - :param jitter: extra seconds added to delay between attempts. default: 0. - fixed if a number, random if a range tuple (min, max) - :param logger: logger.warning(fmt, error, delay) will be called on failed attempts. - default: retry.logging_logger. if None, logging is disabled. - """ +- `tries`: The maximum number of attempts. default: -1 (infinite). -Various retrying logic can be achieved by combination of arguments. +- `delay`: Initial delay between attempts (in seconds). default: 0. +- `max_delay`: The maximum value of delay (in seconds). default: None (no limit). -Examples -"""""""" +- `backoff`: Multiplier applied to delay between attempts. default: 1 (no backoff). -.. code:: python +- `jitter`: Extra seconds added to delay between attempts. default: 0. Fixed if a number, random if a range tuple (min, max). - from retry import retry +- `show_traceback`: Print traceback before retrying (Python3 only). default: False. -.. code:: python +- `logger`: `logger.warning(fmt, error, delay)` will be called on failed attempts. default: retry.logging_logger. if None, logging is disabled. - @retry() - def make_trouble(): - '''Retry until succeed''' +- `fail_callback`: `fail_callback(e)` will be called after failed attempts. +- `condition`: `condition` is a construct that has ...acquire / ...release +and ...wait(n_seconds) -.. code:: python - @retry(ZeroDivisionError, tries=3, delay=2) - def make_trouble(): - '''Retry on ZeroDivisionError, raise error after 3 attempts, sleep 2 seconds between attempts.''' +#### Examples +```python +from reretry import retry -.. code:: python +@retry() +def make_trouble(): + '''Retry until succeeds''' - @retry((ValueError, TypeError), delay=1, backoff=2) - def make_trouble(): - '''Retry on ValueError or TypeError, sleep 1, 2, 4, 8, ... seconds between attempts.''' +@retry() +async def async_make_trouble(): + '''Retry an async function until it succeeds''' -.. code:: python +@retry(ZeroDivisionError, tries=3, delay=2) +def make_trouble(): + '''Retry on ZeroDivisionError, raise error after 3 attempts, + sleep 2 seconds between attempts.''' - @retry((ValueError, TypeError), delay=1, backoff=2, max_delay=4) - def make_trouble(): - '''Retry on ValueError or TypeError, sleep 1, 2, 4, 4, ... seconds between attempts.''' +@retry((ValueError, TypeError), delay=1, backoff=2) +def make_trouble(): + '''Retry on ValueError or TypeError, sleep 1, 2, 4, 8, ... seconds between attempts.''' -.. code:: python +@retry((ValueError, TypeError), delay=1, backoff=2, max_delay=4) +def make_trouble(): + '''Retry on ValueError or TypeError, sleep 1, 2, 4, 4, ... seconds between attempts.''' - @retry(ValueError, delay=1, jitter=1) - def make_trouble(): - '''Retry on ValueError, sleep 1, 2, 3, 4, ... seconds between attempts.''' +@retry(ValueError, delay=1, jitter=1) +def make_trouble(): + '''Retry on ValueError, sleep 1, 2, 3, 4, ... seconds between attempts.''' -.. code:: python +def callback(e: Exception): + '''Print error message''' + print(e) - # If you enable logging, you can get warnings like 'ValueError, retrying in - # 1 seconds' - if __name__ == '__main__': - import logging - logging.basicConfig() - make_trouble() +@retry(ValueError, fail_callback=callback): +def make_trouble(): + '''Retry on ValueError, between attempts call callback(e) + (where e is the Exception raised).''' -retry_call -^^^^^^^^^^ +# If you enable logging, you can get warnings like 'ValueError, retrying in +# 1 seconds' +if __name__ == '__main__': + import logging + logging.basicConfig() + make_trouble() +``` -.. code:: python - - def retry_call(f, fargs=None, fkwargs=None, exceptions=Exception, tries=-1, delay=0, max_delay=None, backoff=1, - jitter=0, - logger=logging_logger): - """ - Calls a function and re-executes it if it failed. - - :param f: the function to execute. - :param fargs: the positional arguments of the function to execute. - :param fkwargs: the named arguments of the function to execute. - :param exceptions: an exception or a tuple of exceptions to catch. default: Exception. - :param tries: the maximum number of attempts. default: -1 (infinite). - :param delay: initial delay between attempts. default: 0. - :param max_delay: the maximum value of delay. default: None (no limit). - :param backoff: multiplier applied to delay between attempts. default: 1 (no backoff). - :param jitter: extra seconds added to delay between attempts. default: 0. - fixed if a number, random if a range tuple (min, max) - :param logger: logger.warning(fmt, error, delay) will be called on failed attempts. - default: retry.logging_logger. if None, logging is disabled. - :returns: the result of the f function. - """ +### The `retry_call` function +Calls a function and re-executes it if it failed. This is very similar to the decorator, except that it takes a function and its arguments as parameters. The use case behind it is to be able to dynamically adjust the retry arguments. -.. code:: python - - import requests - - from retry.api import retry_call - - - def make_trouble(service, info=None): - if not info: - info = '' - r = requests.get(service + info) - return r.text - - - def what_is_my_ip(approach=None): - if approach == "optimistic": - tries = 1 - elif approach == "conservative": - tries = 3 - else: - # skeptical - tries = -1 - result = retry_call(make_trouble, fargs=["http://ipinfo.io/"], fkwargs={"info": "ip"}, tries=tries) - print(result) - - what_is_my_ip("conservative") - - - +#### Usage +`retry_call(f, fargs=None, fkwargs=None, exceptions=Exception, tries=-1, delay=0, max_delay=None, backoff=1, jitter=0, show_traceback=False, logger=logging_logger, fail_callback=None, condition=threading.Condition())` + +#### Example +```python +import requests + +from reretry.api import retry_call + + +def make_trouble(service, info=None): + if not info: + info = '' + r = requests.get(service + info) + return r.text + + +def what_is_my_ip(approach=None): + if approach == "optimistic": + tries = 1 + elif approach == "conservative": + tries = 3 + else: + # skeptical + tries = -1 + result = retry_call( + make_trouble, + fargs=["http://ipinfo.io/"], + fkwargs={"info": "ip"}, + tries=tries + ) + print(result) + +what_is_my_ip("conservative") +``` diff --git a/retry/__init__.py b/reretry/__init__.py similarity index 100% rename from retry/__init__.py rename to reretry/__init__.py diff --git a/reretry/api.py b/reretry/api.py new file mode 100644 index 0000000..7d73a9f --- /dev/null +++ b/reretry/api.py @@ -0,0 +1,279 @@ +import asyncio +from contextlib import contextmanager, AbstractContextManager, ExitStack +import logging +import random +import traceback +import threading +from functools import partial + +import inspect + +from .compat import decorator + +logging_logger = logging.getLogger(__name__) + +import inspect + +from .compat import decorator + +logging_logger = logging.getLogger(__name__) + +class ConditionWrappedWait(AbstractContextManager): + def __init__(self, c): + self.c = c + + def __enter__(self): + self.c.acquire() + return self + + def wait(self, n): + self.c.wait(n) + + @contextmanager + def _cleanup_on_error(self): + with ExitStack() as stack: + stack.push(self) + yield + # The validation check passed and didn't raise an exception + # Accordingly, we want to keep the resource, and pass it + # back to our caller + stack.pop_all() + + def __exit__(self, exc_type, exc_value, exc_tb): + self.c.release() + + +def __retry_internal( + f, + exceptions=Exception, + tries=-1, + delay=0, + max_delay=None, + backoff=1, + jitter=0, + show_traceback=False, + logger=logging_logger, + fail_callback=None, + condition=None, +): + _tries, _delay = tries, delay + + while _tries: + try: + return f() + + except exceptions as e: + _tries -= 1 + + if logger: + _log_attempt(tries, show_traceback, logger, _tries, _delay, e) + + if not _tries: + raise + + if fail_callback is not None: + fail_callback(e) + + with ConditionWrappedWait(condition) as _condition: + _condition.wait(_delay) + + _delay = _new_delay(max_delay, backoff, jitter, _delay) + + +async def __retry_internal_async( + f, + exceptions=Exception, + tries=-1, + delay=0, + max_delay=None, + backoff=1, + jitter=0, + show_traceback=False, + logger=logging_logger, + fail_callback=None, +): + _tries, _delay = tries, delay + + while _tries: + try: + return await f() + + except exceptions as e: + _tries -= 1 + + if logger: + _log_attempt(tries, show_traceback, logger, _tries, _delay, e) + + if not _tries: + raise + + if fail_callback is not None: + await fail_callback(e) + + await asyncio.sleep(_delay) + + _delay = _new_delay(max_delay, backoff, jitter, _delay) + + +def _log_attempt(tries, show_traceback, logger, _tries, _delay, e): + if _tries: + if show_traceback: + tb_str = "".join(traceback.format_exception(None, e, e.__traceback__)) + logger.warning(tb_str) + + logger.warning( + "%s, attempt %s/%s failed - retrying in %s seconds...", + e, + tries - _tries, + tries, + _delay, + ) + + elif tries > 1: + logger.warning( + "%s, attempt %s/%s failed - giving up!", e, tries - _tries, tries + ) + + +def _new_delay(max_delay, backoff, jitter, _delay): + _delay *= backoff + _delay += random.uniform(*jitter) if isinstance(jitter, tuple) else jitter + + if max_delay is not None: + _delay = min(_delay, max_delay) + + return _delay + + +def _is_async(f): + return asyncio.iscoroutinefunction(f) and not inspect.isgeneratorfunction(f) + + +def _get_internal_function(f): + return __retry_internal_async if _is_async(f) else __retry_internal + + +def _check_params(f, show_traceback=False, logger=logging_logger, fail_callback=None): + assert not show_traceback or logger is not None, "`show_traceback` needs `logger`" + + assert not fail_callback or ( + (_is_async(f) and _is_async(fail_callback)) + or (not _is_async(f) and not _is_async(fail_callback)) + ), "If the retried function is async, fail_callback needs to be async as well or vice versa" + + +def retry( + exceptions=Exception, + tries=-1, + delay=0, + max_delay=None, + backoff=1, + jitter=0, + show_traceback=False, + logger=logging_logger, + fail_callback=None, + condition=threading.Condition(), +): + """Returns a retry decorator. + + :param exceptions: an exception or a tuple of exceptions to catch. default: Exception. + :param tries: the maximum number of attempts. default: -1 (infinite). + :param delay: initial delay between attempts (in seconds). default: 0. + :param max_delay: the maximum value of delay (in seconds). default: None (no limit). + :param backoff: multiplier applied to delay between attempts. default: 1 (no backoff). + :param jitter: extra seconds added to delay between attempts. default: 0. + fixed if a number, random if a range tuple (min, max) + :param show_traceback: if True, the traceback of the exception will be logged. + :param logger: logger.warning(fmt, error, delay) will be called on failed attempts. + default: retry.logging_logger. if None, logging is disabled. + :param fail_callback: fail_callback(e) will be called on failed attempts. + :param condition: threading.Condition variable used to .wait, a default one is + provided. time.sleep is not used since it locks up all threads simultaneously. + :returns: a retry decorator. + """ + + @decorator + def retry_decorator(f, *fargs, **fkwargs): + return retry_call( + f, + fargs, + fkwargs, + exceptions, + tries, + delay, + max_delay, + backoff, + jitter, + show_traceback, + logger, + fail_callback, + condition=condition, + ) + + return retry_decorator + + +def retry_call( + f, + fargs=None, + fkwargs=None, + exceptions=Exception, + tries=-1, + delay=0, + max_delay=None, + backoff=1, + jitter=0, + show_traceback=False, + logger=logging_logger, + fail_callback=None, + condition=None, +): + """ + Calls a function and re-executes it if it failed. + + :param f: the function to execute. + :param fargs: the positional arguments of the function to execute. + :param fkwargs: the named arguments of the function to execute. + :param exceptions: an exception or a tuple of exceptions to catch. default: Exception. + :param tries: the maximum number of attempts. default: -1 (infinite). + :param delay: initial delay between attempts (in seconds). default: 0. + :param max_delay: the maximum value of delay (in seconds). default: None (no limit). + :param backoff: multiplier applied to delay between attempts. default: 1 (no backoff). + :param jitter: extra seconds added to delay between attempts. default: 0. + fixed if a number, random if a range tuple (min, max) + :param show_traceback: if True, the traceback of the exception will be logged. + :param logger: logger.warning(fmt, error, delay) will be called on failed attempts. + default: retry.logging_logger. if None, logging is disabled. + :param fail_callback: fail_callback(e) will be called on failed attempts. + :returns: the result of the f function. + """ + + args = fargs or list() + kwargs = fkwargs or dict() + + _check_params(f, show_traceback, logger, fail_callback) + + func = _get_internal_function(f) + args_list = [] + path_through_args = dict( + exceptions=exceptions, + tries=tries, + delay=delay, + max_delay=max_delay, + backoff=backoff, + jitter=jitter, + show_traceback=show_traceback, + logger=logger, + fail_callback=fail_callback, + ) + if hasattr(inspect, 'getfullargspec'): + args_list = inspect.getfullargspec(func).args + else: + args_list = inspect.getargspec(func).args + + if 'condition' in args_list: + path_through_args['condition'] = condition + return func( + partial(f, *args, **kwargs), + **path_through_args, + ) diff --git a/retry/compat.py b/reretry/compat.py similarity index 100% rename from retry/compat.py rename to reretry/compat.py diff --git a/retry/api.py b/retry/api.py deleted file mode 100644 index 4a404b9..0000000 --- a/retry/api.py +++ /dev/null @@ -1,101 +0,0 @@ -import logging -import random -import time - -from functools import partial - -from .compat import decorator - - -logging_logger = logging.getLogger(__name__) - - -def __retry_internal(f, exceptions=Exception, tries=-1, delay=0, max_delay=None, backoff=1, jitter=0, - logger=logging_logger): - """ - Executes a function and retries it if it failed. - - :param f: the function to execute. - :param exceptions: an exception or a tuple of exceptions to catch. default: Exception. - :param tries: the maximum number of attempts. default: -1 (infinite). - :param delay: initial delay between attempts. default: 0. - :param max_delay: the maximum value of delay. default: None (no limit). - :param backoff: multiplier applied to delay between attempts. default: 1 (no backoff). - :param jitter: extra seconds added to delay between attempts. default: 0. - fixed if a number, random if a range tuple (min, max) - :param logger: logger.warning(fmt, error, delay) will be called on failed attempts. - default: retry.logging_logger. if None, logging is disabled. - :returns: the result of the f function. - """ - _tries, _delay = tries, delay - while _tries: - try: - return f() - except exceptions as e: - _tries -= 1 - if not _tries: - raise - - if logger is not None: - logger.warning('%s, retrying in %s seconds...', e, _delay) - - time.sleep(_delay) - _delay *= backoff - - if isinstance(jitter, tuple): - _delay += random.uniform(*jitter) - else: - _delay += jitter - - if max_delay is not None: - _delay = min(_delay, max_delay) - - -def retry(exceptions=Exception, tries=-1, delay=0, max_delay=None, backoff=1, jitter=0, logger=logging_logger): - """Returns a retry decorator. - - :param exceptions: an exception or a tuple of exceptions to catch. default: Exception. - :param tries: the maximum number of attempts. default: -1 (infinite). - :param delay: initial delay between attempts. default: 0. - :param max_delay: the maximum value of delay. default: None (no limit). - :param backoff: multiplier applied to delay between attempts. default: 1 (no backoff). - :param jitter: extra seconds added to delay between attempts. default: 0. - fixed if a number, random if a range tuple (min, max) - :param logger: logger.warning(fmt, error, delay) will be called on failed attempts. - default: retry.logging_logger. if None, logging is disabled. - :returns: a retry decorator. - """ - - @decorator - def retry_decorator(f, *fargs, **fkwargs): - args = fargs if fargs else list() - kwargs = fkwargs if fkwargs else dict() - return __retry_internal(partial(f, *args, **kwargs), exceptions, tries, delay, max_delay, backoff, jitter, - logger) - - return retry_decorator - - -def retry_call(f, fargs=None, fkwargs=None, exceptions=Exception, tries=-1, delay=0, max_delay=None, backoff=1, - jitter=0, - logger=logging_logger): - """ - Calls a function and re-executes it if it failed. - - :param f: the function to execute. - :param fargs: the positional arguments of the function to execute. - :param fkwargs: the named arguments of the function to execute. - :param exceptions: an exception or a tuple of exceptions to catch. default: Exception. - :param tries: the maximum number of attempts. default: -1 (infinite). - :param delay: initial delay between attempts. default: 0. - :param max_delay: the maximum value of delay. default: None (no limit). - :param backoff: multiplier applied to delay between attempts. default: 1 (no backoff). - :param jitter: extra seconds added to delay between attempts. default: 0. - fixed if a number, random if a range tuple (min, max) - :param logger: logger.warning(fmt, error, delay) will be called on failed attempts. - default: retry.logging_logger. if None, logging is disabled. - :returns: the result of the f function. - """ - args = fargs if fargs else list() - kwargs = fkwargs if fkwargs else dict() - return __retry_internal(partial(f, *args, **kwargs), exceptions, tries, delay, max_delay, backoff, jitter, logger) diff --git a/tests/test_retry.py b/tests/test_retry.py index 64f45cd..736a47a 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -1,28 +1,89 @@ -try: - from unittest.mock import create_autospec -except ImportError: - from mock import create_autospec +import asyncio +from importlib import reload +from unittest.mock import MagicMock +import pytest +import time +import logging -try: - from unittest.mock import MagicMock -except ImportError: - from mock import MagicMock +logging_logger = logging.Logger(__name__) -import time +class MockCondition(object): + mock_sleep_time = [0] + def acquire(self): + logging_logger.warning('in acquire') + return True + + def release(self): + logging_logger.warning('in release') + + def wait(self, seconds): + logging_logger.warning('in wait') + print('\n'*10+'in wait'+'\n') + MockCondition.mock_sleep_time[0] += seconds + + +import threading +threading.Condition = MockCondition +from reretry.api import _is_async, retry, retry_call +import reretry.api + +thread_loops = 0 +total_thread_loops = 1000 + +def test_threaded_retry(monkeypatch): + ''' + For this thread actual threading is used, threading condition is + remocked at the end. + ''' + reload(threading) + hit = [0] -import pytest + tries = 5 + delay = 0.01 + backoff = 2 -from retry.api import retry_call -from retry.api import retry + @retry(tries=tries, delay=delay, backoff=backoff) + def f(): + hit[0] += 1 + 1 / 0 + def run(): + global thread_loops + c = threading.Condition() + c.acquire() + thread_delay = delay / total_thread_loops + for a in range(total_thread_loops): + c.wait(thread_delay) + thread_loops += 1 + c.release() + + test_thread = threading.Thread(target=run) + test_thread.start() + c = threading.Condition() + # waiting for the test to start + c.acquire() + c.wait(.01) + c.release() + + num_loops = thread_loops + with pytest.raises(ZeroDivisionError): + f() + + num_loops = thread_loops - num_loops + test_thread.join() + + assert num_loops > 0 and num_loops < total_thread_loops + threading.Condition = MockCondition + assert hit[0] == tries + assert mock_sleep_time[0] == sum(delay * backoff**i for i in range(tries - 1)) -def test_retry(monkeypatch): - mock_sleep_time = [0] - def mock_sleep(seconds): - mock_sleep_time[0] += seconds - monkeypatch.setattr(time, 'sleep', mock_sleep) + +def test_retry(monkeypatch): + MockCondition.mock_sleep_time = [0] + + monkeypatch.setattr(reretry.api.threading, 'Condition', MockCondition) hit = [0] @@ -38,21 +99,22 @@ def f(): with pytest.raises(ZeroDivisionError): f() assert hit[0] == tries - assert mock_sleep_time[0] == sum( - delay * backoff ** i for i in range(tries - 1)) + assert MockCondition.mock_sleep_time[0] == \ + sum(delay * backoff**i for i in range(tries - 1)) def test_tries_inf(): hit = [0] target = 10 - @retry(tries=float('inf')) + @retry(tries=float("inf")) def f(): hit[0] += 1 if hit[0] == target: return target else: raise ValueError + assert f() == target @@ -67,16 +129,16 @@ def f(): return target else: raise ValueError + assert f() == target def test_max_delay(monkeypatch): - mock_sleep_time = [0] + MockCondition.mock_sleep_time = [0] - def mock_sleep(seconds): - mock_sleep_time[0] += seconds - - monkeypatch.setattr(time, 'sleep', mock_sleep) + monkeypatch.setattr(reretry.api.threading, 'Condition', MockCondition) + + print(':: condition', reretry.api.threading.Condition) hit = [0] @@ -93,16 +155,13 @@ def f(): with pytest.raises(ZeroDivisionError): f() assert hit[0] == tries - assert mock_sleep_time[0] == delay * (tries - 1) + assert MockCondition.mock_sleep_time[0] == delay * (tries - 1) def test_fixed_jitter(monkeypatch): - mock_sleep_time = [0] + MockCondition.mock_sleep_time = [0] - def mock_sleep(seconds): - mock_sleep_time[0] += seconds - - monkeypatch.setattr(time, 'sleep', mock_sleep) + monkeypatch.setattr(reretry.api.threading, 'Condition', MockCondition) hit = [0] @@ -117,14 +176,16 @@ def f(): with pytest.raises(ZeroDivisionError): f() assert hit[0] == tries - assert mock_sleep_time[0] == sum(range(tries - 1)) + assert MockCondition.mock_sleep_time[0] == sum(range(tries - 1)) def test_retry_call(): f_mock = MagicMock(side_effect=RuntimeError) tries = 2 try: - retry_call(f_mock, exceptions=RuntimeError, tries=tries) + retry_call(f_mock, exceptions=RuntimeError, tries=tries + , condition=MockCondition() + ) except RuntimeError: pass @@ -137,7 +198,9 @@ def test_retry_call_2(): tries = 5 result = None try: - result = retry_call(f_mock, exceptions=RuntimeError, tries=tries) + result = retry_call(f_mock, exceptions=RuntimeError, tries=tries + , condition=MockCondition() + ) except RuntimeError: pass @@ -146,7 +209,6 @@ def test_retry_call_2(): def test_retry_call_with_args(): - def f(value=0): if value < 0: return value @@ -157,7 +219,9 @@ def f(value=0): result = None f_mock = MagicMock(spec=f, return_value=return_value) try: - result = retry_call(f_mock, fargs=[return_value]) + result = retry_call(f_mock, fargs=[return_value] + , condition=MockCondition() + ) except RuntimeError: pass @@ -166,20 +230,91 @@ def f(value=0): def test_retry_call_with_kwargs(): - def f(value=0): if value < 0: return value else: raise RuntimeError - kwargs = {'value': -1} + kwargs = {"value": -1} result = None - f_mock = MagicMock(spec=f, return_value=kwargs['value']) + f_mock = MagicMock(spec=f, return_value=kwargs["value"]) try: - result = retry_call(f_mock, fkwargs=kwargs) + result = retry_call(f_mock, fkwargs=kwargs + , condition=MockCondition() + ) except RuntimeError: pass - assert result == kwargs['value'] + assert result == kwargs["value"] assert f_mock.call_count == 1 + + +def test_retry_call_with_fail_callback(): + def f(): + raise RuntimeError + + def cb(error): + pass + + callback_mock = MagicMock(spec=cb) + try: + retry_call(f, fail_callback=callback_mock, tries=2 + , condition=MockCondition() + ) + except RuntimeError: + pass + + assert callback_mock.called + + +def test_is_async(): + async def async_func(): + pass + + def non_async_func(): + pass + + def generator(): + yield + + + assert _is_async(async_func) + assert not _is_async(non_async_func) + assert not _is_async(generator) + assert not _is_async(generator()) + assert not _is_async(MagicMock(spec=non_async_func, return_value=-1)) + + +@pytest.mark.asyncio +async def test_async(): + attempts = 1 + raised = False + + @retry(tries=2) + async def f(): + await asyncio.sleep(0.1) + nonlocal attempts, raised + if attempts: + raised = True + attempts -= 1 + raise RuntimeError + return True + + assert await f() + assert raised + assert attempts == 0 + + +def test_check_params(): + with pytest.raises(AssertionError): + retry_call(lambda: None, show_traceback=True, logger=None) + + async def async_func(): + pass + + def non_async_func(): + pass + + with pytest.raises(AssertionError): + retry_call(async_func, fail_callback=non_async_func)