-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchroma_compat.py
More file actions
169 lines (137 loc) · 6.96 KB
/
chroma_compat.py
File metadata and controls
169 lines (137 loc) · 6.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from typing import List, Dict, Optional, Iterable
from inferred_relations import InferredRelations
from synsets import Synsets
from code_tokenizer import tokenize_universal, is_prose
class UnForkClient:
def __init__(self, settings: Optional[Dict] = None):
self.collections: Dict[str, UnForkCollection] = {}
self._use_corpus_as_store = True
def create_collection(self, name: str, metadata: Optional[Dict] = None, embedding_function=None, get_or_create: bool = False):
if get_or_create and name in self.collections:
return self.collections[name]
col = UnForkCollection(name=name, use_corpus_as_store=self._use_corpus_as_store)
self.collections[name] = col
return col
def get_collection(self, name: str):
if name not in self.collections:
raise ValueError(f"Collection '{name}' does not exist")
return self.collections[name]
def list_collections(self):
return list(self.collections.keys())
def delete_collection(self, name: str):
if name in self.collections:
del self.collections[name]
class UnForkCollection:
def __init__(self, name: str, use_corpus_as_store: bool = True):
self.name = name
self.use_corpus_as_store = use_corpus_as_store
self.token_positions: Dict[str, List[int]] = {}
self.corpus_tokens: List[str] = []
self.doc_store: Dict[str, str] = {}
self.metadata_store: Dict[str, Dict] = {}
self.doc_positions: Dict[str, tuple] = {}
# Synset expansion support (empty until user loads file)
self.synsets = Synsets()
if use_corpus_as_store:
self.relations = InferredRelations(self.token_positions, self.corpus_tokens)
else:
# Placeholder for explicit relations
self.relations = InferredRelations(self.token_positions, self.corpus_tokens)
def add(self, ids: List[str] = None, documents: List[str] = None, metadatas: List[Dict] = None, embeddings=None):
if not ids or not documents:
raise ValueError("Both 'ids' and 'documents' are required")
if len(ids) != len(documents):
raise ValueError("ids and documents must have same length")
for idx, (doc_id, doc_text) in enumerate(zip(ids, documents)):
self.doc_store[doc_id] = doc_text
if metadatas:
self.metadata_store[doc_id] = metadatas[idx]
tokens = self._tokenize(doc_text)
start_pos = len(self.corpus_tokens)
self.corpus_tokens.extend(tokens)
for offset, token in enumerate(tokens):
key = token.lower()
pos = start_pos + offset
self.token_positions.setdefault(key, []).append(pos)
# Also add to GDA index for fast lookups
self.relations.add_token(key, pos)
self.doc_positions[doc_id] = (start_pos, start_pos + len(tokens))
def query(self, query_texts: Iterable[str] = None, query_embeddings=None, n_results: int = 10, where=None, where_document=None, include=None):
if isinstance(query_texts, str):
query_texts = [query_texts]
if not query_texts:
return {'ids': [[]], 'documents': [[]], 'metadatas': [[]], 'distances': [[]]}
query_text = query_texts[0]
tokens = self._tokenize(query_text)
# Expand query tokens using synsets to capture semantic variants
expanded_tokens = list(self.synsets.expand_words(tokens) | set(tokens))
inferred = self.relations.infer_relations_from_positions(expanded_tokens, window_size=5)
doc_scores = self._score_documents(inferred)
ranked = sorted(doc_scores.items(), key=lambda x: x[1], reverse=True)[:n_results]
result_ids = [doc_id for doc_id, score in ranked]
result_docs = [self.doc_store.get(doc_id, "") for doc_id in result_ids]
result_meta = [self.metadata_store.get(doc_id, {}) for doc_id in result_ids]
result_scores = [score for doc_id, score in ranked]
return {
'ids': [result_ids],
'documents': [result_docs],
'metadatas': [result_meta],
'distances': [result_scores]
}
def load_synsets_from_file(self, filepath: str) -> int:
"""Load synsets from a file (Moby-format CSV lines) into this collection.
Returns number of groups added.
"""
return self.synsets.load_from_file(filepath)
def peek(self, limit: int = 10):
doc_ids = list(self.doc_store.keys())[:limit]
return {
'ids': doc_ids,
'documents': [self.doc_store[d] for d in doc_ids],
'metadatas': [self.metadata_store.get(d, {}) for d in doc_ids]
}
def count(self):
return len(self.doc_store)
def get(self, ids=None, where=None, limit=None):
if ids:
return {
'ids': ids,
'documents': [self.doc_store.get(doc_id, "") for doc_id in ids],
'metadatas': [self.metadata_store.get(doc_id, {}) for doc_id in ids]
}
raise NotImplementedError("Filtering by 'where' not yet implemented")
def delete(self, ids=None, where=None):
if ids:
for doc_id in ids:
self.doc_store.pop(doc_id, None)
self.metadata_store.pop(doc_id, None)
self.doc_positions.pop(doc_id, None)
def update(self, ids, documents=None, metadatas=None):
for idx, doc_id in enumerate(ids):
if documents and idx < len(documents):
self.doc_store[doc_id] = documents[idx]
if metadatas and idx < len(metadatas):
self.metadata_store[doc_id] = metadatas[idx]
def _tokenize(self, text: str):
"""Universal tokenizer - handles code and prose automatically."""
return tokenize_universal(text)
def _score_documents(self, inferred_relations: Iterable[tuple]):
scores = {}
# Convert relations to a set of positions for O(1) lookup
relation_positions = set()
for rel in inferred_relations:
if len(rel) >= 2:
# Add all positions from the relation tuple
relation_positions.update(rel)
if not relation_positions:
# No relations found, score by document order
for i, doc_id in enumerate(self.doc_positions.keys()):
scores[doc_id] = len(self.doc_positions) - i
return scores
# Score documents based on how many relation positions they contain
for doc_id, (start, end) in self.doc_positions.items():
doc_positions = set(range(start, end))
overlap = len(relation_positions & doc_positions)
if overlap > 0:
scores[doc_id] = overlap
return scores