Skip to content

Fixes #4385 about calling the Create methods when loading models from disk #4485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Nov 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 93 additions & 7 deletions src/Microsoft.ML.Core/ComponentModel/ComponentCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,54 @@

namespace Microsoft.ML.Runtime
{

internal static class Extension
{
internal static AccessModifier Accessmodifier(this MethodInfo methodInfo)
{
if (methodInfo.IsFamilyAndAssembly)
return AccessModifier.PrivateProtected;
if (methodInfo.IsPrivate)
return AccessModifier.Private;
if (methodInfo.IsFamily)
return AccessModifier.Protected;
if (methodInfo.IsFamilyOrAssembly)
return AccessModifier.ProtectedInternal;
if (methodInfo.IsAssembly)
return AccessModifier.Internal;
if (methodInfo.IsPublic)
return AccessModifier.Public;
throw new ArgumentException("Did not find access modifier", "methodInfo");
}

internal static AccessModifier Accessmodifier(this ConstructorInfo constructorInfo)
{
if (constructorInfo.IsFamilyAndAssembly)
return AccessModifier.PrivateProtected;
if (constructorInfo.IsPrivate)
return AccessModifier.Private;
if (constructorInfo.IsFamily)
return AccessModifier.Protected;
if (constructorInfo.IsFamilyOrAssembly)
return AccessModifier.ProtectedInternal;
if (constructorInfo.IsAssembly)
return AccessModifier.Internal;
if (constructorInfo.IsPublic)
return AccessModifier.Public;
throw new ArgumentException("Did not find access modifier", "constructorInfo");
}

internal enum AccessModifier
{
PrivateProtected,
Private,
Protected,
ProtectedInternal,
Internal,
Public
}
}

/// <summary>
/// This catalogs instantiatable components (aka, loadable classes). Components are registered via
/// a descendant of <see cref="LoadableClassAttributeBase"/>, identifying the names and signature types under which the component
Expand Down Expand Up @@ -414,21 +462,59 @@ private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTyp
ctor = null;
create = null;
requireEnvironment = false;
bool requireEnvironmentCtor = false;
bool requireEnvironmentCreate = false;
var parmTypesWithEnv = Utils.Concat(new Type[1] { typeof(IHostEnvironment) }, parmTypes);

if (Utils.Size(parmTypes) == 0 && (getter = FindInstanceGetter(instType, loaderType)) != null)
return true;
if (instType.IsAssignableFrom(loaderType) && (ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypes ?? Type.EmptyTypes, null)) != null)
return true;
if (instType.IsAssignableFrom(loaderType) && (ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypesWithEnv ?? Type.EmptyTypes, null)) != null)

// Find both 'ctor' and 'create' methods if available
if (instType.IsAssignableFrom(loaderType))
{
if ((ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypes ?? Type.EmptyTypes, null)) == null)
{
if ((ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypesWithEnv ?? Type.EmptyTypes, null)) != null)
requireEnvironmentCtor = true;
}
}

if ((create = FindCreateMethod(instType, loaderType, parmTypes ?? Type.EmptyTypes)) == null)
{
requireEnvironment = true;
if ((create = FindCreateMethod(instType, loaderType, parmTypesWithEnv ?? Type.EmptyTypes)) != null)
requireEnvironmentCreate = true;
}

if (ctor != null && create != null)
{
Copy link
Member Author

@antoniovs1029 antoniovs1029 Nov 18, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like this nested if's, I am not sure if they're legible enough, and it is ambiguous what should happen if there's a 'protected' create or constructor method (which, I believe, never happens in the codebase...). Still, this gets the job done.

I can think of a couple of ways of making this, but not sure if they would be more legible. Please, let me know if I should rewrite this in another way. #Resolved

Copy link
Member Author

@antoniovs1029 antoniovs1029 Nov 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I've changed this nested if's in a new iteration (see this comment) but I am not sure if I prefer the nested of's or the new solution. #Resolved

// If both 'ctor' and 'create' methods were found
// Choose the one that is 'more' public
// If they have the same visibility, then throw an exception, since this shouldn't happen.

if (ctor.Accessmodifier() == create.Accessmodifier())
{
Copy link
Member Author

@antoniovs1029 antoniovs1029 Nov 19, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I changed the nested if's I had (link to comment) for this other solution using the Accessmodifier() extension method (as suggested by @yaeldekel ). Although I think this one is more legible, I wouldn't be sure if it's worth it to create the extension method only for this.... So let me know your opinions, Thanks! #Resolved

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me this seems cleaner.

A couple of more ways you can decrease the amount of "if/else"s in the code:

  1. you can put the assignment inside the if condition, like this:
if ((ctor = loaderType.GetConstructor(...)) == null)
  1. If you throw or return inside the "if", then you don't need the "else":
if (ctor.Accessmodifier() == create.Accessmodifier())
    throw ...
if (ctor.Accessmodifier() > create.Accessmodifier())
{
    ...
    return true
}
if (ctor.Accessmodifier() < create.Accessmodifier())
...
etc.

In reply to: 347688260 [](ancestors = 347688260)

throw Contracts.Except($"Can't load type {instType}, because it has both create and constructor methods with the same visibility. Please indicate which one should be used by changing either the signature or the visibility of one of them.");
}
if (ctor.Accessmodifier() > create.Accessmodifier())
{
create = null;
requireEnvironment = requireEnvironmentCtor;
return true;
}
ctor = null;
requireEnvironment = requireEnvironmentCreate;
return true;
}
if ((create = FindCreateMethod(instType, loaderType, parmTypes ?? Type.EmptyTypes)) != null)

if (ctor != null && create == null)
{
requireEnvironment = requireEnvironmentCtor;
return true;
if ((create = FindCreateMethod(instType, loaderType, parmTypesWithEnv ?? Type.EmptyTypes)) != null)
}

if (ctor == null && create != null)
{
requireEnvironment = true;
requireEnvironment = requireEnvironmentCreate;
return true;
}

Expand Down
8 changes: 4 additions & 4 deletions src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ private SchemaBindableCalibratedModelParameters(IHostEnvironment env, ModelLoadC
_featureContribution = SubModel as IFeatureContributionMapper;
}

private static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx)
internal static CalibratedModelParametersBase Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
Expand Down Expand Up @@ -1224,7 +1224,7 @@ private NaiveCalibrator(IHostEnvironment env, ModelLoadContext ctx)
_host.CheckDecode(_binProbs.All(x => (0 <= x && x <= 1)));
}

private static NaiveCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
internal static NaiveCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down Expand Up @@ -1675,7 +1675,7 @@ private PlattCalibrator(IHostEnvironment env, ModelLoadContext ctx)
_host.CheckDecode(FloatUtils.IsFinite(Offset));
}

private static PlattCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
internal static PlattCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down Expand Up @@ -1972,7 +1972,7 @@ private IsotonicCalibrator(IHostEnvironment env, ModelLoadContext ctx)
_host.CheckDecode(valuePrev <= 1);
}

private static IsotonicCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
internal static IsotonicCalibrator Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
}

// Factory method for SignatureLoadModel.
private static FeatureContributionCalculatingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
internal static FeatureContributionCalculatingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
ctx.CheckAtModel(GetVersionInfo());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public LabelIndicatorTransform(IHostEnvironment env,
{
}

public LabelIndicatorTransform(IHostEnvironment env, Options options, IDataView input)
internal LabelIndicatorTransform(IHostEnvironment env, Options options, IDataView input)
: base(env, LoadName, Contracts.CheckRef(options, nameof(options)).Columns,
input, TestIsMulticlassLabel)
{
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Data/Transforms/SkipTakeFilter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ private SkipTakeFilter(long skip, long take, IHostEnvironment env, IDataView inp
/// <param name="env">Host Environment.</param>
/// <param name="options">Options for the skip operation.</param>
/// <param name="input">Input <see cref="IDataView"/>.</param>
public SkipTakeFilter(IHostEnvironment env, SkipOptions options, IDataView input)
internal SkipTakeFilter(IHostEnvironment env, SkipOptions options, IDataView input)
: this(options.Count, Options.DefaultTake, env, input)
{
}
Expand All @@ -112,7 +112,7 @@ public SkipTakeFilter(IHostEnvironment env, SkipOptions options, IDataView input
/// <param name="env">Host Environment.</param>
/// <param name="options">Options for the take operation.</param>
/// <param name="input">Input <see cref="IDataView"/>.</param>
public SkipTakeFilter(IHostEnvironment env, TakeOptions options, IDataView input)
internal SkipTakeFilter(IHostEnvironment env, TakeOptions options, IDataView input)
: this(Options.DefaultSkip, options.Count, env, input)
{
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ private SlotsDroppingTransformer(IHostEnvironment env, ModelLoadContext ctx)
}

// Factory method for SignatureLoadModel.
private static SlotsDroppingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
internal static SlotsDroppingTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
ctx.CheckAtModel(GetVersionInfo());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ private bool IsValid(IValueMapperDist mapper, out VectorDataViewType inputType)
}
}

private static EnsembleDistributionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static EnsembleDistributionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ private bool IsValid(IValueMapper mapper, out VectorDataViewType inputType)
}
}

private static EnsembleModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static EnsembleModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ private void InitializeMappers(out IValueMapper[] mappers, out VectorDataViewTyp
inputType = new VectorDataViewType(NumberDataViewType.Single);
}

private static EnsembleMulticlassModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static EnsembleMulticlassModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
internal static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeRanking.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1170,7 +1170,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

private static FastTreeRankingModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static FastTreeRankingModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
return new FastTreeRankingModelParameters(env, ctx);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

private static FastTreeRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static FastTreeRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/FastTreeTweedie.cs
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

private static FastTreeTweedieModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static FastTreeTweedieModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/GamClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(GamBinaryModelParameters).Assembly.FullName);
}

private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
internal static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/GamRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(GamRegressionModelParameters).Assembly.FullName);
}

private static GamRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static GamRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/RandomForestClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
internal static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.FastTree/RandomForestRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.Writer.Write(_quantileSampleCount);
}

private static FastForestRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static FastForestRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ private static VersionInfo GetVersionInfo()
loaderAssemblyName: typeof(TreeEnsembleFeaturizationTransformer).Assembly.FullName);
}

private static TreeEnsembleFeaturizationTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
internal static TreeEnsembleFeaturizationTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
=> new TreeEnsembleFeaturizationTransformer(env, ctx);
}
}
2 changes: 1 addition & 1 deletion src/Microsoft.ML.KMeansClustering/KMeansModelParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
/// <summary>
/// This method is called by reflection to instantiate a predictor.
/// </summary>
private static KMeansModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static KMeansModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.LightGbm/LightGbmBinaryTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

private static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
internal static IPredictorProducing<float> Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.LightGbm/LightGbmRankingTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

private static LightGbmRankingModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static LightGbmRankingModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
return new LightGbmRankingModelParameters(env, ctx);
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.LightGbm/LightGbmRegressionTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
ctx.SetVersionInfo(GetVersionInfo());
}

private static LightGbmRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static LightGbmRegressionModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Mkl.Components/OlsLinearRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ private static void ProbCheckDecode(Double p)
Contracts.CheckDecode(0 <= p && p <= 1);
}

private static OlsModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static OlsModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.PCA/PcaTrainer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ private protected override void SaveCore(ModelSaveContext ctx)
writer.WriteSinglesNoCount(_eigenVectors[i].GetValues().Slice(0, _dimension));
}

private static PcaModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static PcaModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ private MatrixFactorizationModelParameters(IHostEnvironment env, ModelLoadContex
/// <summary>
/// Load model from the given context
/// </summary>
private static MatrixFactorizationModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
internal static MatrixFactorizationModelParameters Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
Expand Down Expand Up @@ -556,7 +556,7 @@ private static VersionInfo GetVersionInfo()
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(MatrixFactorizationPredictionTransformer).Assembly.FullName);
}
private static MatrixFactorizationPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
internal static MatrixFactorizationPredictionTransformer Create(IHostEnvironment env, ModelLoadContext ctx)
=> new MatrixFactorizationPredictionTransformer(env, ctx);

}
Expand Down
Loading