#
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
#
"""Quark Quantization API for PyTorch."""
import torch
import torch.nn as nn
import torch.fx
import json
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import Dict, Any, Optional, Union, List, Tuple, Iterable, Type
from dataclasses import fields
from quark.shares.utils.import_utils import is_transformers_available
from quark.torch.quantization.config.type import QuantizationMode, Dtype, QSchemeType
from quark.torch.quantization.config.config_verification import check_and_adjust_quant_config
from quark.torch.quantization.model_transformation import process_model_transformation
from quark.torch.quantization.config.config import Config, QuantizationConfig, QuantizationSpec
from quark.torch.quantization.config.config_verification import init_quantization_config, verify_quantization_spec
from quark.torch.quantization.graph.processor.processor import prepare_quant_model, post_calib_optimize
from quark.torch.quantization.graph.processor.pre_check_befor_quant import check_supported_model_and_config
from quark.torch.quantization.graph.processor.processor import post_quant_optimize
from quark.torch.quantization.utils import set_op_by_name, get_op_by_name
from quark.torch.quantization.nn.modules.mixin import QuantMixin
from quark.torch.quantization.tensor_quantize import FakeQuantizeBase, ScaledFakeQuantize, NonScaledFakeQuantize, SequentialQuantize, enable_or_disable_quantizer
from quark.torch.quantization.utils import deep_compare, count_calibration_tokens
from quark.shares.utils.log import ScreenLogger, log_errors
from quark.shares.utils.import_utils import is_safetensors_available
from quark.torch.algorithm.api import apply_pre_quantization_optimization, apply_advanced_quant_algo, add_algorithm_config_by_model
from quark.torch.algorithm.utils.utils import clear_memory
import logging
from quark.torch.quantization.nn.modules import QuantConv2d, QuantConvTranspose2d, QuantLinear, QuantEmbedding, QuantEmbeddingBag
from quark.torch.utils.pack import create_pack_method
import quark.torch.kernel
if is_transformers_available():
from transformers.feature_extraction_utils import BatchFeature
import os
from pathlib import Path
from collections import Counter
from quark.torch.quantization.debug import (insert_stats_hooks, collect_quantization_statistics, check_scale_stats)
if is_safetensors_available():
from safetensors.torch import load_file
__all__ = ["ModelQuantizer", "load_params"]
logger = ScreenLogger(__name__)
QUARK_QUANT_OPS: Dict[str, Type[Union[QuantConv2d, QuantConvTranspose2d, QuantLinear, QuantEmbedding,
QuantEmbeddingBag]]] = {
"QuantConv2d": QuantConv2d,
"QuantConvTranspose2d": QuantConvTranspose2d,
"QuantLinear": QuantLinear,
"QuantEmbedding": QuantEmbedding,
"QuantEmbeddingBag": QuantEmbeddingBag,
}
[docs]
class ModelQuantizer:
"""
Provides an API for quantizing deep learning models using PyTorch.
This class handles the configuration and processing of the model for quantization based on user-defined parameters. It is essential to ensure that the 'config' provided has all necessary quantization parameters defined. This class assumes that the model is compatible with the quantization settings specified in 'config'.
:param Config config: The model quantization configuration.
"""
def __init__(self, config: Config, multi_device: bool = False) -> None:
self.config = config
self.is_all_dynamic: Optional[bool] = None
self.is_weight_only: Optional[bool] = None
self.is_act_dynamic: Optional[bool] = None
self.is_act_contain_scale_per_tensor: Optional[bool] = None
self._is_accelerate: Optional[bool] = None
self.multi_device: bool = multi_device
self.init_config()
def set_logging_level(self) -> None:
if self.config.log_severity_level == 0:
ScreenLogger.set_shared_level(logging.DEBUG)
elif self.config.log_severity_level == 1:
ScreenLogger.set_shared_level(logging.INFO)
elif self.config.log_severity_level == 2:
ScreenLogger.set_shared_level(logging.WARNING)
elif self.config.log_severity_level == 3:
ScreenLogger.set_shared_level(logging.ERROR)
else:
ScreenLogger.set_shared_level(logging.CRITICAL)
[docs]
def quantize_model(
self,
model: nn.Module,
dataloader: Optional[Union[DataLoader[torch.Tensor], DataLoader[List[Dict[str, torch.Tensor]]],
DataLoader[Dict[str, torch.Tensor]], DataLoader[List["BatchFeature"]]]] = None
) -> nn.Module:
"""
Quantizes the given PyTorch model to optimize its performance and reduce its size.
The dataloader is used to provide data necessary for calibration during the quantization process. Depending on the type of data provided (either tensors directly or structured as lists or dictionaries of tensors), the function will adapt the quantization approach accordingly.
It is important that the model and dataloader are compatible in terms of the data they expect and produce. Misalignment in data handling between the model and the dataloader can lead to errors during the quantization process.
:param torch.nn.Module model: The PyTorch model to be quantized. This model should be already trained and ready for quantization.
:param Optional[Union[DataLoader[torch.Tensor], DataLoader[List[Dict[str, torch.Tensor]]], DataLoader[Dict[str, torch.Tensor]], DataLoader[List[BatchFeature]]]] dataloader: The ``torch.utils.data.DataLoader`` providing data that the quantization process will use for calibration. This can be a simple ``DataLoader`` returning tensors, or a more complex structure returning either a list of dictionaries or a dictionary of tensors.
:return: The quantized version of the input model. This model is now optimized for inference with reduced size and potentially improved performance on targeted devices.
:rtype: torch.nn.Module
Example:
.. code-block:: python
# Model & Data preparation
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from quark.torch.quantization.config.config import Config
from quark.torch.quantization.config.type import Dtype, ScaleType, RoundType, QSchemeType
from quark.torch.quantization.observer.observer import PerGroupMinMaxObserver
from quark.torch import ModelQuantizer
model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
model.eval()
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m")
quant_spec = QuantizationSpec(
dtype=Dtype.uint4,
observer_cls=PerGroupMinMaxObserver,
symmetric=False,
scale_type=ScaleType.float,
round_method=RoundType.half_even,
qscheme=QSchemeType.per_group,
ch_axis=1,
is_dynamic=False,
group_size=128
)
quant_config = Config(global_quant_config=QuantizationConfig(weight=quant_spec))
text = "Hello, how are you?"
tokenized_outputs = tokenizer(text, return_tensors="pt")
calib_dataloader = DataLoader(tokenized_outputs['input_ids'])
quantizer = ModelQuantizer(quant_config)
quant_model = quantizer.quantize(model, calib_dataloader)
"""
logger.info(f"Quantizing with the quantization configuration:\n{self.config}")
# Step0-1: Pre quant device check
self._check_model_device(model)
# Step0-2: Enhance config
self._generate_complete_config_by_model(model, dataloader)
# Step1[optional]: Pre quant optimization
model = self._apply_pre_quantization_optimization(model, dataloader)
# Step2: Prepare quantization model for graph mode and eager mode
model = self._prepare_model(model)
# Step3[optional]: Apply Advanced quant algo such as gptq/awq ...
model = self._apply_advanced_quant_algo(model, dataloader)
# Step4[optional]: Do calibration
model = self._do_calibration(model, dataloader)
# Step5[optional]: Post calib optimization
model = self._do_post_calib_optimazation(model)
# Optionally, collect statistics on the quantization errors over the network weights/activations.
if os.environ.get("QUARK_DEBUG", None) is not None:
log_dir = Path(os.environ["QUARK_DEBUG"])
log_dir.mkdir(parents=True, exist_ok=True)
stats: Dict[str, Any] = {}
dataloader = dataloader if not self.is_all_dynamic else None
with insert_stats_hooks(model, stats, log_dir):
collect_quantization_statistics(model, dataloader, stats, log_dir)
# Check the scale of the quantized model.
if os.getenv("QUARK_CHECK_SCALE") == "1":
check_scale_stats(model, self.config)
return model
def _check_model_device(self, model: nn.Module) -> None:
# using accelerate cause, device can not be cpu or disk, temporarily
if hasattr(model, 'hf_device_map'):
if not self.multi_device:
for _, layer_device in model.hf_device_map.items():
if layer_device == "cpu" or layer_device == "disk":
# TODO: We should handle this for customers.
raise MemoryError(
"Out of memory. The available GPU memory is insufficient to load the entire model. You can try adding '--multi_device' "
)
self._is_accelerate = True
else:
self._is_accelerate = False
def _generate_complete_config_by_model(
self, model: nn.Module, dataloader: Union[DataLoader[torch.Tensor], DataLoader[list[dict[str, torch.Tensor]]],
DataLoader[dict[str, torch.Tensor]], DataLoader[List["BatchFeature"]],
None]
) -> None:
"""
Generates a complete configuration based on the provided model and dataloader.
"""
self.config = add_algorithm_config_by_model(model, dataloader, self.config)
[docs]
@staticmethod
def freeze(model: Union[nn.Module, torch.fx.GraphModule]) -> Union[nn.Module, torch.fx.GraphModule]:
"""
Freezes the quantized model by replacing ``FakeQuantize`` modules with ``FreezedFakeQuantize`` modules.`
In order to be able to compile a quantized model through ``torch.compile``, this method needs to be applied.
:param torch.nn.Module model: The neural network model containing quantized layers.
:return: The modified model with ``FakeQuantize`` modules replaced by ``FreezedFakeQuantize`` modules.
:rtype: torch.nn.Module
"""
logger.info("Freeze model start.")
# ----replace FakeQuantize to FreezedFakeQuantize --------------
named_modules = dict(model.named_modules(remove_duplicate=False))
for name, module in named_modules.items():
if isinstance(module, FakeQuantizeBase):
if module.is_dynamic:
# TODO: Add freeze for dynamic model
logger.warning("Cannot freeze dynamic quantize model for now. Keep use FakeQuantize.")
pass
else:
freezed_quantized_module = module.to_freezed_module()
set_op_by_name(model, name, freezed_quantized_module)
# ----if model is quantized in fx.graph mode--------------
if isinstance(model, torch.fx.GraphModule):
model = model.freeze_model()
assert isinstance(model, torch.fx.GraphModule)
model = post_quant_optimize(model=model, hw_constrain=True) # TODO pass argument
logger.info("Freeze model end.")
return model
def _apply_pre_quantization_optimization(
self,
model: nn.Module,
dataloader: Optional[Union[DataLoader[torch.Tensor], DataLoader[List[Dict[str, torch.Tensor]]],
DataLoader[Dict[str, torch.Tensor]], DataLoader[List["BatchFeature"]]]] = None
) -> nn.Module:
return apply_pre_quantization_optimization(model, self.config, dataloader=dataloader)
def _prepare_model(self, model: nn.Module) -> nn.Module:
if self.config.quant_mode is QuantizationMode.eager_mode:
return process_model_transformation(model, self.config)
elif self.config.quant_mode is QuantizationMode.fx_graph_mode:
# Quantization with torch.fx does not support some quantization config and some FX graphs.
# This raises an error if the config / model used are not supported.
check_supported_model_and_config(model, self.config) # type: ignore [arg-type]
return prepare_quant_model(model, self.config).eval() # type: ignore [arg-type]
def _apply_advanced_quant_algo(
self,
model: nn.Module,
dataloader: Optional[Union[DataLoader[torch.Tensor], DataLoader[List[Dict[str, torch.Tensor]]],
DataLoader[Dict[str, torch.Tensor]], DataLoader[List["BatchFeature"]]]] = None
) -> nn.Module:
return apply_advanced_quant_algo(model, self.config, self._is_accelerate, dataloader)
def _check_token_distribution(
self, model: nn.Module, dataloader: Union[DataLoader[torch.Tensor], DataLoader[List[Dict[str, torch.Tensor]]],
DataLoader[Dict[str, torch.Tensor]], DataLoader[List["BatchFeature"]]]
) -> None:
"""
A helper function that warns when a MoE module
received 0 token throughout the calibration process.
"""
threshold = float(os.environ['TOKEN_DISTRIBUTION_THRESHOLD']) if os.getenv(
"TOKEN_DISTRIBUTION_THRESHOLD") is not None else 0.0
assert 0.0 <= threshold <= 1.0, 'threshold should be in [0.0, 1.0]'
total_token_count = count_calibration_tokens(dataloader)
if total_token_count == 0:
logger.warning("No tokens found in calibration dataset. "
"Skipping token distribution check.")
return
# Get the observer token count for each module
token_counts: Counter[str] = Counter()
for name, module in model.named_modules():
if isinstance(module, ScaledFakeQuantize):
if '_input_quantizer' in name:
if module.observer._num_observed_tokens is not None:
token_counts[name.replace("._input_quantizer", "")] = module.observer._num_observed_tokens
for module_name, token_count in token_counts.items():
if (token_count / float(total_token_count)) <= threshold:
logger.warning(f"The module: {module_name} "
f"received {token_count} tokens less than {threshold * 100:.1f}% "
f"of all {total_token_count} calibration tokens.")
# when using multi_device, you must add it here or offload will fail.
# The gpu memory used for gradients cannot be cleaned up by torch.cuda.empty_cache()
@torch.no_grad()
def _do_calibration(
self,
model: nn.Module,
dataloader: Optional[Union[DataLoader[torch.Tensor], DataLoader[List[Dict[str, torch.Tensor]]],
DataLoader[Dict[str, torch.Tensor]], DataLoader[List["BatchFeature"]]]] = None
) -> nn.Module:
# just calib, turn off quantize
if self.is_all_dynamic: # TODO: to be deperated
logger.info("Dynamic quantization, no calibration.")
elif self.is_weight_only or (self.is_act_dynamic and not self.is_act_contain_scale_per_tensor):
logger.info("Weight calibration start.")
for module in model.modules():
if isinstance(module, ScaledFakeQuantize):
module.enable_observer()
module.disable_fake_quant()
# Simply run through the observers to set min_val, max_val, scale and zero_point buffers for the weight and bias.
named_modules = dict(model.named_modules(remove_duplicate=False))
for name, module in tqdm(named_modules.items()):
if isinstance(module, QuantMixin):
if module._weight_quantizer is not None and isinstance(module._weight_quantizer,
(ScaledFakeQuantize, SequentialQuantize)):
weight_quantizers: Union[List[ScaledFakeQuantize], SequentialQuantize] = [
module._weight_quantizer
] if isinstance(module._weight_quantizer, ScaledFakeQuantize) else module._weight_quantizer
if all(quantizer.scale.numel() == 1 and quantizer.scale.item() == 1
for quantizer in weight_quantizers):
# This condition prevents layers that have already been quantized from being quantized again.
if module.weight.device == torch.device("meta"):
weight = module._hf_hook.weights_map["weight"].data
weight = module.get_quant_weight(weight.to(module._hf_hook.execution_device))
del weight
else:
_ = module.get_quant_weight(module.weight)
if module._bias_quantizer is not None and isinstance(module._bias_quantizer,
(ScaledFakeQuantize, SequentialQuantize)):
bias_quantizers: Union[List[ScaledFakeQuantize], SequentialQuantize] = [
module._bias_quantizer
] if isinstance(module._bias_quantizer, ScaledFakeQuantize) else module._bias_quantizer
if all(quantizer.scale.numel() == 1 and quantizer.scale.item() == 1
for quantizer in bias_quantizers):
if module.bias.device == torch.device("meta"):
bias = module._hf_hook.weights_map["bias"].data
_ = module.get_quant_bias(bias.to(module._hf_hook.execution_device))
del bias
else:
_ = module.get_quant_bias(module.bias)
torch.cuda.empty_cache()
clear_memory()
logger.info("Weight calibration end.")
else:
logger.info("Calibration start.")
for module in model.modules():
if isinstance(module, ScaledFakeQuantize):
module.enable_observer()
module.disable_fake_quant()
assert dataloader is not None
with torch.no_grad():
for data in tqdm(dataloader):
if isinstance(data, dict): # pragma: no cover
model(**data)
elif is_transformers_available() and isinstance(data, BatchFeature): # pragma: no cover
_ = model(**data)
else:
model(data)
self._check_token_distribution(model, dataloader)
clear_memory()
logger.info("Calibration end.")
logger.info("Model quantization has been completed.")
# step5[optional]: do evaluation, turn on quantize
if (self.config.algo_config) and self.config.algo_config.name in ['gptq'] and hasattr(
self.config.algo_config, "static_groups") and self.config.algo_config.static_groups is False:
logger.warning(
"Dynamic groups in GPTQ (static_groups=false) does not support FakeQuantize for export, turn off FakeQuantize for weight while keeping open FakeQuantize for activation in order to run evaluations."
)
named_modules = dict(model.named_modules(remove_duplicate=False))
for _, module in tqdm(named_modules.items()):
if isinstance(module, QuantMixin):
if module._weight_quantizer is not None:
enable_or_disable_quantizer(module._weight_quantizer, enable=False)
if module._input_quantizer is not None:
enable_or_disable_quantizer(module._input_quantizer, enable=True)
if module._output_quantizer is not None:
enable_or_disable_quantizer(module._output_quantizer, enable=True)
else:
for name, module in model.named_modules():
if isinstance(module, ScaledFakeQuantize):
if module.is_dynamic and not (
module.is_scale_quant and module.qscheme == QSchemeType.per_tensor
): # For dynamic quantization, observer should be enable and update qparam every time.
module.enable_observer()
module.enable_fake_quant()
else:
module.disable_observer()
module.enable_fake_quant()
elif isinstance(module, NonScaledFakeQuantize):
module.enable_fake_quant()
return model
def _do_post_calib_optimazation(self, model: nn.Module) -> nn.Module:
'''
In some case:
1. After calibration: get weight, activation and bias scale
2. Some hw constrain need let: bias_scale = weight_scale * act_scale
After calibration, we need to do some optimization, and then perform QAT/export.
'''
if self.config.quant_mode is QuantizationMode.eager_mode:
# remain this API TODO
assert isinstance(model, nn.Module)
return model
elif self.config.quant_mode is QuantizationMode.fx_graph_mode:
'''
In calibration: observer will record tensor's distribution. Scale and ZP will be calculated.
In some hardware constrain case.
e.g. b_scale = w_scale * a_scale (we need to modify bias_scale after calibration)
'''
assert isinstance(model, torch.fx.GraphModule)
model = post_calib_optimize(model)
return model # type: ignore[no-any-return]
def init_config(self) -> None:
self.set_logging_level() # set log level: default info
logger.info("Configuration checking start.")
config = self.config
verify_quantization_spec(config)
# TODO: Verify quant algo
for field in fields(Config):
if field.name in ["global_quant_config"]:
quantization_config = getattr(config, field.name)
_config = check_and_adjust_quant_config(quantization_config)
setattr(self.config, field.name, _config)
self.is_all_dynamic, self.is_weight_only, self.is_act_dynamic, self.is_act_contain_scale_per_tensor = \
init_quantization_config(quantization_config)
elif field.name in ["layer_type_quant_config", "layer_quant_config"]:
quantization_config_list = getattr(config, field.name)
for quantization_config in quantization_config_list.values():
self.is_all_dynamic, self.is_weight_only, self.is_act_dynamic, self.is_act_contain_scale_per_tensor = \
init_quantization_config(quantization_config)
if self.is_weight_only:
config_parsing_result = 'weight only quantization'
else:
if self.is_act_dynamic:
config_parsing_result = 'weight quantization and activation dynamic quantization'
else:
config_parsing_result = 'weight quantization and activation static quantization'
logger.info(f"Configuration checking end. The configuration is effective. This is {config_parsing_result}.")
def get_name_and_info(model_info: Dict[str, Any], parent_key: str = "") -> Iterable[Tuple[str, Dict[str, Any]]]:
for key, value in model_info.items():
new_key = f"{parent_key}.{key}" if parent_key else key
if isinstance(value, dict):
if value.get("type", None) is not None and value.get("weight", None) is not None:
yield new_key, value
else:
yield from get_name_and_info(value, new_key)
else:
continue
# TODO: This function is only used in load_params, add support for SequentialQuantize later
def from_float_and_dict(float_module: nn.Module,
quant_info: Dict[str, Any],
param_dict: Dict[str, torch.Tensor],
layer_name: str,
compressed: bool = False,
reorder: bool = True) -> nn.Module:
input_tensors = None
quant_params: Dict[str, Optional[torch.Tensor]] = {}
if quant_info.get("input_quant", None) is not None:
input_tensors = QuantizationSpec.from_dict(quant_info["input_quant"])
quant_params["input_scale"] = param_dict[layer_name + ".input_scale"] # pragma: no cover
quant_params["input_zero_point"] = param_dict[layer_name + ".input_zero_point"] # pragma: no cover
output_tensors = None
if quant_info.get("output_quant", None) is not None:
output_tensors = QuantizationSpec.from_dict(quant_info["output_quant"])
quant_params["output_scale"] = param_dict[layer_name + ".output_scale"]
quant_params["output_zero_point"] = param_dict[layer_name + ".output_zero_point"]
weight_qspec: Optional[QuantizationSpec] = None
weight_tensor = param_dict[quant_info.get("weight", None)]
if quant_info.get("weight_quant", None) is not None:
weight_qspec = QuantizationSpec.from_dict(quant_info["weight_quant"])
weight_scale = param_dict[layer_name + ".weight_scale"]
weight_zero_point = param_dict[layer_name + ".weight_zero_point"]
if compressed:
assert isinstance(weight_qspec, QuantizationSpec), "weight_qspec must be QuantizationSpec instance"
assert isinstance(weight_qspec.qscheme, QSchemeType), "weight_qspec.qscheme must be QSchemeType instance"
assert isinstance(weight_qspec.dtype, Dtype), "weight_qspec.dtype must be Dtype instance"
pack_method = create_pack_method(qscheme=weight_qspec.qscheme.value, dtype=weight_qspec.dtype.value)
weight_tensor = pack_method.unpack(
weight_tensor, reorder,
**({
"origin_packed_axis_size": weight_scale.shape[-1]
} if weight_scale.shape != torch.Size([]) else {}))
weight_tensor = quark.torch.kernel.dequantize( # type: ignore[attr-defined]
weight_qspec.dtype.value, weight_tensor, weight_scale, weight_zero_point, weight_qspec.ch_axis,
weight_qspec.group_size, weight_qspec.qscheme.value)
quant_params["weight_scale"] = weight_scale
quant_params["weight_zero_point"] = weight_zero_point
module_config = QuantizationConfig(input_tensors=input_tensors, output_tensors=output_tensors, weight=weight_qspec)
bias_tensor = None
if quant_info.get("bias", None) is not None:
bias_tensor = param_dict[quant_info.get("bias", None)]
quant_module: nn.Module
if quant_info['type'] in QUARK_QUANT_OPS:
quant_module = QUARK_QUANT_OPS[quant_info['type']].from_float(
float_module,
module_config,
reload=True,
weight_tensor=weight_tensor,
bias_tensor=bias_tensor,
)
else:
raise ValueError(f"The type {quant_info['type']} dose not support in Quark now!")
quant_module.load_quant_params(quant_params)
return quant_module
# TODO: add support for SequentialQuantize later
# TODO: better `reorder` doc
[docs]
@log_errors
def load_params(model: Optional[nn.Module] = None,
json_path: str = "",
safetensors_path: str = "",
pth_path: str = "",
quant_mode: QuantizationMode = QuantizationMode.eager_mode,
compressed: bool = False,
reorder: bool = True) -> nn.Module:
"""
Instantiates a quantized model from saved model files, which is generated from the :py:func:`quark.torch.export.api.save_params` function.
:param torch.nn.Module model: The original Pytorch model.
:param str json_path: The path of the saved json file. Only available for eager mode quantization.
:param str safetensors_path: The path of the saved safetensors file. Only available for eager mode quantization.
:param str pth_path: The path of the saved ``.pth`` file. Only available for ``fx_graph`` mode quantization.
:param QuantizationMode quant_mode: The quantization mode. The choice includes ``"QuantizationMode.eager_mode"`` and ``"QuantizationMode.fx_graph_mode"``. Default is ``"QuantizationMode.eager_mode"``.
:param bool compressed: Whether the quantized model to load is stored using its compressed data type, or in a "fake quantized" format (QDQ).
:param bool reorder: Reorder.
:return: The reloaded quantized version of the input model.
:rtype: torch.nn.Module
Examples:
.. code-block:: python
# eager mode:
from quark.torch import load_params
model = load_params(model, json_path=json_path, safetensors_path=safetensors_path)
.. code-block:: python
# fx_graph mode:
from quark.torch.quantization.api import load_params
model = load_params(pth_path=model_file_path, quant_mode=QuantizationMode.fx_graph_mode)
Note:
This function does not support dynamic quantization for now.
"""
if quant_mode is QuantizationMode.eager_mode:
if not is_safetensors_available():
raise ImportError(
"The function `load_params` with `quant_mode=QuantizationMode.eager_mode` requires the package `safetensors` to be installed, but it was not found. Please install `safetensors`."
)
if model is None:
raise ValueError("Model should not be none if loading eager_mode quantized model")
if json_path == "" or safetensors_path == "":
raise ValueError("Json_path and safetensors_path should not be empty if loading eager_mode quantized model")
# load model structure and parameters
with open(json_path, "r") as file:
model_dict = json.load(file)
params_dict = load_file(safetensors_path)
# verify exported model and float model have the same configuration
model_config = model_dict["config"]
if model_config:
float_model_config: Dict[str, Any] = {}
if hasattr(model.config, "to_diff_dict"):
float_model_config = model.config.to_diff_dict()
elif hasattr(model.config, "items"):
float_model_config = dict(model.config.items())
if not deep_compare(model_config, float_model_config):
raise RuntimeError("Exported model and float model are not the same model!")
# assert ((json.dumps(model_config) == json.dumps(float_model_config)),
# "Exported model and float model are not the same model!")
logger.info("In-place OPs replacement start.")
for name, module_info in get_name_and_info(model_dict["structure"]):
float_module = get_op_by_name(model, name)
if module_info["type"] in QUARK_QUANT_OPS:
module = from_float_and_dict(float_module,
module_info,
params_dict,
layer_name=name,
compressed=compressed,
reorder=reorder)
set_op_by_name(model, name, module)
else:
device = float_module.weight.device
float_module.weight.data = params_dict[module_info.get("weight", None)].to(device)
if module_info.get("bias", None) is not None:
float_module.bias.data = params_dict[module_info.get("bias", None)].to(device)
model = ModelQuantizer.freeze(model)
logger.info("In-place OPs replacement end.")
elif quant_mode is QuantizationMode.fx_graph_mode:
if pth_path == "":
raise ValueError("Pth_path should not be empty if loading eager_mode quantized model")
loaded_quantized_ep = torch.export.load(pth_path)
model = loaded_quantized_ep.module()
return model