Skip to content

Commit 53b263f

Browse files
mariuszwojcikCESARDELATORRE
authored andcommitted
#FS230: move F# MultiClassification_Iris sample to use direct ML.NET API. (dotnet#238)
* dotnet#202: migrate BikeSharingDemand F# sample to v0.9 * migrate Sentiment Analysis F# sample to v0.9 * dotnet#213: append CacheCheckpoint in F# samples to mitigate ML.NET bug #2099 * #FS230: move F# BikeSharingDemand example to use direct ML.NET API. * #FS230: move F# MultiClassification_Iris sample to use direct ML.NET API.
1 parent ee509d4 commit 53b263f

File tree

4 files changed

+64
-63
lines changed

4 files changed

+64
-63
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,24 @@
11
namespace MulticlassClassification_Iris
22

33
module DataStructures =
4+
open Microsoft.ML.Data
45

56
/// Holds information about Iris flower to be classified.
67
[<CLIMutable>]
78
type IrisData = {
9+
[<LoadColumn(0)>]
10+
Label : float32
11+
12+
[<LoadColumn(1)>]
813
SepalLength : float32
14+
15+
[<LoadColumn(2)>]
916
SepalWidth : float32
17+
18+
[<LoadColumn(3)>]
1019
PetalLength : float32
20+
21+
[<LoadColumn(4)>]
1122
PetalWidth : float32
1223
}
1324

@@ -18,8 +29,8 @@ module DataStructures =
1829
}
1930

2031

21-
module TestIrisData =
22-
let Iris1 = { SepalLength = 5.1f; SepalWidth = 3.3f; PetalLength = 1.6f; PetalWidth= 0.2f }
23-
let Iris2 = { SepalLength = 6.4f; SepalWidth = 3.1f; PetalLength = 5.5f; PetalWidth = 2.2f }
24-
let Iris3 = { SepalLength = 4.4f; SepalWidth = 3.1f; PetalLength = 2.5f; PetalWidth = 1.2f }
32+
module SampleIrisData =
33+
let Iris1 = { Label = 0.f; SepalLength = 5.1f; SepalWidth = 3.3f; PetalLength = 1.6f; PetalWidth= 0.2f }
34+
let Iris2 = { Label = 0.f; SepalLength = 6.4f; SepalWidth = 3.1f; PetalLength = 5.5f; PetalWidth = 2.2f }
35+
let Iris3 = { Label = 0.f; SepalLength = 4.4f; SepalWidth = 3.1f; PetalLength = 2.5f; PetalWidth = 1.2f }
2536

samples/fsharp/getting-started/MulticlassClassification_Iris/IrisClassification/IrisClassificationConsoleApp/MulticlassClassification_Iris.fsproj

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,6 @@
88
<ItemGroup>
99
<Compile Include="DataStructures\DataStructures.fs" />
1010
<Compile Include="..\..\..\..\common_v0.9\ConsoleHelper.fs" Link="Common\ConsoleHelper.fs" />
11-
<Compile Include="..\..\..\..\common_v0.9\ModelBuilder.fs" Link="Common\ModelBuilder.fs" />
12-
<Compile Include="..\..\..\..\common_v0.9\ModelScorer.fs" Link="Common\ModelScorer.fs" />
13-
<Compile Include="..\..\..\..\common_v0.9\Pipeline.fs" Link="Common\Pipeline.fs" />
1411
</ItemGroup>
1512

1613
<ItemGroup>

samples/fsharp/getting-started/MulticlassClassification_Iris/IrisClassification/IrisClassificationConsoleApp/Program.fs

Lines changed: 49 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,8 @@
33
open System
44
open System.IO
55
open Microsoft.ML
6-
open Microsoft.ML.Data
7-
open MulticlassClassification_Iris.DataStructures
8-
open Common
96
open MulticlassClassification_Iris
7+
open MulticlassClassification_Iris.DataStructures
108

119
let appPath = Path.GetDirectoryName(Environment.GetCommandLineArgs().[0])
1210

@@ -21,85 +19,80 @@ let modelPath = sprintf @"%s/IrisClassificationModel.zip" baseModelsPath
2119
let buildTrainEvaluateAndSaveModel (mlContext : MLContext) =
2220

2321
// STEP 1: Common data loading configuration
24-
let textLoader =
25-
mlContext.Data.CreateTextReader(
26-
separatorChar = '\t',
27-
hasHeader = true,
28-
columns =
29-
[|
30-
TextLoader.Column("Label", Nullable DataKind.R4, 0)
31-
TextLoader.Column("SepalLength", Nullable DataKind.R4, 1)
32-
TextLoader.Column("SepalWidth", Nullable DataKind.R4, 2)
33-
TextLoader.Column("PetalLength", Nullable DataKind.R4, 3)
34-
TextLoader.Column("PetalWidth", Nullable DataKind.R4, 4)
35-
|]
36-
)
37-
38-
let trainingDataView = textLoader.Read trainDataPath
39-
let testDataView = textLoader.Read testDataPath
40-
41-
// STEP 2: Common data process configuration with pipeline data transformations
42-
let dataProcessPipeline =
43-
mlContext.Transforms.Concatenate("Features", [| "SepalLength"; "SepalWidth"; "PetalLength"; "PetalWidth"|])
44-
|> Common.ModelBuilder.appendCacheCheckpoint mlContext
22+
let trainingDataView = mlContext.Data.ReadFromTextFile<IrisData>(trainDataPath, hasHeader = true)
23+
let testDataView = mlContext.Data.ReadFromTextFile<IrisData>(testDataPath, hasHeader = true)
4524

46-
// STEP 3: Set the training algorithm, then create and config the modelBuilder
25+
// STEP 2: Common data process configuration with pipeline data transformations
26+
let dataProcessPipeline =
27+
mlContext.Transforms.Concatenate("Features", "SepalLength",
28+
"SepalWidth",
29+
"PetalLength",
30+
"PetalWidth")
31+
.AppendCacheCheckpoint(mlContext)
32+
33+
// STEP 3: Set the training algorithm, then append the trainer to the pipeline
4734
let trainer = mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent(labelColumn = "Label", featureColumn = "Features")
48-
let modelBuilder =
49-
Common.ModelBuilder.create mlContext dataProcessPipeline
50-
|> Common.ModelBuilder.addTrainer trainer
35+
let trainingPipeline = dataProcessPipeline.Append(trainer)
36+
37+
38+
5139

5240
// STEP 4: Train the model fitting to the DataSet
41+
5342
//Measure training time
5443
let watch = System.Diagnostics.Stopwatch.StartNew()
5544

5645
printfn "=============== Training the model ==============="
57-
let trainedModel =
58-
modelBuilder
59-
|> Common.ModelBuilder.train trainingDataView
46+
let trainedModel = trainingPipeline.Fit(trainingDataView)
6047

6148
//Stop measuring time
6249
watch.Stop()
63-
let elapsedMs = watch.ElapsedMilliseconds
64-
printfn "***** Training time: %f seconds *****" ((float)elapsedMs/1000.0)
50+
let elapsedMs = float watch.ElapsedMilliseconds
51+
printfn "***** Training time: %f seconds *****" (elapsedMs/1000.)
52+
6553

6654
// STEP 5: Evaluate the model and show accuracy stats
6755
printfn "===== Evaluating Model's accuracy with Test data ====="
68-
let metrics =
69-
(trainedModel, modelBuilder)
70-
|> Common.ModelBuilder.evaluateMultiClassClassificationModel testDataView "Label" "Score"
56+
let predictions = trainedModel.Transform(testDataView)
57+
let metrics = mlContext.MulticlassClassification.Evaluate(predictions, "Label", "Score")
7158

7259
Common.ConsoleHelper.printMultiClassClassificationMetrics (trainer.ToString()) metrics
7360

7461
// STEP 6: Save/persist the trained model to a .ZIP file
75-
(trainedModel, modelBuilder)
76-
|> Common.ModelBuilder.saveModelAsFile modelPath
62+
use fs = new FileStream(modelPath, FileMode.Create, FileAccess.Write, FileShare.Write)
63+
mlContext.Model.Save(trainedModel, fs);
7764

65+
printfn "The model is saved to %s" modelPath
7866

7967

8068
let testSomePredictions (mlContext : MLContext) =
81-
8269
//Test Classification Predictions with some hard-coded samples
83-
let modelScorer =
84-
Common.ModelScorer.create mlContext
85-
|> Common.ModelScorer.loadModelFromZipFile modelPath
86-
87-
let prediction = modelScorer |> Common.ModelScorer.predictSingle DataStructures.TestIrisData.Iris1
88-
printfn "Actual: setosa. Predicted probability: setosa: %.4f" prediction.Score.[0]
89-
printfn " versicolor: %.4f" prediction.Score.[1]
90-
printfn " virginica: %.4f" prediction.Score.[2]
70+
use stream = new FileStream(modelPath, FileMode.Open, FileAccess.Read, FileShare.Read)
71+
let trainedModel = mlContext.Model.Load(stream);
72+
73+
// Create prediction engine related to the loaded trained model
74+
let predEngine = trainedModel.CreatePredictionEngine<IrisData, IrisPrediction>(mlContext)
75+
76+
//Score sample 1
77+
let resultprediction1 = predEngine.Predict(DataStructures.SampleIrisData.Iris1)
78+
79+
printfn "Actual: setosa. Predicted probability: setosa: %.4f" resultprediction1.Score.[0]
80+
printfn " versicolor: %.4f" resultprediction1.Score.[1]
81+
printfn " virginica: %.4f" resultprediction1.Score.[2]
9182
printfn ""
9283

93-
let prediction = modelScorer |> Common.ModelScorer.predictSingle DataStructures.TestIrisData.Iris2
94-
printfn "Actual: virginica. Predicted probability: setosa: %.4f" prediction.Score.[0]
95-
printfn " versicolor: %.4f" prediction.Score.[1]
96-
printfn " virginica: %.4f" prediction.Score.[2]
84+
//Score sample 2
85+
let resultprediction2 = predEngine.Predict(DataStructures.SampleIrisData.Iris2);
86+
printfn "Actual: virginica. Predicted probability: setosa: %.4f" resultprediction2.Score.[0]
87+
printfn " versicolor: %.4f" resultprediction2.Score.[1]
88+
printfn " virginica: %.4f" resultprediction2.Score.[2]
9789
printfn ""
9890

99-
let prediction = modelScorer |> Common.ModelScorer.predictSingle DataStructures.TestIrisData.Iris3
100-
printfn "Actual: versicolor. Predicted probability: setosa: %.4f" prediction.Score.[0]
101-
printfn " versicolor: %.4f" prediction.Score.[1]
102-
printfn " virginica: %.4f" prediction.Score.[2]
91+
//Score sample 3
92+
let resultprediction2 = predEngine.Predict(DataStructures.SampleIrisData.Iris3);
93+
printfn "Actual: versicolor. Predicted probability: setosa: %.4f" resultprediction2.Score.[0]
94+
printfn " versicolor: %.4f" resultprediction2.Score.[1]
95+
printfn " virginica: %.4f" resultprediction2.Score.[2]
10396
printfn ""
10497

10598

@@ -117,6 +110,6 @@ let main argv =
117110
testSomePredictions mlContext
118111

119112
printfn "=============== End of process, hit any key to finish ==============="
120-
ConsoleHelper.consolePressAnyKey()
113+
Console.ReadKey() |> ignore
121114

122115
0 // return an integer exit code

0 commit comments

Comments
 (0)