Skip to content

Commit ae2ac0d

Browse files
committed
Added some edits to address Yael's comments
1 parent b1e5739 commit ae2ac0d

File tree

3 files changed

+9
-11
lines changed

3 files changed

+9
-11
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/ImageClassification/ResnetV2101TransferLearningTrainTestSplit.cs

-2
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ public static void Example()
3030
//Download the image set and unzip
3131
string finalImagesFolderName = DownloadImageSet(
3232
imagesDownloadFolderPath);
33-
// Use this for testing on big flower datasets:
34-
//string finalImagesFolderName = "flower_photos";
3533
string fullImagesetFolderPath = Path.Combine(
3634
imagesDownloadFolderPath, finalImagesFolderName);
3735

docs/samples/Microsoft.ML.Samples/Program.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ internal static void RunAll()
2424
}
2525
}
2626

27-
Console.WriteLine("Number of samples that ran without any exception: " + samples);
28-
27+
Console.WriteLine("Number of samples that ran without any exception: " + samples);
2928

3029
}
3130
}

src/Microsoft.ML.Dnn/ImageClassificationTransform.cs

+8-7
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,9 @@ private void CheckTrainingParameters(ImageClassificationEstimator.Options option
189189
return (jpegData, resizedImage);
190190
}
191191

192-
private static Tensor Encode(VBuffer<byte> buffer, int length)
192+
private static Tensor Encode(VBuffer<byte> buffer)
193193
{
194+
int length = buffer.Length;
194195
var size = c_api.TF_StringEncodedSize((UIntPtr)length);
195196
var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8));
196197
//AllocationType = AllocationType.Tensorflow;
@@ -221,7 +222,7 @@ public ImageProcessor(ImageClassificationTransformer transformer)
221222

222223
public Tensor ProcessImage(VBuffer<byte> imgBuf)
223224
{
224-
var imageTensor = Encode(imgBuf, imgBuf.Length);
225+
var imageTensor = Encode(imgBuf);
225226
var processedTensor = _imagePreprocessingRunner.AddInput(imageTensor, 0).Run()[0];
226227
imageTensor.Dispose();
227228
return processedTensor;
@@ -1170,7 +1171,7 @@ internal sealed class Options : TransformInputBase
11701171
private readonly IHost _host;
11711172
private readonly Options _options;
11721173
private readonly DnnModel _dnnModel;
1173-
private readonly TF_DataType[] _tfInputTypes;
1174+
private readonly DataViewType[] _inputTypes;
11741175
private readonly DataViewType[] _outputTypes;
11751176
private ImageClassificationTransformer _transformer;
11761177

@@ -1179,7 +1180,7 @@ internal ImageClassificationEstimator(IHostEnvironment env, Options options, Dnn
11791180
_host = Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageClassificationEstimator));
11801181
_options = options;
11811182
_dnnModel = dnnModel;
1182-
_tfInputTypes = new[] { TF_DataType.TF_STRING };
1183+
_inputTypes = new[] { new VectorDataViewType(NumberDataViewType.Byte) };
11831184
_outputTypes = new[] { new VectorDataViewType(NumberDataViewType.Single), NumberDataViewType.UInt32.GetItemType() };
11841185
}
11851186

@@ -1206,9 +1207,9 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
12061207
var input = _options.InputColumns[i];
12071208
if (!inputSchema.TryFindColumn(input, out var col))
12081209
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
1209-
//var expectedType = DnnUtils.Tf2MlNetType(_tfInputTypes[i]);
1210-
//if (col.ItemType != expectedType)
1211-
// throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
1210+
var expectedType = _inputTypes[i];
1211+
if (!col.ItemType.Equals(expectedType))
1212+
throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, expectedType.ToString(), col.ItemType.ToString());
12121213
}
12131214
for (var i = 0; i < _options.OutputColumns.Length; i++)
12141215
{

0 commit comments

Comments
 (0)