@@ -63,9 +63,6 @@ def __init__(
6363 training_metrics_config = None ,
6464 callback_configs = None ,
6565 external_checkpoint_path = None ,
66- dataset_meta_data = None ,
67- loss_name = None ,
68- metrics_name = None ,
6966 data_selector = None ,
7067 training_algorithm_class = training_algorithm .OptaxTrainingAlgorithm ,
7168 ):
@@ -125,22 +122,10 @@ def __init__(
125122 external_checkpoint_path: (str) If this argument is set, we will load the
126123 optimizer_state, params, batch_stats, and training_metrics from the
127124 checkpoint at this location.
128- dataset_meta_data: meta_data about the dataset. It is not directly used in
129- the base trainer. Users are expected to overwrite the initialization
130- method in a customimzed trainer to access it.
131- loss_name: name of the loss function. Not directly used in base trainer.
132- Users are expected to overwrite the initialization method in a
133- customimzed trainer to access it.
134- metrics_name: Not directly used in the base trainer. Users are expected to
135- overwrite the initialization method in a customimzed trainer to access
136- it.
137125 data_selector: data selection function returned by
138126 datasets.get_data_selector.
139127 training_algorithm_class: Class of training algorithm to use.
140128 """
141- del dataset_meta_data
142- del loss_name
143- del metrics_name
144129 self ._train_dir = train_dir
145130 self ._model = model
146131 self ._dataset_builder = dataset_builder
0 commit comments