diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingSaveAndLoad.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingSaveAndLoad.cs index f2d97b70eb..57b52c342c 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingSaveAndLoad.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingSaveAndLoad.cs @@ -45,6 +45,7 @@ public static void Example() // the custom action is defined needs to be registered in the // environment. The following registers the assembly where // IsUnderThirtyCustomAction is defined. + // This is necessary only in versions v1.5-preview2 and earlier mlContext.ComponentCatalog.RegisterAssembly(typeof( IsUnderThirtyCustomAction).Assembly); diff --git a/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs b/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs index 819a518188..c918302d5e 100644 --- a/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs +++ b/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs @@ -25,7 +25,7 @@ public static class CustomMappingCatalog /// If the resulting transformer needs to be save-able, the class defining should implement /// and needs to be decorated with /// with the provided . - /// The assembly containing the class should be registered in the environment where it is loaded back + /// In versions v1.5-preview2 and earlier, the assembly containing the class should be registered in the environment where it is loaded back /// using . /// The contract name, used by ML.NET for loading the model. /// If is specified, resulting transformer would not be save-able. diff --git a/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs b/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs index 7a781128a5..03c55c8c58 100644 --- a/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs +++ b/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs @@ -22,6 +22,7 @@ public sealed class CustomMappingTransformer : ITransformer private readonly IHost _host; private readonly Action _mapAction; private readonly string _contractName; + private readonly string _contractAssembly; internal InternalSchemaDefinition AddedSchema { get; } internal SchemaDefinition InputSchemaDefinition { get; } @@ -58,6 +59,7 @@ internal CustomMappingTransformer(IHostEnvironment env, Action mapAc : InternalSchemaDefinition.Create(typeof(TDst), outputSchemaDefinition); _contractName = contractName; + _contractAssembly = _mapAction.Method.DeclaringType.Assembly.FullName; AddedSchema = outSchema; } @@ -67,7 +69,7 @@ internal void SaveModel(ModelSaveContext ctx) { if (_contractName == null) throw _host.Except("Empty contract name for a transform: the transform cannot be saved"); - LambdaTransform.SaveCustomTransformer(_host, ctx, _contractName); + LambdaTransform.SaveCustomTransformer(_host, ctx, _contractName, _contractAssembly); } /// diff --git a/src/Microsoft.ML.Transforms/LambdaTransform.cs b/src/Microsoft.ML.Transforms/LambdaTransform.cs index 4ba05ee4f8..88de4a68e1 100644 --- a/src/Microsoft.ML.Transforms/LambdaTransform.cs +++ b/src/Microsoft.ML.Transforms/LambdaTransform.cs @@ -4,6 +4,7 @@ using System; using System.IO; +using System.Reflection; using System.Text; using Microsoft.ML; using Microsoft.ML.Data; @@ -40,14 +41,17 @@ private static VersionInfo GetVersionInfo() { return new VersionInfo( modelSignature: "CUSTOMXF", - verWrittenCur: 0x00010001, - verReadableCur: 0x00010001, + //verWrittenCur: 0x00010001, // Initial + verWrittenCur: 0x00010002, // Added name of assembly in which the contractName is present + verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature, loaderAssemblyName: typeof(LambdaTransform).Assembly.FullName); } - internal static void SaveCustomTransformer(IExceptionContext ectx, ModelSaveContext ctx, string contractName) + private const uint VerAssemblyNameSaved = 0x00010002; + + internal static void SaveCustomTransformer(IExceptionContext ectx, ModelSaveContext ctx, string contractName, string contractAssembly) { ectx.CheckValue(ctx, nameof(ctx)); ectx.CheckValue(contractName, nameof(contractName)); @@ -56,6 +60,7 @@ internal static void SaveCustomTransformer(IExceptionContext ectx, ModelSaveCont ctx.SetVersionInfo(GetVersionInfo()); ctx.SaveString(contractName); + ctx.SaveString(contractAssembly); } // Factory for SignatureLoadModel. @@ -66,6 +71,12 @@ private static ITransformer Create(IHostEnvironment env, ModelLoadContext ctx) ctx.CheckAtModel(GetVersionInfo()); var contractName = ctx.LoadString(); + if (ctx.Header.ModelVerWritten >= VerAssemblyNameSaved) + { + var contractAssembly = ctx.LoadString(); + Assembly assembly = Assembly.Load(contractAssembly); + env.ComponentCatalog.RegisterAssembly(assembly); + } object factoryObject = env.ComponentCatalog.GetExtensionValue(env, typeof(CustomMappingFactoryAttributeAttribute), contractName); if (!(factoryObject is ICustomMappingFactory mappingFactory)) diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs index b0660bfe08..3bfa01c6b7 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; +using System.IO; using System.Linq; using Microsoft.ML.Data; using Microsoft.ML.Transforms; @@ -151,7 +152,7 @@ public SuperAlienHero() /// /// A mapping from to . It is used to create a - /// in . + /// in . /// [CustomMappingFactoryAttribute("LambdaAlienHero")] private class AlienFusionProcess : CustomMappingFactory @@ -171,8 +172,10 @@ public override Action GetMapping() } } - [Fact] - public void RegisterTypeWithAttribute() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void RegisterTypeWithAttribute(bool saveModel) { // Build in-memory data. var tribe = new List() { new AlienHero("ML.NET", 2, 1000, 2000, 3000, 4000, 5000, 6000, 7000) }; @@ -184,6 +187,13 @@ public void RegisterTypeWithAttribute() var tribeTransformed = model.Transform(tribeDataView); var tribeEnumerable = ML.Data.CreateEnumerable(tribeTransformed, false).ToList(); + ITransformer modelForPrediction = model; + if (saveModel) + { + ML.Model.Save(model, tribeDataView.Schema, "customTransform.zip"); + modelForPrediction = ML.Model.Load("customTransform.zip", out var tribeDataViewSchema); + } + // Make sure the pipeline output is correct. Assert.Equal(tribeEnumerable[0].Name, "Super " + tribe[0].Name); Assert.Equal(tribeEnumerable[0].Merged.Age, tribe[0].One.Age + tribe[0].Two.Age); @@ -192,7 +202,7 @@ public void RegisterTypeWithAttribute() Assert.Equal(tribeEnumerable[0].Merged.HandCount, tribe[0].One.HandCount + tribe[0].Two.HandCount); // Build prediction engine from the trained pipeline. - var engine = ML.Model.CreatePredictionEngine(model); + var engine = ML.Model.CreatePredictionEngine(modelForPrediction); var alien = new AlienHero("TEN.LM", 1, 2, 3, 4, 5, 6, 7, 8); var superAlien = engine.Predict(alien); @@ -202,6 +212,31 @@ public void RegisterTypeWithAttribute() Assert.Equal(superAlien.Merged.Height, alien.One.Height + alien.Two.Height); Assert.Equal(superAlien.Merged.Weight, alien.One.Weight + alien.Two.Weight); Assert.Equal(superAlien.Merged.HandCount, alien.One.HandCount + alien.Two.HandCount); + + Done(); + } + + [Fact] + void TestCustomTransformBackcompat() + { + // With older versions, it is necessary to register the assembly + ML.ComponentCatalog.RegisterAssembly(typeof(AlienFusionProcess).Assembly); + + var modelPath = Path.Combine(DataDir, "backcompat", "customTransform.zip"); + var trainedModel = ML.Model.Load(modelPath, out var dataViewSchema); + + var engine = ML.Model.CreatePredictionEngine(trainedModel); + var alien = new AlienHero("TEN.LM", 1, 2, 3, 4, 5, 6, 7, 8); + var superAlien = engine.Predict(alien); + + // Make sure the prediction engine produces expected result. + Assert.Equal(superAlien.Name, "Super " + alien.Name); + Assert.Equal(superAlien.Merged.Age, alien.One.Age + alien.Two.Age); + Assert.Equal(superAlien.Merged.Height, alien.One.Height + alien.Two.Height); + Assert.Equal(superAlien.Merged.Weight, alien.One.Weight + alien.Two.Weight); + Assert.Equal(superAlien.Merged.HandCount, alien.One.HandCount + alien.Two.HandCount); + + Done(); } [Fact] diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 08b7f38ee8..97f473dee6 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -1250,14 +1250,10 @@ public void TensorFlowStringTest() } [TensorFlowFact] + // This test hangs occasionally + [Trait("Category", "SkipInCI")] public void TensorFlowImageClassificationDefault() { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - Output.WriteLine("TODO TEST_STABILITY: TensorFlowImageClassificationDefault hangs on Linux."); - return; - } - string imagesDownloadFolderPath = Path.Combine(TensorFlowScenariosTestsFixture.assetsPath, "inputs", "images"); @@ -1628,13 +1624,10 @@ internal void TensorFlowImageClassificationWithLRScheduling(LearningRateSchedule [TensorFlowTheory] [InlineData(ImageClassificationTrainer.EarlyStoppingMetric.Accuracy)] [InlineData(ImageClassificationTrainer.EarlyStoppingMetric.Loss)] + // This test hangs ocassionally + [Trait("Category", "SkipInCI")] public void TensorFlowImageClassificationEarlyStopping(ImageClassificationTrainer.EarlyStoppingMetric earlyStoppingMetric) { - if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) - { - Output.WriteLine("TODO TEST_STABILITY: TensorFlowImageClassificationEarlyStopping hangs on Linux."); - return; - } string imagesDownloadFolderPath = Path.Combine(TensorFlowScenariosTestsFixture.assetsPath, "inputs", "images"); diff --git a/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs index c4d9378024..d6e58b5ebc 100644 --- a/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs @@ -44,8 +44,10 @@ public override Action GetMapping() } } - [Fact] - public void TestCustomTransformer() + [Theory] + [InlineData(true)] + [InlineData(false)] + public void TestCustomTransformer(bool registerAssembly) { string dataPath = GetDataPath("adult.tiny.with-schema.txt"); var source = new MultiFileSource(dataPath); @@ -62,17 +64,13 @@ public void TestCustomTransformer() var tempoEnv = new MLContext(1); var customEst = new CustomMappingEstimator(tempoEnv, MyLambda.MyAction, "MyLambda"); - try - { - TestEstimatorCore(customEst, data); - Assert.True(false, "Cannot work without RegisterAssembly"); - } - catch (InvalidOperationException ex) - { - if (!ex.IsMarked()) - throw; - } - ML.ComponentCatalog.RegisterAssembly(typeof(MyLambda).Assembly); + // Before 1.5-preview3 it was required to register the assembly. + // Now, the assembly information is automatically saved in the model and the assembly is registered + // when loading. + // This tests the case that the CustomTransformer still works even if you explicitly register the assembly + if (registerAssembly) + ML.ComponentCatalog.RegisterAssembly(typeof(MyLambda).Assembly); + TestEstimatorCore(customEst, data); transformedData = customEst.Fit(data).Transform(data); diff --git a/test/data/backcompat/customTransform.zip b/test/data/backcompat/customTransform.zip new file mode 100644 index 0000000000..967a43f0d5 Binary files /dev/null and b/test/data/backcompat/customTransform.zip differ