From 075002c37dd73503821dd0196440f7132bd4f9d6 Mon Sep 17 00:00:00 2001 From: superantichrist Date: Thu, 22 Apr 2021 23:05:25 +0900 Subject: [PATCH 1/7] add simple transformer --- alphafold2_pytorch/transformer.py | 94 +++++++++++ setup.py | 3 +- train_simple.py | 253 ++++++++++++++++++++++++++++++ 3 files changed, 349 insertions(+), 1 deletion(-) create mode 100644 alphafold2_pytorch/transformer.py create mode 100644 train_simple.py diff --git a/alphafold2_pytorch/transformer.py b/alphafold2_pytorch/transformer.py new file mode 100644 index 0000000..02cd6f3 --- /dev/null +++ b/alphafold2_pytorch/transformer.py @@ -0,0 +1,94 @@ +###################################################################### +# Transformer! +# ------------ +# +# Transformer is a Seq2Seq model introduced in `“Attention is all you +# need” `__ +# paper for solving machine translation task. Transformer model consists +# of an encoder and decoder block each containing fixed number of layers. +# +# Encoder processes the input sequence by propogating it, through a series +# of Multi-head Attention and Feed forward network layers. The output from +# the Encoder referred to as ``memory``, is fed to the decoder along with +# target tensors. Encoder and decoder are trained in an end-to-end fashion +# using teacher forcing technique. +# + +import math +import torch +from torch import nn +from torch import Tensor +from torch.nn import (TransformerEncoder, TransformerDecoder, + TransformerEncoderLayer, TransformerDecoderLayer) + + +class Seq2SeqTransformer(nn.Module): + def __init__(self, num_encoder_layers: int, num_decoder_layers: int, + emb_size: int, src_vocab_size: int, tgt_vocab_size: int, + dim_feedforward: int = 512, num_head: int = 8, dropout: float = 0.1): + super(Seq2SeqTransformer, self).__init__() + encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=num_head, + dim_feedforward=dim_feedforward) + self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers) + decoder_layer = TransformerDecoderLayer(d_model=emb_size, nhead=num_head, + dim_feedforward=dim_feedforward) + self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers) + + self.generator = nn.Linear(emb_size, tgt_vocab_size) + self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size) + self.tgt_tok_emb = TokenEmbedding(src_vocab_size, emb_size) + + def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor, + tgt_mask: Tensor, src_padding_mask: Tensor, + tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor): + src_emb = self.src_tok_emb(src) + tgt_emb = self.tgt_tok_emb(trg) + memory = self.transformer_encoder(src_emb) + outs = self.transformer_decoder(tgt_emb, memory) + return self.generator(outs) + + def encode(self, src: Tensor, src_mask: Tensor): + return self.transformer_encoder( + self.src_tok_emb(src), src_mask) + + def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor): + return self.transformer_decoder(self.positional_encoding( + self.tgt_tok_emb(tgt)), memory, + tgt_mask) + + +###################################################################### +# Text tokens are represented by using token embeddings. Positional +# encoding is added to the token embedding to introduce a notion of word +# order. +# + +class PositionalEncoding(nn.Module): + def __init__(self, emb_size: int, dropout, maxlen: int = 5000): + super(PositionalEncoding, self).__init__() + den = torch.exp(- torch.arange(0, emb_size, 2) * math.log(10000) / emb_size) + pos = torch.arange(0, maxlen).reshape(maxlen, 1) + pos_embedding = torch.zeros((maxlen, emb_size)) + pos_embedding[:, 0::2] = torch.sin(pos * den) + pos_embedding[:, 1::2] = torch.cos(pos * den) + pos_embedding = pos_embedding.unsqueeze(-2) + + self.dropout = nn.Dropout(dropout) + self.register_buffer('pos_embedding', pos_embedding) + + def forward(self, token_embedding: Tensor): + return self.dropout(token_embedding + + self.pos_embedding[:token_embedding.size(0), :]) + + +class TokenEmbedding(nn.Module): + def __init__(self, vocab_size: int, emb_size): + super(TokenEmbedding, self).__init__() + self.embedding = nn.Embedding(vocab_size, emb_size) + self.emb_size = emb_size + + def forward(self, tokens: Tensor): + return self.embedding(tokens.long()) * math.sqrt(self.emb_size) + + + diff --git a/setup.py b/setup.py index eb9b370..cf673d0 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,8 @@ 'sidechainnet', 'torch>=1.6', 'tqdm', - 'biopython' + 'biopython', + 'tensorboard' ], setup_requires=[ 'pytest-runner', diff --git a/train_simple.py b/train_simple.py new file mode 100644 index 0000000..87af236 --- /dev/null +++ b/train_simple.py @@ -0,0 +1,253 @@ +import torch +from torch import nn +from torch.optim import Adam +from torch.utils.data import DataLoader +import torch.nn.functional as F +from torch.utils.tensorboard import SummaryWriter +from einops import rearrange + +import sidechainnet as scn +from sidechainnet.dataloaders.collate import prepare_dataloaders +from alphafold2_pytorch import Alphafold2 +import alphafold2_pytorch.constants as constants +from alphafold2_pytorch.utils import get_bucketed_distance_matrix +from alphafold2_pytorch.transformer import Seq2SeqTransformer +import time + +# constants + +DEVICE = None # defaults to cuda if available, else cpu +NUM_EPOCHS = int(1e3) +NUM_BATCHES = int(1e5) +GRADIENT_ACCUMULATE_EVERY = 16 +LEARNING_RATE = 3e-4 +IGNORE_INDEX = -100 +THRESHOLD_LENGTH = 50 + +# set device + +DISTOGRAM_BUCKETS = constants.DISTOGRAM_BUCKETS +DEVICE = constants.DEVICE + + +# helpers + + +def cycle(loader, cond=lambda x: True): + while True: + for data in loader: + if not cond(data): + continue + yield data + + +def filter_dictionary_by_seq_length(raw_data, seq_length_threshold, portion): + """Filter SidechainNet data by removing poor-resolution training entries. + + Args: + raw_data (dict): SidechainNet dictionary. + seq_length_threshold (int): sequence length threshold + + Returns: + Filtered dictionary. + """ + new_data = { + "seq": [], + "ang": [], + "ids": [], + "evo": [], + "msk": [], + "crd": [], + "sec": [], + "res": [] + } + train = raw_data[portion] + n_filtered_entries = 0 + total_entires = 0. + for seq, ang, crd, msk, evo, _id, res, sec in zip(train['seq'], train['ang'], + train['crd'], train['msk'], + train['evo'], train['ids'], + train['res'], train['sec']): + total_entires += 1 + if len(seq) > seq_length_threshold: + n_filtered_entries += 1 + continue + else: + new_data["seq"].append(seq) + new_data["ang"].append(ang[:, 0:3]) + new_data["ids"].append(_id) + new_data["evo"].append(evo) + new_data["msk"].append(msk) + new_data["crd"].append(crd) + new_data["sec"].append(sec) + new_data["res"].append(res) + if n_filtered_entries: + print( + f"{total_entires - n_filtered_entries:.0f} out of {total_entires:.0f} ({(total_entires - n_filtered_entries) / total_entires:.1%})" + f" training set entries were included if sequence length <= {seq_length_threshold}") + raw_data[portion] = new_data + return raw_data + + +def create_mask(src, tgt): + src_padding_mask = (src == IGNORE_INDEX).transpose(0, 1) + tgt_padding_mask = (tgt == IGNORE_INDEX).transpose(0, 1) + return src_padding_mask, tgt_padding_mask + + +def train_epoch(model, train_iter, optimizer): + model.train() + losses = 0 + for idx, (batch) in enumerate(train_iter): + seq, coords, angs, mask = batch.seqs, batch.crds, batch.angs, batch.msks + + b, l, _ = seq.shape + + # prepare mask, labels + + seq, coords, angs, mask = seq.argmax(dim=-1).to(DEVICE), coords.to(DEVICE), angs.to(DEVICE), mask.to( + DEVICE).bool() + seq = F.pad(seq, (0, THRESHOLD_LENGTH - l), value=0) + coords = rearrange(coords, 'b (l c) d -> b l c d', l=l) + angs = F.pad(angs, (0, 0, 0, THRESHOLD_LENGTH - l), value=0) + # angs = rearrange(angs, 'b l c -> b (l c)', l=THRESHOLD_LENGTH) + mask = F.pad(mask, (0, THRESHOLD_LENGTH - l), value=False) + + # discretized_distances = get_bucketed_distance_matrix(coords[:, :, 1], mask, DISTOGRAM_BUCKETS, IGNORE_INDEX) + src_padding_mask, tgt_padding_mask = create_mask(seq, seq) + + # predict + + logits = transformer(seq, seq, src_mask=mask, + tgt_mask=mask, src_padding_mask=src_padding_mask, + tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=src_padding_mask) + + # loss + optimizer.zero_grad() + + loss = loss_fn(logits.reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) + loss.backward() + + optimizer.step() + losses += loss.item() + return losses / len(train_iter) + + +def evaluate(model, val_iter): + model.eval() + losses = 0 + for idx, (batch) in (enumerate(val_iter)): + seq, coords, angs, mask = batch.seqs, batch.crds, batch.angs, batch.msks + + b, l, _ = seq.shape + + # prepare mask, labels + + seq, coords, angs, mask = seq.argmax(dim=-1).to(DEVICE), coords.to(DEVICE), angs.to(DEVICE), mask.to( + DEVICE).bool() + seq = F.pad(seq, (0, THRESHOLD_LENGTH - l), value=0) + coords = rearrange(coords, 'b (l c) d -> b l c d', l=l) + angs = F.pad(angs, (0, 0, 0, THRESHOLD_LENGTH - l), value=0) + # angs = rearrange(angs, 'b l c -> b (l c)', l=THRESHOLD_LENGTH) + mask = F.pad(mask, (0, THRESHOLD_LENGTH - l), value=False) + + # discretized_distances = get_bucketed_distance_matrix(coords[:, :, 1], mask, DISTOGRAM_BUCKETS, IGNORE_INDEX) + src_padding_mask, tgt_padding_mask = create_mask(seq, seq) + + # predict + + logits = transformer(seq, seq, src_mask=mask, + tgt_mask=mask, src_padding_mask=src_padding_mask, + tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=src_padding_mask) + + # loss + + loss = loss_fn(logits.reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) + loss.backward() + + losses += loss.item() + return losses / len(val_iter) + + +# get data + +raw_data = scn.load( + casp_version=12, + thinning=30, + batch_size=1, + dynamic_batching=False +) + +filtered_raw_data = filter_dictionary_by_seq_length(raw_data, THRESHOLD_LENGTH, "train") +writer_train = SummaryWriter("runs/train") +writer_valids = [] +for split in scn.utils.download.VALID_SPLITS: + filtered_raw_data = filter_dictionary_by_seq_length(filtered_raw_data, THRESHOLD_LENGTH, f'{split}') + writer_valids.append(SummaryWriter(f"runs/{split}")) +data = prepare_dataloaders( + filtered_raw_data, + aggregate_model_input=True, + batch_size=1, + num_workers=4, + seq_as_onehot=None, + collate_fn=None, + dynamic_batching=False, + optimize_for_cpu_parallelism=False, + train_eval_downsample=.2) +dl = iter(data['train']) + +# model + +# model = Alphafold2( +# dim=256, +# depth=1, +# heads=8, +# dim_head=64 +# ).to(DEVICE) + +SRC_VOCAB_SIZE = 21 # number of amino acids +TGT_VOCAB_SIZE = 3 # backbone torsion angle +NUM_ENCODER_LAYERS = 3 +NUM_DECODER_LAYERS = 3 +EMB_SIZE = 512 +NUM_HEAD = 8 +FFN_HID_DIM = 512 +transformer = Seq2SeqTransformer(num_encoder_layers=NUM_ENCODER_LAYERS, num_decoder_layers=NUM_DECODER_LAYERS, + emb_size=EMB_SIZE, src_vocab_size=SRC_VOCAB_SIZE, tgt_vocab_size=TGT_VOCAB_SIZE, + dim_feedforward=FFN_HID_DIM, num_head=NUM_HEAD) + +# optimizer + +for p in transformer.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + +transformer = transformer.to(DEVICE) + +loss_fn = torch.nn.MSELoss() + +optimizer = torch.optim.Adam( + transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9 +) + +# training loop +for epoch in range(1, NUM_EPOCHS + 1): + start_time = time.time() + train_loss = train_epoch(transformer, iter(data['train']), optimizer) + end_time = time.time() + valid_count = 0 + for split in scn.utils.download.VALID_SPLITS: + val_loss = evaluate(transformer, iter(data[f'{split}'])) + writer_valids[valid_count].add_scalar("loss", val_loss, epoch) + writer_valids[valid_count].flush() + valid_count += 1 + writer_train.add_scalar("loss", train_loss, epoch) + writer_train.flush() + print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, " + f"Epoch time = {(end_time - start_time):.3f}s")) +print('train ended') +writer_train.close() +valid_count = 0 +for split in scn.utils.download.VALID_SPLITS: + writer_valids[valid_count].close() + valid_count += 1 From 2cf73c64915fa175e029764eeac6ae1998caecea Mon Sep 17 00:00:00 2001 From: superantichrist Date: Mon, 26 Apr 2021 14:13:26 +0900 Subject: [PATCH 2/7] change ignore padding index number threshold to 250 --- alphafold2_pytorch/transformer.py | 1 + train_simple.py | 30 ++++++++++++++++++------------ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/alphafold2_pytorch/transformer.py b/alphafold2_pytorch/transformer.py index 02cd6f3..e6ae360 100644 --- a/alphafold2_pytorch/transformer.py +++ b/alphafold2_pytorch/transformer.py @@ -38,6 +38,7 @@ def __init__(self, num_encoder_layers: int, num_decoder_layers: int, self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size) self.tgt_tok_emb = TokenEmbedding(src_vocab_size, emb_size) + # todo make mask work def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor, tgt_mask: Tensor, src_padding_mask: Tensor, tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor): diff --git a/train_simple.py b/train_simple.py index 87af236..bde8712 100644 --- a/train_simple.py +++ b/train_simple.py @@ -21,8 +21,9 @@ NUM_BATCHES = int(1e5) GRADIENT_ACCUMULATE_EVERY = 16 LEARNING_RATE = 3e-4 -IGNORE_INDEX = -100 -THRESHOLD_LENGTH = 50 +IGNORE_INDEX = 21 +# todo change protein sequence threshold length +THRESHOLD_LENGTH = 250 # set device @@ -83,7 +84,7 @@ def filter_dictionary_by_seq_length(raw_data, seq_length_threshold, portion): new_data["res"].append(res) if n_filtered_entries: print( - f"{total_entires - n_filtered_entries:.0f} out of {total_entires:.0f} ({(total_entires - n_filtered_entries) / total_entires:.1%})" + f"{portion}: {total_entires - n_filtered_entries:.0f} out of {total_entires:.0f} ({(total_entires - n_filtered_entries) / total_entires:.1%})" f" training set entries were included if sequence length <= {seq_length_threshold}") raw_data[portion] = new_data return raw_data @@ -107,11 +108,11 @@ def train_epoch(model, train_iter, optimizer): seq, coords, angs, mask = seq.argmax(dim=-1).to(DEVICE), coords.to(DEVICE), angs.to(DEVICE), mask.to( DEVICE).bool() - seq = F.pad(seq, (0, THRESHOLD_LENGTH - l), value=0) + seq = F.pad(seq, (0, THRESHOLD_LENGTH - l), value=IGNORE_INDEX) coords = rearrange(coords, 'b (l c) d -> b l c d', l=l) angs = F.pad(angs, (0, 0, 0, THRESHOLD_LENGTH - l), value=0) # angs = rearrange(angs, 'b l c -> b (l c)', l=THRESHOLD_LENGTH) - mask = F.pad(mask, (0, THRESHOLD_LENGTH - l), value=False) + mask = ~F.pad(mask, (0, THRESHOLD_LENGTH - l, 0, THRESHOLD_LENGTH - l), value=False) # discretized_distances = get_bucketed_distance_matrix(coords[:, :, 1], mask, DISTOGRAM_BUCKETS, IGNORE_INDEX) src_padding_mask, tgt_padding_mask = create_mask(seq, seq) @@ -123,8 +124,6 @@ def train_epoch(model, train_iter, optimizer): tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=src_padding_mask) # loss - optimizer.zero_grad() - loss = loss_fn(logits.reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) loss.backward() @@ -145,7 +144,7 @@ def evaluate(model, val_iter): seq, coords, angs, mask = seq.argmax(dim=-1).to(DEVICE), coords.to(DEVICE), angs.to(DEVICE), mask.to( DEVICE).bool() - seq = F.pad(seq, (0, THRESHOLD_LENGTH - l), value=0) + seq = F.pad(seq, (0, THRESHOLD_LENGTH - l), value=IGNORE_INDEX) coords = rearrange(coords, 'b (l c) d -> b l c d', l=l) angs = F.pad(angs, (0, 0, 0, THRESHOLD_LENGTH - l), value=0) # angs = rearrange(angs, 'b l c -> b (l c)', l=THRESHOLD_LENGTH) @@ -163,7 +162,6 @@ def evaluate(model, val_iter): # loss loss = loss_fn(logits.reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) - loss.backward() losses += loss.item() return losses / len(val_iter) @@ -174,7 +172,7 @@ def evaluate(model, val_iter): raw_data = scn.load( casp_version=12, thinning=30, - batch_size=1, + batch_size=100, dynamic_batching=False ) @@ -187,7 +185,7 @@ def evaluate(model, val_iter): data = prepare_dataloaders( filtered_raw_data, aggregate_model_input=True, - batch_size=1, + batch_size=100, num_workers=4, seq_as_onehot=None, collate_fn=None, @@ -205,7 +203,8 @@ def evaluate(model, val_iter): # dim_head=64 # ).to(DEVICE) -SRC_VOCAB_SIZE = 21 # number of amino acids +# +SRC_VOCAB_SIZE = 22 # number of amino acids + padding 21 TGT_VOCAB_SIZE = 3 # backbone torsion angle NUM_ENCODER_LAYERS = 3 NUM_DECODER_LAYERS = 3 @@ -230,6 +229,7 @@ def evaluate(model, val_iter): transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9 ) +# todo checkpoint routine # training loop for epoch in range(1, NUM_EPOCHS + 1): start_time = time.time() @@ -245,6 +245,12 @@ def evaluate(model, val_iter): writer_train.flush() print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, " f"Epoch time = {(end_time - start_time):.3f}s")) + torch.save({ + 'epoch': epoch, + 'model_state_dict': transformer.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': train_loss, + }, "model.pt") print('train ended') writer_train.close() valid_count = 0 From 9ddcd9e119f4305e67fdcaeafc985634be7ccac2 Mon Sep 17 00:00:00 2001 From: superantichrist Date: Mon, 26 Apr 2021 17:04:25 +0900 Subject: [PATCH 3/7] add checkpoint restore routine add optimization scheduler --- alphafold2_pytorch/transformer.py | 6 +-- train_simple.py | 61 +++++++++++++++++++++++-------- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/alphafold2_pytorch/transformer.py b/alphafold2_pytorch/transformer.py index e6ae360..8d082f2 100644 --- a/alphafold2_pytorch/transformer.py +++ b/alphafold2_pytorch/transformer.py @@ -25,13 +25,13 @@ class Seq2SeqTransformer(nn.Module): def __init__(self, num_encoder_layers: int, num_decoder_layers: int, emb_size: int, src_vocab_size: int, tgt_vocab_size: int, - dim_feedforward: int = 512, num_head: int = 8, dropout: float = 0.1): + dim_feedforward: int = 512, num_head: int = 8, dropout: float = 0.1, activation: str = "relu"): super(Seq2SeqTransformer, self).__init__() encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=num_head, - dim_feedforward=dim_feedforward) + dim_feedforward=dim_feedforward, dropout=dropout, activation=activation) self.transformer_encoder = TransformerEncoder(encoder_layer, num_layers=num_encoder_layers) decoder_layer = TransformerDecoderLayer(d_model=emb_size, nhead=num_head, - dim_feedforward=dim_feedforward) + dim_feedforward=dim_feedforward, dropout=dropout, activation=activation) self.transformer_decoder = TransformerDecoder(decoder_layer, num_layers=num_decoder_layers) self.generator = nn.Linear(emb_size, tgt_vocab_size) diff --git a/train_simple.py b/train_simple.py index bde8712..89f1c97 100644 --- a/train_simple.py +++ b/train_simple.py @@ -13,6 +13,7 @@ from alphafold2_pytorch.utils import get_bucketed_distance_matrix from alphafold2_pytorch.transformer import Seq2SeqTransformer import time +import os # constants @@ -22,9 +23,20 @@ GRADIENT_ACCUMULATE_EVERY = 16 LEARNING_RATE = 3e-4 IGNORE_INDEX = 21 -# todo change protein sequence threshold length THRESHOLD_LENGTH = 250 +BATCH_SIZE = 100 +# transformer constants + +SRC_VOCAB_SIZE = 22 # number of amino acids + padding 21 +TGT_VOCAB_SIZE = 3 # backbone torsion angle +NUM_ENCODER_LAYERS = 3 +NUM_DECODER_LAYERS = 3 +EMB_SIZE = 512 +NUM_HEAD = 8 +FFN_HID_DIM = 512 + +MODEL_PATH = f"model_{THRESHOLD_LENGTH}_{NUM_ENCODER_LAYERS}_{NUM_DECODER_LAYERS}_{FFN_HID_DIM}.pt" # set device DISTOGRAM_BUCKETS = constants.DISTOGRAM_BUCKETS @@ -172,12 +184,13 @@ def evaluate(model, val_iter): raw_data = scn.load( casp_version=12, thinning=30, - batch_size=100, + batch_size=BATCH_SIZE, dynamic_batching=False ) filtered_raw_data = filter_dictionary_by_seq_length(raw_data, THRESHOLD_LENGTH, "train") writer_train = SummaryWriter("runs/train") +writer_train_eval = SummaryWriter("runs/train_eval") writer_valids = [] for split in scn.utils.download.VALID_SPLITS: filtered_raw_data = filter_dictionary_by_seq_length(filtered_raw_data, THRESHOLD_LENGTH, f'{split}') @@ -185,7 +198,7 @@ def evaluate(model, val_iter): data = prepare_dataloaders( filtered_raw_data, aggregate_model_input=True, - batch_size=100, + batch_size=BATCH_SIZE, num_workers=4, seq_as_onehot=None, collate_fn=None, @@ -204,13 +217,6 @@ def evaluate(model, val_iter): # ).to(DEVICE) # -SRC_VOCAB_SIZE = 22 # number of amino acids + padding 21 -TGT_VOCAB_SIZE = 3 # backbone torsion angle -NUM_ENCODER_LAYERS = 3 -NUM_DECODER_LAYERS = 3 -EMB_SIZE = 512 -NUM_HEAD = 8 -FFN_HID_DIM = 512 transformer = Seq2SeqTransformer(num_encoder_layers=NUM_ENCODER_LAYERS, num_decoder_layers=NUM_DECODER_LAYERS, emb_size=EMB_SIZE, src_vocab_size=SRC_VOCAB_SIZE, tgt_vocab_size=TGT_VOCAB_SIZE, dim_feedforward=FFN_HID_DIM, num_head=NUM_HEAD) @@ -228,29 +234,52 @@ def evaluate(model, val_iter): optimizer = torch.optim.Adam( transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9 ) - -# todo checkpoint routine +scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5) + +prev_epoch = 0 +if os.path.exists(MODEL_PATH): + checkpoint = torch.load(MODEL_PATH) + transformer.load_state_dict(checkpoint['model_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + if 'scheduler_state_dict' in checkpoint: + scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + prev_epoch = checkpoint['epoch'] + loss = checkpoint['loss'] + print(f"restore checkpoint. Epoch: {prev_epoch}, loss: {loss:.3f}") # training loop -for epoch in range(1, NUM_EPOCHS + 1): +for epoch in range(prev_epoch + 1, NUM_EPOCHS + 1): start_time = time.time() train_loss = train_epoch(transformer, iter(data['train']), optimizer) end_time = time.time() +# train_eval_loss = evaluate(transformer, iter(data['train-eval'])) + print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, " # Train eval loss: {train_eval_loss:.3f}, " + f"Epoch time = {(end_time - start_time):.3f}s")) valid_count = 0 for split in scn.utils.download.VALID_SPLITS: val_loss = evaluate(transformer, iter(data[f'{split}'])) writer_valids[valid_count].add_scalar("loss", val_loss, epoch) writer_valids[valid_count].flush() + print(f"Epoch: {epoch}, {split} loss: {val_loss:.3f}") valid_count += 1 writer_train.add_scalar("loss", train_loss, epoch) writer_train.flush() - print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, " - f"Epoch time = {(end_time - start_time):.3f}s")) + # writer_train_eval.add_scalar("loss", train_eval_loss, epoch) + # writer_train_eval.flush() + scheduler.step(train_loss) + torch.save({ + 'epoch': epoch, + 'model_state_dict': transformer.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + 'loss': train_loss, + }, MODEL_PATH) torch.save({ 'epoch': epoch, 'model_state_dict': transformer.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), 'loss': train_loss, - }, "model.pt") + }, f"model_{THRESHOLD_LENGTH}_{NUM_ENCODER_LAYERS}_{NUM_DECODER_LAYERS}_{FFN_HID_DIM}_{epoch}.pt") print('train ended') writer_train.close() valid_count = 0 From e5f48fad205d95ee5edc6147d662ee5e050fbf6b Mon Sep 17 00:00:00 2001 From: superantichrist Date: Mon, 10 May 2021 16:40:21 +0900 Subject: [PATCH 4/7] add positional encoding add parameter for various test --- .gitignore | 4 ++ alphafold2_pytorch/transformer.py | 19 ++++++--- alphafold2_pytorch/utils.py | 2 +- train_end2end.py | 7 +-- train_simple.py | 71 ++++++++++++++++++++----------- 5 files changed, 67 insertions(+), 36 deletions(-) diff --git a/.gitignore b/.gitignore index 75352a8..5489622 100644 --- a/.gitignore +++ b/.gitignore @@ -138,3 +138,7 @@ dmypy.json # Pyre type checker .pyre/ + +.idea/ +runs/ +/model/ diff --git a/alphafold2_pytorch/transformer.py b/alphafold2_pytorch/transformer.py index 8d082f2..9ebf6f0 100644 --- a/alphafold2_pytorch/transformer.py +++ b/alphafold2_pytorch/transformer.py @@ -25,7 +25,8 @@ class Seq2SeqTransformer(nn.Module): def __init__(self, num_encoder_layers: int, num_decoder_layers: int, emb_size: int, src_vocab_size: int, tgt_vocab_size: int, - dim_feedforward: int = 512, num_head: int = 8, dropout: float = 0.1, activation: str = "relu"): + dim_feedforward: int = 512, num_head: int = 8, dropout: float = 0.1, activation: str = "relu", + max_len: int = 5000): super(Seq2SeqTransformer, self).__init__() encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=num_head, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation) @@ -37,15 +38,21 @@ def __init__(self, num_encoder_layers: int, num_decoder_layers: int, self.generator = nn.Linear(emb_size, tgt_vocab_size) self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size) self.tgt_tok_emb = TokenEmbedding(src_vocab_size, emb_size) + self.positional_encoding = PositionalEncoding(emb_size, dropout=dropout, maxlen=max_len) # todo make mask work def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor, tgt_mask: Tensor, src_padding_mask: Tensor, - tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor): - src_emb = self.src_tok_emb(src) - tgt_emb = self.tgt_tok_emb(trg) - memory = self.transformer_encoder(src_emb) - outs = self.transformer_decoder(tgt_emb, memory) + tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor, + use_padding_mask: bool = False): + src_emb = self.positional_encoding(self.src_tok_emb(src)) + tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg)) + if use_padding_mask: + memory = self.transformer_encoder(src_emb, src_key_padding_mask=src_padding_mask) + outs = self.transformer_decoder(tgt_emb, memory, tgt_key_padding_mask=tgt_padding_mask) + else: + memory = self.transformer_encoder(src_emb) + outs = self.transformer_decoder(tgt_emb, memory) return self.generator(outs) def encode(self, src: Tensor, src_mask: Tensor): diff --git a/alphafold2_pytorch/utils.py b/alphafold2_pytorch/utils.py index a81c329..8670a2c 100644 --- a/alphafold2_pytorch/utils.py +++ b/alphafold2_pytorch/utils.py @@ -288,7 +288,7 @@ def get_msa_embedd(msa, embedd_model, batch_converter, device = None): return token_reps -def get_esm_embedd(seq, embedd_model, batch_converter, msa_data=None): +def get_esm_embedd(seq, embedd_model, batch_converter, msa_data=None, device = None): """ Returns the ESM embeddings for a protein. Inputs: * seq: ( (b,) L,) tensor of ints (in sidechainnet int-char convention) diff --git a/train_end2end.py b/train_end2end.py index 43c96da..59bbca1 100644 --- a/train_end2end.py +++ b/train_end2end.py @@ -1,5 +1,6 @@ import torch from torch.optim import Adam +from torch import nn from torch.utils.data import DataLoader import torch.nn.functional as F from einops import rearrange @@ -7,7 +8,7 @@ # data import sidechainnet as scn -from sidechainnet.sequence.utils import VOCAB +# from sidechainnet.sequence.utils import VOCAB from sidechainnet.structure.build_info import NUM_COORDS_PER_RES # models @@ -108,11 +109,11 @@ def cycle(loader, cond = lambda x: True): # mask the atoms and backbone positions for each residue # sequence embedding (msa / esm / attn / or nothing) - msa, embedds = None + msa, embedds = None, None # get embedds if FEATURES == "esm": - embedds = get_esm_embedd(seq, embedd_model, batch_converter) + embedds = get_esm_embedd(seq, embedd_model, batch_converter, device=DEVICE) # get msa here elif FEATURES == "msa": pass diff --git a/train_simple.py b/train_simple.py index 89f1c97..4e6b0f0 100644 --- a/train_simple.py +++ b/train_simple.py @@ -23,8 +23,8 @@ GRADIENT_ACCUMULATE_EVERY = 16 LEARNING_RATE = 3e-4 IGNORE_INDEX = 21 -THRESHOLD_LENGTH = 250 -BATCH_SIZE = 100 +THRESHOLD_LENGTH = 50 +BATCH_SIZE = 128 # transformer constants @@ -32,11 +32,12 @@ TGT_VOCAB_SIZE = 3 # backbone torsion angle NUM_ENCODER_LAYERS = 3 NUM_DECODER_LAYERS = 3 -EMB_SIZE = 512 -NUM_HEAD = 8 -FFN_HID_DIM = 512 +EMB_SIZE = 256 +NUM_HEAD = 16 +FFN_HID_DIM = 128 +LOSS_WITHOUT_PADDING = True -MODEL_PATH = f"model_{THRESHOLD_LENGTH}_{NUM_ENCODER_LAYERS}_{NUM_DECODER_LAYERS}_{FFN_HID_DIM}.pt" +MODEL_PATH = f"model/model_t{THRESHOLD_LENGTH}_b{BATCH_SIZE}_e{NUM_ENCODER_LAYERS}_d{NUM_DECODER_LAYERS}_em{EMB_SIZE}_h{NUM_HEAD}_fh{FFN_HID_DIM}.pt" # set device DISTOGRAM_BUCKETS = constants.DISTOGRAM_BUCKETS @@ -122,7 +123,8 @@ def train_epoch(model, train_iter, optimizer): DEVICE).bool() seq = F.pad(seq, (0, THRESHOLD_LENGTH - l), value=IGNORE_INDEX) coords = rearrange(coords, 'b (l c) d -> b l c d', l=l) - angs = F.pad(angs, (0, 0, 0, THRESHOLD_LENGTH - l), value=0) + if not LOSS_WITHOUT_PADDING: + angs = F.pad(angs, (0, 0, 0, THRESHOLD_LENGTH - l), value=0) # angs = rearrange(angs, 'b l c -> b (l c)', l=THRESHOLD_LENGTH) mask = ~F.pad(mask, (0, THRESHOLD_LENGTH - l, 0, THRESHOLD_LENGTH - l), value=False) @@ -135,8 +137,12 @@ def train_epoch(model, train_iter, optimizer): tgt_mask=mask, src_padding_mask=src_padding_mask, tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=src_padding_mask) + optimizer.zero_grad() # loss - loss = loss_fn(logits.reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) + if LOSS_WITHOUT_PADDING: + loss = loss_fn(logits[:, :l, :].reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) + else: + loss = loss_fn(logits.reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) loss.backward() optimizer.step() @@ -158,7 +164,8 @@ def evaluate(model, val_iter): DEVICE).bool() seq = F.pad(seq, (0, THRESHOLD_LENGTH - l), value=IGNORE_INDEX) coords = rearrange(coords, 'b (l c) d -> b l c d', l=l) - angs = F.pad(angs, (0, 0, 0, THRESHOLD_LENGTH - l), value=0) + if not LOSS_WITHOUT_PADDING: + angs = F.pad(angs, (0, 0, 0, THRESHOLD_LENGTH - l), value=0) # angs = rearrange(angs, 'b l c -> b (l c)', l=THRESHOLD_LENGTH) mask = F.pad(mask, (0, THRESHOLD_LENGTH - l), value=False) @@ -173,7 +180,10 @@ def evaluate(model, val_iter): # loss - loss = loss_fn(logits.reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) + if LOSS_WITHOUT_PADDING: + loss = loss_fn(logits[:, :l, :].reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) + else: + loss = loss_fn(logits.reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) losses += loss.item() return losses / len(val_iter) @@ -190,11 +200,12 @@ def evaluate(model, val_iter): filtered_raw_data = filter_dictionary_by_seq_length(raw_data, THRESHOLD_LENGTH, "train") writer_train = SummaryWriter("runs/train") -writer_train_eval = SummaryWriter("runs/train_eval") -writer_valids = [] +# writer_train_eval = SummaryWriter("runs/train_eval") +writer_valid = SummaryWriter("runs/validation") +# writer_valids = [] for split in scn.utils.download.VALID_SPLITS: filtered_raw_data = filter_dictionary_by_seq_length(filtered_raw_data, THRESHOLD_LENGTH, f'{split}') - writer_valids.append(SummaryWriter(f"runs/{split}")) +# writer_valids.append(SummaryWriter(f"runs/{split}")) data = prepare_dataloaders( filtered_raw_data, aggregate_model_input=True, @@ -219,7 +230,7 @@ def evaluate(model, val_iter): # transformer = Seq2SeqTransformer(num_encoder_layers=NUM_ENCODER_LAYERS, num_decoder_layers=NUM_DECODER_LAYERS, emb_size=EMB_SIZE, src_vocab_size=SRC_VOCAB_SIZE, tgt_vocab_size=TGT_VOCAB_SIZE, - dim_feedforward=FFN_HID_DIM, num_head=NUM_HEAD) + dim_feedforward=FFN_HID_DIM, num_head=NUM_HEAD, activation='gelu', max_len=5000) # optimizer @@ -234,7 +245,10 @@ def evaluate(model, val_iter): optimizer = torch.optim.Adam( transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9 ) -scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5) +# optimizer = torch.optim.RMSprop( +# transformer.parameters(), lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False +# ) +scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=20, verbose=True, factor=0.75) prev_epoch = 0 if os.path.exists(MODEL_PATH): @@ -252,20 +266,24 @@ def evaluate(model, val_iter): train_loss = train_epoch(transformer, iter(data['train']), optimizer) end_time = time.time() # train_eval_loss = evaluate(transformer, iter(data['train-eval'])) - print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, " # Train eval loss: {train_eval_loss:.3f}, " - f"Epoch time = {(end_time - start_time):.3f}s")) valid_count = 0 + val_loss_sum = 0 for split in scn.utils.download.VALID_SPLITS: val_loss = evaluate(transformer, iter(data[f'{split}'])) - writer_valids[valid_count].add_scalar("loss", val_loss, epoch) - writer_valids[valid_count].flush() - print(f"Epoch: {epoch}, {split} loss: {val_loss:.3f}") + # writer_valids[valid_count].add_scalar("loss", val_loss, epoch) + # writer_valids[valid_count].flush() + # print(f"Epoch: {epoch}, {split} loss: {val_loss:.3f}") valid_count += 1 + val_loss_sum += val_loss + print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, val loss: {val_loss_sum/valid_count:.3f}, " + f"Epoch time = {(end_time - start_time):.3f}s")) writer_train.add_scalar("loss", train_loss, epoch) writer_train.flush() + writer_valid.add_scalar("loss", val_loss_sum/valid_count, epoch) + writer_valid.flush() # writer_train_eval.add_scalar("loss", train_eval_loss, epoch) # writer_train_eval.flush() - scheduler.step(train_loss) + scheduler.step(val_loss_sum/valid_count) torch.save({ 'epoch': epoch, 'model_state_dict': transformer.state_dict(), @@ -279,10 +297,11 @@ def evaluate(model, val_iter): 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': train_loss, - }, f"model_{THRESHOLD_LENGTH}_{NUM_ENCODER_LAYERS}_{NUM_DECODER_LAYERS}_{FFN_HID_DIM}_{epoch}.pt") + }, f"model/model_t{THRESHOLD_LENGTH}_b{BATCH_SIZE}_e{NUM_ENCODER_LAYERS}_d{NUM_DECODER_LAYERS}_em{EMB_SIZE}_h{NUM_HEAD}_fh{FFN_HID_DIM}_{epoch}.pt") print('train ended') writer_train.close() -valid_count = 0 -for split in scn.utils.download.VALID_SPLITS: - writer_valids[valid_count].close() - valid_count += 1 +writer_valid.close() +# valid_count = 0 +# for split in scn.utils.download.VALID_SPLITS: +# writer_valids[valid_count].close() +# valid_count += 1 From 8f7acd708dfd7b502fcb11313d6aac4ca305c798 Mon Sep 17 00:00:00 2001 From: superantichrist Date: Thu, 13 May 2021 14:44:33 +0900 Subject: [PATCH 5/7] add best model logic save figure for phi, psi, omega angle mask applied for prediction output and angle data --- train_simple.py | 191 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 144 insertions(+), 47 deletions(-) diff --git a/train_simple.py b/train_simple.py index 4e6b0f0..14e672c 100644 --- a/train_simple.py +++ b/train_simple.py @@ -14,17 +14,18 @@ from alphafold2_pytorch.transformer import Seq2SeqTransformer import time import os +import matplotlib.pyplot as plt # constants DEVICE = None # defaults to cuda if available, else cpu -NUM_EPOCHS = int(1e3) +NUM_EPOCHS = int(3e5) NUM_BATCHES = int(1e5) GRADIENT_ACCUMULATE_EVERY = 16 LEARNING_RATE = 3e-4 IGNORE_INDEX = 21 THRESHOLD_LENGTH = 50 -BATCH_SIZE = 128 +BATCH_SIZE = 250 # transformer constants @@ -33,11 +34,12 @@ NUM_ENCODER_LAYERS = 3 NUM_DECODER_LAYERS = 3 EMB_SIZE = 256 -NUM_HEAD = 16 +NUM_HEAD = 8 FFN_HID_DIM = 128 -LOSS_WITHOUT_PADDING = True +LOSS_WITHOUT_PADDING = False MODEL_PATH = f"model/model_t{THRESHOLD_LENGTH}_b{BATCH_SIZE}_e{NUM_ENCODER_LAYERS}_d{NUM_DECODER_LAYERS}_em{EMB_SIZE}_h{NUM_HEAD}_fh{FFN_HID_DIM}.pt" +BEST_MODEL_PATH = MODEL_PATH # set device DISTOGRAM_BUCKETS = constants.DISTOGRAM_BUCKETS @@ -109,9 +111,12 @@ def create_mask(src, tgt): return src_padding_mask, tgt_padding_mask -def train_epoch(model, train_iter, optimizer): +def train_epoch(model, train_iter, optimizer_, epoch): model.train() losses = 0 + radian_diffs = torch.zeros(THRESHOLD_LENGTH*TGT_VOCAB_SIZE*BATCH_SIZE).to(DEVICE) + logits_avg = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * BATCH_SIZE).to(DEVICE) + angs_avg = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * BATCH_SIZE).to(DEVICE) for idx, (batch) in enumerate(train_iter): seq, coords, angs, mask = batch.seqs, batch.crds, batch.angs, batch.msks @@ -126,38 +131,95 @@ def train_epoch(model, train_iter, optimizer): if not LOSS_WITHOUT_PADDING: angs = F.pad(angs, (0, 0, 0, THRESHOLD_LENGTH - l), value=0) # angs = rearrange(angs, 'b l c -> b (l c)', l=THRESHOLD_LENGTH) - mask = ~F.pad(mask, (0, THRESHOLD_LENGTH - l, 0, THRESHOLD_LENGTH - l), value=False) + mask = F.pad(mask, (0, THRESHOLD_LENGTH - l), value=False) # discretized_distances = get_bucketed_distance_matrix(coords[:, :, 1], mask, DISTOGRAM_BUCKETS, IGNORE_INDEX) src_padding_mask, tgt_padding_mask = create_mask(seq, seq) # predict - logits = transformer(seq, seq, src_mask=mask, - tgt_mask=mask, src_padding_mask=src_padding_mask, - tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=src_padding_mask) + logits = model(seq, seq, src_mask=mask, + tgt_mask=mask, src_padding_mask=src_padding_mask, + tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=src_padding_mask) + + optimizer_.zero_grad() + + mask1= mask.unsqueeze(2).expand(-1, -1, 3) + angs1 = torch.acos(torch.zeros(1)).item() * 4 * \ + (angs < -torch.acos(torch.zeros(1)).item() * 1.5) +\ + angs + + angs2 = mask1 * angs1 + logits2 = mask1 * logits + angs3 = angs2.reshape(-1, angs2.shape[-1]) + logits3 = logits2.reshape(-1, logits2.shape[-1]) - optimizer.zero_grad() # loss if LOSS_WITHOUT_PADDING: - loss = loss_fn(logits[:, :l, :].reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) + loss_ = loss_fn(logits[:, :l, :].reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) + diff = logits[:, :l, :].reshape(-1, logits.shape[-1]) - angs.reshape(-1, angs.shape[-1]) else: - loss = loss_fn(logits.reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) - loss.backward() - - optimizer.step() - losses += loss.item() + loss_ = loss_fn(logits3, angs3) + diff = logits3 - angs3 + radian_diff = torch.rad2deg(diff).reshape(-1) + radian_diffs += abs(radian_diff) + logits_avg += abs(torch.rad2deg(logits3)).reshape(-1) + angs_avg += abs(torch.rad2deg(angs3)).reshape(-1) + + # plt.plot(logits3.tolist(), label='logits') + if idx == 0 and epoch % 10 == 0: + plt.clf() + plt.plot(angs3[:, 0:1].reshape(-1)[0:THRESHOLD_LENGTH].tolist(), label='phi') + plt.plot(logits3[:, 0:1].reshape(-1)[0:THRESHOLD_LENGTH].tolist(), label='phi_logit') + plt.ylabel('angles') + plt.legend() + plt.savefig(f"./graph/train1_{epoch}_phi.png") + plt.clf() + plt.plot(angs3[:, 1:2].reshape(-1)[0:THRESHOLD_LENGTH].tolist(), label='psi') + plt.plot(logits3[:, 1:2].reshape(-1)[0:THRESHOLD_LENGTH].tolist(), label='psi_logit') + plt.ylabel('angles') + plt.legend() + plt.savefig(f"./graph/train1_{epoch}_psi.png") + plt.clf() + plt.plot(angs3[:, 2:3].reshape(-1)[0:THRESHOLD_LENGTH].tolist(), label='omega') + plt.plot(logits3[:, 2:3].reshape(-1)[0:THRESHOLD_LENGTH].tolist(), label='omega_logit') + plt.ylabel('angles') + plt.legend() + plt.savefig(f"./graph/train1_{epoch}_omega.png") + # plt.plot(diff.tolist()) + + + loss_.backward() + + # torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) + optimizer_.step() + losses += loss_.item() + radian_diffs = radian_diffs / len(train_iter) + logits_avg = logits_avg / len(train_iter) + angs_avg = angs_avg / len(train_iter) + # diff_dict = {str(i): string for i, string in enumerate(radian_diffs.tolist())} + # writer_train.add_scalars("train", diff_dict, epoch) + if epoch % 10 == 0: + plt.clf() + plt.plot(torch.mean(radian_diffs.reshape(THRESHOLD_LENGTH*TGT_VOCAB_SIZE, -1), 1).tolist(), label='diff') + plt.plot(torch.mean(logits_avg.reshape(THRESHOLD_LENGTH * TGT_VOCAB_SIZE, -1), 1).tolist(), label='logit') + plt.plot(torch.mean(angs_avg.reshape(THRESHOLD_LENGTH * TGT_VOCAB_SIZE, -1), 1).tolist(), label='ang') + plt.ylabel('angles') + plt.legend() + plt.savefig(f"./graph/train_{epoch}.png") return losses / len(train_iter) def evaluate(model, val_iter): model.eval() losses = 0 + radian_diffs = None # torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * BATCH_SIZE).to(DEVICE) for idx, (batch) in (enumerate(val_iter)): seq, coords, angs, mask = batch.seqs, batch.crds, batch.angs, batch.msks b, l, _ = seq.shape - + if radian_diffs is None: + radian_diffs = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * b).to(DEVICE) # prepare mask, labels seq, coords, angs, mask = seq.argmax(dim=-1).to(DEVICE), coords.to(DEVICE), angs.to(DEVICE), mask.to( @@ -174,18 +236,31 @@ def evaluate(model, val_iter): # predict - logits = transformer(seq, seq, src_mask=mask, - tgt_mask=mask, src_padding_mask=src_padding_mask, - tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=src_padding_mask) + logits = model(seq, seq, src_mask=mask, + tgt_mask=mask, src_padding_mask=src_padding_mask, + tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=src_padding_mask) - # loss + angs_correction = torch.acos(torch.zeros(1)).item() * 4 * \ + (angs.reshape(-1, angs.shape[-1]) < -torch.acos(torch.zeros(1)).item() * 1.5) + \ + angs.reshape(-1, angs.shape[-1]) + # loss if LOSS_WITHOUT_PADDING: - loss = loss_fn(logits[:, :l, :].reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) + loss_ = loss_fn(logits[:, :l, :].reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) + diff = logits[:, :l, :].reshape(-1, logits.shape[-1]) - angs.reshape(-1, angs.shape[-1]) else: - loss = loss_fn(logits.reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) - - losses += loss.item() + loss_ = loss_fn(logits.reshape(-1, logits.shape[-1]), angs_correction) + diff = logits.reshape(-1, logits.shape[-1]) - angs_correction + radian_diff = torch.rad2deg(diff).reshape(-1) + radian_diffs += abs(radian_diff) + + losses += loss_.item() + radian_diffs = radian_diffs / len(val_iter) + # diff_dict = {str(i): string for i, string in enumerate(radian_diffs.tolist())} + # writer_train.add_scalars("train", diff_dict, epoch) + # plt.plot(torch.mean(radian_diffs.reshape(THRESHOLD_LENGTH * TGT_VOCAB_SIZE, -1), 1).tolist()) + # plt.ylabel('angles') + # plt.savefig("valid.png") return losses / len(val_iter) @@ -248,24 +323,40 @@ def evaluate(model, val_iter): # optimizer = torch.optim.RMSprop( # transformer.parameters(), lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False # ) -scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=20, verbose=True, factor=0.75) +scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=2000, verbose=True, factor=0.75) prev_epoch = 0 -if os.path.exists(MODEL_PATH): - checkpoint = torch.load(MODEL_PATH) - transformer.load_state_dict(checkpoint['model_state_dict']) - optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - if 'scheduler_state_dict' in checkpoint: - scheduler.load_state_dict(checkpoint['scheduler_state_dict']) - prev_epoch = checkpoint['epoch'] - loss = checkpoint['loss'] - print(f"restore checkpoint. Epoch: {prev_epoch}, loss: {loss:.3f}") + + +def restore_model(model_path, model, optimizer_): + prev_epoch_ = 0 + loss_ = 1e10 + valid_loss_ = 1e10 + if os.path.exists(model_path): + checkpoint = torch.load(model_path) + model.load_state_dict(checkpoint['model_state_dict']) + optimizer_.load_state_dict(checkpoint['optimizer_state_dict']) + if 'scheduler_state_dict' in checkpoint: + scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + prev_epoch_ = checkpoint['epoch'] + loss_ = checkpoint['loss'] + if 'valid_loss' in checkpoint: + valid_loss_ = checkpoint['valid_loss'] + print(f"restore checkpoint. Epoch: {prev_epoch_}, loss: {loss_:.3f}, valid_loss: {valid_loss_:.3f}") + return prev_epoch_, loss_, valid_loss_ + + +prev_epoch, loss, valid_loss = restore_model(MODEL_PATH, transformer, optimizer) # training loop +best_valid = valid_loss if valid_loss < 1e10 else 1e10 +restore_epoch = 10 for epoch in range(prev_epoch + 1, NUM_EPOCHS + 1): + if epoch % restore_epoch == 0: + restore_model(BEST_MODEL_PATH, transformer, optimizer) start_time = time.time() - train_loss = train_epoch(transformer, iter(data['train']), optimizer) + train_loss = train_epoch(transformer, iter(data['train']), optimizer, epoch) end_time = time.time() -# train_eval_loss = evaluate(transformer, iter(data['train-eval'])) + # train_eval_loss = evaluate(transformer, iter(data['train-eval'])) valid_count = 0 val_loss_sum = 0 for split in scn.utils.download.VALID_SPLITS: @@ -275,29 +366,35 @@ def evaluate(model, val_iter): # print(f"Epoch: {epoch}, {split} loss: {val_loss:.3f}") valid_count += 1 val_loss_sum += val_loss - print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, val loss: {val_loss_sum/valid_count:.3f}, " + print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, val loss: {val_loss_sum / valid_count:.3f}, " f"Epoch time = {(end_time - start_time):.3f}s")) writer_train.add_scalar("loss", train_loss, epoch) writer_train.flush() - writer_valid.add_scalar("loss", val_loss_sum/valid_count, epoch) + writer_valid.add_scalar("loss", val_loss_sum / valid_count, epoch) writer_valid.flush() # writer_train_eval.add_scalar("loss", train_eval_loss, epoch) # writer_train_eval.flush() - scheduler.step(val_loss_sum/valid_count) + scheduler.step(val_loss_sum / valid_count) torch.save({ 'epoch': epoch, 'model_state_dict': transformer.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'loss': train_loss, + 'valid_loss': val_loss_sum / valid_count, }, MODEL_PATH) - torch.save({ - 'epoch': epoch, - 'model_state_dict': transformer.state_dict(), - 'optimizer_state_dict': optimizer.state_dict(), - 'scheduler_state_dict': scheduler.state_dict(), - 'loss': train_loss, - }, f"model/model_t{THRESHOLD_LENGTH}_b{BATCH_SIZE}_e{NUM_ENCODER_LAYERS}_d{NUM_DECODER_LAYERS}_em{EMB_SIZE}_h{NUM_HEAD}_fh{FFN_HID_DIM}_{epoch}.pt") + if val_loss_sum / valid_count < best_valid: + best_valid = val_loss_sum / valid_count + BEST_MODEL_PATH = f"model/model_t{THRESHOLD_LENGTH}_b{BATCH_SIZE}_e{NUM_ENCODER_LAYERS}_d{NUM_DECODER_LAYERS}_em{EMB_SIZE}_h{NUM_HEAD}_fh{FFN_HID_DIM}_{epoch}_{best_valid:.3f}.pt" + torch.save({ + 'epoch': epoch, + 'model_state_dict': transformer.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + 'loss': train_loss, + 'valid_loss': best_valid, + }, BEST_MODEL_PATH) + print(f"new best checkpoint. Epoch: {epoch}, loss: {train_loss:.3f}, valid_loss: {best_valid:.3f}") print('train ended') writer_train.close() writer_valid.close() From 694a38ee1d08484133f9bb88d2f4974d8826626e Mon Sep 17 00:00:00 2001 From: superantichrist Date: Fri, 14 May 2021 17:29:59 +0900 Subject: [PATCH 6/7] calculate loss only for non masked area simulated annealing for learning rate --- alphafold2_pytorch/transformer.py | 2 +- setup.py | 3 +- train_simple.py | 158 +++++++++++++++++++++--------- 3 files changed, 114 insertions(+), 49 deletions(-) diff --git a/alphafold2_pytorch/transformer.py b/alphafold2_pytorch/transformer.py index 9ebf6f0..0cb20b0 100644 --- a/alphafold2_pytorch/transformer.py +++ b/alphafold2_pytorch/transformer.py @@ -25,7 +25,7 @@ class Seq2SeqTransformer(nn.Module): def __init__(self, num_encoder_layers: int, num_decoder_layers: int, emb_size: int, src_vocab_size: int, tgt_vocab_size: int, - dim_feedforward: int = 512, num_head: int = 8, dropout: float = 0.1, activation: str = "relu", + dim_feedforward: int = 512, num_head: int = 8, dropout: float = 0.0, activation: str = "relu", max_len: int = 5000): super(Seq2SeqTransformer, self).__init__() encoder_layer = TransformerEncoderLayer(d_model=emb_size, nhead=num_head, diff --git a/setup.py b/setup.py index c757f9e..932a81f 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,8 @@ 'transformers', 'tqdm', 'biopython', - 'tensorboard' + 'tensorboard', + 'matplotlib' ], setup_requires=[ 'pytest-runner', diff --git a/train_simple.py b/train_simple.py index 14e672c..c163913 100644 --- a/train_simple.py +++ b/train_simple.py @@ -22,10 +22,10 @@ NUM_EPOCHS = int(3e5) NUM_BATCHES = int(1e5) GRADIENT_ACCUMULATE_EVERY = 16 -LEARNING_RATE = 3e-4 +LEARNING_RATE = 1e-6 IGNORE_INDEX = 21 THRESHOLD_LENGTH = 50 -BATCH_SIZE = 250 +BATCH_SIZE = 100 # transformer constants @@ -34,18 +34,18 @@ NUM_ENCODER_LAYERS = 3 NUM_DECODER_LAYERS = 3 EMB_SIZE = 256 -NUM_HEAD = 8 +NUM_HEAD = 4 FFN_HID_DIM = 128 LOSS_WITHOUT_PADDING = False MODEL_PATH = f"model/model_t{THRESHOLD_LENGTH}_b{BATCH_SIZE}_e{NUM_ENCODER_LAYERS}_d{NUM_DECODER_LAYERS}_em{EMB_SIZE}_h{NUM_HEAD}_fh{FFN_HID_DIM}.pt" -BEST_MODEL_PATH = MODEL_PATH +BEST_MODEL_PATH = f"model/model_t{THRESHOLD_LENGTH}_b{BATCH_SIZE}_e{NUM_ENCODER_LAYERS}_d{NUM_DECODER_LAYERS}_em{EMB_SIZE}_h{NUM_HEAD}_fh{FFN_HID_DIM}_best.pt" # set device DISTOGRAM_BUCKETS = constants.DISTOGRAM_BUCKETS DEVICE = constants.DEVICE - +graph_interval = 1 # helpers @@ -114,7 +114,7 @@ def create_mask(src, tgt): def train_epoch(model, train_iter, optimizer_, epoch): model.train() losses = 0 - radian_diffs = torch.zeros(THRESHOLD_LENGTH*TGT_VOCAB_SIZE*BATCH_SIZE).to(DEVICE) + radian_diffs = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * BATCH_SIZE).to(DEVICE) logits_avg = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * BATCH_SIZE).to(DEVICE) angs_avg = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * BATCH_SIZE).to(DEVICE) for idx, (batch) in enumerate(train_iter): @@ -144,10 +144,10 @@ def train_epoch(model, train_iter, optimizer_, epoch): optimizer_.zero_grad() - mask1= mask.unsqueeze(2).expand(-1, -1, 3) + mask1 = mask.unsqueeze(2).expand(-1, -1, 3) angs1 = torch.acos(torch.zeros(1)).item() * 4 * \ - (angs < -torch.acos(torch.zeros(1)).item() * 1.5) +\ - angs + (angs < -torch.acos(torch.zeros(1)).item() * 1.5) + \ + angs angs2 = mask1 * angs1 logits2 = mask1 * logits @@ -159,7 +159,7 @@ def train_epoch(model, train_iter, optimizer_, epoch): loss_ = loss_fn(logits[:, :l, :].reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) diff = logits[:, :l, :].reshape(-1, logits.shape[-1]) - angs.reshape(-1, angs.shape[-1]) else: - loss_ = loss_fn(logits3, angs3) + loss_ = loss_fn(torch.masked_select(logits, mask1), torch.masked_select(angs1, mask1)) diff = logits3 - angs3 radian_diff = torch.rad2deg(diff).reshape(-1) radian_diffs += abs(radian_diff) @@ -167,28 +167,28 @@ def train_epoch(model, train_iter, optimizer_, epoch): angs_avg += abs(torch.rad2deg(angs3)).reshape(-1) # plt.plot(logits3.tolist(), label='logits') - if idx == 0 and epoch % 10 == 0: + if idx == 0 and epoch % graph_interval == 0: + offset = torch.randint(0, b, (1,))*THRESHOLD_LENGTH plt.clf() - plt.plot(angs3[:, 0:1].reshape(-1)[0:THRESHOLD_LENGTH].tolist(), label='phi') - plt.plot(logits3[:, 0:1].reshape(-1)[0:THRESHOLD_LENGTH].tolist(), label='phi_logit') + plt.plot(angs3[:, 0:1].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='phi') + plt.plot(logits3[:, 0:1].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='phi_logit') plt.ylabel('angles') plt.legend() - plt.savefig(f"./graph/train1_{epoch}_phi.png") + plt.savefig(f"graph/train1_{epoch}_phi.png") plt.clf() - plt.plot(angs3[:, 1:2].reshape(-1)[0:THRESHOLD_LENGTH].tolist(), label='psi') - plt.plot(logits3[:, 1:2].reshape(-1)[0:THRESHOLD_LENGTH].tolist(), label='psi_logit') + plt.plot(angs3[:, 1:2].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='psi') + plt.plot(logits3[:, 1:2].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='psi_logit') plt.ylabel('angles') plt.legend() - plt.savefig(f"./graph/train1_{epoch}_psi.png") + plt.savefig(f"graph/train1_{epoch}_psi.png") plt.clf() - plt.plot(angs3[:, 2:3].reshape(-1)[0:THRESHOLD_LENGTH].tolist(), label='omega') - plt.plot(logits3[:, 2:3].reshape(-1)[0:THRESHOLD_LENGTH].tolist(), label='omega_logit') + plt.plot(angs3[:, 2:3].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='omega') + plt.plot(logits3[:, 2:3].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='omega_logit') plt.ylabel('angles') plt.legend() - plt.savefig(f"./graph/train1_{epoch}_omega.png") + plt.savefig(f"graph/train1_{epoch}_omega.png") # plt.plot(diff.tolist()) - loss_.backward() # torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) @@ -199,14 +199,14 @@ def train_epoch(model, train_iter, optimizer_, epoch): angs_avg = angs_avg / len(train_iter) # diff_dict = {str(i): string for i, string in enumerate(radian_diffs.tolist())} # writer_train.add_scalars("train", diff_dict, epoch) - if epoch % 10 == 0: + if epoch % graph_interval == 0: plt.clf() - plt.plot(torch.mean(radian_diffs.reshape(THRESHOLD_LENGTH*TGT_VOCAB_SIZE, -1), 1).tolist(), label='diff') + plt.plot(torch.mean(radian_diffs.reshape(THRESHOLD_LENGTH * TGT_VOCAB_SIZE, -1), 1).tolist(), label='diff') plt.plot(torch.mean(logits_avg.reshape(THRESHOLD_LENGTH * TGT_VOCAB_SIZE, -1), 1).tolist(), label='logit') plt.plot(torch.mean(angs_avg.reshape(THRESHOLD_LENGTH * TGT_VOCAB_SIZE, -1), 1).tolist(), label='ang') plt.ylabel('angles') plt.legend() - plt.savefig(f"./graph/train_{epoch}.png") + plt.savefig(f"graph/train_{epoch}.png") return losses / len(train_iter) @@ -214,12 +214,16 @@ def evaluate(model, val_iter): model.eval() losses = 0 radian_diffs = None # torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * BATCH_SIZE).to(DEVICE) + logits_avg = None + angs_avg = None for idx, (batch) in (enumerate(val_iter)): seq, coords, angs, mask = batch.seqs, batch.crds, batch.angs, batch.msks b, l, _ = seq.shape if radian_diffs is None: radian_diffs = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * b).to(DEVICE) + logits_avg = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * b).to(DEVICE) + angs_avg = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * b).to(DEVICE) # prepare mask, labels seq, coords, angs, mask = seq.argmax(dim=-1).to(DEVICE), coords.to(DEVICE), angs.to(DEVICE), mask.to( @@ -240,27 +244,61 @@ def evaluate(model, val_iter): tgt_mask=mask, src_padding_mask=src_padding_mask, tgt_padding_mask=tgt_padding_mask, memory_key_padding_mask=src_padding_mask) - angs_correction = torch.acos(torch.zeros(1)).item() * 4 * \ - (angs.reshape(-1, angs.shape[-1]) < -torch.acos(torch.zeros(1)).item() * 1.5) + \ - angs.reshape(-1, angs.shape[-1]) + mask1 = mask.unsqueeze(2).expand(-1, -1, 3) + angs1 = torch.acos(torch.zeros(1)).item() * 4 * \ + (angs < -torch.acos(torch.zeros(1)).item() * 1.5) + \ + angs + + angs2 = mask1 * angs1 + logits2 = mask1 * logits + angs3 = angs2.reshape(-1, angs2.shape[-1]) + logits3 = logits2.reshape(-1, logits2.shape[-1]) # loss if LOSS_WITHOUT_PADDING: loss_ = loss_fn(logits[:, :l, :].reshape(-1, logits.shape[-1]), angs.reshape(-1, angs.shape[-1])) diff = logits[:, :l, :].reshape(-1, logits.shape[-1]) - angs.reshape(-1, angs.shape[-1]) else: - loss_ = loss_fn(logits.reshape(-1, logits.shape[-1]), angs_correction) - diff = logits.reshape(-1, logits.shape[-1]) - angs_correction + loss_ = loss_fn(torch.masked_select(logits, mask1), torch.masked_select(angs1, mask1)) + diff = logits3 - angs3 radian_diff = torch.rad2deg(diff).reshape(-1) radian_diffs += abs(radian_diff) + logits_avg += abs(torch.rad2deg(logits3)).reshape(-1) + angs_avg += abs(torch.rad2deg(angs3)).reshape(-1) + + if idx == 0 and epoch % graph_interval == 0: + offset = torch.randint(0, b, (1,)) * THRESHOLD_LENGTH + plt.clf() + plt.plot(angs3[:, 0:1].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='phi') + plt.plot(logits3[:, 0:1].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='phi_logit') + plt.ylabel('angles') + plt.legend() + plt.savefig(f"graph/valid1_{epoch}_phi.png") + plt.clf() + plt.plot(angs3[:, 1:2].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='psi') + plt.plot(logits3[:, 1:2].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='psi_logit') + plt.ylabel('angles') + plt.legend() + plt.savefig(f"graph/valid1_{epoch}_psi.png") + plt.clf() + plt.plot(angs3[:, 2:3].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='omega') + plt.plot(logits3[:, 2:3].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='omega_logit') + plt.ylabel('angles') + plt.legend() + plt.savefig(f"graph/valid1_{epoch}_omega.png") losses += loss_.item() radian_diffs = radian_diffs / len(val_iter) - # diff_dict = {str(i): string for i, string in enumerate(radian_diffs.tolist())} - # writer_train.add_scalars("train", diff_dict, epoch) - # plt.plot(torch.mean(radian_diffs.reshape(THRESHOLD_LENGTH * TGT_VOCAB_SIZE, -1), 1).tolist()) - # plt.ylabel('angles') - # plt.savefig("valid.png") + logits_avg = logits_avg / len(val_iter) + angs_avg = angs_avg / len(val_iter) + if epoch % graph_interval == 0: + plt.clf() + plt.plot(torch.mean(radian_diffs.reshape(THRESHOLD_LENGTH * TGT_VOCAB_SIZE, -1), 1).tolist(), label='diff') + plt.plot(torch.mean(logits_avg.reshape(THRESHOLD_LENGTH * TGT_VOCAB_SIZE, -1), 1).tolist(), label='logit') + plt.plot(torch.mean(angs_avg.reshape(THRESHOLD_LENGTH * TGT_VOCAB_SIZE, -1), 1).tolist(), label='ang') + plt.ylabel('angles') + plt.legend() + plt.savefig(f"graph/valid_{epoch}.png") return losses / len(val_iter) @@ -277,6 +315,7 @@ def evaluate(model, val_iter): writer_train = SummaryWriter("runs/train") # writer_train_eval = SummaryWriter("runs/train_eval") writer_valid = SummaryWriter("runs/validation") +writer_best = SummaryWriter("runs/best") # writer_valids = [] for split in scn.utils.download.VALID_SPLITS: filtered_raw_data = filter_dictionary_by_seq_length(filtered_raw_data, THRESHOLD_LENGTH, f'{split}') @@ -318,7 +357,7 @@ def evaluate(model, val_iter): loss_fn = torch.nn.MSELoss() optimizer = torch.optim.Adam( - transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9 + transformer.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.98), eps=1e-9 ) # optimizer = torch.optim.RMSprop( # transformer.parameters(), lr=0.01, alpha=0.99, eps=1e-08, weight_decay=0, momentum=0, centered=False @@ -328,31 +367,44 @@ def evaluate(model, val_iter): prev_epoch = 0 -def restore_model(model_path, model, optimizer_): +def restore_model(model_path, model, optimizer_, restore_optim=False, restore=True): prev_epoch_ = 0 loss_ = 1e10 valid_loss_ = 1e10 if os.path.exists(model_path): checkpoint = torch.load(model_path) - model.load_state_dict(checkpoint['model_state_dict']) - optimizer_.load_state_dict(checkpoint['optimizer_state_dict']) - if 'scheduler_state_dict' in checkpoint: - scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + if restore: + model.load_state_dict(checkpoint['model_state_dict']) + if restore_optim: + optimizer_.load_state_dict(checkpoint['optimizer_state_dict']) + if 'scheduler_state_dict' in checkpoint: + scheduler.load_state_dict(checkpoint['scheduler_state_dict']) prev_epoch_ = checkpoint['epoch'] loss_ = checkpoint['loss'] if 'valid_loss' in checkpoint: valid_loss_ = checkpoint['valid_loss'] - print(f"restore checkpoint. Epoch: {prev_epoch_}, loss: {loss_:.3f}, valid_loss: {valid_loss_:.3f}") + if restore: + print(f"restore checkpoint. Epoch: {prev_epoch_}, loss: {loss_:.3f}, valid_loss: {valid_loss_:.3f}") + else: + print(f"best checkpoint. Epoch: {prev_epoch_}, loss: {loss_:.3f}, valid_loss: {valid_loss_:.3f}") return prev_epoch_, loss_, valid_loss_ prev_epoch, loss, valid_loss = restore_model(MODEL_PATH, transformer, optimizer) -# training loop best_valid = valid_loss if valid_loss < 1e10 else 1e10 -restore_epoch = 10 +_, _, valid_restore = restore_model(BEST_MODEL_PATH, transformer, optimizer, restore=False) +if valid_restore < best_valid: + best_valid = valid_restore +# training loop +not_improved_count = 1 +restore_epoch = 11 +warmup_steps = 4000 for epoch in range(prev_epoch + 1, NUM_EPOCHS + 1): - if epoch % restore_epoch == 0: - restore_model(BEST_MODEL_PATH, transformer, optimizer) + if epoch > warmup_steps and not_improved_count % restore_epoch == 0: + not_improved_count = 1 + learning_rate = pow(EMB_SIZE, -0.5)*min(pow(epoch, -0.5), epoch*pow(warmup_steps, -1.5)) + for g in optimizer.param_groups: + g['lr'] = learning_rate start_time = time.time() train_loss = train_epoch(transformer, iter(data['train']), optimizer, epoch) end_time = time.time() @@ -367,7 +419,7 @@ def restore_model(model_path, model, optimizer_): valid_count += 1 val_loss_sum += val_loss print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, val loss: {val_loss_sum / valid_count:.3f}, " - f"Epoch time = {(end_time - start_time):.3f}s")) + f"Epoch time = {(end_time - start_time):.3f}s learning rate: {learning_rate}")) writer_train.add_scalar("loss", train_loss, epoch) writer_train.flush() writer_valid.add_scalar("loss", val_loss_sum / valid_count, epoch) @@ -385,7 +437,7 @@ def restore_model(model_path, model, optimizer_): }, MODEL_PATH) if val_loss_sum / valid_count < best_valid: best_valid = val_loss_sum / valid_count - BEST_MODEL_PATH = f"model/model_t{THRESHOLD_LENGTH}_b{BATCH_SIZE}_e{NUM_ENCODER_LAYERS}_d{NUM_DECODER_LAYERS}_em{EMB_SIZE}_h{NUM_HEAD}_fh{FFN_HID_DIM}_{epoch}_{best_valid:.3f}.pt" + save_path = f"model/model_t{THRESHOLD_LENGTH}_b{BATCH_SIZE}_e{NUM_ENCODER_LAYERS}_d{NUM_DECODER_LAYERS}_em{EMB_SIZE}_h{NUM_HEAD}_fh{FFN_HID_DIM}_{epoch}_{best_valid:.3f}.pt" torch.save({ 'epoch': epoch, 'model_state_dict': transformer.state_dict(), @@ -394,7 +446,19 @@ def restore_model(model_path, model, optimizer_): 'loss': train_loss, 'valid_loss': best_valid, }, BEST_MODEL_PATH) + torch.save({ + 'epoch': epoch, + 'model_state_dict': transformer.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'scheduler_state_dict': scheduler.state_dict(), + 'loss': train_loss, + 'valid_loss': best_valid, + }, save_path) print(f"new best checkpoint. Epoch: {epoch}, loss: {train_loss:.3f}, valid_loss: {best_valid:.3f}") + writer_best.add_scalar("loss", best_valid, epoch) + writer_best.flush() + elif epoch > warmup_steps: + not_improved_count += 1 print('train ended') writer_train.close() writer_valid.close() From 21b02a7f06fdfaf522ae352baed3e5fc3c46f4e5 Mon Sep 17 00:00:00 2001 From: superantichrist Date: Tue, 1 Jun 2021 14:19:05 +0900 Subject: [PATCH 7/7] apply padding mask correctly --- alphafold2_pytorch/transformer.py | 2 +- train_simple.py | 84 ++++++++++++++++++------------- 2 files changed, 49 insertions(+), 37 deletions(-) diff --git a/alphafold2_pytorch/transformer.py b/alphafold2_pytorch/transformer.py index 0cb20b0..7d6fdca 100644 --- a/alphafold2_pytorch/transformer.py +++ b/alphafold2_pytorch/transformer.py @@ -44,7 +44,7 @@ def __init__(self, num_encoder_layers: int, num_decoder_layers: int, def forward(self, src: Tensor, trg: Tensor, src_mask: Tensor, tgt_mask: Tensor, src_padding_mask: Tensor, tgt_padding_mask: Tensor, memory_key_padding_mask: Tensor, - use_padding_mask: bool = False): + use_padding_mask: bool = True): src_emb = self.positional_encoding(self.src_tok_emb(src)) tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg)) if use_padding_mask: diff --git a/train_simple.py b/train_simple.py index c163913..1c88687 100644 --- a/train_simple.py +++ b/train_simple.py @@ -23,29 +23,37 @@ NUM_BATCHES = int(1e5) GRADIENT_ACCUMULATE_EVERY = 16 LEARNING_RATE = 1e-6 -IGNORE_INDEX = 21 -THRESHOLD_LENGTH = 50 +IGNORE_INDEX = 20 +THRESHOLD_LENGTH = 100 BATCH_SIZE = 100 # transformer constants -SRC_VOCAB_SIZE = 22 # number of amino acids + padding 21 +SRC_VOCAB_SIZE = 21 # number of amino acids + padding 20 TGT_VOCAB_SIZE = 3 # backbone torsion angle -NUM_ENCODER_LAYERS = 3 -NUM_DECODER_LAYERS = 3 -EMB_SIZE = 256 -NUM_HEAD = 4 -FFN_HID_DIM = 128 +NUM_ENCODER_LAYERS = 6 +NUM_DECODER_LAYERS = 6 +EMB_SIZE = 512 +NUM_HEAD = 8 +FFN_HID_DIM = 1024 LOSS_WITHOUT_PADDING = False +warmup_steps = 4000 +DROPOUT = 0.1 +MODEL_NAME = f"model_t{THRESHOLD_LENGTH}_b{BATCH_SIZE}_e{NUM_ENCODER_LAYERS}_d{NUM_DECODER_LAYERS}_em{EMB_SIZE}_h{NUM_HEAD}_fh{FFN_HID_DIM}_dropout{DROPOUT}_warmup{warmup_steps}" MODEL_PATH = f"model/model_t{THRESHOLD_LENGTH}_b{BATCH_SIZE}_e{NUM_ENCODER_LAYERS}_d{NUM_DECODER_LAYERS}_em{EMB_SIZE}_h{NUM_HEAD}_fh{FFN_HID_DIM}.pt" BEST_MODEL_PATH = f"model/model_t{THRESHOLD_LENGTH}_b{BATCH_SIZE}_e{NUM_ENCODER_LAYERS}_d{NUM_DECODER_LAYERS}_em{EMB_SIZE}_h{NUM_HEAD}_fh{FFN_HID_DIM}_best.pt" # set device +try: + os.makedirs(f'graph/{MODEL_NAME}/') +except: + print(f'graph/{MODEL_NAME}/ aleardy exist') + DISTOGRAM_BUCKETS = constants.DISTOGRAM_BUCKETS DEVICE = constants.DEVICE -graph_interval = 1 +graph_interval = 5 # helpers @@ -118,20 +126,20 @@ def train_epoch(model, train_iter, optimizer_, epoch): logits_avg = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * BATCH_SIZE).to(DEVICE) angs_avg = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * BATCH_SIZE).to(DEVICE) for idx, (batch) in enumerate(train_iter): - seq, coords, angs, mask = batch.seqs, batch.crds, batch.angs, batch.msks + seq, coords, angs, mask = batch.int_seqs, batch.crds, batch.angs, batch.msks - b, l, _ = seq.shape + b, l = seq.shape # prepare mask, labels - seq, coords, angs, mask = seq.argmax(dim=-1).to(DEVICE), coords.to(DEVICE), angs.to(DEVICE), mask.to( + seq, coords, angs, mask = seq.to(DEVICE), coords.to(DEVICE), angs.to(DEVICE), mask.to( DEVICE).bool() - seq = F.pad(seq, (0, THRESHOLD_LENGTH - l), value=IGNORE_INDEX) + # seq = F.pad(seq, (0, THRESHOLD_LENGTH - l), value=IGNORE_INDEX) coords = rearrange(coords, 'b (l c) d -> b l c d', l=l) - if not LOSS_WITHOUT_PADDING: - angs = F.pad(angs, (0, 0, 0, THRESHOLD_LENGTH - l), value=0) + # if not LOSS_WITHOUT_PADDING: + # angs = F.pad(angs, (0, 0, 0, THRESHOLD_LENGTH - l), value=0) # angs = rearrange(angs, 'b l c -> b (l c)', l=THRESHOLD_LENGTH) - mask = F.pad(mask, (0, THRESHOLD_LENGTH - l), value=False) + # mask = F.pad(mask, (0, THRESHOLD_LENGTH - l), value=False) # discretized_distances = get_bucketed_distance_matrix(coords[:, :, 1], mask, DISTOGRAM_BUCKETS, IGNORE_INDEX) src_padding_mask, tgt_padding_mask = create_mask(seq, seq) @@ -161,9 +169,12 @@ def train_epoch(model, train_iter, optimizer_, epoch): else: loss_ = loss_fn(torch.masked_select(logits, mask1), torch.masked_select(angs1, mask1)) diff = logits3 - angs3 + diff = F.pad(diff, (0, 0, 0, (THRESHOLD_LENGTH - l)*BATCH_SIZE), value=0) radian_diff = torch.rad2deg(diff).reshape(-1) radian_diffs += abs(radian_diff) + logits3 = F.pad(logits3, (0, 0, 0, (THRESHOLD_LENGTH - l) * BATCH_SIZE), value=0) logits_avg += abs(torch.rad2deg(logits3)).reshape(-1) + angs3 = F.pad(angs3, (0, 0, 0, (THRESHOLD_LENGTH - l) * BATCH_SIZE), value=0) angs_avg += abs(torch.rad2deg(angs3)).reshape(-1) # plt.plot(logits3.tolist(), label='logits') @@ -174,19 +185,19 @@ def train_epoch(model, train_iter, optimizer_, epoch): plt.plot(logits3[:, 0:1].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='phi_logit') plt.ylabel('angles') plt.legend() - plt.savefig(f"graph/train1_{epoch}_phi.png") + plt.savefig(f"graph/{MODEL_NAME}/train1_{epoch}_phi.png") plt.clf() plt.plot(angs3[:, 1:2].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='psi') plt.plot(logits3[:, 1:2].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='psi_logit') plt.ylabel('angles') plt.legend() - plt.savefig(f"graph/train1_{epoch}_psi.png") + plt.savefig(f"graph/{MODEL_NAME}/train1_{epoch}_psi.png") plt.clf() plt.plot(angs3[:, 2:3].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='omega') plt.plot(logits3[:, 2:3].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='omega_logit') plt.ylabel('angles') plt.legend() - plt.savefig(f"graph/train1_{epoch}_omega.png") + plt.savefig(f"graph/{MODEL_NAME}/train1_{epoch}_omega.png") # plt.plot(diff.tolist()) loss_.backward() @@ -206,27 +217,27 @@ def train_epoch(model, train_iter, optimizer_, epoch): plt.plot(torch.mean(angs_avg.reshape(THRESHOLD_LENGTH * TGT_VOCAB_SIZE, -1), 1).tolist(), label='ang') plt.ylabel('angles') plt.legend() - plt.savefig(f"graph/train_{epoch}.png") + plt.savefig(f"graph/{MODEL_NAME}/train_{epoch}.png") return losses / len(train_iter) -def evaluate(model, val_iter): +def evaluate(model, val_iter, split_): model.eval() losses = 0 radian_diffs = None # torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * BATCH_SIZE).to(DEVICE) logits_avg = None angs_avg = None for idx, (batch) in (enumerate(val_iter)): - seq, coords, angs, mask = batch.seqs, batch.crds, batch.angs, batch.msks + seq, coords, angs, mask = batch.int_seqs, batch.crds, batch.angs, batch.msks - b, l, _ = seq.shape + b, l = seq.shape if radian_diffs is None: radian_diffs = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * b).to(DEVICE) logits_avg = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * b).to(DEVICE) angs_avg = torch.zeros(THRESHOLD_LENGTH * TGT_VOCAB_SIZE * b).to(DEVICE) # prepare mask, labels - seq, coords, angs, mask = seq.argmax(dim=-1).to(DEVICE), coords.to(DEVICE), angs.to(DEVICE), mask.to( + seq, coords, angs, mask = seq.to(DEVICE), coords.to(DEVICE), angs.to(DEVICE), mask.to( DEVICE).bool() seq = F.pad(seq, (0, THRESHOLD_LENGTH - l), value=IGNORE_INDEX) coords = rearrange(coords, 'b (l c) d -> b l c d', l=l) @@ -266,26 +277,26 @@ def evaluate(model, val_iter): logits_avg += abs(torch.rad2deg(logits3)).reshape(-1) angs_avg += abs(torch.rad2deg(angs3)).reshape(-1) - if idx == 0 and epoch % graph_interval == 0: + if epoch % graph_interval == 0: offset = torch.randint(0, b, (1,)) * THRESHOLD_LENGTH plt.clf() plt.plot(angs3[:, 0:1].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='phi') plt.plot(logits3[:, 0:1].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='phi_logit') plt.ylabel('angles') plt.legend() - plt.savefig(f"graph/valid1_{epoch}_phi.png") + plt.savefig(f"graph/{MODEL_NAME}/valid1_{epoch}_phi_{split_}_{idx}.png") plt.clf() plt.plot(angs3[:, 1:2].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='psi') plt.plot(logits3[:, 1:2].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='psi_logit') plt.ylabel('angles') plt.legend() - plt.savefig(f"graph/valid1_{epoch}_psi.png") + plt.savefig(f"graph/{MODEL_NAME}/valid1_{epoch}_psi_{split_}_{idx}.png") plt.clf() plt.plot(angs3[:, 2:3].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='omega') plt.plot(logits3[:, 2:3].reshape(-1)[offset:offset+THRESHOLD_LENGTH].tolist(), label='omega_logit') plt.ylabel('angles') plt.legend() - plt.savefig(f"graph/valid1_{epoch}_omega.png") + plt.savefig(f"graph/{MODEL_NAME}/valid1_{epoch}_omega_{split_}_{idx}.png") losses += loss_.item() radian_diffs = radian_diffs / len(val_iter) @@ -298,7 +309,7 @@ def evaluate(model, val_iter): plt.plot(torch.mean(angs_avg.reshape(THRESHOLD_LENGTH * TGT_VOCAB_SIZE, -1), 1).tolist(), label='ang') plt.ylabel('angles') plt.legend() - plt.savefig(f"graph/valid_{epoch}.png") + plt.savefig(f"graph/{MODEL_NAME}/valid_{epoch}.png") return losses / len(val_iter) @@ -312,10 +323,10 @@ def evaluate(model, val_iter): ) filtered_raw_data = filter_dictionary_by_seq_length(raw_data, THRESHOLD_LENGTH, "train") -writer_train = SummaryWriter("runs/train") +writer_train = SummaryWriter(f"runs/{MODEL_NAME}/train") # writer_train_eval = SummaryWriter("runs/train_eval") -writer_valid = SummaryWriter("runs/validation") -writer_best = SummaryWriter("runs/best") +writer_valid = SummaryWriter(f"runs/{MODEL_NAME}/validation") +writer_best = SummaryWriter(f"runs/{MODEL_NAME}/best") # writer_valids = [] for split in scn.utils.download.VALID_SPLITS: filtered_raw_data = filter_dictionary_by_seq_length(filtered_raw_data, THRESHOLD_LENGTH, f'{split}') @@ -344,7 +355,8 @@ def evaluate(model, val_iter): # transformer = Seq2SeqTransformer(num_encoder_layers=NUM_ENCODER_LAYERS, num_decoder_layers=NUM_DECODER_LAYERS, emb_size=EMB_SIZE, src_vocab_size=SRC_VOCAB_SIZE, tgt_vocab_size=TGT_VOCAB_SIZE, - dim_feedforward=FFN_HID_DIM, num_head=NUM_HEAD, activation='gelu', max_len=5000) + dim_feedforward=FFN_HID_DIM, num_head=NUM_HEAD, activation='gelu', max_len=5000, + dropout=DROPOUT) # optimizer @@ -397,11 +409,11 @@ def restore_model(model_path, model, optimizer_, restore_optim=False, restore=Tr best_valid = valid_restore # training loop not_improved_count = 1 -restore_epoch = 11 -warmup_steps = 4000 +restore_epoch = 101 for epoch in range(prev_epoch + 1, NUM_EPOCHS + 1): if epoch > warmup_steps and not_improved_count % restore_epoch == 0: not_improved_count = 1 + restore_model(BEST_MODEL_PATH, transformer, optimizer) learning_rate = pow(EMB_SIZE, -0.5)*min(pow(epoch, -0.5), epoch*pow(warmup_steps, -1.5)) for g in optimizer.param_groups: g['lr'] = learning_rate @@ -412,7 +424,7 @@ def restore_model(model_path, model, optimizer_, restore_optim=False, restore=Tr valid_count = 0 val_loss_sum = 0 for split in scn.utils.download.VALID_SPLITS: - val_loss = evaluate(transformer, iter(data[f'{split}'])) + val_loss = evaluate(transformer, iter(data[f'{split}']), split) # writer_valids[valid_count].add_scalar("loss", val_loss, epoch) # writer_valids[valid_count].flush() # print(f"Epoch: {epoch}, {split} loss: {val_loss:.3f}")