@@ -66,7 +66,8 @@ namespace Microsoft.ML.Vision
66
66
///
67
67
/// ### Training Algorithm Details
68
68
/// Trains a Deep Neural Network(DNN) by leveraging an existing pre-trained model such as Resnet50 for the purpose
69
- /// of classifying images.
69
+ /// of classifying images. The technique was inspired from [TensorFlow's retrain image classification tutorial]
70
+ /// (https://www.tensorflow.org/hub/tutorials/image_retraining)
70
71
/// ]]>
71
72
/// </format>
72
73
/// </remarks>
@@ -392,7 +393,7 @@ public sealed class Options : TrainerInputBaseWithLabel
392
393
public Action < ImageClassificationMetrics > MetricsCallback = null ;
393
394
394
395
/// <summary>
395
- /// Indicates the path where the models get downloaded to and cache files saved, default is a new temporary directory
396
+ /// Indicates the path where the image bottleneck cache files and trained model are saved, default is a new temporary directory
396
397
/// </summary>
397
398
[ Argument ( ArgumentType . AtMostOnce , HelpText = "Indicates the path where the models get downloaded to and cache files saved, default is a new temporary directory." , SortOrder = 15 ) ]
398
399
public string WorkspacePath = null ;
@@ -591,6 +592,7 @@ private void InitializeTrainingGraph(IDataView input)
591
592
_classCount = labelCount == 1 ? 2 : ( int ) labelCount ;
592
593
var imageSize = ImagePreprocessingSize [ _options . Arch ] ;
593
594
_session = LoadTensorFlowSessionFromMetaGraph ( Host , _options . Arch ) . Session ;
595
+ _session . graph . as_default ( ) ;
594
596
( _jpegData , _resizedImage ) = AddJpegDecoding ( imageSize . Item1 , imageSize . Item2 , 3 ) ;
595
597
_jpegDataTensorName = _jpegData . name ;
596
598
_resizedImageTensorName = _resizedImage . name ;
@@ -631,6 +633,14 @@ private protected override MulticlassPredictionTransformer<ImageClassificationMo
631
633
632
634
private protected override ImageClassificationModelParameters TrainModelCore ( TrainContext trainContext )
633
635
{
636
+ // Workspace directory is cleaned after training run. However, the pipeline can be re-used by calling
637
+ // fit() again after transform(), in which case we must ensure workspace directory exists. This scenario
638
+ // is typical in the case of cross-validation.
639
+ if ( ! Directory . Exists ( _options . WorkspacePath ) )
640
+ {
641
+ Directory . CreateDirectory ( _options . WorkspacePath ) ;
642
+ }
643
+
634
644
InitializeTrainingGraph ( trainContext . TrainingSet . Data ) ;
635
645
CheckTrainingParameters ( _options ) ;
636
646
var validationSet = trainContext . ValidationSet ? . Data ?? _options . ValidationSet ;
@@ -1301,7 +1311,7 @@ private void VariableSummaries(RefVariable var)
1301
1311
var optimizer = useLearningRateScheduler ? tf . train . GradientDescentOptimizer ( _learningRateInput ) :
1302
1312
tf . train . GradientDescentOptimizer ( learningRate ) ;
1303
1313
1304
- _trainStep = optimizer . minimize ( crossEntropyMean ) ;
1314
+ _trainStep = optimizer . minimize ( crossEntropyMean ) ;
1305
1315
} ) ;
1306
1316
1307
1317
return ( _trainStep , crossEntropyMean , _labelTensor , _softMaxTensor ) ;
@@ -1341,6 +1351,11 @@ private void Dispose(bool disposing)
1341
1351
{
1342
1352
_session . close ( ) ;
1343
1353
}
1354
+
1355
+ if ( _session != null && _session . graph != IntPtr . Zero )
1356
+ {
1357
+ _session . graph . Dispose ( ) ;
1358
+ }
1344
1359
}
1345
1360
1346
1361
/// <summary>
@@ -1527,6 +1542,11 @@ private void Dispose(bool disposing)
1527
1542
{
1528
1543
_session . close ( ) ;
1529
1544
}
1545
+
1546
+ if ( _session != null && _session . graph != IntPtr . Zero )
1547
+ {
1548
+ _session . graph . Dispose ( ) ;
1549
+ }
1530
1550
}
1531
1551
}
1532
1552
}
0 commit comments