diff --git a/src/cnlpt/CnlpModelForClassification.py b/src/cnlpt/CnlpModelForClassification.py index 30fdde88..9f38923b 100644 --- a/src/cnlpt/CnlpModelForClassification.py +++ b/src/cnlpt/CnlpModelForClassification.py @@ -427,14 +427,17 @@ def compute_loss( loss_fct = MSELoss() task_loss = loss_fct(task_logits.view(-1), labels.view(-1)) else: - if self.class_weights[task_name] is not None: - class_weights = torch.FloatTensor(self.class_weights[task_name]).to( - self.device - ) + if isinstance(self.class_weights, torch.Tensor): + class_weights = self.class_weights else: - class_weights = None - loss_fct = CrossEntropyLoss(weight=class_weights) + if self.class_weights[task_name] is not None: + class_weights = torch.FloatTensor(self.class_weights[task_name]).to( + self.device + ) + else: + class_weights = None + loss_fct = CrossEntropyLoss(weight=class_weights) if self.relations[task_name]: task_labels = labels[ :, :, state["task_label_ind"] : state["task_label_ind"] + seq_len @@ -453,7 +456,7 @@ def compute_loss( task_labels = labels[:, :, state["task_label_ind"]] else: task_labels = labels[:, 0, state["task_label_ind"], :] - + state["task_label_ind"] += 1 task_loss = loss_fct( task_logits.view(-1, task_num_labels), @@ -477,7 +480,7 @@ def compute_loss( "Have not implemented the case where a classification task " "is part of an MTL setup with relations and sequence tagging" ) - + state["task_label_ind"] += 1 task_loss = loss_fct( task_logits, task_labels.type(torch.LongTensor).to(labels.device) @@ -491,6 +494,40 @@ def compute_loss( ) state["loss"] += task_weight * task_loss + def remove_task_classifiers(self, tasks: list[str] = None): + if tasks is None: + self.classifiers = nn.ModuleDict() + self.tasks = [] + self.class_weights = {} + else: + for task in tasks: + self.classifiers.pop(task) + self.tasks.remove(task) + self.class_weights.pop(task) + + def add_task_classifier(self, task_name: str, label_dictionary: dict[str, list]): + self.tasks.append(task_name) + task_num_labels = len(label_dictionary) + self.classifiers[task_name] = ClassificationHead( + self.config, len(label_dictionary) + ) + if self.config.relations[task_name]: + hidden_size = self.config.num_rel_attention_heads + self.classifiers[task_name] = ClassificationHead( + self.config, task_num_labels, hidden_size=hidden_size + ) + else: + self.classifiers[task_name] = ClassificationHead( + self.config, task_num_labels + ) + self.label_dictionary[task_name] = label_dictionary + + def set_class_weights(self, class_weights: Union[list[float], None] = None): + if class_weights is None: + self.class_weights = {x: None for x in self.label_dictionary.keys()} + else: + self.class_weights = class_weights + def forward( self, input_ids=None, diff --git a/src/cnlpt/api/tlink_rest.py b/src/cnlpt/api/tlink_rest.py new file mode 100644 index 00000000..22297226 --- /dev/null +++ b/src/cnlpt/api/tlink_rest.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from fastapi import FastAPI +import numpy as np +from cnlp_rest import create_instance_string, initialize_cnlpt_model, create_dataset +import logging +from time import time +from nltk.tokenize import wordpunct_tokenize as tokenize + +from temporal_rest import TokenizedSentenceDocument, SentenceDocument, Relation, TemporalResults, create_instance_string + +app = FastAPI() +model_name = "../../../../../thyme/ft/models/inference/tlink/5688819" +logger = logging.getLogger('Tlink_REST_Processor') +logger.setLevel(logging.INFO) + +max_length = 128 + +relation_label_list = ["AFTER", "BEFORE", "BEGINS-ON", "BEGINS-ON-1", "CONTAINS", "CONTAINS-1", "ENDS-ON", "ENDS-ON-1", "NOTED-ON", "NOTED-ON-1", "None", "OVERLAP"] + +@app.on_event("startup") +async def startup_event(): + initialize_cnlpt_model(app, model_name) + +@app.post("/temporal/process") +async def process(doc: TokenizedSentenceDocument): + return process_tokenized_sentence_document(doc) + +@app.post("/temporal/process_sentence") +async def process_sentence(doc: SentenceDocument): + tokenized_sent = tokenize(doc.sentence) + doc = TokenizedSentenceDocument(sent_tokens=[tokenized_sent,], metadata='Single sentence') + return process_tokenized_sentence_document(doc) + +def process_tokenized_sentence_document(doc: TokenizedSentenceDocument): + sents = doc.sent_tokens + metadata = doc.metadata + + logger.warn('Received document labeled %s with %d sentences' % (metadata, len(sents))) + instances = [] + start_time = time() + + for sent_ind, token_list in enumerate(sents): + inst_str = create_instance_string(token_list) + logger.debug('Instance string is %s' % (inst_str)) + instances.append(inst_str) + + dataset = create_dataset(instances, app.state.tokenizer, label_lists=[relation_label_list], tasks=['tlink'], max_length=max_length) + preproc_end = time() + + output = app.state.trainer.predict(test_dataset=dataset) + + rel_predictions = np.argmax(output.predictions[0], axis=1) + + pred_end = time() + + timex_results = [] + event_results = [] + rel_results = [] + + for sent_ind in range(len(dataset)): + event_results.append( [] ) + timex_results.append( [] ) + rel_results.append( [Relation(arg1=None, arg2=None, category=relation_label_list[rel_predictions[sent_ind]], arg1_start=sents[sent_ind].index(""), arg2_start=sents[sent_ind].index(""))] ) + + + results = TemporalResults(timexes=timex_results, events=event_results, relations=rel_results) + + postproc_end = time() + + preproc_time = preproc_end - start_time + pred_time = pred_end - preproc_end + postproc_time = postproc_end - pred_end + + logging.info("Pre-processing time: %f, processing time: %f, post-processing time %f" % (preproc_time, pred_time, postproc_time)) + + return results + + +@app.post("/temporal/collection_process_complete") +async def collection_process_complete(): + app.state.trainer = None + +def rest(): + import argparse + + parser = argparse.ArgumentParser(description='Run the http server for temporal relation extraction') + parser.add_argument('-p', '--port', type=int, help='The port number to run the server on', default=8000) + + args = parser.parse_args() + + import uvicorn + uvicorn.run("tlink_rest:app", host='0.0.0.0', port=args.port, reload=True) + + +if __name__ == '__main__': + rest() + diff --git a/src/cnlpt/cnlp_metrics.py b/src/cnlpt/cnlp_metrics.py index dc5725bb..d404bd21 100644 --- a/src/cnlpt/cnlp_metrics.py +++ b/src/cnlpt/cnlp_metrics.py @@ -10,12 +10,13 @@ f1_score, precision_score, recall_score, + precision_recall_fscore_support, ) from .cnlp_processors import classification, relex, tagging logger = logging.getLogger(__name__) - +import pdb def fix_np_types(input_variable): """ @@ -31,6 +32,10 @@ def fix_np_types(input_variable): return input_variable +def is_tagged(input_tag, task_name): + return "O" if input_tag == "O" else f"B-{task_name}" + + def tagging_metrics( label_set: list[str], preds: np.ndarray, @@ -69,6 +74,9 @@ def tagging_metrics( pred_seq = [label_set[x] for x in preds] label_seq = [label_set[x] for x in labels] + is_pred_entity = [is_tagged(p, task_name.upper()) for p in pred_seq] + is_label_entity = [is_tagged(l, task_name.upper()) for l in label_seq] + num_correct = (preds == labels).sum() acc = num_correct / len(preds) @@ -78,7 +86,7 @@ def tagging_metrics( "acc": acc, "token_f1": fix_np_types(f1), "f1": fix_np_types(seq_f1([label_seq], [pred_seq])), - "report": "\n" + seq_cls([label_seq], [pred_seq]), + "report": "\n" + seq_cls([is_label_entity], [is_pred_entity]) + "\n" + seq_cls([label_seq], [pred_seq]), } diff --git a/src/cnlpt/cnlp_predict.py b/src/cnlpt/cnlp_predict.py index ce37ffc8..57195fdb 100644 --- a/src/cnlpt/cnlp_predict.py +++ b/src/cnlpt/cnlp_predict.py @@ -12,7 +12,6 @@ import pandas as pd import tqdm from transformers import EvalPrediction - from .cnlp_processors import classification, relex, tagging logger = logging.getLogger(__name__) @@ -178,7 +177,7 @@ def compute_disagreements( def classification_disagreements(preds: np.ndarray, labels: np.ndarray) -> np.ndarray: (indices,) = np.where(np.not_equal(preds, labels)) - return indices + return [i for i in range(len(preds))] def relation_or_tagging_disagreements( @@ -190,7 +189,7 @@ def relation_or_tagging_disagreements( for pred, label in zip(preds.astype(int), labels.astype(int)) ] ) - return indices + return [i for i in range(len(preds))] def process_prediction( @@ -254,9 +253,9 @@ def process_prediction( out_table["text"] = list(eval_dataset["text"]) out_table["text"] = out_table["text"].apply(remove_newline) - out_table["text"] = out_table["text"].str.replace('"', "") - out_table["text"] = out_table["text"].str.replace("//", "") - out_table["text"] = out_table["text"].str.replace("\\", "") + out_table["text"] = out_table["text"].str.replace('"', "'") + # out_table["text"] = out_table["text"].str.replace("//", "") + # out_table["text"] = out_table["text"].str.replace("\\", "") word_ids = eval_dataset["word_ids"] for task_name, packet in tqdm.tqdm( task_to_label_packet.items(), desc="getting human readable labels" @@ -308,14 +307,13 @@ def get_outputs( if error_analysis: if len(error_inds) > 0: relevant_prob_values = ( - prob_values[error_inds] + prob_values#[error_inds] if output_mode[pred_task] == classification and len(prob_values) > 0 else np.array([]) ) - ground_truth = labels[error_inds].astype(int) - task_prediction = prediction[error_inds].astype(int) - text_samples = pd.Series(text_column[error_inds]) - word_ids = [word_ids[error_ind] for error_ind in error_inds] + ground_truth = labels.astype(int) + task_prediction = prediction.astype(int) + text_samples = pd.Series(text_column) else: return pd.Series([]) else: @@ -334,6 +332,13 @@ def get_outputs( ) elif task_type == tagging: + out = [ + " ".join([task_labels[pred] for pred, label in zip(sent, truth) if label > -100]) + for sent, truth in zip(task_prediction, ground_truth) + ] + + return pd.Series(out) + return get_tagging_prints( character_level, pred_task, @@ -362,8 +367,6 @@ def get_classification_prints( def clean_string(gp: tuple[str, str]) -> str: ground, predicted = gp - if ground == predicted: - return f"_{task_name}_error_detection_bug_" return f"Ground: {ground} Predicted: {predicted}" pred_list = predicted_labels @@ -514,11 +517,12 @@ def get_error_out_string( ) -> str: instance_tokens = get_tokens(instance) ground_string = dict_to_str(disagreements["ground"], instance_tokens) + if len(ground_string) == 0: + ground_string = "None" predicted_string = dict_to_str(disagreements["predicted"], instance_tokens) - - if len(ground_string) == 0 == len(predicted_string): - return f"_{task_name}_error_detection_bug_" + if len(predicted_string) == 0: + predicted_string = "None" return f"Ground: {ground_string} Predicted: {predicted_string}" diff --git a/src/cnlpt/cnlp_processors.py b/src/cnlpt/cnlp_processors.py index 05b010c9..e5a4dfb6 100644 --- a/src/cnlpt/cnlp_processors.py +++ b/src/cnlpt/cnlp_processors.py @@ -167,7 +167,7 @@ def __init__(self, data_dir: str, tasks: set[str] = None, max_train_items=-1): else: sep = "\t" - self.dataset = load_dataset("csv", sep=sep, data_files=data_files) + self.dataset = load_dataset("csv", sep=sep, data_files=data_files, keep_default_na=False) ## find out what tasks are available to this dataset, and see the overlap with what the ## user specified at the cli, remove those tasks so we don't also get them from other datasets diff --git a/src/cnlpt/get_tlink.py b/src/cnlpt/get_tlink.py new file mode 100644 index 00000000..4c915e3c --- /dev/null +++ b/src/cnlpt/get_tlink.py @@ -0,0 +1,163 @@ +import os +import sys +import pdb +from tqdm import tqdm + +import anafora +from anafora import AnaforaData, AnaforaRelation +import requests + +# sentence and token splitters: +from spacy.lang.en import English +nlp = English() +tokenizer = nlp.tokenizer + +xml_name_regex = r'Temporal_(Entity|Relation)\.dave\.(completed|inprogress)\.xml' + +reverse_label = { + "AFTER": "BEFORE", + "BEFORE": "AFTER", + "BEGINS-ON": "BEGINS-ON-1", + "BEGINS-ON-1": "BEGINS-ON", + "CONTAINS": "CONTAINS-1", + "CONTAINS-1": "CONTAINS", + "CONTAINS-SUBEVENT": "CONTAINS-SUBEVENT-1", + "CONTAINS-SUBEVENT-1": "CONTAINS-SUBEVENT", + "ENDS-ON": "ENDS-ON-1", + "ENDS-ON-1": "ENDS-ON", + "NOTED-ON": "NOTED-ON-1", + "NOTED-ON-1": "NOTED-ON", + "OVERLAP": "OVERLAP" +} + +relabel = { + "BEGINS-ON-1": "BEFORE", + "CONTAINS-SUBEVENT": "CONTAINS", + "ENDS-ON": "BEFORE", + "NOTED-ON": "OVERLAP", + "NOTED-ON-1": "OVERLAP" +} + +should_switch = { + "AFTER": "BEFORE", + "BEGINS-ON": "BEFORE", + "CONTAINS-1": "CONTAINS", + "CONTAINS-SUBEVENT-1": "CONTAINS", + "ENDS-ON-1": "BEFORE" +} + + +def order_pair(arg1, label, arg2): + if arg1.spans > arg2.spans: + print("THIS SHOULD NOT HAPPEN") + temp = arg1 + arg1 = arg2 + arg2 = temp + label = reverse_label[label] + if label in should_switch: + temp = arg1 + arg1 = arg2 + arg2 = temp + label = should_switch[label] + elif label in relabel: + label = relabel[label] + return arg1, label, arg2 + + +def main(args): + if len(args) < 3: + sys.stderr.write("Required arguments: [text_dir]") + sys.exit(-1) + + text_dir = args[3] if len(args) > 3 else args[0] + + url = "http://%s/temporal/process" % args[1] + for sub_dir, text_name, xml_names in tqdm(anafora.walk(args[0], xml_name_regex)): + rel_idx = 0 + # print("Processing filename: %s" % (text_name)) + if len(xml_names) > 1: + sys.stderr.write('There were multiple valid xml files for file %s\n' % (text_name)) + filtered_names = [] + for xml_name in xml_names: + if 'dave' in xml_name: + filtered_names.append(xml_name) + if len(filtered_names) == 1: + sys.stderr.write('Picking the file with "dave" in the title: %s\n' % (filtered_names[0]) ) + xml_names = filtered_names + else: + sys.exit(-1) + xml_name = xml_names[0] + # if os.path.exists(os.path.join(args[2], sub_dir, xml_name)): + # continue + + entity_data = AnaforaData.from_file(os.path.join(args[0], sub_dir, xml_name)) + with open(os.path.join(text_dir, sub_dir, text_name)) as f: + full_text = f.read() + + to_remove = [] + for annot in entity_data.annotations: + if annot.type not in ["EVENT", "TIMEX3"]: + to_remove.append(annot) + for annot in to_remove: + entity_data.annotations.remove(annot) + + tokenized_text = tokenizer(full_text) + span_to_token = {(token.idx, token.idx+len(token)): token.i for token in tokenized_text} + for entity in entity_data.annotations: + try: + entity.tokens = (span_to_token[entity.spans[0]], span_to_token[entity.spans[0]] + 1) + except KeyError: + start_token, end_token = -1, -1 + for key in span_to_token: + if key[0] <= entity.spans[0][0] and key[1] >= entity.spans[0][0]: + start_token = span_to_token[key] + if key[0] <= entity.spans[0][1] and key[1] >= entity.spans[0][1]: + end_token = span_to_token[key] + 1 + if start_token > -1 and end_token > -1: + entity.tokens = (start_token, end_token) + else: + pdb.set_trace() + + # sorting entities for efficiency + entity_annots = [a for a in entity_data.annotations] # if a.type in ["EVENT", "TIMEX3"]] + sorted_entity_annots = sorted(entity_annots, key=lambda x: x.spans[0]) + + for i, ent0 in enumerate(sorted_entity_annots): + for j, ent1 in enumerate(sorted_entity_annots[i+1:]): + # if ent1.type not in ["EVENT", "TIMEX3"] or ent0 == ent1 or ent0.tokens[0] > ent1.tokens[0] or ent1.tokens[1] - ent0.tokens[0] > 100: + # continue + if ent1.tokens[1] - ent0.tokens[0] > 20: + break # we've gone too far; skip to the next ent0 + text_start = max(0, ent0.tokens[0] - 12) + text_end = ent1.tokens[1] + 12 + sent = [token.text for token in tokenized_text[text_start:ent0.tokens[0]]] + sent += [""] + [token.text for token in tokenized_text[ent0.tokens[0]:ent0.tokens[1]]] + [""] + sent += [token.text for token in tokenized_text[ent0.tokens[1]:ent1.tokens[0]]] + sent += [""] + [token.text for token in tokenized_text[ent1.tokens[0]:ent1.tokens[1]]] + [""] + sent += [token.text for token in tokenized_text[ent1.tokens[1]:text_end]] + + _r = requests.post(url, json={'sent_tokens': [sent], 'metadata':text_name}) + if _r.status_code != 200: + sys.stderr.write('Error: tlink rest call was not successful\n') + sys.exit(-1) + + if _r.json()["relations"][0][0]["category"] != "None": + new_rel = AnaforaRelation(_annotations=entity_data.annotations) + new_rel.id = f"{rel_idx}@r@{text_name}" + new_rel.type = "TLINK" + ent0, rel_type, ent1 = order_pair(ent0, _r.json()["relations"][0][0]["category"], ent1) + new_rel.properties["Source"] = ent0 + new_rel.properties["Type"] = rel_type + new_rel.properties["Target"] = ent1 + entity_data.annotations.append(new_rel) + # annots_to_add.append(new_rel) + rel_idx += 1 + + # for annot in annots_to_add: + # entity_data.annotations.append(annot) + os.makedirs(os.path.join(args[2], sub_dir), exist_ok=True) + entity_data.to_file(os.path.join(args[2], sub_dir, xml_name)) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/src/cnlpt/train_system.py b/src/cnlpt/train_system.py index bdb335be..6aa882cc 100644 --- a/src/cnlpt/train_system.py +++ b/src/cnlpt/train_system.py @@ -309,7 +309,10 @@ def main( dataset.tasks_to_labels[task] = dataset.tasks_to_labels[task][1:] + [ dataset.tasks_to_labels[task][0] ] - labels = dataset.processed_dataset["train"][task] + if tagger[task]: + labels = [token_label for sent in dataset.processed_dataset["train"][task] for token_label in sent.split()] + else: + labels = dataset.processed_dataset["train"][task] weights = [] label_counts = Counter(labels) for label in dataset.tasks_to_labels[task]: @@ -446,6 +449,7 @@ def main( # TODO check when download any pretrained language model to local disk, if # the following condition "is_hub_model(encoder_name)" works or not. + # ^ is_hub_model and is_external_encoder both return False, as long as "model_type": "cnlpt" is in config.json if not is_external_encoder(encoder_name): # we are loading one of our own trained models as a starting point. # @@ -459,7 +463,6 @@ def main( # the model file to be loaded down below the normal way. since that temp file # doesn't have a stored classifier it will use the randomly-inited classifier head # with the size of the supplied config (for the new task). - # TODO This setting 1) is not tested yet. # 2) if training_args.do_train is false: # we evaluate or make predictions of our trained models. # Both two setting require the registeration of CnlpConfig, and use @@ -468,6 +471,11 @@ def main( # Load the cnlp configuration using AutoConfig, this will not override # the arguments from trained cnlp models. While using CnlpConfig will override # the model_type and model_name of the encoder. + if model_args.keep_existing_classifiers == model_args.ignore_existing_classifiers: # XNOR + raise ValueError( + "For continued training of a cnlpt model, one of --keep_existing_classifiers or --ignore_existing_classifiers flags should be selected." + ) + config = AutoConfig.from_pretrained( ( model_args.config_name @@ -477,41 +485,56 @@ def main( cache_dir=model_args.cache_dir, # in this case we're looking at a fine-tuned model (?) character_level=data_args.character_level, + layer=model_args.layer, ) - if training_args.do_train: # Setting 1) only load weights from the encoder - raise NotImplementedError( - "This functionality has not been restored yet" - ) + if model_args.ignore_existing_classifiers: + config.finetuning_task = ( + data_args.task_name + if data_args.task_name is not None + else dataset.tasks + ) + elif model_args.keep_existing_classifiers: + # setting 2) evaluate or make predictions + if ( + config.finetuning_task != data_args.task_name + or config.relations != relations + or config.tagger != tagger + ): + raise ValueError( + "When --keep_existing_classifiers is selected, please ensure" + "that you set the settings the same as those used in the" + "previous training run." + ) + model = CnlpModelForClassification( - model_path=model_args.encoder_name, config=config, - cache_dir=model_args.cache_dir, - tagger=tagger, - relations=relations, - class_weights=dataset.class_weights, + # class_weights=dataset.class_weights, + class_weights=class_weights, final_task_weight=training_args.final_task_weight, - use_prior_tasks=model_args.use_prior_tasks, - argument_regularization=model_args.arg_reg, ) - delattr(model, "classifiers") - delattr(model, "feature_extractors") + if model_args.ignore_existing_classifiers: + model.remove_task_classifiers() + for task in data_args.task_name: + model.add_task_classifier(task, dataset.get_labels()[task]) + model.set_class_weights(dataset.class_weights) + if training_args.do_train: tempmodel = tempfile.NamedTemporaryFile(dir=model_args.cache_dir) torch.save(model.state_dict(), tempmodel) model_name = tempmodel.name - else: + else: # load existing head # setting 2) evaluate or make predictions model = CnlpModelForClassification.from_pretrained( model_args.encoder_name, config=config, - class_weights=dataset.class_weights, + class_weights=class_weights, final_task_weight=training_args.final_task_weight, freeze=training_args.freeze, bias_fit=training_args.bias_fit, ) - + model.tasks = data_args.task_name else: # This only works when model_args.encoder_name is one of the # model card from https://huggingface.co/models @@ -541,7 +564,7 @@ def main( config.vocab_size = len(tokenizer) model = CnlpModelForClassification( config=config, - class_weights=dataset.class_weights, + class_weights=class_weights, final_task_weight=training_args.final_task_weight, freeze=training_args.freeze, bias_fit=training_args.bias_fit, @@ -656,15 +679,22 @@ def compute_metrics_fn(p: EvalPrediction): raise RuntimeError( f"Unrecognized label type: {type(training_args.model_selection_label)}" ) - else: # same default as in 0.6.0 + elif dataset.output_modes[task] == relex: task_scores.append( metrics[task_name].get( "one_score", np.mean(metrics[task_name].get("f1")) ) ) + else: + task_scores.append( + metrics[task_name].get( + "one_score", np.mean(metrics[task_name].get("token_f1")) + ) + ) # task_scores.append(processor.get_one_score(metrics.get(task_name, metrics.get(task_name.split('-')[0], None)))) one_score = sum(task_scores) / len(task_scores) + metrics["one_score"] = one_score if model is not None: if not hasattr(model, "best_score") or one_score > model.best_score: @@ -675,7 +705,7 @@ def compute_metrics_fn(p: EvalPrediction): model.best_eval_results = metrics if trainer.is_world_process_zero(): if training_args.do_train: - trainer.save_model() + trainer.save_model() # NOTE: a RobertaConfig is loaded here. why? tokenizer.save_pretrained(training_args.output_dir) if model_name == "cnn" or model_name == "lstm": with open( @@ -690,7 +720,7 @@ def compute_metrics_fn(p: EvalPrediction): ) config_dict["task_names"] = task_names json.dump(config_dict, f) - for task_ind, task_name in enumerate(metrics): + for task_ind, task_name in enumerate(task_names): with open(output_eval_file, "a") as writer: logger.info( f"***** Eval results for task {task_name} *****" @@ -720,7 +750,8 @@ def compute_metrics_fn(p: EvalPrediction): return compute_metrics_fn # Initialize our Trainer - training_args.load_best_model_at_end = True + # training_args.load_best_model_at_end = True + # TODO the argument in CnlpTrainingArguments is `model_selection_score`. reconcile this with `metric_for_best_model`? training_args.metric_for_best_model = "one_score" trainer = Trainer( model=model, @@ -884,7 +915,7 @@ def compute_metrics_fn(p: EvalPrediction): out_table = process_prediction( task_names=dataset.tasks, - error_analysis=False, + error_analysis=training_args.error_analysis, output_prob=training_args.output_prob, character_level=data_args.character_level, task_to_label_packet=task_to_label_packet,