diff --git a/src/diffusers/hooks/group_offloading.py b/src/diffusers/hooks/group_offloading.py index 891ac28455af..bbf916f4436e 100644 --- a/src/diffusers/hooks/group_offloading.py +++ b/src/diffusers/hooks/group_offloading.py @@ -129,19 +129,25 @@ def _init_cpu_param_dict(self): if self.stream is None: return cpu_param_dict + def _maybe_pin(data): + # Only pin plain torch.Tensor instances. Subclasses (e.g. torchao + # AffineQuantizedTensor) may not support pin_memory() correctly and + # can silently lose their quantization metadata. + if self.low_cpu_mem_usage or type(data) is not torch.Tensor: + return data + return data.pin_memory() + for module in self.modules: for param in module.parameters(): - cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() + cpu_param_dict[param] = _maybe_pin(param.data.cpu()) for buffer in module.buffers(): - cpu_param_dict[buffer] = ( - buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory() - ) + cpu_param_dict[buffer] = _maybe_pin(buffer.data.cpu()) for param in self.parameters: - cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory() + cpu_param_dict[param] = _maybe_pin(param.data.cpu()) for buffer in self.buffers: - cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory() + cpu_param_dict[buffer] = _maybe_pin(buffer.data.cpu()) return cpu_param_dict @@ -149,7 +155,7 @@ def _init_cpu_param_dict(self): def _pinned_memory_tensors(self): try: pinned_dict = { - param: tensor.pin_memory() if not tensor.is_pinned() else tensor + param: (tensor.pin_memory() if not tensor.is_pinned() and type(tensor) is torch.Tensor else tensor) for param, tensor in self.cpu_param_dict.items() } yield pinned_dict @@ -157,8 +163,11 @@ def _pinned_memory_tensors(self): pinned_dict = None def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream): - tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking) - if self.record_stream: + # Tensor subclasses (e.g. torchao AffineQuantizedTensor) may not support + # non-blocking transfers correctly, so fall back to synchronous copy. + non_blocking = self.non_blocking and type(source_tensor) is torch.Tensor + tensor.data = source_tensor.to(self.onload_device, non_blocking=non_blocking) + if self.record_stream and type(tensor.data) is torch.Tensor: tensor.data.record_stream(default_stream) def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):