Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/ntops/kernels/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from ntops.kernels import (
abs,
add,
addcdiv,
addmm,
atan2,
avg_pool2d,
binary_cross_entropy,
bitwise_and,
bitwise_not,
bitwise_or,
bmm,
bucketize,
clamp,
conv2d,
cos,
Expand All @@ -23,6 +27,7 @@
le,
lt,
max_pool2d,
minimum,
mm,
mul,
ne,
Expand Down Expand Up @@ -82,4 +87,9 @@
"softmax",
"sub",
"tanh",
"minimum",
"atan2",
"addcdiv",
"bucketize",
"binary_cross_entropy",
]
33 changes: 33 additions & 0 deletions src/ntops/kernels/addcdiv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, tensor1, tensor2, value, output):
dtype = output.dtype
val_input = ntl.cast(input, dtype)
val_t1 = ntl.cast(tensor1, dtype)
val_t2 = ntl.cast(tensor2, dtype)
val_v = ntl.cast(value, dtype)

# out = input + value * (t1 / t2)
res = val_input + val_v * (val_t1 / val_t2)

output = res


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(0, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
58 changes: 58 additions & 0 deletions src/ntops/kernels/atan2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, other, output):
y = ntl.cast(input, ntl.float32)
x = ntl.cast(other, ntl.float32)

# 常量定义
PI = 3.1415927410125732
HALF_PI = 1.5707963705062866

abs_y = ntl.where(y < 0, -y, y)
abs_x = ntl.where(x < 0, -x, x)

swap_xy = abs_y > abs_x

num = ntl.where(swap_xy, abs_x, abs_y)
den = ntl.where(swap_xy, abs_y, abs_x)

# 防除零
den_safe = ntl.where(den == 0.0, 1.0, den)
z = num / den_safe
z_sq = z * z

# 多项式逼近
c0 = 0.9998660
c1 = -0.3302995
c2 = 0.1801410
c3 = -0.0851330
c4 = 0.0208351

poly_res = z * (c0 + z_sq * (c1 + z_sq * (c2 + z_sq * (c3 + z_sq * c4))))

theta = ntl.where(swap_xy, HALF_PI - poly_res, poly_res)

res = theta
res = ntl.where(x < 0, PI - theta, res)
res = ntl.where(y < 0, -res, res)
res = ntl.where((x == 0.0) & (y == 0.0), 0.0, res)

output = ntl.cast(res, output.dtype)


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
85 changes: 85 additions & 0 deletions src/ntops/kernels/binary_cross_entropy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor


def arrangement_bce(input, target, weight, output, has_weight, block_size):
input_t = input.flatten().tile((block_size,))
target_t = target.flatten().tile((block_size,))
weight_t = weight.flatten().tile((block_size,))
output_t = output.flatten().tile((block_size,))
return input_t, target_t, weight_t, output_t, has_weight


def application_bce(input, target, weight, output, has_weight):
val_input = ntl.cast(input, ntl.float32)
val_target = ntl.cast(target, ntl.float32)

eps = 1e-12
term1 = ntl.maximum(val_input, eps)
term2 = ntl.maximum(1.0 - val_input, eps)

term_1 = val_target * ntl.log(term1)
term_2 = (1.0 - val_target) * ntl.log(term2)
loss = 0.0 - (term_1 + term_2)

if has_weight:
val_weight = ntl.cast(weight, ntl.float32)
loss = loss * val_weight

output = ntl.cast(loss, output.dtype)


def premake_bce(ndim, dtype=None, has_weight=False, block_size=None):
arrangement_ = functools.partial(arrangement_bce, block_size=block_size)
tensors = (
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), # input
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), # target
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), # weight
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), # output
Tensor(0, constexpr=True, value=has_weight), # has_weight
)
return arrangement_, application_bce, tensors


def arrangement_reduce(input, output, block_size):
input_t = input.tile((block_size,))
output_t = output.tile((1,))
return input_t, output_t


def application_reduce(input, output):
accumulator = 0.0
for i in range(input.shape[0]):
accumulator += ntl.cast(input[i], ntl.float32)
output[0] = ntl.cast(accumulator, output.dtype)


def premake_reduce(dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement_reduce, block_size=block_size)
tensors = (
Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # input
Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # output
)
return arrangement_, application_reduce, tensors


def arrangement_div(input, output, divisor):
return input.tile((1,)), output.tile((1,)), divisor


def application_div(input, output, divisor):
val = ntl.cast(input, ntl.float32)
res = val / divisor
output = ntl.cast(res, output.dtype)


def premake_div(divisor, dtype=None):
arrangement_ = functools.partial(arrangement_div)
tensors = (
Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # input
Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # output
Tensor(0, constexpr=True, value=divisor),
)
return arrangement_, application_div, tensors
57 changes: 57 additions & 0 deletions src/ntops/kernels/bucketize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor


def arrangement(
input, boundaries, output, right, bound_len, padded_len=None, block_size=None
):
input_arranged = input.flatten().tile((block_size,))
output_arranged = output.flatten().tile((block_size,))

bound_arranged = boundaries.flatten().tile((padded_len,))
bound_arranged = bound_arranged.expand((input_arranged.shape[0],))

return input_arranged, bound_arranged, output_arranged, right, bound_len, padded_len


def application(input, boundaries, output, right, bound_len):
val_in = ntl.cast(input, ntl.float32)
val_bound = ntl.cast(boundaries, ntl.float32)

bound_idx = ntl.arange(0, boundaries.shape[0])
valid_mask = bound_idx < bound_len

in_bc = ntl.expand_dims(val_in, 1)
bound_bc = ntl.expand_dims(val_bound, 0)
mask_bc = ntl.expand_dims(valid_mask, 0)

if right:
# count(b <= x)
cond = bound_bc <= in_bc
else:
# count(b < x)
cond = bound_bc < in_bc

final_cond = cond & mask_bc

bucket_idx = ntl.sum(ntl.cast(final_cond, ntl.int32), 1)

output = ntl.cast(bucket_idx, output.dtype)


def premake(ndim, dtype=None, padded_len=None, block_size=None):
arrangement_ = functools.partial(
arrangement, padded_len=padded_len, block_size=block_size
)

tensors = (
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), # input
Tensor(1, dtype=dtype, shape_options={"constexpr": True}), # boundaries
Tensor(ndim, dtype=dtype, shape_options={"constexpr": True}), # output
Tensor(0, dtype=dtype),
Tensor(0, dtype=dtype),
)

return arrangement_, application, tensors
22 changes: 22 additions & 0 deletions src/ntops/kernels/minimum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import functools

import ninetoothed.language as ntl
from ninetoothed import Tensor

from ntops.kernels.element_wise import arrangement


def application(input, other, output):
output = ntl.minimum(input, other) # noqa: F841


def premake(ndim, dtype=None, block_size=None):
arrangement_ = functools.partial(arrangement, block_size=block_size)

tensors = (
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
Tensor(ndim, dtype=dtype),
)

return arrangement_, application, tensors
10 changes: 10 additions & 0 deletions src/ntops/torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from ntops.torch.abs import abs
from ntops.torch.add import add
from ntops.torch.addcdiv import addcdiv
from ntops.torch.addmm import addmm
from ntops.torch.atan2 import atan2
from ntops.torch.avg_pool2d import avg_pool2d
from ntops.torch.binary_cross_entropy import binary_cross_entropy
from ntops.torch.bitwise_and import bitwise_and
from ntops.torch.bitwise_not import bitwise_not
from ntops.torch.bitwise_or import bitwise_or
from ntops.torch.bmm import bmm
from ntops.torch.bucketize import bucketize
from ntops.torch.clamp import clamp
from ntops.torch.conv2d import conv2d
from ntops.torch.cos import cos
Expand All @@ -23,6 +27,7 @@
from ntops.torch.lt import lt
from ntops.torch.matmul import matmul
from ntops.torch.max_pool2d import max_pool2d
from ntops.torch.minimum import minimum
from ntops.torch.mm import mm
from ntops.torch.mul import mul
from ntops.torch.ne import ne
Expand Down Expand Up @@ -82,4 +87,9 @@
"softmax",
"sub",
"tanh",
"minimum",
"atan2",
"addcdiv",
"bucketize",
"binary_cross_entropy",
]
18 changes: 18 additions & 0 deletions src/ntops/torch/addcdiv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def addcdiv(input, tensor1, tensor2, *, value=1.0, out=None):
if out is None:
out = torch.empty_like(input)

block_size = 1024
kernel = _cached_make(
ntops.kernels.addcdiv.premake, input.ndim, input.dtype, block_size=block_size
)

kernel(input, tensor1, tensor2, value, out)

return out
18 changes: 18 additions & 0 deletions src/ntops/torch/atan2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch

import ntops
from ntops.torch.utils import _cached_make


def atan2(input, other, *, out=None):
if out is None:
out = torch.empty_like(input)

block_size = 1024
kernel = _cached_make(
ntops.kernels.atan2.premake, input.ndim, input.dtype, block_size=block_size
)

kernel(input, other, out)

return out
Loading