diff --git a/comm/torch_backend.py b/comm/torch_backend.py index 5eb431e..a7670fb 100644 --- a/comm/torch_backend.py +++ b/comm/torch_backend.py @@ -40,7 +40,7 @@ def recv(self, else: buffer = tensor.cpu() dist.recv(buffer, self.to_global_rank(src), group=self.process_group) - tensor[:] = buffer.to(tensor.device) + tensor.set_(buffer.to(tensor.device)) def isend(self, tensor: torch.Tensor, @@ -63,7 +63,7 @@ def irecv(self, assert False buffer = tensor.cpu() handler = dist.irecv(buffer, self.to_global_rank(src), group=self.process_group) - tensor[:] = buffer.to(tensor.device) + tensor.set_(buffer.to(tensor.device)) return handler def broadcast(self, @@ -75,7 +75,7 @@ def broadcast(self, else: buffer = tensor.cpu() dist.broadcast(buffer, self.to_global_rank(src), group=self.process_group) - tensor[:] = buffer.to(tensor.device) + tensor.set_(buffer.to(tensor.device)) def reduce(self, tensor: torch.Tensor, @@ -90,7 +90,7 @@ def all_reduce(self, op=dist.ReduceOp.SUM): buffer = tensor.cpu() dist.all_reduce(buffer, group=self.process_group, op=op) - tensor[:] = buffer.to(tensor.device) + tensor.set_(buffer.to(tensor.device)) def gather(self, tensor: torch.Tensor, diff --git a/data_parallel/dist_dp_allreduce.py b/data_parallel/dist_dp_allreduce.py index 901a6f4..b077014 100644 --- a/data_parallel/dist_dp_allreduce.py +++ b/data_parallel/dist_dp_allreduce.py @@ -120,6 +120,7 @@ def optimizer_step(self): with torch.cuda.stream(self.torch_optim_comp_stream): self.torch_optim_comp_stream.wait_event(self.allreduce_grad_ready_event) self.profile_mark_optimizer_step_start() + torch.nn.utils.clip_grad_norm_(self.module.parameters(), 1.0) self.optimizer.step() self.torch_optim_comp_stream.record_event(self.optimizer_step_ready_event) diff --git a/data_parallel/dist_dp_cocktail_sgd.py b/data_parallel/dist_dp_cocktail_sgd.py index b463b03..e77f3fb 100644 --- a/data_parallel/dist_dp_cocktail_sgd.py +++ b/data_parallel/dist_dp_cocktail_sgd.py @@ -53,7 +53,7 @@ def __init__(self, args, device, module: torch.nn.Module, optimizer: torch.optim if self.flatten: _params = [] - for i_group, group in enumerate(self.optimizer.optimizer.param_groups): + for i_group, group in enumerate(self.optimizer.param_groups): for i_para, para in enumerate(group["params"]): _params.append(para) self.flatten_para = flatten_tensors(_params) @@ -346,7 +346,7 @@ def _partial_sync(self): cupy_dp_stream = cupy.cuda.ExternalStream(self.dp_comm_stream.cuda_stream) with torch.cuda.stream(self.dp_comm_stream), cupy_dp_stream: - for i_group, group in enumerate(self.optimizer.optimizer.param_groups): + for i_group, group in enumerate(self.optimizer.param_groups): for i_para, para in enumerate(group["params"]): diff --git a/data_parallel/dist_dp_cocktail_sgd_grad.py b/data_parallel/dist_dp_cocktail_sgd_grad.py deleted file mode 100644 index f44674b..0000000 --- a/data_parallel/dist_dp_cocktail_sgd_grad.py +++ /dev/null @@ -1,532 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.cuda -import math -from comm.comm_utils import * -from .flatten_utils import flatten_params, flatten_tensors -from compress.fixpoint import * -from compress.sparsification import * -from compress import flag -import cupy - -import os - -quantization_bits = int(os.environ.get('QUANT_BITS', 8)) -quantization_bucket_size = int(os.environ.get('QUANT_BUCKET_SIZE', 128)) -quantization_stochastic = int(os.environ.get('QUANT_STOCHASTIC', 0)) -quantization_minimum_stochastic_distance = float(os.environ.get('QUANT_MIN_STOCHASTIC_DISTANCE', 0.2)) -top_k_ratio = float(os.environ.get('TOPK_RATIO', 0.5)) -random_p_ratio = float(os.environ.get('RANDOMP_RATIO', 0.5)) -random_method = os.environ.get('RANDOM_METHOD', 'random_rolling') - -import threading - - -class CocktailSGDGradDP: - def __init__(self, args, device, module: torch.nn.Module, optimizer: torch.optim.Optimizer = None, flatten=False): - # assert not flatten - self.args = args - self.dp_bits = args.dp_bits - self.flatten = flatten - self.global_rank = args.rank - self.dp_group_size = args.data_group_size - self.enable_tidy_profiling = (args.profiling == 'tidy_profiling') - # self.dp_comm = get_data_parallel_comm() - self.dp_rank = get_data_parallel_rank() - self.pp_comm = get_pipeline_parallel_comm() - self.pp_rank = get_pipeline_parallel_rank() - self.pp_group_size = get_pipeline_parallel_world_size() - self.device = device - self.dp_comm_stream = torch.cuda.Stream(device=device, priority=-1) - self.torch_optim_comp_stream = torch.cuda.default_stream(device=device) - self.backward_ready_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False) - self.sync_gradients_start_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False) - self.sync_gradients_ready_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False) - self.optimizer_step_ready_event = torch.cuda.Event(enable_timing=self.enable_tidy_profiling, blocking=False) - - self.flag_dp_exception = 0 - - self.module = module - assert optimizer is not None - self.optimizer = optimizer - - if self.flatten: - _params = [] - for i_group, group in enumerate(self.optimizer.optimizer.param_groups): - for i_para, para in enumerate(group["params"]): - _params.append(para) - self.flatten_para = flatten_params(_params) - print("Flattened parameter number: {}, element size: {}." - .format(self.flatten_para.data.numel(), self.flatten_para.data.element_size())) - - - num_paras, element_size = self._compute_total_para_num() - print("Total number of parameters: {}, element size: {}, total size {} MB." - .format(num_paras, element_size, num_paras * element_size // 1024 // 1024)) - - if self.enable_tidy_profiling: - self.global_rank = args.rank - self.init_event = None - self.init_time_stamp = None - - # assert self.flatten - self.sync_gradients_start_event = torch.cuda.Event(enable_timing=True, blocking=False) - self.optimizer_step_start_event = torch.cuda.Event(enable_timing=True, blocking=False) - - self.gather_start_event = torch.cuda.Event(enable_timing=True, blocking=False) - self.sync_start_event = torch.cuda.Event(enable_timing=True, blocking=False) - self.gather_end_event = torch.cuda.Event(enable_timing=True, blocking=False) - self.sync_end_event = torch.cuda.Event(enable_timing=True, blocking=False) - - self.dp_state_dict = {} - - @property - def dp_comm(self): - return get_data_parallel_comm() - - def _compute_total_para_num(self): - total_count = 0 - element_size = 0 - for para in self.module.parameters(): - # print("Parameter: ", para.data.shape) - total_count += torch.numel(para.data) - element_size = para.element_size() - return total_count, element_size - - def profile_mark_sync_grad_start(self): - if self.enable_tidy_profiling: - self.dp_comm_stream.record_event(self.sync_gradients_start_event) - - def profile_mark_allreduce_end(self): - pass - - def profile_mark_optimizer_step_start(self): - if self.enable_tidy_profiling: - self.torch_optim_comp_stream.record_event(self.optimizer_step_start_event) - - def _allreduce_gradients(self): - with torch.cuda.stream(self.dp_comm_stream): - cupy_dp_stream = cupy.cuda.ExternalStream(self.dp_comm_stream.cuda_stream) - self.dp_comm_stream.wait_event(self.backward_ready_event) - for name, para in self.module.named_parameters(): - if para.grad is None: - continue - self.dp_comm.all_reduce(para.grad, stream=cupy_dp_stream) - self.dp_comm_stream.record_event(self.sync_gradients_ready_event) - - def _compress(self, x): - # return x - dtype = x.dtype - shape = x.shape - with torch.cuda.stream(self.dp_comm_stream): - cupy_dp_stream = cupy.cuda.ExternalStream(self.dp_comm_stream.cuda_stream) - with cupy_dp_stream: - - k = max(int(top_k_ratio * x.numel()), 1) - if k >= quantization_bucket_size: - # ensure dividable - k = k // quantization_bucket_size * quantization_bucket_size - else: - # bucket_size will be set to k internally - pass - - values, masks, indices = compress_topk(x, k, return_indices=True) - - values_q, scales_q = compress_flexible_nbits_by_bucket( - values, bits=quantization_bits, scale_method='max', bucket_size=quantization_bucket_size, - stochastic=quantization_stochastic, minimum_stochastic_distance=quantization_minimum_stochastic_distance) - - return (values_q, scales_q, masks), (dtype, shape, values.shape) - - def _decompress(self, x_hat, meta_data): - - values_q, scales_q, masks = x_hat - x_dtype, x_shape, values_shape = meta_data - - values = decompress_flexible_nbits_by_bucket(values_q, scales_q, bits=quantization_bits, original_shape=values_shape, bucket_size=quantization_bucket_size) - - x = decompress_topk(values, masks, x_shape) - x = x.view(x_shape).to(x_dtype) - - return x - - def _update_comm_mask(self, para, comm_mask=None): - - if random_method == 'random_rolling': - - sync_every_n_elements = int(1 / random_p_ratio) - - if comm_mask is None: - para_shape = list(para.shape) - assert para_shape[0] == para_shape[0] // self.dp_group_size * self.dp_group_size - para_shape[0] = para_shape[0] // self.dp_group_size - comm_mask = torch.zeros(para_shape, dtype=torch.bool, device=para.device) - comm_mask.view(-1)[::sync_every_n_elements] = True - n_potisive = comm_mask.sum().item() // quantization_bucket_size * quantization_bucket_size - if n_potisive != 0: - comm_mask.view(-1)[comm_mask.view(-1).cumsum(-1) > n_potisive] = False - assert comm_mask.sum().item() == n_potisive - else: - comm_mask[:] = True - print('comm_mask:', comm_mask.sum().item(), comm_mask.shape) - else: - comm_mask = comm_mask.roll(1) - - elif random_method == 'random_w_replacement': - - seed = torch.randint(10000, [1]) - self.dp_comm.broadcast(seed, 0) - torch.manual_seed(seed.item()) - - para_shape = list(para.shape) - assert len(para_shape) == 1 - assert para_shape[0] == para_shape[0] // self.dp_group_size * self.dp_group_size - para_shape[0] = para_shape[0] // self.dp_group_size - - n_sample = int(random_p_ratio * para_shape[0]) - n_sample = n_sample // 8 * 8 - n_sample = n_sample // quantization_bucket_size * quantization_bucket_size - comm_mask = torch.randint(para_shape[0], (n_sample,), device=para.device) - - elif random_method == 'random_wo_replacement': - - if comm_mask is None or ((self._cursor+1) * self._n_sample >= len(self._comm_indices)): - - seed = torch.randint(10000, [1]) - self.dp_comm.broadcast(seed, 0) - torch.manual_seed(seed.item()) - - para_shape = list(para.shape) - assert len(para_shape) == 1 - assert para_shape[0] == para_shape[0] // self.dp_group_size * self.dp_group_size - para_shape[0] = para_shape[0] // self.dp_group_size - - n_sample = int(random_p_ratio * para_shape[0]) - n_sample = n_sample // 8 * 8 - n_sample = n_sample // quantization_bucket_size * quantization_bucket_size - self._n_sample = n_sample - self._cursor = 0 - self._comm_indices = torch.randperm(para_shape[0], device=para.device) - - comm_mask = self._comm_indices[self._cursor * self._n_sample: (self._cursor+1) * self._n_sample] - self._cursor += 1 - - else: - - raise Exception(f"""Unknown random method '{random_method}'""") - - - return comm_mask - - - - def _partial_sync(self): - - if self.flatten: - - cupy_dp_stream = cupy.cuda.ExternalStream(self.dp_comm_stream.cuda_stream) - with torch.cuda.stream(self.dp_comm_stream), cupy_dp_stream: - - self.dp_comm_stream.record_event(self.sync_gradients_start_event) - - self.dp_comm.barrier() - - name = 'model' - para = self.flatten_para - - dp_state_dict = self.dp_state_dict - - if name not in dp_state_dict: - - # comm mask - comm_mask_list = [] - comm_data_list = [] - for i in range(self.dp_group_size): - comm_mask = self._update_comm_mask(para) - comm_mask_list.append(comm_mask) - - # global para - global_para = para.data.half() - - # server error - server_error = torch.zeros( - para.size(0) // self.dp_group_size, dtype=torch.float16, device=para.device, - ) - - dp_state_dict[name] = { - "comm_mask_list": comm_mask_list, - "global_para": global_para, - "server_error": server_error, - } - else: - for i in range(self.dp_group_size): - dp_state_dict[name]['comm_mask_list'][i] = self._update_comm_mask(para, dp_state_dict[name]['comm_mask_list'][i]) - - comm_mask_list = dp_state_dict[name]["comm_mask_list"] - comm_data_list = comm_data_list = [None for _ in comm_mask_list] - global_para = dp_state_dict[name]["global_para"] - chunk_size = global_para.size(0) // self.dp_group_size - server_error = dp_state_dict[name]["server_error"] - server_mask = comm_mask_list[self.dp_rank].to(server_error.device) - - for i in range(self.dp_group_size): - comm_mask = comm_mask_list[i] - comm_data_list[i] = (para[i*chunk_size:(i+1)*chunk_size][comm_mask] - global_para[i*chunk_size:(i+1)*chunk_size][comm_mask]).half() - - comm_data_compressed_list = [] - comm_data_meta_list = [] - for x in comm_data_list: - data, meta_data = self._compress(x) - comm_data_compressed_list.append(data) - comm_data_meta_list.append(meta_data) - del x - # del comm_data_list - comm_buffer_list = [[torch.zeros_like(x, device='cpu') for x in x_tuple] for x_tuple in comm_data_compressed_list] - - # revert - for i in range(self.dp_group_size): - _data_compressed = self._decompress(comm_data_compressed_list[i], comm_data_meta_list[i]) - para.data[i*chunk_size:(i+1)*chunk_size][comm_mask_list[i]] -= _data_compressed - del _data_compressed - - _group_calls = [] - for i in range(self.dp_group_size): - for j, to_send in enumerate(comm_data_compressed_list[i]): - if i != self.dp_rank: - call = self.dp_comm.isend( - to_send, dst=i, stream=cupy_dp_stream) - _group_calls.append(call) - else: - comm_buffer_list[i][j][:] = to_send.cpu() - for to_recv in comm_buffer_list[i]: - if i != self.dp_rank: - call = self.dp_comm.irecv( - to_recv, src=i, stream=cupy_dp_stream) - _group_calls.append(call) - for call in _group_calls: - call.wait() - - server_data = self._decompress([z.to(para.device) for z in comm_buffer_list[0]], comm_data_meta_list[0]) / len(comm_buffer_list) - for i in range(1, self.dp_group_size): - server_data.data += self._decompress([z.to(para.device) for z in comm_buffer_list[i]], comm_data_meta_list[i]) / len(comm_buffer_list) - server_data.add_(server_error[server_mask].to(server_data.device)) - server_data_compressed, server_data_meta = self._compress(server_data) - server_error.data[server_mask] = (server_data - self._decompress(server_data_compressed, server_data_meta)).to(server_error.device) - - _group_calls = [] - for i in range(self.dp_group_size): - for j, to_send in enumerate(server_data_compressed): - if i != self.dp_rank: - call = self.dp_comm.isend( - to_send, dst=i, stream=cupy_dp_stream) - _group_calls.append(call) - else: - comm_buffer_list[i][j][:] = to_send.cpu() - for to_recv in comm_buffer_list[i]: - if i != self.dp_rank: - call = self.dp_comm.irecv( - to_recv, src=i, stream=cupy_dp_stream) - _group_calls.append(call) - for call in _group_calls: - call.wait() - - for i in range(self.dp_group_size): - - _data = self._decompress([z.to(para.device) for z in comm_buffer_list[i]], comm_data_meta_list[i]) - para.data[i*chunk_size:(i+1)*chunk_size][comm_mask_list[i]] += _data - global_para.data[i*chunk_size:(i+1)*chunk_size][comm_mask_list[i]] += _data - - del _data - - self.dp_comm_stream.record_event(self.sync_gradients_ready_event) - - else: - - cupy_dp_stream = cupy.cuda.ExternalStream(self.dp_comm_stream.cuda_stream) - with torch.cuda.stream(self.dp_comm_stream), cupy_dp_stream: - - for i_group, group in enumerate(self.optimizer.optimizer.param_groups): - for i_para, para in enumerate(group["params"]): - - - para = para.view(-1) - - name = f"{i_group}-{i_para}" - - dp_state_dict = self.dp_state_dict - - if name not in dp_state_dict: - - # comm mask - comm_mask_list = [] - comm_data_list = [] - for i in range(self.dp_group_size): - comm_mask = self._update_comm_mask(para) - comm_mask_list.append(comm_mask) - - # global para - global_para = para.data.half() - - # server error - # server_error = torch.zeros_like(global_para.chunk(self.dp_group_size, 0)[self.dp_rank]) - server_error = torch.zeros( - para.size(0) // self.dp_group_size, dtype=torch.float16, device='cpu', - ) - - # print('server error shape:', server_error.shape) - dp_state_dict[name] = { - "comm_mask_list": comm_mask_list, - "global_para": global_para, - "server_error": server_error, - } - else: - for i in range(self.dp_group_size): - dp_state_dict[name]['comm_mask_list'][i] = self._update_comm_mask(para, dp_state_dict[name]['comm_mask_list'][i]) - - comm_mask_list = dp_state_dict[name]["comm_mask_list"] - comm_data_list = [None for _ in comm_mask_list] - global_para = dp_state_dict[name]["global_para"] - chunk_size = global_para.size(0) // self.dp_group_size - server_error = dp_state_dict[name]["server_error"] - server_mask = comm_mask_list[self.dp_rank] - - for i in range(self.dp_group_size): - comm_mask = comm_mask_list[i] - comm_data_list[i] = (para[i*chunk_size:(i+1)*chunk_size][comm_mask] - global_para[i*chunk_size:(i+1)*chunk_size][comm_mask]).half() - - comm_data_compressed_list = [] - comm_data_meta_list = [] - for x in comm_data_list: - data, meta_data = self._compress(x) - comm_data_compressed_list.append(data) - comm_data_meta_list.append(meta_data) - comm_buffer_list = [[torch.zeros_like(x, device='cpu') for x in x_tuple] for x_tuple in comm_data_compressed_list] - - # revert - for i in range(self.dp_group_size): - _data_compressed = self._decompress(comm_data_compressed_list[i], comm_data_meta_list[i]) - print(comm_data_list[i].shape, _data_compressed.shape, comm_mask_list[i].shape) - para.data[i*chunk_size:(i+1)*chunk_size][comm_mask_list[i]] -= _data_compressed - del _data_compressed - - _group_calls = [] - for i in range(self.dp_group_size): - for j, to_send in enumerate(comm_data_compressed_list[i]): - # print(f"send from {self.dp_rank} to {i}") - if i != self.dp_rank: - call = self.dp_comm.isend( - to_send, dst=i, stream=cupy_dp_stream) - _group_calls.append(call) - else: - comm_buffer_list[i][j][:] = to_send.cpu() - for to_recv in comm_buffer_list[i]: - # print(f"recv from {i} to {self.dp_rank}") - if i != self.dp_rank: - call = self.dp_comm.irecv( - to_recv, src=i, stream=cupy_dp_stream) - _group_calls.append(call) - for call in _group_calls: - call.wait() - - server_data = self._decompress([z.to(para.device) for z in comm_buffer_list[0]], comm_data_meta_list[0]) / len(comm_buffer_list) - for i in range(1, self.dp_group_size): - server_data.data += self._decompress([z.to(para.device) for z in comm_buffer_list[i]], comm_data_meta_list[i]) / len(comm_buffer_list) - server_data.add_(server_error[server_mask].to(server_data.device)) - server_data_compressed, server_data_meta = self._compress(server_data) - server_error.data[server_mask] = (server_data - self._decompress(server_data_compressed, server_data_meta)).cpu() - - - _group_calls = [] - for i in range(self.dp_group_size): - for j, to_send in enumerate(server_data_compressed): - if i != self.dp_rank: - call = self.dp_comm.isend( - to_send, dst=i, stream=cupy_dp_stream) - _group_calls.append(call) - else: - comm_buffer_list[i][j][:] = to_send.cpu() - for to_recv in comm_buffer_list[i]: - if i != self.dp_rank: - call = self.dp_comm.irecv( - to_recv, src=i, stream=cupy_dp_stream) - _group_calls.append(call) - for call in _group_calls: - call.wait() - - - for i in range(self.dp_group_size): - - _data = self._decompress([z.to(para.device) for z in comm_buffer_list[i]], comm_data_meta_list[i]) - para.data[i*chunk_size:(i+1)*chunk_size][comm_mask_list[i]] += _data - global_para.data[i*chunk_size:(i+1)*chunk_size][comm_mask_list[i]] += _data - - del _data - - self.dp_comm_stream.record_event(self.sync_gradients_ready_event) - - def _try_partial_sync(self): - try: - self._partial_sync() - except: - self.flag_dp_exception = 1 - - def pre_optimizer_step(self): - if not flag.FLAG_DISABLE_COMPRESSION: - self.t = threading.Thread(target=self._try_partial_sync) - self.t.start() - - def reinit_dp_comm_if_wrong(self): - - buffers = [torch.zeros(1).long().to(self.device) for _ in range(self.pp_group_size)] - self.pp_comm.all_gather(torch.tensor(self.flag_dp_exception).long().to(self.device), buffers) - self.flag_dp_exception = max([s.item() for s in buffers]) - - if self.flag_dp_exception: - reinit_dp_communicator(self.args) - self.flag_dp_exception = 0 - - def optimizer_step(self): - - if flag.FLAG_DISABLE_COMPRESSION: - self._allreduce_gradients() - else: - self.t.join() - - self.reinit_dp_comm_if_wrong() - - with torch.cuda.stream(self.torch_optim_comp_stream): - self.torch_optim_comp_stream.wait_event(self.sync_gradients_ready_event) - self.torch_optim_comp_stream.wait_event(self.backward_ready_event) - self.profile_mark_optimizer_step_start() - self.optimizer.step() - print('done optim') - self.torch_optim_comp_stream.record_event(self.optimizer_step_ready_event) - - def set_time_stamp(self, init_time_stamp, init_event): - self.init_event = init_event - self.init_time_stamp = init_time_stamp - - def get_ts(self, event): - return self.init_time_stamp + self.init_event.elapsed_time(event) * 1e+3 - - def profiling_data_parallel(self, init_time_stamp, init_event): - self.set_time_stamp(init_time_stamp, init_event) - profiling_log = [] - - # assert self.flatten - allreduce_slot = self.sync_gradients_start_event.elapsed_time(self.sync_gradients_ready_event)*1e+3 - allreduce_log = {"name": "opt_shardedPS_sync", "ph": "X", "pid": self.global_rank, "tid": "7. optimizer-comm", - "ts": self.get_ts(self.sync_gradients_start_event), - "dur": allreduce_slot, "cname": "cq_build_passed", - "args": {'para': 'flattened_grad', 'size': self.flatten_para.numel()}} - # print(allreduce_log) - profiling_log.append(allreduce_log) - - optimizer_slot = self.optimizer_step_start_event.elapsed_time(self.optimizer_step_ready_event) * 1e+3 - optimizer_log = {"name": "opt_comp", "ph": "X", "pid": self.global_rank, "tid": "8. optimizer-comp", - "ts": self.get_ts(self.optimizer_step_start_event), "dur": optimizer_slot, "cname": "bad"} - # print(optimizer_log) - profiling_log.append(optimizer_log) - - return profiling_log diff --git a/data_parallel/dist_dp_utils.py b/data_parallel/dist_dp_utils.py index 075f658..0e3f52c 100644 --- a/data_parallel/dist_dp_utils.py +++ b/data_parallel/dist_dp_utils.py @@ -2,7 +2,6 @@ from .dist_dp_sharded_ps import ShardedPSDP from .dist_dp_local import LocalDP from .dist_dp_cocktail_sgd import CocktailSGDDP -from .dist_dp_cocktail_sgd_grad import CocktailSGDGradDP def get_dp_module(args, device, module, optimizer): @@ -16,8 +15,6 @@ def get_dp_module(args, device, module, optimizer): return ShardedPSDP(args, device, module, optimizer, flatten=False) elif args.dp_mode == 'cocktail_sgd': return CocktailSGDDP(args, device, module, optimizer, flatten=True) - elif args.dp_mode == 'cocktail_sgd_grad': - return CocktailSGDGradDP(args, device, module, optimizer, flatten=True) else: print("Not recognize this data parallel mode.") assert False diff --git a/dist_lm_train.py b/dist_lm_train.py index f9bb316..392d566 100644 --- a/dist_lm_train.py +++ b/dist_lm_train.py @@ -301,15 +301,15 @@ def main(): 'd_model': args.embedding_dim, 'd_inner': args.embedding_dim * 4, 'vocab_size': 50257, - 'attn_cfg': dict(num_heads = 12), # HARD CODED FOR 125M + 'attn_cfg': dict(num_heads = 12, fused_bias_fc=True, use_flash_attn=True), # HARD CODED FOR 125M 'attn_layer_idx': [1, 8], # HARD CODED FOR 125M - 'ssm_cfg': dict(mode='diag', measure='diag-lin'), + 'ssm_cfg': dict(mode='diag', measure='diag-lin', use_fast_fftconv=True), 'pad_vocab_size_multiple': 8, 'max_position_embeddings': 0, 'resid_dropout': 0.0, 'embed_dropout': 0.1, 'layer_norm_epsilon': 1e-5, - 'fused_mlp': True, + 'fused_mlp': False, 'fused_dropout_add_ln': True, 'residual_in_fp32': True }) @@ -354,9 +354,6 @@ def main(): if args.load_checkpoint: load_checkpoint(pipe, args) - if args.fp16: - pipe.optimizer.reload_model_params() - if args.profiling == 'no-profiling': train_loop(args, pipe, device, train_data_loader, test_data_loader) else: diff --git a/example_scripts/pretrain_h3_125m_5btok.sh b/example_scripts/pretrain_h3_125m_5btok.sh index 58252bf..f6941dd 100644 --- a/example_scripts/pretrain_h3_125m_5btok.sh +++ b/example_scripts/pretrain_h3_125m_5btok.sh @@ -1,7 +1,7 @@ netif=lo export GLOO_SOCKET_IFNAME=${netif} export NCCL_SOCKET_IFNAME=${netif} -export WANDB_NAME=h3-125m-pretrain-pile-ar-5btok-linear +export WANDB_NAME=h3-125m-pretrain-pile-ar-5btok-linear-jue-amp # export WANDB_NAME=test export QUANT_BITS=4 @@ -26,7 +26,7 @@ ARGS="--model-name ./empty_model_configs/h3 \ --evaluation-steps 4000 \ --evaluation-num-batch 256 \ --evaluation-data pile \ ---lr 6e-4 --seq-length 2048 --batch-size 16 --micro-batch-size 8 --gradient-accumulate-step 2 \ +--lr 6e-4 --seq-length 2048 --batch-size 32 --micro-batch-size 1 --gradient-accumulate-step 1 \ --dist-url tcp://127.0.0.1:7033 \ --world-size 8 --pipeline-group-size 1 --data-group-size 8 \ --job-id 0 --net-interface ${netif} \ diff --git a/example_scripts/pretrain_h3_125m_cocktail_5btok.sh b/example_scripts/pretrain_h3_125m_cocktail_5btok.sh new file mode 100644 index 0000000..e93bf62 --- /dev/null +++ b/example_scripts/pretrain_h3_125m_cocktail_5btok.sh @@ -0,0 +1,55 @@ +netif=lo +export GLOO_SOCKET_IFNAME=${netif} +export NCCL_SOCKET_IFNAME=${netif} +export WANDB_NAME=h3-125m-pretrain-pile-cocktail-5btok-linear +# export WANDB_NAME=test + +export QUANT_BITS=4 +export TOPK_RATIO=0.5 +export RANDOMP_RATIO=0.4 + +export SHOW_DATA=0 + +# the model name argument is IGNORED +ARGS="--model-name ./empty_model_configs/h3 \ +--tokenizer-name gpt2 \ +--load-pretrained-model false \ +--project-name cocktail-sgd \ +--model-type h3 \ +--optimizer adam \ +--seed 42 \ +--task-name pile \ +--checkpoint-path ./model_ckpts/$WANDB_NAME \ +--num-layers 12 --embedding-dim 768 \ +--total-steps 20000 --warmup-steps 200 --train-warmup-steps 1000 \ +--checkpoint-steps 500 \ +--evaluation-steps 4000 \ +--evaluation-num-batch 256 \ +--evaluation-data pile \ +--lr 6e-4 --seq-length 2048 --batch-size 32 --micro-batch-size 1 --gradient-accumulate-step 1 \ +--dist-url tcp://127.0.0.1:7033 \ +--world-size 8 --pipeline-group-size 1 --data-group-size 8 \ +--job-id 0 --net-interface ${netif} \ +--dp-backend gloo \ +--dp-mode cocktail_sgd \ +--pp-mode gpipe --profiling no-profiling" + +(trap 'kill 0' SIGINT; \ +python dist_lm_train.py $(echo ${ARGS}) --cuda-id 0 --rank 0 \ + & \ +python dist_lm_train.py $(echo ${ARGS}) --cuda-id 1 --rank 1 \ + & \ +python dist_lm_train.py $(echo ${ARGS}) --cuda-id 2 --rank 2 \ + & \ +python dist_lm_train.py $(echo ${ARGS}) --cuda-id 3 --rank 3 \ + & \ +python dist_lm_train.py $(echo ${ARGS}) --cuda-id 4 --rank 4 \ + & \ +python dist_lm_train.py $(echo ${ARGS}) --cuda-id 5 --rank 5 \ + & \ +python dist_lm_train.py $(echo ${ARGS}) --cuda-id 6 --rank 6 \ + & \ +python dist_lm_train.py $(echo ${ARGS}) --cuda-id 7 --rank 7 \ + & \ +wait) + diff --git a/optimizer/optimizer.py b/optimizer/optimizer.py index 8d9963a..103d068 100644 --- a/optimizer/optimizer.py +++ b/optimizer/optimizer.py @@ -62,6 +62,9 @@ def _multi_tensor_copy_this_to_that(this, that): class Fp16Optimizer: # If offload is set to true, the fp32 copy is stored on CPU. def __init__(self, optimizer, grad_scaler, device, offload=False): + + print('WARN: THIS IMPL IS DEPRECATED! AND WILL BE REMOVED SOON!') + self.offload = offload if self.offload: self.cpu_to_gpu_stream = torch.cuda.Stream(device=device, priority=-1) diff --git a/pipeline_parallel/dist_gpipe_pipeline_async.py b/pipeline_parallel/dist_gpipe_pipeline_async.py index 2b1264c..640abf2 100644 --- a/pipeline_parallel/dist_gpipe_pipeline_async.py +++ b/pipeline_parallel/dist_gpipe_pipeline_async.py @@ -11,6 +11,8 @@ import wandb from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup +from torch.cuda.amp import GradScaler, autocast + flag_profile = int(os.environ.get('FLAG_BENCHMARK', '0')) def get_parameter_names(model, forbidden_layer_types): @@ -84,6 +86,7 @@ def __init__(self, args, config, device, use_dp=False, if args.fp16: self.use_fp16 = True self.use_dynamic_scale = (args.loss_scale == 0) + self.scaler = GradScaler() print("=======Gpipe use FP16") else: self.use_fp16 = False @@ -102,6 +105,7 @@ def __init__(self, args, config, device, use_dp=False, self.post_node_rank = self.pp_rank + \ 1 if self.pp_rank != self.pipeline_group_size - 1 else -1 self.comm = get_pipeline_parallel_comm() + self.dp_comm = get_data_parallel_comm() self.gradient_accumulate_step = args.gradient_accumulate_step print("=======Gradient accumulate step: ", self.gradient_accumulate_step) @@ -207,20 +211,10 @@ def __init__(self, args, config, device, use_dp=False, else: self.model = _StageFull(args, config, device) - if self.use_fp16: - self.model.half() - if do_train: - if self.use_fp16: - tmp_optimizer = create_optimizer( - self.model, optimizer_type=getattr(args, 'optimizer', 'adamw'), learning_rate=args.lr) - self.optimizer = get_fp16_optimizer( - args, tmp_optimizer, device) - optim = tmp_optimizer - else: - self.optimizer = create_optimizer( - self.model, optimizer_type=getattr(args, 'optimizer', 'adamw'), learning_rate=args.lr) - optim = self.optimizer + self.optimizer = create_optimizer( + self.model, optimizer_type=getattr(args, 'optimizer', 'adamw'), learning_rate=args.lr) + optim = self.optimizer if args.total_scheduler_steps is not None: total_sched_steps = args.total_scheduler_steps else: @@ -327,7 +321,7 @@ def forward_stage(self, input_data=None, aux_input_data=None): for i in range(self.micro_batch_num): if self.pipeline_group_size > 1: if self.pp_rank == 0: # Only send output to next node, do not receive - with torch.cuda.stream(self.torch_comp_stream): + with torch.cuda.stream(self.torch_comp_stream), autocast(): self.profile_mark_forward_comp_start(i) current_micro_output = self.model( self.input_micro_batches[i], @@ -353,7 +347,7 @@ def forward_stage(self, input_data=None, aux_input_data=None): self.input_micro_batches[i], src=self.pre_node_rank, stream=cupy_recv_stream) self.torch_recv_stream.record_event( self.forward_recv_ready_events[i]) - with torch.cuda.stream(self.torch_comp_stream): + with torch.cuda.stream(self.torch_comp_stream), autocast(): self.torch_comp_stream.wait_event( self.forward_recv_ready_events[i]) self.profile_mark_forward_comp_start(i) @@ -372,7 +366,7 @@ def forward_stage(self, input_data=None, aux_input_data=None): self.input_micro_batches[i], src=self.pre_node_rank, stream=cupy_recv_stream) self.torch_recv_stream.record_event( self.forward_recv_ready_events[i]) - with torch.cuda.stream(self.torch_comp_stream): + with torch.cuda.stream(self.torch_comp_stream), autocast(): self.torch_comp_stream.wait_event( self.forward_recv_ready_events[i]) self.profile_mark_forward_comp_start(i) @@ -392,7 +386,7 @@ def forward_stage(self, input_data=None, aux_input_data=None): dst=self.post_node_rank, stream=cupy_send_stream) self.profile_mark_forward_send_end(i) else: - with torch.cuda.stream(self.torch_comp_stream): + with torch.cuda.stream(self.torch_comp_stream), autocast(): self.profile_mark_forward_comp_start(i) current_micro_output = self.model( self.input_micro_batches[i], @@ -450,14 +444,14 @@ def backward_stage(self, cached_output_micro_batches: List[torch.Tensor], target for i in range(self.micro_batch_num): if self.pipeline_group_size > 1: if self.pp_rank == self.pipeline_group_size - 1: # only send grad back to last node, do not receive - with torch.cuda.stream(self.torch_comp_stream) as st: + with torch.cuda.stream(self.torch_comp_stream) as st, autocast(): self.profile_mark_backward_comp_start(i) loss = loss_func( input=cached_output_micro_batches[i], target=target_as_micro_batches[i]) if not flag_profile: tr_loss.append(loss.item()) - if self.use_fp16: - self.optimizer.scale(loss).backward() + if self.use_fp16 and self.use_dynamic_scale: + self.scaler.scale(loss).backward() else: loss.backward() self.torch_comp_stream.record_event( @@ -480,7 +474,7 @@ def backward_stage(self, cached_output_micro_batches: List[torch.Tensor], target self.output_micro_batches_grad[i], src=self.post_node_rank, stream=cupy_recv_stream) self.torch_recv_stream.record_event( self.backward_recv_ready_events[i]) - with torch.cuda.stream(self.torch_comp_stream): + with torch.cuda.stream(self.torch_comp_stream), autocast(): self.torch_comp_stream.wait_event( self.backward_recv_ready_events[i]) self.profile_mark_backward_comp_start(i) @@ -497,7 +491,7 @@ def backward_stage(self, cached_output_micro_batches: List[torch.Tensor], target self.output_micro_batches_grad[i], src=self.post_node_rank, stream=cupy_recv_stream) self.torch_recv_stream.record_event( self.backward_recv_ready_events[i]) - with torch.cuda.stream(self.torch_comp_stream): + with torch.cuda.stream(self.torch_comp_stream), autocast(): self.torch_comp_stream.wait_event( self.backward_recv_ready_events[i]) self.profile_mark_backward_comp_start(i) @@ -516,18 +510,21 @@ def backward_stage(self, cached_output_micro_batches: List[torch.Tensor], target self.profile_mark_backward_send_end(i) else: - with torch.cuda.stream(self.torch_comp_stream) as st: + with torch.cuda.stream(self.torch_comp_stream) as st, autocast(): self.profile_mark_backward_comp_start(i) loss = loss_func( input=cached_output_micro_batches[i], target=target_as_micro_batches[i]) if not flag_profile: tr_loss.append(loss.item()) - if self.use_fp16: - self.optimizer.scale(loss).backward() + if self.use_fp16 and self.use_dynamic_scale: + self.scaler.scale(loss).backward() else: loss.backward() self.torch_comp_stream.record_event( self.backward_comp_ready_events[i]) + + if self.pp_rank == self.pipeline_group_size - 1: + print('loss: ', sum(tr_loss) / len(tr_loss)) if not flag_profile: if self.pp_rank == self.pipeline_group_size - 1: @@ -579,23 +576,41 @@ def save_on_disk(self, path): torch.save(self.model.state_dict(), os.path.join(path, 'pytorch_model.bin')) def optimizer_step(self): - # hard code: grad clipping - if not self.use_fp16: - torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) - if self.use_dp: - with torch.cuda.stream(self.torch_comp_stream): - self.torch_comp_stream.record_event( - self.dp_optim.backward_ready_event) - self.dp_optim.optimizer_step() - self.scheduler.step() - else: - with torch.cuda.stream(self.torch_comp_stream): - if self.enable_tidy_profiling: - self.optimizer_start_event.record() - self.optimizer.step() + + has_nan = False + if self.use_fp16 and self.use_dynamic_scale: + self.scaler.unscale_(self.optimizer) + optimizer_state = self.scaler._per_optimizer_states[id(self.optimizer)] + has_nan = sum(v.item() for v in optimizer_state["found_inf_per_device"].values()) + has_nan = torch.tensor(has_nan, device=self.device, dtype=torch.float32) + self.dp_comm.all_reduce(has_nan) + has_nan = has_nan.item() + + if not has_nan: + + if self.use_dp: + with torch.cuda.stream(self.torch_comp_stream): + self.torch_comp_stream.record_event( + self.dp_optim.backward_ready_event) + self.dp_optim.optimizer_step() self.scheduler.step() - if self.enable_tidy_profiling: - self.optimizer_end_event.record() + else: + with torch.cuda.stream(self.torch_comp_stream): + + torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) + + if self.enable_tidy_profiling: + self.optimizer_start_event.record() + self.optimizer.step() + self.scheduler.step() + if self.enable_tidy_profiling: + self.optimizer_end_event.record() + else: + print('Found nan/inf, skip.') + + if self.use_fp16 and self.use_dynamic_scale: + self.scaler.update() + if self.enable_tidy_profiling: self.profiling_optimizer_step() @@ -620,10 +635,12 @@ def sgd_iter(self, input_=None, target=None, aux_input_data=None, loss_func=torch.nn.functional.cross_entropy): - if self.use_fp16 and self.use_dynamic_scale: - scales_buffer = [torch.ones_like(self.optimizer.grad_scaler._scale) for _ in range(self.pipeline_group_size)] - self.comm.all_gather(self.optimizer.grad_scaler._scale, scales_buffer) - self.optimizer.grad_scaler._scale.data[:] = min([s.item() for s in scales_buffer]) + if self.use_fp16 and self.use_dynamic_scale and self.scaler._scale is not None: + scales_buffer = [torch.ones_like(self.scaler._scale) for _ in range(self.pipeline_group_size)] + self.comm.all_gather(self.scaler._scale, scales_buffer) + self.scaler._scale.data[:] = min([s.item() for s in scales_buffer]) + + #print(self.scaler._scale) self.comm.barrier() @@ -708,7 +725,7 @@ def infer_stage(self, input_data=None, aux_input_data=None, for i in range(self.micro_batch_num): if self.pipeline_group_size > 1: if self.pp_rank == 0: # Only send output to next node, do not receive - with torch.cuda.stream(self.torch_comp_stream): + with torch.cuda.stream(self.torch_comp_stream), autocast(): current_micro_output = self.model( self.input_micro_batches[i], **{k: v[i] for k, v in aux_input_data.items()}, @@ -723,7 +740,7 @@ def infer_stage(self, input_data=None, aux_input_data=None, cupy_recv_stream = cupy.cuda.ExternalStream(self.torch_recv_stream.cuda_stream) self.comm.recv(self.input_micro_batches[i], src=self.pre_node_rank, stream=cupy_recv_stream) self.torch_recv_stream.record_event(self.forward_recv_ready_events[i]) - with torch.cuda.stream(self.torch_comp_stream): + with torch.cuda.stream(self.torch_comp_stream), autocast(): self.torch_comp_stream.wait_event(self.forward_recv_ready_events[i]) current_micro_output = self.model( self.input_micro_batches[i], input_ids=input_ids_micro_batches[i], @@ -736,7 +753,7 @@ def infer_stage(self, input_data=None, aux_input_data=None, cupy_recv_stream = cupy.cuda.ExternalStream(self.torch_recv_stream.cuda_stream) self.comm.recv(self.input_micro_batches[i], src=self.pre_node_rank, stream=cupy_recv_stream) self.torch_recv_stream.record_event(self.forward_recv_ready_events[i]) - with torch.cuda.stream(self.torch_comp_stream): + with torch.cuda.stream(self.torch_comp_stream), autocast(): self.torch_comp_stream.wait_event(self.forward_recv_ready_events[i]) current_micro_output = self.model( self.input_micro_batches[i], @@ -748,7 +765,7 @@ def infer_stage(self, input_data=None, aux_input_data=None, self.torch_send_stream.wait_event(self.forward_comp_ready_events[i]) self.comm.send(current_micro_output.data, dst=self.post_node_rank, stream=cupy_send_stream) else: - with torch.cuda.stream(self.torch_comp_stream): + with torch.cuda.stream(self.torch_comp_stream), autocast(): current_micro_output = self.model( self.input_micro_batches[i], **{k: v[i] for k, v in aux_input_data.items()}