forked from ansschh/lt-gate
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcheck_splits.py
More file actions
159 lines (139 loc) · 5.96 KB
/
check_splits.py
File metadata and controls
159 lines (139 loc) · 5.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# check_splits.py
from pathlib import Path
import json
import numpy as np
try:
import h5py
except ImportError:
raise SystemExit("Missing dependency: pip install h5py")
ROOT = Path(__file__).resolve().parent
DATA = ROOT / "data"
def summarize_h5(path: Path):
with h5py.File(path, "r") as f:
keys = list(f.keys())
print(f"Available keys in {path}: {keys}")
# The files seem to have a nested structure with 'train', 'val', 'test' as top-level keys
# Let's explore the first key to see what's inside
if keys:
first_key = keys[0]
print(f"Exploring key '{first_key}' in {path}")
# Check if it's a group (has subgroups) or dataset
if isinstance(f[first_key], h5py.Group):
subgroup_keys = list(f[first_key].keys())
print(f" Subgroup keys: {subgroup_keys}")
# Look for data and labels in subgroups
data_key = None
label_key = None
for subkey in subgroup_keys:
if subkey in ['x', 'data', 'images', 'features', 'input']:
data_key = f"{first_key}/{subkey}"
elif subkey in ['y', 'labels', 'targets', 'classes']:
label_key = f"{first_key}/{subkey}"
if data_key and label_key:
x = f[data_key]
y = f[label_key]
else:
print(f" Could not find data/label keys in subgroups: {subgroup_keys}")
return {
"path": str(path),
"keys": keys,
"subgroup_keys": subgroup_keys,
"error": "No data/label keys found in subgroups"
}
else:
# It's a dataset, let's see what it contains
dataset = f[first_key]
print(f" Dataset shape: {dataset.shape}")
print(f" Dataset dtype: {dataset.dtype}")
# If it's a single dataset, it might contain both data and labels
# Let's assume it's data for now
x = dataset
y = None # No separate labels
return {
"path": str(path),
"keys": keys,
"dataset_shape": tuple(x.shape),
"dataset_dtype": str(x.dtype),
"num_samples": int(x.shape[0]),
"seq_len": int(x.shape[1]) if len(x.shape) > 1 else None,
"channels": int(x.shape[2]) if len(x.shape) > 2 else None,
"height": int(x.shape[3]) if len(x.shape) > 3 else None,
"width": int(x.shape[4]) if len(x.shape) > 4 else None,
"x_min": float(np.min(x[:10])),
"x_max": float(np.max(x[:10])),
"has_nan": bool(np.isnan(x[:10]).any()),
"note": "Single dataset, no separate labels"
}
if y is not None:
x_shape = tuple(x.shape)
y_shape = tuple(y.shape)
labels = np.unique(y[:]).tolist()
stats = {
"path": str(path),
"keys": keys,
"data_key": data_key,
"label_key": label_key,
"x_shape": x_shape,
"y_shape": y_shape,
"label_set": labels,
"num_samples": int(x_shape[0]),
"seq_len": int(x_shape[1]) if len(x_shape) > 1 else None,
"channels": int(x_shape[2]) if len(x_shape) > 2 else None,
"height": int(x_shape[3]) if len(x_shape) > 3 else None,
"width": int(x_shape[4]) if len(x_shape) > 4 else None,
"x_min": float(np.min(x[:10])), # sample a few to be quick
"x_max": float(np.max(x[:10])),
"has_nan": bool(np.isnan(x[:10]).any()),
}
return stats
else:
return {
"path": str(path),
"keys": keys,
"error": "Could not determine data structure"
}
def pretty(d):
return json.dumps(d, indent=2)
def main():
manifest = {}
man_path = DATA / "manifest.json"
if man_path.exists():
try:
manifest = json.load(open(man_path, "r"))
except Exception:
pass
results = {}
for split in ["fast", "slow"]:
split_dir = DATA / split
for part in ["train", "val", "test"]:
h5 = split_dir / f"{part}.h5"
if not h5.exists():
print(f"[MISSING] {h5}")
continue
stats = summarize_h5(h5)
results[f"{split}/{part}"] = stats
# Print summary
print("\n=== DATA SUMMARY ===")
for k, v in results.items():
print(f"\n[{k}]")
print(pretty(v))
# Quick sanity checks
print("\n=== QUICK CHECKS ===")
fast_labels = set(results.get("fast/train", {}).get("label_set", []))
slow_labels = set(results.get("slow/train", {}).get("label_set", []))
ok_fast = fast_labels.issubset(set(range(0,5)))
ok_slow = slow_labels.issubset(set(range(5,10)))
print(f"Fast labels subset of {{0..4}}: {ok_fast} -> {sorted(fast_labels)}")
print(f"Slow labels subset of {{5..9}}: {ok_slow} -> {sorted(slow_labels)}")
# Compare counts to manifest if present
if manifest:
print("\n=== MANIFEST CHECK ===")
for k in ["fast", "slow"]:
for part in ["train", "val", "test"]:
key = f"{k}/{part}"
cnt = results.get(key, {}).get("num_samples")
man_cnt = manifest.get(k, {}).get(part)
if cnt is not None and man_cnt is not None:
print(f"{key}: data={cnt}, manifest={man_cnt}, match={cnt==man_cnt}")
if __name__ == "__main__":
main()