Vision Model Quantization using Quark FX Graph Mode#

In this example, we present a vision model quantization workflow. The user specified a nn.Module and transformed the model to torch.fx.GraphModule format by using PyTorch API. During the quantization process, after annotation and insertion quantizers, this modified fx.GraphModule can be used to perform PTQ (Post Training Quantization), or/and QAT (Quantization Aware Training). We supply a demonstration code and show how users assign quant config, more information can be found in User Guide.

PTQ#

In PTQ, after the FakeQuantize is inserted, during the calibration, the observer is activated for recoding the tensor’s distribution the values such as min and max will be recorded to calculate quant parameters, while not performing fake quantizing. This means all the calculations are under FP32 precision. After the calibration, we will activate the fake quantizer to perform quantization and evaluation.

QAT#

Same as PTQ, after the model is prepared. During the training process, both observer and fake_quant are effective, observer is used for recording the tensor’s distribution such as min and max value to calculate quantization parameters, and the tensor will be quantized by fake_quant.

Quick Start#

Perform PTQ to get the quantized model and export to ONNX

python3 quantize.py --data_dir [Train and Test Data floder] \
                    --model_name [mobilenetv2 or resnet18] \
                    --pretrained [Pre-trained model file address] \
                    --model_export onnx
                    --export_dir [dir to save exported model]

Users can also select to perform QAT to further improve classification accuracy. Typically, that there are some training parameters that need to be modified for higher accuracy.

python3 quantize.py --data_dir [Train and Test Data floder] \
                    --model_name [mobilenetv2 or resnet18] \
                    --pretrained [Pre-trained model file address] \
                    --model_export onnx
                    --export_dir [dir to save exported model] \
                    --qat True

Fine-Grained User Guide#

Step1:Prepare float point model, dataset, loss function

from torchvision.models import resnet18
float_model = resnet18(pretrained=False)
float_model.load_state_dict(torch.load(pretrained))
calib_loader = prepare_calib_dataset(args.data_dir, device, calib_length=args.train_batch_size * 10)
train_loader, val_loader = prepare_data_loaders(args.data_dir)
criterion = nn.CrossEntropyLoss().to(device)

Step 2: transformer the ``torch.nn.Module`` to ``torch.fx.GraphModule``.

from torch._export import capture_pre_autograd_graph
example_inputs = (torch.rand(args.train_batch_size, 3, 224, 224).to(device), )
graph_model = capture_pre_autograd_graph(float_model, example_inputs)

Step3: Init the quantizer and quantization configuration

from quark.torch.quantization.config.config import QuantizationSpec, QuantizationConfig, Config
from quark.torch.quantization.config.type import Dtype, QSchemeType, ScaleType, RoundType, QuantizationMode
from quark.torch.quantization.observer.observer import PerTensorMinMaxObserver
INT8_PER_TENSER_SPEC = QuantizationSpec(dtype=Dtype.int8,
                                        qscheme=QSchemeType.per_tensor,
                                        observer_cls=PerTensorMinMaxObserver,
                                        symmetric=True,
                                        scale_type=ScaleType.float,
                                        round_method=RoundType.half_even,
                                        is_dynamic=False)
quant_config = QuantizationConfig(input_tensors=INT8_PER_TENSER_SPEC,
                                      output_tensors=INT8_PER_TENSER_SPEC,
                                      weight=INT8_PER_TENSER_SPEC,
                                      bias=INT8_PER_TENSER_SPEC)
quant_config = Config(global_quant_config=quant_config,
                quant_mode=QuantizationMode.fx_graph_mode)
quantizer = ModelQuantizer(quant_config)

Step4: Generate the quantized graph model by performing calibration

quantized_model = quantizer.quantize_model(graph_model, calib_loader)

Step5 (Optional): QAT for more high accuracy

train(quantized_model, train_loader, val_loader, criterion, device_ids)

Step6: Validate model performance and export

acc1_quant = validate(val_loader, quantized_model, criterion, device)
freezed_model = quantizer.freeze(prepared_model)
acc1_freeze = validate(val_loader, freezed_model, criterion, device)
# check whether acc1_quant == acc1_freeze

# ==============export to ONNX ==================
from quark.torch import ModelExporter
from quark.torch.export.config.custom_config import DEFAULT_EXPORTER_CONFIG
config = DEFAULT_EXPORTER_CONFIG
exporter = ModelExporter(config=config, export_dir=args.export_dir)
example_inputs = (torch.rand(batch_size, 3, 224, 224).to(device),)
exporter.export_onnx_model(freezed_model, example_inputs[0])

# ==========export using torch.export============
example_inputs = (next(iter(val_loader))[0].to(device),)
model_file_path = os.path.join(args.export_dir, args.model_name + ".pth")
exported_model = torch.export.export(freezeded_model, example_inputs)
torch.export.save(exported_model, model_file_path)

Experiment Result#

We conducted PTQ and QAT on both ResNet-18 and MobileNet-V2. In these model, all weight, bias, and activation are quantized. All kinds of Tensors are quantized in INT8, per-tensor, symmetric(zero point is 0). The scale factor is in float format. The following table shows the validation accuracy in the ImageNet dataset produced by the above script.

Method

ResNet-18

MobileNetV2

Float Model

69.764 / 89.085

71.881 / 90.301

PTQ (INT8)

69.084 / 88.648

65.291 / 86.254

QAT (INT8)

69.469 / 88.872

68.562 /88.484