Skip to content

Modified how DataViewTypes are registered #4187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Sep 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public AlienTypeAttributeAttribute(int raceId)
public override void Register()
{
DataViewTypeManager.Register(new DataViewAlienBodyType(RaceId),
typeof(AlienBody), new[] { this });
typeof(AlienBody), this);
}

public override bool Equals(DataViewTypeAttribute other)
Expand Down
103 changes: 80 additions & 23 deletions src/Microsoft.ML.Data/Data/DataViewTypeManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,25 @@ public static class DataViewTypeManager
/// </summary>
internal static DataViewType GetDataViewType(Type type, IEnumerable<Attribute> typeAttributes = null)
{
//Filter attributes as we only care about DataViewTypeAttribute
DataViewTypeAttribute typeAttr = null;
if(typeAttributes != null)
{
typeAttributes = typeAttributes.Where(attr => attr.GetType().IsSubclassOf(typeof(DataViewTypeAttribute)));
if (typeAttributes.Count() > 1)
{
throw Contracts.ExceptParam(nameof(type), "Type {0} cannot be marked with multiple attributes, {1}, derived from {2}.",
type.Name, typeAttributes, typeof(DataViewTypeAttribute));
}
else if (typeAttributes.Count() == 1)
{
typeAttr = typeAttributes.First() as DataViewTypeAttribute;
}
}
lock (_lock)
{
// Compute the ID of type with extra attributes.
var rawType = new TypeWithAttributes(type, typeAttributes);
var rawType = new TypeWithAttributes(type, typeAttr);

// Get the DataViewType's ID which typeID is mapped into.
if (!_rawTypeToDataViewTypeMap.TryGetValue(rawType, out DataViewType dataViewType))
Expand All @@ -73,10 +88,25 @@ internal static DataViewType GetDataViewType(Type type, IEnumerable<Attribute> t
/// </summary>
internal static bool Knows(Type type, IEnumerable<Attribute> typeAttributes = null)
{
//Filter attributes as we only care about DataViewTypeAttribute
DataViewTypeAttribute typeAttr = null;
if(typeAttributes != null)
{
typeAttributes = typeAttributes.Where(attr => attr.GetType().IsSubclassOf(typeof(DataViewTypeAttribute)));
if (typeAttributes.Count() > 1)
{
throw Contracts.ExceptParam(nameof(type), "Type {0} cannot be marked with multiple attributes, {1}, derived from {2}.",
type.Name, typeAttributes, typeof(DataViewTypeAttribute));
}
else if (typeAttributes.Count() == 1)
{
typeAttr = typeAttributes.First() as DataViewTypeAttribute;
}
}
lock (_lock)
{
// Compute the ID of type with extra attributes.
var rawType = new TypeWithAttributes(type, typeAttributes);
var rawType = new TypeWithAttributes(type, typeAttr);

// Check if this ID has been associated with a DataViewType.
// Note that the dictionary below contains (rawType, dataViewType) pairs (key type is TypeWithAttributes, and value type is DataViewType).
Expand Down Expand Up @@ -111,15 +141,47 @@ internal static bool Knows(DataViewType dataViewType)
/// <param name="type">Native type in C#.</param>
/// <param name="dataViewType">The corresponding type of <paramref name="type"/> in ML.NET's type system.</param>
/// <param name="typeAttributes">The <see cref="Attribute"/>s attached to <paramref name="type"/>.</param>
public static void Register(DataViewType dataViewType, Type type, IEnumerable<Attribute> typeAttributes = null)
[Obsolete("This API is depricated, please use the new form of Register which takes in a single DataViewTypeAttribute instead.", false)]
public static void Register(DataViewType dataViewType, Type type, IEnumerable<Attribute> typeAttributes)
{
DataViewTypeAttribute typeAttr = null;
if (typeAttributes != null)
{
if (typeAttributes.Count() > 1)
{
throw Contracts.ExceptParam(nameof(type), $"Type {type} has too many attributes.");
}
else if (typeAttributes.Count() == 1)
{
var attr = typeAttributes.First();
if (!attr.GetType().IsSubclassOf(typeof(DataViewTypeAttribute)))
{
throw Contracts.ExceptParam(nameof(type), $"Type {type} has an attribute that is not of DataViewTypeAttribute.");
}
else
{
typeAttr = attr as DataViewTypeAttribute;
}
}
}
Register(dataViewType, type, typeAttr);
}
/// <summary>
/// This function tells that <paramref name="dataViewType"/> should be representation of data in <paramref name="type"/> in
/// ML.NET's type system. The registered <paramref name="type"/> must be a standard C# object's type.
/// </summary>
/// <param name="type">Native type in C#.</param>
/// <param name="dataViewType">The corresponding type of <paramref name="type"/> in ML.NET's type system.</param>
/// <param name="typeAttribute">The <see cref="DataViewTypeAttribute"/> attached to <paramref name="type"/>.</param>
public static void Register(DataViewType dataViewType, Type type, DataViewTypeAttribute typeAttribute = null)
{
lock (_lock)
{
if (_bannedRawTypes.Contains(type))
throw Contracts.ExceptParam(nameof(type), $"Type {type} has been registered as ML.NET's default supported type, " +
$"so it can't not be registered again.");

var rawType = new TypeWithAttributes(type, typeAttributes);
var rawType = new TypeWithAttributes(type, typeAttribute);

if (_rawTypeToDataViewTypeMap.ContainsKey(rawType) && _rawTypeToDataViewTypeMap[rawType].Equals(dataViewType) &&
_dataViewTypeToRawTypeMap.ContainsKey(dataViewType) && _dataViewTypeToRawTypeMap[dataViewType].Equals(rawType))
Expand Down Expand Up @@ -152,7 +214,7 @@ public static void Register(DataViewType dataViewType, Type type, IEnumerable<At
}

/// <summary>
/// An instance of <see cref="TypeWithAttributes"/> represents an unique key of its <see cref="TargetType"/> and <see cref="_associatedAttributes"/>.
/// An instance of <see cref="TypeWithAttributes"/> represents an unique key of its <see cref="TargetType"/> and <see cref="_associatedAttribute"/>.
/// </summary>
private class TypeWithAttributes
{
Expand All @@ -162,16 +224,16 @@ private class TypeWithAttributes
public Type TargetType { get; }

/// <summary>
/// The underlying type's attributes. Together with <see cref="TargetType"/>, <see cref="_associatedAttributes"/> uniquely defines
/// The underlying type's attributes. Together with <see cref="TargetType"/>, <see cref="_associatedAttribute"/> uniquely defines
/// a key when using <see cref="TypeWithAttributes"/> as the key type in <see cref="Dictionary{TKey, TValue}"/>. Note that the
/// uniqueness is determined by <see cref="Equals(object)"/> and <see cref="GetHashCode"/> below.
/// </summary>
private IEnumerable<Attribute> _associatedAttributes;
private DataViewTypeAttribute _associatedAttribute;

public TypeWithAttributes(Type type, IEnumerable<Attribute> attributes)
public TypeWithAttributes(Type type, DataViewTypeAttribute attribute)
{
TargetType = type;
_associatedAttributes = attributes;
_associatedAttribute = attribute;
}

public override bool Equals(object obj)
Expand All @@ -183,22 +245,15 @@ public override bool Equals(object obj)
// Flag of having the attribute configurations.
var sameAttributeConfig = true;

if (_associatedAttributes == null && other._associatedAttributes == null)
if (_associatedAttribute == null && other._associatedAttribute == null)
sameAttributeConfig = true;
else if (_associatedAttributes == null && other._associatedAttributes != null)
else if (_associatedAttribute == null && other._associatedAttribute != null)
sameAttributeConfig = false;
else if (_associatedAttributes != null && other._associatedAttributes == null)
sameAttributeConfig = false;
else if (_associatedAttributes.Count() != other._associatedAttributes.Count())
else if (_associatedAttribute != null && other._associatedAttribute == null)
sameAttributeConfig = false;
else
{
var zipped = _associatedAttributes.Zip(other._associatedAttributes, (attr, otherAttr) => (attr, otherAttr));
foreach (var (attr, otherAttr) in zipped)
{
if (!attr.Equals(otherAttr))
sameAttributeConfig = false;
}
sameAttributeConfig = _associatedAttribute.Equals(other._associatedAttribute);
}

return sameType && sameAttributeConfig;
Expand All @@ -213,12 +268,14 @@ public override bool Equals(object obj)
/// </summary>
public override int GetHashCode()
{
if (_associatedAttributes == null)
if (_associatedAttribute == null)
return TargetType.GetHashCode();

var code = TargetType.GetHashCode();
foreach (var attr in _associatedAttributes)
code = Hashing.CombineHash(code, attr.GetHashCode());
if (_associatedAttribute != null)
{
code = Hashing.CombineHash(code, _associatedAttribute.GetHashCode());
}
return code;
}

Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.ImageAnalytics/ImageType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public override int GetHashCode()

public override void Register()
{
DataViewTypeManager.Register(new ImageDataViewType(Height, Width), typeof(Bitmap), new[] { this });
DataViewTypeManager.Register(new ImageDataViewType(Height, Width), typeof(Bitmap), this );
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.OnnxTransformer/OnnxMapType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ public sealed class OnnxMapType : StructuredDataViewType
/// <param name="valueType">Value type of the associated ONNX map.</param>
public OnnxMapType(Type keyType, Type valueType) : base(typeof(IDictionary<,>).MakeGenericType(keyType, valueType))
{
DataViewTypeManager.Register(this, RawType, new[] { new OnnxMapTypeAttribute(keyType, valueType) });
DataViewTypeManager.Register(this, RawType, new OnnxMapTypeAttribute(keyType, valueType));
}

public override bool Equals(DataViewType other)
Expand Down Expand Up @@ -95,7 +95,7 @@ public override void Register()
{
var enumerableType = typeof(IDictionary<,>);
var type = enumerableType.MakeGenericType(_keyType, _valueType);
DataViewTypeManager.Register(new OnnxMapType(_keyType, _valueType), type, new[] { this });
DataViewTypeManager.Register(new OnnxMapType(_keyType, _valueType), type, this);
}
}
}
4 changes: 2 additions & 2 deletions src/Microsoft.ML.OnnxTransformer/OnnxSequenceType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ private static Type MakeNativeType(Type elementType)
/// <param name="elementType">The element type of a sequence.</param>
public OnnxSequenceType(Type elementType) : base(MakeNativeType(elementType))
{
DataViewTypeManager.Register(this, RawType, new[] { new OnnxSequenceTypeAttribute(elementType) });
DataViewTypeManager.Register(this, RawType, new OnnxSequenceTypeAttribute(elementType));
}

public override bool Equals(DataViewType other)
Expand Down Expand Up @@ -96,7 +96,7 @@ public override void Register()
{
var enumerableType = typeof(IEnumerable<>);
var type = enumerableType.MakeGenericType(_elemType);
DataViewTypeManager.Register(new OnnxSequenceType(_elemType), type, new[] { this });
DataViewTypeManager.Register(new OnnxSequenceType(_elemType), type, this);
}
}
}
24 changes: 20 additions & 4 deletions test/Microsoft.ML.Core.Tests/UnitTests/TestCustomTypeRegister.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public AlienTypeAttributeAttribute(int raceId)
/// </summary>
public override void Register()
{
DataViewTypeManager.Register(new DataViewAlienBodyType(RaceId), typeof(AlienBody), new[] { this });
DataViewTypeManager.Register(new DataViewAlienBodyType(RaceId), typeof(AlienBody), this);
}

public override bool Equals(DataViewTypeAttribute other)
Expand Down Expand Up @@ -243,7 +243,7 @@ public void TestTypeManager()
{
// "a" has been registered with AlienBody without any attribute, so the user can't
// register "a" again with AlienBody plus the attribute "c."
DataViewTypeManager.Register(a, typeof(AlienBody), new[] { c });
DataViewTypeManager.Register(a, typeof(AlienBody), c);
}
catch
{
Expand All @@ -268,14 +268,30 @@ public void TestTypeManager()
// Register a type with attribute.
var e = new DataViewAlienBodyType(7788);
var f = new AlienTypeAttributeAttribute(8877);
DataViewTypeManager.Register(e, typeof(AlienBody), new[] { f });
DataViewTypeManager.Register(e, typeof(AlienBody), f);
Assert.True(DataViewTypeManager.Knows(e));
Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new[] { f }));
Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new[] { f }));
// "e" is associated with typeof(AlienBody) with "f," so the call below should return true.
Assert.Equal(e, DataViewTypeManager.GetDataViewType(typeof(AlienBody), new[] { f }));
// "a" is associated with typeof(AlienBody) without any attribute, so the call below should return false.
Assert.NotEqual(a, DataViewTypeManager.GetDataViewType(typeof(AlienBody), new[] { f }));
}

[Fact]
public void GetTypeWithAdditionalDataViewTypeAttributes()
{
var a = new DataViewAlienBodyType(7788);
var b = new AlienTypeAttributeAttribute(8877);
var c = new ColumnNameAttribute("foo");
var d = new AlienTypeAttributeAttribute(8876);


DataViewTypeManager.Register(a, typeof(AlienBody), b);
Assert.True(DataViewTypeManager.Knows(a));
Assert.True(DataViewTypeManager.Knows(typeof(AlienBody), new Attribute[] { b, c }));
// "a" is associated with typeof(AlienBody) with "b," so the call below should return true.
Assert.Equal(a, DataViewTypeManager.GetDataViewType(typeof(AlienBody), new Attribute[] { b, c }));
Assert.Throws<ArgumentOutOfRangeException>(() => DataViewTypeManager.Knows(typeof(AlienBody), new Attribute[] { b, d }));
}
}
}
68 changes: 68 additions & 0 deletions test/Microsoft.ML.Tests/OnnxSequenceTypeWithAttributesTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using System.Drawing;
using Microsoft.ML.Data;
using Microsoft.ML.RunTests;
using Microsoft.ML.Transforms.Image;
using Microsoft.ML.Transforms.Onnx;
using Xunit;
using Xunit.Abstractions;
using System.Linq;
using System.IO;
using Microsoft.ML.TestFramework.Attributes;

namespace Microsoft.ML.Tests
{
public class OnnxSequenceTypeWithAttributesTest : BaseTestBaseline
{
public class OutputObj
{
[ColumnName("output")]
[OnnxSequenceType(typeof(IDictionary<string, float>))]
public IEnumerable<IDictionary<string, float>> Output;
}
public class FloatInput
{
[ColumnName("input")]
[VectorType(3)]
public float[] Input { get; set; }
}

public OnnxSequenceTypeWithAttributesTest(ITestOutputHelper output) : base(output)
{
}
public static PredictionEngine<FloatInput, OutputObj> LoadModel(string onnxModelFilePath)
{
var ctx = new MLContext();
var dataView = ctx.Data.LoadFromEnumerable(new List<FloatInput>());

var pipeline = ctx.Transforms.ApplyOnnxModel(
modelFile: onnxModelFilePath,
outputColumnNames: new[] { "output" }, inputColumnNames: new[] { "input" });

var model = pipeline.Fit(dataView);
return ctx.Model.CreatePredictionEngine<FloatInput, OutputObj>(model);
}

[OnnxFact]
public void OnnxSequenceTypeWithColumnNameAttributeTest()
{
var modelFile = Path.Combine(Directory.GetCurrentDirectory(), "zipmap", "TestZipMapString.onnx");
var predictor = LoadModel(modelFile);

FloatInput input = new FloatInput() { Input = new float[] { 1.0f, 2.0f, 3.0f } };
var output = predictor.Predict(input);
var onnx_out = output.Output.FirstOrDefault();
Assert.True(onnx_out.Count == 3, "Output missing data.");
var keys = new List<string>(onnx_out.Keys);
for(var i =0; i < onnx_out.Count; ++i)
{
Assert.Equal(onnx_out[keys[i]], input.Input[i]);
}

}
}
}