-
Notifications
You must be signed in to change notification settings - Fork 12
Support generation tasks and various improvements #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,18 +1,29 @@ | ||
| import copy | ||
| from functools import lru_cache | ||
| import math | ||
| import os | ||
| import psutil | ||
|
|
||
| import bert_score | ||
| import openai | ||
| from rouge_score import rouge_scorer | ||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| from grammar_definition import apply_prompt_format, flatten | ||
| from parsing_supernatural_instructions_tasks import OPEN_GENERATION_SUPERNATURAL_INSTRUCTIONS_TASKS | ||
|
|
||
| openai.api_key = os.getenv("OPENAI_API_KEY") | ||
| PRINT_HIDDEN_STATE = False | ||
|
|
||
|
|
||
| # Keep track of 10 different messages and then warn again | ||
| # Adapted from https://stackoverflow.com/a/66062313 | ||
| @lru_cache(10) | ||
| def print_once(msg: str): | ||
| print(msg) | ||
|
|
||
|
|
||
| def call_openai_api_with_retry(args, prompt, max_tokens=10): | ||
| try: | ||
| if args.gpt3_engine in ['gpt-4', 'gpt-3.5-turbo']: | ||
|
|
@@ -45,11 +56,15 @@ def call_openai_api_with_retry(args, prompt, max_tokens=10): | |
|
|
||
|
|
||
| def query_model_parallelized(model, tokenizer, prompt_list, max_tokens, top_p, temperature): | ||
| inputs = tokenizer(prompt_list, padding=True, return_tensors='pt', return_token_type_ids=False).to('cuda') | ||
| if tokenizer.chat_template is None: | ||
| inputs = tokenizer(prompt_list, padding=True, return_tensors='pt', return_token_type_ids=False).to('cuda') | ||
| else: | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Support chat models |
||
| turns = [[{"role": "user", "content": prompt}] for prompt in prompt_list] | ||
| inputs = tokenizer.apply_chat_template(turns, padding=True, return_tensors='pt', return_dict=True, add_generation_prompt=True, return_token_type_ids=False).to('cuda') | ||
|
|
||
| with torch.no_grad(): | ||
| outputs = model.generate( | ||
| **inputs, top_p=top_p, temperature=temperature, max_new_tokens=max_tokens, | ||
| **inputs, pad_token_id=tokenizer.pad_token_id, top_p=top_p, temperature=temperature, max_new_tokens=max_tokens, | ||
| return_dict_in_generate=True, output_hidden_states=True, output_attentions=False, output_scores=True | ||
| ) | ||
|
|
||
|
|
@@ -68,7 +83,10 @@ def query_model_parallelized(model, tokenizer, prompt_list, max_tokens, top_p, t | |
| logits_list[i].append(logits) | ||
| final_prompt_hidden_state_list = [None for _ in range(len(prompt_list))] | ||
|
|
||
| generated_answer_list = [s.lower() for s in tokenizer.batch_decode(outputs['sequences'], skip_special_tokens=True)] | ||
| sequences = outputs['sequences'] | ||
| if tokenizer.chat_template is not None: | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The truncation of chat models is a bit different from regular models and need to be done separately. Also got rid of lowercasing for generation tasks (done below for classification tasks). |
||
| sequences = [seq[len(input_ids):] for seq, input_ids in zip(sequences, inputs["input_ids"], strict=True)] | ||
| generated_answer_list = [s for s in tokenizer.batch_decode(sequences, skip_special_tokens=True)] | ||
| return generated_answer_list, logits_list, final_prompt_hidden_state_list | ||
|
|
||
|
|
||
|
|
@@ -128,6 +146,8 @@ def _setup_full_prompts_to_test_on(input_fields_list, regex_key_idx_list, select | |
| replacing some variables referring to multiple choice options. | ||
| - Apply prompt format to the desired set of examples to be tested (determined by interval_ids_to_test). | ||
| """ | ||
| if n_shot == 0: | ||
| assert len(demos_fields_list) == 0 and len(demos_regex_key_idx_list) == 0 and len(demonstrations_outputs) == 0 | ||
| demonstration_string = _setup_formatted_demonstrations_with_definition( | ||
| structured_prompt_format, demonstration_definition, demonstrations_outputs, | ||
| original_to_current_multiple_choice_classes, demos_fields_list, demos_regex_key_idx_list | ||
|
|
@@ -143,11 +163,22 @@ def _setup_full_prompts_to_test_on(input_fields_list, regex_key_idx_list, select | |
|
|
||
| full_prompt_string_list = [] | ||
| for input_element, idx in zip(inputs, selected_dataset_ids): | ||
| full_prompt_string_list.append(input_element if n_shot == 0 else demonstration_string + "\n\n" + input_element) | ||
| full_prompt_string_list.append(demonstration_string + "\n\n" + input_element) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think even for 0-shot, the instruction (demonstration_definition) should still be there. |
||
|
|
||
| return full_prompt_string_list, selected_dataset_ids | ||
|
|
||
|
|
||
| # layzily initiailize a bertscore model | ||
| bert_score_model = None | ||
|
|
||
|
|
||
| def get_bert_score_model(): | ||
| global bert_score_model | ||
| if bert_score_model is None: | ||
| bert_score_model = bert_score.BERTScorer(lang='en') | ||
| return bert_score_model | ||
|
|
||
|
|
||
| def evaluate_prompt_format( | ||
| args, dataset, input_fields_list, regex_key_idx_list, selected_dataset_ids, | ||
| demos_fields_list, demos_regex_key_idx_list, demonstrations_outputs, demonstration_definition, | ||
|
|
@@ -165,7 +196,11 @@ def evaluate_prompt_format( | |
| structured_prompt_format, original_to_current_multiple_choice_classes, interval_ids_to_test, args.n_shot) | ||
|
|
||
| # 2. update the output values if needed, i.e. if the multiple choice classes now have different names | ||
| assert all(len(dataset[idx]['output']) == 1 for idx in selected_dataset_ids) | ||
| assert all(len(dataset[idx]['output']) >= 1 for idx in selected_dataset_ids) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Be a bit lenient here. This doesn't hold for some generation tasks. When that happens, only take the first one.
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't it be better to raise some kind of warning in the > 1 case? Otherwise the code will silently take the first output and you might never know it happened
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah makes sense. Added a warning below. |
||
| for idx in selected_dataset_ids: | ||
| if len(dataset[idx]['output']) > 1: | ||
| print_once("Warning: multiple outputs for a single input. Only the first one will be used.") | ||
| dataset[idx]['output'] = [dataset[idx]['output'][0]] | ||
| dataset_updated = copy.deepcopy(dataset) | ||
| if original_to_current_multiple_choice_classes: | ||
| for idx in range(len(dataset)): | ||
|
|
@@ -177,11 +212,25 @@ def evaluate_prompt_format( | |
| return solve_with_rank_based_scoring( | ||
| dataset_updated, selected_dataset_ids, model, tokenizer, input_prompt_string_list, args.batch_size_llm) | ||
|
|
||
| elif args.evaluation_metric == 'exact_prefix_matching': | ||
| elif args.evaluation_metric in {'exact_prefix_matching', 'rouge', 'bertscore'}: | ||
| logs = generate_text_with_metadata( | ||
| args, input_prompt_string_list, model, tokenizer, model_will_repeat_input, | ||
| dataset_updated, selected_dataset_ids, output_classes) | ||
| return exact_prefix_matching_scoring(logs) | ||
| if args.evaluation_metric == 'exact_prefix_matching': | ||
| return exact_prefix_matching_scoring(logs) | ||
| elif args.evaluation_metric == 'rouge': | ||
| scorer = rouge_scorer.RougeScorer(['rougeL']) | ||
| scores = [scorer.score(entry['generation'], entry['answer'])["rougeL"].fmeasure for entry in logs] | ||
| avg_score = sum(scores) / len(scores) | ||
| return (avg_score, None, None), (None, logs) | ||
| elif args.evaluation_metric == 'bertscore': | ||
| references = [entry['answer'] for entry in logs] | ||
| candidates = [entry['generation'] for entry in logs] | ||
| scorer = get_bert_score_model() | ||
| P, R, F1 = scorer.score(candidates, references) | ||
| return (F1.mean().item(), None, None), (None, logs) | ||
|
|
||
|
|
||
|
|
||
|
|
||
| def generate_text_with_metadata(args, input_prompt_string_list, model, tokenizer, model_will_repeat_input, dataset, | ||
|
|
@@ -205,7 +254,9 @@ def generate_text_with_metadata(args, input_prompt_string_list, model, tokenizer | |
| generation_list, score_list, final_prompt_hidden_state_list = query_model_parallelized( | ||
| model, tokenizer, full_prompt_string_list, max_tokens=args.max_new_tokens, top_p=1.0, temperature=1.0, | ||
| ) | ||
| if model_will_repeat_input: | ||
| if args.task_filename not in OPEN_GENERATION_SUPERNATURAL_INSTRUCTIONS_TASKS: | ||
| generation_list = [gen.lower() for gen in generation_list] | ||
| if model_will_repeat_input and tokenizer.chat_template is None: # for chat models, truncation happens in query_model_parallelized | ||
| generation_list = [generation[len(full_prompt_string):] | ||
| for generation, full_prompt_string in zip(generation_list, full_prompt_string_list)] | ||
|
|
||
|
|
@@ -398,16 +449,20 @@ def get_ranking_based_generation_single_token_output_classes(prompts, output_cla | |
| output_classes_tokens = [t for t in tokenizer(output_classes, return_token_type_ids=False)['input_ids']] | ||
| all_classes_share_common_prefix = len(set([tuple(t[:-1]) for t in output_classes_tokens])) == 1 | ||
|
|
||
| tokenized_inputs_list = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False)[ | ||
| 'input_ids'].tolist() | ||
| if tokenizer.chat_template is None: | ||
| tokenized = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False) | ||
| else: | ||
| turns = [[{"role": "user", "content": prompt}] for prompt in prompts] | ||
| tokenized = tokenizer.apply_chat_template(turns, return_tensors="pt", padding=True, return_dict=True, add_generation_prompt=True, return_token_type_ids=False) | ||
| tokenized_inputs_list = tokenized["input_ids"].tolist() | ||
| if all_classes_share_common_prefix: | ||
| for i in range(len(tokenized_inputs_list)): | ||
| # if the tokenized element is [1, 29871, 29900], get [29871] | ||
| tokenized_inputs_list[i] += output_classes_tokens[0][1:-1] | ||
| tokenized_inputs = torch.tensor(tokenized_inputs_list).to('cuda') | ||
|
|
||
| with torch.no_grad(): | ||
| outputs = model.generate(input_ids=tokenized_inputs, | ||
| outputs = model.generate(input_ids=tokenized_inputs, pad_token_id=tokenizer.pad_token_id, | ||
| top_p=top_p, temperature=temperature, max_new_tokens=1, | ||
| return_dict_in_generate=True, output_scores=True) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like
_setup_non_formatted_dataset_with_one_field_onlyworks well for generation tasks out of the box.