We need to support each of the following choices (from @edmundlth):
There are few distinct objects where that choice need to be made here (going to use n for total dataset size and b for batch size)
k=n or k=b when SGLD sampling for w_i.
L_n or L_b in E_w L_k(w)
L_n or L_b in L_k(w*).
k=n or k=b in (E_w L_{k'}(w) - L_{k'}(w^*)) * k / log(k)
One thing that I noticed during TMS and also more recently in the grokking experiment is that the statistics for L_k(w) and L_n(w) could be different over samples of w in an epoch of SGD or SGLD.
Whenever possible, one should use k = n above. As in, in decreasing order of preference and theoretical support:
Do all of 1 -4 with k = n: SGLD sampling with b=n and all loss eval are on full data and calculation happen with k = n.
Do SGLD sampling with k = b (minibatching), but for every SGLD sample w_i, still evaluate the loss on the full dataset.
Do SGLD sampling with minibatching and only use the minibatch loss, i.e. E_w L_n(w) is approx-ed as \mean_i L_b(w_i). But the loss eval on w^* should still be on the full dataset, i.e. L_n(w^*).
We need to support each of the following choices (from @edmundlth):
There are few distinct objects where that choice need to be made here (going to use n for total dataset size and b for batch size)
k=n or k=b when SGLD sampling for w_i.
L_n or L_b in E_w L_k(w)
L_n or L_b in L_k(w*).
k=n or k=b in (E_w L_{k'}(w) - L_{k'}(w^*)) * k / log(k)
One thing that I noticed during TMS and also more recently in the grokking experiment is that the statistics for L_k(w) and L_n(w) could be different over samples of w in an epoch of SGD or SGLD.
Whenever possible, one should use k = n above. As in, in decreasing order of preference and theoretical support:
Do all of 1 -4 with k = n: SGLD sampling with b=n and all loss eval are on full data and calculation happen with k = n.
Do SGLD sampling with k = b (minibatching), but for every SGLD sample w_i, still evaluate the loss on the full dataset.
Do SGLD sampling with minibatching and only use the minibatch loss, i.e. E_w L_n(w) is approx-ed as \mean_i L_b(w_i). But the loss eval on w^* should still be on the full dataset, i.e. L_n(w^*).