diff --git a/hyperleaup/creator.py b/hyperleaup/creator.py index 8e333d2..a1c6744 100644 --- a/hyperleaup/creator.py +++ b/hyperleaup/creator.py @@ -1,7 +1,7 @@ import os import logging from shutil import copyfile -from typing import List, Any +from typing import List, Any, Mapping from hyperleaup.creation_mode import CreationMode from pyspark.sql import DataFrame from pyspark.sql.types import * @@ -93,14 +93,15 @@ def get_table_def(df: DataFrame, schema_name: str, table_name: str) -> TableDefi ) -def insert_data_into_hyper_file(data: List[Any], name: str, table_def: TableDefinition): +def insert_data_into_hyper_file(data: List[Any], name: str, table_def: TableDefinition, + hyper_process_parameters: Mapping[str, str] = None): """Helper function that inserts data into a .hyper file.""" # first, create a temp directory on the driver node tmp_dir = f"/tmp/hyperleaup/{name}/" if not os.path.exists(tmp_dir): os.makedirs(tmp_dir) hyper_database_path = f"/tmp/hyperleaup/{name}/{name}.hyper" - with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hp: + with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU, parameters=hyper_process_parameters) as hp: with Connection(endpoint=hp.endpoint, database=hyper_database_path, create_mode=CreateMode.CREATE_AND_REPLACE) as connection: @@ -113,10 +114,11 @@ def insert_data_into_hyper_file(data: List[Any], name: str, table_def: TableDefi return hyper_database_path -def copy_data_into_hyper_file(csv_path: str, name: str, table_def: TableDefinition) -> str: +def copy_data_into_hyper_file(csv_path: str, name: str, table_def: TableDefinition, + hyper_process_parameters: Mapping[str, str] = None) -> str: """Helper function that copies data from a CSV file to a .hyper file.""" hyper_database_path = f"/tmp/hyperleaup/{name}/{name}.hyper" - with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hp: + with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU, parameters=hyper_process_parameters) as hp: with Connection(endpoint=hp.endpoint, database=Path(hyper_database_path), create_mode=CreateMode.CREATE_AND_REPLACE) as connection: @@ -132,11 +134,12 @@ def copy_data_into_hyper_file(csv_path: str, name: str, table_def: TableDefiniti return hyper_database_path -def copy_parquet_to_hyper_file(parquet_path: str, name: str, table_def: TableDefinition) -> str: +def copy_parquet_to_hyper_file(parquet_path: str, name: str, table_def: TableDefinition, + hyper_process_parameters: Mapping[str, str] = None) -> str: """Helper function that copies data from a Parquet file to a .hyper file.""" hyper_database_path = f"/tmp/hyperleaup/{name}/{name}.hyper" - with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hp: + with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU, parameters=hyper_process_parameters) as hp: with Connection(endpoint=hp.endpoint, database=Path(hyper_database_path), create_mode=CreateMode.CREATE_AND_REPLACE) as connection: @@ -258,7 +261,8 @@ class Creator: def __init__(self, df: DataFrame, name: str, is_dbfs_enabled: bool = False, creation_mode: str = CreationMode.COPY.value, - null_values_replacement = None): + null_values_replacement = None, + hyper_process_parameters: Mapping[str, str] = None): if null_values_replacement is None: null_values_replacement = {} self.df = df @@ -266,6 +270,7 @@ def __init__(self, df: DataFrame, name: str, self.is_dbfs_enabled = is_dbfs_enabled self.creation_mode = creation_mode self.null_values_replacement = null_values_replacement + self.hyper_process_parameters = hyper_process_parameters def create(self) -> str: """Creates a Tableau Hyper File given a SQL statement""" @@ -285,7 +290,7 @@ def create(self) -> str: # COPY data into a Tableau .hyper file logging.info("Copying data into Hyper File...") - database_path = copy_data_into_hyper_file(csv_path, self.name, table_def) + database_path = copy_data_into_hyper_file(csv_path, self.name, table_def, self.hyper_process_parameters) elif self.creation_mode.upper() == CreationMode.INSERT.value: @@ -299,7 +304,7 @@ def create(self) -> str: # Insert data into a Tableau .hyper file logging.info("Inserting data into Hyper File...") - database_path = insert_data_into_hyper_file(data, self.name, table_def) + database_path = insert_data_into_hyper_file(data, self.name, table_def, self.hyper_process_parameters) elif self.creation_mode.upper() == CreationMode.PARQUET.value: @@ -317,7 +322,8 @@ def create(self) -> str: # COPY data into a Tableau .hyper file logging.info("Copying data into Hyper File...") - database_path = copy_parquet_to_hyper_file(parquet_path, self.name, table_def) + database_path = copy_parquet_to_hyper_file(parquet_path, self.name, table_def, + self.hyper_process_parameters) else: raise ValueError(f'Invalid "creation_mode" specified: {self.creation_mode}') diff --git a/hyperleaup/hyper_file.py b/hyperleaup/hyper_file.py index efdf327..1f0e6c4 100644 --- a/hyperleaup/hyper_file.py +++ b/hyperleaup/hyper_file.py @@ -1,6 +1,7 @@ import os import logging from shutil import copyfile +from typing import Mapping from pyspark.sql import DataFrame from tableauhyperapi import HyperProcess, Telemetry, Connection, CreateMode, Inserter @@ -22,7 +23,8 @@ def __init__(self, name: str, sql: str = None, df: DataFrame = None, is_dbfs_enabled: bool = False, creation_mode: str = CreationMode.PARQUET.value, - null_values_replacement: dict = None): + null_values_replacement: dict = None, + hyper_process_parameters: Mapping[str, str] = None): self.name = name # Create a DataFrame from Spark SQL if sql is not None and df is None: @@ -33,6 +35,7 @@ def __init__(self, name: str, self.creation_mode = creation_mode self.is_dbfs_enabled = is_dbfs_enabled self.null_values_replacement = null_values_replacement + self.hyper_process_parameters = hyper_process_parameters # Do not create a Hyper File if loading an existing Hyper File if sql is None and df is None: self.path = None @@ -41,12 +44,14 @@ def __init__(self, name: str, self.name, self.is_dbfs_enabled, self.creation_mode, - self.null_values_replacement).create() + self.null_values_replacement, + self.hyper_process_parameters).create() self.luid = None def print_rows(self): """Prints the first 1,000 rows of a Hyper file""" - with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hp: + with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU, + parameters=self.hyper_process_parameters) as hp: with Connection(endpoint=hp.endpoint, database=self.path) as connection: rows = connection.execute_list_query(f"SELECT * FROM {TableName('Extract', 'Extract')} LIMIT 1000") print("Showing first 1,000 rows") @@ -55,7 +60,8 @@ def print_rows(self): def print_table_def(self, schema: str = "Extract", table: str = "Extract"): """Prints the table definition for a table in a Hyper file.""" - with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hp: + with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU, + parameters=self.hyper_process_parameters) as hp: with Connection(endpoint=hp.endpoint, database=self.path) as connection: table_name = TableName(schema, table) table_definition = connection.catalog.get_table_definition(name=table_name) @@ -108,7 +114,7 @@ def save(self, path: str) -> str: return copyfile(self.path, dest_path) @staticmethod - def load(path: str, is_dbfs_enabled: bool = False): + def load(path: str, is_dbfs_enabled: bool = False, hyper_process_parameters: Mapping[str, str] = None): """Loads a Hyper File from a source path to a temp dir""" # Guard against invalid paths if path.lower().startswith("s3"): @@ -151,7 +157,7 @@ def load(path: str, is_dbfs_enabled: bool = False): hyper_file_path = path # Create a HyperFile object with existing Hyper File path - hf = HyperFile(name=name, is_dbfs_enabled=is_dbfs_enabled) + hf = HyperFile(name=name, is_dbfs_enabled=is_dbfs_enabled, hyper_process_parameters=hyper_process_parameters) hf.path = hyper_file_path return hf @@ -174,7 +180,8 @@ def append(self, sql: str = None, df: DataFrame = None): # Insert, the new data into Hyper File hyper_database_path = self.path logging.info(f'Inserting new data into Hyper database: {hyper_database_path}') - with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU) as hp: + with HyperProcess(telemetry=Telemetry.DO_NOT_SEND_USAGE_DATA_TO_TABLEAU, + parameters=self.hyper_process_parameters) as hp: with Connection(endpoint=hp.endpoint, database=hyper_database_path, create_mode=CreateMode.NONE) as connection: diff --git a/requirements.txt b/requirements.txt index a91de78..41c6808 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ tableauserverclient==0.17.0 pyspark==3.1.3 requests==2.26.0 -tableauhyperapi==0.0.13129 +tableauhyperapi==0.0.16638 urllib3==1.26.6 diff --git a/tests/test_hyper_file.py b/tests/test_hyper_file.py index 1499b7a..0eeb617 100644 --- a/tests/test_hyper_file.py +++ b/tests/test_hyper_file.py @@ -1,7 +1,6 @@ import os from hyperleaup import HyperFile from hyperleaup.spark_fixture import get_spark_session - from tests.test_utils import TestUtils @@ -98,3 +97,31 @@ def test_append(self): hf.append(df=df) num_rows = TestUtils.get_row_count("Extract", "Extract", "/tmp/save/employees.hyper") assert(num_rows == 6) + + def test_hyper_process_parameters(self): + data_path = "/tmp/process_parameters" + + log_dir = "/tmp/logs" + log_file = f"{log_dir}/hyperd.log" + if not os.path.exists(log_dir): + os.mkdir(log_dir) + + data = [ + (1001, "Jane", "Doe", "2000-05-01", 29, False), + (1002, "John", "Doe", "1988-05-03", 29, False), + (2201, "Elonzo", "Smith", "1990-05-03", 29, True) + ] + df = get_spark_session().createDataFrame(data, ["id", "first_name", "last_name", "dob", "age", "is_temp"]) + + hyper_process_parameters = {"log_dir": log_dir} + + for mode in ["insert", "copy", "parquet"]: + if os.path.exists(log_file): + os.remove(log_file) + + HyperFile(name="employees", df=df, is_dbfs_enabled=False, creation_mode=mode, + hyper_process_parameters=hyper_process_parameters).save(data_path) + + # Make sure that the logs have been created in the non-standard location + assert(os.path.exists(log_file)) + assert(os.path.isfile(log_file))