Skip to content

Commit 30b6411

Browse files
committed
fix ga step
1 parent deaf96a commit 30b6411

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

src/twinkle/model/optimizer_group.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def do_grad_sync(self, gradient_accumulation_steps: Optional[int] = None) -> boo
5252
gradient_accumulation_steps = self.gradient_accumulation_steps
5353
else:
5454
self.gradient_accumulation_steps = gradient_accumulation_steps
55-
return (self.cur_step - 1) % gradient_accumulation_steps == 0 and self.cur_step > 1
55+
return gradient_accumulation_steps == 1 or ((self.cur_step - 1) % gradient_accumulation_steps == 0
56+
and self.cur_step > 1)
5657

5758
def _get_lr(self):
5859
"""Get learning rates from optimizer. Override in subclass."""

0 commit comments

Comments
 (0)