From 8fb20a2cdf5407ba5762ef18bb0c145e9bda18be Mon Sep 17 00:00:00 2001 From: thurston-chen Date: Mon, 11 Sep 2023 22:05:56 +0800 Subject: [PATCH 1/2] feat: initial commit on CDDP AI assistant integration over Streamlit pages --- requirements.txt | 4 +- src/cddp/__init__.py | 2 + src/cddp/openai_api.py | 521 ++++++++++++++++++++++++++ src/streamlit/app.py | 2 +- src/streamlit/pages/1_Editor.py | 58 ++- src/streamlit/pages/2_AI Assistant.py | 275 +++++++++++++- src/streamlit/streamlit_utils.py | 165 ++++++++ 7 files changed, 1010 insertions(+), 17 deletions(-) create mode 100644 src/cddp/openai_api.py create mode 100644 src/streamlit/streamlit_utils.py diff --git a/requirements.txt b/requirements.txt index b6cf549..5ff8d0d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,6 @@ flask databricks-cli streamlit streamlit-echarts -streamlit-extras \ No newline at end of file +streamlit-extras +langchain==0.0.223 +openai==0.27.8 \ No newline at end of file diff --git a/src/cddp/__init__.py b/src/cddp/__init__.py index e1485db..e02ac20 100644 --- a/src/cddp/__init__.py +++ b/src/cddp/__init__.py @@ -379,6 +379,8 @@ def init_staging_sample_dataframe(spark, config): output = task["output"]["type"] type = task["input"]["type"] task_landing_path = utils.get_path_for_current_env(type,task["input"]["path"]) + if not task_landing_path: + task_landing_path = staging_path if not os.path.exists(task_landing_path): os.makedirs(task_landing_path) filename = task['name']+".json" diff --git a/src/cddp/openai_api.py b/src/cddp/openai_api.py new file mode 100644 index 0000000..43f3745 --- /dev/null +++ b/src/cddp/openai_api.py @@ -0,0 +1,521 @@ +from langchain.chat_models import AzureChatOpenAI +from langchain.prompts import PromptTemplate +from langchain.chains import LLMChain +import json +import os + + +OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") +DEPLOYMENT = os.getenv("OPENAI_DEPLOYMENT") +MODEL = os.getenv("OPENAI_MODEL") +API_VERSION = os.getenv("OPENAI_API_VERSION") + + +def _prepare_openapi_llm(): + llm = AzureChatOpenAI(deployment_name=DEPLOYMENT, + model=MODEL, + openai_api_version=API_VERSION, + openai_api_base=OPENAI_API_BASE) + + return llm + + +def recommend_tables_for_industry(industry_name: str, industry_contexts: str): + """ Recommend database tables for a given industry and relevant contexts. + + :param industry_name: industry name + :param industry_contexts: industry descriptions/contexts + + :returns: recommened tables with schema in array of json format + """ + + recommaned_tables_for_industry_template = """ + You're a data engineer and familiar with the {industry_name} industry IT systems. + And below is relevant contexts of the industry: + {industry_contexts} + + Please recommend some database tables with potential data schema for the above contexts. + Your response should be an array of JSON format objects like below. + {{ + "table_name": "{{table name}}", + "table_description": "{{table description}}", + [ + {{ + "column_name": "{{column name}}", + "data_type": "{{data type}}", + "is_null": {{true or false}}, + "is_primary_key": {{true or false}}, + "is_foreign_key": {{true or false}} + }} + ] + }} + + Please recommend 7 to 10 database tables: + """ + + llm = _prepare_openapi_llm() + prompt = PromptTemplate( + input_variables=["industry_name", "industry_contexts"], + template=recommaned_tables_for_industry_template, + ) + chain = LLMChain(llm=llm, prompt=prompt) + response = chain({"industry_name": industry_name, + "industry_contexts": industry_contexts}) + results = response["text"] + + return results + + +def recommend_tables_for_industry_mock(industry_name: str, industry_contexts: str): + results = """ + [ + { + "table_name": "airlines", + "table_description": "Information about the airline companies", + "columns": + [ + { + "column_name": "airline_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": true, + "is_foreign_key": false + }, + { + "column_name": "name", + "data_type": "varchar(255)", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + }, + { + "column_name": "country", + "data_type": "varchar(255)", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + } + ] + }, + { + "table_name": "flights", + "table_description": "Information about flights operated by airlines", + "columns": + [ + { + "column_name": "flight_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": true, + "is_foreign_key": false + }, + { + "column_name": "airline_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": true + }, + { + "column_name": "origin", + "data_type": "varchar(255)", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + }, + { + "column_name": "destination", + "data_type": "varchar(255)", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + }, + { + "column_name": "departure_time", + "data_type": "datetime", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + }, + { + "column_name": "arrival_time", + "data_type": "datetime", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + } + ] + }, + { + "table_name": "passengers", + "table_description": "Information about passengers", + "columns": + [ + { + "column_name": "passenger_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": true, + "is_foreign_key": false + }, + { + "column_name": "name", + "data_type": "varchar(255)", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + }, + { + "column_name": "age", + "data_type": "integer", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + }, + { + "column_name": "gender", + "data_type": "varchar(255)", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + }, + { + "column_name": "flight_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": true + } + ] + }, + { + "table_name": "bookings", + "table_description": "Information about flight bookings made by passengers", + "columns": + [ + { + "column_name": "booking_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": true, + "is_foreign_key": false + }, + { + "column_name": "passenger_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": true + }, + { + "column_name": "flight_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": true + } + ] + }, + { + "table_name": "seats", + "table_description": "Information about seats available in flights", + "columns": + [ + { + "column_name": "seat_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": true, + "is_foreign_key": false + }, + { + "column_name": "flight_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": true + }, + { + "column_name": "passenger_id", + "data_type": "integer", + "is_null": true, + "is_primary_key": false, + "is_foreign_key": true + }, + { + "column_name": "seat_number", + "data_type": "varchar(255)", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + } + ] + }, + { + "table_name": "airports", + "table_description": "Information about airports", + "columns": + [ + { + "column_name": "airport_id", + "data_type": "integer", + "is_null": false, + "is_primary_key": true, + "is_foreign_key": false + }, + { + "column_name": "name", + "data_type": "varchar(255)", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + }, + { + "column_name": "location", + "data_type": "varchar(255)", + "is_null": false, + "is_primary_key": false, + "is_foreign_key": false + } + ] + } + ] + """ + + return results + + +def recommend_custom_table(industry_name: str, + industry_contexts: str, + recommened_tables: str, + custom_table_name: str, + custom_table_description: str): + """ Recommend custom/user-defined table for a input custom table name and description of a given industry. + + :param industry_name: industry name + :param industry_contexts: industry descriptions/contexts + :param recommened_tables: previously recommened tables in json string format + :param custom_table_name: custom table name + :param custom_table_description: custom table description + + :returns: recommened custom tables with schema in json format + """ + + recommend_custom_tables_template=""" + You're a data engineer and familiar with the {industry_name} industry IT systems. + And below is relevant contexts of the industry: + {industry_contexts} + + You've recommended below potential database tables and schema previously in json format. + {recommened_tables} + + Please help to add another {custom_table_name} table with the same table schema format, please include as many as possible columns reflecting the reality and the table description below. + {custom_table_description} + + Please only output the new added table without previously recommended tables, therefore the outcome would be: + """ + + llm = _prepare_openapi_llm() + prompt = PromptTemplate( + input_variables=["industry_name", "industry_contexts", "recommened_tables", "custom_table_name", "custom_table_description"], + template=recommend_custom_tables_template, + ) + chain = LLMChain(llm=llm, prompt=prompt) + response = chain({"industry_name": industry_name, + "industry_contexts": industry_contexts, + "recommened_tables": recommened_tables, + "custom_table_name": custom_table_name, + "custom_table_description": custom_table_description}) + results = response["text"] + + return results + + +def recommend_data_processing_logics(industry_name: str, + industry_contexts: str, + recommened_tables: str, + processing_logic: str): + """ Recommend data processing logics for a given industry and sample tables. + + :param industry_name: industry name + :param industry_contexts: industry descriptions/contexts + :param recommened_tables: previously recommened tables in json string format + :param processing_logic: either data cleaning, data transformation or data aggregation + + :returns: recommened data processing logics in array of json format + """ + + recommend_data_cleaning_logics_template=""" + You're a data engineer and familiar with the {industry_name} industry IT systems. + And below is relevant contexts of the industry: + {industry_contexts} + + You've recommended below potential database tables and schema previously in json format. + {recommened_tables} + + Please recommend 5 to 7 {processing_logic} logic over the above tables with Spark SQL statements. + You response should be in an array of JSON format like below. + {{ + "description": "{{descriptions on the data cleaning logic}}", + "involved_tables": [ + "{{involved table X}}", + "{{involved table Y}}", + "{{involved table Z}}" + ], + "sql": "{{Spark SQL statement to do the data cleaning}}", + "schema": "{{cleaned table schema in json string format}}" + }} + + Therefore outcome would be: + """ + + llm = _prepare_openapi_llm() + prompt = PromptTemplate( + input_variables=["industry_name", "industry_contexts", "processing_logic", "recommened_tables"], + template=recommend_data_cleaning_logics_template, + ) + chain = LLMChain(llm=llm, prompt=prompt) + response = chain({"industry_name": industry_name, + "industry_contexts": industry_contexts, + "processing_logic": processing_logic, + "recommened_tables": recommened_tables}) + results = json.loads(response["text"]) + + return results + + +def generate_custom_data_processing_logics(industry_name: str, + industry_contexts: str, + involved_tables: str, + custom_data_processing_logic: str, + output_table_name: str): + """ Generate custom data processing logics for input data processing requirements. + + :param industry_name: industry name + :param industry_contexts: industry descriptions/contexts + :param involved_tables: tables required by the custom data processing logic, in json string format + :param custom_data_processing_logic: custom data processing requirements + :param output_table_name: output/sink table name for processed data + + :returns: custom data processing logic json format + """ + + generate_custom_data_processing_logics_template = """ + You're a data engineer and familiar with the {industry_name} industry IT systems. + And below is relevant contexts of the industry: + {industry_contexts} + + We have database tables listed below with table_name and schema in JSON format. + {involved_tables} + + And below is the data processing requirement. + {custom_data_processing_logic} + + Please help to generate Spark SQL statement with output data schema. + And your response should be in JSON format like below. + {{ + "sql": "{{Spark SQL statement to do the data cleaning}}", + "schema": "{{output data schema in JSON string format}}" + }} + + And the above data schema string should follows below JSON format, while value for the "table_name" key should strictly be "{output_table_name}". + {{ + "table_name": "{output_table_name}", + "coloumns": [ + {{ + "column_name": "{{column name}}", + "data_type": "{{data type}}", + "is_null": {{true or false}}, + "is_primary_key": {{true or false}}, + "is_foreign_key": {{true or false}} + }} + ] + }} + + Therefore the outcome would be: + """ + + llm = _prepare_openapi_llm() + prompt = PromptTemplate( + input_variables=["industry_name", "industry_contexts", "involved_tables", "custom_data_processing_logic", "output_table_name"], + template=generate_custom_data_processing_logics_template, + ) + chain = LLMChain(llm=llm, prompt=prompt) + + # Run the chain only specifying the input variable. + response = chain({"industry_name": industry_name, + "industry_contexts": industry_contexts, + "custom_data_processing_logic": custom_data_processing_logic, + "involved_tables": involved_tables, + "output_table_name": output_table_name}) + results = response["text"] + + return results + + +def generate_sample_data(industry_name: str, + number_of_lines: int, + target_table: str, + column_values_patterns: str): + """ Generate custom data processing logics for input data processing requirements. + + :param industry_name: industry name + :param number_of_lines: number of lines sample data required + :param target_table: target table name and its schema in json format + + :returns: generated sample data in array of json format + """ + + generate_sample_data_template = """ + You're a data engineer and familiar with the {industry_name} industry IT systems. + Please help to generate {number_of_lines} lines of sample data for below table with table schema in json format. + {target_table} + + And below are patterns of column values in json format, if it's not provided please ignore this requirement. + {column_values_patterns} + + And the sample data should be an array of json object like below. + {{ + "{{column X}}": "{{column value}}", + "{{column Y}}": "{{column value}}", + "{{column Z}}": "{{column value}}" + }} + + The sample data would be: + """ + + llm = _prepare_openapi_llm() + prompt = PromptTemplate( + input_variables=["industry_name", "number_of_lines", "target_table", "column_values_patterns"], + template=generate_sample_data_template, + ) + chain = LLMChain(llm=llm, prompt=prompt) + response = chain({"industry_name": industry_name, + "number_of_lines": number_of_lines, + "target_table": target_table, + "column_values_patterns": column_values_patterns}) + results = response["text"] + + return results + + +def generate_sample_data_mock(industry_name: str, + number_of_lines: int, + target_table: str, + column_values_patterns: str): + if target_table["table_name"] == "flights": + results = """ + [ { "flight_id": 1, "airline_id": 1001, "origin": "New York", "destination": "Los Angeles", "departure_time": "2021-01-01 08:00:00", "arrival_time": "2021-01-01 11:30:00" }, { "flight_id": 2, "airline_id": 1002, "origin": "London", "destination": "Paris", "departure_time": "2021-01-02 14:30:00", "arrival_time": "2021-01-02 16:00:00" }, { "flight_id": 3, "airline_id": 1003, "origin": "Tokyo", "destination": "Sydney", "departure_time": "2021-01-03 10:45:00", "arrival_time": "2021-01-04 06:15:00" }, { "flight_id": 4, "airline_id": 1004, "origin": "Chicago", "destination": "Miami", "departure_time": "2021-01-05 16:20:00", "arrival_time": "2021-01-05 19:45:00" }, { "flight_id": 5, "airline_id": 1005, "origin": "Sydney", "destination": "Melbourne", "departure_time": "2021-01-06 09:15:00", "arrival_time": "2021-01-06 10:30:00" }, { "flight_id": 6, "airline_id": 1001, "origin": "Los Angeles", "destination": "New York", "departure_time": "2021-01-07 12:00:00", "arrival_time": "2021-01-07 15:30:00" }, { "flight_id": 7, "airline_id": 1002, "origin": "Paris", "destination": "London", "departure_time": "2021-01-08 18:45:00", "arrival_time": "2021-01-08 20:15:00" }, { "flight_id": 8, "airline_id": 1003, "origin": "Sydney", "destination": "Tokyo", "departure_time": "2021-01-09 14:30:00", "arrival_time": "2021-01-10 08:00:00" }, { "flight_id": 9, "airline_id": 1004, "origin": "Miami", "destination": "Chicago", "departure_time": "2021-01-11 20:00:00", "arrival_time": "2021-01-11 23:25:00" }, { "flight_id": 10, "airline_id": 1005, "origin": "Melbourne", "destination": "Sydney", "departure_time": "2021-01-12 13:45:00", "arrival_time": "2021-01-12 15:00:00" }, { "flight_id": 11, "airline_id": 1001, "origin": "New York", "destination": "Los Angeles", "departure_time": "2021-01-13 08:00:00", "arrival_time": "2021-01-13 11:30:00" }, { "flight_id": 12, "airline_id": 1002, "origin": "London", "destination": "Paris", "departure_time": "2021-01-14 14:30:00", "arrival_time": "2021-01-14 16:00:00" }, { "flight_id": 13, "airline_id": 1003, "origin": "Tokyo", "destination": "Sydney", "departure_time": "2021-01-15 10:45:00", "arrival_time": "2021-01-16 06:15:00" }, { "flight_id": 14, "airline_id": 1004, "origin": "Chicago", "destination": "Miami", "departure_time": "2021-01-17 16:20:00", "arrival_time": "2021-01-17 19:45:00" }, { "flight_id": 15, "airline_id": 1005, "origin": "Sydney", "destination": "Melbourne", "departure_time": "2021-01-18 09:15:00", "arrival_time": "2021-01-18 10:30:00" }, { "flight_id": 16, "airline_id": 1001, "origin": "Los Angeles", "destination": "New York", "departure_time": "2021-01-19 12:00:00", "arrival_time": "2021-01-19 15:30:00" }, { "flight_id": 17, "airline_id": 1002, "origin": "Paris", "destination": "London", "departure_time": "2021-01-20 18:45:00", "arrival_time": "2021-01-20 20:15:00" }, { "flight_id": 18, "airline_id": 1003, "origin": "Sydney", "destination": "Tokyo", "departure_time": "2021-01-21 14:30:00", "arrival_time": "2021-01-22 08:00:00" }, { "flight_id": 19, "airline_id": 1004, "origin": "Miami", "destination": "Chicago", "departure_time": "2021-01-23 20:00:00", "arrival_time": "2021-01-23 23:25:00" }, { "flight_id": 20, "airline_id": 1005, "origin": "Melbourne", "destination": "Sydney", "departure_time": "2021-01-24 13:45:00", "arrival_time": "2021-01-24 15:00:00" } ] + """ + + if target_table["table_name"] == "passengers": + results = """ + [ { "passenger_id": 1, "name": "John Smith", "age": 35, "gender": "Male", "flight_id": 1 }, { "passenger_id": 2, "name": "Jane Doe", "age": 45, "gender": "Female", "flight_id": 1 }, { "passenger_id": 3, "name": "Michael Johnson", "age": 60, "gender": "Male", "flight_id": 2 }, { "passenger_id": 4, "name": "Emily Williams", "age": 25, "gender": "Female", "flight_id": 3 }, { "passenger_id": 5, "name": "David Brown", "age": 55, "gender": "Male", "flight_id": 4 }, { "passenger_id": 6, "name": "Sarah Davis", "age": 30, "gender": "Female", "flight_id": 4 }, { "passenger_id": 7, "name": "Robert Martinez", "age": 65, "gender": "Male", "flight_id": 5 }, { "passenger_id": 8, "name": "Jessica Thomas", "age": 40, "gender": "Female", "flight_id": 5 }, { "passenger_id": 9, "name": "Christopher Wilson", "age": 50, "gender": "Male", "flight_id": 1 }, { "passenger_id": 10, "name": "Stephanie Taylor", "age": 27, "gender": "Female", "flight_id": 2 }, { "passenger_id": 11, "name": "Daniel Anderson", "age": 65, "gender": "Male", "flight_id": 3 }, { "passenger_id": 12, "name": "Melissa Thompson", "age": 42, "gender": "Female", "flight_id": 4 }, { "passenger_id": 13, "name": "Matthew White", "age": 32, "gender": "Male", "flight_id": 5 }, { "passenger_id": 14, "name": "Amanda Harris", "age": 52, "gender": "Female", "flight_id": 1 }, { "passenger_id": 15, "name": "Andrew Lee", "age": 65, "gender": "Male", "flight_id": 2 }, { "passenger_id": 16, "name": "Jennifer Clark", "age": 28, "gender": "Female", "flight_id": 3 }, { "passenger_id": 17, "name": "James Rodriguez", "age": 62, "gender": "Male", "flight_id": 4 }, { "passenger_id": 18, "name": "Nicole Walker", "age": 38, "gender": "Female", "flight_id": 5 }, { "passenger_id": 19, "name": "Ryan Wright", "age": 47, "gender": "Male", "flight_id": 1 }, { "passenger_id": 20, "name": "Lauren Hall", "age": 31, "gender": "Female", "flight_id": 2 } ] + """ + + if target_table["table_name"] == "bookings": + results = """ + [ { "booking_id": 1, "passenger_id": 1, "flight_id": 1 }, { "booking_id": 2, "passenger_id": 2, "flight_id": 2 }, { "booking_id": 3, "passenger_id": 3, "flight_id": 3 }, { "booking_id": 4, "passenger_id": 4, "flight_id": 4 }, { "booking_id": 5, "passenger_id": 5, "flight_id": 5 }, { "booking_id": 6, "passenger_id": 1, "flight_id": 2 }, { "booking_id": 7, "passenger_id": 2, "flight_id": 3 }, { "booking_id": 8, "passenger_id": 3, "flight_id": 4 }, { "booking_id": 9, "passenger_id": 4, "flight_id": 5 }, { "booking_id": 10, "passenger_id": 5, "flight_id": 1 }, { "booking_id": 11, "passenger_id": 1, "flight_id": 3 }, { "booking_id": 12, "passenger_id": 2, "flight_id": 4 }, { "booking_id": 13, "passenger_id": 3, "flight_id": 5 }, { "booking_id": 14, "passenger_id": 4, "flight_id": 1 }, { "booking_id": 15, "passenger_id": 5, "flight_id": 2 }, { "booking_id": 16, "passenger_id": 1, "flight_id": 4 }, { "booking_id": 17, "passenger_id": 2, "flight_id": 5 }, { "booking_id": 18, "passenger_id": 3, "flight_id": 1 }, { "booking_id": 19, "passenger_id": 4, "flight_id": 2 }, { "booking_id": 20, "passenger_id": 5, "flight_id": 3 } ] + """ + + return results diff --git a/src/streamlit/app.py b/src/streamlit/app.py index 585de86..d11edbc 100644 --- a/src/streamlit/app.py +++ b/src/streamlit/app.py @@ -282,7 +282,7 @@ def import_pipeline(): pipeline_name = st.text_input('Pipeline name', key='pipeline_name', value=current_pipeline_obj['name']) if pipeline_name: current_pipeline_obj['name'] = pipeline_name - industry_list = ["Other", "Agriculture", "Automotive", "Banking", "Chemical", "Construction", "Education", "Energy", "Entertainment", "Food", "Government", "Healthcare", "Hospitality", "Insurance", "Machinery", "Manufacturing", "Media", "Mining", "Pharmaceutical", "Real Estate", "Retail", "Telecommunications", "Transportation", "Utilities", "Wholesale"] + industry_list = ["Other", "Airlines", "Agriculture", "Automotive", "Banking", "Chemical", "Construction", "Education", "Energy", "Entertainment", "Food", "Government", "Healthcare", "Hospitality", "Insurance", "Machinery", "Manufacturing", "Media", "Mining", "Pharmaceutical", "Real Estate", "Retail", "Telecommunications", "Transportation", "Utilities", "Wholesale"] industry_selected_idx = 0 if 'industry' in current_pipeline_obj: industry_selected_idx = industry_list.index(current_pipeline_obj['industry']) diff --git a/src/streamlit/pages/1_Editor.py b/src/streamlit/pages/1_Editor.py index a9814ec..3311457 100644 --- a/src/streamlit/pages/1_Editor.py +++ b/src/streamlit/pages/1_Editor.py @@ -3,6 +3,7 @@ sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) import cddp import streamlit as st +import streamlit_utils import pandas as pd import json from pyspark.sql import SparkSession @@ -16,11 +17,13 @@ import uuid from io import StringIO import numpy as np +from cddp import openai_api from streamlit_echarts import st_echarts from streamlit_extras.chart_container import chart_container from streamlit_extras.stylable_container import stylable_container from streamlit_extras.colored_header import colored_header + st.set_page_config(page_title="CDDP - Pipeline Editor") current_pipeline_obj = None @@ -119,11 +122,14 @@ def run_task(task_name, stage="standard"): res_df = None if stage == "standard": res_df = cddp.start_standard_job(spark, config, task, False, True) + schema = res_df.schema.json() elif stage == "serving": res_df = cddp.start_serving_job(spark, config, task, False, True) + schema = res_df.schema.json() dataframe = res_df.toPandas() print(dataframe) st.session_state[f'_{task_name}_data'] = dataframe + st.session_state[f'_{task_name}_schema'] = schema except Exception as e: st.error(f"Cannot run task: {e}") @@ -169,7 +175,7 @@ def create_pipeline(): colored_header( label="Config-Driven Data Pipeline Editor", - description=f"Pipeline ID: {current_pipeline_obj['id']}", + description=f"Pipeline ID: {current_pipeline_obj.get('id', str(uuid.uuid4()))}", color_name="violet-70", ) @@ -182,6 +188,8 @@ def import_pipeline(): def get_pipeline_path(): if 'working_folder' not in st.session_state: + if "id" not in current_pipeline_obj: + current_pipeline_obj['id'] = str(uuid.uuid4()) return"./"+current_pipeline_obj['id']+".json" else: return st.session_state["working_folder"]+"/"+current_pipeline_obj['id']+".json" @@ -229,7 +237,7 @@ def save_pipeline_to_workspace(): if pipeline_name: current_pipeline_obj['name'] = pipeline_name st.text_input('Pipeline ID', key='pipeline_id', value=current_pipeline_obj['id'], disabled=True) - industry_list = ["Other", "Agriculture", "Automotive", "Banking", "Chemical", "Construction", "Education", "Energy", "Entertainment", "Food", "Government", "Healthcare", "Hospitality", "Insurance", "Machinery", "Manufacturing", "Media", "Mining", "Pharmaceutical", "Real Estate", "Retail", "Telecommunications", "Transportation", "Utilities", "Wholesale"] + industry_list = ["Other", "Airlines", "Agriculture", "Automotive", "Banking", "Chemical", "Construction", "Education", "Energy", "Entertainment", "Food", "Government", "Healthcare", "Hospitality", "Insurance", "Machinery", "Manufacturing", "Media", "Mining", "Pharmaceutical", "Real Estate", "Retail", "Telecommunications", "Transportation", "Utilities", "Wholesale"] industry_selected_idx = 0 if 'industry' in current_pipeline_obj: industry_selected_idx = industry_list.index(current_pipeline_obj['industry']) @@ -249,20 +257,36 @@ def save_pipeline_to_workspace(): st.divider() + generated_tables = [] + selected_tables = [] + if "current_generated_tables" in st.session_state: + if "selected_tables" in st.session_state["current_generated_tables"]: + selected_tables = st.session_state["current_generated_tables"]["selected_tables"] + if "generated_tables" in st.session_state["current_generated_tables"]: + generated_tables = json.loads(st.session_state["current_generated_tables"]["generated_tables"]) + + if "current_generated_sample_data" not in st.session_state: + st.session_state['current_generated_sample_data'] = {} + current_generated_sample_data = st.session_state['current_generated_sample_data'] st.subheader('Staging Zone') pipeline_obj = st.session_state['current_pipeline_obj'] + if "staged_tables" not in st.session_state: + st.session_state["staged_tables"] = [] + staged_tables = st.session_state["staged_tables"] + for i in range(len(pipeline_obj["staging"]) ): target_name = pipeline_obj['staging'][i]['output']['target'] stg_name = st.text_input(f'Dataset Name', key=f"stg_{i}_name", value=target_name) if stg_name: with st.expander(stg_name+" Settings"): - st.selectbox( - 'Choose a dataset to add to staging zone', - ['Dataset 1', 'Dataset 2', 'Dataset 3', 'Dataset 4', 'Dataset 5'], key=f'stg_{i}_ai_dataset') + # selected_table = st.selectbox( + # 'Choose a dataset to add to staging zone', + # selected_tables, + # key=f'stg_{i}_ai_dataset') pipeline_obj['staging'][i]['output']['target'] = stg_name pipeline_obj['staging'][i]['name'] = stg_name @@ -275,7 +299,7 @@ def save_pipeline_to_workspace(): uploaded_file = st.file_uploader(f'Choose sample csv file', key=f'stg_{i}_file') if uploaded_file is not None: - + dataframe = pd.read_csv(uploaded_file) spark = st.session_state["spark"] stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) @@ -288,7 +312,7 @@ def save_pipeline_to_workspace(): if 'sampleData' in pipeline_obj['staging'][i] and len(pipeline_obj['staging'][i]['sampleData']) > 0: sampleData = pipeline_obj['staging'][i]['sampleData'] dataframe = pd.DataFrame(sampleData) - st.dataframe(dataframe) + st.dataframe(dataframe) st.button("Delete", key="delete_stg_"+str(i), on_click=delete_task, args = ['staging', i]) @@ -346,9 +370,13 @@ def add_dataset(): if std_desc: pipeline_obj['standard'][i]['description'] = std_desc + # Get all staged table details + staged_table_names, staged_table_details = streamlit_utils.get_staged_tables() + st.multiselect( - 'Choose datasets to add to transformation', - ['Dataset 1', 'Dataset 2', 'Dataset 3', 'Dataset 4', 'Dataset 5'], key=f'std_{i}_ai_dataset') + 'Choose datasets to add to transformation', + staged_table_names, + key=f'std_{i}_ai_dataset') st.button(f'Generate SQL', key=f'std_{i}_gen') if len(current_pipeline_obj['standard'][i]['code']['sql']) == 0: @@ -362,6 +390,10 @@ def add_dataset(): st.button(f'Run SQL', key=f'run_std_{i}_sql', on_click=run_task, args = [std_name, "standard"]) if '_'+std_name+'_data' in st.session_state: st.dataframe(st.session_state['_'+std_name+'_data']) + if f"_{std_name}_schema" in st.session_state: + streamlit_utils.add_std_srv_schema("standard", + std_name, + st.session_state[f"_{std_name}_schema"]) st.button("Delete", key="delete_std_"+str(i), on_click=delete_task, args = ['standard', i]) @@ -405,9 +437,13 @@ def add_transformation(): if srv_desc: pipeline_obj['serving'][i]['description'] = srv_desc + # Get all standardized table details + standardized_table_names, standardized_table_details = streamlit_utils.get_standardized_tables() + st.multiselect( - 'Choose datasets to add to aggregation', - ['Dataset 1', 'Dataset 2', 'Dataset 3', 'Dataset 4', 'Dataset 5'], key=f'srv_{i}_ai_dataset') + 'Choose datasets to add to aggregation', + staged_table_names + standardized_table_names, + key=f'srv_{i}_ai_dataset') st.button(f'Generate SQL', key=f'srv_{i}_gen') if len(current_pipeline_obj['serving'][i]['code']['sql']) == 0: diff --git a/src/streamlit/pages/2_AI Assistant.py b/src/streamlit/pages/2_AI Assistant.py index 00d5176..38530fc 100644 --- a/src/streamlit/pages/2_AI Assistant.py +++ b/src/streamlit/pages/2_AI Assistant.py @@ -1,8 +1,275 @@ +import cddp +import json +import os +import pandas as pd import streamlit as st -import time -import numpy as np +import sys +from cddp import openai_api +sys.path.append(os.path.join(os.path.dirname(__file__), '....')) +import streamlit_utils -st.set_page_config(page_title="AI Assiatant") +st.set_page_config(page_title="AI Assiatant") st.markdown("# AI Assiatant") -st.sidebar.header("AI Assiatant") \ No newline at end of file +st.sidebar.header("AI Assiatant") + +# Display pipeline basic info fetching from Editor page if it's maintained there +current_pipeline_obj = {} +if "current_pipeline_obj" in st.session_state: + current_pipeline_obj = st.session_state["current_pipeline_obj"] +else: + st.session_state["current_pipeline_obj"] = {} + +st.subheader('Generate tables and sample data') +pipeline_name = st.text_input('Pipeline name', + key='pipeline_name_ai_assistance', + value=current_pipeline_obj.get("name", ""), + disabled=True) +pipeline_industry = st.text_input('Industry', + key='pipeline_industry_ai_assistance', + value=current_pipeline_obj.get("industry", ""), + disabled=True) +pipeline_desc = st.text_area('Pipeline description', + key='pipeline_description_ai_assistance', + value=current_pipeline_obj.get("description", ""), + disabled=True) + + +# Initialize current_generated_tables key in session state +if "current_generated_tables" not in st.session_state: + st.session_state['current_generated_tables'] = {} +current_generated_tables = st.session_state['current_generated_tables'] + +# Get industry name from global pipeline config object +industry_name = st.session_state["current_pipeline_obj"].get("industry", "") +industry_contexts = st.session_state["current_pipeline_obj"].get("description", "") + +# Initialize generated sample data key in session state +if "current_generated_sample_data" not in st.session_state: + st.session_state['current_generated_sample_data'] = {} +current_generated_sample_data = st.session_state['current_generated_sample_data'] + + +# AI Assistnt for tables generation +tables = [] +generate_tables_col1, generate_tables_col2 = st.columns(2) +with generate_tables_col1: + st.button("Generate", on_click=streamlit_utils.click_button, kwargs={"button_name": "generate_tables"}) + +if "generate_tables" not in st.session_state: + st.session_state["generate_tables"] = False +elif st.session_state["generate_tables"]: + st.session_state["generate_tables"] = False # Reset button clicked status + with generate_tables_col2: + with st.spinner('Generating...'): + tables = openai_api.recommend_tables_for_industry_mock(industry_name, industry_contexts) + current_generated_tables["generated_tables"] = tables + +try: + if "generated_tables" in current_generated_tables: + tables = json.loads(current_generated_tables["generated_tables"]) + + if "selected_tables" not in current_generated_tables: + current_generated_tables["selected_tables"] = [] + + for table in tables: + columns = table["columns"] + columns_df = pd.DataFrame.from_dict(columns, orient='columns') + + check_flag = False + for selected_table in current_generated_tables["selected_tables"]: + if selected_table == table["table_name"]: + check_flag = True + + sample_data = None + with st.expander(table["table_name"]): + gen_table_name = table["table_name"] + gen_table_desc = table["table_description"] + st.checkbox("Add to staging zone", + key=gen_table_name, + value=check_flag, + on_change=streamlit_utils.add_to_staging_zone, + args=[gen_table_name, gen_table_desc]) + st.write(gen_table_desc) + st.write(columns_df) + + st.subheader(f"Generate sample data by AI assistant") + rows_count = st.slider("Number of rows", min_value=5, max_value=50, key=f'gen_rows_count_slider_{gen_table_name}') + enable_data_requirements = st.toggle("With extra sample data requirements", key=f'data_requirements_toggle_{gen_table_name}') + data_requirements = "" + if enable_data_requirements: + data_requirements = st.text_area("Extra requirements for sample data", + key=f'data_requirements_text_area_{gen_table_name}', + placeholder="Exp: value of column X should follow patterns xxx-xxxx, while x could be A-Z or 0-9") + + generate_sample_data_col1, generate_sample_data_col2 = st.columns(2) + with generate_sample_data_col1: + st.button("Generate", + key=f"generate_data_button_{gen_table_name}", + on_click=streamlit_utils.click_button, + kwargs={"button_name": f"generate_sample_data_{gen_table_name}"}) + + if f"generate_sample_data_{gen_table_name}" not in st.session_state: + st.session_state[f"generate_sample_data_{gen_table_name}"] = False + + if f"{gen_table_name}_smaple_data_generated" not in st.session_state: + st.session_state[f"{gen_table_name}_smaple_data_generated"] = False + elif st.session_state[f"generate_sample_data_{gen_table_name}"]: + st.session_state[f"generate_sample_data_{gen_table_name}"] = False # Reset clicked status + if not st.session_state[f"{gen_table_name}_smaple_data_generated"]: + with generate_sample_data_col2: + with st.spinner('Generating...'): + sample_data = openai_api.generate_sample_data_mock(pipeline_industry, + rows_count, + table, + data_requirements) + st.session_state[f"{gen_table_name}_smaple_data_generated"] = True + # Store generated data to session_state + current_generated_sample_data[gen_table_name] = sample_data + + # Also update current_pipeline_obj if checked check-box before generating sample data + if sample_data and st.session_state[gen_table_name]: + spark = st.session_state["spark"] + json_str, schema = cddp.load_sample_data(spark, sample_data, format="json") + + for index, dataset in enumerate(current_pipeline_obj['staging']): + if dataset["name"] == gen_table_name: + i = index + + current_pipeline_obj['staging'][i]['sampleData'] = json.loads(json_str) + current_pipeline_obj['staging'][i]['schema'] = json.loads(schema) + + if st.session_state[f"{gen_table_name}_smaple_data_generated"]: + st.session_state[f"{gen_table_name}_smaple_data_generated"] = False # Reset data generated flag + json_sample_data = json.loads(sample_data) + current_generated_sample_data[gen_table_name] = json_sample_data # Save generated data to session_state + # st.session_state[f'stg_{i}_data'] = sample_data + + if gen_table_name in current_generated_sample_data: + sample_data_df = pd.DataFrame.from_dict(current_generated_sample_data[gen_table_name], orient='columns') + st.write(sample_data_df) + + if st.session_state[table["table_name"]] and table["table_name"] not in current_generated_tables["selected_tables"]: + current_generated_tables["selected_tables"].append(table["table_name"]) + +except ValueError as e: + # TODO: Add error/exception to standard error-showing widget + st.write(tables) + + + +st.divider() +st.subheader('Generate data transformation logic') + +if "standardized_tables" not in st.session_state: + st.session_state["standardized_tables"] = [] +standardized_tables = st.session_state["standardized_tables"] + +if "current_generated_std_srv_sqls" not in st.session_state: + st.session_state['current_generated_std_srv_sqls'] = {} +current_generated_std_srv_sqls = st.session_state['current_generated_std_srv_sqls'] + +# Get all staged table details +staged_table_names, staged_table_details = streamlit_utils.get_staged_tables() +# Get all standardized table details +standardized_table_names, standardized_table_details = streamlit_utils.get_standardized_tables() + +std_name = st.text_input("Transformation Name") +selected_staged_tables = st.multiselect( + 'Choose datasets to do the data transformation', + staged_table_names + standardized_table_names, + key=f'std_involved_tables') +process_requirements = st.text_area("Transformation requirements", key=f"std_transform_requirements") + +generate_transform_sql_col1, generate_transform_sql_col2 = st.columns(2) +with generate_transform_sql_col1: + st.button(f'Generate SQL', + key=f'std_gen_sql', + on_click=streamlit_utils.click_button, + kwargs={"button_name": f"std_gen_transform_sql"}) + +if "std_gen_transform_sql" not in st.session_state: + st.session_state["std_gen_transform_sql"] = False +if st.session_state["std_gen_transform_sql"]: + st.session_state["std_gen_transform_sql"] = False # Reset clicked status + with generate_transform_sql_col2: + with st.spinner('Generating...'): + process_logic = openai_api.generate_custom_data_processing_logics(industry_name=pipeline_industry, + industry_contexts=pipeline_desc, + involved_tables=staged_table_details + standardized_table_details, + custom_data_processing_logic=process_requirements, + output_table_name=std_name) + try: + process_logic_json = json.loads(process_logic) + std_sql_val = process_logic_json["sql"] + current_generated_std_srv_sqls[std_name] = std_sql_val + except ValueError as e: + st.write(process_logic) + + with st.expander(std_name, expanded=True): + st.button("Add to Standardization zone", + key=f"gen_transformation_{std_name}", + on_click=streamlit_utils.add_to_std_srv_zone, + args=[f"gen_transformation_{std_name}", std_name, process_requirements, "standard"]) + st.text_input("Invovled tables", value=", ".join(selected_staged_tables), disabled=True) + std_sql = st.text_area(f'Transformation Spark SQL', + key=f'std_transform_sql_{std_name}', + value=current_generated_std_srv_sqls[std_name], + on_change=streamlit_utils.update_sql, + args=[f'std_transform_sql_{std_name}', std_name]) + + + +st.divider() +st.subheader('Generate data aggregation logic') + +if "serving_tables" not in st.session_state: + st.session_state["serving_tables"] = [] +serving_tables = st.session_state["serving_tables"] + +# Get all standardized table details +standardized_table_names, standardized_table_details = streamlit_utils.get_standardized_tables() + +srv_name = st.text_input("Aggregation Name") +selected_stg_std_tables = st.multiselect( + 'Choose datasets to do the data aggregation', + staged_table_names + standardized_table_names, + key=f'srv_involved_tables') +process_requirements = st.text_area("Aggregation requirements", key=f"srv_aggregate_requirements") + +generate_aggregate_sql_col1, generate_aggregate_sql_col2 = st.columns(2) +with generate_aggregate_sql_col1: + st.button(f'Generate SQL', + key=f'srv_gen_sql', + on_click=streamlit_utils.click_button, + kwargs={"button_name": f"srv_gen_aggregate_sql"}) + +if "srv_gen_aggregate_sql" not in st.session_state: + st.session_state["srv_gen_aggregate_sql"] = False +if st.session_state["srv_gen_aggregate_sql"]: + st.session_state["srv_gen_aggregate_sql"] = False # Reset clicked status + with generate_aggregate_sql_col2: + with st.spinner('Generating...'): + process_logic = openai_api.generate_custom_data_processing_logics(industry_name=pipeline_industry, + industry_contexts=pipeline_desc, + involved_tables=staged_table_details, + custom_data_processing_logic=process_requirements, + output_table_name=srv_name) + try: + process_logic_json = json.loads(process_logic) + srv_sql_val = process_logic_json["sql"] + current_generated_std_srv_sqls[srv_name] = srv_sql_val + except ValueError as e: + st.write(process_logic) + + with st.expander(srv_name, expanded=True): + st.button("Add to Serving zone", + key=f"gen_aggregation_{srv_name}", + on_click=streamlit_utils.add_to_std_srv_zone, + args=[f"gen_aggregation_{srv_name}", srv_name, process_requirements, "serving"]) + st.text_input("Invovled tables", value=", ".join(selected_stg_std_tables), disabled=True) + std_sql = st.text_area(f'Transformation Spark SQL', + key=f'srv_aggregate_sql_{srv_name}', + value=current_generated_std_srv_sqls[srv_name], + on_change=streamlit_utils.update_sql, + args=[f'std_transform_sql_{srv_name}', srv_name]) diff --git a/src/streamlit/streamlit_utils.py b/src/streamlit/streamlit_utils.py new file mode 100644 index 0000000..f3440d5 --- /dev/null +++ b/src/streamlit/streamlit_utils.py @@ -0,0 +1,165 @@ +import cddp +import json +import streamlit as st + + +def get_selected_tables(tables): + selected_tables = [] + if len(tables) > 0: + for table in tables: + table_name = table["table_name"] + if table_name in st.session_state and st.session_state[table_name]: + selected_tables.append(table_name) + + return selected_tables + + +def click_button(button_name): + st.session_state[button_name] = True + + +def get_selected_table_details(tables, table_name): + target_table = {} + for table in tables: + if table["table_name"] == table_name: + target_table = table + + return target_table + + +def get_selected_tables_details(tables, table_names): + target_tables = [] + for table_name in table_names: + for table_details in tables: + if table_details["table_name"] == table_name: + target_tables.append(table_details) + + return target_tables + + +def is_json_string(input: str): + is_json = True + try: + json.loads(input) + except ValueError as e: + is_json = False + + return is_json + + +def update_sql(key, std_name): + current_generated_std_sqls = st.session_state['current_generated_std_sqls'] + current_generated_std_sqls[std_name] = st.session_state[key] + + +def add_to_staging_zone(stg_name, stg_desc): + pipeline_obj = st.session_state['current_pipeline_obj'] + current_generated_sample_data = st.session_state['current_generated_sample_data'] + sample_data = current_generated_sample_data.get(stg_name, None) + + if st.session_state[stg_name]: # Add to staging zone if checkbox is checked + if sample_data: + spark = st.session_state["spark"] + json_str, schema = cddp.load_sample_data(spark, json.dumps(sample_data), format="json") + json_sample_data = json.loads(json_str) + json_schema = json.loads(schema) + else: + json_sample_data = [] + json_schema = {} + + pipeline_obj["staging"].append({ + "name": stg_name, + "description": stg_desc, + "input": { + "type": "filestore", + "format": "json", + "path": "", + "read-type": "batch" + }, + "output": { + "target": stg_name, + "type": ["file", "view"] + }, + "schema": json_schema, + "sampleData": json_sample_data + }) + else: # Remove staging task from staging zone if it's unchecked + for index, obj in enumerate(pipeline_obj["staging"]): + if obj["name"] == stg_name: + del pipeline_obj['staging'][index] + + +def get_staged_tables(): + pipeline_obj = st.session_state['current_pipeline_obj'] + staging = pipeline_obj.get("staging", None) + + staged_table_names = [] + staged_table_details = [] + if staging: + for staged_table in staging: + staged_table_names.append(staged_table["output"]["target"]) + staged_table_details.append({ + "table_name": staged_table["output"]["target"], + "schema": staged_table.get("schema", "") + }) + + return staged_table_names, staged_table_details + + +def add_to_std_srv_zone(button_key, std_srv_name, std_srv_desc, zone): + pipeline_obj = st.session_state["current_pipeline_obj"] + current_generated_std_srv_sqls = st.session_state["current_generated_std_srv_sqls"] + + if st.session_state[button_key]: # Add to std or srv zone if click add-to-std/srv-zone button + pipeline_obj[zone].append({ + "name": std_srv_name, + "type": "batch", + "description": std_srv_desc, + "code": { + "lang": "sql", + "sql": [current_generated_std_srv_sqls[std_srv_name]] + }, + "output": { + "target": std_srv_name, + "type": ["file", "view"] + }, + "dependency": [] + }) + else: # Remove staging task from staging zone if it's unchecked + for index, obj in enumerate(pipeline_obj[zone]): + if obj["name"] == std_srv_name: + del pipeline_obj['standard'][index] + + +def add_std_srv_schema(zone, output_table_name, schema): + if "current_std_srv_tables_schema" not in st.session_state: + st.session_state['current_std_srv_tables_schema'] = {} + current_std_srv_tables_schema = st.session_state['current_std_srv_tables_schema'] + + current_std_srv_tables_schema[zone] = {} + current_std_srv_tables_schema[zone][output_table_name] = schema + + +def get_standardized_tables(): + pipeline_obj = st.session_state['current_pipeline_obj'] + standard = pipeline_obj.get("standard", None) + + if "current_std_srv_tables_schema" not in st.session_state: + st.session_state['current_std_srv_tables_schema'] = {} + current_std_srv_tables_schema = st.session_state['current_std_srv_tables_schema'] + + if "standard" not in current_std_srv_tables_schema: + current_std_srv_tables_schema["standard"] = {} + + standardized_table_names = [] + standardized_table_details = [] + if standard: + for standardized_table in standard: + std_name = standardized_table["output"]["target"] + standardized_table_names.append(std_name) + standardized_table_details.append({ + "table_name": std_name, + "schema": current_std_srv_tables_schema["standard"].get(std_name, "") + }) + + return standardized_table_names, standardized_table_names From 5f4c48f9cf4c66d9d3d68e9b5d3a90418cb51fe5 Mon Sep 17 00:00:00 2001 From: thurston-chen Date: Tue, 12 Sep 2023 05:39:58 +0800 Subject: [PATCH 2/2] fix: add on_change callback to SQL text area --- src/streamlit/pages/1_Editor.py | 25 +++++++++++++++++++++---- src/streamlit/streamlit_utils.py | 6 +++--- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/src/streamlit/pages/1_Editor.py b/src/streamlit/pages/1_Editor.py index 3311457..f725cba 100644 --- a/src/streamlit/pages/1_Editor.py +++ b/src/streamlit/pages/1_Editor.py @@ -352,6 +352,10 @@ def add_dataset(): st.divider() + if "current_generated_std_srv_sqls" not in st.session_state: + st.session_state['current_generated_std_srv_sqls'] = {} + current_generated_std_srv_sqls = st.session_state['current_generated_std_srv_sqls'] + st.subheader('Standardization Zone') for i in range(len(pipeline_obj["standard"])): @@ -381,8 +385,15 @@ def add_dataset(): st.button(f'Generate SQL', key=f'std_{i}_gen') if len(current_pipeline_obj['standard'][i]['code']['sql']) == 0: current_pipeline_obj['standard'][i]['code']['sql'].append("") - std_sql_val = current_pipeline_obj['standard'][i]['code']['sql'][0] - std_sql = st.text_area(f'SQL', key=f'std_{i}_sql', value=std_sql_val) + # std_sql_val = current_pipeline_obj['standard'][i]['code']['sql'][0] + current_generated_std_srv_sqls[std_name] = current_pipeline_obj['standard'][i]['code']['sql'][0] + + std_sql = st.text_area(f'SQL', + key=f'std_{i}_sql', + # value=std_sql_val, + value=current_generated_std_srv_sqls[std_name], + on_change=streamlit_utils.update_sql, + args=[f'std_{i}_sql', std_name]) if std_sql: pipeline_obj['standard'][i]['code']['sql'][0] = std_sql @@ -448,8 +459,14 @@ def add_transformation(): st.button(f'Generate SQL', key=f'srv_{i}_gen') if len(current_pipeline_obj['serving'][i]['code']['sql']) == 0: current_pipeline_obj['serving'][i]['code']['sql'].append("") - srv_sql_val = current_pipeline_obj['serving'][i]['code']['sql'][0] - srv_sql = st.text_area(f'SQL', key=f'srv_{i}_sql', value=srv_sql_val) + # srv_sql_val = current_pipeline_obj['serving'][i]['code']['sql'][0] + current_generated_std_srv_sqls[srv_name] = current_pipeline_obj['serving'][i]['code']['sql'][0] + srv_sql = st.text_area(f'SQL', + key=f'srv_{i}_sql', + # value=srv_sql_val, + value=current_generated_std_srv_sqls[srv_name], + on_change=streamlit_utils.update_sql, + args=[f'srv_{i}_sql', srv_name]) if srv_sql: pipeline_obj['serving'][i]['code']['sql'][0] = srv_sql diff --git a/src/streamlit/streamlit_utils.py b/src/streamlit/streamlit_utils.py index f3440d5..3a99e1a 100644 --- a/src/streamlit/streamlit_utils.py +++ b/src/streamlit/streamlit_utils.py @@ -47,9 +47,9 @@ def is_json_string(input: str): return is_json -def update_sql(key, std_name): - current_generated_std_sqls = st.session_state['current_generated_std_sqls'] - current_generated_std_sqls[std_name] = st.session_state[key] +def update_sql(key, table_name): + current_generated_std_srv_sqls = st.session_state['current_generated_std_srv_sqls'] + current_generated_std_srv_sqls[table_name] = st.session_state[key] def add_to_staging_zone(stg_name, stg_desc):