From 9a1823eb1032a91cc3cd57c590a70af910282e8d Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Wed, 9 Apr 2025 15:34:05 -0400 Subject: [PATCH] [ML] Bedrock Cohere Task Settings Support (#126493) Add support for Cohere Task Settings and Truncate, through the Amazon Bedrock provider integration. Task Settings can now be passed bother during Inference endpoint creation and Inference POST requests. Close #126156 --- docs/changelog/126493.yaml | 6 + .../org/elasticsearch/TransportVersions.java | 1 + .../InferenceNamedWriteablesProvider.java | 10 +- ...nBedrockCohereEmbeddingsRequestEntity.java | 13 +- .../AmazonBedrockEmbeddingsEntityFactory.java | 2 +- .../AmazonBedrockEmbeddingsRequest.java | 3 + .../amazonbedrock/AmazonBedrockConstants.java | 2 + .../amazonbedrock/AmazonBedrockService.java | 12 ++ .../AmazonBedrockEmbeddingsModel.java | 24 +-- .../AmazonBedrockEmbeddingsTaskSettings.java | 98 +++++++++++ .../services/cohere/CohereTruncation.java | 3 + .../CohereEmbeddingsTaskSettings.java | 3 +- ...ockCohereEmbeddingsRequestEntityTests.java | 32 +++- .../AmazonBedrockServiceTests.java | 164 +++++++++++------- .../AmazonBedrockEmbeddingsModelTests.java | 86 +++++++-- ...zonBedrockEmbeddingsTaskSettingsTests.java | 112 ++++++++++++ 16 files changed, 469 insertions(+), 102 deletions(-) create mode 100644 docs/changelog/126493.yaml create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsTaskSettings.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsTaskSettingsTests.java diff --git a/docs/changelog/126493.yaml b/docs/changelog/126493.yaml new file mode 100644 index 0000000000000..84a54b1058827 --- /dev/null +++ b/docs/changelog/126493.yaml @@ -0,0 +1,6 @@ +pr: 126493 +summary: Bedrock Cohere Task Settings Support +area: Machine Learning +type: enhancement +issues: + - 126156 diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 71afb52902f35..bd12813ee4b7c 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -201,6 +201,7 @@ static TransportVersion def(int id) { public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14); public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15); public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16); + public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 1089ca101bef6..baefccae4fc84 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -41,6 +41,7 @@ import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettings; @@ -173,8 +174,13 @@ private static void addAmazonBedrockNamedWriteables(List input, @Nullable InputType inputType) implements ToXContentObject { +public record AmazonBedrockCohereEmbeddingsRequestEntity( + List input, + @Nullable InputType inputType, + AmazonBedrockEmbeddingsTaskSettings taskSettings +) implements ToXContentObject { private static final String TEXTS_FIELD = "texts"; private static final String INPUT_TYPE_FIELD = "input_type"; @@ -26,9 +31,11 @@ public record AmazonBedrockCohereEmbeddingsRequestEntity(List input, @Nu private static final String SEARCH_QUERY = "search_query"; private static final String CLUSTERING = "clustering"; private static final String CLASSIFICATION = "classification"; + private static final String TRUNCATE = "truncate"; public AmazonBedrockCohereEmbeddingsRequestEntity { Objects.requireNonNull(input); + Objects.requireNonNull(taskSettings); } @Override @@ -43,6 +50,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(INPUT_TYPE_FIELD, SEARCH_DOCUMENT); } + if (taskSettings.cohereTruncation() != null) { + builder.field(TRUNCATE, taskSettings.cohereTruncation().name()); + } + builder.endObject(); return builder; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsEntityFactory.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsEntityFactory.java index b28f1856506fd..e1ca4a51a913d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsEntityFactory.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsEntityFactory.java @@ -39,7 +39,7 @@ public static ToXContent createEntity( return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0)); } case COHERE -> { - return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType); + return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType, model.getTaskSettings()); } default -> { return null; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsRequest.java index 210d1579f714f..3a33962aaf2e9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsRequest.java @@ -76,6 +76,9 @@ protected void executeRequest(AmazonBedrockBaseClient client) { @Override public Request truncate() { + if (provider == AmazonBedrockProvider.COHERE) { + return this; // Cohere has its own truncation logic + } var truncatedInput = truncator.truncate(truncationResult.input()); return new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, embeddingsModel, requestEntity, timeout); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java index 1755dac2ac13f..b9e3a237a3cc6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java @@ -19,6 +19,8 @@ public class AmazonBedrockConstants { public static final String TOP_K_FIELD = "top_k"; public static final String MAX_NEW_TOKENS_FIELD = "max_new_tokens"; + public static final String TRUNCATE_FIELD = "truncate"; + public static final Double MIN_TEMPERATURE_TOP_P_TOP_K_VALUE = 0.0; public static final Double MAX_TEMPERATURE_TOP_P_TOP_K_VALUE = 1.0; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 1e868bead3ee7..520b5c6b91549 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -303,6 +303,7 @@ private static AmazonBedrockModel createModel( context ); checkProviderForTask(TaskType.TEXT_EMBEDDING, model.provider()); + checkTaskSettingsForTextEmbeddingModel(model); return model; } case COMPLETION -> { @@ -368,6 +369,17 @@ private static void checkProviderForTask(TaskType taskType, AmazonBedrockProvide } } + private static void checkTaskSettingsForTextEmbeddingModel(AmazonBedrockEmbeddingsModel model) { + if (model.provider() != AmazonBedrockProvider.COHERE && model.getTaskSettings().cohereTruncation() != null) { + throw new ElasticsearchStatusException( + "The [{}] task type for provider [{}] does not allow [truncate] field", + RestStatus.BAD_REQUEST, + TaskType.TEXT_EMBEDDING, + model.provider() + ); + } + } + private static void checkChatCompletionProviderForTopKParameter(AmazonBedrockChatCompletionModel model) { var taskSettings = model.getTaskSettings(); if (taskSettings.topK() != null) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java index 30703f584739d..f7874304b457c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java @@ -7,14 +7,11 @@ package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings; -import org.elasticsearch.common.ValidationException; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.ServiceSettings; -import org.elasticsearch.inference.TaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; @@ -28,10 +25,8 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel { public static AmazonBedrockEmbeddingsModel of(AmazonBedrockEmbeddingsModel embeddingsModel, Map taskSettings) { if (taskSettings != null && taskSettings.isEmpty() == false) { - // no task settings allowed - var validationException = new ValidationException(); - validationException.addValidationError("Amazon Bedrock embeddings model cannot have task settings"); - throw validationException; + var updatedTaskSettings = embeddingsModel.getTaskSettings().updatedTaskSettings(taskSettings); + return new AmazonBedrockEmbeddingsModel(embeddingsModel, updatedTaskSettings); } return embeddingsModel; @@ -52,7 +47,7 @@ public AmazonBedrockEmbeddingsModel( taskType, service, AmazonBedrockEmbeddingsServiceSettings.fromMap(serviceSettings, context), - new EmptyTaskSettings(), + AmazonBedrockEmbeddingsTaskSettings.fromMap(taskSettings), chunkingSettings, AwsSecretSettings.fromMap(secretSettings) ); @@ -63,12 +58,12 @@ public AmazonBedrockEmbeddingsModel( TaskType taskType, String service, AmazonBedrockEmbeddingsServiceSettings serviceSettings, - TaskSettings taskSettings, + AmazonBedrockEmbeddingsTaskSettings taskSettings, ChunkingSettings chunkingSettings, AwsSecretSettings secrets ) { super( - new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings(), chunkingSettings), + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings), new ModelSecrets(secrets) ); } @@ -77,6 +72,10 @@ public AmazonBedrockEmbeddingsModel(Model model, ServiceSettings serviceSettings super(model, serviceSettings); } + public AmazonBedrockEmbeddingsModel(Model model, AmazonBedrockEmbeddingsTaskSettings taskSettings) { + super(model, taskSettings); + } + @Override public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map taskSettings) { return creator.create(this, taskSettings); @@ -86,4 +85,9 @@ public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map map) { + if (map == null || map.isEmpty()) { + return EMPTY; + } + + ValidationException validationException = new ValidationException(); + + var cohereTruncation = extractOptionalEnum( + map, + TRUNCATE_FIELD, + ModelConfigurations.TASK_SETTINGS, + CohereTruncation::fromString, + CohereTruncation.ALL, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AmazonBedrockEmbeddingsTaskSettings(cohereTruncation); + } + + public AmazonBedrockEmbeddingsTaskSettings(StreamInput in) throws IOException { + this(in.readOptionalEnum(CohereTruncation.class)); + } + + @Override + public boolean isEmpty() { + return cohereTruncation() == null; + } + + @Override + public AmazonBedrockEmbeddingsTaskSettings updatedTaskSettings(Map newSettings) { + var newTaskSettings = fromMap(new HashMap<>(newSettings)); + + return new AmazonBedrockEmbeddingsTaskSettings(firstNonNullOrNull(newTaskSettings.cohereTruncation(), cohereTruncation())); + } + + private static T firstNonNullOrNull(T first, T second) { + return first != null ? first : second; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.AMAZON_BEDROCK_TASK_SETTINGS_8_19; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalEnum(cohereTruncation()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (cohereTruncation != null) { + builder.field(TRUNCATE_FIELD, cohereTruncation); + } + return builder.endObject(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java index e7c9a0247bb1a..16164b2d887fe 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereTruncation.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.services.cohere; +import java.util.EnumSet; import java.util.Locale; /** @@ -31,6 +32,8 @@ public enum CohereTruncation { */ END; + public static final EnumSet ALL = EnumSet.allOf(CohereTruncation.class); + @Override public String toString() { return name().toLowerCase(Locale.ROOT); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java index 88bb50def78fe..09d78708b688d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/embeddings/CohereEmbeddingsTaskSettings.java @@ -20,7 +20,6 @@ import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; import java.io.IOException; -import java.util.EnumSet; import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -63,7 +62,7 @@ public static CohereEmbeddingsTaskSettings fromMap(Map map) { TRUNCATE, ModelConfigurations.TASK_SETTINGS, CohereTruncation::fromString, - EnumSet.allOf(CohereTruncation.class), + CohereTruncation.ALL, validationException ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntityTests.java index 613c67f0b282f..b2b8714413329 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntityTests.java @@ -10,6 +10,9 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockJsonBuilder; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; import java.io.IOException; import java.util.List; @@ -18,23 +21,46 @@ public class AmazonBedrockCohereEmbeddingsRequestEntityTests extends ESTestCase { public void testRequestEntity_GeneratesExpectedJsonBody() throws IOException { - var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input"), InputType.CLASSIFICATION); + var entity = new AmazonBedrockCohereEmbeddingsRequestEntity( + List.of("test input"), + InputType.CLASSIFICATION, + AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings() + ); var builder = new AmazonBedrockJsonBuilder(entity); var result = builder.getStringContent(); assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"classification\"}")); } public void testRequestEntity_GeneratesExpectedJsonBody_WithInternalInputType() throws IOException { - var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input"), InputType.INTERNAL_SEARCH); + var entity = new AmazonBedrockCohereEmbeddingsRequestEntity( + List.of("test input"), + InputType.INTERNAL_SEARCH, + AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings() + ); var builder = new AmazonBedrockJsonBuilder(entity); var result = builder.getStringContent(); assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_query\"}")); } public void testRequestEntity_GeneratesExpectedJsonBody_WithoutInputType() throws IOException { - var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input"), null); + var entity = new AmazonBedrockCohereEmbeddingsRequestEntity( + List.of("test input"), + null, + AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings() + ); var builder = new AmazonBedrockJsonBuilder(entity); var result = builder.getStringContent(); assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_document\"}")); } + + public void testRequestEntity_GeneratesExpectedJsonBody_WithCohereTruncation() throws IOException { + var entity = new AmazonBedrockCohereEmbeddingsRequestEntity( + List.of("test input"), + null, + new AmazonBedrockEmbeddingsTaskSettings(CohereTruncation.START) + ); + var builder = new AmazonBedrockJsonBuilder(entity); + var result = builder.getStringContent(); + assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_document\",\"truncate\":\"START\"}")); + } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index ac301c19ba69f..0d37859cb3690 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionTestUtils; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.bytes.BytesArray; @@ -50,6 +51,8 @@ import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests; import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettingsTests; +import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; import org.hamcrest.CoreMatchers; import org.hamcrest.Matchers; @@ -104,7 +107,7 @@ public void shutdown() throws IOException { public void testParseRequestConfig_CreatesAnAmazonBedrockModel() throws IOException { try (var service = createAmazonBedrockService()) { - ActionListener modelVerificationListener = ActionListener.wrap(model -> { + ActionListener modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> { assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); @@ -114,7 +117,7 @@ public void testParseRequestConfig_CreatesAnAmazonBedrockModel() throws IOExcept var secretSettings = (AwsSecretSettings) model.getSecretSettings(); assertThat(secretSettings.accessKey().toString(), is("access")); assertThat(secretSettings.secretKey().toString(), is("secret")); - }, exception -> fail("Unexpected exception: " + exception)); + }); service.parseRequestConfig( "id", @@ -129,15 +132,62 @@ public void testParseRequestConfig_CreatesAnAmazonBedrockModel() throws IOExcept } } - public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { + public void testParseRequestConfig_CreatesACohereModel() throws IOException { try (var service = createAmazonBedrockService()) { - ActionListener modelVerificationListener = ActionListener.wrap( - model -> fail("Expected exception, but got model: " + model), - exception -> { - assertThat(exception, instanceOf(ElasticsearchStatusException.class)); - assertThat(exception.getMessage(), is("The [amazonbedrock] service does not support task type [sparse_embedding]")); - } + ActionListener modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> { + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.modelId(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.COHERE)); + var secretSettings = (AwsSecretSettings) model.getSecretSettings(); + assertThat(secretSettings.accessKey().toString(), is("access")); + assertThat(secretSettings.secretKey().toString(), is("secret")); + }); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + createEmbeddingsRequestSettingsMap("region", "model", "cohere", null, null, null, null), + AmazonBedrockEmbeddingsTaskSettingsTests.mutableMap("truncate", CohereTruncation.START), + getAmazonBedrockSecretSettingsMap("access", "secret") + ), + modelVerificationListener ); + } + } + + public void testParseRequestConfig_CohereSettingsWithNoCohereModel() throws IOException { + try (var service = createAmazonBedrockService()) { + ActionListener modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("The [text_embedding] task type for provider [amazontitan] does not allow [truncate] field") + ); + }); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, null, null, null), + AmazonBedrockEmbeddingsTaskSettingsTests.mutableMap("truncate", CohereTruncation.START), + getAmazonBedrockSecretSettingsMap("access", "secret") + ), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { + try (var service = createAmazonBedrockService()) { + ActionListener modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), is("The [amazonbedrock] service does not support task type [sparse_embedding]")); + }); service.parseRequestConfig( "id", @@ -246,13 +296,10 @@ public void testGetConfiguration() throws Exception { public void testCreateModel_ForEmbeddingsTask_InvalidProvider() throws IOException { try (var service = createAmazonBedrockService()) { - ActionListener modelVerificationListener = ActionListener.wrap( - model -> fail("Expected exception, but got model: " + model), - exception -> { - assertThat(exception, instanceOf(ElasticsearchStatusException.class)); - assertThat(exception.getMessage(), is("The [text_embedding] task type for provider [anthropic] is not available")); - } - ); + ActionListener modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), is("The [text_embedding] task type for provider [anthropic] is not available")); + }); service.parseRequestConfig( "id", @@ -269,13 +316,10 @@ public void testCreateModel_ForEmbeddingsTask_InvalidProvider() throws IOExcepti public void testCreateModel_TopKParameter_NotAvailable() throws IOException { try (var service = createAmazonBedrockService()) { - ActionListener modelVerificationListener = ActionListener.wrap( - model -> fail("Expected exception, but got model: " + model), - exception -> { - assertThat(exception, instanceOf(ElasticsearchStatusException.class)); - assertThat(exception.getMessage(), is("The [top_k] task parameter is not available for provider [amazontitan]")); - } - ); + ActionListener modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), is("The [top_k] task parameter is not available for provider [amazontitan]")); + }); service.parseRequestConfig( "id", @@ -300,16 +344,13 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws I config.put("extra_key", "value"); - ActionListener modelVerificationListener = ActionListener.wrap( - model -> fail("Expected exception, but got model: " + model), - exception -> { - assertThat(exception, instanceOf(ElasticsearchStatusException.class)); - assertThat( - exception.getMessage(), - is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") - ); - } - ); + ActionListener modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") + ); + }); service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, modelVerificationListener); } @@ -322,9 +363,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMa var config = getRequestConfigMap(serviceSettings, Map.of(), getAmazonBedrockSecretSettingsMap("access", "secret")); - ActionListener modelVerificationListener = ActionListener.wrap((model) -> { - fail("Expected exception, but got model: " + model); - }, e -> { + ActionListener modelVerificationListener = ActionTestUtils.assertNoSuccessListener(e -> { assertThat(e, instanceOf(ElasticsearchStatusException.class)); assertThat( e.getMessage(), @@ -346,9 +385,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); - ActionListener modelVerificationListener = ActionListener.wrap((model) -> { - fail("Expected exception, but got model: " + model); - }, e -> { + ActionListener modelVerificationListener = ActionTestUtils.assertNoSuccessListener(e -> { assertThat(e, instanceOf(ElasticsearchStatusException.class)); assertThat( e.getMessage(), @@ -370,9 +407,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); - ActionListener modelVerificationListener = ActionListener.wrap((model) -> { - fail("Expected exception, but got model: " + model); - }, e -> { + ActionListener modelVerificationListener = ActionTestUtils.assertNoSuccessListener(e -> { assertThat(e, instanceOf(ElasticsearchStatusException.class)); assertThat( e.getMessage(), @@ -386,7 +421,7 @@ public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap public void testParseRequestConfig_MovesModel() throws IOException { try (var service = createAmazonBedrockService()) { - ActionListener modelVerificationListener = ActionListener.wrap(model -> { + ActionListener modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> { assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); @@ -396,7 +431,7 @@ public void testParseRequestConfig_MovesModel() throws IOException { var secretSettings = (AwsSecretSettings) model.getSecretSettings(); assertThat(secretSettings.accessKey().toString(), is("access")); assertThat(secretSettings.secretKey().toString(), is("secret")); - }, exception -> fail("Unexpected exception: " + exception)); + }); service.parseRequestConfig( "id", @@ -413,7 +448,7 @@ public void testParseRequestConfig_MovesModel() throws IOException { public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsProvided() throws IOException { try (var service = createAmazonBedrockService()) { - ActionListener modelVerificationListener = ActionListener.wrap(model -> { + ActionListener modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> { assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); @@ -424,7 +459,7 @@ public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChun assertThat(secretSettings.accessKey().toString(), is("access")); assertThat(secretSettings.secretKey().toString(), is("secret")); assertThat(model.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - }, exception -> fail("Unexpected exception: " + exception)); + }); service.parseRequestConfig( "id", @@ -442,7 +477,7 @@ public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChun public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChunkingSettingsNotProvided() throws IOException { try (var service = createAmazonBedrockService()) { - ActionListener modelVerificationListener = ActionListener.wrap(model -> { + ActionListener modelVerificationListener = ActionTestUtils.assertNoFailureListener(model -> { assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); @@ -453,7 +488,7 @@ public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChun assertThat(secretSettings.accessKey().toString(), is("access")); assertThat(secretSettings.secretKey().toString(), is("secret")); assertThat(model.getConfigurations().getChunkingSettings(), instanceOf(ChunkingSettings.class)); - }, exception -> fail("Unexpected exception: " + exception)); + }); service.parseRequestConfig( "id", @@ -470,13 +505,10 @@ public void testParseRequestConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenChun public void testCreateModel_ForEmbeddingsTask_DimensionsIsNotAllowed() throws IOException { try (var service = createAmazonBedrockService()) { - ActionListener modelVerificationListener = ActionListener.wrap( - model -> fail("Expected exception, but got model: " + model), - exception -> { - assertThat(exception, instanceOf(ValidationException.class)); - assertThat(exception.getMessage(), containsString("[service_settings] does not allow the setting [dimensions]")); - } - ); + ActionListener modelVerificationListener = ActionTestUtils.assertNoSuccessListener(exception -> { + assertThat(exception, instanceOf(ValidationException.class)); + assertThat(exception.getMessage(), containsString("[service_settings] does not allow the setting [dimensions]")); + }); service.parseRequestConfig( "id", @@ -496,7 +528,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnAmazonBedrockEmbeddings var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); - var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); var model = service.parsePersistedConfigWithSecrets( "id", @@ -524,7 +556,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnAmazonBedrockEmbeddings var persistedConfig = getPersistedConfigMap( settingsMap, - new HashMap(Map.of()), + new HashMap<>(Map.of()), createRandomChunkingSettingsMap(), secretSettingsMap ); @@ -606,7 +638,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); - var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); persistedConfig.config().put("extra_key", "value"); var model = service.parsePersistedConfigWithSecrets( @@ -634,7 +666,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); secretSettingsMap.put("extra_key", "value"); - var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); var model = service.parsePersistedConfigWithSecrets( "id", @@ -660,7 +692,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); - var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); persistedConfig.secrets().put("extra_key", "value"); var model = service.parsePersistedConfigWithSecrets( @@ -688,7 +720,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe settingsMap.put("extra_key", "value"); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); - var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); var model = service.parsePersistedConfigWithSecrets( "id", @@ -768,7 +800,7 @@ public void testParsePersistedConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenCh var persistedConfig = getPersistedConfigMap( settingsMap, - new HashMap(Map.of()), + new HashMap<>(Map.of()), createRandomChunkingSettingsMap(), secretSettingsMap ); @@ -791,7 +823,7 @@ public void testParsePersistedConfig_CreatesAnAmazonBedrockEmbeddingsModelWhenCh var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); - var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); @@ -835,7 +867,7 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); - var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); var thrownException = expectThrows( ElasticsearchStatusException.class, @@ -854,7 +886,7 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); - var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); persistedConfig.config().put("extra_key", "value"); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); @@ -875,7 +907,7 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettin settingsMap.put("extra_key", "value"); var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); - var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap<>(Map.of()), secretSettingsMap); persistedConfig.config().put("extra_key", "value"); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModelTests.java index e9e31cf0ccca2..be7796a4b7d86 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModelTests.java @@ -7,12 +7,9 @@ package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings; -import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.ChunkingSettings; -import org.elasticsearch.inference.EmptyTaskSettings; -import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; @@ -20,19 +17,23 @@ import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import java.util.Map; +import java.io.IOException; -import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; public class AmazonBedrockEmbeddingsModelTests extends ESTestCase { - public void testCreateModel_withTaskSettings_shouldFail() { - var baseModel = createModel("id", "region", "model", AmazonBedrockProvider.AMAZONTITAN, "accesskey", "secretkey"); - var thrownException = assertThrows( - ValidationException.class, - () -> AmazonBedrockEmbeddingsModel.of(baseModel, Map.of("testkey", "testvalue")) - ); - assertThat(thrownException.getMessage(), containsString("Amazon Bedrock embeddings model cannot have task settings")); + public void testCreateModel_withTaskSettingsOverride() throws IOException { + var baseTaskSettings = AmazonBedrockEmbeddingsTaskSettingsTests.randomTaskSettings(); + var baseModel = createModel("id", "region", "model", AmazonBedrockProvider.AMAZONTITAN, "accesskey", "secretkey", baseTaskSettings); + + var overrideTaskSettings = AmazonBedrockEmbeddingsTaskSettingsTests.mutateTaskSettings(baseTaskSettings); + var overrideTaskSettingsMap = AmazonBedrockEmbeddingsTaskSettingsTests.toMap(overrideTaskSettings); + + var overriddenModel = AmazonBedrockEmbeddingsModel.of(baseModel, overrideTaskSettingsMap); + assertThat(overriddenModel.getTaskSettings(), equalTo(overrideTaskSettings)); + assertThat(overriddenModel.getTaskSettings(), not(equalTo(baseTaskSettings))); } // model creation only - no tests to define, but we want to have the public createModel @@ -46,7 +47,15 @@ public static AmazonBedrockEmbeddingsModel createModel( String accessKey, String secretKey ) { - return createModel(inferenceId, region, model, provider, null, false, null, null, new RateLimitSettings(240), accessKey, secretKey); + return createModel( + inferenceId, + region, + model, + provider, + accessKey, + secretKey, + AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings() + ); } public static AmazonBedrockEmbeddingsModel createModel( @@ -56,9 +65,22 @@ public static AmazonBedrockEmbeddingsModel createModel( AmazonBedrockProvider provider, String accessKey, String secretKey, - InputType inputType + AmazonBedrockEmbeddingsTaskSettings taskSettings ) { - return createModel(inferenceId, region, model, provider, null, false, null, null, new RateLimitSettings(240), accessKey, secretKey); + return createModel( + inferenceId, + region, + model, + provider, + null, + false, + null, + null, + new RateLimitSettings(240), + accessKey, + secretKey, + taskSettings + ); } public static AmazonBedrockEmbeddingsModel createModel( @@ -114,7 +136,7 @@ public static AmazonBedrockEmbeddingsModel createModel( similarity, rateLimitSettings ), - new EmptyTaskSettings(), + AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings(), chunkingSettings, new AwsSecretSettings(new SecureString(accessKey), new SecureString(secretKey)) ); @@ -132,6 +154,36 @@ public static AmazonBedrockEmbeddingsModel createModel( RateLimitSettings rateLimitSettings, String accessKey, String secretKey + ) { + return createModel( + inferenceId, + region, + model, + provider, + dimensions, + dimensionsSetByUser, + maxTokens, + similarity, + rateLimitSettings, + accessKey, + secretKey, + AmazonBedrockEmbeddingsTaskSettingsTests.emptyTaskSettings() + ); + } + + public static AmazonBedrockEmbeddingsModel createModel( + String inferenceId, + String region, + String model, + AmazonBedrockProvider provider, + @Nullable Integer dimensions, + boolean dimensionsSetByUser, + @Nullable Integer maxTokens, + @Nullable SimilarityMeasure similarity, + RateLimitSettings rateLimitSettings, + String accessKey, + String secretKey, + AmazonBedrockEmbeddingsTaskSettings taskSettings ) { return new AmazonBedrockEmbeddingsModel( inferenceId, @@ -147,7 +199,7 @@ public static AmazonBedrockEmbeddingsModel createModel( similarity, rateLimitSettings ), - new EmptyTaskSettings(), + taskSettings, null, new AwsSecretSettings(new SecureString(accessKey), new SecureString(secretKey)) ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsTaskSettingsTests.java new file mode 100644 index 0000000000000..3fc76743cc878 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsTaskSettingsTests.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.InputTypeTests.randomWithoutUnspecified; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TRUNCATE_FIELD; +import static org.hamcrest.Matchers.equalTo; + +public class AmazonBedrockEmbeddingsTaskSettingsTests extends AbstractBWCWireSerializationTestCase { + + public static AmazonBedrockEmbeddingsTaskSettings emptyTaskSettings() { + return AmazonBedrockEmbeddingsTaskSettings.EMPTY; + } + + public static AmazonBedrockEmbeddingsTaskSettings randomTaskSettings() { + var inputType = randomBoolean() ? randomWithoutUnspecified() : null; + var truncation = randomBoolean() ? randomFrom(CohereTruncation.values()) : null; + return new AmazonBedrockEmbeddingsTaskSettings(truncation); + } + + public static AmazonBedrockEmbeddingsTaskSettings mutateTaskSettings(AmazonBedrockEmbeddingsTaskSettings instance) { + return randomValueOtherThanMany( + v -> Objects.equals(instance, v) || (instance.cohereTruncation() != null && v.cohereTruncation() == null), + AmazonBedrockEmbeddingsTaskSettingsTests::randomTaskSettings + ); + } + + @Override + protected AmazonBedrockEmbeddingsTaskSettings mutateInstanceForVersion( + AmazonBedrockEmbeddingsTaskSettings instance, + TransportVersion version + ) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return AmazonBedrockEmbeddingsTaskSettings::new; + } + + @Override + protected AmazonBedrockEmbeddingsTaskSettings createTestInstance() { + return randomTaskSettings(); + } + + @Override + protected AmazonBedrockEmbeddingsTaskSettings mutateInstance(AmazonBedrockEmbeddingsTaskSettings instance) throws IOException { + return mutateTaskSettings(instance); + } + + public void testEmpty() { + assertTrue(emptyTaskSettings().isEmpty()); + assertTrue(AmazonBedrockEmbeddingsTaskSettings.fromMap(null).isEmpty()); + assertTrue(AmazonBedrockEmbeddingsTaskSettings.fromMap(Map.of()).isEmpty()); + } + + public static Map mutableMap(String key, Enum value) { + return new HashMap<>(Map.of(key, value.toString())); + } + + public void testValidCohereTruncations() { + for (var expectedCohereTruncation : CohereTruncation.ALL) { + var map = mutableMap(TRUNCATE_FIELD, expectedCohereTruncation); + var taskSettings = AmazonBedrockEmbeddingsTaskSettings.fromMap(map); + assertFalse(taskSettings.isEmpty()); + assertThat(taskSettings.cohereTruncation(), equalTo(expectedCohereTruncation)); + } + } + + public void testGarbageCohereTruncations() { + var map = new HashMap(Map.of(TRUNCATE_FIELD, "oiuesoirtuoawoeirha")); + assertThrows(ValidationException.class, () -> AmazonBedrockEmbeddingsTaskSettings.fromMap(map)); + } + + public void testXContent() throws IOException { + var taskSettings = randomTaskSettings(); + var taskSettingsAsMap = toMap(taskSettings); + var roundTripTaskSettings = AmazonBedrockEmbeddingsTaskSettings.fromMap(new HashMap<>(taskSettingsAsMap)); + assertThat(roundTripTaskSettings, equalTo(taskSettings)); + } + + public static Map toMap(AmazonBedrockEmbeddingsTaskSettings taskSettings) throws IOException { + try (var builder = JsonXContent.contentBuilder()) { + taskSettings.toXContent(builder, ToXContent.EMPTY_PARAMS); + var taskSettingsBytes = Strings.toString(builder).getBytes(StandardCharsets.UTF_8); + try (var parser = JsonXContent.jsonXContent.createParser(XContentParserConfiguration.EMPTY, taskSettingsBytes)) { + return parser.map(); + } + } + } +}