Skip to content

Commit 8fc2bb7

Browse files
committed
wip
1 parent bebe60e commit 8fc2bb7

File tree

6 files changed

+83
-52
lines changed

6 files changed

+83
-52
lines changed

src/twinkle/infra/collectors.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import TYPE_CHECKING, Any, Dict, List
33

44
from twinkle import DeviceMesh
5+
from twinkle.utils import pad_and_stack_tensors
56

67
if TYPE_CHECKING:
78
import torch
@@ -39,7 +40,7 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], device_mesh: DeviceMesh)
3940
result[key] = merged
4041

4142
elif isinstance(first_value, torch.Tensor):
42-
result[key] = _pad_and_stack_tensors(values)
43+
result[key] = pad_and_stack_tensors(values)
4344

4445
elif isinstance(first_value, dict):
4546
result[key] = collect_tensor_dict(values)
@@ -53,36 +54,3 @@ def collect_tensor_dict(outputs: List[Dict[str, Any]], device_mesh: DeviceMesh)
5354
if 'loss' in result and len(result['loss']) > 1:
5455
result['loss'] = np.mean(result['loss'])
5556
return result
56-
57-
58-
def _pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = -200) -> 'torch.Tensor':
59-
import torch
60-
if not tensors:
61-
raise ValueError('Empty tensor list')
62-
63-
if len(tensors) == 1:
64-
return tensors[0]
65-
66-
max_ndim = max(t.ndim for t in tensors)
67-
expanded_tensors = []
68-
for t in tensors:
69-
while t.ndim < max_ndim:
70-
t = t.unsqueeze(0)
71-
expanded_tensors.append(t)
72-
73-
max_shape = []
74-
for dim in range(max_ndim):
75-
max_shape.append(max(t.shape[dim] for t in expanded_tensors))
76-
77-
padded_tensors = []
78-
for t in expanded_tensors:
79-
if list(t.shape) == max_shape:
80-
padded_tensors.append(t)
81-
else:
82-
pad_params = []
83-
for dim in range(max_ndim - 1, -1, -1):
84-
pad_params.extend([0, max_shape[dim] - t.shape[dim]])
85-
padded = torch.nn.functional.pad(t, pad_params, value=pad_value)
86-
padded_tensors.append(padded)
87-
88-
return torch.cat(padded_tensors, dim=0)

src/twinkle/loss/dpo.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,6 @@ def __call__(
284284
# Extract ref_logps from ref_outputs if provided
285285
if ref_outputs is not None and ref_logps is None:
286286
ref_logps = ref_outputs.get('logps')
287-
288287
labels = inputs.get('labels')
289288
assert labels is not None, "inputs must contain 'labels'"
290289
if not torch.is_tensor(labels):

src/twinkle/metric/dpo.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import List, Union
44

55
from twinkle.data_format import InputFeature, ModelOutput
6+
from twinkle.utils import pad_and_stack_tensors
67
from .base import Metric
78

89

@@ -81,13 +82,20 @@ def accumulate(self, inputs: Union[InputFeature, List[InputFeature]], outputs: M
8182
"""
8283
import torch
8384
logps = outputs.get('logps')
84-
if logps is None:
85+
if logps is None or len(logps) == 0:
8586
return
87+
88+
if isinstance(logps, list) and logps:
89+
logps = pad_and_stack_tensors(logps)
8690

8791
# Get labels from inputs
8892
if isinstance(inputs, list):
89-
assert len(inputs) == 1
90-
inputs = inputs[0]
93+
labels = [input['labels'] for input in inputs]
94+
if len(labels) == 1:
95+
labels = labels[0]
96+
else:
97+
labels = pad_and_stack_tensors(labels)
98+
inputs = {'labels': labels}
9199

92100
labels = torch.as_tensor(inputs['labels'])
93101
if labels.dim() == 1:

src/twinkle/model/megatron/megatron.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,38 @@ def _not_encoded(inputs):
250250
assert isinstance(inputs, dict)
251251
return 'input_ids' not in inputs and 'input_embedding' not in inputs
252252

253+
@staticmethod
254+
def _slice_value_for_microbatch(value, mb_start: int, mb_end: int, micro_batch_size: int):
255+
"""Recursively slice a value for microbatch processing.
256+
257+
Handles nested dicts (e.g., ref_outputs: {"logps": tensor}) by recursively
258+
slicing internal tensors.
259+
260+
Args:
261+
value: The value to slice (tensor, ndarray, list, dict, or scalar)
262+
mb_start: Start index of the microbatch
263+
mb_end: End index of the microbatch
264+
micro_batch_size: Size of each microbatch
265+
266+
Returns:
267+
Sliced value with the same structure
268+
"""
269+
if isinstance(value, torch.Tensor) and value.dim() >= 1 and value.shape[0] > micro_batch_size:
270+
return value[mb_start:mb_end]
271+
elif isinstance(value, np.ndarray) and value.ndim >= 1 and value.shape[0] > micro_batch_size:
272+
return value[mb_start:mb_end]
273+
elif isinstance(value, (list, tuple)) and len(value) > micro_batch_size:
274+
return value[mb_start:mb_end]
275+
elif isinstance(value, dict):
276+
# Recursively slice dict values (e.g., ref_outputs: {"logps": tensor})
277+
return {
278+
k: MegatronModel._slice_value_for_microbatch(v, mb_start, mb_end, micro_batch_size)
279+
for k, v in value.items()
280+
}
281+
else:
282+
# Scalars, small tensors, or non-sliceable values pass through as-is
283+
return value
284+
253285
def _postprocess_tensor_cp(self, tensor):
254286
"""All-gather and reconstruct full sequence from CP-split tensor.
255287
@@ -401,8 +433,6 @@ def forward_backward(self,
401433
else:
402434
seq_length = original_seq_length
403435

404-
if 'ref_outputs' in kwargs:
405-
breakpoint()
406436
num_microbatches = len(inputs)
407437
loss_extra_kwargs_per_mb = []
408438
if num_microbatches <= 1:
@@ -411,17 +441,10 @@ def forward_backward(self,
411441
for mb_idx in range(num_microbatches):
412442
mb_start = mb_idx * micro_batch_size
413443
mb_end = mb_start + micro_batch_size
414-
mb_kwargs = {}
415-
for key, value in kwargs.items():
416-
if isinstance(value, torch.Tensor) and value.dim() >= 1 and value.shape[0] > micro_batch_size:
417-
mb_kwargs[key] = value[mb_start:mb_end]
418-
elif isinstance(value, np.ndarray) and value.ndim >= 1 and value.shape[0] > micro_batch_size:
419-
mb_kwargs[key] = value[mb_start:mb_end]
420-
elif isinstance(value, (list, tuple)) and len(value) > micro_batch_size:
421-
mb_kwargs[key] = value[mb_start:mb_end]
422-
else:
423-
# Scalars, small tensors, or non-sliceable values pass through as-is
424-
mb_kwargs[key] = value
444+
mb_kwargs = {
445+
key: self._slice_value_for_microbatch(value, mb_start, mb_end, micro_batch_size)
446+
for key, value in kwargs.items()
447+
}
425448
loss_extra_kwargs_per_mb.append(mb_kwargs)
426449

427450
_mb_counter = [0] # mutable counter for closure

src/twinkle/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .parallel import processing_lock
1111
from .platforms import GPU, NPU, Platform, ensure_hccl_socket_env, ensure_npu_backend
1212
from .safetensors import LazyTensor, SafetensorLazyLoader, StreamingSafetensorSaver
13-
from .torch_utils import pad_sequence_to_length, selective_log_softmax, stateless_init_process_group, to_device
13+
from .torch_utils import pad_sequence_to_length, selective_log_softmax, stateless_init_process_group, to_device, pad_and_stack_tensors
1414
from .transformers_utils import find_all_linears, find_layers, get_modules_to_not_convert
1515
from .unsafe import check_unsafe, trust_remote_code
1616
from .utils import copy_files_by_pattern, deep_getattr

src/twinkle/utils/torch_utils.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,36 @@ def stateless_init_process_group(
190190

191191
communicator = Communicator(pg, device=device)
192192
return communicator
193+
194+
195+
def pad_and_stack_tensors(tensors: List['torch.Tensor'], pad_value: float = -200) -> 'torch.Tensor':
196+
import torch
197+
if not tensors:
198+
raise ValueError('Empty tensor list')
199+
200+
if len(tensors) == 1:
201+
return tensors[0]
202+
203+
max_ndim = max(t.ndim for t in tensors)
204+
expanded_tensors = []
205+
for t in tensors:
206+
while t.ndim < max_ndim:
207+
t = t.unsqueeze(0)
208+
expanded_tensors.append(t)
209+
210+
max_shape = []
211+
for dim in range(max_ndim):
212+
max_shape.append(max(t.shape[dim] for t in expanded_tensors))
213+
214+
padded_tensors = []
215+
for t in expanded_tensors:
216+
if list(t.shape) == max_shape:
217+
padded_tensors.append(t)
218+
else:
219+
pad_params = []
220+
for dim in range(max_ndim - 1, -1, -1):
221+
pad_params.extend([0, max_shape[dim] - t.shape[dim]])
222+
padded = torch.nn.functional.pad(t, pad_params, value=pad_value)
223+
padded_tensors.append(padded)
224+
225+
return torch.cat(padded_tensors, dim=0)

0 commit comments

Comments
 (0)