Skip to content

Commit 0fcf7ae

Browse files
Issue 3234: use model schema type instead of class definition schema (#5228)
* fix issue 3234, use model schema type instead of class definition schema * refine comments * change test to test on both case
1 parent 4f90006 commit 0fcf7ae

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/Microsoft.ML.Data/DataView/DataViewConstructionUtils.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,13 @@ internal static SchemaDefinition GetSchemaDefinition<TRow>(IHostEnvironment env,
7171
var schemaDefinitionCol = schemaDefinition.FirstOrDefault(c => c.ColumnName == name);
7272
if (schemaDefinitionCol == null)
7373
throw env.Except($"Type should contain a member named {name}");
74+
75+
//Always use column type from model as this type can be more specific.
76+
//This can be corner case:
77+
//For example, we can load an model whose schema contains Vector<Single, 38>
78+
//and define this field in input class as float[] without specific array length.
79+
schemaDefinitionCol.ColumnType = col.Type;
80+
7481
var annotations = col.Annotations;
7582
if (annotations != null)
7683
{

test/Microsoft.ML.Functional.Tests/ModelFiles.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ public class ModelInput
217217
public string[] CategoricalFeatures;
218218
public float[] NumericalFeatures;
219219
#pragma warning restore SA1401
220+
public float Label;
220221
}
221222

222223
public class ModelOutput
@@ -233,14 +234,29 @@ public void LoadModelWithOptionalColumnTransform()
233234
SchemaDefinition inputSchemaDefinition = SchemaDefinition.Create(typeof(ModelInput));
234235
inputSchemaDefinition[nameof(ModelInput.CategoricalFeatures)].ColumnType = new VectorDataViewType(TextDataViewType.Instance, 5);
235236
inputSchemaDefinition[nameof(ModelInput.NumericalFeatures)].ColumnType = new VectorDataViewType(NumberDataViewType.Single, 3);
237+
236238
var mlContext = new MLContext(1);
237239
ITransformer trainedModel;
238240
DataViewSchema dataViewSchema;
239241
trainedModel = mlContext.Model.Load(TestCommon.GetDataPath(DataDir, "backcompat", "modelwithoptionalcolumntransform.zip"), out dataViewSchema);
242+
243+
var modelInput = new ModelInput()
244+
{
245+
CategoricalFeatures = new[] { "ABC", "ABC", "ABC", "ABC", "ABC" },
246+
NumericalFeatures = new float[] { 1, 1, 1 },
247+
Label = 1
248+
};
249+
250+
// test create prediction engine with user defined schema
240251
var model = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(trainedModel, inputSchemaDefinition: inputSchemaDefinition);
241-
var prediction = model.Predict(new ModelInput() { CategoricalFeatures = new[] { "ABC", "ABC", "ABC", "ABC", "ABC" }, NumericalFeatures = new float [] { 1, 1, 1 } });
252+
var prediction = model.Predict(modelInput);
253+
254+
// test create prediction engine with schema loaded from model
255+
var model2 = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(trainedModel, inputSchema: dataViewSchema);
256+
var prediction2 = model2.Predict(modelInput);
242257

243258
Assert.Equal(1, prediction.Score[0]);
259+
Assert.Equal(1, prediction2.Score[0]);
244260
}
245261

246262
[Fact]

0 commit comments

Comments
 (0)