diff --git a/.gitignore b/.gitignore index ca0b987..7cc7d22 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ __pycache__ dist examples/data examples/models +.vscode/ diff --git a/bitlinear/__init__.py b/bitlinear/__init__.py index 5d740a3..540dcf3 100644 --- a/bitlinear/__init__.py +++ b/bitlinear/__init__.py @@ -8,10 +8,8 @@ AbsMean, AbsMedian, ) -from .kernels import ( - Naive, - NaiveListComp, - TernaryNaive, +from .frozen_bitlinear.frozen_bitlinear import ( TorchLinear, - TorchMulAdd, -) \ No newline at end of file + Naive +) + diff --git a/bitlinear/frozen_bitlinear/.gitignore b/bitlinear/frozen_bitlinear/.gitignore new file mode 100644 index 0000000..9ce6516 --- /dev/null +++ b/bitlinear/frozen_bitlinear/.gitignore @@ -0,0 +1,4 @@ +results +**/weights +**.so +**/build \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/README.md b/bitlinear/frozen_bitlinear/README.md new file mode 100644 index 0000000..921fad7 --- /dev/null +++ b/bitlinear/frozen_bitlinear/README.md @@ -0,0 +1,146 @@ + +## Motivations + +Matrix multiplications are a key building block of most modern high-performance computing systems. +They are notoriously hard to optimize, hence their implementation is generally done by +hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS). +Unfortunately, these libraries are often proprietary and cannot be easily customized +to accommodate the needs of our ternary system. Attempts at [matmul-free linear implementations](https://github.com/ridgerchu/matmulfreellm) still rely on the cuBLAS kernel for the matmul. The goal of this investigation is to develop a CUDA kernel competitive with cuBLAS in speed and performance, and ultimately far faster, as a ternary system can be implemented by a series of masked adds as opposed to traditional matrix multiplication. +Roughly speaking, the traditional kernel implements the following blocked +algorithm to multiply a (M, K) by a (K, N) matrix: + +```C++ + const uint col = blockIdx.x * blockDim.x + threadIdx.x; + const uint row = blockIdx.y * blockDim.y + threadIdx.y; + + if (row < M && col < N) { + + for (int k = 0; k < K) { + sum += A[row, k] * B[k, col]; + } + + output[row * N + col] = sum + bias[row]; + } +``` +where each instance of the above code is a single thread. + +We hope to develop a kernel that can avoid the multiplicative step, instead following the successive framework: + +```C++ + const uint col = blockIdx.x * blockDim.x + threadIdx.x; + const uint row = blockIdx.y * blockDim.y + threadIdx.y; + + if (row < M && col < N) { + + for (int k = 0; k < K) { + if (B[k, col] == 1) { + sum += A[row, k]; + } else if (B[k, col] == -1) { + sum -= A[row, k]; + } + } + + output[row * N + col] = sum + bias[row]; + } +``` + +The code above corresponds to excution of a single thread. For more information on how CUDA parallelization is implemented in code and in reality, visit [here](https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/). + +### Optimizations + +To see savings from our algorithm, we turn to weight packing algorithms. If we represent a weight in two bits, then we can take an eighth of the memory to store the weights, resulting in potentially significant speedups if memory loads are the limiting factor (as they often are). + +Because weight loads are significantly cheaper per capita than activation loads, we want to more heavily prioritize cache hits on activation loads. To do so, we can follow a number of strategies. Within each thread, we can compute a greater number of computations with the same activations. This may lead to speedups in memory loads, but significant optimization is required to find the balance between memory savings and parallelization costs. We can also use shared memory between threads, which allows for threads to compute different outputs in parallel, but without reloading activations each time. A barebones implementation is shown below: + +```C++ + const uint col = blockIdx.x * blockDim.x + threadIdx.x; + const uint row = blockIdx.y * blockDim.y + threadIdx.y; + + __shared__ half shared_activations[]; + + if (row < M && col < N) { + + float sum = 0; + + for (int k = 0; k < K) { + shared_activations[k] = A[row, k]; + } + + __syncthreads(); + + for (int k = 0; k < K) { + if (B[k, col] == 1) { + sum += A[row, k]; + } else if (B[k, col] == -1) { + sum -= A[row, k]; + } + } + + output[row * N + col] = sum + bias[row]; +} +``` +The tradeoff with this implementation is the reliance on threads to be synchronized within the thread call, which is necessary to avoid race conditions but results in a performance penalty. + +Further optimization is required to balance the size of K, M and N in scheduling the batch sizes to schedule to the block and individual threads. Because cuBLAS is highly optimized in this regard, we need a significant speedup to offset their gains. Later work can go into developing autotuning algorithms that dynamically schedule the computation as to maximize the L2 cache hit ratio and minimize synchronization delay. + +## Future Work + +As of now, this optimization extends only into inference. The method of packing weights seems redundant in training, as the shadow weights being stored in higher precision is [necessary for gradient accumulation](https://arxiv.org/pdf/2402.17764). The speedups still can be realized in the forward passes, but more investigation and more robust kernels must be devised to turn these into reality. + +Further work can also be spent looking into fusing these kernels with layer norm, softmax, and other such adjoining layers to realize even faster speedups. This has been growing in popularity recently, and luckily, alot of the work [has been done](https://github.com/ridgerchu/matmulfreellm) for ternary weight layers, we just need to plug and play with our matmul kernel. + +Finally, we have 1.58 bits of information packed into 2, but by utilizing compression algorithms, we can realize even higher levels of weight packing. Because a weight can only occupy one of 3 states ($00$, $10$, or $01$), we can compress $10$ to just $1$, resulting in a further 13% of savings. Another solution could be extending the context region, which would introduce complex interdependencies, but as each weight must be loaded in higher precision in a load magnitudes more costly than bit-level computation, it may prove beneficial. As weight loads are not the crux of the cost, however, it is unclear if traversing this path is worth it. + +## Setup +Make sure you have the correct packages installed in your virtual environment. In a conda environment, you can run: +``` +> conda create -n env python pip +> conda activate env +> pip install -r requirements.txt +``` + +Afterwards, you need to build the desired CUDA kernels for use on your machine. To do so, you can selectively build the kernels you would like to test or use through: +``` +> cd path/to/bitlinear/bitlinear/frozen_bitlinear/cuda/pack_weights +> python setup.py build_ext --inplace +> cd path/to/bitlinear/bitlinear/frozen_bitlinear/cuda/kernels/* +> python setup.py build_ext --inplace +``` +You can also build all of them at once by running +``` +> cd path/to/bitlinear/bitlinear/frozen_bitlinear/ +> chmod +x scripts/build.sh +> scripts/build.sh +``` + +## Testing +Choose one of the kernels available in ```path/to/bitlinear/bitlinear/frozen_bitlinear/src/``` to test against the PyTorch baseline for your device. +``` +< conda activate env +< cd path/to/bitlinear/bitlinear/frozen_bitlinear +< chmod +x scripts/test.sh +< scripts/test.sh + -d (CUDA_AVAILABLE_DEVICES=$device) + -k +``` + +All logs, data, and plots are stored locally under ```path/to/bitlinear/bitlinear/frozen_bitlinear/results/{%Y%m%d_%T}/```. + +If any issues come up, you can reach me at [sopsahl@mit.edu](mailto:sopsahl@mit.edu). + +All benchmarks have been tested on an A6000. + +## Cleanup +To clean the builds and results, run the following commands. +``` +> cd path/to/bitlinear/bitlinear/frozen_bitlinear +> chmod +x scripts/clean.sh +> scripts/clean.sh + -b (default: builds only) + -w (weights) + -r (results) +``` + + +## References +... \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/__init__.py b/bitlinear/frozen_bitlinear/__init__.py new file mode 100644 index 0000000..95b8b85 --- /dev/null +++ b/bitlinear/frozen_bitlinear/__init__.py @@ -0,0 +1 @@ +from src import TorchLinear, Naive, Warptiling \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/__init__.py b/bitlinear/frozen_bitlinear/cuda/__init__.py new file mode 100644 index 0000000..bc55d31 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/__init__.py @@ -0,0 +1,2 @@ +from cuda.kernels import * +from cuda.pack_weights import pack_weights \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/__init__.py b/bitlinear/frozen_bitlinear/cuda/kernels/__init__.py new file mode 100644 index 0000000..8c42e8f --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/__init__.py @@ -0,0 +1,2 @@ +from cuda.kernels.naive_linear import naive +from cuda.kernels.warptiling import warptiling \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/archive/bitlinear_naive.cu b/bitlinear/frozen_bitlinear/cuda/kernels/archive/bitlinear_naive.cu new file mode 100644 index 0000000..79398cc --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/archive/bitlinear_naive.cu @@ -0,0 +1,84 @@ +#include +#include +#include +#include +#include + +#define CEIL_DIV(x, y) ((x) + (y)-1) / (y) + +__global__ void naive_kernel( + const half *input, + const int *weights, + const half *bias, + half *output, + float scale, + int M, + int N, + int K + ) +{ + const uint col = blockIdx.x * blockDim.x + threadIdx.x; + const uint row = blockIdx.y * blockDim.y + threadIdx.y; + + if (row < M && col < N) { + + float sum = 0.0f; + int weight; + + for (int k = 0; k < K; k += 16) { + weight = weights[(col * K + k)/16]; + + for (int offset=0; offset<16; offset++) { + int mask = (weight & (3 << (2 * offset))) >> (2 * offset); + float input_val = __half2float(input[row * K + k + offset]); + + if (mask == 1) { + sum += input_val; + } else if (mask == 2) { + sum -= input_val; + } + } + + } + + // Store the result into the correct location in memory + output[row * N + col] = __float2half((sum + __half2float(bias[row]))); + + } +} + +torch::Tensor linear( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + float scale, + int M, + int N, + int K +) { + auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0); + auto output = torch::zeros({M, N}, options); + + uint blockSize = 32; + + dim3 dimGrid(CEIL_DIV(N, blockSize), CEIL_DIV(M, blockSize)); + dim3 dimBlock(blockSize, blockSize); + + naive_kernel<<>>( + reinterpret_cast(input.data_ptr()), + weights.data_ptr(), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(output.data_ptr()), + scale, + M, N, K + ); + + return output; +} + + +// Binding to generate the .so file, to call from Python. +PYBIND11_MODULE(bitlinear_naive, m) { + m.doc() = "Implementation of bitlinear forward linear in CUDA"; + m.def("linear", &linear, "bitlinear_forward (CUDA)"); +} diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/archive/naive_linear.cu b/bitlinear/frozen_bitlinear/cuda/kernels/archive/naive_linear.cu new file mode 100644 index 0000000..607983b --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/archive/naive_linear.cu @@ -0,0 +1,178 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +/* +********************************************************************* +function name: linear_matmul_cuda +description: dot product of two arbitrarily sized matrices. +parameters: + input: Input of size m X k. + weights: weight kernel of size n X k packed in int8. Contains + 1 at bit0 if weight is 1 and 1 at bit1 if weight is -1 + bias: bias per output channel. + m,k,n: sizes of matrices. +return: none +********************************************************************* +*/ +__global__ void fp16_linear_matmul_cuda( + const half *input, + const int8_t *weights, + const half *bias, + half *output, + const int m, + const int k, + const int n, + int *status +) +{ + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + // Check if we are within bounds + if (col < n && row < m) { + + half sum = __float2half(0.0f); + int start = col * k; + int shift = (start << 1) % 8; + + int index; + int8_t weight; + + // Iterate through the interior dimension + for (int i = 0; i < k; i++) { + + index = (start + i) >> 2; + weight = (weights[index] >> shift) & 0x03; + + if (weight == 1) { + sum = __hadd(sum, input[row * k + i]); + } else if (weight == 2) { + sum = __hsub(sum, input[row * k + i]); + } else if (weight == 3) { + atomicExch(status, 1); // Set error flag, weights incorrectly packed + return; + } + + shift = (shift + 2) % 8; + } + + // Store the result into the correct location in memory + output[row * n + col] = __hadd(sum, bias[row]); + + } +} + +////////////////////////////////////////////////////////////// +// This performs the computations in fp32 for higher precision +////////////////////////////////////////////////////////////// +__global__ void fp32_linear_matmul_cuda( + const half *input, + const int8_t *weights, + const half *bias, + half *output, + const int m, + const int k, + const int n +) +{ + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + // Check if we are within bounds + if (col < n && row < m) { + + float sum = 0.0f; + int start = col * k; + int shift = (start << 1) % 8; + + int index; + int8_t weight; + + // Iterate through the interior dimension + for (int i = 0; i < k; i++) { + + index = (start + i) >> 2; + weight = (weights[index] >> shift) & 0x03; + + if (weight == 1) { + sum += __half2float(input[row * k + i]); + } else if (weight == 2) { + sum -= __half2float(input[row * k + i]); + } + + shift = (shift + 2) % 8; + } + + // Store the result into the correct location in memory + output[row * n + col] = __float2half(sum + __half2float(bias[row])); + + } +} +/* +********************************************************************* +function name: linear +description: linear layer that calls the matmul kernel. +parameters: + input: Input of size m X k. + weights: weight kernel of size n X k packed in int8. Contains + 1 at bit0 if weight is 1 and 1 at bit1 if weight is -1 + bias: bias per output channel. + m,k,n: sizes of matrices. + row_block_size: the number of rows to compute in parallel. + col_block_size: the number of columns to compute in parallel. +return: + output: output of size m x n. +********************************************************************* +*/ +torch::Tensor linear( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + int m, + int k, + int n, + const unsigned int row_block_size, + const unsigned int col_block_size +) { + auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0); + auto output = torch::zeros({m, n}, options); + + unsigned int grid_rows = (m + row_block_size - 1) / row_block_size; + unsigned int grid_cols = (n + col_block_size - 1) / col_block_size; + + dim3 dimGrid(grid_cols, grid_rows); + dim3 dimBlock(col_block_size, row_block_size); + + fp32_linear_matmul_cuda<<>>( + reinterpret_cast(input.data_ptr()), + weights.data_ptr(), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(output.data_ptr()), + m, k, n + ); + + // Check for any errors launching the kernel + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + throw std::runtime_error("CUDA kernel launch failed: " + std::string(cudaGetErrorString(err))); + } + + + cudaDeviceSynchronize(); + return output; +} + + +// Binding to generate the .so file, to call from Python. +PYBIND11_MODULE(naive_linear_cuda, m) { + m.doc() = "Implementation of bitlinear forward linear in CUDA"; + m.def("linear", &linear, "bitlinear_forward (CUDA)"); +} diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/archive/no_sync.cu b/bitlinear/frozen_bitlinear/cuda/kernels/archive/no_sync.cu new file mode 100644 index 0000000..dd3badb --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/archive/no_sync.cu @@ -0,0 +1,128 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +/* +********************************************************************* +function name: linear_matmul_cuda +description: dot product of two arbitrarily sized matrices. +parameters: + input: Input of size m X k. + weights: weight kernel of size n X k packed in int8. Contains + 1 at bit0 if weight is 1 and 1 at bit1 if weight is -1 + bias: bias per output channel. + m,k,n: sizes of matrices. +return: none +********************************************************************* +*/ +////////////////////////////////////////////////////////////// +// This performs the computations in fp32 for higher precision +////////////////////////////////////////////////////////////// +__global__ void fp32_linear_matmul_cuda( + const half *input, + const int8_t *weights, + const half *bias, + half *output, + const int m, + const int k, + const int n +) +{ + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + // Check if we are within bounds + if (col < n && row < m) { + + float sum = 0.0f; + int start = col * k; + int shift = (start << 1) % 8; + + int index; + int8_t weight; + + // Iterate through the interior dimension + for (int i = 0; i < k; i++) { + + index = (start + i) >> 2; + weight = (weights[index] >> shift) & 0x03; + + if (weight == 1) { + sum += __half2float(input[row * k + i]); + } else if (weight == 2) { + sum -= __half2float(input[row * k + i]); + } + + shift = (shift + 2) % 8; + } + + // Store the result into the correct location in memory + output[row * n + col] = __float2half(sum + __half2float(bias[row])); + + } +} +/* +********************************************************************* +function name: linear +description: linear layer that calls the matmul kernel. +parameters: + input: Input of size m X k. + weights: weight kernel of size n X k packed in int8. Contains + 1 at bit0 if weight is 1 and 1 at bit1 if weight is -1 + bias: bias per output channel. + m,k,n: sizes of matrices. + row_block_size: the number of rows to compute in parallel. + col_block_size: the number of columns to compute in parallel. +return: + output: output of size m x n. +********************************************************************* +*/ +torch::Tensor linear( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + int m, + int k, + int n, + const unsigned int row_block_size, + const unsigned int col_block_size +) { + auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0); + auto output = torch::zeros({m, n}, options); + + unsigned int grid_rows = (m + row_block_size - 1) / row_block_size; + unsigned int grid_cols = (n + col_block_size - 1) / col_block_size; + + dim3 dimGrid(grid_cols, grid_rows); + dim3 dimBlock(col_block_size, row_block_size); + + fp32_linear_matmul_cuda<<>>( + reinterpret_cast(input.data_ptr()), + weights.data_ptr(), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(output.data_ptr()), + m, k, n + ); + + // Check for any errors launching the kernel + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + throw std::runtime_error("CUDA kernel launch failed: " + std::string(cudaGetErrorString(err))); + } + + return output; +} + + +// Binding to generate the .so file, to call from Python. +PYBIND11_MODULE(no_stream, m) { + m.doc() = "Implementation of bitlinear forward linear in CUDA"; + m.def("linear", &linear, "bitlinear_forward (CUDA)"); +} diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/archive/row_major.cu b/bitlinear/frozen_bitlinear/cuda/kernels/archive/row_major.cu new file mode 100644 index 0000000..15cd108 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/archive/row_major.cu @@ -0,0 +1,146 @@ +#include +#include +#include +#include +#include + +#define CEIL_DIV(x, y) ((x) + (y)-1) / (y) + +__global__ void naive_kernel( + const half *input, + const int8_t *weights, + const half *bias, + half *output, + int M, + int N, + int K + ) +{ + const uint col = blockIdx.x * blockDim.x + threadIdx.x; + const uint row = blockIdx.y * blockDim.y + threadIdx.y; + + if (row < M && col < N) { + + float sum = 0.0f; + int8_t weight; + + for (int k = 0; k < K; k += 4) { + weight = weights[(col * K + k) >> 2]; + + for (int offset=0; offset<4; offset++) { + int8_t mask = (weight & (3 << (2 * offset))) >> (2 * offset); + + float input_val = __half2float(input[row * K + k + offset]); + + if (mask == 1) { + sum += input_val; + } else if (mask == 2) { + sum -= input_val; + } + } + + } + + // Store the result into the correct location in memory + output[row * N + col] = __float2half(sum + __half2float(bias[row])); + + } +} + +torch::Tensor linear( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + int M, + int N, + int K +) { + auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0); + auto output = torch::zeros({M, N}, options); + + uint blockSize = 32; + + dim3 dimGrid(CEIL_DIV(N, blockSize), CEIL_DIV(M, blockSize)); + dim3 dimBlock(blockSize, blockSize); + + naive_kernel<<>>( + reinterpret_cast(input.data_ptr()), + weights.data_ptr(), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(output.data_ptr()), + M, N, K + ); + + return output; +} + + +// Binding to generate the .so file, to call from Python. +PYBIND11_MODULE(naive, m) { + m.doc() = "Implementation of bitlinear forward linear in CUDA"; + m.def("linear", &linear, "bitlinear_forward (CUDA)"); +} + +void run_sgemm_shared_mem_block(int M, int N, int K, float alpha, float *A, + float *B, float beta, float *C) { + dim3 gridDim(CEIL_DIV(M, 32), CEIL_DIV(N, 32)); + dim3 blockDim(32 * 32); + // L1 cache becomes useless, since we access GMEM only via SMEM, so we carve + // out all of L1 to SMEM. This doesn't currently make a difference, since + // occupancy is limited by reg and thread count, but it's good to do anyway. + cudaFuncSetAttribute(sgemm_shared_mem_block<32>, + cudaFuncAttributePreferredSharedMemoryCarveout, + cudaSharedmemCarveoutMaxShared); + sgemm_shared_mem_block<32> + <<>>(M, N, K, alpha, A, B, beta, C); +} + +__global__ void shared_memory( + const half *input, + const int8_t *weights, + const half *bias, + half *output, + int M, + int N, + int K, + int computeBlockSize, + int blockSize + ) +{ + const uint blockRow = blockIdx.x; + const uint threadCol = blockSize * blockIdx.y + threadIdx.x * computeBlockSize; + + // one row shared between all threads in a block + __shared__ half shared_input[K]; + + // advance pointers to the starting positions + input += blockRow * K; + weights += threadCol * K; + output += blockRow * N + threadCol; + + float sum = 0.0; + for (int bkIdx = 0; bkIdx < K; bkIdx += BLOCKSIZE) { + // Have each thread load one of the elements in A & B + // Make the threadCol (=threadIdx.x) the consecutive index + // to allow global memory access coalescing + shared_input[blockRow * BLOCKSIZE + threadCol] = A[blockRow * K + threadCol]; + Bs[blockRow * BLOCKSIZE + threadCol] = B[blockRow * N + threadCol]; + + // block threads in this block until cache is fully populated + __syncthreads(); + + A += BLOCKSIZE; + B += BLOCKSIZE * N; + + // execute the dotproduct on the currently cached block + for (int dotIdx = 0; dotIdx < BLOCKSIZE; ++dotIdx) { + tmp += As[threadRow * BLOCKSIZE + dotIdx] * + Bs[dotIdx * BLOCKSIZE + threadCol]; + } + // need to sync again at the end, to avoid faster threads + // fetching the next block into the cache before slower threads are done + __syncthreads(); + } + C[threadRow * N + threadCol] = + alpha * tmp + beta * C[threadRow * N + threadCol]; +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/archive/shared_memory.cu b/bitlinear/frozen_bitlinear/cuda/kernels/archive/shared_memory.cu new file mode 100644 index 0000000..ca85336 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/archive/shared_memory.cu @@ -0,0 +1,82 @@ +#include +#include +#include +#include +#include + +#define CEIL_DIV(x, y) ((x) + (y)-1) / (y) + +__global__ void naive_kernel( + const half *input, + const int8_t *weights, + const half *bias, + half *output, + int M, + int N, + int K + ) +{ + const uint col = blockIdx.x * blockDim.x + threadIdx.x; + const uint row = blockIdx.y * blockDim.y + threadIdx.y; + + if (row < M && col < N) { + + float sum = 0.0f; + int8_t weight; + + for (int k = 0; k < K; k += 4) { + weight = weights[(col * K + k) >> 2]; + + for (int offset=0; offset<4; offset++) { + int8_t mask = (weight & (3 << (2 * offset))) >> (2 * offset); + + float input_val = __half2float(input[row * K + k + offset]); + + if (mask == 1) { + sum += input_val; + } else if (mask == 2) { + sum -= input_val; + } + } + + } + + // Store the result into the correct location in memory + output[row * N + col] = __float2half(sum + __half2float(bias[row])); + + } +} + +torch::Tensor linear( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + int M, + int N, + int K +) { + auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0); + auto output = torch::zeros({M, N}, options); + + uint blockSize = 32; + + dim3 dimGrid(CEIL_DIV(N, blockSize), CEIL_DIV(M, blockSize)); + dim3 dimBlock(blockSize, blockSize); + + naive_kernel<<>>( + reinterpret_cast(input.data_ptr()), + weights.data_ptr(), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(output.data_ptr()), + M, N, K + ); + + return output; +} + + +// Binding to generate the .so file, to call from Python. +PYBIND11_MODULE(naive, m) { + m.doc() = "Implementation of bitlinear forward linear in CUDA"; + m.def("linear", &linear, "bitlinear_forward (CUDA)"); +} diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/archive/streamed_linear.cu b/bitlinear/frozen_bitlinear/cuda/kernels/archive/streamed_linear.cu new file mode 100644 index 0000000..f5d390e --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/archive/streamed_linear.cu @@ -0,0 +1,134 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +/* +********************************************************************* +function name: linear_matmul_cuda +description: dot product of two arbitrarily sized matrices. +parameters: + input: Input of size m X k. + weights: weight kernel of size n X k packed in int8. Contains + 1 at bit0 if weight is 1 and 1 at bit1 if weight is -1 + bias: bias per output channel. + m,k,n: sizes of matrices. +return: none +********************************************************************* +*/ +////////////////////////////////////////////////////////////// +// This performs the computations in fp32 for higher precision +////////////////////////////////////////////////////////////// +__global__ void fp32_linear_matmul_cuda( + const half *input, + const int8_t *weights, + const half *bias, + half *output, + const int m, + const int k, + const int n +) +{ + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + // Check if we are within bounds + if (col < n && row < m) { + + float sum = 0.0f; + int start = col * k; + int shift = (start << 1) % 8; + + int index; + int8_t weight; + + // Iterate through the interior dimension + for (int i = 0; i < k; i++) { + + index = (start + i) >> 2; + weight = (weights[index] >> shift) & 0x03; + + if (weight == 1) { + sum += __half2float(input[row * k + i]); + } else if (weight == 2) { + sum -= __half2float(input[row * k + i]); + } + + shift = (shift + 2) % 8; + } + + // Store the result into the correct location in memory + output[row * n + col] = __float2half(sum + __half2float(bias[row])); + + } +} +/* +********************************************************************* +function name: linear +description: linear layer that calls the matmul kernel. +parameters: + input: Input of size m X k. + weights: weight kernel of size n X k packed in int8. Contains + 1 at bit0 if weight is 1 and 1 at bit1 if weight is -1 + bias: bias per output channel. + m,k,n: sizes of matrices. + row_block_size: the number of rows to compute in parallel. + col_block_size: the number of columns to compute in parallel. +return: + output: output of size m x n. +********************************************************************* +*/ +torch::Tensor linear( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + int m, + int k, + int n, + const unsigned int row_block_size, + const unsigned int col_block_size +) { + auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0); + auto output = torch::zeros({m, n}, options); + + unsigned int grid_rows = (m + row_block_size - 1) / row_block_size; + unsigned int grid_cols = (n + col_block_size - 1) / col_block_size; + + dim3 dimGrid(grid_cols, grid_rows); + dim3 dimBlock(col_block_size, row_block_size); + + cudaStream_t stream; + cudaStreamCreate(&stream); + + fp32_linear_matmul_cuda<<>>( + reinterpret_cast(input.data_ptr()), + weights.data_ptr(), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(output.data_ptr()), + m, k, n + ); + + // Check for any errors launching the kernel + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + throw std::runtime_error("CUDA kernel launch failed: " + std::string(cudaGetErrorString(err))); + } + + cudaStreamSynchronize(stream); + cudaStreamDestroy(stream); + + return output; +} + + +// Binding to generate the .so file, to call from Python. +PYBIND11_MODULE(streamed_linear_cuda, m) { + m.doc() = "Implementation of bitlinear forward linear in CUDA"; + m.def("linear", &linear, "bitlinear_forward (CUDA)"); +} diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/kernel.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/kernel.cu new file mode 100644 index 0000000..556b37c --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/kernel.cu @@ -0,0 +1,230 @@ +#include +#include +#include +#include +#include + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +const uint NUM_THREADS = 256; // Number of Threads per block +const uint WARP_SIZE = 32; // Fixed in hardware +const uint NUM_WARPS = NUM_THREADS/WARP_SIZE; // 8 + +const uint TN = 8; // the number of elements calculated in the n dimension per thread +const uint TM = 8; // the number of elements calculated in the m dimension per thread loop + +// Each warp will access 256 elements of A each in batches of 8 +const uint STRIDE_Ak = 8; // The stride for loading A (in K direction) +const uint STRIDE_Wn = 8; // The stride for loading W (in N direction) + +const uint BN = 2048; // NUM_THREADS*STRIDE_Wk*8, the threadblock size for N dimension W SMEM caching. +const uint BK = 512; // The threadblock size for K dimension A SMEM caching. +const uint BM = 4; // NUM_WARPS, The number of rows to calculate per block + +const uint WEIGHTSper32 = 16; +const uint HALVESper32 = 2; + +namespace MemAccess { + typedef struct { + const __half data[8] + } half8; + + __device__ void loadAs( + int K, + const __half *A, half2 *As, + int innerColA, int innerRowA + ) { + + // Load a half8 from A + const half8 tmp = reinterpret_cast(&A[innerRowA * K + innerColA* STRIDE_Ak])[0]; + + // Convert half8 data to half2, transpose it and store in As + As[(innerColA * 4 + 0) * BM + innerRowA] = __halves2half2(tmp.data[1], tmp.data[0]); + As[(innerColA * 4 + 1) * BM + innerRowA] = __halves2half2(tmp.data[3], tmp.data[2]); + As[(innerColA * 4 + 2) * BM + innerRowA] = __halves2half2(tmp.data[5], tmp.data[4]); + As[(innerColA * 4 + 3) * BM + innerRowA] = __halves2half2(tmp.data[7], tmp.data[6]); + + } + + __device__ void loadWs( + const short *W, short *Ws, + int innerRowW + ) { + + // Load an int4 from W + const short8 tmp = reinterpret_cast(W)[innerRowW*STRIDE_Wn]; + + // Convert int4 data to int16_t and store in Ws + Ws[innerRowW*STRIDE_Wn] = tmp[0]; + Ws[innerRowW*STRIDE_Wn + 1] = tmp[1]; + Ws[innerRowW*STRIDE_Wn + 2] = tmp[2]; + Ws[innerRowW*STRIDE_Wn + 3] = tmp[3]; + Ws[innerRowW*STRIDE_Wn + 4] = tmp[4]; + Ws[innerRowW*STRIDE_Wn + 5] = tmp[5]; + Ws[innerRowW*STRIDE_Wn + 6] = tmp[6]; + Ws[innerRowW*STRIDE_Wn + 7] = tmp[7]; + + } + + __device__ void load_regA( + const half2 *As, float *regA, + int index + ) { + + const half8 tmp = reinterpret_cast(&As[index*BM])[0]; + + // loading registers with the 8x2 activations from As + + for (int i=0; i( + &C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0]; + // perform GEMM update in reg + tmp.x = alpha * threadResults[resIdxM * TN + resIdxN] + beta * tmp.x; + tmp.y = alpha * threadResults[resIdxM * TN + resIdxN + 1] + beta * tmp.y; + tmp.z = alpha * threadResults[resIdxM * TN + resIdxN + 2] + beta * tmp.z; + tmp.w = alpha * threadResults[resIdxM * TN + resIdxN + 3] + beta * tmp.w; + // write back + reinterpret_cast( + &C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0] = + tmp; + } + } +} + +void runSgemmVectorize(Torch:Tensor *A, int *B, + float beta, float *C) { + const uint BK = 8; + const uint TM = 8; + const uint TN = 8; + if (M >= 128 and N >= 128) { + const uint BM = 128; + const uint BN = 128; + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 blockDim((BM * BN) / (TM * TN)); + sgemmVectorize + <<>>(M, N, K, alpha, A, B, beta, C); + } else { + // this is a hacky solution to the underlying problem + // of not having proper bounds checking in the kernel + const uint BM = 64; + const uint BN = 64; + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 blockDim((BM * BN) / (TM * TN)); + sgemmVectorize + <<>>(M, N, K, alpha, A, B, beta, C); + } +} + + +torch::Tensor linear( + torch::Tensor *A, + torch::Tensor *W , + half *bias, + half *scale + ) { + + int M = A.size(0); + int K = A.size(1); + int N = W.size(0); + + auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0); + auto output = torch::zeros({M, N}, options); + + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 blockDim((BM * BN) / TM); + + sgemmDBlocktiling<<>>( + M, N, K, + reinterpret_cast(A.data_ptr()), + W.data_ptr(), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(scale.data_ptr()) + reinterpret_cast(output.data_ptr()) + ); + + return output; +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/1dblocktiling.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/1dblocktiling.cu new file mode 100644 index 0000000..12c9c5d --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/1dblocktiling.cu @@ -0,0 +1,80 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +template +__global__ void sgemm1DBlocktiling(int M, int N, int K, float alpha, + const float *A, const float *B, float beta, + float *C) { + // If we flip x and y here we get ~30% less performance for large matrices. + // The current, 30% faster configuration ensures that blocks with sequential + // blockIDs access columns of B sequentially, while sharing the same row of A. + // The slower configuration would share columns of A, but access into B would + // be non-sequential. So the faster configuration has better spatial locality + // and hence a greater L2 hit rate. + const uint cRow = blockIdx.y; + const uint cCol = blockIdx.x; + + // each warp will calculate 32*TM elements, with 32 being the columnar dim. + const int threadCol = threadIdx.x % BN; + const int threadRow = threadIdx.x / BN; + + // allocate space for the current blocktile in SMEM + __shared__ float As[BM * BK]; + __shared__ float Bs[BK * BN]; + + // Move blocktile to beginning of A's row and B's column + A += cRow * BM * K; + B += cCol * BN; + C += cRow * BM * N + cCol * BN; + + // todo: adjust this to each thread to load multiple entries and + // better exploit the cache sizes + assert(BM * BK == blockDim.x); + assert(BN * BK == blockDim.x); + const uint innerColA = threadIdx.x % BK; // warp-level GMEM coalescing + const uint innerRowA = threadIdx.x / BK; + const uint innerColB = threadIdx.x % BN; // warp-level GMEM coalescing + const uint innerRowB = threadIdx.x / BN; + + // allocate thread-local cache for results in registerfile + float threadResults[TM] = {0.0}; + + // outer loop over block tiles + for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { + // populate the SMEM caches + As[innerRowA * BK + innerColA] = A[innerRowA * K + innerColA]; + Bs[innerRowB * BN + innerColB] = B[innerRowB * N + innerColB]; + __syncthreads(); + + // advance blocktile + A += BK; + B += BK * N; + + // calculate per-thread results + for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { + // we make the dotproduct loop the outside loop, which facilitates + // reuse of the Bs entry, which we can cache in a tmp var. + float tmpB = Bs[dotIdx * BN + threadCol]; + for (uint resIdx = 0; resIdx < TM; ++resIdx) { + threadResults[resIdx] += + As[(threadRow * TM + resIdx) * BK + dotIdx] * tmpB; + } + } + __syncthreads(); + } + + // write out the results + for (uint resIdx = 0; resIdx < TM; ++resIdx) { + C[(threadRow * TM + resIdx) * N + threadCol] = + alpha * threadResults[resIdx] + + beta * C[(threadRow * TM + resIdx) * N + threadCol]; + } +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/original.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/original.cu new file mode 100644 index 0000000..12c9c5d --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/original.cu @@ -0,0 +1,80 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +template +__global__ void sgemm1DBlocktiling(int M, int N, int K, float alpha, + const float *A, const float *B, float beta, + float *C) { + // If we flip x and y here we get ~30% less performance for large matrices. + // The current, 30% faster configuration ensures that blocks with sequential + // blockIDs access columns of B sequentially, while sharing the same row of A. + // The slower configuration would share columns of A, but access into B would + // be non-sequential. So the faster configuration has better spatial locality + // and hence a greater L2 hit rate. + const uint cRow = blockIdx.y; + const uint cCol = blockIdx.x; + + // each warp will calculate 32*TM elements, with 32 being the columnar dim. + const int threadCol = threadIdx.x % BN; + const int threadRow = threadIdx.x / BN; + + // allocate space for the current blocktile in SMEM + __shared__ float As[BM * BK]; + __shared__ float Bs[BK * BN]; + + // Move blocktile to beginning of A's row and B's column + A += cRow * BM * K; + B += cCol * BN; + C += cRow * BM * N + cCol * BN; + + // todo: adjust this to each thread to load multiple entries and + // better exploit the cache sizes + assert(BM * BK == blockDim.x); + assert(BN * BK == blockDim.x); + const uint innerColA = threadIdx.x % BK; // warp-level GMEM coalescing + const uint innerRowA = threadIdx.x / BK; + const uint innerColB = threadIdx.x % BN; // warp-level GMEM coalescing + const uint innerRowB = threadIdx.x / BN; + + // allocate thread-local cache for results in registerfile + float threadResults[TM] = {0.0}; + + // outer loop over block tiles + for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { + // populate the SMEM caches + As[innerRowA * BK + innerColA] = A[innerRowA * K + innerColA]; + Bs[innerRowB * BN + innerColB] = B[innerRowB * N + innerColB]; + __syncthreads(); + + // advance blocktile + A += BK; + B += BK * N; + + // calculate per-thread results + for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { + // we make the dotproduct loop the outside loop, which facilitates + // reuse of the Bs entry, which we can cache in a tmp var. + float tmpB = Bs[dotIdx * BN + threadCol]; + for (uint resIdx = 0; resIdx < TM; ++resIdx) { + threadResults[resIdx] += + As[(threadRow * TM + resIdx) * BK + dotIdx] * tmpB; + } + } + __syncthreads(); + } + + // write out the results + for (uint resIdx = 0; resIdx < TM; ++resIdx) { + C[(threadRow * TM + resIdx) * N + threadCol] = + alpha * threadResults[resIdx] + + beta * C[(threadRow * TM + resIdx) * N + threadCol]; + } +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/pack_weights16.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/pack_weights16.cu new file mode 100644 index 0000000..78661ea --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/pack_weights16.cu @@ -0,0 +1,78 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +__global__ void pack_weights_kernel(const half* weights, short* packed_weights, int n, int k) { + + int row = blockIdx.y * blockDim.y + threadIdx.y; + int colStart = blockIdx.x * 8; + + if (row < n && colStart < k) { + + short packed = 0; + + for (int index = 0; index < 8; ++index) { + if (colStart + index < k) { + float weight_value = __half2float(weights[row * k + colStart + index]); + short bit_mask = (weight_value == 1.0f) ? (1 << (index * 2)) : + (weight_value == -1.0f) ? (2 << (index * 2)) : 0; + packed |= bit_mask; + } + } + + int packed_index = blockIdx.x * n + row; + packed_weights[packed_index] = packed; + } +} + +/* +In this implementation we have 8 weights in the K direction, which are then ordered in +column major order ([0, 0-7], [1, 0-7], ... [(N-1), 0-7], [0, 8-15], ...) +This comes from a NxK Weight Matrix +*/ +torch::Tensor packed_K8_column_major(torch::Tensor weights) { + + TORCH_CHECK(weights.is_contiguous(), "weights tensor must be contiguous"); + TORCH_CHECK(weights.dtype() == torch::kFloat16, "weights tensor must be of type float16"); + + int n = weights.size(0); + int k = weights.size(1); + + TORCH_CHECK(k % 8 == 0, "K must be divisible by 8"); + + int packed_size = CEIL_DIV(n * k, 8); + auto packed_weights = torch::zeros({packed_size}, torch::TensorOptions().dtype(torch::kInt16).device(torch::kCUDA)); + + const unsigned int block_size = 128; + const unsigned int grid_rows = CEIL_DIV(n, block_size); + const unsigned int grid_cols = CEIL_DIV(k, 8); + + dim3 dimGrid(grid_cols, grid_rows); + dim3 dimBlock(1, block_size); + + pack_weights_kernel<<>>( + reinterpret_cast(weights.data_ptr()), + packed_weights.data_ptr(), + n, + k + ); + + cudaError_t err = cudaDeviceSynchronize(); + TORCH_CHECK(err == cudaSuccess, "CUDA kernel failed with error: ", cudaGetErrorString(err)); + + return packed_weights; +} + + +PYBIND11_MODULE(pack_weights, m) { + m.def("packed_K8_column_major", &packed_K8_column_major, "Pack fp16 weights into int16 tensor"); +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/pack_weights32.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/pack_weights32.cu new file mode 100644 index 0000000..580d9b7 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/pack_weights32.cu @@ -0,0 +1,84 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +__global__ void pack_weights_kernel(const __half* weights, int* packed_weights, int n, int k) { + + int rowFactor = blockIdx.y * blockDim.y + threadIdx.y + int row = rowFactor * 2; + int colStart = blockIdx.x * 8; + + if (row < n && colStart < k) { + + int packed = 0; + + for (innerRow=0, innerRow<2, ++innerRow) { + + for (int index=0; index<8; ++index) { + + if (colStart + index < k && innerRow + row < n) { + + float weight_value = __half2float(weights[(row + innerRow) * k + colStart + index]); + int bit_mask = (weight_value == 1.0f) ? (1 << (innerRow*16 + index*2)) : + (weight_value == -1.0f) ? (2 << (innerRow*16 + index*2)) : 0; + + packed |= bit_mask; + } + } + } + + int packed_index = blockIdx.x * n + row / 2; + packed_weights[packed_index] = packed; + } +} + +/* +In this implementation we have 16 weights in the K direction, which are then ordered in +column major order ([0, 0-15], [1, 0-15], ... [(N-1), 0-15], [0, 16-31], ...) +This comes from a NxK Weight Matrix +*/ +torch::Tensor packed_K16_row_major(torch::Tensor weights) { + + TORCH_CHECK(weights.is_contiguous(), "weights tensor must be contiguous"); + TORCH_CHECK(weights.dtype() == torch::kFloat16, "weights tensor must be of type float16"); + + int n = weights.size(0); + int k = weights.size(1); + + TORCH_CHECK(k % 8 == 0, "K must be divisible by 8"); + + auto packed_weights = torch::zeros({n/2, k/8}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); + + const unsigned int block_size = 64; + const unsigned int grid_rows = CEIL_DIV(n/2, block_size); + const unsigned int grid_cols = CEIL_DIV(k, 8); + + dim3 dimGrid(grid_cols, grid_rows); + dim3 dimBlock(1, block_size); + + pack_weights_kernel<<>>( + reinterpret_cast(weights.data_ptr()), + packed_weights.data_ptr(), + n, + k + ); + + cudaError_t err = cudaDeviceSynchronize(); + TORCH_CHECK(err == cudaSuccess, "CUDA kernel failed with error: ", cudaGetErrorString(err)); + + return packed_weights; +} + + +PYBIND11_MODULE(pack_weights, m) { + m.def("packed_K16_row_major", &packed_K16_row_major, "Pack fp16 weights into int32 tensor"); +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/shared_mem.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/shared_mem.cu new file mode 100644 index 0000000..806b698 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/shared_mem.cu @@ -0,0 +1,57 @@ +#pragma once + +#include +#include +#include +#include +#include + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +template +__global__ void sgemm_shared_mem_block(int M, int N, int K, float alpha, + const float *A, const float *B, + float beta, float *C) { + // the output block that we want to compute in this threadblock + const uint cRow = blockIdx.x; + const uint cCol = blockIdx.y; + + // allocate buffer for current block in fast shared mem + // shared mem is shared between all threads in a block + __shared__ float As[BLOCKSIZE * BLOCKSIZE]; + __shared__ float Bs[BLOCKSIZE * BLOCKSIZE]; + + // the inner row & col that we're accessing in this thread + const uint threadCol = threadIdx.x % BLOCKSIZE; + const uint threadRow = threadIdx.x / BLOCKSIZE; + + // advance pointers to the starting positions + A += cRow * BLOCKSIZE * K; // row=cRow, col=0 + B += cCol * BLOCKSIZE; // row=0, col=cCol + C += cRow * BLOCKSIZE * N + cCol * BLOCKSIZE; // row=cRow, col=cCol + + float tmp = 0.0; + for (int bkIdx = 0; bkIdx < K; bkIdx += BLOCKSIZE) { + // Have each thread load one of the elements in A & B + // Make the threadCol (=threadIdx.x) the consecutive index + // to allow global memory access coalescing + As[threadRow * BLOCKSIZE + threadCol] = A[threadRow * K + threadCol]; + Bs[threadRow * BLOCKSIZE + threadCol] = B[threadRow * N + threadCol]; + + // block threads in this block until cache is fully populated + __syncthreads(); + A += BLOCKSIZE; + B += BLOCKSIZE * N; + + // execute the dotproduct on the currently cached block + for (int dotIdx = 0; dotIdx < BLOCKSIZE; ++dotIdx) { + tmp += As[threadRow * BLOCKSIZE + dotIdx] * + Bs[dotIdx * BLOCKSIZE + threadCol]; + } + // need to sync again at the end, to avoid faster threads + // fetching the next block into the cache before slower threads are done + __syncthreads(); + } + C[threadRow * N + threadCol] = + alpha * tmp + beta * C[threadRow * N + threadCol]; +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/vectorized.cu b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/vectorized.cu new file mode 100644 index 0000000..e330941 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/kernels/reference/vectorized.cu @@ -0,0 +1,122 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) + +template +__global__ void sgemmVectorize(int M, int N, int K, float alpha, float *A, + float *B, float beta, float *C) { + const uint cRow = blockIdx.y; + const uint cCol = blockIdx.x; + + // BN/TN are the number of threads to span a column + const int threadCol = threadIdx.x % (BN / TN); + const int threadRow = threadIdx.x / (BN / TN); + + // allocate space for the current blocktile in smem + __shared__ float As[BM * BK]; + __shared__ float Bs[BK * BN]; + + // Move blocktile to beginning of A's row and B's column + A += cRow * BM * K; + B += cCol * BN; + C += cRow * BM * N + cCol * BN; + + // calculating the indices that this thread will load into SMEM + // we'll load 128bit / 32bit = 4 elements per thread at each step + const uint innerRowA = threadIdx.x / (BK / 4); + const uint innerColA = threadIdx.x % (BK / 4); + const uint innerRowB = threadIdx.x / (BN / 4); + const uint innerColB = threadIdx.x % (BN / 4); + + // allocate thread-local cache for results in registerfile + float threadResults[TM * TN] = {0.0}; + float regM[TM] = {0.0}; + float regN[TN] = {0.0}; + + // outer-most loop over block tiles + for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { + // populate the SMEM caches + // transpose A while loading it + float4 tmp = + reinterpret_cast(&A[innerRowA * K + innerColA * 4])[0]; + As[(innerColA * 4 + 0) * BM + innerRowA] = tmp.x; + As[(innerColA * 4 + 1) * BM + innerRowA] = tmp.y; + As[(innerColA * 4 + 2) * BM + innerRowA] = tmp.z; + As[(innerColA * 4 + 3) * BM + innerRowA] = tmp.w; + + reinterpret_cast(&Bs[innerRowB * BN + innerColB * 4])[0] = + reinterpret_cast(&B[innerRowB * N + innerColB * 4])[0]; + __syncthreads(); + + // advance blocktile + A += BK; // move BK columns to right + B += BK * N; // move BK rows down + + // calculate per-thread results + for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { + // block into registers + for (uint i = 0; i < TM; ++i) { + regM[i] = As[dotIdx * BM + threadRow * TM + i]; + } + for (uint i = 0; i < TN; ++i) { + regN[i] = Bs[dotIdx * BN + threadCol * TN + i]; + } + for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { + for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { + threadResults[resIdxM * TN + resIdxN] += + regM[resIdxM] * regN[resIdxN]; + } + } + } + __syncthreads(); + } + + // write out the results + for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) { + for (uint resIdxN = 0; resIdxN < TN; resIdxN += 4) { + // load C vector into registers + float4 tmp = reinterpret_cast( + &C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0]; + // perform GEMM update in reg + tmp.x = alpha * threadResults[resIdxM * TN + resIdxN] + beta * tmp.x; + tmp.y = alpha * threadResults[resIdxM * TN + resIdxN + 1] + beta * tmp.y; + tmp.z = alpha * threadResults[resIdxM * TN + resIdxN + 2] + beta * tmp.z; + tmp.w = alpha * threadResults[resIdxM * TN + resIdxN + 3] + beta * tmp.w; + // write back + reinterpret_cast( + &C[(threadRow * TM + resIdxM) * N + threadCol * TN + resIdxN])[0] = + tmp; + } + } +} + +void runSgemmVectorize(int M, int N, int K, float alpha, float *A, float *B, + float beta, float *C) { + const uint BK = 8; + const uint TM = 8; + const uint TN = 8; + if (M >= 128 and N >= 128) { + const uint BM = 128; + const uint BN = 128; + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 blockDim((BM * BN) / (TM * TN)); + sgemmVectorize + <<>>(M, N, K, alpha, A, B, beta, C); + } else { + // this is a hacky solution to the underlying problem + // of not having proper bounds checking in the kernel + const uint BM = 64; + const uint BN = 64; + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + dim3 blockDim((BM * BN) / (TM * TN)); + sgemmVectorize + <<>>(M, N, K, alpha, A, B, beta, C); + } +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/setup.py b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/setup.py new file mode 100644 index 0000000..fbb543f --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/blocktiling/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. The extension requires CUDA.") + +setup( + name='warptiling', + ext_modules=[ + CUDAExtension( + name='warptiling', + sources=['kernels/warptiling.cu'] + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/naive.cu b/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/naive.cu new file mode 100644 index 0000000..448df63 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/naive.cu @@ -0,0 +1,81 @@ +#include +#include +#include +#include +#include + +#define CEIL_DIV(x, y) ((x) + (y)-1) / (y) + +__global__ void naive_kernel( + const half *input, + const int8_t *weights, + const half *bias, + half *output, + int M, + int N, + int K + ) +{ + const uint col = blockIdx.x * blockDim.x + threadIdx.x; + const uint row = blockIdx.y * blockDim.y + threadIdx.y; + + if (row < M && col < N) { + + float sum = 0.0f; + int8_t weight; + + for (int k = 0; k < K; k += 4) { + weight = weights[(col * K + k) >> 2]; + + for (int offset=0; offset<4; offset++) { + int8_t mask = (weight & (3 << (2 * offset))) >> (2 * offset); + float input_val = __half2float(input[row * K + k + offset]); + + if (mask == 1) { + sum += input_val; + } else if (mask == 2) { + sum -= input_val; + } + } + + } + + // Store the result into the correct location in memory + output[row * N + col] = __float2half(sum + __half2float(bias[row])); + + } +} + +torch::Tensor linear( + torch::Tensor input, + torch::Tensor weights, + torch::Tensor bias, + int M, + int N, + int K +) { + auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0); + auto output = torch::zeros({M, N}, options); + + uint blockSize = 32; + + dim3 dimGrid(CEIL_DIV(N, blockSize), CEIL_DIV(M, blockSize)); + dim3 dimBlock(blockSize, blockSize); + + naive_kernel<<>>( + reinterpret_cast(input.data_ptr()), + weights.data_ptr(), + reinterpret_cast(bias.data_ptr()), + reinterpret_cast(output.data_ptr()), + M, N, K + ); + + return output; +} + + +// Binding to generate the .so file, to call from Python. +PYBIND11_MODULE(naive, m) { + m.doc() = "Implementation of bitlinear forward linear in CUDA"; + m.def("linear", &linear, "bitlinear_forward (CUDA)"); +} diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/setup.py b/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/setup.py new file mode 100644 index 0000000..417a092 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/naive_linear/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. The extension requires CUDA.") + +setup( + name='naive', + ext_modules=[ + CUDAExtension( + name='naive', + sources=['naive.cu'] + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/original.cu b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/original.cu new file mode 100644 index 0000000..2cc66f3 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/original.cu @@ -0,0 +1,187 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) +const int WARPSIZE = 32; // warpSize is not constexpr + +namespace wt { +template +__device__ void loadFromGmem(int N, int K, const float *A, const float *B, + float *As, float *Bs, int innerRowA, int innerColA, + int innerRowB, int innerColB) { + for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { + const float4 tmp = reinterpret_cast( + &A[(innerRowA + offset) * K + innerColA * 4])[0]; + // float4 tmp; + // asm("ld.global.nc.v4.f32 {%0, %1, %2, %3}, [%4];" + // : "=f"(tmp.x), "=f"(tmp.y), "=f"(tmp.z), "=f"(tmp.w) + // : "l"(&A[(innerRowA + offset) * K + innerColA * 4])); + As[(innerColA * 4 + 0) * BM + innerRowA + offset] = tmp.x; + As[(innerColA * 4 + 1) * BM + innerRowA + offset] = tmp.y; + As[(innerColA * 4 + 2) * BM + innerRowA + offset] = tmp.z; + As[(innerColA * 4 + 3) * BM + innerRowA + offset] = tmp.w; + } + + for (uint offset = 0; offset + rowStrideB <= BK; offset += rowStrideB) { + reinterpret_cast( + &Bs[(innerRowB + offset) * BN + innerColB * 4])[0] = + reinterpret_cast( + &B[(innerRowB + offset) * N + innerColB * 4])[0]; + // asm("ld.global.v4.f32 {%0, %1, %2, %3}, [%4];" + // : "=f"(Bs[(innerRowB + offset) * BN + innerColB * 4 + 0]), + // "=f"(Bs[(innerRowB + offset) * BN + innerColB * 4 + 1]), + // "=f"(Bs[(innerRowB + offset) * BN + innerColB * 4 + 2]), + // "=f"(Bs[(innerRowB + offset) * BN + innerColB * 4 + 3]) + // : "l"(&B[(innerRowB + offset) * N + innerColB * 4])); + } +} + +template +__device__ void +processFromSmem(float *regM, float *regN, float *threadResults, const float *As, + const float *Bs, const uint warpRow, const uint warpCol, + const uint threadRowInWarp, const uint threadColInWarp) { + for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { + // populate registers for whole warptile + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { + for (uint i = 0; i < TM; ++i) { + regM[wSubRowIdx * TM + i] = + As[(dotIdx * BM) + warpRow * WM + wSubRowIdx * WSUBM + + threadRowInWarp * TM + i]; + } + } + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + for (uint i = 0; i < TN; ++i) { + regN[wSubColIdx * TN + i] = + Bs[(dotIdx * BN) + warpCol * WN + wSubColIdx * WSUBN + + threadColInWarp * TN + i]; + } + } + + // execute warptile matmul + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + // calculate per-thread results + for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { + for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { + threadResults[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + (wSubColIdx * TN) + resIdxN] += + regM[wSubRowIdx * TM + resIdxM] * + regN[wSubColIdx * TN + resIdxN]; + } + } + } + } + } +} + +} // namespace wt + +/* + * @tparam BM The threadblock size for M dimension SMEM caching. + * @tparam BN The threadblock size for N dimension SMEM caching. + * @tparam BK The threadblock size for K dimension SMEM caching. + * @tparam WM M dim of continuous tile computed by each warp + * @tparam WN N dim of continuous tile computed by each warp + * @tparam WMITER The number of subwarp tiling steps in M dimension. + * @tparam WNITER The number of subwarp tiling steps in N dimension. + * @tparam TM The per-thread tile size for M dimension. + * @tparam TN The per-thread tile size for N dimension. + */ +template +__global__ void __launch_bounds__(NUM_THREADS) + sgemmWarptiling(int M, int N, int K, float alpha, float *A, float *B, + float beta, float *C) { + const uint cRow = blockIdx.y; + const uint cCol = blockIdx.x; + + // Placement of the warp in the threadblock tile + const uint warpIdx = threadIdx.x / WARPSIZE; // the warp this thread is in + const uint warpCol = warpIdx % (BN / WN); + const uint warpRow = warpIdx / (BN / WN); + + // size of the warp subtile + constexpr uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER); + constexpr uint WSUBM = WM / WMITER; // 64/2=32 + constexpr uint WSUBN = WN / WNITER; // 32/2=16 + + // Placement of the thread in the warp subtile + const uint threadIdxInWarp = threadIdx.x % WARPSIZE; // [0, 31] + const uint threadColInWarp = threadIdxInWarp % (WSUBN / TN); // i%(16/4) + const uint threadRowInWarp = threadIdxInWarp / (WSUBN / TN); // i/4 + + // allocate space for the current blocktile in SMEM + __shared__ float As[BM * BK]; + __shared__ float Bs[BK * BN]; + + // Move blocktile to beginning of A's row and B's column + A += cRow * BM * K; + B += cCol * BN; + // Move C_ptr to warp's output tile + C += (cRow * BM + warpRow * WM) * N + cCol * BN + warpCol * WN; + + // calculating the indices that this thread will load into SMEM + // we'll load 128bit / 32bit = 4 elements per thread at each step + const uint innerRowA = threadIdx.x / (BK / 4); + const uint innerColA = threadIdx.x % (BK / 4); + constexpr uint rowStrideA = (NUM_THREADS * 4) / BK; + const uint innerRowB = threadIdx.x / (BN / 4); + const uint innerColB = threadIdx.x % (BN / 4); + constexpr uint rowStrideB = NUM_THREADS / (BN / 4); + + // allocate thread-local cache for results in registerfile + float threadResults[WMITER * TM * WNITER * TN] = {0.0}; + // we cache into registers on the warptile level + float regM[WMITER * TM] = {0.0}; + float regN[WNITER * TN] = {0.0}; + + // outer-most loop over block tiles + for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { + wt::loadFromGmem( + N, K, A, B, As, Bs, innerRowA, innerColA, innerRowB, innerColB); + __syncthreads(); + wt::processFromSmem(regM, regN, threadResults, As, Bs, warpRow, warpCol, + threadRowInWarp, threadColInWarp); + A += BK; // move BK columns to right + B += BK * N; // move BK rows down + __syncthreads(); + } + + // write out the results + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + // move C pointer to current warp subtile + float *C_interim = C + (wSubRowIdx * WSUBM) * N + wSubColIdx * WSUBN; + for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) { + for (uint resIdxN = 0; resIdxN < TN; resIdxN += 4) { + // load C vector into registers + float4 tmp = reinterpret_cast( + &C_interim[(threadRowInWarp * TM + resIdxM) * N + + threadColInWarp * TN + resIdxN])[0]; + // perform GEMM update in reg + const int i = (wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + wSubColIdx * TN + resIdxN; + tmp.x = alpha * threadResults[i + 0] + beta * tmp.x; + tmp.y = alpha * threadResults[i + 1] + beta * tmp.y; + tmp.z = alpha * threadResults[i + 2] + beta * tmp.z; + tmp.w = alpha * threadResults[i + 3] + beta * tmp.w; + // write back + reinterpret_cast( + &C_interim[(threadRowInWarp * TM + resIdxM) * N + + threadColInWarp * TN + resIdxN])[0] = tmp; + } + } + } + } +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cu b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cu new file mode 100644 index 0000000..8f06c1f --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/kernels/warptiling.cu @@ -0,0 +1,269 @@ +#include +#include +#include +#include +#include + +// Settings for A6000 +const uint NUM_THREADS = 128; +const uint BN = 128; // The threadblock size for N dimension SMEM caching. +const uint BM = 128; // The threadblock size for M dimension SMEM caching. +const uint BK = 16; // The threadblock size for K dimension SMEM caching. +const uint WN = 64; // N dim of continuous tile computed by each warp +const uint WM = 64; // M dim of continuous tile computed by each warp +const uint WNITER = 4; // The number of subwarp tiling steps in N dimension. +const uint TN = 4; // The per-thread tile size for N dimension. +const uint TM = 8; // The per-thread tile size for M dimension. +const int WARPSIZE = 32; // warpSize is not constexpr +const uint WMITER = (WM * WN) / (WARPSIZE * TM * TN * WNITER); // The number of subwarp tiling steps in M dimension. +const uint WSUBM = WM / WMITER; +const uint WSUBN = WN / WNITER; +const uint rowStrideA = (NUM_THREADS * 4) / BK; +const uint rowStrideB = NUM_THREADS / (BN / 4); +const uint NUM_WARPS = NUM_THREADS / WARPSIZE; + +#define CEIL_DIV(x, y) ((x) + (y)-1) / (y) + +/* +Half4 Helper Functions +*/ + +struct half4 { + half2 x, y; // Each half2 contains two half values +}; + +__device__ half4 loadHalf4(const half* address) { + half4 result; + result.x = *reinterpret_cast(address); + result.y = *reinterpret_cast(address + 2); + return result; +} + + +/* +Memory Operations +*/ + +namespace wt { + + /* + Load from Global Memory 4 items at a time + */ + + __device__ void loadFromGmem( + int N, int K, + const half *A, const int8_t *W, + half *sA, int8_t *sW, + int innerRowA, int innerColA, + int innerRowB, int innerColB + ) + { + // Load matrix A into shared memory sA + for (uint offset = 0; offset + rowStrideA <= BM; offset += rowStrideA) { + const half4 tmp = loadHalf4(&A[(innerRowA + offset) * K + innerColA * 4]); + sA[(innerColA * 4 + 0) * BM + innerRowA + offset] = tmp.x.x; + sA[(innerColA * 4 + 1) * BM + innerRowA + offset] = tmp.x.y; + sA[(innerColA * 4 + 2) * BM + innerRowA + offset] = tmp.y.x; + sA[(innerColA * 4 + 3) * BM + innerRowA + offset] = tmp.y.y; + } + + // Load 4 items matrix W into shared memory sW as packed 2-bit values + for (uint offset = 0; offset + rowStrideB <= BK; offset += rowStrideB) { + sW[(innerRowB + offset) * (BN / 4) + innerColB] = W[(innerRowB + offset) * (K / 4) + innerColB]; + } + } + + /* + Load from Shared Memory and Perform Computation + */ + + __device__ void processFromSmem( + half *regM, int8_t *regN, + float *threadResults, + const half *sA, const int8_t *sW, + const uint warpRow, const uint warpCol, + const uint threadRowInWarp, const uint threadColInWarp + ) + { + for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) { + + // Load sub-matrix of A into registers for a whole warptile + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { + for (uint i = 0; i < TM; ++i) { + regM[wSubRowIdx * TM + i] = + sA[(dotIdx * BM) + warpRow * WM + wSubRowIdx * WSUBM + + threadRowInWarp * TM + i]; + } + } + + // Unpack and load sub-matrix of sW into registers already translated + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + for (uint i = 0; i < TN; ++i) { + uint index = (warpCol * WN + wSubColIdx * WSUBN + threadColInWarp * TN + i) >> 2; + int8_t packedWeights = sW[dotIdx * (BN / 4) + index]; + uint shift = ((warpCol * WN + wSubColIdx * WSUBN + threadColInWarp * TN + i) & 0x03) * 2; + regN[wSubColIdx * TN + i] = (packedWeights >> shift) & 0x03; + } + } + + // execute warptile matmul + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + // per-thread results + for (uint resIdxM = 0; resIdxM < TM; ++resIdxM) { + for (uint resIdxN = 0; resIdxN < TN; ++resIdxN) { + // loads the value and corresponiding weight + float inputVal = __half2float(regM[wSubRowIdx * TM + resIdxM]); + int8_t weight = regN[wSubColIdx * TN + resIdxN]; + // if 0b01, adds the activation + if (weight == 1) { + threadResults[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + (wSubColIdx * TN) + resIdxN] += inputVal; + } + // if 0b10, subtracts the activation + else if (weight == 2) { + threadResults[(wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + (wSubColIdx * TN) + resIdxN] -= inputVal; + } + } + } + } + } + } + } + +} + +/* +CUDA Kernel for WarpTiling with the Bitlinear Implementation +*/ +__global__ void __launch_bounds__(NUM_THREADS) sgemmWarptiling( + const int M, const int N, const int K, + const half *A, const int8_t *W, half *C + ) +{ + const uint cRow = blockIdx.y; + const uint cCol = blockIdx.x; + + // Placement of the warp in the threadblock tile + const uint warpIdx = threadIdx.x / WARPSIZE; // the warp this thread is in + const uint warpCol = warpIdx % (BN / WN); + const uint warpRow = warpIdx / (BN / WN); + + // Placement of the thread in the warp subtile + const uint threadIdxInWarp = threadIdx.x % WARPSIZE; + const uint threadColInWarp = threadIdxInWarp % (WSUBN / TN); + const uint threadRowInWarp = threadIdxInWarp / (WSUBN / TN); + + // allocate space for the current blocktile in SMEM + __shared__ half sA[BM * BK]; + __shared__ int8_t sW[BK * BN / 4]; + + // Move blocktile to beginning of A's row and W's column + A += cRow * BM * K; + W += cCol * BN / 4; + // Move C_ptr to warp's output tile + C += (cRow * BM + warpRow * WM) * N + cCol * BN + warpCol * WN; + + // calculating the indices that this thread will load into SMEM + // we'll load 4 elements per thread at each step + const uint innerRowA = threadIdx.x / (BK / 4); + const uint innerColA = threadIdx.x % (BK / 4); + const uint innerRowB = threadIdx.x / (BN / 4); + const uint innerColB = threadIdx.x % (BN / 4); + + // allocate thread-local cache for results in registerfile + float threadResults[WMITER * TM * WNITER * TN] = {0.0f}; + // we cache into registers on the warptile level + half regM[WMITER * TM] = {__float2half(0.0f)}; + int8_t regN[WNITER * TN] = {0}; + + // loop over block tiles + for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) { + + wt::loadFromGmem(N, K, A, W, sA, sW, innerRowA, innerColA, innerRowB, innerColB); + + __syncthreads(); + + wt::processFromSmem(regM, regN, threadResults, sA, sW, warpRow, warpCol, threadRowInWarp, threadColInWarp); + + A += BK; + W += BK * N / 4; + + __syncthreads(); + } + + // write out the results + for (uint wSubRowIdx = 0; wSubRowIdx < WMITER; ++wSubRowIdx) { + for (uint wSubColIdx = 0; wSubColIdx < WNITER; ++wSubColIdx) { + // move C pointer to current warp subtile + half *C_interim = C + (wSubRowIdx * WSUBM) * N + wSubColIdx * WSUBN; + for (uint resIdxM = 0; resIdxM < TM; resIdxM += 1) { + for (uint resIdxN = 0; resIdxN < TN; resIdxN += 2) { + // load C vector into registers + half2 tmp = reinterpret_cast( + &C_interim[(threadRowInWarp * TM + resIdxM) * N + + threadColInWarp * TN + resIdxN])[0]; + // perform GEMM update in reg + const int i = (wSubRowIdx * TM + resIdxM) * (WNITER * TN) + + wSubColIdx * TN + resIdxN; + tmp.x = __float2half(threadResults[i + 0] + __half2float(tmp.x)); + tmp.y = __float2half(threadResults[i + 1] + __half2float(tmp.y)); + // write back + reinterpret_cast( + &C_interim[(threadRowInWarp * TM + resIdxM) * N + + threadColInWarp * TN + resIdxN])[0] = tmp; + } + } + } + } +} + + +/* +Callable Function that computes the "matmul" of A and W and stores in C +*/ +void matmul(int M, int N, int K, torch::Tensor A, torch::Tensor W, torch::Tensor C) { + + // warptile in threadblocktile + static_assert((BN % WN == 0) and (BM % WM == 0)); + static_assert((BN / WN) * (BM / WM) == NUM_WARPS); + + // threads in warpsubtile + static_assert((WM * WN) % (WARPSIZE * TM * TN * WNITER) == 0); + constexpr uint WMITER = (WM * WN) / (32 * TM * TN * WNITER); + + // warpsubtile in warptile + static_assert((WM % WMITER == 0) and (WN % WNITER == 0)); + + static_assert((NUM_THREADS * 4) % BK == 0, + "NUM_THREADS*4 must be multiple of BK to avoid quantization " + "issues during GMEM->SMEM tiling (loading only parts of the " + "final row of Bs during each iteration)"); + static_assert((NUM_THREADS * 4) % BN == 0, + "NUM_THREADS*4 must be multiple of BN to avoid quantization " + "issues during GMEM->SMEM tiling (loading only parts of the " + "final row of As during each iteration)"); + static_assert(BN % (16 * TN) == 0, + "BN must be a multiple of 16*TN to avoid quantization effects"); + static_assert(BM % (16 * TM) == 0, + "BM must be a multiple of 16*TM to avoid quantization effects"); + static_assert((BM * BK) % (4 * NUM_THREADS) == 0, + "BM*BK must be a multiple of 4*256 to vectorize loads"); + static_assert((BN * BK) % (4 * NUM_THREADS) == 0, + "BN*BK must be a multiple of 4*256 to vectorize loads"); + + dim3 blockDim(NUM_THREADS); + dim3 gridDim(CEIL_DIV(N, BN), CEIL_DIV(M, BM)); + + sgemmWarptiling<<>>( + M, N, K, + reinterpret_cast(A.data_ptr()), + W.data_ptr(), + reinterpret_cast(C.data_ptr()) + ); +} + +PYBIND11_MODULE(warptiling, m) { + m.def("matmul", &matmul, "Run SGEMM Warptiling with Half Precision (CUDA)"); +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/setup.py b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/setup.py new file mode 100644 index 0000000..fbb543f --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/kernels/warptiling/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. The extension requires CUDA.") + +setup( + name='warptiling', + ext_modules=[ + CUDAExtension( + name='warptiling', + sources=['kernels/warptiling.cu'] + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/bitlinear/frozen_bitlinear/cuda/pack_weights/pack_weights.cu b/bitlinear/frozen_bitlinear/cuda/pack_weights/pack_weights.cu new file mode 100644 index 0000000..c1c15c4 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/pack_weights/pack_weights.cu @@ -0,0 +1,84 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; + +__global__ void pack_weights_kernel(const half* weights, int* packed_weights, int n, int k) { + int row = blockIdx.x * blockDim.x + threadIdx.x; + int col = blockIdx.y * blockDim.y + threadIdx.y; + + if (row < n && col < k) { + int weight_index = row * k + col; + float weight_value = __half2float(weights[weight_index]); + + int packed_index = weight_index / 4; // 4 weights per int8 + int bit_index = 2*weight_index % 8; // each weight takes 2 bits + + int bit_mask = (weight_value == 1.0f) ? (1 << bit_index) : + (weight_value == -1.0f) ? (2 << bit_index) : 0; + + atomicOr(&packed_weights[packed_index], bit_mask); // this avoids race condition errors from parallelization + + } +} + +__global__ void int32_to_int8_kernel(const int* packed_weights_int32, int8_t* packed_weights_int8, int packed_size) { + int i = blockIdx.x * blockDim.x + threadIdx.x; + + if (i < packed_size) { + int packed_value = packed_weights_int32[i]; + packed_weights_int8[i] = static_cast(packed_value & 0xFF); + } +} + +torch::Tensor packedint8( + torch::Tensor weights, + int n, + int k + ) { + + TORCH_CHECK(weights.is_contiguous(), "weights tensor must be contiguous"); + TORCH_CHECK(weights.dtype() == torch::kFloat16, "weights tensor must be of type float16"); + + // Calculate size for packed weights tensor + int packed_size = (n * k + 3) / 4; // 4 weights per int8 + + auto packed_weights_int32 = torch::zeros({packed_size}, torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA)); + auto packed_weights_int8 = torch::empty({packed_size}, torch::TensorOptions().dtype(torch::kInt8).device(torch::kCUDA)); + + const unsigned int block_size = 32; + const unsigned int grid_rows = (n + block_size - 1) / block_size; + const unsigned int grid_cols = (k + block_size - 1) / block_size; + + dim3 dimGrid(grid_cols, grid_rows); + dim3 dimBlock(block_size, block_size); + unsigned int grid_size = (packed_size + block_size - 1) / block_size; + + pack_weights_kernel<<>>( + reinterpret_cast(weights.data_ptr()), + packed_weights_int32.data_ptr(), + n, + k + ); + int32_to_int8_kernel<<>>( + packed_weights_int32.data_ptr(), + packed_weights_int8.data_ptr(), + packed_size + ); + + cudaError_t err = cudaDeviceSynchronize(); + TORCH_CHECK(err == cudaSuccess, "CUDA kernel failed with error: ", cudaGetErrorString(err)); + + return packed_weights_int8; +} + + +PYBIND11_MODULE(pack_weights, m) { + m.def("packedint8", &packedint8, "Pack fp16 weights into int8 tensor"); +} \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/cuda/pack_weights/setup.py b/bitlinear/frozen_bitlinear/cuda/pack_weights/setup.py new file mode 100644 index 0000000..8801da6 --- /dev/null +++ b/bitlinear/frozen_bitlinear/cuda/pack_weights/setup.py @@ -0,0 +1,19 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import torch + +if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available. The extension requires CUDA.") + +setup( + name='pack_weights', + ext_modules=[ + CUDAExtension( + name='pack_weights', + sources=['pack_weights.cu'] + ) + ], + cmdclass={ + 'build_ext': BuildExtension + } +) diff --git a/bitlinear/frozen_bitlinear/frozen_bitlinear.py b/bitlinear/frozen_bitlinear/frozen_bitlinear.py new file mode 100644 index 0000000..d3e7c23 --- /dev/null +++ b/bitlinear/frozen_bitlinear/frozen_bitlinear.py @@ -0,0 +1,133 @@ +import torch +import torch.nn as nn + +from ..bitlinear import BitLinear +from ..measures import * +from bitlinear.frozen_bitlinear.frozen_bitlinear import TorchLinear, Naive + +class FrozenBitLinear(nn.Linear): + + w_scale = None + + def __init__( + self, + in_features, + out_features, + kernel, + bias=True, + eps=1e-5, + activation_range=8, + activation_measure='AbsMax', + device=None, + dtype=None + ): + super(FrozenBitLinear, self).__init__( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + dtype=dtype, + ) + + self.eps = eps + self.kernel = eval(kernel)(eps, activation_range, activation_measure) if isinstance(kernel, str) else kernel(eps, activation_range, activation_measure) + + def __repr__(self): + return f"FrozenBitLinear(in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}, kernel={self.kernel}), activation_range={self.activation_range}, activation_measure={self.activation_measure}" + + def forward(self, x): + return self.kernel(x, self.weight, self.bias, self.w_scale) + + def freeze_weights(self, weights:torch.Tensor, weightMeasure:str='AbsMean'): + """ + Parameters: + weights : torch.Tensor + tensor of weights to be packed (out_features x in_features) + weightMeasure : str + str corresponding to a weight packing method + options are 'AbsMean', 'AbsMax', and 'AbsMedian' + default is 'AbsMean' + + Returns: + None + + This function must be called for a Frozen BitLinear Module to pack the weights. It takes advantage of a + kernel-specific weight packing function to store the weights in their ternary representation. This is implicitly + called in bitlinear.freeze + """ + + assert (weights.shape[0] == self.out_features) & (weights.shape[1] == self.in_features), "Weights dimensions must match that of the layer" + + packed_weights, self.w_scale = self.kernel.scale_weights(weights, weightMeasure, self.eps) + + if isinstance(packed_weights, torch.nn.Parameter): + self.weight = packed_weights + else: + self.weight = nn.Parameter(packed_weights) + + +def freeze( + model, + kernel=TorchLinear, + weightMeasure = 'AbsMean', + eps=1e-5, + activation_measure='AbsMax', + activation_range=16, + device=None, + dtype=None + ): + """ + Parameters: + model : torch.nn.Module + Model with BitLinear modules to replace with FrozenBitLinear for faster inference + kernel : str or Kernel + Forward kernel to implement in the model + weightMeasure : str + str corresponding to a weight packing method + options are 'AbsMean', 'AbsMax', and 'AbsMedian' + default is 'AbsMean' + eps : float + value to clamp the weights to in the scaling step to avoid divide-by-zero + default is 1e-5 + activation_measure : str + str corresponding to the activation quantization method + options are 'AbsMean', 'AbsMax', 'AbsMedian', and 'Fp16' + 'AbsMax' + activation_range : int + number of bits to represent the activations in + default is 8 + device : Optional(torch.device) + dtype : Optional(torch.dtype) + + Returns: + None + + This function replaces all of the BitLinear instances in a model with FrozenBitLinear in order to + speed up inference. The weights are packed corresponding to the kernel. + """ + + for name, module in model.named_children(): + if isinstance(module, BitLinear): + + kwargs = { + 'kernel' : kernel, + 'in_features' : module.in_features, + 'out_features' : module.out_features, + 'bias' : getattr(module, "bias", None) is not None, + 'device' : device, + 'dtype' : dtype, + 'eps' : eps, + 'activation_measure' : activation_measure, + 'activation_range' : activation_range + } + + frozen_module = FrozenBitLinear(**kwargs) + frozen_module.freeze_weights(module.weight.data, weightMeasure) # pack + + if kwargs['bias']: + frozen_module.bias.data = module.bias.data + + setattr(model, name, frozen_module) + + else: # recursively iterate throughout the rest of the model + freeze(module, kernel, weightMeasure, eps, device, dtype) diff --git a/bitlinear/frozen_bitlinear/requirements.txt b/bitlinear/frozen_bitlinear/requirements.txt new file mode 100644 index 0000000..4aa6a17 --- /dev/null +++ b/bitlinear/frozen_bitlinear/requirements.txt @@ -0,0 +1,8 @@ +torch +triton +pandas +numpy +plotly +tqdm +matplotlib +openpxyl \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/scripts/build.sh b/bitlinear/frozen_bitlinear/scripts/build.sh new file mode 100755 index 0000000..7b5134c --- /dev/null +++ b/bitlinear/frozen_bitlinear/scripts/build.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# Get the directory where the script is located +SCRIPT_DIR=$(dirname "$(realpath "$0")") + +# Define the parent directory paths +Kernel_DIR="${SCRIPT_DIR}/../cuda/kernels" +PACK_WEIGHTS_DIR="${SCRIPT_DIR}/../cuda/pack_weights" + +# Iterate through each subdirectory under $PARENT_DIR/cuda +for dir in "${Kernel_DIR}/"*/; do + if [ -f "${dir}setup.py" ]; then + echo "Building in directory: $dir" + (cd "$dir" && python setup.py build_ext --inplace) + fi +done + +# Build in the pack_weights directory +if [ -f "${PACK_WEIGHTS_DIR}/setup.py" ]; then + echo "Building in pack_weights directory: $PACK_WEIGHTS_DIR" + (cd "$PACK_WEIGHTS_DIR" && python setup.py build_ext --inplace) +else + echo "No setup.py found in pack_weights directory: ${PACK_WEIGHTS_DIR}" +fi + +echo "Build complete." \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/scripts/clean.sh b/bitlinear/frozen_bitlinear/scripts/clean.sh new file mode 100755 index 0000000..c9fe2db --- /dev/null +++ b/bitlinear/frozen_bitlinear/scripts/clean.sh @@ -0,0 +1,70 @@ +#!/bin/bash + +# Function to display usage information +usage() { + echo "Usage: $0 [-b] [-w] [-r]" + echo " -b Deletes kernel builds" + echo " -w Delete 'weights' directories" + echo " -r Delete 'results' directories" + exit 1 +} + +# Get the directory where the script is located +SCRIPT_DIR=$(dirname "$(realpath "$0")") +PARENT_DIR="${SCRIPT_DIR}/.." + +# Parse command-line arguments +while getopts ":bwr" opt; do + case ${opt} in + b ) + DELETE_BUILD=true + ;; + w ) + DELETE_WEIGHTS=true + ;; + r ) + DELETE_RESULTS=true + ;; + \? ) + echo "Invalid option: -$OPTARG" >&2 + usage + ;; + : ) + echo "Invalid option: -$OPTARG requires an argument" >&2 + usage + ;; + esac +done + +# Shift parsed options away from the positional parameters +shift $((OPTIND -1)) + +if [ "$#" -eq 0 ]; then + DELETE_BUILD=true +fi + +# Delete .so files if -s is specified +if [ "$DELETE_BUILD" = true ]; then + echo "Deleting .so files..." + find "$PARENT_DIR" -type f -name '*.so' -exec rm -f {} \; +fi + +# Delete 'build' directories if -b is specified +if [ "$DELETE_BUILD" = true ]; then + echo "Deleting 'build' directories..." + find "$PARENT_DIR" -type d -name 'build' -exec rm -rf {} \; +fi + +# Delete 'weights' directories if -w is specified +if [ "$DELETE_WEIGHTS" = true ]; then + echo "Deleting 'weights' directories..." + find "$PARENT_DIR" -type d -name 'weights' -exec rm -rf {} \; +fi + +# Delete 'results' directories if -r is specified +if [ "$DELETE_RESULTS" = true ]; then + echo "Deleting 'results' directories..." + find "$PARENT_DIR" -type d -name 'results' -exec rm -rf {} \; +fi + +echo "Cleanup complete." diff --git a/bitlinear/frozen_bitlinear/scripts/test.sh b/bitlinear/frozen_bitlinear/scripts/test.sh new file mode 100755 index 0000000..d7a209d --- /dev/null +++ b/bitlinear/frozen_bitlinear/scripts/test.sh @@ -0,0 +1,57 @@ +#!/bin/bash + +timestamp=$(date "+%Y%m%d_%T") +save_dir="results/$timestamp" + +# Default values +device=0 +kernel="TorchLinear" + +# Function to display usage information +usage() { + echo "Usage: $0 [-d ] [-k ]" + echo " -d Optional device argument (0, 1, 2, 3). Default is 0." + echo " -k Optional kernel argument. Default is 'TorchLinear'." + exit 1 +} + +# Parse options +while getopts ":d:k:" opt; do + case ${opt} in + d ) + device=$OPTARG + # Check if the integer argument is valid (0, 1, 2, 3) + if ! [[ "$device" =~ ^[0-3]$ ]]; then + echo "Error: Device argument must be one of 0, 1, 2, or 3" + usage + fi + ;; + k ) + kernel=$OPTARG + ;; + \? ) + echo "Invalid option: -$OPTARG" >&2 + usage + ;; + : ) + echo "Option -$OPTARG requires an argument" >&2 + usage + ;; + esac +done + +shift $((OPTIND -1)) + +mkdir -p ${save_dir} +cp $0 $save_dir # saves the current training script + +cd "$(dirname "$0")/.." + +run_cmd="CUDA_VISIBLE_DEVICES=$device python -m tests \ + --save_dir $save_dir \ + --kernel $kernel \ + -a \ + 2>&1 | tee ${save_dir}/test-${timestamp}.log" + +echo "$run_cmd" +eval "$run_cmd" diff --git a/bitlinear/frozen_bitlinear/src/__init__.py b/bitlinear/frozen_bitlinear/src/__init__.py new file mode 100644 index 0000000..1756ef7 --- /dev/null +++ b/bitlinear/frozen_bitlinear/src/__init__.py @@ -0,0 +1,4 @@ +from src.default import TorchLinear +from src.naive import Naive +from src.warptiling import Warptiling + diff --git a/bitlinear/frozen_bitlinear/src/default.py b/bitlinear/frozen_bitlinear/src/default.py new file mode 100644 index 0000000..3b6c5e5 --- /dev/null +++ b/bitlinear/frozen_bitlinear/src/default.py @@ -0,0 +1,9 @@ +import torch +import torch.nn.functional as F + +from src.utils import Default + +class TorchLinear(Default): + def __call__(self, input, weight, bias=None, scale=1): + x_quant, x_scale = self.activations(input) + return F.linear(x_quant, weight, bias) * scale * x_scale diff --git a/bitlinear/frozen_bitlinear/src/naive.py b/bitlinear/frozen_bitlinear/src/naive.py new file mode 100644 index 0000000..7ce0c29 --- /dev/null +++ b/bitlinear/frozen_bitlinear/src/naive.py @@ -0,0 +1,26 @@ +from src.utils import Packed8, naive + +class Naive(Packed8): + fxn = naive.linear + + def __call__(self, activations, weights, bias, scale): + # Check constraints. + assert activations.is_contiguous(), "Matrix A must be contiguous" + assert activations.shape[0] == bias.shape[0], "Bias dimension must match input" + + M, K = activations.shape + N = weights.shape[0] * 4//K + + x_quant, x_scale = self.activations(activations) + + return self.fxn(x_quant, weights, bias, M, N, K) * scale * x_scale + + + + + + + + + + diff --git a/bitlinear/frozen_bitlinear/src/utils/Kernel.py b/bitlinear/frozen_bitlinear/src/utils/Kernel.py new file mode 100644 index 0000000..c77b452 --- /dev/null +++ b/bitlinear/frozen_bitlinear/src/utils/Kernel.py @@ -0,0 +1,30 @@ +import torch +import torch.nn.functional as F + +from src.utils.helpers import weight_round_clamp, AbsMax, AbsMean, AbsMedian, Fp16 +from cuda.pack_weights import pack_weights + +class Kernel: + def __init__(self, eps=1e-5, activation_range=8, activation_measure = 'AbsMax'): + + self.eps = eps + self.activations = eval(activation_measure)(activation_range, eps) + + def __call__(self, activations, weights, bias, scale) -> torch.Tensor: + raise NotImplementedError # this needs to be changed depending on the kernel + + def scale_weights(self, weights, measure='AbsMean') -> tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError # this needs to be changed depending on the kernel + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + +class Default(Kernel): + def scale_weights(self, weights, measure='AbsMean') -> tuple[torch.Tensor, torch.Tensor]: + return weight_round_clamp(weights, measure, self.eps) + +class Packed8(Kernel): + def scale_weights(self, weights, measure='AbsMean') -> tuple[torch.Tensor, torch.Tensor]: + weights, scale = weight_round_clamp(weights, measure, self.eps) + return pack_weights.packedint8(weights, *weights.shape), scale diff --git a/bitlinear/frozen_bitlinear/src/utils/__init__.py b/bitlinear/frozen_bitlinear/src/utils/__init__.py new file mode 100644 index 0000000..fd8b129 --- /dev/null +++ b/bitlinear/frozen_bitlinear/src/utils/__init__.py @@ -0,0 +1,2 @@ +from src.utils.Kernel import Packed8, Default +from cuda.kernels import * \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/src/utils/helpers.py b/bitlinear/frozen_bitlinear/src/utils/helpers.py new file mode 100644 index 0000000..988c2d4 --- /dev/null +++ b/bitlinear/frozen_bitlinear/src/utils/helpers.py @@ -0,0 +1,67 @@ +import torch +from math import ceil + + +''' +Activation Quantization Helpers +''' + +def symmetric_range_from_bits(range): + return (ceil(-2**(range-1)), ceil(2**(range-1)-1)) + +def round_clamp(input, range): + return (input.round().clamp(range[0], range[1]) - input).detach() + input + +class ActivationMeasure: + def __init__(self, range=8, eps=1e-5): + self.range = symmetric_range_from_bits(range) + self.eps = eps + + def __call__(self, input): + x_norm = torch.layer_norm(input, input.size()[1:]) + x_scale = self.scale(x_norm) + return round_clamp(x_norm/x_scale, self.range), x_scale + + def scale(self, input) -> torch.Tensor: + raise NotImplementedError + + def __repr__(self): + return f"{self.__class__.__name__}()" + + +''' +Callable Activation Quantization Classes +''' + +class Fp16(ActivationMeasure): + def __init__(self, range=8, eps=1e-5): + pass + def __call__(self, input) -> tuple[torch.Tensor, torch.Tensor]: + return input, torch.tensor([[1.0]], device='cuda', dtype=torch.float16) + +class AbsMax(ActivationMeasure): + def scale(self, input) -> torch.Tensor: + return input.abs().max(dim=-1, keepdim=True).values.clamp_(min=self.eps)/max(abs(k) for k in self.range) + +class AbsMean(ActivationMeasure): + def scale(self, input) -> torch.Tensor: + return input.abs().mean(dim=-1, keepdim=True).clamp_(min=self.eps)/max(abs(k) for k in self.range) + +class AbsMedian(ActivationMeasure): + def scale(self, input) -> torch.Tensor: + return input.abs().median(dim=-1, keepdim=True).values.clamp_(min=self.eps)/max(abs(k) for k in self.range) + + +''' +Weight Quantization Functions +''' + +class WeightMeasure: + AbsMax = lambda input : input.abs().max() + AbsMedian = lambda input : input.abs().median() + AbsMean = lambda input : input.abs().mean() + +def weight_round_clamp(input, measure, eps) -> tuple[torch.Tensor, float]: + scale = eval(f'WeightMeasure.{measure}')(input.detach()).clamp_(min=eps) + scaled_input = input/scale + return (scaled_input.round().clamp(-1, 1) - scaled_input).detach() + scaled_input, scale \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/src/warptiling.py b/bitlinear/frozen_bitlinear/src/warptiling.py new file mode 100644 index 0000000..17a68a8 --- /dev/null +++ b/bitlinear/frozen_bitlinear/src/warptiling.py @@ -0,0 +1,21 @@ +import torch +from src.utils import Packed8, warptiling + +class Warptiling(Packed8): + + def __call__(self, activations, weights, bias, scale): + # Check constraints. + assert activations.is_contiguous(), "Matrix A must be contiguous" + assert activations.shape[0] == bias.shape[0], "Bias dimension must match input" + + M, K = activations.shape + N = weights.shape[0] * 4//K + + x_quant, x_scale = self.activations(activations) + output = torch.zeros((M, N), device='cuda', dtype = torch.float16) + + # print(output, x_quant, weights, bias) + + warptiling.matmul(M, N, K, x_quant, weights, output) + + return (output + bias) * scale * x_scale \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/tests/Benchmark.py b/bitlinear/frozen_bitlinear/tests/Benchmark.py new file mode 100644 index 0000000..c2c8084 --- /dev/null +++ b/bitlinear/frozen_bitlinear/tests/Benchmark.py @@ -0,0 +1,375 @@ +import torch +import torch.nn.functional as F + +import os + +from itertools import product +from tqdm import tqdm + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import plotly.graph_objects as go + +from src import * +from tests.helpers import weights + +class Benchmark: + + options = [256, 512, 1024, 2048, 4096] + + def __init__(self, args): + + self.path = args.save_dir + + self.device='cuda' + + self.kernel_name = args.kernel + self.kernel = eval(self.kernel_name)(activation_measure='Fp16') + + self.baseline = TorchLinear(activation_measure='Fp16') + + print(f'Testing with {self.kernel_name} kernel.') + self.ref_lib = 'cuBLAS' + + if args.a or args.p: + self.profiling() + if args.a or args.u: + self.unittests() + if args.a or args.t: + self.throughputs() + + def unittest(self, M=512, N=256, K=128): + + torch.manual_seed(0) + + inputs = torch.randn((M, K), device=self.device, dtype=torch.float16) + biases = torch.randn((M, 1), device=self.device, dtype=torch.float16) + + weights_torch, weights_kernel, scale = weights(N, K, self.kernel, self.baseline) + + triton_output = self.kernel(inputs, weights_kernel, biases, scale) + torch_output = self.baseline(inputs, weights_torch, biases, scale) + + if torch.allclose(triton_output, torch_output, atol=1e-1, equal_nan=True): + print(f"✅ Triton and Torch matchfor (M={M}, N={N}, K={K})") + return True, 0.0 + else: + difference = torch.max(torch.abs(triton_output - torch_output)).item() / torch.max(torch.abs(torch_output)).item() + print(f"❌ Triton and Torch differ for (M={M}, N={N}, K={K}) -- Maximum Normalized Difference : {difference}") + return False, difference + + def unittests(self, upper_limit=4096, step=128): + + pbar = tqdm(total=upper_limit//step, desc=f"Unit Tests", leave=True) + + results = [] + + for combo in range(step, upper_limit+step, step): + + result, diff = self.unittest(combo, combo, combo) + results.append([combo, result, diff]) + + pbar.set_description(f"M=N=K : {combo} ") + pbar.update() + + pbar.close() + + results = pd.DataFrame(results, columns=['M=N=K', 'Matches Torch', 'Maximum Normalized Difference']) + + save_path = os.path.join(self.path,"unittests") + os.makedirs(save_path, exist_ok=True) + + with pd.ExcelWriter(os.path.join(save_path,"data.xlsx")) as writer: + results.to_excel(writer) + + print(f'\nUnitTest data saved to {save_path}\n') + + print('-------------------------------------') + print(f'Percentage of matching outputs {results["Matches Torch"].astype(int).mean() * 100}%') + print('-------------------------------------') + + fig, ax = plt.subplots(figsize=(10, 6)) + ax.scatter(results['M=N=K'], results['Maximum Normalized Difference'], marker='o', color='b', s=50) + + ax.set_xlabel('Dimension Size', fontsize=14) + ax.set_ylabel('Maximum Normalized Difference in Outputs', fontsize=14) + ax.set_title(f'Comparing PyTorch and {self.kernel_name} Output Tensors', fontsize=16) + ax.grid(True) + + plt.savefig(os.path.join(save_path, "difference.png"), dpi=300) + plt.show() + + def profiling(self, upper_limit=75): + + combinations = list(product(self.options, self.options, self.options)) + + results = { + item : [] + for item in [ + 'M', 'N', 'K', 'Custom_CUDA', 'cuBLAS_CUDA', 'Custom_CPU', 'cuBLAS_CPU', 'CUDA_factor', 'CPU_factor' + ] + } + + for [M, N, K] in combinations: + + print(f" M : {M} | N : {N} | K : {K}") + + results['M'].append(M) + results['N'].append(N) + results['K'].append(K) + + inputs = torch.randn((M, K), device=self.device, dtype=torch.float16) + biases = torch.randn((M, 1), device=self.device, dtype=torch.float16) + + weights_torch, weights_kernel, scale = weights(N, K, self.kernel, self.baseline) + + print("Profiling custom kernel") + with torch.autograd.profiler.profile(use_device = self.device) as prof: + self.kernel(inputs, weights_kernel, biases, scale) + events = prof.key_averages() + results['Custom_CUDA'].append(sum(event.device_time for event in events) / 1000.0) + results['Custom_CPU'].append(sum(event.cpu_time for event in events) / 1000.0) + print(events.table(sort_by="cuda_time_total")) + + print("Profiling cuBLAS kernel") + with torch.autograd.profiler.profile(use_device = self.device) as prof: + self.baseline(inputs, weights_torch, biases, scale) + events = prof.key_averages() + results['cuBLAS_CUDA'].append(sum(event.device_time for event in events) / 1000.0) + results['cuBLAS_CPU'].append(sum(event.cpu_time for event in events) / 1000.0) + print(events.table(sort_by="cuda_time_total")) + + results['CUDA_factor'].append(results['cuBLAS_CUDA'][-1]/results['Custom_CUDA'][-1]) + results['CPU_factor'].append(results['cuBLAS_CPU'][-1]/results['Custom_CPU'][-1]) + + results = pd.DataFrame(results) + + save_path = os.path.join(self.path, "profiling") + os.makedirs(save_path, exist_ok=True) + + with pd.ExcelWriter(os.path.join(save_path, "data.xlsx")) as writer: + results.to_excel(writer) + + print(f'Profiling data saved to {save_path}\n') + + print(f' CUDA improvement = ~{min(results['CUDA_factor'])}-{max(results['CUDA_factor'])}') + print(f' CPU improvement = ~{min(results['CPU_factor'])}-{max(results['CPU_factor'])}') + + #### Plot Over Iterations + def iterations(): + plt.figure(figsize=(12, 12)) + + plt.subplot(2, 1, 1) + plt.plot(results['Custom_CUDA'], label='Custom CUDA Time') + plt.plot(results['cuBLAS_CUDA'], label='cuBLAS CUDA Time') + plt.xlabel('Iteration') + plt.ylabel('CUDA Time (ms)') + plt.ylim(0, max(np.percentile(results['Custom_CUDA'], upper_limit), np.percentile(results['cuBLAS_CUDA'], upper_limit))) + plt.title('CUDA Time Comparison') + plt.legend() + + plt.subplot(2, 1, 2) + plt.plot(results['Custom_CPU'], label='Custom CPU Time') + plt.plot(results['cuBLAS_CPU'], label='cuBLAS CPU Time') + plt.xlabel('Iteration') + plt.ylabel('CPU Time (ms)') + plt.ylim(0, max(np.percentile(results['Custom_CPU'], upper_limit), np.percentile(results['cuBLAS_CPU'], upper_limit))) + plt.title('CPU Time Comparison') + plt.legend() + + plt.tight_layout() + plt.savefig(os.path.join(save_path, "overall.png")) + plt.show() + + + ### Plot Difference Factor + def difference(device): + fig = go.Figure(data=[go.Scatter3d( + x=results['M'], + y=results['N'], + z=results['K'], + mode='markers', + marker=dict( + size=3, + color=results[f'Custom_{device}']/results[f'cuBLAS_{device}'], + colorscale='Viridis', + colorbar=dict(title='Performance Factor (ms)') + ), + text=results[f'Custom_{device}']/results[f'cuBLAS_{device}'], # Add text for tooltips + hovertext=results[f'Custom_{device}']/results[f'cuBLAS_{device}'] # Use the text for hover info + )]) + + fig.update_layout( + title=f'Performance Difference Between Custom_{device} and cuBLAS_{device}', + scene=dict( + xaxis_title='M', + yaxis_title='N', + zaxis_title='K' + ) + ) + + fig.write_html(os.path.join(save_path, f"{device}_factor.html")) + + ### Plot as evolving over individual values + def individual(fixed_axis, device): + row_label = 'N' if fixed_axis == 'M' else 'M' + col_label = 'N' if fixed_axis == 'K' else 'K' + + fig, ax = plt.subplots(len(self.options), len(self.options), figsize=(20, 20)) + + for row, row_val in zip(ax, self.options): + for plot, col_val in zip(row, self.options): + + filtered_df = results.loc[(results[row_label] == row_val) & (results[col_label] == col_val)] + + plot.plot(filtered_df[fixed_axis], filtered_df[f'cuBLAS_{device}'], label='cuBLAS') + plot.plot(filtered_df[fixed_axis], filtered_df[f'Custom_{device}'], label=self.kernel_name) + + plot.set_title(f'{row_label}={row_val}, {col_label}={col_val}', fontsize=10) + plot.tick_params(axis='both', which='major', labelsize=8) + + # Place the legend outside the subplots + handles, labels = ax[0, 0].get_legend_handles_labels() + fig.legend(handles, labels, loc='upper left', fancybox=True, shadow=True, fontsize=15) + + # Set the super title for the figure + fig.suptitle(f'Comparing {self.ref_lib} and {self.kernel_name} on {device} (Plotting over {fixed_axis})', fontsize=20) + fig.text(0.04, 0.5, 'Performance (ms)', va='center', rotation='vertical', fontsize=14) + + fig.tight_layout(rect=[0.05, 0.05, 1, 0.95]) # Adjust the layout to make space for the title and legend + + plt.savefig(os.path.join(save_path, f"{device}_fixed{fixed_axis}.png")) + + iterations() + for device in ['CPU', 'CUDA']: + difference(device) + for fixed_axis in ['M', 'N', 'K']: + individual(fixed_axis, device) + + def throughput(self, M, N, K, iterations=100): + + perf = lambda ms: 2 * M * N * K * iterations * 1e-12 / (ms * 1e-3) + + inputs = torch.randn((M, K), device=self.device, dtype=torch.float16) + biases = torch.randn((M, 1), device=self.device, dtype=torch.float16) + + weights_torch, weights_kernel, scale = weights(N, K, self.kernel, self.baseline) + + cuBlas_time, cuBlas_mem = get_throughput(self.baseline, iterations)(inputs, weights_torch, biases, scale) + + kernel_time, kernel_mem = get_throughput(self.kernel, iterations)(inputs, weights_kernel, biases, scale) + + return [ M, N, K, perf(cuBlas_time), perf(kernel_time), kernel_mem - cuBlas_mem, ] + + def throughputs(self): + combinations = list(product(self.options, self.options, self.options)) + + pbar = tqdm(total=len(combinations), desc=f"Throughput", leave=True) + + results = [] + + for [M, N, K] in combinations: + + results.append(self.throughput(M, N, K)) + + pbar.set_description(f"M = {M} | N = {N} | K = {K} ") + pbar.update() + + pbar.close() + + results = pd.DataFrame(results, columns=['M', 'N', 'K', 'cuBLAS Performance', 'kernel Performance', 'Memory Difference']) + + save_path = os.path.join(self.path, "performance") + os.makedirs(save_path, exist_ok=True) + + with pd.ExcelWriter(os.path.join(save_path, "data.xlsx")) as writer: + results.to_excel(writer) + + print(f'Performance data saved to {save_path}\n') + + ### Performance + for fixed_axis in ['M', 'N', 'K']: + row_label = 'N' if fixed_axis == 'M' else 'M' + col_label = 'N' if fixed_axis == 'K' else 'K' + + fig, ax = plt.subplots(len(self.options), len(self.options), figsize=(20, 20)) + + for row, row_val in zip(ax, self.options): + for plot, col_val in zip(row, self.options): + + filtered_df = results.loc[(results[row_label] == row_val) & (results[col_label] == col_val)] + + plot.plot(filtered_df[fixed_axis], filtered_df[f'cuBLAS Performance'], label='cuBLAS') + plot.plot(filtered_df[fixed_axis], filtered_df[f'kernel Performance'], label=self.kernel_name) + + plot.set_title(f'{row_label}={row_val}, {col_label}={col_val}', fontsize=10) + plot.tick_params(axis='both', which='major', labelsize=8) + + # Place the legend outside the subplots + handles, labels = ax[0, 0].get_legend_handles_labels() + fig.legend(handles, labels, loc='upper left', fancybox=True, shadow=True, fontsize=15) + + # Set the super title for the figure + fig.suptitle(f'Comparing {self.ref_lib} and {self.kernel_name} Performance (Plotting over {fixed_axis})', fontsize=20) + fig.text(0.04, 0.5, 'Performance (TFLOPs)', va='center', rotation='vertical', fontsize=14) + + fig.tight_layout(rect=[0.05, 0.05, 1, 0.95]) # Adjust the layout to make space for the title and legend + + plt.savefig(os.path.join(save_path, f"performance_fixed{fixed_axis}.png")) + plt.show() + + ### memory + for fixed_axis in ['M', 'N', 'K']: + row_label = 'N' if fixed_axis == 'M' else 'M' + col_label = 'N' if fixed_axis == 'K' else 'K' + + fig, ax = plt.subplots(1, len(self.options), figsize=(20, 20)) + + for plot, row_val in zip(ax, self.options): + for col_val in self.options: + + filtered_df = results.loc[(results[row_label] == row_val) & (results[col_label] == col_val)] + + plot.plot(filtered_df[fixed_axis], filtered_df[f'Memory Difference'], label=f'{col_label}={col_val}') + + plot.set_title(f'{row_label}={row_val}', fontsize=10) + plot.legend() + plot.tick_params(axis='both', which='major', labelsize=8) + + + # Set the super title for the figure + fig.suptitle(f'Difference between {self.kernel_name} and {self.ref_lib} Memory (Plotting over {fixed_axis})', fontsize=20) + fig.text(0.04, 0.5, 'Memory Difference (MBs)', va='center', rotation='vertical', fontsize=14) + + fig.tight_layout(rect=[0.05, 0.05, 1, 0.95]) # Adjust the layout to make space for the title and legend + + plt.savefig(os.path.join(save_path, f"memory_fixed{fixed_axis}.png")) + plt.show() + +def get_throughput(func, iterations): + def wrapper(*args): + + for _ in range(10): + func(*args) + + torch.cuda.reset_peak_memory_stats() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + + start_event.record() + + # benchmark + for _ in range(iterations): + func(*args) + + end_event.record() + + torch.cuda.synchronize() + + max_memory_allocated = torch.cuda.max_memory_allocated() / (1024 ** 2) # Convert bytes to MB + + return start_event.elapsed_time(end_event), max_memory_allocated + + return wrapper \ No newline at end of file diff --git a/bitlinear/frozen_bitlinear/tests/__main__.py b/bitlinear/frozen_bitlinear/tests/__main__.py new file mode 100644 index 0000000..73c66f4 --- /dev/null +++ b/bitlinear/frozen_bitlinear/tests/__main__.py @@ -0,0 +1,15 @@ +from argparse import ArgumentParser +from tests.Benchmark import Benchmark + + +parser = ArgumentParser() + +parser.add_argument("--kernel", type=str, default='TorchLinear') +parser.add_argument("-a", action='store_true') +parser.add_argument("-p", action='store_true') +parser.add_argument("-u", action='store_true') +parser.add_argument("-t", action='store_true') +parser.add_argument("--save_dir", type=str, default='results/tmp') +args = parser.parse_args() + +Benchmark(args) diff --git a/bitlinear/frozen_bitlinear/tests/helpers.py b/bitlinear/frozen_bitlinear/tests/helpers.py new file mode 100644 index 0000000..af96d6c --- /dev/null +++ b/bitlinear/frozen_bitlinear/tests/helpers.py @@ -0,0 +1,27 @@ +import torch +import os +import json + +def weights(N, K, kernel, baseline, dtype = torch.float16) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + path = f'weights/{kernel}/{N}_{K}' + try: + + torch_weights = torch.load(os.path.join(path, 'torch.pt'), weights_only=True).to('cuda') + kernel_weights = torch.load(os.path.join(path, 'kernel.pt'), weights_only=True).to('cuda') + scale = torch.load(os.path.join(path, 'scale.pt'), weights_only=True).to('cuda') + return torch_weights, kernel_weights, scale + + except Exception as e: + print(f"Could not load weights because of {e}, generating new ones at {path}") + + weights = torch.randn((N, K), device='cuda', dtype=dtype) + (base_weights, scale) = baseline.scale_weights(weights) + (kernel_weights, _) = kernel.scale_weights(weights) + + os.makedirs(path, exist_ok=True) + torch.save(base_weights, os.path.join(path, 'torch.pt')) + torch.save(kernel_weights, os.path.join(path, 'kernel.pt')) + torch.save(scale, os.path.join(path, 'scale.pt')) + + return base_weights, kernel_weights, scale + diff --git a/bitlinear/kernels.py b/bitlinear/kernels.py index 5896fe7..048e179 100644 --- a/bitlinear/kernels.py +++ b/bitlinear/kernels.py @@ -75,4 +75,4 @@ def __call__(input, weight, bias=None): value += bias[j] out.append(value) output.append(out) - return torch.Tensor(output) + return torch.Tensor(output) \ No newline at end of file diff --git a/bitlinear/measures.py b/bitlinear/measures.py index 9ac4bec..986394e 100644 --- a/bitlinear/measures.py +++ b/bitlinear/measures.py @@ -12,4 +12,4 @@ def __call__(self, input, keepdim): class AbsMedian(Measure): def __call__(self, input, keepdim): - return input.abs().median(dim=-1, keepdim=keepdim).values if keepdim else input.abs().median() + return input.abs().median(dim=-1, keepdim=keepdim).values if keepdim else input.abs().median() \ No newline at end of file