Skip to content

Commit eb91bbb

Browse files
fix grad norm bug (#81)
1 parent 2fa2608 commit eb91bbb

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

src/twinkle/infra/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
import os
66
from typing import Any, Callable, List, Literal, Optional, TypeVar, Union
77

8-
from twinkle.utils import DeviceGroup, DeviceMesh, Platform, check_unsafe, framework_util, requires
8+
from twinkle.utils import DeviceGroup, DeviceMesh, Platform, check_unsafe, framework_util, get_logger, requires
9+
10+
logger = get_logger()
911

1012
T1 = TypeVar('T1', bound=object)
1113

@@ -437,6 +439,9 @@ def new_init(self, *args, **kwargs):
437439
# Pass an instance_id is recommended
438440
instance_id = kwargs.pop('instance_id', '') + f'{caller_file}_{caller_line}'
439441
remote_group = kwargs.get('remote_group')
442+
if remote_group is None:
443+
logger.info(f'⚠️ Using local initialization of class: {cls}, please make sure the class '
444+
'does not need remote execution.')
440445
# If cannot trust_remote_code, no callable and type can be used.
441446
check_unsafe(*args, **kwargs)
442447

src/twinkle/loss/cross_entropy.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,8 @@ def __call__(self, inputs, outputs, **kwargs):
1111
import torch
1212
logits = outputs['logits'].view(-1, outputs['logits'].shape[-1])
1313
labels = inputs['labels'].view(-1)
14-
return torch.nn.CrossEntropyLoss(reduction=self.reduction)(logits, labels)
14+
loss = torch.nn.CrossEntropyLoss(reduction=self.reduction)(logits, labels)
15+
if self.reduction != 'sum':
16+
return loss
17+
else:
18+
return loss, (labels != -100).sum()

0 commit comments

Comments
 (0)