-
Notifications
You must be signed in to change notification settings - Fork 22
/
Training_Functions.py
48 lines (45 loc) · 2.58 KB
/
Training_Functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from Helper_Functions import n_words_of_length
def make_train_set_for_target(target,alphabet,lengths=None,max_train_samples_per_length=300,search_size_per_length=1000,provided_examples=None):
train_set = {}
if None is provided_examples:
provided_examples = []
if None is lengths:
lengths = list(range(15))+[15,20,25,30]
for l in lengths:
samples = [w for w in provided_examples if len(w)==l]
samples += n_words_of_length(search_size_per_length,l,alphabet)
pos = [w for w in samples if target(w)]
neg = [w for w in samples if not target(w)]
pos = pos[:int(max_train_samples_per_length/2)]
neg = neg[:int(max_train_samples_per_length/2)]
minority = min(len(pos),len(neg))
pos = pos[:minority+20]
neg = neg[:minority+20]
train_set.update({w:True for w in pos})
train_set.update({w:False for w in neg})
print("made train set of size:",len(train_set),", of which positive examples:",
len([w for w in train_set if train_set[w]==True]))
return train_set
#curriculum
def mixed_curriculum_train(rnn,train_set,outer_loops=3,stop_threshold=0.001,learning_rate=0.001,
length_epochs=5,random_batch_epochs=100,single_batch_epochs=100,random_batch_size=20):
lengths = sorted(list(set([len(w) for w in train_set])))
for _ in range(outer_loops):
for l in lengths:
training = {w:train_set[w] for w in train_set if len(w)==l}
if len(set([training[w] for w in training])) <= 1: #empty, or length with only one classification
continue
rnn.train_group(training,length_epochs,show=False,loss_every=20,stop_threshold=stop_threshold,
learning_rate=learning_rate,batch_size=None,print_time=False)
# all together but in batches
if rnn.finish_signal == rnn.train_group(train_set,random_batch_epochs,show=True,loss_every=20,
stop_threshold = stop_threshold,
learning_rate=learning_rate,
batch_size=random_batch_size,print_time=False):
break
# all together in one batch
if rnn.finish_signal == rnn.train_group(train_set,single_batch_epochs,show=True,loss_every=20,
stop_threshold = stop_threshold,
learning_rate=learning_rate,batch_size=None,print_time=False):
break
print("classification loss on last batch was:",rnn.all_losses[-1])