Skip to content

[ML] Bedrock Cohere Task Settings Support (#126493) #126559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/126493.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 126493
summary: Bedrock Cohere Task Settings Support
area: Machine Learning
type: enhancement
issues:
- 126156
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ static TransportVersion def(int id) {
public static final TransportVersion INTRODUCE_LIFECYCLE_TEMPLATE_8_19 = def(8_841_0_14);
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings;
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettings;
import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettings;
Expand Down Expand Up @@ -173,8 +174,13 @@ private static void addAmazonBedrockNamedWriteables(List<NamedWriteableRegistry.
AmazonBedrockEmbeddingsServiceSettings::new
)
);

// no task settings for Amazon Bedrock Embeddings
namedWriteables.add(
new NamedWriteableRegistry.Entry(
TaskSettings.class,
AmazonBedrockEmbeddingsTaskSettings.NAME,
AmazonBedrockEmbeddingsTaskSettings::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,31 @@
import org.elasticsearch.inference.InputType;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsTaskSettings;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

import static org.elasticsearch.inference.InputType.invalidInputTypeMessage;

public record AmazonBedrockCohereEmbeddingsRequestEntity(List<String> input, @Nullable InputType inputType) implements ToXContentObject {
public record AmazonBedrockCohereEmbeddingsRequestEntity(
List<String> input,
@Nullable InputType inputType,
AmazonBedrockEmbeddingsTaskSettings taskSettings
) implements ToXContentObject {

private static final String TEXTS_FIELD = "texts";
private static final String INPUT_TYPE_FIELD = "input_type";
private static final String SEARCH_DOCUMENT = "search_document";
private static final String SEARCH_QUERY = "search_query";
private static final String CLUSTERING = "clustering";
private static final String CLASSIFICATION = "classification";
private static final String TRUNCATE = "truncate";

public AmazonBedrockCohereEmbeddingsRequestEntity {
Objects.requireNonNull(input);
Objects.requireNonNull(taskSettings);
}

@Override
Expand All @@ -43,6 +50,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(INPUT_TYPE_FIELD, SEARCH_DOCUMENT);
}

if (taskSettings.cohereTruncation() != null) {
builder.field(TRUNCATE, taskSettings.cohereTruncation().name());
}

builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public static ToXContent createEntity(
return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0));
}
case COHERE -> {
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType);
return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput, inputType, model.getTaskSettings());
}
default -> {
return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ protected void executeRequest(AmazonBedrockBaseClient client) {

@Override
public Request truncate() {
if (provider == AmazonBedrockProvider.COHERE) {
return this; // Cohere has its own truncation logic
}
var truncatedInput = truncator.truncate(truncationResult.input());
return new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, embeddingsModel, requestEntity, timeout);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ public class AmazonBedrockConstants {
public static final String TOP_K_FIELD = "top_k";
public static final String MAX_NEW_TOKENS_FIELD = "max_new_tokens";

public static final String TRUNCATE_FIELD = "truncate";

public static final Double MIN_TEMPERATURE_TOP_P_TOP_K_VALUE = 0.0;
public static final Double MAX_TEMPERATURE_TOP_P_TOP_K_VALUE = 1.0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ private static AmazonBedrockModel createModel(
context
);
checkProviderForTask(TaskType.TEXT_EMBEDDING, model.provider());
checkTaskSettingsForTextEmbeddingModel(model);
return model;
}
case COMPLETION -> {
Expand Down Expand Up @@ -368,6 +369,17 @@ private static void checkProviderForTask(TaskType taskType, AmazonBedrockProvide
}
}

private static void checkTaskSettingsForTextEmbeddingModel(AmazonBedrockEmbeddingsModel model) {
if (model.provider() != AmazonBedrockProvider.COHERE && model.getTaskSettings().cohereTruncation() != null) {
throw new ElasticsearchStatusException(
"The [{}] task type for provider [{}] does not allow [truncate] field",
RestStatus.BAD_REQUEST,
TaskType.TEXT_EMBEDDING,
model.provider()
);
}
}

private static void checkChatCompletionProviderForTopKParameter(AmazonBedrockChatCompletionModel model) {
var taskSettings = model.getTaskSettings();
if (taskSettings.topK() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,11 @@

package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;

import org.elasticsearch.common.ValidationException;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.EmptyTaskSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.inference.common.amazon.AwsSecretSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
Expand All @@ -28,10 +25,8 @@ public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel {

public static AmazonBedrockEmbeddingsModel of(AmazonBedrockEmbeddingsModel embeddingsModel, Map<String, Object> taskSettings) {
if (taskSettings != null && taskSettings.isEmpty() == false) {
// no task settings allowed
var validationException = new ValidationException();
validationException.addValidationError("Amazon Bedrock embeddings model cannot have task settings");
throw validationException;
var updatedTaskSettings = embeddingsModel.getTaskSettings().updatedTaskSettings(taskSettings);
return new AmazonBedrockEmbeddingsModel(embeddingsModel, updatedTaskSettings);
}

return embeddingsModel;
Expand All @@ -52,7 +47,7 @@ public AmazonBedrockEmbeddingsModel(
taskType,
service,
AmazonBedrockEmbeddingsServiceSettings.fromMap(serviceSettings, context),
new EmptyTaskSettings(),
AmazonBedrockEmbeddingsTaskSettings.fromMap(taskSettings),
chunkingSettings,
AwsSecretSettings.fromMap(secretSettings)
);
Expand All @@ -63,12 +58,12 @@ public AmazonBedrockEmbeddingsModel(
TaskType taskType,
String service,
AmazonBedrockEmbeddingsServiceSettings serviceSettings,
TaskSettings taskSettings,
AmazonBedrockEmbeddingsTaskSettings taskSettings,
ChunkingSettings chunkingSettings,
AwsSecretSettings secrets
) {
super(
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings(), chunkingSettings),
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
new ModelSecrets(secrets)
);
}
Expand All @@ -77,6 +72,10 @@ public AmazonBedrockEmbeddingsModel(Model model, ServiceSettings serviceSettings
super(model, serviceSettings);
}

public AmazonBedrockEmbeddingsModel(Model model, AmazonBedrockEmbeddingsTaskSettings taskSettings) {
super(model, taskSettings);
}

@Override
public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map<String, Object> taskSettings) {
return creator.create(this, taskSettings);
Expand All @@ -86,4 +85,9 @@ public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map<String, O
public AmazonBedrockEmbeddingsServiceSettings getServiceSettings() {
return (AmazonBedrockEmbeddingsServiceSettings) super.getServiceSettings();
}

@Override
public AmazonBedrockEmbeddingsTaskSettings getTaskSettings() {
return (AmazonBedrockEmbeddingsTaskSettings) super.getTaskSettings();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalEnum;
import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TRUNCATE_FIELD;

public record AmazonBedrockEmbeddingsTaskSettings(@Nullable CohereTruncation cohereTruncation) implements TaskSettings {
public static final AmazonBedrockEmbeddingsTaskSettings EMPTY = new AmazonBedrockEmbeddingsTaskSettings((CohereTruncation) null);
public static final String NAME = "amazon_bedrock_embeddings_task_settings";

public static AmazonBedrockEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
if (map == null || map.isEmpty()) {
return EMPTY;
}

ValidationException validationException = new ValidationException();

var cohereTruncation = extractOptionalEnum(
map,
TRUNCATE_FIELD,
ModelConfigurations.TASK_SETTINGS,
CohereTruncation::fromString,
CohereTruncation.ALL,
validationException
);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return new AmazonBedrockEmbeddingsTaskSettings(cohereTruncation);
}

public AmazonBedrockEmbeddingsTaskSettings(StreamInput in) throws IOException {
this(in.readOptionalEnum(CohereTruncation.class));
}

@Override
public boolean isEmpty() {
return cohereTruncation() == null;
}

@Override
public AmazonBedrockEmbeddingsTaskSettings updatedTaskSettings(Map<String, Object> newSettings) {
var newTaskSettings = fromMap(new HashMap<>(newSettings));

return new AmazonBedrockEmbeddingsTaskSettings(firstNonNullOrNull(newTaskSettings.cohereTruncation(), cohereTruncation()));
}

private static <T> T firstNonNullOrNull(T first, T second) {
return first != null ? first : second;
}

@Override
public String getWriteableName() {
return NAME;
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.AMAZON_BEDROCK_TASK_SETTINGS_8_19;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalEnum(cohereTruncation());
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
if (cohereTruncation != null) {
builder.field(TRUNCATE_FIELD, cohereTruncation);
}
return builder.endObject();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.inference.services.cohere;

import java.util.EnumSet;
import java.util.Locale;

/**
Expand All @@ -31,6 +32,8 @@ public enum CohereTruncation {
*/
END;

public static final EnumSet<CohereTruncation> ALL = EnumSet.allOf(CohereTruncation.class);

@Override
public String toString() {
return name().toLowerCase(Locale.ROOT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import org.elasticsearch.xpack.inference.services.cohere.CohereTruncation;

import java.io.IOException;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -63,7 +62,7 @@ public static CohereEmbeddingsTaskSettings fromMap(Map<String, Object> map) {
TRUNCATE,
ModelConfigurations.TASK_SETTINGS,
CohereTruncation::fromString,
EnumSet.allOf(CohereTruncation.class),
CohereTruncation.ALL,
validationException
);

Expand Down
Loading