Skip to content

L2 regulariser and SAM #24

@konstantinos-p

Description

@konstantinos-p

It seems to me that there might be a mistake in the way the noised state is computed in the current implementation. Specifically

sam/sam_jax/training_utils

in line 537, forward_and_loss which includes the l2 regularization is used to compute grad this is then used in line 546 as input to dual_vector(grad).

I think this is not exactly correct given the original SAM paper. The state shouldn't be noised for the l2 regularization as it is now, but only for the cross-entropy loss. A separate gradient for the clean state should be computed for the l2 regularization and summed with the SAM gradient.

Is there something that I'm missing?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions