-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvector_index.py
More file actions
333 lines (288 loc) · 12.2 KB
/
vector_index.py
File metadata and controls
333 lines (288 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
"""
Vector index over RAG chunks (small sub-chunks of AI chunks).
Used for Fast RAG: embed queries, retrieve top-K RAG chunks, resolve parent AI chunks, then extract.
"""
import logging
from typing import List, Dict, Optional, Any
import numpy as np
from openai import OpenAI
import config
from chunker import count_tokens
def _get_embedding_client() -> Optional[OpenAI]:
"""Get OpenAI-compatible client for embeddings (OpenAI or LM Studio). Returns None if OpenAI selected but key missing.
When LM Studio: uses EMBEDDING_LM_STUDIO_BASE_URL if set (e.g. second instance for nomic-embed), else main LM Studio URL."""
api_type = getattr(config, "EMBEDDING_API_TYPE", "openai").lower()
try:
if api_type == "lmstudio":
cfg = config.get_embedding_lm_studio_config()
return OpenAI(api_key=cfg["api_key"], base_url=cfg["base_url"])
else:
if not getattr(config, "OPENAI_API_KEY", ""):
return None
cfg = config.get_api_config("openai")
return OpenAI(api_key=cfg["api_key"], base_url=cfg["base_url"])
except Exception:
return None
def _subdivide_ai_chunk_into_rag_chunks(
ai_chunk: Dict[str, Any],
max_tokens: int,
model: str,
) -> List[Dict[str, Any]]:
"""Split an AI chunk's text into smaller RAG chunks (≤ max_tokens each)."""
text = ai_chunk.get("text", "")
parent_chunk_id = ai_chunk.get("chunk_id", 0)
videos = ai_chunk.get("videos", [])
if not text.strip():
return []
# Split by paragraphs first
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
if not paragraphs:
# Fallback: treat whole text as one segment if no paragraph breaks
paragraphs = [text.strip()] if text.strip() else []
rag_chunks = []
current = []
current_tokens = 0
for para in paragraphs:
pt = count_tokens(para, model)
if pt > max_tokens:
# Single paragraph too large: split by rough char boundary (~4 chars/token)
approx_chunk_chars = max(100, (max_tokens * 4) - 50)
start = 0
while start < len(para):
end = min(start + approx_chunk_chars, len(para))
# Try to break at sentence end
for sep in (". ", "! ", "? ", "\n"):
idx = para.rfind(sep, start, end + 1)
if idx != -1:
end = idx + len(sep)
break
seg = para[start:end].strip()
if seg:
rag_chunks.append({
"text": seg,
"parent_chunk_id": parent_chunk_id,
"videos": videos,
})
start = end
continue
if current_tokens + pt > max_tokens and current:
# Emit current RAG chunk
combined = "\n\n".join(current)
rag_chunks.append({
"text": combined,
"parent_chunk_id": parent_chunk_id,
"videos": videos,
})
current = []
current_tokens = 0
current.append(para)
current_tokens += pt
if current:
combined = "\n\n".join(current)
rag_chunks.append({
"text": combined,
"parent_chunk_id": parent_chunk_id,
"videos": videos,
})
return rag_chunks
def build_rag_chunks(ai_chunks: List[Dict[str, Any]], logger: Optional[logging.Logger] = None) -> List[Dict[str, Any]]:
"""Sub-divide all AI chunks into RAG chunks."""
log = logger or logging.getLogger(__name__)
max_tokens = config.RAG_CHUNK_MAX_TOKENS
model = config.EXTRACTION_MODEL
rag_chunks = []
rag_chunk_id = 0
for ac in ai_chunks:
for rc in _subdivide_ai_chunk_into_rag_chunks(ac, max_tokens, model):
rc["rag_chunk_id"] = rag_chunk_id
rag_chunks.append(rc)
rag_chunk_id += 1
log.info(f"Built {len(rag_chunks)} RAG chunks from {len(ai_chunks)} AI chunks")
return rag_chunks
def _embed_texts(client: OpenAI, texts: List[str], model: str, logger: logging.Logger) -> Optional[np.ndarray]:
"""Embed a list of texts. Returns (N, D) numpy array or None on error."""
if not texts:
return np.zeros((0, 1536), dtype=np.float32)
# Batch in chunks of 100 to avoid rate limits
batch_size = 100
all_embeddings = []
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
try:
r = client.embeddings.create(input=batch, model=model)
for d in r.data:
all_embeddings.append(d.embedding)
except Exception as e:
logger.error(f"Embedding batch failed: {e}")
return None
return np.array(all_embeddings, dtype=np.float32)
class VectorIndex:
"""In-memory vector index over RAG chunks."""
def __init__(
self,
rag_chunks: List[Dict[str, Any]],
embeddings: np.ndarray,
logger: Optional[logging.Logger] = None,
):
self.rag_chunks = rag_chunks
self.embeddings = embeddings # (N, D)
self.logger = logger or logging.getLogger(__name__)
def search(self, query_embedding: np.ndarray, k: int) -> List[Dict[str, Any]]:
"""Return top-k RAG chunks by cosine similarity."""
if self.embeddings.shape[0] == 0:
return []
q = query_embedding.reshape(1, -1).astype(np.float32)
norms = np.linalg.norm(self.embeddings, axis=1, keepdims=True)
norms[norms == 0] = 1e-9
sim = (self.embeddings / norms) @ (q / np.linalg.norm(q)).T
sim = sim.flatten()
top = np.argsort(-sim)[: min(k, len(sim))]
out = []
for i in top:
r = dict(self.rag_chunks[i])
r["score"] = float(sim[i])
out.append(r)
return out
def build_index(
ai_chunks: List[Dict[str, Any]],
logger: Optional[logging.Logger] = None,
) -> Optional[VectorIndex]:
"""
Build RAG chunks from AI chunks, embed them, and return VectorIndex.
Returns None if OpenAI embeddings unavailable (e.g. no API key).
"""
log = logger or logging.getLogger(__name__)
rag_chunks = build_rag_chunks(ai_chunks, log)
if not rag_chunks:
log.warning("No RAG chunks to index")
return None
client = _get_embedding_client()
if not client:
api = getattr(config, "EMBEDDING_API_TYPE", "openai")
if api == "lmstudio":
log.warning("Embedding client unavailable. Check LM Studio is running and embedding model is loaded.")
else:
log.warning("OpenAI API key not set; cannot build Fast RAG index. Use LM Studio embeddings or set OPENAI_API_KEY.")
return None
texts = [r["text"] for r in rag_chunks]
model = config.EMBEDDING_MODEL
embeddings = _embed_texts(client, texts, model, log)
if embeddings is None:
return None
return VectorIndex(rag_chunks, embeddings, log)
def retrieve_chunks_fast(
user_question: str,
ai_chunks: List[Dict[str, Any]],
vector_index: Optional[VectorIndex],
extraction_client: Any,
k: Optional[int] = None,
rag_queries: Optional[List[str]] = None,
expand_window: Optional[int] = None,
logger: Optional[logging.Logger] = None,
) -> List[Dict[str, Any]]:
"""
Fast RAG retrieval: (multi-)query RAG index -> parent AI chunk ids -> fetch AI chunks -> extract -> return extracted_content.
Args:
user_question: Original user question (used for extraction and fallback when no rag_queries).
ai_chunks: List of AI chunks (from rag_engine.chunks).
vector_index: Built index over RAG chunks. If None, returns [] (caller should fall back to full search).
extraction_client: LLMClient used for extract_relevant_content.
k: Top-K RAG chunks per query. Default from config.FAST_RAG_K.
rag_queries: Optional list of AI-generated RAG queries. If None, use [user_question].
expand_window: AI chunks before/after to include. Default config.EXPAND_CONTEXT_WINDOW.
logger: Logger.
Returns:
extracted_content list (same shape as rag_query extract_relevant_content_parallel).
"""
from concurrent.futures import ThreadPoolExecutor, as_completed
log = logger or logging.getLogger(__name__)
k = k or config.FAST_RAG_K
expand_window = expand_window if expand_window is not None else config.EXPAND_CONTEXT_WINDOW
extract_rag_only = getattr(config, "FAST_RAG_EXTRACT_RAG_CHUNKS_ONLY", False)
if vector_index is None:
log.warning("Vector index not available; Fast RAG cannot run.")
return []
client = _get_embedding_client()
if not client:
return []
queries = rag_queries if rag_queries else [user_question]
queries = [q.strip() for q in queries if q and isinstance(q, str)]
if not queries:
queries = [user_question]
# Embed each query and search; collect hits per parent (index)
seen_parent_ids = set()
rag_texts_by_index: Dict[int, List[str]] = {}
for q in queries:
try:
r = client.embeddings.create(input=[q], model=config.EMBEDDING_MODEL)
qemb = np.array(r.data[0].embedding, dtype=np.float32)
except Exception as e:
log.warning(f"Failed to embed query '{q[:50]}...': {e}")
continue
for hit in vector_index.search(qemb, k):
pid = hit["parent_chunk_id"]
seen_parent_ids.add(pid)
if hit.get("text"):
rag_texts_by_index.setdefault(pid, []).append(hit["text"].strip())
extracted_content: List[Dict[str, Any]] = []
if extract_rag_only:
# Skip extraction: use RAG chunk text as-is and pass directly to synthesizer.
log.info("[Fast RAG] Using RAG chunks only (no extraction); passing directly to synthesizer.")
for idx in sorted(seen_parent_ids):
if idx not in rag_texts_by_index:
continue
parts = list(dict.fromkeys(rag_texts_by_index[idx]))
if not parts:
continue
content = "\n\n---\n\n".join(parts)
c = ai_chunks[idx]
for v in c.get("videos", []):
extracted_content.append({
"content": content,
"video_id": v["video_id"],
"video_title": v["video_title"],
"video_url": v.get("video_url", ""),
"chunk_id": c["chunk_id"],
})
return extracted_content
# Expand to surrounding AI chunks, then extract (full chunk) via LLM
all_parent_ids = set()
for pid in seen_parent_ids:
for d in range(-expand_window, expand_window + 1):
idx = pid + d
if 0 <= idx < len(ai_chunks):
all_parent_ids.add(idx)
max_parallel = config.MAX_PARALLEL_EXTRACTIONS
with ThreadPoolExecutor(max_workers=max_parallel) as executor:
futures = {
executor.submit(
extraction_client.extract_relevant_content,
ai_chunks[i]["text"],
user_question,
ai_chunks[i]["videos"],
ai_chunks[i]["chunk_id"],
None,
None,
False,
): ai_chunks[i]
for i in sorted(all_parent_ids)
}
for future in as_completed(futures):
chunk = futures[future]
try:
out = future.result(timeout=300)
if not out or not isinstance(out, str):
continue
if not out.strip() or out.strip().upper() in ("NO_RELEVANT_CONTENT",) or out.strip().upper().startswith("NO_RELEVANT_CONTENT"):
continue
for v in chunk["videos"]:
extracted_content.append({
"content": out.strip(),
"video_id": v["video_id"],
"video_title": v["video_title"],
"video_url": v.get("video_url", ""),
"chunk_id": chunk["chunk_id"],
})
except Exception as e:
log.debug(f"Extract error for chunk {chunk['chunk_id']}: {e}")
return extracted_content