Skip to content

fix bug in online_softmax when all loaded values in a warp are -inf#64

Open
LitLeo wants to merge 3 commits intoDao-AILab:mainfrom
LitLeo:main
Open

fix bug in online_softmax when all loaded values in a warp are -inf#64
LitLeo wants to merge 3 commits intoDao-AILab:mainfrom
LitLeo:main

Conversation

@LitLeo
Copy link
Copy Markdown

@LitLeo LitLeo commented Jan 8, 2026

if max_x=-inf, exp_x=nan, which causes error result.

@cute.jit
def online_softmax_reduce(
    x: cute.TensorSSA,
    threads_per_row: cutlass.Constexpr[int],
    reduction_buffer: Optional[cute.Tensor] = None,
    mbar_ptr: Optional[cute.Pointer] = None,
    hook_fn: Optional[Callable] = None,
    phase: Optional[Int32] = None,
    return_exp_x: bool = False,
) -> [Float32, Float32, Optional[cute.TensorSSA]]:
    assert x.dtype == Float32, "x must be of type Float32"
    """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)"""
    max_x = cute.arch.warp_reduction(
        x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0),
        cute.arch.fmax,
        threads_in_group=min(threads_per_row, cute.arch.WARP_SIZE),
    )
    log2_e = math.log2(math.e)
    exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True)
    sum_exp_x = cute.arch.warp_reduction(
        exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0),
        operator.add,
        threads_in_group=min(threads_per_row, cute.arch.WARP_SIZE),
    )

@tridao
Copy link
Copy Markdown
Member

tridao commented Jan 11, 2026

I think it's better to have the check:

max_x_cur = Float32(0.0) if max_x == -Float32.inf else max_x

The subtracting max_x_cur before the exp.

@LitLeo
Copy link
Copy Markdown
Author

LitLeo commented Jan 12, 2026

I think it's better to have the check:

max_x_cur = Float32(0.0) if max_x == -Float32.inf else max_x

The subtracting max_x_cur before the exp.

if max_x == -Float32.inf:
    max_x = Float32(0.0)

is OK, but I noticed that row_reduce has the same issue.
Here is a more general approach I found. What do you think of this solution?

if const_expr(not is_even_N):
    # utils.fill_oob(tXsX, tXpX, -tXsX.element_type.inf)
    utils.fill_oob(tXsX, tXpX, -BFloat16(2**15))

@tridao
Copy link
Copy Markdown
Member

tridao commented Jan 12, 2026

I prefer changing max-cur. For the case where the input is entirely -inf (not out of bounds) we still want output to be zero.

Remove handling for max_x being -inf to avoid NaN in exp_x.
@LitLeo
Copy link
Copy Markdown
Author

LitLeo commented Jan 19, 2026

I prefer changing max-cur. For the case where the input is entirely -inf (not out of bounds) we still want output to be zero.

Done.

@tridao
Copy link
Copy Markdown
Member

tridao commented Jan 19, 2026

Thanks, could you add a test case where the whole vector is -inf?

@LitLeo
Copy link
Copy Markdown
Author

LitLeo commented Jan 20, 2026

Thanks, could you add a test case where the whole vector is -inf?

A test case where the entire vector is -inf is meaningless, as the loss of this vector would be NaN. This bug is triggered when N=150000, so I've added a case with N to the test_cross_entropy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants