Source code for quark.torch.quantization.api

#
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
#
"""Quark Quantization API for PyTorch."""

import json
import logging
from dataclasses import fields
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union

import torch
import torch.fx
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

import quark.torch.kernel
from quark.shares.utils.import_utils import is_safetensors_available, is_transformers_available
from quark.shares.utils.log import ScreenLogger, log_errors
from quark.torch.algorithm.api import add_algorithm_config_by_model, apply_advanced_quant_algo
from quark.torch.algorithm.utils.utils import clear_memory
from quark.torch.quantization.config.config import Config, QuantizationConfig, QuantizationSpec
from quark.torch.quantization.config.config_verification import check_and_adjust_quant_config, init_quantization_config
from quark.torch.quantization.config.type import Dtype, QSchemeType, QuantizationMode
from quark.torch.quantization.graph.processor.pre_check_befor_quant import check_supported_model_and_config
from quark.torch.quantization.graph.processor.processor import (
    post_calib_optimize,
    post_quant_optimize,
    prepare_quant_model,
)
from quark.torch.quantization.model_transformation import process_model_transformation
from quark.torch.quantization.nn.modules import (
    QuantConv2d,
    QuantConvTranspose2d,
    QuantEmbedding,
    QuantEmbeddingBag,
    QuantLinear,
)
from quark.torch.quantization.nn.modules.mixin import QuantMixin
from quark.torch.quantization.tensor_quantize import (
    FakeQuantizeBase,
    NonScaledFakeQuantize,
    ScaledFakeQuantize,
    SequentialQuantize,
    enable_or_disable_quantizer,
)
from quark.torch.quantization.utils import count_calibration_tokens, deep_compare
from quark.torch.utils import getattr_recursive, setattr_recursive
from quark.torch.utils.pack import create_pack_method
from quark.torch.utils.profile import gpu_memory_profiled

if is_transformers_available():
    from transformers.feature_extraction_utils import BatchFeature

import os
from collections import Counter
from pathlib import Path

from quark.torch.quantization.debug import check_scale_stats, collect_quantization_statistics, insert_stats_hooks

if is_safetensors_available():
    from safetensors.torch import load_file

__all__ = ["ModelQuantizer", "load_params"]

logger = ScreenLogger(__name__)

QUARK_QUANT_OPS: dict[
    str, type[Union[QuantConv2d, QuantConvTranspose2d, QuantLinear, QuantEmbedding, QuantEmbeddingBag]]
] = {
    "QuantConv2d": QuantConv2d,
    "QuantConvTranspose2d": QuantConvTranspose2d,
    "QuantLinear": QuantLinear,
    "QuantEmbedding": QuantEmbedding,
    "QuantEmbeddingBag": QuantEmbeddingBag,
}


[docs] class ModelQuantizer: """ Provides an API for quantizing deep learning models using PyTorch. This class handles the configuration and processing of the model for quantization based on user-defined parameters. It is essential to ensure that the 'config' provided has all necessary quantization parameters defined. This class assumes that the model is compatible with the quantization settings specified in 'config'. :param Config config: The model quantization configuration. """ def __init__(self, config: Config, multi_device: bool = False) -> None: self.config = config self.is_all_dynamic: bool | None = None self.is_weight_only: bool | None = None self.is_act_dynamic: bool | None = None self.is_act_contain_scale_per_tensor: bool | None = None self._is_accelerate: bool | None = None self.multi_device: bool = multi_device self.init_config() def set_logging_level(self) -> None: if self.config.log_severity_level == 0: ScreenLogger.set_shared_level(logging.DEBUG) elif self.config.log_severity_level == 1: ScreenLogger.set_shared_level(logging.INFO) elif self.config.log_severity_level == 2: ScreenLogger.set_shared_level(logging.WARNING) elif self.config.log_severity_level == 3: ScreenLogger.set_shared_level(logging.ERROR) else: ScreenLogger.set_shared_level(logging.CRITICAL)
[docs] @gpu_memory_profiled(tag=" QuantizeModel") # type: ignore[arg-type] def quantize_model( self, model: nn.Module, dataloader: Union[ DataLoader[torch.Tensor], DataLoader[list[dict[str, torch.Tensor]]], DataLoader[dict[str, torch.Tensor]], DataLoader[list["BatchFeature"]], ] | None = None, ) -> nn.Module: """ Quantizes the given PyTorch model to optimize its performance and reduce its size. The dataloader is used to provide data necessary for calibration during the quantization process. Depending on the type of data provided (either tensors directly or structured as lists or dictionaries of tensors), the function will adapt the quantization approach accordingly. It is important that the model and dataloader are compatible in terms of the data they expect and produce. Misalignment in data handling between the model and the dataloader can lead to errors during the quantization process. :param torch.nn.Module model: The PyTorch model to be quantized. This model should be already trained and ready for quantization. :param Optional[Union[DataLoader[torch.Tensor], DataLoader[List[Dict[str, torch.Tensor]]], DataLoader[Dict[str, torch.Tensor]], DataLoader[List[BatchFeature]]]] dataloader: The ``torch.utils.data.DataLoader`` providing data that the quantization process will use for calibration. This can be a simple ``DataLoader`` returning tensors, or a more complex structure returning either a list of dictionaries or a dictionary of tensors. :return: The quantized version of the input model. This model is now optimized for inference with reduced size and potentially improved performance on targeted devices. :rtype: torch.nn.Module Example: .. code-block:: python # Model & Data preparation from torch.utils.data import DataLoader from transformers import AutoModelForCausalLM, AutoTokenizer from quark.torch.quantization.config.config import Config from quark.torch.quantization.config.type import Dtype, ScaleType, RoundType, QSchemeType from quark.torch.quantization.observer.observer import PerGroupMinMaxObserver from quark.torch import ModelQuantizer model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") model.eval() tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") quant_spec = QuantizationSpec( dtype=Dtype.uint4, observer_cls=PerGroupMinMaxObserver, symmetric=False, scale_type=ScaleType.float, round_method=RoundType.half_even, qscheme=QSchemeType.per_group, ch_axis=1, is_dynamic=False, group_size=128 ) quant_config = Config(global_quant_config=QuantizationConfig(weight=quant_spec)) text = "Hello, how are you?" tokenized_outputs = tokenizer(text, return_tensors="pt") calib_dataloader = DataLoader(tokenized_outputs['input_ids']) quantizer = ModelQuantizer(quant_config) quant_model = quantizer.quantize(model, calib_dataloader) """ logger.info(f"Quantizing with the quantization configuration:\n{self.config}") # Step0: Pre quant device check self._check_model_device(model) # Step1: Prepare quantization model for graph mode and eager mode model = self._prepare_model(model) # Step2[optional]: Apply Advanced quant algo such as gptq, awq, qronos ... model = self._apply_advanced_quant_algo(model, dataloader) # Step3[optional]: Do calibration model = self._do_calibration(model, dataloader) # Step4[optional]: Post calib optimization model = self._do_post_calib_optimazation(model) # Optionally, collect statistics on the quantization errors over the network weights/activations. if os.environ.get("QUARK_DEBUG", None) is not None: log_dir = Path(os.environ["QUARK_DEBUG"]) log_dir.mkdir(parents=True, exist_ok=True) stats: dict[str, Any] = {} dataloader = dataloader if not self.is_all_dynamic else None with insert_stats_hooks(model, stats, log_dir): collect_quantization_statistics(model, dataloader, stats, log_dir) # Check the scale of the quantized model. if os.getenv("QUARK_CHECK_SCALE") == "1": check_scale_stats(model, self.config) # Add quant_config to attribute of the quantized model, so that it can be used for export model.quant_config = self.config # Add a flag to indicate that the model is quantized model.quark_quantized = True return model
def _check_model_device(self, model: nn.Module) -> None: # using accelerate cause, device can not be cpu or disk, temporarily if hasattr(model, "hf_device_map"): if not self.multi_device: for _, layer_device in model.hf_device_map.items(): if layer_device == "cpu" or layer_device == "disk": # TODO: We should handle this for customers. raise MemoryError( "Out of memory. The available GPU memory is insufficient to load the entire model. You can try adding '--multi_device' " ) self._is_accelerate = True else: self._is_accelerate = False def _generate_complete_config_by_model( self, model: nn.Module, dataloader: Union[ DataLoader[torch.Tensor], DataLoader[list[dict[str, torch.Tensor]]], DataLoader[dict[str, torch.Tensor]], DataLoader[list["BatchFeature"]], None, ], ) -> None: """ Generates a complete configuration based on the provided model and dataloader. """ self.config = add_algorithm_config_by_model(model, dataloader, self.config)
[docs] @staticmethod def freeze(model: Union[nn.Module, torch.fx.GraphModule]) -> Union[nn.Module, torch.fx.GraphModule]: """ Freezes the quantized model by replacing ``FakeQuantize`` modules with ``FrozenFakeQuantize`` modules.` In order to be able to compile a quantized model through ``torch.compile``, this method needs to be applied. :param torch.nn.Module model: The neural network model containing quantized layers. :return: The modified model with ``FakeQuantize`` modules replaced by ``FrozenFakeQuantize`` modules. :rtype: torch.nn.Module """ logger.info("Freeze model start.") # ----replace FakeQuantize to FrozenFakeQuantize -------------- named_modules = dict(model.named_modules(remove_duplicate=False)) for name, module in named_modules.items(): if isinstance(module, FakeQuantizeBase): if module.is_dynamic: # TODO: Add freeze for dynamic model logger.warning("Cannot freeze dynamic quantize model for now. Keep use FakeQuantize.") pass else: frozen_quantized_module = module.to_frozen_module() setattr_recursive(model, name, frozen_quantized_module) # ----if model is quantized in fx.graph mode-------------- if isinstance(model, torch.fx.GraphModule): model = model.freeze_model() assert isinstance(model, torch.fx.GraphModule) model = post_quant_optimize(model=model, hw_constrain=True) # TODO pass argument logger.info("Freeze model end.") return model
def _prepare_model(self, model: nn.Module) -> nn.Module: if self.config.quant_mode is QuantizationMode.eager_mode: return process_model_transformation(model, self.config) elif self.config.quant_mode is QuantizationMode.fx_graph_mode: # Quantization with torch.fx does not support some quantization config and some FX graphs. # This raises an error if the config / model used are not supported. check_supported_model_and_config(model, self.config) # type: ignore [arg-type] return prepare_quant_model(model, self.config).eval() # type: ignore [arg-type] def _apply_advanced_quant_algo( self, model: nn.Module, dataloader: Union[ DataLoader[torch.Tensor], DataLoader[list[dict[str, torch.Tensor]]], DataLoader[dict[str, torch.Tensor]], DataLoader[list["BatchFeature"]], ] | None = None, ) -> nn.Module: return apply_advanced_quant_algo(model, self.config, self._is_accelerate, dataloader) def _check_token_distribution( self, model: nn.Module, dataloader: Union[ DataLoader[torch.Tensor], DataLoader[list[dict[str, torch.Tensor]]], DataLoader[dict[str, torch.Tensor]], DataLoader[list["BatchFeature"]], ], ) -> None: """ A helper function that warns when a MoE module received 0 token throughout the calibration process. """ threshold = ( float(os.environ["TOKEN_DISTRIBUTION_THRESHOLD"]) if os.getenv("TOKEN_DISTRIBUTION_THRESHOLD") is not None else 0.0 ) assert 0.0 <= threshold <= 1.0, "threshold should be in [0.0, 1.0]" total_token_count = count_calibration_tokens(dataloader) if total_token_count == 0: logger.warning("No tokens found in calibration dataset. Skipping token distribution check.") return # Get the observer token count for each module token_counts: Counter[str] = Counter() for name, module in model.named_modules(): if isinstance(module, ScaledFakeQuantize): if "_input_quantizer" in name: if module.observer._num_observed_tokens is not None: token_counts[name.replace("._input_quantizer", "")] = module.observer._num_observed_tokens for module_name, token_count in token_counts.items(): if (token_count / float(total_token_count)) <= threshold: logger.warning( f"The module: {module_name} " f"received {token_count} tokens less than {threshold * 100:.1f}% " f"of all {total_token_count} calibration tokens." ) # when using multi_device, you must add it here or offload will fail. # The gpu memory used for gradients cannot be cleaned up by torch.cuda.empty_cache() @torch.no_grad() def _do_calibration( self, model: nn.Module, dataloader: Union[ DataLoader[torch.Tensor], DataLoader[list[dict[str, torch.Tensor]]], DataLoader[dict[str, torch.Tensor]], DataLoader[list["BatchFeature"]], ] | None = None, ) -> nn.Module: # just calib, turn off quantize if self.is_all_dynamic: # TODO: to be deperated logger.info("Dynamic quantization, no calibration.") elif self.is_weight_only or (self.is_act_dynamic and not self.is_act_contain_scale_per_tensor): logger.info("Weight calibration start.") for module in model.modules(): if isinstance(module, ScaledFakeQuantize): module.enable_observer() module.disable_fake_quant() # Simply run through the observers to set min_val, max_val, scale and zero_point buffers for the weight and bias. named_modules = dict(model.named_modules(remove_duplicate=False)) for name, module in tqdm(named_modules.items()): if isinstance(module, QuantMixin): if module._weight_quantizer is not None and isinstance( module._weight_quantizer, (ScaledFakeQuantize, SequentialQuantize) ): weight_quantizers: Union[list[ScaledFakeQuantize], SequentialQuantize] = ( [module._weight_quantizer] if isinstance(module._weight_quantizer, ScaledFakeQuantize) else module._weight_quantizer ) if all( quantizer.scale.numel() == 1 and quantizer.scale.item() == 1 for quantizer in weight_quantizers ): # This condition prevents layers that have already been quantized from being quantized again. if module.weight.device == torch.device("meta"): weight = module._hf_hook.weights_map["weight"].data weight = module.get_quant_weight(weight.to(module._hf_hook.execution_device)) del weight else: _ = module.get_quant_weight(module.weight) if module._bias_quantizer is not None and isinstance( module._bias_quantizer, (ScaledFakeQuantize, SequentialQuantize) ): bias_quantizers: Union[list[ScaledFakeQuantize], SequentialQuantize] = ( [module._bias_quantizer] if isinstance(module._bias_quantizer, ScaledFakeQuantize) else module._bias_quantizer ) if all( quantizer.scale.numel() == 1 and quantizer.scale.item() == 1 for quantizer in bias_quantizers ): if module.bias.device == torch.device("meta"): bias = module._hf_hook.weights_map["bias"].data _ = module.get_quant_bias(bias.to(module._hf_hook.execution_device)) del bias else: _ = module.get_quant_bias(module.bias) torch.cuda.empty_cache() clear_memory() logger.info("Weight calibration end.") else: logger.info("Calibration start.") for module in model.modules(): if isinstance(module, ScaledFakeQuantize): module.enable_observer() module.disable_fake_quant() assert dataloader is not None with torch.no_grad(): for data in tqdm(dataloader): if isinstance(data, dict): # pragma: no cover model(**data) elif is_transformers_available() and isinstance(data, BatchFeature): # pragma: no cover _ = model(**data) else: model(data) self._check_token_distribution(model, dataloader) clear_memory() logger.info("Calibration end.") logger.info("Model quantization has been completed.") # step5[optional]: do evaluation, turn on quantize if self.config.algo_config is not None and any( cfg.name == "gptq" and hasattr(cfg, "static_groups") and cfg.static_groups is False for cfg in self.config.algo_config ): logger.warning( "Dynamic groups in GPTQ (static_groups=false) does not support FakeQuantize for export, turn off FakeQuantize for weight while keeping open FakeQuantize for activation in order to run evaluations." ) named_modules = dict(model.named_modules(remove_duplicate=False)) for _, module in tqdm(named_modules.items()): if isinstance(module, QuantMixin): if module._weight_quantizer is not None: enable_or_disable_quantizer(module._weight_quantizer, enable=False) if module._input_quantizer is not None: enable_or_disable_quantizer(module._input_quantizer, enable=True) if module._output_quantizer is not None: enable_or_disable_quantizer(module._output_quantizer, enable=True) else: for name, module in model.named_modules(): if isinstance(module, ScaledFakeQuantize): if module.is_dynamic and not ( module.is_scale_quant and module.qscheme == QSchemeType.per_tensor ): # For dynamic quantization, observer should be enable and update qparam every time. module.enable_observer() module.enable_fake_quant() else: module.disable_observer() module.enable_fake_quant() elif isinstance(module, NonScaledFakeQuantize): module.enable_fake_quant() return model def _do_post_calib_optimazation(self, model: nn.Module) -> nn.Module: """ In some case: 1. After calibration: get weight, activation and bias scale 2. Some hw constrain need let: bias_scale = weight_scale * act_scale After calibration, we need to do some optimization, and then perform QAT/export. """ if self.config.quant_mode is QuantizationMode.eager_mode: # remain this API TODO assert isinstance(model, nn.Module) return model elif self.config.quant_mode is QuantizationMode.fx_graph_mode: """ In calibration: observer will record tensor's distribution. Scale and ZP will be calculated. In some hardware constrain case. e.g. b_scale = w_scale * a_scale (we need to modify bias_scale after calibration) """ assert isinstance(model, torch.fx.GraphModule) model = post_calib_optimize(model) return model # type: ignore[no-any-return] def init_config(self) -> None: self.set_logging_level() # set log level: default info logger.info("Configuration checking start.") config = self.config # TODO: Verify quant algo for field in fields(Config): if field.name in ["global_quant_config"]: quantization_config = getattr(config, field.name) _config = check_and_adjust_quant_config(quantization_config) setattr(self.config, field.name, _config) self.is_all_dynamic, self.is_weight_only, self.is_act_dynamic, self.is_act_contain_scale_per_tensor = ( init_quantization_config(quantization_config) ) elif field.name in ["layer_type_quant_config", "layer_quant_config"]: quantization_config_list = getattr(config, field.name) for quantization_config in quantization_config_list.values(): ( self.is_all_dynamic, self.is_weight_only, self.is_act_dynamic, self.is_act_contain_scale_per_tensor, ) = init_quantization_config(quantization_config) if self.is_weight_only: config_parsing_result = "weight only quantization" else: if self.is_act_dynamic: config_parsing_result = "weight quantization and activation dynamic quantization" else: config_parsing_result = "weight quantization and activation static quantization" logger.info(f"Configuration checking end. The configuration is effective. This is {config_parsing_result}.")
def get_name_and_info(model_info: dict[str, Any], parent_key: str = "") -> Iterable[tuple[str, dict[str, Any]]]: for key, value in model_info.items(): new_key = f"{parent_key}.{key}" if parent_key else key if isinstance(value, dict): if value.get("type", None) is not None and value.get("weight", None) is not None: yield new_key, value else: yield from get_name_and_info(value, new_key) else: continue # TODO: This function is only used in load_params, add support for SequentialQuantize later def from_float_and_dict( float_module: nn.Module, quant_info: dict[str, Any], param_dict: dict[str, torch.Tensor], layer_name: str, compressed: bool = False, reorder: bool = True, ) -> nn.Module: input_tensors = None quant_params: dict[str, torch.Tensor | None] = {} if quant_info.get("input_quant") is not None: input_tensors = QuantizationSpec.from_dict(quant_info["input_quant"]) quant_params["input_scale"] = param_dict[layer_name + ".input_scale"] # pragma: no cover quant_params["input_zero_point"] = param_dict[layer_name + ".input_zero_point"] # pragma: no cover output_tensors = None if quant_info.get("output_quant") is not None: output_tensors = QuantizationSpec.from_dict(quant_info["output_quant"]) quant_params["output_scale"] = param_dict[layer_name + ".output_scale"] quant_params["output_zero_point"] = param_dict[layer_name + ".output_zero_point"] weight_qspec: QuantizationSpec | None = None weight_key = quant_info.get("weight") if weight_key is None: raise KeyError("Missing 'weight' in quant_info") weight_tensor = param_dict[weight_key] if quant_info.get("weight_quant") is not None: weight_qspec = QuantizationSpec.from_dict(quant_info["weight_quant"]) weight_scale = param_dict[layer_name + ".weight_scale"] weight_zero_point = param_dict[layer_name + ".weight_zero_point"] if compressed: assert isinstance(weight_qspec, QuantizationSpec), "weight_qspec must be QuantizationSpec instance" assert isinstance(weight_qspec.qscheme, QSchemeType), "weight_qspec.qscheme must be QSchemeType instance" assert isinstance(weight_qspec.dtype, Dtype), "weight_qspec.dtype must be Dtype instance" pack_method = create_pack_method(qscheme=weight_qspec.qscheme.value, dtype=weight_qspec.dtype.value) weight_tensor = pack_method.unpack( weight_tensor, reorder, **({"origin_packed_axis_size": weight_scale.shape[-1]} if weight_scale.shape != torch.Size([]) else {}), ) weight_tensor = quark.torch.kernel.dequantize( # type: ignore[attr-defined] weight_qspec.dtype.value, weight_tensor, weight_scale, weight_zero_point, weight_qspec.ch_axis, weight_qspec.group_size, weight_qspec.qscheme.value, ) quant_params["weight_scale"] = weight_scale quant_params["weight_zero_point"] = weight_zero_point module_config = QuantizationConfig(input_tensors=input_tensors, output_tensors=output_tensors, weight=weight_qspec) bias_tensor = None bias_key = quant_info.get("bias") bias_tensor = param_dict[bias_key] if bias_key is not None else None quant_module: nn.Module if quant_info["type"] in QUARK_QUANT_OPS: quant_module = QUARK_QUANT_OPS[quant_info["type"]].from_float( float_module, module_config, reload=True, weight_tensor=weight_tensor, bias_tensor=bias_tensor, ) else: raise ValueError(f"The type {quant_info['type']} dose not support in Quark now!") quant_module.load_quant_params(quant_params) return quant_module # TODO: add support for SequentialQuantize later # TODO: better `reorder` doc
[docs] @log_errors def load_params( model: nn.Module | None = None, json_path: str = "", safetensors_path: str = "", pth_path: str = "", quant_mode: QuantizationMode = QuantizationMode.eager_mode, compressed: bool = False, reorder: bool = True, ) -> nn.Module: """ Instantiates a quantized model from saved model files, which is generated from the :py:func:`quark.torch.export.api.save_params` function. :param torch.nn.Module model: The original Pytorch model. :param str json_path: The path of the saved json file. Only available for eager mode quantization. :param str safetensors_path: The path of the saved safetensors file. Only available for eager mode quantization. :param str pth_path: The path of the saved ``.pth`` file. Only available for ``fx_graph`` mode quantization. :param QuantizationMode quant_mode: The quantization mode. The choice includes ``"QuantizationMode.eager_mode"`` and ``"QuantizationMode.fx_graph_mode"``. Default is ``"QuantizationMode.eager_mode"``. :param bool compressed: Whether the quantized model to load is stored using its compressed data type, or in a "fake quantized" format (QDQ). :param bool reorder: Reorder. :return: The reloaded quantized version of the input model. :rtype: torch.nn.Module Examples: .. code-block:: python # eager mode: from quark.torch import load_params model = load_params(model, json_path=json_path, safetensors_path=safetensors_path) .. code-block:: python # fx_graph mode: from quark.torch.quantization.api import load_params model = load_params(pth_path=model_file_path, quant_mode=QuantizationMode.fx_graph_mode) Note: This function does not support dynamic quantization for now. """ if quant_mode is QuantizationMode.eager_mode: if not is_safetensors_available(): raise ImportError( "The function `load_params` with `quant_mode=QuantizationMode.eager_mode` requires the package `safetensors` to be installed, but it was not found. Please install `safetensors`." ) if model is None: raise ValueError("Model should not be none if loading eager_mode quantized model") if json_path == "" or safetensors_path == "": raise ValueError("Json_path and safetensors_path should not be empty if loading eager_mode quantized model") # load model structure and parameters with open(json_path) as file: model_dict = json.load(file) params_dict = load_file(safetensors_path) # verify exported model and float model have the same configuration model_config = model_dict["config"] if model_config: float_model_config: dict[str, Any] = {} if hasattr(model.config, "to_diff_dict"): float_model_config = model.config.to_diff_dict() elif hasattr(model.config, "items"): float_model_config = dict(model.config.items()) if not deep_compare(model_config, float_model_config): raise RuntimeError("Exported model and float model are not the same model!") # assert ((json.dumps(model_config) == json.dumps(float_model_config)), # "Exported model and float model are not the same model!") logger.info("In-place OPs replacement start.") for name, module_info in get_name_and_info(model_dict["structure"]): float_module = getattr_recursive(model, name) if module_info["type"] in QUARK_QUANT_OPS: module = from_float_and_dict( float_module, module_info, params_dict, layer_name=name, compressed=compressed, reorder=reorder ) setattr_recursive(model, name, module) else: device = float_module.weight.device weight_key = module_info.get("weight") if weight_key is not None: float_module.weight.data = params_dict[weight_key].to(device) bias_key = module_info.get("bias") if bias_key is not None: float_module.bias.data = params_dict[bias_key].to(device) model = ModelQuantizer.freeze(model) logger.info("In-place OPs replacement end.") elif quant_mode is QuantizationMode.fx_graph_mode: if pth_path == "": raise ValueError("Pth_path should not be empty if loading eager_mode quantized model") loaded_quantized_ep = torch.export.load(pth_path) model = loaded_quantized_ep.module() return model