diff --git a/README.md b/README.md index 39a58429..67b0d1c9 100644 --- a/README.md +++ b/README.md @@ -215,7 +215,7 @@ table_name = '' # Replace with your actual table name # Create Insert Request insert_request = InsertRequest( - table_name=table_name, + table=table_name, values=insert_data, return_tokens=True, # Optional: Get tokens for inserted data continue_on_error=True # Optional: Continue on partial errors @@ -273,7 +273,7 @@ options = InsertOptions( ```python insert_request = InsertRequest( - table_name=table_name, # Replace with the table name + table=table_name, # Replace with the table name values=insert_data, return_tokens=False, # Do not return tokens continue_on_error=False, # Stop inserting if any record fails @@ -474,7 +474,7 @@ try: # Step 2: Create Insert Request insert_request = InsertRequest( - table_name='table1', # Specify the table in the vault where the data will be inserted + table='table1', # Specify the table in the vault where the data will be inserted values=insert_data, # Attach the data (records) to be inserted return_tokens=True, # Specify if tokens should be returned upon successful insertion continue_on_error=True # Optional: Continue on partial errors @@ -551,7 +551,7 @@ try: # Step 2: Build an InsertRequest object with the table name and the data to insert insert_request = InsertRequest( - table_name='', # Replace with the actual table name in your Skyflow vault + table='', # Replace with the actual table name in your Skyflow vault values=insert_data, # Attach the data to be inserted ) @@ -608,7 +608,7 @@ try: # Step 4: Build the InsertRequest object with the data records to insert insert_request = InsertRequest( - table_name='table1', # Specify the table in the vault where the data will be inserted + table='table1', # Specify the table in the vault where the data will be inserted values=insert_data, # Attach the data (records) to be inserted return_tokens=True, # Specify if tokens should be returned upon successful insertion continue_on_error=True # Specify to continue inserting records even if an error occurs for some records @@ -686,7 +686,7 @@ try: # Step 3: Build the InsertRequest object with the upsertData insert_request = InsertRequest( - table_name='table1', # Specify the table in the vault where the data will be inserted + table='table1', # Specify the table in the vault where the data will be inserted values=insert_data, # Attach the data (records) to be inserted return_tokens=True, # Specify if tokens should be returned upon successful insertion upsert='cardholder_name' # Specify the field to be used for upsert operations (e.g., cardholder_name) @@ -1897,23 +1897,24 @@ ReidentifyTextResponse( ``` ### Deidentify File -To deidentify files, use the `deidentify_file` method. The `DeidentifyFileRequest` class creates a deidentify file request, which includes the file to be deidentified and various configuration options. +To deidentify files, use the `deidentify_file` method. The `DeidentifyFileRequest` class creates a deidentify file request, supports providing either a file or a file path in class FileInput for de-identification, along with various configuration options. #### Construct a Deidentify File request ```python from skyflow.error import SkyflowError from skyflow.utils.enums import DetectEntities, MaskingMethod, DetectOutputTranscriptions -from skyflow.vault.detect import DeidentifyFileRequest, TokenFormat, Transformations, Bleep +from skyflow.vault.detect import DeidentifyFileRequest, TokenFormat, Transformations, Bleep, FileInput """ This example demonstrates how to deidentify file, along with corresponding DeidentifyFileRequest schema. """ try: # Initialize Skyflow client # Step 1: Open file for deidentification - file = open('', 'rb') # Open the file in read-binary mode + file_path="" + file = open(file_path, 'rb') # Open the file in read-binary mode # Step 2: Create deidentify file request request = DeidentifyFileRequest( - file=file, # File object to deidentify + file=FileInput(file), # File to de-identify (can also provide a file path) entities=[DetectEntities.SSN, DetectEntities.CREDIT_CARD], # Entities to detect # Token format configuration @@ -1971,7 +1972,7 @@ except Exception as error: ```python from skyflow.error import SkyflowError from skyflow.utils.enums import DetectEntities, MaskingMethod, DetectOutputTranscriptions -from skyflow.vault.detect import DeidentifyFileRequest, TokenFormat, Bleep +from skyflow.vault.detect import DeidentifyFileRequest, TokenFormat, Bleep, FileInput """ * Skyflow Deidentify File Example * @@ -1985,7 +1986,7 @@ try: file = open('sensitive_document.txt', 'rb') # Open the file in read-binary mode # Step 2: Create deidentify file request request = DeidentifyFileRequest( - file=file, # File object to deidentify + file=FileInput(file), # File to de-identify (can also provide a file path) entities=[ DetectEntities.SSN, DetectEntities.CREDIT_CARD @@ -2038,7 +2039,6 @@ DeidentifyFileResponse( ], run_id='83abcdef-2b61-4a83-a4e0-cbc71ffabffd', status='SUCCESS', - errors=[] ) ``` @@ -2121,7 +2121,7 @@ except Exception as error: print('Unexpected Error:', error) # Print the stack trace for debugging purposes ``` -Sample Response +Sample Response: ```python DeidentifyFileResponse( file='TXkgY2FyZCBudW1iZXIgaXMgW0NSRURJVF9DQVJEXQpteSBzZWNvbmQ…', # Base64 encoded file content @@ -2142,7 +2142,26 @@ DeidentifyFileResponse( ], run_id='48ec05ba-96ec-4641-a8e2-35e066afef95', status='SUCCESS', - errors=[] +) +``` + +Incase of invalid/expired RunId: + +```python +DeidentifyFileResponse( + file_base64=None, + file=None, + type='UNKNOWN', + extension=None, + word_count=None, + char_count=None, + size_in_kb=0.0, + duration_in_seconds=None, + page_count=None, + slide_count=None, + entities=[], + run_id='1e9f321f-dd51-4ab1-a014-21212fsdfsd', + status='UNKNOWN' ) ``` diff --git a/samples/detect_api/deidentify_file.py b/samples/detect_api/deidentify_file.py index c9877d58..99b4b26e 100644 --- a/samples/detect_api/deidentify_file.py +++ b/samples/detect_api/deidentify_file.py @@ -1,7 +1,7 @@ from skyflow.error import SkyflowError from skyflow import Env, Skyflow, LogLevel from skyflow.utils.enums import DetectEntities, MaskingMethod, DetectOutputTranscriptions -from skyflow.vault.detect import DeidentifyFileRequest, TokenFormat, Transformations, DateTransformation, Bleep +from skyflow.vault.detect import DeidentifyFileRequest, TokenFormat, Transformations, DateTransformation, Bleep, FileInput """ * Skyflow Deidentify File Example @@ -39,7 +39,7 @@ def perform_file_deidentification(): file = open(file_path, 'rb') # Step 5: Configure Deidentify File Request with all options deidentify_request = DeidentifyFileRequest( - file=file, # File object to deidentify + file=FileInput(file), # File to de-identify (can also provide a file path) entities=[DetectEntities.SSN, DetectEntities.CREDIT_CARD], # Entities to detect allow_regex_list=[''], # Optional: Patterns to allow restrict_regex_list=[''], # Optional: Patterns to restrict diff --git a/samples/vault_api/insert_byot.py b/samples/vault_api/insert_byot.py index ae4c1eae..5161f886 100644 --- a/samples/vault_api/insert_byot.py +++ b/samples/vault_api/insert_byot.py @@ -70,7 +70,7 @@ def perform_secure_data_insertion_with_byot(): ] insert_request = InsertRequest( - table_name=table_name, + table=table_name, values=insert_data, token_mode=TokenMode.ENABLE, # Enable Bring Your Own Token (BYOT) tokens=tokens, # Specify tokens to use for BYOT diff --git a/samples/vault_api/insert_records.py b/samples/vault_api/insert_records.py index 32ec1fae..76ec2259 100644 --- a/samples/vault_api/insert_records.py +++ b/samples/vault_api/insert_records.py @@ -47,7 +47,7 @@ def perform_secure_data_insertion(): # Step 5: Create Insert Request insert_request = InsertRequest( - table_name=table_name, + table=table_name, values=insert_data, return_tokens=True, # Optional: Get tokens for inserted data continue_on_error=True # Optional: Continue on partial errors diff --git a/skyflow/utils/_skyflow_messages.py b/skyflow/utils/_skyflow_messages.py index 460ca29e..e92d251c 100644 --- a/skyflow/utils/_skyflow_messages.py +++ b/skyflow/utils/_skyflow_messages.py @@ -374,6 +374,7 @@ class ErrorLogs(Enum): DEIDENTIFY_FILE_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] Deidentify file resulted in failure." DETECT_RUN_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] Detect get run resulted in failure." DEIDENTIFY_TEXT_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] Deidentify text resulted in failure." + SAVING_DEIDENTIFY_FILE_FAILED = f"{ERROR}: [{error_prefix}] Error while saving deidentified file to output directory." REIDENTIFY_TEXT_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] Reidentify text resulted in failure." DETECT_FILE_REQUEST_REJECTED = f"{ERROR}: [{error_prefix}] Deidentify file resulted in failure." diff --git a/skyflow/utils/_utils.py b/skyflow/utils/_utils.py index 77ffe580..114079b5 100644 --- a/skyflow/utils/_utils.py +++ b/skyflow/utils/_utils.py @@ -447,3 +447,6 @@ def encode_column_values(get_request): encoded_column_values.append(quote(column)) return encoded_column_values + +def get_attribute(obj, camel_case, snake_case): + return getattr(obj, camel_case, None) or getattr(obj, snake_case, None) diff --git a/skyflow/utils/validations/_validations.py b/skyflow/utils/validations/_validations.py index bbca6e85..a7840f07 100644 --- a/skyflow/utils/validations/_validations.py +++ b/skyflow/utils/validations/_validations.py @@ -276,7 +276,7 @@ def validate_file_from_request(file_input: FileInput): raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_TYPE.value, invalid_input_error_code) # Validate file name - file_name = os.path.splitext(file.name)[0] + file_name, _ = os.path.splitext(os.path.basename(file.name)) if not file_name or not file_name.strip(): raise SkyflowError(SkyflowMessages.Error.INVALID_FILE_NAME.value, invalid_input_error_code) @@ -393,10 +393,10 @@ def validate_deidentify_file_request(logger, request: DeidentifyFileRequest): raise SkyflowError(SkyflowMessages.Error.WAIT_TIME_GREATER_THEN_64.value, invalid_input_error_code) def validate_insert_request(logger, request): - if not isinstance(request.table_name, str): + if not isinstance(request.table, str): log_error_log(SkyflowMessages.ErrorLogs.TABLE_IS_REQUIRED.value.format("INSERT"), logger = logger) raise SkyflowError(SkyflowMessages.Error.INVALID_TABLE_NAME_IN_INSERT.value, invalid_input_error_code) - if not request.table_name.strip(): + if not request.table.strip(): log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TABLE_NAME.value.format("INSERT"), logger = logger) raise SkyflowError(SkyflowMessages.Error.MISSING_TABLE_NAME_IN_INSERT.value, invalid_input_error_code) diff --git a/skyflow/vault/controller/_detect.py b/skyflow/vault/controller/_detect.py index 93fac69e..62d551c1 100644 --- a/skyflow/vault/controller/_detect.py +++ b/skyflow/vault/controller/_detect.py @@ -6,8 +6,9 @@ from skyflow.generated.rest import DeidentifyTextRequestFile, DeidentifyAudioRequestFile, DeidentifyPdfRequestFile, \ DeidentifyImageRequestFile, DeidentifyPresentationRequestFile, DeidentifySpreadsheetRequestFile, \ DeidentifyDocumentRequestFile, DeidentifyFileRequestFile +from skyflow.generated.rest.types.deidentify_status_response import DeidentifyStatusResponse from skyflow.utils._skyflow_messages import SkyflowMessages -from skyflow.utils._utils import get_metrics, handle_exception, parse_deidentify_text_response, parse_reidentify_text_response +from skyflow.utils._utils import get_attribute, get_metrics, handle_exception, parse_deidentify_text_response, parse_reidentify_text_response from skyflow.utils.constants import SKY_META_DATA_HEADER from skyflow.utils.logger import log_info, log_error_log from skyflow.utils.validations import validate_deidentify_file_request, validate_get_detect_run_request @@ -83,6 +84,43 @@ def __poll_for_processed_file(self, run_id, max_wait_time=64): except Exception as e: raise e + def __save_deidentify_file_response_output(self, response: DeidentifyStatusResponse, output_directory: str, original_file_name: str, name_without_ext: str): + if not response or not hasattr(response, 'output') or not response.output or not output_directory: + return + + if not os.path.exists(output_directory): + return + + deidentify_file_prefix = "processed-" + output_list = response.output + + base_original_filename = os.path.basename(original_file_name) + base_name_without_ext = os.path.splitext(base_original_filename)[0] + + for idx, output in enumerate(output_list): + try: + processed_file = get_attribute(output, 'processedFile', 'processed_file') + processed_file_type = get_attribute(output, 'processedFileType', 'processed_file_type') + processed_file_extension = get_attribute(output, 'processedFileExtension', 'processed_file_extension') + + if not processed_file: + continue + + decoded_data = base64.b64decode(processed_file) + + if idx == 0 or processed_file_type == 'redacted_file': + output_file_name = os.path.join(output_directory, deidentify_file_prefix + base_original_filename) + if processed_file_extension: + output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{processed_file_extension}") + else: + output_file_name = os.path.join(output_directory, f"{deidentify_file_prefix}{base_name_without_ext}.{processed_file_extension}") + + with open(output_file_name, 'wb') as f: + f.write(decoded_data) + except Exception as e: + log_error_log(SkyflowMessages.ErrorLogs.SAVING_DEIDENTIFY_FILE_FAILED.value, self.__vault_client.get_logger()) + handle_exception(e, self.__vault_client.get_logger()) + def __parse_deidentify_file_response(self, data, run_id=None, status=None): output = getattr(data, "output", []) status_val = getattr(data, "status", None) or status @@ -141,8 +179,8 @@ def output_to_dict_list(output): return DeidentifyFileResponse( file_base64=base64_string, - file=file_obj, # File class will be instantiated in DeidentifyFileResponse - type=first_output.get("type", None), + file=file_obj, + type=first_output.get("type", "UNKNOWN"), extension=extension, word_count=word_count, char_count=char_count, @@ -153,7 +191,6 @@ def output_to_dict_list(output): entities=entities, run_id=run_id_val, status=status_val, - errors=None ) def __get_token_format(self, request): @@ -396,12 +433,11 @@ def deidentify_file(self, request: DeidentifyFileRequest): run_id = getattr(api_response.data, 'run_id', None) processed_response = self.__poll_for_processed_file(run_id, request.wait_time) - parsed_response = self.__parse_deidentify_file_response(processed_response, run_id) if request.output_directory and processed_response.status == 'SUCCESS': - file_name_only = 'processed-'+os.path.basename(file_name) - output_file_path = f"{request.output_directory}/{file_name_only}" - with open(output_file_path, 'wb') as output_file: - output_file.write(base64.b64decode(parsed_response.file_base64)) + name_without_ext, _ = os.path.splitext(file_name) + self.__save_deidentify_file_response_output(processed_response, request.output_directory, file_name, name_without_ext) + + parsed_response = self.__parse_deidentify_file_response(processed_response, run_id) log_info(SkyflowMessages.Info.DETECT_FILE_SUCCESS.value, self.__vault_client.get_logger()) return parsed_response @@ -411,9 +447,9 @@ def deidentify_file(self, request: DeidentifyFileRequest): handle_exception(e, self.__vault_client.get_logger()) def get_detect_run(self, request: GetDetectRunRequest): + log_info(SkyflowMessages.Info.GET_DETECT_RUN_TRIGGERED.value,self.__vault_client.get_logger()) log_info(SkyflowMessages.Info.VALIDATING_GET_DETECT_RUN_INPUT.value, self.__vault_client.get_logger()) validate_get_detect_run_request(self.__vault_client.get_logger(), request) - log_info(SkyflowMessages.Info.DEIDENTIFY_TEXT_REQUEST_RESOLVED.value, self.__vault_client.get_logger()) self.__initialize() files_api = self.__vault_client.get_detect_file_api().with_raw_response @@ -428,6 +464,7 @@ def get_detect_run(self, request: GetDetectRunRequest): parsed_response = self.__parse_deidentify_file_response(DeidentifyFileResponse(run_id=run_id, status='IN_PROGRESS')) else: parsed_response = self.__parse_deidentify_file_response(response.data, run_id, response.data.status) + log_info(SkyflowMessages.Info.GET_DETECT_RUN_SUCCESS.value,self.__vault_client.get_logger()) return parsed_response except Exception as e: log_error_log(SkyflowMessages.ErrorLogs.DETECT_FILE_REQUEST_REJECTED.value, diff --git a/skyflow/vault/controller/_vault.py b/skyflow/vault/controller/_vault.py index 4602cf87..7a288724 100644 --- a/skyflow/vault/controller/_vault.py +++ b/skyflow/vault/controller/_vault.py @@ -6,6 +6,7 @@ parse_tokenize_response, parse_query_response, parse_get_response, encode_column_values, get_metrics from skyflow.utils.constants import SKY_META_DATA_HEADER from skyflow.utils.enums import RequestMethod +from skyflow.utils.enums.redaction_type import RedactionType from skyflow.utils.logger import log_info, log_error_log from skyflow.utils.validations import validate_insert_request, validate_delete_request, validate_query_request, \ validate_get_request, validate_update_request, validate_detokenize_request, validate_tokenize_request @@ -53,7 +54,7 @@ def __build_insert_body(self, request: InsertRequest): records_list = self.__build_batch_field_records( request.values, request.tokens, - request.table_name, + request.table, request.return_tokens, request.upsert ) @@ -85,7 +86,7 @@ def insert(self, request: InsertRequest): else: api_response = records_api.record_service_insert_record(self.__vault_client.get_vault_id(), - request.table_name, records=insert_body,tokenization= request.return_tokens, upsert=request.upsert, homogeneous=request.homogeneous, byot=request.token_mode.value, request_options=self.__get_headers()) + request.table, records=insert_body,tokenization= request.return_tokens, upsert=request.upsert, homogeneous=request.homogeneous, byot=request.token_mode.value, request_options=self.__get_headers()) insert_response = parse_insert_response(api_response, request.continue_on_error) log_info(SkyflowMessages.Info.INSERT_SUCCESS.value, self.__vault_client.get_logger()) @@ -201,7 +202,7 @@ def detokenize(self, request: DetokenizeRequest): tokens_list = [ V1DetokenizeRecordRequest( token=item.get('token'), - redaction=item.get('redaction', None) + redaction=item.get('redaction', RedactionType.DEFAULT) ) for item in request.data ] diff --git a/skyflow/vault/data/_insert_request.py b/skyflow/vault/data/_insert_request.py index 742c5120..909edd88 100644 --- a/skyflow/vault/data/_insert_request.py +++ b/skyflow/vault/data/_insert_request.py @@ -2,7 +2,7 @@ class InsertRequest: def __init__(self, - table_name, + table, values, tokens = None, upsert = None, @@ -10,7 +10,7 @@ def __init__(self, token_mode = TokenMode.DISABLE, return_tokens = True, continue_on_error = False): - self.table_name = table_name + self.table = table self.values = values self.tokens = tokens self.upsert = upsert diff --git a/skyflow/vault/detect/_deidentify_file_response.py b/skyflow/vault/detect/_deidentify_file_response.py index 90a0d493..b340e21c 100644 --- a/skyflow/vault/detect/_deidentify_file_response.py +++ b/skyflow/vault/detect/_deidentify_file_response.py @@ -17,7 +17,6 @@ def __init__( entities: list = None, # list of dicts with keys 'file' and 'extension' run_id: str = None, status: str = None, - errors: list = None, ): self.file_base64 = file_base64 self.file = File(file) if file else None @@ -32,7 +31,6 @@ def __init__( self.entities = entities if entities is not None else [] self.run_id = run_id self.status = status - self.errors = errors def __repr__(self): return ( @@ -42,7 +40,7 @@ def __repr__(self): f"char_count={self.char_count!r}, size_in_kb={self.size_in_kb!r}, " f"duration_in_seconds={self.duration_in_seconds!r}, page_count={self.page_count!r}, " f"slide_count={self.slide_count!r}, entities={self.entities!r}, " - f"run_id={self.run_id!r}, status={self.status!r}, errors={self.errors!r})" + f"run_id={self.run_id!r}, status={self.status!r})" ) def __str__(self): diff --git a/skyflow/vault/detect/_file_input.py b/skyflow/vault/detect/_file_input.py index 472ca0e2..6b8bc2fb 100644 --- a/skyflow/vault/detect/_file_input.py +++ b/skyflow/vault/detect/_file_input.py @@ -1,13 +1,15 @@ +from io import BufferedReader + class FileInput: """ Represents a file input for the vault detection process. Attributes: - file (str): The file object to be processed. This can be a file-like object or a binary string. + file (BufferedReader): The file object to be processed. This can be a file-like object or a binary string. file_path (str): The path to the file to be processed. """ - def __init__(self, file: str= None, file_path: str = None): + def __init__(self, file: BufferedReader= None, file_path: str = None): self.file = file self.file_path = file_path diff --git a/skyflow/vault/tokens/_tokenize_response.py b/skyflow/vault/tokens/_tokenize_response.py index 264b3987..598c2a1c 100644 --- a/skyflow/vault/tokens/_tokenize_response.py +++ b/skyflow/vault/tokens/_tokenize_response.py @@ -1,10 +1,11 @@ class TokenizeResponse: - def __init__(self, tokenized_fields = None): + def __init__(self, tokenized_fields = None, errors = None): self.tokenized_fields = tokenized_fields + self.errors = errors def __repr__(self): - return f"TokenizeResponse(tokenized_fields={self.tokenized_fields})" + return f"TokenizeResponse(tokenized_fields={self.tokenized_fields}, errors={self.errors})" def __str__(self): return self.__repr__() diff --git a/tests/utils/validations/__init__.py b/tests/utils/validations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utils/validations/test__validations.py b/tests/utils/validations/test__validations.py new file mode 100644 index 00000000..48332a55 --- /dev/null +++ b/tests/utils/validations/test__validations.py @@ -0,0 +1,1046 @@ +import unittest +from unittest.mock import Mock, patch, MagicMock +import tempfile +import os + +from skyflow.error import SkyflowError +from skyflow.utils.validations._validations import ( + validate_required_field, validate_api_key, validate_credentials, + validate_log_level, validate_keys, validate_vault_config, + validate_update_vault_config, validate_connection_config, + validate_update_connection_config, validate_file_from_request, + validate_insert_request, validate_delete_request, validate_query_request, + validate_get_detect_run_request, validate_get_request, validate_update_request, + validate_detokenize_request, validate_tokenize_request, validate_invoke_connection_params, + validate_deidentify_text_request, validate_reidentify_text_request, validate_deidentify_file_request +) +from skyflow.utils import SkyflowMessages +from skyflow.utils.enums import DetectEntities, RedactionType +from skyflow.vault.data import GetRequest, UpdateRequest +from skyflow.vault.detect import DeidentifyTextRequest, Transformations, DateTransformation, ReidentifyTextRequest, \ + FileInput, DeidentifyFileRequest +from skyflow.vault.tokens import DetokenizeRequest +from skyflow.vault.connection._invoke_connection_request import InvokeConnectionRequest + +class TestValidations(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.temp_file = tempfile.NamedTemporaryFile(delete=False) + cls.temp_file.write(b"test content") + cls.temp_file.close() + cls.temp_file_path = cls.temp_file.name + cls.temp_dir = tempfile.TemporaryDirectory() + cls.temp_dir_path = cls.temp_dir.name + + @classmethod + def tearDownClass(cls): + if os.path.exists(cls.temp_file_path): + os.unlink(cls.temp_file_path) + cls.temp_dir.cleanup() + + def setUp(self): + self.logger = Mock() + + def test_validate_required_field_valid(self): + config = {"test_field": "test_value"} + validate_required_field( + self.logger, + config, + "test_field", + str, + "Empty error", + "Invalid error" + ) + + def test_validate_required_field_missing(self): + config = {} + with self.assertRaises(SkyflowError) as context: + validate_required_field( + self.logger, + config, + "vault_id", + str, + "Empty error", + "Invalid error" + ) + self.assertEqual(context.exception.message, "Invalid error") + + def test_validate_required_field_empty_string(self): + config = {"test_field": ""} + with self.assertRaises(SkyflowError) as context: + validate_required_field( + self.logger, + config, + "test_field", + str, + "Empty error", + "Invalid error" + ) + self.assertEqual(context.exception.message, "Empty error") + + def test_validate_required_field_wrong_type(self): + config = {"test_field": 123} + with self.assertRaises(SkyflowError) as context: + validate_required_field( + self.logger, + config, + "test_field", + str, + "Empty error", + "Invalid error" + ) + self.assertEqual(context.exception.message, "Invalid error") + + def test_validate_api_key_valid(self): + valid_key = "sky-abc12-1234567890abcdef1234567890abcdef" + self.assertTrue(validate_api_key(valid_key, self.logger)) + + def test_validate_api_key_invalid_prefix(self): + invalid_key = "invalid-abc12-1234567890abcdef1234567890abcdef" + self.assertFalse(validate_api_key(invalid_key, self.logger)) + + def test_validate_api_key_invalid_length(self): + invalid_key = "sky-abc12-123456" + self.assertFalse(validate_api_key(invalid_key, self.logger)) + + def test_validate_credentials_with_api_key(self): + credentials = { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef" + } + validate_credentials(self.logger, credentials) + + def test_validate_credentials_with_expired_token(self): + credentials = { + "token": "expired_token" + } + with patch('skyflow.service_account.is_expired', return_value=True): + with self.assertRaises(SkyflowError) as context: + validate_credentials(self.logger, credentials) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS_TOKEN.value) + + def test_validate_credentials_empty_credentials(self): + credentials = {} + with self.assertRaises(SkyflowError) as context: + validate_credentials(self.logger, credentials) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value) + + def test_validate_credentials_multiple_auth_methods(self): + credentials = { + "token": "valid_token", + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef" + } + with self.assertRaises(SkyflowError) as context: + validate_credentials(self.logger, credentials) + self.assertEqual(context.exception.message, SkyflowMessages.Error.MULTIPLE_CREDENTIALS_PASSED.value) + + + def test_validate_credentials_with_empty_context(self): + credentials = { + "token": "valid_token", + "context": "" + } + with patch('skyflow.service_account.is_expired', return_value=False): + with self.assertRaises(SkyflowError) as context: + validate_credentials(self.logger, credentials) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_CONTEXT.value) + + def test_validate_log_level_valid(self): + from skyflow.utils.enums import LogLevel + log_level = LogLevel.ERROR + validate_log_level(self.logger, log_level) + + def test_validate_log_level_invalid(self): + class InvalidEnum: + pass + invalid_log_level = InvalidEnum() + with self.assertRaises(SkyflowError) as context: + validate_log_level(self.logger, invalid_log_level) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_LOG_LEVEL.value) + + def test_validate_log_level_none(self): + with self.assertRaises(SkyflowError) as context: + validate_log_level(self.logger, None) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_LOG_LEVEL.value) + + def test_validate_keys_valid(self): + config = {"vault_id": "test_id", "cluster_id": "test_cluster"} + validate_keys(self.logger, config, ["vault_id", "cluster_id"]) + + def test_validate_keys_invalid(self): + config = {"invalid_key": "value"} + with self.assertRaises(SkyflowError) as context: + validate_keys(self.logger, config, ["vault_id", "cluster_id"]) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_KEY.value.format("invalid_key")) + + def test_validate_vault_config_valid(self): + from skyflow.utils.enums import Env + config = { + "vault_id": "vault123", + "cluster_id": "cluster123", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef" + }, + "env": Env.DEV + } + self.assertTrue(validate_vault_config(self.logger, config)) + + def test_validate_vault_config_missing_required(self): + config = { + "cluster_id": "cluster123" + } + with self.assertRaises(SkyflowError) as context: + validate_vault_config(self.logger, config) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_VAULT_ID.value) + + + def test_validate_update_vault_config_valid(self): + from skyflow.utils.enums import Env + config = { + "vault_id": "vault123", + "cluster_id": "cluster123", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef" + }, + "env": Env.DEV + } + self.assertTrue(validate_update_vault_config(self.logger, config)) + + def test_validate_update_vault_config_missing_credentials(self): + config = { + "vault_id": "vault123", + "cluster_id": "cluster123" + } + with self.assertRaises(SkyflowError) as context: + validate_update_vault_config(self.logger, config) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("vault", "vault123")) + + def test_validate_update_vault_config_invalid_cluster_id(self): + config = { + "vault_id": "vault123", + "cluster_id": "", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef" + } + } + with self.assertRaises(SkyflowError) as context: + validate_update_vault_config(self.logger, config) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CLUSTER_ID.value.format("vault123")) + + def test_validate_connection_config_valid(self): + config = { + "connection_id": "conn123", + "connection_url": "https://example.com", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef" + } + } + self.assertTrue(validate_connection_config(self.logger, config)) + + def test_validate_connection_config_missing_url(self): + config = { + "connection_id": "conn123", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef" + } + } + with self.assertRaises(SkyflowError) as context: + validate_connection_config(self.logger, config) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CONNECTION_URL.value.format("conn123")) + + def test_validate_connection_config_empty_connection_id(self): + config = { + "connection_id": "", + "connection_url": "https://example.com", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef" + } + } + with self.assertRaises(SkyflowError) as context: + validate_connection_config(self.logger, config) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_CONNECTION_ID.value) + + def test_validate_update_connection_config_valid(self): + config = { + "connection_id": "conn123", + "connection_url": "https://example.com", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef" + } + } + self.assertTrue(validate_update_connection_config(self.logger, config)) + + def test_validate_update_connection_config_missing_credentials(self): + config = { + "connection_id": "conn123", + "connection_url": "https://example.com" + } + with self.assertRaises(SkyflowError) as context: + validate_update_connection_config(self.logger, config) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_CREDENTIALS.value.format("connection", "conn123")) + + def test_validate_update_connection_config_empty_url(self): + config = { + "connection_id": "conn123", + "connection_url": "", + "credentials": { + "api_key": "sky-abc12-1234567890abcdef1234567890abcdef" + } + } + with self.assertRaises(SkyflowError) as context: + validate_update_connection_config(self.logger, config) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_CONNECTION_URL.value.format("conn123")) + + def test_validate_file_from_request_valid_file(self): + file_obj = MagicMock() + file_obj.name = "test.txt" + file_input = MagicMock() + file_input.file = file_obj + file_input.file_path = None + validate_file_from_request(file_input) + + def test_validate_file_from_request_valid_file_path(self): + file_input = MagicMock() + file_input.file = None + file_input.file_path = self.temp_file_path + validate_file_from_request(file_input) + + def test_validate_file_from_request_missing_both(self): + file_input = MagicMock() + file_input.file = None + file_input.file_path = None + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_INPUT.value) + + def test_validate_file_from_request_both_provided(self): + file_obj = MagicMock() + file_obj.name = "test.txt" + file_input = MagicMock() + file_input.file = file_obj + file_input.file_path = "/path/to/file" + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_INPUT.value) + + + def test_validate_file_from_request_invalid_file_path(self): + file_input = MagicMock() + file_input.file = None + file_input.file_path = "/nonexistent/path/to/file" + with self.assertRaises(SkyflowError) as context: + validate_file_from_request(file_input) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_DEIDENTIFY_FILE_PATH.value) + + def test_validate_insert_request_valid(self): + request = MagicMock() + request.table = "test_table" + request.values = [{"field1": "value1"}] + request.upsert = None + request.homogeneous = None + request.token_mode = None + request.return_tokens = False + request.continue_on_error = False + request.tokens = None + validate_insert_request(self.logger, request) + + def test_validate_insert_request_invalid_table(self): + request = MagicMock() + request.table = 123 + request.values = [{"field1": "value1"}] + with self.assertRaises(SkyflowError) as context: + validate_insert_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TABLE_NAME_IN_INSERT.value) + + def test_validate_insert_request_empty_values(self): + request = MagicMock() + request.table = "test_table" + request.values = [] + with self.assertRaises(SkyflowError) as context: + validate_insert_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_DATA_IN_INSERT.value) + + + def test_validate_delete_request_valid(self): + request = MagicMock() + request.table = "test_table" + request.ids = ["id1", "id2"] + validate_delete_request(self.logger, request) + + def test_validate_delete_request_empty_table(self): + request = MagicMock() + request.table = "" + request.ids = ["id1"] + with self.assertRaises(SkyflowError) as context: + validate_delete_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_TABLE_VALUE.value) + + def test_validate_delete_request_missing_ids(self): + request = MagicMock() + request.table = "test_table" + request.ids = None + with self.assertRaises(SkyflowError) as context: + validate_delete_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_RECORD_IDS_IN_DELETE.value) + + def test_validate_query_request_valid(self): + request = MagicMock() + request.query = "SELECT * FROM test_table" + validate_query_request(self.logger, request) + + def test_validate_query_request_empty_query(self): + request = MagicMock() + request.query = "" + with self.assertRaises(SkyflowError) as context: + validate_query_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_QUERY.value) + + def test_validate_query_request_invalid_query_type(self): + request = MagicMock() + request.query = 123 + with self.assertRaises(SkyflowError) as context: + validate_query_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_QUERY_TYPE.value.format(str(type(123)))) + + def test_validate_query_request_non_select_query(self): + request = MagicMock() + request.query = "INSERT INTO test_table VALUES (1)" + with self.assertRaises(SkyflowError) as context: + validate_query_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_QUERY_COMMAND.value.format(request.query)) + + def test_validate_get_detect_run_request_valid(self): + request = MagicMock() + request.run_id = "test_run_123" + validate_get_detect_run_request(self.logger, request) + + def test_validate_get_detect_run_request_empty_run_id(self): + request = MagicMock() + request.run_id = "" + with self.assertRaises(SkyflowError) as context: + validate_get_detect_run_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_RUN_ID.value) + + def test_validate_get_detect_run_request_invalid_run_id_type(self): + request = MagicMock() + request.run_id = 123 # Invalid type + with self.assertRaises(SkyflowError) as context: + validate_get_detect_run_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_RUN_ID.value) + + def test_validate_get_request_valid(self): + from skyflow.utils.enums import RedactionType + request = MagicMock() + request.table = "test_table" + request.redaction_type = RedactionType.PLAIN_TEXT + request.column_name = None + request.column_values = None + request.ids = ["id1", "id2"] + request.fields = ["field1", "field2"] + request.offset = None + request.limit = None + request.download_url = False + request.return_tokens = False + validate_get_request(self.logger, request) + + + def test_validate_get_request_invalid_table_type(self): + request = MagicMock() + request.table = 123 + with self.assertRaises(SkyflowError) as context: + validate_get_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TABLE_VALUE.value) + + def test_validate_get_request_empty_table(self): + request = MagicMock() + request.table = "" + with self.assertRaises(SkyflowError) as context: + validate_get_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_TABLE_VALUE.value) + + def test_validate_get_request_invalid_redaction_type(self): + request = GetRequest( + table="test_table", + fields="invalid", + ids=["id1", "id2"], + redaction_type="invalid" + ) + + with self.assertRaises(SkyflowError) as context: + validate_get_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(type(request.redaction_type))) + + def test_validate_get_request_invalid_fields_type(self): + request= GetRequest( + table="test_table", + fields="invalid" + ) + with self.assertRaises(SkyflowError) as context: + validate_get_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_FIELDS_VALUE.value.format(type(request.fields))) + + def test_validate_get_request_empty_fields(self): + request = GetRequest( + table="test_table", + ids=[], + fields=[] + ) + with self.assertRaises(SkyflowError) as context: + validate_get_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_FIELDS_VALUE.value.format(type(request.fields))) + + def test_validate_get_request_invalid_column_values_type(self): + request = GetRequest( + table="test_table", + column_name="test_column", + column_values="invalid", + ) + + with self.assertRaises(SkyflowError) as context: + validate_get_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_COLUMN_VALUE.value.format(type(request.column_values))) + + def test_validate_get_request_tokens_with_redaction(self): + request = GetRequest( + table="test_table", + return_tokens=True, + redaction_type = RedactionType.PLAIN_TEXT + ) + + with self.assertRaises(SkyflowError) as context: + validate_get_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.REDACTION_WITH_TOKENS_NOT_SUPPORTED.value) + + def test_validate_query_request_valid_complex(self): + request = MagicMock() + request.query = "SELECT * FROM table1 JOIN table2 ON table1.id = table2.id WHERE field = 'value'" + validate_query_request(self.logger, request) + + + def test_validate_query_request_invalid_update(self): + request = MagicMock() + request.query = "UPDATE table SET field = 'value'" # Only SELECT allowed + with self.assertRaises(SkyflowError) as context: + validate_query_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_QUERY_COMMAND.value.format(request.query)) + + def test_validate_update_request_valid(self): + request = MagicMock() + request.table = "test_table" + request.data = {"skyflow_id": "id123", "field1": "value1"} + request.return_tokens = False + request.token_mode = None + request.tokens = None + validate_update_request(self.logger, request) + + def test_validate_update_request_invalid_table_type(self): + request = UpdateRequest( + table=123, + data = {"skyflow_id": "id123"} + ) + with self.assertRaises(SkyflowError) as context: + validate_update_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TABLE_VALUE.value) + + def test_validate_update_request_invalid_token_mode(self): + request = UpdateRequest( + table="test_table", + data = {"skyflow_id": "id123", "field1": "value1"}, + token_mode = "invalid" + ) + with self.assertRaises(SkyflowError) as context: + validate_update_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_MODE_TYPE.value) + + def test_validate_detokenize_request_valid(self): + request = MagicMock() + request.data = [{"token": "token123"}] + request.continue_on_error = False + validate_detokenize_request(self.logger, request) + + def test_validate_detokenize_request_empty_data(self): + request = MagicMock() + request.data = [] # Empty list + request.continue_on_error = False + with self.assertRaises(SkyflowError) as context: + validate_detokenize_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_TOKENS_LIST_VALUE.value) + + def test_validate_detokenize_request_invalid_token(self): + request = MagicMock() + request.data = [{"token": 123}] # Invalid token type + request.continue_on_error = False + with self.assertRaises(SkyflowError) as context: + validate_detokenize_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_TOKEN_TYPE.value.format("DETOKENIZE")) + + def test_validate_tokenize_request_valid(self): + request = MagicMock() + request.values = [{"value": "test", "column_group": "group1"}] + validate_tokenize_request(self.logger, request) + + + def test_validate_tokenize_request_invalid_values_type(self): + request = MagicMock() + request.values = "invalid" # Should be list + with self.assertRaises(SkyflowError) as context: + validate_tokenize_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETERS.value.format(type(request.values))) + + def test_validate_tokenize_request_empty_values(self): + request = MagicMock() + request.values = [] # Empty list + with self.assertRaises(SkyflowError) as context: + validate_tokenize_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_TOKENIZE_PARAMETERS.value) + + def test_validate_tokenize_request_missing_required_fields(self): + request = MagicMock() + request.values = [{"value": "test"}] # Missing column_group + with self.assertRaises(SkyflowError) as context: + validate_tokenize_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_TOKENIZE_PARAMETER_KEY.value.format(0)) + + def test_validate_invoke_connection_params_valid(self): + query_params = {"param1": "value1"} + path_params = {"path1": "value1"} + validate_invoke_connection_params(self.logger, query_params, path_params) + + def test_validate_invoke_connection_params_invalid_path_params_type(self): + request = InvokeConnectionRequest( + method="GET", + query_params={"param1": "value1"}, + path_params="invalid" + ) + with self.assertRaises(SkyflowError) as context: + validate_invoke_connection_params(self.logger, request.query_params, request.path_params) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_PATH_PARAMS.value) + + def test_validate_invoke_connection_params_invalid_query_params_type(self): + request = InvokeConnectionRequest( + method="GET", + query_params="invalid", + path_params={"path1": "value1"} + ) + with self.assertRaises(SkyflowError) as context: + validate_invoke_connection_params(self.logger, request.query_params, request.path_params) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_QUERY_PARAMS.value) + + def test_validate_invoke_connection_params_non_string_path_param(self): + request = InvokeConnectionRequest( + method="GET", + query_params={"param1": "value1"}, + path_params={1: "value1"} + ) + with self.assertRaises(SkyflowError) as context: + validate_invoke_connection_params(self.logger, request.query_params, request.path_params) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_PATH_PARAMS.value) + + def test_validate_invoke_connection_params_non_string_query_param_key(self): + request = InvokeConnectionRequest( + method="GET", + query_params={1: "value1"}, + path_params={"path1": "value1"} + ) + with self.assertRaises(SkyflowError) as context: + validate_invoke_connection_params(self.logger, request.query_params, request.path_params) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_QUERY_PARAMS.value) + + def test_validate_invoke_connection_params_non_serializable_query_params(self): + class NonSerializable: + pass + request = InvokeConnectionRequest( + method="GET", + query_params={"param1": NonSerializable()}, + path_params={"path1": "value1"} + ) + with self.assertRaises(SkyflowError) as context: + validate_invoke_connection_params(self.logger, request.query_params, request.path_params) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_QUERY_PARAMS.value) + + def test_validate_deidentify_text_request_valid(self): + request = DeidentifyTextRequest( + text="test", + entities=None, + allow_regex_list=None, + restrict_regex_list = None, + token_format = None, + transformations = None, + ) + validate_deidentify_text_request(self.logger, request) + + def test_validate_reidentify_text_request_valid(self): + request = ReidentifyTextRequest( + text="test", + masked_entities=[DetectEntities.CREDIT_CARD], + redacted_entities=[DetectEntities.SSN], + plain_text_entities=None, + ) + validate_reidentify_text_request(self.logger, request) + + def test_validate_reidentify_text_request_empty_text(self): + request = ReidentifyTextRequest( + text="", + masked_entities=[DetectEntities.CREDIT_CARD], + redacted_entities=[DetectEntities.SSN], + ) + with self.assertRaises(SkyflowError) as context: + validate_reidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_TEXT_IN_REIDENTIFY.value) + + def test_validate_reidentify_text_request_invalid_redacted_entities(self): + request = ReidentifyTextRequest( + text="test", + redacted_entities="invalid", + ) + with self.assertRaises(SkyflowError) as context: + validate_reidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_REDACTED_ENTITIES_IN_REIDENTIFY.value) + + def test_validate_reidentify_text_request_invalid_plain_text_entities(self): + request = ReidentifyTextRequest( + text="test", + plain_text_entities="invalid", + ) + with self.assertRaises(SkyflowError) as context: + validate_reidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_PLAIN_TEXT_ENTITIES_IN_REIDENTIFY.value) + + + def test_validate_deidentify_text_request_empty_text(self): + request = DeidentifyTextRequest( + text="", + entities=None, + allow_regex_list=None, + restrict_regex_list=None, + token_format=None, + transformations=None, + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_TEXT_IN_DEIDENTIFY.value) + + def test_validate_deidentify_text_request_invalid_text_type(self): + request = DeidentifyTextRequest( + text=["test"], + entities=None, + allow_regex_list=None, + restrict_regex_list=None, + token_format=None, + transformations=None, + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_TEXT_IN_DEIDENTIFY.value) + + def test_validate_deidentify_text_request_invalid_entities_type(self): + request = DeidentifyTextRequest( + text="test", + entities="invalid", + allow_regex_list=None, + restrict_regex_list=None, + token_format=None, + transformations=None, + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_ENTITIES_IN_DEIDENTIFY.value) + + def test_validate_deidentify_text_request_invalid_allow_regex(self): + request = DeidentifyTextRequest( + text="test", + allow_regex_list="invalid", + restrict_regex_list=None, + token_format=None, + transformations=None, + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_ALLOW_REGEX_LIST.value) + + def test_validate_deidentify_text_request_invalid_restrict_regex(self): + request = DeidentifyTextRequest( + text="test", + restrict_regex_list="invalid", + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_RESTRICT_REGEX_LIST.value) + + def test_validate_deidentify_text_request_invalid_token_format(self): + request = DeidentifyTextRequest( + text="test", + token_format="invalid", + transformations=None, + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_TOKEN_FORMAT.value) + + + def test_validate_reidentify_text_request_valid(self): + request = MagicMock() + request.text = "test text" + request.redacted_entities = None + request.masked_entities = None + request.plain_text_entities = None + validate_reidentify_text_request(self.logger, request) + + def test_validate_reidentify_text_request_empty_text(self): + request = MagicMock() + request.text = "" # Empty text + with self.assertRaises(SkyflowError) as context: + validate_reidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_TEXT_IN_REIDENTIFY.value) + + def test_validate_reidentify_text_request_invalid_text_type(self): + request = MagicMock() + request.text = 123 # Invalid type + with self.assertRaises(SkyflowError) as context: + validate_reidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_TEXT_IN_REIDENTIFY.value) + + def test_validate_reidentify_text_request_invalid_redacted_entities(self): + request = MagicMock() + request.text = "test text" + request.redacted_entities = "invalid" + with self.assertRaises(SkyflowError) as context: + validate_reidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_REDACTED_ENTITIES_IN_REIDENTIFY.value) + + def test_validate_reidentify_text_request_invalid_plain_text_entities(self): + request = ReidentifyTextRequest( + text="test text", + plain_text_entities="invalid" + ) + with self.assertRaises(SkyflowError) as context: + validate_reidentify_text_request(self.logger, request) + self.assertEqual(context.exception.message, + SkyflowMessages.Error.INVALID_PLAIN_TEXT_ENTITIES_IN_REIDENTIFY.value) + + def test_validate_deidentify_file_request_valid(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + entities=None, + allow_regex_list=None, + restrict_regex_list=None, + token_format=None, + transformations=None, + output_processed_image=None, + output_ocr_text=None, + masking_method=None, + pixel_density=None, + max_resolution=None, + output_processed_audio=None, + output_transcription=None, + bleep=None, + output_directory=None, + wait_time=None + ) + validate_deidentify_file_request(self.logger, request) + + def test_validate_deidentify_file_request_missing_file(self): + request = DeidentifyFileRequest(file=None) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_FILE_INPUT.value) + + def test_validate_deidentify_file_request_invalid_entities(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + entities="invalid" + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_DETECT_ENTITIES_TYPE.value) + + def test_validate_deidentify_file_request_invalid_allow_regex(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + allow_regex_list="invalid", + entities=[DetectEntities.ACCOUNT_NUMBER] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_ALLOW_REGEX_LIST.value) + + def test_validate_deidentify_file_request_invalid_restrict_regex(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + restrict_regex_list="invalid", + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_RESTRICT_REGEX_LIST.value) + + def test_validate_deidentify_file_request_invalid_token_format(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + token_format="invalid", + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_FORMAT.value) + + def test_validate_deidentify_file_request_invalid_transformations(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + transformations="invalid", + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TRANSFORMATIONS.value) + + def test_validate_deidentify_file_request_invalid_output_processed_image(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + output_processed_image="true", + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_OUTPUT_PROCESSED_IMAGE.value) + + def test_validate_deidentify_file_request_invalid_output_ocr_text(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + output_ocr_text="true", + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_OUTPUT_OCR_TEXT.value) + + def test_validate_deidentify_file_request_invalid_masking_method(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + masking_method="invalid", + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_MASKING_METHOD.value) + + def test_validate_deidentify_file_request_invalid_pixel_density(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + pixel_density="invalid", + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_PIXEL_DENSITY.value) + + def test_validate_deidentify_file_request_invalid_max_resolution(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + max_resolution="invalid", + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_MAXIMUM_RESOLUTION.value) + + def test_validate_deidentify_file_request_invalid_output_processed_audio(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + output_processed_audio="true", + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_OUTPUT_PROCESSED_AUDIO.value) + + def test_validate_deidentify_file_request_invalid_output_transcription(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + output_transcription="invalid", + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_OUTPUT_TRANSCRIPTION.value) + + def test_validate_deidentify_file_request_invalid_wait_time(self): + file_input = FileInput(file_path=self.temp_file_path) + request = DeidentifyFileRequest( + file=file_input, + wait_time="invalid", + entities=[DetectEntities.SSN] + ) + with self.assertRaises(SkyflowError) as context: + validate_deidentify_file_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_WAIT_TIME.value) + + def test_validate_detokenize_request_valid(self): + request = DetokenizeRequest( + data=[{"token": "token123", "redaction": RedactionType.PLAIN_TEXT}], + continue_on_error=False + ) + validate_detokenize_request(self.logger, request) + + def test_validate_detokenize_request_empty_data(self): + request = DetokenizeRequest(data=[], continue_on_error=False) + with self.assertRaises(SkyflowError) as context: + validate_detokenize_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.EMPTY_TOKENS_LIST_VALUE.value) + + def test_validate_detokenize_request_invalid_token_type(self): + request = DetokenizeRequest(data=[{"token": 123}], continue_on_error=False) + with self.assertRaises(SkyflowError) as context: + validate_detokenize_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKEN_TYPE.value.format("DETOKENIZE")) + + def test_validate_detokenize_request_missing_token_key(self): + request = DetokenizeRequest(data=[{"not_token": "value"}], continue_on_error=False) + with self.assertRaises(SkyflowError) as context: + validate_detokenize_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value.format(str(type(request.data)))) + + def test_validate_detokenize_request_invalid_continue_on_error_type(self): + request = DetokenizeRequest(data=[{"token": "token123"}], continue_on_error="invalid") + with self.assertRaises(SkyflowError) as context: + validate_detokenize_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CONTINUE_ON_ERROR_TYPE.value) + + def test_validate_detokenize_request_invalid_redaction_type(self): + request = DetokenizeRequest(data=[{"token": "token123", "redaction": "invalid"}], continue_on_error=False) + with self.assertRaises(SkyflowError) as context: + validate_detokenize_request(self.logger, request) + self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(str(type("invalid")))) diff --git a/tests/vault/controller/test__detect.py b/tests/vault/controller/test__detect.py index 3096ce08..dc3a753f 100644 --- a/tests/vault/controller/test__detect.py +++ b/tests/vault/controller/test__detect.py @@ -159,7 +159,7 @@ def test_deidentify_file_txt_success(self, mock_open, mock_basename, mock_base64 word_count=1, char_count=1, size_in_kb=1, duration_in_seconds=None, page_count=None, slide_count=None, entities=[], run_id="runid123", - status="SUCCESS", errors=None)) as mock_parse: + status="SUCCESS")) as mock_parse: result = self.detect.deidentify_file(req) mock_validate.assert_called_once() @@ -184,7 +184,6 @@ def test_deidentify_file_txt_success(self, mock_open, mock_basename, mock_base64 self.assertIsNone(result.page_count) self.assertIsNone(result.slide_count) self.assertEqual(result.entities, []) - self.assertEqual(result.errors, None) @patch("skyflow.vault.controller._detect.validate_deidentify_file_request") @patch("skyflow.vault.controller._detect.base64") @@ -222,7 +221,7 @@ def test_deidentify_file_audio_success(self, mock_base64, mock_validate): word_count=1, char_count=1, size_in_kb=1, duration_in_seconds=1, page_count=None, slide_count=None, entities=[], run_id="runid456", - status="SUCCESS", errors=None)) as mock_parse: + status="SUCCESS")) as mock_parse: result = self.detect.deidentify_file(req) mock_validate.assert_called_once() files_api.deidentify_audio.assert_called_once() @@ -263,8 +262,7 @@ def test_get_detect_run_success(self, mock_validate): return_value=DeidentifyFileResponse(file="file", type="txt", extension="txt", word_count=1, char_count=1, size_in_kb=1, duration_in_seconds=None, page_count=None, slide_count=None, entities=[], - run_id="runid789", status="SUCCESS", - errors=None)) as mock_parse: + run_id="runid789", status="SUCCESS")) as mock_parse: result = self.detect.get_detect_run(req) mock_validate.assert_called_once() files_api.get_run.assert_called_once() @@ -658,7 +656,11 @@ def test_deidentify_file_using_file_path(self, mock_open, mock_basename, mock_ba # Setup processed response processed_response = Mock() processed_response.status = "SUCCESS" - processed_response.output = [] + processed_response.output = [ + Mock(processedFile="dGVzdCBjb250ZW", + processedFileType="txt", + processedFileExtension="txt") + ] processed_response.wordCharacterCount = Mock(wordCount=1, characterCount=1) # Test the method @@ -679,16 +681,14 @@ def test_deidentify_file_using_file_path(self, mock_open, mock_basename, mock_ba entities=[], run_id="runid123", status="SUCCESS", - errors=None )) as mock_parse: result = self.detect.deidentify_file(req) mock_file.read.assert_called_once() - mock_basename.assert_called_with("/path/to/test.txt") - mock_validate.assert_called_once() files_api.deidentify_text.assert_called_once() + mock_basename.assert_called_with("/path/to/test.txt") mock_poll.assert_called_once() mock_parse.assert_called_once() @@ -710,4 +710,3 @@ def test_deidentify_file_using_file_path(self, mock_open, mock_basename, mock_ba self.assertIsNone(result.page_count) self.assertIsNone(result.slide_count) self.assertEqual(result.entities, []) - self.assertEqual(result.errors, None) diff --git a/tests/vault/controller/test__vault.py b/tests/vault/controller/test__vault.py index 0c8a7743..b1ac71e3 100644 --- a/tests/vault/controller/test__vault.py +++ b/tests/vault/controller/test__vault.py @@ -28,7 +28,7 @@ def test_insert_with_continue_on_error(self, mock_parse_response, mock_validate) # Mock request request = InsertRequest( - table_name=TABLE_NAME, + table=TABLE_NAME, values=[{"field": "value"}], tokens=None, return_tokens=True, @@ -87,7 +87,7 @@ def test_insert_with_continue_on_error_false(self, mock_parse_response, mock_val # Mock request with continue_on_error set to False request = InsertRequest( - table_name=TABLE_NAME, + table=TABLE_NAME, values=[{"field": "value"}], tokens=None, return_tokens=True, @@ -127,7 +127,7 @@ def test_insert_with_continue_on_error_false(self, mock_parse_response, mock_val @patch("skyflow.vault.controller._vault.validate_insert_request") def test_insert_handles_generic_error(self, mock_validate): - request = InsertRequest(table_name="test_table", values=[{"column_name": "value"}], return_tokens=False, + request = InsertRequest(table="test_table", values=[{"column_name": "value"}], return_tokens=False, upsert=False, homogeneous=False, continue_on_error=False, token_mode=Mock()) records_api = self.vault_client.get_records_api.return_value @@ -145,7 +145,7 @@ def test_insert_with_continue_on_error_false_when_tokens_are_not_none(self, mock # Mock request with continue_on_error set to False request = InsertRequest( - table_name=TABLE_NAME, + table=TABLE_NAME, values=[{"field": "value"}], tokens=[{"token_field": "token_val1"}], return_tokens=True,