Skip to content

AutoScheme not supported in MLLM/VLM mode (Qwen3-VL) — timeline for support? #1273

@tosharybkin

Description

@tosharybkin

Description:

I’m trying to use AutoScheme (adaptive mixed-bit scheme search) with a multimodal (VLM/MLLM) model, but AutoRound currently reports that AutoScheme is not supported for multimodal LLMs.


What I’m doing

Model: Qwen/Qwen3-VL-30B-A3B-Instruct
Goal: automatically pick a mixed scheme close to avg_bits=2.8 using options like ("W2A16", "W4A16") on real image+prompt calibration samples.

Minimal code:

import torch
from torch.utils.data import DataLoader
from transformers import AutoProcessor, Qwen3VLMoeForConditionalGeneration
from auto_round import AutoRound, AutoScheme

model_id = "Qwen/Qwen3-VL-30B-A3B-Instruct"

model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
)
model.eval()

processor = AutoProcessor.from_pretrained(model_id)
tokenizer = processor.tokenizer  # pass tokenizer to AutoRound for VLM

# dataset provides (image, prompt) samples, collator builds model inputs
ds = JsonlImagePromptDataset("/path/to/valid.jsonl")
collate = Qwen3VLCollator(processor=processor, image_dir="/path/to/images")

def collate_to_device(batch):
    inputs = collate(batch)
    inputs.pop("token_type_ids", None)
    return inputs.to(model.device)

loader = DataLoader(ds, batch_size=1, shuffle=True, num_workers=0, collate_fn=collate_to_device)

scheme = AutoScheme(
    avg_bits=2.8,
    options=("W2A16", "W4A16"),
    ignore_scale_zp_bits=True,
    batch_size=1,
)

ar = AutoRound(
    model=model,
    tokenizer=tokenizer,
    scheme=scheme,
    dataset=loader,
    nsamples=200,
    seqlen=3000,
    iters=200,
    batch_size=1,
    low_gpu_mem_usage=False,
    device_map=0,
    quant_nontext_module=False,  # I don't want to quantize vision modules
)

ar.quantize_and_save(output_dir="./out", format="auto_round")

What happens

AutoRound detects VLM mode and then prints:

  • INFO ... using MLLM mode for multimodal model.
  • INFO ... AutoScheme is not yet supported for multimodal LLMs.

So AutoScheme cannot be used in this setup.

Expected behavior

I expected AutoScheme to be usable for multimodal models at least in the sense of:

  • allowing multimodal (image+prompt) calibration/tuning runs, and selecting mixed W2/W4 schemes for the LLM part based on the actual multimodal forward path.

Why it matters

My use case is OCR with a VLM (documents/screenshots). Visual tokens are crucial here because the LLM’s reasoning and decoding quality depends on the visual evidence being injected correctly into the language backbone.

Importantly: I do NOT plan to quantize the visual modules (vision encoder / projector). I want to keep them in higher precision (that’s why quant_nontext_module=False).
However, even if the vision side is kept unquantized, it still changes which parts of the LLM are active and how they are used: the presence of visual tokens affects attention patterns, MLP activations, and in MoE models potentially expert routing / activation distribution inside the LLM.

So for OCR workloads, it’s not enough to run AutoScheme on text-only calibration prompts. To pick good mixed-bit allocations (e.g. W2/W4 for avg_bits=2.8), AutoScheme needs to work in multimodal mode, using image+prompt samples, because that is the forward path that determines which LLM weights matter most under visual conditioning.


Questions

Is AutoScheme support for multimodal (VLM/MLLM) models on your roadmap?
If yes, could you share an estimated timeline / milestone (or what is blocking it)?

Thanks!

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions