Conversation
198b12c to
cf6e04e
Compare
| query: The query tensor with shape ``[batchSize, numHeadsQuery, sequenceLengthQuery, dimHead]``. | ||
| key: The key tensor with shape ``[batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]``. | ||
| value: The value tensor with shape ``[batchSize, numHeadsKeyValue, sequenceLengthKeyValue, dimHead]``. |
There was a problem hiding this comment.
I need to verify why in TRT we use numHeadsQuery and numHeadsKeyValue separately.
| 5. Matrix multiplication with value (BMM2) | ||
|
|
||
| Args: | ||
| query: The query tensor with shape ``[batchSize, numHeadsQuery, sequenceLengthQuery, dimHead]``. |
There was a problem hiding this comment.
nit: can we use snake_case to be consistent with the rest of the documentation?
| ## | ||
|
|
||
|
|
||
| def get_trt_dtype_enum_str(dtype: "nvtripy.dtype") -> str: |
There was a problem hiding this comment.
I think we should make this a property of dtype so we don't have to update multiple places when adding new dtypes.
cf6e04e to
a58916a
Compare
|
|
||
| assert output.shape == (batch_size, num_heads, seq_len, head_dim) | ||
|
|
||
| .. code-block:: python |
There was a problem hiding this comment.
Since the inputs to all 3 examples are the same, can we omit the input initialization in the docs so that it is easier to tell what is changing between the samples? Also, can we have the quantization sample omit the mask?
There was a problem hiding this comment.
I'm conflicted on this - on one hand, it will make the examples much cleaner, but on the other, it'll mean that you can't just copy-paste the example code and have it work.
If all the tensors are the same shape, maybe a compromise could be:
query = key = value = tp.iota(...)although we would need to clarify that it's only being done for the sake of brevity and they don't all need to be the same tensor.
a58916a to
d3949d7
Compare
No description provided.