Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ DateTime~=5.5
PyJWT~=2.9.0
requests~=2.32.3
coverage
cryptography
cryptography
python-dotenv~=1.0.1
273 changes: 131 additions & 142 deletions skyflow/utils/_skyflow_messages.py

Large diffs are not rendered by default.

10 changes: 8 additions & 2 deletions skyflow/utils/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import json
import urllib.parse
from dotenv import load_dotenv
from requests.sessions import PreparedRequest
from requests.models import HTTPError
import requests
Expand All @@ -22,6 +23,9 @@
invalid_input_error_code = SkyflowMessages.ErrorCodes.INVALID_INPUT.value

def get_credentials(config_level_creds = None, common_skyflow_creds = None, logger = None):
dotenv_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), ".env")
if dotenv_path:
load_dotenv(dotenv_path)
env_skyflow_credentials = os.getenv("SKYFLOW_CREDENTIALS")
if config_level_creds:
return config_level_creds
Expand All @@ -30,8 +34,10 @@ def get_credentials(config_level_creds = None, common_skyflow_creds = None, logg
if env_skyflow_credentials:
env_skyflow_credentials.strip()
try:
env_creds = json.loads(env_skyflow_credentials.replace('\n', '\\n'))
return env_creds
env_creds = env_skyflow_credentials.replace('\n', '\\n')
return {
'credentials_string': env_creds
}
except json.JSONDecodeError:
raise SkyflowError(SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value, invalid_input_error_code)
else:
Expand Down
2 changes: 1 addition & 1 deletion skyflow/utils/validations/_validations.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def validate_update_vault_config(logger, config):
return True

def validate_connection_config(logger, config):
log_info(SkyflowMessages.Info.VALIDATE_CONNECTION_CONFIG.value, logger)
log_info(SkyflowMessages.Info.VALIDATING_CONNECTION_CONFIG.value, logger)
validate_keys(logger, config, valid_connection_config_keys)

validate_required_field(
Expand Down
8 changes: 4 additions & 4 deletions skyflow/vault/client/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
from skyflow.generated.rest import Configuration, RecordsApi, ApiClient, TokensApi, QueryApi
from skyflow.service_account import generate_bearer_token, generate_bearer_token_from_creds, is_expired
from skyflow.utils import get_vault_url, get_credentials, SkyflowMessages
Expand All @@ -23,7 +24,7 @@ def set_logger(self, log_level, logger):
self.__logger = logger

def initialize_client_configuration(self):
credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger)
credentials = get_credentials(self.__config.get("credentials"), self.__common_skyflow_credentials, logger = self.__logger)
token = self.get_bearer_token(credentials)
vault_url = get_vault_url(self.__config.get("cluster_id"),
self.__config.get("env"),
Expand Down Expand Up @@ -58,8 +59,6 @@ def get_bearer_token(self, credentials):
"ctx": self.__config.get("ctx")
}

log_info(SkyflowMessages.Info.GENERATE_BEARER_TOKEN_TRIGGERED, self.__logger)

if self.__bearer_token is None or self.__is_config_updated:
if 'path' in credentials:
path = credentials.get("path")
Expand All @@ -77,12 +76,13 @@ def get_bearer_token(self, credentials):
self.__logger
)
self.__is_config_updated = False
else:
log_info(SkyflowMessages.Info.REUSE_BEARER_TOKEN.value, self.__logger)

if is_expired(self.__bearer_token):
self.__is_config_updated = True
raise SyntaxError(SkyflowMessages.Error.EXPIRED_TOKEN.value, SkyflowMessages.ErrorCodes.INVALID_INPUT.value)

log_info(SkyflowMessages.Info.REUSE_BEARER_TOKEN.value, self.__logger)
return self.__bearer_token

def update_config(self, config):
Expand Down
2 changes: 0 additions & 2 deletions skyflow/vault/controller/_vault.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from venv import logger

from skyflow.generated.rest import V1FieldRecords, RecordServiceInsertRecordBody, V1DetokenizeRecordRequest, \
V1DetokenizePayload, V1TokenizeRecordRequest, V1TokenizePayload, QueryServiceExecuteQueryBody, \
RecordServiceBulkDeleteRecordBody, RecordServiceUpdateRecordBody, RecordServiceBatchOperationBody, V1BatchRecord, \
Expand Down
4 changes: 3 additions & 1 deletion tests/constants/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
EMPTY_URL = ""
SCOPES_LIST = ["admin", "user", "viewer"]
FORMATTED_SCOPES = "role:admin role:user role:viewer"
INVALID_JSON_FORMAT = '{"invalid": json}'
INVALID_JSON_FORMAT = '[{"invalid": "json"}]'

TEST_ERROR_MESSAGE = "Test error message."

Expand All @@ -90,6 +90,8 @@
CREDENTIALS_WITH_PATH = {"path": "/path/to/creds.json"}
CREDENTIALS_WITH_STRING = {"credentials_string": "dummy_credentials_string"}

VALID_ENV_CREDENTIALS = {"clientID":"CLIENT_ID","clientName":"test_V2","tokenURI":"TOKEN_URI","keyID":"KEY_ID","privateKey":"PRIVATE_KEY","keyValidAfterTime":"2024-10-21T18:06:26.000Z","keyValidBeforeTime":"2025-10-21T18:06:26.000Z","keyAlgorithm":"KEY_ALG_RSA_2048"}


# connection controller constants

Expand Down
22 changes: 6 additions & 16 deletions tests/utils/test__utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,17 @@
from skyflow.vault.connection import InvokeConnectionResponse
from skyflow.vault.data import InsertResponse, DeleteResponse, GetResponse, QueryResponse
from skyflow.vault.tokens import DetokenizeResponse, TokenizeResponse
from tests.constants.test_constants import VALID_CREDENTIALS_STRING, INVALID_JSON_FORMAT, TEST_ERROR_MESSAGE
from tests.constants.test_constants import VALID_CREDENTIALS_STRING, INVALID_JSON_FORMAT, TEST_ERROR_MESSAGE, \
VALID_ENV_CREDENTIALS


class TestUtils(unittest.TestCase):
# def test_get_credentials_empty_credentials(self):
# with self.assertRaises(SkyflowError) as context:
# get_credentials()
# self.assertIn(context.exception.message, SkyflowMessages.Error.INVALID_CREDENTIALS.value)

@patch.dict(os.environ, {"SKYFLOW_CREDENTIALS": VALID_CREDENTIALS_STRING})
@patch.dict(os.environ, {"SKYFLOW_CREDENTIALS": json.dumps(VALID_ENV_CREDENTIALS)})
def test_get_credentials_env_variable(self):
creds = get_credentials()
VALID_CREDENTIALS_STRING.strip()
print(type(creds))
self.assertEqual(creds, json.loads(VALID_CREDENTIALS_STRING.replace('\n', '\\n')))
credentials = get_credentials()
credentials_string = credentials.get('credentials_string')
self.assertEqual(credentials_string, json.dumps(VALID_ENV_CREDENTIALS).replace('\n', '\\n'))

def test_get_credentials_with_config_level_creds(self):
test_creds = {"authToken": "test_token"}
Expand All @@ -41,12 +37,6 @@ def test_get_credentials_with_common_creds(self):
creds = get_credentials(common_skyflow_creds=test_creds)
self.assertEqual(creds, test_creds)

@patch.dict(os.environ, {"SKYFLOW_CREDENTIALS": INVALID_JSON_FORMAT})
def test_get_credentials_invalid_json_format(self):
with self.assertRaises(SkyflowError) as context:
get_credentials()
self.assertIn(context.exception.message, SkyflowMessages.Error.INVALID_JSON_FORMAT_IN_CREDENTIALS_ENV.value)

def test_get_vault_url_valid(self):
valid_cluster_id = "testCluster"
valid_env = Env.DEV
Expand Down
2 changes: 1 addition & 1 deletion tests/vault/client/test__client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_set_logger(self):
@patch("skyflow.vault.client.client.VaultClient.initialize_api_client")
def test_initialize_client_configuration(self, mock_init_api_client, mock_config, mock_get_vault_url,
mock_get_credentials):
mock_get_credentials.return_value = CREDENTIALS_WITH_API_KEY
mock_get_credentials.return_value = (CREDENTIALS_WITH_API_KEY)
mock_get_vault_url.return_value = "https://test-vault-url.com"

self.vault_client.initialize_client_configuration()
Expand Down
Loading