From e8dd92765b18b859348ba919c4968f1bb6daee15 Mon Sep 17 00:00:00 2001 From: Jason Date: Sun, 22 Mar 2026 01:29:56 +0100 Subject: [PATCH] fix: skip pin_memory and non_blocking transfer for tensor subclasses in group offloading Tensor subclasses such as torchao's AffineQuantizedTensor may not support pin_memory() correctly and can silently lose quantization metadata when pinned. Similarly, non-blocking host-to-device transfers may race with forward computation for subclasses that override .to(). - _init_cpu_param_dict: call pin_memory() only for plain torch.Tensor - _pinned_memory_tensors: same guard in the context-manager path - _transfer_tensor_to_device: force synchronous transfer for subclasses, and skip record_stream() when the result is not a plain tensor Fixes #13281 --- src/diffusers/hooks/group_offloading.py | 27 ++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 891ac28455af..bbf916f4436e 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -129,19 +129,25 @@ def _init_cpu_param_dict(self): if self.stream is None: return cpu_param_dict + def _maybe_pin(data): + # Only pin plain torch.Tensor instances. Subclasses (e.g. torchao + # AffineQuantizedTensor) may not support pin_memory() correctly and + # can silently lose their quantization metadata. + if self.low_cpu_mem_usage or type(data) is not torch.Tensor: + return data + return data.pin_memory() + for module in self.modules: for param in module.parameters(): - cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() + cpu_param_dict[param] = _maybe_pin(param.data.cpu()) for buffer in module.buffers(): - cpu_param_dict[buffer] = ( - buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory() - ) + cpu_param_dict[buffer] = _maybe_pin(buffer.data.cpu()) for param in self.parameters: - cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() + cpu_param_dict[param] = _maybe_pin(param.data.cpu()) for buffer in self.buffers: - cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory() + cpu_param_dict[buffer] = _maybe_pin(buffer.data.cpu()) return cpu_param_dict @@ -149,7 +155,7 @@ def _init_cpu_param_dict(self): def _pinned_memory_tensors(self): try: pinned_dict = { - param: tensor.pin_memory() if not tensor.is_pinned() else tensor + param: (tensor.pin_memory() if not tensor.is_pinned() and type(tensor) is torch.Tensor else tensor) for param, tensor in self.cpu_param_dict.items() } yield pinned_dict @@ -157,8 +163,11 @@ def _pinned_memory_tensors(self): pinned_dict = None def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): - tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: + # Tensor subclasses (e.g. torchao AffineQuantizedTensor) may not support + # non-blocking transfers correctly, so fall back to synchronous copy. + non_blocking = self.non_blocking and type(source_tensor) is torch.Tensor + tensor.data = source_tensor.to(self.onload_device, non_blocking=non_blocking) + if self.record_stream and type(tensor.data) is torch.Tensor: tensor.data.record_stream(default_stream) def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):