diff --git a/bit_tf2/train.py b/bit_tf2/train.py index 5292887..f95d7f7 100644 --- a/bit_tf2/train.py +++ b/bit_tf2/train.py @@ -27,6 +27,10 @@ import bit_tf2.models as models import input_pipeline_tf2_or_jax as input_pipeline +# set_memory_growth=True to prevent OOM error +gpu = tf.config.experimental.list_physical_devices('GPU') +tf.config.experimental.set_memory_growth(gpu[0], True) + def reshape_for_keras(features, batch_size, crop_size): features["image"] = tf.reshape(features["image"], (batch_size, crop_size, crop_size, 3)) @@ -140,7 +144,7 @@ def main(args): for epoch, accu in enumerate(history.history['val_accuracy']): logger.info( - f'Step: {epoch * args.eval_every}, ' + f'Step: {epoch * steps_per_epoch}, ' f'Test accuracy: {accu:0.3f}')