Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
35aefd5
Add punctuation data prep code (raw)
pzelasko Apr 16, 2021
52239db
Add segeval metrics
pzelasko Apr 16, 2021
ddd6d24
Initial GPU CRF loss implementation
pzelasko May 7, 2021
171a56e
Add documentation
pzelasko May 7, 2021
76978f8
Add option for trainable/freezed transition scores
pzelasko May 7, 2021
82a5395
Fix arc scores
pzelasko May 9, 2021
8cc0fd0
Fix from_str usage
pzelasko May 9, 2021
44ba0c8
Fix supervision segments and linear fsa
pzelasko May 9, 2021
fbdafa0
First attempt at CRF training
pzelasko May 9, 2021
e92e92d
Merge remote-tracking branch 'origin/punctuation' into punctuation
pzelasko May 9, 2021
a92232d
try modelling O as <eps>
pzelasko May 9, 2021
a812033
try modelling O as <eps>
pzelasko May 9, 2021
d1db50b
try modelling O as <eps>
pzelasko May 9, 2021
7565e6e
Add extra inputs to CRF loss
pzelasko May 11, 2021
ee553a1
Fix
pzelasko May 11, 2021
4b79a51
all subwords get the same label token
pzelasko May 11, 2021
441226c
Various fixes
pzelasko May 12, 2021
4181024
Exclude O (blank) symbol from loss computation
pzelasko May 12, 2021
2e3d24e
Fix supervision segments offset and remove the need for input lens
pzelasko May 12, 2021
9f3a7b5
various fixes to CRF
pzelasko May 12, 2021
2c86086
Working CRF numerator with subword token skipping
pzelasko May 12, 2021
c7bd9fb
Working naive CRF denominator version
pzelasko May 12, 2021
39d14c9
Local CRF loss import
pzelasko Jul 5, 2021
6335ac4
A lot of various fixes
pzelasko Oct 8, 2021
5257081
Merge branch 'punctuation' of https://github.com/pzelasko/daseg into …
pzelasko Oct 8, 2021
636218e
Handle edge case in boundary similarity computation for single-segmen…
pzelasko Oct 16, 2021
85cf4a0
Option to create partial segments for training
pzelasko Oct 29, 2021
6e978de
Handle partial segments in call encoding
pzelasko Oct 29, 2021
b9d321d
Add boundary statistics to the tracked measures
pzelasko Oct 29, 2021
b957662
Add missing imports
pzelasko Oct 29, 2021
276b2b1
Repalce boundary statistics with boundary edit matrix
pzelasko Oct 29, 2021
63987b2
B at different tolerance thresholds + summarization
pzelasko Oct 29, 2021
bc9e5d3
Fix
pzelasko Oct 29, 2021
dc89eaf
Remove CM for now
pzelasko Oct 29, 2021
7a87cce
Readability fix
pzelasko Oct 29, 2021
1ece991
Readability fix
pzelasko Oct 29, 2021
1c5b267
moar metrixx
pzelasko Oct 29, 2021
cf9430e
fix
pzelasko Oct 29, 2021
ac1776e
fix
pzelasko Oct 29, 2021
c140ee7
update gitignore
pzelasko Nov 4, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
__pycache__/
**/__pycache__
deps/
.idea/
*.txt
Expand Down
26 changes: 20 additions & 6 deletions daseg/bin/dasg
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,18 @@ def evaluate(
results['args'] = ctx.params
with SlackNotifier(' '.join(sys.argv)) as slack:
for res_grp in (
'sklearn_metrics', 'seqeval_metrics', 'zhao_kawahara_metrics', 'ORIGINAL_zhao_kawahara_metrics',
'sklearn_metrics', 'seqeval_metrics', 'zhao_kawahara_metrics',
'segeval_metrics',
'args'
):
slack.write_and_print(f'{res_grp.upper()}:')
for key, val in results[res_grp].items():
slack.write_and_print(f'{key}\t{val:.2%}')
try:
# TODO: rewrite cleaner later
if isinstance(val, int): raise Exception()
slack.write_and_print(f'{key}\t{val:.2%}')
except:
slack.write_and_print(f'{key}\t{val}')
if save_output is not None:
with open(save_output, 'wb') as f:
pickle.dump(results, f)
Expand All @@ -152,6 +158,7 @@ def evaluate(
@click.option('-l', '--max-sequence-length', default=4096, type=int)
@click.option('-n', '--turns', is_flag=True)
@click.option('-w', '--windows-if-exceeds-max-len', is_flag=True)
@click.option('-a', '--allow-partial-segments', is_flag=True)
def prepare_exp(
output_dir: Path,
model_name_or_path: str,
Expand All @@ -162,6 +169,7 @@ def prepare_exp(
max_sequence_length: int,
turns: bool,
windows_if_exceeds_max_len: bool,
allow_partial_segments: bool,
):
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
Expand All @@ -183,7 +191,8 @@ def prepare_exp(
labels=model.labels,
max_seq_length=max_sequence_length,
use_turns=turns,
windows_if_exceeds_max_length=windows_if_exceeds_max_len
windows_if_exceeds_max_length=windows_if_exceeds_max_len,
allow_partial_segments=allow_partial_segments,
)
for key, split_corpus in corpus.train_dev_test_split().items()
}
Expand All @@ -203,6 +212,7 @@ def prepare_exp(
@click.option('-r', '--random-seed', default=1050, type=int)
@click.option('-g', '--num-gpus', default=0, type=int)
@click.option('-f', '--fp16', is_flag=True)
@click.option('--crf/--no-crf', default=False)
def train_transformer(
exp_dir: Path,
model_name_or_path: str,
Expand All @@ -212,15 +222,16 @@ def train_transformer(
gradient_accumulation_steps: int,
random_seed: int,
num_gpus: int,
fp16: bool
fp16: bool,
crf: bool
):
pl.seed_everything(random_seed)
output_path = Path(exp_dir)
with open(output_path / 'dataset.pkl', 'rb') as f:
datasets: Dict[str, Dataset] = pickle.load(f)
with open(output_path / 'labels.pkl', 'rb') as f:
labels: List[str] = pickle.load(f)
model = DialogActTransformer(labels=labels, model_name_or_path=model_name_or_path)
model = DialogActTransformer(labels=labels, model_name_or_path=model_name_or_path, crf=crf)
loaders = {
key: to_dataloader(
dset,
Expand Down Expand Up @@ -386,7 +397,10 @@ def evaluate_bigru(
results['args'] = ctx.params
with SlackNotifier(' '.join(sys.argv)) as slack:
for key, val in results['log'].items():
slack.write_and_print(f'{key}\t{val:.2%}')
try:
slack.write_and_print(f'{key}\t{val:.2%}')
except:
slack.write_and_print(f'{key}\t{val}')
if save_output is not None:
with open(save_output, 'wb') as f:
pickle.dump(results, f)
Expand Down
2 changes: 1 addition & 1 deletion daseg/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def predictions_to_dataset(
) -> DialogActCorpus:
dialogues = {}
for (call_id, call), pred_tags in zip(original_dataset.dialogues.items(), predictions):
words, _, speakers = call.words_with_metadata(add_turn_token=True)
words, orig_labels, speakers = call.words_with_metadata(add_turn_token=True)
assert len(words) == len(pred_tags), \
f'Mismatched words ({len(words)}) and predicted tags ({len(pred_tags)}) counts for conversation "{call_id}"'

Expand Down
91 changes: 86 additions & 5 deletions daseg/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from functools import partial
from itertools import chain, groupby
from pathlib import Path
from typing import Callable, Dict, FrozenSet, Iterable, List, Mapping, NamedTuple, Optional, Set, Tuple
from typing import Callable, Dict, FrozenSet, Iterable, List, Literal, Mapping, NamedTuple, Optional, Set, Tuple

import numpy as np
from cytoolz.itertoolz import sliding_window
from more_itertools import flatten
from spacy import displacy
Expand Down Expand Up @@ -211,7 +212,7 @@ def turns(self) -> Iterable['Call']:

@property
def dialog_acts(self) -> List[str]:
return sorted(set(segment.dialog_act for call in self.calls for segment in call))
return sorted(set(str(segment.dialog_act) for call in self.calls for segment in call))

@property
def dialog_act_labels(self) -> List[str]:
Expand All @@ -234,7 +235,7 @@ def joint_coding_dialog_act_label_frequencies(self):

@property
def joint_coding_dialog_act_labels(self) -> List[str]:
return list(chain([BLANK, CONTINUE_TAG], self.dialog_acts))
return list(chain([CONTINUE_TAG], self.dialog_acts))

@property
def vocabulary(self) -> Counter:
Expand Down Expand Up @@ -443,11 +444,15 @@ def encode(
for (_, prv), (idx, cur), (_, nxt) in segment_windows:
if use_joint_coding:
enc = [(word, CONTINUE_TAG) for word in cur.text.split()]
if not (continuations_allowed and nxt is not None and nxt.is_continuation):
there_is_a_continuation = continuations_allowed and nxt is not None and nxt.is_continuation
segment_is_finished = cur.completeness in ['complete', 'left-truncated']
if not there_is_a_continuation and segment_is_finished:
enc[-1] = (enc[-1][0], cur.dialog_act)
else:
enc = [(word, f'{CONTINUE_TAG}{cur.dialog_act}') for word in cur.text.split()]
if not (continuations_allowed and cur.is_continuation):
cur_is_continuation = continuations_allowed and cur.is_continuation
segment_is_starting = cur.completeness in ['complete', 'right-truncated']
if not cur_is_continuation and segment_is_starting:
enc[0] = (enc[0][0], f'{BEGIN_TAG}{cur.dialog_act}')
words, acts = zip(*enc)
encoded_segment = EncodedSegment(
Expand Down Expand Up @@ -475,6 +480,23 @@ def encode(

return encoded_call

def cut_segments_in_windows(self, window_size_in_tokens: int, tokenizer) -> "Call":
"""
Return a copy of ``self`` where functional segments longer than ``window_size_in_tokens``
are split into partial segments. In order to know the token count we will need a tokenizer
argument that has a method ``.tokenize(text: str)``.
"""
return Call(
list(
chain.from_iterable(
fs.to_windows(
window_size_in_tokens=window_size_in_tokens, tokenizer=tokenizer
)
for fs in self
)
)
)


def prepare_call_windows(
call: Call,
Expand Down Expand Up @@ -508,6 +530,9 @@ class FunctionalSegment(NamedTuple):
is_continuation: bool = False
start: Optional[float] = None
end: Optional[float] = None
completeness: Literal[
"complete", "left-truncated", "right-truncated", "both-truncated"
] = "complete"

@property
def num_words(self) -> int:
Expand All @@ -521,6 +546,62 @@ def with_vocabulary(self, vocabulary: Set[str]) -> 'FunctionalSegment':
new_text = ' '.join(w if w in vocabulary else OOV for w in self.text.split())
return FunctionalSegment(new_text, *self[1:])

def to_windows(
self, window_size_in_tokens: int, tokenizer
) -> List["FunctionalSegment"]:
words = self.text.split()
segment_tokens_per_word = [len(tokenizer.tokenize(w)) for w in words]
n_segment_tokens = sum(segment_tokens_per_word)
if n_segment_tokens <= window_size_in_tokens:
return [self]

partial_segments = []
is_first = True
is_last = False
while words:
if n_segment_tokens < window_size_in_tokens:
text = " ".join(words)
words = []
is_last = True
else:
n_partial_words = find_nearest(
segment_tokens_per_word, window_size_in_tokens
)
text = " ".join(words[:n_partial_words])
words = words[n_partial_words:]
partial_segments.append(
FunctionalSegment(
text=text,
dialog_act=self.dialog_act,
speaker=self.speaker,
is_continuation=self.is_continuation,
start=self.start,
end=self.end,
completeness=(
"right-truncated"
if is_first
else "left-truncated"
if is_last
else "both-truncated"
),
)
)
is_first = False

return partial_segments


def find_nearest(array: List[int], value: int):
"""
Find the index of the closest element to ``value`` in cumulative sum of ``array``
that is not greater than ``value``.
"""
array = np.cumsum(array)
diff = array - value
diff = np.where(diff <= 0, diff, np.inf)
idx = diff.argmin()
return idx


class EncodedSegment(NamedTuple):
words: List[str]
Expand Down
Loading