-
Notifications
You must be signed in to change notification settings - Fork 2
/
FFEncoding.py
32 lines (26 loc) · 940 Bytes
/
FFEncoding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from torch import Tensor
class FFEncoding(object):
@staticmethod
def overlay(x: Tensor, y: Tensor):
"""
Replace the first 10 pixels of data [x] with one-hot-encoded label [y]
"""
x_ = x.clone()
x_[:, :10] *= 0.0
x_[range(x.shape[0]), y] = x.max()
return x_
@staticmethod
def overlay2d(x: Tensor, y: Tensor):
"""
Replace the first 10 pixels of data [x] for all channels with one-hot-encoded label [y]
"""
assert x.size(dim=1) % 3 == 0, "Expects a 3 channel image"
img_size = int(x.size(dim=1) / 3)
x_ = x.clone()
x_[:, 0:10] *= 0.0
x_[range(x.shape[0]), y] = x.max()
x_[:, img_size : img_size + 10] *= 0.0
x_[range(x.shape[0]), img_size + y] = x.max()
x_[:, img_size * 2 : img_size * 2 + 10] *= 0.0
x_[range(x.shape[0]), img_size * 2 + y] = x.max()
return x_