@@ -217,6 +217,7 @@ public class ModelInput
217
217
public string [ ] CategoricalFeatures ;
218
218
public float [ ] NumericalFeatures ;
219
219
#pragma warning restore SA1401
220
+ public float Label ;
220
221
}
221
222
222
223
public class ModelOutput
@@ -233,14 +234,29 @@ public void LoadModelWithOptionalColumnTransform()
233
234
SchemaDefinition inputSchemaDefinition = SchemaDefinition . Create ( typeof ( ModelInput ) ) ;
234
235
inputSchemaDefinition [ nameof ( ModelInput . CategoricalFeatures ) ] . ColumnType = new VectorDataViewType ( TextDataViewType . Instance , 5 ) ;
235
236
inputSchemaDefinition [ nameof ( ModelInput . NumericalFeatures ) ] . ColumnType = new VectorDataViewType ( NumberDataViewType . Single , 3 ) ;
237
+
236
238
var mlContext = new MLContext ( 1 ) ;
237
239
ITransformer trainedModel ;
238
240
DataViewSchema dataViewSchema ;
239
241
trainedModel = mlContext . Model . Load ( TestCommon . GetDataPath ( DataDir , "backcompat" , "modelwithoptionalcolumntransform.zip" ) , out dataViewSchema ) ;
242
+
243
+ var modelInput = new ModelInput ( )
244
+ {
245
+ CategoricalFeatures = new [ ] { "ABC" , "ABC" , "ABC" , "ABC" , "ABC" } ,
246
+ NumericalFeatures = new float [ ] { 1 , 1 , 1 } ,
247
+ Label = 1
248
+ } ;
249
+
250
+ // test create prediction engine with user defined schema
240
251
var model = mlContext . Model . CreatePredictionEngine < ModelInput , ModelOutput > ( trainedModel , inputSchemaDefinition : inputSchemaDefinition ) ;
241
- var prediction = model . Predict ( new ModelInput ( ) { CategoricalFeatures = new [ ] { "ABC" , "ABC" , "ABC" , "ABC" , "ABC" } , NumericalFeatures = new float [ ] { 1 , 1 , 1 } } ) ;
252
+ var prediction = model . Predict ( modelInput ) ;
253
+
254
+ // test create prediction engine with schema loaded from model
255
+ var model2 = mlContext . Model . CreatePredictionEngine < ModelInput , ModelOutput > ( trainedModel , inputSchema : dataViewSchema ) ;
256
+ var prediction2 = model2 . Predict ( modelInput ) ;
242
257
243
258
Assert . Equal ( 1 , prediction . Score [ 0 ] ) ;
259
+ Assert . Equal ( 1 , prediction2 . Score [ 0 ] ) ;
244
260
}
245
261
246
262
[ Fact ]
0 commit comments