Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions comm/torch_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def recv(self,
else:
buffer = tensor.cpu()
dist.recv(buffer, self.to_global_rank(src), group=self.process_group)
tensor[:] = buffer.to(tensor.device)
tensor.set_(buffer.to(tensor.device))

def isend(self,
tensor: torch.Tensor,
Expand All @@ -63,7 +63,7 @@ def irecv(self,
assert False
buffer = tensor.cpu()
handler = dist.irecv(buffer, self.to_global_rank(src), group=self.process_group)
tensor[:] = buffer.to(tensor.device)
tensor.set_(buffer.to(tensor.device))
return handler

def broadcast(self,
Expand All @@ -75,7 +75,7 @@ def broadcast(self,
else:
buffer = tensor.cpu()
dist.broadcast(buffer, self.to_global_rank(src), group=self.process_group)
tensor[:] = buffer.to(tensor.device)
tensor.set_(buffer.to(tensor.device))

def reduce(self,
tensor: torch.Tensor,
Expand All @@ -90,7 +90,7 @@ def all_reduce(self,
op=dist.ReduceOp.SUM):
buffer = tensor.cpu()
dist.all_reduce(buffer, group=self.process_group, op=op)
tensor[:] = buffer.to(tensor.device)
tensor.set_(buffer.to(tensor.device))

def gather(self,
tensor: torch.Tensor,
Expand Down
1 change: 1 addition & 0 deletions data_parallel/dist_dp_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def optimizer_step(self):
with torch.cuda.stream(self.torch_optim_comp_stream):
self.torch_optim_comp_stream.wait_event(self.allreduce_grad_ready_event)
self.profile_mark_optimizer_step_start()
torch.nn.utils.clip_grad_norm_(self.module.parameters(), 1.0)
self.optimizer.step()
self.torch_optim_comp_stream.record_event(self.optimizer_step_ready_event)

Expand Down
4 changes: 2 additions & 2 deletions data_parallel/dist_dp_cocktail_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, args, device, module: torch.nn.Module, optimizer: torch.optim

if self.flatten:
_params = []
for i_group, group in enumerate(self.optimizer.optimizer.param_groups):
for i_group, group in enumerate(self.optimizer.param_groups):
for i_para, para in enumerate(group["params"]):
_params.append(para)
self.flatten_para = flatten_tensors(_params)
Expand Down Expand Up @@ -346,7 +346,7 @@ def _partial_sync(self):
cupy_dp_stream = cupy.cuda.ExternalStream(self.dp_comm_stream.cuda_stream)
with torch.cuda.stream(self.dp_comm_stream), cupy_dp_stream:

for i_group, group in enumerate(self.optimizer.optimizer.param_groups):
for i_group, group in enumerate(self.optimizer.param_groups):
for i_para, para in enumerate(group["params"]):


Expand Down
532 changes: 0 additions & 532 deletions data_parallel/dist_dp_cocktail_sgd_grad.py

This file was deleted.

3 changes: 0 additions & 3 deletions data_parallel/dist_dp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from .dist_dp_sharded_ps import ShardedPSDP
from .dist_dp_local import LocalDP
from .dist_dp_cocktail_sgd import CocktailSGDDP
from .dist_dp_cocktail_sgd_grad import CocktailSGDGradDP


def get_dp_module(args, device, module, optimizer):
Expand All @@ -16,8 +15,6 @@ def get_dp_module(args, device, module, optimizer):
return ShardedPSDP(args, device, module, optimizer, flatten=False)
elif args.dp_mode == 'cocktail_sgd':
return CocktailSGDDP(args, device, module, optimizer, flatten=True)
elif args.dp_mode == 'cocktail_sgd_grad':
return CocktailSGDGradDP(args, device, module, optimizer, flatten=True)
else:
print("Not recognize this data parallel mode.")
assert False
9 changes: 3 additions & 6 deletions dist_lm_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,15 +301,15 @@ def main():
'd_model': args.embedding_dim,
'd_inner': args.embedding_dim * 4,
'vocab_size': 50257,
'attn_cfg': dict(num_heads = 12), # HARD CODED FOR 125M
'attn_cfg': dict(num_heads = 12, fused_bias_fc=True, use_flash_attn=True), # HARD CODED FOR 125M
'attn_layer_idx': [1, 8], # HARD CODED FOR 125M
'ssm_cfg': dict(mode='diag', measure='diag-lin'),
'ssm_cfg': dict(mode='diag', measure='diag-lin', use_fast_fftconv=True),
'pad_vocab_size_multiple': 8,
'max_position_embeddings': 0,
'resid_dropout': 0.0,
'embed_dropout': 0.1,
'layer_norm_epsilon': 1e-5,
'fused_mlp': True,
'fused_mlp': False,
'fused_dropout_add_ln': True,
'residual_in_fp32': True
})
Expand Down Expand Up @@ -354,9 +354,6 @@ def main():
if args.load_checkpoint:
load_checkpoint(pipe, args)

if args.fp16:
pipe.optimizer.reload_model_params()

if args.profiling == 'no-profiling':
train_loop(args, pipe, device, train_data_loader, test_data_loader)
else:
Expand Down
4 changes: 2 additions & 2 deletions example_scripts/pretrain_h3_125m_5btok.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
netif=lo
export GLOO_SOCKET_IFNAME=${netif}
export NCCL_SOCKET_IFNAME=${netif}
export WANDB_NAME=h3-125m-pretrain-pile-ar-5btok-linear
export WANDB_NAME=h3-125m-pretrain-pile-ar-5btok-linear-jue-amp
# export WANDB_NAME=test

export QUANT_BITS=4
Expand All @@ -26,7 +26,7 @@ ARGS="--model-name ./empty_model_configs/h3 \
--evaluation-steps 4000 \
--evaluation-num-batch 256 \
--evaluation-data pile \
--lr 6e-4 --seq-length 2048 --batch-size 16 --micro-batch-size 8 --gradient-accumulate-step 2 \
--lr 6e-4 --seq-length 2048 --batch-size 32 --micro-batch-size 1 --gradient-accumulate-step 1 \
--dist-url tcp://127.0.0.1:7033 \
--world-size 8 --pipeline-group-size 1 --data-group-size 8 \
--job-id 0 --net-interface ${netif} \
Expand Down
55 changes: 55 additions & 0 deletions example_scripts/pretrain_h3_125m_cocktail_5btok.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
netif=lo
export GLOO_SOCKET_IFNAME=${netif}
export NCCL_SOCKET_IFNAME=${netif}
export WANDB_NAME=h3-125m-pretrain-pile-cocktail-5btok-linear
# export WANDB_NAME=test

export QUANT_BITS=4
export TOPK_RATIO=0.5
export RANDOMP_RATIO=0.4

export SHOW_DATA=0

# the model name argument is IGNORED
ARGS="--model-name ./empty_model_configs/h3 \
--tokenizer-name gpt2 \
--load-pretrained-model false \
--project-name cocktail-sgd \
--model-type h3 \
--optimizer adam \
--seed 42 \
--task-name pile \
--checkpoint-path ./model_ckpts/$WANDB_NAME \
--num-layers 12 --embedding-dim 768 \
--total-steps 20000 --warmup-steps 200 --train-warmup-steps 1000 \
--checkpoint-steps 500 \
--evaluation-steps 4000 \
--evaluation-num-batch 256 \
--evaluation-data pile \
--lr 6e-4 --seq-length 2048 --batch-size 32 --micro-batch-size 1 --gradient-accumulate-step 1 \
--dist-url tcp://127.0.0.1:7033 \
--world-size 8 --pipeline-group-size 1 --data-group-size 8 \
--job-id 0 --net-interface ${netif} \
--dp-backend gloo \
--dp-mode cocktail_sgd \
--pp-mode gpipe --profiling no-profiling"

(trap 'kill 0' SIGINT; \
python dist_lm_train.py $(echo ${ARGS}) --cuda-id 0 --rank 0 \
& \
python dist_lm_train.py $(echo ${ARGS}) --cuda-id 1 --rank 1 \
& \
python dist_lm_train.py $(echo ${ARGS}) --cuda-id 2 --rank 2 \
& \
python dist_lm_train.py $(echo ${ARGS}) --cuda-id 3 --rank 3 \
& \
python dist_lm_train.py $(echo ${ARGS}) --cuda-id 4 --rank 4 \
& \
python dist_lm_train.py $(echo ${ARGS}) --cuda-id 5 --rank 5 \
& \
python dist_lm_train.py $(echo ${ARGS}) --cuda-id 6 --rank 6 \
& \
python dist_lm_train.py $(echo ${ARGS}) --cuda-id 7 --rank 7 \
& \
wait)

3 changes: 3 additions & 0 deletions optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def _multi_tensor_copy_this_to_that(this, that):
class Fp16Optimizer:
# If offload is set to true, the fp32 copy is stored on CPU.
def __init__(self, optimizer, grad_scaler, device, offload=False):

print('WARN: THIS IMPL IS DEPRECATED! AND WILL BE REMOVED SOON!')

self.offload = offload
if self.offload:
self.cpu_to_gpu_stream = torch.cuda.Stream(device=device, priority=-1)
Expand Down
Loading