From 7a60a155f73d850823fa3db4723a5b2694c36fa8 Mon Sep 17 00:00:00 2001 From: thurston-chen Date: Thu, 21 Sep 2023 11:59:36 +0800 Subject: [PATCH 1/8] fix: refactor AI Assistant page with bugs fixes --- src/cddp/openai_api.py | 233 ++--------------- src/streamlit/pages/1_Editor.py | 7 + src/streamlit/pages/2_AI Assistant.py | 361 ++++++++++++++++---------- src/streamlit/streamlit_utils.py | 146 ++++++++--- 4 files changed, 371 insertions(+), 376 deletions(-) diff --git a/src/cddp/openai_api.py b/src/cddp/openai_api.py index 2788d7b..aaabe58 100644 --- a/src/cddp/openai_api.py +++ b/src/cddp/openai_api.py @@ -106,220 +106,29 @@ def recommend_tables_for_industry(industry_name: str, industry_contexts: str): return results +def recommend_data_processing_scenario_mock(industry_name: str): + results = """ + [ { "pipeline_name": "Flight Delay Pipeline", "description": "Collects flight data from various sources and generates metrics on delays", "stages": [ { "stage": "staging", "description": "Collects flight data from airline APIs and airport databases" }, { "stage": "standard", "description": "Transforms data to calculate delay metrics based on departure and arrival times" }, { "stage": "serving", "description": "Aggregates delay metrics by airport, airline, and route" } ] }, { "pipeline_name": "Baggage Handling Pipeline", "description": "Tracks baggage movement and generates metrics on handling efficiency", "stages": [ { "stage": "staging", "description": "Collects baggage movement data from RFID scanners and baggage handling systems" }, { "stage": "standard", "description": "Transforms data to calculate metrics on baggage handling efficiency, such as time to load and unload baggage" }, { "stage": "serving", "description": "Aggregates metrics by airport, airline, and baggage handling company" } ] }, { "pipeline_name": "Revenue Management Pipeline", "description": "Analyzes sales data to optimize pricing and revenue", "stages": [ { "stage": "staging", "description": "Collects sales data from ticketing systems and travel booking websites" }, { "stage": "standard", "description": "Transforms data to calculate revenue metrics, such as average ticket price and revenue per seat" }, { "stage": "serving", "description": "Aggregates metrics by route, fare class, and seasonality" } ] }, { "pipeline_name": "Maintenance Pipeline", "description": "Monitors aircraft health and schedules maintenance", "stages": [ { "stage": "staging", "description": "Collects aircraft sensor data and maintenance records" }, { "stage": "standard", "description": "Transforms data to identify potential maintenance issues and schedule preventative maintenance" }, { "stage": "serving", "description": "Aggregates metrics by aircraft type, age, and maintenance history" } ] }, { "pipeline_name": "Customer Service Pipeline", "description": "Analyzes customer feedback to improve service", "stages": [ { "stage": "staging", "description": "Collects customer feedback from surveys, social media, and customer service interactions" }, { "stage": "standard", "description": "Transforms data to identify common issues and sentiment analysis of customer feedback" }, { "stage": "serving", "description": "Aggregates metrics by route, airline, and customer feedback channel" } ] }, { "pipeline_name": "Fuel Efficiency Pipeline", "description": "Monitors fuel usage to optimize efficiency and reduce costs", "stages": [ { "stage": "staging", "description": "Collects fuel usage data from aircraft sensors and fueling systems" }, { "stage": "standard", "description": "Transforms data to calculate fuel efficiency metrics, such as fuel burn per passenger mile" }, { "stage": "serving", "description": "Aggregates metrics by aircraft type, route, and seasonality" } ] }, { "pipeline_name": "Security Pipeline", "description": "Monitors security incidents to improve safety and compliance", "stages": [ { "stage": "staging", "description": "Collects security incident data from airport security systems and passenger screening" }, { "stage": "standard", "description": "Transforms data to identify common security incidents and compliance issues" }, { "stage": "serving", "description": "Aggregates metrics by airport, airline, and security incident type" } ] } ] + """ + + return results + + def recommend_tables_for_industry_mock(industry_name: str, industry_contexts: str): results = """ - [ - { - "table_name": "airlines", - "table_description": "Information about the airline companies", - "columns": - [ - { - "column_name": "airline_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": true, - "is_foreign_key": false - }, - { - "column_name": "name", - "data_type": "varchar(255)", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - }, - { - "column_name": "country", - "data_type": "varchar(255)", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - } - ] - }, - { - "table_name": "flights", - "table_description": "Information about flights operated by airlines", - "columns": - [ - { - "column_name": "flight_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": true, - "is_foreign_key": false - }, - { - "column_name": "airline_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": true - }, - { - "column_name": "origin", - "data_type": "varchar(255)", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - }, - { - "column_name": "destination", - "data_type": "varchar(255)", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - }, - { - "column_name": "departure_time", - "data_type": "datetime", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - }, - { - "column_name": "arrival_time", - "data_type": "datetime", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - } - ] - }, - { - "table_name": "passengers", - "table_description": "Information about passengers", - "columns": - [ - { - "column_name": "passenger_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": true, - "is_foreign_key": false - }, - { - "column_name": "name", - "data_type": "varchar(255)", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - }, - { - "column_name": "age", - "data_type": "integer", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - }, - { - "column_name": "gender", - "data_type": "varchar(255)", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - }, - { - "column_name": "flight_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": true - } - ] - }, - { - "table_name": "bookings", - "table_description": "Information about flight bookings made by passengers", - "columns": - [ - { - "column_name": "booking_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": true, - "is_foreign_key": false - }, - { - "column_name": "passenger_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": true - }, - { - "column_name": "flight_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": true - } - ] - }, - { - "table_name": "seats", - "table_description": "Information about seats available in flights", - "columns": - [ - { - "column_name": "seat_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": true, - "is_foreign_key": false - }, - { - "column_name": "flight_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": true - }, - { - "column_name": "passenger_id", - "data_type": "integer", - "is_null": true, - "is_primary_key": false, - "is_foreign_key": true - }, - { - "column_name": "seat_number", - "data_type": "varchar(255)", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - } - ] - }, - { - "table_name": "airports", - "table_description": "Information about airports", - "columns": - [ - { - "column_name": "airport_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": true, - "is_foreign_key": false - }, - { - "column_name": "name", - "data_type": "varchar(255)", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - }, - { - "column_name": "location", - "data_type": "varchar(255)", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - } - ] - } - ] + [{"table_name":"airlines","table_description":"Information about the airline companies","columns":[{"column_name":"airline_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"name","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"country","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false}]},{"table_name":"flights","table_description":"Information about flights operated by airlines","columns":[{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"airline_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true},{"column_name":"origin","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"destination","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"departure_time","data_type":"datetime","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"arrival_time","data_type":"datetime","is_null":false,"is_primary_key":false,"is_foreign_key":false}]},{"table_name":"passengers","table_description":"Information about passengers","columns":[{"column_name":"passenger_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"name","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"age","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"gender","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true}]},{"table_name":"bookings","table_description":"Information about flight bookings made by passengers","columns":[{"column_name":"booking_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"passenger_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true},{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true}]},{"table_name":"seats","table_description":"Information about seats available in flights","columns":[{"column_name":"seat_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true},{"column_name":"passenger_id","data_type":"integer","is_null":true,"is_primary_key":false,"is_foreign_key":true},{"column_name":"seat_number","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false}]},{"table_name":"airports","table_description":"Information about airports","columns":[{"column_name":"airport_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"name","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"location","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false}]}] + """ + + return results + + +def generate_custom_data_processing_logics_mock(industry_name: str, + industry_contexts: str, + involved_tables: str, + custom_data_processing_logic: str, + output_table_name: str): + results = """ + {"sql":"SELECT flights.airline_id, flights.arrival_time, flights.departure_time, flights.destination, flights.flight_id, flights.origin, passengers.age, passengers.gender, passengers.name, passengers.passenger_id FROM flights JOIN bookings ON flights.flight_id = bookings.flight_id JOIN passengers ON bookings.passenger_id = passengers.passenger_id","schema":{"table_name":"std_flights_p","columns":[{"column_name":"airline_id","data_type":"long","is_null":true,"is_primary_key":false,"is_foreign_key":true},{"column_name":"arrival_time","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"departure_time","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"destination","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"flight_id","data_type":"long","is_null":true,"is_primary_key":false,"is_foreign_key":true},{"column_name":"origin","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"age","data_type":"long","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"gender","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"name","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"passenger_id","data_type":"long","is_null":true,"is_primary_key":false,"is_foreign_key":true}]}} """ return results diff --git a/src/streamlit/pages/1_Editor.py b/src/streamlit/pages/1_Editor.py index 67567b6..e2f17be 100644 --- a/src/streamlit/pages/1_Editor.py +++ b/src/streamlit/pages/1_Editor.py @@ -185,6 +185,13 @@ def run_task(task_name, stage="standard"): def delete_task(type, index): if type == "staging": + # Also sync checkbox status in the AI Assistant page + generated_tables = st.session_state["current_generated_tables"]["generated_tables"] + for table in generated_tables: + if table["table_name"] == current_pipeline_obj['staging'][index]["name"]: + table["staged_flag"] = False + break + del current_pipeline_obj['staging'][index] elif type == "standard": del current_pipeline_obj['standard'][index] diff --git a/src/streamlit/pages/2_AI Assistant.py b/src/streamlit/pages/2_AI Assistant.py index 7a74f3a..633c5c5 100644 --- a/src/streamlit/pages/2_AI Assistant.py +++ b/src/streamlit/pages/2_AI Assistant.py @@ -19,7 +19,7 @@ st.set_page_config(page_title="AI Assistant") colored_header( - label="AI Assiatant", + label="AI Assitant", description=f"Leverage AI to assist you in data pipeline development", color_name="violet-70", ) @@ -103,29 +103,35 @@ st.session_state["generate_tables"] = False elif st.session_state["generate_tables"]: st.session_state["generate_tables"] = False # Reset button clicked status - with generate_tables_col2: - with st.spinner('Generating...'): - tables = openai_api.recommend_tables_for_industry(industry_name, industry_contexts) - current_generated_tables["generated_tables"] = tables + + has_staged_table = streamlit_utils.has_staged_table() + + if has_staged_table: + st.warning("Some of current generated tables have been added to Staging zone, please remove them before generating tables again!") + else: + with generate_tables_col2: + with st.spinner('Generating...'): + tables = openai_api.recommend_tables_for_industry(industry_name, industry_contexts) + current_generated_tables["generated_tables"] = json.loads(tables) + + # Clean current_generated_sample_data key in session state once Generate button is clicked again + st.session_state['current_generated_sample_data'] = {} try: if "generated_tables" in current_generated_tables: - tables = json.loads(current_generated_tables["generated_tables"]) - - if "selected_tables" not in current_generated_tables: - current_generated_tables["selected_tables"] = [] + tables = current_generated_tables["generated_tables"] - for table in tables: + for gen_table_index, table in enumerate(tables): columns = table["columns"] columns_df = pd.DataFrame.from_dict(columns, orient='columns') - check_flag = False - for selected_table in current_generated_tables["selected_tables"]: - if selected_table == table["table_name"]: - check_flag = True - sample_data = None - with st.expander(table["table_name"]): + added_to_stage = current_generated_tables["generated_tables"][gen_table_index].get("staged_flag", False) + expander_label = table["table_name"] + if added_to_stage: + expander_label = table["table_name"] + " :heavy_check_mark:" + + with st.expander(expander_label, expanded=added_to_stage): gen_table_name = table["table_name"] gen_table_desc = table["table_description"] @@ -167,7 +173,8 @@ current_generated_sample_data[gen_table_name] = sample_data # Also update current_pipeline_obj if checked check-box before generating sample data - if sample_data and st.session_state[gen_table_name]: + # if sample_data and st.session_state[f"add_to_staging_{gen_table_index}_checkbox"]: + if sample_data and st.session_state.get(f"add_to_staging_{gen_table_index}_checkbox", False): spark = st.session_state["spark"] json_str, schema = cddp.load_sample_data(spark, sample_data, format="json") @@ -188,18 +195,12 @@ sample_data_df = pd.DataFrame.from_dict(current_generated_sample_data[gen_table_name], orient='columns') st.write(sample_data_df) - st.checkbox("Add to staging zone", - key=gen_table_name, - value=check_flag, - on_change=streamlit_utils.add_to_staging_zone, - args=[gen_table_name, gen_table_desc]) - - - # TODO we need to change the key of table["table_name"] - if st.session_state[table["table_name"]] and table["table_name"] not in current_generated_tables["selected_tables"]: - current_generated_tables["selected_tables"].append(table["table_name"]) - - + # Show checkbox only after sample data has been generated + st.checkbox("Add to staging zone", + key=f"add_to_staging_{gen_table_index}_checkbox", + value=added_to_stage, + on_change=streamlit_utils.add_to_staging_zone, + args=[gen_table_index, gen_table_name, gen_table_desc]) except ValueError as e: # TODO: Add error/exception to standard error-showing widget @@ -207,119 +208,211 @@ with tab_std_sql: - # st.divider() - # st.subheader('Generate data transformation logic') - + if "current_std_srv_tables_schema" not in st.session_state: + st.session_state['current_std_srv_tables_schema'] = {} + current_std_srv_tables_schema = st.session_state['current_std_srv_tables_schema'] - if "standardized_tables" not in st.session_state: - st.session_state["standardized_tables"] = [] - standardized_tables = st.session_state["standardized_tables"] + if "standard" not in current_std_srv_tables_schema: + current_std_srv_tables_schema["standard"] = {} if "current_generated_std_srv_sqls" not in st.session_state: st.session_state['current_generated_std_srv_sqls'] = {} current_generated_std_srv_sqls = st.session_state['current_generated_std_srv_sqls'] + if "current_selected_std_tables" not in st.session_state: + st.session_state['current_selected_std_tables'] = [] + # current_selected_std_tables = st.session_state['current_selected_std_tables'] + # Get all staged table details staged_table_names, staged_table_details = streamlit_utils.get_staged_tables() - # Get all standardized table details - standardized_table_names, standardized_table_details = streamlit_utils.get_standardized_tables() - - std_name = st.text_input("Transformation Name") - selected_staged_tables = st.multiselect( - 'Choose datasets to do the data transformation', - staged_table_names + standardized_table_names, - key=f'std_involved_tables') - process_requirements = st.text_area("Transformation requirements", key=f"std_transform_requirements") - - generate_transform_sql_col1, generate_transform_sql_col2 = st.columns(2) - with generate_transform_sql_col1: - st.button(f'Generate SQL', - key=f'std_gen_sql', - on_click=streamlit_utils.click_button, - kwargs={"button_name": f"std_gen_transform_sql"}) - - if "std_gen_transform_sql" not in st.session_state: - st.session_state["std_gen_transform_sql"] = False - if st.session_state["std_gen_transform_sql"]: - st.session_state["std_gen_transform_sql"] = False # Reset clicked status - with generate_transform_sql_col2: - with st.spinner('Generating...'): - process_logic = openai_api.generate_custom_data_processing_logics(industry_name=industry_name, - industry_contexts=industry_contexts, - involved_tables=staged_table_details + standardized_table_details, - custom_data_processing_logic=process_requirements, - output_table_name=std_name) - try: - process_logic_json = json.loads(process_logic) - std_sql_val = process_logic_json["sql"] - current_generated_std_srv_sqls[std_name] = std_sql_val - except ValueError as e: - st.write(process_logic) - - with st.expander(std_name, expanded=True): - st.button("Add to Standardization zone", - key=f"gen_transformation_{std_name}", - on_click=streamlit_utils.add_to_std_srv_zone, - args=[f"gen_transformation_{std_name}", std_name, process_requirements, "standard"]) - st.text_input("Invovled tables", value=", ".join(selected_staged_tables), disabled=True) - std_sql = st.text_area(f'Transformation Spark SQL', - key=f'std_transform_sql_{std_name}', - value=current_generated_std_srv_sqls[std_name], - on_change=streamlit_utils.update_sql, - args=[f'std_transform_sql_{std_name}', std_name]) + + for i in range(len(current_pipeline_obj["standard"])): + target_name = current_pipeline_obj['standard'][i]['output']['target'] + std_name = st.text_input(f'Transformation Name', key=f'std_{i}_name', value=target_name) + + if std_name: + st.subheader(std_name) + current_generated_std_srv_sqls[std_name] = "" + with st.expander(std_name+" Settings", expanded=True): + + current_pipeline_obj['standard'][i]['output']['target'] = std_name + current_pipeline_obj['standard'][i]['name'] = std_name + + # Get all standardized table details + standardized_table_names, standardized_table_details = streamlit_utils.get_standardized_tables() + + optional_tables = staged_table_names + standardized_table_names + if std_name in optional_tables: + optional_tables.remove(std_name) # Remove itself from the optional tables list + + if len(st.session_state['current_selected_std_tables']) == i: + # Set None to default value of multiselect, otherwise it requires to be values in options + st.session_state['current_selected_std_tables'].append(None) + selected_staged_tables = st.multiselect( + 'Choose datasets to do the data transformation', + options=optional_tables, + default=st.session_state['current_selected_std_tables'][i], + on_change=streamlit_utils.update_selected_tables, + key=f'std_{i}_involved_tables', + args=[i, f'std_{i}_involved_tables', 'current_selected_std_tables']) + + if 'description' not in current_pipeline_obj['standard'][i]: + current_pipeline_obj['standard'][i]['description'] = "" + process_requirements = st.text_area("Transformation requirements", + key=f"std_{i}_transform_requirements", + value=current_pipeline_obj['standard'][i]['description']) + current_pipeline_obj['standard'][i]['description'] = process_requirements + + generate_transform_sql_col1, generate_transform_sql_col2 = st.columns(2) + with generate_transform_sql_col1: + st.button(f'Generate SQL', + key=f'std_{i}_gen_sql', + on_click=streamlit_utils.click_button, + kwargs={"button_name": f"std_{i}_gen_transform_sql"}) + + if len(current_pipeline_obj['standard'][i]['code']['sql']) == 0: + current_pipeline_obj['standard'][i]['code']['sql'].append("") + std_sql_val = current_pipeline_obj['standard'][i]['code']['sql'][0] + + # Get selected staged table details + selected_table_details = streamlit_utils.get_selected_tables_details(staged_table_details + standardized_table_details, + selected_staged_tables) + if st.session_state[f"std_{i}_gen_sql"]: + with generate_transform_sql_col2: + with st.spinner('Generating...'): + process_logic = openai_api.generate_custom_data_processing_logics(industry_name=industry_name, + industry_contexts=industry_contexts, + involved_tables=selected_table_details, + custom_data_processing_logic=process_requirements, + output_table_name=std_name) + try: + process_logic_json = json.loads(process_logic) + std_sql_val = process_logic_json["sql"] + std_table_schema = process_logic_json["schema"] + current_std_srv_tables_schema["standard"][std_name] = std_table_schema + # current_generated_std_srv_sqls[std_name] = std_sql_val + current_pipeline_obj['standard'][i]['code']['sql'][0] = std_sql_val + st.write(selected_table_details) + except ValueError as e: + st.write(process_logic) + + std_sql = st.text_area(f'Transformation Spark SQL', + key=f'std_{i}_transform_sql_text_area', + value=std_sql_val) + # on_change=streamlit_utils.update_sql, + # args=[f'std_transform_sql_{std_name}', std_name]) + current_pipeline_obj['standard'][i]['code']['sql'][0] = std_sql + + st.button(f'Run SQL', key=f'run_std_{i}_sql', on_click=streamlit_utils.run_task, args = [std_name, "standard"]) + if f"_{std_name}_standard_data" in st.session_state: + st.dataframe(st.session_state[f"_{std_name}_standard_data"]) + + st.button(f"Delete {std_name}", key="delete_std_"+str(i), on_click=streamlit_utils.delete_task, args = ['standard', i]) + + if i != len(current_pipeline_obj["standard"]) - 1: + st.divider() + + if len(current_pipeline_obj["standard"]) == 0: + st.write("No transformation in standardization zone") + + st.divider() + st.button('Add Transformation', on_click=streamlit_utils.add_transformation, use_container_width=True) with tab_srv_sql: - # st.divider() - # st.subheader('Generate data aggregation logic') - - if "serving_tables" not in st.session_state: - st.session_state["serving_tables"] = [] - serving_tables = st.session_state["serving_tables"] - - # Get all standardized table details - standardized_table_names, standardized_table_details = streamlit_utils.get_standardized_tables() - - srv_name = st.text_input("Aggregation Name") - selected_stg_std_tables = st.multiselect( - 'Choose datasets to do the data aggregation', - staged_table_names + standardized_table_names, - key=f'srv_involved_tables') - process_requirements = st.text_area("Aggregation requirements", key=f"srv_aggregate_requirements") - - generate_aggregate_sql_col1, generate_aggregate_sql_col2 = st.columns(2) - with generate_aggregate_sql_col1: - st.button(f'Generate SQL', - key=f'srv_gen_sql', - on_click=streamlit_utils.click_button, - kwargs={"button_name": f"srv_gen_aggregate_sql"}) - - if "srv_gen_aggregate_sql" not in st.session_state: - st.session_state["srv_gen_aggregate_sql"] = False - if st.session_state["srv_gen_aggregate_sql"]: - st.session_state["srv_gen_aggregate_sql"] = False # Reset clicked status - with generate_aggregate_sql_col2: - with st.spinner('Generating...'): - process_logic = openai_api.generate_custom_data_processing_logics(industry_name=industry_name, - industry_contexts=industry_contexts, - involved_tables=staged_table_details, - custom_data_processing_logic=process_requirements, - output_table_name=srv_name) - try: - process_logic_json = json.loads(process_logic) - srv_sql_val = process_logic_json["sql"] - current_generated_std_srv_sqls[srv_name] = srv_sql_val - except ValueError as e: - st.write(process_logic) - - with st.expander(srv_name, expanded=True): - st.button("Add to Serving zone", - key=f"gen_aggregation_{srv_name}", - on_click=streamlit_utils.add_to_std_srv_zone, - args=[f"gen_aggregation_{srv_name}", srv_name, process_requirements, "serving"]) - st.text_input("Invovled tables", value=", ".join(selected_stg_std_tables), disabled=True) - std_sql = st.text_area(f'Transformation Spark SQL', - key=f'srv_aggregate_sql_{srv_name}', - value=current_generated_std_srv_sqls[srv_name], - on_change=streamlit_utils.update_sql, - args=[f'std_transform_sql_{srv_name}', srv_name]) \ No newline at end of file + + if "current_selected_srv_tables" not in st.session_state: + st.session_state['current_selected_srv_tables'] = [] + # current_selected_srv_tables = st.session_state['current_selected_srv_tables'] + + if "serving" not in current_std_srv_tables_schema: + current_std_srv_tables_schema["serving"] = {} + + for i in range(len(current_pipeline_obj["serving"])): + target_name = current_pipeline_obj['serving'][i]['output']['target'] + srv_name = st.text_input(f'Aggregation Name', key=f'srv_{i}_name', value=target_name) + + if srv_name: + st.subheader(srv_name) + current_generated_std_srv_sqls[srv_name] = "" + with st.expander(srv_name+" Settings", expanded=True): + + current_pipeline_obj['serving'][i]['output']['target'] = srv_name + current_pipeline_obj['serving'][i]['name'] = srv_name + + optional_tables = staged_table_names + standardized_table_names + # if srv_name in optional_tables: + # optional_tables.remove(srv_name) # Remove itself from the optional tables list + + if len(st.session_state['current_selected_srv_tables']) == i: + # Set None to default value of multiselect, otherwise it requires to be values in options + st.session_state['current_selected_srv_tables'].append(None) + selected_staged_tables = st.multiselect( + 'Choose datasets to do the data aggregation', + options=optional_tables, + default=st.session_state['current_selected_srv_tables'][i], + on_change=streamlit_utils.update_selected_tables, + key=f'srv_{i}_involved_tables', + args=[i, f'srv_{i}_involved_tables', 'current_selected_srv_tables']) + + if 'description' not in current_pipeline_obj['serving'][i]: + current_pipeline_obj['serving'][i]['description'] = "" + process_requirements = st.text_area("Aggregation requirements", + key=f"srv_{i}_aggregate_requirements", + value=current_pipeline_obj['serving'][i]['description']) + current_pipeline_obj['serving'][i]['description'] = process_requirements + + generate_aggregate_sql_col1, generate_aggregate_sql_col2 = st.columns(2) + with generate_aggregate_sql_col1: + st.button(f'Generate SQL', + key=f'srv_{i}_gen_sql', + on_click=streamlit_utils.click_button, + kwargs={"button_name": f"srv_{i}_gen_aggregate_sql"}) + + if len(current_pipeline_obj['serving'][i]['code']['sql']) == 0: + current_pipeline_obj['serving'][i]['code']['sql'].append("") + srv_sql_val = current_pipeline_obj['serving'][i]['code']['sql'][0] + + # Get selected staged table details + selected_table_details = streamlit_utils.get_selected_tables_details(staged_table_details + standardized_table_details, + selected_staged_tables) + if st.session_state[f"srv_{i}_gen_sql"]: + with generate_aggregate_sql_col2: + with st.spinner('Generating...'): + process_logic = openai_api.generate_custom_data_processing_logics(industry_name=industry_name, + industry_contexts=industry_contexts, + involved_tables=selected_table_details, + custom_data_processing_logic=process_requirements, + output_table_name=srv_name) + try: + process_logic_json = json.loads(process_logic) + srv_sql_val = process_logic_json["sql"] + srv_table_schema = process_logic_json["schema"] + current_std_srv_tables_schema["serving"][std_name] = srv_table_schema + # current_generated_srv_srv_sqls[srv_name] = srv_sql_val + current_pipeline_obj['serving'][i]['code']['sql'][0] = srv_sql_val + st.write(selected_table_details) + except ValueError as e: + st.write(process_logic) + + srv_sql = st.text_area(f'Aggregation Spark SQL', + key=f'srv_{i}_aggregate_sql_text_area', + value=srv_sql_val) + # on_change=streamlit_utils.update_sql, + # args=[f'srv_aggregate_sql_{srv_name}', srv_name]) + current_pipeline_obj['serving'][i]['code']['sql'][0] = srv_sql + + st.button(f'Run SQL', key=f'run_srv_{i}_sql', on_click=streamlit_utils.run_task, args = [srv_name, "serving"]) + if f"_{srv_name}_serving_data" in st.session_state: + st.dataframe(st.session_state[f"_{srv_name}_serving_data"]) + + st.button(f"Delete {srv_name}", key="delete_srv_"+str(i), on_click=streamlit_utils.delete_task, args = ['serving', i]) + + if i != len(current_pipeline_obj["serving"]) - 1: + st.divider() + + if len(current_pipeline_obj["serving"]) == 0: + st.write("No aggregation in serving zone") + + st.divider() + st.button('Add Aggregation', on_click=streamlit_utils.add_aggregation, use_container_width=True) diff --git a/src/streamlit/streamlit_utils.py b/src/streamlit/streamlit_utils.py index 322d1fe..1d237ce 100644 --- a/src/streamlit/streamlit_utils.py +++ b/src/streamlit/streamlit_utils.py @@ -1,6 +1,7 @@ import cddp import json import streamlit as st +import tempfile import uuid @@ -30,10 +31,11 @@ def get_selected_table_details(tables, table_name): def get_selected_tables_details(tables, table_names): target_tables = [] - for table_name in table_names: - for table_details in tables: - if table_details["table_name"] == table_name: - target_tables.append(table_details) + if table_names: + for table_name in table_names: + for table_details in tables: + if table_details["table_name"] == table_name: + target_tables.append(table_details) return target_tables @@ -53,12 +55,15 @@ def update_sql(key, table_name): current_generated_std_srv_sqls[table_name] = st.session_state[key] -def add_to_staging_zone(stg_name, stg_desc): +def add_to_staging_zone(gen_table_index, stg_name, stg_desc): pipeline_obj = st.session_state['current_pipeline_obj'] current_generated_sample_data = st.session_state['current_generated_sample_data'] sample_data = current_generated_sample_data.get(stg_name, None) + generated_tables = st.session_state["current_generated_tables"]["generated_tables"] + + if st.session_state[f"add_to_staging_{gen_table_index}_checkbox"]: # Add to staging zone if checkbox is checked + generated_tables[gen_table_index]["staged_flag"] = True - if st.session_state[stg_name]: # Add to staging zone if checkbox is checked if sample_data: spark = st.session_state["spark"] json_str, schema = cddp.load_sample_data(spark, json.dumps(sample_data), format="json") @@ -69,23 +74,8 @@ def add_to_staging_zone(stg_name, stg_desc): json_schema = {} add_stg_dataset(pipeline_obj, stg_name, json_schema, json_sample_data) - # pipeline_obj["staging"].append({ - # "name": stg_name, - # "description": stg_desc, - # "input": { - # "type": "filestore", - # "format": "json", - # "path": f"/FileStore/cddp_apps/{pipeline_obj['id']}/landing/{task_name}", - # "read-type": "batch" - # }, - # "output": { - # "target": stg_name, - # "type": ["file", "view"] - # }, - # "schema": json_schema, - # "sampleData": json_sample_data - # }) else: # Remove staging task from staging zone if it's unchecked + generated_tables[gen_table_index]["staged_flag"] = False for index, obj in enumerate(pipeline_obj["staging"]): if obj["name"] == stg_name: del pipeline_obj['staging'][index] @@ -145,14 +135,8 @@ def add_std_srv_schema(zone, output_table_name, schema): def get_standardized_tables(): pipeline_obj = st.session_state['current_pipeline_obj'] standard = pipeline_obj.get("standard", None) - - if "current_std_srv_tables_schema" not in st.session_state: - st.session_state['current_std_srv_tables_schema'] = {} current_std_srv_tables_schema = st.session_state['current_std_srv_tables_schema'] - if "standard" not in current_std_srv_tables_schema: - current_std_srv_tables_schema["standard"] = {} - standardized_table_names = [] standardized_table_details = [] if standard: @@ -164,7 +148,7 @@ def get_standardized_tables(): "schema": current_std_srv_tables_schema["standard"].get(std_name, "") }) - return standardized_table_names, standardized_table_names + return standardized_table_names, standardized_table_details def create_pipeline(): @@ -197,4 +181,106 @@ def add_stg_dataset(pipeline_obj, task_name, schema={}, sample_data=[]): }, "schema": schema, "sampleData": sample_data - }) \ No newline at end of file + }) + + +def run_task(task_name, stage="standard"): + dataframe = None + try: + spark = st.session_state["spark"] + config = st.session_state["current_pipeline_obj"] + with tempfile.TemporaryDirectory() as tmpdir: + working_dir = tmpdir+"/"+config['name'] + cddp.init(spark, config, working_dir) + cddp.clean_database(spark, config) + cddp.init_database(spark, config) + try: + cddp.init_staging_sample_dataframe(spark, config) + except Exception as e: + print(e) + if stage in config: + for task in config[stage]: + + if task_name == task['name']: + print(f"start {stage} task: "+task_name) + res_df = None + if stage == "standard": + res_df = cddp.start_standard_job(spark, config, task, False, True) + elif stage == "serving": + res_df = cddp.start_serving_job(spark, config, task, False, True) + dataframe = res_df.toPandas() + print(dataframe) + st.session_state[f'_{task_name}_{stage}_data'] = dataframe + + res_schema = res_df.schema.json() + st.session_state["current_std_srv_tables_schema"][stage][task_name] = json.loads(res_schema) + + except Exception as e: + st.error(f"Cannot run task: {e}") + + return dataframe + + +def add_transformation(): + task_name = "untitled"+str(len(st.session_state["current_pipeline_obj"]["standard"])+1) + st.session_state["current_pipeline_obj"]["standard"].append({ + "name": task_name, + "type": "batch", + "code": { + "lang": "sql", + "sql": [] + }, + "output": { + "target": task_name, + "type": ["file", "view"] + }, + "dependency":[] + }) + + +def delete_task(type, index): + current_pipeline_obj = st.session_state["current_pipeline_obj"] + + if type == "staging": + del current_pipeline_obj['staging'][index] + elif type == "standard": + del current_pipeline_obj['standard'][index] + elif type == "serving": + del current_pipeline_obj['serving'][index] + elif type == "visualization": + del current_pipeline_obj['visualization'][index] + + st.session_state['current_pipeline_obj'] = current_pipeline_obj + + +def update_selected_tables(index, multiselect_key, session_state_key): + st.session_state[session_state_key][index] = st.session_state[multiselect_key] + + +def add_aggregation(): + task_name = "untitled"+str(len(st.session_state["current_pipeline_obj"]["serving"])+1) + st.session_state["current_pipeline_obj"]["serving"].append({ + "name": task_name, + "type": "batch", + "code": { + "lang": "sql", + "sql": [] + }, + "output": { + "target": task_name, + "type": ["file", "view"] + }, + "dependency":[] + }) + + +def has_staged_table(): + has_staged_table = False + if "generated_tables" in st.session_state["current_generated_tables"]: + generated_tables = st.session_state["current_generated_tables"]["generated_tables"] + for table in generated_tables: + if "staged_flag" in table and table["staged_flag"]: + has_staged_table = True + break + + return has_staged_table From 75416b3870c8964c87d01e7d14e48088b25b976f Mon Sep 17 00:00:00 2001 From: thurston-chen Date: Sun, 24 Sep 2023 17:23:51 +0800 Subject: [PATCH 2/8] fix: add tasks dependency checking logic --- src/streamlit/pages/2_AI Assistant.py | 58 +++++++++++++++++++++------ src/streamlit/streamlit_utils.py | 39 ++++++++++++------ 2 files changed, 73 insertions(+), 24 deletions(-) diff --git a/src/streamlit/pages/2_AI Assistant.py b/src/streamlit/pages/2_AI Assistant.py index 633c5c5..7cf659f 100644 --- a/src/streamlit/pages/2_AI Assistant.py +++ b/src/streamlit/pages/2_AI Assistant.py @@ -19,7 +19,7 @@ st.set_page_config(page_title="AI Assistant") colored_header( - label="AI Assitant", + label="AI Assistant", description=f"Leverage AI to assist you in data pipeline development", color_name="violet-70", ) @@ -228,7 +228,20 @@ for i in range(len(current_pipeline_obj["standard"])): target_name = current_pipeline_obj['standard'][i]['output']['target'] - std_name = st.text_input(f'Transformation Name', key=f'std_{i}_name', value=target_name) + disable_std_name_input = False + disable_std_task_deletion = False + std_name_has_dependency = streamlit_utils.check_tables_dependency(target_name) + + if std_name_has_dependency: + disable_std_name_input = True + disable_std_task_deletion = True + st.info("""This standardization task has been referenced in other tasks! + Please remove relevant dependency before trying to rename or delete this task.""") + + std_name = st.text_input('Transformation Name', + key=f'std_{i}_name', + value=target_name, + disabled=disable_std_name_input) if std_name: st.subheader(std_name) @@ -239,7 +252,7 @@ current_pipeline_obj['standard'][i]['name'] = std_name # Get all standardized table details - standardized_table_names, standardized_table_details = streamlit_utils.get_standardized_tables() + standardized_table_names, standardized_table_details = streamlit_utils.get_std_srv_tables('standard') optional_tables = staged_table_names + standardized_table_names if std_name in optional_tables: @@ -292,7 +305,6 @@ current_std_srv_tables_schema["standard"][std_name] = std_table_schema # current_generated_std_srv_sqls[std_name] = std_sql_val current_pipeline_obj['standard'][i]['code']['sql'][0] = std_sql_val - st.write(selected_table_details) except ValueError as e: st.write(process_logic) @@ -307,7 +319,11 @@ if f"_{std_name}_standard_data" in st.session_state: st.dataframe(st.session_state[f"_{std_name}_standard_data"]) - st.button(f"Delete {std_name}", key="delete_std_"+str(i), on_click=streamlit_utils.delete_task, args = ['standard', i]) + st.button(f"Delete {std_name}", + key="delete_std_"+str(i), + on_click=streamlit_utils.delete_task, + args = ['standard', i], + disabled=disable_std_task_deletion) if i != len(current_pipeline_obj["standard"]) - 1: st.divider() @@ -330,7 +346,20 @@ for i in range(len(current_pipeline_obj["serving"])): target_name = current_pipeline_obj['serving'][i]['output']['target'] - srv_name = st.text_input(f'Aggregation Name', key=f'srv_{i}_name', value=target_name) + srv_name_has_dependency = streamlit_utils.check_tables_dependency(target_name) + disable_srv_name_input = False + disable_srv_task_deletion = False + + if srv_name_has_dependency: + disable_srv_name_input = True + disable_srv_task_deletion = True + st.info("""This serving task has been referenced in other tasks! + Please remove relevant dependency before trying to rename the task.""") + + srv_name = st.text_input('Aggregation Name', + key=f'srv_{i}_name', + value=target_name, + disabled=srv_name_has_dependency) if srv_name: st.subheader(srv_name) @@ -340,9 +369,11 @@ current_pipeline_obj['serving'][i]['output']['target'] = srv_name current_pipeline_obj['serving'][i]['name'] = srv_name - optional_tables = staged_table_names + standardized_table_names - # if srv_name in optional_tables: - # optional_tables.remove(srv_name) # Remove itself from the optional tables list + # Get all standardized table details + serving_table_names, serving_table_details = streamlit_utils.get_std_srv_tables('serving') + optional_tables = staged_table_names + standardized_table_names + serving_table_names + if srv_name in optional_tables: + optional_tables.remove(srv_name) # Remove itself from the optional tables list if len(st.session_state['current_selected_srv_tables']) == i: # Set None to default value of multiselect, otherwise it requires to be values in options @@ -374,7 +405,7 @@ srv_sql_val = current_pipeline_obj['serving'][i]['code']['sql'][0] # Get selected staged table details - selected_table_details = streamlit_utils.get_selected_tables_details(staged_table_details + standardized_table_details, + selected_table_details = streamlit_utils.get_selected_tables_details(staged_table_details + standardized_table_details + serving_table_details, selected_staged_tables) if st.session_state[f"srv_{i}_gen_sql"]: with generate_aggregate_sql_col2: @@ -391,7 +422,6 @@ current_std_srv_tables_schema["serving"][std_name] = srv_table_schema # current_generated_srv_srv_sqls[srv_name] = srv_sql_val current_pipeline_obj['serving'][i]['code']['sql'][0] = srv_sql_val - st.write(selected_table_details) except ValueError as e: st.write(process_logic) @@ -406,7 +436,11 @@ if f"_{srv_name}_serving_data" in st.session_state: st.dataframe(st.session_state[f"_{srv_name}_serving_data"]) - st.button(f"Delete {srv_name}", key="delete_srv_"+str(i), on_click=streamlit_utils.delete_task, args = ['serving', i]) + st.button(f"Delete {srv_name}", + key="delete_srv_"+str(i), + on_click=streamlit_utils.delete_task, + args = ['serving', i], + disabled=disable_srv_task_deletion) if i != len(current_pipeline_obj["serving"]) - 1: st.divider() diff --git a/src/streamlit/streamlit_utils.py b/src/streamlit/streamlit_utils.py index 1d237ce..48561b5 100644 --- a/src/streamlit/streamlit_utils.py +++ b/src/streamlit/streamlit_utils.py @@ -132,23 +132,23 @@ def add_std_srv_schema(zone, output_table_name, schema): current_std_srv_tables_schema[zone][output_table_name] = schema -def get_standardized_tables(): +def get_std_srv_tables(task_type): pipeline_obj = st.session_state['current_pipeline_obj'] - standard = pipeline_obj.get("standard", None) + task = pipeline_obj.get(task_type, None) current_std_srv_tables_schema = st.session_state['current_std_srv_tables_schema'] - standardized_table_names = [] - standardized_table_details = [] - if standard: - for standardized_table in standard: - std_name = standardized_table["output"]["target"] - standardized_table_names.append(std_name) - standardized_table_details.append({ - "table_name": std_name, - "schema": current_std_srv_tables_schema["standard"].get(std_name, "") + std_srv_table_names = [] + std_srv_table_details = [] + if task: + for std_srv_table in task: + task_name = std_srv_table["output"]["target"] + std_srv_table_names.append(task_name) + std_srv_table_details.append({ + "table_name": task_name, + "schema": current_std_srv_tables_schema[task_type].get(task_name, "") }) - return standardized_table_names, standardized_table_details + return std_srv_table_names, std_srv_table_details def create_pipeline(): @@ -240,13 +240,17 @@ def add_transformation(): def delete_task(type, index): current_pipeline_obj = st.session_state["current_pipeline_obj"] + current_selected_std_tables = st.session_state.get('current_selected_std_tables', []) + current_selected_srv_tables = st.session_state.get('current_selected_srv_tables', []) if type == "staging": del current_pipeline_obj['staging'][index] elif type == "standard": del current_pipeline_obj['standard'][index] + del current_selected_std_tables[index] elif type == "serving": del current_pipeline_obj['serving'][index] + del current_selected_srv_tables[index] elif type == "visualization": del current_pipeline_obj['visualization'][index] @@ -284,3 +288,14 @@ def has_staged_table(): break return has_staged_table + + +def check_tables_dependency(target_name): + has_dependency = False + current_std_srv_tasks = st.session_state.get('current_selected_std_tables', []) + st.session_state.get('current_selected_srv_tables', []) + + for task in current_std_srv_tasks: + if task and target_name in task: + has_dependency = True + + return has_dependency From cc58b118cddcabe2a30a0ff61d58d21b4dabb49f Mon Sep 17 00:00:00 2001 From: thurston-chen Date: Mon, 25 Sep 2023 18:50:47 +0800 Subject: [PATCH 3/8] fix: unified session state object for std/srv editing tasks --- src/streamlit/pages/2_AI Assistant.py | 83 ++++++++++----------------- src/streamlit/streamlit_utils.py | 50 +++++++++------- 2 files changed, 60 insertions(+), 73 deletions(-) diff --git a/src/streamlit/pages/2_AI Assistant.py b/src/streamlit/pages/2_AI Assistant.py index 7cf659f..d6edd15 100644 --- a/src/streamlit/pages/2_AI Assistant.py +++ b/src/streamlit/pages/2_AI Assistant.py @@ -207,25 +207,18 @@ st.write(tables) with tab_std_sql: - - if "current_std_srv_tables_schema" not in st.session_state: - st.session_state['current_std_srv_tables_schema'] = {} - current_std_srv_tables_schema = st.session_state['current_std_srv_tables_schema'] - - if "standard" not in current_std_srv_tables_schema: - current_std_srv_tables_schema["standard"] = {} - - if "current_generated_std_srv_sqls" not in st.session_state: - st.session_state['current_generated_std_srv_sqls'] = {} - current_generated_std_srv_sqls = st.session_state['current_generated_std_srv_sqls'] - - if "current_selected_std_tables" not in st.session_state: - st.session_state['current_selected_std_tables'] = [] - # current_selected_std_tables = st.session_state['current_selected_std_tables'] + if "current_editing_pipeline_tasks" not in st.session_state: + st.session_state['current_editing_pipeline_tasks'] = {} + current_editing_pipeline_tasks = st.session_state['current_editing_pipeline_tasks'] + if "standard" not in current_editing_pipeline_tasks: + current_editing_pipeline_tasks['standard'] = [] # Get all staged table details staged_table_names, staged_table_details = streamlit_utils.get_staged_tables() + # Get all standardized table details + standardized_table_names, standardized_table_details = streamlit_utils.get_std_srv_tables('standard') + for i in range(len(current_pipeline_obj["standard"])): target_name = current_pipeline_obj['standard'][i]['output']['target'] disable_std_name_input = False @@ -245,29 +238,24 @@ if std_name: st.subheader(std_name) - current_generated_std_srv_sqls[std_name] = "" with st.expander(std_name+" Settings", expanded=True): - current_pipeline_obj['standard'][i]['output']['target'] = std_name current_pipeline_obj['standard'][i]['name'] = std_name - # Get all standardized table details - standardized_table_names, standardized_table_details = streamlit_utils.get_std_srv_tables('standard') - optional_tables = staged_table_names + standardized_table_names if std_name in optional_tables: optional_tables.remove(std_name) # Remove itself from the optional tables list - if len(st.session_state['current_selected_std_tables']) == i: - # Set None to default value of multiselect, otherwise it requires to be values in options - st.session_state['current_selected_std_tables'].append(None) + if len(current_editing_pipeline_tasks['standard']) == i: + current_editing_pipeline_tasks['standard'].append({}) + current_editing_pipeline_tasks['standard'][i]['target'] = std_name selected_staged_tables = st.multiselect( 'Choose datasets to do the data transformation', options=optional_tables, - default=st.session_state['current_selected_std_tables'][i], + default=current_editing_pipeline_tasks['standard'][i].get('involved_tables', None), on_change=streamlit_utils.update_selected_tables, key=f'std_{i}_involved_tables', - args=[i, f'std_{i}_involved_tables', 'current_selected_std_tables']) + args=['standard', i, f'std_{i}_involved_tables']) if 'description' not in current_pipeline_obj['standard'][i]: current_pipeline_obj['standard'][i]['description'] = "" @@ -302,8 +290,7 @@ process_logic_json = json.loads(process_logic) std_sql_val = process_logic_json["sql"] std_table_schema = process_logic_json["schema"] - current_std_srv_tables_schema["standard"][std_name] = std_table_schema - # current_generated_std_srv_sqls[std_name] = std_sql_val + current_editing_pipeline_tasks['standard'][i]['query_results_schema'] = std_table_schema current_pipeline_obj['standard'][i]['code']['sql'][0] = std_sql_val except ValueError as e: st.write(process_logic) @@ -311,13 +298,12 @@ std_sql = st.text_area(f'Transformation Spark SQL', key=f'std_{i}_transform_sql_text_area', value=std_sql_val) - # on_change=streamlit_utils.update_sql, - # args=[f'std_transform_sql_{std_name}', std_name]) + current_pipeline_obj['standard'][i]['code']['sql'][0] = std_sql - st.button(f'Run SQL', key=f'run_std_{i}_sql', on_click=streamlit_utils.run_task, args = [std_name, "standard"]) - if f"_{std_name}_standard_data" in st.session_state: - st.dataframe(st.session_state[f"_{std_name}_standard_data"]) + st.button(f'Run SQL', key=f'run_std_{i}_sql', on_click=streamlit_utils.run_task, args = [std_name, "standard", i]) + if 'sql_query_results' in current_editing_pipeline_tasks['standard'][i]: + st.dataframe(current_editing_pipeline_tasks['standard'][i]['sql_query_results']) st.button(f"Delete {std_name}", key="delete_std_"+str(i), @@ -337,12 +323,8 @@ with tab_srv_sql: - if "current_selected_srv_tables" not in st.session_state: - st.session_state['current_selected_srv_tables'] = [] - # current_selected_srv_tables = st.session_state['current_selected_srv_tables'] - - if "serving" not in current_std_srv_tables_schema: - current_std_srv_tables_schema["serving"] = {} + if "serving" not in current_editing_pipeline_tasks: + current_editing_pipeline_tasks['serving'] = [] for i in range(len(current_pipeline_obj["serving"])): target_name = current_pipeline_obj['serving'][i]['output']['target'] @@ -363,9 +345,7 @@ if srv_name: st.subheader(srv_name) - current_generated_std_srv_sqls[srv_name] = "" with st.expander(srv_name+" Settings", expanded=True): - current_pipeline_obj['serving'][i]['output']['target'] = srv_name current_pipeline_obj['serving'][i]['name'] = srv_name @@ -375,16 +355,17 @@ if srv_name in optional_tables: optional_tables.remove(srv_name) # Remove itself from the optional tables list - if len(st.session_state['current_selected_srv_tables']) == i: - # Set None to default value of multiselect, otherwise it requires to be values in options - st.session_state['current_selected_srv_tables'].append(None) + if len(current_editing_pipeline_tasks['serving']) == i: + current_editing_pipeline_tasks['serving'].append({}) + current_editing_pipeline_tasks['serving'][i]['target'] = srv_name + selected_staged_tables = st.multiselect( 'Choose datasets to do the data aggregation', options=optional_tables, - default=st.session_state['current_selected_srv_tables'][i], + default=current_editing_pipeline_tasks['serving'][i].get('involved_tables', None), on_change=streamlit_utils.update_selected_tables, key=f'srv_{i}_involved_tables', - args=[i, f'srv_{i}_involved_tables', 'current_selected_srv_tables']) + args=['serving', i, f'srv_{i}_involved_tables']) if 'description' not in current_pipeline_obj['serving'][i]: current_pipeline_obj['serving'][i]['description'] = "" @@ -419,8 +400,7 @@ process_logic_json = json.loads(process_logic) srv_sql_val = process_logic_json["sql"] srv_table_schema = process_logic_json["schema"] - current_std_srv_tables_schema["serving"][std_name] = srv_table_schema - # current_generated_srv_srv_sqls[srv_name] = srv_sql_val + current_editing_pipeline_tasks['serving'][i]['query_results_schema'] = srv_table_schema current_pipeline_obj['serving'][i]['code']['sql'][0] = srv_sql_val except ValueError as e: st.write(process_logic) @@ -428,13 +408,12 @@ srv_sql = st.text_area(f'Aggregation Spark SQL', key=f'srv_{i}_aggregate_sql_text_area', value=srv_sql_val) - # on_change=streamlit_utils.update_sql, - # args=[f'srv_aggregate_sql_{srv_name}', srv_name]) + current_pipeline_obj['serving'][i]['code']['sql'][0] = srv_sql - st.button(f'Run SQL', key=f'run_srv_{i}_sql', on_click=streamlit_utils.run_task, args = [srv_name, "serving"]) - if f"_{srv_name}_serving_data" in st.session_state: - st.dataframe(st.session_state[f"_{srv_name}_serving_data"]) + st.button(f'Run SQL', key=f'run_srv_{i}_sql', on_click=streamlit_utils.run_task, args = [srv_name, "serving", i]) + if 'sql_query_results' in current_editing_pipeline_tasks['serving'][i]: + st.dataframe(current_editing_pipeline_tasks['serving'][i]['sql_query_results']) st.button(f"Delete {srv_name}", key="delete_srv_"+str(i), diff --git a/src/streamlit/streamlit_utils.py b/src/streamlit/streamlit_utils.py index 48561b5..8819cf2 100644 --- a/src/streamlit/streamlit_utils.py +++ b/src/streamlit/streamlit_utils.py @@ -134,19 +134,21 @@ def add_std_srv_schema(zone, output_table_name, schema): def get_std_srv_tables(task_type): pipeline_obj = st.session_state['current_pipeline_obj'] - task = pipeline_obj.get(task_type, None) - current_std_srv_tables_schema = st.session_state['current_std_srv_tables_schema'] + tasks = pipeline_obj.get(task_type, None) + current_editing_pipeline_tasks = st.session_state['current_editing_pipeline_tasks'] std_srv_table_names = [] std_srv_table_details = [] - if task: - for std_srv_table in task: + if tasks: + for std_srv_table in tasks: task_name = std_srv_table["output"]["target"] std_srv_table_names.append(task_name) - std_srv_table_details.append({ - "table_name": task_name, - "schema": current_std_srv_tables_schema[task_type].get(task_name, "") - }) + for index, task in enumerate(current_editing_pipeline_tasks[task_type]): + if task['target'] == task_name: + std_srv_table_details.append({ + "table_name": task_name, + "schema": current_editing_pipeline_tasks[task_type][index].get('query_results_schema', '') + }) return std_srv_table_names, std_srv_table_details @@ -184,8 +186,9 @@ def add_stg_dataset(pipeline_obj, task_name, schema={}, sample_data=[]): }) -def run_task(task_name, stage="standard"): +def run_task(task_name, stage="standard", index=None): dataframe = None + current_editing_pipeline_tasks = st.session_state['current_editing_pipeline_tasks'] try: spark = st.session_state["spark"] config = st.session_state["current_pipeline_obj"] @@ -209,11 +212,11 @@ def run_task(task_name, stage="standard"): elif stage == "serving": res_df = cddp.start_serving_job(spark, config, task, False, True) dataframe = res_df.toPandas() - print(dataframe) - st.session_state[f'_{task_name}_{stage}_data'] = dataframe + # print(dataframe) + current_editing_pipeline_tasks[stage][index]['sql_query_results'] = dataframe res_schema = res_df.schema.json() - st.session_state["current_std_srv_tables_schema"][stage][task_name] = json.loads(res_schema) + current_editing_pipeline_tasks[stage][index]['query_results_schema'] = json.loads(res_schema) except Exception as e: st.error(f"Cannot run task: {e}") @@ -247,18 +250,18 @@ def delete_task(type, index): del current_pipeline_obj['staging'][index] elif type == "standard": del current_pipeline_obj['standard'][index] - del current_selected_std_tables[index] + del st.session_state['current_editing_pipeline_tasks']['standard'][index] elif type == "serving": del current_pipeline_obj['serving'][index] - del current_selected_srv_tables[index] + del st.session_state['current_editing_pipeline_tasks']['serving'][index] elif type == "visualization": del current_pipeline_obj['visualization'][index] st.session_state['current_pipeline_obj'] = current_pipeline_obj -def update_selected_tables(index, multiselect_key, session_state_key): - st.session_state[session_state_key][index] = st.session_state[multiselect_key] +def update_selected_tables(task_type, index, multiselect_key): + st.session_state['current_editing_pipeline_tasks'][task_type][index]['involved_tables'] = st.session_state[multiselect_key] def add_aggregation(): @@ -292,10 +295,15 @@ def has_staged_table(): def check_tables_dependency(target_name): has_dependency = False - current_std_srv_tasks = st.session_state.get('current_selected_std_tables', []) + st.session_state.get('current_selected_srv_tables', []) - - for task in current_std_srv_tasks: - if task and target_name in task: - has_dependency = True + current_editing_pipeline_tasks = st.session_state['current_editing_pipeline_tasks'] + current_used_std_tables = [] + current_used_srv_tables = [] + for task in current_editing_pipeline_tasks['standard']: + current_used_std_tables += task.get('involved_tables', []) + for task in current_editing_pipeline_tasks['serving']: + current_used_srv_tables += task.get('involved_tables', []) + + if target_name in current_used_std_tables + current_used_srv_tables: + has_dependency = True return has_dependency From e14cfcab396598282ccc7534824152919ea63dda Mon Sep 17 00:00:00 2001 From: thurston-chen Date: Tue, 26 Sep 2023 08:49:48 +0800 Subject: [PATCH 4/8] fix: add dependency check for staging tables and fix some UI issues --- src/streamlit/pages/2_AI Assistant.py | 50 +++++++++++++++++++++------ src/streamlit/streamlit_utils.py | 8 +++-- 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/src/streamlit/pages/2_AI Assistant.py b/src/streamlit/pages/2_AI Assistant.py index d6edd15..19f77de 100644 --- a/src/streamlit/pages/2_AI Assistant.py +++ b/src/streamlit/pages/2_AI Assistant.py @@ -93,7 +93,8 @@ if "current_generated_sample_data" not in st.session_state: st.session_state['current_generated_sample_data'] = {} current_generated_sample_data = st.session_state['current_generated_sample_data'] -# AI Assistnt for tables generation + + # AI Assistnt for tables generation tables = [] generate_tables_col1, generate_tables_col2 = st.columns(2) with generate_tables_col1: @@ -115,7 +116,8 @@ current_generated_tables["generated_tables"] = json.loads(tables) # Clean current_generated_sample_data key in session state once Generate button is clicked again - st.session_state['current_generated_sample_data'] = {} + del st.session_state['current_generated_sample_data'] + current_generated_sample_data = {} try: if "generated_tables" in current_generated_tables: @@ -124,6 +126,7 @@ for gen_table_index, table in enumerate(tables): columns = table["columns"] columns_df = pd.DataFrame.from_dict(columns, orient='columns') + stg_name_has_dependency = streamlit_utils.check_tables_dependency(table["table_name"]) sample_data = None added_to_stage = current_generated_tables["generated_tables"][gen_table_index].get("staged_flag", False) @@ -134,18 +137,41 @@ with st.expander(expander_label, expanded=added_to_stage): gen_table_name = table["table_name"] gen_table_desc = table["table_description"] + gen_table_sample_data_count = current_generated_tables["generated_tables"][gen_table_index].get("sample_data_count", 5) + sample_data_requirements_flag = current_generated_tables["generated_tables"][gen_table_index].get("sample_data_requirements_flag", False) + data_requirements = current_generated_tables["generated_tables"][gen_table_index].get("data_requirements", "") st.write(gen_table_desc) st.write(columns_df) st.write(f"Generate sample data") - rows_count = st.slider("Number of rows", min_value=5, max_value=50, key=f'gen_rows_count_slider_{gen_table_name}') - enable_data_requirements = st.toggle("With extra sample data requirements", key=f'data_requirements_toggle_{gen_table_name}') - data_requirements = "" + rows_count = st.slider("Number of rows", + min_value=5, + max_value=50, + value=gen_table_sample_data_count, + key=f'gen_rows_count_slider_{gen_table_name}', + on_change=streamlit_utils.widget_on_change, + args=[f'gen_rows_count_slider_{gen_table_name}', + gen_table_index, + 'sample_data_count']) + + enable_data_requirements = st.toggle("With extra sample data requirements", + value=sample_data_requirements_flag, + key=f'data_requirements_toggle_{gen_table_name}', + on_change=streamlit_utils.widget_on_change, + args=[f'data_requirements_toggle_{gen_table_name}', + gen_table_index, + 'sample_data_requirements_flag']) + if enable_data_requirements: data_requirements = st.text_area("Extra requirements for sample data", - key=f'data_requirements_text_area_{gen_table_name}', - placeholder="Exp: value of column X should follow patterns xxx-xxxx, while x could be A-Z or 0-9") + value=data_requirements, + key=f'data_requirements_text_area_{gen_table_name}', + placeholder="Exp: value of column X should follow patterns xxx-xxxx, while x could be A-Z or 0-9", + on_change=streamlit_utils.widget_on_change, + args=[f'data_requirements_text_area_{gen_table_name}', + gen_table_index, + 'data_requirements']) generate_sample_data_col1, generate_sample_data_col2 = st.columns(2) with generate_sample_data_col1: @@ -189,18 +215,22 @@ st.session_state[f"{gen_table_name}_smaple_data_generated"] = False # Reset data generated flag json_sample_data = json.loads(sample_data) current_generated_sample_data[gen_table_name] = json_sample_data # Save generated data to session_state - # st.session_state[f'stg_{i}_data'] = sample_data if gen_table_name in current_generated_sample_data: sample_data_df = pd.DataFrame.from_dict(current_generated_sample_data[gen_table_name], orient='columns') st.write(sample_data_df) - # Show checkbox only after sample data has been generated + if stg_name_has_dependency: + st.info("""This table has been referenced in other tasks! + Please remove relevant dependency before trying to remove it from staging zone.""") + + # Show checkbox only after sample data has been generated st.checkbox("Add to staging zone", key=f"add_to_staging_{gen_table_index}_checkbox", value=added_to_stage, on_change=streamlit_utils.add_to_staging_zone, - args=[gen_table_index, gen_table_name, gen_table_desc]) + args=[gen_table_index, gen_table_name, gen_table_desc], + disabled=stg_name_has_dependency) except ValueError as e: # TODO: Add error/exception to standard error-showing widget diff --git a/src/streamlit/streamlit_utils.py b/src/streamlit/streamlit_utils.py index 8819cf2..95e0b36 100644 --- a/src/streamlit/streamlit_utils.py +++ b/src/streamlit/streamlit_utils.py @@ -167,6 +167,7 @@ def create_pipeline(): } return st.session_state['current_pipeline_obj'] + def add_stg_dataset(pipeline_obj, task_name, schema={}, sample_data=[]): pipeline_obj["staging"].append({ "name": task_name, @@ -243,8 +244,6 @@ def add_transformation(): def delete_task(type, index): current_pipeline_obj = st.session_state["current_pipeline_obj"] - current_selected_std_tables = st.session_state.get('current_selected_std_tables', []) - current_selected_srv_tables = st.session_state.get('current_selected_srv_tables', []) if type == "staging": del current_pipeline_obj['staging'][index] @@ -307,3 +306,8 @@ def check_tables_dependency(target_name): has_dependency = True return has_dependency + + +def widget_on_change(widget_key, index, session_state_key): + current_generated_tables = st.session_state['current_generated_tables']['generated_tables'] + current_generated_tables[index][session_state_key] = st.session_state[widget_key] From b137bdb1db8a53feb94bf5edaaa067839d2bc885 Mon Sep 17 00:00:00 2001 From: thurston-chen Date: Sat, 7 Oct 2023 17:57:56 +0800 Subject: [PATCH 5/8] refactor: generate one table at a time/request --- src/cddp/openai_api.py | 755 ++++++++++++++------------ src/streamlit/pages/2_AI Assistant.py | 204 +++---- src/streamlit/streamlit_utils.py | 179 ++++++ 3 files changed, 652 insertions(+), 486 deletions(-) diff --git a/src/cddp/openai_api.py b/src/cddp/openai_api.py index aaabe58..459801d 100644 --- a/src/cddp/openai_api.py +++ b/src/cddp/openai_api.py @@ -3,368 +3,437 @@ from langchain.chains import LLMChain import json import os +import time from dotenv import load_dotenv load_dotenv() -OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") -DEPLOYMENT = os.getenv("OPENAI_DEPLOYMENT") -MODEL = os.getenv("OPENAI_MODEL") -API_VERSION = os.getenv("OPENAI_API_VERSION") +class OpenaiApi: + OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") + DEPLOYMENT = os.getenv("OPENAI_DEPLOYMENT") + MODEL = os.getenv("OPENAI_MODEL") + API_VERSION = os.getenv("OPENAI_API_VERSION") -def _prepare_openapi_llm(): - llm = AzureChatOpenAI(deployment_name=DEPLOYMENT, - model=MODEL, - openai_api_version=API_VERSION, - openai_api_base=OPENAI_API_BASE) - return llm + def __init__(self): + self.llm = AzureChatOpenAI(deployment_name=self.DEPLOYMENT, + model=self.MODEL, + openai_api_version=self.API_VERSION, + openai_api_base=self.OPENAI_API_BASE) -def recommend_data_processing_scenario(industry_name: str): - recommend_data_processing_scenario_template = """ - You're a data engineer and familiar with the {industry_name} industry IT systems. - Please recommend 7 to 10 different data processing pipelines scenarios, including required steps like collecting data sources, transforming the data and generating aggregated metrics, etc. - Your answer should be in an array of JSON format like below. - [ + def recommend_data_processing_scenario(self, industry_name: str): + recommend_data_processing_scenario_template = """ + You're a data engineer and familiar with the {industry_name} industry IT systems. + Please recommend 7 to 10 different data processing pipelines scenarios, including required steps like collecting data sources, transforming the data and generating aggregated metrics, etc. + Your answer should be in an array of JSON format like below. + [ + {{ + "pipeline_name": "{{name of the data processing pipeline}}", + "description": "{{short description on the data pipeline}}", + "stages": [ + {{ + "stage": "staging", + "description": "{{description on collected data sources for the rest of the data processing pipeline}}" + }}, + {{ + "stage": "standard", + "description": "{{description on data transformation logics}}" + }}, + {{ + "stage": "serving", + "description": "{{description on data aggregation logics}}" + }} + ] + }} + ] + Therefore your answers are: + """ + + prompt = PromptTemplate( + input_variables=["industry_name"], + template=recommend_data_processing_scenario_template, + ) + chain = LLMChain(llm=self.llm, prompt=prompt) + response = chain({"industry_name": industry_name}) + results = response["text"] + + return results + + def recommend_tables_for_industry(self, industry_name: str, industry_contexts: str): + """ Recommend database tables for a given industry and relevant contexts. + + :param industry_name: industry name + :param industry_contexts: industry descriptions/contexts + + :returns: recommened tables with schema in array of json format + """ + + recommaned_tables_for_industry_template = """ + You're a data engineer and familiar with the {industry_name} industry IT systems. + And below is relevant contexts of the industry: + {industry_contexts} + + Please recommend some database tables with potential data schema for the above contexts. + Your response should be an array of JSON format objects like below. {{ - "pipeline_name": "{{name of the data processing pipeline}}", - "description": "{{short description on the data pipeline}}", - [ - {{ - "stage": "staging", - "description": "{{description on collected data sources for the rest of the data processing pipeline}}" - }}, - {{ - "stage": "standard", - "description": "{{description on data transformation logics}}" - }}, + "table_name": "{{table name}}", + "table_description": "{{table description}}", + "columns": [ {{ - "stage": "serving", - "description": "{{description on data aggregation logics}}" + "column_name": "{{column name}}", + "data_type": "{{data type}}", + "is_null": {{true or false}}, + "is_primary_key": {{true or false}}, + "is_foreign_key": {{true or false}} }} ] }} - ] - Therefore your answers are: - """ - - llm = _prepare_openapi_llm() - prompt = PromptTemplate( - input_variables=["industry_name"], - template=recommend_data_processing_scenario_template, - ) - chain = LLMChain(llm=llm, prompt=prompt) - response = chain({"industry_name": industry_name}) - results = response["text"] - - return results - -def recommend_tables_for_industry(industry_name: str, industry_contexts: str): - """ Recommend database tables for a given industry and relevant contexts. - - :param industry_name: industry name - :param industry_contexts: industry descriptions/contexts - - :returns: recommened tables with schema in array of json format - """ - - recommaned_tables_for_industry_template = """ - You're a data engineer and familiar with the {industry_name} industry IT systems. - And below is relevant contexts of the industry: - {industry_contexts} - - Please recommend some database tables with potential data schema for the above contexts. - Your response should be an array of JSON format objects like below. - {{ - "table_name": "{{table name}}", - "table_description": "{{table description}}", - [ - {{ - "column_name": "{{column name}}", - "data_type": "{{data type}}", - "is_null": {{true or false}}, - "is_primary_key": {{true or false}}, - "is_foreign_key": {{true or false}} - }} - ] - }} - - Please recommend 7 to 10 database tables: - """ - - llm = _prepare_openapi_llm() - prompt = PromptTemplate( - input_variables=["industry_name", "industry_contexts"], - template=recommaned_tables_for_industry_template, - ) - chain = LLMChain(llm=llm, prompt=prompt) - response = chain({"industry_name": industry_name, - "industry_contexts": industry_contexts}) - results = response["text"] + + Please recommend 7 to 10 database tables: + """ - return results - - -def recommend_data_processing_scenario_mock(industry_name: str): - results = """ - [ { "pipeline_name": "Flight Delay Pipeline", "description": "Collects flight data from various sources and generates metrics on delays", "stages": [ { "stage": "staging", "description": "Collects flight data from airline APIs and airport databases" }, { "stage": "standard", "description": "Transforms data to calculate delay metrics based on departure and arrival times" }, { "stage": "serving", "description": "Aggregates delay metrics by airport, airline, and route" } ] }, { "pipeline_name": "Baggage Handling Pipeline", "description": "Tracks baggage movement and generates metrics on handling efficiency", "stages": [ { "stage": "staging", "description": "Collects baggage movement data from RFID scanners and baggage handling systems" }, { "stage": "standard", "description": "Transforms data to calculate metrics on baggage handling efficiency, such as time to load and unload baggage" }, { "stage": "serving", "description": "Aggregates metrics by airport, airline, and baggage handling company" } ] }, { "pipeline_name": "Revenue Management Pipeline", "description": "Analyzes sales data to optimize pricing and revenue", "stages": [ { "stage": "staging", "description": "Collects sales data from ticketing systems and travel booking websites" }, { "stage": "standard", "description": "Transforms data to calculate revenue metrics, such as average ticket price and revenue per seat" }, { "stage": "serving", "description": "Aggregates metrics by route, fare class, and seasonality" } ] }, { "pipeline_name": "Maintenance Pipeline", "description": "Monitors aircraft health and schedules maintenance", "stages": [ { "stage": "staging", "description": "Collects aircraft sensor data and maintenance records" }, { "stage": "standard", "description": "Transforms data to identify potential maintenance issues and schedule preventative maintenance" }, { "stage": "serving", "description": "Aggregates metrics by aircraft type, age, and maintenance history" } ] }, { "pipeline_name": "Customer Service Pipeline", "description": "Analyzes customer feedback to improve service", "stages": [ { "stage": "staging", "description": "Collects customer feedback from surveys, social media, and customer service interactions" }, { "stage": "standard", "description": "Transforms data to identify common issues and sentiment analysis of customer feedback" }, { "stage": "serving", "description": "Aggregates metrics by route, airline, and customer feedback channel" } ] }, { "pipeline_name": "Fuel Efficiency Pipeline", "description": "Monitors fuel usage to optimize efficiency and reduce costs", "stages": [ { "stage": "staging", "description": "Collects fuel usage data from aircraft sensors and fueling systems" }, { "stage": "standard", "description": "Transforms data to calculate fuel efficiency metrics, such as fuel burn per passenger mile" }, { "stage": "serving", "description": "Aggregates metrics by aircraft type, route, and seasonality" } ] }, { "pipeline_name": "Security Pipeline", "description": "Monitors security incidents to improve safety and compliance", "stages": [ { "stage": "staging", "description": "Collects security incident data from airport security systems and passenger screening" }, { "stage": "standard", "description": "Transforms data to identify common security incidents and compliance issues" }, { "stage": "serving", "description": "Aggregates metrics by airport, airline, and security incident type" } ] } ] - """ - - return results - - -def recommend_tables_for_industry_mock(industry_name: str, industry_contexts: str): - results = """ - [{"table_name":"airlines","table_description":"Information about the airline companies","columns":[{"column_name":"airline_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"name","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"country","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false}]},{"table_name":"flights","table_description":"Information about flights operated by airlines","columns":[{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"airline_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true},{"column_name":"origin","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"destination","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"departure_time","data_type":"datetime","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"arrival_time","data_type":"datetime","is_null":false,"is_primary_key":false,"is_foreign_key":false}]},{"table_name":"passengers","table_description":"Information about passengers","columns":[{"column_name":"passenger_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"name","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"age","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"gender","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true}]},{"table_name":"bookings","table_description":"Information about flight bookings made by passengers","columns":[{"column_name":"booking_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"passenger_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true},{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true}]},{"table_name":"seats","table_description":"Information about seats available in flights","columns":[{"column_name":"seat_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true},{"column_name":"passenger_id","data_type":"integer","is_null":true,"is_primary_key":false,"is_foreign_key":true},{"column_name":"seat_number","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false}]},{"table_name":"airports","table_description":"Information about airports","columns":[{"column_name":"airport_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"name","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"location","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false}]}] - """ - - return results - - -def generate_custom_data_processing_logics_mock(industry_name: str, - industry_contexts: str, - involved_tables: str, - custom_data_processing_logic: str, - output_table_name: str): - results = """ - {"sql":"SELECT flights.airline_id, flights.arrival_time, flights.departure_time, flights.destination, flights.flight_id, flights.origin, passengers.age, passengers.gender, passengers.name, passengers.passenger_id FROM flights JOIN bookings ON flights.flight_id = bookings.flight_id JOIN passengers ON bookings.passenger_id = passengers.passenger_id","schema":{"table_name":"std_flights_p","columns":[{"column_name":"airline_id","data_type":"long","is_null":true,"is_primary_key":false,"is_foreign_key":true},{"column_name":"arrival_time","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"departure_time","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"destination","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"flight_id","data_type":"long","is_null":true,"is_primary_key":false,"is_foreign_key":true},{"column_name":"origin","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"age","data_type":"long","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"gender","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"name","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"passenger_id","data_type":"long","is_null":true,"is_primary_key":false,"is_foreign_key":true}]}} - """ - - return results - - -def recommend_custom_table(industry_name: str, - industry_contexts: str, - recommened_tables: str, - custom_table_name: str, - custom_table_description: str): - """ Recommend custom/user-defined table for a input custom table name and description of a given industry. - - :param industry_name: industry name - :param industry_contexts: industry descriptions/contexts - :param recommened_tables: previously recommened tables in json string format - :param custom_table_name: custom table name - :param custom_table_description: custom table description - - :returns: recommened custom tables with schema in json format - """ - - recommend_custom_tables_template=""" - You're a data engineer and familiar with the {industry_name} industry IT systems. - And below is relevant contexts of the industry: - {industry_contexts} - - You've recommended below potential database tables and schema previously in json format. - {recommened_tables} - - Please help to add another {custom_table_name} table with the same table schema format, please include as many as possible columns reflecting the reality and the table description below. - {custom_table_description} - - Please only output the new added table without previously recommended tables, therefore the outcome would be: - """ - - llm = _prepare_openapi_llm() - prompt = PromptTemplate( - input_variables=["industry_name", "industry_contexts", "recommened_tables", "custom_table_name", "custom_table_description"], - template=recommend_custom_tables_template, - ) - chain = LLMChain(llm=llm, prompt=prompt) - response = chain({"industry_name": industry_name, - "industry_contexts": industry_contexts, - "recommened_tables": recommened_tables, - "custom_table_name": custom_table_name, - "custom_table_description": custom_table_description}) - results = response["text"] - - return results - - -def recommend_data_processing_logics(industry_name: str, - industry_contexts: str, - recommened_tables: str, - processing_logic: str): - """ Recommend data processing logics for a given industry and sample tables. - - :param industry_name: industry name - :param industry_contexts: industry descriptions/contexts - :param recommened_tables: previously recommened tables in json string format - :param processing_logic: either data cleaning, data transformation or data aggregation - - :returns: recommened data processing logics in array of json format - """ - - recommend_data_cleaning_logics_template=""" - You're a data engineer and familiar with the {industry_name} industry IT systems. - And below is relevant contexts of the industry: - {industry_contexts} - - You've recommended below potential database tables and schema previously in json format. - {recommened_tables} - - Please recommend 5 to 7 {processing_logic} logic over the above tables with Spark SQL statements. - You response should be in an array of JSON format like below. - {{ - "description": "{{descriptions on the data cleaning logic}}", - "involved_tables": [ - "{{involved table X}}", - "{{involved table Y}}", - "{{involved table Z}}" - ], - "sql": "{{Spark SQL statement to do the data cleaning}}", - "schema": "{{cleaned table schema in json string format}}" - }} - - Therefore outcome would be: - """ - - llm = _prepare_openapi_llm() - prompt = PromptTemplate( - input_variables=["industry_name", "industry_contexts", "processing_logic", "recommened_tables"], - template=recommend_data_cleaning_logics_template, - ) - chain = LLMChain(llm=llm, prompt=prompt) - response = chain({"industry_name": industry_name, - "industry_contexts": industry_contexts, - "processing_logic": processing_logic, - "recommened_tables": recommened_tables}) - results = json.loads(response["text"]) - - return results - - -def generate_custom_data_processing_logics(industry_name: str, - industry_contexts: str, - involved_tables: str, - custom_data_processing_logic: str, - output_table_name: str): - """ Generate custom data processing logics for input data processing requirements. - - :param industry_name: industry name - :param industry_contexts: industry descriptions/contexts - :param involved_tables: tables required by the custom data processing logic, in json string format - :param custom_data_processing_logic: custom data processing requirements - :param output_table_name: output/sink table name for processed data - - :returns: custom data processing logic json format - """ - - generate_custom_data_processing_logics_template = """ - You're a data engineer and familiar with the {industry_name} industry IT systems. - And below is relevant contexts of the industry: - {industry_contexts} - - We have database tables listed below with table_name and schema in JSON format. - {involved_tables} - - And below is the data processing requirement. - {custom_data_processing_logic} - - Please help to generate Spark SQL statement with output data schema. - And your response should be in JSON format like below. - {{ - "sql": "{{Spark SQL statement to do the data cleaning}}", - "schema": "{{output data schema in JSON string format}}" - }} - - And the above data schema string should follows below JSON format, while value for the "table_name" key should strictly be "{output_table_name}". - {{ - "table_name": "{output_table_name}", - "coloumns": [ - {{ - "column_name": "{{column name}}", - "data_type": "{{data type}}", - "is_null": {{true or false}}, - "is_primary_key": {{true or false}}, - "is_foreign_key": {{true or false}} - }} - ] - }} - - Therefore the outcome would be: - """ - - llm = _prepare_openapi_llm() - prompt = PromptTemplate( - input_variables=["industry_name", "industry_contexts", "involved_tables", "custom_data_processing_logic", "output_table_name"], - template=generate_custom_data_processing_logics_template, - ) - chain = LLMChain(llm=llm, prompt=prompt) - - # Run the chain only specifying the input variable. - response = chain({"industry_name": industry_name, - "industry_contexts": industry_contexts, - "custom_data_processing_logic": custom_data_processing_logic, - "involved_tables": involved_tables, - "output_table_name": output_table_name}) - results = response["text"] - - return results - - -def generate_sample_data(industry_name: str, - number_of_lines: int, - target_table: str, - column_values_patterns: str): - """ Generate custom data processing logics for input data processing requirements. - - :param industry_name: industry name - :param number_of_lines: number of lines sample data required - :param target_table: target table name and its schema in json format - - :returns: generated sample data in array of json format - """ - - generate_sample_data_template = """ - You're a data engineer and familiar with the {industry_name} industry IT systems. - Please help to generate {number_of_lines} lines of sample data for below table with table schema in json format. - {target_table} - - And below are patterns of column values in json format, if it's not provided please ignore this requirement. - {column_values_patterns} - - And the sample data should be an array of json object like below. - {{ - "{{column X}}": "{{column value}}", - "{{column Y}}": "{{column value}}", - "{{column Z}}": "{{column value}}" - }} - - The sample data would be: - """ - - llm = _prepare_openapi_llm() - prompt = PromptTemplate( - input_variables=["industry_name", "number_of_lines", "target_table", "column_values_patterns"], - template=generate_sample_data_template, - ) - chain = LLMChain(llm=llm, prompt=prompt) - response = chain({"industry_name": industry_name, - "number_of_lines": number_of_lines, - "target_table": target_table, - "column_values_patterns": column_values_patterns}) - results = response["text"] - - return results - - -def generate_sample_data_mock(industry_name: str, - number_of_lines: int, - target_table: str, - column_values_patterns: str): - if target_table["table_name"] == "flights": + prompt = PromptTemplate( + input_variables=["industry_name", "industry_contexts"], + template=recommaned_tables_for_industry_template, + ) + chain = LLMChain(llm=self.llm, prompt=prompt) + response = chain({"industry_name": industry_name, + "industry_contexts": industry_contexts}) + results = response["text"] + + return results + + + def recommend_data_processing_scenario_mock(self, industry_name: str): results = """ - [ { "flight_id": 1, "airline_id": 1001, "origin": "New York", "destination": "Los Angeles", "departure_time": "2021-01-01 08:00:00", "arrival_time": "2021-01-01 11:30:00" }, { "flight_id": 2, "airline_id": 1002, "origin": "London", "destination": "Paris", "departure_time": "2021-01-02 14:30:00", "arrival_time": "2021-01-02 16:00:00" }, { "flight_id": 3, "airline_id": 1003, "origin": "Tokyo", "destination": "Sydney", "departure_time": "2021-01-03 10:45:00", "arrival_time": "2021-01-04 06:15:00" }, { "flight_id": 4, "airline_id": 1004, "origin": "Chicago", "destination": "Miami", "departure_time": "2021-01-05 16:20:00", "arrival_time": "2021-01-05 19:45:00" }, { "flight_id": 5, "airline_id": 1005, "origin": "Sydney", "destination": "Melbourne", "departure_time": "2021-01-06 09:15:00", "arrival_time": "2021-01-06 10:30:00" }, { "flight_id": 6, "airline_id": 1001, "origin": "Los Angeles", "destination": "New York", "departure_time": "2021-01-07 12:00:00", "arrival_time": "2021-01-07 15:30:00" }, { "flight_id": 7, "airline_id": 1002, "origin": "Paris", "destination": "London", "departure_time": "2021-01-08 18:45:00", "arrival_time": "2021-01-08 20:15:00" }, { "flight_id": 8, "airline_id": 1003, "origin": "Sydney", "destination": "Tokyo", "departure_time": "2021-01-09 14:30:00", "arrival_time": "2021-01-10 08:00:00" }, { "flight_id": 9, "airline_id": 1004, "origin": "Miami", "destination": "Chicago", "departure_time": "2021-01-11 20:00:00", "arrival_time": "2021-01-11 23:25:00" }, { "flight_id": 10, "airline_id": 1005, "origin": "Melbourne", "destination": "Sydney", "departure_time": "2021-01-12 13:45:00", "arrival_time": "2021-01-12 15:00:00" }, { "flight_id": 11, "airline_id": 1001, "origin": "New York", "destination": "Los Angeles", "departure_time": "2021-01-13 08:00:00", "arrival_time": "2021-01-13 11:30:00" }, { "flight_id": 12, "airline_id": 1002, "origin": "London", "destination": "Paris", "departure_time": "2021-01-14 14:30:00", "arrival_time": "2021-01-14 16:00:00" }, { "flight_id": 13, "airline_id": 1003, "origin": "Tokyo", "destination": "Sydney", "departure_time": "2021-01-15 10:45:00", "arrival_time": "2021-01-16 06:15:00" }, { "flight_id": 14, "airline_id": 1004, "origin": "Chicago", "destination": "Miami", "departure_time": "2021-01-17 16:20:00", "arrival_time": "2021-01-17 19:45:00" }, { "flight_id": 15, "airline_id": 1005, "origin": "Sydney", "destination": "Melbourne", "departure_time": "2021-01-18 09:15:00", "arrival_time": "2021-01-18 10:30:00" }, { "flight_id": 16, "airline_id": 1001, "origin": "Los Angeles", "destination": "New York", "departure_time": "2021-01-19 12:00:00", "arrival_time": "2021-01-19 15:30:00" }, { "flight_id": 17, "airline_id": 1002, "origin": "Paris", "destination": "London", "departure_time": "2021-01-20 18:45:00", "arrival_time": "2021-01-20 20:15:00" }, { "flight_id": 18, "airline_id": 1003, "origin": "Sydney", "destination": "Tokyo", "departure_time": "2021-01-21 14:30:00", "arrival_time": "2021-01-22 08:00:00" }, { "flight_id": 19, "airline_id": 1004, "origin": "Miami", "destination": "Chicago", "departure_time": "2021-01-23 20:00:00", "arrival_time": "2021-01-23 23:25:00" }, { "flight_id": 20, "airline_id": 1005, "origin": "Melbourne", "destination": "Sydney", "departure_time": "2021-01-24 13:45:00", "arrival_time": "2021-01-24 15:00:00" } ] + [ { "pipeline_name": "Flight Delay Pipeline", "description": "Collects flight data from various sources and generates metrics on delays", "stages": [ { "stage": "staging", "description": "Collects flight data from airline APIs and airport databases" }, { "stage": "standard", "description": "Transforms data to calculate delay metrics based on departure and arrival times" }, { "stage": "serving", "description": "Aggregates delay metrics by airport, airline, and route" } ] }, { "pipeline_name": "Baggage Handling Pipeline", "description": "Tracks baggage movement and generates metrics on handling efficiency", "stages": [ { "stage": "staging", "description": "Collects baggage movement data from RFID scanners and baggage handling systems" }, { "stage": "standard", "description": "Transforms data to calculate metrics on baggage handling efficiency, such as time to load and unload baggage" }, { "stage": "serving", "description": "Aggregates metrics by airport, airline, and baggage handling company" } ] }, { "pipeline_name": "Revenue Management Pipeline", "description": "Analyzes sales data to optimize pricing and revenue", "stages": [ { "stage": "staging", "description": "Collects sales data from ticketing systems and travel booking websites" }, { "stage": "standard", "description": "Transforms data to calculate revenue metrics, such as average ticket price and revenue per seat" }, { "stage": "serving", "description": "Aggregates metrics by route, fare class, and seasonality" } ] }, { "pipeline_name": "Maintenance Pipeline", "description": "Monitors aircraft health and schedules maintenance", "stages": [ { "stage": "staging", "description": "Collects aircraft sensor data and maintenance records" }, { "stage": "standard", "description": "Transforms data to identify potential maintenance issues and schedule preventative maintenance" }, { "stage": "serving", "description": "Aggregates metrics by aircraft type, age, and maintenance history" } ] }, { "pipeline_name": "Customer Service Pipeline", "description": "Analyzes customer feedback to improve service", "stages": [ { "stage": "staging", "description": "Collects customer feedback from surveys, social media, and customer service interactions" }, { "stage": "standard", "description": "Transforms data to identify common issues and sentiment analysis of customer feedback" }, { "stage": "serving", "description": "Aggregates metrics by route, airline, and customer feedback channel" } ] }, { "pipeline_name": "Fuel Efficiency Pipeline", "description": "Monitors fuel usage to optimize efficiency and reduce costs", "stages": [ { "stage": "staging", "description": "Collects fuel usage data from aircraft sensors and fueling systems" }, { "stage": "standard", "description": "Transforms data to calculate fuel efficiency metrics, such as fuel burn per passenger mile" }, { "stage": "serving", "description": "Aggregates metrics by aircraft type, route, and seasonality" } ] }, { "pipeline_name": "Security Pipeline", "description": "Monitors security incidents to improve safety and compliance", "stages": [ { "stage": "staging", "description": "Collects security incident data from airport security systems and passenger screening" }, { "stage": "standard", "description": "Transforms data to identify common security incidents and compliance issues" }, { "stage": "serving", "description": "Aggregates metrics by airport, airline, and security incident type" } ] } ] """ - if target_table["table_name"] == "passengers": + return results + + + def recommend_tables_for_industry_mock(self, industry_name: str, industry_contexts: str): results = """ - [ { "passenger_id": 1, "name": "John Smith", "age": 35, "gender": "Male", "flight_id": 1 }, { "passenger_id": 2, "name": "Jane Doe", "age": 45, "gender": "Female", "flight_id": 1 }, { "passenger_id": 3, "name": "Michael Johnson", "age": 60, "gender": "Male", "flight_id": 2 }, { "passenger_id": 4, "name": "Emily Williams", "age": 25, "gender": "Female", "flight_id": 3 }, { "passenger_id": 5, "name": "David Brown", "age": 55, "gender": "Male", "flight_id": 4 }, { "passenger_id": 6, "name": "Sarah Davis", "age": 30, "gender": "Female", "flight_id": 4 }, { "passenger_id": 7, "name": "Robert Martinez", "age": 65, "gender": "Male", "flight_id": 5 }, { "passenger_id": 8, "name": "Jessica Thomas", "age": 40, "gender": "Female", "flight_id": 5 }, { "passenger_id": 9, "name": "Christopher Wilson", "age": 50, "gender": "Male", "flight_id": 1 }, { "passenger_id": 10, "name": "Stephanie Taylor", "age": 27, "gender": "Female", "flight_id": 2 }, { "passenger_id": 11, "name": "Daniel Anderson", "age": 65, "gender": "Male", "flight_id": 3 }, { "passenger_id": 12, "name": "Melissa Thompson", "age": 42, "gender": "Female", "flight_id": 4 }, { "passenger_id": 13, "name": "Matthew White", "age": 32, "gender": "Male", "flight_id": 5 }, { "passenger_id": 14, "name": "Amanda Harris", "age": 52, "gender": "Female", "flight_id": 1 }, { "passenger_id": 15, "name": "Andrew Lee", "age": 65, "gender": "Male", "flight_id": 2 }, { "passenger_id": 16, "name": "Jennifer Clark", "age": 28, "gender": "Female", "flight_id": 3 }, { "passenger_id": 17, "name": "James Rodriguez", "age": 62, "gender": "Male", "flight_id": 4 }, { "passenger_id": 18, "name": "Nicole Walker", "age": 38, "gender": "Female", "flight_id": 5 }, { "passenger_id": 19, "name": "Ryan Wright", "age": 47, "gender": "Male", "flight_id": 1 }, { "passenger_id": 20, "name": "Lauren Hall", "age": 31, "gender": "Female", "flight_id": 2 } ] + [{"table_name":"airlines","table_description":"Information about the airline companies","columns":[{"column_name":"airline_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"name","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"country","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false}]},{"table_name":"flights","table_description":"Information about flights operated by airlines","columns":[{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"airline_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true},{"column_name":"origin","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"destination","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"departure_time","data_type":"datetime","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"arrival_time","data_type":"datetime","is_null":false,"is_primary_key":false,"is_foreign_key":false}]},{"table_name":"passengers","table_description":"Information about passengers","columns":[{"column_name":"passenger_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"name","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"age","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"gender","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true}]},{"table_name":"bookings","table_description":"Information about flight bookings made by passengers","columns":[{"column_name":"booking_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"passenger_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true},{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true}]},{"table_name":"seats","table_description":"Information about seats available in flights","columns":[{"column_name":"seat_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true},{"column_name":"passenger_id","data_type":"integer","is_null":true,"is_primary_key":false,"is_foreign_key":true},{"column_name":"seat_number","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false}]},{"table_name":"airports","table_description":"Information about airports","columns":[{"column_name":"airport_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"name","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"location","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false}]}] """ - - if target_table["table_name"] == "bookings": + + return results + + + def recommend_tables_for_industry_one_at_a_time(self, industry_name: str, industry_contexts: str, recommended_tables: str = None, index: int = None): + recommaned_tables_for_industry_template = """ + You're a data engineer and familiar with the {industry_name} industry IT systems. + And below is relevant contexts of the industry: + {industry_contexts} + + Please recommend some database tables with potential data schema for the above contexts. + Your response should strictly be in JSON format like below. + {{ + "table_name": "{{table name}}", + "table_description": "{{table description}}", + "columns": [ + {{ + "column_name": "{{column name}}", + "data_type": "{{data type}}", + "is_null": {{true or false (strictly in lower case)}}, + "is_primary_key": {{true or false (strictly in lower case)}}, + "is_foreign_key": {{true or false (strictly in lower case)}} + }} + ] + }} + + You've recommanded below tables in previous conversations. + {recommended_tables} + + Please recommend another one without duplicated table (if there's no table listed above, please go ahead to create a new one): + """ + + prompt = PromptTemplate( + input_variables=["industry_name", "industry_contexts", "recommended_tables"], + template=recommaned_tables_for_industry_template, + ) + chain = LLMChain(llm=self.llm, prompt=prompt) + response = chain({"industry_name": industry_name, + "industry_contexts": industry_contexts, + "recommended_tables": recommended_tables}) + results = response["text"] + + return results + + + def recommend_tables_for_industry_one_at_a_time_mock(self, industry_name: str, industry_contexts: str, index: int): + if index == 0: + results = """ + {"table_name":"flights","table_description":"Table to store information about flights","columns":[{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"airline_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true},{"column_name":"departure_city","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"arrival_city","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"departure_date","data_type":"date","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"departure_time","data_type":"time","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"arrival_date","data_type":"date","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"arrival_time","data_type":"time","is_null":false,"is_primary_key":false,"is_foreign_key":false}]} + """ + elif index == 1: + results = """ + {"table_name":"passengers","table_description":"Table to store information about passengers","columns":[{"column_name":"passenger_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"passenger_name","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"date_of_birth","data_type":"date","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"email","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"phone_number","data_type":"varchar(20)","is_null":true,"is_primary_key":false,"is_foreign_key":false}]} + """ + elif index == 2: + results = """ + {"table_name":"airlines","table_description":"Table to store information about airline companies","columns":[{"column_name":"airline_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"airline_name","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"country","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false}]} + """ + elif index == 3: + results = """ + {"table_name":"tickets","table_description":"Table to store information about airline tickets","columns":[{"column_name":"ticket_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"passenger_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true},{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true},{"column_name":"ticket_number","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"ticket_class","data_type":"varchar(50)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"ticket_price","data_type":"decimal(10,2)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"ticket_status","data_type":"varchar(50)","is_null":false,"is_primary_key":false,"is_foreign_key":false}]} + """ + elif index == 4: + results = """ + {"table_name":"airports","table_description":"Table to store information about airports","columns":[{"column_name":"airport_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"airport_name","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"city","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"country","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false}]} + """ + # Simulate request consuming response time + time.sleep(1) + + return results + + + def generate_custom_data_processing_logics_mock(self, + industry_name: str, + industry_contexts: str, + involved_tables: str, + custom_data_processing_logic: str, + output_table_name: str): results = """ - [ { "booking_id": 1, "passenger_id": 1, "flight_id": 1 }, { "booking_id": 2, "passenger_id": 2, "flight_id": 2 }, { "booking_id": 3, "passenger_id": 3, "flight_id": 3 }, { "booking_id": 4, "passenger_id": 4, "flight_id": 4 }, { "booking_id": 5, "passenger_id": 5, "flight_id": 5 }, { "booking_id": 6, "passenger_id": 1, "flight_id": 2 }, { "booking_id": 7, "passenger_id": 2, "flight_id": 3 }, { "booking_id": 8, "passenger_id": 3, "flight_id": 4 }, { "booking_id": 9, "passenger_id": 4, "flight_id": 5 }, { "booking_id": 10, "passenger_id": 5, "flight_id": 1 }, { "booking_id": 11, "passenger_id": 1, "flight_id": 3 }, { "booking_id": 12, "passenger_id": 2, "flight_id": 4 }, { "booking_id": 13, "passenger_id": 3, "flight_id": 5 }, { "booking_id": 14, "passenger_id": 4, "flight_id": 1 }, { "booking_id": 15, "passenger_id": 5, "flight_id": 2 }, { "booking_id": 16, "passenger_id": 1, "flight_id": 4 }, { "booking_id": 17, "passenger_id": 2, "flight_id": 5 }, { "booking_id": 18, "passenger_id": 3, "flight_id": 1 }, { "booking_id": 19, "passenger_id": 4, "flight_id": 2 }, { "booking_id": 20, "passenger_id": 5, "flight_id": 3 } ] + {"sql":"SELECT flights.airline_id, flights.arrival_time, flights.departure_time, flights.destination, flights.flight_id, flights.origin, passengers.age, passengers.gender, passengers.name, passengers.passenger_id FROM flights JOIN bookings ON flights.flight_id = bookings.flight_id JOIN passengers ON bookings.passenger_id = passengers.passenger_id","schema":{"table_name":"std_flights_p","columns":[{"column_name":"airline_id","data_type":"long","is_null":true,"is_primary_key":false,"is_foreign_key":true},{"column_name":"arrival_time","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"departure_time","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"destination","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"flight_id","data_type":"long","is_null":true,"is_primary_key":false,"is_foreign_key":true},{"column_name":"origin","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"age","data_type":"long","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"gender","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"name","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"passenger_id","data_type":"long","is_null":true,"is_primary_key":false,"is_foreign_key":true}]}} + """ + + return results + + + def recommend_custom_table(self, + industry_name: str, + industry_contexts: str, + recommened_tables: str, + custom_table_name: str, + custom_table_description: str): + """ Recommend custom/user-defined table for a input custom table name and description of a given industry. + + :param industry_name: industry name + :param industry_contexts: industry descriptions/contexts + :param recommened_tables: previously recommened tables in json string format + :param custom_table_name: custom table name + :param custom_table_description: custom table description + + :returns: recommened custom tables with schema in json format + """ + + recommend_custom_tables_template=""" + You're a data engineer and familiar with the {industry_name} industry IT systems. + And below is relevant contexts of the industry: + {industry_contexts} + + You've recommended below potential database tables and schema previously in json format. + {recommened_tables} + + Please help to add another {custom_table_name} table with the same table schema format, please include as many as possible columns reflecting the reality and the table description below. + {custom_table_description} + + Please only output the new added table without previously recommended tables, therefore the outcome would be: + """ + + prompt = PromptTemplate( + input_variables=["industry_name", "industry_contexts", "recommened_tables", "custom_table_name", "custom_table_description"], + template=recommend_custom_tables_template, + ) + chain = LLMChain(llm=self.llm, prompt=prompt) + response = chain({"industry_name": industry_name, + "industry_contexts": industry_contexts, + "recommened_tables": recommened_tables, + "custom_table_name": custom_table_name, + "custom_table_description": custom_table_description}) + results = response["text"] + + return results + + + def recommend_data_processing_logics(self, + industry_name: str, + industry_contexts: str, + recommened_tables: str, + processing_logic: str): + """ Recommend data processing logics for a given industry and sample tables. + + :param industry_name: industry name + :param industry_contexts: industry descriptions/contexts + :param recommened_tables: previously recommened tables in json string format + :param processing_logic: either data cleaning, data transformation or data aggregation + + :returns: recommened data processing logics in array of json format + """ + + recommend_data_cleaning_logics_template=""" + You're a data engineer and familiar with the {industry_name} industry IT systems. + And below is relevant contexts of the industry: + {industry_contexts} + + You've recommended below potential database tables and schema previously in json format. + {recommened_tables} + + Please recommend 5 to 7 {processing_logic} logic over the above tables with Spark SQL statements. + You response should be in an array of JSON format like below. + {{ + "description": "{{descriptions on the data cleaning logic}}", + "involved_tables": [ + "{{involved table X}}", + "{{involved table Y}}", + "{{involved table Z}}" + ], + "sql": "{{Spark SQL statement to do the data cleaning}}", + "schema": "{{cleaned table schema in json string format}}" + }} + + Therefore outcome would be: + """ + + prompt = PromptTemplate( + input_variables=["industry_name", "industry_contexts", "processing_logic", "recommened_tables"], + template=recommend_data_cleaning_logics_template, + ) + chain = LLMChain(llm=self.llm, prompt=prompt) + response = chain({"industry_name": industry_name, + "industry_contexts": industry_contexts, + "processing_logic": processing_logic, + "recommened_tables": recommened_tables}) + results = json.loads(response["text"]) + + return results + + + def generate_custom_data_processing_logics(self, + industry_name: str, + industry_contexts: str, + involved_tables: str, + custom_data_processing_logic: str, + output_table_name: str): + """ Generate custom data processing logics for input data processing requirements. + + :param industry_name: industry name + :param industry_contexts: industry descriptions/contexts + :param involved_tables: tables required by the custom data processing logic, in json string format + :param custom_data_processing_logic: custom data processing requirements + :param output_table_name: output/sink table name for processed data + + :returns: custom data processing logic json format + """ + + generate_custom_data_processing_logics_template = """ + You're a data engineer and familiar with the {industry_name} industry IT systems. + And below is relevant contexts of the industry: + {industry_contexts} + + We have database tables listed below with table_name and schema in JSON format. + {involved_tables} + + And below is the data processing requirement. + {custom_data_processing_logic} + + Please help to generate Spark SQL statement with output data schema. + And your response should be in JSON format like below. + {{ + "sql": "{{Spark SQL statement to do the data cleaning}}", + "schema": "{{output data schema in JSON string format}}" + }} + + And the above data schema string should follows below JSON format, while value for the "table_name" key should strictly be "{output_table_name}". + {{ + "table_name": "{output_table_name}", + "coloumns": [ + {{ + "column_name": "{{column name}}", + "data_type": "{{data type}}", + "is_null": {{true or false}}, + "is_primary_key": {{true or false}}, + "is_foreign_key": {{true or false}} + }} + ] + }} + + Therefore the outcome would be: + """ + + prompt = PromptTemplate( + input_variables=["industry_name", "industry_contexts", "involved_tables", "custom_data_processing_logic", "output_table_name"], + template=generate_custom_data_processing_logics_template, + ) + chain = LLMChain(llm=self.llm, prompt=prompt) + + # Run the chain only specifying the input variable. + response = chain({"industry_name": industry_name, + "industry_contexts": industry_contexts, + "custom_data_processing_logic": custom_data_processing_logic, + "involved_tables": involved_tables, + "output_table_name": output_table_name}) + results = response["text"] + + return results + + + def generate_sample_data(self, + industry_name: str, + number_of_lines: int, + target_table: str, + column_values_patterns: str): + """ Generate custom data processing logics for input data processing requirements. + + :param industry_name: industry name + :param number_of_lines: number of lines sample data required + :param target_table: target table name and its schema in json format + + :returns: generated sample data in array of json format + """ + + generate_sample_data_template = """ + You're a data engineer and familiar with the {industry_name} industry IT systems. + Please help to generate {number_of_lines} lines of sample data for below table with table schema in json format. + {target_table} + + And below are patterns of column values in json format, if it's not provided please ignore this requirement. + {column_values_patterns} + + And the sample data should be an array of json object like below. + {{ + "{{column X}}": "{{column value}}", + "{{column Y}}": "{{column value}}", + "{{column Z}}": "{{column value}}" + }} + + The sample data would be: """ - return results \ No newline at end of file + prompt = PromptTemplate( + input_variables=["industry_name", "number_of_lines", "target_table", "column_values_patterns"], + template=generate_sample_data_template, + ) + chain = LLMChain(llm=self.llm, prompt=prompt) + response = chain({"industry_name": industry_name, + "number_of_lines": number_of_lines, + "target_table": target_table, + "column_values_patterns": column_values_patterns}) + results = response["text"] + + return results + + + def generate_sample_data_mock(self, + industry_name: str, + number_of_lines: int, + target_table: str, + column_values_patterns: str): + if target_table["table_name"] == "flights": + results = """ + [ { "flight_id": 1, "airline_id": 1001, "origin": "New York", "destination": "Los Angeles", "departure_time": "2021-01-01 08:00:00", "arrival_time": "2021-01-01 11:30:00" }, { "flight_id": 2, "airline_id": 1002, "origin": "London", "destination": "Paris", "departure_time": "2021-01-02 14:30:00", "arrival_time": "2021-01-02 16:00:00" }, { "flight_id": 3, "airline_id": 1003, "origin": "Tokyo", "destination": "Sydney", "departure_time": "2021-01-03 10:45:00", "arrival_time": "2021-01-04 06:15:00" }, { "flight_id": 4, "airline_id": 1004, "origin": "Chicago", "destination": "Miami", "departure_time": "2021-01-05 16:20:00", "arrival_time": "2021-01-05 19:45:00" }, { "flight_id": 5, "airline_id": 1005, "origin": "Sydney", "destination": "Melbourne", "departure_time": "2021-01-06 09:15:00", "arrival_time": "2021-01-06 10:30:00" }, { "flight_id": 6, "airline_id": 1001, "origin": "Los Angeles", "destination": "New York", "departure_time": "2021-01-07 12:00:00", "arrival_time": "2021-01-07 15:30:00" }, { "flight_id": 7, "airline_id": 1002, "origin": "Paris", "destination": "London", "departure_time": "2021-01-08 18:45:00", "arrival_time": "2021-01-08 20:15:00" }, { "flight_id": 8, "airline_id": 1003, "origin": "Sydney", "destination": "Tokyo", "departure_time": "2021-01-09 14:30:00", "arrival_time": "2021-01-10 08:00:00" }, { "flight_id": 9, "airline_id": 1004, "origin": "Miami", "destination": "Chicago", "departure_time": "2021-01-11 20:00:00", "arrival_time": "2021-01-11 23:25:00" }, { "flight_id": 10, "airline_id": 1005, "origin": "Melbourne", "destination": "Sydney", "departure_time": "2021-01-12 13:45:00", "arrival_time": "2021-01-12 15:00:00" }, { "flight_id": 11, "airline_id": 1001, "origin": "New York", "destination": "Los Angeles", "departure_time": "2021-01-13 08:00:00", "arrival_time": "2021-01-13 11:30:00" }, { "flight_id": 12, "airline_id": 1002, "origin": "London", "destination": "Paris", "departure_time": "2021-01-14 14:30:00", "arrival_time": "2021-01-14 16:00:00" }, { "flight_id": 13, "airline_id": 1003, "origin": "Tokyo", "destination": "Sydney", "departure_time": "2021-01-15 10:45:00", "arrival_time": "2021-01-16 06:15:00" }, { "flight_id": 14, "airline_id": 1004, "origin": "Chicago", "destination": "Miami", "departure_time": "2021-01-17 16:20:00", "arrival_time": "2021-01-17 19:45:00" }, { "flight_id": 15, "airline_id": 1005, "origin": "Sydney", "destination": "Melbourne", "departure_time": "2021-01-18 09:15:00", "arrival_time": "2021-01-18 10:30:00" }, { "flight_id": 16, "airline_id": 1001, "origin": "Los Angeles", "destination": "New York", "departure_time": "2021-01-19 12:00:00", "arrival_time": "2021-01-19 15:30:00" }, { "flight_id": 17, "airline_id": 1002, "origin": "Paris", "destination": "London", "departure_time": "2021-01-20 18:45:00", "arrival_time": "2021-01-20 20:15:00" }, { "flight_id": 18, "airline_id": 1003, "origin": "Sydney", "destination": "Tokyo", "departure_time": "2021-01-21 14:30:00", "arrival_time": "2021-01-22 08:00:00" }, { "flight_id": 19, "airline_id": 1004, "origin": "Miami", "destination": "Chicago", "departure_time": "2021-01-23 20:00:00", "arrival_time": "2021-01-23 23:25:00" }, { "flight_id": 20, "airline_id": 1005, "origin": "Melbourne", "destination": "Sydney", "departure_time": "2021-01-24 13:45:00", "arrival_time": "2021-01-24 15:00:00" } ] + """ + + if target_table["table_name"] == "passengers": + results = """ + [ { "passenger_id": 1, "name": "John Smith", "age": 35, "gender": "Male", "flight_id": 1 }, { "passenger_id": 2, "name": "Jane Doe", "age": 45, "gender": "Female", "flight_id": 1 }, { "passenger_id": 3, "name": "Michael Johnson", "age": 60, "gender": "Male", "flight_id": 2 }, { "passenger_id": 4, "name": "Emily Williams", "age": 25, "gender": "Female", "flight_id": 3 }, { "passenger_id": 5, "name": "David Brown", "age": 55, "gender": "Male", "flight_id": 4 }, { "passenger_id": 6, "name": "Sarah Davis", "age": 30, "gender": "Female", "flight_id": 4 }, { "passenger_id": 7, "name": "Robert Martinez", "age": 65, "gender": "Male", "flight_id": 5 }, { "passenger_id": 8, "name": "Jessica Thomas", "age": 40, "gender": "Female", "flight_id": 5 }, { "passenger_id": 9, "name": "Christopher Wilson", "age": 50, "gender": "Male", "flight_id": 1 }, { "passenger_id": 10, "name": "Stephanie Taylor", "age": 27, "gender": "Female", "flight_id": 2 }, { "passenger_id": 11, "name": "Daniel Anderson", "age": 65, "gender": "Male", "flight_id": 3 }, { "passenger_id": 12, "name": "Melissa Thompson", "age": 42, "gender": "Female", "flight_id": 4 }, { "passenger_id": 13, "name": "Matthew White", "age": 32, "gender": "Male", "flight_id": 5 }, { "passenger_id": 14, "name": "Amanda Harris", "age": 52, "gender": "Female", "flight_id": 1 }, { "passenger_id": 15, "name": "Andrew Lee", "age": 65, "gender": "Male", "flight_id": 2 }, { "passenger_id": 16, "name": "Jennifer Clark", "age": 28, "gender": "Female", "flight_id": 3 }, { "passenger_id": 17, "name": "James Rodriguez", "age": 62, "gender": "Male", "flight_id": 4 }, { "passenger_id": 18, "name": "Nicole Walker", "age": 38, "gender": "Female", "flight_id": 5 }, { "passenger_id": 19, "name": "Ryan Wright", "age": 47, "gender": "Male", "flight_id": 1 }, { "passenger_id": 20, "name": "Lauren Hall", "age": 31, "gender": "Female", "flight_id": 2 } ] + """ + + if target_table["table_name"] == "bookings": + results = """ + [ { "booking_id": 1, "passenger_id": 1, "flight_id": 1 }, { "booking_id": 2, "passenger_id": 2, "flight_id": 2 }, { "booking_id": 3, "passenger_id": 3, "flight_id": 3 }, { "booking_id": 4, "passenger_id": 4, "flight_id": 4 }, { "booking_id": 5, "passenger_id": 5, "flight_id": 5 }, { "booking_id": 6, "passenger_id": 1, "flight_id": 2 }, { "booking_id": 7, "passenger_id": 2, "flight_id": 3 }, { "booking_id": 8, "passenger_id": 3, "flight_id": 4 }, { "booking_id": 9, "passenger_id": 4, "flight_id": 5 }, { "booking_id": 10, "passenger_id": 5, "flight_id": 1 }, { "booking_id": 11, "passenger_id": 1, "flight_id": 3 }, { "booking_id": 12, "passenger_id": 2, "flight_id": 4 }, { "booking_id": 13, "passenger_id": 3, "flight_id": 5 }, { "booking_id": 14, "passenger_id": 4, "flight_id": 1 }, { "booking_id": 15, "passenger_id": 5, "flight_id": 2 }, { "booking_id": 16, "passenger_id": 1, "flight_id": 4 }, { "booking_id": 17, "passenger_id": 2, "flight_id": 5 }, { "booking_id": 18, "passenger_id": 3, "flight_id": 1 }, { "booking_id": 19, "passenger_id": 4, "flight_id": 2 }, { "booking_id": 20, "passenger_id": 5, "flight_id": 3 } ] + """ + + return results diff --git a/src/streamlit/pages/2_AI Assistant.py b/src/streamlit/pages/2_AI Assistant.py index 19f77de..ddbceef 100644 --- a/src/streamlit/pages/2_AI Assistant.py +++ b/src/streamlit/pages/2_AI Assistant.py @@ -6,9 +6,10 @@ import sys import streamlit_utils sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) -import cddp.openai_api as openai_api +from cddp.openai_api import OpenaiApi from streamlit_extras.switch_page_button import switch_page from streamlit_extras.colored_header import colored_header +from time import sleep if "current_pipeline_obj" not in st.session_state: switch_page("Home") @@ -24,6 +25,7 @@ color_name="violet-70", ) +openai_api = OpenaiApi() @@ -76,9 +78,6 @@ st.divider() - - - with tab_data: # Initialize current_generated_tables key in session state if "current_generated_tables" not in st.session_state: @@ -94,147 +93,66 @@ st.session_state['current_generated_sample_data'] = {} current_generated_sample_data = st.session_state['current_generated_sample_data'] + if "disable_generate_table_button" not in st.session_state: + st.session_state["disable_generate_table_button"] = False + # AI Assistnt for tables generation tables = [] - generate_tables_col1, generate_tables_col2 = st.columns(2) - with generate_tables_col1: - st.button("Generate", on_click=streamlit_utils.click_button, kwargs={"button_name": "generate_tables"}, use_container_width=True) - - if "generate_tables" not in st.session_state: + placeholder = st.empty() + container = placeholder.container() + with container: + st.button("Generate", + on_click=streamlit_utils.generate_tables, + args=[placeholder, + current_generated_tables, + current_pipeline_obj, + current_generated_sample_data, + industry_name, + industry_contexts, + openai_api], + disabled=st.session_state["disable_generate_table_button"], + use_container_width=True) + + if streamlit_utils.has_staged_table() and st.session_state["generate_tables"]: # Show warning message if some of generated tables have been referenced by other std/srv tasks + with container: + st.warning("Some of current generated tables have been added to Staging zone, please remove them before generating tables again!") st.session_state["generate_tables"] = False - elif st.session_state["generate_tables"]: - st.session_state["generate_tables"] = False # Reset button clicked status - - has_staged_table = streamlit_utils.has_staged_table() + st.session_state["disable_generate_data_widget"] = False + + # Render generated tables once widgets inside table expander has been clicked/updated + if "generated_tables" in current_generated_tables and not st.session_state["generate_tables"]: + tables = current_generated_tables["generated_tables"] + with container: + for gen_table_index, table in enumerate(tables): + streamlit_utils.render_table_expander(table, + current_generated_tables, + current_generated_sample_data, + current_pipeline_obj, + gen_table_index, + industry_name, + openai_api) + elif "generate_tables_count" in st.session_state and len(current_generated_tables.get("generated_tables", [])) == st.session_state["generate_tables_count"]: + st.session_state["disable_generate_data_widget"] = False + tables = current_generated_tables["generated_tables"] + + with placeholder.container(): + st.button("Generate", + key="generate_table_redraw", + on_click=streamlit_utils.generate_tables, + args=[placeholder, current_generated_tables, current_pipeline_obj, current_generated_sample_data, industry_name, industry_contexts], + disabled=st.session_state["disable_generate_table_button"], + use_container_width=True) + + for gen_table_index, table in enumerate(tables): + streamlit_utils.render_table_expander(table, + current_generated_tables, + current_generated_sample_data, + current_pipeline_obj, + gen_table_index, + industry_name, + openai_api, + "redraw") - if has_staged_table: - st.warning("Some of current generated tables have been added to Staging zone, please remove them before generating tables again!") - else: - with generate_tables_col2: - with st.spinner('Generating...'): - tables = openai_api.recommend_tables_for_industry(industry_name, industry_contexts) - current_generated_tables["generated_tables"] = json.loads(tables) - - # Clean current_generated_sample_data key in session state once Generate button is clicked again - del st.session_state['current_generated_sample_data'] - current_generated_sample_data = {} - - try: - if "generated_tables" in current_generated_tables: - tables = current_generated_tables["generated_tables"] - - for gen_table_index, table in enumerate(tables): - columns = table["columns"] - columns_df = pd.DataFrame.from_dict(columns, orient='columns') - stg_name_has_dependency = streamlit_utils.check_tables_dependency(table["table_name"]) - - sample_data = None - added_to_stage = current_generated_tables["generated_tables"][gen_table_index].get("staged_flag", False) - expander_label = table["table_name"] - if added_to_stage: - expander_label = table["table_name"] + " :heavy_check_mark:" - - with st.expander(expander_label, expanded=added_to_stage): - gen_table_name = table["table_name"] - gen_table_desc = table["table_description"] - gen_table_sample_data_count = current_generated_tables["generated_tables"][gen_table_index].get("sample_data_count", 5) - sample_data_requirements_flag = current_generated_tables["generated_tables"][gen_table_index].get("sample_data_requirements_flag", False) - data_requirements = current_generated_tables["generated_tables"][gen_table_index].get("data_requirements", "") - - st.write(gen_table_desc) - st.write(columns_df) - - st.write(f"Generate sample data") - rows_count = st.slider("Number of rows", - min_value=5, - max_value=50, - value=gen_table_sample_data_count, - key=f'gen_rows_count_slider_{gen_table_name}', - on_change=streamlit_utils.widget_on_change, - args=[f'gen_rows_count_slider_{gen_table_name}', - gen_table_index, - 'sample_data_count']) - - enable_data_requirements = st.toggle("With extra sample data requirements", - value=sample_data_requirements_flag, - key=f'data_requirements_toggle_{gen_table_name}', - on_change=streamlit_utils.widget_on_change, - args=[f'data_requirements_toggle_{gen_table_name}', - gen_table_index, - 'sample_data_requirements_flag']) - - if enable_data_requirements: - data_requirements = st.text_area("Extra requirements for sample data", - value=data_requirements, - key=f'data_requirements_text_area_{gen_table_name}', - placeholder="Exp: value of column X should follow patterns xxx-xxxx, while x could be A-Z or 0-9", - on_change=streamlit_utils.widget_on_change, - args=[f'data_requirements_text_area_{gen_table_name}', - gen_table_index, - 'data_requirements']) - - generate_sample_data_col1, generate_sample_data_col2 = st.columns(2) - with generate_sample_data_col1: - st.button("Generate Sample Data", - key=f"generate_data_button_{gen_table_name}", - on_click=streamlit_utils.click_button, - kwargs={"button_name": f"generate_sample_data_{gen_table_name}"}) - - if f"generate_sample_data_{gen_table_name}" not in st.session_state: - st.session_state[f"generate_sample_data_{gen_table_name}"] = False - - if f"{gen_table_name}_smaple_data_generated" not in st.session_state: - st.session_state[f"{gen_table_name}_smaple_data_generated"] = False - elif st.session_state[f"generate_sample_data_{gen_table_name}"]: - st.session_state[f"generate_sample_data_{gen_table_name}"] = False # Reset clicked status - if not st.session_state[f"{gen_table_name}_smaple_data_generated"]: - with generate_sample_data_col2: - with st.spinner('Generating...'): - sample_data = openai_api.generate_sample_data(industry_name, - rows_count, - table, - data_requirements) - st.session_state[f"{gen_table_name}_smaple_data_generated"] = True - # Store generated data to session_state - current_generated_sample_data[gen_table_name] = sample_data - - # Also update current_pipeline_obj if checked check-box before generating sample data - # if sample_data and st.session_state[f"add_to_staging_{gen_table_index}_checkbox"]: - if sample_data and st.session_state.get(f"add_to_staging_{gen_table_index}_checkbox", False): - spark = st.session_state["spark"] - json_str, schema = cddp.load_sample_data(spark, sample_data, format="json") - - for index, dataset in enumerate(current_pipeline_obj['staging']): - if dataset["name"] == gen_table_name: - i = index - - current_pipeline_obj['staging'][i]['sampleData'] = json.loads(json_str) - current_pipeline_obj['staging'][i]['schema'] = json.loads(schema) - - if st.session_state[f"{gen_table_name}_smaple_data_generated"]: - st.session_state[f"{gen_table_name}_smaple_data_generated"] = False # Reset data generated flag - json_sample_data = json.loads(sample_data) - current_generated_sample_data[gen_table_name] = json_sample_data # Save generated data to session_state - - if gen_table_name in current_generated_sample_data: - sample_data_df = pd.DataFrame.from_dict(current_generated_sample_data[gen_table_name], orient='columns') - st.write(sample_data_df) - - if stg_name_has_dependency: - st.info("""This table has been referenced in other tasks! - Please remove relevant dependency before trying to remove it from staging zone.""") - - # Show checkbox only after sample data has been generated - st.checkbox("Add to staging zone", - key=f"add_to_staging_{gen_table_index}_checkbox", - value=added_to_stage, - on_change=streamlit_utils.add_to_staging_zone, - args=[gen_table_index, gen_table_name, gen_table_desc], - disabled=stg_name_has_dependency) - - except ValueError as e: - # TODO: Add error/exception to standard error-showing widget - st.write(tables) with tab_std_sql: if "current_editing_pipeline_tasks" not in st.session_state: @@ -258,7 +176,7 @@ if std_name_has_dependency: disable_std_name_input = True disable_std_task_deletion = True - st.info("""This standardization task has been referenced in other tasks! + st.info("""This standardization task has been referenced by other tasks! Please remove relevant dependency before trying to rename or delete this task.""") std_name = st.text_input('Transformation Name', @@ -365,7 +283,7 @@ if srv_name_has_dependency: disable_srv_name_input = True disable_srv_task_deletion = True - st.info("""This serving task has been referenced in other tasks! + st.info("""This serving task has been referenced by other tasks! Please remove relevant dependency before trying to rename the task.""") srv_name = st.text_input('Aggregation Name', diff --git a/src/streamlit/streamlit_utils.py b/src/streamlit/streamlit_utils.py index 95e0b36..61ed731 100644 --- a/src/streamlit/streamlit_utils.py +++ b/src/streamlit/streamlit_utils.py @@ -1,8 +1,11 @@ import cddp import json +import random import streamlit as st import tempfile import uuid +import pandas as pd +from time import sleep def get_selected_tables(tables): @@ -311,3 +314,179 @@ def check_tables_dependency(target_name): def widget_on_change(widget_key, index, session_state_key): current_generated_tables = st.session_state['current_generated_tables']['generated_tables'] current_generated_tables[index][session_state_key] = st.session_state[widget_key] + st.session_state["generate_tables"] = False + + +def render_table_expander(table, + current_generated_tables, + current_generated_sample_data, + current_pipeline_obj, + gen_table_index, + industry_name, + openai_api, + key_suffix="init"): + sample_data = None + added_to_stage = current_generated_tables["generated_tables"][gen_table_index].get("staged_flag", False) + expander_label = table["table_name"] + + if added_to_stage: + expander_label += " :heavy_check_mark:" + + with st.expander(expander_label, expanded=added_to_stage): + gen_table_name = table["table_name"] + gen_table_desc = table["table_description"] + stg_name_has_dependency = check_tables_dependency(gen_table_name) + + gen_table_sample_data_count = current_generated_tables["generated_tables"][gen_table_index].get("sample_data_count", 5) + sample_data_requirements_flag = current_generated_tables["generated_tables"][gen_table_index].get("sample_data_requirements_flag", False) + data_requirements = current_generated_tables["generated_tables"][gen_table_index].get("data_requirements", "") + + + columns = table["columns"] + columns_df = pd.DataFrame.from_dict(columns, orient='columns') + + st.write(gen_table_desc) + st.write(columns_df) + + st.write(f"Generate sample data") + rows_count = st.slider("Number of rows", + min_value=5, + max_value=50, + value=gen_table_sample_data_count, + key=f'gen_rows_count_slider_{gen_table_name}_{key_suffix}', + on_change=widget_on_change, + args=[f'gen_rows_count_slider_{gen_table_name}_{key_suffix}', + gen_table_index, + 'sample_data_count'], + disabled=st.session_state["disable_generate_data_widget"]) + + enable_data_requirements = st.toggle("With extra sample data requirements", + value=sample_data_requirements_flag, + key=f'data_requirements_toggle_{gen_table_name}_{key_suffix}', + on_change=widget_on_change, + args=[f'data_requirements_toggle_{gen_table_name}_{key_suffix}', + gen_table_index, + 'sample_data_requirements_flag'], + disabled=st.session_state["disable_generate_data_widget"]) + + if enable_data_requirements: + data_requirements = st.text_area("Extra requirements for sample data", + value=data_requirements, + key=f'data_requirements_text_area_{gen_table_name}_{key_suffix}', + placeholder="Exp: value of column X should follow patterns xxx-xxxx, while x could be A-Z or 0-9", + on_change=widget_on_change, + args=[f'data_requirements_text_area_{gen_table_name}_{key_suffix}', + gen_table_index, + 'data_requirements'], + disabled=st.session_state["disable_generate_data_widget"]) + + generate_sample_data_col1, generate_sample_data_col2 = st.columns(2) + with generate_sample_data_col1: + st.button("Generate Sample Data", + key=f"generate_data_button_{gen_table_name}_{key_suffix}", + on_click=click_button, + kwargs={"button_name": f"generate_sample_data_{gen_table_name}"}, + disabled=st.session_state["disable_generate_data_widget"], + use_container_width=True) + + if f"generate_sample_data_{gen_table_name}" not in st.session_state: + st.session_state[f"generate_sample_data_{gen_table_name}"] = False + + if f"{gen_table_name}_smaple_data_generated" not in st.session_state: + st.session_state[f"{gen_table_name}_smaple_data_generated"] = False + elif st.session_state[f"generate_sample_data_{gen_table_name}"]: + st.session_state[f"generate_sample_data_{gen_table_name}"] = False # Reset clicked status + if not st.session_state[f"{gen_table_name}_smaple_data_generated"]: + with generate_sample_data_col2: + with st.spinner('Generating...'): + sample_data = openai_api.generate_sample_data(industry_name, + rows_count, + table, + data_requirements) + st.session_state[f"{gen_table_name}_smaple_data_generated"] = True + # Store generated data to session_state + current_generated_sample_data[gen_table_name] = sample_data + + # Also update current_pipeline_obj if checked check-box before generating sample data + # if sample_data and st.session_state[f"add_to_staging_{gen_table_index}_checkbox"]: + if sample_data and st.session_state.get(f"add_to_staging_{gen_table_index}_checkbox", False): + spark = st.session_state["spark"] + json_str, schema = cddp.load_sample_data(spark, sample_data, format="json") + + for index, dataset in enumerate(current_pipeline_obj['staging']): + if dataset["name"] == gen_table_name: + i = index + + current_pipeline_obj['staging'][i]['sampleData'] = json.loads(json_str) + current_pipeline_obj['staging'][i]['schema'] = json.loads(schema) + + if st.session_state[f"{gen_table_name}_smaple_data_generated"]: + st.session_state[f"{gen_table_name}_smaple_data_generated"] = False # Reset data generated flag + json_sample_data = json.loads(sample_data) + current_generated_sample_data[gen_table_name] = json_sample_data # Save generated data to session_state + + if gen_table_name in current_generated_sample_data: + sample_data_df = pd.DataFrame.from_dict(current_generated_sample_data[gen_table_name], orient='columns') + st.write(sample_data_df) + + if stg_name_has_dependency: + st.info("""This table has been referenced by other tasks! + Please remove relevant dependency before trying to remove it from staging zone.""") + + # Show checkbox only after sample data has been generated + st.checkbox("Add to staging zone", + key=f"add_to_staging_{gen_table_index}_checkbox", + value=added_to_stage, + on_change=add_to_staging_zone, + args=[gen_table_index, gen_table_name, gen_table_desc], + disabled=stg_name_has_dependency) + + +def generate_tables(placeholder, + current_generated_tables, + current_pipeline_obj, + current_generated_sample_data, + industry_name, + industry_contexts, + openai_api): + st.session_state["generate_tables"] = True + gen_tables_count = st.session_state["generate_tables_count"] = random.randint(5, 8) + if "disable_generate_data_widget" not in st.session_state or not st.session_state["disable_generate_data_widget"]: + st.session_state["disable_generate_data_widget"] = True + + # Clean existing table expanders before render new generate ones + placeholder.empty() + sleep(0.01) # Workaround for elements empty/cleaning issue, https://github.com/streamlit/streamlit/issues/5044 + + with placeholder.container(): + spinner_container = st.empty().container() + if not has_staged_table(): # Generate new tables if there's no existing dependency/references found in std/srv tasks + # Clean current_generated_sample_data key in session state + del st.session_state['current_generated_sample_data'] + current_generated_sample_data = {} + + gen_table_index = 0 + generated_tables = "" + current_generated_tables["generated_tables"] = [] + + while gen_table_index < gen_tables_count: + with spinner_container: + with st.spinner(f'Generating {gen_table_index + 1} of {gen_tables_count} tables...'): + table = openai_api.recommend_tables_for_industry_one_at_a_time(industry_name, industry_contexts, generated_tables) + + current_generated_tables["generated_tables"].append(json.loads(table)) + generated_tables = json.dumps(current_generated_tables["generated_tables"], indent=2) + + render_table_expander(json.loads(table), + current_generated_tables, + current_generated_sample_data, + current_pipeline_obj, + gen_table_index, + industry_name, + openai_api) + + gen_table_index += 1 + + # Clean the temporary generated tables and redraw them again in 2_AI_Assistant.py to enable all widgets inside table expanders. + placeholder.empty() + sleep(0.01) From 0f30156c6807c6c6b5f1dbcb28c13039b944871e Mon Sep 17 00:00:00 2001 From: thurston-chen Date: Sun, 8 Oct 2023 17:45:06 +0800 Subject: [PATCH 6/8] refactor: add exception handling and retry for OpenAI requests --- requirements.txt | 1 + src/cddp/openai_api.py | 102 ++++++++++++++++---------- src/streamlit/pages/2_AI Assistant.py | 61 +++++++++------ src/streamlit/streamlit_utils.py | 90 +++++++++++++---------- 4 files changed, 156 insertions(+), 98 deletions(-) diff --git a/requirements.txt b/requirements.txt index bf519df..f612654 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ setuptools wheel pandas==1.5.3 +numpy==1.23.5 pyarrow fastparquet pyspark==3.3.0 diff --git a/src/cddp/openai_api.py b/src/cddp/openai_api.py index 459801d..0cafd5c 100644 --- a/src/cddp/openai_api.py +++ b/src/cddp/openai_api.py @@ -13,6 +13,7 @@ class OpenaiApi: DEPLOYMENT = os.getenv("OPENAI_DEPLOYMENT") MODEL = os.getenv("OPENAI_MODEL") API_VERSION = os.getenv("OPENAI_API_VERSION") + MAX_RETRY = 3 def __init__(self): @@ -22,6 +23,35 @@ def __init__(self): openai_api_base=self.OPENAI_API_BASE) + def _is_valid_json(self, input: str): + """Check whether input string is valid JSON or not + + :param input: input string + + :returns: boolean value on validation check + """ + try: + json.loads(input) + valid_flag = True + except ValueError as e: + valid_flag = False + + return valid_flag + + + def _run_chain(self, chain: LLMChain, params: dict): + retry_count = 0 + + while retry_count < self.MAX_RETRY: + response = chain(params) + if self._is_valid_json(response["text"]): + return response["text"] + print(f"[ERROR] Got bad format response from LLM: {response['text']}") + retry_count += 1 + + raise ValueError("Got bad format response from LLM") + + def recommend_data_processing_scenario(self, industry_name: str): recommend_data_processing_scenario_template = """ You're a data engineer and familiar with the {industry_name} industry IT systems. @@ -55,11 +85,11 @@ def recommend_data_processing_scenario(self, industry_name: str): template=recommend_data_processing_scenario_template, ) chain = LLMChain(llm=self.llm, prompt=prompt) - response = chain({"industry_name": industry_name}) - results = response["text"] + results = self._run_chain(chain, {"industry_name": industry_name}) return results + def recommend_tables_for_industry(self, industry_name: str, industry_contexts: str): """ Recommend database tables for a given industry and relevant contexts. @@ -98,9 +128,8 @@ def recommend_tables_for_industry(self, industry_name: str, industry_contexts: s template=recommaned_tables_for_industry_template, ) chain = LLMChain(llm=self.llm, prompt=prompt) - response = chain({"industry_name": industry_name, - "industry_contexts": industry_contexts}) - results = response["text"] + results = self._run_chain(chain, {"industry_name": industry_name, + "industry_contexts": industry_contexts}) return results @@ -154,10 +183,9 @@ def recommend_tables_for_industry_one_at_a_time(self, industry_name: str, indust template=recommaned_tables_for_industry_template, ) chain = LLMChain(llm=self.llm, prompt=prompt) - response = chain({"industry_name": industry_name, - "industry_contexts": industry_contexts, - "recommended_tables": recommended_tables}) - results = response["text"] + results = self._run_chain(chain, {"industry_name": industry_name, + "industry_contexts": industry_contexts, + "recommended_tables": recommended_tables}) return results @@ -238,12 +266,11 @@ def recommend_custom_table(self, template=recommend_custom_tables_template, ) chain = LLMChain(llm=self.llm, prompt=prompt) - response = chain({"industry_name": industry_name, - "industry_contexts": industry_contexts, - "recommened_tables": recommened_tables, - "custom_table_name": custom_table_name, - "custom_table_description": custom_table_description}) - results = response["text"] + results = self._run_chain(chain, {"industry_name": industry_name, + "industry_contexts": industry_contexts, + "recommened_tables": recommened_tables, + "custom_table_name": custom_table_name, + "custom_table_description": custom_table_description}) return results @@ -292,11 +319,10 @@ def recommend_data_processing_logics(self, template=recommend_data_cleaning_logics_template, ) chain = LLMChain(llm=self.llm, prompt=prompt) - response = chain({"industry_name": industry_name, - "industry_contexts": industry_contexts, - "processing_logic": processing_logic, - "recommened_tables": recommened_tables}) - results = json.loads(response["text"]) + results = self._run_chain(chain, {"industry_name": industry_name, + "industry_contexts": industry_contexts, + "processing_logic": processing_logic, + "recommened_tables": recommened_tables}) return results @@ -360,12 +386,11 @@ def generate_custom_data_processing_logics(self, chain = LLMChain(llm=self.llm, prompt=prompt) # Run the chain only specifying the input variable. - response = chain({"industry_name": industry_name, - "industry_contexts": industry_contexts, - "custom_data_processing_logic": custom_data_processing_logic, - "involved_tables": involved_tables, - "output_table_name": output_table_name}) - results = response["text"] + results = self._run_chain(chain, {"industry_name": industry_name, + "industry_contexts": industry_contexts, + "custom_data_processing_logic": custom_data_processing_logic, + "involved_tables": involved_tables, + "output_table_name": output_table_name}) return results @@ -392,14 +417,16 @@ def generate_sample_data(self, And below are patterns of column values in json format, if it's not provided please ignore this requirement. {column_values_patterns} - And the sample data should be an array of json object like below. - {{ - "{{column X}}": "{{column value}}", - "{{column Y}}": "{{column value}}", - "{{column Z}}": "{{column value}}" - }} + And the sample data should strictly be in JSON format like below. + [ + {{ + "{{column X}}": "{{column value}}", + "{{column Y}}": "{{column value}}", + "{{column Z}}": "{{column value}}" + }} + ] - The sample data would be: + Therefore the sample data would be: """ prompt = PromptTemplate( @@ -407,11 +434,10 @@ def generate_sample_data(self, template=generate_sample_data_template, ) chain = LLMChain(llm=self.llm, prompt=prompt) - response = chain({"industry_name": industry_name, - "number_of_lines": number_of_lines, - "target_table": target_table, - "column_values_patterns": column_values_patterns}) - results = response["text"] + results = self._run_chain(chain, {"industry_name": industry_name, + "number_of_lines": number_of_lines, + "target_table": target_table, + "column_values_patterns": column_values_patterns}) return results diff --git a/src/streamlit/pages/2_AI Assistant.py b/src/streamlit/pages/2_AI Assistant.py index ddbceef..8757705 100644 --- a/src/streamlit/pages/2_AI Assistant.py +++ b/src/streamlit/pages/2_AI Assistant.py @@ -58,9 +58,13 @@ st.session_state["generated_usecases"] = False # Reset button clicked status with generate_use_cases_col2: with st.spinner('Generating...'): - usecases = openai_api.recommend_data_processing_scenario(current_pipeline_obj.get("industry", "")) - st.session_state['current_generated_usecases'] = json.loads(usecases) - + try: + usecases = openai_api.recommend_data_processing_scenario(current_pipeline_obj.get("industry", "")) + st.session_state['current_generated_usecases'] = json.loads(usecases) + except ValueError as e: + st.error("Got invalid response from AI Assistant, please try again!") + except Exception as e: + st.error("Got error while getting help from AI Assistant, please try again!") if "current_generated_usecases" in st.session_state: usecases = st.session_state['current_generated_usecases'] @@ -113,14 +117,15 @@ disabled=st.session_state["disable_generate_table_button"], use_container_width=True) - if streamlit_utils.has_staged_table() and st.session_state["generate_tables"]: # Show warning message if some of generated tables have been referenced by other std/srv tasks + # Show warning message if some of generated tables have been referenced by other std/srv tasks + if streamlit_utils.has_staged_table() and st.session_state["has_clicked_generate_tables_btn"]: with container: st.warning("Some of current generated tables have been added to Staging zone, please remove them before generating tables again!") - st.session_state["generate_tables"] = False + st.session_state["has_clicked_generate_tables_btn"] = False st.session_state["disable_generate_data_widget"] = False # Render generated tables once widgets inside table expander has been clicked/updated - if "generated_tables" in current_generated_tables and not st.session_state["generate_tables"]: + if "generated_tables" in current_generated_tables and not st.session_state["has_clicked_generate_tables_btn"]: tables = current_generated_tables["generated_tables"] with container: for gen_table_index, table in enumerate(tables): @@ -139,7 +144,13 @@ st.button("Generate", key="generate_table_redraw", on_click=streamlit_utils.generate_tables, - args=[placeholder, current_generated_tables, current_pipeline_obj, current_generated_sample_data, industry_name, industry_contexts], + args=[placeholder, + current_generated_tables, + current_pipeline_obj, + current_generated_sample_data, + industry_name, + industry_contexts, + openai_api], disabled=st.session_state["disable_generate_table_button"], use_container_width=True) @@ -164,9 +175,6 @@ # Get all staged table details staged_table_names, staged_table_details = streamlit_utils.get_staged_tables() - # Get all standardized table details - standardized_table_names, standardized_table_details = streamlit_utils.get_std_srv_tables('standard') - for i in range(len(current_pipeline_obj["standard"])): target_name = current_pipeline_obj['standard'][i]['output']['target'] disable_std_name_input = False @@ -190,6 +198,9 @@ current_pipeline_obj['standard'][i]['output']['target'] = std_name current_pipeline_obj['standard'][i]['name'] = std_name + # Get latest standardized table details + standardized_table_names, standardized_table_details = streamlit_utils.get_std_srv_tables('standard') + optional_tables = staged_table_names + standardized_table_names if std_name in optional_tables: optional_tables.remove(std_name) # Remove itself from the optional tables list @@ -229,19 +240,22 @@ if st.session_state[f"std_{i}_gen_sql"]: with generate_transform_sql_col2: with st.spinner('Generating...'): - process_logic = openai_api.generate_custom_data_processing_logics(industry_name=industry_name, - industry_contexts=industry_contexts, - involved_tables=selected_table_details, - custom_data_processing_logic=process_requirements, - output_table_name=std_name) try: + process_logic = openai_api.generate_custom_data_processing_logics(industry_name=industry_name, + industry_contexts=industry_contexts, + involved_tables=selected_table_details, + custom_data_processing_logic=process_requirements, + output_table_name=std_name) + process_logic_json = json.loads(process_logic) std_sql_val = process_logic_json["sql"] std_table_schema = process_logic_json["schema"] current_editing_pipeline_tasks['standard'][i]['query_results_schema'] = std_table_schema current_pipeline_obj['standard'][i]['code']['sql'][0] = std_sql_val except ValueError as e: - st.write(process_logic) + st.error("Got invalid response from AI Assistant, please try again!") + except Exception as e: + st.error("Got error while getting help from AI Assistant, please try again!") std_sql = st.text_area(f'Transformation Spark SQL', key=f'std_{i}_transform_sql_text_area', @@ -339,19 +353,22 @@ if st.session_state[f"srv_{i}_gen_sql"]: with generate_aggregate_sql_col2: with st.spinner('Generating...'): - process_logic = openai_api.generate_custom_data_processing_logics(industry_name=industry_name, - industry_contexts=industry_contexts, - involved_tables=selected_table_details, - custom_data_processing_logic=process_requirements, - output_table_name=srv_name) try: + process_logic = openai_api.generate_custom_data_processing_logics(industry_name=industry_name, + industry_contexts=industry_contexts, + involved_tables=selected_table_details, + custom_data_processing_logic=process_requirements, + output_table_name=srv_name) + process_logic_json = json.loads(process_logic) srv_sql_val = process_logic_json["sql"] srv_table_schema = process_logic_json["schema"] current_editing_pipeline_tasks['serving'][i]['query_results_schema'] = srv_table_schema current_pipeline_obj['serving'][i]['code']['sql'][0] = srv_sql_val except ValueError as e: - st.write(process_logic) + st.error("Got invalid response from AI Assistant, please try again!") + except Exception as e: + st.error("Got error while getting help from AI Assistant, please try again!") srv_sql = st.text_area(f'Aggregation Spark SQL', key=f'srv_{i}_aggregate_sql_text_area', diff --git a/src/streamlit/streamlit_utils.py b/src/streamlit/streamlit_utils.py index 61ed731..f411606 100644 --- a/src/streamlit/streamlit_utils.py +++ b/src/streamlit/streamlit_utils.py @@ -314,7 +314,7 @@ def check_tables_dependency(target_name): def widget_on_change(widget_key, index, session_state_key): current_generated_tables = st.session_state['current_generated_tables']['generated_tables'] current_generated_tables[index][session_state_key] = st.session_state[widget_key] - st.session_state["generate_tables"] = False + st.session_state["has_clicked_generate_tables_btn"] = False def render_table_expander(table, @@ -396,29 +396,36 @@ def render_table_expander(table, st.session_state[f"{gen_table_name}_smaple_data_generated"] = False elif st.session_state[f"generate_sample_data_{gen_table_name}"]: st.session_state[f"generate_sample_data_{gen_table_name}"] = False # Reset clicked status + st.session_state["has_clicked_generate_tables_btn"] = False + if not st.session_state[f"{gen_table_name}_smaple_data_generated"]: with generate_sample_data_col2: with st.spinner('Generating...'): - sample_data = openai_api.generate_sample_data(industry_name, - rows_count, - table, - data_requirements) - st.session_state[f"{gen_table_name}_smaple_data_generated"] = True - # Store generated data to session_state - current_generated_sample_data[gen_table_name] = sample_data - - # Also update current_pipeline_obj if checked check-box before generating sample data - # if sample_data and st.session_state[f"add_to_staging_{gen_table_index}_checkbox"]: - if sample_data and st.session_state.get(f"add_to_staging_{gen_table_index}_checkbox", False): - spark = st.session_state["spark"] - json_str, schema = cddp.load_sample_data(spark, sample_data, format="json") - - for index, dataset in enumerate(current_pipeline_obj['staging']): - if dataset["name"] == gen_table_name: - i = index - - current_pipeline_obj['staging'][i]['sampleData'] = json.loads(json_str) - current_pipeline_obj['staging'][i]['schema'] = json.loads(schema) + try: + sample_data = openai_api.generate_sample_data(industry_name, + rows_count, + table, + data_requirements) + st.session_state[f"{gen_table_name}_smaple_data_generated"] = True + # Store generated data to session_state + current_generated_sample_data[gen_table_name] = sample_data + + # Also update current_pipeline_obj if checked check-box before generating sample data + # if sample_data and st.session_state[f"add_to_staging_{gen_table_index}_checkbox"]: + if sample_data and st.session_state.get(f"add_to_staging_{gen_table_index}_checkbox", False): + spark = st.session_state["spark"] + json_str, schema = cddp.load_sample_data(spark, sample_data, format="json") + + for index, dataset in enumerate(current_pipeline_obj['staging']): + if dataset["name"] == gen_table_name: + i = index + + current_pipeline_obj['staging'][i]['sampleData'] = json.loads(json_str) + current_pipeline_obj['staging'][i]['schema'] = json.loads(schema) + except ValueError as e: + st.error("Got invalid response from AI Assistant, please try again!") + except Exception as e: + st.error("Got error while getting help from AI Assistant, please try again!") if st.session_state[f"{gen_table_name}_smaple_data_generated"]: st.session_state[f"{gen_table_name}_smaple_data_generated"] = False # Reset data generated flag @@ -449,7 +456,7 @@ def generate_tables(placeholder, industry_name, industry_contexts, openai_api): - st.session_state["generate_tables"] = True + st.session_state["has_clicked_generate_tables_btn"] = True gen_tables_count = st.session_state["generate_tables_count"] = random.randint(5, 8) if "disable_generate_data_widget" not in st.session_state or not st.session_state["disable_generate_data_widget"]: st.session_state["disable_generate_data_widget"] = True @@ -470,22 +477,29 @@ def generate_tables(placeholder, current_generated_tables["generated_tables"] = [] while gen_table_index < gen_tables_count: - with spinner_container: - with st.spinner(f'Generating {gen_table_index + 1} of {gen_tables_count} tables...'): - table = openai_api.recommend_tables_for_industry_one_at_a_time(industry_name, industry_contexts, generated_tables) - - current_generated_tables["generated_tables"].append(json.loads(table)) - generated_tables = json.dumps(current_generated_tables["generated_tables"], indent=2) - - render_table_expander(json.loads(table), - current_generated_tables, - current_generated_sample_data, - current_pipeline_obj, - gen_table_index, - industry_name, - openai_api) - - gen_table_index += 1 + try: + with spinner_container: + with st.spinner(f'Generating {gen_table_index + 1} of {gen_tables_count} tables...'): + table = openai_api.recommend_tables_for_industry_one_at_a_time(industry_name, industry_contexts, generated_tables) + + current_generated_tables["generated_tables"].append(json.loads(table)) + generated_tables = json.dumps(current_generated_tables["generated_tables"], indent=2) + + render_table_expander(json.loads(table), + current_generated_tables, + current_generated_sample_data, + current_pipeline_obj, + gen_table_index, + industry_name, + openai_api) + + gen_table_index += 1 + except ValueError as e: + st.error("Got invalid response from AI Assistant, please try again!") + break + except Exception as e: + st.error("Got error while getting help from AI Assistant, please try again!") + break # Clean the temporary generated tables and redraw them again in 2_AI_Assistant.py to enable all widgets inside table expanders. placeholder.empty() From 68cd1119b725aa84da72d71fae9bdd2ccd8cfc90 Mon Sep 17 00:00:00 2001 From: Sean Ma Date: Thu, 18 Apr 2024 08:50:26 +0000 Subject: [PATCH 7/8] fix diff JAVA_HOME path in macos and win --- Dockerfile | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 8248f0f..3a8c9a5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,8 +14,10 @@ RUN apt-get update && \ apt-get clean; -ENV JAVA_HOME /usr/lib/jvm/java-11-openjdk-amd64/ -RUN export JAVA_HOME +# ENV JAVA_HOME /usr/lib/jvm/java-11-openjdk-arm64/ +# ENV JAVA_HOME /usr/lib/jvm/java-11-openjdk-amd64/ +RUN export JAVA_HOME="$(dirname $(dirname $(readlink -f $(which java))))" +RUN echo $JAVA_HOME WORKDIR /usr/src/app ENV FLASK_APP=./src/app.py From 772889d96e5a1e331d3c286a37773a7e20b562ce Mon Sep 17 00:00:00 2001 From: Sean Ma Date: Thu, 18 Apr 2024 08:50:57 +0000 Subject: [PATCH 8/8] fix fruit app visualization issues --- example/pipeline_fruit_batch_2.json | 76 ++--------------------------- src/streamlit/pages/1_Editor.py | 4 +- 2 files changed, 6 insertions(+), 74 deletions(-) diff --git a/example/pipeline_fruit_batch_2.json b/example/pipeline_fruit_batch_2.json index 57ee935..a12fa90 100644 --- a/example/pipeline_fruit_batch_2.json +++ b/example/pipeline_fruit_batch_2.json @@ -258,16 +258,12 @@ ], "visualization": [ { - "name": "untitled1", + "name": "TotalSales", "type": "Bar Chart", "input": "srv_fruit_sales_total", - "description": "" - }, - { - "name": "untitled2", - "type": "Line Chart", - "input": "srv_fruit_sales_total", - "description": "" + "description": "", + "x_axis": "fruit", + "y_axis": "total" } ], "industry": "Other", @@ -416,70 +412,6 @@ "total": 17.0 } ] - }, - "visualization": { - "untitled1": { - "xAxis": { - "type": "category", - "data": [ - "Fiji Apple", - "Green Apple", - "Peach", - "Green Grape", - "Orange", - "Red Grape", - "Banana" - ] - }, - "yAxis": { - "type": "value" - }, - "series": [ - { - "data": [ - 56.0, - 45.0, - 39.0, - 36.0, - 28.0, - 24.0, - 17.0 - ], - "type": "bar" - } - ] - }, - "untitled2": { - "xAxis": { - "type": "category", - "data": [ - "Fiji Apple", - "Green Apple", - "Peach", - "Green Grape", - "Orange", - "Red Grape", - "Banana" - ] - }, - "yAxis": { - "type": "value" - }, - "series": [ - { - "data": [ - 56.0, - 45.0, - 39.0, - 36.0, - 28.0, - 24.0, - 17.0 - ], - "type": "line" - } - ] - } } } } \ No newline at end of file diff --git a/src/streamlit/pages/1_Editor.py b/src/streamlit/pages/1_Editor.py index e2f17be..5fbee7e 100644 --- a/src/streamlit/pages/1_Editor.py +++ b/src/streamlit/pages/1_Editor.py @@ -698,9 +698,9 @@ def clean_vis_data(vis_name): selected_y_axis_index = 1 if len(cols) > 1 else 0 for j in range(len(cols)): - if cols[j] == current_pipeline_obj['visualization'][i]['x_axis']: + if 'x_axis' in current_pipeline_obj['visualization'] and cols[j] == current_pipeline_obj['visualization'][i]['x_axis']: selected_x_axis_index = j - if cols[j] == current_pipeline_obj['visualization'][i]['y_axis']: + if 'y_axis' in current_pipeline_obj['visualization'] and cols[j] == current_pipeline_obj['visualization'][i]['y_axis']: selected_y_axis_index = j x_axis = st.selectbox('X Axis', cols, key=f'vis_x_axis_{i}', index=selected_x_axis_index)