diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..65ee06b --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,25 @@ +name: CI + +on: + push: + branches: [ main ] + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirement.txt + pip install pytest + - name: Run pytest + run: pytest + - name: Ensure coverage threshold + run: python tools/run_trace_coverage.py --min 80 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..669ddd6 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,26 @@ +import sys +import types +from pathlib import Path + +# Ensure the repository root is importable as a package during tests +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + +# Provide a lightweight onnxruntime stub so modules can be imported in tests +if "onnxruntime" not in sys.modules: + class _StubSession: + def __init__(self, *args, **kwargs): + raise RuntimeError( + "onnxruntime is stubbed in tests. Monkeypatch InferenceSession in individual tests." + ) + + sys.modules["onnxruntime"] = types.SimpleNamespace(InferenceSession=_StubSession) + +if "spacy" not in sys.modules: + def _stub_load(*_args, **_kwargs): + raise RuntimeError( + "spacy is stubbed in tests. Monkeypatch spacy.load within individual tests." + ) + + sys.modules["spacy"] = types.SimpleNamespace(load=_stub_load) diff --git a/tests/test_candidates.py b/tests/test_candidates.py new file mode 100644 index 0000000..6c484c0 --- /dev/null +++ b/tests/test_candidates.py @@ -0,0 +1,134 @@ +import io +import json + +from groundkg import candidates + + +def test_non_overlapping_chunks_filters_overlaps(): + sent = "Alice and Bob met Charlie" + ents = [ + {"text": "Alice", "start": 0, "end": 5, "label": "PERSON"}, + {"text": "Charlie", "start": 18, "end": 25, "label": "PERSON"}, + ] + + chunks = candidates.non_overlapping_chunks(sent, ents) + + # "Alice" and "Charlie" should be excluded because they overlap entities + chunk_texts = {c["text"] for c in chunks} + assert chunk_texts == {"Bob"} + # ensure chunks carry the expected metadata + for chunk in chunks: + assert chunk["label"] == "NOUNPHRASE" + assert chunk["end"] - chunk["start"] >= 3 + + +def test_main_emits_subject_object_pairs(tmp_path, monkeypatch): + record = { + "doc_id": "d1", + "sent_idx": 0, + "sent_start": 0, + "text": "Alice visited Paris with Charlie", + "entities": [ + {"text": "Alice", "start": 0, "end": 5, "label": "PERSON"}, + {"text": "Paris", "start": 13, "end": 18, "label": "GPE"}, + ], + } + ner_path = tmp_path / "ner.jsonl" + ner_path.write_text(json.dumps(record) + "\n", encoding="utf-8") + + monkeypatch.setenv("PYTHONHASHSEED", "0") # ensure deterministic iteration if needed + monkeypatch.setattr("sys.argv", ["candidates.py", str(ner_path)]) + + buf = io.StringIO() + monkeypatch.setattr("sys.stdout", buf) + + candidates.main() + + lines = [json.loads(line) for line in buf.getvalue().splitlines() if line] + assert any( + rec["subject"]["text"] == "Alice" and rec["object"]["text"] == "Paris" + for rec in lines + ) + for rec in lines: + assert rec["doc_id"] == "d1" + + +def test_main_respects_char_distance_limit(tmp_path, monkeypatch): + long_text = "Alice" + " " * 151 + "Paris" + record = { + "doc_id": "d1", + "sent_idx": 0, + "sent_start": 0, + "text": long_text, + "entities": [ + {"text": "Alice", "start": 0, "end": 5, "label": "PERSON"}, + {"text": "Paris", "start": 156, "end": 161, "label": "GPE"}, + ], + } + ner_path = tmp_path / "ner.jsonl" + ner_path.write_text(json.dumps(record) + "\n", encoding="utf-8") + + buf = io.StringIO() + monkeypatch.setattr("sys.stdout", buf) + monkeypatch.setattr("sys.argv", ["candidates.py", str(ner_path)]) + + candidates.main() + + assert buf.getvalue().strip() == "" + + +def test_main_caps_pairs_at_limit(tmp_path, monkeypatch): + tokens = ["Alice"] + [f"Obj{i}" for i in range(12)] + text = " ".join(tokens) + + ents = [] + cursor = 0 + for token in tokens: + start = cursor + end = start + len(token) + label = "PERSON" if token == "Alice" else "PRODUCT" + ents.append({"text": token, "start": start, "end": end, "label": label}) + cursor = end + 1 # account for spaces + + record = { + "doc_id": "d1", + "sent_idx": 0, + "sent_start": 0, + "text": text, + "entities": ents, + } + ner_path = tmp_path / "ner.jsonl" + ner_path.write_text(json.dumps(record) + "\n", encoding="utf-8") + + buf = io.StringIO() + monkeypatch.setattr("sys.stdout", buf) + monkeypatch.setattr("sys.argv", ["candidates.py", str(ner_path)]) + + candidates.main() + + lines = [line for line in buf.getvalue().splitlines() if line] + assert len(lines) == candidates.MAX_PAIRS_PER_SENT + + +def test_main_falls_back_to_chunks_without_entities(tmp_path, monkeypatch): + record = { + "doc_id": "d2", + "sent_idx": 0, + "sent_start": 0, + "text": "Solar Panel helps Bright Homes", + "entities": [], + } + ner_path = tmp_path / "ner.jsonl" + ner_path.write_text(json.dumps(record) + "\n", encoding="utf-8") + + buf = io.StringIO() + monkeypatch.setattr("sys.stdout", buf) + monkeypatch.setattr("sys.argv", ["candidates.py", str(ner_path)]) + + candidates.main() + + lines = [json.loads(line) for line in buf.getvalue().splitlines() if line] + assert lines, "expected chunk-derived candidates" + for rec in lines: + assert rec["subject"]["label"] == "NOUNPHRASE" + assert rec["object"]["label"] == "NOUNPHRASE" diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..d2d52a3 --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,9 @@ +import io +import runpy + + +def test_cli_entrypoint_prints_message(monkeypatch): + buf = io.StringIO() + monkeypatch.setattr("sys.stdout", buf) + runpy.run_module("groundkg.cli", run_name="__main__") + assert "Use `python -m groundkg.extract_open" in buf.getvalue() diff --git a/tests/test_dedupe_edges.py b/tests/test_dedupe_edges.py new file mode 100644 index 0000000..4d73da0 --- /dev/null +++ b/tests/test_dedupe_edges.py @@ -0,0 +1,69 @@ +import io +import json + +from groundkg import dedupe_edges + + +def test_key_normalizes_fields(): + edge = { + "subject": " Alice ", + "predicate": "uses", + "object": " Gadget ", + "evidence": {"quote": "Alice uses the gadget."}, + } + assert dedupe_edges.key(edge) == ("alice", "uses", "gadget", "Alice uses the gadget.") + + +def test_main_filters_duplicates(tmp_path, monkeypatch): + edge = { + "subject": "Alice", + "predicate": "uses", + "object": "Gadget", + "evidence": {"quote": "Alice uses the gadget."}, + } + dup_path = tmp_path / "edges.jsonl" + dup_path.write_text("\n".join(json.dumps(e) for e in (edge, edge)) + "\n", encoding="utf-8") + + buf = io.StringIO() + monkeypatch.setattr("sys.argv", ["dedupe_edges.py", str(dup_path)]) + monkeypatch.setattr("sys.stdout", buf) + + dedupe_edges.main() + + lines = buf.getvalue().splitlines() + assert len(lines) == 1 + assert json.loads(lines[0]) == edge + + +def test_key_handles_missing_evidence_quote(): + edge = {"subject": "Alice", "predicate": "uses", "object": "Gadget"} + assert dedupe_edges.key(edge) == ("alice", "uses", "gadget", "") + + +def test_main_dedupes_whitespace_only_quotes(tmp_path, monkeypatch): + edges = [ + { + "subject": "Alice", + "predicate": "uses", + "object": "Gadget", + "evidence": {"quote": " Alice uses the gadget. "}, + }, + { + "subject": "alice ", + "predicate": "uses", + "object": "gadget", + "evidence": {"quote": "Alice uses the gadget."}, + }, + ] + dup_path = tmp_path / "edges.jsonl" + dup_path.write_text("\n".join(json.dumps(e) for e in edges) + "\n", encoding="utf-8") + + buf = io.StringIO() + monkeypatch.setattr("sys.argv", ["dedupe_edges.py", str(dup_path)]) + monkeypatch.setattr("sys.stdout", buf) + + dedupe_edges.main() + + lines = [json.loads(line) for line in buf.getvalue().splitlines() if line] + assert len(lines) == 1 + assert lines[0]["subject"].strip().lower() == "alice" diff --git a/tests/test_export_ttl.py b/tests/test_export_ttl.py new file mode 100644 index 0000000..3224d33 --- /dev/null +++ b/tests/test_export_ttl.py @@ -0,0 +1,70 @@ +import io +import json + +from groundkg import export_ttl + + +def test_iri_sanitizes_text(): + assert export_ttl.iri("node", "Acme, Inc./R&D") == "ex:node/Acme_Inc._R&D" + + +def test_emit_edge_triple_builds_expected_turtle(): + triple, subj = export_ttl.emit_edge_triple({"subject": "Alice", "predicate": "uses", "object": "Gadget"}) + assert triple == "ex:node/Alice ex:uses ex:node/Gadget .\n" + assert subj == "ex:node/Alice" + + +def test_emit_attr_triples_formats_values(): + attr = { + "name": "Battery Life", + "valueNumber": 12, + "unit": "hours", + "valueBoolean": True, + "valueString": "High capacity", + "time": "2023-05-01", + "evidence": {"char_start": 42}, + } + rendered = export_ttl.emit_attr_triples(attr, "ex:node/Alice") + assert "ex:hasAttribute" in rendered + assert "ex:valueNumber 12" in rendered + assert "ex:unit \"hours\"" in rendered + assert "ex:valueBoolean true" in rendered + assert "ex:valueString \"High capacity\"" in rendered + assert rendered.endswith(" .\n") + + +def test_main_reads_edges_and_attributes(tmp_path, monkeypatch): + edges_path = tmp_path / "edges.jsonl" + attrs_path = tmp_path / "attributes.jsonl" + edge = {"subject": "Alice", "predicate": "uses", "object": "Gadget"} + edges_path.write_text(json.dumps(edge) + "\n", encoding="utf-8") + attrs_path.write_text(json.dumps({"name": "Battery", "valueNumber": 3}) + "\n", encoding="utf-8") + + buf = io.StringIO() + monkeypatch.setattr("sys.argv", ["export_ttl.py", str(edges_path)]) + monkeypatch.setattr("sys.stdout", buf) + + export_ttl.main() + + output = buf.getvalue() + assert output.startswith(export_ttl.PREFIX) + assert "ex:node/Alice ex:uses ex:node/Gadget" in output + assert "ex:hasAttribute" in output + + +def test_main_ignores_malformed_attribute_lines(tmp_path, monkeypatch): + edges_path = tmp_path / "edges.jsonl" + attrs_path = tmp_path / "attributes.jsonl" + edge = {"subject": "Alice", "predicate": "uses", "object": "Gadget"} + edges_path.write_text(json.dumps(edge) + "\n", encoding="utf-8") + attrs_path.write_text("{" + "\n", encoding="utf-8") # malformed JSON + + buf = io.StringIO() + monkeypatch.setattr("sys.argv", ["export_ttl.py", str(edges_path)]) + monkeypatch.setattr("sys.stdout", buf) + + export_ttl.main() + + output = buf.getvalue() + assert "ex:node/Alice ex:uses ex:node/Gadget" in output + assert "ex:hasAttribute" not in output diff --git a/tests/test_ner_tag.py b/tests/test_ner_tag.py new file mode 100644 index 0000000..9eb1398 --- /dev/null +++ b/tests/test_ner_tag.py @@ -0,0 +1,131 @@ +import io +import json + +import pytest + +from groundkg import ner_tag + + +class FakeEntity: + def __init__(self, text, start_char, end_char, label): + self.text = text + self.start_char = start_char + self.end_char = end_char + self.label_ = label + + +class FakeSentence: + def __init__(self, text, start, start_char): + self.text = text + self.start = start + self.start_char = start_char + self.end_char = start_char + len(text) + + +class FakeDoc: + def __init__(self, sents, ents): + self.sents = sents + self.ents = ents + + +class FakeNLP: + def __init__(self, doc): + self._doc = doc + self.enabled_pipes = [] + + def enable_pipe(self, name): + self.enabled_pipes.append(name) + + def __call__(self, text): + return self._doc + + +@pytest.fixture +def fake_doc(): + sent_text = "Alice lives in Paris." + sentence = FakeSentence(sent_text, start=0, start_char=0) + ents = [ + FakeEntity("Alice", 0, 5, "PERSON"), + FakeEntity("Paris", 15, 20, "GPE"), + ] + return FakeDoc([sentence], ents) + + +def test_main_streams_sentence_entities(tmp_path, monkeypatch, fake_doc): + text_path = tmp_path / "doc.txt" + text_path.write_text("Alice lives in Paris.", encoding="utf-8") + + fake_nlp = FakeNLP(fake_doc) + + def fake_load(model, disable): + assert model == "en_core_web_trf" + assert "textcat" in disable + return fake_nlp + + monkeypatch.setattr(ner_tag.spacy, "load", fake_load) + monkeypatch.setattr( + "sys.argv", + [ + "ner_tag.py", + str(text_path), + "--doc-id", + "doc-1", + "--model", + "en_core_web_trf", + ], + ) + + buf = io.StringIO() + monkeypatch.setattr("sys.stdout", buf) + + ner_tag.main() + + lines = [json.loads(line) for line in buf.getvalue().splitlines() if line] + assert len(lines) == 1 + record = lines[0] + assert record["doc_id"] == "doc-1" + assert record["sent_idx"] == 0 + assert record["text"] == "Alice lives in Paris." + assert record["entities"] == [ + {"text": "Alice", "start": 0, "end": 5, "label": "PERSON"}, + {"text": "Paris", "start": 15, "end": 20, "label": "GPE"}, + ] + assert fake_nlp.enabled_pipes == ["ner"] + + +def test_main_defaults_doc_id_and_handles_multiple_sentences(tmp_path, monkeypatch): + text = "Alice meets Bob. Charlie visits Rome." + text_path = tmp_path / "doc.txt" + text_path.write_text(text, encoding="utf-8") + + sentences = [ + FakeSentence("Alice meets Bob.", start=0, start_char=0), + FakeSentence("Charlie visits Rome.", start=3, start_char=17), + ] + ents = [ + FakeEntity("Alice", 0, 5, "PERSON"), + FakeEntity("Charlie", 17, 24, "PERSON"), + FakeEntity("Rome", 32, 36, "GPE"), + ] + fake_doc = FakeDoc(sentences, ents) + fake_nlp = FakeNLP(fake_doc) + + def fake_load(model, disable): + assert model == "en_core_web_trf" + return fake_nlp + + monkeypatch.setattr(ner_tag.spacy, "load", fake_load) + monkeypatch.setattr("sys.argv", ["ner_tag.py", str(text_path)]) + + buf = io.StringIO() + monkeypatch.setattr("sys.stdout", buf) + + ner_tag.main() + + lines = [json.loads(line) for line in buf.getvalue().splitlines() if line] + assert [rec["doc_id"] for rec in lines] == ["doc", "doc"] + assert [rec["text"] for rec in lines] == ["Alice meets Bob.", "Charlie visits Rome."] + # Ensure entity offsets are relative to each sentence + second_entities = lines[1]["entities"] + assert any(ent == {"text": "Charlie", "start": 0, "end": 7, "label": "PERSON"} for ent in second_entities) + assert any(ent == {"text": "Rome", "start": 15, "end": 19, "label": "GPE"} for ent in second_entities) diff --git a/tests/test_re_modules.py b/tests/test_re_modules.py new file mode 100644 index 0000000..8ac686f --- /dev/null +++ b/tests/test_re_modules.py @@ -0,0 +1,269 @@ +import io +import json +import sys +import types + +import pytest + +if "numpy" not in sys.modules: + fake_np = types.ModuleType("numpy") + + def array(data, dtype=None): + return data + + def argmax(seq): + return max(range(len(seq)), key=lambda i: seq[i]) + + def zeros(shape, dtype=float): + rows, cols = shape + return [[dtype(0) for _ in range(cols)] for _ in range(rows)] + + fake_np.array = array + fake_np.argmax = argmax + fake_np.zeros = zeros + fake_np.isscalar = lambda value: isinstance(value, (int, float)) + sys.modules["numpy"] = fake_np + +from groundkg import re_infer, re_score + + +def test_re_score_mark_orders_entities(): + text = "The gadget Alice built" + subject = {"start": 4, "end": 10} + obj = {"start": 11, "end": 16} + marked = re_score.mark(text, obj, subject) # subject starts later than object + assert "[E1]" in marked and "[E2]" in marked + assert marked.index("[E1]") < marked.index("[E2]") + + +def test_re_score_main_streams_predictions(tmp_path, monkeypatch): + cand_path = tmp_path / "cands.jsonl" + candidates = [ + { + "doc_id": "d1", + "sent_start": 0, + "text": "Alice uses the gadget", + "subject": {"text": "Alice", "start": 0, "end": 5}, + "object": {"text": "gadget", "start": 12, "end": 18}, + } + ] + cand_path.write_text("\n".join(json.dumps(c) for c in candidates) + "\n", encoding="utf-8") + + onnx_path = tmp_path / "model.onnx" + onnx_path.write_text("", encoding="utf-8") + classes_path = tmp_path / "classes.json" + classes = ["none", "uses"] + classes_path.write_text(json.dumps(classes), encoding="utf-8") + + class FakeInput: + name = "text" + + class FakeSession: + def get_inputs(self): + return [FakeInput()] + + def run(self, _outputs, feeds): + assert isinstance(feeds, dict) + probs = fake_np.zeros((1, len(classes)), dtype=float) + probs[0][1] = 0.9 + return [["uses"], probs] + + monkeypatch.setattr(re_score.ort, "InferenceSession", lambda *a, **k: FakeSession()) + buf = io.StringIO() + monkeypatch.setattr("sys.stdout", buf) + monkeypatch.setattr( + "sys.argv", + [ + "re_score.py", + str(cand_path), + str(onnx_path), + str(classes_path), + ], + ) + + re_score.main() + + lines = [json.loads(line) for line in buf.getvalue().splitlines() if line] + assert len(lines) == 1 + record = lines[0] + assert record["pred"] == "uses" + assert abs(record["prob"] - 0.9) < 1e-6 + assert record["subject"]["text"] == "Alice" + assert record["object"]["text"] == "gadget" + + +def test_re_score_missing_model_exits(tmp_path, monkeypatch, capsys): + cand_path = tmp_path / "cands.jsonl" + cand_path.write_text("{}\n", encoding="utf-8") + missing_model = tmp_path / "missing.onnx" + classes_path = tmp_path / "classes.json" + classes_path.write_text(json.dumps(["none"]), encoding="utf-8") + + monkeypatch.setattr( + "sys.argv", + ["re_score.py", str(cand_path), str(missing_model), str(classes_path)], + ) + + with pytest.raises(SystemExit) as exc: + re_score.main() + + assert exc.value.code == 2 + captured = capsys.readouterr() + assert "missing" in captured.err.lower() + + +def test_type_compatible_enforces_allowed_pairs(): + assert re_infer.type_compatible("uses", "PERSON", "PRODUCT") + assert not re_infer.type_compatible("uses", "PERSON", "GPE") + + +def test_mark_indicates_swapped_subject_object(): + text = "Paris is home to Alice" + subject = {"start": 17, "end": 22} + obj = {"start": 0, "end": 5} + marked, swapped = re_infer.mark(text, subject, obj) + assert swapped is True + assert marked.startswith("[E1]Paris") + + +def test_main_filters_by_threshold_and_types(tmp_path, monkeypatch): + cand_path = tmp_path / "cands.jsonl" + preds_yaml = tmp_path / "preds.yaml" + preds_yaml.write_text("predicates", encoding="utf-8") + onnx_path = tmp_path / "model.onnx" + onnx_path.write_text("", encoding="utf-8") + thresh_path = tmp_path / "thresh.json" + thresholds = {"uses": 0.5} + thresh_path.write_text(json.dumps(thresholds), encoding="utf-8") + + candidates = [ + { + "doc_id": "d1", + "sent_start": 0, + "text": "Alice uses the gadget", + "subject": {"text": "Alice", "start": 0, "end": 5, "label": "PERSON"}, + "object": {"text": "gadget", "start": 12, "end": 18, "label": "PRODUCT"}, + }, + { + "doc_id": "d1", + "sent_start": 0, + "text": "Alice uses the city", + "subject": {"text": "Alice", "start": 0, "end": 5, "label": "PERSON"}, + "object": {"text": "city", "start": 12, "end": 16, "label": "GPE"}, + }, + ] + cand_path.write_text("\n".join(json.dumps(c) for c in candidates) + "\n", encoding="utf-8") + + classes = json.load(open("models/classes.json", "r", encoding="utf-8")) + uses_idx = classes.index("uses") + + class FakeInput: + name = "input_text" + + class FakeSession: + def get_inputs(self): + return [FakeInput()] + + def run(self, *_args, **_kwargs): + probs = fake_np.zeros((1, len(classes)), dtype=float) + probs[0][uses_idx] = 0.92 + labels = ["uses"] + return [labels, probs] + + monkeypatch.setattr(re_infer.ort, "InferenceSession", lambda *a, **k: FakeSession()) + buf = io.StringIO() + monkeypatch.setattr( + "sys.argv", + [ + "re_infer.py", + str(cand_path), + str(preds_yaml), + str(onnx_path), + str(thresh_path), + ], + ) + monkeypatch.setattr("sys.stdout", buf) + + re_infer.main() + + lines = [line for line in buf.getvalue().splitlines() if line] + assert len(lines) == 1 + edge = json.loads(lines[0]) + assert edge["subject"] == "Alice" + assert edge["object"] == "gadget" + assert edge["predicate"] == "uses" + + +def test_main_skips_low_prob_and_allows_unknown_predicate(tmp_path, monkeypatch): + cand_path = tmp_path / "cands.jsonl" + preds_yaml = tmp_path / "preds.yaml" + preds_yaml.write_text("predicates", encoding="utf-8") + onnx_path = tmp_path / "model.onnx" + onnx_path.write_text("", encoding="utf-8") + thresh_path = tmp_path / "thresh.json" + thresholds = {"uses": 0.8, "provides": 0.5} + thresh_path.write_text(json.dumps(thresholds), encoding="utf-8") + + candidates = [ + { + "doc_id": "d1", + "sent_start": 0, + "text": "Alice uses the gadget", + "subject": {"text": "Alice", "start": 0, "end": 5, "label": "PERSON"}, + "object": {"text": "gadget", "start": 12, "end": 18, "label": "PRODUCT"}, + }, + { + "doc_id": "d2", + "sent_start": 10, + "text": "Bob provides Rome with tools", + "subject": {"text": "Bob", "start": 0, "end": 3, "label": "PERSON"}, + "object": {"text": "Rome", "start": 13, "end": 17, "label": "GPE"}, + }, + ] + cand_path.write_text("\n".join(json.dumps(c) for c in candidates) + "\n", encoding="utf-8") + + classes = json.load(open("models/classes.json", "r", encoding="utf-8")) + uses_idx = classes.index("uses") + provides_idx = classes.index("provides") + + class FakeInput: + name = "input_text" + + class FakeSession: + def __init__(self): + self.outputs = [ + [["uses"], [[0.0 for _ in classes]]], + [["provides"], [[0.0 for _ in classes]]], + ] + self.outputs[0][1][0][uses_idx] = 0.6 # below threshold + self.outputs[1][1][0][provides_idx] = 0.9 + + def get_inputs(self): + return [FakeInput()] + + def run(self, *_args, **_kwargs): + return self.outputs.pop(0) + + monkeypatch.setattr(re_infer.ort, "InferenceSession", lambda *a, **k: FakeSession()) + monkeypatch.setattr(re_infer, "ALLOWED_TYPES", {k: v for k, v in re_infer.ALLOWED_TYPES.items() if k != "provides"}) + buf = io.StringIO() + monkeypatch.setattr( + "sys.argv", + [ + "re_infer.py", + str(cand_path), + str(preds_yaml), + str(onnx_path), + str(thresh_path), + ], + ) + monkeypatch.setattr("sys.stdout", buf) + + re_infer.main() + + lines = [json.loads(line) for line in buf.getvalue().splitlines() if line] + assert len(lines) == 1 + edge = lines[0] + assert edge["predicate"] == "provides" + assert edge["subject"] == "Bob" + assert edge["object"] == "Rome" diff --git a/tools/run_trace_coverage.py b/tools/run_trace_coverage.py new file mode 100755 index 0000000..34aa611 --- /dev/null +++ b/tools/run_trace_coverage.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +"""Run pytest under trace module and report package coverage.""" +from __future__ import annotations + +import argparse +import ast +import pathlib +import sys +from trace import Trace + +import pytest + + +def statement_lines(path: pathlib.Path) -> set[int]: + tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) + return {node.lineno for node in ast.walk(tree) if isinstance(node, ast.stmt)} + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--package", default="groundkg", help="Package directory to measure") + parser.add_argument("--min", type=float, default=0.0, help="Minimum required coverage percentage") + parser.add_argument("pytest_args", nargs=argparse.REMAINDER, help="Arguments forwarded to pytest") + args = parser.parse_args() + + pkg_path = pathlib.Path(args.package) + if not pkg_path.exists(): + parser.error(f"Package path {pkg_path} not found") + + tracer = Trace(count=True, trace=False, ignoredirs=[sys.prefix, sys.exec_prefix]) + exit_code = tracer.runfunc(pytest.main, args.pytest_args or []) + if exit_code != 0: + return exit_code + + results = tracer.results() + counts = results.counts + + total_statements = 0 + total_executed = 0 + report_lines = [] + for path in sorted(pkg_path.rglob("*.py")): + if path.name == "__init__.py" and not path.read_text(encoding="utf-8").strip(): + continue + stmts = statement_lines(path) + if not stmts: + continue + executed = sum(1 for line in stmts if counts.get((str(path.resolve()), line), 0)) + total_statements += len(stmts) + total_executed += executed + pct = 100.0 * executed / len(stmts) + report_lines.append(f"{path}: {pct:.1f}% ({executed}/{len(stmts)})") + + overall = 100.0 * total_executed / total_statements if total_statements else 100.0 + print("\n".join(report_lines)) + print(f"Overall coverage for {pkg_path}: {overall:.1f}% ({total_executed}/{total_statements})") + + if overall < args.min: + print(f"Coverage {overall:.1f}% is below required minimum {args.min}%", file=sys.stderr) + return 1 + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())