diff --git a/DCU-adaptation.md b/DCU-adaptation.md new file mode 100644 index 000000000..3652b55f9 --- /dev/null +++ b/DCU-adaptation.md @@ -0,0 +1,8 @@ +# 说明 + +``` +1. 修改fused_kernel文件夹适配海光DCU +2. 修改arguments.py文件适配mpirun启动方式 +3. sbatch run-gpt3.sh可以在太原平台启动88节点运行161B的GPT模型 +``` + diff --git a/lightop-whl/lightop-0.1-cp37-cp37m-linux_x86_64.whl b/lightop-whl/lightop-0.1-cp37-cp37m-linux_x86_64.whl new file mode 100644 index 000000000..5960e9471 Binary files /dev/null and b/lightop-whl/lightop-0.1-cp37-cp37m-linux_x86_64.whl differ diff --git a/megatron/arguments.py b/megatron/arguments.py index c18235a78..60ec52432 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -34,7 +34,6 @@ def parse_args(extra_args_provider=None, defaults={}, """Parse all arguments.""" parser = argparse.ArgumentParser(description='Megatron-LM Arguments', allow_abbrev=False) - # Standard arguments. parser = _add_network_size_args(parser) parser = _add_regularization_args(parser) @@ -67,8 +66,8 @@ def parse_args(extra_args_provider=None, defaults={}, args = parser.parse_args() # Distributed args. - args.rank = int(os.getenv('RANK', '0')) - args.world_size = int(os.getenv("WORLD_SIZE", '1')) + #args.rank = int(os.getenv('RANK', '0')) + #args.world_size = int(os.getenv("WORLD_SIZE", '1')) # Tensor model parallel size. args.tensor_model_parallel_size = min( args.tensor_model_parallel_size, args.world_size) @@ -315,7 +314,6 @@ def parse_args(extra_args_provider=None, defaults={}, import bitsandbytes as bnb except ModuleNotFoundError: raise ModuleNotFoundError("Please install bitsandbytes from https://github.com/facebookresearch/bitsandbytes.") - _print_args(args) return args @@ -746,6 +744,12 @@ def _add_distributed_args(parser): group.add_argument('--use-cpu-initialization', action='store_true', default=None, help='If set, affine parallel weights ' 'initialization uses CPU' ) + + group.add_argument('--rank', default=-1, type=int, help='node rank for distributed training') + group.add_argument('--dist_url', type=str, default="env://127.0.0.1:23456") + group.add_argument('--world_size', type=int, default=-1, help='number of nodes for distributed training') + group.add_argument('--dist_backend', default='nccl', type=str, help='distributed backend') + return parser diff --git a/megatron/fused_kernels/__init__.py b/megatron/fused_kernels/__init__.py index e2ac2567b..7e24d984d 100644 --- a/megatron/fused_kernels/__init__.py +++ b/megatron/fused_kernels/__init__.py @@ -54,8 +54,7 @@ def _cpp_extention_load_helper(name, sources, extra_cuda_flags): sources=sources, build_directory=buildpath, extra_cflags=['-O3',], - extra_cuda_cflags=['-O3', - '--use_fast_math'] + extra_cuda_flags, + extra_cuda_cflags=['-O3'] + extra_cuda_flags, verbose=(args.rank == 0) ) # '-gencode', 'arch=compute_70,code=sm_70', @@ -66,9 +65,7 @@ def _cpp_extention_load_helper(name, sources, extra_cuda_flags): if args.masked_softmax_fusion: extra_cuda_flags = ['-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', - '--expt-relaxed-constexpr', - '--expt-extended-lambda'] + '-U__CUDA_NO_HALF_CONVERSIONS__'] # Upper triangular softmax. sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', @@ -87,7 +84,7 @@ def _cpp_extention_load_helper(name, sources, extra_cuda_flags): # Mixed precision fused layer norm. # ================================= - extra_cuda_flags = ['-maxrregcount=50'] + extra_cuda_flags = [] sources=[srcpath / 'layer_norm_cuda.cpp', srcpath / 'layer_norm_cuda_kernel.cu'] fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper( @@ -95,10 +92,10 @@ def _cpp_extention_load_helper(name, sources, extra_cuda_flags): def _get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], + raw_output = subprocess.check_output([cuda_dir + "/bin/hipcc", "--version"], universal_newlines=True) output = raw_output.split() - release_idx = output.index("release") + 1 + release_idx = output.index("version") + 1 release = output[release_idx].split(".") bare_metal_major = release[0] bare_metal_minor = release[1][0] diff --git a/megatron/fused_kernels/layer_norm_cuda_kernel.cu b/megatron/fused_kernels/layer_norm_cuda_kernel.cu index 28a579e1a..26904ca0a 100644 --- a/megatron/fused_kernels/layer_norm_cuda_kernel.cu +++ b/megatron/fused_kernels/layer_norm_cuda_kernel.cu @@ -247,13 +247,13 @@ void cuWelfordMuSigma2( } } -template U rsqrt(U v) { +template __device__ U rsqrt(U v) { return U(1) / sqrt(v); } -template<> float rsqrt(float v) { +template<> __device__ float rsqrt(float v) { return rsqrtf(v); } -template<> double rsqrt(double v) { +template<> __device__ double rsqrt(double v) { return rsqrt(v); } diff --git a/megatron/fused_kernels/layer_norm_hip_kernel.hip b/megatron/fused_kernels/layer_norm_hip_kernel.hip new file mode 100644 index 000000000..6f1d7991f --- /dev/null +++ b/megatron/fused_kernels/layer_norm_hip_kernel.hip @@ -0,0 +1,833 @@ +// !!! This is a file automatically generated by hipify!!! +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*This code is copied fron NVIDIA apex: + * https://github.com/NVIDIA/apex + * with minor changes. */ + +#include "ATen/ATen.h" +#include "ATen/AccumulateType.h" +#include "ATen/hip/HIPContext.h" +#include "ATen/hip/DeviceUtils.cuh" + +#include +#include + +#include "type_shim.h" + +template __device__ +void cuWelfordOnlineSum( + const U curr, + U& mu, + U& sigma2, + U& count) +{ + count = count + U(1); + U delta = curr - mu; + U lmean = mu + delta / count; + mu = lmean; + U delta2 = curr - lmean; + sigma2 = sigma2 + delta * delta2; +} + +template __device__ +void cuChanOnlineSum( + const U muB, + const U sigma2B, + const U countB, + U& mu, + U& sigma2, + U& count) +{ + U delta = muB - mu; + U nA = count; + U nB = countB; + count = count + countB; + U nX = count; + if (nX > U(0)) { + nA = nA / nX; + nB = nB / nX; + mu = nA*mu + nB*muB; + sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX; + } else { + mu = U(0); + sigma2 = U(0); + } +} + +template __device__ +void cuWelfordMuSigma2( + const T* __restrict__ vals, + const int n1, + const int n2, + const int i1, + U& mu, + U& sigma2, + U* buf) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + U count = U(0); + mu= U(0); + sigma2 = U(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const T* lvals = vals + i1*n2; + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + U curr = static_cast(lvals[l+k]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + } + for (; l < n2; ++l) { + U curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x+(1<(muB,sigma2B,countB,mu,sigma2,count); + } + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + U* ubuf = (U*)buf; + U* ibuf = (U*)(ubuf + blockDim.y); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2*wrt_y] = mu; + ubuf[2*wrt_y+1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + U muB = ubuf[2*threadIdx.y]; + U sigma2B = ubuf[2*threadIdx.y+1]; + U countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1]/U(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2/U(n2), 0); + } + } +} + +template<> __device__ +void cuWelfordMuSigma2( + const at::Half* __restrict__ vals, + const int n1, + const int n2, + const int i1, + float& mu, + float& sigma2, + float* buf) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensor is contiguous + // 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available. + // + // compute variance and mean over n2 + float count = 0.0f; + mu= float(0); + sigma2 = float(0); + if (i1 < n1) { + // one warp normalizes one n1 index, + // synchronization is implicit + // initialize with standard Welford algorithm + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const at::Half* lvals = vals + i1*n2; + int l = 8*thrx; + if ((((size_t)lvals)&3) != 0) { + // 16 bit alignment + // first thread consumes first point + if (thrx == 0) { + float curr = static_cast(lvals[0]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + ++l; + } + // at this point, lvals[l] are 32 bit aligned for all threads. + for (; l+7 < n2; l+=8*numx) { + for (int k = 0; k < 8; k+=2) { + float2 curr = __half22float2(*((__half2*)(lvals+l+k))); + cuWelfordOnlineSum(curr.x,mu,sigma2,count); + cuWelfordOnlineSum(curr.y,mu,sigma2,count); + } + } + for (; l < n2; ++l) { + float curr = static_cast(lvals[l]); + cuWelfordOnlineSum(curr,mu,sigma2,count); + } + // intra-warp reductions + for (int l = 0; l <= 4; ++l) { + int srcLaneB = (threadIdx.x+(1< 1) { + float* ubuf = (float*)buf; + float* ibuf = (float*)(ubuf + blockDim.y); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_y = threadIdx.y - offset; + ubuf[2*wrt_y] = mu; + ubuf[2*wrt_y+1] = sigma2; + ibuf[wrt_y] = count; + } + __syncthreads(); + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + float muB = ubuf[2*threadIdx.y]; + float sigma2B = ubuf[2*threadIdx.y+1]; + float countB = ibuf[threadIdx.y]; + cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count); + } + __syncthreads(); + } + // threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values + if (threadIdx.x == 0 && threadIdx.y == 0) { + ubuf[0] = mu; + ubuf[1] = sigma2; + } + __syncthreads(); + mu = ubuf[0]; + sigma2 = ubuf[1]/float(n2); + // don't care about final value of count, we know count == n2 + } else { + mu = WARP_SHFL(mu, 0); + sigma2 = WARP_SHFL(sigma2/float(n2), 0); + } + } +} + +template __device__ U rsqrt(U v) { + return U(1) / sqrt(v); +} +template<> __device__ float rsqrt(float v) { + return rsqrtf(v); +} +template<> __device__ double rsqrt(double v) { + return rsqrt(v); +} + +namespace { +// This is the un-specialized struct. Note that we prevent instantiation of this +// struct by putting an undefined symbol in the function body so it won't compile. +// template +// struct SharedMemory +// { +// // Ensure that we won't compile any un-specialized types +// __device__ T *getPointer() +// { +// extern __device__ void error(void); +// error(); +// return NULL; +// } +// }; +// https://github.com/NVIDIA/apex/issues/246 +template +struct SharedMemory; + +template <> +struct SharedMemory +{ + __device__ float *getPointer() + { + HIP_DYNAMIC_SHARED( float, s_float) + return s_float; + } +}; + +} + +template __global__ +void cuApplyLayerNorm( + V* __restrict__ output_vals, + U* __restrict__ mean, + U* __restrict__ invvar, + const T* __restrict__ vals, + const int n1, + const int n2, + const U epsilon, + const V* __restrict__ gamma, + const V* __restrict__ beta + ) +{ + // Assumptions: + // 1) blockDim.x == warpSize + // 2) Tensors are contiguous + // + for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { + SharedMemory shared; + U* buf = shared.getPointer(); + U mu,sigma2; + cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf); + const T* lvals = vals + i1*n2; + V* ovals = output_vals + i1*n2; + U c_invvar = rsqrt(sigma2 + epsilon); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL && beta != NULL) { + for (int i = thrx; i < n2; i+=numx) { + U curr = static_cast(lvals[i]); + ovals[i] = (curr - mu) * c_invvar * static_cast(gamma[i]) + static_cast(beta[i]); + } + } else { + for (int i = thrx; i < n2; i+=numx) { + U curr = static_cast(lvals[i]); + ovals[i] = static_cast(c_invvar * (curr - mu)); + } + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + mean[i1] = mu; + invvar[i1] = c_invvar; + } + __syncthreads(); + } +} + +template __device__ +void cuLoadWriteStridedInputs( + const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + U* warp_buf1, + U* warp_buf2, + const T* input, + const V* dout, + const int i1_end, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar + ) +{ + int i1 = i1_block+thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1*n2+i2; + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (i2(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] = curr_dout; + warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar; + } else { + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } + } else { + for (int k = 0; k < blockDim.y; ++k) { + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + warp_buf1[write_idx] = U(0); + warp_buf2[write_idx] = U(0); + } + } +} + +template __device__ +void cuLoadAddStridedInputs( + const int i1_block, + const int thr_load_row_off, + const int thr_load_col_off, + const int i2_off, + const int row_stride, + U* warp_buf1, + U* warp_buf2, + const T* input, + const V* dout, + const int i1_end, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar + ) +{ + int i1 = i1_block+thr_load_row_off; + if (i1 < i1_end) { + U curr_mean = mean[i1]; + U curr_invvar = invvar[i1]; + for (int k = 0; k < blockDim.y; ++k) { + int i2 = i2_off + k; + int load_idx = i1*n2+i2; + int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k; + if (i2(input[load_idx]); + U curr_dout = static_cast(dout[load_idx]); + warp_buf1[write_idx] += curr_dout; + warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar; + } + } + } +} + +template __global__ +void cuComputePartGradGammaBeta( + const V* __restrict__ dout, + const T* __restrict__ input, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + U epsilon, + U* part_grad_gamma, + U* part_grad_beta) +{ + const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y); + const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y; + const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y; + const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y; + const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1; + const int row_stride = blockDim.x+1; + const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1); + const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y; + const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off; + SharedMemory shared; + U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements + U* warp_buf1 = (U*)buf; + U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride; + // compute partial sums from strided inputs + // do this to increase number of loads in flight + cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar); + for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) { + cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar); + } + __syncthreads(); + // inter-warp reductions + // sum within each warp + U acc1 = U(0); + U acc2 = U(0); + for (int k = 0; k < blockDim.y; ++k) { + int row1 = threadIdx.y + k*blockDim.y; + int idx1 = row1*row_stride + threadIdx.x; + acc1 += warp_buf1[idx1]; + acc2 += warp_buf2[idx1]; + } + warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1; + warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2; + __syncthreads(); + // sum all warps + for (int offset = blockDim.y/2; offset > 1; offset /= 2) { + if (threadIdx.y < offset) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + offset; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + warp_buf1[idx1] += warp_buf1[idx2]; + warp_buf2[idx1] += warp_buf2[idx2]; + } + __syncthreads(); + } + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (threadIdx.y == 0 && i2 < n2) { + int row1 = threadIdx.y; + int row2 = threadIdx.y + 1; + int idx1 = row1*row_stride + threadIdx.x; + int idx2 = row2*row_stride + threadIdx.x; + part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2]; + part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2]; + } +} + +template __global__ +void cuComputeGradGammaBeta( + const U* part_grad_gamma, + const U* part_grad_beta, + const int part_size, + const int n1, + const int n2, + V* grad_gamma, + V* grad_beta) +{ + // sum partial gradients for gamma and beta + SharedMemory shared; + U* buf = shared.getPointer(); + int i2 = blockIdx.x * blockDim.x + threadIdx.x; + if (i2 < n2) { + // each warp does sequential reductions until reduced part_size is num_warps + int num_warp_reductions = part_size / blockDim.y; + U sum_gamma = U(0); + U sum_beta = U(0); + const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2; + const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2; + for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) { + sum_gamma += part_grad_gamma_ptr[warp_offset*n2]; + sum_beta += part_grad_beta_ptr[warp_offset*n2]; + } + // inter-warp reductions + const int nbsize3 = blockDim.x * blockDim.y / 2; + for (int offset = blockDim.y/2; offset >= 1; offset /= 2) { + // top half write to shared memory + if (threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[write_idx] = sum_gamma; + buf[write_idx+nbsize3] = sum_beta; + } + __syncthreads(); + // bottom half sums + if (threadIdx.y < offset) { + const int read_idx = threadIdx.y * blockDim.x + threadIdx.x; + sum_gamma += buf[read_idx]; + sum_beta += buf[read_idx+nbsize3]; + } + __syncthreads(); + } + // write out fully summed gradients + if (threadIdx.y == 0) { + grad_gamma[i2] = sum_gamma; + grad_beta[i2] = sum_beta; + } + } +} + +template __global__ +void cuComputeGradInput( + const V* __restrict__ dout, + const T* __restrict__ input, + const int n1, + const int n2, + const U* __restrict__ mean, + const U* __restrict__ invvar, + U epsilon, + const V* gamma, + T* grad_input) +{ + for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) { + U sum_loss1 = U(0); + U sum_loss2 = U(0); + const U c_mean = mean[i1]; + const U c_invvar = invvar[i1]; + const T* k_input = input + i1*n2; + const V* k_dout = dout + i1*n2; + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + if (gamma != NULL) { + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l+k]); + const U c_loss = static_cast(k_dout[l+k]); + sum_loss1 += c_loss * gamma[l+k]; + sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss * gamma[l]; + sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar; + } + } else { + int l = 4*thrx; + for (; l+3 < n2; l+=4*numx) { + for (int k = 0; k < 4; ++k) { + const U c_h = static_cast(k_input[l+k]); + const U c_loss = static_cast(k_dout[l+k]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + for (; l < n2; ++l) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + sum_loss1 += c_loss; + sum_loss2 += c_loss * (c_h - c_mean) * c_invvar; + } + } + // intra-warp reductions + for (int mask = blockDim.x/2; mask > 0; mask /= 2) { + sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask); + sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask); + } + // inter-warp reductions + if (blockDim.y > 1) { + SharedMemory shared; + U* buf = shared.getPointer(); + for (int offset = blockDim.y/2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.y >= offset && threadIdx.y < 2*offset) { + const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x; + buf[2*wrt_i] = sum_loss1; + buf[2*wrt_i+1] = sum_loss2; + } + __syncthreads(); + // lower half merges + if (threadIdx.y < offset) { + const int read_i = threadIdx.y * blockDim.x + threadIdx.x; + sum_loss1 += buf[2*read_i]; + sum_loss2 += buf[2*read_i+1]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + buf[2*threadIdx.x] = sum_loss1; + buf[2*threadIdx.x+1] = sum_loss2; + } + __syncthreads(); + if (threadIdx.y !=0) { + sum_loss1 = buf[2*threadIdx.x]; + sum_loss2 = buf[2*threadIdx.x+1]; + } + } + // all threads now have the two sums over l + U fH = (U)n2; + U term1 = (U(1) / fH) * c_invvar; + T* k_grad_input = grad_input + i1*n2; + if (gamma != NULL) { + for (int l = thrx; l < n2; l+=numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss * gamma[l]; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } else { + for (int l = thrx; l < n2; l+=numx) { + const U c_h = static_cast(k_input[l]); + const U c_loss = static_cast(k_dout[l]); + U f_grad_input = fH * c_loss; + f_grad_input -= sum_loss1; + f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2; + f_grad_input *= term1; + k_grad_input[l] = static_cast(f_grad_input); + } + } + // prevent race where buf is written again before reads are done + __syncthreads(); + } +} + + + + +template +void HostApplyLayerNorm( + V* output, + U* mean, + U* invvar, + const T* input, + int n1, + int n2, + double epsilon, + const V* gamma, + const V* beta + ) +{ + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + const dim3 threads(32,4,1); + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks(1, ::min((uint64_t)n1, maxGridY), 1); + int nshared = + threads.y > 1 ? + threads.y*sizeof(U)+(threads.y/2)*sizeof(U) : + 0; + hipLaunchKernelGGL(( cuApplyLayerNorm), dim3(blocks), dim3(threads), nshared, stream, + output, + mean, + invvar, + input, + n1,n2, + U(epsilon), + gamma,beta); +} + + +void cuda_layer_norm( + at::Tensor* output, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + at::Tensor* beta, + double epsilon) +{ + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel", + HostApplyLayerNorm( + output->DATA_PTR(), + mean->DATA_PTR(), + invvar->DATA_PTR(), + input->DATA_PTR(), + n1,n2, + epsilon, + gamma != NULL ? gamma->DATA_PTR() : NULL, + beta != NULL ? beta->DATA_PTR() : NULL); + ) +} + + +template +void HostLayerNormGradient( + const V* dout, + const U* mean, + const U* invvar, + at::Tensor* input, + int n1, + int n2, + const V* gamma, + const V* beta, + double epsilon, + T* grad_input, + V* grad_gamma, + V* grad_beta + ) +{ + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); + + if (gamma != NULL && beta != NULL) { + // compute grad_gamma(j) and grad_beta(j) + const int part_size = 16; + const dim3 threads2(32,4,1); + const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1); + const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y * + (threads2.x + 1); + const int nshared2_b = threads2.x * threads2.y * sizeof(U); + const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b; + at::Tensor part_grad_gamma = at::empty( + {part_size,n2}, input->options().dtype(at::ScalarType::Float)); + at::Tensor part_grad_beta = at::empty_like(part_grad_gamma); + hipLaunchKernelGGL(( cuComputePartGradGammaBeta), dim3(blocks2), dim3(threads2), nshared2, stream, + dout, + input->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + part_grad_gamma.DATA_PTR(), + part_grad_beta.DATA_PTR()); + + const dim3 threads3(32,8,1); + const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1); + const int nshared3 = threads3.x * threads3.y * sizeof(U); + hipLaunchKernelGGL(( cuComputeGradGammaBeta), dim3(blocks3), dim3(threads3), nshared3, stream, + part_grad_gamma.DATA_PTR(), + part_grad_beta.DATA_PTR(), + part_size, + n1,n2, + grad_gamma, + grad_beta); + } + + // compute grad_input + const uint64_t maxGridY = + at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + const dim3 blocks1(1, ::min((uint64_t)n1, maxGridY), 1); + const dim3 threads1(32,4,1); + int nshared = + threads1.y > 1 ? + threads1.y*threads1.x*sizeof(U) : + 0; + hipLaunchKernelGGL(( cuComputeGradInput), dim3(blocks1), dim3(threads1), nshared, stream, + dout, + input->DATA_PTR(), + n1,n2, + mean, + invvar, + U(epsilon), + gamma, + grad_input); +} + + +void cuda_layer_norm_gradient( + at::Tensor* dout, + at::Tensor* mean, + at::Tensor* invvar, + at::Tensor* input, + int n1, + int n2, + #ifdef VERSION_GE_1_1 + at::IntArrayRef normalized_shape, + #else + at::IntList normalized_shape, + #endif + at::Tensor* gamma, + at::Tensor* beta, + double epsilon, + at::Tensor* grad_input, + at::Tensor* grad_gamma, + at::Tensor* grad_beta) +{ + using namespace at; + DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES( + input->scalar_type(), gamma->scalar_type(), + "cuda_layer_norm_gradient_kernel", + HostLayerNormGradient( + dout->DATA_PTR(), + mean->DATA_PTR(), + invvar->DATA_PTR(), + input, + n1,n2, + // TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta + // if gamma Tensor is NULL on input. + gamma != NULL ? gamma->DATA_PTR() : NULL, + gamma != NULL ? beta->DATA_PTR() : NULL, + epsilon, + grad_input->DATA_PTR(), + gamma != NULL ? grad_gamma->DATA_PTR() : NULL, + gamma != NULL ? grad_beta->DATA_PTR() : NULL); + ) +} diff --git a/megatron/fused_kernels/scaled_masked_softmax.cpp b/megatron/fused_kernels/scaled_masked_softmax.cpp index 1852aee6f..5dd326357 100644 --- a/megatron/fused_kernels/scaled_masked_softmax.cpp +++ b/megatron/fused_kernels/scaled_masked_softmax.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include +#include #include #include diff --git a/megatron/fused_kernels/scaled_masked_softmax.h b/megatron/fused_kernels/scaled_masked_softmax.h index 013dd8366..cda79bb22 100644 --- a/megatron/fused_kernels/scaled_masked_softmax.h +++ b/megatron/fused_kernels/scaled_masked_softmax.h @@ -17,11 +17,11 @@ #pragma once #include -#include +#include #include #include #include -#include +#include #include namespace { @@ -86,7 +86,7 @@ struct Max { template __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) { -#if CUDA_VERSION >= 9000 +#if HIP_VERSION >= 9000 return __shfl_xor_sync(mask, value, laneMask, width); #else return __shfl_xor(value, laneMask, width); @@ -497,55 +497,55 @@ void dispatch_scaled_softmax_forward( switch (log2_elements) { case 0: // 1 scaled_softmax_warp_forward - <<>>(dst, src, scale, batch_count, key_seq_len); + <<>>(dst, src, scale, batch_count, key_seq_len); break; case 1: // 2 scaled_softmax_warp_forward - <<>>(dst, src, scale, batch_count, key_seq_len); + <<>>(dst, src, scale, batch_count, key_seq_len); break; case 2: // 4 scaled_softmax_warp_forward - <<>>(dst, src, scale, batch_count, key_seq_len); + <<>>(dst, src, scale, batch_count, key_seq_len); break; case 3: // 8 scaled_softmax_warp_forward - <<>>(dst, src, scale, batch_count, key_seq_len); + <<>>(dst, src, scale, batch_count, key_seq_len); break; case 4: // 16 scaled_softmax_warp_forward - <<>>(dst, src, scale, batch_count, key_seq_len); + <<>>(dst, src, scale, batch_count, key_seq_len); break; case 5: // 32 scaled_softmax_warp_forward - <<>>(dst, src, scale, batch_count, key_seq_len); + <<>>(dst, src, scale, batch_count, key_seq_len); break; case 6: // 64 scaled_softmax_warp_forward - <<>>(dst, src, scale, batch_count, key_seq_len); + <<>>(dst, src, scale, batch_count, key_seq_len); break; case 7: // 128 scaled_softmax_warp_forward - <<>>(dst, src, scale, batch_count, key_seq_len); + <<>>(dst, src, scale, batch_count, key_seq_len); break; case 8: // 256 scaled_softmax_warp_forward - <<>>(dst, src, scale, batch_count, key_seq_len); + <<>>(dst, src, scale, batch_count, key_seq_len); break; case 9: // 512 scaled_softmax_warp_forward - <<>>(dst, src, scale, batch_count, key_seq_len); + <<>>(dst, src, scale, batch_count, key_seq_len); break; case 10: // 1024 scaled_softmax_warp_forward - <<>>(dst, src, scale, batch_count, key_seq_len); + <<>>(dst, src, scale, batch_count, key_seq_len); break; case 11: // 2048 scaled_softmax_warp_forward - <<>>(dst, src, scale, batch_count, key_seq_len); + <<>>(dst, src, scale, batch_count, key_seq_len); break; case 12: // 4096 scaled_softmax_warp_forward - <<>>(dst, src, scale, batch_count, key_seq_len); + <<>>(dst, src, scale, batch_count, key_seq_len); break; default: break; @@ -591,55 +591,55 @@ void dispatch_scaled_masked_softmax_forward( switch (log2_elements) { case 0: // 1 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; case 1: // 2 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; case 2: // 4 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; case 3: // 8 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; case 4: // 16 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; case 5: // 32 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; case 6: // 64 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; case 7: // 128 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; case 8: // 256 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; case 9: // 512 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; case 10: // 1024 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; case 11: // 2048 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; case 12: // 4096 scaled_masked_softmax_warp_forward - <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); + <<>>(dst, src, mask, scale, batch_count, key_seq_len, pad_batches); break; default: break; @@ -683,55 +683,55 @@ void dispatch_scaled_masked_softmax_backward( switch (log2_elements) { case 0: // 1 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); break; case 1: // 2 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); break; case 2: // 4 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); break; case 3: // 8 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); break; case 4: // 16 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); break; case 5: // 32 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); break; case 6: // 64 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); break; case 7: // 128 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); break; case 8: // 256 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); break; case 9: // 512 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); break; case 10: // 1024 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); break; case 11: // 2048 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); break; case 12: // 4096 scaled_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); + <<>>(grad_input, grad, output, scale, batch_count, key_seq_len); break; default: diff --git a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu index 2efee39a6..d7d6c66b4 100644 --- a/megatron/fused_kernels/scaled_masked_softmax_cuda.cu +++ b/megatron/fused_kernels/scaled_masked_softmax_cuda.cu @@ -18,7 +18,7 @@ #include #include #include -#include +// #include #include #include #include "scaled_masked_softmax.h" diff --git a/megatron/fused_kernels/scaled_masked_softmax_hip.hip b/megatron/fused_kernels/scaled_masked_softmax_hip.hip new file mode 100644 index 000000000..341c7b9b2 --- /dev/null +++ b/megatron/fused_kernels/scaled_masked_softmax_hip.hip @@ -0,0 +1,118 @@ +// !!! This is a file automatically generated by hipify!!! +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +// #include +#include +#include +#include "scaled_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_masked_softmax { + +int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){ + return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads); +} + + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + torch::Tensor const& mask, + float scale_factor) +{ + // input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = input.size(0); + const int pad_batches = mask.size(0); + const int attn_heads = input.size(1); + const int query_seq_len = input.size(2); + const int key_seq_len = input.size(3); + TORCH_INTERNAL_ASSERT(key_seq_len <= 4096); + TORCH_INTERNAL_ASSERT(query_seq_len > 1); + TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches); + TORCH_INTERNAL_ASSERT(mask.size(1) == 1); + TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len); + TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* mask_ptr = static_cast(mask.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_masked_softmax_forward", + dispatch_scaled_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + reinterpret_cast(mask_ptr), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads, + pad_batches); + ); + return softmax_results; +} + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len] + const int batches = output_grads.size(0); + const int attn_heads = output_grads.size(1); + const int query_seq_len = output_grads.size(2); + const int key_seq_len = output_grads.size(3); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_masked_softmax_backward", + dispatch_scaled_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + query_seq_len, + key_seq_len, + batches, + attn_heads); + ); + + //backward pass is completely in-place + return output_grads; +} +} +} +} diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp index ea283588d..1885ca4df 100644 --- a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include +#include #include #include diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h index 6df83fc10..1680dfff1 100644 --- a/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h @@ -17,7 +17,7 @@ #pragma once #include -#include +#include #include #include #include @@ -85,7 +85,7 @@ struct Max { template __device__ __forceinline__ T WARP_SHFL_XOR_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) { -#if CUDA_VERSION >= 9000 +#if HIP_VERSION >= 9000 return __shfl_xor_sync(mask, value, laneMask, width); #else return __shfl_xor(value, laneMask, width); @@ -369,51 +369,51 @@ void dispatch_scaled_upper_triang_masked_softmax_forward( switch (log2_elements) { case 0: // 1 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 1: // 2 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 2: // 4 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 3: // 8 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 4: // 16 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 5: // 32 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 6: // 64 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 7: // 128 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 8: // 256 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 9: // 512 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 10: // 1024 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 11: // 2048 scaled_upper_triang_masked_softmax_warp_forward - <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(dst, src, scale, batch_count, softmax_elements_stride, softmax_elements); break; default: break; @@ -460,51 +460,51 @@ void dispatch_scaled_upper_triang_masked_softmax_backward( switch (log2_elements) { case 0: // 1 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 1: // 2 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 2: // 4 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 3: // 8 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 4: // 16 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 5: // 32 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 6: // 64 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 7: // 128 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 8: // 256 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 9: // 512 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 10: // 1024 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); break; case 11: // 2048 scaled_upper_triang_masked_softmax_warp_backward - <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); + <<>>(grad_input, grad, output, scale, batch_count, softmax_elements_stride, softmax_elements); break; default: break; diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu b/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu index 5efc3d412..b455a09a3 100644 --- a/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu @@ -18,7 +18,7 @@ #include #include #include -#include +// #include #include #include #include "scaled_upper_triang_masked_softmax.h" diff --git a/megatron/fused_kernels/scaled_upper_triang_masked_softmax_hip.hip b/megatron/fused_kernels/scaled_upper_triang_masked_softmax_hip.hip new file mode 100644 index 000000000..b033fe0ec --- /dev/null +++ b/megatron/fused_kernels/scaled_upper_triang_masked_softmax_hip.hip @@ -0,0 +1,99 @@ +// !!! This is a file automatically generated by hipify!!! +/* coding=utf-8 + * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +// #include +#include +#include +#include "scaled_upper_triang_masked_softmax.h" +#include "type_shim.h" + +namespace multihead_attn { +namespace fused_softmax { +namespace scaled_upper_triang_masked_softmax { + +torch::Tensor fwd_cuda( + torch::Tensor const& input, + float scale_factor) +{ + // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = input.size(0); + const int seq_len = input.size(1); + TORCH_INTERNAL_ASSERT(seq_len <= 2048); + + // Output + auto act_options = input.options().requires_grad(false); + torch::Tensor softmax_results = + torch::empty({attn_batches, seq_len, seq_len}, act_options); + + // Softmax Intermediate Result Ptr + void* input_ptr = static_cast(input.data_ptr()); + void* softmax_results_ptr = static_cast(softmax_results.data_ptr()); + + DISPATCH_HALF_AND_BFLOAT( + input.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_forward", + dispatch_scaled_upper_triang_masked_softmax_forward( + reinterpret_cast(softmax_results_ptr), + reinterpret_cast(input_ptr), + scale_factor, + seq_len, + seq_len, + attn_batches); + ); + return softmax_results; +} + + +torch::Tensor bwd_cuda( + torch::Tensor const& output_grads_, + torch::Tensor const& softmax_results_, + float scale_factor) { + + auto output_grads = output_grads_.contiguous(); + auto softmax_results = softmax_results_.contiguous(); + + //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len] + const int attn_batches = output_grads.size(0); + const int seq_len = output_grads.size(1); + TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2)); + + void* output_grads_ptr = static_cast(output_grads.data_ptr()); + + //Softmax Grad + DISPATCH_HALF_AND_BFLOAT( + output_grads_.scalar_type(), + "dispatch_scaled_upper_triang_masked_softmax_backward", + dispatch_scaled_upper_triang_masked_softmax_backward( + reinterpret_cast(output_grads_ptr), + reinterpret_cast(output_grads_ptr), + reinterpret_cast(softmax_results.data_ptr()), + scale_factor, + seq_len, + seq_len, + attn_batches); + ); + + //backward pass is completely in-place + return output_grads; +} +} +} +} diff --git a/megatron/model/fused_layer_norm.py b/megatron/model/fused_layer_norm.py index 55e9c9dd8..b57441300 100644 --- a/megatron/model/fused_layer_norm.py +++ b/megatron/model/fused_layer_norm.py @@ -29,7 +29,7 @@ import importlib import torch import torch.nn.functional as F - +from lightop import op global fused_mix_prec_layer_norm_cuda fused_mix_prec_layer_norm_cuda = None @@ -81,6 +81,8 @@ def __init__(self, normalized_shape, eps=1e-5): self.eps = eps self.weight = Parameter(torch.Tensor(*normalized_shape)) self.bias = Parameter(torch.Tensor(*normalized_shape)) + + self.reset_parameters() args = get_args() @@ -108,4 +110,7 @@ def forward(self, input): return FusedLayerNormAffineFunction.apply( input, self.weight, self.bias, self.normalized_shape, self.eps) else: - return F.layer_norm(input, self.normalized_shape, self.weight, self.bias) + # print("*"*60) + # print(input.dtype,self.weight.dtype,self.bias.dtype) + # print("*"*60) + return op.layernorm_forward_autograd(input, self.weight,self.bias,self.eps) diff --git a/megatron/model/fused_softmax.py b/megatron/model/fused_softmax.py index 07192e2bf..edcce8832 100644 --- a/megatron/model/fused_softmax.py +++ b/megatron/model/fused_softmax.py @@ -18,6 +18,8 @@ import torch.nn as nn from megatron.enums import AttnMaskType +from lightop.fusesoftmax import FuseSoftmax + class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): """ Fused operation which performs following three operations in sequence @@ -220,8 +222,9 @@ def forward_torch_softmax(self, input, mask): mask = self.get_causal_mask(input.shape[2]) mask_output = self.mask_func(input, mask) if mask is not None else input - - probs = torch.nn.Softmax(dim=-1)(mask_output) + #改 + # probs = torch.nn.Softmax(dim=-1)(mask_output) + probs = FuseSoftmax(dim=-1)(mask_output) if self.input_in_float16 and self.softmax_in_fp32: if self.input_in_fp16: diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 03e6faaec..f1dd761d8 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -28,6 +28,7 @@ from megatron.model.fused_bias_gelu import bias_gelu_impl from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu +from lightop import op import deepspeed from .glu_activations import GLU_ACTIVATIONS @@ -405,26 +406,31 @@ def forward(self, hidden_states, attention_mask, layer_past=None, return output, bias +# def bias_dropout_add(x, bias, residual, prob, training): +# # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor +# out = torch.nn.functional.dropout(x + bias, p=prob, training=training) +# out = residual + out +# return out def bias_dropout_add(x, bias, residual, prob, training): # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor - out = torch.nn.functional.dropout(x + bias, p=prob, training=training) - out = residual + out + out = op.add_dropout_forward_autograd(x + bias, residual,prob,training) return out + def get_bias_dropout_add(training): def _bias_dropout_add(x, bias, residual, prob): return bias_dropout_add(x, bias, residual, prob, training) return _bias_dropout_add -@torch.jit.script +# @torch.jit.script def bias_dropout_add_fused_train(x, bias, residual, prob): # type: (Tensor, Tensor, Tensor, float) -> Tensor return bias_dropout_add(x, bias, residual, prob, True) -@torch.jit.script +# @torch.jit.script def bias_dropout_add_fused_inference(x, bias, residual, prob): # type: (Tensor, Tensor, Tensor, float) -> Tensor return bias_dropout_add(x, bias, residual, prob, False) diff --git a/run-gpt3.sh b/run-gpt3.sh new file mode 100644 index 000000000..00ce894f4 --- /dev/null +++ b/run-gpt3.sh @@ -0,0 +1,48 @@ +#!/bin/bash + +#SBATCH -p ty_zhiyuan +#SBATCH -N 88 +#SBATCH --cpus-per-task=1 +#SBATCH --ntasks-per-node=32 +#SBATCH --mem 0 +#SBATCH --gres=dcu:4 +#SBATCH -J gpt3 +#SBATCH -o logs/%j.out +#SBATCH -e logs/%j.err +ulimit -u 200000 +#ethtool -g eno1 + +#export MIOPEN_USER_DB_PATH=/tmp/miopen-udb +#export MIOPEN_CUSTOM_CACHE_DIR=/tmp/miopen-cache +export NCCL_IB_HCA=mlx5_0 +export NCCL_SOCKET_IFNAME=ib0 +export HSA_FORCE_FINE_GRAIN_PCIE=1 +export OMP_NUM_THREADS=1 +echo %j +echo "START TIME: $(date)" +hostfile=./hostfile_dir/$SLURM_JOB_ID +scontrol show hostnames $SLURM_JOB_NODELIST > ${hostfile} +rm `pwd`/hostfile_dir/hostfile-dl-* -f + +for i in `cat $hostfile` +do + echo ${i} slots=4 >> `pwd`/hostfile_dir/hostfile-dl-$SLURM_JOB_ID +done +np=$(cat $hostfile|sort|uniq |wc -l) + +np=$(($np*4)) + +nodename=$(cat $hostfile |sed -n "1p") +echo $nodename +dist_url=`echo $nodename | awk '{print $1}'` + +which mpirun +echo "dist_url: $dist_url" +echo "np: $np" +mpirun -np $np --allow-run-as-root --hostfile `pwd`/hostfile_dir/hostfile-dl-$SLURM_JOB_ID --bind-to none `pwd`/single.sh $dist_url +echo "END TIME: $(date)" + +#just for rename log +now=$(date +"%Y%m%d_%H%M%S") +cp logs/$SLURM_JOB_ID.out logs/gpt3-$now-$SLURM_JOB_ID.log +cp logs/$SLURM_JOB_ID.err logs/gpt3-$now-$SLURM_JOB_ID.err \ No newline at end of file diff --git a/single.sh b/single.sh new file mode 100644 index 000000000..72117d538 --- /dev/null +++ b/single.sh @@ -0,0 +1,168 @@ +#!/bin/bash +export NCCL_SOCKET_IFNAME=ib0 +export NCCL_PXN_DISABLE=0 +export NCCL_IB_HCA=mlx5_0 #0号网卡 +export HSA_FORCE_FINE_GRAIN_PCIE=1 +export MIOPEN_FIND_MODE=3 + +export MIOPEN_USER_DB_PATH=/tmp/miopen-udb +export MIOPEN_CUSTOM_CACHE_DIR=/tmp/miopen-cache + +mkdir -p log_wwl +now=$(date +"%Y%m%d_%H%M%S") + +lrank=$OMPI_COMM_WORLD_LOCAL_RANK +RANK=$OMPI_COMM_WORLD_RANK +WORLD_SIZE=$OMPI_COMM_WORLD_SIZE + + +MODEL_NAME=gpt3_175B +DATA_OUTPUT_PATH=./ +LOGS_PATH=$DATA_OUTPUT_PATH/logs +CHECKPOINT_PATH=checkpoints/$MODEL_NAME + +DATA_PATH="/public/home/platform/wwl/Megatron-DeepSpeed/data/meg-gpt2-oscar-en-10k_text_document" +TENSORBOARD_PATH=output_dir/tensorboard/$MODEL_NAME +CODECARBON_PATH=output_dir/codecarbon/$MODEL_NAME + + +SAVE_INTERVAL=250 + +TP_SIZE=4 +PP_SIZE=88 + +NHIDDEN=12352 #12352 +NLAYERS=88 +NHEADS=64 +SEQ_LEN=2048 + + +ZERO_STAGE=1 +config_json="./${MODEL_NAME}_ds_config.json" + + + +MICRO_BATCH_SIZE=1 +GLOBAL_BATCH_SIZE=1 #5760 +cat < $config_json +{ + "train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE, + "train_batch_size": $GLOBAL_BATCH_SIZE, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": $ZERO_STAGE + }, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 500, + "hysteresis": 2, + "min_loss_scale": 1, + "initial_scale_power": 12 + }, + "steps_per_print": 2000, + "wall_clock_breakdown": false +} +EOT + + + + +DEEPSPEED_ARGS=" \ + --deepspeed \ + --deepspeed_config ${config_json} \ + --zero-stage ${ZERO_STAGE} \ + --deepspeed-activation-checkpointing \ + " +export CMD=" \ + --tensor-model-parallel-size $TP_SIZE \ + --pipeline-model-parallel-size $PP_SIZE \ + --num-layers $NLAYERS \ + --hidden-size $NHIDDEN \ + --num-attention-heads $NHEADS \ + --seq-length $SEQ_LEN \ + --max-position-embeddings $SEQ_LEN \ + --micro-batch-size 1 \ + --global-batch-size $GLOBAL_BATCH_SIZE \ + --train-samples 360000 \ + --loss-scale 12 \ + --clip-grad 1.0 \ + --fp16 \ + --checkpoint-activations \ + --seed 42 + --optimizer adam \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --adam-eps 1e-8 \ + --lr 6.0e-5 \ + --min-lr 6.0e-6 \ + --lr-decay-style cosine \ + --clip-grad 1.0 \ + --weight-decay 1e-1 \ + --exit-duration-in-mins 1190 \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --vocab-file /public/home/platform/wwl/Megatron-DeepSpeed/data/gpt2-vocab.json \ + --merge-file /public/home/platform/wwl/Megatron-DeepSpeed/data/gpt2-merges.txt \ + --log-interval 1 \ + --save-interval $SAVE_INTERVAL \ + --eval-interval 1000 \ + --eval-iters 40 \ + --tensorboard-dir $TENSORBOARD_PATH \ + --tensorboard-queue-size 5 \ + --log-timers-to-tensorboard \ + --log-batch-size-to-tensorboard \ + --log-validation-ppl-to-tensorboard \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + $DEEPSPEED_ARGS \ + " +# APP="deepspeed --num_gpus 1 /work/home/jsyadmin/.conda/envs/ldk-h/bin/python3.7 -u `pwd`/pretrain_gpt.py \ + +APP="python -u `pwd`/pretrain_gpt.py \ + --rank ${RANK} \ + --world_size ${WORLD_SIZE} \ + --dist_url tcp://${1}:34566 \ + --num-workers 2 \ + ${CMD} \ + " +#echo ${APP} + + + +case ${lrank} in +[0]) + export HIP_VISIBLE_DEVICES=0,1,2,3 + export UCX_NET_DEVICES=mlx5_0:1 + export UCX_IB_PCI_BW=mlx5_0:50Gbs + #echo NCCL_SOCKET_IFNAME=ib0 numactl --cpunodebind=0 --membind=0 ${APP} + NCCL_SOCKET_IFNAME=ib0 numactl --cpunodebind=0 --membind=0 ${APP} + + #echo GLOO_SOCKET_IFNAME=ib0 numactl --cpunodebind=0 --membind=0 ${APP} + #GLOO_SOCKET_IFNAME=ib0 numactl --cpunodebind=0 --membind=0 ${APP} + ;; +[1]) + export HIP_VISIBLE_DEVICES=0,1,2,3 + export UCX_NET_DEVICES=mlx5_1:1 + export UCX_IB_PCI_BW=mlx5_1:50Gbs + #echo NCCL_SOCKET_IFNAME=ib0 numactl --cpunodebind=1 --membind=1 ${APP} + NCCL_SOCKET_IFNAME=ib0 numactl --cpunodebind=1 --membind=1 ${APP} + ;; +[2]) + export HIP_VISIBLE_DEVICES=0,1,2,3 + export UCX_NET_DEVICES=mlx5_2:1 + export UCX_IB_PCI_BW=mlx5_2:50Gbs + #echo NCCL_SOCKET_IFNAME=ib0 numactl --cpunodebind=2 --membind=2 ${APP} + NCCL_SOCKET_IFNAME=ib0 numactl --cpunodebind=2 --membind=2 ${APP} + ;; +[3]) + export HIP_VISIBLE_DEVICES=0,1,2,3 + export UCX_NET_DEVICES=mlx5_3:1 + export UCX_IB_PCI_BW=mlx5_3:50Gbs + #echo NCCL_SOCKET_IFNAME=ib0 numactl --cpunodebind=3 --membind=3 ${APP} + NCCL_SOCKET_IFNAME=ib0 numactl --cpunodebind=3 --membind=3 ${APP} + ;; +esac +