Skip to content

fix: constraint parsing, roundtripping #3503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/core/src/delta_datafusion/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ impl Display for BinaryExprFormat<'_> {
impl Display for SqlFormat<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.expr {
Expr::Column(c) => write!(f, "{c}"),
Expr::Column(c) => write!(f, "{}", c.quoted_flat_name()),
Expr::Literal(v) => write!(f, "{}", ScalarValueFormat { scalar: v }),
Expr::Case(case) => {
write!(f, "CASE ")?;
Expand Down Expand Up @@ -727,7 +727,7 @@ mod test {
},
simple!(
Expr::Column(Column::from_qualified_name_ignore_case("Value3")).eq(lit(3_i64)),
"Value3 = 3".to_string()
"\"Value3\" = 3".to_string()
),
simple!(col("active").is_true(), "active IS TRUE".to_string()),
simple!(col("active"), "active".to_string()),
Expand Down
28 changes: 9 additions & 19 deletions crates/core/src/delta_datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1591,40 +1591,30 @@ impl DeltaDataChecker {
return Ok(());
}
let table = MemTable::try_new(record_batch.schema(), vec![vec![record_batch.clone()]])?;
let schema = table.schema();
table.schema();
// Use a random table name to avoid clashes when running multiple parallel tasks, e.g. when using a partitioned table
let table_name: String = uuid::Uuid::new_v4().to_string();
self.ctx.register_table(&table_name, Arc::new(table))?;

let mut violations: Vec<String> = Vec::new();

for check in checks {
let check_name = check.get_name();
if check_name.contains('.') {
if check.get_name().contains('.') {
return Err(DeltaTableError::Generic(
"Support for nested columns is not supported.".to_string(),
"delta constraints for nested columns are not supported at the moment."
.to_string(),
));
}

let field_to_select = if check.as_any().is::<Constraint>() {
"*"
} else {
check_name
check.get_name()
};

// Loop through schema to find the matching field. If the field has a whitespace, we
// need to backtick it, since the expression is an unquoted string
let mut expression = check.get_expression().to_string();
for field in schema.fields() {
if expression.contains(field.name()) {
expression =
expression.replace(field.name(), format!("`{}` ", field.name()).as_str());
break;
}
}
let sql = format!(
"SELECT {} FROM `{table_name}` WHERE NOT ({}) LIMIT 1",
field_to_select, expression
field_to_select,
check.get_expression()
);

let dfs: Vec<RecordBatch> = self.ctx.sql(&sql).await?.collect().await?;
Expand Down Expand Up @@ -2441,7 +2431,7 @@ mod tests {
// Valid invariants return Ok(())
let constraints = vec![
Constraint::new("custom a", "a is not null"),
Constraint::new("custom_b", "b bop < 1000"),
Constraint::new("custom_b", "`b bop` < 1000"),
];
assert!(DeltaDataChecker::new_with_constraints(constraints)
.check_batch(&batch)
Expand All @@ -2451,7 +2441,7 @@ mod tests {
// Violated invariants returns an error with list of violations
let constraints = vec![
Constraint::new("custom_a", "a is null"),
Constraint::new("custom_B", "b bop < 100"),
Constraint::new("custom_B", "\"b bop\" < 100"),
];
let result = DeltaDataChecker::new_with_constraints(constraints)
.check_batch(&batch)
Expand Down
2 changes: 1 addition & 1 deletion crates/core/src/operations/constraints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ mod tests {
let version = table.version();
assert_eq!(version, Some(1));

let expected_expr = "vAlue < 1000"; // spellchecker:disable-line
let expected_expr = "\"vAlue\" < 1000"; // spellchecker:disable-line
assert_eq!(get_constraint_op_params(&mut table).await, expected_expr);
assert_eq!(
get_constraint(&table, "delta.constraints.valid_values"),
Expand Down
1 change: 0 additions & 1 deletion crates/core/src/operations/write/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,6 @@ impl std::future::IntoFuture for WriteBuilder {
Expression::String(s) => {
let df_schema = DFSchema::try_from(schema.as_ref().to_owned())?;
parse_predicate_expression(&df_schema, s, &state)?
// this.snapshot.unwrap().parse_predicate_expression(s, &state)?
}
};
(Some(fmt_expr_to_sql(&pred)?), Some(pred))
Expand Down
127 changes: 127 additions & 0 deletions python/tests/test_constraint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import pytest
from arro3.core import Array, DataType, Field, Schema, Table

from deltalake import DeltaTable, write_deltalake
from deltalake.exceptions import DeltaError, DeltaProtocolError


@pytest.fixture()
def sample_table() -> Table:
nrows = 5
return Table(
{
"id": Array(
["1", "2", "3", "4", "5"],
Field("id", type=DataType.string(), nullable=True),
),
"high price": Array(
list(range(nrows)),
Field("high price", type=DataType.int64(), nullable=True),
),
},
)


def test_not_corrupting_expression(tmp_path):
data = Table.from_pydict(
{
"b": Array([1], DataType.int64()),
"color_column": Array(["red"], DataType.string()),
},
)

data2 = Table.from_pydict(
{
"b": Array([1], DataType.int64()),
"color_column": Array(["blue"], DataType.string()),
},
)

write_deltalake(
tmp_path,
data,
mode="overwrite",
partition_by=["color_column"],
predicate="color_column = 'red'",
)
write_deltalake(
tmp_path,
data2,
mode="overwrite",
partition_by=["color_column"],
predicate="color_column = 'blue'",
)


def test_not_corrupting_expression_columns_spaced(tmp_path):
data = Table.from_pydict(
{
"b": Array([1], DataType.int64()),
"color column": Array(["red"], DataType.string()),
},
)

data2 = Table.from_pydict(
{
"b": Array([1], DataType.int64()),
"color column": Array(["blue"], DataType.string()),
},
)

write_deltalake(
tmp_path,
data,
mode="overwrite",
# partition_by=["color column"],
predicate="`color column` = 'red'",
)
write_deltalake(
tmp_path,
data2,
mode="overwrite",
# partition_by=["color column"],
predicate="`color column` = 'blue'",
)


# fmt: off

@pytest.mark.parametrize("sql_string", [
"`high price` >= 0",
'"high price" >= 0',
"\"high price\" >= 0"
])
def test_add_constraint(tmp_path, sample_table: Table, sql_string: str):
write_deltalake(tmp_path, sample_table)

dt = DeltaTable(tmp_path)

dt.alter.add_constraint({"check_price": sql_string})

last_action = dt.history(1)[0]
assert last_action["operation"] == "ADD CONSTRAINT"
assert dt.version() == 1
assert dt.metadata().configuration == {
"delta.constraints.check_price": '"high price" >= 0'
}
assert dt.protocol().min_writer_version == 3

with pytest.raises(DeltaError):
# Invalid constraint
dt.alter.add_constraint({"check_price": '"high price" < 0'})

with pytest.raises(DeltaProtocolError):
data = Table(
{
"id": Array(["1"], DataType.string()),
"high price": Array([-1], DataType.int64()),
},
schema=Schema(
fields=[
Field("id", type=DataType.string(), nullable=True),
Field("high price", type=DataType.int64(), nullable=True),
]
),
)

write_deltalake(tmp_path, data, mode="append")
Loading