-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
56 lines (39 loc) · 1.68 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
import math
class PositionalEncoding(torch.nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = torch.nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
# print("1- Position:", position.size())
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
# print("2- Div term: ", div_term.size())
pe = torch.zeros(max_len, 1, d_model)
# print("3- PE: ", pe.size())
# print("4- pos x div term: ", torch.sin(position * div_term).size())
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
"""
Args:
x: Tensor, shape [seq_len, batch_size, embedding_dim]
"""
x = x + self.pe[:x.size(0)]
return self.dropout(x)
def transfomer_mask(seq_len):
return torch.triu(torch.ones(seq_len, seq_len)*float("-inf"), diagonal=1)
def textTotensor(raw_text_iter,tokenizer, vocab):
"""Converts raw text into a flat Tensor."""
data = [torch.tensor(vocab(tokenizer(item)), dtype=torch.long) for item in raw_text_iter]
return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))
def batchify_by_cutting(data, bsz):
seq_len = data.size(0) // bsz
data = data[:seq_len * bsz]
data = data.view(bsz, seq_len).t().contiguous()
return data
def get_batch(source, i, bptt = 35):
seq_len = min(bptt, len(source) - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len].reshape(-1)
return [data, target]