diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 2066f9a6..cc5e4ec6 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -5,7 +5,9 @@ import os from typing import Any, Callable, List, Literal, Optional, TypeVar, Union -from twinkle.utils import DeviceGroup, DeviceMesh, Platform, check_unsafe, framework_util, requires +from twinkle.utils import DeviceGroup, DeviceMesh, Platform, check_unsafe, framework_util, get_logger, requires + +logger = get_logger() T1 = TypeVar('T1', bound=object) @@ -437,6 +439,9 @@ def new_init(self, *args, **kwargs): # Pass an instance_id is recommended instance_id = kwargs.pop('instance_id', '') + f'{caller_file}_{caller_line}' remote_group = kwargs.get('remote_group') + if remote_group is None: + logger.info(f'⚠️ Using local initialization of class: {cls}, please make sure the class ' + 'does not need remote execution.') # If cannot trust_remote_code, no callable and type can be used. check_unsafe(*args, **kwargs) diff --git a/src/twinkle/loss/cross_entropy.py b/src/twinkle/loss/cross_entropy.py index 4006c1e8..c5f25b9c 100644 --- a/src/twinkle/loss/cross_entropy.py +++ b/src/twinkle/loss/cross_entropy.py @@ -11,4 +11,8 @@ def __call__(self, inputs, outputs, **kwargs): import torch logits = outputs['logits'].view(-1, outputs['logits'].shape[-1]) labels = inputs['labels'].view(-1) - return torch.nn.CrossEntropyLoss(reduction=self.reduction)(logits, labels) + loss = torch.nn.CrossEntropyLoss(reduction=self.reduction)(logits, labels) + if self.reduction != 'sum': + return loss + else: + return loss, (labels != -100).sum()