Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[flake8]

max-line-length = 120
per-file-ignores =
kerops/kernels/*: B007
__init__.py: F401
23 changes: 23 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Lint

on: [ pull_request ]

env:
MODULE_NAME: kerops

jobs:
lint:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'

- name: Check python code style
run: |
pip install -r requirements-dev.txt
flake8 .
isort --check .
black --check .
30 changes: 28 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,28 @@
# kerops
Efficient and fast algorithms on the GPU
# Kerops
Fast algorithms for GPU

# Install
*pip is not available right now*
```shell
pip install kerops
```

# How fast is it?
Time comparison (ms) for NVidia RTX 3090. Input is an array of size (1, channels, 350, 350, 128); float16; <b>channels_last_3d</b>. Compared to usual 3d convolution from torch (kernel_size=3, padding=1, stride=1, bias=False, in_channels=channels, out_channels=channels). Slowdown compared to copying is shown in parentheses.

| channels |torch.clone| kerops.ops.DWConv |torch.nn.Conv3d(C->C)|
|:--------------------:|:---------:|:--------------------:|:-------------------:|
| 8 | 0.61 | 0.79 (x1.30) | 2.45 (x4.00) |
| 16 | 1.21 | 1.41 (x1.17) | 4.48 (x3.70) |
| 32 | 2.40 | 2.99 (x1.25) | 15.3 (x6.38) |
| 64 | 4.78 | 6.29 (x1.32) | 52.0 (x10.89) |
| 128 | 9.55 | 12.8 (x1.34) | 195.0 (x20.44) |


| channels |torch.clone|kerops.ops.DWConvWGRAD|torch.nn.Conv3d(C->C)|
|:--------------------:|:---------:|:--------------------:|:-------------------:|
| 8 | 0.61 | 2.55 (x4.18) | 7.14 (x11.70) |
| 16 | 1.21 | 3.01 (x2.49) | 12.1 (x10.00) |
| 32 | 2.40 | 4.80 (x2.00) | 24.6 (x10.25) |
| 64 | 4.78 | 8.72 (x1.82) | 71.3 (x14.91) |
| 128 | 9.55 | 17.9 (x1.87) | 245.0 (x25.65) |
2 changes: 1 addition & 1 deletion kerops/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.0.1'
__version__ = '0.0.2'
265 changes: 130 additions & 135 deletions kerops/kernels/dw_conv.py

Large diffs are not rendered by default.

101 changes: 0 additions & 101 deletions kerops/ops/_settings.py

This file was deleted.

16 changes: 6 additions & 10 deletions kerops/ops/addition.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@
from triton import next_power_of_2

from ..kernels.addition import _AddStats_cl3d_backward_impl, _AddStats_cl3d_impl
from ._settings import configure, get_l1_cache, ConfigurableArg
from ..settings import ConfigurableArg, configure, get_l1_cache


@configure(
_l1_cache_bytes=lambda: get_l1_cache(),
_num_warps=lambda: 8
)
@configure(_l1_cache_bytes=get_l1_cache, _num_warps=8)
def AddStats(x, y, inplace=False, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg):
num_channels = x.shape[1]
numel = x.numel()
Expand Down Expand Up @@ -52,11 +49,10 @@ def AddStats(x, y, inplace=False, *, _l1_cache_bytes: ConfigurableArg, _num_warp
return output, mean, sqmean


@configure(
_l1_cache_bytes=lambda: get_l1_cache(),
_num_warps=lambda: 8
)
def AddStatsBackward(add_grad, mean_grad, sqmean_grad, add_result, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg):
@configure(_l1_cache_bytes=get_l1_cache, _num_warps=8)
def AddStatsBackward(
add_grad, mean_grad, sqmean_grad, add_result, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg
):
num_channels = add_grad.shape[1]
numel = add_grad.numel()
assert add_result.shape == add_grad.shape
Expand Down
22 changes: 14 additions & 8 deletions kerops/ops/avgpool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
from triton import next_power_of_2

from ..kernels.avgpool import _AvgPoolCeilStats_cl3d_backward_impl, _AvgPoolCeilStats_cl3d_impl
from ._settings import configure, get_l1_cache, ConfigurableArg
from ..settings import ConfigurableArg, configure, get_l1_cache


@configure(
_l1_cache_bytes=lambda: get_l1_cache(),
_num_warps=lambda: 2,
_l1_cache_bytes=get_l1_cache,
_num_warps=2,
)
def AvgPoolCeilStats(x, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg):
num_channels = x.shape[1]
Expand Down Expand Up @@ -60,11 +60,17 @@ def AvgPoolCeilStats(x, *, _l1_cache_bytes: ConfigurableArg, _num_warps: Configu
return output, mean, sqmean


@configure(
_l1_cache_bytes=lambda: get_l1_cache(),
_num_warps=lambda: 4
)
def AvgPoolCeilStatsBackward(inpgrad, meangrad, sqmeangrad, output, outgrad_shape, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg):
@configure(_l1_cache_bytes=get_l1_cache, _num_warps=4)
def AvgPoolCeilStatsBackward(
inpgrad,
meangrad,
sqmeangrad,
output,
outgrad_shape,
*,
_l1_cache_bytes: ConfigurableArg,
_num_warps: ConfigurableArg,
):
MAX_SIZE = _l1_cache_bytes // inpgrad.element_size() # 32768 for fp16
bsize, num_channels, h_outgrad, w_outgrad, d_outgrad = outgrad_shape
d_inpgrad = inpgrad.shape[-1]
Expand Down
12 changes: 3 additions & 9 deletions kerops/ops/bnrelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,10 @@
from triton import next_power_of_2

from ..kernels.bnrelu import _ApplyBNReLU_cl3d_backward_impl, _ApplyBNReLU_cl3d_impl
from ._settings import configure, get_l1_cache, ConfigurableArg
from ..settings import ConfigurableArg, configure, get_l1_cache


@configure(
_l1_cache_bytes=lambda: get_l1_cache(),
_num_warps=lambda: 8
)
@configure(_l1_cache_bytes=get_l1_cache, _num_warps=8)
def ApplyBNReLU(x, weight, bias, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg):
num_channels = x.shape[1]
numel = x.numel()
Expand Down Expand Up @@ -44,10 +41,7 @@ def ApplyBNReLU(x, weight, bias, *, _l1_cache_bytes: ConfigurableArg, _num_warps
return output


@configure(
_l1_cache_bytes=lambda: get_l1_cache(),
_num_warps=lambda: 8
)
@configure(_l1_cache_bytes=get_l1_cache, _num_warps=8)
def ApplyBNReLUBackward(x, weight, bias, grad, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg):
num_channels = x.shape[1]
numel = x.numel()
Expand Down
63 changes: 32 additions & 31 deletions kerops/ops/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,35 @@
from triton import language as tl, next_power_of_2

from ..kernels.dw_conv import _DWConv_cl3d_impl, _DWConv_wgrad_cl3d_impl
from ._settings import configure, ConfigurableArg
from ..settings import ConfigurableArg, configure


def configure_dwconv(channels):
"""
Hardcoded, benchmarked on RTX 3090, mb should be generated automatically
H, W, D = [350, 350, 128]
def dwconv_warps(channels):
return {8: 1, 16: 2, 32: 2, 64: 2, 128: 4}[channels]

channels: [[num_warps, D_block], [num_warps, D_block]] one for fwd another for bwd
"""

"""
TODO
More geeky solution is to compare performances with respect to splitting axis D
to N * D_block with padding
"""
def dwconv_dblock(channels):
return {8: 32, 16: 32, 32: 16, 64: 8, 128: 8}[channels]

HARDCODED_CONFIG = {
8: [[1, 32], [1, 32]],
16: [[2, 32], [1, 32]],
32: [[2, 16], [1, 32]],
64: [[2, 8], [1, 16]],
128: [[4, 8], [2, 16]],
}

return HARDCODED_CONFIG.get(channels, None)
def dwconv_wgrad_warps(channels):
return {8: 1, 16: 1, 32: 1, 64: 1, 128: 2}[channels]


def dwconv_wgrad_dblock(channels):
return {8: 32, 16: 32, 32: 16, 64: 8, 128: 8}[channels]


def dwconv_wgrad_ilp(channels):
return {8: 1, 16: 1, 32: 2, 64: 3, 128: 3}[channels]


@configure(
ACCTYPE=lambda: 'float32',
_num_warps=lambda weight: configure_dwconv(weight.shape[-1])[0][0],
D_block=lambda weight: configure_dwconv(weight.shape[-1])[0][1],
ACCTYPE='float32',
_num_warps=lambda x: dwconv_warps(x.shape[1]),
D_block=lambda x: dwconv_dblock(x.shape[1]),
)
def DWConv(x, weight, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: ConfigurableArg = 2, D_block: ConfigurableArg = 32):
def DWConv(x, weight, *, ACCTYPE: ConfigurableArg, _num_warps: ConfigurableArg, D_block: ConfigurableArg):
channels = x.shape[1]

assert x.ndim == 5
Expand Down Expand Up @@ -79,11 +74,14 @@ def DWConv(x, weight, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: Confi


@configure(
_num_warps=lambda x: configure_dwconv(x.shape[1])[1][0],
ACCTYPE=lambda: 'float32',
D_block=lambda x: configure_dwconv(x.shape[1])[1][1],
ACCTYPE='float32',
_num_warps=lambda x: dwconv_wgrad_warps(x.shape[1]),
D_block=lambda x: dwconv_wgrad_dblock(x.shape[1]),
ILP=lambda x: dwconv_wgrad_ilp(x.shape[1]),
)
def DWConvWGRAD(x, grad, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: ConfigurableArg=2, D_block: ConfigurableArg = 32):
def DWConvWGRAD(
x, grad, *, ACCTYPE: ConfigurableArg, _num_warps: ConfigurableArg, D_block: ConfigurableArg, ILP: ConfigurableArg
):
channels = x.shape[1]

assert x.ndim == grad.ndim == 5
Expand All @@ -99,10 +97,10 @@ def DWConvWGRAD(x, grad, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: Co
bsize, _, H, W, D = x.shape
batch_stride, _, H_stride, W_stride, _ = x.stride()

H_grid = ceil(H / 2)
H_grid = ceil(H / (2 * ILP))
W_grid = ceil(W / 2)
D_grid = ceil(D / D_block)
grid = (H_grid, W_grid * D_grid)
grid = (H_grid, W_grid, D_grid)

grad_w = torch.zeros([bsize, H_grid * W_grid * D_grid, 3, 3, 3, channels], device=x.device, dtype=torch.float16)
WD_grid = W_grid * D_grid # TODO: mb implement in another way
Expand All @@ -121,9 +119,12 @@ def DWConvWGRAD(x, grad, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: Co
channels,
D_block,
WD_grid,
D_grid,
H_grid,
ILP,
num_warps=_num_warps,
)

grad_w = torch.flip(grad_w.sum(dim=(0, 1)), dims=(2,))
grad_w = grad_w.sum(dim=(0, 1))

return grad_w
Loading