Skip to content

Commit eeba2ee

Browse files
authored
Bail to default implementation upon any unforeseen situation (#6538)
1 parent e451fb7 commit eeba2ee

File tree

4 files changed

+29
-14
lines changed

4 files changed

+29
-14
lines changed

src/Microsoft.ML.FastTree/RandomForestClassification.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,15 @@ public static extern unsafe int DecisionForestClassificationCompute(
265265
[BestFriend]
266266
private bool IsDispatchingToOneDalEnabled()
267267
{
268-
return OneDalUtils.IsDispatchingEnabled();
268+
try
269+
{
270+
return OneDalUtils.IsDispatchingEnabled();
271+
}
272+
catch (Exception)
273+
{
274+
// Bail to default implementation upon encountering any situation where dispatch failed
275+
return false;
276+
}
269277
}
270278

271279
[BestFriend]

src/Microsoft.ML.FastTree/RandomForestRegression.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,15 @@ public static extern unsafe int DecisionForestRegressionCompute(
398398
[BestFriend]
399399
private bool IsDispatchingToOneDalEnabled()
400400
{
401-
return OneDalUtils.IsDispatchingEnabled();
401+
try
402+
{
403+
return OneDalUtils.IsDispatchingEnabled();
404+
}
405+
catch (Exception)
406+
{
407+
// fall back to original implementation for any circumstance that prevents dispatching
408+
return false;
409+
}
402410
}
403411

404412
[BestFriend]

src/Microsoft.ML.Mkl.Components/OlsLinearRegression.cs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
using Microsoft.ML.Internal.Internallearn;
1515
using Microsoft.ML.Internal.Utilities;
1616
using Microsoft.ML.Model;
17+
using Microsoft.ML.OneDal;
1718
using Microsoft.ML.Runtime;
1819
using Microsoft.ML.Trainers;
19-
using Microsoft.ML.OneDal;
2020

2121
[assembly: LoadableClass(OlsTrainer.Summary, typeof(OlsTrainer), typeof(OlsTrainer.Options),
2222
new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
@@ -409,7 +409,15 @@ private void ComputeMklRegression(IChannel ch, FloatLabelCursor.Factory cursorFa
409409
[BestFriend]
410410
private bool IsDispatchingToOneDalEnabled()
411411
{
412-
return OneDalUtils.IsDispatchingEnabled();
412+
try
413+
{
414+
return OneDalUtils.IsDispatchingEnabled();
415+
}
416+
catch (Exception)
417+
{
418+
// Bail to default implementation upon any situation that prevents dispatching
419+
return false;
420+
}
413421
}
414422

415423
private OlsModelParameters TrainCore(IChannel ch, FloatLabelCursor.Factory cursorFactory, int featureCount)

src/Microsoft.ML.OneDal/OneDalUtils.cs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using System.IO;
76
using System.Collections.Generic;
7+
using System.IO;
88
using System.Linq;
99
using System.Runtime.InteropServices;
1010
using Microsoft.ML.Internal.Utilities;
@@ -17,14 +17,6 @@ namespace Microsoft.ML.OneDal
1717
internal static class OneDalUtils
1818
{
1919

20-
#if false
21-
[BestFriend]
22-
internal static bool IsDispatchingEnabled()
23-
{
24-
return Environment.GetEnvironmentVariable("MLNET_BACKEND") == "ONEDAL" &&
25-
System.Runtime.InteropServices.RuntimeInformation.ProcessArchitecture == System.Runtime.InteropServices.Architecture.X64;
26-
}
27-
#else
2820
[BestFriend]
2921
internal static bool IsDispatchingEnabled()
3022
{
@@ -47,7 +39,6 @@ internal static bool IsDispatchingEnabled()
4739
}
4840
return false;
4941
}
50-
#endif
5142

5243
[BestFriend]
5344
internal static long GetTrainData(IChannel channel, FloatLabelCursor.Factory cursorFactory, ref List<float> featuresList, ref List<float> labelsList, int numberOfFeatures)

0 commit comments

Comments
 (0)