Skip to content

fix: skip pin_memory and non_blocking transfer for tensor subclasses in group offloading#13305

Open
s-zx wants to merge 1 commit intohuggingface:mainfrom
s-zx:fix/group-offloading-torchao-subclass-v2
Open

fix: skip pin_memory and non_blocking transfer for tensor subclasses in group offloading#13305
s-zx wants to merge 1 commit intohuggingface:mainfrom
s-zx:fix/group-offloading-torchao-subclass-v2

Conversation

@s-zx
Copy link

@s-zx s-zx commented Mar 22, 2026

What does this PR do?

Fixes #13281.

When combining torchao quantization (e.g. Float8WeightOnlyConfig) with group offloading (use_stream=True), inference fails with a device mismatch: the quantized weight remains on CPU while the input is on CUDA.

Root cause

GroupOffloadingHook._init_cpu_param_dict calls .pin_memory() on every parameter/buffer, and _pinned_memory_tensors pins them again before the async host-to-device transfer. For plain torch.Tensor this works correctly, but for tensor subclasses such as torchao's AffineQuantizedTensor, pin_memory() may silently strip the quantization metadata, producing a plain CPU float tensor. The subsequent .to(device, non_blocking=True) then moves a plain float tensor instead of the quantized one, leaving the quantized weight behind on CPU.

Similarly, _transfer_tensor_to_device schedules record_stream for all results, but AffineQuantizedTensor.to(...) may not return a plain torch.Tensor, so record_stream could raise or be a no-op.

Fix

  • _init_cpu_param_dict: introduce _maybe_pin() helper that calls pin_memory() only when type(data) is torch.Tensor. Tensor subclasses are stored as-is (already on CPU).
  • _pinned_memory_tensors: apply the same type(tensor) is torch.Tensor guard.
  • _transfer_tensor_to_device: force synchronous (blocking) transfer for tensor subclasses, and skip record_stream when the result is not a plain tensor.

Plain-tensor performance is completely unaffected; the async-stream path is preserved for all non-quantized parameters.

…in group offloading

Tensor subclasses such as torchao's AffineQuantizedTensor may not support
pin_memory() correctly and can silently lose quantization metadata when
pinned. Similarly, non-blocking host-to-device transfers may race with
forward computation for subclasses that override .to().

- _init_cpu_param_dict: call pin_memory() only for plain torch.Tensor
- _pinned_memory_tensors: same guard in the context-manager path
- _transfer_tensor_to_device: force synchronous transfer for subclasses,
  and skip record_stream() when the result is not a plain tensor

Fixes huggingface#13281
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Group offloading with use_stream=True breaks torchao quantized models (device mismatch) in qwen image

2 participants