diff --git a/src/cnlpt/_cli/rest.py b/src/cnlpt/_cli/rest.py
index dd857277..60a81da6 100644
--- a/src/cnlpt/_cli/rest.py
+++ b/src/cnlpt/_cli/rest.py
@@ -1,13 +1,43 @@
+from typing import Union
+
import click
-from ..api import MODEL_TYPES, get_rest_app
+
+def parse_models(
+ ctx: click.Context,
+ param: click.Parameter,
+ value: Union[tuple[str, ...], None],
+):
+ if value is None:
+ return None
+
+ models: list[tuple[str, str]] = []
+ for item in value:
+ if "=" in item:
+ prefix, path = item.split("=", 1)
+ if not prefix.startswith("/"):
+ raise click.BadParameter(
+ f"route prefix must start with '/': {prefix}", param=param
+ )
+ elif len(value) > 1:
+ raise click.BadParameter(
+ "route prefixes are required when serving more than one model",
+ param=param,
+ )
+ else:
+ path = item
+ prefix = ""
+ models.append((prefix, path))
+ return models
@click.command("rest", context_settings={"show_default": True})
@click.option(
- "--model-type",
- type=click.Choice(MODEL_TYPES),
- required=True,
+ "--model",
+ "models",
+ multiple=True,
+ callback=parse_models,
+ help="Model definition as [ROUTER_PREFIX=]PATH_TO_MODEL. Prefix must start with '/'.",
)
@click.option(
"-h",
@@ -19,15 +49,13 @@
@click.option(
"-p", "--port", type=int, default=8000, help="Port to serve the REST app."
)
-@click.option(
- "--reload",
- type=bool,
- is_flag=True,
- default=False,
- help="Auto-reload the REST app.",
-)
-def rest_command(model_type: str, host: str, port: int, reload: bool):
+def rest_command(models: list[tuple[str, str]], host: str, port: int):
"""Start a REST application from a model."""
import uvicorn
- uvicorn.run(get_rest_app(model_type), host=host, port=port, reload=reload)
+ from ..rest import CnlpRestApp
+
+ app = CnlpRestApp.multi_app(
+ [(CnlpRestApp(model_path=path), prefix) for prefix, path in models]
+ )
+ uvicorn.run(app, host=host, port=port)
diff --git a/src/cnlpt/api/__init__.py b/src/cnlpt/api/__init__.py
deleted file mode 100644
index a3e8592f..00000000
--- a/src/cnlpt/api/__init__.py
+++ /dev/null
@@ -1,71 +0,0 @@
-"""Serve REST APIs for CNLPT models over your network."""
-
-from typing import Final
-
-MODEL_TYPES: Final = (
- "cnn",
- "current",
- "dtr",
- # "event",
- "hier",
- "negation",
- "temporal",
- # "termexists",
- # "timex",
-)
-"""The available model types for :func:`get_rest_app`."""
-
-
-def get_rest_app(model_type: str):
- """Get a FastAPI app for a certain model type.
-
- Args:
- model_type: The type of model to serve.
-
- Returns:
- The FastAPI app.
- """
- if model_type == "cnn":
- from .cnn_rest import app
-
- return app
- elif model_type == "current":
- from .current_rest import app
-
- return app
- elif model_type == "dtr":
- from .dtr_rest import app
-
- return app
- # elif model_type == "event":
- # from .event_rest import app
-
- # return app
- elif model_type == "hier":
- from .hier_rest import app
-
- return app
- elif model_type == "negation":
- from .negation_rest import app
-
- return app
- elif model_type == "temporal":
- from .temporal_rest import app
-
- return app
- # elif model_type == "termexists":
- # from .termexists_rest import app
-
- # return app
- # elif model_type == "timex":
- # from .timex_rest import app
-
- # return app
- else:
- raise ValueError(f"unknown model type: {model_type}")
-
-
-__all__ = [
- "MODEL_TYPES",
- "get_rest_app",
-]
diff --git a/src/cnlpt/api/cnn_rest.py b/src/cnlpt/api/cnn_rest.py
deleted file mode 100644
index af556e95..00000000
--- a/src/cnlpt/api/cnn_rest.py
+++ /dev/null
@@ -1,102 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import json
-import logging
-import os
-import sys
-from contextlib import asynccontextmanager
-from os.path import join
-from typing import Any
-
-import numpy as np
-import torch
-import torch.backends.mps
-from fastapi import FastAPI
-from scipy.special import softmax
-from transformers import AutoTokenizer, PreTrainedTokenizer
-
-from ..models.baseline import CnnSentenceClassifier
-from .utils import UnannotatedDocument, create_dataset, resolve_device
-
-MODEL_NAME = os.getenv("MODEL_PATH")
-device = os.getenv("MODEL_DEVICE", "auto")
-device = resolve_device(device)
-
-logger = logging.getLogger("CNN_REST_Processor")
-logger.setLevel(logging.DEBUG)
-
-MAX_SEQ_LENGTH = 128
-
-model: CnnSentenceClassifier
-tokenizer: PreTrainedTokenizer
-conf_dict: dict[str, Any]
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- global model, tokenizer, conf_dict
- if MODEL_NAME is None:
- sys.stderr.write(
- "This REST container requires a MODEL_PATH environment variable\n"
- )
- sys.exit(-1)
- conf_file = join(MODEL_NAME, "config.json")
- with open(conf_file) as fp:
- conf_dict = json.load(fp)
-
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
- num_labels_dict = {
- task: len(values) for task, values in conf_dict["label_dictionary"].items()
- }
- model = CnnSentenceClassifier.from_pretrained(
- MODEL_NAME,
- vocab_size=len(tokenizer),
- task_names=conf_dict["task_names"],
- num_labels_dict=num_labels_dict,
- embed_dims=conf_dict["cnn_embed_dim"],
- num_filters=conf_dict["cnn_num_filters"],
- filters=conf_dict["cnn_filter_sizes"],
- )
-
- model = model.to(device)
- tokenizer = tokenizer
- conf_dict = conf_dict
-
- yield
-
-
-app = FastAPI(lifespan=lifespan)
-
-
-@app.post("/cnn/classify")
-async def process(doc: UnannotatedDocument):
- instances = [doc.doc_text]
- dataset = create_dataset(
- instances, tokenizer, max_length=conf_dict["max_seq_length"]
- )
- _, logits = model.forward(
- input_ids=torch.LongTensor(dataset["input_ids"]).to(device),
- attention_mask=torch.LongTensor(dataset["attention_mask"]).to(device),
- )
-
- prediction = int(np.argmax(logits[0].cpu().detach().numpy(), axis=1))
- result = conf_dict["label_dictionary"][conf_dict["task_names"][0]][prediction]
- probabilities = softmax(logits[0][0].cpu().detach().numpy())
- # for redcap purposes, it might make more sense to only output the probability for the predicted class,
- # but i'm outputting them all, for transparency
- out_probabilities = [str(prob) for prob in probabilities]
- return {"result": result, "probabilities": out_probabilities}
diff --git a/src/cnlpt/api/current_rest.py b/src/cnlpt/api/current_rest.py
deleted file mode 100644
index 682c5228..00000000
--- a/src/cnlpt/api/current_rest.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import logging
-from contextlib import asynccontextmanager
-from time import time
-
-import numpy as np
-from fastapi import FastAPI
-from pydantic import BaseModel
-from transformers import Trainer
-from transformers.tokenization_utils import PreTrainedTokenizer
-
-from .utils import (
- EntityDocument,
- create_dataset,
- create_instance_string,
- initialize_cnlpt_model,
-)
-
-logger = logging.getLogger("Current_REST_Processor")
-logger.setLevel(logging.DEBUG)
-
-MODEL_NAME = "mlml-chip/current-thyme"
-TASK = "Current"
-LABELS = [False, True]
-
-MAX_LENGTH = 128
-
-
-class CurrentResults(BaseModel):
- """statuses: list of classifier outputs for every input"""
-
- statuses: list[bool]
-
-
-tokenizer: PreTrainedTokenizer
-trainer: Trainer
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- global tokenizer, trainer
- tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME)
- yield
-
-
-app = FastAPI(lifespan=lifespan)
-
-
-@app.post("/current/process")
-async def process(doc: EntityDocument):
- doc_text = doc.doc_text
- logger.warning(
- f"Received document of len {len(doc_text)} to process with {len(doc.entities)} entities"
- )
- instances = []
- start_time = time()
-
- if len(doc.entities) == 0:
- return CurrentResults(statuses=[])
-
- for ent_ind, offsets in enumerate(doc.entities):
- inst_str = create_instance_string(doc_text, offsets)
- logger.debug(f"Instance string is {inst_str}")
- instances.append(inst_str)
-
- dataset = create_dataset(instances, tokenizer, MAX_LENGTH)
- preproc_end = time()
-
- output = trainer.predict(test_dataset=dataset)
- predictions = output.predictions[0]
- predictions = np.argmax(predictions, axis=1)
-
- pred_end = time()
-
- results = []
- for ent_ind in range(len(dataset)):
- results.append(LABELS[predictions[ent_ind]])
-
- output = CurrentResults(statuses=results)
-
- postproc_end = time()
-
- preproc_time = preproc_end - start_time
- pred_time = pred_end - preproc_end
- postproc_time = postproc_end - pred_end
-
- logging.info(
- f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}"
- )
-
- return output
diff --git a/src/cnlpt/api/dtr_rest.py b/src/cnlpt/api/dtr_rest.py
deleted file mode 100644
index ee85e6c9..00000000
--- a/src/cnlpt/api/dtr_rest.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import logging
-from contextlib import asynccontextmanager
-from time import time
-
-import numpy as np
-from fastapi import FastAPI
-from pydantic import BaseModel
-from transformers import Trainer
-from transformers.tokenization_utils import PreTrainedTokenizer
-
-from .temporal_rest import OLD_DTR_LABEL_LIST
-from .utils import (
- EntityDocument,
- create_dataset,
- create_instance_string,
- initialize_cnlpt_model,
-)
-
-MODEL_NAME = "tmills/tiny-dtr"
-logger = logging.getLogger("DocTimeRel Processor with xtremedistil encoder")
-logger.setLevel(logging.INFO)
-
-MAX_LENGTH = 128
-
-
-class DocTimeRelResults(BaseModel):
- """statuses: dictionary from entity id to classification decision about DocTimeRel"""
-
- statuses: list[str]
-
-
-tokenizer: PreTrainedTokenizer
-trainer: Trainer
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- global tokenizer, trainer
- tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME)
- yield
-
-
-app = FastAPI(lifespan=lifespan)
-
-
-@app.post("/dtr/process")
-async def process(doc: EntityDocument):
- doc_text = doc.doc_text
- logger.warning(
- f"Received document of len {len(doc_text)} to process with {len(doc.entities)} entities"
- )
- instances = []
- start_time = time()
-
- if len(doc.entities) == 0:
- return DocTimeRelResults(statuses=[])
-
- for ent_ind, offsets in enumerate(doc.entities):
- # logger.debug('Entity ind: %d has offsets (%d, %d)' % (ent_ind, offsets[0], offsets[1]))
- inst_str = create_instance_string(doc_text, offsets)
- logger.debug(f"Instance string is {inst_str}")
- instances.append(inst_str)
-
- dataset = create_dataset(instances, tokenizer, max_length=MAX_LENGTH)
-
- preproc_end = time()
-
- output = trainer.predict(test_dataset=dataset)
- predictions = output.predictions[0]
- predictions = np.argmax(predictions, axis=1)
-
- pred_end = time()
-
- results = []
- for ent_ind in range(len(dataset)):
- results.append(OLD_DTR_LABEL_LIST[predictions[ent_ind]])
-
- output = DocTimeRelResults(statuses=results)
-
- postproc_end = time()
-
- preproc_time = preproc_end - start_time
- pred_time = pred_end - preproc_end
- postproc_time = postproc_end - pred_end
-
- logging.warning(
- f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}"
- )
-
- return output
diff --git a/src/cnlpt/api/event_rest.py b/src/cnlpt/api/event_rest.py
deleted file mode 100644
index e2d34c75..00000000
--- a/src/cnlpt/api/event_rest.py
+++ /dev/null
@@ -1,158 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import logging
-from contextlib import asynccontextmanager
-from time import time
-
-import numpy as np
-from fastapi import FastAPI
-from nltk.tokenize import wordpunct_tokenize as tokenize
-from seqeval.metrics.sequence_labeling import get_entities
-from transformers import Trainer
-from transformers.tokenization_utils import PreTrainedTokenizer
-
-from .temporal_rest import (
- EVENT_LABEL_LIST,
- Event,
- SentenceDocument,
- TemporalResults,
- TokenizedSentenceDocument,
- create_instance_string,
-)
-from .utils import create_dataset, initialize_cnlpt_model
-
-MODEL_NAME = "tmills/event-thyme-colon-pubmedbert"
-logger = logging.getLogger("Event_REST_Processor")
-logger.setLevel(logging.INFO)
-
-MAX_LENGTH = 128
-
-
-tokenizer: PreTrainedTokenizer
-trainer: Trainer
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- global tokenizer, trainer
- tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME)
- yield
-
-
-app = FastAPI(lifespan=lifespan)
-
-
-@app.post("/temporal/process")
-async def process(doc: TokenizedSentenceDocument):
- return process_tokenized_sentence_document(doc)
-
-
-@app.post("/temporal/process_sentence")
-async def process_sentence(doc: SentenceDocument):
- tokenized_sent = tokenize(doc.sentence)
- doc = TokenizedSentenceDocument(
- sent_tokens=[
- tokenized_sent,
- ],
- metadata="Single sentence",
- )
- return process_tokenized_sentence_document(doc)
-
-
-def process_tokenized_sentence_document(doc: TokenizedSentenceDocument):
- sents = doc.sent_tokens
- metadata = doc.metadata
-
- logger.warning(f"Received document labeled {metadata} with {len(sents)} sentences")
- instances = []
- start_time = time()
-
- for sent_ind, token_list in enumerate(sents):
- inst_str = create_instance_string(token_list)
- logger.debug(f"Instance string is {inst_str}")
- instances.append(inst_str)
-
- dataset = create_dataset(instances, tokenizer, max_length=MAX_LENGTH)
- preproc_end = time()
-
- output = trainer.predict(test_dataset=dataset)
-
- event_predictions = np.argmax(output.predictions[0], axis=2)
-
- pred_end = time()
-
- timex_results = []
- event_results = []
- rel_results = []
-
- for sent_ind in range(len(dataset)):
- batch_encoding = tokenizer.batch_encode_plus(
- [
- sents[sent_ind],
- ],
- is_split_into_words=True,
- max_length=MAX_LENGTH,
- )
- word_ids = batch_encoding.word_ids(0)
- wpind_to_ind = {}
- event_labels = []
- previous_word_idx = None
-
- for word_pos_idx, word_idx in enumerate(word_ids):
- if word_idx != previous_word_idx and word_idx is not None:
- key = word_pos_idx
- val = len(wpind_to_ind)
-
- wpind_to_ind[key] = val
- event_labels.append(
- EVENT_LABEL_LIST[event_predictions[sent_ind][word_pos_idx]]
- )
- previous_word_idx = word_idx
-
- event_entities = get_entities(event_labels)
- logging.info(f"Extracted {len(event_entities)} events from the sentence")
- event_results.append(
- [
- Event(dtr=label[0], begin=label[1], end=label[2])
- for label in event_entities
- ]
- )
- timex_results.append([])
- rel_results.append([])
-
- results = TemporalResults(
- timexes=timex_results, events=event_results, relations=rel_results
- )
-
- postproc_end = time()
-
- preproc_time = preproc_end - start_time
- pred_time = pred_end - preproc_end
- postproc_time = postproc_end - pred_end
-
- logging.info(
- f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}"
- )
-
- return results
-
-
-@app.post("/temporal/collection_process_complete")
-async def collection_process_complete():
- global trainer
- trainer = None
diff --git a/src/cnlpt/api/hier_rest.py b/src/cnlpt/api/hier_rest.py
deleted file mode 100644
index 66711866..00000000
--- a/src/cnlpt/api/hier_rest.py
+++ /dev/null
@@ -1,104 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import logging
-import os
-from contextlib import asynccontextmanager
-
-import torch
-from fastapi import FastAPI
-from transformers import PreTrainedModel
-from transformers.tokenization_utils import PreTrainedTokenizer
-
-from .utils import (
- UnannotatedDocument,
- create_dataset,
- initialize_hier_model,
- resolve_device,
-)
-
-MODEL_NAME = os.getenv("MODEL_PATH")
-
-device = os.getenv("MODEL_DEVICE", "auto")
-device = resolve_device(device)
-
-logger = logging.getLogger("HierRep_REST_Processor")
-logger.setLevel(logging.DEBUG)
-
-tokenizer: PreTrainedTokenizer
-model: PreTrainedModel
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- global tokenizer, model
- tokenizer, model = initialize_hier_model(MODEL_NAME)
- yield
-
-
-app = FastAPI(lifespan=lifespan)
-
-
-@app.post("/hier/get_rep")
-async def get_representation(doc: UnannotatedDocument):
- instances = [doc.doc_text]
- dataset = create_dataset(
- instances,
- tokenizer,
- max_length=16000,
- hier=True,
- chunk_len=200,
- num_chunks=80,
- insert_empty_chunk_at_beginning=False,
- )
-
- result = model.forward(
- input_ids=torch.LongTensor(dataset["input_ids"]).to(model.device),
- token_type_ids=torch.LongTensor(dataset["token_type_ids"]).to(model.device),
- attention_mask=torch.LongTensor(dataset["attention_mask"]).to(model.device),
- output_hidden_states=True,
- )
-
- # Convert to a list so python can send it out
- hidden_states = result["hidden_states"].to("cpu").detach().numpy()[:, 0, :].tolist()
- return {"reps": hidden_states[0]}
-
-
-@app.post("/hier/classify")
-async def classify(doc: UnannotatedDocument):
- instances = [doc.doc_text]
- dataset = create_dataset(
- instances,
- tokenizer,
- max_length=16000,
- hier=True,
- chunk_len=200,
- num_chunks=80,
- insert_empty_chunk_at_beginning=False,
- )
- result = model.forward(
- input_ids=torch.LongTensor(dataset["input_ids"]).to(model.device),
- token_type_ids=torch.LongTensor(dataset["token_type_ids"]).to(model.device),
- attention_mask=torch.LongTensor(dataset["attention_mask"]).to(model.device),
- output_hidden_states=False,
- )
-
- predictions = [
- int(torch.argmax(logits.to("cpu").detach()).numpy())
- for logits in result["logits"]
- ]
- labels = [next(iter(model.label_dictionary.values()))[x] for x in predictions]
- return {"result": labels}
diff --git a/src/cnlpt/api/negation_rest.py b/src/cnlpt/api/negation_rest.py
deleted file mode 100644
index e4ba78b2..00000000
--- a/src/cnlpt/api/negation_rest.py
+++ /dev/null
@@ -1,112 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import logging
-from contextlib import asynccontextmanager
-from time import time
-
-import numpy as np
-from fastapi import FastAPI
-from pydantic import BaseModel
-from transformers import Trainer
-from transformers.tokenization_utils import PreTrainedTokenizer
-
-from .utils import (
- EntityDocument,
- create_dataset,
- create_instance_string,
- initialize_cnlpt_model,
-)
-
-MODEL_NAME = "mlml-chip/negation_pubmedbert_sharpseed"
-logger = logging.getLogger("Negation_REST_Processor")
-logger.setLevel(logging.DEBUG)
-
-TASK = "Negation"
-LABELS = [-1, 1]
-
-MAX_LENGTH = 128
-
-
-class NegationResults(BaseModel):
- """statuses: dictionary from entity id to classification decision about negation; true -> negated, false -> not negated"""
-
- statuses: list[int]
-
-
-tokenizer: PreTrainedTokenizer
-trainer: Trainer
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- global tokenizer, trainer
- tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME)
- yield
-
-
-app = FastAPI(lifespan=lifespan)
-
-
-@app.post("/negation/process")
-async def process(doc: EntityDocument):
- doc_text = doc.doc_text
- logger.warning(
- f"Received document of len {len(doc_text)} to process with {len(doc.entities)} entities"
- )
- instances = []
- start_time = time()
-
- if len(doc.entities) == 0:
- return NegationResults(statuses=[])
-
- for ent_ind, offsets in enumerate(doc.entities):
- # logger.debug('Entity ind: %d has offsets (%d, %d)' % (ent_ind, offsets[0], offsets[1]))
- inst_str = create_instance_string(doc_text, offsets)
- logger.debug(f"Instance string is {inst_str}")
- instances.append(inst_str)
-
- dataset = create_dataset(instances, tokenizer, MAX_LENGTH)
- preproc_end = time()
-
- output = trainer.predict(test_dataset=dataset)
- predictions = output.predictions[0]
- predictions = np.argmax(predictions, axis=1)
-
- pred_end = time()
-
- results = []
- for ent_ind in range(len(dataset)):
- results.append(LABELS[predictions[ent_ind]])
-
- output = NegationResults(statuses=results)
-
- postproc_end = time()
-
- preproc_time = preproc_end - start_time
- pred_time = pred_end - preproc_end
- postproc_time = postproc_end - pred_end
-
- logging.warning(
- f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}"
- )
-
- return output
-
-
-@app.get("/negation/{test_str}")
-async def test(test_str: str):
- return {"argument": test_str}
diff --git a/src/cnlpt/api/temporal_rest.py b/src/cnlpt/api/temporal_rest.py
deleted file mode 100644
index 7b23e495..00000000
--- a/src/cnlpt/api/temporal_rest.py
+++ /dev/null
@@ -1,373 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import logging
-import os
-from contextlib import asynccontextmanager
-from time import time
-from typing import Union
-
-import numpy as np
-from fastapi import FastAPI
-from nltk.tokenize import wordpunct_tokenize as tokenize
-from pydantic import BaseModel
-from seqeval.metrics.sequence_labeling import get_entities
-from transformers import Trainer
-from transformers.tokenization_utils import PreTrainedTokenizer
-
-from .utils import create_dataset, initialize_cnlpt_model
-
-MODEL_NAME = "mlml-chip/thyme2_colon_e2e"
-logger = logging.getLogger("Temporal_REST_Processor")
-logger.setLevel(logging.INFO)
-
-LABELS = ["-1", "1"]
-TIMEX_LABEL_LIST = [
- "O",
- "B-DATE",
- "B-DURATION",
- "B-PREPOSTEXP",
- "B-QUANTIFIER",
- "B-SET",
- "B-TIME",
- "B-SECTIONTIME",
- "B-DOCTIME",
- "I-DATE",
- "I-DURATION",
- "I-PREPOSTEXP",
- "I-QUANTIFIER",
- "I-SET",
- "I-TIME",
- "I-SECTIONTIME",
- "I-DOCTIME",
-]
-TIMEX_LABEL_DICT = {val: ind for ind, val in enumerate(TIMEX_LABEL_LIST)}
-EVENT_LABEL_LIST = [
- "O",
- "B-AFTER",
- "B-BEFORE",
- "B-BEFORE/OVERLAP",
- "B-OVERLAP",
- "I-AFTER",
- "I-BEFORE",
- "I-BEFORE/OVERLAP",
- "I-OVERLAP",
-]
-EVENT_LABEL_DICT = {val: ind for ind, val in enumerate(EVENT_LABEL_LIST)}
-
-RELATION_LABEL_LIST = ["None", "CONTAINS", "OVERLAP", "BEFORE", "BEGINS-ON", "ENDS-ON"]
-RELATION_LABEL_DICT = {val: ind for ind, val in enumerate(RELATION_LABEL_LIST)}
-
-DTR_LABEL_LIST = ["AFTER", "BEFORE", "BEFORE/OVERLAP", "OVERLAP"]
-OLD_DTR_LABEL_LIST = ["BEFORE", "OVERLAP", "BEFORE/OVERLAP", "AFTER"]
-
-LABELS = [TIMEX_LABEL_LIST, EVENT_LABEL_LIST, RELATION_LABEL_LIST]
-MAX_LENGTH = 128
-
-
-class SentenceDocument(BaseModel):
- sentence: str
-
-
-class TokenizedSentenceDocument(BaseModel):
- """sent_tokens: a list of sentences, where each sentence is a list of tokens"""
-
- sent_tokens: list[list[str]]
- metadata: str
-
-
-class Timex(BaseModel):
- begin: int
- end: int
- timeClass: str
-
-
-class Event(BaseModel):
- begin: int
- end: int
- dtr: str
-
-
-class Relation(BaseModel):
- # Allow args to be none, so that we can potentially link them to times or events in the client, or if they don't
- # care about that. pass back the token indices of the args in addition.
- arg1: Union[str, None]
- arg2: Union[str, None]
- category: str
- arg1_start: int
- arg2_start: int
-
-
-class TemporalResults(BaseModel):
- """lists of timexes, events and relations for list of sentences"""
-
- timexes: list[list[Timex]]
- events: list[list[Event]]
- relations: list[list[Relation]]
-
-
-def create_instance_string(tokens: list[str]):
- return " ".join(tokens)
-
-
-task_order: dict[str, int]
-tasks: list[str]
-tokenizer: PreTrainedTokenizer
-trainer: Trainer
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- global \
- TIMEX_LABEL_LIST, \
- TIMEX_LABEL_DICT, \
- EVENT_LABEL_LIST, \
- EVENT_LABEL_DICT, \
- RELATION_LABEL_LIST, \
- RELATION_LABEL_DICT, \
- task_order, \
- tasks, \
- tokenizer, \
- trainer
-
- local_model_name = os.getenv("MODEL_NAME", MODEL_NAME)
- tokenizer, trainer = initialize_cnlpt_model(local_model_name)
-
- config_dict = trainer.model.config.to_dict()
- # For newer models (version >= 0.6.0), the label dictionary is saved with the model
- # config. we can look for it to preserve backwards compatibility for now but
- # should eventually remove the hardcoded label lists from our inference tools.
- label_dict = config_dict.get("label_dictionary", None)
- if label_dict is not None:
- # some older versions have one label dictionary per dataset, future versions should just
- # have a task-keyed dictionary
- if type(label_dict) is list:
- label_dict = label_dict[0]
-
- if "event" in label_dict:
- EVENT_LABEL_LIST = label_dict["event"]
- EVENT_LABEL_DICT = {val: ind for ind, val in enumerate(EVENT_LABEL_LIST)}
- print(EVENT_LABEL_LIST)
-
- if "timex" in label_dict:
- TIMEX_LABEL_LIST = label_dict["timex"]
- TIMEX_LABEL_DICT = {val: ind for ind, val in enumerate(TIMEX_LABEL_LIST)}
- print(TIMEX_LABEL_LIST)
-
- if "tlinkx" in label_dict:
- RELATION_LABEL_LIST = label_dict["tlinkx"]
- RELATION_LABEL_DICT = {
- val: ind for ind, val in enumerate(RELATION_LABEL_LIST)
- }
- print(RELATION_LABEL_LIST)
-
- tasks = config_dict.get("finetuning_task", None)
- task_order = {}
- if tasks is not None:
- print("Overwriting finetuning task order")
- for task_ind, task_name in enumerate(tasks):
- task_order[task_name] = task_ind
- print(task_order)
- else:
- print("Didn't find a new task ordering in the model config")
- yield
-
-
-app = FastAPI(lifespan=lifespan)
-
-
-@app.post("/temporal/process")
-async def process(doc: TokenizedSentenceDocument):
- return process_tokenized_sentence_document(doc)
-
-
-@app.post("/temporal/process_sentence")
-async def process_sentence(doc: SentenceDocument):
- tokenized_sent = tokenize(doc.sentence)
- doc = TokenizedSentenceDocument(
- sent_tokens=[
- tokenized_sent,
- ],
- metadata="Single sentence",
- )
- return process_tokenized_sentence_document(doc)
-
-
-def process_tokenized_sentence_document(doc: TokenizedSentenceDocument):
- sents = doc.sent_tokens
- metadata = doc.metadata
-
- print(EVENT_LABEL_LIST)
- print(TIMEX_LABEL_LIST)
- print(RELATION_LABEL_LIST)
-
- logger.warning(f"Received document labeled {metadata} with {len(sents)} sentences")
- instances = []
- start_time = time()
-
- for sent_ind, token_list in enumerate(sents):
- inst_str = create_instance_string(token_list)
- logger.debug(f"Instance string is {inst_str}")
- instances.append(inst_str)
-
- dataset = create_dataset(instances, tokenizer, MAX_LENGTH)
- preproc_end = time()
-
- output = trainer.predict(test_dataset=dataset)
-
- timex_predictions = np.argmax(output.predictions[task_order["timex"]], axis=2)
- event_predictions = np.argmax(output.predictions[task_order["event"]], axis=2)
- rel_predictions = np.argmax(output.predictions[task_order["tlinkx"]], axis=3)
- rel_inds = np.where(rel_predictions != RELATION_LABEL_DICT["None"])
-
- logging.debug(f"Found relation indices: {rel_inds!s}")
-
- rels_by_sent = {}
- for rel_num in range(len(rel_inds[0])):
- sent_ind = rel_inds[0][rel_num]
- if sent_ind not in rels_by_sent:
- rels_by_sent[sent_ind] = []
-
- arg1_ind = rel_inds[1][rel_num]
- arg2_ind = rel_inds[2][rel_num]
- if arg1_ind == arg2_ind:
- # no relations between an entity and itself
- logger.warning("Found relation between an entity and itself... skipping")
- continue
-
- rel_cat = rel_predictions[sent_ind, arg1_ind, arg2_ind]
-
- rels_by_sent[sent_ind].append((arg1_ind, arg2_ind, rel_cat))
-
- pred_end = time()
-
- timex_results = []
- event_results = []
- rel_results = []
-
- for sent_ind in range(len(dataset)):
- batch_encoding = tokenizer(
- [
- sents[sent_ind],
- ],
- is_split_into_words=True,
- max_length=MAX_LENGTH,
- )
- word_ids = batch_encoding.word_ids(0)
- wpind_to_ind = {}
- timex_labels = []
- event_labels = []
- previous_word_idx = None
-
- for word_pos_idx, word_idx in enumerate(word_ids):
- if word_idx != previous_word_idx and word_idx is not None:
- key = word_pos_idx
- val = len(wpind_to_ind)
-
- wpind_to_ind[key] = val
- # tokeni_to_wpi[val] = key
- timex_labels.append(
- TIMEX_LABEL_LIST[timex_predictions[sent_ind][word_pos_idx]]
- )
- try:
- event_labels.append(
- EVENT_LABEL_LIST[event_predictions[sent_ind][word_pos_idx]]
- )
- except Exception as e:
- print(
- f"exception thrown when sent_ind={sent_ind} and word_pos_idx={word_pos_idx}"
- )
- print(
- f"prediction is {event_predictions[sent_ind][word_pos_idx]!s}"
- )
- raise e
-
- previous_word_idx = word_idx
-
- timex_entities = get_entities(timex_labels)
- logging.info(
- f"Extracted {len(timex_entities)} timex entities from the sentence"
- )
- timex_results.append(
- [
- Timex(timeClass=label[0], begin=label[1], end=label[2])
- for label in timex_entities
- ]
- )
-
- event_entities = get_entities(event_labels)
- logging.info(f"Extracted {len(event_entities)} events from the sentence")
- event_results.append(
- [
- Event(dtr=label[0], begin=label[1], end=label[2])
- for label in event_entities
- ]
- )
-
- rel_sent_results = []
- for rel in rels_by_sent.get(sent_ind, []):
- arg1 = None
- arg2 = None
- if rel[0] not in wpind_to_ind or rel[1] not in wpind_to_ind:
- logging.warning(
- "Found a relation to a non-leading wordpiece token... ignoring"
- )
- continue
-
- arg1_ind = wpind_to_ind[rel[0]]
- arg2_ind = wpind_to_ind[rel[1]]
-
- sent_timexes = timex_results[-1]
- for timex_ind, timex in enumerate(sent_timexes):
- if timex.begin == arg1_ind:
- arg1 = f"TIMEX-{timex_ind}"
- if timex.begin == arg2_ind:
- arg2 = f"TIMEX-{timex_ind}"
-
- sent_events = event_results[-1]
- for event_ind, event in enumerate(sent_events):
- if event.begin == arg1_ind:
- arg1 = f"EVENT-{event_ind}"
- if event.begin == arg2_ind:
- arg2 = f"EVENT-{event_ind}"
-
- rel = Relation(
- arg1=arg1,
- arg2=arg2,
- category=RELATION_LABEL_LIST[rel[2]],
- arg1_start=arg1_ind,
- arg2_start=arg2_ind,
- )
- rel_sent_results.append(rel)
-
- rel_results.append(rel_sent_results)
-
- results = TemporalResults(
- timexes=timex_results, events=event_results, relations=rel_results
- )
-
- postproc_end = time()
-
- preproc_time = preproc_end - start_time
- pred_time = pred_end - preproc_end
- postproc_time = postproc_end - pred_end
-
- logging.info(
- f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}"
- )
-
- return results
diff --git a/src/cnlpt/api/termexists_rest.py b/src/cnlpt/api/termexists_rest.py
deleted file mode 100644
index a6024513..00000000
--- a/src/cnlpt/api/termexists_rest.py
+++ /dev/null
@@ -1,106 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import logging
-from contextlib import asynccontextmanager
-from time import time
-
-import numpy as np
-from fastapi import FastAPI
-from pydantic import BaseModel
-from transformers import Trainer
-from transformers.tokenization_utils import PreTrainedTokenizer
-
-from .utils import (
- EntityDocument,
- create_dataset,
- create_instance_string,
- initialize_cnlpt_model,
-)
-
-MODEL_NAME = "mlml-chip/sharpseed-termexists"
-logger = logging.getLogger("TermExists_REST_Processor")
-logger.setLevel(logging.DEBUG)
-
-TASK = "TermExists"
-LABELS = [-1, 1]
-
-MAX_LENGTH = 128
-
-
-class TermExistsResults(BaseModel):
- """statuses: list of classifier outputs for every input"""
-
- statuses: list[int]
-
-
-tokenizer: PreTrainedTokenizer
-trainer: Trainer
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- global tokenizer, trainer
- tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME)
- yield
-
-
-app = FastAPI(lifespan=lifespan)
-
-
-@app.post("/termexists/process")
-async def process(doc: EntityDocument):
- doc_text = doc.doc_text
- logger.warning(
- f"Received document of len {len(doc_text)} to process with {len(doc.entities)} entities"
- )
- instances = []
- start_time = time()
-
- if len(doc.entities) == 0:
- return TermExistsResults(statuses=[])
-
- for ent_ind, offsets in enumerate(doc.entities):
- inst_str = create_instance_string(doc_text, offsets)
- logger.debug(f"Instance string is {inst_str}")
- instances.append(inst_str)
-
- dataset = create_dataset(instances, tokenizer, MAX_LENGTH)
- preproc_end = time()
-
- output = trainer.predict(test_dataset=dataset)
- predictions = output.predictions[0]
- predictions = np.argmax(predictions, axis=1)
-
- pred_end = time()
-
- results = []
- for ent_ind in range(len(dataset)):
- results.append(LABELS[predictions[ent_ind]])
-
- output = TermExistsResults(statuses=results)
-
- postproc_end = time()
-
- preproc_time = preproc_end - start_time
- pred_time = pred_end - preproc_end
- postproc_time = postproc_end - pred_end
-
- logging.warning(
- f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}"
- )
-
- return output
diff --git a/src/cnlpt/api/timex_rest.py b/src/cnlpt/api/timex_rest.py
deleted file mode 100644
index bd44778b..00000000
--- a/src/cnlpt/api/timex_rest.py
+++ /dev/null
@@ -1,156 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import logging
-from contextlib import asynccontextmanager
-from time import time
-
-import numpy as np
-from fastapi import FastAPI
-from nltk.tokenize import wordpunct_tokenize as tokenize
-from seqeval.metrics.sequence_labeling import get_entities
-from transformers import Trainer
-from transformers.tokenization_utils import PreTrainedTokenizer
-
-from .temporal_rest import (
- TIMEX_LABEL_LIST,
- SentenceDocument,
- TemporalResults,
- Timex,
- TokenizedSentenceDocument,
- create_instance_string,
-)
-from .utils import create_dataset, initialize_cnlpt_model
-
-MODEL_NAME = "tmills/timex-thyme-colon-pubmedbert"
-logger = logging.getLogger("Timex_REST_Processor")
-logger.setLevel(logging.INFO)
-
-MAX_LENGTH = 128
-
-
-tokenizer: PreTrainedTokenizer
-trainer: Trainer
-
-
-@asynccontextmanager
-async def lifespan(app: FastAPI):
- global tokenizer, trainer
- tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME)
- yield
-
-
-app = FastAPI(lifespan=lifespan)
-
-
-@app.post("/temporal/process")
-async def process(doc: TokenizedSentenceDocument):
- return process_tokenized_sentence_document(doc)
-
-
-@app.post("/temporal/process_sentence")
-async def process_sentence(doc: SentenceDocument):
- tokenized_sent = tokenize(doc.sentence)
- doc = TokenizedSentenceDocument(
- sent_tokens=[
- tokenized_sent,
- ],
- metadata="Single sentence",
- )
- return process_tokenized_sentence_document(doc)
-
-
-def process_tokenized_sentence_document(doc: TokenizedSentenceDocument):
- sents = doc.sent_tokens
- metadata = doc.metadata
-
- logger.warning(f"Received document labeled {metadata} with {len(sents)} sentences")
- instances = []
- start_time = time()
-
- for sent_ind, token_list in enumerate(sents):
- inst_str = create_instance_string(token_list)
- logger.debug(f"Instance string is {inst_str}")
- instances.append(inst_str)
-
- dataset = create_dataset(instances, tokenizer, max_length=MAX_LENGTH)
- logger.warning(f"Dataset is as follows: {dataset.features!s}")
-
- preproc_end = time()
-
- output = trainer.predict(test_dataset=dataset)
-
- timex_predictions = np.argmax(output.predictions[0], axis=2)
-
- timex_results = []
- event_results = []
- relation_results = []
-
- pred_end = time()
-
- for sent_ind in range(len(dataset)):
- batch_encoding = tokenizer.batch_encode_plus(
- [
- sents[sent_ind],
- ],
- is_split_into_words=True,
- max_length=MAX_LENGTH,
- )
- word_ids = batch_encoding.word_ids(0)
- wpind_to_ind = {}
- timex_labels = []
- previous_word_idx = None
-
- for word_pos_idx, word_idx in enumerate(word_ids):
- if word_idx != previous_word_idx and word_idx is not None:
- key = word_pos_idx
- val = len(wpind_to_ind)
-
- wpind_to_ind[key] = val
- timex_labels.append(
- TIMEX_LABEL_LIST[timex_predictions[sent_ind][word_pos_idx]]
- )
- previous_word_idx = word_idx
-
- timex_entities = get_entities(timex_labels)
- logging.info(
- f"Extracted {len(timex_entities)} timex entities from the sentence"
- )
- timex_results.append(
- [
- Timex(timeClass=label[0], begin=label[1], end=label[2])
- for label in timex_entities
- ]
- )
- event_results.append([])
- relation_results.append([])
-
- results = TemporalResults(
- timexes=timex_results, events=event_results, relations=relation_results
- )
-
- postproc_end = time()
-
- preproc_time = preproc_end - start_time
- pred_time = pred_end - preproc_end
- postproc_time = postproc_end - pred_end
-
- logging.info(
- f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}"
- )
-
- return results
diff --git a/src/cnlpt/api/utils.py b/src/cnlpt/api/utils.py
deleted file mode 100644
index c756ef70..00000000
--- a/src/cnlpt/api/utils.py
+++ /dev/null
@@ -1,157 +0,0 @@
-import logging
-import os
-from typing import Literal, cast
-
-import torch
-from datasets import Dataset
-from pydantic import BaseModel
-from transformers.hf_argparser import HfArgumentParser
-
-# Modeling imports
-from transformers.models.auto.configuration_auto import AutoConfig
-from transformers.models.auto.modeling_auto import AutoModel
-from transformers.models.auto.tokenization_auto import AutoTokenizer
-from transformers.tokenization_utils import PreTrainedTokenizer
-from transformers.trainer import Trainer
-from transformers.training_args import TrainingArguments
-
-from ..data.preprocess import preprocess_raw_data
-from ..models import CnlpConfig
-
-
-class UnannotatedDocument(BaseModel):
- doc_text: str
-
-
-class EntityDocument(BaseModel):
- """doc_text: The raw text of the document
- offset: A list of entities, where each is a tuple of character offsets into doc_text for that entity
- """
-
- doc_text: str
- entities: list[list[int]]
-
-
-def create_dataset(
- inst_list: list[str],
- tokenizer: PreTrainedTokenizer,
- max_length: int = 128,
- hier: bool = False,
- chunk_len: int = 200,
- num_chunks: int = 40,
- insert_empty_chunk_at_beginning: bool = False,
-):
- """Use a tokenizer to create a dataset from a list of strings."""
- dataset = Dataset.from_dict({"text": inst_list})
- task_dataset = dataset.map(
- preprocess_raw_data,
- batched=True,
- load_from_cache_file=False,
- desc="Running tokenizer on dataset, organizing labels, creating hierarchical segments if necessary",
- batch_size=100,
- fn_kwargs={
- "tokenizer": tokenizer,
- "tasks": None,
- "max_length": max_length,
- "inference_only": True,
- "hierarchical": hier,
- # TODO: need to get this from the model if necessary
- "chunk_len": chunk_len,
- "num_chunks": num_chunks,
- "insert_empty_chunk_at_beginning": insert_empty_chunk_at_beginning,
- },
- )
- return task_dataset
-
-
-def create_instance_string(doc_text: str, offsets: list[int]):
- start = max(0, offsets[0] - 100)
- end = min(len(doc_text), offsets[1] + 100)
- raw_str = (
- doc_text[start : offsets[0]]
- + " "
- + doc_text[offsets[0] : offsets[1]]
- + " "
- + doc_text[offsets[1] : end]
- )
- return raw_str.replace("\n", " ")
-
-
-def resolve_device(
- device: str,
-) -> Literal["cuda", "mps", "cpu"]:
- device = device.lower()
- if device not in ("cuda", "mps", "cpu", "auto"):
- raise ValueError(f"invalid device {device}")
- if device == "auto":
- if torch.cuda.is_available():
- device = "cuda"
- elif torch.mps.is_available():
- device = "mps"
- else:
- device = "cpu"
- elif device == "cuda" and not torch.cuda.is_available():
- logging.warning(
- "Device is set to 'cuda' but was not available; setting to 'cpu' and proceeding. If you have a GPU you need to debug why pytorch cannot see it."
- )
- device = "cpu"
- elif device == "mps" and not torch.mps.is_available():
- logging.warning(
- "Device is set to 'mps' but was not available; setting to 'cpu' and proceeding. If you have a GPU you need to debug why pytorch cannot see it."
- )
- device = "cpu"
- return device
-
-
-def initialize_cnlpt_model(
- model_name,
- device: Literal["cuda", "mps", "cpu", "auto"] = "auto",
- batch_size=8,
-):
- args = [
- "--output_dir",
- "save_run/",
- "--per_device_eval_batch_size",
- str(batch_size),
- "--do_predict",
- "--report_to",
- "none",
- ]
- parser = HfArgumentParser((TrainingArguments,))
- training_args = cast(
- TrainingArguments, parser.parse_args_into_dataclasses(args=args)[0]
- )
-
- if torch.mps.is_available():
- # pin_memory is unsupported on MPS, but defaults to True,
- # so we'll explicitly turn it off to avoid a warning.
- training_args.dataloader_pin_memory = False
-
- config = AutoConfig.from_pretrained(model_name)
- tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
- model = AutoModel.from_pretrained(
- model_name, cache_dir=os.getenv("HF_CACHE"), config=config
- )
-
- model = model.to(resolve_device(device))
-
- trainer = Trainer(model=model, args=training_args)
-
- return tokenizer, trainer
-
-
-def initialize_hier_model(
- model_name,
- device: Literal["cuda", "mps", "cpu", "auto"] = "auto",
-):
- config: CnlpConfig = AutoConfig.from_pretrained(model_name)
- tokenizer = AutoTokenizer.from_pretrained(model_name, config=config)
-
- model = AutoModel.from_pretrained(
- model_name, cache_dir=os.getenv("HF_CACHE"), config=config
- )
- model.train(False)
-
- model = model.to(resolve_device(device))
-
- return tokenizer, model
diff --git a/src/cnlpt/data/analysis.py b/src/cnlpt/data/analysis.py
index d2591a33..2d7fd7a1 100644
--- a/src/cnlpt/data/analysis.py
+++ b/src/cnlpt/data/analysis.py
@@ -175,6 +175,7 @@ def make_preds_df(
Returns:
The DataFrame for analysis.
"""
+
seq_len = len(predictions.input_data["input_ids"][0])
df_data = {
@@ -195,37 +196,53 @@ def make_preds_df(
else:
tasks = predictions.tasks
+ unlabeled = predictions.raw.label_ids is None
+
for task in tasks:
task_pred = predictions.task_predictions[task.name]
- df = df.with_columns(
- pl.struct(
- labels=pl.struct(
+
+ fields = []
+ if not unlabeled:
+ fields.append(
+ pl.struct(
ids=task_pred.labels,
values=task_pred.target_str_labels,
- ),
- predictions=pl.struct(
+ ).alias("labels")
+ )
+
+ fields.extend(
+ [
+ pl.struct(
ids=task_pred.predicted_int_labels,
values=task_pred.predicted_str_labels,
- ),
- model_output=pl.struct(
+ ).alias("predictions"),
+ pl.struct(
logits=task_pred.logits,
probs=task_pred.probs,
- ),
- ).alias(task.name)
+ ).alias("model_output"),
+ ]
)
+ df = df.with_columns(pl.struct(fields).alias(task.name))
+
if task.type == CLASSIFICATION:
# classification output is already pretty human-interpretable
pass
elif task.type == TAGGING:
# for tagging, we'll convert BIO tags to labeled spans
- df = df.join(
- _bio_tags_to_spans(
- df, pl.col(task.name).struct.field("labels").struct.field("values")
- ),
- on="sample_idx",
- how="left",
- ).rename({"spans": "target_spans"})
+ tagging_fields = []
+ if not unlabeled:
+ df = df.join(
+ _bio_tags_to_spans(
+ df,
+ pl.col(task.name).struct.field("labels").struct.field("values"),
+ ),
+ on="sample_idx",
+ how="left",
+ ).rename({"spans": "target_spans"})
+ tagging_fields.append(
+ pl.field("labels").struct.with_fields(spans="target_spans")
+ )
df = df.join(
_bio_tags_to_spans(
@@ -238,20 +255,27 @@ def make_preds_df(
how="left",
).rename({"spans": "predicted_spans"})
+ tagging_fields.append(
+ pl.field("predictions").struct.with_fields(spans="predicted_spans")
+ )
+
df = df.with_columns(
- pl.col(task.name).struct.with_fields(
- pl.field("labels").struct.with_fields(spans="target_spans"),
- pl.field("predictions").struct.with_fields(spans="predicted_spans"),
- )
- ).drop("target_spans", "predicted_spans")
+ pl.col(task.name).struct.with_fields(tagging_fields)
+ ).drop("target_spans", "predicted_spans", strict=False)
elif task.type == RELATIONS:
- df = df.join(
- _rel_matrix_to_rels(
- df, pl.col(task.name).struct.field("labels").struct.field("values")
- ),
- on="sample_idx",
- how="left",
- ).rename({"relations": "target_relations"})
+ relations_fields = []
+ if not unlabeled:
+ df = df.join(
+ _rel_matrix_to_rels(
+ df,
+ pl.col(task.name).struct.field("labels").struct.field("values"),
+ ),
+ on="sample_idx",
+ how="left",
+ ).rename({"relations": "target_relations"})
+ relations_fields.append(
+ pl.field("labels").struct.with_fields(relations="target_relations")
+ )
df = df.join(
_rel_matrix_to_rels(
@@ -263,15 +287,15 @@ def make_preds_df(
on="sample_idx",
how="left",
).rename({"relations": "predicted_relations"})
+ relations_fields.append(
+ pl.field("predictions").struct.with_fields(
+ relations="predicted_relations"
+ )
+ )
df = df.with_columns(
- pl.col(task.name).struct.with_fields(
- pl.field("labels").struct.with_fields(relations="target_relations"),
- pl.field("predictions").struct.with_fields(
- relations="predicted_relations"
- ),
- )
- ).drop("target_relations", "predicted_relations")
+ pl.col(task.name).struct.with_fields(relations_fields)
+ ).drop("target_relations", "predicted_relations", strict=False)
else:
raise ValueError(f"unknown task type {task.type}")
diff --git a/src/cnlpt/data/predictions.py b/src/cnlpt/data/predictions.py
index c984be1a..fb42b60a 100644
--- a/src/cnlpt/data/predictions.py
+++ b/src/cnlpt/data/predictions.py
@@ -106,7 +106,9 @@ def __init__(
t.name: TaskPredictions(
task=t,
logits=self.raw.predictions[t.index],
- labels=task_labels[t.name].squeeze(),
+ labels=task_labels[t.name].squeeze()
+ if task_labels[t.name] is not None
+ else None,
)
for t in tasks
}
diff --git a/src/cnlpt/rest/__init__.py b/src/cnlpt/rest/__init__.py
new file mode 100644
index 00000000..9551ce58
--- /dev/null
+++ b/src/cnlpt/rest/__init__.py
@@ -0,0 +1,3 @@
+from .cnlp_rest import CnlpRestApp
+
+__all__ = ["CnlpRestApp"]
diff --git a/src/cnlpt/rest/cnlp_rest.py b/src/cnlpt/rest/cnlp_rest.py
new file mode 100644
index 00000000..1f2355cd
--- /dev/null
+++ b/src/cnlpt/rest/cnlp_rest.py
@@ -0,0 +1,224 @@
+import logging
+import os
+from collections.abc import Iterable
+from typing import Union
+
+import polars as pl
+import torch
+from datasets import Dataset
+from fastapi import APIRouter, FastAPI
+from pydantic import BaseModel
+from transformers.models.auto.configuration_auto import AutoConfig
+from transformers.models.auto.modeling_auto import AutoModel
+from transformers.models.auto.tokenization_auto import AutoTokenizer
+from transformers.trainer import Trainer
+from transformers.training_args import TrainingArguments
+from typing_extensions import Self
+
+from ..args.data_args import CnlpDataArguments
+from ..data.analysis import make_preds_df
+from ..data.predictions import CnlpPredictions
+from ..data.preprocess import preprocess_raw_data
+from ..data.task_info import CLASSIFICATION, RELATIONS, TAGGING, TaskInfo
+
+
+class InputDocument(BaseModel):
+ text: str
+ entity_spans: Union[list[tuple[int, int]], None] = None
+
+ def to_text_list(self):
+ if self.entity_spans is None:
+ return [self.text]
+
+ text_list: list[str] = []
+ for entity_start, entity_end in self.entity_spans:
+ start = max(0, entity_start - 100)
+ end = min(len(self.text), entity_end + 100)
+ text_list.append(
+ "".join(
+ [
+ self.text[start:entity_start],
+ "",
+ self.text[entity_start:entity_end],
+ "",
+ self.text[entity_end:end],
+ ]
+ )
+ )
+ return text_list
+
+
+class CnlpRestApp:
+ def __init__(self, model_path: str, device: str = "auto"):
+ self.model_path = model_path
+ self.resolve_device(device)
+ self.setup_logger(logging.DEBUG)
+ self.load_model()
+
+ def resolve_device(self, device: str):
+ self.device = device.lower()
+ if self.device == "auto":
+ if torch.cuda.is_available():
+ self.device = "cuda"
+ elif torch.mps.is_available():
+ self.device = "mps"
+ else:
+ self.device = "cpu"
+ else:
+ try:
+ torch.tensor([1.0], device=self.device)
+ except: # noqa: E722
+ self.logger.warning(
+ f"Device is set to '{self.device}' but was not available; setting to 'cpu' and proceeding. If you have a GPU you need to debug why pytorch cannot see it."
+ )
+ self.device = "cpu"
+
+ def setup_logger(self, log_level):
+ self.logger = logging.getLogger(self.__module__)
+ self.logger.setLevel(log_level)
+
+ def load_model(self):
+ training_args = TrainingArguments(
+ output_dir="save_run/",
+ save_strategy="no",
+ per_device_eval_batch_size=8,
+ do_predict=True,
+ )
+
+ if self.device == "mps":
+ # pin_memory is unsupported on MPS, but defaults to True,
+ # so we'll explicitly turn it off to avoid a warning.
+ training_args.dataloader_pin_memory = False
+
+ self.config = AutoConfig.from_pretrained(self.model_path)
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ self.model_path, config=self.config
+ )
+ self.model = AutoModel.from_pretrained(
+ self.model_path, cache_dir=os.getenv("HF_CACHE"), config=self.config
+ ).to(self.device)
+ self.trainer = Trainer(model=self.model, args=training_args)
+
+ self.tasks: list[TaskInfo] = []
+ for task_idx, task_name in enumerate(self.config.finetuning_task):
+ if self.config.tagger[task_name]:
+ task_type = TAGGING
+ elif self.config.relations[task_name]:
+ task_type = RELATIONS
+ else:
+ task_type = CLASSIFICATION
+
+ self.tasks.append(
+ TaskInfo(
+ name=task_name,
+ type=task_type,
+ index=task_idx,
+ labels=tuple(self.config.label_dictionary[task_name]),
+ )
+ )
+
+ def create_prediction_dataset(
+ self,
+ text: list[str],
+ data_args: CnlpDataArguments,
+ ):
+ dataset = Dataset.from_dict({"text": text})
+
+ return dataset.map(
+ preprocess_raw_data,
+ batched=True,
+ load_from_cache_file=False,
+ desc="Preprocessing raw input",
+ batch_size=100,
+ fn_kwargs={
+ "inference_only": True,
+ "tokenizer": self.tokenizer,
+ "tasks": None,
+ "max_length": data_args.max_seq_length,
+ "hierarchical": self.config.model_type == "hier",
+ "chunk_len": data_args.chunk_len or -1,
+ "num_chunks": data_args.num_chunks or -1,
+ "insert_empty_chunk_at_beginning": data_args.insert_empty_chunk_at_beginning,
+ },
+ )
+
+ def predict(self, dataset: Dataset, data_args: CnlpDataArguments):
+ raw_predictions = self.trainer.predict(dataset)
+ return CnlpPredictions(
+ dataset,
+ raw_prediction=raw_predictions,
+ tasks=self.tasks,
+ data_args=data_args,
+ )
+
+ def format_predictions(self, predictions: CnlpPredictions):
+ df = make_preds_df(predictions).select(["text", *[t.name for t in self.tasks]])
+
+ for task in self.tasks:
+ if task.type == CLASSIFICATION:
+ df = df.with_columns(
+ pl.struct(
+ prediction=pl.col(task.name)
+ .struct.field("predictions")
+ .struct.field("values"),
+ probs=pl.col(task.name)
+ .struct.field("model_output")
+ .struct.field("probs")
+ .arr.to_struct(fields=task.labels),
+ ).alias(task.name)
+ )
+ elif task.type == TAGGING:
+ df = df.with_columns(
+ pl.struct(
+ pl.col(task.name)
+ .struct.field("predictions")
+ .struct.field("spans")
+ ).alias(task.name)
+ )
+ elif task.type == RELATIONS:
+ df = df.with_columns(
+ pl.struct(
+ pl.col(task.name)
+ .struct.field("predictions")
+ .struct.field("relations")
+ ).alias(task.name)
+ )
+
+ return df.to_dicts()
+
+ def process(
+ self,
+ input_doc: InputDocument,
+ max_seq_length: int = 128,
+ chunk_len: Union[int, None] = None,
+ num_chunks: Union[int, None] = None,
+ insert_empty_chunk_at_beginning: bool = False,
+ ):
+ data_args = CnlpDataArguments(
+ data_dir=[],
+ max_seq_length=max_seq_length,
+ chunk_len=chunk_len,
+ num_chunks=num_chunks,
+ insert_empty_chunk_at_beginning=insert_empty_chunk_at_beginning,
+ )
+
+ dataset = self.create_prediction_dataset(input_doc.to_text_list(), data_args)
+ predictions = self.predict(dataset, data_args)
+ return self.format_predictions(predictions)
+
+ def router(self, prefix: str = ""):
+ router = APIRouter(prefix=prefix)
+ router.add_api_route("/process", self.process, methods=["POST"])
+ return router
+
+ def fastapi(self, router_prefix: str = ""):
+ app = FastAPI()
+ app.include_router(self.router(prefix=router_prefix))
+ return app
+
+ @classmethod
+ def multi_app(cls, apps: Iterable[tuple[Self, str]]):
+ multi_app = FastAPI()
+ for app, router_prefix in apps:
+ multi_app.include_router(app.router(router_prefix))
+ return multi_app
diff --git a/test/api/test_api.py b/test/api/test_api.py
index 6b9a3a15..c3f152a1 100644
--- a/test/api/test_api.py
+++ b/test/api/test_api.py
@@ -5,94 +5,159 @@
import pytest
from fastapi.testclient import TestClient
-from cnlpt.api.utils import EntityDocument
+from cnlpt.rest.cnlp_rest import CnlpRestApp, InputDocument
class TestNegation:
@pytest.fixture
def test_client(self):
- from cnlpt.api.negation_rest import app
-
- with TestClient(app) as client:
+ with TestClient(
+ CnlpRestApp("mlml-chip/negation_pubmedbert_sharpseed").fastapi()
+ ) as client:
yield client
def test_negation_startup(self, test_client):
pass
def test_negation_process(self, test_client: TestClient):
- from cnlpt.api.negation_rest import NegationResults
-
- doc = EntityDocument(
- doc_text="The patient has a sore knee and headache "
+ doc = InputDocument(
+ text="The patient has a sore knee and headache "
"but denies nausea and has no anosmia.",
- entities=[[18, 27], [32, 40], [52, 58], [70, 77]],
+ entity_spans=[(18, 27), (32, 40), (52, 58), (70, 77)],
)
- response = test_client.post("/negation/process", content=doc.json())
+ response = test_client.post("/process", content=doc.json())
response.raise_for_status()
- assert response.json() == NegationResults.parse_obj(
- {"statuses": [-1, -1, 1, 1]}
- )
+ assert response.json() == [
+ {
+ "text": "The patient has a sore knee and headache but denies nausea and has no anosmia.",
+ "Negation": {
+ "prediction": "-1",
+ "probs": {
+ "1": pytest.approx(0.0002379878715146333, rel=1e-04),
+ "-1": pytest.approx(0.9997619986534119, rel=1e-04),
+ },
+ },
+ },
+ {
+ "text": "The patient has a sore knee and headache but denies nausea and has no anosmia.",
+ "Negation": {
+ "prediction": "-1",
+ "probs": {
+ "1": pytest.approx(0.0004393413255456835, rel=1e-04),
+ "-1": pytest.approx(0.9995606541633606, rel=1e-04),
+ },
+ },
+ },
+ {
+ "text": "The patient has a sore knee and headache but denies nausea and has no anosmia.",
+ "Negation": {
+ "prediction": "1",
+ "probs": {
+ "1": pytest.approx(0.9921413660049438, rel=1e-04),
+ "-1": pytest.approx(0.007858583703637123, rel=1e-04),
+ },
+ },
+ },
+ {
+ "text": "The patient has a sore knee and headache but denies nausea and has no anosmia.",
+ "Negation": {
+ "prediction": "1",
+ "probs": {
+ "1": pytest.approx(0.9928833246231079, rel=1e-04),
+ "-1": pytest.approx(0.0071166763082146645, rel=1e-04),
+ },
+ },
+ },
+ ]
class TestTemporal:
@pytest.fixture
def test_client(self):
- from cnlpt.api.temporal_rest import app
-
- with TestClient(app) as client:
+ with TestClient(CnlpRestApp("mlml-chip/thyme2_colon_e2e").fastapi()) as client:
yield client
def test_temporal_startup(self, test_client: TestClient):
pass
def test_temporal_process_sentence(self, test_client: TestClient):
- from cnlpt.api.temporal_rest import (
- SentenceDocument,
- TemporalResults,
- )
-
- doc = SentenceDocument(
- sentence="The patient was diagnosed with adenocarcinoma "
+ doc = InputDocument(
+ text="The patient was diagnosed with adenocarcinoma "
"March 3, 2010 and will be returning for "
"chemotherapy next week."
)
- response = test_client.post("/temporal/process_sentence", content=doc.json())
+ response = test_client.post("/process", content=doc.json())
response.raise_for_status()
- out = response.json()
- expected_out = TemporalResults.parse_obj(
+ assert response.json() == [
{
- "events": [
- [
- {"begin": 3, "dtr": "BEFORE", "end": 3},
- {"begin": 5, "dtr": "BEFORE", "end": 5},
- {"begin": 13, "dtr": "AFTER", "end": 13},
- {"begin": 15, "dtr": "AFTER", "end": 15},
+ "text": "The patient was diagnosed with adenocarcinoma March 3, 2010 and will be returning for chemotherapy next week.",
+ "timex": {
+ "spans": [
+ {
+ "text": "March 3, 2010 ",
+ "tag": "DATE",
+ "start": 6,
+ "end": 8,
+ "valid": True,
+ },
+ {
+ "text": "next week.",
+ "tag": "DATE",
+ "start": 15,
+ "end": 16,
+ "valid": True,
+ },
]
- ],
- "relations": [
- [
+ },
+ "event": {
+ "spans": [
+ {
+ "text": "diagnosed ",
+ "tag": "BEFORE",
+ "start": 3,
+ "end": 3,
+ "valid": True,
+ },
+ {
+ "text": "adenocarcinoma ",
+ "tag": "BEFORE",
+ "start": 5,
+ "end": 5,
+ "valid": True,
+ },
{
- "arg1": "TIMEX-0",
- "arg1_start": 6,
- "arg2": "EVENT-0",
- "arg2_start": 3,
- "category": "CONTAINS",
+ "text": "returning ",
+ "tag": "AFTER",
+ "start": 12,
+ "end": 12,
+ "valid": True,
},
{
- "arg1": "TIMEX-1",
- "arg1_start": 16,
- "arg2": "EVENT-2",
- "arg2_start": 13,
- "category": "CONTAINS",
+ "text": "chemotherapy ",
+ "tag": "AFTER",
+ "start": 14,
+ "end": 14,
+ "valid": True,
},
]
- ],
- "timexes": [
- [
- {"begin": 6, "end": 9, "timeClass": "DATE"},
- {"begin": 16, "end": 17, "timeClass": "DATE"},
+ },
+ "tlinkx": {
+ "relations": [
+ {
+ "arg1_wid": 6,
+ "arg1_text": "March",
+ "arg2_wid": 3,
+ "arg2_text": "diagnosed",
+ "label": "CONTAINS",
+ },
+ {
+ "arg1_wid": 15,
+ "arg1_text": "next",
+ "arg2_wid": 12,
+ "arg2_text": "returning",
+ "label": "CONTAINS",
+ },
]
- ],
+ },
}
- )
- assert out == expected_out
+ ]
diff --git a/test/test_init.py b/test/test_init.py
index 12b83959..d2fab80e 100644
--- a/test/test_init.py
+++ b/test/test_init.py
@@ -71,7 +71,7 @@ def test_init_args():
def test_init_api():
- import cnlpt.api
+ import cnlpt.rest
- assert cnlpt.api.__package__ == "cnlpt.api"
- assert cnlpt.api.__all__ == ["MODEL_TYPES", "get_rest_app"]
+ assert cnlpt.rest.__package__ == "cnlpt.rest"
+ assert cnlpt.rest.__all__ == ["CnlpRestApp"]