#
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
#
"""Quark Quantization API for ONNX."""
import logging
import os
import warnings
from pathlib import Path
from typing import List, Optional, Union
import onnx
from onnxruntime.quantization.calibrate import CalibrationDataReader
from quark.onnx.quant_utils import recursive_update
from quark.onnx.quantization.config.config import Config, QConfig, QuantizationConfig
from quark.onnx.quantization.config.maps import _map_q_config
from quark.onnx.quantize import quantize_dynamic, quantize_static
from quark.shares.utils.log import ScreenLogger, log_errors
from .config.algorithm import (
    AdaQuantConfig,
    AdaRoundConfig,
    AlgoConfig,
    AutoMixprecisionConfig,
    CLEConfig,
    QuarotConfig,
    SmoothQuantConfig,
    _algo_flag,
    _resolove_algo_conflict,
)
__all__ = ["ModelQuantizer"]
logger = ScreenLogger(__name__)
[docs]
class ModelQuantizer:
    """Provides an API for quantizing deep learning models using ONNX.
    This class handles the configuration and processing of the model for quantization based on user-defined parameters.
    :param Config config: Configuration object containing settings for quantization.
    Note:
        - It is essential to ensure that the 'config' provided has all necessary quantization parameters defined.
        - This class assumes that the model is compatible with the quantization settings specified in 'config'.
    """
    def __init__(self, config: Union[Config, QConfig]) -> None:
        """Initializes the ModelQuantizer with the provided configuration.
        :param Config config: Configuration object containing global quantization settings.
        """
        if isinstance(config, Config):
            logger.warning("Config has been replaced by QConfig. The old API will be removed in the next release.")
            self.config = config.global_quant_config
            self.set_logging_level()
            if self.config.ignore_warnings:
                warnings.simplefilter("ignore", ResourceWarning)
                warnings.simplefilter("ignore", UserWarning)
        elif isinstance(config, QConfig):
            self.config = config  # type: ignore
            self.set_logging_level()
            if "IgnoreWarnings" in self.config.extra_options and self.config.extra_options["IgnoreWarnings"]:
                warnings.simplefilter("ignore", ResourceWarning)
                warnings.simplefilter("ignore", UserWarning)
        else:
            raise ValueError("quantization config must be one of Config and QConfig.")
    def set_logging_level(self) -> None:
        if isinstance(self.config, QuantizationConfig):
            if self.config.debug_mode:
                ScreenLogger.set_shared_level(logging.DEBUG)
            elif self.config.crypto_mode:
                ScreenLogger.set_shared_level(logging.CRITICAL)
            elif self.config.log_severity_level == 0:
                ScreenLogger.set_shared_level(logging.DEBUG)
            elif self.config.log_severity_level == 1:
                ScreenLogger.set_shared_level(logging.INFO)
            elif self.config.log_severity_level == 2:
                ScreenLogger.set_shared_level(logging.WARNING)
            elif self.config.log_severity_level == 3:
                ScreenLogger.set_shared_level(logging.ERROR)
            else:
                ScreenLogger.set_shared_level(logging.CRITICAL)
        if isinstance(self.config, QConfig):
            if "DebugMode" in self.config.extra_options and self.config.extra_options["DebugMode"]:
                ScreenLogger.set_shared_level(logging.DEBUG)
            elif "CryptoMode" in self.config.extra_options and self.config.extra_options["CryptoMode"]:
                ScreenLogger.set_shared_level(logging.CRITICAL)
            elif "LogSeverityLevel" not in self.config.extra_options:
                ScreenLogger.set_shared_level(logging.INFO)
            elif "LogSeverityLevel" in self.config.extra_options:
                if self.config.extra_options["LogSeverityLevel"] == 0:
                    ScreenLogger.set_shared_level(logging.DEBUG)
                if self.config.extra_options["LogSeverityLevel"] == 1:
                    ScreenLogger.set_shared_level(logging.INFO)
                if self.config.extra_options["LogSeverityLevel"] == 2:
                    ScreenLogger.set_shared_level(logging.WARNING)
                if self.config.extra_options["LogSeverityLevel"] == 3:
                    ScreenLogger.set_shared_level(logging.ERROR)
            else:
                ScreenLogger.set_shared_level(logging.CRITICAL)
[docs]
    @log_errors
    def quantize_model(
        self,
        model_input: Union[str, Path, onnx.ModelProto],
        model_output: Union[str, Path] | None = None,
        calibration_data_reader: CalibrationDataReader | None = None,
        calibration_data_path: str | None = None,
        algorithms: list[AlgoConfig] | None = None,
    ) -> onnx.ModelProto | None:
        """Quantizes the given ONNX model and saves the output to the specified path or returns a ModelProto.
        :param Union[str, Path, onnx.ModelProto] model_input: Path to the input ONNX model file or a ModelProto.
        :param Optional[Union[str, Path]] model_output: Path where the quantized ONNX model will be saved. Defaults to ``None``, in which case the model is not saved but the function returns a ModelProto.
        :param Union[CalibrationDataReader, None] calibration_data_reader: Data reader for model calibration. Defaults to ``None``.
        :param List[AlgoConfig] algorithms: List of algorithms like CLE, SmoothQuant and AdaRound. Defaults to ``None``.
        :return: None
        """
        if isinstance(self.config, QuantizationConfig):
            algorithms = algorithms or []
            logger.warning(
                "The algorithm API is algo_config in QConfig. The old API will be removed in the next release."
            )
        if isinstance(self.config, QConfig):
            algorithms = self.config.algo_config
        if isinstance(model_input, (str, Path)) and not os.path.exists(model_input):
            raise FileNotFoundError(f"Input model file {model_input} does not exist.")
        if not (isinstance(self.config, QuantizationConfig) and self.config.use_dynamic_quant):
            algorithms = _resolove_algo_conflict(algorithms)
            for algo in algorithms:
                recursive_update(self.config.extra_options, algo._get_config(self.config.extra_options))
            if isinstance(self.config, QuantizationConfig):
                return quantize_static(
                    model_input=model_input,
                    model_output=model_output,
                    calibration_data_reader=calibration_data_reader,
                    calibration_data_path=calibration_data_path,
                    calibrate_method=self.config.calibrate_method,
                    quant_format=self.config.quant_format,
                    activation_type=self.config.activation_type,
                    weight_type=self.config.weight_type,
                    input_nodes=self.config.input_nodes,
                    output_nodes=self.config.output_nodes,
                    op_types_to_quantize=self.config.op_types_to_quantize,
                    nodes_to_quantize=self.config.nodes_to_quantize,
                    extra_op_types_to_quantize=self.config.extra_op_types_to_quantize,
                    nodes_to_exclude=self.config.nodes_to_exclude,
                    subgraphs_to_exclude=self.config.subgraphs_to_exclude,
                    specific_tensor_precision=self.config.specific_tensor_precision,
                    execution_providers=self.config.execution_providers,
                    per_channel=self.config.per_channel,
                    reduce_range=self.config.reduce_range,
                    optimize_model=self.config.optimize_model,
                    use_external_data_format=self.config.use_external_data_format,
                    convert_fp16_to_fp32=self.config.convert_fp16_to_fp32,
                    convert_nchw_to_nhwc=self.config.convert_nchw_to_nhwc,
                    include_sq=(self.config.include_sq or _algo_flag(algorithms, SmoothQuantConfig)),
                    include_rotation=(self.config.include_rotation or _algo_flag(algorithms, QuarotConfig)),
                    include_cle=(self.config.include_cle or _algo_flag(algorithms, CLEConfig)),
                    include_auto_mp=(self.config.include_auto_mp or _algo_flag(algorithms, AutoMixprecisionConfig)),
                    include_fast_ft=(
                        self.config.include_fast_ft
                        or _algo_flag(algorithms, AdaRoundConfig)
                        or _algo_flag(algorithms, AdaQuantConfig)
                    ),
                    enable_npu_cnn=self.config.enable_npu_cnn,
                    enable_npu_transformer=self.config.enable_npu_transformer,
                    debug_mode=self.config.debug_mode,
                    crypto_mode=self.config.crypto_mode,
                    print_summary=self.config.print_summary,
                    extra_options=self.config.extra_options,
                )
            if isinstance(self.config, QConfig):
                mapping = _map_q_config(self.config, model_input)
                return quantize_static(
                    model_input=model_input,
                    model_output=model_output,
                    calibration_data_reader=calibration_data_reader,
                    calibration_data_path=calibration_data_path,
                    calibrate_method=mapping["calibrate_method"],
                    quant_format=mapping["quant_format"],
                    activation_type=mapping["activation_type"].map_onnx_format,
                    weight_type=mapping["weight_type"].map_onnx_format,
                    input_nodes=mapping["extra_options"]["InputNodes"],
                    output_nodes=mapping["extra_options"]["OutputNodes"],
                    op_types_to_quantize=mapping["extra_options"]["OpTypesToQuantize"],
                    nodes_to_quantize=mapping["extra_options"]["NodesToQuantize"],
                    specific_tensor_precision=mapping["extra_options"]["SpecificTensorPrecision"],
                    extra_op_types_to_quantize=mapping["extra_options"]["ExtraOpTypesToQuantize"],
                    execution_providers=mapping["extra_options"]["ExecutionProviders"],
                    optimize_model=mapping["extra_options"]["OptimizeModel"],
                    convert_fp16_to_fp32=mapping["extra_options"]["ConvertFP16ToFP32"],
                    convert_nchw_to_nhwc=mapping["extra_options"]["ConvertNCHWToNHWC"],
                    enable_npu_cnn=mapping["extra_options"]["EnableNPUCnn"],
                    debug_mode=mapping["extra_options"]["DebugMode"],
                    crypto_mode=mapping["extra_options"]["CryptoMode"],
                    print_summary=mapping["extra_options"]["PrintSummary"],
                    nodes_to_exclude=mapping["nodes_to_exclude"],
                    subgraphs_to_exclude=mapping["subgraphs_to_exclude"],
                    per_channel=mapping["per_channel"],
                    use_external_data_format=mapping["use_external_data_format"],
                    include_sq=_algo_flag(algorithms, SmoothQuantConfig),
                    include_rotation=_algo_flag(algorithms, QuarotConfig),
                    include_cle=_algo_flag(algorithms, CLEConfig),
                    include_auto_mp=_algo_flag(algorithms, AutoMixprecisionConfig),
                    include_fast_ft=_algo_flag(algorithms, AdaRoundConfig) or _algo_flag(algorithms, AdaQuantConfig),
                    extra_options=mapping["extra_options"],
                )
        else:
            return quantize_dynamic(
                model_input=model_input,
                model_output=model_output,
                op_types_to_quantize=self.config.op_types_to_quantize,
                per_channel=self.config.per_channel,
                reduce_range=self.config.reduce_range,
                weight_type=self.config.weight_type,
                nodes_to_quantize=self.config.nodes_to_quantize,
                nodes_to_exclude=self.config.nodes_to_exclude,
                subgraphs_to_exclude=self.config.subgraphs_to_exclude,
                use_external_data_format=self.config.use_external_data_format,
                debug_mode=self.config.debug_mode,
                crypto_mode=self.config.crypto_mode,
                extra_options=self.config.extra_options,
            )