-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdatasetplot.py
More file actions
70 lines (54 loc) · 2.27 KB
/
datasetplot.py
File metadata and controls
70 lines (54 loc) · 2.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import plotly.express as px
import numpy as np
import pandas as pd
from chromadb.api.models import Collection
from chromadb.api.types import IncludeEnum
def fetch_embeddings(collection: Collection):
"""Extrai os embeddings e metadados da coleção ChromaDB."""
results = collection.get(include=[IncludeEnum.embeddings, IncludeEnum.metadatas])
embeddings = results.get("embeddings", [])
metadatas = results.get("metadatas", [])
labels = [meta.get("document_type", "Unknown") for meta in metadatas] # Extrair o rótulo
return np.array(embeddings), labels
def reduce_dimensionality(embeddings, method='tsne', dimensions=3):
"""Reduz a dimensionalidade dos embeddings usando t-SNE ou PCA."""
if method == 'tsne':
reducer = TSNE(n_components=dimensions, random_state=42)
elif method == 'pca':
reducer = PCA(n_components=dimensions)
else:
raise ValueError("Método inválido! Use 'tsne' ou 'pca'.")
embeddings_reduced = reducer.fit_transform(embeddings)
return embeddings_reduced
def plot_embeddings_3d(embeddings_3d, labels, method='t-SNE'):
"""Plota os embeddings em 3D de forma interativa com plotly."""
df = pd.DataFrame({
'x': embeddings_3d[:, 0],
'y': embeddings_3d[:, 1],
'z': embeddings_3d[:, 2],
'label': labels
})
fig = px.scatter_3d(df, x='x', y='y', z='z', color='label', title=f"Visualização dos Embeddings - {method}")
fig.update_layout(scene=dict(
xaxis_title="Componente 1",
yaxis_title="Componente 2",
zaxis_title="Componente 3"
))
fig.show()
import chromadb
# Inicialize o cliente com as novas configurações
client = chromadb.PersistentClient(path='data/chroma')
# Lista todas as coleções
collections = client.list_collections()
# Obtenha a coleção desejada
collection = client.get_collection("langchain")
print(collection)
# Extrair embeddings e rótulos
embeddings, labels = fetch_embeddings(collection)
# Reduzir dimensionalidade para 3D
embeddings_3d = reduce_dimensionality(embeddings, method='tsne', dimensions=3)
# Plotar os embeddings em 3D
plot_embeddings_3d(embeddings_3d, labels, method='t-SNE')