Skip to content

Commit ed481b6

Browse files
authored
Fix for MulticlassNaiveBayesTrainer export to Onnx (dotnet#4928)
* adding support for batch input dim
1 parent f6cdf57 commit ed481b6

File tree

1 file changed

+41
-29
lines changed

1 file changed

+41
-29
lines changed

src/Microsoft.ML.StandardTrainers/Standard/MulticlassClassification/MulticlassNaiveBayesTrainer.cs

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
405405
}
406406

407407
var one = ctx.AddInitializer(1.0f, "one");
408+
var oneInt = ctx.AddInitializer(1, typeof(int), "oneInt");
408409
var zero = ctx.AddInitializer(0.0f, "zero");
409410
var labelCount = ctx.AddInitializer((float)_labelCount, "labelCount");
410411
var trainingCount = ctx.AddInitializer((float)_totalTrainingCount, "totalTrainingCount");
@@ -414,108 +415,119 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
414415
var labelHistogramName = ctx.AddInitializer(labelHistogramExpanded, new long[] { _featureHistogram[0].Length, _labelHistogram.Length }, "labelHistogramExpanded");
415416
var learnedAbsentFeatureLogProb = ctx.AddInitializer(_absentFeaturesLogProb, new long[] { _absentFeaturesLogProb.Length, 1 }, "absentFeaturesLogProb");
416417

417-
var greaterOutput = ctx.AddIntermediateVariable(null, "greaterOutput", true);
418+
var typeOne = new VectorDataViewType(NumberDataViewType.Single, 1);
419+
var typeFea = new VectorDataViewType(NumberDataViewType.Single, _featureHistogram[0].Length);
420+
var typeLabelByFea = new VectorDataViewType(NumberDataViewType.Single, _labelHistogram.Length, _featureHistogram[0].Length);
421+
var typeLabelByOne = new VectorDataViewType(NumberDataViewType.Single, _labelHistogram.Length, 1);
422+
423+
var greaterOutput = ctx.AddIntermediateVariable(new VectorDataViewType(BooleanDataViewType.Instance, _featureHistogram[0].Length), "greaterOutput");
418424
var opType = "Greater";
419425
ctx.CreateNode(opType, new[] { featureColumn, zero }, new[] { greaterOutput }, ctx.GetNodeName(opType), "");
420426

421427
opType = "Cast";
422-
var isFeaturePresent = ctx.AddIntermediateVariable(null, "isFeaturePresent", true);
423-
var node = ctx.CreateNode(opType, greaterOutput, isFeaturePresent, ctx.GetNodeName(opType), "");
428+
var castOutput = ctx.AddIntermediateVariable(typeFea, "CastOutput");
429+
var node = ctx.CreateNode(opType, greaterOutput, castOutput, ctx.GetNodeName(opType), "");
424430
var t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
425431
node.AddAttribute("to", t);
426432

433+
opType = "ExpandDims";
434+
var isFeaturePresent = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, 1, _featureHistogram[0].Length), "isFeaturePresent");
435+
ctx.CreateNode(opType, new[] { castOutput, oneInt }, new[] { isFeaturePresent }, ctx.GetNodeName(opType), "com.microsoft");
436+
427437
//initialize logProb
428438
opType = "Div";
429-
var divOutput = ctx.AddIntermediateVariable(null, "DivOutput", true);
439+
var divOutput = ctx.AddIntermediateVariable(typeOne, "DivOutput");
430440
ctx.CreateNode(opType, new[] { labelHistogram, trainingCount }, new[] { divOutput }, ctx.GetNodeName(opType), "");
431441

432442
opType = "Log";
433-
var logOutput = ctx.AddIntermediateVariable(null, "LogOutput", true);
443+
var logOutput = ctx.AddIntermediateVariable(typeOne, "LogOutput");
434444
ctx.CreateNode(opType, divOutput, logOutput, ctx.GetNodeName(opType), "");
435445

436446
//log1
437447
opType = "Sum";
438-
var sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
448+
var sumOutput = ctx.AddIntermediateVariable(_inputType, "SumOutput");
439449
ctx.CreateNode(opType, new[] { featureHistogramName, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
440450

441-
var logOutput1 = ctx.AddIntermediateVariable(null, "LogOutput", true);
451+
var logOutput1 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput");
442452
LogMul(ctx, sumOutput, isFeaturePresent, logOutput1);
443453

444454
//log2
445455
opType = "Transpose";
446-
var labelHistogramTrans = ctx.AddIntermediateVariable(null, "transpose", true);
456+
var labelHistogramTrans = ctx.AddIntermediateVariable(typeFea, "Transpose");
447457
ctx.CreateNode(opType, labelHistogramName, labelHistogramTrans, ctx.GetNodeName(opType), "");
448458

449459
opType = "Sub";
450-
var absentFeatureCount = ctx.AddIntermediateVariable(null, "AbsentFeatureCounts", true);
460+
var absentFeatureCount = ctx.AddIntermediateVariable(typeFea, "AbsentFeatureCounts");
451461
ctx.CreateNode(opType, new[] { labelHistogramTrans, featureHistogramName }, new[] { absentFeatureCount }, ctx.GetNodeName(opType), "");
452462

453463
opType = "Sum";
454-
sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
464+
sumOutput = ctx.AddIntermediateVariable(typeFea, "SumOutput");
455465
ctx.CreateNode(opType, new[] { labelHistogramTrans, labelCount }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
456466

457-
var logOutput2 = ctx.AddIntermediateVariable(null, "LogOutput", true);
467+
var logOutput2 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput");
458468
LogMul(ctx, sumOutput, isFeaturePresent, logOutput2);
459469

460470
//log3
461471
opType = "Sum";
462-
sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
472+
sumOutput = ctx.AddIntermediateVariable(typeFea, "SumOutput");
463473
ctx.CreateNode(opType, new[] { absentFeatureCount, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
464474

465-
var logOutput3 = ctx.AddIntermediateVariable(null, "LogOutput", true);
475+
var logOutput3 = ctx.AddIntermediateVariable(typeLabelByFea, "LogOutput");
466476
LogMul(ctx, sumOutput, isFeaturePresent, logOutput3);
467477

468478
//result
469479
opType = "Sub";
470-
var logProb = ctx.AddIntermediateVariable(null, "LogProb", true);
480+
var logProb = ctx.AddIntermediateVariable(typeLabelByFea, "LogProb");
471481
ctx.CreateNode(opType, new[] { logOutput1, logOutput2 }, new[] { logProb }, ctx.GetNodeName(opType), "");
472482

473483
opType = "Sub";
474-
var absentFeatureLogProb = ctx.AddIntermediateVariable(null, "AbsentFeatureLogProb", true);
484+
var absentFeatureLogProb = ctx.AddIntermediateVariable(typeLabelByFea, "AbsentFeatureLogProb");
475485
ctx.CreateNode(opType, new[] { logOutput3, logOutput2 }, new[] { absentFeatureLogProb }, ctx.GetNodeName(opType), "");
476486

477487
opType = "ReduceSum";
478-
var logProbReduceSum = ctx.AddIntermediateVariable(null, "ReduceSum", true);
488+
var logProbReduceSum = ctx.AddIntermediateVariable(typeLabelByOne, "ReduceSum");
479489
node = ctx.CreateNode(opType, new[] { logProb }, new[] { logProbReduceSum }, ctx.GetNodeName(opType), "");
480-
long[] list = { 1 };
490+
long[] list = { 2 };
481491
node.AddAttribute("axes", list);
482492

483493
opType = "ReduceSum";
484-
var absentFeatureLogProbReduceSum = ctx.AddIntermediateVariable(null, "ReduceSum", true);
494+
var absentFeatureLogProbReduceSum = ctx.AddIntermediateVariable(typeLabelByOne, "ReduceSum");
485495
node = ctx.CreateNode(opType, new[] { absentFeatureLogProb }, new[] { absentFeatureLogProbReduceSum }, ctx.GetNodeName(opType), "");
486496
node.AddAttribute("axes", list);
487497

488498
opType = "Cast";
489-
var castOutput = ctx.AddIntermediateVariable(null, "CastOutput2", true);
499+
castOutput = ctx.AddIntermediateVariable(NumberDataViewType.Single, "CastOutput");
490500
node = ctx.CreateNode(opType, learnedAbsentFeatureLogProb, castOutput, ctx.GetNodeName(opType), "");
491501
t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
492502
node.AddAttribute("to", t);
493503

494504
opType = "Sub";
495-
var subOutput = ctx.AddIntermediateVariable(null, "SubOutput", true);
505+
var subOutput = ctx.AddIntermediateVariable(typeLabelByOne, "SubOutput");
496506
ctx.CreateNode(opType, new[] { castOutput, absentFeatureLogProbReduceSum }, new[] { subOutput }, ctx.GetNodeName(opType), "");
497507

498508
opType = "Sum";
499-
sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
509+
sumOutput = ctx.AddIntermediateVariable(typeLabelByOne, "SumOutput");
500510
ctx.CreateNode(opType, new[] { subOutput, logProbReduceSum, logOutput }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
501511

502-
opType = "Transpose";
503-
var transposeOutput = ctx.AddIntermediateVariable(null, "TransposeOutput", true);
504-
ctx.CreateNode(opType, new[] { sumOutput }, new[] { outputNames[1] }, ctx.GetNodeName(opType), "");
512+
opType = "Squeeze";
513+
var squeezeNode = ctx.CreateNode(opType, sumOutput, outputNames[1], ctx.GetNodeName(opType), "");
514+
squeezeNode.AddAttribute("axes", new long[] { 2 });
505515

506516
opType = "ArgMax";
507-
var scoreIndex = ctx.AddIntermediateVariable(null, "ScoreIndex", true);
508-
ctx.CreateNode(opType, new[] { sumOutput }, new[] { scoreIndex }, ctx.GetNodeName(opType), "");
517+
var scoreIndex = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Int64, 1), "ScoreIndex");
518+
node = ctx.CreateNode(opType, new[] { sumOutput }, new[] { scoreIndex }, ctx.GetNodeName(opType), "");
519+
node.AddAttribute("axis", 1);
520+
node.AddAttribute("keepdims", 0);
509521

510522
opType = "Cast";
511-
castOutput = ctx.AddIntermediateVariable(null, "CastOutput3", true);
523+
castOutput = ctx.AddIntermediateVariable(typeOne, "CastOutput");
512524
node = ctx.CreateNode(opType, scoreIndex, castOutput, ctx.GetNodeName(opType), "");
513525
t = InternalDataKindExtensions.ToInternalDataKind(DataKind.Single).ToType();
514526
node.AddAttribute("to", t);
515527

516528
//log3
517529
opType = "Sum";
518-
sumOutput = ctx.AddIntermediateVariable(null, "SumOutput", true);
530+
sumOutput = ctx.AddIntermediateVariable(typeOne, "SumOutput");
519531
ctx.CreateNode(opType, new[] { castOutput, one }, new[] { sumOutput }, ctx.GetNodeName(opType), "");
520532

521533
opType = "Cast";
@@ -529,7 +541,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
529541
private void LogMul(OnnxContext ctx, string input, string isFeaturePresent, string output)
530542
{
531543
var opType = "Log";
532-
var logOutput = ctx.AddIntermediateVariable(null, "LogOutput", true);
544+
var logOutput = ctx.AddIntermediateVariable(new VectorDataViewType(NumberDataViewType.Single, _featureHistogram[0].Length), "LogOutput");
533545
ctx.CreateNode(opType, input, logOutput, ctx.GetNodeName(opType), "");
534546

535547
opType = "Mul";

0 commit comments

Comments
 (0)