diff --git a/tests/test_event_pipeline.py b/tests/test_event_pipeline.py new file mode 100644 index 0000000..7fe25cd --- /dev/null +++ b/tests/test_event_pipeline.py @@ -0,0 +1,194 @@ +import json +import sys + +import pytest + +np = pytest.importorskip("numpy") + +from groundkg import event_extract, events_to_edges, re_score + + +def test_event_extract_main_generates_events(tmp_path, monkeypatch): + manifest = tmp_path / "manifest.jsonl" + events_out = tmp_path / "events.jsonl" + + records = [ + { + "doc_id": "doc1", + "text": "MegaCorp acquired StartUp for $5 million on Jan 2, 2022.", + }, + { + "doc_id": "doc2", + "text": "Bright Future secured $3M from Big VC on Feb 5, 2021.", + }, + { + "doc_id": "doc3", + "text": "Tech Corp launched HyperWidget on 2023.", + }, + { + "doc_id": "doc4", + "source_org": "ACME", + "text": "ACME appointed Jane Doe as CTO on Mar 3, 2020.", + }, + { + "doc_id": "doc5", + "text": "John Smith founded Future Labs in 2019.", + }, + ] + manifest.write_text("\n".join(json.dumps(r) for r in records) + "\n", encoding="utf-8") + + monkeypatch.setattr( + sys, + "argv", + [ + "event_extract", + "--manifest", + str(manifest), + "--out", + str(events_out), + ], + ) + + event_extract.main() + + lines = [json.loads(line) for line in events_out.read_text(encoding="utf-8").splitlines()] + + event_types = {line["type"] for line in lines} + expected_types = {"Acquisition", "Funding", "Launch", "Appointment", "Founding"} + assert expected_types.issubset(event_types) + + acq = next(line for line in lines if line["type"] == "Acquisition") + assert acq["roles"].get("acquirer") == "MegaCorp" + assert acq["roles"].get("target", "").startswith("StartUp") + assert acq["amount_text"] == "$5 million" + assert acq["date_text"] == "Jan 2, 2022" + + funding = next(line for line in lines if line["type"] == "Funding") + assert funding["roles"]["recipient"] == "Bright Future" + + assert any(ev["roles"].get("actor") == "ACME" for ev in lines if ev["type"] == "Appointment") + + founding = next(line for line in lines if line["type"] == "Founding") + assert "founder_or_actor" in founding["roles"] + + +def test_events_to_edges_main(tmp_path, monkeypatch): + events_file = tmp_path / "events.jsonl" + edges_out = tmp_path / "edges.jsonl" + + event_record = { + "event_id": "E1", + "type": "Acquisition", + "trigger": "acquired", + "date_text": "Jan 2, 2022", + "amount_text": "$5 million", + "roles": {"acquirer": "MegaCorp", "target": "StartUp", "empty": ""}, + "confidence": 0.75, + "source": "doc1#s", + } + events_file.write_text(json.dumps(event_record) + "\n", encoding="utf-8") + + monkeypatch.setattr( + sys, + "argv", + ["events_to_edges", "--events", str(events_file), "--out", str(edges_out)], + ) + + events_to_edges.main() + + edges = [json.loads(line) for line in edges_out.read_text(encoding="utf-8").splitlines()] + subjects = {edge["subject"] for edge in edges} + assert subjects == {"event:E1"} + predicates = {edge["predicate"] for edge in edges} + assert predicates == {"type", "trigger", "date", "amount", "acquirer", "target"} + + +def test_re_score_mark_orders_entities(): + text = "Object met Subject" + subject = {"start": 11, "end": 18} + obj = {"start": 0, "end": 6} + marked = re_score.mark(text, subject, obj) + assert marked.startswith("[E1]Object[/E1] met [E2]Subject[/E2]") + + +def test_re_score_main_batches(tmp_path, monkeypatch, capsys): + cand_path = tmp_path / "candidates.jsonl" + onnx_path = tmp_path / "model.onnx" + classes_path = tmp_path / "classes.json" + + num_candidates = 33 + candidates = [] + for i in range(num_candidates): + candidates.append( + { + "doc_id": f"doc{i}", + "sent_start": i, + "text": f"Sentence {i}", + "subject": {"start": 0, "end": 7}, + "object": {"start": 9, "end": 12}, + } + ) + cand_path.write_text("\n".join(json.dumps(c) for c in candidates) + "\n", encoding="utf-8") + onnx_path.write_text("placeholder", encoding="utf-8") + classes_path.write_text(json.dumps(["NEG", "POS"]), encoding="utf-8") + + class DummyEmbedder: + def __init__(self): + self.calls = [] + + def encode(self, texts, show_progress_bar=False, convert_to_numpy=True): + self.calls.append(list(texts)) + batch = np.arange(len(texts) * re_score.EMBEDDING_DIM, dtype=np.float32) + return batch.reshape(len(texts), re_score.EMBEDDING_DIM) + + class DummyInput: + def __init__(self): + self.name = "input" + self.shape = [None, re_score.EMBEDDING_DIM] + + class DummyOutputInfo: + def __init__(self, name): + self.name = name + self.shape = [None, 2] + self.type = "tensor(float)" + + class DummySession: + def __init__(self, path, providers): + self.path = path + self.providers = providers + self.calls = 0 + + def get_inputs(self): + return [DummyInput()] + + def get_outputs(self): + return [DummyOutputInfo("label"), DummyOutputInfo("prob")] + + def run(self, _, feeds): + self.calls += 1 + probs = np.zeros((1, 2), dtype=np.float32) + probs[0, self.calls % 2] = 0.8 + return [np.array(["label"]), probs] + + dummy_embedder = DummyEmbedder() + monkeypatch.setattr(re_score, "get_embedder", lambda: dummy_embedder) + monkeypatch.setattr(re_score.ort, "InferenceSession", DummySession) + + monkeypatch.setattr( + sys, + "argv", + [ + "re_score", + str(cand_path), + str(onnx_path), + str(classes_path), + ], + ) + + re_score.main() + + captured = capsys.readouterr() + lines = [json.loads(line) for line in captured.out.splitlines()] + assert len(lines) == num_candidates + assert {rec["pred"] for rec in lines} <= {"NEG", "POS"} + assert dummy_embedder.calls # ensure embeddings were requested diff --git a/tests/test_re_modules.py b/tests/test_re_modules.py index 7a42503..befaac4 100644 --- a/tests/test_re_modules.py +++ b/tests/test_re_modules.py @@ -5,113 +5,7 @@ import pytest -if "numpy" not in sys.modules: - fake_np = types.ModuleType("numpy") - - class FakeArray: - def __init__(self, data, dtype=None): - self.data = data - self.dtype = dtype - if isinstance(data, list) and len(data) > 0 and isinstance(data[0], list): - self.shape = (len(data), len(data[0])) - elif isinstance(data, list): - self.shape = (len(data),) - else: - self.shape = () - - def reshape(self, *shape): - # Simple reshape - just return a new FakeArray with new shape - flat = self._flatten() - if len(shape) == 1: - if isinstance(shape[0], tuple): - new_shape = shape[0] - else: - # Handle reshape(1, -1) case - if shape[0] == 1: - return FakeArray([flat], self.dtype) - new_shape = shape[0] - elif len(shape) == 2: - # Handle reshape(1, -1) case - if shape[0] == 1: - return FakeArray([flat], self.dtype) - new_shape = shape - else: - new_shape = shape - return FakeArray(flat, self.dtype) - - def _flatten(self): - result = [] - for item in self.data: - if isinstance(item, list): - result.extend(item) - else: - result.append(item) - return result - - def astype(self, dtype): - return FakeArray(self.data, dtype) - - def flatten(self): - return FakeArray(self._flatten(), self.dtype) - - def __getitem__(self, idx): - item = self.data[idx] - # If item is a list, wrap it in FakeArray for proper method access - if isinstance(item, list): - return FakeArray(item, self.dtype) - return item - - def __setitem__(self, idx, value): - if isinstance(idx, tuple): - # Handle 2D indexing like probs[0][1] = 0.9 - self.data[idx[0]][idx[1]] = value - else: - self.data[idx] = value - - def __len__(self): - return len(self.data) - - def __iter__(self): - for item in self.data: - if isinstance(item, list): - yield FakeArray(item, self.dtype) - else: - yield item - - def array(data, dtype=None): - return FakeArray(data, dtype) - - def argmax(seq): - if hasattr(seq, '__len__') and len(seq) > 0: - if hasattr(seq[0], '__len__'): - # 2D array, get argmax of first row - return max(range(len(seq[0])), key=lambda i: seq[0][i]) - return max(range(len(seq)), key=lambda i: seq[i]) - return 0 - - def zeros(shape, dtype=float): - if isinstance(shape, tuple) and len(shape) == 2: - rows, cols = shape - return FakeArray([[dtype(0) for _ in range(cols)] for _ in range(rows)], dtype) - elif isinstance(shape, tuple) and len(shape) == 1: - return FakeArray([dtype(0)] * shape[0], dtype) - else: - rows, cols = shape - return FakeArray([[dtype(0) for _ in range(cols)] for _ in range(rows)], dtype) - - def asarray(data, dtype=None): - if isinstance(data, FakeArray): - return data - return FakeArray(data, dtype) - - fake_np.array = array - fake_np.argmax = argmax - fake_np.zeros = zeros - fake_np.asarray = asarray - fake_np.isscalar = lambda value: isinstance(value, (int, float)) - fake_np.float32 = float - fake_np.float64 = float - sys.modules["numpy"] = fake_np +np = pytest.importorskip("numpy") from groundkg import re_infer, re_score @@ -160,10 +54,10 @@ def get_outputs(self): def run(self, _outputs, feeds): assert isinstance(feeds, dict) - probs = fake_np.zeros((1, len(classes)), dtype=float) + probs = np.zeros((1, len(classes)), dtype=float) probs[0][1] = 0.9 - # Return FakeArray objects to match ONNX output format - return [fake_np.array(["uses"]), probs] + # Return numpy arrays to match ONNX output format + return [np.array(["uses"]), probs] # Mock embedder to return fake embeddings (384 dims for all-MiniLM-L6-v2) class FakeEmbedder: @@ -171,7 +65,7 @@ def encode(self, texts, show_progress_bar=False, convert_to_numpy=True): # Return fake embeddings: one per text, each 384 dimensions if isinstance(texts, str): texts = [texts] - return fake_np.array([[0.1] * 384 for _ in texts]) + return np.array([[0.1] * 384 for _ in texts], dtype=float) monkeypatch.setattr(re_score, "get_embedder", lambda: FakeEmbedder()) monkeypatch.setattr(re_score.ort, "InferenceSession", lambda *a, **k: FakeSession()) @@ -273,7 +167,7 @@ def get_inputs(self): return [FakeInput()] def run(self, *_args, **_kwargs): - probs = fake_np.zeros((1, len(classes)), dtype=float) + probs = np.zeros((1, len(classes)), dtype=float) probs[0][uses_idx] = 0.92 labels = ["uses"] return [labels, probs]