From c691d44b99f093836cffd3eb0730275d62f80aec Mon Sep 17 00:00:00 2001 From: Zhaofeng Wu Date: Sat, 3 Aug 2024 00:47:45 +0000 Subject: [PATCH 1/3] Support generation tasks and various improvements --- data_loading.py | 4 ++-- main.py | 2 +- utils.py | 39 +++++++++++++++++++++++++++++++-------- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/data_loading.py b/data_loading.py index 7b3c9e2..5c416fa 100644 --- a/data_loading.py +++ b/data_loading.py @@ -5,7 +5,7 @@ from grammar_definition import flatten, _one_text_field from parsing_supernatural_instructions_tasks import SUPERNATURAL_INSTRUCTIONS_TASKS_WITH_NO_FORMAT, \ - create_initial_structured_prompt_format + create_initial_structured_prompt_format, OPEN_GENERATION_SUPERNATURAL_INSTRUCTIONS_TASKS SUPERNATURAL_INSTRUCTIONS_DIRECTORY = '../natural-instructions/tasks' INSTRUCTION_INDUCTION_DIRECTORY = '../instruction-induction' @@ -145,7 +145,7 @@ def load_supernatural_instructions_task(args): """ # SuperNaturalInstructions Tasks without a defined format - if any(t in args.task_filename for t in SUPERNATURAL_INSTRUCTIONS_TASKS_WITH_NO_FORMAT): + if any(t in args.task_filename for t in SUPERNATURAL_INSTRUCTIONS_TASKS_WITH_NO_FORMAT + OPEN_GENERATION_SUPERNATURAL_INSTRUCTIONS_TASKS): raw_dataset = _load_raw_dataset_supernatural_instructions(args.task_filename) demonstration_definition = raw_dataset['Definition'][0] raw_dataset = raw_dataset['Instances'] diff --git a/main.py b/main.py index bc900a5..6a08c24 100644 --- a/main.py +++ b/main.py @@ -317,7 +317,7 @@ def _get_output_filename(args): # params to set up evaluation settings parser.add_argument('--num_samples', type=int, default=1000, help='Maximum number of samples to evaluate for each format.') - parser.add_argument('--evaluation_metric', choices=['exact_prefix_matching', 'probability_ranking']) + parser.add_argument('--evaluation_metric', choices=['exact_prefix_matching', 'probability_ranking', 'rouge', 'bertscore']) parser.add_argument('--evaluation_type', type=str, choices=['full', 'format_spread'], help='Determines how to evaluate the array of formats defined. ' 'Options are full evaluation of each node, or use Thompson Sampling to quickly find the format spread.') diff --git a/utils.py b/utils.py index 389f4ff..85d42be 100644 --- a/utils.py +++ b/utils.py @@ -4,10 +4,12 @@ import psutil import openai +from rouge_score import rouge_scorer import torch import torch.nn.functional as F from grammar_definition import apply_prompt_format, flatten +from parsing_supernatural_instructions_tasks import OPEN_GENERATION_SUPERNATURAL_INSTRUCTIONS_TASKS openai.api_key = os.getenv("OPENAI_API_KEY") PRINT_HIDDEN_STATE = False @@ -45,7 +47,11 @@ def call_openai_api_with_retry(args, prompt, max_tokens=10): def query_model_parallelized(model, tokenizer, prompt_list, max_tokens, top_p, temperature): - inputs = tokenizer(prompt_list, padding=True, return_tensors='pt', return_token_type_ids=False).to('cuda') + if tokenizer.chat_template is None: + inputs = tokenizer(prompt_list, padding=True, return_tensors='pt', return_token_type_ids=False).to('cuda') + else: + turns = [[{"role": "user", "content": prompt}] for prompt in prompt_list] + inputs = tokenizer.apply_chat_template(turns, padding=True, return_tensors='pt', return_dict=True, add_generation_prompt=True, return_token_type_ids=False).to('cuda') with torch.no_grad(): outputs = model.generate( @@ -68,7 +74,10 @@ def query_model_parallelized(model, tokenizer, prompt_list, max_tokens, top_p, t logits_list[i].append(logits) final_prompt_hidden_state_list = [None for _ in range(len(prompt_list))] - generated_answer_list = [s.lower() for s in tokenizer.batch_decode(outputs['sequences'], skip_special_tokens=True)] + sequences = outputs['sequences'] + if tokenizer.chat_template is not None: + sequences = [seq[len(input_ids):] for seq, input_ids in zip(sequences, inputs["input_ids"], strict=True)] + generated_answer_list = [s for s in tokenizer.batch_decode(sequences, skip_special_tokens=True)] return generated_answer_list, logits_list, final_prompt_hidden_state_list @@ -128,6 +137,8 @@ def _setup_full_prompts_to_test_on(input_fields_list, regex_key_idx_list, select replacing some variables referring to multiple choice options. - Apply prompt format to the desired set of examples to be tested (determined by interval_ids_to_test). """ + if n_shot == 0: + assert len(demos_fields_list) == 0 and len(demos_regex_key_idx_list) == 0 and len(demonstrations_outputs) == 0 demonstration_string = _setup_formatted_demonstrations_with_definition( structured_prompt_format, demonstration_definition, demonstrations_outputs, original_to_current_multiple_choice_classes, demos_fields_list, demos_regex_key_idx_list @@ -143,7 +154,7 @@ def _setup_full_prompts_to_test_on(input_fields_list, regex_key_idx_list, select full_prompt_string_list = [] for input_element, idx in zip(inputs, selected_dataset_ids): - full_prompt_string_list.append(input_element if n_shot == 0 else demonstration_string + "\n\n" + input_element) + full_prompt_string_list.append(demonstration_string + "\n\n" + input_element) return full_prompt_string_list, selected_dataset_ids @@ -165,7 +176,10 @@ def evaluate_prompt_format( structured_prompt_format, original_to_current_multiple_choice_classes, interval_ids_to_test, args.n_shot) # 2. update the output values if needed, i.e. if the multiple choice classes now have different names - assert all(len(dataset[idx]['output']) == 1 for idx in selected_dataset_ids) + assert all(len(dataset[idx]['output']) >= 1 for idx in selected_dataset_ids) + for idx in selected_dataset_ids: + if len(dataset[idx]['output']) > 1: + dataset[idx]['output'] = [dataset[idx]['output'][0]] dataset_updated = copy.deepcopy(dataset) if original_to_current_multiple_choice_classes: for idx in range(len(dataset)): @@ -177,11 +191,18 @@ def evaluate_prompt_format( return solve_with_rank_based_scoring( dataset_updated, selected_dataset_ids, model, tokenizer, input_prompt_string_list, args.batch_size_llm) - elif args.evaluation_metric == 'exact_prefix_matching': + elif args.evaluation_metric in {'exact_prefix_matching', 'rouge'}: logs = generate_text_with_metadata( args, input_prompt_string_list, model, tokenizer, model_will_repeat_input, dataset_updated, selected_dataset_ids, output_classes) - return exact_prefix_matching_scoring(logs) + if args.evaluation_metric == 'exact_prefix_matching': + return exact_prefix_matching_scoring(logs) + elif args.evaluation_metric == 'rouge': + scorer = rouge_scorer.RougeScorer(['rougeL']) + scores = [scorer.score(entry['generation'], entry['answer'])["rougeL"].fmeasure for entry in logs] + avg_score = sum(scores) / len(scores) + return (avg_score, None, None), (None, logs) + def generate_text_with_metadata(args, input_prompt_string_list, model, tokenizer, model_will_repeat_input, dataset, @@ -205,7 +226,9 @@ def generate_text_with_metadata(args, input_prompt_string_list, model, tokenizer generation_list, score_list, final_prompt_hidden_state_list = query_model_parallelized( model, tokenizer, full_prompt_string_list, max_tokens=args.max_new_tokens, top_p=1.0, temperature=1.0, ) - if model_will_repeat_input: + if args.task_filename not in OPEN_GENERATION_SUPERNATURAL_INSTRUCTIONS_TASKS: + generation_list = [gen.lower() for gen in generation_list] + if model_will_repeat_input and tokenizer.chat_template is None: # for chat models, truncation happens in query_model_parallelized generation_list = [generation[len(full_prompt_string):] for generation, full_prompt_string in zip(generation_list, full_prompt_string_list)] @@ -407,7 +430,7 @@ def get_ranking_based_generation_single_token_output_classes(prompts, output_cla tokenized_inputs = torch.tensor(tokenized_inputs_list).to('cuda') with torch.no_grad(): - outputs = model.generate(input_ids=tokenized_inputs, + outputs = model.generate(input_ids=tokenized_inputs, pad_token_id=tokenizer.pad_token_id, top_p=top_p, temperature=temperature, max_new_tokens=1, return_dict_in_generate=True, output_scores=True) From 48ea81cf003573fc99b3bc39272ae81619e4d78e Mon Sep 17 00:00:00 2001 From: Zhaofeng Wu Date: Mon, 5 Aug 2024 21:34:45 +0000 Subject: [PATCH 2/3] Support bertscore; support chat models for probability ranking --- utils.py | 31 +++++++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/utils.py b/utils.py index 85d42be..41b5eba 100644 --- a/utils.py +++ b/utils.py @@ -3,6 +3,7 @@ import os import psutil +import bert_score import openai from rouge_score import rouge_scorer import torch @@ -55,7 +56,7 @@ def query_model_parallelized(model, tokenizer, prompt_list, max_tokens, top_p, t with torch.no_grad(): outputs = model.generate( - **inputs, top_p=top_p, temperature=temperature, max_new_tokens=max_tokens, + **inputs, pad_token_id=tokenizer.pad_token_id, top_p=top_p, temperature=temperature, max_new_tokens=max_tokens, return_dict_in_generate=True, output_hidden_states=True, output_attentions=False, output_scores=True ) @@ -159,6 +160,17 @@ def _setup_full_prompts_to_test_on(input_fields_list, regex_key_idx_list, select return full_prompt_string_list, selected_dataset_ids +# layzily initiailize a bertscore model +bert_score_model = None + + +def get_bert_score_model(): + global bert_score_model + if bert_score_model is None: + bert_score_model = bert_score.BERTScorer(lang='en') + return bert_score_model + + def evaluate_prompt_format( args, dataset, input_fields_list, regex_key_idx_list, selected_dataset_ids, demos_fields_list, demos_regex_key_idx_list, demonstrations_outputs, demonstration_definition, @@ -191,7 +203,7 @@ def evaluate_prompt_format( return solve_with_rank_based_scoring( dataset_updated, selected_dataset_ids, model, tokenizer, input_prompt_string_list, args.batch_size_llm) - elif args.evaluation_metric in {'exact_prefix_matching', 'rouge'}: + elif args.evaluation_metric in {'exact_prefix_matching', 'rouge', 'bertscore'}: logs = generate_text_with_metadata( args, input_prompt_string_list, model, tokenizer, model_will_repeat_input, dataset_updated, selected_dataset_ids, output_classes) @@ -202,6 +214,13 @@ def evaluate_prompt_format( scores = [scorer.score(entry['generation'], entry['answer'])["rougeL"].fmeasure for entry in logs] avg_score = sum(scores) / len(scores) return (avg_score, None, None), (None, logs) + elif args.evaluation_metric == 'bertscore': + references = [entry['answer'] for entry in logs] + candidates = [entry['generation'] for entry in logs] + scorer = get_bert_score_model() + P, R, F1 = scorer.score(candidates, references) + return (F1.mean().item(), None, None), (None, logs) + @@ -421,8 +440,12 @@ def get_ranking_based_generation_single_token_output_classes(prompts, output_cla output_classes_tokens = [t for t in tokenizer(output_classes, return_token_type_ids=False)['input_ids']] all_classes_share_common_prefix = len(set([tuple(t[:-1]) for t in output_classes_tokens])) == 1 - tokenized_inputs_list = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False)[ - 'input_ids'].tolist() + if tokenizer.chat_template is None: + tokenized = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False) + else: + turns = [[{"role": "user", "content": prompt}] for prompt in prompts] + tokenized = tokenizer.apply_chat_template(turns, return_tensors="pt", padding=True, return_dict=True, add_generation_prompt=True, return_token_type_ids=False) + tokenized_inputs_list = tokenized["input_ids"].tolist() if all_classes_share_common_prefix: for i in range(len(tokenized_inputs_list)): # if the tokenized element is [1, 29871, 29900], get [29871] From ef0056ac6365c3a9fde6cd709423182d1e322a9f Mon Sep 17 00:00:00 2001 From: Zhaofeng Wu Date: Mon, 26 Aug 2024 15:00:33 +0000 Subject: [PATCH 3/3] Add warning --- utils.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/utils.py b/utils.py index 41b5eba..fbef841 100644 --- a/utils.py +++ b/utils.py @@ -1,4 +1,5 @@ import copy +from functools import lru_cache import math import os import psutil @@ -16,6 +17,13 @@ PRINT_HIDDEN_STATE = False +# Keep track of 10 different messages and then warn again +# Adapted from https://stackoverflow.com/a/66062313 +@lru_cache(10) +def print_once(msg: str): + print(msg) + + def call_openai_api_with_retry(args, prompt, max_tokens=10): try: if args.gpt3_engine in ['gpt-4', 'gpt-3.5-turbo']: @@ -191,6 +199,7 @@ def evaluate_prompt_format( assert all(len(dataset[idx]['output']) >= 1 for idx in selected_dataset_ids) for idx in selected_dataset_ids: if len(dataset[idx]['output']) > 1: + print_once("Warning: multiple outputs for a single input. Only the first one will be used.") dataset[idx]['output'] = [dataset[idx]['output'][0]] dataset_updated = copy.deepcopy(dataset) if original_to_current_multiple_choice_classes: