Quantizing Diffusion Models with Quark#
This guide walks through end-to-end quantization of diffusion model submodules (UNet, transformer, etc.) using AMD Quark. The examples below have been validated on SDXL, SD3, and Flux.1-dev.
For most diffusion pipelines, the bulk of the compute is in a single
heavy submodule (pipe.unet for UNet-based pipelines like SDXL, or
pipe.transformer for transformer-based pipelines like SD3 and
Flux), so quantization targets that submodule directly. Quark’s
get_calib_dataloader (in quark.torch.utils.diffusers) hooks
into the pipeline run, captures the submodule’s inputs as a
dict keyed by the forward parameter names, and packages them into a
DataLoader that ModelQuantizer.quantize_model consumes via
model(**data) automatically.
Prerequisites#
pip install diffusers transformers accelerate
Pattern#
Every diffusion model quantization follows the same two steps:
from quark.torch.utils.diffusers import get_calib_dataloader
# Step 1: collect calibration data by running the pipeline.
dataloader = get_calib_dataloader(pipe, pipe.unet, prompts, n_steps=20, ...)
# Step 2: quantize (same as LLMs -- ModelQuantizer + QConfig).
pipe.unet = ModelQuantizer(qconfig).quantize_model(pipe.unet, dataloader)
get_calib_dataloader runs the pipeline, captures the target
submodule’s inputs as a dict keyed by the forward parameter names,
and returns a DataLoader. ModelQuantizer.quantize_model
consumes those dicts via model(**data) automatically – no wrapping
is required.
Example 1: SDXL – INT8 weight-only#
Observer-based INT8 quantization on the SDXL UNet.
import torch
from diffusers import DiffusionPipeline
from quark.torch import ModelQuantizer
from quark.torch.quantization.config.config import Int8PerTensorSpec, QConfig, QLayerConfig
from quark.torch.utils.diffusers import get_calib_dataloader
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16, variant="fp16",
)
pipe.to("cuda")
prompts = [
"A serene lake reflecting mountains at sunset",
"A futuristic city with flying cars at night",
"A close-up portrait with dramatic lighting",
]
dataloader = get_calib_dataloader(pipe, pipe.unet, prompts, n_steps=20, guidance_scale=8.0)
weight_spec = Int8PerTensorSpec(
observer_method="min_max", symmetric=True, scale_type="float",
round_method="half_even", is_dynamic=False,
).to_quantization_spec()
qconfig = QConfig(global_quant_config=QLayerConfig(weight=weight_spec))
pipe.unet = ModelQuantizer(qconfig).quantize_model(pipe.unet, dataloader)
image = pipe("A cat on a windowsill", num_inference_steps=30, guidance_scale=8.0).images[0]
image.save("sdxl_int8.png")
Example 2: SD3 – SVDQuant w4a4#
SVDQuant decomposes weights via SVD, adds a low-rank correction branch,
and smooths activations. Here both weights and activations are
quantized to INT4 (w4a4).
import torch
from diffusers import StableDiffusion3Pipeline
from quark.torch import ModelQuantizer
from quark.torch.quantization.config.config import QConfig, SVDQuantConfig
from quark.torch.algorithm.svdquant import build_quant_layer_config
from quark.torch.utils.diffusers import get_calib_dataloader
pipe = StableDiffusion3Pipeline.from_pretrained(
"stabilityai/stable-diffusion-3-medium-diffusers",
torch_dtype=torch.float16,
)
pipe.to("cuda")
prompts = [
"A serene lake reflecting mountains at sunset",
"A futuristic city with flying cars at night",
"A close-up portrait with dramatic lighting",
"A golden retriever playing in autumn leaves",
"An astronaut floating above Earth",
]
dataloader = get_calib_dataloader(pipe, pipe.transformer, prompts, n_steps=20)
qconfig = QConfig(
global_quant_config=build_quant_layer_config("w4a4"),
exclude=[
"*time_text_embed*", "*context_embedder*", "*pos_embed*",
"*norm_out*", "*proj_out*", "*correction*",
],
algo_config=[SVDQuantConfig(
svd_rank=32,
search_alpha=False,
exclude_patterns=[
"*time_text_embed*", "*context_embedder*",
"*pos_embed*", "*norm_out*", "*proj_out*",
],
)],
)
pipe.transformer = ModelQuantizer(qconfig).quantize_model(pipe.transformer, dataloader)
image = pipe("A cat on a windowsill", num_inference_steps=30).images[0]
image.save("sd3_svdquant_w4a4.png")
Example 3: SDXL – INT8 weight + activation (w8a8)#
Static INT8 quantization for both weights and activations. Uses a min/max observer to collect activation ranges during calibration.
import torch
from diffusers import DiffusionPipeline
from quark.torch import ModelQuantizer
from quark.torch.quantization.config.config import Int8PerTensorSpec, QConfig, QLayerConfig
from quark.torch.utils.diffusers import get_calib_dataloader
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16, variant="fp16",
)
pipe.to("cuda")
prompts = [
"A serene lake reflecting mountains at sunset",
"A futuristic city with flying cars at night",
"A close-up portrait with dramatic lighting",
"A golden retriever playing in autumn leaves",
"An astronaut floating above Earth",
]
dataloader = get_calib_dataloader(pipe, pipe.unet, prompts, n_steps=20, guidance_scale=8.0)
int8_spec = Int8PerTensorSpec(
observer_method="min_max", symmetric=True, scale_type="float",
round_method="half_even", is_dynamic=False,
).to_quantization_spec()
qconfig = QConfig(global_quant_config=QLayerConfig(weight=int8_spec, input_tensors=int8_spec))
pipe.unet = ModelQuantizer(qconfig).quantize_model(pipe.unet, dataloader)
image = pipe("A cat on a windowsill", num_inference_steps=30, guidance_scale=8.0).images[0]
image.save("sdxl_int8_w8a8.png")
Example 4: Flux.1-dev – SVDQuant w4a16#
Flux uses bfloat16 and a different transformer architecture. Note the
Flux-specific pipe_kwargs and exclude patterns.
import torch
from diffusers import FluxPipeline
from quark.torch import ModelQuantizer
from quark.torch.quantization.config.config import QConfig, SVDQuantConfig
from quark.torch.algorithm.svdquant import build_quant_layer_config
from quark.torch.utils.diffusers import get_calib_dataloader
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
device_map="balanced",
)
prompts = [
"A serene lake reflecting mountains at sunset",
"A futuristic city with flying cars at night",
"A close-up portrait with dramatic lighting",
"A golden retriever playing in autumn leaves",
"An astronaut floating above Earth",
]
dataloader = get_calib_dataloader(
pipe, pipe.transformer, prompts, n_steps=20,
height=1024, width=1024, guidance_scale=3.5, max_sequence_length=512,
)
qconfig = QConfig(
global_quant_config=build_quant_layer_config("w4a16"),
exclude=[
"*x_embedder*", "*context_embedder*", "*time_text_embed*",
"*norm_out*", "*proj_out*", "*correction*",
],
algo_config=[SVDQuantConfig(
svd_rank=32,
search_alpha=False,
exclude_patterns=[
"*x_embedder*", "*context_embedder*", "*time_text_embed*",
"*norm_out*", "*proj_out*", "*norm1.linear*", "*norm1_context.linear*",
],
)],
)
pipe.transformer = ModelQuantizer(qconfig).quantize_model(pipe.transformer, dataloader)
image = pipe(
"A cat on a windowsill", num_inference_steps=50,
height=1024, width=1024, guidance_scale=3.5, max_sequence_length=512,
).images[0]
image.save("flux_svdquant_w4a16.png")
Quick reference#
Which submodule to quantize#
Pipeline |
Target submodule |
|---|---|
SDXL, SD 1.5, SD 2.1 |
|
Flux, SD3, PixArt |
|
Pipeline-specific kwargs for get_calib_dataloader#
Pipeline |
Recommended kwargs |
|---|---|
SDXL |
|
SD3 |
(defaults) |
Flux |
|
Any kwargs accepted by pipe(...) can be passed – the utilities
are pipeline-agnostic.
Exclude patterns by model#
These patterns control which layers are skipped during SVD
decomposition and quantization. *correction* must always be
excluded from quantization to protect the SVD low-rank correction
branch.
Model |
SVDQuant exclude |
Quantization exclude |
|---|---|---|
SDXL |
|
same as SVDQuant exclude + |
SD3 |
|
same as SVDQuant exclude + |
Flux |
|
same as SVDQuant exclude + |
Calibration data sizing#
len(prompts) * n_steps = total calibration samples.
Use case |
Prompts |
Steps |
Samples |
|---|---|---|---|
Quick test |
3 |
10 |
30 |
Standard |
5–10 |
20 |
100–200 |
Production |
15+ |
20 |
300+ |
Captured tensors are detached and stored on CPU (one copy per
submodule call). Memory cost is roughly proportional to
len(prompts) * n_steps times the size of one submodule input.
Measure if calibrating with large prompt sets on memory-constrained
hosts.
Using COCO2014 calibration prompts#
For production-quality calibration, you can use prompts from the COCO2014 dataset instead of hand-written prompts.
Setup#
Requires torchvision compatible with your PyTorch version.
export DIFFUSERS_ROOT=$PWD
git clone https://github.com/mlcommons/inference.git
cd inference
git checkout 87ba8cb8a6a4f6525f26255fa513d902b17ab060
cd ./text_to_image/tools/
sh ./download-coco-2014.sh --num-workers 5
sh ./download-coco-2014-calibration.sh -n 5
cd ${DIFFUSERS_ROOT}
export PYTHONPATH="${DIFFUSERS_ROOT}/inference/text_to_image/:$PYTHONPATH"
Dataset files#
Calibration captions:
${DIFFUSERS_ROOT}/inference/text_to_image/coco2014/calibration/captions.tsvTest captions:
${DIFFUSERS_ROOT}/inference/text_to_image/coco2014/captions/captions_source.tsv
Usage#
def load_coco2014_prompts(coco_dir, max_prompts=None):
tsv_path = f"{coco_dir}/captions/captions_source.tsv"
prompts = []
with open(tsv_path, encoding="utf-8") as f:
for line in f.readlines()[1:]: # skip header
cols = line.split("\t")
if len(cols) >= 3:
prompts.append(cols[2].strip())
return prompts[:max_prompts] if max_prompts else prompts
prompts = load_coco2014_prompts("./inference/text_to_image/coco2014", max_prompts=50)
# Use with get_calib_dataloader as usual.
dataloader = get_calib_dataloader(pipe, pipe.unet, prompts, n_steps=20, guidance_scale=8.0)
Native inference#
Quark can convert quantized Linear layers in Diffusers transformer or UNet modules to native inference kernels during ModelQuantizer.freeze. Use RuntimeOptions to select the native linear backend after quantization and before running generation.
from quark.torch.quantization.api import ModelQuantizer
from quark.torch.quantization.utils import RuntimeOptions
runtime_options = RuntimeOptions(native_linear_mode="fp8_per_tensor")
pipe.transformer = ModelQuantizer.freeze(
pipe.transformer,
runtime_options=runtime_options,
)
For MXFP4 quantization, set native_linear_mode="mxfp4". The examples/torch/diffusers/benchmark_flux_fp8_compile.py script provides a complete FLUX FP8 native inference benchmark, including optional torch.compile.
python benchmark_flux_fp8_compile.py --mode eager
Quantization modes#
Method |
Weights |
Activations |
Observer |
Config helper |
|---|---|---|---|---|
INT8 weight-only |
INT8 per-tensor |
fp16/bf16 |
|
manual |
INT8 w8a8 |
INT8 per-tensor |
INT8 per-tensor static |
same as above |
manual |
SVDQuant w4a16 |
INT4 per-group |
fp16/bf16 |
per-group min/max (fixed) |
|
SVDQuant w4a4 |
INT4 per-group |
INT4 per-group dynamic |
per-group min/max (fixed) |
|
MXFP4 |
MXFP4 |
fp16/bf16 |
n/a |
|
NVFP4 |
FP4 block-16 |
FP4 block-16 dynamic |
n/a |
|
For observer-based methods (INT8), the observer determines how
quantization scales are computed from calibration data. Pass the
observer name via observer_method in Int8PerTensorSpec:
Int8PerTensorSpec(observer_method="percentile", ...)
# other valid values: "min_max", "histogram", "MSE", "histogrampro"
Note
Quark’s SmoothQuant algorithm (SmoothQuantConfig) currently
requires LLM-specific layer structure
(model_decoder_layers, scaling_layers) and is not yet
supported for diffusion models.