diff --git a/src/pie_modules/taskmodules/re_span_pair_classification.py b/src/pie_modules/taskmodules/re_span_pair_classification.py index 6b61aa868..3a9b546cf 100644 --- a/src/pie_modules/taskmodules/re_span_pair_classification.py +++ b/src/pie_modules/taskmodules/re_span_pair_classification.py @@ -192,13 +192,25 @@ def construct_argument_marker(pos: str, label: Optional[str] = None, role: str = def inject_markers_into_text( text: str, positions_and_markers: List[Tuple[int, str]] -) -> Tuple[str, Dict[int, int]]: +) -> Tuple[str, Dict[int, Tuple[int, List[str]]]]: + """Inject markers into the text at the given positions. + + Args: + text: The text to inject the markers into. + positions_and_markers: A list of tuples where each tuple contains the position in the text + where the marker should be injected and the marker text itself. + + Returns: + A tuple containing the text with the markers injected and a dictionary mapping the original + positions to the new positions and the markers that were injected at that position. + """ offset = 0 - original2new_pos = dict() + original2new_pos: Dict[int, Tuple[int, List[str]]] = dict() for original_pos, marker in sorted(positions_and_markers): text = text[: original_pos + offset] + marker + text[original_pos + offset :] + previous_markers = original2new_pos.get(original_pos, (-1, []))[1] + original2new_pos[original_pos] = (original_pos + offset, previous_markers + [marker]) offset += len(marker) - original2new_pos[original_pos] = original_pos + offset return text, original2new_pos @@ -505,21 +517,23 @@ def inject_markers_for_labeled_spans( if isinstance(document, TextDocumentWithLabeledPartitions): # create "dummy" markers for the partitions so that entries for these positions are created - # in original2new_pos + # in original_pos2new_pos_and_markers for labeled_partition in document.labeled_partitions: positions_and_markers.append((labeled_partition.start, "")) positions_and_markers.append((labeled_partition.end, "")) # inject markers into the text - marked_text, original2new_pos = inject_markers_into_text( + marked_text, original_pos2new_pos_and_markers = inject_markers_into_text( document.text, positions_and_markers ) # construct new spans old2new_spans = dict() for labeled_span in document.labeled_spans: - start = original2new_pos[labeled_span.start] - end = original2new_pos[labeled_span.end] + start_before_markers, markers = original_pos2new_pos_and_markers[labeled_span.start] + # we use just the span *without* the markers as new span + start = start_before_markers + sum(len(marker) for marker in markers) + end = original_pos2new_pos_and_markers[labeled_span.end][0] new_span = LabeledSpan(start=start, end=end, label=labeled_span.label) old2new_spans[labeled_span] = new_span @@ -546,9 +560,13 @@ def inject_markers_for_labeled_spans( new_document.binary_relations.extend(old2new_relations.values()) if isinstance(document, TextDocumentWithLabeledPartitions): for labeled_partition in document.labeled_partitions: - new_start = original2new_pos[labeled_partition.start] - new_end = original2new_pos[labeled_partition.end] - new_labeled_partitions = labeled_partition.copy(start=new_start, end=new_end) + # we use the span *including* the markers as new span + start, _ = original_pos2new_pos_and_markers[labeled_partition.start] + end_before_markers, markers = original_pos2new_pos_and_markers[ + labeled_partition.end + ] + end = end_before_markers + sum(len(marker) for marker in markers) + new_labeled_partitions = labeled_partition.copy(start=start, end=end) new_document.labeled_partitions.append(new_labeled_partitions) new2old_spans = {new_span: old_span for old_span, new_span in old2new_spans.items()} @@ -657,7 +675,7 @@ def encode_target( get_relation_argument_spans_and_roles(relation) ].append(relation) label_indices = [] # list of label indices - candidate_relations = [] + # candidate_relations = [] for candidate_relation in task_encoding.metadata["candidate_relations"]: candidate_roles_and_args = get_relation_argument_spans_and_roles(candidate_relation) gold_relations = gold_roles_and_args2relation.get(candidate_roles_and_args, []) @@ -678,9 +696,9 @@ def encode_target( label_idx = PAD_VALUES["labels"] label_indices.append(label_idx) - candidate_relations.append(candidate_relation) + # candidate_relations.append(candidate_relation) - task_encoding.metadata["candidate_relations"] = candidate_relations + # task_encoding.metadata["candidate_relations"] = candidate_relations target: TargetEncodingType = {"labels": to_tensor("labels", label_indices)} self._maybe_log_example(task_encoding=task_encoding, target=target) @@ -711,7 +729,9 @@ def _maybe_log_example( ): logger.info(f"relation {i}: {label}") for j, arg_idx in enumerate(tuple_indices): - arg_tokens = tokens[span_start_indices[arg_idx] : span_end_indices[arg_idx]] + arg_tokens = tokens[ + span_start_indices[arg_idx] : span_end_indices[arg_idx] + 1 + ] logger.info(f"\targ {j}: {' '.join([str(x) for x in arg_tokens])}") self._logged_examples_counter += 1 diff --git a/tests/models/test_span_tuple_classification.py b/tests/models/test_span_tuple_classification.py index f984fdbdb..e4a6b833c 100644 --- a/tests/models/test_span_tuple_classification.py +++ b/tests/models/test_span_tuple_classification.py @@ -167,7 +167,7 @@ def batch(): ] ), "span_start_indices": tensor([[1, 9, 0, 0], [4, 12, 18, 0], [4, 12, 17, 21]]), - "span_end_indices": tensor([[7, 12, 0, 0], [10, 15, 21, 0], [10, 15, 20, 24]]), + "span_end_indices": tensor([[6, 11, 0, 0], [9, 14, 20, 0], [9, 14, 19, 23]]), "tuple_indices": tensor( [[[0, 1], [-1, -1], [-1, -1]], [[0, 1], [0, 2], [2, 1]], [[0, 1], [2, 3], [3, 2]]] ), @@ -351,46 +351,41 @@ def test_forward_logits(batch, model): tensor( [ [ - -0.23075447976589203, - 0.08129829168319702, - -0.26441076397895813, - 0.3208361268043518, + -0.3551301658153534, + 0.09493370354175568, + -0.15801358222961426, + 0.5679908990859985, ], [ - -0.2247302085161209, - 0.21453489363193512, - -0.20609508454799652, - 0.2984844148159027, + -0.266460657119751, + 0.16119083762168884, + -0.10706772655248642, + 0.5230874419212341, ], [ - -0.0552724152803421, - 0.18319237232208252, - -0.14115819334983826, - 0.23137536644935608, + -0.11953088641166687, + 0.1623934805393219, + -0.04825110733509064, + 0.43645235896110535, ], + [-0.2047966569662094, 0.17388933897018433, -0.06319254636764526, 0.4306640625], [ - -0.2897184491157532, - 0.17462071776390076, - -0.12004873156547546, - 0.1817789375782013, + -0.3208402395248413, + 0.09282125532627106, + -0.05495951324701309, + 0.4880615472793579, ], [ - -0.3101494312286377, - 0.18245069682598114, - -0.13525372743606567, - 0.28625163435935974, + -0.4020463228225708, + 0.2283128798007965, + 0.013205204159021378, + 0.3972089886665344, ], [ - -0.33728304505348206, - 0.22038179636001587, - -0.0482308566570282, - 0.25237396359443665, - ], - [ - -0.3835912048816681, - 0.20549766719341278, - 0.15333643555641174, - 0.23370930552482605, + -0.2575981616973877, + 0.0700659453868866, + -0.010283984243869781, + 0.4580671489238739, ], ] ), @@ -402,7 +397,7 @@ def test_step(batch, model, config): loss = model._step("train", batch) assert loss is not None if config == {}: - torch.testing.assert_close(loss, torch.tensor(1.3911350965499878)) + torch.testing.assert_close(loss, torch.tensor(1.3872407674789429)) else: raise ValueError(f"Unknown config: {config}") @@ -413,7 +408,7 @@ def test_training_step_and_on_epoch_end(batch, model, config): loss = model.training_step(batch, batch_idx=0) assert loss is not None if config == {}: - torch.testing.assert_close(loss, torch.tensor(1.3911350965499878)) + torch.testing.assert_close(loss, torch.tensor(1.3872407674789429)) else: raise ValueError(f"Unknown config: {config}") @@ -427,7 +422,7 @@ def test_validation_step_and_on_epoch_end(batch, model, config): assert loss is not None metric_values = {k: v.item() for k, v in metric.compute().items()} if config == {}: - torch.testing.assert_close(loss, torch.tensor(1.3911350965499878)) + torch.testing.assert_close(loss, torch.tensor(1.3872407674789429)) assert metric_values == { "macro/f1": 0.14814814925193787, "micro/f1": 0.2857142984867096, @@ -449,7 +444,7 @@ def test_test_step_and_on_epoch_end(batch, model, config): assert loss is not None metric_values = {k: v.item() for k, v in metric.compute().items()} if config == {}: - torch.testing.assert_close(loss, torch.tensor(1.3911350965499878)) + torch.testing.assert_close(loss, torch.tensor(1.3872407674789429)) assert metric_values == { "macro/f1": 0.14814814925193787, "micro/f1": 0.2857142984867096, @@ -483,21 +478,21 @@ def test_predict_and_predict_step(model, batch, config, test_step): tensor( [ [ - [0.1973, 0.2695, 0.1907, 0.3425], + [0.1586, 0.2488, 0.1932, 0.3993], [-1.0000, -1.0000, -1.0000, -1.0000], [-1.0000, -1.0000, -1.0000, -1.0000], ], [ - [0.1902, 0.2951, 0.1938, 0.3209], - [0.2213, 0.2809, 0.2031, 0.2947], - [0.1859, 0.2958, 0.2203, 0.2979], + [0.1692, 0.2596, 0.1985, 0.3727], + [0.1944, 0.2578, 0.2088, 0.3390], + [0.1818, 0.2655, 0.2095, 0.3432], ], [ - [0.1772, 0.2900, 0.2111, 0.3217], - [0.1699, 0.2968, 0.2269, 0.3064], - [0.1571, 0.2831, 0.2687, 0.2912], + [0.1650, 0.2495, 0.2152, 0.3704], + [0.1511, 0.2839, 0.2289, 0.3361], + [0.1750, 0.2429, 0.2241, 0.3580], ], - ], + ] ), ) else: diff --git a/tests/taskmodules/test_re_span_pair_classification.py b/tests/taskmodules/test_re_span_pair_classification.py index 80f494f4f..46d7eaaa3 100644 --- a/tests/taskmodules/test_re_span_pair_classification.py +++ b/tests/taskmodules/test_re_span_pair_classification.py @@ -6,7 +6,11 @@ import torch from pytorch_ie import AnnotationLayer, annotation_field from pytorch_ie.annotations import BinaryRelation, LabeledSpan -from pytorch_ie.documents import TextBasedDocument +from pytorch_ie.documents import ( + TextBasedDocument, + TextDocumentWithLabeledSpansAndBinaryRelations, + TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, +) from torch import tensor from torchmetrics import Metric, MetricCollection @@ -17,7 +21,10 @@ TOKENIZER_NAME_OR_PATH = "bert-base-cased" -CONFIGS = [{}, {"partition_annotation": "sentences"}] +CONFIGS = [ + {}, + {"partition_annotation": "sentences"}, +] CONFIGS_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS} @@ -159,6 +166,55 @@ def task_encodings(taskmodule, document): return result +@pytest.fixture(scope="module") +def normalized_document(taskmodule, document): + return taskmodule.normalize_document(document) + + +def test_normalize_document(taskmodule, document, normalized_document): + assert normalized_document is not None + assert isinstance(normalized_document, TextDocumentWithLabeledSpansAndBinaryRelations) + assert len(normalized_document.labeled_spans) > 0 + assert normalized_document.labeled_spans.resolve() == document.entities.resolve() + assert len(normalized_document.binary_relations) > 0 + assert normalized_document.binary_relations.resolve() == document.relations.resolve() + + +def test_inject_markers_for_labeled_spans(taskmodule, normalized_document): + document_with_markers, injected2original_spans = taskmodule.inject_markers_for_labeled_spans( + normalized_document + ) + assert document_with_markers is not None + assert ( + document_with_markers.text + == "First sentence. [SPAN:PER]Entity G[/SPAN:PER] works at [SPAN:ORG]H[/SPAN:ORG]. And founded [SPAN:ORG]I[/SPAN:ORG]." + ) + assert ( + document_with_markers.labeled_spans.resolve() + == normalized_document.labeled_spans.resolve() + == [("PER", "Entity G"), ("ORG", "H"), ("ORG", "I")] + ) + assert len(injected2original_spans) == 3 + assert all( + labeled_span in injected2original_spans + for labeled_span in document_with_markers.labeled_spans + ) + if isinstance( + normalized_document, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions + ): + assert isinstance( + document_with_markers, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions + ) + assert len(document_with_markers.labeled_partitions) == len( + normalized_document.labeled_partitions + ) + assert document_with_markers.labeled_partitions.resolve() == [ + ("sentence", "First sentence."), + ("sentence", "[SPAN:PER]Entity G[/SPAN:PER] works at [SPAN:ORG]H[/SPAN:ORG]."), + ("sentence", "And founded [SPAN:ORG]I[/SPAN:ORG]."), + ] + + def test_encode_input(task_encodings, document, taskmodule, cfg): assert task_encodings is not None if cfg == {}: @@ -198,8 +254,17 @@ def test_encode_input(task_encodings, document, taskmodule, cfg): ".", "[SEP]", ] + span_start_and_end_tokens = [ + (tokens[start], tokens[end]) + for start, end in zip(inputs["span_start_indices"], inputs["span_end_indices"]) + ] + assert span_start_and_end_tokens == [ + ("[SPAN:PER]", "[/SPAN:PER]"), + ("[SPAN:ORG]", "[/SPAN:ORG]"), + ("[SPAN:ORG]", "[/SPAN:ORG]"), + ] span_tokens = [ - tokens[start:end] + tokens[start : end + 1] for start, end in zip(inputs["span_start_indices"], inputs["span_end_indices"]) ] assert span_tokens == [ @@ -235,8 +300,13 @@ def test_encode_input(task_encodings, document, taskmodule, cfg): "tuple_indices_mask", } tokens = taskmodule.tokenizer.convert_ids_to_tokens(inputs["input_ids"]) + span_start_and_end_tokens = [ + (tokens[start], tokens[end]) + for start, end in zip(inputs["span_start_indices"], inputs["span_end_indices"]) + ] + span_tokens = [ - tokens[start:end] + tokens[start : end + 1] for start, end in zip(inputs["span_start_indices"], inputs["span_end_indices"]) ] tuple_spans = [ @@ -245,6 +315,7 @@ def test_encode_input(task_encodings, document, taskmodule, cfg): if idx == 0: assert tokens == [ "[CLS]", + "[SPAN:PER]", "En", "##ti", "##ty", @@ -258,13 +329,17 @@ def test_encode_input(task_encodings, document, taskmodule, cfg): ".", "[SEP]", ] + assert span_start_and_end_tokens == [ + ("[SPAN:PER]", "[/SPAN:PER]"), + ("[SPAN:ORG]", "[/SPAN:ORG]"), + ] assert span_tokens == [ - ["[CLS]", "En", "##ti", "##ty", "G", "[/SPAN:PER]"], + ["[SPAN:PER]", "En", "##ti", "##ty", "G", "[/SPAN:PER]"], ["[SPAN:ORG]", "H", "[/SPAN:ORG]"], ] assert tuple_spans == [ [ - ["[CLS]", "En", "##ti", "##ty", "G", "[/SPAN:PER]"], + ["[SPAN:PER]", "En", "##ti", "##ty", "G", "[/SPAN:PER]"], ["[SPAN:ORG]", "H", "[/SPAN:ORG]"], ] ] @@ -346,18 +421,18 @@ def test_encode_with_multiple_gold_relations_with_same_arguments(document, caplo assert ( caplog.messages[0] == "skip the candidate relation because there are more than one gold relation for " - "its args and roles: [BinaryRelation(head=LabeledSpan(start=5, end=10, label='PER', score=1.0), " - "tail=LabeledSpan(start=13, end=15, label='ORG', score=1.0), label='org:founded_by', score=1.0), " - "BinaryRelation(head=LabeledSpan(start=5, end=10, label='PER', score=1.0), " - "tail=LabeledSpan(start=13, end=15, label='ORG', score=1.0), label='per:employee_of', score=1.0)]" + "its args and roles: [BinaryRelation(head=LabeledSpan(start=5, end=9, label='PER', score=1.0), " + "tail=LabeledSpan(start=13, end=14, label='ORG', score=1.0), label='org:founded_by', score=1.0), " + "BinaryRelation(head=LabeledSpan(start=5, end=9, label='PER', score=1.0), " + "tail=LabeledSpan(start=13, end=14, label='ORG', score=1.0), label='per:employee_of', score=1.0)]" ) assert ( caplog.messages[1] == "skip the candidate relation because there are more than one gold relation for " - "its args and roles: [BinaryRelation(head=LabeledSpan(start=5, end=10, label='PER', score=1.0), " - "tail=LabeledSpan(start=13, end=15, label='ORG', score=1.0), label='org:founded_by', score=1.0), " - "BinaryRelation(head=LabeledSpan(start=5, end=10, label='PER', score=1.0), " - "tail=LabeledSpan(start=13, end=15, label='ORG', score=1.0), label='per:employee_of', score=1.0)]" + "its args and roles: [BinaryRelation(head=LabeledSpan(start=5, end=9, label='PER', score=1.0), " + "tail=LabeledSpan(start=13, end=14, label='ORG', score=1.0), label='org:founded_by', score=1.0), " + "BinaryRelation(head=LabeledSpan(start=5, end=9, label='PER', score=1.0), " + "tail=LabeledSpan(start=13, end=14, label='ORG', score=1.0), label='per:employee_of', score=1.0)]" ) assert len(encodings) == 1 @@ -396,10 +471,10 @@ def test_maybe_log_example(taskmodule, task_encodings, caplog, cfg): assert caplog.messages == [ "*** Example ***", "doc id: train_doc5", - "tokens: [CLS] En ##ti ##ty G [/SPAN:PER] works at [SPAN:ORG] H [/SPAN:ORG] . [SEP]", - "input_ids: 101 13832 3121 2340 144 28998 1759 1120 28999 145 28997 119 102", + "tokens: [CLS] [SPAN:PER] En ##ti ##ty G [/SPAN:PER] works at [SPAN:ORG] H [/SPAN:ORG] . [SEP]", + "input_ids: 101 28996 13832 3121 2340 144 28998 1759 1120 28999 145 28997 119 102", "relation 0: per:employee_of", - "\targ 0: [CLS] En ##ti ##ty G [/SPAN:PER]", + "\targ 0: [SPAN:PER] En ##ti ##ty G [/SPAN:PER]", "\targ 1: [SPAN:ORG] H [/SPAN:ORG]", ] else: @@ -483,7 +558,7 @@ def test_collate(taskmodule, task_encodings, cfg): ) torch.testing.assert_close(inputs["attention_mask"], torch.ones_like(inputs["input_ids"])) torch.testing.assert_close(inputs["span_start_indices"], tensor([[4, 12, 18]])) - torch.testing.assert_close(inputs["span_end_indices"], tensor([[10, 15, 21]])) + torch.testing.assert_close(inputs["span_end_indices"], tensor([[9, 14, 20]])) torch.testing.assert_close(inputs["tuple_indices"], tensor([[[0, 1], [0, 2], [2, 1]]])) torch.testing.assert_close(inputs["tuple_indices_mask"], tensor([[True, True, True]])) assert set(targets) == {"labels"} @@ -492,13 +567,30 @@ def test_collate(taskmodule, task_encodings, cfg): torch.testing.assert_close( inputs["input_ids"], tensor( - [[101, 13832, 3121, 2340, 144, 28998, 1759, 1120, 28999, 145, 28997, 119, 102]] + [ + [ + 101, + 28996, + 13832, + 3121, + 2340, + 144, + 28998, + 1759, + 1120, + 28999, + 145, + 28997, + 119, + 102, + ] + ] ), ) torch.testing.assert_close( - inputs["attention_mask"], tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) + inputs["attention_mask"], tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) ) - torch.testing.assert_close(inputs["span_start_indices"], tensor([[0, 8]])) + torch.testing.assert_close(inputs["span_start_indices"], tensor([[1, 9]])) torch.testing.assert_close(inputs["span_end_indices"], tensor([[6, 11]])) torch.testing.assert_close(inputs["tuple_indices"], tensor([[[0, 1]]])) torch.testing.assert_close(inputs["tuple_indices_mask"], tensor([[True]]))