diff --git a/label_studio/core/all_urls.json b/label_studio/core/all_urls.json index ff41ce78f50f..903aad862a24 100644 --- a/label_studio/core/all_urls.json +++ b/label_studio/core/all_urls.json @@ -575,6 +575,12 @@ "name": "tasks:api:task-annotations-drafts", "decorators": "" }, + { + "url": "/api/tasks//agreement/", + "module": "tasks.api.TaskAgreementAPI", + "name": "tasks:api:task-agreement", + "decorators": "" + }, { "url": "/api/annotations//", "module": "tasks.api.AnnotationAPI", diff --git a/label_studio/tasks/api.py b/label_studio/tasks/api.py index 569a5a568dcc..a03d7ba6c919 100644 --- a/label_studio/tasks/api.py +++ b/label_studio/tasks/api.py @@ -2,6 +2,7 @@ import logging +from core.feature_flags import flag_set from core.mixins import GetParentObjectMixin from core.permissions import ViewClassPermission, all_permissions from core.utils.common import is_community @@ -405,6 +406,162 @@ def put(self, request, *args, **kwargs): return super(TaskAPI, self).put(request, *args, **kwargs) +@method_decorator( + name='get', + decorator=extend_schema( + tags=['Tasks'], + summary='Get task label distribution', + description='Get aggregated label distribution across all annotations for a task. ' + 'Returns counts of each label value grouped by control tag. ' + 'This is an efficient endpoint that avoids N+1 queries.', + responses={ + '200': OpenApiResponse( + description='Label distribution data', + examples=[ + OpenApiExample( + name='response', + value={ + 'total_annotations': 100, + 'distributions': { + 'label': { + 'type': 'rectanglelabels', + 'labels': {'Car': 45, 'Person': 30, 'Dog': 25}, + }, + }, + }, + media_type='application/json', + ) + ], + ) + }, + extensions={ + 'x-fern-audiences': ['internal'], + }, + ), +) +class TaskAgreementAPI(generics.RetrieveAPIView): + """ + Efficient endpoint for getting label distribution without fetching all annotations. + + This endpoint aggregates annotation results at the database level to avoid N+1 queries. + It returns pre-computed label counts for the Distribution row in the Summary view. + """ + + permission_required = ViewClassPermission(GET=all_permissions.tasks_view) + queryset = Task.objects.all() + + def get(self, request, pk): + # This endpoint is gated by feature flag + if not flag_set('fflag_fix_all_fit_720_lazy_load_annotations', user=request.user): + raise PermissionDenied('Feature not enabled') + + try: + task = Task.objects.get(pk=pk) + except Task.DoesNotExist: + return Response({'error': 'Task not found'}, status=404) + + # Check project access using LSO's native permission check + if not task.project.has_permission(request.user): + raise PermissionDenied('You do not have permission to view this task') + + # Get all annotations for this task with their results in a single query + annotations = Annotation.objects.filter( + task=task, + was_cancelled=False, + ).values_list('result', flat=True) + + total_annotations = len(annotations) + distributions = {} + + def merge_result_into_distributions(result): + """Merge a single result (list of labeling items) into distributions in place.""" + if not result or not isinstance(result, list): + return + for item in result: + if not isinstance(item, dict): + continue + from_name = item.get('from_name', '') + result_type = item.get('type', '') + value = item.get('value', {}) + + if from_name not in distributions: + distributions[from_name] = { + 'type': result_type, + 'labels': {}, + 'values': [], + } + + if result_type.endswith('labels'): + labels = value.get(result_type, []) + if isinstance(labels, list): + for label in labels: + if label not in distributions[from_name]['labels']: + distributions[from_name]['labels'][label] = 0 + distributions[from_name]['labels'][label] += 1 + + elif result_type == 'choices': + choices = value.get('choices', []) + if isinstance(choices, list): + for choice in choices: + if choice not in distributions[from_name]['labels']: + distributions[from_name]['labels'][choice] = 0 + distributions[from_name]['labels'][choice] += 1 + + elif result_type == 'rating': + rating = value.get('rating') + if rating is not None: + distributions[from_name]['values'].append(rating) + + elif result_type == 'number': + number = value.get('number') + if number is not None: + distributions[from_name]['values'].append(number) + + elif result_type == 'taxonomy': + taxonomy = value.get('taxonomy', []) + if isinstance(taxonomy, list): + for path in taxonomy: + if isinstance(path, list) and path: + leaf = path[-1] + if leaf not in distributions[from_name]['labels']: + distributions[from_name]['labels'][leaf] = 0 + distributions[from_name]['labels'][leaf] += 1 + + elif result_type == 'pairwise': + selected = value.get('selected') + if selected: + if selected not in distributions[from_name]['labels']: + distributions[from_name]['labels'][selected] = 0 + distributions[from_name]['labels'][selected] += 1 + + # Process annotation results + for result in annotations: + merge_result_into_distributions(result) + + # Include prediction results in distribution counts so aggregate matches + # client-side (develop / FF off). total_annotations stays annotation count only. + predictions = Prediction.objects.filter(task=task).values_list('result', flat=True) + for result in predictions: + # Prediction.result can be list (same as annotation) or dict + if isinstance(result, list): + merge_result_into_distributions(result) + + # Post-process: calculate averages for numeric types + for from_name, dist in distributions.items(): + if dist['values']: + dist['average'] = sum(dist['values']) / len(dist['values']) + dist['count'] = len(dist['values']) + # Remove raw values from response to keep it lightweight + del dist['values'] + + return Response( + { + 'total_annotations': total_annotations, + 'distributions': distributions, + } + ) + + @method_decorator( name='get', decorator=extend_schema( diff --git a/label_studio/tasks/tests/test_api.py b/label_studio/tasks/tests/test_api.py index 36401646dc3c..61c335011a66 100644 --- a/label_studio/tasks/tests/test_api.py +++ b/label_studio/tasks/tests/test_api.py @@ -1,9 +1,12 @@ +import unittest from unittest.mock import patch +from core.feature_flags import flag_set from organizations.tests.factories import OrganizationFactory +from projects.models import Project from projects.tests.factories import ProjectFactory from rest_framework.test import APITestCase -from tasks.tests.factories import TaskFactory +from tasks.tests.factories import AnnotationFactory, PredictionFactory, TaskFactory class TestTaskAPI(APITestCase): @@ -236,3 +239,382 @@ def test_get_task_resolve_uri_false_with_multiple_url_fields(self): assert response_data['image_2'] == 'gs://bucket-2/image2.png' assert response_data['audio'] == 'azure-blob://container/audio.mp3' assert response_data['text'] == 'Plain text field' + + +class TestTaskAgreementAPIFeatureOff(APITestCase): + """When feature flag is off, agreement endpoint returns 403. Always run this test.""" + + @classmethod + def setUpTestData(cls): + cls.organization = OrganizationFactory() + cls.project = ProjectFactory(organization=cls.organization) + cls.user = cls.organization.created_by + + @patch('tasks.api.flag_set') + def test_distribution_returns_403_when_feature_flag_disabled(self, mock_flag_set): + mock_flag_set.return_value = False + task = TaskFactory(project=self.project) + self.client.force_authenticate(user=self.user) + response = self.client.get(f'/api/tasks/{task.id}/agreement/') + assert response.status_code == 403 + assert 'detail' in response.json() or 'error' in response.json() + + +@unittest.skipUnless( + flag_set('fflag_fix_all_fit_720_lazy_load_annotations', user=None), + 'Agreement API tests require fflag_fix_all_fit_720_lazy_load_annotations to be on', +) +class TestTaskAgreementAPI(APITestCase): + """Tests for TaskAgreementAPI (GET /api/tasks//agreement/). Run only when feature flag is on.""" + + @classmethod + def setUpTestData(cls): + cls.organization = OrganizationFactory() + cls.project = ProjectFactory(organization=cls.organization) + cls.user = cls.organization.created_by + + @patch('tasks.api.flag_set') + def test_distribution_returns_404_for_nonexistent_task(self, mock_flag_set): + mock_flag_set.return_value = True + self.client.force_authenticate(user=self.user) + response = self.client.get('/api/tasks/99999/agreement/') + assert response.status_code == 404 + assert response.json() == {'error': 'Task not found'} + + @patch('tasks.api.flag_set') + @patch.object(Project, 'has_permission') + def test_distribution_permission_denied_for_other_project(self, mock_has_permission, mock_flag_set): + mock_flag_set.return_value = True + other_org = OrganizationFactory() + other_project = ProjectFactory(organization=other_org) + task = TaskFactory(project=other_project) + # In OSS Project.has_permission is a stub that always returns True; patch so other_project denies access + def has_perm(project, user): + return project.id != other_project.id + + mock_has_permission.side_effect = has_perm + self.client.force_authenticate(user=self.user) + response = self.client.get(f'/api/tasks/{task.id}/agreement/') + assert response.status_code == 403 + + @patch('tasks.api.flag_set') + def test_distribution_empty_task_returns_zero_annotations(self, mock_flag_set): + mock_flag_set.return_value = True + task = TaskFactory(project=self.project) + self.client.force_authenticate(user=self.user) + response = self.client.get(f'/api/tasks/{task.id}/agreement/') + assert response.status_code == 200 + data = response.json() + assert data['total_annotations'] == 0 + assert data['distributions'] == {} + + @patch('tasks.api.flag_set') + def test_distribution_with_rectanglelabels(self, mock_flag_set): + mock_flag_set.return_value = True + task = TaskFactory(project=self.project) + AnnotationFactory( + task=task, + project=self.project, + result=[ + { + 'from_name': 'label', + 'to_name': 'image', + 'type': 'rectanglelabels', + 'value': {'rectanglelabels': ['Car', 'Car']}, + } + ], + ) + AnnotationFactory( + task=task, + project=self.project, + result=[ + { + 'from_name': 'label', + 'to_name': 'image', + 'type': 'rectanglelabels', + 'value': {'rectanglelabels': ['Person']}, + } + ], + ) + self.client.force_authenticate(user=self.user) + response = self.client.get(f'/api/tasks/{task.id}/agreement/') + assert response.status_code == 200 + data = response.json() + assert data['total_annotations'] == 2 + assert data['distributions']['label'] == { + 'type': 'rectanglelabels', + 'labels': {'Car': 2, 'Person': 1}, + } + + @patch('tasks.api.flag_set') + def test_distribution_with_choices(self, mock_flag_set): + mock_flag_set.return_value = True + task = TaskFactory(project=self.project) + AnnotationFactory( + task=task, + project=self.project, + result=[ + { + 'from_name': 'sentiment', + 'to_name': 'text', + 'type': 'choices', + 'value': {'choices': ['Positive']}, + } + ], + ) + AnnotationFactory( + task=task, + project=self.project, + result=[ + { + 'from_name': 'sentiment', + 'to_name': 'text', + 'type': 'choices', + 'value': {'choices': ['Negative']}, + } + ], + ) + AnnotationFactory( + task=task, + project=self.project, + result=[ + { + 'from_name': 'sentiment', + 'to_name': 'text', + 'type': 'choices', + 'value': {'choices': ['Positive']}, + } + ], + ) + self.client.force_authenticate(user=self.user) + response = self.client.get(f'/api/tasks/{task.id}/agreement/') + assert response.status_code == 200 + data = response.json() + assert data['total_annotations'] == 3 + assert data['distributions']['sentiment'] == { + 'type': 'choices', + 'labels': {'Positive': 2, 'Negative': 1}, + } + + @patch('tasks.api.flag_set') + def test_distribution_with_rating(self, mock_flag_set): + mock_flag_set.return_value = True + task = TaskFactory(project=self.project) + AnnotationFactory( + task=task, + project=self.project, + result=[ + { + 'from_name': 'rating', + 'to_name': 'text', + 'type': 'rating', + 'value': {'rating': 4}, + } + ], + ) + AnnotationFactory( + task=task, + project=self.project, + result=[ + { + 'from_name': 'rating', + 'to_name': 'text', + 'type': 'rating', + 'value': {'rating': 5}, + } + ], + ) + self.client.force_authenticate(user=self.user) + response = self.client.get(f'/api/tasks/{task.id}/agreement/') + assert response.status_code == 200 + data = response.json() + assert data['total_annotations'] == 2 + assert data['distributions']['rating']['type'] == 'rating' + assert data['distributions']['rating']['average'] == 4.5 + assert data['distributions']['rating']['count'] == 2 + assert 'values' not in data['distributions']['rating'] + + @patch('tasks.api.flag_set') + def test_distribution_with_number(self, mock_flag_set): + mock_flag_set.return_value = True + task = TaskFactory(project=self.project) + AnnotationFactory( + task=task, + project=self.project, + result=[ + { + 'from_name': 'count', + 'to_name': 'text', + 'type': 'number', + 'value': {'number': 10}, + } + ], + ) + AnnotationFactory( + task=task, + project=self.project, + result=[ + { + 'from_name': 'count', + 'to_name': 'text', + 'type': 'number', + 'value': {'number': 20}, + } + ], + ) + self.client.force_authenticate(user=self.user) + response = self.client.get(f'/api/tasks/{task.id}/agreement/') + assert response.status_code == 200 + data = response.json() + assert data['total_annotations'] == 2 + assert data['distributions']['count']['type'] == 'number' + assert data['distributions']['count']['average'] == 15.0 + assert data['distributions']['count']['count'] == 2 + + @patch('tasks.api.flag_set') + def test_distribution_with_taxonomy(self, mock_flag_set): + mock_flag_set.return_value = True + task = TaskFactory(project=self.project) + AnnotationFactory( + task=task, + project=self.project, + result=[ + { + 'from_name': 'tax', + 'to_name': 'text', + 'type': 'taxonomy', + 'value': {'taxonomy': [['Animals', 'Dog']]}, + } + ], + ) + AnnotationFactory( + task=task, + project=self.project, + result=[ + { + 'from_name': 'tax', + 'to_name': 'text', + 'type': 'taxonomy', + 'value': {'taxonomy': [['Animals', 'Cat']]}, + } + ], + ) + self.client.force_authenticate(user=self.user) + response = self.client.get(f'/api/tasks/{task.id}/agreement/') + assert response.status_code == 200 + data = response.json() + assert data['total_annotations'] == 2 + assert data['distributions']['tax'] == { + 'type': 'taxonomy', + 'labels': {'Dog': 1, 'Cat': 1}, + } + + @patch('tasks.api.flag_set') + def test_distribution_with_pairwise(self, mock_flag_set): + mock_flag_set.return_value = True + task = TaskFactory(project=self.project) + AnnotationFactory( + task=task, + project=self.project, + result=[ + { + 'from_name': 'pair', + 'to_name': 'text', + 'type': 'pairwise', + 'value': {'selected': 'left'}, + } + ], + ) + AnnotationFactory( + task=task, + project=self.project, + result=[ + { + 'from_name': 'pair', + 'to_name': 'text', + 'type': 'pairwise', + 'value': {'selected': 'right'}, + } + ], + ) + self.client.force_authenticate(user=self.user) + response = self.client.get(f'/api/tasks/{task.id}/agreement/') + assert response.status_code == 200 + data = response.json() + assert data['total_annotations'] == 2 + assert data['distributions']['pair'] == { + 'type': 'pairwise', + 'labels': {'left': 1, 'right': 1}, + } + + @patch('tasks.api.flag_set') + def test_distribution_excludes_cancelled_annotations(self, mock_flag_set): + mock_flag_set.return_value = True + task = TaskFactory(project=self.project) + AnnotationFactory( + task=task, + project=self.project, + result=[ + { + 'from_name': 'label', + 'to_name': 'image', + 'type': 'rectanglelabels', + 'value': {'rectanglelabels': ['Car']}, + } + ], + ) + AnnotationFactory( + task=task, + project=self.project, + was_cancelled=True, + result=[ + { + 'from_name': 'label', + 'to_name': 'image', + 'type': 'rectanglelabels', + 'value': {'rectanglelabels': ['Person']}, + } + ], + ) + self.client.force_authenticate(user=self.user) + response = self.client.get(f'/api/tasks/{task.id}/agreement/') + assert response.status_code == 200 + data = response.json() + assert data['total_annotations'] == 1 + assert data['distributions']['label']['labels'] == {'Car': 1} + + @patch('tasks.api.flag_set') + def test_distribution_includes_predictions_in_label_counts(self, mock_flag_set): + """Predictions are merged into distributions so aggregate matches client-side (develop / FF off).""" + mock_flag_set.return_value = True + task = TaskFactory(project=self.project) + AnnotationFactory( + task=task, + project=self.project, + result=[ + { + 'from_name': 'label', + 'to_name': 'image', + 'type': 'rectanglelabels', + 'value': {'rectanglelabels': ['Car', 'Car']}, + } + ], + ) + PredictionFactory( + task=task, + project=self.project, + result=[ + { + 'from_name': 'label', + 'to_name': 'image', + 'type': 'rectanglelabels', + 'value': {'rectanglelabels': ['Car']}, + } + ], + ) + self.client.force_authenticate(user=self.user) + response = self.client.get(f'/api/tasks/{task.id}/agreement/') + assert response.status_code == 200 + data = response.json() + assert data['total_annotations'] == 1 + assert data['distributions']['label']['labels'] == {'Car': 3} diff --git a/label_studio/tasks/urls.py b/label_studio/tasks/urls.py index 251ec76ec4b3..340b4b88b0d2 100644 --- a/label_studio/tasks/urls.py +++ b/label_studio/tasks/urls.py @@ -21,6 +21,8 @@ api.AnnotationDraftListAPI.as_view(), name='task-annotations-drafts', ), + # Agreement endpoint for Summary view + path('/agreement/', api.TaskAgreementAPI.as_view(), name='task-agreement'), ] _api_annotations_urlpatterns = [ diff --git a/poetry.lock b/poetry.lock index 547039f5131c..95079a1bd040 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2150,7 +2150,7 @@ optional = false python-versions = ">=3.10,<4" groups = ["main"] files = [ - {file = "c4b8cffdd50f2ca70f37a3504b1c6d4457c0b11e.zip", hash = "sha256:72b3fd5f93c23acfc8627f528aa33b3ba92c7bae6fca778d8385946de0e45b9c"}, + {file = "ca54296fd090cf9b1da7f4b8cd60f85430d774c5.zip", hash = "sha256:bd260eabdbaf666bd0ea98935d496a8d289c10ca86ee9a2e78613477ac115bd5"}, ] [package.dependencies] @@ -2178,7 +2178,7 @@ xmljson = "0.2.1" [package.source] type = "url" -url = "https://github.com/HumanSignal/label-studio-sdk/archive/c4b8cffdd50f2ca70f37a3504b1c6d4457c0b11e.zip" +url = "https://github.com/HumanSignal/label-studio-sdk/archive/ca54296fd090cf9b1da7f4b8cd60f85430d774c5.zip" [[package]] name = "launchdarkly-server-sdk" @@ -5148,4 +5148,4 @@ uwsgi = ["pyuwsgi", "uwsgitop"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4" -content-hash = "c44130874f845c7565d9edd7b080de79a5d0eabbe3fffa9099933ed89c265624" +content-hash = "be6cf8e314ceac5b5de695c635348879ac597cd190f71ff2c1c75f06f6e9876c" diff --git a/pyproject.toml b/pyproject.toml index 92adf12111d3..4b4f8156eed5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ dependencies = [ "tldextract (>=5.1.3)", "uuid-utils (>=0.11.0,<1.0.0)", ## HumanSignal repo dependencies :start - "label-studio-sdk @ https://github.com/HumanSignal/label-studio-sdk/archive/c4b8cffdd50f2ca70f37a3504b1c6d4457c0b11e.zip", + "label-studio-sdk @ https://github.com/HumanSignal/label-studio-sdk/archive/ca54296fd090cf9b1da7f4b8cd60f85430d774c5.zip", ## HumanSignal repo dependencies :end ] diff --git a/web/libs/editor/src/components/App/App.jsx b/web/libs/editor/src/components/App/App.jsx index 245be11c93d1..9eda93a0a793 100644 --- a/web/libs/editor/src/components/App/App.jsx +++ b/web/libs/editor/src/components/App/App.jsx @@ -32,6 +32,8 @@ import { sanitizeHtml } from "../../utils/html"; import { reactCleaner } from "../../utils/reactCleaner"; import { guidGenerator } from "../../utils/unique"; import { isDefined, sortAnnotations } from "../../utils/utilities"; +import { QueryClientProvider } from "@tanstack/react-query"; +import { queryClient } from "@humansignal/core/lib/utils/query-client"; import { ToastProvider, ToastViewport } from "@humansignal/ui/lib/toast/toast"; /** @@ -256,46 +258,61 @@ class App extends Component { className={cn("editor").mod({ fullscreen: settings.fullscreen }).toClassName()} ref={isFF(FF_LSDV_4620_3_ML) ? reactCleaner(this) : null} > - - - - {newUIEnabled ? ( - store.toggleDescription()} - title={store.hasInterface("review") ? "Review Instructions" : "Labeling Instructions"} - > - {store.description} - - ) : ( - <> - {store.showingDescription && ( -
- {/* biome-ignore lint/security/noDangerouslySetInnerHtml: we need html here and it's sanitized */} -
-
- )} - - )} - - {isDefined(store) && store.hasInterface("topbar") && } -
+ + + + {newUIEnabled ? ( - isBulkMode || !store.hasInterface("side-column") ? ( - <> - {mainContent} - {store.hasInterface("topbar") && } - + store.toggleDescription()} + title={store.hasInterface("review") ? "Review Instructions" : "Labeling Instructions"} + > + {store.description} + + ) : ( + <> + {store.showingDescription && ( +
+ {/* biome-ignore lint/security/noDangerouslySetInnerHtml: we need html here and it's sanitized */} +
+
+ )} + + )} + + {isDefined(store) && store.hasInterface("topbar") && } +
+ {newUIEnabled ? ( + isBulkMode || !store.hasInterface("side-column") ? ( + <> + {mainContent} + {store.hasInterface("topbar") && } + + ) : ( + + {mainContent} + {store.hasInterface("topbar") && } + + ) + ) : isBulkMode || !store.hasInterface("side-column") ? ( + mainContent ) : ( - {mainContent} - {store.hasInterface("topbar") && } - - ) - ) : isBulkMode || !store.hasInterface("side-column") ? ( - mainContent - ) : ( - - {mainContent} - - )} -
- - - - {store.hasInterface("debug") && } + + )} +
+ +
+
+ {store.hasInterface("debug") && } +
); } diff --git a/web/libs/editor/src/components/App/Grid.jsx b/web/libs/editor/src/components/App/Grid.jsx index 6fe9440c6506..3e092984e87b 100644 --- a/web/libs/editor/src/components/App/Grid.jsx +++ b/web/libs/editor/src/components/App/Grid.jsx @@ -1,6 +1,6 @@ /** * Grid component for Compare view - renders annotation panels side-by-side - * FIT-720: Added virtualization support for large annotation counts + * Added virtualization support for large annotation counts */ import React, { Component, useCallback, useMemo, useRef, useState } from "react"; diff --git a/web/libs/editor/src/components/TaskSummary/Aggregation.tsx b/web/libs/editor/src/components/TaskSummary/Aggregation.tsx index 3da2aeb886e2..91e25d5f6098 100644 --- a/web/libs/editor/src/components/TaskSummary/Aggregation.tsx +++ b/web/libs/editor/src/components/TaskSummary/Aggregation.tsx @@ -1,13 +1,36 @@ import { useLayoutEffect, useRef, useState } from "react"; import { cnm, IconChevronDown } from "@humansignal/ui"; import type { Header } from "@tanstack/react-table"; +import { useQuery } from "@tanstack/react-query"; import type { RawResult } from "../../stores/types"; import { Chip } from "./Chip"; import type { AnnotationSummary, ControlTag } from "./types"; import { getLabelCounts } from "./utils"; +import { isActive, FF_FIT_720_LAZY_LOAD_ANNOTATIONS } from "@humansignal/core/lib/utils/feature-flags"; import styles from "./TaskSummary.module.scss"; +type DistributionData = { + total_annotations: number; + distributions: Record< + string, + { + type: string; + labels: Record; + average?: number; + count?: number; + } + >; +}; + +const fetchDistribution = async (taskId: number | string): Promise => { + const response = await fetch(`/api/tasks/${taskId}/agreement/`); + if (!response.ok) { + throw new Error("Failed to load distribution"); + } + return response.json(); +}; + const resultValue = (result: RawResult) => { if (result.type === "textarea") { return result.value.text; @@ -25,10 +48,11 @@ export const AggregationCell = ({ isExpanded, }: { control: ControlTag; annotations: AnnotationSummary[]; isExpanded: boolean }) => { const allResults = annotations.flatMap((ann) => ann.results.filter((r) => r.from_name === control.name)); - const totalAnnotations = annotations.length; + // Exclude predictions for percentage denominator to match backend TaskAgreementAPI + const totalAnnotations = annotations.filter((a) => a.type === "annotation").length; if (!allResults.length) { - return No data; + return N/A; } // Handle labels-type controls (rectanglelabels, polygonlabels, labels, etc.) @@ -138,12 +162,12 @@ export const AggregationCell = ({ ); } - // Handle rating - calculate average rating across all annotations + // Handle rating - average over annotations that have a value (matches backend TaskAgreementAPI) if (control.type === "rating") { const ratings = allResults.map((r) => resultValue(r)).filter(Boolean); if (!ratings.length) return No ratings; - const avgRating = ratings.reduce((sum, val) => sum + val, 0) / totalAnnotations; + const avgRating = ratings.reduce((sum, val) => sum + val, 0) / ratings.length; return ( Avg: {avgRating.toFixed(1)} @@ -151,12 +175,12 @@ export const AggregationCell = ({ ); } - // Handle number - calculate average number value across all annotations + // Handle number - average over annotations that have a value (matches backend TaskAgreementAPI) if (control.type === "number") { const numbers = allResults.map((r) => resultValue(r)).filter((v) => v !== null && v !== undefined); if (!numbers.length) return No data; - const avg = numbers.reduce((sum, val) => sum + Number(val), 0) / totalAnnotations; + const avg = numbers.reduce((sum, val) => sum + Number(val), 0) / numbers.length; return ( Avg: {avg.toFixed(1)} @@ -168,24 +192,118 @@ export const AggregationCell = ({ return N/A; }; +const DistributionSkeleton = () => ( +
+
+
+
+); + +const ApiAggregationCell = ({ + control, + distribution, + totalAnnotations, + isExpanded, +}: { + control: ControlTag; + distribution?: { type: string; labels: Record; average?: number; count?: number }; + totalAnnotations: number; + isExpanded: boolean; +}) => { + if (!distribution || Object.keys(distribution.labels).length === 0) { + // Check if it's a numeric type with average + if (distribution?.average !== undefined) { + return ( + + Avg: {distribution.average.toFixed(1)} + {distribution.type === "rating" && } + + ); + } + return N/A; + } + + // Sort labels by count descending + const sortedLabels = Object.entries(distribution.labels).sort(([, a], [, b]) => b - a); + + // Handle choices/taxonomy with percentages + if (distribution.type === "choices" || distribution.type === "taxonomy") { + return ( +
+ {sortedLabels.map(([label, count]) => ( + + {label} + + ))} +
+ ); + } + + // Handle labels and other types with counts + return ( +
+ {sortedLabels.map(([label, count]) => ( + + {label} + + ))} +
+ ); +}; + /** * Renders the complete aggregation/distribution row across all columns. * Includes a toggle button in the first cell that only appears when content overflows. * The toggle expands/collapses the cells to show full content. + * + * With lazy loading, fetches distribution from dedicated API endpoint + * for efficient aggregation without N+1 queries. */ export const AggregationTableRow = ({ headers, controls, annotations, + taskId, }: { headers: Header[]; controls: ControlTag[]; annotations: AnnotationSummary[]; + taskId?: number | string; }) => { const [isExpanded, setIsExpanded] = useState(false); const [hasOverflow, setHasOverflow] = useState(false); const rowRef = useRef(null); + // For non-lazy loading mode, compute from annotations as before + const useApiData = isActive(FF_FIT_720_LAZY_LOAD_ANNOTATIONS) && taskId; + + const { + data: distributionData, + isLoading, + error, + } = useQuery({ + queryKey: ["task-agreement", taskId], + queryFn: () => fetchDistribution(taskId!), + enabled: useApiData && !!taskId, + staleTime: 30000, // Consider data fresh for 30 seconds + gcTime: 5 * 60 * 1000, // Keep in cache for 5 minutes (formerly cacheTime) + }); + useLayoutEffect(() => { if (!rowRef.current) return; @@ -196,7 +314,7 @@ export const AggregationTableRow = ({ }); setHasOverflow(hasOverflowingCells); - }, [annotations, controls]); + }, [annotations, controls, distributionData]); return ( @@ -210,18 +328,26 @@ export const AggregationTableRow = ({ )} style={{ width: header.getSize() }} > - {hasOverflow ? ( - - ) : ( - Distribution - )} +
+ {hasOverflow ? ( + + ) : ( + Distribution + )} + {/* Show total count from API */} + {useApiData && distributionData && ( + + {distributionData.total_annotations} annotations + + )} +
) : ( - + {useApiData && isLoading ? ( + + ) : useApiData && error ? ( + Failed to load + ) : useApiData && distributionData ? ( + + ) : ( + + )} ), )} diff --git a/web/libs/editor/src/components/TaskSummary/LabelingSummary.tsx b/web/libs/editor/src/components/TaskSummary/LabelingSummary.tsx index d9745b83266b..a57a8f20b7dc 100644 --- a/web/libs/editor/src/components/TaskSummary/LabelingSummary.tsx +++ b/web/libs/editor/src/components/TaskSummary/LabelingSummary.tsx @@ -27,6 +27,7 @@ type Props = { controls: ControlTag[]; onSelect: (entity: AnnotationSummary) => void; hideInfo: boolean; + taskId?: number | string; }; const cellFn = (control: ControlTag, render: RendererType) => (props: { row: Row }) => { @@ -57,7 +58,7 @@ const convertPredictionResult = (result: MSTResult) => { const columnHelper = createColumnHelper(); -export const LabelingSummary = ({ hideInfo, annotations: all, controls, onSelect }: Props) => { +export const LabelingSummary = ({ hideInfo, annotations: all, controls, onSelect, taskId }: Props) => { const currentUser = window.APP_SETTINGS?.user; const [columnWidths, setColumnWidths] = useState>({}); const tableRef = useRef(null); @@ -218,6 +219,7 @@ export const LabelingSummary = ({ hideInfo, annotations: all, controls, onSelect headers={table.getHeaderGroups()[0]?.headers ?? []} controls={controls} annotations={annotations} + taskId={taskId} /> )} {/* Annotation Rows */} diff --git a/web/libs/editor/src/components/TaskSummary/TaskSummary.tsx b/web/libs/editor/src/components/TaskSummary/TaskSummary.tsx index e73c02bf910f..74aad025e407 100644 --- a/web/libs/editor/src/components/TaskSummary/TaskSummary.tsx +++ b/web/libs/editor/src/components/TaskSummary/TaskSummary.tsx @@ -117,6 +117,7 @@ const TaskSummary = ({ annotations: all, store: annotationStore }: TaskSummaryPr controls={controls} onSelect={onSelect} hideInfo={annotationStore.store.hasInterface("annotations:hide-info")} + taskId={task?.id} />
diff --git a/web/libs/editor/src/components/TaskSummary/__tests__/TaskSummary.test.tsx b/web/libs/editor/src/components/TaskSummary/__tests__/TaskSummary.test.tsx index bf0f46aa7c43..8e658bc984bf 100644 --- a/web/libs/editor/src/components/TaskSummary/__tests__/TaskSummary.test.tsx +++ b/web/libs/editor/src/components/TaskSummary/__tests__/TaskSummary.test.tsx @@ -1,7 +1,21 @@ +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; +import type { ReactElement } from "react"; import { render, screen } from "@testing-library/react"; import type { MSTAnnotation, MSTStore } from "../../../stores/types"; import TaskSummary from "../TaskSummary"; +const createTestQueryClient = () => + new QueryClient({ + defaultOptions: { + queries: { retry: false }, + }, + }); + +const renderWithQueryClient = (ui: ReactElement) => { + const queryClient = createTestQueryClient(); + return render({ui}); +}; + // Polyfill for Object.groupBy which may not be available in test environment if (!Object.groupBy) { Object.groupBy = ( @@ -165,7 +179,7 @@ describe("TaskSummary", () => { const annotations = [createMockAnnotation()]; const store = createMockStore(); - render(); + renderWithQueryClient(); expect(screen.getByText("Task Summary")).toBeInTheDocument(); expect(screen.getByText("Task Data")).toBeInTheDocument(); @@ -181,7 +195,7 @@ describe("TaskSummary", () => { }, }); - render(); + renderWithQueryClient(); expect(screen.getByText("Agreement")).toBeInTheDocument(); expect(screen.getByText("85.5%")).toBeInTheDocument(); @@ -197,7 +211,7 @@ describe("TaskSummary", () => { }, }); - render(); + renderWithQueryClient(); // Backend controls agreement visibility, so if we have a number, show it expect(screen.getByText("Agreement")).toBeInTheDocument(); @@ -210,7 +224,7 @@ describe("TaskSummary", () => { project: null, }); - render(); + renderWithQueryClient(); // Backend controls agreement visibility, so if we have a number, show it expect(screen.getByText("Agreement")).toBeInTheDocument(); @@ -226,7 +240,7 @@ describe("TaskSummary", () => { ]; const store = createMockStore(); - render(); + renderWithQueryClient(); expect(screen.getByText("Annotations")).toBeInTheDocument(); expect(screen.getByText("2")).toBeInTheDocument(); // Only submitted annotations @@ -241,7 +255,7 @@ describe("TaskSummary", () => { ]; const store = createMockStore(); - render(); + renderWithQueryClient(); expect(screen.getByText("Predictions")).toBeInTheDocument(); expect(screen.getByText("2")).toBeInTheDocument(); // Only submitted predictions @@ -266,7 +280,7 @@ describe("TaskSummary", () => { ]), }); - render(); + renderWithQueryClient(); expect(screen.getByText("Annotator")).toBeInTheDocument(); expect(screen.getByText("sentiment")).toBeInTheDocument(); @@ -288,7 +302,7 @@ describe("TaskSummary", () => { ]), }); - render(); + renderWithQueryClient(); // Object tags should appear in the data summary (as header and badge) expect(screen.getAllByText("text")).toHaveLength(2); // header + badge @@ -299,7 +313,7 @@ describe("TaskSummary", () => { const annotations: MSTAnnotation[] = []; const store = createMockStore(); - render(); + renderWithQueryClient(); // Should show 0 for both annotations and predictions expect(screen.getByText("Annotations")).toBeInTheDocument(); @@ -322,7 +336,7 @@ describe("TaskSummary", () => { }, }); - render(); + renderWithQueryClient(); // Should not display agreement when it's undefined expect(screen.queryByText("Agreement")).not.toBeInTheDocument(); @@ -351,7 +365,7 @@ describe("TaskSummary", () => { names: new Map([controlWithPerRegion]), }); - render(); + renderWithQueryClient(); expect(screen.getByText("regionLabel")).toBeInTheDocument(); }); @@ -367,7 +381,7 @@ describe("TaskSummary", () => { ]), }); - render(); + renderWithQueryClient(); // Only valid object tags with $ prefix should appear (as header and badge) expect(screen.getAllByText("text")).toHaveLength(2); // header + badge