Skip to content

Commit f04c1f8

Browse files
committed
fix lint
1 parent 9d239da commit f04c1f8

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

src/twinkle/model/transformers/strategy/sequence_parallel.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,7 @@ class SequenceParallelConfig:
870870
enabled: bool = True
871871
ulysses_size: Optional[int] = None
872872
gather_logits: bool = True
873-
loss_reduction: str = "mean"
873+
loss_reduction: str = 'mean'
874874
compensate_fsdp_avg: bool = False
875875

876876

@@ -975,17 +975,16 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore
975975
if labels is None or sequence_parallel._sp_group is None:
976976
return loss
977977
# Compute global loss via autograd-aware all-reduce.
978-
reduction = str(self.sp_config.get("loss_reduction", "mean")).lower()
979-
if reduction == "none":
980-
raise ValueError(
981-
"SequenceParallelStrategy.reduce_loss only supports reduction='sum' or 'mean'. "
982-
"Please aggregate per-token losses before calling reduce_loss."
983-
)
984-
compensate_fsdp_avg = bool(self.sp_config.get("compensate_fsdp_avg", False))
978+
reduction = str(self.sp_config.get('loss_reduction', 'mean')).lower()
979+
if reduction == 'none':
980+
raise ValueError("SequenceParallelStrategy.reduce_loss only supports reduction='sum' or 'mean'. "
981+
'Please aggregate per-token losses before calling reduce_loss.')
982+
compensate_fsdp_avg = bool(self.sp_config.get('compensate_fsdp_avg', False))
985983
compensate_factor = float(self.ulysses_size if compensate_fsdp_avg else 1.0)
986984
sum_metric_scale = float(self.ulysses_size)
987985

988986
class _ReduceSequenceParallelLoss(torch.autograd.Function):
987+
989988
@staticmethod
990989
def forward(ctx, local_mean: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor:
991990
local_tokens = num_valid_tokens.detach().clone()
@@ -1011,6 +1010,7 @@ def backward(ctx, grad_output: torch.Tensor):
10111010
return grad_local_mean, None
10121011

10131012
class _ReduceSequenceParallelSum(torch.autograd.Function):
1013+
10141014
@staticmethod
10151015
def forward(ctx, local_sum: torch.Tensor) -> torch.Tensor:
10161016
ctx.sum_metric_scale = sum_metric_scale
@@ -1026,7 +1026,7 @@ def backward(ctx, grad_output: torch.Tensor):
10261026
# logging/metric alignment under outer collect='mean'.
10271027
return grad_output
10281028

1029-
if reduction == "sum":
1029+
if reduction == 'sum':
10301030
return _ReduceSequenceParallelSum.apply(loss)
10311031

10321032
# Default to mean reduction: `loss` is local mean.

src/twinkle/model/transformers/transformers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def _ensure_sp_strategy(self) -> None:
252252
# accelerate DDP/FSDP paths), compensate SP loss backward to keep gradient scale.
253253
if isinstance(self.strategy, (NativeFSDPStrategy, AccelerateStrategy)) and self.device_mesh is not None:
254254
if (self.device_mesh.ulysses_size or 1) > 1 and (self.device_mesh.data_world_size or 1) > 1:
255-
sp_config["compensate_fsdp_avg"] = True
255+
sp_config['compensate_fsdp_avg'] = True
256256
self.sp_strategy = SequenceParallelStrategy(
257257
self.device_mesh,
258258
sp_config,
@@ -440,9 +440,9 @@ def calculate_loss(self, **kwargs):
440440
optimizer_config = self.optimizer_group[adapter_name]
441441
optimizer_config.num_tokens += counts.item()
442442
if self.sp_strategy is not None and 'labels' in inputs:
443-
reduction = getattr(loss_instance, "reduction", None)
443+
reduction = getattr(loss_instance, 'reduction', None)
444444
if reduction is not None:
445-
self.sp_strategy.sp_config["loss_reduction"] = str(reduction)
445+
self.sp_strategy.sp_config['loss_reduction'] = str(reduction)
446446
loss_value = self.sp_strategy.reduce_loss(loss_value, inputs['labels'])
447447
optimizer_config.loss_value += loss_value
448448
outputs['loss'] = optimizer_config.loss_value

0 commit comments

Comments
 (0)