Skip to content

Add Seed property to MLContext and use as default for data splits #4775

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/Microsoft.ML.Core/Data/IHostEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ internal interface ICancelable
bool IsCanceled { get; }
}

[BestFriend]
internal interface ISeededEnvironment : IHostEnvironment
{
/// <summary>
/// The seed property that, if assigned, makes components requiring randomness behave deterministically.
/// </summary>
int? Seed { get; }
}

/// <summary>
/// A host is coupled to a component and provides random number generation and concurrency guidance.
/// Note that the random number generation, like the host environment methods, should be accessed only
Expand Down
19 changes: 1 addition & 18 deletions src/Microsoft.ML.Core/Environment/ConsoleEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -366,24 +366,7 @@ protected override void Dispose(bool disposing)
public ConsoleEnvironment(int? seed = null, bool verbose = false,
MessageSensitivity sensitivity = MessageSensitivity.All,
TextWriter outWriter = null, TextWriter errWriter = null, TextWriter testWriter = null)
: this(RandomUtils.Create(seed), verbose, sensitivity, outWriter, errWriter, testWriter)
{
}

// REVIEW: do we really care about custom random? If we do, let's make this ctor public.
/// <summary>
/// Create an ML.NET environment for local execution, with console feedback.
/// </summary>
/// <param name="rand">An custom source of randomness to use in the environment.</param>
/// <param name="verbose">Set to <c>true</c> for fully verbose logging.</param>
/// <param name="sensitivity">Allowed message sensitivity.</param>
/// <param name="outWriter">Text writer to print normal messages to.</param>
/// <param name="errWriter">Text writer to print error messages to.</param>
/// <param name="testWriter">Optional TextWriter to write messages if the host is a test environment.</param>
private ConsoleEnvironment(Random rand, bool verbose = false,
MessageSensitivity sensitivity = MessageSensitivity.All,
TextWriter outWriter = null, TextWriter errWriter = null, TextWriter testWriter = null)
: base(rand, verbose, nameof(ConsoleEnvironment))
: base(seed, verbose, nameof(ConsoleEnvironment))
{
Contracts.CheckValueOrNull(outWriter);
Contracts.CheckValueOrNull(errWriter);
Expand Down
13 changes: 8 additions & 5 deletions src/Microsoft.ML.Core/Environment/HostEnvironmentBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ internal interface IMessageSource
/// query progress.
/// </summary>
[BestFriend]
internal abstract class HostEnvironmentBase<TEnv> : ChannelProviderBase, IHostEnvironment, IChannelProvider, ICancelable
internal abstract class HostEnvironmentBase<TEnv> : ChannelProviderBase, ISeededEnvironment, IChannelProvider, ICancelable
where TEnv : HostEnvironmentBase<TEnv>
{
void ICancelable.CancelExecution()
Expand Down Expand Up @@ -330,6 +330,9 @@ public void RemoveListener(Action<IMessageSource, TMessage> listenerFunc)

// The random number generator for this host.
private readonly Random _rand;

public int? Seed { get; }

// A dictionary mapping the type of message to the Dispatcher that gets the strongly typed dispatch delegate.
protected readonly ConcurrentDictionary<Type, Dispatcher> ListenerDict;

Expand All @@ -345,14 +348,14 @@ public void RemoveListener(Action<IMessageSource, TMessage> listenerFunc)
private readonly List<WeakReference<IHost>> _children;

/// <summary>
/// The main constructor.
/// The main constructor.
/// </summary>
protected HostEnvironmentBase(Random rand, bool verbose,
protected HostEnvironmentBase(int? seed, bool verbose,
string shortName = null, string parentFullName = null)
: base(shortName, parentFullName, verbose)
{
Contracts.CheckValueOrNull(rand);
_rand = rand ?? RandomUtils.Create();
Seed = seed;
_rand = RandomUtils.Create(Seed);
ListenerDict = new ConcurrentDictionary<Type, Dispatcher>();
ProgressTracker = new ProgressReporting.ProgressTracker(this);
_cancelLock = new object();
Expand Down
11 changes: 8 additions & 3 deletions src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -496,13 +496,12 @@ internal static IEnumerable<TrainTestData> CrossValidationSplit(IHostEnvironment
internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int? seed = null)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("rand");
// We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to
// build a single hash of it. If it is not, we generate a random number.
if (samplingKeyColumn == null)
{
samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn");
data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? host.Rand.Next()));
data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? ((ISeededEnvironment)env).Seed));
}
else
{
Expand All @@ -518,7 +517,13 @@ internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDa
// instead of having two hash transformations.
var origStratCol = samplingKeyColumn;
samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
var columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)(seed ?? host.Rand.Next()));
HashingEstimator.ColumnOptionsInternal columnOptions;
if (seed.HasValue)
columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)seed.Value);
else if (((ISeededEnvironment)env).Seed.HasValue)
columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)((ISeededEnvironment)env).Seed.Value);
else
columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30);
data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
}
else
Expand Down
3 changes: 2 additions & 1 deletion src/Microsoft.ML.Data/MLContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace Microsoft.ML
/// create components for data preparation, feature enginering, training, prediction, model evaluation.
/// It also allows logging, execution control, and the ability set repeatable random numbers.
/// </summary>
public sealed class MLContext : IHostEnvironment
public sealed class MLContext : ISeededEnvironment
{
// REVIEW: consider making LocalEnvironment and MLContext the same class instead of encapsulation.
private readonly LocalEnvironment _env;
Expand Down Expand Up @@ -140,6 +140,7 @@ private void ProcessMessage(IMessageSource source, ChannelMessage message)
IChannel IChannelProvider.Start(string name) => _env.Start(name);
IPipe<TMessage> IChannelProvider.StartPipe<TMessage>(string name) => _env.StartPipe<TMessage>(name);
IProgressChannel IProgressChannelProvider.StartProgressChannel(string name) => _env.StartProgressChannel(name);
int? ISeededEnvironment.Seed => _env.Seed;

[BestFriend]
internal void CancelExecution() => ((ICancelable)_env).CancelExecution();
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Data/Utilities/LocalEnvironment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ protected override void Dispose(bool disposing)
/// </summary>
/// <param name="seed">Random seed. Set to <c>null</c> for a non-deterministic environment.</param>
public LocalEnvironment(int? seed = null)
: base(RandomUtils.Create(seed), verbose: false)
: base(seed, verbose: false)
{
}

Expand Down
2 changes: 1 addition & 1 deletion test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public void AutoFitBinaryTest()
[Fact]
public void AutoFitMultiTest()
{
var context = new MLContext(1);
var context = new MLContext(42);
var columnInference = context.Auto().InferColumns(DatasetUtil.TrivialMulticlassDatasetPath, DatasetUtil.TrivialMulticlassDatasetLabel);
var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions);
var trainData = textLoader.Load(DatasetUtil.TrivialMulticlassDatasetPath);
Expand Down