Skip to content

Commit 49837a3

Browse files
author
Krista Opsahl-Ong
committed
adding in bayesian prompt optimizer
1 parent 072b20e commit 49837a3

File tree

2 files changed

+235
-1
lines changed

2 files changed

+235
-1
lines changed

dspy/teleprompt/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
from .finetune import *
66
from .teleprompt_optuna import *
77
from .knn_fewshot import *
8-
from .signature_opt import SignatureOptimizer
8+
from .signature_opt import SignatureOptimizer
9+
from .signature_opt_bayesian import BayesianSignatureOptimizer
Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
import dsp
2+
import dspy
3+
from dspy.teleprompt.teleprompt import Teleprompter
4+
from dspy.signatures import Signature
5+
from dspy.evaluate.evaluate import Evaluate
6+
from collections import defaultdict
7+
import random
8+
from dspy.teleprompt import BootstrapFewShot
9+
import numpy as np
10+
import optuna
11+
import math
12+
13+
random.seed(42)
14+
15+
"""
16+
USAGE SUGGESTIONS:
17+
18+
The following code can be used to compile a optimized signature teleprompter using the BayesianSignatureOptimizer, and evaluate it on an end task:
19+
20+
from dspy.teleprompt import BayesianSignatureOptimizer
21+
22+
teleprompter = BayesianSignatureOptimizer(prompt_model=prompt_model, task_model=task_model, metric=metric, n=10, init_temperature=1.0)
23+
kwargs = dict(num_threads=NUM_THREADS, display_progress=True, display_table=0)
24+
compiled_prompt_opt = teleprompter.compile(program, devset=devset[:DEV_NUM], optuna_trials_num=100, max_bootstrapped_demos=3, max_labeled_demos=5, eval_kwargs=kwargs)
25+
eval_score = evaluate(compiled_prompt_opt, devset=evalset[:EVAL_NUM], **kwargs)
26+
27+
Note that this teleprompter takes in the following parameters:
28+
29+
* prompt_model: The model used for prompt generation. When unspecified, defaults to the model set in settings (ie. dspy.settings.configure(lm=task_model)).
30+
* task_model: The model used for prompt generation. When unspecified, defaults to the model set in settings (ie. dspy.settings.configure(lm=task_model)).
31+
* metric: The task metric used for optimization.
32+
* n: The number of new prompts to generate and evaluate. Default=10.
33+
* init_temperature: The temperature used to generate new prompts. Higher roughly equals more creative. Default=1.0.
34+
* verbose: Tells the method whether or not to print intermediate steps.
35+
* track_stats: Tells the method whether or not to track statistics about the optimization process.
36+
If True, the method will track a dictionary with a key corresponding to the trial number,
37+
and a value containing a dict with the following keys:
38+
* program: the program being evaluated at a given trial
39+
* score: the last average evaluated score for the program
40+
* pruned: whether or not this program was pruned
41+
This information will be returned as attributes of the best program.
42+
"""
43+
class BasicGenerateInstruction(Signature):
44+
"""You are an instruction optimizer for large language models. I will give you a ``signature`` of fields (inputs and outputs) in English. Your task is to propose an instruction that will lead a good language model to perform the task well. Don't be afraid to be creative."""
45+
46+
basic_instruction = dspy.InputField(desc="The initial instructions before optimization")
47+
proposed_instruction = dspy.OutputField(desc="The improved instructions for the language model")
48+
proposed_prefix_for_output_field = dspy.OutputField(desc="The string at the end of the prompt, which will help the model start solving the task")
49+
50+
class BayesianSignatureOptimizer(Teleprompter):
51+
def __init__(self, prompt_model=None, task_model=None, n=10, metric=None, init_temperature=1.0, verbose=False, track_stats=False):
52+
self.n = n
53+
self.metric = metric
54+
self.init_temperature = init_temperature
55+
self.prompt_model = prompt_model
56+
self.task_model = task_model
57+
self.verbose = verbose
58+
self.track_stats = track_stats
59+
60+
def _print_full_program(self, program):
61+
for i,predictor in enumerate(program.predictors()):
62+
if self.verbose: print(f"Predictor {i}")
63+
if (hasattr(predictor, 'extended_signature')):
64+
if self.verbose: print(f"i: {predictor.extended_signature.instructions}")
65+
if self.verbose: print(f"p: {predictor.extended_signature.fields[-1].name}")
66+
else:
67+
if self.verbose: print(f"i: {predictor.extended_signature1.instructions}")
68+
if self.verbose: print(f"p: {predictor.extended_signature1.fields[-1].name}")
69+
if self.verbose: print("\n")
70+
71+
def _print_model_history(self, model, n=1):
72+
if self.verbose: print(f"Model ({model}) History:")
73+
history = model.inspect_history(n=1)
74+
for history_item in history:
75+
if self.verbose: print(f"{history_item}")
76+
77+
def _generate_first_N_candidates(self, module, N):
78+
candidates = {}
79+
evaluated_candidates = defaultdict(dict)
80+
81+
# Seed the prompt optimizer zero shot with just the instruction, generate BREADTH new prompts
82+
for predictor in module.predictors():
83+
basic_instruction = None
84+
basic_prefix = None
85+
if (hasattr(predictor, 'extended_signature')):
86+
basic_instruction = predictor.extended_signature.instructions
87+
basic_prefix = predictor.extended_signature.fields[-1].name
88+
else:
89+
basic_instruction = predictor.extended_signature1.instructions
90+
basic_prefix = predictor.extended_signature1.fields[-1].name
91+
if self.prompt_model:
92+
with dspy.settings.context(lm=self.prompt_model):
93+
instruct = dspy.Predict(BasicGenerateInstruction, n=N-1, temperature=self.init_temperature)(basic_instruction=basic_instruction)
94+
else:
95+
instruct = dspy.Predict(BasicGenerateInstruction, n=N-1, temperature=self.init_temperature)(basic_instruction=basic_instruction)
96+
# Add in our initial prompt as a candidate as well
97+
instruct.completions.proposed_instruction.append(basic_instruction)
98+
instruct.completions.proposed_prefix_for_output_field.append(basic_prefix)
99+
candidates[id(predictor)] = instruct.completions
100+
evaluated_candidates[id(predictor)] = {}
101+
102+
if self.verbose and self.prompt_model: self._print_model_history(self.prompt_model)
103+
104+
return candidates, evaluated_candidates
105+
106+
def compile(self, student, *, devset, optuna_trials_num, max_bootstrapped_demos, max_labeled_demos, eval_kwargs):
107+
108+
# Set up program and evaluation function
109+
module = student.deepcopy()
110+
evaluate = Evaluate(devset=devset, metric=self.metric, **eval_kwargs)
111+
112+
# Generate N candidate prompts
113+
instruction_candidates, _ = self._generate_first_N_candidates(module, self.n)
114+
115+
# Generate N few shot example sets
116+
demo_candidates = {}
117+
for seed in range(self.n):
118+
if self.verbose: print(f"Creating basic bootstrap {seed}/{self.n}")
119+
120+
# Create a new basic bootstrap few - shot program .
121+
rng = random.Random(seed)
122+
shuffled_devset = devset[:] # Create a copy of devset
123+
rng.shuffle(shuffled_devset) # Shuffle the copy
124+
tp = BootstrapFewShot(metric = self.metric, max_bootstrapped_demos=max_bootstrapped_demos, max_labeled_demos=max_labeled_demos)
125+
candidate_program = tp.compile(student=module.deepcopy(), trainset=shuffled_devset)
126+
127+
# Store the candidate demos
128+
for module_p, candidate_p in zip(module.predictors(), candidate_program.predictors()):
129+
if id(module_p) not in demo_candidates.keys():
130+
demo_candidates[id(module_p)] = []
131+
demo_candidates[id(module_p)].append(candidate_p.demos)
132+
133+
# Initialize variables to store the best program and its score
134+
best_score = float('-inf')
135+
best_program = None
136+
trial_num = 0
137+
138+
trial_logs = {}
139+
140+
# Define our trial objective
141+
def create_objective(baseline_program, instruction_candidates, demo_candidates, evaluate, devset):
142+
def objective(trial):
143+
nonlocal best_program, best_score, trial_num, trial_logs # Allow access to the outer variables
144+
candidate_program = baseline_program.deepcopy()
145+
146+
# Suggest the instruction to use for our predictor
147+
if self.verbose: print(f"Starting trial num: {trial_num}")
148+
trial_logs[trial_num] = {}
149+
150+
for p_old, p_new in zip(baseline_program.predictors(), candidate_program.predictors()):
151+
152+
# Get instruction candidates for our given predictor
153+
p_instruction_candidates = instruction_candidates[id(p_old)]
154+
p_demo_candidates = demo_candidates[id(p_old)]
155+
156+
# Suggest the index of the instruction candidate to use in our trial
157+
instruction_idx = trial.suggest_int(f"{id(p_old)}_predictor_instruction",low=0, high=len(p_instruction_candidates)-1)
158+
demos_idx = trial.suggest_int(f"{id(p_old)}_predictor_demos",low=0, high=len(p_demo_candidates)-1)
159+
trial_logs[trial_num]["instruction_idx"] = instruction_idx
160+
trial_logs[trial_num]["demos_idx"] = demos_idx
161+
162+
# Get the selected instruction candidate
163+
selected_candidate = p_instruction_candidates[instruction_idx]
164+
selected_instruction = selected_candidate.proposed_instruction.strip('"').strip()
165+
selected_prefix = selected_candidate.proposed_prefix_for_output_field.strip('"').strip()
166+
167+
# Use this candidates in our program
168+
p_new.extended_signature.instructions = selected_instruction
169+
p_new.extended_signature.fields[-1] = p_new.extended_signature.fields[-1]._replace(name=selected_prefix)
170+
171+
# Get the selected demos
172+
selected_demos = p_demo_candidates[demos_idx]
173+
174+
# Use these demos in our program
175+
p_new.demos = selected_demos
176+
177+
if self.verbose: print("Evaling the following program:")
178+
self._print_full_program(candidate_program)
179+
trial_logs[trial_num]["program"] = candidate_program
180+
181+
# Evaluate with the new prompts
182+
total_score, curr_avg_score = 0, 0
183+
batch_size = 100
184+
num_batches = math.ceil(len(devset) / batch_size)
185+
186+
for i in range(num_batches):
187+
start_index = i * batch_size
188+
end_index = min((i + 1) * batch_size, len(devset))
189+
split_dev = devset[start_index:end_index]
190+
split_score = evaluate(candidate_program, devset=split_dev, display_table=0)
191+
if self.verbose: print(f"{i}st split score: {split_score}")
192+
193+
total_score += split_score * len(split_dev)
194+
curr_weighted_avg_score = total_score / min((i+1)*100,len(devset))
195+
if self.verbose: print(f"curr average score: {curr_weighted_avg_score}")
196+
197+
trial.report(curr_weighted_avg_score, i)
198+
199+
# Handle pruning based on the intermediate value.
200+
if trial.should_prune():
201+
if self.verbose: print(f"Optuna decided to prune!")
202+
trial_logs[trial_num]["score"] = curr_weighted_avg_score
203+
trial_logs[trial_num]["pruned"] = True
204+
trial_num += 1
205+
raise optuna.TrialPruned()
206+
207+
if self.verbose: print(f"Fully evaled score: {curr_weighted_avg_score}")
208+
self._print_model_history(self.task_model, n=1)
209+
score = curr_weighted_avg_score
210+
211+
trial_logs[trial_num]["score"] = curr_weighted_avg_score
212+
trial_logs[trial_num]["pruned"] = False
213+
214+
# Update the best program if the current score is better
215+
if score > best_score:
216+
best_score = score
217+
best_program = candidate_program.deepcopy()
218+
219+
trial_num += 1
220+
221+
return score
222+
223+
return objective
224+
225+
# Run the trial
226+
objective_function = create_objective(module, instruction_candidates, demo_candidates, evaluate, devset)
227+
study = optuna.create_study(direction="maximize")
228+
score = study.optimize(objective_function, n_trials=optuna_trials_num)
229+
230+
if best_program is not None and self.track_stats:
231+
best_program.trial_logs = trial_logs
232+
233+
return best_program

0 commit comments

Comments
 (0)