|
19 | 19 | import java.io.IOException;
|
20 | 20 | import java.time.Duration;
|
21 | 21 | import java.util.Set;
|
| 22 | +import java.util.function.Supplier; |
22 | 23 |
|
23 | 24 | import org.junit.jupiter.api.Test;
|
24 | 25 | import org.slf4j.Logger;
|
|
32 | 33 | import org.springframework.ai.chat.model.ChatModel;
|
33 | 34 | import org.springframework.ai.content.Media;
|
34 | 35 | import org.springframework.ai.model.tool.ToolCallingChatOptions;
|
| 36 | +import org.springframework.ai.tool.annotation.Tool; |
35 | 37 | import org.springframework.ai.tool.function.FunctionToolCallback;
|
36 | 38 | import org.springframework.beans.factory.annotation.Autowired;
|
37 | 39 | import org.springframework.boot.SpringBootConfiguration;
|
@@ -171,6 +173,61 @@ public record WeatherRequest(String location, String unit) {
|
171 | 173 | public record WeatherResponse(int temp, String unit) {
|
172 | 174 | }
|
173 | 175 |
|
| 176 | + // https://github.com/spring-projects/spring-ai/issues/1878 |
| 177 | + @Test |
| 178 | + void toolAnnotationWeatherForecastTest() { |
| 179 | + |
| 180 | + ChatClient chatClient = ChatClient.builder(this.chatModel).build(); |
| 181 | + |
| 182 | + String response = chatClient.prompt() |
| 183 | + .tools(new DummyWeatherForcastTools()) |
| 184 | + .user("Get current weather in Amsterdam") |
| 185 | + .call() |
| 186 | + .content(); |
| 187 | + |
| 188 | + assertThat(response).isNotEmpty(); |
| 189 | + assertThat(response).contains("20 degrees"); |
| 190 | + } |
| 191 | + |
| 192 | + public static class DummyWeatherForcastTools { |
| 193 | + |
| 194 | + @Tool(description = "Get the current weather forcast in Amsterdam") |
| 195 | + String getCurrentDateTime() { |
| 196 | + return "Weahter is hot and sunny wiht a temperature of 20 degrees"; |
| 197 | + } |
| 198 | + |
| 199 | + } |
| 200 | + |
| 201 | + // https://github.com/spring-projects/spring-ai/issues/1878 |
| 202 | + @Test |
| 203 | + void supplierBasedToolCalling() { |
| 204 | + |
| 205 | + ChatClient chatClient = ChatClient.builder(this.chatModel).build(); |
| 206 | + |
| 207 | + WeatherService.Response response = chatClient.prompt() |
| 208 | + .toolCallbacks(FunctionToolCallback.builder("weather", new WeatherService()) |
| 209 | + .description("Get the current weather") |
| 210 | + .inputType(Void.class) |
| 211 | + .build()) |
| 212 | + .user("Get current weather in Amsterdam") |
| 213 | + .call() |
| 214 | + .entity(WeatherService.Response.class); |
| 215 | + |
| 216 | + assertThat(response).isNotNull(); |
| 217 | + assertThat(response.temp()).isEqualTo(30.0); |
| 218 | + } |
| 219 | + |
| 220 | + public static class WeatherService implements Supplier<WeatherService.Response> { |
| 221 | + |
| 222 | + public record Response(double temp) { |
| 223 | + } |
| 224 | + |
| 225 | + public Response get() { |
| 226 | + return new Response(30.0); |
| 227 | + } |
| 228 | + |
| 229 | + } |
| 230 | + |
174 | 231 | @SpringBootConfiguration
|
175 | 232 | public static class Config {
|
176 | 233 |
|
|
0 commit comments