Skip to content

test: Check that _convert_chat_completion_chunk_to_streaming_chunk works for MistralChatGenerator #1953

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 187 additions & 54 deletions integrations/mistral/tests/test_mistral_chat_generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from datetime import datetime
from unittest.mock import patch
from unittest.mock import ANY, patch

import pytest
import pytz
Expand All @@ -11,12 +11,27 @@
from haystack.tools import Tool
from haystack.utils.auth import Secret
from openai import OpenAIError
from openai.types.chat import ChatCompletion, ChatCompletionMessage
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as ChoiceChunk
from openai.types.chat.chat_completion_chunk import ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction
from openai.types.completion_usage import CompletionUsage

from haystack_integrations.components.generators.mistral.chat.chat_generator import MistralChatGenerator


class CollectorCallback:
"""
Callback to collect streaming chunks for testing purposes.
"""

def __init__(self):
self.chunks = []

def __call__(self, chunk: StreamingChunk) -> None:
self.chunks.append(chunk)


@pytest.fixture
def chat_messages():
return [
Expand Down Expand Up @@ -179,6 +194,138 @@ def test_from_dict_fail_wo_env_var(self, monkeypatch):
with pytest.raises(ValueError, match="None of the .* environment variables are set"):
MistralChatGenerator.from_dict(data)

def test_handle_stream_response(self):
mistral_chunks = [
ChatCompletionChunk(
id="76535283139540de943bc2036121d4c5",
choices=[ChoiceChunk(delta=ChoiceDelta(content="", role="assistant"), index=0)],
created=1750076261,
model="mistral-small-latest",
object="chat.completion.chunk",
),
ChatCompletionChunk(
id="76535283139540de943bc2036121d4c5",
choices=[
ChoiceChunk(
delta=ChoiceDelta(
tool_calls=[
ChoiceDeltaToolCall(
index=0,
id="FL1FFlqUG",
function=ChoiceDeltaToolCallFunction(arguments='{"city": "Paris"}', name="weather"),
),
ChoiceDeltaToolCall(
index=1,
id="xSuhp66iB",
function=ChoiceDeltaToolCallFunction(
arguments='{"city": "Berlin"}', name="weather"
),
),
],
),
finish_reason="tool_calls",
index=0,
)
],
created=1750076261,
model="mistral-small-latest",
object="chat.completion.chunk",
usage=CompletionUsage(
completion_tokens=35,
prompt_tokens=77,
total_tokens=112,
),
),
]

collector_callback = CollectorCallback()
llm = MistralChatGenerator(api_key=Secret.from_token("test-api-key"))
result = llm._handle_stream_response(mistral_chunks, callback=collector_callback)[0] # type: ignore

# Verify the callback collected the expected number of chunks
# We expect 2 chunks: one for the initial empty content and one for the tool calls
assert len(collector_callback.chunks) == 2
assert collector_callback.chunks[0] == StreamingChunk(
content="",
meta={
"model": "mistral-small-latest",
"index": 0,
"tool_calls": None,
"finish_reason": None,
"received_at": ANY,
},
# TODO Uncomment once new haystack version is released
# component_info=ComponentInfo(
# type="haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator",
# name=None,
# ),
)
assert collector_callback.chunks[1] == StreamingChunk(
content="",
meta={
"model": "mistral-small-latest",
"index": 0,
"tool_calls": [
ChoiceDeltaToolCall(
index=0,
id="FL1FFlqUG",
function=ChoiceDeltaToolCallFunction(arguments='{"city": "Paris"}', name="weather"),
),
ChoiceDeltaToolCall(
index=1,
id="xSuhp66iB",
function=ChoiceDeltaToolCallFunction(arguments='{"city": "Berlin"}', name="weather"),
),
],
"finish_reason": "tool_calls",
"received_at": ANY,
# TODO Uncomment once new haystack version is released
# "usage": {
# "completion_tokens": 35,
# "prompt_tokens": 77,
# "total_tokens": 112,
# "completion_tokens_details": None,
# "prompt_tokens_details": None,
# },
},
# TODO Uncomment once new haystack version is released
# component_info=ComponentInfo(
# type="haystack_integrations.components.generators.mistral.chat.chat_generator.MistralChatGenerator",
# name=None,
# ),
# index=0,
# tool_calls=[
# ToolCallDelta(index=0, tool_name="weather", arguments='{"city": "Paris"}', id="FL1FFlqUG"),
# ToolCallDelta(index=1, tool_name="weather", arguments='{"city": "Berlin"}', id="xSuhp66iB"),
# ],
# start=True
)

# Assert text is empty
assert result.text is None

# Verify both tool calls were found and processed
assert len(result.tool_calls) == 2
assert result.tool_calls[0].id == "FL1FFlqUG"
assert result.tool_calls[0].tool_name == "weather"
assert result.tool_calls[0].arguments == {"city": "Paris"}
assert result.tool_calls[1].id == "xSuhp66iB"
assert result.tool_calls[1].tool_name == "weather"
assert result.tool_calls[1].arguments == {"city": "Berlin"}

# Verify meta information
assert result.meta["model"] == "mistral-small-latest"
assert result.meta["finish_reason"] == "tool_calls"
assert result.meta["index"] == 0
assert result.meta["completion_start_time"] is not None
assert result.meta["usage"] == {
"completion_tokens": 35,
"prompt_tokens": 77,
"total_tokens": 112,
"completion_tokens_details": None,
"prompt_tokens_details": None,
}

def test_run(self, chat_messages, mock_chat_completion, monkeypatch): # noqa: ARG002
monkeypatch.setenv("MISTRAL_API_KEY", "fake-api-key")
component = MistralChatGenerator()
Expand Down Expand Up @@ -291,42 +438,44 @@ def test_live_run_with_tools_and_response(self, tools):
"""
Integration test that the MistralChatGenerator component can run with tools and get a response.
"""
initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
initial_messages = [ChatMessage.from_user("What's the weather like in Paris and Berlin?")]
component = MistralChatGenerator(tools=tools)
results = component.run(messages=initial_messages, generation_kwargs={"tool_choice": "any"})

assert len(results["replies"]) > 0, "No replies received"
assert len(results["replies"]) == 1

# Find the message with tool calls
tool_message = None
for message in results["replies"]:
if message.tool_call:
tool_message = message
break

assert tool_message is not None, "No message with tool call found"
assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance"
assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant"

tool_call = tool_message.tool_call
assert tool_call.id, "Tool call does not contain value for 'id' key"
assert tool_call.tool_name == "weather"
assert tool_call.arguments == {"city": "Paris"}
tool_message = results["replies"][0]

assert isinstance(tool_message, ChatMessage)
tool_calls = tool_message.tool_calls
assert len(tool_calls) == 2
assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT)

for tool_call in tool_calls:
assert tool_call.id is not None
assert isinstance(tool_call, ToolCall)
assert tool_call.tool_name == "weather"

arguments = [tool_call.arguments for tool_call in tool_calls]
assert sorted(arguments, key=lambda x: x["city"]) == [{"city": "Berlin"}, {"city": "Paris"}]
assert tool_message.meta["finish_reason"] == "tool_calls"

new_messages = [
initial_messages[0],
tool_message,
ChatMessage.from_tool(tool_result="22° C", origin=tool_call),
ChatMessage.from_tool(tool_result="22° C and sunny", origin=tool_calls[0]),
ChatMessage.from_tool(tool_result="16° C and windy", origin=tool_calls[1]),
]
# Pass the tool result to the model to get the final response
results = component.run(new_messages)

assert len(results["replies"]) == 1
final_message = results["replies"][0]
assert not final_message.tool_call
assert final_message.is_from(ChatRole.ASSISTANT)
assert len(final_message.text) > 0
assert "paris" in final_message.text.lower()
assert "berlin" in final_message.text.lower()

@pytest.mark.skipif(
not os.environ.get("MISTRAL_API_KEY", None),
Expand All @@ -337,45 +486,29 @@ def test_live_run_with_tools_streaming(self, tools):
"""
Integration test that the MistralChatGenerator component can run with tools and streaming.
"""

class Callback:
def __init__(self):
self.responses = ""
self.counter = 0
self.tool_calls = []

def __call__(self, chunk: StreamingChunk) -> None:
self.counter += 1
if chunk.content:
self.responses += chunk.content
if chunk.meta.get("tool_calls"):
self.tool_calls.extend(chunk.meta["tool_calls"])

callback = Callback()
component = MistralChatGenerator(tools=tools, streaming_callback=callback)
component = MistralChatGenerator(tools=tools, streaming_callback=print_streaming_chunk)
results = component.run(
[ChatMessage.from_user("What's the weather like in Paris?")], generation_kwargs={"tool_choice": "any"}
[ChatMessage.from_user("What's the weather like in Paris and Berlin?")],
generation_kwargs={"tool_choice": "any"},
)

assert len(results["replies"]) > 0, "No replies received"
assert callback.counter > 1, "Streaming callback was not called multiple times"
assert callback.tool_calls, "No tool calls received in streaming"
assert len(results["replies"]) == 1

# Find the message with tool calls
tool_message = None
for message in results["replies"]:
if message.tool_call:
tool_message = message
break

assert tool_message is not None, "No message with tool call found"
assert isinstance(tool_message, ChatMessage), "Tool message is not a ChatMessage instance"
assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT), "Tool message is not from the assistant"

tool_call = tool_message.tool_call
assert tool_call.id, "Tool call does not contain value for 'id' key"
assert tool_call.tool_name == "weather"
assert tool_call.arguments == {"city": "Paris"}
tool_message = results["replies"][0]

assert isinstance(tool_message, ChatMessage)
tool_calls = tool_message.tool_calls
assert len(tool_calls) == 2
assert ChatMessage.is_from(tool_message, ChatRole.ASSISTANT)

for tool_call in tool_calls:
assert tool_call.id is not None
assert isinstance(tool_call, ToolCall)
assert tool_call.tool_name == "weather"

arguments = [tool_call.arguments for tool_call in tool_calls]
assert sorted(arguments, key=lambda x: x["city"]) == [{"city": "Berlin"}, {"city": "Paris"}]
assert tool_message.meta["finish_reason"] == "tool_calls"

@pytest.mark.skipif(
Expand Down
Loading