@@ -13,63 +13,79 @@ namespace Samples.Dynamic.Trainers.BinaryClassification
13
13
{<#=Comments#>
14
14
public static void Example()
15
15
{
16
- // Create a new context for ML.NET operations. It can be used for exception tracking and logging,
17
- // as a catalog of available operations and as the source of randomness.
18
- // Setting the seed to a fixed number in this example to make outputs deterministic.
16
+ // Create a new context for ML.NET operations. It can be used for
17
+ // exception tracking and logging, as a catalog of available operations
18
+ // and as the source of randomness. Setting the seed to a fixed number
19
+ // in this example to make outputs deterministic.
19
20
var mlContext = new MLContext(seed: 0);
20
21
21
22
// Create a list of training data points.
22
23
var dataPoints = GenerateRandomDataPoints(1000);
23
24
24
- // Convert the list of data points to an IDataView object, which is consumable by ML.NET API.
25
+ // Convert the list of data points to an IDataView object, which is
26
+ // consumable by ML.NET API.
25
27
var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints);
26
28
<# if (CacheData) { #>
27
29
28
- // ML.NET doesn't cache data set by default. Therefore, if one reads a data set from a file and accesses it many times,
29
- // it can be slow due to expensive featurization and disk operations. When the considered data can fit into memory,
30
- // a solution is to cache the data in memory. Caching is especially helpful when working with iterative algorithms
30
+ // ML.NET doesn't cache data set by default. Therefore, if one reads a
31
+ // data set from a file and accesses it many times, it can be slow due
32
+ // to expensive featurization and disk operations. When the considered
33
+ // data can fit into memory, a solution is to cache the data in memory.
34
+ // Caching is especially helpful when working with iterative algorithms
31
35
// which needs many data passes.
32
36
trainingData = mlContext.Data.Cache(trainingData);
33
37
<# } #>
34
38
35
39
<# if (TrainerOptions == null) { #>
36
40
// Define the trainer.
37
- var pipeline = mlContext.BinaryClassification.Trainers.<#=Trainer#>();
41
+ var pipeline = mlContext.BinaryClassification.Trainers
42
+ .<#=Trainer#>();
38
43
<# } else { #>
39
44
// Define trainer options.
40
45
var options = new <#=TrainerOptions#>;
41
46
42
47
// Define the trainer.
43
- var pipeline = mlContext.BinaryClassification.Trainers.<#=Trainer#>(options);
48
+ var pipeline = mlContext.BinaryClassification.Trainers
49
+ .<#=Trainer#>(options);
44
50
<# } #>
45
51
46
52
// Train the model.
47
53
var model = pipeline.Fit(trainingData);
48
54
49
- // Create testing data. Use different random seed to make it different from training data.
50
- var testData = mlContext.Data.LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123));
55
+ // Create testing data. Use different random seed to make it different
56
+ // from training data.
57
+ var testData = mlContext.Data
58
+ .LoadFromEnumerable(GenerateRandomDataPoints(500, seed:123));
51
59
52
60
// Run the model on test data set.
53
61
var transformedTestData = model.Transform(testData);
54
62
55
63
// Convert IDataView object to a list.
56
- var predictions = mlContext.Data.CreateEnumerable<Prediction>(transformedTestData, reuseRowObject: false).ToList();
64
+ var predictions = mlContext.Data
65
+ .CreateEnumerable<Prediction>(transformedTestData,
66
+ reuseRowObject: false).ToList();
57
67
58
68
// Print 5 predictions.
59
69
foreach (var p in predictions.Take(5))
60
- Console.WriteLine($"Label: {p.Label}, Prediction: {p.PredictedLabel}");
70
+ Console.WriteLine($"Label: {p.Label}, "
71
+ + $"Prediction: {p.PredictedLabel}");
61
72
62
73
<#=ExpectedOutputPerInstance#>
63
- <# string Evaluator = IsCalibrated ? "Evaluate" : "EvaluateNonCalibrated"; #>
74
+ <# string Evaluator = IsCalibrated ? "Evaluate" :
75
+ "EvaluateNonCalibrated"; #>
64
76
65
77
// Evaluate the overall metrics.
66
- var metrics = mlContext.BinaryClassification.<#=Evaluator#>(transformedTestData);
78
+ var metrics = mlContext.BinaryClassification
79
+ .<#=Evaluator#>(transformedTestData);
80
+
67
81
PrintMetrics(metrics);
68
82
69
83
<#=ExpectedOutput#>
70
84
}
71
85
72
- private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count, int seed=0)
86
+ private static IEnumerable<DataPoint> GenerateRandomDataPoints(int count,
87
+ int seed=0)
88
+
73
89
{
74
90
var random = new Random(seed);
75
91
float randomFloat() => (float)random.NextDouble();
@@ -80,13 +96,18 @@ namespace Samples.Dynamic.Trainers.BinaryClassification
80
96
{
81
97
Label = label,
82
98
// Create random features that are correlated with the label.
83
- // For data points with false label, the feature values are slightly increased by adding a constant.
84
- Features = Enumerable.Repeat(label, 50).Select(x => x ? randomFloat() : randomFloat() + <#=DataSepValue#>).ToArray()
99
+ // For data points with false label, the feature values are
100
+ // slightly increased by adding a constant.
101
+ Features = Enumerable.Repeat(label, 50)
102
+ .Select(x => x ? randomFloat() : randomFloat() +
103
+ <#=DataSepValue#>).ToArray()
104
+
85
105
};
86
106
}
87
107
}
88
108
89
- // Example with label and 50 feature values. A data set is a collection of such examples.
109
+ // Example with label and 50 feature values. A data set is a collection of
110
+ // such examples.
90
111
private class DataPoint
91
112
{
92
113
public bool Label { get; set; }
@@ -109,11 +130,15 @@ namespace Samples.Dynamic.Trainers.BinaryClassification
109
130
Console.WriteLine($"Accuracy: {metrics.Accuracy:F2}");
110
131
Console.WriteLine($"AUC: {metrics.AreaUnderRocCurve:F2}");
111
132
Console.WriteLine($"F1 Score: {metrics.F1Score:F2}");
112
- Console.WriteLine($"Negative Precision: {metrics.NegativePrecision:F2}");
133
+ Console.WriteLine($"Negative Precision: " +
134
+ $"{metrics.NegativePrecision:F2}");
135
+
113
136
Console.WriteLine($"Negative Recall: {metrics.NegativeRecall:F2}");
114
- Console.WriteLine($"Positive Precision: {metrics.PositivePrecision:F2}");
137
+ Console.WriteLine($"Positive Precision: " +
138
+ $"{metrics.PositivePrecision:F2}");
139
+
115
140
Console.WriteLine($"Positive Recall: {metrics.PositiveRecall:F2}\n");
116
141
Console.WriteLine(metrics.ConfusionMatrix.GetFormattedConfusionTable());
117
142
}
118
143
}
119
- }
144
+ }
0 commit comments