Skip to content

Reduce EMA RAM usage and training overhead with local-shard EMA#47

Open
chijw wants to merge 1 commit intoNVlabs:mainfrom
chijw:main
Open

Reduce EMA RAM usage and training overhead with local-shard EMA#47
chijw wants to merge 1 commit intoNVlabs:mainfrom
chijw:main

Conversation

@chijw
Copy link
Copy Markdown

@chijw chijw commented Mar 15, 2026

Summary

This PR changes the EMA path so that EMA is maintained as local shards during training instead of materializing full parameters on every rank.

The previous implementation used summon_full_params() in the EMA hot path, which adds unnecessary communication and keeps a full CPU EMA copy on each rank. With this change, each rank updates only its local EMA shard during training, which reduces both EMA memory usage and per-step overhead.

To preserve the existing checkpoint format, generator_ema is still exported as a full state dict at save time. Since EMA is shard-local during training, full_state_dict() reuses the FSDP-wrapped module together with fsdp_state_dict() to gather the full checkpoint, instead of introducing a separate EMA-specific export path.

Changes

  • keep EMA_FSDP shard-local during training
  • remove summon_full_params() from EMA init/update/copy
  • export full generator_ema only at save time
  • switch trainer save paths to use self.generator_ema.full_state_dict(self.model.generator)
  • use to(dtype=..., device=...) when copying EMA tensors back, for better compatibility with newer PyTorch versions
  • adjust checkpoint restore order so that, when generator_ema exists, EMA is initialized from EMA weights before restoring the live generator weights

Copilot AI review requested due to automatic review settings March 15, 2026 06:16
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors EMA handling for FSDP training so EMA is maintained as rank-local parameter shards during training, reducing memory and hot-path overhead, while still exporting a full generator_ema at checkpoint time.

Changes:

  • Remove summon_full_params() usage from EMA init/update/copy paths to keep EMA shard-local.
  • Add EMA_FSDP.full_state_dict() to export a full EMA checkpoint by temporarily swapping params and using FSDP full-state-dict gathering.
  • Update trainer checkpoint save/load flow to initialize EMA from EMA weights (when present) and save via full_state_dict().

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
utils/distributed.py Makes EMA shard-local during training; adds EMA full-checkpoint export via FSDP full state dict.
trainer/distillation.py Updates EMA init/restore ordering and switches checkpoint saving to export full EMA only at save time.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines +129 to +132
checkpoint = fsdp_state_dict(fsdp_module)
for n, p in fsdp_module.module.named_parameters():
if n in live_state:
p.data.copy_(live_state[n].to(dtype=p.dtype, device=p.device))
def full_state_dict(self, fsdp_module):
live_state = {}
for n, p in fsdp_module.module.named_parameters():
live_state[n] = p.detach().clone()
Comment on lines +485 to +488
self.model.generator.load_state_dict(checkpoint["generator_ema"], strict=True)
if self.is_main_process:
print(f"Setting up EMA with weight {ema_weight}")
self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight)
print(f"Setting up EMA with weight {ema_weight}")
self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight)
elif (ema_weight is not None) and (ema_weight > 0.0) and (not self.is_lora_enabled) and self.is_main_process:
print("Warning: EMA checkpoint not found or EMA not initialized.")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants