Skip to content

Add option to execute only the last transform in TransformWrapper and have WordBagEstimator return transformer chain #3700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 16, 2019
Prev Previous commit
Next Next commit
PR feedback.
  • Loading branch information
codemzs committed May 13, 2019
commit 75757c4221a55b23ad5bec5ae4e3dc98989f82b5
6 changes: 4 additions & 2 deletions docs/samples/Microsoft.ML.Samples/Program.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Reflection;
using Samples.Dynamic;

namespace Microsoft.ML.Samples
{
Expand All @@ -9,7 +10,7 @@ public static class Program

internal static void RunAll()
{
int samples = 0;
/*int samples = 0;
foreach (var type in Assembly.GetExecutingAssembly().GetTypes())
{
var sample = type.GetMethod("Example", BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy);
Expand All @@ -22,7 +23,8 @@ internal static void RunAll()
}
}

Console.WriteLine("Number of samples that ran without any exception: " + samples);
Console.WriteLine("Number of samples that ran without any exception: " + samples);*/
ProduceWordBags.Example();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,6 @@ public static IDataView LoadSelectedTransforms(ModelLoadContext ctx, IDataView s
internal TransformerChain<ITransformer> GetTransformer()
{
var result = new TransformerChain<ITransformer>();
IDataTransform lastTransformer = null;
foreach (var transform in _transforms)
{
if (transform.Transform is RowToRowMapperTransform mapper)
Expand All @@ -413,11 +412,9 @@ internal TransformerChain<ITransformer> GetTransformer()
}
else
{
ITransformer transformer = new TransformWrapper(_host, transform.Transform, false, lastTransformer is RowToRowMapperTransform);
ITransformer transformer = new TransformWrapper(_host, transform.Transform, false, true);
result = result.Append(transformer);
}

lastTransformer = transform.Transform;
}
return result;
}
Expand Down
23 changes: 14 additions & 9 deletions src/Microsoft.ML.Transforms/Text/WordBagTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ internal sealed class Options : NgramExtractorTransform.ArgumentsBase
internal const string Summary = "Produces a bag of counts of n-grams (sequences of consecutive words of length 1-n) in a given text. It does so by building "
+ "a dictionary of n-grams and using the id in the dictionary as the index in the bag.";

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
internal static ITransformer CreateTransfomer(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
Expand Down Expand Up @@ -150,10 +150,16 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa
}

IDataView view = input;
view = NgramExtractionUtils.ApplyConcatOnSources(h, options.Columns, view);
view = new WordTokenizingEstimator(env, tokenizeColumns).Fit(view).Transform(view);
return NgramExtractorTransform.CreateDataTransform(h, extractorArgs, view);
ITransformer t0 = NgramExtractionUtils.ApplyConcatOnSources(h, options.Columns);
view = t0.Transform(view);
ITransformer t1 = new WordTokenizingEstimator(env, tokenizeColumns).Fit(view);
view = t1.Transform(view);
ITransformer t2 = NgramExtractorTransform.Create(h, extractorArgs, t1.Transform(view));
return new TransformerChain<ITransformer>(new[] { t0, t1, t2 });
}

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) =>
(IDataTransform)CreateTransfomer(env, options, input).Transform(input);
}

/// <summary>
Expand Down Expand Up @@ -489,13 +495,11 @@ public ITransformer Create(IHostEnvironment env, IDataView input, ExtractorColum

internal static class NgramExtractionUtils
{
public static IDataView ApplyConcatOnSources(IHostEnvironment env, ManyToOneColumn[] columns, IDataView input)
public static ITransformer ApplyConcatOnSources(IHostEnvironment env, ManyToOneColumn[] columns)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(columns, nameof(columns));
env.CheckValue(input, nameof(input));

IDataView view = input;
var concatColumns = new List<ColumnConcatenatingTransformer.ColumnOptions>();
foreach (var col in columns)
{
Expand All @@ -506,10 +510,11 @@ public static IDataView ApplyConcatOnSources(IHostEnvironment env, ManyToOneColu
if (col.Source.Length > 1)
concatColumns.Add(new ColumnConcatenatingTransformer.ColumnOptions(col.Name, col.Source));
}

if (concatColumns.Count > 0)
return new ColumnConcatenatingTransformer(env, concatColumns.ToArray()).Transform(view);
return new ColumnConcatenatingTransformer(env, concatColumns.ToArray());

return view;
return new TransformerChain<ITransformer>();
}

/// <summary>
Expand Down
14 changes: 10 additions & 4 deletions src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ internal sealed class Options : NgramHashExtractingTransformer.ArgumentsBase
internal const string Summary = "Produces a bag of counts of n-grams (sequences of consecutive words of length 1-n) in a given text. "
+ "It does so by hashing each n-gram and using the hash value as the index in the bag.";

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input)
internal static ITransformer CreateTransformer(IHostEnvironment env, Options options, IDataView input)
{
Contracts.CheckValue(env, nameof(env));
var h = env.Register(RegistrationName);
Expand Down Expand Up @@ -132,7 +132,7 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa
};
}

view = new WordTokenizingEstimator(env, tokenizeColumns.ToArray()).Fit(view).Transform(view);
ITransformer t1 = new WordTokenizingEstimator(env, tokenizeColumns.ToArray()).Fit(view);

var featurizeArgs =
new NgramHashExtractingTransformer.Options
Expand All @@ -147,11 +147,17 @@ internal static IDataTransform Create(IHostEnvironment env, Options options, IDa
MaximumNumberOfInverts = options.MaximumNumberOfInverts
};

view = NgramHashExtractingTransformer.Create(h, featurizeArgs, view).Transform(view);
view = t1.Transform(view);
ITransformer t2 = NgramHashExtractingTransformer.Create(h, featurizeArgs, t1.Transform(view));

// Since we added columns with new names, we need to explicitly drop them before we return the IDataTransform.
return ColumnSelectingTransformer.CreateDrop(h, view, tmpColNames.ToArray()) as IDataTransform;
ITransformer t3 = new ColumnSelectingTransformer(env, null, tmpColNames.ToArray());

return new TransformerChain<ITransformer>(new[] { t1, t2, t3 });
}

internal static IDataTransform Create(IHostEnvironment env, Options options, IDataView input) =>
(IDataTransform)CreateTransformer(env, options, input).Transform(input);
}

/// <summary>
Expand Down
4 changes: 2 additions & 2 deletions src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ public ITransformer Fit(IDataView input)
Weighting = _weighting
};

return new TransformWrapper(_host, WordBagBuildingTransformer.Create(_host, options, input), true);
return WordBagBuildingTransformer.CreateTransfomer(_host, options, input);
}

/// <summary>
Expand Down Expand Up @@ -365,7 +365,7 @@ public ITransformer Fit(IDataView input)
MaximumNumberOfInverts = _maximumNumberOfInverts
};

return new TransformWrapper(_host, WordHashBagProducingTransformer.Create(_host, options, input), true);
return WordHashBagProducingTransformer.CreateTransformer(_host, options, input);
}

/// <summary>
Expand Down
6 changes: 5 additions & 1 deletion test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,11 @@ public void EntryPointPipelineEnsembleText()
new WordHashBagProducingTransformer.Options()
{
Columns =
new[] { new WordHashBagProducingTransformer.Column() { Name = "Features", Source = new[] { "Text" } }, }
new[]
{
new WordHashBagProducingTransformer.Column()
{Name = "Features", Source = new[] {"Text"}},
}
},
data);
}
Expand Down
63 changes: 63 additions & 0 deletions test/Microsoft.ML.Tests/Scenarios/WordBagTest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Xunit;
using System.Collections.Generic;
using Microsoft.ML.Transforms.Text;

namespace Microsoft.ML.Scenarios
{
public partial class ScenariosTests
{
[Fact]
public static void WordBags()
{
var mlContext = new MLContext();
var samples = new List<TextData>()
{
new TextData(){ Text = "This is an example to compute bag-of-word features." },
new TextData(){ Text = "ML.NET's ProduceWordBags API produces bag-of-word features from input text." },
new TextData(){ Text = "It does so by first tokenizing text/string into words/tokens then " },
new TextData(){ Text = "computing n-grams and their neumeric values." },
new TextData(){ Text = "Each position in the output vector corresponds to a particular n-gram." },
new TextData(){ Text = "The value at each position corresponds to," },
new TextData(){ Text = "the number of times n-gram occured in the data (Tf), or" },
new TextData(){ Text = "the inverse of the number of documents contain the n-gram (Idf)," },
new TextData(){ Text = "or compute both and multipy together (Tf-Idf)." },
};

var dataview = mlContext.Data.LoadFromEnumerable(samples);
var textPipeline =
mlContext.Transforms.Text.ProduceWordBags("Text", "Text",
ngramLength: 3, useAllLengths: false, weighting: NgramExtractingEstimator.WeightingCriteria.Tf).Append(
mlContext.Transforms.Text.ProduceWordBags("Text2", "Text2",
ngramLength: 3, useAllLengths: false, weighting: NgramExtractingEstimator.WeightingCriteria.Tf));


var textTransformer = textPipeline.Fit(dataview);
var transformedDataView = textTransformer.Transform(dataview);
var predictionEngine = mlContext.Model.CreatePredictionEngine<TextData, TransformedTextData>(textTransformer);
var prediction = predictionEngine.Predict(samples[0]);
Assert.Equal(prediction.Text, new float[] {
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 });
}

private class TextData
{
public string Text { get; set; }
#pragma warning disable 414
public string Text2 = "ABC";
#pragma warning restore 414
}

private class TransformedTextData
{
public float[] Text { get; set; }
}
}

}