Skip to content

Enable OnnxTransformer to accept KeyDataViewTypes as if they were UInt32 #4824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 12, 2020
Prev Previous commit
Next Next commit
Added comment for clarity
  • Loading branch information
antoniovs1029 committed Feb 11, 2020
commit 60bdfd0bb4acc365b35aedf7ca5363da1dccbe43
8 changes: 7 additions & 1 deletion src/Microsoft.ML.OnnxTransformer/OnnxTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,14 @@ public Mapper(OnnxTransformer parent, DataViewSchema inputSchema) :
throw Host.Except($"Variable length input columns not supported");

if (type.GetItemType() != inputNodeInfo.DataViewType.GetItemType())
if(!(type.GetItemType() is KeyDataViewType && inputNodeInfo.DataViewType.GetItemType().RawType == typeof(UInt32)))
{
// If the ONNX model input node expects a type that mismatches with the type of the input IDataView column that is provided
// then throw an exception.
// This is done except in the case where the ONNX model input node expects a UInt32 but the input column is actually KeyDataViewType
// This is done to support a corner case originated in NimbusML. For more info, see: https://github.com/microsoft/NimbusML/issues/426
if (!(type.GetItemType() is KeyDataViewType && inputNodeInfo.DataViewType.GetItemType().RawType == typeof(UInt32)))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.Inputs[i], inputNodeInfo.DataViewType.GetItemType().ToString(), type.ToString());
}

// If the column is one dimension we make sure that the total size of the Onnx shape matches.
// Compute the total size of the known dimensions of the shape.
Expand Down