use torch.gather() to simplify the process of broadcasting a token embedding to an atom embedding and gathering the frame atom coordinates#269
Open
OccupyMars2030 wants to merge 4 commits intobytedance:mainfrom
Conversation
…dding to an atom embedding use torch.gather() to simplify the process of broadcasting a token embedding to an atom embedding
OccupyMars2030
commented
Mar 14, 2026
| remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)] | ||
| remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds | ||
| ranges.extend(remaining_dims) | ||
| return data[ranges] |
Contributor
Author
There was a problem hiding this comment.
Since we are gathering data along only "one" dimension, there is no need to construct an index tensor for each dim like "ranges". Just use torch.gather() because it is more logically concise
This was referenced Mar 14, 2026
add more detailed docstring
Contributor
Author
TODO(maybe 2 weeks later): study if we can use
|
Collaborator
|
@OccupyMars2025 |
OccupyMars2030
commented
Mar 15, 2026
|
|
||
| pad_left = (n_keys - n_queries) // 2 | ||
| pad_right = int((n_trunks - 1 / 2) * n_queries + n_keys / 2 - n + 1 / 2) | ||
| pad_right = (n_keys - n_queries) // 2 + (n_trunks * n_queries - n) |
Contributor
Author
There was a problem hiding this comment.
This line has nothing to do with my other modification. I just cannot understand the original calculation method for pad_right.
Note: both n_keys and n_queries are even integer numbers. You have used assert statement to confirm it.
https://github.com/OccupyMars2025/Protenix/wiki/explain-how-to-calculate-pad_right-for-the-key-tensor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The following code compares the execution time of the two versions of
gather_frame_atom_by_indices()which invokesbatched_gather()On CPU, my implementation is about 5x faster.
On colab, H100 GPU, my implementation is about 15x faster.
To compare the execution time of the two versions of
batched_gather(), you can refer to aqlaboratory/openfold-3#135 (comment)