|
17 | 17 | from uuid import UUID, uuid4 |
18 | 18 | from zlib import compress |
19 | 19 |
|
| 20 | +import httpx |
20 | 21 | import pytest |
21 | 22 | import requests |
22 | 23 | import responses |
|
36 | 37 | from django.utils import timezone |
37 | 38 | from django.utils.functional import cached_property |
38 | 39 | from google.protobuf.timestamp_pb2 import Timestamp |
39 | | -from requests.utils import CaseInsensitiveDict, get_encoding_from_headers |
40 | 40 | from rest_framework import status |
41 | 41 | from rest_framework.request import Request |
42 | 42 | from rest_framework.response import Response |
@@ -678,36 +678,45 @@ def get_cursor_headers(self, response): |
678 | 678 | def api_gateway_proxy_stubbed(self): |
679 | 679 | """Mocks a fake api gateway proxy that redirects via Client objects""" |
680 | 680 |
|
681 | | - def proxy_raw_request( |
682 | | - method: str, |
683 | | - url: str, |
684 | | - headers: Mapping[str, str], |
685 | | - params: Mapping[str, str] | None, |
686 | | - data: Any, |
687 | | - **kwds: Any, |
688 | | - ) -> requests.Response: |
689 | | - from django.test.client import Client |
690 | | - |
691 | | - client = Client() |
692 | | - extra: Mapping[str, Any] = { |
693 | | - f"HTTP_{k.replace('-', '_').upper()}": v for k, v in headers.items() |
694 | | - } |
695 | | - if params: |
696 | | - url += "?" + urlencode(params) |
697 | | - with assume_test_silo_mode(SiloMode.CELL): |
698 | | - resp = getattr(client, method.lower())( |
699 | | - url, b"".join(data), headers["Content-Type"], **extra |
700 | | - ) |
701 | | - response = requests.Response() |
702 | | - response.status_code = resp.status_code |
703 | | - response.headers = CaseInsensitiveDict(resp.headers) |
704 | | - response.encoding = get_encoding_from_headers(response.headers) |
705 | | - response.raw = BytesIO(resp.content) |
706 | | - return response |
| 681 | + from asgiref.sync import sync_to_async |
| 682 | + from django.test.client import Client |
| 683 | + |
| 684 | + class MockedProxy: |
| 685 | + def __init__(self): |
| 686 | + self.client = Client() |
| 687 | + |
| 688 | + @staticmethod |
| 689 | + async def _consume_body(content): |
| 690 | + ret = b"" |
| 691 | + async for chunk in content: |
| 692 | + ret += chunk |
| 693 | + return ret |
| 694 | + |
| 695 | + def build_request(self, method, url, headers, params, content, timeout): |
| 696 | + assert not params |
| 697 | + target = getattr(self.client, method.lower()) |
| 698 | + content_type = headers.pop("Content-Type", "application/octet-stream") |
| 699 | + extra: Mapping[str, Any] = { |
| 700 | + f"HTTP_{k.replace('-', '_').upper()}": v for k, v in headers.items() |
| 701 | + } |
| 702 | + return target, (url, content, content_type), extra |
| 703 | + |
| 704 | + async def send(self, req, stream, follow_redirects): |
| 705 | + with assume_test_silo_mode(SiloMode.CELL): |
| 706 | + url, content, content_type = req[1] |
| 707 | + content = await self._consume_body(content) |
| 708 | + resp = await sync_to_async(req[0])(url, content, content_type, **req[2]) |
| 709 | + wresp = httpx.Response( |
| 710 | + status_code=resp.status_code, |
| 711 | + headers=dict(resp.headers), |
| 712 | + content=resp.content, |
| 713 | + ) |
| 714 | + return wresp |
707 | 715 |
|
| 716 | + mock_client = MockedProxy() |
708 | 717 | with mock.patch( |
709 | | - "sentry.hybridcloud.apigateway.proxy.external_request", |
710 | | - new=proxy_raw_request, |
| 718 | + "sentry.hybridcloud.apigateway.proxy.proxy_client", |
| 719 | + new=mock_client, |
711 | 720 | ): |
712 | 721 | yield |
713 | 722 |
|
|
0 commit comments