Skip to content

fix: disable non-blocking tensor copies to MPS during model loading#13308

Open
agarwalprakhar2511 wants to merge 1 commit intohuggingface:mainfrom
agarwalprakhar2511:fix/mps-weight-corruption-non-blocking
Open

fix: disable non-blocking tensor copies to MPS during model loading#13308
agarwalprakhar2511 wants to merge 1 commit intohuggingface:mainfrom
agarwalprakhar2511:fix/mps-weight-corruption-non-blocking

Conversation

@agarwalprakhar2511
Copy link

@agarwalprakhar2511 agarwalprakhar2511 commented Mar 22, 2026

Summary

Fixes silent weight corruption when loading models with device_map="mps".

Root cause

load_model_dict_into_meta unconditionally sets non_blocking=True for set_module_tensor_to_device (added for accelerate > 1.8.1). With mmap-backed safetensors, the source CPU tensor can be released or recycled before the asynchronous MPS copy completes, leaving the MPS parameter filled with garbage data.

The corruption is non-deterministic and dtype-dependent:

  • float32 + MPS: weights corrupted, biases OK
  • float16 + MPS: biases corrupted, weights OK

This produces extreme values (~1e37), LayerNorm overflow, and NaN/zero outputs (all-black images). The user's workaround — loading to CPU first, then calling .to("mps") — works because .to() is synchronous and the state dict is still alive.

Fix

In load_model_dict_into_meta, detect when the target device is MPS and force non_blocking=False for those transfers. All other devices (CUDA, CPU) continue to use non_blocking=True.

is_mps_target = str(param_device) == "mps" or (
    isinstance(param_device, torch.device) and param_device.type == "mps"
)
set_module_kwargs["non_blocking"] = not is_mps_target

Affected pipelines

Any pipeline loaded with device_map="mps", including but not limited to GlmImagePipeline (reported), Flux, StableDiffusion, etc. The fix is in shared loading infrastructure, not pipeline-specific code.

Test plan

  • Verified MPS device detection produces correct non_blocking values for "mps", torch.device("mps"), "cpu", "cuda:0", and integer device indices.
  • The reporter's reproduction (GlmImagePipeline.from_pretrained(..., device_map="mps")) should no longer produce NaN outputs. I don't have access to the 10GB+ GLM-Image model weights to run end-to-end, but the code path is clear and the fix is minimal.

Fixes #13227

When loading model weights with `device_map="mps"`, `load_model_dict_into_meta`
unconditionally passes `non_blocking=True` to `set_module_tensor_to_device`
(accelerate > 1.8.1). With mmap-backed safetensors the source CPU memory can
be released before the asynchronous MPS copy completes, silently corrupting
the destination weights.

The corruption is non-deterministic and dtype-dependent (float32 corrupts
weights but not biases; float16 corrupts biases but not weights), producing
extreme values (~1e37), LayerNorm overflow, and NaN outputs.

Move the `non_blocking` / `clear_cache` assignment after `param_device` is
resolved, and force `non_blocking=False` when the target is MPS.

Fixes huggingface#13227

Made-with: Cursor
@agarwalprakhar2511 agarwalprakhar2511 force-pushed the fix/mps-weight-corruption-non-blocking branch from a6cfdf7 to 9da9b68 Compare March 22, 2026 09:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] GlmImagePipeline silently corrupts weights on MPS accelerator

1 participant