Skip to content

Leverage Iceberg-Rust for all the transforms #1833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jun 7, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

128 changes: 40 additions & 88 deletions pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,36 @@ def parse_transform(v: Any) -> Any:
return v


def _pyiceberg_transform_wrapper(
transform_func: Callable[["ArrayLike", Any], "ArrayLike"],
*args: Any,
expected_type: Optional["pa.DataType"] = None,
) -> Callable[["ArrayLike"], "ArrayLike"]:
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For bucket/truncate transforms, PyArrow needs to be installed") from e

def _transform(array: "ArrayLike") -> "ArrayLike":
def _cast_if_needed(arr: "ArrayLike") -> "ArrayLike":
if expected_type is not None:
return arr.cast(expected_type)
else:
return arr

if isinstance(array, pa.Array):
return _cast_if_needed(transform_func(array, *args))
elif isinstance(array, pa.ChunkedArray):
result_chunks = []
for arr in array.iterchunks():
result_chunks.append(_cast_if_needed(transform_func(arr, *args)))
return pa.chunked_array(result_chunks)
else:
raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}")

return _transform


class Transform(IcebergRootModel[str], ABC, Generic[S, T]):
"""Transform base class for concrete transforms.

Expand Down Expand Up @@ -198,27 +228,6 @@ def supports_pyarrow_transform(self) -> bool:
@abstractmethod
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ...

def _pyiceberg_transform_wrapper(
self, transform_func: Callable[["ArrayLike", Any], "ArrayLike"], *args: Any
) -> Callable[["ArrayLike"], "ArrayLike"]:
try:
import pyarrow as pa
except ModuleNotFoundError as e:
raise ModuleNotFoundError("For bucket/truncate transforms, PyArrow needs to be installed") from e

def _transform(array: "ArrayLike") -> "ArrayLike":
if isinstance(array, pa.Array):
return transform_func(array, *args)
elif isinstance(array, pa.ChunkedArray):
result_chunks = []
for arr in array.iterchunks():
result_chunks.append(transform_func(arr, *args))
return pa.chunked_array(result_chunks)
else:
raise ValueError(f"PyArrow array can only be of type pa.Array or pa.ChunkedArray, but found {type(array)}")

return _transform


class BucketTransform(Transform[S, int]):
"""Base Transform class to transform a value into a bucket partition value.
Expand Down Expand Up @@ -375,7 +384,7 @@ def __repr__(self) -> str:
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
from pyiceberg_core import transform as pyiceberg_core_transform

return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.bucket, self._num_buckets)
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.bucket, self._num_buckets)

@property
def supports_pyarrow_transform(self) -> bool:
Expand Down Expand Up @@ -501,22 +510,9 @@ def __repr__(self) -> str:

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
import pyarrow.compute as pc

if isinstance(source, DateType):
epoch = pa.scalar(datetime.EPOCH_DATE)
elif isinstance(source, TimestampType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMP)
elif isinstance(source, TimestamptzType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMPTZ)
elif isinstance(source, TimestampNanoType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMP).cast(pa.timestamp("ns"))
elif isinstance(source, TimestamptzNanoType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMPTZ).cast(pa.timestamp("ns"))
else:
raise ValueError(f"Cannot apply year transform for type: {source}")
from pyiceberg_core import transform as pyiceberg_core_transform

return lambda v: pc.years_between(epoch, v) if v is not None else None
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.year, expected_type=pa.int32())


class MonthTransform(TimeTransform[S]):
Expand Down Expand Up @@ -575,28 +571,9 @@ def __repr__(self) -> str:

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
import pyarrow.compute as pc

if isinstance(source, DateType):
epoch = pa.scalar(datetime.EPOCH_DATE)
elif isinstance(source, TimestampType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMP)
elif isinstance(source, TimestamptzType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMPTZ)
elif isinstance(source, TimestampNanoType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMP).cast(pa.timestamp("ns"))
elif isinstance(source, TimestamptzNanoType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMPTZ).cast(pa.timestamp("ns"))
else:
raise ValueError(f"Cannot apply month transform for type: {source}")

def month_func(v: pa.Array) -> pa.Array:
return pc.add(
pc.multiply(pc.years_between(epoch, v), pa.scalar(12)),
pc.add(pc.month(v), pa.scalar(-1)),
)
from pyiceberg_core import transform as pyiceberg_core_transform

return lambda v: month_func(v) if v is not None else None
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.month, expected_type=pa.int32())


class DayTransform(TimeTransform[S]):
Expand Down Expand Up @@ -663,22 +640,9 @@ def __repr__(self) -> str:

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
import pyarrow.compute as pc

if isinstance(source, DateType):
epoch = pa.scalar(datetime.EPOCH_DATE)
elif isinstance(source, TimestampType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMP)
elif isinstance(source, TimestamptzType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMPTZ)
elif isinstance(source, TimestampNanoType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMP).cast(pa.timestamp("ns"))
elif isinstance(source, TimestamptzNanoType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMPTZ).cast(pa.timestamp("ns"))
else:
raise ValueError(f"Cannot apply day transform for type: {source}")
from pyiceberg_core import transform as pyiceberg_core_transform

return lambda v: pc.days_between(epoch, v) if v is not None else None
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.day, expected_type=pa.int32())


class HourTransform(TimeTransform[S]):
Expand Down Expand Up @@ -728,21 +692,9 @@ def __repr__(self) -> str:
return "HourTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
import pyarrow.compute as pc

if isinstance(source, TimestampType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMP)
elif isinstance(source, TimestamptzType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMPTZ)
elif isinstance(source, TimestampNanoType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMP).cast(pa.timestamp("ns"))
elif isinstance(source, TimestamptzNanoType):
epoch = pa.scalar(datetime.EPOCH_TIMESTAMPTZ).cast(pa.timestamp("ns"))
else:
raise ValueError(f"Cannot apply hour transform for type: {source}")
from pyiceberg_core import transform as pyiceberg_core_transform

return lambda v: pc.hours_between(epoch, v) if v is not None else None
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.hour)


def _base64encode(buffer: bytes) -> str:
Expand Down Expand Up @@ -965,7 +917,7 @@ def __repr__(self) -> str:
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
from pyiceberg_core import transform as pyiceberg_core_transform

return self._pyiceberg_transform_wrapper(pyiceberg_core_transform.truncate, self._width)
return _pyiceberg_transform_wrapper(pyiceberg_core_transform.truncate, self._width)

@property
def supports_pyarrow_transform(self) -> bool:
Expand Down
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ psycopg2-binary = { version = ">=2.9.6", optional = true }
sqlalchemy = { version = "^2.0.18", optional = true }
getdaft = { version = ">=0.2.12", optional = true }
cachetools = "^5.5.0"
pyiceberg-core = { version = "^0.4.0", optional = true }
pyiceberg-core = { version = "0.4.0.dev20250326000154", source="testpypi", optional = true }
polars = { version = "^1.21.0", optional = true }
thrift-sasl = { version = ">=0.4.3", optional = true }

Expand Down Expand Up @@ -115,6 +115,14 @@ mkdocs-material = "9.6.9"
mkdocs-material-extensions = "1.3.1"
mkdocs-section-index = "0.3.9"

[[tool.poetry.source]]
name = "pypi"
priority = "primary"

[[tool.poetry.source]]
name = "testpypi"
url = "https://test.pypi.org/simple/"

[[tool.mypy.overrides]]
module = "pytest_mock.*"
ignore_missing_imports = true
Expand Down
4 changes: 2 additions & 2 deletions tests/table/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ def test_partition_type(table_schema_simple: Schema) -> None:
(DecimalType(5, 9), Decimal(19.25)),
(DateType(), datetime.date(1925, 5, 22)),
(TimeType(), datetime.time(19, 25, 00)),
(TimestampType(), datetime.datetime(19, 5, 1, 22, 1, 1)),
(TimestamptzType(), datetime.datetime(19, 5, 1, 22, 1, 1, tzinfo=datetime.timezone.utc)),
(TimestampType(), datetime.datetime(2022, 5, 1, 22, 1, 1)),
(TimestamptzType(), datetime.datetime(2022, 5, 1, 22, 1, 1, tzinfo=datetime.timezone.utc)),
(StringType(), "abc"),
(UUIDType(), UUID("12345678-1234-5678-1234-567812345678").bytes),
(FixedType(5), 'b"\x8e\xd1\x87\x01"'),
Expand Down