Skip to content

Some confusion about the fast grads calculation when converting to Pytorch. #36

@ShunLu91

Description

@ShunLu91

Hello,

Great thanks to you for your great efforts. After reading your paper and code, I found that it's in fact a nice and solid work and I really enjoy it.

To utilize this method in my model training, I try to implement your method using the Pytorch framework. I notice that you use the following code to calculate the gradient norm in a fast mode:

if self.fast:
    grads = K.sqrt(sum([
        self._sum_per_sample(K.square(g))
        for g in K.gradients(losses, self.parameter_list)
    ]))

As far as I am concerned, this line of the code [self._sum_per_sample(K.square(g)) for g in K.gradients(losses, self.parameter_list)] has computed the gradients square and summed them per sample. I am confused about why not directly use K.sqrt() function to get the gradient norm of each sample but introduce another sum() function behind the K.sqrt()?

Besides, I have checked the results of sum([self._sum_per_sample(K.square(g)) for g in K.gradients(losses, self.parameter_list)]) and [self._sum_per_sample(K.square(g)) for g in K.gradients(losses, self.parameter_list)], and found that they were equal, which is really amazing. And if I remove the sum() function behind the K.sqrt(), it will raise the data type error. Therefore, does this sum() function only convert the data type and not perform summation?

Expect your reply and I will share my Pytorch implementation once they are ready.

Best,
Shun Lu

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions