-
Notifications
You must be signed in to change notification settings - Fork 134
Description
As discussed in #446, ReuseParams.get_absolute_name_scope_prefix does not infer the name of a reused parameter correctly if a custom_func is used.
I have added a simple test case in #448, where I share the W parameter of a LinearLayer with the QKV parameter of a SelfAttentionLayer:
"self_att": {"class": "self_attention", "num_heads": n_heads, "total_key_dim": n_total, "n_out": n_total},
"linear": {"class": "linear", "n_out": n_total * 3, "activation": None, "with_bias": False,
"reuse_params": {
"auto_create_missing": False, # should not matter as we do not have any bias
"map": {"W": {"reuse_layer": "self_att", "custom": custom}}}}with
def custom(reuse_layer, *args, **kwargs):
return reuse_layer.params['QKV']This gives the error:
File "/home/runner/work/returnn/returnn/returnn/tf/network.py", line 703, in _create_layer
line: layer = layer_class(**layer_desc)
locals:
layer = <not found>
layer_class = <local> <class 'returnn.tf.layers.basic.LinearLayer'>
layer_desc = <local> {'n_out': 40, 'activation': None, 'with_bias': False, 'reuse_params': <ReuseParams reuse_layer None, map {'W': <ReuseParams reuse_layer <SelfAttentionLayer 'self_att' out_type=Data(shape=(None, 40), batch_shape_meta=[B,T|'time:var:extern_data:data',F|40])>, map None>}>, 'sources': [<SourceLayer '..., len = 8
File "/home/runner/work/returnn/returnn/returnn/tf/layers/basic.py", line 1430, in __init__
line: weights = self.add_param(tf_compat.v1.get_variable(
name="W", shape=weights_shape, dtype=tf.float32, initializer=fwd_weights_initializer))
locals:
weights = <not found>
self = <local> <LinearLayer 'linear' out_type=Data(shape=(None, 40), batch_shape_meta=[B,T|'time:var:extern_data:data',F|40])>
self.add_param = <local> <bound method LayerBase.add_param of <LinearLayer 'linear' out_type=Data(shape=(None, 40), batch_shape_meta=[B,T|'time:var:extern_data:data',F|40])>>
tf_compat = <global> <module 'returnn.tf.compat' from '/home/runner/work/returnn/returnn/returnn/tf/compat.py'>
tf_compat.v1 = <global> <module 'tensorflow._api.v2.compat.v1' from '/home/runner/.local/lib/python3.7/site-packages/tensorflow/_api/v2/compat/v1/__init__.py'>
tf_compat.v1.get_variable = <global> <function get_variable at 0x7fe7e9d20cb0>
name = <not found>
shape = <not found>
weights_shape = <local> (40, 40)
dtype = <not found>
tf = <global> <module 'tensorflow' from '/home/runner/.local/lib/python3.7/site-packages/tensorflow/__init__.py'>
tf.float32 = <global> tf.float32
initializer = <not found>
fwd_weights_initializer = <local> <tensorflow.python.ops.init_ops.GlorotUniform object at 0x7fe7dc3df790>
File "/home/runner/work/returnn/returnn/returnn/tf/layers/base.py", line 840, in add_param
line: name_scope_prefix = self.reuse_params.get_absolute_name_scope_prefix(base_layer=self, param=param)
locals:
name_scope_prefix = <not found>
self = <local> <LinearLayer 'linear' out_type=Data(shape=(None, 40), batch_shape_meta=[B,T|'time:var:extern_data:data',F|40])>
self.reuse_params = <local> <ReuseParams reuse_layer None, map {'W': <ReuseParams reuse_layer <SelfAttentionLayer 'self_att' out_type=Data(shape=(None, 40), batch_shape_meta=[B,T|'time:var:extern_data:data',F|40])>, map None>}>
self.reuse_params.get_absolute_name_scope_prefix = <local> <bound method ReuseParams.get_absolute_name_scope_prefix of <ReuseParams reuse_layer None, map {'W': <ReuseParams reuse_layer <SelfAttentionLayer 'self_att' out_type=Data(shape=(None, 40), batch_shape_meta=[B,T|'time:var:extern_data:data',F|40])>, map None>}>>
base_layer = <not found>
param = <local> <tf.Variable 'self_att/QKV:0' shape=(40, 120) dtype=float32>
File "/home/runner/work/returnn/returnn/returnn/tf/layers/base.py", line 1714, in get_absolute_name_scope_prefix
line: assert self.auto_create_missing
locals:
self = <local> <ReuseParams reuse_layer None, map {'W': <ReuseParams reuse_layer <SelfAttentionLayer 'self_att' out_type=Data(shape=(None, 40), batch_shape_meta=[B,T|'time:var:extern_data:data',F|40])>, map None>}>
self.auto_create_missing = <local> False
AssertionError:
As @albertz mentioned in #446:
The bug is somewhat clear. When you get into add_param, you already have gotten the param correctly. That is correctly done in tf.get_variable via the variable scope (via var_creation_scope). However, the task of add_param is just to add it to self.params, under the correct name (W here). The code is somewhat complicated and complex for that, because we try to infer it from the created variable.
I think this needs to be changed. The whole name infer logic for the case we do something custom with the parameter. We can catch that in var_creation_scope. We could still leave the (simple) default case. But whenever we do sth custom there, we could catch the name from a custom getter, and then store that in param._RETURNN_layer_map_name or so, and then get that in add_param.