From 9f7a068d02a026b1bf8bea3ee02c677f02f939ba Mon Sep 17 00:00:00 2001 From: PPPoint <1024879159@qq.com> Date: Thu, 5 Feb 2026 19:35:03 +0800 Subject: [PATCH] =?UTF-8?q?Finish=20T1-1-12:=20minimum=E3=80=81atan2?= =?UTF-8?q?=E3=80=81addcdiv=E3=80=81bucketize=E3=80=81binary=5Fcross=5Fent?= =?UTF-8?q?ropy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/ntops/kernels/__init__.py | 10 ++ src/ntops/kernels/addcdiv.py | 33 ++++++ src/ntops/kernels/atan2.py | 58 ++++++++++ src/ntops/kernels/binary_cross_entropy.py | 85 ++++++++++++++ src/ntops/kernels/bucketize.py | 57 ++++++++++ src/ntops/kernels/minimum.py | 22 ++++ src/ntops/torch/__init__.py | 10 ++ src/ntops/torch/addcdiv.py | 18 +++ src/ntops/torch/atan2.py | 18 +++ src/ntops/torch/binary_cross_entropy.py | 132 ++++++++++++++++++++++ src/ntops/torch/bucketize.py | 37 ++++++ src/ntops/torch/minimum.py | 18 +++ 12 files changed, 498 insertions(+) create mode 100644 src/ntops/kernels/addcdiv.py create mode 100644 src/ntops/kernels/atan2.py create mode 100644 src/ntops/kernels/binary_cross_entropy.py create mode 100644 src/ntops/kernels/bucketize.py create mode 100644 src/ntops/kernels/minimum.py create mode 100644 src/ntops/torch/addcdiv.py create mode 100644 src/ntops/torch/atan2.py create mode 100644 src/ntops/torch/binary_cross_entropy.py create mode 100644 src/ntops/torch/bucketize.py create mode 100644 src/ntops/torch/minimum.py diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index 084e52c..a90078a 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -1,11 +1,15 @@ from ntops.kernels import ( abs, add, + addcdiv, addmm, + atan2, + binary_cross_entropy, bitwise_and, bitwise_not, bitwise_or, bmm, + bucketize, clamp, cos, div, @@ -20,6 +24,7 @@ layer_norm, le, lt, + minimum, mm, mul, ne, @@ -76,4 +81,9 @@ "softmax", "sub", "tanh", + "minimum", + "atan2", + "addcdiv", + "bucketize", + "binary_cross_entropy", ] diff --git a/src/ntops/kernels/addcdiv.py b/src/ntops/kernels/addcdiv.py new file mode 100644 index 0000000..e9d4112 --- /dev/null +++ b/src/ntops/kernels/addcdiv.py @@ -0,0 +1,33 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, tensor1, tensor2, value, output): + dtype = output.dtype + val_input = ntl.cast(input, dtype) + val_t1 = ntl.cast(tensor1, dtype) + val_t2 = ntl.cast(tensor2, dtype) + val_v = ntl.cast(value, dtype) + + # out = input + value * (t1 / t2) + res = val_input + val_v * (val_t1 / val_t2) + + output = res + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(0, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/atan2.py b/src/ntops/kernels/atan2.py new file mode 100644 index 0000000..17207f8 --- /dev/null +++ b/src/ntops/kernels/atan2.py @@ -0,0 +1,58 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, other, output): + y = ntl.cast(input, ntl.float32) + x = ntl.cast(other, ntl.float32) + + # 常量定义 + PI = 3.1415927410125732 + HALF_PI = 1.5707963705062866 + + abs_y = ntl.where(y < 0, -y, y) + abs_x = ntl.where(x < 0, -x, x) + + swap_xy = abs_y > abs_x + + num = ntl.where(swap_xy, abs_x, abs_y) + den = ntl.where(swap_xy, abs_y, abs_x) + + # 防除零 + den_safe = ntl.where(den == 0.0, 1.0, den) + z = num / den_safe + z_sq = z * z + + # 多项式逼近 + c0 = 0.9998660 + c1 = -0.3302995 + c2 = 0.1801410 + c3 = -0.0851330 + c4 = 0.0208351 + + poly_res = z * (c0 + z_sq * (c1 + z_sq * (c2 + z_sq * (c3 + z_sq * c4)))) + + theta = ntl.where(swap_xy, HALF_PI - poly_res, poly_res) + + res = theta + res = ntl.where(x < 0, PI - theta, res) + res = ntl.where(y < 0, -res, res) + res = ntl.where((x == 0.0) & (y == 0.0), 0.0, res) + + output = ntl.cast(res, output.dtype) + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/binary_cross_entropy.py b/src/ntops/kernels/binary_cross_entropy.py new file mode 100644 index 0000000..f35942c --- /dev/null +++ b/src/ntops/kernels/binary_cross_entropy.py @@ -0,0 +1,85 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement_bce(input, target, weight, output, has_weight, block_size): + input_t = input.flatten().tile((block_size,)) + target_t = target.flatten().tile((block_size,)) + weight_t = weight.flatten().tile((block_size,)) + output_t = output.flatten().tile((block_size,)) + return input_t, target_t, weight_t, output_t, has_weight + + +def application_bce(input, target, weight, output, has_weight): + val_input = ntl.cast(input, ntl.float32) + val_target = ntl.cast(target, ntl.float32) + + eps = 1e-12 + term1 = ntl.maximum(val_input, eps) + term2 = ntl.maximum(1.0 - val_input, eps) + + term_1 = val_target * ntl.log(term1) + term_2 = (1.0 - val_target) * ntl.log(term2) + loss = 0.0 - (term_1 + term_2) + + if has_weight: + val_weight = ntl.cast(weight, ntl.float32) + loss = loss * val_weight + + output = ntl.cast(loss, output.dtype) + + +def premake_bce(ndim, dtype=None, has_weight=False, block_size=None): + arrangement_ = functools.partial(arrangement_bce, block_size=block_size) + tensors = ( + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), # input + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), # target + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), # weight + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), # output + Tensor(0, constexpr=True, value=has_weight), # has_weight + ) + return arrangement_, application_bce, tensors + + +def arrangement_reduce(input, output, block_size): + input_t = input.tile((block_size,)) + output_t = output.tile((1,)) + return input_t, output_t + + +def application_reduce(input, output): + accumulator = 0.0 + for i in range(input.shape[0]): + accumulator += ntl.cast(input[i], ntl.float32) + output[0] = ntl.cast(accumulator, output.dtype) + + +def premake_reduce(dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement_reduce, block_size=block_size) + tensors = ( + Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # input + Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # output + ) + return arrangement_, application_reduce, tensors + + +def arrangement_div(input, output, divisor): + return input.tile((1,)), output.tile((1,)), divisor + + +def application_div(input, output, divisor): + val = ntl.cast(input, ntl.float32) + res = val / divisor + output = ntl.cast(res, output.dtype) + + +def premake_div(divisor, dtype=None): + arrangement_ = functools.partial(arrangement_div) + tensors = ( + Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # input + Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # output + Tensor(0, constexpr=True, value=divisor), + ) + return arrangement_, application_div, tensors diff --git a/src/ntops/kernels/bucketize.py b/src/ntops/kernels/bucketize.py new file mode 100644 index 0000000..ee4d1fb --- /dev/null +++ b/src/ntops/kernels/bucketize.py @@ -0,0 +1,57 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + + +def arrangement( + input, boundaries, output, right, bound_len, padded_len=None, block_size=None +): + input_arranged = input.flatten().tile((block_size,)) + output_arranged = output.flatten().tile((block_size,)) + + bound_arranged = boundaries.flatten().tile((padded_len,)) + bound_arranged = bound_arranged.expand((input_arranged.shape[0],)) + + return input_arranged, bound_arranged, output_arranged, right, bound_len, padded_len + + +def application(input, boundaries, output, right, bound_len): + val_in = ntl.cast(input, ntl.float32) + val_bound = ntl.cast(boundaries, ntl.float32) + + bound_idx = ntl.arange(0, boundaries.shape[0]) + valid_mask = bound_idx < bound_len + + in_bc = ntl.expand_dims(val_in, 1) + bound_bc = ntl.expand_dims(val_bound, 0) + mask_bc = ntl.expand_dims(valid_mask, 0) + + if right: + # count(b <= x) + cond = bound_bc <= in_bc + else: + # count(b < x) + cond = bound_bc < in_bc + + final_cond = cond & mask_bc + + bucket_idx = ntl.sum(ntl.cast(final_cond, ntl.int32), 1) + + output = ntl.cast(bucket_idx, output.dtype) + + +def premake(ndim, dtype=None, padded_len=None, block_size=None): + arrangement_ = functools.partial( + arrangement, padded_len=padded_len, block_size=block_size + ) + + tensors = ( + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), # input + Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # boundaries + Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), # output + Tensor(0, dtype=dtype), + Tensor(0, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/kernels/minimum.py b/src/ntops/kernels/minimum.py new file mode 100644 index 0000000..9f3d65f --- /dev/null +++ b/src/ntops/kernels/minimum.py @@ -0,0 +1,22 @@ +import functools + +import ninetoothed.language as ntl +from ninetoothed import Tensor + +from ntops.kernels.element_wise import arrangement + + +def application(input, other, output): + output = ntl.minimum(input, other) # noqa: F841 + + +def premake(ndim, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + Tensor(ndim, dtype=dtype), + ) + + return arrangement_, application, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 702877e..f5227a1 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -1,10 +1,14 @@ from ntops.torch.abs import abs from ntops.torch.add import add +from ntops.torch.addcdiv import addcdiv from ntops.torch.addmm import addmm +from ntops.torch.atan2 import atan2 +from ntops.torch.binary_cross_entropy import binary_cross_entropy from ntops.torch.bitwise_and import bitwise_and from ntops.torch.bitwise_not import bitwise_not from ntops.torch.bitwise_or import bitwise_or from ntops.torch.bmm import bmm +from ntops.torch.bucketize import bucketize from ntops.torch.clamp import clamp from ntops.torch.cos import cos from ntops.torch.div import div @@ -20,6 +24,7 @@ from ntops.torch.le import le from ntops.torch.lt import lt from ntops.torch.matmul import matmul +from ntops.torch.minimum import minimum from ntops.torch.mm import mm from ntops.torch.mul import mul from ntops.torch.ne import ne @@ -76,4 +81,9 @@ "softmax", "sub", "tanh", + "minimum", + "atan2", + "addcdiv", + "bucketize", + "binary_cross_entropy", ] diff --git a/src/ntops/torch/addcdiv.py b/src/ntops/torch/addcdiv.py new file mode 100644 index 0000000..e4c627d --- /dev/null +++ b/src/ntops/torch/addcdiv.py @@ -0,0 +1,18 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def addcdiv(input, tensor1, tensor2, *, value=1.0, out=None): + if out is None: + out = torch.empty_like(input) + + block_size = 1024 + kernel = _cached_make( + ntops.kernels.addcdiv.premake, input.ndim, input.dtype, block_size=block_size + ) + + kernel(input, tensor1, tensor2, value, out) + + return out diff --git a/src/ntops/torch/atan2.py b/src/ntops/torch/atan2.py new file mode 100644 index 0000000..df2303d --- /dev/null +++ b/src/ntops/torch/atan2.py @@ -0,0 +1,18 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def atan2(input, other, *, out=None): + if out is None: + out = torch.empty_like(input) + + block_size = 1024 + kernel = _cached_make( + ntops.kernels.atan2.premake, input.ndim, input.dtype, block_size=block_size + ) + + kernel(input, other, out) + + return out diff --git a/src/ntops/torch/binary_cross_entropy.py b/src/ntops/torch/binary_cross_entropy.py new file mode 100644 index 0000000..ac13f03 --- /dev/null +++ b/src/ntops/torch/binary_cross_entropy.py @@ -0,0 +1,132 @@ +import math + +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def next_power_of_2(n): + if n == 0: + return 1 + return 1 << (n - 1).bit_length() + + +def get_optimal_block_size(dim_size): + target_size = next_power_of_2(dim_size) + if target_size > 1024: + target_size = 1024 + if target_size < 32: + target_size = 32 + return target_size + + +def binary_cross_entropy( + input, + target, + weight=None, + size_average=None, + reduce=None, + reduction="mean", + out=None, +): + if size_average is not None or reduce is not None: + if reduce is False: + reduction = "none" + elif size_average is True or size_average is None: + reduction = "mean" + else: + reduction = "sum" + + device = input.device + dtype = input.dtype + numel = input.numel() + + if target.shape != input.shape: + target = target.expand(input.shape) + + has_weight = False + if weight is not None: + has_weight = True + if weight.shape != input.shape: + weight = weight.expand(input.shape) + weight = weight.contiguous() + else: + weight = input + + compute_dtype = dtype + if reduction != "none": + compute_dtype = torch.float32 + + if out is not None and reduction == "none": + output_tensor = out + else: + output_tensor = torch.empty(input.shape, dtype=compute_dtype, device=device) + + # 高精度执行 + block_size = 1024 + kernel_bce = _cached_make( + ntops.kernels.binary_cross_entropy.premake_bce, + input.ndim, + compute_dtype, + has_weight, + block_size, + ) + kernel_bce(input, target, weight, output_tensor, has_weight) + + if reduction == "none": + return output_tensor + + # Float32 + current = output_tensor.contiguous().view((numel,)) + + def iterative_reduce(curr_tensor): + while curr_tensor.numel() > 1: + curr_numel = curr_tensor.numel() + block_size = get_optimal_block_size(curr_numel) + + output_len = math.ceil(curr_numel / block_size) + output = torch.empty((output_len,), dtype=compute_dtype, device=device) + + kernel_reduce = _cached_make( + ntops.kernels.binary_cross_entropy.premake_reduce, + compute_dtype, + block_size, + ) + kernel_reduce(curr_tensor, output) + curr_tensor = output + return curr_tensor + + final_sum_tensor = iterative_reduce(current) + + if reduction == "sum": + # Div Kernel (div by 1.0) + if dtype != compute_dtype: + result = torch.empty((1,), dtype=dtype, device=device) + kernel_cast = _cached_make( + ntops.kernels.binary_cross_entropy.premake_div, 1.0, dtype + ) + kernel_cast(final_sum_tensor, result, 1) + final_sum_tensor = result + + result = final_sum_tensor.view(()) + if out is not None: + out.copy_(result) + return out + return result + + elif reduction == "mean": + final_output = torch.empty((1,), dtype=dtype, device=device) + + kernel_div = _cached_make( + ntops.kernels.binary_cross_entropy.premake_div, numel, dtype + ) + kernel_div(final_sum_tensor, final_output, numel) + + result = final_output.view(()) + if out is not None: + out.copy_(result) + return out + return result + + return output_tensor diff --git a/src/ntops/torch/bucketize.py b/src/ntops/torch/bucketize.py new file mode 100644 index 0000000..c34b804 --- /dev/null +++ b/src/ntops/torch/bucketize.py @@ -0,0 +1,37 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def bucketize(input, boundaries, *, out=None, right=False): + if out is None: + out = torch.empty_like(input, dtype=torch.int64) + + if boundaries.ndim != 1: + raise ValueError("boundaries must be 1 dimension") + + bound_len = boundaries.numel() + if bound_len == 0: + out.fill_(0) + return out + + if (bound_len & (bound_len - 1)) == 0: + padded_len = bound_len + else: + padded_len = 1 << bound_len.bit_length() + + padded_len = max(16, padded_len) + + block_size = 1024 + kernel = _cached_make( + ntops.kernels.bucketize.premake, + input.ndim, + input.dtype, + padded_len=padded_len, + block_size=block_size, + ) + + kernel(input, boundaries, out, right, bound_len) + + return out diff --git a/src/ntops/torch/minimum.py b/src/ntops/torch/minimum.py new file mode 100644 index 0000000..d3032b6 --- /dev/null +++ b/src/ntops/torch/minimum.py @@ -0,0 +1,18 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def minimum(input, other, *, out=None): + if out is None: + out = torch.empty_like(input) + + block_size = 1024 + kernel = _cached_make( + ntops.kernels.minimum.premake, input.ndim, input.dtype, block_size=block_size + ) + + kernel(input, other, out) + + return out