From 9e0fa0f2468a0cbf74ce66f5e00bd27a434d99b4 Mon Sep 17 00:00:00 2001 From: Anihilatorgunn Date: Thu, 8 Aug 2024 14:26:27 +0300 Subject: [PATCH 01/19] readme --- README.md | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c3b11f8..71a5c97 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,19 @@ -# kerops +# Kerops Efficient and fast algorithms on the GPU + +# Install +*pip is not available right now* +```shell +pip install kerops +``` + +# How fast is it? +Time comparison (ms) for NVidia RTX 3090. Input is an array of size (1, channels, 350, 350, 128); float16; channels_last_3d. Compared to usual 3d convolution from torch (kernel_size=3, padding=1, stride=1, bias=False, in_channels=channels, out_channels=channels). Slowdown compared to copying is shown in parentheses. + +| channels |torch.clone| kerops.ops.DWConv |torch.nn.Conv3d(C->C)| +|:--------------------:|:---------:|:--------------------:|:-------------------:| +| 8 | 0.61 | 0.81 (x1.32) | 2.45 (x4.00) | +| 16 | 1.21 | 1.27 (1.27) | 4.48 (x3.70) | +| 32 | 2.40 | 3.12 (1.30) | 15.3 (x6.38) | +| 64 | 4.78 | 6.29 (1.32) | 52.0 (x10.89) | +| 128 | 9.55 | 13.2 (1.38) | 195.0 (x20.44) | From d4423b3604b32b425960c0c8e3b9e7bc40298ab9 Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Sat, 28 Dec 2024 15:26:27 +0300 Subject: [PATCH 02/19] configure support non-callable arguments --- kerops/ops/_settings.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/kerops/ops/_settings.py b/kerops/ops/_settings.py index a2f0d25..15e4089 100644 --- a/kerops/ops/_settings.py +++ b/kerops/ops/_settings.py @@ -1,5 +1,6 @@ import inspect from functools import wraps +from typing import Callable L1_CACHE_BYTES = 65536 @@ -47,15 +48,6 @@ def is_configurators_fit(configurable_args, configurators_names): raise RuntimeError(f'Configuration mismatch, {configurable_args=}, {configurators_names=}') -def configurator_call(args, configurator, usual_args): - conf_sign = inspect.signature(configurator) - - # take argnames from configurator, map args with respect to origin function's argnames - conf_args = [args[usual_args.index(param.name)] for param in conf_sign.parameters.values()] - - return configurator(*conf_args) - - class ConfiguredFunction: def __init__(self, origin_function, signature, configurable_args, usual_args, **configurators): self.origin_function = origin_function @@ -65,6 +57,19 @@ def __init__(self, origin_function, signature, configurable_args, usual_args, ** self.configurators = configurators + @staticmethod + def configurator_call(args, configurator, usual_args): + if isinstance(configurator, Callable): + conf_sign = inspect.signature(configurator) + + # take argnames from configurator, map args with respect to origin function's argnames + conf_args = [args[usual_args.index(param.name)] for param in conf_sign.parameters.values()] + + return configurator(*conf_args) + else: + return configurator + + def __call__(self, *args, **kwargs): tmp_kwargs = {**{arg: EmptyKwarg for arg in self.configurable_args}, **kwargs} @@ -72,7 +77,7 @@ def __call__(self, *args, **kwargs): bind.apply_defaults() configured_kwargs = { - k: configurator_call(bind.args, self.configurators[k], self.usual_args) + k: self.configurator_call(bind.args, self.configurators[k], self.usual_args) if input_v is EmptyKwarg else input_v for k, input_v in bind.kwargs.items() } From 5667214111f8d585a2ab3a6c05c5bb7c9f642cd6 Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Sat, 28 Dec 2024 15:34:21 +0300 Subject: [PATCH 03/19] a bit cleaner --- kerops/ops/addition.py | 8 ++++---- kerops/ops/avgpool.py | 8 ++++---- kerops/ops/bnrelu.py | 8 ++++---- kerops/ops/conv.py | 4 ++-- kerops/ops/linear.py | 1 + kerops/ops/quantization.py | 11 ++++------- kerops/ops/stats.py | 8 ++++---- 7 files changed, 23 insertions(+), 25 deletions(-) diff --git a/kerops/ops/addition.py b/kerops/ops/addition.py index 426e52d..bf59f9d 100644 --- a/kerops/ops/addition.py +++ b/kerops/ops/addition.py @@ -9,8 +9,8 @@ @configure( - _l1_cache_bytes=lambda: get_l1_cache(), - _num_warps=lambda: 8 + _l1_cache_bytes=get_l1_cache, + _num_warps=8 ) def AddStats(x, y, inplace=False, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): num_channels = x.shape[1] @@ -53,8 +53,8 @@ def AddStats(x, y, inplace=False, *, _l1_cache_bytes: ConfigurableArg, _num_warp @configure( - _l1_cache_bytes=lambda: get_l1_cache(), - _num_warps=lambda: 8 + _l1_cache_bytes=get_l1_cache, + _num_warps=8 ) def AddStatsBackward(add_grad, mean_grad, sqmean_grad, add_result, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): num_channels = add_grad.shape[1] diff --git a/kerops/ops/avgpool.py b/kerops/ops/avgpool.py index 39453ec..e38e234 100644 --- a/kerops/ops/avgpool.py +++ b/kerops/ops/avgpool.py @@ -9,8 +9,8 @@ @configure( - _l1_cache_bytes=lambda: get_l1_cache(), - _num_warps=lambda: 2, + _l1_cache_bytes=get_l1_cache, + _num_warps=2, ) def AvgPoolCeilStats(x, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): num_channels = x.shape[1] @@ -61,8 +61,8 @@ def AvgPoolCeilStats(x, *, _l1_cache_bytes: ConfigurableArg, _num_warps: Configu @configure( - _l1_cache_bytes=lambda: get_l1_cache(), - _num_warps=lambda: 4 + _l1_cache_bytes=get_l1_cache, + _num_warps=4 ) def AvgPoolCeilStatsBackward(inpgrad, meangrad, sqmeangrad, output, outgrad_shape, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): MAX_SIZE = _l1_cache_bytes // inpgrad.element_size() # 32768 for fp16 diff --git a/kerops/ops/bnrelu.py b/kerops/ops/bnrelu.py index d0076ef..f1e4ce7 100644 --- a/kerops/ops/bnrelu.py +++ b/kerops/ops/bnrelu.py @@ -9,8 +9,8 @@ @configure( - _l1_cache_bytes=lambda: get_l1_cache(), - _num_warps=lambda: 8 + _l1_cache_bytes=get_l1_cache, + _num_warps=8 ) def ApplyBNReLU(x, weight, bias, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): num_channels = x.shape[1] @@ -45,8 +45,8 @@ def ApplyBNReLU(x, weight, bias, *, _l1_cache_bytes: ConfigurableArg, _num_warps @configure( - _l1_cache_bytes=lambda: get_l1_cache(), - _num_warps=lambda: 8 + _l1_cache_bytes=get_l1_cache, + _num_warps=8 ) def ApplyBNReLUBackward(x, weight, bias, grad, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): num_channels = x.shape[1] diff --git a/kerops/ops/conv.py b/kerops/ops/conv.py index 1f1fae8..d537fbf 100644 --- a/kerops/ops/conv.py +++ b/kerops/ops/conv.py @@ -33,7 +33,7 @@ def configure_dwconv(channels): @configure( - ACCTYPE=lambda: 'float32', + ACCTYPE='float32', _num_warps=lambda weight: configure_dwconv(weight.shape[-1])[0][0], D_block=lambda weight: configure_dwconv(weight.shape[-1])[0][1], ) @@ -79,8 +79,8 @@ def DWConv(x, weight, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: Confi @configure( + ACCTYPE='float32', _num_warps=lambda x: configure_dwconv(x.shape[1])[1][0], - ACCTYPE=lambda: 'float32', D_block=lambda x: configure_dwconv(x.shape[1])[1][1], ) def DWConvWGRAD(x, grad, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: ConfigurableArg=2, D_block: ConfigurableArg = 32): diff --git a/kerops/ops/linear.py b/kerops/ops/linear.py index 960c59f..9ebccda 100644 --- a/kerops/ops/linear.py +++ b/kerops/ops/linear.py @@ -18,6 +18,7 @@ def configure_linear(in_channels): return HARDCODED_CONFIG.get(in_channels, None) + @configure( _num_warps=lambda weight: configure_linear(weight.shape[0])[0][0], D_block=lambda weight: configure_linear(weight.shape[0])[0][1], diff --git a/kerops/ops/quantization.py b/kerops/ops/quantization.py index e90e8d8..49e8daf 100644 --- a/kerops/ops/quantization.py +++ b/kerops/ops/quantization.py @@ -6,12 +6,9 @@ from ._settings import configure, get_l1_cache, ConfigurableArg -__all__ = ['QuantUint8Window', 'DequantUint8Window'] - - @configure( - _num_warps=lambda: 4, - _l1_cache_bytes=lambda: get_l1_cache() + _num_warps=4, + _l1_cache_bytes=get_l1_cache ) def QuantUint8Window(x, window, *, _num_warps: ConfigurableArg, _l1_cache_bytes: ConfigurableArg): numel = x.numel() @@ -27,8 +24,8 @@ def QuantUint8Window(x, window, *, _num_warps: ConfigurableArg, _l1_cache_bytes: @configure( - _num_warps=lambda: 4, - _l1_cache_bytes=lambda: get_l1_cache() + _num_warps=4, + _l1_cache_bytes=get_l1_cache ) def DequantUint8Window(x, init_dtype, window, _num_warps: ConfigurableArg, _l1_cache_bytes: ConfigurableArg): numel = x.numel() diff --git a/kerops/ops/stats.py b/kerops/ops/stats.py index 1157325..2427828 100644 --- a/kerops/ops/stats.py +++ b/kerops/ops/stats.py @@ -9,8 +9,8 @@ @configure( - _l1_cache_bytes=lambda: get_l1_cache(), - _num_warps=lambda: 4 + _l1_cache_bytes=get_l1_cache, + _num_warps=4 ) def Stats(x, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): num_channels = x.shape[1] @@ -35,8 +35,8 @@ def Stats(x, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): @configure( - _l1_cache_bytes=lambda: get_l1_cache(), - _num_warps=lambda: 4 + _l1_cache_bytes=get_l1_cache, + _num_warps=4 ) def StatsBackward(x, mean_grad, sqmean_grad, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): num_channels = x.shape[1] From 35eddc9eea7cc664a6132b838882c8edefabab47 Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Sat, 28 Dec 2024 15:34:37 +0300 Subject: [PATCH 04/19] black + isort --- kerops/ops/_settings.py | 16 ++++++++-------- kerops/ops/addition.py | 16 ++++++---------- kerops/ops/avgpool.py | 18 ++++++++++++------ kerops/ops/bnrelu.py | 12 +++--------- kerops/ops/conv.py | 10 +++++++--- kerops/ops/linear.py | 14 +++++++------- kerops/ops/quantization.py | 12 +++--------- kerops/ops/stats.py | 12 +++--------- 8 files changed, 49 insertions(+), 61 deletions(-) diff --git a/kerops/ops/_settings.py b/kerops/ops/_settings.py index 15e4089..aede5ba 100644 --- a/kerops/ops/_settings.py +++ b/kerops/ops/_settings.py @@ -36,9 +36,10 @@ def get_configurable_args_from_signature(signature): return [param.name for param in signature.parameters.values() if param.annotation is ConfigurableArg] -def get_usual_args_from_signature(signature): +def get_usual_args_from_signature(signature): return [ - param.name for param in signature.parameters.values() + param.name + for param in signature.parameters.values() if param.kind is inspect.Parameter.POSITIONAL_ONLY or param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD ] @@ -56,7 +57,6 @@ def __init__(self, origin_function, signature, configurable_args, usual_args, ** self.usual_args = usual_args self.configurators = configurators - @staticmethod def configurator_call(args, configurator, usual_args): if isinstance(configurator, Callable): @@ -69,22 +69,21 @@ def configurator_call(args, configurator, usual_args): else: return configurator - def __call__(self, *args, **kwargs): tmp_kwargs = {**{arg: EmptyKwarg for arg in self.configurable_args}, **kwargs} - + bind = self.signature.bind(*args, **tmp_kwargs) bind.apply_defaults() configured_kwargs = { k: self.configurator_call(bind.args, self.configurators[k], self.usual_args) - if input_v is EmptyKwarg else input_v + if input_v is EmptyKwarg + else input_v for k, input_v in bind.kwargs.items() } return self.origin_function(*bind.args, **configured_kwargs) - def reconfigure(self, **new_configurators): is_configurators_fit(self.configurable_args, new_configurators.keys()) self.configurators = new_configurators @@ -93,7 +92,7 @@ def reconfigure(self, **new_configurators): def configure(**configurators): def wrapper(function): signature = inspect.signature(function) - + check_function_signature(signature) configurable_args = get_configurable_args_from_signature(signature) @@ -103,4 +102,5 @@ def wrapper(function): is_configurators_fit(configurable_args, configurators.keys()) return wraps(function)(ConfiguredFunction(function, signature, configurable_args, usual_args, **configurators)) + return wrapper diff --git a/kerops/ops/addition.py b/kerops/ops/addition.py index bf59f9d..421a596 100644 --- a/kerops/ops/addition.py +++ b/kerops/ops/addition.py @@ -5,13 +5,10 @@ from triton import next_power_of_2 from ..kernels.addition import _AddStats_cl3d_backward_impl, _AddStats_cl3d_impl -from ._settings import configure, get_l1_cache, ConfigurableArg +from ._settings import ConfigurableArg, configure, get_l1_cache -@configure( - _l1_cache_bytes=get_l1_cache, - _num_warps=8 -) +@configure(_l1_cache_bytes=get_l1_cache, _num_warps=8) def AddStats(x, y, inplace=False, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): num_channels = x.shape[1] numel = x.numel() @@ -52,11 +49,10 @@ def AddStats(x, y, inplace=False, *, _l1_cache_bytes: ConfigurableArg, _num_warp return output, mean, sqmean -@configure( - _l1_cache_bytes=get_l1_cache, - _num_warps=8 -) -def AddStatsBackward(add_grad, mean_grad, sqmean_grad, add_result, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): +@configure(_l1_cache_bytes=get_l1_cache, _num_warps=8) +def AddStatsBackward( + add_grad, mean_grad, sqmean_grad, add_result, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg +): num_channels = add_grad.shape[1] numel = add_grad.numel() assert add_result.shape == add_grad.shape diff --git a/kerops/ops/avgpool.py b/kerops/ops/avgpool.py index e38e234..a0bbbbb 100644 --- a/kerops/ops/avgpool.py +++ b/kerops/ops/avgpool.py @@ -5,7 +5,7 @@ from triton import next_power_of_2 from ..kernels.avgpool import _AvgPoolCeilStats_cl3d_backward_impl, _AvgPoolCeilStats_cl3d_impl -from ._settings import configure, get_l1_cache, ConfigurableArg +from ._settings import ConfigurableArg, configure, get_l1_cache @configure( @@ -60,11 +60,17 @@ def AvgPoolCeilStats(x, *, _l1_cache_bytes: ConfigurableArg, _num_warps: Configu return output, mean, sqmean -@configure( - _l1_cache_bytes=get_l1_cache, - _num_warps=4 -) -def AvgPoolCeilStatsBackward(inpgrad, meangrad, sqmeangrad, output, outgrad_shape, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): +@configure(_l1_cache_bytes=get_l1_cache, _num_warps=4) +def AvgPoolCeilStatsBackward( + inpgrad, + meangrad, + sqmeangrad, + output, + outgrad_shape, + *, + _l1_cache_bytes: ConfigurableArg, + _num_warps: ConfigurableArg, +): MAX_SIZE = _l1_cache_bytes // inpgrad.element_size() # 32768 for fp16 bsize, num_channels, h_outgrad, w_outgrad, d_outgrad = outgrad_shape d_inpgrad = inpgrad.shape[-1] diff --git a/kerops/ops/bnrelu.py b/kerops/ops/bnrelu.py index f1e4ce7..1e5e87e 100644 --- a/kerops/ops/bnrelu.py +++ b/kerops/ops/bnrelu.py @@ -5,13 +5,10 @@ from triton import next_power_of_2 from ..kernels.bnrelu import _ApplyBNReLU_cl3d_backward_impl, _ApplyBNReLU_cl3d_impl -from ._settings import configure, get_l1_cache, ConfigurableArg +from ._settings import ConfigurableArg, configure, get_l1_cache -@configure( - _l1_cache_bytes=get_l1_cache, - _num_warps=8 -) +@configure(_l1_cache_bytes=get_l1_cache, _num_warps=8) def ApplyBNReLU(x, weight, bias, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): num_channels = x.shape[1] numel = x.numel() @@ -44,10 +41,7 @@ def ApplyBNReLU(x, weight, bias, *, _l1_cache_bytes: ConfigurableArg, _num_warps return output -@configure( - _l1_cache_bytes=get_l1_cache, - _num_warps=8 -) +@configure(_l1_cache_bytes=get_l1_cache, _num_warps=8) def ApplyBNReLUBackward(x, weight, bias, grad, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): num_channels = x.shape[1] numel = x.numel() diff --git a/kerops/ops/conv.py b/kerops/ops/conv.py index d537fbf..b1f5d02 100644 --- a/kerops/ops/conv.py +++ b/kerops/ops/conv.py @@ -4,7 +4,7 @@ from triton import language as tl, next_power_of_2 from ..kernels.dw_conv import _DWConv_cl3d_impl, _DWConv_wgrad_cl3d_impl -from ._settings import configure, ConfigurableArg +from ._settings import ConfigurableArg, configure def configure_dwconv(channels): @@ -37,7 +37,9 @@ def configure_dwconv(channels): _num_warps=lambda weight: configure_dwconv(weight.shape[-1])[0][0], D_block=lambda weight: configure_dwconv(weight.shape[-1])[0][1], ) -def DWConv(x, weight, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: ConfigurableArg = 2, D_block: ConfigurableArg = 32): +def DWConv( + x, weight, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: ConfigurableArg = 2, D_block: ConfigurableArg = 32 +): channels = x.shape[1] assert x.ndim == 5 @@ -83,7 +85,9 @@ def DWConv(x, weight, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: Confi _num_warps=lambda x: configure_dwconv(x.shape[1])[1][0], D_block=lambda x: configure_dwconv(x.shape[1])[1][1], ) -def DWConvWGRAD(x, grad, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: ConfigurableArg=2, D_block: ConfigurableArg = 32): +def DWConvWGRAD( + x, grad, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: ConfigurableArg = 2, D_block: ConfigurableArg = 32 +): channels = x.shape[1] assert x.ndim == grad.ndim == 5 diff --git a/kerops/ops/linear.py b/kerops/ops/linear.py index 9ebccda..30d79ad 100644 --- a/kerops/ops/linear.py +++ b/kerops/ops/linear.py @@ -4,7 +4,7 @@ from triton import next_power_of_2 from ..kernels.linear import _ReLULinearAdd, _ReLULinearAddBackward -from ._settings import configure, ConfigurableArg +from ._settings import ConfigurableArg, configure def configure_linear(in_channels): @@ -29,9 +29,9 @@ def ReLULinearAdd( weight, add_other, *, - _num_warps: ConfigurableArg=2, - D_block: ConfigurableArg=16, - _ILP: ConfigurableArg=8, + _num_warps: ConfigurableArg = 2, + D_block: ConfigurableArg = 16, + _ILP: ConfigurableArg = 8, ): in_channels = x.shape[1] out_channels = weight.shape[1] @@ -82,9 +82,9 @@ def ReLULinearBackward( grad, weight, *, - _num_warps: ConfigurableArg=8, - D_block: ConfigurableArg=32, - _ILP: ConfigurableArg=16, + _num_warps: ConfigurableArg = 8, + D_block: ConfigurableArg = 32, + _ILP: ConfigurableArg = 16, ): in_channels = weight.shape[0] out_channels = grad.shape[1] diff --git a/kerops/ops/quantization.py b/kerops/ops/quantization.py index 49e8daf..d9af418 100644 --- a/kerops/ops/quantization.py +++ b/kerops/ops/quantization.py @@ -3,13 +3,10 @@ import torch from ..kernels.quantization import _DequantUint8Window_impl, _QuantUint8Window_impl -from ._settings import configure, get_l1_cache, ConfigurableArg +from ._settings import ConfigurableArg, configure, get_l1_cache -@configure( - _num_warps=4, - _l1_cache_bytes=get_l1_cache -) +@configure(_num_warps=4, _l1_cache_bytes=get_l1_cache) def QuantUint8Window(x, window, *, _num_warps: ConfigurableArg, _l1_cache_bytes: ConfigurableArg): numel = x.numel() MAX_SIZE = _l1_cache_bytes // (2 * x.element_size()) @@ -23,10 +20,7 @@ def QuantUint8Window(x, window, *, _num_warps: ConfigurableArg, _l1_cache_bytes: return output -@configure( - _num_warps=4, - _l1_cache_bytes=get_l1_cache -) +@configure(_num_warps=4, _l1_cache_bytes=get_l1_cache) def DequantUint8Window(x, init_dtype, window, _num_warps: ConfigurableArg, _l1_cache_bytes: ConfigurableArg): numel = x.numel() output = torch.empty_like(x, dtype=init_dtype) diff --git a/kerops/ops/stats.py b/kerops/ops/stats.py index 2427828..c3f81a0 100644 --- a/kerops/ops/stats.py +++ b/kerops/ops/stats.py @@ -5,13 +5,10 @@ from triton import next_power_of_2 from ..kernels.stats import _Stats_cl3d_backward_impl, _Stats_cl3d_impl -from ._settings import configure, get_l1_cache, ConfigurableArg +from ._settings import ConfigurableArg, configure, get_l1_cache -@configure( - _l1_cache_bytes=get_l1_cache, - _num_warps=4 -) +@configure(_l1_cache_bytes=get_l1_cache, _num_warps=4) def Stats(x, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): num_channels = x.shape[1] numel = x.numel() @@ -34,10 +31,7 @@ def Stats(x, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): return mean, sqmean -@configure( - _l1_cache_bytes=get_l1_cache, - _num_warps=4 -) +@configure(_l1_cache_bytes=get_l1_cache, _num_warps=4) def StatsBackward(x, mean_grad, sqmean_grad, *, _l1_cache_bytes: ConfigurableArg, _num_warps: ConfigurableArg): num_channels = x.shape[1] numel = x.numel() From 3dfa1327b56f8892e3c89f17050e5760d9cf8945 Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Sat, 28 Dec 2024 15:41:01 +0300 Subject: [PATCH 05/19] small :nail_care: --- kerops/ops/conv.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/kerops/ops/conv.py b/kerops/ops/conv.py index b1f5d02..a502fb5 100644 --- a/kerops/ops/conv.py +++ b/kerops/ops/conv.py @@ -37,9 +37,7 @@ def configure_dwconv(channels): _num_warps=lambda weight: configure_dwconv(weight.shape[-1])[0][0], D_block=lambda weight: configure_dwconv(weight.shape[-1])[0][1], ) -def DWConv( - x, weight, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: ConfigurableArg = 2, D_block: ConfigurableArg = 32 -): +def DWConv(x, weight, *, ACCTYPE: ConfigurableArg, _num_warps: ConfigurableArg, D_block: ConfigurableArg): channels = x.shape[1] assert x.ndim == 5 @@ -85,9 +83,7 @@ def DWConv( _num_warps=lambda x: configure_dwconv(x.shape[1])[1][0], D_block=lambda x: configure_dwconv(x.shape[1])[1][1], ) -def DWConvWGRAD( - x, grad, *, ACCTYPE: ConfigurableArg = 'float32', _num_warps: ConfigurableArg = 2, D_block: ConfigurableArg = 32 -): +def DWConvWGRAD(x, grad, *, ACCTYPE: ConfigurableArg, _num_warps: ConfigurableArg, D_block: ConfigurableArg): channels = x.shape[1] assert x.ndim == grad.ndim == 5 From 7ddb06e69f534bac53c1cba6b3045418232d020f Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Sat, 28 Dec 2024 16:05:52 +0300 Subject: [PATCH 06/19] :nail_care: --- kerops/ops/conv.py | 37 ++++++++++++++----------------------- kerops/ops/linear.py | 42 +++++++++++++++++++++++++++--------------- 2 files changed, 41 insertions(+), 38 deletions(-) diff --git a/kerops/ops/conv.py b/kerops/ops/conv.py index a502fb5..cc5ae0b 100644 --- a/kerops/ops/conv.py +++ b/kerops/ops/conv.py @@ -7,35 +7,26 @@ from ._settings import ConfigurableArg, configure -def configure_dwconv(channels): - """ - Hardcoded, benchmarked on RTX 3090, mb should be generated automatically - H, W, D = [350, 350, 128] +def dwconv_warps(channels): + return {8: 1, 16: 2, 32: 2, 64: 2, 128: 4}[channels] - channels: [[num_warps, D_block], [num_warps, D_block]] one for fwd another for bwd - """ - """ - TODO - More geeky solution is to compare performances with respect to splitting axis D - to N * D_block with padding - """ +def dwconv_dblock(channels): + return {8: 32, 16: 32, 32: 16, 64: 8, 128: 8}[channels] - HARDCODED_CONFIG = { - 8: [[1, 32], [1, 32]], - 16: [[2, 32], [1, 32]], - 32: [[2, 16], [1, 32]], - 64: [[2, 8], [1, 16]], - 128: [[4, 8], [2, 16]], - } - return HARDCODED_CONFIG.get(channels, None) +def dwconv_wgrad_warps(channels): + return {8: 1, 16: 1, 32: 1, 64: 1, 128: 2}[channels] + + +def dwconv_wgrad_dblock(channels): + return {8: 32, 16: 32, 32: 32, 64: 16, 128: 16}[channels] @configure( ACCTYPE='float32', - _num_warps=lambda weight: configure_dwconv(weight.shape[-1])[0][0], - D_block=lambda weight: configure_dwconv(weight.shape[-1])[0][1], + _num_warps=lambda x: dwconv_warps(x.shape[1]), + D_block=lambda x: dwconv_dblock(x.shape[1]), ) def DWConv(x, weight, *, ACCTYPE: ConfigurableArg, _num_warps: ConfigurableArg, D_block: ConfigurableArg): channels = x.shape[1] @@ -80,8 +71,8 @@ def DWConv(x, weight, *, ACCTYPE: ConfigurableArg, _num_warps: ConfigurableArg, @configure( ACCTYPE='float32', - _num_warps=lambda x: configure_dwconv(x.shape[1])[1][0], - D_block=lambda x: configure_dwconv(x.shape[1])[1][1], + _num_warps=lambda x: dwconv_wgrad_warps(x.shape[1]), + D_block=lambda x: dwconv_wgrad_dblock(x.shape[1]), ) def DWConvWGRAD(x, grad, *, ACCTYPE: ConfigurableArg, _num_warps: ConfigurableArg, D_block: ConfigurableArg): channels = x.shape[1] diff --git a/kerops/ops/linear.py b/kerops/ops/linear.py index 30d79ad..42d14b6 100644 --- a/kerops/ops/linear.py +++ b/kerops/ops/linear.py @@ -7,22 +7,34 @@ from ._settings import ConfigurableArg, configure -def configure_linear(in_channels): - # num_warps, D_block, ILP - HARDCODED_CONFIG = { - 16: [[2, 16, 8], [4, 16, 16]], - 32: [[2, 16, 8], [8, 32, 16]], - 64: [[1, 16, 4], [8, 32, 16]], - 128: [[1, 16, 4], [8, 32, 16]], - } +def fwd_warps(in_channels): + return {16: 2, 32: 2, 64: 1, 128: 1}[in_channels] - return HARDCODED_CONFIG.get(in_channels, None) + +def fwd_dblock(in_channels): + return {16: 16, 32: 16, 64: 16, 128: 16}[in_channels] + + +def fwd_ilp(in_channels): + return {16: 8, 32: 8, 64: 4, 128: 4}[in_channels] + + +def bwd_warps(in_channels): + return {16: 4, 32: 8, 64: 8, 128: 8}[in_channels] + + +def bwd_dblock(in_channels): + return {16: 16, 32: 32, 64: 32, 128: 32}[in_channels] + + +def bwd_ilp(in_channels): + return {16: 16, 32: 16, 64: 16, 128: 16}[in_channels] @configure( - _num_warps=lambda weight: configure_linear(weight.shape[0])[0][0], - D_block=lambda weight: configure_linear(weight.shape[0])[0][1], - _ILP=lambda weight: configure_linear(weight.shape[0])[0][2], + _num_warps=lambda weight: fwd_warps(weight.shape[0]), + D_block=lambda weight: fwd_dblock(weight.shape[0]), + _ILP=lambda weight: fwd_ilp(weight.shape[0]), ) def ReLULinearAdd( x, @@ -73,9 +85,9 @@ def ReLULinearAdd( @configure( - _num_warps=lambda weight: configure_linear(weight.shape[0])[1][0], - D_block=lambda weight: configure_linear(weight.shape[0])[1][1], - _ILP=lambda weight: configure_linear(weight.shape[0])[1][2], + _num_warps=lambda weight: bwd_warps(weight.shape[0]), + D_block=lambda weight: bwd_dblock(weight.shape[0]), + _ILP=lambda weight: bwd_ilp(weight.shape[0]), ) def ReLULinearBackward( input, From 2c44243d62aee0d73e5110315faabeed25b956d1 Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Sat, 28 Dec 2024 16:21:28 +0300 Subject: [PATCH 07/19] new black --- kerops/ops/_settings.py | 8 +++++--- requirements-dev.txt | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/kerops/ops/_settings.py b/kerops/ops/_settings.py index aede5ba..43e754a 100644 --- a/kerops/ops/_settings.py +++ b/kerops/ops/_settings.py @@ -76,9 +76,11 @@ def __call__(self, *args, **kwargs): bind.apply_defaults() configured_kwargs = { - k: self.configurator_call(bind.args, self.configurators[k], self.usual_args) - if input_v is EmptyKwarg - else input_v + k: ( + self.configurator_call(bind.args, self.configurators[k], self.usual_args) + if input_v is EmptyKwarg + else input_v + ) for k, input_v in bind.kwargs.items() } diff --git a/requirements-dev.txt b/requirements-dev.txt index 4c736bb..09e76e8 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,7 @@ matplotlib seaborn -black<23.0.0 +black flake8 flake8-tidy-imports flake8-quotes From 0d95e07b354260dc725af7f6cd60d4bb73065fca Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Sat, 28 Dec 2024 16:43:57 +0300 Subject: [PATCH 08/19] settings are placed in a separate module --- kerops/ops/addition.py | 2 +- kerops/ops/avgpool.py | 2 +- kerops/ops/bnrelu.py | 2 +- kerops/ops/conv.py | 2 +- kerops/ops/linear.py | 2 +- kerops/ops/quantization.py | 2 +- kerops/ops/stats.py | 2 +- kerops/settings/__init__.py | 3 ++ kerops/settings/hardware_conf.py | 11 +++++ kerops/settings/utils.py | 30 ++++++++++++ .../{ops/_settings.py => settings/wrapper.py} | 47 +++---------------- 11 files changed, 57 insertions(+), 48 deletions(-) create mode 100644 kerops/settings/__init__.py create mode 100644 kerops/settings/hardware_conf.py create mode 100644 kerops/settings/utils.py rename kerops/{ops/_settings.py => settings/wrapper.py} (61%) diff --git a/kerops/ops/addition.py b/kerops/ops/addition.py index 421a596..0e29f76 100644 --- a/kerops/ops/addition.py +++ b/kerops/ops/addition.py @@ -5,7 +5,7 @@ from triton import next_power_of_2 from ..kernels.addition import _AddStats_cl3d_backward_impl, _AddStats_cl3d_impl -from ._settings import ConfigurableArg, configure, get_l1_cache +from ..settings import ConfigurableArg, configure, get_l1_cache @configure(_l1_cache_bytes=get_l1_cache, _num_warps=8) diff --git a/kerops/ops/avgpool.py b/kerops/ops/avgpool.py index a0bbbbb..82e0f18 100644 --- a/kerops/ops/avgpool.py +++ b/kerops/ops/avgpool.py @@ -5,7 +5,7 @@ from triton import next_power_of_2 from ..kernels.avgpool import _AvgPoolCeilStats_cl3d_backward_impl, _AvgPoolCeilStats_cl3d_impl -from ._settings import ConfigurableArg, configure, get_l1_cache +from ..settings import ConfigurableArg, configure, get_l1_cache @configure( diff --git a/kerops/ops/bnrelu.py b/kerops/ops/bnrelu.py index 1e5e87e..f894ee1 100644 --- a/kerops/ops/bnrelu.py +++ b/kerops/ops/bnrelu.py @@ -5,7 +5,7 @@ from triton import next_power_of_2 from ..kernels.bnrelu import _ApplyBNReLU_cl3d_backward_impl, _ApplyBNReLU_cl3d_impl -from ._settings import ConfigurableArg, configure, get_l1_cache +from ..settings import ConfigurableArg, configure, get_l1_cache @configure(_l1_cache_bytes=get_l1_cache, _num_warps=8) diff --git a/kerops/ops/conv.py b/kerops/ops/conv.py index cc5ae0b..b251eb1 100644 --- a/kerops/ops/conv.py +++ b/kerops/ops/conv.py @@ -4,7 +4,7 @@ from triton import language as tl, next_power_of_2 from ..kernels.dw_conv import _DWConv_cl3d_impl, _DWConv_wgrad_cl3d_impl -from ._settings import ConfigurableArg, configure +from ..settings import ConfigurableArg, configure def dwconv_warps(channels): diff --git a/kerops/ops/linear.py b/kerops/ops/linear.py index 42d14b6..961f166 100644 --- a/kerops/ops/linear.py +++ b/kerops/ops/linear.py @@ -4,7 +4,7 @@ from triton import next_power_of_2 from ..kernels.linear import _ReLULinearAdd, _ReLULinearAddBackward -from ._settings import ConfigurableArg, configure +from ..settings import ConfigurableArg, configure def fwd_warps(in_channels): diff --git a/kerops/ops/quantization.py b/kerops/ops/quantization.py index d9af418..6300edd 100644 --- a/kerops/ops/quantization.py +++ b/kerops/ops/quantization.py @@ -3,7 +3,7 @@ import torch from ..kernels.quantization import _DequantUint8Window_impl, _QuantUint8Window_impl -from ._settings import ConfigurableArg, configure, get_l1_cache +from ..settings import ConfigurableArg, configure, get_l1_cache @configure(_num_warps=4, _l1_cache_bytes=get_l1_cache) diff --git a/kerops/ops/stats.py b/kerops/ops/stats.py index c3f81a0..dc643ff 100644 --- a/kerops/ops/stats.py +++ b/kerops/ops/stats.py @@ -5,7 +5,7 @@ from triton import next_power_of_2 from ..kernels.stats import _Stats_cl3d_backward_impl, _Stats_cl3d_impl -from ._settings import ConfigurableArg, configure, get_l1_cache +from ..settings import ConfigurableArg, configure, get_l1_cache @configure(_l1_cache_bytes=get_l1_cache, _num_warps=4) diff --git a/kerops/settings/__init__.py b/kerops/settings/__init__.py new file mode 100644 index 0000000..31decd2 --- /dev/null +++ b/kerops/settings/__init__.py @@ -0,0 +1,3 @@ +from .hardware_conf import get_l1_cache +from .utils import ConfigurableArg +from .wrapper import configure diff --git a/kerops/settings/hardware_conf.py b/kerops/settings/hardware_conf.py new file mode 100644 index 0000000..e9eba7e --- /dev/null +++ b/kerops/settings/hardware_conf.py @@ -0,0 +1,11 @@ +L1_CACHE_BYTES = 65536 + + +def get_l1_cache(): + global L1_CACHE_BYTES + return L1_CACHE_BYTES + + +def set_l1_cache(new_cache): + global L1_CACHE_BYTES + L1_CACHE_BYTES = new_cache diff --git a/kerops/settings/utils.py b/kerops/settings/utils.py new file mode 100644 index 0000000..d28aad5 --- /dev/null +++ b/kerops/settings/utils.py @@ -0,0 +1,30 @@ +from inspect import Parameter + + +class ConfigurableArg: + pass + + +def check_function_signature(signature): + for param in signature.parameters.values(): + if param.annotation is ConfigurableArg and param.kind is not Parameter.KEYWORD_ONLY: + raise RuntimeError(f'ConfigurableArg must be keyword-only - {param.name}') + elif param.annotation is not ConfigurableArg and param.kind is Parameter.KEYWORD_ONLY: + raise RuntimeError(f'non-ConfigurableArg must not be keyword-only - {param.name}') + + +def get_configurable_args_from_signature(signature): + return [param.name for param in signature.parameters.values() if param.annotation is ConfigurableArg] + + +def get_usual_args_from_signature(signature): + return [ + param.name + for param in signature.parameters.values() + if param.kind is Parameter.POSITIONAL_ONLY or param.kind is Parameter.POSITIONAL_OR_KEYWORD + ] + + +def is_configurators_fit(configurable_args, configurators_names): + if set(configurable_args) != set(configurators_names): + raise RuntimeError(f'Configuration mismatch, {configurable_args=}, {configurators_names=}') diff --git a/kerops/ops/_settings.py b/kerops/settings/wrapper.py similarity index 61% rename from kerops/ops/_settings.py rename to kerops/settings/wrapper.py index 43e754a..6620f48 100644 --- a/kerops/ops/_settings.py +++ b/kerops/settings/wrapper.py @@ -2,53 +2,18 @@ from functools import wraps from typing import Callable - -L1_CACHE_BYTES = 65536 - - -def get_l1_cache(): - global L1_CACHE_BYTES - return L1_CACHE_BYTES - - -def set_l1_cache(new_cache): - global L1_CACHE_BYTES - L1_CACHE_BYTES = new_cache - - -class ConfigurableArg: - pass +from .utils import ( + check_function_signature, + get_configurable_args_from_signature, + get_usual_args_from_signature, + is_configurators_fit, +) class EmptyKwarg: pass -def check_function_signature(signature): - for param in signature.parameters.values(): - if param.annotation is ConfigurableArg and param.kind is not inspect.Parameter.KEYWORD_ONLY: - raise RuntimeError(f'ConfigurableArg must be keyword-only - {param.name}') - elif param.annotation is not ConfigurableArg and param.kind is inspect.Parameter.KEYWORD_ONLY: - raise RuntimeError(f'non-ConfigurableArg must not be keyword-only - {param.name}') - - -def get_configurable_args_from_signature(signature): - return [param.name for param in signature.parameters.values() if param.annotation is ConfigurableArg] - - -def get_usual_args_from_signature(signature): - return [ - param.name - for param in signature.parameters.values() - if param.kind is inspect.Parameter.POSITIONAL_ONLY or param.kind is inspect.Parameter.POSITIONAL_OR_KEYWORD - ] - - -def is_configurators_fit(configurable_args, configurators_names): - if set(configurable_args) != set(configurators_names): - raise RuntimeError(f'Configuration mismatch, {configurable_args=}, {configurators_names=}') - - class ConfiguredFunction: def __init__(self, origin_function, signature, configurable_args, usual_args, **configurators): self.origin_function = origin_function From 3b9bfc5ba4b5e8ddea088ac823ef27ac01f19f71 Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Sat, 28 Dec 2024 16:48:19 +0300 Subject: [PATCH 09/19] better naming --- kerops/settings/utils.py | 8 ++++---- kerops/settings/wrapper.py | 17 ++++++----------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/kerops/settings/utils.py b/kerops/settings/utils.py index d28aad5..6b623a9 100644 --- a/kerops/settings/utils.py +++ b/kerops/settings/utils.py @@ -5,7 +5,7 @@ class ConfigurableArg: pass -def check_function_signature(signature): +def validate_signature(signature): for param in signature.parameters.values(): if param.annotation is ConfigurableArg and param.kind is not Parameter.KEYWORD_ONLY: raise RuntimeError(f'ConfigurableArg must be keyword-only - {param.name}') @@ -13,11 +13,11 @@ def check_function_signature(signature): raise RuntimeError(f'non-ConfigurableArg must not be keyword-only - {param.name}') -def get_configurable_args_from_signature(signature): +def get_config_args(signature): return [param.name for param in signature.parameters.values() if param.annotation is ConfigurableArg] -def get_usual_args_from_signature(signature): +def get_standard_args(signature): return [ param.name for param in signature.parameters.values() @@ -25,6 +25,6 @@ def get_usual_args_from_signature(signature): ] -def is_configurators_fit(configurable_args, configurators_names): +def configs_match(configurable_args, configurators_names): if set(configurable_args) != set(configurators_names): raise RuntimeError(f'Configuration mismatch, {configurable_args=}, {configurators_names=}') diff --git a/kerops/settings/wrapper.py b/kerops/settings/wrapper.py index 6620f48..36a4bbe 100644 --- a/kerops/settings/wrapper.py +++ b/kerops/settings/wrapper.py @@ -2,12 +2,7 @@ from functools import wraps from typing import Callable -from .utils import ( - check_function_signature, - get_configurable_args_from_signature, - get_usual_args_from_signature, - is_configurators_fit, -) +from .utils import configs_match, get_config_args, get_standard_args, validate_signature class EmptyKwarg: @@ -52,7 +47,7 @@ def __call__(self, *args, **kwargs): return self.origin_function(*bind.args, **configured_kwargs) def reconfigure(self, **new_configurators): - is_configurators_fit(self.configurable_args, new_configurators.keys()) + configs_match(self.configurable_args, new_configurators.keys()) self.configurators = new_configurators @@ -60,13 +55,13 @@ def configure(**configurators): def wrapper(function): signature = inspect.signature(function) - check_function_signature(signature) + validate_signature(signature) - configurable_args = get_configurable_args_from_signature(signature) + configurable_args = get_config_args(signature) - usual_args = get_usual_args_from_signature(signature) + usual_args = get_standard_args(signature) - is_configurators_fit(configurable_args, configurators.keys()) + configs_match(configurable_args, configurators.keys()) return wraps(function)(ConfiguredFunction(function, signature, configurable_args, usual_args, **configurators)) From bc65761a93bac7ff8c322527ecbf82149af11ff0 Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Sun, 5 Jan 2025 13:53:45 +0300 Subject: [PATCH 10/19] triton==3.1.0, fast DWConvWGRAD --- kerops/kernels/dw_conv.py | 183 +++++++++++++++++++------------------- kerops/ops/conv.py | 18 ++-- requirements.txt | 2 +- tests/test_ops.py | 6 +- 4 files changed, 107 insertions(+), 102 deletions(-) diff --git a/kerops/kernels/dw_conv.py b/kerops/kernels/dw_conv.py index d2fa328..33f4f52 100644 --- a/kerops/kernels/dw_conv.py +++ b/kerops/kernels/dw_conv.py @@ -157,17 +157,17 @@ def _DWConv_wgrad_cl3d_impl( channels: tl.constexpr, D_block: tl.constexpr, WD_grid, + D_grid, + delta_H_grid, + ILP: tl.constexpr, ): H_cell = tl.program_id(0) - W_D_cell = tl.program_id(1) - - D_gridsize = tl.cdiv(D, D_block) - W_cell = W_D_cell // D_gridsize - D_cell = W_D_cell % D_gridsize + W_cell = tl.program_id(1) + D_cell = tl.program_id(2) input_ptr += D_cell * D_block * channels grad_ptr += D_cell * D_block * channels - weight_grad_ptr += (H_cell * WD_grid + W_D_cell) * 27 * channels + weight_grad_ptr += (H_cell * WD_grid + W_cell * D_grid + D_cell) * 27 * channels channels_offset = tl.arange(0, channels) channels_offset = tl.max_contiguous(tl.multiple_of(channels_offset, channels), channels) @@ -179,11 +179,8 @@ def _DWConv_wgrad_cl3d_impl( mask = mask and (d_offset[None, None, :] + near_offset[:, None, None] >= 0 - D_block * D_cell) mask = mask and (near_offset[:, None, None] != 2) - in_offset = d_offset[None, None, :] * channels + channels_offset[None, :, None] - in_mask = d_offset[None, None, :] < D - D_block * D_cell - - H1_load = 2 * H_cell + 1 < H - W1_load = 2 * W_cell + 1 < W + in_offset = d_offset[None, :] * channels + channels_offset[:, None] + in_mask = d_offset[None, :] < D - D_block * D_cell h0_w0 = tl.zeros([4, channels], dtype=ACCTYPE) h0_w1 = tl.zeros([4, channels], dtype=ACCTYPE) @@ -195,91 +192,91 @@ def _DWConv_wgrad_cl3d_impl( h2_w1 = tl.zeros([4, channels], dtype=ACCTYPE) h2_w2 = tl.zeros([4, channels], dtype=ACCTYPE) - tmp_input_ptr = input_ptr + 2 * H_cell * H_stride + 2 * W_cell * W_stride - x_h0_w0 = tl.load(tmp_input_ptr + in_offset, mask=in_mask, other=0.0) - - tmp_input_ptr = input_ptr + (2 * H_cell + 1) * H_stride + 2 * W_cell * W_stride - x_h1_w0 = tl.load(tmp_input_ptr + in_offset, mask=in_mask and H1_load, other=0.0) - - tmp_input_ptr = input_ptr + 2 * H_cell * H_stride + (2 * W_cell + 1) * W_stride - x_h0_w1 = tl.load(tmp_input_ptr + in_offset, mask=in_mask and W1_load, other=0.0) - - tmp_input_ptr = input_ptr + (2 * H_cell + 1) * H_stride + (2 * W_cell + 1) * W_stride - x_h1_w1 = tl.load(tmp_input_ptr + in_offset, mask=in_mask and (W1_load and H1_load), other=0.0) - gradw_offset = tl.arange(0, 4)[:, None] * channels + channels_offset[None, :] gradw_mask = near_offset[:, None] != 2 - load_next = (2 * H_cell - 1 < H and 2 * H_cell - 1 >= 0) and (2 * W_cell - 1 < W and 2 * W_cell - 1 >= 0) - tmp_grad_ptr = grad_ptr + (2 * H_cell - 1) * H_stride + (2 * W_cell - 1) * W_stride - i = -1 - j = -1 - grad = tl.zeros([4, channels, D_block], dtype=tl.float16) - if load_next: - grad = tl.load(tmp_grad_ptr + offset, mask=mask) - - for k in tl.static_range(0, 16): - if load_next: - if i == -1 and j == -1: - h2_w2 += tl.sum(grad * x_h0_w0, axis=2) - elif i == -1 and j == 0: - h2_w1 += tl.sum(grad * x_h0_w0, axis=2) - h2_w2 += tl.sum(grad * x_h0_w1, axis=2) - elif i == -1 and j == 1: - h2_w0 += tl.sum(grad * x_h0_w0, axis=2) - h2_w1 += tl.sum(grad * x_h0_w1, axis=2) - elif i == -1 and j == 2: - h2_w0 += tl.sum(grad * x_h0_w1, axis=2) - elif i == 0 and j == -1: - h1_w2 += tl.sum(grad * x_h0_w0, axis=2) - h2_w2 += tl.sum(grad * x_h1_w0, axis=2) - elif i == 0 and j == 0: - h1_w1 += tl.sum(grad * x_h0_w0, axis=2) - h2_w1 += tl.sum(grad * x_h1_w0, axis=2) - h1_w2 += tl.sum(grad * x_h0_w1, axis=2) - h2_w2 += tl.sum(grad * x_h1_w1, axis=2) - elif i == 0 and j == 1: - h1_w0 += tl.sum(grad * x_h0_w0, axis=2) - h2_w0 += tl.sum(grad * x_h1_w0, axis=2) - h1_w1 += tl.sum(grad * x_h0_w1, axis=2) - h2_w1 += tl.sum(grad * x_h1_w1, axis=2) - elif i == 0 and j == 2: - h1_w0 += tl.sum(grad * x_h0_w1, axis=2) - h2_w0 += tl.sum(grad * x_h1_w1, axis=2) - elif i == 1 and j == -1: - h0_w2 += tl.sum(grad * x_h0_w0, axis=2) - h1_w2 += tl.sum(grad * x_h1_w0, axis=2) - elif i == 1 and j == 0: - h0_w1 += tl.sum(grad * x_h0_w0, axis=2) - h1_w1 += tl.sum(grad * x_h1_w0, axis=2) - h0_w2 += tl.sum(grad * x_h0_w1, axis=2) - h1_w2 += tl.sum(grad * x_h1_w1, axis=2) - elif i == 1 and j == 1: - h0_w0 += tl.sum(grad * x_h0_w0, axis=2) - h1_w0 += tl.sum(grad * x_h1_w0, axis=2) - h0_w1 += tl.sum(grad * x_h0_w1, axis=2) - h1_w1 += tl.sum(grad * x_h1_w1, axis=2) - elif i == 1 and j == 2: - h0_w0 += tl.sum(grad * x_h0_w1, axis=2) - h1_w0 += tl.sum(grad * x_h1_w1, axis=2) - elif i == 2 and j == -1: - h0_w2 += tl.sum(grad * x_h1_w0, axis=2) - elif i == 2 and j == 0: - h0_w1 += tl.sum(grad * x_h1_w0, axis=2) - h0_w2 += tl.sum(grad * x_h1_w1, axis=2) - elif i == 2 and j == 1: - h0_w0 += tl.sum(grad * x_h1_w0, axis=2) - h0_w1 += tl.sum(grad * x_h1_w1, axis=2) - else: - h0_w0 += tl.sum(grad * x_h1_w1, axis=2) - - k_ = k + 1 - i = (k_ % 4) - 1 - j = (k_ // 4) - 1 - load_next = (2 * H_cell + i < H and 2 * H_cell + i >= 0) and (2 * W_cell + j < W and 2 * W_cell + j >= 0) - tmp_grad_ptr = grad_ptr + (2 * H_cell + i) * H_stride + (2 * W_cell + j) * W_stride - if load_next and k_ < 16: - grad = tl.load(tmp_grad_ptr + offset, mask=mask) # , other=0.) + for ilp in tl.static_range(0, ILP): + H0_load = 2 * H_cell < H + H1_load = 2 * H_cell + 1 < H + W1_load = 2 * W_cell + 1 < W + + tmp_input_ptr = input_ptr + 2 * H_cell * H_stride + 2 * W_cell * W_stride + x_h0_w0 = tl.load(tmp_input_ptr + offset, mask=mask and H0_load) + + tmp_input_ptr = input_ptr + (2 * H_cell + 1) * H_stride + 2 * W_cell * W_stride + x_h1_w0 = tl.load(tmp_input_ptr + offset, mask=mask and H1_load) + + tmp_input_ptr = input_ptr + 2 * H_cell * H_stride + (2 * W_cell + 1) * W_stride + x_h0_w1 = tl.load(tmp_input_ptr + offset, mask=mask and (W1_load and H0_load)) + + tmp_input_ptr = input_ptr + (2 * H_cell + 1) * H_stride + (2 * W_cell + 1) * W_stride + x_h1_w1 = tl.load(tmp_input_ptr + offset, mask=mask and (W1_load and H1_load)) + + #grad = tl.zeros([channels, D_block], dtype=tl.float16)[None] + + for k in tl.static_range(0, 16): + i = (k % 4) - 1 + j = (k // 4) - 1 + load_next = (2 * H_cell + i < H and 2 * H_cell + i >= 0) and (2 * W_cell + j < W and 2 * W_cell + j >= 0) + tmp_grad_ptr = grad_ptr + (2 * H_cell + i) * H_stride + (2 * W_cell + j) * W_stride + + if load_next: + grad = tl.load(tmp_grad_ptr + in_offset, mask=in_mask, other=0.)[None] + + if i == -1 and j == -1: + h2_w2 += tl.sum(grad * x_h0_w0, axis=2) + elif i == -1 and j == 0: + h2_w1 += tl.sum(grad * x_h0_w0, axis=2) + h2_w2 += tl.sum(grad * x_h0_w1, axis=2) + elif i == -1 and j == 1: + h2_w0 += tl.sum(grad * x_h0_w0, axis=2) + h2_w1 += tl.sum(grad * x_h0_w1, axis=2) + elif i == -1 and j == 2: + h2_w0 += tl.sum(grad * x_h0_w1, axis=2) + elif i == 0 and j == -1: + h1_w2 += tl.sum(grad * x_h0_w0, axis=2) + h2_w2 += tl.sum(grad * x_h1_w0, axis=2) + elif i == 0 and j == 0: + h1_w1 += tl.sum(grad * x_h0_w0, axis=2) + h2_w1 += tl.sum(grad * x_h1_w0, axis=2) + h1_w2 += tl.sum(grad * x_h0_w1, axis=2) + h2_w2 += tl.sum(grad * x_h1_w1, axis=2) + elif i == 0 and j == 1: + h1_w0 += tl.sum(grad * x_h0_w0, axis=2) + h2_w0 += tl.sum(grad * x_h1_w0, axis=2) + h1_w1 += tl.sum(grad * x_h0_w1, axis=2) + h2_w1 += tl.sum(grad * x_h1_w1, axis=2) + elif i == 0 and j == 2: + h1_w0 += tl.sum(grad * x_h0_w1, axis=2) + h2_w0 += tl.sum(grad * x_h1_w1, axis=2) + elif i == 1 and j == -1: + h0_w2 += tl.sum(grad * x_h0_w0, axis=2) + h1_w2 += tl.sum(grad * x_h1_w0, axis=2) + elif i == 1 and j == 0: + h0_w1 += tl.sum(grad * x_h0_w0, axis=2) + h1_w1 += tl.sum(grad * x_h1_w0, axis=2) + h0_w2 += tl.sum(grad * x_h0_w1, axis=2) + h1_w2 += tl.sum(grad * x_h1_w1, axis=2) + elif i == 1 and j == 1: + h0_w0 += tl.sum(grad * x_h0_w0, axis=2) + h1_w0 += tl.sum(grad * x_h1_w0, axis=2) + h0_w1 += tl.sum(grad * x_h0_w1, axis=2) + h1_w1 += tl.sum(grad * x_h1_w1, axis=2) + elif i == 1 and j == 2: + h0_w0 += tl.sum(grad * x_h0_w1, axis=2) + h1_w0 += tl.sum(grad * x_h1_w1, axis=2) + elif i == 2 and j == -1: + h0_w2 += tl.sum(grad * x_h1_w0, axis=2) + elif i == 2 and j == 0: + h0_w1 += tl.sum(grad * x_h1_w0, axis=2) + h0_w2 += tl.sum(grad * x_h1_w1, axis=2) + elif i == 2 and j == 1: + h0_w0 += tl.sum(grad * x_h1_w0, axis=2) + h0_w1 += tl.sum(grad * x_h1_w1, axis=2) + else: + h0_w0 += tl.sum(grad * x_h1_w1, axis=2) + + H_cell += delta_H_grid tl.store(weight_grad_ptr + gradw_offset, h0_w0, mask=gradw_mask) tl.store((weight_grad_ptr + 3 * channels) + gradw_offset, h0_w1, mask=gradw_mask) diff --git a/kerops/ops/conv.py b/kerops/ops/conv.py index b251eb1..c928d60 100644 --- a/kerops/ops/conv.py +++ b/kerops/ops/conv.py @@ -20,7 +20,11 @@ def dwconv_wgrad_warps(channels): def dwconv_wgrad_dblock(channels): - return {8: 32, 16: 32, 32: 32, 64: 16, 128: 16}[channels] + return {8: 32, 16: 32, 32: 16, 64: 8, 128: 8}[channels] + + +def dwconv_wgrad_ilp(channels): + return {8: 1, 16: 1, 32: 2, 64: 3, 128: 3}[channels] @configure( @@ -73,8 +77,9 @@ def DWConv(x, weight, *, ACCTYPE: ConfigurableArg, _num_warps: ConfigurableArg, ACCTYPE='float32', _num_warps=lambda x: dwconv_wgrad_warps(x.shape[1]), D_block=lambda x: dwconv_wgrad_dblock(x.shape[1]), + ILP=lambda x: dwconv_wgrad_ilp(x.shape[1]) ) -def DWConvWGRAD(x, grad, *, ACCTYPE: ConfigurableArg, _num_warps: ConfigurableArg, D_block: ConfigurableArg): +def DWConvWGRAD(x, grad, *, ACCTYPE: ConfigurableArg, _num_warps: ConfigurableArg, D_block: ConfigurableArg, ILP: ConfigurableArg): channels = x.shape[1] assert x.ndim == grad.ndim == 5 @@ -90,10 +95,10 @@ def DWConvWGRAD(x, grad, *, ACCTYPE: ConfigurableArg, _num_warps: ConfigurableAr bsize, _, H, W, D = x.shape batch_stride, _, H_stride, W_stride, _ = x.stride() - H_grid = ceil(H / 2) + H_grid = ceil(H / (2 * ILP)) W_grid = ceil(W / 2) D_grid = ceil(D / D_block) - grid = (H_grid, W_grid * D_grid) + grid = (H_grid, W_grid, D_grid) grad_w = torch.zeros([bsize, H_grid * W_grid * D_grid, 3, 3, 3, channels], device=x.device, dtype=torch.float16) WD_grid = W_grid * D_grid # TODO: mb implement in another way @@ -112,9 +117,12 @@ def DWConvWGRAD(x, grad, *, ACCTYPE: ConfigurableArg, _num_warps: ConfigurableAr channels, D_block, WD_grid, + D_grid, + H_grid, + ILP, num_warps=_num_warps, ) - grad_w = torch.flip(grad_w.sum(dim=(0, 1)), dims=(2,)) + grad_w = grad_w.sum(dim=(0, 1)) return grad_w diff --git a/requirements.txt b/requirements.txt index 4bb9eb1..b24cb10 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -triton==2.3.0 +triton==3.1.0 torch diff --git a/tests/test_ops.py b/tests/test_ops.py index d452ef8..f777630 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -27,17 +27,17 @@ def bsize(request): return request.param -@pytest.fixture(params=[2, 3, 7, 13, 16, 32, 37, 53, 111, 128]) +@pytest.fixture(params=[1, 2, 3, 7, 13, 16, 32, 37, 53, 111, 128]) def other_1(request): return request.param -@pytest.fixture(params=[2, 3, 7, 13, 16, 32, 37, 53, 111, 128]) +@pytest.fixture(params=[1, 2, 3, 7, 13, 16, 32, 37, 53, 111, 128]) def other_2(request): return request.param -@pytest.fixture(params=[2, 3, 7, 13, 16, 32, 37, 53, 111, 128]) +@pytest.fixture(params=[1, 2, 3, 7, 13, 16, 32, 37, 53, 111, 128]) def other_3(request): return request.param From 21ce2e27c82b7ec68568c5de8693d8cb483ed11d Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Sun, 5 Jan 2025 15:20:23 +0300 Subject: [PATCH 11/19] DWConv for triton==3.1.0 --- kerops/kernels/dw_conv.py | 84 +++++++++++++++++++-------------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/kerops/kernels/dw_conv.py b/kerops/kernels/dw_conv.py index 33f4f52..247e261 100644 --- a/kerops/kernels/dw_conv.py +++ b/kerops/kernels/dw_conv.py @@ -28,13 +28,13 @@ def _DWConv_cl3d_impl( d_offset = tl.arange(0, D_block) near_offset = tl.arange(0, 4) - 1 - offset = d_offset[:, None, None] * channels + channels_offset[None, :, None] + near_offset[None, None, :] * channels - mask = d_offset[:, None, None] + near_offset[None, None, :] < D - D_block * D_cell - mask = mask and (d_offset[:, None, None] + near_offset[None, None, :] >= 0 - D_block * D_cell) - mask = mask and (near_offset[None, None, :] != 2) + offset = d_offset[:, None, None] * channels + channels_offset[None, None, :] + near_offset[None, :, None] * channels + mask = d_offset[:, None, None] + near_offset[None, :, None] < D - D_block * D_cell + mask = mask and (d_offset[:, None, None] + near_offset[None, :, None] >= 0 - D_block * D_cell) + mask = mask and (near_offset[None, :, None] != 2) - weight_offset = channels_offset[None, :, None] + tl.arange(0, 4)[None, None, :] * channels - weight_mask = tl.arange(0, 4)[None, None, :] != 3 + weight_offset = channels_offset[None, None, :] + tl.arange(0, 4)[None, :, None] * channels + weight_mask = tl.arange(0, 4)[None, :, None] != 3 weight_h0_w0 = tl.load(weight_ptr + weight_offset, mask=weight_mask, other=0.0) weight_h0_w1 = tl.load((weight_ptr + 3 * channels) + weight_offset, mask=weight_mask, other=0.0) @@ -68,57 +68,57 @@ def _DWConv_cl3d_impl( for k in tl.static_range(0, 16): if k == 0: - h0_w0 += tl.sum(x * weight_h0_w0, axis=2) + h0_w0 += tl.sum(x * weight_h0_w0, axis=1) elif k == 1: - h0_w0 += tl.sum(x * weight_h1_w0, axis=2) - h1_w0 += tl.sum(x * weight_h0_w0, axis=2) + h0_w0 += tl.sum(x * weight_h1_w0, axis=1) + h1_w0 += tl.sum(x * weight_h0_w0, axis=1) elif k == 2: - h0_w0 += tl.sum(x * weight_h2_w0, axis=2) - h1_w0 += tl.sum(x * weight_h1_w0, axis=2) + h0_w0 += tl.sum(x * weight_h2_w0, axis=1) + h1_w0 += tl.sum(x * weight_h1_w0, axis=1) elif k == 3: - h1_w0 += tl.sum(x * weight_h2_w0, axis=2) + h1_w0 += tl.sum(x * weight_h2_w0, axis=1) elif k == 4: - h0_w0 += tl.sum(x * weight_h0_w1, axis=2) - h0_w1 += tl.sum(x * weight_h0_w0, axis=2) + h0_w0 += tl.sum(x * weight_h0_w1, axis=1) + h0_w1 += tl.sum(x * weight_h0_w0, axis=1) elif k == 5: - h0_w0 += tl.sum(x * weight_h1_w1, axis=2) - h0_w1 += tl.sum(x * weight_h1_w0, axis=2) - h1_w0 += tl.sum(x * weight_h0_w1, axis=2) - h1_w1 += tl.sum(x * weight_h0_w0, axis=2) + h0_w0 += tl.sum(x * weight_h1_w1, axis=1) + h0_w1 += tl.sum(x * weight_h1_w0, axis=1) + h1_w0 += tl.sum(x * weight_h0_w1, axis=1) + h1_w1 += tl.sum(x * weight_h0_w0, axis=1) elif k == 6: - h0_w0 += tl.sum(x * weight_h2_w1, axis=2) - h0_w1 += tl.sum(x * weight_h2_w0, axis=2) - h1_w0 += tl.sum(x * weight_h1_w1, axis=2) - h1_w1 += tl.sum(x * weight_h1_w0, axis=2) + h0_w0 += tl.sum(x * weight_h2_w1, axis=1) + h0_w1 += tl.sum(x * weight_h2_w0, axis=1) + h1_w0 += tl.sum(x * weight_h1_w1, axis=1) + h1_w1 += tl.sum(x * weight_h1_w0, axis=1) elif k == 7: - h1_w0 += tl.sum(x * weight_h2_w1, axis=2) - h1_w1 += tl.sum(x * weight_h2_w0, axis=2) + h1_w0 += tl.sum(x * weight_h2_w1, axis=1) + h1_w1 += tl.sum(x * weight_h2_w0, axis=1) elif k == 8: - h0_w0 += tl.sum(x * weight_h0_w2, axis=2) - h0_w1 += tl.sum(x * weight_h0_w1, axis=2) + h0_w0 += tl.sum(x * weight_h0_w2, axis=1) + h0_w1 += tl.sum(x * weight_h0_w1, axis=1) elif k == 9: - h0_w0 += tl.sum(x * weight_h1_w2, axis=2) - h0_w1 += tl.sum(x * weight_h1_w1, axis=2) - h1_w0 += tl.sum(x * weight_h0_w2, axis=2) - h1_w1 += tl.sum(x * weight_h0_w1, axis=2) + h0_w0 += tl.sum(x * weight_h1_w2, axis=1) + h0_w1 += tl.sum(x * weight_h1_w1, axis=1) + h1_w0 += tl.sum(x * weight_h0_w2, axis=1) + h1_w1 += tl.sum(x * weight_h0_w1, axis=1) elif k == 10: - h0_w0 += tl.sum(x * weight_h2_w2, axis=2) - h0_w1 += tl.sum(x * weight_h2_w1, axis=2) - h1_w0 += tl.sum(x * weight_h1_w2, axis=2) - h1_w1 += tl.sum(x * weight_h1_w1, axis=2) + h0_w0 += tl.sum(x * weight_h2_w2, axis=1) + h0_w1 += tl.sum(x * weight_h2_w1, axis=1) + h1_w0 += tl.sum(x * weight_h1_w2, axis=1) + h1_w1 += tl.sum(x * weight_h1_w1, axis=1) elif k == 11: - h1_w0 += tl.sum(x * weight_h2_w2, axis=2) - h1_w1 += tl.sum(x * weight_h2_w1, axis=2) + h1_w0 += tl.sum(x * weight_h2_w2, axis=1) + h1_w1 += tl.sum(x * weight_h2_w1, axis=1) elif k == 12: - h0_w1 += tl.sum(x * weight_h0_w2, axis=2) + h0_w1 += tl.sum(x * weight_h0_w2, axis=1) elif k == 13: - h0_w1 += tl.sum(x * weight_h1_w2, axis=2) - h1_w1 += tl.sum(x * weight_h0_w2, axis=2) + h0_w1 += tl.sum(x * weight_h1_w2, axis=1) + h1_w1 += tl.sum(x * weight_h0_w2, axis=1) elif k == 14: - h0_w1 += tl.sum(x * weight_h2_w2, axis=2) - h1_w1 += tl.sum(x * weight_h1_w2, axis=2) + h0_w1 += tl.sum(x * weight_h2_w2, axis=1) + h1_w1 += tl.sum(x * weight_h1_w2, axis=1) else: - h1_w1 += tl.sum(x * weight_h2_w2, axis=2) + h1_w1 += tl.sum(x * weight_h2_w2, axis=1) k_ = k + 1 i = (k_ % 4) - 1 From e9f84adc76128ba08a928f9a3f356a53066a868f Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Sun, 5 Jan 2025 22:12:40 +0300 Subject: [PATCH 12/19] tests --- tests/test_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_ops.py b/tests/test_ops.py index f777630..d452ef8 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -27,17 +27,17 @@ def bsize(request): return request.param -@pytest.fixture(params=[1, 2, 3, 7, 13, 16, 32, 37, 53, 111, 128]) +@pytest.fixture(params=[2, 3, 7, 13, 16, 32, 37, 53, 111, 128]) def other_1(request): return request.param -@pytest.fixture(params=[1, 2, 3, 7, 13, 16, 32, 37, 53, 111, 128]) +@pytest.fixture(params=[2, 3, 7, 13, 16, 32, 37, 53, 111, 128]) def other_2(request): return request.param -@pytest.fixture(params=[1, 2, 3, 7, 13, 16, 32, 37, 53, 111, 128]) +@pytest.fixture(params=[2, 3, 7, 13, 16, 32, 37, 53, 111, 128]) def other_3(request): return request.param From 538f57748ad7cbd5c5d5701061c9f6ce44ba5bb1 Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Mon, 6 Jan 2025 19:53:19 +0300 Subject: [PATCH 13/19] readme --- README.md | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 71a5c97..38dd106 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # Kerops -Efficient and fast algorithms on the GPU +Fast algorithms for GPU # Install *pip is not available right now* @@ -12,8 +12,17 @@ Time comparison (ms) for NVidia RTX 3090. Input is an array of size (1, channels | channels |torch.clone| kerops.ops.DWConv |torch.nn.Conv3d(C->C)| |:--------------------:|:---------:|:--------------------:|:-------------------:| -| 8 | 0.61 | 0.81 (x1.32) | 2.45 (x4.00) | -| 16 | 1.21 | 1.27 (1.27) | 4.48 (x3.70) | -| 32 | 2.40 | 3.12 (1.30) | 15.3 (x6.38) | -| 64 | 4.78 | 6.29 (1.32) | 52.0 (x10.89) | -| 128 | 9.55 | 13.2 (1.38) | 195.0 (x20.44) | +| 8 | 0.61 | 0.79 (x1.30) | 2.45 (x4.00) | +| 16 | 1.21 | 1.41 (x1.17) | 4.48 (x3.70) | +| 32 | 2.40 | 2.99 (x1.25) | 15.3 (x6.38) | +| 64 | 4.78 | 6.29 (x1.32) | 52.0 (x10.89) | +| 128 | 9.55 | 12.8 (x1.34) | 195.0 (x20.44) | + + +| channels |torch.clone|kerops.ops.DWConvWGRAD|torch.nn.Conv3d(C->C)| +|:--------------------:|:---------:|:--------------------:|:-------------------:| +| 8 | 0.61 | 2.55 (x4.18) | 7.14 (x11.70) | +| 16 | 1.21 | 3.01 (x2.49) | 12.1 (x10.00) | +| 32 | 2.40 | 4.80 (x2.00) | 24.6 (x10.25) | +| 64 | 4.78 | 8.72 (x1.82) | 71.3 (x14.91) | +| 128 | 9.55 | 17.9 (x1.87) | 245.0 (x25.65) | From f01c1fa1f69ceb36786818b1a8518704084efe14 Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Tue, 7 Jan 2025 13:53:22 +0300 Subject: [PATCH 14/19] minor fix --- kerops/ops/linear.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/kerops/ops/linear.py b/kerops/ops/linear.py index 961f166..ea8e8f9 100644 --- a/kerops/ops/linear.py +++ b/kerops/ops/linear.py @@ -11,10 +11,6 @@ def fwd_warps(in_channels): return {16: 2, 32: 2, 64: 1, 128: 1}[in_channels] -def fwd_dblock(in_channels): - return {16: 16, 32: 16, 64: 16, 128: 16}[in_channels] - - def fwd_ilp(in_channels): return {16: 8, 32: 8, 64: 4, 128: 4}[in_channels] @@ -27,13 +23,9 @@ def bwd_dblock(in_channels): return {16: 16, 32: 32, 64: 32, 128: 32}[in_channels] -def bwd_ilp(in_channels): - return {16: 16, 32: 16, 64: 16, 128: 16}[in_channels] - - @configure( _num_warps=lambda weight: fwd_warps(weight.shape[0]), - D_block=lambda weight: fwd_dblock(weight.shape[0]), + D_block=16, _ILP=lambda weight: fwd_ilp(weight.shape[0]), ) def ReLULinearAdd( @@ -41,9 +33,9 @@ def ReLULinearAdd( weight, add_other, *, - _num_warps: ConfigurableArg = 2, - D_block: ConfigurableArg = 16, - _ILP: ConfigurableArg = 8, + _num_warps: ConfigurableArg, + D_block: ConfigurableArg, + _ILP: ConfigurableArg, ): in_channels = x.shape[1] out_channels = weight.shape[1] @@ -87,16 +79,16 @@ def ReLULinearAdd( @configure( _num_warps=lambda weight: bwd_warps(weight.shape[0]), D_block=lambda weight: bwd_dblock(weight.shape[0]), - _ILP=lambda weight: bwd_ilp(weight.shape[0]), + _ILP=16, ) def ReLULinearBackward( input, grad, weight, *, - _num_warps: ConfigurableArg = 8, - D_block: ConfigurableArg = 32, - _ILP: ConfigurableArg = 16, + _num_warps: ConfigurableArg, + D_block: ConfigurableArg, + _ILP: ConfigurableArg, ): in_channels = weight.shape[0] out_channels = grad.shape[1] From 14e54571c50e36a5ca8c5612dff918e52f20bdbf Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Tue, 7 Jan 2025 13:53:55 +0300 Subject: [PATCH 15/19] black + isort --- kerops/kernels/dw_conv.py | 18 +++++++++--------- kerops/ops/conv.py | 6 ++++-- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/kerops/kernels/dw_conv.py b/kerops/kernels/dw_conv.py index 247e261..b62ffbd 100644 --- a/kerops/kernels/dw_conv.py +++ b/kerops/kernels/dw_conv.py @@ -199,30 +199,30 @@ def _DWConv_wgrad_cl3d_impl( H0_load = 2 * H_cell < H H1_load = 2 * H_cell + 1 < H W1_load = 2 * W_cell + 1 < W - + tmp_input_ptr = input_ptr + 2 * H_cell * H_stride + 2 * W_cell * W_stride x_h0_w0 = tl.load(tmp_input_ptr + offset, mask=mask and H0_load) - + tmp_input_ptr = input_ptr + (2 * H_cell + 1) * H_stride + 2 * W_cell * W_stride x_h1_w0 = tl.load(tmp_input_ptr + offset, mask=mask and H1_load) - + tmp_input_ptr = input_ptr + 2 * H_cell * H_stride + (2 * W_cell + 1) * W_stride x_h0_w1 = tl.load(tmp_input_ptr + offset, mask=mask and (W1_load and H0_load)) - + tmp_input_ptr = input_ptr + (2 * H_cell + 1) * H_stride + (2 * W_cell + 1) * W_stride x_h1_w1 = tl.load(tmp_input_ptr + offset, mask=mask and (W1_load and H1_load)) - #grad = tl.zeros([channels, D_block], dtype=tl.float16)[None] - + # grad = tl.zeros([channels, D_block], dtype=tl.float16)[None] + for k in tl.static_range(0, 16): i = (k % 4) - 1 j = (k // 4) - 1 load_next = (2 * H_cell + i < H and 2 * H_cell + i >= 0) and (2 * W_cell + j < W and 2 * W_cell + j >= 0) - tmp_grad_ptr = grad_ptr + (2 * H_cell + i) * H_stride + (2 * W_cell + j) * W_stride + tmp_grad_ptr = grad_ptr + (2 * H_cell + i) * H_stride + (2 * W_cell + j) * W_stride if load_next: - grad = tl.load(tmp_grad_ptr + in_offset, mask=in_mask, other=0.)[None] - + grad = tl.load(tmp_grad_ptr + in_offset, mask=in_mask, other=0.0)[None] + if i == -1 and j == -1: h2_w2 += tl.sum(grad * x_h0_w0, axis=2) elif i == -1 and j == 0: diff --git a/kerops/ops/conv.py b/kerops/ops/conv.py index c928d60..b163dc8 100644 --- a/kerops/ops/conv.py +++ b/kerops/ops/conv.py @@ -77,9 +77,11 @@ def DWConv(x, weight, *, ACCTYPE: ConfigurableArg, _num_warps: ConfigurableArg, ACCTYPE='float32', _num_warps=lambda x: dwconv_wgrad_warps(x.shape[1]), D_block=lambda x: dwconv_wgrad_dblock(x.shape[1]), - ILP=lambda x: dwconv_wgrad_ilp(x.shape[1]) + ILP=lambda x: dwconv_wgrad_ilp(x.shape[1]), ) -def DWConvWGRAD(x, grad, *, ACCTYPE: ConfigurableArg, _num_warps: ConfigurableArg, D_block: ConfigurableArg, ILP: ConfigurableArg): +def DWConvWGRAD( + x, grad, *, ACCTYPE: ConfigurableArg, _num_warps: ConfigurableArg, D_block: ConfigurableArg, ILP: ConfigurableArg +): channels = x.shape[1] assert x.ndim == grad.ndim == 5 From bcf9086d48af281386300ad19c346f5e7d9a2e20 Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Tue, 7 Jan 2025 13:55:19 +0300 Subject: [PATCH 16/19] naming --- kerops/kernels/dw_conv.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/kerops/kernels/dw_conv.py b/kerops/kernels/dw_conv.py index b62ffbd..4bcf056 100644 --- a/kerops/kernels/dw_conv.py +++ b/kerops/kernels/dw_conv.py @@ -179,8 +179,8 @@ def _DWConv_wgrad_cl3d_impl( mask = mask and (d_offset[None, None, :] + near_offset[:, None, None] >= 0 - D_block * D_cell) mask = mask and (near_offset[:, None, None] != 2) - in_offset = d_offset[None, :] * channels + channels_offset[:, None] - in_mask = d_offset[None, :] < D - D_block * D_cell + grad_offset = d_offset[None, :] * channels + channels_offset[:, None] + grad_mask = d_offset[None, :] < D - D_block * D_cell h0_w0 = tl.zeros([4, channels], dtype=ACCTYPE) h0_w1 = tl.zeros([4, channels], dtype=ACCTYPE) @@ -212,8 +212,6 @@ def _DWConv_wgrad_cl3d_impl( tmp_input_ptr = input_ptr + (2 * H_cell + 1) * H_stride + (2 * W_cell + 1) * W_stride x_h1_w1 = tl.load(tmp_input_ptr + offset, mask=mask and (W1_load and H1_load)) - # grad = tl.zeros([channels, D_block], dtype=tl.float16)[None] - for k in tl.static_range(0, 16): i = (k % 4) - 1 j = (k // 4) - 1 @@ -221,7 +219,7 @@ def _DWConv_wgrad_cl3d_impl( tmp_grad_ptr = grad_ptr + (2 * H_cell + i) * H_stride + (2 * W_cell + j) * W_stride if load_next: - grad = tl.load(tmp_grad_ptr + in_offset, mask=in_mask, other=0.0)[None] + grad = tl.load(tmp_grad_ptr + grad_offset, mask=grad_mask, other=0.0)[None] if i == -1 and j == -1: h2_w2 += tl.sum(grad * x_h0_w0, axis=2) From 402f1af62a2331195a9d3f7103b7adc048ef0cbe Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Tue, 7 Jan 2025 13:56:18 +0300 Subject: [PATCH 17/19] version --- kerops/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kerops/__version__.py b/kerops/__version__.py index b8023d8..d18f409 100644 --- a/kerops/__version__.py +++ b/kerops/__version__.py @@ -1 +1 @@ -__version__ = '0.0.1' +__version__ = '0.0.2' From 8b7a8e0274d520cf08275451fdd39ed3688afc76 Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Tue, 7 Jan 2025 18:59:53 +0300 Subject: [PATCH 18/19] CI --- .flake8 | 4 ++++ .github/workflows/lint.yml | 23 +++++++++++++++++++++++ requirements-dev.txt | 3 --- 3 files changed, 27 insertions(+), 3 deletions(-) create mode 100644 .flake8 create mode 100644 .github/workflows/lint.yml diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..653028a --- /dev/null +++ b/.flake8 @@ -0,0 +1,4 @@ +[flake8] + +max-line-length = 120 +per-file-ignores = __init__.py: F401 \ No newline at end of file diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..951d456 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,23 @@ +name: Lint + +on: [ pull_request ] + +env: + MODULE_NAME: kerops + +jobs: + lint: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Check python code style + run: | + pip install -r requirements-dev.txt + flake8 . + isort --check . + black --check . \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index 09e76e8..793564c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,3 @@ -matplotlib -seaborn - black flake8 flake8-tidy-imports From 8aa810986b6e8a8978cced103b24bff51d9781a7 Mon Sep 17 00:00:00 2001 From: AnihilatorGun Date: Tue, 7 Jan 2025 19:04:54 +0300 Subject: [PATCH 19/19] CI --- .flake8 | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.flake8 b/.flake8 index 653028a..a1f5a6a 100644 --- a/.flake8 +++ b/.flake8 @@ -1,4 +1,6 @@ [flake8] max-line-length = 120 -per-file-ignores = __init__.py: F401 \ No newline at end of file +per-file-ignores = + kerops/kernels/*: B007 + __init__.py: F401 \ No newline at end of file