@@ -405,6 +405,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
405
405
}
406
406
407
407
var one = ctx . AddInitializer ( 1.0f , "one" ) ;
408
+ var oneInt = ctx . AddInitializer ( 1 , typeof ( int ) , "oneInt" ) ;
408
409
var zero = ctx . AddInitializer ( 0.0f , "zero" ) ;
409
410
var labelCount = ctx . AddInitializer ( ( float ) _labelCount , "labelCount" ) ;
410
411
var trainingCount = ctx . AddInitializer ( ( float ) _totalTrainingCount , "totalTrainingCount" ) ;
@@ -414,108 +415,119 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
414
415
var labelHistogramName = ctx . AddInitializer ( labelHistogramExpanded , new long [ ] { _featureHistogram [ 0 ] . Length , _labelHistogram . Length } , "labelHistogramExpanded" ) ;
415
416
var learnedAbsentFeatureLogProb = ctx . AddInitializer ( _absentFeaturesLogProb , new long [ ] { _absentFeaturesLogProb . Length , 1 } , "absentFeaturesLogProb" ) ;
416
417
417
- var greaterOutput = ctx . AddIntermediateVariable ( null , "greaterOutput" , true ) ;
418
+ var typeOne = new VectorDataViewType ( NumberDataViewType . Single , 1 ) ;
419
+ var typeFea = new VectorDataViewType ( NumberDataViewType . Single , _featureHistogram [ 0 ] . Length ) ;
420
+ var typeLabelByFea = new VectorDataViewType ( NumberDataViewType . Single , _labelHistogram . Length , _featureHistogram [ 0 ] . Length ) ;
421
+ var typeLabelByOne = new VectorDataViewType ( NumberDataViewType . Single , _labelHistogram . Length , 1 ) ;
422
+
423
+ var greaterOutput = ctx . AddIntermediateVariable ( new VectorDataViewType ( BooleanDataViewType . Instance , _featureHistogram [ 0 ] . Length ) , "greaterOutput" ) ;
418
424
var opType = "Greater" ;
419
425
ctx . CreateNode ( opType , new [ ] { featureColumn , zero } , new [ ] { greaterOutput } , ctx . GetNodeName ( opType ) , "" ) ;
420
426
421
427
opType = "Cast" ;
422
- var isFeaturePresent = ctx . AddIntermediateVariable ( null , "isFeaturePresent" , true ) ;
423
- var node = ctx . CreateNode ( opType , greaterOutput , isFeaturePresent , ctx . GetNodeName ( opType ) , "" ) ;
428
+ var castOutput = ctx . AddIntermediateVariable ( typeFea , "CastOutput" ) ;
429
+ var node = ctx . CreateNode ( opType , greaterOutput , castOutput , ctx . GetNodeName ( opType ) , "" ) ;
424
430
var t = InternalDataKindExtensions . ToInternalDataKind ( DataKind . Single ) . ToType ( ) ;
425
431
node . AddAttribute ( "to" , t ) ;
426
432
433
+ opType = "ExpandDims" ;
434
+ var isFeaturePresent = ctx . AddIntermediateVariable ( new VectorDataViewType ( NumberDataViewType . Single , 1 , _featureHistogram [ 0 ] . Length ) , "isFeaturePresent" ) ;
435
+ ctx . CreateNode ( opType , new [ ] { castOutput , oneInt } , new [ ] { isFeaturePresent } , ctx . GetNodeName ( opType ) , "com.microsoft" ) ;
436
+
427
437
//initialize logProb
428
438
opType = "Div" ;
429
- var divOutput = ctx . AddIntermediateVariable ( null , "DivOutput" , true ) ;
439
+ var divOutput = ctx . AddIntermediateVariable ( typeOne , "DivOutput" ) ;
430
440
ctx . CreateNode ( opType , new [ ] { labelHistogram , trainingCount } , new [ ] { divOutput } , ctx . GetNodeName ( opType ) , "" ) ;
431
441
432
442
opType = "Log" ;
433
- var logOutput = ctx . AddIntermediateVariable ( null , "LogOutput" , true ) ;
443
+ var logOutput = ctx . AddIntermediateVariable ( typeOne , "LogOutput" ) ;
434
444
ctx . CreateNode ( opType , divOutput , logOutput , ctx . GetNodeName ( opType ) , "" ) ;
435
445
436
446
//log1
437
447
opType = "Sum" ;
438
- var sumOutput = ctx . AddIntermediateVariable ( null , "SumOutput" , true ) ;
448
+ var sumOutput = ctx . AddIntermediateVariable ( _inputType , "SumOutput" ) ;
439
449
ctx . CreateNode ( opType , new [ ] { featureHistogramName , one } , new [ ] { sumOutput } , ctx . GetNodeName ( opType ) , "" ) ;
440
450
441
- var logOutput1 = ctx . AddIntermediateVariable ( null , "LogOutput" , true ) ;
451
+ var logOutput1 = ctx . AddIntermediateVariable ( typeLabelByFea , "LogOutput" ) ;
442
452
LogMul ( ctx , sumOutput , isFeaturePresent , logOutput1 ) ;
443
453
444
454
//log2
445
455
opType = "Transpose" ;
446
- var labelHistogramTrans = ctx . AddIntermediateVariable ( null , "transpose" , true ) ;
456
+ var labelHistogramTrans = ctx . AddIntermediateVariable ( typeFea , "Transpose" ) ;
447
457
ctx . CreateNode ( opType , labelHistogramName , labelHistogramTrans , ctx . GetNodeName ( opType ) , "" ) ;
448
458
449
459
opType = "Sub" ;
450
- var absentFeatureCount = ctx . AddIntermediateVariable ( null , "AbsentFeatureCounts" , true ) ;
460
+ var absentFeatureCount = ctx . AddIntermediateVariable ( typeFea , "AbsentFeatureCounts" ) ;
451
461
ctx . CreateNode ( opType , new [ ] { labelHistogramTrans , featureHistogramName } , new [ ] { absentFeatureCount } , ctx . GetNodeName ( opType ) , "" ) ;
452
462
453
463
opType = "Sum" ;
454
- sumOutput = ctx . AddIntermediateVariable ( null , "SumOutput" , true ) ;
464
+ sumOutput = ctx . AddIntermediateVariable ( typeFea , "SumOutput" ) ;
455
465
ctx . CreateNode ( opType , new [ ] { labelHistogramTrans , labelCount } , new [ ] { sumOutput } , ctx . GetNodeName ( opType ) , "" ) ;
456
466
457
- var logOutput2 = ctx . AddIntermediateVariable ( null , "LogOutput" , true ) ;
467
+ var logOutput2 = ctx . AddIntermediateVariable ( typeLabelByFea , "LogOutput" ) ;
458
468
LogMul ( ctx , sumOutput , isFeaturePresent , logOutput2 ) ;
459
469
460
470
//log3
461
471
opType = "Sum" ;
462
- sumOutput = ctx . AddIntermediateVariable ( null , "SumOutput" , true ) ;
472
+ sumOutput = ctx . AddIntermediateVariable ( typeFea , "SumOutput" ) ;
463
473
ctx . CreateNode ( opType , new [ ] { absentFeatureCount , one } , new [ ] { sumOutput } , ctx . GetNodeName ( opType ) , "" ) ;
464
474
465
- var logOutput3 = ctx . AddIntermediateVariable ( null , "LogOutput" , true ) ;
475
+ var logOutput3 = ctx . AddIntermediateVariable ( typeLabelByFea , "LogOutput" ) ;
466
476
LogMul ( ctx , sumOutput , isFeaturePresent , logOutput3 ) ;
467
477
468
478
//result
469
479
opType = "Sub" ;
470
- var logProb = ctx . AddIntermediateVariable ( null , "LogProb" , true ) ;
480
+ var logProb = ctx . AddIntermediateVariable ( typeLabelByFea , "LogProb" ) ;
471
481
ctx . CreateNode ( opType , new [ ] { logOutput1 , logOutput2 } , new [ ] { logProb } , ctx . GetNodeName ( opType ) , "" ) ;
472
482
473
483
opType = "Sub" ;
474
- var absentFeatureLogProb = ctx . AddIntermediateVariable ( null , "AbsentFeatureLogProb" , true ) ;
484
+ var absentFeatureLogProb = ctx . AddIntermediateVariable ( typeLabelByFea , "AbsentFeatureLogProb" ) ;
475
485
ctx . CreateNode ( opType , new [ ] { logOutput3 , logOutput2 } , new [ ] { absentFeatureLogProb } , ctx . GetNodeName ( opType ) , "" ) ;
476
486
477
487
opType = "ReduceSum" ;
478
- var logProbReduceSum = ctx . AddIntermediateVariable ( null , "ReduceSum" , true ) ;
488
+ var logProbReduceSum = ctx . AddIntermediateVariable ( typeLabelByOne , "ReduceSum" ) ;
479
489
node = ctx . CreateNode ( opType , new [ ] { logProb } , new [ ] { logProbReduceSum } , ctx . GetNodeName ( opType ) , "" ) ;
480
- long [ ] list = { 1 } ;
490
+ long [ ] list = { 2 } ;
481
491
node . AddAttribute ( "axes" , list ) ;
482
492
483
493
opType = "ReduceSum" ;
484
- var absentFeatureLogProbReduceSum = ctx . AddIntermediateVariable ( null , "ReduceSum" , true ) ;
494
+ var absentFeatureLogProbReduceSum = ctx . AddIntermediateVariable ( typeLabelByOne , "ReduceSum" ) ;
485
495
node = ctx . CreateNode ( opType , new [ ] { absentFeatureLogProb } , new [ ] { absentFeatureLogProbReduceSum } , ctx . GetNodeName ( opType ) , "" ) ;
486
496
node . AddAttribute ( "axes" , list ) ;
487
497
488
498
opType = "Cast" ;
489
- var castOutput = ctx . AddIntermediateVariable ( null , "CastOutput2" , true ) ;
499
+ castOutput = ctx . AddIntermediateVariable ( NumberDataViewType . Single , "CastOutput" ) ;
490
500
node = ctx . CreateNode ( opType , learnedAbsentFeatureLogProb , castOutput , ctx . GetNodeName ( opType ) , "" ) ;
491
501
t = InternalDataKindExtensions . ToInternalDataKind ( DataKind . Single ) . ToType ( ) ;
492
502
node . AddAttribute ( "to" , t ) ;
493
503
494
504
opType = "Sub" ;
495
- var subOutput = ctx . AddIntermediateVariable ( null , "SubOutput" , true ) ;
505
+ var subOutput = ctx . AddIntermediateVariable ( typeLabelByOne , "SubOutput" ) ;
496
506
ctx . CreateNode ( opType , new [ ] { castOutput , absentFeatureLogProbReduceSum } , new [ ] { subOutput } , ctx . GetNodeName ( opType ) , "" ) ;
497
507
498
508
opType = "Sum" ;
499
- sumOutput = ctx . AddIntermediateVariable ( null , "SumOutput" , true ) ;
509
+ sumOutput = ctx . AddIntermediateVariable ( typeLabelByOne , "SumOutput" ) ;
500
510
ctx . CreateNode ( opType , new [ ] { subOutput , logProbReduceSum , logOutput } , new [ ] { sumOutput } , ctx . GetNodeName ( opType ) , "" ) ;
501
511
502
- opType = "Transpose " ;
503
- var transposeOutput = ctx . AddIntermediateVariable ( null , "TransposeOutput" , true ) ;
504
- ctx . CreateNode ( opType , new [ ] { sumOutput } , new [ ] { outputNames [ 1 ] } , ctx . GetNodeName ( opType ) , "" ) ;
512
+ opType = "Squeeze " ;
513
+ var squeezeNode = ctx . CreateNode ( opType , sumOutput , outputNames [ 1 ] , ctx . GetNodeName ( opType ) , "" ) ;
514
+ squeezeNode . AddAttribute ( "axes" , new long [ ] { 2 } ) ;
505
515
506
516
opType = "ArgMax" ;
507
- var scoreIndex = ctx . AddIntermediateVariable ( null , "ScoreIndex" , true ) ;
508
- ctx . CreateNode ( opType , new [ ] { sumOutput } , new [ ] { scoreIndex } , ctx . GetNodeName ( opType ) , "" ) ;
517
+ var scoreIndex = ctx . AddIntermediateVariable ( new VectorDataViewType ( NumberDataViewType . Int64 , 1 ) , "ScoreIndex" ) ;
518
+ node = ctx . CreateNode ( opType , new [ ] { sumOutput } , new [ ] { scoreIndex } , ctx . GetNodeName ( opType ) , "" ) ;
519
+ node . AddAttribute ( "axis" , 1 ) ;
520
+ node . AddAttribute ( "keepdims" , 0 ) ;
509
521
510
522
opType = "Cast" ;
511
- castOutput = ctx . AddIntermediateVariable ( null , "CastOutput3" , true ) ;
523
+ castOutput = ctx . AddIntermediateVariable ( typeOne , "CastOutput" ) ;
512
524
node = ctx . CreateNode ( opType , scoreIndex , castOutput , ctx . GetNodeName ( opType ) , "" ) ;
513
525
t = InternalDataKindExtensions . ToInternalDataKind ( DataKind . Single ) . ToType ( ) ;
514
526
node . AddAttribute ( "to" , t ) ;
515
527
516
528
//log3
517
529
opType = "Sum" ;
518
- sumOutput = ctx . AddIntermediateVariable ( null , "SumOutput" , true ) ;
530
+ sumOutput = ctx . AddIntermediateVariable ( typeOne , "SumOutput" ) ;
519
531
ctx . CreateNode ( opType , new [ ] { castOutput , one } , new [ ] { sumOutput } , ctx . GetNodeName ( opType ) , "" ) ;
520
532
521
533
opType = "Cast" ;
@@ -529,7 +541,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
529
541
private void LogMul ( OnnxContext ctx , string input , string isFeaturePresent , string output )
530
542
{
531
543
var opType = "Log" ;
532
- var logOutput = ctx . AddIntermediateVariable ( null , "LogOutput" , true ) ;
544
+ var logOutput = ctx . AddIntermediateVariable ( new VectorDataViewType ( NumberDataViewType . Single , _featureHistogram [ 0 ] . Length ) , "LogOutput" ) ;
533
545
ctx . CreateNode ( opType , input , logOutput , ctx . GetNodeName ( opType ) , "" ) ;
534
546
535
547
opType = "Mul" ;
0 commit comments