Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ jobs:
uses: ./.github/workflows/shared-tests.yml
with:
python-version: '3.8'
secrets: inherit
7 changes: 7 additions & 0 deletions .github/workflows/shared-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ jobs:
with:
python-version: ${{ inputs.python-version }}

- name: create-json
id: create-json
uses: jsdaniell/create-json@1.1.2
with:
name: "credentials.json"
json: ${{ secrets.VALID_SKYFLOW_CREDS_TEST }}

- name: 'Run Tests'
run: |
pip install -r requirements.txt
Expand Down
6 changes: 3 additions & 3 deletions skyflow/utils/_skyflow_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ class Error(Enum):
EMPTY_DATA_IN_INSERT = f"{error_prefix} Validation error. Data array cannot be empty. Specify data in insert request."
INVALID_UPSERT_OPTIONS_TYPE = f"{error_prefix} Validation error. 'upsert' key cannot be empty in options. At least one object of table and column is required."
INVALID_HOMOGENEOUS_TYPE = f"{error_prefix} Validation error. Invalid type of homogeneous. Specify homogeneous as a string."
INVALID_TOKEN_STRICT_TYPE = f"{error_prefix} Validation error. Invalid type of token strict. Specify token strict as a enum."
INVALID_TOKEN_MODE_TYPE = f"{error_prefix} Validation error. Invalid type of token mode. Specify token mode as a TokenMode enum."
INVALID_RETURN_TOKENS_TYPE = f"{error_prefix} Validation error. Invalid type of return tokens. Specify return tokens as a boolean."
INVALID_CONTINUE_ON_ERROR_TYPE = f"{error_prefix} Validation error. Invalid type of continue on error. Specify continue on error as a boolean."
TOKENS_PASSED_FOR_TOKEN_STRICT_DISABLE = f"{error_prefix} Validation error. 'token_strict' wasn't specified. Set 'token_strict' to 'ENABLE' to insert tokens."
INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_STRICT_ENABLE_STRICT = f"{error_prefix} Validation error. 'byot' is set to 'ENABLE_STRICT', but some fields are missing tokens. Specify tokens for all fields."
TOKENS_PASSED_FOR_TOKEN_MODE_DISABLE = f"{error_prefix} Validation error. 'token_mode' wasn't specified. Set 'token_mode' to 'ENABLE' to insert tokens."
INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT = f"{error_prefix} Validation error. 'token_mode' is set to 'ENABLE_STRICT', but some fields are missing tokens. Specify tokens for all fields."
NO_TOKENS_IN_INSERT = f"{error_prefix} Validation error. Tokens weren't specified for records while 'token_strict' was {{}}. Specify tokens."
BATCH_INSERT_FAILURE = f"{error_prefix} Insert operation failed."
GET_FAILURE = f"{error_prefix} Get operation failed."
Expand Down
2 changes: 1 addition & 1 deletion skyflow/utils/enums/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .env import Env, EnvUrls
from .log_level import LogLevel
from .content_types import ContentType
from .token_strict import TokenStrict
from .token_mode import TokenMode
from .method import Method
from .redaction_type import RedactionType
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from enum import Enum

from skyflow.generated.rest import V1BYOT


class TokenStrict(Enum):
class TokenMode(Enum):
DISABLE = V1BYOT.DISABLE
ENABLE = V1BYOT.ENABLE
ENABLE_STRICT = V1BYOT.ENABLE_STRICT
42 changes: 21 additions & 21 deletions skyflow/utils/validations/_validations.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import re
from skyflow.service_account import is_expired
from skyflow.utils.enums import LogLevel, TokenStrict, Env, RedactionType
from skyflow.utils.enums import LogLevel, Env, RedactionType, TokenMode
from skyflow.error import SkyflowError
from skyflow.utils import SkyflowMessages
from skyflow.utils.logger import log_info, log_error_log
Expand Down Expand Up @@ -286,9 +286,9 @@ def validate_insert_request(logger, request):
log_error_log(SkyflowMessages.ErrorLogs.HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT.value.format("INSERT"), logger = logger)
raise SkyflowError(SkyflowMessages.Error.HOMOGENOUS_NOT_SUPPORTED_WITH_UPSERT.value.format("INSERT"), invalid_input_error_code)

if request.token_strict is not None:
if not isinstance(request.token_strict, TokenStrict):
raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_STRICT_TYPE.value, invalid_input_error_code)
if request.token_mode is not None:
if not isinstance(request.token_mode, TokenMode):
raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_MODE_TYPE.value, invalid_input_error_code)

if not isinstance(request.return_tokens, bool):
raise SkyflowError(SkyflowMessages.Error.INVALID_RETURN_TOKENS_TYPE.value, invalid_input_error_code)
Expand All @@ -311,21 +311,21 @@ def validate_insert_request(logger, request):
log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value("INSERT"), logger = logger)
raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code)

if request.token_strict == TokenStrict.ENABLE and not request.tokens:
raise SkyflowError(SkyflowMessages.Error.NO_TOKENS_IN_INSERT.value.format(request.token_strict), invalid_input_error_code)
if request.token_mode == TokenMode.ENABLE and not request.tokens:
raise SkyflowError(SkyflowMessages.Error.NO_TOKENS_IN_INSERT.value.format(request.token_mode), invalid_input_error_code)

if request.token_strict == TokenStrict.DISABLE and request.tokens:
raise SkyflowError(SkyflowMessages.Error.TOKENS_PASSED_FOR_TOKEN_STRICT_DISABLE.value, invalid_input_error_code)
if request.token_mode == TokenMode.DISABLE and request.tokens:
raise SkyflowError(SkyflowMessages.Error.TOKENS_PASSED_FOR_TOKEN_MODE_DISABLE.value, invalid_input_error_code)

if request.token_strict == TokenStrict.ENABLE_STRICT:
if request.token_mode == TokenMode.ENABLE_STRICT:
if len(request.values) != len(request.tokens):
log_error_log(SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format("INSERT"), logger = logger)
raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_STRICT_ENABLE_STRICT.value, invalid_input_error_code)
raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, invalid_input_error_code)

for v, t in zip(request.values, request.tokens):
if set(v.keys()) != set(t.keys()):
log_error_log(SkyflowMessages.ErrorLogs.MISMATCH_OF_FIELDS_AND_TOKENS.value.format("INSERT"), logger=logger)
raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_STRICT_ENABLE_STRICT.value, invalid_input_error_code)
raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value, invalid_input_error_code)

def validate_delete_request(logger, request):
if not isinstance(request.table, str):
Expand Down Expand Up @@ -467,36 +467,36 @@ def validate_update_request(logger, request):
if not len(request.data.items()):
raise SkyflowError(SkyflowMessages.Error.UPDATE_FIELD_KEY_ERROR.value, invalid_input_error_code)

if request.token_strict is not None:
if not isinstance(request.token_strict, TokenStrict):
raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_STRICT_TYPE.value, invalid_input_error_code)
if request.token_mode is not None:
if not isinstance(request.token_mode, TokenMode):
raise SkyflowError(SkyflowMessages.Error.INVALID_TOKEN_MODE_TYPE.value, invalid_input_error_code)

if request.tokens:
if not isinstance(request.tokens, dict) or not request.tokens:
log_error_log(SkyflowMessages.ErrorLogs.EMPTY_TOKENS.value.format("UPDATE"), logger=logger)
raise SkyflowError(SkyflowMessages.Error.INVALID_TYPE_OF_DATA_IN_INSERT.value, invalid_input_error_code)

if request.token_strict == TokenStrict.ENABLE and not request.tokens:
raise SkyflowError(SkyflowMessages.Error.NO_TOKENS_IN_INSERT.value.format(request.token_Strict),
if request.token_mode == TokenMode.ENABLE and not request.tokens:
raise SkyflowError(SkyflowMessages.Error.NO_TOKENS_IN_INSERT.value.format(request.token_mode),
invalid_input_error_code)

if request.token_strict == TokenStrict.DISABLE and request.tokens:
raise SkyflowError(SkyflowMessages.Error.TOKENS_PASSED_FOR_TOKEN_STRICT_DISABLE.value, invalid_input_error_code)
if request.token_mode == TokenMode.DISABLE and request.tokens:
raise SkyflowError(SkyflowMessages.Error.TOKENS_PASSED_FOR_TOKEN_MODE_DISABLE.value, invalid_input_error_code)

if request.token_strict == TokenStrict.ENABLE_STRICT:
if request.token_mode == TokenMode.ENABLE_STRICT:
if len(field) != len(request.tokens):
log_error_log(
SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format("UPDATE"),
logger=logger)
raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_STRICT_ENABLE_STRICT.value,
raise SkyflowError(SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value,
invalid_input_error_code)

if set(field.keys()) != set(request.tokens.keys()):
log_error_log(
SkyflowMessages.ErrorLogs.INSUFFICIENT_TOKENS_PASSED_FOR_BYOT_ENABLE_STRICT.value.format("UPDATE"),
logger=logger)
raise SkyflowError(
SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_STRICT_ENABLE_STRICT.value,
SkyflowMessages.Error.INSUFFICIENT_TOKENS_PASSED_FOR_TOKEN_MODE_ENABLE_STRICT.value,
invalid_input_error_code)

def validate_detokenize_request(logger, request):
Expand Down
6 changes: 3 additions & 3 deletions skyflow/vault/controller/_vault.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __build_insert_body(self, request: InsertRequest):
body = RecordServiceBatchOperationBody(
records=records_list,
continue_on_error=request.continue_on_error,
byot=request.token_strict.value
byot=request.token_mode.value
)
return body
else:
Expand All @@ -64,7 +64,7 @@ def __build_insert_body(self, request: InsertRequest):
tokenization=request.return_tokens,
upsert=request.upsert,
homogeneous=request.homogeneous,
byot=request.token_strict.value
byot=request.token_mode.value
)

def insert(self, request: InsertRequest):
Expand Down Expand Up @@ -103,7 +103,7 @@ def update(self, request: UpdateRequest):
self.__initialize()
field = {key: value for key, value in request.data.items() if key != "skyflow_id"}
record = V1FieldRecords(fields=field, tokens = request.tokens)
payload = RecordServiceUpdateRecordBody(record=record, tokenization=request.return_tokens, byot=request.token_strict.value)
payload = RecordServiceUpdateRecordBody(record=record, tokenization=request.return_tokens, byot=request.token_mode.value)

records_api = self.__vault_client.get_records_api()
try:
Expand Down
6 changes: 3 additions & 3 deletions skyflow/vault/data/_insert_request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from skyflow.utils.enums import TokenStrict
from skyflow.utils.enums import TokenMode

class InsertRequest:
def __init__(self,
Expand All @@ -7,15 +7,15 @@ def __init__(self,
tokens = None,
upsert = None,
homogeneous = False,
token_strict = TokenStrict.DISABLE,
token_mode = TokenMode.DISABLE,
return_tokens = True,
continue_on_error = False):
self.table_name = table_name
self.values = values
self.tokens = tokens
self.upsert = upsert
self.homogeneous = homogeneous
self.token_strict = token_strict
self.token_mode = token_mode
self.return_tokens = return_tokens
self.continue_on_error = continue_on_error

7 changes: 3 additions & 4 deletions skyflow/vault/data/_update_request.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from skyflow.utils.enums import TokenStrict

from skyflow.utils.enums import TokenMode

class UpdateRequest:
def __init__(self, table, data, tokens = None, return_tokens = False, token_strict = TokenStrict.DISABLE):
def __init__(self, table, data, tokens = None, return_tokens = False, token_mode = TokenMode.DISABLE):
self.table = table
self.data = data
self.tokens = tokens
self.return_tokens = return_tokens
self.token_strict = token_strict
self.token_mode = token_mode
12 changes: 9 additions & 3 deletions tests/constants/test_constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from skyflow import LogLevel
from skyflow import Env
import os
import json

creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json")
with open(creds_path, 'r') as file:
credentials = json.load(file)

#client initialization constants

Expand Down Expand Up @@ -37,7 +43,9 @@

# service account constants

VALID_CREDENTIALS_STRING = '{"clientID":"b78eee76e91c43eda7a0e83f5c3a98e6","clientName":"test_V2","tokenURI":"https://manage.skyflowapis.dev/v1/auth/sa/oauth/token","keyID":"f927c615ca2b433294dcf45da0ba010c","privateKey":"-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC7iqpXXHjMuk5z\nh4PdOp6CxFr2zy6HCe2HKHzNvYcRk04jpjQgw/oRwXd8B5doMTmIzpxJ0K9sDBO4\nvYSdwjRhFnpnXWoVHKijtMUxWAuyZdB1mA/3hqeElpb6218aQeyGA6H98TTzb5G6\nJxn5lBr0qChm2o4gJHbYUO8PVvvm/ixDMrb87sH+yfCTYEWCcE9AozK3d1mST9F1\nSEnQEDML3mBTBqgLRn0NuEI273RpyAierY8KhQkiKg+0p3d2KkIrqgz05XlyKgw+\nV/ECymq2nH4vi3vGzSWMFSxiQ65fKZim+SPqIOLJZGemTOkfGv4SRWCCZ0WOLXO5\nsRcpsttpAgMBAAECggEBAKPcnsVCKNJInq9W4qJzy3fadNhdYfvfcsi7WYCybseu\ne4GugLF4SpElB285etMw32JnlCryybOQQdMS1EK7IuUJrN2Pw1a6+aZAFmPs2BuB\n1khJGvpdjxTMNxLshgX9P9pAZlPpYyiofR23eHyXKY5HNzXXFIOFGMocvSQcDnFe\neQom0mcd5EwVs5Zk4RDtLQlKdqByGgmMI/GRtdG8Of5jKhG1g3YomYAGIaFqCAEJ\nyUJEhfGMztpl8glLPECt2X09oUVrwwM7zOj3a0B46b1zmuLlcusIQHgMg9pNJBOR\nno4LAC8pMX5JEJjFRAYGsrntooHSAWR2n09GzDNkkxkCgYEA3FxZb2rOtcLQf8Dw\nC0UmcYWo3n0o4TVIjbSLPb0vKDIjkLmK42rNmMinD7BooSGPyPB/SNOm5NircDlv\nR5OVA4F3WUhMiDMPcZu/CJ5yWHNyR4y+erZh8NSbc4xwSfTdmMnclPkQBDu8N41N\n9KBwvTT8mqaIMw7NjhN0J1IM1O8CgYEA2d9/4+24GSMCltTQkGh7sALZyGhef/cR\nvL1cNMAeHzvrJVffp9rixtmVs20XCwbVH6AZXlHk5ALfSTo6XurBQyhPls6LC6ns\nNOoyviveo0fV0H8fj39wOWjZh3LQS/5CgxtBh6URMDVJfGv7NIAMOHBXVh8EyA8E\ndrks47VGRScCgYAxt1wuOQi+FV/5EsyVnlpYDnHVEKPie6UM44juuvoitX00r8fY\nG0abi9m1PnW8tNe93BS7l5T12LSFM1AZ9AAQtGr658bsi6iWVy84gJcHwbQs1GI9\nSVy7exw/a5YB+Y7tY82yhqbIbbm/RtApuvD0nznGon/kFRjnTxhLrsVaXQKBgCOH\nNbS2bCH1OpPcClKyJxFRta/fjSFy6bqMan/ToFXZkIPba4ZUxExG6QmETZCnwZNR\nqTFfS2L/MOghDamywGcyKKBf9/6j6/fJBRNL1hdsPGqugDgHQQarmWVkDKGHydLV\nW/9BpKbm2Z/nf+RUySle8G8DyeTRxhmSIsbTJa1bAoGBAMD6TXQg15dX9+hltpMh\n16IJB6Y15AA9KEiVKDyD+WF4V7BVIbsMmjFGoNBAF5/uwJk5UVKaGHjP5Dl12InR\n38wOrDi+uuOGlDfsiPJZ91reGdoXVNAfky9sRK1uBiRiskaWliP4hLdb1SsUzu5s\nHbxRby/7eC3gvCVA+6LqV9Fv\n-----END PRIVATE KEY-----\n","keyValidAfterTime":"2024-10-21T18:06:26.000Z","keyValidBeforeTime":"2025-10-21T18:06:26.000Z","keyAlgorithm":"KEY_ALG_RSA_2048"}'
VALID_CREDENTIALS_STRING = json.dumps(credentials)

VALID_SERVICE_ACCOUNT_CREDS = credentials

CREDENTIALS_WITHOUT_CLIENT_ID = {
'privateKey': 'private_key'
Expand All @@ -54,8 +62,6 @@
'keyID': 'key_id'
}

VALID_SERVICE_ACCOUNT_CREDS = {"clientID":"b78eee76e91c43eda7a0e83f5c3a98e6","clientName":"test_V2","tokenURI":"https://manage.skyflowapis.dev/v1/auth/sa/oauth/token","keyID":"f927c615ca2b433294dcf45da0ba010c","privateKey":"-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC7iqpXXHjMuk5z\nh4PdOp6CxFr2zy6HCe2HKHzNvYcRk04jpjQgw/oRwXd8B5doMTmIzpxJ0K9sDBO4\nvYSdwjRhFnpnXWoVHKijtMUxWAuyZdB1mA/3hqeElpb6218aQeyGA6H98TTzb5G6\nJxn5lBr0qChm2o4gJHbYUO8PVvvm/ixDMrb87sH+yfCTYEWCcE9AozK3d1mST9F1\nSEnQEDML3mBTBqgLRn0NuEI273RpyAierY8KhQkiKg+0p3d2KkIrqgz05XlyKgw+\nV/ECymq2nH4vi3vGzSWMFSxiQ65fKZim+SPqIOLJZGemTOkfGv4SRWCCZ0WOLXO5\nsRcpsttpAgMBAAECggEBAKPcnsVCKNJInq9W4qJzy3fadNhdYfvfcsi7WYCybseu\ne4GugLF4SpElB285etMw32JnlCryybOQQdMS1EK7IuUJrN2Pw1a6+aZAFmPs2BuB\n1khJGvpdjxTMNxLshgX9P9pAZlPpYyiofR23eHyXKY5HNzXXFIOFGMocvSQcDnFe\neQom0mcd5EwVs5Zk4RDtLQlKdqByGgmMI/GRtdG8Of5jKhG1g3YomYAGIaFqCAEJ\nyUJEhfGMztpl8glLPECt2X09oUVrwwM7zOj3a0B46b1zmuLlcusIQHgMg9pNJBOR\nno4LAC8pMX5JEJjFRAYGsrntooHSAWR2n09GzDNkkxkCgYEA3FxZb2rOtcLQf8Dw\nC0UmcYWo3n0o4TVIjbSLPb0vKDIjkLmK42rNmMinD7BooSGPyPB/SNOm5NircDlv\nR5OVA4F3WUhMiDMPcZu/CJ5yWHNyR4y+erZh8NSbc4xwSfTdmMnclPkQBDu8N41N\n9KBwvTT8mqaIMw7NjhN0J1IM1O8CgYEA2d9/4+24GSMCltTQkGh7sALZyGhef/cR\nvL1cNMAeHzvrJVffp9rixtmVs20XCwbVH6AZXlHk5ALfSTo6XurBQyhPls6LC6ns\nNOoyviveo0fV0H8fj39wOWjZh3LQS/5CgxtBh6URMDVJfGv7NIAMOHBXVh8EyA8E\ndrks47VGRScCgYAxt1wuOQi+FV/5EsyVnlpYDnHVEKPie6UM44juuvoitX00r8fY\nG0abi9m1PnW8tNe93BS7l5T12LSFM1AZ9AAQtGr658bsi6iWVy84gJcHwbQs1GI9\nSVy7exw/a5YB+Y7tY82yhqbIbbm/RtApuvD0nznGon/kFRjnTxhLrsVaXQKBgCOH\nNbS2bCH1OpPcClKyJxFRta/fjSFy6bqMan/ToFXZkIPba4ZUxExG6QmETZCnwZNR\nqTFfS2L/MOghDamywGcyKKBf9/6j6/fJBRNL1hdsPGqugDgHQQarmWVkDKGHydLV\nW/9BpKbm2Z/nf+RUySle8G8DyeTRxhmSIsbTJa1bAoGBAMD6TXQg15dX9+hltpMh\n16IJB6Y15AA9KEiVKDyD+WF4V7BVIbsMmjFGoNBAF5/uwJk5UVKaGHjP5Dl12InR\n38wOrDi+uuOGlDfsiPJZ91reGdoXVNAfky9sRK1uBiRiskaWliP4hLdb1SsUzu5s\nHbxRby/7eC3gvCVA+6LqV9Fv\n-----END PRIVATE KEY-----\n","keyValidAfterTime":"2024-10-21T18:06:26.000Z","keyValidBeforeTime":"2025-10-21T18:06:26.000Z","keyAlgorithm":"KEY_ALG_RSA_2048"}

# utils constants

VALID_URL = "https://example.com/path?query=1"
Expand Down
12 changes: 3 additions & 9 deletions tests/service_account/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ def test_generate_bearer_token_invalid_file_path(self, mock_open):

@patch("json.load", side_effect=json.JSONDecodeError("Expecting value", "", 0))
def test_generate_bearer_token_invalid_json(self, mock_json_load):
creds_path = os.path.join(os.path.dirname(__file__), "valid_credentials.json")
creds_path = os.path.join(os.path.dirname(__file__), "invalid_creds.json")
with self.assertRaises(SkyflowError) as context:
generate_bearer_token(creds_path)
self.assertEqual(context.exception.message, SkyflowMessages.Error.FILE_INVALID_JSON.value.format(creds_path))

@patch("skyflow.service_account._utils.get_service_account_token")
def test_generate_bearer_token_valid_file_path(self, mock_generate_bearer_token):
creds_path = os.path.join(os.path.dirname(__file__), "valid_credentials.json")
creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json")
generate_bearer_token(creds_path)
mock_generate_bearer_token.assert_called_once()

Expand Down Expand Up @@ -105,7 +105,7 @@ def test_get_signed_data_token_response_object(self):
self.assertEqual(response[1], signed_token)

def test_generate_signed_data_tokens_from_file_path(self):
creds_path = os.path.join(os.path.dirname(__file__), "valid_credentials.json")
creds_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), "credentials.json")
options = {"data_tokens": ["token1", "token2"], "ctx": 'ctx'}
result = generate_signed_data_tokens(creds_path, options)
self.assertEqual(len(result), 2)
Expand All @@ -116,12 +116,6 @@ def test_generate_signed_data_tokens_from_invalid_file_path(self):
result = generate_signed_data_tokens('credentials1.json', options)
self.assertEqual(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIAL_FILE_PATH.value)

# def test_generate_signed_data_tokens_from_valid_file_path_with_invalid_credentials(self):
# options = {"data_tokens": ["token1", "token2"]}
# with self.assertRaises(SkyflowError) as context:
# result = generate_signed_data_tokens("invalid_creds.json", options)
# self.assertEqual(context.exception.message, SkyflowMessages.Error.FILE_INVALID_JSON.value.format("invalid_creds.json"))

def test_generate_signed_data_tokens_from_creds(self):
options = {"data_tokens": ["token1", "token2"]}
result = generate_signed_data_tokens_from_creds(VALID_CREDENTIALS_STRING, options)
Expand Down
Loading
Loading