Skip to content

use torch.gather() to simplify the process of broadcasting a token embedding to an atom embedding and gathering the frame atom coordinates#269

Open
OccupyMars2030 wants to merge 4 commits intobytedance:mainfrom
OccupyMars2030:use-torch-gather
Open

use torch.gather() to simplify the process of broadcasting a token embedding to an atom embedding and gathering the frame atom coordinates#269
OccupyMars2030 wants to merge 4 commits intobytedance:mainfrom
OccupyMars2030:use-torch-gather

Conversation

@OccupyMars2030
Copy link
Copy Markdown
Contributor

@OccupyMars2030 OccupyMars2030 commented Mar 14, 2026

The following code compares the execution time of the two versions of gather_frame_atom_by_indices() which invokes batched_gather()

import torch

# this is mostly from openfold.utils.torch_utils import batched_gather
def batched_gather(
    data: torch.Tensor, inds: torch.Tensor, dim: int = 0, no_batch_dims: int = 0
) -> torch.Tensor:
    """Gather data according to indices specify by inds

    Args:
        data (torch.Tensor): the input data
            [..., K, ...]
        inds (torch.Tensor): the indices for gathering data
            [..., N]
        dim (int, optional): along which dimension to gather data by inds (the dim of "K" "N"). Defaults to 0.
        no_batch_dims (int, optional): length of dimensions before the "dim" dimension. Defaults to 0.

    Returns:
        torch.Tensor: gathered data
            [..., N, ...]
    """

    # for the naive case
    if len(inds.shape) == 1 and no_batch_dims == 0 and dim == 0:
        return data[inds]

    ranges = []
    for i, s in enumerate(data.shape[:no_batch_dims]):
        r = torch.arange(s)
        r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1))))
        ranges.append(r)

    remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
    remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
    ranges.extend(remaining_dims)
    # return data[tuple(ranges)] # transform it to a tuple to silent a pytorch framework UserWarning
    return data[ranges]


def gather_frame_atom_by_indices(
    coordinate: torch.Tensor, frame_atom_index: torch.Tensor, dim: int = -2
) -> torch.Tensor:
    """construct frames from coordinate

    Args:
        coordinate (torch.Tensor):  the input coordinate
            [..., N_atom, 3]
        frame_atom_index (torch.Tensor): indices of three atoms in each frame
            [..., N_frame, 3] or [N_frame, 3]
        dim (int): along which dimension to select the frame atoms
    Returns:
        torch.Tensor: the constructed frames
            [..., N_frame, 3[three atom], 3[three coordinate]]
    """
    if len(frame_atom_index.shape) == 2:
        # the navie case
        x1 = torch.index_select(
            coordinate, dim=dim, index=frame_atom_index[:, 0]
        )  # [..., N_frame, 3]
        x2 = torch.index_select(
            coordinate, dim=dim, index=frame_atom_index[:, 1]
        )  # [..., N_frame, 3]
        x3 = torch.index_select(
            coordinate, dim=dim, index=frame_atom_index[:, 2]
        )  # [..., N_frame, 3]
        return torch.stack([x1, x2, x3], dim=dim)
    else:
        assert (
            frame_atom_index.shape[:dim] == coordinate.shape[:dim]
        ), "batch size dims should match"

    x1 = batched_gather(
        data=coordinate,
        inds=frame_atom_index[..., 0],
        dim=dim,
        no_batch_dims=len(coordinate.shape[:dim]),
    )  # [..., N_frame, 3]
    x2 = batched_gather(
        data=coordinate,
        inds=frame_atom_index[..., 1],
        dim=dim,
        no_batch_dims=len(coordinate.shape[:dim]),
    )  # [..., N_frame, 3]
    x3 = batched_gather(
        data=coordinate,
        inds=frame_atom_index[..., 2],
        dim=dim,
        no_batch_dims=len(coordinate.shape[:dim]),
    )  # [..., N_frame, 3]
    return torch.stack([x1, x2, x3], dim=dim)


def batched_gather_myversion(
    data: torch.Tensor, inds: torch.Tensor
) -> torch.Tensor:
    """Gather data according to indices specified by inds along the dim = len(inds.shape) - 1

    Args:
        data (torch.Tensor): the input data
            [..., K, ...]
        inds (torch.Tensor): the indices for gathering data
            [..., N]

    Returns:
        torch.Tensor: gathered data
            [..., N, ...]
    """
    assert len(inds.shape) <= len(data.shape), "inds must have less or equal dimensions than data"
    assert inds.shape[:len(inds.shape)-1] == data.shape[:len(inds.shape)-1], f"Batch dimensions must match between data and inds"
   
    if len(inds.shape) == len(data.shape):
        return torch.gather(data, dim=-1, index=inds)

    append_shape = (1,) * (len(data.shape) - len(inds.shape))
    append_shape_broadcasted = data.shape[len(inds.shape) - len(data.shape):]
    inds_broadcasted = inds.reshape(inds.shape + append_shape)
    inds_broadcasted = inds_broadcasted.expand(inds.shape + append_shape_broadcasted)
    return torch.gather(data, dim=len(inds.shape) - 1, index=inds_broadcasted)


def gather_frame_atom_by_indices_my_version(
    coordinate: torch.Tensor, frame_atom_index: torch.Tensor, dim: int = -2
) -> torch.Tensor:
    """construct frames from coordinate

    Args:
        coordinate (torch.Tensor):  the input coordinate
            [..., N_atom, 3[three coordinates]]
        frame_atom_index (torch.Tensor): indices of three atoms in each frame
            [..., N_frame, 3[three atoms per frame]] or [N_frame, 3[three atoms per frame]]
        dim (int): along which dimension to select the frame atoms
    Returns:
        torch.Tensor: the constructed frames
            [..., N_frame, 3[three atoms per frame], 3[three coordinates]]
    """
    if len(frame_atom_index.shape) == 2:
        # the navie case
        return coordinate[..., frame_atom_index, :]
    else:
        assert (
            frame_atom_index.shape[:dim] == coordinate.shape[:dim]
        ), f"the size of each batch dim should match, got {frame_atom_index.shape[:dim]} and {coordinate.shape[:dim]}"

    reshaped_frame_atom_index = frame_atom_index.reshape(*frame_atom_index.shape[:-2], -1)  # [..., N_frame*3]
    batched_frame_atom_coordinates = batched_gather_myversion(
        data=coordinate,
        inds=reshaped_frame_atom_index
    )  # [..., N_frame*3, 3[three coordinates]]
    return batched_frame_atom_coordinates.reshape(*batched_frame_atom_coordinates.shape[:-2], frame_atom_index.shape[-2], frame_atom_index.shape[-1], coordinate.shape[-1])  # [..., N_frame, 3, 3]



n_atoms = 100
n_frames = 11
n_atoms_per_frame = 3 # must be 3
n_coordinates = 3
batch_dims = (20, 3, 4)
coordinate = torch.randn(*batch_dims, n_atoms, n_coordinates)  # Example shape (2, 3, 4, 100, 3)
# frame_atom_index = torch.randint(0, n_atoms, (n_frames, n_atoms_per_frame))
frame_atom_index = torch.randint(0, n_atoms, (*batch_dims, n_frames, n_atoms_per_frame))

coordinate = coordinate.to(torch.device('cuda:0'))
frame_atom_index = frame_atom_index.to(torch.device('cuda:0'))

dim = -2
x1 = gather_frame_atom_by_indices(coordinate, frame_atom_index, dim=dim)
x1_002 = gather_frame_atom_by_indices_my_version(coordinate, frame_atom_index, dim=dim)
print("x1 shape:", x1.shape)  # Should be (2, 3, 4, 11, 3)
print("x1_002 shape:", x1_002.shape)  # Should be (2, 3, 4, 11, 3)
print("x1 stride:", x1.stride())  # Should be (30, 3, 1)
print("x1_002 stride:", x1_002.stride())  # Should be (30, 3, 1)
print("x1 is contiguous:", x1.is_contiguous())  # Should be True
print("x1_002 is contiguous:", x1_002.is_contiguous())  # Should be True    
print("x1 and x1_002 are equal:", torch.equal(x1, x1_002))  # Should be True







# ===== compare the execution time of the two versions =====
import torch
import time
import torch.utils.benchmark as benchmark

# ----------------------
# Your two functions
# ----------------------
fun1 = gather_frame_atom_by_indices
fun2 = gather_frame_atom_by_indices_my_version

# ----------------------
# Setup: Create test tensor (match your real use case!)
# ----------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Warmup: VERY important for GPU (compiles kernels, allocates memory)
fun1(coordinate, frame_atom_index, dim=dim)
fun2(coordinate, frame_atom_index, dim=dim)
if DEVICE == "cuda":
    torch.cuda.synchronize()

# ==============================================================================
# Method 1: Manual timing (simple, good for quick checks)
# ==============================================================================
def time_function(func, repeats=1000):
    if DEVICE == "cuda":
        torch.cuda.synchronize()
    start = time.time()
    
    for _ in range(repeats):
        func(coordinate, frame_atom_index, dim)
    
    if DEVICE == "cuda":
        torch.cuda.synchronize()
    end = time.time()
    return (end - start) / repeats  # average time per call

# Run
t1 = time_function(fun1)
t2 = time_function(fun2)

print("=" * 50)
print("Manual Timing (avg per call)")
print(f"fun1: {t1*10**6:.4f} us")
print(f"fun2: {t2*10**6:.4f} us")
print(f"Faster by: {max(t1/t2, t2/t1):.2f}x")

# ==============================================================================
# Method 2: Torch Built-in Benchmark (RECOMMENDED - most accurate)
# Handles warmup, CUDA sync, statistics, outliers
# ==============================================================================
print("\n" + "=" * 50)
print("Torch Benchmark (official, accurate)")

t_fun1 = benchmark.Timer(
    stmt='fun1(coordinate, frame_atom_index, dim)',
    globals={'fun1': fun1, 'coordinate': coordinate, 'frame_atom_index': frame_atom_index, 'dim': dim}
)
t_fun2 = benchmark.Timer(
    stmt='fun2(coordinate, frame_atom_index, dim)',
    globals={'fun2': fun2, 'coordinate': coordinate, 'frame_atom_index': frame_atom_index, 'dim': dim}
)


res1 = t_fun1.timeit(1000)
res2 = t_fun2.timeit(1000)

print(res1)
print(res2)
print(f"Faster function: {'fun1' if res1.mean < res2.mean else 'fun2'}")
print(f"Faster by: {max(res1.mean/res2.mean, res2.mean/res1.mean):.2f}x")

On CPU, my implementation is about 5x faster.

/tmp/ipykernel_624/2240100266.py:36: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:347.)
  return data[ranges]
x1 shape: torch.Size([20, 3, 4, 11, 3, 3])
x1_002 shape: torch.Size([20, 3, 4, 11, 3, 3])
x1 stride: (1188, 396, 99, 9, 3, 1)
x1_002 stride: (1188, 396, 99, 9, 3, 1)
x1 is contiguous: True
x1_002 is contiguous: True
x1 and x1_002 are equal: True
==================================================
Manual Timing (avg per call)
fun1: 193.9402 us
fun2: 47.1077 us
Faster by: 4.12x

==================================================
Torch Benchmark (official, accurate)
<torch.utils.benchmark.utils.common.Measurement object at 0x794018597020>
fun1(coordinate, frame_atom_index, dim)
  252.56 us
  1 measurement, 1000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x793ee2c9cc20>
fun2(coordinate, frame_atom_index, dim)
  46.05 us
  1 measurement, 1000 runs , 1 thread
Faster function: fun2
Faster by: 5.48x

On colab, H100 GPU, my implementation is about 15x faster.

/tmp/ipykernel_624/961558422.py:36: UserWarning: Using a non-tuple sequence for multidimensional indexing is deprecated and will be changed in pytorch 2.9; use x[tuple(seq)] instead of x[seq]. In pytorch 2.9 this will be interpreted as tensor index, x[torch.tensor(seq)], which will result either in an error or a different result (Triggered internally at /pytorch/torch/csrc/autograd/python_variable_indexing.cpp:347.)
  return data[ranges]
x1 shape: torch.Size([20, 3, 4, 11, 3, 3])
x1_002 shape: torch.Size([20, 3, 4, 11, 3, 3])
x1 stride: (1188, 396, 99, 9, 3, 1)
x1_002 stride: (1188, 396, 99, 9, 3, 1)
x1 is contiguous: True
x1_002 is contiguous: True
x1 and x1_002 are equal: True
==================================================
Manual Timing (avg per call)
fun1: 285.7633 us
fun2: 18.6350 us
Faster by: 15.33x

==================================================
Torch Benchmark (official, accurate)
<torch.utils.benchmark.utils.common.Measurement object at 0x794019706690>
fun1(coordinate, frame_atom_index, dim)
  298.83 us
  1 measurement, 1000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x793ee2653920>
fun2(coordinate, frame_atom_index, dim)
  19.55 us
  1 measurement, 1000 runs , 1 thread
Faster function: fun2
Faster by: 15.29x

To compare the execution time of the two versions of batched_gather(), you can refer to aqlaboratory/openfold-3#135 (comment)

…dding to an atom embedding

use torch.gather() to simplify the process of broadcasting a token embedding to an atom embedding
remaining_dims = [slice(None) for _ in range(len(data.shape) - no_batch_dims)]
remaining_dims[dim - no_batch_dims if dim >= 0 else dim] = inds
ranges.extend(remaining_dims)
return data[ranges]
Copy link
Copy Markdown
Contributor Author

@OccupyMars2030 OccupyMars2030 Mar 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we are gathering data along only "one" dimension, there is no need to construct an index tensor for each dim like "ranges". Just use torch.gather() because it is more logically concise

add more detailed docstring
@OccupyMars2030
Copy link
Copy Markdown
Contributor Author

OccupyMars2030 commented Mar 14, 2026

TODO(maybe 2 weeks later): study if we can use torch.repeat_interleave(), torch.Tensor.repeat() to improve efficiency.

reference: https://github.com/aqlaboratory/openfold-3/blob/main/openfold3/core/utils/atomize_utils.py#L106, Openfold-3 include the function batched_gather(), but Openfold-3 didn't use it, how does it handle these 2 tasks:

1. broadcasting a token embedding to an atom embedding.

2. gathering the frame atom coordinates

@zhangyuxuann
Copy link
Copy Markdown
Collaborator

@OccupyMars2025
Thanks for the contribution! I'll review this PR soon. Appreciate your effort.


pad_left = (n_keys - n_queries) // 2
pad_right = int((n_trunks - 1 / 2) * n_queries + n_keys / 2 - n + 1 / 2)
pad_right = (n_keys - n_queries) // 2 + (n_trunks * n_queries - n)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line has nothing to do with my other modification. I just cannot understand the original calculation method for pad_right.

Note: both n_keys and n_queries are even integer numbers. You have used assert statement to confirm it.

https://github.com/OccupyMars2025/Protenix/wiki/explain-how-to-calculate-pad_right-for-the-key-tensor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants