Skip to content

Commit f7002fb

Browse files
committed
[SPARK-44749][PYTHON][FOLLOWUP][TESTS] Add more tests for named arguments in Python UDTF
### What changes were proposed in this pull request? This is a follow-up of apache#42422. Adds more tests for named arguments in Python UDTF. ### Why are the changes needed? There are more cases to test. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added related tests. Closes apache#42490 from ueshin/issues/SPARK-44749/tests. Authored-by: Takuya UESHIN <[email protected]> Signed-off-by: Takuya UESHIN <[email protected]>
1 parent 1d56290 commit f7002fb

File tree

2 files changed

+80
-10
lines changed

2 files changed

+80
-10
lines changed

python/pyspark/sql/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15654,7 +15654,7 @@ def udtf(
1565415654
+---+---+---+
1565515655
1565615656
>>> _ = spark.udtf.register("test_udtf", TestUDTFWithKwargs)
15657-
>>> spark.sql("SELECT * FROM test_udtf(1, x=>'x', b=>'b')").show()
15657+
>>> spark.sql("SELECT * FROM test_udtf(1, x => 'x', b => 'b')").show()
1565815658
+---+---+---+
1565915659
| a| b| x|
1566015660
+---+---+---+

python/pyspark/sql/tests/test_udtf.py

Lines changed: 79 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1806,8 +1806,8 @@ def eval(self, a, b):
18061806

18071807
for i, df in enumerate(
18081808
[
1809-
self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
1810-
self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
1809+
self.spark.sql("SELECT * FROM test_udtf(a => 10, b => 'x')"),
1810+
self.spark.sql("SELECT * FROM test_udtf(b => 'x', a => 10)"),
18111811
TestUDTF(a=lit(10), b=lit("x")),
18121812
TestUDTF(b=lit("x"), a=lit(10)),
18131813
]
@@ -1827,15 +1827,15 @@ def eval(self, a, b):
18271827
AnalysisException,
18281828
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
18291829
):
1830-
self.spark.sql("SELECT * FROM test_udtf(a=>10, a=>100)").show()
1830+
self.spark.sql("SELECT * FROM test_udtf(a => 10, a => 100)").show()
18311831

18321832
with self.assertRaisesRegex(AnalysisException, "UNEXPECTED_POSITIONAL_ARGUMENT"):
1833-
self.spark.sql("SELECT * FROM test_udtf(a=>10, 'x')").show()
1833+
self.spark.sql("SELECT * FROM test_udtf(a => 10, 'x')").show()
18341834

18351835
with self.assertRaisesRegex(
18361836
PythonException, r"eval\(\) got an unexpected keyword argument 'c'"
18371837
):
1838-
self.spark.sql("SELECT * FROM test_udtf(c=>'x')").show()
1838+
self.spark.sql("SELECT * FROM test_udtf(c => 'x')").show()
18391839

18401840
def test_udtf_with_kwargs(self):
18411841
@udtf(returnType="a: int, b: string")
@@ -1847,8 +1847,8 @@ def eval(self, **kwargs):
18471847

18481848
for i, df in enumerate(
18491849
[
1850-
self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
1851-
self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
1850+
self.spark.sql("SELECT * FROM test_udtf(a => 10, b => 'x')"),
1851+
self.spark.sql("SELECT * FROM test_udtf(b => 'x', a => 10)"),
18521852
TestUDTF(a=lit(10), b=lit("x")),
18531853
TestUDTF(b=lit("x"), a=lit(10)),
18541854
]
@@ -1874,15 +1874,85 @@ def eval(self, **kwargs):
18741874

18751875
for i, df in enumerate(
18761876
[
1877-
self.spark.sql("SELECT * FROM test_udtf(a=>10, b=>'x')"),
1878-
self.spark.sql("SELECT * FROM test_udtf(b=>'x', a=>10)"),
1877+
self.spark.sql("SELECT * FROM test_udtf(a => 10, b => 'x')"),
1878+
self.spark.sql("SELECT * FROM test_udtf(b => 'x', a => 10)"),
18791879
TestUDTF(a=lit(10), b=lit("x")),
18801880
TestUDTF(b=lit("x"), a=lit(10)),
18811881
]
18821882
):
18831883
with self.subTest(query_no=i):
18841884
assertDataFrameEqual(df, [Row(a=10, b="x")])
18851885

1886+
def test_udtf_with_named_arguments_lateral_join(self):
1887+
@udtf
1888+
class TestUDTF:
1889+
@staticmethod
1890+
def analyze(a, b):
1891+
return AnalyzeResult(StructType().add("a", a.data_type))
1892+
1893+
def eval(self, a, b):
1894+
yield a,
1895+
1896+
self.spark.udtf.register("test_udtf", TestUDTF)
1897+
1898+
# lateral join
1899+
for i, df in enumerate(
1900+
[
1901+
self.spark.sql(
1902+
"SELECT f.* FROM "
1903+
"VALUES (0, 'x'), (1, 'y') t(a, b), LATERAL test_udtf(a => a, b => b) f"
1904+
),
1905+
self.spark.sql(
1906+
"SELECT f.* FROM "
1907+
"VALUES (0, 'x'), (1, 'y') t(a, b), LATERAL test_udtf(b => b, a => a) f"
1908+
),
1909+
]
1910+
):
1911+
with self.subTest(query_no=i):
1912+
assertDataFrameEqual(df, [Row(a=0), Row(a=1)])
1913+
1914+
def test_udtf_with_named_arguments_and_defaults(self):
1915+
@udtf
1916+
class TestUDTF:
1917+
@staticmethod
1918+
def analyze(a, b=None):
1919+
schema = StructType().add("a", a.data_type)
1920+
if b is None:
1921+
return AnalyzeResult(schema.add("b", IntegerType()))
1922+
else:
1923+
return AnalyzeResult(schema.add("b", b.data_type))
1924+
1925+
def eval(self, a, b=100):
1926+
yield a, b
1927+
1928+
self.spark.udtf.register("test_udtf", TestUDTF)
1929+
1930+
# without "b"
1931+
for i, df in enumerate(
1932+
[
1933+
self.spark.sql("SELECT * FROM test_udtf(10)"),
1934+
self.spark.sql("SELECT * FROM test_udtf(a => 10)"),
1935+
TestUDTF(lit(10)),
1936+
TestUDTF(a=lit(10)),
1937+
]
1938+
):
1939+
with self.subTest(query_no=i):
1940+
assertDataFrameEqual(df, [Row(a=10, b=100)])
1941+
1942+
# with "b"
1943+
for i, df in enumerate(
1944+
[
1945+
self.spark.sql("SELECT * FROM test_udtf(10, b => 'z')"),
1946+
self.spark.sql("SELECT * FROM test_udtf(a => 10, b => 'z')"),
1947+
self.spark.sql("SELECT * FROM test_udtf(b => 'z', a => 10)"),
1948+
TestUDTF(lit(10), b=lit("z")),
1949+
TestUDTF(a=lit(10), b=lit("z")),
1950+
TestUDTF(b=lit("z"), a=lit(10)),
1951+
]
1952+
):
1953+
with self.subTest(query_no=i):
1954+
assertDataFrameEqual(df, [Row(a=10, b="z")])
1955+
18861956

18871957
class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
18881958
@classmethod

0 commit comments

Comments
 (0)