Source code for quark.onnx.quantization.api

#
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
#
"""Quark Quantization API for ONNX."""

import logging
import os
import warnings
from pathlib import Path
from typing import List, Optional, Union

import onnx
from onnxruntime.quantization.calibrate import CalibrationDataReader

from quark.onnx.quant_utils import recursive_update
from quark.onnx.quantization.config.config import Config, QConfig, QuantizationConfig
from quark.onnx.quantization.config.maps import _map_q_config
from quark.onnx.quantize import quantize_dynamic, quantize_static
from quark.shares.utils.log import ScreenLogger, log_errors

from .config.algorithm import (
    AdaQuantConfig,
    AdaRoundConfig,
    AlgoConfig,
    AutoMixprecisionConfig,
    CLEConfig,
    QuarotConfig,
    SmoothQuantConfig,
    _algo_flag,
    _resolove_algo_conflict,
)

__all__ = ["ModelQuantizer"]

logger = ScreenLogger(__name__)


[docs] class ModelQuantizer: """Provides an API for quantizing deep learning models using ONNX. This class handles the configuration and processing of the model for quantization based on user-defined parameters. :param Config config: Configuration object containing settings for quantization. Note: - It is essential to ensure that the 'config' provided has all necessary quantization parameters defined. - This class assumes that the model is compatible with the quantization settings specified in 'config'. """ def __init__(self, config: Union[Config, QConfig]) -> None: """Initializes the ModelQuantizer with the provided configuration. :param Config config: Configuration object containing global quantization settings. """ if isinstance(config, Config): logger.warning("Config has been replaced by QConfig. The old API will be removed in the next release.") self.config = config.global_quant_config self.set_logging_level() if self.config.ignore_warnings: warnings.simplefilter("ignore", ResourceWarning) warnings.simplefilter("ignore", UserWarning) elif isinstance(config, QConfig): self.config = config # type: ignore self.set_logging_level() if "IgnoreWarnings" in self.config.extra_options and self.config.extra_options["IgnoreWarnings"]: warnings.simplefilter("ignore", ResourceWarning) warnings.simplefilter("ignore", UserWarning) else: raise ValueError("quantization config must be one of Config and QConfig.") def set_logging_level(self) -> None: if isinstance(self.config, QuantizationConfig): if self.config.debug_mode: ScreenLogger.set_shared_level(logging.DEBUG) elif self.config.crypto_mode: ScreenLogger.set_shared_level(logging.CRITICAL) elif self.config.log_severity_level == 0: ScreenLogger.set_shared_level(logging.DEBUG) elif self.config.log_severity_level == 1: ScreenLogger.set_shared_level(logging.INFO) elif self.config.log_severity_level == 2: ScreenLogger.set_shared_level(logging.WARNING) elif self.config.log_severity_level == 3: ScreenLogger.set_shared_level(logging.ERROR) else: ScreenLogger.set_shared_level(logging.CRITICAL) if isinstance(self.config, QConfig): if "DebugMode" in self.config.extra_options and self.config.extra_options["DebugMode"]: ScreenLogger.set_shared_level(logging.DEBUG) elif "CryptoMode" in self.config.extra_options and self.config.extra_options["CryptoMode"]: ScreenLogger.set_shared_level(logging.CRITICAL) elif "LogSeverityLevel" not in self.config.extra_options: ScreenLogger.set_shared_level(logging.INFO) elif "LogSeverityLevel" in self.config.extra_options: if self.config.extra_options["LogSeverityLevel"] == 0: ScreenLogger.set_shared_level(logging.DEBUG) if self.config.extra_options["LogSeverityLevel"] == 1: ScreenLogger.set_shared_level(logging.INFO) if self.config.extra_options["LogSeverityLevel"] == 2: ScreenLogger.set_shared_level(logging.WARNING) if self.config.extra_options["LogSeverityLevel"] == 3: ScreenLogger.set_shared_level(logging.ERROR) else: ScreenLogger.set_shared_level(logging.CRITICAL)
[docs] @log_errors def quantize_model( self, model_input: Union[str, Path, onnx.ModelProto], model_output: Union[str, Path] | None = None, calibration_data_reader: CalibrationDataReader | None = None, calibration_data_path: str | None = None, algorithms: list[AlgoConfig] | None = None, ) -> onnx.ModelProto | None: """Quantizes the given ONNX model and saves the output to the specified path or returns a ModelProto. :param Union[str, Path, onnx.ModelProto] model_input: Path to the input ONNX model file or a ModelProto. :param Optional[Union[str, Path]] model_output: Path where the quantized ONNX model will be saved. Defaults to ``None``, in which case the model is not saved but the function returns a ModelProto. :param Union[CalibrationDataReader, None] calibration_data_reader: Data reader for model calibration. Defaults to ``None``. :param List[AlgoConfig] algorithms: List of algorithms like CLE, SmoothQuant and AdaRound. Defaults to ``None``. :return: None """ if isinstance(self.config, QuantizationConfig): algorithms = algorithms or [] logger.warning( "The algorithm API is algo_config in QConfig. The old API will be removed in the next release." ) if isinstance(self.config, QConfig): algorithms = self.config.algo_config if isinstance(model_input, (str, Path)) and not os.path.exists(model_input): raise FileNotFoundError(f"Input model file {model_input} does not exist.") if not (isinstance(self.config, QuantizationConfig) and self.config.use_dynamic_quant): algorithms = _resolove_algo_conflict(algorithms) for algo in algorithms: recursive_update(self.config.extra_options, algo._get_config(self.config.extra_options)) if isinstance(self.config, QuantizationConfig): return quantize_static( model_input=model_input, model_output=model_output, calibration_data_reader=calibration_data_reader, calibration_data_path=calibration_data_path, calibrate_method=self.config.calibrate_method, quant_format=self.config.quant_format, activation_type=self.config.activation_type, weight_type=self.config.weight_type, input_nodes=self.config.input_nodes, output_nodes=self.config.output_nodes, op_types_to_quantize=self.config.op_types_to_quantize, nodes_to_quantize=self.config.nodes_to_quantize, extra_op_types_to_quantize=self.config.extra_op_types_to_quantize, nodes_to_exclude=self.config.nodes_to_exclude, subgraphs_to_exclude=self.config.subgraphs_to_exclude, specific_tensor_precision=self.config.specific_tensor_precision, execution_providers=self.config.execution_providers, per_channel=self.config.per_channel, reduce_range=self.config.reduce_range, optimize_model=self.config.optimize_model, use_external_data_format=self.config.use_external_data_format, convert_fp16_to_fp32=self.config.convert_fp16_to_fp32, convert_nchw_to_nhwc=self.config.convert_nchw_to_nhwc, include_sq=(self.config.include_sq or _algo_flag(algorithms, SmoothQuantConfig)), include_rotation=(self.config.include_rotation or _algo_flag(algorithms, QuarotConfig)), include_cle=(self.config.include_cle or _algo_flag(algorithms, CLEConfig)), include_auto_mp=(self.config.include_auto_mp or _algo_flag(algorithms, AutoMixprecisionConfig)), include_fast_ft=( self.config.include_fast_ft or _algo_flag(algorithms, AdaRoundConfig) or _algo_flag(algorithms, AdaQuantConfig) ), enable_npu_cnn=self.config.enable_npu_cnn, enable_npu_transformer=self.config.enable_npu_transformer, debug_mode=self.config.debug_mode, crypto_mode=self.config.crypto_mode, print_summary=self.config.print_summary, extra_options=self.config.extra_options, ) if isinstance(self.config, QConfig): mapping = _map_q_config(self.config, model_input) return quantize_static( model_input=model_input, model_output=model_output, calibration_data_reader=calibration_data_reader, calibration_data_path=calibration_data_path, calibrate_method=mapping["calibrate_method"], quant_format=mapping["quant_format"], activation_type=mapping["activation_type"].map_onnx_format, weight_type=mapping["weight_type"].map_onnx_format, input_nodes=mapping["extra_options"]["InputNodes"], output_nodes=mapping["extra_options"]["OutputNodes"], op_types_to_quantize=mapping["extra_options"]["OpTypesToQuantize"], nodes_to_quantize=mapping["extra_options"]["NodesToQuantize"], specific_tensor_precision=mapping["extra_options"]["SpecificTensorPrecision"], extra_op_types_to_quantize=mapping["extra_options"]["ExtraOpTypesToQuantize"], execution_providers=mapping["extra_options"]["ExecutionProviders"], optimize_model=mapping["extra_options"]["OptimizeModel"], convert_fp16_to_fp32=mapping["extra_options"]["ConvertFP16ToFP32"], convert_nchw_to_nhwc=mapping["extra_options"]["ConvertNCHWToNHWC"], enable_npu_cnn=mapping["extra_options"]["EnableNPUCnn"], debug_mode=mapping["extra_options"]["DebugMode"], crypto_mode=mapping["extra_options"]["CryptoMode"], print_summary=mapping["extra_options"]["PrintSummary"], nodes_to_exclude=mapping["nodes_to_exclude"], subgraphs_to_exclude=mapping["subgraphs_to_exclude"], per_channel=mapping["per_channel"], use_external_data_format=mapping["use_external_data_format"], include_sq=_algo_flag(algorithms, SmoothQuantConfig), include_rotation=_algo_flag(algorithms, QuarotConfig), include_cle=_algo_flag(algorithms, CLEConfig), include_auto_mp=_algo_flag(algorithms, AutoMixprecisionConfig), include_fast_ft=_algo_flag(algorithms, AdaRoundConfig) or _algo_flag(algorithms, AdaQuantConfig), extra_options=mapping["extra_options"], ) else: return quantize_dynamic( model_input=model_input, model_output=model_output, op_types_to_quantize=self.config.op_types_to_quantize, per_channel=self.config.per_channel, reduce_range=self.config.reduce_range, weight_type=self.config.weight_type, nodes_to_quantize=self.config.nodes_to_quantize, nodes_to_exclude=self.config.nodes_to_exclude, subgraphs_to_exclude=self.config.subgraphs_to_exclude, use_external_data_format=self.config.use_external_data_format, debug_mode=self.config.debug_mode, crypto_mode=self.config.crypto_mode, extra_options=self.config.extra_options, )