Skip to content

Commit a48e608

Browse files
committed
Fix formatting and add API key for embeddings
Signed-off-by: Filip Hrisafov <[email protected]>
1 parent 08510d6 commit a48e608

File tree

2 files changed

+76
-12
lines changed

2 files changed

+76
-12
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,17 @@ public ResponseEntity<ChatCompletion> chatCompletionEntity(ChatCompletionRequest
175175
Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false.");
176176
Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null.");
177177

178-
return this.restClient.post().uri(this.completionsPath).headers(headers -> {
179-
headers.addAll(additionalHttpHeader);
180-
if (!headers.containsKey(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) {
181-
headers.setBearerAuth(this.apiKey.getValue());
182-
}
183-
}).body(chatRequest).retrieve().toEntity(ChatCompletion.class);
178+
// @formatter:off
179+
return this.restClient.post()
180+
.uri(this.completionsPath)
181+
.headers(headers -> {
182+
headers.addAll(additionalHttpHeader);
183+
addDefaultHeadersIfMissing(headers);
184+
})
185+
.body(chatRequest)
186+
.retrieve()
187+
.toEntity(ChatCompletion.class);
188+
// @formatter:on
184189
}
185190

186191
/**
@@ -209,12 +214,13 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
209214

210215
AtomicBoolean isInsideTool = new AtomicBoolean(false);
211216

212-
return this.webClient.post().uri(this.completionsPath).headers(headers -> {
213-
headers.addAll(additionalHttpHeader);
214-
if (!headers.containsKey(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) {
215-
headers.setBearerAuth(this.apiKey.getValue());
216-
}
217-
})
217+
// @formatter:off
218+
return this.webClient.post()
219+
.uri(this.completionsPath)
220+
.headers(headers -> {
221+
headers.addAll(additionalHttpHeader);
222+
addDefaultHeadersIfMissing(headers);
223+
}) // @formatter:on
218224
.body(Mono.just(chatRequest), ChatCompletionRequest.class)
219225
.retrieve()
220226
.bodyToFlux(String.class)
@@ -288,13 +294,20 @@ public <T> ResponseEntity<EmbeddingList<Embedding>> embeddings(EmbeddingRequest<
288294

289295
return this.restClient.post()
290296
.uri(this.embeddingsPath)
297+
.headers(this::addDefaultHeadersIfMissing)
291298
.body(embeddingRequest)
292299
.retrieve()
293300
.toEntity(new ParameterizedTypeReference<>() {
294301

295302
});
296303
}
297304

305+
private void addDefaultHeadersIfMissing(HttpHeaders headers) {
306+
if (!headers.containsKey(HttpHeaders.AUTHORIZATION) && !(this.apiKey instanceof NoopApiKey)) {
307+
headers.setBearerAuth(this.apiKey.getValue());
308+
}
309+
}
310+
298311
// Package-private getters for mutate/copy
299312
String getBaseUrl() {
300313
return this.baseUrl;

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,57 @@ void dynamicApiKeyWebClientWithAdditionalAuthorizationHeader() throws Interrupte
372372
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer additional-key");
373373
}
374374

375+
@Test
376+
void dynamicApiKeyRestClientEmbeddings() throws InterruptedException {
377+
Queue<ApiKey> apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2")));
378+
OpenAiApi api = OpenAiApi.builder()
379+
.apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue())
380+
.baseUrl(mockWebServer.url("/").toString())
381+
.build();
382+
383+
MockResponse mockResponse = new MockResponse().setResponseCode(200)
384+
.addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
385+
.setBody("""
386+
{
387+
"object": "list",
388+
"data": [
389+
{
390+
"object": "embedding",
391+
"index": 0,
392+
"embedding": [
393+
-0.005540426,
394+
0.0047363234,
395+
-0.015009919,
396+
-0.027093535,
397+
-0.015173893,
398+
0.015173893,
399+
-0.017608276
400+
]
401+
}
402+
],
403+
"model": "text-embedding-ada-002-v2",
404+
"usage": {
405+
"prompt_tokens": 2,
406+
"total_tokens": 2
407+
}
408+
}
409+
""");
410+
mockWebServer.enqueue(mockResponse);
411+
mockWebServer.enqueue(mockResponse);
412+
413+
OpenAiApi.EmbeddingRequest<String> request = new OpenAiApi.EmbeddingRequest<>("Hello world");
414+
ResponseEntity<OpenAiApi.EmbeddingList<OpenAiApi.Embedding>> response = api.embeddings(request);
415+
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK);
416+
RecordedRequest recordedRequest = mockWebServer.takeRequest();
417+
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1");
418+
419+
response = api.embeddings(request);
420+
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK);
421+
422+
recordedRequest = mockWebServer.takeRequest();
423+
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2");
424+
}
425+
375426
}
376427

377428
}

0 commit comments

Comments
 (0)