Skip to content

Commit 9289600

Browse files
committed
lint
1 parent f2e26dd commit 9289600

File tree

12 files changed

+41
-53
lines changed

12 files changed

+41
-53
lines changed

src/twinkle/dataset/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,9 @@ def encode(self, add_generation_prompt: bool = False, **kwargs):
8686
encode_fn = partial(self.template.batch_encode, add_generation_prompt=add_generation_prompt)
8787
with processing_lock('dataset'):
8888
# use a default lock because encode is to all datasets
89-
self.dataset = self.dataset.map(encode_fn,
90-
**kwargs).filter(lambda batch: [True] * len(next(iter(batch.values()))) if 'input_ids' not in batch else [len(x) > 0 for x in batch['input_ids']],
91-
**kwargs)
89+
self.dataset = self.dataset.map(encode_fn, **kwargs).filter(
90+
lambda batch: [True] * len(next(iter(batch.values())))
91+
if 'input_ids' not in batch else [len(x) > 0 for x in batch['input_ids']], **kwargs)
9292

9393
@remote_function()
9494
def check(self, **kwargs):

src/twinkle/infra/collectors.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,9 @@
11
import numpy as np
2-
from typing import TYPE_CHECKING, Any, Dict, List
2+
from typing import Any, Dict, List
33

44
from twinkle import DeviceMesh
55
from twinkle.utils import pad_and_stack_tensors
66

7-
if TYPE_CHECKING:
8-
import torch
9-
107

118
def collect_tensor_dict(outputs: List[Dict[str, Any]], device_mesh: DeviceMesh) -> Dict[str, Any]:
129
import torch

src/twinkle/loss/dpo.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from typing import TYPE_CHECKING, Dict, List, Optional, Union
1010

1111
from twinkle.data_format import LossOutput
12-
from twinkle.utils.torch_utils import selective_log_softmax
1312
from twinkle.loss.base import Loss
13+
from twinkle.utils.torch_utils import selective_log_softmax
1414

1515
if TYPE_CHECKING:
1616
import torch
@@ -176,14 +176,10 @@ def _align_logps(
176176
# Truncate right (keep left part) - may happen in Ray result merging
177177
return logps[:, :target_seq_len]
178178
else:
179-
raise ValueError(
180-
f'ref_logps seq_len ({src_seq_len}) < target seq_len ({target_seq_len}). '
181-
f'This should not happen when both models process the same batch.'
182-
)
179+
raise ValueError(f'ref_logps seq_len ({src_seq_len}) < target seq_len ({target_seq_len}). '
180+
f'This should not happen when both models process the same batch.')
183181

184-
raise ValueError(
185-
f'Cannot align ref_logps shape {logps.shape} to target shape {target_shape}'
186-
)
182+
raise ValueError(f'Cannot align ref_logps shape {logps.shape} to target shape {target_shape}')
187183

188184
def _compute_dpo_loss(
189185
self,
@@ -227,7 +223,7 @@ def _compute_dpo_loss(
227223
elif self.loss_type == 'ipo':
228224
# IPO (Identity Preference Optimization) loss
229225
# Reference: "A General Theoretical Paradigm to Understand Learning from Human Feedback"
230-
losses = (logits - 1 / (2 * self.beta)) ** 2
226+
losses = (logits - 1 / (2 * self.beta))**2
231227
elif self.loss_type == 'kto_pair':
232228
# KTO pair loss (simplified version)
233229
chosen_logratios_scaled = self.beta * chosen_logratios
@@ -236,7 +232,7 @@ def _compute_dpo_loss(
236232
rejected_losses = F.sigmoid(rejected_logratios_scaled)
237233
losses = chosen_losses + rejected_losses
238234
else:
239-
raise ValueError(f"Unknown loss_type: {self.loss_type}")
235+
raise ValueError(f'Unknown loss_type: {self.loss_type}')
240236

241237
# Apply label smoothing if specified
242238
if self.label_smoothing > 0:
@@ -292,7 +288,7 @@ def __call__(
292288
labels = labels.unsqueeze(0)
293289

294290
batch_size = labels.shape[0]
295-
assert batch_size % 2 == 0, "Batch size must be even (chosen + rejected pairs)"
291+
assert batch_size % 2 == 0, 'Batch size must be even (chosen + rejected pairs)'
296292

297293
# Get log probabilities from outputs
298294
logps = self._get_logps_from_outputs(outputs, labels)
@@ -314,9 +310,7 @@ def __call__(
314310
reference_rejected_logps = ref_rejected_logps.to(device=device, dtype=dtype)
315311
elif ref_logps is not None:
316312
# Per-token reference log probs provided, need to align and sum
317-
ref_logps_aligned = self._align_logps(
318-
ref_logps, labels.shape, device, dtype
319-
)
313+
ref_logps_aligned = self._align_logps(ref_logps, labels.shape, device, dtype)
320314
ref_chosen, ref_rejected = self._split_chosen_rejected(ref_logps_aligned)
321315
reference_chosen_logps = self._compute_sequence_logps(ref_chosen, chosen_labels)
322316
reference_rejected_logps = self._compute_sequence_logps(ref_rejected, rejected_labels)
@@ -392,7 +386,7 @@ def __call__(
392386
if labels.dim() == 1:
393387
labels = labels.unsqueeze(0)
394388

395-
assert labels.shape[0] % 2 == 0, "Batch size must be even (chosen + rejected pairs)"
389+
assert labels.shape[0] % 2 == 0, 'Batch size must be even (chosen + rejected pairs)'
396390

397391
# Get log probabilities
398392
logps = self._get_logps_from_outputs(outputs, labels)
@@ -455,7 +449,7 @@ def __call__(
455449
if labels.dim() == 1:
456450
labels = labels.unsqueeze(0)
457451

458-
assert labels.shape[0] % 2 == 0, "Batch size must be even"
452+
assert labels.shape[0] % 2 == 0, 'Batch size must be even'
459453

460454
# Get log probabilities
461455
logps = self._get_logps_from_outputs(outputs, labels)
@@ -521,7 +515,7 @@ def __call__(
521515
if labels.dim() == 1:
522516
labels = labels.unsqueeze(0)
523517

524-
assert labels.shape[0] % 2 == 0, "Batch size must be even"
518+
assert labels.shape[0] % 2 == 0, 'Batch size must be even'
525519

526520
# Get log probabilities
527521
logps = self._get_logps_from_outputs(outputs, labels)
@@ -540,8 +534,8 @@ def __call__(
540534
# Odds ratio: log(odds_chosen / odds_rejected)
541535
# log_odds = log(p/(1-p)) = log(p) - log(1-p)
542536
# Use numerically stable computation
543-
prob_chosen = torch.exp(chosen_avg_logps).clamp(min=1e-7, max=1-1e-7)
544-
prob_rejected = torch.exp(rejected_avg_logps).clamp(min=1e-7, max=1-1e-7)
537+
prob_chosen = torch.exp(chosen_avg_logps).clamp(min=1e-7, max=1 - 1e-7)
538+
prob_rejected = torch.exp(rejected_avg_logps).clamp(min=1e-7, max=1 - 1e-7)
545539
log_odds_chosen = torch.log(prob_chosen) - torch.log(1 - prob_chosen)
546540
log_odds_rejected = torch.log(prob_rejected) - torch.log(1 - prob_rejected)
547541

src/twinkle/metric/dpo.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,8 @@ def _align_logps(self, logps, target_shape, device, dtype):
5757
if src_len == target_len:
5858
return logps
5959
elif src_len < target_len:
60-
raise ValueError(
61-
f'ref_logps seq_len ({src_len}) < target seq_len ({target_len}). '
62-
f'This should not happen when both models process the same batch.'
63-
)
60+
raise ValueError(f'ref_logps seq_len ({src_len}) < target seq_len ({target_len}). '
61+
f'This should not happen when both models process the same batch.')
6462
else:
6563
return logps[:, :target_len]
6664

@@ -84,7 +82,7 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M
8482
logps = outputs.get('logps')
8583
if logps is None or len(logps) == 0:
8684
return
87-
85+
8886
if isinstance(logps, list) and logps:
8987
logps = pad_and_stack_tensors(logps)
9088

@@ -128,9 +126,7 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M
128126
ref_logps = ref_outputs.get('logps')
129127
if ref_logps is not None:
130128
# Align ref_logps to match labels shape (handles different seq lengths)
131-
ref_logps = self._align_logps(
132-
ref_logps, labels.shape, labels.device, logps.dtype
133-
)
129+
ref_logps = self._align_logps(ref_logps, labels.shape, labels.device, logps.dtype)
134130

135131
ref_seq_logps = self._compute_sequence_logps(ref_logps, labels)
136132
ref_chosen_logps, ref_rejected_logps = self._split_chosen_rejected(ref_seq_logps)

src/twinkle/model/megatron/megatron.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@
2525
from twinkle import DeviceMesh, Platform, remote_class, remote_function, requires, torch_util
2626
from twinkle.checkpoint_engine.mixin import CheckpointEngineMixin
2727
from twinkle.data_format import InputFeature, ModelOutput, Trajectory
28-
from twinkle.model.optimizer_group import BaseOptimizerGroup, TrainStatus
2928
from twinkle.hub import HubOperation
3029
from twinkle.infra import collect_tensor_dict
3130
from twinkle.loss import CrossEntropyLoss, Loss
3231
from twinkle.metric import LossMetric, Metric, TrainMetric
3332
from twinkle.model.base import TwinkleModel
33+
from twinkle.model.optimizer_group import BaseOptimizerGroup, TrainStatus
3434
from twinkle.patch import Patch, apply_patch
3535
from twinkle.processor import InputProcessor
3636
from twinkle.template import Template
@@ -435,7 +435,7 @@ def forward_backward(self,
435435
seq_length = original_seq_length + (divisor - original_seq_length % divisor)
436436
else:
437437
seq_length = original_seq_length
438-
438+
439439
num_microbatches = len(inputs)
440440
loss_extra_kwargs_per_mb = []
441441
if num_microbatches <= 1:
@@ -463,10 +463,12 @@ def post_loss_function(output_tensor, inputs, logps):
463463
if not counts:
464464
# Later will gather this value, so it becomes:
465465
# 1. SUM loss: gather_sum(local_num_tokens) = global_num_tokens
466-
# 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps ) = gradient_accumulation_steps * world_size
466+
# 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps )
467+
# = gradient_accumulation_steps * world_size
467468
# Then, grad will divided by this value:
468469
# 1. SUM loss: (global_sum_grad) / (global_num_tokens) = global_sum_grad/global_num_tokens
469-
# 2. PER TOKEN MEAN loss: (gather_sum(per_token_grad * gradient_accumulation_steps)) / (gradient_accumulation_steps * world_size ) = avg_per_token_grad
470+
# 2. PER TOKEN MEAN loss: (gather_sum(per_token_grad * gradient_accumulation_steps))
471+
# / (gradient_accumulation_steps * world_size ) = avg_per_token_grad
470472
counts = torch.tensor(1, device=losses.device)
471473
return self.strategy.reduce_loss(losses, counts, output_tensor, logps)
472474

src/twinkle/model/optimizer_group.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
22
from dataclasses import dataclass, field
3-
from typing import Any, Dict, List, Optional
4-
53
from torch.optim import Optimizer
64
from torch.optim.lr_scheduler import LRScheduler
5+
from typing import Any, Dict, List, Optional
76

87
from twinkle import DeviceMesh
98
from twinkle.data_format import InputFeature, ModelOutput
@@ -83,4 +82,3 @@ def calculate_metrics(self, is_training):
8382
status.inputs = None
8483
status.outputs = None
8584
return results
86-

src/twinkle/model/transformers/multi_lora_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union
99

1010
from twinkle import DeviceMesh, remote_class, remote_function, template
11-
from twinkle.infra import collect_tensor_dict
1211
from twinkle.data_format import InputFeature, Trajectory
1312
from twinkle.hub import HubOperation
13+
from twinkle.infra import collect_tensor_dict
1414
from twinkle.loss import Loss
1515
from twinkle.metric import Metric
1616
from twinkle.processor import InputProcessor

src/twinkle/model/transformers/transformers.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -474,13 +474,16 @@ def calculate_loss(self, **kwargs):
474474
counts = torch.tensor(1, device=loss_value.device)
475475
# Later will gather this value, so it becomes:
476476
# 1. SUM loss: gather_sum(local_num_tokens / dp_world_size) = global_num_tokens / dp_world_size
477-
# 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps / dp_world_size ) = gradient_accumulation_steps
477+
# 2. PER TOKEN MEAN loss: gather_sum(1 * gradient_accumulation_steps / dp_world_size )
478+
# = gradient_accumulation_steps
478479
# Then, grad will divided by this value:
479480
# 1. SUM loss: gather_mean(local_sum_grad) / (global_num_tokens / dp_world_size)
480481
# = (global_sum_grad / dp_world_size) / (global_num_tokens / dp_world_size)
481482
# = global_sum_grad/global_num_tokens
482-
# 2. PER TOKEN MEAN loss: gather_mean(per_token_grad * gradient_accumulation_steps) / gradient_accumulation_steps
483-
# = (global_per_token_grad * gradient_accumulation_steps / dp_world_size ) / gradient_accumulation_steps
483+
# 2. PER TOKEN MEAN loss: gather_mean(per_token_grad * gradient_accumulation_steps)
484+
# / gradient_accumulation_steps
485+
# = (global_per_token_grad * gradient_accumulation_steps / dp_world_size )
486+
# / gradient_accumulation_steps
484487
# = global_per_token_grad / dp_world_size = avg_per_token_grad
485488
counts = counts / self.device_mesh.data_world_size
486489
optimizer_config = self.optimizer_group[adapter_name]

src/twinkle/template/base.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def _encode_messages(self, trajectory: Trajectory, add_generation_prompt: bool =
320320
input_ids = self.tokenizer.encode(text)
321321
encoded = {}
322322
labels = deepcopy(input_ids)
323-
323+
324324
input_feature = InputFeature(
325325
input_ids=np.array(input_ids),
326326
labels=np.array(labels),
@@ -398,9 +398,7 @@ def batch_encode(
398398
for key in trajectories:
399399
if key in traj_keys:
400400
# Encode this trajectory list
401-
result[key] = self.batch_encode(
402-
trajectories[key], add_generation_prompt=add_generation_prompt
403-
)
401+
result[key] = self.batch_encode(trajectories[key], add_generation_prompt=add_generation_prompt)
404402
else:
405403
# Keep non-trajectory columns as-is
406404
result[key] = trajectories[key]

src/twinkle/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
from .parallel import processing_lock
1111
from .platforms import GPU, NPU, Platform, ensure_hccl_socket_env, ensure_npu_backend
1212
from .safetensors import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver
13-
from .torch_utils import pad_sequence_to_length, selective_log_softmax, stateless_init_process_group, to_device, pad_and_stack_tensors
13+
from .torch_utils import (pad_and_stack_tensors, pad_sequence_to_length, selective_log_softmax,
14+
stateless_init_process_group, to_device)
1415
from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert
1516
from .unsafe import check_unsafe, trust_remote_code
1617
from .utils import copy_files_by_pattern, deep_getattr

0 commit comments

Comments
 (0)