Skip to content

Adding CodeGen piece for MatrixFactorization trainer #4391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 30, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge remote-tracking branch 'upstream/master' into codegen-recommend…
…ation

 Conflicts:
	src/Microsoft.ML.CodeGenerator/CodeGenerator/CSharp/CodeGenerator.cs
	src/Microsoft.ML.CodeGenerator/CodeGenerator/CSharp/TrainerGeneratorFactory.cs
	src/Microsoft.ML.CodeGenerator/CodeGenerator/CSharp/TrainerGenerators.cs
	src/Microsoft.ML.CodeGenerator/Templates/Console/ModelProject.cs
	src/Microsoft.ML.CodeGenerator/Templates/Console/ModelProject.tt
	src/Microsoft.ML.CodeGenerator/Templates/Console/PredictProject.cs
	src/Microsoft.ML.CodeGenerator/Templates/Console/PredictProject.tt
	test/Microsoft.ML.CodeGenerator.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppProjectFileContentTest.approved.txt
	test/Microsoft.ML.CodeGenerator.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ModelProjectFileContentTestOnlyStableProjects.approved.txt
	test/Microsoft.ML.CodeGenerator.Tests/ApprovalTests/ConsoleCodeGeneratorTests.cs
  • Loading branch information
maryamariyan committed Oct 29, 2019
commit 70e7ab38f2018d72839bf010711287519260b5ad
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,22 @@ public void GenerateOutput()
bool includeLightGbmPackage = false;
bool includeMklComponentsPackage = false;
bool includeFastTreeePackage = false;
bool includeImageTransformerPackage = false;
bool includeImageClassificationPackage = false;
bool includeRecommenderPackage = false;
SetRequiredNugetPackages(trainerNodes, ref includeLightGbmPackage, ref includeMklComponentsPackage, ref includeFastTreeePackage, ref includeRecommenderPackage);
// Get the extra nuget packages to be included in the generated project.
SetRequiredNugetPackages(_pipeline.Nodes, ref includeLightGbmPackage, ref includeMklComponentsPackage,
ref includeFastTreeePackage, ref includeImageTransformerPackage, ref includeImageClassificationPackage, ref includeRecommenderPackage);

// Get Namespace
var namespaceValue = Utils.Normalize(_settings.OutputName);
var labelType = _columnInferenceResult.TextLoaderOptions.Columns.Where(t => t.Name == _settings.LabelName).First().DataKind;
Type labelTypeCsharp = Utils.GetCSharpType(labelType);

// Generate Model Project
var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage, includeFastTreeePackage, includeRecommenderPackage);
var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp,
includeLightGbmPackage, includeMklComponentsPackage, includeFastTreeePackage,
includeImageTransformerPackage, includeImageClassificationPackage, includeRecommenderPackage);

// Write files to disk.
var modelprojectDir = Path.Combine(_settings.OutputBaseDir, $"{_settings.OutputName}.Model");
Expand All @@ -63,7 +69,9 @@ public void GenerateOutput()
Utils.WriteOutputToFiles(modelProjectContents.ModelProjectFileContent, modelProjectName, modelprojectDir);

// Generate ConsoleApp Project
var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage, includeFastTreeePackage, includeRecommenderPackage);
var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp,
includeLightGbmPackage, includeMklComponentsPackage, includeFastTreeePackage,
includeImageTransformerPackage, includeImageClassificationPackage, includeRecommenderPackage);

// Write files to disk.
var consoleAppProjectDir = Path.Combine(_settings.OutputBaseDir, $"{_settings.OutputName}.ConsoleApp");
Expand All @@ -81,7 +89,9 @@ public void GenerateOutput()
Utils.AddProjectsToSolution(modelprojectDir, modelProjectName, consoleAppProjectDir, consoleAppProjectName, solutionPath);
}

private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, ref bool includeLightGbmPackage, ref bool includeMklComponentsPackage, ref bool includeFastTreePackage, ref bool includeRecommenderPackage)
private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, ref bool includeLightGbmPackage,
ref bool includeMklComponentsPackage, ref bool includeFastTreePackage,
ref bool includeImageTransformerPackage, ref bool includeImageClassificationPackage, includeRecommenderPackage)
{
foreach (var node in trainerNodes)
{
Expand All @@ -103,19 +113,33 @@ private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, re
{
includeFastTreePackage = true;
}
else if (_imageTransformers.Contains(currentNode.Name))
{
includeImageTransformerPackage = true;
}
else if (_imageClassificationTrainers.Contains(currentNode.Name))
{
includeImageClassificationPackage = true;
}
else if (_recommendationTrainers.Contains(currentNode.Name))
{
includeRecommenderPackage = true;
}
}
}

internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent, string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage, bool includeRecommenderPackage)
internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent,
string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue,
Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage,
bool includeFastTreePackage, bool includeImageTransformerPackage,
bool includeImageClassificationPackage, bool includeRecommenderPackage)
{
var predictProgramCSFileContent = GeneratePredictProgramCSFileContent(namespaceValue);
predictProgramCSFileContent = Utils.FormatCode(predictProgramCSFileContent);

var predictProjectFileContent = GeneratPredictProjectFileContent(_settings.OutputName, includeLightGbmPackage, includeMklComponentsPackage, includeFastTreePackage, includeRecommenderPackage);
var predictProjectFileContent = GeneratPredictProjectFileContent(_settings.OutputName,
includeLightGbmPackage, includeMklComponentsPackage, includeFastTreePackage,
includeImageTransformerPackage, includeImageClassificationPackage, includeRecommenderPackage);

var transformsAndTrainers = GenerateTransformsAndTrainers();
var modelBuilderCSFileContent = GenerateModelBuilderCSFileContent(transformsAndTrainers.Usings, transformsAndTrainers.TrainerMethod, transformsAndTrainers.PreTrainerTransforms, transformsAndTrainers.PostTrainerTransforms, namespaceValue, _pipeline.CacheBeforeTrainer, labelTypeCsharp.Name);
Expand All @@ -124,7 +148,11 @@ private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, re
return (predictProgramCSFileContent, predictProjectFileContent, modelBuilderCSFileContent);
}

internal (string ModelInputCSFileContent, string ModelOutputCSFileContent, string ConsumeModelCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage, bool includeRecommenderPackage)
internal (string ModelInputCSFileContent, string ModelOutputCSFileContent, string ConsumeModelCSFileContent,
string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue,
Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage,
bool includeFastTreePackage, bool includeImageTransformerPackage,
bool includeImageClassificationPackage, bool includeRecommenderPackage)
{
var classLabels = GenerateClassLabels();

Expand All @@ -139,7 +167,10 @@ private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, re
// generate ConsumeModel.cs
var consumeModelCSFileContent = GenerateConsumeModelCSFileContent(namespaceValue);
consumeModelCSFileContent = Utils.FormatCode(consumeModelCSFileContent);
var modelProjectFileContent = GenerateModelProjectFileContent(includeLightGbmPackage, includeMklComponentsPackage, includeFastTreePackage, includeRecommenderPackage);
var modelProjectFileContent = GenerateModelProjectFileContent(includeLightGbmPackage,
includeMklComponentsPackage, includeFastTreePackage, includeImageTransformerPackage,
includeImageClassificationPackage, includeRecommenderPackage);

return (modelInputCSFileContent, modelOutputCSFileContent, consumeModelCSFileContent, modelProjectFileContent);
}

Expand Down Expand Up @@ -281,9 +312,20 @@ internal IList<string> GenerateClassLabels()
}

#region Model project
private static string GenerateModelProjectFileContent(bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage, bool includeRecommenderPackage)
private static string GenerateModelProjectFileContent(bool includeLightGbmPackage,
bool includeMklComponentsPackage, bool includeFastTreePackage, bool includeImageTransformerPackage,
bool includeImageClassificationPackage, bool includeRecommenderPackage)
{
ModelProject modelProject = new ModelProject() { IncludeLightGBMPackage = includeLightGbmPackage, IncludeMklComponentsPackage = includeMklComponentsPackage, IncludeFastTreePackage = includeFastTreePackage, IncludeRecommenderPackage = includeRecommenderPackage };
ModelProject modelProject = new ModelProject()
{
IncludeLightGBMPackage = includeLightGbmPackage,
IncludeMklComponentsPackage = includeMklComponentsPackage,
IncludeFastTreePackage = includeFastTreePackage,
IncludeImageTransformerPackage = includeImageTransformerPackage,
IncludeImageClassificationPackage = includeImageClassificationPackage,
IncludeRecommenderPackage = includeRecommenderPackage
};

return modelProject.TransformText();
}

Expand All @@ -310,9 +352,20 @@ private string GenerateModelInputCSFileContent(string namespaceValue, IList<stri
#endregion

#region Predict Project
private static string GeneratPredictProjectFileContent(string namespaceValue, bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage, bool includeRecommenderPackage)
private static string GeneratPredictProjectFileContent(string namespaceValue, bool includeLightGbmPackage,
bool includeMklComponentsPackage, bool includeFastTreePackage, bool includeImageTransformerPackage,
bool includeImageClassificationPackage, bool includeRecommenderPackage)
{
var predictProjectFileContent = new PredictProject() { Namespace = namespaceValue, IncludeMklComponentsPackage = includeMklComponentsPackage, IncludeLightGBMPackage = includeLightGbmPackage, IncludeFastTreePackage = includeFastTreePackage, IncludeRecommenderPackage = includeRecommenderPackage };
var predictProjectFileContent = new PredictProject()
{
Namespace = namespaceValue,
IncludeMklComponentsPackage = includeMklComponentsPackage,
IncludeLightGBMPackage = includeLightGbmPackage,
IncludeFastTreePackage = includeFastTreePackage,
IncludeImageTransformerPackage = includeImageTransformerPackage,
IncludeImageClassificationPackage = includeImageClassificationPackage,
IncludeRecommenderPackage = includeRecommenderPackage
};
return predictProjectFileContent.TransformText();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ internal static ITrainerGenerator GetInstance(PipelineNode node)
return new SymbolicSgdLogisticRegressionBinary(node);
case TrainerName.Ova:
return new OneVersusAll(node);
case TrainerName.ImageClassification:
return new ImageClassificationTrainer(node);
case TrainerName.MatrixFactorization:
return new MatrixFactorization(node);
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,27 @@ public override string[] GenerateUsings()
}
}

internal sealed class ImageClassificationTrainer : TrainerGeneratorBase
{
//ClassName of the trainer
internal override string MethodName => "ImageClassification";
internal override string OptionsName => "ImageClassificationTrainer.Options";
internal override string[] Usings => new string[] { "using Microsoft.ML.Dnn;\r\n" };

public ImageClassificationTrainer(PipelineNode node) : base(node)
{
}
//The named parameters to the trainer.
internal override IDictionary<string, string> NamedParameters
{
get
{
return
new Dictionary<string, string>();
}
}
}

internal class MatrixFactorization : TrainerGeneratorBase
{
//ClassName of the trainer
Expand Down
11 changes: 11 additions & 0 deletions src/Microsoft.ML.CodeGenerator/Templates/Console/ModelProject.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ public virtual string TransformText()
if (IncludeFastTreePackage){
this.Write(" <PackageReference Include=\"Microsoft.ML.FastTree\" Version=\"$(StablePackageVer" +
"sion)\" />\r\n");
}
if (IncludeImageTransformerPackage){
this.Write(" <PackageReference Include=\"Microsoft.ML.ImageAnalytics\" Version=\"$(StablePack" +
"ageVersion)\" />\r\n");
}
if (IncludeImageClassificationPackage){
this.Write(" <PackageReference Include=\"Microsoft.ML.Dnn\" Version=\"$(StablePackageVersion)" +
"\" />\r\n\t<PackageReference Include=\"SciSharp.TensorFlow.Redist\" Version=\"$(StableP" +
"ackageVersion)\" />\r\n");
}
if (IncludeRecommenderPackage){
this.Write(" <PackageReference Include=\"Microsoft.ML.Recommender\" Version=\"$(UnstablePacka" +
Expand All @@ -61,6 +70,8 @@ public virtual string TransformText()
public bool IncludeLightGBMPackage {get;set;}
public bool IncludeMklComponentsPackage {get;set;}
public bool IncludeFastTreePackage {get;set;}
public bool IncludeImageTransformerPackage {get; set;}
public bool IncludeImageClassificationPackage {get; set;}
public bool IncludeRecommenderPackage {get;set;}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@
<# if (IncludeFastTreePackage){ #>
<PackageReference Include="Microsoft.ML.FastTree" Version="$(StablePackageVersion)" />
<#}#>
<# if (IncludeImageTransformerPackage){ #>
<PackageReference Include="Microsoft.ML.ImageAnalytics" Version="$(StablePackageVersion)" />
<#}#>
<# if (IncludeImageClassificationPackage){ #>
<PackageReference Include="Microsoft.ML.Dnn" Version="$(StablePackageVersion)" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="$(StablePackageVersion)" />
<#}#>
<# if (IncludeRecommenderPackage){ #>
<PackageReference Include="Microsoft.ML.Recommender" Version="$(UnstablePackageVersion)" />
<#}#>
Expand All @@ -40,5 +47,7 @@
public bool IncludeLightGBMPackage {get;set;}
public bool IncludeMklComponentsPackage {get;set;}
public bool IncludeFastTreePackage {get;set;}
public bool IncludeImageTransformerPackage {get; set;}
public bool IncludeImageClassificationPackage {get; set;}
public bool IncludeRecommenderPackage {get;set;}
#>
11 changes: 11 additions & 0 deletions src/Microsoft.ML.CodeGenerator/Templates/Console/PredictProject.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ public virtual string TransformText()
if (IncludeFastTreePackage){
this.Write(" <PackageReference Include=\"Microsoft.ML.FastTree\" Version=\"$(StablePackageVer" +
"sion)\" />\r\n");
}
if (IncludeImageTransformerPackage){
this.Write(" <PackageReference Include=\"Microsoft.ML.ImageAnalytics\" Version=\"$(StablePack" +
"ageVersion)\" />\r\n");
}
if (IncludeImageClassificationPackage){
this.Write(" <PackageReference Include=\"Microsoft.ML.Dnn\" Version=\"$(StablePackageVersion)" +
"\" />\r\n\t<PackageReference Include=\"SciSharp.TensorFlow.Redist\" Version=\"$(StableP" +
"ackageVersion)\" />\r\n");
}
if (IncludeRecommenderPackage){
this.Write(" <PackageReference Include=\"Microsoft.ML.Recommender\" Version=\"$(UnstablePacka" +
Expand All @@ -66,6 +75,8 @@ public virtual string TransformText()
public bool IncludeLightGBMPackage {get;set;}
public bool IncludeMklComponentsPackage {get;set;}
public bool IncludeFastTreePackage {get;set;}
public bool IncludeImageTransformerPackage {get; set;}
public bool IncludeImageClassificationPackage {get; set;}
public bool IncludeRecommenderPackage {get;set;}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@
<# if (IncludeFastTreePackage){ #>
<PackageReference Include="Microsoft.ML.FastTree" Version="$(StablePackageVersion)" />
<#}#>
<# if (IncludeImageTransformerPackage){ #>
<PackageReference Include="Microsoft.ML.ImageAnalytics" Version="$(StablePackageVersion)" />
<#}#>
<# if (IncludeImageClassificationPackage){ #>
<PackageReference Include="Microsoft.ML.Dnn" Version="$(StablePackageVersion)" />
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="$(StablePackageVersion)" />
<#}#>
<# if (IncludeRecommenderPackage){ #>
<PackageReference Include="Microsoft.ML.Recommender" Version="$(UnstablePackageVersion)" />
<#}#>
Expand All @@ -39,5 +46,7 @@ public string Namespace {get;set;}
public bool IncludeLightGBMPackage {get;set;}
public bool IncludeMklComponentsPackage {get;set;}
public bool IncludeFastTreePackage {get;set;}
public bool IncludeImageTransformerPackage {get; set;}
public bool IncludeImageClassificationPackage {get; set;}
public bool IncludeRecommenderPackage {get;set;}
#>
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.