@@ -1806,8 +1806,8 @@ def eval(self, a, b):
18061806
18071807 for i , df in enumerate (
18081808 [
1809- self .spark .sql ("SELECT * FROM test_udtf(a=> 10, b=> 'x')" ),
1810- self .spark .sql ("SELECT * FROM test_udtf(b=> 'x', a=> 10)" ),
1809+ self .spark .sql ("SELECT * FROM test_udtf(a => 10, b => 'x')" ),
1810+ self .spark .sql ("SELECT * FROM test_udtf(b => 'x', a => 10)" ),
18111811 TestUDTF (a = lit (10 ), b = lit ("x" )),
18121812 TestUDTF (b = lit ("x" ), a = lit (10 )),
18131813 ]
@@ -1827,15 +1827,15 @@ def eval(self, a, b):
18271827 AnalysisException ,
18281828 "DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE" ,
18291829 ):
1830- self .spark .sql ("SELECT * FROM test_udtf(a=> 10, a=> 100)" ).show ()
1830+ self .spark .sql ("SELECT * FROM test_udtf(a => 10, a => 100)" ).show ()
18311831
18321832 with self .assertRaisesRegex (AnalysisException , "UNEXPECTED_POSITIONAL_ARGUMENT" ):
1833- self .spark .sql ("SELECT * FROM test_udtf(a=> 10, 'x')" ).show ()
1833+ self .spark .sql ("SELECT * FROM test_udtf(a => 10, 'x')" ).show ()
18341834
18351835 with self .assertRaisesRegex (
18361836 PythonException , r"eval\(\) got an unexpected keyword argument 'c'"
18371837 ):
1838- self .spark .sql ("SELECT * FROM test_udtf(c=> 'x')" ).show ()
1838+ self .spark .sql ("SELECT * FROM test_udtf(c => 'x')" ).show ()
18391839
18401840 def test_udtf_with_kwargs (self ):
18411841 @udtf (returnType = "a: int, b: string" )
@@ -1847,8 +1847,8 @@ def eval(self, **kwargs):
18471847
18481848 for i , df in enumerate (
18491849 [
1850- self .spark .sql ("SELECT * FROM test_udtf(a=> 10, b=> 'x')" ),
1851- self .spark .sql ("SELECT * FROM test_udtf(b=> 'x', a=> 10)" ),
1850+ self .spark .sql ("SELECT * FROM test_udtf(a => 10, b => 'x')" ),
1851+ self .spark .sql ("SELECT * FROM test_udtf(b => 'x', a => 10)" ),
18521852 TestUDTF (a = lit (10 ), b = lit ("x" )),
18531853 TestUDTF (b = lit ("x" ), a = lit (10 )),
18541854 ]
@@ -1874,15 +1874,85 @@ def eval(self, **kwargs):
18741874
18751875 for i , df in enumerate (
18761876 [
1877- self .spark .sql ("SELECT * FROM test_udtf(a=> 10, b=> 'x')" ),
1878- self .spark .sql ("SELECT * FROM test_udtf(b=> 'x', a=> 10)" ),
1877+ self .spark .sql ("SELECT * FROM test_udtf(a => 10, b => 'x')" ),
1878+ self .spark .sql ("SELECT * FROM test_udtf(b => 'x', a => 10)" ),
18791879 TestUDTF (a = lit (10 ), b = lit ("x" )),
18801880 TestUDTF (b = lit ("x" ), a = lit (10 )),
18811881 ]
18821882 ):
18831883 with self .subTest (query_no = i ):
18841884 assertDataFrameEqual (df , [Row (a = 10 , b = "x" )])
18851885
1886+ def test_udtf_with_named_arguments_lateral_join (self ):
1887+ @udtf
1888+ class TestUDTF :
1889+ @staticmethod
1890+ def analyze (a , b ):
1891+ return AnalyzeResult (StructType ().add ("a" , a .data_type ))
1892+
1893+ def eval (self , a , b ):
1894+ yield a ,
1895+
1896+ self .spark .udtf .register ("test_udtf" , TestUDTF )
1897+
1898+ # lateral join
1899+ for i , df in enumerate (
1900+ [
1901+ self .spark .sql (
1902+ "SELECT f.* FROM "
1903+ "VALUES (0, 'x'), (1, 'y') t(a, b), LATERAL test_udtf(a => a, b => b) f"
1904+ ),
1905+ self .spark .sql (
1906+ "SELECT f.* FROM "
1907+ "VALUES (0, 'x'), (1, 'y') t(a, b), LATERAL test_udtf(b => b, a => a) f"
1908+ ),
1909+ ]
1910+ ):
1911+ with self .subTest (query_no = i ):
1912+ assertDataFrameEqual (df , [Row (a = 0 ), Row (a = 1 )])
1913+
1914+ def test_udtf_with_named_arguments_and_defaults (self ):
1915+ @udtf
1916+ class TestUDTF :
1917+ @staticmethod
1918+ def analyze (a , b = None ):
1919+ schema = StructType ().add ("a" , a .data_type )
1920+ if b is None :
1921+ return AnalyzeResult (schema .add ("b" , IntegerType ()))
1922+ else :
1923+ return AnalyzeResult (schema .add ("b" , b .data_type ))
1924+
1925+ def eval (self , a , b = 100 ):
1926+ yield a , b
1927+
1928+ self .spark .udtf .register ("test_udtf" , TestUDTF )
1929+
1930+ # without "b"
1931+ for i , df in enumerate (
1932+ [
1933+ self .spark .sql ("SELECT * FROM test_udtf(10)" ),
1934+ self .spark .sql ("SELECT * FROM test_udtf(a => 10)" ),
1935+ TestUDTF (lit (10 )),
1936+ TestUDTF (a = lit (10 )),
1937+ ]
1938+ ):
1939+ with self .subTest (query_no = i ):
1940+ assertDataFrameEqual (df , [Row (a = 10 , b = 100 )])
1941+
1942+ # with "b"
1943+ for i , df in enumerate (
1944+ [
1945+ self .spark .sql ("SELECT * FROM test_udtf(10, b => 'z')" ),
1946+ self .spark .sql ("SELECT * FROM test_udtf(a => 10, b => 'z')" ),
1947+ self .spark .sql ("SELECT * FROM test_udtf(b => 'z', a => 10)" ),
1948+ TestUDTF (lit (10 ), b = lit ("z" )),
1949+ TestUDTF (a = lit (10 ), b = lit ("z" )),
1950+ TestUDTF (b = lit ("z" ), a = lit (10 )),
1951+ ]
1952+ ):
1953+ with self .subTest (query_no = i ):
1954+ assertDataFrameEqual (df , [Row (a = 10 , b = "z" )])
1955+
18861956
18871957class UDTFTests (BaseUDTFTestsMixin , ReusedSQLTestCase ):
18881958 @classmethod
0 commit comments