Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 56 additions & 14 deletions sycl/include/sycl/detail/builtins/builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@

#pragma once

#include <sycl/detail/fwd/multi_ptr.hpp>
#include <sycl/detail/helpers.hpp>
#include <sycl/detail/type_traits.hpp>
#include <sycl/detail/type_traits/vec_marray_traits.hpp>
#include <sycl/detail/vector_convert.hpp>
#include <sycl/marray.hpp>
#include <sycl/multi_ptr.hpp>
#include <sycl/vector.hpp>

namespace sycl {
Expand All @@ -83,19 +83,6 @@ template <typename> struct use_fast_math : std::false_type {};
#endif
template <typename T> constexpr bool use_fast_math_v = use_fast_math<T>::value;

// Utility trait for getting the decoration of a multi_ptr.
template <typename T> struct get_multi_ptr_decoration;
template <typename ElementType, access::address_space Space,
access::decorated DecorateAddress>
struct get_multi_ptr_decoration<
multi_ptr<ElementType, Space, DecorateAddress>> {
static constexpr access::decorated value = DecorateAddress;
};

template <typename T>
constexpr access::decorated get_multi_ptr_decoration_v =
get_multi_ptr_decoration<T>::value;

// Utility trait for checking if a multi_ptr has a "writable" address space,
// i.e. global, local, private or generic.
template <typename T> struct has_writeable_addr_space : std::false_type {};
Expand All @@ -110,6 +97,61 @@ struct has_writeable_addr_space<multi_ptr<ElementType, Space, DecorateAddress>>
template <typename T>
constexpr bool has_writeable_addr_space_v = has_writeable_addr_space<T>::value;

// Classification of pointer-like types used by builtin pointer helpers.
enum class builtin_ptr_kind { raw, multi_ptr };

// Maps a pointer-like type to the corresponding builtin_ptr_kind tag.
template <typename T>
using builtin_ptr_kind_tag_t = std::integral_constant<
builtin_ptr_kind,
is_multi_ptr_v<std::remove_cv_t<std::remove_reference_t<T>>>
? builtin_ptr_kind::multi_ptr
: builtin_ptr_kind::raw>;

// Returns Ptr unchanged for raw pointer-like types.
template <typename PtrTy>
decltype(auto) builtin_raw_ptr(
PtrTy &&Ptr,
std::integral_constant<builtin_ptr_kind, builtin_ptr_kind::raw>) {
return std::forward<PtrTy>(Ptr);
}

// Extracts the underlying raw pointer from a multi_ptr.
template <typename PtrTy>
auto builtin_raw_ptr(
PtrTy &&Ptr,
std::integral_constant<builtin_ptr_kind, builtin_ptr_kind::multi_ptr>) {
return get_raw_pointer(std::forward<PtrTy>(Ptr));
}

// Returns a raw pointer representation for raw pointers and multi_ptrs.
template <typename PtrTy> auto builtin_raw_ptr(PtrTy &&Ptr) {
return builtin_raw_ptr(std::forward<PtrTy>(Ptr),
builtin_ptr_kind_tag_t<PtrTy>{});
}

// Returns a pointer to the first element for raw pointer-like types.
template <typename PtrTy>
decltype(auto) builtin_element_ptr(
PtrTy &&Ptr,
std::integral_constant<builtin_ptr_kind, builtin_ptr_kind::raw>) {
return &(*std::forward<PtrTy>(Ptr))[0];
}

// Returns a pointer to the first element while preserving multi_ptr semantics.
template <typename PtrTy>
auto builtin_element_ptr(
PtrTy &&Ptr,
std::integral_constant<builtin_ptr_kind, builtin_ptr_kind::multi_ptr>) {
return detail::builtin_element_ptr(std::forward<PtrTy>(Ptr));
}

// Returns an element pointer for raw pointers and multi_ptrs.
template <typename PtrTy> auto builtin_element_ptr(PtrTy &&Ptr) {
return builtin_element_ptr(std::forward<PtrTy>(Ptr),
builtin_ptr_kind_tag_t<PtrTy>{});
}

// Utility trait for changing the element type of a type T. If T is a scalar,
// the new type replaces T completely.
template <typename NewElemT, typename T, typename = void>
Expand Down
15 changes: 5 additions & 10 deletions sycl/include/sycl/detail/builtins/math_functions.inc
Original file line number Diff line number Diff line change
Expand Up @@ -223,13 +223,7 @@ auto builtin_delegate_ptr_impl(FuncTy F, PtrTy p, Ts... xs) {

// TODO: Optimize for sizes. Make not to violate ANSI-aliasing rules for the
// pointer argument.
auto p0 = [&]() {
if constexpr (is_multi_ptr_v<PtrTy>)
return address_space_cast<PtrTy::address_space,
get_multi_ptr_decoration_v<PtrTy>>(&(*p)[0]);
else
return &(*p)[0];
}();
auto p0 = builtin_element_ptr(p);

constexpr auto N = T0::size();
if constexpr (N <= 16)
Expand Down Expand Up @@ -314,7 +308,8 @@ using builtin_last_raw_intptr_t =
PtrTy p) { \
if constexpr (is_multi_ptr_v<PtrTy>) { \
/* TODO: Can't really create multi_ptr on host... */ \
return NAME##_impl(SYCL_CONCAT(LESS_ONE(NUM_ARGS), _ARG), p.get_raw()); \
return NAME##_impl(SYCL_CONCAT(LESS_ONE(NUM_ARGS), _ARG), \
builtin_raw_ptr(p)); \
} else { \
return builtin_delegate_ptr_impl( \
[](auto... xs) { return NAME##_impl(xs...); }, p, \
Expand Down Expand Up @@ -355,7 +350,7 @@ template <typename T0, typename T1> auto modf_impl(T0 &x, T1 &&y) {
if constexpr (is_multi_ptr_v<std::remove_reference_t<T1>>) {
// TODO: Spec needs to be clarified, multi_ptr shouldn't be possible on
// host.
return modf_impl(x, y.get_raw());
return modf_impl(x, builtin_raw_ptr(std::forward<T1>(y)));
} else {
return builtin_delegate_ptr_impl(
[](auto x, auto y) { return modf_impl(x, y); }, y,
Expand Down Expand Up @@ -433,7 +428,7 @@ template <typename T0, typename T1> auto sincos_impl(T0 &x, T1 &&y) {
if constexpr (is_multi_ptr_v<std::remove_reference_t<T1>>) {
// TODO: Spec needs to be clarified, multi_ptr shouldn't be possible on
// host.
return sincos_impl(x, y.get_raw());
return sincos_impl(x, builtin_raw_ptr(std::forward<T1>(y)));
} else {
return builtin_delegate_ptr_impl(
[](auto... xs) { return sincos_impl(xs...); }, y,
Expand Down
13 changes: 13 additions & 0 deletions sycl/include/sycl/detail/fwd/multi_ptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

#include <sycl/access/access.hpp>

#include <type_traits>

namespace sycl {
inline namespace _V1 {
// Forward declaration
Expand All @@ -20,5 +22,16 @@ template <access::address_space Space, access::decorated DecorateAddress,
typename ElementType>
multi_ptr<ElementType, Space, DecorateAddress>
address_space_cast(ElementType *pointer);

namespace detail {
template <typename ElementType, access::address_space Space,
access::decorated DecorateAddress>
std::add_pointer_t<ElementType>
get_raw_pointer(multi_ptr<ElementType, Space, DecorateAddress> Ptr);

template <typename ElementType, access::address_space Space,
access::decorated DecorateAddress>
auto builtin_element_ptr(multi_ptr<ElementType, Space, DecorateAddress> Ptr);
} // namespace detail
} // namespace _V1
} // namespace sycl
15 changes: 15 additions & 0 deletions sycl/include/sycl/multi_ptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1377,6 +1377,21 @@ address_space_cast(ElementType *pointer) {
pointer));
}

namespace detail {
template <typename ElementType, access::address_space Space,
access::decorated DecorateAddress>
std::add_pointer_t<ElementType>
get_raw_pointer(multi_ptr<ElementType, Space, DecorateAddress> Ptr) {
return Ptr.get_raw();
}

template <typename ElementType, access::address_space Space,
access::decorated DecorateAddress>
auto builtin_element_ptr(multi_ptr<ElementType, Space, DecorateAddress> Ptr) {
return address_space_cast<Space, DecorateAddress>(&(*Ptr)[0]);
}
} // namespace detail

template <
typename ElementType, access::address_space Space,
access::decorated DecorateAddress = access::decorated::legacy,
Expand Down
6 changes: 2 additions & 4 deletions sycl/test/include_deps/sycl_khr_includes_math.hpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@
// CHECK-NEXT: feature_test.hpp
// CHECK-NEXT: builtins.hpp
// CHECK-NEXT: detail/builtins/builtins.hpp
// CHECK-NEXT: detail/fwd/multi_ptr.hpp
// CHECK-NEXT: access/access.hpp
// CHECK-NEXT: detail/helpers.hpp
// CHECK-NEXT: __spirv/spirv_types.hpp
// CHECK-NEXT: detail/defines.hpp
// CHECK-NEXT: access/access.hpp
// CHECK-NEXT: detail/export.hpp
// CHECK-NEXT: memory_enums.hpp
// CHECK-NEXT: __spirv/spirv_vars.hpp
// CHECK-NEXT: detail/type_traits.hpp
// CHECK-NEXT: detail/type_traits/vec_marray_traits.hpp
// CHECK-NEXT: detail/fwd/multi_ptr.hpp
// CHECK-NEXT: detail/vector_convert.hpp
// CHECK-NEXT: detail/generic_type_traits.hpp
// CHECK-NEXT: aliases.hpp
Expand All @@ -37,8 +37,6 @@
// CHECK-NEXT: detail/common.hpp
// CHECK-NEXT: detail/fwd/accessor.hpp
// CHECK-NEXT: marray.hpp
// CHECK-NEXT: multi_ptr.hpp
// CHECK-NEXT: detail/address_space_cast.hpp
// CHECK-NEXT: detail/builtins/common_functions.inc
// CHECK-NEXT: detail/builtins/helper_macros.hpp
// CHECK-NEXT: detail/builtins/geometric_functions.inc
Expand Down
8 changes: 4 additions & 4 deletions sycl/test/include_deps/sycl_khr_includes_usm.hpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@
// CHECK-NEXT: usm.hpp
// CHECK-NEXT: builtins.hpp
// CHECK-NEXT: detail/builtins/builtins.hpp
// CHECK-NEXT: detail/fwd/multi_ptr.hpp
// CHECK-NEXT: access/access.hpp
// CHECK-NEXT: detail/helpers.hpp
// CHECK-NEXT: __spirv/spirv_types.hpp
// CHECK-NEXT: detail/defines.hpp
// CHECK-NEXT: access/access.hpp
// CHECK-NEXT: detail/export.hpp
// CHECK-NEXT: memory_enums.hpp
// CHECK-NEXT: __spirv/spirv_vars.hpp
// CHECK-NEXT: detail/type_traits.hpp
// CHECK-NEXT: detail/type_traits/vec_marray_traits.hpp
// CHECK-NEXT: detail/fwd/multi_ptr.hpp
// CHECK-NEXT: detail/vector_convert.hpp
// CHECK-NEXT: detail/generic_type_traits.hpp
// CHECK-NEXT: aliases.hpp
Expand All @@ -38,8 +38,6 @@
// CHECK-NEXT: detail/common.hpp
// CHECK-NEXT: detail/fwd/accessor.hpp
// CHECK-NEXT: marray.hpp
// CHECK-NEXT: multi_ptr.hpp
// CHECK-NEXT: detail/address_space_cast.hpp
// CHECK-NEXT: detail/builtins/common_functions.inc
// CHECK-NEXT: detail/builtins/helper_macros.hpp
// CHECK-NEXT: detail/builtins/geometric_functions.inc
Expand Down Expand Up @@ -102,6 +100,8 @@
// CHECK-NEXT: ext/oneapi/accessor_property_list.hpp
// CHECK-NEXT: detail/accessor_iterator.hpp
// CHECK-NEXT: detail/handler_proxy.hpp
// CHECK-NEXT: multi_ptr.hpp
// CHECK-NEXT: detail/address_space_cast.hpp
// CHECK-NEXT: pointers.hpp
// CHECK-NEXT: properties/accessor_properties.hpp
// CHECK-NEXT: properties/runtime_accessor_properties.def
Expand Down
70 changes: 70 additions & 0 deletions sycl/test/regression/builtins_multi_ptr_include_order.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// RUN: %clangxx -fsycl -fsyntax-only -Wno-deprecated-declarations %s -DTEST_BUILTINS_ONLY
// RUN: %clangxx -fsycl -fsyntax-only -Wno-deprecated-declarations %s -DTEST_BUILTINS_FIRST
// RUN: %clangxx -fsycl -fsyntax-only -Wno-deprecated-declarations %s -DTEST_MULTI_PTR_FIRST

// Regression coverage for builtins/multi_ptr decoupling.
// We want to preserve these behaviors:
// 1. <sycl/builtins.hpp> compiles without including <sycl/multi_ptr.hpp>.
// 2. Including builtins before multi_ptr still allows later multi_ptr
// instantiation for scalar pointer builtins.
// 3. Including builtins before multi_ptr still allows later multi_ptr
// instantiation for vector pointer builtins.
// 4. Including multi_ptr before builtins also works for those builtin calls.

#if defined(TEST_BUILTINS_ONLY)
#include <sycl/builtins.hpp>

int main() {
auto Value = sycl::fmin(1.0f, 2.0f);
(void)Value;
return 0;
}

#elif defined(TEST_BUILTINS_FIRST)
#include <sycl/builtins.hpp>
#include <sycl/multi_ptr.hpp>

SYCL_EXTERNAL void
testScalar(sycl::multi_ptr<float, sycl::access::address_space::global_space,
sycl::access::decorated::no>
Ptr) {
(void)sycl::modf(1.0f, Ptr);
(void)sycl::sincos(1.0f, Ptr);
}

SYCL_EXTERNAL void
testVector(sycl::multi_ptr<sycl::vec<float, 2>,
sycl::access::address_space::global_space,
sycl::access::decorated::no>
Ptr) {
sycl::vec<float, 2> Value{1.0f, 2.0f};
(void)sycl::fract(Value, Ptr);
}

int main() { return 0; }

#elif defined(TEST_MULTI_PTR_FIRST)
// clang-format off
#include <sycl/multi_ptr.hpp>
#include <sycl/builtins.hpp>
// clang-format on

SYCL_EXTERNAL void
testScalar(sycl::multi_ptr<float, sycl::access::address_space::global_space,
sycl::access::decorated::no>
Ptr) {
(void)sycl::modf(1.0f, Ptr);
(void)sycl::sincos(1.0f, Ptr);
}

SYCL_EXTERNAL void
testVector(sycl::multi_ptr<sycl::vec<float, 2>,
sycl::access::address_space::global_space,
sycl::access::decorated::no>
Ptr) {
sycl::vec<float, 2> Value{1.0f, 2.0f};
(void)sycl::fract(Value, Ptr);
}

int main() { return 0; }
#endif
Loading