-
Notifications
You must be signed in to change notification settings - Fork 0
/
_2b_USE_batch_embed.py
56 lines (41 loc) · 1.5 KB
/
_2b_USE_batch_embed.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
""" uses a pre-trained universal sentence encoder model to calculate
semantic similarity between abstracts from english wikipedia articles.
time taken: 0hr:08 """
import tensorflow as tf
import tensorflow_hub as hub
from absl import logging
import numpy as np
import csv
import time
import pickle
# load pre_trained USE module from TF Hub
module_url = "https://tfhub.dev/google/universal-sentence-encoder/4"
model = hub.load(module_url)
# reduce logging output.
logging.set_verbosity(logging.ERROR)
# process embeddings in 32 batches to prevent OOM error
embeddings =[]
batch_size = 38233
for i in range(1):
# extract single batch of abstracts from cleaned csv file
abstracts = []
with open("cleaned_data/abstracts_test.txt") as csv_file:
csv_reader = csv.reader(csv_file, delimiter=",")
for j, row in enumerate(csv_reader):
if j >= batch_size * i and j < batch_size * (i+1):
abstracts.append(row[2])
# embed abstracts to list of vectors
start_time = time.time()
current_embeddings = model(abstracts)
# add embeddings to master list
embeddings.extend(current_embeddings)
# display elapsed time
end_time = time.time()
print("time taken: ", end_time - start_time)
# convert to numpy array
embeddings = np.array(embeddings).astype(np.float32)
# save embedded abstracts
with open(f"models/USE/over150chars/all_embeddingsxxxx.pkl", "wb") as file:
pickle.dump(embeddings, file)
# check all embeddings were captured
print(embeddings.shape)