Skip to content

Commit 268ebbc

Browse files
authored
[Part 2] Added convenience constructors for set of transforms. (dotnet#491)
1 parent fbc00db commit 268ebbc

12 files changed

+256
-14
lines changed

src/Microsoft.ML.Data/Transforms/ChooseColumnsTransform.cs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,20 @@ public bool TryUnparse(StringBuilder sb)
5858

5959
public sealed class Arguments
6060
{
61+
public Arguments()
62+
{
63+
64+
}
65+
66+
internal Arguments(params string[] columns)
67+
{
68+
Column = new Column[columns.Length];
69+
for (int i = 0; i < columns.Length; i++)
70+
{
71+
Column[i] = new Column() { Source = columns[i], Name = columns[i] };
72+
}
73+
}
74+
6175
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
6276
public Column[] Column;
6377

@@ -442,6 +456,17 @@ private static VersionInfo GetVersionInfo()
442456

443457
private const string RegistrationName = "ChooseColumns";
444458

459+
/// <summary>
460+
/// Convenience constructor for public facing API.
461+
/// </summary>
462+
/// <param name="env">Host Environment.</param>
463+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
464+
/// <param name="columns">Names of the columns to choose.</param>
465+
public ChooseColumnsTransform(IHostEnvironment env, IDataView input, params string[] columns)
466+
: this(env, new Arguments(columns), input)
467+
{
468+
}
469+
445470
/// <summary>
446471
/// Public constructor corresponding to SignatureDataTransform.
447472
/// </summary>

src/Microsoft.ML.Data/Transforms/ConvertTransform.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,23 @@ private static VersionInfo GetVersionInfo()
169169
// This is parallel to Infos.
170170
private readonly ColInfoEx[] _exes;
171171

172+
/// <summary>
173+
/// Convenience constructor for public facing API.
174+
/// </summary>
175+
/// <param name="env">Host Environment.</param>
176+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
177+
/// <param name="resultType">The expected type of the converted column.</param>
178+
/// <param name="name">Name of the output column.</param>
179+
/// <param name="source">Name of the column to be converted. If this is null '<paramref name="name"/>' will be used.</param>
180+
public ConvertTransform(IHostEnvironment env,
181+
IDataView input,
182+
DataKind resultType,
183+
string name,
184+
string source = null)
185+
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, ResultType = resultType }, input)
186+
{
187+
}
188+
172189
public ConvertTransform(IHostEnvironment env, Arguments args, IDataView input)
173190
: base(env, RegistrationName, env.CheckRef(args, nameof(args)).Column,
174191
input, null)

src/Microsoft.ML.Data/Transforms/GenerateNumberTransform.cs

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,22 @@ private bool TryParse(string str)
7777
}
7878
}
7979

80+
private static class Defaults
81+
{
82+
public const bool UseCounter = false;
83+
public const uint Seed = 42;
84+
}
85+
8086
public sealed class Arguments : TransformInputBase
8187
{
8288
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "New column definition(s) (optional form: name:seed)", ShortName = "col", SortOrder = 1)]
8389
public Column[] Column;
8490

8591
[Argument(ArgumentType.AtMostOnce, HelpText = "Use an auto-incremented integer starting at zero instead of a random number", ShortName = "cnt")]
86-
public bool UseCounter;
92+
public bool UseCounter = Defaults.UseCounter;
8793

8894
[Argument(ArgumentType.AtMostOnce, HelpText = "The random seed")]
89-
public uint Seed = 42;
95+
public uint Seed = Defaults.Seed;
9096
}
9197

9298
private sealed class Bindings : ColumnBindingsBase
@@ -250,6 +256,18 @@ private static VersionInfo GetVersionInfo()
250256

251257
private const string RegistrationName = "GenerateNumber";
252258

259+
/// <summary>
260+
/// Convenience constructor for public facing API.
261+
/// </summary>
262+
/// <param name="env">Host Environment.</param>
263+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
264+
/// <param name="name">Name of the output column.</param>
265+
/// <param name="useCounter">Use an auto-incremented integer starting at zero instead of a random number.</param>
266+
public GenerateNumberTransform(IHostEnvironment env, IDataView input, string name, bool useCounter = Defaults.UseCounter)
267+
: this(env, new Arguments() { Column = new[] { new Column() { Name = name } }, UseCounter = useCounter }, input)
268+
{
269+
}
270+
253271
/// <summary>
254272
/// Public constructor corresponding to SignatureDataTransform.
255273
/// </summary>

src/Microsoft.ML.Data/Transforms/HashTransform.cs

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ public sealed class HashTransform : OneToOneTransformBase, ITransformTemplate
3333
public const int NumBitsMin = 1;
3434
public const int NumBitsLim = 32;
3535

36+
private static class Defaults
37+
{
38+
public const int HashBits = NumBitsLim - 1;
39+
public const uint Seed = 314489979;
40+
public const bool Ordered = false;
41+
public const int InvertHash = 0;
42+
}
43+
3644
public sealed class Arguments
3745
{
3846
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col",
@@ -41,18 +49,18 @@ public sealed class Arguments
4149

4250
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of bits to hash into. Must be between 1 and 31, inclusive",
4351
ShortName = "bits", SortOrder = 2)]
44-
public int HashBits = NumBitsLim - 1;
52+
public int HashBits = Defaults.HashBits;
4553

4654
[Argument(ArgumentType.AtMostOnce, HelpText = "Hashing seed")]
47-
public uint Seed = 314489979;
55+
public uint Seed = Defaults.Seed;
4856

4957
[Argument(ArgumentType.AtMostOnce, HelpText = "Whether the position of each term should be included in the hash",
5058
ShortName = "ord")]
51-
public bool Ordered;
59+
public bool Ordered = Defaults.Ordered;
5260

5361
[Argument(ArgumentType.AtMostOnce, HelpText = "Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.",
5462
ShortName = "ih")]
55-
public int InvertHash;
63+
public int InvertHash = Defaults.InvertHash;
5664
}
5765

5866
public sealed class Column : OneToOneColumn
@@ -234,6 +242,27 @@ public override void Save(ModelSaveContext ctx)
234242
TextModelHelper.SaveAll(Host, ctx, Infos.Length, _keyValues);
235243
}
236244

245+
/// <summary>
246+
/// Convenience constructor for public facing API.
247+
/// </summary>
248+
/// <param name="env">Host Environment.</param>
249+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
250+
/// <param name="name">Name of the output column.</param>
251+
/// <param name="source">Name of the column to be transformed. If this is null '<paramref name="name"/>' will be used.</param>
252+
/// <param name="hashBits">Number of bits to hash into. Must be between 1 and 31, inclusive.</param>
253+
/// <param name="invertHash">Limit the number of keys used to generate the slot name to this many. 0 means no invert hashing, -1 means no limit.</param>
254+
public HashTransform(IHostEnvironment env,
255+
IDataView input,
256+
string name,
257+
string source = null,
258+
int hashBits = Defaults.HashBits,
259+
int invertHash = Defaults.InvertHash)
260+
: this(env, new Arguments() {
261+
Column = new[] { new Column() { Source = source ?? name, Name = name } },
262+
HashBits = hashBits, InvertHash = invertHash }, input)
263+
{
264+
}
265+
237266
public HashTransform(IHostEnvironment env, Arguments args, IDataView input)
238267
: base(Contracts.CheckRef(env, nameof(env)), RegistrationName, env.CheckRef(args, nameof(args)).Column,
239268
input, TestType)

src/Microsoft.ML.Data/Transforms/KeyToValueTransform.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,19 @@ private static VersionInfo GetVersionInfo()
7373
private readonly ColumnType[] _types;
7474
private KeyToValueMap[] _kvMaps;
7575

76+
/// <summary>
77+
/// Convenience constructor for public facing API.
78+
/// </summary>
79+
/// <param name="env">Host Environment.</param>
80+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
81+
/// <param name="name">Name of the output column.</param>
82+
/// <param name="source">Name of the input column. If this is null '<paramref name="name"/>' will be used.</param>
83+
public KeyToValueTransform(IHostEnvironment env, IDataView input, string name, string source = null)
84+
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input)
85+
{
86+
}
87+
88+
7689
/// <summary>
7790
/// Public constructor corresponding to SignatureDataTransform.
7891
/// </summary>

src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,19 @@ public bool TryUnparse(StringBuilder sb)
7070
}
7171
}
7272

73+
private static class Defaults
74+
{
75+
public const bool Bag = false;
76+
}
77+
7378
public sealed class Arguments
7479
{
7580
[Argument(ArgumentType.Multiple, HelpText = "New column definition(s) (optional form: name:src)", ShortName = "col", SortOrder = 1)]
7681
public Column[] Column;
7782

7883
[Argument(ArgumentType.AtMostOnce,
7984
HelpText = "Whether to combine multiple indicator vectors into a single bag vector instead of concatenating them. This is only relevant when the input is a vector.")]
80-
public bool Bag;
85+
public bool Bag = Defaults.Bag;
8186
}
8287

8388
internal const string Summary = "Converts a key column to an indicator vector.";
@@ -112,6 +117,23 @@ private static VersionInfo GetVersionInfo()
112117
private readonly bool[] _concat;
113118
private readonly VectorType[] _types;
114119

120+
/// <summary>
121+
/// Convenience constructor for public facing API.
122+
/// </summary>
123+
/// <param name="env">Host Environment.</param>
124+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
125+
/// <param name="name">Name of the output column.</param>
126+
/// <param name="source">Name of the input column. If this is null '<paramref name="name"/>' will be used.</param>
127+
/// <param name="bag">Whether to combine multiple indicator vectors into a single bag vector instead of concatenating them. This is only relevant when the input is a vector.</param>
128+
public KeyToVectorTransform(IHostEnvironment env,
129+
IDataView input,
130+
string name,
131+
string source = null,
132+
bool bag = Defaults.Bag)
133+
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, Bag = bag }, input)
134+
{
135+
}
136+
115137
/// <summary>
116138
/// Public constructor corresponding to SignatureDataTransform.
117139
/// </summary>

src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,18 @@ private static VersionInfo GetVersionInfo()
6464
private const string RegistrationName = "LabelConvert";
6565
private VectorType _slotType;
6666

67+
/// <summary>
68+
/// Convenience constructor for public facing API.
69+
/// </summary>
70+
/// <param name="env">Host Environment.</param>
71+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
72+
/// <param name="name">Name of the output column.</param>
73+
/// <param name="source">Name of the input column. If this is null '<paramref name="name"/>' will be used.</param>
74+
public LabelConvertTransform(IHostEnvironment env, IDataView input, string name, string source = null)
75+
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input)
76+
{
77+
}
78+
6779
public LabelConvertTransform(IHostEnvironment env, Arguments args, IDataView input)
6880
: base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, input, RowCursorUtils.TestGetLabelGetter)
6981
{

src/Microsoft.ML.Data/Transforms/LabelIndicatorTransform.cs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,23 @@ private static string TestIsMulticlassLabel(ColumnType type)
111111
return $"Label column type is not supported for binary remapping: {type}. Supported types: key, float, double.";
112112
}
113113

114+
/// <summary>
115+
/// Convenience constructor for public facing API.
116+
/// </summary>
117+
/// <param name="env">Host Environment.</param>
118+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
119+
/// <param name="classIndex">Label of the positive class.</param>
120+
/// <param name="name">Name of the output column.</param>
121+
/// <param name="source">Name of the input column. If this is null '<paramref name="name"/>' will be used.</param>
122+
public LabelIndicatorTransform(IHostEnvironment env,
123+
IDataView input,
124+
int classIndex,
125+
string name,
126+
string source = null)
127+
: this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } }, ClassIndex = classIndex }, input)
128+
{
129+
}
130+
114131
public LabelIndicatorTransform(IHostEnvironment env, Arguments args, IDataView input)
115132
: base(env, LoadName, Contracts.CheckRef(args, nameof(args)).Column,
116133
input, TestIsMulticlassLabel)

src/Microsoft.ML.Data/Transforms/RangeFilter.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,19 @@ private static VersionInfo GetVersionInfo()
7777
private readonly bool _includeMin;
7878
private readonly bool _includeMax;
7979

80+
/// <summary>
81+
/// Convenience constructor for public facing API.
82+
/// </summary>
83+
/// <param name="env">Host Environment.</param>
84+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
85+
/// <param name="column">Name of the input column.</param>
86+
/// <param name="minimum">Minimum value (0 to 1 for key types).</param>
87+
/// <param name="maximum">Maximum value (0 to 1 for key types).</param>
88+
public RangeFilter(IHostEnvironment env, IDataView input, string column, Double? minimum = null, Double? maximum = null)
89+
: this(env, new Arguments() { Column = column, Min = minimum, Max = maximum }, input)
90+
{
91+
}
92+
8093
public RangeFilter(IHostEnvironment env, Arguments args, IDataView input)
8194
: base(env, RegistrationName, input)
8295
{

src/Microsoft.ML.Data/Transforms/ShuffleTransform.cs

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,25 @@ namespace Microsoft.ML.Runtime.Data
3333
/// </summary>
3434
public sealed class ShuffleTransform : RowToRowTransformBase
3535
{
36+
private static class Defaults
37+
{
38+
public const int PoolRows = 1000;
39+
public const bool PoolOnly = false;
40+
public const bool ForceShuffle = false;
41+
}
42+
3643
public sealed class Arguments
3744
{
3845
// REVIEW: A more intelligent heuristic, based on the expected size of the inputs, perhaps?
3946
[Argument(ArgumentType.LastOccurenceWins, HelpText = "The pool will have this many rows", ShortName = "rows")]
40-
public int PoolRows = 1000;
47+
public int PoolRows = Defaults.PoolRows;
4148

4249
// REVIEW: Come up with a better way to specify the desired set of functionality.
4350
[Argument(ArgumentType.LastOccurenceWins, HelpText = "If true, the transform will not attempt to shuffle the input cursor but only shuffle based on the pool. This parameter has no effect if the input data was not itself shufflable.", ShortName = "po")]
44-
public bool PoolOnly;
51+
public bool PoolOnly = Defaults.PoolOnly;
4552

4653
[Argument(ArgumentType.LastOccurenceWins, HelpText = "If true, the transform will always provide a shuffled view.", ShortName = "force")]
47-
public bool ForceShuffle;
54+
public bool ForceShuffle = Defaults.ForceShuffle;
4855

4956
[Argument(ArgumentType.LastOccurenceWins, HelpText = "If true, the transform will always shuffle the input. The default value is the same as forceShuffle.", ShortName = "forceSource")]
5057
public bool? ForceShuffleSource;
@@ -79,6 +86,23 @@ private static VersionInfo GetVersionInfo()
7986
// know how to copy other types of values.
8087
private readonly IDataView _subsetInput;
8188

89+
/// <summary>
90+
/// Convenience constructor for public facing API.
91+
/// </summary>
92+
/// <param name="env">Host Environment.</param>
93+
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
94+
/// <param name="poolRows">The pool will have this many rows</param>
95+
/// <param name="poolOnly">If true, the transform will not attempt to shuffle the input cursor but only shuffle based on the pool. This parameter has no effect if the input data was not itself shufflable.</param>
96+
/// <param name="forceShuffle">If true, the transform will always provide a shuffled view.</param>
97+
public ShuffleTransform(IHostEnvironment env,
98+
IDataView input,
99+
int poolRows = Defaults.PoolRows,
100+
bool poolOnly = Defaults.PoolOnly,
101+
bool forceShuffle = Defaults.ForceShuffle)
102+
: this(env, new Arguments() { PoolRows = poolRows, PoolOnly = poolOnly, ForceShuffle = forceShuffle }, input)
103+
{
104+
}
105+
82106
/// <summary>
83107
/// Public constructor corresponding to SignatureDataTransform.
84108
/// </summary>

0 commit comments

Comments
 (0)