From 0a6c187f1f42a216e2985e0cfa9fd2ece5ff3587 Mon Sep 17 00:00:00 2001 From: zhangyunze Date: Mon, 16 Mar 2026 06:36:26 +0000 Subject: [PATCH 1/2] feat: Add RMSNorm op in cambricon backend. --- CMakeLists.txt | 2 +- src/CMakeLists.txt | 39 +++- src/base/rms_norm.h | 10 +- src/cambricon/cast.h | 99 ++++++++ src/cambricon/common.h | 50 ++++ src/cambricon/rmsnorm/kernel.mlu | 298 ++++++++++++++++++++++++ src/cambricon/rmsnorm/rms_norm.h | 93 ++++++++ src/common/cast.h | 11 - src/cpu/add/add.h | 2 +- src/{common => }/cpu/cast.h | 0 src/cpu/causal_softmax/causal_softmax.h | 2 +- src/cpu/gemm/gemm.h | 2 +- src/cpu/rms_norm/rms_norm.h | 2 +- src/cpu/swiglu/swiglu.h | 2 +- src/cuda/add/kernel.h | 1 + src/{common => }/cuda/cast.h | 0 src/cuda/causal_softmax/kernel.cuh | 3 +- src/{common => }/cuda/kernel_commons.h | 0 src/cuda/rms_norm/kernel.h | 12 +- src/cuda/swiglu/kernel.cuh | 2 +- src/data_type.h | 6 + 21 files changed, 605 insertions(+), 31 deletions(-) create mode 100644 src/cambricon/cast.h create mode 100644 src/cambricon/rmsnorm/kernel.mlu create mode 100644 src/cambricon/rmsnorm/rms_norm.h delete mode 100644 src/common/cast.h rename src/{common => }/cpu/cast.h (100%) rename src/{common => }/cuda/cast.h (100%) rename src/{common => }/cuda/kernel_commons.h (100%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 570b7d7..b9e2deb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -179,7 +179,7 @@ if(WITH_CAMBRICON) endif() # If all other platforms are not enabled, CPU is enabled by default. -if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE) +if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX AND NOT WITH_MOORE AND NOT WITH_CAMBRICON) add_compile_definitions(WITH_CPU=1) endif() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3ca0715..585e3ab 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -127,11 +127,48 @@ if(WITH_MOORE) endif() if(WITH_CAMBRICON) - target_compile_definitions(infiniops PUBLIC WITH_CAMBRICON=1) + file(GLOB_RECURSE CAMBRICON_MLU_SOURCES CONFIGURE_DEPENDS "cambricon/*/*.mlu") + find_program(CNCC_COMPILER cncc HINTS "${NEUWARE_HOME}/bin" "$ENV{NEUWARE_HOME}/bin" /usr/local/neuware/bin) + if(CNCC_COMPILER) + message(STATUS "Found cncc: ${CNCC_COMPILER}") + set(MLU_COMPILE_OPTS + -c --bang-mlu-arch=mtp_592 -O3 -fPIC -Wall -Werror -std=c++17 -pthread + -I${CMAKE_CURRENT_SOURCE_DIR} -I${NEUWARE_HOME}/include + -idirafter /usr/local/neuware/lib/clang/11.1.0/include + ) + function(compile_mlu_file src_file) + get_filename_component(name ${src_file} NAME_WE) + get_filename_component(path ${src_file} DIRECTORY) + set(out_file "${CMAKE_CURRENT_BINARY_DIR}/${path}/${name}.o") + file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/${path}") + add_custom_command(OUTPUT ${out_file} + COMMAND ${CNCC_COMPILER} ${MLU_COMPILE_OPTS} -c ${src_file} -o ${out_file} + DEPENDS ${src_file} + COMMENT "Building MLU kernel: ${src_file}" + ) + set_property(DIRECTORY APPEND PROPERTY CAMBRICON_OBJECTS ${out_file}) + endfunction() + foreach(src ${CAMBRICON_MLU_SOURCES}) + compile_mlu_file(${src}) + endforeach() + get_directory_property(CAMBRICON_OBJECT_FILES CAMBRICON_OBJECTS) + if(CAMBRICON_OBJECT_FILES) + target_sources(infiniops PRIVATE ${CAMBRICON_OBJECT_FILES}) + endif() + else() + message(WARNING "cncc compiler not found. MLU kernels will not be compiled.") + endif() + target_compile_definitions(infiniops PRIVATE WITH_CAMBRICON=1) target_include_directories(infiniops PUBLIC "${NEUWARE_HOME}/include") target_link_libraries(infiniops PUBLIC ${CAMBRICON_RUNTIME_LIB} ${CAMBRICON_CNNL_LIB} ${CAMBRICON_CNNL_EXTRA_LIB} ${CAMBRICON_PAPI_LIB}) + if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang") + target_compile_options(infiniops PUBLIC + "$<$:SHELL:-idirafter /usr/local/neuware/lib/clang/11.1.0/include>" + ) + endif() + list(APPEND DEVICE_LIST "cambricon") endif() diff --git a/src/base/rms_norm.h b/src/base/rms_norm.h index 65f44b3..dc28f0a 100644 --- a/src/base/rms_norm.h +++ b/src/base/rms_norm.h @@ -12,15 +12,17 @@ namespace infini::ops { class RmsNorm : public Operator { public: RmsNorm(const Tensor input, const Tensor weight, float eps, Tensor out) - : eps_{eps}, + : input_shape_{input.shape()}, out_shape_{out.shape()}, - input_shape_{input.shape()}, - out_strides_{out.strides()}, input_strides_{input.strides()}, + out_strides_{out.strides()}, + eps_{eps}, dim_{out.size(-1)}, ndim_{out.ndim()}, batch_size_{ndim_ == 2 ? out.size(-2) : out.size(-3)}, - nhead_{ndim_ == 2 ? 1 : out.size(-2)} {} + nhead_{ndim_ == 2 ? 1 : out.size(-2)} { + assert(input.dtype() == out.dtype()); + } RmsNorm(const Tensor input, const Tensor weight, Tensor out) : RmsNorm{input, weight, 1e-6f, out} {} diff --git a/src/cambricon/cast.h b/src/cambricon/cast.h new file mode 100644 index 0000000..fca4265 --- /dev/null +++ b/src/cambricon/cast.h @@ -0,0 +1,99 @@ +#ifndef INFINI_OPS_COMMON_CAMBRICON_CAST_H_ +#define INFINI_OPS_COMMON_CAMBRICON_CAST_H_ + +#include "bang_bf16.h" +#include "bang_fp16.h" +#include "data_type.h" + +namespace infini::ops { + +namespace detail { + +template +using PureType = std::remove_cv_t>; + +template +__host__ __device__ constexpr float ToFloatHelper(T&& x) { + using PureSrc = PureType; + if constexpr (IsBFloat16) { + return __bfloat162float__(x); + } else if constexpr (IsFP16) { + return __half2float(x); + } else { + return static_cast(std::forward(x)); + } +} + +template +__host__ __device__ constexpr Dst FromFloatHelper(float f) { + using PureDst = PureType; + if constexpr (IsBFloat16) { + return __float2bfloat16__(f); + } else if constexpr (IsFP16) { + return __float2half__(f); + } else { + return static_cast(f); + } +} + +// Priority tags for overload resolution. +struct PriorityLow {}; + +struct PriorityHigh : PriorityLow {}; + +// Fallback: lowest priority. This always matches if nothing else does. +template +__host__ __device__ constexpr Dst HardwareCast(Src&& x, PriorityLow) { + return FromFloatHelper(ToFloatHelper(std::forward(x))); +} + +// Usage: `DEFINE_DIRECT_CAST(INTRINSIC, CONDITION)`. +#define DEFINE_DIRECT_CAST(INTRINSIC, ...) \ + template \ + __host__ __device__ auto HardwareCast(Src x, PriorityHigh) \ + ->std::enable_if_t<(__VA_ARGS__), \ + decltype(INTRINSIC(std::declval()))> { \ + return INTRINSIC(x); \ + } + +DEFINE_DIRECT_CAST( + __bfloat162int_rz__, + std::is_same_v, int>&& IsBFloat16>) +DEFINE_DIRECT_CAST( + __bfloat162short_rz__, + std::is_same_v, short>&& IsBFloat16>) +DEFINE_DIRECT_CAST( + __int2bfloat16_rn__, + IsBFloat16>&& std::is_same_v, int>) +DEFINE_DIRECT_CAST(__int2half_rn__, + IsFP16>&& std::is_same_v, int>) +DEFINE_DIRECT_CAST( + __float2bfloat16__, + IsBFloat16>&& std::is_same_v, double>) +DEFINE_DIRECT_CAST( + __float2half__, + IsFP16>&& std::is_same_v, double>) +DEFINE_DIRECT_CAST(__half, IsFP16>&& IsBFloat16>) +#undef DEFINE_DIRECT_CAST + +} // namespace detail + +template +__host__ __device__ Dst Cast(Src&& x) { + static_assert(!std::is_reference_v, + "`Cast` cannot return reference types"); + + using PureSrc = std::remove_cv_t>; + using PureDst = std::remove_cv_t>; + + if constexpr (std::is_same_v) { + return std::forward(x); + } else { + return detail::HardwareCast(std::forward(x), + detail::PriorityHigh{}); + } +} + +} // namespace infini::ops + +#endif diff --git a/src/cambricon/common.h b/src/cambricon/common.h index 50775c2..fc8ede0 100644 --- a/src/cambricon/common.h +++ b/src/cambricon/common.h @@ -2,19 +2,58 @@ #define INFINI_OPS_CAMBRICON_COMMON_H_ #include +#include #include "data_type.h" +#include "device.h" + +#define NRAM_MAX_SIZE (1024 * 240) + +#ifdef __BANG__ + +namespace infini::ops::reduce { + +constexpr int batch_size = 128 / sizeof(float); + +__mlu_func__ void SumInternal(float* dst, float* src, int max_batch) { + const int width = max_batch / batch_size; + + if (width >= 4) { + __bang_sumpool(dst, src, batch_size, 1, width, 1, width, 1, 1); + __bang_reduce_sum(dst, dst, batch_size); + } else { + float sum = 0.0f; + for (int i = 0; i < max_batch; ++i) { + sum += src[i]; + } + dst[0] = sum; + } +} + +} // namespace infini::ops::reduce + +#endif // __BANG__ namespace infini::ops::cnnl_utils { inline cnnlDataType_t GetDataType(DataType dtype) { switch (dtype) { + case DataType::kInt8: + return CNNL_DTYPE_INT8; + case DataType::kUInt8: + return CNNL_DTYPE_UINT8; case DataType::kInt32: return CNNL_DTYPE_INT32; + case DataType::kInt64: + return CNNL_DTYPE_INT64; case DataType::kFloat16: return CNNL_DTYPE_HALF; case DataType::kFloat32: return CNNL_DTYPE_FLOAT; + case DataType::kBFloat16: + return CNNL_DTYPE_BFLOAT16; + case DataType::kFloat64: + return CNNL_DTYPE_DOUBLE; default: return CNNL_DTYPE_INVALID; } @@ -22,4 +61,15 @@ inline cnnlDataType_t GetDataType(DataType dtype) { } // namespace infini::ops::cnnl_utils +namespace infini::ops::cnrt_utils { + +inline void GetLaunchConfig(const Device& device, int* core_per_cluster, + int* cluster_count) { + int device_id = device.index(); + cnrtDeviceGetAttribute(cluster_count, cnrtAttrClusterCount, device_id); + cnrtDeviceGetAttribute(core_per_cluster, cnrtAttrMcorePerCluster, device_id); +} + +} // namespace infini::ops::cnrt_utils + #endif diff --git a/src/cambricon/rmsnorm/kernel.mlu b/src/cambricon/rmsnorm/kernel.mlu new file mode 100644 index 0000000..6a4a25c --- /dev/null +++ b/src/cambricon/rmsnorm/kernel.mlu @@ -0,0 +1,298 @@ +#define WITH_CAMBRICON +#include "rms_norm.h" + +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +namespace infini::ops { + +template +__mlu_global__ void rmsnorm(T *output, const T *input, const Tw *weight, + size_t *shape, ptrdiff_t *output_strides, ptrdiff_t *input_strides, + float epsilon, int num_dims, int norm_dim_size) { + // Calculate problem dimensions + int batch_volume = 1; + for (int dim_idx = 0; dim_idx < num_dims - 1; ++dim_idx) { + batch_volume *= shape[dim_idx]; + } + int vector_size = shape[num_dims - 1]; + + // Task distribution across cores + int remaining_tasks = batch_volume % taskDim; + int base_tasks_per_core = batch_volume / taskDim; + int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0); + int task_start_idx = (taskId < remaining_tasks ? taskId * (base_tasks_per_core + 1) : remaining_tasks * (base_tasks_per_core + 1) + (taskId - remaining_tasks) * base_tasks_per_core); + + // Determine optimal batch size based on vector size + int max_batch_size; + if (vector_size <= 64) { + // For small vectors, process the entire vector at once + max_batch_size = vector_size; + } else { + // For larger vectors, use optimized batch size + max_batch_size = (NRAM_MAX_SIZE - 256) / (2 * sizeof(T) + sizeof(Tw) + sizeof(float)); + max_batch_size = std::min(max_batch_size, vector_size); + max_batch_size = (max_batch_size / 64) * 64; // Align to 64 elements + } + + constexpr int reduce_buffer_size = 128 / sizeof(float); + + // NRAM buffer allocation with dynamic sizing + float *reduction_buffer = (float *)nram_buffer; + T *input_cache = (T *)(reduction_buffer + reduce_buffer_size); + Tw *weight_cache = (Tw *)(input_cache + max_batch_size); + float *float_buffer = (float *)(weight_cache + max_batch_size); + float *weight_float_buffer = (float *)(float_buffer + max_batch_size); + + // Process vectors assigned to current core + for (int task_idx = 0; task_idx < actual_tasks; ++task_idx) { + int current_index = task_start_idx + task_idx; + + // Calculate memory offsets for current task + int input_offset = 0; + int output_offset = 0; + int temp_index = current_index; + + for (int dim = 0; dim < num_dims - 1; ++dim) { + int dim_coord = temp_index % shape[dim]; + input_offset += dim_coord * input_strides[dim]; + output_offset += dim_coord * output_strides[dim]; + temp_index /= shape[dim]; + } + + // Compute sum of squares + float sum_squared = 0.0f; + + if (vector_size <= 128) { + // Small vector optimization: process entire vector at once + __memcpy(input_cache, input + input_offset, vector_size * sizeof(T), GDRAM2NRAM); + + // Convert to float and square + if constexpr (std::is_same::value) { + __bang_half2float(float_buffer, reinterpret_cast(input_cache), vector_size); + } else if constexpr (std::is_same::value) { + __bang_bfloat162float(float_buffer, input_cache, vector_size); + } else { + __memcpy(float_buffer, input_cache, vector_size * sizeof(float), NRAM2NRAM); + } + + __bang_mul(float_buffer, float_buffer, float_buffer, vector_size); + + // Direct accumulation for small vectors + for (int i = 0; i < vector_size; ++i) { + sum_squared += float_buffer[i]; + } + } else { + // Large vector processing with chunking + __bang_write_value(reduction_buffer, reduce_buffer_size, 0); + size_t processed_elements = 0; + + while (processed_elements < vector_size) { + size_t current_batch = std::min((size_t)max_batch_size, vector_size - processed_elements); + + // Load input data + __memcpy(input_cache, input + input_offset + processed_elements * input_strides[num_dims - 1], + current_batch * sizeof(T), GDRAM2NRAM); + + // Convert to float and square + if constexpr (std::is_same::value) { + __bang_half2float(float_buffer, reinterpret_cast(input_cache), current_batch); + } else if constexpr (std::is_same::value) { + __bang_bfloat162float(float_buffer, input_cache, current_batch); + } else { + __memcpy(float_buffer, input_cache, current_batch * sizeof(float), NRAM2NRAM); + } + + __bang_mul(float_buffer, float_buffer, float_buffer, current_batch); + + // Accumulate squared values + float batch_sum = 0.0f; + if (current_batch >= 128) { + infini::ops::reduce::SumInternal(reduction_buffer, float_buffer, current_batch); + batch_sum = reduction_buffer[0]; + } else { + for (size_t i = 0; i < current_batch; ++i) { + batch_sum += float_buffer[i]; + } + } + + sum_squared += batch_sum; + processed_elements += current_batch; + } + } + + // Compute normalization factor + float rms_value = sqrtf(sum_squared / vector_size + epsilon); + float inv_rms = 1.0f / rms_value; + + // Process vector for normalization + if (vector_size <= max_batch_size) { + // Process entire vector at once for small vectors + __memcpy(input_cache, input + input_offset, vector_size * sizeof(T), GDRAM2NRAM); + __memcpy(weight_cache, weight, vector_size * sizeof(Tw), GDRAM2NRAM); + + // Convert input to float + if constexpr (std::is_same::value) { + __bang_half2float(float_buffer, reinterpret_cast(input_cache), vector_size); + } else if constexpr (std::is_same::value) { + __bang_bfloat162float(float_buffer, input_cache, vector_size); + } else { + __memcpy(float_buffer, input_cache, vector_size * sizeof(float), NRAM2NRAM); + } + + // Convert weight to float if needed + if constexpr (std::is_same::value) { + __bang_half2float(weight_float_buffer, reinterpret_cast(weight_cache), vector_size); + } else if constexpr (std::is_same::value) { + __bang_bfloat162float(weight_float_buffer, weight_cache, vector_size); + } else { + __memcpy(weight_float_buffer, weight_cache, vector_size * sizeof(float), NRAM2NRAM); + } + + // Multiply by weight and apply normalization + __bang_mul(float_buffer, float_buffer, weight_float_buffer, vector_size); + __bang_mul_scalar(float_buffer, float_buffer, inv_rms, vector_size); + + // Convert back to output type + if constexpr (std::is_same::value) { + __bang_float2half(reinterpret_cast(input_cache), float_buffer, vector_size); + } else if constexpr (std::is_same::value) { + __bang_float2bfloat16(input_cache, float_buffer, vector_size); + } else { + __memcpy(input_cache, float_buffer, vector_size * sizeof(float), NRAM2NRAM); + } + + // Store results + __memcpy(output + output_offset, input_cache, vector_size * sizeof(T), NRAM2GDRAM); + } else { + // Large vector processing with chunking + size_t processed_elements = 0; + while (processed_elements < vector_size) { + size_t current_batch = std::min((size_t)max_batch_size, vector_size - processed_elements); + + // Load input and weight data + __memcpy(input_cache, input + input_offset + processed_elements * input_strides[num_dims - 1], + current_batch * sizeof(T), GDRAM2NRAM); + __memcpy(weight_cache, weight + processed_elements, current_batch * sizeof(Tw), GDRAM2NRAM); + + // Convert input to float + if constexpr (std::is_same::value) { + __bang_half2float(float_buffer, reinterpret_cast(input_cache), current_batch); + } else if constexpr (std::is_same::value) { + __bang_bfloat162float(float_buffer, input_cache, current_batch); + } else { + __memcpy(float_buffer, input_cache, current_batch * sizeof(float), NRAM2NRAM); + } + + // Convert weight to float if needed + if constexpr (std::is_same::value) { + __bang_half2float(weight_float_buffer, reinterpret_cast(weight_cache), current_batch); + } else if constexpr (std::is_same::value) { + __bang_bfloat162float(weight_float_buffer, weight_cache, current_batch); + } else { + __memcpy(weight_float_buffer, weight_cache, current_batch * sizeof(float), NRAM2NRAM); + } + + // Multiply by weight and apply normalization + __bang_mul(float_buffer, float_buffer, weight_float_buffer, current_batch); + __bang_mul_scalar(float_buffer, float_buffer, inv_rms, current_batch); + + // Convert back to output type + if constexpr (std::is_same::value) { + __bang_float2half(reinterpret_cast(input_cache), float_buffer, current_batch); + } else if constexpr (std::is_same::value) { + __bang_float2bfloat16(input_cache, float_buffer, current_batch); + } else { + __memcpy(input_cache, float_buffer, current_batch * sizeof(float), NRAM2NRAM); + } + + // Store results + __memcpy(output + output_offset + processed_elements * output_strides[num_dims - 1], + input_cache, current_batch * sizeof(T), NRAM2GDRAM); + + processed_elements += current_batch; + } + } + } +} + +template +void RmsnormUnion(void *workspace, int core_per_cluster, int cluster_count, cnrtQueue_t queue, void *y, const void *x, const void *w, const size_t *shape, const ptrdiff_t *y_strides, const ptrdiff_t *x_strides, float eps, int ndim) { + cnrtDim3_t kernel_dim; + cnrtFunctionType_t kernel_type; + + // Configure kernel dimensions + kernel_dim.x = core_per_cluster; + kernel_dim.y = cluster_count; + kernel_dim.z = 1; + kernel_type = cnrtFuncTypeUnion1; // Can choose others, but must adapt kernel_type accordingly + int dimsize = shape[ndim - 1]; // Length of operation dimension + int dim_s; // dim_s is the next power of 2 greater than dimsize + float mi = log2(dimsize); + if (floor(mi) == mi) { + dim_s = dimsize; + } else { + dim_s = pow(2, floor(mi) + 1); + } + constexpr int reduce_num = 128 / sizeof(float); // Cambricon __bang_reduce_sum can only reduce 128 bytes at a time + if (dim_s < reduce_num) { + dim_s = reduce_num; // Force dim_s >= reduce_num + } + + // Prepare device pointers + auto y_ = reinterpret_cast(y); + auto x_ = reinterpret_cast(x); + auto w_ = reinterpret_cast(w); + char *tmp_device = reinterpret_cast(workspace); + char *tmp_stride = tmp_device + ndim * sizeof(size_t); + size_t *mlu_shape = (size_t *)tmp_device; + ptrdiff_t *mlu_x_strides = (ptrdiff_t *)tmp_stride; + ptrdiff_t *mlu_y_strides = mlu_x_strides + ndim; + + // Copy shape and stride information to device + CNRT_CHECK(cnrtMemcpyAsync(mlu_shape, const_cast(shape), ndim * sizeof(size_t), queue, cnrtMemcpyHostToDev)); // const not supported + CNRT_CHECK(cnrtMemcpyAsync(mlu_x_strides, const_cast(x_strides), ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev)); + CNRT_CHECK(cnrtMemcpyAsync(mlu_y_strides, const_cast(y_strides), ndim * sizeof(ptrdiff_t), queue, cnrtMemcpyHostToDev)); + + // Launch kernel + rmsnorm<<>>(y_, x_, w_, mlu_shape, mlu_y_strides, mlu_x_strides, eps, ndim, dim_s); + + cnrtQueueSync(queue); +} + +template void RmsnormUnion<__half, __half>( + void *, int, int, cnrtQueue_t, void *, const void *, const void *, + const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + +template void RmsnormUnion<__half, __bang_bfloat16>( + void *, int, int, cnrtQueue_t, void *, const void *, const void *, + const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + +template void RmsnormUnion<__half, float>( + void *, int, int, cnrtQueue_t, void *, const void *, const void *, + const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + +template void RmsnormUnion<__bang_bfloat16, __half>( + void *, int, int, cnrtQueue_t, void *, const void *, const void *, + const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + +template void RmsnormUnion<__bang_bfloat16, __bang_bfloat16>( + void *, int, int, cnrtQueue_t, void *, const void *, const void *, + const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + +template void RmsnormUnion<__bang_bfloat16, float>( + void *, int, int, cnrtQueue_t, void *, const void *, const void *, + const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + +template void RmsnormUnion( + void *, int, int, cnrtQueue_t, void *, const void *, const void *, + const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + +template void RmsnormUnion( + void *, int, int, cnrtQueue_t, void *, const void *, const void *, + const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + +template void RmsnormUnion( + void *, int, int, cnrtQueue_t, void *, const void *, const void *, + const size_t *, const ptrdiff_t *, const ptrdiff_t *, float, int); + +} // namespace infini::ops diff --git a/src/cambricon/rmsnorm/rms_norm.h b/src/cambricon/rmsnorm/rms_norm.h new file mode 100644 index 0000000..a4540eb --- /dev/null +++ b/src/cambricon/rmsnorm/rms_norm.h @@ -0,0 +1,93 @@ +#ifndef INFINI_OPS_CAMBRICON_RMS_NORM_H_ +#define INFINI_OPS_CAMBRICON_RMS_NORM_H_ + +#include +#include +#include + +#include "../common.h" +#include "base/rms_norm.h" + +namespace infini::ops { + +template +void RmsnormUnion(void *workspace, int core_per_cluster, int cluster_count, + cnrtQueue_t queue, void *y, const void *x, const void *w, + const size_t *shape, const ptrdiff_t *y_strides, + const ptrdiff_t *x_strides, float eps, int ndim); + +template <> +class Operator : public RmsNorm { + public: + Operator(const Tensor input, const Tensor weight, float eps, Tensor out) + : RmsNorm{input, weight, eps, out} { + cnrt_utils::GetLaunchConfig(input.device(), &core_per_cluster, + &cluster_count); + cnrtMalloc(&default_workspace_, workspace_size_in_bytes()); + } + + void operator()(const Tensor input, const Tensor weight, float eps, + Tensor out) const override { + auto queue = static_cast(stream_ ? stream_ : 0); + auto workspace{workspace_ ? workspace_ : default_workspace_}; + + DispatchFunc< + List, + List>( + {static_cast(input.dtype()), + static_cast(Device::Type::kCambricon)}, + 0, + [&](auto input_tag) { + constexpr DataType IDT = static_cast(ListGet<0>(input_tag)); + using InputT = TypeMapType; + DispatchFunc< + List, + List>( + {static_cast(weight.dtype()), + static_cast(Device::Type::kCambricon)}, + 0, + [&](auto weight_tag) { + constexpr DataType WDT = + static_cast(ListGet<0>(weight_tag)); + using WeightT = TypeMapType; + + RmsnormUnion( + workspace, core_per_cluster, cluster_count, queue, + out.data(), input.data(), weight.data(), out_shape_.data(), + out_strides_.data(), input_strides_.data(), eps, ndim_); + }, + "CambriconRmsNorm::operator() - weight dispatch", List<>{}); + }, + "CambriconRmsNorm::operator() - output dispatch", List<>{}); + // DispatchFunc, + // List>( + // {input.dtype(), weight.dtype()}, + // [&](auto input_tag, auto weight_tag) { + // using InputT = typename decltype(input_tag)::type; + // using WeightT = typename decltype(weight_tag)::type; + + // RmsnormUnion( + // workspace, core_per_cluster, cluster_count, queue, + // out.data(), input.data(), weight.data(), + // out_shape_.data(), out_strides_.data(), + // input_strides_.data(), eps, ndim_); + // }, + // "CambriconRmsNorm::operator()"); + } + + ~Operator() { cnrtFree(default_workspace_); } + + std::size_t workspace_size_in_bytes() const override { + return ndim_ * (sizeof(size_t) + 2 * sizeof(ptrdiff_t)); + } + + void *default_workspace_{nullptr}; + int core_per_cluster = 0; + int cluster_count = 0; +}; + +} // namespace infini::ops + +#endif diff --git a/src/common/cast.h b/src/common/cast.h deleted file mode 100644 index a37fb94..0000000 --- a/src/common/cast.h +++ /dev/null @@ -1,11 +0,0 @@ -#ifndef INFINI_OPS_COMMON_CAST_H_ -#define INFINI_OPS_COMMON_CAST_H_ - -#if defined(WITH_NVIDIA) || defined(WITH_ILUVATAR) || defined(WITH_METAX) || \ - defined(WITH_MOORE) -#include "common/cuda/cast.h" -#else -#include "common/cpu/cast.h" -#endif - -#endif diff --git a/src/cpu/add/add.h b/src/cpu/add/add.h index ec605c3..89675a3 100644 --- a/src/cpu/add/add.h +++ b/src/cpu/add/add.h @@ -4,8 +4,8 @@ #include #include "base/add.h" -#include "common/cast.h" #include "common/generic_utils.h" +#include "cpu/cast.h" namespace infini::ops { diff --git a/src/common/cpu/cast.h b/src/cpu/cast.h similarity index 100% rename from src/common/cpu/cast.h rename to src/cpu/cast.h diff --git a/src/cpu/causal_softmax/causal_softmax.h b/src/cpu/causal_softmax/causal_softmax.h index ca207a2..f159d42 100644 --- a/src/cpu/causal_softmax/causal_softmax.h +++ b/src/cpu/causal_softmax/causal_softmax.h @@ -4,8 +4,8 @@ #include #include "base/causal_softmax.h" -#include "common/cast.h" #include "common/generic_utils.h" +#include "cpu/cast.h" #include "data_type.h" #include "tensor.h" diff --git a/src/cpu/gemm/gemm.h b/src/cpu/gemm/gemm.h index 685a94a..96b1d37 100644 --- a/src/cpu/gemm/gemm.h +++ b/src/cpu/gemm/gemm.h @@ -4,8 +4,8 @@ #include #include "base/gemm.h" -#include "common/cast.h" #include "common/generic_utils.h" +#include "cpu/cast.h" namespace infini::ops { diff --git a/src/cpu/rms_norm/rms_norm.h b/src/cpu/rms_norm/rms_norm.h index b3caeb0..2130377 100644 --- a/src/cpu/rms_norm/rms_norm.h +++ b/src/cpu/rms_norm/rms_norm.h @@ -4,8 +4,8 @@ #include #include "base/rms_norm.h" -#include "common/cast.h" #include "common/generic_utils.h" +#include "cpu/cast.h" #include "data_type.h" #include "tensor.h" diff --git a/src/cpu/swiglu/swiglu.h b/src/cpu/swiglu/swiglu.h index ac2b3b2..065aad5 100644 --- a/src/cpu/swiglu/swiglu.h +++ b/src/cpu/swiglu/swiglu.h @@ -4,8 +4,8 @@ #include #include "base/swiglu.h" -#include "common/cast.h" #include "common/generic_utils.h" +#include "cpu/cast.h" namespace infini::ops { diff --git a/src/cuda/add/kernel.h b/src/cuda/add/kernel.h index c174afb..e3a5bd0 100644 --- a/src/cuda/add/kernel.h +++ b/src/cuda/add/kernel.h @@ -6,6 +6,7 @@ #include "base/add.h" #include "common/generic_utils.h" #include "cuda/add/kernel.cuh" +#include "cuda/kernel_commons.h" namespace infini::ops { diff --git a/src/common/cuda/cast.h b/src/cuda/cast.h similarity index 100% rename from src/common/cuda/cast.h rename to src/cuda/cast.h diff --git a/src/cuda/causal_softmax/kernel.cuh b/src/cuda/causal_softmax/kernel.cuh index d578998..ea241c2 100644 --- a/src/cuda/causal_softmax/kernel.cuh +++ b/src/cuda/causal_softmax/kernel.cuh @@ -75,8 +75,7 @@ __global__ void CausalSoftmaxKernel( for (size_t col = threadIdx.x; col < total_seq_len; col += block_size) { if (col < valid_len) { - Compute diff = - Cast(input_row[col]) - Cast(max_val); + Compute diff = Cast(input_row[col]) - Cast(max_val); out_row[col] = ExpAndCast(diff); } else { out_row[col] = Cast(0.0f); diff --git a/src/common/cuda/kernel_commons.h b/src/cuda/kernel_commons.h similarity index 100% rename from src/common/cuda/kernel_commons.h rename to src/cuda/kernel_commons.h diff --git a/src/cuda/rms_norm/kernel.h b/src/cuda/rms_norm/kernel.h index dc28ee5..48ab21e 100644 --- a/src/cuda/rms_norm/kernel.h +++ b/src/cuda/rms_norm/kernel.h @@ -41,12 +41,12 @@ class CudaRmsNorm : public RmsNorm { [&](auto tag) { using T = typename decltype(tag)::type; -#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \ - RmsNormKernel \ - <<>>( \ - reinterpret_cast(out.data()), stride_out_batch, \ - stride_out_nhead, reinterpret_cast(input.data()), \ - stride_input_batch, stride_input_nhead, \ +#define LAUNCH_RMS_NORM_KERNEL(BLOCK_SIZE) \ + RmsNormKernel \ + <<>>( \ + reinterpret_cast(out.data()), stride_out_batch, \ + stride_out_nhead, reinterpret_cast(input.data()), \ + stride_input_batch, stride_input_nhead, \ reinterpret_cast(weight.data()), nhead_, dim_, eps); if (block_size == CUDA_BLOCK_SIZE_2048) { diff --git a/src/cuda/swiglu/kernel.cuh b/src/cuda/swiglu/kernel.cuh index 8004b76..f3997e6 100644 --- a/src/cuda/swiglu/kernel.cuh +++ b/src/cuda/swiglu/kernel.cuh @@ -3,7 +3,7 @@ #include -#include "common/cuda/kernel_commons.h" +#include "cuda/kernel_commons.h" namespace infini::ops { diff --git a/src/data_type.h b/src/data_type.h index af2aec7..ce2adfe 100644 --- a/src/data_type.h +++ b/src/data_type.h @@ -17,6 +17,9 @@ #elif defined(WITH_MOORE) #include #include +#elif defined(WITH_CAMBRICON) +#include "bang_bf16.h" +#include "bang_fp16.h" #endif #include "common/constexpr_map.h" @@ -207,6 +210,9 @@ DEFINE_DATA_TYPE_MAPPING(kBFloat16, __maca_bfloat16) #elif defined(WITH_MOORE) DEFINE_DATA_TYPE_MAPPING(kFloat16, half) DEFINE_DATA_TYPE_MAPPING(kBFloat16, __mt_bfloat16) +#elif defined(WITH_CAMBRICON) +DEFINE_DATA_TYPE_MAPPING(kFloat16, __half) +DEFINE_DATA_TYPE_MAPPING(kBFloat16, __bang_bfloat16) #else DEFINE_DATA_TYPE_MAPPING(kFloat16, Float16) DEFINE_DATA_TYPE_MAPPING(kBFloat16, BFloat16) From 1d6fe7113b26ff3ce213a2a6df9f748b2cce9b3b Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Fri, 20 Mar 2026 17:50:18 +0800 Subject: [PATCH 2/2] refactor: make `Cast` utility to use `Device::Type` template parameter --- src/cambricon/{cast.h => cast_.h} | 30 +++++++------ src/cast.h | 25 +++++++++++ src/cpu/add/add.h | 7 +-- src/cpu/cast.h | 57 ------------------------ src/cpu/cast_.h | 59 +++++++++++++++++++++++++ src/cpu/causal_softmax/causal_softmax.h | 19 ++++---- src/cpu/gemm/gemm.h | 15 ++++--- src/cpu/rms_norm/rms_norm.h | 9 ++-- src/cpu/swiglu/swiglu.h | 9 ++-- src/cuda/add/kernel.cuh | 2 +- src/cuda/{cast.h => cast_.h} | 4 +- src/cuda/causal_softmax/kernel.cuh | 19 ++++---- src/cuda/causal_softmax/kernel.h | 2 +- src/cuda/rms_norm/kernel.cuh | 12 ++--- src/cuda/rms_norm/kernel.h | 2 +- src/operator.h | 5 ++- 16 files changed, 160 insertions(+), 116 deletions(-) rename src/cambricon/{cast.h => cast_.h} (82%) create mode 100644 src/cast.h delete mode 100644 src/cpu/cast.h create mode 100644 src/cpu/cast_.h rename src/cuda/{cast.h => cast_.h} (95%) diff --git a/src/cambricon/cast.h b/src/cambricon/cast_.h similarity index 82% rename from src/cambricon/cast.h rename to src/cambricon/cast_.h index fca4265..465dab5 100644 --- a/src/cambricon/cast.h +++ b/src/cambricon/cast_.h @@ -3,7 +3,7 @@ #include "bang_bf16.h" #include "bang_fp16.h" -#include "data_type.h" +#include "cast.h" namespace infini::ops { @@ -79,20 +79,22 @@ DEFINE_DIRECT_CAST(__half, IsFP16>&& IsBFloat16>) } // namespace detail template -__host__ __device__ Dst Cast(Src&& x) { - static_assert(!std::is_reference_v, - "`Cast` cannot return reference types"); - - using PureSrc = std::remove_cv_t>; - using PureDst = std::remove_cv_t>; - - if constexpr (std::is_same_v) { - return std::forward(x); - } else { - return detail::HardwareCast(std::forward(x), - detail::PriorityHigh{}); +struct CastHelper { + Dst operator()(Src&& x) const { + static_assert(!std::is_reference_v, + "`Cast` cannot return reference types"); + + using PureSrc = std::remove_cv_t>; + using PureDst = std::remove_cv_t>; + + if constexpr (std::is_same_v) { + return std::forward(x); + } else { + return detail::HardwareCast( + std::forward(x), detail::PriorityHigh{}); + } } -} +}; } // namespace infini::ops diff --git a/src/cast.h b/src/cast.h new file mode 100644 index 0000000..51137d3 --- /dev/null +++ b/src/cast.h @@ -0,0 +1,25 @@ +#ifndef INFINI_OPS_CAST_H_ +#define INFINI_OPS_CAST_H_ + +#include "data_type.h" +#include "device.h" + +namespace infini::ops { + +namespace detail { + +template +struct CastHelper { + Dst operator()(const Src& x) const { return static_cast(x); }; +}; + +} // namespace detail + +template +Dst Cast(Src&& x) { + return detail::CastHelper{}(x); +} + +} // namespace infini::ops + +#endif diff --git a/src/cpu/add/add.h b/src/cpu/add/add.h index 89675a3..604c6db 100644 --- a/src/cpu/add/add.h +++ b/src/cpu/add/add.h @@ -5,7 +5,7 @@ #include "base/add.h" #include "common/generic_utils.h" -#include "cpu/cast.h" +#include "cpu/cast_.h" namespace infini::ops { @@ -52,8 +52,9 @@ class Operator : public Add { auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(), out_strides_.data()); - out_ptr[out_idx] = Cast(Cast(input_ptr[input_idx]) + - Cast(other_ptr[other_idx])); + out_ptr[out_idx] = Cast( + Cast(input_ptr[input_idx]) + + Cast(other_ptr[other_idx])); } } }; diff --git a/src/cpu/cast.h b/src/cpu/cast.h deleted file mode 100644 index 68b95fc..0000000 --- a/src/cpu/cast.h +++ /dev/null @@ -1,57 +0,0 @@ -#ifndef INFINI_OPS_COMMON_CPU_CAST_H_ -#define INFINI_OPS_COMMON_CPU_CAST_H_ - -#include "data_type.h" - -namespace infini::ops { - -namespace detail { - -template -constexpr float ToFloatHelper(T &&x) { - using PureSrc = std::remove_cv_t >; - if constexpr (IsBFloat16 || IsFP16) { - return std::forward(x).ToFloat(); - } else { - return static_cast(std::forward(x)); - } -} - -template -constexpr Dst FromFloatHelper(float f) { - using PureDst = std::remove_cv_t >; - if constexpr (IsBFloat16 || IsFP16) { - return PureDst::FromFloat(f); - } else { - return static_cast(f); - } -} - -} // namespace detail - -template -Dst Cast(Src &&x) { - static_assert(!std::is_reference_v, - "`Cast` cannot return reference types"); - - using PureDst = std::remove_cv_t >; - using PureSrc = std::remove_cv_t >; - - if constexpr (std::is_same_v) { - return std::forward(x); - } - - constexpr bool src_is_custom = IsBFloat16 || IsFP16; - constexpr bool dst_is_custom = IsBFloat16 || IsFP16; - - if constexpr (!src_is_custom && !dst_is_custom) { - return static_cast(std::forward(x)); - } else { - return detail::FromFloatHelper( - detail::ToFloatHelper(std::forward(x))); - } -} - -} // namespace infini::ops - -#endif diff --git a/src/cpu/cast_.h b/src/cpu/cast_.h new file mode 100644 index 0000000..c5a5bed --- /dev/null +++ b/src/cpu/cast_.h @@ -0,0 +1,59 @@ +#ifndef INFINI_OPS_COMMON_CPU_CAST_H_ +#define INFINI_OPS_COMMON_CPU_CAST_H_ + +#include "cast.h" + +namespace infini::ops { + +namespace detail { + +template +constexpr float ToFloatHelper(T&& x) { + using PureSrc = std::remove_cv_t >; + if constexpr (IsBFloat16 || IsFP16) { + return std::forward(x).ToFloat(); + } else { + return static_cast(std::forward(x)); + } +} + +template +constexpr Dst FromFloatHelper(float f) { + using PureDst = std::remove_cv_t >; + if constexpr (IsBFloat16 || IsFP16) { + return PureDst::FromFloat(f); + } else { + return static_cast(f); + } +} + +template +struct CastHelper { + Dst operator()(Src&& x) const { + static_assert(!std::is_reference_v, + "`Cast` cannot return reference types"); + + using PureDst = std::remove_cv_t >; + using PureSrc = std::remove_cv_t >; + + if constexpr (std::is_same_v) { + return std::forward(x); + } + + constexpr bool src_is_custom = IsBFloat16 || IsFP16; + constexpr bool dst_is_custom = IsBFloat16 || IsFP16; + + if constexpr (!src_is_custom && !dst_is_custom) { + return static_cast(std::forward(x)); + } else { + return detail::FromFloatHelper( + detail::ToFloatHelper(std::forward(x))); + } + } +}; + +} // namespace detail + +} // namespace infini::ops + +#endif diff --git a/src/cpu/causal_softmax/causal_softmax.h b/src/cpu/causal_softmax/causal_softmax.h index f159d42..679e4b3 100644 --- a/src/cpu/causal_softmax/causal_softmax.h +++ b/src/cpu/causal_softmax/causal_softmax.h @@ -5,7 +5,7 @@ #include "base/causal_softmax.h" #include "common/generic_utils.h" -#include "cpu/cast.h" +#include "cpu/cast_.h" #include "data_type.h" #include "tensor.h" @@ -49,12 +49,12 @@ class Operator : public CausalSoftmax { Tensor::Size valid_len = total_seq_len_ - seq_len_ + i + 1; for (Tensor::Size j = valid_len; j < total_seq_len_; ++j) { - out_row[j * out_stride_j] = Cast(0.0f); + out_row[j * out_stride_j] = Cast(0.0f); } - float max_val = Cast(input_row[0]); + float max_val = Cast(input_row[0]); for (Tensor::Size j = 1; j < valid_len; ++j) { - float v = Cast(input_row[j * input_stride_j]); + float v = Cast(input_row[j * input_stride_j]); if (v > max_val) { max_val = v; } @@ -62,15 +62,16 @@ class Operator : public CausalSoftmax { float sum = 0.0f; for (Tensor::Size j = 0; j < valid_len; ++j) { - float v = - std::exp(Cast(input_row[j * input_stride_j]) - max_val); - out_row[j * out_stride_j] = Cast(v); + float v = std::exp( + Cast(input_row[j * input_stride_j]) - + max_val); + out_row[j * out_stride_j] = Cast(v); sum += v; } for (Tensor::Size j = 0; j < valid_len; ++j) { - out_row[j * out_stride_j] = - Cast(Cast(out_row[j * out_stride_j]) / sum); + out_row[j * out_stride_j] = Cast( + Cast(out_row[j * out_stride_j]) / sum); } } } diff --git a/src/cpu/gemm/gemm.h b/src/cpu/gemm/gemm.h index 96b1d37..afa95b1 100644 --- a/src/cpu/gemm/gemm.h +++ b/src/cpu/gemm/gemm.h @@ -5,7 +5,7 @@ #include "base/gemm.h" #include "common/generic_utils.h" -#include "cpu/cast.h" +#include "cpu/cast_.h" namespace infini::ops { @@ -78,14 +78,19 @@ class Operator : public Gemm { float sum = 0.0f; for (Tensor::Size l = 0; l < k_; ++l) { - float a_val = Cast(A_batch[i * stride_a_m + l * stride_a_k]); - float b_val = Cast(B_batch[l * stride_b_k + j * stride_b_n]); + float a_val = Cast( + A_batch[i * stride_a_m + l * stride_a_k]); + float b_val = Cast( + B_batch[l * stride_b_k + j * stride_b_n]); sum += a_val * b_val; } Tensor::Size idx = i * stride_c_m + j * stride_c_n; - float c_val = beta_value == 0.0f ? 0.0f : Cast(C_batch[idx]); - C_batch[idx] = Cast(alpha_value * sum + beta_value * c_val); + float c_val = beta_value == 0.0f + ? 0.0f + : Cast(C_batch[idx]); + C_batch[idx] = + Cast(alpha_value * sum + beta_value * c_val); } } } diff --git a/src/cpu/rms_norm/rms_norm.h b/src/cpu/rms_norm/rms_norm.h index 2130377..904822c 100644 --- a/src/cpu/rms_norm/rms_norm.h +++ b/src/cpu/rms_norm/rms_norm.h @@ -5,7 +5,7 @@ #include "base/rms_norm.h" #include "common/generic_utils.h" -#include "cpu/cast.h" +#include "cpu/cast_.h" #include "data_type.h" #include "tensor.h" @@ -50,14 +50,15 @@ class Operator : public RmsNorm { float ss = 0; for (Tensor::Size k = 0; k < dim_; ++k) { - float v = Cast(input_row[k]); + float v = Cast(input_row[k]); ss += v * v; } float rms = 1.f / std::sqrt(ss / static_cast(dim_) + eps); for (Tensor::Size k = 0; k < dim_; ++k) { - out_row[k] = Cast(Cast(input_row[k]) * - Cast(weight_ptr[k]) * rms); + out_row[k] = Cast( + Cast(input_row[k]) * + Cast(weight_ptr[k]) * rms); } } } diff --git a/src/cpu/swiglu/swiglu.h b/src/cpu/swiglu/swiglu.h index 065aad5..28cec25 100644 --- a/src/cpu/swiglu/swiglu.h +++ b/src/cpu/swiglu/swiglu.h @@ -5,7 +5,7 @@ #include "base/swiglu.h" #include "common/generic_utils.h" -#include "cpu/cast.h" +#include "cpu/cast_.h" namespace infini::ops { @@ -48,12 +48,13 @@ class Operator : public Swiglu { gate_strides_.data()); auto out_idx = get_idx(i, is_out_contiguous_, out_shape_.data(), out_strides_.data()); - const ComputeType gate_val = Cast(gate_ptr[gate_idx]); + const ComputeType gate_val = + Cast(gate_ptr[gate_idx]); const ComputeType sigmoid_gate = static_cast( 1.0 / (1.0 + std::exp(-static_cast(gate_val)))); const ComputeType swish_gate = gate_val * sigmoid_gate; - out_ptr[out_idx] = - Cast(Cast(input_ptr[input_idx]) * swish_gate); + out_ptr[out_idx] = Cast( + Cast(input_ptr[input_idx]) * swish_gate); } } }; diff --git a/src/cuda/add/kernel.cuh b/src/cuda/add/kernel.cuh index 2d58809..cfd6496 100644 --- a/src/cuda/add/kernel.cuh +++ b/src/cuda/add/kernel.cuh @@ -1,7 +1,7 @@ #ifndef INFINI_OPS_CUDA_ADD_KERNEL_CUH_ #define INFINI_OPS_CUDA_ADD_KERNEL_CUH_ -#include "common/cuda/kernel_commons.h" +#include "cuda/kernel_commons.h" namespace infini::ops { diff --git a/src/cuda/cast.h b/src/cuda/cast_.h similarity index 95% rename from src/cuda/cast.h rename to src/cuda/cast_.h index 1f67a44..eae6127 100644 --- a/src/cuda/cast.h +++ b/src/cuda/cast_.h @@ -97,8 +97,8 @@ __host__ __device__ Dst Cast(Src&& x) { if constexpr (std::is_same_v) { return std::forward(x); } else { - return detail::HardwareCast(std::forward(x), - detail::PriorityHigh{}); + return detail::HardwareCast(std::forward(x), + detail::PriorityHigh{}); } } diff --git a/src/cuda/causal_softmax/kernel.cuh b/src/cuda/causal_softmax/kernel.cuh index ea241c2..f6b0be2 100644 --- a/src/cuda/causal_softmax/kernel.cuh +++ b/src/cuda/causal_softmax/kernel.cuh @@ -5,8 +5,8 @@ #include #include -#include "common/cuda/cast.h" -#include "common/cuda/kernel_commons.h" +#include "cuda/cast_.h" +#include "cuda/kernel_commons.h" namespace infini::ops { @@ -14,7 +14,7 @@ namespace { template __device__ __forceinline__ Data ExpAndCast(Compute x) { - return Cast(expf(Cast(x))); + return Cast(expf(Cast(x))); } struct BlockMaxOp { @@ -41,7 +41,7 @@ __device__ __forceinline__ Compute BlockSum(const Data* data_ptr, size_t count) { Compute thread_sum = 0; for (size_t i = threadIdx.x; i < count; i += block_size) { - thread_sum += Cast(data_ptr[i]); + thread_sum += Cast(data_ptr[i]); } using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage temp_storage; @@ -75,10 +75,11 @@ __global__ void CausalSoftmaxKernel( for (size_t col = threadIdx.x; col < total_seq_len; col += block_size) { if (col < valid_len) { - Compute diff = Cast(input_row[col]) - Cast(max_val); - out_row[col] = ExpAndCast(diff); + Compute diff = Cast(input_row[col]) - + Cast(max_val); + out_row[col] = ExpAndCast(diff); } else { - out_row[col] = Cast(0.0f); + out_row[col] = Cast(0.0f); } } __syncthreads(); @@ -92,8 +93,8 @@ __global__ void CausalSoftmaxKernel( __syncthreads(); for (size_t col = threadIdx.x; col < total_seq_len; col += block_size) { - Compute quot = Cast(out_row[col]) / sum_val; - out_row[col] = Cast(quot); + Compute quot = Cast(out_row[col]) / sum_val; + out_row[col] = Cast(quot); } } diff --git a/src/cuda/causal_softmax/kernel.h b/src/cuda/causal_softmax/kernel.h index 924be40..a320f63 100644 --- a/src/cuda/causal_softmax/kernel.h +++ b/src/cuda/causal_softmax/kernel.h @@ -4,8 +4,8 @@ #include #include "base/causal_softmax.h" -#include "common/cuda/kernel_commons.h" #include "cuda/causal_softmax/kernel.cuh" +#include "cuda/kernel_commons.h" #include "data_type.h" #include "dispatcher.h" diff --git a/src/cuda/rms_norm/kernel.cuh b/src/cuda/rms_norm/kernel.cuh index 10228a6..69db534 100644 --- a/src/cuda/rms_norm/kernel.cuh +++ b/src/cuda/rms_norm/kernel.cuh @@ -5,8 +5,8 @@ #include #include -#include "common/cuda/cast.h" -#include "common/cuda/kernel_commons.h" +#include "cuda/cast_.h" +#include "cuda/kernel_commons.h" namespace infini::ops { @@ -17,7 +17,7 @@ __device__ __forceinline__ TCompute SumSquared(const TData* data_ptr, size_t count) { TCompute ss = 0; for (size_t i = threadIdx.x; i < count; i += block_size) { - TCompute value = Cast(data_ptr[i]); + TCompute value = Cast(data_ptr[i]); ss += value * value; } using BlockReduce = cub::BlockReduce; @@ -46,13 +46,15 @@ __global__ void RmsNormKernel(TData* __restrict__ y, int64_t stride_y_batch, __shared__ TCompute rms; if (threadIdx.x == 0) { - rms = Cast(rsqrtf(ss / Cast(dim) + epsilon)); + rms = Cast( + rsqrtf(ss / Cast(dim) + epsilon)); } __syncthreads(); for (size_t i = threadIdx.x; i < dim; i += block_size) { y_ptr[i] = - Cast(Cast(x_ptr[i]) * Cast(w_ptr[i]) * rms); + Cast(Cast(x_ptr[i]) * + Cast(w_ptr[i]) * rms); } } diff --git a/src/cuda/rms_norm/kernel.h b/src/cuda/rms_norm/kernel.h index 48ab21e..3f61c50 100644 --- a/src/cuda/rms_norm/kernel.h +++ b/src/cuda/rms_norm/kernel.h @@ -4,7 +4,7 @@ #include #include "base/rms_norm.h" -#include "common/cuda/kernel_commons.h" +#include "cuda/kernel_commons.h" #include "cuda/rms_norm/kernel.cuh" #include "data_type.h" #include "dispatcher.h" diff --git a/src/operator.h b/src/operator.h index be6fb51..e04e9af 100644 --- a/src/operator.h +++ b/src/operator.h @@ -93,7 +93,7 @@ class OperatorBase { std::size_t workspace_size_in_bytes_{0}; }; -template +template class Operator : public OperatorBase { public: template @@ -157,6 +157,9 @@ class Operator : public OperatorBase { auto operator()(Args&&... args) const { return (*static_cast(this))(std::forward(args)...); } + + protected: + static constexpr Device::Type device_type_{device_type}; }; } // namespace infini::ops