Quark for PyTorch#
Quantizing a floating-point model with Quark for PyTorch involves the following high-level steps:
Load the original floating-point model.
Set the quantization configuration.
Define the data loader.
Use the Quark API to perform an in-place replacement of the model’s modules with quantized modules.
(Optional) Export the quantized model to other formats, such as ONNX.
Supported Features#
Quark for PyTorch supports the following key features:
Feature Name |
Feature Value |
---|---|
Data Type |
Float16/ Bfloat16 / Int4 / Uint4 / Int8/ OCP_FP8_E4M3/ OCP_MXFP8_E4M3/ OCP_MXFP6 / OCP_MXFP4 /OCP_MXINT8 |
Quant Mode |
Eager Mode / FX Graph Mode |
Quant Strategy |
Static quant / Dynamic quant / Weight only quant |
Quant Scheme |
Per tensor / Per channel / Per group |
Symmetric |
Symmetric / Asymmetric |
Calibration method |
MinMax / Percentile / MSE |
Scale Type |
Float32 / Float16 |
KV-Cache Quant |
FP8 KV-Cache Quant |
In-Place Replace OP |
nn.Linear / nn.Conv2d / nn.ConvTranspose2d / nn.Embedding / nn.EmbeddingBag |
Pre-Quant Optimization |
SmoothQuant |
Quant Algorithm |
AWQ / GPTQ |
Export Format |
ONNX / Json-Safetensors / GGUF(Q4_1) |
Operating Systems |
Linux(ROCm/CUDA) / Windows(CPU) |
Basic Example#
This example shows a basic usecase on how to quantize opt-125m
model with the int8
data type
for symmetric
per tensor
weight-only
quantization.
# 1. Set model
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
# 2. Set quantization configuration
from quark.torch.quantization.config.type import Dtype, ScaleType, RoundType, QSchemeType
from quark.torch.quantization.config.config import Config, QuantizationSpec, QuantizationConfig
from quark.torch.quantization.observer.observer import PerTensorMinMaxObserver
DEFAULT_INT8_PER_TENSOR_SYM_SPEC = QuantizationSpec(dtype=Dtype.int8,
qscheme=QSchemeType.per_tensor,
observer_cls=PerTensorMinMaxObserver,
symmetric=True,
scale_type=ScaleType.float,
round_method=RoundType.half_even,
is_dynamic=False)
DEFAULT_W_INT8_PER_TENSOR_CONFIG = QuantizationConfig(weight=DEFAULT_INT8_PER_TENSOR_SYM_SPEC)
quant_config = Config(global_quant_config=DEFAULT_W_INT8_PER_TENSOR_CONFIG)
# 3. Define calibration dataloader (still need this step for weight only and dynamic quantization)
from torch.utils.data import DataLoader
text = "Hello, how are you?"
tokenized_outputs = tokenizer(text, return_tensors="pt")
calib_dataloader = DataLoader(tokenized_outputs['input_ids'])
# 4. In-place replacement with quantized modules in model
from quark.torch import ModelQuantizer
quantizer = ModelQuantizer(quant_config)
quant_model = quantizer.quantize_model(model, calib_dataloader)
# # 5. (Optional) Export onnx
# # If user want to export the quantized model, please freeze the quantized model first
# freezed_quantized_model = quantizer.freeze(quant_model)
# from quark.torch import ModelExporter
# # Get dummy input
# for data in calib_dataloader:
# input_args = data
# break
# quant_model = quant_model.to('cuda')
# input_args = input_args.to('cuda')
# exporter = ModelExporter('export_path')
# exporter.export_onnx_model(quant_model, input_args)
If the code runs successfully, the terminal displays [QUARK-INFO]: Model quantization has been completed.
For more detailed information, see the section on Advanced Quark Features for PyTorch.