diff --git a/docs/changelog/139379.yaml b/docs/changelog/139379.yaml new file mode 100644 index 0000000000000..bfaaa4f253837 --- /dev/null +++ b/docs/changelog/139379.yaml @@ -0,0 +1,5 @@ +pr: 139379 +summary: Add MV_INTERSECT Function +area: ES|QL +type: enhancement +issues: [] diff --git a/docs/reference/query-languages/esql/_snippets/functions/description/mv_intersect.md b/docs/reference/query-languages/esql/_snippets/functions/description/mv_intersect.md new file mode 100644 index 0000000000000..ada84a68d29c7 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/description/mv_intersect.md @@ -0,0 +1,6 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Description** + +Returns a subset of the inputs sets that contains the intersection of values in provided mv arguments. + diff --git a/docs/reference/query-languages/esql/_snippets/functions/examples/mv_intersect.md b/docs/reference/query-languages/esql/_snippets/functions/examples/mv_intersect.md new file mode 100644 index 0000000000000..b255719eec46c --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/examples/mv_intersect.md @@ -0,0 +1,55 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Examples** + +```esql +ROW a = [1, 2, 3, 4, 5], b = [2, 3, 4, 5, 6] +| EVAL finalValue = MV_INTERSECT(a, b) +| KEEP finalValue +``` + +| finalValue:integer | +| --- | +| [2, 3, 4, 5] | + +```esql +ROW a = [1, 2, 3, 4, 5]::long, b = [2, 3, 4, 5, 6]::long +| EVAL finalValue = MV_INTERSECT(a, b) +| KEEP finalValue +``` + +| finalValue:long | +| --- | +| [2, 3, 4, 5] | + +```esql +ROW a = [true, false, false, false], b = [false] +| EVAL finalValue = MV_INTERSECT(a, b) +| KEEP finalValue +``` + +| finalValue:boolean | +| --- | +| [false] | + +```esql +ROW a = [5.2, 10.5, 1.12345, 2.6928], b = [10.5, 2.6928] +| EVAL finalValue = MV_INTERSECT(a, b) +| KEEP finalValue +``` + +| finalValue:double | +| --- | +| [10.5, 2.6928] | + +```esql +ROW a = ["one", "two", "three", "four", "five"], b = ["one", "four"] +| EVAL finalValue = MV_INTERSECT(a, b) +| KEEP finalValue +``` + +| finalValue:keyword | +| --- | +| ["one", "four"] | + + diff --git a/docs/reference/query-languages/esql/_snippets/functions/layout/mv_intersect.md b/docs/reference/query-languages/esql/_snippets/functions/layout/mv_intersect.md new file mode 100644 index 0000000000000..908c90043fae8 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/layout/mv_intersect.md @@ -0,0 +1,27 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +## `MV_INTERSECT` [esql-mv_intersect] +```{applies_to} +stack: preview 9.3.0 +serverless: preview +``` + +**Syntax** + +:::{image} ../../../images/functions/mv_intersect.svg +:alt: Embedded +:class: text-center +::: + + +:::{include} ../parameters/mv_intersect.md +::: + +:::{include} ../description/mv_intersect.md +::: + +:::{include} ../types/mv_intersect.md +::: + +:::{include} ../examples/mv_intersect.md +::: diff --git a/docs/reference/query-languages/esql/_snippets/functions/parameters/mv_intersect.md b/docs/reference/query-languages/esql/_snippets/functions/parameters/mv_intersect.md new file mode 100644 index 0000000000000..8e54ce567a570 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/parameters/mv_intersect.md @@ -0,0 +1,10 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Parameters** + +`field1` +: + +`field2` +: + diff --git a/docs/reference/query-languages/esql/_snippets/functions/types/mv_intersect.md b/docs/reference/query-languages/esql/_snippets/functions/types/mv_intersect.md new file mode 100644 index 0000000000000..64515ccdcbab1 --- /dev/null +++ b/docs/reference/query-languages/esql/_snippets/functions/types/mv_intersect.md @@ -0,0 +1,27 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +**Supported types** + +| field1 | field2 | result | +| --- | --- | --- | +| boolean | boolean | boolean | +| cartesian_point | cartesian_point | cartesian_point | +| cartesian_shape | cartesian_shape | cartesian_shape | +| date | date | date | +| date_nanos | date_nanos | date_nanos | +| double | double | double | +| geo_point | geo_point | geo_point | +| geo_shape | geo_shape | geo_shape | +| geohash | geohash | geohash | +| geohex | geohex | geohex | +| geotile | geotile | geotile | +| integer | integer | integer | +| ip | ip | ip | +| keyword | keyword | keyword | +| keyword | text | keyword | +| long | long | long | +| text | keyword | keyword | +| text | text | keyword | +| unsigned_long | unsigned_long | unsigned_long | +| version | version | version | + diff --git a/docs/reference/query-languages/esql/functions-operators/mv-functions.md b/docs/reference/query-languages/esql/functions-operators/mv-functions.md index defb3f86dbf32..8c58a3bf9dbe0 100644 --- a/docs/reference/query-languages/esql/functions-operators/mv-functions.md +++ b/docs/reference/query-languages/esql/functions-operators/mv-functions.md @@ -14,7 +14,6 @@ mapped_pages: :::{include} ../_snippets/lists/mv-functions.md ::: - :::{include} ../_snippets/functions/layout/mv_append.md ::: @@ -36,6 +35,9 @@ mapped_pages: :::{include} ../_snippets/functions/layout/mv_first.md ::: +:::{include} ../_snippets/functions/layout/mv_intersect.md +::: + :::{include} ../_snippets/functions/layout/mv_last.md ::: diff --git a/docs/reference/query-languages/esql/images/functions/mv_intersect.svg b/docs/reference/query-languages/esql/images/functions/mv_intersect.svg new file mode 100644 index 0000000000000..a439854651513 --- /dev/null +++ b/docs/reference/query-languages/esql/images/functions/mv_intersect.svg @@ -0,0 +1 @@ +MV_INTERSECT(field1,field2) \ No newline at end of file diff --git a/docs/reference/query-languages/esql/kibana/definition/functions/mv_intersect.json b/docs/reference/query-languages/esql/kibana/definition/functions/mv_intersect.json new file mode 100644 index 0000000000000..fd5b961af6600 --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/definition/functions/mv_intersect.json @@ -0,0 +1,377 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it.", + "type" : "scalar", + "name" : "mv_intersect", + "description" : "Returns a subset of the inputs sets that contains the intersection of values in provided mv arguments.", + "signatures" : [ + { + "params" : [ + { + "name" : "field1", + "type" : "boolean", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "boolean", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "cartesian_point", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "cartesian_point", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "cartesian_point" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "cartesian_shape", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "cartesian_shape", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "cartesian_shape" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "date", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "date", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "date" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "date_nanos", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "date_nanos", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "date_nanos" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "double", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "double", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "geo_point", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "geo_point", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "geo_point" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "geo_shape", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "geo_shape", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "geo_shape" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "geohash", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "geohash", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "geohash" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "geohex", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "geohex", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "geohex" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "geotile", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "geotile", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "geotile" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "integer", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "integer", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "integer" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "ip", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "ip", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "ip" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "keyword", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "keyword", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "keyword" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "keyword", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "text", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "keyword" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "long", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "long", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "long" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "text", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "keyword", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "keyword" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "text", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "text", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "keyword" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "unsigned_long", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "unsigned_long", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "unsigned_long" + }, + { + "params" : [ + { + "name" : "field1", + "type" : "version", + "optional" : false, + "description" : "" + }, + { + "name" : "field2", + "type" : "version", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "version" + } + ], + "examples" : [ + "ROW a = [1, 2, 3, 4, 5], b = [2, 3, 4, 5, 6]\n| EVAL finalValue = MV_INTERSECT(a, b)\n| KEEP finalValue", + "ROW a = [1, 2, 3, 4, 5]::long, b = [2, 3, 4, 5, 6]::long\n| EVAL finalValue = MV_INTERSECT(a, b)\n| KEEP finalValue", + "ROW a = [true, false, false, false], b = [false]\n| EVAL finalValue = MV_INTERSECT(a, b)\n| KEEP finalValue", + "ROW a = [5.2, 10.5, 1.12345, 2.6928], b = [10.5, 2.6928]\n| EVAL finalValue = MV_INTERSECT(a, b)\n| KEEP finalValue", + "ROW a = [\"one\", \"two\", \"three\", \"four\", \"five\"], b = [\"one\", \"four\"]\n| EVAL finalValue = MV_INTERSECT(a, b)\n| KEEP finalValue" + ], + "preview" : true, + "snapshot_only" : false +} diff --git a/docs/reference/query-languages/esql/kibana/docs/functions/mv_intersect.md b/docs/reference/query-languages/esql/kibana/docs/functions/mv_intersect.md new file mode 100644 index 0000000000000..e277a2ad6910e --- /dev/null +++ b/docs/reference/query-languages/esql/kibana/docs/functions/mv_intersect.md @@ -0,0 +1,10 @@ +% This is generated by ESQL's AbstractFunctionTestCase. Do not edit it. See ../README.md for how to regenerate it. + +### MV INTERSECT +Returns a subset of the inputs sets that contains the intersection of values in provided mv arguments. + +```esql +ROW a = [1, 2, 3, 4, 5], b = [2, 3, 4, 5, 6] +| EVAL finalValue = MV_INTERSECT(a, b) +| KEEP finalValue +``` diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_intersect.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_intersect.csv-spec new file mode 100644 index 0000000000000..cfa6d476ad17e --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_intersect.csv-spec @@ -0,0 +1,134 @@ +testMvIntersectWithIntValues +required_capability: fn_mv_intersect + +// tag::testMvIntersectWithIntValues[] +ROW a = [1, 2, 3, 4, 5], b = [2, 3, 4, 5, 6] +| EVAL finalValue = MV_INTERSECT(a, b) +| KEEP finalValue +// end::testMvIntersectWithIntValues[] +; + +// tag::testMvIntersectWithIntValues-result[] +finalValue:integer +[2, 3, 4, 5] +// end::testMvIntersectWithIntValues-result[] +; + +testMvIntersectWithLongValues +required_capability: fn_mv_intersect + +// tag::testMvIntersectWithLongValues[] +ROW a = [1, 2, 3, 4, 5]::long, b = [2, 3, 4, 5, 6]::long +| EVAL finalValue = MV_INTERSECT(a, b) +| KEEP finalValue +// end::testMvIntersectWithLongValues[] +; + +// tag::testMvIntersectWithLongValues-result[] +finalValue:long +[2, 3, 4, 5] +// end::testMvIntersectWithLongValues-result[] +; + +testMvIntersectWithBooleanValues +required_capability: fn_mv_intersect + +// tag::testMvIntersectWithBooleanValues[] +ROW a = [true, false, false, false], b = [false] +| EVAL finalValue = MV_INTERSECT(a, b) +| KEEP finalValue +// end::testMvIntersectWithBooleanValues[] +; + +// tag::testMvIntersectWithBooleanValues-result[] +finalValue:boolean +[false] +// end::testMvIntersectWithBooleanValues-result[] +; + +testMvIntersectWithDoubleValues +required_capability: fn_mv_intersect + +// tag::testMvIntersectWithDoubleValues[] +ROW a = [5.2, 10.5, 1.12345, 2.6928], b = [10.5, 2.6928] +| EVAL finalValue = MV_INTERSECT(a, b) +| KEEP finalValue +// end::testMvIntersectWithDoubleValues[] +; + +// tag::testMvIntersectWithDoubleValues-result[] +finalValue:double +[10.5, 2.6928] +// end::testMvIntersectWithDoubleValues-result[] +; + +testMvIntersectWithBytesRefValues +required_capability: fn_mv_intersect + +// tag::testMvIntersectWithBytesRefValues[] +ROW a = ["one", "two", "three", "four", "five"], b = ["one", "four"] +| EVAL finalValue = MV_INTERSECT(a, b) +| KEEP finalValue +// end::testMvIntersectWithBytesRefValues[] +; + +// tag::testMvIntersectWithBytesRefValues-result[] +finalValue:keyword +["one", "four"] +// end::testMvIntersectWithBytesRefValues-result[] +; + +testMvIntersectNullReturnedWhenNoIntersection +required_capability: fn_mv_intersect + +ROW a = [1, 2, 3, 4], b = [5, 6, 7, 8] +| EVAL finalValue = MV_INTERSECT(a, b) +| KEEP finalValue; + +finalValue:integer +null +; + +testMvIntersectSingleValueReturnedWhenOnlyOneIntersection +required_capability: fn_mv_intersect + +ROW a = [1, 2, 3, 4], b = [4, 5, 6, 7] +| EVAL finalValue = MV_INTERSECT(a, b) +| KEEP finalValue; + +finalValue:integer +4 +; + +testMvIntersectNullReturnedWhenFirstArgIsNull +required_capability: fn_mv_intersect + +ROW a = [1, 2, 3, 4] +| EVAL finalValue = MV_INTERSECT(null, a) +| KEEP finalValue; + +finalValue:integer +null +; + +testMvIntersectNullReturnedWhenSecondArgIsNull +required_capability: fn_mv_intersect + +ROW a = [1, 2, 3, 4] +| EVAL finalValue = MV_INTERSECT(a, null) +| KEEP finalValue; + +finalValue:integer +null +; + +testMvIntersectNullReturnedWhenBothArgsIsNull +required_capability: fn_mv_intersect + +ROW a = [1, 2, 3, 4] +| EVAL finalValue = MV_INTERSECT(null, null) +| KEEP finalValue; + +finalValue:null +null +; diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersectBooleanEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersectBooleanEvaluator.java new file mode 100644 index 0000000000000..bfc1db7065634 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersectBooleanEvaluator.java @@ -0,0 +1,127 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import java.lang.Override; +import java.lang.String; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link MvIntersect}. + * This class is generated. Edit {@code EvaluatorImplementer} instead. + */ +public final class MvIntersectBooleanEvaluator implements EvalOperator.ExpressionEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(MvIntersectBooleanEvaluator.class); + + private final Source source; + + private final EvalOperator.ExpressionEvaluator field1; + + private final EvalOperator.ExpressionEvaluator field2; + + private final DriverContext driverContext; + + private Warnings warnings; + + public MvIntersectBooleanEvaluator(Source source, EvalOperator.ExpressionEvaluator field1, + EvalOperator.ExpressionEvaluator field2, DriverContext driverContext) { + this.source = source; + this.field1 = field1; + this.field2 = field2; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (BooleanBlock field1Block = (BooleanBlock) field1.eval(page)) { + try (BooleanBlock field2Block = (BooleanBlock) field2.eval(page)) { + return eval(page.getPositionCount(), field1Block, field2Block); + } + } + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += field1.baseRamBytesUsed(); + baseRamBytesUsed += field2.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public BooleanBlock eval(int positionCount, BooleanBlock field1Block, BooleanBlock field2Block) { + try(BooleanBlock.Builder result = driverContext.blockFactory().newBooleanBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + boolean allBlocksAreNulls = true; + if (!field1Block.isNull(p)) { + allBlocksAreNulls = false; + } + if (!field2Block.isNull(p)) { + allBlocksAreNulls = false; + } + if (allBlocksAreNulls) { + result.appendNull(); + continue position; + } + MvIntersect.process(result, p, field1Block, field2Block); + } + return result.build(); + } + } + + @Override + public String toString() { + return "MvIntersectBooleanEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(field1, field2); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory field1; + + private final EvalOperator.ExpressionEvaluator.Factory field2; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory field1, + EvalOperator.ExpressionEvaluator.Factory field2) { + this.source = source; + this.field1 = field1; + this.field2 = field2; + } + + @Override + public MvIntersectBooleanEvaluator get(DriverContext context) { + return new MvIntersectBooleanEvaluator(source, field1.get(context), field2.get(context), context); + } + + @Override + public String toString() { + return "MvIntersectBooleanEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersectBytesRefEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersectBytesRefEvaluator.java new file mode 100644 index 0000000000000..b274aeca1d3d0 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersectBytesRefEvaluator.java @@ -0,0 +1,128 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import java.lang.Override; +import java.lang.String; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link MvIntersect}. + * This class is generated. Edit {@code EvaluatorImplementer} instead. + */ +public final class MvIntersectBytesRefEvaluator implements EvalOperator.ExpressionEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(MvIntersectBytesRefEvaluator.class); + + private final Source source; + + private final EvalOperator.ExpressionEvaluator field1; + + private final EvalOperator.ExpressionEvaluator field2; + + private final DriverContext driverContext; + + private Warnings warnings; + + public MvIntersectBytesRefEvaluator(Source source, EvalOperator.ExpressionEvaluator field1, + EvalOperator.ExpressionEvaluator field2, DriverContext driverContext) { + this.source = source; + this.field1 = field1; + this.field2 = field2; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (BytesRefBlock field1Block = (BytesRefBlock) field1.eval(page)) { + try (BytesRefBlock field2Block = (BytesRefBlock) field2.eval(page)) { + return eval(page.getPositionCount(), field1Block, field2Block); + } + } + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += field1.baseRamBytesUsed(); + baseRamBytesUsed += field2.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public BytesRefBlock eval(int positionCount, BytesRefBlock field1Block, + BytesRefBlock field2Block) { + try(BytesRefBlock.Builder result = driverContext.blockFactory().newBytesRefBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + boolean allBlocksAreNulls = true; + if (!field1Block.isNull(p)) { + allBlocksAreNulls = false; + } + if (!field2Block.isNull(p)) { + allBlocksAreNulls = false; + } + if (allBlocksAreNulls) { + result.appendNull(); + continue position; + } + MvIntersect.process(result, p, field1Block, field2Block); + } + return result.build(); + } + } + + @Override + public String toString() { + return "MvIntersectBytesRefEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(field1, field2); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory field1; + + private final EvalOperator.ExpressionEvaluator.Factory field2; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory field1, + EvalOperator.ExpressionEvaluator.Factory field2) { + this.source = source; + this.field1 = field1; + this.field2 = field2; + } + + @Override + public MvIntersectBytesRefEvaluator get(DriverContext context) { + return new MvIntersectBytesRefEvaluator(source, field1.get(context), field2.get(context), context); + } + + @Override + public String toString() { + return "MvIntersectBytesRefEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersectDoubleEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersectDoubleEvaluator.java new file mode 100644 index 0000000000000..09f1c35198a43 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersectDoubleEvaluator.java @@ -0,0 +1,127 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import java.lang.Override; +import java.lang.String; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link MvIntersect}. + * This class is generated. Edit {@code EvaluatorImplementer} instead. + */ +public final class MvIntersectDoubleEvaluator implements EvalOperator.ExpressionEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(MvIntersectDoubleEvaluator.class); + + private final Source source; + + private final EvalOperator.ExpressionEvaluator field1; + + private final EvalOperator.ExpressionEvaluator field2; + + private final DriverContext driverContext; + + private Warnings warnings; + + public MvIntersectDoubleEvaluator(Source source, EvalOperator.ExpressionEvaluator field1, + EvalOperator.ExpressionEvaluator field2, DriverContext driverContext) { + this.source = source; + this.field1 = field1; + this.field2 = field2; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (DoubleBlock field1Block = (DoubleBlock) field1.eval(page)) { + try (DoubleBlock field2Block = (DoubleBlock) field2.eval(page)) { + return eval(page.getPositionCount(), field1Block, field2Block); + } + } + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += field1.baseRamBytesUsed(); + baseRamBytesUsed += field2.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public DoubleBlock eval(int positionCount, DoubleBlock field1Block, DoubleBlock field2Block) { + try(DoubleBlock.Builder result = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + boolean allBlocksAreNulls = true; + if (!field1Block.isNull(p)) { + allBlocksAreNulls = false; + } + if (!field2Block.isNull(p)) { + allBlocksAreNulls = false; + } + if (allBlocksAreNulls) { + result.appendNull(); + continue position; + } + MvIntersect.process(result, p, field1Block, field2Block); + } + return result.build(); + } + } + + @Override + public String toString() { + return "MvIntersectDoubleEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(field1, field2); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory field1; + + private final EvalOperator.ExpressionEvaluator.Factory field2; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory field1, + EvalOperator.ExpressionEvaluator.Factory field2) { + this.source = source; + this.field1 = field1; + this.field2 = field2; + } + + @Override + public MvIntersectDoubleEvaluator get(DriverContext context) { + return new MvIntersectDoubleEvaluator(source, field1.get(context), field2.get(context), context); + } + + @Override + public String toString() { + return "MvIntersectDoubleEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersectIntEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersectIntEvaluator.java new file mode 100644 index 0000000000000..f30a411429a0c --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersectIntEvaluator.java @@ -0,0 +1,127 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import java.lang.Override; +import java.lang.String; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link MvIntersect}. + * This class is generated. Edit {@code EvaluatorImplementer} instead. + */ +public final class MvIntersectIntEvaluator implements EvalOperator.ExpressionEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(MvIntersectIntEvaluator.class); + + private final Source source; + + private final EvalOperator.ExpressionEvaluator field1; + + private final EvalOperator.ExpressionEvaluator field2; + + private final DriverContext driverContext; + + private Warnings warnings; + + public MvIntersectIntEvaluator(Source source, EvalOperator.ExpressionEvaluator field1, + EvalOperator.ExpressionEvaluator field2, DriverContext driverContext) { + this.source = source; + this.field1 = field1; + this.field2 = field2; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (IntBlock field1Block = (IntBlock) field1.eval(page)) { + try (IntBlock field2Block = (IntBlock) field2.eval(page)) { + return eval(page.getPositionCount(), field1Block, field2Block); + } + } + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += field1.baseRamBytesUsed(); + baseRamBytesUsed += field2.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public IntBlock eval(int positionCount, IntBlock field1Block, IntBlock field2Block) { + try(IntBlock.Builder result = driverContext.blockFactory().newIntBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + boolean allBlocksAreNulls = true; + if (!field1Block.isNull(p)) { + allBlocksAreNulls = false; + } + if (!field2Block.isNull(p)) { + allBlocksAreNulls = false; + } + if (allBlocksAreNulls) { + result.appendNull(); + continue position; + } + MvIntersect.process(result, p, field1Block, field2Block); + } + return result.build(); + } + } + + @Override + public String toString() { + return "MvIntersectIntEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(field1, field2); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory field1; + + private final EvalOperator.ExpressionEvaluator.Factory field2; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory field1, + EvalOperator.ExpressionEvaluator.Factory field2) { + this.source = source; + this.field1 = field1; + this.field2 = field2; + } + + @Override + public MvIntersectIntEvaluator get(DriverContext context) { + return new MvIntersectIntEvaluator(source, field1.get(context), field2.get(context), context); + } + + @Override + public String toString() { + return "MvIntersectIntEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersectLongEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersectLongEvaluator.java new file mode 100644 index 0000000000000..e4f7cf0aafca4 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersectLongEvaluator.java @@ -0,0 +1,127 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import java.lang.Override; +import java.lang.String; +import org.apache.lucene.util.RamUsageEstimator; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.Warnings; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.xpack.esql.core.tree.Source; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link MvIntersect}. + * This class is generated. Edit {@code EvaluatorImplementer} instead. + */ +public final class MvIntersectLongEvaluator implements EvalOperator.ExpressionEvaluator { + private static final long BASE_RAM_BYTES_USED = RamUsageEstimator.shallowSizeOfInstance(MvIntersectLongEvaluator.class); + + private final Source source; + + private final EvalOperator.ExpressionEvaluator field1; + + private final EvalOperator.ExpressionEvaluator field2; + + private final DriverContext driverContext; + + private Warnings warnings; + + public MvIntersectLongEvaluator(Source source, EvalOperator.ExpressionEvaluator field1, + EvalOperator.ExpressionEvaluator field2, DriverContext driverContext) { + this.source = source; + this.field1 = field1; + this.field2 = field2; + this.driverContext = driverContext; + } + + @Override + public Block eval(Page page) { + try (LongBlock field1Block = (LongBlock) field1.eval(page)) { + try (LongBlock field2Block = (LongBlock) field2.eval(page)) { + return eval(page.getPositionCount(), field1Block, field2Block); + } + } + } + + @Override + public long baseRamBytesUsed() { + long baseRamBytesUsed = BASE_RAM_BYTES_USED; + baseRamBytesUsed += field1.baseRamBytesUsed(); + baseRamBytesUsed += field2.baseRamBytesUsed(); + return baseRamBytesUsed; + } + + public LongBlock eval(int positionCount, LongBlock field1Block, LongBlock field2Block) { + try(LongBlock.Builder result = driverContext.blockFactory().newLongBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + boolean allBlocksAreNulls = true; + if (!field1Block.isNull(p)) { + allBlocksAreNulls = false; + } + if (!field2Block.isNull(p)) { + allBlocksAreNulls = false; + } + if (allBlocksAreNulls) { + result.appendNull(); + continue position; + } + MvIntersect.process(result, p, field1Block, field2Block); + } + return result.build(); + } + } + + @Override + public String toString() { + return "MvIntersectLongEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(field1, field2); + } + + private Warnings warnings() { + if (warnings == null) { + this.warnings = Warnings.createWarnings( + driverContext.warningsMode(), + source.source().getLineNumber(), + source.source().getColumnNumber(), + source.text() + ); + } + return warnings; + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory field1; + + private final EvalOperator.ExpressionEvaluator.Factory field2; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory field1, + EvalOperator.ExpressionEvaluator.Factory field2) { + this.source = source; + this.field1 = field1; + this.field2 = field2; + } + + @Override + public MvIntersectLongEvaluator get(DriverContext context) { + return new MvIntersectLongEvaluator(source, field1.get(context), field2.get(context), context); + } + + @Override + public String toString() { + return "MvIntersectLongEvaluator[" + "field1=" + field1 + ", field2=" + field2 + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 4ce5333592029..2526ecfa2b4c0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -1746,6 +1746,11 @@ public enum Cap { */ FIX_INLINE_STATS_INCORRECT_PRUNNING(INLINE_STATS.enabled), + /** + * Support for the MV_INTERSECT function which returns the set intersection of two multivalued fields + */ + FN_MV_INTERSECT, + // Last capability should still have a comma for fewer merge conflicts when adding new ones :) // This comment prevents the semicolon from being on the previous capability when Spotless formats the file. ; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index c87714dee048c..91d6053d761d6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -160,6 +160,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvCount; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvDedupe; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvFirst; +import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvIntersect; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvLast; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMax; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedian; @@ -515,6 +516,7 @@ private static FunctionDefinition[][] functions() { def(MvCount.class, MvCount::new, "mv_count"), def(MvDedupe.class, MvDedupe::new, "mv_dedupe"), def(MvFirst.class, MvFirst::new, "mv_first"), + def(MvIntersect.class, MvIntersect::new, "mv_intersect"), def(MvLast.class, MvLast::new, "mv_last"), def(MvMax.class, MvMax::new, "mv_max"), def(MvMedian.class, MvMedian::new, "mv_median"), diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvFunctionWritables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvFunctionWritables.java index 8dafc630e0e02..72ef3fc75f28b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvFunctionWritables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvFunctionWritables.java @@ -21,6 +21,7 @@ public static List getNamedWriteables() { MvCount.ENTRY, MvDedupe.ENTRY, MvFirst.ENTRY, + MvIntersect.ENTRY, MvLast.ENTRY, MvMax.ENTRY, MvMedian.ENTRY, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersect.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersect.java new file mode 100644 index 0000000000000..5ce5002fe1668 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersect.java @@ -0,0 +1,320 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.compute.ann.Evaluator; +import org.elasticsearch.compute.ann.Position; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BytesRefBlock; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.LongBlock; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Nullability; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesTo; +import org.elasticsearch.xpack.esql.expression.function.FunctionAppliesToLifecycle; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.planner.PlannerUtils; + +import java.io.IOException; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Consumer; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isRepresentableExceptCountersDenseVectorAggregateMetricDoubleAndExponentialHistogram; + +/** + * Adds a function to return a result set with multivalued items that are contained in the input sets. + * Example: + * Given set A = {"a","b","c"} and set B = {"b","c","d"}, MV_INTERSECT(A, B) returns {"b", "c"} + */ +public class MvIntersect extends EsqlScalarFunction implements EvaluatorMapper { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "MvIntersect", + MvIntersect::new + ); + + private final Expression field1; + private final Expression field2; + + private DataType dataType; + + @FunctionInfo( + returnType = { + "boolean", + "cartesian_point", + "cartesian_shape", + "date", + "date_nanos", + "double", + "geo_point", + "geo_shape", + "geohash", + "geotile", + "geohex", + "integer", + "ip", + "keyword", + "long", + "unsigned_long", + "version" }, + description = "Returns a subset of the inputs sets that contains the intersection of values in provided mv arguments.", + preview = true, + examples = { + @Example(file = "mv_intersect", tag = "testMvIntersectWithIntValues"), + @Example(file = "mv_intersect", tag = "testMvIntersectWithLongValues"), + @Example(file = "mv_intersect", tag = "testMvIntersectWithBooleanValues"), + @Example(file = "mv_intersect", tag = "testMvIntersectWithDoubleValues"), + @Example(file = "mv_intersect", tag = "testMvIntersectWithBytesRefValues") }, + appliesTo = { @FunctionAppliesTo(lifeCycle = FunctionAppliesToLifecycle.PREVIEW, version = "9.3.0") } + ) + public MvIntersect( + Source source, + @Param( + name = "field1", + type = { + "boolean", + "cartesian_point", + "cartesian_shape", + "date", + "date_nanos", + "double", + "geo_point", + "geo_shape", + "geohash", + "geotile", + "geohex", + "integer", + "ip", + "keyword", + "long", + "text", + "unsigned_long", + "version" } + ) Expression field1, + @Param( + name = "field2", + type = { + "boolean", + "cartesian_point", + "cartesian_shape", + "date", + "date_nanos", + "double", + "geo_point", + "geo_shape", + "geohash", + "geotile", + "geohex", + "integer", + "ip", + "keyword", + "long", + "text", + "unsigned_long", + "version" } + ) Expression field2 + ) { + super(source, List.of(field1, field2)); + this.field1 = field1; + this.field2 = field2; + } + + private MvIntersect(StreamInput in) throws IOException { + this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class)); + } + + @Override + public boolean foldable() { + return field1.foldable() && field2.foldable(); + } + + @Evaluator(extraName = "Boolean") + static void process(BooleanBlock.Builder builder, @Position int position, BooleanBlock field1, BooleanBlock field2) { + processIntersectSet(builder, position, field1, field2, (p, block) -> ((BooleanBlock) block).getBoolean(p), builder::appendBoolean); + } + + @Evaluator(extraName = "BytesRef") + static void process(BytesRefBlock.Builder builder, @Position int position, BytesRefBlock field1, BytesRefBlock field2) { + processIntersectSet(builder, position, field1, field2, (p, block) -> { + BytesRef value = new BytesRef(); + return ((BytesRefBlock) block).getBytesRef(p, value); + }, builder::appendBytesRef); + } + + @Evaluator(extraName = "Int") + static void process(IntBlock.Builder builder, @Position int position, IntBlock field1, IntBlock field2) { + processIntersectSet(builder, position, field1, field2, (p, block) -> ((IntBlock) block).getInt(p), builder::appendInt); + } + + @Evaluator(extraName = "Long") + static void process(LongBlock.Builder builder, @Position int position, LongBlock field1, LongBlock field2) { + processIntersectSet(builder, position, field1, field2, (p, block) -> ((LongBlock) block).getLong(p), builder::appendLong); + } + + @Evaluator(extraName = "Double") + static void process(DoubleBlock.Builder builder, @Position int position, DoubleBlock field1, DoubleBlock field2) { + processIntersectSet(builder, position, field1, field2, (p, block) -> ((DoubleBlock) block).getDouble(p), builder::appendDouble); + } + + @Override + public DataType dataType() { + if (dataType == null) { + resolveType(); + } + return dataType; + } + + @Override + protected TypeResolution resolveType() { + if (childrenResolved() == false) { + return new TypeResolution("Unresolved children"); + } + + // ensure all children are the same type + ElementType field1Type = PlannerUtils.toElementType(field1.dataType()); + ElementType field2Type = PlannerUtils.toElementType(field2.dataType()); + + if (field1Type != field2Type && field1Type.equals(ElementType.NULL) == false && field2Type.equals(ElementType.NULL) == false) { + return new TypeResolution("All child fields must be the same type"); + } + + Expression evaluatedField = field1Type.equals(ElementType.NULL) ? field2 : field1; + + this.dataType = evaluatedField.dataType().noText(); + + TypeResolution resolution = isRepresentableExceptCountersDenseVectorAggregateMetricDoubleAndExponentialHistogram( + evaluatedField, + sourceText(), + FIRST + ); + if (resolution.unresolved()) { + return resolution; + } + + return resolution; + } + + @Override + public Expression replaceChildren(List newChildren) { + return new MvIntersect(source(), newChildren.getFirst(), newChildren.get(1)); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, MvIntersect::new, field1, field2); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + out.writeNamedWriteable(field1); + out.writeNamedWriteable(field2); + } + + @Override + public EvalOperator.ExpressionEvaluator.Factory toEvaluator(ToEvaluator toEvaluator) { + return switch (PlannerUtils.toElementType(dataType())) { + case BOOLEAN -> new MvIntersectBooleanEvaluator.Factory(source(), toEvaluator.apply(field1), toEvaluator.apply(field2)); + case BYTES_REF -> new MvIntersectBytesRefEvaluator.Factory(source(), toEvaluator.apply(field1), toEvaluator.apply(field2)); + case INT -> new MvIntersectIntEvaluator.Factory(source(), toEvaluator.apply(field1), toEvaluator.apply(field2)); + case LONG -> new MvIntersectLongEvaluator.Factory(source(), toEvaluator.apply(field1), toEvaluator.apply(field2)); + case DOUBLE -> new MvIntersectDoubleEvaluator.Factory(source(), toEvaluator.apply(field1), toEvaluator.apply(field2)); + case NULL -> EvalOperator.CONSTANT_NULL_FACTORY; + default -> throw EsqlIllegalArgumentException.illegalDataType(dataType); + }; + } + + @Override + public Nullability nullable() { + return Nullability.TRUE; + } + + @Override + public int hashCode() { + return Objects.hash(field1, field2); + } + + @Override + public boolean equals(Object obj) { + if (obj == null || obj.getClass() != getClass()) { + return false; + } + MvIntersect other = (MvIntersect) obj; + return Objects.equals(other.field1, field1) && Objects.equals(other.field2, field2); + } + + static void processIntersectSet( + Block.Builder builder, + int position, + Block field1, + Block field2, + BiFunction getValueFunction, + Consumer addValueFunction + ) { + int firstValueCount = field1.getValueCount(position); + int secondValueCount = field2.getValueCount(position); + if (firstValueCount == 0 || secondValueCount == 0) { + // if either block has no values, there will be no intersection + builder.appendNull(); + return; + } + + int firstValueIndex = field1.getFirstValueIndex(position); + int secondValueIndex = field2.getFirstValueIndex(position); + + Set values = new LinkedHashSet<>(); + for (int i = 0; i < firstValueCount; i++) { + values.add(getValueFunction.apply(firstValueIndex + i, field1)); + } + + Set secondValues = new HashSet<>(); + for (int i = 0; i < secondValueCount; i++) { + secondValues.add(getValueFunction.apply(secondValueIndex + i, field2)); + } + + values.retainAll(secondValues); + if (values.isEmpty()) { + builder.appendNull(); + return; + } + + builder.beginPositionEntry(); + for (T value : values) { + addValueFunction.accept(value); + } + builder.endPositionEntry(); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index 9d68ad9ac7f1b..5115a5b3a4455 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -3435,6 +3435,23 @@ public void testTopSnippetsFunctionInvalidInputs() { } + public void testMvIntersectValidatesDataTypesAreEqual() { + String[] values = { + "[\"one\", \"two\", \"three\", \"four\", \"five\"]", + "[1, 2, 3, 4, 5]", + "[1, 2, 3, 4, 5]::long", + "[1.1, 2.2, 3.3, 4.4, 5.5]" }; + for (int i = 0; i < values.length; i++) { + for (int j = 0; j < values.length; j++) { + if (i == j) { + continue; + } + String query = "ROW a = " + values[i] + ", b = " + values[j] + " | EVAL finalValue = MV_INTERSECT(a, b)"; + assertThat(error(query, tsdb), containsString(": All child fields must be the same type")); + } + } + } + private void checkVectorFunctionsNullArgs(String functionInvocation) throws Exception { query("from test | eval similarity = " + functionInvocation, fullTextAnalyzer); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersectTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersectTests.java new file mode 100644 index 0000000000000..10745d8827c43 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvIntersectTests.java @@ -0,0 +1,293 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.geo.GeometryTestUtils; +import org.elasticsearch.geo.ShapeTestUtils; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import org.hamcrest.Matcher; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.esql.EsqlTestUtils.randomLiteral; +import static org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes.CARTESIAN; +import static org.elasticsearch.xpack.esql.core.util.SpatialCoordinateTypes.GEO; +import static org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSliceTests.randomGrid; +import static org.hamcrest.Matchers.equalTo; + +public class MvIntersectTests extends AbstractScalarFunctionTestCase { + + public MvIntersectTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + List suppliers = new ArrayList<>(); + booleans(suppliers); + ints(suppliers); + longs(suppliers); + doubles(suppliers); + bytesRefs(suppliers); + return parameterSuppliersFromTypedData(anyNullIsNull(true, suppliers)); + } + + @Override + protected Expression build(Source source, List args) { + return new MvIntersect(source, args.get(0), args.get(1)); + } + + private static Matcher matchResult(HashSet result) { + if (result == null || result.isEmpty()) { + return equalTo(null); + } + + if (result.size() > 1) { + return equalTo(new ArrayList<>(result)); + } + + return equalTo(result.stream().findFirst().get()); + } + + private static void booleans(List suppliers) { + suppliers.add(new TestCaseSupplier(List.of(DataType.BOOLEAN, DataType.BOOLEAN), () -> { + List field1 = randomList(1, 10, () -> randomBoolean()); + List field2 = randomList(1, 10, () -> randomBoolean()); + var result = new LinkedHashSet<>(field1); + result.retainAll(field2); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.BOOLEAN, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.BOOLEAN, "field2") + ), + "MvIntersectBooleanEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.BOOLEAN, + matchResult(result) + ); + })); + } + + private static void ints(List suppliers) { + suppliers.add(new TestCaseSupplier(List.of(DataType.INTEGER, DataType.INTEGER), () -> { + List field1 = randomList(1, 10, () -> randomIntBetween(1, 10)); + List field2 = randomList(1, 10, () -> randomIntBetween(1, 10)); + var result = new LinkedHashSet<>(field1); + result.retainAll(field2); + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.INTEGER, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.INTEGER, "field2") + ), + "MvIntersectIntEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.INTEGER, + matchResult(result) + ); + })); + } + + private static void longs(List suppliers) { + addLongTestCase(suppliers, DataType.LONG, ESTestCase::randomLong); + addLongTestCase(suppliers, DataType.DATETIME, ESTestCase::randomLong); + addLongTestCase(suppliers, DataType.DATE_NANOS, ESTestCase::randomNonNegativeLong); + for (DataType gridType : new DataType[] { DataType.GEOHASH, DataType.GEOTILE, DataType.GEOHEX }) { + addLongTestCase(suppliers, gridType, () -> randomGrid(gridType)); + } + + suppliers.add(new TestCaseSupplier(List.of(DataType.UNSIGNED_LONG, DataType.UNSIGNED_LONG), () -> { + List field1 = randomList(1, 10, ESTestCase::randomLong); + List field2 = randomList(1, 10, ESTestCase::randomLong); + var result = new LinkedHashSet<>(field1); + result.retainAll(field2); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.UNSIGNED_LONG, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.UNSIGNED_LONG, "field2") + ), + "MvIntersectLongEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.UNSIGNED_LONG, + matchResult(result) + ); + })); + } + + private static void addLongTestCase(List suppliers, DataType dataType, Supplier longSupplier) { + suppliers.add(new TestCaseSupplier(List.of(dataType, dataType), () -> { + List field1 = randomList(1, 10, longSupplier); + List field2 = randomList(1, 10, longSupplier); + var result = new LinkedHashSet<>(field1); + result.retainAll(field2); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, dataType, "field1"), + new TestCaseSupplier.TypedData(field2, dataType, "field2") + ), + "MvIntersectLongEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + dataType, + matchResult(result) + ); + })); + } + + private static void doubles(List suppliers) { + suppliers.add(new TestCaseSupplier(List.of(DataType.DOUBLE, DataType.DOUBLE), () -> { + List field1 = randomList(1, 10, () -> randomDouble()); + List field2 = randomList(1, 10, () -> randomDouble()); + var result = new LinkedHashSet<>(field1); + result.retainAll(field2); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.DOUBLE, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.DOUBLE, "field2") + ), + "MvIntersectDoubleEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.DOUBLE, + matchResult(result) + ); + })); + } + + private static void bytesRefs(List suppliers) { + for (DataType lhs : new DataType[] { DataType.KEYWORD, DataType.TEXT }) { + for (DataType rhs : new DataType[] { DataType.KEYWORD, DataType.TEXT }) { + suppliers.add(new TestCaseSupplier(List.of(lhs, rhs), () -> { + List field1 = randomList(1, 10, () -> randomLiteral(lhs).value()); + List field2 = randomList(1, 10, () -> randomLiteral(rhs).value()); + var result = new LinkedHashSet<>(field1); + result.retainAll(field2); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, lhs, "field1"), + new TestCaseSupplier.TypedData(field2, rhs, "field2") + ), + "MvIntersectBytesRefEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.KEYWORD, + matchResult(result) + ); + })); + } + } + suppliers.add(new TestCaseSupplier(List.of(DataType.IP, DataType.IP), () -> { + List field1 = randomList(1, 10, () -> randomLiteral(DataType.IP).value()); + List field2 = randomList(1, 10, () -> randomLiteral(DataType.IP).value()); + var result = new LinkedHashSet<>(field1); + result.retainAll(field2); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.IP, "field"), + new TestCaseSupplier.TypedData(field2, DataType.IP, "field") + ), + "MvIntersectBytesRefEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.IP, + matchResult(result) + ); + })); + + suppliers.add(new TestCaseSupplier(List.of(DataType.VERSION, DataType.VERSION), () -> { + List field1 = randomList(1, 10, () -> randomLiteral(DataType.VERSION).value()); + List field2 = randomList(1, 10, () -> randomLiteral(DataType.VERSION).value()); + var result = new LinkedHashSet<>(field1); + result.retainAll(field2); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.VERSION, "field"), + new TestCaseSupplier.TypedData(field2, DataType.VERSION, "field") + ), + "MvIntersectBytesRefEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.VERSION, + matchResult(result) + ); + })); + + suppliers.add(new TestCaseSupplier(List.of(DataType.GEO_POINT, DataType.GEO_POINT), () -> { + List field1 = randomList(1, 10, () -> new BytesRef(GEO.asWkt(GeometryTestUtils.randomPoint()))); + List field2 = randomList(1, 10, () -> new BytesRef(GEO.asWkt(GeometryTestUtils.randomPoint()))); + var result = new LinkedHashSet<>(field1); + result.retainAll(field2); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.GEO_POINT, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.GEO_POINT, "field2") + ), + "MvIntersectBytesRefEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.GEO_POINT, + matchResult(result) + ); + })); + + suppliers.add(new TestCaseSupplier(List.of(DataType.CARTESIAN_POINT, DataType.CARTESIAN_POINT), () -> { + List field1 = randomList(1, 10, () -> new BytesRef(CARTESIAN.asWkt(ShapeTestUtils.randomPoint()))); + List field2 = randomList(1, 10, () -> new BytesRef(CARTESIAN.asWkt(ShapeTestUtils.randomPoint()))); + var result = new LinkedHashSet<>(field1); + result.retainAll(field2); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.CARTESIAN_POINT, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.CARTESIAN_POINT, "field2") + ), + "MvIntersectBytesRefEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.CARTESIAN_POINT, + matchResult(result) + ); + })); + + suppliers.add(new TestCaseSupplier(List.of(DataType.GEO_SHAPE, DataType.GEO_SHAPE), () -> { + var field1 = randomList(1, 3, () -> new BytesRef(GEO.asWkt(GeometryTestUtils.randomGeometry(randomBoolean(), 500)))); + var field2 = randomList(1, 3, () -> new BytesRef(GEO.asWkt(GeometryTestUtils.randomGeometry(randomBoolean(), 500)))); + var result = new LinkedHashSet<>(field1); + result.retainAll(field2); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.GEO_SHAPE, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.GEO_SHAPE, "field2") + ), + "MvIntersectBytesRefEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.GEO_SHAPE, + matchResult(result) + ); + })); + + suppliers.add(new TestCaseSupplier(List.of(DataType.CARTESIAN_SHAPE, DataType.CARTESIAN_SHAPE), () -> { + var field1 = randomList(1, 3, () -> new BytesRef(CARTESIAN.asWkt(ShapeTestUtils.randomGeometry(randomBoolean(), 500)))); + var field2 = randomList(1, 3, () -> new BytesRef(CARTESIAN.asWkt(ShapeTestUtils.randomGeometry(randomBoolean(), 500)))); + var result = new LinkedHashSet<>(field1); + result.retainAll(field2); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field1, DataType.CARTESIAN_SHAPE, "field1"), + new TestCaseSupplier.TypedData(field2, DataType.CARTESIAN_SHAPE, "field2") + ), + "MvIntersectBytesRefEvaluator[field1=Attribute[channel=0], field2=Attribute[channel=1]]", + DataType.CARTESIAN_SHAPE, + matchResult(result) + ); + })); + } +}