forked from zoogzog/chexnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
HeatmapGenerator.py
104 lines (73 loc) · 3.52 KB
/
HeatmapGenerator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import numpy as np
import time
import sys
from PIL import Image
import cv2
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from DensenetModels import DenseNet121
from DensenetModels import DenseNet169
from DensenetModels import DenseNet201
#--------------------------------------------------------------------------------
#---- Class to generate heatmaps (CAM)
class HeatmapGenerator ():
#---- Initialize heatmap generator
#---- pathModel - path to the trained densenet model
#---- nnArchitecture - architecture name DENSE-NET121, DENSE-NET169, DENSE-NET201
#---- nnClassCount - class count, 14 for chxray-14
def __init__ (self, pathModel, nnArchitecture, nnClassCount, transCrop):
#---- Initialize the network
if nnArchitecture == 'DENSE-NET-121': model = DenseNet121(nnClassCount, True).cuda()
elif nnArchitecture == 'DENSE-NET-169': model = DenseNet169(nnClassCount, True).cuda()
elif nnArchitecture == 'DENSE-NET-201': model = DenseNet201(nnClassCount, True).cuda()
model = torch.nn.DataParallel(model).cuda()
modelCheckpoint = torch.load(pathModel)
model.load_state_dict(modelCheckpoint['state_dict'])
self.model = model.module.densenet121.features
self.model.eval()
#---- Initialize the weights
self.weights = list(self.model.parameters())[-2]
#---- Initialize the image transform - resize + normalize
normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
transformList = []
transformList.append(transforms.Resize(transCrop))
transformList.append(transforms.ToTensor())
transformList.append(normalize)
self.transformSequence = transforms.Compose(transformList)
#--------------------------------------------------------------------------------
def generate (self, pathImageFile, pathOutputFile, transCrop):
#---- Load image, transform, convert
imageData = Image.open(pathImageFile).convert('RGB')
imageData = self.transformSequence(imageData)
imageData = imageData.unsqueeze_(0)
input = torch.autograd.Variable(imageData)
self.model.cuda()
output = self.model(input.cuda())
#---- Generate heatmap
heatmap = None
for i in range (0, len(self.weights)):
map = output[0,i,:,:]
if i == 0: heatmap = self.weights[i] * map
else: heatmap += self.weights[i] * map
#---- Blend original and heatmap
npHeatmap = heatmap.cpu().data.numpy()
imgOriginal = cv2.imread(pathImageFile, 1)
imgOriginal = cv2.resize(imgOriginal, (transCrop, transCrop))
cam = npHeatmap / np.max(npHeatmap)
cam = cv2.resize(cam, (transCrop, transCrop))
heatmap = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET)
img = heatmap * 0.5 + imgOriginal
cv2.imwrite(pathOutputFile, img)
#--------------------------------------------------------------------------------
pathInputImage = 'test/00009285_000.png'
pathOutputImage = 'test/heatmap.png'
pathModel = 'models/m-25012018-123527.pth.tar'
nnArchitecture = 'DENSE-NET-121'
nnClassCount = 14
transCrop = 224
h = HeatmapGenerator(pathModel, nnArchitecture, nnClassCount, transCrop)
h.generate(pathInputImage, pathOutputImage, transCrop)