Skip to content

Commit 4379ead

Browse files
Merge pull request stanfordnlp#843 from drawal1/main
Fixed issue stanfordnlp#894 and stanfordnlp#858 (aws_models issues)
2 parents 4681778 + 9d582d9 commit 4379ead

File tree

2 files changed

+46
-31
lines changed

2 files changed

+46
-31
lines changed

dsp/modules/aws_models.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
class AWSModel(LM):
1919
"""This class adds support for an AWS model.
20+
2021
It is an abstract class and should not be instantiated directly.
2122
Instead, use one of the subclasses - AWSMistral, AWSAnthropic, or AWSMeta.
22-
The subclasses implement the abstract methods _create_body and _call_model and work in conjunction with the AWSProvider classes Bedrock and Sagemaker.
23+
The subclasses implement the abstract methods _create_body and _call_model
24+
and work in conjunction with the AWSProvider classes Bedrock and Sagemaker.
2325
Usage Example:
2426
bedrock = dspy.Bedrock(region_name="us-west-2")
2527
bedrock_mixtral = dspy.AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs)
@@ -43,6 +45,7 @@ def __init__(
4345
model (str, optional): An LM name, e.g., a bedrock name or an AWS endpoint.
4446
max_context_size (int): The maximum context size in tokens.
4547
max_new_tokens (int): The maximum number of tokens to be sampled from the LM.
48+
**kwargs: Additional arguments.
4649
"""
4750
super().__init__(model=model)
4851
self._model_name: str = model
@@ -117,8 +120,10 @@ def __call__(
117120
There is only support for only_completed=True and return_sorted=False
118121
right now.
119122
"""
120-
assert only_completed, "for now"
121-
assert return_sorted is False, "for now"
123+
if not only_completed:
124+
raise ValueError("only_completed must be True for now")
125+
if return_sorted:
126+
raise ValueError("return_sorted must be False for now")
122127

123128
generated = self.basic_request(prompt, **kwargs)
124129
return [generated]
@@ -182,8 +187,7 @@ def _call_model(self, body: str) -> str:
182187
else:
183188
raise ValueError("Error - provider not recognized")
184189

185-
completion = completion.split(self.kwargs["stop"])[0]
186-
return completion
190+
return completion.split(self.kwargs["stop"])[0]
187191

188192

189193
class AWSAnthropic(AWSModel):
@@ -247,12 +251,11 @@ def _call_model(self, body: str) -> str:
247251
body=body,
248252
)
249253
response_body = json.loads(response["body"].read())
250-
completion = response_body["content"][0]["text"]
251-
return completion
254+
return response_body["content"][0]["text"]
252255

253256

254257
class AWSMeta(AWSModel):
255-
"""Llama2 family of models."""
258+
"""Llama3 family of models."""
256259

257260
def __init__(
258261
self,
@@ -275,10 +278,15 @@ def __init__(
275278
for k, v in kwargs.items():
276279
self.kwargs[k] = v
277280

278-
self.kwargs["max_gen_len"] = self.kwargs.pop("max_tokens")
281+
def _format_prompt(self, raw_prompt: str) -> str:
282+
return (
283+
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>"
284+
+ raw_prompt
285+
+ "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
286+
)
279287

280288
def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]:
281-
base_args: dict[str, Any] = self.kwargs
289+
base_args: dict[str, Any] = self.kwargs.copy()
282290
for k, v in kwargs.items():
283291
base_args[k] = v
284292

@@ -290,6 +298,10 @@ def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | floa
290298
query_args.pop("presence_penalty", None)
291299
query_args.pop("model", None)
292300

301+
max_tokens = query_args.pop("max_tokens", None)
302+
if max_tokens:
303+
query_args["max_gen_len"] = max_tokens
304+
293305
query_args["prompt"] = prompt
294306
return (n, query_args)
295307

@@ -299,9 +311,4 @@ def _call_model(self, body: str) -> str:
299311
body=body,
300312
)
301313
response_body = json.loads(response["body"].read())
302-
completion = response_body["generation"]
303-
304-
stop = "\n\n"
305-
completion = completion.split(stop)[0]
306-
307-
return completion
314+
return response_body["generation"]

tests/modules/test_aws_models.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,42 +6,50 @@
66
import dsp
77
import dspy
88

9+
910
def get_lm(lm_provider: str, model_path: str, **kwargs) -> dsp.modules.lm.LM:
1011
"""get the language model"""
1112
# extract model vendor and name from model name
1213
# Model path format is <MODEL_VENDOR>/<MODEL_NAME_OR_ENDPOINT>
13-
model_vendor = model_path.split('/')[0]
14-
model_name = model_path.split('/')[1]
14+
model_vendor = model_path.split("/")[0]
15+
model_name = model_path.split("/")[1]
1516

16-
if lm_provider == 'Bedrock':
17+
if lm_provider == "Bedrock":
1718
bedrock = dspy.Bedrock(region_name="us-west-2")
18-
if model_vendor == 'mistral':
19+
if model_vendor == "mistral":
1920
return dspy.AWSMistral(bedrock, model_name, **kwargs)
20-
elif model_vendor == 'anthropic':
21+
elif model_vendor == "anthropic":
2122
return dspy.AWSAnthropic(bedrock, model_name, **kwargs)
22-
elif model_vendor == 'meta':
23+
elif model_vendor == "meta":
2324
return dspy.AWSMeta(bedrock, model_name, **kwargs)
2425
else:
25-
raise ValueError("Model vendor missing or unsupported: Model path format is <MODEL_VENDOR>/<MODEL_NAME_OR_ENDPOINT>")
26-
elif lm_provider == 'Sagemaker':
26+
raise ValueError(
27+
"Model vendor missing or unsupported: Model path format is <MODEL_VENDOR>/<MODEL_NAME_OR_ENDPOINT>"
28+
)
29+
elif lm_provider == "Sagemaker":
2730
sagemaker = dspy.Sagemaker(region_name="us-west-2")
28-
if model_vendor == 'mistral':
31+
if model_vendor == "mistral":
2932
return dspy.AWSMistral(sagemaker, model_name, **kwargs)
30-
elif model_vendor == 'meta':
33+
elif model_vendor == "meta":
3134
return dspy.AWSMeta(sagemaker, model_name, **kwargs)
3235
else:
33-
raise ValueError("Model vendor missing or unsupported: Model path format is <MODEL_VENDOR>/<MODEL_NAME_OR_ENDPOINT>")
36+
raise ValueError(
37+
"Model vendor missing or unsupported: Model path format is <MODEL_VENDOR>/<MODEL_NAME_OR_ENDPOINT>"
38+
)
3439
else:
3540
raise ValueError(f"Unsupported model: {model_name}")
3641

42+
3743
def run_tests():
3844
"""Test the providers and models"""
3945
# Configure your AWS credentials with the AWS CLI before running this script
4046
provider_model_tuples = [
41-
('Bedrock', 'mistral/mistral.mixtral-8x7b-instruct-v0:1'),
42-
('Bedrock', 'anthropic/anthropic.claude-3-haiku-20240307-v1:0'),
43-
('Bedrock', 'anthropic/anthropic.claude-3-sonnet-20240229-v1:0'),
44-
('Bedrock', 'meta/meta.llama2-70b-chat-v1'),
47+
("Bedrock", "mistral/mistral.mixtral-8x7b-instruct-v0:1"),
48+
("Bedrock", "anthropic/anthropic.claude-3-haiku-20240307-v1:0"),
49+
("Bedrock", "anthropic/anthropic.claude-3-sonnet-20240229-v1:0"),
50+
("Bedrock", "meta/meta.llama2-70b-chat-v1"),
51+
("Bedrock", "meta/meta.llama3-8b-instruct-v1:0"),
52+
("Bedrock", "meta/meta.llama3-70b-instruct-v1:0"),
4553
# ('Sagemaker', 'mistral/<YOUR_ENDPOINT_NAME>'), # REPLACE YOUR_ENDPOINT_NAME with your sagemaker endpoint
4654
]
4755

0 commit comments

Comments
 (0)