-
Notifications
You must be signed in to change notification settings - Fork 3
/
masked_fill.py
57 lines (46 loc) · 1.36 KB
/
masked_fill.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
from time import time
torch.manual_seed(0)
warmups = 100 # iterations
total_times = 10 # seconds
# output = input.masked_fill_(mask, value)
def run_single_test(N, C, contiguous=True):
if contiguous:
input = torch.randn(N, C)
else:
input = torch.randn(N, C + 16).narrow(1, 0, C)
mask = input.ge(0.0)
for i in range(warmups):
output = input.masked_fill_(mask, 0)
ttime = 0
iters = 0
while(ttime < total_times):
t1 = time()
output = input.masked_fill_(mask, 0)
t2 = time()
ttime = ttime + t2 - t1
iters = iters + 1
tt = ttime * 1000 / iters
print("input size: [{} {}]; contiguous: {}; time = {:.3f} ms".format(
N, C, ("True" if contiguous else "False"), tt))
def benchmark():
for contig in [True, False]:
run_single_test(128, 1000, contig)
run_single_test(256, 1000, contig)
run_single_test(512, 1000, contig)
run_single_test(1024, 1000, contig)
benchmark()
def validate():
input = torch.randn(3, 4)
mask = input.ge(0.0)
print('bool mask')
print('input', input)
print('mask', mask)
output = input.masked_fill_(mask, 0)
print('output', output)
mask1 = mask.byte()
print('byte mask')
print('mask', mask1)
output1 = input.masked_fill_(mask1, 0)
print('output1', output1)
#validate()