-
Notifications
You must be signed in to change notification settings - Fork 29
/
word2vec.py
58 lines (48 loc) · 1.88 KB
/
word2vec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from gensim.models import word2vec
import os
import pickle
import random
import torch
import numpy as np
import configparser
def get_texts(train_loader, vocabulary):
texts = []
for filename in train_loader.loaders:
for value in train_loader.loaders[filename]:
loader = list(train_loader.loaders[filename][value])
for data, _ in loader:
for text in data:
text = text.tolist()
for i in range(len(text)):
text[i] = vocabulary.to_word(text[i])
texts.append(text)
print('texts', len(texts))
return texts
def get_weights(model, vocabulary, embed_dim):
weights = np.zeros((len(vocabulary), embed_dim))
for i in range(len(vocabulary)):
if vocabulary.to_word(i) == '<pad>':
continue
weights[i] = model.wv[vocabulary.to_word(i)]
return weights
def main():
data_path = config['data']['path']
embed_dim = int(config['model']['embed_dim'])
vocabulary = pickle.load(open(os.path.join(data_path, config['data']['vocabulary']), 'rb'))
train_loader = pickle.load(open(os.path.join(data_path, config['data']['train_loader']), 'rb'))
texts = get_texts(train_loader, vocabulary)
model = word2vec.Word2Vec(window=int(config['data']['window']), min_count=int(config['data']['min_count']), size=embed_dim)
model.build_vocab(texts)
model.train(texts, total_examples=model.corpus_count, epochs=model.epochs)
weights = get_weights(model, vocabulary, embed_dim)
pickle.dump(weights, open(os.path.join(data_path, config['data']['weights']), 'wb'))
if __name__ == '__main__':
config = configparser.ConfigParser()
config.read('config.ini')
# seed
seed = int(config['data']['seed'])
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
main()