From fdc075ca968d431e50ff63fc0c648c595cdaaa85 Mon Sep 17 00:00:00 2001 From: Simon Opsahl Date: Mon, 5 Aug 2024 12:58:33 +0200 Subject: [PATCH 01/10] Frozen Bitlinear implemention --- .gitignore | 1 + bitlinear/__init__.py | 10 ++++------ bitlinear/kernels.py | 2 +- bitlinear/measures.py | 2 +- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index ca0b987..7cc7d22 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ __pycache__ dist examples/data examples/models +.vscode/ diff --git a/bitlinear/__init__.py b/bitlinear/__init__.py index 5d740a3..540dcf3 100644 --- a/bitlinear/__init__.py +++ b/bitlinear/__init__.py @@ -8,10 +8,8 @@ AbsMean, AbsMedian, ) -from .kernels import ( - Naive, - NaiveListComp, - TernaryNaive, +from .frozen_bitlinear.frozen_bitlinear import ( TorchLinear, - TorchMulAdd, -) \ No newline at end of file + Naive +) + diff --git a/bitlinear/kernels.py b/bitlinear/kernels.py index 5896fe7..048e179 100644 --- a/bitlinear/kernels.py +++ b/bitlinear/kernels.py @@ -75,4 +75,4 @@ def __call__(input, weight, bias=None): value += bias[j] out.append(value) output.append(out) - return torch.Tensor(output) + return torch.Tensor(output) \ No newline at end of file diff --git a/bitlinear/measures.py b/bitlinear/measures.py index 9ac4bec..986394e 100644 --- a/bitlinear/measures.py +++ b/bitlinear/measures.py @@ -12,4 +12,4 @@ def __call__(self, input, keepdim): class AbsMedian(Measure): def __call__(self, input, keepdim): - return input.abs().median(dim=-1, keepdim=keepdim).values if keepdim else input.abs().median() + return input.abs().median(dim=-1, keepdim=keepdim).values if keepdim else input.abs().median() \ No newline at end of file From 746c5512553110ba42c0fe5424d0f8842d6293eb Mon Sep 17 00:00:00 2001 From: Simon Opsahl Date: Mon, 5 Aug 2024 12:58:48 +0200 Subject: [PATCH 02/10] Frozen Bitlinear implemention --- bitlinear/frozen_bitlinear/.gitignore | 4 + bitlinear/frozen_bitlinear/ReadME.md | 116 ++++++ bitlinear/frozen_bitlinear/__init__.py | 1 + bitlinear/frozen_bitlinear/cuda/__init__.py | 2 + .../frozen_bitlinear/cuda/kernels/__init__.py | 1 + .../anthropic/files/bitlinear_naive.cu | 84 ++++ .../cuda/kernels/anthropic/files/naive.cu | 82 ++++ .../cuda/kernels/anthropic/files/row_major.cu | 146 +++++++ .../cuda/kernels/anthropic/setup.py | 23 ++ .../cuda/kernels/naive_linear/naive_linear.cu | 178 +++++++++ .../cuda/kernels/naive_linear/setup.py | 19 + .../cuda/kernels/no_sync/linear.cu | 128 ++++++ .../cuda/kernels/no_sync/setup.py | 19 + .../cuda/kernels/streamed_linear/linear.cu | 134 +++++++ .../cuda/kernels/streamed_linear/setup.py | 19 + .../cuda/pack_weights/pack_weights.cu | 98 +++++ .../cuda/pack_weights/setup.py | 19 + .../frozen_bitlinear/frozen_bitlinear.py | 141 +++++++ bitlinear/frozen_bitlinear/requirements.txt | 8 + bitlinear/frozen_bitlinear/scripts/build.sh | 34 ++ bitlinear/frozen_bitlinear/scripts/clean.sh | 23 ++ bitlinear/frozen_bitlinear/scripts/test.sh | 57 +++ bitlinear/frozen_bitlinear/src/__init__.py | 2 + bitlinear/frozen_bitlinear/src/anthropic.py | 27 ++ bitlinear/frozen_bitlinear/src/default.py | 78 ++++ bitlinear/frozen_bitlinear/src/triton.py | 154 +++++++ .../frozen_bitlinear/src/utils/Kernel.py | 26 ++ .../frozen_bitlinear/src/utils/__init__.py | 2 + .../frozen_bitlinear/src/utils/helpers.py | 13 + bitlinear/frozen_bitlinear/tests/Benchmark.py | 375 ++++++++++++++++++ bitlinear/frozen_bitlinear/tests/__main__.py | 15 + bitlinear/frozen_bitlinear/tests/helpers.py | 29 ++ 32 files changed, 2057 insertions(+) create mode 100644 bitlinear/frozen_bitlinear/.gitignore create mode 100644 bitlinear/frozen_bitlinear/ReadME.md create mode 100644 bitlinear/frozen_bitlinear/__init__.py create mode 100644 bitlinear/frozen_bitlinear/cuda/__init__.py create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/__init__.py create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/anthropic/files/bitlinear_naive.cu create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/anthropic/files/naive.cu create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/anthropic/files/row_major.cu create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/anthropic/setup.py create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/naive_linear.cu create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/setup.py create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/no_sync/linear.cu create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/no_sync/setup.py create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/streamed_linear/linear.cu create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/streamed_linear/setup.py create mode 100644 bitlinear/frozen_bitlinear/cuda/pack_weights/pack_weights.cu create mode 100644 bitlinear/frozen_bitlinear/cuda/pack_weights/setup.py create mode 100644 bitlinear/frozen_bitlinear/frozen_bitlinear.py create mode 100644 bitlinear/frozen_bitlinear/requirements.txt create mode 100755 bitlinear/frozen_bitlinear/scripts/build.sh create mode 100755 bitlinear/frozen_bitlinear/scripts/clean.sh create mode 100755 bitlinear/frozen_bitlinear/scripts/test.sh create mode 100644 bitlinear/frozen_bitlinear/src/__init__.py create mode 100644 bitlinear/frozen_bitlinear/src/anthropic.py create mode 100644 bitlinear/frozen_bitlinear/src/default.py create mode 100644 bitlinear/frozen_bitlinear/src/triton.py create mode 100644 bitlinear/frozen_bitlinear/src/utils/Kernel.py create mode 100644 bitlinear/frozen_bitlinear/src/utils/__init__.py create mode 100644 bitlinear/frozen_bitlinear/src/utils/helpers.py create mode 100644 bitlinear/frozen_bitlinear/tests/Benchmark.py create mode 100644 bitlinear/frozen_bitlinear/tests/__main__.py create mode 100644 bitlinear/frozen_bitlinear/tests/helpers.py diff --git a/bitlinear/frozen_bitlinear/.gitignore b/bitlinear/frozen_bitlinear/.gitignore new file mode 100644 index 0000000..9ce6516 --- /dev/null +++ b/bitlinear/frozen_bitlinear/.gitignore @@ -0,0 +1,4 @@ +results +**/weights +**.so +**/build \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/ReadME.md b/bitlinear/frozen_bitlinear/ReadME.md new file mode 100644 index 0000000..0c47631 --- /dev/null +++ b/bitlinear/frozen_bitlinear/ReadME.md @@ -0,0 +1,116 @@ +## Setup +Make sure you have the correct packages installed in your virtual environment. In a conda environment, you can run: +``` +> conda create -n env python pip +> conda activate env +> pip install -r requirements.txt +``` + +Afterwards, you need to build the desired CUDA kernels for use on your machine. To do so, you can selectively build the kernels you would like to test or use through: +``` +> cd path/to/bitlinear/bitlinear/kernels/cuda/pack_weights +> python setup.py build_ext --inplace +> cd path/to/bitlinear/bitlinear/kernels/cuda/kernels/* +> python setup.py build_ext --inplace +``` +You can also build all of them at once by running +``` +> cd path/to/bitlinear/bitlinear/kernels/ +> chmod +x scripts/build.sh +> scripts/build.sh +``` + +## Testing +Choose one of the kernels available in ```path/to/bitlinear/bitlinear/kernels/src/``` to test against the PyTorch baseline for your device. +``` +< conda activate env +< cd path/to/bitlinear/bitlinear/kernels +< chmod +x scripts/test.sh +< scripts/test.sh + < Which Kernel would you like to test? + < kernel_name +``` + +All logs, data, and plots are stored locally under ```path/to/bitlinear/bitlinear/kernels/results/{%Y%m%d_%T}/```. + +If any issues come up, you can reach me at ```sopsahl@mit.edu```. + +## Cleanup +To clean the builds and results, run the following commands. +``` +> cd path/to/bitlinear/bitlinear/kernels +> chmod +x scripts/clean.sh +> scripts/clean.sh +``` + + + + + + + + + +## Motivations + +Matrix multiplications are a key building block of most modern high-performance computing systems. +They are notoriously hard to optimize, hence their implementation is generally done by +hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS). +Unfortunately, these libraries are often proprietary and cannot be easily customized +to accommodate the needs of modern deep learning workloads (e.g., fused activation functions). +In comes Triton, which is easily customizeable and works to fit our need. + +Roughly speaking, the traditional kernel implements the following blocked +algorithm to multiply a (M, K) by a (K, N) matrix: + +```python +# Do in parallel +for m in range(0, M, BLOCK_SIZE_M): + # Do in parallel + for n in range(0, N, BLOCK_SIZE_N): + acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32) + for k in range(0, K, BLOCK_SIZE_K): + a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K] + b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N] + acc += dot(a, b) + C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc +``` +where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance. + +In a linear instance, all that is changed is the addition of the bias on the final step. + +## Compute Kernel + +The above algorithm is, actually, fairly straightforward to implement in Triton. +The main difficulty comes from the computation of the memory locations at which blocks +of ```A``` and ```B``` must be read in the inner loop. For that, we need +multi-dimensional pointer arithmetic. + +### Pointer Arithmetic + +For a row-major 2D tensor `X`, the memory location of `X[i, j]` is given +by `&X[i, j] = X + i*stride_xi + j*stride_xj`. +Therefore, blocks of pointers for `A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and +`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as: + +```python +&A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1) +&B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1) +``` + +### L2 Cache Optimizations + +As mentioned above, each program instance computes a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` +block of `C`. +It is important to remember that the order in which these blocks are computed does +matter, since it affects the L2 cache hit rate of our program, and unfortunately, a +simple row-major ordering + +```Python +pid = tl.program_id(axis=0) +grid_n = tl.cdiv(N, BLOCK_SIZE_N) +pid_m = pid // grid_n +pid_n = pid % grid_n +``` + +is just not going to cut it in the traditional case. When we are evaluating the bitlinear case, however, we must remember that the much more expensive operation is the activation load, not the weight load. We need to find balance between a simple row-major ordering that only requires one load for each block on `axis=0`, and one that optimizes over grouped clocking methods. \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/__init__.py b/bitlinear/frozen_bitlinear/__init__.py new file mode 100644 index 0000000..4e22c0f --- /dev/null +++ b/bitlinear/frozen_bitlinear/__init__.py @@ -0,0 +1 @@ +from src import TorchLinear, Naive \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/__init__.py b/bitlinear/frozen_bitlinear/cuda/__init__.py new file mode 100644 index 0000000..bc55d31 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/__init__.py @@ -0,0 +1,2 @@ +from cuda.kernels import * +from cuda.pack_weights import pack_weights \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/__init__.py b/bitlinear/frozen_bitlinear/cuda/kernels/__init__.py new file mode 100644 index 0000000..e6b9d63 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/__init__.py @@ -0,0 +1 @@ +from cuda.kernels.anthropic import naive, bitlinear_naive \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/anthropic/files/bitlinear_naive.cu b/bitlinear/frozen_bitlinear/cuda/kernels/anthropic/files/bitlinear_naive.cu new file mode 100644 index 0000000..79398cc --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/anthropic/files/bitlinear_naive.cu @@ -0,0 +1,84 @@ +#include +#include +#include +#include +#include + +#define CEIL_DIV(x, y) ((x) + (y)-1) / (y) + +__global__ void naive_kernel( + const half *input, + const int *weights, + const half *bias, + half *output, + float scale, + int M, + int N, + int K + ) +{ + const uint col = blockIdx.x * blockDim.x + threadIdx.x; + const uint row = blockIdx.y * blockDim.y + threadIdx.y; + + if (row < M && col < N) { + + float sum = 0.0f; + int weight; + + for (int k = 0; k < K; k += 16) { + weight = weights[(col * K + k)/16]; + + for (int offset=0; offset<16; offset++) { + int mask = (weight & (3 << (2 * offset))) >> (2 * offset); + float input_val = __half2float(input[row * K + k + offset]); + + if (mask == 1) { + sum += input_val; + } else if (mask == 2) { + sum -= input_val; + } + } + + } + + // Store the result into the correct location in memory + output[row * N + col] = __float2half((sum + __half2float(bias[row]))); + + } +} + +torch::Tensor linear( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + float scale, + int M, + int N, + int K +) { + auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0); + auto output = torch::zeros({M, N}, options); + + uint blockSize = 32; + + dim3 dimGrid(CEIL_DIV(N, blockSize), CEIL_DIV(M, blockSize)); + dim3 dimBlock(blockSize, blockSize); + + naive_kernel<<>>( + reinterpret_cast(input.data_ptr()), + weights.data_ptr(), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(output.data_ptr()), + scale, + M, N, K + ); + + return output; +} + + +// Binding to generate the .so file, to call from Python. +PYBIND11_MODULE(bitlinear_naive, m) { + m.doc() = "Implementation of bitlinear forward linear in CUDA"; + m.def("linear", &linear, "bitlinear_forward (CUDA)"); +} diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/anthropic/files/naive.cu b/bitlinear/frozen_bitlinear/cuda/kernels/anthropic/files/naive.cu new file mode 100644 index 0000000..ca85336 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/anthropic/files/naive.cu @@ -0,0 +1,82 @@ +#include +#include +#include +#include +#include + +#define CEIL_DIV(x, y) ((x) + (y)-1) / (y) + +__global__ void naive_kernel( + const half *input, + const int8_t *weights, + const half *bias, + half *output, + int M, + int N, + int K + ) +{ + const uint col = blockIdx.x * blockDim.x + threadIdx.x; + const uint row = blockIdx.y * blockDim.y + threadIdx.y; + + if (row < M && col < N) { + + float sum = 0.0f; + int8_t weight; + + for (int k = 0; k < K; k += 4) { + weight = weights[(col * K + k) >> 2]; + + for (int offset=0; offset<4; offset++) { + int8_t mask = (weight & (3 << (2 * offset))) >> (2 * offset); + + float input_val = __half2float(input[row * K + k + offset]); + + if (mask == 1) { + sum += input_val; + } else if (mask == 2) { + sum -= input_val; + } + } + + } + + // Store the result into the correct location in memory + output[row * N + col] = __float2half(sum + __half2float(bias[row])); + + } +} + +torch::Tensor linear( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + int M, + int N, + int K +) { + auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0); + auto output = torch::zeros({M, N}, options); + + uint blockSize = 32; + + dim3 dimGrid(CEIL_DIV(N, blockSize), CEIL_DIV(M, blockSize)); + dim3 dimBlock(blockSize, blockSize); + + naive_kernel<<>>( + reinterpret_cast(input.data_ptr()), + weights.data_ptr(), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(output.data_ptr()), + M, N, K + ); + + return output; +} + + +// Binding to generate the .so file, to call from Python. +PYBIND11_MODULE(naive, m) { + m.doc() = "Implementation of bitlinear forward linear in CUDA"; + m.def("linear", &linear, "bitlinear_forward (CUDA)"); +} diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/anthropic/files/row_major.cu b/bitlinear/frozen_bitlinear/cuda/kernels/anthropic/files/row_major.cu new file mode 100644 index 0000000..15cd108 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/anthropic/files/row_major.cu @@ -0,0 +1,146 @@ +#include +#include +#include +#include +#include + +#define CEIL_DIV(x, y) ((x) + (y)-1) / (y) + +__global__ void naive_kernel( + const half *input, + const int8_t *weights, + const half *bias, + half *output, + int M, + int N, + int K + ) +{ + const uint col = blockIdx.x * blockDim.x + threadIdx.x; + const uint row = blockIdx.y * blockDim.y + threadIdx.y; + + if (row < M && col < N) { + + float sum = 0.0f; + int8_t weight; + + for (int k = 0; k < K; k += 4) { + weight = weights[(col * K + k) >> 2]; + + for (int offset=0; offset<4; offset++) { + int8_t mask = (weight & (3 << (2 * offset))) >> (2 * offset); + + float input_val = __half2float(input[row * K + k + offset]); + + if (mask == 1) { + sum += input_val; + } else if (mask == 2) { + sum -= input_val; + } + } + + } + + // Store the result into the correct location in memory + output[row * N + col] = __float2half(sum + __half2float(bias[row])); + + } +} + +torch::Tensor linear( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + int M, + int N, + int K +) { + auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0); + auto output = torch::zeros({M, N}, options); + + uint blockSize = 32; + + dim3 dimGrid(CEIL_DIV(N, blockSize), CEIL_DIV(M, blockSize)); + dim3 dimBlock(blockSize, blockSize); + + naive_kernel<<>>( + reinterpret_cast(input.data_ptr()), + weights.data_ptr(), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(output.data_ptr()), + M, N, K + ); + + return output; +} + + +// Binding to generate the .so file, to call from Python. +PYBIND11_MODULE(naive, m) { + m.doc() = "Implementation of bitlinear forward linear in CUDA"; + m.def("linear", &linear, "bitlinear_forward (CUDA)"); +} + +void run_sgemm_shared_mem_block(int M, int N, int K, float alpha, float *A, + float *B, float beta, float *C) { + dim3 gridDim(CEIL_DIV(M, 32), CEIL_DIV(N, 32)); + dim3 blockDim(32 * 32); + // L1 cache becomes useless, since we access GMEM only via SMEM, so we carve + // out all of L1 to SMEM. This doesn't currently make a difference, since + // occupancy is limited by reg and thread count, but it's good to do anyway. + cudaFuncSetAttribute(sgemm_shared_mem_block<32>, + cudaFuncAttributePreferredSharedMemoryCarveout, + cudaSharedmemCarveoutMaxShared); + sgemm_shared_mem_block<32> + <<>>(M, N, K, alpha, A, B, beta, C); +} + +__global__ void shared_memory( + const half *input, + const int8_t *weights, + const half *bias, + half *output, + int M, + int N, + int K, + int computeBlockSize, + int blockSize + ) +{ + const uint blockRow = blockIdx.x; + const uint threadCol = blockSize * blockIdx.y + threadIdx.x * computeBlockSize; + + // one row shared between all threads in a block + __shared__ half shared_input[K]; + + // advance pointers to the starting positions + input += blockRow * K; + weights += threadCol * K; + output += blockRow * N + threadCol; + + float sum = 0.0; + for (int bkIdx = 0; bkIdx < K; bkIdx += BLOCKSIZE) { + // Have each thread load one of the elements in A & B + // Make the threadCol (=threadIdx.x) the consecutive index + // to allow global memory access coalescing + shared_input[blockRow * BLOCKSIZE + threadCol] = A[blockRow * K + threadCol]; + Bs[blockRow * BLOCKSIZE + threadCol] = B[blockRow * N + threadCol]; + + // block threads in this block until cache is fully populated + __syncthreads(); + + A += BLOCKSIZE; + B += BLOCKSIZE * N; + + // execute the dotproduct on the currently cached block + for (int dotIdx = 0; dotIdx < BLOCKSIZE; ++dotIdx) { + tmp += As[threadRow * BLOCKSIZE + dotIdx] * + Bs[dotIdx * BLOCKSIZE + threadCol]; + } + // need to sync again at the end, to avoid faster threads + // fetching the next block into the cache before slower threads are done + __syncthreads(); + } + C[threadRow * N + threadCol] = + alpha * tmp + beta * C[threadRow * N + threadCol]; +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/anthropic/setup.py b/bitlinear/frozen_bitlinear/cuda/kernels/anthropic/setup.py new file mode 100644 index 0000000..1e7e465 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/anthropic/setup.py @@ -0,0 +1,23 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. The extension requires CUDA.") + +setup( + name='naive', + ext_modules=[ + CUDAExtension( + name='naive', + sources=['files/naive.cu'] + ), + CUDAExtension( + name='bitlinear_naive', + sources=['files/bitlinear_naive.cu'] + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/naive_linear.cu b/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/naive_linear.cu new file mode 100644 index 0000000..607983b --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/naive_linear.cu @@ -0,0 +1,178 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +/* +********************************************************************* +function name: linear_matmul_cuda +description: dot product of two arbitrarily sized matrices. +parameters: + input: Input of size m X k. + weights: weight kernel of size n X k packed in int8. Contains + 1 at bit0 if weight is 1 and 1 at bit1 if weight is -1 + bias: bias per output channel. + m,k,n: sizes of matrices. +return: none +********************************************************************* +*/ +__global__ void fp16_linear_matmul_cuda( + const half *input, + const int8_t *weights, + const half *bias, + half *output, + const int m, + const int k, + const int n, + int *status +) +{ + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + // Check if we are within bounds + if (col < n && row < m) { + + half sum = __float2half(0.0f); + int start = col * k; + int shift = (start << 1) % 8; + + int index; + int8_t weight; + + // Iterate through the interior dimension + for (int i = 0; i < k; i++) { + + index = (start + i) >> 2; + weight = (weights[index] >> shift) & 0x03; + + if (weight == 1) { + sum = __hadd(sum, input[row * k + i]); + } else if (weight == 2) { + sum = __hsub(sum, input[row * k + i]); + } else if (weight == 3) { + atomicExch(status, 1); // Set error flag, weights incorrectly packed + return; + } + + shift = (shift + 2) % 8; + } + + // Store the result into the correct location in memory + output[row * n + col] = __hadd(sum, bias[row]); + + } +} + +////////////////////////////////////////////////////////////// +// This performs the computations in fp32 for higher precision +////////////////////////////////////////////////////////////// +__global__ void fp32_linear_matmul_cuda( + const half *input, + const int8_t *weights, + const half *bias, + half *output, + const int m, + const int k, + const int n +) +{ + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + // Check if we are within bounds + if (col < n && row < m) { + + float sum = 0.0f; + int start = col * k; + int shift = (start << 1) % 8; + + int index; + int8_t weight; + + // Iterate through the interior dimension + for (int i = 0; i < k; i++) { + + index = (start + i) >> 2; + weight = (weights[index] >> shift) & 0x03; + + if (weight == 1) { + sum += __half2float(input[row * k + i]); + } else if (weight == 2) { + sum -= __half2float(input[row * k + i]); + } + + shift = (shift + 2) % 8; + } + + // Store the result into the correct location in memory + output[row * n + col] = __float2half(sum + __half2float(bias[row])); + + } +} +/* +********************************************************************* +function name: linear +description: linear layer that calls the matmul kernel. +parameters: + input: Input of size m X k. + weights: weight kernel of size n X k packed in int8. Contains + 1 at bit0 if weight is 1 and 1 at bit1 if weight is -1 + bias: bias per output channel. + m,k,n: sizes of matrices. + row_block_size: the number of rows to compute in parallel. + col_block_size: the number of columns to compute in parallel. +return: + output: output of size m x n. +********************************************************************* +*/ +torch::Tensor linear( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + int m, + int k, + int n, + const unsigned int row_block_size, + const unsigned int col_block_size +) { + auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0); + auto output = torch::zeros({m, n}, options); + + unsigned int grid_rows = (m + row_block_size - 1) / row_block_size; + unsigned int grid_cols = (n + col_block_size - 1) / col_block_size; + + dim3 dimGrid(grid_cols, grid_rows); + dim3 dimBlock(col_block_size, row_block_size); + + fp32_linear_matmul_cuda<<>>( + reinterpret_cast(input.data_ptr()), + weights.data_ptr(), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(output.data_ptr()), + m, k, n + ); + + // Check for any errors launching the kernel + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + throw std::runtime_error("CUDA kernel launch failed: " + std::string(cudaGetErrorString(err))); + } + + + cudaDeviceSynchronize(); + return output; +} + + +// Binding to generate the .so file, to call from Python. +PYBIND11_MODULE(naive_linear_cuda, m) { + m.doc() = "Implementation of bitlinear forward linear in CUDA"; + m.def("linear", &linear, "bitlinear_forward (CUDA)"); +} diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/setup.py b/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/setup.py new file mode 100644 index 0000000..e65bc8f --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. The extension requires CUDA.") + +setup( + name='naive_linear_cuda', + ext_modules=[ + CUDAExtension( + name='naive_linear_cuda', + sources=['naive_linear.cu'] + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/no_sync/linear.cu b/bitlinear/frozen_bitlinear/cuda/kernels/no_sync/linear.cu new file mode 100644 index 0000000..dd3badb --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/no_sync/linear.cu @@ -0,0 +1,128 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +/* +********************************************************************* +function name: linear_matmul_cuda +description: dot product of two arbitrarily sized matrices. +parameters: + input: Input of size m X k. + weights: weight kernel of size n X k packed in int8. Contains + 1 at bit0 if weight is 1 and 1 at bit1 if weight is -1 + bias: bias per output channel. + m,k,n: sizes of matrices. +return: none +********************************************************************* +*/ +////////////////////////////////////////////////////////////// +// This performs the computations in fp32 for higher precision +////////////////////////////////////////////////////////////// +__global__ void fp32_linear_matmul_cuda( + const half *input, + const int8_t *weights, + const half *bias, + half *output, + const int m, + const int k, + const int n +) +{ + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + // Check if we are within bounds + if (col < n && row < m) { + + float sum = 0.0f; + int start = col * k; + int shift = (start << 1) % 8; + + int index; + int8_t weight; + + // Iterate through the interior dimension + for (int i = 0; i < k; i++) { + + index = (start + i) >> 2; + weight = (weights[index] >> shift) & 0x03; + + if (weight == 1) { + sum += __half2float(input[row * k + i]); + } else if (weight == 2) { + sum -= __half2float(input[row * k + i]); + } + + shift = (shift + 2) % 8; + } + + // Store the result into the correct location in memory + output[row * n + col] = __float2half(sum + __half2float(bias[row])); + + } +} +/* +********************************************************************* +function name: linear +description: linear layer that calls the matmul kernel. +parameters: + input: Input of size m X k. + weights: weight kernel of size n X k packed in int8. Contains + 1 at bit0 if weight is 1 and 1 at bit1 if weight is -1 + bias: bias per output channel. + m,k,n: sizes of matrices. + row_block_size: the number of rows to compute in parallel. + col_block_size: the number of columns to compute in parallel. +return: + output: output of size m x n. +********************************************************************* +*/ +torch::Tensor linear( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + int m, + int k, + int n, + const unsigned int row_block_size, + const unsigned int col_block_size +) { + auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0); + auto output = torch::zeros({m, n}, options); + + unsigned int grid_rows = (m + row_block_size - 1) / row_block_size; + unsigned int grid_cols = (n + col_block_size - 1) / col_block_size; + + dim3 dimGrid(grid_cols, grid_rows); + dim3 dimBlock(col_block_size, row_block_size); + + fp32_linear_matmul_cuda<<>>( + reinterpret_cast(input.data_ptr()), + weights.data_ptr(), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(output.data_ptr()), + m, k, n + ); + + // Check for any errors launching the kernel + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + throw std::runtime_error("CUDA kernel launch failed: " + std::string(cudaGetErrorString(err))); + } + + return output; +} + + +// Binding to generate the .so file, to call from Python. +PYBIND11_MODULE(no_stream, m) { + m.doc() = "Implementation of bitlinear forward linear in CUDA"; + m.def("linear", &linear, "bitlinear_forward (CUDA)"); +} diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/no_sync/setup.py b/bitlinear/frozen_bitlinear/cuda/kernels/no_sync/setup.py new file mode 100644 index 0000000..36061a6 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/no_sync/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. The extension requires CUDA.") + +setup( + name='no_stream', + ext_modules=[ + CUDAExtension( + name='no_stream', + sources=['linear.cu'] + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/streamed_linear/linear.cu b/bitlinear/frozen_bitlinear/cuda/kernels/streamed_linear/linear.cu new file mode 100644 index 0000000..f5d390e --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/streamed_linear/linear.cu @@ -0,0 +1,134 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +/* +********************************************************************* +function name: linear_matmul_cuda +description: dot product of two arbitrarily sized matrices. +parameters: + input: Input of size m X k. + weights: weight kernel of size n X k packed in int8. Contains + 1 at bit0 if weight is 1 and 1 at bit1 if weight is -1 + bias: bias per output channel. + m,k,n: sizes of matrices. +return: none +********************************************************************* +*/ +////////////////////////////////////////////////////////////// +// This performs the computations in fp32 for higher precision +////////////////////////////////////////////////////////////// +__global__ void fp32_linear_matmul_cuda( + const half *input, + const int8_t *weights, + const half *bias, + half *output, + const int m, + const int k, + const int n +) +{ + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + // Check if we are within bounds + if (col < n && row < m) { + + float sum = 0.0f; + int start = col * k; + int shift = (start << 1) % 8; + + int index; + int8_t weight; + + // Iterate through the interior dimension + for (int i = 0; i < k; i++) { + + index = (start + i) >> 2; + weight = (weights[index] >> shift) & 0x03; + + if (weight == 1) { + sum += __half2float(input[row * k + i]); + } else if (weight == 2) { + sum -= __half2float(input[row * k + i]); + } + + shift = (shift + 2) % 8; + } + + // Store the result into the correct location in memory + output[row * n + col] = __float2half(sum + __half2float(bias[row])); + + } +} +/* +********************************************************************* +function name: linear +description: linear layer that calls the matmul kernel. +parameters: + input: Input of size m X k. + weights: weight kernel of size n X k packed in int8. Contains + 1 at bit0 if weight is 1 and 1 at bit1 if weight is -1 + bias: bias per output channel. + m,k,n: sizes of matrices. + row_block_size: the number of rows to compute in parallel. + col_block_size: the number of columns to compute in parallel. +return: + output: output of size m x n. +********************************************************************* +*/ +torch::Tensor linear( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + int m, + int k, + int n, + const unsigned int row_block_size, + const unsigned int col_block_size +) { + auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0); + auto output = torch::zeros({m, n}, options); + + unsigned int grid_rows = (m + row_block_size - 1) / row_block_size; + unsigned int grid_cols = (n + col_block_size - 1) / col_block_size; + + dim3 dimGrid(grid_cols, grid_rows); + dim3 dimBlock(col_block_size, row_block_size); + + cudaStream_t stream; + cudaStreamCreate(&stream); + + fp32_linear_matmul_cuda<<>>( + reinterpret_cast(input.data_ptr()), + weights.data_ptr(), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(output.data_ptr()), + m, k, n + ); + + // Check for any errors launching the kernel + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + throw std::runtime_error("CUDA kernel launch failed: " + std::string(cudaGetErrorString(err))); + } + + cudaStreamSynchronize(stream); + cudaStreamDestroy(stream); + + return output; +} + + +// Binding to generate the .so file, to call from Python. +PYBIND11_MODULE(streamed_linear_cuda, m) { + m.doc() = "Implementation of bitlinear forward linear in CUDA"; + m.def("linear", &linear, "bitlinear_forward (CUDA)"); +} diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/streamed_linear/setup.py b/bitlinear/frozen_bitlinear/cuda/kernels/streamed_linear/setup.py new file mode 100644 index 0000000..0fdf5f0 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/streamed_linear/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. The extension requires CUDA.") + +setup( + name='streamed_linear_cuda', + ext_modules=[ + CUDAExtension( + name='streamed_linear_cuda', + sources=['linear.cu'] + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/bitlinear/frozen_bitlinear/cuda/pack_weights/pack_weights.cu b/bitlinear/frozen_bitlinear/cuda/pack_weights/pack_weights.cu new file mode 100644 index 0000000..c7e22f9 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/pack_weights/pack_weights.cu @@ -0,0 +1,98 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +__global__ void pack_weights_kernel(const half* weights, int* packed_weights, int n, int k) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + int col = blockIdx.y * blockDim.y + threadIdx.y; + + if (row < n && col < k) { + int weight_index = row * k + col; + float weight_value = __half2float(weights[weight_index]); + + int packed_index = weight_index / 4; // 4 weights per int8 + int bit_index = 2*weight_index % 8; // each weight takes 2 bits + + int bit_mask = (weight_value == 1.0f) ? (1 << bit_index) : + (weight_value == -1.0f) ? (2 << bit_index) : 0; + + atomicOr(&packed_weights[packed_index], bit_mask); // this avoids race condition errors from parallelization + + } +} + +__global__ void int32_to_int8_kernel(const int* packed_weights_int32, int8_t* packed_weights_int8, int packed_size) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i < packed_size) { + int packed_value = packed_weights_int32[i]; + packed_weights_int8[i] = static_cast(packed_value & 0xFF); + } +} + +torch::Tensor packedint8( + torch::Tensor weights, + int n, + int k + ) { + + TORCH_CHECK(weights.is_contiguous(), "weights tensor must be contiguous"); + TORCH_CHECK(weights.dtype() == torch::kFloat16, "weights tensor must be of type float16"); + + // Calculate size for packed weights tensor + int packed_size = (n * k + 3) / 4; // 4 weights per int8 + + + auto packed_weights_int32 = torch::zeros({packed_size}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); + + const unsigned int block_size = 32; + const unsigned int grid_rows = (n + block_size - 1) / block_size; + const unsigned int grid_cols = (k + block_size - 1) / block_size; + + dim3 dimGrid(grid_cols, grid_rows); + dim3 dimBlock(block_size, block_size); + + pack_weights_kernel<<>>( + reinterpret_cast(weights.data_ptr()), + packed_weights_int32.data_ptr(), + n, + k + ); + + cudaError_t err = cudaGetLastError(); + TORCH_CHECK(err == cudaSuccess, "CUDA kernel failed with error: ", cudaGetErrorString(err)); + + cudaDeviceSynchronize(); + + // Allocate final packed weights as int8 + auto packed_weights_int8 = torch::empty({packed_size}, torch::TensorOptions().dtype(torch::kInt8).device(torch::kCUDA)); + + // Calculate the number of blocks needed for the second kernel + unsigned int grid_size = (packed_size + block_size - 1) / block_size; + + // Launch the second kernel to cast int32 to int8 + int32_to_int8_kernel<<>>( + packed_weights_int32.data_ptr(), + packed_weights_int8.data_ptr(), + packed_size + ); + + err = cudaGetLastError(); + TORCH_CHECK(err == cudaSuccess, "CUDA kernel (int32_to_int8_kernel) failed with error: ", cudaGetErrorString(err)); + + cudaDeviceSynchronize(); + + return packed_weights_int8; +} + + +PYBIND11_MODULE(pack_weights, m) { + m.def("packedint8", &packedint8, "Pack fp16 weights into int8 tensor"); +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/pack_weights/setup.py b/bitlinear/frozen_bitlinear/cuda/pack_weights/setup.py new file mode 100644 index 0000000..8801da6 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/pack_weights/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. The extension requires CUDA.") + +setup( + name='pack_weights', + ext_modules=[ + CUDAExtension( + name='pack_weights', + sources=['pack_weights.cu'] + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/bitlinear/frozen_bitlinear/frozen_bitlinear.py b/bitlinear/frozen_bitlinear/frozen_bitlinear.py new file mode 100644 index 0000000..fa83b15 --- /dev/null +++ b/bitlinear/frozen_bitlinear/frozen_bitlinear.py @@ -0,0 +1,141 @@ +import torch +import torch.nn as nn + +from ..bitlinear import BitLinear +from ..measures import * +from bitlinear.frozen_bitlinear.frozen_bitlinear import TorchLinear, Naive + +class FrozenBitLinear(nn.Linear): + + w_scale = None + + def __init__( + self, + in_features, + out_features, + kernel, + bias=True, + eps=1e-5, + activation_range=8, + activation_measure='AbsMax', + device=None, + dtype=None + ): + super(FrozenBitLinear, self).__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + self.eps = eps + + self.kernel = eval(kernel)(activation_range, activation_measure) if isinstance(kernel, str) else kernel + + def __repr__(self): + return f"FrozenBitLinear(in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}, kernel={self.kernel}), activation_range={self.activation_range}, activation_measure={self.activation_measure}" + + def forward(self, x): + if self.activation_measure is None: + x_scale, x_quant = 1, x + else: + x_norm = torch.layer_norm(x, x.size()[1:]) + x_scale = 1 / scale(x_norm, self.activation_range, self.activation_measure, True, self.eps) + x_quant = round_clamp(x_norm / x_scale, self.activation_range) + + return self.kernel(x_quant, self.weight, self.bias, self.w_scale * x_scale) + + def freeze_weights(self, weights:torch.Tensor, weightMeasure:str='AbsMean'): + """ + Parameters: + weights : torch.Tensor + tensor of weights to be packed (out_features x in_features) + weightMeasure : str + str corresponding to a weight packing method + options are 'AbsMean', 'AbsMax', and 'AbsMedian' + default is 'AbsMean' + + Returns: + None + + This function must be called for a Frozen BitLinear Module to pack the weights. It takes advantage of a + kernel-specific weight packing function to store the weights in their ternary representation. This is implicitly + called in bitlinear.freeze + """ + + assert (weights.shape[0] == self.out_features) & (weights.shape[1] == self.in_features), "Weights dimensions must match that of the layer" + + packed_weights, self.w_scale = self.kernel.scale_weights(weights, weightMeasure, self.eps) + + if isinstance(packed_weights, torch.nn.Parameter): + self.weight = packed_weights + else: + self.weight = nn.Parameter(packed_weights) + + +def freeze( + model, + kernel=TorchLinear, + weightMeasure = 'AbsMean', + eps=1e-5, + activation_measure='AbsMax', + activation_range=16, + device=None, + dtype=None + ): + """ + Parameters: + model : torch.nn.Module + Model with BitLinear modules to replace with FrozenBitLinear for faster inference + kernel : str or Kernel + Forward kernel to implement in the model + weightMeasure : str + str corresponding to a weight packing method + options are 'AbsMean', 'AbsMax', and 'AbsMedian' + default is 'AbsMean' + eps : float + value to clamp the weights to in the scaling step to avoid divide-by-zero + default is 1e-5 + activation_measure : str + str corresponding to the activation quantization method + options are 'AbsMean', 'AbsMax', 'AbsMedian', and None + 'AbsMax' + activation_range : int + number of bits to represent the activations in + default is 8 + device : Optional(torch.device) + dtype : Optional(torch.dtype) + + Returns: + None + + This function replaces all of the BitLinear instances in a model with FrozenBitLinear in order to + speed up inference. The weights are packed corresponding to the kernel. + """ + + for name, module in model.named_children(): + if isinstance(module, BitLinear): + + kwargs = { + 'kernel' : kernel, + 'in_features' : module.in_features, + 'out_features' : module.out_features, + 'bias' : getattr(module, "bias", None) is not None, + 'device' : device, + 'dtype' : dtype, + 'eps' : eps, + 'activation_measure' : activation_measure, + 'activation_range' : activation_range + } + + frozen_module = FrozenBitLinear(**kwargs) + frozen_module.freeze_weights(module.weight.data, weightMeasure) # pack + + if kwargs['bias']: + frozen_module.bias.data = module.bias.data + + setattr(model, name, frozen_module) + + else: # recursively iterate throughout the rest of the model + freeze(module, kernel, weightMeasure, eps, device, dtype) diff --git a/bitlinear/frozen_bitlinear/requirements.txt b/bitlinear/frozen_bitlinear/requirements.txt new file mode 100644 index 0000000..4aa6a17 --- /dev/null +++ b/bitlinear/frozen_bitlinear/requirements.txt @@ -0,0 +1,8 @@ +torch +triton +pandas +numpy +plotly +tqdm +matplotlib +openpxyl \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/scripts/build.sh b/bitlinear/frozen_bitlinear/scripts/build.sh new file mode 100755 index 0000000..3cb7419 --- /dev/null +++ b/bitlinear/frozen_bitlinear/scripts/build.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +# Get the directory where the script is located +SCRIPT_DIR=$(dirname "$(realpath "$0")") + +# Define the parent directory paths +Kernel_DIR="${SCRIPT_DIR}/../cuda/kernels" +PACK_WEIGHTS_DIR="${SCRIPT_DIR}/../cuda/pack_weights" + +# Iterate through each subdirectory under $PARENT_DIR/cuda +for dir in "${Kernel_DIR}/"*/; do + if [ -f "${dir}setup.py" ]; then + echo "Building in directory: $dir" + (cd "$dir" && python setup.py build_ext --inplace) + else + # If `setup.py` is not in the first level, check subdirectories + for subdir in "${dir}"*/; do + if [ -f "${subdir}setup.py" ]; then + echo "Building in subdirectory: $subdir" + (cd "$subdir" && python setup.py build_ext --inplace) + fi + done + fi +done + +# Build in the pack_weights directory +if [ -f "${PACK_WEIGHTS_DIR}/setup.py" ]; then + echo "Building in pack_weights directory: $PACK_WEIGHTS_DIR" + (cd "$PACK_WEIGHTS_DIR" && python setup.py build_ext --inplace) +else + echo "No setup.py found in pack_weights directory: ${PACK_WEIGHTS_DIR}" +fi + +echo "Build complete." \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/scripts/clean.sh b/bitlinear/frozen_bitlinear/scripts/clean.sh new file mode 100755 index 0000000..d4ffe47 --- /dev/null +++ b/bitlinear/frozen_bitlinear/scripts/clean.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +# Get the directory where the script is located +SCRIPT_DIR=$(dirname "$(realpath "$0")") +PARENT_DIR="${SCRIPT_DIR}/.." + +# Delete files with .so extension +echo "Deleting .so files..." +find "$PARENT_DIR" -type f -name '*.so' -exec rm -f {} \; + +# Delete directories named 'weights' +echo "Deleting 'weights' directories..." +find "$PARENT_DIR" -type d -name 'weights' -exec rm -rf {} \; + +# Delete directories named 'build' +echo "Deleting 'build' directories..." +find "$PARENT_DIR" -type d -name 'build' -exec rm -rf {} \; + +# Delete directories named 'build' +echo "Deleting 'results' directories..." +find "$PARENT_DIR" -type d -name 'results' -exec rm -rf {} \; + +echo "Cleanup complete." \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/scripts/test.sh b/bitlinear/frozen_bitlinear/scripts/test.sh new file mode 100755 index 0000000..d7a209d --- /dev/null +++ b/bitlinear/frozen_bitlinear/scripts/test.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +timestamp=$(date "+%Y%m%d_%T") +save_dir="results/$timestamp" + +# Default values +device=0 +kernel="TorchLinear" + +# Function to display usage information +usage() { + echo "Usage: $0 [-d ] [-k ]" + echo " -d Optional device argument (0, 1, 2, 3). Default is 0." + echo " -k Optional kernel argument. Default is 'TorchLinear'." + exit 1 +} + +# Parse options +while getopts ":d:k:" opt; do + case ${opt} in + d ) + device=$OPTARG + # Check if the integer argument is valid (0, 1, 2, 3) + if ! [[ "$device" =~ ^[0-3]$ ]]; then + echo "Error: Device argument must be one of 0, 1, 2, or 3" + usage + fi + ;; + k ) + kernel=$OPTARG + ;; + \? ) + echo "Invalid option: -$OPTARG" >&2 + usage + ;; + : ) + echo "Option -$OPTARG requires an argument" >&2 + usage + ;; + esac +done + +shift $((OPTIND -1)) + +mkdir -p ${save_dir} +cp $0 $save_dir # saves the current training script + +cd "$(dirname "$0")/.." + +run_cmd="CUDA_VISIBLE_DEVICES=$device python -m tests \ + --save_dir $save_dir \ + --kernel $kernel \ + -a \ + 2>&1 | tee ${save_dir}/test-${timestamp}.log" + +echo "$run_cmd" +eval "$run_cmd" diff --git a/bitlinear/frozen_bitlinear/src/__init__.py b/bitlinear/frozen_bitlinear/src/__init__.py new file mode 100644 index 0000000..76951f7 --- /dev/null +++ b/bitlinear/frozen_bitlinear/src/__init__.py @@ -0,0 +1,2 @@ +from src.default import TorchLinear +from src.anthropic import Naive diff --git a/bitlinear/frozen_bitlinear/src/anthropic.py b/bitlinear/frozen_bitlinear/src/anthropic.py new file mode 100644 index 0000000..554e926 --- /dev/null +++ b/bitlinear/frozen_bitlinear/src/anthropic.py @@ -0,0 +1,27 @@ +from src.utils import Packed8, naive + +class Anthropic(Packed8): + fxn = lambda args: NotImplementedError + + def __call__(self, activations, weights, bias, scale): + # Check constraints. + assert activations.is_contiguous(), "Matrix A must be contiguous" + assert activations.shape[0] == bias.shape[0], "Bias dimension must match input" + + M, K = activations.shape + N = weights.shape[0] * 4//K + + return self.fxn(activations, weights, bias, M, N, K) * scale + +class Naive(Anthropic): + fxn = naive.linear + + + + + + + + + + diff --git a/bitlinear/frozen_bitlinear/src/default.py b/bitlinear/frozen_bitlinear/src/default.py new file mode 100644 index 0000000..9c9cfeb --- /dev/null +++ b/bitlinear/frozen_bitlinear/src/default.py @@ -0,0 +1,78 @@ +import torch +import torch.nn.functional as F + +from src.utils import Default + +class TorchLinear(Default): + + def __call__(self, input, weight, bias=None, scale=1): + return F.linear(input, weight, bias) * scale + + +class TorchMulAdd(TorchLinear): + def __call__(self, input, weight, bias=None, scale=1): + output = input @ weight.t() + if bias is not None: + output += bias + return output + +class Naive(TorchLinear): + def __call__(self, input, weight, bias=None, scale=1): + print(input.shape, weight.shape, bias.shape if bias is not None else None) + input = input.tolist() + weight = weight.tolist() + if bias is not None: + bias = bias.tolist() + output = [] + n = len(input) + m = len(input[0]) + p = len(weight) + for i in range(n): + out = [] + for j in range(p): + value = sum(input[i][k] * weight[j][k] for k in range(m)) + if bias is not None: + value += bias[j] + out.append(value) + output.append(out) + return torch.Tensor(output) * scale + +class NaiveListComp(TorchLinear): + def __call__(self, input, weight, bias=None, scale=1): + input = input.tolist() + weight = weight.tolist() + if bias is not None: + bias = bias.tolist() + output = [] + n = len(input) + m = len(input[0]) + p = len(weight) + output = [[sum(input[i][k] * weight[j][k] for k in range(m)) + (bias[j] if bias is not None else 0) for j in range(p)] for i in range(n)] + return torch.Tensor(output) * scale + +class TernaryNaive(TorchLinear): + def __call__(self, input, weight, bias=None, scale=1): + input = input.tolist() + weight = weight.tolist() + assert all(all(x in {-1, 0, 1} for x in row) for row in input) + if bias is not None: + bias = bias.tolist() + output = [] + n = len(input) + m = len(input[0]) + p = len(weight) + for i in range(n): + out = [] + for j in range(p): + value = 0 + for k in range(m): + match weight[j][k]: + case 1: + value += input[i][k] + case -1: + value -= input[i][k] + if bias is not None: + value += bias[j] + out.append(value) + output.append(out) + return torch.Tensor(output) * scale diff --git a/bitlinear/frozen_bitlinear/src/triton.py b/bitlinear/frozen_bitlinear/src/triton.py new file mode 100644 index 0000000..431a14a --- /dev/null +++ b/bitlinear/frozen_bitlinear/src/triton.py @@ -0,0 +1,154 @@ +# import torch +# import triton +# import triton.language as tl +# from .utils import Kernel, get_cuda_autotune_config + +# ###### Baseline ###### +# class default(Kernel): + +# def __call__(self, activations, weights, bias, scale): +# # Check constraints. +# assert activations.shape[1] == weights.shape[1], "Incompatible dimensions" +# assert activations.is_contiguous(), "Matrix A must be contiguous" +# assert activations.shape[0] == bias.shape[0], "Bias dimension must match input" + +# M, K = activations.shape +# N = weights.shape[0] + +# weights = weights.transpose(0, 1) + +# # Allocates output. +# output = torch.empty((M, N), device=activations.device, dtype=torch.float16) + +# # 1D launch kernel where each block gets its own program. +# grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + +# kernel[grid]( +# activations, weights, output, bias, # +# M, N, K, # +# activations.stride(0), activations.stride(1), # +# weights.stride(0), weights.stride(1), # +# output.stride(0), output.stride(1), # +# bias.stride(0) +# ) +# return output * scale + + +# def get_cuda_autotune_config(): +# return [ +# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, +# num_warps=8), +# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, +# num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, +# num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, +# num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, +# num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, +# num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, +# num_warps=2), +# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, +# num_warps=2), +# # Good config for fp8 inputs. +# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, +# num_warps=8), +# triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, +# num_warps=8), +# triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, +# num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, +# num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, +# num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, +# num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, +# num_warps=4), +# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, +# num_warps=4) +# ] + +# @triton.autotune( +# configs=get_cuda_autotune_config(), +# key=['M', 'N', 'K'], +# ) +# @triton.jit +# def kernel( +# # Pointers to matrices +# activation_ptr, weights_ptr, ouput_ptr, bias_ptr, +# # Matrix dimensions +# M, N, K, +# # The stride variables represent how much to increase the ptr by when moving by 1 +# # element in a particular dimension. E.g. `stride_am` is how much to increase `activation_ptr` +# # by to get the element one row down (A has M rows). +# stride_am, stride_ak, +# stride_bk, stride_bn, +# stride_cm, stride_cn, +# stride_dm, +# # Meta-parameters +# BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, +# GROUP_SIZE_M: tl.constexpr +# ): +# """Kernel for computing the matmul C = A x B + D +# A has shape (M, K), B has shape (K, N) and D has shape (M, 1) +# Output C has shape (M, N) +# """ +# # ----------------------------------------------------------- +# # Map program ids `pid` to the block of C it should compute. +# # This is done in a grouped ordering to promote L2 data reuse. + +# pid = tl.program_id(axis=0) + +# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) +# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) +# num_pid_in_group = GROUP_SIZE_M * num_pid_n +# group_id = pid // num_pid_in_group +# first_pid_m = group_id * GROUP_SIZE_M +# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) +# pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) +# pid_n = (pid % num_pid_in_group) // group_size_m + +# # ---------------------------------------------------------- +# # Create pointers for the first blocks of A and B. +# # We will advance this pointer as we move in the K direction +# # and accumulate +# offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M +# offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N +# offs_k = tl.arange(0, BLOCK_SIZE_K) +# activation_ptrs = activation_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) +# weights_ptrs = weights_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + +# # ----------------------------------------------------------- +# # Iterate to compute a block of the C matrix. +# # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block +# # of fp32 values for higher accuracy. +# accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) +# for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): +# # Load the next block of A and B, generate a mask by checking the K dimension. +# # If it is out of bounds, set it to 0. +# inputs = tl.load(activation_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) +# weights = tl.load(weights_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) +# # We accumulate along the K dimension. +# accumulator += tl.dot(inputs, weights) +# # Advance the ptrs to the next K block. +# activation_ptrs += BLOCK_SIZE_K * stride_ak +# weights_ptrs += BLOCK_SIZE_K * stride_bk + +# # Add bias D to the accumulated result +# bias_ptrs = bias_ptr + offs_am[:, None] * stride_dm +# bias = tl.load(bias_ptrs) +# accumulator += bias + +# output = accumulator.to(tl.float16) + +# # ----------------------------------------------------------- +# # Write back the block of the output matrix C with masks. +# offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) +# offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) +# ouput_ptrs = ouput_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] +# ouput_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) +# tl.store(ouput_ptrs, output, mask=ouput_mask) + \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/src/utils/Kernel.py b/bitlinear/frozen_bitlinear/src/utils/Kernel.py new file mode 100644 index 0000000..d02e624 --- /dev/null +++ b/bitlinear/frozen_bitlinear/src/utils/Kernel.py @@ -0,0 +1,26 @@ +import torch +import torch.nn.functional as F + +from src.utils.helpers import round_clamp +from cuda.pack_weights import pack_weights + +class Kernel: + + def __call__(self, activations, weights, bias, scale) -> torch.Tensor: + raise NotImplementedError # this needs to be changed depending on the kernel + + def scale_weights(self, weights, measure='AbsMean', eps=1e-5) -> tuple[torch.Tensor, float]: + raise NotImplementedError # this needs to be changed depending on the kernel + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + +class Default(Kernel): + def scale_weights(self, weights, measure='AbsMean', eps=1e-5) -> tuple[torch.Tensor, float]: + return round_clamp(weights, measure, eps) + +class Packed8(Kernel): + def scale_weights(self, weights, measure='AbsMean', eps=1e-5) -> tuple[torch.Tensor, float]: + weights, scale = round_clamp(weights, measure, eps) + return pack_weights.packedint8(weights, *weights.shape), scale diff --git a/bitlinear/frozen_bitlinear/src/utils/__init__.py b/bitlinear/frozen_bitlinear/src/utils/__init__.py new file mode 100644 index 0000000..fd8b129 --- /dev/null +++ b/bitlinear/frozen_bitlinear/src/utils/__init__.py @@ -0,0 +1,2 @@ +from src.utils.Kernel import Packed8, Default +from cuda.kernels import * \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/src/utils/helpers.py b/bitlinear/frozen_bitlinear/src/utils/helpers.py new file mode 100644 index 0000000..2b69422 --- /dev/null +++ b/bitlinear/frozen_bitlinear/src/utils/helpers.py @@ -0,0 +1,13 @@ +import torch + +####### HELPERS ######### + +AbsMax = lambda input : input.abs().max() +AbsMedian = lambda input : input.abs().median() +AbsMean = lambda input : input.abs().mean() + +def round_clamp(input, measure, eps) -> tuple[torch.Tensor, float]: + scale = eval(measure)(input.detach()).clamp_(min=eps).item() + scaled_input = input/scale + return (scaled_input.round().clamp(-1, 1) - scaled_input).detach() + scaled_input, scale + diff --git a/bitlinear/frozen_bitlinear/tests/Benchmark.py b/bitlinear/frozen_bitlinear/tests/Benchmark.py new file mode 100644 index 0000000..8c49e75 --- /dev/null +++ b/bitlinear/frozen_bitlinear/tests/Benchmark.py @@ -0,0 +1,375 @@ +import torch +import torch.nn.functional as F + +import os + +from itertools import product +from tqdm import tqdm + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import plotly.graph_objects as go + +from src import * +from tests.helpers import weights + +class Benchmark: + + options = [256, 512, 1024, 2048, 4096, 8192] + + def __init__(self, args): + + self.path = args.save_dir + + self.device='cuda' + + self.kernel_name = args.kernel + self.kernel = eval(self.kernel_name)() + + self.baseline = TorchLinear() + + print(f'Testing with {self.kernel_name} kernel.') + self.ref_lib = 'cuBLAS' + + if args.a or args.p: + self.profiling() + if args.a or args.u: + self.unittests() + if args.a or args.t: + self.throughputs() + + def unittest(self, M=512, N=256, K=128): + + torch.manual_seed(0) + + inputs = torch.randn((M, K), device=self.device, dtype=torch.float16) + biases = torch.randn((M, 1), device=self.device, dtype=torch.float16) + + weights_torch, weights_kernel, scale = weights(N, K, self.kernel, self.baseline) + + triton_output = self.kernel(inputs, weights_kernel, biases, scale) + torch_output = self.baseline(inputs, weights_torch, biases, scale) + + if torch.allclose(triton_output, torch_output, atol=1e-1, equal_nan=True): + print(f"✅ Triton and Torch matchfor (M={M}, N={N}, K={K})") + return True, 0.0 + else: + difference = torch.max(torch.abs(triton_output - torch_output)).item() / torch.max(torch.abs(torch_output)).item() + print(f"❌ Triton and Torch differ for (M={M}, N={N}, K={K}) -- Maximum Normalized Difference : {difference}") + return False, difference + + def unittests(self, upper_limit=8096, step=128): + + pbar = tqdm(total=upper_limit//step, desc=f"Unit Tests", leave=True) + + results = [] + + for combo in range(step, upper_limit+step, step): + + result, diff = self.unittest(combo, combo, combo) + results.append([combo, result, diff]) + + pbar.set_description(f"M=N=K : {combo} ") + pbar.update() + + pbar.close() + + results = pd.DataFrame(results, columns=['M=N=K', 'Matches Torch', 'Maximum Normalized Difference']) + + save_path = os.path.join(self.path,"unittests") + os.makedirs(save_path, exist_ok=True) + + with pd.ExcelWriter(os.path.join(save_path,"data.xlsx")) as writer: + results.to_excel(writer) + + print(f'\nUnitTest data saved to {save_path}\n') + + print('-------------------------------------') + print(f'Percentage of matching outputs {results["Matches Torch"].astype(int).mean() * 100}%') + print('-------------------------------------') + + fig, ax = plt.subplots(figsize=(10, 6)) + ax.scatter(results['M=N=K'], results['Maximum Normalized Difference'], marker='o', color='b', s=50) + + ax.set_xlabel('Dimension Size', fontsize=14) + ax.set_ylabel('Maximum Normalized Difference in Outputs', fontsize=14) + ax.set_title(f'Comparing PyTorch and {self.kernel_name} Output Tensors', fontsize=16) + ax.grid(True) + + plt.savefig(os.path.join(save_path, "difference.png"), dpi=300) + plt.show() + + def profiling(self, upper_limit=75): + + combinations = list(product(self.options, self.options, self.options)) + + results = { + item : [] + for item in [ + 'M', 'N', 'K', 'Custom_CUDA', 'cuBLAS_CUDA', 'Custom_CPU', 'cuBLAS_CPU', 'CUDA_factor', 'CPU_factor' + ] + } + + for [M, N, K] in combinations: + + print(f" M : {M} | N : {N} | K : {K}") + + results['M'].append(M) + results['N'].append(N) + results['K'].append(K) + + inputs = torch.randn((M, K), device=self.device, dtype=torch.float16) + biases = torch.randn((M, 1), device=self.device, dtype=torch.float16) + + weights_torch, weights_kernel, scale = weights(N, K, self.kernel, self.baseline) + + print("Profiling custom kernel") + with torch.autograd.profiler.profile(use_device = self.device) as prof: + self.kernel(inputs, weights_kernel, biases, scale) + events = prof.key_averages() + results['Custom_CUDA'].append(sum(event.device_time for event in events) / 1000.0) + results['Custom_CPU'].append(sum(event.cpu_time for event in events) / 1000.0) + print(events.table(sort_by="cuda_time_total")) + + print("Profiling cuBLAS kernel") + with torch.autograd.profiler.profile(use_device = self.device) as prof: + self.baseline(inputs, weights_torch, biases, scale) + events = prof.key_averages() + results['cuBLAS_CUDA'].append(sum(event.device_time for event in events) / 1000.0) + results['cuBLAS_CPU'].append(sum(event.cpu_time for event in events) / 1000.0) + print(events.table(sort_by="cuda_time_total")) + + results['CUDA_factor'].append(results['cuBLAS_CUDA'][-1]/results['Custom_CUDA'][-1]) + results['CPU_factor'].append(results['cuBLAS_CPU'][-1]/results['Custom_CPU'][-1]) + + results = pd.DataFrame(results) + + save_path = os.path.join(self.path, "profiling") + os.makedirs(save_path, exist_ok=True) + + with pd.ExcelWriter(os.path.join(save_path, "data.xlsx")) as writer: + results.to_excel(writer) + + print(f'Profiling data saved to {save_path}\n') + + print(f' CUDA improvement = ~{min(results['CUDA_factor'])}-{max(results['CUDA_factor'])}') + print(f' CPU improvement = ~{min(results['CPU_factor'])}-{max(results['CPU_factor'])}') + + #### Plot Over Iterations + def iterations(): + plt.figure(figsize=(12, 12)) + + plt.subplot(2, 1, 1) + plt.plot(results['Custom_CUDA'], label='Custom CUDA Time') + plt.plot(results['cuBLAS_CUDA'], label='cuBLAS CUDA Time') + plt.xlabel('Iteration') + plt.ylabel('CUDA Time (ms)') + plt.ylim(0, max(np.percentile(results['Custom_CUDA'], upper_limit), np.percentile(results['cuBLAS_CUDA'], upper_limit))) + plt.title('CUDA Time Comparison') + plt.legend() + + plt.subplot(2, 1, 2) + plt.plot(results['Custom_CPU'], label='Custom CPU Time') + plt.plot(results['cuBLAS_CPU'], label='cuBLAS CPU Time') + plt.xlabel('Iteration') + plt.ylabel('CPU Time (ms)') + plt.ylim(0, max(np.percentile(results['Custom_CPU'], upper_limit), np.percentile(results['cuBLAS_CPU'], upper_limit))) + plt.title('CPU Time Comparison') + plt.legend() + + plt.tight_layout() + plt.savefig(os.path.join(save_path, "overall.png")) + plt.show() + + + ### Plot Difference Factor + def difference(device): + fig = go.Figure(data=[go.Scatter3d( + x=results['M'], + y=results['N'], + z=results['K'], + mode='markers', + marker=dict( + size=3, + color=results[f'Custom_{device}']/results[f'cuBLAS_{device}'], + colorscale='Viridis', + colorbar=dict(title='Performance Factor (ms)') + ), + text=results[f'Custom_{device}']/results[f'cuBLAS_{device}'], # Add text for tooltips + hovertext=results[f'Custom_{device}']/results[f'cuBLAS_{device}'] # Use the text for hover info + )]) + + fig.update_layout( + title=f'Performance Difference Between Custom_{device} and cuBLAS_{device}', + scene=dict( + xaxis_title='M', + yaxis_title='N', + zaxis_title='K' + ) + ) + + fig.write_html(os.path.join(save_path, f"{device}_factor.html")) + + ### Plot as evolving over individual values + def individual(fixed_axis, device): + row_label = 'N' if fixed_axis == 'M' else 'M' + col_label = 'N' if fixed_axis == 'K' else 'K' + + fig, ax = plt.subplots(len(self.options), len(self.options), figsize=(20, 20)) + + for row, row_val in zip(ax, self.options): + for plot, col_val in zip(row, self.options): + + filtered_df = results.loc[(results[row_label] == row_val) & (results[col_label] == col_val)] + + plot.plot(filtered_df[fixed_axis], filtered_df[f'cuBLAS_{device}'], label='cuBLAS') + plot.plot(filtered_df[fixed_axis], filtered_df[f'Custom_{device}'], label=self.kernel_name) + + plot.set_title(f'{row_label}={row_val}, {col_label}={col_val}', fontsize=10) + plot.tick_params(axis='both', which='major', labelsize=8) + + # Place the legend outside the subplots + handles, labels = ax[0, 0].get_legend_handles_labels() + fig.legend(handles, labels, loc='upper left', fancybox=True, shadow=True, fontsize=15) + + # Set the super title for the figure + fig.suptitle(f'Comparing {self.ref_lib} and {self.kernel_name} on {device} (Plotting over {fixed_axis})', fontsize=20) + fig.text(0.04, 0.5, 'Performance (ms)', va='center', rotation='vertical', fontsize=14) + + fig.tight_layout(rect=[0.05, 0.05, 1, 0.95]) # Adjust the layout to make space for the title and legend + + plt.savefig(os.path.join(save_path, f"{device}_fixed{fixed_axis}.png")) + + iterations() + for device in ['CPU', 'CUDA']: + difference(device) + for fixed_axis in ['M', 'N', 'K']: + individual(fixed_axis, device) + + def throughput(self, M, N, K, iterations=100): + + perf = lambda ms: 2 * M * N * K * iterations * 1e-12 / (ms * 1e-3) + + inputs = torch.randn((M, K), device=self.device, dtype=torch.float16) + biases = torch.randn((M, 1), device=self.device, dtype=torch.float16) + + weights_torch, weights_kernel, scale = weights(N, K, self.kernel, self.baseline) + + cuBlas_time, cuBlas_mem = get_throughput(self.baseline, iterations)(inputs, weights_torch, biases, scale) + + kernel_time, kernel_mem = get_throughput(self.kernel, iterations)(inputs, weights_kernel, biases, scale) + + return [ M, N, K, perf(cuBlas_time), perf(kernel_time), kernel_mem - cuBlas_mem, ] + + def throughputs(self): + combinations = list(product(self.options, self.options, self.options)) + + pbar = tqdm(total=len(combinations), desc=f"Throughput", leave=True) + + results = [] + + for [M, N, K] in combinations: + + results.append(self.throughput(M, N, K)) + + pbar.set_description(f"M = {M} | N = {N} | K = {K} ") + pbar.update() + + pbar.close() + + results = pd.DataFrame(results, columns=['M', 'N', 'K', 'cuBLAS Performance', 'kernel Performance', 'Memory Difference']) + + save_path = os.path.join(self.path, "performance") + os.makedirs(save_path, exist_ok=True) + + with pd.ExcelWriter(os.path.join(save_path, "data.xlsx")) as writer: + results.to_excel(writer) + + print(f'Performance data saved to {save_path}\n') + + ### Performance + for fixed_axis in ['M', 'N', 'K']: + row_label = 'N' if fixed_axis == 'M' else 'M' + col_label = 'N' if fixed_axis == 'K' else 'K' + + fig, ax = plt.subplots(len(self.options), len(self.options), figsize=(20, 20)) + + for row, row_val in zip(ax, self.options): + for plot, col_val in zip(row, self.options): + + filtered_df = results.loc[(results[row_label] == row_val) & (results[col_label] == col_val)] + + plot.plot(filtered_df[fixed_axis], filtered_df[f'cuBLAS Performance'], label='cuBLAS') + plot.plot(filtered_df[fixed_axis], filtered_df[f'kernel Performance'], label=self.kernel_name) + + plot.set_title(f'{row_label}={row_val}, {col_label}={col_val}', fontsize=10) + plot.tick_params(axis='both', which='major', labelsize=8) + + # Place the legend outside the subplots + handles, labels = ax[0, 0].get_legend_handles_labels() + fig.legend(handles, labels, loc='upper left', fancybox=True, shadow=True, fontsize=15) + + # Set the super title for the figure + fig.suptitle(f'Comparing {self.ref_lib} and {self.kernel_name} Performance (Plotting over {fixed_axis})', fontsize=20) + fig.text(0.04, 0.5, 'Performance (TFLOPs)', va='center', rotation='vertical', fontsize=14) + + fig.tight_layout(rect=[0.05, 0.05, 1, 0.95]) # Adjust the layout to make space for the title and legend + + plt.savefig(os.path.join(save_path, f"performance_fixed{fixed_axis}.png")) + plt.show() + + ### memory + for fixed_axis in ['M', 'N', 'K']: + row_label = 'N' if fixed_axis == 'M' else 'M' + col_label = 'N' if fixed_axis == 'K' else 'K' + + fig, ax = plt.subplots(1, len(self.options), figsize=(20, 20)) + + for plot, row_val in zip(ax, self.options): + for col_val in self.options: + + filtered_df = results.loc[(results[row_label] == row_val) & (results[col_label] == col_val)] + + plot.plot(filtered_df[fixed_axis], filtered_df[f'Memory Difference'], label=f'{col_label}={col_val}') + + plot.set_title(f'{row_label}={row_val}', fontsize=10) + plot.legend() + plot.tick_params(axis='both', which='major', labelsize=8) + + + # Set the super title for the figure + fig.suptitle(f'Difference between {self.kernel_name} and {self.ref_lib} Memory (Plotting over {fixed_axis})', fontsize=20) + fig.text(0.04, 0.5, 'Memory Difference (MBs)', va='center', rotation='vertical', fontsize=14) + + fig.tight_layout(rect=[0.05, 0.05, 1, 0.95]) # Adjust the layout to make space for the title and legend + + plt.savefig(os.path.join(save_path, f"memory_fixed{fixed_axis}.png")) + plt.show() + +def get_throughput(func, iterations): + def wrapper(*args): + + for _ in range(10): + func(*args) + + torch.cuda.reset_peak_memory_stats() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + + start_event.record() + + # benchmark + for _ in range(iterations): + func(*args) + + end_event.record() + + torch.cuda.synchronize() + + max_memory_allocated = torch.cuda.max_memory_allocated() / (1024 ** 2) # Convert bytes to MB + + return start_event.elapsed_time(end_event), max_memory_allocated + + return wrapper \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/tests/__main__.py b/bitlinear/frozen_bitlinear/tests/__main__.py new file mode 100644 index 0000000..73c66f4 --- /dev/null +++ b/bitlinear/frozen_bitlinear/tests/__main__.py @@ -0,0 +1,15 @@ +from argparse import ArgumentParser +from tests.Benchmark import Benchmark + + +parser = ArgumentParser() + +parser.add_argument("--kernel", type=str, default='TorchLinear') +parser.add_argument("-a", action='store_true') +parser.add_argument("-p", action='store_true') +parser.add_argument("-u", action='store_true') +parser.add_argument("-t", action='store_true') +parser.add_argument("--save_dir", type=str, default='results/tmp') +args = parser.parse_args() + +Benchmark(args) diff --git a/bitlinear/frozen_bitlinear/tests/helpers.py b/bitlinear/frozen_bitlinear/tests/helpers.py new file mode 100644 index 0000000..1cb9165 --- /dev/null +++ b/bitlinear/frozen_bitlinear/tests/helpers.py @@ -0,0 +1,29 @@ +import torch +import os +import json + +def weights(N, K, kernel, baseline, dtype = torch.float16) -> tuple[torch.Tensor, torch.Tensor, float]: + path = f'weights/{kernel}/{N}_{K}' + try: + + torch_weights = torch.load(os.path.join(path, 'torch.pt'), weights_only=True).to('cuda') + kernel_weights = torch.load(os.path.join(path, 'kernel.pt'), weights_only=True).to('cuda') + with open(os.path.join(path, 'scale.json'), 'r') as f: + scale = json.load(f) + return torch_weights, kernel_weights, scale + + except Exception as e: + print(f"Could not load weights because of {e}, generating new ones at {path}") + + weights = torch.randn((N, K), device='cuda', dtype=dtype) + (base_weights, scale) = baseline.scale_weights(weights) + (kernel_weights, _) = kernel.scale_weights(weights) + + os.makedirs(path, exist_ok=True) + torch.save(base_weights, os.path.join(path, 'torch.pt')) + torch.save(kernel_weights, os.path.join(path, 'kernel.pt')) + with open(os.path.join(path, 'scale.json'), 'w') as f: + json.dump(scale, f) + + return base_weights, kernel_weights, scale + From f5106d40120f5cfd08be668dd42f79fee447078e Mon Sep 17 00:00:00 2001 From: Simon Opsahl Date: Mon, 5 Aug 2024 14:11:21 +0200 Subject: [PATCH 03/10] Continued developing Kernel support for activation Quantization -siop --- bitlinear/frozen_bitlinear/ReadME.md | 25 ++++--- .../frozen_bitlinear/frozen_bitlinear.py | 14 +--- bitlinear/frozen_bitlinear/scripts/clean.sh | 73 +++++++++++++++---- bitlinear/frozen_bitlinear/src/anthropic.py | 6 +- bitlinear/frozen_bitlinear/src/default.py | 73 +------------------ .../frozen_bitlinear/src/utils/Kernel.py | 16 ++-- .../frozen_bitlinear/src/utils/helpers.py | 70 ++++++++++++++++-- bitlinear/frozen_bitlinear/tests/Benchmark.py | 2 +- bitlinear/frozen_bitlinear/tests/helpers.py | 8 +- 9 files changed, 159 insertions(+), 128 deletions(-) diff --git a/bitlinear/frozen_bitlinear/ReadME.md b/bitlinear/frozen_bitlinear/ReadME.md index 0c47631..2a137ce 100644 --- a/bitlinear/frozen_bitlinear/ReadME.md +++ b/bitlinear/frozen_bitlinear/ReadME.md @@ -8,39 +8,42 @@ Make sure you have the correct packages installed in your virtual environment. I Afterwards, you need to build the desired CUDA kernels for use on your machine. To do so, you can selectively build the kernels you would like to test or use through: ``` -> cd path/to/bitlinear/bitlinear/kernels/cuda/pack_weights +> cd path/to/bitlinear/bitlinear/frozen_bitlinear/cuda/pack_weights > python setup.py build_ext --inplace -> cd path/to/bitlinear/bitlinear/kernels/cuda/kernels/* +> cd path/to/bitlinear/bitlinear/frozen_bitlinear/cuda/kernels/* > python setup.py build_ext --inplace ``` You can also build all of them at once by running ``` -> cd path/to/bitlinear/bitlinear/kernels/ +> cd path/to/bitlinear/bitlinear/frozen_bitlinear/ > chmod +x scripts/build.sh > scripts/build.sh ``` ## Testing -Choose one of the kernels available in ```path/to/bitlinear/bitlinear/kernels/src/``` to test against the PyTorch baseline for your device. +Choose one of the kernels available in ```path/to/bitlinear/bitlinear/frozen_bitlinear/src/``` to test against the PyTorch baseline for your device. ``` < conda activate env -< cd path/to/bitlinear/bitlinear/kernels +< cd path/to/bitlinear/bitlinear/frozen_bitlinear < chmod +x scripts/test.sh -< scripts/test.sh - < Which Kernel would you like to test? - < kernel_name +< scripts/test.sh + -d (CUDA_AVAILABLE_DEVICES=$device) + -k ``` -All logs, data, and plots are stored locally under ```path/to/bitlinear/bitlinear/kernels/results/{%Y%m%d_%T}/```. +All logs, data, and plots are stored locally under ```path/to/bitlinear/bitlinear/frozen_bitlinear/results/{%Y%m%d_%T}/```. If any issues come up, you can reach me at ```sopsahl@mit.edu```. ## Cleanup To clean the builds and results, run the following commands. ``` -> cd path/to/bitlinear/bitlinear/kernels +> cd path/to/bitlinear/bitlinear/frozen_bitlinear > chmod +x scripts/clean.sh -> scripts/clean.sh +> scripts/clean.sh + -b (default: builds only) + -w (weights) + -r (results) ``` diff --git a/bitlinear/frozen_bitlinear/frozen_bitlinear.py b/bitlinear/frozen_bitlinear/frozen_bitlinear.py index fa83b15..d3e7c23 100644 --- a/bitlinear/frozen_bitlinear/frozen_bitlinear.py +++ b/bitlinear/frozen_bitlinear/frozen_bitlinear.py @@ -30,21 +30,13 @@ def __init__( ) self.eps = eps - - self.kernel = eval(kernel)(activation_range, activation_measure) if isinstance(kernel, str) else kernel + self.kernel = eval(kernel)(eps, activation_range, activation_measure) if isinstance(kernel, str) else kernel(eps, activation_range, activation_measure) def __repr__(self): return f"FrozenBitLinear(in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}, kernel={self.kernel}), activation_range={self.activation_range}, activation_measure={self.activation_measure}" def forward(self, x): - if self.activation_measure is None: - x_scale, x_quant = 1, x - else: - x_norm = torch.layer_norm(x, x.size()[1:]) - x_scale = 1 / scale(x_norm, self.activation_range, self.activation_measure, True, self.eps) - x_quant = round_clamp(x_norm / x_scale, self.activation_range) - - return self.kernel(x_quant, self.weight, self.bias, self.w_scale * x_scale) + return self.kernel(x, self.weight, self.bias, self.w_scale) def freeze_weights(self, weights:torch.Tensor, weightMeasure:str='AbsMean'): """ @@ -99,7 +91,7 @@ def freeze( default is 1e-5 activation_measure : str str corresponding to the activation quantization method - options are 'AbsMean', 'AbsMax', 'AbsMedian', and None + options are 'AbsMean', 'AbsMax', 'AbsMedian', and 'Fp16' 'AbsMax' activation_range : int number of bits to represent the activations in diff --git a/bitlinear/frozen_bitlinear/scripts/clean.sh b/bitlinear/frozen_bitlinear/scripts/clean.sh index d4ffe47..c9fe2db 100755 --- a/bitlinear/frozen_bitlinear/scripts/clean.sh +++ b/bitlinear/frozen_bitlinear/scripts/clean.sh @@ -1,23 +1,70 @@ #!/bin/bash +# Function to display usage information +usage() { + echo "Usage: $0 [-b] [-w] [-r]" + echo " -b Deletes kernel builds" + echo " -w Delete 'weights' directories" + echo " -r Delete 'results' directories" + exit 1 +} + # Get the directory where the script is located SCRIPT_DIR=$(dirname "$(realpath "$0")") PARENT_DIR="${SCRIPT_DIR}/.." -# Delete files with .so extension -echo "Deleting .so files..." -find "$PARENT_DIR" -type f -name '*.so' -exec rm -f {} \; +# Parse command-line arguments +while getopts ":bwr" opt; do + case ${opt} in + b ) + DELETE_BUILD=true + ;; + w ) + DELETE_WEIGHTS=true + ;; + r ) + DELETE_RESULTS=true + ;; + \? ) + echo "Invalid option: -$OPTARG" >&2 + usage + ;; + : ) + echo "Invalid option: -$OPTARG requires an argument" >&2 + usage + ;; + esac +done + +# Shift parsed options away from the positional parameters +shift $((OPTIND -1)) + +if [ "$#" -eq 0 ]; then + DELETE_BUILD=true +fi + +# Delete .so files if -s is specified +if [ "$DELETE_BUILD" = true ]; then + echo "Deleting .so files..." + find "$PARENT_DIR" -type f -name '*.so' -exec rm -f {} \; +fi -# Delete directories named 'weights' -echo "Deleting 'weights' directories..." -find "$PARENT_DIR" -type d -name 'weights' -exec rm -rf {} \; +# Delete 'build' directories if -b is specified +if [ "$DELETE_BUILD" = true ]; then + echo "Deleting 'build' directories..." + find "$PARENT_DIR" -type d -name 'build' -exec rm -rf {} \; +fi -# Delete directories named 'build' -echo "Deleting 'build' directories..." -find "$PARENT_DIR" -type d -name 'build' -exec rm -rf {} \; +# Delete 'weights' directories if -w is specified +if [ "$DELETE_WEIGHTS" = true ]; then + echo "Deleting 'weights' directories..." + find "$PARENT_DIR" -type d -name 'weights' -exec rm -rf {} \; +fi -# Delete directories named 'build' -echo "Deleting 'results' directories..." -find "$PARENT_DIR" -type d -name 'results' -exec rm -rf {} \; +# Delete 'results' directories if -r is specified +if [ "$DELETE_RESULTS" = true ]; then + echo "Deleting 'results' directories..." + find "$PARENT_DIR" -type d -name 'results' -exec rm -rf {} \; +fi -echo "Cleanup complete." \ No newline at end of file +echo "Cleanup complete." diff --git a/bitlinear/frozen_bitlinear/src/anthropic.py b/bitlinear/frozen_bitlinear/src/anthropic.py index 554e926..8683f37 100644 --- a/bitlinear/frozen_bitlinear/src/anthropic.py +++ b/bitlinear/frozen_bitlinear/src/anthropic.py @@ -9,9 +9,11 @@ def __call__(self, activations, weights, bias, scale): assert activations.shape[0] == bias.shape[0], "Bias dimension must match input" M, K = activations.shape - N = weights.shape[0] * 4//K + N = weights.shape[0] * 4//K - return self.fxn(activations, weights, bias, M, N, K) * scale + x_quant, x_scale = self.activations(input) + + return self.fxn(x_quant, weights, bias, M, N, K) * scale * x_scale class Naive(Anthropic): fxn = naive.linear diff --git a/bitlinear/frozen_bitlinear/src/default.py b/bitlinear/frozen_bitlinear/src/default.py index 9c9cfeb..3b6c5e5 100644 --- a/bitlinear/frozen_bitlinear/src/default.py +++ b/bitlinear/frozen_bitlinear/src/default.py @@ -4,75 +4,6 @@ from src.utils import Default class TorchLinear(Default): - def __call__(self, input, weight, bias=None, scale=1): - return F.linear(input, weight, bias) * scale - - -class TorchMulAdd(TorchLinear): - def __call__(self, input, weight, bias=None, scale=1): - output = input @ weight.t() - if bias is not None: - output += bias - return output - -class Naive(TorchLinear): - def __call__(self, input, weight, bias=None, scale=1): - print(input.shape, weight.shape, bias.shape if bias is not None else None) - input = input.tolist() - weight = weight.tolist() - if bias is not None: - bias = bias.tolist() - output = [] - n = len(input) - m = len(input[0]) - p = len(weight) - for i in range(n): - out = [] - for j in range(p): - value = sum(input[i][k] * weight[j][k] for k in range(m)) - if bias is not None: - value += bias[j] - out.append(value) - output.append(out) - return torch.Tensor(output) * scale - -class NaiveListComp(TorchLinear): - def __call__(self, input, weight, bias=None, scale=1): - input = input.tolist() - weight = weight.tolist() - if bias is not None: - bias = bias.tolist() - output = [] - n = len(input) - m = len(input[0]) - p = len(weight) - output = [[sum(input[i][k] * weight[j][k] for k in range(m)) + (bias[j] if bias is not None else 0) for j in range(p)] for i in range(n)] - return torch.Tensor(output) * scale - -class TernaryNaive(TorchLinear): - def __call__(self, input, weight, bias=None, scale=1): - input = input.tolist() - weight = weight.tolist() - assert all(all(x in {-1, 0, 1} for x in row) for row in input) - if bias is not None: - bias = bias.tolist() - output = [] - n = len(input) - m = len(input[0]) - p = len(weight) - for i in range(n): - out = [] - for j in range(p): - value = 0 - for k in range(m): - match weight[j][k]: - case 1: - value += input[i][k] - case -1: - value -= input[i][k] - if bias is not None: - value += bias[j] - out.append(value) - output.append(out) - return torch.Tensor(output) * scale + x_quant, x_scale = self.activations(input) + return F.linear(x_quant, weight, bias) * scale * x_scale diff --git a/bitlinear/frozen_bitlinear/src/utils/Kernel.py b/bitlinear/frozen_bitlinear/src/utils/Kernel.py index d02e624..77dbdcc 100644 --- a/bitlinear/frozen_bitlinear/src/utils/Kernel.py +++ b/bitlinear/frozen_bitlinear/src/utils/Kernel.py @@ -1,15 +1,19 @@ import torch import torch.nn.functional as F -from src.utils.helpers import round_clamp +from src.utils.helpers import weight_round_clamp, AbsMax, AbsMean, AbsMedian, Fp16 from cuda.pack_weights import pack_weights class Kernel: + def __init__(self, eps=1e-5, activation_range = 8, activation_measure = 'AbsMax'): + + self.eps = eps + self.activations = eval(activation_measure)(activation_range, eps) def __call__(self, activations, weights, bias, scale) -> torch.Tensor: raise NotImplementedError # this needs to be changed depending on the kernel - def scale_weights(self, weights, measure='AbsMean', eps=1e-5) -> tuple[torch.Tensor, float]: + def scale_weights(self, weights, measure='AbsMean') -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError # this needs to be changed depending on the kernel def __repr__(self) -> str: @@ -17,10 +21,10 @@ def __repr__(self) -> str: class Default(Kernel): - def scale_weights(self, weights, measure='AbsMean', eps=1e-5) -> tuple[torch.Tensor, float]: - return round_clamp(weights, measure, eps) + def scale_weights(self, weights, measure='AbsMean') -> tuple[torch.Tensor, torch.Tensor]: + return weight_round_clamp(weights, measure, self.eps) class Packed8(Kernel): - def scale_weights(self, weights, measure='AbsMean', eps=1e-5) -> tuple[torch.Tensor, float]: - weights, scale = round_clamp(weights, measure, eps) + def scale_weights(self, weights, measure='AbsMean') -> tuple[torch.Tensor, torch.Tensor]: + weights, scale = weight_round_clamp(weights, measure, self.eps) return pack_weights.packedint8(weights, *weights.shape), scale diff --git a/bitlinear/frozen_bitlinear/src/utils/helpers.py b/bitlinear/frozen_bitlinear/src/utils/helpers.py index 2b69422..ec27f86 100644 --- a/bitlinear/frozen_bitlinear/src/utils/helpers.py +++ b/bitlinear/frozen_bitlinear/src/utils/helpers.py @@ -1,13 +1,67 @@ import torch +from math import ceil -####### HELPERS ######### + +''' +Activation Quantization Helpers +''' + +def symmetric_range_from_bits(self, range): + return (ceil(-2**(range-1)), ceil(2**(range-1)-1)) + +def round_clamp(self, input): + return (input.round().clamp(self.range[0], self.range[1]) - input).detach() + input + +class ActivationMeasure: + def __init__(self, range=8, eps=1e-5): + self.range = symmetric_range_from_bits(range) + self.eps = eps -AbsMax = lambda input : input.abs().max() -AbsMedian = lambda input : input.abs().median() -AbsMean = lambda input : input.abs().mean() + def __call__(self, input): + x_norm = torch.layer_norm(input, input.size()[1:]) + x_scale = self.scale(x_norm) + return round_clamp(x_norm/x_scale), x_scale -def round_clamp(input, measure, eps) -> tuple[torch.Tensor, float]: - scale = eval(measure)(input.detach()).clamp_(min=eps).item() - scaled_input = input/scale - return (scaled_input.round().clamp(-1, 1) - scaled_input).detach() + scaled_input, scale + def scale(self, input) -> torch.Tensor: + raise NotImplementedError + + def __repr__(self): + return f"{self.__class__.__name__}()" + +''' +Callable Activation Quantization Classes +''' + +class Fp16(ActivationMeasure): + def __init__(self, range=8, eps=1e-5): + pass + def __call__(self, input) -> tuple[torch.Tensor, torch.Tensor]: + return input, torch.Tensor(1.0) + +class AbsMax(ActivationMeasure): + def scale(self, input) -> torch.Tensor: + return input.abs().max(dim=-1, keepdim=True).values.clamp_(min=self.eps)/max(abs(k) for k in self.range) + +class AbsMean(ActivationMeasure): + def scale(self, input) -> torch.Tensor: + return input.abs().mean(dim=-1, keepdim=True).clamp_(min=self.eps)/max(abs(k) for k in self.range) + +class AbsMedian(ActivationMeasure): + def scale(self, input) -> torch.Tensor: + return input.abs().median(dim=-1, keepdim=True).values.clamp_(min=self.eps)/max(abs(k) for k in self.range) + + +''' +Weight Quantization Functions +''' + +class WeightMeasure: + AbsMax = lambda input : input.abs().max() + AbsMedian = lambda input : input.abs().median() + AbsMean = lambda input : input.abs().mean() + +def weight_round_clamp(input, measure, eps) -> tuple[torch.Tensor, float]: + scale = eval(f'WeightMeasure.{measure}')(input.detach()).clamp_(min=eps) + scaled_input = input/scale + return (scaled_input.round().clamp(-1, 1) - scaled_input).detach() + scaled_input, scale \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/tests/Benchmark.py b/bitlinear/frozen_bitlinear/tests/Benchmark.py index 8c49e75..e1a32fe 100644 --- a/bitlinear/frozen_bitlinear/tests/Benchmark.py +++ b/bitlinear/frozen_bitlinear/tests/Benchmark.py @@ -25,7 +25,7 @@ def __init__(self, args): self.device='cuda' self.kernel_name = args.kernel - self.kernel = eval(self.kernel_name)() + self.kernel = eval(self.kernel_name)(activation_measure='Fp16') self.baseline = TorchLinear() diff --git a/bitlinear/frozen_bitlinear/tests/helpers.py b/bitlinear/frozen_bitlinear/tests/helpers.py index 1cb9165..af96d6c 100644 --- a/bitlinear/frozen_bitlinear/tests/helpers.py +++ b/bitlinear/frozen_bitlinear/tests/helpers.py @@ -2,14 +2,13 @@ import os import json -def weights(N, K, kernel, baseline, dtype = torch.float16) -> tuple[torch.Tensor, torch.Tensor, float]: +def weights(N, K, kernel, baseline, dtype = torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: path = f'weights/{kernel}/{N}_{K}' try: torch_weights = torch.load(os.path.join(path, 'torch.pt'), weights_only=True).to('cuda') kernel_weights = torch.load(os.path.join(path, 'kernel.pt'), weights_only=True).to('cuda') - with open(os.path.join(path, 'scale.json'), 'r') as f: - scale = json.load(f) + scale = torch.load(os.path.join(path, 'scale.pt'), weights_only=True).to('cuda') return torch_weights, kernel_weights, scale except Exception as e: @@ -22,8 +21,7 @@ def weights(N, K, kernel, baseline, dtype = torch.float16) -> tuple[torch.Tensor os.makedirs(path, exist_ok=True) torch.save(base_weights, os.path.join(path, 'torch.pt')) torch.save(kernel_weights, os.path.join(path, 'kernel.pt')) - with open(os.path.join(path, 'scale.json'), 'w') as f: - json.dump(scale, f) + torch.save(scale, os.path.join(path, 'scale.pt')) return base_weights, kernel_weights, scale From 109f9c8c461916a7f3c887cdc79f32f21a5077ee Mon Sep 17 00:00:00 2001 From: Simon Opsahl Date: Mon, 5 Aug 2024 15:09:01 +0200 Subject: [PATCH 04/10] README completed with code examples and explanations --- bitlinear/frozen_bitlinear/ReadME.md | 119 --------------------------- 1 file changed, 119 deletions(-) delete mode 100644 bitlinear/frozen_bitlinear/ReadME.md diff --git a/bitlinear/frozen_bitlinear/ReadME.md b/bitlinear/frozen_bitlinear/ReadME.md deleted file mode 100644 index 2a137ce..0000000 --- a/bitlinear/frozen_bitlinear/ReadME.md +++ /dev/null @@ -1,119 +0,0 @@ -## Setup -Make sure you have the correct packages installed in your virtual environment. In a conda environment, you can run: -``` -> conda create -n env python pip -> conda activate env -> pip install -r requirements.txt -``` - -Afterwards, you need to build the desired CUDA kernels for use on your machine. To do so, you can selectively build the kernels you would like to test or use through: -``` -> cd path/to/bitlinear/bitlinear/frozen_bitlinear/cuda/pack_weights -> python setup.py build_ext --inplace -> cd path/to/bitlinear/bitlinear/frozen_bitlinear/cuda/kernels/* -> python setup.py build_ext --inplace -``` -You can also build all of them at once by running -``` -> cd path/to/bitlinear/bitlinear/frozen_bitlinear/ -> chmod +x scripts/build.sh -> scripts/build.sh -``` - -## Testing -Choose one of the kernels available in ```path/to/bitlinear/bitlinear/frozen_bitlinear/src/``` to test against the PyTorch baseline for your device. -``` -< conda activate env -< cd path/to/bitlinear/bitlinear/frozen_bitlinear -< chmod +x scripts/test.sh -< scripts/test.sh - -d (CUDA_AVAILABLE_DEVICES=$device) - -k -``` - -All logs, data, and plots are stored locally under ```path/to/bitlinear/bitlinear/frozen_bitlinear/results/{%Y%m%d_%T}/```. - -If any issues come up, you can reach me at ```sopsahl@mit.edu```. - -## Cleanup -To clean the builds and results, run the following commands. -``` -> cd path/to/bitlinear/bitlinear/frozen_bitlinear -> chmod +x scripts/clean.sh -> scripts/clean.sh - -b (default: builds only) - -w (weights) - -r (results) -``` - - - - - - - - - -## Motivations - -Matrix multiplications are a key building block of most modern high-performance computing systems. -They are notoriously hard to optimize, hence their implementation is generally done by -hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS). -Unfortunately, these libraries are often proprietary and cannot be easily customized -to accommodate the needs of modern deep learning workloads (e.g., fused activation functions). -In comes Triton, which is easily customizeable and works to fit our need. - -Roughly speaking, the traditional kernel implements the following blocked -algorithm to multiply a (M, K) by a (K, N) matrix: - -```python -# Do in parallel -for m in range(0, M, BLOCK_SIZE_M): - # Do in parallel - for n in range(0, N, BLOCK_SIZE_N): - acc = zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=float32) - for k in range(0, K, BLOCK_SIZE_K): - a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K] - b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N] - acc += dot(a, b) - C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc -``` -where each iteration of the doubly-nested for-loop is performed by a dedicated Triton program instance. - -In a linear instance, all that is changed is the addition of the bias on the final step. - -## Compute Kernel - -The above algorithm is, actually, fairly straightforward to implement in Triton. -The main difficulty comes from the computation of the memory locations at which blocks -of ```A``` and ```B``` must be read in the inner loop. For that, we need -multi-dimensional pointer arithmetic. - -### Pointer Arithmetic - -For a row-major 2D tensor `X`, the memory location of `X[i, j]` is given -by `&X[i, j] = X + i*stride_xi + j*stride_xj`. -Therefore, blocks of pointers for `A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]` and -`B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]` can be defined in pseudo-code as: - -```python -&A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K] = a_ptr + (m : m+BLOCK_SIZE_M)[:, None]*A.stride(0) + (k : k+BLOCK_SIZE_K)[None, :]*A.stride(1) -&B[k : k+BLOCK_SIZE_K, n:n+BLOCK_SIZE_N] = b_ptr + (k : k+BLOCK_SIZE_K)[:, None]*B.stride(0) + (n : n+BLOCK_SIZE_N)[None, :]*B.stride(1) -``` - -### L2 Cache Optimizations - -As mentioned above, each program instance computes a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` -block of `C`. -It is important to remember that the order in which these blocks are computed does -matter, since it affects the L2 cache hit rate of our program, and unfortunately, a -simple row-major ordering - -```Python -pid = tl.program_id(axis=0) -grid_n = tl.cdiv(N, BLOCK_SIZE_N) -pid_m = pid // grid_n -pid_n = pid % grid_n -``` - -is just not going to cut it in the traditional case. When we are evaluating the bitlinear case, however, we must remember that the much more expensive operation is the activation load, not the weight load. We need to find balance between a simple row-major ordering that only requires one load for each block on `axis=0`, and one that optimizes over grouped clocking methods. \ No newline at end of file From d2b2daf54af75d34d89a5171ba50db07f0c8c50e Mon Sep 17 00:00:00 2001 From: Simon Opsahl Date: Mon, 5 Aug 2024 15:09:08 +0200 Subject: [PATCH 05/10] README completed with code examples and explanations --- bitlinear/frozen_bitlinear/README.md | 136 +++++++++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 bitlinear/frozen_bitlinear/README.md diff --git a/bitlinear/frozen_bitlinear/README.md b/bitlinear/frozen_bitlinear/README.md new file mode 100644 index 0000000..86039a0 --- /dev/null +++ b/bitlinear/frozen_bitlinear/README.md @@ -0,0 +1,136 @@ +## Setup +Make sure you have the correct packages installed in your virtual environment. In a conda environment, you can run: +``` +> conda create -n env python pip +> conda activate env +> pip install -r requirements.txt +``` + +Afterwards, you need to build the desired CUDA kernels for use on your machine. To do so, you can selectively build the kernels you would like to test or use through: +``` +> cd path/to/bitlinear/bitlinear/frozen_bitlinear/cuda/pack_weights +> python setup.py build_ext --inplace +> cd path/to/bitlinear/bitlinear/frozen_bitlinear/cuda/kernels/* +> python setup.py build_ext --inplace +``` +You can also build all of them at once by running +``` +> cd path/to/bitlinear/bitlinear/frozen_bitlinear/ +> chmod +x scripts/build.sh +> scripts/build.sh +``` + +## Testing +Choose one of the kernels available in ```path/to/bitlinear/bitlinear/frozen_bitlinear/src/``` to test against the PyTorch baseline for your device. +``` +< conda activate env +< cd path/to/bitlinear/bitlinear/frozen_bitlinear +< chmod +x scripts/test.sh +< scripts/test.sh + -d (CUDA_AVAILABLE_DEVICES=$device) + -k +``` + +All logs, data, and plots are stored locally under ```path/to/bitlinear/bitlinear/frozen_bitlinear/results/{%Y%m%d_%T}/```. + +If any issues come up, you can reach me at [sopsahl@mit.edu](mailto:sopsahl@mit.edu). + +## Cleanup +To clean the builds and results, run the following commands. +``` +> cd path/to/bitlinear/bitlinear/frozen_bitlinear +> chmod +x scripts/clean.sh +> scripts/clean.sh + -b (default: builds only) + -w (weights) + -r (results) +``` + + + + + + +## Motivations + +Matrix multiplications are a key building block of most modern high-performance computing systems. +They are notoriously hard to optimize, hence their implementation is generally done by +hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS). +Unfortunately, these libraries are often proprietary and cannot be easily customized +to accommodate the needs of our ternary system. Attempts at [matmul-free linear implementations](https://github.com/ridgerchu/matmulfreellm) still rely on the cuBLAS kernel for the matmul. The goal of this investigation is to develop a CUDA kernel competitive with cuBLAS in speed and performance, and ultimately far faster, as a ternary system can be implemented by a series of masked adds as opposed to traditional matrix multiplication. +Roughly speaking, the traditional kernel implements the following blocked +algorithm to multiply a (M, K) by a (K, N) matrix: + +```C++ + const uint col = blockIdx.x * blockDim.x + threadIdx.x; + const uint row = blockIdx.y * blockDim.y + threadIdx.y; + + if (row < M && col < N) { + + for (int k = 0; k < K) { + sum += A[row, k] * B[k, col]; + } + + output[row * N + col] = sum + bias[row]; + } +``` +where each instance of the above code is a single thread. + +We hope to develop a kernel that can avoid the multiplicative step, instead following the successive framework: + +```C++ + const uint col = blockIdx.x * blockDim.x + threadIdx.x; + const uint row = blockIdx.y * blockDim.y + threadIdx.y; + + if (row < M && col < N) { + + for (int k = 0; k < K) { + if (B[k, col] == 1) { + sum += A[row, k]; + } else if (B[k, col] == -1) { + sum -= A[row, k]; + } + } + + output[row * N + col] = sum + bias[row]; + } +``` + +The code above corresponds to excution of a single thread. For more information on how CUDA parallelization is implemented in code and in reality, visit [here](https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/). + +### Optimizations + +To see savings from our algorithm, we turn to weight packing algorithms. If we represent a weight in two bits, then we can take an eighth of the memory to store the weights, resulting in potentially significant speedups if memory loads are the limiting factor (as they often are). + +Because weight loads are significantly cheaper per capita than activation loads, we want to more heavily prioritize cache hits on activation loads. To do so, we can follow a number of strategies. Within each thread, we can compute a greater number of computations with the same activations. This may lead to speedups in memory loads, but significant optimization is required to find the balance between memory savings and parallelization costs. We can also use shared memory between threads, which allows for threads to compute different outputs in parallel, but without reloading activations each time. A barebones implementation is shown below: + +```C++ + const uint col = blockIdx.x * blockDim.x + threadIdx.x; + const uint row = blockIdx.y * blockDim.y + threadIdx.y; + + __shared__ half shared_activations[]; + + if (row < M && col < N) { + + float sum = 0; + + for (int k = 0; k < K) { + shared_activations[k] = A[row, k]; + } + + __syncthreads(); + + for (int k = 0; k < K) { + if (B[k, col] == 1) { + sum += A[row, k]; + } else if (B[k, col] == -1) { + sum -= A[row, k]; + } + } + + output[row * N + col] = sum + bias[row]; +} +``` +The tradeoff with this implementation is the reliance on threads to be synchronized within the thread call, which is necessary to avoid race conditions but results in a performance penalty. + +Further optimization is required to balance the size of K, M and N in scheduling the batch sizes to schedule to the block and individual threads. Because cuBLAS is highly optimized in this regard, we need a significant speedup to offset their gains. Later work can go into developing autotuning algorithms that dynamically schedule the computation as to maximize the L2 cache hit ratio and minimize synchronization delay. \ No newline at end of file From ae0b54ad0715ccf558825b74cf6c0fa0d12020d6 Mon Sep 17 00:00:00 2001 From: Simon Opsahl Date: Tue, 6 Aug 2024 10:46:27 +0200 Subject: [PATCH 06/10] Refactored kernels and added warptiling beta -siop --- bitlinear/frozen_bitlinear/README.md | 114 +++++++------- .../frozen_bitlinear/cuda/kernels/__init__.py | 2 +- .../cuda/kernels/anthropic/setup.py | 23 --- .../files => archive}/bitlinear_naive.cu | 0 .../{naive_linear => archive}/naive_linear.cu | 0 .../{no_sync/linear.cu => archive/no_sync.cu} | 0 .../{anthropic/files => archive}/row_major.cu | 0 .../naive.cu => archive/shared_memory.cu} | 0 .../linear.cu => archive/streamed_linear.cu} | 0 .../cuda/kernels/naive_linear/naive.cu | 82 ++++++++++ .../cuda/kernels/naive_linear/setup.py | 6 +- .../cuda/kernels/streamed_linear/setup.py | 19 --- .../cuda/kernels/warptiling/kernels.cuh | 4 + .../kernels/warptiling/kernels/warptiling.cuh | 145 ++++++++++++++++++ .../kernels/warptiling_bitlinear.cuh | 107 +++++++++++++ .../cuda/kernels/warptiling/runner.cu | 117 ++++++++++++++ .../kernels/{no_sync => warptiling}/setup.py | 6 +- bitlinear/frozen_bitlinear/tests/Benchmark.py | 4 +- 18 files changed, 525 insertions(+), 104 deletions(-) delete mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/anthropic/setup.py rename bitlinear/frozen_bitlinear/cuda/kernels/{anthropic/files => archive}/bitlinear_naive.cu (100%) rename bitlinear/frozen_bitlinear/cuda/kernels/{naive_linear => archive}/naive_linear.cu (100%) rename bitlinear/frozen_bitlinear/cuda/kernels/{no_sync/linear.cu => archive/no_sync.cu} (100%) rename bitlinear/frozen_bitlinear/cuda/kernels/{anthropic/files => archive}/row_major.cu (100%) rename bitlinear/frozen_bitlinear/cuda/kernels/{anthropic/files/naive.cu => archive/shared_memory.cu} (100%) rename bitlinear/frozen_bitlinear/cuda/kernels/{streamed_linear/linear.cu => archive/streamed_linear.cu} (100%) create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/naive.cu delete mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/streamed_linear/setup.py create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels.cuh create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cuh create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling_bitlinear.cuh create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/warptiling/runner.cu rename bitlinear/frozen_bitlinear/cuda/kernels/{no_sync => warptiling}/setup.py (79%) diff --git a/bitlinear/frozen_bitlinear/README.md b/bitlinear/frozen_bitlinear/README.md index 86039a0..23a194f 100644 --- a/bitlinear/frozen_bitlinear/README.md +++ b/bitlinear/frozen_bitlinear/README.md @@ -1,55 +1,3 @@ -## Setup -Make sure you have the correct packages installed in your virtual environment. In a conda environment, you can run: -``` -> conda create -n env python pip -> conda activate env -> pip install -r requirements.txt -``` - -Afterwards, you need to build the desired CUDA kernels for use on your machine. To do so, you can selectively build the kernels you would like to test or use through: -``` -> cd path/to/bitlinear/bitlinear/frozen_bitlinear/cuda/pack_weights -> python setup.py build_ext --inplace -> cd path/to/bitlinear/bitlinear/frozen_bitlinear/cuda/kernels/* -> python setup.py build_ext --inplace -``` -You can also build all of them at once by running -``` -> cd path/to/bitlinear/bitlinear/frozen_bitlinear/ -> chmod +x scripts/build.sh -> scripts/build.sh -``` - -## Testing -Choose one of the kernels available in ```path/to/bitlinear/bitlinear/frozen_bitlinear/src/``` to test against the PyTorch baseline for your device. -``` -< conda activate env -< cd path/to/bitlinear/bitlinear/frozen_bitlinear -< chmod +x scripts/test.sh -< scripts/test.sh - -d (CUDA_AVAILABLE_DEVICES=$device) - -k -``` - -All logs, data, and plots are stored locally under ```path/to/bitlinear/bitlinear/frozen_bitlinear/results/{%Y%m%d_%T}/```. - -If any issues come up, you can reach me at [sopsahl@mit.edu](mailto:sopsahl@mit.edu). - -## Cleanup -To clean the builds and results, run the following commands. -``` -> cd path/to/bitlinear/bitlinear/frozen_bitlinear -> chmod +x scripts/clean.sh -> scripts/clean.sh - -b (default: builds only) - -w (weights) - -r (results) -``` - - - - - ## Motivations @@ -133,4 +81,64 @@ Because weight loads are significantly cheaper per capita than activation loads, ``` The tradeoff with this implementation is the reliance on threads to be synchronized within the thread call, which is necessary to avoid race conditions but results in a performance penalty. -Further optimization is required to balance the size of K, M and N in scheduling the batch sizes to schedule to the block and individual threads. Because cuBLAS is highly optimized in this regard, we need a significant speedup to offset their gains. Later work can go into developing autotuning algorithms that dynamically schedule the computation as to maximize the L2 cache hit ratio and minimize synchronization delay. \ No newline at end of file +Further optimization is required to balance the size of K, M and N in scheduling the batch sizes to schedule to the block and individual threads. Because cuBLAS is highly optimized in this regard, we need a significant speedup to offset their gains. Later work can go into developing autotuning algorithms that dynamically schedule the computation as to maximize the L2 cache hit ratio and minimize synchronization delay. + +## Limitations + +As of now, this optimization extends only into inference. The method of packing weights seems redundant in training, as the shadow weights being stored in higher precision is [necessary for gradient accumulation](https://arxiv.org/pdf/2402.17764). The speedups still can be realized in the forward passes, but more investigation and more robust kernels must be devised to turn these into reality. + +Further work can also be spent looking into fusing these kernels to realize even faster speedups. This has been growing in popularity recently, and luckily, alot of the work [has been done](https://github.com/ridgerchu/matmulfreellm) for ternary weight layers, we just need to plug and play with our matmul kernel. + +## Setup +Make sure you have the correct packages installed in your virtual environment. In a conda environment, you can run: +``` +> conda create -n env python pip +> conda activate env +> pip install -r requirements.txt +``` + +Afterwards, you need to build the desired CUDA kernels for use on your machine. To do so, you can selectively build the kernels you would like to test or use through: +``` +> cd path/to/bitlinear/bitlinear/frozen_bitlinear/cuda/pack_weights +> python setup.py build_ext --inplace +> cd path/to/bitlinear/bitlinear/frozen_bitlinear/cuda/kernels/* +> python setup.py build_ext --inplace +``` +You can also build all of them at once by running +``` +> cd path/to/bitlinear/bitlinear/frozen_bitlinear/ +> chmod +x scripts/build.sh +> scripts/build.sh +``` + +## Testing +Choose one of the kernels available in ```path/to/bitlinear/bitlinear/frozen_bitlinear/src/``` to test against the PyTorch baseline for your device. +``` +< conda activate env +< cd path/to/bitlinear/bitlinear/frozen_bitlinear +< chmod +x scripts/test.sh +< scripts/test.sh + -d (CUDA_AVAILABLE_DEVICES=$device) + -k +``` + +All logs, data, and plots are stored locally under ```path/to/bitlinear/bitlinear/frozen_bitlinear/results/{%Y%m%d_%T}/```. + +If any issues come up, you can reach me at [sopsahl@mit.edu](mailto:sopsahl@mit.edu). + +All benchmarks have been tested on an A6000. + +## Cleanup +To clean the builds and results, run the following commands. +``` +> cd path/to/bitlinear/bitlinear/frozen_bitlinear +> chmod +x scripts/clean.sh +> scripts/clean.sh + -b (default: builds only) + -w (weights) + -r (results) +``` + + +## References +... \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/__init__.py b/bitlinear/frozen_bitlinear/cuda/kernels/__init__.py index e6b9d63..4e564b9 100644 --- a/bitlinear/frozen_bitlinear/cuda/kernels/__init__.py +++ b/bitlinear/frozen_bitlinear/cuda/kernels/__init__.py @@ -1 +1 @@ -from cuda.kernels.anthropic import naive, bitlinear_naive \ No newline at end of file +from cuda.kernels.naive_linear import naive \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/anthropic/setup.py b/bitlinear/frozen_bitlinear/cuda/kernels/anthropic/setup.py deleted file mode 100644 index 1e7e465..0000000 --- a/bitlinear/frozen_bitlinear/cuda/kernels/anthropic/setup.py +++ /dev/null @@ -1,23 +0,0 @@ -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension -import torch - -if not torch.cuda.is_available(): - raise RuntimeError("CUDA is not available. The extension requires CUDA.") - -setup( - name='naive', - ext_modules=[ - CUDAExtension( - name='naive', - sources=['files/naive.cu'] - ), - CUDAExtension( - name='bitlinear_naive', - sources=['files/bitlinear_naive.cu'] - ) - ], - cmdclass={ - 'build_ext': BuildExtension - } -) diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/anthropic/files/bitlinear_naive.cu b/bitlinear/frozen_bitlinear/cuda/kernels/archive/bitlinear_naive.cu similarity index 100% rename from bitlinear/frozen_bitlinear/cuda/kernels/anthropic/files/bitlinear_naive.cu rename to bitlinear/frozen_bitlinear/cuda/kernels/archive/bitlinear_naive.cu diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/naive_linear.cu b/bitlinear/frozen_bitlinear/cuda/kernels/archive/naive_linear.cu similarity index 100% rename from bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/naive_linear.cu rename to bitlinear/frozen_bitlinear/cuda/kernels/archive/naive_linear.cu diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/no_sync/linear.cu b/bitlinear/frozen_bitlinear/cuda/kernels/archive/no_sync.cu similarity index 100% rename from bitlinear/frozen_bitlinear/cuda/kernels/no_sync/linear.cu rename to bitlinear/frozen_bitlinear/cuda/kernels/archive/no_sync.cu diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/anthropic/files/row_major.cu b/bitlinear/frozen_bitlinear/cuda/kernels/archive/row_major.cu similarity index 100% rename from bitlinear/frozen_bitlinear/cuda/kernels/anthropic/files/row_major.cu rename to bitlinear/frozen_bitlinear/cuda/kernels/archive/row_major.cu diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/anthropic/files/naive.cu b/bitlinear/frozen_bitlinear/cuda/kernels/archive/shared_memory.cu similarity index 100% rename from bitlinear/frozen_bitlinear/cuda/kernels/anthropic/files/naive.cu rename to bitlinear/frozen_bitlinear/cuda/kernels/archive/shared_memory.cu diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/streamed_linear/linear.cu b/bitlinear/frozen_bitlinear/cuda/kernels/archive/streamed_linear.cu similarity index 100% rename from bitlinear/frozen_bitlinear/cuda/kernels/streamed_linear/linear.cu rename to bitlinear/frozen_bitlinear/cuda/kernels/archive/streamed_linear.cu diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/naive.cu b/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/naive.cu new file mode 100644 index 0000000..ca85336 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/naive.cu @@ -0,0 +1,82 @@ +#include +#include +#include +#include +#include + +#define CEIL_DIV(x, y) ((x) + (y)-1) / (y) + +__global__ void naive_kernel( + const half *input, + const int8_t *weights, + const half *bias, + half *output, + int M, + int N, + int K + ) +{ + const uint col = blockIdx.x * blockDim.x + threadIdx.x; + const uint row = blockIdx.y * blockDim.y + threadIdx.y; + + if (row < M && col < N) { + + float sum = 0.0f; + int8_t weight; + + for (int k = 0; k < K; k += 4) { + weight = weights[(col * K + k) >> 2]; + + for (int offset=0; offset<4; offset++) { + int8_t mask = (weight & (3 << (2 * offset))) >> (2 * offset); + + float input_val = __half2float(input[row * K + k + offset]); + + if (mask == 1) { + sum += input_val; + } else if (mask == 2) { + sum -= input_val; + } + } + + } + + // Store the result into the correct location in memory + output[row * N + col] = __float2half(sum + __half2float(bias[row])); + + } +} + +torch::Tensor linear( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + int M, + int N, + int K +) { + auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0); + auto output = torch::zeros({M, N}, options); + + uint blockSize = 32; + + dim3 dimGrid(CEIL_DIV(N, blockSize), CEIL_DIV(M, blockSize)); + dim3 dimBlock(blockSize, blockSize); + + naive_kernel<<>>( + reinterpret_cast(input.data_ptr()), + weights.data_ptr(), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(output.data_ptr()), + M, N, K + ); + + return output; +} + + +// Binding to generate the .so file, to call from Python. +PYBIND11_MODULE(naive, m) { + m.doc() = "Implementation of bitlinear forward linear in CUDA"; + m.def("linear", &linear, "bitlinear_forward (CUDA)"); +} diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/setup.py b/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/setup.py index e65bc8f..417a092 100644 --- a/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/setup.py +++ b/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/setup.py @@ -6,11 +6,11 @@ raise RuntimeError("CUDA is not available. The extension requires CUDA.") setup( - name='naive_linear_cuda', + name='naive', ext_modules=[ CUDAExtension( - name='naive_linear_cuda', - sources=['naive_linear.cu'] + name='naive', + sources=['naive.cu'] ) ], cmdclass={ diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/streamed_linear/setup.py b/bitlinear/frozen_bitlinear/cuda/kernels/streamed_linear/setup.py deleted file mode 100644 index 0fdf5f0..0000000 --- a/bitlinear/frozen_bitlinear/cuda/kernels/streamed_linear/setup.py +++ /dev/null @@ -1,19 +0,0 @@ -from setuptools import setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension -import torch - -if not torch.cuda.is_available(): - raise RuntimeError("CUDA is not available. The extension requires CUDA.") - -setup( - name='streamed_linear_cuda', - ext_modules=[ - CUDAExtension( - name='streamed_linear_cuda', - sources=['linear.cu'] - ) - ], - cmdclass={ - 'build_ext': BuildExtension - } -) diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels.cuh b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels.cuh new file mode 100644 index 0000000..cdd35a0 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels.cuh @@ -0,0 +1,4 @@ +#pragma once + +#include "kernels/warptiling.cuh" +#include "kernels/warptiling_bitlinear.cuh" \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cuh b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cuh new file mode 100644 index 0000000..f40311c --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cuh @@ -0,0 +1,145 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) +const int WARPSIZE = 32; // warpSize is not constexpr + +namespace wt { +template +__device__ void loadFromGmem(int N, int K, const half *A, const half *B, + half *As, half *Bs, int innerRowA, int innerColA, + int innerRowB, int innerColB) { + for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { + const half2 tmp = reinterpret_cast( + &A[(innerRowA + offset) * K + innerColA * 2])[0]; + As[(innerColA * 2 + 0) * BM + innerRowA + offset] = tmp.x; + As[(innerColA * 2 + 1) * BM + innerRowA + offset] = tmp.y; + } + + for (uint offset = 0; offset + rowStrideB <= BK; offset += rowStrideB) { + reinterpret_cast( + &Bs[(innerRowB + offset) * BN + innerColB * 2])[0] = + reinterpret_cast( + &B[(innerRowB + offset) * N + innerColB * 2])[0]; + } +} + +template +__device__ void +processFromSmem(half *regM, half *regN, float *threadResults, const half *As, + const half *Bs, const uint warpRow, const uint warpCol, + const uint threadRowInWarp, const uint threadColInWarp) { + for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { + for (uint i = 0; i < TM; ++i) { + regM[wSubRowIdx * TM + i] = + As[(dotIdx * BM) + warpRow * WM + wSubRowIdx * WSUBM + + threadRowInWarp * TM + i]; + } + } + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + for (uint i = 0; i < TN; ++i) { + regN[wSubColIdx * TN + i] = + Bs[(dotIdx * BN) + warpCol * WN + wSubColIdx * WSUBN + + threadColInWarp * TN + i]; + } + } + + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { + for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { + threadResults[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + (wSubColIdx * TN) + resIdxN] += + __half2float(regM[wSubRowIdx * TM + resIdxM]) * + __half2float(regN[wSubColIdx * TN + resIdxN]); + } + } + } + } + } +} + +} // namespace wt + +template +__global__ void __launch_bounds__(NUM_THREADS) + sgemmWarptiling(int M, int N, int K, float alpha, half *A, half *B, + float beta, half *C) { + const uint cRow = blockIdx.y; + const uint cCol = blockIdx.x; + + const uint warpIdx = threadIdx.x / WARPSIZE; + const uint warpCol = warpIdx % (BN / WN); + const uint warpRow = warpIdx / (BN / WN); + + constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER); + constexpr uint WSUBM = WM / WMITER; + constexpr uint WSUBN = WN / WNITER; + + const uint threadIdxInWarp = threadIdx.x % WARPSIZE; + const uint threadColInWarp = threadIdxInWarp % (WSUBN / TN); + const uint threadRowInWarp = threadIdxInWarp / (WSUBN / TN); + + __shared__ half As[BM * BK]; + __shared__ half Bs[BK * BN]; + + A += cRow * BM * K; + B += cCol * BN; + C += (cRow * BM + warpRow * WM) * N + cCol * BN + warpCol * WN; + + const uint innerRowA = threadIdx.x / (BK / 2); + const uint innerColA = threadIdx.x % (BK / 2); + constexpr uint rowStrideA = (NUM_THREADS * 2) / BK; + const uint innerRowB = threadIdx.x / (BN / 2); + const uint innerColB = threadIdx.x % (BN / 2); + constexpr uint rowStrideB = NUM_THREADS / (BN / 2); + + float threadResults[WMITER * TM * WNITER * TN] = {0.0}; + half regM[WMITER * TM] = {0.0}; + half regN[WNITER * TN] = {0.0}; + + for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { + wt::loadFromGmem( + N, K, A, B, As, Bs, innerRowA, innerColA, innerRowB, innerColB); + __syncthreads(); + wt::processFromSmem(regM, regN, threadResults, As, Bs, warpRow, warpCol, + threadRowInWarp, threadColInWarp); + A += BK; + B += BK * N; + __syncthreads(); + } + + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + half *C_interim = C + (wSubRowIdx * WSUBM) * N + wSubColIdx * WSUBN; + for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) { + for (uint resIdxN = 0; resIdxN < TN; resIdxN += 2) { + half2 tmp = reinterpret_cast( + &C_interim[(threadRowInWarp * TM + resIdxM) * N + + threadColInWarp * TN + resIdxN])[0]; + const int i = (wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + wSubColIdx * TN + resIdxN; + tmp.x = __float2half(alpha * threadResults[i + 0] + beta * __half2float(tmp.x)); + tmp.y = __float2half(alpha * threadResults[i + 1] + beta * __half2float(tmp.y)); + reinterpret_cast( + &C_interim[(threadRowInWarp * TM + resIdxM) * N + + threadColInWarp * TN + resIdxN])[0] = tmp; + } + } + } + } +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling_bitlinear.cuh b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling_bitlinear.cuh new file mode 100644 index 0000000..a9e389a --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling_bitlinear.cuh @@ -0,0 +1,107 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) +const int WARPSIZE = 32; // warpSize is not constexpr + +namespace wt { +template +__device__ void loadFromGmem(int N, int K, const half *A, const uint8_t *B, half *As, uint8_t *Bs, int innerRowA, int innerColA, int innerRowB, int innerColB) { + for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { + const float4 tmp = reinterpret_cast(&A[(innerRowA + offset) * K + innerColA * 4])[0]; + As[(innerColA * 4 + 0) * BM + innerRowA + offset] = __float2half(tmp.x); + As[(innerColA * 4 + 1) * BM + innerRowA + offset] = __float2half(tmp.y); + As[(innerColA * 4 + 2) * BM + innerRowA + offset] = __float2half(tmp.z); + As[(innerColA * 4 + 3) * BM + innerRowA + offset] = __float2half(tmp.w); + } + + for (uint offset = 0; offset + rowStrideB <= BK; offset += rowStrideB) { + reinterpret_cast(&Bs[(innerRowB + offset) * BN + innerColB])[0] = + reinterpret_cast(&B[(innerRowB + offset) * (N / 4) + innerColB])[0]; + } +} + +template +__device__ void processFromSmem(half *regM, uint8_t *regN, float *threadResults, const half *As, const uint8_t *Bs, const uint warpRow, const uint warpCol, const uint threadRowInWarp, const uint threadColInWarp) { + for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { + // populate registers for whole warptile + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { + for (uint i = 0; i < TM; ++i) { + regM[wSubRowIdx * TM + i] = As[(dotIdx * BM) + warpRow * WM + wSubRowIdx * WSUBM + threadRowInWarp * TM + i]; + } + } + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + for (uint i = 0; i < TN; ++i) { + regN[wSubColIdx * TN + i] = Bs[(dotIdx * BN) + warpCol * WN + wSubColIdx * WSUBN + threadColInWarp * TN + i]; + } + } + + // execute warptile matmul + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + // calculate per-thread results + for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { + for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { + // Unpack and process 2-bit weights + uint8_t weight = (regN[wSubColIdx * TN + resIdxN / 4] >> ((resIdxN % 4) * 2)) & 0x03; + float val = __half2float(regM[wSubRowIdx * TM + resIdxM]); + if (weight == 1) { + threadResults[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + (wSubColIdx * TN) + resIdxN] += val; + } else if (weight == 2) { + threadResults[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + (wSubColIdx * TN) + resIdxN] -= val; + } + } + } + } + } + } +} + +template +__global__ void sgemmWarptiling_bitlinear(const int M, const int N, const int K, const float alpha, const half *A, const uint8_t *B, const float beta, half *C) { + const int tid = threadIdx.x; + const int warpId = tid / 32; + const int laneId = tid % 32; + + const int warpRow = warpId / (BN / WN); + const int warpCol = warpId % (BN / WN); + const int threadRowInWarp = laneId / (TN / 4); + const int threadColInWarp = laneId % (TN / 4); + + __shared__ half As[BM * BK]; + __shared__ uint8_t Bs[BK * BN]; + + float threadResults[TM * TN] = {0}; + + for (int i = 0; i < K; i += BK) { + wt::loadFromGmem(N, K, A, B, As, Bs, warpRow, i, warpCol, i); + __syncthreads(); + + wt::processFromSmem(threadResults, As, Bs, warpRow, warpCol, threadRowInWarp, threadColInWarp); + __syncthreads(); + } + + for (int i = 0; i < TM; ++i) { + for (int j = 0; j < TN; ++j) { + const int globalRow = blockIdx.y * BM + warpRow * WM + threadRowInWarp * TM + i; + const int globalCol = blockIdx.x * BN + warpCol * WN + threadColInWarp * TN + j; + if (globalRow < M && globalCol < N) { + C[globalRow * N + globalCol] = __float2half(threadResults[i * TN + j] + __half2float(C[globalRow * N + globalCol])); + } + } + } +} +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/runner.cu b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/runner.cu new file mode 100644 index 0000000..a34bd9e --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/runner.cu @@ -0,0 +1,117 @@ +#include +#include +#include "kernels.cuh" +#include +#include +#include +#include + + +#include +#include +#include + +void warptiling_bitlinear(int M, int N, int K, float alpha, torch::Tensor A, torch::Tensor B, float beta, torch::Tensor C) { + + // Settings for A6000 + const uint NUM_THREADS = 128; + const uint BN = 128; + const uint BM = 128; + const uint BK = 16; + const uint WN = 64; + const uint WM = 64; + const uint WNITER = 4; + const uint TN = 4; + const uint TM = 8; + dim3 blockDim(NUM_THREADS); + + constexpr uint NUM_WARPS = NUM_THREADS / 32; + + // warptile in threadblocktile + static_assert((BN % WN == 0) and (BM % WM == 0)); + static_assert((BN / WN) * (BM / WM) == NUM_WARPS); + + // threads in warpsubtile + static_assert((WM * WN) % (WARPSIZE * TM * TN * WNITER) == 0); + constexpr uint WMITER = (WM * WN) / (32 * TM * TN * WNITER); + // warpsubtile in warptile + static_assert((WM % WMITER == 0) and (WN % WNITER == 0)); + + static_assert((NUM_THREADS * 4) % BK == 0, + "NUM_THREADS*4 must be multiple of BK to avoid quantization " + "issues during GMEM->SMEM tiling (loading only parts of the " + "final row of Bs during each iteration)"); + static_assert((NUM_THREADS * 4) % BN == 0, + "NUM_THREADS*4 must be multiple of BN to avoid quantization " + "issues during GMEM->SMEM tiling (loading only parts of the " + "final row of As during each iteration)"); + static_assert(BN % (16 * TN) == 0, + "BN must be a multiple of 16*TN to avoid quantization effects"); + static_assert(BM % (16 * TM) == 0, + "BM must be a multiple of 16*TM to avoid quantization effects"); + static_assert((BM * BK) % (4 * NUM_THREADS) == 0, + "BM*BK must be a multiple of 4*256 to vectorize loads"); + static_assert((BN * BK) % (4 * NUM_THREADS) == 0, + "BN*BK must be a multiple of 4*256 to vectorize loads"); + + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + + sgemmWarptiling_bitlinear + <<>>(M, N, K, alpha, A.data_ptr(), B.data_ptr(), beta, C.data_ptr()); +} + +void warptiling(int M, int N, int K, float alpha, torch::Tensor A, torch::Tensor B, float beta, torch::Tensor C) { + + // Settings for A6000 + const uint NUM_THREADS = 128; + const uint BN = 128; + const uint BM = 128; + const uint BK = 16; + const uint WN = 64; + const uint WM = 64; + const uint WNITER = 4; + const uint TN = 4; + const uint TM = 8; + dim3 blockDim(NUM_THREADS); + + constexpr uint NUM_WARPS = NUM_THREADS / 32; + + // warptile in threadblocktile + static_assert((BN % WN == 0) and (BM % WM == 0)); + static_assert((BN / WN) * (BM / WM) == NUM_WARPS); + + // threads in warpsubtile + static_assert((WM * WN) % (WARPSIZE * TM * TN * WNITER) == 0); + constexpr uint WMITER = (WM * WN) / (32 * TM * TN * WNITER); + // warpsubtile in warptile + static_assert((WM % WMITER == 0) and (WN % WNITER == 0)); + + static_assert((NUM_THREADS * 4) % BK == 0, + "NUM_THREADS*4 must be multiple of BK to avoid quantization " + "issues during GMEM->SMEM tiling (loading only parts of the " + "final row of Bs during each iteration)"); + static_assert((NUM_THREADS * 4) % BN == 0, + "NUM_THREADS*4 must be multiple of BN to avoid quantization " + "issues during GMEM->SMEM tiling (loading only parts of the " + "final row of As during each iteration)"); + static_assert(BN % (16 * TN) == 0, + "BN must be a multiple of 16*TN to avoid quantization effects"); + static_assert(BM % (16 * TM) == 0, + "BM must be a multiple of 16*TM to avoid quantization effects"); + static_assert((BM * BK) % (4 * NUM_THREADS) == 0, + "BM*BK must be a multiple of 4*256 to vectorize loads"); + static_assert((BN * BK) % (4 * NUM_THREADS) == 0, + "BN*BK must be a multiple of 4*256 to vectorize loads"); + + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + + sgemmWarptiling + <<>>(M, N, K, alpha, A.data_ptr(), B.data_ptr(), beta, C.data_ptr()); +} + +PYBIND11_MODULE(warptiling, m) { + m.def("warptiling_linear", &warptiling, "Run SGEMM Warptiling with Half Precision (CUDA)"); + m.def("warptiling_bitlinear", &warptiling_bitlinear, "Run SGEMM Warptiling with Half Precision (CUDA) and packed weights"); +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/no_sync/setup.py b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/setup.py similarity index 79% rename from bitlinear/frozen_bitlinear/cuda/kernels/no_sync/setup.py rename to bitlinear/frozen_bitlinear/cuda/kernels/warptiling/setup.py index 36061a6..058235a 100644 --- a/bitlinear/frozen_bitlinear/cuda/kernels/no_sync/setup.py +++ b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/setup.py @@ -6,11 +6,11 @@ raise RuntimeError("CUDA is not available. The extension requires CUDA.") setup( - name='no_stream', + name='warptiling', ext_modules=[ CUDAExtension( - name='no_stream', - sources=['linear.cu'] + name='warptiling', + sources=['runner.cu'] ) ], cmdclass={ diff --git a/bitlinear/frozen_bitlinear/tests/Benchmark.py b/bitlinear/frozen_bitlinear/tests/Benchmark.py index e1a32fe..93a4ed5 100644 --- a/bitlinear/frozen_bitlinear/tests/Benchmark.py +++ b/bitlinear/frozen_bitlinear/tests/Benchmark.py @@ -16,7 +16,7 @@ class Benchmark: - options = [256, 512, 1024, 2048, 4096, 8192] + options = [256, 512, 1024, 2048, 4096] def __init__(self, args): @@ -59,7 +59,7 @@ def unittest(self, M=512, N=256, K=128): print(f"❌ Triton and Torch differ for (M={M}, N={N}, K={K}) -- Maximum Normalized Difference : {difference}") return False, difference - def unittests(self, upper_limit=8096, step=128): + def unittests(self, upper_limit=4096, step=128): pbar = tqdm(total=upper_limit//step, desc=f"Unit Tests", leave=True) From d3c64b8f7f60c52c084abc217ae082cf6629bd73 Mon Sep 17 00:00:00 2001 From: Simon Opsahl Date: Tue, 6 Aug 2024 15:46:09 +0200 Subject: [PATCH 07/10] Added beginnings of a warptiling kernel --- bitlinear/frozen_bitlinear/README.md | 6 +- bitlinear/frozen_bitlinear/__init__.py | 2 +- .../frozen_bitlinear/cuda/kernels/__init__.py | 3 +- .../cuda/kernels/warptiling/kernels.cuh | 4 - .../kernels/warptiling/kernels/original.cu | 187 ++++++++++++ .../kernels/warptiling/kernels/warptiling.cu | 271 ++++++++++++++++++ .../kernels/warptiling/kernels/warptiling.cuh | 145 ---------- .../kernels/warptiling_bitlinear.cuh | 107 ------- .../cuda/kernels/warptiling/runner.cu | 117 -------- .../cuda/kernels/warptiling/setup.py | 3 +- .../cuda/pack_weights/pack_weights.cu | 22 +- bitlinear/frozen_bitlinear/scripts/build.sh | 8 - bitlinear/frozen_bitlinear/src/__init__.py | 4 +- .../src/{anthropic.py => naive.py} | 9 +- bitlinear/frozen_bitlinear/src/triton.py | 154 ---------- .../frozen_bitlinear/src/utils/Kernel.py | 2 +- .../frozen_bitlinear/src/utils/helpers.py | 10 +- bitlinear/frozen_bitlinear/src/warptiling.py | 21 ++ bitlinear/frozen_bitlinear/tests/Benchmark.py | 2 +- 19 files changed, 505 insertions(+), 572 deletions(-) delete mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels.cuh create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/original.cu create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cu delete mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cuh delete mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling_bitlinear.cuh delete mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/warptiling/runner.cu rename bitlinear/frozen_bitlinear/src/{anthropic.py => naive.py} (75%) delete mode 100644 bitlinear/frozen_bitlinear/src/triton.py create mode 100644 bitlinear/frozen_bitlinear/src/warptiling.py diff --git a/bitlinear/frozen_bitlinear/README.md b/bitlinear/frozen_bitlinear/README.md index 23a194f..921fad7 100644 --- a/bitlinear/frozen_bitlinear/README.md +++ b/bitlinear/frozen_bitlinear/README.md @@ -83,11 +83,13 @@ The tradeoff with this implementation is the reliance on threads to be synchroni Further optimization is required to balance the size of K, M and N in scheduling the batch sizes to schedule to the block and individual threads. Because cuBLAS is highly optimized in this regard, we need a significant speedup to offset their gains. Later work can go into developing autotuning algorithms that dynamically schedule the computation as to maximize the L2 cache hit ratio and minimize synchronization delay. -## Limitations +## Future Work As of now, this optimization extends only into inference. The method of packing weights seems redundant in training, as the shadow weights being stored in higher precision is [necessary for gradient accumulation](https://arxiv.org/pdf/2402.17764). The speedups still can be realized in the forward passes, but more investigation and more robust kernels must be devised to turn these into reality. -Further work can also be spent looking into fusing these kernels to realize even faster speedups. This has been growing in popularity recently, and luckily, alot of the work [has been done](https://github.com/ridgerchu/matmulfreellm) for ternary weight layers, we just need to plug and play with our matmul kernel. +Further work can also be spent looking into fusing these kernels with layer norm, softmax, and other such adjoining layers to realize even faster speedups. This has been growing in popularity recently, and luckily, alot of the work [has been done](https://github.com/ridgerchu/matmulfreellm) for ternary weight layers, we just need to plug and play with our matmul kernel. + +Finally, we have 1.58 bits of information packed into 2, but by utilizing compression algorithms, we can realize even higher levels of weight packing. Because a weight can only occupy one of 3 states ($00$, $10$, or $01$), we can compress $10$ to just $1$, resulting in a further 13% of savings. Another solution could be extending the context region, which would introduce complex interdependencies, but as each weight must be loaded in higher precision in a load magnitudes more costly than bit-level computation, it may prove beneficial. As weight loads are not the crux of the cost, however, it is unclear if traversing this path is worth it. ## Setup Make sure you have the correct packages installed in your virtual environment. In a conda environment, you can run: diff --git a/bitlinear/frozen_bitlinear/__init__.py b/bitlinear/frozen_bitlinear/__init__.py index 4e22c0f..95b8b85 100644 --- a/bitlinear/frozen_bitlinear/__init__.py +++ b/bitlinear/frozen_bitlinear/__init__.py @@ -1 +1 @@ -from src import TorchLinear, Naive \ No newline at end of file +from src import TorchLinear, Naive, Warptiling \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/__init__.py b/bitlinear/frozen_bitlinear/cuda/kernels/__init__.py index 4e564b9..8c42e8f 100644 --- a/bitlinear/frozen_bitlinear/cuda/kernels/__init__.py +++ b/bitlinear/frozen_bitlinear/cuda/kernels/__init__.py @@ -1 +1,2 @@ -from cuda.kernels.naive_linear import naive \ No newline at end of file +from cuda.kernels.naive_linear import naive +from cuda.kernels.warptiling import warptiling \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels.cuh b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels.cuh deleted file mode 100644 index cdd35a0..0000000 --- a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels.cuh +++ /dev/null @@ -1,4 +0,0 @@ -#pragma once - -#include "kernels/warptiling.cuh" -#include "kernels/warptiling_bitlinear.cuh" \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/original.cu b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/original.cu new file mode 100644 index 0000000..2cc66f3 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/original.cu @@ -0,0 +1,187 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) +const int WARPSIZE = 32; // warpSize is not constexpr + +namespace wt { +template +__device__ void loadFromGmem(int N, int K, const float *A, const float *B, + float *As, float *Bs, int innerRowA, int innerColA, + int innerRowB, int innerColB) { + for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { + const float4 tmp = reinterpret_cast( + &A[(innerRowA + offset) * K + innerColA * 4])[0]; + // float4 tmp; + // asm("ld.global.nc.v4.f32 {%0, %1, %2, %3}, [%4];" + // : "=f"(tmp.x), "=f"(tmp.y), "=f"(tmp.z), "=f"(tmp.w) + // : "l"(&A[(innerRowA + offset) * K + innerColA * 4])); + As[(innerColA * 4 + 0) * BM + innerRowA + offset] = tmp.x; + As[(innerColA * 4 + 1) * BM + innerRowA + offset] = tmp.y; + As[(innerColA * 4 + 2) * BM + innerRowA + offset] = tmp.z; + As[(innerColA * 4 + 3) * BM + innerRowA + offset] = tmp.w; + } + + for (uint offset = 0; offset + rowStrideB <= BK; offset += rowStrideB) { + reinterpret_cast( + &Bs[(innerRowB + offset) * BN + innerColB * 4])[0] = + reinterpret_cast( + &B[(innerRowB + offset) * N + innerColB * 4])[0]; + // asm("ld.global.v4.f32 {%0, %1, %2, %3}, [%4];" + // : "=f"(Bs[(innerRowB + offset) * BN + innerColB * 4 + 0]), + // "=f"(Bs[(innerRowB + offset) * BN + innerColB * 4 + 1]), + // "=f"(Bs[(innerRowB + offset) * BN + innerColB * 4 + 2]), + // "=f"(Bs[(innerRowB + offset) * BN + innerColB * 4 + 3]) + // : "l"(&B[(innerRowB + offset) * N + innerColB * 4])); + } +} + +template +__device__ void +processFromSmem(float *regM, float *regN, float *threadResults, const float *As, + const float *Bs, const uint warpRow, const uint warpCol, + const uint threadRowInWarp, const uint threadColInWarp) { + for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { + // populate registers for whole warptile + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { + for (uint i = 0; i < TM; ++i) { + regM[wSubRowIdx * TM + i] = + As[(dotIdx * BM) + warpRow * WM + wSubRowIdx * WSUBM + + threadRowInWarp * TM + i]; + } + } + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + for (uint i = 0; i < TN; ++i) { + regN[wSubColIdx * TN + i] = + Bs[(dotIdx * BN) + warpCol * WN + wSubColIdx * WSUBN + + threadColInWarp * TN + i]; + } + } + + // execute warptile matmul + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + // calculate per-thread results + for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { + for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { + threadResults[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + (wSubColIdx * TN) + resIdxN] += + regM[wSubRowIdx * TM + resIdxM] * + regN[wSubColIdx * TN + resIdxN]; + } + } + } + } + } +} + +} // namespace wt + +/* + * @tparam BM The threadblock size for M dimension SMEM caching. + * @tparam BN The threadblock size for N dimension SMEM caching. + * @tparam BK The threadblock size for K dimension SMEM caching. + * @tparam WM M dim of continuous tile computed by each warp + * @tparam WN N dim of continuous tile computed by each warp + * @tparam WMITER The number of subwarp tiling steps in M dimension. + * @tparam WNITER The number of subwarp tiling steps in N dimension. + * @tparam TM The per-thread tile size for M dimension. + * @tparam TN The per-thread tile size for N dimension. + */ +template +__global__ void __launch_bounds__(NUM_THREADS) + sgemmWarptiling(int M, int N, int K, float alpha, float *A, float *B, + float beta, float *C) { + const uint cRow = blockIdx.y; + const uint cCol = blockIdx.x; + + // Placement of the warp in the threadblock tile + const uint warpIdx = threadIdx.x / WARPSIZE; // the warp this thread is in + const uint warpCol = warpIdx % (BN / WN); + const uint warpRow = warpIdx / (BN / WN); + + // size of the warp subtile + constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER); + constexpr uint WSUBM = WM / WMITER; // 64/2=32 + constexpr uint WSUBN = WN / WNITER; // 32/2=16 + + // Placement of the thread in the warp subtile + const uint threadIdxInWarp = threadIdx.x % WARPSIZE; // [0, 31] + const uint threadColInWarp = threadIdxInWarp % (WSUBN / TN); // i%(16/4) + const uint threadRowInWarp = threadIdxInWarp / (WSUBN / TN); // i/4 + + // allocate space for the current blocktile in SMEM + __shared__ float As[BM * BK]; + __shared__ float Bs[BK * BN]; + + // Move blocktile to beginning of A's row and B's column + A += cRow * BM * K; + B += cCol * BN; + // Move C_ptr to warp's output tile + C += (cRow * BM + warpRow * WM) * N + cCol * BN + warpCol * WN; + + // calculating the indices that this thread will load into SMEM + // we'll load 128bit / 32bit = 4 elements per thread at each step + const uint innerRowA = threadIdx.x / (BK / 4); + const uint innerColA = threadIdx.x % (BK / 4); + constexpr uint rowStrideA = (NUM_THREADS * 4) / BK; + const uint innerRowB = threadIdx.x / (BN / 4); + const uint innerColB = threadIdx.x % (BN / 4); + constexpr uint rowStrideB = NUM_THREADS / (BN / 4); + + // allocate thread-local cache for results in registerfile + float threadResults[WMITER * TM * WNITER * TN] = {0.0}; + // we cache into registers on the warptile level + float regM[WMITER * TM] = {0.0}; + float regN[WNITER * TN] = {0.0}; + + // outer-most loop over block tiles + for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { + wt::loadFromGmem( + N, K, A, B, As, Bs, innerRowA, innerColA, innerRowB, innerColB); + __syncthreads(); + wt::processFromSmem(regM, regN, threadResults, As, Bs, warpRow, warpCol, + threadRowInWarp, threadColInWarp); + A += BK; // move BK columns to right + B += BK * N; // move BK rows down + __syncthreads(); + } + + // write out the results + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + // move C pointer to current warp subtile + float *C_interim = C + (wSubRowIdx * WSUBM) * N + wSubColIdx * WSUBN; + for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) { + for (uint resIdxN = 0; resIdxN < TN; resIdxN += 4) { + // load C vector into registers + float4 tmp = reinterpret_cast( + &C_interim[(threadRowInWarp * TM + resIdxM) * N + + threadColInWarp * TN + resIdxN])[0]; + // perform GEMM update in reg + const int i = (wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + wSubColIdx * TN + resIdxN; + tmp.x = alpha * threadResults[i + 0] + beta * tmp.x; + tmp.y = alpha * threadResults[i + 1] + beta * tmp.y; + tmp.z = alpha * threadResults[i + 2] + beta * tmp.z; + tmp.w = alpha * threadResults[i + 3] + beta * tmp.w; + // write back + reinterpret_cast( + &C_interim[(threadRowInWarp * TM + resIdxM) * N + + threadColInWarp * TN + resIdxN])[0] = tmp; + } + } + } + } +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cu b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cu new file mode 100644 index 0000000..7a5119a --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cu @@ -0,0 +1,271 @@ +#include +#include +#include +#include +#include + +// Settings for A6000 +const uint NUM_THREADS = 128; +const uint BN = 128; // The threadblock size for N dimension SMEM caching. +const uint BM = 128; // The threadblock size for M dimension SMEM caching. +const uint BK = 16; // The threadblock size for K dimension SMEM caching. +const uint WN = 64; // N dim of continuous tile computed by each warp +const uint WM = 64; // M dim of continuous tile computed by each warp +const uint WNITER = 4; // The number of subwarp tiling steps in N dimension. +const uint TN = 4; // The per-thread tile size for N dimension. +const uint TM = 8; // The per-thread tile size for M dimension. +const int WARPSIZE = 32; // warpSize is not constexpr +const uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER); // The number of subwarp tiling steps in M dimension. +const uint WSUBM = WM / WMITER; +const uint WSUBN = WN / WNITER; +const uint rowStrideA = (NUM_THREADS * 4) / BK; +const uint rowStrideB = NUM_THREADS / (BN / 4); +const uint NUM_WARPS = NUM_THREADS / WARPSIZE; + + +#define CEIL_DIV(x, y) ((x) + (y)-1) / (y) + + +/* +Half4 Helper Functions +*/ + +struct half4 { + half2 x, y; // Each half2 contains two half values +}; + +__device__ half4 loadHalf4(const half* address) { + half4 result; + result.x = *reinterpret_cast(address); + result.y = *reinterpret_cast(address + 2); + return result; +} + + +/* +Memory Operations +*/ + +namespace wt { + + /* + Load from Global Memory 4 items at a time + */ + + __device__ void loadFromGmem( + int N, int K, + const half *A, const int8_t *W, + half *sA, int8_t *sW, + int innerRowA, int innerColA, + int innerRowB, int innerColB + ) + { + // Load matrix A into shared memory sA + for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { + const half4 tmp = loadHalf4(&A[(innerRowA + offset) * K + innerColA * 4]); + sA[(innerColA * 4 + 0) * BM + innerRowA + offset] = tmp.x.x; + sA[(innerColA * 4 + 1) * BM + innerRowA + offset] = tmp.x.y; + sA[(innerColA * 4 + 2) * BM + innerRowA + offset] = tmp.y.x; + sA[(innerColA * 4 + 3) * BM + innerRowA + offset] = tmp.y.y; + } + + // Load 4 items matrix W into shared memory sW as packed 2-bit values + for (uint offset = 0; offset + rowStrideB <= BK; offset += rowStrideB) { + sW[(innerRowB + offset) * (BN / 4) + innerColB] = W[(innerRowB + offset) * (K / 4) + innerColB]; + } + } + + /* + Load from Shared Memory and Perform Computation + */ + + __device__ void processFromSmem( + half *regM, int8_t *regN, + float *threadResults, + const half *sA, const int8_t *sW, + const uint warpRow, const uint warpCol, + const uint threadRowInWarp, const uint threadColInWarp + ) + { + for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { + + // Load sub-matrix of A into registers for a whole warptile + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { + for (uint i = 0; i < TM; ++i) { + regM[wSubRowIdx * TM + i] = + sA[(dotIdx * BM) + warpRow * WM + wSubRowIdx * WSUBM + + threadRowInWarp * TM + i]; + } + } + + // Unpack and load sub-matrix of sW into registers already translated + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + for (uint i = 0; i < TN; ++i) { + uint index = (warpCol * WN + wSubColIdx * WSUBN + threadColInWarp * TN + i) >> 2; + int8_t packedWeights = sW[dotIdx * (BN / 4) + index]; + uint shift = ((warpCol * WN + wSubColIdx * WSUBN + threadColInWarp * TN + i) & 0x03) * 2; + regN[wSubColIdx * TN + i] = (packedWeights >> shift) & 0x03; + } + } + + // execute warptile matmul + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + // per-thread results + for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { + for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { + // loads the value and corresponiding weight + float inputVal = __half2float(regM[wSubRowIdx * TM + resIdxM]); + int8_t weight = regN[wSubColIdx * TN + resIdxN]; + // if 0b01, adds the activation + if (weight == 1) { + threadResults[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + (wSubColIdx * TN) + resIdxN] += inputVal; + } + // if 0b10, subtracts the activation + else if (weight == 2) { + threadResults[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + (wSubColIdx * TN) + resIdxN] -= inputVal; + } + } + } + } + } + } + } + +} + +/* +CUDA Kernel for WarpTiling with the Bitlinear Implementation +*/ +__global__ void __launch_bounds__(NUM_THREADS) sgemmWarptiling( + const int M, const int N, const int K, + const half *A, const int8_t *W, half *C + ) +{ + const uint cRow = blockIdx.y; + const uint cCol = blockIdx.x; + + // Placement of the warp in the threadblock tile + const uint warpIdx = threadIdx.x / WARPSIZE; // the warp this thread is in + const uint warpCol = warpIdx % (BN / WN); + const uint warpRow = warpIdx / (BN / WN); + + // Placement of the thread in the warp subtile + const uint threadIdxInWarp = threadIdx.x % WARPSIZE; + const uint threadColInWarp = threadIdxInWarp % (WSUBN / TN); + const uint threadRowInWarp = threadIdxInWarp / (WSUBN / TN); + + // allocate space for the current blocktile in SMEM + __shared__ half sA[BM * BK]; + __shared__ int8_t sW[BK * BN / 4]; + + // Move blocktile to beginning of A's row and W's column + A += cRow * BM * K; + W += cCol * BN / 4; + // Move C_ptr to warp's output tile + C += (cRow * BM + warpRow * WM) * N + cCol * BN + warpCol * WN; + + // calculating the indices that this thread will load into SMEM + // we'll load 4 elements per thread at each step + const uint innerRowA = threadIdx.x / (BK / 4); + const uint innerColA = threadIdx.x % (BK / 4); + const uint innerRowB = threadIdx.x / (BN / 4); + const uint innerColB = threadIdx.x % (BN / 4); + + // allocate thread-local cache for results in registerfile + float threadResults[WMITER * TM * WNITER * TN] = {0.0f}; + // we cache into registers on the warptile level + half regM[WMITER * TM] = {__float2half(0.0f)}; + int8_t regN[WNITER * TN] = {0}; + + // loop over block tiles + for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { + + wt::loadFromGmem(N, K, A, W, sA, sW, innerRowA, innerColA, innerRowB, innerColB); + + __syncthreads(); + + wt::processFromSmem(regM, regN, threadResults, sA, sW, warpRow, warpCol, threadRowInWarp, threadColInWarp); + + A += BK; + W += BK * N / 4; + + __syncthreads(); + } + + // write out the results + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + // move C pointer to current warp subtile + half *C_interim = C + (wSubRowIdx * WSUBM) * N + wSubColIdx * WSUBN; + for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) { + for (uint resIdxN = 0; resIdxN < TN; resIdxN += 2) { + // load C vector into registers + half2 tmp = reinterpret_cast( + &C_interim[(threadRowInWarp * TM + resIdxM) * N + + threadColInWarp * TN + resIdxN])[0]; + // perform GEMM update in reg + const int i = (wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + wSubColIdx * TN + resIdxN; + tmp.x = __float2half(threadResults[i + 0] + __half2float(tmp.x)); + tmp.y = __float2half(threadResults[i + 1] + __half2float(tmp.y)); + // write back + reinterpret_cast( + &C_interim[(threadRowInWarp * TM + resIdxM) * N + + threadColInWarp * TN + resIdxN])[0] = tmp; + } + } + } + } +} + + +/* +Callable Function that computes the "matmul" of A and W and stores in C +*/ +void matmul(int M, int N, int K, torch::Tensor A, torch::Tensor W, torch::Tensor C) { + + // warptile in threadblocktile + static_assert((BN % WN == 0) and (BM % WM == 0)); + static_assert((BN / WN) * (BM / WM) == NUM_WARPS); + + // threads in warpsubtile + static_assert((WM * WN) % (WARPSIZE * TM * TN * WNITER) == 0); + constexpr uint WMITER = (WM * WN) / (32 * TM * TN * WNITER); + + // warpsubtile in warptile + static_assert((WM % WMITER == 0) and (WN % WNITER == 0)); + + static_assert((NUM_THREADS * 4) % BK == 0, + "NUM_THREADS*4 must be multiple of BK to avoid quantization " + "issues during GMEM->SMEM tiling (loading only parts of the " + "final row of Bs during each iteration)"); + static_assert((NUM_THREADS * 4) % BN == 0, + "NUM_THREADS*4 must be multiple of BN to avoid quantization " + "issues during GMEM->SMEM tiling (loading only parts of the " + "final row of As during each iteration)"); + static_assert(BN % (16 * TN) == 0, + "BN must be a multiple of 16*TN to avoid quantization effects"); + static_assert(BM % (16 * TM) == 0, + "BM must be a multiple of 16*TM to avoid quantization effects"); + static_assert((BM * BK) % (4 * NUM_THREADS) == 0, + "BM*BK must be a multiple of 4*256 to vectorize loads"); + static_assert((BN * BK) % (4 * NUM_THREADS) == 0, + "BN*BK must be a multiple of 4*256 to vectorize loads"); + + dim3 blockDim(NUM_THREADS); + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + + sgemmWarptiling<<>>( + M, N, K, + reinterpret_cast(A.data_ptr()), + W.data_ptr(), + reinterpret_cast(C.data_ptr()) + ); +} + +PYBIND11_MODULE(warptiling, m) { + m.def("matmul", &matmul, "Run SGEMM Warptiling with Half Precision (CUDA)"); +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cuh b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cuh deleted file mode 100644 index f40311c..0000000 --- a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cuh +++ /dev/null @@ -1,145 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) -const int WARPSIZE = 32; // warpSize is not constexpr - -namespace wt { -template -__device__ void loadFromGmem(int N, int K, const half *A, const half *B, - half *As, half *Bs, int innerRowA, int innerColA, - int innerRowB, int innerColB) { - for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { - const half2 tmp = reinterpret_cast( - &A[(innerRowA + offset) * K + innerColA * 2])[0]; - As[(innerColA * 2 + 0) * BM + innerRowA + offset] = tmp.x; - As[(innerColA * 2 + 1) * BM + innerRowA + offset] = tmp.y; - } - - for (uint offset = 0; offset + rowStrideB <= BK; offset += rowStrideB) { - reinterpret_cast( - &Bs[(innerRowB + offset) * BN + innerColB * 2])[0] = - reinterpret_cast( - &B[(innerRowB + offset) * N + innerColB * 2])[0]; - } -} - -template -__device__ void -processFromSmem(half *regM, half *regN, float *threadResults, const half *As, - const half *Bs, const uint warpRow, const uint warpCol, - const uint threadRowInWarp, const uint threadColInWarp) { - for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { - for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { - for (uint i = 0; i < TM; ++i) { - regM[wSubRowIdx * TM + i] = - As[(dotIdx * BM) + warpRow * WM + wSubRowIdx * WSUBM + - threadRowInWarp * TM + i]; - } - } - for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { - for (uint i = 0; i < TN; ++i) { - regN[wSubColIdx * TN + i] = - Bs[(dotIdx * BN) + warpCol * WN + wSubColIdx * WSUBN + - threadColInWarp * TN + i]; - } - } - - for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { - for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { - for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { - for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { - threadResults[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + - (wSubColIdx * TN) + resIdxN] += - __half2float(regM[wSubRowIdx * TM + resIdxM]) * - __half2float(regN[wSubColIdx * TN + resIdxN]); - } - } - } - } - } -} - -} // namespace wt - -template -__global__ void __launch_bounds__(NUM_THREADS) - sgemmWarptiling(int M, int N, int K, float alpha, half *A, half *B, - float beta, half *C) { - const uint cRow = blockIdx.y; - const uint cCol = blockIdx.x; - - const uint warpIdx = threadIdx.x / WARPSIZE; - const uint warpCol = warpIdx % (BN / WN); - const uint warpRow = warpIdx / (BN / WN); - - constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER); - constexpr uint WSUBM = WM / WMITER; - constexpr uint WSUBN = WN / WNITER; - - const uint threadIdxInWarp = threadIdx.x % WARPSIZE; - const uint threadColInWarp = threadIdxInWarp % (WSUBN / TN); - const uint threadRowInWarp = threadIdxInWarp / (WSUBN / TN); - - __shared__ half As[BM * BK]; - __shared__ half Bs[BK * BN]; - - A += cRow * BM * K; - B += cCol * BN; - C += (cRow * BM + warpRow * WM) * N + cCol * BN + warpCol * WN; - - const uint innerRowA = threadIdx.x / (BK / 2); - const uint innerColA = threadIdx.x % (BK / 2); - constexpr uint rowStrideA = (NUM_THREADS * 2) / BK; - const uint innerRowB = threadIdx.x / (BN / 2); - const uint innerColB = threadIdx.x % (BN / 2); - constexpr uint rowStrideB = NUM_THREADS / (BN / 2); - - float threadResults[WMITER * TM * WNITER * TN] = {0.0}; - half regM[WMITER * TM] = {0.0}; - half regN[WNITER * TN] = {0.0}; - - for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { - wt::loadFromGmem( - N, K, A, B, As, Bs, innerRowA, innerColA, innerRowB, innerColB); - __syncthreads(); - wt::processFromSmem(regM, regN, threadResults, As, Bs, warpRow, warpCol, - threadRowInWarp, threadColInWarp); - A += BK; - B += BK * N; - __syncthreads(); - } - - for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { - for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { - half *C_interim = C + (wSubRowIdx * WSUBM) * N + wSubColIdx * WSUBN; - for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) { - for (uint resIdxN = 0; resIdxN < TN; resIdxN += 2) { - half2 tmp = reinterpret_cast( - &C_interim[(threadRowInWarp * TM + resIdxM) * N + - threadColInWarp * TN + resIdxN])[0]; - const int i = (wSubRowIdx * TM + resIdxM) * (WNITER * TN) + - wSubColIdx * TN + resIdxN; - tmp.x = __float2half(alpha * threadResults[i + 0] + beta * __half2float(tmp.x)); - tmp.y = __float2half(alpha * threadResults[i + 1] + beta * __half2float(tmp.y)); - reinterpret_cast( - &C_interim[(threadRowInWarp * TM + resIdxM) * N + - threadColInWarp * TN + resIdxN])[0] = tmp; - } - } - } - } -} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling_bitlinear.cuh b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling_bitlinear.cuh deleted file mode 100644 index a9e389a..0000000 --- a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling_bitlinear.cuh +++ /dev/null @@ -1,107 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) -const int WARPSIZE = 32; // warpSize is not constexpr - -namespace wt { -template -__device__ void loadFromGmem(int N, int K, const half *A, const uint8_t *B, half *As, uint8_t *Bs, int innerRowA, int innerColA, int innerRowB, int innerColB) { - for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { - const float4 tmp = reinterpret_cast(&A[(innerRowA + offset) * K + innerColA * 4])[0]; - As[(innerColA * 4 + 0) * BM + innerRowA + offset] = __float2half(tmp.x); - As[(innerColA * 4 + 1) * BM + innerRowA + offset] = __float2half(tmp.y); - As[(innerColA * 4 + 2) * BM + innerRowA + offset] = __float2half(tmp.z); - As[(innerColA * 4 + 3) * BM + innerRowA + offset] = __float2half(tmp.w); - } - - for (uint offset = 0; offset + rowStrideB <= BK; offset += rowStrideB) { - reinterpret_cast(&Bs[(innerRowB + offset) * BN + innerColB])[0] = - reinterpret_cast(&B[(innerRowB + offset) * (N / 4) + innerColB])[0]; - } -} - -template -__device__ void processFromSmem(half *regM, uint8_t *regN, float *threadResults, const half *As, const uint8_t *Bs, const uint warpRow, const uint warpCol, const uint threadRowInWarp, const uint threadColInWarp) { - for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { - // populate registers for whole warptile - for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { - for (uint i = 0; i < TM; ++i) { - regM[wSubRowIdx * TM + i] = As[(dotIdx * BM) + warpRow * WM + wSubRowIdx * WSUBM + threadRowInWarp * TM + i]; - } - } - for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { - for (uint i = 0; i < TN; ++i) { - regN[wSubColIdx * TN + i] = Bs[(dotIdx * BN) + warpCol * WN + wSubColIdx * WSUBN + threadColInWarp * TN + i]; - } - } - - // execute warptile matmul - for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { - for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { - // calculate per-thread results - for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { - for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { - // Unpack and process 2-bit weights - uint8_t weight = (regN[wSubColIdx * TN + resIdxN / 4] >> ((resIdxN % 4) * 2)) & 0x03; - float val = __half2float(regM[wSubRowIdx * TM + resIdxM]); - if (weight == 1) { - threadResults[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + (wSubColIdx * TN) + resIdxN] += val; - } else if (weight == 2) { - threadResults[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + (wSubColIdx * TN) + resIdxN] -= val; - } - } - } - } - } - } -} - -template -__global__ void sgemmWarptiling_bitlinear(const int M, const int N, const int K, const float alpha, const half *A, const uint8_t *B, const float beta, half *C) { - const int tid = threadIdx.x; - const int warpId = tid / 32; - const int laneId = tid % 32; - - const int warpRow = warpId / (BN / WN); - const int warpCol = warpId % (BN / WN); - const int threadRowInWarp = laneId / (TN / 4); - const int threadColInWarp = laneId % (TN / 4); - - __shared__ half As[BM * BK]; - __shared__ uint8_t Bs[BK * BN]; - - float threadResults[TM * TN] = {0}; - - for (int i = 0; i < K; i += BK) { - wt::loadFromGmem(N, K, A, B, As, Bs, warpRow, i, warpCol, i); - __syncthreads(); - - wt::processFromSmem(threadResults, As, Bs, warpRow, warpCol, threadRowInWarp, threadColInWarp); - __syncthreads(); - } - - for (int i = 0; i < TM; ++i) { - for (int j = 0; j < TN; ++j) { - const int globalRow = blockIdx.y * BM + warpRow * WM + threadRowInWarp * TM + i; - const int globalCol = blockIdx.x * BN + warpCol * WN + threadColInWarp * TN + j; - if (globalRow < M && globalCol < N) { - C[globalRow * N + globalCol] = __float2half(threadResults[i * TN + j] + __half2float(C[globalRow * N + globalCol])); - } - } - } -} -} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/runner.cu b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/runner.cu deleted file mode 100644 index a34bd9e..0000000 --- a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/runner.cu +++ /dev/null @@ -1,117 +0,0 @@ -#include -#include -#include "kernels.cuh" -#include -#include -#include -#include - - -#include -#include -#include - -void warptiling_bitlinear(int M, int N, int K, float alpha, torch::Tensor A, torch::Tensor B, float beta, torch::Tensor C) { - - // Settings for A6000 - const uint NUM_THREADS = 128; - const uint BN = 128; - const uint BM = 128; - const uint BK = 16; - const uint WN = 64; - const uint WM = 64; - const uint WNITER = 4; - const uint TN = 4; - const uint TM = 8; - dim3 blockDim(NUM_THREADS); - - constexpr uint NUM_WARPS = NUM_THREADS / 32; - - // warptile in threadblocktile - static_assert((BN % WN == 0) and (BM % WM == 0)); - static_assert((BN / WN) * (BM / WM) == NUM_WARPS); - - // threads in warpsubtile - static_assert((WM * WN) % (WARPSIZE * TM * TN * WNITER) == 0); - constexpr uint WMITER = (WM * WN) / (32 * TM * TN * WNITER); - // warpsubtile in warptile - static_assert((WM % WMITER == 0) and (WN % WNITER == 0)); - - static_assert((NUM_THREADS * 4) % BK == 0, - "NUM_THREADS*4 must be multiple of BK to avoid quantization " - "issues during GMEM->SMEM tiling (loading only parts of the " - "final row of Bs during each iteration)"); - static_assert((NUM_THREADS * 4) % BN == 0, - "NUM_THREADS*4 must be multiple of BN to avoid quantization " - "issues during GMEM->SMEM tiling (loading only parts of the " - "final row of As during each iteration)"); - static_assert(BN % (16 * TN) == 0, - "BN must be a multiple of 16*TN to avoid quantization effects"); - static_assert(BM % (16 * TM) == 0, - "BM must be a multiple of 16*TM to avoid quantization effects"); - static_assert((BM * BK) % (4 * NUM_THREADS) == 0, - "BM*BK must be a multiple of 4*256 to vectorize loads"); - static_assert((BN * BK) % (4 * NUM_THREADS) == 0, - "BN*BK must be a multiple of 4*256 to vectorize loads"); - - dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); - - sgemmWarptiling_bitlinear - <<>>(M, N, K, alpha, A.data_ptr(), B.data_ptr(), beta, C.data_ptr()); -} - -void warptiling(int M, int N, int K, float alpha, torch::Tensor A, torch::Tensor B, float beta, torch::Tensor C) { - - // Settings for A6000 - const uint NUM_THREADS = 128; - const uint BN = 128; - const uint BM = 128; - const uint BK = 16; - const uint WN = 64; - const uint WM = 64; - const uint WNITER = 4; - const uint TN = 4; - const uint TM = 8; - dim3 blockDim(NUM_THREADS); - - constexpr uint NUM_WARPS = NUM_THREADS / 32; - - // warptile in threadblocktile - static_assert((BN % WN == 0) and (BM % WM == 0)); - static_assert((BN / WN) * (BM / WM) == NUM_WARPS); - - // threads in warpsubtile - static_assert((WM * WN) % (WARPSIZE * TM * TN * WNITER) == 0); - constexpr uint WMITER = (WM * WN) / (32 * TM * TN * WNITER); - // warpsubtile in warptile - static_assert((WM % WMITER == 0) and (WN % WNITER == 0)); - - static_assert((NUM_THREADS * 4) % BK == 0, - "NUM_THREADS*4 must be multiple of BK to avoid quantization " - "issues during GMEM->SMEM tiling (loading only parts of the " - "final row of Bs during each iteration)"); - static_assert((NUM_THREADS * 4) % BN == 0, - "NUM_THREADS*4 must be multiple of BN to avoid quantization " - "issues during GMEM->SMEM tiling (loading only parts of the " - "final row of As during each iteration)"); - static_assert(BN % (16 * TN) == 0, - "BN must be a multiple of 16*TN to avoid quantization effects"); - static_assert(BM % (16 * TM) == 0, - "BM must be a multiple of 16*TM to avoid quantization effects"); - static_assert((BM * BK) % (4 * NUM_THREADS) == 0, - "BM*BK must be a multiple of 4*256 to vectorize loads"); - static_assert((BN * BK) % (4 * NUM_THREADS) == 0, - "BN*BK must be a multiple of 4*256 to vectorize loads"); - - dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); - - sgemmWarptiling - <<>>(M, N, K, alpha, A.data_ptr(), B.data_ptr(), beta, C.data_ptr()); -} - -PYBIND11_MODULE(warptiling, m) { - m.def("warptiling_linear", &warptiling, "Run SGEMM Warptiling with Half Precision (CUDA)"); - m.def("warptiling_bitlinear", &warptiling_bitlinear, "Run SGEMM Warptiling with Half Precision (CUDA) and packed weights"); -} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/setup.py b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/setup.py index 058235a..4fb8b6a 100644 --- a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/setup.py +++ b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/setup.py @@ -10,7 +10,8 @@ ext_modules=[ CUDAExtension( name='warptiling', - sources=['runner.cu'] + sources=['kernels/warptiling.cu' + ] ) ], cmdclass={ diff --git a/bitlinear/frozen_bitlinear/cuda/pack_weights/pack_weights.cu b/bitlinear/frozen_bitlinear/cuda/pack_weights/pack_weights.cu index c7e22f9..c1c15c4 100644 --- a/bitlinear/frozen_bitlinear/cuda/pack_weights/pack_weights.cu +++ b/bitlinear/frozen_bitlinear/cuda/pack_weights/pack_weights.cu @@ -49,8 +49,8 @@ torch::Tensor packedint8( // Calculate size for packed weights tensor int packed_size = (n * k + 3) / 4; // 4 weights per int8 - auto packed_weights_int32 = torch::zeros({packed_size}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); + auto packed_weights_int8 = torch::empty({packed_size}, torch::TensorOptions().dtype(torch::kInt8).device(torch::kCUDA)); const unsigned int block_size = 32; const unsigned int grid_rows = (n + block_size - 1) / block_size; @@ -58,6 +58,7 @@ torch::Tensor packedint8( dim3 dimGrid(grid_cols, grid_rows); dim3 dimBlock(block_size, block_size); + unsigned int grid_size = (packed_size + block_size - 1) / block_size; pack_weights_kernel<<>>( reinterpret_cast(weights.data_ptr()), @@ -65,29 +66,14 @@ torch::Tensor packedint8( n, k ); - - cudaError_t err = cudaGetLastError(); - TORCH_CHECK(err == cudaSuccess, "CUDA kernel failed with error: ", cudaGetErrorString(err)); - - cudaDeviceSynchronize(); - - // Allocate final packed weights as int8 - auto packed_weights_int8 = torch::empty({packed_size}, torch::TensorOptions().dtype(torch::kInt8).device(torch::kCUDA)); - - // Calculate the number of blocks needed for the second kernel - unsigned int grid_size = (packed_size + block_size - 1) / block_size; - - // Launch the second kernel to cast int32 to int8 int32_to_int8_kernel<<>>( packed_weights_int32.data_ptr(), packed_weights_int8.data_ptr(), packed_size ); - err = cudaGetLastError(); - TORCH_CHECK(err == cudaSuccess, "CUDA kernel (int32_to_int8_kernel) failed with error: ", cudaGetErrorString(err)); - - cudaDeviceSynchronize(); + cudaError_t err = cudaDeviceSynchronize(); + TORCH_CHECK(err == cudaSuccess, "CUDA kernel failed with error: ", cudaGetErrorString(err)); return packed_weights_int8; } diff --git a/bitlinear/frozen_bitlinear/scripts/build.sh b/bitlinear/frozen_bitlinear/scripts/build.sh index 3cb7419..7b5134c 100755 --- a/bitlinear/frozen_bitlinear/scripts/build.sh +++ b/bitlinear/frozen_bitlinear/scripts/build.sh @@ -12,14 +12,6 @@ for dir in "${Kernel_DIR}/"*/; do if [ -f "${dir}setup.py" ]; then echo "Building in directory: $dir" (cd "$dir" && python setup.py build_ext --inplace) - else - # If `setup.py` is not in the first level, check subdirectories - for subdir in "${dir}"*/; do - if [ -f "${subdir}setup.py" ]; then - echo "Building in subdirectory: $subdir" - (cd "$subdir" && python setup.py build_ext --inplace) - fi - done fi done diff --git a/bitlinear/frozen_bitlinear/src/__init__.py b/bitlinear/frozen_bitlinear/src/__init__.py index 76951f7..1756ef7 100644 --- a/bitlinear/frozen_bitlinear/src/__init__.py +++ b/bitlinear/frozen_bitlinear/src/__init__.py @@ -1,2 +1,4 @@ from src.default import TorchLinear -from src.anthropic import Naive +from src.naive import Naive +from src.warptiling import Warptiling + diff --git a/bitlinear/frozen_bitlinear/src/anthropic.py b/bitlinear/frozen_bitlinear/src/naive.py similarity index 75% rename from bitlinear/frozen_bitlinear/src/anthropic.py rename to bitlinear/frozen_bitlinear/src/naive.py index 8683f37..7ce0c29 100644 --- a/bitlinear/frozen_bitlinear/src/anthropic.py +++ b/bitlinear/frozen_bitlinear/src/naive.py @@ -1,7 +1,7 @@ from src.utils import Packed8, naive -class Anthropic(Packed8): - fxn = lambda args: NotImplementedError +class Naive(Packed8): + fxn = naive.linear def __call__(self, activations, weights, bias, scale): # Check constraints. @@ -11,13 +11,10 @@ def __call__(self, activations, weights, bias, scale): M, K = activations.shape N = weights.shape[0] * 4//K - x_quant, x_scale = self.activations(input) + x_quant, x_scale = self.activations(activations) return self.fxn(x_quant, weights, bias, M, N, K) * scale * x_scale -class Naive(Anthropic): - fxn = naive.linear - diff --git a/bitlinear/frozen_bitlinear/src/triton.py b/bitlinear/frozen_bitlinear/src/triton.py deleted file mode 100644 index 431a14a..0000000 --- a/bitlinear/frozen_bitlinear/src/triton.py +++ /dev/null @@ -1,154 +0,0 @@ -# import torch -# import triton -# import triton.language as tl -# from .utils import Kernel, get_cuda_autotune_config - -# ###### Baseline ###### -# class default(Kernel): - -# def __call__(self, activations, weights, bias, scale): -# # Check constraints. -# assert activations.shape[1] == weights.shape[1], "Incompatible dimensions" -# assert activations.is_contiguous(), "Matrix A must be contiguous" -# assert activations.shape[0] == bias.shape[0], "Bias dimension must match input" - -# M, K = activations.shape -# N = weights.shape[0] - -# weights = weights.transpose(0, 1) - -# # Allocates output. -# output = torch.empty((M, N), device=activations.device, dtype=torch.float16) - -# # 1D launch kernel where each block gets its own program. -# grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) - -# kernel[grid]( -# activations, weights, output, bias, # -# M, N, K, # -# activations.stride(0), activations.stride(1), # -# weights.stride(0), weights.stride(1), # -# output.stride(0), output.stride(1), # -# bias.stride(0) -# ) -# return output * scale - - -# def get_cuda_autotune_config(): -# return [ -# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, -# num_warps=8), -# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, -# num_warps=4), -# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, -# num_warps=4), -# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, -# num_warps=4), -# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, -# num_warps=4), -# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, -# num_warps=4), -# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, -# num_warps=2), -# triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, -# num_warps=2), -# # Good config for fp8 inputs. -# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, -# num_warps=8), -# triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, -# num_warps=8), -# triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, -# num_warps=4), -# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, -# num_warps=4), -# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=4, -# num_warps=4), -# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, -# num_warps=4), -# triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, -# num_warps=4), -# triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, -# num_warps=4) -# ] - -# @triton.autotune( -# configs=get_cuda_autotune_config(), -# key=['M', 'N', 'K'], -# ) -# @triton.jit -# def kernel( -# # Pointers to matrices -# activation_ptr, weights_ptr, ouput_ptr, bias_ptr, -# # Matrix dimensions -# M, N, K, -# # The stride variables represent how much to increase the ptr by when moving by 1 -# # element in a particular dimension. E.g. `stride_am` is how much to increase `activation_ptr` -# # by to get the element one row down (A has M rows). -# stride_am, stride_ak, -# stride_bk, stride_bn, -# stride_cm, stride_cn, -# stride_dm, -# # Meta-parameters -# BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, -# GROUP_SIZE_M: tl.constexpr -# ): -# """Kernel for computing the matmul C = A x B + D -# A has shape (M, K), B has shape (K, N) and D has shape (M, 1) -# Output C has shape (M, N) -# """ -# # ----------------------------------------------------------- -# # Map program ids `pid` to the block of C it should compute. -# # This is done in a grouped ordering to promote L2 data reuse. - -# pid = tl.program_id(axis=0) - -# num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) -# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) -# num_pid_in_group = GROUP_SIZE_M * num_pid_n -# group_id = pid // num_pid_in_group -# first_pid_m = group_id * GROUP_SIZE_M -# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) -# pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) -# pid_n = (pid % num_pid_in_group) // group_size_m - -# # ---------------------------------------------------------- -# # Create pointers for the first blocks of A and B. -# # We will advance this pointer as we move in the K direction -# # and accumulate -# offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M -# offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N -# offs_k = tl.arange(0, BLOCK_SIZE_K) -# activation_ptrs = activation_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) -# weights_ptrs = weights_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - -# # ----------------------------------------------------------- -# # Iterate to compute a block of the C matrix. -# # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block -# # of fp32 values for higher accuracy. -# accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) -# for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): -# # Load the next block of A and B, generate a mask by checking the K dimension. -# # If it is out of bounds, set it to 0. -# inputs = tl.load(activation_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) -# weights = tl.load(weights_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) -# # We accumulate along the K dimension. -# accumulator += tl.dot(inputs, weights) -# # Advance the ptrs to the next K block. -# activation_ptrs += BLOCK_SIZE_K * stride_ak -# weights_ptrs += BLOCK_SIZE_K * stride_bk - -# # Add bias D to the accumulated result -# bias_ptrs = bias_ptr + offs_am[:, None] * stride_dm -# bias = tl.load(bias_ptrs) -# accumulator += bias - -# output = accumulator.to(tl.float16) - -# # ----------------------------------------------------------- -# # Write back the block of the output matrix C with masks. -# offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) -# offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) -# ouput_ptrs = ouput_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] -# ouput_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) -# tl.store(ouput_ptrs, output, mask=ouput_mask) - \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/src/utils/Kernel.py b/bitlinear/frozen_bitlinear/src/utils/Kernel.py index 77dbdcc..c77b452 100644 --- a/bitlinear/frozen_bitlinear/src/utils/Kernel.py +++ b/bitlinear/frozen_bitlinear/src/utils/Kernel.py @@ -5,7 +5,7 @@ from cuda.pack_weights import pack_weights class Kernel: - def __init__(self, eps=1e-5, activation_range = 8, activation_measure = 'AbsMax'): + def __init__(self, eps=1e-5, activation_range=8, activation_measure = 'AbsMax'): self.eps = eps self.activations = eval(activation_measure)(activation_range, eps) diff --git a/bitlinear/frozen_bitlinear/src/utils/helpers.py b/bitlinear/frozen_bitlinear/src/utils/helpers.py index ec27f86..988c2d4 100644 --- a/bitlinear/frozen_bitlinear/src/utils/helpers.py +++ b/bitlinear/frozen_bitlinear/src/utils/helpers.py @@ -6,11 +6,11 @@ Activation Quantization Helpers ''' -def symmetric_range_from_bits(self, range): +def symmetric_range_from_bits(range): return (ceil(-2**(range-1)), ceil(2**(range-1)-1)) -def round_clamp(self, input): - return (input.round().clamp(self.range[0], self.range[1]) - input).detach() + input +def round_clamp(input, range): + return (input.round().clamp(range[0], range[1]) - input).detach() + input class ActivationMeasure: def __init__(self, range=8, eps=1e-5): @@ -20,7 +20,7 @@ def __init__(self, range=8, eps=1e-5): def __call__(self, input): x_norm = torch.layer_norm(input, input.size()[1:]) x_scale = self.scale(x_norm) - return round_clamp(x_norm/x_scale), x_scale + return round_clamp(x_norm/x_scale, self.range), x_scale def scale(self, input) -> torch.Tensor: raise NotImplementedError @@ -37,7 +37,7 @@ class Fp16(ActivationMeasure): def __init__(self, range=8, eps=1e-5): pass def __call__(self, input) -> tuple[torch.Tensor, torch.Tensor]: - return input, torch.Tensor(1.0) + return input, torch.tensor([[1.0]], device='cuda', dtype=torch.float16) class AbsMax(ActivationMeasure): def scale(self, input) -> torch.Tensor: diff --git a/bitlinear/frozen_bitlinear/src/warptiling.py b/bitlinear/frozen_bitlinear/src/warptiling.py new file mode 100644 index 0000000..17a68a8 --- /dev/null +++ b/bitlinear/frozen_bitlinear/src/warptiling.py @@ -0,0 +1,21 @@ +import torch +from src.utils import Packed8, warptiling + +class Warptiling(Packed8): + + def __call__(self, activations, weights, bias, scale): + # Check constraints. + assert activations.is_contiguous(), "Matrix A must be contiguous" + assert activations.shape[0] == bias.shape[0], "Bias dimension must match input" + + M, K = activations.shape + N = weights.shape[0] * 4//K + + x_quant, x_scale = self.activations(activations) + output = torch.zeros((M, N), device='cuda', dtype = torch.float16) + + # print(output, x_quant, weights, bias) + + warptiling.matmul(M, N, K, x_quant, weights, output) + + return (output + bias) * scale * x_scale \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/tests/Benchmark.py b/bitlinear/frozen_bitlinear/tests/Benchmark.py index 93a4ed5..c2c8084 100644 --- a/bitlinear/frozen_bitlinear/tests/Benchmark.py +++ b/bitlinear/frozen_bitlinear/tests/Benchmark.py @@ -27,7 +27,7 @@ def __init__(self, args): self.kernel_name = args.kernel self.kernel = eval(self.kernel_name)(activation_measure='Fp16') - self.baseline = TorchLinear() + self.baseline = TorchLinear(activation_measure='Fp16') print(f'Testing with {self.kernel_name} kernel.') self.ref_lib = 'cuBLAS' From 63527d713380de3e749c0444b744dbc58ad70132 Mon Sep 17 00:00:00 2001 From: Simon Opsahl Date: Wed, 7 Aug 2024 10:35:32 +0200 Subject: [PATCH 08/10] New method of packing weights to utilize vectorized loads --- .../blocktiling/kernels/blocktiling.cu | 163 ++++++++++++++++++ .../kernels/blocktiling/kernels/original.cu | 80 +++++++++ .../blocktiling/kernels/pack_weights.cu | 78 +++++++++ .../cuda/kernels/blocktiling/setup.py | 19 ++ .../kernels/warptiling/kernels/warptiling.cu | 1 - .../cuda/kernels/warptiling/setup.py | 3 +- 6 files changed, 341 insertions(+), 3 deletions(-) create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/blocktiling.cu create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/original.cu create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/pack_weights.cu create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/setup.py diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/blocktiling.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/blocktiling.cu new file mode 100644 index 0000000..e5daa1e --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/blocktiling.cu @@ -0,0 +1,163 @@ +#include +#include +#include +#include +#include + +// A600 paramters +const uint BM = 64; // The threadblock size for M dimension SMEM caching. +const uint BN = 64; // The threadblock size for N dimension SMEM caching. +const uint BK = 8; // The threadblock size for K dimension SMEM caching. +const uint TM = 8; // The per-thread tile size for M dimension. + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +/* +Helper Functions for Vectorized Loads +*/ +typedef struct { + half2 data[4]; +} half8; + +typedef struct { + char2 data[1]; +} int8_2; + +__device__ void loadfromGMEM( + int N, int K, + const half *A, const int8_t *W, + half *sA, int8_t *sW, + int innerRowA, int innerColA, + int innerRowB, int innerColB + ) { + + // Get the index of the shared_memory + int idx = ...; + int memIdx = ...; + + // Load 8 half values (128 bits) from global memory + const half8 tmp = __ldg(&A[memIdx]); + + // Store the 8 half values into shared memory + sA[8 * idx] = __low2half(tmp.data[0]); + sA[8 * idx + 1] = __high2half(tmp.data[0]); + sA[8 * idx + 2] = __low2half(tmp.data[1]); + sA[8 * idx + 3] = __high2half(tmp.data[1]); + sA[8 * idx + 4] = __low2half(tmp.data[2]); + sA[8 * idx + 5] = __high2half(tmp.data[2]); + sA[8 * idx + 6] = __low2half(tmp.data[3]); + sA[8 * idx + 7] = __high2half(tmp.data[3]); + + + + + + // Synchronize to ensure all threads have written their data + __syncthreads(); + +} + + +namespace btm { + + + + +} + + + +__global__ void sgemm1DBlocktiling( + int M, int N, int K, + const half *A, + const int8_t *W, + const half *bias, + const half *scale, + half *output + ) { + + const uint cRow = blockIdx.y; + const uint cCol = blockIdx.x; + + // each warp will calculate 32*TM elements, with 32 being the columnar dim. + const int threadCol = threadIdx.x % BN; + const int threadRow = threadIdx.x / BN; + + // allocate space for the current blocktile in SMEM + __shared__ half As[BM * BK]; + __shared__ half Ws[BN * BK]; + + // Move blocktile to beginning of A's row and B's column + A += cRow * BM * K; + B += cCol * BN; + C += cRow * BM * N + cCol * BN; + + // todo: adjust this to each thread to load multiple entries and + // better exploit the cache sizes + assert(BM * BK == blockDim.x); + assert(BN * BK == blockDim.x); + + const uint innerColA = threadIdx.x % BK; // warp-level GMEM coalescing + const uint innerRowA = threadIdx.x / BK; + const uint innerColB = threadIdx.x % BN; // warp-level GMEM coalescing + const uint innerRowB = threadIdx.x / BN; + + // allocate thread-local cache for results in registerfile + float threadResults[TM] = {0.0}; + + // outer loop over block tiles + for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { + // populate the SMEM caches + As[innerRowA * BK + innerColA] = A[innerRowA * K + innerColA]; + Bs[innerRowB * BN + innerColB] = B[innerRowB * N + innerColB]; + __syncthreads(); + + // advance blocktile + A += BK; + B += BK * N; + + // calculate per-thread results + for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { + // we make the dotproduct loop the outside loop, which facilitates + // reuse of the Bs entry, which we can cache in a tmp var. + float tmpB = Bs[dotIdx * BN + threadCol]; + for (uint resIdx = 0; resIdx < TM; ++resIdx) { + threadResults[resIdx] += + As[(threadRow * TM + resIdx) * BK + dotIdx] * tmpB; + } + } + __syncthreads(); + } + + // write out the results + for (uint resIdx = 0; resIdx < TM; ++resIdx) { + output[(threadRow * TM + resIdx) * N + threadCol] = threadResults[resIdx] + output[(threadRow * TM + resIdx) * N + threadCol]; + } + +} + +torch::Tensor linear( + int M, int N, int K, + half *A, + int8_t *W , + half *bias, + half *scale + ) { + + auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0); + auto output = torch::zeros({M, N}, options); + + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 blockDim((BM * BN) / TM); + + sgemmDBlocktiling<<>>( + M, N, K, + reinterpret_cast(A.data_ptr()), + W.data_ptr(), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(scale.data_ptr()) + reinterpret_cast(output.data_ptr()) + ); + + return output; +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/original.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/original.cu new file mode 100644 index 0000000..12c9c5d --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/original.cu @@ -0,0 +1,80 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +template +__global__ void sgemm1DBlocktiling(int M, int N, int K, float alpha, + const float *A, const float *B, float beta, + float *C) { + // If we flip x and y here we get ~30% less performance for large matrices. + // The current, 30% faster configuration ensures that blocks with sequential + // blockIDs access columns of B sequentially, while sharing the same row of A. + // The slower configuration would share columns of A, but access into B would + // be non-sequential. So the faster configuration has better spatial locality + // and hence a greater L2 hit rate. + const uint cRow = blockIdx.y; + const uint cCol = blockIdx.x; + + // each warp will calculate 32*TM elements, with 32 being the columnar dim. + const int threadCol = threadIdx.x % BN; + const int threadRow = threadIdx.x / BN; + + // allocate space for the current blocktile in SMEM + __shared__ float As[BM * BK]; + __shared__ float Bs[BK * BN]; + + // Move blocktile to beginning of A's row and B's column + A += cRow * BM * K; + B += cCol * BN; + C += cRow * BM * N + cCol * BN; + + // todo: adjust this to each thread to load multiple entries and + // better exploit the cache sizes + assert(BM * BK == blockDim.x); + assert(BN * BK == blockDim.x); + const uint innerColA = threadIdx.x % BK; // warp-level GMEM coalescing + const uint innerRowA = threadIdx.x / BK; + const uint innerColB = threadIdx.x % BN; // warp-level GMEM coalescing + const uint innerRowB = threadIdx.x / BN; + + // allocate thread-local cache for results in registerfile + float threadResults[TM] = {0.0}; + + // outer loop over block tiles + for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { + // populate the SMEM caches + As[innerRowA * BK + innerColA] = A[innerRowA * K + innerColA]; + Bs[innerRowB * BN + innerColB] = B[innerRowB * N + innerColB]; + __syncthreads(); + + // advance blocktile + A += BK; + B += BK * N; + + // calculate per-thread results + for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { + // we make the dotproduct loop the outside loop, which facilitates + // reuse of the Bs entry, which we can cache in a tmp var. + float tmpB = Bs[dotIdx * BN + threadCol]; + for (uint resIdx = 0; resIdx < TM; ++resIdx) { + threadResults[resIdx] += + As[(threadRow * TM + resIdx) * BK + dotIdx] * tmpB; + } + } + __syncthreads(); + } + + // write out the results + for (uint resIdx = 0; resIdx < TM; ++resIdx) { + C[(threadRow * TM + resIdx) * N + threadCol] = + alpha * threadResults[resIdx] + + beta * C[(threadRow * TM + resIdx) * N + threadCol]; + } +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/pack_weights.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/pack_weights.cu new file mode 100644 index 0000000..78661ea --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/pack_weights.cu @@ -0,0 +1,78 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +__global__ void pack_weights_kernel(const half* weights, short* packed_weights, int n, int k) { + + int row = blockIdx.y * blockDim.y + threadIdx.y; + int colStart = blockIdx.x * 8; + + if (row < n && colStart < k) { + + short packed = 0; + + for (int index = 0; index < 8; ++index) { + if (colStart + index < k) { + float weight_value = __half2float(weights[row * k + colStart + index]); + short bit_mask = (weight_value == 1.0f) ? (1 << (index * 2)) : + (weight_value == -1.0f) ? (2 << (index * 2)) : 0; + packed |= bit_mask; + } + } + + int packed_index = blockIdx.x * n + row; + packed_weights[packed_index] = packed; + } +} + +/* +In this implementation we have 8 weights in the K direction, which are then ordered in +column major order ([0, 0-7], [1, 0-7], ... [(N-1), 0-7], [0, 8-15], ...) +This comes from a NxK Weight Matrix +*/ +torch::Tensor packed_K8_column_major(torch::Tensor weights) { + + TORCH_CHECK(weights.is_contiguous(), "weights tensor must be contiguous"); + TORCH_CHECK(weights.dtype() == torch::kFloat16, "weights tensor must be of type float16"); + + int n = weights.size(0); + int k = weights.size(1); + + TORCH_CHECK(k % 8 == 0, "K must be divisible by 8"); + + int packed_size = CEIL_DIV(n * k, 8); + auto packed_weights = torch::zeros({packed_size}, torch::TensorOptions().dtype(torch::kInt16).device(torch::kCUDA)); + + const unsigned int block_size = 128; + const unsigned int grid_rows = CEIL_DIV(n, block_size); + const unsigned int grid_cols = CEIL_DIV(k, 8); + + dim3 dimGrid(grid_cols, grid_rows); + dim3 dimBlock(1, block_size); + + pack_weights_kernel<<>>( + reinterpret_cast(weights.data_ptr()), + packed_weights.data_ptr(), + n, + k + ); + + cudaError_t err = cudaDeviceSynchronize(); + TORCH_CHECK(err == cudaSuccess, "CUDA kernel failed with error: ", cudaGetErrorString(err)); + + return packed_weights; +} + + +PYBIND11_MODULE(pack_weights, m) { + m.def("packed_K8_column_major", &packed_K8_column_major, "Pack fp16 weights into int16 tensor"); +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/setup.py b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/setup.py new file mode 100644 index 0000000..fbb543f --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. The extension requires CUDA.") + +setup( + name='warptiling', + ext_modules=[ + CUDAExtension( + name='warptiling', + sources=['kernels/warptiling.cu'] + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cu b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cu index 7a5119a..b89421a 100644 --- a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cu +++ b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cu @@ -22,7 +22,6 @@ const uint rowStrideA = (NUM_THREADS * 4) / BK; const uint rowStrideB = NUM_THREADS / (BN / 4); const uint NUM_WARPS = NUM_THREADS / WARPSIZE; - #define CEIL_DIV(x, y) ((x) + (y)-1) / (y) diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/setup.py b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/setup.py index 4fb8b6a..fbb543f 100644 --- a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/setup.py +++ b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/setup.py @@ -10,8 +10,7 @@ ext_modules=[ CUDAExtension( name='warptiling', - sources=['kernels/warptiling.cu' - ] + sources=['kernels/warptiling.cu'] ) ], cmdclass={ From 51aa1e1d3289d9518001d95573185007459db31c Mon Sep 17 00:00:00 2001 From: Simon Opsahl Date: Wed, 7 Aug 2024 12:38:42 +0200 Subject: [PATCH 09/10] Blocktiling -sopsahl --- .../cuda/kernels/blocktiling/kernels/blocktiling.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/blocktiling.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/blocktiling.cu index e5daa1e..db7edae 100644 --- a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/blocktiling.cu +++ b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/blocktiling.cu @@ -16,7 +16,7 @@ const uint TM = 8; // The per-thread tile size for M dimension. Helper Functions for Vectorized Loads */ typedef struct { - half2 data[4]; + half data[8]; } half8; typedef struct { @@ -25,7 +25,7 @@ typedef struct { __device__ void loadfromGMEM( int N, int K, - const half *A, const int8_t *W, + const half *A, const int16_t *W, half *sA, int8_t *sW, int innerRowA, int innerColA, int innerRowB, int innerColB @@ -49,7 +49,9 @@ __device__ void loadfromGMEM( sA[8 * idx + 7] = __high2half(tmp.data[3]); - + // Load 128 bits from the int16_t array (8 int16_t values) + short8 int16Data = reinterpret_cast(W)[...]; + int16Result[idx] = int16Data; // Synchronize to ensure all threads have written their data From 350063dc88d9d232fe212696d9cbe7b8f043b45a Mon Sep 17 00:00:00 2001 From: Simon Opsahl Date: Thu, 8 Aug 2024 16:09:52 +0200 Subject: [PATCH 10/10] Preliminary Optimizing Memory Access Patterns --- .../blocktiling/kernels/blocktiling.cu | 165 ------------- .../kernels/blocktiling/kernels/kernel.cu | 230 ++++++++++++++++++ .../1dblocktiling.cu} | 0 .../blocktiling/kernels/reference/original.cu | 80 ++++++ .../pack_weights16.cu} | 0 .../kernels/reference/pack_weights32.cu | 84 +++++++ .../kernels/reference/shared_mem.cu | 57 +++++ .../kernels/reference/vectorized.cu | 122 ++++++++++ .../cuda/kernels/naive_linear/naive.cu | 1 - .../kernels/warptiling/kernels/warptiling.cu | 1 - 10 files changed, 573 insertions(+), 167 deletions(-) delete mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/blocktiling.cu create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/kernel.cu rename bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/{original.cu => reference/1dblocktiling.cu} (100%) create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/original.cu rename bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/{pack_weights.cu => reference/pack_weights16.cu} (100%) create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/pack_weights32.cu create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/shared_mem.cu create mode 100644 bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/vectorized.cu diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/blocktiling.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/blocktiling.cu deleted file mode 100644 index db7edae..0000000 --- a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/blocktiling.cu +++ /dev/null @@ -1,165 +0,0 @@ -#include -#include -#include -#include -#include - -// A600 paramters -const uint BM = 64; // The threadblock size for M dimension SMEM caching. -const uint BN = 64; // The threadblock size for N dimension SMEM caching. -const uint BK = 8; // The threadblock size for K dimension SMEM caching. -const uint TM = 8; // The per-thread tile size for M dimension. - -#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) - -/* -Helper Functions for Vectorized Loads -*/ -typedef struct { - half data[8]; -} half8; - -typedef struct { - char2 data[1]; -} int8_2; - -__device__ void loadfromGMEM( - int N, int K, - const half *A, const int16_t *W, - half *sA, int8_t *sW, - int innerRowA, int innerColA, - int innerRowB, int innerColB - ) { - - // Get the index of the shared_memory - int idx = ...; - int memIdx = ...; - - // Load 8 half values (128 bits) from global memory - const half8 tmp = __ldg(&A[memIdx]); - - // Store the 8 half values into shared memory - sA[8 * idx] = __low2half(tmp.data[0]); - sA[8 * idx + 1] = __high2half(tmp.data[0]); - sA[8 * idx + 2] = __low2half(tmp.data[1]); - sA[8 * idx + 3] = __high2half(tmp.data[1]); - sA[8 * idx + 4] = __low2half(tmp.data[2]); - sA[8 * idx + 5] = __high2half(tmp.data[2]); - sA[8 * idx + 6] = __low2half(tmp.data[3]); - sA[8 * idx + 7] = __high2half(tmp.data[3]); - - - // Load 128 bits from the int16_t array (8 int16_t values) - short8 int16Data = reinterpret_cast(W)[...]; - int16Result[idx] = int16Data; - - - // Synchronize to ensure all threads have written their data - __syncthreads(); - -} - - -namespace btm { - - - - -} - - - -__global__ void sgemm1DBlocktiling( - int M, int N, int K, - const half *A, - const int8_t *W, - const half *bias, - const half *scale, - half *output - ) { - - const uint cRow = blockIdx.y; - const uint cCol = blockIdx.x; - - // each warp will calculate 32*TM elements, with 32 being the columnar dim. - const int threadCol = threadIdx.x % BN; - const int threadRow = threadIdx.x / BN; - - // allocate space for the current blocktile in SMEM - __shared__ half As[BM * BK]; - __shared__ half Ws[BN * BK]; - - // Move blocktile to beginning of A's row and B's column - A += cRow * BM * K; - B += cCol * BN; - C += cRow * BM * N + cCol * BN; - - // todo: adjust this to each thread to load multiple entries and - // better exploit the cache sizes - assert(BM * BK == blockDim.x); - assert(BN * BK == blockDim.x); - - const uint innerColA = threadIdx.x % BK; // warp-level GMEM coalescing - const uint innerRowA = threadIdx.x / BK; - const uint innerColB = threadIdx.x % BN; // warp-level GMEM coalescing - const uint innerRowB = threadIdx.x / BN; - - // allocate thread-local cache for results in registerfile - float threadResults[TM] = {0.0}; - - // outer loop over block tiles - for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { - // populate the SMEM caches - As[innerRowA * BK + innerColA] = A[innerRowA * K + innerColA]; - Bs[innerRowB * BN + innerColB] = B[innerRowB * N + innerColB]; - __syncthreads(); - - // advance blocktile - A += BK; - B += BK * N; - - // calculate per-thread results - for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { - // we make the dotproduct loop the outside loop, which facilitates - // reuse of the Bs entry, which we can cache in a tmp var. - float tmpB = Bs[dotIdx * BN + threadCol]; - for (uint resIdx = 0; resIdx < TM; ++resIdx) { - threadResults[resIdx] += - As[(threadRow * TM + resIdx) * BK + dotIdx] * tmpB; - } - } - __syncthreads(); - } - - // write out the results - for (uint resIdx = 0; resIdx < TM; ++resIdx) { - output[(threadRow * TM + resIdx) * N + threadCol] = threadResults[resIdx] + output[(threadRow * TM + resIdx) * N + threadCol]; - } - -} - -torch::Tensor linear( - int M, int N, int K, - half *A, - int8_t *W , - half *bias, - half *scale - ) { - - auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0); - auto output = torch::zeros({M, N}, options); - - dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); - dim3 blockDim((BM * BN) / TM); - - sgemmDBlocktiling<<>>( - M, N, K, - reinterpret_cast(A.data_ptr()), - W.data_ptr(), - reinterpret_cast(bias.data_ptr()), - reinterpret_cast(scale.data_ptr()) - reinterpret_cast(output.data_ptr()) - ); - - return output; -} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/kernel.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/kernel.cu new file mode 100644 index 0000000..556b37c --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/kernel.cu @@ -0,0 +1,230 @@ +#include +#include +#include +#include +#include + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +const uint NUM_THREADS = 256; // Number of Threads per block +const uint WARP_SIZE = 32; // Fixed in hardware +const uint NUM_WARPS = NUM_THREADS/WARP_SIZE; // 8 + +const uint TN = 8; // the number of elements calculated in the n dimension per thread +const uint TM = 8; // the number of elements calculated in the m dimension per thread loop + +// Each warp will access 256 elements of A each in batches of 8 +const uint STRIDE_Ak = 8; // The stride for loading A (in K direction) +const uint STRIDE_Wn = 8; // The stride for loading W (in N direction) + +const uint BN = 2048; // NUM_THREADS*STRIDE_Wk*8, the threadblock size for N dimension W SMEM caching. +const uint BK = 512; // The threadblock size for K dimension A SMEM caching. +const uint BM = 4; // NUM_WARPS, The number of rows to calculate per block + +const uint WEIGHTSper32 = 16; +const uint HALVESper32 = 2; + +namespace MemAccess { + typedef struct { + const __half data[8] + } half8; + + __device__ void loadAs( + int K, + const __half *A, half2 *As, + int innerColA, int innerRowA + ) { + + // Load a half8 from A + const half8 tmp = reinterpret_cast(&A[innerRowA * K + innerColA* STRIDE_Ak])[0]; + + // Convert half8 data to half2, transpose it and store in As + As[(innerColA * 4 + 0) * BM + innerRowA] = __halves2half2(tmp.data[1], tmp.data[0]); + As[(innerColA * 4 + 1) * BM + innerRowA] = __halves2half2(tmp.data[3], tmp.data[2]); + As[(innerColA * 4 + 2) * BM + innerRowA] = __halves2half2(tmp.data[5], tmp.data[4]); + As[(innerColA * 4 + 3) * BM + innerRowA] = __halves2half2(tmp.data[7], tmp.data[6]); + + } + + __device__ void loadWs( + const short *W, short *Ws, + int innerRowW + ) { + + // Load an int4 from W + const short8 tmp = reinterpret_cast(W)[innerRowW*STRIDE_Wn]; + + // Convert int4 data to int16_t and store in Ws + Ws[innerRowW*STRIDE_Wn] = tmp[0]; + Ws[innerRowW*STRIDE_Wn + 1] = tmp[1]; + Ws[innerRowW*STRIDE_Wn + 2] = tmp[2]; + Ws[innerRowW*STRIDE_Wn + 3] = tmp[3]; + Ws[innerRowW*STRIDE_Wn + 4] = tmp[4]; + Ws[innerRowW*STRIDE_Wn + 5] = tmp[5]; + Ws[innerRowW*STRIDE_Wn + 6] = tmp[6]; + Ws[innerRowW*STRIDE_Wn + 7] = tmp[7]; + + } + + __device__ void load_regA( + const half2 *As, float *regA, + int index + ) { + + const half8 tmp = reinterpret_cast(&As[index*BM])[0]; + + // loading registers with the 8x2 activations from As + + for (int i=0; i( + &C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0]; + // perform GEMM update in reg + tmp.x = alpha * threadResults[resIdxM * TN + resIdxN] + beta * tmp.x; + tmp.y = alpha * threadResults[resIdxM * TN + resIdxN + 1] + beta * tmp.y; + tmp.z = alpha * threadResults[resIdxM * TN + resIdxN + 2] + beta * tmp.z; + tmp.w = alpha * threadResults[resIdxM * TN + resIdxN + 3] + beta * tmp.w; + // write back + reinterpret_cast( + &C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0] = + tmp; + } + } +} + +void runSgemmVectorize(Torch:Tensor *A, int *B, + float beta, float *C) { + const uint BK = 8; + const uint TM = 8; + const uint TN = 8; + if (M >= 128 and N >= 128) { + const uint BM = 128; + const uint BN = 128; + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 blockDim((BM * BN) / (TM * TN)); + sgemmVectorize + <<>>(M, N, K, alpha, A, B, beta, C); + } else { + // this is a hacky solution to the underlying problem + // of not having proper bounds checking in the kernel + const uint BM = 64; + const uint BN = 64; + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 blockDim((BM * BN) / (TM * TN)); + sgemmVectorize + <<>>(M, N, K, alpha, A, B, beta, C); + } +} + + +torch::Tensor linear( + torch::Tensor *A, + torch::Tensor *W , + half *bias, + half *scale + ) { + + int M = A.size(0); + int K = A.size(1); + int N = W.size(0); + + auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0); + auto output = torch::zeros({M, N}, options); + + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 blockDim((BM * BN) / TM); + + sgemmDBlocktiling<<>>( + M, N, K, + reinterpret_cast(A.data_ptr()), + W.data_ptr(), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(scale.data_ptr()) + reinterpret_cast(output.data_ptr()) + ); + + return output; +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/original.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/1dblocktiling.cu similarity index 100% rename from bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/original.cu rename to bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/1dblocktiling.cu diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/original.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/original.cu new file mode 100644 index 0000000..12c9c5d --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/original.cu @@ -0,0 +1,80 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +template +__global__ void sgemm1DBlocktiling(int M, int N, int K, float alpha, + const float *A, const float *B, float beta, + float *C) { + // If we flip x and y here we get ~30% less performance for large matrices. + // The current, 30% faster configuration ensures that blocks with sequential + // blockIDs access columns of B sequentially, while sharing the same row of A. + // The slower configuration would share columns of A, but access into B would + // be non-sequential. So the faster configuration has better spatial locality + // and hence a greater L2 hit rate. + const uint cRow = blockIdx.y; + const uint cCol = blockIdx.x; + + // each warp will calculate 32*TM elements, with 32 being the columnar dim. + const int threadCol = threadIdx.x % BN; + const int threadRow = threadIdx.x / BN; + + // allocate space for the current blocktile in SMEM + __shared__ float As[BM * BK]; + __shared__ float Bs[BK * BN]; + + // Move blocktile to beginning of A's row and B's column + A += cRow * BM * K; + B += cCol * BN; + C += cRow * BM * N + cCol * BN; + + // todo: adjust this to each thread to load multiple entries and + // better exploit the cache sizes + assert(BM * BK == blockDim.x); + assert(BN * BK == blockDim.x); + const uint innerColA = threadIdx.x % BK; // warp-level GMEM coalescing + const uint innerRowA = threadIdx.x / BK; + const uint innerColB = threadIdx.x % BN; // warp-level GMEM coalescing + const uint innerRowB = threadIdx.x / BN; + + // allocate thread-local cache for results in registerfile + float threadResults[TM] = {0.0}; + + // outer loop over block tiles + for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { + // populate the SMEM caches + As[innerRowA * BK + innerColA] = A[innerRowA * K + innerColA]; + Bs[innerRowB * BN + innerColB] = B[innerRowB * N + innerColB]; + __syncthreads(); + + // advance blocktile + A += BK; + B += BK * N; + + // calculate per-thread results + for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { + // we make the dotproduct loop the outside loop, which facilitates + // reuse of the Bs entry, which we can cache in a tmp var. + float tmpB = Bs[dotIdx * BN + threadCol]; + for (uint resIdx = 0; resIdx < TM; ++resIdx) { + threadResults[resIdx] += + As[(threadRow * TM + resIdx) * BK + dotIdx] * tmpB; + } + } + __syncthreads(); + } + + // write out the results + for (uint resIdx = 0; resIdx < TM; ++resIdx) { + C[(threadRow * TM + resIdx) * N + threadCol] = + alpha * threadResults[resIdx] + + beta * C[(threadRow * TM + resIdx) * N + threadCol]; + } +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/pack_weights.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/pack_weights16.cu similarity index 100% rename from bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/pack_weights.cu rename to bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/pack_weights16.cu diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/pack_weights32.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/pack_weights32.cu new file mode 100644 index 0000000..580d9b7 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/pack_weights32.cu @@ -0,0 +1,84 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +__global__ void pack_weights_kernel(const __half* weights, int* packed_weights, int n, int k) { + + int rowFactor = blockIdx.y * blockDim.y + threadIdx.y + int row = rowFactor * 2; + int colStart = blockIdx.x * 8; + + if (row < n && colStart < k) { + + int packed = 0; + + for (innerRow=0, innerRow<2, ++innerRow) { + + for (int index=0; index<8; ++index) { + + if (colStart + index < k && innerRow + row < n) { + + float weight_value = __half2float(weights[(row + innerRow) * k + colStart + index]); + int bit_mask = (weight_value == 1.0f) ? (1 << (innerRow*16 + index*2)) : + (weight_value == -1.0f) ? (2 << (innerRow*16 + index*2)) : 0; + + packed |= bit_mask; + } + } + } + + int packed_index = blockIdx.x * n + row / 2; + packed_weights[packed_index] = packed; + } +} + +/* +In this implementation we have 16 weights in the K direction, which are then ordered in +column major order ([0, 0-15], [1, 0-15], ... [(N-1), 0-15], [0, 16-31], ...) +This comes from a NxK Weight Matrix +*/ +torch::Tensor packed_K16_row_major(torch::Tensor weights) { + + TORCH_CHECK(weights.is_contiguous(), "weights tensor must be contiguous"); + TORCH_CHECK(weights.dtype() == torch::kFloat16, "weights tensor must be of type float16"); + + int n = weights.size(0); + int k = weights.size(1); + + TORCH_CHECK(k % 8 == 0, "K must be divisible by 8"); + + auto packed_weights = torch::zeros({n/2, k/8}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); + + const unsigned int block_size = 64; + const unsigned int grid_rows = CEIL_DIV(n/2, block_size); + const unsigned int grid_cols = CEIL_DIV(k, 8); + + dim3 dimGrid(grid_cols, grid_rows); + dim3 dimBlock(1, block_size); + + pack_weights_kernel<<>>( + reinterpret_cast(weights.data_ptr()), + packed_weights.data_ptr(), + n, + k + ); + + cudaError_t err = cudaDeviceSynchronize(); + TORCH_CHECK(err == cudaSuccess, "CUDA kernel failed with error: ", cudaGetErrorString(err)); + + return packed_weights; +} + + +PYBIND11_MODULE(pack_weights, m) { + m.def("packed_K16_row_major", &packed_K16_row_major, "Pack fp16 weights into int32 tensor"); +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/shared_mem.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/shared_mem.cu new file mode 100644 index 0000000..806b698 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/shared_mem.cu @@ -0,0 +1,57 @@ +#pragma once + +#include +#include +#include +#include +#include + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +template +__global__ void sgemm_shared_mem_block(int M, int N, int K, float alpha, + const float *A, const float *B, + float beta, float *C) { + // the output block that we want to compute in this threadblock + const uint cRow = blockIdx.x; + const uint cCol = blockIdx.y; + + // allocate buffer for current block in fast shared mem + // shared mem is shared between all threads in a block + __shared__ float As[BLOCKSIZE * BLOCKSIZE]; + __shared__ float Bs[BLOCKSIZE * BLOCKSIZE]; + + // the inner row & col that we're accessing in this thread + const uint threadCol = threadIdx.x % BLOCKSIZE; + const uint threadRow = threadIdx.x / BLOCKSIZE; + + // advance pointers to the starting positions + A += cRow * BLOCKSIZE * K; // row=cRow, col=0 + B += cCol * BLOCKSIZE; // row=0, col=cCol + C += cRow * BLOCKSIZE * N + cCol * BLOCKSIZE; // row=cRow, col=cCol + + float tmp = 0.0; + for (int bkIdx = 0; bkIdx < K; bkIdx += BLOCKSIZE) { + // Have each thread load one of the elements in A & B + // Make the threadCol (=threadIdx.x) the consecutive index + // to allow global memory access coalescing + As[threadRow * BLOCKSIZE + threadCol] = A[threadRow * K + threadCol]; + Bs[threadRow * BLOCKSIZE + threadCol] = B[threadRow * N + threadCol]; + + // block threads in this block until cache is fully populated + __syncthreads(); + A += BLOCKSIZE; + B += BLOCKSIZE * N; + + // execute the dotproduct on the currently cached block + for (int dotIdx = 0; dotIdx < BLOCKSIZE; ++dotIdx) { + tmp += As[threadRow * BLOCKSIZE + dotIdx] * + Bs[dotIdx * BLOCKSIZE + threadCol]; + } + // need to sync again at the end, to avoid faster threads + // fetching the next block into the cache before slower threads are done + __syncthreads(); + } + C[threadRow * N + threadCol] = + alpha * tmp + beta * C[threadRow * N + threadCol]; +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/vectorized.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/vectorized.cu new file mode 100644 index 0000000..e330941 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/vectorized.cu @@ -0,0 +1,122 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +template +__global__ void sgemmVectorize(int M, int N, int K, float alpha, float *A, + float *B, float beta, float *C) { + const uint cRow = blockIdx.y; + const uint cCol = blockIdx.x; + + // BN/TN are the number of threads to span a column + const int threadCol = threadIdx.x % (BN / TN); + const int threadRow = threadIdx.x / (BN / TN); + + // allocate space for the current blocktile in smem + __shared__ float As[BM * BK]; + __shared__ float Bs[BK * BN]; + + // Move blocktile to beginning of A's row and B's column + A += cRow * BM * K; + B += cCol * BN; + C += cRow * BM * N + cCol * BN; + + // calculating the indices that this thread will load into SMEM + // we'll load 128bit / 32bit = 4 elements per thread at each step + const uint innerRowA = threadIdx.x / (BK / 4); + const uint innerColA = threadIdx.x % (BK / 4); + const uint innerRowB = threadIdx.x / (BN / 4); + const uint innerColB = threadIdx.x % (BN / 4); + + // allocate thread-local cache for results in registerfile + float threadResults[TM * TN] = {0.0}; + float regM[TM] = {0.0}; + float regN[TN] = {0.0}; + + // outer-most loop over block tiles + for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { + // populate the SMEM caches + // transpose A while loading it + float4 tmp = + reinterpret_cast(&A[innerRowA * K + innerColA * 4])[0]; + As[(innerColA * 4 + 0) * BM + innerRowA] = tmp.x; + As[(innerColA * 4 + 1) * BM + innerRowA] = tmp.y; + As[(innerColA * 4 + 2) * BM + innerRowA] = tmp.z; + As[(innerColA * 4 + 3) * BM + innerRowA] = tmp.w; + + reinterpret_cast(&Bs[innerRowB * BN + innerColB * 4])[0] = + reinterpret_cast(&B[innerRowB * N + innerColB * 4])[0]; + __syncthreads(); + + // advance blocktile + A += BK; // move BK columns to right + B += BK * N; // move BK rows down + + // calculate per-thread results + for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { + // block into registers + for (uint i = 0; i < TM; ++i) { + regM[i] = As[dotIdx * BM + threadRow * TM + i]; + } + for (uint i = 0; i < TN; ++i) { + regN[i] = Bs[dotIdx * BN + threadCol * TN + i]; + } + for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { + for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { + threadResults[resIdxM * TN + resIdxN] += + regM[resIdxM] * regN[resIdxN]; + } + } + } + __syncthreads(); + } + + // write out the results + for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) { + for (uint resIdxN = 0; resIdxN < TN; resIdxN += 4) { + // load C vector into registers + float4 tmp = reinterpret_cast( + &C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0]; + // perform GEMM update in reg + tmp.x = alpha * threadResults[resIdxM * TN + resIdxN] + beta * tmp.x; + tmp.y = alpha * threadResults[resIdxM * TN + resIdxN + 1] + beta * tmp.y; + tmp.z = alpha * threadResults[resIdxM * TN + resIdxN + 2] + beta * tmp.z; + tmp.w = alpha * threadResults[resIdxM * TN + resIdxN + 3] + beta * tmp.w; + // write back + reinterpret_cast( + &C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0] = + tmp; + } + } +} + +void runSgemmVectorize(int M, int N, int K, float alpha, float *A, float *B, + float beta, float *C) { + const uint BK = 8; + const uint TM = 8; + const uint TN = 8; + if (M >= 128 and N >= 128) { + const uint BM = 128; + const uint BN = 128; + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 blockDim((BM * BN) / (TM * TN)); + sgemmVectorize + <<>>(M, N, K, alpha, A, B, beta, C); + } else { + // this is a hacky solution to the underlying problem + // of not having proper bounds checking in the kernel + const uint BM = 64; + const uint BN = 64; + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 blockDim((BM * BN) / (TM * TN)); + sgemmVectorize + <<>>(M, N, K, alpha, A, B, beta, C); + } +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/naive.cu b/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/naive.cu index ca85336..448df63 100644 --- a/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/naive.cu +++ b/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/naive.cu @@ -29,7 +29,6 @@ __global__ void naive_kernel( for (int offset=0; offset<4; offset++) { int8_t mask = (weight & (3 << (2 * offset))) >> (2 * offset); - float input_val = __half2float(input[row * K + k + offset]); if (mask == 1) { diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cu b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cu index b89421a..8f06c1f 100644 --- a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cu +++ b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cu @@ -24,7 +24,6 @@ const uint NUM_WARPS = NUM_THREADS / WARPSIZE; #define CEIL_DIV(x, y) ((x) + (y)-1) / (y) - /* Half4 Helper Functions */