Skip to content

Commit a6aaaf8

Browse files
wschincodemzs
authored andcommitted
Allow user to overwrite unknown shapes loaded from ONNX model (#3963)
* Allow user to specify ONNX shape Add a command line test Remove unused code * Address comments
1 parent 7f50e71 commit a6aaaf8

File tree

4 files changed

+542
-52
lines changed

4 files changed

+542
-52
lines changed

src/Microsoft.ML.OnnxTransformer/OnnxCatalog.cs

+80-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6+
using System.Collections.Generic;
67
using Microsoft.ML.Data;
78
using Microsoft.ML.Transforms;
89
using Microsoft.ML.Transforms.Onnx;
@@ -36,6 +37,34 @@ public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog
3637
bool fallbackToCpu = false)
3738
=> new OnnxScoringEstimator(CatalogUtils.GetEnvironment(catalog), modelFile, gpuDeviceId, fallbackToCpu);
3839

40+
/// <summary>
41+
/// Create a <see cref="OnnxScoringEstimator"/>, which applies a pre-trained Onnx model to the input column.
42+
/// Input/output columns are determined based on the input/output columns of the provided ONNX model.
43+
/// </summary>
44+
/// <remarks>
45+
/// The name/type of input columns must exactly match name/type of the ONNX model inputs.
46+
/// The name/type of the produced output columns will match name/type of the ONNX model outputs.
47+
/// </remarks>
48+
/// <param name="catalog">The transform's catalog.</param>
49+
/// <param name="modelFile">The path of the file containing the ONNX model.</param>
50+
/// <param name="shapeDictionary">ONNX shape should be used to over those loaded from <paramref name="modelFile"/>.</param>
51+
/// <param name="gpuDeviceId">Optional GPU device ID to run execution on, <see langword="null" /> to run on CPU.</param>
52+
/// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
53+
/// <example>
54+
/// <format type="text/markdown">
55+
/// <![CDATA[
56+
/// [!code-csharp[ApplyOnnxModel](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ApplyOnnxModel.cs)]
57+
/// ]]>
58+
/// </format>
59+
/// </example>
60+
public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog,
61+
string modelFile,
62+
IDictionary<string, int[]> shapeDictionary,
63+
int? gpuDeviceId = null,
64+
bool fallbackToCpu = false)
65+
=> new OnnxScoringEstimator(CatalogUtils.GetEnvironment(catalog), modelFile, gpuDeviceId, fallbackToCpu,
66+
shapeDictionary: shapeDictionary);
67+
3968
/// <summary>
4069
/// Create a <see cref="OnnxScoringEstimator"/>, which applies a pre-trained Onnx model to the <paramref name="inputColumnName"/> column.
4170
/// </summary>
@@ -58,7 +87,53 @@ public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog
5887
string modelFile,
5988
int? gpuDeviceId = null,
6089
bool fallbackToCpu = false)
61-
=> new OnnxScoringEstimator(CatalogUtils.GetEnvironment(catalog), new[] { outputColumnName }, new[] { inputColumnName }, modelFile, gpuDeviceId, fallbackToCpu);
90+
=> new OnnxScoringEstimator(CatalogUtils.GetEnvironment(catalog), new[] { outputColumnName }, new[] { inputColumnName },
91+
modelFile, gpuDeviceId, fallbackToCpu);
92+
93+
/// <summary>
94+
/// Create a <see cref="OnnxScoringEstimator"/>, which applies a pre-trained Onnx model to the <paramref name="inputColumnName"/> column.
95+
/// </summary>
96+
/// <param name="catalog">The transform's catalog.</param>
97+
/// <param name="outputColumnName">The output column resulting from the transformation.</param>
98+
/// <param name="inputColumnName">The input column.</param>
99+
/// <param name="modelFile">The path of the file containing the ONNX model.</param>
100+
/// <param name="shapeDictionary">ONNX shape should be used to over those loaded from <paramref name="modelFile"/>.</param>
101+
/// <param name="gpuDeviceId">Optional GPU device ID to run execution on, <see langword="null" /> to run on CPU.</param>
102+
/// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
103+
/// <example>
104+
/// <format type="text/markdown">
105+
/// <![CDATA[
106+
/// [!code-csharp[ApplyOnnxModel](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/ApplyONNXModelWithInMemoryImages.cs)]
107+
/// ]]>
108+
/// </format>
109+
/// </example>
110+
public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog,
111+
string outputColumnName,
112+
string inputColumnName,
113+
string modelFile,
114+
IDictionary<string, int[]> shapeDictionary,
115+
int? gpuDeviceId = null,
116+
bool fallbackToCpu = false)
117+
=> new OnnxScoringEstimator(CatalogUtils.GetEnvironment(catalog), new[] { outputColumnName }, new[] { inputColumnName },
118+
modelFile, gpuDeviceId, fallbackToCpu, shapeDictionary: shapeDictionary);
119+
120+
/// <summary>
121+
/// Create a <see cref="OnnxScoringEstimator"/>, which applies a pre-trained Onnx model to the <paramref name="inputColumnNames"/> columns.
122+
/// </summary>
123+
/// <param name="catalog">The transform's catalog.</param>
124+
/// <param name="outputColumnNames">The output columns resulting from the transformation.</param>
125+
/// <param name="inputColumnNames">The input columns.</param>
126+
/// <param name="modelFile">The path of the file containing the ONNX model.</param>
127+
/// <param name="gpuDeviceId">Optional GPU device ID to run execution on, <see langword="null" /> to run on CPU.</param>
128+
/// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
129+
public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog,
130+
string[] outputColumnNames,
131+
string[] inputColumnNames,
132+
string modelFile,
133+
int? gpuDeviceId = null,
134+
bool fallbackToCpu = false)
135+
=> new OnnxScoringEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnNames, inputColumnNames,
136+
modelFile, gpuDeviceId, fallbackToCpu);
62137

63138
/// <summary>
64139
/// Create a <see cref="OnnxScoringEstimator"/>, which applies a pre-trained Onnx model to the <paramref name="inputColumnNames"/> columns.
@@ -67,15 +142,18 @@ public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog
67142
/// <param name="outputColumnNames">The output columns resulting from the transformation.</param>
68143
/// <param name="inputColumnNames">The input columns.</param>
69144
/// <param name="modelFile">The path of the file containing the ONNX model.</param>
145+
/// <param name="shapeDictionary">ONNX shape should be used to over those loaded from <paramref name="modelFile"/>.</param>
70146
/// <param name="gpuDeviceId">Optional GPU device ID to run execution on, <see langword="null" /> to run on CPU.</param>
71147
/// <param name="fallbackToCpu">If GPU error, raise exception or fallback to CPU.</param>
72148
public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog,
73149
string[] outputColumnNames,
74150
string[] inputColumnNames,
75151
string modelFile,
152+
IDictionary<string, int[]> shapeDictionary,
76153
int? gpuDeviceId = null,
77154
bool fallbackToCpu = false)
78-
=> new OnnxScoringEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnNames, inputColumnNames, modelFile, gpuDeviceId, fallbackToCpu);
155+
=> new OnnxScoringEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnNames, inputColumnNames,
156+
modelFile, gpuDeviceId, fallbackToCpu, shapeDictionary: shapeDictionary);
79157

80158
/// <summary>
81159
/// Create <see cref="DnnImageFeaturizerEstimator"/>, which applies one of the pre-trained DNN models in

0 commit comments

Comments
 (0)