Skip to content

Commit 331e2b0

Browse files
fix(pipeline): Typing in oauth2 (#111754)
1 parent ebfe30a commit 331e2b0

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

src/sentry/identity/oauth2.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
IntegrationPipelineViewEvent,
3333
IntegrationPipelineViewType,
3434
)
35+
from sentry.pipeline.base import Pipeline
3536
from sentry.pipeline.types import PipelineStepResult
3637
from sentry.pipeline.views.base import PipelineView
3738
from sentry.shared_integrations.exceptions import ApiError, ApiInvalidRequestError, ApiUnauthorized
@@ -282,7 +283,7 @@ def __init__(
282283
self.bind_key = bind_key
283284
self.extra_authorize_params = extra_authorize_params or {}
284285

285-
def get_step_data(self, pipeline: Any, request: HttpRequest) -> dict[str, str]:
286+
def get_step_data(self, pipeline: Pipeline[Any, Any], request: HttpRequest) -> dict[str, str]:
286287
params = urlencode(
287288
{
288289
"client_id": self.client_id,
@@ -301,7 +302,7 @@ def get_serializer_cls(self) -> type:
301302
def handle_post(
302303
self,
303304
validated_data: dict[str, str],
304-
pipeline: Any,
305+
pipeline: Pipeline[Any, Any],
305306
request: HttpRequest,
306307
) -> PipelineStepResult:
307308
code = validated_data["code"]

tests/sentry/identity/test_oauth2.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from collections import namedtuple
44
from functools import cached_property
5-
from typing import Any
5+
from typing import Any, cast
66
from unittest import TestCase
77
from unittest.mock import MagicMock, patch
88
from urllib.parse import parse_qs, parse_qsl, urlparse
@@ -16,6 +16,7 @@
1616
from sentry.identity.pipeline import IdentityPipeline
1717
from sentry.identity.providers.dummy import DummyProvider
1818
from sentry.integrations.types import EventLifecycleOutcome
19+
from sentry.pipeline.base import Pipeline
1920
from sentry.pipeline.types import PipelineStepAction
2021
from sentry.shared_integrations.exceptions import ApiUnauthorized
2122
from sentry.testutils.asserts import assert_failure_metric, assert_slo_metric
@@ -245,7 +246,7 @@ def step(self) -> OAuth2ApiStep:
245246
)
246247

247248
def test_returns_oauth_url(self) -> None:
248-
ctx = _FakePipelineContext(signature="abc123")
249+
ctx = cast(Pipeline, _FakePipelineContext(signature="abc123"))
249250
request = RequestFactory().get("/")
250251
data = self.step.get_step_data(ctx, request)
251252

@@ -299,7 +300,7 @@ def test_exchange_token_success(self) -> None:
299300
"https://example.org/oauth/token",
300301
json={"access_token": "a-fake-token"},
301302
)
302-
ctx = _FakePipelineContext(signature="valid-state")
303+
ctx = cast(Pipeline, _FakePipelineContext(signature="valid-state"))
303304
result = self.step.handle_post(
304305
{"code": "auth-code", "state": "valid-state"}, ctx, self.request
305306
)
@@ -315,7 +316,7 @@ def test_exchange_token_success(self) -> None:
315316
assert data["client_secret"] == "secret-value"
316317

317318
def test_invalid_state(self) -> None:
318-
ctx = _FakePipelineContext(signature="correct-state")
319+
ctx = cast(Pipeline, _FakePipelineContext(signature="correct-state"))
319320
result = self.step.handle_post(
320321
{"code": "auth-code", "state": "wrong-state"}, ctx, self.request
321322
)
@@ -331,7 +332,7 @@ def ssl_error(request):
331332
responses.add_callback(
332333
responses.POST, "https://example.org/oauth/token", callback=ssl_error
333334
)
334-
ctx = _FakePipelineContext(signature="valid-state")
335+
ctx = cast(Pipeline, _FakePipelineContext(signature="valid-state"))
335336
result = self.step.handle_post(
336337
{"code": "auth-code", "state": "valid-state"}, ctx, self.request
337338
)
@@ -347,7 +348,7 @@ def connection_error(request):
347348
responses.add_callback(
348349
responses.POST, "https://example.org/oauth/token", callback=connection_error
349350
)
350-
ctx = _FakePipelineContext(signature="valid-state")
351+
ctx = cast(Pipeline, _FakePipelineContext(signature="valid-state"))
351352
result = self.step.handle_post(
352353
{"code": "auth-code", "state": "valid-state"}, ctx, self.request
353354
)
@@ -358,7 +359,7 @@ def connection_error(request):
358359
@responses.activate
359360
def test_empty_response_body(self) -> None:
360361
responses.add(responses.POST, "https://example.org/oauth/token", body="")
361-
ctx = _FakePipelineContext(signature="valid-state")
362+
ctx = cast(Pipeline, _FakePipelineContext(signature="valid-state"))
362363
result = self.step.handle_post(
363364
{"code": "auth-code", "state": "valid-state"}, ctx, self.request
364365
)
@@ -374,7 +375,7 @@ def test_api_error_401(self) -> None:
374375
json={"error": "unauthorized"},
375376
status=401,
376377
)
377-
ctx = _FakePipelineContext(signature="valid-state")
378+
ctx = cast(Pipeline, _FakePipelineContext(signature="valid-state"))
378379
result = self.step.handle_post(
379380
{"code": "auth-code", "state": "valid-state"}, ctx, self.request
380381
)

0 commit comments

Comments
 (0)