diff --git a/.gitignore b/.gitignore index a6a7d7a..918e59d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +__pycache__/ +**/__pycache__ deps/ .idea/ *.txt diff --git a/daseg/bin/dasg b/daseg/bin/dasg index c665ecb..be29e3b 100644 --- a/daseg/bin/dasg +++ b/daseg/bin/dasg @@ -130,12 +130,18 @@ def evaluate( results['args'] = ctx.params with SlackNotifier(' '.join(sys.argv)) as slack: for res_grp in ( - 'sklearn_metrics', 'seqeval_metrics', 'zhao_kawahara_metrics', 'ORIGINAL_zhao_kawahara_metrics', + 'sklearn_metrics', 'seqeval_metrics', 'zhao_kawahara_metrics', + 'segeval_metrics', 'args' ): slack.write_and_print(f'{res_grp.upper()}:') for key, val in results[res_grp].items(): - slack.write_and_print(f'{key}\t{val:.2%}') + try: + # TODO: rewrite cleaner later + if isinstance(val, int): raise Exception() + slack.write_and_print(f'{key}\t{val:.2%}') + except: + slack.write_and_print(f'{key}\t{val}') if save_output is not None: with open(save_output, 'wb') as f: pickle.dump(results, f) @@ -152,6 +158,7 @@ def evaluate( @click.option('-l', '--max-sequence-length', default=4096, type=int) @click.option('-n', '--turns', is_flag=True) @click.option('-w', '--windows-if-exceeds-max-len', is_flag=True) +@click.option('-a', '--allow-partial-segments', is_flag=True) def prepare_exp( output_dir: Path, model_name_or_path: str, @@ -162,6 +169,7 @@ def prepare_exp( max_sequence_length: int, turns: bool, windows_if_exceeds_max_len: bool, + allow_partial_segments: bool, ): output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) @@ -183,7 +191,8 @@ def prepare_exp( labels=model.labels, max_seq_length=max_sequence_length, use_turns=turns, - windows_if_exceeds_max_length=windows_if_exceeds_max_len + windows_if_exceeds_max_length=windows_if_exceeds_max_len, + allow_partial_segments=allow_partial_segments, ) for key, split_corpus in corpus.train_dev_test_split().items() } @@ -203,6 +212,7 @@ def prepare_exp( @click.option('-r', '--random-seed', default=1050, type=int) @click.option('-g', '--num-gpus', default=0, type=int) @click.option('-f', '--fp16', is_flag=True) +@click.option('--crf/--no-crf', default=False) def train_transformer( exp_dir: Path, model_name_or_path: str, @@ -212,7 +222,8 @@ def train_transformer( gradient_accumulation_steps: int, random_seed: int, num_gpus: int, - fp16: bool + fp16: bool, + crf: bool ): pl.seed_everything(random_seed) output_path = Path(exp_dir) @@ -220,7 +231,7 @@ def train_transformer( datasets: Dict[str, Dataset] = pickle.load(f) with open(output_path / 'labels.pkl', 'rb') as f: labels: List[str] = pickle.load(f) - model = DialogActTransformer(labels=labels, model_name_or_path=model_name_or_path) + model = DialogActTransformer(labels=labels, model_name_or_path=model_name_or_path, crf=crf) loaders = { key: to_dataloader( dset, @@ -386,7 +397,10 @@ def evaluate_bigru( results['args'] = ctx.params with SlackNotifier(' '.join(sys.argv)) as slack: for key, val in results['log'].items(): - slack.write_and_print(f'{key}\t{val:.2%}') + try: + slack.write_and_print(f'{key}\t{val:.2%}') + except: + slack.write_and_print(f'{key}\t{val}') if save_output is not None: with open(save_output, 'wb') as f: pickle.dump(results, f) diff --git a/daseg/conversion.py b/daseg/conversion.py index 0f49a4d..aa7f1a1 100644 --- a/daseg/conversion.py +++ b/daseg/conversion.py @@ -50,7 +50,7 @@ def predictions_to_dataset( ) -> DialogActCorpus: dialogues = {} for (call_id, call), pred_tags in zip(original_dataset.dialogues.items(), predictions): - words, _, speakers = call.words_with_metadata(add_turn_token=True) + words, orig_labels, speakers = call.words_with_metadata(add_turn_token=True) assert len(words) == len(pred_tags), \ f'Mismatched words ({len(words)}) and predicted tags ({len(pred_tags)}) counts for conversation "{call_id}"' diff --git a/daseg/data.py b/daseg/data.py index 7d2d31c..f3378ce 100644 --- a/daseg/data.py +++ b/daseg/data.py @@ -7,8 +7,9 @@ from functools import partial from itertools import chain, groupby from pathlib import Path -from typing import Callable, Dict, FrozenSet, Iterable, List, Mapping, NamedTuple, Optional, Set, Tuple +from typing import Callable, Dict, FrozenSet, Iterable, List, Literal, Mapping, NamedTuple, Optional, Set, Tuple +import numpy as np from cytoolz.itertoolz import sliding_window from more_itertools import flatten from spacy import displacy @@ -211,7 +212,7 @@ def turns(self) -> Iterable['Call']: @property def dialog_acts(self) -> List[str]: - return sorted(set(segment.dialog_act for call in self.calls for segment in call)) + return sorted(set(str(segment.dialog_act) for call in self.calls for segment in call)) @property def dialog_act_labels(self) -> List[str]: @@ -234,7 +235,7 @@ def joint_coding_dialog_act_label_frequencies(self): @property def joint_coding_dialog_act_labels(self) -> List[str]: - return list(chain([BLANK, CONTINUE_TAG], self.dialog_acts)) + return list(chain([CONTINUE_TAG], self.dialog_acts)) @property def vocabulary(self) -> Counter: @@ -443,11 +444,15 @@ def encode( for (_, prv), (idx, cur), (_, nxt) in segment_windows: if use_joint_coding: enc = [(word, CONTINUE_TAG) for word in cur.text.split()] - if not (continuations_allowed and nxt is not None and nxt.is_continuation): + there_is_a_continuation = continuations_allowed and nxt is not None and nxt.is_continuation + segment_is_finished = cur.completeness in ['complete', 'left-truncated'] + if not there_is_a_continuation and segment_is_finished: enc[-1] = (enc[-1][0], cur.dialog_act) else: enc = [(word, f'{CONTINUE_TAG}{cur.dialog_act}') for word in cur.text.split()] - if not (continuations_allowed and cur.is_continuation): + cur_is_continuation = continuations_allowed and cur.is_continuation + segment_is_starting = cur.completeness in ['complete', 'right-truncated'] + if not cur_is_continuation and segment_is_starting: enc[0] = (enc[0][0], f'{BEGIN_TAG}{cur.dialog_act}') words, acts = zip(*enc) encoded_segment = EncodedSegment( @@ -475,6 +480,23 @@ def encode( return encoded_call + def cut_segments_in_windows(self, window_size_in_tokens: int, tokenizer) -> "Call": + """ + Return a copy of ``self`` where functional segments longer than ``window_size_in_tokens`` + are split into partial segments. In order to know the token count we will need a tokenizer + argument that has a method ``.tokenize(text: str)``. + """ + return Call( + list( + chain.from_iterable( + fs.to_windows( + window_size_in_tokens=window_size_in_tokens, tokenizer=tokenizer + ) + for fs in self + ) + ) + ) + def prepare_call_windows( call: Call, @@ -508,6 +530,9 @@ class FunctionalSegment(NamedTuple): is_continuation: bool = False start: Optional[float] = None end: Optional[float] = None + completeness: Literal[ + "complete", "left-truncated", "right-truncated", "both-truncated" + ] = "complete" @property def num_words(self) -> int: @@ -521,6 +546,62 @@ def with_vocabulary(self, vocabulary: Set[str]) -> 'FunctionalSegment': new_text = ' '.join(w if w in vocabulary else OOV for w in self.text.split()) return FunctionalSegment(new_text, *self[1:]) + def to_windows( + self, window_size_in_tokens: int, tokenizer + ) -> List["FunctionalSegment"]: + words = self.text.split() + segment_tokens_per_word = [len(tokenizer.tokenize(w)) for w in words] + n_segment_tokens = sum(segment_tokens_per_word) + if n_segment_tokens <= window_size_in_tokens: + return [self] + + partial_segments = [] + is_first = True + is_last = False + while words: + if n_segment_tokens < window_size_in_tokens: + text = " ".join(words) + words = [] + is_last = True + else: + n_partial_words = find_nearest( + segment_tokens_per_word, window_size_in_tokens + ) + text = " ".join(words[:n_partial_words]) + words = words[n_partial_words:] + partial_segments.append( + FunctionalSegment( + text=text, + dialog_act=self.dialog_act, + speaker=self.speaker, + is_continuation=self.is_continuation, + start=self.start, + end=self.end, + completeness=( + "right-truncated" + if is_first + else "left-truncated" + if is_last + else "both-truncated" + ), + ) + ) + is_first = False + + return partial_segments + + +def find_nearest(array: List[int], value: int): + """ + Find the index of the closest element to ``value`` in cumulative sum of ``array`` + that is not greater than ``value``. + """ + array = np.cumsum(array) + diff = array - value + diff = np.where(diff <= 0, diff, np.inf) + idx = diff.argmin() + return idx + class EncodedSegment(NamedTuple): words: List[str] diff --git a/daseg/dataloaders/transformers.py b/daseg/dataloaders/transformers.py index d06f358..6446c15 100644 --- a/daseg/dataloaders/transformers.py +++ b/daseg/dataloaders/transformers.py @@ -1,19 +1,21 @@ import warnings from functools import partial from itertools import chain -from typing import Iterable, Optional, List +from typing import Iterable, List, Optional import numpy as np import torch from torch.nn import CrossEntropyLoss -from torch.utils.data import DataLoader, TensorDataset, SequentialSampler, RandomSampler, Dataset +from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler, TensorDataset from transformers import PreTrainedTokenizer -from daseg import DialogActCorpus, Call +from daseg import Call, DialogActCorpus, FunctionalSegment +from daseg.data import find_nearest from daseg.utils_ner import InputExample, convert_examples_to_features -def as_windows(call: Call, max_length: int, tokenizer: PreTrainedTokenizer, use_joint_coding: bool) -> Iterable[Call]: +def as_windows_old(call: Call, max_length: int, tokenizer: PreTrainedTokenizer, use_joint_coding: bool) -> Iterable[ + Call]: if not use_joint_coding: warnings.warn('Call windows are not available when joint coding is turned off. Some calls will be truncated.') return [call] @@ -33,6 +35,53 @@ def as_windows(call: Call, max_length: int, tokenizer: PreTrainedTokenizer, use_ yield Call(window) +def as_windows( + call: Call, + max_length: int, + tokenizer: PreTrainedTokenizer, + use_joint_coding: bool, + allow_partial_segments: bool = False, +) -> Iterable[Call]: + if not use_joint_coding: + warnings.warn( + "Call windows are not available when joint coding is turned off. Some calls will be truncated." + ) + return [call] + window = [] + cur_len = 0 + for segment in call: + words = segment.text.split() + segment_tokens_per_word = [len(tokenizer.tokenize(w)) for w in words] + n_segment_tokens = sum(segment_tokens_per_word) + if cur_len + n_segment_tokens > max_length: + if not window and not allow_partial_segments: + raise ValueError( + "Max sequence length is too low - a segment longer than this value was found." + ) + if (n_partial_tokens := max_length - cur_len) > 0: + n_partial_words = find_nearest(segment_tokens_per_word, n_partial_tokens) + if n_partial_words > 0: + partial_segment = FunctionalSegment( + text=' '.join(words[:n_partial_words]), + dialog_act=segment.dialog_act, + speaker=segment.speaker, + is_continuation=segment.is_continuation, + start=segment.start, + end=segment.end, + completeness="right-truncated" + ) + window.append(partial_segment) + n_segment_tokens -= np.cumsum(segment_tokens_per_word)[n_partial_words] + else: + n_segment_tokens = 0 + yield Call(window) + window = [] + cur_len = n_segment_tokens + window.append(segment) + if window: + yield Call(window) + + def to_dataset( corpus: DialogActCorpus, tokenizer: PreTrainedTokenizer, @@ -41,7 +90,8 @@ def to_dataset( max_seq_length: Optional[int] = None, windows_if_exceeds_max_length: bool = False, use_joint_coding: bool = True, - use_turns: bool = False + use_turns: bool = False, + allow_partial_segments: bool = False, ) -> TensorDataset: ner_examples = [] for idx, call in enumerate(corpus.calls): @@ -55,7 +105,8 @@ def to_dataset( call=call, max_length=max_seq_length, tokenizer=tokenizer, - use_joint_coding=use_joint_coding + use_joint_coding=use_joint_coding, + allow_partial_segments=allow_partial_segments, ) else: call_parts = [call] @@ -118,7 +169,7 @@ def to_dataloader(dataset: Dataset, padding_at_start: bool, batch_size: int, tra batch_size=batch_size, collate_fn=partial(truncate_padding_collate_fn, padding_at_start=padding_at_start), pin_memory=True, - num_workers=4 + num_workers=0 ) @@ -132,6 +183,7 @@ def to_transformers_train_dataloader( use_joint_coding: bool = True, use_turns: bool = False, windows_if_exceeds_max_length: bool = False, + allow_partial_segments: bool = False, ) -> DataLoader: dataset = to_dataset( corpus=corpus, @@ -141,7 +193,8 @@ def to_transformers_train_dataloader( max_seq_length=max_seq_length, use_joint_coding=use_joint_coding, use_turns=use_turns, - windows_if_exceeds_max_length=windows_if_exceeds_max_length + windows_if_exceeds_max_length=windows_if_exceeds_max_length, + allow_partial_segments=allow_partial_segments, ) return to_dataloader(dataset, batch_size=batch_size, train=True, padding_at_start=model_type == 'xlnet') @@ -154,7 +207,8 @@ def to_transformers_eval_dataloader( labels: Iterable[str], max_seq_length: Optional[int] = None, use_joint_coding: bool = True, - use_turns: bool = False + use_turns: bool = False, + allow_partial_segments: bool = False, ) -> DataLoader: """ Convert the DA dataset into a PyTorch DataLoader for inference. @@ -173,28 +227,40 @@ def to_transformers_eval_dataloader( labels=labels, max_seq_length=max_seq_length, use_joint_coding=use_joint_coding, - use_turns=use_turns + use_turns=use_turns, + allow_partial_segments=allow_partial_segments, ) return to_dataloader(dataset, batch_size=batch_size, train=False, padding_at_start=model_type == 'xlnet') -def truncate_padding_collate_fn(batch: List[List[torch.Tensor]], padding_at_start: bool = False): +def truncate_padding_collate_fn(batch: List[List[torch.Tensor]], padding_at_start: bool = False, add_ilen: bool = True, sort: bool = True): + batch = sorted(batch, key=lambda tensors: (tensors[3] != -100).to(torch.int32).sum(), reverse=True) redundant_padding = max(mask.sum() for _, mask, _, _ in batch) n_tensors = len(batch[0]) concat_tensors = (torch.cat([sample[i].unsqueeze(0) for sample in batch]) for i in range(n_tensors)) if padding_at_start: return [t[:, -redundant_padding:] for t in concat_tensors] - return [t[:, :redundant_padding] for t in concat_tensors] + truncated = [t[:, :redundant_padding] for t in concat_tensors] + if add_ilen: + # Here we add extra tensor that states the input lens + truncated.append( + truncated[1].sum(dim=1) + ) + return truncated def pad_list_of_arrays(arrays: List[np.ndarray], value: float) -> List[np.ndarray]: - max_out_len = max(x.shape[1] for x in arrays) + if len(arrays[0].shape) > 1: + max_out_len = max(x.shape[1] for x in arrays) + else: + max_out_len = max(x.shape[0] for x in arrays) return [pad_array(t, target_len=max_out_len, value=value) for t in arrays] def pad_array(arr: np.ndarray, target_len: int, value: float): - if arr.shape[1] == target_len: + len_dim = 1 if len(arr.shape) > 1 else 0 + if arr.shape[len_dim] == target_len: return arr pad_shape = list(arr.shape) - pad_shape[1] = target_len - arr.shape[1] - return np.concatenate([arr, np.ones(pad_shape) * value], axis=1) + pad_shape[len_dim] = target_len - arr.shape[len_dim] + return np.concatenate([arr, np.ones(pad_shape) * value], axis=len_dim) diff --git a/daseg/losses/__init__.py b/daseg/losses/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/daseg/losses/crf.py b/daseg/losses/crf.py new file mode 100644 index 0000000..4d44843 --- /dev/null +++ b/daseg/losses/crf.py @@ -0,0 +1,283 @@ +from typing import Dict, List + +import numpy as np +import k2 +import torch +from torch import Tensor, nn + + +class CRFLoss(nn.Module): + """ + Conditional Random Field loss implemented with K2 library. It supports GPU computation. + + Currently, this loss assumes specific topologies for dialog acts/punctuation labeling. + """ + + def __init__( + self, + label_set: List[str], + label2id: Dict[str, int], + ignore_index: int = -100 + ): + super().__init__() + self.label_set = label_set + self.label2id = label2id + self.ignore_index = ignore_index + self.den = k2.arc_sort(make_denominator(label_set, label2id, shared=True)) + self.den.requires_grad_(False) + self.A = create_bigram_lm([self.label2id[l] for l in label_set]).to('cuda') + self.A_scores = nn.Parameter(self.A.scores.clone(), requires_grad=True).to('cuda') + + def forward(self, log_probs: Tensor, input_lens: Tensor, labels: Tensor): + global it + + # Determine all relevant shapes - max_seqlen_scored is the longest sequence length of log_probs + # after we remove ignored indices. + bs, seqlen, nclass = log_probs.shape + max_seqlen_scored = (labels[0] != self.ignore_index).sum() + supervision_segments = make_segments(labels) + + log_probs_scored = log_probs.new_zeros(bs, max_seqlen_scored, nclass) + assert max_seqlen_scored == supervision_segments[0, 2] + for i in range(bs): + log_probs_scored[i, :supervision_segments[i, 2], :] = log_probs[i, labels[i] != self.ignore_index, :] + + # (batch, seqlen, classes) + posteriors = k2.DenseFsaVec(log_probs_scored, supervision_segments) + + # (fsavec) + self.A.set_scores_stochastic_(self.A_scores) + nums = make_numerator(labels) + A_cpu = self.A.to('cpu') + nums = k2.intersect(A_cpu, nums).to(log_probs.device) + for i in range(nums.shape[0]): + # The supervision has to have exactly the same number of arcs as the number of tokens + # which contain labels to score, plus one extra arc for k2's special end-of-fst arc. + assert nums[i].num_arcs == supervision_segments[i, 2] + 1 + + # (fsavec) + #den = k2.intersect(A_cpu, self.den).detach().to(log_probs.device) + den = self.den.to('cuda') + + if it % 100 == 0: + #print_transition_probabilities(self.A, self.den.symbols, list(self.label2id.values())) + for i in range(min(3, labels.size(0))): + print('*' * 120) + print('log_probs_scored.shape', log_probs_scored.shape) + print(f'labels[{i}][:20] = ', labels[i][:20]) + print(f'labels[{i}][labels[{i}] != self.ignore_index][:20] = ', labels[i][labels[i] != self.ignore_index][:20]) + print(f'nums[{i}].labels[:20] = ', nums[i].labels[:20]) + print(f'log_probs_scored[{i}][:20] = ', log_probs_scored.argmax(dim=2)[i][:20]) + print('*' * 120) + it += 1 + + # (fsavec) + num_lattices = k2.intersect_dense(nums, posteriors, output_beam=10.0) + den_lattices = k2.intersect_dense(den, posteriors, output_beam=10.0) + + # (batch,) + num_scores = num_lattices.get_tot_scores(use_double_scores=True, log_semiring=True) + den_scores = den_lattices.get_tot_scores(use_double_scores=True, log_semiring=True) + + # (scalar) + num_tokens = (labels != self.ignore_index).to(torch.int32).sum() + loss = (num_scores - den_scores).sum() / num_tokens + #loss = num_scores.sum() / num_tokens + return loss + +it = 0 + + +def create_bigram_lm(labels: List[int]) -> k2.Fsa: + """ + Create a bigram LM. + The resulting FSA (A) has a start-state and a state for + each label 0, 1, 2, ....; and each of the above-mentioned states + has a transition to the state for each phone and also to the final-state. + """ + final_state = len(labels) + 1 + rules = '' + for i in range(1, final_state): + rules += f'0 {i} {labels[i-1]} 0.0\n' + + for i in range(1, final_state): + for j in range(1, final_state): + rules += f'{i} {j} {labels[j-1]} 0.0\n' + rules += f'{i} {final_state} -1 0.0\n' + rules += f'{final_state}' + return k2.Fsa.from_str(rules) + + +def print_transition_probabilities(P: k2.Fsa, phone_symbol_table: k2.SymbolTable, + phone_ids: List[int], filename: str = None): + '''Print the transition probabilities of a phone LM. + + Args: + P: + A bigram phone LM. + phone_symbol_table: + The phone symbol table. + phone_ids: + A list of phone ids + filename: + Filename to save the printed result. + ''' + num_phones = len(phone_ids) + table = np.zeros((num_phones + 1, num_phones + 2)) + table[:, 0] = 0 + table[0, -1] = 0 # the start state has no arcs to the final state + #assert P.arcs.dim0() == num_phones + 2 + arcs = P.arcs.values()[:, :3] + probability = P.scores.exp().tolist() + + assert arcs.shape[0] - num_phones == num_phones * (num_phones + 1) + for i, arc in enumerate(arcs.tolist()): + src_state, dest_state, label = arc[0], arc[1], arc[2] + prob = probability[i] + if label != -1: + assert label == dest_state + else: + assert dest_state == num_phones + 1 + table[src_state][dest_state] = prob + + try: + from prettytable import PrettyTable + except ImportError: + print('Please run `pip install prettytable`. Skip printing') + return + + x = PrettyTable() + + field_names = ['source'] + field_names.append('sum') + for i in phone_ids: + field_names.append(phone_symbol_table[i]) + field_names.append('final') + + x.field_names = field_names + + for row in range(num_phones + 1): + this_row = [] + if row == 0: + this_row.append('start') + else: + this_row.append(phone_symbol_table[row]) + this_row.append('{:.6f}'.format(table[row, 1:].sum())) + for col in range(1, num_phones + 2): + this_row.append('{:.6f}'.format(table[row, col])) + x.add_row(this_row) + print(str(x)) + #with open(filename, 'w') as f: + # f.write(str(x)) + + +def make_symbol_table(label_set: List[str], label2id: Dict[str, int], shared: bool = True) -> k2.SymbolTable: + """ + Creates a symbol table given a list of classes (e.g. dialog acts, punctuation, etc.). + It adds extra symbols: + - 'O' which is used to indicate special tokens such as + - (when shared=True) 'I-' which is the "in-the-middle" symbol shared between all classes + - (when shared=False) 'I-' which is the "in-the-middle" symbol, + specific for each class (N x classes -> N x I- symbols) + """ + symtab = k2.SymbolTable() + del symtab._sym2id[''] + del symtab._id2sym[0] + if shared: + symtab.add('I-', label2id['I-']) + for l in label_set: + symtab.add(l, label2id[l]) + if not shared: + symtab.add(f'I-{l}', label2id[f'I-{l}']) + return symtab + + +def make_numerator(labels: Tensor) -> k2.Fsa: + """ + Creates a numerator supervision FSA. + It simply encodes the ground truth label sequence and allows no leeway. + Returns a :class:`k2.FsaVec` with FSAs of differing length. + """ + assert len(labels.shape) == 2 + nums = k2.create_fsa_vec([k2.linear_fsa(lab[lab != -100].tolist()) for lab in labels]) + return nums + + +def make_denominator(label_set: List[str], label2id: Dict[str, int], shared: bool = True) -> k2.Fsa: + """ + Creates a "simple" denominator that encodes all possible transitions + given the input label set. + + The labeling scheme is assumed to be IE with joint coding, e.g.: + + Here I am.~~~~~~ How are you today?~~ + I~~~ I Statement I~~ I~~ I~~ Question + + Or without joint coding: + + Here~~~~~~~ I~~~~~~~~~~ am.~~~~~~ How~~~~~~~ are~~~~~~~ you~~~~~~~ today?~~ + I-Statement I-Statement Statement I-Question I-Question I-Question Question + + When shared=True, it uses a shared "in-the-middle" label for all classes; + otherwise each class has a separate one. + """ + symtab = make_symbol_table(label_set, label2id, shared=shared) + + """ + shared=True + 0 0 Statement + 0 0 Question + 0 1 I- + 1 1 I- + 1 0 Statement + 1 0 Question + 0 2 -1 + 2 + """ + + """ + shared=False + 0 0 Statement + 0 0 Question + 0 1 I-Statement + 0 1 I-Question + 1 1 I-Statement + 1 1 I-Question + 1 0 Statement + 1 0 Question + 0 2 -1 + 2 + """ + + s = [] + if shared: + s += [ + f'0 1 {symtab["I-"]} 0.0', + f'1 1 {symtab["I-"]} 0.0' + ] + for idx, label in enumerate(label_set): + s += [f'0 0 {symtab[label]} 0.0'] + if not shared: + s += [ + f'0 1 {symtab["I-" + label]} 0.0' + f'1 1 {symtab["I-" + label]} 0.0' + ] + s += [f'1 0 {symtab[label]} 0.0'] + s += ['0 2 -1 0.0', '2'] + s.sort() + fsa = k2.Fsa.from_str('\n'.join(s)) + fsa.symbols = symtab + return fsa + + +def make_segments(labels: Tensor) -> Tensor: + """ + Creates a supervision segments tensor that indicates for each batch example, + at which index the example has started, and how many tokens it has. + """ + bs = labels.size(0) + return torch.stack([ + torch.arange(bs, dtype=torch.int32), + torch.zeros(bs, dtype=torch.int32), + (labels != -100).to(torch.int32).sum(dim=1).cpu() + ], dim=1).to(torch.int32) diff --git a/daseg/metrics.py b/daseg/metrics.py index e1176e2..3f3c105 100644 --- a/daseg/metrics.py +++ b/daseg/metrics.py @@ -9,6 +9,7 @@ import torch from Bio import pairwise2 from more_itertools import flatten +from segeval.similarity import boundary_statistics from daseg import Call, DialogActCorpus from daseg.data import CONTINUE_TAG @@ -55,6 +56,69 @@ def compute_seqeval_metrics(true_labels: List[List[str]], predictions: List[List } +def compute_segeval_metrics(true_dataset: DialogActCorpus, pred_dataset: DialogActCorpus): + from statistics import mean + from segeval import boundary_similarity, pk + from segeval.data import Dataset + from segeval.compute import summarize + + def fix_single_seg_calls(true, pred): + for cid in true.keys(): + true_segs = true[cid]["ref"] + pred_segs = pred[cid]["hyp"] + if len(true_segs) == len(pred_segs) == 1: + true[cid]["ref"] = true_segs + [1] + pred[cid]["hyp"] = pred_segs + [1] + + true_segments = { + cid: {"ref": [len(fs.text.split()) for fs in call]} + for cid, call in true_dataset.dialogues.items() + } + pred_segments = { + cid: {"hyp": [len(fs.text.split()) for fs in call]} + for cid, call in pred_dataset.dialogues.items() + } + + fix_single_seg_calls(true_segments, pred_segments) + + pred_segments = Dataset(pred_segments) + true_segments = Dataset(true_segments) + + summary = { + "correct": 0, + "almost_correct": 0, + "missed_boundaries": 0, + "false_boundaries": 0, + } + for stat in boundary_statistics(true_segments, pred_segments).values(): + summary["correct"] += len(stat["matches"]) + summary["almost_correct"] += len(stat["transpositions"]) + summary["missed_boundaries"] += len(stat["full_misses"]) + summary["false_boundaries"] += len(stat["additions"]) + OK = summary["correct"] + summary["almost_correct"] + summary["boundary_precision"] = OK / (OK + summary["false_boundaries"]) + summary["boundary_recall"] = OK / (OK + summary["missed_boundaries"]) + summary["boundary_f1"] = ( + summary["boundary_precision"] + * summary["boundary_recall"] + / (summary["boundary_precision"] + summary["boundary_recall"]) + ) + + B2_mean, B2_std, *_ = summarize(boundary_similarity(true_segments, pred_segments)) + B5_mean, B5_std, *_ = summarize(boundary_similarity(true_segments, pred_segments, n_t=5)) + B10_mean, B10_std, *_ = summarize(boundary_similarity(true_segments, pred_segments, n_t=10)) + summary.update({ + "pk": float(mean(pk(true_segments, pred_segments).values())), + "B2": float(B2_mean), + "B2𝛔": float(B2_std), + "B5": float(B5_mean), + "B5𝛔": float(B5_std), + "B10": float(B10_mean), + "B10𝛔": float(B10_std), + }) + return summary + + def compute_zhao_kawahara_metrics_levenshtein(true_dataset: DialogActCorpus, pred_dataset: DialogActCorpus): """ Source: diff --git a/daseg/models/bigru.py b/daseg/models/bigru.py index 4d46b94..a1c1e2c 100644 --- a/daseg/models/bigru.py +++ b/daseg/models/bigru.py @@ -9,7 +9,8 @@ from torch.nn import CrossEntropyLoss from daseg.conversion import joint_coding_predictions_to_corpus -from daseg.metrics import compute_original_zhao_kawahara_metrics, compute_sklearn_metrics, compute_zhao_kawahara_metrics +from daseg.metrics import compute_original_zhao_kawahara_metrics, compute_segeval_metrics, compute_sklearn_metrics, \ + compute_zhao_kawahara_metrics class ZhaoKawaharaBiGru(pl.LightningModule): @@ -194,6 +195,11 @@ def compute_metrics(self, logits, true_labels): ) metrics.update({k: results[k] for k in ('micro_f1', 'macro_f1')}) + # Pk and B metrics + metrics.update(compute_segeval_metrics( + true_dataset=true_dataset, pred_dataset=pred_dataset) + ) + # We show the metrics obtained with Zhao-Kawahara code which computes them differently # (apparently the segment insertion errors are not counted) original_zhao_kawahara_metrics = compute_original_zhao_kawahara_metrics( diff --git a/daseg/models/transformer_model.py b/daseg/models/transformer_model.py index 27ff235..cdab308 100644 --- a/daseg/models/transformer_model.py +++ b/daseg/models/transformer_model.py @@ -19,7 +19,8 @@ from daseg.conversion import predictions_to_dataset from daseg.data import DialogActCorpus from daseg.dataloaders.transformers import pad_list_of_arrays, to_transformers_eval_dataloader -from daseg.metrics import compute_original_zhao_kawahara_metrics, compute_seqeval_metrics, compute_sklearn_metrics, \ +from daseg.metrics import compute_original_zhao_kawahara_metrics, compute_segeval_metrics, compute_seqeval_metrics, \ + compute_sklearn_metrics, \ compute_zhao_kawahara_metrics from daseg.models.longformer_model import LongformerForTokenClassification @@ -72,7 +73,8 @@ def from_pl_checkpoint(path: Path, device: str = 'cpu'): torch.arange(pl_model.config.max_position_embeddings).expand((1, -1)) # Remove extra keys that are no longer needed... for k in ["model.longformer.pooler.dense.weight", "model.longformer.pooler.dense.bias"]: - del ckpt['state_dict'][k] + if k in ckpt['state_dict']: + del ckpt['state_dict'][k] # Manually load the state dict pl_model.load_state_dict(ckpt['state_dict']) @@ -126,10 +128,12 @@ def predict( use_joint_coding=use_joint_coding, use_turns=use_turns ) + elif not isinstance(dataset, torch.utils.data.DataLoader): + dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size) else: dataloader = dataset - eval_ce_losses, logits, out_label_ids = zip(*list(maybe_tqdm( + eval_ce_losses, logits, out_label_ids, input_ids = zip(*list(maybe_tqdm( map( partial( predict_batch_in_windows, @@ -147,18 +151,24 @@ def predict( pad_token_label_id = CrossEntropyLoss().ignore_index out_label_ids = pad_list_of_arrays(out_label_ids, value=pad_token_label_id) logits = pad_list_of_arrays(logits, value=0) - out_label_ids = np.concatenate(out_label_ids, axis=0) + out_label_ids = np.concatenate(out_label_ids, axis=0).astype(np.int32) logits = np.concatenate(logits, axis=0) preds = np.argmax(logits, axis=2) + input_ids = np.concatenate(pad_list_of_arrays(input_ids, value=0), axis=0) label_map = {int(k): v for k, v in self.config.id2label.items()} + label_map[-100] = 'O' + turn_tok_id = self.tokenizer.encode('')[1] + assert out_label_ids.shape == preds.shape, f'{out_label_ids.shape} == {preds.shape}' out_label_list: List[List[str]] = [[] for _ in range(out_label_ids.shape[0])] - preds_list: List[List[str]] = [[] for _ in range(out_label_ids.shape[0])] + preds_list: List[List[str]] = [[] for _ in range(preds.shape[0])] for i in range(out_label_ids.shape[0]): for j in range(out_label_ids.shape[1]): - if out_label_ids[i, j] != pad_token_label_id: + is_ignored = out_label_ids[i, j] == pad_token_label_id + is_turn = input_ids[i, j] == turn_tok_id + if not is_ignored or is_turn: out_label_list[i].append(label_map[out_label_ids[i][j]]) preds_list[i].append(label_map[preds[i][j]]) @@ -180,8 +190,10 @@ def predict( # (apparently the segment insertion errors are not counted) "ORIGINAL_zhao_kawahara_metrics": compute_original_zhao_kawahara_metrics( true_turns=out_label_list, pred_turns=preds_list - ) + ), }) + # Pk and B metrics + if isinstance(dataset, DialogActCorpus): if use_turns: dataset = DialogActCorpus(dialogues={str(i): turn for i, turn in enumerate(dataset.turns)}) @@ -195,6 +207,9 @@ def predict( results["zhao_kawahara_metrics"] = compute_zhao_kawahara_metrics( true_dataset=dataset, pred_dataset=results['dataset'] ) + results["segeval_metrics"] = compute_segeval_metrics( + true_dataset=dataset, pred_dataset=results['dataset'] + ) return results @@ -224,7 +239,7 @@ def predict_batch_in_windows( maxlen = batch[0].shape[1] window_shift = window_len - window_overlap windows = ( - [t[:, i: i + window_len].contiguous().to(device) for t in batch] + [t[:, i: i + window_len].contiguous().to(device) if len(t.shape) > 1 else t for t in batch] for i in range(0, maxlen, window_shift) ) @@ -255,7 +270,7 @@ def predict_batch_in_windows( mems = outputs[2] # workaround for PyTorch file descriptor leaks: # https://github.com/pytorch/pytorch/issues/973 - returns = ce_loss, np.concatenate(logits, axis=1), deepcopy(batch[3].detach().cpu().numpy()) + returns = ce_loss, np.concatenate(logits, axis=1), deepcopy(batch[3].detach().cpu().numpy()), deepcopy(batch[3].detach().cpu().numpy()) for t in batch: del t return returns diff --git a/daseg/models/transformer_pl.py b/daseg/models/transformer_pl.py index fb96a9a..d2be498 100644 --- a/daseg/models/transformer_pl.py +++ b/daseg/models/transformer_pl.py @@ -1,5 +1,4 @@ from pathlib import Path -from pathlib import Path from typing import Dict, List import numpy as np @@ -16,17 +15,24 @@ class DialogActTransformer(pl.LightningModule): - def __init__(self, labels: List[str], model_name_or_path: str, pretrained: bool = True): + def __init__( + self, + labels: List[str], + model_name_or_path: str, + pretrained: bool = True, + crf: bool = False + ): super().__init__() self.save_hyperparameters() self.pad_token_label_id = CrossEntropyLoss().ignore_index self.labels = labels + self.label2id = {label: i for i, label in enumerate(self.labels)} self.num_labels = len(self.labels) self.config = AutoConfig.from_pretrained( model_name_or_path, num_labels=self.num_labels, id2label={str(i): label for i, label in enumerate(self.labels)}, - label2id={label: i for i, label in enumerate(self.labels)}, + label2id=self.label2id ) self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.tokenizer.add_special_tokens({'additional_special_tokens': [NEW_TURN]}) @@ -40,6 +46,11 @@ def __init__(self, labels: List[str], model_name_or_path: str, pretrained: bool model_class = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING[type(self.config)] self.model = model_class(self.config) self.model.resize_token_embeddings(len(self.tokenizer)) + if crf: + from daseg.losses.crf import CRFLoss + self.crf = CRFLoss([l for l in self.labels if l != 'O' and not l.startswith('I-')], self.label2id) + else: + self.crf = None def forward(self, **inputs): return self.model(**inputs) @@ -51,11 +62,22 @@ def training_step(self, batch, batch_num): inputs["token_type_ids"] = ( batch[2] if self.config.model_type in ["bert", "xlnet"] else None ) # XLM and RoBERTa don"t use token_type_ids - outputs = self(**inputs) - loss = outputs[0] - tensorboard_logs = {"loss": loss} - return {"loss": loss, "log": tensorboard_logs} + ce_loss, logits = outputs[:2] + if self.crf is not None: + log_probs = torch.nn.functional.log_softmax(logits, dim=2) + labels, ilens = batch[3], batch[4] + crf_loss = -self.crf(log_probs, ilens, labels) + #ce_loss = 0.1 * ce_loss + loss = crf_loss# + ce_loss + #loss = ce_loss + logs = {"loss": loss, 'crf_loss': crf_loss, 'ce_loss': ce_loss} + else: + loss = ce_loss + logs = {"loss": loss} + progdict = logs.copy() + progdict.pop('loss') + return {"loss": loss, "log": logs, 'progress_bar': progdict} def validation_step(self, batch, batch_nb): "Compute validation" @@ -66,10 +88,14 @@ def validation_step(self, batch, batch_nb): batch[2] if self.config.model_type in ["bert", "xlnet"] else None ) # XLM and RoBERTa don"t use token_type_ids outputs = self(**inputs) - tmp_eval_loss, logits = outputs[:2] + loss, logits = outputs[:2] + if self.crf is not None: + log_probs = torch.nn.functional.log_softmax(logits, dim=2) + labels, ilens = batch[3], batch[4] + loss = -self.crf(log_probs, ilens, labels) preds = logits.detach().cpu().numpy() - out_label_ids = inputs["labels"].detach().cpu().numpy() - return {"val_loss": tmp_eval_loss.detach().cpu(), "pred": preds, "target": out_label_ids} + out_label_ids = inputs["labels"] + return {"val_loss": loss, "pred": preds, "target": out_label_ids} def test_step(self, batch, batch_nb): return self.validation_step(batch, batch_nb) @@ -129,7 +155,7 @@ def compute_and_set_total_steps( def get_lr_scheduler(self): scheduler = get_linear_schedule_with_warmup( - self.opt, num_warmup_steps=0, num_training_steps=self.total_steps + self.opt, num_warmup_steps=250, num_training_steps=self.total_steps ) scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} return scheduler @@ -148,12 +174,14 @@ def configure_optimizers(self): "weight_decay": 0.0, }, ] - optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, eps=1e-8) - self.opt = optimizer - - scheduler = self.get_lr_scheduler() + self.opt = AdamW( + optimizer_grouped_parameters, + lr=5e-5, + eps=1e-8 + ) + self.scheduler = self.get_lr_scheduler() - return [optimizer], [scheduler] + return [self.opt], [self.scheduler] def set_output_dir(self, output_dir: Path): self.output_dir = Path(output_dir) @@ -175,6 +203,9 @@ def set_output_dir(self, output_dir: Path): def pad_outputs(outputs: Dict) -> Dict: max_out_len = max(x["pred"].shape[1] for x in outputs) for x in outputs: + for k in ['pred', 'target']: + if isinstance(x[k], torch.Tensor): + x[k] = x[k].cpu().numpy() x["pred"] = pad_array(x["pred"], target_len=max_out_len, value=0) x["target"] = pad_array(x["target"], target_len=max_out_len, value=CrossEntropyLoss().ignore_index) return outputs diff --git a/daseg/punctuation/__init__.py b/daseg/punctuation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/daseg/punctuation/data.py b/daseg/punctuation/data.py new file mode 100644 index 0000000..79e4a9d --- /dev/null +++ b/daseg/punctuation/data.py @@ -0,0 +1,270 @@ +# conda install pandas numpy matplotlib seaborn +# pip install jupyterlab tqdm kaldialign ipywidgets jupyterlab_widgets transformers cytoolz datasets seqeval +import pickle +import re +import string +from functools import lru_cache +from itertools import chain +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Tuple, TypedDict + +from tqdm.auto import tqdm + +from daseg import DialogActCorpus + + +class Example(TypedDict): + words: List[str] + upper_words: List[str] + norm_words: List[str] + punct: List[str] + is_upper: List[bool] + speaker: str + + +class PunctuationData(TypedDict): + train: List[Example] + dev: List[Example] + test: List[Example] + idx2punct: Dict[int, str] + punct2idx: Dict[str, int] + vocab: List[str] + + +@lru_cache(20000) +def is_special_token(word: str) -> bool: + return any( + word.startswith(ltok) and word.endswith(rtok) + for ltok, rtok in ['[]', '<>'] + ) + + +@lru_cache(20000) +def split_punctuation_from_word(word: str, _pattern=re.compile(r'(\w|[\[\]<>])')) -> Tuple[str, str]: + # this silly _pattern will correctly handle "[NOISE]." -> "[NOISE]", "." + if not word: + return '', '' + word = word[::-1] # hi. -> .ih + match = _pattern.search(word) + if match is not None: + first_non_punc_idx = match.span()[0] + text = word[first_non_punc_idx:][::-1] + punc = word[:first_non_punc_idx][::-1] + return text, punc + else: + # pure punctuation + return '', word + + +def preprocess_punctuation( + text: str, + _precedences=['?', '!', '...', '.', ',', '--', ';'], +) -> str: + text = text.replace('"', '') + text.replace('+', '') + text = re.sub(r'--+', '--', text) + words = text.split() + norm_words = [] + for w in words: + w, punc = split_punctuation_from_word(w) + for sym in _precedences: + if sym in punc: + norm_words.append(f'{w}{sym}') + break + else: + norm_words.append(w) + return ' '.join(norm_words) + + +def create_example( + text: str, + _punctuation=string.punctuation.replace("'", ""), + _special=re.compile(r'\[\[.*?\]\]|\[.*?\]|<.*?>', ) +) -> Optional[Example]: + """ + Converts a text segment / utterance into a dict that + can be used for punctuation/truecasing model training/eval. + + .. code-block:: python + + { + 'words': List[str], + 'upper_words': List[str], + 'norm_words': List[str], + 'punct': List[str], + 'is_upper': List[bool], + 'speaker': str + } + + :param text: + :param _punctuation: list of punctuation symbols (default is globally cached). + :param _special: regex pattern for detecting special symbols like [UNK] or (default is globally cached). + :return: see above. + """ + text = text.replace(' --', '--').replace(' ...', '...').strip() # stick punctuation to the text + text = _special.sub('', text) + text = ' '.join(text.split()) + + text_base, text_punct = split_punctuation_from_word(text) + if not text or not text_base: + return None + + # get rid of pesky punctuations like "hey...?!;" -> "hey?" + text = preprocess_punctuation(text) + # rich words and lower/no-punc words + words = text.split() + norm_words = [w.lower().translate(str.maketrans('', '', _punctuation)) for w in words] + # filter out the words that consisted only of punctuation + idx_to_remove = [] + for idx, (w, nw) in enumerate(zip(words, norm_words)): + if not nw: + idx_to_remove.append(idx) + norm_words = [ + w if not is_special_token(split_punctuation_from_word(words[idx])[0]) else + split_punctuation_from_word(words[idx])[0] + for idx, w in enumerate(norm_words) + if idx not in idx_to_remove + ] + words = [ + w + for idx, w in enumerate(words) + if idx not in idx_to_remove + ] + upper_words, labels = zip(*(split_punctuation_from_word(w) for w in words)) + is_upper = [not is_special_token(w) and any(c.isupper() for c in w) for w in upper_words] + return { + 'words': words, + 'upper_words': upper_words, + 'norm_words': norm_words, + 'punct': labels, + 'is_upper': is_upper + } + + +def train_dev_test_split( + texts: List[Example], + train_portion: float = 0.9, + dev_portion: float = 0.05 +) -> Dict[str, List[Example]]: + train_part = round(len(texts) * train_portion) + dev_part = round(len(texts) * dev_portion) + data = { + 'train': texts[:train_part], + 'dev': texts[train_part:train_part + dev_part], + 'test': texts[train_part + dev_part:] + } + return data + + +def add_vocab_and_labels( + data: Dict[str, Any], + texts: Dict[str, List[Example]] +) -> PunctuationData: + """Pickle structure: + { + 'train': [ + { + # Each dict represents a single turn + 'words': List[str], + 'upper_words': List[str], + 'norm_words': List[str], + 'punct': List[str], + 'is_upper': List[bool], + 'speaker': str + } + ], + 'dev': [...], + 'test': [...], + 'idx2punct': Dict[int, str], + 'punct2idx': Dict[str, int], + 'vocab': List[str] + } + """ + puncts = {p for conv in texts for t in conv for p in t['punct']}; + puncts + idx2punct = list(puncts) + punct2idx = {p: idx for idx, p in enumerate(puncts)} + vocab = sorted({p for conv in texts for t in conv for p in t['norm_words']}) + data.update({ + 'idx2punct': idx2punct, + 'punct2idx': punct2idx, + 'vocab': vocab + }) + return data + + +"""Corpus specific parts""" + +CLSP_FISHER_PATHS = ( + Path('/export/corpora3/LDC/LDC2004T19'), + Path('/export/corpora3/LDC/LDC2005T19') +) + + +def prepare_no_timing(txo: Path) -> Optional[List[Tuple[str, str]]]: + try: + lines = txo.read_text().splitlines() + turns = (l.split() for l in lines if l.strip()) + turns = ((t[0][0], ' '.join(t[1:])) for t in turns) # (speaker, text) + turns = [ + {**data, 'speaker': speaker} + for speaker, data in + ((speaker, create_example(text)) for speaker, text in turns) + if data is not None + ] + return turns + except Exception as e: + print(f'Error processing path: {txo} -- {e}') + return None + + +def prepare_fisher( + paths: Sequence[Path] = CLSP_FISHER_PATHS, + output_path: Path = Path('fisher.pkl') +) -> Dict[str, Any]: + txos = list( + tqdm( + chain.from_iterable( + path.rglob('*.txo') for path in paths, + ), + desc='Scanning for Fisher transcripts' + ) + ) + texts = [ + t + for t in ( + prepare_no_timing(txo) + for txo in tqdm(txos, desc='Processing txos') + ) + if t is not None + ] + data = train_dev_test_split(texts) + data = add_vocab_and_labels(data, texts) + with open(output_path, 'wb') as f: + pickle.dump(data, f) + return data + + +def prepare_swda( + corpus: DialogActCorpus, + output_path: Path = Path('swda.pkl') +) -> List[Tuple[str, str]]: + data = {} + splits = corpus.train_dev_test_split() + for key, split in splits.items(): + calls = [] + for call in split.calls: + turns = [] + for speaker, turn in call.turns: + text = ' '.join(fs.text for fs in turn) + turns.append({ + **create_example(text), + 'speaker': speaker + }) + calls.append(turns) + data[key] = calls + texts = sum(data.values(), []) + data = add_vocab_and_labels(data, texts) + with open(output_path, 'wb') as f: + pickle.dump(data, f) + return data diff --git a/daseg/utils_ner.py b/daseg/utils_ner.py index f4682e6..fbcb418 100644 --- a/daseg/utils_ner.py +++ b/daseg/utils_ner.py @@ -18,7 +18,7 @@ import logging import os -from daseg.data import NEW_TURN +from daseg.data import BLANK, NEW_TURN logger = logging.getLogger(__name__) @@ -103,6 +103,7 @@ def convert_examples_to_features( label_map = {label: i for i, label in enumerate(label_list)} label_map[NEW_TURN] = pad_token_label_id + label_map[BLANK] = pad_token_label_id features = [] for (ex_index, example) in enumerate(examples): @@ -114,8 +115,11 @@ def convert_examples_to_features( for word, label in zip(example.words, example.labels): word_tokens = tokenizer.tokenize(word) tokens.extend(word_tokens) - word_labels = [label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1) + # Change it! # Use the real label id for the first token of the word, and padding ids for the remaining tokens + word_labels = [label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1) + # # Use the real label id for every token of a word + # word_labels = [label_map[label]] * len(word_tokens) label_ids.extend(word_labels) # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. diff --git a/requirements.txt b/requirements.txt index 119babd..ef0bfa1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ cytoolz scikit-learn>=0.22 git+https://github.com/pzelasko/plz git+https://github.com/pzelasko/seqeval +segeval Biopython transformers>=4.0.0,<=5.0.0 gluonnlp