|
| 1 | +from __future__ import annotations |
| 2 | + |
1 | 3 | from collections import namedtuple |
2 | 4 | from functools import cached_property |
| 5 | +from typing import Any |
3 | 6 | from unittest import TestCase |
4 | 7 | from unittest.mock import MagicMock, patch |
5 | 8 | from urllib.parse import parse_qs, parse_qsl, urlparse |
|
9 | 12 | from requests.exceptions import SSLError |
10 | 13 |
|
11 | 14 | import sentry.identity |
12 | | -from sentry.identity.oauth2 import OAuth2CallbackView, OAuth2LoginView |
| 15 | +from sentry.identity.oauth2 import OAuth2ApiStep, OAuth2CallbackView, OAuth2LoginView |
13 | 16 | from sentry.identity.pipeline import IdentityPipeline |
14 | 17 | from sentry.identity.providers.dummy import DummyProvider |
15 | 18 | from sentry.integrations.types import EventLifecycleOutcome |
| 19 | +from sentry.pipeline.types import PipelineStepAction |
16 | 20 | from sentry.shared_integrations.exceptions import ApiUnauthorized |
17 | 21 | from sentry.testutils.asserts import assert_failure_metric, assert_slo_metric |
18 | 22 | from sentry.testutils.silo import control_silo_test |
@@ -209,3 +213,171 @@ def test_customer_domains(self) -> None: |
209 | 213 | assert query["response_type"][0] == "code" |
210 | 214 | assert query["scope"][0] == "all-the-things" |
211 | 215 | assert "state" in query |
| 216 | + |
| 217 | + |
| 218 | +class _FakePipelineContext: |
| 219 | + """Minimal pipeline-like object for testing OAuth2ApiStep.""" |
| 220 | + |
| 221 | + def __init__(self, signature: str = "test-signature") -> None: |
| 222 | + self.signature = signature |
| 223 | + self._state: dict[str, Any] = {} |
| 224 | + |
| 225 | + def bind_state(self, key: str, value: Any) -> None: |
| 226 | + self._state[key] = value |
| 227 | + |
| 228 | + def fetch_state(self, key: str | None = None) -> Any: |
| 229 | + if key is None: |
| 230 | + return self._state |
| 231 | + return self._state.get(key) |
| 232 | + |
| 233 | + |
| 234 | +@control_silo_test |
| 235 | +class OAuth2ApiStepGetStepDataTest(TestCase): |
| 236 | + @cached_property |
| 237 | + def step(self) -> OAuth2ApiStep: |
| 238 | + return OAuth2ApiStep( |
| 239 | + authorize_url="https://example.org/oauth2/authorize", |
| 240 | + client_id="123456", |
| 241 | + client_secret="secret-value", |
| 242 | + access_token_url="https://example.org/oauth/token", |
| 243 | + scope="all-the-things", |
| 244 | + redirect_url="/extensions/default/setup/", |
| 245 | + ) |
| 246 | + |
| 247 | + def test_returns_oauth_url(self) -> None: |
| 248 | + ctx = _FakePipelineContext(signature="abc123") |
| 249 | + request = RequestFactory().get("/") |
| 250 | + data = self.step.get_step_data(ctx, request) |
| 251 | + |
| 252 | + assert "oauthUrl" in data |
| 253 | + url = urlparse(data["oauthUrl"]) |
| 254 | + assert url.scheme == "https" |
| 255 | + assert url.hostname == "example.org" |
| 256 | + assert url.path == "/oauth2/authorize" |
| 257 | + |
| 258 | + query = parse_qs(url.query) |
| 259 | + assert query["client_id"] == ["123456"] |
| 260 | + assert query["response_type"] == ["code"] |
| 261 | + assert query["scope"] == ["all-the-things"] |
| 262 | + assert query["state"] == ["abc123"] |
| 263 | + assert "redirect_uri" in query |
| 264 | + |
| 265 | + def test_serializer_requires_code_and_state(self) -> None: |
| 266 | + ser_cls = self.step.get_serializer_cls() |
| 267 | + assert ser_cls is not None |
| 268 | + |
| 269 | + ser = ser_cls(data={}) |
| 270 | + assert not ser.is_valid() |
| 271 | + assert "code" in ser.errors |
| 272 | + assert "state" in ser.errors |
| 273 | + |
| 274 | + ser = ser_cls(data={"code": "abc", "state": "xyz"}) |
| 275 | + assert ser.is_valid() |
| 276 | + |
| 277 | + |
| 278 | +@control_silo_test |
| 279 | +class OAuth2ApiStepHandlePostTest(TestCase): |
| 280 | + def setUp(self) -> None: |
| 281 | + super().setUp() |
| 282 | + self.request = RequestFactory().get("/") |
| 283 | + |
| 284 | + @cached_property |
| 285 | + def step(self) -> OAuth2ApiStep: |
| 286 | + return OAuth2ApiStep( |
| 287 | + authorize_url="https://example.org/oauth2/authorize", |
| 288 | + client_id="123456", |
| 289 | + client_secret="secret-value", |
| 290 | + access_token_url="https://example.org/oauth/token", |
| 291 | + scope="all-the-things", |
| 292 | + redirect_url="/extensions/default/setup/", |
| 293 | + ) |
| 294 | + |
| 295 | + @responses.activate |
| 296 | + def test_exchange_token_success(self) -> None: |
| 297 | + responses.add( |
| 298 | + responses.POST, |
| 299 | + "https://example.org/oauth/token", |
| 300 | + json={"access_token": "a-fake-token"}, |
| 301 | + ) |
| 302 | + ctx = _FakePipelineContext(signature="valid-state") |
| 303 | + result = self.step.handle_post( |
| 304 | + {"code": "auth-code", "state": "valid-state"}, ctx, self.request |
| 305 | + ) |
| 306 | + |
| 307 | + assert result.action == PipelineStepAction.ADVANCE |
| 308 | + assert ctx.fetch_state("data") == {"access_token": "a-fake-token"} |
| 309 | + |
| 310 | + assert len(responses.calls) == 1 |
| 311 | + data = dict(parse_qsl(responses.calls[0].request.body)) |
| 312 | + assert data["grant_type"] == "authorization_code" |
| 313 | + assert data["code"] == "auth-code" |
| 314 | + assert data["client_id"] == "123456" |
| 315 | + assert data["client_secret"] == "secret-value" |
| 316 | + |
| 317 | + def test_invalid_state(self) -> None: |
| 318 | + ctx = _FakePipelineContext(signature="correct-state") |
| 319 | + result = self.step.handle_post( |
| 320 | + {"code": "auth-code", "state": "wrong-state"}, ctx, self.request |
| 321 | + ) |
| 322 | + |
| 323 | + assert result.action == PipelineStepAction.ERROR |
| 324 | + assert "detail" in result.data |
| 325 | + |
| 326 | + @responses.activate |
| 327 | + def test_ssl_error(self) -> None: |
| 328 | + def ssl_error(request): |
| 329 | + raise SSLError("Could not build connection") |
| 330 | + |
| 331 | + responses.add_callback( |
| 332 | + responses.POST, "https://example.org/oauth/token", callback=ssl_error |
| 333 | + ) |
| 334 | + ctx = _FakePipelineContext(signature="valid-state") |
| 335 | + result = self.step.handle_post( |
| 336 | + {"code": "auth-code", "state": "valid-state"}, ctx, self.request |
| 337 | + ) |
| 338 | + |
| 339 | + assert result.action == PipelineStepAction.ERROR |
| 340 | + assert "SSL" in result.data["detail"] |
| 341 | + |
| 342 | + @responses.activate |
| 343 | + def test_connection_error(self) -> None: |
| 344 | + def connection_error(request): |
| 345 | + raise ConnectionError("Name or service not known") |
| 346 | + |
| 347 | + responses.add_callback( |
| 348 | + responses.POST, "https://example.org/oauth/token", callback=connection_error |
| 349 | + ) |
| 350 | + ctx = _FakePipelineContext(signature="valid-state") |
| 351 | + result = self.step.handle_post( |
| 352 | + {"code": "auth-code", "state": "valid-state"}, ctx, self.request |
| 353 | + ) |
| 354 | + |
| 355 | + assert result.action == PipelineStepAction.ERROR |
| 356 | + assert "connect" in result.data["detail"].lower() |
| 357 | + |
| 358 | + @responses.activate |
| 359 | + def test_empty_response_body(self) -> None: |
| 360 | + responses.add(responses.POST, "https://example.org/oauth/token", body="") |
| 361 | + ctx = _FakePipelineContext(signature="valid-state") |
| 362 | + result = self.step.handle_post( |
| 363 | + {"code": "auth-code", "state": "valid-state"}, ctx, self.request |
| 364 | + ) |
| 365 | + |
| 366 | + assert result.action == PipelineStepAction.ERROR |
| 367 | + assert "json" in result.data["detail"].lower() |
| 368 | + |
| 369 | + @responses.activate |
| 370 | + def test_api_error_401(self) -> None: |
| 371 | + responses.add( |
| 372 | + responses.POST, |
| 373 | + "https://example.org/oauth/token", |
| 374 | + json={"error": "unauthorized"}, |
| 375 | + status=401, |
| 376 | + ) |
| 377 | + ctx = _FakePipelineContext(signature="valid-state") |
| 378 | + result = self.step.handle_post( |
| 379 | + {"code": "auth-code", "state": "valid-state"}, ctx, self.request |
| 380 | + ) |
| 381 | + |
| 382 | + assert result.action == PipelineStepAction.ERROR |
| 383 | + assert "401" in result.data["detail"] |
0 commit comments