Source code for quark.torch.quantization.api

#
# Copyright (C) 2023 - 2025 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
#
"""Quark Quantization API for PyTorch."""

import json
import time
from typing import Any, Iterable

import torch
import torch.fx
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

import quark.torch.kernel
from quark.shares.utils.import_utils import is_safetensors_available, is_transformers_available
from quark.shares.utils.log import ScreenLogger, log_errors
from quark.torch.algorithm.api import add_algorithm_config_by_model, apply_advanced_quant_algo
from quark.torch.algorithm.utils.utils import clear_memory
from quark.torch.quantization.config.config import QConfig, QLayerConfig, QTensorConfig
from quark.torch.quantization.config.config_verification import ConfigVerifier
from quark.torch.quantization.config.type import Dtype, QSchemeType, QuantizationMode
from quark.torch.quantization.graph.processor.pre_check_befor_quant import check_supported_model_and_config
from quark.torch.quantization.graph.processor.processor import (
    post_calib_optimize,
    post_quant_optimize,
    prepare_quant_model,
)
from quark.torch.quantization.model_transformation import process_model_transformation
from quark.torch.quantization.nn.modules import (
    QuantConv2d,
    QuantConvTranspose2d,
    QuantEmbedding,
    QuantEmbeddingBag,
    QuantLinear,
)
from quark.torch.quantization.nn.modules.mixin import QuantMixin
from quark.torch.quantization.tensor_quantize import (
    FakeQuantizeBase,
    NonScaledFakeQuantize,
    ScaledFakeQuantize,
    SequentialQuantize,
    enable_or_disable_quantizer,
)
from quark.torch.quantization.utils import count_calibration_tokens, deep_compare
from quark.torch.utils import (
    QUARK_COUNT_OBSERVED_SAMPLES,
    QUARK_TOKENS_DISTRIBUTION_PATH,
    TOKEN_DISTRIBUTION_THRESHOLD,
    create_pack_method,
    getattr_recursive,
    gpu_memory_profiled,
    setattr_recursive,
)

if is_transformers_available():
    from transformers.feature_extraction_utils import BatchFeature

import os
from collections import Counter
from pathlib import Path

from quark.torch.quantization.debug import check_scale_stats, collect_quantization_statistics, insert_stats_hooks

if is_safetensors_available():
    from safetensors.torch import load_file

__all__ = ["ModelQuantizer", "load_params"]

logger = ScreenLogger(__name__)

QUARK_QUANT_OPS: dict[
    str, type[QuantConv2d | QuantConvTranspose2d | QuantLinear | QuantEmbedding | QuantEmbeddingBag]
] = {
    "QuantConv2d": QuantConv2d,
    "QuantConvTranspose2d": QuantConvTranspose2d,
    "QuantLinear": QuantLinear,
    "QuantEmbedding": QuantEmbedding,
    "QuantEmbeddingBag": QuantEmbeddingBag,
}


[docs] class ModelQuantizer: """ Provides an API for quantizing deep learning models using PyTorch. This class handles the configuration and processing of the model for quantization based on user-defined parameters. It is essential to ensure that the 'config' provided has all necessary quantization parameters defined. This class assumes that the model is compatible with the quantization settings specified in 'config'. :param QConfig config: The model quantization configuration. """ def __init__(self, config: QConfig, multi_device: bool = False) -> None: self.config = config self._is_accelerate: bool | None = None self.multi_device: bool = multi_device self.config_verifier = ConfigVerifier(config) self.init_config()
[docs] @gpu_memory_profiled(tag=" QuantizeModel") # type: ignore[arg-type] def quantize_model( self, model: nn.Module, dataloader: DataLoader[torch.Tensor] | DataLoader[list[dict[str, torch.Tensor]]] | DataLoader[dict[str, torch.Tensor]] | DataLoader[list["BatchFeature"]] | None = None, ) -> nn.Module: """ Quantizes the given PyTorch model to optimize its performance and reduce its size. The dataloader is used to provide data necessary for calibration during the quantization process. Depending on the type of data provided (either tensors directly or structured as lists or dictionaries of tensors), the function will adapt the quantization approach accordingly. It is important that the model and dataloader are compatible in terms of the data they expect and produce. Misalignment in data handling between the model and the dataloader can lead to errors during the quantization process. :param torch.nn.Module model: The PyTorch model to be quantized. This model should be already trained and ready for quantization. :param Optional[Union[DataLoader[torch.Tensor], DataLoader[List[Dict[str, torch.Tensor]]], DataLoader[Dict[str, torch.Tensor]], DataLoader[List[BatchFeature]]]] dataloader: The ``torch.utils.data.DataLoader`` providing data that the quantization process will use for calibration. This can be a simple ``DataLoader`` returning tensors, or a more complex structure returning either a list of dictionaries or a dictionary of tensors. :return: The quantized version of the input model. This model is now optimized for inference with reduced size and potentially improved performance on targeted devices. :rtype: torch.nn.Module Example: .. code-block:: python # Model & Data preparation from torch.utils.data import DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer from quark.torch.quantization.config.config import QConfig from quark.torch.quantization.config.type import Dtype, ScaleType, RoundType, QSchemeType from quark.torch.quantization.observer.observer import PerGroupMinMaxObserver from quark.torch import ModelQuantizer model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m", torch_dtype="auto") model.eval() tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") quant_spec = QTensorConfig( dtype=Dtype.uint4, observer_cls=PerGroupMinMaxObserver, symmetric=False, scale_type=ScaleType.float, round_method=RoundType.half_even, qscheme=QSchemeType.per_group, ch_axis=1, is_dynamic=False, group_size=128 ) quant_config = QConfig(global_quant_config=QLayerConfig(weight=quant_spec)) text = "Hello, how are you?" tokenized_outputs = tokenizer(text, return_tensors="pt") calib_dataloader = DataLoader(tokenized_outputs['input_ids']) quantizer = ModelQuantizer(quant_config) quant_model = quantizer.quantize(model, calib_dataloader) """ logger.info(f"Quantizing with the quantization configuration:\n{self.config}") # Step0: Pre quant device check self._check_model_device(model) # Step1: Prepare quantization model for graph mode and eager mode model = self._prepare_model(model) # Step2[optional]: Apply Advanced quant algo such as gptq, awq, qronos ... model = self._apply_advanced_quant_algo(model, dataloader) # Step3[optional]: Do calibration model = self._do_calibration(model, dataloader) # Step4[optional]: Post calib optimization model = self._do_post_calib_optimazation(model) # Optionally, collect statistics on the quantization errors over the network weights/activations. if os.environ.get("QUARK_DEBUG", None) is not None: log_dir = Path(os.environ["QUARK_DEBUG"]) log_dir.mkdir(parents=True, exist_ok=True) stats: dict[str, Any] = {} dataloader = dataloader if not self.config_verifier.is_all_dynamic else None with insert_stats_hooks(model, stats, log_dir): collect_quantization_statistics(model, dataloader, stats, log_dir) # Check the scale of the quantized model. if os.getenv("QUARK_CHECK_SCALE") == "1": check_scale_stats(model, self.config) # Add quant_config to attribute of the quantized model, so that it can be used for export model.quant_config = self.config # Add a flag to indicate that the model is quantized model.quark_quantized = True return model
def _check_model_device(self, model: nn.Module) -> None: # using accelerate cause, device can not be cpu or disk, temporarily if hasattr(model, "hf_device_map"): if not self.multi_device: for _, layer_device in model.hf_device_map.items(): if layer_device == "cpu" or layer_device == "disk": # TODO: We should handle this for customers. raise MemoryError( "Out of memory. The available GPU memory is insufficient to load the entire model. You can try adding '--multi_device' " ) self._is_accelerate = True else: self._is_accelerate = False def _generate_complete_config_by_model( self, model: nn.Module, dataloader: DataLoader[torch.Tensor] | DataLoader[list[dict[str, torch.Tensor]]] | DataLoader[dict[str, torch.Tensor]] | DataLoader[list["BatchFeature"]] | None, ) -> None: """ Generates a complete configuration based on the provided model and dataloader. """ self.config = add_algorithm_config_by_model(model, dataloader, self.config)
[docs] @staticmethod @torch.no_grad() def freeze( model: nn.Module | torch.fx.GraphModule, quantize: bool | None = None ) -> nn.Module | torch.fx.GraphModule: """ Freezes the quantized model by replacing ``FakeQuantize`` modules with ``FrozenFakeQuantize`` modules. In order to be able to compile a quantized model through ``torch.compile``, this method needs to be applied. :param torch.nn.Module model: The neural network model containing quantized layers. :param Optional[bool] quantize: Whether to effectively quantize weights, moving away from soft weights that are quantized on the fly to e.g. `QuantLinear.weight` actually holding the fake quantized weights. This can be disabled e.g. if we would like simply to move to use ``FrozenFakeQuantize`` from a model using ``QuantLinear`` that is already holding the fake quantized weights in high-precision. Defaults to ``True`` for PyTorch eager models, and ``False`` for FX graph models. :return: The modified model with ``FakeQuantize`` modules replaced by ``FrozenFakeQuantize`` modules. :rtype: torch.nn.Module """ logger.info("Freeze model start.") # ----replace FakeQuantize to FrozenFakeQuantize -------------- named_modules = dict(model.named_modules(remove_duplicate=False)) # NOTE: For GraphModules, we disable soft weight quantization as # ONNX export expects us to run `QuantMixin.get_quant_weight`, `QuantMixin.get_quant_bias` in order to do pattern matching and add `QuantizeLinear` nodes in the ONNX graph for the weights. # This may actually also be required for non-graph models that pass # through ONNX export, it should be clarified with Haoliang. if isinstance(model, torch.fx.GraphModule): if quantize: raise ValueError( f"ModelQuantizer.freeze does not quantize the weights when using a model that is a torch.fx.GraphModule, but got quantize={quantize}. Please check your code or open an issue." ) quantize = False elif quantize is None: quantize = True # Lists the FakeQuantizeBase names which are modified calling # `FakeQuantizeBase.to_frozen_module`. frozen_names = [] quantizer_names = set() for name, module in named_modules.items(): if isinstance(module, QuantMixin): for subname, submodule in module.named_modules(): if isinstance(submodule, FakeQuantizeBase): full_name = name + "." + subname quantizer_names.add(full_name) if not submodule.is_dynamic: frozen_names.append(full_name) if quantize: if "weight_quantizer" in subname: module.weight.data = module.get_quant_weight(module.weight) elif "bias_quantizer" in subname and module.bias is not None: # TODO: better understand why in some cases (e.g. in test/test_for_torch/test_fx_quant_align_hw_pow_of_2.py's `test_torch_fold_bn_after_concat_strategy`) we do have a `bias_quantizer` module attached, but the bias is None. module.bias.data = module.get_quant_bias(module.bias) frozen_quantized_module = submodule.to_frozen_module(frozen_params=quantize) setattr_recursive(model, full_name, frozen_quantized_module) # ----if model is quantized in fx.graph mode-------------- if isinstance(model, torch.fx.GraphModule): # The graph may have `ScaledFakeQuantize` that are not leafs of # of `QuantMixin`, specifically the case for activation quantization. # See e.g. the test `test_pixel_shuffle_annotation`. for name, module in named_modules.items(): if isinstance(module, FakeQuantizeBase) and name not in quantizer_names: quantizer_names.add(name) if not module.is_dynamic: frozen_quantized_module = module.to_frozen_module(frozen_params=False) setattr_recursive(model, name, frozen_quantized_module) frozen_names.append(name) model = model.freeze_model() assert isinstance(model, torch.fx.GraphModule) model = post_quant_optimize(model=model, hw_constrain=True) # TODO pass argument if len(frozen_names) > 0: frozen_names = "\n- ".join(frozen_names) # type: ignore logger.debug(f"Converted to frozen quantizers: \n- {frozen_names}") logger.info("Freeze model end.") return model
def _prepare_model(self, model: nn.Module) -> nn.Module: if self.config.quant_mode is QuantizationMode.eager_mode: return process_model_transformation(model, self.config) elif self.config.quant_mode is QuantizationMode.fx_graph_mode: # Quantization with torch.fx does not support some quantization config and some FX graphs. # This raises an error if the config / model used are not supported. check_supported_model_and_config(model, self.config) # type: ignore [arg-type] return prepare_quant_model(model, self.config).eval() # type: ignore [arg-type] def _apply_advanced_quant_algo( self, model: nn.Module, dataloader: DataLoader[torch.Tensor] | DataLoader[list[dict[str, torch.Tensor]]] | DataLoader[dict[str, torch.Tensor]] | DataLoader[list["BatchFeature"]] | None = None, ) -> nn.Module: return apply_advanced_quant_algo(model, self.config, self._is_accelerate, dataloader) def _check_token_distribution( self, model: nn.Module, dataloader: DataLoader[torch.Tensor] | DataLoader[list[dict[str, torch.Tensor]]] | DataLoader[dict[str, torch.Tensor]] | DataLoader[list["BatchFeature"]], ) -> None: """ A helper function that warns when a MoE module received 0 token throughout the calibration process. """ assert 0.0 <= TOKEN_DISTRIBUTION_THRESHOLD <= 1.0, "threshold should be in [0.0, 1.0]" total_token_count = count_calibration_tokens(dataloader) if total_token_count == 0: logger.warning("No tokens found in calibration dataset. Skipping token distribution check.") return # Get the observer token count for each module token_counts: Counter[str] = Counter() for name, module in model.named_modules(): if isinstance(module, ScaledFakeQuantize): if "_input_quantizer" in name: if module.observer._num_observed_tokens is not None: token_counts[name.replace("._input_quantizer", "")] = module.observer._num_observed_tokens for module_name, token_count in token_counts.items(): if (token_count / float(total_token_count)) <= TOKEN_DISTRIBUTION_THRESHOLD: logger.warning( f"The module: {module_name} " f"received {token_count} tokens less than {TOKEN_DISTRIBUTION_THRESHOLD * 100:.1f}% " f"of all {total_token_count} calibration tokens." ) # Output the tokens distribution if enabled. if QUARK_TOKENS_DISTRIBUTION_PATH: token_stats = { "total_token_count": total_token_count, "token_counts": dict(token_counts), "threshold": TOKEN_DISTRIBUTION_THRESHOLD, } timestamp = int(time.time()) filename = f"tokens_distribution_{timestamp}.json" try: with open(os.path.join(QUARK_TOKENS_DISTRIBUTION_PATH, filename), "w", encoding="utf-8") as f: json.dump(token_stats, f, indent=2, ensure_ascii=False) logger.info(f"Token distribution statistics saved to {filename}") except Exception as e: logger.error(f"Failed to save token distribution statistics: {e}") # when using multi_device, you must add it here or offload will fail. # The gpu memory used for gradients cannot be cleaned up by torch.cuda.empty_cache() @torch.no_grad() def _do_calibration( self, model: nn.Module, dataloader: DataLoader[torch.Tensor] | DataLoader[list[dict[str, torch.Tensor]]] | DataLoader[dict[str, torch.Tensor]] | DataLoader[list["BatchFeature"]] | None = None, ) -> nn.Module: # just calib, turn off quantize if self.config_verifier.is_all_dynamic: # TODO: to be deperated logger.info("Dynamic quantization, no calibration.") elif self.config_verifier.is_weight_only or ( self.config_verifier.is_act_dynamic and not self.config_verifier.is_act_contain_scale_per_tensor ): logger.info("Weight calibration start.") for module in model.modules(): if isinstance(module, ScaledFakeQuantize): module.enable_observer() module.disable_fake_quant() # Simply run through the observers to set min_val, max_val, scale and zero_point buffers for the weight and bias. named_modules = dict(model.named_modules(remove_duplicate=False)) for name, module in tqdm(named_modules.items()): if isinstance(module, QuantMixin): if module._weight_quantizer is not None and isinstance( module._weight_quantizer, (ScaledFakeQuantize, SequentialQuantize) ): weight_quantizers: list[ScaledFakeQuantize] | SequentialQuantize = ( [module._weight_quantizer] if isinstance(module._weight_quantizer, ScaledFakeQuantize) else module._weight_quantizer ) is_static_not_quantized = all( hasattr(quantizer, "scale") for quantizer in weight_quantizers ) and all( quantizer.scale.numel() == 1 and quantizer.scale.item() == 1 for quantizer in weight_quantizers ) if is_static_not_quantized: # This condition prevents layers that have already been quantized from being quantized again. if module.weight.device == torch.device("meta"): weight = module._hf_hook.weights_map["weight"].data weight = module.get_quant_weight(weight.to(module._hf_hook.execution_device)) del weight else: _ = module.get_quant_weight(module.weight) if module._bias_quantizer is not None and isinstance( module._bias_quantizer, (ScaledFakeQuantize, SequentialQuantize) ): bias_quantizers: list[ScaledFakeQuantize] | SequentialQuantize = ( [module._bias_quantizer] if isinstance(module._bias_quantizer, ScaledFakeQuantize) else module._bias_quantizer ) is_static_not_quantized = all( hasattr(quantizer, "scale") for quantizer in bias_quantizers ) and all( quantizer.scale.numel() == 1 and quantizer.scale.item() == 1 for quantizer in bias_quantizers ) if is_static_not_quantized: if module.bias.device == torch.device("meta"): bias = module._hf_hook.weights_map["bias"].data _ = module.get_quant_bias(bias.to(module._hf_hook.execution_device)) del bias else: _ = module.get_quant_bias(module.bias) torch.cuda.empty_cache() clear_memory() logger.info("Weight calibration end.") else: logger.info("Calibration start.") for module in model.modules(): if isinstance(module, ScaledFakeQuantize): module.enable_observer() module.disable_fake_quant() assert dataloader is not None with torch.no_grad(): for data in tqdm(dataloader): if isinstance(data, dict): # pragma: no cover model(**data) elif is_transformers_available() and isinstance(data, BatchFeature): # pragma: no cover _ = model(**data) else: model(data) if QUARK_COUNT_OBSERVED_SAMPLES: self._check_token_distribution(model, dataloader) clear_memory() logger.info("Calibration end.") logger.info("Model quantization has been completed.") # step5[optional]: do evaluation, turn on quantize # Some algorithms handle the quantization of weights, bypassing the attached quantizers. override_weight_quantizers = False if self.config.algo_config is not None: for algo_config in self.config.algo_config: if algo_config.name.lower() in ["gptq", "gptaq", "qronos"] and not algo_config.static_groups: # type: ignore override_weight_quantizers = True break if override_weight_quantizers: logger.warning( f"The algorithm configuration {self.config.algo_config} with dynamic groups setting (static_groups=false) do not support FakeQuantize for export. " f"Turn off FakeQuantize for weight while keeping FakeQuantize enabled for activation to run evaluations." ) named_modules = dict(model.named_modules(remove_duplicate=False)) for _, module in tqdm(named_modules.items()): if isinstance(module, QuantMixin): if module._weight_quantizer is not None: enable_or_disable_quantizer(module._weight_quantizer, enable=False) if module._input_quantizer is not None: enable_or_disable_quantizer(module._input_quantizer, enable=True) if module._output_quantizer is not None: enable_or_disable_quantizer(module._output_quantizer, enable=True) else: for name, module in model.named_modules(): if isinstance(module, ScaledFakeQuantize): if module.is_dynamic and not ( module.is_scale_quant and module.qscheme == QSchemeType.per_tensor ): # For dynamic quantization, observer should be enable and update qparam every time. module.enable_observer() module.enable_fake_quant() else: module.disable_observer() module.enable_fake_quant() elif isinstance(module, NonScaledFakeQuantize): module.enable_fake_quant() return model def _do_post_calib_optimazation(self, model: nn.Module) -> nn.Module: """ In some case: 1. After calibration: get weight, activation and bias scale 2. Some hw constrain need let: bias_scale = weight_scale * act_scale After calibration, we need to do some optimization, and then perform QAT/export. """ if self.config.quant_mode is QuantizationMode.eager_mode: # remain this API TODO assert isinstance(model, nn.Module) return model elif self.config.quant_mode is QuantizationMode.fx_graph_mode: """ In calibration: observer will record tensor's distribution. Scale and ZP will be calculated. In some hardware constrain case. e.g. b_scale = w_scale * a_scale (we need to modify bias_scale after calibration) """ assert isinstance(model, torch.fx.GraphModule) model = post_calib_optimize(model) return model # type: ignore[no-any-return] def init_config(self) -> None: logger.info("Configuration checking start.") # TODO: Verify quant algo self.config_verifier.verify_config() if self.config_verifier.is_weight_only: config_parsing_result = "weight only quantization" else: if self.config_verifier.is_act_dynamic: config_parsing_result = "weight quantization and activation dynamic quantization" else: config_parsing_result = "weight quantization and activation static quantization" logger.info(f"Configuration checking end. The configuration is effective. This is {config_parsing_result}.")
def get_name_and_info(model_info: dict[str, Any], parent_key: str = "") -> Iterable[tuple[str, dict[str, Any]]]: for key, value in model_info.items(): new_key = f"{parent_key}.{key}" if parent_key else key if isinstance(value, dict): if value.get("type", None) is not None and value.get("weight", None) is not None: yield new_key, value else: yield from get_name_and_info(value, new_key) else: continue # TODO: This function is only used in load_params, add support for SequentialQuantize later def from_float_and_dict( float_module: nn.Module, quant_info: dict[str, Any], param_dict: dict[str, torch.Tensor], layer_name: str, compressed: bool = False, reorder: bool = True, ) -> nn.Module: input_tensors = None quant_params: dict[str, torch.Tensor | None] = {} if quant_info.get("input_quant") is not None: input_tensors = QTensorConfig.from_dict(quant_info["input_quant"]) if input_tensors is not None and not input_tensors.is_dynamic: quant_params["input_scale"] = param_dict[layer_name + ".input_scale"] # pragma: no cover quant_params["input_zero_point"] = param_dict[layer_name + ".input_zero_point"] # pragma: no cover output_tensors = None if quant_info.get("output_quant") is not None: output_tensors = QTensorConfig.from_dict(quant_info["output_quant"]) if output_tensors is not None and not output_tensors.is_dynamic: quant_params["output_scale"] = param_dict[layer_name + ".output_scale"] quant_params["output_zero_point"] = param_dict[layer_name + ".output_zero_point"] weight_qspec: QTensorConfig | None = None weight_key = quant_info.get("weight") if weight_key is None: raise KeyError("Missing 'weight' in quant_info") weight_tensor = param_dict[weight_key] if quant_info.get("weight_quant") is not None: weight_qspec = QTensorConfig.from_dict(quant_info["weight_quant"]) weight_scale = param_dict[layer_name + ".weight_scale"] weight_zero_point = param_dict[layer_name + ".weight_zero_point"] if compressed: assert isinstance(weight_qspec, QTensorConfig), "weight_qspec must be QTensorConfig instance" assert isinstance(weight_qspec.qscheme, QSchemeType), "weight_qspec.qscheme must be QSchemeType instance" assert isinstance(weight_qspec.dtype, Dtype), "weight_qspec.dtype must be Dtype instance" pack_method = create_pack_method(qscheme=weight_qspec.qscheme.value, dtype=weight_qspec.dtype.value) weight_tensor = pack_method.unpack( weight_tensor, reorder, **({"origin_packed_axis_size": weight_scale.shape[-1]} if weight_scale.shape != torch.Size([]) else {}), ) weight_tensor = quark.torch.kernel.dequantize( # type: ignore[attr-defined] weight_qspec.dtype.value, weight_tensor, weight_scale, weight_zero_point, weight_qspec.ch_axis, weight_qspec.group_size, weight_qspec.qscheme.value, ) quant_params["weight_scale"] = weight_scale quant_params["weight_zero_point"] = weight_zero_point module_config = QLayerConfig(input_tensors=input_tensors, output_tensors=output_tensors, weight=weight_qspec) bias_tensor = None bias_key = quant_info.get("bias") bias_tensor = param_dict[bias_key] if bias_key is not None else None quant_module: nn.Module if quant_info["type"] in QUARK_QUANT_OPS: quant_module = QUARK_QUANT_OPS[quant_info["type"]].from_float( float_module, module_config, reload=True, weight_tensor=weight_tensor, bias_tensor=bias_tensor, ) else: raise ValueError(f"The type {quant_info['type']} dose not support in Quark now!") quant_module.load_quant_params(quant_params) return quant_module # TODO: add support for SequentialQuantize later # TODO: better `reorder` doc # TODO: Consider deprecating this - is anybody really using it?
[docs] @log_errors def load_params( model: nn.Module | None = None, json_path: str = "", safetensors_path: str = "", pth_path: str = "", quant_mode: QuantizationMode = QuantizationMode.eager_mode, compressed: bool = False, reorder: bool = True, ) -> nn.Module: """ Instantiates a quantized model from saved model files, which is generated from the :py:func:`quark.torch.export.api.save_params` function. :param torch.nn.Module model: The original Pytorch model. :param str json_path: The path of the saved json file. Only available for eager mode quantization. :param str safetensors_path: The path of the saved safetensors file. Only available for eager mode quantization. :param str pth_path: The path of the saved ``.pth`` file. Only available for ``fx_graph`` mode quantization. :param QuantizationMode quant_mode: The quantization mode. The choice includes ``"QuantizationMode.eager_mode"`` and ``"QuantizationMode.fx_graph_mode"``. Default is ``"QuantizationMode.eager_mode"``. :param bool compressed: Whether the quantized model to load is stored using its compressed data type, or in a "fake quantized" format (QDQ). :param bool reorder: Reorder. :return: The reloaded quantized version of the input model. :rtype: torch.nn.Module Examples: .. code-block:: python # eager mode: from quark.torch import load_params model = load_params(model, json_path=json_path, safetensors_path=safetensors_path) .. code-block:: python # fx_graph mode: from quark.torch.quantization.api import load_params model = load_params(pth_path=model_file_path, quant_mode=QuantizationMode.fx_graph_mode) Note: This function does not support dynamic quantization for now. """ if quant_mode is QuantizationMode.eager_mode: if not is_safetensors_available(): raise ImportError( "The function `load_params` with `quant_mode=QuantizationMode.eager_mode` requires the package `safetensors` to be installed, but it was not found. Please install `safetensors`." ) if model is None: raise ValueError("Model should not be none if loading eager_mode quantized model") if json_path == "" or safetensors_path == "": raise ValueError("Json_path and safetensors_path should not be empty if loading eager_mode quantized model") # load model structure and parameters with open(json_path) as file: model_dict = json.load(file) params_dict = load_file(safetensors_path) # verify exported model and float model have the same configuration model_config = model_dict["config"] if model_config: float_model_config: dict[str, Any] = {} if hasattr(model.config, "to_diff_dict"): float_model_config = model.config.to_diff_dict() elif hasattr(model.config, "items"): float_model_config = dict(model.config.items()) if not deep_compare(model_config, float_model_config): raise RuntimeError("Exported model and float model are not the same model!") # assert ((json.dumps(model_config) == json.dumps(float_model_config)), # "Exported model and float model are not the same model!") logger.info("In-place OPs replacement start.") for name, module_info in get_name_and_info(model_dict["structure"]): float_module = getattr_recursive(model, name) if module_info["type"] in QUARK_QUANT_OPS: module = from_float_and_dict( float_module, module_info, params_dict, layer_name=name, compressed=compressed, reorder=reorder ) setattr_recursive(model, name, module) else: device = float_module.weight.device weight_key = module_info.get("weight") if weight_key is not None: float_module.weight.data = params_dict[weight_key].to(device) bias_key = module_info.get("bias") if bias_key is not None: float_module.bias.data = params_dict[bias_key].to(device) # Convert to relevant quantizers to `FrozenScaledFakeQuantize`. This essentially removes the observers from weight quantizers. We disable running quantization, as loaded weights are already quantized. model = ModelQuantizer.freeze(model, quantize=False) logger.info("In-place OPs replacement end.") elif quant_mode is QuantizationMode.fx_graph_mode: if pth_path == "": raise ValueError("Pth_path should not be empty if loading eager_mode quantized model") loaded_quantized_ep = torch.export.load(pth_path) model = loaded_quantized_ep.module() return model