-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
144 lines (121 loc) · 5.68 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from utils.utils import get_labels_frequency, set_logger
from utils.trainer import fit
from models.model import DenseNet121
from data.loader import load_dataset
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torch
import numpy as np
import random
import logging
import sys
import os
import argparse
import warnings
warnings.simplefilter('ignore')
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
default='../Datasets/APTOS/APTOS_images/train_images')
parser.add_argument('--csv_file_path', type=str, default='../CSVs/')
parser.add_argument("--logdir", type=str, required=False,
default="./logs/aptos/", help="Log directory path")
parser.add_argument('--dataset', type=str, default='aptos')
parser.add_argument('--split', type=str, default='split1')
parser.add_argument('--n_distill', type=int, default=20,
help='start to use the kld loss')
parser.add_argument('--mode', default='exact', type=str,
choices=['exact', 'relax', 'multi_pos'])
parser.add_argument('--nce_p', default=1, type=int,
help='number of positive samples for NCE')
parser.add_argument('--nce_k', default=4096, type=int,
help='number of negative samples for NCE')
parser.add_argument('--nce_t', default=0.07, type=float,
help='temperature parameter for softmax')
parser.add_argument('--nce_m', default=0.5, type=float,
help='momentum for non-parametric updates')
parser.add_argument('--CCD_mode', type=str,
default="sup", choices=['sup', 'unsup'])
parser.add_argument('--rel_weight', type=float, default=25,
help='whether use the relation loss')
parser.add_argument('--ccd_weight', type=float,
default=0.1, help='whether use the CCD loss')
parser.add_argument('--anchor_type', type=str,
default="center", choices=['center', 'class'])
parser.add_argument('--class_anchor', default=30, type=int,
help='number of anchors in each class')
parser.add_argument('--feat_dim', type=int, default=128,
help='reduced feature dimension')
parser.add_argument('--s_dim', type=int, default=128,
help='feature dim of the student model')
parser.add_argument('--t_dim', type=int, default=128,
help='feature dim of the EMA teacher')
parser.add_argument('--n_data', type=int, default=3662,
help='total number of training samples.')
parser.add_argument('--t_decay', type=float,
default=0.99, help='ema_decay')
parser.add_argument('--epochs', type=int, default=80,
help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int,
default=64, help='batch_size per gpu')
parser.add_argument('--drop_rate', type=int,
default=0, help='dropout rate')
parser.add_argument('--lr', type=float, default=1e-4,
help='learning rate')
parser.add_argument('--seed', type=int, default=2024, help='random seed')
parser.add_argument('--optimizer', type=str, default='adam', help='optim')
parser.add_argument('--scheduler', type=str,
default='OneCycleLR', help='sch_str')
parser.add_argument('--device', type=str, default='cuda:0', help='device')
parser.add_argument('--consistency', type=float,
default=1, help='consistency')
parser.add_argument('--consistency_rampup', type=float,
default=30, help='consistency_rampup')
args = parser.parse_args()
return args
# Function to set the seed for all random number generators to ensure reproducibility
def set_seed(seed):
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if __name__ == "__main__":
# Get arguments
args = get_args()
# Set seed
set_seed(args.seed)
# Set Logger
if not os.path.exists(args.logdir):
os.makedirs(args.logdir)
logger = set_logger(args)
logger.info(args)
# Loading Data
train_ds, test_ds = load_dataset(args, p=args.nce_p, mode=args.mode)
n_classes = test_ds.n_classes
class_index = train_ds.class_index
print(n_classes)
def worker_init_fn(worker_id):
random.seed(args.seed+worker_id)
train_dl = DataLoader(train_ds, batch_size=args.batch_size,
shuffle=True, num_workers=12, pin_memory=True,
worker_init_fn=worker_init_fn)
test_dl = DataLoader(test_ds, batch_size=args.batch_size,
shuffle=False, num_workers=12, pin_memory=True,
worker_init_fn=worker_init_fn)
freq = get_labels_frequency(args.csv_file_path + args.dataset +
'/' + args.split + '_train.csv', 'diagnosis', 'id_code')
freq = freq.values
weights = freq.sum() / freq
print(weights)
# Loading Models
student = DenseNet121(hidden_units=args.feat_dim,
out_size=n_classes, drop_rate=args.drop_rate)
teacher = DenseNet121(hidden_units=args.feat_dim,
out_size=n_classes, drop_rate=args.drop_rate)
for param in teacher.parameters():
param.detach_()
# Fit the model
fit(student, teacher, train_dl, test_dl, weights,
class_index, logger, args, device=args.device)