Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ __pycache__
dist
examples/data
examples/models
.vscode/
10 changes: 4 additions & 6 deletions bitlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
AbsMean,
AbsMedian,
)
from .kernels import (
Naive,
NaiveListComp,
TernaryNaive,
from .frozen_bitlinear.frozen_bitlinear import (
TorchLinear,
TorchMulAdd,
)
Naive
)

4 changes: 4 additions & 0 deletions bitlinear/frozen_bitlinear/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
results
**/weights
**.so
**/build
146 changes: 146 additions & 0 deletions bitlinear/frozen_bitlinear/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@

## Motivations

Matrix multiplications are a key building block of most modern high-performance computing systems.
They are notoriously hard to optimize, hence their implementation is generally done by
hardware vendors themselves as part of so-called "kernel libraries" (e.g., cuBLAS).
Unfortunately, these libraries are often proprietary and cannot be easily customized
to accommodate the needs of our ternary system. Attempts at [matmul-free linear implementations](https://github.com/ridgerchu/matmulfreellm) still rely on the cuBLAS kernel for the matmul. The goal of this investigation is to develop a CUDA kernel competitive with cuBLAS in speed and performance, and ultimately far faster, as a ternary system can be implemented by a series of masked adds as opposed to traditional matrix multiplication.
Roughly speaking, the traditional kernel implements the following blocked
algorithm to multiply a (M, K) by a (K, N) matrix:

```C++
const uint col = blockIdx.x * blockDim.x + threadIdx.x;
const uint row = blockIdx.y * blockDim.y + threadIdx.y;

if (row < M && col < N) {

for (int k = 0; k < K) {
sum += A[row, k] * B[k, col];
}

output[row * N + col] = sum + bias[row];
}
```
where each instance of the above code is a single thread.

We hope to develop a kernel that can avoid the multiplicative step, instead following the successive framework:

```C++
const uint col = blockIdx.x * blockDim.x + threadIdx.x;
const uint row = blockIdx.y * blockDim.y + threadIdx.y;

if (row < M && col < N) {

for (int k = 0; k < K) {
if (B[k, col] == 1) {
sum += A[row, k];
} else if (B[k, col] == -1) {
sum -= A[row, k];
}
}

output[row * N + col] = sum + bias[row];
}
```

The code above corresponds to excution of a single thread. For more information on how CUDA parallelization is implemented in code and in reality, visit [here](https://developer.nvidia.com/blog/cuda-refresher-cuda-programming-model/).

### Optimizations

To see savings from our algorithm, we turn to weight packing algorithms. If we represent a weight in two bits, then we can take an eighth of the memory to store the weights, resulting in potentially significant speedups if memory loads are the limiting factor (as they often are).

Because weight loads are significantly cheaper per capita than activation loads, we want to more heavily prioritize cache hits on activation loads. To do so, we can follow a number of strategies. Within each thread, we can compute a greater number of computations with the same activations. This may lead to speedups in memory loads, but significant optimization is required to find the balance between memory savings and parallelization costs. We can also use shared memory between threads, which allows for threads to compute different outputs in parallel, but without reloading activations each time. A barebones implementation is shown below:

```C++
const uint col = blockIdx.x * blockDim.x + threadIdx.x;
const uint row = blockIdx.y * blockDim.y + threadIdx.y;

__shared__ half shared_activations[];

if (row < M && col < N) {

float sum = 0;

for (int k = 0; k < K) {
shared_activations[k] = A[row, k];
}

__syncthreads();

for (int k = 0; k < K) {
if (B[k, col] == 1) {
sum += A[row, k];
} else if (B[k, col] == -1) {
sum -= A[row, k];
}
}

output[row * N + col] = sum + bias[row];
}
```
The tradeoff with this implementation is the reliance on threads to be synchronized within the thread call, which is necessary to avoid race conditions but results in a performance penalty.

Further optimization is required to balance the size of K, M and N in scheduling the batch sizes to schedule to the block and individual threads. Because cuBLAS is highly optimized in this regard, we need a significant speedup to offset their gains. Later work can go into developing autotuning algorithms that dynamically schedule the computation as to maximize the L2 cache hit ratio and minimize synchronization delay.

## Future Work

As of now, this optimization extends only into inference. The method of packing weights seems redundant in training, as the shadow weights being stored in higher precision is [necessary for gradient accumulation](https://arxiv.org/pdf/2402.17764). The speedups still can be realized in the forward passes, but more investigation and more robust kernels must be devised to turn these into reality.

Further work can also be spent looking into fusing these kernels with layer norm, softmax, and other such adjoining layers to realize even faster speedups. This has been growing in popularity recently, and luckily, alot of the work [has been done](https://github.com/ridgerchu/matmulfreellm) for ternary weight layers, we just need to plug and play with our matmul kernel.

Finally, we have 1.58 bits of information packed into 2, but by utilizing compression algorithms, we can realize even higher levels of weight packing. Because a weight can only occupy one of 3 states ($00$, $10$, or $01$), we can compress $10$ to just $1$, resulting in a further 13% of savings. Another solution could be extending the context region, which would introduce complex interdependencies, but as each weight must be loaded in higher precision in a load magnitudes more costly than bit-level computation, it may prove beneficial. As weight loads are not the crux of the cost, however, it is unclear if traversing this path is worth it.

## Setup
Make sure you have the correct packages installed in your virtual environment. In a conda environment, you can run:
```
> conda create -n env python pip
> conda activate env
> pip install -r requirements.txt
```

Afterwards, you need to build the desired CUDA kernels for use on your machine. To do so, you can selectively build the kernels you would like to test or use through:
```
> cd path/to/bitlinear/bitlinear/frozen_bitlinear/cuda/pack_weights
> python setup.py build_ext --inplace
> cd path/to/bitlinear/bitlinear/frozen_bitlinear/cuda/kernels/*
> python setup.py build_ext --inplace
```
You can also build all of them at once by running
```
> cd path/to/bitlinear/bitlinear/frozen_bitlinear/
> chmod +x scripts/build.sh
> scripts/build.sh
```

## Testing
Choose one of the kernels available in ```path/to/bitlinear/bitlinear/frozen_bitlinear/src/``` to test against the PyTorch baseline for your device.
```
< conda activate env
< cd path/to/bitlinear/bitlinear/frozen_bitlinear
< chmod +x scripts/test.sh
< scripts/test.sh
-d <device> (CUDA_AVAILABLE_DEVICES=$device)
-k <kernel_name>
```

All logs, data, and plots are stored locally under ```path/to/bitlinear/bitlinear/frozen_bitlinear/results/{%Y%m%d_%T}/```.

If any issues come up, you can reach me at [sopsahl@mit.edu](mailto:sopsahl@mit.edu).

All benchmarks have been tested on an A6000.

## Cleanup
To clean the builds and results, run the following commands.
```
> cd path/to/bitlinear/bitlinear/frozen_bitlinear
> chmod +x scripts/clean.sh
> scripts/clean.sh
-b (default: builds only)
-w (weights)
-r (results)
```


## References
...
1 change: 1 addition & 0 deletions bitlinear/frozen_bitlinear/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from src import TorchLinear, Naive, Warptiling
2 changes: 2 additions & 0 deletions bitlinear/frozen_bitlinear/cuda/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from cuda.kernels import *
from cuda.pack_weights import pack_weights
2 changes: 2 additions & 0 deletions bitlinear/frozen_bitlinear/cuda/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from cuda.kernels.naive_linear import naive
from cuda.kernels.warptiling import warptiling
84 changes: 84 additions & 0 deletions bitlinear/frozen_bitlinear/cuda/kernels/archive/bitlinear_naive.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#include <torch/extension.h>
#include <pybind11/pybind11.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>

#define CEIL_DIV(x, y) ((x) + (y)-1) / (y)

__global__ void naive_kernel(
const half *input,
const int *weights,
const half *bias,
half *output,
float scale,
int M,
int N,
int K
)
{
const uint col = blockIdx.x * blockDim.x + threadIdx.x;
const uint row = blockIdx.y * blockDim.y + threadIdx.y;

if (row < M && col < N) {

float sum = 0.0f;
int weight;

for (int k = 0; k < K; k += 16) {
weight = weights[(col * K + k)/16];

for (int offset=0; offset<16; offset++) {
int mask = (weight & (3 << (2 * offset))) >> (2 * offset);
float input_val = __half2float(input[row * K + k + offset]);

if (mask == 1) {
sum += input_val;
} else if (mask == 2) {
sum -= input_val;
}
}

}

// Store the result into the correct location in memory
output[row * N + col] = __float2half((sum + __half2float(bias[row])));

}
}

torch::Tensor linear(
torch::Tensor input,
torch::Tensor weights,
torch::Tensor bias,
float scale,
int M,
int N,
int K
) {
auto options = torch::TensorOptions().dtype(torch::kFloat16).device(torch::kCUDA, 0);
auto output = torch::zeros({M, N}, options);

uint blockSize = 32;

dim3 dimGrid(CEIL_DIV(N, blockSize), CEIL_DIV(M, blockSize));
dim3 dimBlock(blockSize, blockSize);

naive_kernel<<<dimGrid, dimBlock>>>(
reinterpret_cast<const half*>(input.data_ptr<at::Half>()),
weights.data_ptr<int>(),
reinterpret_cast<const half*>(bias.data_ptr<at::Half>()),
reinterpret_cast<half*>(output.data_ptr<at::Half>()),
scale,
M, N, K
);

return output;
}


// Binding to generate the .so file, to call from Python.
PYBIND11_MODULE(bitlinear_naive, m) {
m.doc() = "Implementation of bitlinear forward linear in CUDA";
m.def("linear", &linear, "bitlinear_forward (CUDA)");
}
Loading