Skip to content
Open

Misc #24

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions examples/model_parallel/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
fix_rand()


def test_model(dim, depth, nh=2, dtype=torch.float):
def test_model(dim, depth, nh=2, dtype=torch.float, seqlen=1024, b=1):
model = Transformer(dim, depth=depth, num_heads=nh, tensor_parallel=False, sequence_parallel=False).cuda().to(dtype)
# tp_model = Transformer(dim, depth=depth, num_heads=nh, tensor_parallel=True, sequence_parallel=False).cuda().to(dtype)
tp_model = Transformer(dim, depth=depth, num_heads=nh, tensor_parallel=True, sequence_parallel=True).cuda().to(dtype)
tp_model = Transformer(dim, depth=depth, num_heads=nh, tensor_parallel=True, sequence_parallel=False).cuda().to(dtype)
# tp_model = Transformer(dim, depth=depth, num_heads=nh, tensor_parallel=True, sequence_parallel=True).cuda().to(dtype)


opt = torch.optim.AdamW(model.parameters())
Expand All @@ -24,13 +24,14 @@ def test_model(dim, depth, nh=2, dtype=torch.float):
# sp_model.blocks[ind].init_from_full(model.blocks[ind])

for _ in range(10):
inp = torch.rand((32, 1024, dim)).cuda().to(dtype)
inp = torch.rand((b, seqlen, dim)).cuda().to(dtype)

opt.zero_grad()

out = model(inp)
tp_out = tp_model(inp)
assert torch.allclose(out, tp_out, rtol=1e-1, atol=1e-02)
# import pdb;pdb.set_trace()
assert torch.allclose(out, tp_out, rtol=1e-3, atol=1e-02)

import pdb;pdb.set_trace()
# TODO: fix this mis alignment
Expand All @@ -43,4 +44,5 @@ def test_model(dim, depth, nh=2, dtype=torch.float):
opt.step()
tp_opt.step()

test_model(1024, 8)
# test_model(1024, 8)
test_model(4, 2, seqlen=5)
1,611 changes: 1,611 additions & 0 deletions explore/memory-profile/attention_processor.py

Large diffs are not rendered by default.

731 changes: 731 additions & 0 deletions explore/memory-profile/transblock.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions torchdistpackage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from .dist.sharded_ema import ShardedEMA


from .utils import fix_rand
from .utils import fix_rand, sliced_run

from .tools.module_profiler import report_prof, register_profile_hooks, get_model_profile

from .tools.module_replace import replace_all_module
from .tools.bnb_fc import replace_linear_by_bnb
from .tools.bminf_int8 import replace_linear_by_bminf
from .tools.bminf_int8 import replace_linear_by_bminf
from .tools.debug_nan import register_debug_nan_hooks,check_model_params
2 changes: 2 additions & 0 deletions torchdistpackage/dist/launch_from_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def setup_distributed_slurm(backend="nccl", port=None):
import torch
import torch.distributed as dist

if dist.is_initialized():
return None
num_gpus = torch.cuda.device_count()

if "SLURM_JOB_ID" in os.environ:
Expand Down
10 changes: 7 additions & 3 deletions torchdistpackage/parallel/tensor_parallel/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,19 @@ def _naive_attn(self, x):
q= _split_heads(q, self.num_heads, self.head_dim)
k= _split_heads(k, self.num_heads, self.head_dim)
v= _split_heads(v, self.num_heads, self.head_dim)
# (b, nh, s, hdim)

attn = ((q * self.scale) @ k.transpose(-2, -1))
# (b, nh, s, s)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, DIM)
x = self.proj(x)
x = self.proj_drop(x)
return x

def forward(self, x):
x = self._naive_attn(x)
x = self.proj_drop(x)
return x


Expand Down Expand Up @@ -80,14 +82,13 @@ def _naive_attn(self, x):
q= _split_heads(q, self.head_num_per_partition, self.head_dim)
k= _split_heads(k, self.head_num_per_partition, self.head_dim)
v= _split_heads(v, self.head_num_per_partition, self.head_dim)

# (b, nh_per, s, hdim)

attn = ((q * self.scale) @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, DIM//self.tp_size)
x = self.proj(x)
x = self.proj_drop(x)
return x

def forward(self, x):
Expand All @@ -96,4 +97,7 @@ def forward(self, x):
x = gather_from_sequence_parallel_region(x)

x = self._naive_attn(x)

x = self.proj_drop(x)

return x
10 changes: 6 additions & 4 deletions torchdistpackage/parallel/tensor_parallel/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def __init__(self, dim, mlp_ratio=4, num_heads=8, sequence_parallel=False):
self.sequence_parallel = sequence_parallel

def forward(self, hidden_states):
if self.sequence_parallel:
hidden_states = maybe_split_into_sequence_parallel(hidden_states)
# if self.sequence_parallel:
# hidden_states = maybe_split_into_sequence_parallel(hidden_states)

residual = hidden_states
hidden_states = self.ln_1(hidden_states)
Expand All @@ -67,8 +67,8 @@ def forward(self, hidden_states):
# residual connection
hidden_states = residual + feed_forward_hidden_states

if self.sequence_parallel:
set_sequence_parallel_attr(hidden_states)
# if self.sequence_parallel:
# set_sequence_parallel_attr(hidden_states)
return hidden_states

@torch.no_grad()
Expand All @@ -93,6 +93,8 @@ def __init__(self, dim, mlp_ratio=4, num_heads=8, depth=12, tensor_parallel=True
self.sequence_parallel = sequence_parallel

def forward(self, x):
if self.sequence_parallel:
x = maybe_split_into_sequence_parallel(x)
for blk in self.blocks:
x = blk(x)
if self.sequence_parallel:
Expand Down
4 changes: 2 additions & 2 deletions torchdistpackage/tools/bminf_int8.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import torch

import bminf

from .module_replace import replace_all_module

def if_replace_linear(name, module):
return isinstance(module, (torch.nn.Linear))

def get_new_module(name, module):
import bminf

new_linear = bminf.QuantizedLinear(module)
return new_linear

Expand Down
3 changes: 2 additions & 1 deletion torchdistpackage/tools/bnb_fc.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import torch
import functools
from bitsandbytes.nn import Linear8bitLt

from .module_replace import replace_all_module

def if_replace_linear(name, module):
return isinstance(module, (torch.nn.Linear))

def get_new_module(name, module, bnb_kwargs={}):
from bitsandbytes.nn import Linear8bitLt

old_shape = module.weight.shape
indim = old_shape[1]
outdim = old_shape[0]
Expand Down
79 changes: 70 additions & 9 deletions torchdistpackage/tools/debug_nan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,38 @@ def check_tensors(inputs):


def check_model_params(model):
ret=True
for name,param in model.named_parameters():
if not check_tensor_inf_nan(param):
print("model param:", name, " contains nan/inf!")
import pdb;pdb.set_trace()
return False
# import pdb;pdb.set_trace()
ret = False
return ret

def print_tensor_info(module_name, output, msg):
if isinstance(output, torch.Tensor):
print(module_name, msg, output.max())
elif isinstance(output, tuple) or isinstance(output, list):
for i, t in enumerate(output):
print_tensor_info(module_name, t, msg)
# print(module_name, f"output_{i}_max", t.max())

@torch.no_grad
def get_max(output):
max=0
if isinstance(output, torch.Tensor):
return output.max().item()
elif isinstance(output, tuple) or isinstance(output, list):
for i, t in enumerate(output):
if get_max(t)>max:
max = get_max(t)
return max

def find_big_change(module_name, args, output):
out_max = get_max(output)
in_max = get_max(args)
if in_max!=0 and out_max / in_max >= 2 and out_max>1000:
print(module_name, out_max, in_max)

# register_forward_hook
@torch.no_grad()
Expand All @@ -38,24 +65,58 @@ def check_value(module, args, output):
import pdb;pdb.set_trace()
print(module_name, "output contains nan/inf!")

# if 'conv_shortcut' in module_name:
# print_tensor_info(module_name, args, "input_max")
# print_tensor_info(module_name, output, "output_max")
find_big_change(module_name, args, output)

return check_value


# model.register_full_backward_hook
def bwd_hook_wrapper(module_name=''):
def check_value(module, grad_input, grad_output):
if not check_tensors(grad_output):
print(module_name, "grad_output contains nan/inf!")
import pdb;pdb.set_trace()
if not check_tensors(grad_input):
print(module_name, "grad_input contains nan/inf!")
import pdb;pdb.set_trace()
# import pdb;pdb.set_trace()
return check_value

# detect param update nan

# usage example:
# fwd_hooks = {}
# bwd_hooks = {}
# for name, module in model.named_modules():
# fwd_hooks[name] = module.register_forward_hook(fwd_hook_wrapper(name))
# bwd_hooks[name] = module.register_forward_hook(bwd_hook_wrapper(name))
def register_debug_nan_hooks(model):
fwd_hooks = {}
bwd_hooks = {}
for name, module in model.named_modules():
fwd_hooks[name] = module.register_forward_hook(fwd_hook_wrapper(name))
# bwd_hooks[name] = module.register_full_backward_hook(bwd_hook_wrapper(name))
return fwd_hooks, bwd_hooks

# debug grad norm

def get_grad_norm(grads, norm_type=2.0):
norms = [torch.norm(g.detach(), norm_type) for g in grads if g is not None]
if len(norms)==0:
return 0
total_norm = torch.norm(torch.stack(norms), norm_type)
return total_norm

def bwd_hook_wrapper(module_name='', max_grad_norm=0.1):
def check_value(module, grad_input, grad_output):
out_grad_norm = get_grad_norm(grad_output)
in_grad_norm = get_grad_norm(grad_input)
if out_grad_norm > max_grad_norm:
if dist.get_rank()==0:
print(module_name, "grad_output norm:", out_grad_norm, flush=True)
if in_grad_norm > max_grad_norm:
if dist.get_rank()==0:
print(module_name, "grad_input norm:", in_grad_norm, flush=True)
return check_value

def register_debug_hooks(model):
bwd_hooks = {}
for name, module in model.named_modules():
bwd_hooks[name] = module.register_full_backward_hook(bwd_hook_wrapper(name))
return bwd_hooks
23 changes: 22 additions & 1 deletion torchdistpackage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,25 @@ def partition_params(model, num_partitions, return_dict=False):
if elcnt > numel_per_partition:
partition_id+=1
elcnt=0
return partitions
return partitions

def sliced_run(fn, input: torch.Tensor, micro_bs):
"""slice the input into several micro batches and run forward, with no_grad

Args:
fn (function): forward function
input (torch.Tensor): tensor
micro_bs (integer): micro batchsize

Returns:
tensor
"""
chunk_outs = []
with torch.no_grad():
chunks = input.split(micro_bs, 0)
for chunk in chunks:
chunk_out = fn(chunk)
chunk_outs.append(chunk_out)

out = torch.cat(chunk_outs, dim=0)
return out