Batched multi-seed inference: 2x speedup on 5-seed runs#284
Open
longleo17 wants to merge 3 commits intobytedance:mainfrom
Open
Batched multi-seed inference: 2x speedup on 5-seed runs#284longleo17 wants to merge 3 commits intobytedance:mainfrom
longleo17 wants to merge 3 commits intobytedance:mainfrom
Conversation
Key optimizations: - Featurizer.encoder: replace manual one-hot dict with torch.nn.functional.one_hot - ref_atom_name_chars_encoded: vectorized ASCII encoding via numpy frombuffer + one_hot - Template featurizer: pre-allocate numpy arrays instead of list append + np.stack - Template featurizer: reuse shared DistogramFeaturesConfig instance - Template parser: vectorized numpy operations for coordinate extraction - Template utils: batch numpy operations for atom mask and position computation - Dataset: replace df.apply(lambda) with df.isin() for eval_type filtering Benchmarked on template+MSA heavy workloads: ~37% wall-clock reduction in CPU featurization time, measured end-to-end on representative PDB complexes.
Architecture: shared input embedding (run once) -> per-seed evoformer (MSA subsampling differs per seed) -> batched diffusion across all seeds -> per-seed confidence head. Enable with PROTENIX_BATCHED_SEEDS=1 env var. Falls back to sequential on failure. Benchmark: 39s vs 71s wall time on 5-seed inference. Also adds: - GPU-CPU overlap pipeline for inference (prefetch, pinned memory, async CUDA transfer stream, seed caching) - Smart copy: volatile features (MSA, templates) cloned per seed instead of full deepcopy - Non-blocking to_device() and pin_memory() utilities - Document cuequivariance native batch dim support (no wrappers needed)
…EDS=1 Batched multi-seed inference is now enabled by default when multiple seeds are requested. On CUDA OOM, the runner automatically falls back to sequential per-seed inference. Set PROTENIX_NOBATCHED_SEEDS=1 to force sequential mode.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Batched multi-seed inference that runs shared input embedding once, then per-seed evoformer (with independent MSA subsampling), then batched diffusion across all seeds, then per-seed confidence head. This gives a 2x wall-time speedup on 5-seed inference (39s vs 71s). Builds on top of #283
Architecture
How to enable
Enabled by default when multiple seeds are requested. To disable, set
PROTENIX_NOBATCHED_SEEDS=1. On CUDA OOM, automatically falls back to sequential per-seed inference.Key changes
protenix/model/protenix.py: New_shared_input_embedding()and_batched_seed_inference()methods. Modifiedmain_inference_loop()to separate volatile features (MSA, templates — deleted by forward pass) from non-volatile features, enabling smart cloning instead of full deepcopy. AddedN_model_seed_overridetoforward(). Batched mode now on by default with opt-out viaPROTENIX_NOBATCHED_SEEDS=1.runner/inference.py: Prefetch pipeline withThreadPoolExecutorfor CPU/GPU overlap. Seed caching (deep copy on first seed, replay for subsequent). Async CUDA transfer stream with pinned memory. Batched seed mode on by default with OOM fallback to sequential per-seed loop.protenix/utils/torch_utils.py:to_device()now acceptsnon_blocking=True. Newpin_memory()utility for nested dicts.protenix/model/triangular/layers.py,triangular.py: Document that cuequivariance kernels natively support batch dims (no wrappers needed).Benchmark
Dependencies
Based on a previous PR to speed up data pipeline