Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from grammar_definition import flatten, _one_text_field
from parsing_supernatural_instructions_tasks import SUPERNATURAL_INSTRUCTIONS_TASKS_WITH_NO_FORMAT, \
create_initial_structured_prompt_format
create_initial_structured_prompt_format, OPEN_GENERATION_SUPERNATURAL_INSTRUCTIONS_TASKS

SUPERNATURAL_INSTRUCTIONS_DIRECTORY = '../natural-instructions/tasks'
INSTRUCTION_INDUCTION_DIRECTORY = '../instruction-induction'
Expand Down Expand Up @@ -145,7 +145,7 @@ def load_supernatural_instructions_task(args):
"""

# SuperNaturalInstructions Tasks without a defined format
if any(t in args.task_filename for t in SUPERNATURAL_INSTRUCTIONS_TASKS_WITH_NO_FORMAT):
if any(t in args.task_filename for t in SUPERNATURAL_INSTRUCTIONS_TASKS_WITH_NO_FORMAT + OPEN_GENERATION_SUPERNATURAL_INSTRUCTIONS_TASKS):
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like _setup_non_formatted_dataset_with_one_field_only works well for generation tasks out of the box.

raw_dataset = _load_raw_dataset_supernatural_instructions(args.task_filename)
demonstration_definition = raw_dataset['Definition'][0]
raw_dataset = raw_dataset['Instances']
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def _get_output_filename(args):

# params to set up evaluation settings
parser.add_argument('--num_samples', type=int, default=1000, help='Maximum number of samples to evaluate for each format.')
parser.add_argument('--evaluation_metric', choices=['exact_prefix_matching', 'probability_ranking'])
parser.add_argument('--evaluation_metric', choices=['exact_prefix_matching', 'probability_ranking', 'rouge', 'bertscore'])
parser.add_argument('--evaluation_type', type=str, choices=['full', 'format_spread'],
help='Determines how to evaluate the array of formats defined. '
'Options are full evaluation of each node, or use Thompson Sampling to quickly find the format spread.')
Expand Down
77 changes: 66 additions & 11 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,29 @@
import copy
from functools import lru_cache
import math
import os
import psutil

import bert_score
import openai
from rouge_score import rouge_scorer
import torch
import torch.nn.functional as F

from grammar_definition import apply_prompt_format, flatten
from parsing_supernatural_instructions_tasks import OPEN_GENERATION_SUPERNATURAL_INSTRUCTIONS_TASKS

openai.api_key = os.getenv("OPENAI_API_KEY")
PRINT_HIDDEN_STATE = False


# Keep track of 10 different messages and then warn again
# Adapted from https://stackoverflow.com/a/66062313
@lru_cache(10)
def print_once(msg: str):
print(msg)


def call_openai_api_with_retry(args, prompt, max_tokens=10):
try:
if args.gpt3_engine in ['gpt-4', 'gpt-3.5-turbo']:
Expand Down Expand Up @@ -45,11 +56,15 @@ def call_openai_api_with_retry(args, prompt, max_tokens=10):


def query_model_parallelized(model, tokenizer, prompt_list, max_tokens, top_p, temperature):
inputs = tokenizer(prompt_list, padding=True, return_tensors='pt', return_token_type_ids=False).to('cuda')
if tokenizer.chat_template is None:
inputs = tokenizer(prompt_list, padding=True, return_tensors='pt', return_token_type_ids=False).to('cuda')
else:
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Support chat models

turns = [[{"role": "user", "content": prompt}] for prompt in prompt_list]
inputs = tokenizer.apply_chat_template(turns, padding=True, return_tensors='pt', return_dict=True, add_generation_prompt=True, return_token_type_ids=False).to('cuda')

with torch.no_grad():
outputs = model.generate(
**inputs, top_p=top_p, temperature=temperature, max_new_tokens=max_tokens,
**inputs, pad_token_id=tokenizer.pad_token_id, top_p=top_p, temperature=temperature, max_new_tokens=max_tokens,
return_dict_in_generate=True, output_hidden_states=True, output_attentions=False, output_scores=True
)

Expand All @@ -68,7 +83,10 @@ def query_model_parallelized(model, tokenizer, prompt_list, max_tokens, top_p, t
logits_list[i].append(logits)
final_prompt_hidden_state_list = [None for _ in range(len(prompt_list))]

generated_answer_list = [s.lower() for s in tokenizer.batch_decode(outputs['sequences'], skip_special_tokens=True)]
sequences = outputs['sequences']
if tokenizer.chat_template is not None:
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The truncation of chat models is a bit different from regular models and need to be done separately. Also got rid of lowercasing for generation tasks (done below for classification tasks).

sequences = [seq[len(input_ids):] for seq, input_ids in zip(sequences, inputs["input_ids"], strict=True)]
generated_answer_list = [s for s in tokenizer.batch_decode(sequences, skip_special_tokens=True)]
return generated_answer_list, logits_list, final_prompt_hidden_state_list


Expand Down Expand Up @@ -128,6 +146,8 @@ def _setup_full_prompts_to_test_on(input_fields_list, regex_key_idx_list, select
replacing some variables referring to multiple choice options.
- Apply prompt format to the desired set of examples to be tested (determined by interval_ids_to_test).
"""
if n_shot == 0:
assert len(demos_fields_list) == 0 and len(demos_regex_key_idx_list) == 0 and len(demonstrations_outputs) == 0
demonstration_string = _setup_formatted_demonstrations_with_definition(
structured_prompt_format, demonstration_definition, demonstrations_outputs,
original_to_current_multiple_choice_classes, demos_fields_list, demos_regex_key_idx_list
Expand All @@ -143,11 +163,22 @@ def _setup_full_prompts_to_test_on(input_fields_list, regex_key_idx_list, select

full_prompt_string_list = []
for input_element, idx in zip(inputs, selected_dataset_ids):
full_prompt_string_list.append(input_element if n_shot == 0 else demonstration_string + "\n\n" + input_element)
full_prompt_string_list.append(demonstration_string + "\n\n" + input_element)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think even for 0-shot, the instruction (demonstration_definition) should still be there.


return full_prompt_string_list, selected_dataset_ids


# layzily initiailize a bertscore model
bert_score_model = None


def get_bert_score_model():
global bert_score_model
if bert_score_model is None:
bert_score_model = bert_score.BERTScorer(lang='en')
return bert_score_model


def evaluate_prompt_format(
args, dataset, input_fields_list, regex_key_idx_list, selected_dataset_ids,
demos_fields_list, demos_regex_key_idx_list, demonstrations_outputs, demonstration_definition,
Expand All @@ -165,7 +196,11 @@ def evaluate_prompt_format(
structured_prompt_format, original_to_current_multiple_choice_classes, interval_ids_to_test, args.n_shot)

# 2. update the output values if needed, i.e. if the multiple choice classes now have different names
assert all(len(dataset[idx]['output']) == 1 for idx in selected_dataset_ids)
assert all(len(dataset[idx]['output']) >= 1 for idx in selected_dataset_ids)
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Be a bit lenient here. This doesn't hold for some generation tasks. When that happens, only take the first one.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be better to raise some kind of warning in the > 1 case? Otherwise the code will silently take the first output and you might never know it happened

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah makes sense. Added a warning below.

for idx in selected_dataset_ids:
if len(dataset[idx]['output']) > 1:
print_once("Warning: multiple outputs for a single input. Only the first one will be used.")
dataset[idx]['output'] = [dataset[idx]['output'][0]]
dataset_updated = copy.deepcopy(dataset)
if original_to_current_multiple_choice_classes:
for idx in range(len(dataset)):
Expand All @@ -177,11 +212,25 @@ def evaluate_prompt_format(
return solve_with_rank_based_scoring(
dataset_updated, selected_dataset_ids, model, tokenizer, input_prompt_string_list, args.batch_size_llm)

elif args.evaluation_metric == 'exact_prefix_matching':
elif args.evaluation_metric in {'exact_prefix_matching', 'rouge', 'bertscore'}:
logs = generate_text_with_metadata(
args, input_prompt_string_list, model, tokenizer, model_will_repeat_input,
dataset_updated, selected_dataset_ids, output_classes)
return exact_prefix_matching_scoring(logs)
if args.evaluation_metric == 'exact_prefix_matching':
return exact_prefix_matching_scoring(logs)
elif args.evaluation_metric == 'rouge':
scorer = rouge_scorer.RougeScorer(['rougeL'])
scores = [scorer.score(entry['generation'], entry['answer'])["rougeL"].fmeasure for entry in logs]
avg_score = sum(scores) / len(scores)
return (avg_score, None, None), (None, logs)
elif args.evaluation_metric == 'bertscore':
references = [entry['answer'] for entry in logs]
candidates = [entry['generation'] for entry in logs]
scorer = get_bert_score_model()
P, R, F1 = scorer.score(candidates, references)
return (F1.mean().item(), None, None), (None, logs)




def generate_text_with_metadata(args, input_prompt_string_list, model, tokenizer, model_will_repeat_input, dataset,
Expand All @@ -205,7 +254,9 @@ def generate_text_with_metadata(args, input_prompt_string_list, model, tokenizer
generation_list, score_list, final_prompt_hidden_state_list = query_model_parallelized(
model, tokenizer, full_prompt_string_list, max_tokens=args.max_new_tokens, top_p=1.0, temperature=1.0,
)
if model_will_repeat_input:
if args.task_filename not in OPEN_GENERATION_SUPERNATURAL_INSTRUCTIONS_TASKS:
generation_list = [gen.lower() for gen in generation_list]
if model_will_repeat_input and tokenizer.chat_template is None: # for chat models, truncation happens in query_model_parallelized
generation_list = [generation[len(full_prompt_string):]
for generation, full_prompt_string in zip(generation_list, full_prompt_string_list)]

Expand Down Expand Up @@ -398,16 +449,20 @@ def get_ranking_based_generation_single_token_output_classes(prompts, output_cla
output_classes_tokens = [t for t in tokenizer(output_classes, return_token_type_ids=False)['input_ids']]
all_classes_share_common_prefix = len(set([tuple(t[:-1]) for t in output_classes_tokens])) == 1

tokenized_inputs_list = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False)[
'input_ids'].tolist()
if tokenizer.chat_template is None:
tokenized = tokenizer(prompts, return_tensors="pt", padding=True, return_token_type_ids=False)
else:
turns = [[{"role": "user", "content": prompt}] for prompt in prompts]
tokenized = tokenizer.apply_chat_template(turns, return_tensors="pt", padding=True, return_dict=True, add_generation_prompt=True, return_token_type_ids=False)
tokenized_inputs_list = tokenized["input_ids"].tolist()
if all_classes_share_common_prefix:
for i in range(len(tokenized_inputs_list)):
# if the tokenized element is [1, 29871, 29900], get [29871]
tokenized_inputs_list[i] += output_classes_tokens[0][1:-1]
tokenized_inputs = torch.tensor(tokenized_inputs_list).to('cuda')

with torch.no_grad():
outputs = model.generate(input_ids=tokenized_inputs,
outputs = model.generate(input_ids=tokenized_inputs, pad_token_id=tokenizer.pad_token_id,
top_p=top_p, temperature=temperature, max_new_tokens=1,
return_dict_in_generate=True, output_scores=True)

Expand Down