Skip to content

Commit f11ac75

Browse files
committed
feat: Add tool_calls support to JdbcChatMemoryRepository
- Add tool_calls column to all database schemas (PostgreSQL, MySQL, SQL Server, HSQLDB) - Implement JSON serialization/deserialization for AssistantMessage.ToolCall objects - Update all dialect classes to include tool_calls in SELECT and INSERT queries - Add PostgreSQL JSONB type support with explicit casting (::jsonb) - Add comprehensive unit tests for tool calls functionality - Add integration tests for tool calls across all supported databases - Maintain backward compatibility with existing chat memory data This enhancement allows the JDBC chat memory repository to persist and retrieve tool call information from AI assistant messages, enabling full conversation context preservation including function calls and their metadata. Signed-off-by: astor-dev <[email protected]>
1 parent 3919204 commit f11ac75

File tree

12 files changed

+219
-17
lines changed

12 files changed

+219
-17
lines changed

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/HsqldbChatMemoryRepositoryDialect.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,20 @@
1818

1919
/**
2020
* HSQLDB-specific SQL dialect for chat memory repository.
21+
*
22+
* @author DoHoon Kim
23+
* @since 1.0.0
2124
*/
2225
public class HsqldbChatMemoryRepositoryDialect implements JdbcChatMemoryRepositoryDialect {
2326

2427
@Override
2528
public String getSelectMessagesSql() {
26-
return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY timestamp ASC";
29+
return "SELECT content, type, tool_calls FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY timestamp ASC";
2730
}
2831

2932
@Override
3033
public String getInsertMessageSql() {
31-
return "INSERT INTO SPRING_AI_CHAT_MEMORY (conversation_id, content, type, timestamp) VALUES (?, ?, ?, ?)";
34+
return "INSERT INTO SPRING_AI_CHAT_MEMORY (conversation_id, content, type, tool_calls, timestamp) VALUES (?, ?, ?, ?, ?)";
3235
}
3336

3437
@Override

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepository.java

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.sql.Timestamp;
2323
import java.time.Instant;
2424
import java.util.List;
25+
import java.util.Map;
2526
import java.util.concurrent.atomic.AtomicLong;
2627

2728
import javax.sql.DataSource;
@@ -36,6 +37,7 @@
3637
import org.springframework.ai.chat.messages.SystemMessage;
3738
import org.springframework.ai.chat.messages.ToolResponseMessage;
3839
import org.springframework.ai.chat.messages.UserMessage;
40+
import org.springframework.ai.model.ModelOptionsUtils;
3941
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
4042
import org.springframework.jdbc.core.JdbcTemplate;
4143
import org.springframework.jdbc.core.RowMapper;
@@ -53,6 +55,7 @@
5355
* @author Linar Abzaltdinov
5456
* @author Mark Pollack
5557
* @author Yanming Zhou
58+
* @author DoHoon Kim
5659
* @since 1.0.0
5760
*/
5861
public final class JdbcChatMemoryRepository implements ChatMemoryRepository {
@@ -124,7 +127,15 @@ public void setValues(PreparedStatement ps, int i) throws SQLException {
124127
ps.setString(1, this.conversationId);
125128
ps.setString(2, message.getText());
126129
ps.setString(3, message.getMessageType().name());
127-
ps.setTimestamp(4, new Timestamp(this.instantSeq.getAndIncrement()));
130+
131+
// Handle tool_calls column
132+
String toolCallsJson = null;
133+
if (message instanceof AssistantMessage assistantMessage && assistantMessage.hasToolCalls()) {
134+
toolCallsJson = ModelOptionsUtils.toJsonString(assistantMessage.getToolCalls());
135+
}
136+
ps.setString(4, toolCallsJson);
137+
138+
ps.setTimestamp(5, new Timestamp(this.instantSeq.getAndIncrement()));
128139
}
129140

130141
@Override
@@ -140,14 +151,25 @@ private static class MessageRowMapper implements RowMapper<Message> {
140151
public Message mapRow(ResultSet rs, int i) throws SQLException {
141152
var content = rs.getString(1);
142153
var type = MessageType.valueOf(rs.getString(2));
154+
var toolCallsJson = rs.getString(3);
143155

144156
return switch (type) {
145157
case USER -> new UserMessage(content);
146-
case ASSISTANT -> new AssistantMessage(content);
158+
case ASSISTANT -> {
159+
List<AssistantMessage.ToolCall> toolCalls = List.of();
160+
if (toolCallsJson != null && !toolCallsJson.trim().isEmpty()) {
161+
try {
162+
toolCalls = ModelOptionsUtils.OBJECT_MAPPER.readValue(toolCallsJson,
163+
ModelOptionsUtils.OBJECT_MAPPER.getTypeFactory()
164+
.constructCollectionType(List.class, AssistantMessage.ToolCall.class));
165+
}
166+
catch (Exception e) {
167+
logger.warn("Failed to deserialize tool calls JSON: {}", toolCallsJson, e);
168+
}
169+
}
170+
yield new AssistantMessage(content, Map.of(), toolCalls);
171+
}
147172
case SYSTEM -> new SystemMessage(content);
148-
// The content is always stored empty for ToolResponseMessages.
149-
// If we want to capture the actual content, we need to extend
150-
// AddBatchPreparedStatement to support it.
151173
case TOOL -> new ToolResponseMessage(List.of());
152174
};
153175
}

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/JdbcChatMemoryRepositoryDialect.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616

1717
package org.springframework.ai.chat.memory.repository.jdbc;
1818

19-
import javax.sql.DataSource;
2019
import java.sql.Connection;
2120

21+
import javax.sql.DataSource;
22+
2223
/**
2324
* Abstraction for database-specific SQL for chat memory repository.
2425
*/

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/MysqlChatMemoryRepositoryDialect.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,19 @@
2020
* MySQL dialect for chat memory repository.
2121
*
2222
* @author Mark Pollack
23+
* @author DoHoon Kim
2324
* @since 1.0.0
2425
*/
2526
public class MysqlChatMemoryRepositoryDialect implements JdbcChatMemoryRepositoryDialect {
2627

2728
@Override
2829
public String getSelectMessagesSql() {
29-
return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY `timestamp`";
30+
return "SELECT content, type, tool_calls FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY `timestamp`";
3031
}
3132

3233
@Override
3334
public String getInsertMessageSql() {
34-
return "INSERT INTO SPRING_AI_CHAT_MEMORY (conversation_id, content, type, `timestamp`) VALUES (?, ?, ?, ?)";
35+
return "INSERT INTO SPRING_AI_CHAT_MEMORY (conversation_id, content, type, tool_calls, `timestamp`) VALUES (?, ?, ?, ?, ?)";
3536
}
3637

3738
@Override

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/PostgresChatMemoryRepositoryDialect.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,19 @@
2020
* Dialect for Postgres.
2121
*
2222
* @author Mark Pollack
23+
* @author DoHoon Kim
2324
* @since 1.0.0
2425
*/
2526
public class PostgresChatMemoryRepositoryDialect implements JdbcChatMemoryRepositoryDialect {
2627

2728
@Override
2829
public String getSelectMessagesSql() {
29-
return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY \"timestamp\"";
30+
return "SELECT content, type, tool_calls FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY \"timestamp\"";
3031
}
3132

3233
@Override
3334
public String getInsertMessageSql() {
34-
return "INSERT INTO SPRING_AI_CHAT_MEMORY (conversation_id, content, type, \"timestamp\") VALUES (?, ?, ?, ?)";
35+
return "INSERT INTO SPRING_AI_CHAT_MEMORY (conversation_id, content, type, tool_calls, \"timestamp\") VALUES (?, ?, ?, ?::jsonb, ?)";
3536
}
3637

3738
@Override

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/java/org/springframework/ai/chat/memory/repository/jdbc/SqlServerChatMemoryRepositoryDialect.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,19 @@
2020
* Dialect for SQL Server.
2121
*
2222
* @author Mark Pollack
23+
* @author DoHoon Kim
2324
* @since 1.0.0
2425
*/
2526
public class SqlServerChatMemoryRepositoryDialect implements JdbcChatMemoryRepositoryDialect {
2627

2728
@Override
2829
public String getSelectMessagesSql() {
29-
return "SELECT content, type FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY [timestamp]";
30+
return "SELECT content, type, tool_calls FROM SPRING_AI_CHAT_MEMORY WHERE conversation_id = ? ORDER BY [timestamp]";
3031
}
3132

3233
@Override
3334
public String getInsertMessageSql() {
34-
return "INSERT INTO SPRING_AI_CHAT_MEMORY (conversation_id, content, type, [timestamp]) VALUES (?, ?, ?, ?)";
35+
return "INSERT INTO SPRING_AI_CHAT_MEMORY (conversation_id, content, type, tool_calls, [timestamp]) VALUES (?, ?, ?, ?, ?)";
3536
}
3637

3738
@Override

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/resources/org/springframework/ai/chat/memory/repository/jdbc/schema-hsqldb.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ CREATE TABLE SPRING_AI_CHAT_MEMORY (
22
conversation_id VARCHAR(36) NOT NULL,
33
content LONGVARCHAR NOT NULL,
44
type VARCHAR(10) NOT NULL,
5+
tool_calls LONGVARCHAR NULL,
56
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL
67
);
78

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/resources/org/springframework/ai/chat/memory/repository/jdbc/schema-mariadb.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ CREATE TABLE IF NOT EXISTS SPRING_AI_CHAT_MEMORY (
22
conversation_id VARCHAR(36) NOT NULL,
33
content TEXT NOT NULL,
44
type VARCHAR(10) NOT NULL,
5+
tool_calls LONGTEXT NULL,
56
`timestamp` TIMESTAMP NOT NULL,
67
CONSTRAINT TYPE_CHECK CHECK (type IN ('USER', 'ASSISTANT', 'SYSTEM', 'TOOL'))
78
);

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/resources/org/springframework/ai/chat/memory/repository/jdbc/schema-postgresql.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ CREATE TABLE IF NOT EXISTS SPRING_AI_CHAT_MEMORY (
22
conversation_id VARCHAR(36) NOT NULL,
33
content TEXT NOT NULL,
44
type VARCHAR(10) NOT NULL CHECK (type IN ('USER', 'ASSISTANT', 'SYSTEM', 'TOOL')),
5+
tool_calls JSONB NULL,
56
"timestamp" TIMESTAMP NOT NULL
67
);
78

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/main/resources/org/springframework/ai/chat/memory/repository/jdbc/schema-sqlserver.sql

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ CREATE TABLE SPRING_AI_CHAT_MEMORY (
22
conversation_id VARCHAR(36) NOT NULL,
33
content NVARCHAR(MAX) NOT NULL,
44
type VARCHAR(10) NOT NULL,
5+
tool_calls NVARCHAR(MAX) NULL,
56
[timestamp] DATETIME2 NOT NULL DEFAULT SYSDATETIME(),
67
CONSTRAINT TYPE_CHECK CHECK (type IN ('USER', 'ASSISTANT', 'SYSTEM', 'TOOL'))
78
);

memory/repository/spring-ai-model-chat-memory-repository-jdbc/src/test/java/org/springframework/ai/chat/memory/repository/jdbc/AbstractJdbcChatMemoryRepositoryIT.java

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.sql.Timestamp;
2020
import java.util.List;
21+
import java.util.Map;
2122
import java.util.UUID;
2223
import java.util.stream.Collectors;
2324

@@ -77,10 +78,10 @@ void saveMessagesSingleMessage(String content, MessageType messageType) {
7778
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect
7879
.from(this.jdbcTemplate.getDataSource());
7980
String selectSql = dialect.getSelectMessagesSql()
80-
.replace("content, type", "conversation_id, content, type, timestamp");
81+
.replace("content, type, tool_calls", "conversation_id, content, type, tool_calls, timestamp");
8182
var result = this.jdbcTemplate.queryForMap(selectSql, conversationId);
8283

83-
assertThat(result.size()).isEqualTo(4);
84+
assertThat(result.size()).isEqualTo(5);
8485
assertThat(result.get("conversation_id")).isEqualTo(conversationId);
8586
assertThat(result.get("content")).isEqualTo(message.getText());
8687
assertThat(result.get("type")).isEqualTo(messageType.name());
@@ -102,7 +103,7 @@ void saveMessagesMultipleMessages() {
102103
JdbcChatMemoryRepositoryDialect dialect = JdbcChatMemoryRepositoryDialect
103104
.from(this.jdbcTemplate.getDataSource());
104105
String selectSql = dialect.getSelectMessagesSql()
105-
.replace("content, type", "conversation_id, content, type, timestamp");
106+
.replace("content, type, tool_calls", "conversation_id, content, type, tool_calls, timestamp");
106107
var results = this.jdbcTemplate.queryForList(selectSql, conversationId);
107108

108109
assertThat(results).hasSize(messages.size());
@@ -186,6 +187,67 @@ void testMessageOrder() {
186187
"4-Fourth message");
187188
}
188189

190+
@Test
191+
void saveAndRetrieveAssistantMessageWithToolCalls() {
192+
String conversationId = UUID.randomUUID().toString();
193+
194+
// Create tool calls
195+
List<AssistantMessage.ToolCall> toolCalls = List.of(
196+
new AssistantMessage.ToolCall("call_1", "function", "get_weather", "{\"location\":\"Seoul\"}"),
197+
new AssistantMessage.ToolCall("call_2", "function", "get_time", "{\"timezone\":\"Asia/Seoul\"}"));
198+
199+
var assistantMessage = new AssistantMessage("I'll help you with that.", Map.of(), toolCalls);
200+
201+
this.chatMemoryRepository.saveAll(conversationId, List.of(assistantMessage));
202+
203+
// Retrieve and verify
204+
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
205+
assertThat(retrievedMessages).hasSize(1);
206+
207+
Message retrievedMessage = retrievedMessages.get(0);
208+
assertThat(retrievedMessage).isInstanceOf(AssistantMessage.class);
209+
210+
AssistantMessage retrievedAssistantMessage = (AssistantMessage) retrievedMessage;
211+
assertThat(retrievedAssistantMessage.getText()).isEqualTo("I'll help you with that.");
212+
assertThat(retrievedAssistantMessage.hasToolCalls()).isTrue();
213+
assertThat(retrievedAssistantMessage.getToolCalls()).hasSize(2);
214+
215+
// Verify first tool call
216+
AssistantMessage.ToolCall firstToolCall = retrievedAssistantMessage.getToolCalls().get(0);
217+
assertThat(firstToolCall.id()).isEqualTo("call_1");
218+
assertThat(firstToolCall.type()).isEqualTo("function");
219+
assertThat(firstToolCall.name()).isEqualTo("get_weather");
220+
assertThat(firstToolCall.arguments()).isEqualTo("{\"location\":\"Seoul\"}");
221+
222+
// Verify second tool call
223+
AssistantMessage.ToolCall secondToolCall = retrievedAssistantMessage.getToolCalls().get(1);
224+
assertThat(secondToolCall.id()).isEqualTo("call_2");
225+
assertThat(secondToolCall.type()).isEqualTo("function");
226+
assertThat(secondToolCall.name()).isEqualTo("get_time");
227+
assertThat(secondToolCall.arguments()).isEqualTo("{\"timezone\":\"Asia/Seoul\"}");
228+
}
229+
230+
@Test
231+
void saveAndRetrieveAssistantMessageWithoutToolCalls() {
232+
String conversationId = UUID.randomUUID().toString();
233+
234+
var assistantMessage = new AssistantMessage("Simple response without tool calls.");
235+
236+
this.chatMemoryRepository.saveAll(conversationId, List.of(assistantMessage));
237+
238+
// Retrieve and verify
239+
List<Message> retrievedMessages = this.chatMemoryRepository.findByConversationId(conversationId);
240+
assertThat(retrievedMessages).hasSize(1);
241+
242+
Message retrievedMessage = retrievedMessages.get(0);
243+
assertThat(retrievedMessage).isInstanceOf(AssistantMessage.class);
244+
245+
AssistantMessage retrievedAssistantMessage = (AssistantMessage) retrievedMessage;
246+
assertThat(retrievedAssistantMessage.getText()).isEqualTo("Simple response without tool calls.");
247+
assertThat(retrievedAssistantMessage.hasToolCalls()).isFalse();
248+
assertThat(retrievedAssistantMessage.getToolCalls()).isEmpty();
249+
}
250+
189251
/**
190252
* Base configuration for all integration tests.
191253
*/

0 commit comments

Comments
 (0)