diff --git a/returnn/tf/layers/base.py b/returnn/tf/layers/base.py index 53c423a292..e38fd8c5c4 100644 --- a/returnn/tf/layers/base.py +++ b/returnn/tf/layers/base.py @@ -161,7 +161,7 @@ def __init__(self, name, network, output=None, n_out=NotSpecified, out_type=None self.collocate_with = collocate_with or [] self.post_init_hooks = [] # list of functions self.sources = list(sources) - self.params = {} # type: typing.Dict[str,tf.Variable] + self.params = {} # type: typing.Dict[str,typing.Union[tf.Variable,tf.Tensor]] self.saveable_param_replace = {} # type: typing.Dict[tf.Variable,typing.Union['tensorflow.python.training.saver.BaseSaverBuilder.SaveableObject',None]] # see get_saveable_params_dict() # nopep8 self.reuse_params = reuse_params self.param_device = param_device @@ -800,7 +800,7 @@ def add_param(self, param, custom_update=None, trainable=None, saveable=None, ax :param bool|None saveable: :param list[list[int]]|None axes_split_info: e.g. [[n],[n]*4] for LSTM matrices :return: param - :rtype tf.Variable + :rtype tf.Variable|tf.Tensor """ _param = param if isinstance(param, tf.Tensor): @@ -810,54 +810,56 @@ def add_param(self, param, custom_update=None, trainable=None, saveable=None, ax import re possible_params = tf_compat.v1.get_collection( tf_compat.v1.GraphKeys.GLOBAL_VARIABLES, scope=re.escape(self.get_absolute_name_scope_prefix())) - if not possible_params: - # None found. Just return as-is. - return param - all_ops = graph_editor.get_backward_walk_ops([param.op], inclusive=False, control_inputs=False) - all_1st_tensors = [op.outputs[0] for op in all_ops if len(op.outputs) == 1] + if possible_params: + all_ops = graph_editor.get_backward_walk_ops([param.op], inclusive=False, control_inputs=False) + all_1st_tensors = [op.outputs[0] for op in all_ops if len(op.outputs) == 1] + # noinspection PyProtectedMember + possible_params = [p for p in possible_params if tf_util.var_handle_or_ref(p) in all_1st_tensors] + if possible_params: + assert len(possible_params) == 1 + param = possible_params[0] + assert isinstance(param, (tf.Variable, tf.Tensor)) + if isinstance(param, tf.Variable): + if not self.trainable: + trainable_collection_ref = param.graph.get_collection_ref(tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES) + if param in trainable_collection_ref: + trainable_collection_ref.remove(param) + if trainable is None: + trainable = param in param.graph.get_collection_ref(tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES) + if saveable is None: + saveable = True + if custom_update: + assert trainable + custom_update.set_on_var(param) + if axes_split_info: + tf_util.set_param_axes_split_info(param, axes_split_info) + if not saveable: + self.saveable_param_replace[param] = None + if getattr(param, "RETURNN_layer", None) is None: + param.RETURNN_layer = self + if getattr(param, "RETURNN_updater_opts", None) is None and self.updater_opts.truth_value: + param.RETURNN_updater_opts = self.updater_opts + # Note that any further postprocessing on the parameter should not be done here, + # as we cannot guarantee that the result from this method is really used, + # e.g. when we use official TF code such as the official LSTM cell. + # The better way is to do it in self.var_creation_scope(), which also applies in those cases. + + if getattr(_param, "_RETURNN_layer_map_name", None) is not None: + # Be explicit, take param_name directly from ReuseParams.variable_custom_getter # noinspection PyProtectedMember - possible_params = [p for p in possible_params if tf_util.var_handle_or_ref(p) in all_1st_tensors] - if not possible_params: - # Not found. Just return as-is. - return param - assert len(possible_params) == 1 - param = possible_params[0] - assert isinstance(param, tf.Variable) - if not self.trainable: - trainable_collection_ref = param.graph.get_collection_ref(tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES) - if param in trainable_collection_ref: - trainable_collection_ref.remove(param) - if trainable is None: - trainable = param in param.graph.get_collection_ref(tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES) - if saveable is None: - saveable = True - if custom_update: - assert trainable - custom_update.set_on_var(param) - if axes_split_info: - tf_util.set_param_axes_split_info(param, axes_split_info) - if self.reuse_params: - name_scope_prefix = self.reuse_params.get_absolute_name_scope_prefix(base_layer=self, param=param) + param_name = _param._RETURNN_layer_map_name else: - name_scope_prefix = self.get_absolute_name_scope_prefix() - assert param.name - assert param.name[:len(name_scope_prefix)] == name_scope_prefix - assert param.name[-2:] == ":0" - param_name = param.name[len(name_scope_prefix):-2] + if self.reuse_params: + name_scope_prefix = self.reuse_params.get_absolute_name_scope_prefix(base_layer=self, param=param) + else: + name_scope_prefix = self.get_absolute_name_scope_prefix() + assert param.name + assert param.name[:len(name_scope_prefix)] == name_scope_prefix + assert param.name[-2:] == ":0" + param_name = param.name[len(name_scope_prefix):-2] if param_name not in self.params: self.params[param_name] = param - else: - assert self.params[param_name] is param - if not saveable: - self.saveable_param_replace[param] = None - if getattr(param, "RETURNN_layer", None) is None: - param.RETURNN_layer = self - if getattr(param, "RETURNN_updater_opts", None) is None and self.updater_opts.truth_value: - param.RETURNN_updater_opts = self.updater_opts - # Note that any further postprocessing on the parameter should not be done here, - # as we cannot guarantee that the result from this method is really used, - # e.g. when we use official TF code such as the official LSTM cell. - # The better way is to do it in self.var_creation_scope(), which also applies in those cases. + assert self.params[param_name] is param return _param def set_param_values_by_dict(self, values_dict, session, ignore_wrong_shape=False, copy_param_mode=None): @@ -1755,8 +1757,12 @@ def custom_getter(getter, name, *args, **kwargs): assert name.startswith(abs_scope_prefix) param_name = name[len(abs_scope_prefix):] # e.g. "W" (not "rec/W") if self.custom_func: - return self.custom_func( + variable = self.custom_func( base_layer=base_layer, reuse_layer=self.reuse_layer, name=param_name, getter=getter, full_name=name, **kwargs) + # The name of the variable created by custom_func might not match param_name. + # We store it here for LayerBase.add_param. + variable._RETURNN_layer_map_name = param_name + return variable if self.param_map is not None: if not self.auto_create_missing: assert param_name in self.param_map diff --git a/returnn/tf/network.py b/returnn/tf/network.py index 73f7e94778..16d5e0cbe1 100644 --- a/returnn/tf/network.py +++ b/returnn/tf/network.py @@ -1158,6 +1158,8 @@ def get_params_list(self): for layer in self._get_all_layers(): assert isinstance(layer, LayerBase) for param_name, param in sorted(layer.params.items()): + if isinstance(param, tf.Tensor): # could happen with reuse_param + continue assert isinstance(param, tf.Variable) if param in ls: # could happen with reuse_params continue @@ -1205,6 +1207,8 @@ def get_trainable_params(self): layer = self.layers[layer_name] assert isinstance(layer, LayerBase) for param_name, param in sorted(layer.params.items()): + if isinstance(param, tf.Tensor): # could happen with reuse_params + continue assert isinstance(param, tf.Variable) if param in trainable_vars_col: ls.append(param) diff --git a/tests/test_TFNetworkLayer.py b/tests/test_TFNetworkLayer.py index 47a9a71ac9..62c10ddb44 100644 --- a/tests/test_TFNetworkLayer.py +++ b/tests/test_TFNetworkLayer.py @@ -1875,7 +1875,7 @@ def test_reuse_params_map_custom(): l1 = network.layers["l1"] l2 = network.layers["output"] assert_equal(set(l1.params.keys()), {"W"}) - assert_equal(set(l2.params.keys()), {"b"}) + assert_equal(set(l2.params.keys()), {"W", "b"}) assert_equal(set(network.get_trainable_params()), {l1.params["W"], l2.params["b"]}) @@ -1904,7 +1904,7 @@ def test_reuse_params_map_custom_rev(): network.construct_from_dict(config.typed_dict["network"]) l1 = network.layers["l1"] l2 = network.layers["output"] - assert_equal(set(l1.params.keys()), {"b"}) + assert_equal(set(l1.params.keys()), {"W", "b"}) assert_equal(set(l2.params.keys()), {"W"}) assert_equal(set(network.get_trainable_params()), {l2.params["W"], l1.params["b"]}) @@ -1968,7 +1968,7 @@ def test_reuse_params_map_custom_dep_loop(): assert_equal(set(train_rec_layer.cell.input_layers_moved_out), {"output", "target_embed"}) assert_equal(set(train_rec_layer.cell.output_layers_moved_out), {"output_prob", "readout", "readout_in"}) assert isinstance(train_rec_layer.cell.output_layers_net, TFNetwork) - assert_equal(set(train_rec_layer.cell.output_layers_net.layers["output_prob"].params.keys()), {"b"}) + assert_equal(set(train_rec_layer.cell.output_layers_net.layers["output_prob"].params.keys()), {"W", "b"}) with make_scope() as session: print("Construct for search") search_net = TFNetwork(config=config, train_flag=False, eval_flag=True, search_flag=True) @@ -3059,6 +3059,33 @@ def make_feed_dict(seq_len=10): session.run(network.get_default_output_layer().output.placeholder, feed_dict=feed) +def test_ReuseParams_different_names(): + n_batch, n_time, n_total, n_heads = 7, 3, 40, 2 + assert n_total % n_heads == 0 + config = Config({ + "extern_data": {"data": {"dim": n_total}}, + "debug_print_layer_output_template": True, + }) + with make_scope(): + net = TFNetwork(config=config) + + def custom(reuse_layer, *args, **kwargs): + return reuse_layer.params['QKV'] + + net.construct_from_dict({ + "self_att": {"class": "self_attention", "num_heads": n_heads, "total_key_dim": n_total, "n_out": n_total}, + "linear": {"class": "linear", "n_out": n_total * 3, "activation": None, "with_bias": False, + "reuse_params": { + "auto_create_missing": False, # should not matter as we do not have any bias + "map": {"W": {"reuse_layer": "self_att", "custom": custom}}}}, + "output": {"class": "copy", "from": "linear"}}) + + self_att = net.get_layer("self_att") + linear = net.get_layer("linear") + assert list(self_att.params.keys()) == ["QKV"] and list(linear.params.keys()) == ["W"] + assert self_att.params["QKV"] is linear.params["W"] + + def test_LossAsIs_custom_dim(): config = Config() config.update({