diff --git a/Dockerfile b/Dockerfile index 8248f0f..3a8c9a5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,8 +14,10 @@ RUN apt-get update && \ apt-get clean; -ENV JAVA_HOME /usr/lib/jvm/java-11-openjdk-amd64/ -RUN export JAVA_HOME +# ENV JAVA_HOME /usr/lib/jvm/java-11-openjdk-arm64/ +# ENV JAVA_HOME /usr/lib/jvm/java-11-openjdk-amd64/ +RUN export JAVA_HOME="$(dirname $(dirname $(readlink -f $(which java))))" +RUN echo $JAVA_HOME WORKDIR /usr/src/app ENV FLASK_APP=./src/app.py diff --git a/example/pipeline_fruit_batch_2.json b/example/pipeline_fruit_batch_2.json index 57ee935..a12fa90 100644 --- a/example/pipeline_fruit_batch_2.json +++ b/example/pipeline_fruit_batch_2.json @@ -258,16 +258,12 @@ ], "visualization": [ { - "name": "untitled1", + "name": "TotalSales", "type": "Bar Chart", "input": "srv_fruit_sales_total", - "description": "" - }, - { - "name": "untitled2", - "type": "Line Chart", - "input": "srv_fruit_sales_total", - "description": "" + "description": "", + "x_axis": "fruit", + "y_axis": "total" } ], "industry": "Other", @@ -416,70 +412,6 @@ "total": 17.0 } ] - }, - "visualization": { - "untitled1": { - "xAxis": { - "type": "category", - "data": [ - "Fiji Apple", - "Green Apple", - "Peach", - "Green Grape", - "Orange", - "Red Grape", - "Banana" - ] - }, - "yAxis": { - "type": "value" - }, - "series": [ - { - "data": [ - 56.0, - 45.0, - 39.0, - 36.0, - 28.0, - 24.0, - 17.0 - ], - "type": "bar" - } - ] - }, - "untitled2": { - "xAxis": { - "type": "category", - "data": [ - "Fiji Apple", - "Green Apple", - "Peach", - "Green Grape", - "Orange", - "Red Grape", - "Banana" - ] - }, - "yAxis": { - "type": "value" - }, - "series": [ - { - "data": [ - 56.0, - 45.0, - 39.0, - 36.0, - 28.0, - 24.0, - 17.0 - ], - "type": "line" - } - ] - } } } } \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index bf519df..f612654 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ setuptools wheel pandas==1.5.3 +numpy==1.23.5 pyarrow fastparquet pyspark==3.3.0 diff --git a/src/cddp/openai_api.py b/src/cddp/openai_api.py index 2788d7b..0cafd5c 100644 --- a/src/cddp/openai_api.py +++ b/src/cddp/openai_api.py @@ -3,559 +3,463 @@ from langchain.chains import LLMChain import json import os +import time from dotenv import load_dotenv load_dotenv() -OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") -DEPLOYMENT = os.getenv("OPENAI_DEPLOYMENT") -MODEL = os.getenv("OPENAI_MODEL") -API_VERSION = os.getenv("OPENAI_API_VERSION") +class OpenaiApi: + OPENAI_API_BASE = os.getenv("OPENAI_API_BASE") + DEPLOYMENT = os.getenv("OPENAI_DEPLOYMENT") + MODEL = os.getenv("OPENAI_MODEL") + API_VERSION = os.getenv("OPENAI_API_VERSION") + MAX_RETRY = 3 -def _prepare_openapi_llm(): - llm = AzureChatOpenAI(deployment_name=DEPLOYMENT, - model=MODEL, - openai_api_version=API_VERSION, - openai_api_base=OPENAI_API_BASE) - return llm + def __init__(self): + self.llm = AzureChatOpenAI(deployment_name=self.DEPLOYMENT, + model=self.MODEL, + openai_api_version=self.API_VERSION, + openai_api_base=self.OPENAI_API_BASE) -def recommend_data_processing_scenario(industry_name: str): - recommend_data_processing_scenario_template = """ - You're a data engineer and familiar with the {industry_name} industry IT systems. - Please recommend 7 to 10 different data processing pipelines scenarios, including required steps like collecting data sources, transforming the data and generating aggregated metrics, etc. - Your answer should be in an array of JSON format like below. - [ + def _is_valid_json(self, input: str): + """Check whether input string is valid JSON or not + + :param input: input string + + :returns: boolean value on validation check + """ + try: + json.loads(input) + valid_flag = True + except ValueError as e: + valid_flag = False + + return valid_flag + + + def _run_chain(self, chain: LLMChain, params: dict): + retry_count = 0 + + while retry_count < self.MAX_RETRY: + response = chain(params) + if self._is_valid_json(response["text"]): + return response["text"] + print(f"[ERROR] Got bad format response from LLM: {response['text']}") + retry_count += 1 + + raise ValueError("Got bad format response from LLM") + + + def recommend_data_processing_scenario(self, industry_name: str): + recommend_data_processing_scenario_template = """ + You're a data engineer and familiar with the {industry_name} industry IT systems. + Please recommend 7 to 10 different data processing pipelines scenarios, including required steps like collecting data sources, transforming the data and generating aggregated metrics, etc. + Your answer should be in an array of JSON format like below. + [ + {{ + "pipeline_name": "{{name of the data processing pipeline}}", + "description": "{{short description on the data pipeline}}", + "stages": [ + {{ + "stage": "staging", + "description": "{{description on collected data sources for the rest of the data processing pipeline}}" + }}, + {{ + "stage": "standard", + "description": "{{description on data transformation logics}}" + }}, + {{ + "stage": "serving", + "description": "{{description on data aggregation logics}}" + }} + ] + }} + ] + Therefore your answers are: + """ + + prompt = PromptTemplate( + input_variables=["industry_name"], + template=recommend_data_processing_scenario_template, + ) + chain = LLMChain(llm=self.llm, prompt=prompt) + results = self._run_chain(chain, {"industry_name": industry_name}) + + return results + + + def recommend_tables_for_industry(self, industry_name: str, industry_contexts: str): + """ Recommend database tables for a given industry and relevant contexts. + + :param industry_name: industry name + :param industry_contexts: industry descriptions/contexts + + :returns: recommened tables with schema in array of json format + """ + + recommaned_tables_for_industry_template = """ + You're a data engineer and familiar with the {industry_name} industry IT systems. + And below is relevant contexts of the industry: + {industry_contexts} + + Please recommend some database tables with potential data schema for the above contexts. + Your response should be an array of JSON format objects like below. {{ - "pipeline_name": "{{name of the data processing pipeline}}", - "description": "{{short description on the data pipeline}}", - [ + "table_name": "{{table name}}", + "table_description": "{{table description}}", + "columns": [ {{ - "stage": "staging", - "description": "{{description on collected data sources for the rest of the data processing pipeline}}" - }}, + "column_name": "{{column name}}", + "data_type": "{{data type}}", + "is_null": {{true or false}}, + "is_primary_key": {{true or false}}, + "is_foreign_key": {{true or false}} + }} + ] + }} + + Please recommend 7 to 10 database tables: + """ + + prompt = PromptTemplate( + input_variables=["industry_name", "industry_contexts"], + template=recommaned_tables_for_industry_template, + ) + chain = LLMChain(llm=self.llm, prompt=prompt) + results = self._run_chain(chain, {"industry_name": industry_name, + "industry_contexts": industry_contexts}) + + return results + + + def recommend_data_processing_scenario_mock(self, industry_name: str): + results = """ + [ { "pipeline_name": "Flight Delay Pipeline", "description": "Collects flight data from various sources and generates metrics on delays", "stages": [ { "stage": "staging", "description": "Collects flight data from airline APIs and airport databases" }, { "stage": "standard", "description": "Transforms data to calculate delay metrics based on departure and arrival times" }, { "stage": "serving", "description": "Aggregates delay metrics by airport, airline, and route" } ] }, { "pipeline_name": "Baggage Handling Pipeline", "description": "Tracks baggage movement and generates metrics on handling efficiency", "stages": [ { "stage": "staging", "description": "Collects baggage movement data from RFID scanners and baggage handling systems" }, { "stage": "standard", "description": "Transforms data to calculate metrics on baggage handling efficiency, such as time to load and unload baggage" }, { "stage": "serving", "description": "Aggregates metrics by airport, airline, and baggage handling company" } ] }, { "pipeline_name": "Revenue Management Pipeline", "description": "Analyzes sales data to optimize pricing and revenue", "stages": [ { "stage": "staging", "description": "Collects sales data from ticketing systems and travel booking websites" }, { "stage": "standard", "description": "Transforms data to calculate revenue metrics, such as average ticket price and revenue per seat" }, { "stage": "serving", "description": "Aggregates metrics by route, fare class, and seasonality" } ] }, { "pipeline_name": "Maintenance Pipeline", "description": "Monitors aircraft health and schedules maintenance", "stages": [ { "stage": "staging", "description": "Collects aircraft sensor data and maintenance records" }, { "stage": "standard", "description": "Transforms data to identify potential maintenance issues and schedule preventative maintenance" }, { "stage": "serving", "description": "Aggregates metrics by aircraft type, age, and maintenance history" } ] }, { "pipeline_name": "Customer Service Pipeline", "description": "Analyzes customer feedback to improve service", "stages": [ { "stage": "staging", "description": "Collects customer feedback from surveys, social media, and customer service interactions" }, { "stage": "standard", "description": "Transforms data to identify common issues and sentiment analysis of customer feedback" }, { "stage": "serving", "description": "Aggregates metrics by route, airline, and customer feedback channel" } ] }, { "pipeline_name": "Fuel Efficiency Pipeline", "description": "Monitors fuel usage to optimize efficiency and reduce costs", "stages": [ { "stage": "staging", "description": "Collects fuel usage data from aircraft sensors and fueling systems" }, { "stage": "standard", "description": "Transforms data to calculate fuel efficiency metrics, such as fuel burn per passenger mile" }, { "stage": "serving", "description": "Aggregates metrics by aircraft type, route, and seasonality" } ] }, { "pipeline_name": "Security Pipeline", "description": "Monitors security incidents to improve safety and compliance", "stages": [ { "stage": "staging", "description": "Collects security incident data from airport security systems and passenger screening" }, { "stage": "standard", "description": "Transforms data to identify common security incidents and compliance issues" }, { "stage": "serving", "description": "Aggregates metrics by airport, airline, and security incident type" } ] } ] + """ + + return results + + + def recommend_tables_for_industry_mock(self, industry_name: str, industry_contexts: str): + results = """ + [{"table_name":"airlines","table_description":"Information about the airline companies","columns":[{"column_name":"airline_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"name","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"country","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false}]},{"table_name":"flights","table_description":"Information about flights operated by airlines","columns":[{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"airline_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true},{"column_name":"origin","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"destination","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"departure_time","data_type":"datetime","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"arrival_time","data_type":"datetime","is_null":false,"is_primary_key":false,"is_foreign_key":false}]},{"table_name":"passengers","table_description":"Information about passengers","columns":[{"column_name":"passenger_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"name","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"age","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"gender","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true}]},{"table_name":"bookings","table_description":"Information about flight bookings made by passengers","columns":[{"column_name":"booking_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"passenger_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true},{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true}]},{"table_name":"seats","table_description":"Information about seats available in flights","columns":[{"column_name":"seat_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true},{"column_name":"passenger_id","data_type":"integer","is_null":true,"is_primary_key":false,"is_foreign_key":true},{"column_name":"seat_number","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false}]},{"table_name":"airports","table_description":"Information about airports","columns":[{"column_name":"airport_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"name","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"location","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false}]}] + """ + + return results + + + def recommend_tables_for_industry_one_at_a_time(self, industry_name: str, industry_contexts: str, recommended_tables: str = None, index: int = None): + recommaned_tables_for_industry_template = """ + You're a data engineer and familiar with the {industry_name} industry IT systems. + And below is relevant contexts of the industry: + {industry_contexts} + + Please recommend some database tables with potential data schema for the above contexts. + Your response should strictly be in JSON format like below. + {{ + "table_name": "{{table name}}", + "table_description": "{{table description}}", + "columns": [ {{ - "stage": "standard", - "description": "{{description on data transformation logics}}" - }}, + "column_name": "{{column name}}", + "data_type": "{{data type}}", + "is_null": {{true or false (strictly in lower case)}}, + "is_primary_key": {{true or false (strictly in lower case)}}, + "is_foreign_key": {{true or false (strictly in lower case)}} + }} + ] + }} + + You've recommanded below tables in previous conversations. + {recommended_tables} + + Please recommend another one without duplicated table (if there's no table listed above, please go ahead to create a new one): + """ + + prompt = PromptTemplate( + input_variables=["industry_name", "industry_contexts", "recommended_tables"], + template=recommaned_tables_for_industry_template, + ) + chain = LLMChain(llm=self.llm, prompt=prompt) + results = self._run_chain(chain, {"industry_name": industry_name, + "industry_contexts": industry_contexts, + "recommended_tables": recommended_tables}) + + return results + + + def recommend_tables_for_industry_one_at_a_time_mock(self, industry_name: str, industry_contexts: str, index: int): + if index == 0: + results = """ + {"table_name":"flights","table_description":"Table to store information about flights","columns":[{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"airline_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true},{"column_name":"departure_city","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"arrival_city","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"departure_date","data_type":"date","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"departure_time","data_type":"time","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"arrival_date","data_type":"date","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"arrival_time","data_type":"time","is_null":false,"is_primary_key":false,"is_foreign_key":false}]} + """ + elif index == 1: + results = """ + {"table_name":"passengers","table_description":"Table to store information about passengers","columns":[{"column_name":"passenger_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"passenger_name","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"date_of_birth","data_type":"date","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"email","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"phone_number","data_type":"varchar(20)","is_null":true,"is_primary_key":false,"is_foreign_key":false}]} + """ + elif index == 2: + results = """ + {"table_name":"airlines","table_description":"Table to store information about airline companies","columns":[{"column_name":"airline_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"airline_name","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"country","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false}]} + """ + elif index == 3: + results = """ + {"table_name":"tickets","table_description":"Table to store information about airline tickets","columns":[{"column_name":"ticket_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"passenger_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true},{"column_name":"flight_id","data_type":"integer","is_null":false,"is_primary_key":false,"is_foreign_key":true},{"column_name":"ticket_number","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"ticket_class","data_type":"varchar(50)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"ticket_price","data_type":"decimal(10,2)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"ticket_status","data_type":"varchar(50)","is_null":false,"is_primary_key":false,"is_foreign_key":false}]} + """ + elif index == 4: + results = """ + {"table_name":"airports","table_description":"Table to store information about airports","columns":[{"column_name":"airport_id","data_type":"integer","is_null":false,"is_primary_key":true,"is_foreign_key":false},{"column_name":"airport_name","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"city","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false},{"column_name":"country","data_type":"varchar(255)","is_null":false,"is_primary_key":false,"is_foreign_key":false}]} + """ + # Simulate request consuming response time + time.sleep(1) + + return results + + + def generate_custom_data_processing_logics_mock(self, + industry_name: str, + industry_contexts: str, + involved_tables: str, + custom_data_processing_logic: str, + output_table_name: str): + results = """ + {"sql":"SELECT flights.airline_id, flights.arrival_time, flights.departure_time, flights.destination, flights.flight_id, flights.origin, passengers.age, passengers.gender, passengers.name, passengers.passenger_id FROM flights JOIN bookings ON flights.flight_id = bookings.flight_id JOIN passengers ON bookings.passenger_id = passengers.passenger_id","schema":{"table_name":"std_flights_p","columns":[{"column_name":"airline_id","data_type":"long","is_null":true,"is_primary_key":false,"is_foreign_key":true},{"column_name":"arrival_time","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"departure_time","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"destination","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"flight_id","data_type":"long","is_null":true,"is_primary_key":false,"is_foreign_key":true},{"column_name":"origin","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"age","data_type":"long","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"gender","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"name","data_type":"string","is_null":true,"is_primary_key":false,"is_foreign_key":false},{"column_name":"passenger_id","data_type":"long","is_null":true,"is_primary_key":false,"is_foreign_key":true}]}} + """ + + return results + + + def recommend_custom_table(self, + industry_name: str, + industry_contexts: str, + recommened_tables: str, + custom_table_name: str, + custom_table_description: str): + """ Recommend custom/user-defined table for a input custom table name and description of a given industry. + + :param industry_name: industry name + :param industry_contexts: industry descriptions/contexts + :param recommened_tables: previously recommened tables in json string format + :param custom_table_name: custom table name + :param custom_table_description: custom table description + + :returns: recommened custom tables with schema in json format + """ + + recommend_custom_tables_template=""" + You're a data engineer and familiar with the {industry_name} industry IT systems. + And below is relevant contexts of the industry: + {industry_contexts} + + You've recommended below potential database tables and schema previously in json format. + {recommened_tables} + + Please help to add another {custom_table_name} table with the same table schema format, please include as many as possible columns reflecting the reality and the table description below. + {custom_table_description} + + Please only output the new added table without previously recommended tables, therefore the outcome would be: + """ + + prompt = PromptTemplate( + input_variables=["industry_name", "industry_contexts", "recommened_tables", "custom_table_name", "custom_table_description"], + template=recommend_custom_tables_template, + ) + chain = LLMChain(llm=self.llm, prompt=prompt) + results = self._run_chain(chain, {"industry_name": industry_name, + "industry_contexts": industry_contexts, + "recommened_tables": recommened_tables, + "custom_table_name": custom_table_name, + "custom_table_description": custom_table_description}) + + return results + + + def recommend_data_processing_logics(self, + industry_name: str, + industry_contexts: str, + recommened_tables: str, + processing_logic: str): + """ Recommend data processing logics for a given industry and sample tables. + + :param industry_name: industry name + :param industry_contexts: industry descriptions/contexts + :param recommened_tables: previously recommened tables in json string format + :param processing_logic: either data cleaning, data transformation or data aggregation + + :returns: recommened data processing logics in array of json format + """ + + recommend_data_cleaning_logics_template=""" + You're a data engineer and familiar with the {industry_name} industry IT systems. + And below is relevant contexts of the industry: + {industry_contexts} + + You've recommended below potential database tables and schema previously in json format. + {recommened_tables} + + Please recommend 5 to 7 {processing_logic} logic over the above tables with Spark SQL statements. + You response should be in an array of JSON format like below. + {{ + "description": "{{descriptions on the data cleaning logic}}", + "involved_tables": [ + "{{involved table X}}", + "{{involved table Y}}", + "{{involved table Z}}" + ], + "sql": "{{Spark SQL statement to do the data cleaning}}", + "schema": "{{cleaned table schema in json string format}}" + }} + + Therefore outcome would be: + """ + + prompt = PromptTemplate( + input_variables=["industry_name", "industry_contexts", "processing_logic", "recommened_tables"], + template=recommend_data_cleaning_logics_template, + ) + chain = LLMChain(llm=self.llm, prompt=prompt) + results = self._run_chain(chain, {"industry_name": industry_name, + "industry_contexts": industry_contexts, + "processing_logic": processing_logic, + "recommened_tables": recommened_tables}) + + return results + + + def generate_custom_data_processing_logics(self, + industry_name: str, + industry_contexts: str, + involved_tables: str, + custom_data_processing_logic: str, + output_table_name: str): + """ Generate custom data processing logics for input data processing requirements. + + :param industry_name: industry name + :param industry_contexts: industry descriptions/contexts + :param involved_tables: tables required by the custom data processing logic, in json string format + :param custom_data_processing_logic: custom data processing requirements + :param output_table_name: output/sink table name for processed data + + :returns: custom data processing logic json format + """ + + generate_custom_data_processing_logics_template = """ + You're a data engineer and familiar with the {industry_name} industry IT systems. + And below is relevant contexts of the industry: + {industry_contexts} + + We have database tables listed below with table_name and schema in JSON format. + {involved_tables} + + And below is the data processing requirement. + {custom_data_processing_logic} + + Please help to generate Spark SQL statement with output data schema. + And your response should be in JSON format like below. + {{ + "sql": "{{Spark SQL statement to do the data cleaning}}", + "schema": "{{output data schema in JSON string format}}" + }} + + And the above data schema string should follows below JSON format, while value for the "table_name" key should strictly be "{output_table_name}". + {{ + "table_name": "{output_table_name}", + "coloumns": [ {{ - "stage": "serving", - "description": "{{description on data aggregation logics}}" + "column_name": "{{column name}}", + "data_type": "{{data type}}", + "is_null": {{true or false}}, + "is_primary_key": {{true or false}}, + "is_foreign_key": {{true or false}} }} ] }} - ] - Therefore your answers are: - """ - - llm = _prepare_openapi_llm() - prompt = PromptTemplate( - input_variables=["industry_name"], - template=recommend_data_processing_scenario_template, - ) - chain = LLMChain(llm=llm, prompt=prompt) - response = chain({"industry_name": industry_name}) - results = response["text"] - - return results - -def recommend_tables_for_industry(industry_name: str, industry_contexts: str): - """ Recommend database tables for a given industry and relevant contexts. - - :param industry_name: industry name - :param industry_contexts: industry descriptions/contexts - - :returns: recommened tables with schema in array of json format - """ - - recommaned_tables_for_industry_template = """ - You're a data engineer and familiar with the {industry_name} industry IT systems. - And below is relevant contexts of the industry: - {industry_contexts} - - Please recommend some database tables with potential data schema for the above contexts. - Your response should be an array of JSON format objects like below. - {{ - "table_name": "{{table name}}", - "table_description": "{{table description}}", + + Therefore the outcome would be: + """ + + prompt = PromptTemplate( + input_variables=["industry_name", "industry_contexts", "involved_tables", "custom_data_processing_logic", "output_table_name"], + template=generate_custom_data_processing_logics_template, + ) + chain = LLMChain(llm=self.llm, prompt=prompt) + + # Run the chain only specifying the input variable. + results = self._run_chain(chain, {"industry_name": industry_name, + "industry_contexts": industry_contexts, + "custom_data_processing_logic": custom_data_processing_logic, + "involved_tables": involved_tables, + "output_table_name": output_table_name}) + + return results + + + def generate_sample_data(self, + industry_name: str, + number_of_lines: int, + target_table: str, + column_values_patterns: str): + """ Generate custom data processing logics for input data processing requirements. + + :param industry_name: industry name + :param number_of_lines: number of lines sample data required + :param target_table: target table name and its schema in json format + + :returns: generated sample data in array of json format + """ + + generate_sample_data_template = """ + You're a data engineer and familiar with the {industry_name} industry IT systems. + Please help to generate {number_of_lines} lines of sample data for below table with table schema in json format. + {target_table} + + And below are patterns of column values in json format, if it's not provided please ignore this requirement. + {column_values_patterns} + + And the sample data should strictly be in JSON format like below. [ {{ - "column_name": "{{column name}}", - "data_type": "{{data type}}", - "is_null": {{true or false}}, - "is_primary_key": {{true or false}}, - "is_foreign_key": {{true or false}} + "{{column X}}": "{{column value}}", + "{{column Y}}": "{{column value}}", + "{{column Z}}": "{{column value}}" }} ] - }} - - Please recommend 7 to 10 database tables: - """ - - llm = _prepare_openapi_llm() - prompt = PromptTemplate( - input_variables=["industry_name", "industry_contexts"], - template=recommaned_tables_for_industry_template, - ) - chain = LLMChain(llm=llm, prompt=prompt) - response = chain({"industry_name": industry_name, - "industry_contexts": industry_contexts}) - results = response["text"] - - return results - - -def recommend_tables_for_industry_mock(industry_name: str, industry_contexts: str): - results = """ - [ - { - "table_name": "airlines", - "table_description": "Information about the airline companies", - "columns": - [ - { - "column_name": "airline_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": true, - "is_foreign_key": false - }, - { - "column_name": "name", - "data_type": "varchar(255)", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - }, - { - "column_name": "country", - "data_type": "varchar(255)", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - } - ] - }, - { - "table_name": "flights", - "table_description": "Information about flights operated by airlines", - "columns": - [ - { - "column_name": "flight_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": true, - "is_foreign_key": false - }, - { - "column_name": "airline_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": true - }, - { - "column_name": "origin", - "data_type": "varchar(255)", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - }, - { - "column_name": "destination", - "data_type": "varchar(255)", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - }, - { - "column_name": "departure_time", - "data_type": "datetime", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - }, - { - "column_name": "arrival_time", - "data_type": "datetime", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - } - ] - }, - { - "table_name": "passengers", - "table_description": "Information about passengers", - "columns": - [ - { - "column_name": "passenger_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": true, - "is_foreign_key": false - }, - { - "column_name": "name", - "data_type": "varchar(255)", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - }, - { - "column_name": "age", - "data_type": "integer", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - }, - { - "column_name": "gender", - "data_type": "varchar(255)", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - }, - { - "column_name": "flight_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": true - } - ] - }, - { - "table_name": "bookings", - "table_description": "Information about flight bookings made by passengers", - "columns": - [ - { - "column_name": "booking_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": true, - "is_foreign_key": false - }, - { - "column_name": "passenger_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": true - }, - { - "column_name": "flight_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": true - } - ] - }, - { - "table_name": "seats", - "table_description": "Information about seats available in flights", - "columns": - [ - { - "column_name": "seat_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": true, - "is_foreign_key": false - }, - { - "column_name": "flight_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": true - }, - { - "column_name": "passenger_id", - "data_type": "integer", - "is_null": true, - "is_primary_key": false, - "is_foreign_key": true - }, - { - "column_name": "seat_number", - "data_type": "varchar(255)", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - } - ] - }, - { - "table_name": "airports", - "table_description": "Information about airports", - "columns": - [ - { - "column_name": "airport_id", - "data_type": "integer", - "is_null": false, - "is_primary_key": true, - "is_foreign_key": false - }, - { - "column_name": "name", - "data_type": "varchar(255)", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - }, - { - "column_name": "location", - "data_type": "varchar(255)", - "is_null": false, - "is_primary_key": false, - "is_foreign_key": false - } - ] - } - ] - """ - - return results - - -def recommend_custom_table(industry_name: str, - industry_contexts: str, - recommened_tables: str, - custom_table_name: str, - custom_table_description: str): - """ Recommend custom/user-defined table for a input custom table name and description of a given industry. - - :param industry_name: industry name - :param industry_contexts: industry descriptions/contexts - :param recommened_tables: previously recommened tables in json string format - :param custom_table_name: custom table name - :param custom_table_description: custom table description - - :returns: recommened custom tables with schema in json format - """ - - recommend_custom_tables_template=""" - You're a data engineer and familiar with the {industry_name} industry IT systems. - And below is relevant contexts of the industry: - {industry_contexts} - - You've recommended below potential database tables and schema previously in json format. - {recommened_tables} - - Please help to add another {custom_table_name} table with the same table schema format, please include as many as possible columns reflecting the reality and the table description below. - {custom_table_description} - - Please only output the new added table without previously recommended tables, therefore the outcome would be: - """ - - llm = _prepare_openapi_llm() - prompt = PromptTemplate( - input_variables=["industry_name", "industry_contexts", "recommened_tables", "custom_table_name", "custom_table_description"], - template=recommend_custom_tables_template, - ) - chain = LLMChain(llm=llm, prompt=prompt) - response = chain({"industry_name": industry_name, - "industry_contexts": industry_contexts, - "recommened_tables": recommened_tables, - "custom_table_name": custom_table_name, - "custom_table_description": custom_table_description}) - results = response["text"] - - return results - - -def recommend_data_processing_logics(industry_name: str, - industry_contexts: str, - recommened_tables: str, - processing_logic: str): - """ Recommend data processing logics for a given industry and sample tables. - - :param industry_name: industry name - :param industry_contexts: industry descriptions/contexts - :param recommened_tables: previously recommened tables in json string format - :param processing_logic: either data cleaning, data transformation or data aggregation - - :returns: recommened data processing logics in array of json format - """ - - recommend_data_cleaning_logics_template=""" - You're a data engineer and familiar with the {industry_name} industry IT systems. - And below is relevant contexts of the industry: - {industry_contexts} - - You've recommended below potential database tables and schema previously in json format. - {recommened_tables} - - Please recommend 5 to 7 {processing_logic} logic over the above tables with Spark SQL statements. - You response should be in an array of JSON format like below. - {{ - "description": "{{descriptions on the data cleaning logic}}", - "involved_tables": [ - "{{involved table X}}", - "{{involved table Y}}", - "{{involved table Z}}" - ], - "sql": "{{Spark SQL statement to do the data cleaning}}", - "schema": "{{cleaned table schema in json string format}}" - }} - - Therefore outcome would be: - """ - - llm = _prepare_openapi_llm() - prompt = PromptTemplate( - input_variables=["industry_name", "industry_contexts", "processing_logic", "recommened_tables"], - template=recommend_data_cleaning_logics_template, - ) - chain = LLMChain(llm=llm, prompt=prompt) - response = chain({"industry_name": industry_name, - "industry_contexts": industry_contexts, - "processing_logic": processing_logic, - "recommened_tables": recommened_tables}) - results = json.loads(response["text"]) - - return results - - -def generate_custom_data_processing_logics(industry_name: str, - industry_contexts: str, - involved_tables: str, - custom_data_processing_logic: str, - output_table_name: str): - """ Generate custom data processing logics for input data processing requirements. - - :param industry_name: industry name - :param industry_contexts: industry descriptions/contexts - :param involved_tables: tables required by the custom data processing logic, in json string format - :param custom_data_processing_logic: custom data processing requirements - :param output_table_name: output/sink table name for processed data - - :returns: custom data processing logic json format - """ - - generate_custom_data_processing_logics_template = """ - You're a data engineer and familiar with the {industry_name} industry IT systems. - And below is relevant contexts of the industry: - {industry_contexts} - - We have database tables listed below with table_name and schema in JSON format. - {involved_tables} - - And below is the data processing requirement. - {custom_data_processing_logic} - - Please help to generate Spark SQL statement with output data schema. - And your response should be in JSON format like below. - {{ - "sql": "{{Spark SQL statement to do the data cleaning}}", - "schema": "{{output data schema in JSON string format}}" - }} - - And the above data schema string should follows below JSON format, while value for the "table_name" key should strictly be "{output_table_name}". - {{ - "table_name": "{output_table_name}", - "coloumns": [ - {{ - "column_name": "{{column name}}", - "data_type": "{{data type}}", - "is_null": {{true or false}}, - "is_primary_key": {{true or false}}, - "is_foreign_key": {{true or false}} - }} - ] - }} - - Therefore the outcome would be: - """ - - llm = _prepare_openapi_llm() - prompt = PromptTemplate( - input_variables=["industry_name", "industry_contexts", "involved_tables", "custom_data_processing_logic", "output_table_name"], - template=generate_custom_data_processing_logics_template, - ) - chain = LLMChain(llm=llm, prompt=prompt) - - # Run the chain only specifying the input variable. - response = chain({"industry_name": industry_name, - "industry_contexts": industry_contexts, - "custom_data_processing_logic": custom_data_processing_logic, - "involved_tables": involved_tables, - "output_table_name": output_table_name}) - results = response["text"] - - return results - - -def generate_sample_data(industry_name: str, - number_of_lines: int, - target_table: str, - column_values_patterns: str): - """ Generate custom data processing logics for input data processing requirements. - - :param industry_name: industry name - :param number_of_lines: number of lines sample data required - :param target_table: target table name and its schema in json format - - :returns: generated sample data in array of json format - """ - - generate_sample_data_template = """ - You're a data engineer and familiar with the {industry_name} industry IT systems. - Please help to generate {number_of_lines} lines of sample data for below table with table schema in json format. - {target_table} - - And below are patterns of column values in json format, if it's not provided please ignore this requirement. - {column_values_patterns} - - And the sample data should be an array of json object like below. - {{ - "{{column X}}": "{{column value}}", - "{{column Y}}": "{{column value}}", - "{{column Z}}": "{{column value}}" - }} - - The sample data would be: - """ - - llm = _prepare_openapi_llm() - prompt = PromptTemplate( - input_variables=["industry_name", "number_of_lines", "target_table", "column_values_patterns"], - template=generate_sample_data_template, - ) - chain = LLMChain(llm=llm, prompt=prompt) - response = chain({"industry_name": industry_name, - "number_of_lines": number_of_lines, - "target_table": target_table, - "column_values_patterns": column_values_patterns}) - results = response["text"] - - return results - - -def generate_sample_data_mock(industry_name: str, - number_of_lines: int, - target_table: str, - column_values_patterns: str): - if target_table["table_name"] == "flights": - results = """ - [ { "flight_id": 1, "airline_id": 1001, "origin": "New York", "destination": "Los Angeles", "departure_time": "2021-01-01 08:00:00", "arrival_time": "2021-01-01 11:30:00" }, { "flight_id": 2, "airline_id": 1002, "origin": "London", "destination": "Paris", "departure_time": "2021-01-02 14:30:00", "arrival_time": "2021-01-02 16:00:00" }, { "flight_id": 3, "airline_id": 1003, "origin": "Tokyo", "destination": "Sydney", "departure_time": "2021-01-03 10:45:00", "arrival_time": "2021-01-04 06:15:00" }, { "flight_id": 4, "airline_id": 1004, "origin": "Chicago", "destination": "Miami", "departure_time": "2021-01-05 16:20:00", "arrival_time": "2021-01-05 19:45:00" }, { "flight_id": 5, "airline_id": 1005, "origin": "Sydney", "destination": "Melbourne", "departure_time": "2021-01-06 09:15:00", "arrival_time": "2021-01-06 10:30:00" }, { "flight_id": 6, "airline_id": 1001, "origin": "Los Angeles", "destination": "New York", "departure_time": "2021-01-07 12:00:00", "arrival_time": "2021-01-07 15:30:00" }, { "flight_id": 7, "airline_id": 1002, "origin": "Paris", "destination": "London", "departure_time": "2021-01-08 18:45:00", "arrival_time": "2021-01-08 20:15:00" }, { "flight_id": 8, "airline_id": 1003, "origin": "Sydney", "destination": "Tokyo", "departure_time": "2021-01-09 14:30:00", "arrival_time": "2021-01-10 08:00:00" }, { "flight_id": 9, "airline_id": 1004, "origin": "Miami", "destination": "Chicago", "departure_time": "2021-01-11 20:00:00", "arrival_time": "2021-01-11 23:25:00" }, { "flight_id": 10, "airline_id": 1005, "origin": "Melbourne", "destination": "Sydney", "departure_time": "2021-01-12 13:45:00", "arrival_time": "2021-01-12 15:00:00" }, { "flight_id": 11, "airline_id": 1001, "origin": "New York", "destination": "Los Angeles", "departure_time": "2021-01-13 08:00:00", "arrival_time": "2021-01-13 11:30:00" }, { "flight_id": 12, "airline_id": 1002, "origin": "London", "destination": "Paris", "departure_time": "2021-01-14 14:30:00", "arrival_time": "2021-01-14 16:00:00" }, { "flight_id": 13, "airline_id": 1003, "origin": "Tokyo", "destination": "Sydney", "departure_time": "2021-01-15 10:45:00", "arrival_time": "2021-01-16 06:15:00" }, { "flight_id": 14, "airline_id": 1004, "origin": "Chicago", "destination": "Miami", "departure_time": "2021-01-17 16:20:00", "arrival_time": "2021-01-17 19:45:00" }, { "flight_id": 15, "airline_id": 1005, "origin": "Sydney", "destination": "Melbourne", "departure_time": "2021-01-18 09:15:00", "arrival_time": "2021-01-18 10:30:00" }, { "flight_id": 16, "airline_id": 1001, "origin": "Los Angeles", "destination": "New York", "departure_time": "2021-01-19 12:00:00", "arrival_time": "2021-01-19 15:30:00" }, { "flight_id": 17, "airline_id": 1002, "origin": "Paris", "destination": "London", "departure_time": "2021-01-20 18:45:00", "arrival_time": "2021-01-20 20:15:00" }, { "flight_id": 18, "airline_id": 1003, "origin": "Sydney", "destination": "Tokyo", "departure_time": "2021-01-21 14:30:00", "arrival_time": "2021-01-22 08:00:00" }, { "flight_id": 19, "airline_id": 1004, "origin": "Miami", "destination": "Chicago", "departure_time": "2021-01-23 20:00:00", "arrival_time": "2021-01-23 23:25:00" }, { "flight_id": 20, "airline_id": 1005, "origin": "Melbourne", "destination": "Sydney", "departure_time": "2021-01-24 13:45:00", "arrival_time": "2021-01-24 15:00:00" } ] - """ - if target_table["table_name"] == "passengers": - results = """ - [ { "passenger_id": 1, "name": "John Smith", "age": 35, "gender": "Male", "flight_id": 1 }, { "passenger_id": 2, "name": "Jane Doe", "age": 45, "gender": "Female", "flight_id": 1 }, { "passenger_id": 3, "name": "Michael Johnson", "age": 60, "gender": "Male", "flight_id": 2 }, { "passenger_id": 4, "name": "Emily Williams", "age": 25, "gender": "Female", "flight_id": 3 }, { "passenger_id": 5, "name": "David Brown", "age": 55, "gender": "Male", "flight_id": 4 }, { "passenger_id": 6, "name": "Sarah Davis", "age": 30, "gender": "Female", "flight_id": 4 }, { "passenger_id": 7, "name": "Robert Martinez", "age": 65, "gender": "Male", "flight_id": 5 }, { "passenger_id": 8, "name": "Jessica Thomas", "age": 40, "gender": "Female", "flight_id": 5 }, { "passenger_id": 9, "name": "Christopher Wilson", "age": 50, "gender": "Male", "flight_id": 1 }, { "passenger_id": 10, "name": "Stephanie Taylor", "age": 27, "gender": "Female", "flight_id": 2 }, { "passenger_id": 11, "name": "Daniel Anderson", "age": 65, "gender": "Male", "flight_id": 3 }, { "passenger_id": 12, "name": "Melissa Thompson", "age": 42, "gender": "Female", "flight_id": 4 }, { "passenger_id": 13, "name": "Matthew White", "age": 32, "gender": "Male", "flight_id": 5 }, { "passenger_id": 14, "name": "Amanda Harris", "age": 52, "gender": "Female", "flight_id": 1 }, { "passenger_id": 15, "name": "Andrew Lee", "age": 65, "gender": "Male", "flight_id": 2 }, { "passenger_id": 16, "name": "Jennifer Clark", "age": 28, "gender": "Female", "flight_id": 3 }, { "passenger_id": 17, "name": "James Rodriguez", "age": 62, "gender": "Male", "flight_id": 4 }, { "passenger_id": 18, "name": "Nicole Walker", "age": 38, "gender": "Female", "flight_id": 5 }, { "passenger_id": 19, "name": "Ryan Wright", "age": 47, "gender": "Male", "flight_id": 1 }, { "passenger_id": 20, "name": "Lauren Hall", "age": 31, "gender": "Female", "flight_id": 2 } ] - """ - - if target_table["table_name"] == "bookings": - results = """ - [ { "booking_id": 1, "passenger_id": 1, "flight_id": 1 }, { "booking_id": 2, "passenger_id": 2, "flight_id": 2 }, { "booking_id": 3, "passenger_id": 3, "flight_id": 3 }, { "booking_id": 4, "passenger_id": 4, "flight_id": 4 }, { "booking_id": 5, "passenger_id": 5, "flight_id": 5 }, { "booking_id": 6, "passenger_id": 1, "flight_id": 2 }, { "booking_id": 7, "passenger_id": 2, "flight_id": 3 }, { "booking_id": 8, "passenger_id": 3, "flight_id": 4 }, { "booking_id": 9, "passenger_id": 4, "flight_id": 5 }, { "booking_id": 10, "passenger_id": 5, "flight_id": 1 }, { "booking_id": 11, "passenger_id": 1, "flight_id": 3 }, { "booking_id": 12, "passenger_id": 2, "flight_id": 4 }, { "booking_id": 13, "passenger_id": 3, "flight_id": 5 }, { "booking_id": 14, "passenger_id": 4, "flight_id": 1 }, { "booking_id": 15, "passenger_id": 5, "flight_id": 2 }, { "booking_id": 16, "passenger_id": 1, "flight_id": 4 }, { "booking_id": 17, "passenger_id": 2, "flight_id": 5 }, { "booking_id": 18, "passenger_id": 3, "flight_id": 1 }, { "booking_id": 19, "passenger_id": 4, "flight_id": 2 }, { "booking_id": 20, "passenger_id": 5, "flight_id": 3 } ] + Therefore the sample data would be: """ - return results \ No newline at end of file + prompt = PromptTemplate( + input_variables=["industry_name", "number_of_lines", "target_table", "column_values_patterns"], + template=generate_sample_data_template, + ) + chain = LLMChain(llm=self.llm, prompt=prompt) + results = self._run_chain(chain, {"industry_name": industry_name, + "number_of_lines": number_of_lines, + "target_table": target_table, + "column_values_patterns": column_values_patterns}) + + return results + + + def generate_sample_data_mock(self, + industry_name: str, + number_of_lines: int, + target_table: str, + column_values_patterns: str): + if target_table["table_name"] == "flights": + results = """ + [ { "flight_id": 1, "airline_id": 1001, "origin": "New York", "destination": "Los Angeles", "departure_time": "2021-01-01 08:00:00", "arrival_time": "2021-01-01 11:30:00" }, { "flight_id": 2, "airline_id": 1002, "origin": "London", "destination": "Paris", "departure_time": "2021-01-02 14:30:00", "arrival_time": "2021-01-02 16:00:00" }, { "flight_id": 3, "airline_id": 1003, "origin": "Tokyo", "destination": "Sydney", "departure_time": "2021-01-03 10:45:00", "arrival_time": "2021-01-04 06:15:00" }, { "flight_id": 4, "airline_id": 1004, "origin": "Chicago", "destination": "Miami", "departure_time": "2021-01-05 16:20:00", "arrival_time": "2021-01-05 19:45:00" }, { "flight_id": 5, "airline_id": 1005, "origin": "Sydney", "destination": "Melbourne", "departure_time": "2021-01-06 09:15:00", "arrival_time": "2021-01-06 10:30:00" }, { "flight_id": 6, "airline_id": 1001, "origin": "Los Angeles", "destination": "New York", "departure_time": "2021-01-07 12:00:00", "arrival_time": "2021-01-07 15:30:00" }, { "flight_id": 7, "airline_id": 1002, "origin": "Paris", "destination": "London", "departure_time": "2021-01-08 18:45:00", "arrival_time": "2021-01-08 20:15:00" }, { "flight_id": 8, "airline_id": 1003, "origin": "Sydney", "destination": "Tokyo", "departure_time": "2021-01-09 14:30:00", "arrival_time": "2021-01-10 08:00:00" }, { "flight_id": 9, "airline_id": 1004, "origin": "Miami", "destination": "Chicago", "departure_time": "2021-01-11 20:00:00", "arrival_time": "2021-01-11 23:25:00" }, { "flight_id": 10, "airline_id": 1005, "origin": "Melbourne", "destination": "Sydney", "departure_time": "2021-01-12 13:45:00", "arrival_time": "2021-01-12 15:00:00" }, { "flight_id": 11, "airline_id": 1001, "origin": "New York", "destination": "Los Angeles", "departure_time": "2021-01-13 08:00:00", "arrival_time": "2021-01-13 11:30:00" }, { "flight_id": 12, "airline_id": 1002, "origin": "London", "destination": "Paris", "departure_time": "2021-01-14 14:30:00", "arrival_time": "2021-01-14 16:00:00" }, { "flight_id": 13, "airline_id": 1003, "origin": "Tokyo", "destination": "Sydney", "departure_time": "2021-01-15 10:45:00", "arrival_time": "2021-01-16 06:15:00" }, { "flight_id": 14, "airline_id": 1004, "origin": "Chicago", "destination": "Miami", "departure_time": "2021-01-17 16:20:00", "arrival_time": "2021-01-17 19:45:00" }, { "flight_id": 15, "airline_id": 1005, "origin": "Sydney", "destination": "Melbourne", "departure_time": "2021-01-18 09:15:00", "arrival_time": "2021-01-18 10:30:00" }, { "flight_id": 16, "airline_id": 1001, "origin": "Los Angeles", "destination": "New York", "departure_time": "2021-01-19 12:00:00", "arrival_time": "2021-01-19 15:30:00" }, { "flight_id": 17, "airline_id": 1002, "origin": "Paris", "destination": "London", "departure_time": "2021-01-20 18:45:00", "arrival_time": "2021-01-20 20:15:00" }, { "flight_id": 18, "airline_id": 1003, "origin": "Sydney", "destination": "Tokyo", "departure_time": "2021-01-21 14:30:00", "arrival_time": "2021-01-22 08:00:00" }, { "flight_id": 19, "airline_id": 1004, "origin": "Miami", "destination": "Chicago", "departure_time": "2021-01-23 20:00:00", "arrival_time": "2021-01-23 23:25:00" }, { "flight_id": 20, "airline_id": 1005, "origin": "Melbourne", "destination": "Sydney", "departure_time": "2021-01-24 13:45:00", "arrival_time": "2021-01-24 15:00:00" } ] + """ + + if target_table["table_name"] == "passengers": + results = """ + [ { "passenger_id": 1, "name": "John Smith", "age": 35, "gender": "Male", "flight_id": 1 }, { "passenger_id": 2, "name": "Jane Doe", "age": 45, "gender": "Female", "flight_id": 1 }, { "passenger_id": 3, "name": "Michael Johnson", "age": 60, "gender": "Male", "flight_id": 2 }, { "passenger_id": 4, "name": "Emily Williams", "age": 25, "gender": "Female", "flight_id": 3 }, { "passenger_id": 5, "name": "David Brown", "age": 55, "gender": "Male", "flight_id": 4 }, { "passenger_id": 6, "name": "Sarah Davis", "age": 30, "gender": "Female", "flight_id": 4 }, { "passenger_id": 7, "name": "Robert Martinez", "age": 65, "gender": "Male", "flight_id": 5 }, { "passenger_id": 8, "name": "Jessica Thomas", "age": 40, "gender": "Female", "flight_id": 5 }, { "passenger_id": 9, "name": "Christopher Wilson", "age": 50, "gender": "Male", "flight_id": 1 }, { "passenger_id": 10, "name": "Stephanie Taylor", "age": 27, "gender": "Female", "flight_id": 2 }, { "passenger_id": 11, "name": "Daniel Anderson", "age": 65, "gender": "Male", "flight_id": 3 }, { "passenger_id": 12, "name": "Melissa Thompson", "age": 42, "gender": "Female", "flight_id": 4 }, { "passenger_id": 13, "name": "Matthew White", "age": 32, "gender": "Male", "flight_id": 5 }, { "passenger_id": 14, "name": "Amanda Harris", "age": 52, "gender": "Female", "flight_id": 1 }, { "passenger_id": 15, "name": "Andrew Lee", "age": 65, "gender": "Male", "flight_id": 2 }, { "passenger_id": 16, "name": "Jennifer Clark", "age": 28, "gender": "Female", "flight_id": 3 }, { "passenger_id": 17, "name": "James Rodriguez", "age": 62, "gender": "Male", "flight_id": 4 }, { "passenger_id": 18, "name": "Nicole Walker", "age": 38, "gender": "Female", "flight_id": 5 }, { "passenger_id": 19, "name": "Ryan Wright", "age": 47, "gender": "Male", "flight_id": 1 }, { "passenger_id": 20, "name": "Lauren Hall", "age": 31, "gender": "Female", "flight_id": 2 } ] + """ + + if target_table["table_name"] == "bookings": + results = """ + [ { "booking_id": 1, "passenger_id": 1, "flight_id": 1 }, { "booking_id": 2, "passenger_id": 2, "flight_id": 2 }, { "booking_id": 3, "passenger_id": 3, "flight_id": 3 }, { "booking_id": 4, "passenger_id": 4, "flight_id": 4 }, { "booking_id": 5, "passenger_id": 5, "flight_id": 5 }, { "booking_id": 6, "passenger_id": 1, "flight_id": 2 }, { "booking_id": 7, "passenger_id": 2, "flight_id": 3 }, { "booking_id": 8, "passenger_id": 3, "flight_id": 4 }, { "booking_id": 9, "passenger_id": 4, "flight_id": 5 }, { "booking_id": 10, "passenger_id": 5, "flight_id": 1 }, { "booking_id": 11, "passenger_id": 1, "flight_id": 3 }, { "booking_id": 12, "passenger_id": 2, "flight_id": 4 }, { "booking_id": 13, "passenger_id": 3, "flight_id": 5 }, { "booking_id": 14, "passenger_id": 4, "flight_id": 1 }, { "booking_id": 15, "passenger_id": 5, "flight_id": 2 }, { "booking_id": 16, "passenger_id": 1, "flight_id": 4 }, { "booking_id": 17, "passenger_id": 2, "flight_id": 5 }, { "booking_id": 18, "passenger_id": 3, "flight_id": 1 }, { "booking_id": 19, "passenger_id": 4, "flight_id": 2 }, { "booking_id": 20, "passenger_id": 5, "flight_id": 3 } ] + """ + + return results diff --git a/src/streamlit/pages/1_Editor.py b/src/streamlit/pages/1_Editor.py index 67567b6..5fbee7e 100644 --- a/src/streamlit/pages/1_Editor.py +++ b/src/streamlit/pages/1_Editor.py @@ -185,6 +185,13 @@ def run_task(task_name, stage="standard"): def delete_task(type, index): if type == "staging": + # Also sync checkbox status in the AI Assistant page + generated_tables = st.session_state["current_generated_tables"]["generated_tables"] + for table in generated_tables: + if table["table_name"] == current_pipeline_obj['staging'][index]["name"]: + table["staged_flag"] = False + break + del current_pipeline_obj['staging'][index] elif type == "standard": del current_pipeline_obj['standard'][index] @@ -691,9 +698,9 @@ def clean_vis_data(vis_name): selected_y_axis_index = 1 if len(cols) > 1 else 0 for j in range(len(cols)): - if cols[j] == current_pipeline_obj['visualization'][i]['x_axis']: + if 'x_axis' in current_pipeline_obj['visualization'] and cols[j] == current_pipeline_obj['visualization'][i]['x_axis']: selected_x_axis_index = j - if cols[j] == current_pipeline_obj['visualization'][i]['y_axis']: + if 'y_axis' in current_pipeline_obj['visualization'] and cols[j] == current_pipeline_obj['visualization'][i]['y_axis']: selected_y_axis_index = j x_axis = st.selectbox('X Axis', cols, key=f'vis_x_axis_{i}', index=selected_x_axis_index) diff --git a/src/streamlit/pages/2_AI Assistant.py b/src/streamlit/pages/2_AI Assistant.py index 7a74f3a..8757705 100644 --- a/src/streamlit/pages/2_AI Assistant.py +++ b/src/streamlit/pages/2_AI Assistant.py @@ -6,9 +6,10 @@ import sys import streamlit_utils sys.path.append(os.path.join(os.path.dirname(__file__), '../..')) -import cddp.openai_api as openai_api +from cddp.openai_api import OpenaiApi from streamlit_extras.switch_page_button import switch_page from streamlit_extras.colored_header import colored_header +from time import sleep if "current_pipeline_obj" not in st.session_state: switch_page("Home") @@ -19,11 +20,12 @@ st.set_page_config(page_title="AI Assistant") colored_header( - label="AI Assiatant", + label="AI Assistant", description=f"Leverage AI to assist you in data pipeline development", color_name="violet-70", ) +openai_api = OpenaiApi() @@ -56,9 +58,13 @@ st.session_state["generated_usecases"] = False # Reset button clicked status with generate_use_cases_col2: with st.spinner('Generating...'): - usecases = openai_api.recommend_data_processing_scenario(current_pipeline_obj.get("industry", "")) - st.session_state['current_generated_usecases'] = json.loads(usecases) - + try: + usecases = openai_api.recommend_data_processing_scenario(current_pipeline_obj.get("industry", "")) + st.session_state['current_generated_usecases'] = json.loads(usecases) + except ValueError as e: + st.error("Got invalid response from AI Assistant, please try again!") + except Exception as e: + st.error("Got error while getting help from AI Assistant, please try again!") if "current_generated_usecases" in st.session_state: usecases = st.session_state['current_generated_usecases'] @@ -76,9 +82,6 @@ st.divider() - - - with tab_data: # Initialize current_generated_tables key in session state if "current_generated_tables" not in st.session_state: @@ -93,233 +96,301 @@ if "current_generated_sample_data" not in st.session_state: st.session_state['current_generated_sample_data'] = {} current_generated_sample_data = st.session_state['current_generated_sample_data'] -# AI Assistnt for tables generation + + if "disable_generate_table_button" not in st.session_state: + st.session_state["disable_generate_table_button"] = False + + # AI Assistnt for tables generation tables = [] - generate_tables_col1, generate_tables_col2 = st.columns(2) - with generate_tables_col1: - st.button("Generate", on_click=streamlit_utils.click_button, kwargs={"button_name": "generate_tables"}, use_container_width=True) - - if "generate_tables" not in st.session_state: - st.session_state["generate_tables"] = False - elif st.session_state["generate_tables"]: - st.session_state["generate_tables"] = False # Reset button clicked status - with generate_tables_col2: - with st.spinner('Generating...'): - tables = openai_api.recommend_tables_for_industry(industry_name, industry_contexts) - current_generated_tables["generated_tables"] = tables - - try: - if "generated_tables" in current_generated_tables: - tables = json.loads(current_generated_tables["generated_tables"]) - - if "selected_tables" not in current_generated_tables: - current_generated_tables["selected_tables"] = [] - - for table in tables: - columns = table["columns"] - columns_df = pd.DataFrame.from_dict(columns, orient='columns') - - check_flag = False - for selected_table in current_generated_tables["selected_tables"]: - if selected_table == table["table_name"]: - check_flag = True - - sample_data = None - with st.expander(table["table_name"]): - gen_table_name = table["table_name"] - gen_table_desc = table["table_description"] - - st.write(gen_table_desc) - st.write(columns_df) - - st.write(f"Generate sample data") - rows_count = st.slider("Number of rows", min_value=5, max_value=50, key=f'gen_rows_count_slider_{gen_table_name}') - enable_data_requirements = st.toggle("With extra sample data requirements", key=f'data_requirements_toggle_{gen_table_name}') - data_requirements = "" - if enable_data_requirements: - data_requirements = st.text_area("Extra requirements for sample data", - key=f'data_requirements_text_area_{gen_table_name}', - placeholder="Exp: value of column X should follow patterns xxx-xxxx, while x could be A-Z or 0-9") - - generate_sample_data_col1, generate_sample_data_col2 = st.columns(2) - with generate_sample_data_col1: - st.button("Generate Sample Data", - key=f"generate_data_button_{gen_table_name}", - on_click=streamlit_utils.click_button, - kwargs={"button_name": f"generate_sample_data_{gen_table_name}"}) - - if f"generate_sample_data_{gen_table_name}" not in st.session_state: - st.session_state[f"generate_sample_data_{gen_table_name}"] = False - - if f"{gen_table_name}_smaple_data_generated" not in st.session_state: - st.session_state[f"{gen_table_name}_smaple_data_generated"] = False - elif st.session_state[f"generate_sample_data_{gen_table_name}"]: - st.session_state[f"generate_sample_data_{gen_table_name}"] = False # Reset clicked status - if not st.session_state[f"{gen_table_name}_smaple_data_generated"]: - with generate_sample_data_col2: - with st.spinner('Generating...'): - sample_data = openai_api.generate_sample_data(industry_name, - rows_count, - table, - data_requirements) - st.session_state[f"{gen_table_name}_smaple_data_generated"] = True - # Store generated data to session_state - current_generated_sample_data[gen_table_name] = sample_data - - # Also update current_pipeline_obj if checked check-box before generating sample data - if sample_data and st.session_state[gen_table_name]: - spark = st.session_state["spark"] - json_str, schema = cddp.load_sample_data(spark, sample_data, format="json") - - for index, dataset in enumerate(current_pipeline_obj['staging']): - if dataset["name"] == gen_table_name: - i = index - - current_pipeline_obj['staging'][i]['sampleData'] = json.loads(json_str) - current_pipeline_obj['staging'][i]['schema'] = json.loads(schema) - - if st.session_state[f"{gen_table_name}_smaple_data_generated"]: - st.session_state[f"{gen_table_name}_smaple_data_generated"] = False # Reset data generated flag - json_sample_data = json.loads(sample_data) - current_generated_sample_data[gen_table_name] = json_sample_data # Save generated data to session_state - # st.session_state[f'stg_{i}_data'] = sample_data - - if gen_table_name in current_generated_sample_data: - sample_data_df = pd.DataFrame.from_dict(current_generated_sample_data[gen_table_name], orient='columns') - st.write(sample_data_df) - - st.checkbox("Add to staging zone", - key=gen_table_name, - value=check_flag, - on_change=streamlit_utils.add_to_staging_zone, - args=[gen_table_name, gen_table_desc]) - - - # TODO we need to change the key of table["table_name"] - if st.session_state[table["table_name"]] and table["table_name"] not in current_generated_tables["selected_tables"]: - current_generated_tables["selected_tables"].append(table["table_name"]) - - - - except ValueError as e: - # TODO: Add error/exception to standard error-showing widget - st.write(tables) + placeholder = st.empty() + container = placeholder.container() + with container: + st.button("Generate", + on_click=streamlit_utils.generate_tables, + args=[placeholder, + current_generated_tables, + current_pipeline_obj, + current_generated_sample_data, + industry_name, + industry_contexts, + openai_api], + disabled=st.session_state["disable_generate_table_button"], + use_container_width=True) + + # Show warning message if some of generated tables have been referenced by other std/srv tasks + if streamlit_utils.has_staged_table() and st.session_state["has_clicked_generate_tables_btn"]: + with container: + st.warning("Some of current generated tables have been added to Staging zone, please remove them before generating tables again!") + st.session_state["has_clicked_generate_tables_btn"] = False + st.session_state["disable_generate_data_widget"] = False + + # Render generated tables once widgets inside table expander has been clicked/updated + if "generated_tables" in current_generated_tables and not st.session_state["has_clicked_generate_tables_btn"]: + tables = current_generated_tables["generated_tables"] + with container: + for gen_table_index, table in enumerate(tables): + streamlit_utils.render_table_expander(table, + current_generated_tables, + current_generated_sample_data, + current_pipeline_obj, + gen_table_index, + industry_name, + openai_api) + elif "generate_tables_count" in st.session_state and len(current_generated_tables.get("generated_tables", [])) == st.session_state["generate_tables_count"]: + st.session_state["disable_generate_data_widget"] = False + tables = current_generated_tables["generated_tables"] + + with placeholder.container(): + st.button("Generate", + key="generate_table_redraw", + on_click=streamlit_utils.generate_tables, + args=[placeholder, + current_generated_tables, + current_pipeline_obj, + current_generated_sample_data, + industry_name, + industry_contexts, + openai_api], + disabled=st.session_state["disable_generate_table_button"], + use_container_width=True) + + for gen_table_index, table in enumerate(tables): + streamlit_utils.render_table_expander(table, + current_generated_tables, + current_generated_sample_data, + current_pipeline_obj, + gen_table_index, + industry_name, + openai_api, + "redraw") + with tab_std_sql: + if "current_editing_pipeline_tasks" not in st.session_state: + st.session_state['current_editing_pipeline_tasks'] = {} + current_editing_pipeline_tasks = st.session_state['current_editing_pipeline_tasks'] + if "standard" not in current_editing_pipeline_tasks: + current_editing_pipeline_tasks['standard'] = [] + + # Get all staged table details + staged_table_names, staged_table_details = streamlit_utils.get_staged_tables() - # st.divider() - # st.subheader('Generate data transformation logic') + for i in range(len(current_pipeline_obj["standard"])): + target_name = current_pipeline_obj['standard'][i]['output']['target'] + disable_std_name_input = False + disable_std_task_deletion = False + std_name_has_dependency = streamlit_utils.check_tables_dependency(target_name) + + if std_name_has_dependency: + disable_std_name_input = True + disable_std_task_deletion = True + st.info("""This standardization task has been referenced by other tasks! + Please remove relevant dependency before trying to rename or delete this task.""") + + std_name = st.text_input('Transformation Name', + key=f'std_{i}_name', + value=target_name, + disabled=disable_std_name_input) + + if std_name: + st.subheader(std_name) + with st.expander(std_name+" Settings", expanded=True): + current_pipeline_obj['standard'][i]['output']['target'] = std_name + current_pipeline_obj['standard'][i]['name'] = std_name + + # Get latest standardized table details + standardized_table_names, standardized_table_details = streamlit_utils.get_std_srv_tables('standard') + + optional_tables = staged_table_names + standardized_table_names + if std_name in optional_tables: + optional_tables.remove(std_name) # Remove itself from the optional tables list + + if len(current_editing_pipeline_tasks['standard']) == i: + current_editing_pipeline_tasks['standard'].append({}) + current_editing_pipeline_tasks['standard'][i]['target'] = std_name + selected_staged_tables = st.multiselect( + 'Choose datasets to do the data transformation', + options=optional_tables, + default=current_editing_pipeline_tasks['standard'][i].get('involved_tables', None), + on_change=streamlit_utils.update_selected_tables, + key=f'std_{i}_involved_tables', + args=['standard', i, f'std_{i}_involved_tables']) + + if 'description' not in current_pipeline_obj['standard'][i]: + current_pipeline_obj['standard'][i]['description'] = "" + process_requirements = st.text_area("Transformation requirements", + key=f"std_{i}_transform_requirements", + value=current_pipeline_obj['standard'][i]['description']) + current_pipeline_obj['standard'][i]['description'] = process_requirements + + generate_transform_sql_col1, generate_transform_sql_col2 = st.columns(2) + with generate_transform_sql_col1: + st.button(f'Generate SQL', + key=f'std_{i}_gen_sql', + on_click=streamlit_utils.click_button, + kwargs={"button_name": f"std_{i}_gen_transform_sql"}) + + if len(current_pipeline_obj['standard'][i]['code']['sql']) == 0: + current_pipeline_obj['standard'][i]['code']['sql'].append("") + std_sql_val = current_pipeline_obj['standard'][i]['code']['sql'][0] + + # Get selected staged table details + selected_table_details = streamlit_utils.get_selected_tables_details(staged_table_details + standardized_table_details, + selected_staged_tables) + if st.session_state[f"std_{i}_gen_sql"]: + with generate_transform_sql_col2: + with st.spinner('Generating...'): + try: + process_logic = openai_api.generate_custom_data_processing_logics(industry_name=industry_name, + industry_contexts=industry_contexts, + involved_tables=selected_table_details, + custom_data_processing_logic=process_requirements, + output_table_name=std_name) + + process_logic_json = json.loads(process_logic) + std_sql_val = process_logic_json["sql"] + std_table_schema = process_logic_json["schema"] + current_editing_pipeline_tasks['standard'][i]['query_results_schema'] = std_table_schema + current_pipeline_obj['standard'][i]['code']['sql'][0] = std_sql_val + except ValueError as e: + st.error("Got invalid response from AI Assistant, please try again!") + except Exception as e: + st.error("Got error while getting help from AI Assistant, please try again!") + + std_sql = st.text_area(f'Transformation Spark SQL', + key=f'std_{i}_transform_sql_text_area', + value=std_sql_val) + + current_pipeline_obj['standard'][i]['code']['sql'][0] = std_sql + + st.button(f'Run SQL', key=f'run_std_{i}_sql', on_click=streamlit_utils.run_task, args = [std_name, "standard", i]) + if 'sql_query_results' in current_editing_pipeline_tasks['standard'][i]: + st.dataframe(current_editing_pipeline_tasks['standard'][i]['sql_query_results']) + + st.button(f"Delete {std_name}", + key="delete_std_"+str(i), + on_click=streamlit_utils.delete_task, + args = ['standard', i], + disabled=disable_std_task_deletion) + + if i != len(current_pipeline_obj["standard"]) - 1: + st.divider() + if len(current_pipeline_obj["standard"]) == 0: + st.write("No transformation in standardization zone") - if "standardized_tables" not in st.session_state: - st.session_state["standardized_tables"] = [] - standardized_tables = st.session_state["standardized_tables"] + st.divider() + st.button('Add Transformation', on_click=streamlit_utils.add_transformation, use_container_width=True) - if "current_generated_std_srv_sqls" not in st.session_state: - st.session_state['current_generated_std_srv_sqls'] = {} - current_generated_std_srv_sqls = st.session_state['current_generated_std_srv_sqls'] - # Get all staged table details - staged_table_names, staged_table_details = streamlit_utils.get_staged_tables() - # Get all standardized table details - standardized_table_names, standardized_table_details = streamlit_utils.get_standardized_tables() - - std_name = st.text_input("Transformation Name") - selected_staged_tables = st.multiselect( - 'Choose datasets to do the data transformation', - staged_table_names + standardized_table_names, - key=f'std_involved_tables') - process_requirements = st.text_area("Transformation requirements", key=f"std_transform_requirements") - - generate_transform_sql_col1, generate_transform_sql_col2 = st.columns(2) - with generate_transform_sql_col1: - st.button(f'Generate SQL', - key=f'std_gen_sql', - on_click=streamlit_utils.click_button, - kwargs={"button_name": f"std_gen_transform_sql"}) - - if "std_gen_transform_sql" not in st.session_state: - st.session_state["std_gen_transform_sql"] = False - if st.session_state["std_gen_transform_sql"]: - st.session_state["std_gen_transform_sql"] = False # Reset clicked status - with generate_transform_sql_col2: - with st.spinner('Generating...'): - process_logic = openai_api.generate_custom_data_processing_logics(industry_name=industry_name, - industry_contexts=industry_contexts, - involved_tables=staged_table_details + standardized_table_details, - custom_data_processing_logic=process_requirements, - output_table_name=std_name) - try: - process_logic_json = json.loads(process_logic) - std_sql_val = process_logic_json["sql"] - current_generated_std_srv_sqls[std_name] = std_sql_val - except ValueError as e: - st.write(process_logic) +with tab_srv_sql: - with st.expander(std_name, expanded=True): - st.button("Add to Standardization zone", - key=f"gen_transformation_{std_name}", - on_click=streamlit_utils.add_to_std_srv_zone, - args=[f"gen_transformation_{std_name}", std_name, process_requirements, "standard"]) - st.text_input("Invovled tables", value=", ".join(selected_staged_tables), disabled=True) - std_sql = st.text_area(f'Transformation Spark SQL', - key=f'std_transform_sql_{std_name}', - value=current_generated_std_srv_sqls[std_name], - on_change=streamlit_utils.update_sql, - args=[f'std_transform_sql_{std_name}', std_name]) + if "serving" not in current_editing_pipeline_tasks: + current_editing_pipeline_tasks['serving'] = [] + + for i in range(len(current_pipeline_obj["serving"])): + target_name = current_pipeline_obj['serving'][i]['output']['target'] + srv_name_has_dependency = streamlit_utils.check_tables_dependency(target_name) + disable_srv_name_input = False + disable_srv_task_deletion = False + + if srv_name_has_dependency: + disable_srv_name_input = True + disable_srv_task_deletion = True + st.info("""This serving task has been referenced by other tasks! + Please remove relevant dependency before trying to rename the task.""") + + srv_name = st.text_input('Aggregation Name', + key=f'srv_{i}_name', + value=target_name, + disabled=srv_name_has_dependency) + + if srv_name: + st.subheader(srv_name) + with st.expander(srv_name+" Settings", expanded=True): + current_pipeline_obj['serving'][i]['output']['target'] = srv_name + current_pipeline_obj['serving'][i]['name'] = srv_name + + # Get all standardized table details + serving_table_names, serving_table_details = streamlit_utils.get_std_srv_tables('serving') + optional_tables = staged_table_names + standardized_table_names + serving_table_names + if srv_name in optional_tables: + optional_tables.remove(srv_name) # Remove itself from the optional tables list + + if len(current_editing_pipeline_tasks['serving']) == i: + current_editing_pipeline_tasks['serving'].append({}) + current_editing_pipeline_tasks['serving'][i]['target'] = srv_name + + selected_staged_tables = st.multiselect( + 'Choose datasets to do the data aggregation', + options=optional_tables, + default=current_editing_pipeline_tasks['serving'][i].get('involved_tables', None), + on_change=streamlit_utils.update_selected_tables, + key=f'srv_{i}_involved_tables', + args=['serving', i, f'srv_{i}_involved_tables']) + + if 'description' not in current_pipeline_obj['serving'][i]: + current_pipeline_obj['serving'][i]['description'] = "" + process_requirements = st.text_area("Aggregation requirements", + key=f"srv_{i}_aggregate_requirements", + value=current_pipeline_obj['serving'][i]['description']) + current_pipeline_obj['serving'][i]['description'] = process_requirements + + generate_aggregate_sql_col1, generate_aggregate_sql_col2 = st.columns(2) + with generate_aggregate_sql_col1: + st.button(f'Generate SQL', + key=f'srv_{i}_gen_sql', + on_click=streamlit_utils.click_button, + kwargs={"button_name": f"srv_{i}_gen_aggregate_sql"}) + + if len(current_pipeline_obj['serving'][i]['code']['sql']) == 0: + current_pipeline_obj['serving'][i]['code']['sql'].append("") + srv_sql_val = current_pipeline_obj['serving'][i]['code']['sql'][0] + + # Get selected staged table details + selected_table_details = streamlit_utils.get_selected_tables_details(staged_table_details + standardized_table_details + serving_table_details, + selected_staged_tables) + if st.session_state[f"srv_{i}_gen_sql"]: + with generate_aggregate_sql_col2: + with st.spinner('Generating...'): + try: + process_logic = openai_api.generate_custom_data_processing_logics(industry_name=industry_name, + industry_contexts=industry_contexts, + involved_tables=selected_table_details, + custom_data_processing_logic=process_requirements, + output_table_name=srv_name) + + process_logic_json = json.loads(process_logic) + srv_sql_val = process_logic_json["sql"] + srv_table_schema = process_logic_json["schema"] + current_editing_pipeline_tasks['serving'][i]['query_results_schema'] = srv_table_schema + current_pipeline_obj['serving'][i]['code']['sql'][0] = srv_sql_val + except ValueError as e: + st.error("Got invalid response from AI Assistant, please try again!") + except Exception as e: + st.error("Got error while getting help from AI Assistant, please try again!") + + srv_sql = st.text_area(f'Aggregation Spark SQL', + key=f'srv_{i}_aggregate_sql_text_area', + value=srv_sql_val) + + current_pipeline_obj['serving'][i]['code']['sql'][0] = srv_sql + + st.button(f'Run SQL', key=f'run_srv_{i}_sql', on_click=streamlit_utils.run_task, args = [srv_name, "serving", i]) + if 'sql_query_results' in current_editing_pipeline_tasks['serving'][i]: + st.dataframe(current_editing_pipeline_tasks['serving'][i]['sql_query_results']) + + st.button(f"Delete {srv_name}", + key="delete_srv_"+str(i), + on_click=streamlit_utils.delete_task, + args = ['serving', i], + disabled=disable_srv_task_deletion) + + if i != len(current_pipeline_obj["serving"]) - 1: + st.divider() + if len(current_pipeline_obj["serving"]) == 0: + st.write("No aggregation in serving zone") -with tab_srv_sql: - # st.divider() - # st.subheader('Generate data aggregation logic') - - if "serving_tables" not in st.session_state: - st.session_state["serving_tables"] = [] - serving_tables = st.session_state["serving_tables"] - - # Get all standardized table details - standardized_table_names, standardized_table_details = streamlit_utils.get_standardized_tables() - - srv_name = st.text_input("Aggregation Name") - selected_stg_std_tables = st.multiselect( - 'Choose datasets to do the data aggregation', - staged_table_names + standardized_table_names, - key=f'srv_involved_tables') - process_requirements = st.text_area("Aggregation requirements", key=f"srv_aggregate_requirements") - - generate_aggregate_sql_col1, generate_aggregate_sql_col2 = st.columns(2) - with generate_aggregate_sql_col1: - st.button(f'Generate SQL', - key=f'srv_gen_sql', - on_click=streamlit_utils.click_button, - kwargs={"button_name": f"srv_gen_aggregate_sql"}) - - if "srv_gen_aggregate_sql" not in st.session_state: - st.session_state["srv_gen_aggregate_sql"] = False - if st.session_state["srv_gen_aggregate_sql"]: - st.session_state["srv_gen_aggregate_sql"] = False # Reset clicked status - with generate_aggregate_sql_col2: - with st.spinner('Generating...'): - process_logic = openai_api.generate_custom_data_processing_logics(industry_name=industry_name, - industry_contexts=industry_contexts, - involved_tables=staged_table_details, - custom_data_processing_logic=process_requirements, - output_table_name=srv_name) - try: - process_logic_json = json.loads(process_logic) - srv_sql_val = process_logic_json["sql"] - current_generated_std_srv_sqls[srv_name] = srv_sql_val - except ValueError as e: - st.write(process_logic) - - with st.expander(srv_name, expanded=True): - st.button("Add to Serving zone", - key=f"gen_aggregation_{srv_name}", - on_click=streamlit_utils.add_to_std_srv_zone, - args=[f"gen_aggregation_{srv_name}", srv_name, process_requirements, "serving"]) - st.text_input("Invovled tables", value=", ".join(selected_stg_std_tables), disabled=True) - std_sql = st.text_area(f'Transformation Spark SQL', - key=f'srv_aggregate_sql_{srv_name}', - value=current_generated_std_srv_sqls[srv_name], - on_change=streamlit_utils.update_sql, - args=[f'std_transform_sql_{srv_name}', srv_name]) \ No newline at end of file + st.divider() + st.button('Add Aggregation', on_click=streamlit_utils.add_aggregation, use_container_width=True) diff --git a/src/streamlit/streamlit_utils.py b/src/streamlit/streamlit_utils.py index 322d1fe..f411606 100644 --- a/src/streamlit/streamlit_utils.py +++ b/src/streamlit/streamlit_utils.py @@ -1,7 +1,11 @@ import cddp import json +import random import streamlit as st +import tempfile import uuid +import pandas as pd +from time import sleep def get_selected_tables(tables): @@ -30,10 +34,11 @@ def get_selected_table_details(tables, table_name): def get_selected_tables_details(tables, table_names): target_tables = [] - for table_name in table_names: - for table_details in tables: - if table_details["table_name"] == table_name: - target_tables.append(table_details) + if table_names: + for table_name in table_names: + for table_details in tables: + if table_details["table_name"] == table_name: + target_tables.append(table_details) return target_tables @@ -53,12 +58,15 @@ def update_sql(key, table_name): current_generated_std_srv_sqls[table_name] = st.session_state[key] -def add_to_staging_zone(stg_name, stg_desc): +def add_to_staging_zone(gen_table_index, stg_name, stg_desc): pipeline_obj = st.session_state['current_pipeline_obj'] current_generated_sample_data = st.session_state['current_generated_sample_data'] sample_data = current_generated_sample_data.get(stg_name, None) + generated_tables = st.session_state["current_generated_tables"]["generated_tables"] + + if st.session_state[f"add_to_staging_{gen_table_index}_checkbox"]: # Add to staging zone if checkbox is checked + generated_tables[gen_table_index]["staged_flag"] = True - if st.session_state[stg_name]: # Add to staging zone if checkbox is checked if sample_data: spark = st.session_state["spark"] json_str, schema = cddp.load_sample_data(spark, json.dumps(sample_data), format="json") @@ -69,23 +77,8 @@ def add_to_staging_zone(stg_name, stg_desc): json_schema = {} add_stg_dataset(pipeline_obj, stg_name, json_schema, json_sample_data) - # pipeline_obj["staging"].append({ - # "name": stg_name, - # "description": stg_desc, - # "input": { - # "type": "filestore", - # "format": "json", - # "path": f"/FileStore/cddp_apps/{pipeline_obj['id']}/landing/{task_name}", - # "read-type": "batch" - # }, - # "output": { - # "target": stg_name, - # "type": ["file", "view"] - # }, - # "schema": json_schema, - # "sampleData": json_sample_data - # }) else: # Remove staging task from staging zone if it's unchecked + generated_tables[gen_table_index]["staged_flag"] = False for index, obj in enumerate(pipeline_obj["staging"]): if obj["name"] == stg_name: del pipeline_obj['staging'][index] @@ -142,29 +135,25 @@ def add_std_srv_schema(zone, output_table_name, schema): current_std_srv_tables_schema[zone][output_table_name] = schema -def get_standardized_tables(): +def get_std_srv_tables(task_type): pipeline_obj = st.session_state['current_pipeline_obj'] - standard = pipeline_obj.get("standard", None) - - if "current_std_srv_tables_schema" not in st.session_state: - st.session_state['current_std_srv_tables_schema'] = {} - current_std_srv_tables_schema = st.session_state['current_std_srv_tables_schema'] + tasks = pipeline_obj.get(task_type, None) + current_editing_pipeline_tasks = st.session_state['current_editing_pipeline_tasks'] - if "standard" not in current_std_srv_tables_schema: - current_std_srv_tables_schema["standard"] = {} - - standardized_table_names = [] - standardized_table_details = [] - if standard: - for standardized_table in standard: - std_name = standardized_table["output"]["target"] - standardized_table_names.append(std_name) - standardized_table_details.append({ - "table_name": std_name, - "schema": current_std_srv_tables_schema["standard"].get(std_name, "") - }) + std_srv_table_names = [] + std_srv_table_details = [] + if tasks: + for std_srv_table in tasks: + task_name = std_srv_table["output"]["target"] + std_srv_table_names.append(task_name) + for index, task in enumerate(current_editing_pipeline_tasks[task_type]): + if task['target'] == task_name: + std_srv_table_details.append({ + "table_name": task_name, + "schema": current_editing_pipeline_tasks[task_type][index].get('query_results_schema', '') + }) - return standardized_table_names, standardized_table_names + return std_srv_table_names, std_srv_table_details def create_pipeline(): @@ -181,6 +170,7 @@ def create_pipeline(): } return st.session_state['current_pipeline_obj'] + def add_stg_dataset(pipeline_obj, task_name, schema={}, sample_data=[]): pipeline_obj["staging"].append({ "name": task_name, @@ -197,4 +187,320 @@ def add_stg_dataset(pipeline_obj, task_name, schema={}, sample_data=[]): }, "schema": schema, "sampleData": sample_data - }) \ No newline at end of file + }) + + +def run_task(task_name, stage="standard", index=None): + dataframe = None + current_editing_pipeline_tasks = st.session_state['current_editing_pipeline_tasks'] + try: + spark = st.session_state["spark"] + config = st.session_state["current_pipeline_obj"] + with tempfile.TemporaryDirectory() as tmpdir: + working_dir = tmpdir+"/"+config['name'] + cddp.init(spark, config, working_dir) + cddp.clean_database(spark, config) + cddp.init_database(spark, config) + try: + cddp.init_staging_sample_dataframe(spark, config) + except Exception as e: + print(e) + if stage in config: + for task in config[stage]: + + if task_name == task['name']: + print(f"start {stage} task: "+task_name) + res_df = None + if stage == "standard": + res_df = cddp.start_standard_job(spark, config, task, False, True) + elif stage == "serving": + res_df = cddp.start_serving_job(spark, config, task, False, True) + dataframe = res_df.toPandas() + # print(dataframe) + current_editing_pipeline_tasks[stage][index]['sql_query_results'] = dataframe + + res_schema = res_df.schema.json() + current_editing_pipeline_tasks[stage][index]['query_results_schema'] = json.loads(res_schema) + + except Exception as e: + st.error(f"Cannot run task: {e}") + + return dataframe + + +def add_transformation(): + task_name = "untitled"+str(len(st.session_state["current_pipeline_obj"]["standard"])+1) + st.session_state["current_pipeline_obj"]["standard"].append({ + "name": task_name, + "type": "batch", + "code": { + "lang": "sql", + "sql": [] + }, + "output": { + "target": task_name, + "type": ["file", "view"] + }, + "dependency":[] + }) + + +def delete_task(type, index): + current_pipeline_obj = st.session_state["current_pipeline_obj"] + + if type == "staging": + del current_pipeline_obj['staging'][index] + elif type == "standard": + del current_pipeline_obj['standard'][index] + del st.session_state['current_editing_pipeline_tasks']['standard'][index] + elif type == "serving": + del current_pipeline_obj['serving'][index] + del st.session_state['current_editing_pipeline_tasks']['serving'][index] + elif type == "visualization": + del current_pipeline_obj['visualization'][index] + + st.session_state['current_pipeline_obj'] = current_pipeline_obj + + +def update_selected_tables(task_type, index, multiselect_key): + st.session_state['current_editing_pipeline_tasks'][task_type][index]['involved_tables'] = st.session_state[multiselect_key] + + +def add_aggregation(): + task_name = "untitled"+str(len(st.session_state["current_pipeline_obj"]["serving"])+1) + st.session_state["current_pipeline_obj"]["serving"].append({ + "name": task_name, + "type": "batch", + "code": { + "lang": "sql", + "sql": [] + }, + "output": { + "target": task_name, + "type": ["file", "view"] + }, + "dependency":[] + }) + + +def has_staged_table(): + has_staged_table = False + if "generated_tables" in st.session_state["current_generated_tables"]: + generated_tables = st.session_state["current_generated_tables"]["generated_tables"] + for table in generated_tables: + if "staged_flag" in table and table["staged_flag"]: + has_staged_table = True + break + + return has_staged_table + + +def check_tables_dependency(target_name): + has_dependency = False + current_editing_pipeline_tasks = st.session_state['current_editing_pipeline_tasks'] + current_used_std_tables = [] + current_used_srv_tables = [] + for task in current_editing_pipeline_tasks['standard']: + current_used_std_tables += task.get('involved_tables', []) + for task in current_editing_pipeline_tasks['serving']: + current_used_srv_tables += task.get('involved_tables', []) + + if target_name in current_used_std_tables + current_used_srv_tables: + has_dependency = True + + return has_dependency + + +def widget_on_change(widget_key, index, session_state_key): + current_generated_tables = st.session_state['current_generated_tables']['generated_tables'] + current_generated_tables[index][session_state_key] = st.session_state[widget_key] + st.session_state["has_clicked_generate_tables_btn"] = False + + +def render_table_expander(table, + current_generated_tables, + current_generated_sample_data, + current_pipeline_obj, + gen_table_index, + industry_name, + openai_api, + key_suffix="init"): + sample_data = None + added_to_stage = current_generated_tables["generated_tables"][gen_table_index].get("staged_flag", False) + expander_label = table["table_name"] + + if added_to_stage: + expander_label += " :heavy_check_mark:" + + with st.expander(expander_label, expanded=added_to_stage): + gen_table_name = table["table_name"] + gen_table_desc = table["table_description"] + stg_name_has_dependency = check_tables_dependency(gen_table_name) + + gen_table_sample_data_count = current_generated_tables["generated_tables"][gen_table_index].get("sample_data_count", 5) + sample_data_requirements_flag = current_generated_tables["generated_tables"][gen_table_index].get("sample_data_requirements_flag", False) + data_requirements = current_generated_tables["generated_tables"][gen_table_index].get("data_requirements", "") + + + columns = table["columns"] + columns_df = pd.DataFrame.from_dict(columns, orient='columns') + + st.write(gen_table_desc) + st.write(columns_df) + + st.write(f"Generate sample data") + rows_count = st.slider("Number of rows", + min_value=5, + max_value=50, + value=gen_table_sample_data_count, + key=f'gen_rows_count_slider_{gen_table_name}_{key_suffix}', + on_change=widget_on_change, + args=[f'gen_rows_count_slider_{gen_table_name}_{key_suffix}', + gen_table_index, + 'sample_data_count'], + disabled=st.session_state["disable_generate_data_widget"]) + + enable_data_requirements = st.toggle("With extra sample data requirements", + value=sample_data_requirements_flag, + key=f'data_requirements_toggle_{gen_table_name}_{key_suffix}', + on_change=widget_on_change, + args=[f'data_requirements_toggle_{gen_table_name}_{key_suffix}', + gen_table_index, + 'sample_data_requirements_flag'], + disabled=st.session_state["disable_generate_data_widget"]) + + if enable_data_requirements: + data_requirements = st.text_area("Extra requirements for sample data", + value=data_requirements, + key=f'data_requirements_text_area_{gen_table_name}_{key_suffix}', + placeholder="Exp: value of column X should follow patterns xxx-xxxx, while x could be A-Z or 0-9", + on_change=widget_on_change, + args=[f'data_requirements_text_area_{gen_table_name}_{key_suffix}', + gen_table_index, + 'data_requirements'], + disabled=st.session_state["disable_generate_data_widget"]) + + generate_sample_data_col1, generate_sample_data_col2 = st.columns(2) + with generate_sample_data_col1: + st.button("Generate Sample Data", + key=f"generate_data_button_{gen_table_name}_{key_suffix}", + on_click=click_button, + kwargs={"button_name": f"generate_sample_data_{gen_table_name}"}, + disabled=st.session_state["disable_generate_data_widget"], + use_container_width=True) + + if f"generate_sample_data_{gen_table_name}" not in st.session_state: + st.session_state[f"generate_sample_data_{gen_table_name}"] = False + + if f"{gen_table_name}_smaple_data_generated" not in st.session_state: + st.session_state[f"{gen_table_name}_smaple_data_generated"] = False + elif st.session_state[f"generate_sample_data_{gen_table_name}"]: + st.session_state[f"generate_sample_data_{gen_table_name}"] = False # Reset clicked status + st.session_state["has_clicked_generate_tables_btn"] = False + + if not st.session_state[f"{gen_table_name}_smaple_data_generated"]: + with generate_sample_data_col2: + with st.spinner('Generating...'): + try: + sample_data = openai_api.generate_sample_data(industry_name, + rows_count, + table, + data_requirements) + st.session_state[f"{gen_table_name}_smaple_data_generated"] = True + # Store generated data to session_state + current_generated_sample_data[gen_table_name] = sample_data + + # Also update current_pipeline_obj if checked check-box before generating sample data + # if sample_data and st.session_state[f"add_to_staging_{gen_table_index}_checkbox"]: + if sample_data and st.session_state.get(f"add_to_staging_{gen_table_index}_checkbox", False): + spark = st.session_state["spark"] + json_str, schema = cddp.load_sample_data(spark, sample_data, format="json") + + for index, dataset in enumerate(current_pipeline_obj['staging']): + if dataset["name"] == gen_table_name: + i = index + + current_pipeline_obj['staging'][i]['sampleData'] = json.loads(json_str) + current_pipeline_obj['staging'][i]['schema'] = json.loads(schema) + except ValueError as e: + st.error("Got invalid response from AI Assistant, please try again!") + except Exception as e: + st.error("Got error while getting help from AI Assistant, please try again!") + + if st.session_state[f"{gen_table_name}_smaple_data_generated"]: + st.session_state[f"{gen_table_name}_smaple_data_generated"] = False # Reset data generated flag + json_sample_data = json.loads(sample_data) + current_generated_sample_data[gen_table_name] = json_sample_data # Save generated data to session_state + + if gen_table_name in current_generated_sample_data: + sample_data_df = pd.DataFrame.from_dict(current_generated_sample_data[gen_table_name], orient='columns') + st.write(sample_data_df) + + if stg_name_has_dependency: + st.info("""This table has been referenced by other tasks! + Please remove relevant dependency before trying to remove it from staging zone.""") + + # Show checkbox only after sample data has been generated + st.checkbox("Add to staging zone", + key=f"add_to_staging_{gen_table_index}_checkbox", + value=added_to_stage, + on_change=add_to_staging_zone, + args=[gen_table_index, gen_table_name, gen_table_desc], + disabled=stg_name_has_dependency) + + +def generate_tables(placeholder, + current_generated_tables, + current_pipeline_obj, + current_generated_sample_data, + industry_name, + industry_contexts, + openai_api): + st.session_state["has_clicked_generate_tables_btn"] = True + gen_tables_count = st.session_state["generate_tables_count"] = random.randint(5, 8) + if "disable_generate_data_widget" not in st.session_state or not st.session_state["disable_generate_data_widget"]: + st.session_state["disable_generate_data_widget"] = True + + # Clean existing table expanders before render new generate ones + placeholder.empty() + sleep(0.01) # Workaround for elements empty/cleaning issue, https://github.com/streamlit/streamlit/issues/5044 + + with placeholder.container(): + spinner_container = st.empty().container() + if not has_staged_table(): # Generate new tables if there's no existing dependency/references found in std/srv tasks + # Clean current_generated_sample_data key in session state + del st.session_state['current_generated_sample_data'] + current_generated_sample_data = {} + + gen_table_index = 0 + generated_tables = "" + current_generated_tables["generated_tables"] = [] + + while gen_table_index < gen_tables_count: + try: + with spinner_container: + with st.spinner(f'Generating {gen_table_index + 1} of {gen_tables_count} tables...'): + table = openai_api.recommend_tables_for_industry_one_at_a_time(industry_name, industry_contexts, generated_tables) + + current_generated_tables["generated_tables"].append(json.loads(table)) + generated_tables = json.dumps(current_generated_tables["generated_tables"], indent=2) + + render_table_expander(json.loads(table), + current_generated_tables, + current_generated_sample_data, + current_pipeline_obj, + gen_table_index, + industry_name, + openai_api) + + gen_table_index += 1 + except ValueError as e: + st.error("Got invalid response from AI Assistant, please try again!") + break + except Exception as e: + st.error("Got error while getting help from AI Assistant, please try again!") + break + + # Clean the temporary generated tables and redraw them again in 2_AI_Assistant.py to enable all widgets inside table expanders. + placeholder.empty() + sleep(0.01)