Source code for quark.onnx.optimize

#
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
#
from quark.shares.utils.log import ScreenLogger
import numpy as np
from math import sqrt
import copy
import onnx
from onnxruntime.quantization.onnx_model import ONNXModel
from onnxruntime.transformers.onnx_model import OnnxModel
from onnxruntime.transformers.fusion_layernorm import FusionLayerNormalization
from onnxruntime.transformers.fusion_gelu import FusionGelu
from .quant_utils import QUANT_OP_TYPES, DEQUANT_OP_TYPES, get_clip_min_max, get_opset_version
from typing import Tuple, List, Optional, Union
from numpy.typing import NDArray
from onnx import ModelProto, NodeProto

logger = ScreenLogger(__name__)


[docs] class Optimize(object): """ A class for optimizations to be applied to onnx model before quantization. :param onnx.ModelProto model: The ONNX model to be optimized. :param List[str] op_types_to_quantize: A list of operation types to be quantized. :param Optional[List[str]] nodes_to_quantize: A list of node names to be quantized. :param Optional[List[str]] nodes_to_exclude: A list of node names to be excluded from quantization. Defaults to ``None``. """ def __init__(self, model: ModelProto, op_types_to_quantize: List[str], nodes_to_quantize: Optional[List[str]], nodes_to_exclude: Optional[List[str]]) -> None: self.model = model self.op_types_to_quantize = op_types_to_quantize self.nodes_to_quantize = nodes_to_quantize self.nodes_to_exclude = nodes_to_exclude def should_quantize_node(self, node: NodeProto) -> bool: if (self.nodes_to_quantize is not None and len(self.nodes_to_quantize) != 0 and node.name not in self.nodes_to_quantize): return False if node.op_type not in self.op_types_to_quantize: return False if self.nodes_to_exclude is not None and node.name in self.nodes_to_exclude: return False return True def replace_node_with(self, node: NodeProto, replaced_type: str) -> NodeProto: new_node = onnx.helper.make_node(replaced_type, inputs=node.input, outputs=node.output, name=node.name) self.model.graph.node.append(new_node) return new_node
[docs] def convert_bn_to_conv(self) -> None: """Convert BatchNormalization to Conv. """ def _get_folded_conv_weights(bn_gamma: NDArray[np.float32], bn_beta: NDArray[np.float32], bn_mm: NDArray[np.float32], bn_mv: NDArray[np.float32], bn_epsilon: float) -> Tuple[NDArray[np.float32], NDArray[np.float32]]: if bn_gamma is not None: multiplier = bn_gamma / np.sqrt(bn_mv + bn_epsilon) else: multiplier = 1 / np.sqrt(bn_mv + bn_epsilon) folded_conv_kernel = multiplier folded_conv_bias = bn_beta + (-bn_mm) * multiplier return folded_conv_kernel, folded_conv_bias self.op_types_to_quantize.append("BatchNormalization") nodes_to_remove: List[NodeProto] = [] init_to_remove: List[str] = [] onnx_model = ONNXModel(self.model) for node in onnx_model.model.graph.node: if node.op_type == 'BatchNormalization' and self.should_quantize_node(node): input_name = node.input[0] input_shape: List[str] = [] for input_info in onnx_model.model.graph.value_info: if input_info.name == input_name: input_shape = [dim.dim_value for dim in input_info.type.tensor_type.shape.dim] if len(node.input) == 5 and len(input_shape) == 4: bn_epsilon = next((attr.f for attr in node.attribute if attr.name == 'epsilon'), 1e-10) for init in onnx_model.model.graph.initializer: if init.name == node.input[1]: bn_gamma = onnx.numpy_helper.to_array(init) elif init.name == node.input[2]: bn_beta = onnx.numpy_helper.to_array(init) elif init.name == node.input[3]: bn_mm = onnx.numpy_helper.to_array(init) elif init.name == node.input[4]: bn_mv = onnx.numpy_helper.to_array(init) try: weights, bias = _get_folded_conv_weights(bn_gamma, bn_beta, bn_mm, bn_mv, bn_epsilon) num_channel = bn_mm.shape[0] weights = weights.reshape([num_channel, 1, 1, 1]) weights_tensor = onnx.numpy_helper.from_array(weights, name=node.output[0] + "weights") bias_tensor = onnx.numpy_helper.from_array(bias, name=node.output[0] + "bias") onnx_model.model.graph.initializer.extend([weights_tensor, bias_tensor]) new_node = onnx.helper.make_node( "Conv", inputs=[node.input[0], node.output[0] + "weights", node.output[0] + "bias"], outputs=[node.output[0]], group=num_channel, kernel_shape=[1, 1], strides=[1, 1], name=node.name) nodes_to_remove.append(node) init_to_remove.extend([node.input[1], node.input[2], node.input[3], node.input[4]]) onnx_model.model.graph.node.append(new_node) logger.info(f"Found BatchNormalization node {node.name}. " f"Replacing with Conv.") except Exception as e: logger.warning( f"Fail to generate conv's weights and bias beacuse of {e}, skip converting bn to conv") else: logger.warning( f"Fail to convert bn {node.name} to conv beacuse BatchNormalization's input or shape does not meet the requirements" ) onnx_model.remove_nodes(nodes_to_remove) onnx_model.remove_initializers(init_to_remove) onnx_model.clean_initializers() onnx_model.topological_sort() self.model = onnx_model.model
[docs] def convert_reduce_mean_to_global_avg_pool(self) -> None: """Convert ReduceMean to GlobalAveragePool. """ from .quant_utils import check_reduce_mean_condition nodes_to_remove = [] onnx_model = ONNXModel(self.model) for node in onnx_model.model.graph.node: if node.op_type == 'ReduceMean' and check_reduce_mean_condition(onnx_model.model, node) and self.should_quantize_node(node): if len(node.input) == 1: new_node = self.replace_node_with(node, 'GlobalAveragePool') nodes_to_remove.append(node) logger.info(f"Found ReduceMean node {node.name} with axes=[2, 3]. " f"Replacing with GlobalAveragePool.") # Handling opset >= 18 for Reduce Mean elif len(node.input) == 2: new_node = onnx.helper.make_node('GlobalAveragePool', inputs=[node.input[0]], outputs=node.output, name=node.name) nodes_to_remove.append(node) onnx_model.model.graph.node.append(new_node) logger.info(f"Found ReduceMean node {node.name} with axes=[2, 3]. " f"Replacing with GlobalAveragePool.") onnx_model.remove_nodes(nodes_to_remove) onnx_model.clean_initializers() onnx_model.topological_sort() self.model = onnx_model.model
[docs] def split_large_kernel_pool(self) -> None: """ For pooling with an excessively large kernel size in the onnx model, split it into multiple smaller poolings. """ def _get_factors(num: int) -> Tuple[int, int]: factor_1 = int(sqrt(num)) while (factor_1 > 1): if (num % (factor_1) == 0): factor_2 = num / factor_1 return int(factor_1), int(factor_2) factor_1 = factor_1 - 1 factor_2 = num return int(factor_1), int(factor_2) onnx_model = ONNXModel(self.model) for node in onnx_model.model.graph.node: if node.op_type == "GlobalAveragePool" and self.should_quantize_node(node): input_name = node.input[0] kw = None kh = None for input_info in onnx_model.model.graph.value_info: if input_info.name == input_name: input_shape = [dim.dim_value for dim in input_info.type.tensor_type.shape.dim] if len(input_shape) == 4: shape_to_check = True kh = input_shape[2] kw = input_shape[3] break if not kw or not kh: logger.warning('Failed to get the input shape, skip optimizing for GlobalAveragePool {}.'.format( node.name)) continue # Only one split is supported. # TODO: Support multiple split operations elif kw * kh > 512: kh1, kh2 = _get_factors(kh) kw1, kw2 = _get_factors(kw) if kh1 * kw1 > 512 or kh2 * kw2 > 512: logger.warning("After split, the kernel size is still too large." "Currently, only one split is supported. Skip optimization.") else: split_tensor = node.input[0] + "_Split" pool_node = onnx.helper.make_node("AveragePool", inputs=[node.input[0]], outputs=[split_tensor], kernel_shape=[kh1, kw1], strides=[kh1, kw1], name=split_tensor) if not node.name: node.name = node.output[0] node.input[0] = split_tensor onnx_model.model.graph.node.extend([pool_node]) logger.info(f"Found GlobalAveragePool node {node.name} with large kernel size. " f"Split it into multiple AveragePools.") onnx_model.clean_initializers() onnx_model.topological_sort() self.model = onnx_model.model
[docs] def convert_split_to_slice(self) -> None: """Convert Split to Slice. """ nodes_to_remove: List[NodeProto] = [] init_to_remove: List[str] = [] onnx_model = ONNXModel(self.model) for node in onnx_model.model.graph.node: if node.op_type == 'Split' and self.should_quantize_node(node): num_input = len(node.input) axis_attr = next((attr for attr in node.attribute if attr.name == 'axis'), None) assert (axis_attr is not None), "No axis attribute founded in Split node" axis = axis_attr.i # if axis_attr is not None else 0 input_name = node.input[0] output_names = node.output if num_input == 2: splits = None for init in onnx_model.model.graph.initializer: if init.name == node.input[1]: splits = onnx.numpy_helper.to_array(init).tolist() if splits is None: logger.warning(f"No split detected of {node.name}, " "failed to convert split to slice, please check the input model.") break elif num_input == 1: split_attr = next((attr for attr in node.attribute if attr.name == 'split'), None) if split_attr is None: logger.warning(f"No split detected of {node.name}, " "failed to convert split to slice, please check the input model.") break splits = split_attr.ints else: logger.warning(f"Failed to convert split of {node.name} to slice, " "the number of input nodes is not supported.") break starts = [sum(splits[:i]) for i in range(len(splits))] ends = [sum(splits[:i + 1]) for i in range(len(splits))] for i in range(len(output_names)): starts_node = onnx.helper.make_node('Constant', inputs=[], outputs=[output_names[i] + '_starts_' + str(i)], value=onnx.helper.make_tensor(name=output_names[i] + '_starts_' + str(i), data_type=onnx.TensorProto.INT64, dims=[1], vals=[starts[i]])) ends_node = onnx.helper.make_node('Constant', inputs=[], outputs=[output_names[i] + '_ends_' + str(i)], value=onnx.helper.make_tensor(name=output_names[i] + '_ends_' + str(i), data_type=onnx.TensorProto.INT64, dims=[1], vals=[ends[i]])) axes_node = onnx.helper.make_node('Constant', inputs=[], outputs=[output_names[i] + '_axes_' + str(i)], value=onnx.helper.make_tensor(name=output_names[i] + '_axes_' + str(i), data_type=onnx.TensorProto.INT64, dims=[1], vals=[axis])) steps_node = onnx.helper.make_node('Constant', inputs=[], outputs=[output_names[i] + '_steps_' + str(i)], value=onnx.helper.make_tensor(name=output_names[i] + '_steps_' + str(i), data_type=onnx.TensorProto.INT64, dims=[1], vals=[1])) slice_node = onnx.helper.make_node("Slice", inputs=[ input_name, output_names[i] + '_starts_' + str(i), output_names[i] + '_ends_' + str(i), output_names[i] + '_axes_' + str(i), output_names[i] + '_steps_' + str(i) ], outputs=[output_names[i]], name=output_names[i] + '_' + str(i)) onnx_model.model.graph.node.extend([slice_node, starts_node, ends_node, axes_node, steps_node]) nodes_to_remove.append(node) if len(node.input) > 1: init_to_remove.append(node.input[1]) logger.info(f"Found Split node {node.name}. " f"Replacing with Slice.") onnx_model.remove_nodes(nodes_to_remove) onnx_model.remove_initializers(init_to_remove) onnx_model.clean_initializers() onnx_model.topological_sort() self.model = onnx_model.model
[docs] def fuse_instance_norm(self) -> None: ''' The split instance norm operation will be fused to InstanceNorm operation ''' onnx_model = ONNXModel(self.model) tensor_to_producer_dict = {} remove_nodes: List[NodeProto] = [] remove_inits: List[onnx.TensorProto] = [] for node in onnx_model.model.graph.node: for output in node.output: tensor_to_producer_dict[output] = node for init in onnx_model.model.graph.initializer: tensor_to_producer_dict[init.name] = init for node in onnx_model.model.graph.node: if node.op_type == "Add": try: add0_i0 = node.input[0] add0_i1 = node.input[1] add0_i0_node = tensor_to_producer_dict[add0_i0] add0_i1_node = tensor_to_producer_dict[add0_i1] # TODO: Use different dictionaries to distinguish between node and init. if add0_i0_node.op_type == "Mul" and add0_i1_node.op_type == "Sub": sub0_node = add0_i1_node sub0_i0 = sub0_node.input[0] sub0_i1 = sub0_node.input[1] sub0_i1_node = tensor_to_producer_dict[sub0_i1] if sub0_i1_node.op_type == "Mul": mul0_node = sub0_i1_node mul0_i0 = mul0_node.input[0] mul0_i1 = mul0_node.input[1] mul0_i0_node = tensor_to_producer_dict[mul0_i0] mul0_i1_node = tensor_to_producer_dict[mul0_i1] if mul0_i0_node.op_type == "GlobalAveragePool" and mul0_i1_node.op_type == "Mul": mul1_node = mul0_i1_node mul1_i0 = mul1_node.input[0] mul1_i1 = mul1_node.input[1] mul1_i0_node = tensor_to_producer_dict[mul1_i0] mul1_i1_node = tensor_to_producer_dict[mul1_i1] if mul1_i0_node.op_type == "Reciprocal": rec0_node = mul1_i0_node rec0_i0 = rec0_node.input[0] rec0_i0_node = tensor_to_producer_dict[rec0_i0] if rec0_i0_node.op_type == "Sqrt": sqr0_node = rec0_i0_node sqr0_i0 = sqr0_node.input[0] sqr0_i0_node = tensor_to_producer_dict[sqr0_i0] if sqr0_i0_node.op_type == "Add": add1_node = sqr0_i0_node add1_i0 = add1_node.input[0] add1_i1 = add1_node.input[1] add1_i0_node = tensor_to_producer_dict[add1_i0] if add1_i0_node.op_type == "GlobalAveragePool": gap0_node = add1_i0_node gap0_i0 = gap0_node.input[0] gap0_i0_node = tensor_to_producer_dict[gap0_i0] if gap0_i0_node.op_type == "Mul": mul2_node = gap0_i0_node mul2_i0 = mul2_node.input[0] mul2_i0_node = tensor_to_producer_dict[mul2_i0] if mul2_i0_node.op_type == "Sub": sub1_node = mul2_i0_node sub1_i0 = sub1_node.input[0] sub1_i1 = sub1_node.input[1] sub1_i0_node = tensor_to_producer_dict[sub1_i0] sub1_i1_node = tensor_to_producer_dict[sub1_i1] if sub1_i1_node.op_type == "GlobalAveragePool": # Remove nodes remove_node_list = [ node, add0_i0_node, add0_i1_node, sub0_i1_node, mul0_i0_node, mul0_i1_node, mul1_i0_node, rec0_i0_node, sqr0_i0_node, add1_i0_node, gap0_i0_node, mul2_i0_node, ] # Add InstanceNormalization bias_init = onnx_model.get_initializer(sub0_i0) bias_init.dims[:] = [bias_init.dims[1]] weight_init = onnx_model.get_initializer(mul1_i1) weight_init.dims[:] = [weight_init.dims[1]] eps_init = onnx_model.get_initializer(add1_i1) instance_norm_node = onnx.helper.make_node( "InstanceNormalization", [sub1_i0, mul1_i1, sub0_i0], node.output, node.name, epsilon=onnx.numpy_helper.to_array(eps_init).item()) logger.info( f"Matched Instance Normalization, fuse it into InstanceNormalization {node.name}" ) onnx_model.add_node(instance_norm_node) remove_nodes.extend(remove_node_list) remove_inits.append(eps_init) except Exception as e: logger.debug( f"FuseInstanceNorm is enabled, but {node.name} does not meet the matching rules:{e}, skipping this node" ) onnx_model.remove_nodes(remove_nodes) onnx_model.remove_initializers(remove_inits) onnx_model.clean_initializers() onnx_model.topological_sort() self.model = onnx_model.model
[docs] def fuse_l2_norm(self) -> None: """ convert L2norm ops to LpNormalization """ onnx_model = ONNXModel(self.model) tensor_to_producer_dict = {} remove_nodes: List[NodeProto] = [] remove_inits: List[onnx.TensorProto] = [] for node in onnx_model.model.graph.node: for output in node.output: tensor_to_producer_dict[output] = node for init in onnx_model.model.graph.initializer: tensor_to_producer_dict[init.name] = init for node in onnx_model.model.graph.node: if node.op_type == "Mul": try: inp_0 = node.input[0] inp_1 = node.input[1] inp_0_node = tensor_to_producer_dict[inp_0] inp_1_node = tensor_to_producer_dict[inp_1] if inp_0_node.op_type == "Unsqueeze" and inp_1_node.op_type == "Reciprocal": rec_node = inp_1_node rec_inp_0 = rec_node.input[0] rec_inp_0_node = tensor_to_producer_dict[rec_inp_0] if rec_inp_0_node.op_type == "Sqrt": sqrt_node = rec_inp_0_node sqrt_inp_0 = sqrt_node.input[0] sqrt_inp_0_node = tensor_to_producer_dict[sqrt_inp_0] if sqrt_inp_0_node.op_type == "Max": max_node = sqrt_inp_0_node max_inp_0 = max_node.input[0] max_inp_1 = max_node.input[1] max_inp_0_node = tensor_to_producer_dict[max_inp_0] if max_inp_0_node.op_type == "ReduceSum": red_node = max_inp_0_node red_inp_0 = red_node.input[0] red_inp_0_node = tensor_to_producer_dict[red_inp_0] if red_inp_0_node.op_type == "Mul": mul_node = red_inp_0_node mul_inp_0 = mul_node.input[0] mul_inp_0_node = tensor_to_producer_dict[mul_inp_0] if mul_inp_0_node.op_type == "Unsqueeze": uns_node = mul_inp_0_node # Remove nodes logger.info(f"Found L2norm ops from {node.name}.") nodes_to_remove_list = [ node, rec_node, sqrt_node, max_node, red_node, mul_node, ] remove_nodes.extend(nodes_to_remove_list) eps_init = onnx_model.get_initializer(max_inp_1) remove_inits.append(eps_init) # Add LpNormalization inp = uns_node.output[0] out = node.output[0] l2norm_node = onnx.helper.make_node("LpNormalization", [inp], [out], node.name, p=2) onnx_model.add_node(l2norm_node) logger.info("Converted L2norm ops from {node.name} to LpNormalization.") except Exception as e: logger.debug( f"FuseL2Norm is enabled, but {node.name} does not meet the matching rules:{e}, skipping this node" ) onnx_model.remove_nodes(remove_nodes) onnx_model.remove_initializers(remove_inits) onnx_model.clean_initializers() onnx_model.topological_sort() self.model = onnx_model.model
[docs] def fold_batch_norm(self) -> None: """ fold BatchNormalization to target operations """ def _get_folded_weight_bias(target_type: str, target_weight: NDArray[np.float32], target_bias: Union[NDArray[np.float32], NDArray[np.float64]], bn_gamma: Optional[NDArray[np.float32]], bn_beta: Optional[NDArray[np.float32]], bn_mean: NDArray[np.float32], bn_var: NDArray[np.float32], bn_epsilon: float) -> Tuple[NDArray[np.float32], NDArray[np.float32]]: if bn_gamma is not None: multiplier = bn_gamma / np.sqrt(bn_var + bn_epsilon) else: multiplier = 1 / np.sqrt(bn_var + bn_epsilon) if target_type == "Gemm": bn_weight = np.diag(multiplier) elif target_type == "ConvTranspose": bn_weight = multiplier.reshape(1, len(multiplier), 1, 1) if bn_beta is not None: bn_bias = bn_beta + (-bn_mean) * multiplier else: bn_bias = (-bn_mean) * multiplier if target_type == "Gemm": folded_weight = np.dot(bn_weight, target_weight) folded_bias = np.dot(bn_weight, target_bias) + bn_bias elif target_type == "ConvTranspose": folded_weight = bn_weight * target_weight folded_bias = bn_weight.reshape(1, -1) * target_bias + bn_bias folded_bias = folded_bias.reshape(-1) return folded_weight, folded_bias onnx_model = ONNXModel(self.model) TARGET_OPS = ('ConvTranspose', 'Gemm') remove_nodes = [] for node in onnx_model.model.graph.node: if node.op_type != 'BatchNormalization' or self.should_quantize_node(node): continue if len(node.input) != 5: logger.warning(f"BatchNorm {node.name} with {len(node.input)} inputs cannot be folded.") continue target_node = onnx_model.get_parent(node, 0) if target_node is None: logger.warning(f"BatchNorm {node.name} that is isolated node cannot be folded.") continue if target_node.op_type not in TARGET_OPS: logger.debug(f"BatchNorm {node.name} after node {target_node.name} cannot be folded.") continue bn_gamma_init = onnx_model.get_initializer(node.input[1]) bn_gamma = None if bn_gamma_init is None else onnx.numpy_helper.to_array(bn_gamma_init) bn_beta_init = onnx_model.get_initializer(node.input[2]) bn_beta = None if bn_beta_init is None else onnx.numpy_helper.to_array(bn_beta_init) bn_mean_init = onnx_model.get_initializer(node.input[3]) bn_mean = None if bn_mean_init is None else onnx.numpy_helper.to_array(bn_mean_init) bn_var_init = onnx_model.get_initializer(node.input[4]) bn_var = None if bn_var_init is None else onnx.numpy_helper.to_array(bn_var_init) bn_epsilon = next((attr.f for attr in node.attribute if attr.name == 'epsilon'), 1e-10) if bn_mean is None or bn_var is None: logger.warning(f"BatchNorm {node.name} that is missing mean or variance cannot be folded.") continue target_weight_init = onnx_model.get_initializer(target_node.input[1]) target_weight = None if target_weight_init is None else onnx.numpy_helper.to_array(target_weight_init) target_bias_init = onnx_model.get_initializer(target_node.input[2]) if len(target_node.input) > 2 else None target_bias = None if target_bias_init is None else onnx.numpy_helper.to_array(target_bias_init) if target_weight is None: logger.warning(f"BatchNorm {node.name}'s target node f{target_node.name} is not foldable.") continue target_type = target_node.op_type if target_type == "Gemm": transB = next((attr.i for attr in target_node.attribute if attr.name == 'transB'), 0) # TODO: Support transB is 0 if transB == 0: logger.debug(f"Target node f{target_node.name}'s transB=0 is not supported.") continue if target_type == "ConvTranspose": group = next((attr.i for attr in target_node.attribute if attr.name == 'group'), 1) # TODO: Support ConvTranspose group != 1 if group != 1: logger.debug(f"Target node f{target_node.name}'s group !=1 is not supported.") continue if target_bias is None: if target_type == "Gemm": target_bias = np.zeros(target_weight.shape[0]) else: # target_type == "ConvTranspose": target_bias = np.zeros(target_weight.shape[1]) target_bias_name = target_node.name + "_bias_4bn" target_bias_init = onnx.numpy_helper.from_array(target_bias.astype(np.float32), name=target_bias_name) onnx_model.add_initializer(target_bias_init) target_node.input.append(target_bias_name) # Calculate the weight and bias after folded folded_weight, folded_bias = _get_folded_weight_bias(target_type, target_weight, target_bias, bn_gamma, bn_beta, bn_mean, bn_var, bn_epsilon) # Update target node's weight and bias folded_weight_init = onnx.numpy_helper.from_array(folded_weight.astype(np.float32), name=target_weight_init.name) target_weight_init.CopyFrom(folded_weight_init) folded_bias_init = onnx.numpy_helper.from_array(folded_bias.astype(np.float32), name=target_bias_init.name) target_bias_init = onnx_model.get_initializer(target_node.input[2]) target_bias_init.CopyFrom(folded_bias_init) # Deal with the tensor name children = onnx_model.get_children(target_node) for child in children: if child is node: # this node will be removed continue for input_index, input_name in enumerate(child.input): if input_name == target_node.output[0]: child.input[input_index] = node.output[0] target_node.output[0] = node.output[0] # TODO: has shared initializers? remove_nodes.append(node) logger.info(f"Folded {node.op_type} {node.name} to {target_node.op_type} {target_node.name}.") onnx_model.remove_nodes(remove_nodes) onnx_model.clean_initializers() onnx_model.topological_sort() self.model = onnx_model.model
[docs] def convert_clip_to_relu(self) -> None: ''' Convert Clip to Relu. ''' nodes_to_remove = [] init_to_remove = [] onnx_model = ONNXModel(self.model) for node in onnx_model.model.graph.node: if node.op_type == 'Clip' and self.should_quantize_node(node): min_value, max_value, para_type = get_clip_min_max(onnx_model.model, node) if min_value is None or min_value < 0: continue # could not be replaced with Relu if para_type == 1: # This Clip node's min and max come from initializers for init in onnx_model.model.graph.initializer: if len(node.input) > 1 and init.name == node.input[1]: init_to_remove.append(init) if len(node.input) > 2 and init.name == node.input[2]: init_to_remove.append(init) elif para_type == 2: # This Clip node's min and max come from other nodes for nd in onnx_model.model.graph.node: if ((len(node.input) > 1 and node.input[1] in nd.output) or (len(node.input) > 2 and node.input[2] in nd.output)) is False: continue if nd.op_type == 'Identity': for init in onnx_model.model.graph.initializer: if len(nd.input) > 1 and init.name == nd.input[1]: init_to_remove.append(init) if len(nd.input) > 2 and init.name == nd.input[2]: init_to_remove.append(init) nodes_to_remove.append(nd) elif nd.op_type == 'Constant': nodes_to_remove.append(nd) logger.info(f"Convert Clip node {node.name} to Relu, " f"its min is {min_value}, max is {max_value} and type is {para_type}") relu_node = onnx.helper.make_node("Relu", [node.input[0]], node.output, node.name) onnx_model.model.graph.node.extend([relu_node]) # insert a Relu node nodes_to_remove.append(node) # to remove this Clip node onnx_model.remove_nodes(nodes_to_remove) onnx_model.remove_initializers(init_to_remove) onnx_model.clean_initializers() onnx_model.topological_sort() self.model = onnx_model.model
[docs] def fold_batch_norm_after_concat(self) -> None: """ fold BatchNormalization (after concat) to target operations """ def _get_folded_weight_bias(target_type: str, target_weight: NDArray[np.float32], target_bias: Union[NDArray[np.float32], NDArray[np.float64]], bn_gamma: Union[NDArray[np.float32], None], bn_beta: Union[NDArray[np.float32], None], bn_mean: NDArray[np.float32], bn_var: NDArray[np.float32], bn_epsilon: float, start: int, end: int) -> Tuple[NDArray[np.float32], NDArray[np.float32]]: if bn_gamma is not None: multiplier = bn_gamma[start:end] / np.sqrt(bn_var[start:end] + bn_epsilon) else: multiplier = 1 / np.sqrt(bn_var[start:end] + bn_epsilon) if target_type == "Gemm": bn_weight = np.diag(multiplier) elif target_type == "ConvTranspose": bn_weight = multiplier.reshape(1, len(multiplier), 1, 1) elif target_type == "Conv": bn_weight = multiplier.reshape(len(multiplier), 1, 1, 1) if bn_beta is not None: bn_bias = bn_beta[start:end] + (-bn_mean[start:end]) * multiplier else: bn_bias = (-bn_mean[start:end]) * multiplier if target_type == "Gemm": folded_weight = np.dot(bn_weight, target_weight) folded_bias = np.dot(bn_weight, target_bias) + bn_bias elif target_type == "ConvTranspose": folded_weight = bn_weight * target_weight folded_bias = bn_weight.reshape(1, -1) * target_bias + bn_bias folded_bias = folded_bias.reshape(-1) elif target_type == "Conv": folded_weight = bn_weight * target_weight folded_bias = bn_weight.reshape(-1) * target_bias + bn_bias return folded_weight, folded_bias onnx_model = ONNXModel(self.model) TARGET_OPS = ('ConvTranspose', 'Gemm', 'Conv') remove_nodes = [] for node in onnx_model.model.graph.node: if node.op_type != 'BatchNormalization' or self.should_quantize_node(node): continue if len(node.input) != 5: logger.warning(f"BatchNorm {node.name} with {len(node.input)} inputs cannot be folded.") continue # find potential target nodes parent_node = onnx_model.get_parent(node, 0) if parent_node is None: logger.warning(f"BatchNorm {node.name} that is isolated node cannot be folded.") continue if parent_node.op_type == 'Concat': grandparent_nodes = onnx_model.get_parents(parent_node) else: continue # check if all target nodes satisfy the requirements to be folded is_foldable = True for target_node in grandparent_nodes: target_type = target_node.op_type if target_type not in TARGET_OPS: logger.debug( f"Not all parent nodes of Concat are in ['ConvTranspose', 'Gemm', 'Conv'], so BatchNorm {node.name} after Concat node cannot be folded." ) is_foldable = False break if target_type == "Gemm": transB = next((attr.i for attr in target_node.attribute if attr.name == 'transB'), 0) # TODO: Support transB is 0 if transB == 0: logger.debug(f"Target node f{target_node.name}'s transB=0 is not supported.") is_foldable = False break if target_type == "ConvTranspose": group = next((attr.i for attr in target_node.attribute if attr.name == 'group'), 1) # TODO: Support ConvTranspose group != 1 if group != 1: logger.debug(f"Target node f{target_node.name}'s group !=1 is not supported.") is_foldable = False break target_weight_init = onnx_model.get_initializer(target_node.input[1]) target_weight = None if target_weight_init is None else onnx.numpy_helper.to_array(target_weight_init) if target_weight is None: logger.warning(f"BatchNorm {node.name}'s target node f{target_node.name} is not foldable.") is_foldable = False break if is_foldable is False: continue bn_gamma_init = onnx_model.get_initializer(node.input[1]) bn_gamma = None if bn_gamma_init is None else onnx.numpy_helper.to_array(bn_gamma_init) bn_beta_init = onnx_model.get_initializer(node.input[2]) bn_beta = None if bn_beta_init is None else onnx.numpy_helper.to_array(bn_beta_init) bn_mean_init = onnx_model.get_initializer(node.input[3]) bn_mean = None if bn_mean_init is None else onnx.numpy_helper.to_array(bn_mean_init) bn_var_init = onnx_model.get_initializer(node.input[4]) bn_var = None if bn_var_init is None else onnx.numpy_helper.to_array(bn_var_init) bn_epsilon = next((attr.f for attr in node.attribute if attr.name == 'epsilon'), 1e-10) if bn_mean is None or bn_var is None: logger.warning(f"BatchNorm {node.name} that is missing mean or variance cannot be folded.") continue # fold batchnorm to target nodes start_idx, end_idx = 0, 0 for i in range(len(grandparent_nodes)): target_node = grandparent_nodes[i] target_weight_init = onnx_model.get_initializer(target_node.input[1]) target_weight = onnx.numpy_helper.to_array(target_weight_init) target_bias_init = onnx_model.get_initializer(target_node.input[2]) if len( target_node.input) > 2 else None target_bias = None if target_bias_init is None else onnx.numpy_helper.to_array(target_bias_init) if target_bias is None: if target_type == "Conv": target_bias = np.zeros(target_weight.shape[0]) elif target_type == "Gemm": target_bias = np.zeros(target_weight.shape[0]) else: # if target_type == "ConvTranspose": target_bias = np.zeros(target_weight.shape[1]) target_bias_name = target_node.name + "_bias_4bn" target_bias_init = onnx.numpy_helper.from_array(target_bias.astype(np.float32), name=target_bias_name) onnx_model.add_initializer(target_bias_init) target_node.input.append(target_bias_name) end_idx += target_bias.shape[0] # Calculate the weight and bias after folded folded_weight, folded_bias = _get_folded_weight_bias(target_type, target_weight, target_bias, bn_gamma, bn_beta, bn_mean, bn_var, bn_epsilon, start_idx, end_idx) start_idx += target_bias.shape[0] # Update target node's weight and bias folded_weight_init = onnx.numpy_helper.from_array(folded_weight.astype(np.float32), name=target_weight_init.name) target_weight_init.CopyFrom(folded_weight_init) folded_bias_init = onnx.numpy_helper.from_array(folded_bias.astype(np.float32), name=target_bias_init.name) target_bias_init = onnx_model.get_initializer(target_node.input[2]) target_bias_init.CopyFrom(folded_bias_init) # Deal with the tensor name children = onnx_model.get_children(parent_node) for child in children: if child is node: # this node will be removed continue for input_index, input_name in enumerate(child.input): if input_name == parent_node.output[0]: child.input[input_index] = node.output[0] parent_node.output[0] = node.output[0] # TODO: has shared initializers? remove_nodes.append(node) logger.info(f"Folded {node.op_type} {node.name} to {target_node.op_type} {target_node.name}.") onnx_model.remove_nodes(remove_nodes) onnx_model.clean_initializers() onnx_model.topological_sort() self.model = onnx_model.model
def dedicate_dq_node(self) -> None: onnx_model = ONNXModel(self.model) output_name_to_node = onnx_model.output_name_to_node() input_name_to_nodes = onnx_model.input_name_to_nodes() nodes_to_add = [] for node in onnx_model.model.graph.node: if node.op_type not in DEQUANT_OP_TYPES: continue # Deal with the implicit condition of multiple consumers consumers = [] if onnx_model.is_graph_output(node.output[0]): consumers = [node] # Just to occupy the position if node.output[0] not in input_name_to_nodes: continue children = input_name_to_nodes[node.output[0]] if len(children) + len(consumers) < 2: continue consumers = consumers + children if node.input[0] not in output_name_to_node: continue parent = output_name_to_node[node.input[0]] if parent.op_type not in QUANT_OP_TYPES: continue # If this is a shared weight, copy Q as well if onnx_model.get_initializer(parent.input[0]) is not None: copy_q = True else: copy_q = False for index, consumer in enumerate(consumers): if index == 0: continue postfix = f"_{index}" if copy_q: # Copy a new QuantizedLinear node parent_new = copy.deepcopy(parent) parent_new.name = parent_new.name + postfix parent_new.output[0] = parent_new.output[0] + postfix nodes_to_add.append(parent_new) output_info = next( (info for info in onnx_model.model.graph.value_info if info.name == parent.output[0]), None) output_info_new = next( (info for info in onnx_model.model.graph.value_info if info.name == parent_new.output[0]), None) if output_info is not None and output_info_new is None: output_info_new = copy.deepcopy(output_info) output_info_new.name = parent_new.output[0] onnx_model.model.graph.value_info.extend([output_info_new]) # Copy a new DequantizeLinear node node_new = copy.deepcopy(node) node_new.name = node.name + postfix node_new.output[0] = node.output[0] + postfix if copy_q: # Should Connect with the new q node_new.input[0] = parent_new.output[0] nodes_to_add.append(node_new) # Copy shape info output_info = next((info for info in onnx_model.model.graph.value_info if info.name == node.output[0]), None) output_info_new = next( (info for info in onnx_model.model.graph.value_info if info.name == node_new.output[0]), None) if output_info is not None and output_info_new is None: output_info_new = copy.deepcopy(output_info) output_info_new.name = node_new.output[0] onnx_model.model.graph.value_info.extend([output_info_new]) onnx_model.replace_node_input(consumer, node.output[0], node_new.output[0]) if len(nodes_to_add): logger.info(f"Dedicate {len(nodes_to_add)} DQs in post-processing.") onnx_model.add_nodes(nodes_to_add) onnx_model.topological_sort() self.model = onnx_model.model
[docs] def optimize(model: ModelProto, op_types_to_quantize: List[str], nodes_to_quantize: Optional[List[str]], nodes_to_exclude: Optional[List[str]], convert_bn_to_conv: bool = True, convert_reduce_mean_to_global_avg_pool: bool = True, split_large_kernel_pool: bool = True, convert_split_to_slice: bool = True, fuse_instance_norm: bool = True, fuse_l2_norm: bool = True, fuse_gelu: bool = True, fuse_layer_norm: bool = True, fold_batch_norm: bool = True, convert_clip_to_relu: bool = True, fold_batch_norm_after_concat: bool = True, dedicate_dq_node: bool = False) -> ModelProto: """ Optimize an ONNX model to meet specific constraints and requirements for deployment on an CPU/NPU. This function applies various optimization techniques to the provided ONNX model based on the specified parameters. The optimizations include fusing operations, converting specific layers, and folding batch normalization layers, among others. :param onnx.ModelProto model: The ONNX model to be optimized. :param List[str] op_types_to_quantize: List of operation types to be quantized. :param Optional[List[str]] nodes_to_quantize: List of node names to explicitly quantize. If `None`, quantization is applied based on the operation types. :param Optional[List[str]] nodes_to_exclude: List of node names to exclude from quantization. :param bool convert_bn_to_conv: Flag indicating whether to convert BatchNorm layers to Conv layers. :param bool convert_reduce_mean_to_global_avg_pool: Flag indicating whether to convert ReduceMean layers to GlobalAveragePool layers. :param bool split_large_kernel_pool: Flag indicating whether to split large kernel pooling operations. :param bool convert_split_to_slice: Flag indicating whether to convert Split layers to Slice layers. :param bool fuse_instance_norm: Flag indicating whether to fuse InstanceNorm layers. :param bool fuse_l2_norm: Flag indicating whether to fuse L2Norm layers. :param bool fuse_gelu: Flag indicating whether to fuse Gelu layers. :param bool fuse_layer_norm: Flag indicating whether to fuse LayerNorm layers. :param bool fold_batch_norm: Flag indicating whether to fold BatchNorm layers into preceding Conv layers. :param bool convert_clip_to_relu: Flag indicating whether to convert Clip layers to ReLU layers. :param bool fold_batch_norm_after_concat: Flag indicating whether to fold BatchNorm layers after concatenation operations. :return: The optimized ONNX model. :rtype: ModelProto Notes: - The ``Optimize`` class is used to apply the optimizations based on the provided flags. - The function returns the optimized model with the applied transformations. """ onnx_model = OnnxModel(model) opset_version = get_opset_version(onnx_model.model) optimizer = Optimize( model, op_types_to_quantize, nodes_to_quantize, nodes_to_exclude, ) if fuse_instance_norm: optimizer.fuse_instance_norm() if convert_reduce_mean_to_global_avg_pool: optimizer.convert_reduce_mean_to_global_avg_pool() if split_large_kernel_pool: optimizer.split_large_kernel_pool() if convert_split_to_slice: optimizer.convert_split_to_slice() if fuse_l2_norm: optimizer.fuse_l2_norm() if fuse_layer_norm: if opset_version < 17: logger.warning(f"The opset version is {opset_version} < 17. Skipping fusing layer normalization.") else: fusion_layernorm = FusionLayerNormalization(onnx_model) fusion_layernorm.apply() if fuse_gelu: if opset_version < 20: logger.warning(f"The opset version is {opset_version} < 20. Skipping fusing Gelu.") else: fusion_gelu = FusionGelu(onnx_model) fusion_gelu.apply() if fold_batch_norm: optimizer.fold_batch_norm() if convert_clip_to_relu: optimizer.convert_clip_to_relu() if fold_batch_norm_after_concat: optimizer.fold_batch_norm_after_concat() if convert_bn_to_conv: optimizer.convert_bn_to_conv() # Only for quantization post-processing if dedicate_dq_node: optimizer.dedicate_dq_node() return optimizer.model