diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index a5b94451..460ca29e 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -169,6 +169,7 @@ class Error(Enum): INVALID_PLAIN_TEXT_ENTITIES_IN_REIDENTIFY= f"{error_prefix} Validation error. The plainTextEntities field must be an array of DetectEntities enums. Specify a valid plainTextEntities." INVALID_DEIDENTIFY_FILE_REQUEST= f"{error_prefix} Validation error. Invalid deidentify file request. Specify a valid deidentify file request." + INVALID_DEIDENTIFY_FILE_INPUT= f"{error_prefix} Validation error. Invalid deidentify file input. Please provide either a file or a file path." EMPTY_FILE_OBJECT= f"{error_prefix} Validation error. File object cannot be empty. Specify a valid file object." INVALID_FILE_FORMAT= f"{error_prefix} Validation error. Invalid file format. Specify a valid file format." MISSING_FILE_SOURCE= f"{error_prefix} Validation error. Provide exactly one of filePath, base64, or fileObject." @@ -197,7 +198,7 @@ class Error(Enum): INVALID_FILE_OR_ENCODED_FILE= f"{error_prefix} . Error while decoding base64 and saving file" INVALID_FILE_TYPE = f"{error_prefix} Validation error. Invalid file type. Specify a valid file type." INVALID_FILE_NAME= f"{error_prefix} Validation error. Invalid file name. Specify a valid file name." - FILE_READ_ERROR= f"{error_prefix} Validation error. Unable to read file. Verify the file path." + INVALID_DEIDENTIFY_FILE_PATH= f"{error_prefix} Validation error. Invalid file path. Specify a valid file path." INVALID_BASE64_HEADER= f"{error_prefix} Validation error. Invalid base64 header. Specify a valid base64 header." INVALID_WAIT_TIME= f"{error_prefix} Validation error. Invalid wait time. Specify a valid wait time as number and should not be greater than 64 secs." INVALID_OUTPUT_DIRECTORY= f"{error_prefix} Validation error. Invalid output directory. Specify a valid output directory as string." diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index 0ff9f038..bbca6e85 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -9,6 +9,7 @@ from skyflow.utils.logger import log_info, log_error_log from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, TokenFormat, Transformations, \ GetDetectRunRequest, Bleep, DeidentifyFileRequest +from skyflow.vault.detect._file_input import FileInput valid_vault_config_keys = ["vault_id", "cluster_id", "credentials", "env"] valid_connection_config_keys = ["connection_id", "connection_url", "credentials"] @@ -257,9 +258,42 @@ def validate_update_connection_config(logger, config): return True +def validate_file_from_request(file_input: FileInput): + if file_input is None: + raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_INPUT.value, invalid_input_error_code) + + has_file = hasattr(file_input, 'file') and file_input.file is not None + has_file_path = hasattr(file_input, 'file_path') and file_input.file_path is not None + + # Must provide exactly one of file or file_path + if (has_file and has_file_path) or (not has_file and not has_file_path): + raise SkyflowError(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_INPUT.value, invalid_input_error_code) + + if has_file: + file = file_input.file + # Validate file object has required attributes + if not hasattr(file, 'name') or not isinstance(file.name, str) or not file.name.strip(): + raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_TYPE.value, invalid_input_error_code) + + # Validate file name + file_name = os.path.splitext(file.name)[0] + if not file_name or not file_name.strip(): + raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_NAME.value, invalid_input_error_code) + + elif has_file_path: + file_path = file_input.file_path + if not isinstance(file_path, str) or not file_path.strip(): + raise SkyflowError(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value, invalid_input_error_code) + + if not os.path.exists(file_path) or not os.path.isfile(file_path): + raise SkyflowError(SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value, invalid_input_error_code) + def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): if not hasattr(request, 'file') or request.file is None: raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_INPUT.value, invalid_input_error_code) + + # Validate file input first + validate_file_from_request(request.file) # Optional: entities if hasattr(request, 'entities') and request.entities is not None: diff --git a/skyflow/vault/controller/_detect.py b/skyflow/vault/controller/_detect.py index 1dbd533c..8e6576cf 100644 --- a/skyflow/vault/controller/_detect.py +++ b/skyflow/vault/controller/_detect.py @@ -1,3 +1,4 @@ +import io import json import os from skyflow.error import SkyflowError @@ -20,6 +21,7 @@ from skyflow.vault.detect import DeidentifyTextRequest, DeidentifyTextResponse, ReidentifyTextRequest, \ ReidentifyTextResponse, DeidentifyFileRequest, DeidentifyFileResponse, GetDetectRunRequest + class Detect: def __init__(self, vault_client): self.__vault_client = vault_client @@ -124,10 +126,22 @@ def output_to_dict_list(output): word_count = getattr(word_character_count, "word_count", None) char_count = getattr(word_character_count, "character_count", None) + base64_string = first_output.get("file", None) + extension = first_output.get("extension", None) + + file_obj = None + if base64_string is not None: + file_bytes = base64.b64decode(base64_string) + file_obj = io.BytesIO(file_bytes) + file_obj.name = f"deidentified.{extension}" if extension else "processed_file" + else: + file_obj = None + return DeidentifyFileResponse( - file=first_output.get("file", None), + file_base64=base64_string, + file=file_obj, # File class will be instantiated in DeidentifyFileResponse type=first_output.get("type", None), - extension=first_output.get("extension", None), + extension=extension, word_count=word_count, char_count=char_count, size_in_kb=size, @@ -216,16 +230,26 @@ def reidentify_text(self, request: ReidentifyTextRequest) -> ReidentifyTextRespo log_error_log(SkyflowMessages.ErrorLogs.REIDENTIFY_TEXT_REQUEST_REJECTED.value, self.__vault_client.get_logger()) handle_exception(e, self.__vault_client.get_logger()) + def __get_file_from_request(self, request: DeidentifyFileRequest): + file_input = request.file + + # Check for file + if hasattr(file_input, 'file') and file_input.file is not None: + return file_input.file + + # Check for file_path if file is not provided + if hasattr(file_input, 'file_path') and file_input.file_path is not None: + return open(file_input.file_path, 'rb') + def deidentify_file(self, request: DeidentifyFileRequest): log_info(SkyflowMessages.Info.DETECT_FILE_TRIGGERED.value, self.__vault_client.get_logger()) validate_deidentify_file_request(self.__vault_client.get_logger(), request) self.__initialize() files_api = self.__vault_client.get_detect_file_api().with_raw_response - file_obj = request.file + file_obj = self.__get_file_from_request(request) file_name = getattr(file_obj, 'name', None) file_extension = self._get_file_extension(file_name) if file_name else None file_content = file_obj.read() - base64_string = base64.b64encode(file_content).decode('utf-8') try: @@ -375,7 +399,7 @@ def deidentify_file(self, request: DeidentifyFileRequest): file_name_only = 'processed-'+os.path.basename(file_name) output_file_path = f"{request.output_directory}/{file_name_only}" with open(output_file_path, 'wb') as output_file: - output_file.write(base64.b64decode(parsed_response.file)) + output_file.write(base64.b64decode(parsed_response.file_base64)) log_info(SkyflowMessages.Info.DETECT_FILE_SUCCESS.value, self.__vault_client.get_logger()) return parsed_response diff --git a/skyflow/vault/detect/__init__.py b/skyflow/vault/detect/__init__.py index e385a1f2..bd09fed8 100644 --- a/skyflow/vault/detect/__init__.py +++ b/skyflow/vault/detect/__init__.py @@ -10,4 +10,5 @@ from ._deidentify_file_request import DeidentifyFileRequest from ._audio_bleep import Bleep from ._deidentify_file_response import DeidentifyFileResponse -from ._get_detect_run_request import GetDetectRunRequest \ No newline at end of file +from ._get_detect_run_request import GetDetectRunRequest +from ._file_input import FileInput \ No newline at end of file diff --git a/skyflow/vault/detect/_deidentify_file_request.py b/skyflow/vault/detect/_deidentify_file_request.py index a429f5d5..09d8b118 100644 --- a/skyflow/vault/detect/_deidentify_file_request.py +++ b/skyflow/vault/detect/_deidentify_file_request.py @@ -3,6 +3,7 @@ from skyflow.vault.detect import TokenFormat, Transformations from skyflow.vault.detect._audio_bleep import Bleep from skyflow.utils.enums import MaskingMethod, DetectOutputTranscriptions +from skyflow.vault.detect._file_input import FileInput class DeidentifyFileRequest: def __init__( @@ -24,7 +25,7 @@ def __init__( output_directory: Optional[str] = None, wait_time: Optional[Union[int, float]] = None ): - self.file: object = file + self.file: FileInput = file self.entities: Optional[List[DetectEntities]] = entities self.allow_regex_list: Optional[List[str]] = allow_regex_list self.restrict_regex_list: Optional[List[str]] = restrict_regex_list diff --git a/skyflow/vault/detect/_deidentify_file_response.py b/skyflow/vault/detect/_deidentify_file_response.py index f386080d..dfe1851a 100644 --- a/skyflow/vault/detect/_deidentify_file_response.py +++ b/skyflow/vault/detect/_deidentify_file_response.py @@ -1,7 +1,11 @@ +import io +from skyflow.vault.detect._file import File + class DeidentifyFileResponse: def __init__( self, - file: str = None, + file_base64: str = None, + file: io.BytesIO = None, type: str = None, extension: str = None, word_count: int = None, @@ -15,7 +19,8 @@ def __init__( status: str = None, errors: list = [], ): - self.file = file + self.file_base64 = file_base64 + self.file = File(file) if file else None self.type = type self.extension = extension self.word_count = word_count @@ -32,12 +37,12 @@ def __init__( def __repr__(self): return ( f"DeidentifyFileResponse(" - f"file={self.file!r}, type={self.type!r}, extension={self.extension!r}, " - f"word_count={self.word_count!r}, char_count={self.char_count!r}, " - f"size_in_kb={self.size_in_kb!r}, duration_in_seconds={self.duration_in_seconds!r}, " - f"page_count={self.page_count!r}, slide_count={self.slide_count!r}, " - f"entities={self.entities!r}, run_id={self.run_id!r}, status={self.status!r})," - f"errors={self.errors!r})" + f"file_base64={self.file_base64!r}, file={self.file!r}, type={self.type!r}, " + f"extension={self.extension!r}, word_count={self.word_count!r}, " + f"char_count={self.char_count!r}, size_in_kb={self.size_in_kb!r}, " + f"duration_in_seconds={self.duration_in_seconds!r}, page_count={self.page_count!r}, " + f"slide_count={self.slide_count!r}, entities={self.entities!r}, " + f"run_id={self.run_id!r}, status={self.status!r}, errors={self.errors!r})" ) def __str__(self): diff --git a/skyflow/vault/detect/_file.py b/skyflow/vault/detect/_file.py new file mode 100644 index 00000000..ad188666 --- /dev/null +++ b/skyflow/vault/detect/_file.py @@ -0,0 +1,53 @@ +import io +import mimetypes +import time + +class File: + def __init__(self, file: io.BytesIO = None): + self.file = file + + @property + def name(self) -> str: + """Get file name""" + if self.file: + return getattr(self.file, 'name', 'unknown') + return None + + @property + def size(self) -> int: + """Get file size in bytes""" + if self.file: + pos = self.file.tell() + self.file.seek(0, io.SEEK_END) + size = self.file.tell() + self.file.seek(pos) + return size + return None + + @property + def type(self) -> str: + """Get file mime type""" + if self.file: + return mimetypes.guess_type(self.name)[0] or '' + return None + + @property + def last_modified(self) -> int: + """Get file last modified timestamp in milliseconds""" + if self.file: + return int(time.time() * 1000) + return None + + def seek(self, offset, whence=0): + if self.file: + return self.file.seek(offset, whence) + + def read(self, size=-1): + if self.file: + return self.file.read(size) + + def __repr__(self): + return ( + f"File(name={self.name!r}, size={self.size!r}, type={self.type!r}, " + f"last_modified={self.last_modified!r})" + ) diff --git a/skyflow/vault/detect/_file_input.py b/skyflow/vault/detect/_file_input.py new file mode 100644 index 00000000..472ca0e2 --- /dev/null +++ b/skyflow/vault/detect/_file_input.py @@ -0,0 +1,19 @@ +class FileInput: + """ + Represents a file input for the vault detection process. + + Attributes: + file (str): The file object to be processed. This can be a file-like object or a binary string. + file_path (str): The path to the file to be processed. + """ + + def __init__(self, file: str= None, file_path: str = None): + self.file = file + self.file_path = file_path + + def __repr__(self) -> str: + return f"FileInput(file={self.file!r}, file_path={self.file_path!r})" + + def __str__(self) -> str: + return self.__repr__() + \ No newline at end of file diff --git a/tests/vault/controller/test__detect.py b/tests/vault/controller/test__detect.py index 29db32dc..dbe92e93 100644 --- a/tests/vault/controller/test__detect.py +++ b/tests/vault/controller/test__detect.py @@ -6,8 +6,12 @@ from skyflow.utils import SkyflowMessages from skyflow.vault.controller import Detect from skyflow.vault.detect import DeidentifyTextRequest, ReidentifyTextRequest, \ - TokenFormat, DateTransformation, Transformations, DeidentifyFileRequest, GetDetectRunRequest, DeidentifyFileResponse + TokenFormat, DateTransformation, Transformations, DeidentifyFileRequest, GetDetectRunRequest, \ + DeidentifyFileResponse, FileInput from skyflow.utils.enums import DetectEntities, TokenType +import io + +from skyflow.vault.detect._file import File VAULT_ID = "test_vault_id" @@ -127,7 +131,7 @@ def test_deidentify_file_txt_success(self, mock_open, mock_basename, mock_base64 file_obj.name = "/tmp/test.txt" mock_basename.return_value = "test.txt" mock_base64.b64encode.return_value = b"dGVzdCBjb250ZW50" - req = DeidentifyFileRequest(file=file_obj) + req = DeidentifyFileRequest(file=FileInput(file=file_obj)) req.entities = [] req.token_format = Mock(default="default", entity_unique_counter=[], entity_only=[]) req.allow_regex_list = [] @@ -149,18 +153,38 @@ def test_deidentify_file_txt_success(self, mock_open, mock_basename, mock_base64 with patch.object(self.detect, "_Detect__poll_for_processed_file", return_value=processed_response) as mock_poll, \ patch.object(self.detect, "_Detect__parse_deidentify_file_response", - return_value=DeidentifyFileResponse(file="dGVzdCBjb250ZW50", type="txt", extension="txt", + return_value=DeidentifyFileResponse(file_base64="dGVzdCBjb250ZW50", + file=io.BytesIO(b"test content"), type="txt", + extension="txt", word_count=1, char_count=1, size_in_kb=1, duration_in_seconds=None, page_count=None, slide_count=None, entities=[], run_id="runid123", status="SUCCESS", errors=[])) as mock_parse: result = self.detect.deidentify_file(req) + mock_validate.assert_called_once() files_api.deidentify_text.assert_called_once() mock_poll.assert_called_once() mock_parse.assert_called_once() + self.assertIsInstance(result, DeidentifyFileResponse) self.assertEqual(result.status, "SUCCESS") + self.assertEqual(result.run_id, "runid123") + self.assertEqual(result.file_base64, "dGVzdCBjb250ZW50") + self.assertEqual(result.type, "txt") + self.assertEqual(result.extension, "txt") + + self.assertIsInstance(result.file, File) + result.file.seek(0) + self.assertEqual(result.file.read(), b"test content") + self.assertEqual(result.word_count, 1) + self.assertEqual(result.char_count, 1) + self.assertEqual(result.size_in_kb, 1) + self.assertIsNone(result.duration_in_seconds) + self.assertIsNone(result.page_count) + self.assertIsNone(result.slide_count) + self.assertEqual(result.entities, []) + self.assertEqual(result.errors, []) @patch("skyflow.vault.controller._detect.validate_deidentify_file_request") @patch("skyflow.vault.controller._detect.base64") @@ -170,7 +194,7 @@ def test_deidentify_file_audio_success(self, mock_base64, mock_validate): file_obj.read.return_value = file_content file_obj.name = "audio.mp3" mock_base64.b64encode.return_value = b"YXVkaW8gYnl0ZXM=" - req = DeidentifyFileRequest(file=file_obj) + req = DeidentifyFileRequest(file=FileInput(file=file_obj)) req.entities = [] req.token_format = Mock(default="default", entity_unique_counter=[], entity_only=[]) req.allow_regex_list = [] @@ -192,7 +216,9 @@ def test_deidentify_file_audio_success(self, mock_base64, mock_validate): with patch.object(self.detect, "_Detect__poll_for_processed_file", return_value=processed_response) as mock_poll, \ patch.object(self.detect, "_Detect__parse_deidentify_file_response", - return_value=DeidentifyFileResponse(file="YXVkaW8gYnl0ZXM=", type="mp3", extension="mp3", + return_value=DeidentifyFileResponse(file_base64="YXVkaW8gYnl0ZXM=", + file=io.BytesIO(b"audio bytes"), type="mp3", + extension="mp3", word_count=1, char_count=1, size_in_kb=1, duration_in_seconds=1, page_count=None, slide_count=None, entities=[], run_id="runid456", @@ -234,11 +260,12 @@ def test_get_detect_run_success(self, mock_validate): response.word_character_count = Mock(word_count=1, character_count=1) files_api.get_run.return_value = response with patch.object(self.detect, "_Detect__parse_deidentify_file_response", - return_value=DeidentifyFileResponse(file="file", type="txt", extension="txt", word_count=1, - char_count=1, size_in_kb=1, duration_in_seconds=None, - page_count=None, slide_count=None, entities=[], - run_id="runid789", status="SUCCESS", - errors=[])) as mock_parse: + return_value=DeidentifyFileResponse( + file="file", type="txt", extension="txt", word_count=1, + char_count=1, size_in_kb=1, duration_in_seconds=None, + page_count=None, slide_count=None, entities=[], + run_id="runid789", status="SUCCESS", + errors=[])) as mock_parse: result = self.detect.get_detect_run(req) mock_validate.assert_called_once() files_api.get_run.assert_called_once() @@ -262,11 +289,10 @@ def test_get_detect_run_exception(self, mock_validate): @patch("skyflow.vault.controller._detect.open", create=True) @patch.object(Detect, "_Detect__poll_for_processed_file") def test_deidentify_file_all_branches(self, mock_poll, mock_open, mock_basename, mock_base64, mock_validate): - """Test all file type branches with optimized mocking""" - # Common mocks file_content = b"test content" mock_base64.b64encode.return_value = b"dGVzdCBjb250ZW50" + mock_base64.b64decode.return_value = file_content # Prepare a generic processed_response for all branches processed_response = Mock() @@ -283,69 +309,78 @@ def test_deidentify_file_all_branches(self, mock_poll, mock_open, mock_basename, processed_response.run_id = "runid123" mock_poll.return_value = processed_response - # Patch __parse_deidentify_file_response to return a valid DeidentifyFileResponse - with patch.object(self.detect, "_Detect__parse_deidentify_file_response", - return_value=DeidentifyFileResponse( - file="dGVzdCBjb250ZW50", type="pdf", extension="pdf", - word_count=1, char_count=1, size_in_kb=1, - duration_in_seconds=1, page_count=1, slide_count=1, - entities=[], run_id="runid123", status="SUCCESS", errors=[] - )) as mock_parse: - # Test configuration for different file types - test_cases = [ - ("test.pdf", "pdf", "deidentify_pdf"), - ("test.jpg", "jpg", "deidentify_image"), - ("test.pptx", "pptx", "deidentify_presentation"), - ("test.csv", "csv", "deidentify_spreadsheet"), - ("test.docx", "docx", "deidentify_document"), - ("test.json", "json", "deidentify_structured_text"), - ("test.xml", "xml", "deidentify_structured_text"), - ("test.unknown", "unknown", "deidentify_file") - ] + # Test configuration for different file types + test_cases = [ + ("test.pdf", "pdf", "deidentify_pdf"), + ("test.jpg", "jpg", "deidentify_image"), + ("test.pptx", "pptx", "deidentify_presentation"), + ("test.csv", "csv", "deidentify_spreadsheet"), + ("test.docx", "docx", "deidentify_document"), + ("test.json", "json", "deidentify_structured_text"), + ("test.xml", "xml", "deidentify_structured_text"), + ("test.unknown", "unknown", "deidentify_file") + ] + + for file_name, extension, api_method in test_cases: + with self.subTest(file_type=extension): + # Setup file mock + file_obj = Mock() + file_obj.read.return_value = file_content + file_obj.name = file_name + mock_basename.return_value = file_name + + # Setup request with FileInput + req = DeidentifyFileRequest(file=FileInput(file=file_obj)) + req.entities = [] + req.token_format = Mock(default="default", entity_unique_counter=[], entity_only=[]) + req.allow_regex_list = [] + req.restrict_regex_list = [] + req.transformations = None + req.output_directory = "/tmp" + + # Setup API mock + files_api = Mock() + files_api.with_raw_response = files_api + api_method_mock = Mock() + setattr(files_api, api_method, api_method_mock) + self.vault_client.get_detect_file_api.return_value = files_api + + # Setup API response + api_response = Mock() + api_response.data = Mock(run_id="runid123") + api_method_mock.return_value = api_response + + # Actually run the method + result = self.detect.deidentify_file(req) + + # Verify the result + self.assertIsInstance(result, DeidentifyFileResponse) + self.assertEqual(result.status, "SUCCESS") + self.assertEqual(result.run_id, "runid123") + self.assertEqual(result.file_base64, "dGVzdCBjb250ZW50") + self.assertIsInstance(result.file, File) + result.file.seek(0) # Reset file pointer before reading + self.assertEqual(result.file.read(), b"test content") + self.assertEqual(result.type, "pdf") + self.assertEqual(result.extension, "pdf") + self.assertEqual(result.size_in_kb, 1) + self.assertEqual(result.duration_in_seconds, 1) + self.assertEqual(result.page_count, 1) + self.assertEqual(result.slide_count, 1) + self.assertEqual(result.word_count, 1) + self.assertEqual(result.char_count, 1) + + # Verify API was called + api_method_mock.assert_called_once() + mock_poll.assert_called_with("runid123", None) - for file_name, extension, api_method in test_cases: - with self.subTest(file_type=extension): - # Setup file mock - file_obj = Mock() - file_obj.read.return_value = file_content - file_obj.name = file_name - mock_basename.return_value = file_name - - # Setup request - req = DeidentifyFileRequest(file=file_obj) - req.entities = [] - req.token_format = Mock(default="default", entity_unique_counter=[], entity_only=[]) - req.allow_regex_list = [] - req.restrict_regex_list = [] - req.transformations = None - req.output_directory = "/tmp" - - # Setup API mock - files_api = Mock() - files_api.with_raw_response = files_api - api_method_mock = Mock() - setattr(files_api, api_method, api_method_mock) - self.vault_client.get_detect_file_api.return_value = files_api - - # Setup API response - api_response = Mock() - api_response.data = Mock(run_id="runid123") - api_method_mock.return_value = api_response - - # Actually run the method - result = self.detect.deidentify_file(req) - self.assertIsInstance(result, DeidentifyFileResponse) - self.assertEqual(result.status, "SUCCESS") - self.assertEqual(result.file, "dGVzdCBjb250ZW50") - self.assertEqual(result.type, "pdf") - self.assertEqual(result.extension, "pdf") @patch("skyflow.vault.controller._detect.validate_deidentify_file_request") @patch("skyflow.vault.controller._detect.base64") def test_deidentify_file_exception(self, mock_base64, mock_validate): file_obj = Mock() file_obj.read.side_effect = Exception("Read error") file_obj.name = "test.txt" - req = DeidentifyFileRequest(file=file_obj) + req = DeidentifyFileRequest(file=FileInput(file=file_obj)) req.entities = [] req.token_format = Mock(default="default", entity_unique_counter=[], entity_only=[]) req.allow_regex_list = [] @@ -404,8 +439,8 @@ def test_parse_deidentify_file_response_dict_and_obj(self): # Dict input data = { "output": [ - {"processedFile": "abc", "processedFileType": "pdf", "processedFileExtension": "pdf"}, - {"processedFile": "def", "processedFileType": "entities", "processedFileExtension": "json"} + {"processedFile": "YWJj", "processedFileType": "pdf", "processedFileExtension": "pdf"}, # base64 for "abc" + {"processedFile": "ZGVm", "processedFileType": "entities", "processedFileExtension": "json"} # base64 for "def" ], "word_character_count": {"word_count": 5, "character_count": 10}, "size": 1, @@ -426,9 +461,9 @@ class DummyWordChar: class DummyData: output = [ type("O", (), - {"processed_file": "abc", "processed_file_type": "pdf", "processed_file_extension": "pdf"})(), + {"processed_file": "YWJj", "processed_file_type": "pdf", "processed_file_extension": "pdf"})(), type("O", (), - {"processed_file": "def", "processed_file_type": "entities", "processed_file_extension": "json"})() + {"processed_file": "ZGVm", "processed_file_type": "entities", "processed_file_extension": "json"})() ] word_character_count = DummyWordChar() size = 1 @@ -441,7 +476,9 @@ class DummyData: obj_data = DummyData() result = self.detect._Detect__parse_deidentify_file_response(obj_data, "runid", "SUCCESS") self.assertIsInstance(result, DeidentifyFileResponse) - + self.assertEqual(result.file_base64, "YWJj") + self.assertIsInstance(result.file, File) + self.assertEqual(result.file.read(), b"abc") def test_get_token_format_missing_attribute(self): """Test __get_token_format when token_format attribute is missing""" class DummyRequest: @@ -559,12 +596,11 @@ def track_sleep(*args): self.assertEqual(calls, [2, 2]) self.assertEqual(result.status, "SUCCESS") - def test_parse_deidentify_file_response_output_conversion(self): """Test output conversion in parse_deidentify_file_response""" class OutputObj: - processed_file = "file123" + processed_file = "YWJjMTIz" # base64 for "abc123" processed_file_type = "pdf" processed_file_extension = "pdf" @@ -574,6 +610,103 @@ class OutputObj: result = self.detect._Detect__parse_deidentify_file_response(data) - self.assertEqual(result.file, "file123") + # Check base64 string + self.assertEqual(result.file_base64, "YWJjMTIz") + # Check File object + self.assertIsInstance(result.file, File) + self.assertEqual(result.file.read(), b"abc123") + # Check other attributes self.assertEqual(result.type, "pdf") - self.assertEqual(result.extension, "pdf") \ No newline at end of file + self.assertEqual(result.extension, "pdf") + # Reset file pointer and verify content again + result.file.seek(0) + self.assertEqual(result.file.read(), b"abc123") + + @patch("skyflow.vault.controller._detect.validate_deidentify_file_request") + @patch("skyflow.vault.controller._detect.base64") + @patch("skyflow.vault.controller._detect.os.path.basename") + @patch("skyflow.vault.controller._detect.open", create=True) + def test_deidentify_file_using_file_path(self, mock_open, mock_basename, mock_base64, mock_validate): + # Setup mock file context + mock_file = MagicMock() + mock_file.read.return_value = b"test content from file path" + mock_file.name = "/path/to/test.txt" + mock_file.__enter__.return_value = mock_file # Mock context manager + mock_open.return_value = mock_file + mock_basename.return_value = "test.txt" + mock_base64.b64encode.return_value = b"dGVzdCBjb250ZW50IGZyb20gZmlsZSBwYXRo" # base64 of "test content from file path" + mock_base64.b64decode.return_value = b"test content from file path" + # Create request with file_path + req = DeidentifyFileRequest(file=FileInput(file_path="/path/to/test.txt")) + req.entities = [] + req.token_format = Mock(default="default", entity_unique_counter=[], entity_only=[]) + req.allow_regex_list = [] + req.restrict_regex_list = [] + req.transformations = None + req.output_directory = "/tmp" + + # Setup API mock + files_api = Mock() + files_api.with_raw_response = files_api + files_api.deidentify_text = Mock() + self.vault_client.get_detect_file_api.return_value = files_api + api_response = Mock() + api_response.data = Mock(run_id="runid123") + files_api.deidentify_text.return_value = api_response + + # Setup processed response + processed_response = Mock() + processed_response.status = "SUCCESS" + processed_response.output = [] + processed_response.word_character_count = Mock(word_count=1, character_count=1) + + # Test the method + with patch.object(self.detect, "_Detect__poll_for_processed_file", + return_value=processed_response) as mock_poll, \ + patch.object(self.detect, "_Detect__parse_deidentify_file_response", + return_value=DeidentifyFileResponse( + file_base64="dGVzdCBjb250ZW50IGZyb20gZmlsZSBwYXRo", + file=io.BytesIO(b"test content from file path"), + type="txt", + extension="txt", + word_count=1, + char_count=1, + size_in_kb=1, + duration_in_seconds=None, + page_count=None, + slide_count=None, + entities=[], + run_id="runid123", + status="SUCCESS", + errors=[] + )) as mock_parse: + + result = self.detect.deidentify_file(req) + + mock_file.read.assert_called_once() + mock_basename.assert_called_with("/path/to/test.txt") + + mock_validate.assert_called_once() + files_api.deidentify_text.assert_called_once() + mock_poll.assert_called_once() + mock_parse.assert_called_once() + + # Response assertions + self.assertIsInstance(result, DeidentifyFileResponse) + self.assertEqual(result.status, "SUCCESS") + self.assertEqual(result.run_id, "runid123") + self.assertEqual(result.file_base64, "dGVzdCBjb250ZW50IGZyb20gZmlsZSBwYXRo") + self.assertEqual(result.type, "txt") + self.assertEqual(result.extension, "txt") + + self.assertIsInstance(result.file, File) + result.file.seek(0) + self.assertEqual(result.file.read(), b"test content from file path") + self.assertEqual(result.word_count, 1) + self.assertEqual(result.char_count, 1) + self.assertEqual(result.size_in_kb, 1) + self.assertIsNone(result.duration_in_seconds) + self.assertIsNone(result.page_count) + self.assertIsNone(result.slide_count) + self.assertEqual(result.entities, []) + self.assertEqual(result.errors, []) \ No newline at end of file