@@ -21,13 +21,14 @@ def __init__(
2121 mixed_precision : Literal ['no' , 'fp8' , 'fp16' , 'bf16' ] = 'bf16' ,
2222 ddp_config : Dict [str , Any ] = None ,
2323 fsdp_config : Dict [str , Any ] = None ,
24+ memory_efficient : bool = True ,
2425 ):
2526 from accelerate import Accelerator
2627
2728 self .device_mesh = device_mesh
2829 self .mixed_precision = mixed_precision
2930 parallelism_config = self ._parallelism_config_from_device_mesh (device_mesh )
30- fsdp_plugin = self ._fsdp_config_from_device_mesh (device_mesh , fsdp_config )
31+ fsdp_plugin = self ._fsdp_config_from_device_mesh (device_mesh , fsdp_config , memory_efficient )
3132
3233 kwargs_handlers = []
3334 if ddp_config is not None :
@@ -69,7 +70,7 @@ def _parallelism_config_from_device_mesh(device_mesh: DeviceMesh):
6970
7071 return parallelism_config
7172
72- def _fsdp_config_from_device_mesh (self , device_mesh : DeviceMesh , fsdp_config : Dict [str , Any ]):
73+ def _fsdp_config_from_device_mesh (self , device_mesh : DeviceMesh , fsdp_config : Dict [str , Any ], memory_efficient : bool ):
7374 from accelerate import FullyShardedDataParallelPlugin
7475 from torch .distributed .fsdp import BackwardPrefetch
7576 from torch .distributed .fsdp import ShardingStrategy as FSDPShardingStrategy
@@ -107,7 +108,7 @@ def _fsdp_config_from_device_mesh(self, device_mesh: DeviceMesh, fsdp_config: Di
107108 activation_checkpointing = fsdp_config .pop ('activation_checkpointing' , False ),
108109 auto_wrap_policy = fsdp_config .pop ('auto_wrap_policy' , 'transformer_based_wrap' ), # noqa
109110 reshard_after_forward = fsdp_config .pop ('reshard_after_forward' , True ),
110- cpu_ram_efficient_loading = fsdp_config .pop ('cpu_ram_efficient_loading' , True ),
111+ cpu_ram_efficient_loading = fsdp_config .pop ('cpu_ram_efficient_loading' , memory_efficient ),
111112 ** fsdp_config ,
112113 )
113114 # The env vars (ACCELERATE_USE_FSDP, FSDP_CPU_RAM_EFFICIENT_LOADING) are set
0 commit comments