Skip to content

Scores to Label mapping #239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 25, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions src/Microsoft.ML.Core/Data/ITransformModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@ public interface ITransformModel
/// Note that the schema may have columns that aren't needed by this transform model.
/// If an IDataView exists with this schema, then applying this transform model to it
/// shouldn't fail because of column type issues.
/// REVIEW: Would be nice to be able to trim this to the minimum needed somehow. Note
/// however that doing so may cause issues for composing transform models. For example,
/// if transform model A needs column X and model B needs Y, that is NOT produced by A,
/// then trimming A's input schema would cause composition to fail.
/// </summary>
// REVIEW: Would be nice to be able to trim this to the minimum needed somehow. Note
// however that doing so may cause issues for composing transform models. For example,
Copy link
Contributor

@TomFinley TomFinley May 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh good. I'm half tempted to compose a regular expression to see how many of these beauties had slipped into the XML docs. :P :D #ByDesign

// if transform model A needs column X and model B needs Y, that is NOT produced by A,
// then trimming A's input schema would cause composition to fail.
ISchema InputSchema { get; }

/// <summary>
/// The output schema that this transform model was originally instantiated on. The schema resulting
/// from <see cref="Apply(IHostEnvironment, ITransformModel)"/> may differ from this, similarly to how
/// <see cref="InputSchema"/> may differ from the schema of dataviews we apply this transform model to.
/// </summary>
ISchema OutputSchema { get; }

/// <summary>
/// Apply the transform(s) in the model to the given input data.
/// </summary>
Expand Down
12 changes: 8 additions & 4 deletions src/Microsoft.ML.Data/EntryPoints/TransformModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@ public sealed class TransformModel : ITransformModel
/// if transform model A needs column X and model B needs Y, that is NOT produced by A,
/// then trimming A's input schema would cause composition to fail.
/// </summary>
public ISchema InputSchema
{
get { return _schemaRoot; }
}
public ISchema InputSchema => _schemaRoot;

/// <summary>
/// The resulting schema once applied to this model. The <see cref="InputSchema"/> might have
/// columns that are not needed by this transform and these columns will be seen in the
/// <see cref="OutputSchema"/> produced by this transform.
/// </summary>
public ISchema OutputSchema => _chain.Schema;

/// <summary>
/// Create a TransformModel containing the transforms from "result" back to "input".
Expand Down
34 changes: 34 additions & 0 deletions src/Microsoft.ML/PredictionModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.Internal.Utilities;
using System;
using System.Collections.Generic;
using System.IO;
Expand All @@ -29,6 +30,39 @@ internal Runtime.EntryPoints.TransformModel PredictorModel
get { return _predictorModel; }
}

/// <summary>
/// Returns labels that correspond to indices of the score array in the case of
/// multi-class classification problem.
/// </summary>
/// <param name="names">Label to score mapping</param>
/// <param name="scoreColumnName">Name of the score column</param>
/// <returns></returns>
public bool TryGetScoreLabelNames(out string[] names, string scoreColumnName = DefaultColumnNames.Score)
{
names = null;
ISchema schema = _predictorModel.OutputSchema;
int colIndex = -1;
if (!schema.TryGetColumnIndex(scoreColumnName, out colIndex))
return false;

int expectedLabelCount = schema.GetColumnType(colIndex).ValueCount;
if (!schema.HasSlotNames(colIndex, expectedLabelCount))
return false;

VBuffer<DvText> labels = default;
schema.GetMetadata(MetadataUtils.Kinds.SlotNames, colIndex, ref labels);

if (labels.Length != expectedLabelCount)
return false;

names = new string[expectedLabelCount];
int index = 0;
foreach(var label in labels.DenseValues())
names[index++] = label.ToString();

return true;
}

/// <summary>
/// Read model from file asynchronously.
/// </summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ public void TrainAndPredictIrisModelWithStringLabelTest()
pipeline.Add(new StochasticDualCoordinateAscentClassifier());

PredictionModel<IrisDataWithStringLabel, IrisPrediction> model = pipeline.Train<IrisDataWithStringLabel, IrisPrediction>();
string[] scoreLabels;
model.TryGetScoreLabelNames(out scoreLabels);

Assert.NotNull(scoreLabels);
Assert.Equal(3, scoreLabels.Length);
Assert.Equal("Iris-setosa", scoreLabels[0]);
Assert.Equal("Iris-versicolor", scoreLabels[1]);
Assert.Equal("Iris-virginica", scoreLabels[2]);

IrisPrediction prediction = model.Predict(new IrisDataWithStringLabel()
{
Expand Down