-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_groups.py
More file actions
104 lines (83 loc) · 4.45 KB
/
test_groups.py
File metadata and controls
104 lines (83 loc) · 4.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argparse
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sentence_transformers import SentenceTransformer, losses
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import cosine_similarity
# -----------------------------------------------------------------------------
def average_pairwise_similarity(embeddings):
sim_matrix = cosine_similarity(embeddings)
tril = np.tril_indices(len(sim_matrix), k=-1)
return sim_matrix[tril].mean()
# -----------------------------------------------------------------------------
def compute_group_similarity(model, group_words):
# clean-up our words, make sure they're strings and not blank
words = [w for w in group_words if isinstance(w, str) and w.strip()]
# need at least 2 words to do a similarity
if len(words) < 2:
return None, None
embeddings = model.encode(words)
ave_sim = average_pairwise_similarity(embeddings)
return ave_sim, embeddings
# -----------------------------------------------------------------------------
def visualize_tsne(group_embeddings_dict, base_path, title="t-SNE of Sample Groups"):
all_vectors, all_labels = [], []
for group_id, (words, embeddings) in group_embeddings_dict.items():
all_vectors.extend(embeddings)
all_labels.extend([f"group_{group_id}"] * len(embeddings))
tsne = TSNE(n_components=2, random_state=42)
reduced = tsne.fit_transform(all_vectors)
df = pd.DataFrame({'x': reduced[:, 0], 'y': reduced[:, 1], 'group': all_labels})
plt.figure(figsize=(10, 8))
sns.scatterplot(data=df, x='x', y='y', hue='group', style='group', s=60)
plt.title(title)
plt.tight_layout()
plt.savefig(base_path + '-TSNE.png')
# -----------------------------------------------------------------------------
def load_groups_from_csv(csv_path, group_col='id', data_col='word', min_size=2):
df = pd.read_csv(csv_path)
df = df.dropna(subset=[data_col])
df[data_col] = df[data_col].astype(str)
grouped = df.groupby(group_col)[data_col].apply(list)
groups = {gid: words for gid, words in grouped.items() if len(words) >= min_size}
return groups
# =============================================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="test_groups", description="Tests a fine tuned model with new groups."
)
parser.add_argument('--csv_path', type=str, required=True, help='Path to CSV test examples file.')
parser.add_argument('--model_path', type=str, required=True, help='Path to fine tuned sentence transormer model.')
parser.add_argument('--visualize_groups', type=int, required=False, default=3, help='How many groups to visualize with t-SNE. Default: 3')
parser.add_argument('--group_col', type=str, required=False, default='id', help='Column header in CSV file that denotes the group. Default: \'id\'')
parser.add_argument('--data_col', type=str, required=False, default='word', help='Column header in CSV file that denotes the data. Default: \'word\'')
parser.add_argument('--min_size', type=int, required=False, default=2, help='Defines the minium size of a \'small\' group. Default: 2')
args = parser.parse_args()
csv_path = args.csv_path
base_path, _ = os.path.splitext(csv_path)
print("⏳ Loading model...")
model = SentenceTransformer(args.model_path)
print("📖 Reading groups from: {csv_path}")
groups = load_groups_from_csv(csv_path, group_col=args.group_col, data_col=args.data_col, min_size=args.min_size)
results = []
tsne_samples = {}
print(f"🤔 Evaluating {len(groups)} groups...")
for i, (gid, words) in enumerate(groups.items()):
avg_sim, emb = compute_group_similarity(model, words)
if avg_sim is not None:
results.append({'group_id': gid, 'avg_similarity': avg_sim, 'size': len(words)})
if len(tsne_samples) < args.visualize_groups:
tsne_samples[gid] = (words, emb)
results_df = pd.DataFrame(results)
results_df.sort_values("avg_similarity", ascending=False, inplace=True)
print("\n📊 Group Similarity Summary:")
print(results_df.round(3).head(10))
output_csv = base_path + "-test-groups.csv"
print(f"📝 Saving results to {output_csv}...")
results_df.to_csv(output_csv, index=False)
if tsne_samples:
print("🎨 Visualizing sample groups with t-SNE...")
visualize_tsne(tsne_samples, base_path)