From 448ea01903eb3d6134a3fcc1e8191f0c31009719 Mon Sep 17 00:00:00 2001 From: "Meszaros, Gergely" Date: Wed, 21 Jan 2026 09:42:16 +0000 Subject: [PATCH] Improve compile and run- time of marray_* by reducing template instatiations (NFCI) Reduce the number of template instantiations in marray_operators.h by making the initial sequences and scalars runtime parameters instead of template parameters. This reduces the number of SYCL kernels by 4x to 16x depending on the test case. This could theoretically decrease runtime performance, but it seems like that's also significantly improved, likely by reducing kernel JIT times. These numbers are from a local run on an Intel GPU. Obtained by running the commands ```bash TESTS=( test_marray_arithmetic_assignment test_marray_arithmetic_binary test_marray_basic test_marray_bitwise test_marray_pre_post test_marray_relational ) ninja -C build "${TESTS[@]}" ``` - Before: 15m 20s (920s) - After: 2m 19s (139s) (6.6x speedup) Runtime: ```bash for test in "${TESTS[@]}"; do build/bin/${test} done ``` - Before: 20m 51s (1251s) - After: 3m 8s (188s) (6.7x speedup) --- tests/marray_basic/marray_operators.h | 494 +++++++++++++------------- 1 file changed, 249 insertions(+), 245 deletions(-) diff --git a/tests/marray_basic/marray_operators.h b/tests/marray_basic/marray_operators.h index 828ceafe3..80ce70d35 100644 --- a/tests/marray_basic/marray_operators.h +++ b/tests/marray_basic/marray_operators.h @@ -27,26 +27,90 @@ #include "marray_common.h" #include "marray_operator_helper.h" +#include #include namespace marray_operators { +/** + * @brief Define several sequences to initialize array instances. */ + +enum class init_sequence { inc, dec, ones, twos }; + +inline constexpr init_sequence all_init_sequences[] = { + init_sequence::inc, init_sequence::dec, init_sequence::ones, + init_sequence::twos}; + +template +inline constexpr DataT get_seq_val(std::size_t num_elements, init_sequence seq, + std::size_t i) { + switch (seq) { + case init_sequence::inc: + return DataT(i + 1); + case init_sequence::dec: + return DataT(num_elements - i); + case init_sequence::ones: + return DataT(1); + case init_sequence::twos: + return DataT(2); + } +} + +inline std::string get_sequence_name(init_sequence seq) { + switch (seq) { + case init_sequence::inc: + return "incrementing sequence"; + case init_sequence::dec: + return "decrementing sequence"; + case init_sequence::ones: + return "sequence of ones"; + case init_sequence::twos: + return "sequence of twos"; + } +} + +/** + * @brief Define several constants to initialize scalar instances. */ + +enum class init_scalar { one, two }; + +static constexpr init_scalar all_init_scalars[] = {init_scalar::one, + init_scalar::two}; + +template +inline DataT get_scalar_val(init_scalar sca) { + switch (sca) { + case init_scalar::one: + return DataT(1); + case init_scalar::two: + return DataT(2); + } +} + +inline std::string get_scalar_name(init_scalar sca) { + switch (sca) { + case init_scalar::one: + return "one (1)"; + case init_scalar::two: + return "two (2)"; + } +} + template struct operators_helper { static constexpr std::size_t NumElements = NumElementsT::value; using marray_t = sycl::marray; using varray_t = std::valarray; - template - static void init(array_type& ma) { + template + static void init(array_type& ma, init_sequence seq) { for (std::size_t i = 0; i < NumElements; i++) { - ma[i] = init_func::template init(i); + ma[i] = get_seq_val(NumElements, seq, i); } } - template - static void init(DataT& d) { - d = init_func::template init(); + static void init(DataT& d, init_scalar sca) { + d = get_scalar_val(sca); } }; @@ -75,88 +139,26 @@ bool are_equal_ignore_division(const T1& lhs, const T1& rhs) { return value_operations::are_equal(lhs, rhs); } -/** - * @brief Define several sequences to initialize array instances. */ - -struct seq_inc { - template - static DataT init(std::size_t i) { - return DataT(i + 1); - } -}; - -template -struct seq_dec { - template - static DataT init(std::size_t i) { - return DataT(NumElements - i); - } -}; - -struct seq_one { - template - static DataT init(std::size_t) { - return {1}; - } -}; - -struct seq_two { - template - static DataT init(std::size_t) { - return DataT(2); - } -}; - -template -inline auto get_sequences() { - return named_type_pack, seq_one, - seq_two>::generate("incrementing sequence", - "decrementing sequence", - "sequences of ones", - "sequence of twos"); -} - -/** - * @brief Define several constants to initialize scalar instances. */ - -struct sca_one { - template - static DataT init() { - return DataT(1); - } -}; - -struct sca_two { - template - static DataT init() { - return DataT(2); - } -}; - -inline auto get_scalars() { - return named_type_pack::generate("one (1)", "two (2)"); -} - -template +template class run_unary_sequence { using helper = operators_helper; template - static void run_on_host(const ResT& res_expected) { + static void run_on_host(init_sequence seq, const ResT& res_expected) { INFO("validation on host"); OpT op; typename helper::marray_t val_actual; - helper::template init(val_actual); + helper::init(val_actual, seq); auto res_actual = op(val_actual); CHECK(value_operations::are_equal(res_expected, res_actual)); } template - static void run_on_device(const std::valarray& res_expected) { + static void run_on_device(init_sequence seq, + const std::valarray& res_expected) { INFO("validation on device"); auto queue = sycl_cts::util::get_cts_object::queue(); @@ -173,7 +175,7 @@ class run_unary_sequence { cgh.single_task([=]() { OpT op; typename helper::marray_t val_actual; - helper::template init(val_actual); + helper::init(val_actual, seq); res_actual_acc[0] = op(val_actual); }); }) @@ -184,18 +186,20 @@ class run_unary_sequence { } public: - void operator()(const std::string& function_name) { - INFO("for input (sequence) \"" << function_name << "\": "); + void operator()() { + for (const init_sequence seq : all_init_sequences) { + INFO("for input (sequence) \"" << get_sequence_name(seq) << "\": "); - OpT op; + OpT op; - typename helper::varray_t val_expected(helper::NumElements); - helper::template init(val_expected); - auto res_expected = op(val_expected); + typename helper::varray_t val_expected(helper::NumElements); + helper::init(val_expected, seq); + auto res_expected = op(val_expected); - run_on_host(res_expected); + run_on_host(seq, res_expected); - run_on_device(res_expected); + run_on_device(seq, res_expected); + } } }; @@ -215,26 +219,24 @@ class run_unary(); - for_all_combinations( - functions); + run_unary_sequence{}(); } }; -template +template class run_unary_post_sequence { using helper = operators_helper; template - static void run_on_host(const typename helper::varray_t& val_expected, + static void run_on_host(init_sequence seq, + const typename helper::varray_t& val_expected, const ResT& res_expected) { INFO("validation on host"); OpT op; typename helper::marray_t val_actual; - helper::template init(val_actual); + helper::init(val_actual, seq); auto res_actual = op(val_actual); // check the returned output @@ -244,7 +246,8 @@ class run_unary_post_sequence { } template - static void run_on_device(const typename helper::varray_t& val_expected, + static void run_on_device(init_sequence seq, + const typename helper::varray_t& val_expected, const std::valarray& res_expected) { INFO("validation on device"); @@ -266,7 +269,7 @@ class run_unary_post_sequence { sycl::write_only}; cgh.single_task([=]() { OpT op; - helper::template init(val_actual_acc[0]); + helper::init(val_actual_acc[0], seq); res_actual_acc[0] = op(val_actual_acc[0]); }); }) @@ -280,18 +283,20 @@ class run_unary_post_sequence { } public: - void operator()(const std::string& function_name) { - INFO("for input (sequence) \"" << function_name << "\": "); + void operator()() { + for (const init_sequence seq : all_init_sequences) { + INFO("for input (sequence) \"" << get_sequence_name(seq) << "\": "); - OpT op; + OpT op; - typename helper::varray_t val_expected(helper::NumElements); - helper::template init(val_expected); - auto res_expected = op(val_expected); + typename helper::varray_t val_expected(helper::NumElements); + helper::init(val_expected, seq); + auto res_expected = op(val_expected); - run_on_host(val_expected, res_expected); + run_on_host(seq, val_expected, res_expected); - run_on_device(val_expected, res_expected); + run_on_device(seq, val_expected, res_expected); + } } }; @@ -301,34 +306,33 @@ class run_unary_post { void operator()(const std::string& operator_name) { INFO("for operator \"" << operator_name << "\": "); - const auto functions = get_sequences(); - for_all_combinations( - functions); + run_unary_post_sequence{}(); } }; -template +template class run_binary_sequence_scalar { using helper = operators_helper; template - static void run_on_host(const ResT& res_expected) { + static void run_on_host(init_sequence seq, init_scalar sca, + const ResT& res_expected) { INFO("validation on host"); OpT op; typename helper::marray_t lhs_actual; - helper::template init(lhs_actual); + helper::init(lhs_actual, seq); DataT rhs_actual; - helper::template init(rhs_actual); + helper::init(rhs_actual, sca); auto res_actual = op(lhs_actual, rhs_actual); CHECK(value_operations::are_equal(res_expected, res_actual)); } template - static void run_on_device(const std::valarray& res_expected) { + static void run_on_device(init_sequence seq, init_scalar sca, + const std::valarray& res_expected) { INFO("validation on device"); auto queue = sycl_cts::util::get_cts_object::queue(); @@ -345,9 +349,9 @@ class run_binary_sequence_scalar { cgh.single_task([=]() { OpT op; typename helper::marray_t lhs_actual; - helper::template init(lhs_actual); + helper::init(lhs_actual, seq); DataT rhs_actual; - helper::template init(rhs_actual); + helper::init(rhs_actual, sca); res_actual_acc[0] = op(lhs_actual, rhs_actual); }); }) @@ -359,31 +363,33 @@ class run_binary_sequence_scalar { } public: - void operator()(const std::string& function_name, - const std::string& constant_name) { - INFO("for lhs (sequence) \"" << function_name << "\": "); - INFO("for rhs (scalar) \"" << constant_name << "\": "); + void operator()() { + for (const init_sequence seq : all_init_sequences) { + for (const init_scalar sca : all_init_scalars) { + INFO("for lhs (sequence) \"" << get_sequence_name(seq) << "\": "); + INFO("for rhs (scalar) \"" << get_scalar_name(sca) << "\": "); - OpT op; + OpT op; - typename helper::varray_t lhs_expected(helper::NumElements); - helper::template init(lhs_expected); - DataT rhs_expected; - helper::template init(rhs_expected); - auto res_expected = op(lhs_expected, rhs_expected); + typename helper::varray_t lhs_expected(helper::NumElements); + helper::init(lhs_expected, seq); + DataT rhs_expected; + helper::init(rhs_expected, sca); + auto res_expected = op(lhs_expected, rhs_expected); - run_on_host(res_expected); + run_on_host(seq, sca, res_expected); - run_on_device(res_expected); + run_on_device(seq, sca, res_expected); + } + } } }; -template inline constexpr bool init_seq_contains_too_big_values_for_shift_op( + init_sequence seq, std::size_t seq_el_num, std::size_t max_shift_wo_undef_behavior) { - return (std::is_same_v || - std::is_same_v>)&&seq_el_num > - max_shift_wo_undef_behavior; + return (seq == init_sequence::inc || seq == init_sequence::dec) && + seq_el_num > max_shift_wo_undef_behavior; } /** @@ -403,18 +409,17 @@ inline constexpr bool init_seq_contains_too_big_values_for_shift_op( So right shift operation by N bits with N is less than sizeof(int) - 1 is guaranteed legal for type int, wider types and for small integral types because of its integral promotions to int. */ -template -inline constexpr bool test_case_is_invalid_for_shift_op() { +template +inline constexpr bool test_case_is_invalid_for_shift_op( + init_sequence rhs_seq, std::size_t seq_el_num) { constexpr int max_left_shift_wo_undef_behavior = 8; constexpr int max_right_shift_wo_undef_behavior = sizeof(int) - 1; if constexpr (std::is_same_v || std::is_same_v) - return init_seq_contains_too_big_values_for_shift_op( - max_left_shift_wo_undef_behavior); + return init_seq_contains_too_big_values_for_shift_op( + rhs_seq, seq_el_num, max_left_shift_wo_undef_behavior); if constexpr (std::is_same_v || std::is_same_v) - return init_seq_contains_too_big_values_for_shift_op( - max_right_shift_wo_undef_behavior); + return init_seq_contains_too_big_values_for_shift_op( + rhs_seq, seq_el_num, max_right_shift_wo_undef_behavior); return false; } @@ -424,28 +429,29 @@ inline constexpr bool is_shift_op() { std::is_same_v || std::is_same_v; } -template +template class run_binary_scalar_sequence { using helper = operators_helper; template - static void run_on_host(const ResT& res_expected) { + static void run_on_host(init_scalar sca, init_sequence seq, + const ResT& res_expected) { INFO("validation on host"); OpT op; DataT lhs_actual; - helper::template init(lhs_actual); + helper::init(lhs_actual, sca); typename helper::marray_t rhs_actual; - helper::template init(rhs_actual); + helper::init(rhs_actual, seq); auto res_actual = op(lhs_actual, rhs_actual); CHECK(value_operations::are_equal(res_expected, res_actual)); } template - static void run_on_device(const std::valarray& res_expected) { + static void run_on_device(init_scalar sca, init_sequence seq, + const std::valarray& res_expected) { INFO("validation on device"); auto queue = sycl_cts::util::get_cts_object::queue(); @@ -462,9 +468,9 @@ class run_binary_scalar_sequence { cgh.single_task([=]() { OpT op; DataT lhs_actual; - helper::template init(lhs_actual); + helper::init(lhs_actual, sca); typename helper::marray_t rhs_actual; - helper::template init(rhs_actual); + helper::init(rhs_actual, seq); res_actual_acc[0] = op(lhs_actual, rhs_actual); }); }) @@ -476,53 +482,55 @@ class run_binary_scalar_sequence { } public: - void operator()(const std::string& constant_name, - const std::string& function_name) { - if constexpr (is_shift_op() && - test_case_is_invalid_for_shift_op()) { - return; - } + void operator()() { + for (const init_scalar sca : all_init_scalars) { + for (const init_sequence seq : all_init_sequences) { + if (is_shift_op() && + test_case_is_invalid_for_shift_op(seq, helper::NumElements)) + continue; - INFO("for lhs (scalar) \"" << constant_name << "\": "); - INFO("for rhs (sequence) \"" << function_name << "\": "); + INFO("for lhs (scalar) \"" << get_scalar_name(sca) << "\": "); + INFO("for rhs (sequence) \"" << get_sequence_name(seq) << "\": "); - OpT op; + OpT op; - DataT lhs_expected; - helper::template init(lhs_expected); - typename helper::varray_t rhs_expected(helper::NumElements); - helper::template init(rhs_expected); - auto res_expected = op(lhs_expected, rhs_expected); + DataT lhs_expected; + helper::init(lhs_expected, sca); + typename helper::varray_t rhs_expected(helper::NumElements); + helper::init(rhs_expected, seq); + auto res_expected = op(lhs_expected, rhs_expected); - run_on_host(res_expected); + run_on_host(sca, seq, res_expected); - run_on_device(res_expected); + run_on_device(sca, seq, res_expected); + } + } } }; -template +template class run_binary_sequence_sequence { using helper = operators_helper; template - static void run_on_host(const ResT& res_expected) { + static void run_on_host(init_sequence seq1, init_sequence seq2, + const ResT& res_expected) { INFO("validation on host"); OpT op; typename helper::marray_t lhs_actual; - helper::template init(lhs_actual); + helper::init(lhs_actual, seq1); typename helper::marray_t rhs_actual; - helper::template init(rhs_actual); + helper::init(rhs_actual, seq2); auto res_actual = op(lhs_actual, rhs_actual); CHECK(value_operations::are_equal(res_expected, res_actual)); } template - static void run_on_device(const std::valarray& res_expected) { + static void run_on_device(init_sequence seq1, init_sequence seq2, + const std::valarray& res_expected) { INFO("validation on device"); auto queue = sycl_cts::util::get_cts_object::queue(); @@ -539,9 +547,9 @@ class run_binary_sequence_sequence { cgh.single_task([=]() { OpT op; typename helper::marray_t lhs_actual; - helper::template init(lhs_actual); + helper::init(lhs_actual, seq1); typename helper::marray_t rhs_actual; - helper::template init(rhs_actual); + helper::init(rhs_actual, seq2); res_actual_acc[0] = op(lhs_actual, rhs_actual); }); }) @@ -553,28 +561,29 @@ class run_binary_sequence_sequence { } public: - void operator()(const std::string& function_name_1, - const std::string& function_name_2) { - if constexpr (is_shift_op() && - test_case_is_invalid_for_shift_op()) { - return; - } + void operator()() { + for (const init_sequence seq1 : all_init_sequences) { + for (const init_sequence seq2 : all_init_sequences) { + if (is_shift_op() && + test_case_is_invalid_for_shift_op(seq2, helper::NumElements)) + continue; - INFO("for lhs (sequence) \"" << function_name_1 << "\": "); - INFO("for rhs (sequence) \"" << function_name_2 << "\": "); + INFO("for lhs (sequence) \"" << get_sequence_name(seq1) << "\": "); + INFO("for rhs (sequence) \"" << get_sequence_name(seq2) << "\": "); - OpT op; + OpT op; - typename helper::varray_t lhs_expected(helper::NumElements); - helper::template init(lhs_expected); - typename helper::varray_t rhs_expected(helper::NumElements); - helper::template init(rhs_expected); - auto res_expected = op(lhs_expected, rhs_expected); + typename helper::varray_t lhs_expected(helper::NumElements); + helper::init(lhs_expected, seq1); + typename helper::varray_t rhs_expected(helper::NumElements); + helper::init(rhs_expected, seq2); + auto res_expected = op(lhs_expected, rhs_expected); - run_on_host(res_expected); + run_on_host(seq1, seq2, res_expected); - run_on_device(res_expected); + run_on_device(seq1, seq2, res_expected); + } + } } }; @@ -598,33 +607,28 @@ class run_binary< void operator()(const std::string& operator_name) { INFO("for operator \"" << operator_name << "\": "); - const auto constants = get_scalars(); - const auto functions = get_sequences(); - for_all_combinations( - functions, constants); - for_all_combinations( - constants, functions); - for_all_combinations(functions, functions); + run_binary_sequence_scalar{}(); + run_binary_scalar_sequence{}(); + run_binary_sequence_sequence{}(); } }; -template +template class run_binary_assignment_sequence_scalar { using helper = operators_helper; template - static void run_on_host(const typename helper::varray_t& lhs_expected, + static void run_on_host(init_sequence seq, init_scalar sca, + const typename helper::varray_t& lhs_expected, const ResT& res_expected) { INFO("validation on host"); OpT op; typename helper::marray_t lhs_actual; - helper::template init(lhs_actual); + helper::init(lhs_actual, seq); DataT rhs_actual; - helper::template init(rhs_actual); + helper::init(rhs_actual, sca); auto res_actual = op(lhs_actual, rhs_actual); // check the returned output @@ -634,7 +638,8 @@ class run_binary_assignment_sequence_scalar { } template - static void run_on_device(const typename helper::varray_t& lhs_expected, + static void run_on_device(init_sequence seq, init_scalar sca, + const typename helper::varray_t& lhs_expected, const std::valarray& res_expected) { INFO("validation on device"); @@ -656,10 +661,9 @@ class run_binary_assignment_sequence_scalar { sycl::write_only}; cgh.single_task([=]() { OpT op; - typename helper::marray_t lhs_actual; - helper::template init(lhs_actual_acc[0]); + helper::init(lhs_actual_acc[0], seq); DataT rhs_actual; - helper::template init(rhs_actual); + helper::init(rhs_actual, sca); res_actual_acc[0] = op(lhs_actual_acc[0], rhs_actual); }); }) @@ -675,41 +679,44 @@ class run_binary_assignment_sequence_scalar { } public: - void operator()(const std::string& function_name, - const std::string& constant_name) { - INFO("for lhs (sequence) \"" << function_name << "\": "); - INFO("for rhs (scalar) \"" << constant_name << "\": "); + void operator()() { + for (const init_sequence seq : all_init_sequences) { + for (const init_scalar sca : all_init_scalars) { + INFO("for lhs (sequence) \"" << get_sequence_name(seq) << "\": "); + INFO("for rhs (scalar) \"" << get_scalar_name(sca) << "\": "); - OpT op; + OpT op; - typename helper::varray_t lhs_expected(helper::NumElements); - helper::template init(lhs_expected); - DataT rhs_expected; - helper::template init(rhs_expected); - auto res_expected = op(lhs_expected, rhs_expected); + typename helper::varray_t lhs_expected(helper::NumElements); + helper::init(lhs_expected, seq); + DataT rhs_expected; + helper::init(rhs_expected, sca); + auto res_expected = op(lhs_expected, rhs_expected); - run_on_host(lhs_expected, res_expected); + run_on_host(seq, sca, lhs_expected, res_expected); - run_on_device(lhs_expected, res_expected); + run_on_device(seq, sca, lhs_expected, res_expected); + } + } } }; -template +template class run_binary_assignment_sequence_sequence { using helper = operators_helper; template - static void run_on_host(const typename helper::varray_t& lhs_expected, + static void run_on_host(init_sequence seq1, init_sequence seq2, + const typename helper::varray_t& lhs_expected, const ResT& res_expected) { INFO("validation on host"); OpT op; typename helper::marray_t lhs_actual; - helper::template init(lhs_actual); + helper::init(lhs_actual, seq1); typename helper::marray_t rhs_actual; - helper::template init(rhs_actual); + helper::init(rhs_actual, seq2); auto res_actual = op(lhs_actual, rhs_actual); // check the returned output @@ -719,7 +726,8 @@ class run_binary_assignment_sequence_sequence { } template - static void run_on_device(const typename helper::varray_t& lhs_expected, + static void run_on_device(init_sequence seq1, init_sequence seq2, + const typename helper::varray_t& lhs_expected, const std::valarray& res_expected) { INFO("validation on device"); @@ -741,10 +749,9 @@ class run_binary_assignment_sequence_sequence { sycl::write_only}; cgh.single_task([=]() { OpT op; - typename helper::marray_t lhs_actual; - helper::template init(lhs_actual_acc[0]); + helper::init(lhs_actual_acc[0], seq1); typename helper::marray_t rhs_actual; - helper::template init(rhs_actual); + helper::init(rhs_actual, seq2); res_actual_acc[0] = op(lhs_actual_acc[0], rhs_actual); }); }) @@ -760,28 +767,29 @@ class run_binary_assignment_sequence_sequence { } public: - void operator()(const std::string& function_name_1, - const std::string& function_name_2) { - if constexpr (is_shift_op() && - test_case_is_invalid_for_shift_op()) { - return; - } + void operator()() { + for (const init_sequence seq1 : all_init_sequences) { + for (const init_sequence seq2 : all_init_sequences) { + if (is_shift_op() && + test_case_is_invalid_for_shift_op(seq2, helper::NumElements)) + continue; - INFO("for lhs (sequence) \"" << function_name_1 << "\": "); - INFO("for rhs (sequence) \"" << function_name_2 << "\": "); + INFO("for lhs (sequence) \"" << get_sequence_name(seq1) << "\": "); + INFO("for rhs (sequence) \"" << get_sequence_name(seq2) << "\": "); - OpT op; + OpT op; - typename helper::varray_t lhs_expected(helper::NumElements); - helper::template init(lhs_expected); - typename helper::varray_t rhs_expected(helper::NumElements); - helper::template init(rhs_expected); - auto res_expected = op(lhs_expected, rhs_expected); + typename helper::varray_t lhs_expected(helper::NumElements); + helper::init(lhs_expected, seq1); + typename helper::varray_t rhs_expected(helper::NumElements); + helper::init(rhs_expected, seq2); + auto res_expected = op(lhs_expected, rhs_expected); - run_on_host(lhs_expected, res_expected); + run_on_host(seq1, seq2, lhs_expected, res_expected); - run_on_device(lhs_expected, res_expected); + run_on_device(seq1, seq2, lhs_expected, res_expected); + } + } } }; @@ -807,12 +815,8 @@ class run_binary_assignment< void operator()(const std::string& operator_name) { INFO("for operator \"" << operator_name << "\": "); - const auto constants = get_scalars(); - const auto functions = get_sequences(); - for_all_combinations(functions, constants); - for_all_combinations(functions, functions); + run_binary_assignment_sequence_scalar{}(); + run_binary_assignment_sequence_sequence{}(); } };