Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
4ae5e6e
Updated implementation of the ML Backend experience. Docs & Test to f…
Feb 4, 2024
72da7a9
Updating the experience a bit more based on the feedback. Plus
Feb 18, 2024
e6fc978
minor updates on the experience
Mar 4, 2024
9b1f263
removing debug info, uncommenting bits
Mar 4, 2024
bf134a6
Merge branch 'develop' into dev/ml-backend-exp
deppp Mar 5, 2024
12a7227
fixing small issue to make it backward compatible with the previous p…
Mar 5, 2024
8eec2cd
Merge branch 'dev/ml-backend-exp' of https://github.com/heartexlabs/l…
Mar 5, 2024
9d314e7
Fix errors, code cleanup, fix ruff
Mar 7, 2024
fb48fc7
Downgrade testing-library, reformat with linters
Mar 7, 2024
9211b86
Remove excessive calls in task api
Mar 7, 2024
3b082f2
Fix frontend script
Mar 7, 2024
a14857d
ci: Build frontend
robot-ci-heartex Mar 7, 2024
1379c5c
fmt
jombooth Mar 7, 2024
7b8825d
fix sdk version & change api /predict/test
Mar 8, 2024
2b6f5c4
ci: Build frontend
robot-ci-heartex Mar 8, 2024
547c9a2
try update lock
Mar 8, 2024
4983250
Merge branch 'develop' into 'fb-ml-backend-exp'
bmartel Mar 8, 2024
38b87e1
[submodules] Copy src HumanSignal/dm2
bmartel Mar 8, 2024
753b561
ci: Build frontend
robot-ci-heartex Mar 8, 2024
d4b4bb3
Address review comments with stylistic changes, remove unusable code,…
Mar 9, 2024
03f9056
Additional changes
Mar 9, 2024
9005c57
Add svg icon file
Mar 9, 2024
e9a3983
ci: Build frontend
robot-ci-heartex Mar 9, 2024
c972f98
Fix ml/predict/test api
Mar 9, 2024
b55d6d3
Running a formatter/lint on the code
bmartel Mar 11, 2024
7d6f8ea
ci: Build frontend
robot-ci-heartex Mar 11, 2024
9ef194b
Running a formatter/lint on the code
bmartel Mar 11, 2024
97e5de6
Merge remote-tracking branch 'origin/fb-ml-backend-exp' into fb-ml-ba…
bmartel Mar 11, 2024
4004226
ci: Build frontend
robot-ci-heartex Mar 11, 2024
a644e14
Running a formatter/lint on the code
bmartel Mar 11, 2024
dbeb40a
Merge remote-tracking branch 'origin/fb-ml-backend-exp' into fb-ml-ba…
bmartel Mar 11, 2024
42e5975
Running a formatter/lint on the code
bmartel Mar 11, 2024
c0e12f4
ci: Build frontend
robot-ci-heartex Mar 11, 2024
023ff5a
Handle security on password input
Mar 11, 2024
8a4e568
ci: Build frontend
robot-ci-heartex Mar 11, 2024
b3023ce
Display default password when not specified
Mar 11, 2024
478a54b
ci: Build frontend
robot-ci-heartex Mar 11, 2024
efd96c9
Reduce number of calls due to project id cache
Mar 11, 2024
911c1e3
ci: Build frontend
robot-ci-heartex Mar 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pre-commit-dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ repos:
rev: v0.9.1
hooks:
- id: blue
args: [ --verbose ]
6 changes: 6 additions & 0 deletions label_studio/core/all_urls.json
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,12 @@
"module": "ml.api.MLBackendTrainAPI",
"name": "ml:api:ml-train",
"decorators": ""
},
{
"url": "/api/ml/<int:pk>/predict/test",
"module": "ml.api.MLBackendPredictAPI",
"name": "ml:api:ml-predict-test",
"decorators": ""
},
{
"url": "/api/ml/<int:pk>/interactive-annotating",
Expand Down
5 changes: 3 additions & 2 deletions label_studio/data_manager/actions/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from core.permissions import AllPermissions
from core.redis import start_job_async_or_sync
from core.utils.common import load_func
from data_manager.functions import evaluate_predictions
from data_manager.functions import retrieve_predictions
from django.conf import settings
from projects.models import Project
from tasks.functions import update_tasks_counters
Expand All @@ -24,7 +24,7 @@ def retrieve_tasks_predictions(project, queryset, **kwargs):
:param project: project instance
:param queryset: filtered tasks db queryset
"""
evaluate_predictions(queryset)
retrieve_predictions(queryset)
return {'processed_items': queryset.count(), 'detail': 'Retrieved ' + str(queryset.count()) + ' predictions'}


Expand Down Expand Up @@ -138,6 +138,7 @@ def async_project_summary_recalculation(tasks_ids_list, project_id):
'title': 'Retrieve Predictions',
'order': 90,
'dialog': {
'modal_title': 'Retrieve Predictions',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modal_title -> title

'text': 'Send the selected tasks to all ML backends connected to the project.'
'This operation might be abruptly interrupted due to a timeout. '
'The recommended way to get predictions is to update tasks using the Label Studio API.'
Expand Down
1 change: 1 addition & 0 deletions label_studio/data_manager/actions/next_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def next_task(project, queryset, **kwargs):
# serialize task
context = {'request': request, 'project': project, 'resolve_uri': True, 'annotations': False}
serializer = NextTaskSerializer(next_task, context=context)

response = serializer.data
response['queue'] = queue_info
return response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def predictions_to_annotations_form(user, project):
{
'type': 'select',
'name': 'model_version',
'label': 'Choose a model',
'label': 'Choose predictions',
'options': versions,
}
],
Expand All @@ -95,8 +95,10 @@ def predictions_to_annotations_form(user, project):
'title': 'Create Annotations From Predictions',
'order': 91,
'dialog': {
'text': 'This action will create new annotations from predictions with the selected model version '
'for each selected task.',
'modal_title': 'Create Annotations From Predictions',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modal_title -> title

'text': 'Create annotations from predictions using selected predictions set '
'for each selected task.'
'Your account will be assigned as an owner to those annotations. ',
'type': 'confirm',
'form': predictions_to_annotations_form,
},
Expand Down
24 changes: 17 additions & 7 deletions label_studio/data_manager/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from core.utils.common import int_from_request, load_func
from core.utils.params import bool_from_request
from data_manager.actions import get_all_actions, perform_action
from data_manager.functions import evaluate_predictions, get_prepare_params, get_prepared_queryset
from data_manager.functions import get_prepare_params, get_prepared_queryset
from data_manager.managers import get_fields_for_evaluation
from data_manager.models import View
from data_manager.serializers import DataManagerTaskSerializer, ViewResetSerializer, ViewSerializer
Expand Down Expand Up @@ -255,11 +255,15 @@ def get(self, request):
# keep ids ordering
page = [tasks_by_ids[_id] for _id in ids]

# TODO MM TODO this needs a discussion, because I'd expect
# people to retrieve manually instead on DM load, plus it
# will slow down initial DM load

# retrieve ML predictions if tasks don't have them
if not review and project.evaluate_predictions_automatically:
tasks_for_predictions = Task.objects.filter(id__in=ids, predictions__isnull=True)
evaluate_predictions(tasks_for_predictions)
[tasks_by_ids[_id].refresh_from_db() for _id in ids]
# if not review and project.retrieve_predictions_automatically:
# tasks_for_predictions = Task.objects.filter(id__in=ids, predictions__isnull=True)
# retrieve_predictions(tasks_for_predictions)
# [tasks_by_ids[_id].refresh_from_db() for _id in ids]

if flag_set('fflag_fix_back_leap_24_tasks_api_optimization_05092023_short'):
serializer = self.task_serializer_class(
Expand All @@ -271,13 +275,18 @@ def get(self, request):
else:
serializer = self.task_serializer_class(page, many=True, context=context)
return self.get_paginated_response(serializer.data)

# TODO
# all tasks
if project.evaluate_predictions_automatically:
evaluate_predictions(queryset.filter(predictions__isnull=True))
# if project.retrieve_predictions_automatically:
# retrieve_predictions(queryset.filter(predictions__isnull=True))

queryset = Task.prepared.annotate_queryset(
queryset, fields_for_evaluation=fields_for_evaluation, all_fields=all_fields, request=request
)

serializer = self.task_serializer_class(queryset, many=True, context=context)

return Response(serializer.data)


Expand Down Expand Up @@ -377,6 +386,7 @@ def post(self, request):
# perform action and return the result dict
kwargs = {'request': request} # pass advanced params to actions
result = perform_action(action_id, project, queryset, request.user, **kwargs)

code = result.pop('response_code', 200)

return Response(result, status=code)
19 changes: 13 additions & 6 deletions label_studio/data_manager/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,16 +314,23 @@ def get_prepared_queryset(request, project):
return queryset


def evaluate_predictions(tasks):
"""Call ML backend for prediction evaluation of the task queryset"""
def retrieve_predictions(tasks, backend=None):
"""Call ML backend to retrieve predictions with the task queryset as an input"""
if not tasks:
return

project = tasks[0].project
if not backend:
project = tasks[0].project
backend = project.ml_backends.first()

for ml_backend in project.ml_backends.all():
# tasks = tasks.filter(~Q(predictions__model_version=ml_backend.model_version))
ml_backend.predict_tasks(tasks)
# IMPORTANT change here, ml_backends.all => ml_backends.first
# we are using only one ML backend, not multiple
if backend:
return backend.predict_and_save(tasks=tasks)

# for ml_backend in project.ml_backends.first():
# # tasks = tasks.filter(~Q(predictions__model_version=ml_backend.model_version))
# ml_backend.predict_and_save(tasks=tasks)


def filters_ordering_selected_items_exist(data):
Expand Down
119 changes: 104 additions & 15 deletions label_studio/ml/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from core.feature_flags import flag_set
from core.permissions import ViewClassPermission, all_permissions
from django.conf import settings
from django.http import Http404
from django.utils.decorators import method_decorator
from django_filters.rest_framework import DjangoFilterBackend
from drf_yasg.utils import swagger_auto_schema
Expand Down Expand Up @@ -76,10 +77,11 @@ class MLBackendListAPI(generics.ListCreateAPIView):
def get_queryset(self):
project_pk = self.request.query_params.get('project')
project = generics.get_object_or_404(Project, pk=project_pk)

self.check_object_permissions(self.request, project)
ml_backends = MLBackend.objects.filter(project_id=project.id)
for mlb in ml_backends:
mlb.update_state()

ml_backends = project.update_ml_backends_state()

return ml_backends

def perform_create(self, serializer):
Expand Down Expand Up @@ -202,6 +204,58 @@ def post(self, request, *args, **kwargs):
return Response(status=status.HTTP_200_OK)


@method_decorator(
name='post',
decorator=swagger_auto_schema(
tags=['Machine Learning'],
operation_summary='Predict',
operation_description="""
After you add an ML backend, call this API with the ML backend ID to run a test prediction on specific task data
""",
manual_parameters=[
openapi.Parameter(
name='id',
type=openapi.TYPE_INTEGER,
in_=openapi.IN_PATH,
description='A unique integer value identifying this ML backend.',
),
],
responses={
200: openapi.Response(title='Predicting OK', description='Predicting has successfully started.'),
500: openapi.Response(
description='Predicting error',
schema=openapi.Schema(
title='Error message',
description='Error message',
type=openapi.TYPE_STRING,
example='Server responded with an error.',
),
),
},
),
)
class MLBackendPredictTestAPI(APIView):
serializer_class = MLBackendSerializer
permission_required = all_permissions.projects_change

def post(self, request, *args, **kwargs):
ml_backend = generics.get_object_or_404(MLBackend, pk=self.kwargs['pk'])
self.check_object_permissions(self.request, ml_backend)

random = request.query_params.get('random', False)
if random:
task = Task.get_random(project=ml_backend.project)
if not task:
raise Http404

kwargs = ml_backend._predict(task)
return Response(**kwargs)

# TODO this needs to be implemented and needs to have a specific task param
ml_backend.predict()
return Response(status=status.HTTP_200_OK)


@method_decorator(
name='post',
decorator=swagger_auto_schema(
Expand All @@ -227,29 +281,64 @@ def post(self, request, *args, **kwargs):
),
)
class MLBackendInteractiveAnnotating(APIView):
""" """

permission_required = all_permissions.tasks_view

def _error_response(self, message, log_function=logger.info):
""" """
log_function(message)
return Response({'errors': [message]}, status=status.HTTP_200_OK)

def _get_task(self, ml_backend, validated_data):
""" """
return generics.get_object_or_404(Task, pk=validated_data['task'], project=ml_backend.project)

def _get_credentials(self, request, context, project):
""" """
if flag_set('ff_back_dev_2362_project_credentials_060722_short', request.user):
context.update(
project_credentials_login=project.task_data_login,
project_credentials_password=project.task_data_password,
)
return context

def _get_ml_results(self, ml_api_result):
""" """
results = ml_api_result.response.get('results', [None])
if isinstance(results, list) and len(results) >= 1:
return results[0]

return None

def post(self, request, *args, **kwargs):
""" """
ml_backend = generics.get_object_or_404(MLBackend, pk=self.kwargs['pk'])
self.check_object_permissions(self.request, ml_backend)
self.check_object_permissions(request, ml_backend)

serializer = MLInteractiveAnnotatingRequest(data=request.data)
serializer.is_valid(raise_exception=True)
validated_data = serializer.validated_data

task = generics.get_object_or_404(Task, pk=validated_data['task'], project=ml_backend.project)
context = validated_data.get('context')
task = self._get_task(ml_backend, serializer.validated_data)
context = self._get_credentials(request, serializer.validated_data.get('context', {}), task.project)

if flag_set('ff_back_dev_2362_project_credentials_060722_short', request.user):
context['project_credentials_login'] = task.project.task_data_login
context['project_credentials_password'] = task.project.task_data_password
ml_api_result = ml_backend.interactive_annotating(task, context, user=self.request.user)

if ml_api_result.is_error:
message = f'Prediction not created for project {self}: {ml_api_result.error_message}'
return self._error_response(message)

if not isinstance(ml_api_result.response, dict) or 'results' not in ml_api_result.response:
message = f'Incorrect response from ML service it must be a dict and contain "results" key: {ml_api_result.response}'
return self._error_response(message)

ml_results = self._get_ml_results(ml_api_result)

result = ml_backend.interactive_annotating(task, context, user=request.user)
if not ml_results:
message = f'ML backend has to return a list with at least 1 annotation but it returned: {type(ml_results)}'
return self._error_response(message, logger.warning)

return Response(
result,
status=status.HTTP_200_OK,
)
return Response({'data': ml_results}, status=status.HTTP_200_OK)


@method_decorator(
Expand Down
Loading