Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions bigframes/ml/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def predict(
max_output_tokens (int, default 128):
Maximum number of tokens that can be generated in the response. Specify a lower value for shorter responses and a higher value for longer responses.
A token may be smaller than a word. A token is approximately four characters. 100 tokens correspond to roughly 60-80 words.
Default 128. Possible values [1, 1024].
Default 128. For the 'text-bison' model, possible values are in the range [1, 1024]. For the 'text-bison-32k' model, possible values are in the range [1, 8196].
Please ensure that the specified value for max_output_tokens is within the appropriate range for the model being used.

top_k (int, default 40):
Top-k changes how the model selects tokens for output. A top-k of 1 means the selected token is the most probable among all tokens
Expand All @@ -184,12 +185,26 @@ def predict(
# Params reference: https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models
if temperature < 0.0 or temperature > 1.0:
raise ValueError(f"temperature must be [0.0, 1.0], but is {temperature}.")
if max_output_tokens not in range(1, 1025):

if (
self.model_name == _TEXT_GENERATOR_BISON_ENDPOINT
and max_output_tokens not in range(1, 1025)
):
raise ValueError(
f"max_output_token must be [1, 1024] for TextBison model, but is {max_output_tokens}."
)

if (
self.model_name == _TEXT_GENERATOR_BISON_32K_ENDPOINT
and max_output_tokens not in range(1, 8197)
):
raise ValueError(
f"max_output_token must be [1, 1024], but is {max_output_tokens}."
f"max_output_token must be [1, 8196] for TextBison 32k model, but is {max_output_tokens}."
)

if top_k not in range(1, 41):
raise ValueError(f"top_k must be [1, 40], but is {top_k}.")

if top_p < 0.0 or top_p > 1.0:
raise ValueError(f"top_p must be [0.0, 1.0], but is {top_p}.")

Expand Down