Skip to content

Commit

Permalink
fix data changes
Browse files Browse the repository at this point in the history
  • Loading branch information
delfosseaurelien committed Sep 17, 2021
2 parents 5efa4e7 + f36deb0 commit cb54bbb
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 151 deletions.
214 changes: 161 additions & 53 deletions biotransformers/lightning_utils/data.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import functools
import math
import random
from collections import OrderedDict
from typing import Callable, List, Sequence, Tuple
from typing import Callable, List, Optional, Sequence, Tuple

import numpy as np
import torch
import torch.distributed as dist
from esm.data import BatchConverter
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader, Dataset, Sampler


Expand Down Expand Up @@ -187,6 +190,7 @@ def get_batch_indices(
sequence_strs: List[str],
toks_per_batch: int,
crop_sizes: Tuple[int, int] = (600, 1200),
seed: int = 0,
) -> List[List[List[Tuple[int, int]]]]:
"""
This sampler aims to create batches that do not contain fixed number of sequences
Expand All @@ -208,31 +212,37 @@ def get_batch_indices(
Args:
sequence_strs: list of string
toks_per_batch: maximum number of token per batch
crop_sizes: min and max sequence lengths when cropping
toks_per_batch (int): Maximum number of token per batch
extra_toks_per_seq (int, optional): . Defaults to 0.
crop_sizes (Tuple[int, int]): min and max sequence lengths when cropping
seed (int): seed to be used for random generator
Returns:
List: List of batches indexes and lengths
"""
min_size, max_size = crop_sizes
buffer_type = List[Tuple[int, int]]

def crop_length(length: int) -> int:
crop_size = random.randint(min_size, max_size) - 2
rand_generator = random.Random(seed)

def crop_length(length: int, random_generator: random.Random) -> int:
crop_size = random_generator.randint(min_size, max_size) - 2
if length > crop_size:
return crop_size
else:
return length

sizes = [(crop_length(len(s)), i) for i, s in enumerate(sequence_strs)]
sizes = [
(crop_length(len(s), rand_generator), i) for i, s in enumerate(sequence_strs)
]
min_length, max_length = min([t[0] for t in sizes]), max([t[0] for t in sizes])

# if there is a large gap between min and max size, sort the list
if min_length < 0.8 * max_length:
sizes.sort()
# otherwise shuffle it
else:
random.shuffle(sizes)
rand_generator.shuffle(sizes)

batches: List[List[buffer_type]] = []
buffer: buffer_type = []
Expand All @@ -257,7 +267,7 @@ def _flush_current_buf():

_flush_current_buf()

random.shuffle(batches)
rand_generator.shuffle(batches)
return batches


Expand Down Expand Up @@ -297,6 +307,98 @@ def __iter__(self):
)


class DistributedBatchWithConstantNumberTokensSampler(Sampler):
"""
Sampler that returns batches of sequences indices in the dataset so that to ensure
not a fixed number of sequences per batch but rather a fixed number of tokens per
batch. This sampler also takes into account that we may want to crop dynamically
sequences when sampling and thus returns in addition to indices, desired cropping
lengths to inform the dataloader. This version of the sampler is distributed to
be used with DDP accelerator.
"""

def __init__(
self,
sequence_strs: List[str],
toks_per_batch: int,
crop_sizes: Tuple[int, int] = (512, 1024),
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
seed: int = 0,
):
Sampler.__init__(self, data_source=None)

# Replicate Torch Distributed Sampler logic
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if rank >= num_replicas or rank < 0:
raise ValueError(
"Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1)
)

self._num_replicas = num_replicas
self._rank = rank
self._epoch = 0
self._seed = seed

self._sequence_strs = sequence_strs
self._toks_per_batch = toks_per_batch
self._crop_sizes = crop_sizes
self._init_batches = get_batch_indices(
sequence_strs=sequence_strs,
toks_per_batch=toks_per_batch,
crop_sizes=crop_sizes,
seed=self._seed + self._epoch,
)
self._num_samples = math.ceil(len(self._init_batches) / self._num_replicas)
self._total_size = self._num_samples * self._num_replicas

def __len__(self) -> int:
return self._num_samples

def set_epoch(self, epoch: int) -> None:
self._epoch = epoch

def __iter__(self):

# generate batches with constant number of tokens
batches = get_batch_indices(
sequence_strs=self._sequence_strs,
toks_per_batch=self._toks_per_batch,
crop_sizes=self._crop_sizes,
seed=self._seed + self._epoch,
)

# shuffle the indices
rng = np.random.default_rng(seed=self._seed + self._epoch)
indices = list(rng.permutation(len(batches)))

# add extra samples to make it evenly divisible
padding_size = self._total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
assert len(indices) == self._total_size

# subsample (to get batches for this worker)
indices = indices[self._rank : self._total_size : self._num_replicas]
assert len(indices) == self._num_samples

# get corresponding batches
batches = [batches[i] for i in indices]

# return iterator
yield from batches


class BatchWithConstantNumberTokensDataset(Dataset):
"""
Dataset class to work in pair with the BatchWithConstantNumberTokensSampler.
Expand All @@ -319,52 +421,58 @@ def __getitem__(self, sampler_out) -> List[str]:
return sequences


def create_dataloader(
sequences: List[str],
alphabet: AlphabetDataLoader,
masking_ratio: float,
masking_prob: float,
random_token_prob: float,
num_workers: int,
toks_per_batch: int,
crop_sizes: Tuple[int, int] = (512, 1024),
) -> DataLoader:
"""Create the PyTorch Dataloader.
class BatchWithConstantNumberTokensDataModule(LightningDataModule):
def __init__(
self,
train_sequences: List[str],
validation_sequences: List[str],
alphabet: AlphabetDataLoader,
masking_ratio: float,
masking_prob: float,
random_token_prob: float,
num_workers: int,
toks_per_batch: int,
crop_sizes: Tuple[int, int] = (512, 1024),
):
LightningDataModule.__init__(self)
self._train_sequences = train_sequences
self._validation_sequences = validation_sequences
self._alphabet = alphabet
self._masking_ratio = masking_ratio
self._masking_prob = masking_prob
self._random_token_prob = random_token_prob
self._num_workers = num_workers
self._toks_per_batch = toks_per_batch
self._crop_sizes = crop_sizes

Args:
filenames: list of sequences
alphabet: facebook alphabet.
filter_len: whether filter data wrt len.batch_seq
num_workers: num of parallel data samplers
masking_ratio: ratio of tokens to be masked.
masking_prob: probability that the chose token is replaced with a mask token.
random_token_prob: probability that the chose token is replaced with a random token.
toks_per_batch: number of tokens per batch
crop_sizes: range of values to crop dynamically sequences when sampling them
def _get_dataloader(self, sequences: List[str]) -> DataLoader:
dataset = BatchWithConstantNumberTokensDataset(sequences)
batch_sampler = DistributedBatchWithConstantNumberTokensSampler(
sequence_strs=sequences,
toks_per_batch=self._toks_per_batch,
crop_sizes=self._crop_sizes,
)

Returns:
torch DataLoader
"""
loader = DataLoader(
dataset,
num_workers=self._num_workers,
collate_fn=functools.partial(
collate_fn,
tokenizer=self._alphabet.tokenizer(),
alphabet=self._alphabet,
masking_ratio=self._masking_ratio,
masking_prob=self._masking_prob,
random_token_prob=self._random_token_prob,
),
pin_memory=True,
worker_init_fn=worker_init_fn,
batch_sampler=batch_sampler,
sampler=None,
)
return loader

dataset = BatchWithConstantNumberTokensDataset(sequences)
batch_sampler = BatchWithConstantNumberTokensSampler(
sequence_strs=sequences, toks_per_batch=toks_per_batch, crop_sizes=crop_sizes
)
def train_dataloader(self):
return self._get_dataloader(self._train_sequences)

loader = DataLoader(
dataset,
num_workers=num_workers,
collate_fn=functools.partial(
collate_fn,
tokenizer=alphabet.tokenizer(),
alphabet=alphabet,
masking_ratio=masking_ratio,
masking_prob=masking_prob,
random_token_prob=random_token_prob,
),
pin_memory=True,
worker_init_fn=worker_init_fn,
batch_sampler=batch_sampler,
sampler=None,
)
return loader
def val_dataloader(self):
return self._get_dataloader(self._validation_sequences)
23 changes: 5 additions & 18 deletions biotransformers/wrappers/esm_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@

import esm
import torch
from biotransformers.lightning_utils.data import (
AlphabetDataLoader,
convert_ckpt_to_statedict,
)
from biotransformers.lightning_utils.data import AlphabetDataLoader
from biotransformers.utils.constant import DEFAULT_ESM_MODEL, ESM_LIST
from biotransformers.utils.logger import logger # noqa
from biotransformers.utils.utils import _generate_chunks, _get_num_batch_iter
Expand Down Expand Up @@ -49,6 +46,10 @@ def model(self) -> torch.nn.Module:
"""Return torch model."""
return self._model

def set_model(self, model: torch.nn.Module):
"""Set torch model."""
self._model = model.to(self._model.device)

@property
def clean_model_id(self) -> str:
"""Clean model ID (in case the model directory is not)"""
Expand Down Expand Up @@ -118,20 +119,6 @@ def process_sequences_and_tokens(
}
return encoded_inputs

def _load_model(self, path_model: str, map_location=None):
"""Load model."""
if path_model.endswith(".pt"):
loaded_model = torch.load(path_model)
elif path_model.endswith(".ckpt"):
loaded_model = convert_ckpt_to_statedict(
torch.load(path_model)["state_dict"]
)
else:
raise ValueError("Expecting a .pt or .ckpt file")
self._model.load_state_dict(loaded_model, map_location)
self._model.eval()
log.info("Load model %s" % path_model)

def model_pass(
self,
model_inputs: Dict[str, torch.Tensor],
Expand Down
4 changes: 2 additions & 2 deletions biotransformers/wrappers/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def model(self) -> torch.nn.Module:
pass

@abstractmethod
def _load_model(self, path: str):
"""Load model."""
def set_model(self, model: torch.nn.Module):
"""Set torch model."""
pass

@abstractmethod
Expand Down
23 changes: 5 additions & 18 deletions biotransformers/wrappers/rostlab_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@

import torch
import copy
from biotransformers.lightning_utils.data import (
AlphabetDataLoader,
convert_ckpt_to_statedict,
)
from biotransformers.lightning_utils.data import AlphabetDataLoader
from biotransformers.utils.constant import DEFAULT_ROSTLAB_MODEL, ROSTLAB_LIST
from biotransformers.utils.logger import logger # noqa
from biotransformers.utils.utils import _generate_chunks, _get_num_batch_iter
Expand Down Expand Up @@ -52,6 +49,10 @@ def model(self) -> torch.nn.Module:
"""Return torch model."""
return self._model

def set_model(self, model: torch.nn.Module):
"""Set torch model."""
self._model = model.to(self._model.device)

@property
def clean_model_id(self) -> str:
"""Clean model ID (in case the model directory is not)"""
Expand Down Expand Up @@ -102,20 +103,6 @@ def embeddings_size(self) -> int:
"""Returns size of the embeddings"""
return self.hidden_size

def _load_model(self, path_model: str, map_location=None):
"""Load model."""
if path_model.endswith(".pt"):
loaded_model = torch.load(path_model)
elif path_model.endswith(".ckpt"):
loaded_model = convert_ckpt_to_statedict(
torch.load(path_model)["state_dict"]
)
else:
raise ValueError("Expecting a .pt or .ckpt file")
self._model.load_state_dict(loaded_model, map_location)
self._model.eval()
log.info("Load model %s" % path_model)

def process_sequences_and_tokens(
self,
sequences_list: List[str],
Expand Down
Loading

0 comments on commit cb54bbb

Please sign in to comment.