diff --git a/jobs/jobs/config.py b/jobs/jobs/config.py index 4ec64c6bd..f3095169d 100644 --- a/jobs/jobs/config.py +++ b/jobs/jobs/config.py @@ -54,6 +54,9 @@ def get_service_uri(prefix: str) -> str: # noqa PAGINATION_THRESHOLD = 7 PROVIDE_JWT_IF_NO_ANY = True +# request settings +REQUEST_TIMEOUT = 60 # in seconds + # S3 settings STORAGE_PROVIDER = os.getenv("STORAGE_PROVIDER") JOBS_SIGNED_URL_ENABLED = ( diff --git a/jobs/jobs/main.py b/jobs/jobs/main.py index 6e98313a1..122b69a9f 100644 --- a/jobs/jobs/main.py +++ b/jobs/jobs/main.py @@ -72,6 +72,13 @@ async def create_job( ) ) + if len(job_params.revisions) > 0: + await utils.update_create_job_params_using_revisions( + job_params=job_params, + current_tenant=current_tenant, + jwt_token=jw_token, + ) + if job_params.type == schemas.JobType.ExtractionJob: created_extraction_job = await create_job_funcs.create_extraction_job( extraction_job_input=job_params, # type: ignore diff --git a/jobs/jobs/utils.py b/jobs/jobs/utils.py index 35fc28ee2..a796f0014 100644 --- a/jobs/jobs/utils.py +++ b/jobs/jobs/utils.py @@ -7,7 +7,7 @@ import jobs.airflow_utils as airflow_utils import jobs.databricks_utils as databricks_utils import jobs.pipeline as pipeline -from jobs import db_service +from jobs import db_service, schemas from jobs.config import ( ANNOTATION_SERVICE_HOST, ASSETS_SERVICE_HOST, @@ -16,6 +16,7 @@ JOBS_SIGNED_URL_KEY_NAME, PAGINATION_THRESHOLD, PIPELINES_SERVICE_HOST, + REQUEST_TIMEOUT, ROOT_PATH, TAXONOMY_SERVICE_HOST, USERS_HOST, @@ -572,11 +573,8 @@ async def get_job_progress( "X-Current-Tenant": current_tenant, "Authorization": f"Bearer: {jw_token}", } - timeout = aiohttp.ClientTimeout(total=5) try: - _, response = await fetch( - method="GET", url=url, headers=headers, timeout=timeout - ) + _, response = await fetch(method="GET", url=url, headers=headers) except aiohttp.client_exceptions.ClientError as err: logger.exception(f"Failed request url = {url}, error = {err}") raise fastapi.HTTPException( @@ -596,6 +594,8 @@ async def fetch( headers: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Tuple[int, Any]: + if "timeout" not in kwargs: + kwargs["timeout"] = aiohttp.ClientTimeout(total=REQUEST_TIMEOUT) async with aiohttp.request( method=method, url=url, json=body, headers=headers, data=data, **kwargs ) as resp: @@ -767,13 +767,11 @@ async def get_annotation_revisions( "X-Current-Tenant": current_tenant, "Authorization": f"Bearer: {jw_token}", } - timeout = aiohttp.ClientTimeout(total=5) try: _, response = await fetch( method="GET", url=f"{ANNOTATION_SERVICE_HOST}/revisions/{job_id}/{file_id}", headers=headers, - timeout=timeout, ) except aiohttp.client_exceptions.ClientError as err: logger.exception( @@ -788,6 +786,42 @@ async def get_annotation_revisions( return response +async def get_annotations_by_revisions( + current_tenant: Optional[str], jw_token: str, revisions: List[str] +) -> Optional[Dict[str, Any]]: + """Get annotations by filtering""" + + headers = { + "X-Current-Tenant": current_tenant, + "Authorization": f"Bearer: {jw_token}", + } + + post_data = { + "filters": [ + {"field": "revision", "operator": "in", "value": revisions} + ] + } + + try: + _, response = await fetch( + method="POST", + url=f"{ANNOTATION_SERVICE_HOST}/annotation", + headers=headers, + body=post_data, + ) + except aiohttp.client_exceptions.ClientError as err: + logger.exception( + f"Failed request to get annotations by revisions: {revisions}" + ) + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_400_BAD_REQUEST, + detail="Could not retrieve selected annotations:" + f" {', '.join(revisions)}", + ) from err + + return response + + async def search_datasets_by_ids( datasets_ids: List[int], current_tenant: str, jw_token: str ) -> Dict[str, Any]: @@ -860,3 +894,24 @@ async def validate_create_job_previous_jobs( detail="Jobs with these ids do not exist.", ) return [j.id for j in previous_jobs] + + +async def update_create_job_params_using_revisions( + job_params: schemas.JobParams, current_tenant: str, jwt_token: str +) -> None: + response = await get_annotations_by_revisions( + current_tenant=current_tenant, + jw_token=jwt_token, + revisions=list(job_params.revisions), + ) + + unique_file_ids_of_revisions = set( + [ + data["file_id"] + for data in response.get("data", []) + if "file_id" in data + ] + ) + job_params.files = list( + unique_file_ids_of_revisions.union(job_params.files) + ) diff --git a/jobs/pytest.ini b/jobs/pytest.ini index 2f6c8d12f..40880458c 100644 --- a/jobs/pytest.ini +++ b/jobs/pytest.ini @@ -1,2 +1,2 @@ [pytest] -asyncio_mode=strict +asyncio_mode=auto diff --git a/jobs/tests/test_utils.py b/jobs/tests/test_utils.py index af3f104c3..4eff11700 100644 --- a/jobs/tests/test_utils.py +++ b/jobs/tests/test_utils.py @@ -1,4 +1,4 @@ -from unittest.mock import patch +from unittest.mock import AsyncMock, Mock, patch import aiohttp.client_exceptions import pytest @@ -6,6 +6,7 @@ from tests.conftest import FakePipeline, patched_create_pre_signed_s3_url import jobs.utils as utils +from jobs.schemas import JobParams, JobType # --------------TEST get_files_data_from_datasets--------------- @@ -1193,3 +1194,69 @@ async def test_execute_external_pipeline(sign_s3_links: bool): ) else: assert FakePipeline.calls[-1]["files"][0].get("signed_url") is None + + +async def test_update_create_job_params_using_revisions(monkeypatch): + job_params = JobParams( + name="name_1", + type=JobType.ExtractionJob, + pipeline_name="pipeline_name_1", + files=[1], + revisions=["revision_id_1", "revision_id_2", "revision_id_3"], + ) + + mock_response = { + "data": [ + {"file_id": 2, "revision": "revision_id_1"}, + {"file_id": 3, "revision": "revision_id_2"}, + {"file_id": 3, "revision": "revision_id_3"}, + ] + } + + mock_current_tenant = Mock() + mock_jwt_token = Mock() + mock_get_annotations_by_revisions = AsyncMock(return_value=mock_response) + + monkeypatch.setattr( + utils, + "get_annotations_by_revisions", + mock_get_annotations_by_revisions, + ) + + await utils.update_create_job_params_using_revisions( + job_params, mock_current_tenant, mock_jwt_token + ) + + mock_get_annotations_by_revisions.assert_called_once() + + assert job_params.files == [1, 2, 3] + + +async def test_get_annotations_by_revisions(monkeypatch): + revisions = ["revision_id_1", "revision_id_2"] + + mock_fetch_response_status = Mock() + mock_fetch_response_json = Mock() + + def mock_fetch_side_effect(**kwargs): + assert kwargs["url"].endswith("/annotation") + assert kwargs["method"] == "POST" + assert kwargs["body"]["filters"][0] == { + "field": "revision", + "operator": "in", + "value": revisions, + } + + return mock_fetch_response_status, mock_fetch_response_json + + mock_fetch = AsyncMock(side_effect=mock_fetch_side_effect) + mock_current_tenant = Mock() + mock_jw_token = Mock() + + monkeypatch.setattr(utils, "fetch", mock_fetch) + + await utils.get_annotations_by_revisions( + mock_current_tenant, mock_jw_token, revisions + ) + + mock_fetch.assert_called_once()