diff --git a/requirements.txt b/requirements.txt index dc760cb..ac34d3a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,8 +14,9 @@ databricks-cli streamlit streamlit-echarts streamlit-extras - +langchain==0.0.223 +openai==0.27.8 azure-core==1.24.1 azure.identity azure-storage-blob -azure-data-tables \ No newline at end of file +azure-data-tables diff --git a/src/cddp/__init__.py b/src/cddp/__init__.py index e1485db..e02ac20 100644 --- a/src/cddp/__init__.py +++ b/src/cddp/__init__.py @@ -379,6 +379,8 @@ def init_staging_sample_dataframe(spark, config): output = task["output"]["type"] type = task["input"]["type"] task_landing_path = utils.get_path_for_current_env(type,task["input"]["path"]) + if not task_landing_path: + task_landing_path = staging_path if not os.path.exists(task_landing_path): os.makedirs(task_landing_path) filename = task['name']+".json" diff --git a/src/cddp/openai_api.py b/src/cddp/openai_api.py new file mode 100644 index 0000000..43f3745 --- /dev/null +++ b/src/cddp/openai_api.py @@ -0,0 +1,521 @@ +from langchain.chat_models import AzureChatOpenAI +from langchain.prompts import PromptTemplate +from langchain.chains import LLMChain +import json +import os + + +OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") +DEPLOYMENT = os.getenv("OPENAI_DEPLOYMENT") +MODEL = os.getenv("OPENAI_MODEL") +API_VERSION = os.getenv("OPENAI_API_VERSION") + + +def _prepare_openapi_llm(): + llm = AzureChatOpenAI(deployment_name=DEPLOYMENT, + model=MODEL, + openai_api_version=API_VERSION, + openai_api_base=OPENAI_API_BASE) + + return llm + + +def recommend_tables_for_industry(industry_name: str, industry_contexts: str): + """ Recommend database tables for a given industry and relevant contexts. + + :param industry_name: industry name + :param industry_contexts: industry descriptions/contexts + + :returns: recommened tables with schema in array of json format + """ + + recommaned_tables_for_industry_template = """ + You're a data engineer and familiar with the {industry_name} industry IT systems. + And below is relevant contexts of the industry: + {industry_contexts} + + Please recommend some database tables with potential data schema for the above contexts. + Your response should be an array of JSON format objects like below. + {{ + "table_name": "{{table name}}", + "table_description": "{{table description}}", + [ + {{ + "column_name": "{{column name}}", + "data_type": "{{data type}}", + "is_null": {{true or false}}, + "is_primary_key": {{true or false}}, + "is_foreign_key": {{true or false}} + }} + ] + }} + + Please recommend 7 to 10 database tables: + """ + + llm = _prepare_openapi_llm() + prompt = PromptTemplate( + input_variables=["industry_name", "industry_contexts"], + template=recommaned_tables_for_industry_template, + ) + chain = LLMChain(llm=llm, prompt=prompt) + response = chain({"industry_name": industry_name, + "industry_contexts": industry_contexts}) + results = response["text"] + + return results + + +def recommend_tables_for_industry_mock(industry_name: str, industry_contexts: str): + results = """ + [ + { + "table_name": "airlines", + "table_description": "Information about the airline companies", + "columns": + [ + { + "column_name": "airline_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": true, + "is_foreign_key": false + }, + { + "column_name": "name", + "data_type": "varchar(255)", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + }, + { + "column_name": "country", + "data_type": "varchar(255)", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + } + ] + }, + { + "table_name": "flights", + "table_description": "Information about flights operated by airlines", + "columns": + [ + { + "column_name": "flight_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": true, + "is_foreign_key": false + }, + { + "column_name": "airline_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": true + }, + { + "column_name": "origin", + "data_type": "varchar(255)", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + }, + { + "column_name": "destination", + "data_type": "varchar(255)", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + }, + { + "column_name": "departure_time", + "data_type": "datetime", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + }, + { + "column_name": "arrival_time", + "data_type": "datetime", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + } + ] + }, + { + "table_name": "passengers", + "table_description": "Information about passengers", + "columns": + [ + { + "column_name": "passenger_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": true, + "is_foreign_key": false + }, + { + "column_name": "name", + "data_type": "varchar(255)", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + }, + { + "column_name": "age", + "data_type": "integer", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + }, + { + "column_name": "gender", + "data_type": "varchar(255)", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + }, + { + "column_name": "flight_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": true + } + ] + }, + { + "table_name": "bookings", + "table_description": "Information about flight bookings made by passengers", + "columns": + [ + { + "column_name": "booking_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": true, + "is_foreign_key": false + }, + { + "column_name": "passenger_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": true + }, + { + "column_name": "flight_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": true + } + ] + }, + { + "table_name": "seats", + "table_description": "Information about seats available in flights", + "columns": + [ + { + "column_name": "seat_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": true, + "is_foreign_key": false + }, + { + "column_name": "flight_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": true + }, + { + "column_name": "passenger_id", + "data_type": "integer", + "is_null": true, + "is_primary_key": false, + "is_foreign_key": true + }, + { + "column_name": "seat_number", + "data_type": "varchar(255)", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + } + ] + }, + { + "table_name": "airports", + "table_description": "Information about airports", + "columns": + [ + { + "column_name": "airport_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": true, + "is_foreign_key": false + }, + { + "column_name": "name", + "data_type": "varchar(255)", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + }, + { + "column_name": "location", + "data_type": "varchar(255)", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + } + ] + } + ] + """ + + return results + + +def recommend_custom_table(industry_name: str, + industry_contexts: str, + recommened_tables: str, + custom_table_name: str, + custom_table_description: str): + """ Recommend custom/user-defined table for a input custom table name and description of a given industry. + + :param industry_name: industry name + :param industry_contexts: industry descriptions/contexts + :param recommened_tables: previously recommened tables in json string format + :param custom_table_name: custom table name + :param custom_table_description: custom table description + + :returns: recommened custom tables with schema in json format + """ + + recommend_custom_tables_template=""" + You're a data engineer and familiar with the {industry_name} industry IT systems. + And below is relevant contexts of the industry: + {industry_contexts} + + You've recommended below potential database tables and schema previously in json format. + {recommened_tables} + + Please help to add another {custom_table_name} table with the same table schema format, please include as many as possible columns reflecting the reality and the table description below. + {custom_table_description} + + Please only output the new added table without previously recommended tables, therefore the outcome would be: + """ + + llm = _prepare_openapi_llm() + prompt = PromptTemplate( + input_variables=["industry_name", "industry_contexts", "recommened_tables", "custom_table_name", "custom_table_description"], + template=recommend_custom_tables_template, + ) + chain = LLMChain(llm=llm, prompt=prompt) + response = chain({"industry_name": industry_name, + "industry_contexts": industry_contexts, + "recommened_tables": recommened_tables, + "custom_table_name": custom_table_name, + "custom_table_description": custom_table_description}) + results = response["text"] + + return results + + +def recommend_data_processing_logics(industry_name: str, + industry_contexts: str, + recommened_tables: str, + processing_logic: str): + """ Recommend data processing logics for a given industry and sample tables. + + :param industry_name: industry name + :param industry_contexts: industry descriptions/contexts + :param recommened_tables: previously recommened tables in json string format + :param processing_logic: either data cleaning, data transformation or data aggregation + + :returns: recommened data processing logics in array of json format + """ + + recommend_data_cleaning_logics_template=""" + You're a data engineer and familiar with the {industry_name} industry IT systems. + And below is relevant contexts of the industry: + {industry_contexts} + + You've recommended below potential database tables and schema previously in json format. + {recommened_tables} + + Please recommend 5 to 7 {processing_logic} logic over the above tables with Spark SQL statements. + You response should be in an array of JSON format like below. + {{ + "description": "{{descriptions on the data cleaning logic}}", + "involved_tables": [ + "{{involved table X}}", + "{{involved table Y}}", + "{{involved table Z}}" + ], + "sql": "{{Spark SQL statement to do the data cleaning}}", + "schema": "{{cleaned table schema in json string format}}" + }} + + Therefore outcome would be: + """ + + llm = _prepare_openapi_llm() + prompt = PromptTemplate( + input_variables=["industry_name", "industry_contexts", "processing_logic", "recommened_tables"], + template=recommend_data_cleaning_logics_template, + ) + chain = LLMChain(llm=llm, prompt=prompt) + response = chain({"industry_name": industry_name, + "industry_contexts": industry_contexts, + "processing_logic": processing_logic, + "recommened_tables": recommened_tables}) + results = json.loads(response["text"]) + + return results + + +def generate_custom_data_processing_logics(industry_name: str, + industry_contexts: str, + involved_tables: str, + custom_data_processing_logic: str, + output_table_name: str): + """ Generate custom data processing logics for input data processing requirements. + + :param industry_name: industry name + :param industry_contexts: industry descriptions/contexts + :param involved_tables: tables required by the custom data processing logic, in json string format + :param custom_data_processing_logic: custom data processing requirements + :param output_table_name: output/sink table name for processed data + + :returns: custom data processing logic json format + """ + + generate_custom_data_processing_logics_template = """ + You're a data engineer and familiar with the {industry_name} industry IT systems. + And below is relevant contexts of the industry: + {industry_contexts} + + We have database tables listed below with table_name and schema in JSON format. + {involved_tables} + + And below is the data processing requirement. + {custom_data_processing_logic} + + Please help to generate Spark SQL statement with output data schema. + And your response should be in JSON format like below. + {{ + "sql": "{{Spark SQL statement to do the data cleaning}}", + "schema": "{{output data schema in JSON string format}}" + }} + + And the above data schema string should follows below JSON format, while value for the "table_name" key should strictly be "{output_table_name}". + {{ + "table_name": "{output_table_name}", + "coloumns": [ + {{ + "column_name": "{{column name}}", + "data_type": "{{data type}}", + "is_null": {{true or false}}, + "is_primary_key": {{true or false}}, + "is_foreign_key": {{true or false}} + }} + ] + }} + + Therefore the outcome would be: + """ + + llm = _prepare_openapi_llm() + prompt = PromptTemplate( + input_variables=["industry_name", "industry_contexts", "involved_tables", "custom_data_processing_logic", "output_table_name"], + template=generate_custom_data_processing_logics_template, + ) + chain = LLMChain(llm=llm, prompt=prompt) + + # Run the chain only specifying the input variable. + response = chain({"industry_name": industry_name, + "industry_contexts": industry_contexts, + "custom_data_processing_logic": custom_data_processing_logic, + "involved_tables": involved_tables, + "output_table_name": output_table_name}) + results = response["text"] + + return results + + +def generate_sample_data(industry_name: str, + number_of_lines: int, + target_table: str, + column_values_patterns: str): + """ Generate custom data processing logics for input data processing requirements. + + :param industry_name: industry name + :param number_of_lines: number of lines sample data required + :param target_table: target table name and its schema in json format + + :returns: generated sample data in array of json format + """ + + generate_sample_data_template = """ + You're a data engineer and familiar with the {industry_name} industry IT systems. + Please help to generate {number_of_lines} lines of sample data for below table with table schema in json format. + {target_table} + + And below are patterns of column values in json format, if it's not provided please ignore this requirement. + {column_values_patterns} + + And the sample data should be an array of json object like below. + {{ + "{{column X}}": "{{column value}}", + "{{column Y}}": "{{column value}}", + "{{column Z}}": "{{column value}}" + }} + + The sample data would be: + """ + + llm = _prepare_openapi_llm() + prompt = PromptTemplate( + input_variables=["industry_name", "number_of_lines", "target_table", "column_values_patterns"], + template=generate_sample_data_template, + ) + chain = LLMChain(llm=llm, prompt=prompt) + response = chain({"industry_name": industry_name, + "number_of_lines": number_of_lines, + "target_table": target_table, + "column_values_patterns": column_values_patterns}) + results = response["text"] + + return results + + +def generate_sample_data_mock(industry_name: str, + number_of_lines: int, + target_table: str, + column_values_patterns: str): + if target_table["table_name"] == "flights": + results = """ + [ { "flight_id": 1, "airline_id": 1001, "origin": "New York", "destination": "Los Angeles", "departure_time": "2021-01-01 08:00:00", "arrival_time": "2021-01-01 11:30:00" }, { "flight_id": 2, "airline_id": 1002, "origin": "London", "destination": "Paris", "departure_time": "2021-01-02 14:30:00", "arrival_time": "2021-01-02 16:00:00" }, { "flight_id": 3, "airline_id": 1003, "origin": "Tokyo", "destination": "Sydney", "departure_time": "2021-01-03 10:45:00", "arrival_time": "2021-01-04 06:15:00" }, { "flight_id": 4, "airline_id": 1004, "origin": "Chicago", "destination": "Miami", "departure_time": "2021-01-05 16:20:00", "arrival_time": "2021-01-05 19:45:00" }, { "flight_id": 5, "airline_id": 1005, "origin": "Sydney", "destination": "Melbourne", "departure_time": "2021-01-06 09:15:00", "arrival_time": "2021-01-06 10:30:00" }, { "flight_id": 6, "airline_id": 1001, "origin": "Los Angeles", "destination": "New York", "departure_time": "2021-01-07 12:00:00", "arrival_time": "2021-01-07 15:30:00" }, { "flight_id": 7, "airline_id": 1002, "origin": "Paris", "destination": "London", "departure_time": "2021-01-08 18:45:00", "arrival_time": "2021-01-08 20:15:00" }, { "flight_id": 8, "airline_id": 1003, "origin": "Sydney", "destination": "Tokyo", "departure_time": "2021-01-09 14:30:00", "arrival_time": "2021-01-10 08:00:00" }, { "flight_id": 9, "airline_id": 1004, "origin": "Miami", "destination": "Chicago", "departure_time": "2021-01-11 20:00:00", "arrival_time": "2021-01-11 23:25:00" }, { "flight_id": 10, "airline_id": 1005, "origin": "Melbourne", "destination": "Sydney", "departure_time": "2021-01-12 13:45:00", "arrival_time": "2021-01-12 15:00:00" }, { "flight_id": 11, "airline_id": 1001, "origin": "New York", "destination": "Los Angeles", "departure_time": "2021-01-13 08:00:00", "arrival_time": "2021-01-13 11:30:00" }, { "flight_id": 12, "airline_id": 1002, "origin": "London", "destination": "Paris", "departure_time": "2021-01-14 14:30:00", "arrival_time": "2021-01-14 16:00:00" }, { "flight_id": 13, "airline_id": 1003, "origin": "Tokyo", "destination": "Sydney", "departure_time": "2021-01-15 10:45:00", "arrival_time": "2021-01-16 06:15:00" }, { "flight_id": 14, "airline_id": 1004, "origin": "Chicago", "destination": "Miami", "departure_time": "2021-01-17 16:20:00", "arrival_time": "2021-01-17 19:45:00" }, { "flight_id": 15, "airline_id": 1005, "origin": "Sydney", "destination": "Melbourne", "departure_time": "2021-01-18 09:15:00", "arrival_time": "2021-01-18 10:30:00" }, { "flight_id": 16, "airline_id": 1001, "origin": "Los Angeles", "destination": "New York", "departure_time": "2021-01-19 12:00:00", "arrival_time": "2021-01-19 15:30:00" }, { "flight_id": 17, "airline_id": 1002, "origin": "Paris", "destination": "London", "departure_time": "2021-01-20 18:45:00", "arrival_time": "2021-01-20 20:15:00" }, { "flight_id": 18, "airline_id": 1003, "origin": "Sydney", "destination": "Tokyo", "departure_time": "2021-01-21 14:30:00", "arrival_time": "2021-01-22 08:00:00" }, { "flight_id": 19, "airline_id": 1004, "origin": "Miami", "destination": "Chicago", "departure_time": "2021-01-23 20:00:00", "arrival_time": "2021-01-23 23:25:00" }, { "flight_id": 20, "airline_id": 1005, "origin": "Melbourne", "destination": "Sydney", "departure_time": "2021-01-24 13:45:00", "arrival_time": "2021-01-24 15:00:00" } ] + """ + + if target_table["table_name"] == "passengers": + results = """ + [ { "passenger_id": 1, "name": "John Smith", "age": 35, "gender": "Male", "flight_id": 1 }, { "passenger_id": 2, "name": "Jane Doe", "age": 45, "gender": "Female", "flight_id": 1 }, { "passenger_id": 3, "name": "Michael Johnson", "age": 60, "gender": "Male", "flight_id": 2 }, { "passenger_id": 4, "name": "Emily Williams", "age": 25, "gender": "Female", "flight_id": 3 }, { "passenger_id": 5, "name": "David Brown", "age": 55, "gender": "Male", "flight_id": 4 }, { "passenger_id": 6, "name": "Sarah Davis", "age": 30, "gender": "Female", "flight_id": 4 }, { "passenger_id": 7, "name": "Robert Martinez", "age": 65, "gender": "Male", "flight_id": 5 }, { "passenger_id": 8, "name": "Jessica Thomas", "age": 40, "gender": "Female", "flight_id": 5 }, { "passenger_id": 9, "name": "Christopher Wilson", "age": 50, "gender": "Male", "flight_id": 1 }, { "passenger_id": 10, "name": "Stephanie Taylor", "age": 27, "gender": "Female", "flight_id": 2 }, { "passenger_id": 11, "name": "Daniel Anderson", "age": 65, "gender": "Male", "flight_id": 3 }, { "passenger_id": 12, "name": "Melissa Thompson", "age": 42, "gender": "Female", "flight_id": 4 }, { "passenger_id": 13, "name": "Matthew White", "age": 32, "gender": "Male", "flight_id": 5 }, { "passenger_id": 14, "name": "Amanda Harris", "age": 52, "gender": "Female", "flight_id": 1 }, { "passenger_id": 15, "name": "Andrew Lee", "age": 65, "gender": "Male", "flight_id": 2 }, { "passenger_id": 16, "name": "Jennifer Clark", "age": 28, "gender": "Female", "flight_id": 3 }, { "passenger_id": 17, "name": "James Rodriguez", "age": 62, "gender": "Male", "flight_id": 4 }, { "passenger_id": 18, "name": "Nicole Walker", "age": 38, "gender": "Female", "flight_id": 5 }, { "passenger_id": 19, "name": "Ryan Wright", "age": 47, "gender": "Male", "flight_id": 1 }, { "passenger_id": 20, "name": "Lauren Hall", "age": 31, "gender": "Female", "flight_id": 2 } ] + """ + + if target_table["table_name"] == "bookings": + results = """ + [ { "booking_id": 1, "passenger_id": 1, "flight_id": 1 }, { "booking_id": 2, "passenger_id": 2, "flight_id": 2 }, { "booking_id": 3, "passenger_id": 3, "flight_id": 3 }, { "booking_id": 4, "passenger_id": 4, "flight_id": 4 }, { "booking_id": 5, "passenger_id": 5, "flight_id": 5 }, { "booking_id": 6, "passenger_id": 1, "flight_id": 2 }, { "booking_id": 7, "passenger_id": 2, "flight_id": 3 }, { "booking_id": 8, "passenger_id": 3, "flight_id": 4 }, { "booking_id": 9, "passenger_id": 4, "flight_id": 5 }, { "booking_id": 10, "passenger_id": 5, "flight_id": 1 }, { "booking_id": 11, "passenger_id": 1, "flight_id": 3 }, { "booking_id": 12, "passenger_id": 2, "flight_id": 4 }, { "booking_id": 13, "passenger_id": 3, "flight_id": 5 }, { "booking_id": 14, "passenger_id": 4, "flight_id": 1 }, { "booking_id": 15, "passenger_id": 5, "flight_id": 2 }, { "booking_id": 16, "passenger_id": 1, "flight_id": 4 }, { "booking_id": 17, "passenger_id": 2, "flight_id": 5 }, { "booking_id": 18, "passenger_id": 3, "flight_id": 1 }, { "booking_id": 19, "passenger_id": 4, "flight_id": 2 }, { "booking_id": 20, "passenger_id": 5, "flight_id": 3 } ] + """ + + return results diff --git a/src/streamlit/app.py b/src/streamlit/app.py index 585de86..d11edbc 100644 --- a/src/streamlit/app.py +++ b/src/streamlit/app.py @@ -282,7 +282,7 @@ def import_pipeline(): pipeline_name = st.text_input('Pipeline name', key='pipeline_name', value=current_pipeline_obj['name']) if pipeline_name: current_pipeline_obj['name'] = pipeline_name - industry_list = ["Other", "Agriculture", "Automotive", "Banking", "Chemical", "Construction", "Education", "Energy", "Entertainment", "Food", "Government", "Healthcare", "Hospitality", "Insurance", "Machinery", "Manufacturing", "Media", "Mining", "Pharmaceutical", "Real Estate", "Retail", "Telecommunications", "Transportation", "Utilities", "Wholesale"] + industry_list = ["Other", "Airlines", "Agriculture", "Automotive", "Banking", "Chemical", "Construction", "Education", "Energy", "Entertainment", "Food", "Government", "Healthcare", "Hospitality", "Insurance", "Machinery", "Manufacturing", "Media", "Mining", "Pharmaceutical", "Real Estate", "Retail", "Telecommunications", "Transportation", "Utilities", "Wholesale"] industry_selected_idx = 0 if 'industry' in current_pipeline_obj: industry_selected_idx = industry_list.index(current_pipeline_obj['industry']) diff --git a/src/streamlit/pages/1_Editor.py b/src/streamlit/pages/1_Editor.py index 20357a3..76d07ac 100644 --- a/src/streamlit/pages/1_Editor.py +++ b/src/streamlit/pages/1_Editor.py @@ -3,6 +3,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) import cddp import streamlit as st +import streamlit_utils import pandas as pd import json from pyspark.sql import SparkSession @@ -16,6 +17,7 @@ import uuid from io import StringIO import numpy as np +from cddp import openai_api from streamlit_echarts import st_echarts from streamlit_extras.chart_container import chart_container from streamlit_extras.stylable_container import stylable_container @@ -26,6 +28,7 @@ if "working_folder" not in st.session_state: switch_page("Home") + st.set_page_config(page_title="CDDP - Pipeline Editor") current_pipeline_obj = None @@ -120,11 +123,14 @@ def run_task(task_name, stage="standard"): res_df = None if stage == "standard": res_df = cddp.start_standard_job(spark, config, task, False, True) + schema = res_df.schema.json() elif stage == "serving": res_df = cddp.start_serving_job(spark, config, task, False, True) + schema = res_df.schema.json() dataframe = res_df.toPandas() print(dataframe) st.session_state[f'_{task_name}_data'] = dataframe + st.session_state[f'_{task_name}_schema'] = schema except Exception as e: st.error(f"Cannot run task: {e}") @@ -172,7 +178,7 @@ def create_pipeline(): colored_header( label="Config-Driven Data Pipeline Editor", - description=f"Pipeline ID: {current_pipeline_obj['id']}", + description=f"Pipeline ID: {current_pipeline_obj.get('id', str(uuid.uuid4()))}", color_name="violet-70", ) @@ -185,6 +191,8 @@ def import_pipeline(): def get_pipeline_path(): if 'working_folder' not in st.session_state: + if "id" not in current_pipeline_obj: + current_pipeline_obj['id'] = str(uuid.uuid4()) return"./"+current_pipeline_obj['id']+".json" else: return st.session_state["working_folder"]+"/"+current_pipeline_obj['id']+".json" @@ -266,7 +274,7 @@ def publish_pipeline_to_gallery(): if pipeline_name: current_pipeline_obj['name'] = pipeline_name st.text_input('Pipeline ID', key='pipeline_id', value=current_pipeline_obj['id'], disabled=True) - industry_list = ["Other", "Agriculture", "Automotive", "Banking", "Chemical", "Construction", "Education", "Energy", "Entertainment", "Food", "Government", "Healthcare", "Hospitality", "Insurance", "Machinery", "Manufacturing", "Media", "Mining", "Pharmaceutical", "Real Estate", "Retail", "Telecommunications", "Transportation", "Utilities", "Wholesale"] + industry_list = ["Other", "Airlines", "Agriculture", "Automotive", "Banking", "Chemical", "Construction", "Education", "Energy", "Entertainment", "Food", "Government", "Healthcare", "Hospitality", "Insurance", "Machinery", "Manufacturing", "Media", "Mining", "Pharmaceutical", "Real Estate", "Retail", "Telecommunications", "Transportation", "Utilities", "Wholesale"] industry_selected_idx = 0 if 'industry' in current_pipeline_obj: industry_selected_idx = industry_list.index(current_pipeline_obj['industry']) @@ -286,17 +294,36 @@ def publish_pipeline_to_gallery(): st.divider() + generated_tables = [] + selected_tables = [] + if "current_generated_tables" in st.session_state: + if "selected_tables" in st.session_state["current_generated_tables"]: + selected_tables = st.session_state["current_generated_tables"]["selected_tables"] + if "generated_tables" in st.session_state["current_generated_tables"]: + generated_tables = json.loads(st.session_state["current_generated_tables"]["generated_tables"]) + + if "current_generated_sample_data" not in st.session_state: + st.session_state['current_generated_sample_data'] = {} + current_generated_sample_data = st.session_state['current_generated_sample_data'] st.subheader('Staging Zone') pipeline_obj = st.session_state['current_pipeline_obj'] + if "staged_tables" not in st.session_state: + st.session_state["staged_tables"] = [] + staged_tables = st.session_state["staged_tables"] + for i in range(len(pipeline_obj["staging"]) ): target_name = pipeline_obj['staging'][i]['output']['target'] stg_name = st.text_input(f'Dataset Name', key=f"stg_{i}_name", value=target_name) if stg_name: with st.expander(stg_name+" Settings"): + # selected_table = st.selectbox( + # 'Choose a dataset to add to staging zone', + # selected_tables, + # key=f'stg_{i}_ai_dataset') # st.selectbox( # 'Choose a dataset to add to staging zone', # ['Dataset 1', 'Dataset 2', 'Dataset 3', 'Dataset 4', 'Dataset 5'], key=f'stg_{i}_ai_dataset') @@ -312,7 +339,7 @@ def publish_pipeline_to_gallery(): uploaded_file = st.file_uploader(f'Choose sample csv file', key=f'stg_{i}_file') if uploaded_file is not None: - + dataframe = pd.read_csv(uploaded_file) spark = st.session_state["spark"] stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) @@ -325,7 +352,7 @@ def publish_pipeline_to_gallery(): if 'sampleData' in pipeline_obj['staging'][i] and len(pipeline_obj['staging'][i]['sampleData']) > 0: sampleData = pipeline_obj['staging'][i]['sampleData'] dataframe = pd.DataFrame(sampleData) - st.dataframe(dataframe) + st.dataframe(dataframe) st.button("Delete", key="delete_stg_"+str(i), on_click=delete_task, args = ['staging', i]) @@ -365,6 +392,10 @@ def add_dataset(): st.divider() + if "current_generated_std_srv_sqls" not in st.session_state: + st.session_state['current_generated_std_srv_sqls'] = {} + current_generated_std_srv_sqls = st.session_state['current_generated_std_srv_sqls'] + st.subheader('Standardization Zone') for i in range(len(pipeline_obj["standard"])): @@ -383,6 +414,13 @@ def add_dataset(): if std_desc: pipeline_obj['standard'][i]['description'] = std_desc + # Get all staged table details + staged_table_names, staged_table_details = streamlit_utils.get_staged_tables() + + st.multiselect( + 'Choose datasets to add to transformation', + staged_table_names, + key=f'std_{i}_ai_dataset') # st.multiselect( # 'Choose datasets to add to transformation', # ['Dataset 1', 'Dataset 2', 'Dataset 3', 'Dataset 4', 'Dataset 5'], key=f'std_{i}_ai_dataset') @@ -390,8 +428,15 @@ def add_dataset(): st.button(f'Generate SQL', key=f'std_{i}_gen') if len(current_pipeline_obj['standard'][i]['code']['sql']) == 0: current_pipeline_obj['standard'][i]['code']['sql'].append("") - std_sql_val = current_pipeline_obj['standard'][i]['code']['sql'][0] - std_sql = st.text_area(f'SQL', key=f'std_{i}_sql', value=std_sql_val) + # std_sql_val = current_pipeline_obj['standard'][i]['code']['sql'][0] + current_generated_std_srv_sqls[std_name] = current_pipeline_obj['standard'][i]['code']['sql'][0] + + std_sql = st.text_area(f'SQL', + key=f'std_{i}_sql', + # value=std_sql_val, + value=current_generated_std_srv_sqls[std_name], + on_change=streamlit_utils.update_sql, + args=[f'std_{i}_sql', std_name]) if std_sql: pipeline_obj['standard'][i]['code']['sql'][0] = std_sql @@ -399,6 +444,10 @@ def add_dataset(): st.button(f'Run SQL', key=f'run_std_{i}_sql', on_click=run_task, args = [std_name, "standard"]) if '_'+std_name+'_data' in st.session_state: st.dataframe(st.session_state['_'+std_name+'_data']) + if f"_{std_name}_schema" in st.session_state: + streamlit_utils.add_std_srv_schema("standard", + std_name, + st.session_state[f"_{std_name}_schema"]) st.button("Delete", key="delete_std_"+str(i), on_click=delete_task, args = ['standard', i]) @@ -442,6 +491,13 @@ def add_transformation(): if srv_desc: pipeline_obj['serving'][i]['description'] = srv_desc + # Get all standardized table details + standardized_table_names, standardized_table_details = streamlit_utils.get_standardized_tables() + + st.multiselect( + 'Choose datasets to add to aggregation', + staged_table_names + standardized_table_names, + key=f'srv_{i}_ai_dataset') # st.multiselect( # 'Choose datasets to add to aggregation', # ['Dataset 1', 'Dataset 2', 'Dataset 3', 'Dataset 4', 'Dataset 5'], key=f'srv_{i}_ai_dataset') @@ -449,8 +505,14 @@ def add_transformation(): st.button(f'Generate SQL', key=f'srv_{i}_gen') if len(current_pipeline_obj['serving'][i]['code']['sql']) == 0: current_pipeline_obj['serving'][i]['code']['sql'].append("") - srv_sql_val = current_pipeline_obj['serving'][i]['code']['sql'][0] - srv_sql = st.text_area(f'SQL', key=f'srv_{i}_sql', value=srv_sql_val) + # srv_sql_val = current_pipeline_obj['serving'][i]['code']['sql'][0] + current_generated_std_srv_sqls[srv_name] = current_pipeline_obj['serving'][i]['code']['sql'][0] + srv_sql = st.text_area(f'SQL', + key=f'srv_{i}_sql', + # value=srv_sql_val, + value=current_generated_std_srv_sqls[srv_name], + on_change=streamlit_utils.update_sql, + args=[f'srv_{i}_sql', srv_name]) if srv_sql: pipeline_obj['serving'][i]['code']['sql'][0] = srv_sql diff --git a/src/streamlit/pages/2_AI Assistant.py b/src/streamlit/pages/2_AI Assistant.py index 9f22fbd..38530fc 100644 --- a/src/streamlit/pages/2_AI Assistant.py +++ b/src/streamlit/pages/2_AI Assistant.py @@ -11,7 +11,7 @@ st.set_page_config(page_title="AI Assiatant") st.markdown("# AI Assiatant") - +st.sidebar.header("AI Assiatant") # Display pipeline basic info fetching from Editor page if it's maintained there current_pipeline_obj = {} @@ -272,4 +272,4 @@ key=f'srv_aggregate_sql_{srv_name}', value=current_generated_std_srv_sqls[srv_name], on_change=streamlit_utils.update_sql, - args=[f'std_transform_sql_{srv_name}', srv_name]) \ No newline at end of file + args=[f'std_transform_sql_{srv_name}', srv_name]) diff --git a/src/streamlit/streamlit_utils.py b/src/streamlit/streamlit_utils.py index 73b6811..3a99e1a 100644 --- a/src/streamlit/streamlit_utils.py +++ b/src/streamlit/streamlit_utils.py @@ -162,4 +162,4 @@ def get_standardized_tables(): "schema": current_std_srv_tables_schema["standard"].get(std_name, "") }) - return standardized_table_names, standardized_table_names \ No newline at end of file + return standardized_table_names, standardized_table_names