Parallel forward with state#830
Open
mathieu-charbonnel wants to merge 3 commits intostate-spaces:mainfrom
Open
Parallel forward with state#830mathieu-charbonnel wants to merge 3 commits intostate-spaces:mainfrom
mathieu-charbonnel wants to merge 3 commits intostate-spaces:mainfrom
Conversation
2d562b2 to
0818d8b
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This PR answers the issue #536.
My answer here #536 as well as the OP describe the motivation behind this change.
In this PR we propose a step_chunk function that applies the parallel scan approach to the inference step on a chunk of tokens. Just like the step function this function allows taking last_inputs and hidden_state as arguments. The usage of pscan brings the number of steps in sequence length from O[L] to O(log(L)).
step_chunk combines:
the last_input handling (similarly to step code) for convolution continuity
Update of the first state,
The ssm update can be written Ht = Δt * Bt * xt + exp(A × Δt) * Ht-1
If we denote X[t] = Δt * Bt * xt, A[t] = exp(A × Δt)
Then the first state should be H[0] = X[0] + A[0] * H[-1] where H[-1] is the last state
This is done inside the cuda kernel as Δt * Bt and exp(A × Δt) -which are done in the kernel- need to be computed first
Then parallel scan is applied exactly the way it is done in forward
In terms of implementation I modified the forward to handle new inference params, and initial state modification in cuda kernel.
Looking forward for some reviews, please note that I am not experienced in cuda kernel development and relied heavily on AI tools for updating it.
I added a few tests to verify correctness of step_chunk processing.