In the new, general implementation of the SGLD optimizer, the step-wise parameter updates are computed BEFORE weight_decay are applied to the step, meaning these operations are effectively noop. Introduced here.
|
p.data.add_( |
|
preconditioning.noise_coef * noise, |
|
alpha=group["lr"] ** 0.5, |
|
) |
|
# Apply weight decay separately from other updates |
|
if group["weight_decay"] != 0: |
|
d_p.add_(group["weight_decay"] * p.data) |
|
|
|
# Bounding box enforcement |
|
if group["bounding_box_size"] is not None: |
|
initial_param = state["initial_param"] |
|
torch.clamp_( |
|
p.data, |
|
min=initial_param - group["bounding_box_size"], |
|
max=initial_param + group["bounding_box_size"], |
|
) |
|
|
|
# Track metrics |
|
metrics = group["metrics"] |
|
if "dws" in metrics: |
|
metrics["dws"].append(d_p.clone()) |
|
|
|
if "grad_norm" in metrics and p.grad is not None: |
|
metrics["grad_norm"] += ( |
|
(p.grad.data * group["nbeta"] * 0.5 * group["lr"]) ** 2 |
|
).sum() |
|
|
|
if "weight_norm" in metrics: |
|
metrics["weight_norm"] += (p.data**2).sum() |
|
|
|
if "noise" in metrics: |
|
metrics["noise"].append(noise) |
|
|
|
if "noise_norm" in metrics: |
|
metrics["noise_norm"] += (noise**2).sum() |
With the use of specified Priors, the weight_decay may be less of an issue.
In the new, general implementation of the SGLD optimizer, the step-wise parameter updates are computed BEFORE weight_decay are applied to the step, meaning these operations are effectively noop. Introduced here.
devinterp/src/devinterp/optim/sgmcmc.py
Lines 386 to 420 in 491c4b9
With the use of specified
Priors, the weight_decay may be less of an issue.