Skip to content

Parallel forward with state#830

Open
mathieu-charbonnel wants to merge 3 commits intostate-spaces:mainfrom
DataDog:mathieu.charbonnel/parallel_forward_with_state
Open

Parallel forward with state#830
mathieu-charbonnel wants to merge 3 commits intostate-spaces:mainfrom
DataDog:mathieu.charbonnel/parallel_forward_with_state

Conversation

@mathieu-charbonnel
Copy link
Copy Markdown

@mathieu-charbonnel mathieu-charbonnel commented Dec 29, 2025

This PR answers the issue #536.
My answer here #536 as well as the OP describe the motivation behind this change.

In this PR we propose a step_chunk function that applies the parallel scan approach to the inference step on a chunk of tokens. Just like the step function this function allows taking last_inputs and hidden_state as arguments. The usage of pscan brings the number of steps in sequence length from O[L] to O(log(L)).

step_chunk combines:

  • the last_input handling (similarly to step code) for convolution continuity

  • Update of the first state,
    The ssm update can be written Ht = Δt * Bt * xt + exp(A × Δt) * Ht-1
    If we denote X[t] = Δt * Bt * xt, A[t] = exp(A × Δt)
    Then the first state should be H[0] = X[0] + A[0] * H[-1] where H[-1] is the last state
    This is done inside the cuda kernel as Δt * Bt and exp(A × Δt) -which are done in the kernel- need to be computed first

  • Then parallel scan is applied exactly the way it is done in forward

In terms of implementation I modified the forward to handle new inference params, and initial state modification in cuda kernel.

Looking forward for some reviews, please note that I am not experienced in cuda kernel development and relied heavily on AI tools for updating it.
I added a few tests to verify correctness of step_chunk processing.

@mathieu-charbonnel mathieu-charbonnel changed the title Mathieu.charbonnel/parallel forward with state Parallel forward with state Jan 6, 2026
@mathieu-charbonnel mathieu-charbonnel force-pushed the mathieu.charbonnel/parallel_forward_with_state branch from 2d562b2 to 0818d8b Compare January 27, 2026 14:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant