Skip to content

Commit de250ba

Browse files
authored
Add example for applying ONNX model to in-memory images (#3851)
* Add example for applying ONNX model to in-memory images * Add expected outputs
1 parent d82cd7c commit de250ba

File tree

4 files changed

+188
-0
lines changed

4 files changed

+188
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
using System;
2+
using System.Drawing;
3+
using System.Linq;
4+
using Microsoft.ML;
5+
using Microsoft.ML.Data;
6+
using Microsoft.ML.Transforms.Image;
7+
8+
namespace Samples.Dynamic
9+
{
10+
public static class ApplyOnnxModelWithInMemoryImages
11+
{
12+
// Example of applying ONNX transform on in-memory images.
13+
public static void Example()
14+
{
15+
// Download the squeeznet image model from ONNX model zoo, version 1.2
16+
// https://github.com/onnx/models/tree/master/squeezenet or use
17+
// Microsoft.ML.Onnx.TestModels nuget.
18+
// It's a multiclass classifier. It consumes an input "data_0" and produces
19+
// an output "softmaxout_1".
20+
var modelPath = @"squeezenet\00000001\model.onnx";
21+
22+
// Create ML pipeline to score the data using OnnxScoringEstimator
23+
var mlContext = new MLContext();
24+
25+
// Create in-memory data points. Its Image/Scores field is the input/output of the used ONNX model.
26+
var dataPoints = new ImageDataPoint[]
27+
{
28+
new ImageDataPoint(Color.Red),
29+
new ImageDataPoint(Color.Green)
30+
};
31+
32+
// Convert training data to IDataView, the general data type used in ML.NET.
33+
var dataView = mlContext.Data.LoadFromEnumerable(dataPoints);
34+
35+
// Create a ML.NET pipeline which contains two steps. First, ExtractPixle is used to convert the 224x224 image to a 3x224x224 float tensor.
36+
// Then the float tensor is fed into a ONNX model with an input called "data_0" and an output called "softmaxout_1". Note that "data_0" and
37+
// "softmaxout_1" are model input and output names stored in the used ONNX model file. Users may need to inspect their own models to
38+
// get the right input and output column names.
39+
var pipeline = mlContext.Transforms.ExtractPixels("data_0", "Image") // Map column "Image" to column "data_0"
40+
.Append(mlContext.Transforms.ApplyOnnxModel("softmaxout_1", "data_0", modelPath)); // Map column "data_0" to column "softmaxout_1"
41+
var model = pipeline.Fit(dataView);
42+
var onnx = model.Transform(dataView);
43+
44+
// Convert IDataView back to IEnumerable<ImageDataPoint> so that user can inspect the output, column "softmaxout_1", of the ONNX transform.
45+
// Note that Column "softmaxout_1" would be stored in ImageDataPont.Scores because the added attributed [ColumnName("softmaxout_1")]
46+
// tells that ImageDataPont.Scores is equivalent to column "softmaxout_1".
47+
var transformedDataPoints = mlContext.Data.CreateEnumerable<ImageDataPoint>(onnx, false).ToList();
48+
49+
// The scores are probabilities of all possible classes, so they should all be positive.
50+
foreach (var dataPoint in transformedDataPoints)
51+
{
52+
var firstClassProb = dataPoint.Scores.First();
53+
var lastClassProb = dataPoint.Scores.Last();
54+
Console.WriteLine($"The probability of being the first class is {firstClassProb * 100}%.");
55+
Console.WriteLine($"The probability of being the last class is {lastClassProb * 100}%.");
56+
}
57+
58+
// Expected output:
59+
// The probability of being the first class is 0.002542659%.
60+
// The probability of being the last class is 0.0292684%.
61+
// The probability of being the first class is 0.02258059%.
62+
// The probability of being the last class is 0.394428%.
63+
}
64+
65+
// This class is used in Example() to describe data points which will be consumed by ML.NET pipeline.
66+
private class ImageDataPoint
67+
{
68+
// Height of Image.
69+
private const int height = 224;
70+
71+
// Width of Image.
72+
private const int width = 224;
73+
74+
// Image will be consumed by ONNX image multiclass classification model.
75+
[ImageType(height, width)]
76+
public Bitmap Image { get; set; }
77+
78+
// Expected output of ONNX model. It contains probabilities of all classes.
79+
// Note that the ColumnName below should match the output name in the used
80+
// ONNX model file.
81+
[ColumnName("softmaxout_1")]
82+
public float[] Scores { get; set; }
83+
84+
public ImageDataPoint()
85+
{
86+
Image = null;
87+
}
88+
89+
public ImageDataPoint(Color color)
90+
{
91+
Image = new Bitmap(width, height);
92+
for (int i = 0; i < width; ++i)
93+
for (int j = 0; j < height; ++j)
94+
Image.SetPixel(i, j, color);
95+
}
96+
}
97+
}
98+
}

src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs

+1
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ internal static ImageLoadingEstimator LoadImages(this TransformsCatalog catalog,
118118
/// <format type="text/markdown">
119119
/// <![CDATA[
120120
/// [!code-csharp[ExtractPixels](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ImageAnalytics/ExtractPixels.cs)]
121+
/// [!code-csharp[ApplyOnnxModel](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ApplyONNXModelWithInMemoryImages.cs)]
121122
/// ]]></format>
122123
/// </example>
123124
public static ImagePixelExtractingEstimator ExtractPixels(this TransformsCatalog catalog,

src/Microsoft.ML.OnnxTransformer/OnnxCatalog.cs

+7
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog
4545
/// <param name="modelFile">The path of the file containing the ONNX model.</param>
4646
/// <param name="gpuDeviceId">Optional GPU device ID to run execution on, <see langword="null" /> to run on CPU.</param>
4747
/// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
48+
/// <example>
49+
/// <format type="text/markdown">
50+
/// <![CDATA[
51+
/// [!code-csharp[ApplyOnnxModel](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ApplyONNXModelWithInMemoryImages.cs)]
52+
/// ]]>
53+
/// </format>
54+
/// </example>
4855
public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog,
4956
string outputColumnName,
5057
string inputColumnName,

test/Microsoft.ML.OnnxTransformerTest/OnnxTransformTests.cs

+82
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.Collections.Generic;
7+
using System.Drawing;
78
using System.IO;
89
using System.Linq;
910
using Microsoft.ML;
@@ -14,6 +15,7 @@
1415
using Microsoft.ML.StaticPipe;
1516
using Microsoft.ML.TestFramework.Attributes;
1617
using Microsoft.ML.Tools;
18+
using Microsoft.ML.Transforms.Image;
1719
using Microsoft.ML.Transforms.StaticPipe;
1820
using Xunit;
1921
using Xunit.Abstractions;
@@ -316,5 +318,85 @@ public void TestUnknownDimensions()
316318
Assert.Equal(0, predictions[1].argmax[0]);
317319
Assert.Equal(2, predictions[2].argmax[0]);
318320
}
321+
322+
/// <summary>
323+
/// This class is used in <see cref="OnnxModelInMemoryImage"/> to describe data points which will be consumed by ML.NET pipeline.
324+
/// </summary>
325+
private class ImageDataPoint
326+
{
327+
/// <summary>
328+
/// Height of <see cref="Image"/>.
329+
/// </summary>
330+
private const int height = 224;
331+
332+
/// <summary>
333+
/// Width of <see cref="Image"/>.
334+
/// </summary>
335+
private const int width = 224;
336+
337+
/// <summary>
338+
/// Image will be consumed by ONNX image multiclass classification model.
339+
/// </summary>
340+
[ImageType(height, width)]
341+
public Bitmap Image { get; set; }
342+
343+
/// <summary>
344+
/// Output of ONNX model. It contains probabilities of all classes.
345+
/// </summary>
346+
[ColumnName("softmaxout_1")]
347+
public float[] Scores { get; set; }
348+
349+
public ImageDataPoint()
350+
{
351+
Image = null;
352+
}
353+
354+
public ImageDataPoint(Color color)
355+
{
356+
Image = new Bitmap(width, height);
357+
for (int i = 0; i < width; ++i)
358+
for (int j = 0; j < height; ++j)
359+
Image.SetPixel(i, j, color);
360+
}
361+
}
362+
363+
/// <summary>
364+
/// Test applying ONNX transform on in-memory image.
365+
/// </summary>
366+
[OnnxFact]
367+
public void OnnxModelInMemoryImage()
368+
{
369+
// Path of ONNX model. It's a multiclass classifier. It consumes an input "data_0" and produces an output "softmaxout_1".
370+
var modelFile = "squeezenet/00000001/model.onnx";
371+
372+
// Create in-memory data points. Its Image/Scores field is the input/output of the used ONNX model.
373+
var dataPoints = new ImageDataPoint[]
374+
{
375+
new ImageDataPoint(Color.Red),
376+
new ImageDataPoint(Color.Green)
377+
};
378+
379+
// Convert training data to IDataView, the general data type used in ML.NET.
380+
var dataView = ML.Data.LoadFromEnumerable(dataPoints);
381+
382+
// Create a ML.NET pipeline which contains two steps. First, ExtractPixle is used to convert the 224x224 image to a 3x224x224 float tensor.
383+
// Then the float tensor is fed into a ONNX model with an input called "data_0" and an output called "softmaxout_1". Note that "data_0" and
384+
// "softmaxout_1" are model input and output names stored in the used ONNX model file. Users may need to inspect their own models to
385+
// get the right input and output column names.
386+
var pipeline = ML.Transforms.ExtractPixels("data_0", "Image") // Map column "Image" to column "data_0"
387+
.Append(ML.Transforms.ApplyOnnxModel("softmaxout_1", "data_0", modelFile)); // Map column "data_0" to column "softmaxout_1"
388+
var model = pipeline.Fit(dataView);
389+
var onnx = model.Transform(dataView);
390+
391+
// Convert IDataView back to IEnumerable<ImageDataPoint> so that user can inspect the output, column "softmaxout_1", of the ONNX transform.
392+
// Note that Column "softmaxout_1" would be stored in ImageDataPont.Scores because the added attributed [ColumnName("softmaxout_1")]
393+
// tells that ImageDataPont.Scores is equivalent to column "softmaxout_1".
394+
var transformedDataPoints = ML.Data.CreateEnumerable<ImageDataPoint>(onnx, false).ToList();
395+
396+
// The scores are probabilities of all possible classes, so they should all be positive.
397+
foreach (var dataPoint in transformedDataPoints)
398+
foreach (var score in dataPoint.Scores)
399+
Assert.True(score > 0);
400+
}
319401
}
320402
}

0 commit comments

Comments
 (0)