From ece0bcc15226f3b5e5634bef3c095d9ecc12ea01 Mon Sep 17 00:00:00 2001 From: itholic Date: Fri, 23 Jun 2023 17:52:46 +0900 Subject: [PATCH 1/3] [SPARK-44155][SQL] Adding a dev utility to improve error messages based on LLM --- dev/error_message_refiner.py | 265 +++++++++++++++++++++++++++++++++++ dev/tox.ini | 3 + 2 files changed, 268 insertions(+) create mode 100644 dev/error_message_refiner.py diff --git a/dev/error_message_refiner.py b/dev/error_message_refiner.py new file mode 100644 index 0000000000000..bc274633fc56e --- /dev/null +++ b/dev/error_message_refiner.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python3 + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Utility for refining error messages based on LLM. + +Usage: + python error_message_refiner.py [--gpt_version=] + +Arguments: + Required. + The name of the error class to refine the messages for. + The list of error classes is located in + `core/src/main/resources/error/error-classes.json`. + +Options: + --gpt_version= Optional. + The version of Chat GPT to use for refining the error messages. + If not provided, the default version("gpt-3.5-turbo") will be used. + +Example usage: + python error_message_refiner.py CANNOT_DECODE_URL --gpt_version=gpt-4 + +Description: + This script refines error messages using the LLM based approach. + It takes the name of the error class as a required argument and, optionally, + allows specifying the version of Chat GPT to use for refining the messages. + + Options: + --gpt_version: Specifies the version of Chat GPT. + If not provided, the default version("gpt-3.5-turbo") will be used. + + Note: + - Ensure that the necessary dependencies are installed before running the script. + - Ensure that the valid API key is entered in the `api-key.txt`. + - The refined error messages will be displayed in the console output. + - To use the gpt-4 model, you need to join the waitlist. Please refer to + https://help.openai.com/en/articles/7102672-how-can-i-access-gpt-4 for more details. +""" + +import argparse +import json +import openai +import re +import subprocess +import random +import os +from typing import Tuple, Optional + +from sparktestsupport import SPARK_HOME + +PATH_TO_ERROR_CLASS = f"{SPARK_HOME}/core/src/main/resources/error/error-classes.json" +# Register your own open API key for the environment variable. +# The API key can be obtained from https://platform.openai.com/account/api-keys. +OPENAI_API_KEY = os.environ.get("OPEN_API_KEY") + +openai.api_key = OPENAI_API_KEY + + +def _git_grep_files(search_string: str, exclude: str = None) -> str: + """ + Executes 'git grep' command to search for files containing the given search string. + Returns the file path where the search string is found. + """ + result = subprocess.run( + ["git", "grep", "-l", search_string, "--", f"{SPARK_HOME}/*.scala"], + capture_output=True, + text=True, + ) + output = result.stdout.strip() + + files = output.split("\n") + files = [file for file in files if "Suite" not in file] + if exclude is not None: + files = [file for file in files if exclude not in file] + file = random.choice(files) + return file + + +def _find_function(file_name: str, search_string: str) -> Optional[str]: + """ + Searches for a function in the given file containing the specified search string. + Returns the name of the function if found, otherwise None. + """ + with open(file_name, "r") as file: + content = file.read() + functions = re.findall(r"def\s+(\w+)\s*\(", content) + + for function in functions: + function_content = re.search( + rf"def\s+{re.escape(function)}(?:(?!def).)*?{re.escape(search_string)}", + content, + re.DOTALL, + ) + if function_content and search_string in function_content.group(0): + return function + + return None + + +def _find_func_body(file_name: str, search_string: str) -> Optional[str]: + """ + Searches for a function body in the given file containing the specified search string. + Returns the function body if found, otherwise None. + """ + with open(file_name, "r") as file: + content = file.read() + functions = re.findall(r"def\s+(\w+)\s*\(", content) + + for function in functions: + function_content = re.search( + rf"def\s+{re.escape(function)}(?:(?!def\s).)*?{re.escape(search_string)}", + content, + re.DOTALL, + ) + if function_content and search_string in function_content.group(0): + return function_content.group(0) + + return None + + +def _get_error_function(error_class: str) -> str: + """ + Retrieves the name of the error function that triggers the given error class. + """ + search_string = error_class + matched_file = _git_grep_files(search_string) + err_func = _find_function(matched_file, search_string) + return err_func + + +def _get_source_code(error_class: str) -> str: + """ + Retrieves the source code of a function where the given error class is being invoked. + """ + search_string = error_class + matched_file = _git_grep_files(search_string) + err_func = _find_function(matched_file, search_string) + source_file = _git_grep_files(err_func, exclude=matched_file) + source_code = _find_func_body(source_file, err_func) + return source_code + + +def _get_error_message(error_class: str) -> str: + """ + Returns the error message from the provided error class + listed in core/src/main/resources/error/error-classes.json. + """ + with open(PATH_TO_ERROR_CLASS) as f: + error_classes = json.load(f) + + error_message = error_classes.get(error_class) + if error_message: + return error_message["message"] + else: + return f"Error message not found for class: {error_class}" + + +def ask_chat_gpt(error_class: str, gpt_version: str) -> Tuple[str, str]: + """ + Requests error message improvement from Chat GPT. + Returns a tuple containing the old error message and the refined error message. + """ + old_error_message = " ".join(_get_error_message(error_class)) + error_function = _get_error_function(error_class) + source_code = _get_source_code(error_class) + prompt = f"""I would like to improve the error messages in Apache Spark. +Apache Spark provides error message guidelines, so error messages generated by Apache Spark should follow the following rules: + +1. Include What, Why, and How +Error messages should explicitly address What, Why, and How. +For example, the error message "Unable to generate an encoder for inner class <> without access to the scope that this class was defined in. Try moving this class out of its parent class." follows the guidelines as it includes What, Why, and How. +On the other hand, the error message "Unsupported function name <>" only includes What and does not provide Why and How, so it needs improvement to comply with the guidelines. + +2. Use clear language +Error messages in should adhere to the Diction guide and Wording guide provided in the two tables below in Markdown format. + +- Diction guide: + +| Phrases | When to use | Example | +| ------------------------------------------------------------ | ------------------------------------------------------------ | ----------------------------------------------- | +| Unsupported | The user may reasonably assume that the operation is supported, but it is not. This error may go away in the future if developers add support for the operation. | `Data type <> is unsupported.` | +| Invalid / Not allowed / Unexpected | The user made a mistake when specifying an operation. The message should inform the user of how to resolve the error. | `Array has size <>, index <> is invalid.` | +| `Found <> generators for the clause <>. Only one generator is allowed.` | | | +| `Found an unexpected state format version <>. Expected versions 1 or 2.` | | | +| Failed to | The system encountered an unexpected error that cannot be reasonably attributed to user error. | `Failed to compile <>.` | +| Cannot | Any time, preferably only if one of the above alternatives does not apply. | `Cannot generate code for unsupported type <>.` | + +- Wording guide: + +| Best practice | Before | After | +| ------------------------------------------------------------ | ------------------------------------------------------------ | ------------------------------------------------------------ | +| Use active voice | `DataType <> is not supported by <>.` | `<> does not support datatype <>.` | +| Avoid time-based statements, such as promises of future support | `Pandas UDF aggregate expressions are currently not supported in pivot.` | `Pivot does not support Pandas UDF aggregate expressions.` | +| `Parquet type not yet supported: <>.` | `<> does not support Parquet type.` | | +| Use the present tense to describe the error and provide suggestions | `Couldn't find the reference column for <> at <>.` | `Cannot find the reference column for <> at <>.` | +| `Join strategy hint parameter should be an identifier or string but was <>.` | `Cannot use join strategy hint parameter <>. Use a table name or identifier to specify the parameter.` | | +| Provide concrete examples if the resolution is unclear | `<> Hint expects a partition number as a parameter.` | `<> Hint expects a partition number as a parameter. For example, specify 3 partitions with <>(3).` | +| Avoid sounding accusatory, judgmental, or insulting | `You must specify an amount for <>.` | `<> cannot be empty. Specify an amount for <>.` | +| Be direct | `LEGACY store assignment policy is disallowed in Spark data source V2. Please set the configuration spark.sql.storeAssignmentPolicy to other values.` | `Spark data source V2 does not allow the LEGACY store assignment policy. Set the configuration spark.sql.storeAssignment to ANSI or STRICT.` | +| Do not use programming jargon in user-facing errors | `RENAME TABLE source and destination databases do not match: '<>' != '<>'.` | `RENAME TABLE source and destination databases do not match. The source database is <>, but the destination database is <>.` | + +To adhere to the above guidelines, please improve the following error message: "{old_error_message}" +Please note that the angle brackets in the error message represent the placeholder for the error message parameter, so they should not be modified or removed. +For more detail, the error message is triggered through a error function named "{error_function}", and the source code that actually calls the error function is as follows: + +{source_code} + +When improving the error, please also refer to the source code provided above to provide more detailed context to users.""" + try: + response = openai.ChatCompletion.create( + model=gpt_version, + messages=[ + { + "role": "system", + "content": "You are a engineer who want to improve the error messages " + "for Apache Spark users.", + }, + {"role": "user", "content": prompt}, + ], + ) + except openai.error.AuthenticationError: + raise openai.error.AuthenticationError( + "Please verify if the API key is entered correctly in `dev/api_key.txt`." + ) + except openai.error.InvalidRequestError as e: + if "gpt-4" in str(e): + raise openai.error.AuthenticationError( + "To use the gpt-4 model, you need to join the waitlist. " + "Please refer to " + "https://help.openai.com/en/articles/7102672-how-can-i-access-gpt-4 " + "for more details." + ) + raise e + result = "" + for choice in response.choices: + result += choice.message.content + return old_error_message, result + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("error_class", type=str) + parser.add_argument("--gpt_version", type=str, default="gpt-3.5-turbo") + + args = parser.parse_args() + + old_error_message, new_error_message = ask_chat_gpt(args.error_class, args.gpt_version) + print(f"[{args.error_class}]\nBefore: {old_error_message}\nAfter: {new_error_message}") diff --git a/dev/tox.ini b/dev/tox.ini index e2a77786ed4cc..389bb9d63da25 100644 --- a/dev/tox.ini +++ b/dev/tox.ini @@ -42,6 +42,9 @@ per-file-ignores = python/pyspark/streaming/tests/*.py: F403, python/pyspark/tests/*.py: F403, python/pyspark/testing/*: F401 + # This script includes a prompt for ChatGPT, + # so we make an exception for E501 in order not to impose a limit on the length of the prompt. + dev/error_message_refiner.py: E501, exclude = */target/*, docs/.local_ruby_bundle/, From 0ee23726ed9c06c1f0040d8f995d71efb47f66b7 Mon Sep 17 00:00:00 2001 From: itholic Date: Wed, 12 Jul 2023 11:15:38 +0900 Subject: [PATCH 2/3] adjusted comments --- dev/error_message_refiner.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/dev/error_message_refiner.py b/dev/error_message_refiner.py index bc274633fc56e..6ec09c18b0548 100644 --- a/dev/error_message_refiner.py +++ b/dev/error_message_refiner.py @@ -21,21 +21,21 @@ Utility for refining error messages based on LLM. Usage: - python error_message_refiner.py [--gpt_version=] + python error_message_refiner.py [--model_name=] Arguments: Required. The name of the error class to refine the messages for. The list of error classes is located in - `core/src/main/resources/error/error-classes.json`. + `common/utils/src/main/resources/error/error-classes.json`. Options: - --gpt_version= Optional. + --model_name= Optional. The version of Chat GPT to use for refining the error messages. If not provided, the default version("gpt-3.5-turbo") will be used. Example usage: - python error_message_refiner.py CANNOT_DECODE_URL --gpt_version=gpt-4 + python error_message_refiner.py CANNOT_DECODE_URL --model_name=gpt-4 Description: This script refines error messages using the LLM based approach. @@ -43,7 +43,7 @@ allows specifying the version of Chat GPT to use for refining the messages. Options: - --gpt_version: Specifies the version of Chat GPT. + --model_name: Specifies the version of Chat GPT. If not provided, the default version("gpt-3.5-turbo") will be used. Note: @@ -65,10 +65,10 @@ from sparktestsupport import SPARK_HOME -PATH_TO_ERROR_CLASS = f"{SPARK_HOME}/core/src/main/resources/error/error-classes.json" +PATH_TO_ERROR_CLASS = f"{SPARK_HOME}/common/utils/src/main/resources/error/error-classes.json" # Register your own open API key for the environment variable. # The API key can be obtained from https://platform.openai.com/account/api-keys. -OPENAI_API_KEY = os.environ.get("OPEN_API_KEY") +OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") openai.api_key = OPENAI_API_KEY @@ -172,7 +172,7 @@ def _get_error_message(error_class: str) -> str: return f"Error message not found for class: {error_class}" -def ask_chat_gpt(error_class: str, gpt_version: str) -> Tuple[str, str]: +def ask_chat_gpt(error_class: str, model_name: str) -> Tuple[str, str]: """ Requests error message improvement from Chat GPT. Returns a tuple containing the old error message and the refined error message. @@ -225,7 +225,7 @@ def ask_chat_gpt(error_class: str, gpt_version: str) -> Tuple[str, str]: When improving the error, please also refer to the source code provided above to provide more detailed context to users.""" try: response = openai.ChatCompletion.create( - model=gpt_version, + model=model_name, messages=[ { "role": "system", @@ -237,7 +237,7 @@ def ask_chat_gpt(error_class: str, gpt_version: str) -> Tuple[str, str]: ) except openai.error.AuthenticationError: raise openai.error.AuthenticationError( - "Please verify if the API key is entered correctly in `dev/api_key.txt`." + "Please verify if the API key is set correctly in `OPENAI_API_KEY`." ) except openai.error.InvalidRequestError as e: if "gpt-4" in str(e): @@ -257,9 +257,9 @@ def ask_chat_gpt(error_class: str, gpt_version: str) -> Tuple[str, str]: if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("error_class", type=str) - parser.add_argument("--gpt_version", type=str, default="gpt-3.5-turbo") + parser.add_argument("--model_name", type=str, default="gpt-3.5-turbo") args = parser.parse_args() - old_error_message, new_error_message = ask_chat_gpt(args.error_class, args.gpt_version) + old_error_message, new_error_message = ask_chat_gpt(args.error_class, args.model_name) print(f"[{args.error_class}]\nBefore: {old_error_message}\nAfter: {new_error_message}") From accb87a6d58ae0e068e4aa943010e2657a2b6d27 Mon Sep 17 00:00:00 2001 From: Haejoon Lee <44108233+itholic@users.noreply.github.com> Date: Thu, 20 Jul 2023 10:45:15 +0900 Subject: [PATCH 3/3] Update dev/error_message_refiner.py Co-authored-by: Amanda Liu --- dev/error_message_refiner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/error_message_refiner.py b/dev/error_message_refiner.py index 6ec09c18b0548..0afebe9c356bb 100644 --- a/dev/error_message_refiner.py +++ b/dev/error_message_refiner.py @@ -174,7 +174,7 @@ def _get_error_message(error_class: str) -> str: def ask_chat_gpt(error_class: str, model_name: str) -> Tuple[str, str]: """ - Requests error message improvement from Chat GPT. + Requests error message improvement from GPT API Returns a tuple containing the old error message and the refined error message. """ old_error_message = " ".join(_get_error_message(error_class))