This repository has been archived by the owner on May 20, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
llama_pruner.py
158 lines (134 loc) · 5.67 KB
/
llama_pruner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import torch
import torch.nn as nn
import torch_pruning as tp
from torch_pruning import BasePruningFunc, ops
from copy import deepcopy
from functools import reduce
from operator import mul
from typing import Callable, Sequence, Tuple, Dict
##############################
# Pruners
##############################
class RMSNormPrunner(BasePruningFunc):
def prune_out_channels(self, layer: nn.Module, idxs: Sequence[int]) -> nn.Module:
#print("Pruning RMSNorm Layer: {}".format(layer))
keep_idxs = list(set(range(layer.weight.size(0))) - set(idxs))
keep_idxs.sort()
layer.weight = torch.nn.Parameter(
layer.weight[keep_idxs]
)
return layer
prune_in_channels = prune_out_channels
def get_out_channels(self, layer):
return layer.weight.size(0)
def get_in_channels(self, layer):
return layer.weight.size(0)
class AttentionPrunner(BasePruningFunc):
def prune_out_channels(self, layer: nn.Module, idxs: Sequence[int]) -> nn.Module:
assert len(idxs) % layer.n_heads == 0
for sub_layer in [layer.wq, layer.wk, layer.wv, layer.wo]:
keep_idxs = list(set(range(sub_layer.out_features)) - set(idxs))
keep_idxs.sort()
sub_layer.out_features = sub_layer.out_features-len(idxs)
sub_layer.weight = torch.nn.Parameter(sub_layer.weight.data[keep_idxs])
if sub_layer.bias is not None:
sub_layer.bias = torch.nn.Parameter(sub_layer.bias.data[keep_idxs])
keep_idxs = list(set(range(sub_layer.in_features)) - set(idxs))
keep_idxs.sort()
sub_layer.in_features = sub_layer.in_features-len(idxs)
sub_layer.weight = torch.nn.Parameter(
sub_layer.weight.data.cpu().clone()[:, keep_idxs]
)
layer.dim = layer.dim - len(idxs)
layer.head_dim = layer.dim // layer.n_heads
layer.cache_k.data = layer.cache_k.data.cpu().clone()[..., :layer.head_dim]
layer.cache_v = layer.cache_v.data.cpu().clone()[..., :layer.head_dim]
return layer
prune_in_channels = prune_out_channels
def get_out_channels(self, layer):
return layer.dim
def get_in_channels(self, layer):
return layer.dim
attention_pruner = AttentionPrunner()
rmsnorm_pruner = RMSNormPrunner()
##############################
# Importance
##############################
class MagnitudeImportance(tp.importance.Importance):
def __init__(self, p=2, group_reduction="mean", normalizer=None):
self.p = p
self.group_reduction = group_reduction
self.normalizer = normalizer
def _reduce(self, group_imp):
if self.group_reduction == "sum":
group_imp = group_imp.sum(dim=0)
elif self.group_reduction == "mean":
group_imp = group_imp.mean(dim=0)
elif self.group_reduction == "max":
group_imp = group_imp.max(dim=0)[0]
elif self.group_reduction == "prod":
group_imp = torch.prod(group_imp, dim=0)
elif self.group_reduction=='first':
group_imp = group_imp[0]
elif self.group_reduction is None:
group_imp = group_imp
else:
raise NotImplementedError
return group_imp
@torch.no_grad()
def __call__(self, group, ch_groups=1):
group_imp = []
#Get group norm
#print(group)
for dep, idxs in group:
idxs.sort()
layer = dep.target.module
prune_fn = dep.handler
# Linear out_channels
if prune_fn==tp.prune_linear_out_channels:
w = layer.weight.data[idxs].flatten(1)
local_norm = w.abs().pow(self.p).sum(1)
group_imp.append(local_norm)
# Linear in_channels
elif prune_fn in [
tp.prune_linear_in_channels,
]:
w = layer.weight
local_norm = w.abs().pow(self.p).sum(0)
local_norm = local_norm[idxs]
group_imp.append(local_norm)
# RMSNorm
elif prune_fn == rmsnorm_pruner.prune_out_channels:
# regularize BN
w = layer.weight.data[idxs]
local_norm = w.abs().pow(self.p)
group_imp.append(local_norm)
# Embedding
elif prune_fn == tp.prune_embedding_out_channels:
w = layer.weight.data[:, idxs]
local_norm = w.abs().pow(self.p)
group_imp.append(local_norm)
# Attention
elif prune_fn == attention_pruner.prune_out_channels:
local_norm = 0
for sub_layer in [layer.wq, layer.wk, layer.wv, layer.wo]:
w_out = sub_layer.weight.data[idxs]
local_norm += w_out.abs().pow(self.p).sum(1)
w_in = sub_layer.weight.data[:, idxs]
local_norm += w_in.abs().pow(self.p).sum(0)
group_imp.append(local_norm)
if len(group_imp)==0:
return None
min_imp_size = min([len(imp) for imp in group_imp])
aligned_group_imp = []
for imp in group_imp:
if len(imp)>min_imp_size and len(imp)%min_imp_size==0:
imp = imp.view(len(imp) // min_imp_size, min_imp_size).sum(0)
aligned_group_imp.append(imp)
elif len(imp)==min_imp_size:
aligned_group_imp.append(imp)
group_imp = torch.stack(aligned_group_imp, dim=0)
group_imp = self._reduce(group_imp)
if self.normalizer is not None:
group_imp = self.normalizer(group, group_imp)
return group_imp