Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions DCU-adaptation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# 说明

```
1. 修改fused_kernel文件夹适配海光DCU
2. 修改arguments.py文件适配mpirun启动方式
3. sbatch run-gpt3.sh可以在太原平台启动88节点运行161B的GPT模型
```

Binary file not shown.
12 changes: 8 additions & 4 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def parse_args(extra_args_provider=None, defaults={},
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
allow_abbrev=False)

# Standard arguments.
parser = _add_network_size_args(parser)
parser = _add_regularization_args(parser)
Expand Down Expand Up @@ -67,8 +66,8 @@ def parse_args(extra_args_provider=None, defaults={},
args = parser.parse_args()

# Distributed args.
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))
#args.rank = int(os.getenv('RANK', '0'))
#args.world_size = int(os.getenv("WORLD_SIZE", '1'))
# Tensor model parallel size.
args.tensor_model_parallel_size = min(
args.tensor_model_parallel_size, args.world_size)
Expand Down Expand Up @@ -315,7 +314,6 @@ def parse_args(extra_args_provider=None, defaults={},
import bitsandbytes as bnb
except ModuleNotFoundError:
raise ModuleNotFoundError("Please install bitsandbytes from https://github.com/facebookresearch/bitsandbytes.")

_print_args(args)
return args

Expand Down Expand Up @@ -746,6 +744,12 @@ def _add_distributed_args(parser):
group.add_argument('--use-cpu-initialization', action='store_true',
default=None, help='If set, affine parallel weights '
'initialization uses CPU' )

group.add_argument('--rank', default=-1, type=int, help='node rank for distributed training')
group.add_argument('--dist_url', type=str, default="env://127.0.0.1:23456")
group.add_argument('--world_size', type=int, default=-1, help='number of nodes for distributed training')
group.add_argument('--dist_backend', default='nccl', type=str, help='distributed backend')

return parser


Expand Down
13 changes: 5 additions & 8 deletions megatron/fused_kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
sources=sources,
build_directory=buildpath,
extra_cflags=['-O3',],
extra_cuda_cflags=['-O3',
'--use_fast_math'] + extra_cuda_flags,
extra_cuda_cflags=['-O3'] + extra_cuda_flags,
verbose=(args.rank == 0)
)
# '-gencode', 'arch=compute_70,code=sm_70',
Expand All @@ -66,9 +65,7 @@ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):

if args.masked_softmax_fusion:
extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr',
'--expt-extended-lambda']
'-U__CUDA_NO_HALF_CONVERSIONS__']

# Upper triangular softmax.
sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
Expand All @@ -87,18 +84,18 @@ def _cpp_extention_load_helper(name, sources, extra_cuda_flags):
# Mixed precision fused layer norm.
# =================================

extra_cuda_flags = ['-maxrregcount=50']
extra_cuda_flags = []
sources=[srcpath / 'layer_norm_cuda.cpp',
srcpath / 'layer_norm_cuda_kernel.cu']
fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
"fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags)


def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
raw_output = subprocess.check_output([cuda_dir + "/bin/hipcc", "--version"],
universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release_idx = output.index("version") + 1
release = output[release_idx].split(".")
bare_metal_major = release[0]
bare_metal_minor = release[1][0]
Expand Down
6 changes: 3 additions & 3 deletions megatron/fused_kernels/layer_norm_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -247,13 +247,13 @@ void cuWelfordMuSigma2(
}
}

template<typename U> U rsqrt(U v) {
template<typename U> __device__ U rsqrt(U v) {
return U(1) / sqrt(v);
}
template<> float rsqrt(float v) {
template<> __device__ float rsqrt(float v) {
return rsqrtf(v);
}
template<> double rsqrt(double v) {
template<> __device__ double rsqrt(double v) {
return rsqrt(v);
}

Expand Down
Loading