|
57 | 57 | flags.DEFINE_string('trainer', 'standard', 'Name of the trainer to use.') |
58 | 58 | flags.DEFINE_string('model', 'fully_connected', 'Name of the model to train.') |
59 | 59 | flags.DEFINE_string('loss', 'cross_entropy', 'Loss function.') |
60 | | -flags.DEFINE_string('metrics', 'classification_metrics', |
61 | | - 'Metrics to be used for evaluation.') |
| 60 | +flags.DEFINE_string( |
| 61 | + 'algoperf_submission_name', '', 'AlgoPerf submission module lookup name.' |
| 62 | +) |
| 63 | +flags.DEFINE_string( |
| 64 | + 'metrics', 'classification_metrics', 'Metrics to be used for evaluation.' |
| 65 | +) |
62 | 66 | flags.DEFINE_string('initializer', 'noop', 'Must be in [noop, meta_init].') |
63 | | -flags.DEFINE_string('experiment_dir', None, |
64 | | - 'Path to save weights and other results. Each trial ' |
65 | | - 'directory will have path experiment_dir/worker_id/.') |
| 67 | +flags.DEFINE_string( |
| 68 | + 'experiment_dir', |
| 69 | + None, |
| 70 | + 'Path to save weights and other results. Each trial ' |
| 71 | + 'directory will have path experiment_dir/worker_id/.', |
| 72 | +) |
66 | 73 | flags.DEFINE_string('dataset', 'mnist', 'Which dataset to train on.') |
67 | 74 | flags.DEFINE_string('data_selector', 'noop', 'Which data selector to use.') |
68 | 75 | flags.DEFINE_integer('num_train_steps', None, 'The number of steps to train.') |
69 | 76 | flags.DEFINE_integer( |
70 | | - 'num_tf_data_prefetches', -1, 'The number of batches to to prefetch from ' |
71 | | - 'network to host at each step. Set to -1 for tf.data.AUTOTUNE.') |
| 77 | + 'num_tf_data_prefetches', |
| 78 | + -1, |
| 79 | + 'The number of batches to to prefetch from ' |
| 80 | + 'network to host at each step. Set to -1 for tf.data.AUTOTUNE.', |
| 81 | +) |
72 | 82 | flags.DEFINE_integer( |
73 | | - 'num_device_prefetches', 0, 'The number of batches to to prefetch from ' |
74 | | - 'host to device at each step.') |
| 83 | + 'num_device_prefetches', |
| 84 | + 0, |
| 85 | + 'The number of batches to to prefetch from host to device at each step.', |
| 86 | +) |
75 | 87 | flags.DEFINE_integer( |
76 | | - 'num_tf_data_map_parallel_calls', -1, 'The number of parallel calls to ' |
77 | | - 'make from tf.data.map. Set to -1 for tf.data.AUTOTUNE.' |
| 88 | + 'num_tf_data_map_parallel_calls', |
| 89 | + -1, |
| 90 | + 'The number of parallel calls to ' |
| 91 | + 'make from tf.data.map. Set to -1 for tf.data.AUTOTUNE.', |
78 | 92 | ) |
79 | 93 | flags.DEFINE_integer('eval_batch_size', None, 'Batch size for evaluation.') |
80 | 94 | flags.DEFINE_bool('eval_use_ema', None, 'If True evals will use ema of params.') |
81 | 95 | flags.DEFINE_integer( |
82 | | - 'eval_num_batches', None, |
| 96 | + 'eval_num_batches', |
| 97 | + None, |
83 | 98 | 'Number of batches for evaluation. Leave None to evaluate ' |
84 | | - 'on the entire validation and test set.') |
| 99 | + 'on the entire validation and test set.', |
| 100 | +) |
85 | 101 | flags.DEFINE_integer( |
86 | | - 'test_num_batches', None, |
| 102 | + 'test_num_batches', |
| 103 | + None, |
87 | 104 | 'Number of batches for eval on test set. Leave None to evaluate ' |
88 | | - 'on the entire test set.') |
89 | | -flags.DEFINE_integer('eval_train_num_batches', None, |
90 | | - 'Number of batches when evaluating on the training set.') |
| 105 | + 'on the entire test set.', |
| 106 | +) |
| 107 | +flags.DEFINE_integer( |
| 108 | + 'eval_train_num_batches', |
| 109 | + None, |
| 110 | + 'Number of batches when evaluating on the training set.', |
| 111 | +) |
91 | 112 | flags.DEFINE_integer('eval_frequency', 1000, 'Evaluate every k steps.') |
92 | 113 | flags.DEFINE_string( |
93 | | - 'hparam_overrides', '', 'JSON representation of a flattened dict of hparam ' |
| 114 | + 'hparam_overrides', |
| 115 | + '', |
| 116 | + 'JSON representation of a flattened dict of hparam ' |
94 | 117 | 'overrides. For nested dictionaries, the override key ' |
95 | | - 'should be specified as lr_hparams.base_lr.') |
| 118 | + 'should be specified as lr_hparams.base_lr.', |
| 119 | +) |
96 | 120 | flags.DEFINE_string( |
97 | | - 'callback_configs', '', 'JSON representation of a list of dictionaries ' |
98 | | - 'which specify general callbacks to be run during eval of training.') |
| 121 | + 'callback_configs', |
| 122 | + '', |
| 123 | + 'JSON representation of a list of dictionaries ' |
| 124 | + 'which specify general callbacks to be run during eval of training.', |
| 125 | +) |
99 | 126 | flags.DEFINE_list( |
100 | | - 'checkpoint_steps', [], 'List of steps to checkpoint the' |
| 127 | + 'checkpoint_steps', |
| 128 | + [], |
| 129 | + 'List of steps to checkpoint the' |
101 | 130 | ' model. The checkpoints will be saved in a separate' |
102 | 131 | 'directory train_dir/checkpoints. Note these checkpoints' |
103 | 132 | 'will be in addition to the normal checkpointing that' |
104 | | - 'occurs during training for preemption purposes.') |
105 | | -flags.DEFINE_string('external_checkpoint_path', None, |
106 | | - 'If this argument is set, the trainer will initialize' |
107 | | - 'the parameters, batch stats, optimizer state, and training' |
108 | | - 'metrics by loading them from the checkpoint at this path.') |
| 133 | + 'occurs during training for preemption purposes.', |
| 134 | +) |
| 135 | +flags.DEFINE_string( |
| 136 | + 'external_checkpoint_path', |
| 137 | + None, |
| 138 | + 'If this argument is set, the trainer will initialize' |
| 139 | + 'the parameters, batch stats, optimizer state, and training' |
| 140 | + 'metrics by loading them from the checkpoint at this path.', |
| 141 | +) |
109 | 142 |
|
110 | 143 | flags.DEFINE_string( |
111 | 144 | 'early_stopping_target_name', |
112 | 145 | None, |
113 | 146 | 'A string naming the metric to use to perform early stopping. If this ' |
114 | 147 | 'metric reaches the value `early_stopping_target_value`, training will ' |
115 | | - 'stop. Must include the dataset split (ex: validation/error_rate).') |
| 148 | + 'stop. Must include the dataset split (ex: validation/error_rate).', |
| 149 | +) |
116 | 150 | flags.DEFINE_float( |
117 | 151 | 'early_stopping_target_value', |
118 | 152 | None, |
119 | | - 'A float indicating the value at which to stop training.') |
| 153 | + 'A float indicating the value at which to stop training.', |
| 154 | +) |
120 | 155 | flags.DEFINE_enum( |
121 | 156 | 'early_stopping_mode', |
122 | 157 | None, |
@@ -198,6 +233,7 @@ def _run( |
198 | 233 | initializer_name, |
199 | 234 | model_name, |
200 | 235 | loss_name, |
| 236 | + algoperf_submission_name, |
201 | 237 | metrics_name, |
202 | 238 | num_train_steps, |
203 | 239 | experiment_dir, |
@@ -225,7 +261,8 @@ def _run( |
225 | 261 | hparam_file=hparam_file, |
226 | 262 | hparam_overrides=hparam_overrides, |
227 | 263 | input_pipeline_hps=input_pipeline_hps, |
228 | | - allowed_unrecognized_hparams=allowed_unrecognized_hparams) |
| 264 | + allowed_unrecognized_hparams=allowed_unrecognized_hparams, |
| 265 | + algoperf_submission_name=algoperf_submission_name) |
229 | 266 |
|
230 | 267 | # Note that one should never tune an RNG seed!!! The seed is only included in |
231 | 268 | # the hparams for convenience of running hparam trials with multiple seeds per |
@@ -358,6 +395,7 @@ def main(unused_argv): |
358 | 395 | initializer_name=FLAGS.initializer, |
359 | 396 | model_name=FLAGS.model, |
360 | 397 | loss_name=FLAGS.loss, |
| 398 | + algoperf_submission_name=FLAGS.algoperf_submission_name, |
361 | 399 | metrics_name=FLAGS.metrics, |
362 | 400 | num_train_steps=FLAGS.num_train_steps, |
363 | 401 | experiment_dir=experiment_dir, |
|
0 commit comments