diff --git a/pyiceberg/expressions/visitors.py b/pyiceberg/expressions/visitors.py index 779f2b476f..a6268c0d48 100644 --- a/pyiceberg/expressions/visitors.py +++ b/pyiceberg/expressions/visitors.py @@ -915,17 +915,13 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi # In the order described by the "Column Projection" section of the Iceberg spec: # https://iceberg.apache.org/spec/#column-projection - # Evaluate column projection first if it exists - if field_id in self.projected_field_values: - if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)( - Record(self.projected_field_values[field_id]) - ): - return AlwaysTrue() - - # Evaluate initial_default value + # Evaluate column projection first if it exists, otherwise default to the initial-default-value + field_value = ( + self.projected_field_values[field_id] if field.field_id in self.projected_field_values else field.initial_default + ) return ( AlwaysTrue() - if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(Record(field.initial_default)) + if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(Record(field_value)) else AlwaysFalse() ) @@ -940,7 +936,7 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi def translate_column_names( - expr: BooleanExpression, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[int, Any] = EMPTY_DICT + expr: BooleanExpression, file_schema: Schema, case_sensitive: bool = True, projected_field_values: Dict[int, Any] = EMPTY_DICT ) -> BooleanExpression: return visit(expr, _ColumnNameTranslator(file_schema, case_sensitive, projected_field_values)) diff --git a/tests/expressions/test_visitors.py b/tests/expressions/test_visitors.py index 997cc7f7d7..d0b6ab5ab4 100644 --- a/tests/expressions/test_visitors.py +++ b/tests/expressions/test_visitors.py @@ -1750,7 +1750,7 @@ def test_translate_column_names_missing_column_match_explicit_null() -> None: ) # Translate column names - translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True, projected_field_values={2: None}) + translated_expr = translate_column_names(bound_expr, file_schema, projected_field_values={2: None}) # Should evaluate to AlwaysTrue because the missing column is treated as null # missing_col's default initial_default (None) satisfies the IsNull predicate @@ -1828,12 +1828,7 @@ def test_translate_column_names_missing_column_with_projected_field_matches() -> ) # Projected column that is missing in the file schema - projected_field_values = {2: 42} - - # Translate column names - translated_expr = translate_column_names( - bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values - ) + translated_expr = translate_column_names(bound_expr, file_schema, projected_field_values={2: 42}) # Should evaluate to AlwaysTrue since projected field value matches the expression literal # even though the field is missing in the file schema @@ -1860,18 +1855,13 @@ def test_translate_column_names_missing_column_with_projected_field_mismatch() - ) # Projected column that is missing in the file schema - projected_field_values = {2: 1} - - # Translate column names - translated_expr = translate_column_names( - bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values - ) + translated_expr = translate_column_names(bound_expr, file_schema, projected_field_values={2: 1}) # Should evaluate to AlwaysFalse since projected field value does not match the expression literal assert translated_expr == AlwaysFalse() -def test_translate_column_names_missing_column_projected_field_fallbacks_to_initial_default() -> None: +def test_translate_column_names_missing_column_projected_field_ignores_initial_default() -> None: """Test translate_column_names when projected field value doesn't match but initial_default does.""" # Original schema with a field that has an initial_default original_schema = Schema( @@ -1891,43 +1881,11 @@ def test_translate_column_names_missing_column_projected_field_fallbacks_to_init ) # Projected field value that differs from both the expression literal and initial_default - projected_field_values = {2: 10} # This doesn't match expression literal (42) - - # Translate column names - translated_expr = translate_column_names( - bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values - ) - - # Should evaluate to AlwaysTrue since projected field value doesn't match but initial_default does - assert translated_expr == AlwaysTrue() - - -def test_translate_column_names_missing_column_projected_field_matches_initial_default_mismatch() -> None: - """Test translate_column_names when both projected field value and initial_default doesn't match.""" - # Original schema with a field that has an initial_default that doesn't match the expression - original_schema = Schema( - NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), - NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False, initial_default=10), - schema_id=1, - ) - - # Create bound expression for the missing column - unbound_expr = EqualTo("missing_col", 42) - bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True)) - - # File schema only has the existing column (field_id=1), missing field_id=2 - file_schema = Schema( - NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), - schema_id=1, - ) - - # Projected field value that matches the expression literal - projected_field_values = {2: 10} # This doesn't match expression literal (42) - - # Translate column names translated_expr = translate_column_names( - bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values + bound_expr, + file_schema, + projected_field_values={2: 10}, # This doesn't match expression literal (42) ) - # Should evaluate to AlwaysFalse since both projected field value and initial_default does not match + # Should evaluate to AlwaysFalse since projected field value doesn't match the expression literal assert translated_expr == AlwaysFalse() diff --git a/tests/integration/test_hive_migration.py b/tests/integration/test_hive_migration.py index 060450731e..51386d56c4 100644 --- a/tests/integration/test_hive_migration.py +++ b/tests/integration/test_hive_migration.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import time +from datetime import date import pytest from pyspark.sql import SparkSession @@ -75,12 +76,8 @@ def test_migrate_table( tbl = session_catalog_hive.load_table(dst_table_identifier) assert tbl.schema().column_names == ["number", "dt"] - # TODO: Returns the primitive type (int), rather than the logical type - # assert set(tbl.scan().to_arrow().column(1).combine_chunks().tolist()) == {'2022-01-01', '2023-01-01'} - + assert set(tbl.scan().to_arrow().column(1).combine_chunks().tolist()) == {date(2023, 1, 1), date(2022, 1, 1)} assert tbl.scan(row_filter="number > 3").to_arrow().column(0).combine_chunks().tolist() == [4, 5, 6] - assert tbl.scan(row_filter="dt == '2023-01-01'").to_arrow().column(0).combine_chunks().tolist() == [4, 5, 6] - - # TODO: Issue with filtering the projected column - # assert tbl.scan(row_filter="dt == '2022-01-01'").to_arrow().column(0).combine_chunks().tolist() == [1, 2, 3] + assert tbl.scan(row_filter="dt == '2022-01-01'").to_arrow().column(0).combine_chunks().tolist() == [1, 2, 3] + assert tbl.scan(row_filter="dt < '2022-02-01'").to_arrow().column(0).combine_chunks().tolist() == [1, 2, 3]