-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Fixes #4385 about calling the Create methods when loading models from disk #4485
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d8e1dfc
e7f0910
1e5170a
b1e0761
a1a430d
dd62e51
e076101
9841ed5
87c1b5a
33fb93c
bb408f4
26609e8
ad833ad
88aef21
1dd3e33
d5a0b9a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,54 @@ | |
|
||
namespace Microsoft.ML.Runtime | ||
{ | ||
|
||
internal static class Extension | ||
{ | ||
internal static AccessModifier Accessmodifier(this MethodInfo methodInfo) | ||
{ | ||
if (methodInfo.IsFamilyAndAssembly) | ||
return AccessModifier.PrivateProtected; | ||
if (methodInfo.IsPrivate) | ||
return AccessModifier.Private; | ||
if (methodInfo.IsFamily) | ||
return AccessModifier.Protected; | ||
if (methodInfo.IsFamilyOrAssembly) | ||
return AccessModifier.ProtectedInternal; | ||
if (methodInfo.IsAssembly) | ||
return AccessModifier.Internal; | ||
if (methodInfo.IsPublic) | ||
return AccessModifier.Public; | ||
throw new ArgumentException("Did not find access modifier", "methodInfo"); | ||
} | ||
|
||
internal static AccessModifier Accessmodifier(this ConstructorInfo constructorInfo) | ||
{ | ||
if (constructorInfo.IsFamilyAndAssembly) | ||
return AccessModifier.PrivateProtected; | ||
if (constructorInfo.IsPrivate) | ||
return AccessModifier.Private; | ||
if (constructorInfo.IsFamily) | ||
return AccessModifier.Protected; | ||
if (constructorInfo.IsFamilyOrAssembly) | ||
return AccessModifier.ProtectedInternal; | ||
if (constructorInfo.IsAssembly) | ||
return AccessModifier.Internal; | ||
if (constructorInfo.IsPublic) | ||
return AccessModifier.Public; | ||
throw new ArgumentException("Did not find access modifier", "constructorInfo"); | ||
} | ||
|
||
internal enum AccessModifier | ||
{ | ||
PrivateProtected, | ||
Private, | ||
Protected, | ||
ProtectedInternal, | ||
Internal, | ||
Public | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// This catalogs instantiatable components (aka, loadable classes). Components are registered via | ||
/// a descendant of <see cref="LoadableClassAttributeBase"/>, identifying the names and signature types under which the component | ||
|
@@ -414,21 +462,59 @@ private static bool TryGetIniters(Type instType, Type loaderType, Type[] parmTyp | |
ctor = null; | ||
create = null; | ||
requireEnvironment = false; | ||
bool requireEnvironmentCtor = false; | ||
bool requireEnvironmentCreate = false; | ||
var parmTypesWithEnv = Utils.Concat(new Type[1] { typeof(IHostEnvironment) }, parmTypes); | ||
|
||
if (Utils.Size(parmTypes) == 0 && (getter = FindInstanceGetter(instType, loaderType)) != null) | ||
return true; | ||
if (instType.IsAssignableFrom(loaderType) && (ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypes ?? Type.EmptyTypes, null)) != null) | ||
return true; | ||
if (instType.IsAssignableFrom(loaderType) && (ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypesWithEnv ?? Type.EmptyTypes, null)) != null) | ||
|
||
// Find both 'ctor' and 'create' methods if available | ||
if (instType.IsAssignableFrom(loaderType)) | ||
{ | ||
if ((ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypes ?? Type.EmptyTypes, null)) == null) | ||
{ | ||
if ((ctor = loaderType.GetConstructor(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic, null, parmTypesWithEnv ?? Type.EmptyTypes, null)) != null) | ||
requireEnvironmentCtor = true; | ||
} | ||
} | ||
|
||
if ((create = FindCreateMethod(instType, loaderType, parmTypes ?? Type.EmptyTypes)) == null) | ||
{ | ||
requireEnvironment = true; | ||
if ((create = FindCreateMethod(instType, loaderType, parmTypesWithEnv ?? Type.EmptyTypes)) != null) | ||
requireEnvironmentCreate = true; | ||
} | ||
|
||
if (ctor != null && create != null) | ||
{ | ||
// If both 'ctor' and 'create' methods were found | ||
// Choose the one that is 'more' public | ||
// If they have the same visibility, then throw an exception, since this shouldn't happen. | ||
|
||
if (ctor.Accessmodifier() == create.Accessmodifier()) | ||
{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I changed the nested if's I had (link to comment) for this other solution using the Accessmodifier() extension method (as suggested by @yaeldekel ). Although I think this one is more legible, I wouldn't be sure if it's worth it to create the extension method only for this.... So let me know your opinions, Thanks! #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To me this seems cleaner. A couple of more ways you can decrease the amount of "if/else"s in the code:
In reply to: 347688260 [](ancestors = 347688260) |
||
throw Contracts.Except($"Can't load type {instType}, because it has both create and constructor methods with the same visibility. Please indicate which one should be used by changing either the signature or the visibility of one of them."); | ||
} | ||
if (ctor.Accessmodifier() > create.Accessmodifier()) | ||
{ | ||
create = null; | ||
requireEnvironment = requireEnvironmentCtor; | ||
return true; | ||
} | ||
ctor = null; | ||
requireEnvironment = requireEnvironmentCreate; | ||
return true; | ||
} | ||
if ((create = FindCreateMethod(instType, loaderType, parmTypes ?? Type.EmptyTypes)) != null) | ||
|
||
if (ctor != null && create == null) | ||
{ | ||
requireEnvironment = requireEnvironmentCtor; | ||
return true; | ||
if ((create = FindCreateMethod(instType, loaderType, parmTypesWithEnv ?? Type.EmptyTypes)) != null) | ||
} | ||
|
||
if (ctor == null && create != null) | ||
{ | ||
requireEnvironment = true; | ||
requireEnvironment = requireEnvironmentCreate; | ||
return true; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't really like this nested if's, I am not sure if they're legible enough, and it is ambiguous what should happen if there's a 'protected' create or constructor method (which, I believe, never happens in the codebase...). Still, this gets the job done.
I can think of a couple of ways of making this, but not sure if they would be more legible. Please, let me know if I should rewrite this in another way. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I've changed this nested if's in a new iteration (see this comment) but I am not sure if I prefer the nested of's or the new solution. #Resolved