3
3
open System
4
4
open System.IO
5
5
open Microsoft.ML
6
- open Microsoft.ML .Data
7
- open MulticlassClassification_Iris.DataStructures
8
- open Common
9
6
open MulticlassClassification_Iris
7
+ open MulticlassClassification_Iris.DataStructures
10
8
11
9
let appPath = Path.GetDirectoryName( Environment.GetCommandLineArgs().[ 0 ])
12
10
@@ -21,85 +19,80 @@ let modelPath = sprintf @"%s/IrisClassificationModel.zip" baseModelsPath
21
19
let buildTrainEvaluateAndSaveModel ( mlContext : MLContext ) =
22
20
23
21
// STEP 1: Common data loading configuration
24
- let textLoader =
25
- mlContext.Data.CreateTextReader(
26
- separatorChar = '\t' ,
27
- hasHeader = true ,
28
- columns =
29
- [|
30
- TextLoader.Column( " Label" , Nullable DataKind.R4, 0 )
31
- TextLoader.Column( " SepalLength" , Nullable DataKind.R4, 1 )
32
- TextLoader.Column( " SepalWidth" , Nullable DataKind.R4, 2 )
33
- TextLoader.Column( " PetalLength" , Nullable DataKind.R4, 3 )
34
- TextLoader.Column( " PetalWidth" , Nullable DataKind.R4, 4 )
35
- |]
36
- )
37
-
38
- let trainingDataView = textLoader.Read trainDataPath
39
- let testDataView = textLoader.Read testDataPath
40
-
41
- // STEP 2: Common data process configuration with pipeline data transformations
42
- let dataProcessPipeline =
43
- mlContext.Transforms.Concatenate( " Features" , [| " SepalLength" ; " SepalWidth" ; " PetalLength" ; " PetalWidth" |])
44
- |> Common.ModelBuilder.appendCacheCheckpoint mlContext
22
+ let trainingDataView = mlContext.Data.ReadFromTextFile< IrisData>( trainDataPath, hasHeader = true )
23
+ let testDataView = mlContext.Data.ReadFromTextFile< IrisData>( testDataPath, hasHeader = true )
45
24
46
- // STEP 3: Set the training algorithm, then create and config the modelBuilder
25
+ // STEP 2: Common data process configuration with pipeline data transformations
26
+ let dataProcessPipeline =
27
+ mlContext.Transforms.Concatenate( " Features" , " SepalLength" ,
28
+ " SepalWidth" ,
29
+ " PetalLength" ,
30
+ " PetalWidth" )
31
+ .AppendCacheCheckpoint( mlContext)
32
+
33
+ // STEP 3: Set the training algorithm, then append the trainer to the pipeline
47
34
let trainer = mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent( labelColumn = " Label" , featureColumn = " Features" )
48
- let modelBuilder =
49
- Common.ModelBuilder.create mlContext dataProcessPipeline
50
- |> Common.ModelBuilder.addTrainer trainer
35
+ let trainingPipeline = dataProcessPipeline.Append( trainer)
36
+
37
+
38
+
51
39
52
40
// STEP 4: Train the model fitting to the DataSet
41
+
53
42
//Measure training time
54
43
let watch = System.Diagnostics.Stopwatch.StartNew()
55
44
56
45
printfn " =============== Training the model ==============="
57
- let trainedModel =
58
- modelBuilder
59
- |> Common.ModelBuilder.train trainingDataView
46
+ let trainedModel = trainingPipeline.Fit( trainingDataView)
60
47
61
48
//Stop measuring time
62
49
watch.Stop()
63
- let elapsedMs = watch.ElapsedMilliseconds
64
- printfn " ***** Training time: %f seconds *****" (( float) elapsedMs/ 1000.0 )
50
+ let elapsedMs = float watch.ElapsedMilliseconds
51
+ printfn " ***** Training time: %f seconds *****" ( elapsedMs/ 1000. )
52
+
65
53
66
54
// STEP 5: Evaluate the model and show accuracy stats
67
55
printfn " ===== Evaluating Model's accuracy with Test data ====="
68
- let metrics =
69
- ( trainedModel, modelBuilder)
70
- |> Common.ModelBuilder.evaluateMultiClassClassificationModel testDataView " Label" " Score"
56
+ let predictions = trainedModel.Transform( testDataView)
57
+ let metrics = mlContext.MulticlassClassification.Evaluate( predictions, " Label" , " Score" )
71
58
72
59
Common.ConsoleHelper.printMultiClassClassificationMetrics ( trainer.ToString()) metrics
73
60
74
61
// STEP 6: Save/persist the trained model to a .ZIP file
75
- ( trainedModel , modelBuilder )
76
- |> Common.ModelBuilder.saveModelAsFile modelPath
62
+ use fs = new FileStream ( modelPath , FileMode.Create , FileAccess.Write , FileShare.Write )
63
+ mlContext.Model.Save ( trainedModel , fs );
77
64
65
+ printfn " The model is saved to %s " modelPath
78
66
79
67
80
68
let testSomePredictions ( mlContext : MLContext ) =
81
-
82
69
//Test Classification Predictions with some hard-coded samples
83
- let modelScorer =
84
- Common.ModelScorer.create mlContext
85
- |> Common.ModelScorer.loadModelFromZipFile modelPath
86
-
87
- let prediction = modelScorer |> Common.ModelScorer.predictSingle DataStructures.TestIrisData.Iris1
88
- printfn " Actual: setosa. Predicted probability: setosa: %.4f " prediction.Score.[ 0 ]
89
- printfn " versicolor: %.4f " prediction.Score.[ 1 ]
90
- printfn " virginica: %.4f " prediction.Score.[ 2 ]
70
+ use stream = new FileStream( modelPath, FileMode.Open, FileAccess.Read, FileShare.Read)
71
+ let trainedModel = mlContext.Model.Load( stream);
72
+
73
+ // Create prediction engine related to the loaded trained model
74
+ let predEngine = trainedModel.CreatePredictionEngine< IrisData, IrisPrediction>( mlContext)
75
+
76
+ //Score sample 1
77
+ let resultprediction1 = predEngine.Predict( DataStructures.SampleIrisData.Iris1)
78
+
79
+ printfn " Actual: setosa. Predicted probability: setosa: %.4f " resultprediction1.Score.[ 0 ]
80
+ printfn " versicolor: %.4f " resultprediction1.Score.[ 1 ]
81
+ printfn " virginica: %.4f " resultprediction1.Score.[ 2 ]
91
82
printfn " "
92
83
93
- let prediction = modelScorer |> Common.ModelScorer.predictSingle DataStructures.TestIrisData.Iris2
94
- printfn " Actual: virginica. Predicted probability: setosa: %.4f " prediction.Score.[ 0 ]
95
- printfn " versicolor: %.4f " prediction.Score.[ 1 ]
96
- printfn " virginica: %.4f " prediction.Score.[ 2 ]
84
+ //Score sample 2
85
+ let resultprediction2 = predEngine.Predict( DataStructures.SampleIrisData.Iris2);
86
+ printfn " Actual: virginica. Predicted probability: setosa: %.4f " resultprediction2.Score.[ 0 ]
87
+ printfn " versicolor: %.4f " resultprediction2.Score.[ 1 ]
88
+ printfn " virginica: %.4f " resultprediction2.Score.[ 2 ]
97
89
printfn " "
98
90
99
- let prediction = modelScorer |> Common.ModelScorer.predictSingle DataStructures.TestIrisData.Iris3
100
- printfn " Actual: versicolor. Predicted probability: setosa: %.4f " prediction.Score.[ 0 ]
101
- printfn " versicolor: %.4f " prediction.Score.[ 1 ]
102
- printfn " virginica: %.4f " prediction.Score.[ 2 ]
91
+ //Score sample 3
92
+ let resultprediction2 = predEngine.Predict( DataStructures.SampleIrisData.Iris3);
93
+ printfn " Actual: versicolor. Predicted probability: setosa: %.4f " resultprediction2.Score.[ 0 ]
94
+ printfn " versicolor: %.4f " resultprediction2.Score.[ 1 ]
95
+ printfn " virginica: %.4f " resultprediction2.Score.[ 2 ]
103
96
printfn " "
104
97
105
98
@@ -117,6 +110,6 @@ let main argv =
117
110
testSomePredictions mlContext
118
111
119
112
printfn " =============== End of process, hit any key to finish ==============="
120
- ConsoleHelper.consolePressAnyKey ()
113
+ Console.ReadKey () |> ignore
121
114
122
115
0 // return an integer exit code
0 commit comments