-
Notifications
You must be signed in to change notification settings - Fork 2
/
example_loader.lua
129 lines (108 loc) · 5.28 KB
/
example_loader.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
local ExampleLoader, parent = torch.class('ExampleLoader')
function ExampleLoader:__init(dataset, normalization_params, scales, example_loader_opts)
self.scales = scales
self.normalization_params = normalization_params
self.example_loader_opts = example_loader_opts
self.dataset = dataset
end
local function table2d(I, J, elem_generator)
local res = {}
for i = 1, I do
res[i] = {}
for j = 1, J do
res[i][j] = elem_generator(i, j)
end
end
return res
end
local function subtract_mean(dst, src, normalization_params)
local channel_order = assert(({rgb = {1, 2, 3}, bgr = {3, 2, 1}})[normalization_params.channel_order])
for c = 1, 3 do
dst[c]:copy(src[channel_order[c]]):add(-normalization_params.rgb_mean[channel_order[c]])
if normalization_params.rgb_std then
dst[c]:div(normalization_params.rgb_std[channel_order[c]])
end
end
end
local function rescale(img, max_height, max_width)
--local height_width = math.max(dhw_rgb:size(3), dhw_rgb:size(2))
--local im_scale = target_height_width / height_width
local scale_factor = max_height / img:size(2)
if torch.round(img:size(3) * scale_factor) > max_width then
scale_factor = math.min(scale_factor, max_width / img:size(3))
end
return image.scale(img, math.min(max_width, img:size(3) * scale_factor), math.min(max_height, img:size(2) * scale_factor))
end
local function flip(images_j, rois_j)
image.hflip(images_j, images_j)
rois_j:select(2, 1):mul(-1):add(images_j:size(3))
rois_j:select(2, 3):mul(-1):add(images_j:size(3))
local tmp = rois_j:select(2, 1):clone()
rois_j:select(2, 1):copy(rois_j:select(2, 3))
rois_j:select(2, 3):copy(tmp)
end
local function insert_dummy_dim1(...)
for _, tensor in ipairs({...}) do
tensor:resize(1, unpack(tensor:size():totable()))
end
end
function ExampleLoader:makeBatchTable(batchSize, isTrainingPhase)
local o = self:getPhaseOpts(isTrainingPhase)
local num_jittered_copies = isTrainingPhase and 2 or (1 + (o.hflips and 2 or 1) * o.numScales)
return table2d(batchSize, num_jittered_copies, function() return {torch.FloatTensor(), torch.FloatTensor(), torch.FloatTensor(), torch.FloatTensor()} end)
end
function ExampleLoader:loadExample(exampleIdx, isTrainingPhase)
local o = self:getPhaseOpts(isTrainingPhase)
local labels_loaded = self.dataset[o.subset]:getLabels(exampleIdx)
local rois_loaded = self.dataset[o.subset]:getProposals(exampleIdx)
local jpeg_loaded = self.dataset[o.subset]:getJpegBytes(exampleIdx)
local scales = o.scales or self.scales
local normalization_params = self.normalization_params
if(isTrainingPhase) then
rois_wt=torch.load(weight_dir..string.gsub(self.dataset[o.subset]:getImageFileName(exampleIdx),'jpg','t7'));
else
rois_wt=torch.Tensor(rois_loaded:size()[1],80)
end
local scale_inds = isTrainingPhase and {0, torch.random(1, o.numScales)} or torch.range(0, o.numScales):totable()
local hflips = isTrainingPhase and (o.hflips and torch.random(0, 1) or 0) or (o.hflips and 2 or 0) -- 0 is no_flip, 1 is do_flip, 2 is both
local rois_perm = isTrainingPhase and torch.randperm(rois_loaded:size(1)) or torch.range(1, rois_loaded:size(1))
return function(indexInBatch, batchTable)
image = image or require 'image'
local img_original = image.decompressJPG(jpeg_loaded, 3, normalization_params.scale == 255 and 'byte' or 'float')
local height_original, width_original = img_original:size(2), img_original:size(3)
local rois_scale0 = rois_loaded:index(1, rois_perm:sub(1, math.min(rois_loaded:size(1), o.numRoisPerImage)):long())
local rois_wt_perm=rois_wt:index(1, rois_perm:sub(1, math.min(rois_loaded:size(1), o.numRoisPerImage)):long())
for j, scale_ind in ipairs(scale_inds) do
local images, rois, labels, my_wt = unpack(batchTable[indexInBatch][j])
my_wt:resize(rois_wt_perm:size()):copy(rois_wt_perm)
local img_scaled = scale_ind == 0 and img_original:clone() or rescale(img_original, scales[scale_ind][1], scales[scale_ind][2])
local width_scaled, height_scaled = img_scaled:size(3), img_scaled:size(2)
subtract_mean(images:resize(img_scaled:size()), img_scaled, normalization_params)
rois:cmul(rois_scale0, torch.FloatTensor{{width_scaled / width_original, height_scaled / height_original, width_scaled / width_original, height_scaled / height_original, 1.0}}:narrow(2, 1, rois_scale0:size(2)):contiguous():expandAs(rois_scale0))
labels:resize(labels_loaded:size()):copy(labels_loaded)
if hflips == 1 then
flip(images, rois)
elseif scale_ind ~= 0 and hflips == 2 then
local jj = #batchTable[indexInBatch] - j + 2
local images_flipped, rois_flipped, labels_flipped = unpack(batchTable[indexInBatch][jj])
images_flipped:resizeAs(images):copy(images)
rois_flipped:resizeAs(rois):copy(rois)
labels_flipped:resizeAs(labels):copy(labels)
flip(images_flipped, rois_flipped)
insert_dummy_dim1(images_flipped, rois_flipped, labels_flipped)
end
-- print(rois)
insert_dummy_dim1(images, rois, labels)
end
collectgarbage()
end
end
function ExampleLoader:getNumExamples(isTrainingPhase)
return self.dataset[self:getSubset(isTrainingPhase)]:getNumExamples()
end
function ExampleLoader:getPhaseOpts(isTrainingPhase)
return isTrainingPhase and self.example_loader_opts['training'] or self.example_loader_opts['evaluate']
end
function ExampleLoader:getSubset(isTrainingPhase)
return self:getPhaseOpts(isTrainingPhase).subset
end