From 35aefd5c36f47cfa495ac7d5c1b9113492a804f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 16 Apr 2021 16:53:03 -0400 Subject: [PATCH 01/38] Add punctuation data prep code (raw) --- daseg/punctuation/__init__.py | 0 daseg/punctuation/data.py | 243 ++++++++++++++++++++++++++++++++++ 2 files changed, 243 insertions(+) create mode 100644 daseg/punctuation/__init__.py create mode 100644 daseg/punctuation/data.py diff --git a/daseg/punctuation/__init__.py b/daseg/punctuation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/daseg/punctuation/data.py b/daseg/punctuation/data.py new file mode 100644 index 0000000..968b851 --- /dev/null +++ b/daseg/punctuation/data.py @@ -0,0 +1,243 @@ +# conda install pandas numpy matplotlib seaborn +# pip install jupyterlab tqdm kaldialign ipywidgets jupyterlab_widgets transformers cytoolz datasets seqeval +import pickle +import re +import string +from functools import lru_cache +from itertools import chain +from pathlib import Path +from typing import Any, Dict, List, Sequence, Tuple + +from tqdm.auto import tqdm + +from daseg import DialogActCorpus + +SPECIAL = ['[COUGH]', + '[COUGH]--[inaudible]', + '[COUGH]]', + '[LAUGH]', + '[LAUGH]]', + '[LIPSMACK]', + '[MN]', + '[NOISE]', + '[PAUSE]', + '[SIGH]', + '[a]', + '[inaudible]', + '[overspeaking]', + '[pause]'] + + +@lru_cache(20000) +def is_special_token(word: str) -> bool: + return any( + word.startswith(ltok) and word.endswith(rtok) + for ltok, rtok in ['[]', '<>'] + ) + + +@lru_cache(20000) +def split_punc(word: str, _pattern=re.compile(r'(\w|[\[\]<>])')) -> str: + # this silly _pattern will correctly handle "[NOISE]." -> "[NOISE]", "." + if not word: + return '', '' + word = word[::-1] # hi. -> .ih + match = _pattern.search(word) + if match is not None: + first_non_punc_idx = match.span()[0] + text = word[first_non_punc_idx:][::-1] + punc = word[:first_non_punc_idx][::-1] + return text, punc + else: + # pure punctuation + return '', word + + +def norm_punct(text: str): + text = text.replace('"', '') + text.replace('+', '') + text = re.sub(r'--+', '--', text) + precendences = ['?', '!', '...', '.', ',', '--', ';'] + words = text.split() + norm_words = [] + for w in words: + w, punc = split_punc(w) + for sym in precendences: + if sym in punc: + norm_words.append(f'{w}{sym}') + # print('inner', f'|{w}|{punc}|{sym}|{norm_words[-1]}|') + break + else: + norm_words.append(w) + # print('before return', words, norm_words) + return ' '.join(norm_words) + + +def to_words_labels_pair( + text: str, + _punctuation=string.punctuation.replace("'", ""), + _special=re.compile('|'.join(w.replace('[', '\[').replace(']', '\]') for w in SPECIAL)) +): + text = text.replace(' --', '--').replace(' ...', '...').strip() # stick punctuation to the text + text = _special.sub('', text) + text = ' '.join(text.split()) + + text_base, text_punct = split_punc(text) + if not text or not text_base: + return None + + # get rid of pesky punctuations like "hey...?!;" -> "hey?" + text = norm_punct(text) + # rich words and lower/no-punc words + words = text.split() + norm_words = [w.lower().translate(str.maketrans('', '', _punctuation)) for w in words] + # filter out the words that consisted only of punctuation + idx_to_remove = [] + for idx, (w, nw) in enumerate(zip(words, norm_words)): + if not nw: + idx_to_remove.append(idx) + norm_words = [ + w if not is_special_token(split_punc(words[idx])[0]) else split_punc(words[idx])[0] + for idx, w in enumerate(norm_words) + if idx not in idx_to_remove + ] + words = [ + w + for idx, w in enumerate(words) + if idx not in idx_to_remove + ] + try: + upper_words, labels = zip(*(split_punc(w) for w in words)) + except: + print(text) + print(words) + print([split_punc(w) for w in words]) + raise + is_upper = [not is_special_token(w) and any(c.isupper() for c in w) for w in upper_words] + return { + 'words': words, + 'upper_words': upper_words, + 'norm_words': norm_words, + 'punct': labels, + 'is_upper': is_upper + } + + +def prepare_no_timing(txo: Path) -> List[Tuple[str, str]]: + try: + lines = txo.read_text().splitlines() + turns = (l.split() for l in lines if l.strip()) + turns = ((t[0][0], ' '.join(t[1:])) for t in turns) # (speaker, text) + turns = [ + {**data, 'speaker': speaker} + for speaker, data in + ((speaker, to_words_labels_pair(text)) for speaker, text in turns) + if data is not None + ] + return turns + except Exception as e: + print(f'Error processing path: {txo} -- {e}') + return None + + +def train_dev_test_split(texts): + train_part = round(len(texts) * 0.9) + dev_part = round(len(texts) * 0.05) + data = { + 'train': texts[:train_part], + 'dev': texts[train_part:train_part + dev_part], + 'test': texts[train_part + dev_part:] + } + return data + + +def add_vocab_and_labels(data: Dict[str, Any], texts) -> Dict[str, Any]: + """Pickle structure: + { + 'train': [ + { + # Each dict represents a single turn + 'words': List[str], + 'upper_words': List[str], + 'norm_words': List[str], + 'punct': List[str], + 'is_upper': List[bool], + 'speaker': str + } + ], + 'dev': [...], + 'test': [...], + 'idx2punct': Dict[int, str], + 'punct2idx': Dict[str, int], + 'vocab': List[str] + } + """ + puncts = {p for conv in texts for t in conv for p in t['punct']}; + puncts + idx2punct = list(puncts) + punct2idx = {p: idx for idx, p in enumerate(puncts)} + vocab = sorted({p for conv in texts for t in conv for p in t['norm_words']}) + data.update({ + 'idx2punct': idx2punct, + 'punct2idx': punct2idx, + 'vocab': vocab + }) + return data + + +CLSP_FISHER_PATHS = ( + Path('/export/corpora3/LDC/LDC2004T19'), + Path('/export/corpora3/LDC/LDC2005T19') +) + + +def prepare_fisher( + paths: Sequence[Path] = CLSP_FISHER_PATHS, + output_path: Path = Path('fisher.pkl') +) -> Dict[str, Any]: + txos = list( + tqdm( + chain.from_iterable( + path.rglob('*.txo') for path in paths, + ), + desc='Scanning for Fisher transcripts' + ) + ) + texts = [ + t + for t in ( + prepare_no_timing(txo) + for txo in tqdm(txos, desc='Processing txos') + ) + if t is not None + ] + data = train_dev_test_split(texts) + data = add_vocab_and_labels(data, texts) + with open(output_path, 'wb') as f: + pickle.dump(data, f) + return data + + +def prepare_swda( + corpus: DialogActCorpus, + output_path: Path = Path('swda.pkl') +) -> List[Tuple[str, str]]: + data = {} + splits = corpus.train_dev_test_split() + for key, split in splits.items(): + calls = [] + for call in split.calls: + turns = [] + for speaker, turn in call.turns: + text = ' '.join(fs.text for fs in turn) + turns.append({ + **to_words_labels_pair(text), + 'speaker': speaker + }) + calls.append(turns) + data[key] = calls + texts = sum(data.values(), []) + data = add_vocab_and_labels(data, texts) + with open(output_path, 'wb') as f: + pickle.dump(data, f) + return data From 52239db44cfde7b76ef9fbeaa6e5de3b4f834344 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 16 Apr 2021 17:23:12 -0400 Subject: [PATCH 02/38] Add segeval metrics --- daseg/metrics.py | 20 ++++++++++++++++++++ daseg/models/bigru.py | 8 +++++++- daseg/models/transformer_model.py | 10 ++++++++-- requirements.txt | 1 + 4 files changed, 36 insertions(+), 3 deletions(-) diff --git a/daseg/metrics.py b/daseg/metrics.py index e1176e2..11baf5c 100644 --- a/daseg/metrics.py +++ b/daseg/metrics.py @@ -55,6 +55,26 @@ def compute_seqeval_metrics(true_labels: List[List[str]], predictions: List[List } +def compute_segeval_metrics(true_dataset: DialogActCorpus, pred_dataset: DialogActCorpus): + from statistics import mean + from segeval.data import Dataset + from segeval import boundary_similarity, pk + + true_segments = Dataset({ + cid: {'ref': [len(fs.text.split()) for fs in call]} + for cid, call in true_dataset.dialogues.items() + }) + pred_segments = Dataset({ + cid: {'hyp': [len(fs.text.split()) for fs in call]} + for cid, call in pred_dataset.dialogues.items() + }) + + return { + 'pk': float(mean(pk(true_segments, pred_segments).values())), + 'B': float(mean(boundary_similarity(true_segments, pred_segments).values())) + } + + def compute_zhao_kawahara_metrics_levenshtein(true_dataset: DialogActCorpus, pred_dataset: DialogActCorpus): """ Source: diff --git a/daseg/models/bigru.py b/daseg/models/bigru.py index 4d46b94..a1c1e2c 100644 --- a/daseg/models/bigru.py +++ b/daseg/models/bigru.py @@ -9,7 +9,8 @@ from torch.nn import CrossEntropyLoss from daseg.conversion import joint_coding_predictions_to_corpus -from daseg.metrics import compute_original_zhao_kawahara_metrics, compute_sklearn_metrics, compute_zhao_kawahara_metrics +from daseg.metrics import compute_original_zhao_kawahara_metrics, compute_segeval_metrics, compute_sklearn_metrics, \ + compute_zhao_kawahara_metrics class ZhaoKawaharaBiGru(pl.LightningModule): @@ -194,6 +195,11 @@ def compute_metrics(self, logits, true_labels): ) metrics.update({k: results[k] for k in ('micro_f1', 'macro_f1')}) + # Pk and B metrics + metrics.update(compute_segeval_metrics( + true_dataset=true_dataset, pred_dataset=pred_dataset) + ) + # We show the metrics obtained with Zhao-Kawahara code which computes them differently # (apparently the segment insertion errors are not counted) original_zhao_kawahara_metrics = compute_original_zhao_kawahara_metrics( diff --git a/daseg/models/transformer_model.py b/daseg/models/transformer_model.py index 27ff235..1a9e3bb 100644 --- a/daseg/models/transformer_model.py +++ b/daseg/models/transformer_model.py @@ -19,7 +19,8 @@ from daseg.conversion import predictions_to_dataset from daseg.data import DialogActCorpus from daseg.dataloaders.transformers import pad_list_of_arrays, to_transformers_eval_dataloader -from daseg.metrics import compute_original_zhao_kawahara_metrics, compute_seqeval_metrics, compute_sklearn_metrics, \ +from daseg.metrics import compute_original_zhao_kawahara_metrics, compute_segeval_metrics, compute_seqeval_metrics, \ + compute_sklearn_metrics, \ compute_zhao_kawahara_metrics from daseg.models.longformer_model import LongformerForTokenClassification @@ -180,8 +181,10 @@ def predict( # (apparently the segment insertion errors are not counted) "ORIGINAL_zhao_kawahara_metrics": compute_original_zhao_kawahara_metrics( true_turns=out_label_list, pred_turns=preds_list - ) + ), }) + # Pk and B metrics + if isinstance(dataset, DialogActCorpus): if use_turns: dataset = DialogActCorpus(dialogues={str(i): turn for i, turn in enumerate(dataset.turns)}) @@ -195,6 +198,9 @@ def predict( results["zhao_kawahara_metrics"] = compute_zhao_kawahara_metrics( true_dataset=dataset, pred_dataset=results['dataset'] ) + results["segeval_metrics"] = compute_segeval_metrics( + true_dataset=dataset, pred_dataset=results['dataset'] + ) return results diff --git a/requirements.txt b/requirements.txt index 119babd..ef0bfa1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ cytoolz scikit-learn>=0.22 git+https://github.com/pzelasko/plz git+https://github.com/pzelasko/seqeval +segeval Biopython transformers>=4.0.0,<=5.0.0 gluonnlp From ddd6d2469b83087890ca7f66932dbc471ebd5429 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 7 May 2021 18:17:26 -0400 Subject: [PATCH 03/38] Initial GPU CRF loss implementation --- daseg/losses/__init__.py | 0 daseg/losses/crf.py | 113 ++++++++++++++++++++++++++++++ daseg/punctuation/data.py | 141 +++++++++++++++++++++++--------------- 3 files changed, 197 insertions(+), 57 deletions(-) create mode 100644 daseg/losses/__init__.py create mode 100644 daseg/losses/crf.py diff --git a/daseg/losses/__init__.py b/daseg/losses/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/daseg/losses/crf.py b/daseg/losses/crf.py new file mode 100644 index 0000000..15ddab5 --- /dev/null +++ b/daseg/losses/crf.py @@ -0,0 +1,113 @@ +from typing import List + +import k2 +import torch +from torch import Tensor, nn + + +class CRFLoss(nn.Module): + def __init__(self, label_set: List[str]): + super().__init__() + self.label_set = label_set + self.den = make_topology(label_set, shared=True) + self.den_scores = nn.Parameter(self.den.scores.clone(), requires_grad=True) + + def forward(self, log_probs: Tensor, input_lens: Tensor, labels: Tensor): + posteriors = k2.DenseFsaVec( + log_probs, + supervision_segments=make_segments(input_lens) + ) + nums = make_numerator(labels, input_lens) + self.den.set_scores_stochastic_(self.den_scores) + + num_lattices = k2.intersect_dense(nums, posteriors, output_beam=10.0) + den_lattice = k2.intersect_dense(self.den, posteriors, output_beam=10.0) + + num_score = num_lattices.get_tot_scores(use_double_scores=True, log_semiring=True).sum() + den_score = den_lattice.get_tot_scores(use_double_scores=True, log_semiring=True).sum() + + loss = num_score - den_score + return loss + + +def make_symbol_table(label_set: List[str], shared: bool = True) -> k2.SymbolTable: + symtab = k2.SymbolTable() + symtab.add('O') + if shared: + symtab.add('I-') + for l in label_set: + symtab.add(l) + if not shared: + symtab.add(f'I-{l}') + return symtab + + +def make_numerator(labels: Tensor, input_lens: Tensor) -> k2.Fsa: + assert labels.size(0) == input_lens.size(0) + assert len(labels.shape) == 2 + assert len(input_lens.shape) == 1 + nums = k2.create_fsa_vec([ + k2.linear_fsa(l[:llen]) for l, llen in zip(labels, input_lens) + ]) + return nums + + +def make_topology(label_set: List[str], shared: bool = True) -> k2.Fsa: + symtab = make_symbol_table(label_set, shared=shared) + + """ + shared=True + 0 0 O + 0 0 Statement + 0 0 Question + 0 1 I- + 1 1 I- + 1 0 Statement + 1 0 Question + 0 2 -1 + 2 + """ + + """ + shared=False + 0 0 O + 0 0 Statement + 0 0 Question + 0 1 I-Statement + 0 1 I-Question + 1 1 I-Statement + 1 1 I-Question + 1 0 Statement + 1 0 Question + 0 2 -1 + 2 + """ + + s = [f'0 0 {symtab["O"]}'] + if shared: + s += [ + f'0 1 {symtab["I-"]}', + f'1 1 {symtab["I-"]}' + ] + for idx, label in enumerate(label_set): + s += [f'0 0 {symtab[label]}'] + if not shared: + s += [ + f'0 1 {symtab["I-" + label]}' + f'1 1 {symtab["I-" + label]}' + ] + s += [f'1 0 {symtab[label]}'] + s += ['0 2 -1', '2'] + s.sort() + fsa = k2.Fsa.from_str(s) + fsa.symbols = symtab + return fsa + + +def make_segments(input_lens: Tensor) -> Tensor: + bs = input_lens.size(0) + return torch.stack([ + torch.arange(bs, dtype=torch.int32), + torch.zeros(bs, dtype=torch.int32), + input_lens.cpu().to(torch.int32) + ]) diff --git a/daseg/punctuation/data.py b/daseg/punctuation/data.py index 968b851..79e4a9d 100644 --- a/daseg/punctuation/data.py +++ b/daseg/punctuation/data.py @@ -6,26 +6,29 @@ from functools import lru_cache from itertools import chain from pathlib import Path -from typing import Any, Dict, List, Sequence, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple, TypedDict from tqdm.auto import tqdm from daseg import DialogActCorpus -SPECIAL = ['[COUGH]', - '[COUGH]--[inaudible]', - '[COUGH]]', - '[LAUGH]', - '[LAUGH]]', - '[LIPSMACK]', - '[MN]', - '[NOISE]', - '[PAUSE]', - '[SIGH]', - '[a]', - '[inaudible]', - '[overspeaking]', - '[pause]'] + +class Example(TypedDict): + words: List[str] + upper_words: List[str] + norm_words: List[str] + punct: List[str] + is_upper: List[bool] + speaker: str + + +class PunctuationData(TypedDict): + train: List[Example] + dev: List[Example] + test: List[Example] + idx2punct: Dict[int, str] + punct2idx: Dict[str, int] + vocab: List[str] @lru_cache(20000) @@ -37,7 +40,7 @@ def is_special_token(word: str) -> bool: @lru_cache(20000) -def split_punc(word: str, _pattern=re.compile(r'(\w|[\[\]<>])')) -> str: +def split_punctuation_from_word(word: str, _pattern=re.compile(r'(\w|[\[\]<>])')) -> Tuple[str, str]: # this silly _pattern will correctly handle "[NOISE]." -> "[NOISE]", "." if not word: return '', '' @@ -53,41 +56,61 @@ def split_punc(word: str, _pattern=re.compile(r'(\w|[\[\]<>])')) -> str: return '', word -def norm_punct(text: str): +def preprocess_punctuation( + text: str, + _precedences=['?', '!', '...', '.', ',', '--', ';'], +) -> str: text = text.replace('"', '') text.replace('+', '') text = re.sub(r'--+', '--', text) - precendences = ['?', '!', '...', '.', ',', '--', ';'] words = text.split() norm_words = [] for w in words: - w, punc = split_punc(w) - for sym in precendences: + w, punc = split_punctuation_from_word(w) + for sym in _precedences: if sym in punc: norm_words.append(f'{w}{sym}') - # print('inner', f'|{w}|{punc}|{sym}|{norm_words[-1]}|') break else: norm_words.append(w) - # print('before return', words, norm_words) return ' '.join(norm_words) -def to_words_labels_pair( +def create_example( text: str, _punctuation=string.punctuation.replace("'", ""), - _special=re.compile('|'.join(w.replace('[', '\[').replace(']', '\]') for w in SPECIAL)) -): + _special=re.compile(r'\[\[.*?\]\]|\[.*?\]|<.*?>', ) +) -> Optional[Example]: + """ + Converts a text segment / utterance into a dict that + can be used for punctuation/truecasing model training/eval. + + .. code-block:: python + + { + 'words': List[str], + 'upper_words': List[str], + 'norm_words': List[str], + 'punct': List[str], + 'is_upper': List[bool], + 'speaker': str + } + + :param text: + :param _punctuation: list of punctuation symbols (default is globally cached). + :param _special: regex pattern for detecting special symbols like [UNK] or (default is globally cached). + :return: see above. + """ text = text.replace(' --', '--').replace(' ...', '...').strip() # stick punctuation to the text text = _special.sub('', text) text = ' '.join(text.split()) - text_base, text_punct = split_punc(text) + text_base, text_punct = split_punctuation_from_word(text) if not text or not text_base: return None # get rid of pesky punctuations like "hey...?!;" -> "hey?" - text = norm_punct(text) + text = preprocess_punctuation(text) # rich words and lower/no-punc words words = text.split() norm_words = [w.lower().translate(str.maketrans('', '', _punctuation)) for w in words] @@ -97,7 +120,8 @@ def to_words_labels_pair( if not nw: idx_to_remove.append(idx) norm_words = [ - w if not is_special_token(split_punc(words[idx])[0]) else split_punc(words[idx])[0] + w if not is_special_token(split_punctuation_from_word(words[idx])[0]) else + split_punctuation_from_word(words[idx])[0] for idx, w in enumerate(norm_words) if idx not in idx_to_remove ] @@ -106,13 +130,7 @@ def to_words_labels_pair( for idx, w in enumerate(words) if idx not in idx_to_remove ] - try: - upper_words, labels = zip(*(split_punc(w) for w in words)) - except: - print(text) - print(words) - print([split_punc(w) for w in words]) - raise + upper_words, labels = zip(*(split_punctuation_from_word(w) for w in words)) is_upper = [not is_special_token(w) and any(c.isupper() for c in w) for w in upper_words] return { 'words': words, @@ -123,26 +141,13 @@ def to_words_labels_pair( } -def prepare_no_timing(txo: Path) -> List[Tuple[str, str]]: - try: - lines = txo.read_text().splitlines() - turns = (l.split() for l in lines if l.strip()) - turns = ((t[0][0], ' '.join(t[1:])) for t in turns) # (speaker, text) - turns = [ - {**data, 'speaker': speaker} - for speaker, data in - ((speaker, to_words_labels_pair(text)) for speaker, text in turns) - if data is not None - ] - return turns - except Exception as e: - print(f'Error processing path: {txo} -- {e}') - return None - - -def train_dev_test_split(texts): - train_part = round(len(texts) * 0.9) - dev_part = round(len(texts) * 0.05) +def train_dev_test_split( + texts: List[Example], + train_portion: float = 0.9, + dev_portion: float = 0.05 +) -> Dict[str, List[Example]]: + train_part = round(len(texts) * train_portion) + dev_part = round(len(texts) * dev_portion) data = { 'train': texts[:train_part], 'dev': texts[train_part:train_part + dev_part], @@ -151,7 +156,10 @@ def train_dev_test_split(texts): return data -def add_vocab_and_labels(data: Dict[str, Any], texts) -> Dict[str, Any]: +def add_vocab_and_labels( + data: Dict[str, Any], + texts: Dict[str, List[Example]] +) -> PunctuationData: """Pickle structure: { 'train': [ @@ -185,12 +193,31 @@ def add_vocab_and_labels(data: Dict[str, Any], texts) -> Dict[str, Any]: return data +"""Corpus specific parts""" + CLSP_FISHER_PATHS = ( Path('/export/corpora3/LDC/LDC2004T19'), Path('/export/corpora3/LDC/LDC2005T19') ) +def prepare_no_timing(txo: Path) -> Optional[List[Tuple[str, str]]]: + try: + lines = txo.read_text().splitlines() + turns = (l.split() for l in lines if l.strip()) + turns = ((t[0][0], ' '.join(t[1:])) for t in turns) # (speaker, text) + turns = [ + {**data, 'speaker': speaker} + for speaker, data in + ((speaker, create_example(text)) for speaker, text in turns) + if data is not None + ] + return turns + except Exception as e: + print(f'Error processing path: {txo} -- {e}') + return None + + def prepare_fisher( paths: Sequence[Path] = CLSP_FISHER_PATHS, output_path: Path = Path('fisher.pkl') @@ -231,7 +258,7 @@ def prepare_swda( for speaker, turn in call.turns: text = ' '.join(fs.text for fs in turn) turns.append({ - **to_words_labels_pair(text), + **create_example(text), 'speaker': speaker }) calls.append(turns) From 171a56e6bc70ef95746da712ff449e80f4bf79f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 7 May 2021 18:32:39 -0400 Subject: [PATCH 04/38] Add documentation --- daseg/losses/crf.py | 55 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 5 deletions(-) diff --git a/daseg/losses/crf.py b/daseg/losses/crf.py index 15ddab5..d9a6d4e 100644 --- a/daseg/losses/crf.py +++ b/daseg/losses/crf.py @@ -6,31 +6,50 @@ class CRFLoss(nn.Module): + """ + Conditional Random Field loss implemented with K2 library. It supports GPU computation. + + Currently, this loss assumes specific topologies for dialog acts/punctuation labeling. + """ + def __init__(self, label_set: List[str]): super().__init__() self.label_set = label_set - self.den = make_topology(label_set, shared=True) + self.den = make_denominator(label_set, shared=True) self.den_scores = nn.Parameter(self.den.scores.clone(), requires_grad=True) def forward(self, log_probs: Tensor, input_lens: Tensor, labels: Tensor): + # (batch, seqlen, classes) posteriors = k2.DenseFsaVec( log_probs, supervision_segments=make_segments(input_lens) ) + # (fsavec) nums = make_numerator(labels, input_lens) self.den.set_scores_stochastic_(self.den_scores) + # (fsavec) num_lattices = k2.intersect_dense(nums, posteriors, output_beam=10.0) den_lattice = k2.intersect_dense(self.den, posteriors, output_beam=10.0) - num_score = num_lattices.get_tot_scores(use_double_scores=True, log_semiring=True).sum() - den_score = den_lattice.get_tot_scores(use_double_scores=True, log_semiring=True).sum() + # (batch,) + num_scores = num_lattices.get_tot_scores(use_double_scores=True, log_semiring=True) + den_scores = den_lattice.get_tot_scores(use_double_scores=True, log_semiring=True) - loss = num_score - den_score + # (scalar) + loss = (num_scores - den_scores).sum() return loss def make_symbol_table(label_set: List[str], shared: bool = True) -> k2.SymbolTable: + """ + Creates a symbol table given a list of classes (e.g. dialog acts, punctuation, etc.). + It adds extra symbols: + - 'O' which is used to indicate special tokens such as + - (when shared=True) 'I-' which is the "in-the-middle" symbol shared between all classes + - (when shared=False) 'I-' which is the "in-the-middle" symbol, + specific for each class (N x classes -> N x I- symbols) + """ symtab = k2.SymbolTable() symtab.add('O') if shared: @@ -43,6 +62,11 @@ def make_symbol_table(label_set: List[str], shared: bool = True) -> k2.SymbolTab def make_numerator(labels: Tensor, input_lens: Tensor) -> k2.Fsa: + """ + Creates a numerator supervision FSA. + It simply encodes the ground truth label sequence and allows no leeway. + Returns a :class:`k2.FsaVec` with FSAs of differing length. + """ assert labels.size(0) == input_lens.size(0) assert len(labels.shape) == 2 assert len(input_lens.shape) == 1 @@ -52,7 +76,24 @@ def make_numerator(labels: Tensor, input_lens: Tensor) -> k2.Fsa: return nums -def make_topology(label_set: List[str], shared: bool = True) -> k2.Fsa: +def make_denominator(label_set: List[str], shared: bool = True) -> k2.Fsa: + """ + Creates a "simple" denominator that encodes all possible transitions + given the input label set. + + The labeling scheme is assumed to be IE with joint coding, e.g.: + + Here I am.~~~~~~ How are you today?~~ + I~~~ I Statement I~~ I~~ I~~ Question + + Or without joint coding: + + Here~~~~~~~ I~~~~~~~~~~ am.~~~~~~ How~~~~~~~ are~~~~~~~ you~~~~~~~ today?~~ + I-Statement I-Statement Statement I-Question I-Question I-Question Question + + When shared=True, it uses a shared "in-the-middle" label for all classes; + otherwise each class has a separate one. + """ symtab = make_symbol_table(label_set, shared=shared) """ @@ -105,6 +146,10 @@ def make_topology(label_set: List[str], shared: bool = True) -> k2.Fsa: def make_segments(input_lens: Tensor) -> Tensor: + """ + Creates a supervision segments tensor that indicates for each batch example, + at which index the example has started, and how many tokens it has. + """ bs = input_lens.size(0) return torch.stack([ torch.arange(bs, dtype=torch.int32), From 76978f8c3051e431abbab74e053286fcf7d80484 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 7 May 2021 18:33:48 -0400 Subject: [PATCH 05/38] Add option for trainable/freezed transition scores --- daseg/losses/crf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/daseg/losses/crf.py b/daseg/losses/crf.py index d9a6d4e..673cf5e 100644 --- a/daseg/losses/crf.py +++ b/daseg/losses/crf.py @@ -12,11 +12,11 @@ class CRFLoss(nn.Module): Currently, this loss assumes specific topologies for dialog acts/punctuation labeling. """ - def __init__(self, label_set: List[str]): + def __init__(self, label_set: List[str], trainable_transition_scores: bool = True): super().__init__() self.label_set = label_set self.den = make_denominator(label_set, shared=True) - self.den_scores = nn.Parameter(self.den.scores.clone(), requires_grad=True) + self.den_scores = nn.Parameter(self.den.scores.clone(), requires_grad=trainable_transition_scores) def forward(self, log_probs: Tensor, input_lens: Tensor, labels: Tensor): # (batch, seqlen, classes) From 82a5395ff5fafcc596048f62f0eb0416afa89e82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sun, 9 May 2021 19:08:24 -0400 Subject: [PATCH 06/38] Fix arc scores --- daseg/losses/crf.py | 16 ++++++++-------- daseg/models/transformer_pl.py | 27 ++++++++++++++++++++++----- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/daseg/losses/crf.py b/daseg/losses/crf.py index 673cf5e..c23f35f 100644 --- a/daseg/losses/crf.py +++ b/daseg/losses/crf.py @@ -124,21 +124,21 @@ def make_denominator(label_set: List[str], shared: bool = True) -> k2.Fsa: 2 """ - s = [f'0 0 {symtab["O"]}'] + s = [f'0 0 {symtab["O"]} 0.0'] if shared: s += [ - f'0 1 {symtab["I-"]}', - f'1 1 {symtab["I-"]}' + f'0 1 {symtab["I-"]} 0.0', + f'1 1 {symtab["I-"]} 0.0' ] for idx, label in enumerate(label_set): - s += [f'0 0 {symtab[label]}'] + s += [f'0 0 {symtab[label]} 0.0'] if not shared: s += [ - f'0 1 {symtab["I-" + label]}' - f'1 1 {symtab["I-" + label]}' + f'0 1 {symtab["I-" + label]} 0.0' + f'1 1 {symtab["I-" + label]} 0.0' ] - s += [f'1 0 {symtab[label]}'] - s += ['0 2 -1', '2'] + s += [f'1 0 {symtab[label]} 0.0'] + s += ['0 2 -1 0.0', '2'] s.sort() fsa = k2.Fsa.from_str(s) fsa.symbols = symtab diff --git a/daseg/models/transformer_pl.py b/daseg/models/transformer_pl.py index fb96a9a..481f4d5 100644 --- a/daseg/models/transformer_pl.py +++ b/daseg/models/transformer_pl.py @@ -12,11 +12,18 @@ from daseg.data import NEW_TURN from daseg.dataloaders.transformers import pad_array +from daseg.losses.crf import CRFLoss from daseg.metrics import as_tensors, compute_sklearn_metrics class DialogActTransformer(pl.LightningModule): - def __init__(self, labels: List[str], model_name_or_path: str, pretrained: bool = True): + def __init__( + self, + labels: List[str], + model_name_or_path: str, + pretrained: bool = True, + crf: bool = False + ): super().__init__() self.save_hyperparameters() self.pad_token_label_id = CrossEntropyLoss().ignore_index @@ -40,6 +47,10 @@ def __init__(self, labels: List[str], model_name_or_path: str, pretrained: bool model_class = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING[type(self.config)] self.model = model_class(self.config) self.model.resize_token_embeddings(len(self.tokenizer)) + if crf: + self.crf = CRFLoss(list(set(labels) - {'O', 'I-'})) + else: + self.crf = None def forward(self, **inputs): return self.model(**inputs) @@ -53,7 +64,10 @@ def training_step(self, batch, batch_num): ) # XLM and RoBERTa don"t use token_type_ids outputs = self(**inputs) - loss = outputs[0] + loss, logits = outputs[:2] + if self.crf is not None: + log_probs = torch.nn.functional.log_softmax(logits) + loss = -self.crf(log_probs) tensorboard_logs = {"loss": loss} return {"loss": loss, "log": tensorboard_logs} @@ -66,10 +80,13 @@ def validation_step(self, batch, batch_nb): batch[2] if self.config.model_type in ["bert", "xlnet"] else None ) # XLM and RoBERTa don"t use token_type_ids outputs = self(**inputs) - tmp_eval_loss, logits = outputs[:2] + loss, logits = outputs[:2] + if self.crf is not None: + log_probs = torch.nn.functional.log_softmax(logits) + loss = -self.crf(log_probs) preds = logits.detach().cpu().numpy() - out_label_ids = inputs["labels"].detach().cpu().numpy() - return {"val_loss": tmp_eval_loss.detach().cpu(), "pred": preds, "target": out_label_ids} + out_label_ids = inputs["labels"] + return {"val_loss": loss, "pred": preds, "target": out_label_ids} def test_step(self, batch, batch_nb): return self.validation_step(batch, batch_nb) From 8cc0fd02e1014679561692b1ee947e695e31a236 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sun, 9 May 2021 19:11:54 -0400 Subject: [PATCH 07/38] Fix from_str usage --- daseg/losses/crf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/daseg/losses/crf.py b/daseg/losses/crf.py index c23f35f..7134eac 100644 --- a/daseg/losses/crf.py +++ b/daseg/losses/crf.py @@ -140,7 +140,7 @@ def make_denominator(label_set: List[str], shared: bool = True) -> k2.Fsa: s += [f'1 0 {symtab[label]} 0.0'] s += ['0 2 -1 0.0', '2'] s.sort() - fsa = k2.Fsa.from_str(s) + fsa = k2.Fsa.from_str('\n'.join(s)) fsa.symbols = symtab return fsa From 44ba0c87abef9359eff05c04ee0518dbc17682ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 10 May 2021 01:24:40 +0200 Subject: [PATCH 08/38] Fix supervision segments and linear fsa --- daseg/losses/crf.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/daseg/losses/crf.py b/daseg/losses/crf.py index 7134eac..e1712c2 100644 --- a/daseg/losses/crf.py +++ b/daseg/losses/crf.py @@ -20,10 +20,9 @@ def __init__(self, label_set: List[str], trainable_transition_scores: bool = Tru def forward(self, log_probs: Tensor, input_lens: Tensor, labels: Tensor): # (batch, seqlen, classes) - posteriors = k2.DenseFsaVec( - log_probs, - supervision_segments=make_segments(input_lens) - ) + supervision_segments=make_segments(input_lens) + posteriors = k2.DenseFsaVec(log_probs, supervision_segments) + # (fsavec) nums = make_numerator(labels, input_lens) self.den.set_scores_stochastic_(self.den_scores) @@ -71,7 +70,7 @@ def make_numerator(labels: Tensor, input_lens: Tensor) -> k2.Fsa: assert len(labels.shape) == 2 assert len(input_lens.shape) == 1 nums = k2.create_fsa_vec([ - k2.linear_fsa(l[:llen]) for l, llen in zip(labels, input_lens) + k2.linear_fsa(l[:llen].tolist()) for l, llen in zip(labels, input_lens) ]) return nums @@ -155,4 +154,4 @@ def make_segments(input_lens: Tensor) -> Tensor: torch.arange(bs, dtype=torch.int32), torch.zeros(bs, dtype=torch.int32), input_lens.cpu().to(torch.int32) - ]) + ], dim=1) From fbdafa0e086d70e0180e6e6b0448ff4004db02b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sun, 9 May 2021 19:35:30 -0400 Subject: [PATCH 09/38] First attempt at CRF training --- daseg/bin/dasg | 6 ++++-- daseg/losses/crf.py | 26 ++++++++++++++++---------- daseg/models/transformer_pl.py | 5 +++-- 3 files changed, 23 insertions(+), 14 deletions(-) diff --git a/daseg/bin/dasg b/daseg/bin/dasg index c665ecb..70ddd8c 100644 --- a/daseg/bin/dasg +++ b/daseg/bin/dasg @@ -203,6 +203,7 @@ def prepare_exp( @click.option('-r', '--random-seed', default=1050, type=int) @click.option('-g', '--num-gpus', default=0, type=int) @click.option('-f', '--fp16', is_flag=True) +@click.option('--crf/--no-crf', default=False) def train_transformer( exp_dir: Path, model_name_or_path: str, @@ -212,7 +213,8 @@ def train_transformer( gradient_accumulation_steps: int, random_seed: int, num_gpus: int, - fp16: bool + fp16: bool, + crf: bool ): pl.seed_everything(random_seed) output_path = Path(exp_dir) @@ -220,7 +222,7 @@ def train_transformer( datasets: Dict[str, Dataset] = pickle.load(f) with open(output_path / 'labels.pkl', 'rb') as f: labels: List[str] = pickle.load(f) - model = DialogActTransformer(labels=labels, model_name_or_path=model_name_or_path) + model = DialogActTransformer(labels=labels, model_name_or_path=model_name_or_path, crf=crf) loaders = { key: to_dataloader( dset, diff --git a/daseg/losses/crf.py b/daseg/losses/crf.py index 7134eac..88cc00d 100644 --- a/daseg/losses/crf.py +++ b/daseg/losses/crf.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Dict, List import k2 import torch @@ -12,10 +12,16 @@ class CRFLoss(nn.Module): Currently, this loss assumes specific topologies for dialog acts/punctuation labeling. """ - def __init__(self, label_set: List[str], trainable_transition_scores: bool = True): + def __init__( + self, + label_set: List[str], + label2id: Dict[str, int], + trainable_transition_scores: bool = True + ): super().__init__() self.label_set = label_set - self.den = make_denominator(label_set, shared=True) + self.label2id = label2id + self.den = make_denominator(label_set, label2id, shared=True) self.den_scores = nn.Parameter(self.den.scores.clone(), requires_grad=trainable_transition_scores) def forward(self, log_probs: Tensor, input_lens: Tensor, labels: Tensor): @@ -41,7 +47,7 @@ def forward(self, log_probs: Tensor, input_lens: Tensor, labels: Tensor): return loss -def make_symbol_table(label_set: List[str], shared: bool = True) -> k2.SymbolTable: +def make_symbol_table(label_set: List[str], label2id: Dict[str, int], shared: bool = True) -> k2.SymbolTable: """ Creates a symbol table given a list of classes (e.g. dialog acts, punctuation, etc.). It adds extra symbols: @@ -51,13 +57,13 @@ def make_symbol_table(label_set: List[str], shared: bool = True) -> k2.SymbolTab specific for each class (N x classes -> N x I- symbols) """ symtab = k2.SymbolTable() - symtab.add('O') + symtab.add('O', label2id['O']) if shared: - symtab.add('I-') + symtab.add('I-', label2id['I-']) for l in label_set: - symtab.add(l) + symtab.add(l, label2id[l]) if not shared: - symtab.add(f'I-{l}') + symtab.add(f'I-{l}', label2id[f'I-{l}']) return symtab @@ -76,7 +82,7 @@ def make_numerator(labels: Tensor, input_lens: Tensor) -> k2.Fsa: return nums -def make_denominator(label_set: List[str], shared: bool = True) -> k2.Fsa: +def make_denominator(label_set: List[str], label2id: Dict[str, int], shared: bool = True) -> k2.Fsa: """ Creates a "simple" denominator that encodes all possible transitions given the input label set. @@ -94,7 +100,7 @@ def make_denominator(label_set: List[str], shared: bool = True) -> k2.Fsa: When shared=True, it uses a shared "in-the-middle" label for all classes; otherwise each class has a separate one. """ - symtab = make_symbol_table(label_set, shared=shared) + symtab = make_symbol_table(label_set, label2id, shared=shared) """ shared=True diff --git a/daseg/models/transformer_pl.py b/daseg/models/transformer_pl.py index 481f4d5..0fa3b05 100644 --- a/daseg/models/transformer_pl.py +++ b/daseg/models/transformer_pl.py @@ -28,12 +28,13 @@ def __init__( self.save_hyperparameters() self.pad_token_label_id = CrossEntropyLoss().ignore_index self.labels = labels + self.label2id = {label: i for i, label in enumerate(self.labels)} self.num_labels = len(self.labels) self.config = AutoConfig.from_pretrained( model_name_or_path, num_labels=self.num_labels, id2label={str(i): label for i, label in enumerate(self.labels)}, - label2id={label: i for i, label in enumerate(self.labels)}, + label2id=self.label2id ) self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) self.tokenizer.add_special_tokens({'additional_special_tokens': [NEW_TURN]}) @@ -48,7 +49,7 @@ def __init__( self.model = model_class(self.config) self.model.resize_token_embeddings(len(self.tokenizer)) if crf: - self.crf = CRFLoss(list(set(labels) - {'O', 'I-'})) + self.crf = CRFLoss(self.labels, self.label2id) else: self.crf = None From a92232dfc470c9ee44e0b42cda89965567b3552e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sun, 9 May 2021 19:43:14 -0400 Subject: [PATCH 10/38] try modelling O as --- daseg/losses/crf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/daseg/losses/crf.py b/daseg/losses/crf.py index 18ba774..9705d21 100644 --- a/daseg/losses/crf.py +++ b/daseg/losses/crf.py @@ -56,7 +56,7 @@ def make_symbol_table(label_set: List[str], label2id: Dict[str, int], shared: bo specific for each class (N x classes -> N x I- symbols) """ symtab = k2.SymbolTable() - symtab.add('O', label2id['O']) + # symtab.add('O', label2id['O']) if shared: symtab.add('I-', label2id['I-']) for l in label_set: From a8120331f2aafeb2898b42da3291f98b2d964430 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sun, 9 May 2021 19:45:30 -0400 Subject: [PATCH 11/38] try modelling O as --- daseg/models/transformer_pl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/daseg/models/transformer_pl.py b/daseg/models/transformer_pl.py index 0fa3b05..074b7fe 100644 --- a/daseg/models/transformer_pl.py +++ b/daseg/models/transformer_pl.py @@ -49,7 +49,7 @@ def __init__( self.model = model_class(self.config) self.model.resize_token_embeddings(len(self.tokenizer)) if crf: - self.crf = CRFLoss(self.labels, self.label2id) + self.crf = CRFLoss([l for l in self.labels if l != 'O' and not l.startswith('I-')], self.label2id) else: self.crf = None From d1db50b0ade9fead24ad9b4c7cd6285d4e4eded7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sun, 9 May 2021 19:46:34 -0400 Subject: [PATCH 12/38] try modelling O as --- daseg/losses/crf.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/daseg/losses/crf.py b/daseg/losses/crf.py index 9705d21..22beaa3 100644 --- a/daseg/losses/crf.py +++ b/daseg/losses/crf.py @@ -129,7 +129,8 @@ def make_denominator(label_set: List[str], label2id: Dict[str, int], shared: boo 2 """ - s = [f'0 0 {symtab["O"]} 0.0'] + # s = [f'0 0 {symtab["O"]} 0.0'] + s = [f'0 0 {symtab[""]} 0.0'] if shared: s += [ f'0 1 {symtab["I-"]} 0.0', From 7565e6e7f60b83556c3d271286b4d87605b534eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 11 May 2021 10:28:53 -0400 Subject: [PATCH 13/38] Add extra inputs to CRF loss --- daseg/dataloaders/transformers.py | 9 +++++++-- daseg/losses/crf.py | 1 + daseg/models/transformer_pl.py | 7 ++++--- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/daseg/dataloaders/transformers.py b/daseg/dataloaders/transformers.py index d06f358..6fc99a7 100644 --- a/daseg/dataloaders/transformers.py +++ b/daseg/dataloaders/transformers.py @@ -178,13 +178,18 @@ def to_transformers_eval_dataloader( return to_dataloader(dataset, batch_size=batch_size, train=False, padding_at_start=model_type == 'xlnet') -def truncate_padding_collate_fn(batch: List[List[torch.Tensor]], padding_at_start: bool = False): +def truncate_padding_collate_fn(batch: List[List[torch.Tensor]], padding_at_start: bool = False, add_ilen: bool = True): redundant_padding = max(mask.sum() for _, mask, _, _ in batch) n_tensors = len(batch[0]) concat_tensors = (torch.cat([sample[i].unsqueeze(0) for sample in batch]) for i in range(n_tensors)) if padding_at_start: return [t[:, -redundant_padding:] for t in concat_tensors] - return [t[:, :redundant_padding] for t in concat_tensors] + truncated = [t[:, :redundant_padding] for t in concat_tensors] + # Here we add extra tensor that states the input lens + truncated.append( + truncated[0].new_tensor([mask.sum() for _, mask, _, _ in truncated]) + ) + return truncated def pad_list_of_arrays(arrays: List[np.ndarray], value: float) -> List[np.ndarray]: diff --git a/daseg/losses/crf.py b/daseg/losses/crf.py index 22beaa3..fbaad29 100644 --- a/daseg/losses/crf.py +++ b/daseg/losses/crf.py @@ -1,3 +1,4 @@ +import dataclasses from typing import Dict, List import k2 diff --git a/daseg/models/transformer_pl.py b/daseg/models/transformer_pl.py index 074b7fe..0259d13 100644 --- a/daseg/models/transformer_pl.py +++ b/daseg/models/transformer_pl.py @@ -63,12 +63,12 @@ def training_step(self, batch, batch_num): inputs["token_type_ids"] = ( batch[2] if self.config.model_type in ["bert", "xlnet"] else None ) # XLM and RoBERTa don"t use token_type_ids - outputs = self(**inputs) loss, logits = outputs[:2] if self.crf is not None: log_probs = torch.nn.functional.log_softmax(logits) - loss = -self.crf(log_probs) + labels, ilens = batch[3], batch[4] + loss = -self.crf(log_probs, ilens, labels) tensorboard_logs = {"loss": loss} return {"loss": loss, "log": tensorboard_logs} @@ -84,7 +84,8 @@ def validation_step(self, batch, batch_nb): loss, logits = outputs[:2] if self.crf is not None: log_probs = torch.nn.functional.log_softmax(logits) - loss = -self.crf(log_probs) + labels, ilens = batch[3], batch[4] + loss = -self.crf(log_probs, ilens, labels) preds = logits.detach().cpu().numpy() out_label_ids = inputs["labels"] return {"val_loss": loss, "pred": preds, "target": out_label_ids} From ee553a1e00369b4a7eb3c2a91a058b422bde673e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 11 May 2021 10:33:31 -0400 Subject: [PATCH 14/38] Fix --- daseg/dataloaders/transformers.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/daseg/dataloaders/transformers.py b/daseg/dataloaders/transformers.py index 6fc99a7..67f239b 100644 --- a/daseg/dataloaders/transformers.py +++ b/daseg/dataloaders/transformers.py @@ -185,10 +185,11 @@ def truncate_padding_collate_fn(batch: List[List[torch.Tensor]], padding_at_star if padding_at_start: return [t[:, -redundant_padding:] for t in concat_tensors] truncated = [t[:, :redundant_padding] for t in concat_tensors] - # Here we add extra tensor that states the input lens - truncated.append( - truncated[0].new_tensor([mask.sum() for _, mask, _, _ in truncated]) - ) + if add_ilen: + # Here we add extra tensor that states the input lens + truncated.append( + truncated[0].new_tensor(truncated[1].sum(dim=1)) + ) return truncated From 4b79a51ba29deba49d7ac547a6e99b70f8b6844f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Tue, 11 May 2021 10:37:08 -0400 Subject: [PATCH 15/38] all subwords get the same label token --- daseg/utils_ner.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/daseg/utils_ner.py b/daseg/utils_ner.py index f4682e6..e0a4d45 100644 --- a/daseg/utils_ner.py +++ b/daseg/utils_ner.py @@ -114,8 +114,11 @@ def convert_examples_to_features( for word, label in zip(example.words, example.labels): word_tokens = tokenizer.tokenize(word) tokens.extend(word_tokens) - word_labels = [label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1) - # Use the real label id for the first token of the word, and padding ids for the remaining tokens + # Change it! + # # Use the real label id for the first token of the word, and padding ids for the remaining tokens + # word_labels = [label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1) + # Use the real label id for every token of a word + word_labels = [label_map[label]] * len(word_tokens) label_ids.extend(word_labels) # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. From 441226c7f207df4c2000caa6354e814912512660 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 12 May 2021 16:18:32 +0200 Subject: [PATCH 16/38] Various fixes --- daseg/dataloaders/transformers.py | 5 +++-- daseg/losses/crf.py | 32 +++++++++++++++++-------------- daseg/models/transformer_model.py | 3 ++- daseg/models/transformer_pl.py | 9 ++++++++- daseg/utils_ner.py | 8 ++++---- 5 files changed, 35 insertions(+), 22 deletions(-) diff --git a/daseg/dataloaders/transformers.py b/daseg/dataloaders/transformers.py index 67f239b..4c3a894 100644 --- a/daseg/dataloaders/transformers.py +++ b/daseg/dataloaders/transformers.py @@ -178,7 +178,8 @@ def to_transformers_eval_dataloader( return to_dataloader(dataset, batch_size=batch_size, train=False, padding_at_start=model_type == 'xlnet') -def truncate_padding_collate_fn(batch: List[List[torch.Tensor]], padding_at_start: bool = False, add_ilen: bool = True): +def truncate_padding_collate_fn(batch: List[List[torch.Tensor]], padding_at_start: bool = False, add_ilen: bool = True, sort: bool = True): + batch = sorted(batch, key=lambda tensors: (tensors[3] != -100).to(torch.int32).sum(), reverse=True) redundant_padding = max(mask.sum() for _, mask, _, _ in batch) n_tensors = len(batch[0]) concat_tensors = (torch.cat([sample[i].unsqueeze(0) for sample in batch]) for i in range(n_tensors)) @@ -188,7 +189,7 @@ def truncate_padding_collate_fn(batch: List[List[torch.Tensor]], padding_at_star if add_ilen: # Here we add extra tensor that states the input lens truncated.append( - truncated[0].new_tensor(truncated[1].sum(dim=1)) + truncated[1].sum(dim=1) ) return truncated diff --git a/daseg/losses/crf.py b/daseg/losses/crf.py index fbaad29..24d6e7e 100644 --- a/daseg/losses/crf.py +++ b/daseg/losses/crf.py @@ -22,28 +22,30 @@ def __init__( super().__init__() self.label_set = label_set self.label2id = label2id - self.den = make_denominator(label_set, label2id, shared=True) - self.den_scores = nn.Parameter(self.den.scores.clone(), requires_grad=trainable_transition_scores) + #self.den = make_denominator(label_set, label2id, shared=True).to('cuda') + #self.den_scores = nn.Parameter(self.den.scores.clone(), requires_grad=trainable_transition_scores).to('cuda') def forward(self, log_probs: Tensor, input_lens: Tensor, labels: Tensor): # (batch, seqlen, classes) - supervision_segments=make_segments(input_lens) + #supervision_segments=make_segments(input_lens) + supervision_segments=make_segments(labels) posteriors = k2.DenseFsaVec(log_probs, supervision_segments) # (fsavec) - nums = make_numerator(labels, input_lens) - self.den.set_scores_stochastic_(self.den_scores) + nums = make_numerator(labels, input_lens).to(log_probs.device) + #self.den.set_scores_stochastic_(self.den_scores) # (fsavec) num_lattices = k2.intersect_dense(nums, posteriors, output_beam=10.0) - den_lattice = k2.intersect_dense(self.den, posteriors, output_beam=10.0) + #den_lattice = k2.intersect_dense(self.den, posteriors, output_beam=10.0) # (batch,) num_scores = num_lattices.get_tot_scores(use_double_scores=True, log_semiring=True) - den_scores = den_lattice.get_tot_scores(use_double_scores=True, log_semiring=True) + #den_scores = den_lattice.get_tot_scores(use_double_scores=True, log_semiring=True) # (scalar) - loss = (num_scores - den_scores).sum() + #loss = (num_scores - den_scores).sum() / log_probs.size(0) + loss = num_scores.sum() return loss @@ -77,7 +79,7 @@ def make_numerator(labels: Tensor, input_lens: Tensor) -> k2.Fsa: assert len(labels.shape) == 2 assert len(input_lens.shape) == 1 nums = k2.create_fsa_vec([ - k2.linear_fsa(l[:llen].tolist()) for l, llen in zip(labels, input_lens) + k2.linear_fsa(l[l != -100].tolist()) for l, llen in zip(labels, input_lens) ]) return nums @@ -152,14 +154,16 @@ def make_denominator(label_set: List[str], label2id: Dict[str, int], shared: boo return fsa -def make_segments(input_lens: Tensor) -> Tensor: +def make_segments(labels: Tensor) -> Tensor: """ Creates a supervision segments tensor that indicates for each batch example, at which index the example has started, and how many tokens it has. """ - bs = input_lens.size(0) + #bs = input_lens.size(0) + bs = labels.size(0) return torch.stack([ torch.arange(bs, dtype=torch.int32), - torch.zeros(bs, dtype=torch.int32), - input_lens.cpu().to(torch.int32) - ], dim=1) + torch.ones(bs, dtype=torch.int32), # start at one because of [BOS] + (labels != -100).to(torch.int32).sum(dim=1).cpu() + #input_lens.cpu().to(torch.int32) - 3 # subtract one for [BOS] and two for [CLS] and [EOS] + ], dim=1).to(torch.int32) diff --git a/daseg/models/transformer_model.py b/daseg/models/transformer_model.py index 1a9e3bb..fa52f66 100644 --- a/daseg/models/transformer_model.py +++ b/daseg/models/transformer_model.py @@ -73,7 +73,8 @@ def from_pl_checkpoint(path: Path, device: str = 'cpu'): torch.arange(pl_model.config.max_position_embeddings).expand((1, -1)) # Remove extra keys that are no longer needed... for k in ["model.longformer.pooler.dense.weight", "model.longformer.pooler.dense.bias"]: - del ckpt['state_dict'][k] + if k in ckpt['state_dict']: + del ckpt['state_dict'][k] # Manually load the state dict pl_model.load_state_dict(ckpt['state_dict']) diff --git a/daseg/models/transformer_pl.py b/daseg/models/transformer_pl.py index 0259d13..9444428 100644 --- a/daseg/models/transformer_pl.py +++ b/daseg/models/transformer_pl.py @@ -69,7 +69,11 @@ def training_step(self, batch, batch_num): log_probs = torch.nn.functional.log_softmax(logits) labels, ilens = batch[3], batch[4] loss = -self.crf(log_probs, ilens, labels) - tensorboard_logs = {"loss": loss} + batch_size = log_probs.size(0) + num_tokens = batch[1].sum() + tensorboard_logs = {"loss": loss * batch_size / num_tokens} + else: + tensorboard_logs = {"loss": loss} return {"loss": loss, "log": tensorboard_logs} def validation_step(self, batch, batch_nb): @@ -194,6 +198,9 @@ def set_output_dir(self, output_dir: Path): def pad_outputs(outputs: Dict) -> Dict: max_out_len = max(x["pred"].shape[1] for x in outputs) for x in outputs: + for k in ['pred', 'target']: + if isinstance(x[k], torch.Tensor): + x[k] = x[k].cpu().numpy() x["pred"] = pad_array(x["pred"], target_len=max_out_len, value=0) x["target"] = pad_array(x["target"], target_len=max_out_len, value=CrossEntropyLoss().ignore_index) return outputs diff --git a/daseg/utils_ner.py b/daseg/utils_ner.py index e0a4d45..93e21c6 100644 --- a/daseg/utils_ner.py +++ b/daseg/utils_ner.py @@ -115,10 +115,10 @@ def convert_examples_to_features( word_tokens = tokenizer.tokenize(word) tokens.extend(word_tokens) # Change it! - # # Use the real label id for the first token of the word, and padding ids for the remaining tokens - # word_labels = [label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1) - # Use the real label id for every token of a word - word_labels = [label_map[label]] * len(word_tokens) + # Use the real label id for the first token of the word, and padding ids for the remaining tokens + word_labels = [label_map[label]] + [pad_token_label_id] * (len(word_tokens) - 1) + # # Use the real label id for every token of a word + # word_labels = [label_map[label]] * len(word_tokens) label_ids.extend(word_labels) # Account for [CLS] and [SEP] with "- 2" and with "- 3" for RoBERTa. From 4181024dc8734a992283d774a8a8ad2f9f58e72c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 12 May 2021 10:43:44 -0400 Subject: [PATCH 17/38] Exclude O (blank) symbol from loss computation --- daseg/losses/crf.py | 4 ++-- daseg/utils_ner.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/daseg/losses/crf.py b/daseg/losses/crf.py index 24d6e7e..c495a5f 100644 --- a/daseg/losses/crf.py +++ b/daseg/losses/crf.py @@ -1,4 +1,3 @@ -import dataclasses from typing import Dict, List import k2 @@ -133,7 +132,8 @@ def make_denominator(label_set: List[str], label2id: Dict[str, int], shared: boo """ # s = [f'0 0 {symtab["O"]} 0.0'] - s = [f'0 0 {symtab[""]} 0.0'] + # s = [f'0 0 {symtab[""]} 0.0'] + s = [] if shared: s += [ f'0 1 {symtab["I-"]} 0.0', diff --git a/daseg/utils_ner.py b/daseg/utils_ner.py index 93e21c6..fbcb418 100644 --- a/daseg/utils_ner.py +++ b/daseg/utils_ner.py @@ -18,7 +18,7 @@ import logging import os -from daseg.data import NEW_TURN +from daseg.data import BLANK, NEW_TURN logger = logging.getLogger(__name__) @@ -103,6 +103,7 @@ def convert_examples_to_features( label_map = {label: i for i, label in enumerate(label_list)} label_map[NEW_TURN] = pad_token_label_id + label_map[BLANK] = pad_token_label_id features = [] for (ex_index, example) in enumerate(examples): From 2e3d24ea0b7a130ed4a94850d1a01f4c32298944 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 12 May 2021 11:37:04 -0400 Subject: [PATCH 18/38] Fix supervision segments offset and remove the need for input lens --- daseg/losses/crf.py | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/daseg/losses/crf.py b/daseg/losses/crf.py index c495a5f..d0b078d 100644 --- a/daseg/losses/crf.py +++ b/daseg/losses/crf.py @@ -26,21 +26,24 @@ def __init__( def forward(self, log_probs: Tensor, input_lens: Tensor, labels: Tensor): # (batch, seqlen, classes) - #supervision_segments=make_segments(input_lens) - supervision_segments=make_segments(labels) + # supervision_segments=make_segments(input_lens) + supervision_segments = make_segments(labels) posteriors = k2.DenseFsaVec(log_probs, supervision_segments) # (fsavec) - nums = make_numerator(labels, input_lens).to(log_probs.device) - #self.den.set_scores_stochastic_(self.den_scores) + nums = make_numerator(labels).to(log_probs.device) + print('nums.shape', nums.shape) + print('nums->num_arcs', [nums[0].num_arcs for i in range(nums.shape[0])]) + print('segments', supervision_segments) + # self.den.set_scores_stochastic_(self.den_scores) # (fsavec) num_lattices = k2.intersect_dense(nums, posteriors, output_beam=10.0) - #den_lattice = k2.intersect_dense(self.den, posteriors, output_beam=10.0) + # den_lattice = k2.intersect_dense(self.den, posteriors, output_beam=10.0) # (batch,) num_scores = num_lattices.get_tot_scores(use_double_scores=True, log_semiring=True) - #den_scores = den_lattice.get_tot_scores(use_double_scores=True, log_semiring=True) + # den_scores = den_lattice.get_tot_scores(use_double_scores=True, log_semiring=True) # (scalar) #loss = (num_scores - den_scores).sum() / log_probs.size(0) @@ -58,7 +61,6 @@ def make_symbol_table(label_set: List[str], label2id: Dict[str, int], shared: bo specific for each class (N x classes -> N x I- symbols) """ symtab = k2.SymbolTable() - # symtab.add('O', label2id['O']) if shared: symtab.add('I-', label2id['I-']) for l in label_set: @@ -68,18 +70,14 @@ def make_symbol_table(label_set: List[str], label2id: Dict[str, int], shared: bo return symtab -def make_numerator(labels: Tensor, input_lens: Tensor) -> k2.Fsa: +def make_numerator(labels: Tensor) -> k2.Fsa: """ Creates a numerator supervision FSA. It simply encodes the ground truth label sequence and allows no leeway. Returns a :class:`k2.FsaVec` with FSAs of differing length. """ - assert labels.size(0) == input_lens.size(0) assert len(labels.shape) == 2 - assert len(input_lens.shape) == 1 - nums = k2.create_fsa_vec([ - k2.linear_fsa(l[l != -100].tolist()) for l, llen in zip(labels, input_lens) - ]) + nums = k2.create_fsa_vec([k2.linear_fsa(lab[lab != -100].tolist()) for lab in labels]) return nums @@ -105,7 +103,6 @@ def make_denominator(label_set: List[str], label2id: Dict[str, int], shared: boo """ shared=True - 0 0 O 0 0 Statement 0 0 Question 0 1 I- @@ -118,7 +115,6 @@ def make_denominator(label_set: List[str], label2id: Dict[str, int], shared: boo """ shared=False - 0 0 O 0 0 Statement 0 0 Question 0 1 I-Statement @@ -131,8 +127,6 @@ def make_denominator(label_set: List[str], label2id: Dict[str, int], shared: boo 2 """ - # s = [f'0 0 {symtab["O"]} 0.0'] - # s = [f'0 0 {symtab[""]} 0.0'] s = [] if shared: s += [ @@ -159,11 +153,9 @@ def make_segments(labels: Tensor) -> Tensor: Creates a supervision segments tensor that indicates for each batch example, at which index the example has started, and how many tokens it has. """ - #bs = input_lens.size(0) bs = labels.size(0) return torch.stack([ torch.arange(bs, dtype=torch.int32), - torch.ones(bs, dtype=torch.int32), # start at one because of [BOS] + torch.zeros(bs, dtype=torch.int32), (labels != -100).to(torch.int32).sum(dim=1).cpu() - #input_lens.cpu().to(torch.int32) - 3 # subtract one for [BOS] and two for [CLS] and [EOS] ], dim=1).to(torch.int32) From 9f3a7b5e43131491119accc23c6d5177c3a7ebf7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 12 May 2021 21:10:13 +0200 Subject: [PATCH 19/38] various fixes to CRF --- daseg/data.py | 2 +- daseg/losses/crf.py | 41 ++++++++++++++++++++++++---------- daseg/models/transformer_pl.py | 34 +++++++++++++++------------- 3 files changed, 49 insertions(+), 28 deletions(-) diff --git a/daseg/data.py b/daseg/data.py index 7d2d31c..60cc782 100644 --- a/daseg/data.py +++ b/daseg/data.py @@ -234,7 +234,7 @@ def joint_coding_dialog_act_label_frequencies(self): @property def joint_coding_dialog_act_labels(self) -> List[str]: - return list(chain([BLANK, CONTINUE_TAG], self.dialog_acts)) + return list(chain([CONTINUE_TAG], self.dialog_acts)) @property def vocabulary(self) -> Counter: diff --git a/daseg/losses/crf.py b/daseg/losses/crf.py index d0b078d..e29e545 100644 --- a/daseg/losses/crf.py +++ b/daseg/losses/crf.py @@ -16,40 +16,55 @@ def __init__( self, label_set: List[str], label2id: Dict[str, int], - trainable_transition_scores: bool = True + trainable_transition_scores: bool = False ): super().__init__() self.label_set = label_set self.label2id = label2id - #self.den = make_denominator(label_set, label2id, shared=True).to('cuda') - #self.den_scores = nn.Parameter(self.den.scores.clone(), requires_grad=trainable_transition_scores).to('cuda') + self.den = make_denominator(label_set, label2id, shared=True).to('cuda') + self.den_scores = nn.Parameter(self.den.scores.clone(), requires_grad=trainable_transition_scores).to('cuda') def forward(self, log_probs: Tensor, input_lens: Tensor, labels: Tensor): + global it # (batch, seqlen, classes) - # supervision_segments=make_segments(input_lens) supervision_segments = make_segments(labels) posteriors = k2.DenseFsaVec(log_probs, supervision_segments) # (fsavec) nums = make_numerator(labels).to(log_probs.device) - print('nums.shape', nums.shape) - print('nums->num_arcs', [nums[0].num_arcs for i in range(nums.shape[0])]) - print('segments', supervision_segments) - # self.den.set_scores_stochastic_(self.den_scores) + for i in range(nums.shape[0]): + # The supervision has to have exactly the same number of arcs as the number of tokens + # which contain labels to score, plus one extra arc for k2's special end-of-fst arc. + assert nums[i].num_arcs == supervision_segments[i, 2] + 1 + self.den.set_scores_stochastic_(self.den_scores) + + if it % 100 == 0: + for i in range(min(3, labels.size(0))): + print('*' * 120) + print('log_probs.shape', log_probs.shape) + print(f'labels[{i}][:20] = ', labels[i][:20]) + print(f'labels[{i}][labels[{i}] != -100][:20] = ', labels[i][labels[i] != -100][:20]) + print(f'nums[{i}].labels[:20] = ', nums[i].labels[:20]) + print(f'log_probs[{i}][:20] = ', log_probs.argmax(dim=2)[i][:20]) + print('*' * 120) + it += 1 # (fsavec) num_lattices = k2.intersect_dense(nums, posteriors, output_beam=10.0) - # den_lattice = k2.intersect_dense(self.den, posteriors, output_beam=10.0) + den_lattice = k2.intersect_dense(self.den, posteriors, output_beam=10.0) # (batch,) num_scores = num_lattices.get_tot_scores(use_double_scores=True, log_semiring=True) - # den_scores = den_lattice.get_tot_scores(use_double_scores=True, log_semiring=True) + den_scores = den_lattice.get_tot_scores(use_double_scores=True, log_semiring=True) # (scalar) - #loss = (num_scores - den_scores).sum() / log_probs.size(0) - loss = num_scores.sum() + num_tokens = (labels != -100).to(torch.int32).sum() + loss = (num_scores - den_scores).sum() / num_tokens + #loss = num_scores.sum() / (labels != -100).to(torch.int32).sum() return loss +it = 0 + def make_symbol_table(label_set: List[str], label2id: Dict[str, int], shared: bool = True) -> k2.SymbolTable: """ @@ -61,6 +76,8 @@ def make_symbol_table(label_set: List[str], label2id: Dict[str, int], shared: bo specific for each class (N x classes -> N x I- symbols) """ symtab = k2.SymbolTable() + del symtab._sym2id[''] + del symtab._id2sym[0] if shared: symtab.add('I-', label2id['I-']) for l in label_set: diff --git a/daseg/models/transformer_pl.py b/daseg/models/transformer_pl.py index 9444428..b7bf2f8 100644 --- a/daseg/models/transformer_pl.py +++ b/daseg/models/transformer_pl.py @@ -64,17 +64,19 @@ def training_step(self, batch, batch_num): batch[2] if self.config.model_type in ["bert", "xlnet"] else None ) # XLM and RoBERTa don"t use token_type_ids outputs = self(**inputs) - loss, logits = outputs[:2] + ce_loss, logits = outputs[:2] if self.crf is not None: - log_probs = torch.nn.functional.log_softmax(logits) + log_probs = torch.nn.functional.log_softmax(logits, dim=2) labels, ilens = batch[3], batch[4] - loss = -self.crf(log_probs, ilens, labels) - batch_size = log_probs.size(0) - num_tokens = batch[1].sum() - tensorboard_logs = {"loss": loss * batch_size / num_tokens} + crf_loss = -self.crf(log_probs, ilens, labels) + ce_loss = 0.1 * ce_loss + loss = crf_loss + ce_loss + logs = {"loss": loss, 'crf_loss': crf_loss, 'ce_loss': ce_loss} else: - tensorboard_logs = {"loss": loss} - return {"loss": loss, "log": tensorboard_logs} + logs = {"loss": loss} + progdict = logs.copy() + progdict.pop('loss') + return {"loss": loss, "log": logs, 'progress_bar': logs} def validation_step(self, batch, batch_nb): "Compute validation" @@ -87,7 +89,7 @@ def validation_step(self, batch, batch_nb): outputs = self(**inputs) loss, logits = outputs[:2] if self.crf is not None: - log_probs = torch.nn.functional.log_softmax(logits) + log_probs = torch.nn.functional.log_softmax(logits, dim=2) labels, ilens = batch[3], batch[4] loss = -self.crf(log_probs, ilens, labels) preds = logits.detach().cpu().numpy() @@ -152,7 +154,7 @@ def compute_and_set_total_steps( def get_lr_scheduler(self): scheduler = get_linear_schedule_with_warmup( - self.opt, num_warmup_steps=0, num_training_steps=self.total_steps + self.opt, num_warmup_steps=250, num_training_steps=self.total_steps ) scheduler = {"scheduler": scheduler, "interval": "step", "frequency": 1} return scheduler @@ -171,12 +173,14 @@ def configure_optimizers(self): "weight_decay": 0.0, }, ] - optimizer = AdamW(optimizer_grouped_parameters, lr=5e-5, eps=1e-8) - self.opt = optimizer - - scheduler = self.get_lr_scheduler() + self.opt = AdamW( + optimizer_grouped_parameters, + lr=5e-5, + eps=1e-8 + ) + self.scheduler = self.get_lr_scheduler() - return [optimizer], [scheduler] + return [self.opt], [self.scheduler] def set_output_dir(self, output_dir: Path): self.output_dir = Path(output_dir) From 2c86086f271ddc5d410d55c2e9ee64f384a5e8e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 12 May 2021 22:17:28 +0200 Subject: [PATCH 20/38] Working CRF numerator with subword token skipping --- daseg/losses/crf.py | 67 +++++++++++++++++++++++++++------- daseg/models/transformer_pl.py | 8 ++-- 2 files changed, 58 insertions(+), 17 deletions(-) diff --git a/daseg/losses/crf.py b/daseg/losses/crf.py index e29e545..5d14d62 100644 --- a/daseg/losses/crf.py +++ b/daseg/losses/crf.py @@ -16,56 +16,95 @@ def __init__( self, label_set: List[str], label2id: Dict[str, int], - trainable_transition_scores: bool = False + trainable_transition_scores: bool = True, + ignore_index: int = -100 ): super().__init__() self.label_set = label_set self.label2id = label2id + self.ignore_index = ignore_index self.den = make_denominator(label_set, label2id, shared=True).to('cuda') - self.den_scores = nn.Parameter(self.den.scores.clone(), requires_grad=trainable_transition_scores).to('cuda') + self.A = create_bigram_lm([self.label2id[l] for l in label_set]).to('cuda') + self.A_scores = nn.Parameter(self.A.scores.clone(), requires_grad=trainable_transition_scores).to('cuda') def forward(self, log_probs: Tensor, input_lens: Tensor, labels: Tensor): global it - # (batch, seqlen, classes) + + # Determine all relevant shapes - max_seqlen_scored is the longest sequence length of log_probs + # after we remove ignored indices. + bs, seqlen, nclass = log_probs.shape + max_seqlen_scored = (labels[0] != self.ignore_index).sum() supervision_segments = make_segments(labels) - posteriors = k2.DenseFsaVec(log_probs, supervision_segments) + + log_probs_scored = log_probs.new_zeros(bs, max_seqlen_scored, nclass) + assert max_seqlen_scored == supervision_segments[0, 2] + for i in range(bs): + log_probs_scored[i, :supervision_segments[i, 2], :] = log_probs[i, labels[i] != self.ignore_index, :] + + # (batch, seqlen, classes) + posteriors = k2.DenseFsaVec(log_probs_scored, supervision_segments) # (fsavec) - nums = make_numerator(labels).to(log_probs.device) + nums = make_numerator(labels) + for i in range(nums.shape[0]): + # The supervision has to have exactly the same number of arcs as the number of tokens + # which contain labels to score, plus one extra arc for k2's special end-of-fst arc. + assert nums[i].num_arcs == supervision_segments[i, 2] + 1 + self.A.set_scores_stochastic_(self.A_scores) + nums = k2.intersect(self.A.to('cpu'), nums).to(log_probs.device) for i in range(nums.shape[0]): # The supervision has to have exactly the same number of arcs as the number of tokens # which contain labels to score, plus one extra arc for k2's special end-of-fst arc. assert nums[i].num_arcs == supervision_segments[i, 2] + 1 - self.den.set_scores_stochastic_(self.den_scores) if it % 100 == 0: for i in range(min(3, labels.size(0))): print('*' * 120) - print('log_probs.shape', log_probs.shape) + print('log_probs_scored.shape', log_probs_scored.shape) print(f'labels[{i}][:20] = ', labels[i][:20]) - print(f'labels[{i}][labels[{i}] != -100][:20] = ', labels[i][labels[i] != -100][:20]) + print(f'labels[{i}][labels[{i}] != self.ignore_index][:20] = ', labels[i][labels[i] != self.ignore_index][:20]) print(f'nums[{i}].labels[:20] = ', nums[i].labels[:20]) - print(f'log_probs[{i}][:20] = ', log_probs.argmax(dim=2)[i][:20]) + print(f'log_probs_scored[{i}][:20] = ', log_probs_scored.argmax(dim=2)[i][:20]) print('*' * 120) it += 1 # (fsavec) num_lattices = k2.intersect_dense(nums, posteriors, output_beam=10.0) - den_lattice = k2.intersect_dense(self.den, posteriors, output_beam=10.0) + #den_lattice = k2.intersect_dense(self.den, posteriors, output_beam=10.0) # (batch,) num_scores = num_lattices.get_tot_scores(use_double_scores=True, log_semiring=True) - den_scores = den_lattice.get_tot_scores(use_double_scores=True, log_semiring=True) + #den_scores = den_lattice.get_tot_scores(use_double_scores=True, log_semiring=True) # (scalar) - num_tokens = (labels != -100).to(torch.int32).sum() - loss = (num_scores - den_scores).sum() / num_tokens - #loss = num_scores.sum() / (labels != -100).to(torch.int32).sum() + num_tokens = (labels != self.ignore_index).to(torch.int32).sum() + #loss = (num_scores - den_scores).sum() / num_tokens + loss = num_scores.sum() / num_tokens return loss it = 0 +def create_bigram_lm(labels: List[int]) -> k2.Fsa: + """ + Create a bigram LM. + The resulting FSA (A) has a start-state and a state for + each label 0, 1, 2, ....; and each of the above-mentioned states + has a transition to the state for each phone and also to the final-state. + """ + final_state = len(labels) + 1 + rules = '' + for i in range(1, final_state): + rules += f'0 {i} {labels[i-1]} 0.0\n' + + for i in range(1, final_state): + for j in range(1, final_state): + rules += f'{i} {j} {labels[j-1]} 0.0\n' + rules += f'{i} {final_state} -1 0.0\n' + rules += f'{final_state}' + return k2.Fsa.from_str(rules) + + def make_symbol_table(label_set: List[str], label2id: Dict[str, int], shared: bool = True) -> k2.SymbolTable: """ Creates a symbol table given a list of classes (e.g. dialog acts, punctuation, etc.). diff --git a/daseg/models/transformer_pl.py b/daseg/models/transformer_pl.py index b7bf2f8..d98f2df 100644 --- a/daseg/models/transformer_pl.py +++ b/daseg/models/transformer_pl.py @@ -69,14 +69,16 @@ def training_step(self, batch, batch_num): log_probs = torch.nn.functional.log_softmax(logits, dim=2) labels, ilens = batch[3], batch[4] crf_loss = -self.crf(log_probs, ilens, labels) - ce_loss = 0.1 * ce_loss - loss = crf_loss + ce_loss + #ce_loss = 0.1 * ce_loss + loss = crf_loss# + ce_loss + #loss = ce_loss logs = {"loss": loss, 'crf_loss': crf_loss, 'ce_loss': ce_loss} else: + loss = ce_loss logs = {"loss": loss} progdict = logs.copy() progdict.pop('loss') - return {"loss": loss, "log": logs, 'progress_bar': logs} + return {"loss": loss, "log": logs, 'progress_bar': progdict} def validation_step(self, batch, batch_nb): "Compute validation" From c7bd9fb80594b0af3dda15812defb6b4fd286a9f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Wed, 12 May 2021 23:06:38 +0200 Subject: [PATCH 21/38] Working naive CRF denominator version --- daseg/losses/crf.py | 92 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 79 insertions(+), 13 deletions(-) diff --git a/daseg/losses/crf.py b/daseg/losses/crf.py index 5d14d62..4d44843 100644 --- a/daseg/losses/crf.py +++ b/daseg/losses/crf.py @@ -1,5 +1,6 @@ from typing import Dict, List +import numpy as np import k2 import torch from torch import Tensor, nn @@ -16,16 +17,16 @@ def __init__( self, label_set: List[str], label2id: Dict[str, int], - trainable_transition_scores: bool = True, ignore_index: int = -100 ): super().__init__() self.label_set = label_set self.label2id = label2id self.ignore_index = ignore_index - self.den = make_denominator(label_set, label2id, shared=True).to('cuda') + self.den = k2.arc_sort(make_denominator(label_set, label2id, shared=True)) + self.den.requires_grad_(False) self.A = create_bigram_lm([self.label2id[l] for l in label_set]).to('cuda') - self.A_scores = nn.Parameter(self.A.scores.clone(), requires_grad=trainable_transition_scores).to('cuda') + self.A_scores = nn.Parameter(self.A.scores.clone(), requires_grad=True).to('cuda') def forward(self, log_probs: Tensor, input_lens: Tensor, labels: Tensor): global it @@ -45,19 +46,21 @@ def forward(self, log_probs: Tensor, input_lens: Tensor, labels: Tensor): posteriors = k2.DenseFsaVec(log_probs_scored, supervision_segments) # (fsavec) - nums = make_numerator(labels) - for i in range(nums.shape[0]): - # The supervision has to have exactly the same number of arcs as the number of tokens - # which contain labels to score, plus one extra arc for k2's special end-of-fst arc. - assert nums[i].num_arcs == supervision_segments[i, 2] + 1 self.A.set_scores_stochastic_(self.A_scores) - nums = k2.intersect(self.A.to('cpu'), nums).to(log_probs.device) + nums = make_numerator(labels) + A_cpu = self.A.to('cpu') + nums = k2.intersect(A_cpu, nums).to(log_probs.device) for i in range(nums.shape[0]): # The supervision has to have exactly the same number of arcs as the number of tokens # which contain labels to score, plus one extra arc for k2's special end-of-fst arc. assert nums[i].num_arcs == supervision_segments[i, 2] + 1 + # (fsavec) + #den = k2.intersect(A_cpu, self.den).detach().to(log_probs.device) + den = self.den.to('cuda') + if it % 100 == 0: + #print_transition_probabilities(self.A, self.den.symbols, list(self.label2id.values())) for i in range(min(3, labels.size(0))): print('*' * 120) print('log_probs_scored.shape', log_probs_scored.shape) @@ -70,16 +73,16 @@ def forward(self, log_probs: Tensor, input_lens: Tensor, labels: Tensor): # (fsavec) num_lattices = k2.intersect_dense(nums, posteriors, output_beam=10.0) - #den_lattice = k2.intersect_dense(self.den, posteriors, output_beam=10.0) + den_lattices = k2.intersect_dense(den, posteriors, output_beam=10.0) # (batch,) num_scores = num_lattices.get_tot_scores(use_double_scores=True, log_semiring=True) - #den_scores = den_lattice.get_tot_scores(use_double_scores=True, log_semiring=True) + den_scores = den_lattices.get_tot_scores(use_double_scores=True, log_semiring=True) # (scalar) num_tokens = (labels != self.ignore_index).to(torch.int32).sum() - #loss = (num_scores - den_scores).sum() / num_tokens - loss = num_scores.sum() / num_tokens + loss = (num_scores - den_scores).sum() / num_tokens + #loss = num_scores.sum() / num_tokens return loss it = 0 @@ -105,6 +108,69 @@ def create_bigram_lm(labels: List[int]) -> k2.Fsa: return k2.Fsa.from_str(rules) +def print_transition_probabilities(P: k2.Fsa, phone_symbol_table: k2.SymbolTable, + phone_ids: List[int], filename: str = None): + '''Print the transition probabilities of a phone LM. + + Args: + P: + A bigram phone LM. + phone_symbol_table: + The phone symbol table. + phone_ids: + A list of phone ids + filename: + Filename to save the printed result. + ''' + num_phones = len(phone_ids) + table = np.zeros((num_phones + 1, num_phones + 2)) + table[:, 0] = 0 + table[0, -1] = 0 # the start state has no arcs to the final state + #assert P.arcs.dim0() == num_phones + 2 + arcs = P.arcs.values()[:, :3] + probability = P.scores.exp().tolist() + + assert arcs.shape[0] - num_phones == num_phones * (num_phones + 1) + for i, arc in enumerate(arcs.tolist()): + src_state, dest_state, label = arc[0], arc[1], arc[2] + prob = probability[i] + if label != -1: + assert label == dest_state + else: + assert dest_state == num_phones + 1 + table[src_state][dest_state] = prob + + try: + from prettytable import PrettyTable + except ImportError: + print('Please run `pip install prettytable`. Skip printing') + return + + x = PrettyTable() + + field_names = ['source'] + field_names.append('sum') + for i in phone_ids: + field_names.append(phone_symbol_table[i]) + field_names.append('final') + + x.field_names = field_names + + for row in range(num_phones + 1): + this_row = [] + if row == 0: + this_row.append('start') + else: + this_row.append(phone_symbol_table[row]) + this_row.append('{:.6f}'.format(table[row, 1:].sum())) + for col in range(1, num_phones + 2): + this_row.append('{:.6f}'.format(table[row, col])) + x.add_row(this_row) + print(str(x)) + #with open(filename, 'w') as f: + # f.write(str(x)) + + def make_symbol_table(label_set: List[str], label2id: Dict[str, int], shared: bool = True) -> k2.SymbolTable: """ Creates a symbol table given a list of classes (e.g. dialog acts, punctuation, etc.). From 39d14c97696c6955d5dbfad80184f9f38f47a116 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 5 Jul 2021 11:40:06 -0400 Subject: [PATCH 22/38] Local CRF loss import --- daseg/models/transformer_pl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/daseg/models/transformer_pl.py b/daseg/models/transformer_pl.py index d98f2df..d2be498 100644 --- a/daseg/models/transformer_pl.py +++ b/daseg/models/transformer_pl.py @@ -1,5 +1,4 @@ from pathlib import Path -from pathlib import Path from typing import Dict, List import numpy as np @@ -12,7 +11,6 @@ from daseg.data import NEW_TURN from daseg.dataloaders.transformers import pad_array -from daseg.losses.crf import CRFLoss from daseg.metrics import as_tensors, compute_sklearn_metrics @@ -49,6 +47,7 @@ def __init__( self.model = model_class(self.config) self.model.resize_token_embeddings(len(self.tokenizer)) if crf: + from daseg.losses.crf import CRFLoss self.crf = CRFLoss([l for l in self.labels if l != 'O' and not l.startswith('I-')], self.label2id) else: self.crf = None From 6335ac4ccfe707c55af77e2b40a74bf79b5b6d0f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 8 Oct 2021 20:24:29 +0200 Subject: [PATCH 23/38] A lot of various fixes --- daseg/bin/dasg | 15 ++++++++++++--- daseg/conversion.py | 2 +- daseg/data.py | 2 +- daseg/dataloaders/transformers.py | 14 +++++++++----- daseg/models/transformer_model.py | 20 ++++++++++++++------ 5 files changed, 37 insertions(+), 16 deletions(-) diff --git a/daseg/bin/dasg b/daseg/bin/dasg index 70ddd8c..ed82747 100644 --- a/daseg/bin/dasg +++ b/daseg/bin/dasg @@ -130,12 +130,18 @@ def evaluate( results['args'] = ctx.params with SlackNotifier(' '.join(sys.argv)) as slack: for res_grp in ( - 'sklearn_metrics', 'seqeval_metrics', 'zhao_kawahara_metrics', 'ORIGINAL_zhao_kawahara_metrics', + 'sklearn_metrics', 'seqeval_metrics', 'zhao_kawahara_metrics', + 'segeval_metrics', 'args' ): slack.write_and_print(f'{res_grp.upper()}:') for key, val in results[res_grp].items(): - slack.write_and_print(f'{key}\t{val:.2%}') + try: + # TODO: rewrite cleaner later + if isinstance(val, int): raise Exception() + slack.write_and_print(f'{key}\t{val:.2%}') + except: + slack.write_and_print(f'{key}\t{val}') if save_output is not None: with open(save_output, 'wb') as f: pickle.dump(results, f) @@ -388,7 +394,10 @@ def evaluate_bigru( results['args'] = ctx.params with SlackNotifier(' '.join(sys.argv)) as slack: for key, val in results['log'].items(): - slack.write_and_print(f'{key}\t{val:.2%}') + try: + slack.write_and_print(f'{key}\t{val:.2%}') + except: + slack.write_and_print(f'{key}\t{val}') if save_output is not None: with open(save_output, 'wb') as f: pickle.dump(results, f) diff --git a/daseg/conversion.py b/daseg/conversion.py index 0f49a4d..aa7f1a1 100644 --- a/daseg/conversion.py +++ b/daseg/conversion.py @@ -50,7 +50,7 @@ def predictions_to_dataset( ) -> DialogActCorpus: dialogues = {} for (call_id, call), pred_tags in zip(original_dataset.dialogues.items(), predictions): - words, _, speakers = call.words_with_metadata(add_turn_token=True) + words, orig_labels, speakers = call.words_with_metadata(add_turn_token=True) assert len(words) == len(pred_tags), \ f'Mismatched words ({len(words)}) and predicted tags ({len(pred_tags)}) counts for conversation "{call_id}"' diff --git a/daseg/data.py b/daseg/data.py index 60cc782..94ad1c6 100644 --- a/daseg/data.py +++ b/daseg/data.py @@ -211,7 +211,7 @@ def turns(self) -> Iterable['Call']: @property def dialog_acts(self) -> List[str]: - return sorted(set(segment.dialog_act for call in self.calls for segment in call)) + return sorted(set(str(segment.dialog_act) for call in self.calls for segment in call)) @property def dialog_act_labels(self) -> List[str]: diff --git a/daseg/dataloaders/transformers.py b/daseg/dataloaders/transformers.py index 4c3a894..6746525 100644 --- a/daseg/dataloaders/transformers.py +++ b/daseg/dataloaders/transformers.py @@ -118,7 +118,7 @@ def to_dataloader(dataset: Dataset, padding_at_start: bool, batch_size: int, tra batch_size=batch_size, collate_fn=partial(truncate_padding_collate_fn, padding_at_start=padding_at_start), pin_memory=True, - num_workers=4 + num_workers=0 ) @@ -195,13 +195,17 @@ def truncate_padding_collate_fn(batch: List[List[torch.Tensor]], padding_at_star def pad_list_of_arrays(arrays: List[np.ndarray], value: float) -> List[np.ndarray]: - max_out_len = max(x.shape[1] for x in arrays) + if len(arrays[0].shape) > 1: + max_out_len = max(x.shape[1] for x in arrays) + else: + max_out_len = max(x.shape[0] for x in arrays) return [pad_array(t, target_len=max_out_len, value=value) for t in arrays] def pad_array(arr: np.ndarray, target_len: int, value: float): - if arr.shape[1] == target_len: + len_dim = 1 if len(arr.shape) > 1 else 0 + if arr.shape[len_dim] == target_len: return arr pad_shape = list(arr.shape) - pad_shape[1] = target_len - arr.shape[1] - return np.concatenate([arr, np.ones(pad_shape) * value], axis=1) + pad_shape[len_dim] = target_len - arr.shape[len_dim] + return np.concatenate([arr, np.ones(pad_shape) * value], axis=len_dim) diff --git a/daseg/models/transformer_model.py b/daseg/models/transformer_model.py index fa52f66..cdab308 100644 --- a/daseg/models/transformer_model.py +++ b/daseg/models/transformer_model.py @@ -128,10 +128,12 @@ def predict( use_joint_coding=use_joint_coding, use_turns=use_turns ) + elif not isinstance(dataset, torch.utils.data.DataLoader): + dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size) else: dataloader = dataset - eval_ce_losses, logits, out_label_ids = zip(*list(maybe_tqdm( + eval_ce_losses, logits, out_label_ids, input_ids = zip(*list(maybe_tqdm( map( partial( predict_batch_in_windows, @@ -149,18 +151,24 @@ def predict( pad_token_label_id = CrossEntropyLoss().ignore_index out_label_ids = pad_list_of_arrays(out_label_ids, value=pad_token_label_id) logits = pad_list_of_arrays(logits, value=0) - out_label_ids = np.concatenate(out_label_ids, axis=0) + out_label_ids = np.concatenate(out_label_ids, axis=0).astype(np.int32) logits = np.concatenate(logits, axis=0) preds = np.argmax(logits, axis=2) + input_ids = np.concatenate(pad_list_of_arrays(input_ids, value=0), axis=0) label_map = {int(k): v for k, v in self.config.id2label.items()} + label_map[-100] = 'O' + turn_tok_id = self.tokenizer.encode('')[1] + assert out_label_ids.shape == preds.shape, f'{out_label_ids.shape} == {preds.shape}' out_label_list: List[List[str]] = [[] for _ in range(out_label_ids.shape[0])] - preds_list: List[List[str]] = [[] for _ in range(out_label_ids.shape[0])] + preds_list: List[List[str]] = [[] for _ in range(preds.shape[0])] for i in range(out_label_ids.shape[0]): for j in range(out_label_ids.shape[1]): - if out_label_ids[i, j] != pad_token_label_id: + is_ignored = out_label_ids[i, j] == pad_token_label_id + is_turn = input_ids[i, j] == turn_tok_id + if not is_ignored or is_turn: out_label_list[i].append(label_map[out_label_ids[i][j]]) preds_list[i].append(label_map[preds[i][j]]) @@ -231,7 +239,7 @@ def predict_batch_in_windows( maxlen = batch[0].shape[1] window_shift = window_len - window_overlap windows = ( - [t[:, i: i + window_len].contiguous().to(device) for t in batch] + [t[:, i: i + window_len].contiguous().to(device) if len(t.shape) > 1 else t for t in batch] for i in range(0, maxlen, window_shift) ) @@ -262,7 +270,7 @@ def predict_batch_in_windows( mems = outputs[2] # workaround for PyTorch file descriptor leaks: # https://github.com/pytorch/pytorch/issues/973 - returns = ce_loss, np.concatenate(logits, axis=1), deepcopy(batch[3].detach().cpu().numpy()) + returns = ce_loss, np.concatenate(logits, axis=1), deepcopy(batch[3].detach().cpu().numpy()), deepcopy(batch[3].detach().cpu().numpy()) for t in batch: del t return returns From 636218eabc91548b1e4ac4e8970a6747e09c1416 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sat, 16 Oct 2021 02:20:43 +0200 Subject: [PATCH 24/38] Handle edge case in boundary similarity computation for single-segment examples --- daseg/metrics.py | 31 ++++++++++++++++++++++--------- 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/daseg/metrics.py b/daseg/metrics.py index 11baf5c..899c157 100644 --- a/daseg/metrics.py +++ b/daseg/metrics.py @@ -60,18 +60,31 @@ def compute_segeval_metrics(true_dataset: DialogActCorpus, pred_dataset: DialogA from segeval.data import Dataset from segeval import boundary_similarity, pk - true_segments = Dataset({ - cid: {'ref': [len(fs.text.split()) for fs in call]} + def fix_single_seg_calls(true, pred): + for cid in true.keys(): + true_segs = true[cid]["ref"] + pred_segs = pred[cid]["hyp"] + if len(true_segs) == len(pred_segs) == 1: + true[cid]["ref"] = true_segs + [1] + pred[cid]["hyp"] = pred_segs + [1] + + true_segments = { + cid: {"ref": [len(fs.text.split()) for fs in call]} for cid, call in true_dataset.dialogues.items() - }) - pred_segments = Dataset({ - cid: {'hyp': [len(fs.text.split()) for fs in call]} + } + pred_segments = { + cid: {"hyp": [len(fs.text.split()) for fs in call]} for cid, call in pred_dataset.dialogues.items() - }) - + } + + fix_single_seg_calls(true_segments, pred_segments) + + pred_segments = Dataset(pred_segments) + true_segments = Dataset(true_segments) + return { - 'pk': float(mean(pk(true_segments, pred_segments).values())), - 'B': float(mean(boundary_similarity(true_segments, pred_segments).values())) + "pk": float(mean(pk(true_segments, pred_segments).values())), + "B": float(mean(boundary_similarity(true_segments, pred_segments).values())), } From 85cf4a027b896e3a47dd7e127e3149b5c837a98d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 29 Oct 2021 18:28:09 -0400 Subject: [PATCH 25/38] Option to create partial segments for training --- daseg/bin/dasg | 5 +- daseg/data.py | 79 ++++++++++++++++++++++++++++++- daseg/dataloaders/transformers.py | 71 +++++++++++++++++++++++---- 3 files changed, 144 insertions(+), 11 deletions(-) diff --git a/daseg/bin/dasg b/daseg/bin/dasg index ed82747..be29e3b 100644 --- a/daseg/bin/dasg +++ b/daseg/bin/dasg @@ -158,6 +158,7 @@ def evaluate( @click.option('-l', '--max-sequence-length', default=4096, type=int) @click.option('-n', '--turns', is_flag=True) @click.option('-w', '--windows-if-exceeds-max-len', is_flag=True) +@click.option('-a', '--allow-partial-segments', is_flag=True) def prepare_exp( output_dir: Path, model_name_or_path: str, @@ -168,6 +169,7 @@ def prepare_exp( max_sequence_length: int, turns: bool, windows_if_exceeds_max_len: bool, + allow_partial_segments: bool, ): output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) @@ -189,7 +191,8 @@ def prepare_exp( labels=model.labels, max_seq_length=max_sequence_length, use_turns=turns, - windows_if_exceeds_max_length=windows_if_exceeds_max_len + windows_if_exceeds_max_length=windows_if_exceeds_max_len, + allow_partial_segments=allow_partial_segments, ) for key, split_corpus in corpus.train_dev_test_split().items() } diff --git a/daseg/data.py b/daseg/data.py index 94ad1c6..6c0f3ff 100644 --- a/daseg/data.py +++ b/daseg/data.py @@ -7,8 +7,9 @@ from functools import partial from itertools import chain, groupby from pathlib import Path -from typing import Callable, Dict, FrozenSet, Iterable, List, Mapping, NamedTuple, Optional, Set, Tuple +from typing import Callable, Dict, FrozenSet, Iterable, List, Literal, Mapping, NamedTuple, Optional, Set, Tuple +import numpy as np from cytoolz.itertoolz import sliding_window from more_itertools import flatten from spacy import displacy @@ -475,6 +476,23 @@ def encode( return encoded_call + def cut_segments_in_windows(self, window_size_in_tokens: int, tokenizer) -> "Call": + """ + Return a copy of ``self`` where functional segments longer than ``window_size_in_tokens`` + are split into partial segments. In order to know the token count we will need a tokenizer + argument that has a method ``.tokenize(text: str)``. + """ + return Call( + list( + chain.from_iterable( + fs.to_windows( + window_size_in_tokens=window_size_in_tokens, tokenizer=tokenizer + ) + for fs in self + ) + ) + ) + def prepare_call_windows( call: Call, @@ -508,6 +526,9 @@ class FunctionalSegment(NamedTuple): is_continuation: bool = False start: Optional[float] = None end: Optional[float] = None + completeness: Literal[ + "complete", "left-truncated", "right-truncated", "both-truncated" + ] = "complete" @property def num_words(self) -> int: @@ -521,6 +542,62 @@ def with_vocabulary(self, vocabulary: Set[str]) -> 'FunctionalSegment': new_text = ' '.join(w if w in vocabulary else OOV for w in self.text.split()) return FunctionalSegment(new_text, *self[1:]) + def to_windows( + self, window_size_in_tokens: int, tokenizer + ) -> List["FunctionalSegment"]: + words = self.text.split() + segment_tokens_per_word = [len(tokenizer.tokenize(w)) for w in words] + n_segment_tokens = sum(segment_tokens_per_word) + if n_segment_tokens <= window_size_in_tokens: + return [self] + + partial_segments = [] + is_first = True + is_last = False + while words: + if n_segment_tokens < window_size_in_tokens: + text = " ".join(words) + words = [] + is_last = True + else: + n_partial_words = find_nearest( + segment_tokens_per_word, window_size_in_tokens + ) + text = " ".join(words[:n_partial_words]) + words = words[n_partial_words:] + partial_segments.append( + FunctionalSegment( + text=text, + dialog_act=self.dialog_act, + speaker=self.speaker, + is_continuation=self.is_continuation, + start=self.start, + end=self.end, + completeness=( + "right-truncated" + if is_first + else "left-truncated" + if is_last + else "both-truncated" + ), + ) + ) + is_first = False + + return partial_segments + + +def find_nearest(array: List[int], value: int): + """ + Find the index of the closest element to ``value`` in cumulative sum of ``array`` + that is not greater than ``value``. + """ + array = np.cumsum(array) + diff = array - value + diff = np.where(diff <= 0, diff, np.inf) + idx = diff.argmin() + return idx + class EncodedSegment(NamedTuple): words: List[str] diff --git a/daseg/dataloaders/transformers.py b/daseg/dataloaders/transformers.py index 6746525..42b2a48 100644 --- a/daseg/dataloaders/transformers.py +++ b/daseg/dataloaders/transformers.py @@ -1,19 +1,20 @@ import warnings from functools import partial from itertools import chain -from typing import Iterable, Optional, List +from typing import Iterable, List, Optional import numpy as np import torch from torch.nn import CrossEntropyLoss -from torch.utils.data import DataLoader, TensorDataset, SequentialSampler, RandomSampler, Dataset +from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler, TensorDataset from transformers import PreTrainedTokenizer -from daseg import DialogActCorpus, Call +from daseg import Call, DialogActCorpus from daseg.utils_ner import InputExample, convert_examples_to_features -def as_windows(call: Call, max_length: int, tokenizer: PreTrainedTokenizer, use_joint_coding: bool) -> Iterable[Call]: +def as_windows_old(call: Call, max_length: int, tokenizer: PreTrainedTokenizer, use_joint_coding: bool) -> Iterable[ + Call]: if not use_joint_coding: warnings.warn('Call windows are not available when joint coding is turned off. Some calls will be truncated.') return [call] @@ -33,6 +34,52 @@ def as_windows(call: Call, max_length: int, tokenizer: PreTrainedTokenizer, use_ yield Call(window) +def as_windows( + call: Call, + max_length: int, + tokenizer: PreTrainedTokenizer, + use_joint_coding: bool, + allow_partial_segments: bool = False, +) -> Iterable[Call]: + if not use_joint_coding: + warnings.warn( + "Call windows are not available when joint coding is turned off. Some calls will be truncated." + ) + return [call] + window = [] + cur_len = 0 + for segment in call: + words = segment.text.split() + segment_tokens_per_word = [len(tokenizer.tokenize(w)) for w in words] + n_segment_tokens = sum(segment_tokens_per_word) + if cur_len + n_segment_tokens > max_length: + if not window and not allow_partial_segments: + raise ValueError( + "Max sequence length is too low - a segment longer than this value was found." + ) + if (n_partial_tokens := max_length - cur_len) > 0: + n_partial_words = find_nearest(segment_tokens_per_word, n_partial_tokens) + partial_segment = FunctionalSegment( + text=' '.join(words[:n_partial_words]), + dialog_act=segment.dialog_act, + speaker=segment.speaker, + is_continuation=segment.is_continuation, + start=segment.start, + end=segment.end, + completeness="right-truncated" + ) + window.append(partial_segment) + n_segment_tokens -= np.cumsum(segment_tokens_per_word)[n_partial_words] + else: + n_segment_tokens = 0 + yield Call(window) + window = [] + cur_len = n_segment_tokens + window.append(segment) + if window: + yield Call(window) + + def to_dataset( corpus: DialogActCorpus, tokenizer: PreTrainedTokenizer, @@ -41,7 +88,8 @@ def to_dataset( max_seq_length: Optional[int] = None, windows_if_exceeds_max_length: bool = False, use_joint_coding: bool = True, - use_turns: bool = False + use_turns: bool = False, + allow_partial_segments: bool = False, ) -> TensorDataset: ner_examples = [] for idx, call in enumerate(corpus.calls): @@ -55,7 +103,8 @@ def to_dataset( call=call, max_length=max_seq_length, tokenizer=tokenizer, - use_joint_coding=use_joint_coding + use_joint_coding=use_joint_coding, + allow_partial_segments=allow_partial_segments, ) else: call_parts = [call] @@ -132,6 +181,7 @@ def to_transformers_train_dataloader( use_joint_coding: bool = True, use_turns: bool = False, windows_if_exceeds_max_length: bool = False, + allow_partial_segments: bool = False, ) -> DataLoader: dataset = to_dataset( corpus=corpus, @@ -141,7 +191,8 @@ def to_transformers_train_dataloader( max_seq_length=max_seq_length, use_joint_coding=use_joint_coding, use_turns=use_turns, - windows_if_exceeds_max_length=windows_if_exceeds_max_length + windows_if_exceeds_max_length=windows_if_exceeds_max_length, + allow_partial_segments=allow_partial_segments, ) return to_dataloader(dataset, batch_size=batch_size, train=True, padding_at_start=model_type == 'xlnet') @@ -154,7 +205,8 @@ def to_transformers_eval_dataloader( labels: Iterable[str], max_seq_length: Optional[int] = None, use_joint_coding: bool = True, - use_turns: bool = False + use_turns: bool = False, + allow_partial_segments: bool = False, ) -> DataLoader: """ Convert the DA dataset into a PyTorch DataLoader for inference. @@ -173,7 +225,8 @@ def to_transformers_eval_dataloader( labels=labels, max_seq_length=max_seq_length, use_joint_coding=use_joint_coding, - use_turns=use_turns + use_turns=use_turns, + allow_partial_segments=allow_partial_segments, ) return to_dataloader(dataset, batch_size=batch_size, train=False, padding_at_start=model_type == 'xlnet') From 6e978de39ee8fec41f51def0fd23f3af3e58f746 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 29 Oct 2021 18:39:44 -0400 Subject: [PATCH 26/38] Handle partial segments in call encoding --- daseg/data.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/daseg/data.py b/daseg/data.py index 6c0f3ff..f3378ce 100644 --- a/daseg/data.py +++ b/daseg/data.py @@ -444,11 +444,15 @@ def encode( for (_, prv), (idx, cur), (_, nxt) in segment_windows: if use_joint_coding: enc = [(word, CONTINUE_TAG) for word in cur.text.split()] - if not (continuations_allowed and nxt is not None and nxt.is_continuation): + there_is_a_continuation = continuations_allowed and nxt is not None and nxt.is_continuation + segment_is_finished = cur.completeness in ['complete', 'left-truncated'] + if not there_is_a_continuation and segment_is_finished: enc[-1] = (enc[-1][0], cur.dialog_act) else: enc = [(word, f'{CONTINUE_TAG}{cur.dialog_act}') for word in cur.text.split()] - if not (continuations_allowed and cur.is_continuation): + cur_is_continuation = continuations_allowed and cur.is_continuation + segment_is_starting = cur.completeness in ['complete', 'right-truncated'] + if not cur_is_continuation and segment_is_starting: enc[0] = (enc[0][0], f'{BEGIN_TAG}{cur.dialog_act}') words, acts = zip(*enc) encoded_segment = EncodedSegment( From b9d321d79a6ca6262caaa158e0a7c57323fdc239 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 29 Oct 2021 18:50:19 -0400 Subject: [PATCH 27/38] Add boundary statistics to the tracked measures --- daseg/metrics.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/daseg/metrics.py b/daseg/metrics.py index 899c157..d7caeb6 100644 --- a/daseg/metrics.py +++ b/daseg/metrics.py @@ -58,6 +58,7 @@ def compute_seqeval_metrics(true_labels: List[List[str]], predictions: List[List def compute_segeval_metrics(true_dataset: DialogActCorpus, pred_dataset: DialogActCorpus): from statistics import mean from segeval.data import Dataset + from segeval.similarity import boundary_statistics from segeval import boundary_similarity, pk def fix_single_seg_calls(true, pred): @@ -76,16 +77,19 @@ def fix_single_seg_calls(true, pred): cid: {"hyp": [len(fs.text.split()) for fs in call]} for cid, call in pred_dataset.dialogues.items() } - + fix_single_seg_calls(true_segments, pred_segments) - + pred_segments = Dataset(pred_segments) true_segments = Dataset(true_segments) - - return { + + result = boundary_statistics(true_segments, pred_segments) + + result.update({ "pk": float(mean(pk(true_segments, pred_segments).values())), "B": float(mean(boundary_similarity(true_segments, pred_segments).values())), - } + }) + return result def compute_zhao_kawahara_metrics_levenshtein(true_dataset: DialogActCorpus, pred_dataset: DialogActCorpus): From b957662264e0a85fe6ca5c3a0439fec96b0cdc1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 29 Oct 2021 18:52:29 -0400 Subject: [PATCH 28/38] Add missing imports --- daseg/dataloaders/transformers.py | 3 ++- daseg/metrics.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/daseg/dataloaders/transformers.py b/daseg/dataloaders/transformers.py index 42b2a48..c738078 100644 --- a/daseg/dataloaders/transformers.py +++ b/daseg/dataloaders/transformers.py @@ -9,7 +9,8 @@ from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler, TensorDataset from transformers import PreTrainedTokenizer -from daseg import Call, DialogActCorpus +from daseg import Call, DialogActCorpus, FunctionalSegment +from daseg.data import find_nearest from daseg.utils_ner import InputExample, convert_examples_to_features diff --git a/daseg/metrics.py b/daseg/metrics.py index d7caeb6..cb14a31 100644 --- a/daseg/metrics.py +++ b/daseg/metrics.py @@ -84,7 +84,7 @@ def fix_single_seg_calls(true, pred): true_segments = Dataset(true_segments) result = boundary_statistics(true_segments, pred_segments) - + result.update({ "pk": float(mean(pk(true_segments, pred_segments).values())), "B": float(mean(boundary_similarity(true_segments, pred_segments).values())), From 276b2b13419345d3b60869df50aef808a4bf0cc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 29 Oct 2021 18:58:41 -0400 Subject: [PATCH 29/38] Repalce boundary statistics with boundary edit matrix --- daseg/metrics.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/daseg/metrics.py b/daseg/metrics.py index cb14a31..214731c 100644 --- a/daseg/metrics.py +++ b/daseg/metrics.py @@ -58,7 +58,7 @@ def compute_seqeval_metrics(true_labels: List[List[str]], predictions: List[List def compute_segeval_metrics(true_dataset: DialogActCorpus, pred_dataset: DialogActCorpus): from statistics import mean from segeval.data import Dataset - from segeval.similarity import boundary_statistics + from segeval.similarity import boundary_confusion_matrix from segeval import boundary_similarity, pk def fix_single_seg_calls(true, pred): @@ -83,13 +83,11 @@ def fix_single_seg_calls(true, pred): pred_segments = Dataset(pred_segments) true_segments = Dataset(true_segments) - result = boundary_statistics(true_segments, pred_segments) - - result.update({ + return { "pk": float(mean(pk(true_segments, pred_segments).values())), "B": float(mean(boundary_similarity(true_segments, pred_segments).values())), - }) - return result + "CM": boundary_confusion_matrix(true_segments, pred_segments), + } def compute_zhao_kawahara_metrics_levenshtein(true_dataset: DialogActCorpus, pred_dataset: DialogActCorpus): From 63987b24a861e161c2a814dd7de67d24673703b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 29 Oct 2021 19:04:51 -0400 Subject: [PATCH 30/38] B at different tolerance thresholds + summarization --- daseg/metrics.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/daseg/metrics.py b/daseg/metrics.py index 214731c..bd67873 100644 --- a/daseg/metrics.py +++ b/daseg/metrics.py @@ -9,6 +9,7 @@ import torch from Bio import pairwise2 from more_itertools import flatten +from segeval.compute import summarize from daseg import Call, DialogActCorpus from daseg.data import CONTINUE_TAG @@ -85,8 +86,10 @@ def fix_single_seg_calls(true, pred): return { "pk": float(mean(pk(true_segments, pred_segments).values())), - "B": float(mean(boundary_similarity(true_segments, pred_segments).values())), - "CM": boundary_confusion_matrix(true_segments, pred_segments), + "B(tol=2)": summarize(boundary_similarity(true_segments, pred_segments)), + "B(tol=5)": summarize(boundary_similarity(true_segments, pred_segments, n_t=5)), + "B(tol=10)": summarize(boundary_similarity(true_segments, pred_segments, n_t=10)), + "CM": summarize(boundary_confusion_matrix(true_segments, pred_segments)), } From bc9e5d3a058bf13f897f6c72ee0174890cb92f64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 29 Oct 2021 19:05:50 -0400 Subject: [PATCH 31/38] Fix --- daseg/metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/daseg/metrics.py b/daseg/metrics.py index bd67873..b0c09a2 100644 --- a/daseg/metrics.py +++ b/daseg/metrics.py @@ -9,7 +9,6 @@ import torch from Bio import pairwise2 from more_itertools import flatten -from segeval.compute import summarize from daseg import Call, DialogActCorpus from daseg.data import CONTINUE_TAG @@ -58,9 +57,10 @@ def compute_seqeval_metrics(true_labels: List[List[str]], predictions: List[List def compute_segeval_metrics(true_dataset: DialogActCorpus, pred_dataset: DialogActCorpus): from statistics import mean + from segeval import boundary_similarity, pk from segeval.data import Dataset from segeval.similarity import boundary_confusion_matrix - from segeval import boundary_similarity, pk + from segeval.compute import summarize def fix_single_seg_calls(true, pred): for cid in true.keys(): From dc89eaf932f1621aaf3ebde9eef1bcb06f717c79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 29 Oct 2021 19:06:46 -0400 Subject: [PATCH 32/38] Remove CM for now --- daseg/metrics.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/daseg/metrics.py b/daseg/metrics.py index b0c09a2..0ab4d0b 100644 --- a/daseg/metrics.py +++ b/daseg/metrics.py @@ -59,7 +59,6 @@ def compute_segeval_metrics(true_dataset: DialogActCorpus, pred_dataset: DialogA from statistics import mean from segeval import boundary_similarity, pk from segeval.data import Dataset - from segeval.similarity import boundary_confusion_matrix from segeval.compute import summarize def fix_single_seg_calls(true, pred): @@ -89,7 +88,7 @@ def fix_single_seg_calls(true, pred): "B(tol=2)": summarize(boundary_similarity(true_segments, pred_segments)), "B(tol=5)": summarize(boundary_similarity(true_segments, pred_segments, n_t=5)), "B(tol=10)": summarize(boundary_similarity(true_segments, pred_segments, n_t=10)), - "CM": summarize(boundary_confusion_matrix(true_segments, pred_segments)), + # "CM": summarize(boundary_confusion_matrix(true_segments, pred_segments)), } From 7a87ccea8f2ad2bbcc203428a276cb5d6afc58c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 29 Oct 2021 19:10:40 -0400 Subject: [PATCH 33/38] Readability fix --- daseg/metrics.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/daseg/metrics.py b/daseg/metrics.py index 0ab4d0b..33cbf1f 100644 --- a/daseg/metrics.py +++ b/daseg/metrics.py @@ -83,11 +83,17 @@ def fix_single_seg_calls(true, pred): pred_segments = Dataset(pred_segments) true_segments = Dataset(true_segments) + B2_mean, B2_std, *_ = summarize(boundary_similarity(true_segments, pred_segments)), + B5_mean, B5_std, *_ = summarize(boundary_similarity(true_segments, pred_segments, n_t=5)), + B10_mean, B10_std, *_ = summarize(boundary_similarity(true_segments, pred_segments, n_t=10)), return { "pk": float(mean(pk(true_segments, pred_segments).values())), - "B(tol=2)": summarize(boundary_similarity(true_segments, pred_segments)), - "B(tol=5)": summarize(boundary_similarity(true_segments, pred_segments, n_t=5)), - "B(tol=10)": summarize(boundary_similarity(true_segments, pred_segments, n_t=10)), + "B2": B2_mean, + "B2𝛔": B2_std, + "B5": B5_mean, + "B5𝛔": B5_std, + "B10": B10_mean, + "B10𝛔": B10_std, # "CM": summarize(boundary_confusion_matrix(true_segments, pred_segments)), } From 1ece9919a972124a71362e39c16323d8d64d6bad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 29 Oct 2021 19:11:14 -0400 Subject: [PATCH 34/38] Readability fix --- daseg/metrics.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/daseg/metrics.py b/daseg/metrics.py index 33cbf1f..9edf18c 100644 --- a/daseg/metrics.py +++ b/daseg/metrics.py @@ -88,12 +88,12 @@ def fix_single_seg_calls(true, pred): B10_mean, B10_std, *_ = summarize(boundary_similarity(true_segments, pred_segments, n_t=10)), return { "pk": float(mean(pk(true_segments, pred_segments).values())), - "B2": B2_mean, - "B2𝛔": B2_std, - "B5": B5_mean, - "B5𝛔": B5_std, - "B10": B10_mean, - "B10𝛔": B10_std, + "B2": float(B2_mean), + "B2𝛔": float(B2_std), + "B5": float(B5_mean), + "B5𝛔": float(B5_std), + "B10": float(B10_mean), + "B10𝛔": float(B10_std), # "CM": summarize(boundary_confusion_matrix(true_segments, pred_segments)), } From 1c5b267343966b983cd5cbc49f93b32480001ace Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 29 Oct 2021 19:25:09 -0400 Subject: [PATCH 35/38] moar metrixx --- daseg/metrics.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/daseg/metrics.py b/daseg/metrics.py index 9edf18c..eb4f860 100644 --- a/daseg/metrics.py +++ b/daseg/metrics.py @@ -9,6 +9,7 @@ import torch from Bio import pairwise2 from more_itertools import flatten +from segeval.similarity import boundary_statistics from daseg import Call, DialogActCorpus from daseg.data import CONTINUE_TAG @@ -83,10 +84,30 @@ def fix_single_seg_calls(true, pred): pred_segments = Dataset(pred_segments) true_segments = Dataset(true_segments) + summary = { + "correct": 0, + "almost_correct": 0, + "missed_boundaries": 0, + "false_boundaries": 0, + } + for stat in boundary_statistics(true_segments, pred_segments).values(): + summary["correct"] += len(stat["matches"]) + summary["almost_correct"] += len(stat["transpositions"]) + summary["missed_boundaries"] += len(stat["full_misses"]) + summary["false_boundaries"] += len(stat["additions"]) + OK = summary["correct"] + summary["almost_correct"] + summary["boundary_precision"] = OK / (OK + summary["false_boundaries"]) + summary["boundary_recall"] = OK / (OK + summary["missed_boundaries"]) + summary["boundary_f1"] = ( + summary["boundary_precision"] + * summary["boundary_recall"] + / (summary["boundary_precision"] + summary["boundary_recall"]) + ) + B2_mean, B2_std, *_ = summarize(boundary_similarity(true_segments, pred_segments)), B5_mean, B5_std, *_ = summarize(boundary_similarity(true_segments, pred_segments, n_t=5)), B10_mean, B10_std, *_ = summarize(boundary_similarity(true_segments, pred_segments, n_t=10)), - return { + summary.update({ "pk": float(mean(pk(true_segments, pred_segments).values())), "B2": float(B2_mean), "B2𝛔": float(B2_std), @@ -94,8 +115,8 @@ def fix_single_seg_calls(true, pred): "B5𝛔": float(B5_std), "B10": float(B10_mean), "B10𝛔": float(B10_std), - # "CM": summarize(boundary_confusion_matrix(true_segments, pred_segments)), - } + }) + return summary def compute_zhao_kawahara_metrics_levenshtein(true_dataset: DialogActCorpus, pred_dataset: DialogActCorpus): From cf9430e29f19971d8aafd21d4e6c02abac652c24 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 29 Oct 2021 19:25:49 -0400 Subject: [PATCH 36/38] fix --- daseg/metrics.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/daseg/metrics.py b/daseg/metrics.py index eb4f860..3f3c105 100644 --- a/daseg/metrics.py +++ b/daseg/metrics.py @@ -104,9 +104,9 @@ def fix_single_seg_calls(true, pred): / (summary["boundary_precision"] + summary["boundary_recall"]) ) - B2_mean, B2_std, *_ = summarize(boundary_similarity(true_segments, pred_segments)), - B5_mean, B5_std, *_ = summarize(boundary_similarity(true_segments, pred_segments, n_t=5)), - B10_mean, B10_std, *_ = summarize(boundary_similarity(true_segments, pred_segments, n_t=10)), + B2_mean, B2_std, *_ = summarize(boundary_similarity(true_segments, pred_segments)) + B5_mean, B5_std, *_ = summarize(boundary_similarity(true_segments, pred_segments, n_t=5)) + B10_mean, B10_std, *_ = summarize(boundary_similarity(true_segments, pred_segments, n_t=10)) summary.update({ "pk": float(mean(pk(true_segments, pred_segments).values())), "B2": float(B2_mean), From ac1776e30fe9b95bfb6f18f253a194d91e0ccdce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 29 Oct 2021 19:44:05 -0400 Subject: [PATCH 37/38] fix --- daseg/dataloaders/transformers.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/daseg/dataloaders/transformers.py b/daseg/dataloaders/transformers.py index c738078..6446c15 100644 --- a/daseg/dataloaders/transformers.py +++ b/daseg/dataloaders/transformers.py @@ -60,17 +60,18 @@ def as_windows( ) if (n_partial_tokens := max_length - cur_len) > 0: n_partial_words = find_nearest(segment_tokens_per_word, n_partial_tokens) - partial_segment = FunctionalSegment( - text=' '.join(words[:n_partial_words]), - dialog_act=segment.dialog_act, - speaker=segment.speaker, - is_continuation=segment.is_continuation, - start=segment.start, - end=segment.end, - completeness="right-truncated" - ) - window.append(partial_segment) - n_segment_tokens -= np.cumsum(segment_tokens_per_word)[n_partial_words] + if n_partial_words > 0: + partial_segment = FunctionalSegment( + text=' '.join(words[:n_partial_words]), + dialog_act=segment.dialog_act, + speaker=segment.speaker, + is_continuation=segment.is_continuation, + start=segment.start, + end=segment.end, + completeness="right-truncated" + ) + window.append(partial_segment) + n_segment_tokens -= np.cumsum(segment_tokens_per_word)[n_partial_words] else: n_segment_tokens = 0 yield Call(window) From c140ee745dac61b2d0076eb1e8afd667530376d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 4 Nov 2021 18:21:36 +0100 Subject: [PATCH 38/38] update gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index a6a7d7a..918e59d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +__pycache__/ +**/__pycache__ deps/ .idea/ *.txt