From df881983badd3ac244e4a8e8c19549e2ce1437d6 Mon Sep 17 00:00:00 2001 From: ianbulovic Date: Fri, 29 Aug 2025 13:05:36 -0400 Subject: [PATCH 1/8] unify model loading --- src/cnlpt/modeling/__init__.py | 39 ++ src/cnlpt/modeling/config/__init__.py | 11 + src/cnlpt/modeling/config/base_config.py | 130 ++++ src/cnlpt/modeling/config/cnn_config.py | 37 + .../modeling/config/hierarchical_config.py | 58 ++ src/cnlpt/modeling/config/lstm_config.py | 33 + .../modeling/config/projection_config.py | 39 ++ src/cnlpt/modeling/load.py | 83 +++ src/cnlpt/modeling/models/__init__.py | 11 + src/cnlpt/modeling/models/cnn_model.py | 105 +++ .../modeling/models/hierarchical_model.py | 240 +++++++ src/cnlpt/modeling/models/lstm_model.py | 68 ++ src/cnlpt/modeling/models/projection_model.py | 337 +++++++++ src/cnlpt/modeling/modules.py | 331 +++++++++ src/cnlpt/modeling/types.py | 43 ++ src/cnlpt/modeling/utils.py | 52 ++ src/cnlpt/models/__init__.py | 9 - src/cnlpt/models/baseline/__init__.py | 4 - src/cnlpt/models/baseline/cnn.py | 109 --- src/cnlpt/models/baseline/lstm.py | 59 -- src/cnlpt/models/cnlp.py | 649 ------------------ src/cnlpt/models/hierarchical.py | 461 ------------- 22 files changed, 1617 insertions(+), 1291 deletions(-) create mode 100644 src/cnlpt/modeling/__init__.py create mode 100644 src/cnlpt/modeling/config/__init__.py create mode 100644 src/cnlpt/modeling/config/base_config.py create mode 100644 src/cnlpt/modeling/config/cnn_config.py create mode 100644 src/cnlpt/modeling/config/hierarchical_config.py create mode 100644 src/cnlpt/modeling/config/lstm_config.py create mode 100644 src/cnlpt/modeling/config/projection_config.py create mode 100644 src/cnlpt/modeling/load.py create mode 100644 src/cnlpt/modeling/models/__init__.py create mode 100644 src/cnlpt/modeling/models/cnn_model.py create mode 100644 src/cnlpt/modeling/models/hierarchical_model.py create mode 100644 src/cnlpt/modeling/models/lstm_model.py create mode 100644 src/cnlpt/modeling/models/projection_model.py create mode 100644 src/cnlpt/modeling/modules.py create mode 100644 src/cnlpt/modeling/types.py create mode 100644 src/cnlpt/modeling/utils.py delete mode 100644 src/cnlpt/models/__init__.py delete mode 100644 src/cnlpt/models/baseline/__init__.py delete mode 100644 src/cnlpt/models/baseline/cnn.py delete mode 100644 src/cnlpt/models/baseline/lstm.py delete mode 100644 src/cnlpt/models/cnlp.py delete mode 100644 src/cnlpt/models/hierarchical.py diff --git a/src/cnlpt/modeling/__init__.py b/src/cnlpt/modeling/__init__.py new file mode 100644 index 00000000..d6145a6e --- /dev/null +++ b/src/cnlpt/modeling/__init__.py @@ -0,0 +1,39 @@ +from transformers.models.auto.configuration_auto import AutoConfig +from transformers.models.auto.modeling_auto import AutoModel + +from .config import ( + CnnModelConfig, + HierarchicalModelConfig, + LstmModelConfig, + ProjectionModelConfig, +) +from .models import CnnModel, HierarchicalModel, LstmModel, ProjectionModel +from .types import ClassificationMode, ModelType + +__all__ = [ + "ClassificationMode", + "CnnModel", + "CnnModelConfig", + "HierarchicalModel", + "HierarchicalModelConfig", + "LstmModel", + "LstmModelConfig", + "ModelType", + "ProjectionModel", + "ProjectionModelConfig", +] + + +AutoConfig.register("cnlpt.proj", ProjectionModelConfig) +AutoModel.register(ProjectionModelConfig, ProjectionModel) + +AutoConfig.register("cnlpt.cnn", CnnModelConfig) +AutoModel.register(CnnModelConfig, CnnModel) + +AutoConfig.register("cnlpt.hier", HierarchicalModelConfig) +AutoModel.register(HierarchicalModelConfig, HierarchicalModel) + +AutoConfig.register("cnlpt.lstm", LstmModelConfig) +AutoModel.register(LstmModelConfig, LstmModel) + +# TODO(ian) It would be REALLY nice if we could load legacy models... diff --git a/src/cnlpt/modeling/config/__init__.py b/src/cnlpt/modeling/config/__init__.py new file mode 100644 index 00000000..d9b5d876 --- /dev/null +++ b/src/cnlpt/modeling/config/__init__.py @@ -0,0 +1,11 @@ +from .cnn_config import CnnModelConfig +from .hierarchical_config import HierarchicalModelConfig +from .lstm_config import LstmModelConfig +from .projection_config import ProjectionModelConfig + +__all__ = [ + "CnnModelConfig", + "HierarchicalModelConfig", + "LstmModelConfig", + "ProjectionModelConfig", +] diff --git a/src/cnlpt/modeling/config/base_config.py b/src/cnlpt/modeling/config/base_config.py new file mode 100644 index 00000000..834fef2e --- /dev/null +++ b/src/cnlpt/modeling/config/base_config.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from dataclasses import asdict +from os import PathLike +from typing import Any, Union + +from transformers import CONFIG_MAPPING, AutoConfig, AutoModel, PretrainedConfig +from transformers import logging as transformers_logging +from transformers.modeling_utils import PreTrainedModel + +from ... import __version__ as current_cnlpt_version +from ...data.task_info import TaskInfo +from ..utils import warn_on_version_mismatch + + +def _load_encoder_config( + encoder_config: Union[PretrainedConfig, dict[str, Any], None], + encoder_name: str, +) -> PretrainedConfig: + if ( + isinstance(encoder_config, dict) + and (model_type := encoder_config.get("model_type")) is not None + ): + config_class: type[PretrainedConfig] = CONFIG_MAPPING[model_type] + return config_class.from_dict(encoder_config) + elif isinstance(encoder_config, PretrainedConfig): + return encoder_config + return AutoConfig.from_pretrained(encoder_name) + + +def _load_tasks( + tasks: Union[list[dict[str, Any]], list[TaskInfo]], +) -> list[dict[str, Any]]: + if tasks is None or len(tasks) == 0: + return [] + elif isinstance(tasks[0], TaskInfo): + return [asdict(t) for t in tasks] + else: + return tasks + + +class BaseConfig(PretrainedConfig): + def __init__( + self, + *, + tasks: Union[list[dict[str, Any]], list[TaskInfo], None] = None, + vocab_size: Union[int, None] = None, + cnlpt_version: Union[str, None] = None, + **kwargs, + ): + if cnlpt_version is None: + self.cnlpt_version = current_cnlpt_version + else: + warn_on_version_mismatch(cnlpt_version) + self.cnlpt_version = cnlpt_version + + if "_tasks" in kwargs: + self._tasks = kwargs.pop("_tasks") + else: + self.tasks = tasks + + super().__init__(vocab_size=vocab_size, **kwargs) + + @property + def tasks(self) -> list[TaskInfo]: + if self._tasks is None: + return [] + return [TaskInfo(**t) for t in self._tasks] + + @tasks.setter + def tasks(self, tasks: Union[list[dict[str, Any]], list[TaskInfo]]): + if tasks is None or len(tasks) == 0: + self._tasks = [] + elif isinstance(tasks[0], TaskInfo): + self._tasks = [asdict(t) for t in tasks] + else: + self._tasks = tasks + + +class BaseConfigWithEncoder(BaseConfig): + def __init__( + self, + *, + tasks: Union[list[dict[str, Any]], list[TaskInfo], None] = None, + vocab_size: Union[int, None] = None, + cnlpt_version: Union[str, None] = None, + encoder_name: Union[str, PathLike] = "roberta-base", + encoder_config: Union[PretrainedConfig, dict[str, Any], None] = None, + **kwargs, + ): + super().__init__( + tasks=tasks, + vocab_size=vocab_size, + cnlpt_version=cnlpt_version, + **kwargs, + ) + + self._set_encoder(encoder_name, encoder_config) + + def _set_encoder( + self, + encoder_name: str, + encoder_config: Union[PretrainedConfig, dict[str, Any], None], + ): + self.encoder_name = encoder_name + self.encoder_config = _load_encoder_config(encoder_config, encoder_name) + self.encoder_output_dim: int = self._get_encoder_attr("dim", "hidden_size") + self.encoder_dropout: float = self._get_encoder_attr( + "dropout", "mlp_dropout", "hidden_dropout_prob" + ) + + def _get_encoder_attr(self, *keys): + for key in keys: + if (result := getattr(self.encoder_config, key, None)) is not None: + return result + raise ValueError( + f"Encoder config does not have any of the attributes {[*keys]}. " + "Please use a supported encoder (e.g. BERT/RoBERTa/DistilBERT/ModernBERT)" + ) + + def load_encoder_model(self, resize_token_embeddings: bool) -> PreTrainedModel: + # Disable warnings for a moment while we load the model to keep the console clean. + # (The emitted warnings are non-issues.) + verb_before = transformers_logging.get_verbosity() + transformers_logging.set_verbosity_error() + encoder: PreTrainedModel = AutoModel.from_config(config=self.encoder_config) + if resize_token_embeddings: + encoder.resize_token_embeddings(self.vocab_size, mean_resizing=False) + transformers_logging.set_verbosity(verb_before) + return encoder diff --git a/src/cnlpt/modeling/config/cnn_config.py b/src/cnlpt/modeling/config/cnn_config.py new file mode 100644 index 00000000..54bcc39a --- /dev/null +++ b/src/cnlpt/modeling/config/cnn_config.py @@ -0,0 +1,37 @@ +from typing import Any, Union + +from ...data.task_info import CLASSIFICATION, TaskInfo +from .base_config import BaseConfig + + +class CnnModelConfig(BaseConfig): + model_type = "cnlpt.cnn" + + def __init__( + self, + *, + tasks: Union[list[dict[str, Any]], list[TaskInfo], None] = None, + vocab_size: Union[int, None] = None, + use_prior_tasks: bool = False, + embed_dim: int = 100, + num_filters: int = 25, + filter_sizes: tuple[int, ...] = (1, 2, 3), + dropout: float = 0.2, + **kwargs, + ): + super().__init__( + tasks=tasks, + vocab_size=vocab_size, + **kwargs, + ) + + self.use_prior_tasks = use_prior_tasks + self.embed_dim = embed_dim + self.filters_per_size = num_filters + self.filter_sizes = filter_sizes + self.dropout = dropout + + if any(t.type != CLASSIFICATION for t in self.tasks): + raise NotImplementedError( + "using a CNN model for non-classification tasks is not yet implemented" + ) diff --git a/src/cnlpt/modeling/config/hierarchical_config.py b/src/cnlpt/modeling/config/hierarchical_config.py new file mode 100644 index 00000000..4aab3b64 --- /dev/null +++ b/src/cnlpt/modeling/config/hierarchical_config.py @@ -0,0 +1,58 @@ +import os +from typing import Any, Union + +from ...data.task_info import TaskInfo +from .base_config import BaseConfigWithEncoder + + +def _resolve_layer(layer: int, n_layers: int): + if layer < 0: + layer = layer + n_layers + 1 + + if layer > n_layers: + raise ValueError( + f"The layer specified ({layer}) is too big for the specified chunk transformer which has {n_layers} layers" + ) + elif layer < 0: + raise ValueError( + f"The layer specified ({layer}) is a negative value which is larger than the actual number of layers {n_layers}" + ) + elif layer == 0: + raise ValueError( + "The classifier layer derived is 0 which is ambiguous -- there is no usable 0th layer in a hierarchical model. Enter a value for the layer argument that at least 1 (use one layer) or -1 (use the final layer)" + ) + return layer + + +class HierarchicalModelConfig(BaseConfigWithEncoder): + model_type = "cnlpt.hier" + + def __init__( + self, + *, + tasks: Union[list[dict[str, Any]], list[TaskInfo], None] = None, + vocab_size: Union[int, None] = None, + encoder_name: Union[str, os.PathLike] = "roberta-base", + layer: int = -1, + n_layers: int = 8, + d_inner: int = 2048, + n_head: int = 8, + d_k: int = 8, + d_v: int = 96, + dropout: float = 0.1, + **kwargs, + ): + super().__init__( + tasks=tasks, + vocab_size=vocab_size, + encoder_name=encoder_name, + **kwargs, + ) + + self.layer = _resolve_layer(layer, n_layers) + self.n_layers = n_layers + self.d_inner = d_inner + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + self.dropout = dropout diff --git a/src/cnlpt/modeling/config/lstm_config.py b/src/cnlpt/modeling/config/lstm_config.py new file mode 100644 index 00000000..2180e0f6 --- /dev/null +++ b/src/cnlpt/modeling/config/lstm_config.py @@ -0,0 +1,33 @@ +from typing import Any, Union + +from ...data.task_info import CLASSIFICATION, TaskInfo +from .base_config import BaseConfig + + +class LstmModelConfig(BaseConfig): + model_type = "cnlpt.lstm" + + def __init__( + self, + *, + tasks: Union[list[dict[str, Any]], list[TaskInfo], None] = None, + vocab_size: Union[int, None] = None, + embed_dim: int = 100, + hidden_size: int = 100, + dropout: float = 0.2, + **kwargs, + ): + super().__init__( + tasks=tasks, + vocab_size=vocab_size, + **kwargs, + ) + + self.embed_dim = embed_dim + self.hidden_size = hidden_size + self.dropout = dropout + + if any(t.type != CLASSIFICATION for t in self.tasks): + raise NotImplementedError( + "using a LSTM model for non-classification tasks is not yet implemented" + ) diff --git a/src/cnlpt/modeling/config/projection_config.py b/src/cnlpt/modeling/config/projection_config.py new file mode 100644 index 00000000..e9270871 --- /dev/null +++ b/src/cnlpt/modeling/config/projection_config.py @@ -0,0 +1,39 @@ +import os +from typing import Any, Literal, Union + +from ...data.task_info import TaskInfo +from .base_config import BaseConfigWithEncoder + + +class ProjectionModelConfig(BaseConfigWithEncoder): + model_type = "cnlpt.proj" + + def __init__( + self, + *, + tasks: Union[list[dict[str, Any]], list[TaskInfo], None] = None, + vocab_size: Union[int, None] = None, + encoder_name: Union[str, os.PathLike] = "roberta-base", + encoder_layer: int = -1, + use_prior_tasks: bool = False, + classification_mode: Literal["cls", "tagged"] = "cls", + num_rel_attention_heads: int = 12, + rel_attention_head_dims: int = 64, + character_level: bool = False, + **kwargs, + ): + super().__init__( + tasks=tasks, + vocab_size=vocab_size, + encoder_name=encoder_name, + **kwargs, + ) + + self.encoder_layer = encoder_layer + self.use_prior_tasks = use_prior_tasks + self.tokens = ( + classification_mode == "tagged" + ) # TODO(ian) this should really be self.classification_mode + self.num_rel_attention_heads = num_rel_attention_heads + self.rel_attention_head_dims = rel_attention_head_dims + self.character_level = character_level diff --git a/src/cnlpt/modeling/load.py b/src/cnlpt/modeling/load.py new file mode 100644 index 00000000..93a13690 --- /dev/null +++ b/src/cnlpt/modeling/load.py @@ -0,0 +1,83 @@ +from typing import Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_utils import PreTrainedModel +from transformers.models.auto.configuration_auto import AutoConfig +from transformers.models.auto.modeling_auto import AutoModel + +from ..data.task_info import CLASSIFICATION, RELATIONS, TAGGING, TaskInfo +from .config import ( + CnnModelConfig, + HierarchicalModelConfig, + LstmModelConfig, + ProjectionModelConfig, +) + + +def try_load_config( + model_name_or_path: str, +) -> Union[ + ProjectionModelConfig, CnnModelConfig, HierarchicalModelConfig, LstmModelConfig +]: + """Load a model config, potentially for a model created with an earlier version of CNLPT. + + Args: + config_file: Path to a config file on disk. + + Returns: + The loaded config. + """ + + config_data = PretrainedConfig.get_config_dict(model_name_or_path)[0] + + if "model_type" not in config_data: + raise ValueError("could not infer model type") + model_type: str = config_data.pop("model_type") + + if model_type.startswith("cnlpt."): + # This is a post-0.7 model, so we'll just use autoconfig + return AutoConfig.from_pretrained(model_name_or_path) + + config_data.pop("architectures") + + training_tasks: list[str] = config_data.pop("finetuning_task") + label_dict: dict[str, list[str]] = config_data.pop("label_dictionary") + tagging_tasks: dict[str, bool] = config_data.pop("tagger") + relations_tasks: dict[str, bool] = config_data.pop("relations") + + tasks: list[TaskInfo] = [] + for task_idx, task_name in enumerate(training_tasks): + task_type = ( + TAGGING + if tagging_tasks[task_name] + else RELATIONS + if relations_tasks[task_name] + else CLASSIFICATION + ) + tasks.append( + TaskInfo( + name=task_name, + type=task_type, + index=task_idx, + labels=tuple(label_dict[task_name]), + ) + ) + + config_data["tasks"] = tasks + + if model_type == "cnlpt": + config_data["architectures"] = ["ProjectionModel"] + config_data["model_type"] = ["cnlpt.proj"] + tagged_mode = config_data.pop("tokens") + config_data["classification_mode"] = "tagged" if tagged_mode else "cls" + + return ProjectionModelConfig(**config_data) + else: + raise NotImplementedError( + "loading legacy models other than projection models is not yet implemented" + ) + + +def try_load_pretrained_model(pretrained_model_name_or_path: str) -> PreTrainedModel: + config = try_load_config(pretrained_model_name_or_path) + return AutoModel.from_pretrained(pretrained_model_name_or_path, config=config) diff --git a/src/cnlpt/modeling/models/__init__.py b/src/cnlpt/modeling/models/__init__.py new file mode 100644 index 00000000..62e549c4 --- /dev/null +++ b/src/cnlpt/modeling/models/__init__.py @@ -0,0 +1,11 @@ +from .cnn_model import CnnModel +from .hierarchical_model import HierarchicalModel +from .lstm_model import LstmModel +from .projection_model import ProjectionModel + +__all__ = [ + "CnnModel", + "HierarchicalModel", + "LstmModel", + "ProjectionModel", +] diff --git a/src/cnlpt/modeling/models/cnn_model.py b/src/cnlpt/modeling/models/cnn_model.py new file mode 100644 index 00000000..6b216270 --- /dev/null +++ b/src/cnlpt/modeling/models/cnn_model.py @@ -0,0 +1,105 @@ +from typing import Union + +import torch +import torch.nn.functional as F +from torch import nn +from transformers.modeling_utils import PreTrainedModel + +from ..config.cnn_config import CnnModelConfig + + +class CnnModel(PreTrainedModel): + base_model_prefix = "cnlpt.cnn" + config_class = CnnModelConfig + + def __init__( + self, + config: CnnModelConfig, + *, + class_weights: Union[dict[str, torch.FloatTensor], None] = None, + **kwargs, + ): + super().__init__(config) + self.config: CnnModelConfig + + self.embed = nn.Embedding( + num_embeddings=config.vocab_size, + embedding_dim=config.embed_dim, + ) + self.convs = nn.ModuleList( + [ + nn.Conv1d( + in_channels=config.embed_dim, + out_channels=config.filters_per_size, + kernel_size=filter_size, + ) + for filter_size in config.filter_sizes + ] + ) + + self.loss_fns = { + task.name: nn.CrossEntropyLoss( + weight=class_weights[task.name] if class_weights is not None else None + ) + for task in self.config.tasks + } + + total_filters = len(self.config.filter_sizes) * self.config.filters_per_size + + self.fcs = nn.ModuleList() + for task in self.config.tasks: + self.fcs.append(nn.Linear(total_filters, len(task.labels))) + + if self.config.use_prior_tasks: + self.intertask_matrices: list[list[nn.Linear]] = [] + for i in range(len(self.tasks)): + matrices = [] + for j in range(len(self.tasks) - i): + matrices.append(nn.Linear(2, total_filters)) + self.intertask_matrices.append(matrices) + + def forward( + self, + input_ids: Union[torch.LongTensor, None] = None, + labels: Union[torch.LongTensor, None] = None, + output_hidden_states=False, + **kwargs, + ): + embeddings: torch.Tensor = self.embed(input_ids) + embeddings = embeddings.transpose(1, 2) + all_convs: list[torch.Tensor] = [conv(embeddings) for conv in self.convs] + pooled_convs = [ + F.max_pool1d(conv_out, conv_out.shape[2]) for conv_out in all_convs + ] + + fc_in = torch.cat(pooled_convs, 1).squeeze(2) + + logits = [] + loss = 0 + for task, fc in zip(self.config.tasks, self.fcs): + # get feaures from previous tasks using the world's tiniest linear layer + if self.config.use_prior_tasks: + for prev_task_ind in range(task.index): + prev_task_matrix = self.intertask_matrices[prev_task_ind][ + task.index - prev_task_ind - 1 + ] + prev_task_matrix = prev_task_matrix.to(logits[prev_task_ind].device) + prev_task_features = prev_task_matrix(logits[prev_task_ind]) + fc_in = fc_in + prev_task_features + task_logits: torch.Tensor = fc(fc_in) + logits.append(task_logits) + + if labels is not None: + if labels.ndim == 2: + # if len(self.fcs) == 1: + # task_labels = labels[:,0] + task_labels = labels[:, task.index] + elif labels.ndim == 3: + task_labels = labels[:, 0, task.index] + loss += self.loss_fns[task.name]( + task_logits, task_labels.type(torch.LongTensor).to(labels.device) + ) + if output_hidden_states: + return loss, logits, fc_in + else: + return loss, logits diff --git a/src/cnlpt/modeling/models/hierarchical_model.py b/src/cnlpt/modeling/models/hierarchical_model.py new file mode 100644 index 00000000..596a0db3 --- /dev/null +++ b/src/cnlpt/modeling/models/hierarchical_model.py @@ -0,0 +1,240 @@ +""" +Module containing the Hierarchical Transformer module, adapted from Xin Su. +""" + +import logging +from dataclasses import dataclass +from typing import Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput +from transformers.modeling_utils import PreTrainedModel + +from ...data.task_info import TaskInfo +from ..config.hierarchical_config import HierarchicalModelConfig +from ..modules import ClassificationHead, EncoderLayer +from ..utils import freeze_encoder_weights, generalize_encoder_forward_kwargs + +logger = logging.getLogger(__name__) + + +@dataclass +class HierarchicalSequenceClassifierOutput(SequenceClassifierOutput): + chunk_attentions: Union[tuple[torch.FloatTensor], None] = None + + +class HierarchicalModel(PreTrainedModel): + """ + Hierarchical Transformer model (https://arxiv.org/abs/2105.06752) + + Adapted from Xin Su's implementation (https://github.com/xinsu626/DocTransformer) + """ + + base_model_prefix = "cnlpt.hier" + config_class = HierarchicalModelConfig + + def __init__( + self, + config: HierarchicalModelConfig, + *, + freeze: float = -1.0, + class_weights: Union[dict[str, torch.FloatTensor], None] = None, + **kwargs, + ): + super().__init__(config) + self.config: HierarchicalModelConfig + + self.encoder = self.config.load_encoder_model(True) + + if freeze > 0: + freeze_encoder_weights(self.encoder, freeze) + + # Document-level transformer layer + self.transformer: list[EncoderLayer] = nn.ModuleList( + [ + EncoderLayer( + d_model=self.config.encoder_output_dim, + d_inner=self.config.d_inner, + n_head=self.config.n_head, + d_k=self.config.d_k, + d_v=self.config.d_v, + dropout=self.config.dropout, + ) + for _ in range(self.config.n_layers) + ] + ) + + self.configure_for_tasks(self.config.tasks, class_weights) + + def configure_for_tasks( + self, + tasks: list[TaskInfo], + class_weights: Union[dict[str, torch.FloatTensor], None], + ): + self.tasks = self.config.tasks = tasks + self.classifiers = nn.ModuleDict() + + for task in self.tasks: + self.classifiers[task.name] = ClassificationHead( + hidden_dropout_prob=self.config.dropout, + hidden_size=self.config.encoder_output_dim, + num_labels=len(task.labels), + ) + + self.class_weights = class_weights + + def forward( + self, + input_ids: Union[torch.LongTensor, None] = None, + attention_mask: Union[torch.LongTensor, None] = None, + token_type_ids: Union[torch.LongTensor, None] = None, + position_ids: Union[torch.LongTensor, None] = None, + head_mask: Union[torch.LongTensor, None] = None, + inputs_embeds: Union[torch.FloatTensor, None] = None, + labels: Union[torch.LongTensor, None] = None, + output_attentions: Union[bool, None] = None, + **kwargs, + ): + """ + Forward method. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, num_chunks, chunk_len)`, *optional*): + A batch of chunked documents as tokenizer indices. + attention_mask (`torch.LongTensor` of shape `(batch_size, num_chunks, chunk_len)`, *optional*): + Attention masks for the batch. + token_type_ids (`torch.LongTensor` of shape `(batch_size, num_chunks, chunk_len)`, *optional*): + Token type IDs for the batch. + position_ids: (`torch.LongTensor` of shape `(batch_size, num_chunks, chunk_len)`, *optional*): + Position IDs for the batch. + head_mask (`torch.LongTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Token encoder head mask. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_chunks, chunk_len, hidden_size)`, *optional*): + A batch of chunked documents as token embeddings. + labels (`torch.LongTensor` of shape `(batch_size, num_tasks)`, *optional*): + Labels for computing the sequence classification/regression loss. + Indices should be in `[0, ..., self.num_labels[task_ind] - 1]`. + If `self.num_labels[task_ind] == 1` a regression loss is computed (Mean-Square loss), + If `self.num_labels[task_ind] > 1` a classification loss is computed (Cross-Entropy). + output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. + output_hidden_states: If True, return a matrix of shape (batch_size, num_chunks, hidden size) representing the contextualized embeddings of each chunk. The 0-th element of each chunk is the classifier representation for that instance. + event_tokens: not currently used (only relevant for token classification) + + Returns: + + """ + if input_ids is not None: + batch_size, num_chunks, chunk_len = input_ids.shape + flat_shape = (batch_size * num_chunks, chunk_len) + else: # inputs_embeds is not None + batch_size, num_chunks, chunk_len, embed_dim = inputs_embeds.shape + flat_shape = (batch_size * num_chunks, chunk_len, embed_dim) + + encoder_kwargs = generalize_encoder_forward_kwargs( + self.encoder, + attention_mask=( + attention_mask.reshape(flat_shape[:3]) + if attention_mask is not None + else None + ), + token_type_ids=( + token_type_ids.reshape(flat_shape[:3]) + if token_type_ids is not None + else None + ), + position_ids=( + position_ids.reshape(flat_shape[:3]) + if position_ids is not None + else None + ), + head_mask=head_mask, + inputs_embeds=( + inputs_embeds.reshape(flat_shape) if inputs_embeds is not None else None + ), + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=True, + ) + + outputs: BaseModelOutput = self.encoder( + input_ids.reshape(flat_shape[:3]) if input_ids is not None else None, + **encoder_kwargs, + ) + + logits = [] + + # outputs.last_hidden_state.shape: (B * n_chunks, chunk_len, hidden_size) + # (B * n_chunk, hidden_size) + chunks_reps = outputs.last_hidden_state[..., 0, :].reshape( + batch_size, num_chunks, outputs.last_hidden_state.shape[-1] + ) + + # Use pre-trained model's position embedding + position_ids = torch.arange( + num_chunks, dtype=torch.long, device=chunks_reps.device + ) # (n_chunk) + position_ids = position_ids.unsqueeze(0).expand_as( + chunks_reps[:, :, 0] + ) # (B, n_chunk) + position_embeddings: torch.Tensor = self.encoder.embeddings.position_embeddings( + position_ids + ) + chunks_reps = chunks_reps + position_embeddings + chunks_attns: Union[list[torch.Tensor], None] = None + + # document encoding (B, n_chunk, hidden_size) + for layer_ind, layer_module in enumerate(self.transformer): + chunks_reps: torch.Tensor + chunks_attn: torch.Tensor + chunks_reps, chunks_attn = layer_module(chunks_reps) + if output_attentions: + if chunks_attns is None: + chunks_attns = [] + chunks_attns.append(chunks_attn) + + ## this case is mainly for when we are doing subsequent fine-tuning using a pre-trained + ## hierarchical model and we want to check whether an earlier layer might provide better + ## classification performance (e.g., if we think the last layer(s) are overfit to the pre-training + ## objective) Just short circuit rather than doing the whole computation. + if layer_ind + 1 >= self.config.layer: + break + + hidden_states = chunks_reps + + # extract first Documents as rep. (B, hidden_size) + doc_rep = chunks_reps[:, 0, :] + + total_loss = None + for task in self.tasks: + loss_fct = CrossEntropyLoss( + weight=self.class_weights[task.name] + if self.class_weights is not None + else None + ) + + # predict (B, 5) + task_logits = self.classifiers[task.name](doc_rep) + logits.append(task_logits) + + if labels is not None: + task_labels = labels[:, task.index] + task_loss = loss_fct( + task_logits, task_labels.type(torch.LongTensor).to(labels.device) + ) + if total_loss is None: + total_loss = task_loss + else: + total_loss += task_loss + + output = HierarchicalSequenceClassifierOutput( + loss=total_loss, + logits=logits, + ) + if self.config.output_hidden_states: + output.hidden_states = (*outputs.hidden_states, *hidden_states) + if self.config.output_attentions: + output.attentions = outputs.attentions + output.chunk_attentions = chunks_attns + return output diff --git a/src/cnlpt/modeling/models/lstm_model.py b/src/cnlpt/modeling/models/lstm_model.py new file mode 100644 index 00000000..4990fe11 --- /dev/null +++ b/src/cnlpt/modeling/models/lstm_model.py @@ -0,0 +1,68 @@ +from typing import Union + +import torch +from torch import nn +from transformers.modeling_utils import PreTrainedModel + +from ..config.lstm_config import LstmModelConfig + + +class LstmModel(PreTrainedModel): + base_model_prefix = "cnlpt.lstm" + config_class = LstmModelConfig + + def __init__( + self, + config: LstmModelConfig, + *, + class_weights: Union[dict[str, torch.FloatTensor], None] = None, + **kwargs, + ): + super().__init__(config) + self.config = config + self.embed = nn.Embedding( + num_embeddings=config.vocab_size, + embedding_dim=config.embed_dim, + ) + self.lstm = nn.LSTM( + input_size=config.embed_dim, + hidden_size=config.hidden_size, + bidirectional=True, + ) + self.loss_fns = { + task.name: nn.CrossEntropyLoss( + weight=class_weights[task.name] if class_weights is not None else None + ) + for task in self.config.tasks + } + + self.fcs = nn.ModuleList() + + for task in config.tasks: + self.fcs.append(nn.Linear(4 * config.hidden_size, len(task.labels))) + + def forward( + self, + input_ids: Union[torch.LongTensor, None] = None, + labels: Union[torch.LongTensor, None] = None, + **kwargs, + ): + embeddings = self.embed(input_ids) + lstm_out = self.lstm(embeddings)[0] + + logits: list[torch.Tensor] = [] + loss = 0 + for task, fc in zip(self.config.tasks, self.fcs): + features = torch.cat((lstm_out[:, 0, :], lstm_out[:, -1, :]), 1) + task_logits: torch.Tensor = fc(features) + logits.append(task_logits) + + if labels is not None: + if labels.ndim == 2: + task_labels = labels[:, 0] + elif labels.ndim == 3: + task_labels = labels[:, 0, task.index] + loss += self.loss_fns[task.name]( + task_logits, task_labels.type(torch.LongTensor).to(labels.device) + ) + return loss, logits diff --git a/src/cnlpt/modeling/models/projection_model.py b/src/cnlpt/modeling/models/projection_model.py new file mode 100644 index 00000000..5a134040 --- /dev/null +++ b/src/cnlpt/modeling/models/projection_model.py @@ -0,0 +1,337 @@ +""" +Module containing the CNLP transformer model. +""" + +from __future__ import annotations + +import logging +from typing import Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import SequenceClassifierOutput +from transformers.modeling_utils import PreTrainedModel + +from ...data.task_info import RELATIONS, TAGGING +from ..config.projection_config import ProjectionModelConfig +from ..modules import ClassificationHead, RepresentationProjectionLayer +from ..utils import freeze_encoder_weights, generalize_encoder_forward_kwargs + +logger = logging.getLogger(__name__) + + +class ProjectionModel(PreTrainedModel): + base_model_prefix = "cnlpt.proj" + config_class = ProjectionModelConfig + + def __init__( + self, + config: ProjectionModelConfig, + *, + class_weights: Union[dict[str, torch.FloatTensor], None] = None, + final_task_weight: float = 1.0, + freeze: float = -1.0, + bias_fit: bool = False, + **kwargs, + ): + """Create a new CNLP transformer model instance from a config object. + + Args: + config: The CnlpConfig object that configures this model + class_weights: If provided, the weights to use for each task when computing the loss. Defaults to None. + final_task_weight: The weight to use for the final task when computing the loss. Defaults to 1.0. + freeze: What proportion of encoder weights to freeze (-1 for none). Defaults to -1.0. + bias_fit: Whether to fine-tune only the bias of the encoder. Defaults to False. + """ + super().__init__(config) + self.config: ProjectionModelConfig + + # part of the motivation for only resizing embeddings for non character-level models + # is that at the time of writing, CANINE and Flair are the only game in town. + # CANINE's hashable embeddings for unicode codepoints allows for + # additional parameterization, which rn doesn't seem so relevant + self.encoder = self.config.load_encoder_model(not self.config.character_level) + + self.tasks = self.config.tasks + + if self.encoder.config.model_type == "modernbert": + self.num_layers = len(self.encoder.base_model.layers) + else: + self.num_layers = len(self.encoder.encoder.layer) + if self.config.encoder_layer > self.num_layers: + raise ValueError( + f"The layer specified ({self.config.encoder_layer}) is too big for the specified encoder which has {self.num_layers} layers" + ) + + if freeze > 0: + freeze_encoder_weights(self.encoder, freeze) + + if bias_fit: + for name, param in self.encoder.named_parameters(): + if "bias" not in name: + param.requires_grad = False + + self.feature_extractors = nn.ModuleDict() + self.classifiers = nn.ModuleDict() + + total_prev_task_labels = 0 + for task in self.tasks: + self.feature_extractors[task.name] = RepresentationProjectionLayer( + hidden_dropout_prob=self.config.encoder_dropout, + hidden_size=self.config.encoder_output_dim, + layer=self.config.encoder_layer, + tokens=self.config.tokens, + task_type=task.type, + num_attention_heads=self.config.num_rel_attention_heads, + head_size=self.config.rel_attention_head_dims, + ) + hidden_size = self.config.encoder_output_dim + if task.type == RELATIONS: + hidden_size = self.config.num_rel_attention_heads + if self.config.use_prior_tasks: + hidden_size += total_prev_task_labels + + self.classifiers[task.name] = ClassificationHead( + hidden_dropout_prob=self.config.encoder_dropout, + hidden_size=hidden_size, + num_labels=len(task.labels), + ) + + total_prev_task_labels += len(task.labels) + + self.class_weights = class_weights + self.final_task_weight = final_task_weight + self.use_prior_tasks = self.config.use_prior_tasks + self.reg_temperature = 1.0 + + def predict_relations_with_previous_logits( + self, features: torch.Tensor, logits: list[torch.Tensor] + ) -> torch.Tensor: + """For the relation prediction task, use previous predictions of the tagging task as additional features in the + representation used for making the relation prediction. + + Args: + features: The existing feature vector for the relations + logits: The predicted logits from the tagging task + + Returns: + The augmented feature tensor + """ + + # features is (batch x seq x seq x n_heads) + seq_len = features.shape[1] + for prior_task_logits in logits: + if len(features.shape) == 4: + if len(prior_task_logits.shape) == 3: + # prior task is sequence tagging: + # we have batch x seq x num_classes. + # we want to concatenate the num_classes to the variables at each element of the sequence, + # but then need to broadcast it down all the rows of the matrix. + aug = prior_task_logits.unsqueeze( + 2 + ) # add another dimension to repeat along + aug = aug.repeat( + 1, 1, seq_len, 1 + ) # repeat along the new empty dimension so we have our seq logits repeated seq_len x seq_len + features = torch.cat( + (features, aug), 3 + ) # concatenate the relation matrix with the sequence matrix + else: + logger.warning( + f"It is not implemented to add a task of shape {prior_task_logits.shape!s} to a relation matrix" + ) + elif len(features.shape) == 3: + # sequence + logger.warning( + "It is not implemented to add previous task of any type to a sequence task" + ) + + return features + + def compute_loss( + self, + task_logits: torch.FloatTensor, + labels: torch.LongTensor, + task_ind: int, + task_num_labels: int, + batch_size: int, + seq_len: int, + state: dict, + ) -> None: + """ + Computes the loss for a single batch and a single task. + + Args: + task_logits: + labels: + task_ind: + task_num_labels: + batch_size: + seq_len: + state: + :meta private: + """ + task = self.tasks[task_ind] + if task_num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + task_loss = loss_fct(task_logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss( + weight=self.class_weights[task.name] + if self.class_weights is not None + else None + ) + + if task.type == RELATIONS: + task_labels = labels[ + :, :, state["task_label_ind"] : state["task_label_ind"] + seq_len + ] + state["task_label_ind"] += seq_len + task_loss = loss_fct( + task_logits.permute(0, 3, 1, 2), + task_labels.type(torch.LongTensor).to(labels.device), + ) + elif task.type == TAGGING: + # in cases where we are only given a single task the HF code will have one fewer dimension in the labels, so just add a dummy dimension to make our indexing work: + if labels.ndim == 2: + task_labels = labels + elif labels.ndim == 3: + # labels = labels.unsqueeze(1) + task_labels = labels[:, :, state["task_label_ind"]] + else: + task_labels = labels[:, 0, state["task_label_ind"], :] + + state["task_label_ind"] += 1 + task_loss = loss_fct( + task_logits.view(-1, task_num_labels), + task_labels.reshape( + [ + batch_size * seq_len, + ] + ) + .type(torch.LongTensor) + .to(labels.device), + ) + else: # task.type == CLASSIFICATION + if labels.ndim == 1: + task_labels = labels + elif labels.ndim == 2: + task_labels = labels[:, task_ind] + elif labels.ndim == 3: + task_labels = labels[:, 0, task_ind] + else: + raise NotImplementedError( + "Have not implemented the case where a classification task " + "is part of an MTL setup with relations and sequence tagging" + ) + + state["task_label_ind"] += 1 + task_loss = loss_fct( + task_logits, task_labels.type(torch.LongTensor).to(labels.device) + ) + + if state["loss"] is None: + state["loss"] = task_loss + else: + task_weight = ( + 1.0 if task_ind + 1 < len(self.tasks) else self.final_task_weight + ) + state["loss"] += task_weight * task_loss + + def forward( + self, + input_ids=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + event_tokens=None, + **kwargs, + ): + r"""Forward method. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_len)`, *optional*): + A batch of chunked documents as tokenizer indices. + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_len)`, *optional*): + Attention masks for the batch. + token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_len)`, *optional*): + Token type IDs for the batch. + position_ids: (`torch.LongTensor` of shape `(batch_size, sequence_len)`, *optional*): + Position IDs for the batch. + head_mask (`torch.LongTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Token encoder head mask. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_len, hidden_size)`, *optional*): + A batch of chunked documents as token embeddings. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. + Indices should be in :obj:`[0, ..., config.num_labels - 1]`. + If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. + event_tokens: a mask defining which tokens in the input are to be averaged for input to classifier head; only used when self.tokens==True. + + Returns: (`transformers.SequenceClassifierOutput`) the output of the model + """ + + kwargs = generalize_encoder_forward_kwargs( + self.encoder, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=True, + return_dict=True, + ) + + outputs = self.encoder(input_ids, **kwargs) + + batch_size, seq_len = input_ids.shape + + logits = [] + + state = dict(loss=None, task_label_ind=0) + + for task in self.tasks: + # hidden_states has shape (layers x batch x seq x hidden) + + # features shape: + # for classification (including event tokens mode): (batch x hidden) + # for tagging: (batch x seq x hidden) + # for relations: (batch x seq x seq x n_heads) + features = self.feature_extractors[task.name]( + outputs.hidden_states, event_tokens + ) + if self.use_prior_tasks: + # note: this specific way of incorporating previous logits doesn't help in my experiments with thyme/clinical tempeval + if task.type == RELATIONS: + features = self.predict_relations_with_previous_logits( + features, logits + ) + task_logits = self.classifiers[task.name](features) + logits.append(task_logits) + + if labels is not None: + self.compute_loss( + task_logits, + labels, + task.index, + len(task.labels), + batch_size, + seq_len, + state, + ) + output = SequenceClassifierOutput(loss=state["loss"], logits=logits) + if self.config.output_hidden_states: + output.hidden_states = outputs.hidden_states + if self.config.output_attentions: + output.attentions = outputs.attentions + + return output diff --git a/src/cnlpt/modeling/modules.py b/src/cnlpt/modeling/modules.py new file mode 100644 index 00000000..f561502a --- /dev/null +++ b/src/cnlpt/modeling/modules.py @@ -0,0 +1,331 @@ +from __future__ import annotations + +import math +from typing import Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ..data.task_info import CLASSIFICATION, RELATIONS, TAGGING, TaskType + + +class ClassificationHead(nn.Module): + """Generic classification head that can be used for any task.""" + + def __init__(self, hidden_dropout_prob: float, hidden_size: int, num_labels: int): + super().__init__() + self.dropout = nn.Dropout(hidden_dropout_prob) + self.out_proj = nn.Linear(hidden_size, num_labels) + + def forward(self, features, *kwargs): + x = self.dropout(features) + x = self.out_proj(x) + return x + + +class RepresentationProjectionLayer(nn.Module): + """The class that maps from some output from a text encoder into a feature representation that can be classified. + Project the representation to a new space depending on the task type, based on arguments passed in to the constructor.""" + + def __init__( + self, + hidden_dropout_prob: float, + hidden_size: int, + layer: int = 10, + tokens: bool = False, + task_type: TaskType = CLASSIFICATION, + num_attention_heads: int = -1, + head_size: int = 64, + ): + """ + Args: + config: The config file for the encoder + layer: Which layer to pull the encoder representation from + tokens: Whether to classify an entity based on the token reprsentation rather than the CLS representation + tagger: Whether the current task is a token tagging task + relations: Whether the current task is relation exttraction + num_attention_heads: For relations, how many "features" to use + head_size: For relations, how big each head should be + """ + super().__init__() + self.dropout = nn.Dropout(hidden_dropout_prob) + if task_type == RELATIONS: + self.dense = nn.Identity() + else: + self.dense = nn.Linear(hidden_size, hidden_size) + + self.layer_to_use = layer + self.tokens = tokens + self.task_type = task_type + self.hidden_size = hidden_size + + if num_attention_heads <= 0 and self.task_type == RELATIONS: + raise Exception( + "Inconsistent configuration: num_attention_heads must be > 0 for relations" + ) + + if self.task_type == RELATIONS: + self.num_attention_heads = num_attention_heads + self.attention_head_size = head_size + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(hidden_size, self.all_head_size) + self.key = nn.Linear(hidden_size, self.all_head_size) + + if tokens and self.task_type in (TAGGING, RELATIONS): + raise Exception( + "Inconsistent configuration: tokens cannot be true in tagger or relation mode" + ) + + def transpose_for_scores(self, x): + # x: (batch x seq x all_head) + + new_x_shape = x.size()[:-1] + ( + self.num_attention_heads, + self.attention_head_size, + ) + # (batch x seq x n_heads x head_size) + x = x.view(*new_x_shape) + + # (batch x n_heads x seq x head_size) + return x.permute(0, 2, 1, 3) + + def forward(self, features, event_tokens: torch.Tensor, **kwargs): + # features: (layers x batch x seq x hidden) + # event_tokens: (batch x seq) + + seq_length = features[0].shape[1] + if self.tokens: + # grab the average over the tokens of the thing we want to classify + # probably involved passing in some sub-sequence of interest so we know what tokens to grab, + # then we average across those tokens. + + # (batch) + event_toks_per_seq = event_tokens.sum(1) + + # (batch x seq x hidden) + expanded_tokens = event_tokens.unsqueeze(2).expand( + features[0].shape[0], seq_length, self.hidden_size + ) + + # (batch x seq x hidden) + filtered_features = features[self.layer_to_use] * expanded_tokens + + # (batch x hidden) + x = filtered_features.sum(1) / event_toks_per_seq.unsqueeze(1).expand( + features[0].shape[0], self.hidden_size + ) + elif self.task_type == TAGGING: + # (batch x seq x hidden) + x = features[self.layer_to_use] + elif self.task_type == RELATIONS: + # something like multi-headed attention but without the weighted sum at the end, so i get (num_heads) features for each of N x N grid, which feads into NxN softmax (with the same parameters) + # (batch x seq x hidden) + hidden_states = features[self.layer_to_use] + + # (batch x n_heads x seq x head_size) + key_layer = self.transpose_for_scores(self.key(hidden_states)) + # (batch x n_heads x seq x head_size) + query_layer = self.transpose_for_scores(self.query(hidden_states)) + + # (batch x n_heads x seq x seq) + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + + # Now we have num_heads features for each N X N relations. + x = attention_scores / math.sqrt(self.attention_head_size) + # move the 12 dimension to the end for easier classification + + # (batch x seq x seq x n_heads) + x = x.permute(0, 2, 3, 1) + + else: + # take token (equiv. to [CLS]) + # (batch x hidden) + x = features[self.layer_to_use][..., 0, :] + + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + + # for classification (including event tokens mode): (batch x hidden) + # for tagging: (batch x seq x hidden) + # for relations: (batch x seq x seq x n_heads) + return x + + +### MODULES FOR HIERARCHICAL MODEL ### + + +class PositionwiseFeedForward(nn.Module): + """ + A two-feed-forward-layer module + + Original author: Yu-Hsiang Huang (https://github.com/jadore801120/attention-is-all-you-need-pytorch) + + Args: + d_in: the dimensionality of the input and output of the encoder + d_hid: the inner hidden size of the positionwise FFN in the encoder + dropout: the amount of dropout to use in training (default 0.1) + """ + + def __init__(self, d_in, d_hid, dropout=0.1): + super().__init__() + self.w_1 = nn.Linear(d_in, d_hid) # position-wise + self.w_2 = nn.Linear(d_hid, d_in) # position-wise + self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + residual = x + x = self.layer_norm(x) + + output = self.w_2(F.relu(self.w_1(x))) + output = self.dropout(output) + output += residual + + return output + + +class ScaledDotProductAttention(nn.Module): + """ + Scaled Dot-Product Attention + + Original author: Yu-Hsiang Huang (https://github.com/jadore801120/attention-is-all-you-need-pytorch) + + Args: + temperature: the temperature for scaled dot product attention + attn_dropout: the amount of dropout to use in training + for scaled dot product attention (default 0.1, not + tuned in the rest of the code) + """ + + def __init__(self, temperature, attn_dropout=0.1): + super().__init__() + self.temperature = temperature + self.dropout = nn.Dropout(attn_dropout) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Union[torch.Tensor, None] = None, + ): + attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) + + if mask is not None: + attn = attn.masked_fill(mask == 0, -1e9) + + attn: torch.Tensor = self.dropout(F.softmax(attn, dim=-1)) + output = torch.matmul(attn, v) + + return output, attn + + +class MultiHeadAttention(nn.Module): + """ + Multi-Head Attention module + + Original author: Yu-Hsiang Huang (https://github.com/jadore801120/attention-is-all-you-need-pytorch) + + Args: + n_head: the number of attention heads + d_model: the dimensionality of the input and output of the encoder + d_k: the size of the query and key vectors + d_v: the size of the value vector + """ + + def __init__( + self, + n_head: int, + d_model: int, + d_k: int, + d_v: int, + dropout: float = 0.1, + ): + super().__init__() + + self.n_head = n_head + self.d_k = d_k + self.d_v = d_v + + self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) + self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) + self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) + self.fc = nn.Linear(n_head * d_v, d_model, bias=False) + + self.attention = ScaledDotProductAttention(temperature=d_k**0.5) + + self.dropout = nn.Dropout(dropout) + self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + mask: Union[torch.Tensor, None] = None, + ): + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) + + residual = q + q = self.layer_norm(q) + + # Pass through the pre-attention projection: b x lq x (n*dv) + # Separate different heads: b x lq x n x dv + q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) + k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) + v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) + + # Transpose for attention dot product: b x n x lq x dv + q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) + + if mask is not None: + mask = mask.unsqueeze(1) # For head axis broadcasting. + + output: torch.Tensor + attn: torch.Tensor + output, attn = self.attention(q, k, v, mask=mask) + + # Transpose to move the head dimension back: b x lq x n x dv + # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) + output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1) + output = self.dropout(self.fc(output)) + output += residual + + return output, attn + + +class EncoderLayer(nn.Module): + """ + Compose with two layers + + Original author: Yu-Hsiang Huang (https://github.com/jadore801120/attention-is-all-you-need-pytorch) + + Args: + d_model: the dimensionality of the input and output of the encoder + d_inner: the inner hidden size of the positionwise FFN in the encoder + n_head: the number of attention heads + d_k: the size of the query and key vectors + d_v: the size of the value vector + dropout: the amount of dropout to use in training in both the + attention and FFN steps (default 0.1) + """ + + def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): + super().__init__() + self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) + self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) + + def forward( + self, enc_input: torch.Tensor, slf_attn_mask: Union[torch.Tensor, None] = None + ): + enc_output: torch.Tensor + enc_slf_attn: torch.Tensor + enc_output, enc_slf_attn = self.slf_attn( + enc_input, enc_input, enc_input, mask=slf_attn_mask + ) + enc_output = self.pos_ffn(enc_output) + return enc_output, enc_slf_attn diff --git a/src/cnlpt/modeling/types.py b/src/cnlpt/modeling/types.py new file mode 100644 index 00000000..d8d41ecb --- /dev/null +++ b/src/cnlpt/modeling/types.py @@ -0,0 +1,43 @@ +from enum import Enum + +from .config import ( + CnnModelConfig, + HierarchicalModelConfig, + LstmModelConfig, + ProjectionModelConfig, +) +from .models import CnnModel, HierarchicalModel, LstmModel, ProjectionModel + + +class ModelType(str, Enum): + CNN = "cnn" + LSTM = "lstm" + HIER = "hier" + PROJ = "proj" + + @property + def config_class(self): + if self == ModelType.CNN: + return CnnModelConfig + if self == ModelType.LSTM: + return LstmModelConfig + if self == ModelType.HIER: + return HierarchicalModelConfig + if self == ModelType.PROJ: + return ProjectionModelConfig + + @property + def model_class(self): + if self == ModelType.CNN: + return CnnModel + if self == ModelType.LSTM: + return LstmModel + if self == ModelType.HIER: + return HierarchicalModel + if self == ModelType.PROJ: + return ProjectionModel + + +class ClassificationMode(str, Enum): + CLS = "cls" + TAGGED = "tagged" diff --git a/src/cnlpt/modeling/utils.py b/src/cnlpt/modeling/utils.py new file mode 100644 index 00000000..3ae9adaa --- /dev/null +++ b/src/cnlpt/modeling/utils.py @@ -0,0 +1,52 @@ +import inspect +import logging +import random +import warnings +from typing import Any + +from .. import __version__ as cnlpt_version + +logger = logging.getLogger(__name__) + + +def warn_on_version_mismatch(model_version: str): + ckpt_maj_min = tuple(model_version.split(".", maxsplit=2)[:2]) + cnlpt_maj_min = tuple(cnlpt_version.split(".", maxsplit=2)[:2]) + + if ckpt_maj_min != cnlpt_maj_min: + warning = f"You are loading a model created with cnlpt version {model_version}, but this is version {cnlpt_version}. Be aware that the checkpoint may be incompatible." + + warnings.warn(warning) + logger.warning(warning) + + +def generalize_encoder_forward_kwargs(encoder, **kwargs: Any) -> dict[str, Any]: + """Create a new input feature argument that preserves only the features that are valid for this encoder. + Warn if a feature is present but not valid for the encoder. + + Args: + encoder: A HF encoder model + + Returns: + Dictionary of valid arguments for this encoder + """ + new_kwargs = dict() + params = inspect.signature(encoder.forward).parameters + for name, value in kwargs.items(): + if name not in params and value is not None: + # Warn if a contentful parameter is not valid + logger.warning( + f"Parameter {name} not present for encoder class {encoder.__class__.__name__}." + ) + elif name in params: + # Pass all, and only, parameters that are valid, + # regardless of whether they are None + new_kwargs[name] = value + # else, value is None and not in params, so we ignore it + return new_kwargs + + +def freeze_encoder_weights(encoder, freeze_prob: float): + for param in encoder.parameters(): + if random.random() < freeze_prob: + param.requires_grad = False diff --git a/src/cnlpt/models/__init__.py b/src/cnlpt/models/__init__.py deleted file mode 100644 index efeb47bd..00000000 --- a/src/cnlpt/models/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from transformers import AutoConfig, AutoModel - -from .cnlp import CnlpConfig, CnlpModelForClassification -from .hierarchical import HierarchicalModel - -__all__ = ["CnlpConfig", "CnlpModelForClassification", "HierarchicalModel"] - -AutoConfig.register("cnlpt", CnlpConfig) -AutoModel.register(CnlpConfig, (CnlpModelForClassification, HierarchicalModel)) diff --git a/src/cnlpt/models/baseline/__init__.py b/src/cnlpt/models/baseline/__init__.py deleted file mode 100644 index bc8279fa..00000000 --- a/src/cnlpt/models/baseline/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .cnn import CnnSentenceClassifier -from .lstm import LstmSentenceClassifier - -__all__ = ["CnnSentenceClassifier", "LstmSentenceClassifier"] diff --git a/src/cnlpt/models/baseline/cnn.py b/src/cnlpt/models/baseline/cnn.py deleted file mode 100644 index 36294c2c..00000000 --- a/src/cnlpt/models/baseline/cnn.py +++ /dev/null @@ -1,109 +0,0 @@ -import torch -import torch.nn.functional as F -from huggingface_hub import PyTorchModelHubMixin -from torch import nn - - -class CnnSentenceClassifier(nn.Module, PyTorchModelHubMixin): - def __init__( - self, - vocab_size, - task_names: list[str], - num_labels_dict: dict[str, int], - embed_dims=100, - num_filters=25, - dropout=0.2, - filters=(1, 2, 3), - use_prior_tasks=False, - class_weights=None, - ): - super().__init__() - self.dropout = dropout - - self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dims) - self.convs = nn.ModuleList( - [nn.Conv1d(embed_dims, num_filters, x) for x in filters] - ) - # need separate loss functions with different weights - if class_weights is not None: - if class_weights.ndim > 1: - self.loss_fn = { - task_name: nn.CrossEntropyLoss(weight=class_weights[i]) - for i, task_name in enumerate(task_names) - } - else: - self.loss_fn = { - task_name: nn.CrossEntropyLoss(weight=class_weights) - for task_name in task_names - } - else: - self.loss_fn = { - task_name: nn.CrossEntropyLoss() for task_name in task_names - } - self.fcs = nn.ModuleList() - - self.task_names = task_names - for task_name in self.task_names: - if task_name not in num_labels_dict: - raise ValueError("Misalignment between task_names and num_labels_dict") - self.fcs.append( - nn.Linear(num_filters * len(filters), num_labels_dict[task_name]) - ) - - self.use_prior_tasks = use_prior_tasks - if self.use_prior_tasks: - self.intertask_matrices = [] - for i in range(len(self.task_names)): - matrices = [] - for j in range(len(self.task_names) - i): - matrices.append(nn.Linear(2, num_filters * len(filters))) - self.intertask_matrices.append(matrices) - # self.intertask_matrix = nn.Linear(2, num_filters * len(filters)) - # put logits for task a through intertask_matrix[a][b] to get features to add to features of task b - - def forward( - self, - input_ids=None, - event_tokens=None, - labels=None, - output_hidden_states=False, - **kwargs, - ): - embeddings = self.embed(input_ids) - embeddings = embeddings.transpose(1, 2) - all_convs = [conv(embeddings) for conv in self.convs] - pooled_convs = [ - F.max_pool1d(conv_out, conv_out.shape[2]) for conv_out in all_convs - ] - - fc_in = torch.cat(pooled_convs, 1).squeeze(2) - - logits = [] - loss = 0 - for task_ind, task_fc in enumerate(self.fcs): - # get feaures from previous tasks using the world's tiniest linear layer - if self.use_prior_tasks: - for prev_task_ind in range(task_ind): - prev_task_matrix = self.intertask_matrices[prev_task_ind][ - task_ind - prev_task_ind - 1 - ] - prev_task_matrix = prev_task_matrix.to(logits[prev_task_ind].device) - prev_task_features = prev_task_matrix(logits[prev_task_ind]) - fc_in = fc_in + prev_task_features - task_logits = task_fc(fc_in) - logits.append(task_logits) - - if labels is not None: - if labels.ndim == 2: - # if len(self.fcs) == 1: - # task_labels = labels[:,0] - task_labels = labels[:, task_ind] - elif labels.ndim == 3: - task_labels = labels[:, 0, task_ind] - loss += self.loss_fn[self.task_names[task_ind]]( - task_logits, task_labels.type(torch.LongTensor).to(labels.device) - ) - if output_hidden_states: - return loss, logits, fc_in - else: - return loss, logits diff --git a/src/cnlpt/models/baseline/lstm.py b/src/cnlpt/models/baseline/lstm.py deleted file mode 100644 index a65bd829..00000000 --- a/src/cnlpt/models/baseline/lstm.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -from huggingface_hub import PyTorchModelHubMixin -from torch import nn - - -class LstmSentenceClassifier(nn.Module, PyTorchModelHubMixin): - def __init__( - self, - vocab_size, - task_names: list[str], - num_labels_dict: dict[str, int], - embed_dims=100, - dropout=0.2, - hidden_size=100, - ): - super().__init__() - self.dropout = dropout - - self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dims) - self.lstm = nn.LSTM( - input_size=embed_dims, hidden_size=hidden_size, bidirectional=True - ) - self.loss_fn = nn.CrossEntropyLoss() - - self.fcs = nn.ModuleList() - - self.task_names = task_names - for task_name in self.task_names: - if task_name not in num_labels_dict: - raise ValueError("Misalignment between task_names and num_labels_dict") - self.fcs.append(nn.Linear(4 * hidden_size, num_labels_dict[task_name])) - - def forward( - self, - input_ids=None, - event_tokens=None, - labels=None, - **kwargs, - ): - embeddings = self.embed(input_ids) - # embeddings = embeddings.transpose(1,2) - lstm_out = self.lstm(embeddings)[0] - - logits = [] - loss = 0 - for task_ind, task_fc in enumerate(self.fcs): - features = torch.cat((lstm_out[:, 0, :], lstm_out[:, -1, :]), 1) - task_logits = task_fc(features) - logits.append(task_logits) - - if labels is not None: - if labels.ndim == 2: - task_labels = labels[:, 0] - elif labels.ndim == 3: - task_labels = labels[:, 0, task_ind] - loss += self.loss_fn( - task_logits, task_labels.type(torch.LongTensor).to(labels.device) - ) - return loss, logits diff --git a/src/cnlpt/models/cnlp.py b/src/cnlpt/models/cnlp.py deleted file mode 100644 index 69508897..00000000 --- a/src/cnlpt/models/cnlp.py +++ /dev/null @@ -1,649 +0,0 @@ -""" -Module containing the CNLP transformer model. -""" - -from __future__ import annotations - -import inspect -import logging -import math -import random -from os import PathLike -from typing import Any, Union - -import torch -from torch import nn -from torch.nn import CrossEntropyLoss, MSELoss -from transformers import AutoConfig, AutoModel -from transformers.configuration_utils import PretrainedConfig -from transformers.modeling_outputs import SequenceClassifierOutput -from transformers.modeling_utils import PreTrainedModel - -from .. import __version__ as cnlpt_version - -logger = logging.getLogger(__name__) - - -def generalize_encoder_forward_kwargs(encoder, **kwargs: Any) -> dict[str, Any]: - """Create a new input feature argument that preserves only the features that are valid for this encoder. - Warn if a feature is present but not valid for the encoder. - - Args: - encoder: A HF encoder model - - Returns: - Dictionary of valid arguments for this encoder - """ - new_kwargs = dict() - params = inspect.signature(encoder.forward).parameters - for name, value in kwargs.items(): - if name not in params and value is not None: - # Warn if a contentful parameter is not valid - logger.warning( - f"Parameter {name} not present for encoder class {encoder.__class__.__name__}." - ) - elif name in params: - # Pass all, and only, parameters that are valid, - # regardless of whether they are None - new_kwargs[name] = value - # else, value is None and not in params, so we ignore it - return new_kwargs - - -def freeze_encoder_weights(encoder, freeze: float): - """Probabilistically freeze the weights of this HF encoder model according to the freeze parameter. - Values of freeze >=1 are treated as if every parameter should be frozen. - - Args: - encoder: HF encoder model - freeze: Probability of freezing any given parameter (0-1) - """ - for param in encoder.parameters(): - if freeze >= 1.0: - param.requires_grad = False - else: - dart = random.random() - if dart < freeze: - param.requires_grad = False - - -class ClassificationHead(nn.Module): - """Generic classification head that can be used for any task.""" - - def __init__(self, config, num_labels, hidden_size=-1): - super().__init__() - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.out_proj = nn.Linear( - config.hidden_size if hidden_size < 0 else hidden_size, num_labels - ) - - def forward(self, features, *kwargs): - x = self.dropout(features) - x = self.out_proj(x) - return x - - -class RepresentationProjectionLayer(nn.Module): - """The class that maps from some output from a text encoder into a feature representation that can be classified. - Project the representation to a new space depending on the task type, based on arguments passed in to the constructor.""" - - def __init__( - self, - config: CnlpConfig, - layer: int = 10, - tokens: bool = False, - tagger: bool = False, - relations: bool = False, - num_attention_heads: int = -1, - head_size: int = 64, - ): - """ - Args: - config: The config file for the encoder - layer: Which layer to pull the encoder representation from - tokens: Whether to classify an entity based on the token reprsentation rather than the CLS representation - tagger: Whether the current task is a token tagging task - relations: Whether the current task is relation exttraction - num_attention_heads: For relations, how many "features" to use - head_size: For relations, how big each head should be - """ - super().__init__() - self.dropout = nn.Dropout(config.hidden_dropout_prob) - if relations: - self.dense = nn.Identity() - else: - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - - self.layer_to_use = layer - self.tokens = tokens - self.tagger = tagger - self.relations = relations - self.hidden_size = config.hidden_size - - if num_attention_heads <= 0 and relations: - raise Exception( - "Inconsistent configuration: num_attention_heads must be > 0 for relations" - ) - - if relations: - self.num_attention_heads = num_attention_heads - self.attention_head_size = head_size - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(config.hidden_size, self.all_head_size) - self.key = nn.Linear(config.hidden_size, self.all_head_size) - - if tokens and (tagger or relations): - raise Exception( - "Inconsistent configuration: tokens cannot be true in tagger or relation mode" - ) - - def transpose_for_scores(self, x): - # x: (batch x seq x all_head) - - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads, - self.attention_head_size, - ) - # (batch x seq x n_heads x head_size) - x = x.view(*new_x_shape) - - # (batch x n_heads x seq x head_size) - return x.permute(0, 2, 1, 3) - - def forward(self, features, event_tokens: torch.Tensor, **kwargs): - # features: (layers x batch x seq x hidden) - # event_tokens: (batch x seq) - - seq_length = features[0].shape[1] - if self.tokens: - # grab the average over the tokens of the thing we want to classify - # probably involved passing in some sub-sequence of interest so we know what tokens to grab, - # then we average across those tokens. - - # (batch) - event_toks_per_seq = event_tokens.sum(1) - - # (batch x seq x hidden) - expanded_tokens = event_tokens.unsqueeze(2).expand( - features[0].shape[0], seq_length, self.hidden_size - ) - - # (batch x seq x hidden) - filtered_features = features[self.layer_to_use] * expanded_tokens - - # (batch x hidden) - x = filtered_features.sum(1) / event_toks_per_seq.unsqueeze(1).expand( - features[0].shape[0], self.hidden_size - ) - elif self.tagger: - # (batch x seq x hidden) - x = features[self.layer_to_use] - elif self.relations: - # something like multi-headed attention but without the weighted sum at the end, so i get (num_heads) features for each of N x N grid, which feads into NxN softmax (with the same parameters) - # (batch x seq x hidden) - hidden_states = features[self.layer_to_use] - - # (batch x n_heads x seq x head_size) - key_layer = self.transpose_for_scores(self.key(hidden_states)) - # (batch x n_heads x seq x head_size) - query_layer = self.transpose_for_scores(self.query(hidden_states)) - - # (batch x n_heads x seq x seq) - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - # Now we have num_heads features for each N X N relations. - x = attention_scores / math.sqrt(self.attention_head_size) - # move the 12 dimension to the end for easier classification - - # (batch x seq x seq x n_heads) - x = x.permute(0, 2, 3, 1) - - else: - # take token (equiv. to [CLS]) - # (batch x hidden) - x = features[self.layer_to_use][..., 0, :] - - x = self.dropout(x) - x = self.dense(x) - x = torch.tanh(x) - - # for classification (including event tokens mode): (batch x hidden) - # for tagging: (batch x seq x hidden) - # for relations: (batch x seq x seq x n_heads) - return x - - -class CnlpConfig(PretrainedConfig): - """The config class for :class:`CnlpModelForClassification`.""" - - model_type = "cnlpt" - - def __init__( - self, - *, - encoder_name: Union[str, PathLike] = "roberta-base", - finetuning_task: Union[list[str], None] = None, - layer: int = -1, - tokens: bool = False, - num_rel_attention_heads: int = 12, - rel_attention_head_dims: int = 64, - tagger: dict[str, bool] = {}, - relations: dict[str, bool] = {}, - use_prior_tasks: bool = False, - hier_head_config: Union[dict[str, Any], None] = None, - label_dictionary: Union[dict[str, list[str]], None] = None, - character_level: bool = False, - **kwargs, - ): - """Create a new CnlpConfig object. - - Args: - encoder_name: The encoder name to use with :meth:`transformers.AutoConfig.from_pretrained`. Defaults to "roberta-base". - finetuning_task: The tasks for which this model is fine-tuned. Defaults to None. - layer: The index of the encoder layer to extract features from. Defaults to -1. - tokens: If true, sentence-level classification is done based on averaged token embeddings for token(s) surrounded by special tokens. Defaults to False. - num_rel_attention_heads: The number of features/attention heads to use in the NxN relation classifier. Defaults to 12. - rel_attention_head_dims: The number of parameters in each attention head in the NxN relation classifier. Defaults to 64. - tagger: For each task, whether the task is a sequence tagging task. Defaults to {}. - relations: For each task, whether the task is a relation extraction task. Defaults to {}. - use_prior_tasks: Whether to use the outputs from the previous tasks as additional inputs for subsequent tasks. Defaults to False. - hier_head_config: If this is a hierarchical model, this is where the config parameters go. Defaults to None. - label_dictionary: A mapping from task names to label sets. Defaults to None. - character_level: Whether the encoder is character level. Defaults to False. - """ - super().__init__(**kwargs) - # self.name_or_path='cnlpt' - self.finetuning_task = finetuning_task - self.layer = layer - self.tokens = tokens - self.num_rel_attention_heads = num_rel_attention_heads - self.rel_attention_head_dims = rel_attention_head_dims - self.tagger = tagger - self.relations = relations - self.use_prior_tasks = use_prior_tasks - self.encoder_name = encoder_name - self.encoder_config = AutoConfig.from_pretrained(encoder_name).to_dict() - self.hier_head_config = hier_head_config - self.label_dictionary = label_dictionary - self.cnlpt_version = cnlpt_version - self.character_level = character_level - if encoder_name.startswith("distilbert"): - self.hidden_dropout_prob = self.encoder_config["dropout"] - self.hidden_size = self.encoder_config["dim"] - elif self.encoder_config["model_type"] == "modernbert": - self.hidden_size = self.encoder_config["hidden_size"] - # downstream uses hidden dropout prob for additional layers, modernbert splits into different dropouts for different - # parts of the encoder -- mlp dropout is probably generally good - self.hidden_dropout_prob = self.encoder_config["mlp_dropout"] - # don't need these in my code but keep them around just in case - self.attention_dropout = self.encoder_config["attention_dropout"] - self.embedding_dropout = self.encoder_config["embedding_dropout"] - self.mlp_dropout = self.encoder_config["mlp_dropout"] - self.classifier_dropout = self.encoder_config["classifier_dropout"] - else: - try: - self.hidden_dropout_prob = self.encoder_config["hidden_dropout_prob"] - self.hidden_size = self.encoder_config["hidden_size"] - except KeyError as ke: - raise ValueError( - f"Encoder config does not have an attribute" - f' "{ke.args[0]}"; this is likely because the API of' - f" the chosen encoder differs from the BERT/RoBERTa" - f" API and the DistilBERT API. Encoders with different" - f" APIs are not yet supported (#35)." - ) - - -class CnlpModelForClassification(PreTrainedModel): - """The CNLP transformer model.""" - - base_model_prefix = "cnlpt" - config_class = CnlpConfig - - def __init__( - self, - config: config_class, - *, - class_weights: Union[dict[str, float], None] = None, - final_task_weight: float = 1.0, - freeze: float = -1.0, - bias_fit: bool = False, - ): - """Create a new CNLP transformer model instance from a config object. - - Args: - config: The CnlpConfig object that configures this model - class_weights: If provided, the weights to use for each task when computing the loss. Defaults to None. - final_task_weight: The weight to use for the final task when computing the loss. Defaults to 1.0. - freeze: What proportion of encoder weights to freeze (-1 for none). Defaults to -1.0. - bias_fit: Whether to fine-tune only the bias of the encoder. Defaults to False. - """ - super().__init__(config) - - encoder_config = AutoConfig.from_pretrained(config.encoder_name) - encoder_config.vocab_size = config.vocab_size - config.encoder_config = encoder_config.to_dict() - encoder_model = AutoModel.from_config(encoder_config) - - self.encoder = encoder_model.from_pretrained(config.encoder_name) - - # part of the motivation for leaving this - # logic alone for character level models is that - # at the time of writing, CANINE and Flair are the only game in town. - # CANINE's hashable embeddings for unicode codepoints allows for - # additional parameterization, which rn doesn't seem so relevant - if not config.character_level: - self.encoder.resize_token_embeddings( - encoder_config.vocab_size, mean_resizing=False - ) - # This would seem to be redundant with the label list, which maps from tasks to labels, - # but this version is ordered. This will allow the user to specify an order for any methods - # where we feed the output of one task into the next. - # It also will be used as the canonical order of returning results/logits - self.tasks = config.finetuning_task - - if config.layer > self.num_layers: - raise ValueError( - f"The layer specified ({config.layer}) is too big for the specified encoder which has {self.num_layers} layers" - ) - - if freeze > 0: - freeze_encoder_weights(self.encoder, freeze) - - if bias_fit: - for name, param in self.encoder.named_parameters(): - if "bias" not in name: - param.requires_grad = False - - self.feature_extractors = nn.ModuleDict() - self.classifiers = nn.ModuleDict() - - total_prev_task_labels = 0 - for task_name, task_labels in config.label_dictionary.items(): - task_num_labels = len(task_labels) - self.feature_extractors[task_name] = RepresentationProjectionLayer( - config, - layer=config.layer, - tokens=config.tokens, - tagger=config.tagger[task_name], - relations=config.relations[task_name], - num_attention_heads=config.num_rel_attention_heads, - head_size=config.rel_attention_head_dims, - ) - if config.relations[task_name]: - hidden_size = config.num_rel_attention_heads - if config.use_prior_tasks: - hidden_size += total_prev_task_labels - - self.classifiers[task_name] = ClassificationHead( - config, task_num_labels, hidden_size=hidden_size - ) - else: - self.classifiers[task_name] = ClassificationHead( - config, task_num_labels - ) - total_prev_task_labels += task_num_labels - - # Are we operating as a sequence classifier (1 label per input sequence) or a tagger (1 label per input token in the sequence) - self.tagger = config.tagger - self.relations = config.relations - - if class_weights is None: - self.class_weights = {x: None for x in config.label_dictionary.keys()} - else: - self.class_weights = class_weights - - self.label_dictionary = config.label_dictionary - self.final_task_weight = final_task_weight - self.use_prior_tasks = config.use_prior_tasks - self.reg_temperature = 1.0 - - # self.init_weights() - - @property - def num_layers(self): - if self.encoder.config.model_type == "modernbert": - return len(self.encoder.base_model.layers) - else: - return len(self.encoder.encoder.layer) - - def predict_relations_with_previous_logits( - self, features: torch.Tensor, logits: list[torch.Tensor] - ) -> torch.Tensor: - """For the relation prediction task, use previous predictions of the tagging task as additional features in the - representation used for making the relation prediction. - - Args: - features: The existing feature vector for the relations - logits: The predicted logits from the tagging task - - Returns: - The augmented feature tensor - """ - - # features is (batch x seq x seq x n_heads) - seq_len = features.shape[1] - for prior_task_logits in logits: - if len(features.shape) == 4: - if len(prior_task_logits.shape) == 3: - # prior task is sequence tagging: - # we have batch x seq x num_classes. - # we want to concatenate the num_classes to the variables at each element of the sequence, - # but then need to broadcast it down all the rows of the matrix. - aug = prior_task_logits.unsqueeze( - 2 - ) # add another dimension to repeat along - aug = aug.repeat( - 1, 1, seq_len, 1 - ) # repeat along the new empty dimension so we have our seq logits repeated seq_len x seq_len - features = torch.cat( - (features, aug), 3 - ) # concatenate the relation matrix with the sequence matrix - else: - logger.warning( - f"It is not implemented to add a task of shape {prior_task_logits.shape!s} to a relation matrix" - ) - elif len(features.shape) == 3: - # sequence - logger.warning( - "It is not implemented to add previous task of any type to a sequence task" - ) - - return features - - def compute_loss( - self, - task_logits: torch.FloatTensor, - labels: torch.LongTensor, - task_ind: int, - task_num_labels: int, - batch_size: int, - seq_len: int, - state: dict, - ) -> None: - """ - Computes the loss for a single batch and a single task. - - Args: - task_logits: - labels: - task_ind: - task_num_labels: - batch_size: - seq_len: - state: - :meta private: - """ - task_name = self.tasks[task_ind] - if task_num_labels == 1: - # We are doing regression - loss_fct = MSELoss() - task_loss = loss_fct(task_logits.view(-1), labels.view(-1)) - else: - if self.class_weights[task_name] is not None: - class_weights = torch.FloatTensor(self.class_weights[task_name]).to( - self.device - ) - else: - class_weights = None - loss_fct = CrossEntropyLoss(weight=class_weights) - - if self.relations[task_name]: - task_labels = labels[ - :, :, state["task_label_ind"] : state["task_label_ind"] + seq_len - ] - state["task_label_ind"] += seq_len - task_loss = loss_fct( - task_logits.permute(0, 3, 1, 2), - task_labels.type(torch.LongTensor).to(labels.device), - ) - elif self.tagger[task_name]: - # in cases where we are only given a single task the HF code will have one fewer dimension in the labels, so just add a dummy dimension to make our indexing work: - if labels.ndim == 2: - task_labels = labels - elif labels.ndim == 3: - # labels = labels.unsqueeze(1) - task_labels = labels[:, :, state["task_label_ind"]] - else: - task_labels = labels[:, 0, state["task_label_ind"], :] - - state["task_label_ind"] += 1 - task_loss = loss_fct( - task_logits.view(-1, task_num_labels), - task_labels.reshape( - [ - batch_size * seq_len, - ] - ) - .type(torch.LongTensor) - .to(labels.device), - ) - else: - if labels.ndim == 1: - task_labels = labels - elif labels.ndim == 2: - task_labels = labels[:, task_ind] - elif labels.ndim == 3: - task_labels = labels[:, 0, task_ind] - else: - raise NotImplementedError( - "Have not implemented the case where a classification task " - "is part of an MTL setup with relations and sequence tagging" - ) - - state["task_label_ind"] += 1 - task_loss = loss_fct( - task_logits, task_labels.type(torch.LongTensor).to(labels.device) - ) - - if state["loss"] is None: - state["loss"] = task_loss - else: - task_weight = ( - 1.0 if task_ind + 1 < len(self.tasks) else self.final_task_weight - ) - state["loss"] += task_weight * task_loss - - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - event_tokens=None, - ): - r"""Forward method. - - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, sequence_len)`, *optional*): - A batch of chunked documents as tokenizer indices. - attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_len)`, *optional*): - Attention masks for the batch. - token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_len)`, *optional*): - Token type IDs for the batch. - position_ids: (`torch.LongTensor` of shape `(batch_size, sequence_len)`, *optional*): - Position IDs for the batch. - head_mask (`torch.LongTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Token encoder head mask. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_len, hidden_size)`, *optional*): - A batch of chunked documents as token embeddings. - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. - Indices should be in :obj:`[0, ..., config.num_labels - 1]`. - If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), - If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). - output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. - output_hidden_states: not used. - event_tokens: a mask defining which tokens in the input are to be averaged for input to classifier head; only used when self.tokens==True. - - Returns: (`transformers.SequenceClassifierOutput`) the output of the model - """ - - kwargs = generalize_encoder_forward_kwargs( - self.encoder, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=True, - return_dict=True, - ) - - outputs = self.encoder(input_ids, **kwargs) - - batch_size, seq_len = input_ids.shape - - logits = [] - - state = dict(loss=None, task_label_ind=0) - - for task_ind, task_name in enumerate(self.tasks): - task_labels = self.label_dictionary[task_name] - # hidden_states has shape (layers x batch x seq x hidden) - - # features shape: - # for classification (including event tokens mode): (batch x hidden) - # for tagging: (batch x seq x hidden) - # for relations: (batch x seq x seq x n_heads) - features = self.feature_extractors[task_name]( - outputs.hidden_states, event_tokens - ) - if self.use_prior_tasks: - # note: this specific way of incorporating previous logits doesn't help in my experiments with thyme/clinical tempeval - if self.relations[task_name]: - features = self.predict_relations_with_previous_logits( - features, logits - ) - task_logits = self.classifiers[task_name](features) - logits.append(task_logits) - - if labels is not None: - self.compute_loss( - task_logits, - labels, - task_ind, - len(task_labels), - batch_size, - seq_len, - state, - ) - - if self.training: - return SequenceClassifierOutput( - loss=state["loss"], - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - else: - return SequenceClassifierOutput( - loss=state["loss"], logits=logits, attentions=outputs.attentions - ) diff --git a/src/cnlpt/models/hierarchical.py b/src/cnlpt/models/hierarchical.py deleted file mode 100644 index 93d4bd8a..00000000 --- a/src/cnlpt/models/hierarchical.py +++ /dev/null @@ -1,461 +0,0 @@ -""" -Module containing the Hierarchical Transformer module, adapted from Xin Su. -""" - -import copy -import logging -from dataclasses import dataclass -from typing import Union, cast - -import torch -import torch.nn.functional as F -from torch import nn -from torch.nn import CrossEntropyLoss -from transformers import AutoConfig, AutoModel -from transformers.modeling_outputs import SequenceClassifierOutput -from transformers.modeling_utils import PreTrainedModel - -from .cnlp import ( - ClassificationHead, - CnlpConfig, - freeze_encoder_weights, - generalize_encoder_forward_kwargs, -) - -logger = logging.getLogger(__name__) - - -@dataclass -class HierarchicalSequenceClassifierOutput(SequenceClassifierOutput): - chunk_attentions: Union[tuple[torch.FloatTensor], None] = None - - -class MultiHeadAttention(nn.Module): - """ - Multi-Head Attention module - - Original author: Yu-Hsiang Huang (https://github.com/jadore801120/attention-is-all-you-need-pytorch) - - Args: - n_head: the number of attention heads - d_model: the dimensionality of the input and output of the encoder - d_k: the size of the query and key vectors - d_v: the size of the value vector - """ - - def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): - super().__init__() - - self.n_head = n_head - self.d_k = d_k - self.d_v = d_v - - self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) - self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) - self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) - self.fc = nn.Linear(n_head * d_v, d_model, bias=False) - - self.attention = ScaledDotProductAttention(temperature=d_k**0.5) - - self.dropout = nn.Dropout(dropout) - self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) - - def forward(self, q, k, v, mask=None): - d_k, d_v, n_head = self.d_k, self.d_v, self.n_head - sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) - - residual = q - q = self.layer_norm(q) - - # Pass through the pre-attention projection: b x lq x (n*dv) - # Separate different heads: b x lq x n x dv - q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) - k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) - v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) - - # Transpose for attention dot product: b x n x lq x dv - q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) - - if mask is not None: - mask = mask.unsqueeze(1) # For head axis broadcasting. - - output, attn = self.attention(q, k, v, mask=mask) - - # Transpose to move the head dimension back: b x lq x n x dv - # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) - output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1) - output = self.dropout(self.fc(output)) - output += residual - - return output, attn - - -class PositionwiseFeedForward(nn.Module): - """ - A two-feed-forward-layer module - - Original author: Yu-Hsiang Huang (https://github.com/jadore801120/attention-is-all-you-need-pytorch) - - Args: - d_in: the dimensionality of the input and output of the encoder - d_hid: the inner hidden size of the positionwise FFN in the encoder - dropout: the amount of dropout to use in training (default 0.1) - """ - - def __init__(self, d_in, d_hid, dropout=0.1): - super().__init__() - self.w_1 = nn.Linear(d_in, d_hid) # position-wise - self.w_2 = nn.Linear(d_hid, d_in) # position-wise - self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - residual = x - x = self.layer_norm(x) - - output = self.w_2(F.relu(self.w_1(x))) - output = self.dropout(output) - output += residual - - return output - - -class ScaledDotProductAttention(nn.Module): - """ - Scaled Dot-Product Attention - - Original author: Yu-Hsiang Huang (https://github.com/jadore801120/attention-is-all-you-need-pytorch) - - Args: - temperature: the temperature for scaled dot product attention - attn_dropout: the amount of dropout to use in training - for scaled dot product attention (default 0.1, not - tuned in the rest of the code) - """ - - def __init__(self, temperature, attn_dropout=0.1): - super().__init__() - self.temperature = temperature - self.dropout = nn.Dropout(attn_dropout) - - def forward(self, q, k, v, mask=None): - attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) - - if mask is not None: - attn = attn.masked_fill(mask == 0, -1e9) - - attn = self.dropout(F.softmax(attn, dim=-1)) - output = torch.matmul(attn, v) - - return output, attn - - -class EncoderLayer(nn.Module): - """ - Compose with two layers - - Original author: Yu-Hsiang Huang (https://github.com/jadore801120/attention-is-all-you-need-pytorch) - - Args: - d_model: the dimensionality of the input and output of the encoder - d_inner: the inner hidden size of the positionwise FFN in the encoder - n_head: the number of attention heads - d_k: the size of the query and key vectors - d_v: the size of the value vector - dropout: the amount of dropout to use in training in both the - attention and FFN steps (default 0.1) - """ - - def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): - super().__init__() - self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) - self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) - - def forward(self, enc_input, slf_attn_mask=None): - enc_output, enc_slf_attn = self.slf_attn( - enc_input, enc_input, enc_input, mask=slf_attn_mask - ) - enc_output = self.pos_ffn(enc_output) - return enc_output, enc_slf_attn - - -class HierarchicalModel(PreTrainedModel): - """ - Hierarchical Transformer model (https://arxiv.org/abs/2105.06752) - - Adapted from Xin Su's implementation (https://github.com/xinsu626/DocTransformer) - - Args: - config: - transformer_head_config: - class_weights: - final_task_weight: - freeze: - """ - - base_model_prefix = "hier" - config_class = CnlpConfig - - def __init__( - self, - config: config_class, - *, - freeze: float = -1.0, - class_weights: Union[list[float], None] = None, - ): - # Initialize common components - super().__init__( - config, - ) - - self.config = cast(CnlpConfig, self.config) # for PyCharm - - assert self.config.hier_head_config is not None, ( - "Hierarchical model is being instantiated with no hierarchical head config" - ) - - encoder_config = AutoConfig.from_pretrained(self.config.encoder_name) - encoder_config.vocab_size = self.config.vocab_size - self.config.encoder_config = encoder_config.to_dict() - encoder_model = AutoModel.from_config(encoder_config) - self.encoder = encoder_model.from_pretrained(self.config.encoder_name) - self.encoder.resize_token_embeddings(encoder_config.vocab_size) - if self.config.layer > self.config.hier_head_config["n_layers"]: - raise ValueError( - f"The layer specified ({self.config.layer}) is too big for the specified chunk transformer which has {self.config.hier_head_config['n_layers']} layers" - ) - - if self.config.layer < 0: - self.layer = ( - self.config.hier_head_config["n_layers"] + self.config.layer + 1 - ) - if self.layer < 0: - raise ValueError( - f"The layer specified ({self.config.layer}) is a negative value which is larger than the actual number of layers {self.config.hier_head_config['n_layers']}" - ) - else: - self.layer = self.config.layer - - if self.layer == 0: - raise ValueError( - "The classifier layer derived is 0 which is ambiguous -- there is no usable 0th layer in a hierarchical model. Enter a value for the layer argument that at least 1 (use one layer) or -1 (use the final layer)" - ) - - # This would seem to be redundant with the label list, which maps from tasks to labels, - # but this version is ordered. This will allow the user to specify an order for any methods - # where we feed the output of one task into the next. - # It also will be used as the canonical order of returning results/logits - self.tasks = config.finetuning_task - - if freeze > 0: - freeze_encoder_weights(self.encoder, freeze) - - # Document-level transformer layer - transformer_layer = EncoderLayer( - d_model=self.config.hidden_size, - d_inner=self.config.hier_head_config["d_inner"], - n_head=self.config.hier_head_config["n_head"], - d_k=self.config.hier_head_config["d_k"], - d_v=self.config.hier_head_config["d_v"], - dropout=self.config.hier_head_config["dropout"], - ) - self.transformer = nn.ModuleList( - [ - copy.deepcopy(transformer_layer) - for _ in range(self.config.hier_head_config["n_layers"]) - ] - ) - - self.classifiers = nn.ModuleDict() - # for task_num_labels in self.num_labels: - for task_name, task_labels in config.label_dictionary.items(): - task_num_labels = len(task_labels) - self.classifiers[task_name] = ClassificationHead( - self.config, task_num_labels - ) - - self.label_dictionary = config.label_dictionary - self.set_class_weights(class_weights) - - def remove_task_classifiers(self, tasks: Union[list[str], None] = None): - if tasks is None: - self.classifiers = nn.ModuleDict() - self.tasks = [] - self.class_weights = {} - else: - for task in tasks: - self.classifiers.pop(task) - self.tasks.remove(task) - self.class_weights.pop(task) - - def add_task_classifier(self, task_name: str, task_labels: list[str]): - self.tasks.append(task_name) - self.classifiers[task_name] = ClassificationHead(self.config, len(task_labels)) - self.label_dictionary[task_name] = task_labels - - def set_class_weights(self, class_weights: Union[list[float], None] = None): - if class_weights is None: - self.class_weights = {x: None for x in self.label_dictionary.keys()} - else: - self.class_weights = class_weights - - def forward( - self, - input_ids=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - output_attentions=None, - output_hidden_states=False, - event_tokens=None, - ): - """ - Forward method. - - Args: - input_ids (`torch.LongTensor` of shape `(batch_size, num_chunks, chunk_len)`, *optional*): - A batch of chunked documents as tokenizer indices. - attention_mask (`torch.LongTensor` of shape `(batch_size, num_chunks, chunk_len)`, *optional*): - Attention masks for the batch. - token_type_ids (`torch.LongTensor` of shape `(batch_size, num_chunks, chunk_len)`, *optional*): - Token type IDs for the batch. - position_ids: (`torch.LongTensor` of shape `(batch_size, num_chunks, chunk_len)`, *optional*): - Position IDs for the batch. - head_mask (`torch.LongTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): - Token encoder head mask. - inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_chunks, chunk_len, hidden_size)`, *optional*): - A batch of chunked documents as token embeddings. - labels (`torch.LongTensor` of shape `(batch_size, num_tasks)`, *optional*): - Labels for computing the sequence classification/regression loss. - Indices should be in `[0, ..., self.num_labels[task_ind] - 1]`. - If `self.num_labels[task_ind] == 1` a regression loss is computed (Mean-Square loss), - If `self.num_labels[task_ind] > 1` a classification loss is computed (Cross-Entropy). - output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. - output_hidden_states: If True, return a matrix of shape (batch_size, num_chunks, hidden size) representing the contextualized embeddings of each chunk. The 0-th element of each chunk is the classifier representation for that instance. - event_tokens: not currently used (only relevant for token classification) - - Returns: - - """ - if input_ids is not None: - batch_size, num_chunks, chunk_len = input_ids.shape - flat_shape = (batch_size * num_chunks, chunk_len) - else: # inputs_embeds is not None - batch_size, num_chunks, chunk_len, embed_dim = inputs_embeds.shape - flat_shape = (batch_size * num_chunks, chunk_len, embed_dim) - - kwargs = generalize_encoder_forward_kwargs( - self.encoder, - attention_mask=( - attention_mask.reshape(flat_shape[:3]) - if attention_mask is not None - else None - ), - token_type_ids=( - token_type_ids.reshape(flat_shape[:3]) - if token_type_ids is not None - else None - ), - position_ids=( - position_ids.reshape(flat_shape[:3]) - if position_ids is not None - else None - ), - head_mask=head_mask, - inputs_embeds=( - inputs_embeds.reshape(flat_shape) if inputs_embeds is not None else None - ), - output_attentions=output_attentions, - output_hidden_states=True, - return_dict=True, - ) - - outputs = self.encoder( - input_ids.reshape(flat_shape[:3]) if input_ids is not None else None, - **kwargs, - ) - - logits = [] - hidden_states = None - - # outputs.last_hidden_state.shape: (B * n_chunks, chunk_len, hidden_size) - # (B * n_chunk, hidden_size) - chunks_reps = outputs.last_hidden_state[..., 0, :].reshape( - batch_size, num_chunks, outputs.last_hidden_state.shape[-1] - ) - - # Use pre-trained model's position embedding - position_ids = torch.arange( - num_chunks, dtype=torch.long, device=chunks_reps.device - ) # (n_chunk) - position_ids = position_ids.unsqueeze(0).expand_as( - chunks_reps[:, :, 0] - ) # (B, n_chunk) - position_embeddings = self.encoder.embeddings.position_embeddings(position_ids) - chunks_reps = chunks_reps + position_embeddings - chunks_attns = None - - # document encoding (B, n_chunk, hidden_size) - for layer_ind, layer_module in enumerate(self.transformer): - chunks_reps, chunks_attn = layer_module(chunks_reps) - if output_attentions: - if chunks_attns is None: - chunks_attns = [] - chunks_attns.append(chunks_attn) - - ## this case is mainly for when we are doing subsequent fine-tuning using a pre-trained - ## hierarchical model and we want to check whether an earlier layer might provide better - ## classification performance (e.g., if we think the last layer(s) are overfit to the pre-training - ## objective) Just short circuit rather than doing the whole computation. - if layer_ind + 1 >= self.layer: - break - - if output_hidden_states: - hidden_states = chunks_reps - - # extract first Documents as rep. (B, hidden_size) - doc_rep = chunks_reps[:, 0, :] - - total_loss = None - for task_ind, task_name in enumerate(self.tasks): - if self.class_weights[task_name] is not None: - class_weights = torch.FloatTensor(self.class_weights[task_name]).to( - self.device - ) - else: - class_weights = None - loss_fct = CrossEntropyLoss(weight=class_weights) - - # predict (B, 5) - task_logits = self.classifiers[task_name](doc_rep) - logits.append(task_logits) - - if labels is not None: - task_labels = labels[:, task_ind] - task_loss = loss_fct( - task_logits, task_labels.type(torch.LongTensor).to(labels.device) - ) - if total_loss is None: - total_loss = task_loss - else: - total_loss += task_loss - - if self.training: - return HierarchicalSequenceClassifierOutput( - loss=total_loss, - logits=logits, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - chunk_attentions=chunks_attns, - ) - else: - return HierarchicalSequenceClassifierOutput( - loss=total_loss, - logits=logits, - hidden_states=hidden_states, - attentions=outputs.attentions, - chunk_attentions=chunks_attns, - ) From ebc56f0820de05d990cf9519a3663424704aaf50 Mon Sep 17 00:00:00 2001 From: ianbulovic Date: Fri, 29 Aug 2025 13:12:36 -0400 Subject: [PATCH 2/8] rework args --- src/cnlpt/args/__init__.py | 23 - src/cnlpt/args/data_args.py | 107 ---- src/cnlpt/args/log.py | 3 - src/cnlpt/args/model_args.py | 204 ------- src/cnlpt/args/parse_args.py | 94 ---- src/cnlpt/data/__init__.py | 3 +- src/cnlpt/data/cnlp_dataset.py | 144 +++-- src/cnlpt/data/preprocess.py | 34 +- src/cnlpt/data/task_info.py | 1 + src/cnlpt/train_system/__init__.py | 3 +- .../training_args.py => train_system/args.py} | 65 ++- src/cnlpt/train_system/cnlp_train_system.py | 527 +++--------------- src/cnlpt/train_system/display.py | 31 +- src/cnlpt/train_system/log.py | 2 +- src/cnlpt/train_system/training_callbacks.py | 28 +- src/cnlpt/train_system/utils.py | 2 +- 16 files changed, 280 insertions(+), 991 deletions(-) delete mode 100644 src/cnlpt/args/__init__.py delete mode 100644 src/cnlpt/args/data_args.py delete mode 100644 src/cnlpt/args/log.py delete mode 100644 src/cnlpt/args/model_args.py delete mode 100644 src/cnlpt/args/parse_args.py rename src/cnlpt/{args/training_args.py => train_system/args.py} (50%) diff --git a/src/cnlpt/args/__init__.py b/src/cnlpt/args/__init__.py deleted file mode 100644 index d0d0bf64..00000000 --- a/src/cnlpt/args/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -Module containing the CNLP command line argument definitions -""" - -from .data_args import CnlpDataArguments -from .model_args import CnlpModelArguments -from .parse_args import ( - parse_args_dict, - parse_args_from_argv, - parse_args_json_file, - preprocess_args, -) -from .training_args import CnlpTrainingArguments - -__all__ = [ - "CnlpDataArguments", - "CnlpModelArguments", - "CnlpTrainingArguments", - "parse_args_dict", - "parse_args_from_argv", - "parse_args_json_file", - "preprocess_args", -] diff --git a/src/cnlpt/args/data_args.py b/src/cnlpt/args/data_args.py deleted file mode 100644 index 5ca62413..00000000 --- a/src/cnlpt/args/data_args.py +++ /dev/null @@ -1,107 +0,0 @@ -from dataclasses import dataclass, field -from typing import Union - - -@dataclass -class CnlpDataArguments: - """ - Arguments pertaining to what data we are going to input our model for training and eval. - - Using :class:`transformers.HfArgumentParser` we can turn this class - into argparse arguments to be able to specify them on - the command line. - """ - - data_dir: list[str] = field( - metadata={ - "help": "The input data dirs. A space-separated list of directories that " - "should contain the .tsv files (or other data files) for the task. " - "Should be presented in the same order as the task names." - } - ) - - task_name: Union[list[str], None] = field( - default_factory=lambda: None, - metadata={ - "help": "A space-separated list of tasks to train on (mainly used as keys to internally track and display output)" - }, - ) - # field( - # metadata={"help": "A space-separated list of tasks to train on: " + ", ".join(cnlp_processors.keys())}) - - max_seq_length: int = field( - default=128, - metadata={ - "help": "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated, sequences shorter will be padded." - }, - ) - overwrite_cache: bool = field( - default=False, - metadata={"help": "Overwrite the cached training and evaluation sets"}, - ) - - weight_classes: bool = field( - default=False, - metadata={ - "help": "A flag that indicates whether class-specific loss should be used. " - "This can be useful in cases with severe class imbalance. The formula " - "for a weight of a class is the count of that class divided the count " - "of the rarest class." - }, - ) - - chunk_len: Union[int, None] = field( - default=None, metadata={"help": "Chunk length for hierarchical model"} - ) - character_level: bool = field( - default=False, - metadata={ - "help": "Whether the dataset sould be processed at the character level" - "(otherwise will be processed at the token level)" - }, - ) - - num_chunks: Union[int, None] = field( - default=None, metadata={"help": "Max chunk count for hierarchical model"} - ) - - insert_empty_chunk_at_beginning: bool = field( - default=False, - metadata={"help": "Whether to insert an empty chunk for hierarchical model"}, - ) - - truncate_examples: bool = field( - default=False, - metadata={ - "help": "Whether to truncate input examples when displaying them in the log" - }, - ) - - max_train_items: Union[int, None] = field( - default=-1, - metadata={ - "help": "Set a number of train instances to use during training (useful for debugging data processing logic if a dataset is very large. Default is to train on all training data." - }, - ) - - max_eval_items: Union[int, None] = field( - default=-1, - metadata={ - "help": "Set a number of validation instances to use during training (useful if a dataset has been created using dumb logic like 80/10/10 and 10%% takes forever to evaluate on. Default is evaluate on all validation data." - }, - ) - - max_test_items: Union[int, None] = field( - default=-1, - metadata={ - "help": "Set a number of test instances to use during prediction (useful for debugging)" - }, - ) - - allow_disjoint_labels: bool = field( - default=False, - metadata={ - "help": "Allow tasks to have disjoint label sets in different data splits (useful for testing)." - }, - ) diff --git a/src/cnlpt/args/log.py b/src/cnlpt/args/log.py deleted file mode 100644 index e9140742..00000000 --- a/src/cnlpt/args/log.py +++ /dev/null @@ -1,3 +0,0 @@ -import logging - -logger = logging.getLogger("cnlpt.data") diff --git a/src/cnlpt/args/model_args.py b/src/cnlpt/args/model_args.py deleted file mode 100644 index b57cc728..00000000 --- a/src/cnlpt/args/model_args.py +++ /dev/null @@ -1,204 +0,0 @@ -from dataclasses import dataclass, field, fields -from enum import Enum -from typing import Union - -cnlpt_models = ["cnn", "lstm", "hier", "cnlpt"] - - -@dataclass -class CnlpModelArguments: - """ - Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. - See all possible arguments by passing the ``--help`` flag to this script. - """ - - model: Union[str, None] = field( - default="cnlpt", metadata={"help": "Model type", "choices": cnlpt_models} - ) - encoder_name: Union[str, None] = field( - default="roberta-base", - metadata={ - "help": "Path to pretrained model or model identifier from huggingface.co/models" - }, - ) - config_name: Union[str, None] = field( - default=None, - metadata={ - "help": "Pretrained config name or path if not the same as model_name" - }, - ) - tokenizer_name: Union[str, None] = field( - default=None, - metadata={ - "help": "Pretrained tokenizer name or path if not the same as model_name" - }, - ) - cache_dir: Union[str, None] = field( - default=None, - metadata={ - "help": "Where do you want to store the pretrained models downloaded from s3" - }, - ) - layer: int = field( - default=-1, metadata={"help": "Which layer's CLS ('') token to use"} - ) - token: bool = field( - default=False, - metadata={ - "help": "Classify over an actual token rather than the [CLS] ('') token -- requires that the tokens to be classified are surrounded by / tokens" - }, - ) - - # NxN relation classifier-specific arguments - num_rel_feats: int = field( - default=12, - metadata={ - "help": "Number of features/attention heads to use in the NxN relation classifier" - }, - ) - head_features: int = field( - default=64, - metadata={ - "help": "Number of parameters in each attention head in the NxN relation classifier" - }, - ) - - # CNN-specific arguments - cnn_embed_dim: int = field( - default=100, - metadata={ - "help": "For the CNN baseline model, the size of the word embedding space." - }, - ) - cnn_num_filters: int = field( - default=25, - metadata={ - "help": ( - "For the CNN baseline model, the number of " - "convolution filters to use for each filter size." - ) - }, - ) - - cnn_filter_sizes: list[int] = field( - default_factory=lambda: [1, 2, 3], - metadata={ - "help": ( - "For the CNN baseline model, a space-separated list " - "of size(s) of the filters (kernels)" - ) - }, - ) - - # LSTM-specific arguments - lstm_embed_dim: int = field( - default=100, - metadata={ - "help": "For the LSTM baseline model, the size of the word embedding space." - }, - ) - lstm_hidden_size: int = field( - default=100, - metadata={ - "help": "For the LSTM baseline model, the hidden size of the LSTM layer" - }, - ) - - # Multi-task classifier-specific arguments - use_prior_tasks: bool = field( - default=False, - metadata={ - "help": "In the multi-task setting, incorporate the logits from the previous tasks into subsequent representation layers. This will be done in the task order specified in the command line." - }, - ) - - # Hierarchical Transformer-specific arguments - hier_num_layers: int = field( - default=2, - metadata={ - "help": ( - "For the hierarchical model, the number of document-level transformer " - "layers" - ) - }, - ) - hier_hidden_dim: int = field( - default=2048, - metadata={ - "help": ( - "For the hierarchical model, the inner hidden size of the positionwise " - "FFN in the document-level transformer layers" - ) - }, - ) - hier_n_head: int = field( - default=8, - metadata={ - "help": ( - "For the hierarchical model, the number of attention heads in the " - "document-level transformer layers" - ) - }, - ) - hier_d_k: int = field( - default=8, - metadata={ - "help": ( - "For the hierarchical model, the size of the query and key vectors in " - "the document-level transformer layers" - ) - }, - ) - hier_d_v: int = field( - default=96, - metadata={ - "help": ( - "For the hierarchical model, the size of the value vectors in the " - "document-level transformer layers" - ) - }, - ) - hier_dropout: float = field( - default=0.1, - metadata={ - "help": "For the hierarchical model, the dropout probability for the " - "document-level transformer layers" - }, - ) - keep_existing_classifiers: bool = field( - default=False, - metadata={ - "help": ( - "For the hierarchical model, load classifier weights from " - "the saved checkpoint. For inference of the trained model or " - "continued fine-tuning." - ) - }, - ) - ignore_existing_classifiers: bool = field( - default=False, - metadata={ - "help": ( - "For the hierarchical model, ignore classifier weights " - "from the saved checkpoint. The weights will be initialized." - ) - }, - ) - - def to_dict(self): - # adapted from transformers.TrainingArguments.to_dict() - # filter out fields that are defined as field(init=False) - d = { - field.name: getattr(self, field.name) - for field in fields(self) - if field.init - } - - for k, v in d.items(): - if isinstance(v, Enum): - d[k] = v.value - if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): - d[k] = [x.value for x in v] - if k.endswith("_token"): - d[k] = f"<{k.upper()}>" - return d diff --git a/src/cnlpt/args/parse_args.py b/src/cnlpt/args/parse_args.py deleted file mode 100644 index dd3382c3..00000000 --- a/src/cnlpt/args/parse_args.py +++ /dev/null @@ -1,94 +0,0 @@ -import os -import sys -from typing import Any, Union, cast - -import torch -from transformers.hf_argparser import DataClassType, HfArgumentParser - -from .data_args import CnlpDataArguments -from .log import logger -from .model_args import CnlpModelArguments -from .training_args import CnlpTrainingArguments - - -def _cast_dataclasses_to_args( - dataclasses: tuple[Any, ...], -) -> tuple[CnlpModelArguments, CnlpDataArguments, CnlpTrainingArguments]: - return cast( - tuple[CnlpModelArguments, CnlpDataArguments, CnlpTrainingArguments], dataclasses - ) - - -def _get_args_parser(): - args_dataclasses = cast( - tuple[DataClassType, ...], - (CnlpModelArguments, CnlpDataArguments, CnlpTrainingArguments), - ) - return HfArgumentParser(args_dataclasses, prog="cnlpt train") - - -def parse_args_dict( - args: dict[str, Any], -): - return _cast_dataclasses_to_args(_get_args_parser().parse_dict(args)) - - -def parse_args_json_file( - json_file: Union[str, os.PathLike], -): - return _cast_dataclasses_to_args(_get_args_parser().parse_json_file(json_file)) - - -def parse_args_from_argv( - argv: Union[list[str], None] = None, -): - if argv is None: - argv = sys.argv - if len(argv) == 2 and argv[1].endswith(".json"): - # If we pass only one argument to the script and it's the path to a json file, - # let's parse it to get our arguments. - return parse_args_json_file(argv[1]) - else: - return _cast_dataclasses_to_args( - _get_args_parser().parse_args_into_dataclasses(argv) - ) - - -def preprocess_args( - model_args: CnlpModelArguments, - data_args: CnlpDataArguments, - training_args: CnlpTrainingArguments, -): - if ( - training_args.output_dir is not None - and os.path.exists(training_args.output_dir) - and os.listdir(training_args.output_dir) - and training_args.do_train - and not training_args.overwrite_output_dir - ): - raise ValueError( - f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome." - ) - - if training_args.truncation_side_left: - if model_args.model == "hier": - logger.warning( - "truncation_side_left flag is not available for the hierarchical model -- setting to false" - ) - training_args.truncation_side_left = False - - if torch.mps.is_available(): - # pin_memory is unsupported on MPS, but defaults to True, - # so we'll explicitly turn it off to avoid a warning. - training_args.dataloader_pin_memory = False - - if training_args.metric_for_best_model is None: - training_args.metric_for_best_model = "eval_avg_acc" - elif not training_args.metric_for_best_model.startswith("eval_"): - training_args.metric_for_best_model = ( - f"eval_{training_args.metric_for_best_model}" - ) - if training_args.load_best_model_at_end is None: - training_args.load_best_model_at_end = True - - return model_args, data_args, training_args diff --git a/src/cnlpt/data/__init__.py b/src/cnlpt/data/__init__.py index 70cf8d2a..40a60d1a 100644 --- a/src/cnlpt/data/__init__.py +++ b/src/cnlpt/data/__init__.py @@ -1,4 +1,4 @@ -from .cnlp_dataset import CnlpDataset +from .cnlp_dataset import CnlpDataset, HierarchicalDataConfig from .predictions import CnlpPredictions from .preprocess import preprocess_raw_data from .task_info import ( @@ -16,6 +16,7 @@ "TAGGING", "CnlpDataset", "CnlpPredictions", + "HierarchicalDataConfig", "TaskInfo", "TaskType", "get_task_type", diff --git a/src/cnlpt/data/cnlp_dataset.py b/src/cnlpt/data/cnlp_dataset.py index a5a10350..1e29b79b 100644 --- a/src/cnlpt/data/cnlp_dataset.py +++ b/src/cnlpt/data/cnlp_dataset.py @@ -1,22 +1,48 @@ +import os +from collections import Counter +from dataclasses import dataclass +from enum import Enum +from typing import Literal, Union + +import torch from datasets import Dataset +from transformers.models.auto.tokenization_auto import AutoTokenizer from transformers.tokenization_utils import PreTrainedTokenizer -from ..args.data_args import CnlpDataArguments from .data_reader import CnlpDataReader from .preprocess import preprocess_raw_data -def _validate_dataset_args(args: CnlpDataArguments, hierarchical: bool): - if hierarchical: - if args.chunk_len is None or args.num_chunks is None: - raise ValueError( - "For the hierarchical model, data_args.chunk_len and data_args.num_chunks must be specified." - ) - implicit_max_len = args.chunk_len * args.num_chunks - if args.max_seq_length < implicit_max_len: - raise ValueError( - "For the hierarchical model, the max seq length should be equal to the chunk length * num_chunks, otherwise what is the point?" - ) +@dataclass(frozen=True) +class HierarchicalDataConfig: + chunk_len: int + num_chunks: int + prepend_empty_chunk: bool + + +def load_tokenizer( + model_name_or_path: str, + hf_cache_dir: Union[str, None] = None, + truncation_side: Literal["left", "right"] = "right", + character_level: bool = False, +) -> PreTrainedTokenizer: + tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, + cache_dir=hf_cache_dir, + add_prefix_space=True, + truncation_side=truncation_side, + additional_special_tokens=( + ["", "", "", "", "", "", "", ""] + if not character_level + else None + ), + ) + return tokenizer + + +class TruncationSide(str, Enum): + LEFT = "left" + RIGHT = "right" class CnlpDataset: @@ -24,9 +50,19 @@ class CnlpDataset: def __init__( self, - args: CnlpDataArguments, - tokenizer: PreTrainedTokenizer, - hierarchical: bool = False, + data_dir: Union[str, os.PathLike], + tokenizer: Union[str, PreTrainedTokenizer] = "roberta-base", + task_names: Union[list[str], None] = None, + hier_config: Union[HierarchicalDataConfig, None] = None, + truncation_side: TruncationSide = TruncationSide.RIGHT, + max_seq_length: int = 128, + use_data_cache: bool = True, + max_train: Union[int, None] = None, + max_eval: Union[int, None] = None, + max_test: Union[int, None] = None, + allow_disjoint_labels: bool = False, + character_level: bool = False, + hf_cache_dir: Union[str, None] = None, ): """Create a new `CnlpDataset`. @@ -35,25 +71,44 @@ def __init__( tokenizer: Tokenizer to tokenize the raw data. hierarchical: Whether this data is being preprocessed for a hierarchical model. Defaults to False. """ - _validate_dataset_args(args, hierarchical) - self.hierarchical = hierarchical + if hier_config is not None: + implicit_max_len = hier_config.chunk_len * hier_config.num_chunks + + # TODO(ian) should this be `!=`` instead of `<`? + if max_seq_length < implicit_max_len: + raise ValueError( + "For the hierarchical model, the max seq length should be equal to the chunk length * num_chunks, otherwise what is the point?" + ) + + self.data_dir = data_dir + if isinstance(tokenizer, str): + self.tokenizer = load_tokenizer( + tokenizer, + hf_cache_dir=hf_cache_dir, + truncation_side=truncation_side, + character_level=character_level, + ) + else: + self.tokenizer = tokenizer - reader = CnlpDataReader(allow_disjoint_labels=args.allow_disjoint_labels) - for data_dir in args.data_dir: - reader.load_dir(data_dir) + reader = CnlpDataReader(allow_disjoint_labels=allow_disjoint_labels) + reader.load_dir(data_dir) - self.tasks = reader.get_tasks(args.task_name or None) + self.tasks = reader.get_tasks(task_names) self.dataset = reader.dataset - if (val_limit := (args.max_eval_items or 0)) > 0: - self.dataset["validation"] = self.dataset["validation"].take(val_limit) + if max_train is not None: + self.dataset["train"] = self.dataset["train"].take(max_train) + if max_eval is not None: + self.dataset["validation"] = self.dataset["validation"].take(max_eval) + if max_test is not None: + self.dataset["test"] = self.dataset["test"].take(max_test) - if (train_limit := (args.max_train_items or 0)) > 0: - self.dataset["train"] = self.dataset["train"].take(train_limit) - - if (test_limit := (args.max_test_items or 0)) > 0: - self.dataset["test"] = self.dataset["test"].take(test_limit) + self.hier_config = hier_config + self.truncation_side = truncation_side + self.max_seq_length = max_seq_length + self.character_level = character_level split_data: Dataset for split_name, split_data in self.dataset.items(): @@ -61,19 +116,16 @@ def __init__( preprocess_raw_data, desc=f"Preprocessing {split_name} data", batched=True, - load_from_cache_file=not args.overwrite_cache, + load_from_cache_file=use_data_cache, batch_size=100, num_proc=1, fn_kwargs={ - "tokenizer": tokenizer, + "tokenizer": self.tokenizer, "tasks": self.tasks, - "max_length": args.max_seq_length, + "max_length": self.max_seq_length, "inference_only": "train" not in reader.split_names, - "hierarchical": self.hierarchical, - "character_level": args.character_level, - "chunk_len": args.chunk_len, - "num_chunks": args.num_chunks, - "insert_empty_chunk_at_beginning": args.insert_empty_chunk_at_beginning, + "character_level": self.character_level, + "hier_config": self.hier_config, }, ) @@ -91,3 +143,23 @@ def validation_data(self): def test_data(self): """This dataset's test split.""" return self.dataset["test"] + + def get_class_weights(self, device: torch.device): + class_weights: dict[str, torch.FloatTensor] = {} + for task in self.tasks: + train_labels = self.train_data[task.name] + weights: list[float] = [] + train_label_counts = Counter(train_labels) + for label in task.labels: + # class weights are determined by severity of class imbalance + weights.append( + len(train_labels) / (len(task.labels) * train_label_counts[label]) + ) + + class_weights[task.name] = torch.tensor( + # if we just have the one class, simplify the tensor or pytorch will be mad + # TODO(ian) why would we ever have just one class?? + weights[0] if len(weights) == 1 else weights + ).to(device) + + return class_weights diff --git a/src/cnlpt/data/preprocess.py b/src/cnlpt/data/preprocess.py index ce09432d..8faf5507 100644 --- a/src/cnlpt/data/preprocess.py +++ b/src/cnlpt/data/preprocess.py @@ -1,6 +1,6 @@ import logging from collections.abc import Iterable -from typing import Any, Final, Union +from typing import TYPE_CHECKING, Any, Final, Union import numpy as np from transformers.tokenization_utils import PreTrainedTokenizer @@ -8,6 +8,9 @@ from .task_info import CLASSIFICATION, RELATIONS, TAGGING, TaskInfo +if TYPE_CHECKING: + from .cnlp_dataset import HierarchicalDataConfig + logger = logging.getLogger(__name__) MISSING_DATA_STR: Final = "__None__" MASK_VALUE: Final = -100 @@ -19,11 +22,8 @@ def preprocess_raw_data( tasks: Union[Iterable[TaskInfo], None], max_length: Union[int, None] = None, inference_only: bool = False, - hierarchical: bool = False, character_level: bool = False, - chunk_len: int = -1, - num_chunks: int = -1, - insert_empty_chunk_at_beginning: bool = False, + hier_config: Union["HierarchicalDataConfig", None] = None, ) -> BatchEncoding: """Preprocess raw CNLP data for training/evaluation. @@ -52,7 +52,7 @@ def preprocess_raw_data( batch=batch, tokenizer=tokenizer, max_length=max_length, - hierarchical=hierarchical, + hierarchical=hier_config is not None, character_level=character_level, ) @@ -88,15 +88,15 @@ def preprocess_raw_data( tokenized_input["event_mask"] = _build_event_mask_character( tokenized_input=tokenized_input ) - if hierarchical: + if hier_config is not None: tokenized_input = _convert_features_to_hierarchical( tokenized_input, - chunk_len=chunk_len, - num_chunks=num_chunks, + chunk_len=hier_config.chunk_len, + num_chunks=hier_config.num_chunks, cls_id=tokenizer.cls_token_id, sep_id=tokenizer.sep_token_id, pad_id=tokenizer.pad_token_id, - insert_empty_chunk_at_beginning=insert_empty_chunk_at_beginning, + insert_empty_chunk_at_beginning=hier_config.prepend_empty_chunk, ) return tokenized_input @@ -313,15 +313,19 @@ def _tokenize_batch( 'The data does not seem to have a text column (literally a column labeled "text" is required)' ) - if hierarchical: - padding = False - else: - padding = "max_length" + # TODO(ian) Why was padding for hierarchical models disabled? At a glance it seems like it's ok + # to enable it, but this should be confirmed before merging. + + # if hierarchical: + # padding = False + # else: + # padding = "max_length" tokenized_batch = tokenizer( sentences, max_length=max_length, - padding=padding, + # padding=padding, # TODO(ian) See above. + padding="max_length", truncation=True, is_split_into_words=not character_level, ) diff --git a/src/cnlpt/data/task_info.py b/src/cnlpt/data/task_info.py index 79e09f12..58e8dbbb 100644 --- a/src/cnlpt/data/task_info.py +++ b/src/cnlpt/data/task_info.py @@ -2,6 +2,7 @@ from functools import cached_property from typing import Final, Literal, Union +# TODO(ian) convert this to an enum for consistency? TaskType = Literal["classification", "tagging", "relations"] "A type of task that this library can train a model to solve." diff --git a/src/cnlpt/train_system/__init__.py b/src/cnlpt/train_system/__init__.py index 10f2e655..6dfcc221 100644 --- a/src/cnlpt/train_system/__init__.py +++ b/src/cnlpt/train_system/__init__.py @@ -1,3 +1,4 @@ +from .args import CnlpTrainingArguments from .cnlp_train_system import CnlpTrainSystem -__all__ = ["CnlpTrainSystem"] +__all__ = ["CnlpTrainSystem", "CnlpTrainingArguments"] diff --git a/src/cnlpt/args/training_args.py b/src/cnlpt/train_system/args.py similarity index 50% rename from src/cnlpt/args/training_args.py rename to src/cnlpt/train_system/args.py index f7e8ecc8..fbba3387 100644 --- a/src/cnlpt/args/training_args.py +++ b/src/cnlpt/train_system/args.py @@ -1,30 +1,40 @@ from dataclasses import dataclass, field +from typing import Union +import torch +from transformers.trainer_utils import IntervalStrategy from transformers.training_args import TrainingArguments @dataclass class CnlpTrainingArguments(TrainingArguments): - """ - Additional arguments specific to this class. - See all possible arguments in :class:`transformers.TrainingArguments` - or by passing the ``--help`` flag to this script. - """ + def __post_init__(self): + if self.metric_for_best_model is None: + self.metric_for_best_model = "eval_avg_macro_f1" + elif not self.metric_for_best_model.startswith("eval_"): + self.metric_for_best_model = f"eval_{self.metric_for_best_model}" - evals_per_epoch: int = field( - default=-1, + # `dataloader_pin_memory` is unsupported on mps but defaults to True, + # so we'll disable it here to avoid warnings in the console. + if self.dataloader_pin_memory and self.device == torch.device("mps"): + self.dataloader_pin_memory = False + + return super().__post_init__() + + weight_classes: bool = field( + default=False, metadata={ - "help": "Number of times to evaluate and possibly save model per training epoch (allows for a lazy kind of early stopping)" + "help": "A flag that indicates whether class-specific loss should be used. This can be useful in cases with severe class imbalance. The formula for a weight of a class is the count of that class divided the count of the rarest class." }, ) final_task_weight: float = field( default=1.0, metadata={ - "help": "Amount to up/down-weight final task in task list (other tasks weighted 1.0)" + "help": "Amount to up/down-weight final task in task list (other tasks weighted 1.0)." }, ) - freeze: float = field( - default=-1.0, + freeze_encoder: float = field( + default=0.0, metadata={ "help": "Freeze the encoder layers and only train the layer between the encoder and classification architecture. Probably works best with --token flag since [CLS] may not be well-trained for anything in particular. If not specified, no weight freezing will be done. If specified as a flag (no arguments), 100%% of weights will be frozen. If a float (0..1.0) is specified, each weight will be frozen with that probability.", "nargs": "?", @@ -34,34 +44,41 @@ class CnlpTrainingArguments(TrainingArguments): bias_fit: bool = field( default=False, metadata={ - "help": "Only optimize the bias parameters of the encoder (and the weights of the classifier heads), as proposed in the BitFit paper by Ben Zaken et al. 2021 (https://arxiv.org/abs/2106.10199)" + "help": "Only optimize the bias parameters of the encoder (and the weights of the classifier heads), as proposed in the BitFit paper by Ben Zaken et al. 2021 (https://arxiv.org/abs/2106.10199)." }, ) - output_prob: bool = field( + report_probs: bool = field( default=False, metadata={ - "help": "If selected, probability scores will be added to the output prediction file for test data when used with --do_predict, and to the evaluation file for dev data when used with --error_analysis. Currently implemented for classification tasks only." + "help": "If selected, probability scores will be added to the output prediction file for test data when used with --do_predict." }, ) - truncation_side_left: bool = field( - default=False, + evals_per_epoch: int = field( + default=0, metadata={ - "help": "Truncate samples from left. Note that hier model do not support this setting." + "help": "Number of times to evaluate and possibly save model per training epoch (allows for a lazy kind of early stopping). Note that setting this argument will automatically override `eval_steps` and `eval_strategy`." }, ) - error_analysis: bool = field( - default=False, + rich_display: bool = field( + default=True, metadata={ - "help": "Pretty printing for instances where at least one ground truth label for any of the tasks disagrees with the model's prediction" + "help": "Whether to render a live progress display in the console during training." }, ) - logging_strategy: str = field(default="epoch") - - rich_display: bool = field( + # override transformers TrainingArguments defaults + logging_strategy: IntervalStrategy = field( + default="epoch", + metadata={"help": "The evaluation strategy to adopt during training."}, + ) + logging_first_step: bool = field( default=True, + metadata={"help": "Whether to log the first step of training."}, + ) + cache_dir: Union[str, None] = field( + default=None, metadata={ - "help": "Whether to render a live progress display in the console during training." + "help": "Optionally override the HuggingFace cache directory.", }, ) diff --git a/src/cnlpt/train_system/cnlp_train_system.py b/src/cnlpt/train_system/cnlp_train_system.py index 6117f941..bfd14244 100644 --- a/src/cnlpt/train_system/cnlp_train_system.py +++ b/src/cnlpt/train_system/cnlp_train_system.py @@ -1,430 +1,74 @@ import contextlib import math import os -from collections import Counter -from typing import Any, Union, cast +from typing import Union import numpy as np import numpy.typing as npt -import torch from datasets import Dataset -from transformers.models.auto.configuration_auto import AutoConfig -from transformers.models.auto.modeling_auto import AutoModel -from transformers.models.auto.tokenization_auto import AutoTokenizer -from transformers.tokenization_utils import PreTrainedTokenizer from transformers.trainer import Trainer from transformers.trainer_callback import PrinterCallback, TrainerCallback -from transformers.trainer_utils import EvalPrediction, IntervalStrategy, set_seed - -from ..args import ( - CnlpDataArguments, - CnlpModelArguments, - CnlpTrainingArguments, - parse_args_dict, - parse_args_from_argv, - parse_args_json_file, - preprocess_args, -) +from transformers.trainer_utils import EvalPrediction, IntervalStrategy + from ..data import RELATIONS, TAGGING, CnlpDataset, CnlpPredictions -from ..models import CnlpConfig, CnlpModelForClassification, HierarchicalModel -from ..models.baseline import CnnSentenceClassifier, LstmSentenceClassifier +from ..modeling.models import CnnModel, HierarchicalModel, LstmModel, ProjectionModel +from .args import CnlpTrainingArguments from .display import TrainSystemDisplay -from .log import configure_logger_for_training, logger +from .log import configure_logger_for_training from .metrics import TaskEvalPrediction from .training_callbacks import BasicLoggingCallback, DisplayCallback -from .utils import is_external_encoder, simple_softmax +from .utils import simple_softmax class CnlpTrainSystem: - """This class manages the full training workflow for the cnlp_transformers library. - - The train system can be initialized directly from `CnlpModelArguments`, `CnlpDataArguments`, and `CnlpTrainingArguments`, - or using one of the following class methods: - - `from_json_args(json_file)`: Load arguments from a json file. - - `from_args_dict(args)`: Load arguments from a python dictionary. - - `from_argv(argv)`: Load arguments from `sys.argv` or a user-specified list of argv-style arguments. - """ - def __init__( self, - *, - model_args: CnlpModelArguments, - data_args: CnlpDataArguments, + model: Union[CnnModel, LstmModel, HierarchicalModel, ProjectionModel], + dataset: CnlpDataset, training_args: CnlpTrainingArguments, ): - configure_logger_for_training(training_args) - preprocess_args( - model_args=model_args, data_args=data_args, training_args=training_args - ) - self.model_args = model_args - self.data_args = data_args - self.training_args = training_args - self.disp: Union[TrainSystemDisplay, None] = None - - set_seed(self.training_args.seed) - - self.tokenizer = self._init_tokenizer() - self.dataset = self._init_dataset() - self.model = self._init_model() + self.model = model + self.dataset = dataset + self.args = training_args + self._ensure_model_dataset_compatibility() self._set_eval_strategy() - @classmethod - def from_json_args(cls, json_file: Union[str, os.PathLike]): - """Instantiate the train system from a json-formatted args file. - - Args: - json_file: Path to the json-formatted args file. - - Returns: - The new `CnlpTrainSystem` instance. - """ - model_args, data_args, training_args = parse_args_json_file(json_file) - return cls( - model_args=model_args, data_args=data_args, training_args=training_args - ) - - @classmethod - def from_args_dict(cls, args: dict[str, Any]): - """Instantiate the train system from a dict of args. - - Args: - args: Arguments for the train system. - - Returns: - The new `CnlpTrainSystem` instance. - """ - model_args, data_args, training_args = parse_args_dict(args) - return cls( - model_args=model_args, data_args=data_args, training_args=training_args - ) - - @classmethod - def from_argv(cls, argv: Union[list[str], None] = None): - """Instantiate the train system from `sys.argv` or a user-specified list of argv-style arguments. - - If `argv` is not specified, `sys.argv` will be used. - - Args: - argv: List of arguments. Optional, defaults to None. - - Returns: - The new `CnlpTrainSystem` instance. - """ - model_args, data_args, training_args = parse_args_from_argv(argv) - return cls( - model_args=model_args, data_args=data_args, training_args=training_args - ) - - def _init_tokenizer(self): - tokenizer_name = self.model_args.tokenizer_name or self.model_args.encoder_name - assert tokenizer_name is not None - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, - cache_dir=self.model_args.cache_dir, - add_prefix_space=True, - truncation_side=( - "left" if self.training_args.truncation_side_left else "right" - ), - additional_special_tokens=( - ["", "", "", "", "", "", "", ""] - if not self.data_args.character_level - else None - ), - ) - return cast(PreTrainedTokenizer, tokenizer) - - def _init_dataset(self): - return CnlpDataset( - self.data_args, - tokenizer=self.tokenizer, - hierarchical=(self.model_args.model == "hier"), - ) - - def _init_model(self): - model_name = self.model_args.model - if model_name == "cnn": - return self._init_cnn_model() - elif model_name == "lstm": - return self._init_lstm_model() - elif model_name == "hier": - return self._init_hier_model() - else: - return self._init_cnlpt_model() - - def _get_class_weights(self): - if not self.data_args.weight_classes: - return None - - class_weights: list[list[float]] = [] - for task in self.dataset.tasks: - train_labels = self.dataset.train_data[task.name] - weights: list[float] = [] - train_label_counts = Counter(train_labels) - for label in task.labels: - # class weights are determined by severity of class imbalance - weights.append( - len(train_labels) / (len(task.labels) * train_label_counts[label]) - ) - - class_weights.append(weights) - - class_weights_tensor = torch.tensor( - # if we just have the one class, simplify the tensor or pytorch will be mad - class_weights[0] if len(class_weights) == 1 else class_weights - ).to(self.training_args.device) - - return class_weights_tensor - - def _init_cnn_model(self): - model = CnnSentenceClassifier( - len(self.tokenizer), - task_names=[t.name for t in self.dataset.tasks], - num_labels_dict={t.name: len(t.labels) for t in self.dataset.tasks}, - embed_dims=self.model_args.cnn_embed_dim, - num_filters=self.model_args.cnn_num_filters, - filters=self.model_args.cnn_filter_sizes, - use_prior_tasks=self.model_args.use_prior_tasks, - class_weights=self._get_class_weights(), - ) - # Check if the caller specified a saved model to load (e.g., for an inference-only run) - assert self.model_args.encoder_name is not None - model_path = os.path.join(self.model_args.encoder_name, "pytorch_model.bin") - if os.path.exists(model_path): - model.load_state_dict(torch.load(model_path)) - - return model - - def _init_lstm_model(self): - model = LstmSentenceClassifier( - len(self.tokenizer), - task_names=[t.name for t in self.dataset.tasks], - num_labels_dict={t.name: len(t.labels) for t in self.dataset.tasks}, - embed_dims=self.model_args.lstm_embed_dim, - hidden_size=self.model_args.lstm_hidden_size, - ) - # Check if the caller specified a saved model to load (e.g., for an inference-only run) - assert self.model_args.encoder_name is not None - model_path = os.path.join(self.model_args.encoder_name, "pytorch_model.bin") - if os.path.exists(model_path): - model.load_state_dict(torch.load(model_path)) - - return model - - def _init_hier_model(self): - encoder_name = ( - self.model_args.config_name - if self.model_args.config_name - else self.model_args.encoder_name - ) - assert encoder_name is not None - if is_external_encoder(encoder_name): - config = CnlpConfig( - encoder_name=encoder_name, - finetuning_task=[t.name for t in self.dataset.tasks], - layer=self.model_args.layer or -1, - tokens=self.model_args.token, - num_rel_attention_heads=self.model_args.num_rel_feats, - rel_attention_head_dims=self.model_args.head_features, - tagger={t.name: t.type == TAGGING for t in self.dataset.tasks}, - relations={t.name: t.type == RELATIONS for t in self.dataset.tasks}, - label_dictionary={t.name: list(t.labels) for t in self.dataset.tasks}, - hier_head_config=dict( - n_layers=self.model_args.hier_num_layers, - d_inner=self.model_args.hier_hidden_dim, - n_head=self.model_args.hier_n_head, - d_k=self.model_args.hier_d_k, - d_v=self.model_args.hier_d_v, - dropout=self.model_args.hier_dropout, - ), - ) - # num_tokens=len(tokenizer)) - config.vocab_size = len(self.tokenizer) - - model = HierarchicalModel( - config=config, - # TODO(ian) as far as I can tell, this was always just None? - class_weights=None, - freeze=self.training_args.freeze, - ) - else: - if ( - self.model_args.keep_existing_classifiers - == self.model_args.ignore_existing_classifiers - ): # XNOR - raise ValueError( - "For continued training of a cnlpt hierarchical model, one of --keep_existing_classifiers or --ignore_existing_classifiers flags should be selected." - ) - # use a checkpoint from an existing model - - config: CnlpConfig = AutoConfig.from_pretrained( - encoder_name, - cache_dir=self.model_args.cache_dir, - layer=self.model_args.layer, - ) - task_is_relations = { - t.name: t.type == RELATIONS for t in self.dataset.tasks - } - task_is_tagging = {t.name: t.type == TAGGING for t in self.dataset.tasks} - - if self.model_args.ignore_existing_classifiers: - config.finetuning_task = [t.name for t in self.dataset.tasks] - config.relations = task_is_relations - config.tagger = task_is_tagging - config.label_dictionary = {} # this gets filled in later - elif self.model_args.keep_existing_classifiers: - if ( - config.finetuning_task != [t.name for t in self.dataset.tasks] - or config.relations != task_is_relations - or config.tagger != task_is_tagging - ): - raise ValueError( - "When --keep_existing_classifiers selected, please ensure" - "that you set the settings the same as those used in the" - "previous training run." - ) - - # TODO: check if user overwrote parameters in command line that could change behavior of the model and warn - # if self.data_args.chunk_len is not None: - - logger.info("Loading pre-trained hierarchical model...") - model: HierarchicalModel = AutoModel.from_pretrained( - encoder_name, config=config - ) + if not os.path.exists(self.args.output_dir): + os.mkdir(self.args.output_dir) - if self.model_args.ignore_existing_classifiers: - model.remove_task_classifiers() - for task in self.dataset.tasks: - model.add_task_classifier(task.name, list(task.labels)) - - # TODO(ian) as far as I can tell, this was always just None? - model.set_class_weights(None) - - return cast(HierarchicalModel, model) - - def _init_cnlpt_model(self): - # by default cnlpt model, but need to check which encoder they want - encoder_name = self.model_args.encoder_name - assert encoder_name is not None - - # TODO check when download any pretrained language model to local disk, if - # the following condition "is_hub_model(encoder_name)" works or not. - if not is_external_encoder(encoder_name): - # we are loading one of our own trained models as a starting point. - # - # 1) if training_args.do_train is true: - # sometimes we may want to use an encoder that has had continued pre-training, either on - # in-domain MLM or another task we think might be useful. In that case our encoder will just - # be a link to a directory. If the encoder-name is not recognized as a pre-trained model, special - # logic for ad hoc encoders follows: - # we will load it as-is initially, then delete its classifier head, save the encoder - # as a temp file, and make that temp file - # the model file to be loaded down below the normal way. since that temp file - # doesn't have a stored classifier it will use the randomly-inited classifier head - # with the size of the supplied config (for the new task). - # TODO This setting 1) is not tested yet. - # 2) if training_args.do_train is false: - # we evaluate or make predictions of our trained models. - # Both two setting require the registeration of CnlpConfig, and use - # AutoConfig.from_pretrained() to load the configuration file - - # Load the cnlp configuration using AutoConfig, this will not override - # the arguments from trained cnlp models. While using CnlpConfig will override - # the model_type and model_name of the encoder. - encoder_name = ( - self.model_args.config_name - if self.model_args.config_name - else encoder_name + def _ensure_model_dataset_compatibility(self): + if ( + isinstance(self.model, HierarchicalModel) + and self.dataset.hier_config is None + ): + raise ValueError( + "to train a hierarchical model, you need a hierarchical-formatted dataset. Pass a HierarchicalDataConfig instance to your dataset initializer." ) - config = AutoConfig.from_pretrained( - encoder_name, - cache_dir=self.model_args.cache_dir, - # in this case we're looking at a fine-tuned model (?) - character_level=self.data_args.character_level, + elif ( + not isinstance(self.model, HierarchicalModel) + and self.dataset.hier_config is not None + ): + raise ValueError( + "cannot train a non-hierarchical model on a hierarchical-formatted dataset. Be sure not to pass a HierarchicalDataConfig instance to your dataset initializer." ) - if self.training_args.do_train: - # Setting 1) only load weights from the encoder - raise NotImplementedError( - "This functionality has not been restored yet" + def _set_eval_strategy(self): + if self.args.do_train: + if self.args.evals_per_epoch > 0: + batches_per_epoch = math.ceil( + len(self.dataset.train_data) / self.args.train_batch_size ) - else: - # setting 2) evaluate or make predictions - model = CnlpModelForClassification.from_pretrained( - self.model_args.encoder_name, - config=config, - # TODO(ian) as far as I can tell, this was always just None? - class_weights=None, - final_task_weight=self.training_args.final_task_weight, - freeze=self.training_args.freeze, - bias_fit=self.training_args.bias_fit, + total_steps = int( + self.args.num_train_epochs + * batches_per_epoch + // self.args.gradient_accumulation_steps ) - else: - # This only works when model_args.encoder_name is one of the - # model card from https://huggingface.co/models - # By default, we use model card as the starting point to fine-tune - encoder_name = ( - self.model_args.config_name - if self.model_args.config_name - else encoder_name - ) - config = CnlpConfig( - encoder_name=encoder_name, - finetuning_task=[t.name for t in self.dataset.tasks], - layer=self.model_args.layer, - tokens=self.model_args.token, - num_rel_attention_heads=self.model_args.num_rel_feats, - rel_attention_head_dims=self.model_args.head_features, - tagger={t.name: t.type == TAGGING for t in self.dataset.tasks}, - relations={t.name: t.type == RELATIONS for t in self.dataset.tasks}, - label_dictionary={t.name: list(t.labels) for t in self.dataset.tasks}, - character_level=self.data_args.character_level, - # num_tokens=len(tokenizer), - ) - config.vocab_size = len(self.tokenizer) - model = CnlpModelForClassification( - config=config, - # TODO(ian) as far as I can tell, this was always just None? - class_weights=None, - final_task_weight=self.training_args.final_task_weight, - freeze=self.training_args.freeze, - bias_fit=self.training_args.bias_fit, - ) - - return cast(CnlpModelForClassification, model) - - def _set_eval_strategy(self): - if not self.training_args.do_train: - return - - batches_per_epoch = math.ceil( - len(self.dataset.train_data) / self.training_args.train_batch_size - ) - total_steps = int( - self.training_args.num_train_epochs - * batches_per_epoch - // self.training_args.gradient_accumulation_steps - ) - - if self.training_args.evals_per_epoch > 0: - logger.warning( - "Overwriting the value of logging steps based on provided evals_per_epoch argument" - ) - # steps per epoch factors in gradient accumulation steps (as compared to batches_per_epoch above which doesn't) - steps_per_epoch = int(total_steps // self.training_args.num_train_epochs) - self.training_args.eval_steps = ( - steps_per_epoch // self.training_args.evals_per_epoch - ) - self.training_args.eval_strategy = self.training_args.eval_strategy = ( - IntervalStrategy.STEPS - ) - # This will save model per epoch - # training_args.save_strategy = IntervalStrategy.EPOCH - elif self.training_args.do_eval: - logger.info("Evaluation strategy not specified so evaluating every epoch") - self.training_args.eval_strategy = self.training_args.eval_strategy = ( - IntervalStrategy.EPOCH - ) + steps_per_epoch = int(total_steps // self.args.num_train_epochs) + self.args.eval_steps = steps_per_epoch // self.args.evals_per_epoch + self.args.eval_strategy = IntervalStrategy.STEPS + elif self.args.do_eval: + self.args.eval_strategy = IntervalStrategy.EPOCH def _extract_task_predictions(self, p: EvalPrediction): task_predictions: list[TaskEvalPrediction] = [] @@ -441,7 +85,7 @@ def _extract_task_predictions(self, p: EvalPrediction): preds = np.argmax(raw_preds, axis=3) else: preds = np.argmax(raw_preds, axis=1) - if self.training_args.output_prob: + if self.args.report_probs: probs = np.max( [simple_softmax(logits) for logits in raw_preds], axis=1, @@ -455,7 +99,7 @@ def _extract_task_predictions(self, p: EvalPrediction): # we are doing inference, so no labels labels = None elif task.type == RELATIONS: - task_label_width = self.data_args.max_seq_length + task_label_width = self.dataset.max_seq_length # relation labels labels = label_ids[ :, :, task_label_offset : task_label_offset + task_label_width @@ -513,7 +157,7 @@ def _compute_metrics(self, p: EvalPrediction): result = summary_metrics | metrics - requested_metric = self.training_args.metric_for_best_model + requested_metric = self.args.metric_for_best_model if requested_metric is not None and requested_metric not in result: submetrics: list[float] = [] for sub in requested_metric.split(","): @@ -527,73 +171,75 @@ def _compute_metrics(self, p: EvalPrediction): return result @contextlib.contextmanager - def trainer(self): - trainer_callbacks: list[TrainerCallback] = [ - BasicLoggingCallback(self.model_args, self.data_args, self.training_args) - ] - - if self.training_args.rich_display and self.training_args.local_rank in (-1, 0): - self.training_args.disable_tqdm = True - self.disp = TrainSystemDisplay( - self.model_args, - self.data_args, - self.training_args, - ) + def _trainer(self): + configure_logger_for_training(self.args) + + trainer_callbacks: list[TrainerCallback] = [BasicLoggingCallback(self)] + if self.args.rich_display and self.args.local_rank in (-1, 0): + self.args.disable_tqdm = True + self.disp = TrainSystemDisplay(self) display_callback = DisplayCallback(self.disp) trainer_callbacks.append(display_callback) else: + self.disp = None display_callback = None with self.disp or contextlib.nullcontext(): trainer = Trainer( model=self.model, - args=self.training_args, + args=self.args, train_dataset=self.dataset.train_data, eval_dataset=self.dataset.validation_data, compute_metrics=self._compute_metrics, callbacks=trainer_callbacks, ) - if self.training_args.rich_display: - # remove the PrinterCallback added by default when we initialized the trainer + if self.args.rich_display: + # If tqdm is disabled, the Trainer will automatically add a PrinterCallback + # when initialized. We disable tqdm when the rich display is active, but we + # don't want the PrinterCallback in that case either, so we'll remove it + # manually here. trainer.remove_callback(PrinterCallback) yield trainer self.disp = None + def _evaluate(self, trainer: Trainer): + if self.disp: + self.disp.eval_desc = "Evaluating" + return trainer.evaluate() + + def _predict(self, trainer: Trainer, dataset: Dataset): + if self.disp: + self.disp.eval_desc = "Predicting" + raw_prediction = trainer.predict(dataset) + return CnlpPredictions( + dataset, + raw_prediction, + self.dataset.tasks, + max_seq_length=self.dataset.max_seq_length, + ) + def train(self): - """Begin the training loop.""" + """Run the training loop.""" - with self.trainer() as trainer: + with self._trainer() as trainer: if self.disp: self.disp.eval_desc = "Evaluating" - trainer.train() + trainer.train(resume_from_checkpoint=self.args.resume_from_checkpoint) trainer.save_model() - if self.training_args.do_predict: + if self.args.do_predict: predictions = self._predict(trainer, self.dataset.test_data) predictions_file = os.path.join( - self.training_args.output_dir, "predictions.json" + self.args.output_dir, "predictions.json" ) predictions.save_json( predictions_file, - allow_overwrite=self.training_args.overwrite_output_dir, + allow_overwrite=self.args.overwrite_output_dir, ) - def _evaluate(self, trainer: Trainer): - if self.disp: - self.disp.eval_desc = "Evaluating" - return trainer.evaluate() - - def _predict(self, trainer: Trainer, dataset: Dataset): - if self.disp: - self.disp.eval_desc = "Predicting" - raw_prediction = trainer.predict(dataset) - return CnlpPredictions( - dataset, raw_prediction, self.dataset.tasks, self.data_args - ) - def evaluate(self) -> dict[str, float]: """Run an evaluation on the valdiation set. @@ -601,7 +247,7 @@ def evaluate(self) -> dict[str, float]: Evaluation metrics. """ - with self.trainer() as trainer: + with self._trainer() as trainer: return self._evaluate(trainer) def predict(self, dataset: Union[Dataset, None] = None) -> CnlpPredictions: @@ -609,16 +255,11 @@ def predict(self, dataset: Union[Dataset, None] = None) -> CnlpPredictions: Args: dataset: Dataset to run predictions. Optional, defaults to the test data in this - train system's dataset. + train system's dataset. Returns: The prediction output. """ - with self.trainer() as trainer: + with self._trainer() as trainer: return self._predict(trainer, dataset or self.dataset.test_data) - - -def main(argv: Union[list[str], None] = None): - train_system = CnlpTrainSystem.from_argv(argv) - train_system.train() diff --git a/src/cnlpt/train_system/display.py b/src/cnlpt/train_system/display.py index 458c1c67..4e69c618 100644 --- a/src/cnlpt/train_system/display.py +++ b/src/cnlpt/train_system/display.py @@ -1,5 +1,5 @@ import os -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Union from rich.console import Console from rich.live import Live @@ -14,7 +14,8 @@ ) from rich.table import Table -from ..args import CnlpDataArguments, CnlpModelArguments, CnlpTrainingArguments +if TYPE_CHECKING: + from .cnlp_train_system import CnlpTrainSystem console = Console() @@ -24,15 +25,8 @@ def _val_fmt(x): class TrainSystemDisplay: - def __init__( - self, - model_args: CnlpModelArguments, - data_args: CnlpDataArguments, - training_args: CnlpTrainingArguments, - ): - self.model_args = model_args - self.data_args = data_args - self.training_args = training_args + def __init__(self, train_system: "CnlpTrainSystem"): + self.train_system = train_system self.eval_desc = "Evaluating" @@ -64,10 +58,10 @@ def title(self): @property def subtitle(self): - if self.training_args.output_dir is None: + if self.train_system.args.output_dir is None: return "?" logfile = os.path.join( - os.path.abspath(self.training_args.output_dir), "train_system.log" + os.path.abspath(self.train_system.args.output_dir), "train_system.log" ) return f"Training log: {logfile}" @@ -101,7 +95,7 @@ def format_metric_name(metric_name: str): _val_fmt(self.eval_metrics[metric_name]), _val_fmt(self.best_eval_metrics[metric_name]), ] - if metric_name == self.training_args.metric_for_best_model: + if metric_name == self.train_system.args.metric_for_best_model: row[0] = f"[bold][cyan]> {row[0]}" row[1] = f"[bold]{row[1]}" row[2] = f"[bold]{row[2]}" @@ -140,14 +134,13 @@ def body(self): meta.add_column(style="blue", justify="right") meta.add_column() out_dir_abspath = ( - os.path.abspath(self.training_args.output_dir) - if self.training_args.output_dir is not None + os.path.abspath(self.train_system.args.output_dir) + if self.train_system.args.output_dir is not None else "?" ) meta.add_row("Output dir:", out_dir_abspath) - meta.add_row("Dataset:", ", ".join(self.data_args.data_dir)) - meta.add_row("Encoder:", str(self.model_args.encoder_name)) - meta.add_row("Model type:", str(self.model_args.model)) + meta.add_row("Dataset:", str(self.train_system.dataset.data_dir)) + meta.add_row("Model type:", type(self.train_system.model).__name__) stats = Table.grid(padding=(1, 1)) stats.add_column(style="blue", justify="right") diff --git a/src/cnlpt/train_system/log.py b/src/cnlpt/train_system/log.py index c125897d..3a359c16 100644 --- a/src/cnlpt/train_system/log.py +++ b/src/cnlpt/train_system/log.py @@ -5,7 +5,7 @@ from transformers import logging as transformers_logging -from ..args import CnlpTrainingArguments +from .args import CnlpTrainingArguments logger = logging.getLogger("cnlpt.train_system") diff --git a/src/cnlpt/train_system/training_callbacks.py b/src/cnlpt/train_system/training_callbacks.py index b4bf4291..95dff1be 100644 --- a/src/cnlpt/train_system/training_callbacks.py +++ b/src/cnlpt/train_system/training_callbacks.py @@ -1,5 +1,5 @@ from dataclasses import asdict -from typing import Union +from typing import TYPE_CHECKING, Union from transformers.trainer_callback import ( TrainerCallback, @@ -9,10 +9,12 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, SaveStrategy from transformers.training_args import TrainingArguments -from ..args import CnlpDataArguments, CnlpModelArguments, CnlpTrainingArguments from .display import TrainSystemDisplay from .log import logger +if TYPE_CHECKING: + from .cnlp_train_system import CnlpTrainSystem + class DisplayCallback(TrainerCallback): def __init__( @@ -146,19 +148,15 @@ def on_predict( **kwargs, ): self.display.finish_eval() + self.display.update() self.current_eval_step = 0 class BasicLoggingCallback(TrainerCallback): - def __init__( - self, - model_args: CnlpModelArguments, - data_args: CnlpDataArguments, - training_args: CnlpTrainingArguments, - ): - self.model_args = model_args - self.data_args = data_args - self.training_args = training_args + def __init__(self, train_system: "CnlpTrainSystem"): + self.train_system = train_system + for arg_name, arg_value in asdict(train_system.args).items(): + logger.info(f"{arg_name}={arg_value}") def on_train_begin( self, @@ -167,14 +165,6 @@ def on_train_begin( control: TrainerControl, **kwargs, ): - logger.info("*** TRAIN SYSTEM ARGS ***") - for args_data, prefix in ( - (self.model_args, "model_args"), - (self.data_args, "data_args"), - (self.training_args, "training_args"), - ): - for arg, val in sorted(asdict(args_data).items()): - logger.info("%s.%s: %s", prefix, arg, val) logger.info("*** STARTING TRAINING ***") def on_log( diff --git a/src/cnlpt/train_system/utils.py b/src/cnlpt/train_system/utils.py index 07cccba3..a4c2f05a 100644 --- a/src/cnlpt/train_system/utils.py +++ b/src/cnlpt/train_system/utils.py @@ -14,7 +14,7 @@ def is_cnlpt_model(model_path: str) -> bool: Whether the model is a cnlpt classifier model. """ encoder_config = AutoConfig.from_pretrained(model_path) - return encoder_config.model_type == "cnlpt" + return hasattr(encoder_config, "cnlpt_version") def is_external_encoder(model_name_or_path: str) -> bool: From 5e1a449a24308ca0084c0161b9ff2bd02dd94463 Mon Sep 17 00:00:00 2001 From: ianbulovic Date: Fri, 29 Aug 2025 13:13:06 -0400 Subject: [PATCH 3/8] rework cli --- pyproject.toml | 5 +- src/cnlpt/__main__.py | 4 +- src/cnlpt/_cli/main.py | 68 +++-- src/cnlpt/_cli/rest.py | 83 ++++--- src/cnlpt/_cli/train.py | 517 +++++++++++++++++++++++++++++++++++++- uv.lock | 531 +++++++++++++++++++++++++++++++++++++++- 6 files changed, 1134 insertions(+), 74 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e28ccda0..af893b7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,6 @@ readme = "README.md" requires-python = ">=3.9, <3.13" dependencies = [ "anaforatools~=1.1.0", - "click>=8.1.7", "datasets~=2.21.0", "fastapi~=0.115.2", "httpx>=0.27.2", @@ -25,10 +24,12 @@ dependencies = [ "polars>=1.30.0", "pydantic~=1.10.8", "requests~=2.32.2", + "rich>=14.0.0", "scikit-learn~=1.5.2", "seqeval~=1.2.2", "torch>=2.6", "transformers[torch]~=4.51", + "typer~=0.16.0", "uvicorn[standard]~=0.32.0", ] @@ -57,11 +58,13 @@ lint = [ build = ["build", "pip>=21.3", "twine"] test = ["lorem-text>=3.0", "pytest"] docs = ["sphinx", "sphinx-autodoc-typehints", "sphinx-rtd-theme", "sphobjinv"] +notebooks = ["ipykernel", "ipywidgets"] dev = [ { include-group = "build" }, { include-group = "docs" }, { include-group = "lint" }, { include-group = "test" }, + { include-group = "notebooks" }, "pre-commit", ] diff --git a/src/cnlpt/__main__.py b/src/cnlpt/__main__.py index 52d81be7..df2bbc89 100644 --- a/src/cnlpt/__main__.py +++ b/src/cnlpt/__main__.py @@ -1,8 +1,8 @@ -from ._cli.main import cli +from ._cli.main import main as cli_main def main(): - cli() + cli_main() if __name__ == "__main__": diff --git a/src/cnlpt/_cli/main.py b/src/cnlpt/_cli/main.py index 9d5744fc..ed8f62ce 100644 --- a/src/cnlpt/_cli/main.py +++ b/src/cnlpt/_cli/main.py @@ -1,30 +1,46 @@ -import click - -from .. import __version__ -from .rest import rest_command -from .train import train_command - - -@click.group(invoke_without_command=True) -@click.option( - "--version", - type=bool, - is_flag=True, - default=False, - help="Print the cnlp_transformers version.", -) -@click.pass_context -def cli(ctx: click.Context, version: bool): - if ctx.invoked_subcommand is not None: - return +from typing import Annotated +import rich +import typer +from typer.core import DEFAULT_MARKUP_MODE + +from .. import __version__ as cnlpt_version +from . import rest, train + +app = typer.Typer(add_completion=False, rich_markup_mode=DEFAULT_MARKUP_MODE) + + +app.command(no_args_is_help=True)(rest.rest) +app.command( + no_args_is_help=True, + context_settings={ + "allow_extra_args": True, + "ignore_unknown_options": True, + }, + epilog=train.TRAIN_EPILOG, +)(train.train) + + +def version_callback(version: bool): if version: - print(__version__) - ctx.exit() - else: - click.echo(ctx.get_help()) - ctx.exit() + rich.print(f"cnlp_transformers version: [b cyan]{cnlpt_version}") + raise typer.Exit() + + +@app.callback(no_args_is_help=True) +def cli( + version: Annotated[ + bool, + typer.Option( + "--version", + help="Show the cnlp_transformers version and exit.", + is_eager=True, + callback=version_callback, + ), + ] = False, +): + pass -cli.add_command(rest_command) -cli.add_command(train_command) +def main(): + app() diff --git a/src/cnlpt/_cli/rest.py b/src/cnlpt/_cli/rest.py index dd857277..d82e01c6 100644 --- a/src/cnlpt/_cli/rest.py +++ b/src/cnlpt/_cli/rest.py @@ -1,33 +1,56 @@ -import click - -from ..api import MODEL_TYPES, get_rest_app - - -@click.command("rest", context_settings={"show_default": True}) -@click.option( - "--model-type", - type=click.Choice(MODEL_TYPES), - required=True, -) -@click.option( - "-h", - "--host", - type=str, - default="0.0.0.0", - help="Host address to serve the REST app.", -) -@click.option( - "-p", "--port", type=int, default=8000, help="Port to serve the REST app." -) -@click.option( - "--reload", - type=bool, - is_flag=True, - default=False, - help="Auto-reload the REST app.", -) -def rest_command(model_type: str, host: str, port: int, reload: bool): +from typing import Annotated + +import typer + + +def parse_models(ctx: typer.Context, param: typer.CallbackParam, value: str): + if value is None: + return None + + models: list[tuple[str, str]] = [] + if isinstance(value, str): + value = [value] + for item in value: + if "=" in item: + prefix, path = item.split("=", 1) + if not prefix.startswith("/"): + raise typer.BadParameter( + f"route prefix must start with '/': {prefix}", param=param + ) + elif len(value) > 1: + raise typer.BadParameter( + "route prefixes are required when serving more than one model", + param=param, + ) + else: + path = item + prefix = "" + models.append((prefix, path)) + return models + + +def rest( + models: Annotated[ + list[str], + typer.Option( + "--model", + callback=parse_models, + help="Model definition as [ROUTER_PREFIX=]PATH_TO_MODEL. Route prefix must start with '/'. This option can be specified multiple times to serve multiple models simultaneously. Route prefixes are required when serving more than one model.", + ), + ], + host: Annotated[ + str, typer.Option("-h", "--host", help="Host address to serve the REST app.") + ] = "0.0.0.0", + port: Annotated[ + int, typer.Option("-p", "--port", help="Port to serve the REST app.") + ] = 8000, +): """Start a REST application from a model.""" import uvicorn - uvicorn.run(get_rest_app(model_type), host=host, port=port, reload=reload) + from ..rest import CnlpRestApp + + app = CnlpRestApp.multi_app( + [(CnlpRestApp(model_path=path), prefix) for prefix, path in models] + ) + uvicorn.run(app, host=host, port=port) diff --git a/src/cnlpt/_cli/train.py b/src/cnlpt/_cli/train.py index bfa5b729..18fb8d22 100644 --- a/src/cnlpt/_cli/train.py +++ b/src/cnlpt/_cli/train.py @@ -1,17 +1,510 @@ -import click +from enum import Enum +from typing import Annotated, Any, Final, Union -from ..train_system.cnlp_train_system import main as train_system +import typer +from click.core import ParameterSource +from transformers.hf_argparser import HfArgumentParser +from transformers.models.auto.modeling_auto import AutoModel +from transformers.trainer_utils import IntervalStrategy +from transformers.training_args import TrainingArguments +from ..data.cnlp_dataset import CnlpDataset, HierarchicalDataConfig, TruncationSide +from ..modeling.config.cnn_config import CnnModelConfig +from ..modeling.config.hierarchical_config import HierarchicalModelConfig +from ..modeling.config.lstm_config import LstmModelConfig +from ..modeling.config.projection_config import ProjectionModelConfig +from ..modeling.models import CnnModel, HierarchicalModel, LstmModel, ProjectionModel +from ..modeling.types import ClassificationMode, ModelType +from ..train_system.args import CnlpTrainingArguments +from ..train_system.cnlp_train_system import CnlpTrainSystem -@click.command( - "train", - context_settings=dict( - ignore_unknown_options=True, +_ARG_COMPAT_METADATA_KEY = "cnlpt.model_arg_compat" +DEFAULT_ENCODER: Final = "roberta-base" + + +def compatible_models(types: list[ModelType]): + def callback(ctx: typer.Context, param: typer.CallbackParam, value: Any): + if ctx.resilient_parsing: + return + if _ARG_COMPAT_METADATA_KEY not in ctx.meta: + ctx.meta[_ARG_COMPAT_METADATA_KEY] = {} + ctx.meta[_ARG_COMPAT_METADATA_KEY][param.name] = types + if isinstance(value, Enum): + return value.value + return value + + return callback + + +def training_arg_option( + field_name: str, + *aliases, + compatibility: Union[list[ModelType], None] = None, + **kwargs, +): + field = CnlpTrainingArguments.__dataclass_fields__[field_name] + if len(aliases) == 0: + aliases = (f"--{field_name}",) + if compatibility is not None: + kwargs["callback"] = compatible_models(compatibility) + return typer.Option( + *aliases, + help=field.metadata.get("help", None), + rich_help_panel="CNLPT Training Arguments", + **kwargs, + ) + + +def model_arg_option( + *args, + compatibility: Union[list[ModelType], None] = None, + **kwargs, +): + if compatibility is not None: + kwargs["callback"] = compatible_models(compatibility) + return typer.Option(*args, rich_help_panel="Model Arguments", **kwargs) + + +def data_arg_option( + *args, + compatibility: Union[list[ModelType], None] = None, + **kwargs, +): + if compatibility is not None: + kwargs["callback"] = compatible_models(compatibility) + return typer.Option(*args, rich_help_panel="Data Arguments", **kwargs) + + +##### MODEL ARGS ##### +ModelTypeArg = Annotated[ + ModelType, + model_arg_option( + "--model_type", + help="The type of model to load.", + case_sensitive=False, + ), +] +EncoderArg = Annotated[ + str, + model_arg_option( + "--encoder", + compatibility=["proj", "hier"], + help="For projection and hierarchical models, which encoder model to use.", + ), +] +UsePriorTasksArg = Annotated[ + bool, + model_arg_option( + "--use_prior_tasks", + compatibility=["proj", "cnn"], + help="For projection and CNN models, whether to use the output of prior tasks as an input to subsequent ones.", + ), +] +EncoderLayerArg = Annotated[ + int, + model_arg_option( + "--encoder_layer", + "--layer", + compatibility=["proj", "hier"], + help="For projection and hierarchical models, which layer of the encoder to use for representation.", + ), +] +ClassificationModeArg = Annotated[ + ClassificationMode, + model_arg_option( + "--classification_mode", + compatibility=["proj"], + help="For projection models, chooses whether to classify from the [CLS] token or from a token span tagged with .", + case_sensitive=False, + ), +] +RelationAttnHeadsArg = Annotated[ + int, + model_arg_option( + "--relation_attn_heads", + compatibility=["proj"], + help="For projection models, the number of relation attention heads to use for relation extraction tasks.", + ), +] +RelationAttnHeadDimArg = Annotated[ + int, + model_arg_option( + "--relation_attn_head_dim", + compatibility=["proj"], + help="For projection models, the dimension of attention heads for relation extraction tasks.", + ), +] +HierLayersArg = Annotated[ + int, + model_arg_option( + "--hier_layers", + compatibility=["hier"], + help="For hierarchical models, the number of hierarchical layers.", + ), +] +HierUseLayerArg = Annotated[ + int, + model_arg_option( + "--hier_layers", + compatibility=["hier"], + help="For hierarchical models, the layer to use for classification.", + ), +] +HierHiddenDimArg = Annotated[ + int, + model_arg_option( + "--hier_hidden_dim", + compatibility=["hier"], + help="For hierarchical models, the hidden dimension of the FFN in each layer.", + ), +] +HierHeadsArg = Annotated[ + int, + model_arg_option( + "--hier_heads", + compatibility=["hier"], + help="For hierarchical models, the number of attention heads.", + ), +] +HierQKDimArg = Annotated[ + int, + model_arg_option( + "--hier_qk_dim", + compatibility=["hier"], + help="For hierarchical models, the dimension of the query and key vectors.", + ), +] +HierVDimArg = Annotated[ + int, + model_arg_option( + "--hier_v_dim", + compatibility=["hier"], + help="For hierarchical models, the dimension of the value vectors.", + ), +] +DropoutArg = Annotated[ + # TODO(ian): should this be available for proj models too? + float, + model_arg_option( + "--dropout", + compatibility=["hier", "cnn", "lstm"], + help="For hierarchical, CNN, and LSTM models, the dropout probability.", + ), +] +EmbedDimArg = Annotated[ + int, + model_arg_option( + "--embed_dim", + compatibility=["cnn", "lstm"], + help="For CNN and LSTM models, the embedding dimension.", + ), +] +CnnNumFiltersArg = Annotated[ + int, + model_arg_option( + "--cnn_num_filters", + compatibility=["cnn"], + help="For CNN models, the number of filters per filter size.", ), - add_help_option=False, -) -@click.argument("train_args", nargs=-1, type=click.UNPROCESSED) -def train_command(train_args: list[str]): - "Fine-tune models for clinical NLP." +] +CnnFilterSizesArg = Annotated[ + str, + model_arg_option( + "--cnn_filter_sizes", + compatibility=["cnn"], + help="For CNN models, a comma-separated list of filter sizes to use.", + ), +] +LstmHiddenSizeArg = Annotated[ + int, + model_arg_option( + "--lstm_hidden_size", + compatibility=["lstm"], + help="LSTM models, the dimension of the hidden layer.", + ), +] + +##### DATA ARGS ##### +DataDirArg = Annotated[ + str, + data_arg_option( + "--data_dir", help="Path to a directory containing CNLPT-formatted data." + ), +] +TaskNamesArg = Annotated[ + Union[list[str], None], + data_arg_option( + "--task", + "-t", + help="The name of a task in the dataset to train on. Can be specified multiple times to target more than one task. Defaults to all tasks.", + ), +] +TokenizerArg = Annotated[ + Union[str, None], + data_arg_option( + "--tokenizer", + help=f'Name or path to a model to use for tokenization. For projection and hierarchical models, this will default to the --encoder_name if left unspecified; otherwise defaults to "{DEFAULT_ENCODER}".', + ), +] +TruncationSideArg = Annotated[ + TruncationSide, + data_arg_option( + "--truncation_side", + help="Which side to perform truncation when tokenizing. Note that hierarchical models don't support left-side truncation.", + compatibility=["cnn", "lstm", "proj"], + case_sensitive=False, + ), +] +MaxSeqLengthArg = Annotated[ + int, + data_arg_option( + "--max_seq_length", help="Maximum sequence length for tokenization." + ), +] +OverwriteDataCacheArg = Annotated[ + bool, + data_arg_option( + "--overwrite_data_cache", + help="Overwrite the data cache to force re-preprocessing of the data.", + ), +] +MaxTrainArg = Annotated[ + Union[int, None], + data_arg_option("--max_train", help="Limit the number of training samples to use."), +] +MaxEvalArg = Annotated[ + Union[int, None], + data_arg_option("--max_eval", help="Limit the number of eval samples to use."), +] +MaxTestArg = Annotated[ + Union[int, None], + data_arg_option("--max_test", help="Limit the number of test samples to use."), +] +AllowDisjointLabelsArg = Annotated[ + bool, + data_arg_option( + "--allow_disjoint_labels", + help="Allow disjoint label sets for tasks in different data splits. Can be useful for debugging.", + ), +] +CharacterLevelArg = Annotated[ + bool, + data_arg_option( + "--character_level", + help=".Whether the dataset sould be processed at the character level (otherwise will be processed at the token level).", + ), +] +HierChunkLenArg = Annotated[ + Union[int, None], + data_arg_option("--hier_chunk_len", help="Chunk length for hierarchical models."), +] +HierNumChunksArg = Annotated[ + Union[int, None], + data_arg_option( + "--hier_num_chunks", help="Number of chunks for hierarchical models." + ), +] +HierPrependEmptyChunkArg = Annotated[ + Union[int, None], + data_arg_option( + "--hier_prepend_empty_chunk", + help="Whether to prepend an empty chunk for hierarchical models.", + ), +] + +##### TRAINING ARGS ##### +WeightClassesArg = Annotated[bool, training_arg_option("weight_classes")] +FinalTaskWeightArg = Annotated[float, training_arg_option("final_task_weight")] +FreezeEncoderArg = Annotated[float, training_arg_option("freeze_encoder")] +BiasFitArg = Annotated[bool, training_arg_option("bias_fit")] +ReportProbsArg = Annotated[bool, training_arg_option("report_probs")] +EvalsPerEpochArg = Annotated[int, training_arg_option("evals_per_epoch")] +RichDisplayArg = Annotated[bool, training_arg_option("rich_display")] +LoggingStrategyArg = Annotated[ + IntervalStrategy, training_arg_option("logging_strategy") +] +LoggingFirstStepArg = Annotated[bool, training_arg_option("logging_first_step")] +CacheDirArg = Annotated[Union[str, None], training_arg_option("cache_dir")] + + +def train( + ctx: typer.Context, + # ------------------ # + # MODEL ARGS # + # ------------------ # + model_type: ModelTypeArg = ..., + encoder_name: EncoderArg = DEFAULT_ENCODER, + use_prior_tasks: UsePriorTasksArg = False, + encoder_layer: EncoderLayerArg = -1, + classification_mode: ClassificationModeArg = "cls", + relation_attn_heads: RelationAttnHeadsArg = 12, + relation_attn_head_dim: RelationAttnHeadDimArg = 64, + hier_layers: HierLayersArg = 8, + hier_use_layer: HierUseLayerArg = -1, + hier_hidden_dim: HierHiddenDimArg = 2048, + hier_heads: HierHeadsArg = 8, + hier_qk_dim: HierQKDimArg = 8, + hier_v_dim: HierVDimArg = 96, + dropout: DropoutArg = 0.1, + embed_dim: EmbedDimArg = 100, + cnn_num_filters: CnnNumFiltersArg = 25, + cnn_filter_sizes: CnnFilterSizesArg = "1,2,3", + lstm_hidden_size: LstmHiddenSizeArg = 100, + # ----------------- # + # DATA ARGS # + # ----------------- # + data_dir: DataDirArg = ..., + task_names: TaskNamesArg = None, + tokenizer: TokenizerArg = DEFAULT_ENCODER, + truncation_side: TruncationSideArg = "right", + max_seq_length: MaxSeqLengthArg = 128, + overwrite_data_cache: OverwriteDataCacheArg = False, + max_train: MaxTrainArg = None, + max_eval: MaxEvalArg = None, + max_test: MaxTestArg = None, + allow_disjoint_labels: AllowDisjointLabelsArg = False, + character_level: CharacterLevelArg = False, + hier_chunk_len: HierChunkLenArg = None, + hier_num_chunks: HierNumChunksArg = None, + hier_prepend_empty_chunk: HierPrependEmptyChunkArg = None, + # --------------------- # + # TRAINING ARGS # + # --------------------- # + weight_classes: WeightClassesArg = False, + final_task_weight: FinalTaskWeightArg = 1.0, + freeze_encoder: FreezeEncoderArg = 0.0, + bias_fit: BiasFitArg = False, + report_probs: ReportProbsArg = False, + evals_per_epoch: EvalsPerEpochArg = 0, + rich_display: RichDisplayArg = True, + logging_strategy: LoggingStrategyArg = "epoch", + logging_first_step: LoggingFirstStepArg = True, + cache_dir: CacheDirArg = None, + # --------------------- # + **kwargs, +): + # TODO(ian): it's probably worth making this docstring pretty descriptive + """Run the cnlp_transformers training system.""" + + # If the tokenizer wasn't explicitly specified and this is a model + # that accepts an encoder, use the encoder's tokenizer. + if ctx.get_parameter_source("tokenizer") != ParameterSource.COMMANDLINE and ( + model_type in (ModelType.HIER, ModelType.PROJ) + ): + tokenizer = encoder_name + + dataset = CnlpDataset( + data_dir=data_dir, + tokenizer=tokenizer, + task_names=task_names, + hier_config=( + HierarchicalDataConfig( + hier_chunk_len, hier_num_chunks, hier_prepend_empty_chunk + ) + if model_type == ModelType.HIER + else None + ), + truncation_side=truncation_side, + max_seq_length=max_seq_length, + use_data_cache=not overwrite_data_cache, + max_train=max_train, + max_eval=max_eval, + max_test=max_test, + allow_disjoint_labels=allow_disjoint_labels, + character_level=character_level, + hf_cache_dir=cache_dir, + ) + + # Since CnlpTrainingArguments inherits from the transformers TrainingArguments, + # rather than maintain explicit arguments for all TrainingArguments fields we'll + # just leave them unprocessed and pass any unknown args to the transformers parser. + hf_args_parser = HfArgumentParser(TrainingArguments) + hf_training_args, extra_args = hf_args_parser.parse_known_args(ctx.args) + + if len(extra_args) > 0: + raise typer.BadParameter(f"unrecognized arguments: {extra_args!s}", ctx) + + # Combine our args with the args parsed by transformers. We must use the `|` + # operator in this order so that our args take precedence. + training_args = CnlpTrainingArguments( + **( + vars(hf_training_args) + | dict( + weight_classes=weight_classes, + final_task_weight=final_task_weight, + freeze_encoder=freeze_encoder, + bias_fit=bias_fit, + report_probs=report_probs, + evals_per_epoch=evals_per_epoch, + rich_display=rich_display, + logging_strategy=logging_strategy, + logging_first_step=logging_first_step, + ) + ) + ) + + if model_type == ModelType.CNN: + config = CnnModelConfig( + tasks=list(dataset.tasks), + vocab_size=len(dataset.tokenizer), + use_prior_tasks=use_prior_tasks, + embed_dim=embed_dim, + num_filters=cnn_num_filters, + filter_sizes=tuple([int(s.strip()) for s in cnn_filter_sizes.split(",")]), + dropout=dropout, + ) + elif model_type == ModelType.LSTM: + config = LstmModelConfig( + tasks=list(dataset.tasks), + vocab_size=len(dataset.tokenizer), + embed_dim=embed_dim, + hidden_size=lstm_hidden_size, + dropout=dropout, + ) + elif model_type == ModelType.HIER: + config = HierarchicalModelConfig( + tasks=list(dataset.tasks), + vocab_size=len(dataset.tokenizer), + encoder_name=encoder_name, + layer=hier_use_layer, + n_layers=hier_layers, + d_inner=hier_hidden_dim, + n_head=hier_heads, + d_k=hier_qk_dim, + d_v=hier_v_dim, + dropout=dropout, + ) + elif model_type == ModelType.PROJ: + config = ProjectionModelConfig( + tasks=list(dataset.tasks), + vocab_size=len(dataset.tokenizer), + encoder_name=encoder_name, + encoder_layer=encoder_layer, + use_prior_tasks=use_prior_tasks, + classification_mode=classification_mode, + num_rel_attention_heads=relation_attn_heads, + rel_attention_head_dims=relation_attn_head_dim, + character_level=character_level, + ) + + model_init_kwargs = {} + if weight_classes: + model_init_kwargs["class_weights"] = dataset.get_class_weights( + training_args.device + ) + if freeze_encoder > 0: + model_init_kwargs["freeze"] = freeze_encoder + if final_task_weight != 1.0: + model_init_kwargs["final_task_weight"] = final_task_weight + if bias_fit: + model_init_kwargs["bias_fit"] = True + + model: Union[CnnModel, LstmModel, HierarchicalModel, ProjectionModel] = ( + AutoModel.from_config(config, **model_init_kwargs) + ) + train_system = CnlpTrainSystem(model, dataset, training_args) + train_system.train() + - train_system(argv=list(train_args)) +TRAIN_EPILOG = """[red]More training arguments are available, see the +[b blue][link=https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments]HF Transformers documentation[/link][/b blue].""" diff --git a/uv.lock b/uv.lock index f04e58a3..082cad6d 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.9, <3.13" resolution-markers = [ "python_full_version >= '3.12'", @@ -179,6 +179,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a1/ee/48ca1a7c89ffec8b6a0c5d02b89c305671d5ffd8d3c94acf8b8c408575bb/anyio-4.9.0-py3-none-any.whl", hash = "sha256:9f76d541cad6e36af7beb62e978876f3b41e3e04f2c1fbf0884604c0a9c4d93c", size = 100916, upload-time = "2025-03-17T00:02:52.713Z" }, ] +[[package]] +name = "appnope" +version = "0.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/35/5d/752690df9ef5b76e169e68d6a129fa6d08a7100ca7f754c89495db3c6019/appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee", size = 4170, upload-time = "2024-02-06T09:43:11.258Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/29/5ecc3a15d5a33e31b26c11426c45c501e439cb865d0bff96315d86443b78/appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c", size = 4321, upload-time = "2024-02-06T09:43:09.663Z" }, +] + +[[package]] +name = "asttokens" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/e7/82da0a03e7ba5141f05cce0d302e6eed121ae055e0456ca228bf693984bc/asttokens-3.0.0.tar.gz", hash = "sha256:0dcd8baa8d62b0c1d118b399b2ddba3c4aff271d0d7a9e0d4c1681c79035bbc7", size = 61978, upload-time = "2024-11-30T04:30:14.439Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2", size = 26918, upload-time = "2024-11-30T04:30:10.946Z" }, +] + [[package]] name = "async-timeout" version = "5.0.1" @@ -249,6 +267,8 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/fc/97/c783634659c2920c3fc70419e3af40972dbaf758daa229a7d6ea6135c90d/cffi-1.17.1.tar.gz", hash = "sha256:1c39c6016c32bc48dd54561950ebd6836e1670f2ae46128f67cf49e789c52824", size = 516621, upload-time = "2024-09-04T20:45:21.852Z" } wheels = [ + { url = "https://files.pythonhosted.org/packages/90/07/f44ca684db4e4f08a3fdc6eeb9a0d15dc6883efc7b8c90357fdbf74e186c/cffi-1.17.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:df8b1c11f177bc2313ec4b2d46baec87a5f3e71fc8b45dab2ee7cae86d9aba14", size = 182191, upload-time = "2024-09-04T20:43:30.027Z" }, + { url = "https://files.pythonhosted.org/packages/08/fd/cc2fedbd887223f9f5d170c96e57cbf655df9831a6546c1727ae13fa977a/cffi-1.17.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8f2cdc858323644ab277e9bb925ad72ae0e67f69e804f4898c070998d50b1a67", size = 178592, upload-time = "2024-09-04T20:43:32.108Z" }, { url = "https://files.pythonhosted.org/packages/de/cc/4635c320081c78d6ffc2cab0a76025b691a91204f4aa317d568ff9280a2d/cffi-1.17.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edae79245293e15384b51f88b00613ba9f7198016a5948b5dddf4917d4d26382", size = 426024, upload-time = "2024-09-04T20:43:34.186Z" }, { url = "https://files.pythonhosted.org/packages/b6/7b/3b2b250f3aab91abe5f8a51ada1b717935fdaec53f790ad4100fe2ec64d1/cffi-1.17.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45398b671ac6d70e67da8e4224a065cec6a93541bb7aebe1b198a61b58c7b702", size = 448188, upload-time = "2024-09-04T20:43:36.286Z" }, { url = "https://files.pythonhosted.org/packages/d3/48/1b9283ebbf0ec065148d8de05d647a986c5f22586b18120020452fff8f5d/cffi-1.17.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad9413ccdeda48c5afdae7e4fa2192157e991ff761e7ab8fdd8926f40b160cc3", size = 455571, upload-time = "2024-09-04T20:43:38.586Z" }, @@ -257,6 +277,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ab/a0/62f00bcb411332106c02b663b26f3545a9ef136f80d5df746c05878f8c4b/cffi-1.17.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:045d61c734659cc045141be4bae381a41d89b741f795af1dd018bfb532fd0df8", size = 461325, upload-time = "2024-09-04T20:43:43.117Z" }, { url = "https://files.pythonhosted.org/packages/36/83/76127035ed2e7e27b0787604d99da630ac3123bfb02d8e80c633f218a11d/cffi-1.17.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:6883e737d7d9e4899a8a695e00ec36bd4e5e4f18fabe0aca0efe0a4b44cdb13e", size = 438784, upload-time = "2024-09-04T20:43:45.256Z" }, { url = "https://files.pythonhosted.org/packages/21/81/a6cd025db2f08ac88b901b745c163d884641909641f9b826e8cb87645942/cffi-1.17.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:6b8b4a92e1c65048ff98cfe1f735ef8f1ceb72e3d5f0c25fdb12087a23da22be", size = 461564, upload-time = "2024-09-04T20:43:46.779Z" }, + { url = "https://files.pythonhosted.org/packages/f8/fe/4d41c2f200c4a457933dbd98d3cf4e911870877bd94d9656cc0fcb390681/cffi-1.17.1-cp310-cp310-win32.whl", hash = "sha256:c9c3d058ebabb74db66e431095118094d06abf53284d9c81f27300d0e0d8bc7c", size = 171804, upload-time = "2024-09-04T20:43:48.186Z" }, + { url = "https://files.pythonhosted.org/packages/d1/b6/0b0f5ab93b0df4acc49cae758c81fe4e5ef26c3ae2e10cc69249dfd8b3ab/cffi-1.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:0f048dcf80db46f0098ccac01132761580d28e28bc0f78ae0d58048063317e15", size = 181299, upload-time = "2024-09-04T20:43:49.812Z" }, + { url = "https://files.pythonhosted.org/packages/6b/f4/927e3a8899e52a27fa57a48607ff7dc91a9ebe97399b357b85a0c7892e00/cffi-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a45e3c6913c5b87b3ff120dcdc03f6131fa0065027d0ed7ee6190736a74cd401", size = 182264, upload-time = "2024-09-04T20:43:51.124Z" }, + { url = "https://files.pythonhosted.org/packages/6c/f5/6c3a8efe5f503175aaddcbea6ad0d2c96dad6f5abb205750d1b3df44ef29/cffi-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:30c5e0cb5ae493c04c8b42916e52ca38079f1b235c2f8ae5f4527b963c401caf", size = 178651, upload-time = "2024-09-04T20:43:52.872Z" }, { url = "https://files.pythonhosted.org/packages/94/dd/a3f0118e688d1b1a57553da23b16bdade96d2f9bcda4d32e7d2838047ff7/cffi-1.17.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f75c7ab1f9e4aca5414ed4d8e5c0e303a34f4421f8a0d47a4d019ceff0ab6af4", size = 445259, upload-time = "2024-09-04T20:43:56.123Z" }, { url = "https://files.pythonhosted.org/packages/2e/ea/70ce63780f096e16ce8588efe039d3c4f91deb1dc01e9c73a287939c79a6/cffi-1.17.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a1ed2dd2972641495a3ec98445e09766f077aee98a1c896dcb4ad0d303628e41", size = 469200, upload-time = "2024-09-04T20:43:57.891Z" }, { url = "https://files.pythonhosted.org/packages/1c/a0/a4fa9f4f781bda074c3ddd57a572b060fa0df7655d2a4247bbe277200146/cffi-1.17.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:46bf43160c1a35f7ec506d254e5c890f3c03648a4dbac12d624e4490a7046cd1", size = 477235, upload-time = "2024-09-04T20:44:00.18Z" }, @@ -265,6 +289,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1a/52/d9a0e523a572fbccf2955f5abe883cfa8bcc570d7faeee06336fbd50c9fc/cffi-1.17.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a9b15d491f3ad5d692e11f6b71f7857e7835eb677955c00cc0aefcd0669adaf6", size = 477999, upload-time = "2024-09-04T20:44:05.023Z" }, { url = "https://files.pythonhosted.org/packages/44/74/f2a2460684a1a2d00ca799ad880d54652841a780c4c97b87754f660c7603/cffi-1.17.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:de2ea4b5833625383e464549fec1bc395c1bdeeb5f25c4a3a82b5a8c756ec22f", size = 454242, upload-time = "2024-09-04T20:44:06.444Z" }, { url = "https://files.pythonhosted.org/packages/f8/4a/34599cac7dfcd888ff54e801afe06a19c17787dfd94495ab0c8d35fe99fb/cffi-1.17.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fc48c783f9c87e60831201f2cce7f3b2e4846bf4d8728eabe54d60700b318a0b", size = 478604, upload-time = "2024-09-04T20:44:08.206Z" }, + { url = "https://files.pythonhosted.org/packages/34/33/e1b8a1ba29025adbdcda5fb3a36f94c03d771c1b7b12f726ff7fef2ebe36/cffi-1.17.1-cp311-cp311-win32.whl", hash = "sha256:85a950a4ac9c359340d5963966e3e0a94a676bd6245a4b55bc43949eee26a655", size = 171727, upload-time = "2024-09-04T20:44:09.481Z" }, + { url = "https://files.pythonhosted.org/packages/3d/97/50228be003bb2802627d28ec0627837ac0bf35c90cf769812056f235b2d1/cffi-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:caaf0640ef5f5517f49bc275eca1406b0ffa6aa184892812030f04c2abf589a0", size = 181400, upload-time = "2024-09-04T20:44:10.873Z" }, + { url = "https://files.pythonhosted.org/packages/5a/84/e94227139ee5fb4d600a7a4927f322e1d4aea6fdc50bd3fca8493caba23f/cffi-1.17.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:805b4371bf7197c329fcb3ead37e710d1bca9da5d583f5073b799d5c5bd1eee4", size = 183178, upload-time = "2024-09-04T20:44:12.232Z" }, + { url = "https://files.pythonhosted.org/packages/da/ee/fb72c2b48656111c4ef27f0f91da355e130a923473bf5ee75c5643d00cca/cffi-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:733e99bc2df47476e3848417c5a4540522f234dfd4ef3ab7fafdf555b082ec0c", size = 178840, upload-time = "2024-09-04T20:44:13.739Z" }, { url = "https://files.pythonhosted.org/packages/cc/b6/db007700f67d151abadf508cbfd6a1884f57eab90b1bb985c4c8c02b0f28/cffi-1.17.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1257bdabf294dceb59f5e70c64a3e2f462c30c7ad68092d01bbbfb1c16b1ba36", size = 454803, upload-time = "2024-09-04T20:44:15.231Z" }, { url = "https://files.pythonhosted.org/packages/1a/df/f8d151540d8c200eb1c6fba8cd0dfd40904f1b0682ea705c36e6c2e97ab3/cffi-1.17.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da95af8214998d77a98cc14e3a3bd00aa191526343078b530ceb0bd710fb48a5", size = 478850, upload-time = "2024-09-04T20:44:17.188Z" }, { url = "https://files.pythonhosted.org/packages/28/c0/b31116332a547fd2677ae5b78a2ef662dfc8023d67f41b2a83f7c2aa78b1/cffi-1.17.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d63afe322132c194cf832bfec0dc69a99fb9bb6bbd550f161a49e9e855cc78ff", size = 485729, upload-time = "2024-09-04T20:44:18.688Z" }, @@ -272,6 +300,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b2/d5/da47df7004cb17e4955df6a43d14b3b4ae77737dff8bf7f8f333196717bf/cffi-1.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b62ce867176a75d03a665bad002af8e6d54644fad99a3c70905c543130e39d93", size = 479424, upload-time = "2024-09-04T20:44:21.673Z" }, { url = "https://files.pythonhosted.org/packages/0b/ac/2a28bcf513e93a219c8a4e8e125534f4f6db03e3179ba1c45e949b76212c/cffi-1.17.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:386c8bf53c502fff58903061338ce4f4950cbdcb23e2902d86c0f722b786bbe3", size = 484568, upload-time = "2024-09-04T20:44:23.245Z" }, { url = "https://files.pythonhosted.org/packages/d4/38/ca8a4f639065f14ae0f1d9751e70447a261f1a30fa7547a828ae08142465/cffi-1.17.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4ceb10419a9adf4460ea14cfd6bc43d08701f0835e979bf821052f1805850fe8", size = 488736, upload-time = "2024-09-04T20:44:24.757Z" }, + { url = "https://files.pythonhosted.org/packages/86/c5/28b2d6f799ec0bdecf44dced2ec5ed43e0eb63097b0f58c293583b406582/cffi-1.17.1-cp312-cp312-win32.whl", hash = "sha256:a08d7e755f8ed21095a310a693525137cfe756ce62d066e53f502a83dc550f65", size = 172448, upload-time = "2024-09-04T20:44:26.208Z" }, + { url = "https://files.pythonhosted.org/packages/50/b9/db34c4755a7bd1cb2d1603ac3863f22bcecbd1ba29e5ee841a4bc510b294/cffi-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:51392eae71afec0d0c8fb1a53b204dbb3bcabcb3c9b807eedf3e1e6ccf2de903", size = 181976, upload-time = "2024-09-04T20:44:27.578Z" }, + { url = "https://files.pythonhosted.org/packages/b9/ea/8bb50596b8ffbc49ddd7a1ad305035daa770202a6b782fc164647c2673ad/cffi-1.17.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b2ab587605f4ba0bf81dc0cb08a41bd1c0a5906bd59243d56bad7668a6fc6c16", size = 182220, upload-time = "2024-09-04T20:45:01.577Z" }, + { url = "https://files.pythonhosted.org/packages/ae/11/e77c8cd24f58285a82c23af484cf5b124a376b32644e445960d1a4654c3a/cffi-1.17.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:28b16024becceed8c6dfbc75629e27788d8a3f9030691a1dbf9821a128b22c36", size = 178605, upload-time = "2024-09-04T20:45:03.837Z" }, { url = "https://files.pythonhosted.org/packages/ed/65/25a8dc32c53bf5b7b6c2686b42ae2ad58743f7ff644844af7cdb29b49361/cffi-1.17.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1d599671f396c4723d016dbddb72fe8e0397082b0a77a4fab8028923bec050e8", size = 424910, upload-time = "2024-09-04T20:45:05.315Z" }, { url = "https://files.pythonhosted.org/packages/42/7a/9d086fab7c66bd7c4d0f27c57a1b6b068ced810afc498cc8c49e0088661c/cffi-1.17.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca74b8dbe6e8e8263c0ffd60277de77dcee6c837a3d0881d8c1ead7268c9e576", size = 447200, upload-time = "2024-09-04T20:45:06.903Z" }, { url = "https://files.pythonhosted.org/packages/da/63/1785ced118ce92a993b0ec9e0d0ac8dc3e5dbfbcaa81135be56c69cabbb6/cffi-1.17.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f7f5baafcc48261359e14bcd6d9bff6d4b28d9103847c9e136694cb0501aef87", size = 454565, upload-time = "2024-09-04T20:45:08.975Z" }, @@ -280,6 +312,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5b/95/b34462f3ccb09c2594aa782d90a90b045de4ff1f70148ee79c69d37a0a5a/cffi-1.17.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9755e4345d1ec879e3849e62222a18c7174d65a6a92d5b346b1863912168b595", size = 460486, upload-time = "2024-09-04T20:45:13.935Z" }, { url = "https://files.pythonhosted.org/packages/fc/fc/a1e4bebd8d680febd29cf6c8a40067182b64f00c7d105f8f26b5bc54317b/cffi-1.17.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f1e22e8c4419538cb197e4dd60acc919d7696e5ef98ee4da4e01d3f8cfa4cc5a", size = 437911, upload-time = "2024-09-04T20:45:15.696Z" }, { url = "https://files.pythonhosted.org/packages/e6/c3/21cab7a6154b6a5ea330ae80de386e7665254835b9e98ecc1340b3a7de9a/cffi-1.17.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c03e868a0b3bc35839ba98e74211ed2b05d2119be4e8a0f224fba9384f1fe02e", size = 460632, upload-time = "2024-09-04T20:45:17.284Z" }, + { url = "https://files.pythonhosted.org/packages/cb/b5/fd9f8b5a84010ca169ee49f4e4ad6f8c05f4e3545b72ee041dbbcb159882/cffi-1.17.1-cp39-cp39-win32.whl", hash = "sha256:e31ae45bc2e29f6b2abd0de1cc3b9d5205aa847cafaecb8af1476a609a2f6eb7", size = 171820, upload-time = "2024-09-04T20:45:18.762Z" }, + { url = "https://files.pythonhosted.org/packages/8c/52/b08750ce0bce45c143e1b5d7357ee8c55341b52bdef4b0f081af1eb248c2/cffi-1.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:d016c76bdd850f3c626af19b0542c9677ba156e4ee4fccfdd7848803533ef662", size = 181290, upload-time = "2024-09-04T20:45:20.226Z" }, ] [[package]] @@ -369,7 +403,6 @@ name = "cnlp-transformers" source = { editable = "." } dependencies = [ { name = "anaforatools" }, - { name = "click" }, { name = "datasets" }, { name = "fastapi" }, { name = "httpx" }, @@ -380,10 +413,12 @@ dependencies = [ { name = "polars" }, { name = "pydantic" }, { name = "requests" }, + { name = "rich" }, { name = "scikit-learn" }, { name = "seqeval" }, { name = "torch" }, { name = "transformers", extra = ["torch"] }, + { name = "typer" }, { name = "uvicorn", extra = ["standard"] }, ] @@ -395,6 +430,8 @@ build = [ ] dev = [ { name = "build" }, + { name = "ipykernel" }, + { name = "ipywidgets" }, { name = "lorem-text" }, { name = "pip" }, { name = "pre-commit" }, @@ -423,6 +460,10 @@ docs = [ lint = [ { name = "ruff" }, ] +notebooks = [ + { name = "ipykernel" }, + { name = "ipywidgets" }, +] test = [ { name = "lorem-text" }, { name = "pytest" }, @@ -431,7 +472,6 @@ test = [ [package.metadata] requires-dist = [ { name = "anaforatools", specifier = "~=1.1.0" }, - { name = "click", specifier = ">=8.1.7" }, { name = "datasets", specifier = "~=2.21.0" }, { name = "fastapi", specifier = "~=0.115.2" }, { name = "httpx", specifier = ">=0.27.2" }, @@ -442,10 +482,12 @@ requires-dist = [ { name = "polars", specifier = ">=1.30.0" }, { name = "pydantic", specifier = "~=1.10.8" }, { name = "requests", specifier = "~=2.32.2" }, + { name = "rich", specifier = ">=14.0.0" }, { name = "scikit-learn", specifier = "~=1.5.2" }, { name = "seqeval", specifier = "~=1.2.2" }, { name = "torch", specifier = ">=2.6" }, { name = "transformers", extras = ["torch"], specifier = "~=4.51" }, + { name = "typer", specifier = "~=0.16.0" }, { name = "uvicorn", extras = ["standard"], specifier = "~=0.32.0" }, ] @@ -457,6 +499,8 @@ build = [ ] dev = [ { name = "build" }, + { name = "ipykernel" }, + { name = "ipywidgets" }, { name = "lorem-text", specifier = ">=3.0" }, { name = "pip", specifier = ">=21.3" }, { name = "pre-commit" }, @@ -475,6 +519,10 @@ docs = [ { name = "sphobjinv" }, ] lint = [{ name = "ruff", specifier = "==0.11.8" }] +notebooks = [ + { name = "ipykernel" }, + { name = "ipywidgets" }, +] test = [ { name = "lorem-text", specifier = ">=3.0" }, { name = "pytest" }, @@ -489,6 +537,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, ] +[[package]] +name = "comm" +version = "0.2.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/13/7d740c5849255756bc17888787313b61fd38a0a8304fc4f073dfc46122aa/comm-0.2.3.tar.gz", hash = "sha256:2dc8048c10962d55d7ad693be1e7045d891b7ce8d999c97963a5e3e99c055971", size = 6319, upload-time = "2025-07-25T14:02:04.452Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/97/891a0971e1e4a8c5d2b20bbe0e524dc04548d2307fee33cdeba148fd4fc7/comm-0.2.3-py3-none-any.whl", hash = "sha256:c615d91d75f7f04f095b30d1c1711babd43bdc6419c1be9886a85f2f4e489417", size = 7294, upload-time = "2025-07-25T14:02:02.896Z" }, +] + [[package]] name = "cryptography" version = "44.0.2" @@ -551,6 +608,40 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/72/b3/33c4ad44fa020e3757e9b2fad8a5de53d9079b501e6bbc45bdd18f82f893/datasets-2.21.0-py3-none-any.whl", hash = "sha256:25e4e097110ce28824b746a107727ada94024cba11db8bc588d468414692b65a", size = 527251, upload-time = "2024-08-14T06:40:39.612Z" }, ] +[[package]] +name = "debugpy" +version = "1.8.16" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ca/d4/722d0bcc7986172ac2ef3c979ad56a1030e3afd44ced136d45f8142b1f4a/debugpy-1.8.16.tar.gz", hash = "sha256:31e69a1feb1cf6b51efbed3f6c9b0ef03bc46ff050679c4be7ea6d2e23540870", size = 1643809, upload-time = "2025-08-06T18:00:02.647Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ce/fd/f1b75ebc61d90882595b81d808efd3573c082e1c3407850d9dccac4ae904/debugpy-1.8.16-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:2a3958fb9c2f40ed8ea48a0d34895b461de57a1f9862e7478716c35d76f56c65", size = 2085511, upload-time = "2025-08-06T18:00:05.067Z" }, + { url = "https://files.pythonhosted.org/packages/df/5e/c5c1934352871128b30a1a144a58b5baa546e1b57bd47dbed788bad4431c/debugpy-1.8.16-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5ca7314042e8a614cc2574cd71f6ccd7e13a9708ce3c6d8436959eae56f2378", size = 3562094, upload-time = "2025-08-06T18:00:06.66Z" }, + { url = "https://files.pythonhosted.org/packages/c9/d5/2ebe42377e5a78dc786afc25e61ee83c5628d63f32dfa41092597d52fe83/debugpy-1.8.16-cp310-cp310-win32.whl", hash = "sha256:8624a6111dc312ed8c363347a0b59c5acc6210d897e41a7c069de3c53235c9a6", size = 5234277, upload-time = "2025-08-06T18:00:08.429Z" }, + { url = "https://files.pythonhosted.org/packages/54/f8/e774ad16a60b9913213dbabb7472074c5a7b0d84f07c1f383040a9690057/debugpy-1.8.16-cp310-cp310-win_amd64.whl", hash = "sha256:fee6db83ea5c978baf042440cfe29695e1a5d48a30147abf4c3be87513609817", size = 5266011, upload-time = "2025-08-06T18:00:10.162Z" }, + { url = "https://files.pythonhosted.org/packages/63/d6/ad70ba8b49b23fa286fb21081cf732232cc19374af362051da9c7537ae52/debugpy-1.8.16-cp311-cp311-macosx_14_0_universal2.whl", hash = "sha256:67371b28b79a6a12bcc027d94a06158f2fde223e35b5c4e0783b6f9d3b39274a", size = 2184063, upload-time = "2025-08-06T18:00:11.885Z" }, + { url = "https://files.pythonhosted.org/packages/aa/49/7b03e88dea9759a4c7910143f87f92beb494daaae25560184ff4ae883f9e/debugpy-1.8.16-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b2abae6dd02523bec2dee16bd6b0781cccb53fd4995e5c71cc659b5f45581898", size = 3134837, upload-time = "2025-08-06T18:00:13.782Z" }, + { url = "https://files.pythonhosted.org/packages/5d/52/b348930316921de7565fbe37a487d15409041713004f3d74d03eb077dbd4/debugpy-1.8.16-cp311-cp311-win32.whl", hash = "sha256:f8340a3ac2ed4f5da59e064aa92e39edd52729a88fbde7bbaa54e08249a04493", size = 5159142, upload-time = "2025-08-06T18:00:15.391Z" }, + { url = "https://files.pythonhosted.org/packages/d8/ef/9aa9549ce1e10cea696d980292e71672a91ee4a6a691ce5f8629e8f48c49/debugpy-1.8.16-cp311-cp311-win_amd64.whl", hash = "sha256:70f5fcd6d4d0c150a878d2aa37391c52de788c3dc680b97bdb5e529cb80df87a", size = 5183117, upload-time = "2025-08-06T18:00:17.251Z" }, + { url = "https://files.pythonhosted.org/packages/61/fb/0387c0e108d842c902801bc65ccc53e5b91d8c169702a9bbf4f7efcedf0c/debugpy-1.8.16-cp312-cp312-macosx_14_0_universal2.whl", hash = "sha256:b202e2843e32e80b3b584bcebfe0e65e0392920dc70df11b2bfe1afcb7a085e4", size = 2511822, upload-time = "2025-08-06T18:00:18.526Z" }, + { url = "https://files.pythonhosted.org/packages/37/44/19e02745cae22bf96440141f94e15a69a1afaa3a64ddfc38004668fcdebf/debugpy-1.8.16-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64473c4a306ba11a99fe0bb14622ba4fbd943eb004847d9b69b107bde45aa9ea", size = 4230135, upload-time = "2025-08-06T18:00:19.997Z" }, + { url = "https://files.pythonhosted.org/packages/f3/0b/19b1ba5ee4412f303475a2c7ad5858efb99c90eae5ec627aa6275c439957/debugpy-1.8.16-cp312-cp312-win32.whl", hash = "sha256:833a61ed446426e38b0dd8be3e9d45ae285d424f5bf6cd5b2b559c8f12305508", size = 5281271, upload-time = "2025-08-06T18:00:21.281Z" }, + { url = "https://files.pythonhosted.org/packages/b1/e0/bc62e2dc141de53bd03e2c7cb9d7011de2e65e8bdcdaa26703e4d28656ba/debugpy-1.8.16-cp312-cp312-win_amd64.whl", hash = "sha256:75f204684581e9ef3dc2f67687c3c8c183fde2d6675ab131d94084baf8084121", size = 5323149, upload-time = "2025-08-06T18:00:23.033Z" }, + { url = "https://files.pythonhosted.org/packages/35/40/acdad5944e508d5e936979ad3e96e56b78ba6d7fa75aaffc4426cb921e12/debugpy-1.8.16-cp39-cp39-macosx_14_0_x86_64.whl", hash = "sha256:135ccd2b1161bade72a7a099c9208811c137a150839e970aeaf121c2467debe8", size = 2086696, upload-time = "2025-08-06T18:00:36.469Z" }, + { url = "https://files.pythonhosted.org/packages/2d/eb/8d6a2cf3b29e272b5dfebe6f384f8457977d4fd7a02dab2cae4d421dbae2/debugpy-1.8.16-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:211238306331a9089e253fd997213bc4a4c65f949271057d6695953254095376", size = 3557329, upload-time = "2025-08-06T18:00:38.189Z" }, + { url = "https://files.pythonhosted.org/packages/00/7b/63b9cc4d3c6980c702911c0f6a9748933ce4e4f16ae0ec4fdef7690f6662/debugpy-1.8.16-cp39-cp39-win32.whl", hash = "sha256:88eb9ffdfb59bf63835d146c183d6dba1f722b3ae2a5f4b9fc03e925b3358922", size = 5235114, upload-time = "2025-08-06T18:00:39.586Z" }, + { url = "https://files.pythonhosted.org/packages/05/cf/80947f57e0ef4d6e33ec9c3a109a542678eba465723bf8b599719238eb93/debugpy-1.8.16-cp39-cp39-win_amd64.whl", hash = "sha256:c2c47c2e52b40449552843b913786499efcc3dbc21d6c49287d939cd0dbc49fd", size = 5266799, upload-time = "2025-08-06T18:00:41.013Z" }, + { url = "https://files.pythonhosted.org/packages/52/57/ecc9ae29fa5b2d90107cd1d9bf8ed19aacb74b2264d986ae9d44fe9bdf87/debugpy-1.8.16-py2.py3-none-any.whl", hash = "sha256:19c9521962475b87da6f673514f7fd610328757ec993bf7ec0d8c96f9a325f9e", size = 5287700, upload-time = "2025-08-06T18:00:42.333Z" }, +] + +[[package]] +name = "decorator" +version = "5.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/fa/6d96a0978d19e17b68d634497769987b16c8f4cd0a7a05048bec693caa6b/decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360", size = 56711, upload-time = "2025-02-24T04:41:34.073Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190, upload-time = "2025-02-24T04:41:32.565Z" }, +] + [[package]] name = "dill" version = "0.3.8" @@ -587,6 +678,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453, upload-time = "2024-07-12T22:25:58.476Z" }, ] +[[package]] +name = "executing" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/91/50/a9d80c47ff289c611ff12e63f7c5d13942c65d68125160cefd768c73e6e4/executing-2.2.0.tar.gz", hash = "sha256:5d108c028108fe2551d1a7b2e8b713341e2cb4fc0aa7dcf966fa4327a5226755", size = 978693, upload-time = "2025-01-22T15:41:29.403Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702, upload-time = "2025-01-22T15:41:25.929Z" }, +] + [[package]] name = "fastapi" version = "0.115.12" @@ -852,6 +952,138 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050, upload-time = "2025-03-19T20:10:01.071Z" }, ] +[[package]] +name = "ipykernel" +version = "6.30.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "appnope", marker = "sys_platform == 'darwin'" }, + { name = "comm" }, + { name = "debugpy" }, + { name = "ipython", version = "8.18.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "ipython", version = "8.37.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.10.*'" }, + { name = "ipython", version = "9.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "jupyter-client" }, + { name = "jupyter-core" }, + { name = "matplotlib-inline" }, + { name = "nest-asyncio" }, + { name = "packaging" }, + { name = "psutil" }, + { name = "pyzmq" }, + { name = "tornado" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bb/76/11082e338e0daadc89c8ff866185de11daf67d181901038f9e139d109761/ipykernel-6.30.1.tar.gz", hash = "sha256:6abb270161896402e76b91394fcdce5d1be5d45f456671e5080572f8505be39b", size = 166260, upload-time = "2025-08-04T15:47:35.018Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/c7/b445faca8deb954fe536abebff4ece5b097b923de482b26e78448c89d1dd/ipykernel-6.30.1-py3-none-any.whl", hash = "sha256:aa6b9fb93dca949069d8b85b6c79b2518e32ac583ae9c7d37c51d119e18b3fb4", size = 117484, upload-time = "2025-08-04T15:47:32.622Z" }, +] + +[[package]] +name = "ipython" +version = "8.18.1" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.10'", +] +dependencies = [ + { name = "colorama", marker = "python_full_version < '3.10' and sys_platform == 'win32'" }, + { name = "decorator", marker = "python_full_version < '3.10'" }, + { name = "exceptiongroup", marker = "python_full_version < '3.10'" }, + { name = "jedi", marker = "python_full_version < '3.10'" }, + { name = "matplotlib-inline", marker = "python_full_version < '3.10'" }, + { name = "pexpect", marker = "python_full_version < '3.10' and sys_platform != 'win32'" }, + { name = "prompt-toolkit", marker = "python_full_version < '3.10'" }, + { name = "pygments", marker = "python_full_version < '3.10'" }, + { name = "stack-data", marker = "python_full_version < '3.10'" }, + { name = "traitlets", marker = "python_full_version < '3.10'" }, + { name = "typing-extensions", marker = "python_full_version < '3.10'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/b9/3ba6c45a6df813c09a48bac313c22ff83efa26cbb55011218d925a46e2ad/ipython-8.18.1.tar.gz", hash = "sha256:ca6f079bb33457c66e233e4580ebfc4128855b4cf6370dddd73842a9563e8a27", size = 5486330, upload-time = "2023-11-27T09:58:34.596Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/47/6b/d9fdcdef2eb6a23f391251fde8781c38d42acd82abe84d054cb74f7863b0/ipython-8.18.1-py3-none-any.whl", hash = "sha256:e8267419d72d81955ec1177f8a29aaa90ac80ad647499201119e2f05e99aa397", size = 808161, upload-time = "2023-11-27T09:58:30.538Z" }, +] + +[[package]] +name = "ipython" +version = "8.37.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version == '3.10.*'", +] +dependencies = [ + { name = "colorama", marker = "python_full_version == '3.10.*' and sys_platform == 'win32'" }, + { name = "decorator", marker = "python_full_version == '3.10.*'" }, + { name = "exceptiongroup", marker = "python_full_version == '3.10.*'" }, + { name = "jedi", marker = "python_full_version == '3.10.*'" }, + { name = "matplotlib-inline", marker = "python_full_version == '3.10.*'" }, + { name = "pexpect", marker = "python_full_version == '3.10.*' and sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "prompt-toolkit", marker = "python_full_version == '3.10.*'" }, + { name = "pygments", marker = "python_full_version == '3.10.*'" }, + { name = "stack-data", marker = "python_full_version == '3.10.*'" }, + { name = "traitlets", marker = "python_full_version == '3.10.*'" }, + { name = "typing-extensions", marker = "python_full_version == '3.10.*'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/85/31/10ac88f3357fc276dc8a64e8880c82e80e7459326ae1d0a211b40abf6665/ipython-8.37.0.tar.gz", hash = "sha256:ca815841e1a41a1e6b73a0b08f3038af9b2252564d01fc405356d34033012216", size = 5606088, upload-time = "2025-05-31T16:39:09.613Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/d0/274fbf7b0b12643cbbc001ce13e6a5b1607ac4929d1b11c72460152c9fc3/ipython-8.37.0-py3-none-any.whl", hash = "sha256:ed87326596b878932dbcb171e3e698845434d8c61b8d8cd474bf663041a9dcf2", size = 831864, upload-time = "2025-05-31T16:39:06.38Z" }, +] + +[[package]] +name = "ipython" +version = "9.4.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12'", + "python_full_version == '3.11.*'", +] +dependencies = [ + { name = "colorama", marker = "python_full_version >= '3.11' and sys_platform == 'win32'" }, + { name = "decorator", marker = "python_full_version >= '3.11'" }, + { name = "ipython-pygments-lexers", marker = "python_full_version >= '3.11'" }, + { name = "jedi", marker = "python_full_version >= '3.11'" }, + { name = "matplotlib-inline", marker = "python_full_version >= '3.11'" }, + { name = "pexpect", marker = "python_full_version >= '3.11' and sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "prompt-toolkit", marker = "python_full_version >= '3.11'" }, + { name = "pygments", marker = "python_full_version >= '3.11'" }, + { name = "stack-data", marker = "python_full_version >= '3.11'" }, + { name = "traitlets", marker = "python_full_version >= '3.11'" }, + { name = "typing-extensions", marker = "python_full_version == '3.11.*'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/54/80/406f9e3bde1c1fd9bf5a0be9d090f8ae623e401b7670d8f6fdf2ab679891/ipython-9.4.0.tar.gz", hash = "sha256:c033c6d4e7914c3d9768aabe76bbe87ba1dc66a92a05db6bfa1125d81f2ee270", size = 4385338, upload-time = "2025-07-01T11:11:30.606Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/63/f8/0031ee2b906a15a33d6bfc12dd09c3dfa966b3cb5b284ecfb7549e6ac3c4/ipython-9.4.0-py3-none-any.whl", hash = "sha256:25850f025a446d9b359e8d296ba175a36aedd32e83ca9b5060430fe16801f066", size = 611021, upload-time = "2025-07-01T11:11:27.85Z" }, +] + +[[package]] +name = "ipython-pygments-lexers" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pygments", marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ef/4c/5dd1d8af08107f88c7f741ead7a40854b8ac24ddf9ae850afbcf698aa552/ipython_pygments_lexers-1.1.1.tar.gz", hash = "sha256:09c0138009e56b6854f9535736f4171d855c8c08a563a0dcd8022f78355c7e81", size = 8393, upload-time = "2025-01-17T11:24:34.505Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074, upload-time = "2025-01-17T11:24:33.271Z" }, +] + +[[package]] +name = "ipywidgets" +version = "8.1.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "comm" }, + { name = "ipython", version = "8.18.1", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.10'" }, + { name = "ipython", version = "8.37.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version == '3.10.*'" }, + { name = "ipython", version = "9.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "jupyterlab-widgets" }, + { name = "traitlets" }, + { name = "widgetsnbextension" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3e/48/d3dbac45c2814cb73812f98dd6b38bbcc957a4e7bb31d6ea9c03bf94ed87/ipywidgets-8.1.7.tar.gz", hash = "sha256:15f1ac050b9ccbefd45dccfbb2ef6bed0029d8278682d569d71b8dd96bee0376", size = 116721, upload-time = "2025-05-05T12:42:03.489Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/58/6a/9166369a2f092bd286d24e6307de555d63616e8ddb373ebad2b5635ca4cd/ipywidgets-8.1.7-py3-none-any.whl", hash = "sha256:764f2602d25471c213919b8a1997df04bef869251db4ca8efba1b76b1bd9f7bb", size = 139806, upload-time = "2025-05-05T12:41:56.833Z" }, +] + [[package]] name = "jaraco-classes" version = "3.4.0" @@ -888,6 +1120,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9f/4f/24b319316142c44283d7540e76c7b5a6dbd5db623abd86bb7b3491c21018/jaraco.functools-4.1.0-py3-none-any.whl", hash = "sha256:ad159f13428bc4acbf5541ad6dec511f91573b90fba04df61dafa2a1231cf649", size = 10187, upload-time = "2024-09-27T19:47:07.14Z" }, ] +[[package]] +name = "jedi" +version = "0.19.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "parso" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287, upload-time = "2024-11-11T01:41:42.873Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278, upload-time = "2024-11-11T01:41:40.175Z" }, +] + [[package]] name = "jeepney" version = "0.9.0" @@ -945,6 +1189,46 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/01/0e/b27cdbaccf30b890c40ed1da9fd4a3593a5cf94dae54fb34f8a4b74fcd3f/jsonschema_specifications-2025.4.1-py3-none-any.whl", hash = "sha256:4653bffbd6584f7de83a67e0d620ef16900b390ddc7939d56684d6c81e33f1af", size = 18437, upload-time = "2025-04-23T12:34:05.422Z" }, ] +[[package]] +name = "jupyter-client" +version = "8.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "importlib-metadata", marker = "python_full_version < '3.10'" }, + { name = "jupyter-core" }, + { name = "python-dateutil" }, + { name = "pyzmq" }, + { name = "tornado" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/22/bf9f12fdaeae18019a468b68952a60fe6dbab5d67cd2a103cac7659b41ca/jupyter_client-8.6.3.tar.gz", hash = "sha256:35b3a0947c4a6e9d589eb97d7d4cd5e90f910ee73101611f01283732bd6d9419", size = 342019, upload-time = "2024-09-17T10:44:17.613Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/85/b0394e0b6fcccd2c1eeefc230978a6f8cb0c5df1e4cd3e7625735a0d7d1e/jupyter_client-8.6.3-py3-none-any.whl", hash = "sha256:e8a19cc986cc45905ac3362915f410f3af85424b4c0905e94fa5f2cb08e8f23f", size = 106105, upload-time = "2024-09-17T10:44:15.218Z" }, +] + +[[package]] +name = "jupyter-core" +version = "5.8.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "platformdirs" }, + { name = "pywin32", marker = "platform_python_implementation != 'PyPy' and sys_platform == 'win32'" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/99/1b/72906d554acfeb588332eaaa6f61577705e9ec752ddb486f302dafa292d9/jupyter_core-5.8.1.tar.gz", hash = "sha256:0a5f9706f70e64786b75acba995988915ebd4601c8a52e534a40b51c95f59941", size = 88923, upload-time = "2025-05-27T07:38:16.655Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2f/57/6bffd4b20b88da3800c5d691e0337761576ee688eb01299eae865689d2df/jupyter_core-5.8.1-py3-none-any.whl", hash = "sha256:c28d268fc90fb53f1338ded2eb410704c5449a358406e8a948b75706e24863d0", size = 28880, upload-time = "2025-05-27T07:38:15.137Z" }, +] + +[[package]] +name = "jupyterlab-widgets" +version = "3.0.15" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b9/7d/160595ca88ee87ac6ba95d82177d29ec60aaa63821d3077babb22ce031a5/jupyterlab_widgets-3.0.15.tar.gz", hash = "sha256:2920888a0c2922351a9202817957a68c07d99673504d6cd37345299e971bb08b", size = 213149, upload-time = "2025-05-05T12:32:31.004Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/6a/ca128561b22b60bd5a0c4ea26649e68c8556b82bc70a0c396eebc977fe86/jupyterlab_widgets-3.0.15-py3-none-any.whl", hash = "sha256:d59023d7d7ef71400d51e6fee9a88867f6e65e10a4201605d2d7f3e8f012a31c", size = 216571, upload-time = "2025-05-05T12:32:29.534Z" }, +] + [[package]] name = "keyring" version = "25.6.0" @@ -1032,6 +1316,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/73/085399401383ce949f727afec55ec3abd76648d04b9f22e1c0e99cb4bec3/MarkupSafe-3.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a", size = 15506, upload-time = "2024-10-18T15:21:52.974Z" }, ] +[[package]] +name = "matplotlib-inline" +version = "0.1.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/99/5b/a36a337438a14116b16480db471ad061c36c3694df7c2084a0da7ba538b7/matplotlib_inline-0.1.7.tar.gz", hash = "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90", size = 8159, upload-time = "2024-04-15T13:44:44.803Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899, upload-time = "2024-04-15T13:44:43.265Z" }, +] + [[package]] name = "mdurl" version = "0.1.2" @@ -1159,6 +1455,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3", size = 133351, upload-time = "2024-01-28T18:52:31.981Z" }, ] +[[package]] +name = "nest-asyncio" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/83/f8/51569ac65d696c8ecbee95938f89d4abf00f47d58d48f6fbabfe8f0baefe/nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe", size = 7418, upload-time = "2024-01-21T14:25:19.227Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195, upload-time = "2024-01-21T14:25:17.223Z" }, +] + [[package]] name = "networkx" version = "3.2.1" @@ -1467,6 +1772,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2f/49/5c30646e96c684570925b772eac4eb0a8cb0ca590fa978f56c5d3ae73ea1/pandas-2.2.3-cp39-cp39-win_amd64.whl", hash = "sha256:4850ba03528b6dd51d6c5d273c46f183f39a9baf3f0143e566b89450965b105e", size = 11618011, upload-time = "2024-09-20T13:10:02.351Z" }, ] +[[package]] +name = "parso" +version = "0.8.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/66/94/68e2e17afaa9169cf6412ab0f28623903be73d1b32e208d9e8e541bb086d/parso-0.8.4.tar.gz", hash = "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d", size = 400609, upload-time = "2024-04-05T09:43:55.897Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650, upload-time = "2024-04-05T09:43:53.299Z" }, +] + +[[package]] +name = "pexpect" +version = "4.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ptyprocess" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/92/cc564bf6381ff43ce1f4d06852fc19a2f11d180f23dc32d9588bee2f149d/pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f", size = 166450, upload-time = "2023-11-25T09:07:26.339Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772, upload-time = "2023-11-25T06:56:14.81Z" }, +] + [[package]] name = "pip" version = "25.1.1" @@ -1524,6 +1850,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/74/a88bf1b1efeae488a0c0b7bdf71429c313722d1fc0f377537fbe554e6180/pre_commit-4.2.0-py2.py3-none-any.whl", hash = "sha256:a009ca7205f1eb497d10b845e52c838a98b6cdd2102a6c8e4540e94ee75c58bd", size = 220707, upload-time = "2025-03-18T21:35:19.343Z" }, ] +[[package]] +name = "prompt-toolkit" +version = "3.0.51" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wcwidth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bb/6e/9d084c929dfe9e3bfe0c6a47e31f78a25c54627d64a66e884a8bf5474f1c/prompt_toolkit-3.0.51.tar.gz", hash = "sha256:931a162e3b27fc90c86f1b48bb1fb2c528c2761475e57c9c06de13311c7b54ed", size = 428940, upload-time = "2025-04-15T09:18:47.731Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ce/4f/5249960887b1fbe561d9ff265496d170b55a735b76724f10ef19f9e40716/prompt_toolkit-3.0.51-py3-none-any.whl", hash = "sha256:52742911fde84e2d423e2f9a4cf1de7d7ac4e51958f648d9540e0fb8db077b07", size = 387810, upload-time = "2025-04-15T09:18:44.753Z" }, +] + [[package]] name = "propcache" version = "0.3.1" @@ -1612,6 +1950,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/50/1b/6921afe68c74868b4c9fa424dad3be35b095e16687989ebbb50ce4fceb7c/psutil-7.0.0-cp37-abi3-win_amd64.whl", hash = "sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553", size = 244885, upload-time = "2025-02-13T21:54:37.486Z" }, ] +[[package]] +name = "ptyprocess" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/e5/16ff212c1e452235a90aeb09066144d0c5a6a8c0834397e03f5224495c4e/ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220", size = 70762, upload-time = "2020-12-28T15:15:30.155Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35", size = 13993, upload-time = "2020-12-28T15:15:28.35Z" }, +] + +[[package]] +name = "pure-eval" +version = "0.2.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/05/0a34433a064256a578f1783a10da6df098ceaa4a57bbeaa96a6c0352786b/pure_eval-0.2.3.tar.gz", hash = "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42", size = 19752, upload-time = "2024-07-21T12:58:21.801Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842, upload-time = "2024-07-21T12:58:20.04Z" }, +] + [[package]] name = "pyarrow" version = "20.0.0" @@ -1770,6 +2126,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, ] +[[package]] +name = "pywin32" +version = "311" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/40/44efbb0dfbd33aca6a6483191dae0716070ed99e2ecb0c53683f400a0b4f/pywin32-311-cp310-cp310-win32.whl", hash = "sha256:d03ff496d2a0cd4a5893504789d4a15399133fe82517455e78bad62efbb7f0a3", size = 8760432, upload-time = "2025-07-14T20:13:05.9Z" }, + { url = "https://files.pythonhosted.org/packages/5e/bf/360243b1e953bd254a82f12653974be395ba880e7ec23e3731d9f73921cc/pywin32-311-cp310-cp310-win_amd64.whl", hash = "sha256:797c2772017851984b97180b0bebe4b620bb86328e8a884bb626156295a63b3b", size = 9590103, upload-time = "2025-07-14T20:13:07.698Z" }, + { url = "https://files.pythonhosted.org/packages/57/38/d290720e6f138086fb3d5ffe0b6caa019a791dd57866940c82e4eeaf2012/pywin32-311-cp310-cp310-win_arm64.whl", hash = "sha256:0502d1facf1fed4839a9a51ccbcc63d952cf318f78ffc00a7e78528ac27d7a2b", size = 8778557, upload-time = "2025-07-14T20:13:11.11Z" }, + { url = "https://files.pythonhosted.org/packages/7c/af/449a6a91e5d6db51420875c54f6aff7c97a86a3b13a0b4f1a5c13b988de3/pywin32-311-cp311-cp311-win32.whl", hash = "sha256:184eb5e436dea364dcd3d2316d577d625c0351bf237c4e9a5fabbcfa5a58b151", size = 8697031, upload-time = "2025-07-14T20:13:13.266Z" }, + { url = "https://files.pythonhosted.org/packages/51/8f/9bb81dd5bb77d22243d33c8397f09377056d5c687aa6d4042bea7fbf8364/pywin32-311-cp311-cp311-win_amd64.whl", hash = "sha256:3ce80b34b22b17ccbd937a6e78e7225d80c52f5ab9940fe0506a1a16f3dab503", size = 9508308, upload-time = "2025-07-14T20:13:15.147Z" }, + { url = "https://files.pythonhosted.org/packages/44/7b/9c2ab54f74a138c491aba1b1cd0795ba61f144c711daea84a88b63dc0f6c/pywin32-311-cp311-cp311-win_arm64.whl", hash = "sha256:a733f1388e1a842abb67ffa8e7aad0e70ac519e09b0f6a784e65a136ec7cefd2", size = 8703930, upload-time = "2025-07-14T20:13:16.945Z" }, + { url = "https://files.pythonhosted.org/packages/e7/ab/01ea1943d4eba0f850c3c61e78e8dd59757ff815ff3ccd0a84de5f541f42/pywin32-311-cp312-cp312-win32.whl", hash = "sha256:750ec6e621af2b948540032557b10a2d43b0cee2ae9758c54154d711cc852d31", size = 8706543, upload-time = "2025-07-14T20:13:20.765Z" }, + { url = "https://files.pythonhosted.org/packages/d1/a8/a0e8d07d4d051ec7502cd58b291ec98dcc0c3fff027caad0470b72cfcc2f/pywin32-311-cp312-cp312-win_amd64.whl", hash = "sha256:b8c095edad5c211ff31c05223658e71bf7116daa0ecf3ad85f3201ea3190d067", size = 9495040, upload-time = "2025-07-14T20:13:22.543Z" }, + { url = "https://files.pythonhosted.org/packages/ba/3a/2ae996277b4b50f17d61f0603efd8253cb2d79cc7ae159468007b586396d/pywin32-311-cp312-cp312-win_arm64.whl", hash = "sha256:e286f46a9a39c4a18b319c28f59b61de793654af2f395c102b4f819e584b5852", size = 8710102, upload-time = "2025-07-14T20:13:24.682Z" }, + { url = "https://files.pythonhosted.org/packages/59/42/b86689aac0cdaee7ae1c58d464b0ff04ca909c19bb6502d4973cdd9f9544/pywin32-311-cp39-cp39-win32.whl", hash = "sha256:aba8f82d551a942cb20d4a83413ccbac30790b50efb89a75e4f586ac0bb8056b", size = 8760837, upload-time = "2025-07-14T20:12:59.59Z" }, + { url = "https://files.pythonhosted.org/packages/9f/8a/1403d0353f8c5a2f0829d2b1c4becbf9da2f0a4d040886404fc4a5431e4d/pywin32-311-cp39-cp39-win_amd64.whl", hash = "sha256:e0c4cfb0621281fe40387df582097fd796e80430597cb9944f0ae70447bacd91", size = 9590187, upload-time = "2025-07-14T20:13:01.419Z" }, + { url = "https://files.pythonhosted.org/packages/60/22/e0e8d802f124772cec9c75430b01a212f86f9de7546bda715e54140d5aeb/pywin32-311-cp39-cp39-win_arm64.whl", hash = "sha256:62ea666235135fee79bb154e695f3ff67370afefd71bd7fea7512fc70ef31e3d", size = 8778162, upload-time = "2025-07-14T20:13:03.544Z" }, +] + [[package]] name = "pywin32-ctypes" version = "0.2.3" @@ -1823,6 +2198,72 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/19/87/5124b1c1f2412bb95c59ec481eaf936cd32f0fe2a7b16b97b81c4c017a6a/PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8", size = 162312, upload-time = "2024-08-06T20:33:49.073Z" }, ] +[[package]] +name = "pyzmq" +version = "27.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "implementation_name == 'pypy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/30/5f/557d2032a2f471edbcc227da724c24a1c05887b5cda1e3ae53af98b9e0a5/pyzmq-27.0.1.tar.gz", hash = "sha256:45c549204bc20e7484ffd2555f6cf02e572440ecf2f3bdd60d4404b20fddf64b", size = 281158, upload-time = "2025-08-03T05:05:40.352Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/0b/ccf4d0b152a6a11f0fc01e73978202fe0e8fe0e91e20941598e83a170bee/pyzmq-27.0.1-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:90a4da42aa322de8a3522461e3b5fe999935763b27f69a02fced40f4e3cf9682", size = 1329293, upload-time = "2025-08-03T05:02:56.001Z" }, + { url = "https://files.pythonhosted.org/packages/bc/76/48706d291951b1300d3cf985e503806901164bf1581f27c4b6b22dbab2fa/pyzmq-27.0.1-cp310-cp310-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:e648dca28178fc879c814cf285048dd22fd1f03e1104101106505ec0eea50a4d", size = 905953, upload-time = "2025-08-03T05:02:59.061Z" }, + { url = "https://files.pythonhosted.org/packages/aa/8a/df3135b96712068d184c53120c7dbf3023e5e362a113059a4f85cd36c6a0/pyzmq-27.0.1-cp310-cp310-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4bca8abc31799a6f3652d13f47e0b0e1cab76f9125f2283d085a3754f669b607", size = 666165, upload-time = "2025-08-03T05:03:00.789Z" }, + { url = "https://files.pythonhosted.org/packages/ee/ed/341a7148e08d2830f480f53ab3d136d88fc5011bb367b516d95d0ebb46dd/pyzmq-27.0.1-cp310-cp310-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:092f4011b26d6b0201002f439bd74b38f23f3aefcb358621bdc3b230afc9b2d5", size = 853756, upload-time = "2025-08-03T05:03:03.347Z" }, + { url = "https://files.pythonhosted.org/packages/c2/bc/d26fe010477c3e901f0f5a3e70446950dde9aa217f1d1a13534eb0fccfe5/pyzmq-27.0.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6f02f30a4a6b3efe665ab13a3dd47109d80326c8fd286311d1ba9f397dc5f247", size = 1654870, upload-time = "2025-08-03T05:03:05.331Z" }, + { url = "https://files.pythonhosted.org/packages/32/21/9b488086bf3f55b2eb26db09007a3962f62f3b81c5c6295a6ff6aaebd69c/pyzmq-27.0.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:f293a1419266e3bf3557d1f8778f9e1ffe7e6b2c8df5c9dca191caf60831eb74", size = 2033444, upload-time = "2025-08-03T05:03:07.318Z" }, + { url = "https://files.pythonhosted.org/packages/3d/53/85b64a792223cd43393d25e03c8609df41aac817ea5ce6a27eceeed433ee/pyzmq-27.0.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ce181dd1a7c6c012d0efa8ab603c34b5ee9d86e570c03415bbb1b8772eeb381c", size = 1891289, upload-time = "2025-08-03T05:03:08.96Z" }, + { url = "https://files.pythonhosted.org/packages/23/5b/078aae8fe1c4cdba1a77a598870c548fd52b4d4a11e86b8116bbef47d9f3/pyzmq-27.0.1-cp310-cp310-win32.whl", hash = "sha256:f65741cc06630652e82aa68ddef4986a3ab9073dd46d59f94ce5f005fa72037c", size = 566693, upload-time = "2025-08-03T05:03:10.711Z" }, + { url = "https://files.pythonhosted.org/packages/24/e1/4471fff36416ebf1ffe43577b9c7dcf2ff4798f2171f0d169640a48d2305/pyzmq-27.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:44909aa3ed2234d69fe81e1dade7be336bcfeab106e16bdaa3318dcde4262b93", size = 631649, upload-time = "2025-08-03T05:03:12.232Z" }, + { url = "https://files.pythonhosted.org/packages/e8/4c/8edac8dd56f223124aa40403d2c097bbad9b0e2868a67cad9a2a029863aa/pyzmq-27.0.1-cp310-cp310-win_arm64.whl", hash = "sha256:4401649bfa0a38f0f8777f8faba7cd7eb7b5b8ae2abc7542b830dd09ad4aed0d", size = 559274, upload-time = "2025-08-03T05:03:13.728Z" }, + { url = "https://files.pythonhosted.org/packages/ae/18/a8e0da6ababbe9326116fb1c890bf1920eea880e8da621afb6bc0f39a262/pyzmq-27.0.1-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:9729190bd770314f5fbba42476abf6abe79a746eeda11d1d68fd56dd70e5c296", size = 1332721, upload-time = "2025-08-03T05:03:15.237Z" }, + { url = "https://files.pythonhosted.org/packages/75/a4/9431ba598651d60ebd50dc25755402b770322cf8432adcc07d2906e53a54/pyzmq-27.0.1-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:696900ef6bc20bef6a242973943574f96c3f97d2183c1bd3da5eea4f559631b1", size = 908249, upload-time = "2025-08-03T05:03:16.933Z" }, + { url = "https://files.pythonhosted.org/packages/f0/7a/e624e1793689e4e685d2ee21c40277dd4024d9d730af20446d88f69be838/pyzmq-27.0.1-cp311-cp311-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f96a63aecec22d3f7fdea3c6c98df9e42973f5856bb6812c3d8d78c262fee808", size = 668649, upload-time = "2025-08-03T05:03:18.49Z" }, + { url = "https://files.pythonhosted.org/packages/6c/29/0652a39d4e876e0d61379047ecf7752685414ad2e253434348246f7a2a39/pyzmq-27.0.1-cp311-cp311-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c512824360ea7490390566ce00bee880e19b526b312b25cc0bc30a0fe95cb67f", size = 856601, upload-time = "2025-08-03T05:03:20.194Z" }, + { url = "https://files.pythonhosted.org/packages/36/2d/8d5355d7fc55bb6e9c581dd74f58b64fa78c994079e3a0ea09b1b5627cde/pyzmq-27.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dfb2bb5e0f7198eaacfb6796fb0330afd28f36d985a770745fba554a5903595a", size = 1657750, upload-time = "2025-08-03T05:03:22.055Z" }, + { url = "https://files.pythonhosted.org/packages/ab/f4/cd032352d5d252dc6f5ee272a34b59718ba3af1639a8a4ef4654f9535cf5/pyzmq-27.0.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4f6886c59ba93ffde09b957d3e857e7950c8fe818bd5494d9b4287bc6d5bc7f1", size = 2034312, upload-time = "2025-08-03T05:03:23.578Z" }, + { url = "https://files.pythonhosted.org/packages/e4/1a/c050d8b6597200e97a4bd29b93c769d002fa0b03083858227e0376ad59bc/pyzmq-27.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b99ea9d330e86ce1ff7f2456b33f1bf81c43862a5590faf4ef4ed3a63504bdab", size = 1893632, upload-time = "2025-08-03T05:03:25.167Z" }, + { url = "https://files.pythonhosted.org/packages/6a/29/173ce21d5097e7fcf284a090e8beb64fc683c6582b1f00fa52b1b7e867ce/pyzmq-27.0.1-cp311-cp311-win32.whl", hash = "sha256:571f762aed89025ba8cdcbe355fea56889715ec06d0264fd8b6a3f3fa38154ed", size = 566587, upload-time = "2025-08-03T05:03:26.769Z" }, + { url = "https://files.pythonhosted.org/packages/53/ab/22bd33e7086f0a2cc03a5adabff4bde414288bb62a21a7820951ef86ec20/pyzmq-27.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:ee16906c8025fa464bea1e48128c048d02359fb40bebe5333103228528506530", size = 632873, upload-time = "2025-08-03T05:03:28.685Z" }, + { url = "https://files.pythonhosted.org/packages/90/14/3e59b4a28194285ceeff725eba9aa5ba8568d1cb78aed381dec1537c705a/pyzmq-27.0.1-cp311-cp311-win_arm64.whl", hash = "sha256:ba068f28028849da725ff9185c24f832ccf9207a40f9b28ac46ab7c04994bd41", size = 558918, upload-time = "2025-08-03T05:03:30.085Z" }, + { url = "https://files.pythonhosted.org/packages/0e/9b/c0957041067c7724b310f22c398be46399297c12ed834c3bc42200a2756f/pyzmq-27.0.1-cp312-abi3-macosx_10_15_universal2.whl", hash = "sha256:af7ebce2a1e7caf30c0bb64a845f63a69e76a2fadbc1cac47178f7bb6e657bdd", size = 1305432, upload-time = "2025-08-03T05:03:32.177Z" }, + { url = "https://files.pythonhosted.org/packages/8e/55/bd3a312790858f16b7def3897a0c3eb1804e974711bf7b9dcb5f47e7f82c/pyzmq-27.0.1-cp312-abi3-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:8f617f60a8b609a13099b313e7e525e67f84ef4524b6acad396d9ff153f6e4cd", size = 895095, upload-time = "2025-08-03T05:03:33.918Z" }, + { url = "https://files.pythonhosted.org/packages/20/50/fc384631d8282809fb1029a4460d2fe90fa0370a0e866a8318ed75c8d3bb/pyzmq-27.0.1-cp312-abi3-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1d59dad4173dc2a111f03e59315c7bd6e73da1a9d20a84a25cf08325b0582b1a", size = 651826, upload-time = "2025-08-03T05:03:35.818Z" }, + { url = "https://files.pythonhosted.org/packages/7e/0a/2356305c423a975000867de56888b79e44ec2192c690ff93c3109fd78081/pyzmq-27.0.1-cp312-abi3-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f5b6133c8d313bde8bd0d123c169d22525300ff164c2189f849de495e1344577", size = 839751, upload-time = "2025-08-03T05:03:37.265Z" }, + { url = "https://files.pythonhosted.org/packages/d7/1b/81e95ad256ca7e7ccd47f5294c1c6da6e2b64fbace65b84fe8a41470342e/pyzmq-27.0.1-cp312-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:58cca552567423f04d06a075f4b473e78ab5bdb906febe56bf4797633f54aa4e", size = 1641359, upload-time = "2025-08-03T05:03:38.799Z" }, + { url = "https://files.pythonhosted.org/packages/50/63/9f50ec965285f4e92c265c8f18344e46b12803666d8b73b65d254d441435/pyzmq-27.0.1-cp312-abi3-musllinux_1_2_i686.whl", hash = "sha256:4b9d8e26fb600d0d69cc9933e20af08552e97cc868a183d38a5c0d661e40dfbb", size = 2020281, upload-time = "2025-08-03T05:03:40.338Z" }, + { url = "https://files.pythonhosted.org/packages/02/4a/19e3398d0dc66ad2b463e4afa1fc541d697d7bc090305f9dfb948d3dfa29/pyzmq-27.0.1-cp312-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:2329f0c87f0466dce45bba32b63f47018dda5ca40a0085cc5c8558fea7d9fc55", size = 1877112, upload-time = "2025-08-03T05:03:42.012Z" }, + { url = "https://files.pythonhosted.org/packages/bf/42/c562e9151aa90ed1d70aac381ea22a929d6b3a2ce4e1d6e2e135d34fd9c6/pyzmq-27.0.1-cp312-abi3-win32.whl", hash = "sha256:57bb92abdb48467b89c2d21da1ab01a07d0745e536d62afd2e30d5acbd0092eb", size = 558177, upload-time = "2025-08-03T05:03:43.979Z" }, + { url = "https://files.pythonhosted.org/packages/40/96/5c50a7d2d2b05b19994bf7336b97db254299353dd9b49b565bb71b485f03/pyzmq-27.0.1-cp312-abi3-win_amd64.whl", hash = "sha256:ff3f8757570e45da7a5bedaa140489846510014f7a9d5ee9301c61f3f1b8a686", size = 618923, upload-time = "2025-08-03T05:03:45.438Z" }, + { url = "https://files.pythonhosted.org/packages/13/33/1ec89c8f21c89d21a2eaff7def3676e21d8248d2675705e72554fb5a6f3f/pyzmq-27.0.1-cp312-abi3-win_arm64.whl", hash = "sha256:df2c55c958d3766bdb3e9d858b911288acec09a9aab15883f384fc7180df5bed", size = 552358, upload-time = "2025-08-03T05:03:46.887Z" }, + { url = "https://files.pythonhosted.org/packages/7c/f1/cdceaf9b6637570f36eee2dbd25bc5a800637cd9b4103b15fbc4b0658b82/pyzmq-27.0.1-cp39-cp39-macosx_10_15_universal2.whl", hash = "sha256:05a94233fdde585eb70924a6e4929202a747eea6ed308a6171c4f1c715bbe39e", size = 1330651, upload-time = "2025-08-03T05:04:45.583Z" }, + { url = "https://files.pythonhosted.org/packages/74/5c/469d3b9315eb4d5c61c431a4ae8acdb6abb165dfa5ddbc7af639be53891c/pyzmq-27.0.1-cp39-cp39-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:c96702e1082eab62ae583d64c4e19c9b848359196697e536a0c57ae9bd165bd5", size = 906524, upload-time = "2025-08-03T05:04:47.904Z" }, + { url = "https://files.pythonhosted.org/packages/ed/c0/c7a12a533a87beb1143f4a9c8f4d6f82775c04eb3ad27f664e0ef00a6189/pyzmq-27.0.1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c9180d1f5b4b73e28b64e63cc6c4c097690f102aa14935a62d5dd7426a4e5b5a", size = 863547, upload-time = "2025-08-03T05:04:49.579Z" }, + { url = "https://files.pythonhosted.org/packages/41/78/50907d004511bd23eae03d951f3ca4e4cc2e7eb5ec8d3df70d89eca3f97c/pyzmq-27.0.1-cp39-cp39-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e971d8680003d0af6020713e52f92109b46fedb463916e988814e04c8133578a", size = 666797, upload-time = "2025-08-03T05:04:51.263Z" }, + { url = "https://files.pythonhosted.org/packages/67/bd/ec3388888eda39705a4cefb465452a4bca5430a3435803588ced49943fdb/pyzmq-27.0.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:fe632fa4501154d58dfbe1764a0495734d55f84eaf1feda4549a1f1ca76659e9", size = 1655601, upload-time = "2025-08-03T05:04:53.026Z" }, + { url = "https://files.pythonhosted.org/packages/84/50/170a1671a171365dda677886d42c39629a086752696ede70296b8f6224d8/pyzmq-27.0.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4c3874344fd5fa6d58bb51919708048ac4cab21099f40a227173cddb76b4c20b", size = 2034120, upload-time = "2025-08-03T05:04:55.323Z" }, + { url = "https://files.pythonhosted.org/packages/a4/0a/f06841495e4ec33ed65588e94aff07f1dcbc6878e1611577f6b97a449068/pyzmq-27.0.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:0ec09073ed67ae236785d543df3b322282acc0bdf6d1b748c3e81f3043b21cb5", size = 1891956, upload-time = "2025-08-03T05:04:57.084Z" }, + { url = "https://files.pythonhosted.org/packages/d9/6b/6ba945a4756e4b1ba69b909d2b040d16aff0f0edd56a60874970b8d47237/pyzmq-27.0.1-cp39-cp39-win32.whl", hash = "sha256:f44e7ea288d022d4bf93b9e79dafcb4a7aea45a3cbeae2116792904931cefccf", size = 567388, upload-time = "2025-08-03T05:04:58.704Z" }, + { url = "https://files.pythonhosted.org/packages/b0/b4/8ffb9cfb363bc9d61c5d8d9f79a7ada572b0865dac9f4a547da901b81d76/pyzmq-27.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:ffe6b809a97ac6dea524b3b837d5b28743d8c2f121141056d168ff0ba8f614ef", size = 632004, upload-time = "2025-08-03T05:05:00.434Z" }, + { url = "https://files.pythonhosted.org/packages/6c/4b/dd5c4d3bb7261efb30a909d2df447ac77393653e5c34c8a9cd536f429c3e/pyzmq-27.0.1-cp39-cp39-win_arm64.whl", hash = "sha256:fde26267416c8478c95432c81489b53f57b0b5d24cd5c8bfaebf5bbaac4dc90c", size = 559881, upload-time = "2025-08-03T05:05:02.363Z" }, + { url = "https://files.pythonhosted.org/packages/6f/87/fc96f224dd99070fe55d0afc37ac08d7d4635d434e3f9425b232867e01b9/pyzmq-27.0.1-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:544b995a6a1976fad5d7ff01409b4588f7608ccc41be72147700af91fd44875d", size = 835950, upload-time = "2025-08-03T05:05:04.193Z" }, + { url = "https://files.pythonhosted.org/packages/d1/b6/802d96017f176c3a7285603d9ed2982550095c136c6230d3e0b53f52c7e5/pyzmq-27.0.1-pp310-pypy310_pp73-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:0f772eea55cccce7f45d6ecdd1d5049c12a77ec22404f6b892fae687faa87bee", size = 799876, upload-time = "2025-08-03T05:05:06.263Z" }, + { url = "https://files.pythonhosted.org/packages/4e/52/49045c6528007cce385f218f3a674dc84fc8b3265330d09e57c0a59b41f4/pyzmq-27.0.1-pp310-pypy310_pp73-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c9d63d66059114a6756d09169c9209ffceabacb65b9cb0f66e6fc344b20b73e6", size = 567402, upload-time = "2025-08-03T05:05:08.028Z" }, + { url = "https://files.pythonhosted.org/packages/bc/fe/c29ac0d5a817543ecf0cb18f17195805bad0da567a1c64644aacf11b2779/pyzmq-27.0.1-pp310-pypy310_pp73-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1da8e645c655d86f0305fb4c65a0d848f461cd90ee07d21f254667287b5dbe50", size = 747030, upload-time = "2025-08-03T05:05:10.116Z" }, + { url = "https://files.pythonhosted.org/packages/17/d1/cc1fbfb65b4042016e4e035b2548cdfe0945c817345df83aa2d98490e7fc/pyzmq-27.0.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:1843fd0daebcf843fe6d4da53b8bdd3fc906ad3e97d25f51c3fed44436d82a49", size = 544567, upload-time = "2025-08-03T05:05:11.856Z" }, + { url = "https://files.pythonhosted.org/packages/b4/1a/49f66fe0bc2b2568dd4280f1f520ac8fafd73f8d762140e278d48aeaf7b9/pyzmq-27.0.1-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7fb0ee35845bef1e8c4a152d766242164e138c239e3182f558ae15cb4a891f94", size = 835949, upload-time = "2025-08-03T05:05:13.798Z" }, + { url = "https://files.pythonhosted.org/packages/49/94/443c1984b397eab59b14dd7ae8bc2ac7e8f32dbc646474453afcaa6508c4/pyzmq-27.0.1-pp311-pypy311_pp73-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:f379f11e138dfd56c3f24a04164f871a08281194dd9ddf656a278d7d080c8ad0", size = 799875, upload-time = "2025-08-03T05:05:15.632Z" }, + { url = "https://files.pythonhosted.org/packages/30/f1/fd96138a0f152786a2ba517e9c6a8b1b3516719e412a90bb5d8eea6b660c/pyzmq-27.0.1-pp311-pypy311_pp73-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b978c0678cffbe8860ec9edc91200e895c29ae1ac8a7085f947f8e8864c489fb", size = 567403, upload-time = "2025-08-03T05:05:17.326Z" }, + { url = "https://files.pythonhosted.org/packages/16/57/34e53ef2b55b1428dac5aabe3a974a16c8bda3bf20549ba500e3ff6cb426/pyzmq-27.0.1-pp311-pypy311_pp73-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7ebccf0d760bc92a4a7c751aeb2fef6626144aace76ee8f5a63abeb100cae87f", size = 747032, upload-time = "2025-08-03T05:05:19.074Z" }, + { url = "https://files.pythonhosted.org/packages/81/b7/769598c5ae336fdb657946950465569cf18803140fe89ce466d7f0a57c11/pyzmq-27.0.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:77fed80e30fa65708546c4119840a46691290efc231f6bfb2ac2a39b52e15811", size = 544566, upload-time = "2025-08-03T05:05:20.798Z" }, + { url = "https://files.pythonhosted.org/packages/60/8d/c0880acd2d5908eec6fe9b399f0fb630e5f203f8a69f82442d5cb2b2f46c/pyzmq-27.0.1-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:d97b59cbd8a6c8b23524a8ce237ff9504d987dc07156258aa68ae06d2dd5f34d", size = 835946, upload-time = "2025-08-03T05:05:31.161Z" }, + { url = "https://files.pythonhosted.org/packages/c1/35/6b71409aa6629b3d4917b38961501898827f4fb5ddc680cc8e0cb13987f3/pyzmq-27.0.1-pp39-pypy39_pp73-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:27a78bdd384dbbe7b357af95f72efe8c494306b5ec0a03c31e2d53d6763e5307", size = 799870, upload-time = "2025-08-03T05:05:33.01Z" }, + { url = "https://files.pythonhosted.org/packages/16/f6/5d36d8f6571478f32c32f5872abd76eda052746283ca87e24cc5758f7987/pyzmq-27.0.1-pp39-pypy39_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b007e5dcba684e888fbc90554cb12a2f4e492927c8c2761a80b7590209821743", size = 758371, upload-time = "2025-08-03T05:05:34.722Z" }, + { url = "https://files.pythonhosted.org/packages/6f/29/6a7b7f5d47712487d8a3516584a4a484a0147f2537228237397793b2de69/pyzmq-27.0.1-pp39-pypy39_pp73-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:95594b2ceeaa94934e3e94dd7bf5f3c3659cf1a26b1fb3edcf6e42dad7e0eaf2", size = 567395, upload-time = "2025-08-03T05:05:36.701Z" }, + { url = "https://files.pythonhosted.org/packages/eb/37/c1f26d13e9d4c3bfce42fead8ff640f6c06a58decde49a6b295b9d52cefd/pyzmq-27.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:70b719a130b81dd130a57ac0ff636dc2c0127c5b35ca5467d1b67057e3c7a4d2", size = 544561, upload-time = "2025-08-03T05:05:38.608Z" }, +] + [[package]] name = "readme-renderer" version = "44.0" @@ -2270,6 +2711,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/8b/f6/126c9309c8fe93e5d6bb850593cd58d591daf2da45cc78b61e48d8d95879/setuptools-80.1.0-py3-none-any.whl", hash = "sha256:ea0e7655c05b74819f82e76e11a85b31779fee7c4969e82f72bab0664e8317e4", size = 1240689, upload-time = "2025-04-30T17:41:03.789Z" }, ] +[[package]] +name = "shellingham" +version = "1.5.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" }, +] + [[package]] name = "six" version = "1.17.0" @@ -2536,6 +2986,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9c/4d/103e541e2533df159e1070cd4372b447a0b689e08a49d271b7b950e21f92/sphobjinv-2.3.1.2-py3-none-any.whl", hash = "sha256:66478d1787d28ef3ebeeedad57c592fdea04cf10eeed0df56307c85ab4eee789", size = 50820, upload-time = "2024-12-22T22:34:10.572Z" }, ] +[[package]] +name = "stack-data" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "asttokens" }, + { name = "executing" }, + { name = "pure-eval" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/28/e3/55dcc2cfbc3ca9c29519eb6884dd1415ecb53b0e934862d3559ddcb7e20b/stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9", size = 44707, upload-time = "2023-09-30T13:58:05.479Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521, upload-time = "2023-09-30T13:58:03.53Z" }, +] + [[package]] name = "starlette" version = "0.46.2" @@ -2672,6 +3136,25 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/77/6391214d084a85aeb099d520420d39f405928b6a5f27a3f1a453c27c5173/torch-2.7.1-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:a737b5edd1c44a5c1ece2e9f3d00df9d1b3fb9541138bee56d83d38293fb6c9d", size = 68630146, upload-time = "2025-06-04T17:35:26.434Z" }, ] +[[package]] +name = "tornado" +version = "6.5.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/09/ce/1eb500eae19f4648281bb2186927bb062d2438c2e5093d1360391afd2f90/tornado-6.5.2.tar.gz", hash = "sha256:ab53c8f9a0fa351e2c0741284e06c7a45da86afb544133201c5cc8578eb076a0", size = 510821, upload-time = "2025-08-08T18:27:00.78Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f6/48/6a7529df2c9cc12efd2e8f5dd219516184d703b34c06786809670df5b3bd/tornado-6.5.2-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:2436822940d37cde62771cff8774f4f00b3c8024fe482e16ca8387b8a2724db6", size = 442563, upload-time = "2025-08-08T18:26:42.945Z" }, + { url = "https://files.pythonhosted.org/packages/f2/b5/9b575a0ed3e50b00c40b08cbce82eb618229091d09f6d14bce80fc01cb0b/tornado-6.5.2-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:583a52c7aa94ee046854ba81d9ebb6c81ec0fd30386d96f7640c96dad45a03ef", size = 440729, upload-time = "2025-08-08T18:26:44.473Z" }, + { url = "https://files.pythonhosted.org/packages/1b/4e/619174f52b120efcf23633c817fd3fed867c30bff785e2cd5a53a70e483c/tornado-6.5.2-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b0fe179f28d597deab2842b86ed4060deec7388f1fd9c1b4a41adf8af058907e", size = 444295, upload-time = "2025-08-08T18:26:46.021Z" }, + { url = "https://files.pythonhosted.org/packages/95/fa/87b41709552bbd393c85dd18e4e3499dcd8983f66e7972926db8d96aa065/tornado-6.5.2-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b186e85d1e3536d69583d2298423744740986018e393d0321df7340e71898882", size = 443644, upload-time = "2025-08-08T18:26:47.625Z" }, + { url = "https://files.pythonhosted.org/packages/f9/41/fb15f06e33d7430ca89420283a8762a4e6b8025b800ea51796ab5e6d9559/tornado-6.5.2-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e792706668c87709709c18b353da1f7662317b563ff69f00bab83595940c7108", size = 443878, upload-time = "2025-08-08T18:26:50.599Z" }, + { url = "https://files.pythonhosted.org/packages/11/92/fe6d57da897776ad2e01e279170ea8ae726755b045fe5ac73b75357a5a3f/tornado-6.5.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:06ceb1300fd70cb20e43b1ad8aaee0266e69e7ced38fa910ad2e03285009ce7c", size = 444549, upload-time = "2025-08-08T18:26:51.864Z" }, + { url = "https://files.pythonhosted.org/packages/9b/02/c8f4f6c9204526daf3d760f4aa555a7a33ad0e60843eac025ccfd6ff4a93/tornado-6.5.2-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:74db443e0f5251be86cbf37929f84d8c20c27a355dd452a5cfa2aada0d001ec4", size = 443973, upload-time = "2025-08-08T18:26:53.625Z" }, + { url = "https://files.pythonhosted.org/packages/ae/2d/f5f5707b655ce2317190183868cd0f6822a1121b4baeae509ceb9590d0bd/tornado-6.5.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b5e735ab2889d7ed33b32a459cac490eda71a1ba6857b0118de476ab6c366c04", size = 443954, upload-time = "2025-08-08T18:26:55.072Z" }, + { url = "https://files.pythonhosted.org/packages/e8/59/593bd0f40f7355806bf6573b47b8c22f8e1374c9b6fd03114bd6b7a3dcfd/tornado-6.5.2-cp39-abi3-win32.whl", hash = "sha256:c6f29e94d9b37a95013bb669616352ddb82e3bfe8326fccee50583caebc8a5f0", size = 445023, upload-time = "2025-08-08T18:26:56.677Z" }, + { url = "https://files.pythonhosted.org/packages/c7/2a/f609b420c2f564a748a2d80ebfb2ee02a73ca80223af712fca591386cafb/tornado-6.5.2-cp39-abi3-win_amd64.whl", hash = "sha256:e56a5af51cc30dd2cae649429af65ca2f6571da29504a07995175df14c18f35f", size = 445427, upload-time = "2025-08-08T18:26:57.91Z" }, + { url = "https://files.pythonhosted.org/packages/5e/4f/e1f65e8f8c76d73658b33d33b81eed4322fb5085350e4328d5c956f0c8f9/tornado-6.5.2-cp39-abi3-win_arm64.whl", hash = "sha256:d6c33dc3672e3a1f3618eb63b7ef4683a7688e7b9e6e8f0d9aa5726360a004af", size = 444456, upload-time = "2025-08-08T18:26:59.207Z" }, +] + [[package]] name = "tqdm" version = "4.67.1" @@ -2684,6 +3167,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, ] +[[package]] +name = "traitlets" +version = "5.14.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/eb/79/72064e6a701c2183016abbbfedaba506d81e30e232a68c9f0d6f6fcd1574/traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7", size = 161621, upload-time = "2024-04-19T11:11:49.746Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359, upload-time = "2024-04-19T11:11:46.763Z" }, +] + [[package]] name = "transformers" version = "4.51.3" @@ -2746,6 +3238,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7c/b6/74e927715a285743351233f33ea3c684528a0d374d2e43ff9ce9585b73fe/twine-6.1.0-py3-none-any.whl", hash = "sha256:a47f973caf122930bf0fbbf17f80b83bc1602c9ce393c7845f289a3001dc5384", size = 40791, upload-time = "2025-01-21T18:45:24.584Z" }, ] +[[package]] +name = "typer" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "rich" }, + { name = "shellingham" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c5/8c/7d682431efca5fd290017663ea4588bf6f2c6aad085c7f108c5dbc316e70/typer-0.16.0.tar.gz", hash = "sha256:af377ffaee1dbe37ae9440cb4e8f11686ea5ce4e9bae01b84ae7c63b87f1dd3b", size = 102625, upload-time = "2025-05-26T14:30:31.824Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/76/42/3efaf858001d2c2913de7f354563e3a3a2f0decae3efe98427125a8f441e/typer-0.16.0-py3-none-any.whl", hash = "sha256:1f79bed11d4d02d4310e3c1b7ba594183bcedb0ac73b27a9e5f28f6fb5b98855", size = 46317, upload-time = "2025-05-26T14:30:30.523Z" }, +] + [[package]] name = "typing-extensions" version = "4.13.2" @@ -2913,6 +3420,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/46/d3/ecc62cbd7054f0812f3a7ca7c1c9f7ba99ba45efcfc8297a9fcd2c87b31c/watchfiles-1.0.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0901429650652d3f0da90bad42bdafc1f9143ff3605633c455c999a2d786cac", size = 456511, upload-time = "2025-04-08T10:36:25.42Z" }, ] +[[package]] +name = "wcwidth" +version = "0.2.13" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/63/53559446a878410fc5a5974feb13d31d78d752eb18aeba59c7fef1af7598/wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5", size = 101301, upload-time = "2024-01-06T02:10:57.829Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166, upload-time = "2024-01-06T02:10:55.763Z" }, +] + [[package]] name = "websockets" version = "15.0.1" @@ -2978,6 +3494,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/a8/5b41e0da817d64113292ab1f8247140aac61cbf6cfd085d6a0fa77f4984f/websockets-15.0.1-py3-none-any.whl", hash = "sha256:f7a866fbc1e97b5c617ee4116daaa09b722101d4a3c170c787450ba409f9736f", size = 169743, upload-time = "2025-03-05T20:03:39.41Z" }, ] +[[package]] +name = "widgetsnbextension" +version = "4.0.14" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/41/53/2e0253c5efd69c9656b1843892052a31c36d37ad42812b5da45c62191f7e/widgetsnbextension-4.0.14.tar.gz", hash = "sha256:a3629b04e3edb893212df862038c7232f62973373869db5084aed739b437b5af", size = 1097428, upload-time = "2025-04-10T13:01:25.628Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ca/51/5447876806d1088a0f8f71e16542bf350918128d0a69437df26047c8e46f/widgetsnbextension-4.0.14-py3-none-any.whl", hash = "sha256:4875a9eaf72fbf5079dc372a51a9f268fc38d46f767cbf85c43a36da5cb9b575", size = 2196503, upload-time = "2025-04-10T13:01:23.086Z" }, +] + [[package]] name = "xxhash" version = "3.5.0" From 227d5444c75d8cb73e3daeff188582db3905bda8 Mon Sep 17 00:00:00 2001 From: ianbulovic Date: Fri, 29 Aug 2025 13:14:41 -0400 Subject: [PATCH 4/8] consolidate rest apis --- src/cnlpt/api/__init__.py | 71 ------ src/cnlpt/api/cnn_rest.py | 102 --------- src/cnlpt/api/current_rest.py | 106 --------- src/cnlpt/api/dtr_rest.py | 106 --------- src/cnlpt/api/event_rest.py | 158 ------------- src/cnlpt/api/hier_rest.py | 104 --------- src/cnlpt/api/negation_rest.py | 112 ---------- src/cnlpt/api/temporal_rest.py | 373 ------------------------------- src/cnlpt/api/termexists_rest.py | 106 --------- src/cnlpt/api/timex_rest.py | 156 ------------- src/cnlpt/api/utils.py | 157 ------------- src/cnlpt/data/analysis.py | 96 +++++--- src/cnlpt/data/predictions.py | 39 +++- src/cnlpt/rest/__init__.py | 3 + src/cnlpt/rest/cnlp_rest.py | 208 +++++++++++++++++ 15 files changed, 298 insertions(+), 1599 deletions(-) delete mode 100644 src/cnlpt/api/__init__.py delete mode 100644 src/cnlpt/api/cnn_rest.py delete mode 100644 src/cnlpt/api/current_rest.py delete mode 100644 src/cnlpt/api/dtr_rest.py delete mode 100644 src/cnlpt/api/event_rest.py delete mode 100644 src/cnlpt/api/hier_rest.py delete mode 100644 src/cnlpt/api/negation_rest.py delete mode 100644 src/cnlpt/api/temporal_rest.py delete mode 100644 src/cnlpt/api/termexists_rest.py delete mode 100644 src/cnlpt/api/timex_rest.py delete mode 100644 src/cnlpt/api/utils.py create mode 100644 src/cnlpt/rest/__init__.py create mode 100644 src/cnlpt/rest/cnlp_rest.py diff --git a/src/cnlpt/api/__init__.py b/src/cnlpt/api/__init__.py deleted file mode 100644 index a3e8592f..00000000 --- a/src/cnlpt/api/__init__.py +++ /dev/null @@ -1,71 +0,0 @@ -"""Serve REST APIs for CNLPT models over your network.""" - -from typing import Final - -MODEL_TYPES: Final = ( - "cnn", - "current", - "dtr", - # "event", - "hier", - "negation", - "temporal", - # "termexists", - # "timex", -) -"""The available model types for :func:`get_rest_app`.""" - - -def get_rest_app(model_type: str): - """Get a FastAPI app for a certain model type. - - Args: - model_type: The type of model to serve. - - Returns: - The FastAPI app. - """ - if model_type == "cnn": - from .cnn_rest import app - - return app - elif model_type == "current": - from .current_rest import app - - return app - elif model_type == "dtr": - from .dtr_rest import app - - return app - # elif model_type == "event": - # from .event_rest import app - - # return app - elif model_type == "hier": - from .hier_rest import app - - return app - elif model_type == "negation": - from .negation_rest import app - - return app - elif model_type == "temporal": - from .temporal_rest import app - - return app - # elif model_type == "termexists": - # from .termexists_rest import app - - # return app - # elif model_type == "timex": - # from .timex_rest import app - - # return app - else: - raise ValueError(f"unknown model type: {model_type}") - - -__all__ = [ - "MODEL_TYPES", - "get_rest_app", -] diff --git a/src/cnlpt/api/cnn_rest.py b/src/cnlpt/api/cnn_rest.py deleted file mode 100644 index af556e95..00000000 --- a/src/cnlpt/api/cnn_rest.py +++ /dev/null @@ -1,102 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import json -import logging -import os -import sys -from contextlib import asynccontextmanager -from os.path import join -from typing import Any - -import numpy as np -import torch -import torch.backends.mps -from fastapi import FastAPI -from scipy.special import softmax -from transformers import AutoTokenizer, PreTrainedTokenizer - -from ..models.baseline import CnnSentenceClassifier -from .utils import UnannotatedDocument, create_dataset, resolve_device - -MODEL_NAME = os.getenv("MODEL_PATH") -device = os.getenv("MODEL_DEVICE", "auto") -device = resolve_device(device) - -logger = logging.getLogger("CNN_REST_Processor") -logger.setLevel(logging.DEBUG) - -MAX_SEQ_LENGTH = 128 - -model: CnnSentenceClassifier -tokenizer: PreTrainedTokenizer -conf_dict: dict[str, Any] - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global model, tokenizer, conf_dict - if MODEL_NAME is None: - sys.stderr.write( - "This REST container requires a MODEL_PATH environment variable\n" - ) - sys.exit(-1) - conf_file = join(MODEL_NAME, "config.json") - with open(conf_file) as fp: - conf_dict = json.load(fp) - - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) - num_labels_dict = { - task: len(values) for task, values in conf_dict["label_dictionary"].items() - } - model = CnnSentenceClassifier.from_pretrained( - MODEL_NAME, - vocab_size=len(tokenizer), - task_names=conf_dict["task_names"], - num_labels_dict=num_labels_dict, - embed_dims=conf_dict["cnn_embed_dim"], - num_filters=conf_dict["cnn_num_filters"], - filters=conf_dict["cnn_filter_sizes"], - ) - - model = model.to(device) - tokenizer = tokenizer - conf_dict = conf_dict - - yield - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/cnn/classify") -async def process(doc: UnannotatedDocument): - instances = [doc.doc_text] - dataset = create_dataset( - instances, tokenizer, max_length=conf_dict["max_seq_length"] - ) - _, logits = model.forward( - input_ids=torch.LongTensor(dataset["input_ids"]).to(device), - attention_mask=torch.LongTensor(dataset["attention_mask"]).to(device), - ) - - prediction = int(np.argmax(logits[0].cpu().detach().numpy(), axis=1)) - result = conf_dict["label_dictionary"][conf_dict["task_names"][0]][prediction] - probabilities = softmax(logits[0][0].cpu().detach().numpy()) - # for redcap purposes, it might make more sense to only output the probability for the predicted class, - # but i'm outputting them all, for transparency - out_probabilities = [str(prob) for prob in probabilities] - return {"result": result, "probabilities": out_probabilities} diff --git a/src/cnlpt/api/current_rest.py b/src/cnlpt/api/current_rest.py deleted file mode 100644 index 682c5228..00000000 --- a/src/cnlpt/api/current_rest.py +++ /dev/null @@ -1,106 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -from contextlib import asynccontextmanager -from time import time - -import numpy as np -from fastapi import FastAPI -from pydantic import BaseModel -from transformers import Trainer -from transformers.tokenization_utils import PreTrainedTokenizer - -from .utils import ( - EntityDocument, - create_dataset, - create_instance_string, - initialize_cnlpt_model, -) - -logger = logging.getLogger("Current_REST_Processor") -logger.setLevel(logging.DEBUG) - -MODEL_NAME = "mlml-chip/current-thyme" -TASK = "Current" -LABELS = [False, True] - -MAX_LENGTH = 128 - - -class CurrentResults(BaseModel): - """statuses: list of classifier outputs for every input""" - - statuses: list[bool] - - -tokenizer: PreTrainedTokenizer -trainer: Trainer - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global tokenizer, trainer - tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME) - yield - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/current/process") -async def process(doc: EntityDocument): - doc_text = doc.doc_text - logger.warning( - f"Received document of len {len(doc_text)} to process with {len(doc.entities)} entities" - ) - instances = [] - start_time = time() - - if len(doc.entities) == 0: - return CurrentResults(statuses=[]) - - for ent_ind, offsets in enumerate(doc.entities): - inst_str = create_instance_string(doc_text, offsets) - logger.debug(f"Instance string is {inst_str}") - instances.append(inst_str) - - dataset = create_dataset(instances, tokenizer, MAX_LENGTH) - preproc_end = time() - - output = trainer.predict(test_dataset=dataset) - predictions = output.predictions[0] - predictions = np.argmax(predictions, axis=1) - - pred_end = time() - - results = [] - for ent_ind in range(len(dataset)): - results.append(LABELS[predictions[ent_ind]]) - - output = CurrentResults(statuses=results) - - postproc_end = time() - - preproc_time = preproc_end - start_time - pred_time = pred_end - preproc_end - postproc_time = postproc_end - pred_end - - logging.info( - f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}" - ) - - return output diff --git a/src/cnlpt/api/dtr_rest.py b/src/cnlpt/api/dtr_rest.py deleted file mode 100644 index ee85e6c9..00000000 --- a/src/cnlpt/api/dtr_rest.py +++ /dev/null @@ -1,106 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -from contextlib import asynccontextmanager -from time import time - -import numpy as np -from fastapi import FastAPI -from pydantic import BaseModel -from transformers import Trainer -from transformers.tokenization_utils import PreTrainedTokenizer - -from .temporal_rest import OLD_DTR_LABEL_LIST -from .utils import ( - EntityDocument, - create_dataset, - create_instance_string, - initialize_cnlpt_model, -) - -MODEL_NAME = "tmills/tiny-dtr" -logger = logging.getLogger("DocTimeRel Processor with xtremedistil encoder") -logger.setLevel(logging.INFO) - -MAX_LENGTH = 128 - - -class DocTimeRelResults(BaseModel): - """statuses: dictionary from entity id to classification decision about DocTimeRel""" - - statuses: list[str] - - -tokenizer: PreTrainedTokenizer -trainer: Trainer - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global tokenizer, trainer - tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME) - yield - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/dtr/process") -async def process(doc: EntityDocument): - doc_text = doc.doc_text - logger.warning( - f"Received document of len {len(doc_text)} to process with {len(doc.entities)} entities" - ) - instances = [] - start_time = time() - - if len(doc.entities) == 0: - return DocTimeRelResults(statuses=[]) - - for ent_ind, offsets in enumerate(doc.entities): - # logger.debug('Entity ind: %d has offsets (%d, %d)' % (ent_ind, offsets[0], offsets[1])) - inst_str = create_instance_string(doc_text, offsets) - logger.debug(f"Instance string is {inst_str}") - instances.append(inst_str) - - dataset = create_dataset(instances, tokenizer, max_length=MAX_LENGTH) - - preproc_end = time() - - output = trainer.predict(test_dataset=dataset) - predictions = output.predictions[0] - predictions = np.argmax(predictions, axis=1) - - pred_end = time() - - results = [] - for ent_ind in range(len(dataset)): - results.append(OLD_DTR_LABEL_LIST[predictions[ent_ind]]) - - output = DocTimeRelResults(statuses=results) - - postproc_end = time() - - preproc_time = preproc_end - start_time - pred_time = pred_end - preproc_end - postproc_time = postproc_end - pred_end - - logging.warning( - f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}" - ) - - return output diff --git a/src/cnlpt/api/event_rest.py b/src/cnlpt/api/event_rest.py deleted file mode 100644 index e2d34c75..00000000 --- a/src/cnlpt/api/event_rest.py +++ /dev/null @@ -1,158 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import logging -from contextlib import asynccontextmanager -from time import time - -import numpy as np -from fastapi import FastAPI -from nltk.tokenize import wordpunct_tokenize as tokenize -from seqeval.metrics.sequence_labeling import get_entities -from transformers import Trainer -from transformers.tokenization_utils import PreTrainedTokenizer - -from .temporal_rest import ( - EVENT_LABEL_LIST, - Event, - SentenceDocument, - TemporalResults, - TokenizedSentenceDocument, - create_instance_string, -) -from .utils import create_dataset, initialize_cnlpt_model - -MODEL_NAME = "tmills/event-thyme-colon-pubmedbert" -logger = logging.getLogger("Event_REST_Processor") -logger.setLevel(logging.INFO) - -MAX_LENGTH = 128 - - -tokenizer: PreTrainedTokenizer -trainer: Trainer - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global tokenizer, trainer - tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME) - yield - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/temporal/process") -async def process(doc: TokenizedSentenceDocument): - return process_tokenized_sentence_document(doc) - - -@app.post("/temporal/process_sentence") -async def process_sentence(doc: SentenceDocument): - tokenized_sent = tokenize(doc.sentence) - doc = TokenizedSentenceDocument( - sent_tokens=[ - tokenized_sent, - ], - metadata="Single sentence", - ) - return process_tokenized_sentence_document(doc) - - -def process_tokenized_sentence_document(doc: TokenizedSentenceDocument): - sents = doc.sent_tokens - metadata = doc.metadata - - logger.warning(f"Received document labeled {metadata} with {len(sents)} sentences") - instances = [] - start_time = time() - - for sent_ind, token_list in enumerate(sents): - inst_str = create_instance_string(token_list) - logger.debug(f"Instance string is {inst_str}") - instances.append(inst_str) - - dataset = create_dataset(instances, tokenizer, max_length=MAX_LENGTH) - preproc_end = time() - - output = trainer.predict(test_dataset=dataset) - - event_predictions = np.argmax(output.predictions[0], axis=2) - - pred_end = time() - - timex_results = [] - event_results = [] - rel_results = [] - - for sent_ind in range(len(dataset)): - batch_encoding = tokenizer.batch_encode_plus( - [ - sents[sent_ind], - ], - is_split_into_words=True, - max_length=MAX_LENGTH, - ) - word_ids = batch_encoding.word_ids(0) - wpind_to_ind = {} - event_labels = [] - previous_word_idx = None - - for word_pos_idx, word_idx in enumerate(word_ids): - if word_idx != previous_word_idx and word_idx is not None: - key = word_pos_idx - val = len(wpind_to_ind) - - wpind_to_ind[key] = val - event_labels.append( - EVENT_LABEL_LIST[event_predictions[sent_ind][word_pos_idx]] - ) - previous_word_idx = word_idx - - event_entities = get_entities(event_labels) - logging.info(f"Extracted {len(event_entities)} events from the sentence") - event_results.append( - [ - Event(dtr=label[0], begin=label[1], end=label[2]) - for label in event_entities - ] - ) - timex_results.append([]) - rel_results.append([]) - - results = TemporalResults( - timexes=timex_results, events=event_results, relations=rel_results - ) - - postproc_end = time() - - preproc_time = preproc_end - start_time - pred_time = pred_end - preproc_end - postproc_time = postproc_end - pred_end - - logging.info( - f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}" - ) - - return results - - -@app.post("/temporal/collection_process_complete") -async def collection_process_complete(): - global trainer - trainer = None diff --git a/src/cnlpt/api/hier_rest.py b/src/cnlpt/api/hier_rest.py deleted file mode 100644 index 66711866..00000000 --- a/src/cnlpt/api/hier_rest.py +++ /dev/null @@ -1,104 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -import os -from contextlib import asynccontextmanager - -import torch -from fastapi import FastAPI -from transformers import PreTrainedModel -from transformers.tokenization_utils import PreTrainedTokenizer - -from .utils import ( - UnannotatedDocument, - create_dataset, - initialize_hier_model, - resolve_device, -) - -MODEL_NAME = os.getenv("MODEL_PATH") - -device = os.getenv("MODEL_DEVICE", "auto") -device = resolve_device(device) - -logger = logging.getLogger("HierRep_REST_Processor") -logger.setLevel(logging.DEBUG) - -tokenizer: PreTrainedTokenizer -model: PreTrainedModel - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global tokenizer, model - tokenizer, model = initialize_hier_model(MODEL_NAME) - yield - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/hier/get_rep") -async def get_representation(doc: UnannotatedDocument): - instances = [doc.doc_text] - dataset = create_dataset( - instances, - tokenizer, - max_length=16000, - hier=True, - chunk_len=200, - num_chunks=80, - insert_empty_chunk_at_beginning=False, - ) - - result = model.forward( - input_ids=torch.LongTensor(dataset["input_ids"]).to(model.device), - token_type_ids=torch.LongTensor(dataset["token_type_ids"]).to(model.device), - attention_mask=torch.LongTensor(dataset["attention_mask"]).to(model.device), - output_hidden_states=True, - ) - - # Convert to a list so python can send it out - hidden_states = result["hidden_states"].to("cpu").detach().numpy()[:, 0, :].tolist() - return {"reps": hidden_states[0]} - - -@app.post("/hier/classify") -async def classify(doc: UnannotatedDocument): - instances = [doc.doc_text] - dataset = create_dataset( - instances, - tokenizer, - max_length=16000, - hier=True, - chunk_len=200, - num_chunks=80, - insert_empty_chunk_at_beginning=False, - ) - result = model.forward( - input_ids=torch.LongTensor(dataset["input_ids"]).to(model.device), - token_type_ids=torch.LongTensor(dataset["token_type_ids"]).to(model.device), - attention_mask=torch.LongTensor(dataset["attention_mask"]).to(model.device), - output_hidden_states=False, - ) - - predictions = [ - int(torch.argmax(logits.to("cpu").detach()).numpy()) - for logits in result["logits"] - ] - labels = [next(iter(model.label_dictionary.values()))[x] for x in predictions] - return {"result": labels} diff --git a/src/cnlpt/api/negation_rest.py b/src/cnlpt/api/negation_rest.py deleted file mode 100644 index e4ba78b2..00000000 --- a/src/cnlpt/api/negation_rest.py +++ /dev/null @@ -1,112 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -from contextlib import asynccontextmanager -from time import time - -import numpy as np -from fastapi import FastAPI -from pydantic import BaseModel -from transformers import Trainer -from transformers.tokenization_utils import PreTrainedTokenizer - -from .utils import ( - EntityDocument, - create_dataset, - create_instance_string, - initialize_cnlpt_model, -) - -MODEL_NAME = "mlml-chip/negation_pubmedbert_sharpseed" -logger = logging.getLogger("Negation_REST_Processor") -logger.setLevel(logging.DEBUG) - -TASK = "Negation" -LABELS = [-1, 1] - -MAX_LENGTH = 128 - - -class NegationResults(BaseModel): - """statuses: dictionary from entity id to classification decision about negation; true -> negated, false -> not negated""" - - statuses: list[int] - - -tokenizer: PreTrainedTokenizer -trainer: Trainer - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global tokenizer, trainer - tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME) - yield - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/negation/process") -async def process(doc: EntityDocument): - doc_text = doc.doc_text - logger.warning( - f"Received document of len {len(doc_text)} to process with {len(doc.entities)} entities" - ) - instances = [] - start_time = time() - - if len(doc.entities) == 0: - return NegationResults(statuses=[]) - - for ent_ind, offsets in enumerate(doc.entities): - # logger.debug('Entity ind: %d has offsets (%d, %d)' % (ent_ind, offsets[0], offsets[1])) - inst_str = create_instance_string(doc_text, offsets) - logger.debug(f"Instance string is {inst_str}") - instances.append(inst_str) - - dataset = create_dataset(instances, tokenizer, MAX_LENGTH) - preproc_end = time() - - output = trainer.predict(test_dataset=dataset) - predictions = output.predictions[0] - predictions = np.argmax(predictions, axis=1) - - pred_end = time() - - results = [] - for ent_ind in range(len(dataset)): - results.append(LABELS[predictions[ent_ind]]) - - output = NegationResults(statuses=results) - - postproc_end = time() - - preproc_time = preproc_end - start_time - pred_time = pred_end - preproc_end - postproc_time = postproc_end - pred_end - - logging.warning( - f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}" - ) - - return output - - -@app.get("/negation/{test_str}") -async def test(test_str: str): - return {"argument": test_str} diff --git a/src/cnlpt/api/temporal_rest.py b/src/cnlpt/api/temporal_rest.py deleted file mode 100644 index 7b23e495..00000000 --- a/src/cnlpt/api/temporal_rest.py +++ /dev/null @@ -1,373 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import logging -import os -from contextlib import asynccontextmanager -from time import time -from typing import Union - -import numpy as np -from fastapi import FastAPI -from nltk.tokenize import wordpunct_tokenize as tokenize -from pydantic import BaseModel -from seqeval.metrics.sequence_labeling import get_entities -from transformers import Trainer -from transformers.tokenization_utils import PreTrainedTokenizer - -from .utils import create_dataset, initialize_cnlpt_model - -MODEL_NAME = "mlml-chip/thyme2_colon_e2e" -logger = logging.getLogger("Temporal_REST_Processor") -logger.setLevel(logging.INFO) - -LABELS = ["-1", "1"] -TIMEX_LABEL_LIST = [ - "O", - "B-DATE", - "B-DURATION", - "B-PREPOSTEXP", - "B-QUANTIFIER", - "B-SET", - "B-TIME", - "B-SECTIONTIME", - "B-DOCTIME", - "I-DATE", - "I-DURATION", - "I-PREPOSTEXP", - "I-QUANTIFIER", - "I-SET", - "I-TIME", - "I-SECTIONTIME", - "I-DOCTIME", -] -TIMEX_LABEL_DICT = {val: ind for ind, val in enumerate(TIMEX_LABEL_LIST)} -EVENT_LABEL_LIST = [ - "O", - "B-AFTER", - "B-BEFORE", - "B-BEFORE/OVERLAP", - "B-OVERLAP", - "I-AFTER", - "I-BEFORE", - "I-BEFORE/OVERLAP", - "I-OVERLAP", -] -EVENT_LABEL_DICT = {val: ind for ind, val in enumerate(EVENT_LABEL_LIST)} - -RELATION_LABEL_LIST = ["None", "CONTAINS", "OVERLAP", "BEFORE", "BEGINS-ON", "ENDS-ON"] -RELATION_LABEL_DICT = {val: ind for ind, val in enumerate(RELATION_LABEL_LIST)} - -DTR_LABEL_LIST = ["AFTER", "BEFORE", "BEFORE/OVERLAP", "OVERLAP"] -OLD_DTR_LABEL_LIST = ["BEFORE", "OVERLAP", "BEFORE/OVERLAP", "AFTER"] - -LABELS = [TIMEX_LABEL_LIST, EVENT_LABEL_LIST, RELATION_LABEL_LIST] -MAX_LENGTH = 128 - - -class SentenceDocument(BaseModel): - sentence: str - - -class TokenizedSentenceDocument(BaseModel): - """sent_tokens: a list of sentences, where each sentence is a list of tokens""" - - sent_tokens: list[list[str]] - metadata: str - - -class Timex(BaseModel): - begin: int - end: int - timeClass: str - - -class Event(BaseModel): - begin: int - end: int - dtr: str - - -class Relation(BaseModel): - # Allow args to be none, so that we can potentially link them to times or events in the client, or if they don't - # care about that. pass back the token indices of the args in addition. - arg1: Union[str, None] - arg2: Union[str, None] - category: str - arg1_start: int - arg2_start: int - - -class TemporalResults(BaseModel): - """lists of timexes, events and relations for list of sentences""" - - timexes: list[list[Timex]] - events: list[list[Event]] - relations: list[list[Relation]] - - -def create_instance_string(tokens: list[str]): - return " ".join(tokens) - - -task_order: dict[str, int] -tasks: list[str] -tokenizer: PreTrainedTokenizer -trainer: Trainer - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global \ - TIMEX_LABEL_LIST, \ - TIMEX_LABEL_DICT, \ - EVENT_LABEL_LIST, \ - EVENT_LABEL_DICT, \ - RELATION_LABEL_LIST, \ - RELATION_LABEL_DICT, \ - task_order, \ - tasks, \ - tokenizer, \ - trainer - - local_model_name = os.getenv("MODEL_NAME", MODEL_NAME) - tokenizer, trainer = initialize_cnlpt_model(local_model_name) - - config_dict = trainer.model.config.to_dict() - # For newer models (version >= 0.6.0), the label dictionary is saved with the model - # config. we can look for it to preserve backwards compatibility for now but - # should eventually remove the hardcoded label lists from our inference tools. - label_dict = config_dict.get("label_dictionary", None) - if label_dict is not None: - # some older versions have one label dictionary per dataset, future versions should just - # have a task-keyed dictionary - if type(label_dict) is list: - label_dict = label_dict[0] - - if "event" in label_dict: - EVENT_LABEL_LIST = label_dict["event"] - EVENT_LABEL_DICT = {val: ind for ind, val in enumerate(EVENT_LABEL_LIST)} - print(EVENT_LABEL_LIST) - - if "timex" in label_dict: - TIMEX_LABEL_LIST = label_dict["timex"] - TIMEX_LABEL_DICT = {val: ind for ind, val in enumerate(TIMEX_LABEL_LIST)} - print(TIMEX_LABEL_LIST) - - if "tlinkx" in label_dict: - RELATION_LABEL_LIST = label_dict["tlinkx"] - RELATION_LABEL_DICT = { - val: ind for ind, val in enumerate(RELATION_LABEL_LIST) - } - print(RELATION_LABEL_LIST) - - tasks = config_dict.get("finetuning_task", None) - task_order = {} - if tasks is not None: - print("Overwriting finetuning task order") - for task_ind, task_name in enumerate(tasks): - task_order[task_name] = task_ind - print(task_order) - else: - print("Didn't find a new task ordering in the model config") - yield - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/temporal/process") -async def process(doc: TokenizedSentenceDocument): - return process_tokenized_sentence_document(doc) - - -@app.post("/temporal/process_sentence") -async def process_sentence(doc: SentenceDocument): - tokenized_sent = tokenize(doc.sentence) - doc = TokenizedSentenceDocument( - sent_tokens=[ - tokenized_sent, - ], - metadata="Single sentence", - ) - return process_tokenized_sentence_document(doc) - - -def process_tokenized_sentence_document(doc: TokenizedSentenceDocument): - sents = doc.sent_tokens - metadata = doc.metadata - - print(EVENT_LABEL_LIST) - print(TIMEX_LABEL_LIST) - print(RELATION_LABEL_LIST) - - logger.warning(f"Received document labeled {metadata} with {len(sents)} sentences") - instances = [] - start_time = time() - - for sent_ind, token_list in enumerate(sents): - inst_str = create_instance_string(token_list) - logger.debug(f"Instance string is {inst_str}") - instances.append(inst_str) - - dataset = create_dataset(instances, tokenizer, MAX_LENGTH) - preproc_end = time() - - output = trainer.predict(test_dataset=dataset) - - timex_predictions = np.argmax(output.predictions[task_order["timex"]], axis=2) - event_predictions = np.argmax(output.predictions[task_order["event"]], axis=2) - rel_predictions = np.argmax(output.predictions[task_order["tlinkx"]], axis=3) - rel_inds = np.where(rel_predictions != RELATION_LABEL_DICT["None"]) - - logging.debug(f"Found relation indices: {rel_inds!s}") - - rels_by_sent = {} - for rel_num in range(len(rel_inds[0])): - sent_ind = rel_inds[0][rel_num] - if sent_ind not in rels_by_sent: - rels_by_sent[sent_ind] = [] - - arg1_ind = rel_inds[1][rel_num] - arg2_ind = rel_inds[2][rel_num] - if arg1_ind == arg2_ind: - # no relations between an entity and itself - logger.warning("Found relation between an entity and itself... skipping") - continue - - rel_cat = rel_predictions[sent_ind, arg1_ind, arg2_ind] - - rels_by_sent[sent_ind].append((arg1_ind, arg2_ind, rel_cat)) - - pred_end = time() - - timex_results = [] - event_results = [] - rel_results = [] - - for sent_ind in range(len(dataset)): - batch_encoding = tokenizer( - [ - sents[sent_ind], - ], - is_split_into_words=True, - max_length=MAX_LENGTH, - ) - word_ids = batch_encoding.word_ids(0) - wpind_to_ind = {} - timex_labels = [] - event_labels = [] - previous_word_idx = None - - for word_pos_idx, word_idx in enumerate(word_ids): - if word_idx != previous_word_idx and word_idx is not None: - key = word_pos_idx - val = len(wpind_to_ind) - - wpind_to_ind[key] = val - # tokeni_to_wpi[val] = key - timex_labels.append( - TIMEX_LABEL_LIST[timex_predictions[sent_ind][word_pos_idx]] - ) - try: - event_labels.append( - EVENT_LABEL_LIST[event_predictions[sent_ind][word_pos_idx]] - ) - except Exception as e: - print( - f"exception thrown when sent_ind={sent_ind} and word_pos_idx={word_pos_idx}" - ) - print( - f"prediction is {event_predictions[sent_ind][word_pos_idx]!s}" - ) - raise e - - previous_word_idx = word_idx - - timex_entities = get_entities(timex_labels) - logging.info( - f"Extracted {len(timex_entities)} timex entities from the sentence" - ) - timex_results.append( - [ - Timex(timeClass=label[0], begin=label[1], end=label[2]) - for label in timex_entities - ] - ) - - event_entities = get_entities(event_labels) - logging.info(f"Extracted {len(event_entities)} events from the sentence") - event_results.append( - [ - Event(dtr=label[0], begin=label[1], end=label[2]) - for label in event_entities - ] - ) - - rel_sent_results = [] - for rel in rels_by_sent.get(sent_ind, []): - arg1 = None - arg2 = None - if rel[0] not in wpind_to_ind or rel[1] not in wpind_to_ind: - logging.warning( - "Found a relation to a non-leading wordpiece token... ignoring" - ) - continue - - arg1_ind = wpind_to_ind[rel[0]] - arg2_ind = wpind_to_ind[rel[1]] - - sent_timexes = timex_results[-1] - for timex_ind, timex in enumerate(sent_timexes): - if timex.begin == arg1_ind: - arg1 = f"TIMEX-{timex_ind}" - if timex.begin == arg2_ind: - arg2 = f"TIMEX-{timex_ind}" - - sent_events = event_results[-1] - for event_ind, event in enumerate(sent_events): - if event.begin == arg1_ind: - arg1 = f"EVENT-{event_ind}" - if event.begin == arg2_ind: - arg2 = f"EVENT-{event_ind}" - - rel = Relation( - arg1=arg1, - arg2=arg2, - category=RELATION_LABEL_LIST[rel[2]], - arg1_start=arg1_ind, - arg2_start=arg2_ind, - ) - rel_sent_results.append(rel) - - rel_results.append(rel_sent_results) - - results = TemporalResults( - timexes=timex_results, events=event_results, relations=rel_results - ) - - postproc_end = time() - - preproc_time = preproc_end - start_time - pred_time = pred_end - preproc_end - postproc_time = postproc_end - pred_end - - logging.info( - f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}" - ) - - return results diff --git a/src/cnlpt/api/termexists_rest.py b/src/cnlpt/api/termexists_rest.py deleted file mode 100644 index a6024513..00000000 --- a/src/cnlpt/api/termexists_rest.py +++ /dev/null @@ -1,106 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import logging -from contextlib import asynccontextmanager -from time import time - -import numpy as np -from fastapi import FastAPI -from pydantic import BaseModel -from transformers import Trainer -from transformers.tokenization_utils import PreTrainedTokenizer - -from .utils import ( - EntityDocument, - create_dataset, - create_instance_string, - initialize_cnlpt_model, -) - -MODEL_NAME = "mlml-chip/sharpseed-termexists" -logger = logging.getLogger("TermExists_REST_Processor") -logger.setLevel(logging.DEBUG) - -TASK = "TermExists" -LABELS = [-1, 1] - -MAX_LENGTH = 128 - - -class TermExistsResults(BaseModel): - """statuses: list of classifier outputs for every input""" - - statuses: list[int] - - -tokenizer: PreTrainedTokenizer -trainer: Trainer - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global tokenizer, trainer - tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME) - yield - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/termexists/process") -async def process(doc: EntityDocument): - doc_text = doc.doc_text - logger.warning( - f"Received document of len {len(doc_text)} to process with {len(doc.entities)} entities" - ) - instances = [] - start_time = time() - - if len(doc.entities) == 0: - return TermExistsResults(statuses=[]) - - for ent_ind, offsets in enumerate(doc.entities): - inst_str = create_instance_string(doc_text, offsets) - logger.debug(f"Instance string is {inst_str}") - instances.append(inst_str) - - dataset = create_dataset(instances, tokenizer, MAX_LENGTH) - preproc_end = time() - - output = trainer.predict(test_dataset=dataset) - predictions = output.predictions[0] - predictions = np.argmax(predictions, axis=1) - - pred_end = time() - - results = [] - for ent_ind in range(len(dataset)): - results.append(LABELS[predictions[ent_ind]]) - - output = TermExistsResults(statuses=results) - - postproc_end = time() - - preproc_time = preproc_end - start_time - pred_time = pred_end - preproc_end - postproc_time = postproc_end - pred_end - - logging.warning( - f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}" - ) - - return output diff --git a/src/cnlpt/api/timex_rest.py b/src/cnlpt/api/timex_rest.py deleted file mode 100644 index bd44778b..00000000 --- a/src/cnlpt/api/timex_rest.py +++ /dev/null @@ -1,156 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import logging -from contextlib import asynccontextmanager -from time import time - -import numpy as np -from fastapi import FastAPI -from nltk.tokenize import wordpunct_tokenize as tokenize -from seqeval.metrics.sequence_labeling import get_entities -from transformers import Trainer -from transformers.tokenization_utils import PreTrainedTokenizer - -from .temporal_rest import ( - TIMEX_LABEL_LIST, - SentenceDocument, - TemporalResults, - Timex, - TokenizedSentenceDocument, - create_instance_string, -) -from .utils import create_dataset, initialize_cnlpt_model - -MODEL_NAME = "tmills/timex-thyme-colon-pubmedbert" -logger = logging.getLogger("Timex_REST_Processor") -logger.setLevel(logging.INFO) - -MAX_LENGTH = 128 - - -tokenizer: PreTrainedTokenizer -trainer: Trainer - - -@asynccontextmanager -async def lifespan(app: FastAPI): - global tokenizer, trainer - tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME) - yield - - -app = FastAPI(lifespan=lifespan) - - -@app.post("/temporal/process") -async def process(doc: TokenizedSentenceDocument): - return process_tokenized_sentence_document(doc) - - -@app.post("/temporal/process_sentence") -async def process_sentence(doc: SentenceDocument): - tokenized_sent = tokenize(doc.sentence) - doc = TokenizedSentenceDocument( - sent_tokens=[ - tokenized_sent, - ], - metadata="Single sentence", - ) - return process_tokenized_sentence_document(doc) - - -def process_tokenized_sentence_document(doc: TokenizedSentenceDocument): - sents = doc.sent_tokens - metadata = doc.metadata - - logger.warning(f"Received document labeled {metadata} with {len(sents)} sentences") - instances = [] - start_time = time() - - for sent_ind, token_list in enumerate(sents): - inst_str = create_instance_string(token_list) - logger.debug(f"Instance string is {inst_str}") - instances.append(inst_str) - - dataset = create_dataset(instances, tokenizer, max_length=MAX_LENGTH) - logger.warning(f"Dataset is as follows: {dataset.features!s}") - - preproc_end = time() - - output = trainer.predict(test_dataset=dataset) - - timex_predictions = np.argmax(output.predictions[0], axis=2) - - timex_results = [] - event_results = [] - relation_results = [] - - pred_end = time() - - for sent_ind in range(len(dataset)): - batch_encoding = tokenizer.batch_encode_plus( - [ - sents[sent_ind], - ], - is_split_into_words=True, - max_length=MAX_LENGTH, - ) - word_ids = batch_encoding.word_ids(0) - wpind_to_ind = {} - timex_labels = [] - previous_word_idx = None - - for word_pos_idx, word_idx in enumerate(word_ids): - if word_idx != previous_word_idx and word_idx is not None: - key = word_pos_idx - val = len(wpind_to_ind) - - wpind_to_ind[key] = val - timex_labels.append( - TIMEX_LABEL_LIST[timex_predictions[sent_ind][word_pos_idx]] - ) - previous_word_idx = word_idx - - timex_entities = get_entities(timex_labels) - logging.info( - f"Extracted {len(timex_entities)} timex entities from the sentence" - ) - timex_results.append( - [ - Timex(timeClass=label[0], begin=label[1], end=label[2]) - for label in timex_entities - ] - ) - event_results.append([]) - relation_results.append([]) - - results = TemporalResults( - timexes=timex_results, events=event_results, relations=relation_results - ) - - postproc_end = time() - - preproc_time = preproc_end - start_time - pred_time = pred_end - preproc_end - postproc_time = postproc_end - pred_end - - logging.info( - f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}" - ) - - return results diff --git a/src/cnlpt/api/utils.py b/src/cnlpt/api/utils.py deleted file mode 100644 index c756ef70..00000000 --- a/src/cnlpt/api/utils.py +++ /dev/null @@ -1,157 +0,0 @@ -import logging -import os -from typing import Literal, cast - -import torch -from datasets import Dataset -from pydantic import BaseModel -from transformers.hf_argparser import HfArgumentParser - -# Modeling imports -from transformers.models.auto.configuration_auto import AutoConfig -from transformers.models.auto.modeling_auto import AutoModel -from transformers.models.auto.tokenization_auto import AutoTokenizer -from transformers.tokenization_utils import PreTrainedTokenizer -from transformers.trainer import Trainer -from transformers.training_args import TrainingArguments - -from ..data.preprocess import preprocess_raw_data -from ..models import CnlpConfig - - -class UnannotatedDocument(BaseModel): - doc_text: str - - -class EntityDocument(BaseModel): - """doc_text: The raw text of the document - offset: A list of entities, where each is a tuple of character offsets into doc_text for that entity - """ - - doc_text: str - entities: list[list[int]] - - -def create_dataset( - inst_list: list[str], - tokenizer: PreTrainedTokenizer, - max_length: int = 128, - hier: bool = False, - chunk_len: int = 200, - num_chunks: int = 40, - insert_empty_chunk_at_beginning: bool = False, -): - """Use a tokenizer to create a dataset from a list of strings.""" - dataset = Dataset.from_dict({"text": inst_list}) - task_dataset = dataset.map( - preprocess_raw_data, - batched=True, - load_from_cache_file=False, - desc="Running tokenizer on dataset, organizing labels, creating hierarchical segments if necessary", - batch_size=100, - fn_kwargs={ - "tokenizer": tokenizer, - "tasks": None, - "max_length": max_length, - "inference_only": True, - "hierarchical": hier, - # TODO: need to get this from the model if necessary - "chunk_len": chunk_len, - "num_chunks": num_chunks, - "insert_empty_chunk_at_beginning": insert_empty_chunk_at_beginning, - }, - ) - return task_dataset - - -def create_instance_string(doc_text: str, offsets: list[int]): - start = max(0, offsets[0] - 100) - end = min(len(doc_text), offsets[1] + 100) - raw_str = ( - doc_text[start : offsets[0]] - + " " - + doc_text[offsets[0] : offsets[1]] - + " " - + doc_text[offsets[1] : end] - ) - return raw_str.replace("\n", " ") - - -def resolve_device( - device: str, -) -> Literal["cuda", "mps", "cpu"]: - device = device.lower() - if device not in ("cuda", "mps", "cpu", "auto"): - raise ValueError(f"invalid device {device}") - if device == "auto": - if torch.cuda.is_available(): - device = "cuda" - elif torch.mps.is_available(): - device = "mps" - else: - device = "cpu" - elif device == "cuda" and not torch.cuda.is_available(): - logging.warning( - "Device is set to 'cuda' but was not available; setting to 'cpu' and proceeding. If you have a GPU you need to debug why pytorch cannot see it." - ) - device = "cpu" - elif device == "mps" and not torch.mps.is_available(): - logging.warning( - "Device is set to 'mps' but was not available; setting to 'cpu' and proceeding. If you have a GPU you need to debug why pytorch cannot see it." - ) - device = "cpu" - return device - - -def initialize_cnlpt_model( - model_name, - device: Literal["cuda", "mps", "cpu", "auto"] = "auto", - batch_size=8, -): - args = [ - "--output_dir", - "save_run/", - "--per_device_eval_batch_size", - str(batch_size), - "--do_predict", - "--report_to", - "none", - ] - parser = HfArgumentParser((TrainingArguments,)) - training_args = cast( - TrainingArguments, parser.parse_args_into_dataclasses(args=args)[0] - ) - - if torch.mps.is_available(): - # pin_memory is unsupported on MPS, but defaults to True, - # so we'll explicitly turn it off to avoid a warning. - training_args.dataloader_pin_memory = False - - config = AutoConfig.from_pretrained(model_name) - tokenizer = AutoTokenizer.from_pretrained(model_name, config=config) - model = AutoModel.from_pretrained( - model_name, cache_dir=os.getenv("HF_CACHE"), config=config - ) - - model = model.to(resolve_device(device)) - - trainer = Trainer(model=model, args=training_args) - - return tokenizer, trainer - - -def initialize_hier_model( - model_name, - device: Literal["cuda", "mps", "cpu", "auto"] = "auto", -): - config: CnlpConfig = AutoConfig.from_pretrained(model_name) - tokenizer = AutoTokenizer.from_pretrained(model_name, config=config) - - model = AutoModel.from_pretrained( - model_name, cache_dir=os.getenv("HF_CACHE"), config=config - ) - model.train(False) - - model = model.to(resolve_device(device)) - - return tokenizer, model diff --git a/src/cnlpt/data/analysis.py b/src/cnlpt/data/analysis.py index d2591a33..f83deefd 100644 --- a/src/cnlpt/data/analysis.py +++ b/src/cnlpt/data/analysis.py @@ -175,7 +175,8 @@ def make_preds_df( Returns: The DataFrame for analysis. """ - seq_len = len(predictions.input_data["input_ids"][0]) + + seq_len = len(predictions.input_data["word_ids"][0]) df_data = { "sample_idx": list(range(len(predictions.input_data))), @@ -195,37 +196,53 @@ def make_preds_df( else: tasks = predictions.tasks + unlabeled = predictions.raw.label_ids is None + for task in tasks: task_pred = predictions.task_predictions[task.name] - df = df.with_columns( - pl.struct( - labels=pl.struct( + + fields = [] + if not unlabeled: + fields.append( + pl.struct( ids=task_pred.labels, values=task_pred.target_str_labels, - ), - predictions=pl.struct( + ).alias("labels") + ) + + fields.extend( + [ + pl.struct( ids=task_pred.predicted_int_labels, values=task_pred.predicted_str_labels, - ), - model_output=pl.struct( + ).alias("predictions"), + pl.struct( logits=task_pred.logits, probs=task_pred.probs, - ), - ).alias(task.name) + ).alias("model_output"), + ] ) + df = df.with_columns(pl.struct(fields).alias(task.name)) + if task.type == CLASSIFICATION: # classification output is already pretty human-interpretable pass elif task.type == TAGGING: # for tagging, we'll convert BIO tags to labeled spans - df = df.join( - _bio_tags_to_spans( - df, pl.col(task.name).struct.field("labels").struct.field("values") - ), - on="sample_idx", - how="left", - ).rename({"spans": "target_spans"}) + tagging_fields = [] + if not unlabeled: + df = df.join( + _bio_tags_to_spans( + df, + pl.col(task.name).struct.field("labels").struct.field("values"), + ), + on="sample_idx", + how="left", + ).rename({"spans": "target_spans"}) + tagging_fields.append( + pl.field("labels").struct.with_fields(spans="target_spans") + ) df = df.join( _bio_tags_to_spans( @@ -238,20 +255,27 @@ def make_preds_df( how="left", ).rename({"spans": "predicted_spans"}) + tagging_fields.append( + pl.field("predictions").struct.with_fields(spans="predicted_spans") + ) + df = df.with_columns( - pl.col(task.name).struct.with_fields( - pl.field("labels").struct.with_fields(spans="target_spans"), - pl.field("predictions").struct.with_fields(spans="predicted_spans"), - ) - ).drop("target_spans", "predicted_spans") + pl.col(task.name).struct.with_fields(tagging_fields) + ).drop("target_spans", "predicted_spans", strict=False) elif task.type == RELATIONS: - df = df.join( - _rel_matrix_to_rels( - df, pl.col(task.name).struct.field("labels").struct.field("values") - ), - on="sample_idx", - how="left", - ).rename({"relations": "target_relations"}) + relations_fields = [] + if not unlabeled: + df = df.join( + _rel_matrix_to_rels( + df, + pl.col(task.name).struct.field("labels").struct.field("values"), + ), + on="sample_idx", + how="left", + ).rename({"relations": "target_relations"}) + relations_fields.append( + pl.field("labels").struct.with_fields(relations="target_relations") + ) df = df.join( _rel_matrix_to_rels( @@ -263,15 +287,15 @@ def make_preds_df( on="sample_idx", how="left", ).rename({"relations": "predicted_relations"}) + relations_fields.append( + pl.field("predictions").struct.with_fields( + relations="predicted_relations" + ) + ) df = df.with_columns( - pl.col(task.name).struct.with_fields( - pl.field("labels").struct.with_fields(relations="target_relations"), - pl.field("predictions").struct.with_fields( - relations="predicted_relations" - ), - ) - ).drop("target_relations", "predicted_relations") + pl.col(task.name).struct.with_fields(relations_fields) + ).drop("target_relations", "predicted_relations", strict=False) else: raise ValueError(f"unknown task type {task.type}") diff --git a/src/cnlpt/data/predictions.py b/src/cnlpt/data/predictions.py index c984be1a..b4c82aa9 100644 --- a/src/cnlpt/data/predictions.py +++ b/src/cnlpt/data/predictions.py @@ -6,11 +6,11 @@ import numpy as np import numpy.typing as npt +import polars as pl from datasets import Dataset from scipy.special import softmax from transformers.trainer_utils import PredictionOutput -from ..args.data_args import CnlpDataArguments from ..data.preprocess import MASK_VALUE from .task_info import CLASSIFICATION, TAGGING, TaskInfo @@ -47,7 +47,6 @@ class CnlpPredictions: input_data: Dataset raw: PredictionOutput tasks: list[TaskInfo] - data_args: CnlpDataArguments task_predictions: dict[str, TaskPredictions] @@ -56,19 +55,20 @@ def __init__( input_data: Dataset, raw_prediction: PredictionOutput, tasks: Iterable[TaskInfo], - data_args: CnlpDataArguments, + max_seq_length: int, ): self.input_data = input_data self.raw = raw_prediction self.tasks = sorted(tasks, key=lambda t: t.index) - self.data_args = data_args + self.max_seq_length = max_seq_length # task indices must start at zero and increase by 1 - assert all(idx == t.index for idx, t in enumerate(tasks)) + if not all(idx == t.index for idx, t in enumerate(tasks)): + raise RuntimeError("task indices should start at zero and increase by one") self.task_predictions: dict[str, TaskPredictions] = {} - task_labels: dict[str, npt.NDArray] + task_labels: dict[str, Union[npt.NDArray, None]] if self.raw.label_ids is None: task_labels = {t.name: None for t in tasks} @@ -98,15 +98,17 @@ def __init__( offset += 1 else: # task.type == RELATIONS task_labels[task.name] = self.raw.label_ids[ - :, :, offset : offset + self.data_args.max_seq_length + :, :, offset : offset + self.max_seq_length ].astype(int) - offset += self.data_args.max_seq_length + offset += self.max_seq_length self.task_predictions = { t.name: TaskPredictions( task=t, logits=self.raw.predictions[t.index], - labels=task_labels[t.name].squeeze(), + labels=task_labels[t.name].squeeze() + if task_labels[t.name] is not None + else None, ) for t in tasks } @@ -128,7 +130,7 @@ def arr_to_list(obj): "metrics": self.raw.metrics, }, "tasks": [asdict(t) for t in self.tasks], - "data_args": asdict(self.data_args), + "max_seq_length": self.max_seq_length, } def save_json( @@ -155,16 +157,29 @@ def list_to_arr(obj, dtype): metrics=data["raw"]["metrics"], ) tasks = [TaskInfo(**t) for t in data["tasks"]] - data_args = CnlpDataArguments(**data["data_args"]) + max_seq_length = data["max_seq_length"] return cls( input_data=input_data, raw_prediction=raw, tasks=tasks, - data_args=data_args, + max_seq_length=max_seq_length, ) @classmethod def load_json(cls, filepath: Union[str, os.PathLike]): with open(filepath) as f: return cls.from_dict(json.load(f)) + + @property + def metrics(self): + return {k.removeprefix("test_"): v for k, v in self.raw.metrics.items()} + + def metrics_df(self): + metrics = self.metrics.items() + return pl.DataFrame( + { + "metric": [m for m, v in metrics], + "value": [v for m, v in metrics], + } + ) diff --git a/src/cnlpt/rest/__init__.py b/src/cnlpt/rest/__init__.py new file mode 100644 index 00000000..9551ce58 --- /dev/null +++ b/src/cnlpt/rest/__init__.py @@ -0,0 +1,3 @@ +from .cnlp_rest import CnlpRestApp + +__all__ = ["CnlpRestApp"] diff --git a/src/cnlpt/rest/cnlp_rest.py b/src/cnlpt/rest/cnlp_rest.py new file mode 100644 index 00000000..add10fa3 --- /dev/null +++ b/src/cnlpt/rest/cnlp_rest.py @@ -0,0 +1,208 @@ +import logging +from collections.abc import Iterable +from typing import Union + +import polars as pl +import torch +from datasets import Dataset +from fastapi import APIRouter, FastAPI +from pydantic import BaseModel +from transformers.models.auto.modeling_auto import AutoModel +from transformers.models.auto.tokenization_auto import AutoTokenizer +from transformers.trainer import Trainer +from transformers.training_args import TrainingArguments +from typing_extensions import Self + +from ..data.analysis import make_preds_df +from ..data.cnlp_dataset import HierarchicalDataConfig +from ..data.predictions import CnlpPredictions +from ..data.preprocess import preprocess_raw_data +from ..data.task_info import CLASSIFICATION, RELATIONS, TAGGING, TaskInfo +from ..modeling.config.hierarchical_config import HierarchicalModelConfig +from ..modeling.load import try_load_config + + +class InputDocument(BaseModel): + text: str + entity_spans: Union[list[tuple[int, int]], None] = None + + def to_text_list(self): + if self.entity_spans is None: + return [self.text] + + text_list: list[str] = [] + for entity_start, entity_end in self.entity_spans: + start = max(0, entity_start - 100) + end = min(len(self.text), entity_end + 100) + text_list.append( + "".join( + [ + self.text[start:entity_start], + "", + self.text[entity_start:entity_end], + "", + self.text[entity_end:end], + ] + ) + ) + return text_list + + +class CnlpRestApp: + def __init__(self, model_path: str, device: str = "auto"): + self.model_path = model_path + self.resolve_device(device) + self.setup_logger(logging.INFO) + self.load_model() + + def resolve_device(self, device: str): + self.device = device.lower() + if self.device == "auto": + if torch.cuda.is_available(): + self.device = "cuda" + elif torch.mps.is_available(): + self.device = "mps" + else: + self.device = "cpu" + else: + try: + torch.tensor([1.0], device=self.device) + except: # noqa: E722 + self.logger.warning( + f"Device is set to '{self.device}' but was not available; setting to 'cpu' and proceeding. If you have a GPU you need to debug why pytorch cannot see it." + ) + self.device = "cpu" + + def setup_logger(self, log_level): + self.logger = logging.getLogger(self.__module__) + self.logger.setLevel(log_level) + + def load_model(self): + training_args = TrainingArguments( + output_dir="cnlp_rest/", + save_strategy="no", + per_device_eval_batch_size=8, + do_predict=True, + ) + + if self.device == "mps": + # pin_memory is unsupported on MPS, but defaults to True, + # so we'll explicitly turn it off to avoid a warning. + training_args.dataloader_pin_memory = False + + self.config = try_load_config(self.model_path) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) + self.model = AutoModel.from_pretrained( + self.model_path, + config=self.config, + ).to(self.device) + self.trainer = Trainer(model=self.model, args=training_args) + + self.tasks: list[TaskInfo] = self.config.tasks + + def create_prediction_dataset( + self, + text: list[str], + max_seq_length: int = 128, + hier_data_config: Union[HierarchicalDataConfig, None] = None, + ): + dataset = Dataset.from_dict({"text": text}) + + return dataset.map( + preprocess_raw_data, + batched=True, + load_from_cache_file=False, + desc="Preprocessing raw input", + batch_size=100, + fn_kwargs={ + "inference_only": True, + "tokenizer": self.tokenizer, + "tasks": None, + "max_length": max_seq_length, + "hier_config": hier_data_config, + }, + ) + + def predict(self, dataset: Dataset, max_seq_length: int): + raw_predictions = self.trainer.predict(dataset) + return CnlpPredictions( + dataset, + raw_prediction=raw_predictions, + tasks=self.tasks, + max_seq_length=max_seq_length, + ) + + def format_predictions(self, predictions: CnlpPredictions): + df = make_preds_df(predictions).select(["text", *[t.name for t in self.tasks]]) + + for task in self.tasks: + if task.type == CLASSIFICATION: + df = df.with_columns( + pl.struct( + prediction=pl.col(task.name) + .struct.field("predictions") + .struct.field("values"), + probs=pl.col(task.name) + .struct.field("model_output") + .struct.field("probs") + .arr.to_struct(fields=task.labels), + ).alias(task.name) + ) + elif task.type == TAGGING: + df = df.with_columns( + pl.struct( + pl.col(task.name) + .struct.field("predictions") + .struct.field("spans") + ).alias(task.name) + ) + elif task.type == RELATIONS: + df = df.with_columns( + pl.struct( + pl.col(task.name) + .struct.field("predictions") + .struct.field("relations") + ).alias(task.name) + ) + + return df.to_dicts() + + def process( + self, + input_doc: InputDocument, + max_seq_length: int = 128, + chunk_len: Union[int, None] = None, + num_chunks: Union[int, None] = None, + prepend_empty_chunk: bool = False, + ): + if isinstance(self.config, HierarchicalModelConfig): + hier_data_config = HierarchicalDataConfig( + chunk_len=chunk_len, + num_chunks=num_chunks, + prepend_empty_chunk=prepend_empty_chunk, + ) + else: + hier_data_config = None + + dataset = self.create_prediction_dataset( + input_doc.to_text_list(), max_seq_length, hier_data_config + ) + predictions = self.predict(dataset, max_seq_length) + return self.format_predictions(predictions) + + def router(self, prefix: str = ""): + router = APIRouter(prefix=prefix) + router.add_api_route("/process", self.process, methods=["POST"]) + return router + + def fastapi(self, router_prefix: str = ""): + app = FastAPI() + app.include_router(self.router(prefix=router_prefix)) + return app + + @classmethod + def multi_app(cls, apps: Iterable[tuple[Self, str]]): + multi_app = FastAPI() + for app, router_prefix in apps: + multi_app.include_router(app.router(router_prefix)) + return multi_app From 0c07ce96da46cc4aead85ee6816105b341d993b1 Mon Sep 17 00:00:00 2001 From: ianbulovic Date: Fri, 29 Aug 2025 13:16:26 -0400 Subject: [PATCH 5/8] update tests --- test/api/test_api.py | 156 ++++++++++++++++++++------------- test/common/fixtures.py | 115 ++++++++++++++---------- test/data/test_analysis.py | 2 +- test/data/test_cnlp_dataset.py | 20 ++--- test/test_init.py | 60 ++----------- 5 files changed, 175 insertions(+), 178 deletions(-) diff --git a/test/api/test_api.py b/test/api/test_api.py index 6b9a3a15..38ccf240 100644 --- a/test/api/test_api.py +++ b/test/api/test_api.py @@ -5,94 +5,124 @@ import pytest from fastapi.testclient import TestClient -from cnlpt.api.utils import EntityDocument +from cnlpt.rest.cnlp_rest import CnlpRestApp, InputDocument class TestNegation: @pytest.fixture def test_client(self): - from cnlpt.api.negation_rest import app - - with TestClient(app) as client: + with TestClient( + CnlpRestApp("mlml-chip/negation_pubmedbert_sharpseed").fastapi() + ) as client: yield client def test_negation_startup(self, test_client): pass def test_negation_process(self, test_client: TestClient): - from cnlpt.api.negation_rest import NegationResults - - doc = EntityDocument( - doc_text="The patient has a sore knee and headache " - "but denies nausea and has no anosmia.", - entities=[[18, 27], [32, 40], [52, 58], [70, 77]], + doc = InputDocument( + text="The patient has a sore knee and headache but denies nausea and has no anosmia.", + entity_spans=[(18, 27), (32, 40), (52, 58), (70, 77)], ) - response = test_client.post("/negation/process", content=doc.json()) + response = test_client.post("/process", content=doc.json()) response.raise_for_status() - assert response.json() == NegationResults.parse_obj( - {"statuses": [-1, -1, 1, 1]} - ) + response_json = response.json() + assert [record["Negation"]["prediction"] for record in response_json] == [ + "-1", # sore knee (not negated) + "-1", # headache (not negated) + "1", # nausea (negated) + "1", # anosmia (negated) + ] class TestTemporal: @pytest.fixture def test_client(self): - from cnlpt.api.temporal_rest import app - - with TestClient(app) as client: + with TestClient(CnlpRestApp("mlml-chip/thyme2_colon_e2e").fastapi()) as client: yield client def test_temporal_startup(self, test_client: TestClient): pass def test_temporal_process_sentence(self, test_client: TestClient): - from cnlpt.api.temporal_rest import ( - SentenceDocument, - TemporalResults, + doc = InputDocument( + text="The patient was diagnosed with adenocarcinoma March 3, 2010 and will be returning for chemotherapy next week." + ) + response = test_client.post("/process", content=doc.json()) + response.raise_for_status() + response_json = response.json() + assert all( + span in response_json[0]["timex"]["spans"] + for span in [ + { + "text": "March 3, 2010 ", + "tag": "DATE", + "start": 6, + "end": 8, + "valid": True, + }, + { + "text": "next week.", + "tag": "DATE", + "start": 15, + "end": 16, + "valid": True, + }, + ] ) - doc = SentenceDocument( - sentence="The patient was diagnosed with adenocarcinoma " - "March 3, 2010 and will be returning for " - "chemotherapy next week." + assert all( + span in response_json[0]["event"]["spans"] + for span in [ + { + "text": "diagnosed ", + "tag": "BEFORE", + "start": 3, + "end": 3, + "valid": True, + }, + { + "text": "adenocarcinoma ", + "tag": "BEFORE", + "start": 5, + "end": 5, + "valid": True, + }, + { + "text": "returning ", + "tag": "AFTER", + "start": 12, + "end": 12, + "valid": True, + }, + { + "text": "chemotherapy ", + "tag": "AFTER", + "start": 14, + "end": 14, + "valid": True, + }, + ] ) - response = test_client.post("/temporal/process_sentence", content=doc.json()) - response.raise_for_status() - out = response.json() - expected_out = TemporalResults.parse_obj( - { - "events": [ - [ - {"begin": 3, "dtr": "BEFORE", "end": 3}, - {"begin": 5, "dtr": "BEFORE", "end": 5}, - {"begin": 13, "dtr": "AFTER", "end": 13}, - {"begin": 15, "dtr": "AFTER", "end": 15}, - ] - ], - "relations": [ - [ - { - "arg1": "TIMEX-0", - "arg1_start": 6, - "arg2": "EVENT-0", - "arg2_start": 3, - "category": "CONTAINS", - }, - { - "arg1": "TIMEX-1", - "arg1_start": 16, - "arg2": "EVENT-2", - "arg2_start": 13, - "category": "CONTAINS", - }, - ] - ], - "timexes": [ - [ - {"begin": 6, "end": 9, "timeClass": "DATE"}, - {"begin": 16, "end": 17, "timeClass": "DATE"}, - ] - ], - } + + assert all( + span in response_json[0]["tlinkx"]["relations"] + for span in [ + [ + { + "arg1_wid": 6, + "arg1_text": "March", + "arg2_wid": 3, + "arg2_text": "diagnosed", + "label": "CONTAINS", + }, + { + "arg1_wid": 15, + "arg1_text": "next", + "arg2_wid": 12, + "arg2_text": "returning", + "label": "CONTAINS", + }, + ] + ] ) - assert out == expected_out diff --git a/test/common/fixtures.py b/test/common/fixtures.py index c93674d2..9fa21dc1 100644 --- a/test/common/fixtures.py +++ b/test/common/fixtures.py @@ -11,18 +11,16 @@ from lorem_text import lorem from transformers.models.auto.tokenization_auto import AutoTokenizer -from cnlpt.args import ( - CnlpDataArguments, - CnlpModelArguments, - CnlpTrainingArguments, - parse_args_dict, -) from cnlpt.data import TaskType -from cnlpt.train_system import CnlpTrainSystem +from cnlpt.data.cnlp_dataset import CnlpDataset +from cnlpt.modeling import ( + ModelType, +) +from cnlpt.train_system import CnlpTrainingArguments, CnlpTrainSystem @pytest.fixture(autouse=True) -def disable_mps(monkeypatch): +def disable_mps_for_ci(monkeypatch): """Disable MPS for CI""" if os.getenv("CI", False): monkeypatch.setattr("torch._C._mps_is_available", lambda: False) @@ -144,52 +142,75 @@ def random_cnlp_data_options( ) -@pytest.fixture -def cnlp_args(request): - marker: pytest.Mark = request.node.get_closest_marker("cnlp_args") - custom_args = marker.kwargs if marker is not None else {} +def custom_cnlp_train_args(**kwargs): + return pytest.mark.cnlp_train_args(**kwargs) - with tempfile.TemporaryDirectory(prefix="cnlp_output_dir") as out_dir: - yield parse_args_dict( - { - "data_dir": [], - "encoder_name": "roberta-base", - "output_dir": out_dir, - "overwrite_output_dir": True, - "do_train": True, - "do_eval": True, - "do_predict": True, - "evals_per_epoch": 1, - "num_train_epochs": 1, - "learning_rate": 1e-5, - "report_to": None, - "save_strategy": "best", - "rich_display": False, - "allow_disjoint_labels": True, - **custom_args, - } - ) +def custom_cnlp_dataset_args(**kwargs): + return pytest.mark.cnlp_dataset_args(**kwargs) -def custom_cnlp_args(**kwargs): - return pytest.mark.cnlp_args(**kwargs) + +def custom_cnlp_model_config(model_type: ModelType, **kwargs): + return pytest.mark.cnlp_model_config(model_type=model_type, **kwargs) + + +def _get_marker_kwargs( + request: pytest.FixtureRequest, marker_name: str +) -> dict[str, Any]: + marker: pytest.Mark = request.node.get_closest_marker(marker_name) + return marker.kwargs if marker is not None else {} @pytest.fixture def random_cnlp_train_system( + request, random_cnlp_data_dir, - cnlp_args: tuple[CnlpModelArguments, CnlpDataArguments, CnlpTrainingArguments], ): - model_args, data_args, training_args = cnlp_args - data_args.data_dir = [random_cnlp_data_dir] - yield CnlpTrainSystem( - model_args=model_args, data_args=data_args, training_args=training_args - ) + custom_train_args = _get_marker_kwargs(request, "cnlp_train_args") + custom_dataset_args = _get_marker_kwargs(request, "cnlp_dataset_args") + custom_model_config_args = _get_marker_kwargs(request, "cnlp_model_config_args") + + with tempfile.TemporaryDirectory(prefix="cnlp_output_dir") as out_dir: + train_args = CnlpTrainingArguments( + **( + dict( + output_dir=out_dir, + overwrite_output_dir=True, + do_train=True, + do_eval=True, + do_predict=True, + evals_per_epoch=1, + num_train_epochs=1, + rich_display=False, + ) + | custom_train_args + ), + ) + + dataset = CnlpDataset(random_cnlp_data_dir, **custom_dataset_args) + + model_type: ModelType = custom_model_config_args.pop( + "model_type", ModelType.PROJ + ) + + config = model_type.config_class( + tasks=list(dataset.tasks), + vocab_size=len(dataset.tokenizer), + **custom_model_config_args, + ) + + model = model_type.model_class(config) + + yield CnlpTrainSystem( + model=model, + dataset=dataset, + training_args=train_args, + ) - # we must close the logfile handler or tearing down the temporary output directory will fail - for handler in logging.root.handlers: - if isinstance(handler, logging.FileHandler) and handler.baseFilename.endswith( - "train_system.log" - ): - handler.close() - logging.root.removeHandler(handler) + # we must close the logfile handler or tearing down the temporary output directory will fail + for handler in logging.root.handlers: + if isinstance( + handler, logging.FileHandler + ) and handler.baseFilename.endswith("train_system.log"): + handler.close() + logging.root.removeHandler(handler) diff --git a/test/data/test_analysis.py b/test/data/test_analysis.py index c592a851..512e270a 100644 --- a/test/data/test_analysis.py +++ b/test/data/test_analysis.py @@ -13,7 +13,7 @@ ) def test_predict(random_cnlp_train_system): predictions = random_cnlp_train_system.predict() - seq_len = random_cnlp_train_system.data_args.max_seq_length + seq_len = random_cnlp_train_system.dataset.max_seq_length df = make_preds_df(predictions) assert df.schema == pl.Schema( { diff --git a/test/data/test_cnlp_dataset.py b/test/data/test_cnlp_dataset.py index c9ce089a..49201ec8 100644 --- a/test/data/test_cnlp_dataset.py +++ b/test/data/test_cnlp_dataset.py @@ -1,6 +1,5 @@ import numpy as np -from cnlpt.args import CnlpDataArguments from cnlpt.data import CnlpDataset from ..common.fixtures import random_cnlp_data_options @@ -13,8 +12,7 @@ n_dev=7, ) def test_create_random_dataset(tokenizer, random_cnlp_data_dir): - args = CnlpDataArguments([random_cnlp_data_dir]) - cnlp_dataset = CnlpDataset(args=args, tokenizer=tokenizer, hierarchical=False) + cnlp_dataset = CnlpDataset(random_cnlp_data_dir, tokenizer=tokenizer) assert len(cnlp_dataset.train_data) == 5 assert len(cnlp_dataset.test_data) == 6 assert len(cnlp_dataset.validation_data) == 7 @@ -29,14 +27,10 @@ def test_labels_shape_classification_only(tokenizer, random_cnlp_data_dir): generates labels with the shape (batch, n_tasks). """ batch_size = 3 - max_seq_len = 128 - args = CnlpDataArguments( - [random_cnlp_data_dir], - max_seq_length=max_seq_len, - overwrite_cache=True, + cnlp_dataset = CnlpDataset( + random_cnlp_data_dir, tokenizer=tokenizer, use_data_cache=False ) - cnlp_dataset = CnlpDataset(args=args, tokenizer=tokenizer, hierarchical=False) batch = next(cnlp_dataset.train_data.iter(batch_size)) labels = np.array(batch["label"]) @@ -56,12 +50,12 @@ def test_labels_shape_mixed_tasks(tokenizer, random_cnlp_data_dir): batch_size = 3 max_seq_len = 128 - args = CnlpDataArguments( - [random_cnlp_data_dir], + cnlp_dataset = CnlpDataset( + random_cnlp_data_dir, + tokenizer=tokenizer, max_seq_length=max_seq_len, - overwrite_cache=True, + use_data_cache=False, ) - cnlp_dataset = CnlpDataset(args=args, tokenizer=tokenizer, hierarchical=False) batch = next(cnlp_dataset.train_data.iter(batch_size)) labels = np.array(batch["label"]) diff --git a/test/test_init.py b/test/test_init.py index 12b83959..2a441740 100644 --- a/test/test_init.py +++ b/test/test_init.py @@ -1,3 +1,4 @@ +# ruff: noqa: F401 """ Test suite for initializing the library """ @@ -9,69 +10,20 @@ def test_init(): """ import cnlpt - assert cnlpt.__package__ == "cnlpt" - def test_init_models(): - import cnlpt.models - - assert cnlpt.models.__package__ == "cnlpt.models" - assert cnlpt.models.__all__ == [ - "CnlpConfig", - "CnlpModelForClassification", - "HierarchicalModel", - ] - - import cnlpt.models.baseline - - assert cnlpt.models.baseline.__package__ == "cnlpt.models.baseline" - assert cnlpt.models.baseline.__all__ == [ - "CnnSentenceClassifier", - "LstmSentenceClassifier", - ] + import cnlpt.modeling + import cnlpt.modeling.config + import cnlpt.modeling.models def test_init_train_system(): import cnlpt.train_system - assert cnlpt.train_system.__package__ == "cnlpt.train_system" - assert cnlpt.train_system.__all__ == ["CnlpTrainSystem"] - def test_init_data(): import cnlpt.data - assert cnlpt.data.__package__ == "cnlpt.data" - assert cnlpt.data.__all__ == [ - "CLASSIFICATION", - "RELATIONS", - "TAGGING", - "CnlpDataset", - "CnlpPredictions", - "TaskInfo", - "TaskType", - "get_task_type", - "preprocess_raw_data", - ] - - -def test_init_args(): - import cnlpt.args - - assert cnlpt.args.__package__ == "cnlpt.args" - assert cnlpt.args.__all__ == [ - "CnlpDataArguments", - "CnlpModelArguments", - "CnlpTrainingArguments", - "parse_args_dict", - "parse_args_from_argv", - "parse_args_json_file", - "preprocess_args", - ] - - -def test_init_api(): - import cnlpt.api - assert cnlpt.api.__package__ == "cnlpt.api" - assert cnlpt.api.__all__ == ["MODEL_TYPES", "get_rest_app"] +def test_init_rest(): + import cnlpt.rest From 315d12fea53da3d1ade141a6d80ceb304a5fb4f4 Mon Sep 17 00:00:00 2001 From: ianbulovic Date: Fri, 29 Aug 2025 13:18:40 -0400 Subject: [PATCH 6/8] update examples --- examples/.gitignore | 2 + examples/chemprot/README.md | 24 +- examples/chemprot/preprocess_chemprot.py | 28 +- examples/uci_drug/README.md | 118 ++++---- examples/uci_drug/prepare_data.py | 67 +++++ examples/uci_drug/transform_uci_drug.py | 116 -------- examples/uci_drug/uci_drug.ipynb | 332 +++++++++++++++++++++++ 7 files changed, 480 insertions(+), 207 deletions(-) create mode 100644 examples/.gitignore create mode 100644 examples/uci_drug/prepare_data.py delete mode 100644 examples/uci_drug/transform_uci_drug.py create mode 100644 examples/uci_drug/uci_drug.ipynb diff --git a/examples/.gitignore b/examples/.gitignore new file mode 100644 index 00000000..2985b620 --- /dev/null +++ b/examples/.gitignore @@ -0,0 +1,2 @@ +*/dataset/ +*/train_output/ diff --git a/examples/chemprot/README.md b/examples/chemprot/README.md index 3a40956f..9f29ea30 100644 --- a/examples/chemprot/README.md +++ b/examples/chemprot/README.md @@ -1,24 +1,22 @@ # Fine-tuning for tagging: End-to-end example -1. Preprocess the data with `uv run examples/chemprot/prepare_chemprot_dataset.py data/chemprot` +1. Preprocess the data with `uv run examples/chemprot/prepare_chemprot_dataset.py` -2. Fine-tune with something like: +2. Fine-tune for NER with something like: ```bash -cnlpt train \ - --task_name chemical_ner gene_ner \ - --data_dir data/chemprot \ - --encoder_name allenai/scibert_scivocab_uncased \ - --do_train \ - --do_eval \ - --cache_dir cache/ \ - --output_dir temp/ \ +uv run cnlpt train \ + --model_type proj \ + --encoder allenai/scibert_scivocab_uncased \ + --data_dir ./dataset \ + --task chemical_ner --task gene_ner \ + --output_dir ./train_output \ --overwrite_output_dir \ - --num_train_epochs 50 \ + --do_train --do_eval \ + --num_train_epochs 3 \ --learning_rate 2e-5 \ --lr_scheduler_type constant \ - --report_to none \ - --save_strategy no \ + --save_strategy best \ --gradient_accumulation_steps 1 \ --eval_accumulation_steps 10 \ --weight_decay 0.2 diff --git a/examples/chemprot/preprocess_chemprot.py b/examples/chemprot/preprocess_chemprot.py index 2f3aa4f2..cdb418ab 100644 --- a/examples/chemprot/preprocess_chemprot.py +++ b/examples/chemprot/preprocess_chemprot.py @@ -1,19 +1,24 @@ import bisect import itertools -import os import re from dataclasses import dataclass -from sys import argv -from typing import Any, Union +from pathlib import Path +from typing import Any import polars as pl from datasets import load_dataset from datasets.dataset_dict import Dataset, DatasetDict +from datasets.utils import disable_progress_bars, enable_progress_bars from rich.console import Console def load_chemprot_dataset(cache_dir="./cache") -> DatasetDict: - return load_dataset("bigbio/chemprot", "chemprot_full_source", cache_dir=cache_dir) + disable_progress_bars() + dataset = load_dataset( + "bigbio/chemprot", "chemprot_full_source", cache_dir=cache_dir + ) + enable_progress_bars() + return dataset def clean_text(text: str): @@ -156,25 +161,18 @@ def preprocess_data(split: Dataset): ) -def main(out_dir: Union[str, os.PathLike]): +if __name__ == "__main__": console = Console() - - if not os.path.isdir(out_dir): - os.mkdir(out_dir) + out_dir = Path(__file__).parent / "dataset" + out_dir.mkdir(exist_ok=True) with console.status("Loading dataset...") as st: dataset = load_chemprot_dataset() for split in ("train", "test", "validation"): st.update(f"Preprocessing {split} data...") preprocessed = preprocess_data(dataset[split]) - preprocessed.write_csv( - os.path.join(out_dir, f"{split}.tsv"), separator="\t" - ) + preprocessed.write_csv(out_dir / f"{split}.tsv", separator="\t") console.print( f"[green i]Preprocessed chemprot data saved to [repr.filename]{out_dir}[/]." ) - - -if __name__ == "__main__": - main(argv[1]) diff --git a/examples/uci_drug/README.md b/examples/uci_drug/README.md index 050e4c13..ea396506 100644 --- a/examples/uci_drug/README.md +++ b/examples/uci_drug/README.md @@ -1,74 +1,66 @@ -### Fine-tuning for classification: End-to-end example +# Drug Review Sentiment Classification -1. Download data from [Drug Reviews (Druglib.com) Data Set](https://archive.ics.uci.edu/dataset/461/drug+review+dataset+druglib+com) to `data` folder and extract. Pay attention to their terms: - 1. only use the data for research purposes - 2. don't use the data for any commerical purposes - 3. don't distribute the data to anyone else - 4. cite us +## Jupyter notebook example -2. Run ```python examples/uci_drug/transform_uci_drug.py ``` to preprocess the data from the extract directory into a new directory. This will create {train,dev,test}.tsv in the processed directory specified, where the sentiment ratings have been collapsed into 3 categories. +See the [example notebook](./uci_drug.ipynb) for a step-by-step walkthrough of +how to use CNLPT to train a model for sentiment classification of drug reviews. -3. Fine-tune with something like: +## CLI example -```bash -cnlpt train \ - --data_dir \ - --task_name sentiment \ - --encoder_name roberta-base \ - --do_train \ - --do_eval \ - --cache_dir cache/ \ - --output_dir temp/ \ - --overwrite_output_dir \ - --evals_per_epoch 5 \ - --num_train_epochs 1 \ - --learning_rate 1e-5 \ - --report_to none \ - --metric_for_best_model eval_sentiment.avg_micro_f1 \ - --load_best_model_at_end \ - --save_strategy best -``` - -On our hardware, that command results in eval performance like the following: -```sentiment = {'acc': 0.7041800643086816, 'f1': [0.7916666666666666, 0.7228915662650603, 0.19444444444444442], 'acc_and_f1': [0.7479233654876741, 0.7135358152868709, 0.449312254376563], 'recall': [0.8216216216216217, 0.8695652173913043, 0.12280701754385964], 'precision': [0.7638190954773869, 0.6185567010309279, 0.4666666666666667]}``` - -#### Error Analysis for Classification - -If you run the above command with the `--error_analysis` flag, you can obtain the `dev` instances for which the model made an erroneous -prediction, organized by their original index in `dev` split, in the `eval_predictions...tsv` file in the `--output_dir` argument. -For us the first line of this file (after the header) is: - -``` - text sentiment -2 Benefits: helped aleviate whip lash symptoms Side effects: none that i noticed Overall comments: i took the medications for the prescribed time and symptoms improved, however, I still have some symptoms which are being treated through physical therapy since the accident was only in December Ground: Medium Predicted: High - -``` - -The number at the beginning of the line, 2, is the index of the instance in the `dev` split. The `text` column contains the text of the erroneous instances and the following columns are the tasks provided to the model, in this case, just `sentiment`. `Ground: Medium Predicted: High` indicates that the provided ground truth label for the instance sentiment is `Medium` but the model predicted `High`. - -#### Human Readable Predictions for Classification +If you prefer, you can instead use the CLI to train the model: -Similarly if you run the above command with `--do_predict` you can obtain human readable predictions for the `test` split, in the `test_predictions...tsv` file. For us the first line of this file (after the header) is: - -``` -0 Benefits: The antibiotic may have destroyed bacteria causing my sinus infection. But it may also have been caused by a virus, so its hard to say. Side effects: Some back pain, some nauseau. Overall comments: Took the antibiotics for 14 days. Sinus infection was gone after the 6th day. Low - -``` - -##### Prediction Probability Outputs for Classification - -(Currently only supported for classification tasks), if you run the above command with the `--output_prob` flag, you can see the model's softmax-obtained probability for the predicted classification label. The first error analysis sample from `dev` would now looks like: - -``` - text sentiment -2 Benefits: helped aleviate whip lash symptoms Side effects: none that i noticed Overall comments: i took the medications for the prescribed time and symptoms improved, however, I still have some symptoms which are being treated through physical therapy since the accident was only in December Ground: Medium Predicted: High , Probability 0.613825 +### Download and preprocess the data +Use the [`prepare_data.py`](./prepare_data.py) script to download the data and convert it to CNLPT's data format: +```bash +uv run prepare_data.py ``` -And the first prediction sample from `test` now looks like: +> [!TIP] About the dataset: +> This script downloads the +> [*Drug Reviews (Druglib.com)* dataset](https://archive.ics.uci.edu/dataset/461/drug+review+dataset+druglib+com). +> Please be aware of the terms of use: +> +> > Important Notes: +> > +> > When using this dataset, you agree that you +> > +> > 1) only use the data for research purposes +> > 2) don't use the data for any commerical purposes +> > 3) don't distribute the data to anyone else +> > 4) cite UCI data lab and the source +> +> Here is the dataset's BibTeX citation: +> +> ```bibtex +> @misc{drug_reviews_(druglib.com)_461, +> author = {Kallumadi, Surya and Grer, Felix}, +> title = {{Drug Reviews (Druglib.com)}}, +> year = {2018}, +> howpublished = {UCI Machine Learning Repository}, +> note = {{DOI}: https://doi.org/10.24432/C55G6J} +> } +> ``` + +### Train a model + +The following example fine-tunes +[the RoBERTa base model](https://huggingface.co/FacebookAI/roberta-base) +with an added projection layer for classification: -``` - text sentiment -0 Benefits: The antibiotic may have destroyed bacteria causing my sinus infection. But it may also have been caused by a virus, so its hard to say. Side effects: Some back pain, some nauseau. Overall comments: Took the antibiotics for 14 days. Sinus infection was gone after the 6th day. Low , Probability 0.370522 +```bash +uv run cnlpt train \ + --model_type proj \ + --encoder roberta-base \ + --data_dir ./dataset \ + --task sentiment \ + --output_dir ./train_output \ + --overwrite_output_dir \ + --do_train --do_eval --do_predict \ + --evals_per_epoch 2 \ + --learning_rate 1e-5 \ + --metric_for_best_model 'sentiment.macro_f1' \ + --load_best_model_at_end \ + --save_strategy best ``` diff --git a/examples/uci_drug/prepare_data.py b/examples/uci_drug/prepare_data.py new file mode 100644 index 00000000..0a5bb6db --- /dev/null +++ b/examples/uci_drug/prepare_data.py @@ -0,0 +1,67 @@ +import io +import zipfile +from pathlib import Path + +import polars as pl +import requests + +DATASET_ZIP_URL = ( + "https://archive.ics.uci.edu/static/public/461/drug+review+dataset+druglib+com.zip" +) +DATA_DIR = Path(__file__).parent / "dataset" + + +def preprocess_raw_data(unprocessed_path: str): + return pl.read_csv(unprocessed_path, separator="\t").select( + id="", + sentiment=pl.col("rating").map_elements( + lambda rating: "Negative" + if rating < 5 + else "Neutral" + if rating < 8 + else "Positive", + return_dtype=pl.String, + ), + text=( + pl.concat_str( + "benefitsReview", + "sideEffectsReview", + "commentsReview", + separator=" ", + ) + .str.replace_all("\n", " ") + .str.replace_all("\r", " ") + .str.replace_all("\t", " ") + ), + ) + + +if __name__ == "__main__": + DATA_DIR.mkdir(exist_ok=True) + + # Download dataset + response = requests.get(DATASET_ZIP_URL) + response.raise_for_status() + zip = zipfile.ZipFile(io.BytesIO(response.content)) + zip.extractall(DATA_DIR) + + raw_train_file = DATA_DIR / "drugLibTrain_raw.tsv" + raw_test_file = DATA_DIR / "drugLibTest_raw.tsv" + + # Preprocess raw data + preprocessed_train_data = preprocess_raw_data(raw_train_file) + preprocessed_test_data = preprocess_raw_data(raw_test_file) + + # 90/10 split for train and dev + preprocessed_train_data, preprocessed_dev_data = ( + preprocessed_train_data.iter_slices(int(preprocessed_train_data.shape[0] * 0.9)) + ) + + # Write to tsv files + preprocessed_train_data.write_csv(DATA_DIR / "train.tsv", separator="\t") + preprocessed_dev_data.write_csv(DATA_DIR / "dev.tsv", separator="\t") + preprocessed_test_data.write_csv(DATA_DIR / "test.tsv", separator="\t") + + # Delete raw data files + raw_train_file.unlink() + raw_test_file.unlink() diff --git a/examples/uci_drug/transform_uci_drug.py b/examples/uci_drug/transform_uci_drug.py deleted file mode 100644 index e7162819..00000000 --- a/examples/uci_drug/transform_uci_drug.py +++ /dev/null @@ -1,116 +0,0 @@ -#!/usr/bin/env python3 -""" -Data Download Source: -https://archive.ics.uci.edu/dataset/462/drug+review+dataset+drugs+com - -Data Source: -Surya Kallumadi -Kansas State University -Manhattan, Kansas, USA -surya '@' ksu.edu - -Felix Gräßer -Institut für Biomedizinische Technik -Technische Universität Dresden -Dresden, Germany -felix.graesser '@' tu-dresden.de - - -Important Notes: -When using this dataset, you agree that you -1) only use the data for research purposes -2) don't use the data for any commerical purposes -3) don't distribute the data to anyone else -4) cite UCI data lab and the source -""" - -import csv -import sys -from pathlib import Path - -import pandas as pd - -TRAIN_FILE = "drugsComTrain_raw.tsv" -TEST_FILE = "drugsComTest_raw.tsv" - - -def to_sentiment(rating): - rating = int(rating) - if rating <= 4: - return "Low" - elif rating > 4 and rating < 8: - return "Medium" - else: - return "High" - - -def remove_newline(review): - review = review.replace("'", "'") - review = review.replace("\n", " ") - review = review.replace("\r", " ") - review = review.replace("\t", " ") - return review - - -def main(): - input_path = Path(sys.argv[1]) - output_path = Path(sys.argv[-1]) - - # read-in files - df = pd.read_csv(input_path / TRAIN_FILE, sep="\t", usecols=["review", "rating"]) - test = pd.read_csv(input_path / TEST_FILE, sep="\t", usecols=["review", "rating"]) - - # split into sentiments categories - test["sentiment"] = test.rating.apply(to_sentiment) - df["sentiment"] = df.rating.apply(to_sentiment) - - # remove newlines: - test["review"] = test.review.apply(remove_newline) - df["review"] = df.review.apply(remove_newline) - - # remove quotes - df["text"] = df["review"].str.replace('"', "") - test["text"] = test["review"].str.replace('"', "") - - # split train and dev into 9:1 ratio - train = df.sample(frac=0.9, random_state=200) - dev = df.drop(train.index) - - # select column as desired - test = test[["sentiment", "text"]] - train = train[["sentiment", "text"]] - dev = dev[["sentiment", "text"]] - - # output CSVs - output_path.mkdir(parents=True, exist_ok=True) - test.to_csv( - output_path / "test.tsv", - sep="\t", - encoding="utf-8", - index=False, - header=True, - quoting=csv.QUOTE_NONE, - escapechar=None, - ) - train.to_csv( - output_path / "train.tsv", - sep="\t", - encoding="utf-8", - index=False, - header=True, - quoting=csv.QUOTE_NONE, - escapechar=None, - ) - dev.to_csv( - output_path / "dev.tsv", - sep="\t", - encoding="utf-8", - index=False, - header=True, - quoting=csv.QUOTE_NONE, - escapechar=None, - ) - - -if __name__ == "__main__": - main() diff --git a/examples/uci_drug/uci_drug.ipynb b/examples/uci_drug/uci_drug.ipynb new file mode 100644 index 00000000..76d8c33e --- /dev/null +++ b/examples/uci_drug/uci_drug.ipynb @@ -0,0 +1,332 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "038b45b2", + "metadata": {}, + "source": [ + "# Drug Review Sentiment Classification\n", + "\n", + "This example uses the [Druglib drug reviews dataset](https://archive.ics.uci.edu/dataset/461/drug+review+dataset+druglib+com) to train a model to classify drug reviews as having positive, neutral, or negative sentiment.\n", + "\n", + "## Setup\n", + "\n", + "To use this notebook, first make sure your virtual environment is configured correctly. From the base cnlp_transformers directory, run:\n", + "\n", + "```bash\n", + "uv sync --group notebooks\n", + "```\n", + "\n", + "This will create a local virtual environment at `.venv` with all the necessary dependencies to run the notebook. Select that environment for your notebook kernel.\n", + "\n", + "## Dataset information \n", + "\n", + "Data Download Source:\n", + "https://archive.ics.uci.edu/dataset/461/drug+review+dataset+druglib+com\n", + "\n", + "BibTeX citation:\n", + "\n", + "```bibtex\n", + "@misc{drug_reviews_(druglib.com)_461,\n", + " author = {Kallumadi, Surya and Grer, Felix},\n", + " title = {{Drug Reviews (Druglib.com)}},\n", + " year = {2018},\n", + " howpublished = {UCI Machine Learning Repository},\n", + " note = {{DOI}: https://doi.org/10.24432/C55G6J}\n", + "}\n", + "```\n", + "\n", + "Important Notes:\n", + "When using this dataset, you agree that you\n", + "1) only use the data for research purposes\n", + "2) don't use the data for any commerical purposes\n", + "3) don't distribute the data to anyone else\n", + "4) cite UCI data lab and the source\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a4b8bcf5", + "metadata": {}, + "outputs": [], + "source": [ + "# Imports\n", + "import io\n", + "import os\n", + "import zipfile\n", + "from pathlib import Path\n", + "\n", + "import polars as pl\n", + "import requests\n", + "\n", + "from cnlpt.data import CnlpDataset, HierarchicalDataConfig\n", + "from cnlpt.data.analysis import make_preds_df\n", + "from cnlpt.data.predictions import CnlpPredictions\n", + "from cnlpt.modeling import (\n", + " CnnModel,\n", + " CnnModelConfig,\n", + " HierarchicalModel,\n", + " HierarchicalModelConfig,\n", + " LstmModel,\n", + " LstmModelConfig,\n", + " ModelType,\n", + " ProjectionModel,\n", + " ProjectionModelConfig,\n", + ")\n", + "from cnlpt.train_system import CnlpTrainingArguments, CnlpTrainSystem" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e40a7f65", + "metadata": {}, + "outputs": [], + "source": [ + "# Define dataset and output directories\n", + "\n", + "DATA_DIR = Path.cwd() / \"dataset\"\n", + "OUTPUT_DIR = DATA_DIR = Path.cwd() / \"train_output\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cd6a3be6", + "metadata": {}, + "outputs": [], + "source": [ + "# Download and preprocess data for training\n", + "FORCE_REDOWNLOAD = False\n", + "\n", + "DATASET_ZIP_URL = (\n", + " \"https://archive.ics.uci.edu/static/public/461/drug+review+dataset+druglib+com.zip\"\n", + ")\n", + "\n", + "\n", + "def preprocess_raw_data(unprocessed_path: str):\n", + " return pl.read_csv(unprocessed_path, separator=\"\\t\").select(\n", + " id=\"\",\n", + " sentiment=pl.col(\"rating\").map_elements(\n", + " lambda rating: \"Negative\"\n", + " if rating < 5\n", + " else \"Neutral\"\n", + " if rating < 8\n", + " else \"Positive\",\n", + " return_dtype=pl.String,\n", + " ),\n", + " text=(\n", + " pl.concat_str(\n", + " \"benefitsReview\",\n", + " \"sideEffectsReview\",\n", + " \"commentsReview\",\n", + " separator=\" \",\n", + " )\n", + " .str.replace_all(\"\\n\", \" \")\n", + " .str.replace_all(\"\\r\", \" \")\n", + " .str.replace_all(\"\\t\", \" \")\n", + " ),\n", + " )\n", + "\n", + "\n", + "DATA_DIR.mkdir(exist_ok=True)\n", + "\n", + "if not FORCE_REDOWNLOAD and all(\n", + " p in os.listdir(DATA_DIR) for p in (\"train.tsv\", \"dev.tsv\", \"test.tsv\")\n", + "):\n", + " print(\"Skipping download and preprocessing, data already exists on disk.\")\n", + "else:\n", + " # Download dataset\n", + " response = requests.get(DATASET_ZIP_URL)\n", + " response.raise_for_status()\n", + " zip = zipfile.ZipFile(io.BytesIO(response.content))\n", + " zip.extractall(DATA_DIR)\n", + "\n", + " raw_train_file = DATA_DIR / \"drugLibTrain_raw.tsv\"\n", + " raw_test_file = DATA_DIR / \"drugLibTest_raw.tsv\"\n", + "\n", + " # Preprocess raw data\n", + " preprocessed_train_data = preprocess_raw_data(raw_train_file)\n", + " preprocessed_test_data = preprocess_raw_data(raw_test_file)\n", + "\n", + " # 90/10 split for train and dev\n", + " preprocessed_train_data, preprocessed_dev_data = (\n", + " preprocessed_train_data.iter_slices(int(preprocessed_train_data.shape[0] * 0.9))\n", + " )\n", + "\n", + " # Write to tsv files\n", + " preprocessed_train_data.write_csv(DATA_DIR / \"train.tsv\", separator=\"\\t\")\n", + " preprocessed_dev_data.write_csv(DATA_DIR / \"dev.tsv\", separator=\"\\t\")\n", + " preprocessed_test_data.write_csv(DATA_DIR / \"test.tsv\", separator=\"\\t\")\n", + "\n", + " # Delete raw data files\n", + " raw_train_file.unlink()\n", + " raw_test_file.unlink()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "692c6744", + "metadata": {}, + "outputs": [], + "source": [ + "# Choose a model type\n", + "\n", + "# Change this to choose a different type of model!\n", + "MODEL_TYPE = ModelType.PROJ" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e480ae0e", + "metadata": {}, + "outputs": [], + "source": [ + "# Prepare the dataset\n", + "\n", + "if MODEL_TYPE == ModelType.HIER:\n", + " hier_data_config = HierarchicalDataConfig(\n", + " chunk_len=16, num_chunks=8, prepend_empty_chunk=False\n", + " )\n", + "else:\n", + " hier_data_config = None\n", + "\n", + "dataset = CnlpDataset(\n", + " DATA_DIR,\n", + " task_names=[\"sentiment\"],\n", + " hier_config=hier_data_config,\n", + " # max_train=100, # optionally limit how many train/eval/test instances to use\n", + " # max_eval=100,\n", + " # max_test=100,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87d78367", + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize the model\n", + "\n", + "if MODEL_TYPE == ModelType.PROJ:\n", + " model = ProjectionModel(\n", + " ProjectionModelConfig(tasks=dataset.tasks, vocab_size=len(dataset.tokenizer))\n", + " )\n", + "elif MODEL_TYPE == ModelType.HIER:\n", + " model = HierarchicalModel(\n", + " HierarchicalModelConfig(tasks=dataset.tasks, vocab_size=len(dataset.tokenizer))\n", + " )\n", + "elif MODEL_TYPE == ModelType.CNN:\n", + " model = CnnModel(\n", + " CnnModelConfig(tasks=dataset.tasks, vocab_size=len(dataset.tokenizer))\n", + " )\n", + "elif MODEL_TYPE == ModelType.LSTM:\n", + " model = LstmModel(\n", + " LstmModelConfig(tasks=dataset.tasks, vocab_size=len(dataset.tokenizer))\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "70434d3f", + "metadata": {}, + "outputs": [], + "source": [ + "# Set up the train system\n", + "\n", + "train_system = CnlpTrainSystem(\n", + " model,\n", + " dataset,\n", + " CnlpTrainingArguments(\n", + " output_dir=OUTPUT_DIR,\n", + " overwrite_output_dir=True,\n", + " do_train=True,\n", + " do_eval=True,\n", + " do_predict=True,\n", + " evals_per_epoch=2,\n", + " num_train_epochs=3,\n", + " learning_rate=1e-5,\n", + " metric_for_best_model=\"sentiment.macro_f1\",\n", + " load_best_model_at_end=True,\n", + " save_strategy=\"best\",\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5110055e", + "metadata": {}, + "outputs": [], + "source": [ + "# Train the model!\n", + "\n", + "train_system.train()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea2ed5c3", + "metadata": {}, + "outputs": [], + "source": [ + "# Load predictions and print metrics\n", + "\n", + "preds = CnlpPredictions.load_json(OUTPUT_DIR / \"predictions.json\")\n", + "with pl.Config(tbl_rows=len(preds.metrics)):\n", + " display(preds.metrics_df())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6e220348", + "metadata": {}, + "outputs": [], + "source": [ + "# Find prediction errors\n", + "\n", + "analysis = make_preds_df(preds)\n", + "analysis.select(\"sample_id\", \"text\", pl.col(\"sentiment\").struct.unnest()).select(\n", + " \"sample_id\",\n", + " \"text\",\n", + " label=pl.col(\"labels\").struct.field(\"values\"),\n", + " prediction=pl.col(\"predictions\").struct.field(\"values\"),\n", + " probabilities=pl.col(\"model_output\")\n", + " .struct.field(\"probs\")\n", + " # labels are sorted automatically by cnlpt, so we should expect them in sorted order\n", + " .arr.to_struct(fields=sorted([\"Negative\", \"Neutral\", \"Positive\"])),\n", + ").unnest(\"probabilities\").filter(pl.col(\"label\") != pl.col(\"prediction\"))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From ebcf90ff428192bb459cb029d2aff0d76b130e4a Mon Sep 17 00:00:00 2001 From: ianbulovic Date: Fri, 29 Aug 2025 13:31:33 -0400 Subject: [PATCH 7/8] remove unused kwargs --- src/cnlpt/_cli/train.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/cnlpt/_cli/train.py b/src/cnlpt/_cli/train.py index 18fb8d22..04736c5d 100644 --- a/src/cnlpt/_cli/train.py +++ b/src/cnlpt/_cli/train.py @@ -380,8 +380,6 @@ def train( logging_strategy: LoggingStrategyArg = "epoch", logging_first_step: LoggingFirstStepArg = True, cache_dir: CacheDirArg = None, - # --------------------- # - **kwargs, ): # TODO(ian): it's probably worth making this docstring pretty descriptive """Run the cnlp_transformers training system.""" From db92ec41eb7b3c0b9988ea32a18ef01207563c4d Mon Sep 17 00:00:00 2001 From: ianbulovic Date: Fri, 29 Aug 2025 14:03:32 -0400 Subject: [PATCH 8/8] typo --- src/cnlpt/_cli/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cnlpt/_cli/train.py b/src/cnlpt/_cli/train.py index 04736c5d..eda08629 100644 --- a/src/cnlpt/_cli/train.py +++ b/src/cnlpt/_cli/train.py @@ -220,7 +220,7 @@ def data_arg_option( model_arg_option( "--lstm_hidden_size", compatibility=["lstm"], - help="LSTM models, the dimension of the hidden layer.", + help="For LSTM models, the dimension of the hidden layer.", ), ]