Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .pre-commit-dev.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ repos:
rev: v0.9.1
hooks:
- id: blue
args: [ --verbose ]
6 changes: 0 additions & 6 deletions label_studio/core/all_urls.json
Original file line number Diff line number Diff line change
Expand Up @@ -946,12 +946,6 @@
"module": "ml.api.MLBackendTrainAPI",
"name": "ml:api:ml-train",
"decorators": ""
},
{
"url": "/api/ml/<int:pk>/predict/test",
"module": "ml.api.MLBackendPredictAPI",
"name": "ml:api:ml-predict-test",
"decorators": ""
},
{
"url": "/api/ml/<int:pk>/interactive-annotating",
Expand Down
5 changes: 2 additions & 3 deletions label_studio/data_manager/actions/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from core.permissions import AllPermissions
from core.redis import start_job_async_or_sync
from core.utils.common import load_func
from data_manager.functions import retrieve_predictions
from data_manager.functions import evaluate_predictions
from django.conf import settings
from projects.models import Project
from tasks.functions import update_tasks_counters
Expand All @@ -24,7 +24,7 @@ def retrieve_tasks_predictions(project, queryset, **kwargs):
:param project: project instance
:param queryset: filtered tasks db queryset
"""
retrieve_predictions(queryset)
evaluate_predictions(queryset)
return {'processed_items': queryset.count(), 'detail': 'Retrieved ' + str(queryset.count()) + ' predictions'}


Expand Down Expand Up @@ -138,7 +138,6 @@ def async_project_summary_recalculation(tasks_ids_list, project_id):
'title': 'Retrieve Predictions',
'order': 90,
'dialog': {
'modal_title': 'Retrieve Predictions',
'text': 'Send the selected tasks to all ML backends connected to the project.'
'This operation might be abruptly interrupted due to a timeout. '
'The recommended way to get predictions is to update tasks using the Label Studio API.'
Expand Down
1 change: 0 additions & 1 deletion label_studio/data_manager/actions/next_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def next_task(project, queryset, **kwargs):
# serialize task
context = {'request': request, 'project': project, 'resolve_uri': True, 'annotations': False}
serializer = NextTaskSerializer(next_task, context=context)

response = serializer.data
response['queue'] = queue_info
return response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def predictions_to_annotations_form(user, project):
{
'type': 'select',
'name': 'model_version',
'label': 'Choose predictions',
'label': 'Choose a model',
'options': versions,
}
],
Expand All @@ -95,10 +95,8 @@ def predictions_to_annotations_form(user, project):
'title': 'Create Annotations From Predictions',
'order': 91,
'dialog': {
'modal_title': 'Create Annotations From Predictions',
'text': 'Create annotations from predictions using selected predictions set '
'for each selected task.'
'Your account will be assigned as an owner to those annotations. ',
'text': 'This action will create new annotations from predictions with the selected model version '
'for each selected task.',
'type': 'confirm',
'form': predictions_to_annotations_form,
},
Expand Down
24 changes: 7 additions & 17 deletions label_studio/data_manager/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from core.utils.common import int_from_request, load_func
from core.utils.params import bool_from_request
from data_manager.actions import get_all_actions, perform_action
from data_manager.functions import get_prepare_params, get_prepared_queryset
from data_manager.functions import evaluate_predictions, get_prepare_params, get_prepared_queryset
from data_manager.managers import get_fields_for_evaluation
from data_manager.models import View
from data_manager.serializers import DataManagerTaskSerializer, ViewResetSerializer, ViewSerializer
Expand Down Expand Up @@ -255,15 +255,11 @@ def get(self, request):
# keep ids ordering
page = [tasks_by_ids[_id] for _id in ids]

# TODO MM TODO this needs a discussion, because I'd expect
# people to retrieve manually instead on DM load, plus it
# will slow down initial DM load

# retrieve ML predictions if tasks don't have them
# if not review and project.retrieve_predictions_automatically:
# tasks_for_predictions = Task.objects.filter(id__in=ids, predictions__isnull=True)
# retrieve_predictions(tasks_for_predictions)
# [tasks_by_ids[_id].refresh_from_db() for _id in ids]
if not review and project.evaluate_predictions_automatically:
tasks_for_predictions = Task.objects.filter(id__in=ids, predictions__isnull=True)
evaluate_predictions(tasks_for_predictions)
[tasks_by_ids[_id].refresh_from_db() for _id in ids]

if flag_set('fflag_fix_back_leap_24_tasks_api_optimization_05092023_short'):
serializer = self.task_serializer_class(
Expand All @@ -275,18 +271,13 @@ def get(self, request):
else:
serializer = self.task_serializer_class(page, many=True, context=context)
return self.get_paginated_response(serializer.data)

# TODO
# all tasks
# if project.retrieve_predictions_automatically:
# retrieve_predictions(queryset.filter(predictions__isnull=True))

if project.evaluate_predictions_automatically:
evaluate_predictions(queryset.filter(predictions__isnull=True))
queryset = Task.prepared.annotate_queryset(
queryset, fields_for_evaluation=fields_for_evaluation, all_fields=all_fields, request=request
)

serializer = self.task_serializer_class(queryset, many=True, context=context)

return Response(serializer.data)


Expand Down Expand Up @@ -386,7 +377,6 @@ def post(self, request):
# perform action and return the result dict
kwargs = {'request': request} # pass advanced params to actions
result = perform_action(action_id, project, queryset, request.user, **kwargs)

code = result.pop('response_code', 200)

return Response(result, status=code)
19 changes: 6 additions & 13 deletions label_studio/data_manager/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,23 +314,16 @@ def get_prepared_queryset(request, project):
return queryset


def retrieve_predictions(tasks, backend=None):
"""Call ML backend to retrieve predictions with the task queryset as an input"""
def evaluate_predictions(tasks):
"""Call ML backend for prediction evaluation of the task queryset"""
if not tasks:
return

if not backend:
project = tasks[0].project
backend = project.ml_backends.first()
project = tasks[0].project

# IMPORTANT change here, ml_backends.all => ml_backends.first
# we are using only one ML backend, not multiple
if backend:
return backend.predict_and_save(tasks=tasks)

# for ml_backend in project.ml_backends.first():
# # tasks = tasks.filter(~Q(predictions__model_version=ml_backend.model_version))
# ml_backend.predict_and_save(tasks=tasks)
for ml_backend in project.ml_backends.all():
# tasks = tasks.filter(~Q(predictions__model_version=ml_backend.model_version))
ml_backend.predict_tasks(tasks)


def filters_ordering_selected_items_exist(data):
Expand Down
119 changes: 15 additions & 104 deletions label_studio/ml/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from core.feature_flags import flag_set
from core.permissions import ViewClassPermission, all_permissions
from django.conf import settings
from django.http import Http404
from django.utils.decorators import method_decorator
from django_filters.rest_framework import DjangoFilterBackend
from drf_yasg.utils import swagger_auto_schema
Expand Down Expand Up @@ -77,11 +76,10 @@ class MLBackendListAPI(generics.ListCreateAPIView):
def get_queryset(self):
project_pk = self.request.query_params.get('project')
project = generics.get_object_or_404(Project, pk=project_pk)

self.check_object_permissions(self.request, project)

ml_backends = project.update_ml_backends_state()

ml_backends = MLBackend.objects.filter(project_id=project.id)
for mlb in ml_backends:
mlb.update_state()
return ml_backends

def perform_create(self, serializer):
Expand Down Expand Up @@ -204,58 +202,6 @@ def post(self, request, *args, **kwargs):
return Response(status=status.HTTP_200_OK)


@method_decorator(
name='post',
decorator=swagger_auto_schema(
tags=['Machine Learning'],
operation_summary='Predict',
operation_description="""
After you add an ML backend, call this API with the ML backend ID to run a test prediction on specific task data
""",
manual_parameters=[
openapi.Parameter(
name='id',
type=openapi.TYPE_INTEGER,
in_=openapi.IN_PATH,
description='A unique integer value identifying this ML backend.',
),
],
responses={
200: openapi.Response(title='Predicting OK', description='Predicting has successfully started.'),
500: openapi.Response(
description='Predicting error',
schema=openapi.Schema(
title='Error message',
description='Error message',
type=openapi.TYPE_STRING,
example='Server responded with an error.',
),
),
},
),
)
class MLBackendPredictTestAPI(APIView):
serializer_class = MLBackendSerializer
permission_required = all_permissions.projects_change

def post(self, request, *args, **kwargs):
ml_backend = generics.get_object_or_404(MLBackend, pk=self.kwargs['pk'])
self.check_object_permissions(self.request, ml_backend)

random = request.query_params.get('random', False)
if random:
task = Task.get_random(project=ml_backend.project)
if not task:
raise Http404

kwargs = ml_backend._predict(task)
return Response(**kwargs)

# TODO this needs to be implemented and needs to have a specific task param
ml_backend.predict()
return Response(status=status.HTTP_200_OK)


@method_decorator(
name='post',
decorator=swagger_auto_schema(
Expand All @@ -281,64 +227,29 @@ def post(self, request, *args, **kwargs):
),
)
class MLBackendInteractiveAnnotating(APIView):
""" """

permission_required = all_permissions.tasks_view

def _error_response(self, message, log_function=logger.info):
""" """
log_function(message)
return Response({'errors': [message]}, status=status.HTTP_200_OK)

def _get_task(self, ml_backend, validated_data):
""" """
return generics.get_object_or_404(Task, pk=validated_data['task'], project=ml_backend.project)

def _get_credentials(self, request, context, project):
""" """
if flag_set('ff_back_dev_2362_project_credentials_060722_short', request.user):
context.update(
project_credentials_login=project.task_data_login,
project_credentials_password=project.task_data_password,
)
return context

def _get_ml_results(self, ml_api_result):
""" """
results = ml_api_result.response.get('results', [None])
if isinstance(results, list) and len(results) >= 1:
return results[0]

return None

def post(self, request, *args, **kwargs):
""" """
ml_backend = generics.get_object_or_404(MLBackend, pk=self.kwargs['pk'])
self.check_object_permissions(request, ml_backend)

self.check_object_permissions(self.request, ml_backend)
serializer = MLInteractiveAnnotatingRequest(data=request.data)
serializer.is_valid(raise_exception=True)
validated_data = serializer.validated_data

task = self._get_task(ml_backend, serializer.validated_data)
context = self._get_credentials(request, serializer.validated_data.get('context', {}), task.project)

ml_api_result = ml_backend.interactive_annotating(task, context, user=self.request.user)
task = generics.get_object_or_404(Task, pk=validated_data['task'], project=ml_backend.project)
context = validated_data.get('context')

if ml_api_result.is_error:
message = f'Prediction not created for project {self}: {ml_api_result.error_message}'
return self._error_response(message)

if not isinstance(ml_api_result.response, dict) or 'results' not in ml_api_result.response:
message = f'Incorrect response from ML service it must be a dict and contain "results" key: {ml_api_result.response}'
return self._error_response(message)

ml_results = self._get_ml_results(ml_api_result)
if flag_set('ff_back_dev_2362_project_credentials_060722_short', request.user):
context['project_credentials_login'] = task.project.task_data_login
context['project_credentials_password'] = task.project.task_data_password

if not ml_results:
message = f'ML backend has to return a list with at least 1 annotation but it returned: {type(ml_results)}'
return self._error_response(message, logger.warning)
result = ml_backend.interactive_annotating(task, context, user=request.user)

return Response({'data': ml_results}, status=status.HTTP_200_OK)
return Response(
result,
status=status.HTTP_200_OK,
)


@method_decorator(
Expand Down
Loading