Skip to content

Commit adad40c

Browse files
Add deterministic option for LightGBM (#7415)
* add deterministic option * updated core manifest * updated netfx core manifest
1 parent f7c9790 commit adad40c

File tree

4 files changed

+243
-1
lines changed

4 files changed

+243
-1
lines changed

src/Microsoft.ML.LightGbm/LightGbmTrainerBase.cs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@ public class OptionsBase : TrainerInputBaseWithGroupId
6161
{nameof(CategoricalSmoothing), "cat_smooth" },
6262
{nameof(L2CategoricalRegularization), "cat_l2" },
6363
{nameof(HandleMissingValue), "use_missing" },
64-
{nameof(UseZeroAsMissingValue), "zero_as_missing" }
64+
{nameof(UseZeroAsMissingValue), "zero_as_missing" },
65+
{nameof(Deterministic), "deterministic"},
66+
{nameof(ForceRowWise), "force_row_wise"},
67+
{nameof(ForceColumnWise), "force_col_wise"},
6568
};
6669

6770
internal string GetOptionName(string name)
@@ -236,6 +239,24 @@ private protected OptionsBase() { }
236239
[Argument(ArgumentType.AtMostOnce, HelpText = "Sets the random seed for LightGBM to use.")]
237240
public int? Seed;
238241

242+
/// <summary>
243+
/// Setting this to true should ensure the stable results when using the same data and the same parameters and different num_threads.
244+
/// </summary>
245+
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to use deterministic algorithm.")]
246+
public bool Deterministic = false;
247+
248+
/// <summary>
249+
/// Whether to force column-wise histogram building.
250+
/// </summary>
251+
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to force column-wise histogram building.")]
252+
public bool ForceColumnWise = false;
253+
254+
/// <summary>
255+
/// Whether to force row-wise histogram building.
256+
/// </summary>
257+
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether to force row-wise histogram building.")]
258+
public bool ForceRowWise = false;
259+
239260
[Argument(ArgumentType.Multiple, HelpText = "Parallel LightGBM Learning Algorithm", ShortName = "parag")]
240261
internal ISupportParallel ParallelTrainer = new SingleTrainerFactory();
241262

@@ -279,6 +300,9 @@ internal virtual Dictionary<string, object> ToDictionary(IHost host)
279300
res[GetOptionName(nameof(MaximumCategoricalSplitPointCount))] = MaximumCategoricalSplitPointCount;
280301
res[GetOptionName(nameof(CategoricalSmoothing))] = CategoricalSmoothing;
281302
res[GetOptionName(nameof(L2CategoricalRegularization))] = L2CategoricalRegularization;
303+
res[GetOptionName(nameof(Deterministic))] = Deterministic;
304+
res[GetOptionName(nameof(ForceColumnWise))] = ForceColumnWise;
305+
res[GetOptionName(nameof(ForceRowWise))] = ForceRowWise;
282306

283307
return res;
284308
}

test/BaselineOutput/Common/EntryPoints/core_manifest.json

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11891,6 +11891,33 @@
1189111891
"IsNullable": true,
1189211892
"Default": null
1189311893
},
11894+
{
11895+
"Name": "Deterministic",
11896+
"Type": "Bool",
11897+
"Desc": "Whether to use deterministic algorithm.",
11898+
"Required": false,
11899+
"SortOrder": 150.0,
11900+
"IsNullable": false,
11901+
"Default": false
11902+
},
11903+
{
11904+
"Name": "ForceColumnWise",
11905+
"Type": "Bool",
11906+
"Desc": "Whether to force column-wise histogram building.",
11907+
"Required": false,
11908+
"SortOrder": 150.0,
11909+
"IsNullable": false,
11910+
"Default": false
11911+
},
11912+
{
11913+
"Name": "ForceRowWise",
11914+
"Type": "Bool",
11915+
"Desc": "Whether to force row-wise histogram building.",
11916+
"Required": false,
11917+
"SortOrder": 150.0,
11918+
"IsNullable": false,
11919+
"Default": false
11920+
},
1189411921
{
1189511922
"Name": "ParallelTrainer",
1189611923
"Type": {
@@ -12410,6 +12437,33 @@
1241012437
"IsNullable": true,
1241112438
"Default": null
1241212439
},
12440+
{
12441+
"Name": "Deterministic",
12442+
"Type": "Bool",
12443+
"Desc": "Whether to use deterministic algorithm.",
12444+
"Required": false,
12445+
"SortOrder": 150.0,
12446+
"IsNullable": false,
12447+
"Default": false
12448+
},
12449+
{
12450+
"Name": "ForceColumnWise",
12451+
"Type": "Bool",
12452+
"Desc": "Whether to force column-wise histogram building.",
12453+
"Required": false,
12454+
"SortOrder": 150.0,
12455+
"IsNullable": false,
12456+
"Default": false
12457+
},
12458+
{
12459+
"Name": "ForceRowWise",
12460+
"Type": "Bool",
12461+
"Desc": "Whether to force row-wise histogram building.",
12462+
"Required": false,
12463+
"SortOrder": 150.0,
12464+
"IsNullable": false,
12465+
"Default": false
12466+
},
1241312467
{
1241412468
"Name": "ParallelTrainer",
1241512469
"Type": {
@@ -12929,6 +12983,33 @@
1292912983
"IsNullable": true,
1293012984
"Default": null
1293112985
},
12986+
{
12987+
"Name": "Deterministic",
12988+
"Type": "Bool",
12989+
"Desc": "Whether to use deterministic algorithm.",
12990+
"Required": false,
12991+
"SortOrder": 150.0,
12992+
"IsNullable": false,
12993+
"Default": false
12994+
},
12995+
{
12996+
"Name": "ForceColumnWise",
12997+
"Type": "Bool",
12998+
"Desc": "Whether to force column-wise histogram building.",
12999+
"Required": false,
13000+
"SortOrder": 150.0,
13001+
"IsNullable": false,
13002+
"Default": false
13003+
},
13004+
{
13005+
"Name": "ForceRowWise",
13006+
"Type": "Bool",
13007+
"Desc": "Whether to force row-wise histogram building.",
13008+
"Required": false,
13009+
"SortOrder": 150.0,
13010+
"IsNullable": false,
13011+
"Default": false
13012+
},
1293213013
{
1293313014
"Name": "ParallelTrainer",
1293413015
"Type": {
@@ -13409,6 +13490,33 @@
1340913490
"IsNullable": true,
1341013491
"Default": null
1341113492
},
13493+
{
13494+
"Name": "Deterministic",
13495+
"Type": "Bool",
13496+
"Desc": "Whether to use deterministic algorithm.",
13497+
"Required": false,
13498+
"SortOrder": 150.0,
13499+
"IsNullable": false,
13500+
"Default": false
13501+
},
13502+
{
13503+
"Name": "ForceColumnWise",
13504+
"Type": "Bool",
13505+
"Desc": "Whether to force column-wise histogram building.",
13506+
"Required": false,
13507+
"SortOrder": 150.0,
13508+
"IsNullable": false,
13509+
"Default": false
13510+
},
13511+
{
13512+
"Name": "ForceRowWise",
13513+
"Type": "Bool",
13514+
"Desc": "Whether to force row-wise histogram building.",
13515+
"Required": false,
13516+
"SortOrder": 150.0,
13517+
"IsNullable": false,
13518+
"Default": false
13519+
},
1341213520
{
1341313521
"Name": "ParallelTrainer",
1341413522
"Type": {

test/BaselineOutput/Common/EntryPoints/netcoreapp/core_manifest.json

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11891,6 +11891,33 @@
1189111891
"IsNullable": true,
1189211892
"Default": null
1189311893
},
11894+
{
11895+
"Name": "Deterministic",
11896+
"Type": "Bool",
11897+
"Desc": "Whether to use deterministic algorithm.",
11898+
"Required": false,
11899+
"SortOrder": 150.0,
11900+
"IsNullable": false,
11901+
"Default": false
11902+
},
11903+
{
11904+
"Name": "ForceColumnWise",
11905+
"Type": "Bool",
11906+
"Desc": "Whether to force column-wise histogram building.",
11907+
"Required": false,
11908+
"SortOrder": 150.0,
11909+
"IsNullable": false,
11910+
"Default": false
11911+
},
11912+
{
11913+
"Name": "ForceRowWise",
11914+
"Type": "Bool",
11915+
"Desc": "Whether to force row-wise histogram building.",
11916+
"Required": false,
11917+
"SortOrder": 150.0,
11918+
"IsNullable": false,
11919+
"Default": false
11920+
},
1189411921
{
1189511922
"Name": "ParallelTrainer",
1189611923
"Type": {
@@ -12410,6 +12437,33 @@
1241012437
"IsNullable": true,
1241112438
"Default": null
1241212439
},
12440+
{
12441+
"Name": "Deterministic",
12442+
"Type": "Bool",
12443+
"Desc": "Whether to use deterministic algorithm.",
12444+
"Required": false,
12445+
"SortOrder": 150.0,
12446+
"IsNullable": false,
12447+
"Default": false
12448+
},
12449+
{
12450+
"Name": "ForceColumnWise",
12451+
"Type": "Bool",
12452+
"Desc": "Whether to force column-wise histogram building.",
12453+
"Required": false,
12454+
"SortOrder": 150.0,
12455+
"IsNullable": false,
12456+
"Default": false
12457+
},
12458+
{
12459+
"Name": "ForceRowWise",
12460+
"Type": "Bool",
12461+
"Desc": "Whether to force row-wise histogram building.",
12462+
"Required": false,
12463+
"SortOrder": 150.0,
12464+
"IsNullable": false,
12465+
"Default": false
12466+
},
1241312467
{
1241412468
"Name": "ParallelTrainer",
1241512469
"Type": {
@@ -12929,6 +12983,33 @@
1292912983
"IsNullable": true,
1293012984
"Default": null
1293112985
},
12986+
{
12987+
"Name": "Deterministic",
12988+
"Type": "Bool",
12989+
"Desc": "Whether to use deterministic algorithm.",
12990+
"Required": false,
12991+
"SortOrder": 150.0,
12992+
"IsNullable": false,
12993+
"Default": false
12994+
},
12995+
{
12996+
"Name": "ForceColumnWise",
12997+
"Type": "Bool",
12998+
"Desc": "Whether to force column-wise histogram building.",
12999+
"Required": false,
13000+
"SortOrder": 150.0,
13001+
"IsNullable": false,
13002+
"Default": false
13003+
},
13004+
{
13005+
"Name": "ForceRowWise",
13006+
"Type": "Bool",
13007+
"Desc": "Whether to force row-wise histogram building.",
13008+
"Required": false,
13009+
"SortOrder": 150.0,
13010+
"IsNullable": false,
13011+
"Default": false
13012+
},
1293213013
{
1293313014
"Name": "ParallelTrainer",
1293413015
"Type": {
@@ -13409,6 +13490,33 @@
1340913490
"IsNullable": true,
1341013491
"Default": null
1341113492
},
13493+
{
13494+
"Name": "Deterministic",
13495+
"Type": "Bool",
13496+
"Desc": "Whether to use deterministic algorithm.",
13497+
"Required": false,
13498+
"SortOrder": 150.0,
13499+
"IsNullable": false,
13500+
"Default": false
13501+
},
13502+
{
13503+
"Name": "ForceColumnWise",
13504+
"Type": "Bool",
13505+
"Desc": "Whether to force column-wise histogram building.",
13506+
"Required": false,
13507+
"SortOrder": 150.0,
13508+
"IsNullable": false,
13509+
"Default": false
13510+
},
13511+
{
13512+
"Name": "ForceRowWise",
13513+
"Type": "Bool",
13514+
"Desc": "Whether to force row-wise histogram building.",
13515+
"Required": false,
13516+
"SortOrder": 150.0,
13517+
"IsNullable": false,
13518+
"Default": false
13519+
},
1341213520
{
1341313521
"Name": "ParallelTrainer",
1341413522
"Type": {

test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ public void LightGBMBinaryEstimator()
6666
NumberOfLeaves = 10,
6767
MinimumExampleCountPerLeaf = 2,
6868
UnbalancedSets = false, // default value
69+
Deterministic = true,
70+
ForceRowWise = true
6971
});
7072

7173
var pipeWithTrainer = pipe.Append(trainer);

0 commit comments

Comments
 (0)