Skip to content

[8.x] Support configurable chunking in semantic_text fields (#121041) #126545

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/121041.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 121041
summary: Support configurable chunking in `semantic_text` fields
area: Relevance
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ static TransportVersion def(int id) {
public static final TransportVersion RERANK_COMMON_OPTIONS_ADDED_8_19 = def(8_841_0_15);
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
public static final TransportVersion SEMANTIC_TEXT_CHUNKING_CONFIG_8_19 = def(8_841_0_18);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.TransportVersions;
import org.elasticsearch.cluster.Diff;
import org.elasticsearch.cluster.SimpleDiffable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.xcontent.ToXContentFragment;
Expand All @@ -22,8 +23,11 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.TransportVersions.SEMANTIC_TEXT_CHUNKING_CONFIG_8_19;

/**
* Contains inference field data for fields.
* As inference is done in the coordinator node to avoid re-doing it at shard / replica level, the coordinator needs to check for the need
Expand All @@ -35,21 +39,30 @@ public final class InferenceFieldMetadata implements SimpleDiffable<InferenceFie
private static final String INFERENCE_ID_FIELD = "inference_id";
private static final String SEARCH_INFERENCE_ID_FIELD = "search_inference_id";
private static final String SOURCE_FIELDS_FIELD = "source_fields";
static final String CHUNKING_SETTINGS_FIELD = "chunking_settings";

private final String name;
private final String inferenceId;
private final String searchInferenceId;
private final String[] sourceFields;
private final Map<String, Object> chunkingSettings;

public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields) {
this(name, inferenceId, inferenceId, sourceFields);
public InferenceFieldMetadata(String name, String inferenceId, String[] sourceFields, Map<String, Object> chunkingSettings) {
this(name, inferenceId, inferenceId, sourceFields, chunkingSettings);
}

public InferenceFieldMetadata(String name, String inferenceId, String searchInferenceId, String[] sourceFields) {
public InferenceFieldMetadata(
String name,
String inferenceId,
String searchInferenceId,
String[] sourceFields,
Map<String, Object> chunkingSettings
) {
this.name = Objects.requireNonNull(name);
this.inferenceId = Objects.requireNonNull(inferenceId);
this.searchInferenceId = Objects.requireNonNull(searchInferenceId);
this.sourceFields = Objects.requireNonNull(sourceFields);
this.chunkingSettings = chunkingSettings != null ? Map.copyOf(chunkingSettings) : null;
}

public InferenceFieldMetadata(StreamInput input) throws IOException {
Expand All @@ -61,6 +74,11 @@ public InferenceFieldMetadata(StreamInput input) throws IOException {
this.searchInferenceId = this.inferenceId;
}
this.sourceFields = input.readStringArray();
if (input.getTransportVersion().onOrAfter(SEMANTIC_TEXT_CHUNKING_CONFIG_8_19)) {
this.chunkingSettings = input.readGenericMap();
} else {
this.chunkingSettings = null;
}
}

@Override
Expand All @@ -71,6 +89,9 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(searchInferenceId);
}
out.writeStringArray(sourceFields);
if (out.getTransportVersion().onOrAfter(SEMANTIC_TEXT_CHUNKING_CONFIG_8_19)) {
out.writeGenericMap(chunkingSettings);
}
}

@Override
Expand All @@ -81,16 +102,22 @@ public boolean equals(Object o) {
return Objects.equals(name, that.name)
&& Objects.equals(inferenceId, that.inferenceId)
&& Objects.equals(searchInferenceId, that.searchInferenceId)
&& Arrays.equals(sourceFields, that.sourceFields);
&& Arrays.equals(sourceFields, that.sourceFields)
&& Objects.equals(chunkingSettings, that.chunkingSettings);
}

@Override
public int hashCode() {
int result = Objects.hash(name, inferenceId, searchInferenceId);
int result = Objects.hash(name, inferenceId, searchInferenceId, chunkingSettings);
result = 31 * result + Arrays.hashCode(sourceFields);
return result;
}

@Override
public String toString() {
return Strings.toString(this);
}

public String getName() {
return name;
}
Expand All @@ -107,6 +134,10 @@ public String[] getSourceFields() {
return sourceFields;
}

public Map<String, Object> getChunkingSettings() {
return chunkingSettings;
}

public static Diff<InferenceFieldMetadata> readDiffFrom(StreamInput in) throws IOException {
return SimpleDiffable.readDiffFrom(InferenceFieldMetadata::new, in);
}
Expand All @@ -119,6 +150,11 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId);
}
builder.array(SOURCE_FIELDS_FIELD, sourceFields);
if (chunkingSettings != null) {
builder.startObject(CHUNKING_SETTINGS_FIELD);
builder.mapContents(chunkingSettings);
builder.endObject();
}
return builder.endObject();
}

Expand All @@ -131,6 +167,7 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws
String currentFieldName = null;
String inferenceId = null;
String searchInferenceId = null;
Map<String, Object> chunkingSettings = null;
List<String> inputFields = new ArrayList<>();
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
if (token == XContentParser.Token.FIELD_NAME) {
Expand All @@ -151,6 +188,8 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws
}
}
}
} else if (CHUNKING_SETTINGS_FIELD.equals(currentFieldName)) {
chunkingSettings = parser.map();
} else {
parser.skipChildren();
}
Expand All @@ -159,7 +198,8 @@ public static InferenceFieldMetadata fromXContent(XContentParser parser) throws
name,
inferenceId,
searchInferenceId == null ? inferenceId : searchInferenceId,
inputFields.toArray(String[]::new)
inputFields.toArray(String[]::new),
chunkingSettings
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/

package org.elasticsearch.inference;

import org.elasticsearch.core.Nullable;

import java.util.List;

public record ChunkInferenceInput(String input, @Nullable ChunkingSettings chunkingSettings) {

public ChunkInferenceInput(String input) {
this(input, null);
}

public static List<String> inputs(List<ChunkInferenceInput> chunkInferenceInputs) {
return chunkInferenceInputs.stream().map(ChunkInferenceInput::input).toList();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
import org.elasticsearch.common.io.stream.VersionedNamedWriteable;
import org.elasticsearch.xcontent.ToXContentObject;

import java.util.Map;

public interface ChunkingSettings extends ToXContentObject, VersionedNamedWriteable {
ChunkingStrategy getChunkingStrategy();

Map<String, Object> asMap();
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,18 @@ void unifiedCompletionInfer(
/**
* Chunk long text.
*
* @param model The model
* @param query Inference query, mainly for re-ranking
* @param input Inference input
* @param taskSettings Settings in the request to override the model's defaults
* @param inputType For search, ingest etc
* @param timeout The timeout for the request
* @param listener Chunked Inference result listener
* @param model The model
* @param query Inference query, mainly for re-ranking
* @param input Inference input
* @param taskSettings Settings in the request to override the model's defaults
* @param inputType For search, ingest etc
* @param timeout The timeout for the request
* @param listener Chunked Inference result listener
*/
void chunkedInfer(
Model model,
@Nullable String query,
List<String> input,
List<ChunkInferenceInput> input,
Map<String, Object> taskSettings,
InputType inputType,
TimeValue timeout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,8 @@ private static InferenceFieldMetadata randomInferenceFieldMetadata(String name)
name,
randomIdentifier(),
randomIdentifier(),
randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new)
randomSet(1, 5, ESTestCase::randomIdentifier).toArray(String[]::new),
InferenceFieldMetadataTests.generateRandomChunkingSettings()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
import org.elasticsearch.xcontent.XContentParser;

import java.io.IOException;
import java.util.Map;
import java.util.function.Predicate;

import static org.elasticsearch.cluster.metadata.InferenceFieldMetadata.CHUNKING_SETTINGS_FIELD;
import static org.hamcrest.Matchers.equalTo;

public class InferenceFieldMetadataTests extends AbstractXContentTestCase<InferenceFieldMetadata> {
Expand All @@ -37,11 +39,6 @@ protected InferenceFieldMetadata createTestInstance() {
return createTestItem();
}

@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
return p -> p.equals(""); // do not add elements at the top-level as any element at this level is parsed as a new inference field
}

@Override
protected InferenceFieldMetadata doParseInstance(XContentParser parser) throws IOException {
if (parser.nextToken() == XContentParser.Token.START_OBJECT) {
Expand All @@ -58,18 +55,57 @@ protected boolean supportsUnknownFields() {
return true;
}

@Override
protected Predicate<String> getRandomFieldsExcludeFilter() {
// do not add elements at the top-level as any element at this level is parsed as a new inference field,
// and do not add additional elements to chunking maps as they will fail parsing with extra data
return field -> field.equals("") || field.contains(CHUNKING_SETTINGS_FIELD);
}

private static InferenceFieldMetadata createTestItem() {
String name = randomAlphaOfLengthBetween(3, 10);
String inferenceId = randomIdentifier();
String searchInferenceId = randomIdentifier();
String[] inputFields = generateRandomStringArray(5, 10, false, false);
return new InferenceFieldMetadata(name, inferenceId, searchInferenceId, inputFields);
Map<String, Object> chunkingSettings = generateRandomChunkingSettings();
return new InferenceFieldMetadata(name, inferenceId, searchInferenceId, inputFields, chunkingSettings);
}

public static Map<String, Object> generateRandomChunkingSettings() {
if (randomBoolean()) {
return null; // Defaults to model chunking settings
}
return randomBoolean() ? generateRandomWordBoundaryChunkingSettings() : generateRandomSentenceBoundaryChunkingSettings();
}

private static Map<String, Object> generateRandomWordBoundaryChunkingSettings() {
return Map.of("strategy", "word_boundary", "max_chunk_size", randomIntBetween(20, 100), "overlap", randomIntBetween(1, 50));
}

private static Map<String, Object> generateRandomSentenceBoundaryChunkingSettings() {
return Map.of(
"strategy",
"sentence_boundary",
"max_chunk_size",
randomIntBetween(20, 100),
"sentence_overlap",
randomIntBetween(0, 1)
);
}

public void testNullCtorArgsThrowException() {
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata(null, "inferenceId", "searchInferenceId", new String[0]));
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", null, "searchInferenceId", new String[0]));
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null, new String[0]));
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", "searchInferenceId", null));
assertThrows(
NullPointerException.class,
() -> new InferenceFieldMetadata(null, "inferenceId", "searchInferenceId", new String[0], Map.of())
);
assertThrows(
NullPointerException.class,
() -> new InferenceFieldMetadata("name", null, "searchInferenceId", new String[0], Map.of())
);
assertThrows(NullPointerException.class, () -> new InferenceFieldMetadata("name", "inferenceId", null, new String[0], Map.of()));
assertThrows(
NullPointerException.class,
() -> new InferenceFieldMetadata("name", "inferenceId", "searchInferenceId", null, Map.of())
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.apache.lucene.search.Query;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadata;
import org.elasticsearch.cluster.metadata.InferenceFieldMetadataTests;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.plugins.MapperPlugin;
import org.elasticsearch.plugins.Plugin;
Expand Down Expand Up @@ -102,7 +103,13 @@ private static class TestInferenceFieldMapper extends FieldMapper implements Inf

@Override
public InferenceFieldMetadata getMetadata(Set<String> sourcePaths) {
return new InferenceFieldMetadata(fullPath(), INFERENCE_ID, SEARCH_INFERENCE_ID, sourcePaths.toArray(new String[0]));
return new InferenceFieldMetadata(
fullPath(),
INFERENCE_ID,
SEARCH_INFERENCE_ID,
sourcePaths.toArray(new String[0]),
InferenceFieldMetadataTests.generateRandomChunkingSettings()
);
}

@Override
Expand Down
Loading