-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaccuracy.py
More file actions
116 lines (90 loc) · 3.52 KB
/
accuracy.py
File metadata and controls
116 lines (90 loc) · 3.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gzip
import os
import pickle
from label_opt import labelOpt
import pyscipopt
from dataset import MIPDataset
from nn import GNNPolicy
from pathlib import Path
import scipy.io as io
import numpy as np
from label_opt import labelOpt,lexOpt
import torch
import torch.nn.functional as F
import torch_geometric
import random
import shutil
from config import *
import argparse
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
parser = argparse.ArgumentParser()
parser.add_argument('--expName', type=str, default='SMSP-rs')
parser.add_argument('--dataset', type=str, default='SMSP')
parser.add_argument('--PE', type=str, default='Y')
args = parser.parse_args()
EXP_NAME = args.expName
DATASET = args.dataset
info = confInfo[args.dataset]
DIR_SOL = os.path.join(info['trainDir'],'sol')
DIR_PRE = os.path.join(EXP_NAME,'logits_valid')
os.makedirs(args.expName,exist_ok=True)
sample_names = os.listdir(DIR_SOL)
sample_files = [ (os.path.join(DIR_PRE,name.replace('.sol','.prob')),os.path.join(DIR_SOL,name)) for name in sample_names]
random.seed(0)
random.shuffle(sample_files)
valid_files = sample_files[int(0.8 * len(sample_files)) :]
Ks = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
nErrors = []
for step,filepath in enumerate(valid_files):
pre_file,sol_file = filepath
preData = pickle.load(open(pre_file,'rb'))
pre = preData['pre']
solData = pickle.load(gzip.open(sol_file,'rb'))
sol = solData['sols'][0].astype(int)
varnames = solData['varNames']
if 'IP' == DATASET:
nItem = max([int(re.findall('\d+', name)[0]) for name in varnames if 'place' in name]) + 1
nBin = max([int(re.findall('\d+', name)[1]) for name in varnames if 'place' in name]) + 1
X = torch.zeros((nItem, nBin))
for ind, name in enumerate(varnames):
if 'place' in name:
ss = re.findall('\d+', name)
a, b = int(ss[0]), int(ss[1])
X[a, b] = sol[ind]
X_hat = torch.Tensor(pre.reshape(nItem,nBin))
elif 'SMSP' == DATASET:
nItem = max([int(re.findall('\d+', name)[0]) for name in varnames if 'X' in name]) + 1
nCap = max([int(re.findall('\d+', name)[0]) for name in varnames if 'Y' in name]) + 1
X= torch.zeros((nItem + nCap, nItem))
for ind, name in enumerate(varnames):
ss = re.findall('\d+', name)
a, b = int(ss[0]), int(ss[1])
if 'X' in name:
X[a, b] = sol[ind]
elif 'Y' in name:
X[a + nItem, b] = sol[ind]
X_hat = torch.Tensor(pre.reshape(nItem+ nCap, nItem))
X_bar = labelOpt(X_hat[None, :, :], X.clone()[None, :, :], device='cpu')
X_hat = X_hat.reshape(-1)
X_round = X_hat.round()
X_bar = X_bar.reshape(-1)
n = X_hat.shape[-1]
kErrors = []
for k in Ks:
nTop = int(n * k)
ordering = (-(X_hat - 0.5).abs()).sort(dim=-1)[1][0:nTop]
topKRound = X_round[ordering]
topKXBar = X_bar[ordering]
error = (topKRound != topKXBar).sum().item()
kErrors.append(error)
nErrors.append(kErrors)
print(f'Processed {step}/{len(valid_files)}')
nErrors = np.array(nErrors)
errorMean = nErrors.mean(axis=0)
errorStd = nErrors.std(axis=0)
with open(os.path.join(EXP_NAME,'prediction_error.txt'),'w') as f:
for i in range(nErrors.shape[1]):
s = f'k={Ks[i]} mean: {errorMean[i]:.2f} std: {errorStd[i]:.2f}\n'
f.write(s)
print(s)