Pruning#

Quark Peuning API for PyTorch.

class quark.torch.pruning.api.ModelPruner(config: Config)[source]#

Provides an API for pruning deep learning models using PyTorch.

This class handles the configuration and processing of the model for pruning based on user-defined parameters. It is essential to ensure that the ‘config’ provided has all necessary pruning parameters defined. This class assumes that the model is compatible with the pruning settings specified in ‘config’.

Parameters:

config (Config) – Model pruning configuration.

pruning_model(model: Module, dataloader: DataLoader[Tensor] | DataLoader[List[Dict[str, Tensor]]] | DataLoader[Dict[str, Tensor]] | None = None) Module[source]#

Prunes the given PyTorch model to optimize its performance and reduce its size.

The dataloader is used to provide data necessary for calibration during the pruning process. Depending on the type of data provided (either tensors directly or structured as lists or dictionaries of tensors), the function will adapt the pruning approach accordingly.

It is important that the model and dataloader are compatible in terms of the data they expect and produce. Misalignment in data handling between the model and the dataloader can lead to errors during the pruning process.

Parameters:
  • model (torch.nn.Module) – The PyTorch model to be pruned. This model should be already trained and ready for pruning.

  • dataloader (Optional[Union[DataLoader[torch.Tensor], DataLoader[List[Dict[str, torch.Tensor]]], DataLoader[Dict[str, torch.Tensor]]]]) – The torch.utils.data.DataLoader providing data that the pruning process will use for calibration. This can be a simple DataLoader returning tensors, or a more complex structure returning either a list of dictionaries or a dictionary of tensors. Defaults to None.

Returns:

The pruned version of the input model. This model is now optimized for inference with reduced size and potentially improved performance on targeted devices.

Return type:

torch.nn.Module