Skip to content

Add tp.attention op#709

Draft
yizhuoz004 wants to merge 1 commit intomainfrom
tp-attention
Draft

Add tp.attention op#709
yizhuoz004 wants to merge 1 commit intomainfrom
tp-attention

Conversation

@yizhuoz004
Copy link
Copy Markdown
Collaborator

No description provided.

Comment on lines +68 to +70
query: The query tensor with shape ``[batchSize, numHeadsQuery, sequenceLengthQuery, dimHead]``.
key: The key tensor with shape ``[batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]``.
value: The value tensor with shape ``[batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]``.
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to verify why in TRT we use numHeadsQuery and numHeadsKeyValue separately.

@yizhuoz004 yizhuoz004 changed the title Add AttentionOp Add tp.attention op Nov 4, 2025
5. Matrix multiplication with value (BMM2)

Args:
query: The query tensor with shape ``[batchSize, numHeadsQuery, sequenceLengthQuery, dimHead]``.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we use snake_case to be consistent with the rest of the documentation?

##


def get_trt_dtype_enum_str(dtype: "nvtripy.dtype") -> str:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make this a property of dtype so we don't have to update multiple places when adding new dtypes.


assert output.shape == (batch_size, num_heads, seq_len, head_dim)

.. code-block:: python
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the inputs to all 3 examples are the same, can we omit the input initialization in the docs so that it is easier to tell what is changing between the samples? Also, can we have the quantization sample omit the mask?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm conflicted on this - on one hand, it will make the examples much cleaner, but on the other, it'll mean that you can't just copy-paste the example code and have it work.

If all the tensors are the same shape, maybe a compromise could be:

query = key = value = tp.iota(...)

although we would need to clarify that it's only being done for the sake of brevity and they don't all need to be the same tensor.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants