Skip to content

remove stats correction from ES|QL sample #129319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class RestSampleTestCase extends ESRestTestCase {
public void skipWhenSampleDisabled() throws IOException {
assumeTrue(
"Requires SAMPLE capability",
EsqlSpecTestCase.hasCapabilities(adminClient(), List.of(EsqlCapabilities.Cap.SAMPLE_V2.capabilityName()))
EsqlSpecTestCase.hasCapabilities(adminClient(), List.of(EsqlCapabilities.Cap.SAMPLE_V3.capabilityName()))
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
// because the CSV tests don't support such assertions.

row
required_capability: sample_v2
required_capability: sample_v3

ROW x = 1 | SAMPLE .999999999
;
Expand All @@ -20,7 +20,7 @@ x:integer


row and mv_expand
required_capability: sample_v2
required_capability: sample_v3

ROW x = [1,2,3,4,5] | MV_EXPAND x | SAMPLE .999999999
;
Expand All @@ -35,15 +35,14 @@ x:integer


adjust stats for sampling
required_capability: sample_v2
required_capability: sample_v3

FROM employees
| SAMPLE 0.5
| STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)), avg_emp_no = AVG(emp_no), sum_emp_no = SUM(emp_no)
| EVAL is_expected = count >= 20 AND count <= 180 AND
values_count >= 10 AND values_count <= 90 AND
| STATS count = COUNT(), avg_emp_no = AVG(emp_no), sum_emp_no = SUM(emp_no)
| EVAL is_expected = count >= 10 AND count <= 90 AND
avg_emp_no > 10010 AND avg_emp_no < 10090 AND
sum_emp_no > 20*10010 AND sum_emp_no < 180*10090
sum_emp_no > 10*10010 AND sum_emp_no < 90*10090
| KEEP is_expected
;

Expand All @@ -53,14 +52,13 @@ true


before where
required_capability: sample_v2
required_capability: sample_v3

FROM employees
| SAMPLE 0.5
| WHERE emp_no > 10050
| STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)), avg_emp_no = AVG(emp_no)
| EVAL is_expected = count >= 5 AND count <= 95 AND
values_count >= 2 AND values_count <= 48 AND
| STATS count = COUNT(), avg_emp_no = AVG(emp_no)
| EVAL is_expected = count >= 2 AND count <= 48 AND
avg_emp_no > 10055 AND avg_emp_no < 10095
| KEEP is_expected
;
Expand All @@ -71,14 +69,13 @@ true


after where
required_capability: sample_v2
required_capability: sample_v3

FROM employees
| WHERE emp_no <= 10050
| SAMPLE 0.5
| STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)), avg_emp_no = AVG(emp_no)
| EVAL is_expected = count >= 5 AND count <= 95 AND
values_count >= 2 AND values_count <= 48 AND
| STATS count = COUNT(), avg_emp_no = AVG(emp_no)
| EVAL is_expected = count >= 2 AND count <= 48 AND
avg_emp_no > 10005 AND avg_emp_no < 10045
| KEEP is_expected
;
Expand All @@ -89,14 +86,13 @@ true


before sort
required_capability: sample_v2
required_capability: sample_v3

FROM employees
| SAMPLE 0.5
| SORT emp_no
| STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)), avg_emp_no = AVG(emp_no)
| EVAL is_expected = count >= 20 AND count <= 180 AND
values_count >= 10 AND values_count <= 90 AND
| STATS count = COUNT(), avg_emp_no = AVG(emp_no)
| EVAL is_expected = count >= 10 AND count <= 90 AND
avg_emp_no > 10010 AND avg_emp_no < 10090
| KEEP is_expected
;
Expand All @@ -107,14 +103,13 @@ true


after sort
required_capability: sample_v2
required_capability: sample_v3

FROM employees
| SORT emp_no
| SAMPLE 0.5
| STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)), avg_emp_no = AVG(emp_no)
| EVAL is_expected = count >= 20 AND count <= 180 AND
values_count >= 10 AND values_count <= 90 AND
| STATS count = COUNT(), avg_emp_no = AVG(emp_no)
| EVAL is_expected = count >= 10 AND count <= 90 AND
avg_emp_no > 10010 AND avg_emp_no < 10090
| KEEP is_expected
;
Expand All @@ -125,13 +120,13 @@ true


before limit
required_capability: sample_v2
required_capability: sample_v3

FROM employees
| SAMPLE 0.5
| LIMIT 10
| STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no))
| EVAL is_expected = count == 10 AND values_count == 10
| STATS count = COUNT(emp_no)
| EVAL is_expected = count == 10
| KEEP is_expected
;

Expand All @@ -141,14 +136,13 @@ true


after limit
required_capability: sample_v2
required_capability: sample_v3

FROM employees
| LIMIT 50
| SAMPLE 0.5
| STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no))
| EVAL is_expected = count >= 5 AND count <= 95 AND
values_count >= 2 AND values_count <= 48
| STATS count = COUNT(emp_no)
| EVAL is_expected = count >= 2 AND count <= 48
| KEEP is_expected
;

Expand All @@ -158,7 +152,7 @@ true


before mv_expand
required_capability: sample_v2
required_capability: sample_v3

ROW x = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50], y = [1,2]
| MV_EXPAND x
Expand All @@ -176,7 +170,7 @@ true


after mv_expand
required_capability: sample_v2
required_capability: sample_v3

ROW x = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50], y = [1,2]
| MV_EXPAND x
Expand All @@ -194,15 +188,14 @@ true


multiple samples
required_capability: sample_v2
required_capability: sample_v3

FROM employees
| SAMPLE 0.7
| SAMPLE 0.8
| SAMPLE 0.9
| STATS count = COUNT(), values_count = MV_COUNT(VALUES(emp_no)), avg_emp_no = AVG(emp_no)
| EVAL is_expected = count >= 20 AND count <= 180 AND
values_count >= 10 AND values_count <= 90 AND
| STATS count = COUNT(), avg_emp_no = AVG(emp_no)
| EVAL is_expected = count >= 10 AND count <= 90 AND
avg_emp_no > 10010 AND avg_emp_no < 10090
| KEEP is_expected
;
Expand All @@ -213,15 +206,14 @@ true


after stats
required_capability: sample_v2
required_capability: sample_v3

FROM employees
| SAMPLE 0.5
| STATS avg_salary = AVG(salary) BY job_positions
| SAMPLE 0.8
| STATS count = COUNT(), values_count = MV_COUNT(VALUES(avg_salary)), avg_avg_salary = AVG(avg_salary)
| EVAL is_expected = count >= 1 AND count <= 20 AND
values_count >= 1 AND values_count <= 16 AND
| STATS count = COUNT(), avg_avg_salary = AVG(avg_salary)
| EVAL is_expected = count >= 1 AND count <= 16 AND
avg_avg_salary > 25000 AND avg_avg_salary < 75000
| KEEP is_expected
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1082,7 +1082,7 @@ public enum Cap {
/**
* Support for the SAMPLE command
*/
SAMPLE_V2(Build.current().isSnapshot()),
SAMPLE_V3(Build.current().isSnapshot()),

/**
* The {@code _query} API now gives a cast recommendation if multiple types are found in certain instances.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@
import org.elasticsearch.xpack.esql.expression.function.FunctionType;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromAggregateMetricDouble;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvCount;
import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;
import org.elasticsearch.xpack.esql.planner.ToAggregator;

Expand All @@ -39,11 +37,9 @@
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT;
import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isType;

public class Count extends AggregateFunction implements ToAggregator, SurrogateExpression, HasSampleCorrection {
public class Count extends AggregateFunction implements ToAggregator, SurrogateExpression {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Count", Count::new);

private final boolean isSampleCorrected;

@FunctionInfo(
returnType = "long",
description = "Returns the total number (count) of input values.",
Expand Down Expand Up @@ -98,20 +94,11 @@ public Count(
}

public Count(Source source, Expression field, Expression filter) {
this(source, field, filter, false);
}

private Count(Source source, Expression field, Expression filter, boolean isSampleCorrected) {
super(source, field, filter, emptyList());
this.isSampleCorrected = isSampleCorrected;
}

private Count(StreamInput in) throws IOException {
super(in);
// isSampleCorrected is only used during query optimization to mark
// whether this function has been processed. Hence there's no need to
// serialize it.
this.isSampleCorrected = false;
}

@Override
Expand Down Expand Up @@ -182,14 +169,4 @@ public Expression surrogate() {

return null;
}

@Override
public boolean isSampleCorrected() {
return isSampleCorrected;
}

@Override
public Expression sampleCorrection(Expression sampleProbability) {
return new ToLong(source(), new Div(source(), new Count(source(), field(), filter(), true), sampleProbability));
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@
import org.elasticsearch.xpack.esql.expression.function.FunctionType;
import org.elasticsearch.xpack.esql.expression.function.Param;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.FromAggregateMetricDouble;
import org.elasticsearch.xpack.esql.expression.function.scalar.convert.ToLong;
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div;
import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul;

import java.io.IOException;
Expand All @@ -45,11 +43,9 @@
/**
* Sum all values of a field in matching documents.
*/
public class Sum extends NumericAggregate implements SurrogateExpression, HasSampleCorrection {
public class Sum extends NumericAggregate implements SurrogateExpression {
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Sum", Sum::new);

private final boolean isSampleCorrected;

@FunctionInfo(
returnType = { "long", "double" },
description = "The sum of a numeric expression.",
Expand All @@ -69,20 +65,11 @@ public Sum(Source source, @Param(name = "number", type = { "aggregate_metric_dou
}

public Sum(Source source, Expression field, Expression filter) {
this(source, field, filter, false);
}

private Sum(Source source, Expression field, Expression filter, boolean isSampleCorrected) {
super(source, field, filter, emptyList());
this.isSampleCorrected = isSampleCorrected;
}

private Sum(StreamInput in) throws IOException {
super(in);
// isSampleCorrected is only used during query optimization to mark
// whether this function has been processed. Hence there's no need to
// serialize it.
this.isSampleCorrected = false;
}

@Override
Expand Down Expand Up @@ -160,19 +147,4 @@ public Expression surrogate() {
? new Mul(s, new MvSum(s, field), new Count(s, new Literal(s, StringUtils.WILDCARD, DataType.KEYWORD)))
: null;
}

@Override
public boolean isSampleCorrected() {
return isSampleCorrected;
}

@Override
public Expression sampleCorrection(Expression sampleProbability) {
Expression correctedSum = new Div(source(), new Sum(source(), field(), filter(), true), sampleProbability);
return switch (dataType()) {
case DOUBLE -> correctedSum;
case LONG -> new ToLong(source(), correctedSum);
default -> throw new IllegalStateException("unexpected data type [" + dataType() + "]");
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import org.elasticsearch.xpack.esql.VerificationException;
import org.elasticsearch.xpack.esql.common.Failures;
import org.elasticsearch.xpack.esql.core.type.DataType;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.ApplySampleCorrections;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.BooleanFunctionEqualsElimination;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.BooleanSimplification;
import org.elasticsearch.xpack.esql.optimizer.rules.logical.CombineBinaryComparisons;
Expand Down Expand Up @@ -130,7 +129,6 @@ protected static Batch<LogicalPlan> substitutions() {
return new Batch<>(
"Substitutions",
Limiter.ONCE,
new ApplySampleCorrections(),
new SubstituteSurrogatePlans(),
// Translate filtered expressions into aggregate with filters - can't use surrogate expressions because it was
// retrofitted for constant folding - this needs to be fixed.
Expand Down
Loading