Skip to content

Commit

Permalink
[WIP] Dev finetune (stanfordnlp#1796)
Browse files Browse the repository at this point in the history
* Re-add lm.launch_kwargs

* Re-add launch_kwargs

* Remove extra logs from bootstrap_finetune.py

* Remove extra logs from provider.py

* Update logs for bettertogether.py

* Add status updates to openai.py

* Remoce extra log from bootstrap_finetune.py

* Update logs in bootstrap_finetune.py

* Update openai.py

* Update openai.py

* Update openai.py

* Update openai.py

* Log OpenAI training messages

* Update bettertogether.py

* Update bettertogether.py

* Update bootstrap_finetune.py
  • Loading branch information
dilarasoylu authored Nov 14, 2024
1 parent 87aedfe commit 0509a0f
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 21 deletions.
10 changes: 7 additions & 3 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
num_retries: int = 3,
provider=None,
finetuning_model: Optional[str] = None,
launch_kwargs: Optional[dict[str, Any]] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
self.callbacks = callbacks or []
self.num_retries = num_retries
self.finetuning_model = finetuning_model
self.launch_kwargs = launch_kwargs

# TODO(bug): Arbitrary model strings could include the substring "o1-".
# We should find a more robust way to check for the "o1-" family models.
Expand Down Expand Up @@ -113,10 +115,12 @@ def __call__(self, prompt=None, messages=None, **kwargs):
return outputs

def launch(self, launch_kwargs: Optional[Dict[str, Any]] = None):
self.provider.launch(self.model, **launch_kwargs)
launch_kwargs = launch_kwargs or self.launch_kwargs
self.provider.launch(self.model, launch_kwargs)

def kill(self, kill_kwargs: Optional[Dict[str, Any]] = None):
self.provider.kill(self.model, **kill_kwargs)
def kill(self, launch_kwargs: Optional[Dict[str, Any]] = None):
launch_kwargs = launch_kwargs or self.launch_kwargs
self.provider.kill(self.model, launch_kwargs)

def finetune(
self,
Expand Down
26 changes: 24 additions & 2 deletions dspy/clients/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import time
from datetime import datetime
from typing import Any, Dict, List, Optional

import openai
Expand Down Expand Up @@ -248,11 +249,32 @@ def wait_for_job(
job: TrainingJobOpenAI,
poll_frequency: int = 20,
):
# Poll for the job until it is done
done = False
cur_event_id = None
reported_estimated_time = False
while not done:
done = OpenAIProvider.is_terminal_training_status(job.status())
# Report estimated time if not already reported
if not reported_estimated_time:
remote_job = openai.fine_tuning.jobs.retrieve(job.provider_job_id)
timestamp = remote_job.estimated_finish
if timestamp:
estimated_finish_dt = datetime.fromtimestamp(timestamp)
delta_dt = estimated_finish_dt - datetime.now()
print(f"[OpenAI Provider] The OpenAI estimated time remaining is: {delta_dt}")
reported_estimated_time = True

# Get new events
page = openai.fine_tuning.jobs.list_events(fine_tuning_job_id=job.provider_job_id, limit=1)
new_event = page.data[0] if page.data else None
if new_event and new_event.id != cur_event_id:
dt = datetime.fromtimestamp(new_event.created_at)
print(f"[OpenAI Provider] {dt} {new_event.message}")
cur_event_id = new_event.id

# Sleep and update the flag
time.sleep(poll_frequency)

done = OpenAIProvider.is_terminal_training_status(job.status())

@staticmethod
def get_trained_model(job):
Expand Down
8 changes: 2 additions & 6 deletions dspy/clients/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,11 @@ def is_provider_model(model: str) -> bool:

@staticmethod
def launch(model: str, launch_kwargs: Optional[Dict[str, Any]] = None):
msg = f"`launch()` is called for the auto-launched model `{model}`"
msg += " -- no action is taken!"
print(msg)
pass

@staticmethod
def kill(model: str, kill_kwargs: Optional[Dict[str, Any]] = None):
msg = f"`kill()` is called for the auto-launched model `{model}`"
msg += " -- no action is taken!"
print(msg)
pass

@staticmethod
def finetune(
Expand Down
9 changes: 7 additions & 2 deletions dspy/teleprompt/bettertogether.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,13 @@ def compile(
student = prepare_student(student)
set_missing_predictor_lms(student)

# Make a shallow copy of the trainset, so that we don't change the order
# of the examples in the original trainset
trainset = trainset[:]
print("[BetterTogether] Compiling the student program...")
student = self._run_strategies(parsed_strategy, student, trainset, valset_ratio)

print("[BetterTogether] BetterTogether has finished compiling the student program.")
print("[BetterTogether] BetterTogether has finished compiling the student program")
return student

def _run_strategies(self, parsed_strategy, student, trainset, valset_ratio) -> Program:
Expand All @@ -80,7 +83,7 @@ def _run_strategies(self, parsed_strategy, student, trainset, valset_ratio) -> P

for ind, step_code in enumerate(parsed_strategy):
current_strategy = self.STRAT_SEP.join(parsed_strategy[:ind + 1])
print(f"[BetterTogether] Step {ind + 1} of {len(parsed_strategy)} - Strategy `{current_strategy}`")
print(f"\n[BetterTogether] ########## Step {ind + 1} of {len(parsed_strategy)} - Strategy '{current_strategy}' ##########")

print("[BetterTogether] Shuffling the trainset...")
self.rng.shuffle(trainset)
Expand All @@ -104,6 +107,8 @@ def _compile_prompt_optimizer(self, student, trainset, valset_ratio) -> Program:
print("[BetterTogether] Preparing for prompt optimization...")

# Sampling a validation set from the trainset for the prompt optimizer
# We drop the hints for prompt optimization
trainset = [x.with_inputs(*list(set(x.inputs().keys()) - {"hint"})) for x in trainset]
num_val = int(valset_ratio * len(trainset))
prompt_valset = trainset[:num_val]
prompt_trainset = trainset[num_val:]
Expand Down
14 changes: 6 additions & 8 deletions dspy/teleprompt/bootstrap_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def compile(self, student: Program, trainset: List[Example], teacher: Optional[P
training_key = (pred.lm, data_pred_ind)
if training_key not in key_to_data:
train_data, data_format = self._prepare_finetune_data(trace_data=trace_data, lm=pred.lm, pred_ind=data_pred_ind)
print(f"Using {len(train_data)} data points for fine-tuning the model: {pred.lm.model}")
print(f"[BootstrapFinetune] Using {len(train_data)} data points for fine-tuning the model: {pred.lm.model}")
finetune_kwargs = dict(lm=pred.lm, train_data=train_data, train_kwargs=self.train_kwargs[pred.lm], data_format=data_format)
key_to_data[training_key] = finetune_kwargs

Expand Down Expand Up @@ -108,7 +108,7 @@ def compile(self, student: Program, trainset: List[Example], teacher: Optional[P
@staticmethod
def finetune_lms(finetune_dict) -> Dict[Any, LM]:
num_jobs = len(finetune_dict)
print(f"[BootstrapFinetune] Starting {num_jobs} fine-tuning jobs...")
print(f"[BootstrapFinetune] Starting {num_jobs} fine-tuning job(s)...")
# TODO(nit) Pass an identifier to the job so that we can tell the logs
# coming from different fine-tune threads.

Expand All @@ -121,7 +121,7 @@ def finetune_lms(finetune_dict) -> Dict[Any, LM]:
for ind, (key, job) in enumerate(key_to_job.items()):
key_to_lm[key] = job.result()
job.thread.join()
print(f"Job {ind + 1}/{num_jobs} completed.")
print(f"[BootstrapFinetune] Job {ind + 1}/{num_jobs} is done")

return key_to_lm

Expand All @@ -130,7 +130,7 @@ def _prepare_finetune_data(self, trace_data: List[Dict[str, Any]], lm: LM, pred_
if self.metric:
print(f"[BootstrapFinetune] Collected data for {len(trace_data)} examples")
trace_data = [d for d in trace_data if d["score"]]
print(f"[BootstrapFinetune] After filtering for score, {len(trace_data)} examples remain")
print(f"[BootstrapFinetune] After filtering with the metric, {len(trace_data)} examples remain")

data = []
adapter = self.adapter[lm] or lm.infer_adapter()
Expand Down Expand Up @@ -234,7 +234,6 @@ def set_missing_predictor_lms(program: Program) -> Program:


def prepare_student(student: Program) -> Program:
print("Ensuring that the student is not compiled")
if getattr(student, "_compiled", False):
raise ValueError("The student program should not be compiled.")

Expand All @@ -246,15 +245,14 @@ def prepare_student(student: Program) -> Program:

def prepare_teacher(student: Program, teacher: Program = None) -> Program:
if teacher is None:
print("No teacher provided. Using a copy of the student program as the teacher.")
return student.deepcopy()
else:
teacher = teacher.deepcopy()

print("Ensuring that the student and teacher are are structurally equivalent.")
# Ensuring that the student and teacher are are structurally equivalent
assert_structural_equivalency(student, teacher)

print("Ensuring that the student and teacher programs do not share predictors.")
# Ensuring that the student and teacher programs do not share predictors
assert_no_shared_predictor(student, teacher)

return teacher
Expand Down

0 comments on commit 0509a0f

Please sign in to comment.