@@ -211,7 +211,7 @@ def _worker_wrap_model_memory_efficient(rank, world_size, port, ref_sd):
211211 mesh_dim_names = ('fsdp' , ),
212212 device_type = _DEVICE_TYPE ,
213213 )
214- strategy = NativeFSDPStrategy (device_mesh = mesh , mixed_precision = 'no' , memory_efficient = True )
214+ strategy = NativeFSDPStrategy (device_mesh = mesh , mixed_precision = 'no' , memory_efficient_init = True )
215215
216216 model = TinyModel (dim = 32 ).to (_DEVICE_TYPE )
217217 if rank == 0 :
@@ -269,7 +269,7 @@ def _worker_wrap_model_legacy(rank, world_size, port, ref_sd):
269269 mesh_dim_names = ('fsdp' , ),
270270 device_type = _DEVICE_TYPE ,
271271 )
272- strategy = NativeFSDPStrategy (device_mesh = mesh , mixed_precision = 'no' , memory_efficient = False )
272+ strategy = NativeFSDPStrategy (device_mesh = mesh , mixed_precision = 'no' , memory_efficient_init = False )
273273
274274 model = TinyModel (dim = 32 ).to (_DEVICE_TYPE )
275275 model .load_state_dict (ref_sd )
@@ -324,7 +324,7 @@ def _worker_wrap_model_per_layer(rank, world_size, port, ref_sd):
324324 mesh_dim_names = ('fsdp' , ),
325325 device_type = _DEVICE_TYPE ,
326326 )
327- strategy = NativeFSDPStrategy (device_mesh = mesh , mixed_precision = 'no' , memory_efficient = True )
327+ strategy = NativeFSDPStrategy (device_mesh = mesh , mixed_precision = 'no' , memory_efficient_init = True )
328328
329329 model = TinyTransformerModel (dim = 32 , num_layers = 2 ).to (_DEVICE_TYPE )
330330 if rank == 0 :
0 commit comments