diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..a1f5a6a --- /dev/null +++ b/.flake8 @@ -0,0 +1,6 @@ +[flake8] + +max-line-length = 120 +per-file-ignores = + kerops/kernels/*: B007 + __init__.py: F401 \ No newline at end of file diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..951d456 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,23 @@ +name: Lint + +on: [ pull_request ] + +env: + MODULE_NAME: kerops + +jobs: + lint: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Check python code style + run: | + pip install -r requirements-dev.txt + flake8 . + isort --check . + black --check . \ No newline at end of file diff --git a/README.md b/README.md index c3b11f8..38dd106 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,28 @@ -# kerops -Efficient and fast algorithms on the GPU +# Kerops +Fast algorithms for GPU + +# Install +*pip is not available right now* +```shell +pip install kerops +``` + +# How fast is it? +Time comparison (ms) for NVidia RTX 3090. Input is an array of size (1, channels, 350, 350, 128); float16; channels_last_3d. Compared to usual 3d convolution from torch (kernel_size=3, padding=1, stride=1, bias=False, in_channels=channels, out_channels=channels). Slowdown compared to copying is shown in parentheses. + +| channels |torch.clone| kerops.ops.DWConv |torch.nn.Conv3d(C->C)| +|:--------------------:|:---------:|:--------------------:|:-------------------:| +| 8 | 0.61 | 0.79 (x1.30) | 2.45 (x4.00) | +| 16 | 1.21 | 1.41 (x1.17) | 4.48 (x3.70) | +| 32 | 2.40 | 2.99 (x1.25) | 15.3 (x6.38) | +| 64 | 4.78 | 6.29 (x1.32) | 52.0 (x10.89) | +| 128 | 9.55 | 12.8 (x1.34) | 195.0 (x20.44) | + + +| channels |torch.clone|kerops.ops.DWConvWGRAD|torch.nn.Conv3d(C->C)| +|:--------------------:|:---------:|:--------------------:|:-------------------:| +| 8 | 0.61 | 2.55 (x4.18) | 7.14 (x11.70) | +| 16 | 1.21 | 3.01 (x2.49) | 12.1 (x10.00) | +| 32 | 2.40 | 4.80 (x2.00) | 24.6 (x10.25) | +| 64 | 4.78 | 8.72 (x1.82) | 71.3 (x14.91) | +| 128 | 9.55 | 17.9 (x1.87) | 245.0 (x25.65) | diff --git a/kerops/__version__.py b/kerops/__version__.py index b8023d8..d18f409 100644 --- a/kerops/__version__.py +++ b/kerops/__version__.py @@ -1 +1 @@ -__version__ = '0.0.1' +__version__ = '0.0.2' diff --git a/kerops/kernels/dw_conv.py b/kerops/kernels/dw_conv.py index d2fa328..4bcf056 100644 --- a/kerops/kernels/dw_conv.py +++ b/kerops/kernels/dw_conv.py @@ -28,13 +28,13 @@ def _DWConv_cl3d_impl( d_offset = tl.arange(0, D_block) near_offset = tl.arange(0, 4) - 1 - offset = d_offset[:, None, None] * channels + channels_offset[None, :, None] + near_offset[None, None, :] * channels - mask = d_offset[:, None, None] + near_offset[None, None, :] < D - D_block * D_cell - mask = mask and (d_offset[:, None, None] + near_offset[None, None, :] >= 0 - D_block * D_cell) - mask = mask and (near_offset[None, None, :] != 2) + offset = d_offset[:, None, None] * channels + channels_offset[None, None, :] + near_offset[None, :, None] * channels + mask = d_offset[:, None, None] + near_offset[None, :, None] < D - D_block * D_cell + mask = mask and (d_offset[:, None, None] + near_offset[None, :, None] >= 0 - D_block * D_cell) + mask = mask and (near_offset[None, :, None] != 2) - weight_offset = channels_offset[None, :, None] + tl.arange(0, 4)[None, None, :] * channels - weight_mask = tl.arange(0, 4)[None, None, :] != 3 + weight_offset = channels_offset[None, None, :] + tl.arange(0, 4)[None, :, None] * channels + weight_mask = tl.arange(0, 4)[None, :, None] != 3 weight_h0_w0 = tl.load(weight_ptr + weight_offset, mask=weight_mask, other=0.0) weight_h0_w1 = tl.load((weight_ptr + 3 * channels) + weight_offset, mask=weight_mask, other=0.0) @@ -68,57 +68,57 @@ def _DWConv_cl3d_impl( for k in tl.static_range(0, 16): if k == 0: - h0_w0 += tl.sum(x * weight_h0_w0, axis=2) + h0_w0 += tl.sum(x * weight_h0_w0, axis=1) elif k == 1: - h0_w0 += tl.sum(x * weight_h1_w0, axis=2) - h1_w0 += tl.sum(x * weight_h0_w0, axis=2) + h0_w0 += tl.sum(x * weight_h1_w0, axis=1) + h1_w0 += tl.sum(x * weight_h0_w0, axis=1) elif k == 2: - h0_w0 += tl.sum(x * weight_h2_w0, axis=2) - h1_w0 += tl.sum(x * weight_h1_w0, axis=2) + h0_w0 += tl.sum(x * weight_h2_w0, axis=1) + h1_w0 += tl.sum(x * weight_h1_w0, axis=1) elif k == 3: - h1_w0 += tl.sum(x * weight_h2_w0, axis=2) + h1_w0 += tl.sum(x * weight_h2_w0, axis=1) elif k == 4: - h0_w0 += tl.sum(x * weight_h0_w1, axis=2) - h0_w1 += tl.sum(x * weight_h0_w0, axis=2) + h0_w0 += tl.sum(x * weight_h0_w1, axis=1) + h0_w1 += tl.sum(x * weight_h0_w0, axis=1) elif k == 5: - h0_w0 += tl.sum(x * weight_h1_w1, axis=2) - h0_w1 += tl.sum(x * weight_h1_w0, axis=2) - h1_w0 += tl.sum(x * weight_h0_w1, axis=2) - h1_w1 += tl.sum(x * weight_h0_w0, axis=2) + h0_w0 += tl.sum(x * weight_h1_w1, axis=1) + h0_w1 += tl.sum(x * weight_h1_w0, axis=1) + h1_w0 += tl.sum(x * weight_h0_w1, axis=1) + h1_w1 += tl.sum(x * weight_h0_w0, axis=1) elif k == 6: - h0_w0 += tl.sum(x * weight_h2_w1, axis=2) - h0_w1 += tl.sum(x * weight_h2_w0, axis=2) - h1_w0 += tl.sum(x * weight_h1_w1, axis=2) - h1_w1 += tl.sum(x * weight_h1_w0, axis=2) + h0_w0 += tl.sum(x * weight_h2_w1, axis=1) + h0_w1 += tl.sum(x * weight_h2_w0, axis=1) + h1_w0 += tl.sum(x * weight_h1_w1, axis=1) + h1_w1 += tl.sum(x * weight_h1_w0, axis=1) elif k == 7: - h1_w0 += tl.sum(x * weight_h2_w1, axis=2) - h1_w1 += tl.sum(x * weight_h2_w0, axis=2) + h1_w0 += tl.sum(x * weight_h2_w1, axis=1) + h1_w1 += tl.sum(x * weight_h2_w0, axis=1) elif k == 8: - h0_w0 += tl.sum(x * weight_h0_w2, axis=2) - h0_w1 += tl.sum(x * weight_h0_w1, axis=2) + h0_w0 += tl.sum(x * weight_h0_w2, axis=1) + h0_w1 += tl.sum(x * weight_h0_w1, axis=1) elif k == 9: - h0_w0 += tl.sum(x * weight_h1_w2, axis=2) - h0_w1 += tl.sum(x * weight_h1_w1, axis=2) - h1_w0 += tl.sum(x * weight_h0_w2, axis=2) - h1_w1 += tl.sum(x * weight_h0_w1, axis=2) + h0_w0 += tl.sum(x * weight_h1_w2, axis=1) + h0_w1 += tl.sum(x * weight_h1_w1, axis=1) + h1_w0 += tl.sum(x * weight_h0_w2, axis=1) + h1_w1 += tl.sum(x * weight_h0_w1, axis=1) elif k == 10: - h0_w0 += tl.sum(x * weight_h2_w2, axis=2) - h0_w1 += tl.sum(x * weight_h2_w1, axis=2) - h1_w0 += tl.sum(x * weight_h1_w2, axis=2) - h1_w1 += tl.sum(x * weight_h1_w1, axis=2) + h0_w0 += tl.sum(x * weight_h2_w2, axis=1) + h0_w1 += tl.sum(x * weight_h2_w1, axis=1) + h1_w0 += tl.sum(x * weight_h1_w2, axis=1) + h1_w1 += tl.sum(x * weight_h1_w1, axis=1) elif k == 11: - h1_w0 += tl.sum(x * weight_h2_w2, axis=2) - h1_w1 += tl.sum(x * weight_h2_w1, axis=2) + h1_w0 += tl.sum(x * weight_h2_w2, axis=1) + h1_w1 += tl.sum(x * weight_h2_w1, axis=1) elif k == 12: - h0_w1 += tl.sum(x * weight_h0_w2, axis=2) + h0_w1 += tl.sum(x * weight_h0_w2, axis=1) elif k == 13: - h0_w1 += tl.sum(x * weight_h1_w2, axis=2) - h1_w1 += tl.sum(x * weight_h0_w2, axis=2) + h0_w1 += tl.sum(x * weight_h1_w2, axis=1) + h1_w1 += tl.sum(x * weight_h0_w2, axis=1) elif k == 14: - h0_w1 += tl.sum(x * weight_h2_w2, axis=2) - h1_w1 += tl.sum(x * weight_h1_w2, axis=2) + h0_w1 += tl.sum(x * weight_h2_w2, axis=1) + h1_w1 += tl.sum(x * weight_h1_w2, axis=1) else: - h1_w1 += tl.sum(x * weight_h2_w2, axis=2) + h1_w1 += tl.sum(x * weight_h2_w2, axis=1) k_ = k + 1 i = (k_ % 4) - 1 @@ -157,17 +157,17 @@ def _DWConv_wgrad_cl3d_impl( channels: tl.constexpr, D_block: tl.constexpr, WD_grid, + D_grid, + delta_H_grid, + ILP: tl.constexpr, ): H_cell = tl.program_id(0) - W_D_cell = tl.program_id(1) - - D_gridsize = tl.cdiv(D, D_block) - W_cell = W_D_cell // D_gridsize - D_cell = W_D_cell % D_gridsize + W_cell = tl.program_id(1) + D_cell = tl.program_id(2) input_ptr += D_cell * D_block * channels grad_ptr += D_cell * D_block * channels - weight_grad_ptr += (H_cell * WD_grid + W_D_cell) * 27 * channels + weight_grad_ptr += (H_cell * WD_grid + W_cell * D_grid + D_cell) * 27 * channels channels_offset = tl.arange(0, channels) channels_offset = tl.max_contiguous(tl.multiple_of(channels_offset, channels), channels) @@ -179,11 +179,8 @@ def _DWConv_wgrad_cl3d_impl( mask = mask and (d_offset[None, None, :] + near_offset[:, None, None] >= 0 - D_block * D_cell) mask = mask and (near_offset[:, None, None] != 2) - in_offset = d_offset[None, None, :] * channels + channels_offset[None, :, None] - in_mask = d_offset[None, None, :] < D - D_block * D_cell - - H1_load = 2 * H_cell + 1 < H - W1_load = 2 * W_cell + 1 < W + grad_offset = d_offset[None, :] * channels + channels_offset[:, None] + grad_mask = d_offset[None, :] < D - D_block * D_cell h0_w0 = tl.zeros([4, channels], dtype=ACCTYPE) h0_w1 = tl.zeros([4, channels], dtype=ACCTYPE) @@ -195,91 +192,89 @@ def _DWConv_wgrad_cl3d_impl( h2_w1 = tl.zeros([4, channels], dtype=ACCTYPE) h2_w2 = tl.zeros([4, channels], dtype=ACCTYPE) - tmp_input_ptr = input_ptr + 2 * H_cell * H_stride + 2 * W_cell * W_stride - x_h0_w0 = tl.load(tmp_input_ptr + in_offset, mask=in_mask, other=0.0) - - tmp_input_ptr = input_ptr + (2 * H_cell + 1) * H_stride + 2 * W_cell * W_stride - x_h1_w0 = tl.load(tmp_input_ptr + in_offset, mask=in_mask and H1_load, other=0.0) - - tmp_input_ptr = input_ptr + 2 * H_cell * H_stride + (2 * W_cell + 1) * W_stride - x_h0_w1 = tl.load(tmp_input_ptr + in_offset, mask=in_mask and W1_load, other=0.0) - - tmp_input_ptr = input_ptr + (2 * H_cell + 1) * H_stride + (2 * W_cell + 1) * W_stride - x_h1_w1 = tl.load(tmp_input_ptr + in_offset, mask=in_mask and (W1_load and H1_load), other=0.0) - gradw_offset = tl.arange(0, 4)[:, None] * channels + channels_offset[None, :] gradw_mask = near_offset[:, None] != 2 - load_next = (2 * H_cell - 1 < H and 2 * H_cell - 1 >= 0) and (2 * W_cell - 1 < W and 2 * W_cell - 1 >= 0) - tmp_grad_ptr = grad_ptr + (2 * H_cell - 1) * H_stride + (2 * W_cell - 1) * W_stride - i = -1 - j = -1 - grad = tl.zeros([4, channels, D_block], dtype=tl.float16) - if load_next: - grad = tl.load(tmp_grad_ptr + offset, mask=mask) - - for k in tl.static_range(0, 16): - if load_next: - if i == -1 and j == -1: - h2_w2 += tl.sum(grad * x_h0_w0, axis=2) - elif i == -1 and j == 0: - h2_w1 += tl.sum(grad * x_h0_w0, axis=2) - h2_w2 += tl.sum(grad * x_h0_w1, axis=2) - elif i == -1 and j == 1: - h2_w0 += tl.sum(grad * x_h0_w0, axis=2) - h2_w1 += tl.sum(grad * x_h0_w1, axis=2) - elif i == -1 and j == 2: - h2_w0 += tl.sum(grad * x_h0_w1, axis=2) - elif i == 0 and j == -1: - h1_w2 += tl.sum(grad * x_h0_w0, axis=2) - h2_w2 += tl.sum(grad * x_h1_w0, axis=2) - elif i == 0 and j == 0: - h1_w1 += tl.sum(grad * x_h0_w0, axis=2) - h2_w1 += tl.sum(grad * x_h1_w0, axis=2) - h1_w2 += tl.sum(grad * x_h0_w1, axis=2) - h2_w2 += tl.sum(grad * x_h1_w1, axis=2) - elif i == 0 and j == 1: - h1_w0 += tl.sum(grad * x_h0_w0, axis=2) - h2_w0 += tl.sum(grad * x_h1_w0, axis=2) - h1_w1 += tl.sum(grad * x_h0_w1, axis=2) - h2_w1 += tl.sum(grad * x_h1_w1, axis=2) - elif i == 0 and j == 2: - h1_w0 += tl.sum(grad * x_h0_w1, axis=2) - h2_w0 += tl.sum(grad * x_h1_w1, axis=2) - elif i == 1 and j == -1: - h0_w2 += tl.sum(grad * x_h0_w0, axis=2) - h1_w2 += tl.sum(grad * x_h1_w0, axis=2) - elif i == 1 and j == 0: - h0_w1 += tl.sum(grad * x_h0_w0, axis=2) - h1_w1 += tl.sum(grad * x_h1_w0, axis=2) - h0_w2 += tl.sum(grad * x_h0_w1, axis=2) - h1_w2 += tl.sum(grad * x_h1_w1, axis=2) - elif i == 1 and j == 1: - h0_w0 += tl.sum(grad * x_h0_w0, axis=2) - h1_w0 += tl.sum(grad * x_h1_w0, axis=2) - h0_w1 += tl.sum(grad * x_h0_w1, axis=2) - h1_w1 += tl.sum(grad * x_h1_w1, axis=2) - elif i == 1 and j == 2: - h0_w0 += tl.sum(grad * x_h0_w1, axis=2) - h1_w0 += tl.sum(grad * x_h1_w1, axis=2) - elif i == 2 and j == -1: - h0_w2 += tl.sum(grad * x_h1_w0, axis=2) - elif i == 2 and j == 0: - h0_w1 += tl.sum(grad * x_h1_w0, axis=2) - h0_w2 += tl.sum(grad * x_h1_w1, axis=2) - elif i == 2 and j == 1: - h0_w0 += tl.sum(grad * x_h1_w0, axis=2) - h0_w1 += tl.sum(grad * x_h1_w1, axis=2) - else: - h0_w0 += tl.sum(grad * x_h1_w1, axis=2) - - k_ = k + 1 - i = (k_ % 4) - 1 - j = (k_ // 4) - 1 - load_next = (2 * H_cell + i < H and 2 * H_cell + i >= 0) and (2 * W_cell + j < W and 2 * W_cell + j >= 0) - tmp_grad_ptr = grad_ptr + (2 * H_cell + i) * H_stride + (2 * W_cell + j) * W_stride - if load_next and k_ < 16: - grad = tl.load(tmp_grad_ptr + offset, mask=mask) # , other=0.) + for ilp in tl.static_range(0, ILP): + H0_load = 2 * H_cell < H + H1_load = 2 * H_cell + 1 < H + W1_load = 2 * W_cell + 1 < W + + tmp_input_ptr = input_ptr + 2 * H_cell * H_stride + 2 * W_cell * W_stride + x_h0_w0 = tl.load(tmp_input_ptr + offset, mask=mask and H0_load) + + tmp_input_ptr = input_ptr + (2 * H_cell + 1) * H_stride + 2 * W_cell * W_stride + x_h1_w0 = tl.load(tmp_input_ptr + offset, mask=mask and H1_load) + + tmp_input_ptr = input_ptr + 2 * H_cell * H_stride + (2 * W_cell + 1) * W_stride + x_h0_w1 = tl.load(tmp_input_ptr + offset, mask=mask and (W1_load and H0_load)) + + tmp_input_ptr = input_ptr + (2 * H_cell + 1) * H_stride + (2 * W_cell + 1) * W_stride + x_h1_w1 = tl.load(tmp_input_ptr + offset, mask=mask and (W1_load and H1_load)) + + for k in tl.static_range(0, 16): + i = (k % 4) - 1 + j = (k // 4) - 1 + load_next = (2 * H_cell + i < H and 2 * H_cell + i >= 0) and (2 * W_cell + j < W and 2 * W_cell + j >= 0) + tmp_grad_ptr = grad_ptr + (2 * H_cell + i) * H_stride + (2 * W_cell + j) * W_stride + + if load_next: + grad = tl.load(tmp_grad_ptr + grad_offset, mask=grad_mask, other=0.0)[None] + + if i == -1 and j == -1: + h2_w2 += tl.sum(grad * x_h0_w0, axis=2) + elif i == -1 and j == 0: + h2_w1 += tl.sum(grad * x_h0_w0, axis=2) + h2_w2 += tl.sum(grad * x_h0_w1, axis=2) + elif i == -1 and j == 1: + h2_w0 += tl.sum(grad * x_h0_w0, axis=2) + h2_w1 += tl.sum(grad * x_h0_w1, axis=2) + elif i == -1 and j == 2: + h2_w0 += tl.sum(grad * x_h0_w1, axis=2) + elif i == 0 and j == -1: + h1_w2 += tl.sum(grad * x_h0_w0, axis=2) + h2_w2 += tl.sum(grad * x_h1_w0, axis=2) + elif i == 0 and j == 0: + h1_w1 += tl.sum(grad * x_h0_w0, axis=2) + h2_w1 += tl.sum(grad * x_h1_w0, axis=2) + h1_w2 += tl.sum(grad * x_h0_w1, axis=2) + h2_w2 += tl.sum(grad * x_h1_w1, axis=2) + elif i == 0 and j == 1: + h1_w0 += tl.sum(grad * x_h0_w0, axis=2) + h2_w0 += tl.sum(grad * x_h1_w0, axis=2) + h1_w1 += tl.sum(grad * x_h0_w1, axis=2) + h2_w1 += tl.sum(grad * x_h1_w1, axis=2) + elif i == 0 and j == 2: + h1_w0 += tl.sum(grad * x_h0_w1, axis=2) + h2_w0 += tl.sum(grad * x_h1_w1, axis=2) + elif i == 1 and j == -1: + h0_w2 += tl.sum(grad * x_h0_w0, axis=2) + h1_w2 += tl.sum(grad * x_h1_w0, axis=2) + elif i == 1 and j == 0: + h0_w1 += tl.sum(grad * x_h0_w0, axis=2) + h1_w1 += tl.sum(grad * x_h1_w0, axis=2) + h0_w2 += tl.sum(grad * x_h0_w1, axis=2) + h1_w2 += tl.sum(grad * x_h1_w1, axis=2) + elif i == 1 and j == 1: + h0_w0 += tl.sum(grad * x_h0_w0, axis=2) + h1_w0 += tl.sum(grad * x_h1_w0, axis=2) + h0_w1 += tl.sum(grad * x_h0_w1, axis=2) + h1_w1 += tl.sum(grad * x_h1_w1, axis=2) + elif i == 1 and j == 2: + h0_w0 += tl.sum(grad * x_h0_w1, axis=2) + h1_w0 += tl.sum(grad * x_h1_w1, axis=2) + elif i == 2 and j == -1: + h0_w2 += tl.sum(grad * x_h1_w0, axis=2) + elif i == 2 and j == 0: + h0_w1 += tl.sum(grad * x_h1_w0, axis=2) + h0_w2 += tl.sum(grad * x_h1_w1, axis=2) + elif i == 2 and j == 1: + h0_w0 += tl.sum(grad * x_h1_w0, axis=2) + h0_w1 += tl.sum(grad * x_h1_w1, axis=2) + else: + h0_w0 += tl.sum(grad * x_h1_w1, axis=2) + + H_cell += delta_H_grid tl.store(weight_grad_ptr + gradw_offset, h0_w0, mask=gradw_mask) tl.store((weight_grad_ptr + 3 * channels) + gradw_offset, h0_w1, mask=gradw_mask) diff --git a/kerops/ops/_settings.py b/kerops/ops/_settings.py deleted file mode 100644 index a2f0d25..0000000 --- a/kerops/ops/_settings.py +++ /dev/null @@ -1,101 +0,0 @@ -import inspect -from functools import wraps - - -L1_CACHE_BYTES = 65536 - - -def get_l1_cache(): - global L1_CACHE_BYTES - return L1_CACHE_BYTES - - -def set_l1_cache(new_cache): - global L1_CACHE_BYTES - L1_CACHE_BYTES = new_cache - - -class ConfigurableArg: - pass - - -class EmptyKwarg: - pass - - -def check_function_signature(signature): - for param in signature.parameters.values(): - if param.annotation is ConfigurableArg and param.kind is not inspect.Parameter.KEYWORD_ONLY: - raise RuntimeError(f'ConfigurableArg must be keyword-only - {param.name}') - elif param.annotation is not ConfigurableArg and param.kind is inspect.Parameter.KEYWORD_ONLY: - raise RuntimeError(f'non-ConfigurableArg must not be keyword-only - {param.name}') - - -def get_configurable_args_from_signature(signature): - return [param.name for param in signature.parameters.values() if param.annotation is ConfigurableArg] - - -def get_usual_args_from_signature(signature): - return [ - param.name for param in signature.parameters.values() - if param.kind is inspect.Parameter.POSITIONAL_ONLY or param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD - ] - - -def is_configurators_fit(configurable_args, configurators_names): - if set(configurable_args) != set(configurators_names): - raise RuntimeError(f'Configuration mismatch, {configurable_args=}, {configurators_names=}') - - -def configurator_call(args, configurator, usual_args): - conf_sign = inspect.signature(configurator) - - # take argnames from configurator, map args with respect to origin function's argnames - conf_args = [args[usual_args.index(param.name)] for param in conf_sign.parameters.values()] - - return configurator(*conf_args) - - -class ConfiguredFunction: - def __init__(self, origin_function, signature, configurable_args, usual_args, **configurators): - self.origin_function = origin_function - self.signature = signature - self.configurable_args = configurable_args - self.usual_args = usual_args - self.configurators = configurators - - - def __call__(self, *args, **kwargs): - tmp_kwargs = {**{arg: EmptyKwarg for arg in self.configurable_args}, **kwargs} - - bind = self.signature.bind(*args, **tmp_kwargs) - bind.apply_defaults() - - configured_kwargs = { - k: configurator_call(bind.args, self.configurators[k], self.usual_args) - if input_v is EmptyKwarg else input_v - for k, input_v in bind.kwargs.items() - } - - return self.origin_function(*bind.args, **configured_kwargs) - - - def reconfigure(self, **new_configurators): - is_configurators_fit(self.configurable_args, new_configurators.keys()) - self.configurators = new_configurators - - -def configure(**configurators): - def wrapper(function): - signature = inspect.signature(function) - - check_function_signature(signature) - - configurable_args = get_configurable_args_from_signature(signature) - - usual_args = get_usual_args_from_signature(signature) - - is_configurators_fit(configurable_args, configurators.keys()) - - return wraps(function)(ConfiguredFunction(function, signature, configurable_args, usual_args, **configurators)) - return wrapper diff --git a/kerops/ops/addition.py b/kerops/ops/addition.py index 426e52d..0e29f76 100644 --- a/kerops/ops/addition.py +++ b/kerops/ops/addition.py @@ -5,13 +5,10 @@ from triton import next_power_of_2 from ..kernels.addition import _AddStats_cl3d_backward_impl, _AddStats_cl3d_impl -from ._settings import configure, get_l1_cache, ConfigurableArg +from ..settings import ConfigurableArg, configure, get_l1_cache -@configure( - _l1_cache_bytes=lambda: get_l1_cache(), - _num_warps=lambda: 8 -) +@configure(_l1_cache_bytes=get_l1_cache, _num_warps=8) def AddStats(x, y, inplace=False, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): num_channels = x.shape[1] numel = x.numel() @@ -52,11 +49,10 @@ def AddStats(x, y, inplace=False, *, _l1_cache_bytes: ConfigurableArg, _num_warp return output, mean, sqmean -@configure( - _l1_cache_bytes=lambda: get_l1_cache(), - _num_warps=lambda: 8 -) -def AddStatsBackward(add_grad, mean_grad, sqmean_grad, add_result, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): +@configure(_l1_cache_bytes=get_l1_cache, _num_warps=8) +def AddStatsBackward( + add_grad, mean_grad, sqmean_grad, add_result, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg +): num_channels = add_grad.shape[1] numel = add_grad.numel() assert add_result.shape == add_grad.shape diff --git a/kerops/ops/avgpool.py b/kerops/ops/avgpool.py index 39453ec..82e0f18 100644 --- a/kerops/ops/avgpool.py +++ b/kerops/ops/avgpool.py @@ -5,12 +5,12 @@ from triton import next_power_of_2 from ..kernels.avgpool import _AvgPoolCeilStats_cl3d_backward_impl, _AvgPoolCeilStats_cl3d_impl -from ._settings import configure, get_l1_cache, ConfigurableArg +from ..settings import ConfigurableArg, configure, get_l1_cache @configure( - _l1_cache_bytes=lambda: get_l1_cache(), - _num_warps=lambda: 2, + _l1_cache_bytes=get_l1_cache, + _num_warps=2, ) def AvgPoolCeilStats(x, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): num_channels = x.shape[1] @@ -60,11 +60,17 @@ def AvgPoolCeilStats(x, *, _l1_cache_bytes: ConfigurableArg, _num_warps: Configu return output, mean, sqmean -@configure( - _l1_cache_bytes=lambda: get_l1_cache(), - _num_warps=lambda: 4 -) -def AvgPoolCeilStatsBackward(inpgrad, meangrad, sqmeangrad, output, outgrad_shape, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): +@configure(_l1_cache_bytes=get_l1_cache, _num_warps=4) +def AvgPoolCeilStatsBackward( + inpgrad, + meangrad, + sqmeangrad, + output, + outgrad_shape, + *, + _l1_cache_bytes: ConfigurableArg, + _num_warps: ConfigurableArg, +): MAX_SIZE = _l1_cache_bytes // inpgrad.element_size() # 32768 for fp16 bsize, num_channels, h_outgrad, w_outgrad, d_outgrad = outgrad_shape d_inpgrad = inpgrad.shape[-1] diff --git a/kerops/ops/bnrelu.py b/kerops/ops/bnrelu.py index d0076ef..f894ee1 100644 --- a/kerops/ops/bnrelu.py +++ b/kerops/ops/bnrelu.py @@ -5,13 +5,10 @@ from triton import next_power_of_2 from ..kernels.bnrelu import _ApplyBNReLU_cl3d_backward_impl, _ApplyBNReLU_cl3d_impl -from ._settings import configure, get_l1_cache, ConfigurableArg +from ..settings import ConfigurableArg, configure, get_l1_cache -@configure( - _l1_cache_bytes=lambda: get_l1_cache(), - _num_warps=lambda: 8 -) +@configure(_l1_cache_bytes=get_l1_cache, _num_warps=8) def ApplyBNReLU(x, weight, bias, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): num_channels = x.shape[1] numel = x.numel() @@ -44,10 +41,7 @@ def ApplyBNReLU(x, weight, bias, *, _l1_cache_bytes: ConfigurableArg, _num_warps return output -@configure( - _l1_cache_bytes=lambda: get_l1_cache(), - _num_warps=lambda: 8 -) +@configure(_l1_cache_bytes=get_l1_cache, _num_warps=8) def ApplyBNReLUBackward(x, weight, bias, grad, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): num_channels = x.shape[1] numel = x.numel() diff --git a/kerops/ops/conv.py b/kerops/ops/conv.py index 1f1fae8..b163dc8 100644 --- a/kerops/ops/conv.py +++ b/kerops/ops/conv.py @@ -4,40 +4,35 @@ from triton import language as tl, next_power_of_2 from ..kernels.dw_conv import _DWConv_cl3d_impl, _DWConv_wgrad_cl3d_impl -from ._settings import configure, ConfigurableArg +from ..settings import ConfigurableArg, configure -def configure_dwconv(channels): - """ - Hardcoded, benchmarked on RTX 3090, mb should be generated automatically - H, W, D = [350, 350, 128] +def dwconv_warps(channels): + return {8: 1, 16: 2, 32: 2, 64: 2, 128: 4}[channels] - channels: [[num_warps, D_block], [num_warps, D_block]] one for fwd another for bwd - """ - """ - TODO - More geeky solution is to compare performances with respect to splitting axis D - to N * D_block with padding - """ +def dwconv_dblock(channels): + return {8: 32, 16: 32, 32: 16, 64: 8, 128: 8}[channels] - HARDCODED_CONFIG = { - 8: [[1, 32], [1, 32]], - 16: [[2, 32], [1, 32]], - 32: [[2, 16], [1, 32]], - 64: [[2, 8], [1, 16]], - 128: [[4, 8], [2, 16]], - } - return HARDCODED_CONFIG.get(channels, None) +def dwconv_wgrad_warps(channels): + return {8: 1, 16: 1, 32: 1, 64: 1, 128: 2}[channels] + + +def dwconv_wgrad_dblock(channels): + return {8: 32, 16: 32, 32: 16, 64: 8, 128: 8}[channels] + + +def dwconv_wgrad_ilp(channels): + return {8: 1, 16: 1, 32: 2, 64: 3, 128: 3}[channels] @configure( - ACCTYPE=lambda: 'float32', - _num_warps=lambda weight: configure_dwconv(weight.shape[-1])[0][0], - D_block=lambda weight: configure_dwconv(weight.shape[-1])[0][1], + ACCTYPE='float32', + _num_warps=lambda x: dwconv_warps(x.shape[1]), + D_block=lambda x: dwconv_dblock(x.shape[1]), ) -def DWConv(x, weight, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: ConfigurableArg = 2, D_block: ConfigurableArg = 32): +def DWConv(x, weight, *, ACCTYPE: ConfigurableArg, _num_warps: ConfigurableArg, D_block: ConfigurableArg): channels = x.shape[1] assert x.ndim == 5 @@ -79,11 +74,14 @@ def DWConv(x, weight, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: Confi @configure( - _num_warps=lambda x: configure_dwconv(x.shape[1])[1][0], - ACCTYPE=lambda: 'float32', - D_block=lambda x: configure_dwconv(x.shape[1])[1][1], + ACCTYPE='float32', + _num_warps=lambda x: dwconv_wgrad_warps(x.shape[1]), + D_block=lambda x: dwconv_wgrad_dblock(x.shape[1]), + ILP=lambda x: dwconv_wgrad_ilp(x.shape[1]), ) -def DWConvWGRAD(x, grad, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: ConfigurableArg=2, D_block: ConfigurableArg = 32): +def DWConvWGRAD( + x, grad, *, ACCTYPE: ConfigurableArg, _num_warps: ConfigurableArg, D_block: ConfigurableArg, ILP: ConfigurableArg +): channels = x.shape[1] assert x.ndim == grad.ndim == 5 @@ -99,10 +97,10 @@ def DWConvWGRAD(x, grad, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: Co bsize, _, H, W, D = x.shape batch_stride, _, H_stride, W_stride, _ = x.stride() - H_grid = ceil(H / 2) + H_grid = ceil(H / (2 * ILP)) W_grid = ceil(W / 2) D_grid = ceil(D / D_block) - grid = (H_grid, W_grid * D_grid) + grid = (H_grid, W_grid, D_grid) grad_w = torch.zeros([bsize, H_grid * W_grid * D_grid, 3, 3, 3, channels], device=x.device, dtype=torch.float16) WD_grid = W_grid * D_grid # TODO: mb implement in another way @@ -121,9 +119,12 @@ def DWConvWGRAD(x, grad, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: Co channels, D_block, WD_grid, + D_grid, + H_grid, + ILP, num_warps=_num_warps, ) - grad_w = torch.flip(grad_w.sum(dim=(0, 1)), dims=(2,)) + grad_w = grad_w.sum(dim=(0, 1)) return grad_w diff --git a/kerops/ops/linear.py b/kerops/ops/linear.py index 960c59f..ea8e8f9 100644 --- a/kerops/ops/linear.py +++ b/kerops/ops/linear.py @@ -4,33 +4,38 @@ from triton import next_power_of_2 from ..kernels.linear import _ReLULinearAdd, _ReLULinearAddBackward -from ._settings import configure, ConfigurableArg +from ..settings import ConfigurableArg, configure -def configure_linear(in_channels): - # num_warps, D_block, ILP - HARDCODED_CONFIG = { - 16: [[2, 16, 8], [4, 16, 16]], - 32: [[2, 16, 8], [8, 32, 16]], - 64: [[1, 16, 4], [8, 32, 16]], - 128: [[1, 16, 4], [8, 32, 16]], - } +def fwd_warps(in_channels): + return {16: 2, 32: 2, 64: 1, 128: 1}[in_channels] + + +def fwd_ilp(in_channels): + return {16: 8, 32: 8, 64: 4, 128: 4}[in_channels] + + +def bwd_warps(in_channels): + return {16: 4, 32: 8, 64: 8, 128: 8}[in_channels] + + +def bwd_dblock(in_channels): + return {16: 16, 32: 32, 64: 32, 128: 32}[in_channels] - return HARDCODED_CONFIG.get(in_channels, None) @configure( - _num_warps=lambda weight: configure_linear(weight.shape[0])[0][0], - D_block=lambda weight: configure_linear(weight.shape[0])[0][1], - _ILP=lambda weight: configure_linear(weight.shape[0])[0][2], + _num_warps=lambda weight: fwd_warps(weight.shape[0]), + D_block=16, + _ILP=lambda weight: fwd_ilp(weight.shape[0]), ) def ReLULinearAdd( x, weight, add_other, *, - _num_warps: ConfigurableArg=2, - D_block: ConfigurableArg=16, - _ILP: ConfigurableArg=8, + _num_warps: ConfigurableArg, + D_block: ConfigurableArg, + _ILP: ConfigurableArg, ): in_channels = x.shape[1] out_channels = weight.shape[1] @@ -72,18 +77,18 @@ def ReLULinearAdd( @configure( - _num_warps=lambda weight: configure_linear(weight.shape[0])[1][0], - D_block=lambda weight: configure_linear(weight.shape[0])[1][1], - _ILP=lambda weight: configure_linear(weight.shape[0])[1][2], + _num_warps=lambda weight: bwd_warps(weight.shape[0]), + D_block=lambda weight: bwd_dblock(weight.shape[0]), + _ILP=16, ) def ReLULinearBackward( input, grad, weight, *, - _num_warps: ConfigurableArg=8, - D_block: ConfigurableArg=32, - _ILP: ConfigurableArg=16, + _num_warps: ConfigurableArg, + D_block: ConfigurableArg, + _ILP: ConfigurableArg, ): in_channels = weight.shape[0] out_channels = grad.shape[1] diff --git a/kerops/ops/quantization.py b/kerops/ops/quantization.py index e90e8d8..6300edd 100644 --- a/kerops/ops/quantization.py +++ b/kerops/ops/quantization.py @@ -3,16 +3,10 @@ import torch from ..kernels.quantization import _DequantUint8Window_impl, _QuantUint8Window_impl -from ._settings import configure, get_l1_cache, ConfigurableArg +from ..settings import ConfigurableArg, configure, get_l1_cache -__all__ = ['QuantUint8Window', 'DequantUint8Window'] - - -@configure( - _num_warps=lambda: 4, - _l1_cache_bytes=lambda: get_l1_cache() -) +@configure(_num_warps=4, _l1_cache_bytes=get_l1_cache) def QuantUint8Window(x, window, *, _num_warps: ConfigurableArg, _l1_cache_bytes: ConfigurableArg): numel = x.numel() MAX_SIZE = _l1_cache_bytes // (2 * x.element_size()) @@ -26,10 +20,7 @@ def QuantUint8Window(x, window, *, _num_warps: ConfigurableArg, _l1_cache_bytes: return output -@configure( - _num_warps=lambda: 4, - _l1_cache_bytes=lambda: get_l1_cache() -) +@configure(_num_warps=4, _l1_cache_bytes=get_l1_cache) def DequantUint8Window(x, init_dtype, window, _num_warps: ConfigurableArg, _l1_cache_bytes: ConfigurableArg): numel = x.numel() output = torch.empty_like(x, dtype=init_dtype) diff --git a/kerops/ops/stats.py b/kerops/ops/stats.py index 1157325..dc643ff 100644 --- a/kerops/ops/stats.py +++ b/kerops/ops/stats.py @@ -5,13 +5,10 @@ from triton import next_power_of_2 from ..kernels.stats import _Stats_cl3d_backward_impl, _Stats_cl3d_impl -from ._settings import configure, get_l1_cache, ConfigurableArg +from ..settings import ConfigurableArg, configure, get_l1_cache -@configure( - _l1_cache_bytes=lambda: get_l1_cache(), - _num_warps=lambda: 4 -) +@configure(_l1_cache_bytes=get_l1_cache, _num_warps=4) def Stats(x, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): num_channels = x.shape[1] numel = x.numel() @@ -34,10 +31,7 @@ def Stats(x, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): return mean, sqmean -@configure( - _l1_cache_bytes=lambda: get_l1_cache(), - _num_warps=lambda: 4 -) +@configure(_l1_cache_bytes=get_l1_cache, _num_warps=4) def StatsBackward(x, mean_grad, sqmean_grad, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): num_channels = x.shape[1] numel = x.numel() diff --git a/kerops/settings/__init__.py b/kerops/settings/__init__.py new file mode 100644 index 0000000..31decd2 --- /dev/null +++ b/kerops/settings/__init__.py @@ -0,0 +1,3 @@ +from .hardware_conf import get_l1_cache +from .utils import ConfigurableArg +from .wrapper import configure diff --git a/kerops/settings/hardware_conf.py b/kerops/settings/hardware_conf.py new file mode 100644 index 0000000..e9eba7e --- /dev/null +++ b/kerops/settings/hardware_conf.py @@ -0,0 +1,11 @@ +L1_CACHE_BYTES = 65536 + + +def get_l1_cache(): + global L1_CACHE_BYTES + return L1_CACHE_BYTES + + +def set_l1_cache(new_cache): + global L1_CACHE_BYTES + L1_CACHE_BYTES = new_cache diff --git a/kerops/settings/utils.py b/kerops/settings/utils.py new file mode 100644 index 0000000..6b623a9 --- /dev/null +++ b/kerops/settings/utils.py @@ -0,0 +1,30 @@ +from inspect import Parameter + + +class ConfigurableArg: + pass + + +def validate_signature(signature): + for param in signature.parameters.values(): + if param.annotation is ConfigurableArg and param.kind is not Parameter.KEYWORD_ONLY: + raise RuntimeError(f'ConfigurableArg must be keyword-only - {param.name}') + elif param.annotation is not ConfigurableArg and param.kind is Parameter.KEYWORD_ONLY: + raise RuntimeError(f'non-ConfigurableArg must not be keyword-only - {param.name}') + + +def get_config_args(signature): + return [param.name for param in signature.parameters.values() if param.annotation is ConfigurableArg] + + +def get_standard_args(signature): + return [ + param.name + for param in signature.parameters.values() + if param.kind is Parameter.POSITIONAL_ONLY or param.kind is Parameter.POSITIONAL_OR_KEYWORD + ] + + +def configs_match(configurable_args, configurators_names): + if set(configurable_args) != set(configurators_names): + raise RuntimeError(f'Configuration mismatch, {configurable_args=}, {configurators_names=}') diff --git a/kerops/settings/wrapper.py b/kerops/settings/wrapper.py new file mode 100644 index 0000000..36a4bbe --- /dev/null +++ b/kerops/settings/wrapper.py @@ -0,0 +1,68 @@ +import inspect +from functools import wraps +from typing import Callable + +from .utils import configs_match, get_config_args, get_standard_args, validate_signature + + +class EmptyKwarg: + pass + + +class ConfiguredFunction: + def __init__(self, origin_function, signature, configurable_args, usual_args, **configurators): + self.origin_function = origin_function + self.signature = signature + self.configurable_args = configurable_args + self.usual_args = usual_args + self.configurators = configurators + + @staticmethod + def configurator_call(args, configurator, usual_args): + if isinstance(configurator, Callable): + conf_sign = inspect.signature(configurator) + + # take argnames from configurator, map args with respect to origin function's argnames + conf_args = [args[usual_args.index(param.name)] for param in conf_sign.parameters.values()] + + return configurator(*conf_args) + else: + return configurator + + def __call__(self, *args, **kwargs): + tmp_kwargs = {**{arg: EmptyKwarg for arg in self.configurable_args}, **kwargs} + + bind = self.signature.bind(*args, **tmp_kwargs) + bind.apply_defaults() + + configured_kwargs = { + k: ( + self.configurator_call(bind.args, self.configurators[k], self.usual_args) + if input_v is EmptyKwarg + else input_v + ) + for k, input_v in bind.kwargs.items() + } + + return self.origin_function(*bind.args, **configured_kwargs) + + def reconfigure(self, **new_configurators): + configs_match(self.configurable_args, new_configurators.keys()) + self.configurators = new_configurators + + +def configure(**configurators): + def wrapper(function): + signature = inspect.signature(function) + + validate_signature(signature) + + configurable_args = get_config_args(signature) + + usual_args = get_standard_args(signature) + + configs_match(configurable_args, configurators.keys()) + + return wraps(function)(ConfiguredFunction(function, signature, configurable_args, usual_args, **configurators)) + + return wrapper diff --git a/requirements-dev.txt b/requirements-dev.txt index 4c736bb..793564c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,4 @@ -matplotlib -seaborn - -black<23.0.0 +black flake8 flake8-tidy-imports flake8-quotes diff --git a/requirements.txt b/requirements.txt index 4bb9eb1..b24cb10 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -triton==2.3.0 +triton==3.1.0 torch