Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 53 additions & 47 deletions returnn/tf/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __init__(self, name, network, output=None, n_out=NotSpecified, out_type=None
self.collocate_with = collocate_with or []
self.post_init_hooks = [] # list of functions
self.sources = list(sources)
self.params = {} # type: typing.Dict[str,tf.Variable]
self.params = {} # type: typing.Dict[str,typing.Union[tf.Variable,tf.Tensor]]
self.saveable_param_replace = {} # type: typing.Dict[tf.Variable,typing.Union['tensorflow.python.training.saver.BaseSaverBuilder.SaveableObject',None]] # see get_saveable_params_dict() # nopep8
self.reuse_params = reuse_params
self.param_device = param_device
Expand Down Expand Up @@ -800,7 +800,7 @@ def add_param(self, param, custom_update=None, trainable=None, saveable=None, ax
:param bool|None saveable:
:param list[list[int]]|None axes_split_info: e.g. [[n],[n]*4] for LSTM matrices
:return: param
:rtype tf.Variable
:rtype tf.Variable|tf.Tensor
"""
_param = param
if isinstance(param, tf.Tensor):
Expand All @@ -810,54 +810,56 @@ def add_param(self, param, custom_update=None, trainable=None, saveable=None, ax
import re
possible_params = tf_compat.v1.get_collection(
tf_compat.v1.GraphKeys.GLOBAL_VARIABLES, scope=re.escape(self.get_absolute_name_scope_prefix()))
if not possible_params:
# None found. Just return as-is.
return param
all_ops = graph_editor.get_backward_walk_ops([param.op], inclusive=False, control_inputs=False)
all_1st_tensors = [op.outputs[0] for op in all_ops if len(op.outputs) == 1]
if possible_params:
all_ops = graph_editor.get_backward_walk_ops([param.op], inclusive=False, control_inputs=False)
all_1st_tensors = [op.outputs[0] for op in all_ops if len(op.outputs) == 1]
# noinspection PyProtectedMember
possible_params = [p for p in possible_params if tf_util.var_handle_or_ref(p) in all_1st_tensors]
if possible_params:
assert len(possible_params) == 1
param = possible_params[0]
assert isinstance(param, (tf.Variable, tf.Tensor))
if isinstance(param, tf.Variable):
if not self.trainable:
trainable_collection_ref = param.graph.get_collection_ref(tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES)
if param in trainable_collection_ref:
trainable_collection_ref.remove(param)
if trainable is None:
trainable = param in param.graph.get_collection_ref(tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES)
if saveable is None:
saveable = True
if custom_update:
assert trainable
custom_update.set_on_var(param)
if axes_split_info:
tf_util.set_param_axes_split_info(param, axes_split_info)
if not saveable:
self.saveable_param_replace[param] = None
if getattr(param, "RETURNN_layer", None) is None:
param.RETURNN_layer = self
if getattr(param, "RETURNN_updater_opts", None) is None and self.updater_opts.truth_value:
param.RETURNN_updater_opts = self.updater_opts
# Note that any further postprocessing on the parameter should not be done here,
# as we cannot guarantee that the result from this method is really used,
# e.g. when we use official TF code such as the official LSTM cell.
# The better way is to do it in self.var_creation_scope(), which also applies in those cases.

if getattr(_param, "_RETURNN_layer_map_name", None) is not None:
# Be explicit, take param_name directly from ReuseParams.variable_custom_getter
# noinspection PyProtectedMember
possible_params = [p for p in possible_params if tf_util.var_handle_or_ref(p) in all_1st_tensors]
if not possible_params:
# Not found. Just return as-is.
return param
assert len(possible_params) == 1
param = possible_params[0]
assert isinstance(param, tf.Variable)
if not self.trainable:
trainable_collection_ref = param.graph.get_collection_ref(tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES)
if param in trainable_collection_ref:
trainable_collection_ref.remove(param)
if trainable is None:
trainable = param in param.graph.get_collection_ref(tf_compat.v1.GraphKeys.TRAINABLE_VARIABLES)
if saveable is None:
saveable = True
if custom_update:
assert trainable
custom_update.set_on_var(param)
if axes_split_info:
tf_util.set_param_axes_split_info(param, axes_split_info)
if self.reuse_params:
name_scope_prefix = self.reuse_params.get_absolute_name_scope_prefix(base_layer=self, param=param)
param_name = _param._RETURNN_layer_map_name
else:
name_scope_prefix = self.get_absolute_name_scope_prefix()
assert param.name
assert param.name[:len(name_scope_prefix)] == name_scope_prefix
assert param.name[-2:] == ":0"
param_name = param.name[len(name_scope_prefix):-2]
if self.reuse_params:
name_scope_prefix = self.reuse_params.get_absolute_name_scope_prefix(base_layer=self, param=param)
else:
name_scope_prefix = self.get_absolute_name_scope_prefix()
assert param.name
assert param.name[:len(name_scope_prefix)] == name_scope_prefix
assert param.name[-2:] == ":0"
param_name = param.name[len(name_scope_prefix):-2]
if param_name not in self.params:
self.params[param_name] = param
else:
assert self.params[param_name] is param
if not saveable:
self.saveable_param_replace[param] = None
if getattr(param, "RETURNN_layer", None) is None:
param.RETURNN_layer = self
if getattr(param, "RETURNN_updater_opts", None) is None and self.updater_opts.truth_value:
param.RETURNN_updater_opts = self.updater_opts
# Note that any further postprocessing on the parameter should not be done here,
# as we cannot guarantee that the result from this method is really used,
# e.g. when we use official TF code such as the official LSTM cell.
# The better way is to do it in self.var_creation_scope(), which also applies in those cases.
assert self.params[param_name] is param
return _param

def set_param_values_by_dict(self, values_dict, session, ignore_wrong_shape=False, copy_param_mode=None):
Expand Down Expand Up @@ -1755,8 +1757,12 @@ def custom_getter(getter, name, *args, **kwargs):
assert name.startswith(abs_scope_prefix)
param_name = name[len(abs_scope_prefix):] # e.g. "W" (not "rec/W")
if self.custom_func:
return self.custom_func(
variable = self.custom_func(
base_layer=base_layer, reuse_layer=self.reuse_layer, name=param_name, getter=getter, full_name=name, **kwargs)
# The name of the variable created by custom_func might not match param_name.
# We store it here for LayerBase.add_param.
variable._RETURNN_layer_map_name = param_name
return variable
if self.param_map is not None:
if not self.auto_create_missing:
assert param_name in self.param_map
Expand Down
4 changes: 4 additions & 0 deletions returnn/tf/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1158,6 +1158,8 @@ def get_params_list(self):
for layer in self._get_all_layers():
assert isinstance(layer, LayerBase)
for param_name, param in sorted(layer.params.items()):
if isinstance(param, tf.Tensor): # could happen with reuse_param
continue
assert isinstance(param, tf.Variable)
if param in ls: # could happen with reuse_params
continue
Expand Down Expand Up @@ -1205,6 +1207,8 @@ def get_trainable_params(self):
layer = self.layers[layer_name]
assert isinstance(layer, LayerBase)
for param_name, param in sorted(layer.params.items()):
if isinstance(param, tf.Tensor): # could happen with reuse_params
continue
assert isinstance(param, tf.Variable)
if param in trainable_vars_col:
ls.append(param)
Expand Down
33 changes: 30 additions & 3 deletions tests/test_TFNetworkLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1875,7 +1875,7 @@ def test_reuse_params_map_custom():
l1 = network.layers["l1"]
l2 = network.layers["output"]
assert_equal(set(l1.params.keys()), {"W"})
assert_equal(set(l2.params.keys()), {"b"})
assert_equal(set(l2.params.keys()), {"W", "b"})
assert_equal(set(network.get_trainable_params()), {l1.params["W"], l2.params["b"]})


Expand Down Expand Up @@ -1904,7 +1904,7 @@ def test_reuse_params_map_custom_rev():
network.construct_from_dict(config.typed_dict["network"])
l1 = network.layers["l1"]
l2 = network.layers["output"]
assert_equal(set(l1.params.keys()), {"b"})
assert_equal(set(l1.params.keys()), {"W", "b"})
assert_equal(set(l2.params.keys()), {"W"})
assert_equal(set(network.get_trainable_params()), {l2.params["W"], l1.params["b"]})

Expand Down Expand Up @@ -1968,7 +1968,7 @@ def test_reuse_params_map_custom_dep_loop():
assert_equal(set(train_rec_layer.cell.input_layers_moved_out), {"output", "target_embed"})
assert_equal(set(train_rec_layer.cell.output_layers_moved_out), {"output_prob", "readout", "readout_in"})
assert isinstance(train_rec_layer.cell.output_layers_net, TFNetwork)
assert_equal(set(train_rec_layer.cell.output_layers_net.layers["output_prob"].params.keys()), {"b"})
assert_equal(set(train_rec_layer.cell.output_layers_net.layers["output_prob"].params.keys()), {"W", "b"})
with make_scope() as session:
print("Construct for search")
search_net = TFNetwork(config=config, train_flag=False, eval_flag=True, search_flag=True)
Expand Down Expand Up @@ -3059,6 +3059,33 @@ def make_feed_dict(seq_len=10):
session.run(network.get_default_output_layer().output.placeholder, feed_dict=feed)


def test_ReuseParams_different_names():
n_batch, n_time, n_total, n_heads = 7, 3, 40, 2
assert n_total % n_heads == 0
config = Config({
"extern_data": {"data": {"dim": n_total}},
"debug_print_layer_output_template": True,
})
with make_scope():
net = TFNetwork(config=config)

def custom(reuse_layer, *args, **kwargs):
return reuse_layer.params['QKV']

net.construct_from_dict({
"self_att": {"class": "self_attention", "num_heads": n_heads, "total_key_dim": n_total, "n_out": n_total},
"linear": {"class": "linear", "n_out": n_total * 3, "activation": None, "with_bias": False,
"reuse_params": {
"auto_create_missing": False, # should not matter as we do not have any bias
"map": {"W": {"reuse_layer": "self_att", "custom": custom}}}},
"output": {"class": "copy", "from": "linear"}})

self_att = net.get_layer("self_att")
linear = net.get_layer("linear")
assert list(self_att.params.keys()) == ["QKV"] and list(linear.params.keys()) == ["W"]
assert self_att.params["QKV"] is linear.params["W"]


def test_LossAsIs_custom_dim():
config = Config()
config.update({
Expand Down