Skip to content

Commit ac25618

Browse files
authored
feat: Support callable for series map method (#2100)
1 parent c56a78c commit ac25618

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

bigframes/series.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import typing
2626
from typing import (
2727
Any,
28+
Callable,
2829
cast,
2930
Iterable,
3031
List,
@@ -2339,7 +2340,7 @@ def _throw_if_index_contains_duplicates(
23392340

23402341
def map(
23412342
self,
2342-
arg: typing.Union[Mapping, Series],
2343+
arg: typing.Union[Mapping, Series, Callable],
23432344
na_action: Optional[str] = None,
23442345
*,
23452346
verify_integrity: bool = False,
@@ -2361,6 +2362,7 @@ def map(
23612362
)
23622363
map_df = map_df.set_index("keys")
23632364
elif callable(arg):
2365+
# This is for remote function and managed funtion.
23642366
return self.apply(arg)
23652367
else:
23662368
# Mirroring pandas, call the uncallable object

tests/system/large/functions/test_managed_function.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1245,7 +1245,7 @@ def the_sum(s):
12451245
cleanup_function_assets(the_sum_mf, session.bqclient, ignore_failures=False)
12461246

12471247

1248-
def test_managed_function_series_where_mask(session, dataset_id, scalars_dfs):
1248+
def test_managed_function_series_where_mask_map(session, dataset_id, scalars_dfs):
12491249
try:
12501250

12511251
# The return type has to be bool type for callable where condition.
@@ -1286,6 +1286,13 @@ def _is_positive(s):
12861286
# Ignore any dtype difference.
12871287
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
12881288

1289+
# Test series.map method.
1290+
bf_result = bf_int64_filtered.map(is_positive_mf).to_pandas()
1291+
pd_result = pd_int64_filtered.map(_is_positive)
1292+
1293+
# Ignore any dtype difference.
1294+
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
1295+
12891296
finally:
12901297
# Clean up the gcp assets created for the managed function.
12911298
cleanup_function_assets(is_positive_mf, session.bqclient, ignore_failures=False)

0 commit comments

Comments
 (0)