Skip to content

Commit b4bff87

Browse files
Add public generic methods to TextLoader catalog that accept Options objects (#5134)
* Added new API for generic text loader catalog methods, with an Options parameter * Added a couple of tests to directly test the new API
1 parent 7d8c85b commit b4bff87

File tree

4 files changed

+96
-30
lines changed

4 files changed

+96
-30
lines changed

src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs

+18-9
Original file line numberDiff line numberDiff line change
@@ -1460,6 +1460,23 @@ internal static TextLoader CreateTextLoader<TInput>(IHostEnvironment host,
14601460
bool trimWhitespace = Defaults.TrimWhitespace,
14611461
IMultiStreamSource dataSample = null)
14621462
{
1463+
Options options = new Options
1464+
{
1465+
HasHeader = hasHeader,
1466+
Separators = new[] { separator },
1467+
AllowQuoting = allowQuoting,
1468+
AllowSparse = supportSparse,
1469+
TrimWhitespace = trimWhitespace
1470+
};
1471+
1472+
return CreateTextLoader<TInput>(host, options, dataSample);
1473+
}
1474+
1475+
internal static TextLoader CreateTextLoader<TInput>(IHostEnvironment host,
1476+
Options options = null,
1477+
IMultiStreamSource dataSample = null)
1478+
{
1479+
options = options ?? new Options();
14631480
var userType = typeof(TInput);
14641481

14651482
var fieldInfos = userType.GetFields(BindingFlags.Public | BindingFlags.Instance);
@@ -1520,15 +1537,7 @@ internal static TextLoader CreateTextLoader<TInput>(IHostEnvironment host,
15201537
columns.Add(column);
15211538
}
15221539

1523-
Options options = new Options
1524-
{
1525-
HasHeader = hasHeader,
1526-
Separators = new[] { separator },
1527-
AllowQuoting = allowQuoting,
1528-
AllowSparse = supportSparse,
1529-
TrimWhitespace = trimWhitespace,
1530-
Columns = columns.ToArray()
1531-
};
1540+
options.Columns = columns.ToArray();
15321541

15331542
return new TextLoader(host, options, dataSample: dataSample);
15341543
}

src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderSaverCatalog.cs

+49-14
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,19 @@ public static TextLoader CreateTextLoader<TInput>(this DataOperationsCatalog cat
119119
=> TextLoader.CreateTextLoader<TInput>(CatalogUtils.GetEnvironment(catalog), hasHeader, separatorChar, allowQuoting,
120120
allowSparse, trimWhitespace, dataSample: dataSample);
121121

122+
/// <summary>
123+
/// Create a text loader <see cref="TextLoader"/> by inferencing the dataset schema from a data model type.
124+
/// </summary>
125+
/// <param name="catalog">The <see cref="DataOperationsCatalog"/> catalog.</param>
126+
/// <param name="options">Defines the settings of the load operation. Defines the settings of the load operation. No need to specify a Columns field,
127+
/// as columns will be infered by this method.</param>
128+
/// <param name="dataSample">The optional location of a data sample. The sample can be used to infer information
129+
/// about the columns, such as slot names.</param>
130+
public static TextLoader CreateTextLoader<TInput>(this DataOperationsCatalog catalog,
131+
TextLoader.Options options,
132+
IMultiStreamSource dataSample = null)
133+
=> TextLoader.CreateTextLoader<TInput>(CatalogUtils.GetEnvironment(catalog), options, dataSample);
134+
122135
/// <summary>
123136
/// Load a <see cref="IDataView"/> from a text file using <see cref="TextLoader"/>.
124137
/// Note that <see cref="IDataView"/>'s are lazy, so no actual loading happens here, just schema validation.
@@ -172,6 +185,35 @@ public static IDataView LoadFromTextFile(this DataOperationsCatalog catalog,
172185
return loader.Load(new MultiFileSource(path));
173186
}
174187

188+
/// <summary>
189+
/// Load a <see cref="IDataView"/> from a text file using <see cref="TextLoader"/>.
190+
/// Note that <see cref="IDataView"/>'s are lazy, so no actual loading happens here, just schema validation.
191+
/// </summary>
192+
/// <param name="catalog">The <see cref="DataOperationsCatalog"/> catalog.</param>
193+
/// <param name="path">Specifies a file from which to load.</param>
194+
/// <param name="options">Defines the settings of the load operation.</param>
195+
/// <example>
196+
/// <format type="text/markdown">
197+
/// <![CDATA[
198+
/// [!code-csharp[LoadFromTextFile](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/SaveAndLoadFromText.cs)]
199+
/// ]]>
200+
/// </format>
201+
/// </example>
202+
public static IDataView LoadFromTextFile(this DataOperationsCatalog catalog, string path,
203+
TextLoader.Options options = null)
204+
{
205+
Contracts.CheckNonEmpty(path, nameof(path));
206+
if (!File.Exists(path))
207+
{
208+
throw Contracts.ExceptParam(nameof(path), "File does not exist at path: {0}", path);
209+
}
210+
211+
var env = catalog.GetEnvironment();
212+
var source = new MultiFileSource(path);
213+
214+
return new TextLoader(env, options, dataSample: source).Load(source);
215+
}
216+
175217
/// <summary>
176218
/// Load a <see cref="IDataView"/> from a text file using <see cref="TextLoader"/>.
177219
/// Note that <see cref="IDataView"/>'s are lazy, so no actual loading happens here, just schema validation.
@@ -221,27 +263,20 @@ public static IDataView LoadFromTextFile<TInput>(this DataOperationsCatalog cata
221263
/// </summary>
222264
/// <param name="catalog">The <see cref="DataOperationsCatalog"/> catalog.</param>
223265
/// <param name="path">Specifies a file from which to load.</param>
224-
/// <param name="options">Defines the settings of the load operation.</param>
225-
/// <example>
226-
/// <format type="text/markdown">
227-
/// <![CDATA[
228-
/// [!code-csharp[LoadFromTextFile](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/DataOperations/SaveAndLoadFromText.cs)]
229-
/// ]]>
230-
/// </format>
231-
/// </example>
232-
public static IDataView LoadFromTextFile(this DataOperationsCatalog catalog, string path,
233-
TextLoader.Options options = null)
266+
/// <param name="options">Defines the settings of the load operation. No need to specify a Columns field,
267+
/// as columns will be infered by this method.</param>
268+
/// <returns>The data view.</returns>
269+
public static IDataView LoadFromTextFile<TInput>(this DataOperationsCatalog catalog, string path,
270+
TextLoader.Options options)
234271
{
235272
Contracts.CheckNonEmpty(path, nameof(path));
236273
if (!File.Exists(path))
237274
{
238275
throw Contracts.ExceptParam(nameof(path), "File does not exist at path: {0}", path);
239276
}
240277

241-
var env = catalog.GetEnvironment();
242-
var source = new MultiFileSource(path);
243-
244-
return new TextLoader(env, options, dataSample: source).Load(source);
278+
return TextLoader.CreateTextLoader<TInput>(CatalogUtils.GetEnvironment(catalog), options)
279+
.Load(new MultiFileSource(path));
245280
}
246281

247282
/// <summary>

test/Microsoft.ML.Functional.Tests/Prediction.cs

+7-2
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,14 @@ public void ReconfigurablePrediction()
3636
{
3737
var mlContext = new MLContext(seed: 1);
3838

39+
var options = new TextLoader.Options
40+
{
41+
HasHeader = TestDatasets.Sentiment.fileHasHeader,
42+
Separators = new[] { TestDatasets.Sentiment.fileSeparator }
43+
};
44+
3945
var data = mlContext.Data.LoadFromTextFile<TweetSentiment>(TestCommon.GetDataPath(DataDir, TestDatasets.Sentiment.trainFilename),
40-
hasHeader: TestDatasets.Sentiment.fileHasHeader,
41-
separatorChar: TestDatasets.Sentiment.fileSeparator);
46+
options);
4247

4348
// Create a training pipeline.
4449
var pipeline = mlContext.Transforms.Text.FeaturizeText("Features", "SentimentText")

test/Microsoft.ML.Tests/TextLoaderTests.cs

+22-5
Original file line numberDiff line numberDiff line change
@@ -704,8 +704,10 @@ public class IrisColumnIndices
704704
public string Type;
705705
}
706706

707-
[Fact]
708-
public void LoaderColumnsFromIrisData()
707+
[Theory]
708+
[InlineData(true)]
709+
[InlineData(false)]
710+
public void LoaderColumnsFromIrisData(bool useOptionsObject)
709711
{
710712
var dataPath = GetDataPath(TestDatasets.irisData.trainFilename);
711713
var mlContext = new MLContext(1);
@@ -719,7 +721,12 @@ public void LoaderColumnsFromIrisData()
719721
var irisFirstRowValues = irisFirstRow.Values.GetEnumerator();
720722

721723
// Simple load
722-
var dataIris = mlContext.Data.CreateTextLoader<Iris>(separatorChar: ',').Load(dataPath);
724+
IDataView dataIris;
725+
if (useOptionsObject)
726+
dataIris = mlContext.Data.CreateTextLoader<Iris>(new TextLoader.Options() { Separator = ",", AllowQuoting = false }).Load(dataPath);
727+
else
728+
dataIris = mlContext.Data.CreateTextLoader<Iris>(separatorChar: ',').Load(dataPath);
729+
723730
var previewIris = dataIris.Preview(1);
724731

725732
Assert.Equal(5, previewIris.ColumnView.Length);
@@ -735,7 +742,12 @@ public void LoaderColumnsFromIrisData()
735742
Assert.Equal("Iris-setosa", previewIris.RowView[0].Values[index].Value.ToString());
736743

737744
// Load with start and end indexes
738-
var dataIrisStartEnd = mlContext.Data.CreateTextLoader<IrisStartEnd>(separatorChar: ',').Load(dataPath);
745+
IDataView dataIrisStartEnd;
746+
if (useOptionsObject)
747+
dataIrisStartEnd = mlContext.Data.CreateTextLoader<IrisStartEnd>(new TextLoader.Options() { Separator = ",", AllowQuoting = false }).Load(dataPath);
748+
else
749+
dataIrisStartEnd = mlContext.Data.CreateTextLoader<IrisStartEnd>(separatorChar: ',').Load(dataPath);
750+
739751
var previewIrisStartEnd = dataIrisStartEnd.Preview(1);
740752

741753
Assert.Equal(2, previewIrisStartEnd.ColumnView.Length);
@@ -752,7 +764,12 @@ public void LoaderColumnsFromIrisData()
752764
}
753765

754766
// load setting the distinct columns. Loading column 0 and 2
755-
var dataIrisColumnIndices = mlContext.Data.CreateTextLoader<IrisColumnIndices>(separatorChar: ',').Load(dataPath);
767+
IDataView dataIrisColumnIndices;
768+
if (useOptionsObject)
769+
dataIrisColumnIndices = mlContext.Data.CreateTextLoader<IrisColumnIndices>(new TextLoader.Options() { Separator = ",", AllowQuoting = false }).Load(dataPath);
770+
else
771+
dataIrisColumnIndices = mlContext.Data.CreateTextLoader<IrisColumnIndices>(separatorChar: ',').Load(dataPath);
772+
756773
var previewIrisColumnIndices = dataIrisColumnIndices.Preview(1);
757774

758775
Assert.Equal(2, previewIrisColumnIndices.ColumnView.Length);

0 commit comments

Comments
 (0)