Skip to content

Commit 0509a0f

Browse files
authored
[WIP] Dev finetune (stanfordnlp#1796)
* Re-add lm.launch_kwargs * Re-add launch_kwargs * Remove extra logs from bootstrap_finetune.py * Remove extra logs from provider.py * Update logs for bettertogether.py * Add status updates to openai.py * Remoce extra log from bootstrap_finetune.py * Update logs in bootstrap_finetune.py * Update openai.py * Update openai.py * Update openai.py * Update openai.py * Log OpenAI training messages * Update bettertogether.py * Update bettertogether.py * Update bootstrap_finetune.py
1 parent 87aedfe commit 0509a0f

File tree

5 files changed

+46
-21
lines changed

5 files changed

+46
-21
lines changed

dspy/clients/lm.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
num_retries: int = 3,
3737
provider=None,
3838
finetuning_model: Optional[str] = None,
39+
launch_kwargs: Optional[dict[str, Any]] = None,
3940
**kwargs,
4041
):
4142
"""
@@ -68,6 +69,7 @@ def __init__(
6869
self.callbacks = callbacks or []
6970
self.num_retries = num_retries
7071
self.finetuning_model = finetuning_model
72+
self.launch_kwargs = launch_kwargs
7173

7274
# TODO(bug): Arbitrary model strings could include the substring "o1-".
7375
# We should find a more robust way to check for the "o1-" family models.
@@ -113,10 +115,12 @@ def __call__(self, prompt=None, messages=None, **kwargs):
113115
return outputs
114116

115117
def launch(self, launch_kwargs: Optional[Dict[str, Any]] = None):
116-
self.provider.launch(self.model, **launch_kwargs)
118+
launch_kwargs = launch_kwargs or self.launch_kwargs
119+
self.provider.launch(self.model, launch_kwargs)
117120

118-
def kill(self, kill_kwargs: Optional[Dict[str, Any]] = None):
119-
self.provider.kill(self.model, **kill_kwargs)
121+
def kill(self, launch_kwargs: Optional[Dict[str, Any]] = None):
122+
launch_kwargs = launch_kwargs or self.launch_kwargs
123+
self.provider.kill(self.model, launch_kwargs)
120124

121125
def finetune(
122126
self,

dspy/clients/openai.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22
import time
3+
from datetime import datetime
34
from typing import Any, Dict, List, Optional
45

56
import openai
@@ -248,11 +249,32 @@ def wait_for_job(
248249
job: TrainingJobOpenAI,
249250
poll_frequency: int = 20,
250251
):
252+
# Poll for the job until it is done
251253
done = False
254+
cur_event_id = None
255+
reported_estimated_time = False
252256
while not done:
253-
done = OpenAIProvider.is_terminal_training_status(job.status())
257+
# Report estimated time if not already reported
258+
if not reported_estimated_time:
259+
remote_job = openai.fine_tuning.jobs.retrieve(job.provider_job_id)
260+
timestamp = remote_job.estimated_finish
261+
if timestamp:
262+
estimated_finish_dt = datetime.fromtimestamp(timestamp)
263+
delta_dt = estimated_finish_dt - datetime.now()
264+
print(f"[OpenAI Provider] The OpenAI estimated time remaining is: {delta_dt}")
265+
reported_estimated_time = True
266+
267+
# Get new events
268+
page = openai.fine_tuning.jobs.list_events(fine_tuning_job_id=job.provider_job_id, limit=1)
269+
new_event = page.data[0] if page.data else None
270+
if new_event and new_event.id != cur_event_id:
271+
dt = datetime.fromtimestamp(new_event.created_at)
272+
print(f"[OpenAI Provider] {dt} {new_event.message}")
273+
cur_event_id = new_event.id
274+
275+
# Sleep and update the flag
254276
time.sleep(poll_frequency)
255-
277+
done = OpenAIProvider.is_terminal_training_status(job.status())
256278

257279
@staticmethod
258280
def get_trained_model(job):

dspy/clients/provider.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,11 @@ def is_provider_model(model: str) -> bool:
4545

4646
@staticmethod
4747
def launch(model: str, launch_kwargs: Optional[Dict[str, Any]] = None):
48-
msg = f"`launch()` is called for the auto-launched model `{model}`"
49-
msg += " -- no action is taken!"
50-
print(msg)
48+
pass
5149

5250
@staticmethod
5351
def kill(model: str, kill_kwargs: Optional[Dict[str, Any]] = None):
54-
msg = f"`kill()` is called for the auto-launched model `{model}`"
55-
msg += " -- no action is taken!"
56-
print(msg)
52+
pass
5753

5854
@staticmethod
5955
def finetune(

dspy/teleprompt/bettertogether.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,13 @@ def compile(
6666
student = prepare_student(student)
6767
set_missing_predictor_lms(student)
6868

69+
# Make a shallow copy of the trainset, so that we don't change the order
70+
# of the examples in the original trainset
71+
trainset = trainset[:]
6972
print("[BetterTogether] Compiling the student program...")
7073
student = self._run_strategies(parsed_strategy, student, trainset, valset_ratio)
7174

72-
print("[BetterTogether] BetterTogether has finished compiling the student program.")
75+
print("[BetterTogether] BetterTogether has finished compiling the student program")
7376
return student
7477

7578
def _run_strategies(self, parsed_strategy, student, trainset, valset_ratio) -> Program:
@@ -80,7 +83,7 @@ def _run_strategies(self, parsed_strategy, student, trainset, valset_ratio) -> P
8083

8184
for ind, step_code in enumerate(parsed_strategy):
8285
current_strategy = self.STRAT_SEP.join(parsed_strategy[:ind + 1])
83-
print(f"[BetterTogether] Step {ind + 1} of {len(parsed_strategy)} - Strategy `{current_strategy}`")
86+
print(f"\n[BetterTogether] ########## Step {ind + 1} of {len(parsed_strategy)} - Strategy '{current_strategy}' ##########")
8487

8588
print("[BetterTogether] Shuffling the trainset...")
8689
self.rng.shuffle(trainset)
@@ -104,6 +107,8 @@ def _compile_prompt_optimizer(self, student, trainset, valset_ratio) -> Program:
104107
print("[BetterTogether] Preparing for prompt optimization...")
105108

106109
# Sampling a validation set from the trainset for the prompt optimizer
110+
# We drop the hints for prompt optimization
111+
trainset = [x.with_inputs(*list(set(x.inputs().keys()) - {"hint"})) for x in trainset]
107112
num_val = int(valset_ratio * len(trainset))
108113
prompt_valset = trainset[:num_val]
109114
prompt_trainset = trainset[num_val:]

dspy/teleprompt/bootstrap_finetune.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def compile(self, student: Program, trainset: List[Example], teacher: Optional[P
7979
training_key = (pred.lm, data_pred_ind)
8080
if training_key not in key_to_data:
8181
train_data, data_format = self._prepare_finetune_data(trace_data=trace_data, lm=pred.lm, pred_ind=data_pred_ind)
82-
print(f"Using {len(train_data)} data points for fine-tuning the model: {pred.lm.model}")
82+
print(f"[BootstrapFinetune] Using {len(train_data)} data points for fine-tuning the model: {pred.lm.model}")
8383
finetune_kwargs = dict(lm=pred.lm, train_data=train_data, train_kwargs=self.train_kwargs[pred.lm], data_format=data_format)
8484
key_to_data[training_key] = finetune_kwargs
8585

@@ -108,7 +108,7 @@ def compile(self, student: Program, trainset: List[Example], teacher: Optional[P
108108
@staticmethod
109109
def finetune_lms(finetune_dict) -> Dict[Any, LM]:
110110
num_jobs = len(finetune_dict)
111-
print(f"[BootstrapFinetune] Starting {num_jobs} fine-tuning jobs...")
111+
print(f"[BootstrapFinetune] Starting {num_jobs} fine-tuning job(s)...")
112112
# TODO(nit) Pass an identifier to the job so that we can tell the logs
113113
# coming from different fine-tune threads.
114114

@@ -121,7 +121,7 @@ def finetune_lms(finetune_dict) -> Dict[Any, LM]:
121121
for ind, (key, job) in enumerate(key_to_job.items()):
122122
key_to_lm[key] = job.result()
123123
job.thread.join()
124-
print(f"Job {ind + 1}/{num_jobs} completed.")
124+
print(f"[BootstrapFinetune] Job {ind + 1}/{num_jobs} is done")
125125

126126
return key_to_lm
127127

@@ -130,7 +130,7 @@ def _prepare_finetune_data(self, trace_data: List[Dict[str, Any]], lm: LM, pred_
130130
if self.metric:
131131
print(f"[BootstrapFinetune] Collected data for {len(trace_data)} examples")
132132
trace_data = [d for d in trace_data if d["score"]]
133-
print(f"[BootstrapFinetune] After filtering for score, {len(trace_data)} examples remain")
133+
print(f"[BootstrapFinetune] After filtering with the metric, {len(trace_data)} examples remain")
134134

135135
data = []
136136
adapter = self.adapter[lm] or lm.infer_adapter()
@@ -234,7 +234,6 @@ def set_missing_predictor_lms(program: Program) -> Program:
234234

235235

236236
def prepare_student(student: Program) -> Program:
237-
print("Ensuring that the student is not compiled")
238237
if getattr(student, "_compiled", False):
239238
raise ValueError("The student program should not be compiled.")
240239

@@ -246,15 +245,14 @@ def prepare_student(student: Program) -> Program:
246245

247246
def prepare_teacher(student: Program, teacher: Program = None) -> Program:
248247
if teacher is None:
249-
print("No teacher provided. Using a copy of the student program as the teacher.")
250248
return student.deepcopy()
251249
else:
252250
teacher = teacher.deepcopy()
253251

254-
print("Ensuring that the student and teacher are are structurally equivalent.")
252+
# Ensuring that the student and teacher are are structurally equivalent
255253
assert_structural_equivalency(student, teacher)
256254

257-
print("Ensuring that the student and teacher programs do not share predictors.")
255+
# Ensuring that the student and teacher programs do not share predictors
258256
assert_no_shared_predictor(student, teacher)
259257

260258
return teacher

0 commit comments

Comments
 (0)