Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 79 additions & 5 deletions sycl/include/sycl/detail/builtins/builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,15 @@
#pragma once

#include <sycl/detail/helpers.hpp>
#include <sycl/detail/fwd/multi_ptr.hpp>
#include <sycl/detail/generic_type_traits.hpp>
#include <sycl/detail/memcpy.hpp>
#include <sycl/detail/type_traits.hpp>
#include <sycl/detail/type_traits/vec_marray_traits.hpp>
#include <sycl/detail/vector_convert.hpp>
#include <sycl/half_type.hpp>
#include <sycl/marray.hpp>
#include <sycl/multi_ptr.hpp>
#include <sycl/vector.hpp>

#include <algorithm>

namespace sycl {
inline namespace _V1 {
Expand Down Expand Up @@ -110,6 +113,55 @@ struct has_writeable_addr_space<multi_ptr<ElementType, Space, DecorateAddress>>
template <typename T>
constexpr bool has_writeable_addr_space_v = has_writeable_addr_space<T>::value;

enum class builtin_ptr_kind { raw, multi_ptr };

template <typename T>
using builtin_ptr_kind_tag_t = std::integral_constant<
builtin_ptr_kind,
is_multi_ptr_v<std::remove_cv_t<std::remove_reference_t<T>>>
? builtin_ptr_kind::multi_ptr
: builtin_ptr_kind::raw>;

template <typename PtrTy>
decltype(auto)
builtin_raw_ptr(PtrTy &&Ptr,
std::integral_constant<builtin_ptr_kind,
builtin_ptr_kind::raw>) {
return std::forward<PtrTy>(Ptr);
}

template <typename PtrTy>
auto builtin_raw_ptr(PtrTy &&Ptr,
std::integral_constant<builtin_ptr_kind,
builtin_ptr_kind::multi_ptr>) {
return get_raw_pointer(std::forward<PtrTy>(Ptr));
}

template <typename PtrTy> auto builtin_raw_ptr(PtrTy &&Ptr) {
return builtin_raw_ptr(std::forward<PtrTy>(Ptr),
builtin_ptr_kind_tag_t<PtrTy>{});
}

template <typename PtrTy>
decltype(auto)
builtin_element_ptr(PtrTy &&Ptr,
std::integral_constant<builtin_ptr_kind,
builtin_ptr_kind::raw>) {
return &(*std::forward<PtrTy>(Ptr))[0];
}

template <typename PtrTy>
auto builtin_element_ptr(PtrTy &&Ptr,
std::integral_constant<builtin_ptr_kind,
builtin_ptr_kind::multi_ptr>) {
return detail::builtin_element_ptr(std::forward<PtrTy>(Ptr));
}

template <typename PtrTy> auto builtin_element_ptr(PtrTy &&Ptr) {
return builtin_element_ptr(std::forward<PtrTy>(Ptr),
builtin_ptr_kind_tag_t<PtrTy>{});
}

// Utility trait for changing the element type of a type T. If T is a scalar,
// the new type replaces T completely.
template <typename NewElemT, typename T, typename = void>
Expand Down Expand Up @@ -161,6 +213,28 @@ template <class T, int N> marray<T, N> to_marray(vec<T, N> X) {
return Marray;
}

// Relation builtins widen signed-char masks to the required integer element
// type. Keep that conversion local here so builtins.hpp does not need to pull
// in vector_convert.hpp just for vec::convert.
template <typename NewElemT, int N>
vec<NewElemT, N> relational_mask_widen(vec<signed char, N> X) {
static_assert(is_scalar_arithmetic_v<NewElemT>);

#ifdef __SYCL_DEVICE_ONLY__
if constexpr (N > 1) {
using src_vector_t = signed char __attribute__((ext_vector_type(N)));
using dst_vector_t = NewElemT __attribute__((ext_vector_type(N)));
auto OpenCLVec = bit_cast<src_vector_t>(X);
return bit_cast<vec<NewElemT, N>>(
__builtin_convertvector(OpenCLVec, dst_vector_t));
}
#endif

vec<NewElemT, N> Result{};
loop<N>([&](auto idx) { Result[idx] = static_cast<NewElemT>(X[idx]); });
return Result;
}

namespace builtins {
#ifdef __SYCL_DEVICE_ONLY__
template <typename T> auto convert_arg(T &&x) {
Expand All @@ -176,7 +250,7 @@ template <typename T> auto convert_arg(T &&x) {
__attribute__((ext_vector_type(N)))>;
return bit_cast<result_type>(x);
} else if constexpr (is_swizzle_v<no_cv_ref>) {
return convert_arg(simplify_if_swizzle_t<no_cv_ref>{x});
return convert_arg(materialize_if_swizzle(std::forward<T>(x)));
} else {
static_assert(is_scalar_arithmetic_v<no_cv_ref> ||
is_multi_ptr_v<no_cv_ref> || std::is_pointer_v<no_cv_ref> ||
Expand Down Expand Up @@ -226,7 +300,7 @@ auto builtin_default_host_impl(FuncTy F, const Ts &...x) {
if constexpr ((... || is_marray_v<Ts>)) {
return builtin_marray_impl(F, x...);
} else {
return F(simplify_if_swizzle_t<Ts>{x}...);
return F(materialize_if_swizzle(x)...);
}
}

Expand Down
14 changes: 7 additions & 7 deletions sycl/include/sycl/detail/builtins/common_functions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ BUILTIN_COMMON(THREE_ARGS, mix, __spirv_ocl_mix)
template <typename T0, typename T1>
detail::builtin_enable_common_non_scalar_t<T0, T1>
mix(T0 x, T1 y, detail::get_elem_type_t<T0> z) {
return mix(detail::simplify_if_swizzle_t<T0>{x},
detail::simplify_if_swizzle_t<T0>{y},
return mix(detail::materialize_if_swizzle(x),
detail::materialize_if_swizzle(y),
detail::simplify_if_swizzle_t<T0>{z});
}

Expand All @@ -44,7 +44,7 @@ template <typename T>
detail::builtin_enable_common_non_scalar_t<T> step(detail::get_elem_type_t<T> x,
T y) {
return step(detail::simplify_if_swizzle_t<T>{x},
detail::simplify_if_swizzle_t<T>{y});
detail::materialize_if_swizzle(y));
}

BUILTIN_COMMON(THREE_ARGS, smoothstep, __spirv_ocl_smoothstep)
Expand All @@ -53,30 +53,30 @@ detail::builtin_enable_common_non_scalar_t<T>
smoothstep(detail::get_elem_type_t<T> x, detail::get_elem_type_t<T> y, T z) {
return smoothstep(detail::simplify_if_swizzle_t<T>{x},
detail::simplify_if_swizzle_t<T>{y},
detail::simplify_if_swizzle_t<T>{z});
detail::materialize_if_swizzle(z));
}

BUILTIN_COMMON(TWO_ARGS, max, __spirv_ocl_fmax_common)
template <typename T>
detail::builtin_enable_common_non_scalar_t<T>(max)(
T x, detail::get_elem_type_t<T> y) {
return (max)(detail::simplify_if_swizzle_t<T>{x},
return (max)(detail::materialize_if_swizzle(x),
detail::simplify_if_swizzle_t<T>{y});
}

BUILTIN_COMMON(TWO_ARGS, min, __spirv_ocl_fmin_common)
template <typename T>
detail::builtin_enable_common_non_scalar_t<T>(min)(
T x, detail::get_elem_type_t<T> y) {
return (min)(detail::simplify_if_swizzle_t<T>{x},
return (min)(detail::materialize_if_swizzle(x),
detail::simplify_if_swizzle_t<T>{y});
}

BUILTIN_COMMON(THREE_ARGS, clamp, __spirv_ocl_fclamp)
template <typename T>
detail::builtin_enable_common_non_scalar_t<T>
clamp(T x, detail::get_elem_type_t<T> y, detail::get_elem_type_t<T> z) {
return clamp(detail::simplify_if_swizzle_t<T>{x},
return clamp(detail::materialize_if_swizzle(x),
detail::simplify_if_swizzle_t<T>{y},
detail::simplify_if_swizzle_t<T>{z});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ auto builtin_delegate_geo_impl(FuncTy F, const Ts &...x) {
else
return ret;
} else {
return F(simplify_if_swizzle_t<T>{x}...);
return F(materialize_if_swizzle(x)...);
}
}
} // namespace detail
Expand Down
10 changes: 4 additions & 6 deletions sycl/include/sycl/detail/builtins/helper_macros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,12 @@
#define THREE_ARGS_ARG x, y, z

#define ONE_ARG_SIMPLIFIED_ARG \
simplify_if_swizzle_t<T0> { x }
detail::materialize_if_swizzle(x)
#define TWO_ARGS_SIMPLIFIED_ARG \
simplify_if_swizzle_t<T0>{x}, simplify_if_swizzle_t<T1> { y }
detail::materialize_if_swizzle(x), detail::materialize_if_swizzle(y)
#define THREE_ARGS_SIMPLIFIED_ARG \
simplify_if_swizzle_t<T0>{x}, simplify_if_swizzle_t<T1>{y}, \
simplify_if_swizzle_t<T2> { \
z \
}
detail::materialize_if_swizzle(x), detail::materialize_if_swizzle(y), \
detail::materialize_if_swizzle(z)

#define TWO_ARGS_ARG_ROTATED y, x
#define THREE_ARGS_ARG_ROTATED z, x, y
Expand Down
6 changes: 3 additions & 3 deletions sycl/include/sycl/detail/builtins/integer_functions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -105,23 +105,23 @@ BUILTIN_GENINT_SU(TWO_ARGS, max)
template <typename T>
detail::builtin_enable_integer_non_scalar_t<T>(max)(
T x, detail::get_elem_type_t<T> y) {
return (max)(detail::simplify_if_swizzle_t<T>{x},
return (max)(detail::materialize_if_swizzle(x),
detail::simplify_if_swizzle_t<T>{y});
}

BUILTIN_GENINT_SU(TWO_ARGS, min)
template <typename T>
detail::builtin_enable_integer_non_scalar_t<T>(min)(
T x, detail::get_elem_type_t<T> y) {
return (min)(detail::simplify_if_swizzle_t<T>{x},
return (min)(detail::materialize_if_swizzle(x),
detail::simplify_if_swizzle_t<T>{y});
}

BUILTIN_GENINT_SU(THREE_ARGS, clamp)
template <typename T>
detail::builtin_enable_integer_non_scalar_t<T>
clamp(T x, detail::get_elem_type_t<T> y, detail::get_elem_type_t<T> z) {
return clamp(detail::simplify_if_swizzle_t<T>{x},
return clamp(detail::materialize_if_swizzle(x),
detail::simplify_if_swizzle_t<T>{y},
detail::simplify_if_swizzle_t<T>{z});
}
Expand Down
25 changes: 10 additions & 15 deletions sycl/include/sycl/detail/builtins/math_functions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ BUILTIN_GENF(THREE_ARGS, mad)
BUILTIN_GENF(TWO_ARGS, NAME) \
template <typename T> \
detail::builtin_enable_math_t<T> NAME(T x, detail::get_elem_type_t<T> y) { \
return NAME(detail::simplify_if_swizzle_t<T>{x}, \
return NAME(detail::materialize_if_swizzle(x), \
detail::simplify_if_swizzle_t<T>{y}); \
}

Expand Down Expand Up @@ -223,13 +223,7 @@ auto builtin_delegate_ptr_impl(FuncTy F, PtrTy p, Ts... xs) {

// TODO: Optimize for sizes. Make not to violate ANSI-aliasing rules for the
// pointer argument.
auto p0 = [&]() {
if constexpr (is_multi_ptr_v<PtrTy>)
return address_space_cast<PtrTy::address_space,
get_multi_ptr_decoration_v<PtrTy>>(&(*p)[0]);
else
return &(*p)[0];
}();
auto p0 = builtin_element_ptr(p);

constexpr auto N = T0::size();
if constexpr (N <= 16)
Expand Down Expand Up @@ -314,7 +308,8 @@ using builtin_last_raw_intptr_t =
PtrTy p) { \
if constexpr (is_multi_ptr_v<PtrTy>) { \
/* TODO: Can't really create multi_ptr on host... */ \
return NAME##_impl(SYCL_CONCAT(LESS_ONE(NUM_ARGS), _ARG), p.get_raw()); \
return NAME##_impl(SYCL_CONCAT(LESS_ONE(NUM_ARGS), _ARG), \
builtin_raw_ptr(p)); \
} else { \
return builtin_delegate_ptr_impl( \
[](auto... xs) { return NAME##_impl(xs...); }, p, \
Expand All @@ -336,7 +331,7 @@ BUILTIN_LAST_INTPTR(THREE_ARGS, remquo)
#ifndef __SYCL_DEVICE_ONLY__
namespace detail {
template <typename T0, typename T1> auto fract_impl(T0 &x, T1 &y) {
auto flr = floor(simplify_if_swizzle_t<T0>{x});
auto flr = floor(materialize_if_swizzle(x));
*y = flr;
return fmin(x - flr, nextafter(simplify_if_swizzle_t<T0>{1.0},
simplify_if_swizzle_t<T0>{0.0}));
Expand All @@ -355,11 +350,11 @@ template <typename T0, typename T1> auto modf_impl(T0 &x, T1 &&y) {
if constexpr (is_multi_ptr_v<std::remove_reference_t<T1>>) {
// TODO: Spec needs to be clarified, multi_ptr shouldn't be possible on
// host.
return modf_impl(x, y.get_raw());
return modf_impl(x, builtin_raw_ptr(std::forward<T1>(y)));
} else {
return builtin_delegate_ptr_impl(
[](auto x, auto y) { return modf_impl(x, y); }, y,
simplify_if_swizzle_t<T0>{x});
materialize_if_swizzle(x));
}
}
} // namespace detail
Expand Down Expand Up @@ -395,7 +390,7 @@ BUILTIN_MATH_LAST_INT(rootn)
BUILTIN_MATH_LAST_INT(ldexp)
template <typename T> detail::builtin_enable_math_t<T> ldexp(T x, int y) {
return ldexp(
detail::simplify_if_swizzle_t<T>{x},
detail::materialize_if_swizzle(x),
detail::change_elements_t<int, detail::simplify_if_swizzle_t<T>>{y});
}

Expand Down Expand Up @@ -433,11 +428,11 @@ template <typename T0, typename T1> auto sincos_impl(T0 &x, T1 &&y) {
if constexpr (is_multi_ptr_v<std::remove_reference_t<T1>>) {
// TODO: Spec needs to be clarified, multi_ptr shouldn't be possible on
// host.
return sincos_impl(x, y.get_raw());
return sincos_impl(x, builtin_raw_ptr(std::forward<T1>(y)));
} else {
return builtin_delegate_ptr_impl(
[](auto... xs) { return sincos_impl(xs...); }, y,
simplify_if_swizzle_t<T0>{x});
materialize_if_swizzle(x));
}
}
#endif
Expand Down
4 changes: 2 additions & 2 deletions sycl/include/sycl/detail/builtins/relational_functions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ auto builtin_device_rel_impl(FuncTy F, const Ts &...xs) {
auto tmp = bit_cast<vec<signed char, num_elements<T>::value>>(ret);
using res_elem_type = fixed_width_signed<sizeof(get_elem_type_t<T>)>;
static_assert(is_scalar_arithmetic_v<res_elem_type>);
return tmp.template convert<res_elem_type>();
return relational_mask_widen<res_elem_type>(tmp);
} else if constexpr (std::is_same_v<T, half>) {
return bool{F(builtins::convert_arg(xs)...)};
} else {
Expand All @@ -80,7 +80,7 @@ template <typename FuncTy, typename... Ts>
auto builtin_delegate_rel_impl(FuncTy F, const Ts &...x) {
using T = typename first_type<Ts...>::type;
if constexpr ((... || is_swizzle_v<Ts>)) {
return F(simplify_if_swizzle_t<T>{x}...);
return F(materialize_if_swizzle(x)...);
} else if constexpr (is_vec_v<T>) {
// TODO: using Res{} to avoid Werror. Not sure if ok.
vec<fixed_width_signed<sizeof(typename T::element_type)>, T::size()> Res{};
Expand Down
14 changes: 14 additions & 0 deletions sycl/include/sycl/detail/fwd/multi_ptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include <sycl/access/access.hpp>

#include <type_traits>

namespace sycl {
inline namespace _V1 {
// Forward declaration
Expand All @@ -20,5 +22,17 @@ template <access::address_space Space, access::decorated DecorateAddress,
typename ElementType>
multi_ptr<ElementType, Space, DecorateAddress>
address_space_cast(ElementType *pointer);

namespace detail {
template <typename ElementType, access::address_space Space,
access::decorated DecorateAddress>
std::add_pointer_t<ElementType>
get_raw_pointer(multi_ptr<ElementType, Space, DecorateAddress> Ptr);

template <typename ElementType, access::address_space Space,
access::decorated DecorateAddress>
auto builtin_element_ptr(
multi_ptr<ElementType, Space, DecorateAddress> Ptr);
} // namespace detail
} // namespace _V1
} // namespace sycl
21 changes: 21 additions & 0 deletions sycl/include/sycl/detail/type_traits/vec_marray_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <cstddef>
#include <type_traits>
#include <utility>

#include <sycl/detail/defines_elementary.hpp>

Expand Down Expand Up @@ -59,6 +60,26 @@ struct simplify_if_swizzle<detail::hide_swizzle_from_adl::Swizzle<
template <typename T>
using simplify_if_swizzle_t = typename simplify_if_swizzle<T>::type;

template <typename T> constexpr decltype(auto) materialize_if_swizzle(T &&X) {
return std::forward<T>(X);
}

#if __SYCL_USE_LIBSYCL8_VEC_IMPL
template <typename VecT, typename OperationLeftT, typename OperationRightT,
template <typename> class OperationCurrentT, int... Indexes>
simplify_if_swizzle_t<SwizzleOp<VecT, OperationLeftT, OperationRightT,
OperationCurrentT, Indexes...>>
materialize_if_swizzle(
const SwizzleOp<VecT, OperationLeftT, OperationRightT, OperationCurrentT,
Indexes...> &X);
#else
template <bool IsConstVec, typename DataT, int VecSize, int... Indexes>
simplify_if_swizzle_t<detail::hide_swizzle_from_adl::Swizzle<
IsConstVec, DataT, VecSize, Indexes...>>
materialize_if_swizzle(const detail::hide_swizzle_from_adl::Swizzle<
IsConstVec, DataT, VecSize, Indexes...> &X);
#endif

// --------- is_* traits ------------------ //
template <typename> struct is_vec : std::false_type {};
template <typename T, int N> struct is_vec<vec<T, N>> : std::true_type {};
Expand Down
Loading
Loading