Vision Model Quantization using Quark FX Graph Mode#

In this example, we present a vision model quantization workflow. The user specified a nn.Module and transformed the model to torch.fx.GraphModule format by using PyTorch API. During the quantization process, after annotation and insertion quantizers, this modified fx.GraphModule can be used to perform PTQ (Post Training Quantization), or/and QAT (Quantization Aware Training). We supply a demonstration code and show how users assign quant config, more information can be found in User Guide.

PTQ#

In PTQ, after the FakeQuantize is inserted, during the calibration, the observer is activated for recoding the tensor’s distribution the values such as min and max will be recorded to calculate quant parameters, while not performing fake quantizing. This means all the calculations are under FP32 precision. After the calibration, we will activate the fake quantizer to perform quantization and evaluation.

QAT#

Same as PTQ, after the model is prepared. During the training process, both observer and fake_quant are effective, observer is used for recording the tensor’s distribution such as min and max value to calculate quantization parameters, and the tensor will be quantized by fake_quant.

TQT#

A method for uniform symmetric quantizers using standard backpropagation and gradient descent. Different with QAT, TQT add scale-factors gradient. And different with LSQ that trains the scale-factors directly, which leads to stability issues, TQT constrains scale-factors to power-of-2 and uses a gradient formulation to train log-thresholds instead. So theoretically TQT is better than LSQ and LSQ is better than QAT. For efficient fixed-point implementations, TQT constrains quantization scheme to use: Symmetric、Per-tensor scaling、Power-of-2 scaling. Currently, only signed data are supported for tqt. More experimental results are on the way.

Quick Start#

Perform PTQ to get the quantized model and export to ONNX

python3 quantize.py --data_dir [Train and Test Data floder] \
                    --model_name [mobilenetv2 or resnet18] \
                    --pretrained [Pre-trained model file address] \
                    --model_export onnx
                    --export_dir [dir to save exported model]

Users can also select to perform QAT to further improve classification accuracy. Typically, that there are some training parameters that need to be modified for higher accuracy.

python3 quantize.py --data_dir [Train and Test Data floder] \
                    --model_name [mobilenetv2 or resnet18] \
                    --pretrained [Pre-trained model file address] \
                    --model_export onnx
                    --export_dir [dir to save exported model] \
                    --qat True

LSQ and TQT are optimized methods for QAT which can improve accuracy theoretically. The params --tqt True --lsq True are provided for users to try. Model export is not supported now.

Fine-Grained User Guide#

Step1:Prepare float point model, dataset, loss function

from torchvision.models import resnet18
float_model = resnet18(pretrained=False)
float_model.load_state_dict(torch.load(pretrained))
calib_loader = prepare_calib_dataset(args.data_dir, device, calib_length=args.train_batch_size * 10)
train_loader, val_loader = prepare_data_loaders(args.data_dir)
criterion = nn.CrossEntropyLoss().to(device)

Step 2: transformer the ``torch.nn.Module`` to ``torch.fx.GraphModule``.

from torch._export import capture_pre_autograd_graph
example_inputs = (torch.rand(args.train_batch_size, 3, 224, 224).to(device), )
graph_model = capture_pre_autograd_graph(float_model, example_inputs)

Step3: Init the quantizer and quantization configuration

from quark.torch.quantization.config.config import QuantizationSpec, QuantizationConfig, Config
from quark.torch.quantization.config.type import Dtype, QSchemeType, ScaleType, RoundType, QuantizationMode
from quark.torch.quantization.observer.observer import PerTensorMinMaxObserver
INT8_PER_TENSER_SPEC = QuantizationSpec(dtype=Dtype.int8,
                                        qscheme=QSchemeType.per_tensor,
                                        observer_cls=PerTensorMinMaxObserver,
                                        symmetric=True,
                                        scale_type=ScaleType.float,
                                        round_method=RoundType.half_even,
                                        is_dynamic=False)
quant_config = QuantizationConfig(input_tensors=INT8_PER_TENSER_SPEC,
                                      output_tensors=INT8_PER_TENSER_SPEC,
                                      weight=INT8_PER_TENSER_SPEC,
                                      bias=INT8_PER_TENSER_SPEC)
quant_config = Config(global_quant_config=quant_config,
                quant_mode=QuantizationMode.fx_graph_mode)
quantizer = ModelQuantizer(quant_config)

Step4: Generate the quantized graph model by performing calibration

quantized_model = quantizer.quantize_model(graph_model, calib_loader)

Step5 (Optional): QAT for more high accuracy

train(quantized_model, train_loader, val_loader, criterion, device_ids)

Step6: Validate model performance and export

acc1_quant = validate(val_loader, quantized_model, criterion, device)
freezed_model = quantizer.freeze(prepared_model)
acc1_freeze = validate(val_loader, freezed_model, criterion, device)
# check whether acc1_quant == acc1_freeze

# ==============export to ONNX ==================
from quark.torch import ModelExporter
from quark.torch.export.config.config import ExporterConfig, JsonExporterConfig
config = ExporterConfig(json_export_config=JsonExporterConfig())
exporter = ModelExporter(config=config, export_dir=args.export_dir)
example_inputs = (torch.rand(batch_size, 3, 224, 224).to(device),)
exporter.export_onnx_model(freezed_model, example_inputs[0])

# ==========export using torch.export============
example_inputs = (next(iter(val_loader))[0].to(device),)
model_file_path = os.path.join(args.export_dir, args.model_name + ".pth")
exported_model = torch.export.export(freezeded_model, example_inputs)
torch.export.save(exported_model, model_file_path)

Experiment Result#

We conducted PTQ and QAT on both ResNet-18 and MobileNet-V2. In these model, all weight, bias, and activation are quantized. All kinds of Tensors are quantized in INT8, per-tensor, symmetric(zero point is 0). The scale factor is in float format. The following table shows the validation accuracy in the ImageNet dataset produced by the above script.

Method

ResNet-18

MobileNetV2

Float Model

69.764 / 89.085

71.881 / 90.301

PTQ (INT8)

69.084 / 88.648

65.291 / 86.254

QAT (INT8)

69.469 / 88.872

68.562 /88.484