diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/MapPathExtractor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/MapPathExtractor.java new file mode 100644 index 0000000000000..ea829b2eb7bac --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/MapPathExtractor.java @@ -0,0 +1,209 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common; + +import org.elasticsearch.common.Strings; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; + +/** + * Extracts fields from a {@link Map}. + * + * Uses a subset of the JSONPath schema to extract fields from a map. + * For more information see here. + * + * This implementation differs in how it handles lists in that JSONPath will flatten inner lists. This implementation + * preserves inner lists. + * + * Examples of the schema: + * + *
+ * {@code
+ * $.field1.array[*].field2
+ * $.field1.field2
+ * }
+ * 
+ * + * Given the map + *
+ * {@code
+ * {
+ *     "request_id": "B4AB89C8-B135-xxxx-A6F8-2BAB801A2CE4",
+ *     "latency": 38,
+ *     "usage": {
+ *         "token_count": 3072
+ *     },
+ *     "result": {
+ *         "embeddings": [
+ *             {
+ *                 "index": 0,
+ *                 "embedding": [
+ *                     2,
+ *                     4
+ *                 ]
+ *             },
+ *             {
+ *                 "index": 1,
+ *                 "embedding": [
+ *                     1,
+ *                     2
+ *                 ]
+ *             }
+ *         ]
+ *     }
+ * }
+ * }
+ * 
+ * + *
+ * {@code
+ * var embeddings = MapPathExtractor.extract(map, "$.result.embeddings[*].embedding");
+ * }
+ * 
+ * + * Will result in: + * + *
+ * {@code
+ * [
+ *   [2, 4],
+ *   [1, 2]
+ * ]
+ * }
+ * 
+ * + * This implementation differs from JSONPath when handling a list of maps. JSONPath will flatten the result and return a single array. + * this implementation will preserve each nested list while gathering the results. + * + * For example + * + *
+ * {@code
+ * {
+ *   "result": [
+ *     {
+ *       "key": [
+ *         {
+ *           "a": 1.1
+ *         },
+ *         {
+ *           "a": 2.2
+ *         }
+ *       ]
+ *     },
+ *     {
+ *       "key": [
+ *         {
+ *           "a": 3.3
+ *         },
+ *         {
+ *           "a": 4.4
+ *         }
+ *       ]
+ *     }
+ *   ]
+ * }
+ * }
+ * {@code var embeddings = MapPathExtractor.extract(map, "$.result[*].key[*].a");}
+ *
+ * JSONPath: {@code [1.1, 2.2, 3.3, 4.4]}
+ * This implementation: {@code [[1.1, 2.2], [3.3, 4.4]]}
+ * 
+ */ +public class MapPathExtractor { + + private static final String DOLLAR = "$"; + + // default for testing + static final Pattern dotFieldPattern = Pattern.compile("^\\.([^.\\[]+)(.*)"); + static final Pattern arrayWildcardPattern = Pattern.compile("^\\[\\*\\](.*)"); + + public static Object extract(Map data, String path) { + if (data == null || data.isEmpty() || path == null || path.trim().isEmpty()) { + return null; + } + + var cleanedPath = path.trim(); + + if (cleanedPath.startsWith(DOLLAR)) { + cleanedPath = cleanedPath.substring(DOLLAR.length()); + } else { + throw new IllegalArgumentException(Strings.format("Path [%s] must start with a dollar sign ($)", cleanedPath)); + } + + return navigate(data, cleanedPath); + } + + private static Object navigate(Object current, String remainingPath) { + if (current == null || remainingPath == null || remainingPath.isEmpty()) { + return current; + } + + var dotFieldMatcher = dotFieldPattern.matcher(remainingPath); + var arrayWildcardMatcher = arrayWildcardPattern.matcher(remainingPath); + + if (dotFieldMatcher.matches()) { + String field = dotFieldMatcher.group(1); + if (field == null || field.isEmpty()) { + throw new IllegalArgumentException( + Strings.format( + "Unable to extract field from remaining path [%s]. Fields must be delimited by a dot character.", + remainingPath + ) + ); + } + + String nextPath = dotFieldMatcher.group(2); + if (current instanceof Map currentMap) { + var fieldFromMap = currentMap.get(field); + if (fieldFromMap == null) { + throw new IllegalArgumentException(Strings.format("Unable to find field [%s] in map", field)); + } + + return navigate(currentMap.get(field), nextPath); + } else { + throw new IllegalArgumentException( + Strings.format( + "Current path [%s] matched the dot field pattern but the current object is not a map, " + + "found invalid type [%s] instead.", + remainingPath, + current.getClass().getSimpleName() + ) + ); + } + } else if (arrayWildcardMatcher.matches()) { + String nextPath = arrayWildcardMatcher.group(1); + if (current instanceof List list) { + List results = new ArrayList<>(); + + for (Object item : list) { + Object result = navigate(item, nextPath); + if (result != null) { + results.add(result); + } + } + + return results; + } else { + throw new IllegalArgumentException( + Strings.format( + "Current path [%s] matched the array field pattern but the current object is not a list, " + + "found invalid type [%s] instead.", + remainingPath, + current.getClass().getSimpleName() + ) + ); + } + } + + throw new IllegalArgumentException(Strings.format("Invalid path received [%s], unable to extract a field name.", remainingPath)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/MapPathExtractorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/MapPathExtractorTests.java new file mode 100644 index 0000000000000..cd084ca224798 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/MapPathExtractorTests.java @@ -0,0 +1,187 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.common; + +import org.elasticsearch.test.ESTestCase; + +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class MapPathExtractorTests extends ESTestCase { + public void testExtract_RetrievesListOfLists() { + Map input = Map.of( + "result", + Map.of("embeddings", List.of(Map.of("index", 0, "embedding", List.of(1, 2)), Map.of("index", 1, "embedding", List.of(3, 4)))) + ); + + assertThat(MapPathExtractor.extract(input, "$.result.embeddings[*].embedding"), is(List.of(List.of(1, 2), List.of(3, 4)))); + } + + public void testExtract_IteratesListOfMapsToListOfStrings() { + Map input = Map.of( + "result", + List.of(Map.of("key", List.of("value1", "value2")), Map.of("key", List.of("value3", "value4"))) + ); + + assertThat( + MapPathExtractor.extract(input, "$.result[*].key[*]"), + is(List.of(List.of("value1", "value2"), List.of("value3", "value4"))) + ); + } + + public void testExtract_IteratesListOfMapsToListOfMapsOfStringToDoubles() { + Map input = Map.of( + "result", + List.of( + Map.of("key", List.of(Map.of("a", 1.1d), Map.of("a", 2.2d))), + Map.of("key", List.of(Map.of("a", 3.3d), Map.of("a", 4.4d))) + ) + ); + + assertThat(MapPathExtractor.extract(input, "$.result[*].key[*].a"), is(List.of(List.of(1.1d, 2.2d), List.of(3.3d, 4.4d)))); + } + + public void testExtract_ReturnsNullForEmptyList() { + Map input = Map.of(); + + assertNull(MapPathExtractor.extract(input, "$.awesome")); + } + + public void testExtract_ReturnsNull_WhenTheInputMapIsNull() { + assertNull(MapPathExtractor.extract(null, "$.result")); + } + + public void testExtract_ReturnsNull_WhenPathIsNull() { + assertNull(MapPathExtractor.extract(Map.of("key", "value"), null)); + } + + public void testExtract_ReturnsNull_WhenPathIsWhiteSpace() { + assertNull(MapPathExtractor.extract(Map.of("key", "value"), " ")); + } + + public void testExtract_ThrowsException_WhenPathDoesNotStartWithDollarSign() { + var exception = expectThrows(IllegalArgumentException.class, () -> MapPathExtractor.extract(Map.of("key", "value"), ".key")); + assertThat(exception.getMessage(), is("Path [.key] must start with a dollar sign ($)")); + } + + public void testExtract_ThrowsException_WhenCannotFindField() { + Map input = Map.of("result", "key"); + + var exception = expectThrows(IllegalArgumentException.class, () -> MapPathExtractor.extract(input, "$.awesome")); + assertThat(exception.getMessage(), is("Unable to find field [awesome] in map")); + } + + public void testExtract_ThrowsAnException_WhenThePathIsInvalid() { + Map input = Map.of("result", "key"); + + var exception = expectThrows(IllegalArgumentException.class, () -> MapPathExtractor.extract(input, "$awesome")); + assertThat(exception.getMessage(), is("Invalid path received [awesome], unable to extract a field name.")); + } + + public void testExtract_ThrowsException_WhenMissingArraySyntax() { + Map input = Map.of( + "result", + Map.of("embeddings", List.of(Map.of("index", 0, "embedding", List.of(1, 2)), Map.of("index", 1, "embedding", List.of(3, 4)))) + ); + + var exception = expectThrows( + IllegalArgumentException.class, + // embeddings is missing [*] to indicate that it is an array + () -> MapPathExtractor.extract(input, "$.result.embeddings.embedding") + ); + assertThat( + exception.getMessage(), + is( + "Current path [.embedding] matched the dot field pattern but the current object " + + "is not a map, found invalid type [List12] instead." + ) + ); + } + + public void testExtract_ThrowsException_WhenHasArraySyntaxButIsAMap() { + Map input = Map.of( + "result", + Map.of("embeddings", List.of(Map.of("index", 0, "embedding", List.of(1, 2)), Map.of("index", 1, "embedding", List.of(3, 4)))) + ); + + var exception = expectThrows( + IllegalArgumentException.class, + // result is not an array + () -> MapPathExtractor.extract(input, "$.result[*].embeddings[*].embedding") + ); + assertThat( + exception.getMessage(), + is( + "Current path [[*].embeddings[*].embedding] matched the array field pattern but the current " + + "object is not a list, found invalid type [Map1] instead." + ) + ); + } + + public void testExtract_ReturnsAnEmptyList_WhenItIsEmpty() { + Map input = Map.of("result", List.of()); + + assertThat(MapPathExtractor.extract(input, "$.result"), is(List.of())); + } + + public void testExtract_ReturnsAnEmptyList_WhenItIsEmpty_PathIncludesArray() { + Map input = Map.of("result", List.of()); + + assertThat(MapPathExtractor.extract(input, "$.result[*]"), is(List.of())); + } + + public void testDotFieldPattern() { + { + var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc.123"); + assertTrue(matcher.matches()); + assertThat(matcher.group(1), is("abc")); + assertThat(matcher.group(2), is(".123")); + } + { + var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc[*].123"); + assertTrue(matcher.matches()); + assertThat(matcher.group(1), is("abc")); + assertThat(matcher.group(2), is("[*].123")); + } + { + var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc[.123"); + assertTrue(matcher.matches()); + assertThat(matcher.group(1), is("abc")); + assertThat(matcher.group(2), is("[.123")); + } + { + var matcher = MapPathExtractor.dotFieldPattern.matcher(".abc"); + assertTrue(matcher.matches()); + assertThat(matcher.group(1), is("abc")); + assertThat(matcher.group(2), is("")); + } + } + + public void testArrayWildcardPattern() { + { + var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[*].abc.123"); + assertTrue(matcher.matches()); + assertThat(matcher.group(1), is(".abc.123")); + } + { + var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[*]"); + assertTrue(matcher.matches()); + assertThat(matcher.group(1), is("")); + } + { + var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[1].abc"); + assertFalse(matcher.matches()); + } + { + var matcher = MapPathExtractor.arrayWildcardPattern.matcher("[].abc"); + assertFalse(matcher.matches()); + } + } +}