Skip to content

Commit 6bdede8

Browse files
authored
Update augment branch to use kernels for zero-fill before layernorm and rmsnorm (#1)
1 parent 80a2d68 commit 6bdede8

File tree

4 files changed

+118
-24
lines changed

4 files changed

+118
-24
lines changed

setup.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,19 @@ def te_version() -> str:
3535
with open(root_path / "VERSION", "r") as f:
3636
version = f.readline().strip()
3737

38+
try:
39+
output = subprocess.run(
40+
["git", "rev-parse" , "--short", "HEAD"],
41+
capture_output=True,
42+
cwd=root_path,
43+
check=True,
44+
universal_newlines=True,
45+
)
46+
except (CalledProcessError, OSError):
47+
commit = ""
48+
else:
49+
commit = output.stdout.strip()
50+
3851
# [augment] Here is where we replace the git hash with our own versioning.
3952
# You can disable this behavior with NVTE_NO_AUGMENT_VERSION=1.
4053
if not int(os.getenv("NVTE_NO_AUGMENT_VERSION", "0")):
@@ -43,21 +56,10 @@ def te_version() -> str:
4356
torch_version = parse(torch.__version__)
4457
cuda_version = parse(torch.version.cuda)
4558
version_string = f".cu{cuda_version.major}{cuda_version.minor}.torch{torch_version.major}{torch_version.minor}"
46-
return version + "+augment" + version_string
59+
return version + "+augment" + version_string + "." + commit
4760

4861
if not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0")):
49-
try:
50-
output = subprocess.run(
51-
["git", "rev-parse" , "--short", "HEAD"],
52-
capture_output=True,
53-
cwd=root_path,
54-
check=True,
55-
universal_newlines=True,
56-
)
57-
except (CalledProcessError, OSError):
58-
pass
59-
else:
60-
commit = output.stdout.strip()
62+
if len(commit) > 0:
6163
version += f"+{commit}"
6264
return version
6365

transformer_engine/common/layer_norm/ln_api.cpp

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
* See LICENSE for license information.
55
************************************************************************/
66

7+
#include <cstdlib>
78
#include <transformer_engine/layer_norm.h>
9+
#include <string>
810
#include <vector>
911
#include "ln.h"
1012
#include "../common.h"
@@ -31,6 +33,9 @@ Compute always in FP32
3133
namespace transformer_engine {
3234
namespace layer_norm {
3335

36+
// [Augment] Forward declare helper kernel added to avoid using memset.
37+
void launch_zero_out(void *, size_t, size_t, cudaStream_t);
38+
3439
using namespace transformer_engine;
3540

3641
// Create registries and provide runtime versions of config hash functions.
@@ -232,16 +237,36 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
232237
params.barrier = reinterpret_cast<int*>(barrier->data.dptr);
233238
}
234239

240+
// NOTE[augment]: this envvar exists to restore the prior behavior of TE (ie, use a memset
241+
// kernel. So if you want to get the upstream behavior, run with NVTE_FORCE_MEMSET=1.
242+
const char *envval = std::getenv("NVTE_FORCE_MEMSET");
243+
bool force_memset = (envval != nullptr) && (std::string(envval) == "1");
235244
// Clear buffers
236245
if ( params.fp8_out ) {
237-
cudaMemsetAsync(params.amax, 0,
238-
layer_norm::product(z->amax.shape) *
239-
typeToSize(z->amax.dtype), stream);
246+
if ( force_memset ) {
247+
cudaMemsetAsync(params.amax, 0,
248+
layer_norm::product(z->amax.shape) *
249+
typeToSize(z->amax.dtype), stream);
250+
} else {
251+
// [Augment] Use the zero-out kernel, not memset.
252+
layer_norm::launch_zero_out(params.amax,
253+
layer_norm::product(z->amax.shape),
254+
typeToSize(z->amax.dtype),
255+
stream);
256+
}
240257
}
241258
if ( launch_params.barrier_size > 0 ) {
242-
cudaMemsetAsync(params.barrier, 0,
243-
layer_norm::product(barrier->data.shape) *
244-
typeToSize(barrier->data.dtype), stream);
259+
if ( force_memset ) {
260+
cudaMemsetAsync(params.barrier, 0,
261+
layer_norm::product(barrier->data.shape) *
262+
typeToSize(barrier->data.dtype), stream);
263+
} else {
264+
// [Augment] Use the zero-out kernel, not memset.
265+
layer_norm::launch_zero_out(params.barrier,
266+
layer_norm::product(barrier->data.shape),
267+
typeToSize(barrier->data.dtype),
268+
stream);
269+
}
245270
}
246271

247272
// Launch the kernel.

transformer_engine/common/layer_norm/ln_fwd_cuda_kernel.cu

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,48 @@
1010

1111
using namespace transformer_engine::layer_norm;
1212

13+
// [Augment] What follows is a small custom kernel (and launch function) to zero-out a buffer.
14+
// We use this to replace a call to cudaMemsetAsync, which introduces gaps in cuda graph
15+
// execution. I am sure there is a more natural place to put this, but I haven't spent the time
16+
// to figure out the TE build system.
17+
namespace transformer_engine::layer_norm {
18+
19+
// Kernel itself: simple blockwise loop
20+
template <typename T>
21+
__launch_bounds__(128)
22+
__global__ void zero_out(T *x, const size_t N) {
23+
const int tidx = threadIdx.x;
24+
const int bidx = blockIdx.x;
25+
const int stride = blockDim.x * gridDim.x;
26+
for (int i = tidx + bidx * blockDim.x; i < N; i += stride) {
27+
x[i] = 0;
28+
}
29+
}
30+
31+
// Launch function: switch over element size and use an appropriate dtype for each.
32+
// NOTE: if speed ever becomes an issue, this ought to be vectorized.
33+
void launch_zero_out(void *buf, size_t N, size_t elem_size, cudaStream_t stream) {
34+
int num_blocks = DIVUP(static_cast<int>(N), 128);
35+
switch (elem_size) {
36+
case 1:
37+
zero_out<byte><<<num_blocks, 128, 0, stream>>>(reinterpret_cast<byte*>(buf), N);
38+
break;
39+
case 2:
40+
zero_out<fp16><<<num_blocks, 128, 0, stream>>>(reinterpret_cast<fp16*>(buf), N);
41+
break;
42+
case 4:
43+
zero_out<float><<<num_blocks, 128, 0, stream>>>(reinterpret_cast<float*>(buf), N);
44+
break;
45+
case 8:
46+
zero_out<double><<<num_blocks, 128, 0, stream>>>(reinterpret_cast<double*>(buf), N);
47+
break;
48+
default:
49+
break;
50+
}
51+
}
52+
53+
}; // namespace transformer_engine::layer_norm
54+
1355
template<
1456
typename weight_t,
1557
typename input_t,

transformer_engine/common/rmsnorm/rmsnorm_api.cpp

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
* See LICENSE for license information.
55
************************************************************************/
66

7+
#include <cstdlib>
78
#include <numeric>
9+
#include <string>
810
#include <vector>
911
#include "../common.h"
1012
#include "rmsnorm.h"
@@ -35,6 +37,9 @@ namespace transformer_engine {
3537

3638
namespace layer_norm {
3739
uint64_t get_key(DType wtype, DType itype, DType otype, DType ctype, uint64_t hidden_size);
40+
41+
// [Augment] Forward declare helper kernel added to avoid using memset.
42+
void launch_zero_out(void *, size_t, size_t, cudaStream_t);
3843
}
3944

4045
namespace rmsnorm {
@@ -177,15 +182,35 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
177182
params.barrier = reinterpret_cast<int *>(barrier->data.dptr);
178183
}
179184

185+
// NOTE[augment]: this envvar exists to restore the prior behavior of TE (ie, use a memset
186+
// kernel. So if you want to get the upstream behavior, run with NVTE_FORCE_MEMSET=1.
187+
const char *envval = std::getenv("NVTE_FORCE_MEMSET");
188+
bool force_memset = (envval != nullptr) && (std::string(envval) == "1");
180189
// Clear buffers
181190
if (params.fp8_out) {
182-
cudaMemsetAsync(params.amax, 0, rmsnorm::product(z->amax.shape) * typeToSize(z->amax.dtype),
183-
stream);
191+
if (force_memset) {
192+
cudaMemsetAsync(params.amax, 0, rmsnorm::product(z->amax.shape) * typeToSize(z->amax.dtype),
193+
stream);
194+
} else {
195+
// [Augment] Use the zero-out kernel, not memset.
196+
layer_norm::launch_zero_out(params.amax,
197+
rmsnorm::product(z->amax.shape),
198+
typeToSize(z->amax.dtype),
199+
stream);
200+
}
184201
}
185202
if (launch_params.barrier_size > 0) {
186-
cudaMemsetAsync(params.barrier, 0,
187-
rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
188-
stream);
203+
if (force_memset) {
204+
cudaMemsetAsync(params.barrier, 0,
205+
rmsnorm::product(barrier->data.shape) * typeToSize(barrier->data.dtype),
206+
stream);
207+
} else {
208+
// [Augment] Use the zero-out kernel, not memset.
209+
layer_norm::launch_zero_out(params.barrier,
210+
rmsnorm::product(barrier->data.shape),
211+
typeToSize(barrier->data.dtype),
212+
stream);
213+
}
189214
}
190215

191216
// Launch the kernel.

0 commit comments

Comments
 (0)