Skip to content

Commit 4fbaba1

Browse files
sourabh2k15copybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 750320849
1 parent 0775266 commit 4fbaba1

File tree

7 files changed

+793
-49
lines changed

7 files changed

+793
-49
lines changed

init2winit/hyperparameters.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ def build_hparams(model_name,
102102
hparam_file,
103103
hparam_overrides,
104104
input_pipeline_hps=None,
105-
allowed_unrecognized_hparams=None):
105+
allowed_unrecognized_hparams=None,
106+
algoperf_submission_name=None):
106107
"""Build experiment hyperparameters.
107108
108109
Args:
@@ -121,6 +122,7 @@ def build_hparams(model_name,
121122
hparams from an error to a warning can be useful when trying to tune using
122123
a shared search space over multiple workloads that don't all support the
123124
same set of hyperparameters.
125+
algoperf_submission_name: The name of the algoperf submission.
124126
125127
Returns:
126128
A ConfigDict of experiment hyperparameters.
@@ -163,6 +165,10 @@ def build_hparams(model_name,
163165
for key in ['opt_hparams', 'lr_hparams']:
164166
merged[key].unlock()
165167

168+
if algoperf_submission_name:
169+
with merged.unlocked():
170+
merged['algoperf_submission_name'] = algoperf_submission_name
171+
166172
if hparam_file:
167173
logging.info('Loading hparams from %s', hparam_file)
168174
with gfile.GFile(hparam_file, 'r') as f:

init2winit/main.py

Lines changed: 68 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -57,66 +57,101 @@
5757
flags.DEFINE_string('trainer', 'standard', 'Name of the trainer to use.')
5858
flags.DEFINE_string('model', 'fully_connected', 'Name of the model to train.')
5959
flags.DEFINE_string('loss', 'cross_entropy', 'Loss function.')
60-
flags.DEFINE_string('metrics', 'classification_metrics',
61-
'Metrics to be used for evaluation.')
60+
flags.DEFINE_string(
61+
'algoperf_submission_name', '', 'AlgoPerf submission module lookup name.'
62+
)
63+
flags.DEFINE_string(
64+
'metrics', 'classification_metrics', 'Metrics to be used for evaluation.'
65+
)
6266
flags.DEFINE_string('initializer', 'noop', 'Must be in [noop, meta_init].')
63-
flags.DEFINE_string('experiment_dir', None,
64-
'Path to save weights and other results. Each trial '
65-
'directory will have path experiment_dir/worker_id/.')
67+
flags.DEFINE_string(
68+
'experiment_dir',
69+
None,
70+
'Path to save weights and other results. Each trial '
71+
'directory will have path experiment_dir/worker_id/.',
72+
)
6673
flags.DEFINE_string('dataset', 'mnist', 'Which dataset to train on.')
6774
flags.DEFINE_string('data_selector', 'noop', 'Which data selector to use.')
6875
flags.DEFINE_integer('num_train_steps', None, 'The number of steps to train.')
6976
flags.DEFINE_integer(
70-
'num_tf_data_prefetches', -1, 'The number of batches to to prefetch from '
71-
'network to host at each step. Set to -1 for tf.data.AUTOTUNE.')
77+
'num_tf_data_prefetches',
78+
-1,
79+
'The number of batches to to prefetch from '
80+
'network to host at each step. Set to -1 for tf.data.AUTOTUNE.',
81+
)
7282
flags.DEFINE_integer(
73-
'num_device_prefetches', 0, 'The number of batches to to prefetch from '
74-
'host to device at each step.')
83+
'num_device_prefetches',
84+
0,
85+
'The number of batches to to prefetch from host to device at each step.',
86+
)
7587
flags.DEFINE_integer(
76-
'num_tf_data_map_parallel_calls', -1, 'The number of parallel calls to '
77-
'make from tf.data.map. Set to -1 for tf.data.AUTOTUNE.'
88+
'num_tf_data_map_parallel_calls',
89+
-1,
90+
'The number of parallel calls to '
91+
'make from tf.data.map. Set to -1 for tf.data.AUTOTUNE.',
7892
)
7993
flags.DEFINE_integer('eval_batch_size', None, 'Batch size for evaluation.')
8094
flags.DEFINE_bool('eval_use_ema', None, 'If True evals will use ema of params.')
8195
flags.DEFINE_integer(
82-
'eval_num_batches', None,
96+
'eval_num_batches',
97+
None,
8398
'Number of batches for evaluation. Leave None to evaluate '
84-
'on the entire validation and test set.')
99+
'on the entire validation and test set.',
100+
)
85101
flags.DEFINE_integer(
86-
'test_num_batches', None,
102+
'test_num_batches',
103+
None,
87104
'Number of batches for eval on test set. Leave None to evaluate '
88-
'on the entire test set.')
89-
flags.DEFINE_integer('eval_train_num_batches', None,
90-
'Number of batches when evaluating on the training set.')
105+
'on the entire test set.',
106+
)
107+
flags.DEFINE_integer(
108+
'eval_train_num_batches',
109+
None,
110+
'Number of batches when evaluating on the training set.',
111+
)
91112
flags.DEFINE_integer('eval_frequency', 1000, 'Evaluate every k steps.')
92113
flags.DEFINE_string(
93-
'hparam_overrides', '', 'JSON representation of a flattened dict of hparam '
114+
'hparam_overrides',
115+
'',
116+
'JSON representation of a flattened dict of hparam '
94117
'overrides. For nested dictionaries, the override key '
95-
'should be specified as lr_hparams.base_lr.')
118+
'should be specified as lr_hparams.base_lr.',
119+
)
96120
flags.DEFINE_string(
97-
'callback_configs', '', 'JSON representation of a list of dictionaries '
98-
'which specify general callbacks to be run during eval of training.')
121+
'callback_configs',
122+
'',
123+
'JSON representation of a list of dictionaries '
124+
'which specify general callbacks to be run during eval of training.',
125+
)
99126
flags.DEFINE_list(
100-
'checkpoint_steps', [], 'List of steps to checkpoint the'
127+
'checkpoint_steps',
128+
[],
129+
'List of steps to checkpoint the'
101130
' model. The checkpoints will be saved in a separate'
102131
'directory train_dir/checkpoints. Note these checkpoints'
103132
'will be in addition to the normal checkpointing that'
104-
'occurs during training for preemption purposes.')
105-
flags.DEFINE_string('external_checkpoint_path', None,
106-
'If this argument is set, the trainer will initialize'
107-
'the parameters, batch stats, optimizer state, and training'
108-
'metrics by loading them from the checkpoint at this path.')
133+
'occurs during training for preemption purposes.',
134+
)
135+
flags.DEFINE_string(
136+
'external_checkpoint_path',
137+
None,
138+
'If this argument is set, the trainer will initialize'
139+
'the parameters, batch stats, optimizer state, and training'
140+
'metrics by loading them from the checkpoint at this path.',
141+
)
109142

110143
flags.DEFINE_string(
111144
'early_stopping_target_name',
112145
None,
113146
'A string naming the metric to use to perform early stopping. If this '
114147
'metric reaches the value `early_stopping_target_value`, training will '
115-
'stop. Must include the dataset split (ex: validation/error_rate).')
148+
'stop. Must include the dataset split (ex: validation/error_rate).',
149+
)
116150
flags.DEFINE_float(
117151
'early_stopping_target_value',
118152
None,
119-
'A float indicating the value at which to stop training.')
153+
'A float indicating the value at which to stop training.',
154+
)
120155
flags.DEFINE_enum(
121156
'early_stopping_mode',
122157
None,
@@ -198,6 +233,7 @@ def _run(
198233
initializer_name,
199234
model_name,
200235
loss_name,
236+
algoperf_submission_name,
201237
metrics_name,
202238
num_train_steps,
203239
experiment_dir,
@@ -225,7 +261,8 @@ def _run(
225261
hparam_file=hparam_file,
226262
hparam_overrides=hparam_overrides,
227263
input_pipeline_hps=input_pipeline_hps,
228-
allowed_unrecognized_hparams=allowed_unrecognized_hparams)
264+
allowed_unrecognized_hparams=allowed_unrecognized_hparams,
265+
algoperf_submission_name=algoperf_submission_name)
229266

230267
# Note that one should never tune an RNG seed!!! The seed is only included in
231268
# the hparams for convenience of running hparam trials with multiple seeds per
@@ -358,6 +395,7 @@ def main(unused_argv):
358395
initializer_name=FLAGS.initializer,
359396
model_name=FLAGS.model,
360397
loss_name=FLAGS.loss,
398+
algoperf_submission_name=FLAGS.algoperf_submission_name,
361399
metrics_name=FLAGS.metrics,
362400
num_train_steps=FLAGS.num_train_steps,
363401
experiment_dir=experiment_dir,
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# coding=utf-8
2+
# Copyright 2024 The init2winit Authors.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Init2winit workload."""
17+
18+
from init2winit.experiments.mlcommons.workloads import mlcommons_targets
19+
from init2winit.experiments.mlcommons.workloads import mlcommons_workload_info
20+
from init2winit.trainer_lib import spec
21+
22+
23+
class Init2winitWorkload(spec.Workload):
24+
"""Init2winit workload."""
25+
26+
def initialize(self, model, hps):
27+
self._model = model
28+
self._hps = hps
29+
30+
@property
31+
def workload_name(self):
32+
if not self._hps.workload_name:
33+
self._workload_name = self._hps.dataset + '_' + self._hps.model
34+
else:
35+
self._workload_name = self._hps.workload_name
36+
37+
return self._workload_name
38+
39+
@property
40+
def param_shapes(self):
41+
return self._model.param_shapes
42+
43+
@property
44+
def model_params_types(self):
45+
return self._model.param_types
46+
47+
@property
48+
def step_hint(self):
49+
if self.workload_name not in mlcommons_workload_info.num_train_steps:
50+
raise ValueError(
51+
f'Workload {self.workload_name} not found in num_train_steps.')
52+
return mlcommons_workload_info.num_train_steps[self.workload_name]
53+
54+
@property
55+
def target_metric_name(self):
56+
if self.workload_name not in mlcommons_targets.validation_targets:
57+
raise ValueError(
58+
f'Workload {self.workload_name} not found in validation targets.')
59+
60+
return mlcommons_targets.validation_targets[self.workload_name]['metric']

0 commit comments

Comments
 (0)