Skip to content

Onedal algorithms backed by nuget packages #6521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 66 commits into from
Dec 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
66 commits
Select commit Hold shift + click to select a range
5d1c55f
Add OneDal namespace
Alexsandruss Nov 11, 2022
4d38486
oneDAL algorithms wrapper
Alexsandruss Nov 11, 2022
df4caed
oneDAL Decision Forest integration
Alexsandruss Nov 11, 2022
646cfcd
oneDAL Linear Models
Alexsandruss Nov 11, 2022
9b913e7
oneDAL README
Alexsandruss Nov 11, 2022
3f30986
Fix mac lib path
Alexsandruss Nov 11, 2022
e34cb51
Fix cmake args for Win platform
Alexsandruss Nov 11, 2022
1592aa9
Fix cmake args for Win platform 2
Alexsandruss Nov 11, 2022
2324141
Modify building of onedal wrapper
Alexsandruss Nov 20, 2022
c6dc1c4
Temp.fix of wrapper deps
Alexsandruss Nov 21, 2022
2476425
Add onedal devel version
Alexsandruss Nov 21, 2022
3870283
Modify linking and add tbb libs copy
Alexsandruss Nov 21, 2022
e0e3cd4
Modify link to tbb[malloc]
Alexsandruss Nov 21, 2022
23e9c04
Fix tbb link
Alexsandruss Nov 21, 2022
60fd3aa
Fix tbb versions
Alexsandruss Nov 21, 2022
34da106
tbb linking fixes
Alexsandruss Nov 21, 2022
6ff47c5
Fix tbb linking 2
Alexsandruss Nov 21, 2022
1014d7e
Remove df_clsf inference from wrapper, fix msvc warnings
Alexsandruss Nov 21, 2022
0981764
Fix win linking
Alexsandruss Nov 21, 2022
c37debb
Fix linking on Windows
Alexsandruss Nov 21, 2022
12866e2
fix onedalutils
Alexsandruss Nov 22, 2022
b803f51
Fix usage of onedalutils
Alexsandruss Nov 22, 2022
1545e9d
Fix usage of onedalutils n2
Alexsandruss Nov 22, 2022
fd4bc39
Fix usage of onedalutils n3
Alexsandruss Nov 22, 2022
dbfb792
Fix usage of onedalutils n4
Alexsandruss Nov 22, 2022
057c0e9
Fix win compile options
Alexsandruss Nov 23, 2022
0428e69
Corrected some formatting issues
rgesteve Dec 4, 2022
62d3880
The logic (on linux) to download nuget packages
rgesteve Dec 4, 2022
10a8aa2
First cut trying to download packages on Windows
rgesteve Dec 5, 2022
cf67884
Added a sample for OneDal Random forest
rgesteve Dec 5, 2022
747682f
Merge branch 'onedal_with_nuget' of https://github.com/rgesteve/machi…
rgesteve Dec 5, 2022
0c9a00c
Move download functionality to msbuild
rgesteve Dec 5, 2022
d5f2442
OS-specific downloads
rgesteve Dec 5, 2022
16f7c7b
Revert changes to these files, as now using PackageDownloads instead
rgesteve Dec 6, 2022
4ab37db
Remove unused cmake variables
Alexsandruss Dec 6, 2022
512bf86
Merge branch 'alex_exp_onedal' into onedal_with_nuget
rgesteve Dec 6, 2022
f2aa4b1
Rebasing to main
rgesteve Dec 7, 2022
5875e31
Restore accessibility of ctor in favor of changes in AssemblyInfo
rgesteve Dec 8, 2022
075b48e
dependencies of benchmark driver
rgesteve Dec 12, 2022
d0de991
Small driver that installs dependencies and runs benchmarking scripts
rgesteve Dec 12, 2022
42012ba
Guard onedal as exclusive of x64 arch
rgesteve Dec 12, 2022
6afa6cb
Merge branch 'onedal_with_nuget' of https://github.com/rgesteve/machi…
rgesteve Dec 12, 2022
db39a8f
Consider MacOS builds
rgesteve Dec 13, 2022
858be56
Only build OneDal in x64 architectures
rgesteve Dec 13, 2022
6fb303b
The IS_64BIT_BUILD guard didn't work, switching to string comp
rgesteve Dec 13, 2022
ad499b2
Accomodate arch reporting on mac
rgesteve Dec 13, 2022
13b8931
Activate OneDal only on x64
rgesteve Dec 13, 2022
afae3ea
Only pass build parameters for onedal in x64
rgesteve Dec 14, 2022
7c01466
When on Windows, setting onedal to only build on x64
rgesteve Dec 14, 2022
6f47eac
Avoid CMake 'var not used' error
rgesteve Dec 14, 2022
07be9f3
Copy OneDal wrapper only on x64 architectures
rgesteve Dec 14, 2022
27827c1
OS-specific download (instead of same payload for linux/macos)
rgesteve Dec 15, 2022
e7368d7
fixed sln file
michaelgsharp Dec 15, 2022
e9c1436
A better range for average memory requirements, extra output breaks b…
rgesteve Dec 16, 2022
4b52f28
Merge branch 'onedal_with_nuget' of https://github.com/rgesteve/machi…
rgesteve Dec 16, 2022
371ccd9
Adding a unit test for OneDAL, and updating the usage notes
rgesteve Dec 16, 2022
68f0a17
Fixed styling for oneDAL and small syntax nitpicks
rgesteve Dec 16, 2022
52d4d6a
fixing onedal project not showing in VS
michaelgsharp Dec 16, 2022
3331545
Added temp test to probe loading libraries
rgesteve Dec 19, 2022
d8433e5
Merge branch 'onedal_with_nuget' of https://github.com/rgesteve/machi…
rgesteve Dec 19, 2022
d19621b
Displaying where this is (supposedly) reading Native DLLs
rgesteve Dec 19, 2022
5bfab70
Copy dependencies so that they're included in Microsoft.ML.OneDal nupkg
rgesteve Dec 19, 2022
e2a6094
Remove debugging tests
rgesteve Dec 20, 2022
35a18a1
Addressing having to set LD_LIBRARY_PATH manually
rgesteve Dec 20, 2022
a603030
Copy onedal dependencies to avoid assumption they include pdbs
rgesteve Dec 21, 2022
34ddbba
PATH manipulation on Win to account for dll loading
rgesteve Dec 21, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion Microsoft.ML.sln
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Microsoft Visual Studio Solution File, Format Version 12.00
Microsoft Visual Studio Solution File, Format Version 12.00
# Visual Studio Version 17
VisualStudioVersion = 17.1.32120.378
MinimumVisualStudioVersion = 10.0.40219.1
Expand Down Expand Up @@ -71,6 +71,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.ImageAnalytics
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Mkl.Components", "src\Microsoft.ML.Mkl.Components\Microsoft.ML.Mkl.Components.csproj", "{A7222F41-1CF0-47D9-B80C-B4D77B027A61}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.OneDal", "src\Microsoft.ML.OneDal\Microsoft.ML.OneDal.csproj", "{A7222F94-2AF1-10C9-A21C-C4D22B137A69}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TensorFlow", "src\Microsoft.ML.TensorFlow\Microsoft.ML.TensorFlow.csproj", "{570A0B8A-5463-44D2-8521-54C0CA4CACA9}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TimeSeries", "src\Microsoft.ML.TimeSeries\Microsoft.ML.TimeSeries.csproj", "{5A79C7F0-3D99-4123-B0DA-7C9FFCD13132}"
Expand All @@ -90,6 +92,8 @@ EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Samples", "docs\samples\Microsoft.ML.Samples\Microsoft.ML.Samples.csproj", "{ECB71297-9DF1-48CE-B93A-CD969221F9B6}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.SamplesUtils", "src\Microsoft.ML.SamplesUtils\Microsoft.ML.SamplesUtils.csproj", "{11A5210E-2EA7-42F1-80DB-827762E9C781}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Samples.OneDal", "docs\samples\Microsoft.ML.Samples.OneDal\Microsoft.ML.Samples.OneDal.csproj", "{38ED61F4-FA22-4DE9-B0C4-91F327F4EE31}"
EndProject
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Recommender", "src\Microsoft.ML.Recommender\Microsoft.ML.Recommender.csproj", "{C8E1772B-DFD9-4A4D-830D-6AAB1C668BB3}"
EndProject
Expand Down Expand Up @@ -403,6 +407,14 @@ Global
{A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Debug|x64.ActiveCfg = Debug|Any CPU
{A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Debug|x64.Build.0 = Debug|Any CPU
{A7222F94-2AF1-10C9-A21C-C4D22B137A69}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{A7222F94-2AF1-10C9-A21C-C4D22B137A69}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A7222F94-2AF1-10C9-A21C-C4D22B137A69}.Debug|x64.ActiveCfg = Debug|Any CPU
{A7222F94-2AF1-10C9-A21C-C4D22B137A69}.Debug|x64.Build.0 = Debug|Any CPU
{A7222F94-2AF1-10C9-A21C-C4D22B137A69}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A7222F94-2AF1-10C9-A21C-C4D22B137A69}.Release|Any CPU.Build.0 = Release|Any CPU
{A7222F94-2AF1-10C9-A21C-C4D22B137A69}.Release|x64.ActiveCfg = Release|Any CPU
{A7222F94-2AF1-10C9-A21C-C4D22B137A69}.Release|x64.Build.0 = Release|Any CPU
{A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release|Any CPU.Build.0 = Release|Any CPU
{A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release|x64.ActiveCfg = Release|Any CPU
Expand Down Expand Up @@ -747,6 +759,14 @@ Global
{C3D82402-F207-4F19-8C57-5AF0FBAF9682}.Release|Any CPU.Build.0 = Release|Any CPU
{C3D82402-F207-4F19-8C57-5AF0FBAF9682}.Release|x64.ActiveCfg = Release|Any CPU
{C3D82402-F207-4F19-8C57-5AF0FBAF9682}.Release|x64.Build.0 = Release|Any CPU
{38ED61F4-FA22-4DE9-B0C4-91F327F4EE31}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{38ED61F4-FA22-4DE9-B0C4-91F327F4EE31}.Debug|Any CPU.Build.0 = Debug|Any CPU
{38ED61F4-FA22-4DE9-B0C4-91F327F4EE31}.Debug|x64.ActiveCfg = Debug|Any CPU
{38ED61F4-FA22-4DE9-B0C4-91F327F4EE31}.Debug|x64.Build.0 = Debug|Any CPU
{38ED61F4-FA22-4DE9-B0C4-91F327F4EE31}.Release|Any CPU.ActiveCfg = Release|Any CPU
{38ED61F4-FA22-4DE9-B0C4-91F327F4EE31}.Release|Any CPU.Build.0 = Release|Any CPU
{38ED61F4-FA22-4DE9-B0C4-91F327F4EE31}.Release|x64.ActiveCfg = Release|Any CPU
{38ED61F4-FA22-4DE9-B0C4-91F327F4EE31}.Release|x64.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -782,6 +802,7 @@ Global
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{00E38F77-1E61-4CDF-8F97-1417D4E85053} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{A7222F41-1CF0-47D9-B80C-B4D77B027A61} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{A7222F94-2AF1-10C9-A21C-C4D22B137A69} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{570A0B8A-5463-44D2-8521-54C0CA4CACA9} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{5A79C7F0-3D99-4123-B0DA-7C9FFCD13132} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{8C05642D-C3AA-4972-B02C-93681161A6BC} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
Expand Down Expand Up @@ -825,6 +846,7 @@ Global
{FF0BD187-4451-4A3B-934B-2AE3454896E2} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{BBC3A950-BD68-45AC-9DBD-A8F4D8847745} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{C3D82402-F207-4F19-8C57-5AF0FBAF9682} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{38ED61F4-FA22-4DE9-B0C4-91F327F4EE31} = {DA452A53-2E94-4433-B08C-041EDEC729E6}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
Expand Down
16 changes: 16 additions & 0 deletions README-oneDAL.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# oneDAL supported algorithms

oneAPI Data Analytics Library (oneDAL) is a library providing highly optimized machine learning and data analytics kernels. Some of these kernels is integrated into ML.NET via C++/C# interoperability.

[oneDAL Documentation](http://oneapi-src.github.io/oneDAL/) | [oneDAL Repository](https://github.com/oneapi-src/oneDAL)

> Please note that oneDAL acceleration paths are only available in x64 architectures

Integration consists of:

* A "native" component (under `src/Native/Microsoft.ML.OneDal`) implementing wrapper to pass data and parameters to oneDAL;
* Dispatching to oneDAL kernels inside relevant learners: `OLS` (`src/Microsoft.ML.Mkl.Components`), `Logistic Regression` (`src/Microsoft.ML.StandardTrainers`), `Random Forest` (`src/Microsoft.ML.FastTree`);

## Running ML.NET trainers with dispatching to oneDAL kernels

Currently, dispatching to oneDAL inside ML.NET is regulated by `MLNET_BACKEND` environment variable. If it's set to `ONEDAL`, oneDAL kernel will be used, otherwise - default ML.NET.
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net7.0</TargetFramework>
<!--This ensures that we can never make the mistake of adding this as a friend assembly. Please don't remove.-->
<PublicSign>false</PublicSign>
</PropertyGroup>

<!--
<ItemGroup>
<Reference Include="Microsoft.ML">
<HintPath>../machinelearning/artifacts/bin/Microsoft.ML/Debug/netstandard2.0/Microsoft.ML.dll</HintPath>
</Reference>
<Reference Include="Microsoft.ML.Core">
<HintPath>../machinelearning/artifacts/bin/Microsoft.ML/Debug/netstandard2.0/Microsoft.ML.Core.dll</HintPath>
</Reference>
<Reference Include="Microsoft.ML.Data">
<HintPath>../machinelearning/artifacts/bin/Microsoft.ML/Debug/netstandard2.0/Microsoft.ML.Data.dll</HintPath>
</Reference>
<Reference Include="Microsoft.ML.Mkl.Components">
<HintPath>../machinelearning/artifacts/bin/Microsoft.ML.Mkl.Components/Debug/netstandard2.0/Microsoft.ML.Mkl.Components.dll</HintPath>
</Reference>
<Reference Include="Microsoft.ML.DataView">
<HintPath>../machinelearning/artifacts/bin/Microsoft.ML/Debug/netstandard2.0/Microsoft.ML.DataView.dll</HintPath>
</Reference>
<Reference Include="Microsoft.ML.Transforms">
<HintPath>../machinelearning/artifacts/bin/Microsoft.ML.Transforms/Debug/netstandard2.0/Microsoft.ML.Transforms.dll</HintPath>
</Reference>
<Reference Include="Microsoft.ML.StandardTrainers">
<HintPath>../machinelearning/artifacts/bin/Microsoft.ML/Debug/netstandard2.0/Microsoft.ML.StandardTrainers.dll</HintPath>
</Reference>
<Reference Include="Microsoft.ML.FastTree">
<HintPath>../machinelearning/artifacts/bin/Microsoft.ML.FastTree/Debug/netstandard2.0/Microsoft.ML.FastTree.dll</HintPath>
</Reference>
<Reference Include="Microsoft.ML.OneDal">
<HintPath>../machinelearning/artifacts/bin/Microsoft.ML.OneDal/Debug/netstandard2.0/Microsoft.ML.OneDal.dll</HintPath>
</Reference>
<PackageReference Include="Newtonsoft.Json" Version="13.0.1"/>
</ItemGroup>
-->

<ItemGroup>
<ProjectReference Include="..\..\..\src\Microsoft.ML.Transforms\Microsoft.ML.Transforms.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.Mkl.Components\Microsoft.ML.Mkl.Components.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.StandardTrainers\Microsoft.ML.StandardTrainers.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.FastTree\Microsoft.ML.FastTree.csproj" />
<ProjectReference Include="..\..\..\src\Microsoft.ML.OneDal\Microsoft.ML.OneDal.csproj" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.1"/>
</ItemGroup>

</Project>
203 changes: 203 additions & 0 deletions docs/samples/Microsoft.ML.Samples.OneDal/Program.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;
using Microsoft.ML.Trainers.FastTree;
using Newtonsoft.Json;

namespace Microsoft.ML.Samples.OneDal
{
class Program
{
public static IDataView[] LoadData(
MLContext mlContext, string trainingFile, string testingFile,
string task, string label = "target", char separator = ',')
{
List<IDataView> dataList = new List<IDataView>();
System.IO.StreamReader file = new System.IO.StreamReader(trainingFile);
string header = file.ReadLine();
file.Close();
string[] headerArray = header.Split(separator);
List<TextLoader.Column> columns = new List<TextLoader.Column>();
foreach (string column in headerArray)
{
if (column == label)
{
if (task == "binary")
columns.Add(new TextLoader.Column(column, DataKind.Boolean, Array.IndexOf(headerArray, column)));
else
columns.Add(new TextLoader.Column(column, DataKind.Single, Array.IndexOf(headerArray, column)));
}
else
{
columns.Add(new TextLoader.Column(column, DataKind.Single, Array.IndexOf(headerArray, column)));
}
}

var loader = mlContext.Data.CreateTextLoader(
separatorChar: separator,
hasHeader: true,
columns: columns.ToArray()
);
dataList.Add(loader.Load(trainingFile));
dataList.Add(loader.Load(testingFile));
return dataList.ToArray();
}

public static string[] GetFeaturesArray(IDataView data, string labelName = "target")
{
List<string> featuresList = new List<string>();
var nColumns = data.Schema.Count;
var columnsEnumerator = data.Schema.GetEnumerator();
for (int i = 0; i < nColumns; i++)
{
columnsEnumerator.MoveNext();
if (columnsEnumerator.Current.Name != labelName)
featuresList.Add(columnsEnumerator.Current.Name);
}

return featuresList.ToArray();
}

public static double[] RunRandomForestClassification(MLContext mlContext, IDataView trainingData, IDataView testingData, string labelName, int numberOfTrees, int numberOfLeaves)
{
var featuresArray = GetFeaturesArray(trainingData, labelName);
var preprocessingPipeline = mlContext.Transforms.Concatenate("Features", featuresArray);
var preprocessedTrainingData = preprocessingPipeline.Fit(trainingData).Transform(trainingData);
var preprocessedTestingData = preprocessingPipeline.Fit(trainingData).Transform(testingData);

FastForestBinaryTrainer.Options options = new FastForestBinaryTrainer.Options();
options.LabelColumnName = labelName;
options.FeatureColumnName = "Features";
options.NumberOfTrees = numberOfTrees;
options.NumberOfLeaves = numberOfLeaves;
options.MinimumExampleCountPerLeaf = 5;
options.FeatureFraction = 1.0;

var trainer = mlContext.BinaryClassification.Trainers.FastForest(options);

ITransformer model = trainer.Fit(preprocessedTrainingData);

IDataView trainingPredictions = model.Transform(preprocessedTrainingData);
var trainingMetrics = mlContext.BinaryClassification.EvaluateNonCalibrated(trainingPredictions, labelColumnName: labelName);
IDataView testingPredictions = model.Transform(preprocessedTestingData);
var testingMetrics = mlContext.BinaryClassification.EvaluateNonCalibrated(testingPredictions, labelColumnName: labelName);

double[] metrics = new double[4];
metrics[0] = trainingMetrics.Accuracy;
metrics[1] = testingMetrics.Accuracy;
metrics[2] = trainingMetrics.F1Score;
metrics[3] = testingMetrics.F1Score;
return metrics;
}

public static double[] RunRandomForestRegression(MLContext mlContext, IDataView trainingData, IDataView testingData, string labelName, int numberOfTrees, int numberOfLeaves)
{
var featuresArray = GetFeaturesArray(trainingData, labelName);
var preprocessingPipeline = mlContext.Transforms.Concatenate("Features", featuresArray);
var preprocessedTrainingData = preprocessingPipeline.Fit(trainingData).Transform(trainingData);
var preprocessedTestingData = preprocessingPipeline.Fit(trainingData).Transform(testingData);

FastForestRegressionTrainer.Options options = new FastForestRegressionTrainer.Options();
options.LabelColumnName = labelName;
options.FeatureColumnName = "Features";
options.NumberOfTrees = numberOfTrees;
options.NumberOfLeaves = numberOfLeaves;
options.MinimumExampleCountPerLeaf = 5;
options.FeatureFraction = 1.0;

var trainer = mlContext.Regression.Trainers.FastForest(options);

ITransformer model = trainer.Fit(preprocessedTrainingData);

IDataView trainingPredictions = model.Transform(preprocessedTrainingData);
var trainingMetrics = mlContext.Regression.Evaluate(trainingPredictions, labelColumnName: labelName);
IDataView testingPredictions = model.Transform(preprocessedTestingData);
var testingMetrics = mlContext.Regression.Evaluate(testingPredictions, labelColumnName: labelName);

double[] metrics = new double[4];
metrics[0] = trainingMetrics.RootMeanSquaredError;
metrics[1] = testingMetrics.RootMeanSquaredError;
metrics[2] = trainingMetrics.RSquared;
metrics[3] = testingMetrics.RSquared;
return metrics;
}

public static double[] RunOLSRegression(MLContext mlContext, IDataView trainingData, IDataView testingData, string labelName)
{
var featuresArray = GetFeaturesArray(trainingData, labelName);
var preprocessingPipeline = mlContext.Transforms.Concatenate("Features", featuresArray);
var preprocessedTrainingData = preprocessingPipeline.Fit(trainingData).Transform(trainingData);
var preprocessedTestingData = preprocessingPipeline.Fit(trainingData).Transform(testingData);

OlsTrainer.Options options = new OlsTrainer.Options();
options.LabelColumnName = labelName;
options.FeatureColumnName = "Features";

var trainer = mlContext.Regression.Trainers.Ols(options);

ITransformer model = trainer.Fit(preprocessedTrainingData);

IDataView trainingPredictions = model.Transform(preprocessedTrainingData);
var trainingMetrics = mlContext.Regression.Evaluate(trainingPredictions, labelColumnName: labelName);
IDataView testingPredictions = model.Transform(preprocessedTestingData);
var testingMetrics = mlContext.Regression.Evaluate(testingPredictions, labelColumnName: labelName);

double[] metrics = new double[4];
metrics[0] = trainingMetrics.RootMeanSquaredError;
metrics[1] = testingMetrics.RootMeanSquaredError;
metrics[2] = trainingMetrics.RSquared;
metrics[3] = testingMetrics.RSquared;
return metrics;
}

static void Main(string[] args)
{
// args[0] - training data filename
// args[1] - testing data filename
// args[2] - machine learning task (regression, binary)
// args[3] - machine learning algorithm (RandomForest, OLS)
// Random Forest parameters:
// args[4] - NumberOfTrees
// args[5] - NumberOfLeaves
var mlContext = new MLContext(seed: 42);
// data[0] - training subset
// data[1] - testing subset
IDataView[] data = LoadData(mlContext, args[0], args[1], args[2]);
string labelName = "target";

var mainWatch = System.Diagnostics.Stopwatch.StartNew();
double[] metrics;
if (args[3] == "RandomForest")
{
int numberOfTrees = Int32.Parse(args[4]);
int numberOfLeaves = Int32.Parse(args[5]);
if (args[2] == "binary")
{

metrics = RunRandomForestClassification(mlContext, data[0], data[1], labelName, numberOfTrees, numberOfLeaves);
mainWatch.Stop();
Console.WriteLine("algorithm,all workflow time[ms],training accuracy,testing accuracy,training F1 score,testing F1 score");
Console.WriteLine($"Random Forest Binary,{mainWatch.Elapsed.TotalMilliseconds},{metrics[0]},{metrics[1]},{metrics[2]},{metrics[3]}");
}
else
{
metrics = RunRandomForestRegression(mlContext, data[0], data[1], labelName, numberOfTrees, numberOfLeaves);
mainWatch.Stop();
Console.WriteLine("algorithm,all workflow time[ms],training RMSE,testing RMSE,training R2 score,testing R2 score");
Console.WriteLine($"Random Forest Regression,{mainWatch.Elapsed.TotalMilliseconds},{metrics[0]},{metrics[1]},{metrics[2]},{metrics[3]}");
}
}
else if (args[3] == "OLS")
{
metrics = RunOLSRegression(mlContext, data[0], data[1], labelName);
mainWatch.Stop();
Console.WriteLine("algorithm,all workflow time[ms],training RMSE,testing RMSE,training R2 score,testing R2 score");
Console.WriteLine($"OLS Regression,{mainWatch.Elapsed.TotalMilliseconds},{metrics[0]},{metrics[1]},{metrics[2]},{metrics[3]}");
}
}
}
}

3 changes: 3 additions & 0 deletions docs/samples/Microsoft.ML.Samples.OneDal/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
numpy
pandas
scikit-learn
Loading