Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion src/sentry/integrations/discord/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from django.http.request import HttpRequest
from django.http.response import HttpResponseBase
from django.utils.translation import gettext_lazy as _
from rest_framework.fields import CharField

from sentry import options
from sentry.api.serializers.rest_framework.base import CamelSnakeSerializer
from sentry.constants import ObjectStatus
from sentry.integrations.base import (
FeatureDescription,
Expand All @@ -32,7 +34,8 @@
)
from sentry.notifications.platform.target import IntegrationNotificationTarget
from sentry.organizations.services.organization.model import RpcOrganization
from sentry.pipeline.views.base import PipelineView
from sentry.pipeline.types import PipelineStepResult
from sentry.pipeline.views.base import ApiPipelineSteps, PipelineView
from sentry.shared_integrations.exceptions import ApiError, IntegrationError
from sentry.utils.http import absolute_uri

Expand Down Expand Up @@ -141,6 +144,64 @@ def uninstall(self) -> None:
return


class DiscordOAuthApiSerializer(CamelSnakeSerializer):
code = CharField(required=True)
state = CharField(required=True)
guild_id = CharField(required=True)


class DiscordOAuthApiStep:
"""API-mode OAuth step for Discord integration setup.

Discord's OAuth flow is unique: the authorize URL includes bot permissions,
and the callback returns a guild_id alongside the authorization code.
This step handles both, binding guild_id and code to pipeline state.
"""

step_name = "oauth_login"

def __init__(
self,
client_id: str,
permissions: int,
scopes: frozenset[str],
redirect_url: str,
) -> None:
self.client_id = client_id
self.permissions = permissions
self.scopes = scopes
self.redirect_url = redirect_url

def get_step_data(self, pipeline: IntegrationPipeline, request: HttpRequest) -> dict[str, str]:
params = urlencode(
{
"client_id": self.client_id,
"permissions": self.permissions,
"scope": " ".join(self.scopes),
"response_type": "code",
"state": pipeline.signature,
"redirect_uri": self.redirect_url,
}
)
return {"oauthUrl": f"https://discord.com/api/oauth2/authorize?{params}"}

def get_serializer_cls(self) -> type:
return DiscordOAuthApiSerializer

def handle_post(
self,
validated_data: dict[str, str],
pipeline: IntegrationPipeline,
request: HttpRequest,
) -> PipelineStepResult:
if validated_data["state"] != pipeline.signature:
return PipelineStepResult.error("An error occurred while validating your request.")

pipeline.bind_state("guild_id", validated_data["guild_id"])
pipeline.bind_state("code", validated_data["code"])
return PipelineStepResult.advance()


class DiscordIntegrationProvider(IntegrationProvider):
key = IntegrationProviderSlug.DISCORD.value
name = "Discord"
Expand Down Expand Up @@ -176,6 +237,16 @@ def __init__(self) -> None:
def get_pipeline_views(self) -> Sequence[PipelineView[IntegrationPipeline]]:
return [DiscordInstallPipeline(self.get_params_for_oauth())]

def get_pipeline_api_steps(self) -> ApiPipelineSteps[IntegrationPipeline]:
return [
DiscordOAuthApiStep(
client_id=self.application_id,
permissions=self.bot_permissions,
scopes=self.oauth_scopes,
redirect_url=self.setup_url,
),
]

def build_integration(self, state: Mapping[str, Any]) -> IntegrationData:
guild_id = str(state.get("guild_id"))

Expand Down
133 changes: 132 additions & 1 deletion tests/sentry/integrations/discord/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from __future__ import annotations

from typing import Any
from unittest import mock
from urllib.parse import parse_qs, urlencode, urlparse

import pytest
import responses
from django.urls import reverse
from responses.matchers import header_matcher, json_params_matcher

from sentry import audit_log, options
Expand All @@ -18,6 +22,8 @@
DiscordIntegrationProvider,
)
from sentry.integrations.models.integration import Integration
from sentry.integrations.models.organization_integration import OrganizationIntegration
from sentry.integrations.pipeline import IntegrationPipeline
from sentry.models.auditlogentry import AuditLogEntry
from sentry.notifications.platform.discord.provider import DiscordRenderable
from sentry.notifications.platform.target import IntegrationNotificationTarget
Expand All @@ -30,7 +36,7 @@
IntegrationConfigurationError,
IntegrationError,
)
from sentry.testutils.cases import IntegrationTestCase, TestCase
from sentry.testutils.cases import APITestCase, IntegrationTestCase, TestCase
from sentry.testutils.silo import control_silo_test
from sentry.utils import json

Expand Down Expand Up @@ -552,3 +558,128 @@ def test_send_notification_api_error(self, mock_send: mock.MagicMock) -> None:
self.installation.send_notification(target=self.target, payload=payload)

assert str(e.value) == error_payload


@control_silo_test
class DiscordApiPipelineTest(APITestCase):
endpoint = "sentry-api-0-organization-pipeline"
method = "post"

guild_id = "1234567890"
guild_name = "Cool server"

def setUp(self) -> None:
super().setUp()
self.login_as(self.user)
self.application_id = "application-id"
self.public_key = "public-key"
self.bot_token = "bot-token"
self.client_secret = "client-secret"
options.set("discord.application-id", self.application_id)
options.set("discord.public-key", self.public_key)
options.set("discord.bot-token", self.bot_token)
options.set("discord.client-secret", self.client_secret)

def tearDown(self) -> None:
responses.reset()
super().tearDown()

def _get_pipeline_url(self) -> str:
return reverse(
self.endpoint,
args=[self.organization.slug, IntegrationPipeline.pipeline_name],
)

def _initialize_pipeline(self) -> Any:
return self.client.post(
self._get_pipeline_url(),
data={"action": "initialize", "provider": "discord"},
format="json",
)

def _advance_step(self, data: dict[str, Any]) -> Any:
return self.client.post(self._get_pipeline_url(), data=data, format="json")

def _get_pipeline_signature(self, resp: Any) -> str:
return resp.data["data"]["oauthUrl"].split("state=")[1].split("&")[0]

@responses.activate
def test_initialize_pipeline(self) -> None:
resp = self._initialize_pipeline()
assert resp.status_code == 200
assert resp.data["step"] == "oauth_login"
assert resp.data["stepIndex"] == 0
assert resp.data["totalSteps"] == 1
assert resp.data["provider"] == "discord"
oauth_url = resp.data["data"]["oauthUrl"]
assert "discord.com/api/oauth2/authorize" in oauth_url
assert "permissions=" in oauth_url

parsed = urlparse(oauth_url)
params = parse_qs(parsed.query)
assert params["client_id"] == [self.application_id]
assert params["permissions"] == [str(DiscordIntegrationProvider.bot_permissions)]
requested_scopes = set(params["scope"][0].split(" "))
assert requested_scopes == DiscordIntegrationProvider.oauth_scopes

@responses.activate
def test_oauth_step_missing_guild_id(self) -> None:
resp = self._initialize_pipeline()
pipeline_signature = self._get_pipeline_signature(resp)
resp = self._advance_step({"code": "auth-code", "state": pipeline_signature})
assert resp.status_code == 400

@responses.activate
@mock.patch("sentry.integrations.discord.client.DiscordClient.set_application_command")
def test_full_pipeline_flow(self, mock_set_application_command: mock.MagicMock) -> None:
responses.add(
responses.GET,
url=f"{DiscordClient.base_url}{GUILD_URL.format(guild_id=self.guild_id)}",
match=[header_matcher({"Authorization": f"Bot {self.bot_token}"})],
json={"id": self.guild_id, "name": self.guild_name},
)
responses.add(
responses.GET,
url=f"{DiscordClient.base_url}{APPLICATION_COMMANDS_URL.format(application_id=self.application_id)}",
match=[header_matcher({"Authorization": f"Bot {self.bot_token}"})],
json=COMMANDS,
)
responses.add(
responses.POST,
url=f"{DISCORD_BASE_URL}/oauth2/token",
json={"access_token": "access_token"},
)
responses.add(
responses.GET,
url=f"{DiscordClient.base_url}/users/@me",
json={"id": "user_1234"},
)
responses.add(
responses.GET,
url=f"{DiscordClient.base_url}/users/@me/guilds/{self.guild_id}/member",
json={},
)

resp = self._initialize_pipeline()
assert resp.data["step"] == "oauth_login"
pipeline_signature = self._get_pipeline_signature(resp)

resp = self._advance_step(
{
"code": "discord-auth-code",
"state": pipeline_signature,
"guildId": self.guild_id,
}
)
assert resp.status_code == 200
assert resp.data["status"] == "complete"
assert "data" in resp.data

integration = Integration.objects.get(provider="discord")
assert integration.external_id == self.guild_id
assert integration.name == self.guild_name

assert OrganizationIntegration.objects.filter(
organization_id=self.organization.id,
integration=integration,
).exists()
Loading