Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions pyiceberg/expressions/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -915,17 +915,13 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi

# In the order described by the "Column Projection" section of the Iceberg spec:
# https://iceberg.apache.org/spec/#column-projection
# Evaluate column projection first if it exists
if field_id in self.projected_field_values:
if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(
Record(self.projected_field_values[field_id])
):
return AlwaysTrue()

# Evaluate initial_default value
# Evaluate column projection first if it exists, otherwise default to the initial-default-value
field_value = (
self.projected_field_values[field_id] if field.field_id in self.projected_field_values else field.initial_default
)
return (
AlwaysTrue()
if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(Record(field.initial_default))
if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(Record(field_value))
else AlwaysFalse()
)

Expand All @@ -940,7 +936,7 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi


def translate_column_names(
expr: BooleanExpression, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[int, Any] = EMPTY_DICT
expr: BooleanExpression, file_schema: Schema, case_sensitive: bool = True, projected_field_values: Dict[int, Any] = EMPTY_DICT
) -> BooleanExpression:
return visit(expr, _ColumnNameTranslator(file_schema, case_sensitive, projected_field_values))

Expand Down
58 changes: 8 additions & 50 deletions tests/expressions/test_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,7 +1750,7 @@ def test_translate_column_names_missing_column_match_explicit_null() -> None:
)

# Translate column names
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True, projected_field_values={2: None})
translated_expr = translate_column_names(bound_expr, file_schema, projected_field_values={2: None})

# Should evaluate to AlwaysTrue because the missing column is treated as null
# missing_col's default initial_default (None) satisfies the IsNull predicate
Expand Down Expand Up @@ -1828,12 +1828,7 @@ def test_translate_column_names_missing_column_with_projected_field_matches() ->
)

# Projected column that is missing in the file schema
projected_field_values = {2: 42}

# Translate column names
translated_expr = translate_column_names(
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
)
translated_expr = translate_column_names(bound_expr, file_schema, projected_field_values={2: 42})

# Should evaluate to AlwaysTrue since projected field value matches the expression literal
# even though the field is missing in the file schema
Expand All @@ -1860,18 +1855,13 @@ def test_translate_column_names_missing_column_with_projected_field_mismatch() -
)

# Projected column that is missing in the file schema
projected_field_values = {2: 1}

# Translate column names
translated_expr = translate_column_names(
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
)
translated_expr = translate_column_names(bound_expr, file_schema, projected_field_values={2: 1})

# Should evaluate to AlwaysFalse since projected field value does not match the expression literal
assert translated_expr == AlwaysFalse()


def test_translate_column_names_missing_column_projected_field_fallbacks_to_initial_default() -> None:
def test_translate_column_names_missing_column_projected_field_ignores_initial_default() -> None:
"""Test translate_column_names when projected field value doesn't match but initial_default does."""
# Original schema with a field that has an initial_default
original_schema = Schema(
Expand All @@ -1891,43 +1881,11 @@ def test_translate_column_names_missing_column_projected_field_fallbacks_to_init
)

# Projected field value that differs from both the expression literal and initial_default
projected_field_values = {2: 10} # This doesn't match expression literal (42)

# Translate column names
translated_expr = translate_column_names(
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
)

# Should evaluate to AlwaysTrue since projected field value doesn't match but initial_default does
assert translated_expr == AlwaysTrue()


def test_translate_column_names_missing_column_projected_field_matches_initial_default_mismatch() -> None:
"""Test translate_column_names when both projected field value and initial_default doesn't match."""
# Original schema with a field that has an initial_default that doesn't match the expression
original_schema = Schema(
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False, initial_default=10),
schema_id=1,
)

# Create bound expression for the missing column
unbound_expr = EqualTo("missing_col", 42)
bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True))

# File schema only has the existing column (field_id=1), missing field_id=2
file_schema = Schema(
NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False),
schema_id=1,
)

# Projected field value that matches the expression literal
projected_field_values = {2: 10} # This doesn't match expression literal (42)

# Translate column names
translated_expr = translate_column_names(
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
bound_expr,
file_schema,
projected_field_values={2: 10}, # This doesn't match expression literal (42)
)

# Should evaluate to AlwaysFalse since both projected field value and initial_default does not match
# Should evaluate to AlwaysFalse since projected field value doesn't match the expression literal
assert translated_expr == AlwaysFalse()
11 changes: 4 additions & 7 deletions tests/integration/test_hive_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import time
from datetime import date

import pytest
from pyspark.sql import SparkSession
Expand Down Expand Up @@ -75,12 +76,8 @@ def test_migrate_table(
tbl = session_catalog_hive.load_table(dst_table_identifier)
assert tbl.schema().column_names == ["number", "dt"]

# TODO: Returns the primitive type (int), rather than the logical type
# assert set(tbl.scan().to_arrow().column(1).combine_chunks().tolist()) == {'2022-01-01', '2023-01-01'}

assert set(tbl.scan().to_arrow().column(1).combine_chunks().tolist()) == {date(2023, 1, 1), date(2022, 1, 1)}
assert tbl.scan(row_filter="number > 3").to_arrow().column(0).combine_chunks().tolist() == [4, 5, 6]

assert tbl.scan(row_filter="dt == '2023-01-01'").to_arrow().column(0).combine_chunks().tolist() == [4, 5, 6]

# TODO: Issue with filtering the projected column
# assert tbl.scan(row_filter="dt == '2022-01-01'").to_arrow().column(0).combine_chunks().tolist() == [1, 2, 3]
assert tbl.scan(row_filter="dt == '2022-01-01'").to_arrow().column(0).combine_chunks().tolist() == [1, 2, 3]
assert tbl.scan(row_filter="dt < '2022-02-01'").to_arrow().column(0).combine_chunks().tolist() == [1, 2, 3]