From 87018bb446afdc0778b9935ca94049fdf6fc7bb7 Mon Sep 17 00:00:00 2001 From: yes Date: Sun, 5 Apr 2026 14:28:27 -0700 Subject: [PATCH 1/2] Replace direct builtins-side multi_ptr usage with forward-declared helper customization points so no longer needs to include the full definition. Preserve multi_ptr builtin behavior once multi_ptr.hpp is included, and update include-deps / regression coverage. On a TU including only : Host-only Frontend: 601.943 ms -> 572.657 ms (-4.87%) Device-only Frontend: 592.140 ms -> 457.361 ms (-22.76%) multi_ptr.hpp parse time dropped from 18.708 ms -> 0 ms on host multi_ptr.hpp parse time dropped from 108.662 ms -> 0 ms on device --- .../include/sycl/detail/builtins/builtins.hpp | 64 +++++++++++++---- .../sycl/detail/builtins/math_functions.inc | 15 ++-- sycl/include/sycl/detail/fwd/multi_ptr.hpp | 14 ++++ sycl/include/sycl/multi_ptr.hpp | 15 ++++ .../sycl_khr_includes_math.hpp.cpp | 4 +- .../sycl_khr_includes_usm.hpp.cpp | 6 +- .../builtins_multi_ptr_include_order.cpp | 68 +++++++++++++++++++ 7 files changed, 156 insertions(+), 30 deletions(-) create mode 100644 sycl/test/regression/builtins_multi_ptr_include_order.cpp diff --git a/sycl/include/sycl/detail/builtins/builtins.hpp b/sycl/include/sycl/detail/builtins/builtins.hpp index d69c0b93ce7c1..4d0c09bedd6f4 100644 --- a/sycl/include/sycl/detail/builtins/builtins.hpp +++ b/sycl/include/sycl/detail/builtins/builtins.hpp @@ -64,11 +64,11 @@ #pragma once #include +#include #include #include #include #include -#include #include namespace sycl { @@ -83,19 +83,6 @@ template struct use_fast_math : std::false_type {}; #endif template constexpr bool use_fast_math_v = use_fast_math::value; -// Utility trait for getting the decoration of a multi_ptr. -template struct get_multi_ptr_decoration; -template -struct get_multi_ptr_decoration< - multi_ptr> { - static constexpr access::decorated value = DecorateAddress; -}; - -template -constexpr access::decorated get_multi_ptr_decoration_v = - get_multi_ptr_decoration::value; - // Utility trait for checking if a multi_ptr has a "writable" address space, // i.e. global, local, private or generic. template struct has_writeable_addr_space : std::false_type {}; @@ -110,6 +97,55 @@ struct has_writeable_addr_space> template constexpr bool has_writeable_addr_space_v = has_writeable_addr_space::value; +enum class builtin_ptr_kind { raw, multi_ptr }; + +template +using builtin_ptr_kind_tag_t = std::integral_constant< + builtin_ptr_kind, + is_multi_ptr_v>> + ? builtin_ptr_kind::multi_ptr + : builtin_ptr_kind::raw>; + +template +decltype(auto) +builtin_raw_ptr(PtrTy &&Ptr, + std::integral_constant) { + return std::forward(Ptr); +} + +template +auto builtin_raw_ptr(PtrTy &&Ptr, + std::integral_constant) { + return get_raw_pointer(std::forward(Ptr)); +} + +template auto builtin_raw_ptr(PtrTy &&Ptr) { + return builtin_raw_ptr(std::forward(Ptr), + builtin_ptr_kind_tag_t{}); +} + +template +decltype(auto) +builtin_element_ptr(PtrTy &&Ptr, + std::integral_constant) { + return &(*std::forward(Ptr))[0]; +} + +template +auto builtin_element_ptr(PtrTy &&Ptr, + std::integral_constant) { + return detail::builtin_element_ptr(std::forward(Ptr)); +} + +template auto builtin_element_ptr(PtrTy &&Ptr) { + return builtin_element_ptr(std::forward(Ptr), + builtin_ptr_kind_tag_t{}); +} + // Utility trait for changing the element type of a type T. If T is a scalar, // the new type replaces T completely. template diff --git a/sycl/include/sycl/detail/builtins/math_functions.inc b/sycl/include/sycl/detail/builtins/math_functions.inc index 8387fe09e6b69..62f5484517ba6 100644 --- a/sycl/include/sycl/detail/builtins/math_functions.inc +++ b/sycl/include/sycl/detail/builtins/math_functions.inc @@ -223,13 +223,7 @@ auto builtin_delegate_ptr_impl(FuncTy F, PtrTy p, Ts... xs) { // TODO: Optimize for sizes. Make not to violate ANSI-aliasing rules for the // pointer argument. - auto p0 = [&]() { - if constexpr (is_multi_ptr_v) - return address_space_cast>(&(*p)[0]); - else - return &(*p)[0]; - }(); + auto p0 = builtin_element_ptr(p); constexpr auto N = T0::size(); if constexpr (N <= 16) @@ -314,7 +308,8 @@ using builtin_last_raw_intptr_t = PtrTy p) { \ if constexpr (is_multi_ptr_v) { \ /* TODO: Can't really create multi_ptr on host... */ \ - return NAME##_impl(SYCL_CONCAT(LESS_ONE(NUM_ARGS), _ARG), p.get_raw()); \ + return NAME##_impl(SYCL_CONCAT(LESS_ONE(NUM_ARGS), _ARG), \ + builtin_raw_ptr(p)); \ } else { \ return builtin_delegate_ptr_impl( \ [](auto... xs) { return NAME##_impl(xs...); }, p, \ @@ -355,7 +350,7 @@ template auto modf_impl(T0 &x, T1 &&y) { if constexpr (is_multi_ptr_v>) { // TODO: Spec needs to be clarified, multi_ptr shouldn't be possible on // host. - return modf_impl(x, y.get_raw()); + return modf_impl(x, builtin_raw_ptr(std::forward(y))); } else { return builtin_delegate_ptr_impl( [](auto x, auto y) { return modf_impl(x, y); }, y, @@ -433,7 +428,7 @@ template auto sincos_impl(T0 &x, T1 &&y) { if constexpr (is_multi_ptr_v>) { // TODO: Spec needs to be clarified, multi_ptr shouldn't be possible on // host. - return sincos_impl(x, y.get_raw()); + return sincos_impl(x, builtin_raw_ptr(std::forward(y))); } else { return builtin_delegate_ptr_impl( [](auto... xs) { return sincos_impl(xs...); }, y, diff --git a/sycl/include/sycl/detail/fwd/multi_ptr.hpp b/sycl/include/sycl/detail/fwd/multi_ptr.hpp index c3718463dc070..202116d8c7999 100644 --- a/sycl/include/sycl/detail/fwd/multi_ptr.hpp +++ b/sycl/include/sycl/detail/fwd/multi_ptr.hpp @@ -10,6 +10,8 @@ #include +#include + namespace sycl { inline namespace _V1 { // Forward declaration @@ -20,5 +22,17 @@ template multi_ptr address_space_cast(ElementType *pointer); + +namespace detail { +template +std::add_pointer_t +get_raw_pointer(multi_ptr Ptr); + +template +auto builtin_element_ptr( + multi_ptr Ptr); +} // namespace detail } // namespace _V1 } // namespace sycl diff --git a/sycl/include/sycl/multi_ptr.hpp b/sycl/include/sycl/multi_ptr.hpp index 3106bd3cc1a6c..ade21a5a6cd9b 100644 --- a/sycl/include/sycl/multi_ptr.hpp +++ b/sycl/include/sycl/multi_ptr.hpp @@ -1377,6 +1377,21 @@ address_space_cast(ElementType *pointer) { pointer)); } +namespace detail { +template +std::add_pointer_t +get_raw_pointer(multi_ptr Ptr) { + return Ptr.get_raw(); +} + +template +auto builtin_element_ptr(multi_ptr Ptr) { + return address_space_cast(&(*Ptr)[0]); +} +} // namespace detail + template < typename ElementType, access::address_space Space, access::decorated DecorateAddress = access::decorated::legacy, diff --git a/sycl/test/include_deps/sycl_khr_includes_math.hpp.cpp b/sycl/test/include_deps/sycl_khr_includes_math.hpp.cpp index 6b0b4af4e44f1..7ff05121b1979 100644 --- a/sycl/test/include_deps/sycl_khr_includes_math.hpp.cpp +++ b/sycl/test/include_deps/sycl_khr_includes_math.hpp.cpp @@ -17,9 +17,9 @@ // CHECK-NEXT: detail/export.hpp // CHECK-NEXT: memory_enums.hpp // CHECK-NEXT: __spirv/spirv_vars.hpp +// CHECK-NEXT: detail/fwd/multi_ptr.hpp // CHECK-NEXT: detail/type_traits.hpp // CHECK-NEXT: detail/type_traits/vec_marray_traits.hpp -// CHECK-NEXT: detail/fwd/multi_ptr.hpp // CHECK-NEXT: detail/vector_convert.hpp // CHECK-NEXT: detail/generic_type_traits.hpp // CHECK-NEXT: aliases.hpp @@ -37,8 +37,6 @@ // CHECK-NEXT: detail/common.hpp // CHECK-NEXT: detail/fwd/accessor.hpp // CHECK-NEXT: marray.hpp -// CHECK-NEXT: multi_ptr.hpp -// CHECK-NEXT: detail/address_space_cast.hpp // CHECK-NEXT: detail/builtins/common_functions.inc // CHECK-NEXT: detail/builtins/helper_macros.hpp // CHECK-NEXT: detail/builtins/geometric_functions.inc diff --git a/sycl/test/include_deps/sycl_khr_includes_usm.hpp.cpp b/sycl/test/include_deps/sycl_khr_includes_usm.hpp.cpp index e26c64af65dfb..c9dce9a912558 100644 --- a/sycl/test/include_deps/sycl_khr_includes_usm.hpp.cpp +++ b/sycl/test/include_deps/sycl_khr_includes_usm.hpp.cpp @@ -18,9 +18,9 @@ // CHECK-NEXT: detail/export.hpp // CHECK-NEXT: memory_enums.hpp // CHECK-NEXT: __spirv/spirv_vars.hpp +// CHECK-NEXT: detail/fwd/multi_ptr.hpp // CHECK-NEXT: detail/type_traits.hpp // CHECK-NEXT: detail/type_traits/vec_marray_traits.hpp -// CHECK-NEXT: detail/fwd/multi_ptr.hpp // CHECK-NEXT: detail/vector_convert.hpp // CHECK-NEXT: detail/generic_type_traits.hpp // CHECK-NEXT: aliases.hpp @@ -38,8 +38,6 @@ // CHECK-NEXT: detail/common.hpp // CHECK-NEXT: detail/fwd/accessor.hpp // CHECK-NEXT: marray.hpp -// CHECK-NEXT: multi_ptr.hpp -// CHECK-NEXT: detail/address_space_cast.hpp // CHECK-NEXT: detail/builtins/common_functions.inc // CHECK-NEXT: detail/builtins/helper_macros.hpp // CHECK-NEXT: detail/builtins/geometric_functions.inc @@ -102,6 +100,8 @@ // CHECK-NEXT: ext/oneapi/accessor_property_list.hpp // CHECK-NEXT: detail/accessor_iterator.hpp // CHECK-NEXT: detail/handler_proxy.hpp +// CHECK-NEXT: multi_ptr.hpp +// CHECK-NEXT: detail/address_space_cast.hpp // CHECK-NEXT: pointers.hpp // CHECK-NEXT: properties/accessor_properties.hpp // CHECK-NEXT: properties/runtime_accessor_properties.def diff --git a/sycl/test/regression/builtins_multi_ptr_include_order.cpp b/sycl/test/regression/builtins_multi_ptr_include_order.cpp new file mode 100644 index 0000000000000..b8f2033958e76 --- /dev/null +++ b/sycl/test/regression/builtins_multi_ptr_include_order.cpp @@ -0,0 +1,68 @@ +// RUN: %clangxx -fsycl -fsyntax-only -Wno-deprecated-declarations %s -DTEST_BUILTINS_ONLY +// RUN: %clangxx -fsycl -fsyntax-only -Wno-deprecated-declarations %s -DTEST_BUILTINS_FIRST +// RUN: %clangxx -fsycl -fsyntax-only -Wno-deprecated-declarations %s -DTEST_MULTI_PTR_FIRST + +// Regression coverage for builtins/multi_ptr decoupling. +// We want to preserve these behaviors: +// 1. compiles without including . +// 2. Including builtins before multi_ptr still allows later multi_ptr +// instantiation for scalar pointer builtins. +// 3. Including builtins before multi_ptr still allows later multi_ptr +// instantiation for vector pointer builtins. +// 4. Including multi_ptr before builtins also works for those builtin calls. + +#if defined(TEST_BUILTINS_ONLY) +#include + +int main() { + auto Value = sycl::fmin(1.0f, 2.0f); + (void)Value; + return 0; +} + +#elif defined(TEST_BUILTINS_FIRST) +#include +#include + +SYCL_EXTERNAL void testScalar( + sycl::multi_ptr + Ptr) { + (void)sycl::modf(1.0f, Ptr); + (void)sycl::sincos(1.0f, Ptr); +} + +SYCL_EXTERNAL void testVector( + sycl::multi_ptr, + sycl::access::address_space::global_space, + sycl::access::decorated::no> + Ptr) { + sycl::vec Value{1.0f, 2.0f}; + (void)sycl::fract(Value, Ptr); +} + +int main() { return 0; } + +#elif defined(TEST_MULTI_PTR_FIRST) +#include +#include + +SYCL_EXTERNAL void testScalar( + sycl::multi_ptr + Ptr) { + (void)sycl::modf(1.0f, Ptr); + (void)sycl::sincos(1.0f, Ptr); +} + +SYCL_EXTERNAL void testVector( + sycl::multi_ptr, + sycl::access::address_space::global_space, + sycl::access::decorated::no> + Ptr) { + sycl::vec Value{1.0f, 2.0f}; + (void)sycl::fract(Value, Ptr); +} + +int main() { return 0; } +#endif \ No newline at end of file From ba2a3567b29d3ba23603793ba1debcea074cf0f5 Mon Sep 17 00:00:00 2001 From: yes Date: Sun, 5 Apr 2026 15:32:32 -0700 Subject: [PATCH 2/2] Stop pulling `sycl/detail/vector_convert.hpp` into `detail/builtins/builtins.hpp`. This keeps the relational builtins path self-contained by introducing a narrow relational_mask_widen helper for relation-mask widening, and replaces the old transitive dependencies with the direct headers that builtins actually needs (generic_type_traits.hpp, half_type.hpp). It also adds the missing direct exception.hpp include for ext/oneapi/bf16_storage_builtins.hpp and updates the affected include-deps / IR checks. Measured with measure_builtin.cpp and -ftime-trace against the base branch: host Total ExecuteCompiler: 604.853 ms -> 533.452 ms (-71.401 ms, -11.80%) host Total Frontend: 600.196 ms -> 529.366 ms (-70.830 ms, -11.80%) device Total ExecuteCompiler: 598.557 ms -> 420.885 ms (-177.672 ms, -29.68%) device Total Frontend: 593.798 ms -> 416.060 ms (-177.738 ms, -29.93%) changes --- .../include/sycl/detail/builtins/builtins.hpp | 46 +- .../sycl/detail/builtins/common_functions.inc | 14 +- .../detail/builtins/geometric_functions.inc | 2 +- .../sycl/detail/builtins/helper_macros.hpp | 10 +- .../detail/builtins/integer_functions.inc | 6 +- .../sycl/detail/builtins/math_functions.inc | 10 +- .../detail/builtins/relational_functions.inc | 4 +- .../detail/type_traits/vec_marray_traits.hpp | 21 + sycl/include/sycl/detail/vector_base.hpp | 147 ++ sycl/include/sycl/detail/vector_core.hpp | 324 ++++ sycl/include/sycl/detail/vector_swizzle.hpp | 294 ++++ .../include/sycl/detail/vector_swizzle_op.hpp | 870 +++++++++ .../sycl/ext/intel/math/imf_half_trivial.hpp | 1 + .../sycl/ext/oneapi/bf16_storage_builtins.hpp | 1 + sycl/include/sycl/vector.hpp | 1547 +---------------- sycl/source/builtins/host_helper_macros.hpp | 2 + .../vector/bf16_builtins_new_vec.cpp | 8 +- .../vector/bf16_builtins_old_vec.cpp | 8 +- ...sycl_khr_includes_group_algorithms.hpp.cpp | 4 + .../sycl_khr_includes_groups.hpp.cpp | 4 + .../sycl_khr_includes_images.hpp.cpp | 4 + .../sycl_khr_includes_interop_handle.hpp.cpp | 4 + .../sycl_khr_includes_math.hpp.cpp | 12 +- .../sycl_khr_includes_reduction.hpp.cpp | 11 +- .../sycl_khr_includes_stream.hpp.cpp | 4 - .../sycl_khr_includes_usm.hpp.cpp | 14 +- .../sycl_khr_includes_vec.hpp.cpp | 12 +- .../builtins_multi_ptr_include_order.cpp | 2 + 28 files changed, 1780 insertions(+), 1606 deletions(-) create mode 100644 sycl/include/sycl/detail/vector_base.hpp create mode 100644 sycl/include/sycl/detail/vector_core.hpp create mode 100644 sycl/include/sycl/detail/vector_swizzle.hpp create mode 100644 sycl/include/sycl/detail/vector_swizzle_op.hpp diff --git a/sycl/include/sycl/detail/builtins/builtins.hpp b/sycl/include/sycl/detail/builtins/builtins.hpp index 4d0c09bedd6f4..b2fa267efcd85 100644 --- a/sycl/include/sycl/detail/builtins/builtins.hpp +++ b/sycl/include/sycl/detail/builtins/builtins.hpp @@ -65,11 +65,14 @@ #include #include +#include +#include #include #include -#include +#include #include -#include + +#include namespace sycl { inline namespace _V1 { @@ -83,6 +86,19 @@ template struct use_fast_math : std::false_type {}; #endif template constexpr bool use_fast_math_v = use_fast_math::value; +// Utility trait for getting the decoration of a multi_ptr. +template struct get_multi_ptr_decoration; +template +struct get_multi_ptr_decoration< + multi_ptr> { + static constexpr access::decorated value = DecorateAddress; +}; + +template +constexpr access::decorated get_multi_ptr_decoration_v = + get_multi_ptr_decoration::value; + // Utility trait for checking if a multi_ptr has a "writable" address space, // i.e. global, local, private or generic. template struct has_writeable_addr_space : std::false_type {}; @@ -197,6 +213,28 @@ template marray to_marray(vec X) { return Marray; } +// Relation builtins widen signed-char masks to the required integer element +// type. Keep that conversion local here so builtins.hpp does not need to pull +// in vector_convert.hpp just for vec::convert. +template +vec relational_mask_widen(vec X) { + static_assert(is_scalar_arithmetic_v); + +#ifdef __SYCL_DEVICE_ONLY__ + if constexpr (N > 1) { + using src_vector_t = signed char __attribute__((ext_vector_type(N))); + using dst_vector_t = NewElemT __attribute__((ext_vector_type(N))); + auto OpenCLVec = bit_cast(X); + return bit_cast>( + __builtin_convertvector(OpenCLVec, dst_vector_t)); + } +#endif + + vec Result{}; + loop([&](auto idx) { Result[idx] = static_cast(X[idx]); }); + return Result; +} + namespace builtins { #ifdef __SYCL_DEVICE_ONLY__ template auto convert_arg(T &&x) { @@ -212,7 +250,7 @@ template auto convert_arg(T &&x) { __attribute__((ext_vector_type(N)))>; return bit_cast(x); } else if constexpr (is_swizzle_v) { - return convert_arg(simplify_if_swizzle_t{x}); + return convert_arg(materialize_if_swizzle(std::forward(x))); } else { static_assert(is_scalar_arithmetic_v || is_multi_ptr_v || std::is_pointer_v || @@ -262,7 +300,7 @@ auto builtin_default_host_impl(FuncTy F, const Ts &...x) { if constexpr ((... || is_marray_v)) { return builtin_marray_impl(F, x...); } else { - return F(simplify_if_swizzle_t{x}...); + return F(materialize_if_swizzle(x)...); } } diff --git a/sycl/include/sycl/detail/builtins/common_functions.inc b/sycl/include/sycl/detail/builtins/common_functions.inc index 8c8e3e46277a9..40f84da4980b9 100644 --- a/sycl/include/sycl/detail/builtins/common_functions.inc +++ b/sycl/include/sycl/detail/builtins/common_functions.inc @@ -34,8 +34,8 @@ BUILTIN_COMMON(THREE_ARGS, mix, __spirv_ocl_mix) template detail::builtin_enable_common_non_scalar_t mix(T0 x, T1 y, detail::get_elem_type_t z) { - return mix(detail::simplify_if_swizzle_t{x}, - detail::simplify_if_swizzle_t{y}, + return mix(detail::materialize_if_swizzle(x), + detail::materialize_if_swizzle(y), detail::simplify_if_swizzle_t{z}); } @@ -44,7 +44,7 @@ template detail::builtin_enable_common_non_scalar_t step(detail::get_elem_type_t x, T y) { return step(detail::simplify_if_swizzle_t{x}, - detail::simplify_if_swizzle_t{y}); + detail::materialize_if_swizzle(y)); } BUILTIN_COMMON(THREE_ARGS, smoothstep, __spirv_ocl_smoothstep) @@ -53,14 +53,14 @@ detail::builtin_enable_common_non_scalar_t smoothstep(detail::get_elem_type_t x, detail::get_elem_type_t y, T z) { return smoothstep(detail::simplify_if_swizzle_t{x}, detail::simplify_if_swizzle_t{y}, - detail::simplify_if_swizzle_t{z}); + detail::materialize_if_swizzle(z)); } BUILTIN_COMMON(TWO_ARGS, max, __spirv_ocl_fmax_common) template detail::builtin_enable_common_non_scalar_t(max)( T x, detail::get_elem_type_t y) { - return (max)(detail::simplify_if_swizzle_t{x}, + return (max)(detail::materialize_if_swizzle(x), detail::simplify_if_swizzle_t{y}); } @@ -68,7 +68,7 @@ BUILTIN_COMMON(TWO_ARGS, min, __spirv_ocl_fmin_common) template detail::builtin_enable_common_non_scalar_t(min)( T x, detail::get_elem_type_t y) { - return (min)(detail::simplify_if_swizzle_t{x}, + return (min)(detail::materialize_if_swizzle(x), detail::simplify_if_swizzle_t{y}); } @@ -76,7 +76,7 @@ BUILTIN_COMMON(THREE_ARGS, clamp, __spirv_ocl_fclamp) template detail::builtin_enable_common_non_scalar_t clamp(T x, detail::get_elem_type_t y, detail::get_elem_type_t z) { - return clamp(detail::simplify_if_swizzle_t{x}, + return clamp(detail::materialize_if_swizzle(x), detail::simplify_if_swizzle_t{y}, detail::simplify_if_swizzle_t{z}); } diff --git a/sycl/include/sycl/detail/builtins/geometric_functions.inc b/sycl/include/sycl/detail/builtins/geometric_functions.inc index 1ee7d34fded87..fc018beb085bf 100644 --- a/sycl/include/sycl/detail/builtins/geometric_functions.inc +++ b/sycl/include/sycl/detail/builtins/geometric_functions.inc @@ -44,7 +44,7 @@ auto builtin_delegate_geo_impl(FuncTy F, const Ts &...x) { else return ret; } else { - return F(simplify_if_swizzle_t{x}...); + return F(materialize_if_swizzle(x)...); } } } // namespace detail diff --git a/sycl/include/sycl/detail/builtins/helper_macros.hpp b/sycl/include/sycl/detail/builtins/helper_macros.hpp index 3b320d4e0772a..d1936027b8baa 100644 --- a/sycl/include/sycl/detail/builtins/helper_macros.hpp +++ b/sycl/include/sycl/detail/builtins/helper_macros.hpp @@ -122,14 +122,12 @@ #define THREE_ARGS_ARG x, y, z #define ONE_ARG_SIMPLIFIED_ARG \ - simplify_if_swizzle_t { x } + detail::materialize_if_swizzle(x) #define TWO_ARGS_SIMPLIFIED_ARG \ - simplify_if_swizzle_t{x}, simplify_if_swizzle_t { y } + detail::materialize_if_swizzle(x), detail::materialize_if_swizzle(y) #define THREE_ARGS_SIMPLIFIED_ARG \ - simplify_if_swizzle_t{x}, simplify_if_swizzle_t{y}, \ - simplify_if_swizzle_t { \ - z \ - } + detail::materialize_if_swizzle(x), detail::materialize_if_swizzle(y), \ + detail::materialize_if_swizzle(z) #define TWO_ARGS_ARG_ROTATED y, x #define THREE_ARGS_ARG_ROTATED z, x, y diff --git a/sycl/include/sycl/detail/builtins/integer_functions.inc b/sycl/include/sycl/detail/builtins/integer_functions.inc index 65b91160ac260..60fe11d04c386 100644 --- a/sycl/include/sycl/detail/builtins/integer_functions.inc +++ b/sycl/include/sycl/detail/builtins/integer_functions.inc @@ -105,7 +105,7 @@ BUILTIN_GENINT_SU(TWO_ARGS, max) template detail::builtin_enable_integer_non_scalar_t(max)( T x, detail::get_elem_type_t y) { - return (max)(detail::simplify_if_swizzle_t{x}, + return (max)(detail::materialize_if_swizzle(x), detail::simplify_if_swizzle_t{y}); } @@ -113,7 +113,7 @@ BUILTIN_GENINT_SU(TWO_ARGS, min) template detail::builtin_enable_integer_non_scalar_t(min)( T x, detail::get_elem_type_t y) { - return (min)(detail::simplify_if_swizzle_t{x}, + return (min)(detail::materialize_if_swizzle(x), detail::simplify_if_swizzle_t{y}); } @@ -121,7 +121,7 @@ BUILTIN_GENINT_SU(THREE_ARGS, clamp) template detail::builtin_enable_integer_non_scalar_t clamp(T x, detail::get_elem_type_t y, detail::get_elem_type_t z) { - return clamp(detail::simplify_if_swizzle_t{x}, + return clamp(detail::materialize_if_swizzle(x), detail::simplify_if_swizzle_t{y}, detail::simplify_if_swizzle_t{z}); } diff --git a/sycl/include/sycl/detail/builtins/math_functions.inc b/sycl/include/sycl/detail/builtins/math_functions.inc index 62f5484517ba6..301c09eb6cf66 100644 --- a/sycl/include/sycl/detail/builtins/math_functions.inc +++ b/sycl/include/sycl/detail/builtins/math_functions.inc @@ -167,7 +167,7 @@ BUILTIN_GENF(THREE_ARGS, mad) BUILTIN_GENF(TWO_ARGS, NAME) \ template \ detail::builtin_enable_math_t NAME(T x, detail::get_elem_type_t y) { \ - return NAME(detail::simplify_if_swizzle_t{x}, \ + return NAME(detail::materialize_if_swizzle(x), \ detail::simplify_if_swizzle_t{y}); \ } @@ -331,7 +331,7 @@ BUILTIN_LAST_INTPTR(THREE_ARGS, remquo) #ifndef __SYCL_DEVICE_ONLY__ namespace detail { template auto fract_impl(T0 &x, T1 &y) { - auto flr = floor(simplify_if_swizzle_t{x}); + auto flr = floor(materialize_if_swizzle(x)); *y = flr; return fmin(x - flr, nextafter(simplify_if_swizzle_t{1.0}, simplify_if_swizzle_t{0.0})); @@ -354,7 +354,7 @@ template auto modf_impl(T0 &x, T1 &&y) { } else { return builtin_delegate_ptr_impl( [](auto x, auto y) { return modf_impl(x, y); }, y, - simplify_if_swizzle_t{x}); + materialize_if_swizzle(x)); } } } // namespace detail @@ -390,7 +390,7 @@ BUILTIN_MATH_LAST_INT(rootn) BUILTIN_MATH_LAST_INT(ldexp) template detail::builtin_enable_math_t ldexp(T x, int y) { return ldexp( - detail::simplify_if_swizzle_t{x}, + detail::materialize_if_swizzle(x), detail::change_elements_t>{y}); } @@ -432,7 +432,7 @@ template auto sincos_impl(T0 &x, T1 &&y) { } else { return builtin_delegate_ptr_impl( [](auto... xs) { return sincos_impl(xs...); }, y, - simplify_if_swizzle_t{x}); + materialize_if_swizzle(x)); } } #endif diff --git a/sycl/include/sycl/detail/builtins/relational_functions.inc b/sycl/include/sycl/detail/builtins/relational_functions.inc index fea8f59e2993f..a76534e6230e8 100644 --- a/sycl/include/sycl/detail/builtins/relational_functions.inc +++ b/sycl/include/sycl/detail/builtins/relational_functions.inc @@ -66,7 +66,7 @@ auto builtin_device_rel_impl(FuncTy F, const Ts &...xs) { auto tmp = bit_cast::value>>(ret); using res_elem_type = fixed_width_signed)>; static_assert(is_scalar_arithmetic_v); - return tmp.template convert(); + return relational_mask_widen(tmp); } else if constexpr (std::is_same_v) { return bool{F(builtins::convert_arg(xs)...)}; } else { @@ -80,7 +80,7 @@ template auto builtin_delegate_rel_impl(FuncTy F, const Ts &...x) { using T = typename first_type::type; if constexpr ((... || is_swizzle_v)) { - return F(simplify_if_swizzle_t{x}...); + return F(materialize_if_swizzle(x)...); } else if constexpr (is_vec_v) { // TODO: using Res{} to avoid Werror. Not sure if ok. vec, T::size()> Res{}; diff --git a/sycl/include/sycl/detail/type_traits/vec_marray_traits.hpp b/sycl/include/sycl/detail/type_traits/vec_marray_traits.hpp index 90779a0463d86..2b4f589baea4c 100644 --- a/sycl/include/sycl/detail/type_traits/vec_marray_traits.hpp +++ b/sycl/include/sycl/detail/type_traits/vec_marray_traits.hpp @@ -10,6 +10,7 @@ #include #include +#include #include @@ -59,6 +60,26 @@ struct simplify_if_swizzle using simplify_if_swizzle_t = typename simplify_if_swizzle::type; +template constexpr decltype(auto) materialize_if_swizzle(T &&X) { + return std::forward(X); +} + +#if __SYCL_USE_LIBSYCL8_VEC_IMPL +template class OperationCurrentT, int... Indexes> +simplify_if_swizzle_t> +materialize_if_swizzle( + const SwizzleOp &X); +#else +template +simplify_if_swizzle_t> +materialize_if_swizzle(const detail::hide_swizzle_from_adl::Swizzle< + IsConstVec, DataT, VecSize, Indexes...> &X); +#endif + // --------- is_* traits ------------------ // template struct is_vec : std::false_type {}; template struct is_vec> : std::true_type {}; diff --git a/sycl/include/sycl/detail/vector_base.hpp b/sycl/include/sycl/detail/vector_base.hpp new file mode 100644 index 0000000000000..865deffa742c0 --- /dev/null +++ b/sycl/include/sycl/detail/vector_base.hpp @@ -0,0 +1,147 @@ +//==----------------- vector_base.hpp - vec storage helpers ---------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include + +#ifndef __SYCL_USE_PLAIN_ARRAY_AS_VEC_STORAGE +#define __SYCL_USE_PLAIN_ARRAY_AS_VEC_STORAGE !__SYCL_USE_LIBSYCL8_VEC_IMPL +#endif + +namespace sycl { +inline namespace _V1 { +namespace detail { + +template class vec_base_test; + +template class vec_base { + static constexpr size_t AdjustedNum = (NumElements == 3) ? 4 : NumElements; + using DataType = std::conditional_t< +#if __SYCL_USE_PLAIN_ARRAY_AS_VEC_STORAGE + true, +#else + sizeof(std::array) == sizeof(DataT[AdjustedNum]) && + alignof(std::array) == + alignof(DataT[AdjustedNum]), +#endif + DataT[AdjustedNum], std::array>; + + template friend class detail::vec_base_test; + +protected: + static constexpr int alignment = (std::min)((size_t)64, sizeof(DataType)); + alignas(alignment) DataType m_Data; + + template + constexpr vec_base(const DataT &Val, std::index_sequence) + : m_Data{((void)Is, Val)...} {} + + template + constexpr vec_base(const std::array &Arr, + std::index_sequence) + : m_Data{Arr[Is]...} {} + + template + static constexpr bool AllowArgTypeInVariadicCtor = []() constexpr { + if constexpr (std::is_convertible_v) { + return true; + } else if constexpr (is_vec_or_swizzle_v) { + if constexpr (CtorArgTy::size() == 1 && + std::is_convertible_v) { + return true; + } + return std::is_same_v; + } else { + return false; + } + }(); + + template static constexpr int num_elements() { + if constexpr (is_vec_or_swizzle_v) + return T::size(); + else + return 1; + } + + template class FlattenVecArg { + template + static constexpr auto helper(const T &V, std::index_sequence) { +#if __SYCL_USE_LIBSYCL8_VEC_IMPL + if constexpr (is_swizzle_v) + return std::array{static_cast(V.getValue(Is))...}; + else +#endif + return std::array{static_cast(V[Is])...}; + } + + public: + constexpr auto operator()(const T &A) const { + if constexpr (is_vec_or_swizzle_v) { + return helper(A, std::make_index_sequence()); + } else { + return std::array{static_cast(A)}; + } + } + }; + + template + using VecArgArrayCreator = ArrayCreator; + +public: + constexpr vec_base() = default; + constexpr vec_base(const vec_base &) = default; + constexpr vec_base(vec_base &&) = default; + constexpr vec_base &operator=(const vec_base &) = default; + constexpr vec_base &operator=(vec_base &&) = default; + + explicit constexpr vec_base(const DataT &arg) + : vec_base(arg, std::make_index_sequence()) {} + + template && ...)) && + ((num_elements() + ...)) == NumElements>> + constexpr vec_base(const argTN &...args) + : vec_base{VecArgArrayCreator::Create(args...), + std::make_index_sequence()} {} +}; + +#if !__SYCL_USE_LIBSYCL8_VEC_IMPL +template class vec_base { + using DataType = std::conditional_t< +#if __SYCL_USE_PLAIN_ARRAY_AS_VEC_STORAGE + true, +#else + sizeof(std::array) == sizeof(DataT[1]) && + alignof(std::array) == alignof(DataT[1]), +#endif + DataT[1], std::array>; + +protected: + static constexpr int alignment = (std::min)((size_t)64, sizeof(DataType)); + alignas(alignment) DataType m_Data; + +public: + constexpr vec_base() = default; + constexpr vec_base(const vec_base &) = default; + constexpr vec_base(vec_base &&) = default; + constexpr vec_base &operator=(const vec_base &) = default; + constexpr vec_base &operator=(vec_base &&) = default; + + constexpr vec_base(const DataT &arg) : m_Data{{arg}} {} +}; +#endif + +} // namespace detail +} // namespace _V1 +} // namespace sycl \ No newline at end of file diff --git a/sycl/include/sycl/detail/vector_core.hpp b/sycl/include/sycl/detail/vector_core.hpp new file mode 100644 index 0000000000000..bc7e0919d66bb --- /dev/null +++ b/sycl/include/sycl/detail/vector_core.hpp @@ -0,0 +1,324 @@ +//==---------------- vector_core.hpp - sycl::vec class --------------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +// Check if Clang's ext_vector_type attribute is available. Host compiler +// may not be Clang, and Clang may not be built with the extension. +#ifdef __clang__ +#ifndef __has_extension +#define __has_extension(x) 0 +#endif +#ifndef __HAS_EXT_VECTOR_TYPE__ +#if __has_extension(attribute_ext_vector_type) +#define __HAS_EXT_VECTOR_TYPE__ +#endif +#endif +#endif // __clang__ + +// See vec::DataType definitions for more details +#ifndef __SYCL_USE_PLAIN_ARRAY_AS_VEC_STORAGE +#define __SYCL_USE_PLAIN_ARRAY_AS_VEC_STORAGE !__SYCL_USE_LIBSYCL8_VEC_IMPL +#endif + +#if !defined(__HAS_EXT_VECTOR_TYPE__) && defined(__SYCL_DEVICE_ONLY__) +#error "SYCL device compiler is built without ext_vector_type support" +#endif + +#include + +#include + +#include +#include +#include +#include + +#include + +namespace sycl { +inline namespace _V1 { + +///////////////////////// class sycl::vec ///////////////////////// +// Provides a cross-platform vector class template that works efficiently on +// SYCL devices as well as in host C++ code. +template +class __SYCL_EBO vec : +#if __SYCL_USE_LIBSYCL8_VEC_IMPL + public detail::vec_arith, +#else + public detail::VecOperators>::Combined, +#endif + public detail::ApplyIf< + NumElements == 1, + detail::ScalarConversionOperatorsMixIn>>, + public detail::NamedSwizzlesMixinBoth>, + // Keep it last to simplify ABI layout test: + public detail::vec_base { + static_assert(std::is_same_v>, + "DataT must be cv-unqualified"); + + static_assert(detail::is_allowed_vec_size_v, + "Invalid number of elements for sycl::vec: only 1, 2, 3, 4, 8 " + "or 16 are supported"); + static_assert(sizeof(bool) == sizeof(uint8_t), "bool size is not 1 byte"); + + using Base = detail::vec_base; + +#if __SYCL_USE_LIBSYCL8_VEC_IMPL +#ifdef __SYCL_DEVICE_ONLY__ + using element_type_for_vector_t = typename detail::map_type< + DataT, +#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0) + std::byte, /*->*/ std::uint8_t, // +#endif + bool, /*->*/ std::uint8_t, // + sycl::half, /*->*/ sycl::detail::half_impl::StorageT, // + sycl::ext::oneapi::bfloat16, /*->*/ uint16_t, // + char, /*->*/ detail::ConvertToOpenCLType_t, // + DataT, /*->*/ DataT // + >::type; + +public: + // Type used for passing sycl::vec to SPIRV builtins. + // We can not use ext_vector_type(1) as it's not supported by SPIRV + // plugins (CTS fails). + using vector_t = + typename std::conditional_t; + + // Make it a template to avoid ambiguity with `vec(const DataT &)` when + // `vector_t` is the same as `DataT`. Not that the other ctor isn't a template + // so we don't even need a smart `enable_if` condition here, the mere fact of + // this being a template makes the other ctor preferred. + // For vectors of length 3, make sure to only copy 3 elements, not 4, to work + // around code generation issues, see LLVM #144454. + template < + typename vector_t_ = vector_t, + typename = typename std::enable_if_t>> + constexpr vec(vector_t_ openclVector) { + sycl::detail::memcpy_no_adl(&this->m_Data, &openclVector, + NumElements * + sizeof(element_type_for_vector_t)); + } + + /* @SYCL2020 + * Available only when: compiled for the device. + * Converts this SYCL vec instance to the underlying backend-native vector + * type defined by vector_t. + */ + operator vector_t() const { return sycl::bit_cast(this->m_Data); } + +private: +#endif // __SYCL_DEVICE_ONLY__ +#endif + +#if __SYCL_USE_LIBSYCL8_VEC_IMPL + template + using Swizzle = + detail::SwizzleOp, detail::GetOp, + detail::GetOp, Indexes...>; + + template + using ConstSwizzle = + detail::SwizzleOp, detail::GetOp, + detail::GetOp, Indexes...>; +#else + template + using Swizzle = + detail::hide_swizzle_from_adl::Swizzle; + + template + using ConstSwizzle = + detail::hide_swizzle_from_adl::Swizzle; +#endif + + // Element type for relational operator return value. + using rel_t = detail::fixed_width_signed; + +public: + // Aliases required by SYCL 2020 to make sycl::vec consistent + // with that of marray and buffer. + using element_type = DataT; + using value_type = DataT; + + using Base::Base; + constexpr vec(const vec &) = default; + constexpr vec(vec &&) = default; + + /****************** Assignment Operators **************/ + constexpr vec &operator=(const vec &) = default; + constexpr vec &operator=(vec &&) = default; + +#if __SYCL_USE_LIBSYCL8_VEC_IMPL + // Template required to prevent ambiguous overload with the copy assignment + // when NumElements == 1. The template prevents implicit conversion from + // vec<_, 1> to DataT. + template + typename std::enable_if_t, + vec &> + operator=(const DataT &Rhs) { + *this = vec{Rhs}; + return *this; + } + + // W/o this, things like "vec = vec" doesn't work. + template + typename std::enable_if_t< + !std::is_same_v && std::is_convertible_v, vec &> + operator=(const vec &Rhs) { + *this = Rhs.template as(); + return *this; + } +#else + template + typename std::enable_if_t, vec &> + operator=(const T &Rhs) { + *this = vec{static_cast(Rhs)}; + return *this; + } +#endif + + __SYCL2020_DEPRECATED("get_count() is deprecated, please use size() instead") + static constexpr size_t get_count() { return size(); } + static constexpr size_t size() noexcept { return NumElements; } + __SYCL2020_DEPRECATED( + "get_size() is deprecated, please use byte_size() instead") + static constexpr size_t get_size() { return byte_size(); } + static constexpr size_t byte_size() noexcept { return sizeof(Base); } + +#if __SYCL_USE_LIBSYCL8_VEC_IMPL +private: + // getValue should be able to operate on different underlying + // types: enum cl_float#N , builtin vector float#N, builtin type float. + constexpr auto getValue(int Index) const { + using RetType = + typename std::conditional_t, int8_t, +#ifdef __SYCL_DEVICE_ONLY__ + element_type_for_vector_t +#else + DataT +#endif + >; + +#ifdef __SYCL_DEVICE_ONLY__ + if constexpr (std::is_same_v) + return sycl::bit_cast(this->m_Data[Index]); + else +#endif + return static_cast(this->m_Data[Index]); + } + +public: +#endif + + // Out-of-class definition is in `sycl/detail/vector_convert.hpp` + template + vec convert() const; + + template asT as() const { return sycl::bit_cast(*this); } + + template Swizzle swizzle() { +#if __SYCL_USE_LIBSYCL8_VEC_IMPL + return this; +#else + return Swizzle{*this}; +#endif + } + + template + ConstSwizzle swizzle() const { +#if __SYCL_USE_LIBSYCL8_VEC_IMPL + return this; +#else + return ConstSwizzle{*this}; +#endif + } + + const DataT &operator[](int i) const { return this->m_Data[i]; } + + DataT &operator[](int i) { return this->m_Data[i]; } + + template + void load(size_t Offset, multi_ptr Ptr) { + for (int I = 0; I < NumElements; I++) { + this->m_Data[I] = *multi_ptr( + Ptr + Offset * NumElements + I); + } + } + template + void load(size_t Offset, multi_ptr Ptr) { + multi_ptr ConstPtr(Ptr); + load(Offset, ConstPtr); + } + template + void + load(size_t Offset, + accessor + Acc) { + multi_ptr::AS, + access::decorated::yes> + MultiPtr(Acc); + load(Offset, MultiPtr); + } + void load(size_t Offset, const DataT *Ptr) { + for (int I = 0; I < NumElements; ++I) + this->m_Data[I] = Ptr[Offset * NumElements + I]; + } + + template + void store(size_t Offset, + multi_ptr Ptr) const { + for (int I = 0; I < NumElements; I++) { + *multi_ptr(Ptr + Offset * NumElements + + I) = this->m_Data[I]; + } + } + template + void + store(size_t Offset, + accessor + Acc) { + multi_ptr::AS, access::decorated::yes> + MultiPtr(Acc); + store(Offset, MultiPtr); + } + void store(size_t Offset, DataT *Ptr) const { + for (int I = 0; I < NumElements; ++I) + Ptr[Offset * NumElements + I] = this->m_Data[I]; + } + + // friends + template class T4, + int... T5> + friend class detail::SwizzleOp; + template friend class __SYCL_EBO vec; +#if __SYCL_USE_LIBSYCL8_VEC_IMPL + // To allow arithmetic operators access private members of vec. + template friend class detail::vec_arith; +#endif +}; +///////////////////////// class sycl::vec ///////////////////////// + +#ifdef __cpp_deduction_guides +// all compilers supporting deduction guides also support fold expressions +template && ...)>> +vec(T, U...) -> vec; +#endif + +} // namespace _V1 +} // namespace sycl \ No newline at end of file diff --git a/sycl/include/sycl/detail/vector_swizzle.hpp b/sycl/include/sycl/detail/vector_swizzle.hpp new file mode 100644 index 0000000000000..dea6b015042d7 --- /dev/null +++ b/sycl/include/sycl/detail/vector_swizzle.hpp @@ -0,0 +1,294 @@ +//==------------- vector_swizzle.hpp - vec/swizzle support ----------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include + +namespace sycl { + +// TODO: It should be within _V1 namespace, fix in the next ABI breaking +// windows. +enum class rounding_mode { automatic = 0, rte = 1, rtz = 2, rtp = 3, rtn = 4 }; + +inline namespace _V1 { +namespace ext::oneapi { +class bfloat16; +} + +struct elem { + static constexpr int x = 0; + static constexpr int y = 1; + static constexpr int z = 2; + static constexpr int w = 3; + static constexpr int r = 0; + static constexpr int g = 1; + static constexpr int b = 2; + static constexpr int a = 3; + static constexpr int s0 = 0; + static constexpr int s1 = 1; + static constexpr int s2 = 2; + static constexpr int s3 = 3; + static constexpr int s4 = 4; + static constexpr int s5 = 5; + static constexpr int s6 = 6; + static constexpr int s7 = 7; + static constexpr int s8 = 8; + static constexpr int s9 = 9; + static constexpr int sA = 10; + static constexpr int sB = 11; + static constexpr int sC = 12; + static constexpr int sD = 13; + static constexpr int sE = 14; + static constexpr int sF = 15; +}; + +namespace detail { +template class OperationCurrentT, int... Indexes> +class SwizzleOp; + +// Special type indicating that SwizzleOp should just read value from vector - +// not trying to perform any operations. Should not be called. +template class GetOp { +public: + using DataT = T; + DataT getValue(size_t) const { return (DataT)0; } + DataT operator()(DataT, DataT) { return (DataT)0; } +}; + +// Templated vs. non-templated conversion operator behaves differently when two +// conversions are needed as in the case below: +// +// sycl::vec v; +// (void)static_cast(v); +// +// Make sure the snippet above compiles. That is important because +// +// sycl::vec v; +// if (v.x() == 42) +// ... +// +// must go throw `v.x()` returning a swizzle, then its `operator==` returning +// vec and we want that code to compile. +template class ScalarConversionOperatorsMixIn { + using element_type = typename from_incomplete::element_type; + +public: + operator element_type() const { + return (*static_cast(this))[0]; + } + +#if !__SYCL_USE_LIBSYCL8_VEC_IMPL + template < + typename T, typename = std::enable_if_t>, + typename = + std::void_t(std::declval()))>> + explicit operator T() const { + return static_cast((*static_cast(this))[0]); + } +#endif +}; + +template +inline constexpr bool is_fundamental_or_half_or_bfloat16 = + std::is_fundamental_v || std::is_same_v, half> || + std::is_same_v, ext::oneapi::bfloat16>; + +#if !__SYCL_USE_LIBSYCL8_VEC_IMPL +template class ConversionToVecMixin { + using vec_ty = typename from_incomplete::result_vec_ty; + +public: + operator vec_ty() const { + auto &self = *static_cast(this); + if constexpr (vec_ty::size() == 1) + // Avoid recursion by explicitly going through `vec(const DataT &)` ctor. + return vec_ty{static_cast(self)}; + else + // Uses `vec`'s variadic ctor. + return vec_ty{self}; + } +}; + +template class SwizzleBase { + using VecT = typename from_incomplete::vec_ty; + +public: + explicit SwizzleBase(VecT &Vec) : Vec(Vec) {} + + const Self &operator=(const Self &) = delete; + +protected: + VecT &Vec; +}; + +template +class SwizzleBase::is_assignable>> { + using VecT = typename from_incomplete::vec_ty; + using ResultVecT = typename from_incomplete::result_vec_ty; + + using DataT = typename from_incomplete::element_type; + static constexpr int N = from_incomplete::size(); + +public: + explicit SwizzleBase(VecT &Vec) : Vec(Vec) {} + + template + void load(size_t offset, + multi_ptr ptr) const { + ResultVecT v; + v.load(offset, ptr); + *static_cast(this) = v; + } + + template + std::enable_if_t + operator=(const detail::hide_swizzle_from_adl::Swizzle< + OtherIsConstVec, DataT, OtherVecSize, OtherIndexes...> &rhs) { + return (*this = static_cast(rhs)); + } + + const Self &operator=(const ResultVecT &rhs) const { + for (int i = 0; i < N; ++i) + (*static_cast(this))[i] = rhs[i]; + + return *static_cast(this); + } + + template && + !is_swizzle_v>> + const Self &operator=(const T &rhs) const { + for (int i = 0; i < N; ++i) + (*static_cast(this))[i] = static_cast(rhs); + + return *static_cast(this); + } + + // Default copy-assignment. Self's implicitly generated copy-assignment uses + // this. + // + // We're templated on "Self", so each swizzle has its own SwizzleBase and the + // following is ok (1-to-1 bidirectional mapping between Self and its + // SwizzleBase instantiation) even if a bit counterintuitive. + const SwizzleBase &operator=(const SwizzleBase &rhs) const { + const Self &self = (*static_cast(this)); + self = static_cast(static_cast(rhs)); + return self; + } + +protected: + VecT &Vec; +}; + +namespace hide_swizzle_from_adl { +// Can't have sycl::vec anywhere in template parameters because that would bring +// its hidden friends into ADL. Put it in a dedicated namespace to avoid +// anything extra via ADL as well. +template +class __SYCL_EBO Swizzle + : public SwizzleBase>, + public SwizzleOperators< + Swizzle>::Combined, + public ApplyIf>>, + public ConversionToVecMixin< + Swizzle>, + public NamedSwizzlesMixinBoth< + Swizzle> { + using Base = SwizzleBase>; + + static constexpr int NumElements = sizeof...(Indexes); + using ResultVec = vec; + + // Get underlying vec index for (*this)[idx] access. + static constexpr auto get_vec_idx(int idx) { + int counter = 0; + int result = -1; + ((result = counter++ == idx ? Indexes : result), ...); + return result; + } + +public: + using Base::Base; + using Base::operator=; + + using element_type = DataT; + using value_type = DataT; + +#if __SYCL_USE_LIBSYCL8_VEC_IMPL +#ifdef __SYCL_DEVICE_ONLY__ + using vector_t = typename vec::vector_t; +#endif // __SYCL_DEVICE_ONLY__ +#endif + + Swizzle() = delete; + Swizzle(const Swizzle &) = delete; + + static constexpr size_t byte_size() noexcept { + return ResultVec::byte_size(); + } + static constexpr size_t size() noexcept { return ResultVec::size(); } + + __SYCL2020_DEPRECATED( + "get_size() is deprecated, please use byte_size() instead") + size_t get_size() const { return static_cast(*this).get_size(); } + + __SYCL2020_DEPRECATED("get_count() is deprecated, please use size() instead") + size_t get_count() const { + return static_cast(*this).get_count(); + }; + + template + vec convert() const { + return static_cast(*this) + .template convert(); + } + + template asT as() const { + return static_cast(*this).template as(); + } + + template + void store(size_t offset, + multi_ptr ptr) const { + return static_cast(*this).store(offset, ptr); + } + + template auto swizzle() const { + return this->Vec.template swizzle(); + } + + auto &operator[](int index) const { return this->Vec[get_vec_idx(index)]; } +}; + +template +inline simplify_if_swizzle_t> +materialize_if_swizzle(const hide_swizzle_from_adl::Swizzle< + IsConstVec, DataT, VecSize, Indexes...> &X) { + return static_cast>>(X); +} +} // namespace hide_swizzle_from_adl +#endif +} // namespace detail +} // namespace _V1 +} // namespace sycl \ No newline at end of file diff --git a/sycl/include/sycl/detail/vector_swizzle_op.hpp b/sycl/include/sycl/detail/vector_swizzle_op.hpp new file mode 100644 index 0000000000000..c4cb28e285fc2 --- /dev/null +++ b/sycl/include/sycl/detail/vector_swizzle_op.hpp @@ -0,0 +1,870 @@ +//==----------- vector_swizzle_op.hpp - libsycl8 swizzle ops --------------==// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#pragma once + +#include + +#include +#include +#include + +namespace sycl { +inline namespace _V1 { + +#if __SYCL_USE_LIBSYCL8_VEC_IMPL +namespace detail { + +// Special type for working SwizzleOp with scalars, stores a scalar and gives +// the scalar at any index. Provides interface is compatible with SwizzleOp +// operations +template class GetScalarOp { +public: + using DataT = T; + GetScalarOp(DataT Data) : m_Data(Data) {} + DataT getValue(size_t) const { return m_Data; } + +private: + DataT m_Data; +}; +template using rel_t = detail::fixed_width_signed; + +template struct EqualTo { + constexpr rel_t operator()(const T &Lhs, const T &Rhs) const { + return (Lhs == Rhs) ? -1 : 0; + } +}; + +template struct NotEqualTo { + constexpr rel_t operator()(const T &Lhs, const T &Rhs) const { + return (Lhs != Rhs) ? -1 : 0; + } +}; + +template struct GreaterEqualTo { + constexpr rel_t operator()(const T &Lhs, const T &Rhs) const { + return (Lhs >= Rhs) ? -1 : 0; + } +}; + +template struct LessEqualTo { + constexpr rel_t operator()(const T &Lhs, const T &Rhs) const { + return (Lhs <= Rhs) ? -1 : 0; + } +}; + +template struct GreaterThan { + constexpr rel_t operator()(const T &Lhs, const T &Rhs) const { + return (Lhs > Rhs) ? -1 : 0; + } +}; + +template struct LessThan { + constexpr rel_t operator()(const T &Lhs, const T &Rhs) const { + return (Lhs < Rhs) ? -1 : 0; + } +}; + +template struct LogicalAnd { + constexpr rel_t operator()(const T &Lhs, const T &Rhs) const { + return (Lhs && Rhs) ? -1 : 0; + } +}; + +template struct LogicalOr { + constexpr rel_t operator()(const T &Lhs, const T &Rhs) const { + return (Lhs || Rhs) ? -1 : 0; + } +}; + +template struct RShift { + constexpr T operator()(const T &Lhs, const T &Rhs) const { + return Lhs >> Rhs; + } +}; + +template struct LShift { + constexpr T operator()(const T &Lhs, const T &Rhs) const { + return Lhs << Rhs; + } +}; + +///////////////////////// class SwizzleOp ///////////////////////// +// SwizzleOP represents expression templates that operate on vec. +// Actual computation performed on conversion or assignment operators. +template class OperationCurrentT, int... Indexes> +class SwizzleOp : public detail::NamedSwizzlesMixinBoth< + SwizzleOp, + sizeof...(Indexes)> { + using DataT = typename VecT::element_type; + +public: + using element_type = DataT; + using value_type = DataT; + + __SYCL2020_DEPRECATED("get_count() is deprecated, please use size() instead") + size_t get_count() const { return size(); } + static constexpr size_t size() noexcept { return sizeof...(Indexes); } + + template + __SYCL2020_DEPRECATED( + "get_size() is deprecated, please use byte_size() instead") + size_t get_size() const { + return byte_size(); + } + + template size_t byte_size() const noexcept { + return sizeof(DataT) * (Num == 3 ? 4 : Num); + } + +private: + // Certain operators return a vector with a different element type. Also, the + // left and right operand types may differ. CommonDataT selects a result type + // based on these types to ensure that the result value can be represented. + // + // Example 1: + // sycl::vec vec{...}; + // auto result = 300u + vec.x(); + // + // CommonDataT is std::common_type_t since + // it's larger than unsigned char. + // + // Example 2: + // sycl::vec vec{...}; + // auto result = vec.template swizzle() && vec; + // + // CommonDataT is DataT since operator&& returns a vector with element type + // int8_t, which is larger than bool. + // + // Example 3: + // sycl::vec vec{...}; auto swlo = vec.lo(); + // auto result = swlo == swlo; + // + // CommonDataT is DataT since operator== returns a vector with element type + // int8_t, which is the same size as std::byte. std::common_type_t + // can't be used here since there's no type that int8_t and std::byte can both + // be implicitly converted to. + using OpLeftDataT = typename OperationLeftT::DataT; + using OpRightDataT = typename OperationRightT::DataT; + using CommonDataT = std::conditional_t< + sizeof(DataT) >= sizeof(std::common_type_t), + DataT, std::common_type_t>; + + using rel_t = detail::rel_t; + using vec_t = vec; + using vec_rel_t = vec; + + template class OperationCurrentT_, int... Idx_> + using NewLHOp = SwizzleOp, + OperationRightT_, OperationCurrentT_, Idx_...>; + + template class OperationCurrentT_, int... Idx_> + using NewRelOp = SwizzleOp, + SwizzleOp, + OperationRightT_, OperationCurrentT_, Idx_...>; + + template class OperationCurrentT_, int... Idx_> + using NewRHOp = SwizzleOp, + OperationCurrentT_, Idx_...>; + + template + using EnableIfOneIndex = + typename std::enable_if_t<1 == IdxNum && SwizzleOp::size() == IdxNum, T>; + + template + using EnableIfMultipleIndexes = + typename std::enable_if_t<1 != IdxNum && SwizzleOp::size() == IdxNum, T>; + + template + using EnableIfScalarType = + typename std::enable_if_t && + detail::is_fundamental_or_half_or_bfloat16>; + + template + using EnableIfNoScalarType = + typename std::enable_if_t || + !detail::is_fundamental_or_half_or_bfloat16>; + + template + using Swizzle = + SwizzleOp, GetOp, GetOp, Indices...>; + + template + using ConstSwizzle = + SwizzleOp, GetOp, GetOp, Indices...>; + +public: +#ifdef __SYCL_DEVICE_ONLY__ + using vector_t = typename vec_t::vector_t; +#endif // __SYCL_DEVICE_ONLY__ + + const DataT &operator[](int i) const { + std::array Idxs{Indexes...}; + return (*m_Vector)[Idxs[i]]; + } + + template + std::enable_if_t, DataT> &operator[](int i) { + std::array Idxs{Indexes...}; + return (*m_Vector)[Idxs[i]]; + } + + template , + typename = EnableIfScalarType> + operator T() const { + return getValue(0); + } + + template > + friend NewRHOp, std::multiplies, Indexes...> + operator*(const T &Lhs, const SwizzleOp &Rhs) { + return NewRHOp, std::multiplies, Indexes...>( + Rhs.m_Vector, GetScalarOp(Lhs), Rhs); + } + + template > + friend NewRHOp, std::plus, Indexes...> + operator+(const T &Lhs, const SwizzleOp &Rhs) { + return NewRHOp, std::plus, Indexes...>( + Rhs.m_Vector, GetScalarOp(Lhs), Rhs); + } + + template > + friend NewRHOp, std::divides, Indexes...> + operator/(const T &Lhs, const SwizzleOp &Rhs) { + return NewRHOp, std::divides, Indexes...>( + Rhs.m_Vector, GetScalarOp(Lhs), Rhs); + } + + // TODO: Check that Rhs arg is suitable. +#ifdef __SYCL_OPASSIGN +#error "Undefine __SYCL_OPASSIGN macro." +#endif +#define __SYCL_OPASSIGN(OPASSIGN, OP) \ + friend const SwizzleOp &operator OPASSIGN(const SwizzleOp & Lhs, \ + const DataT & Rhs) { \ + Lhs.operatorHelper(vec_t(Rhs)); \ + return Lhs; \ + } \ + template \ + friend const SwizzleOp &operator OPASSIGN(const SwizzleOp & Lhs, \ + const RhsOperation & Rhs) { \ + Lhs.operatorHelper(Rhs); \ + return Lhs; \ + } \ + friend const SwizzleOp &operator OPASSIGN(const SwizzleOp & Lhs, \ + const vec_t & Rhs) { \ + Lhs.operatorHelper(Rhs); \ + return Lhs; \ + } + + __SYCL_OPASSIGN(+=, std::plus) + __SYCL_OPASSIGN(-=, std::minus) + __SYCL_OPASSIGN(*=, std::multiplies) + __SYCL_OPASSIGN(/=, std::divides) + __SYCL_OPASSIGN(%=, std::modulus) + __SYCL_OPASSIGN(&=, std::bit_and) + __SYCL_OPASSIGN(|=, std::bit_or) + __SYCL_OPASSIGN(^=, std::bit_xor) + __SYCL_OPASSIGN(>>=, RShift) + __SYCL_OPASSIGN(<<=, LShift) +#undef __SYCL_OPASSIGN + +#ifdef __SYCL_UOP +#error "Undefine __SYCL_UOP macro" +#endif +#define __SYCL_UOP(UOP, OPASSIGN) \ + friend const SwizzleOp &operator UOP(const SwizzleOp & sv) { \ + sv OPASSIGN static_cast(1); \ + return sv; \ + } \ + friend vec_t operator UOP(const SwizzleOp &sv, int) { \ + vec_t Ret = sv; \ + sv OPASSIGN static_cast(1); \ + return Ret; \ + } + + __SYCL_UOP(++, +=) + __SYCL_UOP(--, -=) +#undef __SYCL_UOP + + template + friend typename std::enable_if_t< + std::is_same_v && !detail::is_vgenfloat_v, vec_t> + operator~(const SwizzleOp &Rhs) { + vec_t Tmp = Rhs; + return ~Tmp; + } + + friend vec_rel_t operator!(const SwizzleOp &Rhs) { + vec_t Tmp = Rhs; + return !Tmp; + } + + friend vec_t operator+(const SwizzleOp &Rhs) { + vec_t Tmp = Rhs; + return +Tmp; + } + + friend vec_t operator-(const SwizzleOp &Rhs) { + vec_t Tmp = Rhs; + return -Tmp; + } + +// scalar BINOP vec<> +// scalar BINOP SwizzleOp +// vec<> BINOP SwizzleOp +#ifdef __SYCL_BINOP +#error "Undefine __SYCL_BINOP macro" +#endif +#define __SYCL_BINOP(BINOP, COND) \ + template \ + friend std::enable_if_t<(COND), vec_t> operator BINOP( \ + const DataT & Lhs, const SwizzleOp & Rhs) { \ + vec_t Tmp = Rhs; \ + return Lhs BINOP Tmp; \ + } \ + template \ + friend std::enable_if_t<(COND), vec_t> operator BINOP(const SwizzleOp & Lhs, \ + const DataT & Rhs) { \ + vec_t Tmp = Lhs; \ + return Tmp BINOP Rhs; \ + } \ + template \ + friend std::enable_if_t<(COND), vec_t> operator BINOP( \ + const vec_t & Lhs, const SwizzleOp & Rhs) { \ + vec_t Tmp = Rhs; \ + return Lhs BINOP Tmp; \ + } \ + template \ + friend std::enable_if_t<(COND), vec_t> operator BINOP(const SwizzleOp & Lhs, \ + const vec_t & Rhs) { \ + vec_t Tmp = Lhs; \ + return Tmp BINOP Rhs; \ + } + + __SYCL_BINOP(+, (!detail::is_byte_v)) + __SYCL_BINOP(-, (!detail::is_byte_v)) + __SYCL_BINOP(*, (!detail::is_byte_v)) + __SYCL_BINOP(/, (!detail::is_byte_v)) + __SYCL_BINOP(%, (!detail::is_byte_v)) + __SYCL_BINOP(&, true) + __SYCL_BINOP(|, true) + __SYCL_BINOP(^, true) + // We have special <<, >> operators for std::byte. + __SYCL_BINOP(>>, (!detail::is_byte_v)) + __SYCL_BINOP(<<, (!detail::is_byte_v)) + + template + friend std::enable_if_t, vec_t> + operator>>(const SwizzleOp &Lhs, const int shift) { + vec_t Tmp = Lhs; + return Tmp >> shift; + } + + template + friend std::enable_if_t, vec_t> + operator<<(const SwizzleOp &Lhs, const int shift) { + vec_t Tmp = Lhs; + return Tmp << shift; + } +#undef __SYCL_BINOP + +// scalar RELLOGOP vec<> +// scalar RELLOGOP SwizzleOp +// vec<> RELLOGOP SwizzleOp +#ifdef __SYCL_RELLOGOP +#error "Undefine __SYCL_RELLOGOP macro" +#endif +#define __SYCL_RELLOGOP(RELLOGOP, COND) \ + template \ + friend std::enable_if_t<(COND), vec_rel_t> operator RELLOGOP( \ + const DataT & Lhs, const SwizzleOp & Rhs) { \ + vec_t Tmp = Rhs; \ + return Lhs RELLOGOP Tmp; \ + } \ + template \ + friend std::enable_if_t<(COND), vec_rel_t> operator RELLOGOP( \ + const SwizzleOp & Lhs, const DataT & Rhs) { \ + vec_t Tmp = Lhs; \ + return Tmp RELLOGOP Rhs; \ + } \ + template \ + friend std::enable_if_t<(COND), vec_rel_t> operator RELLOGOP( \ + const vec_t & Lhs, const SwizzleOp & Rhs) { \ + vec_t Tmp = Rhs; \ + return Lhs RELLOGOP Tmp; \ + } \ + template \ + friend std::enable_if_t<(COND), vec_rel_t> operator RELLOGOP( \ + const SwizzleOp & Lhs, const vec_t & Rhs) { \ + vec_t Tmp = Lhs; \ + return Tmp RELLOGOP Rhs; \ + } + + __SYCL_RELLOGOP(==, (!detail::is_byte_v)) + __SYCL_RELLOGOP(!=, (!detail::is_byte_v)) + __SYCL_RELLOGOP(>, (!detail::is_byte_v)) + __SYCL_RELLOGOP(<, (!detail::is_byte_v)) + __SYCL_RELLOGOP(>=, (!detail::is_byte_v)) + __SYCL_RELLOGOP(<=, (!detail::is_byte_v)) + __SYCL_RELLOGOP(&&, (!detail::is_byte_v && !detail::is_vgenfloat_v)) + __SYCL_RELLOGOP(||, (!detail::is_byte_v && !detail::is_vgenfloat_v)) +#undef __SYCL_RELLOGOP + + template > + SwizzleOp &operator=(const vec &Rhs) { + std::array Idxs{Indexes...}; + for (size_t I = 0; I < Idxs.size(); ++I) { + (*m_Vector)[Idxs[I]] = Rhs[I]; + } + return *this; + } + + template > + SwizzleOp &operator=(const DataT &Rhs) { + std::array Idxs{Indexes...}; + (*m_Vector)[Idxs[0]] = Rhs; + return *this; + } + + template = true> + SwizzleOp &operator=(const DataT &Rhs) { + std::array Idxs{Indexes...}; + for (auto Idx : Idxs) { + (*m_Vector)[Idx] = Rhs; + } + return *this; + } + + template > + SwizzleOp &operator=(DataT &&Rhs) { + std::array Idxs{Indexes...}; + (*m_Vector)[Idxs[0]] = Rhs; + return *this; + } + + template > + NewLHOp, std::multiplies, Indexes...> + operator*(const T &Rhs) const { + return NewLHOp, std::multiplies, Indexes...>( + m_Vector, *this, GetScalarOp(Rhs)); + } + + template > + NewLHOp + operator*(const RhsOperation &Rhs) const { + return NewLHOp(m_Vector, *this, + Rhs); + } + + template > + NewLHOp, std::plus, Indexes...> operator+(const T &Rhs) const { + return NewLHOp, std::plus, Indexes...>(m_Vector, *this, + GetScalarOp(Rhs)); + } + + template > + NewLHOp + operator+(const RhsOperation &Rhs) const { + return NewLHOp(m_Vector, *this, Rhs); + } + + template > + NewLHOp, std::minus, Indexes...> + operator-(const T &Rhs) const { + return NewLHOp, std::minus, Indexes...>(m_Vector, *this, + GetScalarOp(Rhs)); + } + + template > + NewLHOp + operator-(const RhsOperation &Rhs) const { + return NewLHOp(m_Vector, *this, Rhs); + } + + template > + NewLHOp, std::divides, Indexes...> + operator/(const T &Rhs) const { + return NewLHOp, std::divides, Indexes...>( + m_Vector, *this, GetScalarOp(Rhs)); + } + + template > + NewLHOp + operator/(const RhsOperation &Rhs) const { + return NewLHOp(m_Vector, *this, + Rhs); + } + + template > + NewLHOp, std::modulus, Indexes...> + operator%(const T &Rhs) const { + return NewLHOp, std::modulus, Indexes...>( + m_Vector, *this, GetScalarOp(Rhs)); + } + + template > + NewLHOp + operator%(const RhsOperation &Rhs) const { + return NewLHOp(m_Vector, *this, + Rhs); + } + + template > + NewLHOp, std::bit_and, Indexes...> + operator&(const T &Rhs) const { + return NewLHOp, std::bit_and, Indexes...>( + m_Vector, *this, GetScalarOp(Rhs)); + } + + template > + NewLHOp + operator&(const RhsOperation &Rhs) const { + return NewLHOp(m_Vector, *this, + Rhs); + } + + template > + NewLHOp, std::bit_or, Indexes...> + operator|(const T &Rhs) const { + return NewLHOp, std::bit_or, Indexes...>( + m_Vector, *this, GetScalarOp(Rhs)); + } + + template > + NewLHOp + operator|(const RhsOperation &Rhs) const { + return NewLHOp(m_Vector, *this, Rhs); + } + + template > + NewLHOp, std::bit_xor, Indexes...> + operator^(const T &Rhs) const { + return NewLHOp, std::bit_xor, Indexes...>( + m_Vector, *this, GetScalarOp(Rhs)); + } + + template > + NewLHOp + operator^(const RhsOperation &Rhs) const { + return NewLHOp(m_Vector, *this, + Rhs); + } + + template > + NewLHOp, RShift, Indexes...> operator>>(const T &Rhs) const { + return NewLHOp, RShift, Indexes...>(m_Vector, *this, + GetScalarOp(Rhs)); + } + + template > + NewLHOp + operator>>(const RhsOperation &Rhs) const { + return NewLHOp(m_Vector, *this, Rhs); + } + + template > + NewLHOp, LShift, Indexes...> operator<<(const T &Rhs) const { + return NewLHOp, LShift, Indexes...>(m_Vector, *this, + GetScalarOp(Rhs)); + } + + template > + NewLHOp + operator<<(const RhsOperation &Rhs) const { + return NewLHOp(m_Vector, *this, Rhs); + } + + template class T4, + int... T5, + typename = typename std::enable_if_t> + SwizzleOp &operator=(const SwizzleOp &Rhs) { + std::array Idxs{Indexes...}; + for (size_t I = 0; I < Idxs.size(); ++I) { + (*m_Vector)[Idxs[I]] = Rhs.getValue(I); + } + return *this; + } + + template class T4, + int... T5, + typename = typename std::enable_if_t> + SwizzleOp &operator=(SwizzleOp &&Rhs) { + std::array Idxs{Indexes...}; + for (size_t I = 0; I < Idxs.size(); ++I) { + (*m_Vector)[Idxs[I]] = Rhs.getValue(I); + } + return *this; + } + + template > + NewRelOp, EqualTo, Indexes...> operator==(const T &Rhs) const { + return NewRelOp, EqualTo, Indexes...>(NULL, *this, + GetScalarOp(Rhs)); + } + + template > + NewRelOp + operator==(const RhsOperation &Rhs) const { + return NewRelOp(NULL, *this, Rhs); + } + + template > + NewRelOp, NotEqualTo, Indexes...> + operator!=(const T &Rhs) const { + return NewRelOp, NotEqualTo, Indexes...>( + NULL, *this, GetScalarOp(Rhs)); + } + + template > + NewRelOp + operator!=(const RhsOperation &Rhs) const { + return NewRelOp(NULL, *this, Rhs); + } + + template > + NewRelOp, GreaterEqualTo, Indexes...> + operator>=(const T &Rhs) const { + return NewRelOp, GreaterEqualTo, Indexes...>( + NULL, *this, GetScalarOp(Rhs)); + } + + template > + NewRelOp + operator>=(const RhsOperation &Rhs) const { + return NewRelOp(NULL, *this, Rhs); + } + + template > + NewRelOp, LessEqualTo, Indexes...> + operator<=(const T &Rhs) const { + return NewRelOp, LessEqualTo, Indexes...>( + NULL, *this, GetScalarOp(Rhs)); + } + + template > + NewRelOp + operator<=(const RhsOperation &Rhs) const { + return NewRelOp(NULL, *this, Rhs); + } + + template > + NewRelOp, GreaterThan, Indexes...> + operator>(const T &Rhs) const { + return NewRelOp, GreaterThan, Indexes...>( + NULL, *this, GetScalarOp(Rhs)); + } + + template > + NewRelOp + operator>(const RhsOperation &Rhs) const { + return NewRelOp(NULL, *this, Rhs); + } + + template > + NewRelOp, LessThan, Indexes...> operator<(const T &Rhs) const { + return NewRelOp, LessThan, Indexes...>(NULL, *this, + GetScalarOp(Rhs)); + } + + template > + NewRelOp + operator<(const RhsOperation &Rhs) const { + return NewRelOp(NULL, *this, Rhs); + } + + template > + NewRelOp, LogicalAnd, Indexes...> + operator&&(const T &Rhs) const { + return NewRelOp, LogicalAnd, Indexes...>( + NULL, *this, GetScalarOp(Rhs)); + } + + template > + NewRelOp + operator&&(const RhsOperation &Rhs) const { + return NewRelOp(NULL, *this, Rhs); + } + + template > + NewRelOp, LogicalOr, Indexes...> + operator||(const T &Rhs) const { + return NewRelOp, LogicalOr, Indexes...>(NULL, *this, + GetScalarOp(Rhs)); + } + + template > + NewRelOp + operator||(const RhsOperation &Rhs) const { + return NewRelOp(NULL, *this, Rhs); + } + +private: + static constexpr int get_vec_idx(int idx) { + int counter = 0; + int result = -1; + ((result = counter++ == idx ? Indexes : result), ...); + return result; + } + +public: + template + ConstSwizzle swizzle() const { + return m_Vector; + } + + template + Swizzle swizzle() { + return m_Vector; + } + + // Leave store() interface to automatic conversion to vec<>. + // Load to vec_t and then assign to swizzle. + template + void load(size_t offset, multi_ptr ptr) { + vec_t Tmp; + Tmp.load(offset, ptr); + *this = Tmp; + } + + template + vec convert() const { + // First materialize the swizzle to vec_t and then apply convert() to it. + vec_t Tmp; + std::array Idxs{Indexes...}; + for (size_t I = 0; I < Idxs.size(); ++I) { + Tmp[I] = (*m_Vector)[Idxs[I]]; + } + return Tmp.template convert(); + } + + template asT as() const { + // First materialize the swizzle to vec_t and then apply as() to it. + vec_t Tmp = *this; + static_assert((sizeof(Tmp) == sizeof(asT)), + "The new SYCL vec type must have the same storage size in " + "bytes as this SYCL swizzled vec"); + static_assert(detail::is_vec_v, + "asT must be SYCL vec of a different element type and " + "number of elements specified by asT"); + return Tmp.template as(); + } + +private: + SwizzleOp(const SwizzleOp &Rhs) + : m_Vector(Rhs.m_Vector), m_LeftOperation(Rhs.m_LeftOperation), + m_RightOperation(Rhs.m_RightOperation) {} + + SwizzleOp(VecT *Vector, OperationLeftT LeftOperation, + OperationRightT RightOperation) + : m_Vector(Vector), m_LeftOperation(LeftOperation), + m_RightOperation(RightOperation) {} + + SwizzleOp(VecT *Vector) : m_Vector(Vector) {} + + SwizzleOp(SwizzleOp &&Rhs) + : m_Vector(Rhs.m_Vector), m_LeftOperation(std::move(Rhs.m_LeftOperation)), + m_RightOperation(std::move(Rhs.m_RightOperation)) {} + + // Either performing CurrentOperation on results of left and right operands + // or reading values from actual vector. Perform implicit type conversion when + // the number of elements == 1 + + template + CommonDataT getValue(EnableIfOneIndex Index) const { + if (std::is_same_v, GetOp>) { + std::array Idxs{Indexes...}; + return (*m_Vector)[Idxs[Index]]; + } + auto Op = OperationCurrentT(); + return Op(m_LeftOperation.getValue(Index), + m_RightOperation.getValue(Index)); + } + + template + DataT getValue(EnableIfMultipleIndexes Index) const { + if (std::is_same_v, GetOp>) { + std::array Idxs{Indexes...}; + return (*m_Vector)[Idxs[Index]]; + } + auto Op = OperationCurrentT(); + return Op(m_LeftOperation.getValue(Index), + m_RightOperation.getValue(Index)); + } + + template