LLM Pruning#
Overview of LLM Pruning#
Deployment of LLMs is constrained by their intensive computational demands.
To make LLMs more accessible and efficient for practical use, LLM pruning shows the ability to reduce the number of parameters or create a Small LLM from a larger one.
Through pruning, either by dropping layers (depth pruning) or dropping neurons and attention heads, and embedding channels (width pruning).
After pruning, the higher the pruning ratio, the higher the accuracy degradation will be. As a result, Pruning is often accompanied by some fine-tuning or retraining for accuracy recovery.
Mainstream Methods Overview#
Unstructured pruning: Concentrate on individual weights, set the individual elements in weight to 0. Meaning that sparsity of rate s% will eliminate s% of the entries in a weight matrix. But the efficiency of the pruned sparse network cannot be realized on general-purpose GPU hardware.
Semi-structured pruning: This approach enforces exactly N non-zero values in each block of M consecutive weights. Whether can get GPU acceleration depends on specific hardware architectures.
Structured pruning: The pruning methods can be categorized into two different grain levels.
Channel-wise: Based on the weight importance, delete entire rows or columns of weights, providing a more hardware-friendly solution that reduces storage requirements and improves GPU inference performance.
Layer-Wise: Delete the entire decode layer, which reduces the model parameters and model structure more straightforwardly. As a result, this can bring a more significant GPU inference performance improvement.
By viewing the above short introduction, even though several methods have been proposed by researchers. Considering the actual production environments, Unstructured pruning and Semi-structured pruning may have competitive results in accuracy. It may not decrease the amount of computation, especially may not bring GPU inference acceleration.
In order to get a real LLM model size reduction and forward computation acceleration. As a result, in the Quark Pruning tool, we mainly consider the structured pruning methods, which can directly reduce the model size and accelerate GPU inference.
Quark Pruning Tool Feature#
Supported Methods Overview#
OSSCAR Pruner: A structured pruning method, this method needs to use a small amount of calibration data to perform layer-wise reconstruction and minimize the layer reconstruction error.
This method will reduce the intermediate dimension (output channel of the former FC and the input channel of the following FC layer ) in fully connected layers.
This method will not change LLM model structure.
After pruning, if the remaining channel number is power-of-2, it usually can bring GPU inference time acceleration. (e.g 4096 = 2**12).
Depth-Wise Pruner: Based on the PPL influence, the consecutive decode layers that even after deletion have less influence on the final PPL will be regarded as having less influence on the LLM. These layers can be deleted.
Will delete consecutive decode layers based on pre-defined prune ratio.
Will directly reduce model size, as it directly deletes the decode layers. This way will improve GPU inference performance.
For example, the following code block shows a typical
Llama
model structure.For
OSSCAR
pruner: will decrease the output channel ingate_proj
andup_proj
, meanwhile decrease the input channel ofdown_proj
.For depth pruner: will directly delete the entire
LlamaDecoderLayer
in the model, for example, after deleting the 10 consecutive layers, this model will remain 70 (80-10) layers, which can directly improve GPU inference performance.
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 8192)
(layers): ModuleList(
(0-79): 80 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=8192, out_features=8192, bias=False)
(k_proj): Linear(in_features=8192, out_features=1024, bias=False)
(v_proj): Linear(in_features=8192, out_features=1024, bias=False)
(o_proj): Linear(in_features=8192, out_features=8192, bias=False)
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=8192, out_features=28672, bias=False)
(up_proj): Linear(in_features=8192, out_features=28672, bias=False)
(down_proj): Linear(in_features=28672, out_features=8192, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((8192,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((8192,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((8192,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=8192, out_features=128256, bias=False)
)
API Usage#
Prepare LLM & Tokenizer#
model = AutoModelForCausalLM.from_pretrained(ckpt_path, device_map="auto", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(ckpt_path, padding_side="left")
Init Prune Config & Calibration/Test Dataset#
from quark.torch.pruning.config import Config
from quark.torch.pruning.config import OSSCARConfig, LayerImportancePruneConfig
# Init Prune config
# Using OSSCAR
algo_config = OSSCARConfig.from_dict(json.load(algo_config_file))
# or Using depth-wise prune
algo_config = LayerImportancePruneConfig.from_dict((json.load(algo_config_file))
pruning_config = Config(algo_config=pruning_algo_config)
# Init the Calibration/Test dataset (Fake code)
def get_wikitext_dataset(data_dir, tokenizer):
testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt")
return testenc
validation_data = get_wikitext_dataset(data_dir, tokenizer)
Init Pruner & Perform Pruning#
from quark.torch import ModelPruner
model_pruner = ModelPruner(pruning_config)
model = model_pruner.pruning_model(model, validation_data)
Evaluate the Metric & Save Pruned Model#
# evaluate
eval_model(args, model, main_device) # evaluate the PPL on pruned model
# save to target path
model.save_pretrained(save_dir, safe_serialization=True)
tokenizer.save_pretrained(save_dir)
Prune Results#
OSSCAR:#
Model Name |
Model Size |
Pruning Rate |
Pruned Model Size |
Before Pruning PPL On Wiki2 |
After Pruning PPL On Wiki2 |
---|---|---|---|---|---|
mistralai/Mixtral-8x7B-Instruct-v0.1 |
46.7B |
9.4838% |
42.2B |
4.1370 |
5.1195 |
CohereForAI/c4ai-command-r-08-2024 |
32.3B |
7.4025% |
29.9B |
4.5081 |
6.3794 |
Qwen/Qwen2.5-14B-Instruct |
14.8B |
7.0284% |
13.7B |
5.6986 |
7.5994 |
meta-llama/Meta-Llama-3-8B |
8.0B |
6.8945% |
7.5B |
6.1382 |
8.0755 |
meta-llama/Llama-2-7b-hf |
6.7B |
6.7224% |
6.2B |
5.4721 |
6.2462 |
facebook/opt-6.7b |
6.7B |
7.5651% |
6.2B |
10.8602 |
11.8958 |
THUDM/chatglm3-6b |
6.2B |
7.7590% |
5.6B |
29.9560 |
36.0010 |
microsoft/Phi-3.5-mini-instruct |
3.8B |
5.9274% |
3.6B |
6.1959 |
7.8074 |
Depth-Wise Prune:#
Model |
PPL (Original) |
PPL (After prune) |
Prune Ratio |
Other Info |
---|---|---|---|---|
llama-2-7b-chat-hf |
6.9419 |
8.139 |
9.01% |
delete 3 layer [11-13] total 32 |
llama-3.1-405B |
1.85624 |
2.8429 |
10.3% |
delete 13 layer [18-30] total 126 |
Llama-2-70b-chat-h |
4.6460 |
5.5305 |
11.16% |
delete 9 layer [27-35] total 80 |
Qwen2.5-14B-Instruct |
5.7010 |
7.0681 |
9.32% |
delete 5 layers [22-26] total 48 |
Mixtral-8x7B-Instruct-v0.1 |
4.1378 |
5.0338 |
10% |
delete 3 layer [12-14] total 32 |
facebook/opt-6.7b |
10.8605 |
14.4321 |
12.1% |
delete 4 layer [21-24] total 32 |
deepseek-moe-16b-chat |
7.3593 |
8.94327 |
10.76% |
delete 3 layer [11-13] total 28 |
Example Code#
All the example code can be found under /examples/torch/language_modeling/llm_pruning
Other:#
Reference: OSSCAR: One-Shot Structured Pruning in Vision and Language Models with Combinatorial Optimization