Skip to content

Commit 00d4ce7

Browse files
committed
Robust Scaler now added to the Normalizer catalog
1 parent c023271 commit 00d4ce7

File tree

8 files changed

+1336
-154
lines changed

8 files changed

+1336
-154
lines changed

src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@ private static class Defaults
152152
public const bool LogMeanVarCdf = true;
153153
public const int NumBins = 1024;
154154
public const int MinBinSize = 10;
155+
public const bool CenterData = true;
156+
public const int QuantileMin = 25;
157+
public const int QuantileMax = 75;
155158
}
156159

157160
public abstract class ControlZeroArgumentsBase : ArgumentsBase
@@ -245,6 +248,18 @@ public sealed class SupervisedBinArguments : BinArgumentsBase
245248
public int MinBinSize = Defaults.MinBinSize;
246249
}
247250

251+
public sealed class RobustScalingArguments : AffineArgumentsBase
252+
{
253+
[Argument(ArgumentType.AtMostOnce, HelpText = "Should the data be centered around 0", Name = "CenterData", ShortName = "center", SortOrder = 1)]
254+
public bool CenterData = Defaults.CenterData;
255+
256+
[Argument(ArgumentType.AtMostOnce, HelpText = "Minimum quantile value. Defaults to 25", Name = "QuantileMin", ShortName = "qmin", SortOrder = 2)]
257+
public uint QuantileMin = Defaults.QuantileMin;
258+
259+
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum quantile value. Defaults to 75", Name = "QuantileMax", ShortName = "qmax", SortOrder = 3)]
260+
public uint QuantileMax = Defaults.QuantileMax;
261+
}
262+
248263
internal const string MinMaxNormalizerSummary = "Normalizes the data based on the observed minimum and maximum values of the data.";
249264
internal const string MeanVarNormalizerSummary = "Normalizes the data based on the computed mean and variance of the data.";
250265
internal const string LogMeanVarNormalizerSummary = "Normalizes the data based on the computed mean and variance of the logarithm of the data.";
@@ -1145,6 +1160,46 @@ public static int GetLabelColumnId(IExceptionContext host, DataViewSchema schema
11451160
return labelColumnId;
11461161
}
11471162
}
1163+
1164+
internal static partial class RobustScaleUtils
1165+
{
1166+
public static IColumnFunctionBuilder CreateBuilder(RobustScalingArguments args, IHost host,
1167+
int icol, int srcIndex, DataViewType srcType, DataViewRowCursor cursor)
1168+
{
1169+
Contracts.AssertValue(host);
1170+
host.AssertValue(args);
1171+
1172+
return CreateBuilder(new NormalizingEstimator.RobustScalingColumnOptions(
1173+
args.Columns[icol].Name,
1174+
args.Columns[icol].Source ?? args.Columns[icol].Name,
1175+
args.Columns[icol].MaximumExampleCount ?? args.MaximumExampleCount,
1176+
args.CenterData,
1177+
args.QuantileMin,
1178+
args.QuantileMax), host, srcIndex, srcType, cursor);
1179+
}
1180+
1181+
public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.RobustScalingColumnOptions column, IHost host,
1182+
int srcIndex, DataViewType srcType, DataViewRowCursor cursor)
1183+
{
1184+
var srcColumn = cursor.Schema[srcIndex];
1185+
if (srcType is NumberDataViewType)
1186+
{
1187+
if (srcType == NumberDataViewType.Single)
1188+
return Sng.RobustScalerOneColumnFunctionBuilder.Create(column, host, srcType, column.CenterData, column.QuantileMin, column.QuantileMax, cursor.GetGetter<Single>(srcColumn));
1189+
if (srcType == NumberDataViewType.Double)
1190+
return Dbl.RobustScalerOneColumnFunctionBuilder.Create(column, host, srcType, column.CenterData, column.QuantileMin, column.QuantileMax, cursor.GetGetter<double>(srcColumn));
1191+
}
1192+
if (srcType is VectorDataViewType vectorType && vectorType.IsKnownSize && vectorType.ItemType is NumberDataViewType)
1193+
{
1194+
if (vectorType.ItemType == NumberDataViewType.Single)
1195+
return Sng.RobustScalerVecFunctionBuilder.Create(column, host, srcType as VectorDataViewType, column.CenterData, column.QuantileMin, column.QuantileMax, cursor.GetGetter<VBuffer<float>>(srcColumn));
1196+
if (vectorType.ItemType == NumberDataViewType.Double)
1197+
return Dbl.RobustScalerVecFunctionBuilder.Create(column, host, srcType as VectorDataViewType, column.CenterData, column.QuantileMin, column.QuantileMax, cursor.GetGetter<VBuffer<double>>(srcColumn));
1198+
}
1199+
1200+
throw host.ExceptParam(nameof(srcType), "Wrong column type for input column. Expected: Single, Double, Vector of Single or Vector of Double. Got: {0}.", srcType.ToString());
1201+
}
1202+
}
11481203
}
11491204

11501205
internal static partial class AffineNormSerializationUtils

src/Microsoft.ML.Data/Transforms/NormalizeColumnDbl.cs

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,117 @@ private void Update(int j, TFloat origVal)
516516
}
517517
}
518518

519+
[BestFriend]
520+
internal static partial class MedianAggregatorUtils
521+
{
522+
/// <summary>
523+
/// Based on the algorithm on GeeksForGeeks https://www.geeksforgeeks.org/median-of-stream-of-integers-running-integers/.
524+
/// </summary>
525+
/// <param name="num">The new number to account for in our median calculation.</param>
526+
/// <param name="median">The current median.</param>
527+
/// <param name="belowMedianHeap">The MaxHeap that has all the numbers below the median.</param>
528+
/// <param name="aboveMedianHeap">The MinHeap that has all the numbers above the median.</param>
529+
[BestFriend]
530+
internal static void GetMedianSoFar(in double num, ref double median, ref MaxHeap<double> belowMedianHeap, ref MinHeap<double> aboveMedianHeap)
531+
{
532+
int comparison = belowMedianHeap.Count().CompareTo(aboveMedianHeap.Count());
533+
534+
if (comparison < 0)
535+
{ // More elements in aboveMedianHeap than belowMedianHeap.
536+
if (num < median)
537+
{ // Current element belongs in the belowMedianHeap.
538+
// Insert new number into belowMedianHeap
539+
belowMedianHeap.Add(num);
540+
541+
}
542+
else
543+
{ // Current element belongs in aboveMedianHeap.
544+
// Need to move one to belowMedianHeap to keep heeps balanced.
545+
belowMedianHeap.Add(aboveMedianHeap.Pop());
546+
547+
aboveMedianHeap.Add(num);
548+
}
549+
550+
// Both heaps are balanced so median is the average of the 2 heaps.
551+
median = (aboveMedianHeap.Peek() + belowMedianHeap.Peek()) / 2;
552+
553+
}
554+
else if (comparison == 0)
555+
{ // Both heaps have the same number of elements. Simple put the number where it belongs.
556+
if (num < median)
557+
{ // Current element belongs in the belowMedianHeap.
558+
belowMedianHeap.Add(num);
559+
560+
// Now we have an odd number of items, median is the new root of the belowMedianHeap
561+
median = belowMedianHeap.Peek();
562+
563+
}
564+
else
565+
{ // Current element belongs in above median heap.
566+
aboveMedianHeap.Add(num);
567+
568+
// Now we have an odd number of items, median is the new root of the aboveMedianHeap
569+
median = aboveMedianHeap.Peek();
570+
}
571+
572+
}
573+
else
574+
{ // More elements in belowMedianHeap than aboveMedianHeap.
575+
if (num < median)
576+
{ // Current element belongs in the belowMedianHeap.
577+
// Need to move one to aboveMedianHeap to keep heeps balanced.
578+
aboveMedianHeap.Add(belowMedianHeap.Pop());
579+
580+
// Insert new number into belowMedianHeap
581+
belowMedianHeap.Add(num);
582+
583+
}
584+
else
585+
{ // Current element belongs in aboveMedianHeap.
586+
aboveMedianHeap.Add(num);
587+
}
588+
589+
// Both heaps are balanced so median is the average of the 2 heaps.
590+
median = (aboveMedianHeap.Peek() + belowMedianHeap.Peek()) / 2;
591+
}
592+
}
593+
}
594+
595+
/// <summary>
596+
/// Base class for tracking median values for a single valued column.
597+
/// It tracks median values of non-sparse values (vCount).
598+
/// NaNs are ignored when updating min and max.
599+
/// </summary>
600+
internal sealed class MedianDblAggregator : IColumnAggregator<double>
601+
{
602+
private MedianAggregatorUtils.MaxHeap<double> _belowMedianHeap;
603+
private MedianAggregatorUtils.MinHeap<double> _aboveMedianHeap;
604+
private double _median;
605+
606+
public MedianDblAggregator(int contatinerStartingSize = 1000)
607+
{
608+
Contracts.Check(contatinerStartingSize > 0);
609+
_belowMedianHeap = new MedianAggregatorUtils.MaxHeap<double>(contatinerStartingSize);
610+
_aboveMedianHeap = new MedianAggregatorUtils.MinHeap<double>(contatinerStartingSize);
611+
_median = default;
612+
}
613+
614+
public double Median
615+
{
616+
get { return _median; }
617+
}
618+
619+
public void ProcessValue(in double value)
620+
{
621+
MedianAggregatorUtils.GetMedianSoFar(value, ref _median, ref _belowMedianHeap, ref _aboveMedianHeap);
622+
}
623+
624+
public void Finish()
625+
{
626+
// Finish is a no-op because we are updating the median continually as we go
627+
}
628+
}
629+
519630
internal sealed partial class NormalizeTransform
520631
{
521632
internal abstract partial class AffineColumnFunction
@@ -1912,6 +2023,144 @@ public static IColumnFunctionBuilder Create(NormalizingEstimator.SupervisedBinni
19122023
return new SupervisedBinVecColumnFunctionBuilder(host, lim, fix, numBins, column.MininimumBinSize, valueColumnId, labelColumnId, dataRow);
19132024
}
19142025
}
2026+
2027+
public sealed class RobustScalerOneColumnFunctionBuilder : OneColumnFunctionBuilderBase<double>
2028+
{
2029+
private readonly MinMaxDblAggregator _minMaxAggregator;
2030+
private readonly MedianDblAggregator _medianAggregator;
2031+
private readonly bool _centerData;
2032+
private readonly uint _quantileMin;
2033+
private readonly uint _quantileMax;
2034+
private VBuffer<double> _buffer;
2035+
2036+
private RobustScalerOneColumnFunctionBuilder(IHost host, long lim, bool centerData, uint quantileMin, uint quantileMax, ValueGetter<double> getSrc)
2037+
: base(host, lim, getSrc)
2038+
{
2039+
// Using the MinMax aggregator since that is what needs to be found here as well.
2040+
// The difference is how the min/max are used.
2041+
_minMaxAggregator = new MinMaxDblAggregator(1);
2042+
_medianAggregator = new MedianDblAggregator();
2043+
_buffer = new VBuffer<double>(1, new double[1]);
2044+
_centerData = centerData;
2045+
_quantileMin = quantileMin;
2046+
_quantileMax = quantileMax;
2047+
}
2048+
2049+
protected override bool ProcessValue(in double val)
2050+
{
2051+
if (!base.ProcessValue(in val))
2052+
return false;
2053+
VBufferEditor.CreateFromBuffer(ref _buffer).Values[0] = val;
2054+
_minMaxAggregator.ProcessValue(in _buffer);
2055+
_medianAggregator.ProcessValue(in val);
2056+
return true;
2057+
}
2058+
2059+
public static IColumnFunctionBuilder Create(NormalizingEstimator.RobustScalingColumnOptions column, IHost host, DataViewType srcType,
2060+
bool centerData, uint quantileMin, uint quantileMax, ValueGetter<double> getter)
2061+
{
2062+
host.CheckUserArg(column.MaximumExampleCount > 1, nameof(column.MaximumExampleCount), "Must be greater than 1");
2063+
return new RobustScalerOneColumnFunctionBuilder(host, column.MaximumExampleCount, centerData, quantileMin, quantileMax, getter);
2064+
}
2065+
2066+
public override IColumnFunction CreateColumnFunction()
2067+
{
2068+
_minMaxAggregator.Finish();
2069+
_medianAggregator.Finish();
2070+
2071+
double median = _medianAggregator.Median;
2072+
double range = _minMaxAggregator.Max[0] - _minMaxAggregator.Min[0];
2073+
// Divide the range by 100 because we need to make the number, i.e. 75, into a decimal, .75
2074+
double quantileRange = (_quantileMax - _quantileMin) / 100f;
2075+
double scale = 1 / (range * quantileRange);
2076+
2077+
if (_centerData)
2078+
return AffineColumnFunction.Create(Host, scale, median);
2079+
else
2080+
return AffineColumnFunction.Create(Host, scale, 0);
2081+
}
2082+
}
2083+
2084+
public sealed class RobustScalerVecFunctionBuilder : OneColumnFunctionBuilderBase<VBuffer<double>>
2085+
{
2086+
private readonly MinMaxDblAggregator _minMaxAggregator;
2087+
private readonly MedianDblAggregator[] _medianAggregators;
2088+
private readonly bool _centerData;
2089+
private readonly uint _quantileMin;
2090+
private readonly uint _quantileMax;
2091+
2092+
private RobustScalerVecFunctionBuilder(IHost host, long lim, int vectorSize, bool centerData, uint quantileMin, uint quantileMax, ValueGetter<VBuffer<double>> getSrc)
2093+
: base(host, lim, getSrc)
2094+
{
2095+
// Using the MinMax aggregator since that is what needs to be found here as well.
2096+
// The difference is how the min/max are used.
2097+
_minMaxAggregator = new MinMaxDblAggregator(vectorSize);
2098+
2099+
// If we aren't centering data dont need the median.
2100+
_medianAggregators = new MedianDblAggregator[vectorSize];
2101+
2102+
for (int i = 0; i < vectorSize; i++)
2103+
{
2104+
_medianAggregators[i] = new MedianDblAggregator();
2105+
}
2106+
2107+
_centerData = centerData;
2108+
_quantileMin = quantileMin;
2109+
_quantileMax = quantileMax;
2110+
}
2111+
2112+
protected override bool ProcessValue(in VBuffer<double> val)
2113+
{
2114+
if (!base.ProcessValue(in val))
2115+
return false;
2116+
_minMaxAggregator.ProcessValue(in val);
2117+
2118+
// Have to calculate the median per slot
2119+
var span = val.GetValues();
2120+
for (int i = 0; i < _medianAggregators.Length; i++)
2121+
{
2122+
_medianAggregators[i].ProcessValue(span[i]);
2123+
}
2124+
2125+
return true;
2126+
}
2127+
2128+
public static IColumnFunctionBuilder Create(NormalizingEstimator.RobustScalingColumnOptions column, IHost host, VectorDataViewType srcType,
2129+
bool centerData, uint quantileMin, uint quantileMax, ValueGetter<VBuffer<double>> getter)
2130+
{
2131+
host.CheckUserArg(column.MaximumExampleCount > 1, nameof(column.MaximumExampleCount), "Must be greater than 1");
2132+
var vectorSize = srcType.Size;
2133+
return new RobustScalerVecFunctionBuilder(host, column.MaximumExampleCount, vectorSize, centerData, quantileMin, quantileMax, getter);
2134+
}
2135+
2136+
public override IColumnFunction CreateColumnFunction()
2137+
{
2138+
_minMaxAggregator.Finish();
2139+
2140+
double[] scale = new double[_medianAggregators.Length];
2141+
double[] median = new double[_medianAggregators.Length];
2142+
2143+
// Have to calculate the median per slot
2144+
for (int i = 0; i < _medianAggregators.Length; i++)
2145+
{
2146+
_medianAggregators[i].Finish();
2147+
median[i] = _medianAggregators[i].Median;
2148+
2149+
double range = _minMaxAggregator.Max[i] - _minMaxAggregator.Min[i];
2150+
2151+
// Divide the range by 100 because we need to make the number, i.e. 75, into a decimal, .75
2152+
double quantileRange = (_quantileMax - _quantileMin) / 100f;
2153+
scale[i] = 1 / (range * quantileRange);
2154+
2155+
}
2156+
2157+
if (_centerData)
2158+
return AffineColumnFunction.Create(Host, scale, median, null);
2159+
else
2160+
return AffineColumnFunction.Create(Host, scale, null, null);
2161+
2162+
}
2163+
}
19152164
}
19162165
}
19172166
}

0 commit comments

Comments
 (0)