Source code for quark.onnx.quantization.api

#
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
#
"""Quark Quantization API for ONNX."""

import os
import logging
import warnings
from quark.shares.utils.log import ScreenLogger, log_errors

from pathlib import Path
from typing import Union, Optional

import onnx
from onnxruntime.quantization.calibrate import CalibrationDataReader

from quark.onnx.quantization.config.config import Config
from quark.onnx.quantize import quantize_static, quantize_dynamic

__all__ = ["ModelQuantizer"]

logger = ScreenLogger(__name__)


[docs] class ModelQuantizer: """ Provides an API for quantizing deep learning models using ONNX. This class handles the configuration and processing of the model for quantization based on user-defined parameters. Args: config (Config): Configuration object containing settings for quantization. Note: It is essential to ensure that the 'config' provided has all necessary quantization parameters defined. This class assumes that the model is compatible with the quantization settings specified in 'config'. """ def __init__(self, config: Config) -> None: """ Initializes the ModelQuantizer with the provided configuration. Args: config (Config): Configuration object containing global quantization settings. """ self.config = config.global_quant_config self.set_logging_level() if self.config.ignore_warnings: warnings.simplefilter("ignore", ResourceWarning) warnings.simplefilter("ignore", UserWarning) def set_logging_level(self) -> None: if self.config.debug_mode: ScreenLogger.set_shared_level(logging.DEBUG) elif self.config.crypto_mode: ScreenLogger.set_shared_level(logging.CRITICAL) elif self.config.log_severity_level == 0: ScreenLogger.set_shared_level(logging.DEBUG) elif self.config.log_severity_level == 1: ScreenLogger.set_shared_level(logging.INFO) elif self.config.log_severity_level == 2: ScreenLogger.set_shared_level(logging.WARNING) elif self.config.log_severity_level == 3: ScreenLogger.set_shared_level(logging.ERROR) else: ScreenLogger.set_shared_level(logging.CRITICAL)
[docs] @log_errors def quantize_model(self, model_input: Union[str, Path, onnx.ModelProto], model_output: Optional[Union[str, Path]] = None, calibration_data_reader: Optional[CalibrationDataReader] = None, calibration_data_path: Optional[str] = None) -> Optional[onnx.ModelProto]: """ Quantizes the given ONNX model and saves the output to the specified path or returns a ModelProto. :param Union[str, Path, onnx.ModelProto] model_input: Path to the input ONNX model file or a ModelProto. :param Optional[Union[str, Path]] model_output: Path where the quantized ONNX model will be saved. Defaults to ``None``, in which case the model is not saved but the function returns a ModelProto. :param Union[CalibrationDataReader, None] calibration_data_reader: Data reader for model calibration. Defaults to ``None``. :return: ``None`` """ if isinstance(model_input, (str, Path)) and not os.path.exists(model_input): raise FileNotFoundError(f"Input model file {model_input} does not exist.") if not self.config.use_dynamic_quant: return quantize_static(model_input=model_input, model_output=model_output, calibration_data_reader=calibration_data_reader, calibration_data_path=calibration_data_path, calibrate_method=self.config.calibrate_method, quant_format=self.config.quant_format, activation_type=self.config.activation_type, weight_type=self.config.weight_type, input_nodes=self.config.input_nodes, output_nodes=self.config.output_nodes, op_types_to_quantize=self.config.op_types_to_quantize, nodes_to_quantize=self.config.nodes_to_quantize, extra_op_types_to_quantize=self.config.extra_op_types_to_quantize, nodes_to_exclude=self.config.nodes_to_exclude, subgraphs_to_exclude=self.config.subgraphs_to_exclude, specific_tensor_precision=self.config.specific_tensor_precision, execution_providers=self.config.execution_providers, per_channel=self.config.per_channel, reduce_range=self.config.reduce_range, optimize_model=self.config.optimize_model, use_external_data_format=self.config.use_external_data_format, convert_fp16_to_fp32=self.config.convert_fp16_to_fp32, convert_nchw_to_nhwc=self.config.convert_nchw_to_nhwc, include_sq=self.config.include_sq, include_rotation=self.config.include_rotation, include_cle=self.config.include_cle, include_auto_mp=self.config.include_auto_mp, include_fast_ft=self.config.include_fast_ft, enable_npu_cnn=self.config.enable_npu_cnn, enable_npu_transformer=self.config.enable_npu_transformer, debug_mode=self.config.debug_mode, crypto_mode=self.config.crypto_mode, print_summary=self.config.print_summary, extra_options=self.config.extra_options) else: return quantize_dynamic(model_input=model_input, model_output=model_output, op_types_to_quantize=self.config.op_types_to_quantize, per_channel=self.config.per_channel, reduce_range=self.config.reduce_range, weight_type=self.config.weight_type, nodes_to_quantize=self.config.nodes_to_quantize, nodes_to_exclude=self.config.nodes_to_exclude, subgraphs_to_exclude=self.config.subgraphs_to_exclude, use_external_data_format=self.config.use_external_data_format, debug_mode=self.config.debug_mode, crypto_mode=self.config.crypto_mode, extra_options=self.config.extra_options)