Skip to content

[Ilu backend] Fix the implement of normalizations and add other ops.#199

Open
madengfei wants to merge 37 commits intoXPU-Forces:masterfrom
madengfei:ilu_ops_pr_0401
Open

[Ilu backend] Fix the implement of normalizations and add other ops.#199
madengfei wants to merge 37 commits intoXPU-Forces:masterfrom
madengfei:ilu_ops_pr_0401

Conversation

@madengfei
Copy link
Copy Markdown
Collaborator

  1. Fix the implement of normalizations, test accuracy ok.
  2. Add implements of RoPE, KVCache, PagedPrefillAttention,PagedDecodeAttention,Swiglu,Gelu, Silu, etc.

@wwens7
Copy link
Copy Markdown
Collaborator

wwens7 commented Apr 1, 2026

可以参考着补充一下 perf test,后续优化性能可以直接跑 perf

@madengfei
Copy link
Copy Markdown
Collaborator Author

好的,已经开始做了。

@zhangjihang-BD
Copy link
Copy Markdown
Collaborator

我已经仔细阅读了这个 PR 的全部变更内容。下面是我的详细 Review。


PR #199 Review: [Ilu backend] Fix the implement of normalizations and add other ops

概览: 本 PR 为 ILU (Iluvatar 天数智芯) 后端新增了大量 Triton kernel 实现(+4682/-416,43 个文件),涵盖 RoPE、KVCache、PagedAttention、SwiGLU、GELU、SiLU、GroupGemm、Sampling、SDPA 等算子,并修复了已有的 normalization 实现。


一、架构设计问题 (High Priority)

1. Core 层注入平台特定代码,破坏了架构分层

多处在 core/operators/ 中直接引入 ILU 平台分发逻辑,违反了 core 层应保持平台无关的设计原则:

  • core/operators/linear.py — 在 MojoLinear.forward 中 try/except 导入 linear_fwd_impl
  • core/operators/sampling.py — 多处 if get_platform() == "ilu" 分支
  • core/operators/kv_cache.pyif get_platform() == "ilu" 分支
  • core/operators/gemm.pyMojoGroupGemm.forwardMojoQuantGroupLinearReduceSum.forward 中同样注入 ILU 分支
  • experimental/operators/store_lowrank.py — 同理

这些应该通过 supported_platforms_list 和 operator registry 机制实现(像 TTXSiluTTXGelu 等那样),而不是在基类中做条件分发。这种做法会导致:

  • 后续每新增一个平台都需修改 core 层
  • 违反开闭原则
  • 难以测试和维护

建议: 为这些算子创建对应的 TTX operator 子类(部分已有如 TTXStorePagedKVCache),将 ILU 分发逻辑收拢到 backends/ttx/operators/ 层。

2. MojoLinear.forward 的 fallback 机制有缺陷

try:
    from mojo_opset.backends.ttx.kernels import linear_fwd_impl
    return linear_fwd_impl(input, self.weight, self.bias)
except (NotImplementedError, TypeError):
    return F.linear(input, self.weight, self.bias)
  • ImportError 不在 catch 范围内,如果 linear_fwd_impl 导入失败会直接抛异常
  • 这个 try/except 会让 所有平台(包括 CPU、CUDA、NPU)都尝试导入 ILU kernel,即使它们不需要

二、代码重复问题 (Medium Priority)

3. 大量工具函数/常量在多个文件中重复定义

以下内容在 fused_add_layernorm.pyfused_add_rmsnorm.pylayernorm.pyrmsnorm.py 四个文件中几乎完全一致:

  • COL_BLOCKING_THRESHOLD = 2048
  • TOKEN_BLOCK_SIZE_TABLE 字典
  • _block_size_n_pow2() 函数
  • layer_norm_fwd_heuristics() / rms_norm_fwd_heuristics() — 逻辑完全一致

建议: 提取到 utils.py 中统一管理。

4. calculate_settings()swiglu.pysilu.py 中各有一份

gelu.pysilu.py 导入 calculate_settings(好的做法),但 swiglu.py 自己又定义了一份完全一样的函数。


三、正确性与健壮性问题 (Medium-High Priority)

5. Prefill Attention kernel 采用两遍扫描,存在精度与性能隐患

_paged_prefill_causal_attn_kernel 先做一遍循环求 max,再做一遍循环算 softmax。这种写法:

  • 对 K 做了 2 次全扫描,复杂度是 online softmax 的 2 倍
  • 每个 program 处理单个 (query, head) pair,无分块并行,长序列时性能很差
  • 逐元素点积 tl.sum(q_vec * k_vec) 而非矩阵 tl.dot,未利用 tensor core

建议: 后续可参考 FlashAttention 或 FlashDecoding 的 Triton 实现做优化(目前 PR 评论中也提到了补 perf test 的事)。

6. paged_attention_prefill_impl 中有 Python for-loop 逐 batch 处理

KV gather、GQA expand 和 block 拼接都是 PyTorch + Python 层面的循环,对于大 batch 会很慢。虽然从正确性角度可以接受作为初始实现,但需要明确标注为"功能验证版本"。

7. paged_attention_decode_impl 有硬编码限制

assert block_size <= 128, f"temp: only support block_size <= 128, but got {block_size}"

以及 kernel 内 tl.static_assert(PAGE_SIZE == BLOCK_SIZE_N),限制了 block size 必须等于 page size。这些约束应在文档中明确说明。

8. convolution.py 类型标注错误

conv_state_indices: Optional[str] = None,

应为 Optional[torch.Tensor]

9. gqa_interleave 方向可能反了(Prefill vs Decode 不一致)

paged_attention_prefill_impl 中:

if gqa_interleave:
    k_expanded = k_unpadded.repeat((1, num_q_heads // num_kv_heads, 1))

而在 paged_decode_kernel 中:

if GQA_INTERLEAVE:
    kv_head_id = q_head_id % NUM_KV_HEADS
else:
    kv_head_id = q_head_id // (NUM_Q_HEADS // NUM_KV_HEADS)

repeat vs repeat_interleave 的语义与 decode 中的 % vs // 映射关系需要仔细确认是否一致。


四、代码质量问题 (Low-Medium Priority)

10. tests/utils.py 留有注释掉的旧代码

# device = get_platform()
device = get_torch_device()

应该直接删除注释行。

11. torch_npu 在通用 kernels/utils.py 中全局导入

try:
    import torch_npu
except ImportError:
    torch_npu = None

放在通用的 kernels/utils.py 中会导致非 NPU 环境产生不必要的导入尝试。应该延迟导入或放在 NPU 特定模块中。

12. store_label_cache_infer_impl 硬编码 ub_buffer = 192

这个值的含义和单位(KB? bytes?)不明确,也没有根据实际硬件信息动态确定。

13. m_grouped_matmul_impl 用 Python for-loop 调用 per-group kernel

虽然存在 _m_grouped_matmul_bKmajor_kernel_m_grouped_matmul_bNmajor_kernel 这两个支持多 group 融合的 kernel,但 m_grouped_matmul_impl 实际上调用的是简单的 _group_matmul_kernel,逐 group 循环。那两个复杂的融合 kernel 和 get_autotune_config() 完全是死代码。

14. 未使用的导入

  • fused_add_layernorm.py 导入了 Tuple 但未使用
  • fused_add_rmsnorm.py 导入了 VEC_ALIGN_BYTESalign 但未使用
  • fused_add_layernorm.py 导入了 VEC_ALIGN_BYTES 但未使用

15. test_linear.pytest_gemm.py 存在大量重复测试用例

新增的 test_linear.pytest_group_gemmtest_quant_group_linear_reduce_sumtest_grouped_matmul_cases_via_group_lineartest_group_linear_two_groups_single_call 等测试与 test_gemm.py 中的几乎完全一致。建议复用或明确区分。


五、正面评价

  1. .npu()device=query.device: attention.py 中将硬编码的 .npu() 改为 device=query.device 是正确的设备无关写法
  2. libentry fallback: utils.py 中对 libentry 做了 try/import 处理,兼容了没有该 API 的 Triton 版本
  3. ilu_grid_dim_from_row_tasks: 针对 ILU 硬件特性(vectorcore 数量)进行 grid 下限保护,设计合理
  4. Normalization 修复: 将 BLOCK_SIZE 调整为 power-of-2(512:16 替换 512:10 等)以适配 ILU Triton 的 tl.arange 要求,fix 合理
  5. test_moe.py dtype fix: current_state.float() 确保了 F.linear 输入 dtype 一致性
  6. 测试用例有精度验证: 新增的测试都通过 forward_diff_withtorch.testing.assert_close 做了精度对比

六、总结建议

优先级 问题 建议
P0 Core 层注入平台分发 通过 TTX operator 子类 + registry 实现
P1 MojoLinear.forward fallback 缺陷 修正异常捕获范围;移到 TTX 子类
P1 GQA interleave 一致性 验证 prefill/decode 的 head 映射逻辑
P2 大量代码重复 提取公共函数到 utils
P2 死代码(autotune configs) 清理或启用
P2 注释掉的旧代码 删除
P3 类型标注错误 修正 conv_state_indices 类型
P3 未使用导入 清理

总体来说,这个 PR 作为 ILU 后端的功能基线是可以接受的(accuracy test 已通过),但在合入 master 前建议优先解决 P0 的架构问题,至少要避免对 core 层的侵入式修改。

From Opus.

@LeoLau94
Copy link
Copy Markdown
Contributor

LeoLau94 commented Apr 1, 2026

我已经仔细阅读了这个 PR 的全部变更内容。下面是我的详细 Review。

PR #199 Review: [Ilu backend] Fix the implement of normalizations and add other ops

概览: 本 PR 为 ILU (Iluvatar 天数智芯) 后端新增了大量 Triton kernel 实现(+4682/-416,43 个文件),涵盖 RoPE、KVCache、PagedAttention、SwiGLU、GELU、SiLU、GroupGemm、Sampling、SDPA 等算子,并修复了已有的 normalization 实现。

一、架构设计问题 (High Priority)

1. Core 层注入平台特定代码,破坏了架构分层

多处在 core/operators/ 中直接引入 ILU 平台分发逻辑,违反了 core 层应保持平台无关的设计原则:

  • core/operators/linear.py — 在 MojoLinear.forward 中 try/except 导入 linear_fwd_impl
  • core/operators/sampling.py — 多处 if get_platform() == "ilu" 分支
  • core/operators/kv_cache.pyif get_platform() == "ilu" 分支
  • core/operators/gemm.pyMojoGroupGemm.forwardMojoQuantGroupLinearReduceSum.forward 中同样注入 ILU 分支
  • experimental/operators/store_lowrank.py — 同理

这些应该通过 supported_platforms_list 和 operator registry 机制实现(像 TTXSiluTTXGelu 等那样),而不是在基类中做条件分发。这种做法会导致:

  • 后续每新增一个平台都需修改 core 层
  • 违反开闭原则
  • 难以测试和维护

建议: 为这些算子创建对应的 TTX operator 子类(部分已有如 TTXStorePagedKVCache),将 ILU 分发逻辑收拢到 backends/ttx/operators/ 层。

2. MojoLinear.forward 的 fallback 机制有缺陷

try:
    from mojo_opset.backends.ttx.kernels import linear_fwd_impl
    return linear_fwd_impl(input, self.weight, self.bias)
except (NotImplementedError, TypeError):
    return F.linear(input, self.weight, self.bias)
  • ImportError 不在 catch 范围内,如果 linear_fwd_impl 导入失败会直接抛异常
  • 这个 try/except 会让 所有平台(包括 CPU、CUDA、NPU)都尝试导入 ILU kernel,即使它们不需要

二、代码重复问题 (Medium Priority)

3. 大量工具函数/常量在多个文件中重复定义

以下内容在 fused_add_layernorm.pyfused_add_rmsnorm.pylayernorm.pyrmsnorm.py 四个文件中几乎完全一致:

  • COL_BLOCKING_THRESHOLD = 2048
  • TOKEN_BLOCK_SIZE_TABLE 字典
  • _block_size_n_pow2() 函数
  • layer_norm_fwd_heuristics() / rms_norm_fwd_heuristics() — 逻辑完全一致

建议: 提取到 utils.py 中统一管理。

4. calculate_settings()swiglu.pysilu.py 中各有一份

gelu.pysilu.py 导入 calculate_settings(好的做法),但 swiglu.py 自己又定义了一份完全一样的函数。

三、正确性与健壮性问题 (Medium-High Priority)

5. Prefill Attention kernel 采用两遍扫描,存在精度与性能隐患

_paged_prefill_causal_attn_kernel 先做一遍循环求 max,再做一遍循环算 softmax。这种写法:

  • 对 K 做了 2 次全扫描,复杂度是 online softmax 的 2 倍
  • 每个 program 处理单个 (query, head) pair,无分块并行,长序列时性能很差
  • 逐元素点积 tl.sum(q_vec * k_vec) 而非矩阵 tl.dot,未利用 tensor core

建议: 后续可参考 FlashAttention 或 FlashDecoding 的 Triton 实现做优化(目前 PR 评论中也提到了补 perf test 的事)。

6. paged_attention_prefill_impl 中有 Python for-loop 逐 batch 处理

KV gather、GQA expand 和 block 拼接都是 PyTorch + Python 层面的循环,对于大 batch 会很慢。虽然从正确性角度可以接受作为初始实现,但需要明确标注为"功能验证版本"。

7. paged_attention_decode_impl 有硬编码限制

assert block_size <= 128, f"temp: only support block_size <= 128, but got {block_size}"

以及 kernel 内 tl.static_assert(PAGE_SIZE == BLOCK_SIZE_N),限制了 block size 必须等于 page size。这些约束应在文档中明确说明。

8. convolution.py 类型标注错误

conv_state_indices: Optional[str] = None,

应为 Optional[torch.Tensor]

9. gqa_interleave 方向可能反了(Prefill vs Decode 不一致)

paged_attention_prefill_impl 中:

if gqa_interleave:
    k_expanded = k_unpadded.repeat((1, num_q_heads // num_kv_heads, 1))

而在 paged_decode_kernel 中:

if GQA_INTERLEAVE:
    kv_head_id = q_head_id % NUM_KV_HEADS
else:
    kv_head_id = q_head_id // (NUM_Q_HEADS // NUM_KV_HEADS)

repeat vs repeat_interleave 的语义与 decode 中的 % vs // 映射关系需要仔细确认是否一致。

四、代码质量问题 (Low-Medium Priority)

10. tests/utils.py 留有注释掉的旧代码

# device = get_platform()
device = get_torch_device()

应该直接删除注释行。

11. torch_npu 在通用 kernels/utils.py 中全局导入

try:
    import torch_npu
except ImportError:
    torch_npu = None

放在通用的 kernels/utils.py 中会导致非 NPU 环境产生不必要的导入尝试。应该延迟导入或放在 NPU 特定模块中。

12. store_label_cache_infer_impl 硬编码 ub_buffer = 192

这个值的含义和单位(KB? bytes?)不明确,也没有根据实际硬件信息动态确定。

13. m_grouped_matmul_impl 用 Python for-loop 调用 per-group kernel

虽然存在 _m_grouped_matmul_bKmajor_kernel_m_grouped_matmul_bNmajor_kernel 这两个支持多 group 融合的 kernel,但 m_grouped_matmul_impl 实际上调用的是简单的 _group_matmul_kernel,逐 group 循环。那两个复杂的融合 kernel 和 get_autotune_config() 完全是死代码。

14. 未使用的导入

  • fused_add_layernorm.py 导入了 Tuple 但未使用
  • fused_add_rmsnorm.py 导入了 VEC_ALIGN_BYTESalign 但未使用
  • fused_add_layernorm.py 导入了 VEC_ALIGN_BYTES 但未使用

15. test_linear.pytest_gemm.py 存在大量重复测试用例

新增的 test_linear.pytest_group_gemmtest_quant_group_linear_reduce_sumtest_grouped_matmul_cases_via_group_lineartest_group_linear_two_groups_single_call 等测试与 test_gemm.py 中的几乎完全一致。建议复用或明确区分。

五、正面评价

  1. .npu()device=query.device: attention.py 中将硬编码的 .npu() 改为 device=query.device 是正确的设备无关写法
  2. libentry fallback: utils.py 中对 libentry 做了 try/import 处理,兼容了没有该 API 的 Triton 版本
  3. ilu_grid_dim_from_row_tasks: 针对 ILU 硬件特性(vectorcore 数量)进行 grid 下限保护,设计合理
  4. Normalization 修复: 将 BLOCK_SIZE 调整为 power-of-2(512:16 替换 512:10 等)以适配 ILU Triton 的 tl.arange 要求,fix 合理
  5. test_moe.py dtype fix: current_state.float() 确保了 F.linear 输入 dtype 一致性
  6. 测试用例有精度验证: 新增的测试都通过 forward_diff_withtorch.testing.assert_close 做了精度对比

六、总结建议

优先级 问题 建议
P0 Core 层注入平台分发 通过 TTX operator 子类 + registry 实现
P1 MojoLinear.forward fallback 缺陷 修正异常捕获范围;移到 TTX 子类
P1 GQA interleave 一致性 验证 prefill/decode 的 head 映射逻辑
P2 大量代码重复 提取公共函数到 utils
P2 死代码(autotune configs) 清理或启用
P2 注释掉的旧代码 删除
P3 类型标注错误 修正 conv_state_indices 类型
P3 未使用导入 清理
总体来说,这个 PR 作为 ILU 后端的功能基线是可以接受的(accuracy test 已通过),但在合入 master 前建议优先解决 P0 的架构问题,至少要避免对 core 层的侵入式修改。

From Opus.

有理有据,令人信服,这是哪位大师?

@madengfei
Copy link
Copy Markdown
Collaborator Author

这个reviewer太牛了,就照着这改了

yao.guo and others added 3 commits April 2, 2026 04:59
…es#191)

* feat(*): support quantized operators for torch_npu backend.

* feat(moe_quant): support quantized fused moe.

* refactor(*): refactor the LLM-generated code by LLM itself.

* chore(moe.py): add TODO.

* chore(README.md): udpate operator status.

* misc: update MoE configurations in test_moe.

---------

Co-authored-by: wwens7 <zhaowenshuo.oo@bytedance.com>
@zhangjihang-BD
Copy link
Copy Markdown
Collaborator

我再试试gemini老师

@zhangjihang-BD
Copy link
Copy Markdown
Collaborator

/gemini review

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the ILU backend by adding various Triton-based kernels and operators, including activation functions, attention mechanisms, normalization, and linear operations. My review identified several critical issues regarding the implementation of post-normalization logic in fused kernels and performance concerns related to kernel launch overhead and algorithmic complexity in the attention and grouped GEMM implementations.

I am having trouble creating individual review comments. Click here to see my feedback.

mojo_opset/backends/ttx/kernels/ilu/fused_add_layernorm.py (178-181)

high

The logic for add_mode="post" appears to be incorrect. The kernel computes Y = LayerNorm(hidden_states + residual), which is a pre-normalization operation. For a post-normalization architecture, the operation is typically output = LayerNorm(hidden_states) + residual. This implementation does not seem to support the post-norm case correctly. Returning Y for both the output and the next residual for post-norm models is likely to cause incorrect model behavior.

mojo_opset/backends/ttx/kernels/ilu/fused_add_rmsnorm.py (192-195)

high

The logic for add_mode="post" appears to be incorrect. The kernel computes Y = RMSNorm(hidden_states + residual), which is a pre-normalization operation. For a post-normalization architecture, the operation is typically output = RMSNorm(hidden_states) + residual. This implementation does not seem to support the post-norm case correctly. Returning Y for both the output and the next residual for post-norm models is likely to cause incorrect model behavior.

mojo_opset/backends/ttx/kernels/ilu/group_gemm.py (247-303)

medium

The current implementation of m_grouped_matmul_impl launches a separate kernel for each group within a Python loop. This approach can be inefficient due to kernel launch overhead, especially when dealing with a large number of small groups. This file already contains more advanced kernels like _m_grouped_matmul_bNmajor_kernel and _m_grouped_matmul_bKmajor_kernel that are designed to process all groups in a single launch. It would be more performant to refactor this function to use one of those kernels.

mojo_opset/backends/ttx/kernels/ilu/quant_group_linear.py (105-128)

medium

The quant_group_linear_reduce_sum_impl function launches a Triton kernel for each item in the batch via a Python loop. This can lead to significant performance degradation due to kernel launch overhead, especially with large batch sizes. To improve performance, consider refactoring the implementation to handle the batch dimension inside the Triton kernel, allowing for a single kernel launch to process the entire batch.

mojo_opset/backends/ttx/kernels/ilu/sdpa.py (24-108)

medium

The forward kernel _sdpa_masked_fwd_kernel computes the full attention matrix for each query token without tiling over the sequence dimension. This results in O(S^2) complexity per head, which can be very inefficient and consume a large amount of memory for long sequences. For better performance, consider implementing a tiled approach, similar to FlashAttention, which has O(S) complexity.

@madengfei
Copy link
Copy Markdown
Collaborator Author

目前已经按照Opus老师的建议改了一版。 @zhangjihang-BD

@yfchen-byted
Copy link
Copy Markdown
Collaborator

yfchen-byted commented Apr 2, 2026

RoPE有一个拆分的重构刚合入,麻烦rebase下

@madengfei
Copy link
Copy Markdown
Collaborator Author

RoPE有一个拆分的重构刚合入,麻烦rebase下
好的,正在弄

@yfchen-byted
Copy link
Copy Markdown
Collaborator

\gemini review

@wwens7
Copy link
Copy Markdown
Collaborator

wwens7 commented Apr 7, 2026

需要补一下 perf test 哈,包括 ilu 的 perf helper func(主要是 device perf)

@madengfei
Copy link
Copy Markdown
Collaborator Author

需要补一下 perf test 哈,包括 ilu 的 perf helper func(主要是 device perf)


ok

wwens7
wwens7 previously approved these changes Apr 8, 2026
tests/utils.py Outdated
perf_fn = perf_npu
elif device == 'mlu':
perf_fn = perf_mlu
elif platform == "ilu":
Copy link
Copy Markdown
Collaborator

@yfchen-byted yfchen-byted Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

话说为什么这里会存在diverge;现在的auto_switch_platform装饰器都是基于get_platform的,用torch_device的话是否会存在不一致

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok,"device=get_platform()”确实就是一致的,diverge是不合理的。我改下。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants