-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_proth_ml_dataset.py
More file actions
45 lines (40 loc) · 1.28 KB
/
generate_proth_ml_dataset.py
File metadata and controls
45 lines (40 loc) · 1.28 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import pandas as pd
import numpy as np
import random
from collections import defaultdict
# Parameters
K_MIN, K_MAX = 1, 1200
N_MIN, N_MAX = 1, 100_000
NEG_PER_POS = 2
INPUT_CSV = "prothsearch_primtal.csv"
OUTPUT_CSV = "proth_ml_dataset.csv"
# 1. Load positive examples
df = pd.read_csv(INPUT_CSV)
# Only keep k and n in specified range
mask = (df['k'] >= K_MIN) & (df['k'] <= K_MAX) & (df['n'] >= N_MIN) & (df['n'] <= N_MAX)
pos_df = df.loc[mask, ['k', 'n']].copy()
pos_df['label'] = 1
# 2. Build lookup for fast negative sampling
pos_set = set((int(row.k), int(row.n)) for row in pos_df.itertuples(index=False))
k_to_ns = defaultdict(set)
for k, n in pos_set:
k_to_ns[k].add(n)
# 3. Generate negative examples
data = []
for row in pos_df.itertuples(index=False):
k, n = int(row.k), int(row.n)
data.append({'k': k, 'n': n, 'label': 1})
negs = set()
attempts = 0
while len(negs) < NEG_PER_POS and attempts < 100:
n_neg = random.randint(N_MIN, N_MAX)
if n_neg not in k_to_ns[k]:
negs.add(n_neg)
attempts += 1
for n_neg in negs:
data.append({'k': k, 'n': n_neg, 'label': 0})
# 4. Shuffle and save
random.shuffle(data)
df_out = pd.DataFrame(data)
df_out.to_csv(OUTPUT_CSV, index=False)
print(f"✅ Saved {len(df_out)} rows to {OUTPUT_CSV}")