Skip to content

How to compute gradients with tf.scatter_sub? #4

@mztkenan

Description

@mztkenan

When implemting lambda-opt in tensorflow, I came across a problem to compute gradients with tf.scatter_sub

θ refers to an embedding matrix for docid.
The formulation is

θ(t+1)=θ(t) - α*(grad+2λθ),

delta = theta_grad_no_reg.values * lr + 2 * lr * cur_scale * cur_theta
next_theta_tensor = tf.scatter_sub(theta,theta_grad_no_reg.indices,delta)

then I use θ(t+1) for some computation. Finally, I want to compute gradients with respect to λ, not θ.

But the gradient is None.

I wrote a demo like this:

import tensorflow as tf

w = tf.constant([[1.0], [2.0], [3.0]], dtype=tf.float32)
y = tf.constant([5.0], dtype=tf.float32)

# θ
emb_matrix = tf.get_variable("embedding_name", shape=(10, 3),
                    initializer=tf.random_normal_initializer(),dtype=tf.float32)
# get one line emb
cur_emb=tf.nn.embedding_lookup(emb_matrix,[0])
# The λ matrix
doc_lambda = tf.get_variable(name='docid_lambda', shape=(10, 3),
                             initializer=tf.random_normal_initializer(), dtype=tf.float32)
# get one line λ
cur_lambda=tf.nn.embedding_lookup(doc_lambda, [0])

# θ(t+1) Tensor("ScatterSub:0", shape=(10, 3), dtype=float32_ref)
next_emb_matrix=tf.scatter_sub(emb_matrix, [0], (cur_emb *cur_lambda)) 
# do some compute with θ(t+1) Tensor ,not Variable
next_cur_emb=tf.nn.embedding_lookup(next_emb_matrix,[0])

y_ = tf.matmul(next_cur_emb, w)
loss = tf.reduce_mean((y - y_) ** 2)
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001)
grad_var_list=optimizer.compute_gradients(loss)
print(grad_var_list)
# [(None, <tf.Variable 'embedding_name:0' shape=(10, 3) dtype=float32_ref>), (None, <tf.Variable 'docid_lambda:0' shape=(10, 3) dtype=float32_ref>)]

The gradient is None, too. It seems that tf.scatter_sub op don't provide gradient?

I know it's about the use of tensorflow, not about this paper. But I think you may know about this.
Thanks for your help!

Metadata

Metadata

Assignees

No one assigned

    Labels

    help wantedExtra attention is needed

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions