forked from VainF/Torch-Pruning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_serialization.py
58 lines (49 loc) · 1.78 KB
/
test_serialization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))
import torch
from torchvision.models import vit_b_16 as entry
import torch_pruning as tp
from torchvision.models.vision_transformer import VisionTransformer
def test_serialization():
model = entry().eval()
customized_value = 8
model.customized_value = customized_value
importance = tp.importance.MagnitudeImportance(p=1)
round_to = None
if isinstance( model, VisionTransformer): round_to = model.encoder.layers[0].num_heads
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs=torch.randn(1, 3, 224, 224),
importance=importance,
iterative_steps=1,
pruning_ratio=0.5,
round_to=round_to,
)
pruner.step()
if isinstance(
model, VisionTransformer
): # Torchvision relies on the hidden_dim variable for forwarding, so we have to modify this varaible after pruning
model.hidden_dim = model.conv_proj.out_channels
true_hidden_dim = model.hidden_dim
print(model.class_token.shape, model.encoder.pos_embedding.shape)
state_dict = tp.state_dict(model)
torch.save(state_dict, 'test.pth')
# create a new model
model = entry().eval()
print(model)
# load the pruned state_dict
loaded_state_dict = torch.load('test.pth', map_location='cpu')
tp.load_stat
e_dict(model, state_dict=loaded_state_dict)
print(model)
# test
assert model.customized_value == customized_value
assert model.hidden_dim == true_hidden_dim
print(model.customized_value) # check the user attributes
print(model.hidden_dim)
out = model(torch.randn(1,3,224,224))
print(out.shape)
loss = out.sum()
loss.backward()
if __name__=='__main__':
test_serialization()