-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmigrate_embeddings.py
More file actions
164 lines (136 loc) · 5.74 KB
/
migrate_embeddings.py
File metadata and controls
164 lines (136 loc) · 5.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""
Migration Script: MiniLM (384D) → Gemini Embedding 2 (3072D)
Re-embeds all existing memories from the legacy collection into a new v2 collection
using Gemini Embedding 2 for richer semantic matching and multimodal support.
Run this once after upgrading to migrate existing memories.
New installations don't need this — they start with v2 automatically.
Usage:
python migrate_embeddings.py
"""
import time
import logging
import chromadb
from chromadb.config import Settings
from google import genai
from google.genai import types as genai_types
from config import Config
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
logger = logging.getLogger(__name__)
def migrate():
# Initialize ChromaDB
client = chromadb.PersistentClient(
path=Config.CHROMA_DB_PATH,
settings=Settings(anonymized_telemetry=False)
)
# Initialize Gemini client
gemini = genai.Client(api_key=Config.GOOGLE_API_KEY)
logger.info("Gemini Embedding 2 client initialized")
# Find legacy collections to migrate
for entity_name, entity_config in Config.ENTITIES.items():
legacy_name = entity_config["collection_name"]
v2_name = legacy_name + "_v2"
try:
legacy = client.get_collection(name=legacy_name)
except Exception:
logger.info(f"No legacy collection '{legacy_name}' found for {entity_name} — skipping")
continue
legacy_count = legacy.count()
if legacy_count == 0:
logger.info(f"Legacy collection '{legacy_name}' is empty — skipping")
continue
logger.info(f"Legacy collection: {legacy_name} ({legacy_count} memories)")
# Get or create v2 collection
v2 = client.get_or_create_collection(
name=v2_name,
metadata={"entity": entity_name, "embedding_model": "gemini-embedding-2"}
)
existing_v2 = v2.count()
logger.info(f"V2 collection: {v2_name} ({existing_v2} existing memories)")
# Pull all memories from legacy
results = legacy.get(include=["documents", "metadatas"])
if not results or not results["ids"]:
logger.info("No memories to migrate")
continue
total = len(results["ids"])
logger.info(f"Migrating {total} memories...")
# Get IDs already in v2 to skip duplicates (for resumable migration)
existing_ids = set()
if existing_v2 > 0:
existing_results = v2.get()
existing_ids = set(existing_results["ids"])
logger.info(f"Skipping {len(existing_ids)} already-migrated memories")
# Migrate in batches
batch_size = 20
migrated = 0
skipped = 0
failed = 0
for i in range(0, total, batch_size):
batch_ids = results["ids"][i:i + batch_size]
batch_docs = results["documents"][i:i + batch_size]
batch_metas = results["metadatas"][i:i + batch_size]
# Filter out already-migrated
new_ids = []
new_docs = []
new_metas = []
for j, mid in enumerate(batch_ids):
if mid in existing_ids:
skipped += 1
continue
new_ids.append(mid)
new_docs.append(batch_docs[j])
new_metas.append(batch_metas[j])
if not new_ids:
continue
# Generate embeddings for this batch
batch_embeddings = []
for doc in new_docs:
try:
result = gemini.models.embed_content(
model=Config.GEMINI_EMBEDDING_MODEL,
contents=doc,
config=genai_types.EmbedContentConfig(
task_type="RETRIEVAL_DOCUMENT",
output_dimensionality=Config.GEMINI_EMBEDDING_DIMENSIONS
)
)
batch_embeddings.append(result.embeddings[0].values)
except Exception as e:
logger.error(f"Failed to embed memory {new_ids[len(batch_embeddings)]}: {e}")
failed += 1
batch_embeddings.append(None)
# Filter out failed embeddings
final_ids = []
final_docs = []
final_metas = []
final_embeddings = []
for j, emb in enumerate(batch_embeddings):
if emb is not None:
final_ids.append(new_ids[j])
final_docs.append(new_docs[j])
final_metas.append(new_metas[j])
final_embeddings.append(emb)
if final_ids:
v2.add(
ids=final_ids,
embeddings=final_embeddings,
documents=final_docs,
metadatas=final_metas
)
migrated += len(final_ids)
logger.info(f" Progress: {migrated + skipped}/{total} (migrated: {migrated}, skipped: {skipped}, failed: {failed})")
# Small delay to be gentle on the API
if i + batch_size < total:
time.sleep(0.5)
# Final stats
logger.info("=" * 60)
logger.info(f"MIGRATION COMPLETE for {entity_name}")
logger.info(f" Total memories: {total}")
logger.info(f" Migrated: {migrated}")
logger.info(f" Skipped (already existed): {skipped}")
logger.info(f" Failed: {failed}")
logger.info(f" V2 collection now has: {v2.count()} memories")
logger.info("=" * 60)
if failed > 0:
logger.warning(f"{failed} memories failed to migrate. Run this script again to retry.")
if __name__ == "__main__":
migrate()