Open
Description
A user on discord ran into problems when trying to modify https://github.com/deepset-ai/haystack-cookbook/blob/main/notebooks/chat_with_SQL_3_ways.ipynb
We should make the example work with ChatPromptBuilder and OpenAIChatGenerator instead as we want to shift to ChatGenerators.
The user later reported that the suggestion posted on discord worked.
I recommend to also adapt the SQLQuery component to work with List[ChatMessage] insted of List[str] as input. something along the lines of the code you already have for filling the database plus then:
import pandas as pd
import sqlite3
from typing import List
from haystack import component, Pipeline
from haystack.components.builders import ChatPromptBuilder
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.dataclasses import ChatMessage
@component
class SQLQuery:
def __init__(self, sql_database: str):
self.connection = sqlite3.connect(sql_database, check_same_thread=False)
@component.output_types(results=List[str], queries=List[str])
def run(self, queries: List[ChatMessage]):
results = []
for query in queries:
sql_query = query.text
result = pd.read_sql(sql_query, self.connection)
results.append(f"{result}")
return {"results": results, "queries": [msg.text for msg in queries]}
columns = df.columns.to_list()
columns = ', '.join(columns)
prompt_builder = ChatPromptBuilder()
sql_query = SQLQuery('absenteeism.db')
llm = OpenAIChatGenerator(model="gpt-4o-mini")
pipeline = Pipeline()
pipeline.add_component("prompt_builder", prompt_builder)
pipeline.add_component("llm", llm)
pipeline.add_component("sql_querier", sql_query)
pipeline.connect("prompt_builder.prompt", "llm.messages")
pipeline.connect("llm.replies", "sql_querier.queries")
system_message = ChatMessage.from_system("You are a helpful assistant that generates SQL queries based on natural language questions.")
user_message = ChatMessage.from_user("""Please generate an SQL query. The query should answer the following Question: {{question}};
The query is to be answered for the table is called 'absenteeism' with the following
Columns: {{columns}};
Answer:""")
result = pipeline.run(
data={
"prompt_builder": {
"template": [system_message, user_message],
"template_variables": {
"question": "On which days of the week does the average absenteeism time exceed 4 hours?",
"columns": columns
}
}
}
)
print(result["sql_querier"]["results"][0])
Metadata
Metadata
Assignees
Labels
No labels