Skip to content

Learning with counts (Dracula) transformer #4514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Jun 4, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
70e3480
count table transform
yaeldMS Nov 7, 2019
26169c8
Dracula with unit tests
yaeldMS Dec 2, 2019
b818101
Fix entry point catalog test
yaeldMS Dec 2, 2019
08a7108
Address code review comments
yaeldMS Dec 3, 2019
2b2428b
create estimator from trained transformer
yaeldMS Dec 10, 2019
0ef935b
switch from three dimensional array of counts to two dimensional arra…
yaeldMS Dec 11, 2019
097f2f1
change mechanism for loading a pre-trained count table
yaeldMS Dec 16, 2019
73fd1ff
Add a sample
yaeldMS Dec 18, 2019
741b5ad
fix entrypoint catalog
yaeldMS Dec 25, 2019
8d00f47
documentation
yaeldMS Dec 27, 2019
c13ff8d
count table transform
yaeldMS Nov 7, 2019
f756a20
Dracula with unit tests
yaeldMS Dec 2, 2019
8880cc5
Address code review comments
yaeldMS Dec 3, 2019
330c6c5
create estimator from trained transformer
yaeldMS Dec 10, 2019
5c33181
change mechanism for loading a pre-trained count table
yaeldMS Dec 16, 2019
c8a9df5
Add a sample
yaeldMS Dec 18, 2019
a93ea18
documentation
yaeldMS Dec 27, 2019
9c921ce
fix unit tests
yaeldMS Dec 28, 2019
d7616a7
Delete unused file
yaeldMS Jan 1, 2020
90803b7
make CountTable* classes internal
yaeldMS Jan 10, 2020
c9cf4ce
Possible solution for adding noise only when training a pipeline
yaeldMS Jan 28, 2020
3d23f80
Fix bug
yaeldMS Jan 29, 2020
96e0041
Make all APIs and classes internal.
yaeldMS Feb 5, 2020
3fc56ec
Exclude dracula sample.
yaeldMS Feb 9, 2020
66f7865
Switch to using HashingTransformer instead of HashJoiningTransform.
yaeldMS May 18, 2020
91887da
Fix EntryPointCatalog test
yaeldMS May 19, 2020
c36e5b3
Address code review comments.
yaeldMS Jun 1, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
create estimator from trained transformer
  • Loading branch information
yaeldMS committed May 18, 2020
commit 2b2428b3eaee906d5a38f8bd73239f03cfee0b6e
59 changes: 47 additions & 12 deletions src/Microsoft.ML.Transforms/Dracula/CMCountTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ private static VersionInfo GetVersionInfo()

private readonly int _depth; // Number of different hash functions
private readonly int _width; // Hash space. May be any number, typically a power of 2
private readonly float[][][] _tables; // dimensions: label cardinality * depth * width

public float[][][] Tables { get; }

public CMCountTable(float[][][] tables, float[] priorCounts)
: base(Utils.Size(tables), priorCounts, 0, null)
Expand All @@ -60,7 +61,7 @@ public CMCountTable(float[][][] tables, float[] priorCounts)
Contracts.Check(_width > 0, "width must be positive");
Contracts.Check(tables.All(t => t.All(t2 => Utils.Size(t2) == _width)), "Width must be the same for all depths");

_tables = tables;
Tables = tables;
}

public static CMCountTable Create(IHostEnvironment env, ModelLoadContext ctx)
Expand Down Expand Up @@ -90,20 +91,20 @@ private CMCountTable(IHostEnvironment env, ModelLoadContext ctx)
_width = ctx.Reader.ReadInt32();
env.CheckDecode(_width > 0);

_tables = new float[LabelCardinality][][];
Tables = new float[LabelCardinality][][];
for (int i = 0; i < LabelCardinality; i++)
{
bool isSparse = ctx.Reader.ReadBoolByte();

_tables[i] = new float[_depth][];
Tables[i] = new float[_depth][];
for (int j = 0; j < _depth; j++)
{
if (!isSparse)
_tables[i][j] = ctx.Reader.ReadSingleArray(_width);
Tables[i][j] = ctx.Reader.ReadSingleArray(_width);
else
{
float[] table;
_tables[i][j] = table = new float[_width];
Tables[i][j] = table = new float[_width];
int pos = -1;
for (; ; )
{
Expand Down Expand Up @@ -145,7 +146,7 @@ public override void Save(ModelSaveContext ctx)

for (int iLabel = 0; iLabel < LabelCardinality; iLabel++)
{
var table = _tables[iLabel];
var table = Tables[iLabel];
bool isSparse = IsTableSparse(table);
ctx.Writer.WriteBoolByte(isSparse);
foreach (var array in table)
Expand Down Expand Up @@ -199,7 +200,7 @@ public override void GetCounts(long key, Span<float> counts)
for (int ilabel = 0; ilabel < LabelCardinality; ilabel++)
{
float minValue = -1;
var table = _tables[ilabel];
var table = Tables[ilabel];
for (int idepth = 0; idepth < _depth; idepth++)
{
int iwidth = (int)(Hashing.MixHash(Hashing.MurmurRound(hash, (uint)idepth)) % _width);
Expand All @@ -222,12 +223,17 @@ public override int AppendRows(List<int> hashIds, List<ulong> hashValues, List<f
hashIds.Add(i);
hashValues.Add((ulong)j);
for (int label = 0; label < LabelCardinality; label++)
countsCur[label] = _tables[label][i][j];
countsCur[label] = Tables[label][i][j];
counts.Add(countsCur);
}
}
return _depth * _width;
}

public override InternalCountTableBuilderBase ToBuilder()
{
return new CMCountTableBuilder.Builder(this);
}
}

internal sealed class CMCountTableBuilder : CountTableBuilderBase
Expand Down Expand Up @@ -268,9 +274,9 @@ internal CMCountTableBuilder(IHostEnvironment env, Options options)
{
}

internal override InternalCountTableBuilderBase GetBuilderHelper(long labelCardinality) => new Builder(labelCardinality, _depth, _width);
internal override InternalCountTableBuilderBase GetInternalBuilder(long labelCardinality) => new Builder(labelCardinality, _depth, _width);

private sealed class Builder : InternalCountTableBuilderBase
internal sealed class Builder : InternalCountTableBuilderBase
{
private readonly int _depth;
private readonly double[][][] _tables; // label cardinality * depth * width
Expand All @@ -294,7 +300,36 @@ public Builder(long labelCardinality, int depth, int width)
}
}

internal override ICountTable CreateCountTable()
public Builder(CMCountTable table)
: base(table.LabelCardinality)
{
Contracts.AssertValue(table);

_tables = new double[LabelCardinality][][];
for (int iLabel = 0; iLabel < LabelCardinality; iLabel++)
{
var oldTables = table.Tables[iLabel];
if (iLabel == 0)
_depth = oldTables.Length;
else
Contracts.Assert(_depth == oldTables.Length);

_tables[iLabel] = new double[_depth][];
for (int iDepth = 0; iDepth < _depth; iDepth++)
{
var oldTable = oldTables[iDepth];
if (iLabel == 0 && iDepth == 0)
_width = oldTable.Length;
else
Contracts.Assert(_width == oldTable.Length);

_tables[iLabel][iDepth] = new double[_width];
oldTable.CopyTo(_tables[iLabel][iDepth], 0);
}
}
}

internal override CountTableBase CreateCountTable()
{
var priorCounts = PriorCounts.Select(x => (float)x).ToArray();

Expand Down
4 changes: 3 additions & 1 deletion src/Microsoft.ML.Transforms/Dracula/CountTable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ internal abstract class CountTableBase : ICountTable, ICanSaveModel
{
public const int LabelCardinalityLim = 100;

protected readonly int LabelCardinality; // number of values the label can assume
public readonly int LabelCardinality; // number of values the label can assume
private readonly double[] _priorFrequencies;

public float GarbageThreshold { get; private set; } // garbage bin threshold
Expand Down Expand Up @@ -146,5 +146,7 @@ public virtual void Save(ModelSaveContext ctx)
}

public abstract int AppendRows(List<int> hashIds, List<ulong> hashValues, List<float[]> counts);

public abstract InternalCountTableBuilderBase ToBuilder();
}
}
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Transforms/Dracula/CountTableBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ private protected CountTableBuilderBase()
{
}

internal abstract InternalCountTableBuilderBase GetBuilderHelper(long labelCardinality);
internal abstract InternalCountTableBuilderBase GetInternalBuilder(long labelCardinality);

public static CountTableBuilderBase CreateCMCountTableBuilder(int depth = 4, int width = 1 << 23)
=> new CMCountTableBuilder(depth, width);
Expand Down Expand Up @@ -53,6 +53,6 @@ internal void Increment(long key, long labelKey)

internal abstract void InsertOrUpdateRawCounts(int hashId, long hashValue, in VBuffer<float> counts);

internal abstract ICountTable CreateCountTable();
internal abstract CountTableBase CreateCountTable();
}
}
71 changes: 47 additions & 24 deletions src/Microsoft.ML.Transforms/Dracula/CountTableTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ public SharedColumnOptions(string name, string inputColumnName, float priorCoeff
private readonly CountTableBuilderBase _sharedBuilder;
private readonly string _labelColumnName;
private readonly string _externalCountsFile;
private readonly CountTableTransformer _initialCounts;

internal CountTableEstimator(IHostEnvironment env, string labelColumnName, CountTableBuilderBase countTableBuilder, params SharedColumnOptions[] columns)
: this(env, labelColumnName, columns)
Expand All @@ -95,6 +96,26 @@ internal CountTableEstimator(IHostEnvironment env, string labelColumnName, strin
_builders = columns.Select(c => c.CountTableBuilder).ToArray();
}

internal CountTableEstimator(IHostEnvironment env, string labelColumnName, CountTableTransformer initial, params InputOutputColumnPair[] columns)
: this(env, labelColumnName, ExtractColumnOptions(initial, columns))
{
_initialCounts = initial;
}

private static ColumnOptionsBase[] ExtractColumnOptions(CountTableTransformer initial, InputOutputColumnPair[] columns)
{
Contracts.CheckValue(initial, nameof(initial));
if (columns.Length != initial.Featurizer.PriorCoef.Length)
throw Contracts.ExceptParam(nameof(columns), $"New estimator applied {columns.Length} columns, but old transformer applied to {initial.Featurizer.PriorCoef.Length} columns");
var cols = new ColumnOptionsBase[columns.Length];
for (int i=0; i<columns.Length;i++)
{
cols[i] = new SharedColumnOptions(columns[i].OutputColumnName, columns[i].InputColumnName,
initial.Featurizer.PriorCoef[i], initial.Featurizer.LaplaceScale[i], initial.Seeds[i]);
}
return cols;
}

private CountTableEstimator(IHostEnvironment env, string labelColumnName, ColumnOptionsBase[] columns)
{
Contracts.CheckValue(env, nameof(env));
Expand Down Expand Up @@ -130,9 +151,11 @@ public CountTableTransformer Fit(IDataView input)
inputColumns[i] = col.GetValueOrDefault();
}

_host.Assert((_sharedBuilder == null) != (_builders == null));
_host.Assert(_initialCounts != null || _sharedBuilder != null || _builders != null);
MultiCountTableBuilderBase multiBuilder;
if (_builders != null)
if (_initialCounts != null)
multiBuilder = _initialCounts.Featurizer.MultiCountTable.ToBuilder(_host);
else if (_builders != null)
multiBuilder = new ParallelMultiCountTableBuilder(_host, inputColumns, _builders, labelCardinality, _externalCountsFile);
else
multiBuilder = new BagMultiCountTableBuilder(_host, inputColumns, _sharedBuilder, labelCardinality);
Expand Down Expand Up @@ -385,10 +408,10 @@ internal static class Defaults
public const bool SharedTable = false;
}

//private readonly DraculaFeaturizer[][] _featurizers; // parallel to count tables
private readonly DraculaFeaturizer _featurizer;
internal readonly DraculaFeaturizer Featurizer;
private readonly string[] _labelClassNames;
private readonly int[] _seeds;

internal int[] Seeds { get; }

internal const string Summary = "Transforms the categorical column into the set of features: count of each label class, "
+ "log-odds for each label class, back-off indicator. The input columns must be keys. This is a part of the Dracula transform.";
Expand All @@ -414,9 +437,9 @@ internal CountTableTransformer(IHostEnvironment env, DraculaFeaturizer featurize
Host.AssertValueOrNull(labelClassNames);
Host.Assert(Utils.Size(seeds) == featurizer.ColCount);

_featurizer = featurizer;
Featurizer = featurizer;
_labelClassNames = labelClassNames;
_seeds = seeds;
Seeds = seeds;
}

// Factory method for SignatureLoadDataTransform.
Expand Down Expand Up @@ -511,8 +534,8 @@ private CountTableTransformer(IHost host, ModelLoadContext ctx)
}
}

_seeds = ctx.Reader.ReadIntArray(ColumnPairs.Length);
ctx.LoadModel<DraculaFeaturizer, SignatureLoadModel>(host, out _featurizer, "DraculaFeaturizer");
Seeds = ctx.Reader.ReadIntArray(ColumnPairs.Length);
ctx.LoadModel<DraculaFeaturizer, SignatureLoadModel>(host, out Featurizer, "DraculaFeaturizer");
}

private protected override void SaveModel(ModelSaveContext ctx)
Expand Down Expand Up @@ -542,8 +565,8 @@ private protected override void SaveModel(ModelSaveContext ctx)
}
}

ctx.Writer.WriteIntsNoCount(_seeds);
ctx.SaveModel(_featurizer, "DraculaFeaturizer");
ctx.Writer.WriteIntsNoCount(Seeds);
ctx.SaveModel(Featurizer, "DraculaFeaturizer");
}

private protected override IRowMapper MakeRowMapper(DataViewSchema schema) => new Mapper(this, schema);
Expand All @@ -553,7 +576,7 @@ public void SaveCountTables(string path)
var saver = new TextSaver(Host, new TextSaver.Arguments() { OutputHeader = false, OutputSchema = false, Dense = true });
using (var stream = new FileStream(path, FileMode.Create))
using (var ch = Host.Start("Saving Count Tables"))
DataSaverUtils.SaveDataView(ch, saver, _featurizer.ToDataView(), stream);
DataSaverUtils.SaveDataView(ch, saver, Featurizer.ToDataView(), stream);
}

private sealed class Mapper : OneToOneMapperBase
Expand All @@ -572,8 +595,8 @@ protected override DataViewSchema.DetachedColumn[] GetOutputColumnsCore()
{
var inputCol = InputSchema[_parent.ColumnPairs[i].inputColumnName];
var valueCount = inputCol.Type.GetValueCount();
Host.Check((long)valueCount * _parent._featurizer.NumFeatures < int.MaxValue, "Too large output size");
var type = new VectorDataViewType(NumberDataViewType.Single, valueCount, _parent._featurizer.NumFeatures);
Host.Check((long)valueCount * _parent.Featurizer.NumFeatures < int.MaxValue, "Too large output size");
var type = new VectorDataViewType(NumberDataViewType.Single, valueCount, _parent.Featurizer.NumFeatures);

// We supply slot names if the source is a single-value column, or if it has slot names.
if (!(inputCol.Type is VectorDataViewType) || inputCol.HasSlotNames())
Expand Down Expand Up @@ -608,7 +631,7 @@ private ValueGetter<VBuffer<ReadOnlyMemory<char>>> GetSlotNamesGetter(DataViewSc
ValueGetter<VBuffer<ReadOnlyMemory<char>>> getter =
(ref VBuffer<ReadOnlyMemory<char>> dst) =>
{
_parent._featurizer.GetFeatureNames(_parent._labelClassNames, ref featureNames);
_parent.Featurizer.GetFeatureNames(_parent._labelClassNames, ref featureNames);
int nFeatures = featureNames.Length;

var editor = VBufferEditor.Create(ref dst, nFeatures * inputSlotNames.Length);
Expand Down Expand Up @@ -638,12 +661,12 @@ protected override Delegate MakeGetter(DataViewRow input, int iinfo, Func<int, b

private ValueGetter<VBuffer<float>> ConstructSingleGetter(DataViewRow input, int iinfo)
{
Host.Assert(_parent._featurizer.SlotCount[iinfo] == 1);
Host.Assert(_parent.Featurizer.SlotCount[iinfo] == 1);
uint src = 0;
var srcGetter = input.GetGetter<uint>(input.Schema[_parent.ColumnPairs[iinfo].inputColumnName]);
var outputLength = _parent._featurizer.NumFeatures;
var rand = new Random(_parent._seeds[iinfo]);
var featurizer = _parent._featurizer;
var outputLength = _parent.Featurizer.NumFeatures;
var rand = new Random(_parent.Seeds[iinfo]);
var featurizer = _parent.Featurizer;
return (ref VBuffer<float> dst) =>
{
srcGetter(ref src);
Expand All @@ -657,12 +680,12 @@ private ValueGetter<VBuffer<float>> ConstructVectorGetter(DataViewRow input, int
{
var inputCol = input.Schema[_parent.ColumnPairs[iinfo].inputColumnName];
int n = inputCol.Type.GetValueCount();
Host.Assert(_parent._featurizer.SlotCount[iinfo] == n);
Host.Assert(_parent.Featurizer.SlotCount[iinfo] == n);
VBuffer<uint> src = default;

var outputLength = _parent._featurizer.NumFeatures;
var outputLength = _parent.Featurizer.NumFeatures;
var srcGetter = input.GetGetter<VBuffer<uint>>(inputCol);
var rand = new Random(_parent._seeds[iinfo]);
var rand = new Random(_parent.Seeds[iinfo]);
return (ref VBuffer<float> dst) =>
{
srcGetter(ref src);
Expand All @@ -671,7 +694,7 @@ private ValueGetter<VBuffer<float>> ConstructVectorGetter(DataViewRow input, int
{
var srcValues = src.GetValues();
for (int i = 0; i < n; i++)
_parent._featurizer.GetFeatures(iinfo, i, rand, srcValues[i], editor.Values.Slice(i * outputLength, outputLength));
_parent.Featurizer.GetFeatures(iinfo, i, rand, srcValues[i], editor.Values.Slice(i * outputLength, outputLength));
}
else
{
Expand All @@ -681,7 +704,7 @@ private ValueGetter<VBuffer<float>> ConstructVectorGetter(DataViewRow input, int
for (int i = 0; i < srcIndices.Length; i++)
{
var index = srcIndices[i];
_parent._featurizer.GetFeatures(iinfo, index, rand, srcValues[i], editor.Values.Slice(index * outputLength, outputLength));
_parent.Featurizer.GetFeatures(iinfo, index, rand, srcValues[i], editor.Values.Slice(index * outputLength, outputLength));
}
}

Expand Down
Loading