Source code for quark.torch.export.api

#
# Copyright (C) 2023 - 2025 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
#
"""Quark Exporting and Importing API for PyTorch."""

from __future__ import annotations

import dataclasses
import json
import re
import shutil
import tempfile
from abc import ABC, abstractmethod
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Any

import torch
import torch.nn as nn
from tqdm import tqdm

if TYPE_CHECKING:
    from quark.torch.quantization.config.config import QConfig

from quark.shares.utils.import_utils import (
    is_accelerate_available,
    is_gguf_available_and_version_0_6_0,
    is_safetensors_available,
    is_transformers_available,
)
from quark.shares.utils.log import ScreenLogger
from quark.torch.algorithm.rotation.rotation import RotationProcessor
from quark.torch.export.config.config import JsonExporterConfig
from quark.torch.export.json_export.builder.native_model_info_builder import NativeModelInfoBuilder
from quark.torch.export.main_export.model_post_process import ModelPostProcessor
from quark.torch.export.main_export.quant_config_parser import QuantConfigParser, get_layer_quant_config
from quark.torch.export.main_import.pretrained_config import PretrainedConfig
from quark.torch.export.nn.modules.qparamslinear import QParamsLinear, QParamsLinearWithRotation
from quark.torch.export.onnx import convert_model_to_uint4_int4, export_onnx_model_optimization
from quark.torch.export.safetensors import _load_weights_from_safetensors, export_hf_model
from quark.torch.export.utils import (
    _build_quantized_model,
    _convert_quantized_model,
    _fix_loaded_weights_key_mismatch,
    _fix_state_dict_key_on_save,
    _handle_multi_device_loading,
    _untie_parameters,
)
from quark.torch.quantization.config.type import QuantizationMode
from quark.torch.quantization.model_transformation import (
    export_cache_state_dict_from_model,
    import_model_with_cache_from_safetensors,
)
from quark.torch.quantization.tensor_quantize import NonScaledFakeQuantize, ScaledFakeQuantize
from quark.torch.utils import QPARAMSLINEAR_OVERRIDES_STATE_DICT, setattr_recursive

if is_gguf_available_and_version_0_6_0():
    from quark.torch.export.gguf_export.api import convert_exported_model_to_gguf
if is_transformers_available():
    from transformers import PreTrainedModel
if is_accelerate_available():
    from accelerate.hooks import AlignDevicesHook, add_hook_to_module
if is_safetensors_available():
    from safetensors.torch import save_file

__all__ = [
    "export_safetensors",
    "export_onnx",
    "export_gguf",
    "import_model_from_safetensors",
    "save_params",
]

logger = ScreenLogger(__name__)


def _get_submodule_or_none(model: nn.Module, module_path: str) -> nn.Module | None:
    """Resolve a dotted submodule path without raising exceptions.

    Works with nested modules and ModuleList entries (numeric names).
    Returns None if any segment is missing.
    """
    cur: nn.Module = model
    for part in module_path.split("."):
        modules = getattr(cur, "_modules", None)
        if not isinstance(modules, dict) or part not in modules:
            return None
        cur = modules[part]
    return cur


class BaseExporter(ABC):
    """Base class for all model exporters."""

    def __init__(self, **kwargs: Any) -> None:
        self.output_dir: Path

    @abstractmethod
    def _validate(self) -> None:
        """Validate export parameters and model compatibility."""
        raise NotImplementedError("`_validate` method could not be called directly in BaseExporter")

    @abstractmethod
    def _export_impl(self) -> None:
        """Perform the actual export operation."""
        raise NotImplementedError("`_export_impl` method could not be called directly in BaseExporter")

    def _export(self, *args: Any, **kwargs: Any) -> None:
        """Process the model for validation and export. Sets attributes, validates, and exports."""

        self.output_dir.mkdir(parents=True, exist_ok=True)

        self._validate()
        self._export_impl()


class SafetensorsExporter(BaseExporter):
    """Base exporter for Safetensors format."""

    def __init__(
        self, model: torch.nn.Module, output_dir: Path, custom_mode: str, weight_format: str, pack_method: str
    ) -> None:
        super().__init__()

        self.model = model
        self.output_dir = output_dir
        self.custom_mode = custom_mode
        self.weight_format = weight_format
        self.pack_method = pack_method

    def _validate(self) -> None:
        """Validate Safetensors export parameters."""
        if self.weight_format not in ["real_quantized", "fake_quantized"]:
            raise ValueError(
                f"Weight_format must be one of `real_quantized`, `fake_quantized` when exporting to safetensors format, got {self.weight_format}."
            )
        if self.pack_method not in ["reorder", "order"]:
            raise ValueError(
                f"Pack_method must be one of `reorder`, `order` when exporting to safetensors format, got {self.pack_method}."
            )

        # Validate that model has quant_config
        if getattr(self.model, "quark_quantized", False) and getattr(self.model, "quant_config", None) is None:
            raise ValueError("Model must have a 'quant_config' attribute if it is quantized with quark.")

        # Validate model type
        if not is_transformers_available() or not isinstance(self.model, PreTrainedModel):
            raise NotImplementedError(
                "Exporting to safetensors format is currently only supported for Transformers models. Please open an issue."
            )

    def _prepare_quantization_config(
        self, quant_config: QConfig, temp_json_config: JsonExporterConfig
    ) -> dict[str, Any]:
        """Prepare quantization configuration. To be implemented by subclasses."""
        raise NotImplementedError("Subclasses must implement _prepare_quantization_config")

    def _export_impl(self) -> None:
        """Export model to Safetensors format."""
        # Get quant_config from the model
        quant_config = getattr(self.model, "quant_config", None)

        # Create a copy of the model to avoid modifying the original
        original_config = self.model.config.__dict__.copy()

        # Prepare cache for export if present
        cache_prepared = False
        if quant_config is not None:
            from quark.torch.quantization.model_transformation import prepare_model_for_cache_export

            cache_prepared = prepare_model_for_cache_export(self.model, quant_config)

        # Create temporary config objects for processing
        # If kv_cache is not quantized, there is no need to export kv_cache name.
        kv_cache_group = (
            getattr(quant_config, "kv_cache_group", []) if getattr(quant_config, "kv_cache_quant_config", {}) else []
        )
        temp_json_config = JsonExporterConfig(
            weight_format=self.weight_format,
            pack_method=self.pack_method,
            kv_cache_group=kv_cache_group,
            min_kv_scale=getattr(quant_config, "min_kv_scale", 0.0),
        )

        if quant_config is not None:
            # Prepare quantization configuration
            quantization_config_dict = self._prepare_quantization_config(quant_config, temp_json_config)

            # Update model config with quantization info
            self.model.config.update({"quantization_config": quantization_config_dict})

        # Process model for export
        processor = ModelPostProcessor(
            self.model,
            temp_json_config,
            custom_mode=self.custom_mode,
            output_quant=quant_config is not None and quant_config.global_quant_config.output_tensors is not None,
        )
        processor.merge_scale()
        processed_model = processor.get_processed_model()

        # Export cache state dict if present
        cache_state_dict = {}
        if cache_prepared and quant_config is not None:
            cache_state_dict = export_cache_state_dict_from_model(self.model, quant_config)
            if cache_state_dict:
                logger.info(f"Including cache quantization state dict with {len(cache_state_dict)} entries")

        # Export using HF format (single call), optionally attach alias-only output scales
        alias_only_cache: dict[str, torch.Tensor] = (
            {k: v for k, v in cache_state_dict.items() if k.endswith(".output_scale")} if cache_state_dict else {}
        )

        inserted_buffers: list[tuple[nn.Module, str]] = []
        if alias_only_cache:
            for full_key, tensor in alias_only_cache.items():
                module_path = full_key.rsplit(".", 1)[0]
                target_module = _get_submodule_or_none(processed_model, module_path)
                if target_module is None:
                    logger.debug(
                        "[CACHE EXPORT] Skipping alias %s because module path %s was not found",
                        full_key,
                        module_path,
                    )
                    continue

                existing_attr = getattr(target_module, "output_scale", None)
                if isinstance(existing_attr, torch.Tensor):
                    with torch.no_grad():
                        existing_attr.copy_(tensor)
                elif existing_attr is None:
                    target_module.register_buffer("output_scale", tensor.clone())
                    inserted_buffers.append((target_module, "output_scale"))
                else:
                    logger.debug(
                        "[CACHE EXPORT] Unable to attach output_scale to %s because attribute already exists",
                        module_path,
                    )
                    continue

        if self.weight_format == "real_quantized":
            # Useful only for Transformers models, see the comment below regarding the serialization keys.
            processed_model._fix_state_dict_key_on_save = staticmethod(_fix_state_dict_key_on_save)

        # Export using HF format
        export_hf_model(model=processed_model, export_dir=str(self.output_dir))

        # Clean up any temporary buffers we inserted
        for module, buffer_name in inserted_buffers:
            buffers = getattr(module, "_buffers", None)
            if isinstance(buffers, dict) and buffer_name in buffers:
                del buffers[buffer_name]
            if hasattr(module, buffer_name):
                delattr(module, buffer_name)

        # Reset model config to original state
        self.model.config.__dict__.clear()
        self.model.config.__dict__.update(original_config)

        # Reset model to original state
        processor.reset_model()

        logger.info(f"Successfully exported model to Safetensors format in {self.custom_mode} mode: {self.output_dir}")


class QuarkSafetensorsExporter(SafetensorsExporter):
    """Exporter for Safetensors format in quark mode."""

    def _validate(self) -> None:
        """Validate quark mode export parameters."""
        super()._validate()
        if self.custom_mode != "quark":
            raise ValueError(f"QuarkSafetensorsExporter only supports custom_mode='quark', got {self.custom_mode}.")

    def _prepare_quantization_config(
        self, quant_config: QConfig, temp_json_config: JsonExporterConfig
    ) -> dict[str, Any]:
        """Prepare quantization configuration for quark mode."""
        quark_quant_config = quant_config.to_dict()
        quantization_config_dict = {}

        # Handle quark mode
        quark_quant_config["export"] = dataclasses.asdict(temp_json_config)
        quantization_config_dict.update(quark_quant_config)
        return quantization_config_dict


class CustomSafetensorsExporter(SafetensorsExporter):
    """Exporter for Safetensors format in custom modes (awq, fp8)."""

    def _validate(self) -> None:
        """Validate custom mode export parameters."""
        super()._validate()
        if self.custom_mode not in ["awq", "fp8"]:
            raise ValueError(
                f"CustomSafetensorsExporter only supports custom_mode in ['awq', 'fp8'], got {self.custom_mode}."
            )

    def _prepare_quantization_config(
        self, quant_config: QConfig, temp_json_config: JsonExporterConfig
    ) -> dict[str, Any]:
        """Prepare quantization configuration for custom modes (awq, fp8)."""
        quark_quant_config = quant_config.to_dict()
        quantization_config_dict = {}
        config_parser = QuantConfigParser(quant_config, temp_json_config)

        # Handle custom modes (awq, fp8)
        custom_config, inferred_custom_mode = config_parser.get_custom_config()
        if inferred_custom_mode != self.custom_mode:
            raise ValueError(
                f"Requested to export the model in the custom mode `{self.custom_mode}`, but the quantization config used does not appear to match with this `custom_mode`."
            )

        if len(custom_config) > 0:
            quantization_config_dict.update(custom_config)
        else:
            quantization_config_dict.update(quark_quant_config)

        # Add export info for HF format
        quantization_config_dict["export"] = dataclasses.asdict(temp_json_config)
        return quantization_config_dict


class OnnxExporter(BaseExporter):
    """Exporter for ONNX format."""

    def __init__(
        self,
        model: torch.nn.Module,
        output_dir: Path,
        input_args: tuple[Any, ...],
        opset_version: int | None,
        input_names: list[str],
        output_names: list[str],
        verbose: bool,
        do_constant_folding: bool,
        operator_export_type: torch.onnx.OperatorExportTypes,
        uint4_int4_flag: bool,
        dynamo: bool = False,
    ) -> None:
        super().__init__()
        # Declare attributes that will be set by _export method

        self.model = model
        self.output_dir = output_dir
        self.input_args = input_args
        self.opset_version = opset_version
        self.input_names = input_names
        self.output_names = output_names
        self.verbose = verbose
        self.do_constant_folding = do_constant_folding
        self.operator_export_type = operator_export_type
        self.uint4_int4_flag = uint4_int4_flag
        self.dynamo = dynamo

    def _validate(self) -> None:
        """Validate ONNX export parameters."""
        # Basic validation - ONNX export is generally more permissive
        if not isinstance(self.input_args, (torch.Tensor, tuple)):
            raise ValueError("input_args must be a torch.Tensor or tuple")

    def _export_impl(self) -> None:
        """Export model to ONNX format."""
        logger.info("Start exporting quantized onnx model ...")

        # When transformers version in upper than 4.55.0, the use_cache option will cause DynamicCache in ONNX export and failed to export.
        # So we need to disable the use_cache option to avoid DynamicCache in ONNX export.
        if hasattr(self.model, "config"):
            original_use_cache = getattr(self.model.config, "use_cache", None)
            if hasattr(self.model.config, "use_cache"):
                self.model.config.use_cache = False

        # Enable fake quantization for ONNX export
        for module in self.model.modules():
            if isinstance(module, ScaledFakeQuantize):
                module.disable_observer()
                module.enable_fake_quant()

        # Define output path
        onnx_path = self.output_dir / "quark_model.onnx"

        # Export to ONNX
        try:
            torch.onnx.export(
                self.model.eval(),
                self.input_args,
                str(onnx_path),
                verbose=self.verbose,
                input_names=self.input_names,
                output_names=self.output_names,
                opset_version=self.opset_version,
                do_constant_folding=self.do_constant_folding,
                operator_export_type=self.operator_export_type,
                dynamo=self.dynamo,
            )
        except Exception as e:
            if not self.dynamo:
                raise Exception(
                    f"The ONNX export failed during `torch.onnx.export` call. This could be due to an issue in the model definition not compatible with torch.jit.trace-based ONNX export, or other issues. Consider trying to export using `dynamo=True`, refer to `quark.torch.export.api.export_onnx` API documentation. Error: {e}"
                )
            else:
                raise Exception(
                    f"The ONNX export failed during `torch.onnx.export` call. Consider trying to export using `dynamo=False`, refer to `quark.torch.export.api.export_onnx` API documentation. Error: {e}"
                )

        export_onnx_model_optimization(onnx_path)

        # Handle uint4/int4 conversion if needed
        if self.uint4_int4_flag:
            convert_model_to_uint4_int4(str(onnx_path))
        else:
            logger.info(f"Quantized onnx model exported to {onnx_path} successfully.")

        # restore the use_cache option
        if hasattr(self.model, "config") and hasattr(self.model.config, "use_cache"):
            self.model.config.use_cache = original_use_cache

        logger.info(f"Successfully exported model to ONNX format: {onnx_path}")


class GgufExporter(BaseExporter):
    """Exporter for GGUF format."""

    def __init__(self, model: torch.nn.Module, output_dir: Path, model_type: str, tokenizer_path: str | Path) -> None:
        super().__init__()

        self.model = model
        self.output_dir = output_dir
        self.model_type = model_type
        self.tokenizer_path = tokenizer_path

    def _validate(self) -> None:
        """Validate GGUF export parameters."""
        if not self.model_type:
            raise ValueError("model_type must be specified for GGUF export")

        # Check if tokenizer_path is a local path or HuggingFace model name
        if Path(self.tokenizer_path).exists():
            # It's a local path, validate it exists
            actual_tokenizer_path = self.tokenizer_path
        else:
            # Assume it's a HuggingFace model name - let the GGUF converter handle validation
            actual_tokenizer_path = self.tokenizer_path

        self.actual_tokenizer_path = actual_tokenizer_path

    def _export_impl(self) -> None:
        """Export model to GGUF format."""
        logger.info("Start exporting GGUF model ...")

        # First export to quark format (JSON + safetensors)
        temp_quark_dir = self.output_dir / "temp_quark_export"
        temp_quark_dir.mkdir(exist_ok=True)

        temp_config = JsonExporterConfig()
        params_dict: dict[str, torch.Tensor] = {}
        builder = NativeModelInfoBuilder(model=self.model, config=temp_config)
        info = builder.build_model_info(params_dict)

        # Save JSON info
        json_path = temp_quark_dir / f"{self.model_type}.json"
        with open(json_path, "w") as f:
            json.dump(info, f, indent=4)

        # Handle tensor sharing for safetensors
        data_ptr_list: list[str] = []
        for key, value in params_dict.items():
            if str(value.data_ptr()) in data_ptr_list:
                params_dict[key] = value.clone()
            else:
                data_ptr_list.append(str(value.data_ptr()))

        # Save safetensors
        if not is_safetensors_available():
            raise ImportError(
                "The function `export_gguf` requires the package `safetensors` to be installed, but it was not found. Please install `safetensors`."
            )

        safetensors_path = temp_quark_dir / f"{self.model_type}.safetensors"
        save_file(params_dict, safetensors_path)

        # Convert to GGUF format
        gguf_output_path = self.output_dir / f"{self.model_type}.gguf"
        convert_exported_model_to_gguf(
            model_name=self.model_type,
            json_path=json_path,
            safetensor_path=safetensors_path,
            tokenizer_dir=self.actual_tokenizer_path,
            output_file_path=gguf_output_path,
        )

        # Clean up temporary files
        shutil.rmtree(temp_quark_dir)

        logger.info(f"Successfully exported model to GGUF format: {gguf_output_path}")


[docs] def export_safetensors( model: torch.nn.Module, output_dir: str | Path, custom_mode: str = "quark", weight_format: str = "real_quantized", pack_method: str = "reorder", ) -> None: """ Export the quantized PyTorch model to Safetensors format. The model's network architecture or configuration is stored in the json file, and parameters including weight, bias, scale, and zero_point are stored in the safetensors file. :param torch.nn.Module model: The quantized model to be exported. :param Union[str, Path] output_dir: Directory to save the exported files. :param str custom_mode: Export mode determining quantization handling. Defaults to ``"quark"``. Possible values are: * ``"quark"``: standard quark format. This is the default and recommended format that should be favored. * ``"awq"``: targets AutoAWQ library. * ``"fp8"``: targets vLLM-compatible fp8 models. :param str weight_format: How to handle quantized parameters. Defaults to ``"real_quantized"``. Possible values are: * ``"real_quantized"``: actual quantized parameters. * ``"fake_quantized"``: QDQ (Quantize-Dequantize) representation of quantized parameters. :param str pack_method: Real_quantized parameter packing strategy. Defaults to ``"reorder"``. Possible values are: * ``"reorder"``: reorder the real_quantized parameters layout for hardware. * ``"order"``: keep the original real_quantized parameters layout. :return: ``None`` Example: .. code-block:: python from quark.torch import export_safetensors export_path = "./output_dir" export_safetensors(model, export_path, custom_mode="quark", weight_format="real_quantized", pack_method="reorder") """ # Get quant_config from the model if getattr(model, "quark_quantized", False) and getattr(model, "quant_config", None) is None: raise ValueError("Model must have a 'quant_config' attribute if it is quantized with quark.") if custom_mode != "quark": logger.warning( f"The 'custom_mode' parameter is deprecated and will be removed in version 1.0. " f"Currently using custom_mode='{custom_mode}', but only 'quark' mode will be supported in the future. " f"Please migrate to using custom_mode='quark'." ) # Choose the appropriate exporter based on custom_mode if custom_mode == "quark": exporter_cls = QuarkSafetensorsExporter elif custom_mode in ["awq", "fp8"]: exporter_cls = CustomSafetensorsExporter # type: ignore else: raise ValueError(f"Custom_mode must be one of `quark`, `fp8`, `awq`, got {custom_mode}.") exporter = exporter_cls( model=model, output_dir=Path(output_dir), custom_mode=custom_mode, weight_format=weight_format, pack_method=pack_method, ) exporter._export()
[docs] def export_onnx( model: torch.nn.Module, output_dir: str | Path, input_args: tuple[Any, ...], opset_version: int | None = None, input_names: list[str] = [], output_names: list[str] = [], verbose: bool = False, do_constant_folding: bool = True, operator_export_type: torch.onnx.OperatorExportTypes = torch.onnx.OperatorExportTypes.ONNX, uint4_int4_flag: bool = False, dynamo: bool = False, ) -> None: """ Export the onnx graph of the quantized PyTorch model. :param torch.nn.Module model: The quantized model to be exported. :param Union[str, Path] output_dir: Directory to save the ONNX file :param Union[torch.Tensor, Tuple[float]] input_args: Example inputs for ONNX tracing. :param Optional[int] opset_version: The version of the ONNX opset to target. If not set, it will be valued the latest version that is stable for the current version of PyTorch. Defaults to ``None``. :param List[str] input_names: Names to assign to the input nodes of the onnx graph, in order. Defaults to ``[]``. :param List[str] output_names: Names to assign to the output nodes of the onnx graph, in order. Defaults to ``[]``. :param bool verbose: Flag to control showing verbose log or no. Defaults to ``False``. :param bool do_constant_folding: Flag to apply constant folding optimization. Defaults to ``True``. :param torch.onnx.OperatorExportTypes operator_export_type: Export operator type in onnx graph. The choices include ``OperatorExportTypes.ONNX``, ``OperatorExportTypes.ONNX_FALLTHROUGH``, ``OperatorExportTypes.ONNX_ATEN`` and ``OperatorExportTypes.ONNX_ATEN_FALLBACK``. Defaults to ``OperatorExportTypes.ONNX``. :param bool uint4_int4_flag: Flag to indicate uint4/int4 quantized model or not. Defaults to ``False``. :param bool dynamo: Whether to export the model with ``torch.export`` ExportedProgram instead of TorchScript. Please refer to `PyTorch documentation <https://docs.pytorch.org/docs/stable/onnx_dynamo.html#torch.onnx.export>`__ for more details. Defaults to ``False``. :return: None Example: .. code-block:: python from quark.torch import export_onnx export_onnx(model, output_dir, input_args) **Note**: Mix quantization of int4/uint4 and int8/uint8 is not supported currently. In other words, if the model contains both quantized nodes of uint4/int4 and uint8/int8, this function cannot be used to export the ONNX graph. """ exporter = OnnxExporter( model=model, output_dir=Path(output_dir), input_args=input_args, opset_version=opset_version, input_names=input_names, output_names=output_names, verbose=verbose, do_constant_folding=do_constant_folding, operator_export_type=operator_export_type, uint4_int4_flag=uint4_int4_flag, dynamo=dynamo, ) exporter._export()
[docs] def export_gguf( model: torch.nn.Module, output_dir: str | Path, model_type: str, tokenizer_path: str | Path, ) -> None: """ Export the gguf file of the quantized PyTorch model. :param torch.nn.Module model: The quantized model to be exported. :param Union[str, Path] output_dir: Directory to save the GGUF file :param str model_type: The model type of the model, e.g. ``"gpt2"``, ``"gptj"``, or ``"llama"``. :param Union[str, Path] tokenizer_path: Tokenizer needs to be encoded into gguf model. This argument specifies the directory path of the tokenizer, which contains tokenizer.json, tokenizer_config.json and/or tokenizer.model. :return: None Example: .. code-block:: python from quark.torch import export_gguf export_gguf(model, output_dir, model_type, tokenizer_path) Note: Currently, only support asymetric int4 per_group weight-only quantization, and the group_size must be 32. Supported models include Llama2-7b, Llama2-13b, Llama2-70b, and Llama3-8b. """ if not is_gguf_available_and_version_0_6_0(): raise ImportError( "The function `export_gguf` requires the package `gguf==0.6.0` to be installed, but it was not found. Please install `gguf==0.6.0`." ) exporter = GgufExporter( model=model, output_dir=Path(output_dir), model_type=model_type, tokenizer_path=tokenizer_path ) exporter._export()
class BaseImporter(ABC): """Base class for all model importers.""" def __init__(self) -> None: pass @abstractmethod def _validate(self) -> None: """Validate import parameters and model compatibility.""" pass @abstractmethod def _import_impl(self) -> torch.nn.Module: """Perform the actual import operation.""" pass def _import(self, *args: Any, **kwargs: Any) -> torch.nn.Module: """Process the model for validation and import. Sets attributes, validates, and imports.""" for k, v in kwargs.items(): setattr(self, k, v) self._validate() return self._import_impl() class SafetensorsImporter(BaseImporter): """Importer for Safetensors format.""" def __init__(self) -> None: super().__init__() # Declare attributes that will be set by _import method self.model: torch.nn.Module self.model_dir: str self.multi_device: bool def _validate(self) -> None: """Validate Safetensors import parameters.""" if not is_safetensors_available(): raise ImportError( "The function `import_model_from_safetensors` requires the package `safetensors` to be installed, but it was not found. Please install `safetensors`." ) def _import_impl(self) -> torch.nn.Module: """Import model from Safetensors format.""" logger.info("Start importing safetensors quantized model ...") # Create temporary model config object model_config = PretrainedConfig(pretrained_dir=self.model_dir) # Load weights from file, on cpu device. checkpoint_weights = _load_weights_from_safetensors(self.model_dir) # For some transformer models, the huggingface checkpoint keys are different from the model.state_dict keys. # So we need to convert the checkpoint keys to model.state_dict keys using the _checkpoint_conversion_mapping attribute. if hasattr(self.model, "_checkpoint_conversion_mapping") and len(self.model._checkpoint_conversion_mapping) > 0: checkpoint_conversion_mapping = self.model._checkpoint_conversion_mapping converted_weights = {} for name, param in checkpoint_weights.items(): new_name = name for pattern, replacement in checkpoint_conversion_mapping.items(): converted_name = re.sub(pattern, replacement, name) if converted_name != name: new_name = converted_name break # Apply first matching pattern only converted_weights[new_name] = param checkpoint_weights = converted_weights original_model_on_meta_device = False for name, param in chain(self.model.named_parameters(), self.model.named_buffers()): if param.device.type == "meta": original_model_on_meta_device = True break if original_model_on_meta_device: has_non_persistent_buffers = any( len(submodule._non_persistent_buffers_set) > 0 for submodule in self.model.modules() ) if has_non_persistent_buffers: raise NotImplementedError( "Reloading a safetensors model using the original non-quantized model placed on meta device while it contains non-persistent buffers is not supported, as the non-persistent buffers can not be reloaded from the serialized checkpoint. Please consider initializing the original non-quantized model on cpu or cuda device. Please open an issue for the feature to be supported." ) # Build model with quantization support model = _build_quantized_model(self.model, model_config, checkpoint_weights) # Handle parameter untying if is_accelerate_available(): _untie_parameters(model, checkpoint_weights) # Save cache-related keys BEFORE any filtering or state_dict operations # (needed for real_quantized mode where these keys are not in model's state_dict) cache_state_dict = { k: v for k, v in checkpoint_weights.items() if ".output_scale" in k and ("k_proj" in k or "v_proj" in k) } if cache_state_dict: logger.debug( f"Saved {len(cache_state_dict)} cache keys before filtering. Cache keys: {', '.join(list(cache_state_dict.keys()))}." ) # Get current model state dict model_state_dict = model.state_dict() # There is a mismatch between serialized checkpoints and `QParamsLinear` parameters/buffers keys. # See context in #3665. # TODO: Remove condition once we drop transformers<=4.56 support. if not QPARAMSLINEAR_OVERRIDES_STATE_DICT: checkpoint_weights = _fix_loaded_weights_key_mismatch( checkpoint_weights, weight_format=model_config.weight_format, custom_mode=model_config.quantization_config["quant_method"], ) # In case we are loading the quantized weights into a model that is not on meta device, # we re-use the original device the weights were placed on, as `assign=True` is used later. # This is helpful e.g. in case the original model was dispatched to multiple # devices ahead of time with `accelerate`. for name, param in model_state_dict.items(): if name not in checkpoint_weights: raise ValueError(f"The loaded checkpoint misses the key {name} present in the model weights.") else: if param.device.type != "meta": checkpoint_weights[name] = checkpoint_weights[name].to(param.device) # Handle multi-device loading if enabled if self.multi_device and is_accelerate_available(): _handle_multi_device_loading(model, checkpoint_weights) # Load weights into model with strict=False to handle missing quantization parameters model.load_state_dict(checkpoint_weights, assign=True, strict=False) # Convert model model = _convert_quantized_model(model, model_config) config_from_model = getattr(model, "quant_config", None) if config_from_model is not None: # Use the saved cache state dict for import import_model_with_cache_from_safetensors(model, cache_state_dict, config_from_model) logger.info("safetensors quantized model imported successfully.") return model
[docs] def import_model_from_safetensors( model: torch.nn.Module, model_dir: str, multi_device: bool = False ) -> torch.nn.Module: """ Imports a quantized model from the local directory ``model_dir`` into a non-quantized model ``model``. :param torch.nn.Module model: The non-quantized model, that will be transformed in place to a quantized model using the ``"quantization_config"`` in the ``config.json`` file retrieved in the local directory ``model_dir``, and in which quantized weights will be loaded into. :param str model_dir: Directory containing the model files (``config.json`` and ``model.safetensors``) :param bool multi_device: Whether to use multi-device loading using Accelerate library. Defaults to ``False``. :return: The model with loaded weights and proper quantization modules. """ importer = SafetensorsImporter() return importer._import(model=model, model_dir=model_dir, multi_device=multi_device)
[docs] def save_params( model: nn.Module, model_type: str, args: tuple[Any, ...] | None = None, kwargs: dict[str, Any] | None = None, export_dir: Path | str = tempfile.gettempdir(), quant_mode: QuantizationMode = QuantizationMode.eager_mode, compressed: bool = False, reorder: bool = True, ) -> None: """ Save the network architecture or configurations and parameters of the quantized model. For eager mode quantization, the model's configurations are stored in json file, and parameters including weight, bias, scale, and zero_point are stored in safetensors file. For fx_graph mode quantization, the model's network architecture and parameters are stored in pth file. :param torch.nn.Module model: The quantized model to be saved. :param str model_type: The type of the model, e.g. gpt2, gptj, llama or gptnext. :param Optional[Tuple[Any, ...]] args: Example tuple inputs for this quantized model. Only available for fx_graph mode quantization. Default is ``None``. :param Optional[Dict[str, Any]] kwargs: Example dict inputs for this quantized model. Only available for fx_graph mode quantization. Default is ``None``. :param Union[Path, str] export_dir: The target export directory. :param QuantizationMode quant_mode: The quantization mode. The choice includes ``QuantizationMode.eager_mode`` and ``QuantizationMode.fx_graph_mode``. Default is ``QuantizationMode.eager_mode``. :param bool compressed: Export the compressed (real quantized) model or QDQ model, Default is ``False`` and it exports the QDQ model. :param bool reorder: pack method, uses pack the weight (eg. packs four ``torch.int8`` value into one ``torch.int32`` value). Default is ``True``. :return: None Examples: .. code-block:: python # eager mode: from quark.torch import save_params save_params(model, model_type=model_type, export_dir="./save_dir") .. code-block:: python # fx_graph mode: from quark.torch.export.api import save_params save_params(model, model_type=model_type, args=example_inputs, export_dir="./save_dir", quant_mode=QuantizationMode.fx_graph_mode) """ logger.info("Start saving parameters of quantized model ...") for name, submodule in model.named_modules(): if isinstance(submodule, (ScaledFakeQuantize, NonScaledFakeQuantize)): if not submodule.frozen_params and not submodule.is_dynamic: raise ValueError( f"`model = ModelQuantizer.freeze(model)` needs to be called prior to running `save_params`, but found soft parameters in the model (in {name}). Please double check your code or open an issue." ) export_dir = Path(export_dir) export_dir.mkdir(parents=True, exist_ok=True) if quant_mode is QuantizationMode.eager_mode: if not is_safetensors_available(): raise ImportError( "The function `save_params` with `quant_mode=QuantizationMode.eager_mode` requires the package `safetensors` to be installed, but it was not found. Please install `safetensors`." ) params_dict: dict[str, torch.Tensor] = {} builder = NativeModelInfoBuilder(model=model, config=JsonExporterConfig()) info = builder.build_model_info(params_dict, compressed=compressed, reorder=reorder) json_path = export_dir / f"{model_type}.json" with open(json_path, "w") as f: json.dump(info, f, indent=4) # handle tensors shared data_ptr_list: list[str] = [] for key, value in params_dict.items(): if str(value.data_ptr()) in data_ptr_list: params_dict[key] = value.clone() else: data_ptr_list.append(str(value.data_ptr())) params_path = export_dir / f"{model_type}.safetensors" save_file(params_dict, params_path) elif quant_mode is QuantizationMode.fx_graph_mode: if args is None: raise ValueError("args should not be None when saving fx_graph_mode quantized model") model_file_path = export_dir / f"{model_type}_quantized.pth" exported_model = torch.export.export(model, args, kwargs=kwargs) torch.export.save(exported_model, model_file_path) logger.info(f"Parameters of quantized model saved to {export_dir} successfully.")
def _map_to_quark(model: nn.Module, quantization_config: QConfig, pack_method: str, custom_mode: str) -> None: """ Maps a non-quantized model (possibly on meta device) to a model with QParamsLinear layers with weights not initialized. This function is useful to later load a checkpoint in the quark model using `model.load_state_dict(state_dict)`. Parameters: model (torch.nn.Module): An instance of the original not-quantized model. This model may be on `meta` device, or may have random weights. quantization_config (QConfig): The quantization configuration orginally used to quantize the model in Quark. pack_method (str): The packing method used when the model was serialized. custom_mode (str): The custom mode to use to initialize the `QParamsLinear` layers. The recommended mode is simply quark-native `"quark"`, but `"awq"` and `"fp8"` are also available. """ named_modules = dict(model.named_modules(remove_duplicate=False)) layers_online_rotation = set() rotation_config = quantization_config.get_rotation_config() if rotation_config is not None: layers_online_rotation = RotationProcessor.get_online_rotation_layers(rotation_config, model) if rotation_config.r3: raise NotImplementedError( "Reloading a model quantization using rotation algorithm with r3=True is not supported at the moment. Please open an issue." ) for op_name, float_module in tqdm(named_modules.items()): op_type = type(float_module) layer_quantization_config = get_layer_quant_config(quantization_config, op_type, op_name) if layer_quantization_config is not None and isinstance(float_module, nn.Linear): if op_name in layers_online_rotation: qparams_linear_cls = QParamsLinearWithRotation else: qparams_linear_cls = QParamsLinear qparams_linear = qparams_linear_cls.from_module( float_module, custom_mode=custom_mode, pack_method=pack_method, algo_config=quantization_config.algo_config, quant_config=layer_quantization_config, ) # for multi_device, hook can offer info. if hasattr(float_module, "_hf_hook"): hook = float_module._hf_hook quark_hook = AlignDevicesHook( execution_device=hook.execution_device, offload=hook.offload, io_same_device=hook.io_same_device, weights_map=hook.weights_map, offload_buffers=hook.offload_buffers, place_submodules=hook.place_submodules, skip_keys=hook.skip_keys, tied_params_map=hook.tied_params_map, ) add_hook_to_module(qparams_linear, quark_hook) setattr_recursive(model, op_name, qparams_linear) float_module.to("meta") del float_module # You have to add this func to lower the peak memory. torch.cuda.empty_cache() def _move_quantizer_to_dict(model: nn.Module) -> None: """ Move the model's QParamsLinear quantizer to a dict which will work will tp Parameters: model (torch.nn.Module): An instance of the original not-quantized model. This model may be on `meta` device, or may have random weights. """ dict_name = "_quant_dict" quantizer_names = ["weight_quantizer", "input_quantizer", "output_quantizer", "bias_quantizer"] named_modules = dict(model.named_modules(remove_duplicate=False)) for module_name, float_module in tqdm(named_modules.items()): # If the current object have the quantizer specified as input names, update it to Nine and save to the dict. if isinstance(float_module, (torch.nn.Linear, torch.nn.Module)): if hasattr(float_module, dict_name): qdict = {} for quantizer_name in quantizer_names: if hasattr(float_module, quantizer_name): quantizer = getattr(float_module, quantizer_name, None) if quantizer is not None: qdict[quantizer_name] = quantizer setattr(float_module, quantizer_name, None) if len(qdict) > 0: setattr(float_module, dict_name, qdict)