From 8cdafd031655c6c12387787f47d8ea9185fc8b5b Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Fri, 8 Jun 2018 13:13:30 -0700 Subject: [PATCH 1/2] Create CalibratedPredictor instead of SchemaBindableCalibratedPredictor whenever the predictor implements IValueMapper. --- src/Microsoft.ML.Data/Prediction/Calibrator.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index efa52d2ff6..b6f4fe3344 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -853,7 +853,7 @@ public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironm var predWithFeatureScores = predictor as IPredictorWithFeatureWeights; if (predWithFeatureScores != null && predictor is IParameterMixer && cali is IParameterMixer) return new ParameterMixingCalibratedPredictor(env, predWithFeatureScores, cali); - if (needValueMapper) + if (needValueMapper || predictor is IValueMapper) return new CalibratedPredictor(env, predictor, cali); return new SchemaBindableCalibratedPredictor(env, predictor, cali); } From 2ff2d77f5b874d1c830593ae5f582c7853061c34 Mon Sep 17 00:00:00 2001 From: Yael Dekel Date: Mon, 11 Jun 2018 14:37:15 -0700 Subject: [PATCH 2/2] Address PR comments. --- .../Prediction/Calibrator.cs | 18 ++---- .../Standard/MultiClass/Ova.cs | 2 +- .../UnitTests/TestCSharpApi.cs | 59 +++++++++++++++++++ 3 files changed, 66 insertions(+), 13 deletions(-) diff --git a/src/Microsoft.ML.Data/Prediction/Calibrator.cs b/src/Microsoft.ML.Data/Prediction/Calibrator.cs index b6f4fe3344..487726572b 100644 --- a/src/Microsoft.ML.Data/Prediction/Calibrator.cs +++ b/src/Microsoft.ML.Data/Prediction/Calibrator.cs @@ -746,13 +746,10 @@ private static bool NeedCalibration(IHostEnvironment env, IChannel ch, ICalibrat /// The trainer used to train the predictor. /// The predictor that needs calibration. /// The examples to used for calibrator training. - /// Indicates whether the predictor returned needs to be an . - /// This parameter is needed for OVA that uses the predictors as s. If it is false, - /// The predictor returned is an an . /// The original predictor, if no calibration is needed, /// or a metapredictor that wraps the original predictor and the newly trained calibrator. public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel ch, ICalibratorTrainer calibrator, - int maxRows, ITrainer trainer, IPredictor predictor, RoleMappedData data, bool needValueMapper = false) + int maxRows, ITrainer trainer, IPredictor predictor, RoleMappedData data) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); @@ -763,7 +760,7 @@ public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel if (!NeedCalibration(env, ch, calibrator, trainer, predictor, data.Schema)) return predictor; - return TrainCalibrator(env, ch, calibrator, maxRows, predictor, data, needValueMapper); + return TrainCalibrator(env, ch, calibrator, maxRows, predictor, data); } /// @@ -775,13 +772,10 @@ public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel /// The maximum rows to use for calibrator training. /// The predictor that needs calibration. /// The examples to used for calibrator training. - /// Indicates whether the predictor returned needs to be an . - /// This parameter is needed for OVA that uses the predictors as s. If it is false, - /// The predictor returned is an an . /// The original predictor, if no calibration is needed, /// or a metapredictor that wraps the original predictor and the newly trained calibrator. public static IPredictor TrainCalibrator(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer, - int maxRows, IPredictor predictor, RoleMappedData data, bool needValueMapper = false) + int maxRows, IPredictor predictor, RoleMappedData data) { Contracts.CheckValue(env, nameof(env)); env.CheckValue(ch, nameof(ch)); @@ -834,10 +828,10 @@ public static IPredictor TrainCalibrator(IHostEnvironment env, IChannel ch, ICal } } var cali = caliTrainer.FinishTraining(ch); - return CreateCalibratedPredictor(env, (IPredictorProducing)predictor, cali, needValueMapper); + return CreateCalibratedPredictor(env, (IPredictorProducing)predictor, cali); } - public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironment env, IPredictorProducing predictor, ICalibrator cali, bool needValueMapper = false) + public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironment env, IPredictorProducing predictor, ICalibrator cali) { Contracts.Assert(predictor != null); if (cali == null) @@ -853,7 +847,7 @@ public static IPredictorProducing CreateCalibratedPredictor(IHostEnvironm var predWithFeatureScores = predictor as IPredictorWithFeatureWeights; if (predWithFeatureScores != null && predictor is IParameterMixer && cali is IParameterMixer) return new ParameterMixingCalibratedPredictor(env, predWithFeatureScores, cali); - if (needValueMapper || predictor is IValueMapper) + if (predictor is IValueMapper) return new CalibratedPredictor(env, predictor, cali); return new SchemaBindableCalibratedPredictor(env, predictor, cali); } diff --git a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs index 3d6e1e67b2..8aa6e4b6e0 100644 --- a/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs +++ b/src/Microsoft.ML.StandardLearners/Standard/MultiClass/Ova.cs @@ -92,7 +92,7 @@ private TScalarPredictor TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappe else calibrator = Args.Calibrator.CreateInstance(Host); var res = CalibratorUtils.TrainCalibratorIfNeeded(Host, ch, calibrator, Args.MaxCalibrationExamples, - trainer, predictor, td, true); + trainer, predictor, td); predictor = res as TScalarPredictor; Host.Check(predictor != null, "Calibrated predictor does not implement the expected interface"); } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index b42dee2d52..f8ac506f6e 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -798,5 +798,64 @@ public void TestOvaMacro() } } } + + [Fact] + public void TestOvaMacroWithUncalibratedLearner() + { + var dataPath = GetDataPath(@"iris.txt"); + using (var env = new TlcEnvironment(42)) + { + // Specify subgraph for OVA + var subGraph = env.CreateExperiment(); + var learnerInput = new Trainers.AveragedPerceptronBinaryClassifier { Shuffle = false }; + var learnerOutput = subGraph.Add(learnerInput); + // Create pipeline with OVA and multiclass scoring. + var experiment = env.CreateExperiment(); + var importInput = new ML.Data.TextLoader(dataPath); + importInput.Arguments.Column = new TextLoaderColumn[] + { + new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } }, + new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(1,4) } } + }; + var importOutput = experiment.Add(importInput); + var oneVersusAll = new Models.OneVersusAll + { + TrainingData = importOutput.Data, + Nodes = subGraph, + UseProbabilities = true, + }; + var ovaOutput = experiment.Add(oneVersusAll); + var scoreInput = new ML.Transforms.DatasetScorer + { + Data = importOutput.Data, + PredictorModel = ovaOutput.PredictorModel + }; + var scoreOutput = experiment.Add(scoreInput); + var evalInput = new ML.Models.ClassificationEvaluator + { + Data = scoreOutput.ScoredData + }; + var evalOutput = experiment.Add(evalInput); + experiment.Compile(); + experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + experiment.Run(); + + var data = experiment.GetOutput(evalOutput.OverallMetrics); + var schema = data.Schema; + var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol); + Assert.True(b); + using (var cursor = data.GetRowCursor(col => col == accCol)) + { + var getter = cursor.GetGetter(accCol); + b = cursor.MoveNext(); + Assert.True(b); + double acc = 0; + getter(ref acc); + Assert.Equal(0.71, acc, 2); + b = cursor.MoveNext(); + Assert.False(b); + } + } + } } }