-
Notifications
You must be signed in to change notification settings - Fork 9
/
main.py
189 lines (166 loc) · 10.4 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import argparse
import os
import random
import warnings
import matplotlib.pyplot as plt
import numpy as np
import torch
import loader
import processor
from os.path import join as j
from config.parse_args import parse_args
warnings.filterwarnings('ignore')
base_path = os.path.dirname(os.path.realpath(__file__))
data_path = j(base_path, '../../data')
models_ser_path = j(base_path, 'models', 'ser_v1')
models_s2eg_path = j(base_path, 'models', 's2eg_v1')
os.makedirs(models_ser_path, exist_ok=True)
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
parser = argparse.ArgumentParser(description='Speech to Emotive Gestures')
parser.add_argument('--dataset-ser', type=str, default='iemocap', metavar='D-SER',
help='dataset to train and evaluate speech emotion recognition (default: iemocap)')
parser.add_argument('--dataset-s2eg', type=str, default='ted_db', metavar='D-S2G',
help='dataset to train and evaluate speech to emotive gestures (default: ted)')
parser.add_argument('-dap', '--dataset-s2eg-already-processed',
help='Optional. Set to True if dataset has already been processed.' +
'If not, or if you are not sure, set it to False.',
type=str2bool, default=True)
parser.add_argument('-c', '--config', required=True, is_config_file=True, help='Config file path')
parser.add_argument('--frame-drop', type=int, default=2, metavar='FD',
help='frame down-sample rate (default: 2)')
parser.add_argument('--add-mirrored', type=bool, default=False, metavar='AM',
help='perform data augmentation by mirroring all the sequences (default: False)')
parser.add_argument('--train-ser', type=bool, default=False, metavar='T-SER',
help='train the ser model (default: True)')
parser.add_argument('--emo-as-cats', type=bool, default=True, metavar='EAC',
help='consider emotions as categories (True) or dimensions (False) (default: False)')
parser.add_argument('--train-s2eg', type=bool, default=False, metavar='T-S2EG',
help='train the s2eg model (default: True)')
parser.add_argument('--use-multiple-gpus', type=bool, default=True, metavar='T',
help='use multiple GPUs if available (default: True)')
parser.add_argument('--ser-load-last-best', type=bool, default=True, metavar='SER-LB',
help='load the most recent best model for ser (default: True)')
parser.add_argument('--s2eg-load-last-best', type=bool, default=True, metavar='S2EG-LB',
help='load the most recent best model for s2eg (default: True)')
parser.add_argument('--batch-size', type=int, default=16, metavar='B',
help='input batch size for training (default: 32)')
parser.add_argument('--num-worker', type=int, default=4, metavar='W',
help='number of threads? (default: 4)')
parser.add_argument('--ser-start-epoch', type=int, default=600, metavar='SER-SE',
help='starting epoch of training of ser (default: 0)')
parser.add_argument('--ser-num-epoch', type=int, default=5000, metavar='SER-NE',
help='number of epochs to train ser (default: 1000)')
parser.add_argument('--s2eg-start-epoch', type=int, default=142, metavar='S2EG-SE',
help='starting epoch of training of s2eg (default: 0)')
parser.add_argument('--s2eg-num-epoch', type=int, default=50000, metavar='S2EG-NE',
help='number of epochs to train s2eg (default: 1000)')
# parser.add_argument('--window-length', type=int, default=1, metavar='WL',
# help='max number of past time steps to take as input to transformer decoder (default: 60)')
parser.add_argument('--ser-optimizer', type=str, default='Adam', metavar='SER-O',
help='optimizer (default: Adam)')
parser.add_argument('--base-lr-ser', type=float, default=1e-3, metavar='LR-SER',
help='base learning rate for ser (default: 1e-2)')
parser.add_argument('--base-tr', type=float, default=1., metavar='TR',
help='base teacher rate (default: 1.0)')
parser.add_argument('--step', type=list, default=0.05 * np.arange(20), metavar='[S]',
help='fraction of steps when learning rate will be decreased (default: [0.5, 0.75, 0.875])')
parser.add_argument('--lr-ser-decay', type=float, default=0.999, metavar='LRD-SER',
help='learning rate decay for ser (default: 0.999)')
parser.add_argument('--lr-s2eg-decay', type=float, default=0.999, metavar='LRD-S2EG',
help='learning rate decay for s2eg (default: 0.999)')
parser.add_argument('--gradient-clip', type=float, default=0.1, metavar='GC',
help='gradient clip threshold (default: 0.1)')
parser.add_argument('--nesterov', action='store_true', default=True,
help='use nesterov')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=5e-4, metavar='D',
help='Weight decay (default: 5e-4)')
parser.add_argument('--upper-body-weight', type=float, default=1., metavar='UBW',
help='loss weight on the upper body joint motions (default: 2.05)')
parser.add_argument('--affs-reg', type=float, default=0.8, metavar='AR',
help='regularization for affective features loss (default: 0.01)')
parser.add_argument('--quat-norm-reg', type=float, default=0.1, metavar='QNR',
help='regularization for unit norm constraint (default: 0.01)')
parser.add_argument('--quat-reg', type=float, default=1.2, metavar='QR',
help='regularization for quaternion loss (default: 0.01)')
parser.add_argument('--recons-reg', type=float, default=1.2, metavar='RCR',
help='regularization for reconstruction loss (default: 1.2)')
parser.add_argument('--eval-interval', type=int, default=1, metavar='EI',
help='interval after which model is evaluated (default: 1)')
parser.add_argument('--log-interval', type=int, default=100, metavar='LI',
help='interval after which log is printed (default: 100)')
parser.add_argument('--save-interval', type=int, default=10, metavar='SI',
help='interval after which model is saved (default: 10)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--pavi-log', action='store_true', default=False,
help='pavi log')
parser.add_argument('--print-log', action='store_true', default=True,
help='print log')
parser.add_argument('--save-log', action='store_true', default=True,
help='save log')
# TO ADD: save_result
args = parser.parse_args()
randomized = False
config_args = parse_args()
if not args.train_ser:
train_data_ted, eval_data_ted, test_data_ted,\
train_data_ted_wav, eval_data_ted_wav, test_data_ted_wav,\
ted_wav_max_all, ted_wav_min_all = loader.load_ted_db_data(data_path,
args.dataset_s2eg,
config_args,
args.dataset_s2eg_already_processed)
# train_ted_wav_dict, eval_ted_wav_dict, test_ted_wav_dict,\
pose_dim = 27 # 9 x 3
time_steps = 34
else:
train_data_ted, eval_data_ted, test_data_ted, \
train_data_ted_wav, eval_data_ted_wav, test_data_ted_wav, \
ted_wav_max_all, ted_wav_min_all, pose_dim, time_steps = [None] * 10
train_data_wav, eval_data_wav, test_data_wav, \
train_labels_cat, eval_labels_cat, test_labels_cat, \
train_labels_dim, eval_labels_dim, test_labels_dim, \
iemocap_wav_max_all, iemocap_wav_min_all = loader.load_iemocap_data(data_path, args.dataset_ser)
_, wav_channels, wav_height, wav_width = train_data_wav.shape
num_emo_cats = train_labels_cat.shape[-1]
num_emo_dims = train_labels_dim.shape[-1]
if args.emo_as_cats:
args.work_dir_ser = j(models_ser_path, args.dataset_ser + '_{:02d}_cats'.format(num_emo_cats))
else:
args.work_dir_ser = j(models_ser_path, args.dataset_ser + '_{:02d}_dims'.format(num_emo_dims))
args.work_dir_s2eg = j(models_s2eg_path, args.dataset_s2eg)
os.makedirs(args.work_dir_ser, exist_ok=True)
os.makedirs(args.work_dir_s2eg, exist_ok=True)
args.video_save_path = j(base_path, 'outputs', 'videos_trimodal_style')
os.makedirs(args.video_save_path, exist_ok=True)
args.quantitative_save_path = j(base_path, 'outputs', 'quantitative')
os.makedirs(args.quantitative_save_path, exist_ok=True)
data_loader = dict(train_data_ser=train_data_wav, train_data_s2eg=train_data_ted,
train_data_s2eg_wav=train_data_ted_wav, # train_data_s2eg_wav_dict=train_ted_wav_dict,
train_labels_cat=train_labels_cat, train_labels_dim=train_labels_dim,
eval_data_ser=eval_data_wav, eval_data_s2eg=eval_data_ted,
eval_data_s2eg_wav=eval_data_ted_wav, # eval_data_s2eg_wav_dict=eval_ted_wav_dict,
eval_labels_cat=eval_labels_cat, eval_labels_dim=eval_labels_dim,
test_data_ser=test_data_wav, test_data_s2eg=test_data_ted,
test_data_s2eg_wav=test_data_ted_wav, # test_data_s2eg_wav_dict=test_ted_wav_dict,
test_labels_cat=test_labels_cat, test_labels_dim=test_labels_dim,
ted_wav_max_all=ted_wav_max_all, ted_wav_min_all=ted_wav_min_all)
pr = processor.Processor(args, config_args, data_path, data_loader,
wav_channels, wav_height, wav_width,
num_emo_cats, num_emo_dims, pose_dim,
time_steps, save_path=base_path)
if args.train_ser or args.train_s2eg:
pr.train()
# pr.generate_gestures(samples_to_generate=len(data_loader['test_data_s2eg_wav']),
# randomized=randomized, ser_epoch='best', s2eg_epoch=142)
pr.generate_gestures_by_env_file(j(data_path, 'ted_db/lmdb_test'), [5, 12],
randomized=randomized, ser_epoch='best', s2eg_epoch=142)