Skip to content

Adding a LoadColumnNameAttribute #4308

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 89 additions & 25 deletions src/Microsoft.ML.Experimental/DataLoadSave/Database/DatabaseLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,21 @@ internal static DatabaseLoader CreateDatabaseLoader<TInput>(IHostEnvironment hos
var column = new Column();
column.Name = mappingAttrName?.Name ?? memberInfo.Name;

var mappingAttr = memberInfo.GetCustomAttribute<LoadColumnAttribute>();
var indexMappingAttr = memberInfo.GetCustomAttribute<LoadColumnAttribute>();
var nameMappingAttr = memberInfo.GetCustomAttribute<LoadColumnNameAttribute>();

if (mappingAttr is object)
if (indexMappingAttr is object)
{
var sources = mappingAttr.Sources.Select((source) => Range.FromTextLoaderRange(source)).ToArray();
column.Source = sources;
if (nameMappingAttr is object)
{
throw Contracts.Except($"Cannot specify both {nameof(LoadColumnAttribute)} and {nameof(LoadColumnNameAttribute)}");
}

column.Source = indexMappingAttr.Sources.Select((source) => Range.FromTextLoaderRange(source)).ToArray();
}
else if (nameMappingAttr is object)
{
column.Source = nameMappingAttr.Sources.Select((source) => new Range(source)).ToArray();
}

InternalDataKind dk;
Expand Down Expand Up @@ -228,7 +237,7 @@ public Column(string name, DbType dbType, Range[] source, KeyCount keyCount = nu
public DbType Type = DbType.Single;

/// <summary>
/// Source index range(s) of the column.
/// Source index or name range(s) of the column.
/// </summary>
[Argument(ArgumentType.Multiple, HelpText = "Source index range(s) of the column", ShortName = "src")]
public Range[] Source;
Expand All @@ -241,7 +250,7 @@ public Column(string name, DbType dbType, Range[] source, KeyCount keyCount = nu
}

/// <summary>
/// Specifies the range of indices of input columns that should be mapped to an output column.
/// Specifies the range of indices or names of input columns that should be mapped to an output column.
/// </summary>
public sealed class Range
{
Expand All @@ -256,6 +265,19 @@ public Range(int index)
Contracts.CheckParam(index >= 0, nameof(index), "Must be non-negative");
Min = index;
Max = index;
Name = null;
}

/// <summary>
/// A range representing a single value. Will result in a scalar column.
/// </summary>
/// <param name="name">The name of the field of the table to read.</param>
public Range(string name)
{
Contracts.CheckValue(name, nameof(name));
Min = -1;
Max = -1;
Name = name;
}

/// <summary>
Expand All @@ -278,15 +300,30 @@ public Range(int min, int max)
/// <summary>
/// The minimum index of the column, inclusive.
/// </summary>
/// <remarks>
/// This value is ignored if <see cref="Name" /> is not <c>null</c>.
/// </remarks>
[Argument(ArgumentType.Required, HelpText = "First index in the range")]
public int Min;

/// <summary>
/// The maximum index of the column, inclusive.
/// </summary>
/// <remarks>
/// This value is ignored if <see cref="Name" /> is not <c>null</c>.
/// </remarks>
[Argument(ArgumentType.AtMostOnce, HelpText = "Last index in the range")]
public int Max;

/// <summary>
/// The name of the input column.
/// </summary>
/// <remarks>
/// This value, if non-<c>null</c>, overrides <see cref="Min" /> and <see cref="Max" />.
/// </remarks>
[Argument(ArgumentType.AtMostOnce, HelpText = "Name of the column")]
public string Name;

/// <summary>
/// Force scalar columns to be treated as vectors of length one.
/// </summary>
Expand Down Expand Up @@ -318,17 +355,28 @@ public sealed class Options
/// </summary>
internal readonly struct Segment
{
public readonly string Name;
public readonly int Min;
public readonly int Lim;
public readonly bool ForceVector;

public Segment(int min, int lim, bool forceVector)
{
Contracts.Assert(0 <= min & min < lim);
Name = null;
Min = min;
Lim = lim;
ForceVector = forceVector;
}

public Segment(string name, bool forceVector)
{
Contracts.Assert(name != null);
Name = name;
Min = -1;
Lim = -1;
ForceVector = forceVector;
}
}

/// <summary>
Expand Down Expand Up @@ -368,19 +416,23 @@ public static ColInfo Create(string name, PrimitiveDataViewType itemType, Segmen
if (segs != null)
{
var order = Utils.GetIdentityPermutation(segs.Length);
Array.Sort(order, (x, y) => segs[x].Min.CompareTo(segs[y].Min));

// Check that the segments are disjoint.
for (int i = 1; i < order.Length; i++)
if ((segs.Length != 0) && (segs[0].Name is null))
{
int a = order[i - 1];
int b = order[i];
Contracts.Assert(segs[a].Min <= segs[b].Min);
if (segs[a].Lim > segs[b].Min)
Array.Sort(order, (x, y) => segs[x].Min.CompareTo(segs[y].Min));

// Check that the segments are disjoint.
for (int i = 1; i < order.Length; i++)
{
throw user ?
Contracts.ExceptUserArg(nameof(Column.Source), "Intervals specified for column '{0}' overlap", name) :
Contracts.ExceptDecode("Intervals specified for column '{0}' overlap", name);
int a = order[i - 1];
int b = order[i];
Contracts.Assert(segs[a].Min <= segs[b].Min);
if (segs[a].Lim > segs[b].Min)
{
throw user ?
Contracts.ExceptUserArg(nameof(Column.Source), "Intervals specified for column '{0}' overlap", name) :
Contracts.ExceptDecode("Intervals specified for column '{0}' overlap", name);
}
}
}

Expand All @@ -389,7 +441,7 @@ public static ColInfo Create(string name, PrimitiveDataViewType itemType, Segmen
for (int i = 0; i < segs.Length; i++)
{
var seg = segs[i];
size += seg.Lim - seg.Min;
size += (seg.Name is null) ? seg.Lim - seg.Min : 1;
}
Contracts.Assert(size >= segs.Length);

Expand Down Expand Up @@ -454,15 +506,23 @@ public Bindings(DatabaseLoader parent, Column[] cols)
for (int i = 0; i < segs.Length; i++)
{
var range = col.Source[i];

int min = range.Min;
ch.CheckUserArg(0 <= min, nameof(range.Min));

Segment seg;

int max = range.Max;
ch.CheckUserArg(min <= max, nameof(range.Max));
seg = new Segment(min, max + 1, range.ForceVector);
if (range.Name is null)
{
int min = range.Min;
ch.CheckUserArg(0 <= min, nameof(range.Min));

int max = range.Max;
ch.CheckUserArg(min <= max, nameof(range.Max));
seg = new Segment(min, max + 1, range.ForceVector);
}
else
{
string columnName = range.Name;
ch.CheckUserArg(columnName != null, nameof(range.Name));
seg = new Segment(columnName, range.ForceVector);
}

segs[i] = seg;
}
Expand Down Expand Up @@ -490,6 +550,7 @@ public Bindings(ModelLoadContext ctx, DatabaseLoader parent)
// ulong: count for key range
// int: number of segments
// foreach segment:
// string id: name
// int: min
// int: lim
// byte: force vector (verWrittenCur: verIsVectorSupported)
Expand Down Expand Up @@ -532,11 +593,12 @@ public Bindings(ModelLoadContext ctx, DatabaseLoader parent)
segs = new Segment[cseg];
for (int iseg = 0; iseg < cseg; iseg++)
{
string columnName = ctx.LoadStringOrNull();
int min = ctx.Reader.ReadInt32();
int lim = ctx.Reader.ReadInt32();
Contracts.CheckDecode(0 <= min && min < lim);
bool forceVector = ctx.Reader.ReadBoolByte();
segs[iseg] = new Segment(min, lim, forceVector);
segs[iseg] = (columnName is null) ? new Segment(min, lim, forceVector) : new Segment(columnName, forceVector);
}
}

Expand All @@ -563,6 +625,7 @@ internal void Save(ModelSaveContext ctx)
// ulong: count for key range
// int: number of segments
// foreach segment:
// string id: name
// int: min
// int: lim
// byte: force vector (verWrittenCur: verIsVectorSupported)
Expand All @@ -588,6 +651,7 @@ internal void Save(ModelSaveContext ctx)
ctx.Writer.Write(info.Segments.Length);
foreach (var seg in info.Segments)
{
ctx.SaveStringOrNull(seg.Name);
ctx.Writer.Write(seg.Min);
ctx.Writer.Write(seg.Lim);
ctx.Writer.WriteBoolByte(seg.ForceVector);
Expand Down
Loading