-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
52 lines (45 loc) · 1.68 KB
/
main.py
File metadata and controls
52 lines (45 loc) · 1.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import os
os.environ["PYTHONHASHSEED"] = "0" # temporary placeholder
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
from prototree.prototree import ProtoTree
from util.log import Log
from util.args import get_args, save_args, get_optimizer
from util.data import get_dataloaders
from util.init import init_tree
from util.net import get_network, freeze
from util.visualize import gen_vis, _assign_sample_counts
from util import visualize_delta
from util.analyse import *
from util.save import *
from prototree.train import train_epoch, train_epoch_kontschieder
from prototree.test import eval, eval_fidelity
from prototree.prune import prune
from prototree.project import project, project_with_class_constraints
from prototree.upsample import upsample
from util.utils import initialize_leaf_mu_sigma
import pickle
import torch
from shutil import copy
from copy import deepcopy
from prototree.branch import Branch
from prototree.leaf import Leaf
from prototree.node import Node
import numpy as np
import random
from util.func import min_pool2d
from util.add_func import load_tree, collect_sgima_statistics, update_mu_sigma, delta_update, training_step, tree_prune, check_require_grads
# from car_racing.games.carracing import RacingNet, CarRacing
import math
torch.cuda.empty_cache()
'''
########## Instead of test_accuracy I am considering mae for both reporting and saving as best loss
'''
def log_tree_routing(tree, attr=None, batch_idx=0):
"""
Recursively prints the tree structure along with (1 - ps) for each branch node
for a selected sample in the batch (default: first sample).
"""
def recurse(node, depth=0):
prefix = " " * depth
def mahsa:
print;lll}