@@ -28,6 +28,7 @@ def __init__(
2828 max_new_tokens : int ,
2929 truncate_long_prompts : bool = False ,
3030 input_output_ratio : int = 3 ,
31+ batch_n : bool = True ,
3132 ) -> None :
3233 """_summary_
3334
@@ -40,6 +41,7 @@ def __init__(
4041 input_output_ratio (int, optional): The rough size of the number of input tokens to output tokens in the worst case. Defaults to 3.
4142 temperature (float, optional): _description_. Defaults to 0.0.
4243 truncate_long_prompts (bool, optional): If True, remove extremely long inputs to context. Defaults to False.
44+ batch_n (bool, False): If False, call the LM N times rather than batching. Not all AWS models support the n parameter.
4345 """
4446 super ().__init__ (model = model )
4547 # AWS doesn't have an equivalent of max_tokens so let's clarify
@@ -48,9 +50,10 @@ def __init__(
4850 self ._max_new_tokens : int = max_new_tokens
4951 self ._model_name : str = model
5052 self ._truncate_long_prompt_prompts : bool = truncate_long_prompts
53+ self ._batch_n : bool = batch_n
5154
5255 import boto3
53-
56+
5457 self .predictor = boto3 .client (service_name , region_name = region_name )
5558
5659 @abstractmethod
@@ -72,7 +75,7 @@ def _sanitize_kwargs(self, query_kwargs: dict[str, Any]) -> dict[str, Any]:
7275 return query_kwargs
7376
7477 @abstractmethod
75- def _call_model (self , body : str ) -> str :
78+ def _call_model (self , body : str ) -> str | list [ str ] :
7679 """Call model, get generated input without the formatted prompt"""
7780 pass
7881
@@ -82,7 +85,20 @@ def _extract_input_parameters(
8285 ) -> dict [str , str | float | int ]:
8386 pass
8487
85- def basic_request (self , prompt , ** kwargs ) -> str :
88+ def _simple_api_call (self , formatted_prompt : str , ** kwargs ) -> str | list [str ]:
89+ body = self ._create_body (formatted_prompt , ** kwargs )
90+ json_body = json .dumps (body )
91+ llm_out : str | list [str ] = self ._call_model (json_body )
92+ if isinstance (llm_out , str ):
93+ llm_out = llm_out .replace (formatted_prompt , "" )
94+ else :
95+ llm_out = [generated .replace (formatted_prompt , "" ) for generated in llm_out ]
96+ self .history .append (
97+ {"prompt" : formatted_prompt , "response" : llm_out , "kwargs" : body }
98+ )
99+ return llm_out
100+
101+ def basic_request (self , prompt , ** kwargs ) -> str | list [str ]:
86102 """Query the endpoint."""
87103
88104 # Remove any texts that are too long
@@ -92,16 +108,28 @@ def basic_request(self, prompt, **kwargs) -> str:
92108 formatted_prompt = self ._format_prompt (truncated_prompt )
93109 else :
94110 formatted_prompt = self ._format_prompt ((prompt ))
95- body = self ._create_body (formatted_prompt , ** kwargs )
96- json_body : str = json .dumps (body )
97-
98- generated : str = self ._call_model (json_body )
99111
100- self .history .append (
101- {"prompt" : formatted_prompt , "response" : generated , "kwargs" : body }
102- )
112+ llm_out : str | list [str ]
113+ if "n" in kwargs .keys ():
114+ if self ._batch_n :
115+ llm_out = self ._simple_api_call (
116+ formatted_prompt = formatted_prompt , ** kwargs
117+ )
118+ else :
119+ del kwargs ["n" ]
120+ llm_out = []
121+ for _ in range (0 , kwargs ["n" ]):
122+ generated : str | list [str ] = self ._simple_api_call (
123+ formatted_prompt = formatted_prompt , ** kwargs
124+ )
125+ if isinstance (generated , str ):
126+ llm_out .append (generated )
127+ else :
128+ raise TypeError ("Error, list type was returned from LM call" )
129+ else :
130+ llm_out = self ._simple_api_call (formatted_prompt = formatted_prompt , ** kwargs )
103131
104- return generated . replace ( formatted_prompt , "" )
132+ return llm_out
105133
106134 def _estimate_tokens (self , text : str ) -> int :
107135 return len (text ) * CHARS2TOKENS
0 commit comments