-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Added the assembly name of the custom transform to the model file #4989
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -192,7 +196,7 @@ public void RegisterTypeWithAttribute() | |||
Assert.Equal(tribeEnumerable[0].Merged.HandCount, tribe[0].One.HandCount + tribe[0].Two.HandCount); | |||
|
|||
// Build prediction engine from the trained pipeline. | |||
var engine = ML.Model.CreatePredictionEngine<AlienHero, SuperAlienHero>(model); | |||
var engine = ML.Model.CreatePredictionEngine<AlienHero, SuperAlienHero>(modelSaved); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be good to have this test both for the model before saving it, and for the model that is loaded from disk. Ideally both models should behave the same, but if a bug is introduced that only appears in one of both cases, it might be good to have both of the tests.
The InlineDataAttribute
could become handy for this to avoid code duplication, as used by the PermutationFeatureImportanceTests
:
machinelearning/test/Microsoft.ML.Tests/PermutationFeatureImportanceTests.cs
Lines 30 to 51 in cacc72f
[Theory] | |
[InlineData(true)] | |
[InlineData(false)] | |
public void TestPfiRegressionOnDenseFeatures(bool saveModel) | |
{ | |
var data = GetDenseDataset(); | |
var model = ML.Regression.Trainers.OnlineGradientDescent().Fit(data); | |
ImmutableArray<RegressionMetricsStatistics> pfi; | |
if(saveModel) | |
{ | |
var modelAndSchemaPath = GetOutputPath("TestPfiRegressionOnDenseFeatures.zip"); | |
ML.Model.Save(model, data.Schema, modelAndSchemaPath); | |
var loadedModel = ML.Model.Load(modelAndSchemaPath, out var schema); | |
var castedModel = loadedModel as RegressionPredictionTransformer<LinearRegressionModelParameters>; | |
pfi = ML.Regression.PermutationFeatureImportance(castedModel, data); | |
} | |
else | |
{ | |
pfi = ML.Regression.PermutationFeatureImportance(model, data); | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (!ex.IsMarked()) | ||
throw; | ||
} | ||
ML.ComponentCatalog.RegisterAssembly(typeof(MyLambda).Assembly); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder, what happens (and what should happen) if the user actually registers the assembly as done here and then tries to load the model? Will it throw an exception, or will it work anyway with the fix on this PR? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will work. The assembly can be registered multiple times. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great.
If anything, I would recommend using the InlineData trick I mentioned in another comment to test both cases: when the user registers the assembly manually (even if it's not necessary) and when they don't register it. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@@ -58,6 +60,7 @@ public sealed class CustomMappingTransformer<TSrc, TDst> : ITransformer | |||
: InternalSchemaDefinition.Create(typeof(TDst), outputSchemaDefinition); | |||
|
|||
_contractName = contractName; | |||
_contractAssembly = _mapAction.Method.DeclaringType.Assembly.FullName; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any case where the loaded model would actually require having a different name registered from the "FullName" retrieved from here? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or any case where trying to access that member of _mapAction
would throw? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The call to Method can throw a MemberAccessException. But that would be up to the caller to fix in their code and the exception would help with that.
In reply to: 401790783 [](ancestors = 401790783)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should enforce the idea that the same transform that is used in training should be used in prediction as well. If they were to be different, then they are not the same pipelines and not the same models.
In reply to: 401431899 [](ancestors = 401431899)
In general, it LGTM, I just left some comments and also it would be good to update the docs and samples in here to say that it's no longer necessary for the user to register the assembly when loading back the model: |
@@ -3,7 +3,9 @@ | |||
// See the LICENSE file in the project root for more information. | |||
|
|||
using System; | |||
using System.Diagnostics.Contracts; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, why do we need System.Diagnostics.Contracts
in here? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM It only needs to update the docs and the samples to show that it is no longer necessary to register the assembly
Fixes #4965
Right now when a custom transformer is used, the model file retains the name of the transformer as a string. This is not enough information to re-instantiate the custom transformer when the model is loaded from file. If the assembly containing the custom transformer is not already registered with the component catalog, then the model will fail to load.
To fix this, I have incremented the model version and am now saving the assembly name of the transform with the model.