Skip to content

fix: increase recursion limit, cache compilation tree hashes #184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 14, 2023
75 changes: 72 additions & 3 deletions bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

from dataclasses import dataclass, field
from dataclasses import dataclass, field, fields
import functools
import typing
from typing import Optional, Tuple
Expand Down Expand Up @@ -66,6 +66,13 @@ def session(self):
return sessions[0]
return None

# BigFrameNode trees can be very deep so its important avoid recalculating the hash from scratch
# Each subclass of BigFrameNode should use this property to implement __hash__
# The default dataclass-generated __hash__ method is not cached
@functools.cached_property
def _node_hash(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's document this (needed for cached_property and to avoid infinite loop)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comment explaining need for this hash impl

return hash(tuple(hash(getattr(self, field.name)) for field in fields(self)))


@dataclass(frozen=True)
class UnaryNode(BigFrameNode):
Expand Down Expand Up @@ -95,6 +102,9 @@ class JoinNode(BigFrameNode):
def child_nodes(self) -> typing.Sequence[BigFrameNode]:
return (self.left_child, self.right_child)

def __hash__(self):
return self._node_hash


@dataclass(frozen=True)
class ConcatNode(BigFrameNode):
Expand All @@ -104,13 +114,19 @@ class ConcatNode(BigFrameNode):
def child_nodes(self) -> typing.Sequence[BigFrameNode]:
return self.children

def __hash__(self):
return self._node_hash


# Input Nodex
@dataclass(frozen=True)
class ReadLocalNode(BigFrameNode):
feather_bytes: bytes
column_ids: typing.Tuple[str, ...]

def __hash__(self):
return self._node_hash


# TODO: Refactor to take raw gbq object reference
@dataclass(frozen=True)
Expand All @@ -125,45 +141,70 @@ class ReadGbqNode(BigFrameNode):
def session(self):
return (self.table_session,)

def __hash__(self):
return self._node_hash


# Unary nodes
@dataclass(frozen=True)
class DropColumnsNode(UnaryNode):
columns: Tuple[str, ...]

def __hash__(self):
return self._node_hash


@dataclass(frozen=True)
class PromoteOffsetsNode(UnaryNode):
col_id: str

def __hash__(self):
return self._node_hash


@dataclass(frozen=True)
class FilterNode(UnaryNode):
predicate_id: str
keep_null: bool = False

def __hash__(self):
return self._node_hash


@dataclass(frozen=True)
class OrderByNode(UnaryNode):
by: Tuple[OrderingColumnReference, ...]

def __hash__(self):
return self._node_hash


@dataclass(frozen=True)
class ReversedNode(UnaryNode):
pass
# useless field to make sure has distinct hash
reversed: bool = True

def __hash__(self):
return self._node_hash


@dataclass(frozen=True)
class SelectNode(UnaryNode):
column_ids: typing.Tuple[str, ...]

def __hash__(self):
return self._node_hash


@dataclass(frozen=True)
class ProjectUnaryOpNode(UnaryNode):
input_id: str
op: ops.UnaryOp
output_id: Optional[str] = None

def __hash__(self):
return self._node_hash


@dataclass(frozen=True)
class ProjectBinaryOpNode(UnaryNode):
Expand All @@ -172,6 +213,9 @@ class ProjectBinaryOpNode(UnaryNode):
op: ops.BinaryOp
output_id: str

def __hash__(self):
return self._node_hash


@dataclass(frozen=True)
class ProjectTernaryOpNode(UnaryNode):
Expand All @@ -181,19 +225,28 @@ class ProjectTernaryOpNode(UnaryNode):
op: ops.TernaryOp
output_id: str

def __hash__(self):
return self._node_hash


@dataclass(frozen=True)
class AggregateNode(UnaryNode):
aggregations: typing.Tuple[typing.Tuple[str, agg_ops.AggregateOp, str], ...]
by_column_ids: typing.Tuple[str, ...] = tuple([])
dropna: bool = True

def __hash__(self):
return self._node_hash


# TODO: Unify into aggregate
@dataclass(frozen=True)
class CorrNode(UnaryNode):
corr_aggregations: typing.Tuple[typing.Tuple[str, str, str], ...]

def __hash__(self):
return self._node_hash


@dataclass(frozen=True)
class WindowOpNode(UnaryNode):
Expand All @@ -204,10 +257,14 @@ class WindowOpNode(UnaryNode):
never_skip_nulls: bool = False
skip_reproject_unsafe: bool = False

def __hash__(self):
return self._node_hash


@dataclass(frozen=True)
class ReprojectOpNode(UnaryNode):
pass
def __hash__(self):
return self._node_hash


@dataclass(frozen=True)
Expand All @@ -223,19 +280,28 @@ class UnpivotNode(UnaryNode):
] = (pandas.Float64Dtype(),)
how: typing.Literal["left", "right"] = "left"

def __hash__(self):
return self._node_hash


@dataclass(frozen=True)
class AssignNode(UnaryNode):
source_id: str
destination_id: str

def __hash__(self):
return self._node_hash


@dataclass(frozen=True)
class AssignConstantNode(UnaryNode):
destination_id: str
value: typing.Hashable
dtype: typing.Optional[bigframes.dtypes.Dtype]

def __hash__(self):
return self._node_hash


@dataclass(frozen=True)
class RandomSampleNode(UnaryNode):
Expand All @@ -244,3 +310,6 @@ class RandomSampleNode(UnaryNode):
@property
def deterministic(self) -> bool:
return False

def __hash__(self):
return self._node_hash
4 changes: 4 additions & 0 deletions bigframes/pandas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from collections import namedtuple
import inspect
import sys
import typing
from typing import (
Any,
Expand Down Expand Up @@ -657,6 +658,9 @@ def read_gbq_function(function_name: str):
close_session = global_session.close_session
reset_session = global_session.close_session

# SQL Compilation uses recursive algorithms on deep trees
# 10M tree depth should be sufficient to generate any sql that is under bigquery limit
sys.setrecursionlimit(max(10000000, sys.getrecursionlimit()))

# Use __all__ to let type checkers know what is part of the public API.
__all___ = [
Expand Down
7 changes: 7 additions & 0 deletions tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3667,6 +3667,13 @@ def test_df_dot_operator_series(
)


def test_recursion_limit(scalars_df_index):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be able to make this a unit test, but as discussed, it is good to exercise ibis here too.

scalars_df_index = scalars_df_index[["int64_too", "int64_col", "float64_col"]]
for i in range(400):
scalars_df_index = scalars_df_index + 4
scalars_df_index.to_pandas()


def test_to_pandas_downsampling_option_override(session):
df = session.read_gbq("bigframes-dev.bigframes_tests_sys.batting")
download_size = 1
Expand Down