-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataset.py
More file actions
121 lines (101 loc) · 4.5 KB
/
dataset.py
File metadata and controls
121 lines (101 loc) · 4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import random
import numpy as np
import copy
from feature_extract import extract_features
import torch
from torch.utils.data import Dataset
import os
import gzip
import pickle
from config import *
import argparse
class SeedGenerator:
def __init__(self,initialSeed = 0,nSeeds=1000000):
random.seed(initialSeed)
self.initialSeed = initialSeed
self.fixedSeeds = [random.randint(1,100000000) for i in range(nSeeds)]
self.cur_ind = 0
print(f'set {len(self.fixedSeeds)} seeds, initialSeed {initialSeed}')
def get_seed(self):
self.cur_ind += 1
return self.fixedSeeds[self.cur_ind-1]
class MIPDataset(Dataset):
def __init__(self,files,bgdir,reorderFunc, augFunc,sampleTimes,seedGenerator):
insPaths = [ filepaths[0] for filepaths in files]
solPaths = [ filepaths[1] for filepaths in files]
self.insPaths = insPaths
self.solPaths = solPaths
self.bgdir = bgdir
self.reorder = reorderFunc
self.addPos = augFunc
random.seed(seedGenerator.initialSeed)
self.seeds = [random.randint(1,5000) for _ in range(sampleTimes)] if sampleTimes>0 else None
os.makedirs(bgdir,exist_ok=True)
self.seedGenerator = seedGenerator
# self.seedSets = set()
def __getitem__(self, index):
inspath = self.insPaths[index]
solpath = self.solPaths[index]
insname = os.path.basename(inspath)
bgpath = os.path.join(self.bgdir,insname.replace('.gz','')+'.bp')
bpData = pickle.load(gzip.open(bgpath,'rb'))
if 'sols' not in bpData.keys():
solData = pickle.load(gzip.open(solpath, 'rb'))
reorderData = self.reorder(bpData['varNames'])
data = {
# 'groupFeatures':torch.Tensor(features.groupFeatures),
'varFeatures': torch.Tensor(bpData['variableFeatures']),
'consFeatures': torch.Tensor(bpData['constraintFeatures']),
'edgeFeatures': torch.Tensor(bpData['edgeWeights']),
'edgeInds': torch.Tensor(bpData['edgeInds'].astype(int)).permute(1, 0), # var ID -> cons ID
'nGroup': reorderData['nGroup'],
'nElement': reorderData['nElement'],
'reorderInds': torch.Tensor(reorderData['reorderInds'])
}
sols = solData['sols']
objs = solData['objs']
if ''.join(bpData['varNames']) != ''.join(solData['varNames']):
raise NotImplementedError
data['sols'] = torch.Tensor(sols[0])
data['objs'] = torch.Tensor([objs[0]])
pickle.dump(data,gzip.open(bgpath,'wb'))
bpData = data
# add aug features
generated_seed = self.seedGenerator.get_seed()
if self.seeds is None:
data = self.addPos(bpData, seed=generated_seed)
else:
random.seed(generated_seed)
selectedSeed = random.choice(self.seeds)
selectedSeed = selectedSeed*(index+1)
data = self.addPos(bpData, seed=selectedSeed)
# self.seedSets.add(selectedSeed)
return data
def __len__(self):
return len(self.insPaths)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='SMSP')
args = parser.parse_args()
info = confInfo[args.dataset]
ADDPOS = info['addPosFeature']
REORDER = info['reorder']
fileDir = os.path.join(info['trainDir'], 'ins')
solDir = os.path.join(info['trainDir'], 'sol')
bgDir = os.path.join(info['trainDir'], 'bg')
solnames = os.listdir(solDir)
filepaths = [os.path.join(fileDir, solname.replace('.sol', '')) for solname in solnames]
solpaths = [os.path.join(solDir, solname) for solname in solnames]
dataset = MIPDataset(list(zip(filepaths,solpaths)),bgDir,REORDER,ADDPOS)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)
print('Start constructing bipartite graph ...')
for step,data in enumerate(data_loader):
varFeatures = data['varFeatures']
consFeatures = data['consFeatures']
edgeFeatures = data['edgeFeatures']
edgeInds = data['edgeInds']
sols = data['sols']
objs = data['objs']
reorderInds = data['reorderInds']
print(f'Processed {step}/{len(data_loader)}')
print('Bipartite graph construction finished!')