[feat] Support multimodel mtp#14
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces decoder_input propagation through the Multi-Token Prediction (MTP) layers and adds a _get_embeddings helper method to handle sequence rolling and embedding computation. Several critical issues need to be addressed: the forward method signature in patcher.py is missing the decoder_input parameter, which will cause a NameError; the MultiTokenPrediction container requires patching to accept the new argument and avoid a TypeError; the make_viewless_tensor function is used without being imported; and the logic in _get_embeddings should be updated to preserve pre-computed decoder_input values instead of unconditionally overwriting them.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for multimodal Multi-Token Prediction (MTP) by modifying the GPT model's forward and post-processing logic to handle combined embedding and decoder inputs. It also adds a _get_embeddings method to the MultiTokenPredictionLayer to manage tensor rolling and sequence parallelism. Review feedback suggests addressing a potential IndexError in the _get_embeddings method by rolling along the sequence dimension directly instead of transposing, and improving type safety by adding Union to the imports and updating the embedding argument's type hint.
No description provided.