Skip to content

Added decimal marker option in TextLoader #5145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 22, 2020
Merged
Prev Previous commit
Next Next commit
Added unit test for ',' as a decimal marker, and added decimalMarker …
…to TextLoaderCursor and TextLoaderParser
  • Loading branch information
mstfbl committed May 20, 2020
commit 7658a70887bdda3c59bef9f5ae30b27175c00b9b
12 changes: 7 additions & 5 deletions src/Microsoft.ML.Core/Utilities/DoubleParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,8 @@ private static bool TryParseCore(ReadOnlySpan<char> span, ref int ich, ref bool
Contracts.Assert(num == 0);
Contracts.Assert(exp == 0);

const char decimalMarker = '.';

if (ich >= span.Length)
return false;

Expand Down Expand Up @@ -554,7 +556,7 @@ private static bool TryParseCore(ReadOnlySpan<char> span, ref int ich, ref bool
return false;
break;

case '.':
case decimalMarker:
goto LPoint;

// The common cases.
Expand All @@ -571,7 +573,7 @@ private static bool TryParseCore(ReadOnlySpan<char> span, ref int ich, ref bool
break;
}

// Get digits before '.'
// Get digits before the decimal marker, which may be '.' or ','
uint d;
for (; ; )
{
Expand All @@ -593,14 +595,14 @@ private static bool TryParseCore(ReadOnlySpan<char> span, ref int ich, ref bool
}
Contracts.Assert(i < span.Length);

if (span[i] != '.')
if (span[i] != decimalMarker)
goto LAfterDigits;

LPoint:
Contracts.Assert(i < span.Length);
Contracts.Assert(span[i] == '.');
Contracts.Assert(span[i] == decimalMarker);

// Get the digits after '.'
// Get the digits after the decimal marker, which may be '.' or ','
for (; ; )
{
if (++i >= span.Length)
Expand Down
13 changes: 8 additions & 5 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -709,11 +709,11 @@ public Bindings(TextLoader parent, Column[] cols, IMultiStreamSource headerFile,
ch.Assert(0 <= inputSize & inputSize < SrcLim);
List<ReadOnlyMemory<char>> lines = null;
if (headerFile != null)
Cursor.GetSomeLines(headerFile, 1, parent.ReadMultilines, parent._separators, ref lines);
Cursor.GetSomeLines(headerFile, 1, parent.ReadMultilines, parent._separators, ref lines, parent._decimalMarker);
if (needInputSize && inputSize == 0)
Cursor.GetSomeLines(dataSample, 100, parent.ReadMultilines, parent._separators, ref lines);
Cursor.GetSomeLines(dataSample, 100, parent.ReadMultilines, parent._separators, ref lines, parent._decimalMarker);
else if (headerFile == null && parent.HasHeader)
Cursor.GetSomeLines(dataSample, 1, parent.ReadMultilines, parent._separators, ref lines);
Cursor.GetSomeLines(dataSample, 1, parent.ReadMultilines, parent._separators, ref lines, parent._decimalMarker);

if (needInputSize && inputSize == 0)
{
Expand Down Expand Up @@ -1410,8 +1410,11 @@ private TextLoader(IHost host, ModelLoadContext ctx)
if (_separators.Contains(':'))
host.CheckDecode((_flags & OptionFlags.AllowSparse) == 0);

_decimalMarker = ctx.Reader.ReadChar();
host.CheckDecode(_decimalMarker == '.' || _decimalMarker == ',');
if (ctx.Header.ModelVerWritten >= 0x0001000D)
{
_decimalMarker = ctx.Reader.ReadChar();
host.CheckDecode(_decimalMarker == '.' || _decimalMarker == ',');
}
_bindings = new Bindings(ctx, this);
_parser = new Parser(this);
}
Expand Down
12 changes: 7 additions & 5 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderCursor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ public static DataViewRowCursor Create(TextLoader parent, IMultiStreamSource fil
SetupCursor(parent, active, 0, out srcNeeded, out cthd);
Contracts.Assert(cthd > 0);

var reader = new LineReader(files, BatchSize, 100, parent.HasHeader, parent.ReadMultilines, parent._separators, parent._maxRows, 1);
var reader = new LineReader(files, BatchSize, 100, parent.HasHeader, parent.ReadMultilines, parent._separators, parent._maxRows, 1, parent._decimalMarker);
var stats = new ParseStats(parent._host, 1);
return new Cursor(parent, stats, active, reader, srcNeeded, cthd);
}
Expand All @@ -163,7 +163,7 @@ public static DataViewRowCursor[] CreateSet(TextLoader parent, IMultiStreamSourc
SetupCursor(parent, active, n, out srcNeeded, out cthd);
Contracts.Assert(cthd > 0);

var reader = new LineReader(files, BatchSize, 100, parent.HasHeader, parent.ReadMultilines, parent._separators, parent._maxRows, cthd);
var reader = new LineReader(files, BatchSize, 100, parent.HasHeader, parent.ReadMultilines, parent._separators, parent._maxRows, cthd, parent._decimalMarker);
var stats = new ParseStats(parent._host, cthd);
if (cthd <= 1)
return new DataViewRowCursor[1] { new Cursor(parent, stats, active, reader, srcNeeded, 1) };
Expand Down Expand Up @@ -205,7 +205,7 @@ public override ValueGetter<DataViewRowId> GetIdGetter()
};
}

public static void GetSomeLines(IMultiStreamSource source, int count, bool readMultilines, char[] separators, ref List<ReadOnlyMemory<char>> lines)
public static void GetSomeLines(IMultiStreamSource source, int count, bool readMultilines, char[] separators, ref List<ReadOnlyMemory<char>> lines, char decimalMarker)
{
Contracts.AssertValue(source);
Contracts.Assert(count > 0);
Expand All @@ -215,7 +215,7 @@ public static void GetSomeLines(IMultiStreamSource source, int count, bool readM
count = 2;

LineBatch batch;
var reader = new LineReader(source, count, 1, false, readMultilines, separators, count, 1);
var reader = new LineReader(source, count, 1, false, readMultilines, separators, count, 1, decimalMarker);
try
{
batch = reader.GetBatch();
Expand Down Expand Up @@ -404,6 +404,7 @@ private sealed class LineReader
private readonly bool _hasHeader;
private readonly bool _readMultilines;
private readonly char[] _separators;
private readonly char _decimalMarker;
private readonly int _batchSize;
private readonly IMultiStreamSource _files;

Expand All @@ -413,7 +414,7 @@ private sealed class LineReader
private Task _thdRead;
private volatile bool _abort;

public LineReader(IMultiStreamSource files, int batchSize, int bufSize, bool hasHeader, bool readMultilines, char[] separators, long limit, int cref)
public LineReader(IMultiStreamSource files, int batchSize, int bufSize, bool hasHeader, bool readMultilines, char[] separators, long limit, int cref, char decimalMarker)
{
// Note that files is allowed to be empty.
Contracts.AssertValue(files);
Expand All @@ -430,6 +431,7 @@ public LineReader(IMultiStreamSource files, int batchSize, int bufSize, bool has
_separators = separators;
_files = files;
_cref = cref;
_decimalMarker = decimalMarker;

_queue = new BlockingQueue<LineBatch>(bufSize);
_thdRead = Utils.RunOnBackgroundThreadAsync(ThreadProc);
Expand Down
2 changes: 2 additions & 0 deletions src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ public void Clear()
}

private readonly char[] _separators;
private readonly char _decimalMarker;
private readonly OptionFlags _flags;
private readonly int _inputSize;
private readonly ColInfo[] _infos;
Expand Down Expand Up @@ -683,6 +684,7 @@ public Parser(TextLoader parent)
}

_separators = parent._separators;
_decimalMarker = parent._decimalMarker;
_flags = parent._flags;
_inputSize = parent._inputSize;
Contracts.Assert(_inputSize >= 0);
Expand Down
35 changes: 35 additions & 0 deletions test/Microsoft.ML.Tests/TextLoaderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,41 @@ public void TestTextLoaderKeyTypeBackCompat()
}
}

[Fact]
public void TestCommaAsDecimalMarker()
{
string dataPath = GetDataPath("iris_decimal_marker_as_comma.txt");

// Create a new context for ML.NET operations. It can be used for exception tracking and logging,
// as a catalog of available operations and as the source of randomness.
var mlContext = new MLContext(seed: 1);
var reader = new TextLoader(mlContext, new TextLoader.Options()
{
Columns = new[]
{
new TextLoader.Column("Label", DataKind.Single, 0),
new TextLoader.Column("Features", DataKind.Single, new [] { new TextLoader.Range(1, 4) }),
},
DecimalMarker = ','
});
// Data
var textData = reader.Load(GetDataPath(dataPath));
var data = mlContext.Data.Cache(mlContext.Transforms.Conversion.MapValueToKey("Label")
.Fit(textData).Transform(textData));

// Pipeline
var pipeline = mlContext.MulticlassClassification.Trainers.OneVersusAll(
mlContext.BinaryClassification.Trainers.LinearSvm(new Trainers.LinearSvmTrainer.Options { NumberOfIterations = 100 }),
useProbabilities: false);

var model = pipeline.Fit(data);
var predictions = model.Transform(data);

// Metrics
var metrics = mlContext.MulticlassClassification.Evaluate(predictions);
Assert.True(metrics.MicroAccuracy > 0.83);
}

private class IrisNoFields
{
}
Expand Down