Skip to content

Added cross entropy support to validation training, edited metric reporting #5255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 2, 2020
22 changes: 11 additions & 11 deletions src/Microsoft.ML.Vision/ImageClassificationTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,8 @@ public sealed class TrainMetrics
/// </summary>
public override string ToString()
{
if (DatasetUsed == ImageClassificationMetrics.Dataset.Train)
return $"Phase: Training, Dataset used: {DatasetUsed.ToString(),10}, Batch Processed Count: {BatchProcessedCount,3}, Learning Rate: {LearningRate,10} " +
Copy link
Contributor Author

@mstfbl mstfbl Jul 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added learning rate to the end of the string here for cosmetic reasons. Now there is uniform alignment of columns' statistics between training and validation. :) #Resolved

$"Epoch: {Epoch,3}, Accuracy: {Accuracy,10}, Cross-Entropy: {CrossEntropy,10}";
else
return $"Phase: Training, Dataset used: {DatasetUsed.ToString(),10}, Batch Processed Count: {BatchProcessedCount,3}, " +
$"Epoch: {Epoch,3}, Accuracy: {Accuracy,10}";
return $"Phase: Training, Dataset used: {DatasetUsed.ToString(),10}, Batch Processed Count: {BatchProcessedCount,3}, Learning Rate: {LearningRate,10} " +
$"Epoch: {Epoch,3}, Accuracy: {Accuracy,10}, Cross-Entropy: {CrossEntropy,10}";
}
}

Expand Down Expand Up @@ -951,8 +947,8 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,

if (validationNeeded)
{
validationEvalRunner = new Runner(_session, new[] { _bottleneckInput.name, _labelTensor.name },
new[] { _evaluationStep.name });
validationEvalRunner = new Runner(_session, runnerInputTensorNames.ToArray(),
new[] { _evaluationStep.name, _crossEntropy.name }, new[] { _trainStep.name });
}

runner = new Runner(_session, runnerInputTensorNames.ToArray(),
Expand Down Expand Up @@ -1029,21 +1025,25 @@ private void TrainAndEvaluateClassificationLayer(string trainBottleneckFilePath,

// Evaluate.
TrainAndEvaluateClassificationLayerCore(epoch, learningRate, featureFileStartOffset,
metrics, labelTensorShape, featureTensorShape, batchSize,
validationSetLabelReader, validationSetFeatureReader, labelBuffer, featuresBuffer,
labelBufferSizeInBytes, featureBufferSizeInBytes, featureFileRecordSize, null,
metrics, labelTensorShape, featureTensorShape, batchSize, validationSetLabelReader,
validationSetFeatureReader, labelBuffer, featuresBuffer, labelBufferSizeInBytes,
featureBufferSizeInBytes, featureFileRecordSize, _options.LearningRateScheduler,
trainState, validationEvalRunner, featureBufferPtr, labelBufferPtr,
(outputTensors, metrics) =>
{
outputTensors[0].ToScalar(ref accuracy);
outputTensors[1].ToScalar(ref crossentropy);
metrics.Train.Accuracy += accuracy;
metrics.Train.CrossEntropy += crossentropy;
outputTensors[0].Dispose();
outputTensors[1].Dispose();
});

if (statisticsCallback != null)
{
metrics.Train.Epoch = epoch;
metrics.Train.Accuracy /= metrics.Train.BatchProcessedCount;
metrics.Train.CrossEntropy /= metrics.Train.BatchProcessedCount;
metrics.Train.DatasetUsed = ImageClassificationMetrics.Dataset.Validation;
statisticsCallback(metrics);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1483,17 +1483,19 @@ public void TensorFlowImageClassification(ImageClassificationTrainer.Architectur
[TensorFlowFact]
public void TensorFlowImageClassificationWithExponentialLRScheduling()
{
TensorFlowImageClassificationWithLRScheduling(new ExponentialLRDecay(), 50);
ExponentialLRDecay exponentialLRDecay = new ExponentialLRDecay();
TensorFlowImageClassificationWithLRScheduling(exponentialLRDecay, 50, (int)exponentialLRDecay.NumEpochsPerDecay);
}

[TensorFlowFact]
public void TensorFlowImageClassificationWithPolynomialLRScheduling()
{

TensorFlowImageClassificationWithLRScheduling(new PolynomialLRDecay(), 50);
PolynomialLRDecay polynomialLRDecay = new PolynomialLRDecay();
TensorFlowImageClassificationWithLRScheduling(polynomialLRDecay, 50, (int)polynomialLRDecay.NumEpochsPerDecay, polynomialLRDecay.EndLearningRate);
}

internal void TensorFlowImageClassificationWithLRScheduling(LearningRateScheduler learningRateScheduler, int epoch)
internal void TensorFlowImageClassificationWithLRScheduling(
LearningRateScheduler learningRateScheduler, int epoch, int numEpochsPerDecay, float endLearningRate = 0.0f)
{
//Load all the original images info
IEnumerable<ImageData> images = LoadImagesFromDirectory(
Expand Down Expand Up @@ -1521,6 +1523,11 @@ internal void TensorFlowImageClassificationWithLRScheduling(LearningRateSchedule
var (trainSetBottleneckCachedValuesFileName, validationSetBottleneckCachedValuesFileName,
workspacePath, isReuse) = getInitialParameters(ImageClassificationTrainer.Architecture.ResnetV2101, _finalImagesFolderName);

float[] learningRatesTraining = new float[epoch];
float[] learningRatesValidation = new float[epoch];
float[] crossEntropyTraining = new float[epoch];
float[] crossEntropyValidation = new float[epoch];
float baseLearningRate = 0.01f;
var options = new ImageClassificationTrainer.Options()
{
FeatureColumnName = "Image",
Expand All @@ -1531,8 +1538,64 @@ internal void TensorFlowImageClassificationWithLRScheduling(LearningRateSchedule
Arch = ImageClassificationTrainer.Architecture.ResnetV2101,
Epoch = epoch,
BatchSize = 10,
LearningRate = 0.01f,
MetricsCallback = (metric) => Console.WriteLine(metric),
LearningRate = baseLearningRate,
MetricsCallback = (metric) =>
{
// Check that learning rates in metrics from both the training and validation phases decay and are sensible
if (metric.Train != null)
{
if (metric.Train.Epoch > 1)
{
Assert.InRange(metric.Train.LearningRate, 0, baseLearningRate);
Assert.True(metric.Train.CrossEntropy > 0);
}
// Save learning rate and cross entropy values in training phase
if (metric.Train.DatasetUsed == ImageClassificationTrainer.ImageClassificationMetrics.Dataset.Train)
{
learningRatesTraining[metric.Train.Epoch] = metric.Train.LearningRate;
crossEntropyTraining[metric.Train.Epoch] = metric.Train.CrossEntropy;
// Check that learning rates over each epoch-per-decay are decreasing, and that cross entropy is also decreasing
if (metric.Train.Epoch > 1)
{
// Testing PolynomialLRDecay training
if (endLearningRate != 0.0)
{
Assert.True(learningRatesTraining[metric.Train.Epoch - numEpochsPerDecay] > learningRatesTraining[metric.Train.Epoch]
|| learningRatesTraining[metric.Train.Epoch] == endLearningRate);
}
// Testing ExponentialLRDecay training
else
{
Assert.True(learningRatesTraining[metric.Train.Epoch - numEpochsPerDecay] > learningRatesTraining[metric.Train.Epoch]);
}
Assert.True(crossEntropyTraining[metric.Train.Epoch - numEpochsPerDecay] > crossEntropyTraining[metric.Train.Epoch]);
}
}
// Save learning rate and cross entropy values in validation phase
else
{
learningRatesValidation[metric.Train.Epoch] = metric.Train.LearningRate;
crossEntropyValidation[metric.Train.Epoch] = metric.Train.CrossEntropy;
// Check that learning rates over each epoch-per-decay are decreasing, and that cross entropy is also decreasing
if (metric.Train.Epoch > 1)
{
// Testing PolynomialLRDecay validation
if (endLearningRate != 0.0)
{
Assert.True(learningRatesValidation[metric.Train.Epoch - numEpochsPerDecay] > learningRatesValidation[metric.Train.Epoch]
|| learningRatesValidation[metric.Train.Epoch] == endLearningRate);
}
// Testing ExponentialLRDecay validation
else
{
Assert.True(learningRatesValidation[metric.Train.Epoch - numEpochsPerDecay] > learningRatesValidation[metric.Train.Epoch]);
}
Assert.True(crossEntropyValidation[metric.Train.Epoch - numEpochsPerDecay] > crossEntropyValidation[metric.Train.Epoch]);
}
}
}
Console.WriteLine(metric);
},
ValidationSet = validationSet,
WorkspacePath = workspacePath,
TrainSetBottleneckCachedValuesFileName = trainSetBottleneckCachedValuesFileName,
Expand Down