Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions skyflow/generated/rest/api/tokens_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def record_service_detokenize_with_http_info(

_response_types_map: Dict[str, Optional[str]] = {
'200': "V1DetokenizeResponse",
'207': "V1DetokenizeResponse",
'404': "object",
}
response_data = self.api_client.call_api(
Expand Down
11 changes: 7 additions & 4 deletions skyflow/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from urllib.parse import quote
from skyflow.error import SkyflowError
from skyflow.generated.rest import V1UpdateRecordResponse, V1BulkDeleteRecordResponse, \
V1DetokenizeResponse, V1TokenizeResponse, V1GetQueryResponse, V1BulkGetRecordResponse
V1DetokenizeResponse, V1TokenizeResponse, V1GetQueryResponse, V1BulkGetRecordResponse, ApiResponse
from skyflow.utils.logger import log_error, log_error_log
from . import SkyflowMessages, SDK_VERSION
from .enums import Env, ContentType, EnvUrls
Expand Down Expand Up @@ -195,7 +195,8 @@ def parse_insert_response(api_response, continue_on_error):
errors = []
insert_response = InsertResponse()
if continue_on_error:
for idx, response in enumerate(api_response.responses):
response_data = json.loads(api_response.raw_data.decode('utf-8'))
for idx, response in enumerate(response_data.get('responses', [])):
if response['Status'] == 200:
body = response['Body']
if 'records' in body:
Expand All @@ -210,6 +211,7 @@ def parse_insert_response(api_response, continue_on_error):
inserted_fields.append(inserted_field)
elif response['Status'] == 400:
error = {
'request_id': api_response.headers.get('x-request-id'),
'request_index': idx,
'error': response['Body']['error']
}
Expand Down Expand Up @@ -264,13 +266,14 @@ def parse_get_response(api_response: V1BulkGetRecordResponse):

return get_response

def parse_detokenize_response(api_response: V1DetokenizeResponse):
def parse_detokenize_response(api_response: ApiResponse[V1DetokenizeResponse]):
detokenized_fields = []
errors = []

for record in api_response.records:
for record in api_response.data.records:
if record.error:
errors.append({
"request_id": api_response.headers.get('x-request-id'),
"token": record.token,
"error": record.error
})
Expand Down
8 changes: 4 additions & 4 deletions skyflow/utils/enums/env.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from enum import Enum

class Env(Enum):
DEV = 'DEV',
SANDBOX = 'SANDBOX',
DEV = 'DEV'
SANDBOX = 'SANDBOX'
PROD = 'PROD'
STAGE = 'STAGE'

class EnvUrls(Enum):
PROD = "vault.skyflowapis.com",
SANDBOX = "vault.skyflowapis-preview.com",
PROD = "vault.skyflowapis.com"
SANDBOX = "vault.skyflowapis-preview.com"
DEV = "vault.skyflowapis.dev"
STAGE = "vault.skyflowapis.tech"
6 changes: 3 additions & 3 deletions skyflow/utils/validations/_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,15 +514,15 @@ def validate_detokenize_request(logger, request):
raise SkyflowError(SkyflowMessages.Error.EMPTY_TOKENS_LIST_VALUE.value, invalid_input_error_code)

for item in request.data:
if 'token' not in item or 'redaction' not in item:
if 'token' not in item:
raise SkyflowError(SkyflowMessages.Error.INVALID_TOKENS_LIST_VALUE.value(type(request.data)), invalid_input_error_code)
token = item.get('token')
redaction = item.get('redaction')
redaction = item.get('redaction', RedactionType.PLAIN_TEXT)

if not isinstance(token, str) or not token:
raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_TYPE.value.format("DETOKENIZE"), invalid_input_error_code)

if not isinstance(redaction, RedactionType) or not redaction:
if redaction is not None and not isinstance(redaction, RedactionType):
raise SkyflowError(SkyflowMessages.Error.INVALID_REDACTION_TYPE.value.format(type(redaction)), invalid_input_error_code)

def validate_tokenize_request(logger, request):
Expand Down
10 changes: 7 additions & 3 deletions skyflow/vault/controller/_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from skyflow.utils import SkyflowMessages, parse_insert_response, \
handle_exception, parse_update_record_response, parse_delete_response, parse_detokenize_response, \
parse_tokenize_response, parse_query_response, parse_get_response, encode_column_values
from skyflow.utils.enums import RedactionType
from skyflow.utils.logger import log_info, log_error_log
from skyflow.utils.validations import validate_insert_request, validate_delete_request, validate_query_request, \
validate_get_request, validate_update_request, validate_detokenize_request, validate_tokenize_request
Expand Down Expand Up @@ -89,7 +90,7 @@ def insert(self, request: InsertRequest):
log_info(SkyflowMessages.Info.INSERT_TRIGGERED.value, self.__vault_client.get_logger())

if request.continue_on_error:
api_response = records_api.record_service_batch_operation(self.__vault_client.get_vault_id(),
api_response = records_api.record_service_batch_operation_with_http_info(self.__vault_client.get_vault_id(),
insert_body)

else:
Expand Down Expand Up @@ -230,14 +231,17 @@ def detokenize(self, request: DetokenizeRequest):
log_info(SkyflowMessages.Info.DETOKENIZE_REQUEST_RESOLVED.value, self.__vault_client.get_logger())
self.__initialize()
tokens_list = [
V1DetokenizeRecordRequest(token=item.get('token'), redaction=item.get('redaction').value)
V1DetokenizeRecordRequest(
token=item.get('token'),
redaction=item.get('redaction').value if item.get('redaction') else RedactionType.PLAIN_TEXT.value
)
for item in request.data
]
payload = V1DetokenizePayload(detokenization_parameters=tokens_list, continue_on_error=request.continue_on_error)
tokens_api = self.__vault_client.get_tokens_api()
try:
log_info(SkyflowMessages.Info.DETOKENIZE_TRIGGERED.value, self.__vault_client.get_logger())
api_response = tokens_api.record_service_detokenize(
api_response = tokens_api.record_service_detokenize_with_http_info(
self.__vault_client.get_vault_id(),
detokenize_payload=payload
)
Expand Down
24 changes: 18 additions & 6 deletions tests/utils/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,13 +183,23 @@ def test_construct_invoke_connection_request_with_form_date_content_type(self):

def test_parse_insert_response(self):
api_response = Mock()
api_response.responses = [
{"Status": 200, "Body": {"records": [{"skyflow_id": "id1"}]}},
{"Status": 400, "Body": {"error": TEST_ERROR_MESSAGE}}
]

api_response.raw_data = json.dumps({
"responses": [
{"Status": 200, "Body": {"records": [{"skyflow_id": "id1"}]}},
{"Status": 400, "Body": {"error": "TEST_ERROR_MESSAGE"}}
]
}).encode('utf-8')

api_response.headers = {"x-request-id": "test-request-id"}

result = parse_insert_response(api_response, continue_on_error=True)

self.assertEqual(len(result.inserted_fields), 1)
self.assertEqual(len(result.errors), 1)
self.assertEqual(result.inserted_fields[0]['skyflow_id'], "id1")
self.assertEqual(result.errors[0]['error'], "TEST_ERROR_MESSAGE")
self.assertEqual(result.errors[0]['request_id'], "test-request-id")

def test_parse_insert_response_continue_on_error_false(self):
mock_api_response = Mock()
Expand Down Expand Up @@ -252,11 +262,13 @@ def test_parse_get_response_successful(self):

def test_parse_detokenize_response_with_mixed_records(self):
mock_api_response = Mock()
mock_api_response.records = [
mock_api_response.data = Mock() # Ensure `data` exists
mock_api_response.data.records = [
Mock(token="token1", value="value1", value_type=Mock(value="Type1"), error=None),
Mock(token="token2", value=None, value_type=None, error="Some error"),
Mock(token="token3", value="value3", value_type=Mock(value="Type2"), error=None),
]
mock_api_response.headers = {"x-request-id": "test-request-id"} # Mock headers

result = parse_detokenize_response(mock_api_response)
self.assertIsInstance(result, DetokenizeResponse)
Expand All @@ -267,7 +279,7 @@ def test_parse_detokenize_response_with_mixed_records(self):
]

expected_errors = [
{"token": "token2", "error": "Some error"}
{"request_id": "test-request-id", "token": "token2", "error": "Some error"}
]

self.assertEqual(result.detokenized_fields, expected_detokenized_fields)
Expand Down
21 changes: 11 additions & 10 deletions tests/vault/controller/test__vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from unittest.mock import Mock, patch
from skyflow.generated.rest import RecordServiceBatchOperationBody, V1BatchRecord, RecordServiceInsertRecordBody, \
V1FieldRecords, RecordServiceUpdateRecordBody, RecordServiceBulkDeleteRecordBody, QueryServiceExecuteQueryBody, \
V1DetokenizeRecordRequest, V1DetokenizePayload, V1TokenizePayload, V1TokenizeRecordRequest, RedactionEnumREDACTION
V1DetokenizeRecordRequest, V1DetokenizePayload, V1TokenizePayload, V1TokenizeRecordRequest, RedactionEnumREDACTION, \
BatchRecordMethod
from skyflow.utils.enums import RedactionType, TokenMode
from skyflow.vault.controller import Vault
from skyflow.vault.data import InsertRequest, InsertResponse, UpdateResponse, UpdateRequest, DeleteResponse, \
Expand Down Expand Up @@ -43,7 +44,7 @@ def test_insert_with_continue_on_error(self, mock_parse_response, mock_validate)
V1BatchRecord(
fields={"field": "value"},
table_name=TABLE_NAME,
method="POST",
method=BatchRecordMethod.POST,
tokenization=True,
upsert="column_name"
)
Expand Down Expand Up @@ -71,14 +72,14 @@ def test_insert_with_continue_on_error(self, mock_parse_response, mock_validate)
# Set the return value for the parse response
mock_parse_response.return_value = expected_response
records_api = self.vault_client.get_records_api.return_value
records_api.record_service_batch_operation.return_value = mock_api_response
records_api.record_service_batch_operation_with_http_info.return_value = mock_api_response

# Call the insert function
result = self.vault.insert(request)

# Assertions
mock_validate.assert_called_once_with(self.vault_client.get_logger(), request)
records_api.record_service_batch_operation.assert_called_once_with(VAULT_ID, expected_body)
records_api.record_service_batch_operation_with_http_info.assert_called_once_with(VAULT_ID, expected_body)
mock_parse_response.assert_called_once_with(mock_api_response, True)

# Assert that the result matches the expected InsertResponse
Expand Down Expand Up @@ -481,28 +482,28 @@ def test_detokenize_successful(self, mock_parse_response, mock_validate):
# Mock API response
mock_api_response = Mock()
mock_api_response.records = [
Mock(token="token1", value="value1", value_type=Mock(value="STRING"), error=None),
Mock(token="token2", value="value2", value_type=Mock(value="STRING"), error=None)
Mock(skyflow_id="id_1", token="token1", value="value1", value_type=Mock(value="STRING"), error=None),
Mock(skyflow_id="id_2", token="token2", value="value2", value_type=Mock(value="STRING"), error=None)
]

# Expected parsed response
expected_fields = [
{"token": "token1", "value": "value1", "type": "STRING"},
{"token": "token2", "value": "value2", "type": "STRING"}
{"skyflow_id": "id_1", "token": "token1", "value": "value1", "type": "STRING"},
{"skyflow_id": "id_2", "token": "token2", "value": "value2", "type": "STRING"}
]
expected_response = DetokenizeResponse(detokenized_fields=expected_fields, errors=[])

# Set the return value for parse_detokenize_response
mock_parse_response.return_value = expected_response
tokens_api = self.vault_client.get_tokens_api.return_value
tokens_api.record_service_detokenize.return_value = mock_api_response
tokens_api.record_service_detokenize_with_http_info.return_value = mock_api_response

# Call the detokenize function
result = self.vault.detokenize(request)

# Assertions
mock_validate.assert_called_once_with(self.vault_client.get_logger(), request)
tokens_api.record_service_detokenize.assert_called_once_with(
tokens_api.record_service_detokenize_with_http_info.assert_called_once_with(
VAULT_ID,
detokenize_payload=expected_payload
)
Expand Down
Loading