Quark for Pytorch - Feature Description#

Quark for PyTorch provides the key features as below:

Feature Name

Feature Value

Data Type

Float16/ Bfloat16 / Int4 / Uint4 / Int8/ OCP_FP8_E4M3/ OCP_MXFP8_E4M3/ OCP_MXFP6 / OCP_MXFP4 /OCP_MXINT8

Quant Mode

Eager Mode / FX Graph Mode

Quant Strategy

Static quant / Dynamic quant / Weight only quant

Quant Scheme

Per tensor / Per channel / Per group

Symmetric

Symmetric / Asymmetric

Calibration method

MinMax / Percentile / MSE

Scale Type

Float32 / Float16

KV-Cache Quant

FP8 KV-Cache Quant

In-Place Replace OP

nn.Linear / nn.Conv2d

Pre-Quant Optimization

SmoothQuant

Quant Algorithm

AWQ / GPTQ

Export Format

ONNX / Json-Safetensors / GGUF(Q4_1)

Operating Systems

Linux(ROCm/CUDA) / Windows(CPU)

We present detailed explanations of these features:

Quant Strategy#

Quark for Pytorch offers three distinct quantization strategies tailored to meet the requirements of various HW backends:

  • Post Training Weight-Only Quantization: The weights are quantized ahead of time but the activations are not quantized(using original float data type) during inference.

  • Post Training Dynamic Quantization: The weights are quantized ahead of time but the activations are dynamically quantized during inference.

  • Post Training Static Quantization: Post Training Static Quantization quantizes both the weights and activations in the model. To achieve the best results, this process necessitates calibration with a dataset that accurately represents the actual data, which allows for precise determination of the optimal quantization parameters for activations.

Here is one sample example for different quant strategies:

# 1. Set model
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")

# 2. Set quantization configuration
from quark.torch.quantization.config.type import Dtype, ScaleType, RoundType, QSchemeType
from quark.torch.quantization.config.config import Config, QuantizationSpec, QuantizationConfig
from quark.torch.quantization.observer.observer import PerTensorMinMaxObserver

# 2-1. For weight only quantization, please uncomment the following lines.
DEFAULT_UINT4_PER_GROUP_ASYM_SPEC = QuantizationSpec(dtype=Dtype.uint4,
                                                    observer_cls=PerChannelMinMaxObserver,
                                                    symmetric=False,
                                                    scale_type=ScaleType.float,
                                                    round_method=RoundType.half_even,
                                                    qscheme=QSchemeType.per_group,
                                                    ch_axis=0,
                                                    is_dynamic=False,
                                                    group_size=32)
DEFAULT_W_UINT4_PER_GROUP_CONFIG = QuantizationConfig(weight=DEFAULT_UINT4_PER_GROUP_ASYM_SPEC)
quant_config = Config(global_quant_config=DEFAULT_W_UINT4_PER_GROUP_CONFIG)

# 2-2. For dynamic quantization, please uncomment the following lines.
# INT8_PER_TENSER_DYNAMIC_SPEC = QuantizationSpec(dtype=Dtype.int8,
#                                                 qscheme=QSchemeType.per_tensor,
#                                                 observer_cls=PerTensorMinMaxObserver,
#                                                 symmetric=True,
#                                                 scale_type=ScaleType.float,
#                                                 round_method=RoundType.half_even,
#                                                 is_dynamic=True)
# DEFAULT_W_INT8_A_INT8_PER_TENSOR_DYNAMIC_CONFIG = QuantizationConfig(input_tensors=INT8_PER_TENSER_DYNAMIC_SPEC,
#                                                                      weight=INT8_PER_TENSER_DYNAMIC_SPEC)
# quant_config = Config(global_quant_config=DEFAULT_W_INT8_A_INT8_PER_TENSOR_DYNAMIC_CONFIG)

# 2-3. For static quantization , please uncomment the following lines.
# FP8_PER_TENSOR_SPEC = QuantizationSpec(dtype=Dtype.fp8_e4m3,
#                                        qscheme=QSchemeType.per_tensor,
#                                        observer_cls=PerTensorMinMaxObserver,
#                                        is_dynamic=False)
# DEFAULT_W_FP8_A_FP8_PER_TENSOR_CONFIG = QuantizationConfig(input_tensors=FP8_PER_TENSOR_SPEC,
#                                                            weight=FP8_PER_TENSOR_SPEC)
# quant_config = Config(global_quant_config=DEFAULT_W_FP8_A_FP8_PER_TENSOR_CONFIG)

# 3. Define calibration dataloader (still need this step for weight only and dynamic quantization)
from torch.utils.data import DataLoader
text = "Hello, how are you?"
tokenized_outputs = tokenizer(text, return_tensors="pt")
calib_dataloader = DataLoader(tokenized_outputs['input_ids'])

# 4. In-place replacement with quantized modules in model
from quark.torch import ModelQuantizer
quantizer = ModelQuantizer(quant_config)
quant_model = quantizer.quantize_model(model, calib_dataloader)

The strategies share the same user API. Users simply need to set the strategy through the quantization configuration, as demonstrated in the example above. More details about setting quantization configuration are in the chapter “Configuring Quark for PyTorch”

The Quant Schemes#

Quark for PyTorch is capable of handling per tensor, per channel and per group quantization, supporting both symmetric and asymmetric methods.

  • Per Tensor Quantization means that quantize the tensor with one scalar. The scaling factor is a scalar.

  • Per Channel Quantization means that for each dimension, typically the channel dimension of a tensor, the values in the tensor are quantized with different quantization parameters. The scaling factor is a 1-D tensor, with the length of the quantization axis. For the input tensor with shape (D0, ..., Di, ..., Dn) and ch_axis=i, The scaling factor is a 1-D tensor of length Di.

  • Per Group Quantization means that divides the tensor into smaller blocks that are independently quantized. The scaling factor has the same dimension with the input tensor. For the input tensor with shape (D0, ..., Di, ..., Dn) and ch_axis=i and group_size=m, The scaling factor has the shape of (D0, ..., Di/m, ..., Dn).

The Symmetric/Asymmetric Quantization#

Symmetric/Asymmetric quantization is primarily used to describe the quantization of integers. Symmetric quantization involves scaling the data by a fixed scaling factor, and zero-point is generally set at zero. Asymmetric quantization uses a scaling factor and a zero-point that can shift, allowing the zero of the quantized data to represent a value other than zero.

The Calibration Methods#

Quark for PyTorch supports these types of calibration methods:

  • MinMax Calibration method: The MinMax calibration method for computing the quantization parameters based on the running min and max values. This method uses the tensor min/max statistics to compute the quantization parameters. The module records the running minimum and maximum of incoming tensors and uses this statistic to compute the quantization parameters.

  • Percentile Calibration method: The Percentile calibration method, often used in robust scaling, involves scaling features based on percentile information from a static histogram, rather than using the absolute minimum and maximum values. This method is particularly useful for managing outliers in data.

  • MSE Calibration method: The MSE (Mean Squared Error) calibration method refers to a method where calibration is performed by minimizing the mean squared error between the predicted outputs and the actual outputs. This method is typically used in regression contexts where the goal is to adjust model parameters or data transformations to reduce the average squared difference between estimated values and the true values. MSE calibration helps in refining model accuracy by fine-tuning predictions to be as close as possible to the real data points.

KV-Cache Quant#

Quark for PyTorch supports the quantization of kv cache in the attention layer of transformer models.

Pre-Quant Optimization#

Quark for PyTorch supports SmoothQuant as the pre-quant optimization.

Advanced Quant Algorithm#

Quark for PyTorch supports AWQ and GPTQ the advanced algorithm.

  • AWQ : Quark for PyTorch re-implements the algorithm of AWQ. Quark for PyTorch only supports AWQ with quantization data type as UINT4/INT4 and per group/channel, running on Linux with the GPU mode for now. If users want to used it with per group/channel, they should set qscheme as per_group with group_size=-1.

  • GPTQ : Quark for PyTorch re-implements the algorithm of GPTQ. Quark for PyTorch only supports GPTQ with quantization data type as UINT4/INT4 and per group/channel, running on Linux with the GPU mode for now. If users want to used it with per group/channel, they should set qscheme as per_group with group_size=-1.