Skip to content

Data splits to default to MLContext seed when not specified #4764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 2, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 4 additions & 8 deletions src/Microsoft.ML.Data/DataLoadSave/DataOperationsCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -495,13 +495,14 @@ internal static IEnumerable<TrainTestData> CrossValidationSplit(IHostEnvironment
/// </summary>
internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDataView data, ref string samplingKeyColumn, int? seed = null)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("rand");
// We need to handle two cases: if samplingKeyColumn is provided, we use hashJoin to
// build a single hash of it. If it is not, we generate a random number.

if (samplingKeyColumn == null)
{
samplingKeyColumn = data.Schema.GetTempColumnName("SamplingKeyColumn");
data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)seed);
data = new GenerateNumberTransform(env, data, samplingKeyColumn, (uint?)(seed ?? host.Rand.Next()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why fallback to host.Rand.Next() instead of the fixed seed given to the MLContext?

I think this is doing (1), then (4) from: #4752 (comment)

TrainTestSplit etc should listen to MLContext seed if not given by the user for the specific component.

I think the precedence for PRNG seeds should be:

  1. User seed specified for the component instance (or from parent component if hiding behind another)
  2. User seed specified in MLContext
  3. Default seed in component (unsure if we should remove these, might be a component-by-component discussion)
  4. Random (and more random than current -- see Use a GUID when creating the temp path #4645 (comment))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MLContext seed isn't exposed, i.e. I cannot directly use mlContext.Seed. The seed for Host.Rand is set from MLContext though.

}
else
{
Expand All @@ -517,11 +518,7 @@ internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDa
// instead of having two hash transformations.
var origStratCol = samplingKeyColumn;
samplingKeyColumn = data.Schema.GetTempColumnName(samplingKeyColumn);
HashingEstimator.ColumnOptionsInternal columnOptions;
if (seed.HasValue)
columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)seed.Value);
else
columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30);
var columnOptions = new HashingEstimator.ColumnOptionsInternal(samplingKeyColumn, origStratCol, 30, (uint)(seed ?? host.Rand.Next()));
data = new HashingEstimator(env, columnOptions).Fit(data).Transform(data);
}
else
Expand All @@ -533,7 +530,6 @@ internal static void EnsureGroupPreservationColumn(IHostEnvironment env, ref IDa
data = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumnOptions(samplingKeyColumn, origStratCol, ensureZeroUntouched: true)).Fit(data).Transform(data);
}
}

}
}
}
Expand Down