@@ -516,6 +516,117 @@ private void Update(int j, TFloat origVal)
516
516
}
517
517
}
518
518
519
+ [ BestFriend ]
520
+ internal static partial class MedianAggregatorUtils
521
+ {
522
+ /// <summary>
523
+ /// Based on the algorithm on GeeksForGeeks https://www.geeksforgeeks.org/median-of-stream-of-integers-running-integers/.
524
+ /// </summary>
525
+ /// <param name="num">The new number to account for in our median calculation.</param>
526
+ /// <param name="median">The current median.</param>
527
+ /// <param name="belowMedianHeap">The MaxHeap that has all the numbers below the median.</param>
528
+ /// <param name="aboveMedianHeap">The MinHeap that has all the numbers above the median.</param>
529
+ [ BestFriend ]
530
+ internal static void GetMedianSoFar ( in double num , ref double median , ref MaxHeap < double > belowMedianHeap , ref MinHeap < double > aboveMedianHeap )
531
+ {
532
+ int comparison = belowMedianHeap . Count ( ) . CompareTo ( aboveMedianHeap . Count ( ) ) ;
533
+
534
+ if ( comparison < 0 )
535
+ { // More elements in aboveMedianHeap than belowMedianHeap.
536
+ if ( num < median )
537
+ { // Current element belongs in the belowMedianHeap.
538
+ // Insert new number into belowMedianHeap
539
+ belowMedianHeap . Add ( num ) ;
540
+
541
+ }
542
+ else
543
+ { // Current element belongs in aboveMedianHeap.
544
+ // Need to move one to belowMedianHeap to keep heeps balanced.
545
+ belowMedianHeap . Add ( aboveMedianHeap . Pop ( ) ) ;
546
+
547
+ aboveMedianHeap . Add ( num ) ;
548
+ }
549
+
550
+ // Both heaps are balanced so median is the average of the 2 heaps.
551
+ median = ( aboveMedianHeap . Peek ( ) + belowMedianHeap . Peek ( ) ) / 2 ;
552
+
553
+ }
554
+ else if ( comparison == 0 )
555
+ { // Both heaps have the same number of elements. Simple put the number where it belongs.
556
+ if ( num < median )
557
+ { // Current element belongs in the belowMedianHeap.
558
+ belowMedianHeap . Add ( num ) ;
559
+
560
+ // Now we have an odd number of items, median is the new root of the belowMedianHeap
561
+ median = belowMedianHeap . Peek ( ) ;
562
+
563
+ }
564
+ else
565
+ { // Current element belongs in above median heap.
566
+ aboveMedianHeap . Add ( num ) ;
567
+
568
+ // Now we have an odd number of items, median is the new root of the aboveMedianHeap
569
+ median = aboveMedianHeap . Peek ( ) ;
570
+ }
571
+
572
+ }
573
+ else
574
+ { // More elements in belowMedianHeap than aboveMedianHeap.
575
+ if ( num < median )
576
+ { // Current element belongs in the belowMedianHeap.
577
+ // Need to move one to aboveMedianHeap to keep heeps balanced.
578
+ aboveMedianHeap . Add ( belowMedianHeap . Pop ( ) ) ;
579
+
580
+ // Insert new number into belowMedianHeap
581
+ belowMedianHeap . Add ( num ) ;
582
+
583
+ }
584
+ else
585
+ { // Current element belongs in aboveMedianHeap.
586
+ aboveMedianHeap . Add ( num ) ;
587
+ }
588
+
589
+ // Both heaps are balanced so median is the average of the 2 heaps.
590
+ median = ( aboveMedianHeap . Peek ( ) + belowMedianHeap . Peek ( ) ) / 2 ;
591
+ }
592
+ }
593
+ }
594
+
595
+ /// <summary>
596
+ /// Base class for tracking median values for a single valued column.
597
+ /// It tracks median values of non-sparse values (vCount).
598
+ /// NaNs are ignored when updating min and max.
599
+ /// </summary>
600
+ internal sealed class MedianDblAggregator : IColumnAggregator < double >
601
+ {
602
+ private MedianAggregatorUtils . MaxHeap < double > _belowMedianHeap ;
603
+ private MedianAggregatorUtils . MinHeap < double > _aboveMedianHeap ;
604
+ private double _median ;
605
+
606
+ public MedianDblAggregator ( int contatinerStartingSize = 1000 )
607
+ {
608
+ Contracts . Check ( contatinerStartingSize > 0 ) ;
609
+ _belowMedianHeap = new MedianAggregatorUtils . MaxHeap < double > ( contatinerStartingSize ) ;
610
+ _aboveMedianHeap = new MedianAggregatorUtils . MinHeap < double > ( contatinerStartingSize ) ;
611
+ _median = default ;
612
+ }
613
+
614
+ public double Median
615
+ {
616
+ get { return _median ; }
617
+ }
618
+
619
+ public void ProcessValue ( in double value )
620
+ {
621
+ MedianAggregatorUtils . GetMedianSoFar ( value , ref _median , ref _belowMedianHeap , ref _aboveMedianHeap ) ;
622
+ }
623
+
624
+ public void Finish ( )
625
+ {
626
+ // Finish is a no-op because we are updating the median continually as we go
627
+ }
628
+ }
629
+
519
630
internal sealed partial class NormalizeTransform
520
631
{
521
632
internal abstract partial class AffineColumnFunction
@@ -1912,6 +2023,144 @@ public static IColumnFunctionBuilder Create(NormalizingEstimator.SupervisedBinni
1912
2023
return new SupervisedBinVecColumnFunctionBuilder ( host , lim , fix , numBins , column . MininimumBinSize , valueColumnId , labelColumnId , dataRow ) ;
1913
2024
}
1914
2025
}
2026
+
2027
+ public sealed class RobustScalerOneColumnFunctionBuilder : OneColumnFunctionBuilderBase < double >
2028
+ {
2029
+ private readonly MinMaxDblAggregator _minMaxAggregator ;
2030
+ private readonly MedianDblAggregator _medianAggregator ;
2031
+ private readonly bool _centerData ;
2032
+ private readonly uint _quantileMin ;
2033
+ private readonly uint _quantileMax ;
2034
+ private VBuffer < double > _buffer ;
2035
+
2036
+ private RobustScalerOneColumnFunctionBuilder ( IHost host , long lim , bool centerData , uint quantileMin , uint quantileMax , ValueGetter < double > getSrc )
2037
+ : base ( host , lim , getSrc )
2038
+ {
2039
+ // Using the MinMax aggregator since that is what needs to be found here as well.
2040
+ // The difference is how the min/max are used.
2041
+ _minMaxAggregator = new MinMaxDblAggregator ( 1 ) ;
2042
+ _medianAggregator = new MedianDblAggregator ( ) ;
2043
+ _buffer = new VBuffer < double > ( 1 , new double [ 1 ] ) ;
2044
+ _centerData = centerData ;
2045
+ _quantileMin = quantileMin ;
2046
+ _quantileMax = quantileMax ;
2047
+ }
2048
+
2049
+ protected override bool ProcessValue ( in double val )
2050
+ {
2051
+ if ( ! base . ProcessValue ( in val ) )
2052
+ return false ;
2053
+ VBufferEditor . CreateFromBuffer ( ref _buffer ) . Values [ 0 ] = val ;
2054
+ _minMaxAggregator . ProcessValue ( in _buffer ) ;
2055
+ _medianAggregator . ProcessValue ( in val ) ;
2056
+ return true ;
2057
+ }
2058
+
2059
+ public static IColumnFunctionBuilder Create ( NormalizingEstimator . RobustScalingColumnOptions column , IHost host , DataViewType srcType ,
2060
+ bool centerData , uint quantileMin , uint quantileMax , ValueGetter < double > getter )
2061
+ {
2062
+ host . CheckUserArg ( column . MaximumExampleCount > 1 , nameof ( column . MaximumExampleCount ) , "Must be greater than 1" ) ;
2063
+ return new RobustScalerOneColumnFunctionBuilder ( host , column . MaximumExampleCount , centerData , quantileMin , quantileMax , getter ) ;
2064
+ }
2065
+
2066
+ public override IColumnFunction CreateColumnFunction ( )
2067
+ {
2068
+ _minMaxAggregator . Finish ( ) ;
2069
+ _medianAggregator . Finish ( ) ;
2070
+
2071
+ double median = _medianAggregator . Median ;
2072
+ double range = _minMaxAggregator . Max [ 0 ] - _minMaxAggregator . Min [ 0 ] ;
2073
+ // Divide the range by 100 because we need to make the number, i.e. 75, into a decimal, .75
2074
+ double quantileRange = ( _quantileMax - _quantileMin ) / 100f ;
2075
+ double scale = 1 / ( range * quantileRange ) ;
2076
+
2077
+ if ( _centerData )
2078
+ return AffineColumnFunction . Create ( Host , scale , median ) ;
2079
+ else
2080
+ return AffineColumnFunction . Create ( Host , scale , 0 ) ;
2081
+ }
2082
+ }
2083
+
2084
+ public sealed class RobustScalerVecFunctionBuilder : OneColumnFunctionBuilderBase < VBuffer < double > >
2085
+ {
2086
+ private readonly MinMaxDblAggregator _minMaxAggregator ;
2087
+ private readonly MedianDblAggregator [ ] _medianAggregators ;
2088
+ private readonly bool _centerData ;
2089
+ private readonly uint _quantileMin ;
2090
+ private readonly uint _quantileMax ;
2091
+
2092
+ private RobustScalerVecFunctionBuilder ( IHost host , long lim , int vectorSize , bool centerData , uint quantileMin , uint quantileMax , ValueGetter < VBuffer < double > > getSrc )
2093
+ : base ( host , lim , getSrc )
2094
+ {
2095
+ // Using the MinMax aggregator since that is what needs to be found here as well.
2096
+ // The difference is how the min/max are used.
2097
+ _minMaxAggregator = new MinMaxDblAggregator ( vectorSize ) ;
2098
+
2099
+ // If we aren't centering data dont need the median.
2100
+ _medianAggregators = new MedianDblAggregator [ vectorSize ] ;
2101
+
2102
+ for ( int i = 0 ; i < vectorSize ; i ++ )
2103
+ {
2104
+ _medianAggregators [ i ] = new MedianDblAggregator ( ) ;
2105
+ }
2106
+
2107
+ _centerData = centerData ;
2108
+ _quantileMin = quantileMin ;
2109
+ _quantileMax = quantileMax ;
2110
+ }
2111
+
2112
+ protected override bool ProcessValue ( in VBuffer < double > val )
2113
+ {
2114
+ if ( ! base . ProcessValue ( in val ) )
2115
+ return false ;
2116
+ _minMaxAggregator . ProcessValue ( in val ) ;
2117
+
2118
+ // Have to calculate the median per slot
2119
+ var span = val . GetValues ( ) ;
2120
+ for ( int i = 0 ; i < _medianAggregators . Length ; i ++ )
2121
+ {
2122
+ _medianAggregators [ i ] . ProcessValue ( span [ i ] ) ;
2123
+ }
2124
+
2125
+ return true ;
2126
+ }
2127
+
2128
+ public static IColumnFunctionBuilder Create ( NormalizingEstimator . RobustScalingColumnOptions column , IHost host , VectorDataViewType srcType ,
2129
+ bool centerData , uint quantileMin , uint quantileMax , ValueGetter < VBuffer < double > > getter )
2130
+ {
2131
+ host . CheckUserArg ( column . MaximumExampleCount > 1 , nameof ( column . MaximumExampleCount ) , "Must be greater than 1" ) ;
2132
+ var vectorSize = srcType . Size ;
2133
+ return new RobustScalerVecFunctionBuilder ( host , column . MaximumExampleCount , vectorSize , centerData , quantileMin , quantileMax , getter ) ;
2134
+ }
2135
+
2136
+ public override IColumnFunction CreateColumnFunction ( )
2137
+ {
2138
+ _minMaxAggregator . Finish ( ) ;
2139
+
2140
+ double [ ] scale = new double [ _medianAggregators . Length ] ;
2141
+ double [ ] median = new double [ _medianAggregators . Length ] ;
2142
+
2143
+ // Have to calculate the median per slot
2144
+ for ( int i = 0 ; i < _medianAggregators . Length ; i ++ )
2145
+ {
2146
+ _medianAggregators [ i ] . Finish ( ) ;
2147
+ median [ i ] = _medianAggregators [ i ] . Median ;
2148
+
2149
+ double range = _minMaxAggregator . Max [ i ] - _minMaxAggregator . Min [ i ] ;
2150
+
2151
+ // Divide the range by 100 because we need to make the number, i.e. 75, into a decimal, .75
2152
+ double quantileRange = ( _quantileMax - _quantileMin ) / 100f ;
2153
+ scale [ i ] = 1 / ( range * quantileRange ) ;
2154
+
2155
+ }
2156
+
2157
+ if ( _centerData )
2158
+ return AffineColumnFunction . Create ( Host , scale , median , null ) ;
2159
+ else
2160
+ return AffineColumnFunction . Create ( Host , scale , null , null ) ;
2161
+
2162
+ }
2163
+ }
1915
2164
}
1916
2165
}
1917
2166
}
0 commit comments