diff --git a/setup.py b/setup.py index 8c09ec2e..7cd88cab 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ if sys.version_info < (3, 8): raise RuntimeError("skyflow requires Python 3.8+") -current_version = '2.0.0b6' +current_version = '2.0.0b6.dev0+c3095a9' setup( name='skyflow', diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index a5b94451..460ca29e 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -169,6 +169,7 @@ class Error(Enum): INVALID_PLAIN_TEXT_ENTITIES_IN_REIDENTIFY= f"{error_prefix} Validation error. The plainTextEntities field must be an array of DetectEntities enums. Specify a valid plainTextEntities." INVALID_DEIDENTIFY_FILE_REQUEST= f"{error_prefix} Validation error. Invalid deidentify file request. Specify a valid deidentify file request." + INVALID_DEIDENTIFY_FILE_INPUT= f"{error_prefix} Validation error. Invalid deidentify file input. Please provide either a file or a file path." EMPTY_FILE_OBJECT= f"{error_prefix} Validation error. File object cannot be empty. Specify a valid file object." INVALID_FILE_FORMAT= f"{error_prefix} Validation error. Invalid file format. Specify a valid file format." MISSING_FILE_SOURCE= f"{error_prefix} Validation error. Provide exactly one of filePath, base64, or fileObject." @@ -197,7 +198,7 @@ class Error(Enum): INVALID_FILE_OR_ENCODED_FILE= f"{error_prefix} . Error while decoding base64 and saving file" INVALID_FILE_TYPE = f"{error_prefix} Validation error. Invalid file type. Specify a valid file type." INVALID_FILE_NAME= f"{error_prefix} Validation error. Invalid file name. Specify a valid file name." - FILE_READ_ERROR= f"{error_prefix} Validation error. Unable to read file. Verify the file path." + INVALID_DEIDENTIFY_FILE_PATH= f"{error_prefix} Validation error. Invalid file path. Specify a valid file path." INVALID_BASE64_HEADER= f"{error_prefix} Validation error. Invalid base64 header. Specify a valid base64 header." INVALID_WAIT_TIME= f"{error_prefix} Validation error. Invalid wait time. Specify a valid wait time as number and should not be greater than 64 secs." INVALID_OUTPUT_DIRECTORY= f"{error_prefix} Validation error. Invalid output directory. Specify a valid output directory as string." diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index 6b013a85..77ffe580 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -211,7 +211,6 @@ def get_metrics(): } return details_dic - def parse_insert_response(api_response, continue_on_error): # Retrieve the headers and data from the API response api_response_headers = api_response.headers @@ -239,13 +238,13 @@ def parse_insert_response(api_response, continue_on_error): error = { 'request_index': idx, 'request_id': request_id, - 'error': response['Body']['error'] + 'error': response['Body']['error'], + 'http_code': response['Status'], } errors.append(error) insert_response.inserted_fields = inserted_fields - insert_response.errors = errors - + insert_response.errors = errors if len(errors) > 0 else None else: for record in api_response_data.records: field_data = { @@ -257,6 +256,7 @@ def parse_insert_response(api_response, continue_on_error): inserted_fields.append(field_data) insert_response.inserted_fields = inserted_fields + insert_response.errors = None return insert_response @@ -275,21 +275,17 @@ def parse_delete_response(api_response: V1BulkDeleteRecordResponse): delete_response = DeleteResponse() deleted_ids = api_response.record_id_response delete_response.deleted_ids = deleted_ids - delete_response.errors = [] + delete_response.errors = None return delete_response - def parse_get_response(api_response: V1BulkGetRecordResponse): get_response = GetResponse() data = [] - errors = [] for record in api_response.records: field_data = {field: value for field, value in record.fields.items()} data.append(field_data) get_response.data = data - get_response.errors = errors - return get_response def parse_detokenize_response(api_response: HttpResponse[V1DetokenizeResponse]): @@ -320,7 +316,7 @@ def parse_detokenize_response(api_response: HttpResponse[V1DetokenizeResponse]): errors = errors detokenize_response = DetokenizeResponse() detokenize_response.detokenized_fields = detokenized_fields - detokenize_response.errors = errors + detokenize_response.errors = errors if len(errors) > 0 else None return detokenize_response @@ -357,7 +353,7 @@ def parse_invoke_connection_response(api_response: requests.Response): if 'x-request-id' in api_response.headers: metadata['request_id'] = api_response.headers['x-request-id'] - return InvokeConnectionResponse(data=data, metadata=metadata) + return InvokeConnectionResponse(data=data, metadata=metadata, errors=None) except Exception as e: raise SkyflowError(SkyflowMessages.Error.RESPONSE_NOT_JSON.value.format(content), status_code) except HTTPError: diff --git a/skyflow/utils/_version.py b/skyflow/utils/_version.py index 0cd8592c..e394031c 100644 --- a/skyflow/utils/_version.py +++ b/skyflow/utils/_version.py @@ -1 +1 @@ -SDK_VERSION = '2.0.0b6' \ No newline at end of file +SDK_VERSION = '2.0.0b6.dev0+c3095a9' \ No newline at end of file diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index 0ff9f038..bbca6e85 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -9,6 +9,7 @@ from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, TokenFormat, Transformations, \ GetDetectRunRequest, Bleep, DeidentifyFileRequest +from skyflow.vault.detect._file_input import FileInput valid_vault_config_keys = ["vault_id", "cluster_id", "credentials", "env"] valid_connection_config_keys = ["connection_id", "connection_url", "credentials"] @@ -257,9 +258,42 @@ def validate_update_connection_config(logger, config): return True +def validate_file_from_request(file_input: FileInput): + if file_input is None: + raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_INPUT.value, invalid_input_error_code) + + has_file = hasattr(file_input, 'file') and file_input.file is not None + has_file_path = hasattr(file_input, 'file_path') and file_input.file_path is not None + + # Must provide exactly one of file or file_path + if (has_file and has_file_path) or (not has_file and not has_file_path): + raise SkyflowError(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_INPUT.value, invalid_input_error_code) + + if has_file: + file = file_input.file + # Validate file object has required attributes + if not hasattr(file, 'name') or not isinstance(file.name, str) or not file.name.strip(): + raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_TYPE.value, invalid_input_error_code) + + # Validate file name + file_name = os.path.splitext(file.name)[0] + if not file_name or not file_name.strip(): + raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_NAME.value, invalid_input_error_code) + + elif has_file_path: + file_path = file_input.file_path + if not isinstance(file_path, str) or not file_path.strip(): + raise SkyflowError(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value, invalid_input_error_code) + + if not os.path.exists(file_path) or not os.path.isfile(file_path): + raise SkyflowError(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value, invalid_input_error_code) + def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): if not hasattr(request, 'file') or request.file is None: raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_INPUT.value, invalid_input_error_code) + + # Validate file input first + validate_file_from_request(request.file) # Optional: entities if hasattr(request, 'entities') and request.entities is not None: diff --git a/skyflow/vault/connection/_invoke_connection_response.py b/skyflow/vault/connection/_invoke_connection_response.py index 818b94a1..882e150c 100644 --- a/skyflow/vault/connection/_invoke_connection_response.py +++ b/skyflow/vault/connection/_invoke_connection_response.py @@ -1,10 +1,11 @@ class InvokeConnectionResponse: - def __init__(self, data=None, metadata=None): + def __init__(self, data=None, metadata=None, errors=None): self.data = data self.metadata = metadata if metadata else {} + self.errors = errors if errors else None def __repr__(self): - return f"ConnectionResponse('data'={self.data},'metadata'={self.metadata})" + return f"ConnectionResponse('data'={self.data},'metadata'={self.metadata}), 'errors'={self.errors})" def __str__(self): return self.__repr__() \ No newline at end of file diff --git a/skyflow/vault/controller/_detect.py b/skyflow/vault/controller/_detect.py index 1dbd533c..606d58ef 100644 --- a/skyflow/vault/controller/_detect.py +++ b/skyflow/vault/controller/_detect.py @@ -1,3 +1,4 @@ +import io import json import os from skyflow.error import SkyflowError @@ -20,6 +21,7 @@ from skyflow.vault.detect import DeidentifyTextRequest, DeidentifyTextResponse, ReidentifyTextRequest, \ ReidentifyTextResponse, DeidentifyFileRequest, DeidentifyFileResponse, GetDetectRunRequest + class Detect: def __init__(self, vault_client): self.__vault_client = vault_client @@ -124,10 +126,22 @@ def output_to_dict_list(output): word_count = getattr(word_character_count, "word_count", None) char_count = getattr(word_character_count, "character_count", None) + base64_string = first_output.get("file", None) + extension = first_output.get("extension", None) + + file_obj = None + if base64_string is not None: + file_bytes = base64.b64decode(base64_string) + file_obj = io.BytesIO(file_bytes) + file_obj.name = f"deidentified.{extension}" if extension else "processed_file" + else: + file_obj = None + return DeidentifyFileResponse( - file=first_output.get("file", None), + file_base64=base64_string, + file=file_obj, # File class will be instantiated in DeidentifyFileResponse type=first_output.get("type", None), - extension=first_output.get("extension", None), + extension=extension, word_count=word_count, char_count=char_count, size_in_kb=size, @@ -137,7 +151,7 @@ def output_to_dict_list(output): entities=entities, run_id=run_id_val, status=status_val, - errors=[] + errors=None ) def __get_token_format(self, request): @@ -216,16 +230,26 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo log_error_log(SkyflowMessages.ErrorLogs.REIDENTIFY_TEXT_REQUEST_REJECTED.value, self.__vault_client.get_logger()) handle_exception(e, self.__vault_client.get_logger()) + def __get_file_from_request(self, request: DeidentifyFileRequest): + file_input = request.file + + # Check for file + if hasattr(file_input, 'file') and file_input.file is not None: + return file_input.file + + # Check for file_path if file is not provided + if hasattr(file_input, 'file_path') and file_input.file_path is not None: + return open(file_input.file_path, 'rb') + def deidentify_file(self, request: DeidentifyFileRequest): log_info(SkyflowMessages.Info.DETECT_FILE_TRIGGERED.value, self.__vault_client.get_logger()) validate_deidentify_file_request(self.__vault_client.get_logger(), request) self.__initialize() files_api = self.__vault_client.get_detect_file_api().with_raw_response - file_obj = request.file + file_obj = self.__get_file_from_request(request) file_name = getattr(file_obj, 'name', None) file_extension = self._get_file_extension(file_name) if file_name else None file_content = file_obj.read() - base64_string = base64.b64encode(file_content).decode('utf-8') try: @@ -375,7 +399,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): file_name_only = 'processed-'+os.path.basename(file_name) output_file_path = f"{request.output_directory}/{file_name_only}" with open(output_file_path, 'wb') as output_file: - output_file.write(base64.b64decode(parsed_response.file)) + output_file.write(base64.b64decode(parsed_response.file_base64)) log_info(SkyflowMessages.Info.DETECT_FILE_SUCCESS.value, self.__vault_client.get_logger()) return parsed_response diff --git a/skyflow/vault/data/_insert_response.py b/skyflow/vault/data/_insert_response.py index 6407426d..0c7c777f 100644 --- a/skyflow/vault/data/_insert_response.py +++ b/skyflow/vault/data/_insert_response.py @@ -1,7 +1,5 @@ class InsertResponse: def __init__(self, inserted_fields = None, errors=None): - if errors is None: - errors = list() self.inserted_fields = inserted_fields self.errors = errors diff --git a/skyflow/vault/data/_query_response.py b/skyflow/vault/data/_query_response.py index e2034758..b97fa9bd 100644 --- a/skyflow/vault/data/_query_response.py +++ b/skyflow/vault/data/_query_response.py @@ -1,7 +1,7 @@ class QueryResponse: def __init__(self): self.fields = [] - self.errors = [] + self.errors = None def __repr__(self): return f"QueryResponse(fields={self.fields}, errors={self.errors})" diff --git a/skyflow/vault/data/_update_response.py b/skyflow/vault/data/_update_response.py index dbbb9cc7..c37ee000 100644 --- a/skyflow/vault/data/_update_response.py +++ b/skyflow/vault/data/_update_response.py @@ -1,7 +1,7 @@ class UpdateResponse: def __init__(self, updated_field = None, errors=None): self.updated_field = updated_field - self.errors = errors if errors is not None else [] + self.errors = errors def __repr__(self): return f"UpdateResponse(updated_field={self.updated_field}, errors={self.errors})" diff --git a/skyflow/vault/detect/__init__.py b/skyflow/vault/detect/__init__.py index e385a1f2..bd09fed8 100644 --- a/skyflow/vault/detect/__init__.py +++ b/skyflow/vault/detect/__init__.py @@ -10,4 +10,5 @@ from ._deidentify_file_request import DeidentifyFileRequest from ._audio_bleep import Bleep from ._deidentify_file_response import DeidentifyFileResponse -from ._get_detect_run_request import GetDetectRunRequest \ No newline at end of file +from ._get_detect_run_request import GetDetectRunRequest +from ._file_input import FileInput \ No newline at end of file diff --git a/skyflow/vault/detect/_deidentify_file_request.py b/skyflow/vault/detect/_deidentify_file_request.py index a429f5d5..09d8b118 100644 --- a/skyflow/vault/detect/_deidentify_file_request.py +++ b/skyflow/vault/detect/_deidentify_file_request.py @@ -3,6 +3,7 @@ from skyflow.vault.detect import TokenFormat, Transformations from skyflow.vault.detect._audio_bleep import Bleep from skyflow.utils.enums import MaskingMethod, DetectOutputTranscriptions +from skyflow.vault.detect._file_input import FileInput class DeidentifyFileRequest: def __init__( @@ -24,7 +25,7 @@ def __init__( output_directory: Optional[str] = None, wait_time: Optional[Union[int, float]] = None ): - self.file: object = file + self.file: FileInput = file self.entities: Optional[List[DetectEntities]] = entities self.allow_regex_list: Optional[List[str]] = allow_regex_list self.restrict_regex_list: Optional[List[str]] = restrict_regex_list diff --git a/skyflow/vault/detect/_deidentify_file_response.py b/skyflow/vault/detect/_deidentify_file_response.py index f386080d..90a0d493 100644 --- a/skyflow/vault/detect/_deidentify_file_response.py +++ b/skyflow/vault/detect/_deidentify_file_response.py @@ -1,7 +1,11 @@ +import io +from skyflow.vault.detect._file import File + class DeidentifyFileResponse: def __init__( self, - file: str = None, + file_base64: str = None, + file: io.BytesIO = None, type: str = None, extension: str = None, word_count: int = None, @@ -13,9 +17,10 @@ def __init__( entities: list = None, # list of dicts with keys 'file' and 'extension' run_id: str = None, status: str = None, - errors: list = [], + errors: list = None, ): - self.file = file + self.file_base64 = file_base64 + self.file = File(file) if file else None self.type = type self.extension = extension self.word_count = word_count @@ -32,12 +37,12 @@ def __init__( def __repr__(self): return ( f"DeidentifyFileResponse(" - f"file={self.file!r}, type={self.type!r}, extension={self.extension!r}, " - f"word_count={self.word_count!r}, char_count={self.char_count!r}, " - f"size_in_kb={self.size_in_kb!r}, duration_in_seconds={self.duration_in_seconds!r}, " - f"page_count={self.page_count!r}, slide_count={self.slide_count!r}, " - f"entities={self.entities!r}, run_id={self.run_id!r}, status={self.status!r})," - f"errors={self.errors!r})" + f"file_base64={self.file_base64!r}, file={self.file!r}, type={self.type!r}, " + f"extension={self.extension!r}, word_count={self.word_count!r}, " + f"char_count={self.char_count!r}, size_in_kb={self.size_in_kb!r}, " + f"duration_in_seconds={self.duration_in_seconds!r}, page_count={self.page_count!r}, " + f"slide_count={self.slide_count!r}, entities={self.entities!r}, " + f"run_id={self.run_id!r}, status={self.status!r}, errors={self.errors!r})" ) def __str__(self): diff --git a/skyflow/vault/detect/_file.py b/skyflow/vault/detect/_file.py new file mode 100644 index 00000000..ad188666 --- /dev/null +++ b/skyflow/vault/detect/_file.py @@ -0,0 +1,53 @@ +import io +import mimetypes +import time + +class File: + def __init__(self, file: io.BytesIO = None): + self.file = file + + @property + def name(self) -> str: + """Get file name""" + if self.file: + return getattr(self.file, 'name', 'unknown') + return None + + @property + def size(self) -> int: + """Get file size in bytes""" + if self.file: + pos = self.file.tell() + self.file.seek(0, io.SEEK_END) + size = self.file.tell() + self.file.seek(pos) + return size + return None + + @property + def type(self) -> str: + """Get file mime type""" + if self.file: + return mimetypes.guess_type(self.name)[0] or '' + return None + + @property + def last_modified(self) -> int: + """Get file last modified timestamp in milliseconds""" + if self.file: + return int(time.time() * 1000) + return None + + def seek(self, offset, whence=0): + if self.file: + return self.file.seek(offset, whence) + + def read(self, size=-1): + if self.file: + return self.file.read(size) + + def __repr__(self): + return ( + f"File(name={self.name!r}, size={self.size!r}, type={self.type!r}, " + f"last_modified={self.last_modified!r})" + ) diff --git a/skyflow/vault/detect/_file_input.py b/skyflow/vault/detect/_file_input.py new file mode 100644 index 00000000..472ca0e2 --- /dev/null +++ b/skyflow/vault/detect/_file_input.py @@ -0,0 +1,19 @@ +class FileInput: + """ + Represents a file input for the vault detection process. + + Attributes: + file (str): The file object to be processed. This can be a file-like object or a binary string. + file_path (str): The path to the file to be processed. + """ + + def __init__(self, file: str= None, file_path: str = None): + self.file = file + self.file_path = file_path + + def __repr__(self) -> str: + return f"FileInput(file={self.file!r}, file_path={self.file_path!r})" + + def __str__(self) -> str: + return self.__repr__() + \ No newline at end of file diff --git a/tests/utils/test__utils.py b/tests/utils/test__utils.py index 6324d9a7..6eaacf47 100644 --- a/tests/utils/test__utils.py +++ b/tests/utils/test__utils.py @@ -252,6 +252,12 @@ def test_parse_insert_response(self): result = parse_insert_response(api_response, continue_on_error=True) self.assertEqual(len(result.inserted_fields), 1) self.assertEqual(len(result.errors), 1) + # Assert first successful record + self.assertEqual(result.inserted_fields[0]["skyflow_id"], "id1") + # Assert error record + self.assertEqual(result.errors[0]["error"], TEST_ERROR_MESSAGE) + self.assertEqual(result.errors[0]["http_code"], 400) + self.assertEqual(result.errors[0]["request_id"], "12345") def test_parse_insert_response_continue_on_error_false(self): mock_api_response = Mock() @@ -270,7 +276,7 @@ def test_parse_insert_response_continue_on_error_false(self): ] self.assertEqual(result.inserted_fields, expected_inserted_fields) - self.assertEqual(result.errors, []) + self.assertEqual(result.errors, None) def test_parse_update_record_response(self): api_response = Mock() @@ -291,7 +297,7 @@ def test_parse_delete_response_successful(self): expected_deleted_ids = ["id_1", "id_2", "id_3"] self.assertEqual(result.deleted_ids, expected_deleted_ids) - self.assertEqual(result.errors, []) + self.assertEqual(result.errors, None) def test_parse_get_response_successful(self): mock_api_response = Mock() @@ -310,7 +316,7 @@ def test_parse_get_response_successful(self): ] self.assertEqual(result.data, expected_data) - self.assertEqual(result.errors, []) + # self.assertEqual(result.errors, None) def test_parse_detokenize_response_with_mixed_records(self): mock_api_response = Mock() @@ -384,6 +390,7 @@ def test_parse_invoke_connection_response_successful(self, mock_response): self.assertIsInstance(result, InvokeConnectionResponse) self.assertEqual(result.data["key"], "value") self.assertEqual(result.metadata["request_id"], "1234") + self.assertEqual(result.errors, None) @patch("requests.Response") def test_parse_invoke_connection_response_json_decode_error(self, mock_response): diff --git a/tests/vault/controller/test__connection.py b/tests/vault/controller/test__connection.py index 70702514..4ccad1c7 100644 --- a/tests/vault/controller/test__connection.py +++ b/tests/vault/controller/test__connection.py @@ -55,7 +55,8 @@ def test_invoke_success(self, mock_send): # Assertions for successful invocation expected_response = { 'data': {"response": "success"}, - 'metadata': {"request_id": "test-request-id"} + 'metadata': {"request_id": "test-request-id"}, + 'errors': None } self.assertEqual(vars(response), expected_response) self.mock_vault_client.get_bearer_token.assert_called_once() diff --git a/tests/vault/controller/test__detect.py b/tests/vault/controller/test__detect.py index 29db32dc..1352f85b 100644 --- a/tests/vault/controller/test__detect.py +++ b/tests/vault/controller/test__detect.py @@ -6,8 +6,12 @@ from skyflow.utils import SkyflowMessages from skyflow.vault.controller import Detect from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, \ - TokenFormat, DateTransformation, Transformations, DeidentifyFileRequest, GetDetectRunRequest, DeidentifyFileResponse + TokenFormat, DateTransformation, Transformations, DeidentifyFileRequest, GetDetectRunRequest, \ + DeidentifyFileResponse, FileInput from skyflow.utils.enums import DetectEntities, TokenType +import io + +from skyflow.vault.detect._file import File VAULT_ID = "test_vault_id" @@ -127,7 +131,7 @@ def test_deidentify_file_txt_success(self, mock_open, mock_basename, mock_base64 file_obj.name = "/tmp/test.txt" mock_basename.return_value = "test.txt" mock_base64.b64encode.return_value = b"dGVzdCBjb250ZW50" - req = DeidentifyFileRequest(file=file_obj) + req = DeidentifyFileRequest(file=FileInput(file=file_obj)) req.entities = [] req.token_format = Mock(default="default", entity_unique_counter=[], entity_only=[]) req.allow_regex_list = [] @@ -149,18 +153,38 @@ def test_deidentify_file_txt_success(self, mock_open, mock_basename, mock_base64 with patch.object(self.detect, "_Detect__poll_for_processed_file", return_value=processed_response) as mock_poll, \ patch.object(self.detect, "_Detect__parse_deidentify_file_response", - return_value=DeidentifyFileResponse(file="dGVzdCBjb250ZW50", type="txt", extension="txt", + return_value=DeidentifyFileResponse(file_base64="dGVzdCBjb250ZW50", + file=io.BytesIO(b"test content"), type="txt", + extension="txt", word_count=1, char_count=1, size_in_kb=1, duration_in_seconds=None, page_count=None, slide_count=None, entities=[], run_id="runid123", - status="SUCCESS", errors=[])) as mock_parse: + status="SUCCESS", errors=None)) as mock_parse: result = self.detect.deidentify_file(req) + mock_validate.assert_called_once() files_api.deidentify_text.assert_called_once() mock_poll.assert_called_once() mock_parse.assert_called_once() + self.assertIsInstance(result, DeidentifyFileResponse) self.assertEqual(result.status, "SUCCESS") + self.assertEqual(result.run_id, "runid123") + self.assertEqual(result.file_base64, "dGVzdCBjb250ZW50") + self.assertEqual(result.type, "txt") + self.assertEqual(result.extension, "txt") + + self.assertIsInstance(result.file, File) + result.file.seek(0) + self.assertEqual(result.file.read(), b"test content") + self.assertEqual(result.word_count, 1) + self.assertEqual(result.char_count, 1) + self.assertEqual(result.size_in_kb, 1) + self.assertIsNone(result.duration_in_seconds) + self.assertIsNone(result.page_count) + self.assertIsNone(result.slide_count) + self.assertEqual(result.entities, []) + self.assertEqual(result.errors, None) @patch("skyflow.vault.controller._detect.validate_deidentify_file_request") @patch("skyflow.vault.controller._detect.base64") @@ -170,7 +194,7 @@ def test_deidentify_file_audio_success(self, mock_base64, mock_validate): file_obj.read.return_value = file_content file_obj.name = "audio.mp3" mock_base64.b64encode.return_value = b"YXVkaW8gYnl0ZXM=" - req = DeidentifyFileRequest(file=file_obj) + req = DeidentifyFileRequest(file=FileInput(file=file_obj)) req.entities = [] req.token_format = Mock(default="default", entity_unique_counter=[], entity_only=[]) req.allow_regex_list = [] @@ -192,11 +216,13 @@ def test_deidentify_file_audio_success(self, mock_base64, mock_validate): with patch.object(self.detect, "_Detect__poll_for_processed_file", return_value=processed_response) as mock_poll, \ patch.object(self.detect, "_Detect__parse_deidentify_file_response", - return_value=DeidentifyFileResponse(file="YXVkaW8gYnl0ZXM=", type="mp3", extension="mp3", + return_value=DeidentifyFileResponse(file_base64="YXVkaW8gYnl0ZXM=", + file=io.BytesIO(b"audio bytes"), type="mp3", + extension="mp3", word_count=1, char_count=1, size_in_kb=1, duration_in_seconds=1, page_count=None, slide_count=None, entities=[], run_id="runid456", - status="SUCCESS", errors=[])) as mock_parse: + status="SUCCESS", errors=None)) as mock_parse: result = self.detect.deidentify_file(req) mock_validate.assert_called_once() files_api.deidentify_audio.assert_called_once() @@ -238,7 +264,7 @@ def test_get_detect_run_success(self, mock_validate): char_count=1, size_in_kb=1, duration_in_seconds=None, page_count=None, slide_count=None, entities=[], run_id="runid789", status="SUCCESS", - errors=[])) as mock_parse: + errors=None)) as mock_parse: result = self.detect.get_detect_run(req) mock_validate.assert_called_once() files_api.get_run.assert_called_once() @@ -262,11 +288,10 @@ def test_get_detect_run_exception(self, mock_validate): @patch("skyflow.vault.controller._detect.open", create=True) @patch.object(Detect, "_Detect__poll_for_processed_file") def test_deidentify_file_all_branches(self, mock_poll, mock_open, mock_basename, mock_base64, mock_validate): - """Test all file type branches with optimized mocking""" - # Common mocks file_content = b"test content" mock_base64.b64encode.return_value = b"dGVzdCBjb250ZW50" + mock_base64.b64decode.return_value = file_content # Prepare a generic processed_response for all branches processed_response = Mock() @@ -283,69 +308,78 @@ def test_deidentify_file_all_branches(self, mock_poll, mock_open, mock_basename, processed_response.run_id = "runid123" mock_poll.return_value = processed_response - # Patch __parse_deidentify_file_response to return a valid DeidentifyFileResponse - with patch.object(self.detect, "_Detect__parse_deidentify_file_response", - return_value=DeidentifyFileResponse( - file="dGVzdCBjb250ZW50", type="pdf", extension="pdf", - word_count=1, char_count=1, size_in_kb=1, - duration_in_seconds=1, page_count=1, slide_count=1, - entities=[], run_id="runid123", status="SUCCESS", errors=[] - )) as mock_parse: - # Test configuration for different file types - test_cases = [ - ("test.pdf", "pdf", "deidentify_pdf"), - ("test.jpg", "jpg", "deidentify_image"), - ("test.pptx", "pptx", "deidentify_presentation"), - ("test.csv", "csv", "deidentify_spreadsheet"), - ("test.docx", "docx", "deidentify_document"), - ("test.json", "json", "deidentify_structured_text"), - ("test.xml", "xml", "deidentify_structured_text"), - ("test.unknown", "unknown", "deidentify_file") - ] + # Test configuration for different file types + test_cases = [ + ("test.pdf", "pdf", "deidentify_pdf"), + ("test.jpg", "jpg", "deidentify_image"), + ("test.pptx", "pptx", "deidentify_presentation"), + ("test.csv", "csv", "deidentify_spreadsheet"), + ("test.docx", "docx", "deidentify_document"), + ("test.json", "json", "deidentify_structured_text"), + ("test.xml", "xml", "deidentify_structured_text"), + ("test.unknown", "unknown", "deidentify_file") + ] + + for file_name, extension, api_method in test_cases: + with self.subTest(file_type=extension): + # Setup file mock + file_obj = Mock() + file_obj.read.return_value = file_content + file_obj.name = file_name + mock_basename.return_value = file_name + + # Setup request with FileInput + req = DeidentifyFileRequest(file=FileInput(file=file_obj)) + req.entities = [] + req.token_format = Mock(default="default", entity_unique_counter=[], entity_only=[]) + req.allow_regex_list = [] + req.restrict_regex_list = [] + req.transformations = None + req.output_directory = "/tmp" + + # Setup API mock + files_api = Mock() + files_api.with_raw_response = files_api + api_method_mock = Mock() + setattr(files_api, api_method, api_method_mock) + self.vault_client.get_detect_file_api.return_value = files_api + + # Setup API response + api_response = Mock() + api_response.data = Mock(run_id="runid123") + api_method_mock.return_value = api_response + + # Actually run the method + result = self.detect.deidentify_file(req) + + # Verify the result + self.assertIsInstance(result, DeidentifyFileResponse) + self.assertEqual(result.status, "SUCCESS") + self.assertEqual(result.run_id, "runid123") + self.assertEqual(result.file_base64, "dGVzdCBjb250ZW50") + self.assertIsInstance(result.file, File) + result.file.seek(0) # Reset file pointer before reading + self.assertEqual(result.file.read(), b"test content") + self.assertEqual(result.type, "pdf") + self.assertEqual(result.extension, "pdf") + self.assertEqual(result.size_in_kb, 1) + self.assertEqual(result.duration_in_seconds, 1) + self.assertEqual(result.page_count, 1) + self.assertEqual(result.slide_count, 1) + self.assertEqual(result.word_count, 1) + self.assertEqual(result.char_count, 1) + + # Verify API was called + api_method_mock.assert_called_once() + mock_poll.assert_called_with("runid123", None) - for file_name, extension, api_method in test_cases: - with self.subTest(file_type=extension): - # Setup file mock - file_obj = Mock() - file_obj.read.return_value = file_content - file_obj.name = file_name - mock_basename.return_value = file_name - - # Setup request - req = DeidentifyFileRequest(file=file_obj) - req.entities = [] - req.token_format = Mock(default="default", entity_unique_counter=[], entity_only=[]) - req.allow_regex_list = [] - req.restrict_regex_list = [] - req.transformations = None - req.output_directory = "/tmp" - - # Setup API mock - files_api = Mock() - files_api.with_raw_response = files_api - api_method_mock = Mock() - setattr(files_api, api_method, api_method_mock) - self.vault_client.get_detect_file_api.return_value = files_api - - # Setup API response - api_response = Mock() - api_response.data = Mock(run_id="runid123") - api_method_mock.return_value = api_response - - # Actually run the method - result = self.detect.deidentify_file(req) - self.assertIsInstance(result, DeidentifyFileResponse) - self.assertEqual(result.status, "SUCCESS") - self.assertEqual(result.file, "dGVzdCBjb250ZW50") - self.assertEqual(result.type, "pdf") - self.assertEqual(result.extension, "pdf") @patch("skyflow.vault.controller._detect.validate_deidentify_file_request") @patch("skyflow.vault.controller._detect.base64") def test_deidentify_file_exception(self, mock_base64, mock_validate): file_obj = Mock() file_obj.read.side_effect = Exception("Read error") file_obj.name = "test.txt" - req = DeidentifyFileRequest(file=file_obj) + req = DeidentifyFileRequest(file=FileInput(file=file_obj)) req.entities = [] req.token_format = Mock(default="default", entity_unique_counter=[], entity_only=[]) req.allow_regex_list = [] @@ -404,8 +438,8 @@ def test_parse_deidentify_file_response_dict_and_obj(self): # Dict input data = { "output": [ - {"processedFile": "abc", "processedFileType": "pdf", "processedFileExtension": "pdf"}, - {"processedFile": "def", "processedFileType": "entities", "processedFileExtension": "json"} + {"processedFile": "YWJj", "processedFileType": "pdf", "processedFileExtension": "pdf"}, # base64 for "abc" + {"processedFile": "ZGVm", "processedFileType": "entities", "processedFileExtension": "json"} # base64 for "def" ], "word_character_count": {"word_count": 5, "character_count": 10}, "size": 1, @@ -426,9 +460,9 @@ class DummyWordChar: class DummyData: output = [ type("O", (), - {"processed_file": "abc", "processed_file_type": "pdf", "processed_file_extension": "pdf"})(), + {"processed_file": "YWJj", "processed_file_type": "pdf", "processed_file_extension": "pdf"})(), type("O", (), - {"processed_file": "def", "processed_file_type": "entities", "processed_file_extension": "json"})() + {"processed_file": "ZGVm", "processed_file_type": "entities", "processed_file_extension": "json"})() ] word_character_count = DummyWordChar() size = 1 @@ -441,7 +475,9 @@ class DummyData: obj_data = DummyData() result = self.detect._Detect__parse_deidentify_file_response(obj_data, "runid", "SUCCESS") self.assertIsInstance(result, DeidentifyFileResponse) - + self.assertEqual(result.file_base64, "YWJj") + self.assertIsInstance(result.file, File) + self.assertEqual(result.file.read(), b"abc") def test_get_token_format_missing_attribute(self): """Test __get_token_format when token_format attribute is missing""" class DummyRequest: @@ -559,12 +595,11 @@ def track_sleep(*args): self.assertEqual(calls, [2, 2]) self.assertEqual(result.status, "SUCCESS") - def test_parse_deidentify_file_response_output_conversion(self): """Test output conversion in parse_deidentify_file_response""" class OutputObj: - processed_file = "file123" + processed_file = "YWJjMTIz" # base64 for "abc123" processed_file_type = "pdf" processed_file_extension = "pdf" @@ -574,6 +609,103 @@ class OutputObj: result = self.detect._Detect__parse_deidentify_file_response(data) - self.assertEqual(result.file, "file123") + # Check base64 string + self.assertEqual(result.file_base64, "YWJjMTIz") + # Check File object + self.assertIsInstance(result.file, File) + self.assertEqual(result.file.read(), b"abc123") + # Check other attributes self.assertEqual(result.type, "pdf") - self.assertEqual(result.extension, "pdf") \ No newline at end of file + self.assertEqual(result.extension, "pdf") + # Reset file pointer and verify content again + result.file.seek(0) + self.assertEqual(result.file.read(), b"abc123") + + @patch("skyflow.vault.controller._detect.validate_deidentify_file_request") + @patch("skyflow.vault.controller._detect.base64") + @patch("skyflow.vault.controller._detect.os.path.basename") + @patch("skyflow.vault.controller._detect.open", create=True) + def test_deidentify_file_using_file_path(self, mock_open, mock_basename, mock_base64, mock_validate): + # Setup mock file context + mock_file = MagicMock() + mock_file.read.return_value = b"test content from file path" + mock_file.name = "/path/to/test.txt" + mock_file.__enter__.return_value = mock_file # Mock context manager + mock_open.return_value = mock_file + mock_basename.return_value = "test.txt" + mock_base64.b64encode.return_value = b"dGVzdCBjb250ZW50IGZyb20gZmlsZSBwYXRo" # base64 of "test content from file path" + mock_base64.b64decode.return_value = b"test content from file path" + # Create request with file_path + req = DeidentifyFileRequest(file=FileInput(file_path="/path/to/test.txt")) + req.entities = [] + req.token_format = Mock(default="default", entity_unique_counter=[], entity_only=[]) + req.allow_regex_list = [] + req.restrict_regex_list = [] + req.transformations = None + req.output_directory = "/tmp" + + # Setup API mock + files_api = Mock() + files_api.with_raw_response = files_api + files_api.deidentify_text = Mock() + self.vault_client.get_detect_file_api.return_value = files_api + api_response = Mock() + api_response.data = Mock(run_id="runid123") + files_api.deidentify_text.return_value = api_response + + # Setup processed response + processed_response = Mock() + processed_response.status = "SUCCESS" + processed_response.output = [] + processed_response.word_character_count = Mock(word_count=1, character_count=1) + + # Test the method + with patch.object(self.detect, "_Detect__poll_for_processed_file", + return_value=processed_response) as mock_poll, \ + patch.object(self.detect, "_Detect__parse_deidentify_file_response", + return_value=DeidentifyFileResponse( + file_base64="dGVzdCBjb250ZW50IGZyb20gZmlsZSBwYXRo", + file=io.BytesIO(b"test content from file path"), + type="txt", + extension="txt", + word_count=1, + char_count=1, + size_in_kb=1, + duration_in_seconds=None, + page_count=None, + slide_count=None, + entities=[], + run_id="runid123", + status="SUCCESS", + errors=None + )) as mock_parse: + + result = self.detect.deidentify_file(req) + + mock_file.read.assert_called_once() + mock_basename.assert_called_with("/path/to/test.txt") + + mock_validate.assert_called_once() + files_api.deidentify_text.assert_called_once() + mock_poll.assert_called_once() + mock_parse.assert_called_once() + + # Response assertions + self.assertIsInstance(result, DeidentifyFileResponse) + self.assertEqual(result.status, "SUCCESS") + self.assertEqual(result.run_id, "runid123") + self.assertEqual(result.file_base64, "dGVzdCBjb250ZW50IGZyb20gZmlsZSBwYXRo") + self.assertEqual(result.type, "txt") + self.assertEqual(result.extension, "txt") + + self.assertIsInstance(result.file, File) + result.file.seek(0) + self.assertEqual(result.file.read(), b"test content from file path") + self.assertEqual(result.word_count, 1) + self.assertEqual(result.char_count, 1) + self.assertEqual(result.size_in_kb, 1) + self.assertIsNone(result.duration_in_seconds) + self.assertIsNone(result.page_count) + self.assertIsNone(result.slide_count) + self.assertEqual(result.entities, []) + self.assertEqual(result.errors, None) diff --git a/tests/vault/controller/test__vault.py b/tests/vault/controller/test__vault.py index 39b44ae1..0c8a7743 100644 --- a/tests/vault/controller/test__vault.py +++ b/tests/vault/controller/test__vault.py @@ -123,7 +123,7 @@ def test_insert_with_continue_on_error_false(self, mock_parse_response, mock_val # Assert that the result matches the expected InsertResponse self.assertEqual(result.inserted_fields, expected_inserted_fields) - self.assertEqual(result.errors, []) # No errors expected + self.assertEqual(result.errors, None) # No errors expected @patch("skyflow.vault.controller._vault.validate_insert_request") def test_insert_handles_generic_error(self, mock_validate): @@ -181,7 +181,7 @@ def test_insert_with_continue_on_error_false_when_tokens_are_not_none(self, mock # Assert that the result matches the expected InsertResponse self.assertEqual(result.inserted_fields, expected_inserted_fields) - self.assertEqual(result.errors, []) # No errors expected + self.assertEqual(result.errors, None) # No errors expected @patch("skyflow.vault.controller._vault.validate_update_request") @patch("skyflow.vault.controller._vault.parse_update_record_response") @@ -223,7 +223,7 @@ def test_update_successful(self, mock_parse_response, mock_validate): # Check that the result matches the expected UpdateResponse self.assertEqual(result.updated_field, expected_updated_field) - self.assertEqual(result.errors, []) # No errors expected + self.assertEqual(result.errors, None) # No errors expected @patch("skyflow.vault.controller._vault.validate_update_request") def test_update_handles_generic_error(self, mock_validate): @@ -257,7 +257,7 @@ def test_delete_successful(self, mock_parse_response, mock_validate): # Expected parsed response expected_deleted_ids = ["12345", "67890"] - expected_response = DeleteResponse(deleted_ids=expected_deleted_ids, errors=[]) + expected_response = DeleteResponse(deleted_ids=expected_deleted_ids, errors=None) # Set the return value for the parse response mock_parse_response.return_value = expected_response @@ -273,7 +273,7 @@ def test_delete_successful(self, mock_parse_response, mock_validate): # Check that the result matches the expected DeleteResponse self.assertEqual(result.deleted_ids, expected_deleted_ids) - self.assertEqual(result.errors, []) # No errors expected + self.assertEqual(result.errors, None) # No errors expected @patch("skyflow.vault.controller._vault.validate_delete_request") def test_delete_handles_generic_exception(self, mock_validate): @@ -330,7 +330,7 @@ def test_get_successful(self, mock_parse_response, mock_validate): {"field1": "value1", "field2": "value2"}, {"field1": "value3", "field2": "value4"} ] - expected_response = GetResponse(data=expected_data, errors=[]) + expected_response = GetResponse(data=expected_data, errors=None) # Set the return value for parse_get_response mock_parse_response.return_value = expected_response @@ -346,7 +346,7 @@ def test_get_successful(self, mock_parse_response, mock_validate): # Check that the result matches the expected GetResponse self.assertEqual(result.data, expected_data) - self.assertEqual(result.errors, []) # No errors expected + self.assertEqual(result.errors, None) # No errors expected @patch("skyflow.vault.controller._vault.validate_get_request") @patch("skyflow.vault.controller._vault.parse_get_response") @@ -381,7 +381,7 @@ def test_get_successful_with_column_values(self, mock_parse_response, mock_valid {"field1": "value1", "field2": "value2"}, {"field1": "value3", "field2": "value4"} ] - expected_response = GetResponse(data=expected_data, errors=[]) + expected_response = GetResponse(data=expected_data, errors=None) # Set the return value for parse_get_response mock_parse_response.return_value = expected_response @@ -397,7 +397,7 @@ def test_get_successful_with_column_values(self, mock_parse_response, mock_valid # Check that the result matches the expected GetResponse self.assertEqual(result.data, expected_data) - self.assertEqual(result.errors, []) # No errors expected + self.assertEqual(result.errors, None) # No errors expected @patch("skyflow.vault.controller._vault.validate_get_request") def test_get_handles_generic_error(self, mock_validate): @@ -446,7 +446,7 @@ def test_query_successful(self, mock_parse_response, mock_validate): # Check that the result matches the expected QueryResponse self.assertEqual(result.fields, expected_fields) - self.assertEqual(result.errors, []) # No errors expected + self.assertEqual(result.errors, None) # No errors expected @patch("skyflow.vault.controller._vault.validate_query_request") def test_query_handles_generic_error(self, mock_validate): @@ -495,7 +495,7 @@ def test_detokenize_successful(self, mock_parse_response, mock_validate): {"token": "token1", "value": "value1", "type": "STRING"}, {"token": "token2", "value": "value2", "type": "STRING"} ] - expected_response = DetokenizeResponse(detokenized_fields=expected_fields, errors=[]) + expected_response = DetokenizeResponse(detokenized_fields=expected_fields, errors=None) # Set the return value for parse_detokenize_response mock_parse_response.return_value = expected_response @@ -511,7 +511,7 @@ def test_detokenize_successful(self, mock_parse_response, mock_validate): # Check that the result matches the expected DetokenizeResponse self.assertEqual(result.detokenized_fields, expected_fields) - self.assertEqual(result.errors, []) # No errors expected + self.assertEqual(result.errors, None) # No errors expected @patch("skyflow.vault.controller._vault.validate_detokenize_request") def test_detokenize_handles_generic_error(self, mock_validate):