Skip to content

Adding CodeGen piece for MatrixFactorization trainer #4391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ internal class CodeGenerator : IProjectGenerator
private readonly Pipeline _pipeline;
private readonly CodeGeneratorSettings _settings;
private readonly ColumnInferenceResults _columnInferenceResult;
private static readonly HashSet<string> _recommendationTrainers = new HashSet<string>() { TrainerName.MatrixFactorization.ToString() };
private static readonly HashSet<string> _lightGbmTrainers = new HashSet<string>() { TrainerName.LightGbmBinary.ToString(), TrainerName.LightGbmMulti.ToString(), TrainerName.LightGbmRegression.ToString() };
private static readonly HashSet<string> _mklComponentsTrainers = new HashSet<string>() { TrainerName.OlsRegression.ToString(), TrainerName.SymbolicSgdLogisticRegressionBinary.ToString() };
private static readonly HashSet<string> _fastTreeTrainers = new HashSet<string>() { TrainerName.FastForestBinary.ToString(), TrainerName.FastForestRegression.ToString(), TrainerName.FastTreeBinary.ToString(), TrainerName.FastTreeRegression.ToString(), TrainerName.FastTreeTweedieRegression.ToString() };
Expand All @@ -42,9 +43,10 @@ public void GenerateOutput()
bool includeFastTreeePackage = false;
bool includeImageTransformerPackage = false;
bool includeImageClassificationPackage = false;
bool includeRecommenderPackage = false;
// Get the extra nuget packages to be included in the generated project.
SetRequiredNugetPackages(_pipeline.Nodes, ref includeLightGbmPackage, ref includeMklComponentsPackage,
ref includeFastTreeePackage, ref includeImageTransformerPackage, ref includeImageClassificationPackage);
ref includeFastTreeePackage, ref includeImageTransformerPackage, ref includeImageClassificationPackage, ref includeRecommenderPackage);

// Get Namespace
var namespaceValue = Utils.Normalize(_settings.OutputName);
Expand All @@ -54,7 +56,7 @@ public void GenerateOutput()
// Generate Model Project
var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp,
includeLightGbmPackage, includeMklComponentsPackage, includeFastTreeePackage,
includeImageTransformerPackage, includeImageClassificationPackage);
includeImageTransformerPackage, includeImageClassificationPackage, includeRecommenderPackage);

// Write files to disk.
var modelprojectDir = Path.Combine(_settings.OutputBaseDir, $"{_settings.OutputName}.Model");
Expand All @@ -69,7 +71,7 @@ public void GenerateOutput()
// Generate ConsoleApp Project
var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp,
includeLightGbmPackage, includeMklComponentsPackage, includeFastTreeePackage,
includeImageTransformerPackage, includeImageClassificationPackage);
includeImageTransformerPackage, includeImageClassificationPackage, includeRecommenderPackage);

// Write files to disk.
var consoleAppProjectDir = Path.Combine(_settings.OutputBaseDir, $"{_settings.OutputName}.ConsoleApp");
Expand All @@ -89,7 +91,7 @@ public void GenerateOutput()

private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, ref bool includeLightGbmPackage,
ref bool includeMklComponentsPackage, ref bool includeFastTreePackage,
ref bool includeImageTransformerPackage, ref bool includeImageClassificationPackage)
ref bool includeImageTransformerPackage, ref bool includeImageClassificationPackage, ref bool includeRecommenderPackage)
{
foreach (var node in trainerNodes)
{
Expand Down Expand Up @@ -119,21 +121,26 @@ private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, re
{
includeImageClassificationPackage = true;
}
else if (_recommendationTrainers.Contains(currentNode.Name))
{
includeRecommenderPackage = true;
}
}
}

internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent,
string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue,
Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage,
bool includeFastTreePackage, bool includeImageTransformerPackage,
bool includeImageClassificationPackage)
bool includeImageClassificationPackage, bool includeRecommenderPackage)
{
var predictProgramCSFileContent = GeneratePredictProgramCSFileContent(namespaceValue);
predictProgramCSFileContent = Utils.FormatCode(predictProgramCSFileContent);

var predictProjectFileContent = GeneratPredictProjectFileContent(_settings.OutputName,
includeLightGbmPackage, includeMklComponentsPackage, includeFastTreePackage,
includeImageTransformerPackage, includeImageClassificationPackage);
includeImageTransformerPackage, includeImageClassificationPackage, includeRecommenderPackage,
_settings.StablePackageVersion, _settings.UnstablePackageVersion);

var transformsAndTrainers = GenerateTransformsAndTrainers();
var modelBuilderCSFileContent = GenerateModelBuilderCSFileContent(transformsAndTrainers.Usings, transformsAndTrainers.TrainerMethod, transformsAndTrainers.PreTrainerTransforms, transformsAndTrainers.PostTrainerTransforms, namespaceValue, _pipeline.CacheBeforeTrainer, labelTypeCsharp.Name);
Expand All @@ -146,7 +153,7 @@ private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, re
string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue,
Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage,
bool includeFastTreePackage, bool includeImageTransformerPackage,
bool includeImageClassificationPackage)
bool includeImageClassificationPackage, bool includeRecommenderPackage)
{
var classLabels = GenerateClassLabels();

Expand All @@ -163,7 +170,8 @@ private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, re
consumeModelCSFileContent = Utils.FormatCode(consumeModelCSFileContent);
var modelProjectFileContent = GenerateModelProjectFileContent(includeLightGbmPackage,
includeMklComponentsPackage, includeFastTreePackage, includeImageTransformerPackage,
includeImageClassificationPackage);
includeImageClassificationPackage, includeRecommenderPackage,
_settings.StablePackageVersion, _settings.UnstablePackageVersion);

return (modelInputCSFileContent, modelOutputCSFileContent, consumeModelCSFileContent, modelProjectFileContent);
}
Expand Down Expand Up @@ -308,15 +316,19 @@ internal IList<string> GenerateClassLabels()
#region Model project
private static string GenerateModelProjectFileContent(bool includeLightGbmPackage,
bool includeMklComponentsPackage, bool includeFastTreePackage, bool includeImageTransformerPackage,
bool includeImageClassificationPackage)
bool includeImageClassificationPackage, bool includeRecommenderPackage,
string stablePackageVersion, string unstablePackageVersion)
{
ModelProject modelProject = new ModelProject()
{
IncludeLightGBMPackage = includeLightGbmPackage,
IncludeMklComponentsPackage = includeMklComponentsPackage,
IncludeFastTreePackage = includeFastTreePackage,
IncludeImageTransformerPackage = includeImageTransformerPackage,
IncludeImageClassificationPackage = includeImageClassificationPackage
IncludeImageClassificationPackage = includeImageClassificationPackage,
IncludeRecommenderPackage = includeRecommenderPackage,
StablePackageVersion = stablePackageVersion,
UnstablePackageVersion = unstablePackageVersion
};

return modelProject.TransformText();
Expand Down Expand Up @@ -347,7 +359,8 @@ private string GenerateModelInputCSFileContent(string namespaceValue, IList<stri
#region Predict Project
private static string GeneratPredictProjectFileContent(string namespaceValue, bool includeLightGbmPackage,
bool includeMklComponentsPackage, bool includeFastTreePackage, bool includeImageTransformerPackage,
bool includeImageClassificationPackage)
bool includeImageClassificationPackage, bool includeRecommenderPackage,
string stablePackageVersion, string unstablePackageVersion)
{
var predictProjectFileContent = new PredictProject()
{
Expand All @@ -356,7 +369,10 @@ private static string GeneratPredictProjectFileContent(string namespaceValue, bo
IncludeLightGBMPackage = includeLightGbmPackage,
IncludeFastTreePackage = includeFastTreePackage,
IncludeImageTransformerPackage = includeImageTransformerPackage,
IncludeImageClassificationPackage = includeImageClassificationPackage
IncludeImageClassificationPackage = includeImageClassificationPackage,
IncludeRecommenderPackage = includeRecommenderPackage,
StablePackageVersion = stablePackageVersion,
UnstablePackageVersion = unstablePackageVersion
};
return predictProjectFileContent.TransformText();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ public CodeGeneratorSettings()

public GenerateTarget Target { get; set; }

public string StablePackageVersion { get; set; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JakeRadMSFT @LittleLittleCloud - what do you think of this approach? It at least removes hard-coded versions from the CodeGenerator assembly, and allows the "app" (CLI and Model Builder) to pick the version they require.


public string UnstablePackageVersion { get; set; }

internal TaskKind MlTask { get; set; }

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ internal abstract class TrainerGeneratorBase : ITrainerGenerator
private Dictionary<string, object> _arguments;
private bool _hasAdvancedSettings;
private string _seperator;
protected virtual bool IncludeFeatureColumnName => true;

//abstract properties
internal abstract string OptionsName { get; }
Expand Down Expand Up @@ -47,7 +48,10 @@ private void Initialize(PipelineNode node)
{
node.Properties.Add("LabelColumnName", "Label");
}
node.Properties.Add("FeatureColumnName", "Features");
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FeatureColumnName gets added for all existing trainer generators. I added IncludeFeatureColumnName which is false for MatrixFactorization trainer generator only.

Copy link
Member Author

@maryamariyan maryamariyan Oct 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO:

  • Look more closely at the implementation to see why it was expected for all Trainer Generators to have a FeatureColumn. if it is expected then it should also be added for MatrixFactorization trainer generator

if (IncludeFeatureColumnName)
{
node.Properties.Add("FeatureColumnName", "Features");
}

foreach (var kv in node.Properties)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ internal static ITrainerGenerator GetInstance(PipelineNode node)
return new OneVersusAll(node);
case TrainerName.ImageClassification:
return new ImageClassificationTrainer(node);
case TrainerName.MatrixFactorization:
return new MatrixFactorization(node);
default:
throw new ArgumentException($"The trainer '{trainer}' is not handled currently.");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -577,5 +577,35 @@ internal override IDictionary<string, string> NamedParameters
}
}
}

internal class MatrixFactorization : TrainerGeneratorBase
{
//ClassName of the trainer
internal override string MethodName => "MatrixFactorization";

internal override string OptionsName => "MatrixFactorizationTrainer.Options";
protected override bool IncludeFeatureColumnName => false;

//The named parameters to the trainer.
internal override IDictionary<string, string> NamedParameters
{
get
{
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit of an anti-pattern: to return a new object every time a property is invoked. I see this is already being used in existing code, so you don't need to fix it in this PR. But in general, calling a property is supposed to be fast.

https://docs.microsoft.com/en-us/dotnet/standard/design-guidelines/property

Although properties are technically very similar to methods, they are quite different in terms of their usage scenarios. They should be seen as smart fields.

new Dictionary<string, string>()
{
{ "MatrixColumnIndexColumnName","matrixColumnIndexColumnName" },
{ "MatrixRowIndexColumnName","matrixRowIndexColumnName" },
{ "LabelColumnName","labelColumnName" }
};
}
}

internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n" };

public MatrixFactorization(PipelineNode node) : base(node)
{
}
}
}
}
Loading