Skip to content

Commit 4bf1a5d

Browse files
committed
Fixed casting logic for hashing to be based on raw type in order to support key types correctly
1 parent ad16cc6 commit 4bf1a5d

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/Microsoft.ML.Data/Transforms/Hashing.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -1359,7 +1359,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, stri
13591359
OnnxNode murmurNode;
13601360
OnnxNode isZeroNode;
13611361

1362-
var srcType = _srcTypes[iinfo].GetItemType();
1362+
var srcType = _srcTypes[iinfo].GetItemType().RawType;
13631363
if (_parent._columns[iinfo].Combine)
13641364
return false;
13651365

@@ -1386,9 +1386,9 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, string srcVariable, stri
13861386
}
13871387

13881388
// Since these numeric types are not supported by Onnxruntime, we cast them to UInt32.
1389-
if (srcType == NumberDataViewType.UInt16 || srcType == NumberDataViewType.Int16 ||
1390-
srcType == NumberDataViewType.SByte || srcType == NumberDataViewType.Byte ||
1391-
srcType == BooleanDataViewType.Instance || srcType is KeyDataViewType)
1389+
if (srcType == typeof(ushort) || srcType == typeof(short) ||
1390+
srcType == typeof(sbyte) || srcType == typeof(byte) ||
1391+
srcType == typeof(bool))
13921392
{
13931393
castOutput = ctx.AddIntermediateVariable(NumberDataViewType.UInt32, "CastOutput", true);
13941394
castNode = ctx.CreateNode("Cast", srcVariable, castOutput, ctx.GetNodeName(opType), "");

0 commit comments

Comments
 (0)