Skip to content

Commit c64d13e

Browse files
committed
feat(identity): Add OAuth2ApiStep for API-driven OAuth2 flows
Add a generic API-mode step that handles the full OAuth2 authorization code flow in a single step: returns the authorize URL for the frontend to open in a popup, then accepts code/state from the trampoline callback and exchanges it for an access token. Includes a configurable bind_key for controlling where token data is stored in pipeline state. Refs VDY-37
1 parent 36f6e8a commit c64d13e

File tree

2 files changed

+318
-2
lines changed

2 files changed

+318
-2
lines changed

src/sentry/identity/oauth2.py

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from django.views.decorators.csrf import csrf_exempt
1717
from requests import Response
1818
from requests.exceptions import HTTPError, SSLError
19+
from rest_framework.fields import CharField
20+
from rest_framework.serializers import Serializer
1921

2022
from sentry.auth.exceptions import IdentityNotValid
2123
from sentry.exceptions import NotRegistered
@@ -30,20 +32,27 @@
3032
IntegrationPipelineViewEvent,
3133
IntegrationPipelineViewType,
3234
)
35+
from sentry.pipeline.types import PipelineStepResult
3336
from sentry.pipeline.views.base import PipelineView
3437
from sentry.shared_integrations.exceptions import ApiError, ApiInvalidRequestError, ApiUnauthorized
3538
from sentry.users.models.identity import Identity
3639
from sentry.utils.http import absolute_uri
3740

3841
from .base import Provider
3942

40-
__all__ = ["OAuth2Provider", "OAuth2CallbackView", "OAuth2LoginView"]
43+
__all__ = ["OAuth2Provider", "OAuth2CallbackView", "OAuth2LoginView", "OAuth2ApiStep"]
4144

4245
logger = logging.getLogger(__name__)
4346
ERR_INVALID_STATE = "An error occurred while validating your request."
4447
ERR_TOKEN_RETRIEVAL = "Failed to retrieve token from the upstream service."
4548

4649

50+
class OAuth2ApiStepError(Exception):
51+
"""Raised when the OAuth2 API step encounters an error during token exchange."""
52+
53+
pass
54+
55+
4756
def _redirect_url(pipeline: IdentityPipeline) -> str:
4857
associate_url = reverse(
4958
"sentry-extension-setup",
@@ -137,6 +146,23 @@ def get_pipeline_views(self) -> list[PipelineView[IdentityPipeline]]:
137146
),
138147
]
139148

149+
def get_pipeline_api_steps(self) -> list[OAuth2ApiStep]:
150+
redirect_url = self.config.get(
151+
"redirect_url",
152+
reverse("sentry-extension-setup", kwargs={"provider_id": "default"}),
153+
)
154+
return [
155+
OAuth2ApiStep(
156+
authorize_url=self.get_oauth_authorize_url(),
157+
client_id=self.get_oauth_client_id(),
158+
client_secret=self.get_oauth_client_secret(),
159+
access_token_url=self.get_oauth_access_token_url(),
160+
scope=" ".join(self.get_oauth_scopes()),
161+
redirect_url=redirect_url,
162+
verify_ssl=self.config.get("verify_ssl", True),
163+
),
164+
]
165+
140166
def get_refresh_token_params(
141167
self, refresh_token: str, identity: Identity | RpcIdentity, **kwargs: Any
142168
) -> dict[str, str | None]:
@@ -214,6 +240,124 @@ def record_event(event: IntegrationPipelineViewType, provider: str):
214240
)
215241

216242

243+
class OAuth2ApiSerializer(Serializer):
244+
code = CharField(required=True)
245+
state = CharField(required=True)
246+
247+
248+
class OAuth2ApiStep:
249+
"""
250+
Generic API-mode step for OAuth2 identity authentication.
251+
252+
Handles the full OAuth2 authorization code flow in a single API step:
253+
254+
- GET (get_step_data): returns the OAuth authorize URL for the frontend to
255+
open in a popup.
256+
- POST (handle_post): receives the callback params (code, state) relayed by
257+
the trampoline via postMessage, validates state, exchanges the code for an
258+
access token, and binds the token data to pipeline state.
259+
"""
260+
261+
step_name = "oauth_login"
262+
263+
def __init__(
264+
self,
265+
authorize_url: str,
266+
client_id: str,
267+
client_secret: str,
268+
access_token_url: str,
269+
scope: str,
270+
redirect_url: str,
271+
verify_ssl: bool = True,
272+
bind_key: str = "data",
273+
extra_authorize_params: dict[str, str] | None = None,
274+
) -> None:
275+
self.authorize_url = authorize_url
276+
self.client_id = client_id
277+
self.client_secret = client_secret
278+
self.access_token_url = access_token_url
279+
self.scope = scope
280+
self.redirect_url = redirect_url
281+
self.verify_ssl = verify_ssl
282+
self.bind_key = bind_key
283+
self.extra_authorize_params = extra_authorize_params or {}
284+
285+
def get_step_data(self, pipeline: Any, request: HttpRequest) -> dict[str, str]:
286+
params = urlencode(
287+
{
288+
"client_id": self.client_id,
289+
"response_type": "code",
290+
"scope": self.scope,
291+
"state": pipeline.signature,
292+
"redirect_uri": absolute_uri(self.redirect_url),
293+
**self.extra_authorize_params,
294+
}
295+
)
296+
return {"oauthUrl": f"{self.authorize_url}?{params}"}
297+
298+
def get_serializer_cls(self) -> type:
299+
return OAuth2ApiSerializer
300+
301+
def handle_post(
302+
self,
303+
validated_data: dict[str, str],
304+
pipeline: Any,
305+
request: HttpRequest,
306+
) -> PipelineStepResult:
307+
code = validated_data["code"]
308+
state = validated_data["state"]
309+
310+
if state != pipeline.signature:
311+
return PipelineStepResult.error(ERR_INVALID_STATE)
312+
313+
try:
314+
data = self._exchange_token(code)
315+
except OAuth2ApiStepError as e:
316+
logger.info("identity.token-exchange-error", extra={"error": str(e)})
317+
return PipelineStepResult.error(str(e))
318+
319+
pipeline.bind_state(self.bind_key, data)
320+
return PipelineStepResult.advance()
321+
322+
def _exchange_token(self, code: str) -> dict[str, Any]:
323+
"""Exchange an authorization code for an access token.
324+
325+
Raises OAuth2ApiStepError on failure.
326+
"""
327+
token_params = {
328+
"grant_type": "authorization_code",
329+
"code": code,
330+
"redirect_uri": absolute_uri(self.redirect_url),
331+
"client_id": self.client_id,
332+
"client_secret": self.client_secret,
333+
}
334+
try:
335+
req = safe_urlopen(self.access_token_url, data=token_params, verify_ssl=self.verify_ssl)
336+
req.raise_for_status()
337+
except HTTPError as e:
338+
error_resp = e.response
339+
exc = ApiError.from_response(error_resp, url=self.access_token_url)
340+
sentry_sdk.capture_exception(exc)
341+
raise OAuth2ApiStepError(
342+
f"Could not retrieve access token. Received {exc.code}: {exc.text}"
343+
) from e
344+
except SSLError as e:
345+
raise OAuth2ApiStepError(
346+
f"Could not verify SSL certificate for {self.access_token_url}"
347+
) from e
348+
except ConnectionError as e:
349+
raise OAuth2ApiStepError(f"Could not connect to {self.access_token_url}") from e
350+
351+
try:
352+
body = safe_urlread(req)
353+
content_type = req.headers.get("Content-Type", "").lower()
354+
if content_type.startswith("application/x-www-form-urlencoded"):
355+
return dict(parse_qsl(body))
356+
return orjson.loads(body)
357+
except orjson.JSONDecodeError as e:
358+
raise OAuth2ApiStepError("Could not decode a JSON response, please try again.") from e
359+
360+
217361
class OAuth2LoginView:
218362
authorize_url: str | None = None
219363
client_id: str | None = None

tests/sentry/identity/test_oauth2.py

Lines changed: 173 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from __future__ import annotations
2+
13
from collections import namedtuple
24
from functools import cached_property
5+
from typing import Any
36
from unittest import TestCase
47
from unittest.mock import MagicMock, patch
58
from urllib.parse import parse_qs, parse_qsl, urlparse
@@ -9,10 +12,11 @@
912
from requests.exceptions import SSLError
1013

1114
import sentry.identity
12-
from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView
15+
from sentry.identity.oauth2 import OAuth2ApiStep, OAuth2CallbackView, OAuth2LoginView
1316
from sentry.identity.pipeline import IdentityPipeline
1417
from sentry.identity.providers.dummy import DummyProvider
1518
from sentry.integrations.types import EventLifecycleOutcome
19+
from sentry.pipeline.types import PipelineStepAction
1620
from sentry.shared_integrations.exceptions import ApiUnauthorized
1721
from sentry.testutils.asserts import assert_failure_metric, assert_slo_metric
1822
from sentry.testutils.silo import control_silo_test
@@ -209,3 +213,171 @@ def test_customer_domains(self) -> None:
209213
assert query["response_type"][0] == "code"
210214
assert query["scope"][0] == "all-the-things"
211215
assert "state" in query
216+
217+
218+
class _FakePipelineContext:
219+
"""Minimal pipeline-like object for testing OAuth2ApiStep."""
220+
221+
def __init__(self, signature: str = "test-signature") -> None:
222+
self.signature = signature
223+
self._state: dict[str, Any] = {}
224+
225+
def bind_state(self, key: str, value: Any) -> None:
226+
self._state[key] = value
227+
228+
def fetch_state(self, key: str | None = None) -> Any:
229+
if key is None:
230+
return self._state
231+
return self._state.get(key)
232+
233+
234+
@control_silo_test
235+
class OAuth2ApiStepGetStepDataTest(TestCase):
236+
@cached_property
237+
def step(self) -> OAuth2ApiStep:
238+
return OAuth2ApiStep(
239+
authorize_url="https://example.org/oauth2/authorize",
240+
client_id="123456",
241+
client_secret="secret-value",
242+
access_token_url="https://example.org/oauth/token",
243+
scope="all-the-things",
244+
redirect_url="/extensions/default/setup/",
245+
)
246+
247+
def test_returns_oauth_url(self) -> None:
248+
ctx = _FakePipelineContext(signature="abc123")
249+
request = RequestFactory().get("/")
250+
data = self.step.get_step_data(ctx, request)
251+
252+
assert "oauthUrl" in data
253+
url = urlparse(data["oauthUrl"])
254+
assert url.scheme == "https"
255+
assert url.hostname == "example.org"
256+
assert url.path == "/oauth2/authorize"
257+
258+
query = parse_qs(url.query)
259+
assert query["client_id"] == ["123456"]
260+
assert query["response_type"] == ["code"]
261+
assert query["scope"] == ["all-the-things"]
262+
assert query["state"] == ["abc123"]
263+
assert "redirect_uri" in query
264+
265+
def test_serializer_requires_code_and_state(self) -> None:
266+
ser_cls = self.step.get_serializer_cls()
267+
assert ser_cls is not None
268+
269+
ser = ser_cls(data={})
270+
assert not ser.is_valid()
271+
assert "code" in ser.errors
272+
assert "state" in ser.errors
273+
274+
ser = ser_cls(data={"code": "abc", "state": "xyz"})
275+
assert ser.is_valid()
276+
277+
278+
@control_silo_test
279+
class OAuth2ApiStepHandlePostTest(TestCase):
280+
def setUp(self) -> None:
281+
super().setUp()
282+
self.request = RequestFactory().get("/")
283+
284+
@cached_property
285+
def step(self) -> OAuth2ApiStep:
286+
return OAuth2ApiStep(
287+
authorize_url="https://example.org/oauth2/authorize",
288+
client_id="123456",
289+
client_secret="secret-value",
290+
access_token_url="https://example.org/oauth/token",
291+
scope="all-the-things",
292+
redirect_url="/extensions/default/setup/",
293+
)
294+
295+
@responses.activate
296+
def test_exchange_token_success(self) -> None:
297+
responses.add(
298+
responses.POST,
299+
"https://example.org/oauth/token",
300+
json={"access_token": "a-fake-token"},
301+
)
302+
ctx = _FakePipelineContext(signature="valid-state")
303+
result = self.step.handle_post(
304+
{"code": "auth-code", "state": "valid-state"}, ctx, self.request
305+
)
306+
307+
assert result.action == PipelineStepAction.ADVANCE
308+
assert ctx.fetch_state("data") == {"access_token": "a-fake-token"}
309+
310+
assert len(responses.calls) == 1
311+
data = dict(parse_qsl(responses.calls[0].request.body))
312+
assert data["grant_type"] == "authorization_code"
313+
assert data["code"] == "auth-code"
314+
assert data["client_id"] == "123456"
315+
assert data["client_secret"] == "secret-value"
316+
317+
def test_invalid_state(self) -> None:
318+
ctx = _FakePipelineContext(signature="correct-state")
319+
result = self.step.handle_post(
320+
{"code": "auth-code", "state": "wrong-state"}, ctx, self.request
321+
)
322+
323+
assert result.action == PipelineStepAction.ERROR
324+
assert "detail" in result.data
325+
326+
@responses.activate
327+
def test_ssl_error(self) -> None:
328+
def ssl_error(request):
329+
raise SSLError("Could not build connection")
330+
331+
responses.add_callback(
332+
responses.POST, "https://example.org/oauth/token", callback=ssl_error
333+
)
334+
ctx = _FakePipelineContext(signature="valid-state")
335+
result = self.step.handle_post(
336+
{"code": "auth-code", "state": "valid-state"}, ctx, self.request
337+
)
338+
339+
assert result.action == PipelineStepAction.ERROR
340+
assert "SSL" in result.data["detail"]
341+
342+
@responses.activate
343+
def test_connection_error(self) -> None:
344+
def connection_error(request):
345+
raise ConnectionError("Name or service not known")
346+
347+
responses.add_callback(
348+
responses.POST, "https://example.org/oauth/token", callback=connection_error
349+
)
350+
ctx = _FakePipelineContext(signature="valid-state")
351+
result = self.step.handle_post(
352+
{"code": "auth-code", "state": "valid-state"}, ctx, self.request
353+
)
354+
355+
assert result.action == PipelineStepAction.ERROR
356+
assert "connect" in result.data["detail"].lower()
357+
358+
@responses.activate
359+
def test_empty_response_body(self) -> None:
360+
responses.add(responses.POST, "https://example.org/oauth/token", body="")
361+
ctx = _FakePipelineContext(signature="valid-state")
362+
result = self.step.handle_post(
363+
{"code": "auth-code", "state": "valid-state"}, ctx, self.request
364+
)
365+
366+
assert result.action == PipelineStepAction.ERROR
367+
assert "json" in result.data["detail"].lower()
368+
369+
@responses.activate
370+
def test_api_error_401(self) -> None:
371+
responses.add(
372+
responses.POST,
373+
"https://example.org/oauth/token",
374+
json={"error": "unauthorized"},
375+
status=401,
376+
)
377+
ctx = _FakePipelineContext(signature="valid-state")
378+
result = self.step.handle_post(
379+
{"code": "auth-code", "state": "valid-state"}, ctx, self.request
380+
)
381+
382+
assert result.action == PipelineStepAction.ERROR
383+
assert "401" in result.data["detail"]

0 commit comments

Comments
 (0)