-
Notifications
You must be signed in to change notification settings - Fork 0
/
ParamVis.py
54 lines (44 loc) · 1.61 KB
/
ParamVis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""
对模型网络初始化和训练后各层参数的可视化
1. 绘制各层参数的直方图
2. 绘制各层参数的热力图
"""
import os
import sys
import json
sys.path.append("..")
import matplotlib.pyplot as plt
import src.Layers as L
from src.MLPModel import MLPModel
if not os.path.exists("images"):
os.makedirs("images")
ckpt_path = "./models/model_epoch_100.pkl"
nn_architecture = json.load(open(ckpt_path.replace(".pkl", ".json"), "r"))
model = MLPModel(nn_architecture)
print(model.layers)
for i, layer in enumerate(model.layers):
if isinstance(layer, L.Linear):
plt.figure()
plt.hist(layer.W.flatten(), bins=100)
plt.title(f"Layer {int(i/2) + 1} Weight Distribution")
plt.savefig(f"images/layer_{int(i/2) + 1}_weight_distribution_init.png")
plt.figure()
plt.imshow(layer.W, cmap="hot", interpolation="nearest")
plt.title(f"Layer {int(i/2) + 1} Weight Matrix")
plt.colorbar()
plt.savefig(f"images/layer_{i + 1}_weight_matrix_init.png")
print(i)
model.load_model_dict(path=ckpt_path)
for i, layer in enumerate(model.layers):
if isinstance(layer, L.Linear):
plt.figure()
plt.hist(layer.W.flatten(), bins=100)
plt.title(f"Layer {int(i/2) + 1} Weight Distribution")
plt.savefig(f"images/layer_{int(i/2) + 1}_weight_distribution.png")
print(i)
plt.figure()
plt.imshow(layer.W, cmap="hot", interpolation="nearest")
plt.title(f"Layer {int(i/2) + 1} Weight Matrix")
plt.colorbar()
plt.savefig(f"images/layer_{int(i/2) + 1}_weight_matrix.png")
print(i)