Skip to content
3 changes: 3 additions & 0 deletions jobs/jobs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ def get_service_uri(prefix: str) -> str: # noqa
PAGINATION_THRESHOLD = 7
PROVIDE_JWT_IF_NO_ANY = True

# request settings
REQUEST_TIMEOUT = 60 # in seconds

# S3 settings
STORAGE_PROVIDER = os.getenv("STORAGE_PROVIDER")
JOBS_SIGNED_URL_ENABLED = (
Expand Down
7 changes: 7 additions & 0 deletions jobs/jobs/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ async def create_job(
)
)

if len(job_params.revisions) > 0:
await utils.update_create_job_params_using_revisions(
job_params=job_params,
current_tenant=current_tenant,
jwt_token=jw_token,
)

if job_params.type == schemas.JobType.ExtractionJob:
created_extraction_job = await create_job_funcs.create_extraction_job(
extraction_job_input=job_params, # type: ignore
Expand Down
69 changes: 62 additions & 7 deletions jobs/jobs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import jobs.airflow_utils as airflow_utils
import jobs.databricks_utils as databricks_utils
import jobs.pipeline as pipeline
from jobs import db_service
from jobs import db_service, schemas
from jobs.config import (
ANNOTATION_SERVICE_HOST,
ASSETS_SERVICE_HOST,
Expand All @@ -16,6 +16,7 @@
JOBS_SIGNED_URL_KEY_NAME,
PAGINATION_THRESHOLD,
PIPELINES_SERVICE_HOST,
REQUEST_TIMEOUT,
ROOT_PATH,
TAXONOMY_SERVICE_HOST,
USERS_HOST,
Expand Down Expand Up @@ -572,11 +573,8 @@ async def get_job_progress(
"X-Current-Tenant": current_tenant,
"Authorization": f"Bearer: {jw_token}",
}
timeout = aiohttp.ClientTimeout(total=5)
try:
_, response = await fetch(
method="GET", url=url, headers=headers, timeout=timeout
)
_, response = await fetch(method="GET", url=url, headers=headers)
except aiohttp.client_exceptions.ClientError as err:
logger.exception(f"Failed request url = {url}, error = {err}")
raise fastapi.HTTPException(
Expand All @@ -596,6 +594,8 @@ async def fetch(
headers: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> Tuple[int, Any]:
if "timeout" not in kwargs:
kwargs["timeout"] = aiohttp.ClientTimeout(total=REQUEST_TIMEOUT)
async with aiohttp.request(
method=method, url=url, json=body, headers=headers, data=data, **kwargs
) as resp:
Expand Down Expand Up @@ -767,13 +767,11 @@ async def get_annotation_revisions(
"X-Current-Tenant": current_tenant,
"Authorization": f"Bearer: {jw_token}",
}
timeout = aiohttp.ClientTimeout(total=5)
try:
_, response = await fetch(
method="GET",
url=f"{ANNOTATION_SERVICE_HOST}/revisions/{job_id}/{file_id}",
headers=headers,
timeout=timeout,
)
except aiohttp.client_exceptions.ClientError as err:
logger.exception(
Expand All @@ -788,6 +786,42 @@ async def get_annotation_revisions(
return response


async def get_annotations_by_revisions(
current_tenant: Optional[str], jw_token: str, revisions: List[str]
) -> Optional[Dict[str, Any]]:
"""Get annotations by filtering"""

headers = {
"X-Current-Tenant": current_tenant,
"Authorization": f"Bearer: {jw_token}",
}

post_data = {
"filters": [
{"field": "revision", "operator": "in", "value": revisions}
]
}

try:
_, response = await fetch(
method="POST",
url=f"{ANNOTATION_SERVICE_HOST}/annotation",
headers=headers,
body=post_data,
)
except aiohttp.client_exceptions.ClientError as err:
logger.exception(
f"Failed request to get annotations by revisions: {revisions}"
)
raise fastapi.HTTPException(
status_code=fastapi.status.HTTP_400_BAD_REQUEST,
detail="Could not retrieve selected annotations:"
f" {', '.join(revisions)}",
) from err

return response


async def search_datasets_by_ids(
datasets_ids: List[int], current_tenant: str, jw_token: str
) -> Dict[str, Any]:
Expand Down Expand Up @@ -860,3 +894,24 @@ async def validate_create_job_previous_jobs(
detail="Jobs with these ids do not exist.",
)
return [j.id for j in previous_jobs]


async def update_create_job_params_using_revisions(
job_params: schemas.JobParams, current_tenant: str, jwt_token: str
) -> None:
response = await get_annotations_by_revisions(
current_tenant=current_tenant,
jw_token=jwt_token,
revisions=list(job_params.revisions),
)

unique_file_ids_of_revisions = set(
[
data["file_id"]
for data in response.get("data", [])
if "file_id" in data
]
)
job_params.files = list(
unique_file_ids_of_revisions.union(job_params.files)
)
2 changes: 1 addition & 1 deletion jobs/pytest.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[pytest]
asyncio_mode=strict
asyncio_mode=auto
69 changes: 68 additions & 1 deletion jobs/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from unittest.mock import patch
from unittest.mock import AsyncMock, Mock, patch

import aiohttp.client_exceptions
import pytest
from fastapi import HTTPException
from tests.conftest import FakePipeline, patched_create_pre_signed_s3_url

import jobs.utils as utils
from jobs.schemas import JobParams, JobType

# --------------TEST get_files_data_from_datasets---------------

Expand Down Expand Up @@ -1193,3 +1194,69 @@ async def test_execute_external_pipeline(sign_s3_links: bool):
)
else:
assert FakePipeline.calls[-1]["files"][0].get("signed_url") is None


async def test_update_create_job_params_using_revisions(monkeypatch):
job_params = JobParams(
name="name_1",
type=JobType.ExtractionJob,
pipeline_name="pipeline_name_1",
files=[1],
revisions=["revision_id_1", "revision_id_2", "revision_id_3"],
)

mock_response = {
"data": [
{"file_id": 2, "revision": "revision_id_1"},
{"file_id": 3, "revision": "revision_id_2"},
{"file_id": 3, "revision": "revision_id_3"},
]
}

mock_current_tenant = Mock()
mock_jwt_token = Mock()
mock_get_annotations_by_revisions = AsyncMock(return_value=mock_response)

monkeypatch.setattr(
utils,
"get_annotations_by_revisions",
mock_get_annotations_by_revisions,
)

await utils.update_create_job_params_using_revisions(
job_params, mock_current_tenant, mock_jwt_token
)

mock_get_annotations_by_revisions.assert_called_once()

assert job_params.files == [1, 2, 3]


async def test_get_annotations_by_revisions(monkeypatch):
revisions = ["revision_id_1", "revision_id_2"]

mock_fetch_response_status = Mock()
mock_fetch_response_json = Mock()

def mock_fetch_side_effect(**kwargs):
assert kwargs["url"].endswith("/annotation")
assert kwargs["method"] == "POST"
assert kwargs["body"]["filters"][0] == {
"field": "revision",
"operator": "in",
"value": revisions,
}

return mock_fetch_response_status, mock_fetch_response_json

mock_fetch = AsyncMock(side_effect=mock_fetch_side_effect)
mock_current_tenant = Mock()
mock_jw_token = Mock()

monkeypatch.setattr(utils, "fetch", mock_fetch)

await utils.get_annotations_by_revisions(
mock_current_tenant, mock_jw_token, revisions
)

mock_fetch.assert_called_once()