Source code for quark.torch.quantization.config.template

#
# Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
#
from __future__ import annotations

from typing import Any

import torch.nn as nn

from quark.shares.utils.log import ScreenLogger
from quark.torch.export.main_export.quant_config_parser import get_layer_quant_config
from quark.torch.quantization.config.algo_configs import get_algo_config
from quark.torch.quantization.config.config import (
    AlgoConfig,
    AutoSmoothQuantConfig,
    AWQConfig,
    BFP16Spec,
    Config,
    FP8E4M3PerChannelSpec,
    FP8E4M3PerTensorSpec,
    GPTAQConfig,
    GPTQConfig,
    Int4PerChannelSpec,
    Int4PerGroupSpec,
    Int8PerTensorSpec,
    MX6Spec,
    OCP_MXFP4Spec,
    OCP_MXFP6E2M3Spec,
    OCP_MXFP6E3M2Spec,
    QLayerConfig,
    QronosConfig,
    RotationConfig,
    SmoothQuantConfig,
    Uint4PerChannelSpec,
    Uint4PerGroupSpec,
)

logger = ScreenLogger(__name__)


[docs] class QuantizationScheme: """Abstract base class for quantization schemes.""" def __init__(self, config: QLayerConfig): self._config = config @property def config(self) -> QLayerConfig: return self._config
[docs] class Int4WeightOnlyScheme(QuantizationScheme): """Scheme for INT4 weight-only quantization.""" def __init__(self, group_size: int): self.group_size = group_size @property def config(self) -> QLayerConfig: weight_spec = Int4PerGroupSpec( ch_axis=-1, is_dynamic=False, scale_type="float", group_size=self.group_size ).to_quantization_spec() return QLayerConfig(weight=weight_spec)
[docs] class Int4WeightOnlyPerChannelScheme(QuantizationScheme): """Scheme for INT4 weight-only per-channel quantization.""" def __init__(self) -> None: pass @property def config(self) -> QLayerConfig: weight_spec = Int4PerChannelSpec(is_dynamic=False, ch_axis=0).to_quantization_spec() return QLayerConfig(weight=weight_spec)
[docs] class Uint4WeightOnlyScheme(QuantizationScheme): """Scheme for UINT4 weight-only quantization.""" def __init__(self, group_size: int): self.group_size = group_size @property def config(self) -> QLayerConfig: weight_spec = Uint4PerGroupSpec( ch_axis=-1, is_dynamic=False, scale_type="float", group_size=self.group_size ).to_quantization_spec() return QLayerConfig(weight=weight_spec)
[docs] class Uint4WeightOnlyPerChannelScheme(QuantizationScheme): """Scheme for UINT4 weight-only per-channel quantization.""" def __init__(self) -> None: pass @property def config(self) -> QLayerConfig: weight_spec = Uint4PerChannelSpec(is_dynamic=False, ch_axis=0).to_quantization_spec() return QLayerConfig(weight=weight_spec)
[docs] class Int8Scheme(QuantizationScheme): """Scheme for INT8 weight and activation input quantization.""" def __init__(self) -> None: pass @property def config(self) -> QLayerConfig: spec = Int8PerTensorSpec( observer_method="min_max", symmetric=True, scale_type="float", round_method="half_even", is_dynamic=False ).to_quantization_spec() return QLayerConfig(weight=spec, input_tensors=spec)
[docs] class FP8Scheme(QuantizationScheme): """Scheme for FP8 quantization (e4m3 format).""" def __init__(self) -> None: pass @property def config(self) -> QLayerConfig: spec = FP8E4M3PerTensorSpec( observer_method="min_max", scale_type="float", is_dynamic=False ).to_quantization_spec() return QLayerConfig(weight=spec, input_tensors=spec)
[docs] class MXFP4Scheme(QuantizationScheme): """Scheme for MXFP4 quantization.""" def __init__(self) -> None: pass @property def config(self) -> QLayerConfig: spec = OCP_MXFP4Spec(ch_axis=-1, is_dynamic=False).to_quantization_spec() spec_dynamic = OCP_MXFP4Spec(ch_axis=-1, is_dynamic=True).to_quantization_spec() return QLayerConfig(weight=spec, input_tensors=spec_dynamic)
[docs] class MXFP6E3M2Scheme(QuantizationScheme): """Scheme for MXFP6E3M2 quantization.""" def __init__(self) -> None: pass @property def config(self) -> QLayerConfig: spec = OCP_MXFP6E3M2Spec(ch_axis=-1, is_dynamic=False).to_quantization_spec() spec_dynamic = OCP_MXFP6E3M2Spec(ch_axis=-1, is_dynamic=True).to_quantization_spec() return QLayerConfig(weight=spec, input_tensors=spec_dynamic)
[docs] class MXFP6E2M3Scheme(QuantizationScheme): """Scheme for MXFP6E2M3 quantization.""" def __init__(self) -> None: pass @property def config(self) -> QLayerConfig: spec = OCP_MXFP6E2M3Spec(ch_axis=-1, is_dynamic=False).to_quantization_spec() spec_dynamic = OCP_MXFP6E2M3Spec(ch_axis=-1, is_dynamic=True).to_quantization_spec() return QLayerConfig(weight=spec, input_tensors=spec_dynamic)
[docs] class MXFP4_MXFP6E2M3Scheme(QuantizationScheme): """Scheme for MXFP4 weight and MXFP6E2M3 activation input quantization.""" def __init__(self) -> None: pass @property def config(self) -> QLayerConfig: weight_spec = OCP_MXFP4Spec(ch_axis=-1, is_dynamic=False).to_quantization_spec() input_spec = OCP_MXFP6E2M3Spec(ch_axis=-1, is_dynamic=True).to_quantization_spec() return QLayerConfig(weight=weight_spec, input_tensors=input_spec)
[docs] class MX6Scheme(QuantizationScheme): """Scheme for MX6 quantization.""" def __init__(self) -> None: pass @property def config(self) -> QLayerConfig: spec = MX6Spec(ch_axis=-1, block_size=32).to_quantization_spec() return QLayerConfig(weight=spec, input_tensors=spec)
[docs] class BFP16Scheme(QuantizationScheme): """Scheme for BFP16 quantization.""" def __init__(self) -> None: pass @property def config(self) -> QLayerConfig: spec = BFP16Spec(ch_axis=-1).to_quantization_spec() return QLayerConfig(weight=spec, input_tensors=spec)
[docs] class MXFP4_FP8Scheme(QuantizationScheme): """Scheme for MXFP4 weight and FP8 activation input quantization.""" def __init__(self) -> None: pass @property def config(self) -> QLayerConfig: weight_spec = OCP_MXFP4Spec(ch_axis=-1, is_dynamic=False, scale_calculation_mode="even").to_quantization_spec() input_spec = FP8E4M3PerTensorSpec(observer_method="min_max", is_dynamic=False).to_quantization_spec() return QLayerConfig(weight=weight_spec, input_tensors=input_spec)
[docs] class PTPCFP8Scheme(QuantizationScheme): """Scheme for PTPC FP8 quantization (Dynamic activation per-token quantization, weight quantization per-channel). Uses FP8 Per-Channel Static for weights and FP8 Per-Token Dynamic for activations. """ def __init__(self) -> None: pass @property def config(self) -> QLayerConfig: weight_spec = FP8E4M3PerChannelSpec(is_dynamic=False, ch_axis=0).to_quantization_spec() input_spec = FP8E4M3PerChannelSpec(is_dynamic=True, ch_axis=1).to_quantization_spec() return QLayerConfig(weight=weight_spec, input_tensors=input_spec)
[docs] class QuantizationSchemeCollection: """Collection for quantization schemes.""" def __init__(self) -> None: self._schemes: dict[str, QuantizationScheme] = {} self._collect_supported_schemes() self._custom_registered_schemes: list[str] = [] def _collect_supported_schemes(self) -> None: """Collect all supported quantization schemes.""" # INT4 weight-only schemes self._schemes["int4_wo_32"] = Int4WeightOnlyScheme(group_size=32) self._schemes["int4_wo_64"] = Int4WeightOnlyScheme(group_size=64) self._schemes["int4_wo_128"] = Int4WeightOnlyScheme(group_size=128) self._schemes["int4_wo_per_channel"] = Int4WeightOnlyPerChannelScheme() # UINT4 weight-only schemes self._schemes["uint4_wo_32"] = Uint4WeightOnlyScheme(group_size=32) self._schemes["uint4_wo_64"] = Uint4WeightOnlyScheme(group_size=64) self._schemes["uint4_wo_128"] = Uint4WeightOnlyScheme(group_size=128) self._schemes["uint4_wo_per_channel"] = Uint4WeightOnlyPerChannelScheme() # INT8 scheme self._schemes["int8"] = Int8Scheme() # FP8 quantization schemes self._schemes["fp8"] = FP8Scheme() self._schemes["ptpc_fp8"] = PTPCFP8Scheme() # OCP MXFP quantization schemes self._schemes["mxfp4"] = MXFP4Scheme() self._schemes["mxfp6_e3m2"] = MXFP6E3M2Scheme() self._schemes["mxfp6_e2m3"] = MXFP6E2M3Scheme() self._schemes["mxfp4_mxfp6_e2m3"] = MXFP4_MXFP6E2M3Scheme() self._schemes["mxfp4_fp8"] = MXFP4_FP8Scheme() # MX6 quantization schemes self._schemes["mx6"] = MX6Scheme() # BFP16 quantization schemes self._schemes["bfp16"] = BFP16Scheme()
[docs] def register_scheme(self, scheme_name: str, scheme: QuantizationScheme) -> None: """Register a quantization scheme.""" if scheme_name in self._schemes: raise ValueError(f"Scheme '{scheme_name}' already registered, please use a different name.") self._schemes[scheme_name] = scheme self._custom_registered_schemes.append(scheme_name)
[docs] def unregister_scheme(self, scheme_name: str) -> None: """Unregister a quantization scheme.""" if scheme_name not in self._schemes: raise ValueError(f"Scheme '{scheme_name}' not found, please check the name.") if scheme_name not in self._custom_registered_schemes: raise ValueError( f"Scheme '{scheme_name}' not registered as custom scheme, quark built-in schemes cannot be unregistered." ) del self._schemes[scheme_name] self._custom_registered_schemes.remove(scheme_name)
[docs] def get_supported_schemes(self) -> list[str]: """Get list of supported quantization schemes.""" return list(self._schemes.keys())
[docs] def get_scheme(self, scheme_name: str) -> QuantizationScheme: """Get a quantization scheme by name.""" return self._schemes[scheme_name]
[docs] class LLMTemplate: """ A configuration template that defines how to quantize specific types of LLM models. Each LLM architecture (like llama, qwen, deepseek, etc.) has its own unique structure and naming patterns for layers. This template allows specifying those patterns and quantization settings in a reusable way. :param str model_type: Type of the LLM model. :param List[str] kv_layers_name: List of k_proj and v_proj layer name patterns to match. Default is ``None``. :param Union[str, List[str]] q_layer_name: q_proj layer name pattern to match. Default is ``None``. :param List[str] exclude_layers_name: List of layer name patterns to exclude from quantization. Default is ``[]``. :param AWQConfig awq_config: Configuration for AWQ algorithm. Default is ``None``. :param GPTQConfig gptq_config: Configuration for GPTQ algorithm. Default is ``None``. :param SmoothQuantConfig smoothquant_config: Configuration for SmoothQuant algorithm. Default is ``None``. :param AutoSmoothQuantConfig autosmoothquant_config: Configuration for AutoSmoothQuant algorithm. Default is ``None``. :param RotationConfig rotation_config: Configuration for Rotation algorithm. Default is ``None``. Note: - The quantization schemes supported by the template are: - fp8 - ptpc_fp8 - int4_wo_32 - int4_wo_64 - int4_wo_128 - int4_wo_per_channel - uint4_wo_32 - uint4_wo_64 - uint4_wo_128 - uint4_wo_per_channel - int8 - mxfp4 - mxfp6_e3m2 - mxfp6_e2m3 - mx6 - bfp16 - The quantization algorithms supported by the template are: - awq - gptq - smoothquant - autosmoothquant - rotation - The KV cache schemes supported by the template are: - fp8 - The attention schemes supported by the template are: - fp8 Creating a Custom Template: To create a custom template for a new model type, you can define layer name patterns and algorithm configurations specific to your model architecture. Take `moonshotai/Kimi-K2-Instruct <https://huggingface.co/moonshotai/Kimi-K2-Instruct>`__ as an example: .. code-block:: python from quark.torch import LLMTemplate # Create a new template template = LLMTemplate( model_type="kimi_k2", kv_layers_name=["*kv_b_proj"], exclude_layers_name=["lm_head"] ) # Register the template to LLMTemplate class (optional, if you want to use the template in other places) LLMTemplate.register_template(template) """ _templates: dict[str, LLMTemplate] = {} _SCHEME_COLLECTION = QuantizationSchemeCollection() _SUPPORTED_SCHEMES = _SCHEME_COLLECTION.get_supported_schemes() _SUPPORTED_ALGORITHMS = ["awq", "gptq", "gptaq", "smoothquant", "autosmoothquant", "qronos", "rotation"] _SUPPORTED_KV_CACHE_SCHEMES = ["fp8"] _SUPPORTED_ATTENTION_SCHEMES = ["fp8"] def __init__( self, model_type: str, kv_layers_name: list[str] | None = None, q_layer_name: str | list[str] | None = None, exclude_layers_name: list[str] = [], awq_config: AWQConfig | None = None, gptq_config: GPTQConfig | None = None, gptaq_config: GPTAQConfig | None = None, qronos_config: QronosConfig | None = None, smoothquant_config: SmoothQuantConfig | None = None, autosmoothquant_config: AutoSmoothQuantConfig | None = None, rotation_config: RotationConfig | None = None, ): self.model_type = model_type self.kv_layers_name = kv_layers_name self.q_layer_name = q_layer_name self.exclude_layers_name = exclude_layers_name # Algorithm-specific configuration fields self.algo_config: dict[str, AlgoConfig | None] = {} self.algo_config["awq"] = awq_config self.algo_config["gptq"] = gptq_config self.algo_config["gptaq"] = gptaq_config self.algo_config["qronos"] = qronos_config self.algo_config["smoothquant"] = smoothquant_config self.algo_config["autosmoothquant"] = autosmoothquant_config self.algo_config["rotation"] = rotation_config
[docs] @classmethod def list_available(cls: type[LLMTemplate]) -> list[str]: """ List all available model names of registered templates. :return: List of template names. :rtype: List[str] Example: .. code-block:: python from quark.torch import LLMTemplate templates = LLMTemplate.list_available() print(templates) # ['llama', 'opt', 'gpt2', ...] """ return list(cls._templates.keys())
[docs] @classmethod def register_template(cls, template: LLMTemplate) -> None: """ Register a template. :param LLMTemplate template: The template to register. Example: .. code-block:: python from quark.torch import LLMTemplate # Create template template = LLMTemplate( model_type="llama", kv_layers_name=["*k_proj", "*v_proj"], q_layer_name="*q_proj", exclude_layers_name=["lm_head"], ) # Register template LLMTemplate.register_template(template) """ if template.model_type in cls._templates: logger.warning( f"Template '{template.model_type}' already registered, will overwrite the existing template." ) cls._templates[template.model_type] = template
[docs] @classmethod def get(cls, model_type: str) -> LLMTemplate: """Get a template by model type. :param str model_type: Type of the model. It is obtained from the original LLM HuggingFace model's ``model.config.model_type`` attribute. When the model_type field is not defined, the ``model.config.architecture[0]`` is assigned as the model_type.. Available model types: - chatglm - cohere - dbrx - deepseek - deepseek_v2 - deepseek_v3 - gemma2 - gemma3 - gemma3_text - gptj - gpt_oss - grok-1 - instella - llama - llama4 - mistral - mixtral - mllama - olmo - opt - phi - phi3 - qwen - qwen2 - qwen2_moe - qwen3 - qwen3_moe - qwen3_vl_moe :return: The template object. :rtype: LLMTemplate Example: .. code-block:: python from quark.torch import LLMTemplate template = LLMTemplate.get("llama") print(template) """ if model_type not in cls._templates: available = ", ".join(cls.list_available()) raise ValueError( f"There is no model template defined for the model type '{model_type}'. Available templates: {available}." "you can refer to the example comments within the `register_template` function for instructions on how to register a new model." ) return cls._templates[model_type]
# Register a new quantization scheme for the template
[docs] @classmethod def register_scheme(cls, scheme_name: str, config: QLayerConfig) -> None: """ Register a new quantization scheme for LLMTemplate class. :param str scheme_name: Name of the scheme. :param QLayerConfig config: Configuration for the scheme. Example: .. code-block:: python # Register a new quantization scheme ``int8_wo (int8 weight-only)`` to the template from quark.torch import LLMTemplate from quark.torch.quantization.config.config import Int8PerTensorSpec, QLayerConfig quant_spec = Int8PerTensorSpec(observer_method="min_max", symmetric=True, scale_type="float", round_method="half_even", is_dynamic=False).to_quantization_spec() global_config = QLayerConfig(weight=quant_spec) LLMTemplate.register_scheme("int8_wo", config=global_config) """ cls._SCHEME_COLLECTION.register_scheme(scheme_name, QuantizationScheme(config)) cls._SUPPORTED_SCHEMES = cls._SCHEME_COLLECTION.get_supported_schemes()
[docs] @classmethod def unregister_scheme(cls, scheme_name: str) -> None: """ Unregister a quantization scheme. :param str scheme_name: Name of the scheme to unregister. Example: .. code-block:: python from quark.torch import LLMTemplate LLMTemplate.unregister_scheme("int8") """ cls._SCHEME_COLLECTION.unregister_scheme(scheme_name) cls._SUPPORTED_SCHEMES = cls._SCHEME_COLLECTION.get_supported_schemes()
[docs] @classmethod def get_supported_schemes(cls) -> list[str]: """Get list of supported quantization schemes.""" return cls._SUPPORTED_SCHEMES
[docs] def get_config( self, scheme: str, algorithm: str | list[str] | None = None, kv_cache_scheme: str | None = None, min_kv_scale: float = 0.0, attention_scheme: str | None = None, layer_config: dict[str, str] | None = None, layer_type_config: dict[type[nn.Module], str] | None = None, exclude_layers: list[str] | None = None, algo_configs: dict[str, AlgoConfig] | None = None, ) -> Config: """ Create a quantization configuration based on the provided parameters. :param str scheme: Name of the quantization scheme. :param Optional[Union[str, List[str]]] algorithm: Name or list of names of quantization algorithms to apply. :param Optional[str] kv_cache_scheme: Name of the KV cache quantization scheme. :param float min_kv_scale: Minimum value of KV Cache scale. :param Optional[str] attention_scheme: Name of the attention quantization scheme. :param Optional[Dict[str, str]] layer_config: Dictionary of layer name patterns and quantization scheme names. :param Optional[Dict[Type[nn.Module], str]] layer_type_config: Dictionary of layer types and quantization scheme names. :param Optional[List[str]] exclude_layers: List of layer names to exclude from quantization. :param Optional[Dict[str, AlgoConfig]] algo_configs: Dictionary of algorithm names to their configurations. Example: .. code-block:: python from quark.torch import LLMTemplate template = LLMTemplate.get("llama") config = template.get_config(scheme="fp8", kv_cache_scheme="fp8") """ # Check if the scheme is supported if scheme not in LLMTemplate._SUPPORTED_SCHEMES: raise ValueError(f"Unsupported quantization scheme: {scheme}") # Check if the algorithm is supported if algorithm: if isinstance(algorithm, str): algorithm = [algorithm] for algo in algorithm: if algo not in self._SUPPORTED_ALGORITHMS: raise ValueError(f"Unsupported algorithm: {algo}") # Check if the KV cache scheme is supported if kv_cache_scheme and kv_cache_scheme not in self._SUPPORTED_KV_CACHE_SCHEMES: raise ValueError(f"Unsupported KV cache scheme: {kv_cache_scheme}") # Check if the attention scheme is supported if attention_scheme and attention_scheme not in self._SUPPORTED_ATTENTION_SCHEMES: raise ValueError(f"Unsupported attention scheme: {attention_scheme}") # Set up base global configuration global_config = self._create_global_config(scheme) # Create config object config = Config( global_quant_config=global_config, min_kv_scale=min_kv_scale, exclude=self.exclude_layers_name, kv_cache_group=self.kv_layers_name, ) # Apply algorithm if specified if algorithm: config = self._set_algorithm(config, algorithm, algo_configs) # Set layer quantization configuration first # Apply per-layer configuration overrides if layer_config: config = self._set_layer_name_config(config, layer_config) if layer_type_config: config = self._set_layer_type_config(config, layer_type_config) # Set KV cache quantization quantization configuration after setting layer quantization configuration to aviod the conflicts # Apply KV cache quantization if specified if kv_cache_scheme: config = self._set_kv_cache_config(config, kv_cache_scheme) # Apply attention quantization if specified if attention_scheme: config = self._set_attention_config(config, attention_scheme) # Apply exclude layers configuration if exclude_layers is not None: config = self._set_exclude_layers_config(config, exclude_layers) return config
def _create_global_config(self, scheme: str) -> QLayerConfig: return LLMTemplate._SCHEME_COLLECTION.get_scheme(scheme).config def _set_algorithm( self, config: Config, algorithm: str | list[str], algo_configs: dict[str, AlgoConfig] | None = None ) -> Config: if isinstance(algorithm, str): algorithm = [algorithm] # Clone algo_config to avoid modifying the original algo_config effective_algo_config = dict(self.algo_config) if algo_configs: for algo_name, algo_cfg in algo_configs.items(): effective_algo_config[algo_name.lower()] = algo_cfg for algo in algorithm: if config.algo_config is None: config.algo_config = [] algo_key = algo.lower() if algo_key == "awq": if effective_algo_config.get("awq"): config.algo_config.append(effective_algo_config["awq"]) else: logger.warning( f"No AWQ config provided for {self.model_type}, " "falling back to default AWQ config. If you need customized AWQ quantization for this model, " "please provide the AWQ config, and pass it to LLMTemplate constructor." ) # Fallback to default AWQ config config.algo_config.append(AWQConfig()) elif algo_key == "gptq": if effective_algo_config.get("gptq"): config.algo_config.append(effective_algo_config["gptq"]) else: logger.warning( f"No GPTQ config provided for {self.model_type}, " "falling back to default GPTQ config. If you need customized GPTQ quantization for this model, " "please provide the GPTQ config, and pass it to LLMTemplate constructor." ) # Fallback to default GPTQ config config.algo_config.append(GPTQConfig()) elif algo_key == "gptaq": if effective_algo_config.get("gptaq"): config.algo_config.append(effective_algo_config["gptaq"]) else: raise ValueError( f"No GPTAQ config provided for {self.model_type}. " "Please provide a GPTAQ config and pass it to the LLMTemplate constructor, " "or use a model type that has GPTAQ configuration defined." ) elif algo_key == "qronos": if effective_algo_config.get("qronos"): config.algo_config.append(effective_algo_config["qronos"]) else: raise ValueError( f"No Qronos config provided for {self.model_type}. " "Please provide a Qronos config and pass it to the LLMTemplate constructor, " "or use a model type that has Qronos configuration defined." ) elif algo_key == "smoothquant": if effective_algo_config.get("smoothquant"): config.algo_config.append(effective_algo_config["smoothquant"]) else: logger.warning( f"No SmoothQuant config provided for {self.model_type}, " "falling back to default SmoothQuant config. If you need customized SmoothQuant quantization for this model, " "please provide the SmoothQuant config, and pass it to LLMTemplate constructor." ) # Fallback to default SmoothQuant config config.algo_config.append(SmoothQuantConfig()) elif algo_key == "autosmoothquant": if effective_algo_config.get("autosmoothquant"): config.algo_config.append(effective_algo_config["autosmoothquant"]) else: logger.warning( f"No AutoSmoothQuant config provided for {self.model_type}, " "falling back to default AutoSmoothQuant config. If you need customized AutoSmoothQuant quantization for this model, " "please provide the AutoSmoothQuant config, and pass it to LLMTemplate constructor." ) # Fallback to default AutoSmoothQuant config config.algo_config.append(AutoSmoothQuantConfig()) elif algo_key == "rotation": if effective_algo_config.get("rotation"): config.algo_config.append(effective_algo_config["rotation"]) else: logger.warning( f"No Rotation config provided for {self.model_type}, " "not to use Rotation quantization for this model." ) else: raise ValueError(f"Unsupported algorithm: {algo}") return config def _set_kv_cache_config(self, config: Config, kv_cache_scheme: str) -> Config: # Use pattern matching to identify KV projection layers if self.kv_layers_name is None: return config if kv_cache_scheme == "fp8": spec = FP8E4M3PerTensorSpec(observer_method="min_max", is_dynamic=False).to_quantization_spec() for layer_name in self.kv_layers_name: # Get the layer quantization configuration layer_quant_config = get_layer_quant_config(config, nn.Linear, layer_name) if layer_quant_config is None: continue layer_config = QLayerConfig( weight=layer_quant_config.weight, input_tensors=layer_quant_config.input_tensors, output_tensors=spec, ) config.layer_quant_config[layer_name] = layer_config # Create a separate config for KV cache kv_cache_config = QLayerConfig( weight=layer_quant_config.weight, input_tensors=layer_quant_config.input_tensors, output_tensors=spec, ) config.kv_cache_quant_config[layer_name] = kv_cache_config else: raise ValueError(f"Unsupported KV cache quantization scheme: {kv_cache_scheme}") return config def _set_attention_config(self, config: Config, attention_scheme: str) -> Config: if attention_scheme == "fp8": spec = FP8E4M3PerTensorSpec(observer_method="min_max", is_dynamic=False).to_quantization_spec() config.softmax_quant_spec = spec if self.q_layer_name is not None: if isinstance(self.q_layer_name, str): self.q_layer_name = [self.q_layer_name] for q_layer_name in self.q_layer_name: # Get the layer quantization configuration layer_quant_config = get_layer_quant_config(config, nn.Linear, q_layer_name) if layer_quant_config is None: continue config.layer_quant_config[q_layer_name] = QLayerConfig( weight=layer_quant_config.weight, input_tensors=layer_quant_config.input_tensors, output_tensors=spec, ) else: raise ValueError(f"Unsupported attention quantization scheme: {attention_scheme}") return config def _set_layer_name_config(self, config: Config, layer_name_config: dict[str, str]) -> Config: for layer_name, layer_scheme in layer_name_config.items(): config.layer_quant_config[layer_name] = LLMTemplate._SCHEME_COLLECTION.get_scheme(layer_scheme).config return config def _set_layer_type_config(self, config: Config, layer_type_config: dict[type[nn.Module], str]) -> Config: for layer_type, layer_scheme in layer_type_config.items(): config.layer_type_quant_config[layer_type] = LLMTemplate._SCHEME_COLLECTION.get_scheme(layer_scheme).config return config def _set_exclude_layers_config(self, config: Config, exclude_layers: list[str]) -> Config: config.exclude.clear() for layer_name in exclude_layers: config.exclude.append(layer_name) return config
# Default template configurations DEFAULT_TEMPLATES = { "chatglm": { "kv_layers_name": ["*query_key_value"], "q_layer_name": "*query_key_value", "exclude_layers_name": ["transformer.output_layer"], "awq_config": "chatglm", "gptq_config": "chatglm", "gptaq_config": "chatglm", "qronos_config": "chatglm", "smoothquant_config": "chatglm", "autosmoothquant_config": "chatglm", "rotation_config": "chatglm", }, "cohere": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head"], "awq_config": "cohere", "gptq_config": "cohere", "gptaq_config": "cohere", "qronos_config": "cohere", "smoothquant_config": "cohere", "autosmoothquant_config": "cohere", "rotation_config": "cohere", }, "dbrx": { "kv_layers_name": ["*Wqkv"], "q_layer_name": "*Wqkv", "exclude_layers_name": ["lm_head", "*router.layer"], "awq_config": "dbrx", "gptq_config": "dbrx", "gptaq_config": "dbrx", "qronos_config": "dbrx", "smoothquant_config": "dbrx", "autosmoothquant_config": "dbrx", "rotation_config": "dbrx", }, "deepseek": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head", "*.gate"], "awq_config": "deepseek", "gptq_config": "deepseek", "gptaq_config": "deepseek", "qronos_config": "deepseek", "smoothquant_config": "deepseek", "autosmoothquant_config": "deepseek", "rotation_config": "deepseek", }, "deepseek_v2": { "kv_layers_name": ["*kv_b_proj"], "q_layer_name": ["*q_a_proj", "*q_b_proj"], "exclude_layers_name": ["lm_head", "*self_attn*", "*mlp.gate"], "awq_config": "deepseek_v2", "gptq_config": "deepseek_v2", "gptaq_config": "deepseek_v2", "qronos_config": "deepseek_v2", "smoothquant_config": "deepseek_v2", "autosmoothquant_config": "deepseek_v2", "rotation_config": "deepseek_v2", }, "deepseek_v3": { "kv_layers_name": ["*kv_b_proj"], "q_layer_name": ["*q_a_proj", "*q_b_proj"], "exclude_layers_name": ["lm_head", "*self_attn*", "*mlp.gate"], "awq_config": "deepseek_v3", "gptq_config": "deepseek_v3", "gptaq_config": "deepseek_v3", "qronos_config": "deepseek_v3", "smoothquant_config": "deepseek_v3", "autosmoothquant_config": "deepseek_v3", "rotation_config": "deepseek_v3", }, "gemma2": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head"], "awq_config": "gemma2", "gptq_config": "gemma2", "gptaq_config": "gemma2", "qronos_config": "gemma2", "smoothquant_config": "gemma2", "autosmoothquant_config": "gemma2", "rotation_config": "gemma2", }, "gemma3": { "kv_layers_name": ["*language_model.*k_proj", "*language_model.*v_proj"], "q_layer_name": "*language_model.*q_proj", "exclude_layers_name": ["*vision_tower*", "*multi_modal_projector*", "*lm_head"], "awq_config": "gemma3", "gptq_config": "gemma3", "gptaq_config": "gemma3", "qronos_config": "gemma3", "smoothquant_config": "gemma3", "autosmoothquant_config": "gemma3", "rotation_config": "gemma3", }, "gemma3_text": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["*lm_head"], "awq_config": "gemma3_text", "gptq_config": "gemma3_text", "gptaq_config": "gemma3_text", "qronos_config": "gemma3_text", "smoothquant_config": "gemma3_text", "autosmoothquant_config": "gemma3_text", "rotation_config": "gemma3_text", }, "gptj": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head"], "awq_config": "gptj", "gptq_config": "gptj", "gptaq_config": "gptj", "qronos_config": "gptj", "smoothquant_config": "gptj", "autosmoothquant_config": "gptj", "rotation_config": "gptj", }, "gpt_oss": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head", "*router*"], "awq_config": "gpt_oss", "gptq_config": "gpt_oss", "gptaq_config": "gpt_oss", "qronos_config": "gpt_oss", "smoothquant_config": "gpt_oss", "autosmoothquant_config": "gpt_oss", "rotation_config": "gpt_oss", }, "granitemoehybrid": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head", "*router*"], "awq_config": "granitemoehybrid", "gptq_config": "granitemoehybrid", "gptaq_config": "granitemoehybrid", "qronos_config": "granitemoehybrid", "smoothquant_config": "granitemoehybrid", "autosmoothquant_config": "granitemoehybrid", "rotation_config": "granitemoehybrid", }, "grok-1": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head", "*.gate"], "awq_config": "grok-1", "gptq_config": "grok-1", "gptaq_config": "grok-1", "qronos_config": "grok-1", "smoothquant_config": "grok-1", "autosmoothquant_config": "grok-1", "rotation_config": "grok-1", }, "instella": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head"], "awq_config": "instella", "gptq_config": "instella", "gptaq_config": "instella", "qronos_config": "instella", "smoothquant_config": "instella", "autosmoothquant_config": "instella", "rotation_config": "instella", }, "llama": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head"], "awq_config": "llama", "gptq_config": "llama", "gptaq_config": "llama", "qronos_config": "llama", "smoothquant_config": "llama", "autosmoothquant_config": "llama", "rotation_config": "llama", }, "llama4": { "kv_layers_name": ["*language_model.*.k_proj", "*language_model.*.v_proj"], "q_layer_name": "*language_model.*.q_proj", "exclude_layers_name": [ "multi_modal_projector*", "*feed_forward.router", "vision_model*", "*lm_head", ], "awq_config": "llama4", "gptq_config": "llama4", "gptaq_config": "llama4", "qronos_config": "llama4", "smoothquant_config": "llama4", "autosmoothquant_config": "llama4", "rotation_config": "llama4", }, "mistral": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head"], "awq_config": "mistral", "gptq_config": "mistral", "gptaq_config": "mistral", "qronos_config": "mistral", "smoothquant_config": "mistral", "autosmoothquant_config": "mistral", "rotation_config": "mistral", }, "mixtral": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head", "*.gate"], "awq_config": "mixtral", "gptq_config": "mixtral", "gptaq_config": "mixtral", "qronos_config": "mixtral", "smoothquant_config": "mixtral", "autosmoothquant_config": "mixtral", "rotation_config": "mixtral", }, "mllama": { "kv_layers_name": ["*language_model.*k_proj", "*language_model.*v_proj"], "q_layer_name": "*self_attn.q_proj", "exclude_layers_name": ["*lm_head", "*patch_embedding", "multi_modal_projector"], "awq_config": "mllama", "gptq_config": "mllama", "gptaq_config": "mllama", "qronos_config": "mllama", "smoothquant_config": "mllama", "autosmoothquant_config": "mllama", "rotation_config": "mllama", }, "olmo": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head"], "awq_config": "olmo", "gptq_config": "olmo", "gptaq_config": "olmo", "qronos_config": "olmo", "smoothquant_config": "olmo", "autosmoothquant_config": "olmo", "rotation_config": "olmo", }, "opt": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head"], "awq_config": "opt", "gptq_config": "opt", "gptaq_config": "opt", "qronos_config": "opt", "smoothquant_config": "opt", "autosmoothquant_config": "opt", "rotation_config": "opt", }, "phi": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head"], "awq_config": "phi", "gptq_config": "phi", "gptaq_config": "phi", "qronos_config": "phi", "smoothquant_config": "phi", "autosmoothquant_config": "phi", "rotation_config": "phi", }, "phi3": { "kv_layers_name": ["*qkv_proj"], "q_layer_name": "*qkv_proj", "exclude_layers_name": ["lm_head"], "awq_config": "phi3", "gptq_config": "phi3", "gptaq_config": "phi3", "qronos_config": "phi3", "smoothquant_config": "phi3", "autosmoothquant_config": "phi3", "rotation_config": "phi3", }, "qwen": { "kv_layers_name": ["*c_attn"], "q_layer_name": "*c_attn", "exclude_layers_name": ["lm_head"], "awq_config": "qwen", "gptq_config": "qwen", "gptaq_config": "qwen", "qronos_config": "qwen", "smoothquant_config": "qwen", "autosmoothquant_config": "qwen", "rotation_config": "qwen", }, "qwen2": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head"], "awq_config": "qwen2", "gptq_config": "qwen2", "gptaq_config": "qwen2", "qronos_config": "qwen2", "smoothquant_config": "qwen2", "autosmoothquant_config": "qwen2", "rotation_config": "qwen2", }, "qwen2_moe": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head", "*.gate", "*.shared_expert_gate"], "awq_config": "qwen2_moe", "gptq_config": "qwen2_moe", "gptaq_config": "qwen2_moe", "qronos_config": "qwen2_moe", "smoothquant_config": "qwen2_moe", "autosmoothquant_config": "qwen2_moe", "rotation_config": "qwen2_moe", }, "qwen3": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head"], "awq_config": "qwen3", "gptq_config": "qwen3", "gptaq_config": "qwen3", "qronos_config": "qwen3", "smoothquant_config": "qwen3", "autosmoothquant_config": "qwen3", "rotation_config": "qwen3", }, "qwen3_moe": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head", "*.gate"], "awq_config": "qwen3_moe", "gptq_config": "qwen3_moe", "gptaq_config": "qwen3_moe", "qronos_config": "qwen3_moe", "smoothquant_config": "qwen3_moe", "autosmoothquant_config": "qwen3_moe", "rotation_config": "qwen3_moe", }, "qwen3_vl_moe": { "kv_layers_name": ["*k_proj", "*v_proj"], "q_layer_name": "*q_proj", "exclude_layers_name": ["lm_head", "*mlp.gate", "*.visual.*"], "awq_config": "qwen3_vl_moe", "gptq_config": "qwen3_vl_moe", "gptaq_config": "qwen3_vl_moe", "qronos_config": "qwen3_vl_moe", "smoothquant_config": "qwen3_vl_moe", "autosmoothquant_config": "qwen3_vl_moe", "rotation_config": "qwen3_vl_moe", }, } def _create_template_from_config(model_type: str, config: dict[str, Any]) -> LLMTemplate: """create a template from configuration dictionary.""" return LLMTemplate( model_type=model_type, kv_layers_name=config["kv_layers_name"], q_layer_name=config["q_layer_name"], exclude_layers_name=config["exclude_layers_name"], awq_config=get_algo_config("awq", config["awq_config"]), # type: ignore gptq_config=get_algo_config("gptq", config["gptq_config"]), # type: ignore gptaq_config=get_algo_config("gptaq", config["gptaq_config"]), # type: ignore qronos_config=get_algo_config("qronos", config["qronos_config"]), # type: ignore smoothquant_config=get_algo_config("smoothquant", config["smoothquant_config"]), # type: ignore autosmoothquant_config=get_algo_config("autosmoothquant", config["autosmoothquant_config"]), rotation_config=get_algo_config("rotation", config["rotation_config"]), ) # type: ignore """ Developer Note for Quark Engineers: ==================================== To add a new model template, follow these steps: 1. Add the model configuration to DEFAULT_TEMPLATES dictionary above. 2. Add corresponding algorithm configs to the algo config registry if the algorithm is needed. (see quark/torch/quantization/config/algo_config.py) 3. Update the docstring list in LLMTemplate.get() method to include the new model type. 4. Test the new template with various quantization schemes and algorithms Example for adding "new_model": .. code-block:: python "new_model": { "kv_layers_name": ["*attention.k_proj", "*attention.v_proj"], "q_layer_name": "*attention.q_proj", "exclude_layers_name": ["lm_head"], "awq_config": "new_model", "gptq_config": "new_model", "smoothquant_config": "new_model", "autosmoothquant_config": "new_model", "rotation_config": "new_model", } """ # Register built-in templates for model_type, config in DEFAULT_TEMPLATES.items(): if model_type not in LLMTemplate._templates: template = _create_template_from_config(model_type, config) LLMTemplate.register_template(template)