-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtrain_utilities.py
More file actions
82 lines (63 loc) · 2.82 KB
/
train_utilities.py
File metadata and controls
82 lines (63 loc) · 2.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import re
import sys
import logging
import pickle
from pathlib import Path
from itertools import product
from lp import *
from gomory import *
class TensorSet():
def __init__(self, set_=None):
self._set = set_ if set_ is not None else set()
def __len__(self):
return len(self._set)
def append(self, new_tensor):
new_tensor = tuple(new_tensor.detach().cpu().numpy().tolist())
self._set.update({new_tensor})
def get_instance(instance_path, add_variable_bounds=False, presolve=True, force_reload=False, device="cpu"):
instance_path = Path(instance_path)
# Obtain solution info
solution_path = instance_path.parent/"solutions"/(instance_path.stem+".pkl")
if solution_path.exists() and not force_reload:
# Load solutions
with open(solution_path, "rb") as solution_file:
solutions = pickle.load(solution_file)
solutions = tuple(t.to(device) if torch.is_tensor(t) else t for t in solutions)
else:
# Compute solutions
A, b, c, vtypes, _ = load_instance(instance_path, device=device,
add_variable_bounds=add_variable_bounds, presolve=presolve)
lp_value, lp_solution, _ = solve_lp(A, b, c)
ilp_value, ilp_solution = solve_ilp(A, b, c, vtypes)
gomory_values = compute_gomory_bounds(A, b, c, vtypes, nb_rounds=2)
solutions = A, b, c, vtypes, lp_value, lp_solution, ilp_value, ilp_solution, gomory_values
# Save solutions
solutions_on_cpu = tuple(t.to("cpu") if torch.is_tensor(t) else t for t in solutions)
solution_path.parent.mkdir(exist_ok=True)
with open(solution_path, "wb") as solution_file:
pickle.dump(solutions_on_cpu, solution_file)
return solutions
def configure_logging(output_file=None):
logger = logging.getLogger("subadditive")
if not logger.hasHandlers():
logger.setLevel(logging.INFO)
formatter = logging.Formatter(datefmt='%H:%M:%S',
# fmt='[%(asctime)s] %(threadName)-12s %(message)s'
fmt='[%(asctime)s] %(message)s')
if output_file:
file_handler = logging.FileHandler(output_file)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
return logger
def path_ordering(path):
groups = re.split("[/_.]", str(path))
groups = [int(group) if group.isnumeric() else group for group in groups]
return groups
def dict_product(**iterable):
for items in product(*iterable.values()):
yield dict(zip(iterable.keys(), items))