diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index 6b013a85..77ffe580 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -211,7 +211,6 @@ def get_metrics(): } return details_dic - def parse_insert_response(api_response, continue_on_error): # Retrieve the headers and data from the API response api_response_headers = api_response.headers @@ -239,13 +238,13 @@ def parse_insert_response(api_response, continue_on_error): error = { 'request_index': idx, 'request_id': request_id, - 'error': response['Body']['error'] + 'error': response['Body']['error'], + 'http_code': response['Status'], } errors.append(error) insert_response.inserted_fields = inserted_fields - insert_response.errors = errors - + insert_response.errors = errors if len(errors) > 0 else None else: for record in api_response_data.records: field_data = { @@ -257,6 +256,7 @@ def parse_insert_response(api_response, continue_on_error): inserted_fields.append(field_data) insert_response.inserted_fields = inserted_fields + insert_response.errors = None return insert_response @@ -275,21 +275,17 @@ def parse_delete_response(api_response: V1BulkDeleteRecordResponse): delete_response = DeleteResponse() deleted_ids = api_response.record_id_response delete_response.deleted_ids = deleted_ids - delete_response.errors = [] + delete_response.errors = None return delete_response - def parse_get_response(api_response: V1BulkGetRecordResponse): get_response = GetResponse() data = [] - errors = [] for record in api_response.records: field_data = {field: value for field, value in record.fields.items()} data.append(field_data) get_response.data = data - get_response.errors = errors - return get_response def parse_detokenize_response(api_response: HttpResponse[V1DetokenizeResponse]): @@ -320,7 +316,7 @@ def parse_detokenize_response(api_response: HttpResponse[V1DetokenizeResponse]): errors = errors detokenize_response = DetokenizeResponse() detokenize_response.detokenized_fields = detokenized_fields - detokenize_response.errors = errors + detokenize_response.errors = errors if len(errors) > 0 else None return detokenize_response @@ -357,7 +353,7 @@ def parse_invoke_connection_response(api_response: requests.Response): if 'x-request-id' in api_response.headers: metadata['request_id'] = api_response.headers['x-request-id'] - return InvokeConnectionResponse(data=data, metadata=metadata) + return InvokeConnectionResponse(data=data, metadata=metadata, errors=None) except Exception as e: raise SkyflowError(SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(content), status_code) except HTTPError: diff --git a/skyflow/vault/connection/_invoke_connection_response.py b/skyflow/vault/connection/_invoke_connection_response.py index 818b94a1..882e150c 100644 --- a/skyflow/vault/connection/_invoke_connection_response.py +++ b/skyflow/vault/connection/_invoke_connection_response.py @@ -1,10 +1,11 @@ class InvokeConnectionResponse: - def __init__(self, data=None, metadata=None): + def __init__(self, data=None, metadata=None, errors=None): self.data = data self.metadata = metadata if metadata else {} + self.errors = errors if errors else None def __repr__(self): - return f"ConnectionResponse('data'={self.data},'metadata'={self.metadata})" + return f"ConnectionResponse('data'={self.data},'metadata'={self.metadata}), 'errors'={self.errors})" def __str__(self): return self.__repr__() \ No newline at end of file diff --git a/skyflow/vault/controller/_detect.py b/skyflow/vault/controller/_detect.py index 8e6576cf..606d58ef 100644 --- a/skyflow/vault/controller/_detect.py +++ b/skyflow/vault/controller/_detect.py @@ -151,7 +151,7 @@ def output_to_dict_list(output): entities=entities, run_id=run_id_val, status=status_val, - errors=[] + errors=None ) def __get_token_format(self, request): diff --git a/skyflow/vault/data/_insert_response.py b/skyflow/vault/data/_insert_response.py index 6407426d..0c7c777f 100644 --- a/skyflow/vault/data/_insert_response.py +++ b/skyflow/vault/data/_insert_response.py @@ -1,7 +1,5 @@ class InsertResponse: def __init__(self, inserted_fields = None, errors=None): - if errors is None: - errors = list() self.inserted_fields = inserted_fields self.errors = errors diff --git a/skyflow/vault/data/_query_response.py b/skyflow/vault/data/_query_response.py index e2034758..b97fa9bd 100644 --- a/skyflow/vault/data/_query_response.py +++ b/skyflow/vault/data/_query_response.py @@ -1,7 +1,7 @@ class QueryResponse: def __init__(self): self.fields = [] - self.errors = [] + self.errors = None def __repr__(self): return f"QueryResponse(fields={self.fields}, errors={self.errors})" diff --git a/skyflow/vault/data/_update_response.py b/skyflow/vault/data/_update_response.py index dbbb9cc7..c37ee000 100644 --- a/skyflow/vault/data/_update_response.py +++ b/skyflow/vault/data/_update_response.py @@ -1,7 +1,7 @@ class UpdateResponse: def __init__(self, updated_field = None, errors=None): self.updated_field = updated_field - self.errors = errors if errors is not None else [] + self.errors = errors def __repr__(self): return f"UpdateResponse(updated_field={self.updated_field}, errors={self.errors})" diff --git a/skyflow/vault/detect/_deidentify_file_response.py b/skyflow/vault/detect/_deidentify_file_response.py index dfe1851a..90a0d493 100644 --- a/skyflow/vault/detect/_deidentify_file_response.py +++ b/skyflow/vault/detect/_deidentify_file_response.py @@ -17,7 +17,7 @@ def __init__( entities: list = None, # list of dicts with keys 'file' and 'extension' run_id: str = None, status: str = None, - errors: list = [], + errors: list = None, ): self.file_base64 = file_base64 self.file = File(file) if file else None diff --git a/tests/utils/test__utils.py b/tests/utils/test__utils.py index 6324d9a7..6eaacf47 100644 --- a/tests/utils/test__utils.py +++ b/tests/utils/test__utils.py @@ -252,6 +252,12 @@ def test_parse_insert_response(self): result = parse_insert_response(api_response, continue_on_error=True) self.assertEqual(len(result.inserted_fields), 1) self.assertEqual(len(result.errors), 1) + # Assert first successful record + self.assertEqual(result.inserted_fields[0]["skyflow_id"], "id1") + # Assert error record + self.assertEqual(result.errors[0]["error"], TEST_ERROR_MESSAGE) + self.assertEqual(result.errors[0]["http_code"], 400) + self.assertEqual(result.errors[0]["request_id"], "12345") def test_parse_insert_response_continue_on_error_false(self): mock_api_response = Mock() @@ -270,7 +276,7 @@ def test_parse_insert_response_continue_on_error_false(self): ] self.assertEqual(result.inserted_fields, expected_inserted_fields) - self.assertEqual(result.errors, []) + self.assertEqual(result.errors, None) def test_parse_update_record_response(self): api_response = Mock() @@ -291,7 +297,7 @@ def test_parse_delete_response_successful(self): expected_deleted_ids = ["id_1", "id_2", "id_3"] self.assertEqual(result.deleted_ids, expected_deleted_ids) - self.assertEqual(result.errors, []) + self.assertEqual(result.errors, None) def test_parse_get_response_successful(self): mock_api_response = Mock() @@ -310,7 +316,7 @@ def test_parse_get_response_successful(self): ] self.assertEqual(result.data, expected_data) - self.assertEqual(result.errors, []) + # self.assertEqual(result.errors, None) def test_parse_detokenize_response_with_mixed_records(self): mock_api_response = Mock() @@ -384,6 +390,7 @@ def test_parse_invoke_connection_response_successful(self, mock_response): self.assertIsInstance(result, InvokeConnectionResponse) self.assertEqual(result.data["key"], "value") self.assertEqual(result.metadata["request_id"], "1234") + self.assertEqual(result.errors, None) @patch("requests.Response") def test_parse_invoke_connection_response_json_decode_error(self, mock_response): diff --git a/tests/vault/controller/test__connection.py b/tests/vault/controller/test__connection.py index 70702514..4ccad1c7 100644 --- a/tests/vault/controller/test__connection.py +++ b/tests/vault/controller/test__connection.py @@ -55,7 +55,8 @@ def test_invoke_success(self, mock_send): # Assertions for successful invocation expected_response = { 'data': {"response": "success"}, - 'metadata': {"request_id": "test-request-id"} + 'metadata': {"request_id": "test-request-id"}, + 'errors': None } self.assertEqual(vars(response), expected_response) self.mock_vault_client.get_bearer_token.assert_called_once() diff --git a/tests/vault/controller/test__detect.py b/tests/vault/controller/test__detect.py index dbe92e93..1352f85b 100644 --- a/tests/vault/controller/test__detect.py +++ b/tests/vault/controller/test__detect.py @@ -159,7 +159,7 @@ def test_deidentify_file_txt_success(self, mock_open, mock_basename, mock_base64 word_count=1, char_count=1, size_in_kb=1, duration_in_seconds=None, page_count=None, slide_count=None, entities=[], run_id="runid123", - status="SUCCESS", errors=[])) as mock_parse: + status="SUCCESS", errors=None)) as mock_parse: result = self.detect.deidentify_file(req) mock_validate.assert_called_once() @@ -184,7 +184,7 @@ def test_deidentify_file_txt_success(self, mock_open, mock_basename, mock_base64 self.assertIsNone(result.page_count) self.assertIsNone(result.slide_count) self.assertEqual(result.entities, []) - self.assertEqual(result.errors, []) + self.assertEqual(result.errors, None) @patch("skyflow.vault.controller._detect.validate_deidentify_file_request") @patch("skyflow.vault.controller._detect.base64") @@ -222,7 +222,7 @@ def test_deidentify_file_audio_success(self, mock_base64, mock_validate): word_count=1, char_count=1, size_in_kb=1, duration_in_seconds=1, page_count=None, slide_count=None, entities=[], run_id="runid456", - status="SUCCESS", errors=[])) as mock_parse: + status="SUCCESS", errors=None)) as mock_parse: result = self.detect.deidentify_file(req) mock_validate.assert_called_once() files_api.deidentify_audio.assert_called_once() @@ -260,12 +260,11 @@ def test_get_detect_run_success(self, mock_validate): response.word_character_count = Mock(word_count=1, character_count=1) files_api.get_run.return_value = response with patch.object(self.detect, "_Detect__parse_deidentify_file_response", - return_value=DeidentifyFileResponse( - file="file", type="txt", extension="txt", word_count=1, - char_count=1, size_in_kb=1, duration_in_seconds=None, - page_count=None, slide_count=None, entities=[], - run_id="runid789", status="SUCCESS", - errors=[])) as mock_parse: + return_value=DeidentifyFileResponse(file="file", type="txt", extension="txt", word_count=1, + char_count=1, size_in_kb=1, duration_in_seconds=None, + page_count=None, slide_count=None, entities=[], + run_id="runid789", status="SUCCESS", + errors=None)) as mock_parse: result = self.detect.get_detect_run(req) mock_validate.assert_called_once() files_api.get_run.assert_called_once() @@ -678,7 +677,7 @@ def test_deidentify_file_using_file_path(self, mock_open, mock_basename, mock_ba entities=[], run_id="runid123", status="SUCCESS", - errors=[] + errors=None )) as mock_parse: result = self.detect.deidentify_file(req) @@ -709,4 +708,4 @@ def test_deidentify_file_using_file_path(self, mock_open, mock_basename, mock_ba self.assertIsNone(result.page_count) self.assertIsNone(result.slide_count) self.assertEqual(result.entities, []) - self.assertEqual(result.errors, []) \ No newline at end of file + self.assertEqual(result.errors, None) diff --git a/tests/vault/controller/test__vault.py b/tests/vault/controller/test__vault.py index 39b44ae1..0c8a7743 100644 --- a/tests/vault/controller/test__vault.py +++ b/tests/vault/controller/test__vault.py @@ -123,7 +123,7 @@ def test_insert_with_continue_on_error_false(self, mock_parse_response, mock_val # Assert that the result matches the expected InsertResponse self.assertEqual(result.inserted_fields, expected_inserted_fields) - self.assertEqual(result.errors, []) # No errors expected + self.assertEqual(result.errors, None) # No errors expected @patch("skyflow.vault.controller._vault.validate_insert_request") def test_insert_handles_generic_error(self, mock_validate): @@ -181,7 +181,7 @@ def test_insert_with_continue_on_error_false_when_tokens_are_not_none(self, mock # Assert that the result matches the expected InsertResponse self.assertEqual(result.inserted_fields, expected_inserted_fields) - self.assertEqual(result.errors, []) # No errors expected + self.assertEqual(result.errors, None) # No errors expected @patch("skyflow.vault.controller._vault.validate_update_request") @patch("skyflow.vault.controller._vault.parse_update_record_response") @@ -223,7 +223,7 @@ def test_update_successful(self, mock_parse_response, mock_validate): # Check that the result matches the expected UpdateResponse self.assertEqual(result.updated_field, expected_updated_field) - self.assertEqual(result.errors, []) # No errors expected + self.assertEqual(result.errors, None) # No errors expected @patch("skyflow.vault.controller._vault.validate_update_request") def test_update_handles_generic_error(self, mock_validate): @@ -257,7 +257,7 @@ def test_delete_successful(self, mock_parse_response, mock_validate): # Expected parsed response expected_deleted_ids = ["12345", "67890"] - expected_response = DeleteResponse(deleted_ids=expected_deleted_ids, errors=[]) + expected_response = DeleteResponse(deleted_ids=expected_deleted_ids, errors=None) # Set the return value for the parse response mock_parse_response.return_value = expected_response @@ -273,7 +273,7 @@ def test_delete_successful(self, mock_parse_response, mock_validate): # Check that the result matches the expected DeleteResponse self.assertEqual(result.deleted_ids, expected_deleted_ids) - self.assertEqual(result.errors, []) # No errors expected + self.assertEqual(result.errors, None) # No errors expected @patch("skyflow.vault.controller._vault.validate_delete_request") def test_delete_handles_generic_exception(self, mock_validate): @@ -330,7 +330,7 @@ def test_get_successful(self, mock_parse_response, mock_validate): {"field1": "value1", "field2": "value2"}, {"field1": "value3", "field2": "value4"} ] - expected_response = GetResponse(data=expected_data, errors=[]) + expected_response = GetResponse(data=expected_data, errors=None) # Set the return value for parse_get_response mock_parse_response.return_value = expected_response @@ -346,7 +346,7 @@ def test_get_successful(self, mock_parse_response, mock_validate): # Check that the result matches the expected GetResponse self.assertEqual(result.data, expected_data) - self.assertEqual(result.errors, []) # No errors expected + self.assertEqual(result.errors, None) # No errors expected @patch("skyflow.vault.controller._vault.validate_get_request") @patch("skyflow.vault.controller._vault.parse_get_response") @@ -381,7 +381,7 @@ def test_get_successful_with_column_values(self, mock_parse_response, mock_valid {"field1": "value1", "field2": "value2"}, {"field1": "value3", "field2": "value4"} ] - expected_response = GetResponse(data=expected_data, errors=[]) + expected_response = GetResponse(data=expected_data, errors=None) # Set the return value for parse_get_response mock_parse_response.return_value = expected_response @@ -397,7 +397,7 @@ def test_get_successful_with_column_values(self, mock_parse_response, mock_valid # Check that the result matches the expected GetResponse self.assertEqual(result.data, expected_data) - self.assertEqual(result.errors, []) # No errors expected + self.assertEqual(result.errors, None) # No errors expected @patch("skyflow.vault.controller._vault.validate_get_request") def test_get_handles_generic_error(self, mock_validate): @@ -446,7 +446,7 @@ def test_query_successful(self, mock_parse_response, mock_validate): # Check that the result matches the expected QueryResponse self.assertEqual(result.fields, expected_fields) - self.assertEqual(result.errors, []) # No errors expected + self.assertEqual(result.errors, None) # No errors expected @patch("skyflow.vault.controller._vault.validate_query_request") def test_query_handles_generic_error(self, mock_validate): @@ -495,7 +495,7 @@ def test_detokenize_successful(self, mock_parse_response, mock_validate): {"token": "token1", "value": "value1", "type": "STRING"}, {"token": "token2", "value": "value2", "type": "STRING"} ] - expected_response = DetokenizeResponse(detokenized_fields=expected_fields, errors=[]) + expected_response = DetokenizeResponse(detokenized_fields=expected_fields, errors=None) # Set the return value for parse_detokenize_response mock_parse_response.return_value = expected_response @@ -511,7 +511,7 @@ def test_detokenize_successful(self, mock_parse_response, mock_validate): # Check that the result matches the expected DetokenizeResponse self.assertEqual(result.detokenized_fields, expected_fields) - self.assertEqual(result.errors, []) # No errors expected + self.assertEqual(result.errors, None) # No errors expected @patch("skyflow.vault.controller._vault.validate_detokenize_request") def test_detokenize_handles_generic_error(self, mock_validate):