Skip to content
Draft
48 changes: 47 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,111 +5,157 @@ npm-debug.log*
yarn-debug.log*
yarn-error.log*


# Runtime data
pids
*.pid
*.seed
*.pid.lock


# Directory for instrumented libs generated by jscoverage/JSCover
lib-cov


# Coverage directory used by tools like istanbul
coverage


# nyc test coverage
.nyc_output


# Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files)
.grunt


# Bower dependency directory (https://bower.io/)
bower_components


# node-waf configuration
.lock-wscript


# Compiled binary addons (http://nodejs.org/api/addons.html)
build/Release


# Dependency directories
node_modules/
jspm_packages/


# TypeScript v1 declaration files
typings/


# Optional npm cache directory
.npm


# Optional eslint cache
.eslintcache


# Optional REPL history
.node_repl_history


# Output of 'npm pack'
*.tgz


# Yarn Integrity file
.yarn-integrity


# dotenv environment variables file
.env
.env.test


# parcel-bundler cache (https://parceljs.org/)
.cache


# next.js build output
.next


# vuepress build output
.vuepress/dist


# Serverless directories
.serverless/


# FuseBox cache
.fusebox/


# DynamoDB Local files
.dynamodb/


# IDEA files
.idea/


# VSCode files
.vscode/


# Jupyter Notebook
.ipynb_checkpoints


# pytest
.pytest_cache/


# Cypress
cypress/videos
cypress/screenshots


# SASS
.sass-cache/


# Python
__pycache__/
*.py[cod]
*$py.class


# Cello
.cello


# Selenium
.wdm


# Jest
/coverage


# End of https://www.toptal.com/developers/gitignore/api/node


.venv/
venv/
venv/
server/data/bird
server/batch_job_metadata
server/chroma
server/finetune
server/new_finetune
server/scripts/v1_dataset
server/scripts/v1_wikisql
server/vector_cache
server/data/v1_wikisql
server/data/wikisql/dev_dataset
server/data/wikisql/test_dataset

31 changes: 31 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
alembic==1.13.2
alive_progress==3.2.0
anthropic==0.45.2
attr==0.3.2
bpemb==0.3.6
chromadb==0.5.15
datasketch==1.6.5
datatypes==0.20.0
datefinder==0.7.3
fastapi==0.115.8
func_timeout==4.3.5
llama_index==0.12.17
networkx==3.4.2
nltk==3.9.1
numpy==1.24.4
openai==1.63.0
pandas==2.2.3
protobuf==3.20.3
pydantic==2.10.6
pytest==8.3.4
Requests==2.32.3
scikit_learn==1.6.1
snowflake_connector_python==3.12.3
SQLAlchemy==2.0.28
sqlglot==26.6.0
sqlparse==0.5.3
tiktoken==0.8.0
torch==2.2.2
torchtext==0.17.2
tqdm==4.66.6
uvicorn==0.34.0
6 changes: 5 additions & 1 deletion server/app/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@ def set_database(database_name: str):
if PATH_CONFIG.dataset_type == DatasetType.SYNTHETIC:
if database_name not in [db_type.value for db_type in DatabaseType]:
raise ValueError(ERROR_DATABASE_NOT_FOUND.format(database_name=database_name))


if PATH_CONFIG.dataset_type in [DatasetType.WIKI_DEV, DatasetType.WIKI_TEST]:
if database_name not in os.listdir(PATH_CONFIG.dataset_dir()) or f"{database_name}.db" not in os.listdir(PATH_CONFIG.database_dir(database_name=database_name)):
raise ValueError(ERROR_DATABASE_NOT_FOUND.format(database_name=database_name))

PATH_CONFIG.set_database(database_name)

global engine, SessionLocal
Expand Down
44 changes: 31 additions & 13 deletions server/preprocess/add_descriptions_bird_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def read_csv(file_path, encodings=["utf-8-sig", "ISO-8859-1"]):

def extract_column_type_from_schema(connection, table_name, column_name):
""" Fetches the column data type from a SQLite table schema. """

logger.info(f"Extract col type")

# Fetch the table schema from SQLite
cursor = connection.cursor()
cursor.execute(f"PRAGMA table_info({table_name});")
Expand Down Expand Up @@ -98,6 +99,7 @@ def improve_column_descriptions(database_name, dataset_type, client):

errors = []
try:
logger.info("Connecting to db")
connection = sqlite3.connect(
PATH_CONFIG.sqlite_path(database_name=database_name, dataset_type=dataset_type)
)
Expand All @@ -106,12 +108,15 @@ def improve_column_descriptions(database_name, dataset_type, client):
table_description_csv_path = os.path.join(
base_path, f"{database_name}_tables.csv"
) # TO DO: Update file path as a constant after finalizing path organization.
logger.info("Starting to read")

table_description_df = read_csv(table_description_csv_path)
tables_in_database = get_table_names(connection)
logger.info("Looping")

for table_csv in os.listdir(base_path):
# Skip the Overall Tables Description File
logger.info(f"AT TABLE {table_csv}")
if table_csv == f"{database_name}_tables.csv":
continue

Expand All @@ -124,7 +129,6 @@ def improve_column_descriptions(database_name, dataset_type, client):
})
logger.error(f"Table '{table_name}' does not exist in the SQLite database. Please check {table_csv}")
continue

table_df = read_csv(os.path.join(base_path, table_csv))

# Fetch the first row of data for the table
Expand All @@ -139,9 +143,10 @@ def improve_column_descriptions(database_name, dataset_type, client):
if len(table_description_df) > 0
else "No description available."
)

logger.info("Got table description")
# Get column names from the SQLite database for validation
column_names = get_table_columns(connection, table_name)
logger.info("Got table column")

# Generate improved column descriptions
for idx, row in table_df.iterrows():
Expand All @@ -161,6 +166,8 @@ def improve_column_descriptions(database_name, dataset_type, client):
logger.error(f"Column '{row['original_column_name']}' does not exist. Please check {table_csv}.")
continue

logger.info("Generating improved col description")

improved_description = get_imrpoved_coloumn_description(row, table_name, first_row, connection, client, table_description, errors, database_name)

# Update the improved description in the DataFrame and save it to the CSV file
Expand All @@ -180,7 +187,6 @@ def improve_column_descriptions(database_name, dataset_type, client):
finally:
return errors


def create_database_tables_csv(database_name, dataset_type, client):
""" Create a {database_name}_tables.csv file with table descriptions and connected tables. """

Expand All @@ -197,14 +203,19 @@ def create_database_tables_csv(database_name, dataset_type, client):
table_descriptions = read_csv(table_description_csv_path)
except Exception as e:
table_descriptions = pd.DataFrame(columns=['table_name', 'table_description'])
logger.info(f"Table creating...")

connection = sqlite3.connect(PATH_CONFIG.sqlite_path(database_name=database_name, dataset_type=dataset_type))
cursor = connection.cursor()

logger.info(f"Enter format schema")
schema_ddl = format_schema(FormatType.CODE, database_name=database_name, dataset_type=dataset_type)
logger.info(f"Get table names")
tables = get_table_names(connection)

for table_name in tables:
logger.info(f"Processing table: {table_name}")

if table_name in table_descriptions['table_name'].values and not pd.isna(table_descriptions.loc[table_descriptions['table_name'] == table_name, 'table_description'].values[0]):
logger.info(f"Table {table_name} already has a description. Skipping LLM call.")
continue
Expand All @@ -226,7 +237,6 @@ def create_database_tables_csv(database_name, dataset_type, client):
table_description = client.execute_prompt(prompt)
except Exception as e:
if GOOGLE_RESOURCE_EXHAUSTED_EXCEPTION_STR in str(e):
# Rate limit exceeded: Too many requests. Retrying in 5 seconds...
time.sleep(5)
else:
errors.append({
Expand All @@ -240,12 +250,10 @@ def create_database_tables_csv(database_name, dataset_type, client):
"table_description": [table_description]
})
table_descriptions = pd.concat([table_descriptions, new_row], ignore_index=True)

# Save to CSV after processing each table
table_descriptions.to_csv(table_description_csv_path, index=False)


# Create a DataFrame for all tables
result_df = pd.DataFrame(table_descriptions)
result_df.to_csv(table_description_csv_path, index=False)

connection.close()
except Exception as e:
errors.append({
Expand All @@ -258,22 +266,28 @@ def create_database_tables_csv(database_name, dataset_type, client):

def ensure_description_files_exist(database_name, dataset_type):
""" Ensure description files exist for the given database. If not, create them from column_meaning.json. """
logger.info(f"STEP 0: Setup check")

# Get the base path for the description files and the path to the column meaning file

base_path = PATH_CONFIG.description_dir(database_name=database_name, dataset_type=dataset_type)
logger.info(f"Description path: {base_path}")

column_meaning_path = PATH_CONFIG.column_meaning_path(dataset_type=dataset_type)

logger.info(f"Col meaning path: {column_meaning_path}")

# If the column meaning file exists, load it
column_meaning = None
if os.path.exists(column_meaning_path):
with open(column_meaning_path, 'r') as f:
column_meaning = json.load(f)
logger.info(f"Col meaning")

# If the description directory does not exist and column_meaning is loaded, create the directory and files
if not os.path.exists(base_path) and column_meaning:
os.makedirs(base_path)
connection = sqlite3.connect(PATH_CONFIG.sqlite_path(database_name=database_name, dataset_type=dataset_type))

# Iterate over the column meanings
table_data = {}
for key, description in column_meaning.items():
Expand Down Expand Up @@ -345,10 +359,14 @@ def add_database_descriptions(

error_list = []
for database in tqdm(databases, desc="Generating Descriptions for databases"):
logger.info(f"DATABASE: {database}")
try:
ensure_description_files_exist(database_name=database, dataset_type=dataset_type)
errors = create_database_tables_csv(database_name=database, dataset_type=dataset_type, client=client)
logger.info(f"STEP 1: Create CSV")
#errors = create_database_tables_csv(database_name=database, dataset_type=dataset_type, client=client)
errors = None
if not errors:
logger.info(f"STEP 2: Improve Description")
errors = improve_column_descriptions(database_name=database, dataset_type=dataset_type, client=client)

error_list.extend(errors)
Expand Down
Loading