1717
1818class AWSModel (LM ):
1919 """This class adds support for an AWS model.
20+
2021 It is an abstract class and should not be instantiated directly.
2122 Instead, use one of the subclasses - AWSMistral, AWSAnthropic, or AWSMeta.
22- The subclasses implement the abstract methods _create_body and _call_model and work in conjunction with the AWSProvider classes Bedrock and Sagemaker.
23+ The subclasses implement the abstract methods _create_body and _call_model
24+ and work in conjunction with the AWSProvider classes Bedrock and Sagemaker.
2325 Usage Example:
2426 bedrock = dspy.Bedrock(region_name="us-west-2")
2527 bedrock_mixtral = dspy.AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs)
@@ -43,6 +45,7 @@ def __init__(
4345 model (str, optional): An LM name, e.g., a bedrock name or an AWS endpoint.
4446 max_context_size (int): The maximum context size in tokens.
4547 max_new_tokens (int): The maximum number of tokens to be sampled from the LM.
48+ **kwargs: Additional arguments.
4649 """
4750 super ().__init__ (model = model )
4851 self ._model_name : str = model
@@ -117,8 +120,10 @@ def __call__(
117120 There is only support for only_completed=True and return_sorted=False
118121 right now.
119122 """
120- assert only_completed , "for now"
121- assert return_sorted is False , "for now"
123+ if not only_completed :
124+ raise ValueError ("only_completed must be True for now" )
125+ if return_sorted :
126+ raise ValueError ("return_sorted must be False for now" )
122127
123128 generated = self .basic_request (prompt , ** kwargs )
124129 return [generated ]
@@ -182,8 +187,7 @@ def _call_model(self, body: str) -> str:
182187 else :
183188 raise ValueError ("Error - provider not recognized" )
184189
185- completion = completion .split (self .kwargs ["stop" ])[0 ]
186- return completion
190+ return completion .split (self .kwargs ["stop" ])[0 ]
187191
188192
189193class AWSAnthropic (AWSModel ):
@@ -247,12 +251,11 @@ def _call_model(self, body: str) -> str:
247251 body = body ,
248252 )
249253 response_body = json .loads (response ["body" ].read ())
250- completion = response_body ["content" ][0 ]["text" ]
251- return completion
254+ return response_body ["content" ][0 ]["text" ]
252255
253256
254257class AWSMeta (AWSModel ):
255- """Llama2 family of models."""
258+ """Llama3 family of models."""
256259
257260 def __init__ (
258261 self ,
@@ -275,10 +278,15 @@ def __init__(
275278 for k , v in kwargs .items ():
276279 self .kwargs [k ] = v
277280
278- self .kwargs ["max_gen_len" ] = self .kwargs .pop ("max_tokens" )
281+ def _format_prompt (self , raw_prompt : str ) -> str :
282+ return (
283+ "<|begin_of_text|><|start_header_id|>user<|end_header_id|>"
284+ + raw_prompt
285+ + "<|eot_id|><|start_header_id|>assistant<|end_header_id|>"
286+ )
279287
280288 def _create_body (self , prompt : str , ** kwargs ) -> tuple [int , dict [str , str | float ]]:
281- base_args : dict [str , Any ] = self .kwargs
289+ base_args : dict [str , Any ] = self .kwargs . copy ()
282290 for k , v in kwargs .items ():
283291 base_args [k ] = v
284292
@@ -290,6 +298,10 @@ def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | floa
290298 query_args .pop ("presence_penalty" , None )
291299 query_args .pop ("model" , None )
292300
301+ max_tokens = query_args .pop ("max_tokens" , None )
302+ if max_tokens :
303+ query_args ["max_gen_len" ] = max_tokens
304+
293305 query_args ["prompt" ] = prompt
294306 return (n , query_args )
295307
@@ -299,9 +311,4 @@ def _call_model(self, body: str) -> str:
299311 body = body ,
300312 )
301313 response_body = json .loads (response ["body" ].read ())
302- completion = response_body ["generation" ]
303-
304- stop = "\n \n "
305- completion = completion .split (stop )[0 ]
306-
307- return completion
314+ return response_body ["generation" ]
0 commit comments