Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions src/diffusers/hooks/group_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,36 +129,45 @@ def _init_cpu_param_dict(self):
if self.stream is None:
return cpu_param_dict

def _maybe_pin(data):
# Only pin plain torch.Tensor instances. Subclasses (e.g. torchao
# AffineQuantizedTensor) may not support pin_memory() correctly and
# can silently lose their quantization metadata.
if self.low_cpu_mem_usage or type(data) is not torch.Tensor:
return data
return data.pin_memory()

for module in self.modules:
for param in module.parameters():
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
cpu_param_dict[param] = _maybe_pin(param.data.cpu())
for buffer in module.buffers():
cpu_param_dict[buffer] = (
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
)
cpu_param_dict[buffer] = _maybe_pin(buffer.data.cpu())

for param in self.parameters:
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
cpu_param_dict[param] = _maybe_pin(param.data.cpu())

for buffer in self.buffers:
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
cpu_param_dict[buffer] = _maybe_pin(buffer.data.cpu())

return cpu_param_dict

@contextmanager
def _pinned_memory_tensors(self):
try:
pinned_dict = {
param: tensor.pin_memory() if not tensor.is_pinned() else tensor
param: (tensor.pin_memory() if not tensor.is_pinned() and type(tensor) is torch.Tensor else tensor)
for param, tensor in self.cpu_param_dict.items()
}
yield pinned_dict
finally:
pinned_dict = None

def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if self.record_stream:
# Tensor subclasses (e.g. torchao AffineQuantizedTensor) may not support
# non-blocking transfers correctly, so fall back to synchronous copy.
non_blocking = self.non_blocking and type(source_tensor) is torch.Tensor
tensor.data = source_tensor.to(self.onload_device, non_blocking=non_blocking)
if self.record_stream and type(tensor.data) is torch.Tensor:
tensor.data.record_stream(default_stream)

def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
Expand Down