Skip to content

Commit bb5d0de

Browse files
authored
Merge pull request stanfordnlp#388 from JamesScharf/bedrock-args
Bedrock args hotfix
2 parents c8405aa + 4fd70be commit bb5d0de

File tree

2 files changed

+51
-12
lines changed

2 files changed

+51
-12
lines changed

dsp/modules/aws_lm.py

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
max_new_tokens: int,
2929
truncate_long_prompts: bool = False,
3030
input_output_ratio: int = 3,
31+
batch_n: bool = True,
3132
) -> None:
3233
"""_summary_
3334
@@ -40,6 +41,7 @@ def __init__(
4041
input_output_ratio (int, optional): The rough size of the number of input tokens to output tokens in the worst case. Defaults to 3.
4142
temperature (float, optional): _description_. Defaults to 0.0.
4243
truncate_long_prompts (bool, optional): If True, remove extremely long inputs to context. Defaults to False.
44+
batch_n (bool, False): If False, call the LM N times rather than batching. Not all AWS models support the n parameter.
4345
"""
4446
super().__init__(model=model)
4547
# AWS doesn't have an equivalent of max_tokens so let's clarify
@@ -48,9 +50,10 @@ def __init__(
4850
self._max_new_tokens: int = max_new_tokens
4951
self._model_name: str = model
5052
self._truncate_long_prompt_prompts: bool = truncate_long_prompts
53+
self._batch_n: bool = batch_n
5154

5255
import boto3
53-
56+
5457
self.predictor = boto3.client(service_name, region_name=region_name)
5558

5659
@abstractmethod
@@ -72,7 +75,7 @@ def _sanitize_kwargs(self, query_kwargs: dict[str, Any]) -> dict[str, Any]:
7275
return query_kwargs
7376

7477
@abstractmethod
75-
def _call_model(self, body: str) -> str:
78+
def _call_model(self, body: str) -> str | list[str]:
7679
"""Call model, get generated input without the formatted prompt"""
7780
pass
7881

@@ -82,7 +85,20 @@ def _extract_input_parameters(
8285
) -> dict[str, str | float | int]:
8386
pass
8487

85-
def basic_request(self, prompt, **kwargs) -> str:
88+
def _simple_api_call(self, formatted_prompt: str, **kwargs) -> str | list[str]:
89+
body = self._create_body(formatted_prompt, **kwargs)
90+
json_body = json.dumps(body)
91+
llm_out: str | list[str] = self._call_model(json_body)
92+
if isinstance(llm_out, str):
93+
llm_out = llm_out.replace(formatted_prompt, "")
94+
else:
95+
llm_out = [generated.replace(formatted_prompt, "") for generated in llm_out]
96+
self.history.append(
97+
{"prompt": formatted_prompt, "response": llm_out, "kwargs": body}
98+
)
99+
return llm_out
100+
101+
def basic_request(self, prompt, **kwargs) -> str | list[str]:
86102
"""Query the endpoint."""
87103

88104
# Remove any texts that are too long
@@ -92,16 +108,28 @@ def basic_request(self, prompt, **kwargs) -> str:
92108
formatted_prompt = self._format_prompt(truncated_prompt)
93109
else:
94110
formatted_prompt = self._format_prompt((prompt))
95-
body = self._create_body(formatted_prompt, **kwargs)
96-
json_body: str = json.dumps(body)
97-
98-
generated: str = self._call_model(json_body)
99111

100-
self.history.append(
101-
{"prompt": formatted_prompt, "response": generated, "kwargs": body}
102-
)
112+
llm_out: str | list[str]
113+
if "n" in kwargs.keys():
114+
if self._batch_n:
115+
llm_out = self._simple_api_call(
116+
formatted_prompt=formatted_prompt, **kwargs
117+
)
118+
else:
119+
del kwargs["n"]
120+
llm_out = []
121+
for _ in range(0, kwargs["n"]):
122+
generated: str | list[str] = self._simple_api_call(
123+
formatted_prompt=formatted_prompt, **kwargs
124+
)
125+
if isinstance(generated, str):
126+
llm_out.append(generated)
127+
else:
128+
raise TypeError("Error, list type was returned from LM call")
129+
else:
130+
llm_out = self._simple_api_call(formatted_prompt=formatted_prompt, **kwargs)
103131

104-
return generated.replace(formatted_prompt, "")
132+
return llm_out
105133

106134
def _estimate_tokens(self, text: str) -> int:
107135
return len(text) * CHARS2TOKENS

dsp/modules/bedrock.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,13 @@ def __init__(
3030
truncate_long_prompts=False,
3131
input_output_ratio=input_output_ratio,
3232
max_new_tokens=max_new_tokens,
33+
batch_n=True, # Bedrock does not support the `n` parameter
3334
)
35+
self._validate_model(model)
36+
37+
def _validate_model(self, model: str) -> None:
38+
if "claude" not in model.lower():
39+
raise NotImplementedError("Only claude models are supported as of now")
3440

3541
def _create_body(self, prompt: str, **kwargs) -> dict[str, str | float]:
3642
base_args: dict[str, Any] = {
@@ -41,7 +47,12 @@ def _create_body(self, prompt: str, **kwargs) -> dict[str, str | float]:
4147
query_args: dict[str, Any] = self._sanitize_kwargs(base_args)
4248
query_args["prompt"] = prompt
4349
# AWS Bedrock forbids these keys
44-
50+
if "max_tokens" in query_args.keys():
51+
max_tokens: int = query_args["max_tokens"]
52+
input_tokens: int = self._estimate_tokens(prompt)
53+
max_tokens_to_sample: int = max_tokens - input_tokens
54+
del query_args["max_tokens"]
55+
query_args["max_tokens_to_sample"] = max_tokens_to_sample
4556
return query_args
4657

4758
def _call_model(self, body: str) -> str:

0 commit comments

Comments
 (0)