-
Notifications
You must be signed in to change notification settings - Fork 2.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
update train_sts_seed_optimization with SentenceTransformerTrainer #3092
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,34 +23,25 @@ | |
python train_sts_seed_optimization.py bert-base-uncased 10 0.3 | ||
""" | ||
|
||
import csv | ||
import gzip | ||
import logging | ||
import math | ||
import os | ||
import random | ||
import sys | ||
|
||
import numpy as np | ||
import torch | ||
from torch.utils.data import DataLoader | ||
from datasets import load_dataset | ||
|
||
from sentence_transformers import LoggingHandler, SentenceTransformer, losses, models, util | ||
from sentence_transformers import LoggingHandler, SentenceTransformer, losses, models | ||
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator | ||
from sentence_transformers.readers import InputExample | ||
from sentence_transformers.similarity_functions import SimilarityFunction | ||
from sentence_transformers.trainer import SentenceTransformerTrainer | ||
from sentence_transformers.training_args import SentenceTransformerTrainingArguments | ||
|
||
#### Just some code to print debug information to stdout | ||
logging.basicConfig( | ||
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO, handlers=[LoggingHandler()] | ||
) | ||
#### /print debug information to stdout | ||
|
||
|
||
# Check if dataset exists. If not, download and extract it | ||
sts_dataset_path = "datasets/stsbenchmark.tsv.gz" | ||
|
||
if not os.path.exists(sts_dataset_path): | ||
util.http_get("https://sbert.net/datasets/stsbenchmark.tsv.gz", sts_dataset_path) | ||
|
||
|
||
# You can specify any huggingface/transformers pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base | ||
|
@@ -85,49 +76,61 @@ | |
|
||
model = SentenceTransformer(modules=[word_embedding_model, pooling_model]) | ||
|
||
# Convert the dataset to a DataLoader ready for training | ||
logging.info("Read STSbenchmark train dataset") | ||
|
||
train_samples = [] | ||
dev_samples = [] | ||
test_samples = [] | ||
with gzip.open(sts_dataset_path, "rt", encoding="utf8") as fIn: | ||
reader = csv.DictReader(fIn, delimiter="\t", quoting=csv.QUOTE_NONE) | ||
for row in reader: | ||
score = float(row["score"]) / 5.0 # Normalize score to range 0 ... 1 | ||
inp_example = InputExample(texts=[row["sentence1"], row["sentence2"]], label=score) | ||
|
||
if row["split"] == "dev": | ||
dev_samples.append(inp_example) | ||
elif row["split"] == "test": | ||
test_samples.append(inp_example) | ||
else: | ||
train_samples.append(inp_example) | ||
|
||
train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=train_batch_size) | ||
# 2. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb | ||
train_dataset = load_dataset("sentence-transformers/stsb", split="train") | ||
eval_dataset = load_dataset("sentence-transformers/stsb", split="validation") | ||
test_dataset = load_dataset("sentence-transformers/stsb", split="test") | ||
logging.info(train_dataset) | ||
|
||
train_loss = losses.CosineSimilarityLoss(model=model) | ||
|
||
logging.info("Read STSbenchmark dev dataset") | ||
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name="sts-dev") | ||
# 4. Define an evaluator for use during training. | ||
dev_evaluator = EmbeddingSimilarityEvaluator( | ||
sentences1=eval_dataset["sentence1"], | ||
sentences2=eval_dataset["sentence2"], | ||
scores=eval_dataset["score"], | ||
main_similarity=SimilarityFunction.COSINE, | ||
name="sts-dev", | ||
) | ||
|
||
# Configure the training. We skip evaluation in this example | ||
warmup_steps = math.ceil(len(train_dataloader) * num_epochs * 0.1) # 10% of train data for warm-up | ||
warmup_steps = math.ceil(len(train_dataset) * num_epochs * 0.1) # 10% of train data for warm-up | ||
|
||
# Stopping and Evaluating after 30% of training data (less than 1 epoch) | ||
# We find from (Dodge et al.) that 20-30% is often ideal for convergence of random seed | ||
steps_per_epoch = math.ceil(len(train_dataloader) * stop_after) | ||
steps_per_epoch = math.ceil(len(train_dataset) * stop_after) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this is used right now |
||
|
||
logging.info(f"Warmup-steps: {warmup_steps}") | ||
|
||
logging.info(f"Early-stopping: {int(stop_after * 100)}% of the training-data") | ||
|
||
# Train the model | ||
model.fit( | ||
train_objectives=[(train_dataloader, train_loss)], | ||
evaluator=evaluator, | ||
epochs=num_epochs, | ||
steps_per_epoch=steps_per_epoch, | ||
evaluation_steps=1000, | ||
# 5. Define the training arguments | ||
args = SentenceTransformerTrainingArguments( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the |
||
# Required parameter: | ||
output_dir=model_save_path, | ||
# Optional training parameters: | ||
num_train_epochs=num_epochs, | ||
per_device_train_batch_size=train_batch_size, | ||
per_device_eval_batch_size=train_batch_size, | ||
warmup_steps=warmup_steps, | ||
output_path=model_save_path, | ||
fp16=True, # Set to False if you get an error that your GPU can't run on FP16 | ||
bf16=False, # Set to True if you have a GPU that supports BF16 | ||
# Optional tracking/debugging parameters: | ||
evaluation_strategy="steps", | ||
eval_steps=stop_after, | ||
save_strategy="steps", | ||
save_steps=stop_after, | ||
logging_steps=stop_after, | ||
run_name="sts", # Will be used in W&B if `wandb` is installed | ||
) | ||
|
||
# 6. Create the trainer & start training | ||
trainer = SentenceTransformerTrainer( | ||
model=model, | ||
args=args, | ||
train_dataset=train_dataset, | ||
eval_dataset=eval_dataset, | ||
loss=train_loss, | ||
evaluator=dev_evaluator, | ||
) | ||
trainer.train() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
SentenceTransformerTrainingArguments
has awarmup_ratio=0.1
that we can use instead.