diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingSaveAndLoad.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingSaveAndLoad.cs
index f2d97b70eb..57b52c342c 100644
--- a/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingSaveAndLoad.cs
+++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Transforms/CustomMappingSaveAndLoad.cs
@@ -45,6 +45,7 @@ public static void Example()
// the custom action is defined needs to be registered in the
// environment. The following registers the assembly where
// IsUnderThirtyCustomAction is defined.
+ // This is necessary only in versions v1.5-preview2 and earlier
mlContext.ComponentCatalog.RegisterAssembly(typeof(
IsUnderThirtyCustomAction).Assembly);
diff --git a/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs b/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs
index 819a518188..c918302d5e 100644
--- a/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs
+++ b/src/Microsoft.ML.Transforms/CustomMappingCatalog.cs
@@ -25,7 +25,7 @@ public static class CustomMappingCatalog
/// If the resulting transformer needs to be save-able, the class defining should implement
/// and needs to be decorated with
/// with the provided .
- /// The assembly containing the class should be registered in the environment where it is loaded back
+ /// In versions v1.5-preview2 and earlier, the assembly containing the class should be registered in the environment where it is loaded back
/// using .
/// The contract name, used by ML.NET for loading the model.
/// If is specified, resulting transformer would not be save-able.
diff --git a/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs b/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs
index 7a781128a5..03c55c8c58 100644
--- a/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs
+++ b/src/Microsoft.ML.Transforms/CustomMappingTransformer.cs
@@ -22,6 +22,7 @@ public sealed class CustomMappingTransformer : ITransformer
private readonly IHost _host;
private readonly Action _mapAction;
private readonly string _contractName;
+ private readonly string _contractAssembly;
internal InternalSchemaDefinition AddedSchema { get; }
internal SchemaDefinition InputSchemaDefinition { get; }
@@ -58,6 +59,7 @@ internal CustomMappingTransformer(IHostEnvironment env, Action mapAc
: InternalSchemaDefinition.Create(typeof(TDst), outputSchemaDefinition);
_contractName = contractName;
+ _contractAssembly = _mapAction.Method.DeclaringType.Assembly.FullName;
AddedSchema = outSchema;
}
@@ -67,7 +69,7 @@ internal void SaveModel(ModelSaveContext ctx)
{
if (_contractName == null)
throw _host.Except("Empty contract name for a transform: the transform cannot be saved");
- LambdaTransform.SaveCustomTransformer(_host, ctx, _contractName);
+ LambdaTransform.SaveCustomTransformer(_host, ctx, _contractName, _contractAssembly);
}
///
diff --git a/src/Microsoft.ML.Transforms/LambdaTransform.cs b/src/Microsoft.ML.Transforms/LambdaTransform.cs
index 4ba05ee4f8..88de4a68e1 100644
--- a/src/Microsoft.ML.Transforms/LambdaTransform.cs
+++ b/src/Microsoft.ML.Transforms/LambdaTransform.cs
@@ -4,6 +4,7 @@
using System;
using System.IO;
+using System.Reflection;
using System.Text;
using Microsoft.ML;
using Microsoft.ML.Data;
@@ -40,14 +41,17 @@ private static VersionInfo GetVersionInfo()
{
return new VersionInfo(
modelSignature: "CUSTOMXF",
- verWrittenCur: 0x00010001,
- verReadableCur: 0x00010001,
+ //verWrittenCur: 0x00010001, // Initial
+ verWrittenCur: 0x00010002, // Added name of assembly in which the contractName is present
+ verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature,
loaderAssemblyName: typeof(LambdaTransform).Assembly.FullName);
}
- internal static void SaveCustomTransformer(IExceptionContext ectx, ModelSaveContext ctx, string contractName)
+ private const uint VerAssemblyNameSaved = 0x00010002;
+
+ internal static void SaveCustomTransformer(IExceptionContext ectx, ModelSaveContext ctx, string contractName, string contractAssembly)
{
ectx.CheckValue(ctx, nameof(ctx));
ectx.CheckValue(contractName, nameof(contractName));
@@ -56,6 +60,7 @@ internal static void SaveCustomTransformer(IExceptionContext ectx, ModelSaveCont
ctx.SetVersionInfo(GetVersionInfo());
ctx.SaveString(contractName);
+ ctx.SaveString(contractAssembly);
}
// Factory for SignatureLoadModel.
@@ -66,6 +71,12 @@ private static ITransformer Create(IHostEnvironment env, ModelLoadContext ctx)
ctx.CheckAtModel(GetVersionInfo());
var contractName = ctx.LoadString();
+ if (ctx.Header.ModelVerWritten >= VerAssemblyNameSaved)
+ {
+ var contractAssembly = ctx.LoadString();
+ Assembly assembly = Assembly.Load(contractAssembly);
+ env.ComponentCatalog.RegisterAssembly(assembly);
+ }
object factoryObject = env.ComponentCatalog.GetExtensionValue(env, typeof(CustomMappingFactoryAttributeAttribute), contractName);
if (!(factoryObject is ICustomMappingFactory mappingFactory))
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs
index b0660bfe08..3bfa01c6b7 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs
@@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
+using System.IO;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Transforms;
@@ -151,7 +152,7 @@ public SuperAlienHero()
///
/// A mapping from to . It is used to create a
- /// in .
+ /// in .
///
[CustomMappingFactoryAttribute("LambdaAlienHero")]
private class AlienFusionProcess : CustomMappingFactory
@@ -171,8 +172,10 @@ public override Action GetMapping()
}
}
- [Fact]
- public void RegisterTypeWithAttribute()
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public void RegisterTypeWithAttribute(bool saveModel)
{
// Build in-memory data.
var tribe = new List() { new AlienHero("ML.NET", 2, 1000, 2000, 3000, 4000, 5000, 6000, 7000) };
@@ -184,6 +187,13 @@ public void RegisterTypeWithAttribute()
var tribeTransformed = model.Transform(tribeDataView);
var tribeEnumerable = ML.Data.CreateEnumerable(tribeTransformed, false).ToList();
+ ITransformer modelForPrediction = model;
+ if (saveModel)
+ {
+ ML.Model.Save(model, tribeDataView.Schema, "customTransform.zip");
+ modelForPrediction = ML.Model.Load("customTransform.zip", out var tribeDataViewSchema);
+ }
+
// Make sure the pipeline output is correct.
Assert.Equal(tribeEnumerable[0].Name, "Super " + tribe[0].Name);
Assert.Equal(tribeEnumerable[0].Merged.Age, tribe[0].One.Age + tribe[0].Two.Age);
@@ -192,7 +202,7 @@ public void RegisterTypeWithAttribute()
Assert.Equal(tribeEnumerable[0].Merged.HandCount, tribe[0].One.HandCount + tribe[0].Two.HandCount);
// Build prediction engine from the trained pipeline.
- var engine = ML.Model.CreatePredictionEngine(model);
+ var engine = ML.Model.CreatePredictionEngine(modelForPrediction);
var alien = new AlienHero("TEN.LM", 1, 2, 3, 4, 5, 6, 7, 8);
var superAlien = engine.Predict(alien);
@@ -202,6 +212,31 @@ public void RegisterTypeWithAttribute()
Assert.Equal(superAlien.Merged.Height, alien.One.Height + alien.Two.Height);
Assert.Equal(superAlien.Merged.Weight, alien.One.Weight + alien.Two.Weight);
Assert.Equal(superAlien.Merged.HandCount, alien.One.HandCount + alien.Two.HandCount);
+
+ Done();
+ }
+
+ [Fact]
+ void TestCustomTransformBackcompat()
+ {
+ // With older versions, it is necessary to register the assembly
+ ML.ComponentCatalog.RegisterAssembly(typeof(AlienFusionProcess).Assembly);
+
+ var modelPath = Path.Combine(DataDir, "backcompat", "customTransform.zip");
+ var trainedModel = ML.Model.Load(modelPath, out var dataViewSchema);
+
+ var engine = ML.Model.CreatePredictionEngine(trainedModel);
+ var alien = new AlienHero("TEN.LM", 1, 2, 3, 4, 5, 6, 7, 8);
+ var superAlien = engine.Predict(alien);
+
+ // Make sure the prediction engine produces expected result.
+ Assert.Equal(superAlien.Name, "Super " + alien.Name);
+ Assert.Equal(superAlien.Merged.Age, alien.One.Age + alien.Two.Age);
+ Assert.Equal(superAlien.Merged.Height, alien.One.Height + alien.Two.Height);
+ Assert.Equal(superAlien.Merged.Weight, alien.One.Weight + alien.Two.Weight);
+ Assert.Equal(superAlien.Merged.HandCount, alien.One.HandCount + alien.Two.HandCount);
+
+ Done();
}
[Fact]
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
index 08b7f38ee8..97f473dee6 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
@@ -1250,14 +1250,10 @@ public void TensorFlowStringTest()
}
[TensorFlowFact]
+ // This test hangs occasionally
+ [Trait("Category", "SkipInCI")]
public void TensorFlowImageClassificationDefault()
{
- if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
- {
- Output.WriteLine("TODO TEST_STABILITY: TensorFlowImageClassificationDefault hangs on Linux.");
- return;
- }
-
string imagesDownloadFolderPath = Path.Combine(TensorFlowScenariosTestsFixture.assetsPath, "inputs",
"images");
@@ -1628,13 +1624,10 @@ internal void TensorFlowImageClassificationWithLRScheduling(LearningRateSchedule
[TensorFlowTheory]
[InlineData(ImageClassificationTrainer.EarlyStoppingMetric.Accuracy)]
[InlineData(ImageClassificationTrainer.EarlyStoppingMetric.Loss)]
+ // This test hangs ocassionally
+ [Trait("Category", "SkipInCI")]
public void TensorFlowImageClassificationEarlyStopping(ImageClassificationTrainer.EarlyStoppingMetric earlyStoppingMetric)
{
- if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
- {
- Output.WriteLine("TODO TEST_STABILITY: TensorFlowImageClassificationEarlyStopping hangs on Linux.");
- return;
- }
string imagesDownloadFolderPath = Path.Combine(TensorFlowScenariosTestsFixture.assetsPath, "inputs",
"images");
diff --git a/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs
index c4d9378024..d6e58b5ebc 100644
--- a/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs
+++ b/test/Microsoft.ML.Tests/Transformers/CustomMappingTests.cs
@@ -44,8 +44,10 @@ public override Action GetMapping()
}
}
- [Fact]
- public void TestCustomTransformer()
+ [Theory]
+ [InlineData(true)]
+ [InlineData(false)]
+ public void TestCustomTransformer(bool registerAssembly)
{
string dataPath = GetDataPath("adult.tiny.with-schema.txt");
var source = new MultiFileSource(dataPath);
@@ -62,17 +64,13 @@ public void TestCustomTransformer()
var tempoEnv = new MLContext(1);
var customEst = new CustomMappingEstimator(tempoEnv, MyLambda.MyAction, "MyLambda");
- try
- {
- TestEstimatorCore(customEst, data);
- Assert.True(false, "Cannot work without RegisterAssembly");
- }
- catch (InvalidOperationException ex)
- {
- if (!ex.IsMarked())
- throw;
- }
- ML.ComponentCatalog.RegisterAssembly(typeof(MyLambda).Assembly);
+ // Before 1.5-preview3 it was required to register the assembly.
+ // Now, the assembly information is automatically saved in the model and the assembly is registered
+ // when loading.
+ // This tests the case that the CustomTransformer still works even if you explicitly register the assembly
+ if (registerAssembly)
+ ML.ComponentCatalog.RegisterAssembly(typeof(MyLambda).Assembly);
+
TestEstimatorCore(customEst, data);
transformedData = customEst.Fit(data).Transform(data);
diff --git a/test/data/backcompat/customTransform.zip b/test/data/backcompat/customTransform.zip
new file mode 100644
index 0000000000..967a43f0d5
Binary files /dev/null and b/test/data/backcompat/customTransform.zip differ