Skip to content

Commit 0b43e92

Browse files
committed
fix: port middleware/test_proxy.py to async context
1 parent 2d49863 commit 0b43e92

File tree

3 files changed

+66
-35
lines changed

3 files changed

+66
-35
lines changed

src/sentry/testutils/asserts.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from django.db import models
44
from django.db.models.functions import Cast
5-
from django.http import StreamingHttpResponse
65

76
from sentry.constants import ObjectStatus
87
from sentry.integrations.types import EventLifecycleOutcome
@@ -44,10 +43,7 @@ def assert_commit_shape(commit):
4443
def assert_status_code(response, minimum: int, maximum: int | None = None):
4544
# Omit max to assert status_code == minimum.
4645
maximum = maximum or minimum + 1
47-
assert minimum <= response.status_code < maximum, (
48-
response.status_code,
49-
response.getvalue() if isinstance(response, StreamingHttpResponse) else response.content,
50-
)
46+
assert minimum <= response.status_code < maximum, response
5147

5248

5349
def assert_existing_projects_status(

src/sentry/testutils/cases.py

Lines changed: 38 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from uuid import UUID, uuid4
1818
from zlib import compress
1919

20+
import httpx
2021
import pytest
2122
import requests
2223
import responses
@@ -36,7 +37,6 @@
3637
from django.utils import timezone
3738
from django.utils.functional import cached_property
3839
from google.protobuf.timestamp_pb2 import Timestamp
39-
from requests.utils import CaseInsensitiveDict, get_encoding_from_headers
4040
from rest_framework import status
4141
from rest_framework.request import Request
4242
from rest_framework.response import Response
@@ -678,36 +678,45 @@ def get_cursor_headers(self, response):
678678
def api_gateway_proxy_stubbed(self):
679679
"""Mocks a fake api gateway proxy that redirects via Client objects"""
680680

681-
def proxy_raw_request(
682-
method: str,
683-
url: str,
684-
headers: Mapping[str, str],
685-
params: Mapping[str, str] | None,
686-
data: Any,
687-
**kwds: Any,
688-
) -> requests.Response:
689-
from django.test.client import Client
690-
691-
client = Client()
692-
extra: Mapping[str, Any] = {
693-
f"HTTP_{k.replace('-', '_').upper()}": v for k, v in headers.items()
694-
}
695-
if params:
696-
url += "?" + urlencode(params)
697-
with assume_test_silo_mode(SiloMode.CELL):
698-
resp = getattr(client, method.lower())(
699-
url, b"".join(data), headers["Content-Type"], **extra
700-
)
701-
response = requests.Response()
702-
response.status_code = resp.status_code
703-
response.headers = CaseInsensitiveDict(resp.headers)
704-
response.encoding = get_encoding_from_headers(response.headers)
705-
response.raw = BytesIO(resp.content)
706-
return response
681+
from asgiref.sync import sync_to_async
682+
from django.test.client import Client
683+
684+
class MockedProxy:
685+
def __init__(self):
686+
self.client = Client()
687+
688+
@staticmethod
689+
async def _consume_body(content):
690+
ret = b""
691+
async for chunk in content:
692+
ret += chunk
693+
return ret
694+
695+
def build_request(self, method, url, headers, params, content, timeout):
696+
assert not params
697+
target = getattr(self.client, method.lower())
698+
content_type = headers.pop("Content-Type", "application/octet-stream")
699+
extra: Mapping[str, Any] = {
700+
f"HTTP_{k.replace('-', '_').upper()}": v for k, v in headers.items()
701+
}
702+
return target, (url, content, content_type), extra
703+
704+
async def send(self, req, stream, follow_redirects):
705+
with assume_test_silo_mode(SiloMode.CELL):
706+
url, content, content_type = req[1]
707+
content = await self._consume_body(content)
708+
resp = await sync_to_async(req[0])(url, content, content_type, **req[2])
709+
wresp = httpx.Response(
710+
status_code=resp.status_code,
711+
headers=dict(resp.headers),
712+
content=resp.content,
713+
)
714+
return wresp
707715

716+
mock_client = MockedProxy()
708717
with mock.patch(
709-
"sentry.hybridcloud.apigateway.proxy.external_request",
710-
new=proxy_raw_request,
718+
"sentry.hybridcloud.apigateway.proxy.proxy_client",
719+
new=mock_client,
711720
):
712721
yield
713722

tests/sentry/middleware/test_proxy.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
from __future__ import annotations
22

3+
import asyncio
34
from functools import cached_property
5+
from unittest.mock import patch
46

57
from django.http import HttpRequest
68

79
from sentry.middleware.proxy import SetRemoteAddrFromForwardedFor
810
from sentry.models.team import Team
911
from sentry.silo.base import SiloMode
1012
from sentry.testutils.cases import APITestCase, TestCase
13+
from sentry.testutils.helpers.response import close_streaming_response
1114
from sentry.testutils.silo import assume_test_silo_mode, control_silo_test
1215
from sentry.types.cell import Cell, RegionCategory
1316
from sentry.utils import json
@@ -48,6 +51,29 @@ class FakedAPIProxyTest(APITestCase):
4851
endpoint = "sentry-api-0-organization-teams"
4952
method = "post"
5053

54+
def setUp(self) -> None:
55+
super().setUp()
56+
57+
from sentry.hybridcloud.apigateway.middleware import ApiGatewayMiddleware
58+
59+
_original_middleware = ApiGatewayMiddleware._process_view_inner
60+
61+
def _process_view_match(self, request, view_func, view_args, view_kwargs):
62+
try:
63+
asyncio.get_running_loop()
64+
return self._process_view_inner(request, view_func, view_args, view_kwargs)
65+
except RuntimeError:
66+
return self._process_view_sync(request, view_func, view_args, view_kwargs)
67+
68+
self._middleware_patch = patch.object(
69+
ApiGatewayMiddleware, "_process_view_match", _process_view_match
70+
)
71+
self._middleware_patch.start()
72+
73+
def tearDown(self) -> None:
74+
self._middleware_patch.stop()
75+
super().tearDown()
76+
5177
def test_through_api_gateway(self) -> None:
5278
if SiloMode.get_current_mode() == SiloMode.MONOLITH:
5379
return
@@ -62,7 +88,7 @@ def test_through_api_gateway(self) -> None:
6288
status_code=201,
6389
)
6490

65-
result = json.loads(resp.getvalue())
91+
result = json.loads(close_streaming_response(resp))
6692
with assume_test_silo_mode(SiloMode.CELL):
6793
team = Team.objects.get(id=result["id"])
6894
assert team.idp_provisioned

0 commit comments

Comments
 (0)