diff --git a/src/cnlpt/_cli/rest.py b/src/cnlpt/_cli/rest.py index dd857277..60a81da6 100644 --- a/src/cnlpt/_cli/rest.py +++ b/src/cnlpt/_cli/rest.py @@ -1,13 +1,43 @@ +from typing import Union + import click -from ..api import MODEL_TYPES, get_rest_app + +def parse_models( + ctx: click.Context, + param: click.Parameter, + value: Union[tuple[str, ...], None], +): + if value is None: + return None + + models: list[tuple[str, str]] = [] + for item in value: + if "=" in item: + prefix, path = item.split("=", 1) + if not prefix.startswith("/"): + raise click.BadParameter( + f"route prefix must start with '/': {prefix}", param=param + ) + elif len(value) > 1: + raise click.BadParameter( + "route prefixes are required when serving more than one model", + param=param, + ) + else: + path = item + prefix = "" + models.append((prefix, path)) + return models @click.command("rest", context_settings={"show_default": True}) @click.option( - "--model-type", - type=click.Choice(MODEL_TYPES), - required=True, + "--model", + "models", + multiple=True, + callback=parse_models, + help="Model definition as [ROUTER_PREFIX=]PATH_TO_MODEL. Prefix must start with '/'.", ) @click.option( "-h", @@ -19,15 +49,13 @@ @click.option( "-p", "--port", type=int, default=8000, help="Port to serve the REST app." ) -@click.option( - "--reload", - type=bool, - is_flag=True, - default=False, - help="Auto-reload the REST app.", -) -def rest_command(model_type: str, host: str, port: int, reload: bool): +def rest_command(models: list[tuple[str, str]], host: str, port: int): """Start a REST application from a model.""" import uvicorn - uvicorn.run(get_rest_app(model_type), host=host, port=port, reload=reload) + from ..rest import CnlpRestApp + + app = CnlpRestApp.multi_app( + [(CnlpRestApp(model_path=path), prefix) for prefix, path in models] + ) + uvicorn.run(app, host=host, port=port) diff --git a/src/cnlpt/api/__init__.py b/src/cnlpt/api/__init__.py deleted file mode 100644 index a3e8592f..00000000 --- a/src/cnlpt/api/__init__.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Serve REST APIs for CNLPT models over your network.""" - -from typing import Final - -MODEL_TYPES: Final = ( - "cnn", - "current", - "dtr", - # "event", - "hier", - "negation", - "temporal", - # "termexists", - # "timex", -) -"""The available model types for :func:`get_rest_app`.""" - - -def get_rest_app(model_type: str): - """Get a FastAPI app for a certain model type. - - Args: - model_type: The type of model to serve. - - Returns: - The FastAPI app. - """ - if model_type == "cnn": - from .cnn_rest import app - - return app - elif model_type == "current": - from .current_rest import app - - return app - elif model_type == "dtr": - from .dtr_rest import app - - return app - # elif model_type == "event": - # from .event_rest import app - - # return app - elif model_type == "hier": - from .hier_rest import app - - return app - elif model_type == "negation": - from .negation_rest import app - - return app - elif model_type == "temporal": - from .temporal_rest import app - - return app - # elif model_type == "termexists": - # from .termexists_rest import app - - # return app - # elif model_type == "timex": - # from .timex_rest import app - - # return app - else: - raise ValueError(f"unknown model type: {model_type}") - - -__all__ = [ - "MODEL_TYPES", - "get_rest_app", -] diff --git a/src/cnlpt/api/cnn_rest.py b/src/cnlpt/api/cnn_rest.py deleted file mode 100644 index af556e95..00000000 --- a/src/cnlpt/api/cnn_rest.py +++ /dev/null @@ -1,102 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import json -import logging -import os -import sys -from contextlib import asynccontextmanager -from os.path import join -from typing import Any - -import numpy as np -import torch -import torch.backends.mps -from fastapi import FastAPI -from scipy.special import softmax -from transformers import AutoTokenizer, PreTrainedTokenizer - -from ..models.baseline import CnnSentenceClassifier -from .utils import UnannotatedDocument, create_dataset, resolve_device - -MODEL_NAME = os.getenv("MODEL_PATH") -device = os.getenv("MODEL_DEVICE", "auto") -device = resolve_device(device) - -logger = logging.getLogger("CNN_REST_Processor") -logger.setLevel(logging.DEBUG) - -MAX_SEQ_LENGTH = 128 - -model: CnnSentenceClassifier -tokenizer: PreTrainedTokenizer -conf_dict: dict[str, Any] - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global model, tokenizer, conf_dict - if MODEL_NAME is None: - sys.stderr.write( - "This REST container requires a MODEL_PATH environment variable\n" - ) - sys.exit(-1) - conf_file = join(MODEL_NAME, "config.json") - with open(conf_file) as fp: - conf_dict = json.load(fp) - - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - num_labels_dict = { - task: len(values) for task, values in conf_dict["label_dictionary"].items() - } - model = CnnSentenceClassifier.from_pretrained( - MODEL_NAME, - vocab_size=len(tokenizer), - task_names=conf_dict["task_names"], - num_labels_dict=num_labels_dict, - embed_dims=conf_dict["cnn_embed_dim"], - num_filters=conf_dict["cnn_num_filters"], - filters=conf_dict["cnn_filter_sizes"], - ) - - model = model.to(device) - tokenizer = tokenizer - conf_dict = conf_dict - - yield - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/cnn/classify") -async def process(doc: UnannotatedDocument): - instances = [doc.doc_text] - dataset = create_dataset( - instances, tokenizer, max_length=conf_dict["max_seq_length"] - ) - _, logits = model.forward( - input_ids=torch.LongTensor(dataset["input_ids"]).to(device), - attention_mask=torch.LongTensor(dataset["attention_mask"]).to(device), - ) - - prediction = int(np.argmax(logits[0].cpu().detach().numpy(), axis=1)) - result = conf_dict["label_dictionary"][conf_dict["task_names"][0]][prediction] - probabilities = softmax(logits[0][0].cpu().detach().numpy()) - # for redcap purposes, it might make more sense to only output the probability for the predicted class, - # but i'm outputting them all, for transparency - out_probabilities = [str(prob) for prob in probabilities] - return {"result": result, "probabilities": out_probabilities} diff --git a/src/cnlpt/api/current_rest.py b/src/cnlpt/api/current_rest.py deleted file mode 100644 index 682c5228..00000000 --- a/src/cnlpt/api/current_rest.py +++ /dev/null @@ -1,106 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -from contextlib import asynccontextmanager -from time import time - -import numpy as np -from fastapi import FastAPI -from pydantic import BaseModel -from transformers import Trainer -from transformers.tokenization_utils import PreTrainedTokenizer - -from .utils import ( - EntityDocument, - create_dataset, - create_instance_string, - initialize_cnlpt_model, -) - -logger = logging.getLogger("Current_REST_Processor") -logger.setLevel(logging.DEBUG) - -MODEL_NAME = "mlml-chip/current-thyme" -TASK = "Current" -LABELS = [False, True] - -MAX_LENGTH = 128 - - -class CurrentResults(BaseModel): - """statuses: list of classifier outputs for every input""" - - statuses: list[bool] - - -tokenizer: PreTrainedTokenizer -trainer: Trainer - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global tokenizer, trainer - tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME) - yield - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/current/process") -async def process(doc: EntityDocument): - doc_text = doc.doc_text - logger.warning( - f"Received document of len {len(doc_text)} to process with {len(doc.entities)} entities" - ) - instances = [] - start_time = time() - - if len(doc.entities) == 0: - return CurrentResults(statuses=[]) - - for ent_ind, offsets in enumerate(doc.entities): - inst_str = create_instance_string(doc_text, offsets) - logger.debug(f"Instance string is {inst_str}") - instances.append(inst_str) - - dataset = create_dataset(instances, tokenizer, MAX_LENGTH) - preproc_end = time() - - output = trainer.predict(test_dataset=dataset) - predictions = output.predictions[0] - predictions = np.argmax(predictions, axis=1) - - pred_end = time() - - results = [] - for ent_ind in range(len(dataset)): - results.append(LABELS[predictions[ent_ind]]) - - output = CurrentResults(statuses=results) - - postproc_end = time() - - preproc_time = preproc_end - start_time - pred_time = pred_end - preproc_end - postproc_time = postproc_end - pred_end - - logging.info( - f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}" - ) - - return output diff --git a/src/cnlpt/api/dtr_rest.py b/src/cnlpt/api/dtr_rest.py deleted file mode 100644 index ee85e6c9..00000000 --- a/src/cnlpt/api/dtr_rest.py +++ /dev/null @@ -1,106 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -from contextlib import asynccontextmanager -from time import time - -import numpy as np -from fastapi import FastAPI -from pydantic import BaseModel -from transformers import Trainer -from transformers.tokenization_utils import PreTrainedTokenizer - -from .temporal_rest import OLD_DTR_LABEL_LIST -from .utils import ( - EntityDocument, - create_dataset, - create_instance_string, - initialize_cnlpt_model, -) - -MODEL_NAME = "tmills/tiny-dtr" -logger = logging.getLogger("DocTimeRel Processor with xtremedistil encoder") -logger.setLevel(logging.INFO) - -MAX_LENGTH = 128 - - -class DocTimeRelResults(BaseModel): - """statuses: dictionary from entity id to classification decision about DocTimeRel""" - - statuses: list[str] - - -tokenizer: PreTrainedTokenizer -trainer: Trainer - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global tokenizer, trainer - tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME) - yield - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/dtr/process") -async def process(doc: EntityDocument): - doc_text = doc.doc_text - logger.warning( - f"Received document of len {len(doc_text)} to process with {len(doc.entities)} entities" - ) - instances = [] - start_time = time() - - if len(doc.entities) == 0: - return DocTimeRelResults(statuses=[]) - - for ent_ind, offsets in enumerate(doc.entities): - # logger.debug('Entity ind: %d has offsets (%d, %d)' % (ent_ind, offsets[0], offsets[1])) - inst_str = create_instance_string(doc_text, offsets) - logger.debug(f"Instance string is {inst_str}") - instances.append(inst_str) - - dataset = create_dataset(instances, tokenizer, max_length=MAX_LENGTH) - - preproc_end = time() - - output = trainer.predict(test_dataset=dataset) - predictions = output.predictions[0] - predictions = np.argmax(predictions, axis=1) - - pred_end = time() - - results = [] - for ent_ind in range(len(dataset)): - results.append(OLD_DTR_LABEL_LIST[predictions[ent_ind]]) - - output = DocTimeRelResults(statuses=results) - - postproc_end = time() - - preproc_time = preproc_end - start_time - pred_time = pred_end - preproc_end - postproc_time = postproc_end - pred_end - - logging.warning( - f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}" - ) - - return output diff --git a/src/cnlpt/api/event_rest.py b/src/cnlpt/api/event_rest.py deleted file mode 100644 index e2d34c75..00000000 --- a/src/cnlpt/api/event_rest.py +++ /dev/null @@ -1,158 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import logging -from contextlib import asynccontextmanager -from time import time - -import numpy as np -from fastapi import FastAPI -from nltk.tokenize import wordpunct_tokenize as tokenize -from seqeval.metrics.sequence_labeling import get_entities -from transformers import Trainer -from transformers.tokenization_utils import PreTrainedTokenizer - -from .temporal_rest import ( - EVENT_LABEL_LIST, - Event, - SentenceDocument, - TemporalResults, - TokenizedSentenceDocument, - create_instance_string, -) -from .utils import create_dataset, initialize_cnlpt_model - -MODEL_NAME = "tmills/event-thyme-colon-pubmedbert" -logger = logging.getLogger("Event_REST_Processor") -logger.setLevel(logging.INFO) - -MAX_LENGTH = 128 - - -tokenizer: PreTrainedTokenizer -trainer: Trainer - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global tokenizer, trainer - tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME) - yield - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/temporal/process") -async def process(doc: TokenizedSentenceDocument): - return process_tokenized_sentence_document(doc) - - -@app.post("/temporal/process_sentence") -async def process_sentence(doc: SentenceDocument): - tokenized_sent = tokenize(doc.sentence) - doc = TokenizedSentenceDocument( - sent_tokens=[ - tokenized_sent, - ], - metadata="Single sentence", - ) - return process_tokenized_sentence_document(doc) - - -def process_tokenized_sentence_document(doc: TokenizedSentenceDocument): - sents = doc.sent_tokens - metadata = doc.metadata - - logger.warning(f"Received document labeled {metadata} with {len(sents)} sentences") - instances = [] - start_time = time() - - for sent_ind, token_list in enumerate(sents): - inst_str = create_instance_string(token_list) - logger.debug(f"Instance string is {inst_str}") - instances.append(inst_str) - - dataset = create_dataset(instances, tokenizer, max_length=MAX_LENGTH) - preproc_end = time() - - output = trainer.predict(test_dataset=dataset) - - event_predictions = np.argmax(output.predictions[0], axis=2) - - pred_end = time() - - timex_results = [] - event_results = [] - rel_results = [] - - for sent_ind in range(len(dataset)): - batch_encoding = tokenizer.batch_encode_plus( - [ - sents[sent_ind], - ], - is_split_into_words=True, - max_length=MAX_LENGTH, - ) - word_ids = batch_encoding.word_ids(0) - wpind_to_ind = {} - event_labels = [] - previous_word_idx = None - - for word_pos_idx, word_idx in enumerate(word_ids): - if word_idx != previous_word_idx and word_idx is not None: - key = word_pos_idx - val = len(wpind_to_ind) - - wpind_to_ind[key] = val - event_labels.append( - EVENT_LABEL_LIST[event_predictions[sent_ind][word_pos_idx]] - ) - previous_word_idx = word_idx - - event_entities = get_entities(event_labels) - logging.info(f"Extracted {len(event_entities)} events from the sentence") - event_results.append( - [ - Event(dtr=label[0], begin=label[1], end=label[2]) - for label in event_entities - ] - ) - timex_results.append([]) - rel_results.append([]) - - results = TemporalResults( - timexes=timex_results, events=event_results, relations=rel_results - ) - - postproc_end = time() - - preproc_time = preproc_end - start_time - pred_time = pred_end - preproc_end - postproc_time = postproc_end - pred_end - - logging.info( - f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}" - ) - - return results - - -@app.post("/temporal/collection_process_complete") -async def collection_process_complete(): - global trainer - trainer = None diff --git a/src/cnlpt/api/hier_rest.py b/src/cnlpt/api/hier_rest.py deleted file mode 100644 index 66711866..00000000 --- a/src/cnlpt/api/hier_rest.py +++ /dev/null @@ -1,104 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -import os -from contextlib import asynccontextmanager - -import torch -from fastapi import FastAPI -from transformers import PreTrainedModel -from transformers.tokenization_utils import PreTrainedTokenizer - -from .utils import ( - UnannotatedDocument, - create_dataset, - initialize_hier_model, - resolve_device, -) - -MODEL_NAME = os.getenv("MODEL_PATH") - -device = os.getenv("MODEL_DEVICE", "auto") -device = resolve_device(device) - -logger = logging.getLogger("HierRep_REST_Processor") -logger.setLevel(logging.DEBUG) - -tokenizer: PreTrainedTokenizer -model: PreTrainedModel - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global tokenizer, model - tokenizer, model = initialize_hier_model(MODEL_NAME) - yield - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/hier/get_rep") -async def get_representation(doc: UnannotatedDocument): - instances = [doc.doc_text] - dataset = create_dataset( - instances, - tokenizer, - max_length=16000, - hier=True, - chunk_len=200, - num_chunks=80, - insert_empty_chunk_at_beginning=False, - ) - - result = model.forward( - input_ids=torch.LongTensor(dataset["input_ids"]).to(model.device), - token_type_ids=torch.LongTensor(dataset["token_type_ids"]).to(model.device), - attention_mask=torch.LongTensor(dataset["attention_mask"]).to(model.device), - output_hidden_states=True, - ) - - # Convert to a list so python can send it out - hidden_states = result["hidden_states"].to("cpu").detach().numpy()[:, 0, :].tolist() - return {"reps": hidden_states[0]} - - -@app.post("/hier/classify") -async def classify(doc: UnannotatedDocument): - instances = [doc.doc_text] - dataset = create_dataset( - instances, - tokenizer, - max_length=16000, - hier=True, - chunk_len=200, - num_chunks=80, - insert_empty_chunk_at_beginning=False, - ) - result = model.forward( - input_ids=torch.LongTensor(dataset["input_ids"]).to(model.device), - token_type_ids=torch.LongTensor(dataset["token_type_ids"]).to(model.device), - attention_mask=torch.LongTensor(dataset["attention_mask"]).to(model.device), - output_hidden_states=False, - ) - - predictions = [ - int(torch.argmax(logits.to("cpu").detach()).numpy()) - for logits in result["logits"] - ] - labels = [next(iter(model.label_dictionary.values()))[x] for x in predictions] - return {"result": labels} diff --git a/src/cnlpt/api/negation_rest.py b/src/cnlpt/api/negation_rest.py deleted file mode 100644 index e4ba78b2..00000000 --- a/src/cnlpt/api/negation_rest.py +++ /dev/null @@ -1,112 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -from contextlib import asynccontextmanager -from time import time - -import numpy as np -from fastapi import FastAPI -from pydantic import BaseModel -from transformers import Trainer -from transformers.tokenization_utils import PreTrainedTokenizer - -from .utils import ( - EntityDocument, - create_dataset, - create_instance_string, - initialize_cnlpt_model, -) - -MODEL_NAME = "mlml-chip/negation_pubmedbert_sharpseed" -logger = logging.getLogger("Negation_REST_Processor") -logger.setLevel(logging.DEBUG) - -TASK = "Negation" -LABELS = [-1, 1] - -MAX_LENGTH = 128 - - -class NegationResults(BaseModel): - """statuses: dictionary from entity id to classification decision about negation; true -> negated, false -> not negated""" - - statuses: list[int] - - -tokenizer: PreTrainedTokenizer -trainer: Trainer - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global tokenizer, trainer - tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME) - yield - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/negation/process") -async def process(doc: EntityDocument): - doc_text = doc.doc_text - logger.warning( - f"Received document of len {len(doc_text)} to process with {len(doc.entities)} entities" - ) - instances = [] - start_time = time() - - if len(doc.entities) == 0: - return NegationResults(statuses=[]) - - for ent_ind, offsets in enumerate(doc.entities): - # logger.debug('Entity ind: %d has offsets (%d, %d)' % (ent_ind, offsets[0], offsets[1])) - inst_str = create_instance_string(doc_text, offsets) - logger.debug(f"Instance string is {inst_str}") - instances.append(inst_str) - - dataset = create_dataset(instances, tokenizer, MAX_LENGTH) - preproc_end = time() - - output = trainer.predict(test_dataset=dataset) - predictions = output.predictions[0] - predictions = np.argmax(predictions, axis=1) - - pred_end = time() - - results = [] - for ent_ind in range(len(dataset)): - results.append(LABELS[predictions[ent_ind]]) - - output = NegationResults(statuses=results) - - postproc_end = time() - - preproc_time = preproc_end - start_time - pred_time = pred_end - preproc_end - postproc_time = postproc_end - pred_end - - logging.warning( - f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}" - ) - - return output - - -@app.get("/negation/{test_str}") -async def test(test_str: str): - return {"argument": test_str} diff --git a/src/cnlpt/api/temporal_rest.py b/src/cnlpt/api/temporal_rest.py deleted file mode 100644 index 7b23e495..00000000 --- a/src/cnlpt/api/temporal_rest.py +++ /dev/null @@ -1,373 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import logging -import os -from contextlib import asynccontextmanager -from time import time -from typing import Union - -import numpy as np -from fastapi import FastAPI -from nltk.tokenize import wordpunct_tokenize as tokenize -from pydantic import BaseModel -from seqeval.metrics.sequence_labeling import get_entities -from transformers import Trainer -from transformers.tokenization_utils import PreTrainedTokenizer - -from .utils import create_dataset, initialize_cnlpt_model - -MODEL_NAME = "mlml-chip/thyme2_colon_e2e" -logger = logging.getLogger("Temporal_REST_Processor") -logger.setLevel(logging.INFO) - -LABELS = ["-1", "1"] -TIMEX_LABEL_LIST = [ - "O", - "B-DATE", - "B-DURATION", - "B-PREPOSTEXP", - "B-QUANTIFIER", - "B-SET", - "B-TIME", - "B-SECTIONTIME", - "B-DOCTIME", - "I-DATE", - "I-DURATION", - "I-PREPOSTEXP", - "I-QUANTIFIER", - "I-SET", - "I-TIME", - "I-SECTIONTIME", - "I-DOCTIME", -] -TIMEX_LABEL_DICT = {val: ind for ind, val in enumerate(TIMEX_LABEL_LIST)} -EVENT_LABEL_LIST = [ - "O", - "B-AFTER", - "B-BEFORE", - "B-BEFORE/OVERLAP", - "B-OVERLAP", - "I-AFTER", - "I-BEFORE", - "I-BEFORE/OVERLAP", - "I-OVERLAP", -] -EVENT_LABEL_DICT = {val: ind for ind, val in enumerate(EVENT_LABEL_LIST)} - -RELATION_LABEL_LIST = ["None", "CONTAINS", "OVERLAP", "BEFORE", "BEGINS-ON", "ENDS-ON"] -RELATION_LABEL_DICT = {val: ind for ind, val in enumerate(RELATION_LABEL_LIST)} - -DTR_LABEL_LIST = ["AFTER", "BEFORE", "BEFORE/OVERLAP", "OVERLAP"] -OLD_DTR_LABEL_LIST = ["BEFORE", "OVERLAP", "BEFORE/OVERLAP", "AFTER"] - -LABELS = [TIMEX_LABEL_LIST, EVENT_LABEL_LIST, RELATION_LABEL_LIST] -MAX_LENGTH = 128 - - -class SentenceDocument(BaseModel): - sentence: str - - -class TokenizedSentenceDocument(BaseModel): - """sent_tokens: a list of sentences, where each sentence is a list of tokens""" - - sent_tokens: list[list[str]] - metadata: str - - -class Timex(BaseModel): - begin: int - end: int - timeClass: str - - -class Event(BaseModel): - begin: int - end: int - dtr: str - - -class Relation(BaseModel): - # Allow args to be none, so that we can potentially link them to times or events in the client, or if they don't - # care about that. pass back the token indices of the args in addition. - arg1: Union[str, None] - arg2: Union[str, None] - category: str - arg1_start: int - arg2_start: int - - -class TemporalResults(BaseModel): - """lists of timexes, events and relations for list of sentences""" - - timexes: list[list[Timex]] - events: list[list[Event]] - relations: list[list[Relation]] - - -def create_instance_string(tokens: list[str]): - return " ".join(tokens) - - -task_order: dict[str, int] -tasks: list[str] -tokenizer: PreTrainedTokenizer -trainer: Trainer - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global \ - TIMEX_LABEL_LIST, \ - TIMEX_LABEL_DICT, \ - EVENT_LABEL_LIST, \ - EVENT_LABEL_DICT, \ - RELATION_LABEL_LIST, \ - RELATION_LABEL_DICT, \ - task_order, \ - tasks, \ - tokenizer, \ - trainer - - local_model_name = os.getenv("MODEL_NAME", MODEL_NAME) - tokenizer, trainer = initialize_cnlpt_model(local_model_name) - - config_dict = trainer.model.config.to_dict() - # For newer models (version >= 0.6.0), the label dictionary is saved with the model - # config. we can look for it to preserve backwards compatibility for now but - # should eventually remove the hardcoded label lists from our inference tools. - label_dict = config_dict.get("label_dictionary", None) - if label_dict is not None: - # some older versions have one label dictionary per dataset, future versions should just - # have a task-keyed dictionary - if type(label_dict) is list: - label_dict = label_dict[0] - - if "event" in label_dict: - EVENT_LABEL_LIST = label_dict["event"] - EVENT_LABEL_DICT = {val: ind for ind, val in enumerate(EVENT_LABEL_LIST)} - print(EVENT_LABEL_LIST) - - if "timex" in label_dict: - TIMEX_LABEL_LIST = label_dict["timex"] - TIMEX_LABEL_DICT = {val: ind for ind, val in enumerate(TIMEX_LABEL_LIST)} - print(TIMEX_LABEL_LIST) - - if "tlinkx" in label_dict: - RELATION_LABEL_LIST = label_dict["tlinkx"] - RELATION_LABEL_DICT = { - val: ind for ind, val in enumerate(RELATION_LABEL_LIST) - } - print(RELATION_LABEL_LIST) - - tasks = config_dict.get("finetuning_task", None) - task_order = {} - if tasks is not None: - print("Overwriting finetuning task order") - for task_ind, task_name in enumerate(tasks): - task_order[task_name] = task_ind - print(task_order) - else: - print("Didn't find a new task ordering in the model config") - yield - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/temporal/process") -async def process(doc: TokenizedSentenceDocument): - return process_tokenized_sentence_document(doc) - - -@app.post("/temporal/process_sentence") -async def process_sentence(doc: SentenceDocument): - tokenized_sent = tokenize(doc.sentence) - doc = TokenizedSentenceDocument( - sent_tokens=[ - tokenized_sent, - ], - metadata="Single sentence", - ) - return process_tokenized_sentence_document(doc) - - -def process_tokenized_sentence_document(doc: TokenizedSentenceDocument): - sents = doc.sent_tokens - metadata = doc.metadata - - print(EVENT_LABEL_LIST) - print(TIMEX_LABEL_LIST) - print(RELATION_LABEL_LIST) - - logger.warning(f"Received document labeled {metadata} with {len(sents)} sentences") - instances = [] - start_time = time() - - for sent_ind, token_list in enumerate(sents): - inst_str = create_instance_string(token_list) - logger.debug(f"Instance string is {inst_str}") - instances.append(inst_str) - - dataset = create_dataset(instances, tokenizer, MAX_LENGTH) - preproc_end = time() - - output = trainer.predict(test_dataset=dataset) - - timex_predictions = np.argmax(output.predictions[task_order["timex"]], axis=2) - event_predictions = np.argmax(output.predictions[task_order["event"]], axis=2) - rel_predictions = np.argmax(output.predictions[task_order["tlinkx"]], axis=3) - rel_inds = np.where(rel_predictions != RELATION_LABEL_DICT["None"]) - - logging.debug(f"Found relation indices: {rel_inds!s}") - - rels_by_sent = {} - for rel_num in range(len(rel_inds[0])): - sent_ind = rel_inds[0][rel_num] - if sent_ind not in rels_by_sent: - rels_by_sent[sent_ind] = [] - - arg1_ind = rel_inds[1][rel_num] - arg2_ind = rel_inds[2][rel_num] - if arg1_ind == arg2_ind: - # no relations between an entity and itself - logger.warning("Found relation between an entity and itself... skipping") - continue - - rel_cat = rel_predictions[sent_ind, arg1_ind, arg2_ind] - - rels_by_sent[sent_ind].append((arg1_ind, arg2_ind, rel_cat)) - - pred_end = time() - - timex_results = [] - event_results = [] - rel_results = [] - - for sent_ind in range(len(dataset)): - batch_encoding = tokenizer( - [ - sents[sent_ind], - ], - is_split_into_words=True, - max_length=MAX_LENGTH, - ) - word_ids = batch_encoding.word_ids(0) - wpind_to_ind = {} - timex_labels = [] - event_labels = [] - previous_word_idx = None - - for word_pos_idx, word_idx in enumerate(word_ids): - if word_idx != previous_word_idx and word_idx is not None: - key = word_pos_idx - val = len(wpind_to_ind) - - wpind_to_ind[key] = val - # tokeni_to_wpi[val] = key - timex_labels.append( - TIMEX_LABEL_LIST[timex_predictions[sent_ind][word_pos_idx]] - ) - try: - event_labels.append( - EVENT_LABEL_LIST[event_predictions[sent_ind][word_pos_idx]] - ) - except Exception as e: - print( - f"exception thrown when sent_ind={sent_ind} and word_pos_idx={word_pos_idx}" - ) - print( - f"prediction is {event_predictions[sent_ind][word_pos_idx]!s}" - ) - raise e - - previous_word_idx = word_idx - - timex_entities = get_entities(timex_labels) - logging.info( - f"Extracted {len(timex_entities)} timex entities from the sentence" - ) - timex_results.append( - [ - Timex(timeClass=label[0], begin=label[1], end=label[2]) - for label in timex_entities - ] - ) - - event_entities = get_entities(event_labels) - logging.info(f"Extracted {len(event_entities)} events from the sentence") - event_results.append( - [ - Event(dtr=label[0], begin=label[1], end=label[2]) - for label in event_entities - ] - ) - - rel_sent_results = [] - for rel in rels_by_sent.get(sent_ind, []): - arg1 = None - arg2 = None - if rel[0] not in wpind_to_ind or rel[1] not in wpind_to_ind: - logging.warning( - "Found a relation to a non-leading wordpiece token... ignoring" - ) - continue - - arg1_ind = wpind_to_ind[rel[0]] - arg2_ind = wpind_to_ind[rel[1]] - - sent_timexes = timex_results[-1] - for timex_ind, timex in enumerate(sent_timexes): - if timex.begin == arg1_ind: - arg1 = f"TIMEX-{timex_ind}" - if timex.begin == arg2_ind: - arg2 = f"TIMEX-{timex_ind}" - - sent_events = event_results[-1] - for event_ind, event in enumerate(sent_events): - if event.begin == arg1_ind: - arg1 = f"EVENT-{event_ind}" - if event.begin == arg2_ind: - arg2 = f"EVENT-{event_ind}" - - rel = Relation( - arg1=arg1, - arg2=arg2, - category=RELATION_LABEL_LIST[rel[2]], - arg1_start=arg1_ind, - arg2_start=arg2_ind, - ) - rel_sent_results.append(rel) - - rel_results.append(rel_sent_results) - - results = TemporalResults( - timexes=timex_results, events=event_results, relations=rel_results - ) - - postproc_end = time() - - preproc_time = preproc_end - start_time - pred_time = pred_end - preproc_end - postproc_time = postproc_end - pred_end - - logging.info( - f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}" - ) - - return results diff --git a/src/cnlpt/api/termexists_rest.py b/src/cnlpt/api/termexists_rest.py deleted file mode 100644 index a6024513..00000000 --- a/src/cnlpt/api/termexists_rest.py +++ /dev/null @@ -1,106 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -from contextlib import asynccontextmanager -from time import time - -import numpy as np -from fastapi import FastAPI -from pydantic import BaseModel -from transformers import Trainer -from transformers.tokenization_utils import PreTrainedTokenizer - -from .utils import ( - EntityDocument, - create_dataset, - create_instance_string, - initialize_cnlpt_model, -) - -MODEL_NAME = "mlml-chip/sharpseed-termexists" -logger = logging.getLogger("TermExists_REST_Processor") -logger.setLevel(logging.DEBUG) - -TASK = "TermExists" -LABELS = [-1, 1] - -MAX_LENGTH = 128 - - -class TermExistsResults(BaseModel): - """statuses: list of classifier outputs for every input""" - - statuses: list[int] - - -tokenizer: PreTrainedTokenizer -trainer: Trainer - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global tokenizer, trainer - tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME) - yield - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/termexists/process") -async def process(doc: EntityDocument): - doc_text = doc.doc_text - logger.warning( - f"Received document of len {len(doc_text)} to process with {len(doc.entities)} entities" - ) - instances = [] - start_time = time() - - if len(doc.entities) == 0: - return TermExistsResults(statuses=[]) - - for ent_ind, offsets in enumerate(doc.entities): - inst_str = create_instance_string(doc_text, offsets) - logger.debug(f"Instance string is {inst_str}") - instances.append(inst_str) - - dataset = create_dataset(instances, tokenizer, MAX_LENGTH) - preproc_end = time() - - output = trainer.predict(test_dataset=dataset) - predictions = output.predictions[0] - predictions = np.argmax(predictions, axis=1) - - pred_end = time() - - results = [] - for ent_ind in range(len(dataset)): - results.append(LABELS[predictions[ent_ind]]) - - output = TermExistsResults(statuses=results) - - postproc_end = time() - - preproc_time = preproc_end - start_time - pred_time = pred_end - preproc_end - postproc_time = postproc_end - pred_end - - logging.warning( - f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}" - ) - - return output diff --git a/src/cnlpt/api/timex_rest.py b/src/cnlpt/api/timex_rest.py deleted file mode 100644 index bd44778b..00000000 --- a/src/cnlpt/api/timex_rest.py +++ /dev/null @@ -1,156 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import logging -from contextlib import asynccontextmanager -from time import time - -import numpy as np -from fastapi import FastAPI -from nltk.tokenize import wordpunct_tokenize as tokenize -from seqeval.metrics.sequence_labeling import get_entities -from transformers import Trainer -from transformers.tokenization_utils import PreTrainedTokenizer - -from .temporal_rest import ( - TIMEX_LABEL_LIST, - SentenceDocument, - TemporalResults, - Timex, - TokenizedSentenceDocument, - create_instance_string, -) -from .utils import create_dataset, initialize_cnlpt_model - -MODEL_NAME = "tmills/timex-thyme-colon-pubmedbert" -logger = logging.getLogger("Timex_REST_Processor") -logger.setLevel(logging.INFO) - -MAX_LENGTH = 128 - - -tokenizer: PreTrainedTokenizer -trainer: Trainer - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global tokenizer, trainer - tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME) - yield - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/temporal/process") -async def process(doc: TokenizedSentenceDocument): - return process_tokenized_sentence_document(doc) - - -@app.post("/temporal/process_sentence") -async def process_sentence(doc: SentenceDocument): - tokenized_sent = tokenize(doc.sentence) - doc = TokenizedSentenceDocument( - sent_tokens=[ - tokenized_sent, - ], - metadata="Single sentence", - ) - return process_tokenized_sentence_document(doc) - - -def process_tokenized_sentence_document(doc: TokenizedSentenceDocument): - sents = doc.sent_tokens - metadata = doc.metadata - - logger.warning(f"Received document labeled {metadata} with {len(sents)} sentences") - instances = [] - start_time = time() - - for sent_ind, token_list in enumerate(sents): - inst_str = create_instance_string(token_list) - logger.debug(f"Instance string is {inst_str}") - instances.append(inst_str) - - dataset = create_dataset(instances, tokenizer, max_length=MAX_LENGTH) - logger.warning(f"Dataset is as follows: {dataset.features!s}") - - preproc_end = time() - - output = trainer.predict(test_dataset=dataset) - - timex_predictions = np.argmax(output.predictions[0], axis=2) - - timex_results = [] - event_results = [] - relation_results = [] - - pred_end = time() - - for sent_ind in range(len(dataset)): - batch_encoding = tokenizer.batch_encode_plus( - [ - sents[sent_ind], - ], - is_split_into_words=True, - max_length=MAX_LENGTH, - ) - word_ids = batch_encoding.word_ids(0) - wpind_to_ind = {} - timex_labels = [] - previous_word_idx = None - - for word_pos_idx, word_idx in enumerate(word_ids): - if word_idx != previous_word_idx and word_idx is not None: - key = word_pos_idx - val = len(wpind_to_ind) - - wpind_to_ind[key] = val - timex_labels.append( - TIMEX_LABEL_LIST[timex_predictions[sent_ind][word_pos_idx]] - ) - previous_word_idx = word_idx - - timex_entities = get_entities(timex_labels) - logging.info( - f"Extracted {len(timex_entities)} timex entities from the sentence" - ) - timex_results.append( - [ - Timex(timeClass=label[0], begin=label[1], end=label[2]) - for label in timex_entities - ] - ) - event_results.append([]) - relation_results.append([]) - - results = TemporalResults( - timexes=timex_results, events=event_results, relations=relation_results - ) - - postproc_end = time() - - preproc_time = preproc_end - start_time - pred_time = pred_end - preproc_end - postproc_time = postproc_end - pred_end - - logging.info( - f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}" - ) - - return results diff --git a/src/cnlpt/api/utils.py b/src/cnlpt/api/utils.py deleted file mode 100644 index c756ef70..00000000 --- a/src/cnlpt/api/utils.py +++ /dev/null @@ -1,157 +0,0 @@ -import logging -import os -from typing import Literal, cast - -import torch -from datasets import Dataset -from pydantic import BaseModel -from transformers.hf_argparser import HfArgumentParser - -# Modeling imports -from transformers.models.auto.configuration_auto import AutoConfig -from transformers.models.auto.modeling_auto import AutoModel -from transformers.models.auto.tokenization_auto import AutoTokenizer -from transformers.tokenization_utils import PreTrainedTokenizer -from transformers.trainer import Trainer -from transformers.training_args import TrainingArguments - -from ..data.preprocess import preprocess_raw_data -from ..models import CnlpConfig - - -class UnannotatedDocument(BaseModel): - doc_text: str - - -class EntityDocument(BaseModel): - """doc_text: The raw text of the document - offset: A list of entities, where each is a tuple of character offsets into doc_text for that entity - """ - - doc_text: str - entities: list[list[int]] - - -def create_dataset( - inst_list: list[str], - tokenizer: PreTrainedTokenizer, - max_length: int = 128, - hier: bool = False, - chunk_len: int = 200, - num_chunks: int = 40, - insert_empty_chunk_at_beginning: bool = False, -): - """Use a tokenizer to create a dataset from a list of strings.""" - dataset = Dataset.from_dict({"text": inst_list}) - task_dataset = dataset.map( - preprocess_raw_data, - batched=True, - load_from_cache_file=False, - desc="Running tokenizer on dataset, organizing labels, creating hierarchical segments if necessary", - batch_size=100, - fn_kwargs={ - "tokenizer": tokenizer, - "tasks": None, - "max_length": max_length, - "inference_only": True, - "hierarchical": hier, - # TODO: need to get this from the model if necessary - "chunk_len": chunk_len, - "num_chunks": num_chunks, - "insert_empty_chunk_at_beginning": insert_empty_chunk_at_beginning, - }, - ) - return task_dataset - - -def create_instance_string(doc_text: str, offsets: list[int]): - start = max(0, offsets[0] - 100) - end = min(len(doc_text), offsets[1] + 100) - raw_str = ( - doc_text[start : offsets[0]] - + " " - + doc_text[offsets[0] : offsets[1]] - + " " - + doc_text[offsets[1] : end] - ) - return raw_str.replace("\n", " ") - - -def resolve_device( - device: str, -) -> Literal["cuda", "mps", "cpu"]: - device = device.lower() - if device not in ("cuda", "mps", "cpu", "auto"): - raise ValueError(f"invalid device {device}") - if device == "auto": - if torch.cuda.is_available(): - device = "cuda" - elif torch.mps.is_available(): - device = "mps" - else: - device = "cpu" - elif device == "cuda" and not torch.cuda.is_available(): - logging.warning( - "Device is set to 'cuda' but was not available; setting to 'cpu' and proceeding. If you have a GPU you need to debug why pytorch cannot see it." - ) - device = "cpu" - elif device == "mps" and not torch.mps.is_available(): - logging.warning( - "Device is set to 'mps' but was not available; setting to 'cpu' and proceeding. If you have a GPU you need to debug why pytorch cannot see it." - ) - device = "cpu" - return device - - -def initialize_cnlpt_model( - model_name, - device: Literal["cuda", "mps", "cpu", "auto"] = "auto", - batch_size=8, -): - args = [ - "--output_dir", - "save_run/", - "--per_device_eval_batch_size", - str(batch_size), - "--do_predict", - "--report_to", - "none", - ] - parser = HfArgumentParser((TrainingArguments,)) - training_args = cast( - TrainingArguments, parser.parse_args_into_dataclasses(args=args)[0] - ) - - if torch.mps.is_available(): - # pin_memory is unsupported on MPS, but defaults to True, - # so we'll explicitly turn it off to avoid a warning. - training_args.dataloader_pin_memory = False - - config = AutoConfig.from_pretrained(model_name) - tokenizer = AutoTokenizer.from_pretrained(model_name, config=config) - model = AutoModel.from_pretrained( - model_name, cache_dir=os.getenv("HF_CACHE"), config=config - ) - - model = model.to(resolve_device(device)) - - trainer = Trainer(model=model, args=training_args) - - return tokenizer, trainer - - -def initialize_hier_model( - model_name, - device: Literal["cuda", "mps", "cpu", "auto"] = "auto", -): - config: CnlpConfig = AutoConfig.from_pretrained(model_name) - tokenizer = AutoTokenizer.from_pretrained(model_name, config=config) - - model = AutoModel.from_pretrained( - model_name, cache_dir=os.getenv("HF_CACHE"), config=config - ) - model.train(False) - - model = model.to(resolve_device(device)) - - return tokenizer, model diff --git a/src/cnlpt/data/analysis.py b/src/cnlpt/data/analysis.py index d2591a33..2d7fd7a1 100644 --- a/src/cnlpt/data/analysis.py +++ b/src/cnlpt/data/analysis.py @@ -175,6 +175,7 @@ def make_preds_df( Returns: The DataFrame for analysis. """ + seq_len = len(predictions.input_data["input_ids"][0]) df_data = { @@ -195,37 +196,53 @@ def make_preds_df( else: tasks = predictions.tasks + unlabeled = predictions.raw.label_ids is None + for task in tasks: task_pred = predictions.task_predictions[task.name] - df = df.with_columns( - pl.struct( - labels=pl.struct( + + fields = [] + if not unlabeled: + fields.append( + pl.struct( ids=task_pred.labels, values=task_pred.target_str_labels, - ), - predictions=pl.struct( + ).alias("labels") + ) + + fields.extend( + [ + pl.struct( ids=task_pred.predicted_int_labels, values=task_pred.predicted_str_labels, - ), - model_output=pl.struct( + ).alias("predictions"), + pl.struct( logits=task_pred.logits, probs=task_pred.probs, - ), - ).alias(task.name) + ).alias("model_output"), + ] ) + df = df.with_columns(pl.struct(fields).alias(task.name)) + if task.type == CLASSIFICATION: # classification output is already pretty human-interpretable pass elif task.type == TAGGING: # for tagging, we'll convert BIO tags to labeled spans - df = df.join( - _bio_tags_to_spans( - df, pl.col(task.name).struct.field("labels").struct.field("values") - ), - on="sample_idx", - how="left", - ).rename({"spans": "target_spans"}) + tagging_fields = [] + if not unlabeled: + df = df.join( + _bio_tags_to_spans( + df, + pl.col(task.name).struct.field("labels").struct.field("values"), + ), + on="sample_idx", + how="left", + ).rename({"spans": "target_spans"}) + tagging_fields.append( + pl.field("labels").struct.with_fields(spans="target_spans") + ) df = df.join( _bio_tags_to_spans( @@ -238,20 +255,27 @@ def make_preds_df( how="left", ).rename({"spans": "predicted_spans"}) + tagging_fields.append( + pl.field("predictions").struct.with_fields(spans="predicted_spans") + ) + df = df.with_columns( - pl.col(task.name).struct.with_fields( - pl.field("labels").struct.with_fields(spans="target_spans"), - pl.field("predictions").struct.with_fields(spans="predicted_spans"), - ) - ).drop("target_spans", "predicted_spans") + pl.col(task.name).struct.with_fields(tagging_fields) + ).drop("target_spans", "predicted_spans", strict=False) elif task.type == RELATIONS: - df = df.join( - _rel_matrix_to_rels( - df, pl.col(task.name).struct.field("labels").struct.field("values") - ), - on="sample_idx", - how="left", - ).rename({"relations": "target_relations"}) + relations_fields = [] + if not unlabeled: + df = df.join( + _rel_matrix_to_rels( + df, + pl.col(task.name).struct.field("labels").struct.field("values"), + ), + on="sample_idx", + how="left", + ).rename({"relations": "target_relations"}) + relations_fields.append( + pl.field("labels").struct.with_fields(relations="target_relations") + ) df = df.join( _rel_matrix_to_rels( @@ -263,15 +287,15 @@ def make_preds_df( on="sample_idx", how="left", ).rename({"relations": "predicted_relations"}) + relations_fields.append( + pl.field("predictions").struct.with_fields( + relations="predicted_relations" + ) + ) df = df.with_columns( - pl.col(task.name).struct.with_fields( - pl.field("labels").struct.with_fields(relations="target_relations"), - pl.field("predictions").struct.with_fields( - relations="predicted_relations" - ), - ) - ).drop("target_relations", "predicted_relations") + pl.col(task.name).struct.with_fields(relations_fields) + ).drop("target_relations", "predicted_relations", strict=False) else: raise ValueError(f"unknown task type {task.type}") diff --git a/src/cnlpt/data/predictions.py b/src/cnlpt/data/predictions.py index c984be1a..fb42b60a 100644 --- a/src/cnlpt/data/predictions.py +++ b/src/cnlpt/data/predictions.py @@ -106,7 +106,9 @@ def __init__( t.name: TaskPredictions( task=t, logits=self.raw.predictions[t.index], - labels=task_labels[t.name].squeeze(), + labels=task_labels[t.name].squeeze() + if task_labels[t.name] is not None + else None, ) for t in tasks } diff --git a/src/cnlpt/rest/__init__.py b/src/cnlpt/rest/__init__.py new file mode 100644 index 00000000..9551ce58 --- /dev/null +++ b/src/cnlpt/rest/__init__.py @@ -0,0 +1,3 @@ +from .cnlp_rest import CnlpRestApp + +__all__ = ["CnlpRestApp"] diff --git a/src/cnlpt/rest/cnlp_rest.py b/src/cnlpt/rest/cnlp_rest.py new file mode 100644 index 00000000..1f2355cd --- /dev/null +++ b/src/cnlpt/rest/cnlp_rest.py @@ -0,0 +1,224 @@ +import logging +import os +from collections.abc import Iterable +from typing import Union + +import polars as pl +import torch +from datasets import Dataset +from fastapi import APIRouter, FastAPI +from pydantic import BaseModel +from transformers.models.auto.configuration_auto import AutoConfig +from transformers.models.auto.modeling_auto import AutoModel +from transformers.models.auto.tokenization_auto import AutoTokenizer +from transformers.trainer import Trainer +from transformers.training_args import TrainingArguments +from typing_extensions import Self + +from ..args.data_args import CnlpDataArguments +from ..data.analysis import make_preds_df +from ..data.predictions import CnlpPredictions +from ..data.preprocess import preprocess_raw_data +from ..data.task_info import CLASSIFICATION, RELATIONS, TAGGING, TaskInfo + + +class InputDocument(BaseModel): + text: str + entity_spans: Union[list[tuple[int, int]], None] = None + + def to_text_list(self): + if self.entity_spans is None: + return [self.text] + + text_list: list[str] = [] + for entity_start, entity_end in self.entity_spans: + start = max(0, entity_start - 100) + end = min(len(self.text), entity_end + 100) + text_list.append( + "".join( + [ + self.text[start:entity_start], + "", + self.text[entity_start:entity_end], + "", + self.text[entity_end:end], + ] + ) + ) + return text_list + + +class CnlpRestApp: + def __init__(self, model_path: str, device: str = "auto"): + self.model_path = model_path + self.resolve_device(device) + self.setup_logger(logging.DEBUG) + self.load_model() + + def resolve_device(self, device: str): + self.device = device.lower() + if self.device == "auto": + if torch.cuda.is_available(): + self.device = "cuda" + elif torch.mps.is_available(): + self.device = "mps" + else: + self.device = "cpu" + else: + try: + torch.tensor([1.0], device=self.device) + except: # noqa: E722 + self.logger.warning( + f"Device is set to '{self.device}' but was not available; setting to 'cpu' and proceeding. If you have a GPU you need to debug why pytorch cannot see it." + ) + self.device = "cpu" + + def setup_logger(self, log_level): + self.logger = logging.getLogger(self.__module__) + self.logger.setLevel(log_level) + + def load_model(self): + training_args = TrainingArguments( + output_dir="save_run/", + save_strategy="no", + per_device_eval_batch_size=8, + do_predict=True, + ) + + if self.device == "mps": + # pin_memory is unsupported on MPS, but defaults to True, + # so we'll explicitly turn it off to avoid a warning. + training_args.dataloader_pin_memory = False + + self.config = AutoConfig.from_pretrained(self.model_path) + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_path, config=self.config + ) + self.model = AutoModel.from_pretrained( + self.model_path, cache_dir=os.getenv("HF_CACHE"), config=self.config + ).to(self.device) + self.trainer = Trainer(model=self.model, args=training_args) + + self.tasks: list[TaskInfo] = [] + for task_idx, task_name in enumerate(self.config.finetuning_task): + if self.config.tagger[task_name]: + task_type = TAGGING + elif self.config.relations[task_name]: + task_type = RELATIONS + else: + task_type = CLASSIFICATION + + self.tasks.append( + TaskInfo( + name=task_name, + type=task_type, + index=task_idx, + labels=tuple(self.config.label_dictionary[task_name]), + ) + ) + + def create_prediction_dataset( + self, + text: list[str], + data_args: CnlpDataArguments, + ): + dataset = Dataset.from_dict({"text": text}) + + return dataset.map( + preprocess_raw_data, + batched=True, + load_from_cache_file=False, + desc="Preprocessing raw input", + batch_size=100, + fn_kwargs={ + "inference_only": True, + "tokenizer": self.tokenizer, + "tasks": None, + "max_length": data_args.max_seq_length, + "hierarchical": self.config.model_type == "hier", + "chunk_len": data_args.chunk_len or -1, + "num_chunks": data_args.num_chunks or -1, + "insert_empty_chunk_at_beginning": data_args.insert_empty_chunk_at_beginning, + }, + ) + + def predict(self, dataset: Dataset, data_args: CnlpDataArguments): + raw_predictions = self.trainer.predict(dataset) + return CnlpPredictions( + dataset, + raw_prediction=raw_predictions, + tasks=self.tasks, + data_args=data_args, + ) + + def format_predictions(self, predictions: CnlpPredictions): + df = make_preds_df(predictions).select(["text", *[t.name for t in self.tasks]]) + + for task in self.tasks: + if task.type == CLASSIFICATION: + df = df.with_columns( + pl.struct( + prediction=pl.col(task.name) + .struct.field("predictions") + .struct.field("values"), + probs=pl.col(task.name) + .struct.field("model_output") + .struct.field("probs") + .arr.to_struct(fields=task.labels), + ).alias(task.name) + ) + elif task.type == TAGGING: + df = df.with_columns( + pl.struct( + pl.col(task.name) + .struct.field("predictions") + .struct.field("spans") + ).alias(task.name) + ) + elif task.type == RELATIONS: + df = df.with_columns( + pl.struct( + pl.col(task.name) + .struct.field("predictions") + .struct.field("relations") + ).alias(task.name) + ) + + return df.to_dicts() + + def process( + self, + input_doc: InputDocument, + max_seq_length: int = 128, + chunk_len: Union[int, None] = None, + num_chunks: Union[int, None] = None, + insert_empty_chunk_at_beginning: bool = False, + ): + data_args = CnlpDataArguments( + data_dir=[], + max_seq_length=max_seq_length, + chunk_len=chunk_len, + num_chunks=num_chunks, + insert_empty_chunk_at_beginning=insert_empty_chunk_at_beginning, + ) + + dataset = self.create_prediction_dataset(input_doc.to_text_list(), data_args) + predictions = self.predict(dataset, data_args) + return self.format_predictions(predictions) + + def router(self, prefix: str = ""): + router = APIRouter(prefix=prefix) + router.add_api_route("/process", self.process, methods=["POST"]) + return router + + def fastapi(self, router_prefix: str = ""): + app = FastAPI() + app.include_router(self.router(prefix=router_prefix)) + return app + + @classmethod + def multi_app(cls, apps: Iterable[tuple[Self, str]]): + multi_app = FastAPI() + for app, router_prefix in apps: + multi_app.include_router(app.router(router_prefix)) + return multi_app diff --git a/test/api/test_api.py b/test/api/test_api.py index 6b9a3a15..c3f152a1 100644 --- a/test/api/test_api.py +++ b/test/api/test_api.py @@ -5,94 +5,159 @@ import pytest from fastapi.testclient import TestClient -from cnlpt.api.utils import EntityDocument +from cnlpt.rest.cnlp_rest import CnlpRestApp, InputDocument class TestNegation: @pytest.fixture def test_client(self): - from cnlpt.api.negation_rest import app - - with TestClient(app) as client: + with TestClient( + CnlpRestApp("mlml-chip/negation_pubmedbert_sharpseed").fastapi() + ) as client: yield client def test_negation_startup(self, test_client): pass def test_negation_process(self, test_client: TestClient): - from cnlpt.api.negation_rest import NegationResults - - doc = EntityDocument( - doc_text="The patient has a sore knee and headache " + doc = InputDocument( + text="The patient has a sore knee and headache " "but denies nausea and has no anosmia.", - entities=[[18, 27], [32, 40], [52, 58], [70, 77]], + entity_spans=[(18, 27), (32, 40), (52, 58), (70, 77)], ) - response = test_client.post("/negation/process", content=doc.json()) + response = test_client.post("/process", content=doc.json()) response.raise_for_status() - assert response.json() == NegationResults.parse_obj( - {"statuses": [-1, -1, 1, 1]} - ) + assert response.json() == [ + { + "text": "The patient has a sore knee and headache but denies nausea and has no anosmia.", + "Negation": { + "prediction": "-1", + "probs": { + "1": pytest.approx(0.0002379878715146333, rel=1e-04), + "-1": pytest.approx(0.9997619986534119, rel=1e-04), + }, + }, + }, + { + "text": "The patient has a sore knee and headache but denies nausea and has no anosmia.", + "Negation": { + "prediction": "-1", + "probs": { + "1": pytest.approx(0.0004393413255456835, rel=1e-04), + "-1": pytest.approx(0.9995606541633606, rel=1e-04), + }, + }, + }, + { + "text": "The patient has a sore knee and headache but denies nausea and has no anosmia.", + "Negation": { + "prediction": "1", + "probs": { + "1": pytest.approx(0.9921413660049438, rel=1e-04), + "-1": pytest.approx(0.007858583703637123, rel=1e-04), + }, + }, + }, + { + "text": "The patient has a sore knee and headache but denies nausea and has no anosmia.", + "Negation": { + "prediction": "1", + "probs": { + "1": pytest.approx(0.9928833246231079, rel=1e-04), + "-1": pytest.approx(0.0071166763082146645, rel=1e-04), + }, + }, + }, + ] class TestTemporal: @pytest.fixture def test_client(self): - from cnlpt.api.temporal_rest import app - - with TestClient(app) as client: + with TestClient(CnlpRestApp("mlml-chip/thyme2_colon_e2e").fastapi()) as client: yield client def test_temporal_startup(self, test_client: TestClient): pass def test_temporal_process_sentence(self, test_client: TestClient): - from cnlpt.api.temporal_rest import ( - SentenceDocument, - TemporalResults, - ) - - doc = SentenceDocument( - sentence="The patient was diagnosed with adenocarcinoma " + doc = InputDocument( + text="The patient was diagnosed with adenocarcinoma " "March 3, 2010 and will be returning for " "chemotherapy next week." ) - response = test_client.post("/temporal/process_sentence", content=doc.json()) + response = test_client.post("/process", content=doc.json()) response.raise_for_status() - out = response.json() - expected_out = TemporalResults.parse_obj( + assert response.json() == [ { - "events": [ - [ - {"begin": 3, "dtr": "BEFORE", "end": 3}, - {"begin": 5, "dtr": "BEFORE", "end": 5}, - {"begin": 13, "dtr": "AFTER", "end": 13}, - {"begin": 15, "dtr": "AFTER", "end": 15}, + "text": "The patient was diagnosed with adenocarcinoma March 3, 2010 and will be returning for chemotherapy next week.", + "timex": { + "spans": [ + { + "text": "March 3, 2010 ", + "tag": "DATE", + "start": 6, + "end": 8, + "valid": True, + }, + { + "text": "next week.", + "tag": "DATE", + "start": 15, + "end": 16, + "valid": True, + }, ] - ], - "relations": [ - [ + }, + "event": { + "spans": [ + { + "text": "diagnosed ", + "tag": "BEFORE", + "start": 3, + "end": 3, + "valid": True, + }, + { + "text": "adenocarcinoma ", + "tag": "BEFORE", + "start": 5, + "end": 5, + "valid": True, + }, { - "arg1": "TIMEX-0", - "arg1_start": 6, - "arg2": "EVENT-0", - "arg2_start": 3, - "category": "CONTAINS", + "text": "returning ", + "tag": "AFTER", + "start": 12, + "end": 12, + "valid": True, }, { - "arg1": "TIMEX-1", - "arg1_start": 16, - "arg2": "EVENT-2", - "arg2_start": 13, - "category": "CONTAINS", + "text": "chemotherapy ", + "tag": "AFTER", + "start": 14, + "end": 14, + "valid": True, }, ] - ], - "timexes": [ - [ - {"begin": 6, "end": 9, "timeClass": "DATE"}, - {"begin": 16, "end": 17, "timeClass": "DATE"}, + }, + "tlinkx": { + "relations": [ + { + "arg1_wid": 6, + "arg1_text": "March", + "arg2_wid": 3, + "arg2_text": "diagnosed", + "label": "CONTAINS", + }, + { + "arg1_wid": 15, + "arg1_text": "next", + "arg2_wid": 12, + "arg2_text": "returning", + "label": "CONTAINS", + }, ] - ], + }, } - ) - assert out == expected_out + ] diff --git a/test/test_init.py b/test/test_init.py index 12b83959..d2fab80e 100644 --- a/test/test_init.py +++ b/test/test_init.py @@ -71,7 +71,7 @@ def test_init_args(): def test_init_api(): - import cnlpt.api + import cnlpt.rest - assert cnlpt.api.__package__ == "cnlpt.api" - assert cnlpt.api.__all__ == ["MODEL_TYPES", "get_rest_app"] + assert cnlpt.rest.__package__ == "cnlpt.rest" + assert cnlpt.rest.__all__ == ["CnlpRestApp"]