diff --git a/rest_framework_condition/decorators.py b/rest_framework_condition/decorators.py index 1bbae44..92307dc 100644 --- a/rest_framework_condition/decorators.py +++ b/rest_framework_condition/decorators.py @@ -1,36 +1,90 @@ import functools -from django.views.decorators.http import condition as django_condition +import warnings +from calendar import timegm -def condition(etag_func=None, last_modified_func=None): +from django.utils.cache import get_conditional_response +from django.utils.http import http_date, quote_etag + + +def condition(etag_func=None, last_modified_func=None, use_self=False): """ - Decorator to support conditional retrieval (or change) - for a Django Rest Framework's ViewSet. + Decorator to support conditional retrieval (or change) for a Django Rest + Framework's ViewSet. + + This decorator emulates Django's original decorator by wrapping the + underlying functionality where possible but handles the Django Rest + Framework request object. - It calls Django's original decorator but pass correct request object to it. - Django's original decorator doesn't work with DRF request object. + See: django.views.decorators.http.condition """ + + if not use_self: + warnings.warn( + 'The etag_func and last_modified_func should accept a "self" ' + 'argument which matches how Django Rest Framework calls ' + 'view/viewset methods.\n\n' + 'After updating the handlers pass "use_self" to the condition ' + 'decorator to enable the future functionality and silence this ' + 'warning.', + DeprecationWarning) + def decorator(func): @functools.wraps(func) - def wrapper(obj_self, request, *args, **kwargs): - drf_request = request - wsgi_request = request._request + def wrapper(self, request, *args, **kwargs): + if etag_func: + if use_self: + etag = etag_func(self, request, *args, **kwargs) + else: + etag = etag_func(request, *args, **kwargs) + + # The value from etag_func() could be quoted or unquoted. + if etag: + etag = quote_etag(etag) + else: + etag = None + + if last_modified_func: + if use_self: + last_modified = last_modified_func( + self, request, *args, **kwargs) + else: + last_modified = last_modified_func( + request, *args, **kwargs) + + if last_modified: + last_modified = timegm(last_modified.utctimetuple()) + else: + last_modified = None + + # pass the wrapped WSGI request for Django + response = get_conditional_response( + request._request, + etag=etag, + last_modified=last_modified, + ) + + if response is None: + response = func(self, request, *args, **kwargs) + + # Set relevant headers on the response if they don't already exist + # and if the request method is safe. + if request.method in ('GET', 'HEAD'): + if last_modified and not response.has_header('Last-Modified'): + response['Last-Modified'] = http_date(last_modified) + if etag: + response.setdefault('ETag', etag) - def patched_viewset_method(*_args, **_kwargs): - """Call original viewset method with correct type of request""" - return func(obj_self, drf_request, *args, **kwargs) + return response - django_decorator = django_condition(etag_func, last_modified_func) - decorated_viewset_method = django_decorator(patched_viewset_method) - return decorated_viewset_method(wsgi_request, *args, **kwargs) return wrapper return decorator # Shortcut decorators for common cases based on ETag or Last-Modified only -def etag(etag_func): - return condition(etag_func=etag_func) +def etag(etag_func, use_self=False): + return condition(etag_func=etag_func, use_self=use_self) -def last_modified(last_modified_func): - return condition(last_modified_func=last_modified_func) +def last_modified(last_modified_func, use_self=False): + return condition(last_modified_func=last_modified_func, use_self=use_self) diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 7d5d013..2853a46 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -1,4 +1,5 @@ import calendar +import json from datetime import datetime from django.urls import reverse @@ -209,3 +210,33 @@ def test_etag_has_access_to_kwargs_from_view(self): assert response.status_code == status.HTTP_200_OK assert response['ETag'] == '"hash-42"' assert response.data == {'data': 'etag', 'pk': '42'} + + +class TestDecoratorMatchesBuiltin(APITestCase): + def check_responses(self, builtin_url, api_url): + builtin_response = self.client.get(builtin_url) + api_response = self.client.get(api_url) + + assert builtin_response.status_code == api_response.status_code + assert json.loads(builtin_response.content) == api_response.data + + # Check the headers added, but DRF is allowed to add additional + # headers, and the content length may differ. + for key in builtin_response._headers: + if key.lower() != 'content-length': + assert builtin_response[key] == api_response[key] + + def test_etag(self): + self.check_responses( + builtin_url=reverse('builtin-view-etag'), + api_url=reverse('api-view-etag')) + + def test_etag_with_kwargs(self): + self.check_responses( + builtin_url=reverse('builtin-view-etag-kwargs', args=[42]), + api_url=reverse('etag-kwargs-detail', args=[42])) + + def test_last_modified(self): + self.check_responses( + builtin_url=reverse('builtin-view-last-modified'), + api_url=reverse('api-view-last-modified')) diff --git a/tests/urls.py b/tests/urls.py index 32823b4..5a42ab0 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -2,6 +2,9 @@ from rest_framework.routers import DefaultRouter from tests.views import ( + builtin_etag_kwargs_view, + builtin_etag_view, + builtin_last_modified_view, ETagApiView, EtagFromKwargsViewSet, EtagViewSet, @@ -20,15 +23,26 @@ urlpatterns = [ path( - "^api-view/no-condition/$", + "api-view/no-condition/", NoConditionApiView.as_view(), name="api-view-no-condition", ), path( - "^api-view/last-modified/$", + "api-view/last-modified/", LastModifiedApiView.as_view(), name="api-view-last-modified", ), - path("^api-view/etag/$", ETagApiView.as_view(), name="api-view-etag"), - path("^view-set/", include(router.urls)), + path("api-view/etag/", ETagApiView.as_view(), name="api-view-etag"), + path( + "builtin/last-modified/", + builtin_last_modified_view, + name="builtin-view-last-modified" + ), + path("builtin/etag/", builtin_etag_view, name="builtin-view-etag"), + path( + "builtin/etag//", + builtin_etag_kwargs_view, + name="builtin-view-etag-kwargs", + ), + path("view-set/", include(router.urls)), ] diff --git a/tests/views.py b/tests/views.py index 8db70e7..0b8d4d7 100644 --- a/tests/views.py +++ b/tests/views.py @@ -1,36 +1,57 @@ from datetime import datetime +from django.http import JsonResponse from rest_framework import views, viewsets from rest_framework.response import Response +from django.views.decorators.http import ( + etag as builtin_etag, + last_modified as builtin_last_modified, +) from rest_framework_condition import etag, last_modified -def my_last_modified(request, *args, **kwargs): +def my_last_modified(self, request, *args, **kwargs): return datetime(2019, 1, 1) -def my_etag(request, *args, **kwargs): +def my_etag(self, request, *args, **kwargs): return 'hash123' -def etag_from_kwargs(request, *args, **kwargs): +def etag_from_kwargs(self, request, *args, **kwargs): return 'hash-{}'.format(kwargs['pk']) +@builtin_last_modified(lambda request: my_last_modified(None, request)) +def builtin_last_modified_view(request): + return JsonResponse({'data': '2019'}) + + +@builtin_etag(lambda request: my_etag(None, request)) +def builtin_etag_view(request): + return JsonResponse({'data': 'etag'}) + + +@builtin_etag( + lambda request, **kwargs: etag_from_kwargs(None, request, **kwargs)) +def builtin_etag_kwargs_view(request, pk): + return JsonResponse({'data': 'etag', 'pk': pk}) + + class NoConditionApiView(views.APIView): def get(self, request): return Response({'data': 'no-condition'}) class LastModifiedApiView(views.APIView): - @last_modified(my_last_modified) + @last_modified(my_last_modified, use_self=True) def get(self, request): return Response({'data': '2019'}) class ETagApiView(views.APIView): - @etag(my_etag) + @etag(my_etag, use_self=True) def get(self, request): return Response({'data': 'etag'}) @@ -44,26 +65,26 @@ def retrieve(self, request, pk=None): class LastModifiedViewSet(viewsets.ViewSet): - @last_modified(my_last_modified) + @last_modified(my_last_modified, use_self=True) def list(self, request): return Response({'data': '2019'}) - @last_modified(my_last_modified) + @last_modified(my_last_modified, use_self=True) def retrieve(self, request, pk=None): return Response({'data': '2019', 'pk': pk}) class EtagViewSet(viewsets.ViewSet): - @etag(my_etag) + @etag(my_etag, use_self=True) def list(self, request): return Response({'data': 'etag'}) - @etag(my_etag) + @etag(my_etag, use_self=True) def retrieve(self, request, pk=None): return Response({'data': 'etag', 'pk': pk}) class EtagFromKwargsViewSet(viewsets.ViewSet): - @etag(etag_from_kwargs) + @etag(etag_from_kwargs, use_self=True) def retrieve(self, request, pk=None): return Response({'data': 'etag', 'pk': pk}) diff --git a/tox.ini b/tox.ini index 2eb5447..fa6c252 100644 --- a/tox.ini +++ b/tox.ini @@ -12,4 +12,4 @@ deps = pytest-cov commands = # NOTE: you can run any command line tool here - not just tests - pytest --cov=rest_framework_condition/ + pytest --cov=rest_framework_condition/ {posargs}