-
Notifications
You must be signed in to change notification settings - Fork 106
Open
Description
Hi, I still need to test this but if I am reading this correctly
Lines 79 to 94 in f90a42b
| for group in self.param_groups: | |
| params = group["params"] | |
| params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) | |
| for base_i in range(len(params))[::dist.get_world_size()]: | |
| if base_i + dist.get_rank() < len(params): | |
| p = params[base_i + dist.get_rank()] | |
| if p.grad is None: | |
| # continue | |
| p.grad = torch.zeros_like(p) # Force synchronization | |
| state = self.state[p] | |
| if len(state) == 0: | |
| state["momentum_buffer"] = torch.zeros_like(p) | |
| update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) | |
| p.mul_(1 - group["lr"] * group["weight_decay"]) | |
| p.add_(update.reshape(p.shape), alpha=-group["lr"]) | |
| dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) |
It seems that each GPU updates
state["momentum_buffer"] of its assigned parameters in-place and the parameters themselves are synced with all_gather(), but state["momentum_buffer"] isn't synced. Unless there is some automatic guarantee in place what may happen is that it works fine as long as there is no training interruption, but once you save & load only the state["momentum_buffer"] of the parameters handled by cuda:0 is correct...Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels