-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinferred_relations.py
More file actions
97 lines (79 loc) · 3.8 KB
/
inferred_relations.py
File metadata and controls
97 lines (79 loc) · 3.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from functools import lru_cache
from typing import List, Tuple, Dict, Iterable, Optional
from gda_hash import GdaPositionIndex, gda_hash
class InferredRelations:
"""Compatibility shim backed by position-hash inference.
Lightweight implementation focused on the methods we need for
UnForkCollection and basic Chroma compatibility.
Now uses GDA 24-bit hash for 3x faster lookups.
"""
def __init__(self, token_positions: Dict[str, List[int]], corpus_tokens: List[str],
cache_size: int = 4096, use_gda: bool = True):
self.corpus_tokens = corpus_tokens
self.use_gda = use_gda
if use_gda:
# Use fast GDA hash index
self.gda_index = GdaPositionIndex()
# Migrate existing positions to GDA
for token, positions in token_positions.items():
for pos in positions:
self.gda_index.add(token, pos)
self.token_positions = token_positions # Keep for compatibility
else:
self.token_positions = token_positions
self.gda_index = None
# Use an lru_cache-wrapped method for inference
self._cache = lru_cache(maxsize=cache_size)(self._infer_relations_cached)
def add_token(self, token: str, position: int):
"""Add a token position (uses GDA if enabled)."""
key = token.lower()
if self.use_gda and self.gda_index:
self.gda_index.add(key, position)
self.token_positions.setdefault(key, []).append(position)
def get_positions(self, token: str) -> List[int]:
"""Get positions for token (uses GDA if enabled for speed)."""
key = token.lower()
if self.use_gda and self.gda_index:
return self.gda_index.get(key)
return self.token_positions.get(key, [])
def infer_relations_from_positions(self, query_tokens: Iterable[str], window_size: int = 5, invert: bool = False) -> List[Tuple[str, str, str]]:
key = (tuple(query_tokens), window_size, invert)
return list(self._cache(key))
def _infer_relations_cached(self, key):
query_tokens, window_size, invert = key
rels = []
for token in query_tokens:
token_key = token.lower()
positions = self.get_positions(token_key)
if not positions:
continue
for pos in positions:
if invert:
neighbor_range = range(max(0, pos - window_size), pos)
else:
neighbor_range = range(pos + 1, min(len(self.corpus_tokens), pos + window_size + 1))
for npos in neighbor_range:
neighbor_token = self.corpus_tokens[npos]
distance = abs(npos - pos)
if distance == 1:
predicate = "NEXT_TO"
elif distance <= 2:
predicate = "NEAR"
elif distance <= 5:
predicate = "RELATED"
else:
predicate = "DISTANT"
if invert:
rels.append((neighbor_token, predicate, token_key))
else:
rels.append((token_key, predicate, neighbor_token))
return tuple(rels)
# Compatibility helpers
@property
def triples(self):
# Expensive on large corpora - return empty by default to avoid surprises
return []
def by_subject(self, subject: str):
return self.infer_relations_from_positions([subject], window_size=5)
def by_object(self, obj: str):
return self.infer_relations_from_positions([obj], window_size=5, invert=True)