quark.torch.export.api
#
Module Contents#
Classes#
- class quark.torch.export.api.ModelExporter(config: quark.torch.export.config.config.ExporterConfig, export_dir: Union[pathlib.Path, str] = tempfile.gettempdir())#
Provides an API for exporting quantized Pytorch deep learning models. This class converts the quantized model to json-safetensors files or onnx graph, and saves to export_dir.
- Parameters:
config (ExporterConfig) – Configuration object containing settings for exporting.
export_dir (Union[Path, str]) – The target export diretory. This could be a string or a pathlib.Path(string) object.
- export_model_info(model: torch.nn.Module, model_type: str, model_dtype: torch.dtype = torch.float16, export_type: str = 'vllm-adopt') None #
This function aims to export json and safetensors files of the quantized Pytorch model. The model’s network architecture is stored in the json file, and parameters including weight, bias, scale, and zero_point are stored in the safetensors file.
- Parameters:
model (torch.nn.Module) – The quantized model to be exported.
model_type (str) – The type of the model, e.g. gpt2, gptj, llama or gptnext.
model_dtype (torch.dtype) – The weight data type of the quantized model. Default is torch.float16.
export_type (str) – The specific format in which the JSON and safetensors files are stored. The choices include ‘vllm-adopt’ and ‘native’. Default is vllm-adopt. If set to ‘vllm-adopt’, the exported files are customized for the VLLM compiler. The ‘native’ configuration is currently for internal testing use.
- Returns:
None
Examples:
export_path = "./output_dir" from quark.torch import ModelExporter from quark.torch.export.config.custom_config import DEFAULT_EXPORTER_CONFIG exporter = ModelExporter(config=DEFAULT_EXPORTER_CONFIG, export_dir=export_path) exporter.export_model_info(model, model_type, model_dtype, export_type="vllm-adopt")
Note
Since the export_type “native” is only for internal testing use currently, this function is only used to export files required by the VLLM compiler. Supported quantization types include fp8, int4_per_group, and w4a8_per_group. Supported models include Llama2-7b, Llama2-13b, Llama2-70b, and Llama3-8b.
- export_onnx_model(model: torch.nn.Module, input_args: Union[torch.Tensor, Tuple[float]], input_names: List[str] = [], output_names: List[str] = [], verbose: bool = False, opset_version: Optional[str] = None, do_constant_folding: bool = True, operator_export_type: torch.onnx.OperatorExportTypes = torch.onnx.OperatorExportTypes.ONNX, uint4_int4_flag: bool = False) None #
This function aims to export onnx graph of the quantized Pytorch model.
- Parameters:
model (torch.nn.Module) – The quantized model to be exported.
input_args (Union[torch.Tensor, Tuple[float]]) – Example inputs for this quantized model.
input_names (List[str]) – Names to assign to the input nodes of the onnx graph, in order. Default is empty list.
output_names (List[str]) – Names to assign to the output nodes of the onnx graph, in order. Default is empty list.
verbose (bool) – Flag to control showing verbose log or no. Default is False
opset_version (Optional[str]) – The version of the default (ai.onnx) opset to target. If not set, it will be valued the latest version that is stable for the current version of PyTorch.
do_constant_folding (bool) – Apply the constant-folding optimization. Default is False
operator_export_type (torch.onnx.OperatorExportTypes) – Export operator type in onnx graph. The choices include OperatorExportTypes.ONNX, OperatorExportTypes.ONNX_FALLTHROUGH, OperatorExportTypes.ONNX_ATEN and OperatorExportTypes.ONNX_ATEN_FALLBACK. Default is OperatorExportTypes.ONNX.
uint4_int4_flag (bool) – Flag to indicate uint4/int4 quantized model or not. Default is False.
- Returns:
None
Examples:
from quark.torch import ModelExporter from quark.torch.export.config.custom_config import DEFAULT_EXPORTER_CONFIG exporter = ModelExporter(config=DEFAULT_EXPORTER_CONFIG, export_dir=export_path) exporter.export_onnx_model(model, input_args)
Note
Mix quantization of int4/uint4 and int8/uint8 is not supported currently. In other words, if the model contains both quantized nodes of uint4/int4 and uint8/int8, this function cannot be used to export the ONNX graph.