@@ -186,6 +186,9 @@ private void CheckTrainingParameters(ImageClassificationEstimator.Options option
186
186
187
187
if ( _session . graph . OperationByName ( _labelTensor . name . Split ( ':' ) [ 0 ] ) == null )
188
188
throw Host . ExceptParam ( nameof ( options . TensorFlowLabel ) , $ "'{ options . TensorFlowLabel } ' does not exist in the model") ;
189
+ if ( options . EarlyStoppingCriteria != null && options . ValidationSet == null && options . TestOnTrainSet == false )
190
+ throw Host . ExceptParam ( nameof ( options . EarlyStoppingCriteria ) , $ "Early stopping enabled but unable to find a validation" +
191
+ $ " set and/or train set testing disabled. Please disable early stopping or either provide a validation set or enable train set training.") ;
189
192
}
190
193
191
194
private ( Tensor , Tensor ) AddJpegDecoding ( int height , int width , int depth )
@@ -381,6 +384,7 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
381
384
float crossentropy = 0 ;
382
385
for ( int epoch = 0 ; epoch < epochs ; epoch += 1 )
383
386
{
387
+ batchIndex = 0 ;
384
388
metrics . Train . Accuracy = 0 ;
385
389
metrics . Train . CrossEntropy = 0 ;
386
390
metrics . Train . BatchProcessedCount = 0 ;
@@ -432,6 +436,42 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
432
436
}
433
437
}
434
438
439
+ //Process last incomplete batch
440
+ if ( batchIndex > 0 )
441
+ {
442
+ featureTensorShape [ 0 ] = batchIndex ;
443
+ featureBatchSizeInBytes = sizeof ( float ) * featureLength * batchIndex ;
444
+ labelTensorShape [ 0 ] = batchIndex ;
445
+ labelBatchSizeInBytes = sizeof ( long ) * batchIndex ;
446
+ runner . AddInput ( new Tensor ( featureBatchPtr , featureTensorShape , TF_DataType . TF_FLOAT , featureBatchSizeInBytes ) , 0 )
447
+ . AddInput ( new Tensor ( labelBatchPtr , labelTensorShape , TF_DataType . TF_INT64 , labelBatchSizeInBytes ) , 1 )
448
+ . Run ( ) ;
449
+
450
+ metrics . Train . BatchProcessedCount += 1 ;
451
+
452
+ if ( options . TestOnTrainSet && statisticsCallback != null )
453
+ {
454
+ var outputTensors = testEvalRunner
455
+ . AddInput ( new Tensor ( featureBatchPtr , featureTensorShape , TF_DataType . TF_FLOAT , featureBatchSizeInBytes ) , 0 )
456
+ . AddInput ( new Tensor ( labelBatchPtr , labelTensorShape , TF_DataType . TF_INT64 , labelBatchSizeInBytes ) , 1 )
457
+ . Run ( ) ;
458
+
459
+ outputTensors [ 0 ] . ToScalar < float > ( ref accuracy ) ;
460
+ outputTensors [ 1 ] . ToScalar < float > ( ref crossentropy ) ;
461
+ metrics . Train . Accuracy += accuracy ;
462
+ metrics . Train . CrossEntropy += crossentropy ;
463
+
464
+ outputTensors [ 0 ] . Dispose ( ) ;
465
+ outputTensors [ 1 ] . Dispose ( ) ;
466
+ }
467
+
468
+ batchIndex = 0 ;
469
+ featureTensorShape [ 0 ] = batchSize ;
470
+ featureBatchSizeInBytes = sizeof ( float ) * featureBatch . Length ;
471
+ labelTensorShape [ 0 ] = batchSize ;
472
+ labelBatchSizeInBytes = sizeof ( long ) * batchSize ;
473
+ }
474
+
435
475
if ( options . TestOnTrainSet && statisticsCallback != null )
436
476
{
437
477
metrics . Train . Epoch = epoch ;
@@ -443,7 +483,15 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
443
483
}
444
484
445
485
if ( validationSet == null )
486
+ {
487
+ //Early stopping check
488
+ if ( options . EarlyStoppingCriteria != null )
489
+ {
490
+ if ( options . EarlyStoppingCriteria . ShouldStop ( metrics . Train ) )
491
+ break ;
492
+ }
446
493
continue ;
494
+ }
447
495
448
496
batchIndex = 0 ;
449
497
metrics . Train . BatchProcessedCount = 0 ;
@@ -481,6 +529,31 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
481
529
}
482
530
}
483
531
532
+ //Process last incomplete batch
533
+ if ( batchIndex > 0 )
534
+ {
535
+ featureTensorShape [ 0 ] = batchIndex ;
536
+ featureBatchSizeInBytes = sizeof ( float ) * featureLength * batchIndex ;
537
+ labelTensorShape [ 0 ] = batchIndex ;
538
+ labelBatchSizeInBytes = sizeof ( long ) * batchIndex ;
539
+ var outputTensors = validationEvalRunner
540
+ . AddInput ( new Tensor ( featureBatchPtr , featureTensorShape , TF_DataType . TF_FLOAT , featureBatchSizeInBytes ) , 0 )
541
+ . AddInput ( new Tensor ( labelBatchPtr , labelTensorShape , TF_DataType . TF_INT64 , labelBatchSizeInBytes ) , 1 )
542
+ . Run ( ) ;
543
+
544
+ outputTensors [ 0 ] . ToScalar < float > ( ref accuracy ) ;
545
+ metrics . Train . Accuracy += accuracy ;
546
+ metrics . Train . BatchProcessedCount += 1 ;
547
+ batchIndex = 0 ;
548
+
549
+ featureTensorShape [ 0 ] = batchSize ;
550
+ featureBatchSizeInBytes = sizeof ( float ) * featureBatch . Length ;
551
+ labelTensorShape [ 0 ] = batchSize ;
552
+ labelBatchSizeInBytes = sizeof ( long ) * batchSize ;
553
+
554
+ outputTensors [ 0 ] . Dispose ( ) ;
555
+ }
556
+
484
557
if ( statisticsCallback != null )
485
558
{
486
559
metrics . Train . Epoch = epoch ;
0 commit comments