-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Added multiple related fixes to enable automatic addition of KeyToValue #4878
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -199,23 +199,26 @@ private protected override void SaveAsOnnxCore(OnnxContext ctx) | |
for (int iinfo = 0; iinfo < Bindings.InfoCount; ++iinfo) | ||
outColumnNames[iinfo] = Bindings.GetColumnName(Bindings.MapIinfoToCol(iinfo)); | ||
|
||
string opType = "Binarizer"; | ||
string scoreColumn = Bindings.RowMapper.OutputSchema[Bindings.ScoreColumnIndex].Name; | ||
|
||
OnnxNode node; | ||
var binarizerOutput = ctx.AddIntermediateVariable(null, "BinarizerOutput", true); | ||
string opType = "Binarizer"; | ||
var binarizerOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "BinarizerOutput", false); | ||
node = ctx.CreateNode(opType, ctx.GetVariableName(scoreColumn), binarizerOutput, ctx.GetNodeName(opType)); | ||
node.AddAttribute("threshold", _threshold); | ||
|
||
string scoreColumn; | ||
if (Bindings.RowMapper.OutputSchema[Bindings.ScoreColumnIndex].Name == "Score") | ||
scoreColumn = outColumnNames[1]; | ||
else | ||
string comparisonOutput = binarizerOutput; | ||
if (Bindings.PredColType is KeyDataViewType) | ||
{ | ||
Host.Assert(Bindings.InfoCount >= 3); | ||
scoreColumn = outColumnNames[2]; | ||
var one = ctx.AddInitializer(1.0f, "one"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Is this +1 ? to be 1 based? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this is to make it one based and make it consistent with ML.NET results. In reply to: 383088629 [](ancestors = 383088629) |
||
var addOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "Add", false); | ||
opType = "Add"; | ||
ctx.CreateNode(opType, new[] { binarizerOutput, one }, new[] { addOutput }, ctx.GetNodeName(opType), ""); | ||
comparisonOutput = addOutput; | ||
} | ||
node = ctx.CreateNode(opType, ctx.GetVariableName(scoreColumn), binarizerOutput, ctx.GetNodeName(opType)); | ||
node.AddAttribute("threshold", _threshold); | ||
|
||
opType = "Cast"; | ||
node = ctx.CreateNode(opType, binarizerOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), ""); | ||
node = ctx.CreateNode(opType, comparisonOutput, ctx.GetVariableName(outColumnNames[0]), ctx.GetNodeName(opType), ""); | ||
var predictedLabelCol = OutputSchema.GetColumnOrNull(outColumnNames[0]); | ||
Host.Assert(predictedLabelCol.HasValue); | ||
node.AddAttribute("to", predictedLabelCol.Value.Type.RawType); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -787,7 +787,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src | |
long[] termIds; | ||
string opType = "LabelEncoder"; | ||
OnnxNode castNode; | ||
var labelEncoderOutput = ctx.AddIntermediateVariable(_types[iinfo], "LabelEncoderOutput", true); | ||
var labelEncoderOutput = ctx.AddIntermediateVariable(NumberDataViewType.Int64, "LabelEncoderOutput", true); | ||
|
||
if (info.TypeSrc.GetItemType().Equals(TextDataViewType.Instance)) | ||
{ | ||
|
@@ -804,7 +804,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src | |
else if (info.TypeSrc.GetItemType().Equals(NumberDataViewType.Double)) | ||
{ | ||
// LabelEncoder doesn't support double tensors, so values are cast to floats | ||
var castOutput = ctx.AddIntermediateVariable(null, "castOutput", true); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
How did it work with null before? Was there an exception? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the last parameter is true, it skips adding shape and type information and therefore accepts null. This still works, but I am prepping some parts of the code base for issues I have seen when run against the master branch of ORT. In reply to: 383088824 [](ancestors = 383088824) |
||
var castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "castOutput", true); | ||
castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName(opType), ""); | ||
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType(); | ||
castNode.AddAttribute("to", t); | ||
|
@@ -815,7 +815,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src | |
else if (info.TypeSrc.GetItemType().Equals(NumberDataViewType.Int64)) | ||
{ | ||
// LabelEncoder doesn't support mapping int64 -> int64, so values are cast to strings | ||
var castOutput = ctx.AddIntermediateVariable(null, "castOutput", true); | ||
var castOutput = ctx.AddIntermediateVariable(TextDataViewType.Instance, "castOutput", true); | ||
castNode = ctx.CreateNode("Cast", srcVariableName, castOutput, ctx.GetNodeName(opType), ""); | ||
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.String).ToType(); | ||
castNode.AddAttribute("to", t); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this false mean?
#Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
False = Do not skip adding shape and type information (or in other words, add shape and type information). False is the default value. Technically it is not necessary to specify it.
In reply to: 383088227 [](ancestors = 383088227)