Skip to content
Open
2 changes: 1 addition & 1 deletion auto_round/auto_scheme/delta_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
apply_quant_scheme,
compute_avg_bits_for_scheme,
compute_layer_bits,
dispatch_model_by_all_available_devices,
parse_shared_layers,
remove_quant_scheme,
)
Expand All @@ -45,6 +44,7 @@
SUPPORTED_LAYER_TYPES,
check_to_quantized,
clear_memory,
dispatch_model_by_all_available_devices,
get_block_names,
get_major_device,
get_module,
Expand Down
50 changes: 1 addition & 49 deletions auto_round/auto_scheme/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from auto_round.schemes import QuantizationScheme, preset_name_to_scheme
from auto_round.utils import (
DEVICE_ENVIRON_VARIABLE_MAPPING,
SUPPORTED_LAYER_TYPES,
check_to_quantized,
get_block_names,
Expand Down Expand Up @@ -213,55 +214,6 @@ def compute_layer_bits(
return total_bits, avg_bits


# Important Notice This dispatch does not follow dict device_map, just extract all available devices and use them
def dispatch_model_by_all_available_devices(
model: torch.nn.Module, device_map: Union[str, int, dict, None]
) -> torch.nn.Module:
if device_map is None:
device_map = 0

no_split_modules = normalize_no_split_modules(getattr(model, "_no_split_modules", []))
if device_map == "auto":
max_memory = get_balanced_memory(
model,
max_memory=None,
no_split_module_classes=no_split_modules,
)
device_map = infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=no_split_modules)
model = dispatch_model(model, device_map=device_map)
return model

devices = parse_available_devices(device_map)

if len(devices) == 1:
model.to(devices[0])
return model

max_memory = get_balanced_memory(
model,
max_memory=None,
no_split_module_classes=no_split_modules,
)

# Filter max_memory with devices
# assume only one GPU model
new_max_memory = {}
for device in devices:
if ":" in device:
device = int(device.split(":")[-1])
elif device == "cpu":
device = "cpu"
elif isinstance(device, str):
device = 0
else:
raise ValueError(f"Unsupported device {device} in device_map: {device_map}")
new_max_memory[device] = max_memory[device]
model.tie_weights()
device_map = infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=no_split_modules)
model = dispatch_model(model, device_map=device_map)
return model


def merge_lists_unionfind(list_of_lists):
parent = {}

Expand Down
15 changes: 12 additions & 3 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1573,7 +1573,11 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
block = convert_module_to_hp_if_necessary(block, dtype=self.amp_dtype, device=self.device)
update_block_global_scale_if_needed(block, self.data_type, self.group_size)
self._register_act_max_hook(block)
if is_auto_device_mapping(self.device_map) and len(self.device_list) > 1:
if (
is_auto_device_mapping(self.device_map)
and len(self.device_list) > 1
and not getattr(self, "is_diffusion", False)
):
set_auto_device_map_for_block_with_tuning(
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, self.device
)
Expand Down Expand Up @@ -2320,7 +2324,8 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
max_memory=new_max_memory,
no_split_module_classes=no_split_modules,
)
self.model.tie_weights()
if hasattr(self.model, "tie_weights") and callable(self.model.tie_weights):
self.model.tie_weights()
device_map = infer_auto_device_map(
self.model, max_memory=new_max_memory, no_split_module_classes=no_split_modules
)
Expand Down Expand Up @@ -2979,7 +2984,11 @@ def _quantize_block(
if auto_offload:
# card_0_in_high_risk indicates that card_0 memory is already in high usage (90%) w/o any weights
# loss_device is used to calculate loss on the second device if available and card_0_in_high_risk
if is_auto_device_mapping(self.device_map) and len(self.device_list) > 1:
if (
is_auto_device_mapping(self.device_map)
and len(self.device_list) > 1
and not getattr(self, "is_diffusion", False)
):
card_0_in_high_risk, loss_device = set_auto_device_map_for_block_with_tuning(
block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device
)
Expand Down
9 changes: 6 additions & 3 deletions auto_round/compressors/diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,12 @@ auto-round \

For diffusion models, currently we only validate quantizaion on the FLUX.1-dev, which involves quantizing the transformer component of the pipeline.

| Model | calibration dataset |
|--------------|--------------|
| black-forest-labs/FLUX.1-dev | COCO2014 |
| Model | calibration dataset | Model Link |
|---------------|---------------------|--------------|
| black-forest-labs/FLUX.1-dev | COCO2014 | - |
| Tongyi-MAI/Z-Image | COCO2014 | - |
| Tongyi-MAI/Z-Image-Turb | COCO2014 | - |
| stepfun-ai/NextStep-1.1 | COCO2014 | - |



Expand Down
Loading
Loading