Reduce EMA RAM usage and training overhead with local-shard EMA#47
Open
chijw wants to merge 1 commit intoNVlabs:mainfrom
Open
Reduce EMA RAM usage and training overhead with local-shard EMA#47chijw wants to merge 1 commit intoNVlabs:mainfrom
chijw wants to merge 1 commit intoNVlabs:mainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
This PR refactors EMA handling for FSDP training so EMA is maintained as rank-local parameter shards during training, reducing memory and hot-path overhead, while still exporting a full generator_ema at checkpoint time.
Changes:
- Remove
summon_full_params()usage from EMA init/update/copy paths to keep EMA shard-local. - Add
EMA_FSDP.full_state_dict()to export a full EMA checkpoint by temporarily swapping params and using FSDP full-state-dict gathering. - Update trainer checkpoint save/load flow to initialize EMA from EMA weights (when present) and save via
full_state_dict().
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| utils/distributed.py | Makes EMA shard-local during training; adds EMA full-checkpoint export via FSDP full state dict. |
| trainer/distillation.py | Updates EMA init/restore ordering and switches checkpoint saving to export full EMA only at save time. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
Comment on lines
+129
to
+132
| checkpoint = fsdp_state_dict(fsdp_module) | ||
| for n, p in fsdp_module.module.named_parameters(): | ||
| if n in live_state: | ||
| p.data.copy_(live_state[n].to(dtype=p.dtype, device=p.device)) |
| def full_state_dict(self, fsdp_module): | ||
| live_state = {} | ||
| for n, p in fsdp_module.module.named_parameters(): | ||
| live_state[n] = p.detach().clone() |
Comment on lines
+485
to
+488
| self.model.generator.load_state_dict(checkpoint["generator_ema"], strict=True) | ||
| if self.is_main_process: | ||
| print(f"Setting up EMA with weight {ema_weight}") | ||
| self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight) |
| print(f"Setting up EMA with weight {ema_weight}") | ||
| self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight) | ||
| elif (ema_weight is not None) and (ema_weight > 0.0) and (not self.is_lora_enabled) and self.is_main_process: | ||
| print("Warning: EMA checkpoint not found or EMA not initialized.") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR changes the EMA path so that EMA is maintained as local shards during training instead of materializing full parameters on every rank.
The previous implementation used
summon_full_params()in the EMA hot path, which adds unnecessary communication and keeps a full CPU EMA copy on each rank. With this change, each rank updates only its local EMA shard during training, which reduces both EMA memory usage and per-step overhead.To preserve the existing checkpoint format,
generator_emais still exported as a full state dict at save time. Since EMA is shard-local during training,full_state_dict()reuses the FSDP-wrapped module together withfsdp_state_dict()to gather the full checkpoint, instead of introducing a separate EMA-specific export path.Changes
EMA_FSDPshard-local during trainingsummon_full_params()from EMA init/update/copygenerator_emaonly at save timeself.generator_ema.full_state_dict(self.model.generator)to(dtype=..., device=...)when copying EMA tensors back, for better compatibility with newer PyTorch versions