forked from NUS-HPC-AI-Lab/DD-Ranking
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdemo_hard.py
More file actions
66 lines (59 loc) · 2.2 KB
/
demo_hard.py
File metadata and controls
66 lines (59 loc) · 2.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import os
import torch
import warnings
from ddranking.metrics import HardLabelEvaluator
from ddranking.config import Config
warnings.filterwarnings("ignore")
""" Use config file to specify the arguments (Recommended) """
config = Config.from_file("./configs/Demo_Hard_Label.yaml")
hard_label_evaluator = HardLabelEvaluator(config)
syn_data_dir = "./baselines/DM/CIFAR10/IPC10/"
syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu')
syn_lr = 0.01
print(hard_label_evaluator.compute_metrics(image_tensor=syn_images, syn_lr=syn_lr))
""" Use keyword arguments """
device = "cuda"
method_name = "DM" # Specify your method name
ipc = 10 # Specify your IPC
dataset = "CIFAR10" # Specify your dataset name
data_dir = "./datasets" # Specify your dataset path
syn_data_dir = "./DM/CIFAR10/IPC10/" # Specify your synthetic data path
model_name = "ConvNet-3" # Specify your model name
im_size = (32, 32) # Specify your image size
dsa_params = {
"prob_flip": 0.5,
"ratio_rotate": 15.0,
"saturation": 2.0,
"brightness": 1.0,
"contrast": 0.5,
"ratio_scale": 1.2,
"ratio_crop_pad": 0.125,
"ratio_cutout": 0.5,
}
syn_images = torch.load(os.path.join(syn_data_dir, f"images.pt"), map_location='cpu')
save_path = f"./results/{dataset}/{model_name}/IPC{ipc}/dm_hard_scores.csv"
hard_label_evaluator = HardLabelEvaluator(
dataset=dataset,
real_data_path=data_dir,
ipc=ipc,
model_name=model_name,
default_lr=0.01,
optimizer='sgd', # Use SGD optimizer
lr_scheduler='step', # Use StepLR learning rate scheduler
weight_decay=0.0005,
momentum=0.9,
use_zca=False,
num_eval=5,
data_aug_func='dsa', # Use DSA data augmentation
aug_params=dsa_params, # Specify DSA parameters
im_size=im_size,
num_epochs=1000,
num_workers=4,
use_torchvision=False,
syn_batch_size=128,
real_batch_size=256,
custom_val_trans=None,
device=device,
save_path=save_path
)
print(hard_label_evaluator.compute_metrics(image_tensor=syn_images, syn_lr=0.01))