Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update train_sts_seed_optimization with SentenceTransformerTrainer #3092

Merged
merged 4 commits into from
Dec 2, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 49 additions & 46 deletions examples/training/data_augmentation/train_sts_seed_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,25 @@
python train_sts_seed_optimization.py bert-base-uncased 10 0.3
"""

import csv
import gzip
import logging
import math
import os
import random
import sys

import numpy as np
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset

from sentence_transformers import LoggingHandler, SentenceTransformer, losses, models, util
from sentence_transformers import LoggingHandler, SentenceTransformer, losses, models
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.readers import InputExample
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.trainer import SentenceTransformerTrainer
from sentence_transformers.training_args import SentenceTransformerTrainingArguments

#### Just some code to print debug information to stdout
logging.basicConfig(
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()]
)
#### /print debug information to stdout


# Check if dataset exists. If not, download and extract it
sts_dataset_path = "datasets/stsbenchmark.tsv.gz"

if not os.path.exists(sts_dataset_path):
util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path)


# You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base
Expand Down Expand Up @@ -85,49 +76,61 @@

model = SentenceTransformer(modules=[word_embedding_model, pooling_model])

# Convert the dataset to a DataLoader ready for training
logging.info("Read STSbenchmark train dataset")

train_samples = []
dev_samples = []
test_samples = []
with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn:
reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE)
for row in reader:
score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1
inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score)

if row["split"] == "dev":
dev_samples.append(inp_example)
elif row["split"] == "test":
test_samples.append(inp_example)
else:
train_samples.append(inp_example)

train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size)
# 2. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb
train_dataset = load_dataset("sentence-transformers/stsb", split="train")
eval_dataset = load_dataset("sentence-transformers/stsb", split="validation")
test_dataset = load_dataset("sentence-transformers/stsb", split="test")
logging.info(train_dataset)

train_loss = losses.CosineSimilarityLoss(model=model)

logging.info("Read STSbenchmark dev dataset")
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev")
# 4. Define an evaluator for use during training.
dev_evaluator = EmbeddingSimilarityEvaluator(
sentences1=eval_dataset["sentence1"],
sentences2=eval_dataset["sentence2"],
scores=eval_dataset["score"],
main_similarity=SimilarityFunction.COSINE,
name="sts-dev",
)

# Configure the training. We skip evaluation in this example
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up
warmup_steps = math.ceil(len(train_dataset) * num_epochs * 0.1) # 10% of train data for warm-up
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The SentenceTransformerTrainingArguments has a warmup_ratio=0.1 that we can use instead.


# Stopping and Evaluating after 30% of training data (less than 1 epoch)
# We find from (Dodge et al.) that 20-30% is often ideal for convergence of random seed
steps_per_epoch = math.ceil(len(train_dataloader) * stop_after)
steps_per_epoch = math.ceil(len(train_dataset) * stop_after)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is used right now


logging.info(f"Warmup-steps: {warmup_steps}")

logging.info(f"Early-stopping: {int(stop_after * 100)}% of the training-data")

# Train the model
model.fit(
train_objectives=[(train_dataloader, train_loss)],
evaluator=evaluator,
epochs=num_epochs,
steps_per_epoch=steps_per_epoch,
evaluation_steps=1000,
# 5. Define the training arguments
args = SentenceTransformerTrainingArguments(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the stop_after isn't actually making it stop after this many steps.
Normally you can use max_steps, but then I think it messes with the scheduler, ideally we want the scheduler to be "normal" but then still stop after stop_after steps, but I'm not sure if that's the old behaviour either.

# Required parameter:
output_dir=model_save_path,
# Optional training parameters:
num_train_epochs=num_epochs,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=train_batch_size,
warmup_steps=warmup_steps,
output_path=model_save_path,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
# Optional tracking/debugging parameters:
evaluation_strategy="steps",
eval_steps=stop_after,
save_strategy="steps",
save_steps=stop_after,
logging_steps=stop_after,
run_name="sts", # Will be used in W&B if `wandb` is installed
)

# 6. Create the trainer & start training
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=train_loss,
evaluator=dev_evaluator,
)
trainer.train()