fix: disable non-blocking tensor copies to MPS during model loading#13308
Open
agarwalprakhar2511 wants to merge 1 commit intohuggingface:mainfrom
Open
fix: disable non-blocking tensor copies to MPS during model loading#13308agarwalprakhar2511 wants to merge 1 commit intohuggingface:mainfrom
agarwalprakhar2511 wants to merge 1 commit intohuggingface:mainfrom
Conversation
When loading model weights with `device_map="mps"`, `load_model_dict_into_meta` unconditionally passes `non_blocking=True` to `set_module_tensor_to_device` (accelerate > 1.8.1). With mmap-backed safetensors the source CPU memory can be released before the asynchronous MPS copy completes, silently corrupting the destination weights. The corruption is non-deterministic and dtype-dependent (float32 corrupts weights but not biases; float16 corrupts biases but not weights), producing extreme values (~1e37), LayerNorm overflow, and NaN outputs. Move the `non_blocking` / `clear_cache` assignment after `param_device` is resolved, and force `non_blocking=False` when the target is MPS. Fixes huggingface#13227 Made-with: Cursor
a6cfdf7 to
9da9b68
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes silent weight corruption when loading models with
device_map="mps".Root cause
load_model_dict_into_metaunconditionally setsnon_blocking=Trueforset_module_tensor_to_device(added for accelerate > 1.8.1). With mmap-backed safetensors, the source CPU tensor can be released or recycled before the asynchronous MPS copy completes, leaving the MPS parameter filled with garbage data.The corruption is non-deterministic and dtype-dependent:
float32+ MPS: weights corrupted, biases OKfloat16+ MPS: biases corrupted, weights OKThis produces extreme values (~1e37), LayerNorm overflow, and NaN/zero outputs (all-black images). The user's workaround — loading to CPU first, then calling
.to("mps")— works because.to()is synchronous and the state dict is still alive.Fix
In
load_model_dict_into_meta, detect when the target device is MPS and forcenon_blocking=Falsefor those transfers. All other devices (CUDA, CPU) continue to usenon_blocking=True.Affected pipelines
Any pipeline loaded with
device_map="mps", including but not limited toGlmImagePipeline(reported), Flux, StableDiffusion, etc. The fix is in shared loading infrastructure, not pipeline-specific code.Test plan
non_blockingvalues for"mps",torch.device("mps"),"cpu","cuda:0", and integer device indices.GlmImagePipeline.from_pretrained(..., device_map="mps")) should no longer produce NaN outputs. I don't have access to the 10GB+ GLM-Image model weights to run end-to-end, but the code path is clear and the fix is minimal.Fixes #13227