diff --git a/src/ntops/kernels/__init__.py b/src/ntops/kernels/__init__.py index f6934ef..cf14e6b 100644 --- a/src/ntops/kernels/__init__.py +++ b/src/ntops/kernels/__init__.py @@ -31,6 +31,7 @@ relu, rms_norm, rotary_position_embedding, + round, rsqrt, scaled_dot_product_attention, sigmoid, @@ -74,6 +75,7 @@ "relu", "rms_norm", "rotary_position_embedding", + "round", "rsqrt", "scaled_dot_product_attention", "sigmoid", diff --git a/src/ntops/kernels/round.py b/src/ntops/kernels/round.py new file mode 100644 index 0000000..3d95f9a --- /dev/null +++ b/src/ntops/kernels/round.py @@ -0,0 +1,35 @@ +import functools + +import ninetoothed +import ninetoothed.language as ntl +from ninetoothed import Tensor +from ninetoothed.language import libdevice + +from ntops.kernels.element_wise import arrangement + + +def application(input, output): + output = libdevice.nearbyint(ntl.cast(input, ntl.float32)) # noqa: F841 + + +def application_with_decimals(input, factor, inv_factor, output): + scaled = input * ntl.cast( + factor, input.dtype + ) # 在 input 的原始精度下乘,匹配 torch 行为 + output = libdevice.nearbyint(ntl.cast(scaled, ntl.float32)) * inv_factor # noqa: F841 + + +def premake(ndim, decimals=0, dtype=None, block_size=None): + arrangement_ = functools.partial(arrangement, block_size=block_size) + + if decimals == 0: + tensors = (Tensor(ndim, dtype=dtype), Tensor(ndim, dtype=dtype)) + return arrangement_, application, tensors + else: + tensors = ( + Tensor(ndim, dtype=dtype), + Tensor(0, dtype=ninetoothed.float64), + Tensor(0, dtype=ninetoothed.float64), + Tensor(ndim, dtype=dtype), + ) + return arrangement_, application_with_decimals, tensors diff --git a/src/ntops/torch/__init__.py b/src/ntops/torch/__init__.py index 82fc596..9815b27 100644 --- a/src/ntops/torch/__init__.py +++ b/src/ntops/torch/__init__.py @@ -31,6 +31,7 @@ from ntops.torch.relu import relu from ntops.torch.rms_norm import rms_norm from ntops.torch.rotary_position_embedding import rotary_position_embedding +from ntops.torch.round import round from ntops.torch.rsqrt import rsqrt from ntops.torch.scaled_dot_product_attention import scaled_dot_product_attention from ntops.torch.sigmoid import sigmoid @@ -74,6 +75,7 @@ "relu", "rms_norm", "rotary_position_embedding", + "round", "rsqrt", "scaled_dot_product_attention", "sigmoid", diff --git a/src/ntops/torch/round.py b/src/ntops/torch/round.py new file mode 100644 index 0000000..2496767 --- /dev/null +++ b/src/ntops/torch/round.py @@ -0,0 +1,20 @@ +import torch + +import ntops +from ntops.torch.utils import _cached_make + + +def round(input, decimals=0, *, out=None): + if out is None: + out = torch.empty_like(input) + + if decimals == 0: + kernel = _cached_make(ntops.kernels.round.premake, input.ndim) + kernel(input, out) + else: + factor = 10.0**decimals + inv_factor = 1.0 / factor + kernel = _cached_make(ntops.kernels.round.premake, input.ndim, decimals=True) + kernel(input, factor, inv_factor, out) + + return out diff --git a/tests/test_round.py b/tests/test_round.py new file mode 100644 index 0000000..0179f98 --- /dev/null +++ b/tests/test_round.py @@ -0,0 +1,17 @@ +import pytest +import torch + +import ntops +from tests.skippers import skip_if_cuda_not_available +from tests.utils import generate_arguments + + +@skip_if_cuda_not_available +@pytest.mark.parametrize(*generate_arguments()) +def test_round(shape, dtype, device, rtol, atol): + input = torch.randn(shape, dtype=dtype, device=device) + + ninetoothed_output = ntops.torch.round(input) + reference_output = torch.round(input) + + assert torch.allclose(ninetoothed_output, reference_output, rtol=rtol, atol=atol)