Skip to content

Commit db2f9cf

Browse files
init2winit Teamcopybara-github
authored andcommitted
internal
PiperOrigin-RevId: 872213481
1 parent 26eaa80 commit db2f9cf

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

init2winit/hyperparameters.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ def build_hparams(model_name,
150150
if merged_dict.get('use_shallue_label_smoothing', False):
151151
num_classes = merged_dict['output_shape'][-1]
152152
merged_dict['label_smoothing'] *= num_classes / float(num_classes - 1)
153+
if 'compile_init_on_cpu' not in merged_dict:
154+
merged_dict['compile_init_on_cpu'] = False
153155

154156
merged = config_dict.ConfigDict(merged_dict)
155157
merged.lock()

init2winit/model_lib/base_model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,10 +200,18 @@ def initialize(self, initializer, hps, rng, metrics_logger):
200200
# construction.
201201
# We initialize model params on host to avoid memory issues.
202202

203+
compile_init_on_cpu = hps.get('compile_init_on_cpu', False)
204+
jit_kwargs = {}
205+
if compile_init_on_cpu:
206+
jit_kwargs['backend'] = 'cpu'
207+
logging.info(
208+
'Compiling model init on %s.',
209+
'cpu' if compile_init_on_cpu else 'device',
210+
)
203211
start_time = time.time()
204212
model_init_fn = jax.jit(
205-
functools.partial(self.flax_module.init, train=False),
206-
backend='cpu')
213+
functools.partial(self.flax_module.init, train=False), **jit_kwargs
214+
)
207215

208216
init_dict = model_init_fn({'params': params_rng, 'dropout': dropout_rng},
209217
*fake_input_batch)

0 commit comments

Comments
 (0)