Skip to content

Added cross entropy support to validation training, edited metric reporting #5255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 2, 2020
12 changes: 8 additions & 4 deletions src/Microsoft.ML.Vision/ImageClassificationTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,11 @@ public sealed class TrainMetrics
public override string ToString()
{
if (DatasetUsed == ImageClassificationMetrics.Dataset.Train)
return $"Phase: Training, Dataset used: {DatasetUsed.ToString(),10}, Batch Processed Count: {BatchProcessedCount,3}, Learning Rate: {LearningRate,10} " +
Copy link
Contributor Author

@mstfbl mstfbl Jul 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added learning rate to the end of the string here for cosmetic reasons. Now there is uniform alignment of columns' statistics between training and validation. :) #Resolved

$"Epoch: {Epoch,3}, Accuracy: {Accuracy,10}, Cross-Entropy: {CrossEntropy,10}";
return $"Phase: Training, Dataset used: {DatasetUsed.ToString(),10}, Batch Processed Count: {BatchProcessedCount,3}, " +
$"Epoch: {Epoch,3}, Accuracy: {Accuracy,10}, Cross-Entropy: {CrossEntropy,10}, Learning Rate: {LearningRate,10}";
else
return $"Phase: Training, Dataset used: {DatasetUsed.ToString(),10}, Batch Processed Count: {BatchProcessedCount,3}, " +
$"Epoch: {Epoch,3}, Accuracy: {Accuracy,10}";
$"Epoch: {Epoch,3}, Accuracy: {Accuracy,10}, Cross-Entropy: {CrossEntropy,10}";
}
}

Expand Down Expand Up @@ -952,7 +952,7 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
if (validationNeeded)
{
validationEvalRunner = new Runner(_session, new[] { _bottleneckInput.name, _labelTensor.name },
new[] { _evaluationStep.name });
new[] { _evaluationStep.name, _crossEntropy.name });
}

runner = new Runner(_session, runnerInputTensorNames.ToArray(),
Expand Down Expand Up @@ -1036,14 +1036,18 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,
(outputTensors, metrics) =>
{
outputTensors[0].ToScalar(ref accuracy);
outputTensors[1].ToScalar(ref crossentropy);
metrics.Train.Accuracy += accuracy;
metrics.Train.CrossEntropy += crossentropy;
outputTensors[0].Dispose();
outputTensors[1].Dispose();
});

if (statisticsCallback != null)
{
metrics.Train.Epoch = epoch;
metrics.Train.Accuracy /= metrics.Train.BatchProcessedCount;
metrics.Train.CrossEntropy /= metrics.Train.BatchProcessedCount;
metrics.Train.DatasetUsed = ImageClassificationMetrics.Dataset.Validation;
statisticsCallback(metrics);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1489,7 +1489,6 @@ public void TensorFlowImageClassificationWithExponentialLRScheduling()
[TensorFlowFact]
public void TensorFlowImageClassificationWithPolynomialLRScheduling()
{

TensorFlowImageClassificationWithLRScheduling(new PolynomialLRDecay(), 50);
}

Expand Down Expand Up @@ -1521,6 +1520,8 @@ internal void TensorFlowImageClassificationWithLRScheduling(LearningRateSchedule
var (trainSetBottleneckCachedValuesFileName, validationSetBottleneckCachedValuesFileName,
workspacePath, isReuse) = getInitialParameters(ImageClassificationTrainer.Architecture.ResnetV2101, _finalImagesFolderName);

float[] crossEntropyTraining = new float[epoch];
float[] crossEntropyValidation = new float[epoch];
var options = new ImageClassificationTrainer.Options()
{
FeatureColumnName = "Image",
Expand All @@ -1532,7 +1533,30 @@ internal void TensorFlowImageClassificationWithLRScheduling(LearningRateSchedule
Epoch = epoch,
BatchSize = 10,
LearningRate = 0.01f,
MetricsCallback = (metric) => Console.WriteLine(metric),
MetricsCallback = (metric) =>
{
if (metric.Train != null)
{
// Check that cross validation rates during both the training and validation phases are decreasing and are sensible
if (metric.Train.DatasetUsed == ImageClassificationTrainer.ImageClassificationMetrics.Dataset.Train)
{
// Save cross entropy values in training phase
crossEntropyTraining[metric.Train.Epoch] = metric.Train.CrossEntropy;
// Check that cross entropy values over each epoch-per-decay are decreasing in training phase
if (metric.Train.Epoch > 0)
Assert.True(crossEntropyTraining[metric.Train.Epoch - 1] > crossEntropyTraining[metric.Train.Epoch]);
}
else
{
// Save cross entropy values in validation phase
crossEntropyValidation[metric.Train.Epoch] = metric.Train.CrossEntropy;
// Check that cross entropy values over each epoch-per-decay are decreasing in validation phase
if (metric.Train.Epoch > 0)
Assert.True(crossEntropyValidation[metric.Train.Epoch - 1] > crossEntropyValidation[metric.Train.Epoch]);
}
}
Console.WriteLine(metric);
},
ValidationSet = validationSet,
WorkspacePath = workspacePath,
TrainSetBottleneckCachedValuesFileName = trainSetBottleneckCachedValuesFileName,
Expand Down