Skip to content

Commit 8804ada

Browse files
feat: Add str.join method (#2054)
1 parent 5229e07 commit 8804ada

File tree

8 files changed

+144
-7
lines changed

8 files changed

+144
-7
lines changed

bigframes/core/compile/ibis_compiler/aggregate_compiler.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,29 @@ def _(
676676
).to_expr()
677677

678678

679+
@compile_ordered_unary_agg.register
680+
def _(
681+
op: agg_ops.StringAggOp,
682+
column: ibis_types.Column,
683+
window=None,
684+
order_by: typing.Sequence[ibis_types.Value] = [],
685+
) -> ibis_types.ArrayValue:
686+
if window is not None:
687+
raise NotImplementedError(
688+
f"StringAgg with windowing is not supported. {constants.FEEDBACK_LINK}"
689+
)
690+
691+
return (
692+
ibis_ops.StringAgg(
693+
column, # type: ignore
694+
sep=op.sep, # type: ignore
695+
order_by=order_by, # type: ignore
696+
)
697+
.to_expr()
698+
.fill_null(ibis_types.literal(""))
699+
)
700+
701+
679702
@compile_binary_agg.register
680703
def _(
681704
op: agg_ops.CorrOp, left: ibis_types.Column, right: ibis_types.Column, window=None

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1216,11 +1216,18 @@ def to_arry_op_impl(*values: ibis_types.Value):
12161216
def array_reduce_op_impl(x: ibis_types.Value, op: ops.ArrayReduceOp):
12171217
import bigframes.core.compile.ibis_compiler.aggregate_compiler as agg_compilers
12181218

1219-
return typing.cast(ibis_types.ArrayValue, x).reduce(
1220-
lambda arr_vals: agg_compilers.compile_unary_agg(
1221-
op.aggregation, typing.cast(ibis_types.Column, arr_vals)
1219+
if op.aggregation.order_independent:
1220+
return typing.cast(ibis_types.ArrayValue, x).reduce(
1221+
lambda arr_vals: agg_compilers.compile_unary_agg(
1222+
op.aggregation, typing.cast(ibis_types.Column, arr_vals)
1223+
)
1224+
)
1225+
else:
1226+
return typing.cast(ibis_types.ArrayValue, x).reduce(
1227+
lambda arr_vals: agg_compilers.compile_ordered_unary_agg(
1228+
op.aggregation, typing.cast(ibis_types.Column, arr_vals)
1229+
)
12221230
)
1223-
)
12241231

12251232

12261233
# JSON Ops

bigframes/operations/aggregations.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,26 @@ def skips_nulls(self):
379379
return True
380380

381381
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
382-
return pd.ArrowDtype(
383-
pa.list_(dtypes.bigframes_dtype_to_arrow_dtype(input_types[0]))
384-
)
382+
return dtypes.list_type(input_types[0])
383+
384+
385+
@dataclasses.dataclass(frozen=True)
386+
class StringAggOp(UnaryAggregateOp):
387+
name: ClassVar[str] = "string_agg"
388+
sep: str = ","
389+
390+
@property
391+
def order_independent(self):
392+
return False
393+
394+
@property
395+
def skips_nulls(self):
396+
return True
397+
398+
def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionType:
399+
if input_types[0] != dtypes.STRING_DTYPE:
400+
raise TypeError(f"Type {input_types[0]} is not string-like")
401+
return dtypes.STRING_DTYPE
385402

386403

387404
@dataclasses.dataclass(frozen=True)

bigframes/operations/strings.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import bigframes.dataframe as df
2525
import bigframes.operations as ops
2626
from bigframes.operations._op_converters import convert_index, convert_slice
27+
import bigframes.operations.aggregations as agg_ops
2728
import bigframes.operations.base
2829
import bigframes.series as series
2930

@@ -295,6 +296,11 @@ def cat(
295296
) -> series.Series:
296297
return self._apply_binary_op(others, ops.strconcat_op, alignment=join)
297298

299+
def join(self, sep: str) -> series.Series:
300+
return self._apply_unary_op(
301+
ops.ArrayReduceOp(aggregation=agg_ops.StringAggOp(sep=sep))
302+
)
303+
298304
def to_blob(self, connection: Optional[str] = None) -> series.Series:
299305
"""Create a BigFrames Blob series from a series of URIs.
300306

tests/system/small/operations/test_strings.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,3 +736,14 @@ def test_getitem_w_struct_array():
736736
expected = bpd.Series(expected_data, dtype=bpd.ArrowDtype((pa_struct)))
737737

738738
assert_series_equal(result.to_pandas(), expected.to_pandas())
739+
740+
741+
def test_string_join(session):
742+
pd_series = pd.Series([["a", "b", "c"], ["100"], ["hello", "world"], []])
743+
bf_series = session.read_pandas(pd_series)
744+
745+
pd_result = pd_series.str.join("--")
746+
bf_result = bf_series.str.join("--").to_pandas()
747+
748+
pd_result = pd_result.astype("string[pyarrow]")
749+
assert_series_equal(pd_result, bf_result, check_dtype=False, check_index_type=False)

third_party/bigframes_vendored/ibis/backends/sql/compilers/bigquery/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,22 @@ def visit_ArrayAggregate(self, op, *, arg, order_by, where):
10881088
expr = arg
10891089
return sge.IgnoreNulls(this=self.agg.array_agg(expr, where=where))
10901090

1091+
def visit_StringAgg(self, op, *, arg, sep, order_by, where):
1092+
if len(order_by) > 0:
1093+
expr = sge.Order(
1094+
this=arg,
1095+
expressions=[
1096+
# Avoid adding NULLS FIRST / NULLS LAST in SQL, which is
1097+
# unsupported in ARRAY_AGG by reconstructing the node as
1098+
# plain SQL text.
1099+
f"({order_column.args['this'].sql(dialect='bigquery')}) {'DESC' if order_column.args.get('desc') else 'ASC'}"
1100+
for order_column in order_by
1101+
],
1102+
)
1103+
else:
1104+
expr = arg
1105+
return self.agg.string_agg(expr, sep, where=where)
1106+
10911107
def visit_FirstNonNullValue(self, op, *, arg):
10921108
return sge.IgnoreNulls(this=sge.FirstValue(this=arg))
10931109

third_party/bigframes_vendored/ibis/expr/operations/reductions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,3 +401,20 @@ class ArrayAggregate(Filterable, Reduction):
401401
@attribute
402402
def dtype(self):
403403
return dt.Array(self.arg.dtype)
404+
405+
406+
@public
407+
class StringAgg(Filterable, Reduction):
408+
"""
409+
Collects the elements of this expression into a string. Similar to
410+
the ibis `GroupConcat`, but adds `order_by_*` parameter.
411+
"""
412+
413+
arg: Column
414+
sep: Value[dt.String]
415+
416+
order_by: VarTuple[Value] = ()
417+
418+
@attribute
419+
def dtype(self):
420+
return dt.string

third_party/bigframes_vendored/pandas/core/strings/accessor.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,3 +1298,43 @@ def center(
12981298
bigframes.series.Series: Returns Series or Index with minimum number of char in object.
12991299
"""
13001300
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)
1301+
1302+
def join(self, sep: str):
1303+
"""
1304+
Join lists contained as elements in the Series/Index with passed delimiter.
1305+
1306+
If the elements of a Series are lists themselves, join the content of these
1307+
lists using the delimiter passed to the function.
1308+
This function is an equivalent to :meth:`str.join`.
1309+
1310+
**Examples:**
1311+
1312+
>>> import bigframes.pandas as bpd
1313+
>>> bpd.options.display.progress_bar = None
1314+
>>> import pandas as pd
1315+
1316+
Example with a list that contains non-string elements.
1317+
1318+
>>> s = bpd.Series([['lion', 'elephant', 'zebra'],
1319+
... ['dragon'],
1320+
... ['duck', 'swan', 'fish', 'guppy']])
1321+
>>> s
1322+
0 ['lion' 'elephant' 'zebra']
1323+
1 ['dragon']
1324+
2 ['duck' 'swan' 'fish' 'guppy']
1325+
dtype: list<item: string>[pyarrow]
1326+
1327+
>>> s.str.join('-')
1328+
0 lion-elephant-zebra
1329+
1 dragon
1330+
2 duck-swan-fish-guppy
1331+
dtype: string
1332+
1333+
Args:
1334+
sep (str):
1335+
Delimiter to use between list entries.
1336+
1337+
Returns:
1338+
bigframes.series.Series: The list entries concatenated by intervening occurrences of the delimiter.
1339+
"""
1340+
raise NotImplementedError(constants.ABSTRACT_METHOD_ERROR_MESSAGE)

0 commit comments

Comments
 (0)