Skip to content

Commit 6a2b9e7

Browse files
committed
remove
1 parent c5aafff commit 6a2b9e7

1 file changed

Lines changed: 0 additions & 57 deletions

File tree

src/twinkle/utils/torch_utils.py

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -44,63 +44,6 @@ def pad_sequence_to_length(
4444
return F.pad(tensor, pad_tuple, mode='constant', value=pad_value)
4545

4646

47-
# TODO: check and remove
48-
def pad_2d_list_to_tensor(
49-
data_list: List[List],
50-
max_length: Optional[int] = None,
51-
pad_value: float = 0.0,
52-
left_pad: bool = False,
53-
dtype: 'torch.dtype' = None,
54-
device: Optional[Union[str, 'torch.device']] = None,
55-
) -> 'torch.Tensor':
56-
"""
57-
Pad a 2D list (e.g., list of logprobs) to a 2D tensor.
58-
59-
Args:
60-
data_list: List of lists, each inner list can have different lengths
61-
max_length: Target length. If None, uses the max length in data_list
62-
pad_value: Value to use for padding
63-
left_pad: If True, pad on the left (right-align data); otherwise pad on the right
64-
dtype: Output tensor dtype
65-
device: Output tensor device
66-
67-
Returns:
68-
Padded tensor of shape [batch, max_length]
69-
"""
70-
import torch
71-
if dtype is None:
72-
dtype = torch.float32
73-
if not data_list:
74-
return torch.tensor([], dtype=dtype, device=device)
75-
76-
# Find max length
77-
lengths = [len(item) if item is not None else 0 for item in data_list]
78-
data_max_len = max(lengths) if lengths else 0
79-
target_len = max_length if max_length is not None and max_length > data_max_len else data_max_len
80-
81-
if target_len == 0:
82-
return torch.full((len(data_list), 0), pad_value, dtype=dtype, device=device)
83-
84-
batch_size = len(data_list)
85-
result = torch.full((batch_size, target_len), pad_value, dtype=dtype, device=device)
86-
87-
for i, item in enumerate(data_list):
88-
if item is None or len(item) == 0:
89-
continue
90-
seq_len = min(len(item), target_len)
91-
if torch.is_tensor(item):
92-
values = item[-seq_len:] if left_pad else item[:seq_len]
93-
else:
94-
values = torch.tensor(item[-seq_len:] if left_pad else item[:seq_len], dtype=dtype, device=device)
95-
96-
if left_pad:
97-
result[i, -seq_len:] = values
98-
else:
99-
result[i, :seq_len] = values
100-
101-
return result
102-
103-
10447
def selective_log_softmax(logits, index) -> 'torch.Tensor':
10548
"""
10649
refer: trl/trainer/utils

0 commit comments

Comments
 (0)