Skip to content

Commit 9bc3d7b

Browse files
ashbhandarecodemzs
authored andcommitted
Image Classification API: Fix processing incomplete batch(<batchSize), images processed per epoch , enable EarlyStopping without Validation Set. Fixes dotnet#4274 and dotnet#4286 (dotnet#4289)
* In ImageClassification, process incomplete batch where number of samples < batchSize. * fixed batchIndex not reseting in train loop, enabled EarlyStopping when validationSet is not given for ImageClassificationAPI * fixed changing shape of feature and label tensor for incomplete batch,detected edge case where early stopping not supported. * Improved featureBatchSizeInBytes calculation, improved exception message.
1 parent 34d970f commit 9bc3d7b

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed

src/Microsoft.ML.Dnn/ImageClassificationTransform.cs

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,9 @@ private void CheckTrainingParameters(ImageClassificationEstimator.Options option
186186

187187
if (_session.graph.OperationByName(_labelTensor.name.Split(':')[0]) == null)
188188
throw Host.ExceptParam(nameof(options.TensorFlowLabel), $"'{options.TensorFlowLabel}' does not exist in the model");
189+
if (options.EarlyStoppingCriteria != null && options.ValidationSet == null && options.TestOnTrainSet == false)
190+
throw Host.ExceptParam(nameof(options.EarlyStoppingCriteria), $"Early stopping enabled but unable to find a validation" +
191+
$" set and/or train set testing disabled. Please disable early stopping or either provide a validation set or enable train set training.");
189192
}
190193

191194
private (Tensor, Tensor) AddJpegDecoding(int height, int width, int depth)
@@ -381,6 +384,7 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
381384
float crossentropy = 0;
382385
for (int epoch = 0; epoch < epochs; epoch += 1)
383386
{
387+
batchIndex = 0;
384388
metrics.Train.Accuracy = 0;
385389
metrics.Train.CrossEntropy = 0;
386390
metrics.Train.BatchProcessedCount = 0;
@@ -432,6 +436,42 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
432436
}
433437
}
434438

439+
//Process last incomplete batch
440+
if (batchIndex > 0)
441+
{
442+
featureTensorShape[0] = batchIndex;
443+
featureBatchSizeInBytes = sizeof(float) * featureLength * batchIndex;
444+
labelTensorShape[0] = batchIndex;
445+
labelBatchSizeInBytes = sizeof(long) * batchIndex;
446+
runner.AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, featureBatchSizeInBytes), 0)
447+
.AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, labelBatchSizeInBytes), 1)
448+
.Run();
449+
450+
metrics.Train.BatchProcessedCount += 1;
451+
452+
if (options.TestOnTrainSet && statisticsCallback != null)
453+
{
454+
var outputTensors = testEvalRunner
455+
.AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, featureBatchSizeInBytes), 0)
456+
.AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, labelBatchSizeInBytes), 1)
457+
.Run();
458+
459+
outputTensors[0].ToScalar<float>(ref accuracy);
460+
outputTensors[1].ToScalar<float>(ref crossentropy);
461+
metrics.Train.Accuracy += accuracy;
462+
metrics.Train.CrossEntropy += crossentropy;
463+
464+
outputTensors[0].Dispose();
465+
outputTensors[1].Dispose();
466+
}
467+
468+
batchIndex = 0;
469+
featureTensorShape[0] = batchSize;
470+
featureBatchSizeInBytes = sizeof(float) * featureBatch.Length;
471+
labelTensorShape[0] = batchSize;
472+
labelBatchSizeInBytes = sizeof(long) * batchSize;
473+
}
474+
435475
if (options.TestOnTrainSet && statisticsCallback != null)
436476
{
437477
metrics.Train.Epoch = epoch;
@@ -443,7 +483,15 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
443483
}
444484

445485
if (validationSet == null)
486+
{
487+
//Early stopping check
488+
if (options.EarlyStoppingCriteria != null)
489+
{
490+
if (options.EarlyStoppingCriteria.ShouldStop(metrics.Train))
491+
break;
492+
}
446493
continue;
494+
}
447495

448496
batchIndex = 0;
449497
metrics.Train.BatchProcessedCount = 0;
@@ -481,6 +529,31 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
481529
}
482530
}
483531

532+
//Process last incomplete batch
533+
if(batchIndex > 0)
534+
{
535+
featureTensorShape[0] = batchIndex;
536+
featureBatchSizeInBytes = sizeof(float) * featureLength * batchIndex;
537+
labelTensorShape[0] = batchIndex;
538+
labelBatchSizeInBytes = sizeof(long) * batchIndex;
539+
var outputTensors = validationEvalRunner
540+
.AddInput(new Tensor(featureBatchPtr, featureTensorShape, TF_DataType.TF_FLOAT, featureBatchSizeInBytes), 0)
541+
.AddInput(new Tensor(labelBatchPtr, labelTensorShape, TF_DataType.TF_INT64, labelBatchSizeInBytes), 1)
542+
.Run();
543+
544+
outputTensors[0].ToScalar<float>(ref accuracy);
545+
metrics.Train.Accuracy += accuracy;
546+
metrics.Train.BatchProcessedCount += 1;
547+
batchIndex = 0;
548+
549+
featureTensorShape[0] = batchSize;
550+
featureBatchSizeInBytes = sizeof(float) * featureBatch.Length;
551+
labelTensorShape[0] = batchSize;
552+
labelBatchSizeInBytes = sizeof(long) * batchSize;
553+
554+
outputTensors[0].Dispose();
555+
}
556+
484557
if (statisticsCallback != null)
485558
{
486559
metrics.Train.Epoch = epoch;

0 commit comments

Comments
 (0)