Skip to content

Batched multi-seed inference: 2x speedup on 5-seed runs#284

Open
longleo17 wants to merge 3 commits intobytedance:mainfrom
longleo17:batched-seed-inference
Open

Batched multi-seed inference: 2x speedup on 5-seed runs#284
longleo17 wants to merge 3 commits intobytedance:mainfrom
longleo17:batched-seed-inference

Conversation

@longleo17
Copy link
Copy Markdown
Contributor

@longleo17 longleo17 commented Mar 24, 2026

Batched multi-seed inference that runs shared input embedding once, then per-seed evoformer (with independent MSA subsampling), then batched diffusion across all seeds, then per-seed confidence head. This gives a 2x wall-time speedup on 5-seed inference (39s vs 71s). Builds on top of #283

Architecture

Phase 1: Shared input embedding (run ONCE — deterministic in eval mode)
    ↓
Phase 2: Per-seed evoformer loop (MSA subsampling differs per seed)
    ↓  stack evoformer outputs → [N_seeds, N_token, ...]
Phase 3: Batched diffusion (all seeds processed together)
    ↓  prepare_cache per-seed (mixes batched z with unbatched atom features)
Phase 4: Per-seed confidence head + summary

How to enable

Enabled by default when multiple seeds are requested. To disable, set PROTENIX_NOBATCHED_SEEDS=1. On CUDA OOM, automatically falls back to sequential per-seed inference.

Key changes

  • protenix/model/protenix.py: New _shared_input_embedding() and _batched_seed_inference() methods. Modified main_inference_loop() to separate volatile features (MSA, templates — deleted by forward pass) from non-volatile features, enabling smart cloning instead of full deepcopy. Added N_model_seed_override to forward(). Batched mode now on by default with opt-out via PROTENIX_NOBATCHED_SEEDS=1.
  • runner/inference.py: Prefetch pipeline with ThreadPoolExecutor for CPU/GPU overlap. Seed caching (deep copy on first seed, replay for subsequent). Async CUDA transfer stream with pinned memory. Batched seed mode on by default with OOM fallback to sequential per-seed loop.
  • protenix/utils/torch_utils.py: to_device() now accepts non_blocking=True. New pin_memory() utility for nested dicts.
  • protenix/model/triangular/layers.py, triangular.py: Document that cuequivariance kernels natively support batch dims (no wrappers needed).

Benchmark

Mode Wall time (5 seeds, ~250 tokens)
Sequential 71s
Batched (default) 39s

Dependencies

Based on a previous PR to speed up data pipeline

longleo17 and others added 3 commits March 24, 2026 12:19
Key optimizations:
- Featurizer.encoder: replace manual one-hot dict with torch.nn.functional.one_hot
- ref_atom_name_chars_encoded: vectorized ASCII encoding via numpy frombuffer + one_hot
- Template featurizer: pre-allocate numpy arrays instead of list append + np.stack
- Template featurizer: reuse shared DistogramFeaturesConfig instance
- Template parser: vectorized numpy operations for coordinate extraction
- Template utils: batch numpy operations for atom mask and position computation
- Dataset: replace df.apply(lambda) with df.isin() for eval_type filtering

Benchmarked on template+MSA heavy workloads: ~37% wall-clock reduction in
CPU featurization time, measured end-to-end on representative PDB complexes.
Architecture: shared input embedding (run once) -> per-seed evoformer
(MSA subsampling differs per seed) -> batched diffusion across all seeds
-> per-seed confidence head.

Enable with PROTENIX_BATCHED_SEEDS=1 env var. Falls back to sequential
on failure. Benchmark: 39s vs 71s wall time on 5-seed inference.

Also adds:
- GPU-CPU overlap pipeline for inference (prefetch, pinned memory, async
  CUDA transfer stream, seed caching)
- Smart copy: volatile features (MSA, templates) cloned per seed instead
  of full deepcopy
- Non-blocking to_device() and pin_memory() utilities
- Document cuequivariance native batch dim support (no wrappers needed)
…EDS=1

Batched multi-seed inference is now enabled by default when multiple seeds
are requested. On CUDA OOM, the runner automatically falls back to
sequential per-seed inference. Set PROTENIX_NOBATCHED_SEEDS=1 to force
sequential mode.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant