Skip to content

Access indices array for VBuffer in KeyToVector transformer only when resulting vector is sparse. #3763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/Microsoft.ML.Data/Transforms/KeyToVector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ private ValueGetter<VBuffer<float>> MakeGetterInd(DataViewRow input, int iinfo)
int lenDst = checked(size * lenSrc);
var values = src.GetValues();
int cntSrc = values.Length;
var editor = VBufferEditor.Create(ref dst, lenDst, cntSrc);
var editor = VBufferEditor.Create(ref dst, lenDst, cntSrc, keepOldOnResize: false, requireIndicesOnDense: true);

int count = 0;
if (src.IsDense)
Expand Down Expand Up @@ -814,14 +814,16 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)

var metadata = new List<SchemaShape.Column>();
if (col.Annotations.TryFindColumn(AnnotationUtils.Kinds.KeyValues, out var keyMeta))
if (col.Kind != SchemaShape.Column.VectorKind.VariableVector && keyMeta.ItemType is TextDataViewType)
if (((colInfo.OutputCountVector && col.IsKey) || col.Kind != SchemaShape.Column.VectorKind.VariableVector) && keyMeta.ItemType is TextDataViewType)
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, keyMeta.ItemType, false));
if (!colInfo.OutputCountVector && (col.Kind == SchemaShape.Column.VectorKind.Scalar || col.Kind == SchemaShape.Column.VectorKind.Vector))
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.CategoricalSlotRanges, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Int32, false));
if (!colInfo.OutputCountVector || (col.Kind == SchemaShape.Column.VectorKind.Scalar))
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));

result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(metadata));
result[colInfo.Name] = new SchemaShape.Column(colInfo.Name,
col.Kind == SchemaShape.Column.VectorKind.VariableVector && !colInfo.OutputCountVector ? SchemaShape.Column.VectorKind.VariableVector : SchemaShape.Column.VectorKind.Vector,
NumberDataViewType.Single, false, new SchemaShape(metadata));
}

return new SchemaShape(result.Values);
Expand Down
4 changes: 3 additions & 1 deletion src/Microsoft.ML.Transforms/KeyToVectorMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,9 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, keyMeta.ItemType, false));
if (col.Kind == SchemaShape.Column.VectorKind.Scalar)
metadata.Add(new SchemaShape.Column(AnnotationUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BooleanDataViewType.Instance, false));
result[colInfo.outputColumnName] = new SchemaShape.Column(colInfo.outputColumnName, SchemaShape.Column.VectorKind.Vector, NumberDataViewType.Single, false, new SchemaShape(metadata));
result[colInfo.outputColumnName] = new SchemaShape.Column(colInfo.outputColumnName,
col.Kind == SchemaShape.Column.VectorKind.VariableVector ? SchemaShape.Column.VectorKind.VariableVector : SchemaShape.Column.VectorKind.Vector,
NumberDataViewType.Single, false, new SchemaShape(metadata));
}

return new SchemaShape(result.Values);
Expand Down
31 changes: 31 additions & 0 deletions test/BaselineOutput/Common/Categorical/oneHot.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#@ TextLoader{
#@ header+
#@ sep=tab
#@ col=A:I4:0
#@ col=B:I4:1-2
#@ col=C:I4:3-**
#@ col={name=CatA type=U4 src={ min=-1} key=2}
#@ col={name=CatA src={ min=-1 max=0 vector=+}}
#@ col={name=CatB type=U4 src={ min=-1} key=2}
#@ col={name=CatB src={ min=-1 max=1 vector=+}}
#@ col={name=CatC type=U4 src={ min=-1} key=2}
#@ col={name=CatC src={ min=-1 max=0 vector=+}}
#@ col={name=CatD type=U4 src={ min=-1} key=2}
#@ col={name=CatVA type=U4 src={ min=-1 max=0 vector=+} key=3}
#@ col={name=CatVA src={ min=-1 max=1 vector=+}}
#@ col={name=CatVB type=U4 src={ min=-1 max=0 vector=+} key=3}
#@ col={name=CatVB src={ min=-1 max=4 vector=+}}
#@ col={name=CatVC type=U4 src={ min=-1 max=0 vector=+} key=3}
#@ col={name=CatVC src={ min=-1 max=4 vector=+}}
#@ col={name=CatVD type=U4 src={ min=-1 max=0 vector=+} key=3}
#@ col={name=CatVVA type=U4 src={ min=-1 var=+} key=3}
#@ col={name=CatVVA src={ min=-1 max=1 vector=+}}
#@ col={name=CatVVB type=U4 src={ min=-1 var=+} key=3}
#@ col={name=CatVVB src={ min=-1 var=+}}
#@ col={name=CatVVC type=U4 src={ min=-1 var=+} key=3}
#@ col={name=CatVVC src={ min=-1 var=+}}
#@ col={name=CatVVD type=U4 src={ min=-1 var=+} key=3}
#@ }
A "" "" CatA 1 4 CatB Bit2 Bit1 Bit0 CatC 1 4 CatD "" "" 2 3 4 "" "" [0].Bit2 [0].Bit1 [0].Bit0 [1].Bit2 [1].Bit1 [1].Bit0 "" "" [0].2 [0].3 [0].4 [1].2 [1].3 [1].4 "" "" 3 4 2
1 2 3 3 4 0 1 0 0 0 0 0 0 1 0 0 0 1 1 1 0 0 1 0 0 0 0 0 1 0 1 1 0 0 0 1 0 0 1 0 1 1 1 0 0 1 0 0 0 0 0 1 0 1 1 0 0 0 1 0 0 1
4 2 4 2 4 3 1 0 1 1 0 0 1 1 0 1 1 0 2 1 0 1 0 2 0 0 0 0 1 0 0 2 1 0 0 0 0 1 0 2 2 1 0 1 1 1 2 1 0 0 1 0 0 0 1 0 0 0 2 1 0 0 0 1 0 1 0 1 0 0 2 1 0
31 changes: 31 additions & 0 deletions test/BaselineOutput/Common/CategoricalHash/oneHotHash.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#@ TextLoader{
#@ header+
#@ sep=tab
#@ col=A:TX:0
#@ col=B:TX:1-2
#@ col=C:TX:3-**
#@ col={name=CatA type=U4 src={ min=-1} key=65536}
#@ col={name=CatA src={ min=-1 max=65534 vector=+}}
#@ col={name=CatB type=U4 src={ min=-1} key=65536}
#@ col={name=CatB src={ min=-1 max=16 vector=+}}
#@ col={name=CatC type=U4 src={ min=-1} key=65536}
#@ col={name=CatC src={ min=-1 max=65534 vector=+}}
#@ col={name=CatD type=U4 src={ min=-1} key=65536}
#@ col={name=CatVA type=U4 src={ min=-1 max=0 vector=+} key=65536}
#@ col={name=CatVA src={ min=-1 max=65534 vector=+}}
#@ col={name=CatVB type=U4 src={ min=-1 max=0 vector=+} key=65536}
#@ col={name=CatVB src={ min=-1 max=34 vector=+}}
#@ col={name=CatVC type=U4 src={ min=-1 max=0 vector=+} key=65536}
#@ col={name=CatVC src={ min=-1 max=131070 vector=+}}
#@ col={name=CatVD type=U4 src={ min=-1 max=0 vector=+} key=65536}
#@ col={name=CatVVA type=U4 src={ min=-1 var=+} key=65536}
#@ col={name=CatVVA src={ min=-1 max=65534 vector=+}}
#@ col={name=CatVVB type=U4 src={ min=-1 var=+} key=65536}
#@ col={name=CatVVB src={ min=-1 var=+}}
#@ col={name=CatVVC type=U4 src={ min=-1 var=+} key=65536}
#@ col={name=CatVVC src={ min=-1 var=+}}
#@ col={name=CatVVD type=U4 src={ min=-1 var=+} key=65536}
#@ }
A 393284 2:CatA 65539:CatB 65558:CatC 131095:CatD
1 2 3 2 3 4 17369 589955 17369:1 65536:17369 65540:1 65545:1 65546:1 65547:1 65548:1 65550:1 65551:1 65554:1 65555:17369 82925:1 131092:17369 131093:45477 131094:61578 176572:1 192673:1 196631:45477 196632:61578 196635:1 196637:1 196638:1 196642:1 196643:1 196645:1 196648:1 196650:1 196653:1 196654:1 196655:1 196656:1 196661:1 196665:1 196667:1 196669:45477 196670:61578 242148:1 323785:1 327743:45477 327744:61578 327745:45477 327746:61578 327747:39452 367200:1 373225:1 389326:1 393284:45477 393285:61578 393286:39452 393289:1 393291:1 393292:1 393296:1 393297:1 393299:1 393302:1 393304:1 393307:1 393308:1 393309:1 393310:1 393315:1 393319:1 393321:1 393325:1 393328:1 393329:1 393331:1 393336:1 393337:1 393338:1 393341:45477 393342:61578 393343:39452 438821:1 520458:1 563868:1 589952:45477 589953:61578 589954:39452
4 4 5 3 4 5 20750 589955 20750:1 65536:20750 65540:1 65542:1 65546:1 65551:1 65552:1 65553:1 65555:20750 86306:1 131092:20750 131093:20750 131094:23709 151845:1 154804:1 196631:20750 196632:23709 196636:1 196638:1 196642:1 196647:1 196648:1 196649:1 196654:1 196656:1 196657:1 196658:1 196661:1 196664:1 196665:1 196666:1 196668:1 196669:20750 196670:23709 217421:1 285916:1 327743:20750 327744:23709 327745:47483 327746:61549 327747:22463 350211:1 375231:1 389297:1 393284:47483 393285:61549 393286:22463 393289:1 393291:1 393292:1 393293:1 393296:1 393298:1 393299:1 393300:1 393301:1 393303:1 393304:1 393307:1 393308:1 393309:1 393310:1 393316:1 393317:1 393319:1 393320:1 393322:1 393326:1 393328:1 393330:1 393331:1 393332:1 393333:1 393335:1 393336:1 393337:1 393338:1 393339:1 393340:1 393341:47483 393342:61549 393343:22463 440827:1 520429:1 546879:1 589952:47483 589953:61549 589954:22463
14 changes: 0 additions & 14 deletions test/BaselineOutput/SingleRelease/Categorical/featurized.tsv

This file was deleted.

13 changes: 0 additions & 13 deletions test/BaselineOutput/SingleRelease/CategoricalHash/featurized.tsv

This file was deleted.

29 changes: 22 additions & 7 deletions test/Microsoft.ML.Tests/Transformers/CategoricalHashTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ public CategoricalHashTests(ITestOutputHelper output) : base(output)
private class TestClass
{
public string A;
public string B;
public string C;
[VectorType(2)]
public string[] B;
public string[] C;
}

private class TestMeta
Expand All @@ -45,17 +46,31 @@ private class TestMeta
[Fact]
public void CategoricalHashWorkout()
{
var data = new[] { new TestClass() { A = "1", B = "2", C = "3", }, new TestClass() { A = "4", B = "5", C = "6" } };
var data = new[] { new TestClass() { A = "1", B = new[] { "2", "3" }, C = new[] { "2", "3", "4" } }, new TestClass() { A = "4", B = new[] { "4", "5" }, C = new[] { "3", "4", "5" } } };

var dataView = ML.Data.LoadFromEnumerable(data);
var pipe = ML.Transforms.Categorical.OneHotHashEncoding(new[]{
new OneHotHashEncodingEstimator.ColumnOptions("CatA", "A", OneHotEncodingEstimator.OutputKind.Bag),
new OneHotHashEncodingEstimator.ColumnOptions("CatA", "A", OneHotEncodingEstimator.OutputKind.Bag),
new OneHotHashEncodingEstimator.ColumnOptions("CatB", "A", OneHotEncodingEstimator.OutputKind.Binary),
new OneHotHashEncodingEstimator.ColumnOptions("CatC", "A", OneHotEncodingEstimator.OutputKind.Indicator),
new OneHotHashEncodingEstimator.ColumnOptions("CatD", "A", OneHotEncodingEstimator.OutputKind.Key),
new OneHotHashEncodingEstimator.ColumnOptions("CatVA", "B", OneHotEncodingEstimator.OutputKind.Bag),
new OneHotHashEncodingEstimator.ColumnOptions("CatVB", "B", OneHotEncodingEstimator.OutputKind.Binary),
new OneHotHashEncodingEstimator.ColumnOptions("CatVC", "B", OneHotEncodingEstimator.OutputKind.Indicator),
new OneHotHashEncodingEstimator.ColumnOptions("CatVD", "B", OneHotEncodingEstimator.OutputKind.Key),
new OneHotHashEncodingEstimator.ColumnOptions("CatVVA", "C", OneHotEncodingEstimator.OutputKind.Bag),
new OneHotHashEncodingEstimator.ColumnOptions("CatVVB", "C", OneHotEncodingEstimator.OutputKind.Binary),
new OneHotHashEncodingEstimator.ColumnOptions("CatVVC", "C", OneHotEncodingEstimator.OutputKind.Indicator),
new OneHotHashEncodingEstimator.ColumnOptions("CatVVD", "C", OneHotEncodingEstimator.OutputKind.Key),
});

TestEstimatorCore(pipe, dataView);
var outputPath = GetOutputPath("CategoricalHash", "oneHotHash.tsv");
var savedData = pipe.Fit(dataView).Transform(dataView);

using (var fs = File.Create(outputPath))
ML.Data.SaveAsText(savedData, fs, headerRow: true, keepHidden: true);
CheckEquality("CategoricalHash", "oneHotHash.tsv");
Done();
}

Expand All @@ -68,7 +83,7 @@ public void CategoricalHashStatic()
VectorString: ctx.LoadText(1, 4),
SingleVectorString: ctx.LoadText(1, 1)));
var data = reader.Load(dataPath);
var wrongCollection = new[] { new TestClass() { A = "1", B = "2", C = "3", }, new TestClass() { A = "4", B = "5", C = "6" } };
var wrongCollection = new[] { new TestClass() { A = "1", B = new[] { "2", "3" }, C = new[] { "2", "3", "4" } }, new TestClass() { A = "4", B = new[] { "4", "5" }, C = new[] { "3", "4", "5" } } };

var invalidData = ML.Data.LoadFromEnumerable(wrongCollection);
var est = data.MakeNewEstimator().
Expand Down Expand Up @@ -211,12 +226,12 @@ public void TestCommandLine()
[Fact]
public void TestOldSavingAndLoading()
{
var data = new[] { new TestClass() { A = "1", B = "2", C = "3", }, new TestClass() { A = "4", B = "5", C = "6" } };
var data = new[] { new TestClass() { A = "1", B = new[] { "2", "3" }, C = new[] { "2", "3", "4" } }, new TestClass() { A = "4", B = new[] { "4", "5" }, C = new[] { "3", "4", "5" } } };
var dataView = ML.Data.LoadFromEnumerable(data);
var pipe = ML.Transforms.Categorical.OneHotHashEncoding(new[]{
new OneHotHashEncodingEstimator.ColumnOptions("CatHashA", "A"),
new OneHotHashEncodingEstimator.ColumnOptions("CatHashB", "B"),
new OneHotHashEncodingEstimator.ColumnOptions("CatHashC", "C")
new OneHotHashEncodingEstimator.ColumnOptions("CatHashC", "C"),
});
var result = pipe.Fit(dataView).Transform(dataView);
var resultRoles = new RoleMappedData(result);
Expand Down
Loading