Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 194 additions & 0 deletions tests/test_event_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
import json
import sys

import pytest

np = pytest.importorskip("numpy")

from groundkg import event_extract, events_to_edges, re_score


def test_event_extract_main_generates_events(tmp_path, monkeypatch):
manifest = tmp_path / "manifest.jsonl"
events_out = tmp_path / "events.jsonl"

records = [
{
"doc_id": "doc1",
"text": "MegaCorp acquired StartUp for $5 million on Jan 2, 2022.",
},
{
"doc_id": "doc2",
"text": "Bright Future secured $3M from Big VC on Feb 5, 2021.",
},
{
"doc_id": "doc3",
"text": "Tech Corp launched HyperWidget on 2023.",
},
{
"doc_id": "doc4",
"source_org": "ACME",
"text": "ACME appointed Jane Doe as CTO on Mar 3, 2020.",
},
{
"doc_id": "doc5",
"text": "John Smith founded Future Labs in 2019.",
},
]
manifest.write_text("\n".join(json.dumps(r) for r in records) + "\n", encoding="utf-8")

monkeypatch.setattr(
sys,
"argv",
[
"event_extract",
"--manifest",
str(manifest),
"--out",
str(events_out),
],
)

event_extract.main()

lines = [json.loads(line) for line in events_out.read_text(encoding="utf-8").splitlines()]

event_types = {line["type"] for line in lines}
expected_types = {"Acquisition", "Funding", "Launch", "Appointment", "Founding"}
assert expected_types.issubset(event_types)

acq = next(line for line in lines if line["type"] == "Acquisition")
assert acq["roles"].get("acquirer") == "MegaCorp"
assert acq["roles"].get("target", "").startswith("StartUp")
assert acq["amount_text"] == "$5 million"
assert acq["date_text"] == "Jan 2, 2022"

funding = next(line for line in lines if line["type"] == "Funding")
assert funding["roles"]["recipient"] == "Bright Future"

assert any(ev["roles"].get("actor") == "ACME" for ev in lines if ev["type"] == "Appointment")

founding = next(line for line in lines if line["type"] == "Founding")
assert "founder_or_actor" in founding["roles"]


def test_events_to_edges_main(tmp_path, monkeypatch):
events_file = tmp_path / "events.jsonl"
edges_out = tmp_path / "edges.jsonl"

event_record = {
"event_id": "E1",
"type": "Acquisition",
"trigger": "acquired",
"date_text": "Jan 2, 2022",
"amount_text": "$5 million",
"roles": {"acquirer": "MegaCorp", "target": "StartUp", "empty": ""},
"confidence": 0.75,
"source": "doc1#s",
}
events_file.write_text(json.dumps(event_record) + "\n", encoding="utf-8")

monkeypatch.setattr(
sys,
"argv",
["events_to_edges", "--events", str(events_file), "--out", str(edges_out)],
)

events_to_edges.main()

edges = [json.loads(line) for line in edges_out.read_text(encoding="utf-8").splitlines()]
subjects = {edge["subject"] for edge in edges}
assert subjects == {"event:E1"}
predicates = {edge["predicate"] for edge in edges}
assert predicates == {"type", "trigger", "date", "amount", "acquirer", "target"}


def test_re_score_mark_orders_entities():
text = "Object met Subject"
subject = {"start": 11, "end": 18}
obj = {"start": 0, "end": 6}
marked = re_score.mark(text, subject, obj)
assert marked.startswith("[E1]Object[/E1] met [E2]Subject[/E2]")


def test_re_score_main_batches(tmp_path, monkeypatch, capsys):
cand_path = tmp_path / "candidates.jsonl"
onnx_path = tmp_path / "model.onnx"
classes_path = tmp_path / "classes.json"

num_candidates = 33
candidates = []
for i in range(num_candidates):
candidates.append(
{
"doc_id": f"doc{i}",
"sent_start": i,
"text": f"Sentence {i}",
"subject": {"start": 0, "end": 7},
"object": {"start": 9, "end": 12},
}
)
cand_path.write_text("\n".join(json.dumps(c) for c in candidates) + "\n", encoding="utf-8")
onnx_path.write_text("placeholder", encoding="utf-8")
classes_path.write_text(json.dumps(["NEG", "POS"]), encoding="utf-8")

class DummyEmbedder:
def __init__(self):
self.calls = []

def encode(self, texts, show_progress_bar=False, convert_to_numpy=True):
self.calls.append(list(texts))
batch = np.arange(len(texts) * re_score.EMBEDDING_DIM, dtype=np.float32)
return batch.reshape(len(texts), re_score.EMBEDDING_DIM)

class DummyInput:
def __init__(self):
self.name = "input"
self.shape = [None, re_score.EMBEDDING_DIM]

class DummyOutputInfo:
def __init__(self, name):
self.name = name
self.shape = [None, 2]
self.type = "tensor(float)"

class DummySession:
def __init__(self, path, providers):
self.path = path
self.providers = providers
self.calls = 0

def get_inputs(self):
return [DummyInput()]

def get_outputs(self):
return [DummyOutputInfo("label"), DummyOutputInfo("prob")]

def run(self, _, feeds):
self.calls += 1
probs = np.zeros((1, 2), dtype=np.float32)
probs[0, self.calls % 2] = 0.8
return [np.array(["label"]), probs]

dummy_embedder = DummyEmbedder()
monkeypatch.setattr(re_score, "get_embedder", lambda: dummy_embedder)
monkeypatch.setattr(re_score.ort, "InferenceSession", DummySession)

monkeypatch.setattr(
sys,
"argv",
[
"re_score",
str(cand_path),
str(onnx_path),
str(classes_path),
],
)

re_score.main()

captured = capsys.readouterr()
lines = [json.loads(line) for line in captured.out.splitlines()]
assert len(lines) == num_candidates
assert {rec["pred"] for rec in lines} <= {"NEG", "POS"}
assert dummy_embedder.calls # ensure embeddings were requested
118 changes: 6 additions & 112 deletions tests/test_re_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,113 +5,7 @@

import pytest

if "numpy" not in sys.modules:
fake_np = types.ModuleType("numpy")

class FakeArray:
def __init__(self, data, dtype=None):
self.data = data
self.dtype = dtype
if isinstance(data, list) and len(data) > 0 and isinstance(data[0], list):
self.shape = (len(data), len(data[0]))
elif isinstance(data, list):
self.shape = (len(data),)
else:
self.shape = ()

def reshape(self, *shape):
# Simple reshape - just return a new FakeArray with new shape
flat = self._flatten()
if len(shape) == 1:
if isinstance(shape[0], tuple):
new_shape = shape[0]
else:
# Handle reshape(1, -1) case
if shape[0] == 1:
return FakeArray([flat], self.dtype)
new_shape = shape[0]
elif len(shape) == 2:
# Handle reshape(1, -1) case
if shape[0] == 1:
return FakeArray([flat], self.dtype)
new_shape = shape
else:
new_shape = shape
return FakeArray(flat, self.dtype)

def _flatten(self):
result = []
for item in self.data:
if isinstance(item, list):
result.extend(item)
else:
result.append(item)
return result

def astype(self, dtype):
return FakeArray(self.data, dtype)

def flatten(self):
return FakeArray(self._flatten(), self.dtype)

def __getitem__(self, idx):
item = self.data[idx]
# If item is a list, wrap it in FakeArray for proper method access
if isinstance(item, list):
return FakeArray(item, self.dtype)
return item

def __setitem__(self, idx, value):
if isinstance(idx, tuple):
# Handle 2D indexing like probs[0][1] = 0.9
self.data[idx[0]][idx[1]] = value
else:
self.data[idx] = value

def __len__(self):
return len(self.data)

def __iter__(self):
for item in self.data:
if isinstance(item, list):
yield FakeArray(item, self.dtype)
else:
yield item

def array(data, dtype=None):
return FakeArray(data, dtype)

def argmax(seq):
if hasattr(seq, '__len__') and len(seq) > 0:
if hasattr(seq[0], '__len__'):
# 2D array, get argmax of first row
return max(range(len(seq[0])), key=lambda i: seq[0][i])
return max(range(len(seq)), key=lambda i: seq[i])
return 0

def zeros(shape, dtype=float):
if isinstance(shape, tuple) and len(shape) == 2:
rows, cols = shape
return FakeArray([[dtype(0) for _ in range(cols)] for _ in range(rows)], dtype)
elif isinstance(shape, tuple) and len(shape) == 1:
return FakeArray([dtype(0)] * shape[0], dtype)
else:
rows, cols = shape
return FakeArray([[dtype(0) for _ in range(cols)] for _ in range(rows)], dtype)

def asarray(data, dtype=None):
if isinstance(data, FakeArray):
return data
return FakeArray(data, dtype)

fake_np.array = array
fake_np.argmax = argmax
fake_np.zeros = zeros
fake_np.asarray = asarray
fake_np.isscalar = lambda value: isinstance(value, (int, float))
fake_np.float32 = float
fake_np.float64 = float
sys.modules["numpy"] = fake_np
np = pytest.importorskip("numpy")

from groundkg import re_infer, re_score

Expand Down Expand Up @@ -160,18 +54,18 @@ def get_outputs(self):

def run(self, _outputs, feeds):
assert isinstance(feeds, dict)
probs = fake_np.zeros((1, len(classes)), dtype=float)
probs = np.zeros((1, len(classes)), dtype=float)
probs[0][1] = 0.9
# Return FakeArray objects to match ONNX output format
return [fake_np.array(["uses"]), probs]
# Return numpy arrays to match ONNX output format
return [np.array(["uses"]), probs]

# Mock embedder to return fake embeddings (384 dims for all-MiniLM-L6-v2)
class FakeEmbedder:
def encode(self, texts, show_progress_bar=False, convert_to_numpy=True):
# Return fake embeddings: one per text, each 384 dimensions
if isinstance(texts, str):
texts = [texts]
return fake_np.array([[0.1] * 384 for _ in texts])
return np.array([[0.1] * 384 for _ in texts], dtype=float)

monkeypatch.setattr(re_score, "get_embedder", lambda: FakeEmbedder())
monkeypatch.setattr(re_score.ort, "InferenceSession", lambda *a, **k: FakeSession())
Expand Down Expand Up @@ -273,7 +167,7 @@ def get_inputs(self):
return [FakeInput()]

def run(self, *_args, **_kwargs):
probs = fake_np.zeros((1, len(classes)), dtype=float)
probs = np.zeros((1, len(classes)), dtype=float)
probs[0][uses_idx] = 0.92
labels = ["uses"]
return [labels, probs]
Expand Down