File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 4141from collections import defaultdict
4242import psutil
4343
44+ from clt .training .utils import torch_bfloat16_to_numpy_uint16
45+
4446try :
4547 import GPUtil
4648except ImportError :
@@ -764,13 +766,13 @@ def _write_chunk(
764766
765767 # Convert to numpy
766768 with self ._conditional_measure (f"chunk_{ chunk_idx } _layer_{ lid } _convert_numpy" ):
767- inp_np = inp_perm . to ( self . torch_dtype ). numpy ()
768- tgt_np = tgt_perm . to ( self . torch_dtype ). numpy ()
769-
770- # Handle bfloat16 conversion
771- if h5py_dtype_str == "uint16" and inp_np . dtype == np . dtype ( "bfloat16" ) :
772- inp_np = inp_np . view ( np . uint16 )
773- tgt_np = tgt_np . view ( np . uint16 )
769+ # Handle bfloat16 conversion
770+ if h5py_dtype_str == "uint16" :
771+ inp_np = torch_bfloat16_to_numpy_uint16 ( inp_perm )
772+ tgt_np = torch_bfloat16_to_numpy_uint16 ( tgt_perm )
773+ else :
774+ inp_np = inp_perm . to ( self . torch_dtype ). numpy ( )
775+ tgt_np = tgt_perm . to ( self . torch_dtype ). numpy ( )
774776
775777 # Store prepared data
776778 layer_data [lid ] = (inp_np , tgt_np )
Original file line number Diff line number Diff line change 11import datetime
22
3+ import numpy as np
4+ import torch
5+
36
47# Helper function to format elapsed time
58def _format_elapsed_time (seconds : float ) -> str :
@@ -11,3 +14,7 @@ def _format_elapsed_time(seconds: float) -> str:
1114 return f"{ td .days * 24 + hours :02d} :{ minutes :02d} :{ seconds :02d} "
1215 else :
1316 return f"{ minutes :02d} :{ seconds :02d} "
17+
18+
19+ def torch_bfloat16_to_numpy_uint16 (x : torch .Tensor ) -> np .ndarray :
20+ return np .frombuffer (x .float ().numpy ().tobytes (), dtype = np .uint16 )[1 ::2 ].reshape (x .shape )
You can’t perform that action at this time.
0 commit comments