diff --git a/src/sentry/tasks/llm_issue_detection/__init__.py b/src/sentry/tasks/llm_issue_detection/__init__.py index d9e1c162da20d4..5f28bf5b5cf894 100644 --- a/src/sentry/tasks/llm_issue_detection/__init__.py +++ b/src/sentry/tasks/llm_issue_detection/__init__.py @@ -1,13 +1,13 @@ from sentry.tasks.llm_issue_detection.detection import ( DetectedIssue, create_issue_occurrence_from_detection, - detect_llm_issues_for_project, + detect_llm_issues_for_org, run_llm_issue_detection, ) __all__ = [ "DetectedIssue", "create_issue_occurrence_from_detection", - "detect_llm_issues_for_project", + "detect_llm_issues_for_org", "run_llm_issue_detection", ] diff --git a/src/sentry/tasks/llm_issue_detection/detection.py b/src/sentry/tasks/llm_issue_detection/detection.py index 357ad48341214d..18b637b863f88e 100644 --- a/src/sentry/tasks/llm_issue_detection/detection.py +++ b/src/sentry/tasks/llm_issue_detection/detection.py @@ -13,7 +13,7 @@ from urllib3 import BaseHTTPResponse from sentry import features, options -from sentry.constants import VALID_PLATFORMS +from sentry.constants import VALID_PLATFORMS, ObjectStatus from sentry.issues.grouptype import ( AIDetectedCodeHealthGroupType, AIDetectedDBGroupType, @@ -25,12 +25,15 @@ ) from sentry.issues.issue_occurrence import IssueEvidence, IssueOccurrence from sentry.issues.producer import PayloadType, produce_occurrence_to_kafka +from sentry.models.organization import Organization, OrganizationStatus from sentry.models.project import Project from sentry.net.http import connection_from_url from sentry.seer.explorer.utils import normalize_description from sentry.seer.signed_seer_api import SeerViewerContext, make_signed_seer_api_request from sentry.tasks.base import instrumented_task from sentry.taskworker.namespaces import issues_tasks +from sentry.utils.hashlib import md5_text +from sentry.utils.query import RangeQuerySetWrapper from sentry.utils.redis import redis_clusters logger = logging.getLogger("sentry.tasks.llm_issue_detection") @@ -39,13 +42,14 @@ SEER_TIMEOUT_S = 10 START_TIME_DELTA_MINUTES = 60 TRANSACTION_BATCH_SIZE = 50 -NUM_TRANSACTIONS_TO_PROCESS = 10 TRACE_PROCESSING_TTL_SECONDS = 7200 -# Character limit for LLM-generated fields to protect against abuse. -# Word limits are enforced by Seer's prompt (see seer/automation/issue_detection/analyze.py). -# This limit prevents excessively long outputs from malicious or malfunctioning LLMs. MAX_LLM_FIELD_LENGTH = 2000 +DISPATCH_INTERVAL_MINUTES = 15 +NUM_DISPATCH_SLOTS = 10 +MAX_ORGS_PER_CYCLE = 500 +ORG_DISPATCH_STAGGER_SECONDS = 15 + seer_issue_detection_connection_pool = connection_from_url( settings.SEER_AUTOFIX_URL, @@ -260,13 +264,16 @@ def create_issue_occurrence_from_detection( ) -def get_enabled_project_ids() -> list[int]: - """ - Get the list of project IDs that are explicitly enabled for LLM detection. +def _get_current_dispatch_slot() -> int: + """Return the current time slot index for bucketed dispatch.""" + now = datetime.now(UTC) + minutes_since_epoch = int(now.timestamp()) // 60 + return (minutes_since_epoch // DISPATCH_INTERVAL_MINUTES) % NUM_DISPATCH_SLOTS - Returns the allowlist from system options. - """ - return options.get("issue-detection.llm-detection.projects-allowlist") + +def _org_in_slot(org_id: int, slot: int) -> bool: + """Check if an org's hash-assigned slot matches the given slot.""" + return int(md5_text(str(org_id)).hexdigest(), 16) % NUM_DISPATCH_SLOTS == slot @instrumented_task( @@ -277,45 +284,64 @@ def get_enabled_project_ids() -> list[int]: def run_llm_issue_detection() -> None: """ Main scheduled task for LLM issue detection. + + Uses md5 hash bucketing to spread org dispatches across time slots. + Each 15-minute cycle processes one slot's worth of orgs. """ if not options.get("issue-detection.llm-detection.enabled"): return - enabled_project_ids = get_enabled_project_ids() - if not enabled_project_ids: - return - - # Spawn a sub-task for each project with staggered delays - for index, project_id in enumerate(enabled_project_ids): - detect_llm_issues_for_project.apply_async( - args=[project_id], - countdown=index * 90, + current_slot = _get_current_dispatch_slot() + dispatched = 0 + + for org in RangeQuerySetWrapper( + Organization.objects.filter(status=OrganizationStatus.ACTIVE), + ): + if dispatched >= MAX_ORGS_PER_CYCLE: + break + + if ( + not _org_in_slot(org.id, current_slot) + or not features.has("organizations:ai-issue-detection", org) + or not features.has("organizations:gen-ai-features", org) + or org.get_option("sentry:hide_ai_features") + ): + continue + + detect_llm_issues_for_org.apply_async( + args=[org.id], + countdown=dispatched * ORG_DISPATCH_STAGGER_SECONDS, headers={"sentry-propagate-traces": False}, ) + dispatched += 1 + + sentry_sdk.metrics.count( + "llm_issue_detection.orgs_dispatched", + dispatched, + attributes={"slot": current_slot}, + ) @instrumented_task( - name="sentry.tasks.llm_issue_detection.detect_llm_issues_for_project", + name="sentry.tasks.llm_issue_detection.detect_llm_issues_for_org", namespace=issues_tasks, processing_deadline_duration=180, # 3 minutes ) -def detect_llm_issues_for_project(project_id: int) -> None: +def detect_llm_issues_for_org(org_id: int) -> None: """ - Process a single project for LLM issue detection. + Process a single organization for LLM issue detection. - Gets the project's top TRANSACTION_BATCH_SIZE transaction spans from the last START_TIME_DELTA_MINUTES, sorted by -sum(span.duration). - From those transactions, dedupes on normalized transaction_name. - For each deduped transaction, gets first trace_id from the start of time window, which has small random variation. - Sends these trace_ids to seer, which uses get_trace_waterfall to construct an EAPTrace to analyze. + Picks one random active project, selects 1 trace, and sends to Seer. + Budget enforcement happens on the Seer side. """ from sentry.tasks.llm_issue_detection.trace_data import ( # circular imports get_project_top_transaction_traces_for_llm_detection, ) - project = Project.objects.get_from_cache(id=project_id) - organization = project.organization - organization_id = organization.id - organization_slug = organization.slug + try: + organization = Organization.objects.get_from_cache(id=org_id) + except Organization.DoesNotExist: + return has_access = features.has("organizations:gen-ai-features", organization) and not bool( organization.get_option("sentry:hide_ai_features") @@ -323,6 +349,18 @@ def detect_llm_issues_for_project(project_id: int) -> None: if not has_access: return + projects = list( + Project.objects.filter( + organization_id=org_id, + status=ObjectStatus.ACTIVE, + ).values_list("id", flat=True) + ) + if not projects: + return + + project_id = random.choice(projects) + + project = Project.objects.get_from_cache(id=project_id) perf_settings = project.get_option("sentry:performance_issue_settings", default={}) if not perf_settings.get("ai_issue_detection_enabled", True): return @@ -333,20 +371,17 @@ def detect_llm_issues_for_project(project_id: int) -> None: if not evidence_traces: return - # Shuffle to randomize selection random.shuffle(evidence_traces) - # Bulk check which traces are already processed all_trace_ids = [t.trace_id for t in evidence_traces] unprocessed_ids = _get_unprocessed_traces(all_trace_ids) skipped = len(all_trace_ids) - len(unprocessed_ids) if skipped: sentry_sdk.metrics.count("llm_issue_detection.trace.skipped", skipped) - # Take up to NUM_TRANSACTIONS_TO_PROCESS traces_to_send: list[TraceMetadataWithSpanCount] = [ t for t in evidence_traces if t.trace_id in unprocessed_ids - ][:NUM_TRANSACTIONS_TO_PROCESS] + ][:1] if not traces_to_send: return @@ -359,12 +394,12 @@ def detect_llm_issues_for_project(project_id: int) -> None: seer_request = IssueDetectionRequest( traces=traces_to_send, - organization_id=organization_id, + organization_id=org_id, project_id=project_id, - org_slug=organization_slug, + org_slug=organization.slug, ) - viewer_context = SeerViewerContext(organization_id=organization_id) + viewer_context = SeerViewerContext(organization_id=org_id) response = make_issue_detection_request( seer_request, timeout=SEER_TIMEOUT_S, @@ -374,25 +409,14 @@ def detect_llm_issues_for_project(project_id: int) -> None: if response.status == 202: mark_traces_as_processed([trace.trace_id for trace in traces_to_send]) - - logger.info( - "llm_issue_detection.request_accepted", - extra={ - "project_id": project_id, - "organization_id": organization_id, - "trace_count": len(traces_to_send), - }, - ) return - # Log (+ send to sentry) unexpected responses logger.error( "llm_issue_detection.unexpected_response", extra={ "status_code": response.status, "response_data": response.data, "project_id": project_id, - "organization_id": organization_id, - "trace_count": len(traces_to_send), + "organization_id": org_id, }, ) diff --git a/tests/sentry/tasks/test_llm_issue_detection.py b/tests/sentry/tasks/test_llm_issue_detection.py index 73dfc9e7d86a38..6c599e89570016 100644 --- a/tests/sentry/tasks/test_llm_issue_detection.py +++ b/tests/sentry/tasks/test_llm_issue_detection.py @@ -8,13 +8,15 @@ from sentry.tasks.llm_issue_detection import ( DetectedIssue, create_issue_occurrence_from_detection, - detect_llm_issues_for_project, + detect_llm_issues_for_org, run_llm_issue_detection, ) from sentry.tasks.llm_issue_detection.detection import ( - START_TIME_DELTA_MINUTES, + NUM_DISPATCH_SLOTS, TRANSACTION_BATCH_SIZE, + TraceMetadataWithSpanCount, _get_unprocessed_traces, + _org_in_slot, mark_traces_as_processed, ) from sentry.tasks.llm_issue_detection.trace_data import ( @@ -23,61 +25,173 @@ ) from sentry.testutils.cases import APITransactionTestCase, SnubaTestCase, SpanTestCase, TestCase from sentry.testutils.helpers.datetime import before_now -from sentry.testutils.helpers.features import with_feature +from sentry.testutils.pytest.fixtures import django_db_all + + +class TestDispatchSlotBucketing(TestCase): + def test_org_in_slot_deterministic(self): + slot = next(s for s in range(NUM_DISPATCH_SLOTS) if _org_in_slot(12345, s)) + assert _org_in_slot(12345, slot) is True + for s in range(NUM_DISPATCH_SLOTS): + if s != slot: + assert _org_in_slot(12345, s) is False + + def test_org_in_slot_distributes_evenly(self): + slot_counts: dict[int, int] = {s: 0 for s in range(NUM_DISPATCH_SLOTS)} + for org_id in range(1, 10001): + for s in range(NUM_DISPATCH_SLOTS): + if _org_in_slot(org_id, s): + slot_counts[s] += 1 + break + + expected_per_slot = 10000 / NUM_DISPATCH_SLOTS + for count in slot_counts.values(): + assert count > expected_per_slot * 0.8 + assert count < expected_per_slot * 1.2 + + +@django_db_all +class TestRunLLMIssueDetection(TestCase): + @patch("sentry.tasks.llm_issue_detection.detection.detect_llm_issues_for_org.apply_async") + @patch("sentry.tasks.llm_issue_detection.detection._org_in_slot") + def test_dispatches_orgs_in_current_slot(self, mock_org_in_slot, mock_apply_async): + org = self.create_organization() + mock_org_in_slot.return_value = True + + with ( + self.options({"issue-detection.llm-detection.enabled": True}), + self.feature( + { + "organizations:ai-issue-detection": [org.slug], + "organizations:gen-ai-features": [org.slug], + } + ), + ): + run_llm_issue_detection() + mock_apply_async.assert_called_once() + assert mock_apply_async.call_args.kwargs["args"] == [org.id] + + @patch("sentry.tasks.llm_issue_detection.detection.detect_llm_issues_for_org.apply_async") + @patch("sentry.tasks.llm_issue_detection.detection._org_in_slot") + def test_skips_orgs_not_in_current_slot(self, mock_org_in_slot, mock_apply_async): + org = self.create_organization() + mock_org_in_slot.return_value = False + + with ( + self.options({"issue-detection.llm-detection.enabled": True}), + self.feature( + { + "organizations:ai-issue-detection": [org.slug], + "organizations:gen-ai-features": [org.slug], + } + ), + ): + run_llm_issue_detection() -class LLMIssueDetectionTest(TestCase): - @patch("sentry.tasks.llm_issue_detection.detection.detect_llm_issues_for_project.apply_async") - def test_run_detection_dispatches_sub_tasks(self, mock_apply_async): - project = self.create_project() + mock_apply_async.assert_not_called() - with self.options( - { - "issue-detection.llm-detection.enabled": True, - "issue-detection.llm-detection.projects-allowlist": [project.id], - } + @patch("sentry.tasks.llm_issue_detection.detection.detect_llm_issues_for_org.apply_async") + @patch("sentry.tasks.llm_issue_detection.detection._org_in_slot") + def test_skips_orgs_without_feature_flag(self, mock_org_in_slot, mock_apply_async): + self.create_organization() + mock_org_in_slot.return_value = True + + with self.options({"issue-detection.llm-detection.enabled": True}): + run_llm_issue_detection() + + mock_apply_async.assert_not_called() + + @patch("sentry.tasks.llm_issue_detection.detection.detect_llm_issues_for_org.apply_async") + @patch("sentry.tasks.llm_issue_detection.detection._org_in_slot") + def test_skips_orgs_with_hidden_ai(self, mock_org_in_slot, mock_apply_async): + org = self.create_organization() + org.update_option("sentry:hide_ai_features", True) + mock_org_in_slot.return_value = True + + with ( + self.options({"issue-detection.llm-detection.enabled": True}), + self.feature( + { + "organizations:ai-issue-detection": [org.slug], + "organizations:gen-ai-features": [org.slug], + } + ), ): run_llm_issue_detection() - mock_apply_async.assert_called_once_with( - args=[project.id], countdown=0, headers={"sentry-propagate-traces": False} - ) + mock_apply_async.assert_not_called() + - @with_feature("organizations:gen-ai-features") +class TestDetectLLMIssuesForOrg(TestCase): + @patch("sentry.tasks.llm_issue_detection.detection.mark_traces_as_processed") + @patch("sentry.tasks.llm_issue_detection.detection._get_unprocessed_traces") @patch("sentry.tasks.llm_issue_detection.detection.make_issue_detection_request") @patch( "sentry.tasks.llm_issue_detection.trace_data.get_project_top_transaction_traces_for_llm_detection" ) - def test_detect_llm_issues_no_transactions(self, mock_get_transactions, mock_seer_request): - mock_get_transactions.return_value = [] + @patch("sentry.tasks.llm_issue_detection.detection.Project.objects.filter") + def test_sends_one_trace_to_seer( + self, + mock_project_filter, + mock_get_transactions, + mock_seer_request, + mock_get_unprocessed, + mock_mark_processed, + ): + mock_project_filter.return_value.values_list.return_value = [self.project.id] + mock_get_transactions.return_value = [ + TraceMetadataWithSpanCount(trace_id="trace_1", span_count=50), + TraceMetadataWithSpanCount(trace_id="trace_2", span_count=100), + ] + mock_get_unprocessed.return_value = {"trace_1", "trace_2"} - detect_llm_issues_for_project(self.project.id) + mock_response = Mock() + mock_response.status = 202 + mock_seer_request.return_value = mock_response - mock_get_transactions.assert_called_once_with( - self.project.id, - limit=TRANSACTION_BATCH_SIZE, - start_time_delta_minutes=START_TIME_DELTA_MINUTES, - ) - mock_seer_request.assert_not_called() + with self.feature({"organizations:gen-ai-features": True}): + detect_llm_issues_for_org(self.organization.id) - @with_feature("organizations:gen-ai-features") - @patch("sentry.tasks.llm_issue_detection.trace_data.Spans.run_table_query") + assert mock_seer_request.call_count == 1 + seer_request = mock_seer_request.call_args[0][0] + assert len(seer_request.traces) == 1 + assert seer_request.organization_id == self.organization.id + mock_mark_processed.assert_called_once() + + @patch("sentry.tasks.llm_issue_detection.detection.mark_traces_as_processed") @patch("sentry.tasks.llm_issue_detection.detection.make_issue_detection_request") - def test_detect_llm_issues_no_traces(self, mock_seer_request, mock_spans_query): - mock_spans_query.side_effect = [ - # First call: Return a transaction - { - "data": [{"transaction": "transaction_name", "sum(span.duration)": 1000}], - "meta": {}, - }, - # Second call (trace query): return empty - {"data": [], "meta": {}}, + @patch("sentry.tasks.llm_issue_detection.detection._get_unprocessed_traces") + @patch( + "sentry.tasks.llm_issue_detection.trace_data.get_project_top_transaction_traces_for_llm_detection" + ) + @patch("sentry.tasks.llm_issue_detection.detection.Project.objects.filter") + def test_does_not_mark_processed_on_seer_error( + self, + mock_project_filter, + mock_get_transactions, + mock_get_unprocessed, + mock_seer_request, + mock_mark_processed, + ): + mock_project_filter.return_value.values_list.return_value = [self.project.id] + mock_get_transactions.return_value = [ + TraceMetadataWithSpanCount(trace_id="trace_1", span_count=50), ] + mock_get_unprocessed.return_value = {"trace_1"} + + mock_response = Mock() + mock_response.status = 500 + mock_response.data = b"Internal Server Error" + mock_seer_request.return_value = mock_response - detect_llm_issues_for_project(self.project.id) + with self.feature({"organizations:gen-ai-features": True}): + detect_llm_issues_for_org(self.organization.id) - mock_seer_request.assert_not_called() + mock_mark_processed.assert_not_called() + +class LLMIssueDetectionTest(TestCase): @patch("sentry.tasks.llm_issue_detection.detection.produce_occurrence_to_kafka") def test_create_issue_occurrence_from_detection(self, mock_produce_occurrence): detected_issue = DetectedIssue( @@ -98,52 +212,14 @@ def test_create_issue_occurrence_from_detection(self, mock_produce_occurrence): ) assert mock_produce_occurrence.called - call_kwargs = mock_produce_occurrence.call_args.kwargs - - assert call_kwargs["payload_type"].value == "occurrence" - - occurrence = call_kwargs["occurrence"] + occurrence = mock_produce_occurrence.call_args.kwargs["occurrence"] assert occurrence.type == AIDetectedGeneralGroupType assert occurrence.issue_title == "Slow Database Query" - assert occurrence.subtitle == "Your application is running out of database connections" - assert occurrence.project_id == self.project.id - assert occurrence.culprit == "test_transaction" - assert occurrence.level == "warning" - assert occurrence.fingerprint == ["llm-detected-slow-database-query"] - - assert occurrence.evidence_data["trace_id"] == "abc123xyz" - assert occurrence.evidence_data["transaction"] == "test_transaction" - assert ( - occurrence.evidence_data["explanation"] - == "Your application is running out of database connections" - ) - assert occurrence.evidence_data["impact"] == "High - may cause request failures" - - evidence_display = occurrence.evidence_display - assert len(evidence_display) == 3 - - assert evidence_display[0].name == "Explanation" - assert ( - evidence_display[0].value == "Your application is running out of database connections" - ) - assert evidence_display[1].name == "Impact" - assert evidence_display[1].value == "High - may cause request failures" - assert evidence_display[2].name == "Evidence" - assert evidence_display[2].value == "Connection pool at 95% capacity" - - event_data = call_kwargs["event_data"] - assert event_data["project_id"] == self.project.id - assert event_data["platform"] == "other" - assert event_data["contexts"]["trace"]["trace_id"] == "abc123xyz" - assert "event_id" in event_data - assert "received" in event_data - assert "timestamp" in event_data + assert occurrence.project_id == self.project.id @patch("sentry.tasks.llm_issue_detection.detection.produce_occurrence_to_kafka") - def test_create_issue_occurrence_uses_group_for_fingerprint_when_set( - self, mock_produce_occurrence - ): + def test_create_issue_occurrence_maps_group_type(self, mock_produce_occurrence): detected_issue = DetectedIssue( title="Inefficient Database Queries", explanation="Multiple queries in loop", @@ -163,130 +239,14 @@ def test_create_issue_occurrence_uses_group_for_fingerprint_when_set( assert occurrence.fingerprint == ["llm-detected-n+1-database-queries"] assert occurrence.type == AIDetectedDBGroupType - @with_feature("organizations:gen-ai-features") - @patch("sentry.tasks.llm_issue_detection.detection.mark_traces_as_processed") - @patch("sentry.tasks.llm_issue_detection.detection._get_unprocessed_traces") - @patch("sentry.tasks.llm_issue_detection.detection.make_issue_detection_request") - @patch("sentry.tasks.llm_issue_detection.trace_data.Spans.run_table_query") - @patch("sentry.tasks.llm_issue_detection.detection.random.shuffle") - def test_detect_llm_issues_full_flow( - self, - mock_shuffle, - mock_spans_query, - mock_seer_request, - mock_get_unprocessed, - mock_mark_processed, - ): - mock_shuffle.return_value = None # shuffles in-place, mock to prevent reordering - mock_get_unprocessed.return_value = {"trace_id_1", "trace_id_2"} # All unprocessed - - mock_spans_query.side_effect = [ - # First call: transaction spans - { - "data": [ - {"transaction": "POST /some/thing", "sum(span.duration)": 1007}, - {"transaction": "GET /another/", "sum(span.duration)": 1003}, - ], - "meta": {}, - }, - # Second call: trace for transaction 1 - { - "data": [ - {"trace": "trace_id_1", "precise.start_ts": 1234}, - ], - "meta": {}, - }, - # Third call: trace for transaction 2 - { - "data": [ - {"trace": "trace_id_2", "precise.start_ts": 1234}, - ], - "meta": {}, - }, - # Fourth call: span count query - { - "data": [ - {"trace": "trace_id_1", "count()": 50}, - {"trace": "trace_id_2", "count()": 100}, - ], - "meta": {}, - }, - ] - - # Seer returns 202 for async processing - mock_accepted_response = Mock() - mock_accepted_response.status = 202 - mock_seer_request.return_value = mock_accepted_response - - detect_llm_issues_for_project(self.project.id) - - assert mock_spans_query.call_count == 4 # 1 transactions, 2 traces, 1 span count - assert mock_seer_request.call_count == 1 # Single batch request - - seer_request = mock_seer_request.call_args[0][0] - assert seer_request.project_id == self.project.id - assert seer_request.organization_id == self.project.organization_id - assert len(seer_request.traces) == 2 - trace_ids = {t.trace_id for t in seer_request.traces} - assert trace_ids == {"trace_id_1", "trace_id_2"} - - assert mock_mark_processed.call_count == 1 - mock_mark_processed.assert_called_once_with(["trace_id_1", "trace_id_2"]) - - @with_feature("organizations:gen-ai-features") - @patch("sentry.tasks.llm_issue_detection.detection.mark_traces_as_processed") - @patch("sentry.tasks.llm_issue_detection.detection._get_unprocessed_traces") - @patch("sentry.tasks.llm_issue_detection.detection.make_issue_detection_request") - @patch("sentry.tasks.llm_issue_detection.trace_data.Spans.run_table_query") - @patch("sentry.tasks.llm_issue_detection.detection.random.shuffle") - @patch("sentry.tasks.llm_issue_detection.detection.logger.error") - def test_detect_llm_issues_seer_error_no_traces_marked( - self, - mock_logger_error, - mock_shuffle, - mock_spans_query, - mock_seer_request, - mock_get_unprocessed, - mock_mark_processed, - ): - mock_shuffle.return_value = None - mock_get_unprocessed.return_value = {"trace_id_1"} - - mock_spans_query.side_effect = [ - { - "data": [ - {"transaction": "POST /some/thing", "sum(span.duration)": 1007}, - ], - "meta": {}, - }, - {"data": [{"trace": "trace_id_1", "precise.start_ts": 1234}], "meta": {}}, - {"data": [{"trace": "trace_id_1", "count()": 50}], "meta": {}}, - ] - - mock_error_response = Mock() - mock_error_response.status = 500 - mock_error_response.data = b"Internal Server Error" - mock_seer_request.return_value = mock_error_response - - detect_llm_issues_for_project(self.project.id) - - assert mock_seer_request.call_count == 1 - assert mock_logger_error.call_count == 1 - # Traces NOT marked as processed on error - will be retried next run - assert mock_mark_processed.call_count == 0 - class TestTraceProcessingFunctions: @pytest.mark.parametrize( ("trace_ids", "mget_return", "expected"), [ - # All unprocessed (mget returns None for missing keys) (["a", "b", "c"], [None, None, None], {"a", "b", "c"}), - # Some processed (mget returns "1" for existing keys) (["a", "b", "c"], ["1", None, "1"], {"b"}), - # All processed (["a", "b"], ["1", "1"], set()), - # Empty input ([], [], set()), ], ) @@ -297,17 +257,14 @@ def test_get_unprocessed_traces( mock_cluster = Mock() mock_redis_clusters.get.return_value = mock_cluster mock_cluster.mget.return_value = mget_return - - result = _get_unprocessed_traces(trace_ids) - - assert result == expected + assert _get_unprocessed_traces(trace_ids) == expected @pytest.mark.parametrize( ("trace_ids", "expected_set_calls"), [ - (["trace_123"], 1), # Single trace - (["trace_1", "trace_2", "trace_3"], 3), # Multiple traces - ([], 0), # Empty list - early return, no pipeline calls + (["trace_123"], 1), + (["trace_1", "trace_2", "trace_3"], 3), + ([], 0), ], ) @patch("sentry.tasks.llm_issue_detection.detection.redis_clusters") @@ -323,9 +280,7 @@ def test_mark_traces_as_processed( mark_traces_as_processed(trace_ids) assert mock_pipeline.set.call_count == expected_set_calls - if expected_set_calls == 0: - mock_cluster.pipeline.assert_not_called() - else: + if expected_set_calls > 0: mock_pipeline.execute.assert_called_once() @@ -333,22 +288,18 @@ class TestGetValidTraceIdsBySpanCount: @pytest.mark.parametrize( ("query_result", "expected"), [ - # All valid ( {"data": [{"trace": "a", "count()": 50}, {"trace": "b", "count()": 100}]}, {"a": 50, "b": 100}, ), - # Some below lower limit ( {"data": [{"trace": "a", "count()": 10}, {"trace": "b", "count()": 50}]}, {"b": 50}, ), - # Some above upper limit ( {"data": [{"trace": "a", "count()": 50}, {"trace": "b", "count()": 600}]}, {"a": 50}, ), - # Empty result ({"data": []}, {}), ], ) @@ -357,13 +308,7 @@ def test_filters_by_span_count( self, mock_spans_query: Mock, query_result: dict, expected: dict[str, int] ) -> None: mock_spans_query.return_value = query_result - mock_snuba_params = Mock() - mock_config = Mock() - - result = get_valid_trace_ids_by_span_count( - ["a", "b", "c", "d"], mock_snuba_params, mock_config - ) - + result = get_valid_trace_ids_by_span_count(["a", "b", "c", "d"], Mock(), Mock()) assert result == expected @@ -376,13 +321,12 @@ def setUp(self) -> None: @patch("sentry.tasks.llm_issue_detection.trace_data.get_valid_trace_ids_by_span_count") def test_returns_deduped_transaction_traces(self, mock_span_count) -> None: - # Mock span count check to return all traces as valid mock_span_count.side_effect = lambda trace_ids, *args: {tid: 50 for tid in trace_ids} trace_id_1 = uuid.uuid4().hex span1 = self.create_span( { - "description": "GET /api/users/123456", # will dedupe + "description": "GET /api/users/123456", "sentry_tags": {"transaction": "GET /api/users/123456"}, "trace_id": trace_id_1, "is_segment": True, @@ -395,12 +339,12 @@ def test_returns_deduped_transaction_traces(self, mock_span_count) -> None: trace_id_2 = uuid.uuid4().hex span2 = self.create_span( { - "description": "GET /api/users/789012", # will dedupe + "description": "GET /api/users/789012", "sentry_tags": {"transaction": "GET /api/users/789012"}, "trace_id": trace_id_2, "is_segment": True, "exclusive_time_ms": 200, - "duration_ms": 200, # will return before span1 in transaction query + "duration_ms": 200, }, start_ts=self.ten_mins_ago + timedelta(seconds=1), ) @@ -425,7 +369,5 @@ def test_returns_deduped_transaction_traces(self, mock_span_count) -> None: ) assert len(evidence_traces) == 2 - - # trace_id_2 prevails over trace_id_1 because transaction span duration was higher assert evidence_traces[0].trace_id == trace_id_2 assert evidence_traces[1].trace_id == trace_id_3