diff --git a/src/sentry/seer/signed_seer_api.py b/src/sentry/seer/signed_seer_api.py index f69c3bb11a0d61..a88218af5efae3 100644 --- a/src/sentry/seer/signed_seer_api.py +++ b/src/sentry/seer/signed_seer_api.py @@ -386,6 +386,22 @@ class SupergroupsGetByGroupIdsRequest(TypedDict): group_ids: list[int] +class SupergroupDetailData(TypedDict): + id: int + title: str + summary: str + error_type: str + code_area: str + group_ids: list[int] + project_ids: list[int] + created_at: str + updated_at: str + + +class SupergroupsByGroupIdsResponse(TypedDict): + data: list[SupergroupDetailData] + + class ServiceMapUpdateRequest(TypedDict): organization_id: int nodes: list[dict[str, Any]] diff --git a/src/sentry/seer/supergroups/endpoints/organization_supergroups_by_group.py b/src/sentry/seer/supergroups/endpoints/organization_supergroups_by_group.py index c3cadbddde864c..fc2c50de7a59be 100644 --- a/src/sentry/seer/supergroups/endpoints/organization_supergroups_by_group.py +++ b/src/sentry/seer/supergroups/endpoints/organization_supergroups_by_group.py @@ -16,6 +16,7 @@ from sentry.models.organization import Organization from sentry.seer.signed_seer_api import ( SeerViewerContext, + SupergroupsByGroupIdsResponse, make_supergroups_get_by_group_ids_request, ) @@ -55,21 +56,19 @@ def get(self, request: Request, organization: Organization) -> Response: status=status_codes.HTTP_400_BAD_REQUEST, ) - group_qs = Group.objects.filter( - id__in=group_ids, - project__organization=organization, - ) - status_param = request.GET.get("status") - if status_param is not None: - if status_param not in STATUS_QUERY_CHOICES: - return Response( - {"detail": "Invalid status parameter"}, - status=status_codes.HTTP_400_BAD_REQUEST, - ) - group_qs = group_qs.filter(status=STATUS_QUERY_CHOICES[status_param]) - - valid_group_ids = set(group_qs.values_list("id", flat=True)) + if status_param is not None and status_param not in STATUS_QUERY_CHOICES: + return Response( + {"detail": "Invalid status parameter"}, + status=status_codes.HTTP_400_BAD_REQUEST, + ) + + valid_group_ids = set( + Group.objects.filter( + id__in=group_ids, + project__organization=organization, + ).values_list("id", flat=True) + ) group_ids = [gid for gid in group_ids if gid in valid_group_ids] if not group_ids: @@ -90,4 +89,31 @@ def get(self, request: Request, organization: Organization) -> Response: status=response.status, ) - return Response(orjson.loads(response.data)) + data: SupergroupsByGroupIdsResponse = orjson.loads(response.data) + + if not status_param: + return Response(data) + + # Seer returns all group_ids per supergroup regardless of status. + # We can't filter before the Seer call because Seer expands group_ids + # to include the full supergroup membership, not just the requested IDs. + # Instead, collect every group_id from the response, check status in + # bulk, and strip out non-matching ones. + all_response_group_ids: list[int] = [] + for sg in data["data"]: + all_response_group_ids.extend(sg["group_ids"]) + + matching_ids = set( + Group.objects.filter( + id__in=all_response_group_ids, + project__organization=organization, + status=STATUS_QUERY_CHOICES[status_param], + ).values_list("id", flat=True) + ) + + for sg in data["data"]: + sg["group_ids"] = [gid for gid in sg["group_ids"] if gid in matching_ids] + # Drop supergroups that have no matching groups after filtering + data["data"] = [sg for sg in data["data"] if sg["group_ids"]] + + return Response(data) diff --git a/tests/sentry/seer/supergroups/endpoints/test_organization_supergroups_by_group.py b/tests/sentry/seer/supergroups/endpoints/test_organization_supergroups_by_group.py index 7f239d45207ef3..4419d824602847 100644 --- a/tests/sentry/seer/supergroups/endpoints/test_organization_supergroups_by_group.py +++ b/tests/sentry/seer/supergroups/endpoints/test_organization_supergroups_by_group.py @@ -1,5 +1,6 @@ from __future__ import annotations +from typing import Any from unittest.mock import MagicMock, patch import orjson @@ -8,7 +9,7 @@ from sentry.testutils.cases import APITestCase -def mock_seer_response(data): +def mock_seer_response(data: dict[str, Any]) -> MagicMock: response = MagicMock() response.status = 200 response.data = orjson.dumps(data) @@ -29,18 +30,41 @@ def setUp(self): @patch( "sentry.seer.supergroups.endpoints.organization_supergroups_by_group.make_supergroups_get_by_group_ids_request" ) - def test_status_filter(self, mock_seer): - mock_seer.return_value = mock_seer_response({"supergroups": []}) + def test_status_filter_strips_resolved_from_response(self, mock_seer): + extra_unresolved = self.create_group(project=self.project, status=GroupStatus.UNRESOLVED) + mock_seer.return_value = mock_seer_response( + { + "data": [ + { + "id": 1, + "group_ids": [ + self.unresolved_group.id, + self.resolved_group.id, + extra_unresolved.id, + ], + "title": "kept", + }, + { + "id": 2, + "group_ids": [self.resolved_group.id], + "title": "dropped", + }, + ] + } + ) with self.feature("organizations:top-issues-ui"): - self.get_success_response( + response = self.get_success_response( self.organization.slug, group_id=[self.unresolved_group.id, self.resolved_group.id], status="unresolved", ) - body = mock_seer.call_args[0][0] - assert body["group_ids"] == [self.unresolved_group.id] + assert len(response.data["data"]) == 1 + assert response.data["data"][0]["group_ids"] == [ + self.unresolved_group.id, + extra_unresolved.id, + ] def test_status_filter_invalid(self): with self.feature("organizations:top-issues-ui"): @@ -50,12 +74,3 @@ def test_status_filter_invalid(self): status="bogus", status_code=400, ) - - def test_status_filter_all_filtered_out(self): - with self.feature("organizations:top-issues-ui"): - self.get_error_response( - self.organization.slug, - group_id=[self.resolved_group.id], - status="unresolved", - status_code=404, - )