From e49273b387fa30ad054540e9cbd5e0411d2c597d Mon Sep 17 00:00:00 2001 From: Jiali Cheng <20452970+chengjiali@users.noreply.github.com> Date: Thu, 1 Aug 2024 08:51:02 +0800 Subject: [PATCH] add GPQA, SciBench, SciEval --- eval/Subject/gpqa/evaluate_gpqa_cot.py | 149 +++++++ eval/Subject/gpqa/util.py | 396 ++++++++++++++++++ .../Subject/scibench/evaluate_scibench_cot.py | 129 ++++++ eval/Subject/scibench/util.py | 365 ++++++++++++++++ eval/Subject/scieval/SciEval | 1 + eval/Subject/scieval/evaluate_scieval_cot.py | 134 ++++++ eval/Subject/scieval/util.py | 358 ++++++++++++++++ 7 files changed, 1532 insertions(+) create mode 100644 eval/Subject/gpqa/evaluate_gpqa_cot.py create mode 100644 eval/Subject/gpqa/util.py create mode 100644 eval/Subject/scibench/evaluate_scibench_cot.py create mode 100644 eval/Subject/scibench/util.py create mode 160000 eval/Subject/scieval/SciEval create mode 100644 eval/Subject/scieval/evaluate_scieval_cot.py create mode 100644 eval/Subject/scieval/util.py diff --git a/eval/Subject/gpqa/evaluate_gpqa_cot.py b/eval/Subject/gpqa/evaluate_gpqa_cot.py new file mode 100644 index 0000000..23f55e4 --- /dev/null +++ b/eval/Subject/gpqa/evaluate_gpqa_cot.py @@ -0,0 +1,149 @@ +from typing import Dict, Any +import os +import json +from tqdm import tqdm +from datetime import datetime +import openai +from time import sleep +import argparse +from util import * +from datasets import Dataset, load_dataset +import pandas as pd +from vllm import LLM, SamplingParams +import torch +import time +from collections import namedtuple + +Example = namedtuple('Example', ['question', 'choice1', 'choice2', 'choice3', 'choice4', 'correct_index']) + + + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str, default="./eurus-7b-kto-hf") +# meta-llama/Llama-2-7b-chat-hf +# parser.add_argument("--input_data", type=str, default="./theorem_qa.json") +parser.add_argument("--save_dir", type=str, default="./") +parser.add_argument("--model_type", type=str, default='mistral') + +args = parser.parse_args() + + +llm = LLM( + model=args.model, + trust_remote_code=True, + tensor_parallel_size=torch.cuda.device_count(), +) + +def generate_sample_batch(question_list): + sampling_params = SamplingParams(max_tokens=1024, + temperature=0.0, + top_p=1, + n=1, + stop=["question:"],) + outputs = llm.generate(question_list, sampling_params, use_tqdm=False) + completions = [output.outputs[0].text for output in outputs] + return completions + + +# def create_reader_request(example: Dict[str, Any]) -> str: +# string = f"Question: {example['Question']}" +# return string + + +# Copied from: https://github.com/idavidrein/gpqa/blob/main/baselines/utils.py#L197 +def load_examples(data): + import random + seed = 0 + random.seed(seed) + + def shuffle_choices_and_create_example(row) -> Example: + list_choices = [row['Incorrect Answer 1'], row['Incorrect Answer 2'], row['Incorrect Answer 3'], row['Correct Answer']] + random.shuffle(list_choices) + example = Example(row.Question, list_choices[0], list_choices[1], list_choices[2], list_choices[3], + list_choices.index(row['Correct Answer'])) + return example + + return [shuffle_choices_and_create_example(row) for _, row in data.iterrows()] + + +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained(args.model) +def make_conv(example, model_type, prev_output=None): + prompt_round1 = f"What is the correct answer to this question: {example.question}" + prompt_round1 += f"\n\nChoices:\n(A) {example.choice1}\n(B) {example.choice2}\n(C) {example.choice3}\n(D) {example.choice4}" + prompt_round1 += "\nLet's think step-by-step:" + prompt_round2 = '\n\nBased on the above, what is the single, most likely answer choice? Answer in the format \"The correct answer is (insert answer here)\".' + + if prev_output is None: + prompt = prompt_round1 + else: + prompt = prompt_round1 + prev_output + prompt_round2 + + msg = [{"role": "user", "content": prompt},] + out = tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) + return out + + + +if __name__ == "__main__": + + + test_set = load_dataset('Idavidrein/gpqa', 'gpqa_main')['train'].to_pandas()[:3] + test_set['example'] = load_examples(test_set) + test_set["prompt_round1"] = test_set['example'].apply(lambda example: make_conv(example, args.model_type)) + + # Adapted from: https://github.com/idavidrein/gpqa/blob/main/baselines/utils.py#L48 + # First round: Think step-by-step + test_set["completion_round1"] = generate_sample_batch(test_set["prompt_round1"].tolist()) + + # Second round: Extract final answer + test_set["prompt_round2"] = test_set.apply(lambda row: make_conv(row['example'], args.model_type, row['completion_round1']), axis=1) + test_set["completion"] = generate_sample_batch(test_set["prompt_round2"].tolist()) + test_set.to_json(os.path.join(args.save_dir, "gpqa_completion.json")) + + answered_set = dict() + correct, wrong = 0, 0 + + output_filename = os.path.join(args.save_dir, "gpqa_output.json") + writer = open(output_filename, 'w') + accuracy = [] + index_mapping = ['(a)', '(b)', '(c)', '(d)'] + + for example in test_set.to_dict(orient="records"): + + if answered_set and example['id'] in answered_set: + writer.write(answered_set[example['id']] + '\n') + continue + + result = example["completion"] + _, prediction = match_answer(result) + prediction = postprocess_number(prediction) + reference = index_mapping[example["example"].correct_index] + + verifier = GPQATask(id=example["Record ID"], + prompt=example["prompt_round2"], + reference=reference, + answer_type='Option') + acc = verifier.success(prediction) + tmp = { + 'id': example['Record ID'], + 'question': example['prompt_round2'], + 'prediction': prediction, + 'answer': reference, + 'rationale': result, + 'answer_type': 'Option', + "is_correct": acc, + } + writer.write(json.dumps(tmp) + '\n') + + + accuracy.append(acc) + + writer.close() + print() + + accuracy = sum([1 if acc else 0 for acc in accuracy]) / len(accuracy) + with open(os.path.join(args.save_dir, "result.txt"), "w") as f: + print({"accuracy": accuracy * 100}) + print({"accuracy": accuracy * 100}, file=f) + diff --git a/eval/Subject/gpqa/util.py b/eval/Subject/gpqa/util.py new file mode 100644 index 0000000..573660a --- /dev/null +++ b/eval/Subject/gpqa/util.py @@ -0,0 +1,396 @@ +# copy from https://github.com/xingyaoww/mint-bench/blob/main/mint/tasks/reasoning.py +import numpy as np +from inspect import isfunction +from typing import * + +import re +import ast +from sympy import Rational +import traceback + + +def extract_code(result: str) -> str: + lines = [] + within_func = False + for line in result.split('\n'): + if line.startswith('`'): + continue + + if line.startswith('def '): + within_func = True + + if line.startswith('import') or line.startswith('from'): + lines.append(line) + + if within_func: + lines.append(line) + + if line.startswith(' return') or line.startswith(' return'): + within_func = False + + for line_no in range(len(lines) - 1, -1, -1): + if 'return ' not in lines[line_no]: + del lines[line_no] + else: + break + result = '\n'.join(lines) + return result + +def match_answer(response): + is_matched = False + ans_marker = 'The correct answer is :\n' + ans_idx = response.lower().rfind(ans_marker) + if ans_idx != -1: + is_matched = True + response = response[ans_idx + len(ans_marker):].strip() + if response.endswith("\n"): + response = response[:-2] + + ans_marker = 'The correct answer is : ' + ans_idx = response.lower().rfind(ans_marker) + if ans_idx != -1: + is_matched = True + response = response[ans_idx + len(ans_marker):].strip() + if response.endswith("\n"): + response = response[:-2] + + # Find boxed + ans_boxed = _last_boxed_only_string(response) + if ans_boxed: + is_matched = True + response = ans_boxed + + # Grade + return is_matched, response + +def _last_boxed_only_string(string): + idx = string.rfind("\\boxed") + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + left_brace_idx = None + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if left_brace_idx is None: + left_brace_idx = i + elif string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + + i += 1 + + if left_brace_idx is None or right_brace_idx is None: + return None + + return string[left_brace_idx + 1: right_brace_idx].strip() + +def extract_answer(result: str) -> str: + prediction = result.strip().strip('\n').split('\n')[-1] + tmp = '' + for entry in prediction.split(' ')[::-1]: + if entry == 'is' or entry == 'be' or entry == 'are' or entry.endswith(':'): + break + tmp = entry + ' ' + tmp + prediction = tmp.strip().strip('.') + return prediction + + +def postprocess_number(prediction): + if isinstance(prediction, set): + prediction = list(prediction) + elif isinstance(prediction, np.complex128): + prediction = prediction.real + elif isinstance(prediction, np.ndarray): + prediction = prediction.tolist() + elif isinstance(prediction, complex): + prediction = prediction.real + elif isinstance(prediction, list): + prediction = [float(x) for x in prediction] + elif 'sympy' in str(type(prediction)): + prediction = float(prediction) + elif isfunction(prediction): + prediction = None + + return prediction + + +impossible_questions = ['jianyu_xu/integer_programming_1.json', 'tonyxia/totient6.json'] + + + +def compare_two_numbers(p, gt): + if isinstance(p, int) or isinstance(p, float): + pass + elif isinstance(p, list) or isinstance(p, bool) or isinstance(p, str): + return False + elif isinstance(p, tuple) or isinstance(p, complex) or isinstance(p, dict): + return False + else: + return False + + if isinstance(gt, float): + return within_eps(pred=p, gt=gt) + else: + return round(p) == gt + + +def compare_two_list(pred, gt): + if not isinstance(pred, list): + return False + elif len(pred) != len(gt): + return False + elif any([not isinstance(x, (int, float)) for x in pred]): + return False + else: + pred = sorted(pred) + gt = sorted(gt) + return all([compare_two_numbers(p, g) for p, g in zip(pred, gt)]) + + +def within_eps(pred: float, gt: float): + eps = abs(gt) * 0.04 + if pred >= gt - eps and pred <= gt + eps: + return True + else: + return False + + +def parse_number_list(s: str): + # Check if the string is a valid list by trying to parse it + parsed_list = ast.literal_eval(s) + return parsed_list + + +def is_number(string): + pattern = r"^[-+]?(\d{1,3}(,\d{3})*|(\d+))(\.\d+)?$" + match = re.match(pattern, string) + return bool(match) + + +def is_scientific_number(string): + pattern = r"^[-+]?\d+(\.\d+)?e[-]?\d+$" + match = re.match(pattern, string) + return bool(match) + + +def contain_num_and_str(string): + pattern_str = r"[a-zA-Z]" + pattern_num = r"[0-9]" + return bool(re.search(pattern_str, string) and re.search(pattern_num, string)) + + +def is_number(string): + pattern = r"^[-+]?(\d{1,3}(,\d{3})*|(\d+))(\.\d+)?$" + match = re.match(pattern, string) + return bool(match) + + +def is_scientific_number(string): + pattern = r"^[-+]?\d+(\.\d+)?e[-]?\d+$" + match = re.match(pattern, string) + return bool(match) + +class GPQATask(): + """Subclass of Task for multiple choice tasks.""" + + task_name = "reasoning" + + def __init__(self, id, prompt: str, reference: str, **kwargs): + self._id = id + self._prompt = ( + "Answer the following question with a number, a list of numbers or True or False. " + + prompt.strip() + ) + self._reference = reference + self._answer_type = kwargs.get("answer_type") + + def extract_answer(self, solution: str) -> Optional[Any]: + """Extract the answer from the given solution.""" + prediction = solution.lower() + # Following the preprocessing steps from TheoremQA + # https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L170 + + # Preprocessing the string [Stage 1] + if not isinstance(prediction, str): + prediction = str(prediction) if prediction is not None else "" + + # Replace special tokens + if "=" in prediction: + prediction = prediction.split("=")[-1].strip() + if "≈" in prediction: + prediction = prediction.split("≈")[-1].strip() + if "`" in prediction: + prediction = prediction.replace("`", "") + if "$" in prediction: + prediction = prediction.replace("$", "") + if "°" in prediction: + prediction = prediction.replace("°", "") + + # Detect the boolean keyword in the generation + if prediction.lower() in ["true", "yes", "false", "no"]: + if prediction.lower() == "true" or prediction.lower() == "yes": + prediction = "True" + else: + prediction = "False" + if "True" in prediction or "False" in prediction: + prediction = "True" if "True" in prediction else "False" + + # Detect the approximation keyword + if "approximately" in prediction: + prediction = prediction.replace("approximately", "").strip() + if " or " in prediction: + prediction = prediction.split(" or ")[0] + + # Drop the units before and after the number + if re.match(r"[-+]?(?:[\d,]*\.*\d+) [^0-9 ]+$", prediction): + prediction = re.search( + r"([-+]?(?:[\d,]*\.*\d+)) [^0-9 ]+$", prediction + ).group(1) + if re.match(r"[^0-9 ]+ [-+]?(?:[\d,]*\.*\d+)$", prediction): + prediction = re.search( + r"[^0-9 ]+ ([-+]?(?:[\d,]*\.*\d+))$", prediction + ).group(1) + if re.match(r"[-+]?(?:[\d,]*\.*\d+)[^\d]{1,2}$", prediction): + prediction = re.search( + r"([-+]?(?:[\d,]*\.*\d+))[^\d]{1,2}$", prediction + ).group(1) + if re.match(r"[^-+\d]{1,2}(?:[\d,]*\.*\d+)$", prediction): + prediction = re.search( + r"[^-+\d]{1,2}((?:[\d,]*\.*\d+))$", prediction + ).group(1) + + # Preprocessing the number [Stage 1] + if "10^" in prediction: + prediction = re.sub(r"10\^(-?\d+)", r"math.pow(10, \1)", prediction) + if " x " in prediction: + prediction = prediction.replace(" x ", "*") + if " × " in prediction: + prediction = prediction.replace(" × ", "*") + if is_number(prediction): + prediction = prediction.replace(",", "") + + # Preprocessing the option [Stage 3] + if "a)" in prediction or "a )" in prediction or prediction.lower().strip() == "a": + prediction = "(a)" + if "b)" in prediction or "b )" in prediction or prediction.lower().strip() == "b": + prediction = "(b)" + if "c)" in prediction or "c )" in prediction or prediction.lower().strip() == "c": + prediction = "(c)" + if "d)" in prediction or "d )" in prediction or prediction.lower().strip() == "d": + prediction = "(d)" + + if ( + "(a)" in prediction + or "(b)" in prediction + or "(c)" in prediction + or "(d)" in prediction + ): + prediction = '"' + re.search(r"\([a-d]\)", prediction).group(0) + '"' + + # # If the prediction is empty, use dummy '0' + # if not prediction: + # prediction = "0" + + # # Converting the string answer to a number/list/bool/option + # try: + # prediction = eval(prediction) + # except Exception as e: + # ''' + # print( + # f"[TASK] Failed to convert the answer: {prediction}\n{traceback.format_exc()}" + # ) + # ''' + # return None # failed to convert the answer + # If the prediction is empty, use dummy '0' + if prediction: + # Converting the string answer to a number/list/bool/option + try: + prediction = eval(prediction) + except Exception as e: + ''' + print( + f"[TASK] Failed to convert the answer: {prediction}\n{traceback.format_exc()}" + ) + ''' + return None # failed to convert the answer + + # Performing common type conversion + if isinstance(prediction, (set, tuple)): + prediction = list(prediction) + if isinstance(prediction[0], complex): + prediction = [tmp.real for tmp in prediction] + elif isinstance(prediction[0], Rational): + prediction = [float(tmp) for tmp in prediction] + elif isinstance(prediction, np.ndarray): + prediction = prediction.tolist() + else: + if isinstance(prediction, complex): + prediction = prediction.real + elif isinstance(prediction, Rational): + prediction = float(prediction) + + return prediction + + def compare_w_digits(self, reference: str, answer: str) -> bool: + if reference.isdigit() and answer.isdigit(): + return abs(float(reference) - float(answer)) <= 0.05 * float(reference) + else: + return reference in answer + + def success(self, solution: str) -> bool: + """This checks whether the given solution can complete the current task.""" + # Follow the implementation from TheoremQA + # https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L301C9-L317C1 + if not solution: # empty + return False + prediction = self.extract_answer(solution) + if prediction is None or prediction == "": + return False + # LOGGER.info(f"GPQA Parsed Prediction: {prediction}") + answer_type = self._answer_type + gt = self._reference + + if isinstance(prediction, (str, int, float)) or isinstance(prediction, list): + # Comparing prediction against the reference + if answer_type in ["bool", "option", "Option"]: + cur_correct = int(prediction == f"({gt})") or int(prediction == gt) + elif answer_type == "integer": + cur_correct = int(compare_two_numbers(prediction, gt)) + elif answer_type == "float": + cur_correct = int(compare_two_numbers(prediction, gt)) + elif answer_type in ["list of integer", "list of float"]: + cur_correct = int(compare_two_list(prediction, gt)) + else: + cur_correct = 0 + return bool(cur_correct) + + + def extract_options(self, prompt: str) -> dict: + # Find the possible option separators (comma, semicolon, or parentheses) + prompt = prompt.split("Options: ")[-1] + # Extract the options using the delimiter + options_match = prompt.split(" , ") + options = {} + for i in range(len(options_match)): + option = options_match[i].strip("[]' ") + option = option.split(")") + letter = option[0].lower().strip() + content = ( + option[1] + .lower() + .strip(".") + .replace(". Which option is correct?", "") + .replace(". Which one is correct?", "") + .strip() + ) + options.update({letter: content}) + return options + diff --git a/eval/Subject/scibench/evaluate_scibench_cot.py b/eval/Subject/scibench/evaluate_scibench_cot.py new file mode 100644 index 0000000..0e9bb6c --- /dev/null +++ b/eval/Subject/scibench/evaluate_scibench_cot.py @@ -0,0 +1,129 @@ +from typing import Dict, Any +import os +import json +from tqdm import tqdm +from datetime import datetime +import openai +from time import sleep +import argparse +from util import * +from datasets import Dataset, load_dataset +import pandas as pd +from vllm import LLM, SamplingParams +import torch +import time +from collections import namedtuple + +Example = namedtuple('Example', ['question', 'choice1', 'choice2', 'choice3', 'choice4', 'correct_index']) + + + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str, default="./eurus-7b-kto-hf") +# meta-llama/Llama-2-7b-chat-hf +# parser.add_argument("--input_data", type=str, default="./theorem_qa.json") +parser.add_argument("--save_dir", type=str, default="./") +parser.add_argument("--model_type", type=str, default='mistral') + +args = parser.parse_args() + + +llm = LLM( + model=args.model, + trust_remote_code=True, + tensor_parallel_size=torch.cuda.device_count(), +) + +def generate_sample_batch(question_list): + sampling_params = SamplingParams(max_tokens=1024, + temperature=0.0, + top_p=1, + n=1, + stop=["question:"],) + outputs = llm.generate(question_list, sampling_params, use_tqdm=False) + completions = [output.outputs[0].text for output in outputs] + return completions + + +# def create_reader_request(example: Dict[str, Any]) -> str: +# string = f"Question: {example['Question']}" +# return string + + +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained(args.model) +def make_conv(question, model_type, prev_output=None): + # Prompt template copied from: https://github.com/mandyyyyii/scibench/blob/main/eval/eval_zeroCot.py#L11 + prompt_round1 = f"Q: " + question + "\n" + "A: Let's think step by step." + prompt_round2 = 'Therefore, the answer is \\boxed{Your Answer}.' + + if prev_output is None: + prompt = prompt_round1 + else: + prompt = prompt_round1 + prev_output + prompt_round2 + + msg = [{"role": "user", "content": prompt},] + out = tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) + return out + + + +if __name__ == "__main__": + + + test_set = load_dataset('xw27/scibench')['train'].to_pandas() + test_set["prompt_round1"] = test_set['problem_text'].apply(lambda question: make_conv(question, args.model_type)) + + # Adapted from: https://github.com/mandyyyyii/scibench/blob/main/eval/eval_zeroCot.py#L90 + # First round: Think step-by-step + test_set["completion_round1"] = generate_sample_batch(test_set["prompt_round1"].tolist()) + + # Second round: Extract final answer + test_set["prompt_round2"] = test_set.apply(lambda row: make_conv(row['problem_text'], args.model_type, row['completion_round1']), axis=1) + test_set["completion"] = generate_sample_batch(test_set["prompt_round2"].tolist()) + test_set.to_json(os.path.join(args.save_dir, "scibench_completion.json")) + + answered_set = dict() + correct, wrong = 0, 0 + + output_filename = os.path.join(args.save_dir, "scibench_output.json") + writer = open(output_filename, 'w') + accuracy = [] + + for idx, example in enumerate(test_set.to_dict(orient="records")): + + if answered_set and idx in answered_set: + writer.write(answered_set[idx] + '\n') + continue + + result = example["completion"] + _, prediction = match_answer(result) + prediction = postprocess_number(prediction) + + verifier = SciBenchTask(id=idx, + prompt=example["prompt_round2"], + reference=example["answer_number"], + answer_type='float') + acc = verifier.success(prediction) + tmp = { + 'id': idx, + 'question': example['prompt_round2'], + 'prediction': prediction, + 'answer': example["answer_number"], + 'rationale': result, + 'answer_type': 'float', + "is_correct": acc, + } + writer.write(json.dumps(tmp) + '\n') + + + accuracy.append(acc) + + writer.close() + print() + + accuracy = sum([1 if acc else 0 for acc in accuracy]) / len(accuracy) + with open(os.path.join(args.save_dir, "result.txt"), "w") as f: + print({"accuracy": accuracy * 100}) + print({"accuracy": accuracy * 100}, file=f) + diff --git a/eval/Subject/scibench/util.py b/eval/Subject/scibench/util.py new file mode 100644 index 0000000..4ca1b7a --- /dev/null +++ b/eval/Subject/scibench/util.py @@ -0,0 +1,365 @@ +# copy from https://github.com/xingyaoww/mint-bench/blob/main/mint/tasks/reasoning.py +import numpy as np +from inspect import isfunction +from typing import * + +import json +import re +import ast +from sympy import Rational +import traceback + + +def extract_code(result: str) -> str: + lines = [] + within_func = False + for line in result.split('\n'): + if line.startswith('`'): + continue + + if line.startswith('def '): + within_func = True + + if line.startswith('import') or line.startswith('from'): + lines.append(line) + + if within_func: + lines.append(line) + + if line.startswith(' return') or line.startswith(' return'): + within_func = False + + for line_no in range(len(lines) - 1, -1, -1): + if 'return ' not in lines[line_no]: + del lines[line_no] + else: + break + result = '\n'.join(lines) + return result + +def match_answer(response): + is_matched = False + ans_marker = 'the answer is\n' + ans_idx = response.lower().rfind(ans_marker) + if ans_idx != -1: + is_matched = True + response = response[ans_idx + len(ans_marker):].strip() + if response.endswith("\n"): + response = response[:-2] + + ans_marker = 'the answer is ' + ans_idx = response.lower().rfind(ans_marker) + if ans_idx != -1: + is_matched = True + response = response[ans_idx + len(ans_marker):].strip() + if response.endswith("\n"): + response = response[:-2] + print(response) + + # Find boxed + ans_boxed = _last_boxed_only_string(response) + if ans_boxed: + is_matched = True + response = ans_boxed + + # Grade + return is_matched, response + +def _last_boxed_only_string(string): + idx = string.rfind("\\boxed") + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + left_brace_idx = None + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if left_brace_idx is None: + left_brace_idx = i + elif string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + + i += 1 + + if left_brace_idx is None or right_brace_idx is None: + return None + + return string[left_brace_idx + 1: right_brace_idx].strip() + +def extract_answer(result: str) -> str: + prediction = result.strip().strip('\n').split('\n')[-1] + tmp = '' + for entry in prediction.split(' ')[::-1]: + if entry == 'is' or entry == 'be' or entry == 'are' or entry.endswith(':'): + break + tmp = entry + ' ' + tmp + prediction = tmp.strip().strip('.') + return prediction + + +def postprocess_number(prediction): + if isinstance(prediction, set): + prediction = list(prediction) + elif isinstance(prediction, np.complex128): + prediction = prediction.real + elif isinstance(prediction, np.ndarray): + prediction = prediction.tolist() + elif isinstance(prediction, complex): + prediction = prediction.real + elif isinstance(prediction, list): + prediction = [float(x) for x in prediction] + elif 'sympy' in str(type(prediction)): + prediction = float(prediction) + elif isfunction(prediction): + prediction = None + + return prediction + + +impossible_questions = ['jianyu_xu/integer_programming_1.json', 'tonyxia/totient6.json'] + + + +def compare_two_numbers(p, gt): + if isinstance(p, int) or isinstance(p, float): + pass + elif isinstance(p, list) or isinstance(p, bool) or isinstance(p, str): + return False + elif isinstance(p, tuple) or isinstance(p, complex) or isinstance(p, dict): + return False + else: + return False + + if isinstance(gt, float): + return within_eps(pred=p, gt=gt) + else: + return round(p) == gt + + +def compare_two_list(pred, gt): + if not isinstance(pred, list): + return False + elif len(pred) != len(gt): + return False + elif any([not isinstance(x, (int, float)) for x in pred]): + return False + else: + pred = sorted(pred) + gt = sorted(gt) + return all([compare_two_numbers(p, g) for p, g in zip(pred, gt)]) + + +def within_eps(pred: float, gt: float): + eps = abs(gt) * 0.04 + if pred >= gt - eps and pred <= gt + eps: + return True + else: + return False + + +def parse_number_list(s: str): + # Check if the string is a valid list by trying to parse it + parsed_list = ast.literal_eval(s) + return parsed_list + + +def is_number(string): + pattern = r"^[-+]?(\d{1,3}(,\d{3})*|(\d+))(\.\d+)?$" + match = re.match(pattern, string) + return bool(match) + + +def is_scientific_number(string): + pattern = r"^[-+]?\d+(\.\d+)?e[-]?\d+$" + match = re.match(pattern, string) + return bool(match) + + +def contain_num_and_str(string): + pattern_str = r"[a-zA-Z]" + pattern_num = r"[0-9]" + return bool(re.search(pattern_str, string) and re.search(pattern_num, string)) + + +def is_number(string): + pattern = r"^[-+]?(\d{1,3}(,\d{3})*|(\d+))(\.\d+)?$" + match = re.match(pattern, string) + return bool(match) + + +def remove_not(x): + match_number = re.compile('[\$]?\ *10\^[{]?\ *-?[0-9]+\ *[}]?\ *[\$]?') + result=re.findall(match_number, x) + if len(result) !=0: + return re.split(match_number, x)[-1] + return None + + +def parse_not(inputs): + try: + if not inputs: + return '','' + if '\times' in inputs: + x,ab=inputs.split('\times') + elif '\\times' in inputs: + x,ab=inputs.split('\\times') + elif '*' in inputs: + x,ab=inputs.split('*') + else: + return inputs + return x,ab + except: + return '','' + + +def cal_not(inputs): + try: + x,ab=list(inputs) + match_number = re.compile('10\^[{]?\ *-?[0-9]+\ *[}]?') + ab=re.findall(match_number, ab)[0] + ab=ab[ab.find('^')+1:] + if '{' in ab: + ab=ab[ab.find('{')+1:] + if '}' in ab: + ab=ab[:ab.find('}')] + x=x.strip() + out=float(x)*10**float(ab) + # print(float(x)*10**float(ab)) + return str(out) + except: + print('error') + return inputs + + +def remove_boxed(s): + left = "oxed{" #change + try: + assert s[:len(left)] == left + assert s[-1] == "}" + answer = s[len(left):-1] + if "=" in answer: + answer = answer.split("=")[-1].lstrip(" ") + return answer + except: + return None + + +def last_boxed_only_string(string): + idx = string.rfind("oxed") #change + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx == None: + retval = None + else: + retval = string[idx:right_brace_idx + 1] + + return retval + + +class SciBenchTask(): + """Subclass of Task for multiple choice tasks.""" + + task_name = "reasoning" + + def __init__(self, id, prompt: str, reference: str, **kwargs): + self._id = id + self._prompt = ( + "Answer the following question with a number, a list of numbers or True or False. " + + prompt.strip() + ) + self._reference = reference + self._answer_type = kwargs.get('answer_type') + + def extract_answer(self, solution: str) -> Optional[Any]: + """Extract the answer from the given solution.""" + prediction = solution + # Following the preprocessing steps from TheoremQA + # https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L170 + + # is_matched, model_output = match_answer(prediction) + + # https://github.com/mandyyyyii/scibench/blob/main/eval/eval_zeroCot.py#L96 + model_output = remove_boxed(last_boxed_only_string(prediction)) + + return model_output + + def compare_w_digits(self, reference: str, answer: str) -> bool: + if reference.isdigit() and answer.isdigit(): + return abs(float(reference) - float(answer)) <= 0.05 * float(reference) + else: + return reference in answer + + def success(self, solution: str) -> bool: + """This checks whether the given solution can complete the current task.""" + # Follow the implementation from TheoremQA + # https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L301C9-L317C1 + if not solution: # empty + return False + prediction = self.extract_answer(solution) + print(prediction, self._reference) + + if prediction is None or prediction == "": + return False + # LOGGER.info(f"GPQA Parsed Prediction: {prediction}") + answer_type = self._answer_type + gt = self._reference + + if isinstance(prediction, (str, int, float)) or isinstance(prediction, list): + # Comparing prediction against the reference + if answer_type in ["bool", "option", "Option"]: + cur_correct = int(prediction == f"({gt})") or int(prediction == gt) + elif answer_type == "integer": + cur_correct = int(compare_two_numbers(prediction, gt)) + elif answer_type == "float": + cur_correct = int(compare_two_numbers(prediction, gt)) + elif answer_type in ["list of integer", "list of float"]: + cur_correct = int(compare_two_list(prediction, gt)) + else: + cur_correct = 0 + return bool(cur_correct) + + + def extract_options(self, prompt: str) -> dict: + # Find the possible option separators (comma, semicolon, or parentheses) + prompt = prompt.split("Options: ")[-1] + # Extract the options using the delimiter + options_match = prompt.split(" , ") + options = {} + for i in range(len(options_match)): + option = options_match[i].strip("[]' ") + option = option.split(")") + letter = option[0].lower().strip() + content = ( + option[1] + .lower() + .strip(".") + .replace(". Which option is correct?", "") + .replace(". Which one is correct?", "") + .strip() + ) + options.update({letter: content}) + return options + diff --git a/eval/Subject/scieval/SciEval b/eval/Subject/scieval/SciEval new file mode 160000 index 0000000..777119e --- /dev/null +++ b/eval/Subject/scieval/SciEval @@ -0,0 +1 @@ +Subproject commit 777119e96e90c1368d0461874b9e89a70821c3a6 diff --git a/eval/Subject/scieval/evaluate_scieval_cot.py b/eval/Subject/scieval/evaluate_scieval_cot.py new file mode 100644 index 0000000..43a4c53 --- /dev/null +++ b/eval/Subject/scieval/evaluate_scieval_cot.py @@ -0,0 +1,134 @@ +from typing import Dict, Any +import os +import json +from tqdm import tqdm +from datetime import datetime +import openai +from time import sleep +import argparse +from util import * +from datasets import Dataset, load_dataset +import pandas as pd +from vllm import LLM, SamplingParams +import torch +import time +from collections import namedtuple + +Example = namedtuple('Example', ['question', 'choice1', 'choice2', 'choice3', 'choice4', 'correct_index']) + + + +parser = argparse.ArgumentParser() +parser.add_argument("--model", type=str, default="./eurus-7b-kto-hf") +# meta-llama/Llama-2-7b-chat-hf +# parser.add_argument("--input_data", type=str, default="./theorem_qa.json") +parser.add_argument("--save_dir", type=str, default="./") +parser.add_argument("--model_type", type=str, default='mistral') + +args = parser.parse_args() + + +llm = LLM( + model=args.model, + trust_remote_code=True, + tensor_parallel_size=torch.cuda.device_count(), +) + +def generate_sample_batch(question_list): + sampling_params = SamplingParams(max_tokens=1024, + temperature=0.0, + top_p=1, + n=1, + stop=["question:"],) + outputs = llm.generate(question_list, sampling_params, use_tqdm=False) + completions = [output.outputs[0].text for output in outputs] + return completions + + +# def create_reader_request(example: Dict[str, Any]) -> str: +# string = f"Question: {example['Question']}" +# return string + + +from transformers import AutoTokenizer +tokenizer = AutoTokenizer.from_pretrained(args.model) +def make_conv(example, model_type, prev_output=None): + # Prompt template copied from: https://github.com/mandyyyyii/scieval/blob/main/eval/eval_zeroCot.py#L11 + prompt_round1 = example['cot_prompot'] + '\n' + example['question'] + "\nLet's think step by step: " + prompt_round2 = ' Therefore, the answer is' + + if prev_output is None: + prompt = prompt_round1 + else: + prompt = prompt_round1 + prev_output + prompt_round2 + + msg = [{"role": "user", "content": prompt},] + out = tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) + return out + + + +if __name__ == "__main__": + + + test_set = load_dataset('OpenDFM/SciEval')['valid'].to_pandas() + test_set["prompt_round1"] = test_set.apply(lambda example: make_conv(example, args.model_type), axis=1) + + # Adapted from: https://github.com/OpenDFM/SciEval/blob/main/eval.py + # First round: Think step-by-step + test_set["completion_round1"] = generate_sample_batch(test_set["prompt_round1"].tolist()) + + # Second round: Extract final answer + test_set["prompt_round2"] = test_set.apply(lambda row: make_conv(row['problem_text'], args.model_type, row['completion_round1']), axis=1) + test_set["completion"] = generate_sample_batch(test_set["prompt_round2"].tolist()) + test_set.to_json(os.path.join(args.save_dir, "scieval_completion.json")) + + answered_set = dict() + correct, wrong = 0, 0 + + output_filename = os.path.join(args.save_dir, "scieval_output.json") + writer = open(output_filename, 'w') + accuracy = [] + + for example in test_set.to_dict(orient="records"): + + if answered_set and idx in answered_set: + writer.write(answered_set[idx] + '\n') + continue + + result = example["completion"] + _, prediction = match_answer(result) + prediction = postprocess_number(prediction) + + verifier = SciEvalTask(id=example['id'], + prompt=example["prompt_round2"], + reference=example["answer"], + answer_type='Option') + acc = verifier.success(prediction) + tmp = { + 'id': example['id'], + 'question': example['prompt_round2'], + 'prediction': prediction, + 'answer': example["answer"], + 'rationale': result, + 'category': example['category'], + 'topic': example['topic'], + 'ability': example['ability'], + 'type': example['type'], + 'task_name': example['task_name'], + 'answer_type': 'Option', + "is_correct": acc, + } + writer.write(json.dumps(tmp) + '\n') + + + accuracy.append(acc) + + writer.close() + print() + + accuracy = sum([1 if acc else 0 for acc in accuracy]) / len(accuracy) + with open(os.path.join(args.save_dir, "result.txt"), "w") as f: + print({"accuracy": accuracy * 100}) + print({"accuracy": accuracy * 100}, file=f) + diff --git a/eval/Subject/scieval/util.py b/eval/Subject/scieval/util.py new file mode 100644 index 0000000..46227ab --- /dev/null +++ b/eval/Subject/scieval/util.py @@ -0,0 +1,358 @@ +# copy from https://github.com/xingyaoww/mint-bench/blob/main/mint/tasks/reasoning.py +import numpy as np +from inspect import isfunction +from typing import * + +import json +import re +import ast +from sympy import Rational +import traceback + + +def extract_code(result: str) -> str: + lines = [] + within_func = False + for line in result.split('\n'): + if line.startswith('`'): + continue + + if line.startswith('def '): + within_func = True + + if line.startswith('import') or line.startswith('from'): + lines.append(line) + + if within_func: + lines.append(line) + + if line.startswith(' return') or line.startswith(' return'): + within_func = False + + for line_no in range(len(lines) - 1, -1, -1): + if 'return ' not in lines[line_no]: + del lines[line_no] + else: + break + result = '\n'.join(lines) + return result + +def match_answer(response): + is_matched = False + ans_marker = 'The correct answer is :\n' + ans_idx = response.lower().rfind(ans_marker) + if ans_idx != -1: + is_matched = True + response = response[ans_idx + len(ans_marker):].strip() + if response.endswith("\n"): + response = response[:-2] + + ans_marker = 'The correct answer is : ' + ans_idx = response.lower().rfind(ans_marker) + if ans_idx != -1: + is_matched = True + response = response[ans_idx + len(ans_marker):].strip() + if response.endswith("\n"): + response = response[:-2] + + # Find boxed + ans_boxed = _last_boxed_only_string(response) + if ans_boxed: + is_matched = True + response = ans_boxed + + # Grade + return is_matched, response + +def _last_boxed_only_string(string): + idx = string.rfind("\\boxed") + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + left_brace_idx = None + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if left_brace_idx is None: + left_brace_idx = i + elif string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + + i += 1 + + if left_brace_idx is None or right_brace_idx is None: + return None + + return string[left_brace_idx + 1: right_brace_idx].strip() + +def extract_answer(result: str) -> str: + prediction = result.strip().strip('\n').split('\n')[-1] + tmp = '' + for entry in prediction.split(' ')[::-1]: + if entry == 'is' or entry == 'be' or entry == 'are' or entry.endswith(':'): + break + tmp = entry + ' ' + tmp + prediction = tmp.strip().strip('.') + return prediction + + +def postprocess_number(prediction): + if isinstance(prediction, set): + prediction = list(prediction) + elif isinstance(prediction, np.complex128): + prediction = prediction.real + elif isinstance(prediction, np.ndarray): + prediction = prediction.tolist() + elif isinstance(prediction, complex): + prediction = prediction.real + elif isinstance(prediction, list): + prediction = [float(x) for x in prediction] + elif 'sympy' in str(type(prediction)): + prediction = float(prediction) + elif isfunction(prediction): + prediction = None + + return prediction + + +impossible_questions = ['jianyu_xu/integer_programming_1.json', 'tonyxia/totient6.json'] + + + +def compare_two_numbers(p, gt): + if isinstance(p, int) or isinstance(p, float): + pass + elif isinstance(p, list) or isinstance(p, bool) or isinstance(p, str): + return False + elif isinstance(p, tuple) or isinstance(p, complex) or isinstance(p, dict): + return False + else: + return False + + if isinstance(gt, float): + return within_eps(pred=p, gt=gt) + else: + return round(p) == gt + + +def compare_two_list(pred, gt): + if not isinstance(pred, list): + return False + elif len(pred) != len(gt): + return False + elif any([not isinstance(x, (int, float)) for x in pred]): + return False + else: + pred = sorted(pred) + gt = sorted(gt) + return all([compare_two_numbers(p, g) for p, g in zip(pred, gt)]) + + +def within_eps(pred: float, gt: float): + eps = abs(gt) * 0.04 + if pred >= gt - eps and pred <= gt + eps: + return True + else: + return False + + +def parse_number_list(s: str): + # Check if the string is a valid list by trying to parse it + parsed_list = ast.literal_eval(s) + return parsed_list + + +def is_number(string): + pattern = r"^[-+]?(\d{1,3}(,\d{3})*|(\d+))(\.\d+)?$" + match = re.match(pattern, string) + return bool(match) + + +def is_scientific_number(string): + pattern = r"^[-+]?\d+(\.\d+)?e[-]?\d+$" + match = re.match(pattern, string) + return bool(match) + + +def contain_num_and_str(string): + pattern_str = r"[a-zA-Z]" + pattern_num = r"[0-9]" + return bool(re.search(pattern_str, string) and re.search(pattern_num, string)) + + +def is_number(string): + pattern = r"^[-+]?(\d{1,3}(,\d{3})*|(\d+))(\.\d+)?$" + match = re.match(pattern, string) + return bool(match) + + +def remove_not(x): + match_number = re.compile('[\$]?\ *10\^[{]?\ *-?[0-9]+\ *[}]?\ *[\$]?') + result=re.findall(match_number, x) + if len(result) !=0: + return re.split(match_number, x)[-1] + return None + + +def parse_not(inputs): + try: + if not inputs: + return '','' + if '\times' in inputs: + x,ab=inputs.split('\times') + elif '\\times' in inputs: + x,ab=inputs.split('\\times') + elif '*' in inputs: + x,ab=inputs.split('*') + else: + return inputs + return x,ab + except: + return '','' + + +def cal_not(inputs): + try: + x,ab=list(inputs) + match_number = re.compile('10\^[{]?\ *-?[0-9]+\ *[}]?') + ab=re.findall(match_number, ab)[0] + ab=ab[ab.find('^')+1:] + if '{' in ab: + ab=ab[ab.find('{')+1:] + if '}' in ab: + ab=ab[:ab.find('}')] + x=x.strip() + out=float(x)*10**float(ab) + # print(float(x)*10**float(ab)) + return str(out) + except: + print('error') + return inputs + + +def remove_boxed(s): + left = "oxed{" #change + try: + assert s[:len(left)] == left + assert s[-1] == "}" + answer = s[len(left):-1] + if "=" in answer: + answer = answer.split("=")[-1].lstrip(" ") + return answer + except: + return None + + +def last_boxed_only_string(string): + idx = string.rfind("oxed") #change + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx == None: + retval = None + else: + retval = string[idx:right_brace_idx + 1] + + return retval + + +class SciBenchTask(): + """Subclass of Task for multiple choice tasks.""" + + task_name = "reasoning" + + def __init__(self, id, prompt: str, reference: str, **kwargs): + self._id = id + self._prompt = ( + "Answer the following question with a number, a list of numbers or True or False. " + + prompt.strip() + ) + self._reference = reference + + def extract_answer(self, solution: str) -> Optional[Any]: + """Extract the answer from the given solution.""" + prediction = solution + # Following the preprocessing steps from TheoremQA + # https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L170 + + prediction = remove_boxed(last_boxed_only_string(prediction)) + + return prediction + + def compare_w_digits(self, reference: str, answer: str) -> bool: + if reference.isdigit() and answer.isdigit(): + return abs(float(reference) - float(answer)) <= 0.05 * float(reference) + else: + return reference in answer + + def success(self, solution: str) -> bool: + """This checks whether the given solution can complete the current task.""" + # Follow the implementation from TheoremQA + # https://github.com/wenhuchen/TheoremQA/blob/123e36beaaa97c01f28a582f13c4f77a6822c199/predict_accuracy.py#L301C9-L317C1 + if not solution: # empty + return False + prediction = self.extract_answer(solution) + if prediction is None or prediction == "": + return False + # LOGGER.info(f"GPQA Parsed Prediction: {prediction}") + answer_type = self._answer_type + gt = self._reference + + if isinstance(prediction, (str, int, float)) or isinstance(prediction, list): + # Comparing prediction against the reference + if answer_type in ["bool", "option", "Option"]: + cur_correct = int(prediction == f"({gt})") or int(prediction == gt) + elif answer_type == "integer": + cur_correct = int(compare_two_numbers(prediction, gt)) + elif answer_type == "float": + cur_correct = int(compare_two_numbers(prediction, gt)) + elif answer_type in ["list of integer", "list of float"]: + cur_correct = int(compare_two_list(prediction, gt)) + else: + cur_correct = 0 + return bool(cur_correct) + + + def extract_options(self, prompt: str) -> dict: + # Find the possible option separators (comma, semicolon, or parentheses) + prompt = prompt.split("Options: ")[-1] + # Extract the options using the delimiter + options_match = prompt.split(" , ") + options = {} + for i in range(len(options_match)): + option = options_match[i].strip("[]' ") + option = option.split(")") + letter = option[0].lower().strip() + content = ( + option[1] + .lower() + .strip(".") + .replace(". Which option is correct?", "") + .replace(". Which one is correct?", "") + .strip() + ) + options.update({letter: content}) + return options +