Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 45 additions & 8 deletions src/cnlpt/CnlpModelForClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,14 +427,17 @@ def compute_loss(
loss_fct = MSELoss()
task_loss = loss_fct(task_logits.view(-1), labels.view(-1))
else:
if self.class_weights[task_name] is not None:
class_weights = torch.FloatTensor(self.class_weights[task_name]).to(
self.device
)
if isinstance(self.class_weights, torch.Tensor):
class_weights = self.class_weights
else:
class_weights = None
loss_fct = CrossEntropyLoss(weight=class_weights)
if self.class_weights[task_name] is not None:
class_weights = torch.FloatTensor(self.class_weights[task_name]).to(
self.device
)
else:
class_weights = None

loss_fct = CrossEntropyLoss(weight=class_weights)
if self.relations[task_name]:
task_labels = labels[
:, :, state["task_label_ind"] : state["task_label_ind"] + seq_len
Expand All @@ -453,7 +456,7 @@ def compute_loss(
task_labels = labels[:, :, state["task_label_ind"]]
else:
task_labels = labels[:, 0, state["task_label_ind"], :]

state["task_label_ind"] += 1
task_loss = loss_fct(
task_logits.view(-1, task_num_labels),
Expand All @@ -477,7 +480,7 @@ def compute_loss(
"Have not implemented the case where a classification task "
"is part of an MTL setup with relations and sequence tagging"
)

state["task_label_ind"] += 1
task_loss = loss_fct(
task_logits, task_labels.type(torch.LongTensor).to(labels.device)
Expand All @@ -491,6 +494,40 @@ def compute_loss(
)
state["loss"] += task_weight * task_loss

def remove_task_classifiers(self, tasks: list[str] = None):
if tasks is None:
self.classifiers = nn.ModuleDict()
self.tasks = []
self.class_weights = {}
else:
for task in tasks:
self.classifiers.pop(task)
self.tasks.remove(task)
self.class_weights.pop(task)

def add_task_classifier(self, task_name: str, label_dictionary: dict[str, list]):
self.tasks.append(task_name)
task_num_labels = len(label_dictionary)
self.classifiers[task_name] = ClassificationHead(
self.config, len(label_dictionary)
)
if self.config.relations[task_name]:
hidden_size = self.config.num_rel_attention_heads
self.classifiers[task_name] = ClassificationHead(
self.config, task_num_labels, hidden_size=hidden_size
)
else:
self.classifiers[task_name] = ClassificationHead(
self.config, task_num_labels
)
self.label_dictionary[task_name] = label_dictionary

def set_class_weights(self, class_weights: Union[list[float], None] = None):
if class_weights is None:
self.class_weights = {x: None for x in self.label_dictionary.keys()}
else:
self.class_weights = class_weights

def forward(
self,
input_ids=None,
Expand Down
113 changes: 113 additions & 0 deletions src/cnlpt/api/tlink_rest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from fastapi import FastAPI
import numpy as np
from cnlp_rest import create_instance_string, initialize_cnlpt_model, create_dataset
import logging
from time import time
from nltk.tokenize import wordpunct_tokenize as tokenize

from temporal_rest import TokenizedSentenceDocument, SentenceDocument, Relation, TemporalResults, create_instance_string

Check failure on line 25 in src/cnlpt/api/tlink_rest.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F811)

src/cnlpt/api/tlink_rest.py:25:99: F811 Redefinition of unused `create_instance_string` from line 20

app = FastAPI()

Check failure on line 27 in src/cnlpt/api/tlink_rest.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

src/cnlpt/api/tlink_rest.py:18:1: I001 Import block is un-sorted or un-formatted
model_name = "../../../../../thyme/ft/models/inference/tlink/5688819"
logger = logging.getLogger('Tlink_REST_Processor')
logger.setLevel(logging.INFO)

max_length = 128

relation_label_list = ["AFTER", "BEFORE", "BEGINS-ON", "BEGINS-ON-1", "CONTAINS", "CONTAINS-1", "ENDS-ON", "ENDS-ON-1", "NOTED-ON", "NOTED-ON-1", "None", "OVERLAP"]

@app.on_event("startup")
async def startup_event():
initialize_cnlpt_model(app, model_name)

@app.post("/temporal/process")
async def process(doc: TokenizedSentenceDocument):
return process_tokenized_sentence_document(doc)

@app.post("/temporal/process_sentence")
async def process_sentence(doc: SentenceDocument):
tokenized_sent = tokenize(doc.sentence)
doc = TokenizedSentenceDocument(sent_tokens=[tokenized_sent,], metadata='Single sentence')
return process_tokenized_sentence_document(doc)

def process_tokenized_sentence_document(doc: TokenizedSentenceDocument):
sents = doc.sent_tokens
metadata = doc.metadata

logger.warn('Received document labeled %s with %d sentences' % (metadata, len(sents)))

Check failure on line 54 in src/cnlpt/api/tlink_rest.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (G010)

src/cnlpt/api/tlink_rest.py:54:12: G010 Logging statement uses `warn` instead of `warning`

Check failure on line 54 in src/cnlpt/api/tlink_rest.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (G002)

src/cnlpt/api/tlink_rest.py:54:17: G002 Logging statement uses `%`
instances = []
start_time = time()

for sent_ind, token_list in enumerate(sents):
inst_str = create_instance_string(token_list)
logger.debug('Instance string is %s' % (inst_str))

Check failure on line 60 in src/cnlpt/api/tlink_rest.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP031)

src/cnlpt/api/tlink_rest.py:60:22: UP031 Use format specifiers instead of percent format

Check failure on line 60 in src/cnlpt/api/tlink_rest.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (G002)

src/cnlpt/api/tlink_rest.py:60:22: G002 Logging statement uses `%`
instances.append(inst_str)

dataset = create_dataset(instances, app.state.tokenizer, label_lists=[relation_label_list], tasks=['tlink'], max_length=max_length)
preproc_end = time()

output = app.state.trainer.predict(test_dataset=dataset)

rel_predictions = np.argmax(output.predictions[0], axis=1)

pred_end = time()

timex_results = []
event_results = []
rel_results = []

for sent_ind in range(len(dataset)):
event_results.append( [] )
timex_results.append( [] )
rel_results.append( [Relation(arg1=None, arg2=None, category=relation_label_list[rel_predictions[sent_ind]], arg1_start=sents[sent_ind].index("<e1>"), arg2_start=sents[sent_ind].index("<e2>"))] )


results = TemporalResults(timexes=timex_results, events=event_results, relations=rel_results)

postproc_end = time()

preproc_time = preproc_end - start_time
pred_time = pred_end - preproc_end
postproc_time = postproc_end - pred_end

logging.info("Pre-processing time: %f, processing time: %f, post-processing time %f" % (preproc_time, pred_time, postproc_time))

Check failure on line 90 in src/cnlpt/api/tlink_rest.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP031)

src/cnlpt/api/tlink_rest.py:90:18: UP031 Use format specifiers instead of percent format

Check failure on line 90 in src/cnlpt/api/tlink_rest.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (G002)

src/cnlpt/api/tlink_rest.py:90:18: G002 Logging statement uses `%`

return results


@app.post("/temporal/collection_process_complete")
async def collection_process_complete():
app.state.trainer = None

def rest():
import argparse

parser = argparse.ArgumentParser(description='Run the http server for temporal relation extraction')
parser.add_argument('-p', '--port', type=int, help='The port number to run the server on', default=8000)

args = parser.parse_args()

import uvicorn
uvicorn.run("tlink_rest:app", host='0.0.0.0', port=args.port, reload=True)


if __name__ == '__main__':
rest()

12 changes: 10 additions & 2 deletions src/cnlpt/cnlp_metrics.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import logging
from typing import Any

import numpy as np
from seqeval.metrics import classification_report as seq_cls
from seqeval.metrics import f1_score as seq_f1
from sklearn.metrics import (
accuracy_score,
classification_report,
f1_score,
precision_score,
recall_score,
precision_recall_fscore_support,

Check failure on line 13 in src/cnlpt/cnlp_metrics.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

src/cnlpt/cnlp_metrics.py:13:5: F401 `sklearn.metrics.precision_recall_fscore_support` imported but unused
)

from .cnlp_processors import classification, relex, tagging

logger = logging.getLogger(__name__)

Check failure on line 18 in src/cnlpt/cnlp_metrics.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

src/cnlpt/cnlp_metrics.py:1:1: I001 Import block is un-sorted or un-formatted

import pdb

def fix_np_types(input_variable):
"""
Expand All @@ -31,6 +32,10 @@
return input_variable


def is_tagged(input_tag, task_name):
return "O" if input_tag == "O" else f"B-{task_name}"


def tagging_metrics(
label_set: list[str],
preds: np.ndarray,
Expand Down Expand Up @@ -69,6 +74,9 @@
pred_seq = [label_set[x] for x in preds]
label_seq = [label_set[x] for x in labels]

is_pred_entity = [is_tagged(p, task_name.upper()) for p in pred_seq]
is_label_entity = [is_tagged(l, task_name.upper()) for l in label_seq]

num_correct = (preds == labels).sum()

acc = num_correct / len(preds)
Expand All @@ -78,7 +86,7 @@
"acc": acc,
"token_f1": fix_np_types(f1),
"f1": fix_np_types(seq_f1([label_seq], [pred_seq])),
"report": "\n" + seq_cls([label_seq], [pred_seq]),
"report": "\n" + seq_cls([is_label_entity], [is_pred_entity]) + "\n" + seq_cls([label_seq], [pred_seq]),
}


Expand Down
36 changes: 20 additions & 16 deletions src/cnlpt/cnlp_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import pandas as pd
import tqdm
from transformers import EvalPrediction

from .cnlp_processors import classification, relex, tagging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -178,7 +177,7 @@ def compute_disagreements(

def classification_disagreements(preds: np.ndarray, labels: np.ndarray) -> np.ndarray:
(indices,) = np.where(np.not_equal(preds, labels))
return indices
return [i for i in range(len(preds))]


def relation_or_tagging_disagreements(
Expand All @@ -190,7 +189,7 @@ def relation_or_tagging_disagreements(
for pred, label in zip(preds.astype(int), labels.astype(int))
]
)
return indices
return [i for i in range(len(preds))]


def process_prediction(
Expand Down Expand Up @@ -254,9 +253,9 @@ def process_prediction(
out_table["text"] = list(eval_dataset["text"])
out_table["text"] = out_table["text"].apply(remove_newline)

out_table["text"] = out_table["text"].str.replace('"', "")
out_table["text"] = out_table["text"].str.replace("//", "")
out_table["text"] = out_table["text"].str.replace("\\", "")
out_table["text"] = out_table["text"].str.replace('"', "'")
# out_table["text"] = out_table["text"].str.replace("//", "")
# out_table["text"] = out_table["text"].str.replace("\\", "")
word_ids = eval_dataset["word_ids"]
for task_name, packet in tqdm.tqdm(
task_to_label_packet.items(), desc="getting human readable labels"
Expand Down Expand Up @@ -308,14 +307,13 @@ def get_outputs(
if error_analysis:
if len(error_inds) > 0:
relevant_prob_values = (
prob_values[error_inds]
prob_values#[error_inds]
if output_mode[pred_task] == classification and len(prob_values) > 0
else np.array([])
)
ground_truth = labels[error_inds].astype(int)
task_prediction = prediction[error_inds].astype(int)
text_samples = pd.Series(text_column[error_inds])
word_ids = [word_ids[error_ind] for error_ind in error_inds]
ground_truth = labels.astype(int)
task_prediction = prediction.astype(int)
text_samples = pd.Series(text_column)
else:
return pd.Series([])
else:
Expand All @@ -334,6 +332,13 @@ def get_outputs(
)

elif task_type == tagging:
out = [
" ".join([task_labels[pred] for pred, label in zip(sent, truth) if label > -100])
for sent, truth in zip(task_prediction, ground_truth)
]

return pd.Series(out)

return get_tagging_prints(
character_level,
pred_task,
Expand Down Expand Up @@ -362,8 +367,6 @@ def get_classification_prints(

def clean_string(gp: tuple[str, str]) -> str:
ground, predicted = gp
if ground == predicted:
return f"_{task_name}_error_detection_bug_"
return f"Ground: {ground} Predicted: {predicted}"

pred_list = predicted_labels
Expand Down Expand Up @@ -514,11 +517,12 @@ def get_error_out_string(
) -> str:
instance_tokens = get_tokens(instance)
ground_string = dict_to_str(disagreements["ground"], instance_tokens)
if len(ground_string) == 0:
ground_string = "None"

predicted_string = dict_to_str(disagreements["predicted"], instance_tokens)

if len(ground_string) == 0 == len(predicted_string):
return f"_{task_name}_error_detection_bug_"
if len(predicted_string) == 0:
predicted_string = "None"

return f"Ground: {ground_string} Predicted: {predicted_string}"

Expand Down
2 changes: 1 addition & 1 deletion src/cnlpt/cnlp_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __init__(self, data_dir: str, tasks: set[str] = None, max_train_items=-1):
else:
sep = "\t"

self.dataset = load_dataset("csv", sep=sep, data_files=data_files)
self.dataset = load_dataset("csv", sep=sep, data_files=data_files, keep_default_na=False)

## find out what tasks are available to this dataset, and see the overlap with what the
## user specified at the cli, remove those tasks so we don't also get them from other datasets
Expand Down
Loading
Loading