Skip to content

Commit 7459a14

Browse files
elvaliuliuliuimback82
authored andcommitted
Support UDF that returns RowType (dotnet#376)
1 parent 8ad38fb commit 7459a14

File tree

5 files changed

+318
-23
lines changed

5 files changed

+318
-23
lines changed

src/csharp/Microsoft.Spark.E2ETest/UdfTests/UdfComplexTypesTests.cs

Lines changed: 70 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public void TestUdfWithArrayType()
4848

4949
Row[] rows = _df.Select(workingUdf(_df["ids"])).Collect().ToArray();
5050
Assert.Equal(3, rows.Length);
51-
51+
5252
var expected = new[] { "1", "3,5", "2,4" };
5353
string[] rowsToArray = rows.Select(x => x[0].ToString()).ToArray();
5454
Assert.Equal(expected, rowsToArray);
@@ -158,25 +158,76 @@ public void TestUdfWithRowType()
158158
[Fact]
159159
public void TestUdfWithReturnAsRowType()
160160
{
161-
// UDF with return as RowType throws a following exception:
162-
// Unhandled Exception: System.Reflection.TargetInvocationException: Exception has been thrown by the target of an invocation.
163-
// --->System.ArgumentException: System.Object is not supported.
164-
// at Microsoft.Spark.Utils.UdfUtils.GetReturnType(Type type) in Microsoft.Spark\Utils\UdfUtils.cs:line 142
165-
// at Microsoft.Spark.Utils.UdfUtils.GetReturnType(Type type) in Microsoft.Spark\Utils\UdfUtils.cs:line 136
166-
// at Microsoft.Spark.Sql.Functions.CreateUdf[TResult](String name, Delegate execute, PythonEvalType evalType) in Microsoft.Spark\Sql\Functions.cs:line 4053
167-
// at Microsoft.Spark.Sql.Functions.CreateUdf[TResult](String name, Delegate execute) in Microsoft.Spark\Sql\Functions.cs:line 4040
168-
// at Microsoft.Spark.Sql.Functions.Udf[T, TResult](Func`2 udf) in Microsoft.Spark\Sql\Functions.cs:line 3607
169-
Assert.Throws<ArgumentException>(() => Udf<string, Row>(
170-
(str) =>
161+
// Single GenericRow
162+
{
163+
var schema = new StructType(new[]
164+
{
165+
new StructField("col1", new IntegerType()),
166+
new StructField("col2", new StringType())
167+
});
168+
Func<Column, Column> udf = Udf<string>(
169+
str => new GenericRow(new object[] { 1, "abc" }), schema);
170+
171+
Row[] rows = _df.Select(udf(_df["name"])).Collect().ToArray();
172+
Assert.Equal(3, rows.Length);
173+
174+
foreach (Row row in rows)
175+
{
176+
Assert.Equal(2, row.Size());
177+
Assert.Equal(1, row.GetAs<int>("col1"));
178+
Assert.Equal("abc", row.GetAs<string>("col2"));
179+
}
180+
}
181+
182+
// Nested GenericRow
183+
{
184+
var subSchema1 = new StructType(new[]
185+
{
186+
new StructField("col1", new IntegerType()),
187+
});
188+
var subSchema2 = new StructType(new[]
189+
{
190+
new StructField("col1", new StringType()),
191+
new StructField("col2", subSchema1),
192+
});
193+
var schema = new StructType(new[]
194+
{
195+
new StructField("col1", new IntegerType()),
196+
new StructField("col2", subSchema1),
197+
new StructField("col3", subSchema2)
198+
});
199+
200+
Func<Column, Column> udf = Udf<string>(
201+
str => new GenericRow(
202+
new object[]
203+
{
204+
1,
205+
new GenericRow(new object[] { 1 }),
206+
new GenericRow(new object[]
207+
{
208+
"abc",
209+
new GenericRow(new object[] { 10 })
210+
})
211+
}),
212+
schema);
213+
214+
Row[] rows = _df.Select(udf(_df["name"])).Collect().ToArray();
215+
Assert.Equal(3, rows.Length);
216+
217+
foreach (Row row in rows)
171218
{
172-
var structFields = new List<StructField>()
173-
{
174-
new StructField("name", new StringType()),
175-
};
176-
var schema = new StructType(structFields);
177-
var row = new Row(new object[] { str }, schema);
178-
return row;
179-
}));
219+
Assert.Equal(3, row.Size());
220+
Assert.Equal(1, row.GetAs<int>("col1"));
221+
Assert.Equal(
222+
new Row(new object[] { 1 }, subSchema1),
223+
row.GetAs<Row>("col2"));
224+
Assert.Equal(
225+
new Row(
226+
new object[] { "abc", new Row(new object[] { 10 }, subSchema1) },
227+
subSchema2),
228+
row.GetAs<Row>("col3"));
229+
}
230+
}
180231
}
181232
}
182233
}

src/csharp/Microsoft.Spark/Sql/Functions.cs

Lines changed: 205 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
using Microsoft.Spark.Interop;
99
using Microsoft.Spark.Interop.Ipc;
1010
using Microsoft.Spark.Sql.Expressions;
11+
using Microsoft.Spark.Sql.Types;
1112
using Microsoft.Spark.Utils;
1213

1314
namespace Microsoft.Spark.Sql
@@ -3797,6 +3798,189 @@ public static Func<Column, Column, Column, Column, Column, Column, Column, Colum
37973798
return CreateUdf<TResult>(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf)).Apply10;
37983799
}
37993800

3801+
/// <summary>Creates a UDF from the specified delegate.</summary>
3802+
/// <param name="udf">The UDF function implementation.</param>
3803+
/// <param name="returnType">Schema associated with this row</param>
3804+
/// <returns>
3805+
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
3806+
/// </returns>
3807+
public static Func<Column> Udf(Func<GenericRow> udf, StructType returnType)
3808+
{
3809+
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply0;
3810+
}
3811+
3812+
/// <summary>Creates a UDF from the specified delegate.</summary>
3813+
/// <typeparam name="T">Specifies the type of the first argument to the UDF.</typeparam>
3814+
/// <param name="udf">The UDF function implementation.</param>
3815+
/// <param name="returnType">Schema associated with this row</param>
3816+
/// <returns>
3817+
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
3818+
/// </returns>
3819+
public static Func<Column, Column> Udf<T>(Func<T, GenericRow> udf, StructType returnType)
3820+
{
3821+
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply1;
3822+
}
3823+
3824+
/// <summary>Creates a UDF from the specified delegate.</summary>
3825+
/// <typeparam name="T1">Specifies the type of the first argument to the UDF.</typeparam>
3826+
/// <typeparam name="T2">Specifies the type of the second argument to the UDF.</typeparam>
3827+
/// <param name="udf">The UDF function implementation.</param>
3828+
/// <param name="returnType">Schema associated with this row</param>
3829+
/// <returns>
3830+
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
3831+
/// </returns>
3832+
public static Func<Column, Column, Column> Udf<T1, T2>(
3833+
Func<T1, T2, GenericRow> udf, StructType returnType)
3834+
{
3835+
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply2;
3836+
}
3837+
3838+
/// <summary>Creates a UDF from the specified delegate.</summary>
3839+
/// <typeparam name="T1">Specifies the type of the first argument to the UDF.</typeparam>
3840+
/// <typeparam name="T2">Specifies the type of the second argument to the UDF.</typeparam>
3841+
/// <typeparam name="T3">Specifies the type of the third argument to the UDF.</typeparam>
3842+
/// <param name="udf">The UDF function implementation.</param>
3843+
/// <param name="returnType">Schema associated with this row</param>
3844+
/// <returns>
3845+
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
3846+
/// </returns>
3847+
public static Func<Column, Column, Column, Column> Udf<T1, T2, T3>(
3848+
Func<T1, T2, T3, GenericRow> udf, StructType returnType)
3849+
{
3850+
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply3;
3851+
}
3852+
3853+
/// <summary>Creates a UDF from the specified delegate.</summary>
3854+
/// <typeparam name="T1">Specifies the type of the first argument to the UDF.</typeparam>
3855+
/// <typeparam name="T2">Specifies the type of the second argument to the UDF.</typeparam>
3856+
/// <typeparam name="T3">Specifies the type of the third argument to the UDF.</typeparam>
3857+
/// <typeparam name="T4">Specifies the type of the fourth argument to the UDF.</typeparam>
3858+
/// <param name="udf">The UDF function implementation.</param>
3859+
/// <param name="returnType">Schema associated with this row</param>
3860+
/// <returns>
3861+
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
3862+
/// </returns>
3863+
public static Func<Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4>(
3864+
Func<T1, T2, T3, T4, GenericRow> udf, StructType returnType)
3865+
{
3866+
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply4;
3867+
}
3868+
3869+
/// <summary>Creates a UDF from the specified delegate.</summary>
3870+
/// <typeparam name="T1">Specifies the type of the first argument to the UDF.</typeparam>
3871+
/// <typeparam name="T2">Specifies the type of the second argument to the UDF.</typeparam>
3872+
/// <typeparam name="T3">Specifies the type of the third argument to the UDF.</typeparam>
3873+
/// <typeparam name="T4">Specifies the type of the fourth argument to the UDF.</typeparam>
3874+
/// <typeparam name="T5">Specifies the type of the fifth argument to the UDF.</typeparam>
3875+
/// <param name="udf">The UDF function implementation.</param>
3876+
/// <param name="returnType">Schema associated with this row</param>
3877+
/// <returns>
3878+
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
3879+
/// </returns>
3880+
public static Func<Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5>(
3881+
Func<T1, T2, T3, T4, T5, GenericRow> udf, StructType returnType)
3882+
{
3883+
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply5;
3884+
}
3885+
3886+
/// <summary>Creates a UDF from the specified delegate.</summary>
3887+
/// <typeparam name="T1">Specifies the type of the first argument to the UDF.</typeparam>
3888+
/// <typeparam name="T2">Specifies the type of the second argument to the UDF.</typeparam>
3889+
/// <typeparam name="T3">Specifies the type of the third argument to the UDF.</typeparam>
3890+
/// <typeparam name="T4">Specifies the type of the fourth argument to the UDF.</typeparam>
3891+
/// <typeparam name="T5">Specifies the type of the fifth argument to the UDF.</typeparam>
3892+
/// <typeparam name="T6">Specifies the type of the sixth argument to the UDF.</typeparam>
3893+
/// <param name="udf">The UDF function implementation.</param>
3894+
/// <param name="returnType">Schema associated with this row</param>
3895+
/// <returns>
3896+
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
3897+
/// </returns>
3898+
public static Func<Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6>(
3899+
Func<T1, T2, T3, T4, T5, T6, GenericRow> udf, StructType returnType)
3900+
{
3901+
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply6;
3902+
}
3903+
3904+
/// <summary>Creates a UDF from the specified delegate.</summary>
3905+
/// <typeparam name="T1">Specifies the type of the first argument to the UDF.</typeparam>
3906+
/// <typeparam name="T2">Specifies the type of the second argument to the UDF.</typeparam>
3907+
/// <typeparam name="T3">Specifies the type of the third argument to the UDF.</typeparam>
3908+
/// <typeparam name="T4">Specifies the type of the fourth argument to the UDF.</typeparam>
3909+
/// <typeparam name="T5">Specifies the type of the fifth argument to the UDF.</typeparam>
3910+
/// <typeparam name="T6">Specifies the type of the sixth argument to the UDF.</typeparam>
3911+
/// <typeparam name="T7">Specifies the type of the seventh argument to the UDF.</typeparam>
3912+
/// <param name="udf">The UDF function implementation.</param>
3913+
/// <param name="returnType">Schema associated with this row</param>
3914+
/// <returns>
3915+
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
3916+
/// </returns>
3917+
public static Func<Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7>(
3918+
Func<T1, T2, T3, T4, T5, T6, T7, GenericRow> udf, StructType returnType)
3919+
{
3920+
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply7;
3921+
}
3922+
3923+
/// <summary>Creates a UDF from the specified delegate.</summary>
3924+
/// <typeparam name="T1">Specifies the type of the first argument to the UDF.</typeparam>
3925+
/// <typeparam name="T2">Specifies the type of the second argument to the UDF.</typeparam>
3926+
/// <typeparam name="T3">Specifies the type of the third argument to the UDF.</typeparam>
3927+
/// <typeparam name="T4">Specifies the type of the fourth argument to the UDF.</typeparam>
3928+
/// <typeparam name="T5">Specifies the type of the fifth argument to the UDF.</typeparam>
3929+
/// <typeparam name="T6">Specifies the type of the sixth argument to the UDF.</typeparam>
3930+
/// <typeparam name="T7">Specifies the type of the seventh argument to the UDF.</typeparam>
3931+
/// <typeparam name="T8">Specifies the type of the eighth argument to the UDF.</typeparam>
3932+
/// <param name="udf">The UDF function implementation.</param>
3933+
/// <param name="returnType">Schema associated with this row</param>
3934+
/// <returns>
3935+
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
3936+
/// </returns>
3937+
public static Func<Column, Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7, T8>(
3938+
Func<T1, T2, T3, T4, T5, T6, T7, T8, GenericRow> udf, StructType returnType)
3939+
{
3940+
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply8;
3941+
}
3942+
3943+
/// <summary>Creates a UDF from the specified delegate.</summary>
3944+
/// <typeparam name="T1">Specifies the type of the first argument to the UDF.</typeparam>
3945+
/// <typeparam name="T2">Specifies the type of the second argument to the UDF.</typeparam>
3946+
/// <typeparam name="T3">Specifies the type of the third argument to the UDF.</typeparam>
3947+
/// <typeparam name="T4">Specifies the type of the fourth argument to the UDF.</typeparam>
3948+
/// <typeparam name="T5">Specifies the type of the fifth argument to the UDF.</typeparam>
3949+
/// <typeparam name="T6">Specifies the type of the sixth argument to the UDF.</typeparam>
3950+
/// <typeparam name="T7">Specifies the type of the seventh argument to the UDF.</typeparam>
3951+
/// <typeparam name="T8">Specifies the type of the eighth argument to the UDF.</typeparam>
3952+
/// <typeparam name="T9">Specifies the type of the ninth argument to the UDF.</typeparam>
3953+
/// <param name="udf">The UDF function implementation.</param>
3954+
/// <param name="returnType">Schema associated with this row</param>
3955+
/// <returns>A delegate that when invoked will return a <see cref="Column"/> for the result of the UDF.</returns>
3956+
public static Func<Column, Column, Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7, T8, T9>(
3957+
Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, GenericRow> udf, StructType returnType)
3958+
{
3959+
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply9;
3960+
}
3961+
3962+
/// <summary>Creates a UDF from the specified delegate.</summary>
3963+
/// <typeparam name="T1">Specifies the type of the first argument to the UDF.</typeparam>
3964+
/// <typeparam name="T2">Specifies the type of the second argument to the UDF.</typeparam>
3965+
/// <typeparam name="T3">Specifies the type of the third argument to the UDF.</typeparam>
3966+
/// <typeparam name="T4">Specifies the type of the fourth argument to the UDF.</typeparam>
3967+
/// <typeparam name="T5">Specifies the type of the fifth argument to the UDF.</typeparam>
3968+
/// <typeparam name="T6">Specifies the type of the sixth argument to the UDF.</typeparam>
3969+
/// <typeparam name="T7">Specifies the type of the seventh argument to the UDF.</typeparam>
3970+
/// <typeparam name="T8">Specifies the type of the eighth argument to the UDF.</typeparam>
3971+
/// <typeparam name="T9">Specifies the type of the ninth argument to the UDF.</typeparam>
3972+
/// <typeparam name="T10">Specifies the type of the tenth argument to the UDF.</typeparam>
3973+
/// <param name="udf">The UDF function implementation.</param>
3974+
/// <param name="returnType">Schema associated with this row</param>
3975+
/// <returns>
3976+
/// A delegate that returns a <see cref="Column"/> for the result of the UDF.
3977+
/// </returns>
3978+
public static Func<Column, Column, Column, Column, Column, Column, Column, Column, Column, Column, Column> Udf<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10>(
3979+
Func<T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, GenericRow> udf, StructType returnType)
3980+
{
3981+
return CreateUdf(udf.Method.ToString(), UdfUtils.CreateUdfWrapper(udf), returnType).Apply10;
3982+
}
3983+
38003984
/// <summary>Creates a Vector UDF from the specified delegate.</summary>
38013985
/// <typeparam name="T">Specifies the type of the first argument to the UDF.</typeparam>
38023986
/// <typeparam name="TResult">Specifies the return type of the UDF.</typeparam>
@@ -4071,6 +4255,11 @@ private static UserDefinedFunction CreateUdf<TResult>(string name, Delegate exec
40714255
return CreateUdf<TResult>(name, execute, UdfUtils.PythonEvalType.SQL_BATCHED_UDF);
40724256
}
40734257

4258+
private static UserDefinedFunction CreateUdf(string name, Delegate execute, StructType returnType)
4259+
{
4260+
return CreateUdf(name, execute, UdfUtils.PythonEvalType.SQL_BATCHED_UDF, returnType);
4261+
}
4262+
40744263
private static UserDefinedFunction CreateVectorUdf<TResult>(string name, Delegate execute)
40754264
{
40764265
return CreateUdf<TResult>(name, execute, UdfUtils.PythonEvalType.SQL_SCALAR_PANDAS_UDF);
@@ -4079,7 +4268,21 @@ private static UserDefinedFunction CreateVectorUdf<TResult>(string name, Delegat
40794268
private static UserDefinedFunction CreateUdf<TResult>(
40804269
string name,
40814270
Delegate execute,
4082-
UdfUtils.PythonEvalType evalType)
4271+
UdfUtils.PythonEvalType evalType) =>
4272+
CreateUdf(name, execute, evalType, UdfUtils.GetReturnType(typeof(TResult)));
4273+
4274+
private static UserDefinedFunction CreateUdf(
4275+
string name,
4276+
Delegate execute,
4277+
UdfUtils.PythonEvalType evalType,
4278+
StructType returnType) =>
4279+
CreateUdf(name, execute, evalType, returnType.Json);
4280+
4281+
private static UserDefinedFunction CreateUdf(
4282+
string name,
4283+
Delegate execute,
4284+
UdfUtils.PythonEvalType evalType,
4285+
string returnType)
40834286
{
40844287
return UserDefinedFunction.Create(
40854288
name,
@@ -4088,7 +4291,7 @@ private static UserDefinedFunction CreateUdf<TResult>(
40884291
CommandSerDe.SerializedMode.Row,
40894292
CommandSerDe.SerializedMode.Row),
40904293
evalType,
4091-
UdfUtils.GetReturnType(typeof(TResult)));
4294+
returnType);
40924295
}
40934296

40944297
private static Column ApplyFunction(string funcName)

0 commit comments

Comments
 (0)