Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradual slowdown of training in bigger batch sizes #3050

Open
sidharthg-couture opened this issue Nov 11, 2024 · 7 comments
Open

Gradual slowdown of training in bigger batch sizes #3050

sidharthg-couture opened this issue Nov 11, 2024 · 7 comments

Comments

@sidharthg-couture
Copy link

I am facing a very weird issue here.

Issue

  • The training speed slows down with time for batch sizes 64 and 128. For batch size 32 it seems to be staying fairly constant.
  • The tensorboard graph for epoch/time (batch size based on color: yellow - 32, purple - 64, green - 128):
image

As you can see the training slows down for larger batches.

Experiment Details

  • I am training a Stella-400M model on MNR loss
  • Using 4 A100 to train, using Deepspeed Zero- Stage 2

Training Code

import traceback
from datetime import datetime
# from accelerate.logging import get_logger
# logging = get_logger(__name__, log_level="INFO")


from datasets import Dataset, load_dataset
from datasets.utils.logging import disable_progress_bar, enable_progress_bar
enable_progress_bar()
# disable_progress_bar()

from peft import LoraConfig, get_peft_model
from peft import PeftConfig, PeftModel

from transformers import AutoModel, AutoTokenizer, AutoConfig, TrainerCallback, get_linear_schedule_with_warmup

from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Transformer, Pooling
from sentence_transformers.evaluation import SentenceEvaluator
from sentence_transformers.losses import MultipleNegativesRankingLoss, GISTEmbedLoss
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import MultiDatasetBatchSamplers, SentenceTransformerTrainingArguments, BatchSamplers
from sentence_transformers.sampler import NoDuplicatesBatchSampler, ProportionalBatchSampler
from sentence_transformers import LoggingHandler, util
from transformers import BitsAndBytesConfig

import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import BatchSampler, ConcatDataset, SubsetRandomSampler

from typing import Any, Iterator
from collections import defaultdict
from itertools import accumulate, cycle


import pandas as pd
import numpy as np
import gc
import os
import random


RANDOM_STATE = 0

random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)

#setting env variables
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_DATASETS_CACHE"] = "/data/hgfc_cache_new"
os.environ["HF_HOME"] = "/data/hgfc_cache_new"

TRIAL = False # control param for test runs

BASE_DIR = "/data/training_runs/stella_mnr_train_epoch3"

CACHE_DIR = "/data/hgfc_cache_new"

# if TRIAL: BASE_DIR = "/data/monil/stella_mnr_train_final"

if not os.path.exists(BASE_DIR): os.makedirs(BASE_DIR)


# Hyper Params
CONFIG = {
    "base_model_path": "/data/models/stella_checkpoint_2ep",

    "triplet_data_path": "/data/datasets/prompted_train_data_full_run/triplet_data",
    "duplet_data_path": "/data/datasets/prompted_train_data_full_run/duplet_data",
    "query_pair_path": "/data/datasets/prompted_train_data_full_run/query_pair_data",
    "product_descriptions_path": "/data/datasets/product_descriptions_with_keywords_14072024.pkl",
    "deepspeed_config_path": "/app/notebooks/monil/deepspeed.config",
    "keep_in_memory": True,
    
    "test_samples": 100 if TRIAL else 10000,
    "dev_samples": 100 if TRIAL else 5000,
    "batch_size": 32,
    "accumulation_step": 1,
    "eval_steps": 10 if TRIAL else 10000,
    "num_epochs":3,

    "output_dir":f"{BASE_DIR}/model_checkpoints_bf16",
    "logging_dir":f"{BASE_DIR}/train_logs/run_4x32_in_memory_fixed",
    "datasets_dir":f"{BASE_DIR}/val_test_datasets"
}


if not os.path.exists(CONFIG["logging_dir"]): os.makedirs(CONFIG["logging_dir"])
if not os.path.exists(CONFIG["output_dir"]): os.makedirs(CONFIG["output_dir"])
if not os.path.exists(CONFIG["datasets_dir"]): os.makedirs(CONFIG["datasets_dir"])


if not TRIAL:
    logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, filename=CONFIG["logging_dir"]+"/logs.txt")
else:
    logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)  

logging.info(CONFIG)

# custom evaluator class
class SearchEvaluator(SentenceEvaluator):
    """
    This class evaluates a SentenceTransformer model for the task of re-ranking.
    """
    def __init__(self, dev_dict, product_descriptions, num_sample_products = 50000):

        random.seed(RANDOM_STATE)
        torch.manual_seed(RANDOM_STATE)
        self.product_descriptions = product_descriptions
        
        self.all_products = list({value for values in dev_dict.values() for value in values["product_code"]}.union({value for values in dev_dict.values() for value in values["anti_product_code"]}))
        self.all_products = list(random.sample(self.all_products, min(len(self.all_products), num_sample_products)))
        
        self.dev_dict = {key:{"product_code":values["product_code"].intersection(set(self.all_products)),"anti_product_code":values["anti_product_code"].intersection(set(self.all_products))}  for key, values in dev_dict.items()}
        
        self.max_top_k = max(list({len(value["product_code"]) for value in self.dev_dict.values()}))
        self.csv_file = "SearchEvaluator_results.csv"
        print("Unique products", len(self.all_products))
        print("Max top k", self.max_top_k)
        

    def __call__(self, model, output_path: str = None, epoch: int = -1, steps: int = -1) -> float:
        
        corpus_embeddings = model.encode([self.product_descriptions[code] for code in self.all_products], convert_to_tensor=True, show_progress_bar=False)
        queries = list(self.dev_dict.keys())
        
        query_embedding = model.encode(queries, convert_to_tensor=True, show_progress_bar=False)
        
#         print("Embeddings created")
        prec = 0
        valid_queries = 0
        for i, result in enumerate(util.semantic_search(query_embedding, corpus_embeddings, top_k=self.max_top_k)):
            orig_products = set(self.dev_dict[queries[i]]["product_code"])
            num_products = len(orig_products)
            if num_products>0:
                retrieved_products = {self.all_products[result[i]["corpus_id"]] for i in range(num_products)}
                prec+= (len(retrieved_products.intersection(orig_products))/num_products)
                valid_queries+=1

        prec = prec/valid_queries

        logging.info("precision on dev set: {}".format(steps, prec))

        del corpus_embeddings
        del query_embedding
        gc.collect()
        torch.cuda.empty_cache()
        
        return prec


def main():
    # Set the log level to INFO to get more information  

    print("entered main")
    
    logging.info(CONFIG)
    

    # loading data
    logging.info("Imports successful, loading data")
    
    product_descriptions = pd.read_pickle(CONFIG["product_descriptions_path"])
    logging.info("loaded {} unique product descriptions in catalogue".format(len(product_descriptions)))
  
    triplet_train_dataset = load_dataset("parquet", data_files=f"{CONFIG['triplet_data_path']}/train/*.parquet", cache_dir=CACHE_DIR, keep_in_memory = CONFIG["keep_in_memory"])
    logging.info(f"Triplet train dataset:  {triplet_train_dataset}")
    logging.info(triplet_train_dataset["train"][0])
      
    triplet_dev_dataset = load_dataset("parquet", data_files=f"{CONFIG['triplet_data_path']}/dev/*.parquet", cache_dir=CACHE_DIR, keep_in_memory = CONFIG["keep_in_memory"])
    logging.info(f"Triplet dev dataset:  {triplet_dev_dataset}")
    logging.info(triplet_dev_dataset["train"][0])
      
      
    grouped_dev_triplets = triplet_dev_dataset['train'].to_pandas().groupby("query").agg(set).reset_index()
    dev_dict = {row["query"]:{"product_code":row["selected_ids"], "anti_product_code":row["anti_selected_ids"]} for _,row in grouped_dev_triplets.iterrows()}
    logging.info("Triplet Data dev query count: {}".format(len(dev_dict)))
    logging.info(triplet_dev_dataset["train"][0])
  
    duplet_train_dataset = load_dataset("parquet", data_files=f"{CONFIG['duplet_data_path']}/train/*.parquet", cache_dir=CACHE_DIR, keep_in_memory = CONFIG["keep_in_memory"])
    logging.info(f"Duplet train dataset:  {duplet_train_dataset}")
    logging.info(duplet_train_dataset["train"][0])
  
    query_pair_train_dataset = load_dataset("parquet", data_files=f"{CONFIG['query_pair_path']}/train/*.parquet", cache_dir=CACHE_DIR, keep_in_memory = CONFIG["keep_in_memory"])
    logging.info(f"Query pair train dataset: {query_pair_train_dataset}")
    logging.info(query_pair_train_dataset["train"][0])
      

    final_train_dataset = {
          "product_triplets": triplet_train_dataset.select_columns(['query', 'positive', 'negative']).shuffle(seed=RANDOM_STATE).rename_column("query", "anchor")['train'],
          "product_duplets": duplet_train_dataset.select_columns(['query', 'positive']).shuffle(seed=RANDOM_STATE).rename_column("query", "anchor")['train'],
          "hinglish_duplets": query_pair_train_dataset.select_columns(['query', 'positive']).shuffle(seed=RANDOM_STATE).rename_column("query", "anchor")['train']
  
    }
  
    final_dev_dataset = {
          "product_triplets": triplet_dev_dataset.select_columns(['query', 'positive']).shuffle(seed=RANDOM_STATE).rename_column("query", "anchor")['train']
    }

    # model loading
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    logging.info("loading model for finetuning")
    device = f"cuda:{local_rank}"

    model = SentenceTransformer(CONFIG["base_model_path"], trust_remote_code=True, device=device)
    
    logging.info("Final Finetuning Model:\n{}".format(model))
    
    gc.collect()
    torch.cuda.empty_cache()


    model = model.to(f"cuda:{local_rank}")

    logging.info("triggering trainer")
    args = SentenceTransformerTrainingArguments(
        do_train=True,
        do_eval=True,
        # Required parameter:
        output_dir=CONFIG["output_dir"],
        overwrite_output_dir = False,
        # Optional training parameters:
        num_train_epochs=3,
        per_device_train_batch_size=CONFIG["batch_size"],
        per_device_eval_batch_size=CONFIG["batch_size"],
        gradient_accumulation_steps=CONFIG["accumulation_step"],
        load_best_model_at_end=True,
        metric_for_best_model = "eval_evaluator",
        learning_rate=5e-6,
        warmup_ratio=0.1,
        fp16=False,  # Set to False if you get an error that your GPU can't run on FP16
        bf16=True,  # Set to True if you have a GPU that supports BF16
        batch_sampler=BatchSamplers.NO_DUPLICATES,
        multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
        # Optional tracking/debugging parameters:
        max_grad_norm = 1.0,
        deepspeed = CONFIG["deepspeed_config_path"],
        disable_tqdm=False,
        logging_dir = CONFIG["logging_dir"],
        eval_strategy="steps",
        eval_steps=CONFIG["eval_steps"],
        save_strategy="steps",
        save_steps=CONFIG["eval_steps"],
        # save_steps = 10,
        save_total_limit=20,
        logging_steps=10,
        save_safetensors=False,
        eval_on_start=False,
        torch_empty_cache_steps=None,
        # ignore_data_skip=True
    )

    search_evaluator = SearchEvaluator(dev_dict, product_descriptions)

    train_loss = MultipleNegativesRankingLoss(model)

    trainer = SentenceTransformerTrainer(
        model=model,
        args=args,
        train_dataset=final_train_dataset,
        eval_dataset=final_dev_dataset,
        loss=train_loss,
        evaluator=search_evaluator,
    )
    trainer.train()
    
    trainer.save_model()
    trainer.save_state()


if __name__ == "__main__":
    main()

Deepspeed Config (standard config but providing incase it helps)

{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },

    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto"
        }
    },

    "zero_optimization": {
        "stage": 2
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto"
}

GPU utilisation for different batch sizes

  • We are using 4xA100 cluster, where each GPU has 80GB VRAM.
  • Batch Size of 32 uses around 19% of VRAM in each GPU
  • Batch Size of 64 uses around 35% of VRAM in each GPU
  • Batch Size of 128 uses around 60% of VRAM in each GPU

The VRAM usage is fairly constant through training and does not fluctuate much (barely 0.5-1%)

Conclusion

I am unable to understand why the training is slowing down as shown, for larger batch sizes. For using MNR loss, bigger batch sizes are preferred, and the training will also be done faster ideally, given the training works without this issue.

I have spent quite some time to understand what the issue here is, but have been unable to do so. Any help will be appreciated. Thanks!

@sidharthg-couture
Copy link
Author

I think the issue exists since the BatchSamplers.NO_DUPLICATES sampler takes longer to find a batch with no duplicates as the training keeps proceeding. so I tried increasing dataloader_num_workers to 32, but once the training starts, after the initial spikes the CPU cores don't seem to be used.

The training also seems to be more or less as slow as before. How could I tackle this issue?

I am mostly certain that its the increased time of creating the batch is causing the slowdown.

@tomaarsen
Copy link
Collaborator

Hmm. It is indeed possible that the NoDuplicatesBatchSampler is introducing a larger overhead on larger batches. I don't see a clear path for improving this, however. I'm also not sure if the increasing the dataloader_num_workers helps with the sampler - it might still just be one CPU core that takes care of the batch sampling.

  • Tom Aarsen

@sidharthg-couture
Copy link
Author

Hmm. It is indeed possible that the NoDuplicatesBatchSampler is introducing a larger overhead on larger batches. I don't see a clear path for improving this, however. I'm also not sure if the increasing the dataloader_num_workers helps with the sampler - it might still just be one CPU core that takes care of the batch sampling.

  • Tom Aarsen

Yeah, it seems that increasing the number of dataloader is not affecting the performance. And I have mostly made sure that NoDuplicatesBatchSampler is causing the issue, since when I used a standard batch sampler (BatchSamplers.BATCH_SAMPLER) the epoch/train graph is a straight line.

Would there be no possible fixes about this?

@tomaarsen
Copy link
Collaborator

The fix would be to speed up the batch sampler, but I don't know if there's room for improvement. Perhaps it's faster to hash each text and compare based on that rather than doing set overlap with strings:

while remaining_indices:
batch_values = set()
batch_indices = []
for index in remaining_indices:
sample_values = {
value
for key, value in self.dataset[index].items()
if not key.endswith("_prompt_length") and key != "dataset_name"
}
if sample_values & batch_values:
continue

Or perhaps set overlap to begin with is slower than e.g. doing set membership a few times over. I'm not sure.

  • Tom Aarsen

@sidharthg-couture
Copy link
Author

Okay, i will try looking into this. Along with it, the bottleneck could also be avoided if you have less duplicates in your training data right?

@tomaarsen
Copy link
Collaborator

I appreciate it!
And yes, if there's minimal or no duplicates then there's no need for the NoDuplicatesBatchSampler and you'll have faster training. If you have no duplicates but still use the NoDuplicatesBatchSampler, then I assume that it's still slower than without that batch sampler.

  • Tom Aarsen

@sidharthg-couture
Copy link
Author

I appreciate it! And yes, if there's minimal or no duplicates then there's no need for the NoDuplicatesBatchSampler and you'll have faster training. If you have no duplicates but still use the NoDuplicatesBatchSampler, then I assume that it's still slower than without that batch sampler.

  • Tom Aarsen

Yeah.. that makes sense. Thank you for the replies.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants