-
Notifications
You must be signed in to change notification settings - Fork 0
/
conv2d_LC_layer.py
452 lines (422 loc) · 20.3 KB
/
conv2d_LC_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
# -*- coding: utf-8 -*-
"""Convolutional layers.
"""
"""
MIT License
Copyright (c) 2019 Nikos Karantzas
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import activations
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras import constraints
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.utils import conv_utils
#from tensorflow.python.keras.utils.generic_utils import transpose_shape
#from tensorflow.python.keras.legacy import interfaces
import os
import numpy as np
from random import sample
# import h5py
import scipy.io as sio
import tensorflow as tf
def randomly_sampled_kernel(fixed_kernel, num_filters, sample_size, c_in, c_out):
"""
:param fixed_kernel: np.array of predesigned filters;
It must be of shape [filters_height, filters_width, num_filters]
:param num_filters:
:param sample_size:
:param c_in: number of input channels in CNN layer
:param c_out: number of output channels in CNN layer
:return:
channels: a numpy array with shape = (width, height, c_in, c_out, sample_size)
"""
assert fixed_kernel.ndim == 3
# n = f.shape[2]
h = fixed_kernel.shape[0]
w = fixed_kernel.shape[1]
channels = np.zeros((h, w, c_in, c_out, sample_size))
for k in range(c_in):
for j in range(c_out):
channels[:, :, k, j, :] = fixed_kernel[:, :, sample(range(0, num_filters), sample_size)]
return channels
@tf.keras.utils.register_keras_serializable(package='Custom', name='conv_LC')
class _Conv_LC(Layer):
"""Abstract nD convolution layer (private, used as implementation base).
This layer creates a convolution kernel that is convolved
with the layer input to produce a tensor of outputs.
If `use_bias` is True, a bias vector is created and added to the outputs.
Finally, if `activation` is not `None`,
it is applied to the outputs as well.
# Arguments
rank: An integer, the rank of the convolution,
e.g. "2" for 2D convolution.
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of n integers, specifying the
dimensions of the convolution window.
strides: An integer or tuple/list of n integers,
specifying the strides of the convolution.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: One of `"valid"` or `"same"` (case-insensitive).
data_format: A string,
one of `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs.
`"channels_last"` corresponds to inputs with shape
`(batch, ..., channels)` while `"channels_first"` corresponds to
inputs with shape `(batch, channels, ...)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
dilation_rate: An integer or tuple/list of n integers, specifying
the dilation rate to use for dilated convolution.
Currently, specifying any `dilation_rate` value != 1 is
incompatible with specifying any `strides` value != 1.
activation: Activation function to use
(see [activations](../activations.md)).
If you don't specify anything, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix
(see [initializers](../initializers.md)).
bias_initializer: Initializer for the bias vector
(see [initializers](../initializers.md)).
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix
(see [regularizer](../regularizers.md)).
bias_regularizer: Regularizer function applied to the bias vector
(see [regularizer](../regularizers.md)).
activity_regularizer: Regularizer function applied to
the output of the layer (its "activation").
(see [regularizer](../regularizers.md)).
kernel_constraint: Constraint function applied to the kernel matrix
(see [constraints](../constraints.md)).
bias_constraint: Constraint function applied to the bias vector
(see [constraints](../constraints.md)).
"""
def __init__(self,
rank,
num_filters,
sample_size,
kernel_size,
strides=1,
padding='valid',
data_format=None,
dilation_rate=1,
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(_Conv_LC, self).__init__(**kwargs)
self.rank = rank
self.num_filters = num_filters
self.sample_size = sample_size
self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank,
'kernel_size')
# print("self.kernel_size", self.kernel_size)
self.strides = conv_utils.normalize_tuple(strides, rank, 'strides')
self.padding = conv_utils.normalize_padding(padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, rank,
'dilation_rate')
self.activation = activations.get(activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.activity_regularizer = regularizers.get(activity_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(ndim=self.rank + 2)
def build(self, input_shape):
if self.data_format == 'channels_first':
channel_axis = 1
else:
channel_axis = -1
if input_shape[channel_axis] is None:
raise ValueError('The channel dimension of the inputs '
'should be defined. Found `None`.')
input_dim = input_shape[channel_axis]
kernel_shape = self.kernel_size + (input_dim, self.num_filters)
#######################################################################################
# # This strategy works when using all the filters in the filter bank
# # numpy array of predesigned filter bank
# fixed_filters = np.load('path_to_filters.npy')
# fixed_filters_shape = fixed_filters.shape # (w, h, num_predesigned_filters)
# self.filter_bank_tensor = K.constant(
# value=fixed_filters,
# name='fixed_filters',)
# self.LC_coeffs = self.add_weight(shape=(fixed_filters_shape[-1], input_dim, self.num_filters),
# initializer=self.kernel_initializer,
# name='coeffs_weights',
# regularizer=self.kernel_regularizer,
# constraint=self.kernel_constraint)
#
# # shape = (width, height, c_in, c_out)
# self.kernel = tf.einsum('ijk4,klm->ijlm', self.filter_bank_tensor, self.LC_coeffs)
#######################################################################################
# Kazem's addition
fixed_filters = np.load('5x5_SDPFdict.npy')
num_fixed_filters = fixed_filters.shape[-1]
# print("num_fixed_filters: ", num_fixed_filters)
# sample_size = int(num_fixed_filters/2)
fixed_kernel = randomly_sampled_kernel(
fixed_kernel=fixed_filters,
num_filters=num_fixed_filters,
sample_size=self.sample_size,
c_in=input_dim,
c_out=self.num_filters)
# print("fixed_kernel.shape", fixed_kernel.shape)
self.fixed_kernel_tensor = K.constant(
value=fixed_kernel,
name='fixed_kernel',)
self.LC_coeffs = self.add_weight(shape=(input_dim, self.num_filters, self.sample_size),
initializer=self.kernel_initializer,
name='coeffs',
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
# print(self.kernel)
#######################################################################################
if self.use_bias:
self.bias = self.add_weight(shape=(self.num_filters,),
initializer=self.bias_initializer,
name='bias',
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
else:
self.bias = None
# Set input spec.
self.input_spec = InputSpec(ndim=self.rank + 2,
axes={channel_axis: input_dim})
self.built = True
def call(self, inputs):
if self.rank == 1:
raise Exception('1D convolution is not currently supported')
if self.rank == 2:
self.kernel = tf.einsum('ijklm,klm->ijkl', self.fixed_kernel_tensor, self.LC_coeffs)
outputs = K.conv2d(
inputs,
self.kernel,
strides=self.strides,
padding=self.padding,
data_format=self.data_format,
dilation_rate=self.dilation_rate)
if self.rank == 3:
raise Exception('3D convolution is not currently supported')
if self.use_bias:
outputs = K.bias_add(
outputs,
self.bias,
data_format=self.data_format)
if self.activation is not None:
return self.activation(outputs)
return outputs
def compute_output_shape(self, input_shape):
if self.data_format == 'channels_last':
space = input_shape[1:-1]
new_space = []
for i in range(len(space)):
new_dim = conv_utils.conv_output_length(
space[i],
self.kernel_size[i],
padding=self.padding,
stride=self.strides[i],
dilation=self.dilation_rate[i])
new_space.append(new_dim)
return (input_shape[0],) + tuple(new_space) + (self.num_filters,)
if self.data_format == 'channels_first':
space = input_shape[2:]
new_space = []
for i in range(len(space)):
new_dim = conv_utils.conv_output_length(
space[i],
self.kernel_size[i],
padding=self.padding,
stride=self.strides[i],
dilation=self.dilation_rate[i])
new_space.append(new_dim)
return (input_shape[0], self.num_filters) + tuple(new_space)
def get_config(self):
config = {
'rank': self.rank,
'num_filters': self.num_filters,
'sample_size': self.sample_size,
'kernel_size': self.kernel_size,
'strides': self.strides,
'padding': self.padding,
'data_format': self.data_format,
'dilation_rate': self.dilation_rate,
'activation': activations.serialize(self.activation),
'use_bias': self.use_bias,
'kernel_initializer': initializers.serialize(self.kernel_initializer),
'bias_initializer': initializers.serialize(self.bias_initializer),
'kernel_regularizer': regularizers.serialize(self.kernel_regularizer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'activity_regularizer': regularizers.serialize(self.activity_regularizer),
'kernel_constraint': constraints.serialize(self.kernel_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint)#,
#'input_spec': self.input_spec.serialize()
}
base_config = super(_Conv_LC, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config):
return cls(**config)
@tf.keras.utils.register_keras_serializable(package='Custom', name='conv2D_LC')
class Conv2D_LC(_Conv_LC):
"""2D convolution layer (e.g. spatial convolution over images).
This layer creates a convolution kernel that is convolved
with the layer input to produce a tensor of
outputs. If `use_bias` is True,
a bias vector is created and added to the outputs. Finally, if
`activation` is not `None`, it is applied to the outputs as well.
When using this layer as the first layer in a model,
provide the keyword argument `input_shape`
(tuple of integers, does not include the batch axis),
e.g. `input_shape=(128, 128, 3)` for 128x128 RGB pictures
in `data_format="channels_last"`.
# Arguments
filters: Integer, the dimensionality of the output space
(i.e. the number of output filters in the convolution).
kernel_size: An integer or tuple/list of 2 integers, specifying the
height and width of the 2D convolution window.
Can be a single integer to specify the same value for
all spatial dimensions.
strides: An integer or tuple/list of 2 integers,
specifying the strides of the convolution
along the height and width.
Can be a single integer to specify the same value for
all spatial dimensions.
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: one of `"valid"` or `"same"` (case-insensitive).
Note that `"same"` is slightly inconsistent across backends with
`strides` != 1, as described
[here](https://github.com/keras-team/keras/pull/9473#issuecomment-372166860)
data_format: A string,
one of `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs.
`"channels_last"` corresponds to inputs with shape
`(batch, height, width, channels)` while `"channels_first"`
corresponds to inputs with shape
`(batch, channels, height, width)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
dilation_rate: an integer or tuple/list of 2 integers, specifying
the dilation rate to use for dilated convolution.
Can be a single integer to specify the same value for
all spatial dimensions.
Currently, specifying any `dilation_rate` value != 1 is
incompatible with specifying any stride value != 1.
activation: Activation function to use
(see [activations](../activations.md)).
If you don't specify anything, no activation is applied
(ie. "linear" activation: `a(x) = x`).
use_bias: Boolean, whether the layer uses a bias vector.
kernel_initializer: Initializer for the `kernel` weights matrix
(see [initializers](../initializers.md)).
bias_initializer: Initializer for the bias vector
(see [initializers](../initializers.md)).
kernel_regularizer: Regularizer function applied to
the `kernel` weights matrix
(see [regularizer](../regularizers.md)).
bias_regularizer: Regularizer function applied to the bias vector
(see [regularizer](../regularizers.md)).
activity_regularizer: Regularizer function applied to
the output of the layer (its "activation").
(see [regularizer](../regularizers.md)).
kernel_constraint: Constraint function applied to the kernel matrix
(see [constraints](../constraints.md)).
bias_constraint: Constraint function applied to the bias vector
(see [constraints](../constraints.md)).
# Input shape
4D tensor with shape:
`(batch, channels, rows, cols)`
if `data_format` is `"channels_first"`
or 4D tensor with shape:
`(batch, rows, cols, channels)`
if `data_format` is `"channels_last"`.
# Output shape
4D tensor with shape:
`(batch, filters, new_rows, new_cols)`
if `data_format` is `"channels_first"`
or 4D tensor with shape:
`(batch, new_rows, new_cols, filters)`
if `data_format` is `"channels_last"`.
`rows` and `cols` values might have changed due to padding.
"""
# @interfaces.legacy_conv2d_support
def __init__(self,
num_filters,
sample_size,
kernel_size,
strides=(1, 1),
padding='valid',
data_format=None,
dilation_rate=(1, 1),
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
**kwargs):
super(Conv2D_LC, self).__init__(
rank=2,
num_filters=num_filters,
sample_size=sample_size,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
activation=activation,
use_bias=use_bias,
kernel_initializer=kernel_initializer,
bias_initializer=bias_initializer,
kernel_regularizer=kernel_regularizer,
bias_regularizer=bias_regularizer,
activity_regularizer=activity_regularizer,
kernel_constraint=kernel_constraint,
bias_constraint=bias_constraint,
**kwargs)
def get_config(self):
config = super(Conv2D_LC, self).get_config()
config.pop('rank')
return config
@classmethod
def from_config(cls, config):
print("Reading from Config")
print(config)
return cls(**config)