#
# Copyright (C) 2023 - 2025 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
#
"""Quark Quantization Config API for PyTorch"""
from __future__ import annotations
import json
from abc import abstractmethod
from dataclasses import dataclass, field, fields
from typing import TYPE_CHECKING, Any, TypeVar, cast
import torch.nn as nn
if TYPE_CHECKING:
from quark.torch.quantization.config.template import LLMTemplate
from quark.shares.config import BaseAlgoConfig, BaseConfigImpl, BaseQConfig, BaseQLayerConfig, BaseQTensorConfig
from quark.shares.utils.doc import add_start_docstring
from quark.shares.utils.log import ScreenLogger
from quark.torch.quantization.config.type import (
ALL_DATA_TYPES,
DeviceType,
Dtype,
QSchemeType,
QuantizationMode,
RoundType,
ScaleType,
TQTThresholdInitMeth,
ZeroPointType,
)
from quark.torch.quantization.config.utils import dataclass_pretty_string
from quark.torch.quantization.constants import ONLY_DTYPE_CHANGE, QUARK_LAYER_TYPES, USING_NON_SCALED_QUANT
from quark.torch.quantization.observer import (
OBSERVER_CLASSES,
OBSERVER_MAP,
PER_CHANNEL_OBSERVERS,
PER_GROUP_OBSERVERS,
PER_TENSOR_OBSERVERS,
ObserverBase,
PerBlockMXDiffsObserver,
PerBlockMXObserver,
PerChannelMinMaxObserver,
PerGroupMinMaxObserver,
PerTensorHistogramObserver,
PerTensorHistogramObserverPro,
PerTensorMinMaxObserver,
PerTensorMSEObserver,
PerTensorPercentileObserver,
PlaceholderObserver,
)
logger = ScreenLogger(__name__)
DATA_TYPE_SPEC_DOCSTRING = r"""Helper class to define a :py:class:`.QTensorConfig` using {0}.
Example:
.. code-block:: python
quantization_spec = {1}({2}).to_quantization_spec()
"""
PER_TENSOR_OBSERVER_METHOD_MAP: dict[str, type[ObserverBase]] = {
"min_max": PerTensorMinMaxObserver,
"histogram": PerTensorHistogramObserver,
"histogrampro": PerTensorHistogramObserverPro,
"MSE": PerTensorMSEObserver,
"percentile": PerTensorPercentileObserver,
}
SCALE_TYPE_MAP = {"float": ScaleType.float, "power_of_2": ScaleType.pof2}
ROUND_METHOD_MAP = {"round": RoundType.round, "floor": RoundType.floor, "half_even": RoundType.half_even}
ZERO_POINT_TYPE_MAP = {"int32": ZeroPointType.int32, "float32": ZeroPointType.float32}
def get_per_tensor_observer(observer_method: str | None = None) -> type[ObserverBase] | None:
if observer_method:
assert observer_method in PER_TENSOR_OBSERVER_METHOD_MAP, (
f"Invalid observer method. Valid observer methods are {list(PER_TENSOR_OBSERVER_METHOD_MAP.keys())}"
)
observer_cls = PER_TENSOR_OBSERVER_METHOD_MAP[observer_method]
else:
observer_cls = None
return observer_cls
def get_scale_type(scale_type: str | None = None) -> ScaleType | None:
if scale_type:
assert scale_type in SCALE_TYPE_MAP, f"Invalid scale type. Valid scale types are {list(SCALE_TYPE_MAP.keys())}"
ret = SCALE_TYPE_MAP[scale_type]
else:
ret = None
return ret
def get_round_method(round_method: str | None = None) -> RoundType | None:
if round_method:
assert round_method in ROUND_METHOD_MAP, (
f"Invalid round method. Valid round methods are {list(ROUND_METHOD_MAP.keys())}"
)
ret = ROUND_METHOD_MAP[round_method]
else:
ret = None
return ret
def get_zero_point_type(zero_point_type: str | None = None) -> ZeroPointType | None:
if zero_point_type:
assert zero_point_type in ZERO_POINT_TYPE_MAP, (
f"Invalid zero point type, Valid zero point type method are {list(ZERO_POINT_TYPE_MAP.keys())}"
)
ret = ZERO_POINT_TYPE_MAP[zero_point_type]
else:
ret = None
return ret
QCT = TypeVar("QCT", bound="QConfig")
[docs]
@dataclass(eq=True)
class QConfig(BaseConfigImpl, BaseQConfig):
"""
A class that encapsulates comprehensive quantization configurations for a machine learning model, allowing for detailed and hierarchical control over quantization parameters across different model components.
:param QLayerConfig global_quant_config: Global quantization configuration applied to the entire model unless overridden at the layer level.
:param Dict[torch.nn.Module, QLayerConfig] layer_type_quant_config: A dictionary mapping from layer types (e.g., nn.Conv2d, nn.Linear) to their quantization configurations.
:param Dict[str, QLayerConfig] layer_quant_config: A dictionary mapping from layer names to their quantization configurations, allowing for per-layer customization. Default is ``{}``.
:param Dict[str, QLayerConfig] kv_cache_quant_config: A dictionary mapping from layer names to kv_cache quantization configurations. Default is ``{}``.
:param Optional[QTensorConfig] softmax_quant_spec: A quantization specifications of nn.functional.softmax output. Default is ``None``.
:param List[str] exclude: A list of layer names to be excluded from quantization, enabling selective quantization of the model. Default is ``[]``.
:param Optional[AlgoConfig] algo_config: Optional configuration for the quantization algorithm, such as GPTQ, AWQ and Qronos. After this process, the datatype/fake_datatype of weights will be changed with quantization scales. Default is ``None``.
:param QuantizationMode quant_mode: The quantization mode to be used (``eager_mode`` or ``fx_graph_mode``). Default is ``QuantizationMode.eager_mode``.
"""
# Note: `global_quant_config`, `exclude`, `algo_config`, `log_severity_level`, `version` are inherited from `BaseQConfig`
# A dictionary mapping from layer types (e.g., nn.Conv2d, nn.Linear) to their quantization configurations.
layer_type_quant_config: dict[type[nn.Module], QLayerConfig] = field(default_factory=dict)
# A dictionary mapping from layer names to their quantization configurations, allowing for per-layer customization.
layer_quant_config: dict[str, QLayerConfig] = field(default_factory=dict)
# A dictionary mapping from layer names to kv_cache quantization configurations.
kv_cache_quant_config: dict[str, QLayerConfig] = field(default_factory=dict)
# A list of layer names to be grouped for kv_cache quantization, enabling per-group customization.
kv_cache_group: list[str] = field(default_factory=list)
# The minimum scale of kv_cache quantization.
min_kv_scale: float = 0.0
# Control whether KV-cache quantization is applied post-RoPE (inside HF Cache.update).
# When False (default), legacy pre-RoPE behaviour applies (module-level K/V output quantizers).
kv_cache_post_rope: bool = False
# A quantization specifications of nn.functional.softmax output.
softmax_quant_spec: QTensorConfig | None = None
# The quantization mode to be used (eager_mode or fx_graph_mode)
quant_mode: QuantizationMode = QuantizationMode.eager_mode
def to_dict(self) -> dict[str, Any]:
config_dict: dict[str, Any] = {
"global_quant_config": self.global_quant_config.to_dict(),
"exclude": self.exclude,
"algo_config": [config.to_dict() for config in self.algo_config] if self.algo_config is not None else None,
"softmax_quant_spec": self.softmax_quant_spec.to_dict() if self.softmax_quant_spec is not None else None,
"quant_method": "quark",
}
layer_type_quant_config_dict: dict[str, Any] = {}
for layer_type, config in self.layer_type_quant_config.items():
layer_type_quant_config_dict[layer_type.__name__] = config.to_dict()
config_dict["layer_type_quant_config"] = layer_type_quant_config_dict
layer_quant_config_dict: dict[str, Any] = {}
for name, config in self.layer_quant_config.items():
layer_quant_config_dict[name] = config.to_dict()
config_dict["layer_quant_config"] = layer_quant_config_dict
kv_cache_quant_config_dict: dict[str, Any] = {}
for name, config in self.kv_cache_quant_config.items():
kv_cache_quant_config_dict[name] = config.to_dict()
config_dict["kv_cache_quant_config"] = kv_cache_quant_config_dict
# Only serialize kv_cache_post_rope - it's needed to determine cache integration behavior during import
# kv_cache_group and min_kv_scale are export-time processing parameters and not needed for import
config_dict["kv_cache_post_rope"] = self.kv_cache_post_rope
config_dict["quant_mode"] = self.quant_mode.name
config_dict["version"] = self.version
return config_dict
def __str__(self) -> str:
s = dataclass_pretty_string(self)
return s
@classmethod
def from_dict(cls: type[QCT], config_dict: dict[str, Any]) -> QCT:
global_quant_config = QLayerConfig.from_dict(config_dict["global_quant_config"])
# TODO: Deprecate legacy configuration and remove the None check here.
# Legacy (quark<1.0) configuration used to allow layer_type_quant_config=None in the serialized config, inconstitant with
# the type hints of the dataclass.
layer_type_quant_config = {}
if config_dict["layer_type_quant_config"] is not None:
for layer_type_name, layer_type_quantization_config in config_dict["layer_type_quant_config"].items():
if layer_type_name in QUARK_LAYER_TYPES:
layer_type_quant_config[QUARK_LAYER_TYPES[layer_type_name]] = QLayerConfig.from_dict(
layer_type_quantization_config
)
else:
raise NotImplementedError(
f"Quark does not support reloading a quantization `Config` from a dictionary using custom `layer_type_quantization_config`. Found `'{layer_type_name}'` in `layer_type_quantization_config`, which is not among the supported {QUARK_LAYER_TYPES}."
)
# TODO: Deprecate legacy configuration and remove the None check here.
# Legacy (quark<1.0) configuration used to allow layer_quant_config=None in the serialized config, inconstitant with
# the type hints of the dataclass.
if config_dict["layer_quant_config"] is not None:
layer_quant_config = {
layer_name: QLayerConfig.from_dict(quant_config_dict)
for layer_name, quant_config_dict in config_dict["layer_quant_config"].items()
}
else:
layer_quant_config = {}
if config_dict.get("kv_cache_quant_config") is not None:
kv_cache_quant_config = {
kv_cache_name: QLayerConfig.from_dict(kv_cache_config_dict)
for kv_cache_name, kv_cache_config_dict in config_dict["kv_cache_quant_config"].items()
}
else:
kv_cache_quant_config = {}
# TODO: Deprecate legacy (quark<1.0) configuration and remove the check here.
# `exclude` used to be serialized as `None` when there was no exclude layer, instead of `[]`.
if config_dict["exclude"] is None: # pragma: no cover
exclude = []
else:
exclude = config_dict["exclude"]
if "algo_config" in config_dict and config_dict["algo_config"] is not None:
if isinstance(config_dict["algo_config"], list): # new config
algo_config = [_load_quant_algo_config_from_dict(config) for config in config_dict["algo_config"]]
else: # old config
algo_config = [_load_quant_algo_config_from_dict(config_dict["algo_config"])]
else:
algo_config = None
# Get softmax_quant_spec configuration from config_dict
softmax_quant_spec = (
QTensorConfig.from_dict(config_dict["softmax_quant_spec"])
if ("softmax_quant_spec" in config_dict and config_dict["softmax_quant_spec"] is not None)
else None
)
if "quant_mode" in config_dict:
quant_mode = QuantizationMode[config_dict["quant_mode"]] # Access by name and not by value.
else:
# TODO: Deprecate legacy (quark<1.0) configuration and remove the check here.
# The key `"quant_mode"` used not to be serialized in the legacy quantization_config, inconstitant with
# the type hints of the dataclass.
quant_mode = QuantizationMode.eager_mode
# get version from config_dict, if not found (e.g. models exported with amd-quark<=0.8), set it to `None`.
version = config_dict["version"] if "version" in config_dict else None
kv_cache_post_rope = config_dict.get("kv_cache_post_rope", False)
return cls(
global_quant_config=global_quant_config,
layer_type_quant_config=layer_type_quant_config,
layer_quant_config=layer_quant_config,
kv_cache_quant_config=kv_cache_quant_config,
kv_cache_post_rope=kv_cache_post_rope,
exclude=exclude,
algo_config=algo_config,
softmax_quant_spec=softmax_quant_spec,
quant_mode=quant_mode,
version=version,
)
@staticmethod
def with_llm_template(
template: LLMTemplate,
scheme: str,
algorithm: str | None = None,
kv_cache_scheme: str | None = None,
min_kv_scale: float = 0.0,
attention_scheme: str | None = None,
layer_config: dict[str, str] | None = None,
layer_type_config: dict[type[nn.Module], str] | None = None,
exclude_layers: list[str] | None = None,
) -> Config:
return template.get_config(
scheme=scheme,
algorithm=algorithm,
kv_cache_scheme=kv_cache_scheme,
min_kv_scale=min_kv_scale,
attention_scheme=attention_scheme,
layer_config=layer_config,
layer_type_config=layer_type_config,
exclude_layers=exclude_layers,
)
def __post_init__(self) -> None:
if self.__class__ is Config:
logger.warning(
f"quark.torch.config.config.{self.__class__.__name__} is deprecated and will be removed in a future release. Please use quark.torch.config.config.QConfig instead."
)
if self.algo_config is not None:
for algo_config in self.algo_config:
if algo_config.name in ["rotation", "quarot"]:
algo_config_cls_name = algo_config.__class__.__name__
if len(self.kv_cache_quant_config) > 0 and not algo_config.r3: # type: ignore
logger.warning(
f"The R3 rotation is disabled, but the KV cache is configured to quantized with: {self.kv_cache_quant_config}. KV cache quantization may benefit from R3 rotation if keys are quantized. Consider using `r3=True` in {algo_config_cls_name} configuration."
)
if len(self.kv_cache_quant_config) == 0 and algo_config.r3: # type: ignore
logger.warning(
f"No KV cache quantization configuration provided, but `{algo_config_cls_name}.r3` is set to `True`. This setting is only useful in case KV cache quantization is used. Consider using `r3=False`."
)
def get_rotation_config(self) -> RotationConfig | None:
rotation_config = None
if self.algo_config is not None:
for algo_config in self.algo_config:
if isinstance(algo_config, RotationConfig):
rotation_config = algo_config
break
return rotation_config
QLT = TypeVar("QLT", bound="QLayerConfig")
[docs]
@dataclass(eq=True)
class QLayerConfig(BaseQLayerConfig):
"""
A data class that specifies quantization configurations for different components of a module, allowing hierarchical control over how each tensor type is quantized.
:param Optional[Union[QTensorConfig, List[QTensorConfig]]] input_tensors: Input tensors quantization specification. If None, following the hierarchical quantization setup. e.g. If the ``input_tensors`` in ``layer_type_quant_config`` is ``None``, the configuration from ``global_quant_config`` will be used instead. Defaults to ``None``. If None in ``global_quant_config``, ``input_tensors`` are not quantized.
:param Optional[Union[QTensorConfig, List[QTensorConfig]]] output_tensors: Output tensors quantization specification. Defaults to ``None``. If ``None``, the same as above.
:param Optional[Union[QTensorConfig, List[QTensorConfig]]] weight: The weights tensors quantization specification. Defaults to ``None``. If ``None``, the same as above.
:param Optional[Union[QTensorConfig, List[QTensorConfig]]] bias: The bias tensors quantization specification. Defaults to ``None``. If ``None``, the same as above.
:param Optional[DeviceType] target_device: Configuration specifying the target device (e.g., CPU, GPU, IPU) for the quantized model.
"""
# Note: `input_tensors`, ou`tput_tensors, wei`ght, and bi`as are inherited from `BaseQLayerConfig`
# with type Union[BaseQTensorConfig, list[BaseQTensorConfig]] | None
# They are not redeclared here to avoid type incompatibility issues
target_device: DeviceType | None = None
def __post_init__(self) -> None:
if self.__class__ is QuantizationConfig:
logger.warning(
f"quark.torch.config.config.{self.__class__.__name__} is deprecated and will be removed in a future release. Please use quark.torch.config.config.QLayerConfig instead."
)
for tensor_name, quantization_spec in [
("input_tensors", self.input_tensors),
("output_tensors", self.output_tensors),
]:
if quantization_spec is not None:
if isinstance(quantization_spec, QTensorConfig):
quantization_spec = [quantization_spec]
assert isinstance(quantization_spec, list), "quantization_spec must a list"
for quant_spec in quantization_spec:
if quant_spec.qscheme == QSchemeType.per_group and not quant_spec.is_dynamic:
raise ValueError(
f"The parameterization `qscheme=QSchemeType.per_group` along with `is_dynamic=False` is currently not supported in AMD Quark for `QLayerConfig.input_tensors` and `QLayerConfig.output_tensors`, got `QLayerConfig.{tensor_name}.qscheme={quant_spec.qscheme}` and `QLayerConfig.{tensor_name}.is_dynamic={quant_spec.is_dynamic}`. Consider using `QLayerConfig.{tensor_name}.is_dynamic=True` as static per-group quantization for activations is rarely relevant."
)
[docs]
def to_dict(self) -> dict[str, Any]:
def convert_spec_to_dict(
spec: QTensorConfig | list[QTensorConfig] | None,
) -> dict[str, Any] | list[dict[str, Any]] | None:
if spec is None:
return None
elif isinstance(spec, list):
return [s.to_dict() for s in spec]
else:
return spec.to_dict()
return {
"input_tensors": convert_spec_to_dict(self.input_tensors),
"output_tensors": convert_spec_to_dict(self.output_tensors),
"weight": convert_spec_to_dict(self.weight),
"bias": convert_spec_to_dict(self.bias),
"target_device": self.target_device.value if self.target_device is not None else None,
}
[docs]
@classmethod
def from_dict(cls: type[QLT], quantization_config: dict[str, Any]) -> QLT:
def convert_dict_to_spec(
config: dict[str, Any] | list[dict[str, Any]] | None,
) -> QTensorConfig | list[QTensorConfig] | None:
if config is None:
return None
elif isinstance(config, list):
specs = [QTensorConfig.from_dict(c) for c in config]
assert all(spec is not None for spec in specs), "all quantization specs must be valid (not None)"
# After verification, all items are guaranteed to be QTensorConfig
return specs # type: ignore[return-value]
else:
return QTensorConfig.from_dict(config)
input_tensors = convert_dict_to_spec(quantization_config["input_tensors"])
output_tensors = convert_dict_to_spec(quantization_config["output_tensors"])
weight = convert_dict_to_spec(quantization_config["weight"])
bias = convert_dict_to_spec(quantization_config["bias"])
# TODO: Deprecate legacy configuration.
# Legacy (quark<1.0) saved quantization_config does not have the key `"target_device"`.
target_device = quantization_config["target_device"] if "target_device" in quantization_config else None
target_device = DeviceType(target_device) if target_device is not None else None
return cls(
input_tensors=input_tensors,
output_tensors=output_tensors,
weight=weight,
bias=bias,
target_device=target_device,
)
[docs]
@dataclass
class TwoStageSpec(BaseConfigImpl):
"""
A data class that specifies two-stage quantization configurations for different components of a module,
allowing hierarchical control over how each tensor type is quantized.
"""
first_stage: DataTypeSpec | QTensorConfig
second_stage: DataTypeSpec | QTensorConfig
@abstractmethod
def to_quantization_spec(self) -> list[QTensorConfig]:
pass
[docs]
@dataclass
class ProgressiveSpec(TwoStageSpec):
"""
A data class that specifies a progressive quantization specification for a tensor.
The first stage quantizes the input tensor, while the second stage quantizes the output from the first stage.
For example, to progressively quantize a float16 tensor:
1. First quantize it to fp8_e4m3 using fp8_e4m3 per-tensor quantization, get a fp8_e4m3 tensor.
2. Then quantize the fp8_e4m3 tensor to int4 using int4 per-channel quantization, get a int4 tensor.
The configuration for this example would be:
.. code-block:: python
quant_spec = ProgressiveSpec(
first_stage=FP8E4M3PerTensorSpec(observer_method="min_max",
is_dynamic=False),
second_stage=Int4PerChannelSpec(symmetric=False,
scale_type="float",
round_method="half_even",
ch_axis=0,
is_dynamic=False)
).to_quantization_spec()
"""
def to_quantization_spec(self) -> list[QTensorConfig]:
return [
self.first_stage.to_quantization_spec() if isinstance(self.first_stage, DataTypeSpec) else self.first_stage,
self.second_stage.to_quantization_spec()
if isinstance(self.second_stage, DataTypeSpec)
else self.second_stage,
]
[docs]
@dataclass
class ScaleQuantSpec(TwoStageSpec):
"""
A data class that specifies a two-stage quantization process for scale quantization.
The quantization happens in two stages:
1. First stage quantizes the input tensor itself.
2. Second stage quantizes the scale values from the first stage quantization.
For example, given a float16 tensor:
1. First quantize the tensor to fp4_e2m1 using fp4_e2m1 per-group quantization, producing a fp4_e2m1 tensor with float16 scale values.
2. Then quantize those float16 scale values to fp8_e4m3 using fp8_e4m3 per-tensor quantization.
The configuration for this example would be:
.. code-block:: python
quant_spec = ScaleQuantSpec(
first_stage=FP4PerGroupSpec(group_size=16, is_dynamic=False),
second_stage=FP8E4M3PerTensorSpec(observer_method="min_max", is_dynamic=False)
).to_quantization_spec()
"""
def to_quantization_spec(self) -> list[QTensorConfig]:
second_stage_spec = (
self.second_stage.to_quantization_spec()
if isinstance(self.second_stage, DataTypeSpec)
else self.second_stage
)
second_stage_spec.is_scale_quant = True
return [
self.first_stage.to_quantization_spec() if isinstance(self.first_stage, DataTypeSpec) else self.first_stage,
second_stage_spec,
]
[docs]
class DataTypeSpec(BaseConfigImpl):
@abstractmethod
def to_quantization_spec(self) -> QTensorConfig:
pass
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"uint4 per tensor quantization", "Uint4PerTensorSpec", "is_dynamic=True, symmetric=False"
)
)
class Uint4PerTensorSpec(DataTypeSpec):
observer_method: str | None = "min_max"
symmetric: bool | None = False
scale_type: str | None = "float"
round_method: str | None = "half_even"
is_dynamic: bool | None = False
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.uint4,
observer_cls=get_per_tensor_observer(self.observer_method),
symmetric=self.symmetric,
scale_type=get_scale_type(self.scale_type),
round_method=get_round_method(self.round_method),
qscheme=QSchemeType.per_tensor,
is_dynamic=self.is_dynamic,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"uint4 per channel quantization",
"Uint4PerChannelSpec",
r"""
symmetric=True,
scale_type="float",
round_method="half_even",
ch_axis=0,
is_dynamic=False
""",
)
)
class Uint4PerChannelSpec(DataTypeSpec):
ch_axis: int
symmetric: bool | None = False
scale_type: str | None = "float"
round_method: str | None = "half_even"
is_dynamic: bool | None = None
zero_point_type: str | None = "int32"
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.uint4,
observer_cls=PerChannelMinMaxObserver,
symmetric=self.symmetric,
scale_type=get_scale_type(self.scale_type),
round_method=get_round_method(self.round_method),
qscheme=QSchemeType.per_channel,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
zero_point_type=get_zero_point_type(self.zero_point_type),
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"uint4 per group quantization",
"Uint4PerGroupSpec",
r"""
symmetric=False,
scale_type="float",
round_method="half_even",
ch_axis=1,
is_dynamic=False,
group_size=128
""",
)
)
class Uint4PerGroupSpec(DataTypeSpec):
ch_axis: int
group_size: int
symmetric: bool = False
is_dynamic: bool | None = False
scale_type: str | None = "float"
round_method: str | None = "half_even"
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.uint4,
observer_cls=PerGroupMinMaxObserver,
symmetric=self.symmetric,
scale_type=get_scale_type(self.scale_type),
round_method=get_round_method(self.round_method),
qscheme=QSchemeType.per_group,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
group_size=self.group_size,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"int3 per group quantization",
"Int3PerGroupSpec",
r"""
symmetric=True,
scale_type="float",
round_method="half_even",
is_dynamic=False,
group_size=32,
""",
)
)
class Int3PerGroupSpec(DataTypeSpec):
ch_axis: int
group_size: int
symmetric: bool | None = True
scale_type: str | None = "float"
round_method: str | None = "half_even"
is_dynamic: bool | None = False
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.int3,
observer_cls=PerGroupMinMaxObserver,
symmetric=self.symmetric,
scale_type=get_scale_type(self.scale_type),
round_method=get_round_method(self.round_method),
qscheme=QSchemeType.per_group,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
group_size=self.group_size,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"int3 per channel quantization",
"Int3PerChannelSpec",
r"""
symmetric=False,
scale_type="float",
round_method="half_even",
ch_axis=0,
is_dynamic=False
""",
)
)
class Int3PerChannelSpec(DataTypeSpec):
ch_axis: int
symmetric: bool | None = False
scale_type: str | None = "float"
round_method: str | None = "half_even"
is_dynamic: bool | None = False
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.int3,
observer_cls=PerChannelMinMaxObserver,
symmetric=self.symmetric,
scale_type=get_scale_type(self.scale_type),
round_method=get_round_method(self.round_method),
qscheme=QSchemeType.per_channel,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"int2 per group quantization",
"Int2PerGroupSpec",
r"""
symmetric=True,
scale_type="float",
round_method="half_even",
is_dynamic=False,
group_size=32,
""",
)
)
class Int2PerGroupSpec(DataTypeSpec):
ch_axis: int
group_size: int
symmetric: bool | None = True
scale_type: str | None = "float"
round_method: str | None = "half_even"
is_dynamic: bool | None = False
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.int2,
observer_cls=PerGroupMinMaxObserver,
symmetric=self.symmetric,
scale_type=get_scale_type(self.scale_type),
round_method=get_round_method(self.round_method),
qscheme=QSchemeType.per_group,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
group_size=self.group_size,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"int4 per tensor quantization",
"Int4PerTensorSpec",
r"""
observer_method="min_max",
symmetric=True,
scale_type="float",
round_method="half_even",
is_dynamic=False
""",
)
)
class Int4PerTensorSpec(DataTypeSpec):
observer_method: str | None = "min_max"
symmetric: bool | None = True
scale_type: str | None = "float"
round_method: str | None = "half_even"
is_dynamic: bool | None = False
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.int4,
observer_cls=get_per_tensor_observer(self.observer_method),
symmetric=self.symmetric,
scale_type=get_scale_type(self.scale_type),
round_method=get_round_method(self.round_method),
qscheme=QSchemeType.per_tensor,
is_dynamic=self.is_dynamic,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"int4 per channel quantization",
"Int4PerChannelSpec",
r"""
symmetric=False,
scale_type="float",
round_method="half_even",
ch_axis=0,
is_dynamic=False
""",
)
)
class Int4PerChannelSpec(DataTypeSpec):
ch_axis: int
symmetric: bool | None = True
scale_type: str | None = "float"
round_method: str | None = "half_even"
is_dynamic: bool | None = False
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.int4,
observer_cls=PerChannelMinMaxObserver,
symmetric=self.symmetric,
scale_type=get_scale_type(self.scale_type),
round_method=get_round_method(self.round_method),
qscheme=QSchemeType.per_channel,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"int4 per group quantization",
"Int4PerGroupSpec",
r"""
symmetric=True,
scale_type="float",
round_method="half_even",
ch_axis=1,
is_dynamic=False,
group_size=128
""",
)
)
class Int4PerGroupSpec(DataTypeSpec):
ch_axis: int
group_size: int
symmetric: bool = True
is_dynamic: bool | None = False
scale_type: str | None = "float"
round_method: str | None = "half_even"
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.int4,
observer_cls=PerGroupMinMaxObserver,
symmetric=self.symmetric,
scale_type=get_scale_type(self.scale_type),
round_method=get_round_method(self.round_method),
qscheme=QSchemeType.per_group,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
group_size=self.group_size,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"uint8 per tensor quantization",
"Uint8PerTensorSpec",
r"""
observer_method="percentile",
symmetric=True,
scale_type="float",
round_method="half_even",
is_dynamic=False
""",
)
)
class Uint8PerTensorSpec(DataTypeSpec):
observer_method: str | None = "min_max"
symmetric: bool | None = False
scale_type: str | None = "float"
round_method: str | None = "half_even"
is_dynamic: bool | None = False
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.uint8,
observer_cls=get_per_tensor_observer(self.observer_method),
symmetric=self.symmetric,
scale_type=get_scale_type(self.scale_type),
round_method=get_round_method(self.round_method),
qscheme=QSchemeType.per_tensor,
is_dynamic=self.is_dynamic,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"uint8 per channel quantization",
"Uint8PerChannelSpec",
r"""
symmetric=True,
scale_type="float",
round_method="half_even",
ch_axis=0,
is_dynamic=False
""",
)
)
class Uint8PerChannelSpec(DataTypeSpec):
ch_axis: int
symmetric: bool | None = True
scale_type: str | None = "float"
round_method: str | None = "half_even"
is_dynamic: bool | None = False
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.uint8,
observer_cls=PerChannelMinMaxObserver,
symmetric=self.symmetric,
scale_type=get_scale_type(self.scale_type),
round_method=get_round_method(self.round_method),
qscheme=QSchemeType.per_channel,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"uint8 per group quantization",
"Uint8PerGroupSpec",
r"""
symmetric=False,
scale_type="float",
round_method="half_even",
ch_axis=1,
is_dynamic=False,
group_size=128
""",
)
)
class Uint8PerGroupSpec(DataTypeSpec):
ch_axis: int
group_size: int
symmetric: bool | None = False
scale_type: str | None = "float"
round_method: str | None = "half_even"
is_dynamic: bool | None = False
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.uint8,
observer_cls=PerGroupMinMaxObserver,
symmetric=self.symmetric,
scale_type=get_scale_type(self.scale_type),
round_method=get_round_method(self.round_method),
qscheme=QSchemeType.per_group,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
group_size=self.group_size,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"int8 per tensor quantization",
"Int8PerTensorSpec",
r"""
observer_method="min_max",
symmetric=True,
scale_type="float",
round_method="half_even",
is_dynamic=False
""",
)
)
class Int8PerTensorSpec(DataTypeSpec):
observer_method: str | None = "min_max"
symmetric: bool | None = True
scale_type: str | None = "float"
round_method: str | None = "half_even"
is_dynamic: bool | None = False
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.int8,
observer_cls=get_per_tensor_observer(self.observer_method),
symmetric=self.symmetric,
scale_type=get_scale_type(self.scale_type),
round_method=get_round_method(self.round_method),
qscheme=QSchemeType.per_tensor,
is_dynamic=self.is_dynamic,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"int8 per channel quantization",
"Int8PerChannelSpec",
r"""
symmetric=False,
scale_type="float",
round_method="half_even",
ch_axis=0,
is_dynamic=False
""",
)
)
class Int8PerChannelSpec(DataTypeSpec):
ch_axis: int
symmetric: bool | None = False
scale_type: str | None = "float"
round_method: str | None = "half_even"
is_dynamic: bool | None = False
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.int8,
observer_cls=PerChannelMinMaxObserver,
symmetric=self.symmetric,
scale_type=get_scale_type(self.scale_type),
round_method=get_round_method(self.round_method),
qscheme=QSchemeType.per_channel,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"int8 per group quantization",
"Int8PerGroupSpec",
r"""
symmetric=True,
scale_type="float",
round_method="half_even",
ch_axis=1,
is_dynamic=False,
group_size=128
""",
)
)
class Int8PerGroupSpec(DataTypeSpec):
ch_axis: int
group_size: int
symmetric: bool | None = True
scale_type: str | None = "float"
round_method: str | None = "half_even"
is_dynamic: bool | None = False
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.int8,
observer_cls=PerGroupMinMaxObserver,
symmetric=self.symmetric,
scale_type=get_scale_type(self.scale_type),
round_method=get_round_method(self.round_method),
qscheme=QSchemeType.per_group,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
group_size=self.group_size,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"FP8E4M3 per tensor quantization",
"FP8E4M3PerTensorSpec",
r"""
observer_method="min_max",
is_dynamic=False
""",
)
)
class FP8E4M3PerTensorSpec(DataTypeSpec):
observer_method: str | None = "min_max"
scale_type: str | None = None
is_dynamic: bool | None = False
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.fp8_e4m3,
observer_cls=get_per_tensor_observer(self.observer_method),
symmetric=True,
scale_type=get_scale_type(self.scale_type),
round_method=RoundType.half_even,
qscheme=QSchemeType.per_tensor,
is_dynamic=self.is_dynamic,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"FP8E4M3 per channel quantization", "FP8E4M3PerChannelSpec", "is_dynamic=False, ch_axis=0"
)
)
class FP8E4M3PerChannelSpec(DataTypeSpec):
ch_axis: int
symmetric: bool | None = True
scale_type: str | None = "float"
round_method: str | None = "half_even"
is_dynamic: bool | None = False
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.fp8_e4m3,
observer_cls=PerChannelMinMaxObserver,
symmetric=self.symmetric,
scale_type=get_scale_type(self.scale_type),
round_method=get_round_method(self.round_method),
qscheme=QSchemeType.per_channel,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"FP8E4M3 per group quantization",
"FP8E4M3PerGroupSpec",
r"""
ch_axis=-1,
group_size=group_size,
is_dynamic=True
""",
)
)
class FP8E4M3PerGroupSpec(DataTypeSpec):
ch_axis: int
group_size: int
scale_format: str | None = "float32"
scale_calculation_mode: str | None = None
is_dynamic: bool | None = True
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.fp8_e4m3,
observer_cls=PerBlockMXObserver,
symmetric=None,
scale_type=ScaleType.float,
scale_format=self.scale_format,
scale_calculation_mode=self.scale_calculation_mode,
round_method=RoundType.half_even,
qscheme=QSchemeType.per_group,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
group_size=self.group_size,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"FP8E5M2 per tensor quantization",
"FP8E5M2PerTensorSpec",
r"""
observer_method="min_max",
is_dynamic=False
""",
)
)
class FP8E5M2PerTensorSpec(DataTypeSpec):
observer_method: str | None = "min_max"
symmetric: bool | None = True
scale_type: str | None = "float"
round_method: str | None = "half_even"
is_dynamic: bool | None = False
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.fp8_e5m2,
observer_cls=get_per_tensor_observer(self.observer_method),
symmetric=self.symmetric,
scale_type=get_scale_type(self.scale_type),
round_method=get_round_method(self.round_method),
qscheme=QSchemeType.per_tensor,
is_dynamic=self.is_dynamic,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"FP8E5M2 per channel quantization", "FP8E5M2PerChannelSpec", "is_dynamic=False, ch_axis=0"
)
)
class FP8E5M2PerChannelSpec(DataTypeSpec):
ch_axis: int
symmetric: bool | None = True
scale_type: str | None = "float"
round_method: str | None = "half_even"
is_dynamic: bool | None = False
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.fp8_e5m2,
observer_cls=PerChannelMinMaxObserver,
symmetric=self.symmetric,
scale_type=get_scale_type(self.scale_type),
round_method=get_round_method(self.round_method),
qscheme=QSchemeType.per_channel,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"FP8E5M2 per group quantization",
"FP8E5M2PerGroupSpec",
r"""
ch_axis=-1,
group_size=group_size,
is_dynamic=True
""",
)
)
class FP8E5M2PerGroupSpec(DataTypeSpec):
ch_axis: int
group_size: int
scale_format: str | None = "float32"
scale_calculation_mode: str | None = None
is_dynamic: bool | None = True
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.fp8_e5m2,
observer_cls=PerBlockMXObserver,
symmetric=None,
scale_type=ScaleType.float,
scale_format=self.scale_format,
scale_calculation_mode=self.scale_calculation_mode,
round_method=RoundType.half_even,
qscheme=QSchemeType.per_group,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
group_size=self.group_size,
)
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"FP4 per group quantization",
"FP4PerGroupSpec",
r"""
ch_axis=-1,
group_size=group_size,
is_dynamic=True
""",
)
)
class FP4PerGroupSpec(DataTypeSpec):
ch_axis: int
group_size: int
scale_format: str | None = "float32"
scale_calculation_mode: str | None = None
is_dynamic: bool | None = True
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.fp4,
observer_cls=PerBlockMXObserver,
symmetric=None,
scale_type=ScaleType.float,
scale_format=self.scale_format,
scale_calculation_mode=self.scale_calculation_mode,
round_method=RoundType.half_even,
qscheme=QSchemeType.per_group,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
group_size=self.group_size,
)
[docs]
@dataclass
class FP6E2M3PerGroupSpec(DataTypeSpec):
ch_axis: int
group_size: int
scale_format: str | None = "float32"
scale_calculation_mode: str | None = None
is_dynamic: bool | None = True
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.fp6_e2m3,
observer_cls=PerBlockMXObserver,
symmetric=None,
scale_type=ScaleType.float,
scale_format=self.scale_format,
scale_calculation_mode=self.scale_calculation_mode,
round_method=RoundType.half_even,
qscheme=QSchemeType.per_group,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
group_size=self.group_size,
)
[docs]
@dataclass
class FP6E3M2PerGroupSpec(DataTypeSpec):
ch_axis: int
group_size: int
scale_format: str | None = "float32"
scale_calculation_mode: str | None = None
is_dynamic: bool | None = True
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.fp6_e3m2,
observer_cls=PerBlockMXObserver,
symmetric=None,
scale_type=ScaleType.float,
scale_format=self.scale_format,
scale_calculation_mode=self.scale_calculation_mode,
round_method=RoundType.half_even,
qscheme=QSchemeType.per_group,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
group_size=self.group_size,
)
[docs]
@dataclass(eq=True)
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"float16 data type. The resulting QTensorConfig does not quantize the tensor.", "Float16Spec", ""
)
)
class Float16Spec(DataTypeSpec):
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(dtype=Dtype.float16, observer_cls=PlaceholderObserver)
[docs]
@dataclass(eq=True)
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"bfloat16 data type. The resulting QTensorConfig does not quantize the tensor.", "Bfloat16Spec", ""
)
)
class Bfloat16Spec(DataTypeSpec):
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(dtype=Dtype.bfloat16, observer_cls=PlaceholderObserver)
[docs]
class OCP_MXSpec(DataTypeSpec):
OCP_MX_SPEC_KWARGS = {
"observer_cls": PerBlockMXObserver,
"symmetric": None,
"scale_type": ScaleType.float,
"round_method": RoundType.half_even,
"scale_format": "e8m0",
"qscheme": QSchemeType.per_group,
"group_size": 32,
}
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"MX OCP data type using FP8E4M3",
"OCP_MXFP8E4M3Spec",
r"""
ch_axis=-1,
is_dynamic=False
""",
)
)
class OCP_MXFP8E4M3Spec(OCP_MXSpec):
ch_axis: int
is_dynamic: bool = True
scale_calculation_mode: str = "even"
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.fp8_e4m3,
scale_calculation_mode=self.scale_calculation_mode,
is_dynamic=self.is_dynamic,
ch_axis=self.ch_axis,
**self.OCP_MX_SPEC_KWARGS,
) # type: ignore[arg-type]
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"MX OCP data type using FP8E5M2",
"OCP_MXFP8E5M2Spec",
r"""
ch_axis=-1,
is_dynamic=False
""",
)
)
class OCP_MXFP8E5M2Spec(OCP_MXSpec):
ch_axis: int
is_dynamic: bool = True
scale_calculation_mode: str = "even"
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.fp8_e5m2,
scale_calculation_mode=self.scale_calculation_mode,
is_dynamic=self.is_dynamic,
ch_axis=self.ch_axis,
**self.OCP_MX_SPEC_KWARGS,
) # type: ignore[arg-type]
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"MX OCP data type using FP6E3M2",
"OCP_MXFP6E3M2Spec",
r"""
ch_axis=-1,
is_dynamic=False
""",
)
)
class OCP_MXFP6E3M2Spec(OCP_MXSpec):
ch_axis: int
is_dynamic: bool = True
scale_calculation_mode: str = "even"
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.fp6_e3m2,
scale_calculation_mode=self.scale_calculation_mode,
is_dynamic=self.is_dynamic,
ch_axis=self.ch_axis,
**self.OCP_MX_SPEC_KWARGS,
) # type: ignore[arg-type]
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"MX OCP data type using FP6E2M3",
"OCP_MXFP6E2M3Spec",
r"""
ch_axis=-1,
is_dynamic=False
""",
)
)
class OCP_MXFP6E2M3Spec(OCP_MXSpec):
ch_axis: int
is_dynamic: bool = True
scale_calculation_mode: str = "even"
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.fp6_e2m3,
scale_calculation_mode=self.scale_calculation_mode,
is_dynamic=self.is_dynamic,
ch_axis=self.ch_axis,
**self.OCP_MX_SPEC_KWARGS,
) # type: ignore[arg-type]
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"MX OCP data type using FP4",
"OCP_MXFP4Spec",
r"""
ch_axis=-1,
is_dynamic=False
""",
)
)
class OCP_MXFP4Spec(OCP_MXSpec):
ch_axis: int
is_dynamic: bool = True
scale_calculation_mode: str = "even"
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.fp4,
scale_calculation_mode=self.scale_calculation_mode,
is_dynamic=self.is_dynamic,
ch_axis=self.ch_axis,
**self.OCP_MX_SPEC_KWARGS,
) # type: ignore[arg-type]
[docs]
@dataclass
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"MX OCP data type using INT8",
"OCP_MXINT8Spec",
r"""
ch_axis=-1,
is_dynamic=False
""",
)
)
class OCP_MXINT8Spec(OCP_MXSpec):
ch_axis: int
is_dynamic: bool = True
scale_calculation_mode: str = "even"
# TODO: support Dtype.int8 in PerBlockMXObserver.
# Dtype.int8 still uses NonScaledFakeQuantize (see tensor_quantize.py),
# which it needs not to.
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.mx,
mx_element_dtype=Dtype.int8,
scale_calculation_mode=self.scale_calculation_mode,
ch_axis=self.ch_axis,
group_size=32,
) # type: ignore[arg-type]
[docs]
@dataclass
class OCP_MXFP4DiffsSpec(DataTypeSpec):
ch_axis: int
is_dynamic: bool | None = False
scale_calculation_mode: str | None = "even"
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.fp4,
observer_cls=PerBlockMXDiffsObserver,
symmetric=None,
scale_type=ScaleType.float,
round_method=RoundType.half_even,
scale_format="e8m0",
scale_calculation_mode=self.scale_calculation_mode,
qscheme=QSchemeType.per_group,
ch_axis=self.ch_axis,
is_dynamic=self.is_dynamic,
group_size=32,
)
[docs]
@dataclass(eq=True)
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"MX6 data type as defined in https://arxiv.org/pdf/2302.08007. More details are available in the :doc:`Two Level Quantization Formats </pytorch/adv_two_level>` documentation",
"MX6Spec",
"is_dynamic=False",
)
)
class MX6Spec(DataTypeSpec):
ch_axis: int
block_size: int
scale_calculation_mode: str | None = "even"
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.mx6,
ch_axis=self.ch_axis,
group_size=self.block_size,
scale_calculation_mode=self.scale_calculation_mode,
)
[docs]
@dataclass(eq=True)
@add_start_docstring(
DATA_TYPE_SPEC_DOCSTRING.format(
"MX9 data type as defined in https://arxiv.org/pdf/2302.08007. More details are available in the :doc:`Two Level Quantization Formats </pytorch/adv_two_level>` documentation",
"MX9Spec",
"is_dynamic=False",
)
)
class MX9Spec(DataTypeSpec):
ch_axis: int
block_size: int
scale_calculation_mode: str | None = "even"
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.mx9,
ch_axis=self.ch_axis,
group_size=self.block_size,
scale_calculation_mode=self.scale_calculation_mode,
)
[docs]
@dataclass
@add_start_docstring(DATA_TYPE_SPEC_DOCSTRING.format("bfp16 data type", "BFP16Spec", "is_dynamic=False"))
class BFP16Spec(DataTypeSpec):
ch_axis: int
scale_calculation_mode: str | None = "even"
def to_quantization_spec(self) -> QTensorConfig:
return QTensorConfig(
dtype=Dtype.bfp16, ch_axis=self.ch_axis, group_size=8, scale_calculation_mode=self.scale_calculation_mode
)
QTT = TypeVar("QTT", bound="QTensorConfig")
[docs]
@dataclass(eq=True)
class QTensorConfig(BaseQTensorConfig):
"""
A data class that defines the specifications for quantizing tensors within a model.
:param Dtype dtype: The data type for quantization (e.g., int8, int4).
:param Optional[bool] is_dynamic: Specifies whether dynamic or static quantization should be used. Default is ``None``, which indicates no specification.
:param Optional[Type[ObserverBase]] observer_cls: The class of observer to be used for determining quantization parameters like min/max values. Default is ``None``.
:param Optional[QSchemeType] qscheme: The quantization scheme to use, such as per_tensor, per_channel or per_group. Default is ``None``.
:param Optional[int] ch_axis: The channel axis for per-channel quantization. Default is ``None``.
:param Optional[int] group_size: The size of the group for per-group quantization, also the block size for MX datatypes. Default is ``None``.
:param Optional[bool] symmetric: Indicates if the quantization should be symmetric around zero. If True, quantization is symmetric. If ``None``, it defers to a higher-level or global setting. Default is ``None``.
:param Optional[RoundType] round_method: The rounding method during quantization, such as half_even. If None, it defers to a higher-level or default method. Default is ``None``.
:param Optional[ScaleType] scale_type: Defines the scale type to be used for quantization, like power of two or float. If ``None``, it defers to a higher-level setting or uses a default method. Default is ``None``.
:param Optional[Dtype] mx_element_dtype: Defines the data type to be used for the element type when using mx datatypes, the shared scale effectively uses FP8 E8M0.
:param Optional[bool] is_scale_quant: Indicates whether this spec is for quantizing scales rather than tensors. Default is ``False``.
Example:
.. code-block:: python
from quark.torch.quantization.config.type import Dtype, ScaleType, RoundType, QSchemeType
from quark.torch.quantization.config.config import QTensorConfig
from quark.torch.quantization.observer.observer import PerChannelMinMaxObserver
quantization_spec = QTensorConfig(
dtype=Dtype.int8,
qscheme=QSchemeType.per_channel,
observer_cls=PerChannelMinMaxObserver,
symmetric=True,
scale_type=ScaleType.float,
round_method=RoundType.half_even,
is_dynamic=False,
ch_axis=1,
)
"""
# Note: `dtype`, `observer_cls`, `is_dynamic`, `qscheme`, `ch_axis`, `group_size`, `symmetric`, `round_method`, `scale_type`, `scale_format`, `scale_calculation_mode`, `qat_spec`, `mx_element_dtype`, `zero_point_type` are inherited from `BaseQTensorConfig`
qscheme: QSchemeType | None = None
observer_cls: type[ObserverBase] | None = None
dtype: Dtype
scale_type: ScaleType | None = None
mx_element_dtype: Dtype | None = None
zero_point_type: ZeroPointType | None = ZeroPointType.int32
is_scale_quant: bool = False
def __post_init__(self) -> None:
"""
When the user init a QTensorConfig, we need to check whether the config is valid.
for example:
1. observer_cls -> PerTensorPowOf2MinMSEObserver
2. qscheme -> QSchemeType.per_channel
For the above config, the `per_channel` is in conflict with PerTensorPowOf2MinMSEObserver
Target:
Once user config a Config like above that contains any conflict, we need to \
throw an exception and tell the user what the conflict is.
"""
if self.__class__ is QuantizationSpec:
logger.warning(
f"quark.torch.config.config.{self.__class__.__name__} is deprecated and will be removed in a future release. Please use quark.torch.config.config.QTensorConfig instead."
)
# NOTE: for developers, every time a new dtype is added, please add the corresponding check for the new dtype.
if self.dtype in ONLY_DTYPE_CHANGE and self.observer_cls != PlaceholderObserver:
raise ValueError(
f"{self.dtype} only supports observer_cls=PlaceholderObserver (as only type casting is used), got observer_cls={self.observer_cls}."
)
if self.dtype not in ALL_DATA_TYPES:
raise ValueError(f"The value dtype={self.dtype} is not among the supported dtypes {ALL_DATA_TYPES}.")
# NOTE: for developers, every time a new observer is added, please add the corresponding check for the new observer.
if self.observer_cls is not None and self.observer_cls not in OBSERVER_CLASSES:
raise ValueError(
f"The value observer_cls={self.observer_cls} is not among the supported observer_cls: {OBSERVER_CLASSES}."
)
if self.dtype in [
Dtype.int8,
Dtype.uint8,
Dtype.int16,
Dtype.uint16,
Dtype.int4,
Dtype.uint4,
Dtype.int3,
Dtype.int2,
Dtype.int32,
Dtype.fp8_e4m3,
Dtype.fp8_e5m2,
]:
if self.is_dynamic is None:
raise ValueError(
f"The field `is_dynamic` cannot be None when Dtype is {self.dtype.name} in QTensorConfig."
)
if self.observer_cls is None:
raise ValueError(
f"The field `observer_cls` cannot be None when Dtype is {self.dtype.name} in QTensorConfig."
)
if self.qscheme is None:
raise ValueError(
f"The field `qscheme` cannot be None when Dtype is {self.dtype.name} in QTensorConfig. Please reconfigure the quantization settings accordingly."
)
if self.dtype in [
Dtype.int8,
Dtype.uint8,
Dtype.int16,
Dtype.uint16,
Dtype.int4,
Dtype.uint4,
Dtype.int3,
Dtype.int32,
]:
if self.symmetric is None:
raise ValueError(
f"The field `symmetric` cannot be None when Dtype is {self.dtype.name} in QTensorConfig. Please reconfigure the quantization settings accordingly."
)
if self.round_method is None:
raise ValueError(
f"The field `round_method` cannot be None when Dtype is {self.dtype.name} in QTensorConfig. Please reconfigure the quantization settings accordingly."
)
if self.scale_type is None:
raise ValueError(
f"The field `scale_type` cannot be None when Dtype is {self.dtype.name} in QTensorConfig. Please reconfigure the quantization settings accordingly."
)
# CASE 1: will only init NonScaledFakeQuantize and PlaceholderObserver,
# NOTE: quark/torch/quantization/tensor_quantize.py FakeQuantizeBase:get_fake_quantize
# in this case will using NonScaledFakeQuantize, and only init PlaceholderObserver()
# As a results, quant forward func: quark.torch.kernel.non_scaled_fake_quantize
# /torch/kernel/hw_emulation/hw_emulation_interface.py: def non_scaled_fake_quantize
if self.dtype in USING_NON_SCALED_QUANT:
# 1.In quark/torch/kernel/__init__.py: class NonScaledFakeQuantizeFunction
# During quantization: the following needed
# 1.input_tensor 2.quant_dtype 2.mx_element_dtype 3.axis 4.block_size 5.scale_calculation_mode needed
# 2.In init PlaceholderObserver, only qspec.dtype needed
# Summary for QTensorConfig: 1.dtype(r) 2.mx_element_dtype(o) 3.axis(r) 4.group_size(r) 5.scale_calculation_mode(o)
required_fields = ["dtype", "ch_axis", "group_size"]
oprional_fields = ["mx_element_dtype", "scale_calculation_mode"]
if self.ch_axis is None:
raise ValueError(
f"When using dtype={self.dtype}, quantization_spec.ch_axis must be specified. Got `ch_axis=None`."
)
if self.group_size is None:
raise ValueError(
f"When using dtype={self.dtype}, quantization_spec.group_size must be specified. Got `group_size=None`."
)
if self.dtype == Dtype.mx and self.mx_element_dtype is None:
raise ValueError(
f"When using dtype={self.dtype}, quantization_spec.mx_element_dtype must be specified. Got `mx_element_dtype=None`."
)
for each_field in fields(self):
if each_field.name not in required_fields + oprional_fields:
value = getattr(self, each_field.name)
default_value = each_field.default
if value != default_value:
logger.warning(
f"When using dtype={self.dtype}, QTensorConfig.{each_field.name} will not take effect. Got {each_field.name}={value} but the default is {each_field.name}={default_value}."
)
return
# CASE 2: # NOTE only quantization_spec.dtype is needed
# in quantization actually call: fake_quantize_with_dtype_convert
if self.dtype in ONLY_DTYPE_CHANGE:
required_fields = ["dtype", "observer_cls"]
for each_field in fields(self):
if each_field.name not in required_fields:
value = getattr(self, each_field.name)
default_value = each_field.default
if value != default_value:
logger.warning(
f"In {self.dtype} quant, QTensorConfig.{each_field.name} will not take effect. User supplied: {value} User should skip setting this field"
)
return
# CASE 3
# NOTE: quark/torch/quantization/tensor_quantize.py FakeQuantizeBase:get_fake_quantize
# we will init ScaledFakeQuantize and the corresponding observer
# def scaled_fake_quantize, in Quark/quark/torch/kernel/hw_emulation/hw_emulation_interface.py
# 1. fake_quantize_int: qscheme, axis (if channel/group), group_size, round_mode
# 2. fake_quantize_fp8_e4m3: qscheme, axis (if channel/group), group_size
# 3. fake_quantize_fp8_e5m2: qscheme, axis (if channel/group), group_size
# 4. fake_quantize_fp4_fp6: qscheme, axis (if channel/group), group_size, quant_dtype(channel/group)
assert self.observer_cls is not None, "Supplied QTensorConfig's observer_cls is None"
assert self.qscheme is not None, "Supplied QTensorConfig's qscheme is None"
if self.qscheme == QSchemeType.per_tensor:
assert self.observer_cls in PER_TENSOR_OBSERVERS, (
f"You select Tensor wise quant, the observer_cls you select is {self.observer_cls} not support tesnor wise quant."
)
elif self.qscheme == QSchemeType.per_channel:
if self.observer_cls not in PER_CHANNEL_OBSERVERS:
raise ValueError(
f"You select channel wise quant, the observer_cls you select is {self.observer_cls} not support channel wise quant."
)
if not isinstance(self.ch_axis, int):
raise ValueError("You select channel wise quant, user must assigned int num to ch_axis.")
elif self.qscheme == QSchemeType.per_group:
if self.observer_cls not in PER_GROUP_OBSERVERS:
raise ValueError(
f"The combination qscheme={self.qscheme} and observer_cls={self.observer_cls} is not supported. For per group quantization, please use an observer from {PER_GROUP_OBSERVERS}."
)
if not isinstance(self.ch_axis, int):
raise ValueError(
f"Got ch_axis={self.ch_axis} in QTensorConfig initialization with qscheme={self.qscheme}. A correct positive integer value is required for per-group quantization."
)
if not isinstance(self.group_size, int):
raise ValueError(
f"Got group_size={self.group_size} in QTensorConfig initialization with qscheme={self.qscheme}. A correct positive integer value is required for per-group quantization."
)
else: # NOTE for developer
raise ModuleNotFoundError(
f"Please decide {self.observer_cls.__name__} belongs to which kind of quant (tensor/channel/group)."
)
def set_group_size(self, group_size: int) -> None:
assert isinstance(group_size, int) and (group_size > 0 or group_size == -1), (
"Group size must be a positive integer or -1 (which means group size equals to dimension size)."
)
self.group_size = group_size
[docs]
def to_dict(self) -> dict[str, Any]:
# TODO: qat_spec, mx_element_dtype missing.
return {
"dtype": self.dtype.name,
"is_dynamic": self.is_dynamic,
"qscheme": self.qscheme.name if self.qscheme is not None else None,
"ch_axis": self.ch_axis,
"group_size": self.group_size,
"symmetric": self.symmetric,
"round_method": self.round_method.name if self.round_method is not None else None,
"scale_type": self.scale_type.name if self.scale_type is not None else None,
"scale_format": self.scale_format,
"scale_calculation_mode": self.scale_calculation_mode,
"mx_element_dtype": self.mx_element_dtype.name if self.mx_element_dtype is not None else None,
"observer_cls": self.observer_cls.__name__ if self.observer_cls is not None else None,
"is_scale_quant": self.is_scale_quant,
}
[docs]
@classmethod
def from_dict(cls: type[QTT], config_dict: dict[str, Any]) -> QTT:
dtype = Dtype[config_dict["dtype"]]
if config_dict.get("mx_element_dtype") is not None:
mx_element_dtype = Dtype[config_dict["mx_element_dtype"]]
else:
mx_element_dtype = None
if config_dict.get("qscheme") is not None:
qscheme = QSchemeType[config_dict["qscheme"]]
else:
qscheme = None
if config_dict.get("round_method") is not None:
round_method = RoundType[config_dict["round_method"]]
else:
round_method = None
if config_dict.get("scale_type") is not None:
scale_type = ScaleType[config_dict["scale_type"]]
else:
scale_type = None
if config_dict.get("scale_format") is not None:
scale_format = config_dict["scale_format"]
else:
scale_format = None
if config_dict.get("scale_calculation_mode") is not None:
scale_calculation_mode = config_dict["scale_calculation_mode"]
else:
scale_calculation_mode = None
# TODO: Deprecate legacy configuration.
# Accomodate the legacy (quark<1.0) export which used custom keys.
is_dynamic = config_dict["is_dynamic"] if "is_dynamic" in config_dict else config_dict["dynamic"]
ch_axis = config_dict["ch_axis"] if "ch_axis" in config_dict else config_dict["axis"]
group_size = config_dict["group_size"]
symmetric = config_dict["symmetric"]
if "observer_cls" in config_dict:
if config_dict["observer_cls"] in OBSERVER_MAP:
observer_cls = OBSERVER_MAP[config_dict["observer_cls"]]
else: # pragma: no cover
logger.warning(
f"Unknown observer_cls={config_dict['observer_cls']}. Loading the QTensorConfig with observer_cls=PlaceholderObserver."
)
observer_cls = PlaceholderObserver
else: # pragma: no cover
# quark<1.0 used not to save the `observer_cls` in `QTensorConfig.to_dict()`.
observer_cls = PlaceholderObserver
is_scale_quant = config_dict.get("is_scale_quant", False)
return cls(
dtype=dtype,
is_dynamic=is_dynamic,
qscheme=qscheme,
ch_axis=ch_axis,
group_size=group_size,
symmetric=symmetric,
round_method=round_method,
scale_type=scale_type,
scale_format=scale_format,
scale_calculation_mode=scale_calculation_mode,
mx_element_dtype=mx_element_dtype,
observer_cls=observer_cls, # type: ignore[arg-type]
is_scale_quant=is_scale_quant,
)
def is_ocp_mxfp4(self) -> bool:
quant_spec_is_mxfp4 = True
for key, value in OCP_MXSpec.OCP_MX_SPEC_KWARGS.items():
if getattr(self, key) != value:
quant_spec_is_mxfp4 = False
break
if self.dtype != Dtype.fp4:
quant_spec_is_mxfp4 = False
return quant_spec_is_mxfp4
[docs]
@dataclass
class QATSpec(BaseConfigImpl):
pass
[docs]
@dataclass
class TQTSpec(QATSpec):
"""
Configuration for the Trained Quantization Thresholds (TQT) post-training quantization method, implementing https://arxiv.org/abs/1903.08066.
"""
threshold_init_meth: TQTThresholdInitMeth | None = None
[docs]
def load_pre_optimization_config_from_file(file_path: str) -> PreQuantOptConfig:
"""
Load pre-optimization configuration from a JSON file.
:param file_path: The path to the JSON file containing the pre-optimization configuration.
:type file_path: str
:return: The pre-optimization configuration.
:rtype: PreQuantOptConfig
"""
with open(file_path) as file:
algo_config_info = json.load(file)
return _load_pre_optimization_config_from_dict(algo_config_info)
[docs]
def load_quant_algo_config_from_file(file_path: str) -> AlgoConfig:
"""
Load quantization algorithm configuration from a JSON file.
:param file_path: The path to the JSON file containing the quantization algorithm configuration.
:type file_path: str
:return: The quantization algorithm configuration.
:rtype: AlgoConfig
"""
with open(file_path) as file:
algo_config_info = json.load(file)
return _load_quant_algo_config_from_dict(algo_config_info)
def _load_pre_optimization_config_from_dict(pre_optimization_config_dict: dict[str, Any]) -> PreQuantOptConfig:
"""
Load pre-optimization configuration from a dictionary.
:param pre_optimization_config_dict: A dictionary containing the pre-optimization configuration.
:type pre_optimization_config_dict: Dict[str, Any]
:return: The pre-optimization configuration.
:rtype: PreQuantOptConfig
:raises ValueError: If the configuration name is not recognized.
"""
# Deprecate old settings for GQA
pre_optimization_config_dict.pop("num_attention_heads", None)
pre_optimization_config_dict.pop("num_key_value_heads", None)
if pre_optimization_config_dict["name"] == "rotation":
return cast(PreQuantOptConfig, RotationConfig.from_dict(pre_optimization_config_dict))
elif pre_optimization_config_dict["name"] == "quarot":
return cast(PreQuantOptConfig, QuaRotConfig.from_dict(pre_optimization_config_dict))
elif pre_optimization_config_dict["name"] == "smooth":
return cast(PreQuantOptConfig, SmoothQuantConfig.from_dict(pre_optimization_config_dict))
else:
raise ValueError(f"Unknown algorithm name {pre_optimization_config_dict['name']}")
def _load_quant_algo_config_from_dict(algo_config_dict: dict[str, Any]) -> AlgoConfig:
"""
Load quantization algorithm configuration from a dictionary.
:param algo_config_dict: A dictionary containing the quantization algorithm configuration.
:type algo_config_dict: Dict[str, Any]
:return: The quantization algorithm configuration.
:rtype: AlgoConfig
:raises ValueError: If the configuration name is not recognized.
"""
# Deprecate old settings for GQA
algo_config_dict.pop("num_attention_heads", None)
algo_config_dict.pop("num_key_value_heads", None)
if algo_config_dict["name"] == "rotation":
return cast(AlgoConfig, RotationConfig.from_dict(algo_config_dict))
elif algo_config_dict["name"] == "quarot":
return cast(AlgoConfig, QuaRotConfig.from_dict(algo_config_dict))
elif algo_config_dict["name"] == "smooth":
return cast(AlgoConfig, SmoothQuantConfig.from_dict(algo_config_dict))
elif algo_config_dict["name"] == "awq":
return cast(AlgoConfig, AWQConfig.from_dict(algo_config_dict))
elif algo_config_dict["name"] == "gptq": # pragma: no cover
return cast(AlgoConfig, GPTQConfig.from_dict(algo_config_dict))
elif algo_config_dict["name"] == "gptaq": # pragma: no cover
return cast(AlgoConfig, GPTAQConfig.from_dict(algo_config_dict))
elif algo_config_dict["name"] == "autosmoothquant": # pragma: no cover:
return cast(AlgoConfig, AutoSmoothQuantConfig.from_dict(algo_config_dict))
elif algo_config_dict["name"] == "qronos": # pragma: no cover
return cast(AlgoConfig, QronosConfig.from_dict(algo_config_dict))
else:
raise ValueError(f"Unknown algorithm name {algo_config_dict['name']}")
[docs]
@dataclass
class PreQuantOptConfig(BaseAlgoConfig):
pass
[docs]
@dataclass
class AlgoConfig(BaseAlgoConfig):
pass
[docs]
@dataclass
class SmoothQuantConfig(AlgoConfig):
"""
A data class that defines the specifications for Smooth Quantization.
:param str name: The name of the configuration, typically used to identify different quantization settings. Default is ``"smooth"``.
:param int alpha: The factor of adjustment in the quantization formula, influencing how aggressively weights are quantized. Default is ``1``.
:param float scale_clamp_min: The minimum scaling factor to be used during quantization, preventing the scale from becoming too small. Default is ``1e-3``.
:param List[Dict[str, Any]] scaling_layers: Specific settings for scaling layers, allowing customization of quantization parameters for different layers within the model. Default is ``None``.
:param str model_decoder_layers: Specifies any particular decoder layers in the model that might have unique quantization requirements. Default is ``None``.
The parameter ``scaling_layers`` can be left to an empty list (default), in which case they will be automatically detected.
Example:
.. code-block:: python
from quark.torch.quantization.config.config import SmoothQuantConfig
scaling_layers=[
{
"prev_op": "input_layernorm",
"layers": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
"inp": "self_attn.q_proj",
"module2inspect": "self_attn"
},
{
"prev_op": "post_attention_layernorm",
"layers": ["mlp.gate_proj", "mlp.up_proj"],
"inp": "mlp.gate_proj",
"module2inspect": "mlp"
}
]
smoothquant_config = SmoothQuantConfig(
scaling_layers=scaling_layers,
model_decoder_layers="model.layers"
)
"""
name: str = "smooth"
alpha: float = 1
scale_clamp_min: float = 1e-3
scaling_layers: list[dict[str, Any]] = field(default_factory=list)
model_decoder_layers: str = ""
[docs]
@dataclass
class RotationConfig(AlgoConfig):
"""
A data class that defines the specifications for the rotation algorithms.
:param str name: The name of the configuration, typically used to identify different rotation settings. Default is ``"rotation"``.
:param bool r1: Whether to apply ``R1`` rotation. See `SpinQuant paper <https://arxiv.org/abs/2405.16406>`__ for details. Defaults to ``True``.
:param bool r2: Whether to apply ``R2`` rotation. See `SpinQuant paper <https://arxiv.org/abs/2405.16406>`__ for details. Defaults to ``True``.
:param bool r3: Whether to apply ``R3`` rotation. It is only useful when using KV cache quantization. See `SpinQuant paper <https://arxiv.org/abs/2405.16406>`__ for details. Defaults to ``True``.
:param bool r4: Whether to apply ``R4`` rotation. See `SpinQuant paper <https://arxiv.org/abs/2405.16406>`__ for details. Defaults to ``True``.
:param Optional[int] rotation_size: The size of rotations to apply on activations/weights. By default, the activation last dimension (e.g. ``hidden_size``), or weight input/output channel dimension is used as rotation size. In case the parameter ``rotation_size`` is specified, smaller rotations of size ``(rotation_size, rotation_size)`` are applied per-block. Defaults to ``None``.
:param Optional[bool] random: Deprecated. Use ``random_r1`` and ``random_r2`` instead. Defaults to ``None``.
:param bool random_r1: A boolean flag indicating whether ``R1`` should be a random Hadamard matrix. See `SpinQuant paper <https://arxiv.org/abs/2405.16406>`__ for details. This can be useful for data augmentation purposes where random rotations may be required. Default is ``False``.
:param bool random_r2: A boolean flag indicating whether ``R2`` should be a random Hadamard matrix. See `SpinQuant paper <https://arxiv.org/abs/2405.16406>`__ for details. This can be useful for data augmentation purposes where random rotations may be required. Default is ``False``. ``random_r1`` and ``random_r2`` are only relevant if we are using Hadamard rotations for ``R1`` and ``R2``.
:param List[Dict[str, str]] scaling_layers: Specific settings for scaling layers, specifying the layer names where ``R1`` rotations must be applied if chosen, or where smoothing scales must be trained, if ``train_smooth=True``. It is a dictionary with keys ``"first_layer"``, ``"middle_layers"`` and ``"last_layer"``, which are dictionaries specifying:
- ``"prev_modules"``: The list of previous modules the activation rotation may be fused into. **This is optional/unused for online R1**.
- ``"norm_module"``: The list of normalization layer that is in between the modules in ``"prev_modules"`` and ``"next_modules"``. The normalization weight will typically need to be merged first into the modules of ``"next_modules"`` in order to permute the activation rotation with the normalization, and fuse the activation rotation in the modules of ``"prev_modules"``. **This is optional/unused for online R1**.
- ``"next_modules"``: The list of modules to fuse inverse rotation into, on the input features dimension. This corresponds to weight rotation on the input features dimension. **This is optional/unused for online R1**.
- ``"target_modules"``: If specified, the list of modules to apply online rotation on, this may be only on a subset of ``"next_modules"``. This is for example useful for MOE models, in case the experts gate/router is not quantized (skipping online rotation for it). **Useful only for online ``R1`` rotations, optional.** If not provided, ``"next_modules"`` is used instead as the list of modules to apply rotation on.
:param str backbone: A string indicating the path to the model backbone.
:param str model_decoder_layers: A string indicating the path to the list of decoder layers.
:param str v_proj: A string indicating the path to the v projection layer, starting from the decoder layer it is in.
:param str o_proj: A string indicating the path to the o projection layer, starting from the decoder layer it is in.
:param str self_attn: A string indicating the path to the self attention block, starting from the decoder layer it is in.
:param str mlp: A string indicating the path to the multilayer perceptron layer, starting from the decoder layer it is in.
:param Optional[bool] online_r1_rotation: Whether the activation rotation ``R1`` should be kept online instead of being fused into the preceding layer weights. In case ``online_r1_rotation=False``, a single ``R1`` rotation is shared across all layers, whereas specialized rotations can be used per linear layer in case the activation rotation is done online. Defaults to ``False``, i.e. ``R1`` is fused offline. If `R1` is not used, defaults to ``None``.
:param bool trainable: Whether to insert trainable rotations.
:param Optional[OnlineRotationConfig] online_config: Configuration specific to online rotations. Refer to :py:class:`.OnlineRotationConfig`. Defaults to ``OnlineRotationConfig(shared_parallel=False)`` for online trainable ``R1``, otherwise defaults to ``None`` (no effect).
:param Optional[bool] train_smooth: Whether to train SmoothQuant scales. The trainable transform is then ``T = D @ R`` (or ``T = R @ D``) with the rotation ``R`` and the SmoothQuant scales ``D`` (seen as a 1D vector, or a diagonal matrix). This parameter is similar to the approach in `OSTQuant <https://arxiv.org/abs/2501.13987>`__ paper. Defaults to ``None``. In case ``trainable=True``, defaults to ``False``.
:param Optional[List[str]] smooth_positions: A list of rotation positions to train smoothing scales for. Acceptable list values are ``"r1"``, ``"r2"``, ``"r4"``. Defaults to ``None``. In case ``trainable=True``, defaults to ``[]``.
Example for llama model, offline rotations:
.. code-block:: python
from quark.torch.quantization.config.config import RotationConfig
scaling_layers = {
"first_layer": [
{
"prev_modules": ["model.embed_tokens"],
"norm_module": "model.layers.layer_id.input_layernorm",
"next_modules": [
"model.layers.layer_id.self_attn.q_proj",
"model.layers.layer_id.self_attn.k_proj",
"model.layers.layer_id.self_attn.v_proj",
],
},
{
"prev_modules": ["model.layers.layer_id.self_attn.o_proj"],
"norm_module": "model.layers.layer_id.post_attention_layernorm",
"next_modules": ["model.layers.layer_id.mlp.up_proj", "model.layers.layer_id.mlp.gate_proj"],
},
],
"middle_layers": [
{
"prev_modules": ["model.layers.pre_layer_id.mlp.down_proj"],
"norm_module": "model.layers.layer_id.input_layernorm",
"next_modules": [
"model.layers.layer_id.self_attn.q_proj",
"model.layers.layer_id.self_attn.k_proj",
"model.layers.layer_id.self_attn.v_proj",
],
},
{
"prev_modules": ["model.layers.layer_id.self_attn.o_proj"],
"norm_module": "model.layers.layer_id.post_attention_layernorm",
"next_modules": ["model.layers.layer_id.mlp.up_proj", "model.layers.layer_id.mlp.gate_proj"],
},
],
"last_layer": [
{
"prev_modules": ["model.layers.layer_id.mlp.down_proj"],
"norm_module": "model.norm",
"next_modules": ["lm_head"],
}
],
}
rotation_config = RotationConfig(
model_decoder_layers="model.layers",
v_proj="self_attn.v_proj",
o_proj="self_attn.o_proj",
self_attn="self_attn",
mlp="mlp",
scaling_layers=scaling_layers,
r1=True,
r2=True,
)
Example for llama model, online rotations:
.. code-block:: python
from quark.torch.quantization.config.config import RotationConfig
scaling_layers = {
"first_layer": [
{
"target_modules": [
"model.layers.layer_id.self_attn.q_proj",
"model.layers.layer_id.self_attn.k_proj",
"model.layers.layer_id.self_attn.v_proj",
]
},
{
"target_modules": ["model.layers.layer_id.mlp.up_proj", "model.layers.layer_id.mlp.gate_proj"]
},
],
"middle_layers": [
{
"target_modules": [
"model.layers.layer_id.self_attn.q_proj",
"model.layers.layer_id.self_attn.k_proj",
"model.layers.layer_id.self_attn.v_proj",
]
},
{
"target_modules": ["model.layers.layer_id.mlp.up_proj", "model.layers.layer_id.mlp.gate_proj"]
},
],
"last_layer": [
{
"target_modules": [],
}
],
}
rotation_config = RotationConfig(
model_decoder_layers="model.layers",
v_proj="self_attn.v_proj",
o_proj="self_attn.o_proj",
self_attn="self_attn",
mlp="mlp",
scaling_layers=scaling_layers,
r1=True,
r2=True,
online_r1_rotation=True,
rotation_size=32,
)
"""
scaling_layers: dict[str, list[dict[str, Any]]]
name: str = "rotation"
r1: bool = True
r2: bool = False
r3: bool = False
r4: bool = False
rotation_size: int | None = None
random: bool | None = None # TODO: deprecated, remove in 0.12
random_r1: bool = False
random_r2: bool = False
backbone: str = "model"
model_decoder_layers: str = "model.layers"
v_proj: str = "self_attn.v_proj"
o_proj: str = "self_attn.o_proj"
self_attn: str = "self_attn"
mlp: str = "mlp"
online_r1_rotation: bool | None = None
online_config: OnlineRotationConfig | None = None
trainable: bool = False
train_smooth: bool | None = None
smooth_positions: list[str] | None = None
def __post_init__(self) -> None:
if self.random is not None:
logger.warning(
f"quark.torch.config.config.RotationConfig argument `random` is deprecated and will be removed in v0.12, please use `random_r1` and `random_r2` instead. Got `random={self.random}`. Setting `random_r1={self.random}` and `random_r2={self.random}`."
)
self.random_r1 = self.random
self.random_r2 = self.random
if (self.random_r1 or self.random_r2) and self.rotation_size is not None:
raise NotImplementedError(
f"random_r1=True or random_r2=True along with a custom rotation_size={self.rotation_size} is not supported at the moment in RotationConfig. Please open an issue."
)
if self.r1:
if self.online_r1_rotation is None:
self.online_r1_rotation = False
else:
if self.online_r1_rotation is not None:
raise ValueError(
f"The configuration online_r1_rotation={self.online_r1_rotation} has no effect along r1={self.r1}."
)
if self.trainable and self.r3:
raise NotImplementedError("trainable=True along `r3=True` is not implemented.")
if self.online_config is not None:
if not self.online_r1_rotation or not self.r1 or not self.trainable:
raise ValueError(
f"Got online_config={self.online_config}, r1={self.r1}, trainable={self.trainable}, for which online_config={self.online_config} has no effect."
)
else:
if self.trainable and self.r1 and self.online_r1_rotation:
self.online_config = OnlineRotationConfig(shared_parallel=False)
if self.trainable and self.train_smooth is None:
self.train_smooth = False
self.smooth_positions = []
if self.train_smooth and not self.trainable:
raise ValueError(f"train_smooth={self.train_smooth} along trainable={self.trainable} is not supported.")
if self.train_smooth and self.smooth_positions is None:
raise ValueError(
"RotationConfig.smooth_positions needs to be set in case RotationConfig.train_smooth=True, got None."
)
@classmethod
def from_dict(cls, rotation_dict: dict[str, Any]) -> RotationConfig:
if "online_config" in rotation_dict and isinstance(rotation_dict["online_config"], dict):
online_config = OnlineRotationConfig(**rotation_dict["online_config"])
rotation_dict["online_config"] = online_config
return cls(**rotation_dict)
[docs]
@dataclass
class OnlineRotationConfig:
shared_parallel: bool
@dataclass
class QuaRotConfig(AlgoConfig):
# TODO: deprecated, remove in 0.12 release.
scaling_layers: dict[str, list[dict[str, Any]]]
name: str = "quarot"
r1: bool = True
r2: bool = True
r3: bool = True
r4: bool = True
rotation_size: int | None = None
random_r1: bool = False
random_r2: bool = False
optimized_rotation_path: str | None = None
backbone: str = "model"
model_decoder_layers: str = "model.layers"
v_proj: str = "self_attn.v_proj"
o_proj: str = "self_attn.o_proj"
self_attn: str = "self_attn"
mlp: str = "mlp"
def __post_init__(self) -> None:
if self.__class__ is QuaRotConfig:
logger.warning(
"quark.torch.config.config.QuaRotConfig is deprecated and will be removed in AMD Quark v0.12. Please use `quark.torch.quantization.config.config.RotationConfig` instead."
)
if (self.random_r1 or self.random_r2) and self.rotation_size is not None:
raise NotImplementedError(
f"random_r1=True or random_r2=True along with a custom rotation_size={self.rotation_size} is not supported at the moment in QuaRotConfig. Please open an issue."
)
if self.optimized_rotation_path is not None and self.rotation_size is not None:
raise NotImplementedError(
f"Using a preset optimized_rotation_path={self.optimized_rotation_path} along with a custom rotation_size={self.rotation_size} is not supported. Please open an issue."
)
[docs]
@dataclass
class AutoSmoothQuantConfig(AlgoConfig):
"""
A data class that defines the specifications for AutoSmoothQuant.
:param str name: The name of the quantization configuration. Default is ``"autosmoothquant"``.
:param List[Dict[str, str]] scaling_layers: Configuration details for scaling layers within the model, specifying custom scaling parameters per layer. Default is ``None``.
:param str compute_scale_loss: Calculate the best scale loss, "MSE" or "MAE". Default is ``"MSE"``.
:param str model_decoder_layers: Specifies the layers involved in model decoding that may require different quantization parameters. Default is ``None``.
Example:
.. code-block:: python
from quark.torch.quantization.config.config import AutoSmoothQuantConfig
scaling_layers = [
{
"prev_op": "input_layernorm",
"layers": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
"inp": "self_attn.q_proj",
"module2inspect": "self_attn"
},
{
"prev_op": "self_attn.v_proj",
"layers": ["self_attn.o_proj"],
"inp": "self_attn.o_proj"
},
{
"prev_op": "post_attention_layernorm",
"layers": ["mlp.gate_proj", "mlp.up_proj"],
"inp": "mlp.gate_proj",
"module2inspect": "mlp"
},
{
"prev_op": "mlp.up_proj",
"layers": ["mlp.down_proj"],
"inp": "mlp.down_proj"
}
]
autosmoothquant_config = AutoSmoothQuantConfig(
model_decoder_layers="model.layers",
scaling_layers=scaling_layers
)
"""
name: str = "autosmoothquant"
scaling_layers: list[dict[str, Any]] | None = None
model_decoder_layers: str | None = None
compute_scale_loss: str | None = "MSE"
[docs]
@dataclass
class AWQConfig(AlgoConfig):
"""
Configuration for Activation-aware Weight Quantization (AWQ).
:param str name: The name of the quantization configuration. Default is ``"awq"``.
:param List[Dict[str, Any]] scaling_layers: Configuration details for scaling layers within the model, specifying custom scaling parameters per layer. Default is ``None``.
:param str model_decoder_layers: Specifies the layers involved in model decoding that may require different quantization parameters. Default is ``None``.
The parameter ``scaling_layers`` can be left to an empty list (default), in which case they will be automatically detected.
Example:
.. code-block:: python
from quark.torch.quantization.config.config import AWQConfig
scaling_layers = [
{
"prev_op": "input_layernorm",
"layers": ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"],
"inp": "self_attn.q_proj",
"module2inspect": "self_attn"
},
{
"prev_op": "post_attention_layernorm",
"layers": ["mlp.gate_proj", "mlp.up_proj"],
"inp": "mlp.gate_proj",
"module2inspect": "mlp"
},
]
awq_config = AWQConfig(
model_decoder_layers="model.layers",
scaling_layers=scaling_layers
)
"""
name: str = "awq"
scaling_layers: list[dict[str, Any]] = field(default_factory=list)
model_decoder_layers: str = field(default_factory=str)
[docs]
@dataclass
class GPTQConfig(AlgoConfig):
"""
A data class that defines the specifications for Accurate Post-Training Quantization for Generative Pre-trained Transformers (GPTQ).
:param str name: The configuration name. Default is ``"gptq"``.
:param int block_size: GPTQ divides the columns into blocks of size block_size and quantizes each block separately. Default is ``128``.
:param float damp_percent: The percentage used to dampen the quantization effect, aiding in the maintenance of accuracy post-quantization. Default is ``0.01``.
:param bool desc_act: Indicates whether descending activation is used, typically to enhance model performance with quantization. Default is ``True``.
:param bool static_groups: Specifies whether the order of groups for quantization are static or can be dynamically adjusted. Default is ``True``. Quark export only support static_groups as True.
:param List[str] inside_layer_modules: Lists the names of internal layer modules within the model that require specific quantization handling. Default is ``None``.
:param str model_decoder_layers: Specifies custom settings for quantization on specific decoder layers of the model. Default is ``None``.
Example:
.. code-block:: python
from quark.torch.quantization.config.config import GPTQConfig
gptq_config = GPTQConfig(
inside_layer_modules=[
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.q_proj",
"self_attn.o_proj",
"mlp.up_proj",
"mlp.gate_proj",
"mlp.down_proj"
],
model_decoder_layers="model.layers"
)
"""
name: str = "gptq"
block_size: int = 128
damp_percent: float = 0.01
desc_act: bool = True
static_groups: bool = True
inside_layer_modules: list[str] = field(default_factory=list)
model_decoder_layers: str = field(default_factory=str)
def __post_init__(self) -> None:
if self.desc_act and not self.static_groups:
raise ValueError(
"AMD Quark does not support using GPTQ with `desc_act=True` and `static_groups=False`. Please use `static_groups=True`, or disable `desc_act`."
)
[docs]
@dataclass
class GPTAQConfig(AlgoConfig):
"""
A data class that defines the specifications for Accurate Post-Training Quantization for Generative Pre-trained Transformers (GPTQ).
:param str name: The configuration name. Default is ``"gptaq"``.
:param int block_size: GPTAQ divides the columns into blocks of size block_size and quantizes each block separately. Default is ``128``.
:param float damp_percent: The percentage used to dampen the quantization effect, aiding in the maintenance of accuracy post-quantization. Default is ``0.01``.
:param bool desc_act: Indicates whether descending activation is used, typically to enhance model performance with quantization. Default is ``True``.
:param bool static_groups: Specifies whether the order of groups for quantization are static or can be dynamically adjusted. Default is ``True``. Quark export only support static_groups as True.
:param List[str] inside_layer_modules: Lists the names of internal layer modules within the model that require specific quantization handling. Default is ``None``.
:param str model_decoder_layers: Specifies custom settings for quantization on specific decoder layers of the model. Default is ``None``.
Example:
.. code-block:: python
from quark.torch.quantization.config.config import GPTAQConfig
gptq_config = GPTAQConfig(
inside_layer_modules=[
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.q_proj",
"self_attn.o_proj",
"mlp.up_proj",
"mlp.gate_proj",
"mlp.down_proj"
],
model_decoder_layers="model.layers"
)
"""
name: str = "gptaq"
block_size: int = 128
damp_percent: float = 0.01
alpha: float = 0.25
desc_act: bool = True
static_groups: bool = True
inside_layer_modules: list[str] = field(default_factory=list)
model_decoder_layers: str = field(default_factory=str)
def __post_init__(self) -> None:
if self.desc_act and not self.static_groups:
raise ValueError(
"AMD Quark does not support using GPTAQ with `desc_act=True` and `static_groups=False`. Please use `static_groups=True`, or disable `desc_act`."
)
[docs]
@dataclass
class QronosConfig(AlgoConfig):
"""
Configuration for Qronos, an advanced post-training quantization algorithm. Implemented as proposed in https://arxiv.org/pdf/2505.11695
:param List[str] inside_layer_modules: Lists the names of internal layer modules within the model that require specific quantization handling.
:param str model_decoder_layers: Specifies custom settings for quantization on specific decoder layers of the model.
:param str name: The configuration name. Default is ``"qronos"``.
:param int block_size: Qronos divides the columns into blocks of size block_size and quantizes each block separately. Default is ``128``.
:param bool desc_act: Indicates whether descending activation is used, typically to enhance model performance with quantization. Default is ``True``.
:param bool static_groups: Specifies whether the order of groups for quantization are static or can be dynamically adjusted. Default is ``True``. Quark export only supports ``static_groups=True``.
:param float alpha: Dampening factor for numerical stability during matrix inversions. Default is ``1e-6``.
:param float beta: Stabilisation factor for Cholesky decomposition. Default is ``1e4``.
Example:
.. code-block:: python
from quark.torch.quantization.config.config import QronosConfig
qronos_config = QronosConfig(
inside_layer_modules=[
"self_attn.k_proj",
"self_attn.v_proj",
"self_attn.q_proj",
"self_attn.o_proj",
"mlp.up_proj",
"mlp.gate_proj",
"mlp.down_proj"
],
model_decoder_layers="model.layers"
)
"""
inside_layer_modules: list[str]
model_decoder_layers: str
name: str = "qronos"
block_size: int = 128
desc_act: bool = True
static_groups: bool = True
alpha: float = 1e-3
beta: float = 1e4
def __post_init__(self) -> None:
if self.desc_act and not self.static_groups:
raise ValueError(
"AMD Quark does not support using Qronos with `desc_act=True` and `static_groups=False`. Please use `static_groups=True`."
)
if self.block_size <= 0:
raise ValueError(f"Number of blocks must be positive, got {self.block_size}.")
@dataclass(eq=True)
class QuantizationSpec(QTensorConfig):
pass
@dataclass(eq=True)
class QuantizationConfig(QLayerConfig):
pass
@dataclass(eq=True)
class Config(QConfig):
pass