Skip to content

Commit 1480fda

Browse files
authored
Recreate workspace directory when fit() is called, fix documentation, free up unmanaged memory. (dotnet#4438)
* Recreate workspace directory when fit() is called, fix documentation, free up unmanaged memory. * PR feedback.
1 parent 948521e commit 1480fda

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

src/Microsoft.ML.Vision/ImageClassificationTrainer.cs

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ namespace Microsoft.ML.Vision
6666
///
6767
/// ### Training Algorithm Details
6868
/// Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained model such as Resnet50 for the purpose
69-
/// of classifying images.
69+
/// of classifying images. The technique was inspired from [TensorFlow's retrain image classification tutorial]
70+
/// (https://www.tensorflow.org/hub/tutorials/image_retraining)
7071
/// ]]>
7172
/// </format>
7273
/// </remarks>
@@ -392,7 +393,7 @@ public sealed class Options : TrainerInputBaseWithLabel
392393
public Action<ImageClassificationMetrics> MetricsCallback = null;
393394

394395
/// <summary>
395-
/// Indicates the path where the models get downloaded to and cache files saved, default is a new temporary directory
396+
/// Indicates the path where the image bottleneck cache files and trained model are saved, default is a new temporary directory
396397
/// </summary>
397398
[Argument(ArgumentType.AtMostOnce, HelpText = "Indicates the path where the models get downloaded to and cache files saved, default is a new temporary directory.", SortOrder = 15)]
398399
public string WorkspacePath = null;
@@ -591,6 +592,7 @@ private void InitializeTrainingGraph(IDataView input)
591592
_classCount = labelCount == 1 ? 2 : (int)labelCount;
592593
var imageSize = ImagePreprocessingSize[_options.Arch];
593594
_session = LoadTensorFlowSessionFromMetaGraph(Host, _options.Arch).Session;
595+
_session.graph.as_default();
594596
(_jpegData, _resizedImage) = AddJpegDecoding(imageSize.Item1, imageSize.Item2, 3);
595597
_jpegDataTensorName = _jpegData.name;
596598
_resizedImageTensorName = _resizedImage.name;
@@ -631,6 +633,14 @@ private protected override MulticlassPredictionTransformer<ImageClassificationMo
631633

632634
private protected override ImageClassificationModelParameters TrainModelCore(TrainContext trainContext)
633635
{
636+
// Workspace directory is cleaned after training run. However, the pipeline can be re-used by calling
637+
// fit() again after transform(), in which case we must ensure workspace directory exists. This scenario
638+
// is typical in the case of cross-validation.
639+
if (!Directory.Exists(_options.WorkspacePath))
640+
{
641+
Directory.CreateDirectory(_options.WorkspacePath);
642+
}
643+
634644
InitializeTrainingGraph(trainContext.TrainingSet.Data);
635645
CheckTrainingParameters(_options);
636646
var validationSet = trainContext.ValidationSet?.Data ?? _options.ValidationSet;
@@ -1301,7 +1311,7 @@ private void VariableSummaries(RefVariable var)
13011311
var optimizer = useLearningRateScheduler ? tf.train.GradientDescentOptimizer(_learningRateInput) :
13021312
tf.train.GradientDescentOptimizer(learningRate);
13031313

1304-
_trainStep = optimizer.minimize(crossEntropyMean);
1314+
_trainStep = optimizer.minimize(crossEntropyMean);
13051315
});
13061316

13071317
return (_trainStep, crossEntropyMean, _labelTensor, _softMaxTensor);
@@ -1341,6 +1351,11 @@ private void Dispose(bool disposing)
13411351
{
13421352
_session.close();
13431353
}
1354+
1355+
if (_session != null && _session.graph != IntPtr.Zero)
1356+
{
1357+
_session.graph.Dispose();
1358+
}
13441359
}
13451360

13461361
/// <summary>
@@ -1527,6 +1542,11 @@ private void Dispose(bool disposing)
15271542
{
15281543
_session.close();
15291544
}
1545+
1546+
if (_session != null && _session.graph != IntPtr.Zero)
1547+
{
1548+
_session.graph.Dispose();
1549+
}
15301550
}
15311551
}
15321552
}

0 commit comments

Comments
 (0)