From 5ba2b01713054c2b4e2b6d8c07462ce676e4f16e Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 27 Feb 2026 17:44:06 +0800 Subject: [PATCH 1/2] fix --- src/twinkle/infra/__init__.py | 6 +++++- src/twinkle/loss/cross_entropy.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 2066f9a6..6abb7b41 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -5,7 +5,9 @@ import os from typing import Any, Callable, List, Literal, Optional, TypeVar, Union -from twinkle.utils import DeviceGroup, DeviceMesh, Platform, check_unsafe, framework_util, requires +from twinkle.utils import DeviceGroup, DeviceMesh, Platform, check_unsafe, framework_util, requires, get_logger + +logger = get_logger() T1 = TypeVar('T1', bound=object) @@ -437,6 +439,8 @@ def new_init(self, *args, **kwargs): # Pass an instance_id is recommended instance_id = kwargs.pop('instance_id', '') + f'{caller_file}_{caller_line}' remote_group = kwargs.get('remote_group') + if remote_group is None: + logger.info(f'⚠️ Using local initialization of class: {cls}, please make sure the class does not need remote execution.') # If cannot trust_remote_code, no callable and type can be used. check_unsafe(*args, **kwargs) diff --git a/src/twinkle/loss/cross_entropy.py b/src/twinkle/loss/cross_entropy.py index 4006c1e8..c5f25b9c 100644 --- a/src/twinkle/loss/cross_entropy.py +++ b/src/twinkle/loss/cross_entropy.py @@ -11,4 +11,8 @@ def __call__(self, inputs, outputs, **kwargs): import torch logits = outputs['logits'].view(-1, outputs['logits'].shape[-1]) labels = inputs['labels'].view(-1) - return torch.nn.CrossEntropyLoss(reduction=self.reduction)(logits, labels) + loss = torch.nn.CrossEntropyLoss(reduction=self.reduction)(logits, labels) + if self.reduction != 'sum': + return loss + else: + return loss, (labels != -100).sum() From 01b62acd7d64603a78d82c439b00184f59d57271 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Fri, 27 Feb 2026 17:45:08 +0800 Subject: [PATCH 2/2] lint code --- src/twinkle/infra/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 6abb7b41..cc5e4ec6 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -5,7 +5,7 @@ import os from typing import Any, Callable, List, Literal, Optional, TypeVar, Union -from twinkle.utils import DeviceGroup, DeviceMesh, Platform, check_unsafe, framework_util, requires, get_logger +from twinkle.utils import DeviceGroup, DeviceMesh, Platform, check_unsafe, framework_util, get_logger, requires logger = get_logger() @@ -440,7 +440,8 @@ def new_init(self, *args, **kwargs): instance_id = kwargs.pop('instance_id', '') + f'{caller_file}_{caller_line}' remote_group = kwargs.get('remote_group') if remote_group is None: - logger.info(f'⚠️ Using local initialization of class: {cls}, please make sure the class does not need remote execution.') + logger.info(f'⚠️ Using local initialization of class: {cls}, please make sure the class ' + 'does not need remote execution.') # If cannot trust_remote_code, no callable and type can be used. check_unsafe(*args, **kwargs)