@@ -44,63 +44,6 @@ def pad_sequence_to_length(
4444 return F .pad (tensor , pad_tuple , mode = 'constant' , value = pad_value )
4545
4646
47- # TODO: check and remove
48- def pad_2d_list_to_tensor (
49- data_list : List [List ],
50- max_length : Optional [int ] = None ,
51- pad_value : float = 0.0 ,
52- left_pad : bool = False ,
53- dtype : 'torch.dtype' = None ,
54- device : Optional [Union [str , 'torch.device' ]] = None ,
55- ) -> 'torch.Tensor' :
56- """
57- Pad a 2D list (e.g., list of logprobs) to a 2D tensor.
58-
59- Args:
60- data_list: List of lists, each inner list can have different lengths
61- max_length: Target length. If None, uses the max length in data_list
62- pad_value: Value to use for padding
63- left_pad: If True, pad on the left (right-align data); otherwise pad on the right
64- dtype: Output tensor dtype
65- device: Output tensor device
66-
67- Returns:
68- Padded tensor of shape [batch, max_length]
69- """
70- import torch
71- if dtype is None :
72- dtype = torch .float32
73- if not data_list :
74- return torch .tensor ([], dtype = dtype , device = device )
75-
76- # Find max length
77- lengths = [len (item ) if item is not None else 0 for item in data_list ]
78- data_max_len = max (lengths ) if lengths else 0
79- target_len = max_length if max_length is not None and max_length > data_max_len else data_max_len
80-
81- if target_len == 0 :
82- return torch .full ((len (data_list ), 0 ), pad_value , dtype = dtype , device = device )
83-
84- batch_size = len (data_list )
85- result = torch .full ((batch_size , target_len ), pad_value , dtype = dtype , device = device )
86-
87- for i , item in enumerate (data_list ):
88- if item is None or len (item ) == 0 :
89- continue
90- seq_len = min (len (item ), target_len )
91- if torch .is_tensor (item ):
92- values = item [- seq_len :] if left_pad else item [:seq_len ]
93- else :
94- values = torch .tensor (item [- seq_len :] if left_pad else item [:seq_len ], dtype = dtype , device = device )
95-
96- if left_pad :
97- result [i , - seq_len :] = values
98- else :
99- result [i , :seq_len ] = values
100-
101- return result
102-
103-
10447def selective_log_softmax (logits , index ) -> 'torch.Tensor' :
10548 """
10649 refer: trl/trainer/utils
0 commit comments