Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,18 @@ def zero_optimization_stage(self):

def mics_shard_size(self):
return self._config.mics_shard_size

def lins_enable(self):
return self._config.zero_config.enable_lins

def lins_param_partition_num(self):
return self._config.zero_config.lins_param_partition_num

def lins_os_partition_num(self):
return self._config.zero_config.lins_os_partition_num

def lins_grad_partition_num(self):
return self._config.zero_config.lins_grad_partition_num

def zero_reduce_bucket_size(self):
return self._config.zero_config.reduce_bucket_size
Expand Down Expand Up @@ -1532,6 +1544,8 @@ def _configure_zero_optimizer(self, optimizer):
ranks=[0])
if mics_shard_size > 0:
return self._return_mics_optimizer(optimizer, timers)
elif self.lins_enable():
return self._return_lins_optimizer(optimizer, timers)

log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
Expand Down Expand Up @@ -1608,6 +1622,40 @@ def _return_mics_optimizer(self, basic_optimizer, timers):
communication_data_type=self.communication_data_type)
return optimizer

def _return_lins_optimizer(self, basic_optimizer, timers):
from deepspeed.runtime.zero.lins import LinS_Optimizer
model_dtype, gradient_accumulation_dtype = self.get_data_types()
print(f"enable LinS_Optimizer!!!!", flush=True)
optimizer = LinS_Optimizer(self.module,
basic_optimizer,
timers=timers,
ds_config=self.config,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
contiguous_gradients=self.zero_contiguous_gradients(),
reduce_bucket_size=self.zero_reduce_bucket_size(),
prefetch_bucket_size=self.zero_prefetch_bucket_size(),
max_reuse_distance=self.zero_max_reuse_distance(),
max_live_parameters=self.zero_max_live_parameters(),
param_persistence_threshold=self.zero_param_persistence_threshold(),
model_persistence_threshold=self.zero_model_persistence_threshold(),
dp_process_group=self.seq_data_parallel_group,
reduce_scatter=self.zero_reduce_scatter(),
overlap_comm=self.zero_overlap_comm(),
offload_optimizer_config=self.zero_offload_optimizer(),
offload_param_config=self.zero_offload_param(),
sub_group_size=self.zero_sub_group_size(),
mpu=self.mpu,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps(),
aio_config=self.aio_config(),
gradient_accumulation_dtype=gradient_accumulation_dtype,
communication_data_type=self.communication_data_type)
return optimizer

def _configure_eigenvalue(self):
eigenvalue = Eigenvalue(
verbose=self.eigenvalue_verbose(),
Expand Down Expand Up @@ -1659,6 +1707,8 @@ def deepspeed_io(self,
data_sampler=None,
collate_fn=None,
num_local_io_workers=None):

print(type(dataset))
if not (self.is_map_style_dataset(dataset) or self.is_iterable_style_dataset(dataset)):
raise ValueError("Training data must be a torch Dataset")

Expand Down
7 changes: 7 additions & 0 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,13 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
Override nn.Module apply function, for Stage 3.
"""

enable_lins: bool = False
lins_param_partition_num: int = Field(-1, new_param="lins_param_partition_num")
lins_os_partition_num: int = Field(-1, new_param="lins_os_partition_num")
lins_grad_partition_num: int = Field(-1, new_param="lins_grad_partition_num")
hierarchical_allgather: bool = False


# Validators
@validator("overlap_comm")
def overlap_comm_valid(cls, field_value, values):
Expand Down
Loading