-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathconfig.py
More file actions
124 lines (98 loc) · 3.95 KB
/
config.py
File metadata and controls
124 lines (98 loc) · 3.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from ast import literal_eval
import torch
class CfgNode:
""" a lightweight configuration class inspired by yacs """
# TODO: convert to subclass from a dict like in yacs?
# TODO: implement freezing to prevent shooting of own foot
# TODO: additional existence/override checks when reading/writing params?
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
def __str__(self):
return self._str_helper(0)
def _str_helper(self, indent):
""" need to have a helper to support nested indentation for pretty printing """
parts = []
for k, v in self.__dict__.items():
if isinstance(v, CfgNode):
parts.append("%s:\n" % k)
parts.append(v._str_helper(indent + 1))
else:
parts.append("%s: %s\n" % (k, v))
parts = [' ' * (indent * 4) + p for p in parts]
return "".join(parts)
def to_dict(self):
""" return a dict representation of the config """
return { k: v.to_dict() if isinstance(v, CfgNode) else v for k, v in self.__dict__.items() }
def merge_from_dict(self, d):
self.__dict__.update(d)
def merge_from_args(self, args):
"""
update the configuration from a list of strings that is expected
to come from the command line, i.e. sys.argv[1:].
The arguments are expected to be in the form of `--arg=value`, and
the arg can use . to denote nested sub-attributes. Example:
--model.n_layer=10 --trainer.batch_size=32
"""
for arg in args:
keyval = arg.split('=')
assert len(keyval) == 2, "expecting each override arg to be of form --arg=value, got %s" % arg
key, val = keyval # unpack
# first translate val into a python object
try:
val = literal_eval(val)
"""
need some explanation here.
- if val is simply a string, literal_eval will throw a ValueError
- if val represents a thing (like an 3, 3.14, [1,2,3], False, None, etc.) it will get created
"""
except ValueError:
pass
# find the appropriate object to insert the attribute into
assert key[:2] == '--'
key = key[2:] # strip the '--'
keys = key.split('.')
obj = self
for k in keys[:-1]:
obj = getattr(obj, k)
leaf_key = keys[-1]
# ensure that this attribute exists
assert hasattr(obj, leaf_key), f"{key} is not an attribute that exists in the config"
# overwrite the attribute
print("command line overwriting config attribute %s with %s" % (key, val))
setattr(obj, leaf_key, val)
def get_default_config():
C = CfgNode()
# either model_type or (n_layer, n_head, n_embd) must be given in the config
C.model_type = 'gpt'
C.n_layer = 2
C.n_head = 1
C.n_embd = 16
C.in_dim = 2 # we probably don't need to project inputs to higher space but we can with this transformer
C.out_dim = C.in_dim # we just want one answer at the end
# I THINK NOT NECESSARY FOR US
C.vocab_size = None
C.block_size = 10
# dropout hyperparameters
C.embd_pdrop = 0
C.resid_pdrop = 0
C.attn_pdrop = 0
C.max_seq_length = 6123
return C
def linreg_config():
C = CfgNode()
# either model_type or (n_layer, n_head, n_embd) must be given in the config
C.model_type = 'gpt'
C.n_layer = 2
C.n_head = 1
C.n_embd = 16
C.in_dim = 1 # we probably don't need to project inputs to higher space but we can with this transformer
C.out_dim = C.in_dim # we just want one answer at the end
# I THINK NOT NECESSARY FOR US
C.vocab_size = None
C.block_size = 10
# dropout hyperparameters
C.embd_pdrop = 0
C.resid_pdrop = 0
C.attn_pdrop = 0
C.max_seq_length = 65*2
return C