Skip to content

Commit 034c81c

Browse files
committed
Make type registration framework
Register image type
1 parent f154bb0 commit 034c81c

File tree

7 files changed

+149
-27
lines changed

7 files changed

+149
-27
lines changed

src/Microsoft.ML.Data/Data/SchemaDefinition.cs

+30-25
Original file line numberDiff line numberDiff line change
@@ -392,37 +392,42 @@ public static SchemaDefinition Create(Type userType, Direction direction = Direc
392392

393393
InternalSchemaDefinition.GetVectorAndItemType(memberInfo, out bool isVector, out Type dataType);
394394

395-
PrimitiveDataViewType itemType;
396-
var keyAttr = memberInfo.GetCustomAttribute<KeyTypeAttribute>();
397-
if (keyAttr != null)
398-
{
399-
if (!KeyDataViewType.IsValidDataType(dataType))
400-
throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", memberInfo.Name);
401-
if (keyAttr.KeyCount == null)
402-
itemType = new KeyDataViewType(dataType, dataType.ToMaxInt());
403-
else
404-
itemType = new KeyDataViewType(dataType, keyAttr.KeyCount.Count.GetValueOrDefault());
405-
}
406-
else
407-
itemType = ColumnTypeExtensions.PrimitiveTypeFromType(dataType);
408-
409395
// Get the column type.
410396
DataViewType columnType;
411-
var vectorAttr = memberInfo.GetCustomAttribute<VectorTypeAttribute>();
412-
if (vectorAttr != null && !isVector)
413-
throw Contracts.ExceptParam(nameof(userType), $"Member {memberInfo.Name} marked with {nameof(VectorTypeAttribute)}, but does not appear to be a vector type", memberInfo.Name);
414-
if (isVector)
397+
if (TypeManager.GetDataViewTypeOrNull(dataType) == null)
415398
{
416-
int[] dims = vectorAttr?.Dims;
417-
if (dims != null && dims.Any(d => d < 0))
418-
throw Contracts.ExceptParam(nameof(userType), "Some of member {0}'s dimension lengths are negative");
419-
if (Utils.Size(dims) == 0)
420-
columnType = new VectorDataViewType(itemType, 0);
399+
PrimitiveDataViewType itemType;
400+
var keyAttr = memberInfo.GetCustomAttribute<KeyTypeAttribute>();
401+
if (keyAttr != null)
402+
{
403+
if (!KeyDataViewType.IsValidDataType(dataType))
404+
throw Contracts.ExceptParam(nameof(userType), "Member {0} marked with KeyType attribute, but does not appear to be a valid kind of data for a key type", memberInfo.Name);
405+
if (keyAttr.KeyCount == null)
406+
itemType = new KeyDataViewType(dataType, dataType.ToMaxInt());
407+
else
408+
itemType = new KeyDataViewType(dataType, keyAttr.KeyCount.Count.GetValueOrDefault());
409+
}
410+
else
411+
itemType = ColumnTypeExtensions.PrimitiveTypeFromType(dataType);
412+
413+
var vectorAttr = memberInfo.GetCustomAttribute<VectorTypeAttribute>();
414+
if (vectorAttr != null && !isVector)
415+
throw Contracts.ExceptParam(nameof(userType), $"Member {memberInfo.Name} marked with {nameof(VectorTypeAttribute)}, but does not appear to be a vector type", memberInfo.Name);
416+
if (isVector)
417+
{
418+
int[] dims = vectorAttr?.Dims;
419+
if (dims != null && dims.Any(d => d < 0))
420+
throw Contracts.ExceptParam(nameof(userType), "Some of member {0}'s dimension lengths are negative");
421+
if (Utils.Size(dims) == 0)
422+
columnType = new VectorDataViewType(itemType, 0);
423+
else
424+
columnType = new VectorDataViewType(itemType, dims);
425+
}
421426
else
422-
columnType = new VectorDataViewType(itemType, dims);
427+
columnType = itemType;
423428
}
424429
else
425-
columnType = itemType;
430+
columnType = TypeManager.GetDataViewTypeOrNull(dataType);
426431

427432
cols.Add(new Column(memberInfo.Name, columnType, name));
428433
}

src/Microsoft.ML.Data/DataView/InternalSchemaDefinition.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ public static void GetVectorAndItemType(Type rawType, string name, out bool isVe
187187

188188
if (itemType == typeof(string))
189189
itemType = typeof(ReadOnlyMemory<char>);
190-
else if (!itemType.TryGetDataKind(out _))
190+
else if (!itemType.TryGetDataKind(out _) && TypeManager.GetDataViewTypeOrNull(itemType) == null)
191191
throw Contracts.ExceptParam(nameof(rawType), "Could not determine an IDataView type for member {0}", name);
192192
}
193193

src/Microsoft.ML.Data/DataView/TypedCursor.cs

+4
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,10 @@ private Action<TRow> GenerateSetter(DataViewRow input, int index, InternalSchema
319319

320320
del = CreateDirectSetter<int>;
321321
}
322+
else if (TypeManager.GetRawTypeOrNull(colType) != null)
323+
{
324+
del = CreateDirectSetter<int>;
325+
}
322326
else
323327
{
324328
// REVIEW: Is this even possible?

src/Microsoft.ML.Data/Utils/ApiUtils.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ private static OpCode GetAssignmentOpCode(Type t)
2323
if (t == typeof(ReadOnlyMemory<char>) || t == typeof(string) || t.IsArray ||
2424
(t.IsGenericType && t.GetGenericTypeDefinition() == typeof(VBuffer<>)) ||
2525
(t.IsGenericType && t.GetGenericTypeDefinition() == typeof(Nullable<>)) ||
26-
t == typeof(DateTime) || t == typeof(DateTimeOffset) || t == typeof(TimeSpan) || t == typeof(DataViewRowId))
26+
t == typeof(DateTime) || t == typeof(DateTimeOffset) || t == typeof(TimeSpan) ||
27+
t == typeof(DataViewRowId) || TypeManager.GetDataViewTypeOrNull(t) != null)
2728
{
2829
return OpCodes.Stobj;
2930
}
+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
using System;
2+
using System.Collections.Concurrent;
3+
using System.Collections.Generic;
4+
using Microsoft.ML.Internal.DataView;
5+
6+
namespace Microsoft.ML.Data
7+
{
8+
public static class TypeManager
9+
{
10+
// Types have been used in ML.NET type systems. They can have multiple-to-one type mapping.
11+
// For example, UInt32 and Key can be mapped to uint. We enforce one-to-one mapping for all
12+
// user-registered types.
13+
private static HashSet<Type> _notAllowedRawTypes;
14+
private static ConcurrentDictionary<Type, DataViewType> _rawTypeToDataViewTypeMap;
15+
private static ConcurrentDictionary<DataViewType, Type> _dataViewTypeToRawTypeMap;
16+
17+
/// <summary>
18+
/// Constructor to initialize type mappings.
19+
/// </summary>
20+
static TypeManager()
21+
{
22+
_notAllowedRawTypes = new HashSet<Type>() {
23+
typeof(Boolean), typeof(SByte), typeof(Byte),
24+
typeof(Int16), typeof(UInt16), typeof(Int32), typeof(UInt32),
25+
typeof(Int64), typeof(UInt64), typeof(string), typeof(ReadOnlySpan<char>)
26+
};
27+
_rawTypeToDataViewTypeMap = new ConcurrentDictionary<Type, DataViewType>();
28+
_dataViewTypeToRawTypeMap = new ConcurrentDictionary<DataViewType, Type>();
29+
}
30+
31+
public static DataViewType GetDataViewTypeOrNull(Type type)
32+
{
33+
if (_rawTypeToDataViewTypeMap.ContainsKey(type))
34+
return _rawTypeToDataViewTypeMap[type];
35+
else
36+
return null;
37+
}
38+
39+
public static Type GetRawTypeOrNull(DataViewType type)
40+
{
41+
if (_dataViewTypeToRawTypeMap.ContainsKey(type))
42+
return _dataViewTypeToRawTypeMap[type];
43+
else
44+
return null;
45+
}
46+
47+
public static void Register(Type rawType, DataViewType dataViewType)
48+
{
49+
if (_notAllowedRawTypes.Contains(rawType))
50+
throw Contracts.ExceptParam(nameof(rawType), $"Type {rawType} has been registered as ML.NET's default type. " +
51+
$"so it can't not be registered again.");
52+
if (_rawTypeToDataViewTypeMap.ContainsKey(rawType))
53+
throw Contracts.ExceptParam(nameof(rawType), $"Repeated type registration. The raw type {rawType} " +
54+
$"has been associated with {_rawTypeToDataViewTypeMap[rawType]}.");
55+
_rawTypeToDataViewTypeMap[rawType] = dataViewType;
56+
_dataViewTypeToRawTypeMap[dataViewType] = rawType;
57+
}
58+
}
59+
}

src/Microsoft.ML.ImageAnalytics/ImageType.cs

+6
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ public sealed class ImageDataViewType : StructuredDataViewType
1313
{
1414
public readonly int Height;
1515
public readonly int Width;
16+
17+
static ImageDataViewType()
18+
{
19+
TypeManager.Register(typeof(Bitmap), new ImageDataViewType());
20+
}
21+
1622
public ImageDataViewType(int height, int width)
1723
: base(typeof(Bitmap))
1824
{

test/Microsoft.ML.Tests/ImagesTests.cs

+47
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,53 @@ public void TestGreyscaleTransformImages()
184184
Done();
185185
}
186186

187+
[Fact]
188+
public void TestGrayScaleInMemory()
189+
{
190+
var imagesDataFile = SamplesUtils.DatasetUtils.DownloadImages();
191+
192+
var data = ML.Data.CreateTextLoader(new TextLoader.Options()
193+
{
194+
Columns = new[]
195+
{
196+
new TextLoader.Column("ImagePath", DataKind.String, 0),
197+
new TextLoader.Column("Name", DataKind.String, 1),
198+
}
199+
}).Load(imagesDataFile);
200+
201+
var imagesFolder = Path.GetDirectoryName(imagesDataFile);
202+
// Image loading and conversion pipeline.
203+
var pipeline = ML.Transforms.LoadImages("ImageObject", imagesFolder, "ImagePath")
204+
.Append(ML.Transforms.ConvertToGrayscale("Grayscale", "ImageObject"));
205+
206+
var transformedData = pipeline.Fit(data).Transform(data);
207+
208+
var transformedDataPoints = ML.Data.CreateEnumerable<TransformedImageDataPoint>(transformedData, true).ToList();
209+
210+
foreach (var datapoint in transformedDataPoints)
211+
{
212+
var image = datapoint.Grayscale;
213+
Assert.NotNull(image);
214+
for (int x = 0; x < image.Width; x++)
215+
{
216+
for (int y = 0; y < image.Height; y++)
217+
{
218+
var pixel = image.GetPixel(x, y);
219+
// greyscale image has same values for R, G and B.
220+
Assert.True(pixel.R == pixel.G && pixel.G == pixel.B);
221+
}
222+
}
223+
}
224+
}
225+
226+
private class TransformedImageDataPoint
227+
{
228+
public string ImagePath { get; set; }
229+
public string Name { get; set; }
230+
public Bitmap ImageObject { get; set; }
231+
public Bitmap Grayscale { get; set; }
232+
}
233+
187234
[Fact]
188235
public void TestBackAndForthConversionWithAlphaInterleave()
189236
{

0 commit comments

Comments
 (0)