diff --git a/src/sentry/features/temporary.py b/src/sentry/features/temporary.py index 99e3eb8baea852..09ccef03d03778 100644 --- a/src/sentry/features/temporary.py +++ b/src/sentry/features/temporary.py @@ -160,6 +160,7 @@ def register_temporary_features(manager: FeatureManager) -> None: manager.add("organizations:integration-api-pipeline-gitlab", OrganizationFeature, FeatureHandlerStrategy.FLAGPOLE, api_expose=True) manager.add("organizations:integration-api-pipeline-slack", OrganizationFeature, FeatureHandlerStrategy.FLAGPOLE, api_expose=True) manager.add("organizations:integration-api-pipeline-bitbucket", OrganizationFeature, FeatureHandlerStrategy.FLAGPOLE, api_expose=True) + manager.add("organizations:integration-api-pipeline-aws-lambda", OrganizationFeature, FeatureHandlerStrategy.FLAGPOLE, api_expose=True) # Project Management Integrations Feature Parity Flags manager.add("organizations:integrations-github_enterprise-project-management", OrganizationFeature, FeatureHandlerStrategy.FLAGPOLE, api_expose=True) manager.add("organizations:integrations-gitlab-project-management", OrganizationFeature, FeatureHandlerStrategy.FLAGPOLE, api_expose=True) diff --git a/src/sentry/integrations/aws_lambda/integration.py b/src/sentry/integrations/aws_lambda/integration.py index 6530a7f13b495f..7681be867a32fc 100644 --- a/src/sentry/integrations/aws_lambda/integration.py +++ b/src/sentry/integrations/aws_lambda/integration.py @@ -8,9 +8,12 @@ from django.http.request import HttpRequest from django.http.response import HttpResponseBase from django.utils.translation import gettext_lazy as _ +from rest_framework import serializers +from rest_framework.fields import CharField, ChoiceField, IntegerField, ListField from sentry import analytics, options from sentry.analytics.events.integration_serverless_setup import IntegrationServerlessSetup +from sentry.api.serializers.rest_framework.base import CamelSnakeSerializer from sentry.integrations.base import ( FeatureDescription, IntegrationData, @@ -25,7 +28,8 @@ from sentry.integrations.pipeline import IntegrationPipeline from sentry.organizations.services.organization import organization_service from sentry.organizations.services.organization.model import RpcOrganization -from sentry.pipeline.views.base import PipelineView, render_react_view +from sentry.pipeline.types import PipelineStepResult +from sentry.pipeline.views.base import ApiPipelineSteps, PipelineView, render_react_view from sentry.projects.services.project import project_service from sentry.silo.base import control_silo_function from sentry.users.models.user import User @@ -198,6 +202,209 @@ def update_function_to_latest_version(self, target): return self.get_serialized_lambda_function(target) +class ProjectSelectSerializer(CamelSnakeSerializer): + project_id = IntegerField(required=True) + + +class ProjectSelectApiStep: + step_name = "project_select" + + def get_step_data(self, pipeline: IntegrationPipeline, request: HttpRequest) -> dict[str, Any]: + return {} + + def get_serializer_cls(self) -> type: + return ProjectSelectSerializer + + def handle_post( + self, + validated_data: dict[str, Any], + pipeline: IntegrationPipeline, + request: HttpRequest, + ) -> PipelineStepResult: + project_id = validated_data["project_id"] + + assert pipeline.organization is not None + valid_project_ids = {p.id for p in pipeline.organization.projects} + if project_id not in valid_project_ids: + return PipelineStepResult.error("Invalid project") + + pipeline.bind_state("project_id", project_id) + return PipelineStepResult.advance() + + +class CloudFormationSerializer(CamelSnakeSerializer): + account_number = CharField(required=True) + region = ChoiceField(choices=[(r, r) for r in ALL_AWS_REGIONS], required=True) + aws_external_id = CharField(required=True) + + def validate_account_number(self, value: str) -> str: + if not value.isdigit() or len(value) != 12: + raise serializers.ValidationError("Must be a 12-digit AWS account number") + return value + + +class CloudFormationApiStep: + step_name = "cloudformation" + + def get_step_data(self, pipeline: IntegrationPipeline, request: HttpRequest) -> dict[str, Any]: + template_url = options.get("aws-lambda.cloudformation-url") + return { + "baseCloudformationUrl": "https://console.aws.amazon.com/cloudformation/home#/stacks/create/review", + "templateUrl": template_url, + "stackName": "Sentry-Monitoring-Stack", + "regionList": ALL_AWS_REGIONS, + } + + def get_serializer_cls(self) -> type: + return CloudFormationSerializer + + def handle_post( + self, + validated_data: dict[str, Any], + pipeline: IntegrationPipeline, + request: HttpRequest, + ) -> PipelineStepResult: + account_number = validated_data["account_number"] + region = validated_data["region"] + aws_external_id = validated_data["aws_external_id"] + + pipeline.bind_state("account_number", account_number) + pipeline.bind_state("region", region) + pipeline.bind_state("aws_external_id", aws_external_id) + + try: + gen_aws_client(account_number, region, aws_external_id) + except ClientError: + return PipelineStepResult.error( + "Please validate the Cloudformation stack was created successfully" + ) + except ConfigurationError: + raise + except Exception as e: + logger.warning( + "CloudFormationApiStep.unexpected_error", + extra={"error": str(e)}, + ) + return PipelineStepResult.error("Unknown error") + + return PipelineStepResult.advance() + + +class FunctionSelectSerializer(CamelSnakeSerializer): + enabled_functions = ListField(child=CharField(), required=True) + + +class InstrumentationApiStep: + step_name = "instrumentation" + + def get_step_data(self, pipeline: IntegrationPipeline, request: HttpRequest) -> dict[str, Any]: + account_number = pipeline.fetch_state("account_number") + region = pipeline.fetch_state("region") + aws_external_id = pipeline.fetch_state("aws_external_id") + + lambda_client = gen_aws_client(account_number, region, aws_external_id) + lambda_functions = get_supported_functions(lambda_client) + lambda_functions.sort(key=lambda x: x["FunctionName"].lower()) + + return { + "functions": [ + { + "name": fn["FunctionName"], + "runtime": fn["Runtime"], + "description": fn.get("Description", ""), + } + for fn in lambda_functions + ] + } + + def get_serializer_cls(self) -> type: + return FunctionSelectSerializer + + def handle_post( + self, + validated_data: dict[str, Any], + pipeline: IntegrationPipeline, + request: HttpRequest, + ) -> PipelineStepResult: + assert pipeline.organization is not None + organization = pipeline.organization + + account_number = pipeline.fetch_state("account_number") + region = pipeline.fetch_state("region") + project_id = pipeline.fetch_state("project_id") + aws_external_id = pipeline.fetch_state("aws_external_id") + + enabled_functions = validated_data["enabled_functions"] + enabled_lambdas = {name: True for name in enabled_functions} + + sentry_project_dsn = get_dsn_for_project(organization.id, project_id) + + lambda_client = gen_aws_client(account_number, region, aws_external_id) + lambda_functions = get_supported_functions(lambda_client) + lambda_functions.sort(key=lambda x: x["FunctionName"].lower()) + + lambda_functions = [ + fn for fn in lambda_functions if enabled_lambdas.get(fn["FunctionName"]) + ] + + def _enable_lambda(function): + try: + enable_single_lambda(lambda_client, function, sentry_project_dsn) + return (True, function, None) + except Exception as e: + return (False, function, e) + + failures: list[dict[str, Any]] = [] + success_count = 0 + + with ContextPropagatingThreadPoolExecutor( + max_workers=options.get("aws-lambda.thread-count") + ) as _lambda_setup_thread_pool: + for success, function, e in _lambda_setup_thread_pool.map( + _enable_lambda, lambda_functions + ): + name = function["FunctionName"] + if success: + success_count += 1 + else: + err_message: str | _StrPromise = str(e) + is_custom_err, err_message = get_sentry_err_message(err_message) + if not is_custom_err: + capture_exception(e) + err_message = _("Unknown Error") + failures.append({"name": name, "error": str(err_message)}) + logger.info( + "update_function_configuration.error", + extra={ + "organization_id": organization.id, + "lambda_name": name, + "account_number": account_number, + "region": region, + "error": str(e), + }, + ) + + analytics.record( + IntegrationServerlessSetup( + user_id=request.user.id, + organization_id=organization.id, + integration="aws_lambda", + success_count=success_count, + failure_count=len(failures), + ) + ) + + if failures: + return PipelineStepResult.stay( + data={ + "failures": failures, + "successCount": success_count, + } + ) + + return PipelineStepResult.advance() + + class AwsLambdaIntegrationProvider(IntegrationProvider): key = "aws_lambda" name = "AWS Lambda" @@ -213,6 +420,13 @@ def get_pipeline_views(self) -> list[PipelineView[IntegrationPipeline]]: AwsLambdaSetupLayerPipelineView(), ] + def get_pipeline_api_steps(self) -> ApiPipelineSteps[IntegrationPipeline]: + return [ + ProjectSelectApiStep(), + CloudFormationApiStep(), + InstrumentationApiStep(), + ] + @control_silo_function def build_integration(self, state: Mapping[str, Any]) -> IntegrationData: region = state["region"] diff --git a/tests/sentry/integrations/aws_lambda/test_integration.py b/tests/sentry/integrations/aws_lambda/test_integration.py index 5bc745a87b8688..27068acc21b5e0 100644 --- a/tests/sentry/integrations/aws_lambda/test_integration.py +++ b/tests/sentry/integrations/aws_lambda/test_integration.py @@ -1,19 +1,22 @@ +from typing import Any from unittest.mock import ANY, MagicMock, patch from urllib.parse import urlencode from botocore.exceptions import ClientError from django.http import HttpResponse +from django.urls import reverse from sentry.integrations.aws_lambda import AwsLambdaIntegrationProvider from sentry.integrations.aws_lambda import integration as aws_lambda_integration from sentry.integrations.aws_lambda.utils import ALL_AWS_REGIONS from sentry.integrations.models.integration import Integration from sentry.integrations.models.organization_integration import OrganizationIntegration +from sentry.integrations.pipeline import IntegrationPipeline from sentry.models.projectkey import ProjectKey from sentry.organizations.services.organization import organization_service from sentry.projects.services.project import project_service from sentry.silo.base import SiloMode -from sentry.testutils.cases import IntegrationTestCase +from sentry.testutils.cases import APITestCase, IntegrationTestCase from sentry.testutils.helpers.options import override_options from sentry.testutils.silo import assume_test_silo_mode, control_silo_test from sentry.users.services.user.serial import serialize_rpc_user @@ -25,6 +28,7 @@ account_number = "599817902985" region = "us-east-2" +aws_external_id = "test-external-id-1234" @control_silo_test @@ -530,3 +534,274 @@ class MockException(Exception): mock_react_view.assert_called_with( ANY, "awsLambdaFailureDetails", {"lambdaFunctionFailures": failures, "successCount": 0} ) + + +@control_silo_test +class AwsLambdaApiPipelineTest(APITestCase): + endpoint = "sentry-api-0-organization-pipeline" + method = "post" + + def setUp(self) -> None: + super().setUp() + self.login_as(self.user) + self.projectA = self.create_project(organization=self.organization, slug="projA") + self.projectB = self.create_project(organization=self.organization, slug="projB") + + def _get_pipeline_url(self) -> str: + return reverse( + self.endpoint, + args=[self.organization.slug, IntegrationPipeline.pipeline_name], + ) + + def _initialize_pipeline(self) -> Any: + return self.client.post( + self._get_pipeline_url(), + data={"action": "initialize", "provider": "aws_lambda"}, + format="json", + ) + + def _advance_step(self, data: dict[str, Any]) -> Any: + return self.client.post(self._get_pipeline_url(), data=data, format="json") + + def _get_step_info(self) -> Any: + return self.client.get(self._get_pipeline_url()) + + def test_initialize_pipeline(self) -> None: + resp = self._initialize_pipeline() + assert resp.status_code == 200 + assert resp.data["step"] == "project_select" + assert resp.data["stepIndex"] == 0 + assert resp.data["totalSteps"] == 3 + assert resp.data["provider"] == "aws_lambda" + + def test_project_select_step_data(self) -> None: + self._initialize_pipeline() + resp = self._get_step_info() + assert resp.status_code == 200 + assert resp.data["step"] == "project_select" + assert resp.data["data"] == {} + + def test_project_select_advance(self) -> None: + self._initialize_pipeline() + resp = self._advance_step({"projectId": self.projectA.id}) + assert resp.status_code == 200 + assert resp.data["status"] == "advance" + assert resp.data["step"] == "cloudformation" + + def test_project_select_invalid_project(self) -> None: + self._initialize_pipeline() + resp = self._advance_step({"projectId": 99999}) + assert resp.status_code == 400 + assert resp.data["status"] == "error" + + def test_project_select_missing_project_id(self) -> None: + self._initialize_pipeline() + resp = self._advance_step({}) + assert resp.status_code == 400 + + @patch("sentry.integrations.aws_lambda.integration.gen_aws_client") + def test_cloudformation_step_data(self, mock_gen_aws_client: MagicMock) -> None: + self._initialize_pipeline() + self._advance_step({"projectId": self.projectA.id}) + resp = self._get_step_info() + assert resp.status_code == 200 + assert resp.data["step"] == "cloudformation" + data = resp.data["data"] + assert "templateUrl" in data + assert "regionList" in data + assert data["stackName"] == "Sentry-Monitoring-Stack" + + @patch("sentry.integrations.aws_lambda.integration.gen_aws_client") + def test_cloudformation_advance(self, mock_gen_aws_client: MagicMock) -> None: + self._initialize_pipeline() + self._advance_step({"projectId": self.projectA.id}) + resp = self._advance_step( + { + "accountNumber": account_number, + "region": region, + "awsExternalId": aws_external_id, + } + ) + assert resp.status_code == 200 + assert resp.data["status"] == "advance" + assert resp.data["step"] == "instrumentation" + + @patch("sentry.integrations.aws_lambda.integration.gen_aws_client") + def test_cloudformation_invalid_region(self, mock_gen_aws_client: MagicMock) -> None: + self._initialize_pipeline() + self._advance_step({"projectId": self.projectA.id}) + resp = self._advance_step( + { + "accountNumber": account_number, + "region": "invalid-region", + "awsExternalId": aws_external_id, + } + ) + assert resp.status_code == 400 + assert "region" in resp.data + + @patch("sentry.integrations.aws_lambda.integration.gen_aws_client") + def test_cloudformation_invalid_account_number(self, mock_gen_aws_client: MagicMock) -> None: + self._initialize_pipeline() + self._advance_step({"projectId": self.projectA.id}) + resp = self._advance_step( + { + "accountNumber": "bad", + "region": region, + "awsExternalId": aws_external_id, + } + ) + assert resp.status_code == 400 + assert "accountNumber" in resp.data + + @patch("sentry.integrations.aws_lambda.integration.gen_aws_client") + def test_cloudformation_client_error(self, mock_gen_aws_client: MagicMock) -> None: + mock_gen_aws_client.side_effect = ClientError({"Error": {}}, "assume_role") + self._initialize_pipeline() + self._advance_step({"projectId": self.projectA.id}) + resp = self._advance_step( + { + "accountNumber": account_number, + "region": region, + "awsExternalId": aws_external_id, + } + ) + assert resp.status_code == 400 + assert resp.data["status"] == "error" + assert "Cloudformation" in resp.data["data"]["detail"] + + @patch("sentry.integrations.aws_lambda.integration.get_supported_functions") + @patch("sentry.integrations.aws_lambda.integration.gen_aws_client") + def test_instrumentation_step_data( + self, + mock_gen_aws_client: MagicMock, + mock_get_supported_functions: MagicMock, + ) -> None: + mock_get_supported_functions.return_value = [ + {"FunctionName": "lambdaB", "Runtime": "nodejs12.x", "Description": "B func"}, + {"FunctionName": "lambdaA", "Runtime": "python3.9", "Description": "A func"}, + ] + + self._initialize_pipeline() + self._advance_step({"projectId": self.projectA.id}) + self._advance_step( + { + "accountNumber": account_number, + "region": region, + "awsExternalId": aws_external_id, + } + ) + resp = self._get_step_info() + assert resp.status_code == 200 + assert resp.data["step"] == "instrumentation" + functions = resp.data["data"]["functions"] + assert len(functions) == 2 + assert functions[0]["name"] == "lambdaA" + assert functions[1]["name"] == "lambdaB" + + @patch("sentry.integrations.aws_lambda.integration.get_supported_functions") + @patch("sentry.integrations.aws_lambda.integration.gen_aws_client") + def test_full_api_pipeline_success( + self, + mock_gen_aws_client: MagicMock, + mock_get_supported_functions: MagicMock, + ) -> None: + mock_client = mock_gen_aws_client.return_value + mock_client.update_function_configuration = MagicMock() + mock_client.describe_account = MagicMock(return_value={"Account": {"Name": "my_name"}}) + + mock_get_supported_functions.return_value = [ + { + "FunctionName": "lambdaA", + "Runtime": "nodejs12.x", + "FunctionArn": f"arn:aws:lambda:{region}:{account_number}:function:lambdaA", + }, + ] + + with assume_test_silo_mode(SiloMode.CELL): + sentry_project_dsn = ProjectKey.get_default(project=self.projectA).get_dsn(public=True) + + self._initialize_pipeline() + self._advance_step({"projectId": self.projectA.id}) + self._advance_step( + { + "accountNumber": account_number, + "region": region, + "awsExternalId": aws_external_id, + } + ) + resp = self._advance_step({"enabledFunctions": ["lambdaA"]}) + assert resp.status_code == 200 + assert resp.data["status"] == "complete" + + mock_client.update_function_configuration.assert_called_once() + call_kwargs = mock_client.update_function_configuration.call_args[1] + assert call_kwargs["FunctionName"] == "lambdaA" + assert call_kwargs["Environment"]["Variables"]["SENTRY_DSN"] == sentry_project_dsn + + integration = Integration.objects.get(provider="aws_lambda") + assert integration.name == f"my_name {region}" + assert integration.external_id == f"{account_number}-{region}" + assert integration.metadata["account_number"] == account_number + assert integration.metadata["region"] == region + assert "aws_external_id" in integration.metadata + assert OrganizationIntegration.objects.filter( + integration=integration, organization_id=self.organization.id + ).exists() + + @patch("sentry.integrations.aws_lambda.integration.get_supported_functions") + @patch("sentry.integrations.aws_lambda.integration.gen_aws_client") + def test_instrumentation_with_failures( + self, + mock_gen_aws_client: MagicMock, + mock_get_supported_functions: MagicMock, + ) -> None: + class MockException(Exception): + pass + + bad_layer = "arn:aws:lambda:us-east-2:546545:layer:another-layer:5" + mock_client = mock_gen_aws_client.return_value + mock_client.update_function_configuration = MagicMock( + side_effect=Exception(f"Layer version {bad_layer} does not exist") + ) + mock_client.describe_account = MagicMock(return_value={"Account": {"Name": "my_name"}}) + mock_client.exceptions = MagicMock() + mock_client.exceptions.ResourceConflictException = MockException + + mock_get_supported_functions.return_value = [ + { + "FunctionName": "lambdaA", + "Runtime": "nodejs12.x", + "FunctionArn": f"arn:aws:lambda:{region}:{account_number}:function:lambdaA", + }, + ] + + self._initialize_pipeline() + self._advance_step({"projectId": self.projectA.id}) + self._advance_step( + { + "accountNumber": account_number, + "region": region, + "awsExternalId": aws_external_id, + } + ) + + resp = self._advance_step({"enabledFunctions": ["lambdaA"]}) + assert resp.status_code == 200 + assert resp.data["status"] == "stay" + assert resp.data["data"]["successCount"] == 0 + assert len(resp.data["data"]["failures"]) == 1 + assert resp.data["data"]["failures"][0]["name"] == "lambdaA" + assert "another-layer" in resp.data["data"]["failures"][0]["error"] + + # User retries (or deselects failed functions), pipeline finishes + mock_client.update_function_configuration = MagicMock() + resp = self._advance_step({"enabledFunctions": ["lambdaA"]}) + assert resp.status_code == 200 + assert resp.data["status"] == "complete" + + integration = Integration.objects.get(provider="aws_lambda") + assert integration.external_id == f"{account_number}-{region}" + + integration = Integration.objects.get(provider="aws_lambda") + assert integration.external_id == f"{account_number}-{region}"