Skip to content

Commit b64c39d

Browse files
top_k and top_p transposed in vertexai (langchain-ai#5673)
Fix transposed properties in vertexai model Co-authored-by: Dev 2049 <[email protected]>
1 parent 3fb0e48 commit b64c39d

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

langchain/llms/vertexai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def _default_params(self) -> Dict[str, Any]:
4343
base_params = {
4444
"temperature": self.temperature,
4545
"max_output_tokens": self.max_output_tokens,
46-
"top_k": self.top_p,
47-
"top_p": self.top_k,
46+
"top_k": self.top_k,
47+
"top_p": self.top_p,
4848
}
4949
return {**base_params}
5050

tests/integration_tests/chat_models/test_vertexai.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
Your end-user credentials would be used to make the calls (make sure you've run
88
`gcloud auth login` first).
99
"""
10+
from unittest.mock import Mock, patch
11+
1012
import pytest
1113

1214
from langchain.chat_models import ChatVertexAI
@@ -86,3 +88,31 @@ def test_vertexai_single_call_failes_no_message() -> None:
8688
str(exc_info.value)
8789
== "You should provide at least one message to start the chat!"
8890
)
91+
92+
93+
def test_vertexai_args_passed() -> None:
94+
response_text = "Goodbye"
95+
user_prompt = "Hello"
96+
prompt_params = {
97+
"max_output_tokens": 1,
98+
"temperature": 10000.0,
99+
"top_k": 10,
100+
"top_p": 0.5,
101+
}
102+
103+
# Mock the library to ensure the args are passed correctly
104+
with patch(
105+
"vertexai.language_models._language_models.ChatSession.send_message"
106+
) as send_message:
107+
mock_response = Mock(text=response_text)
108+
send_message.return_value = mock_response
109+
110+
model = ChatVertexAI(**prompt_params)
111+
message = HumanMessage(content=user_prompt)
112+
response = model([message])
113+
114+
assert response.content == response_text
115+
send_message.assert_called_once_with(
116+
user_prompt,
117+
**prompt_params,
118+
)

0 commit comments

Comments
 (0)