Skip to content

added in support for System.DateTime type for the DateTimeTransformer #4661

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 139 additions & 80 deletions src/Microsoft.ML.Featurizers/DateTimeTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ internal DateTimeEstimator(IHostEnvironment env, Options options)

public DateTimeTransformer Fit(IDataView input)
{
return new DateTimeTransformer(_host, _options.Source, _options.Prefix, _options.Country);
return new DateTimeTransformer(_host, _options.Source, _options.Prefix, _options.Country, input.Schema);
}

public SchemaShape GetOutputSchema(SchemaShape inputSchema)
Expand Down Expand Up @@ -246,14 +246,22 @@ public sealed class DateTimeTransformer : RowToRowTransformerBase, IDisposable
internal const string LoadName = "DateTimeTransform";
internal const string LoaderSignature = "DateTimeTransform";
private LongTypedColumn _column;
private DataViewSchema _schema;

#endregion

internal DateTimeTransformer(IHostEnvironment host, string inputColumnName, string columnPrefix, DateTimeEstimator.HolidayList country) :
internal DateTimeTransformer(IHostEnvironment host, string inputColumnName, string columnPrefix, DateTimeEstimator.HolidayList country, DataViewSchema schema) :
base(host.Register(nameof(DateTimeTransformer)))
{
host.Check(!CommonExtensions.OsIsCentOS7(), "CentOS7 is not supported");

_schema = schema;
if (_schema[inputColumnName].Type.RawType != typeof(long) &&
_schema[inputColumnName].Type.RawType != typeof(DateTime))
{
throw new Exception($"Unsupported type {_schema[inputColumnName].Type.RawType} for input column ${inputColumnName}. Only long and System.DateTime are supported");
}

_column = new LongTypedColumn(inputColumnName, columnPrefix);
_column.CreateTransformerFromEstimator(country);
}
Expand Down Expand Up @@ -346,6 +354,37 @@ protected override bool ReleaseHandle()

#region TimePoint

// Exact native representation
[StructLayout(LayoutKind.Sequential, Pack = 1)]
internal struct NativeTimePoint
{
public int Year;
public byte Month;
public byte Day;
public byte Hour;
public byte Minute;
public byte Second;
public byte AmPm;
public byte Hour12;
public byte DayOfWeek;
public byte DayOfQuarter;
public ushort DayOfYear;
public ushort WeekOfMonth;
public byte QuarterOfYear;
public byte HalfOfYear;
public byte WeekIso;
public int YearIso;
public IntPtr MonthLabelPointer;
public IntPtr MonthLabelSize;
public IntPtr AmPmLabelPointer;
public IntPtr AmPmLabelSize;
public IntPtr DayOfWeekLabelPointer;
public IntPtr DayOfWeekLabelSize;
public IntPtr HolidayNamePointer;
public IntPtr HolidayNameSize;
public byte IsPaidTimeOff;
}

[StructLayoutAttribute(LayoutKind.Sequential)]
internal struct TimePoint
{
Expand Down Expand Up @@ -443,9 +482,11 @@ private static unsafe string GetStringFromPointer(ref ReadOnlySpan<byte> rawData

};

#endregion
#endregion

#region ColumnInfo

#region BaseClass
#region BaseClass

internal delegate bool DestroyCppTransformerEstimator(IntPtr estimator, out IntPtr errorHandle);
internal delegate bool DestroyTransformerSaveData(IntPtr buffer, IntPtr bufferSize, out IntPtr errorHandle);
Expand All @@ -455,11 +496,16 @@ internal abstract class TypedColumn : IDisposable
{
internal readonly string Source;
internal readonly string Prefix;
internal readonly int IntPtrSize;
internal readonly int StructSize;

internal TypedColumn(string source, string prefix)
internal unsafe TypedColumn(string source, string prefix)
{
Source = source;
Prefix = prefix;
IntPtrSize = IntPtr.Size;

StructSize = sizeof(NativeTimePoint);
}

internal abstract void CreateTransformerFromEstimator(DateTimeEstimator.HolidayList country);
Expand Down Expand Up @@ -542,22 +588,17 @@ internal TypedColumn(string source, string prefix) :

}

#endregion
#endregion BaseClass

#region DateTimeTypedColumn
#region LongTypedColumn

internal sealed class LongTypedColumn : TypedColumn<long>
{
private TransformerEstimatorSafeHandle _transformerHandler;
private readonly int _intPtrSize;
private readonly int _structSize;

internal LongTypedColumn(string source, string prefix) :
base(source, prefix)
{
_intPtrSize = IntPtr.Size;

// The native struct is 25 bytes + 8 size_t.
_structSize = 25 + (_intPtrSize * 8);
}

[DllImport("Featurizers", EntryPoint = "DateTimeFeaturizer_CreateEstimator"), SuppressUnmanagedCodeSecurity]
Expand Down Expand Up @@ -601,10 +642,9 @@ internal override TimePoint Transform(long input)

using (var handler = new TransformedDataSafeHandle(output, DestroyTransformedDataNative))
{
// 29 plus size.
unsafe
{
return new TimePoint(new ReadOnlySpan<byte>(output.ToPointer(), _structSize), _intPtrSize);
return new TimePoint(new ReadOnlySpan<byte>(output.ToPointer(), StructSize), IntPtrSize);
}
}
}
Expand Down Expand Up @@ -633,18 +673,21 @@ private protected override bool CreateTransformerSaveDataHelper(out IntPtr buffe
CreateTransformerSaveDataNative(_transformerHandler, out buffer, out bufferSize, out errorHandle);
}

#endregion
#endregion LongTypedColumn

#endregion ColumnInfo

private sealed class Mapper : MapperBase
{

#region Class data members
#region Class data members
private static readonly DateTime _unixEpoch = new DateTime(1970, 1, 1);

private readonly DateTimeTransformer _parent;
private ConcurrentDictionary<long, TimePoint> _cache;
private ConcurrentQueue<long> _oldestKeys;

#endregion
#endregion

public Mapper(DateTimeTransformer parent, DataViewSchema inputSchema) :
base(parent.Host.Register(nameof(Mapper)), inputSchema, parent)
Expand All @@ -667,85 +710,101 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
return columns.ToArray();
}

private Delegate MakeGetter<T>(DataViewRow input, int iinfo)
private Delegate MakeGetter<TInput, TTransformed>(DataViewRow input, int iinfo)
{
// If already in posix time.
if (typeof(TInput) == typeof(long))
return MakeLongGetter<TTransformed>(input, iinfo);
// System.DateTime
else
return MakeDateTimeGetter<TTransformed>(input, iinfo);
}

private Delegate MakeLongGetter<TTransformed>(DataViewRow input, int iinfo)
{
var getter = input.GetGetter<long>(input.Schema[_parent._column.Source]);
ValueGetter<T> result = (ref T dst) =>
ValueGetter<TTransformed> result = (ref TTransformed dst) =>
{
long dateTime = default;
getter(ref dateTime);

if (!_cache.TryGetValue(dateTime, out TimePoint timePoint))
{
_cache[dateTime] = _parent._column.Transform(dateTime);
_oldestKeys.Enqueue(dateTime);
timePoint = _cache[dateTime];

// If more than 100 cached items, remove 20
if (_cache.Count > 100)
{
for (int i = 0; i < 20; i++)
{
long key;
while (!_oldestKeys.TryDequeue(out key)) { }
while (!_cache.TryRemove(key, out TimePoint removedValue)) { }
}
}
}
var timePoint = _parent._column.Transform(dateTime);

if (iinfo == 0)
dst = (T)Convert.ChangeType(timePoint.Year, typeof(T));
else if (iinfo == 1)
dst = (T)Convert.ChangeType(timePoint.Month, typeof(T));
else if (iinfo == 2)
dst = (T)Convert.ChangeType(timePoint.Day, typeof(T));
else if (iinfo == 3)
dst = (T)Convert.ChangeType(timePoint.Hour, typeof(T));
else if (iinfo == 4)
dst = (T)Convert.ChangeType(timePoint.Minute, typeof(T));
else if (iinfo == 5)
dst = (T)Convert.ChangeType(timePoint.Second, typeof(T));
else if (iinfo == 6)
dst = (T)Convert.ChangeType(timePoint.AmPm, typeof(T));
else if (iinfo == 7)
dst = (T)Convert.ChangeType(timePoint.Hour12, typeof(T));
else if (iinfo == 8)
dst = (T)Convert.ChangeType(timePoint.DayOfWeek, typeof(T));
else if (iinfo == 9)
dst = (T)Convert.ChangeType(timePoint.DayOfQuarter, typeof(T));
else if (iinfo == 10)
dst = (T)Convert.ChangeType(timePoint.DayOfYear, typeof(T));
else if (iinfo == 11)
dst = (T)Convert.ChangeType(timePoint.WeekOfMonth, typeof(T));
else if (iinfo == 12)
dst = (T)Convert.ChangeType(timePoint.QuarterOfYear, typeof(T));
else if (iinfo == 13)
dst = (T)Convert.ChangeType(timePoint.HalfOfYear, typeof(T));
else if (iinfo == 14)
dst = (T)Convert.ChangeType(timePoint.WeekIso, typeof(T));
else if (iinfo == 15)
dst = (T)Convert.ChangeType(timePoint.YearIso, typeof(T));
else if (iinfo == 16)
dst = (T)Convert.ChangeType(timePoint.MonthLabel.AsMemory(), typeof(T));
else if (iinfo == 17)
dst = (T)Convert.ChangeType(timePoint.AmPmLabel.AsMemory(), typeof(T));
else if (iinfo == 18)
dst = (T)Convert.ChangeType(timePoint.DayOfWeekLabel.AsMemory(), typeof(T));
else if (iinfo == 19)
dst = (T)Convert.ChangeType(timePoint.HolidayName.AsMemory(), typeof(T));
else
dst = (T)Convert.ChangeType(timePoint.IsPaidTimeOff, typeof(T));
dst = GetColumnFromStruct<TTransformed>(ref timePoint, iinfo);
};

return result;
}

private Delegate MakeDateTimeGetter<TTransformed>(DataViewRow input, int iinfo)
{
var getter = input.GetGetter<DateTime>(input.Schema[_parent._column.Source]);
ValueGetter<TTransformed> result = (ref TTransformed dst) =>
{
DateTime dateTime = default;
getter(ref dateTime);

var timePoint = _parent._column.Transform(dateTime.Subtract(_unixEpoch).Ticks / TimeSpan.TicksPerSecond);

dst = GetColumnFromStruct<TTransformed>(ref timePoint, iinfo);
};

return result;
}

private TTransformed GetColumnFromStruct<TTransformed>(ref TimePoint timePoint, int iinfo)
{
if (iinfo == 0)
return (TTransformed)Convert.ChangeType(timePoint.Year, typeof(TTransformed));
else if (iinfo == 1)
return (TTransformed)Convert.ChangeType(timePoint.Month, typeof(TTransformed));
else if (iinfo == 2)
return (TTransformed)Convert.ChangeType(timePoint.Day, typeof(TTransformed));
else if (iinfo == 3)
return (TTransformed)Convert.ChangeType(timePoint.Hour, typeof(TTransformed));
else if (iinfo == 4)
return (TTransformed)Convert.ChangeType(timePoint.Minute, typeof(TTransformed));
else if (iinfo == 5)
return (TTransformed)Convert.ChangeType(timePoint.Second, typeof(TTransformed));
else if (iinfo == 6)
return (TTransformed)Convert.ChangeType(timePoint.AmPm, typeof(TTransformed));
else if (iinfo == 7)
return (TTransformed)Convert.ChangeType(timePoint.Hour12, typeof(TTransformed));
else if (iinfo == 8)
return (TTransformed)Convert.ChangeType(timePoint.DayOfWeek, typeof(TTransformed));
else if (iinfo == 9)
return (TTransformed)Convert.ChangeType(timePoint.DayOfQuarter, typeof(TTransformed));
else if (iinfo == 10)
return (TTransformed)Convert.ChangeType(timePoint.DayOfYear, typeof(TTransformed));
else if (iinfo == 11)
return (TTransformed)Convert.ChangeType(timePoint.WeekOfMonth, typeof(TTransformed));
else if (iinfo == 12)
return (TTransformed)Convert.ChangeType(timePoint.QuarterOfYear, typeof(TTransformed));
else if (iinfo == 13)
return (TTransformed)Convert.ChangeType(timePoint.HalfOfYear, typeof(TTransformed));
else if (iinfo == 14)
return (TTransformed)Convert.ChangeType(timePoint.WeekIso, typeof(TTransformed));
else if (iinfo == 15)
return (TTransformed)Convert.ChangeType(timePoint.YearIso, typeof(TTransformed));
else if (iinfo == 16)
return (TTransformed)Convert.ChangeType(timePoint.MonthLabel.AsMemory(), typeof(TTransformed));
else if (iinfo == 17)
return (TTransformed)Convert.ChangeType(timePoint.AmPmLabel.AsMemory(), typeof(TTransformed));
else if (iinfo == 18)
return (TTransformed)Convert.ChangeType(timePoint.DayOfWeekLabel.AsMemory(), typeof(TTransformed));
else if (iinfo == 19)
return (TTransformed)Convert.ChangeType(timePoint.HolidayName.AsMemory(), typeof(TTransformed));
else
return (TTransformed)Convert.ChangeType(timePoint.IsPaidTimeOff, typeof(TTransformed));
}

protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, bool> activeOutput, out Action disposer)
{
disposer = null;

// Have to add 1 to iinfo since the enum starts at 1
return Utils.MarshalInvoke(MakeGetter<int>, ((DateTimeEstimator.ColumnsProduced)iinfo + 1).GetRawColumnType(), input, iinfo);
return Utils.MarshalInvoke(MakeGetter<int, int>, new Type[] { input.Schema[_parent._column.Source].Type.RawType, ((DateTimeEstimator.ColumnsProduced)iinfo + 1).GetRawColumnType() }, input, iinfo);

}

private protected override Func<int, bool> GetDependenciesCore(Func<int, bool> activeOutput)
Expand Down
Loading