diff --git a/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json b/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json index 5d28914734..0ccb7b1fcf 100644 --- a/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json +++ b/src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json @@ -46,6 +46,26 @@ "BertArchitecture.Roberta" ] }, + "dataKind": { + "type": "string", + "enum": [ + "DataKind.Int16", + "DataKind.SByte", + "DataKind.Byte", + "DataKind.UInt16", + "DataKind.Int32", + "DataKind.UInt32", + "DataKind.Int64", + "DataKind.UInt64", + "DataKind.Single", + "DataKind.Double", + "DataKind.String", + "DataKind.Boolean", + "DataKind.TimeSpan", + "DataKind.DateTime", + "DataKind.DateTimeOffset" + ] + }, "bertArchitectureArray": { "type": "array", "items": { @@ -217,7 +237,8 @@ "TrainingAnswerColumnName", "AnswerIndexStartColumnName", "predictedAnswerColumnName", - "TopKAnswers" + "TopKAnswers", + "TargetType" ] }, "option_type": { @@ -235,7 +256,8 @@ "anchor", "dnnModelFactory", "bertArchitecture", - "imageClassificationArchType" + "imageClassificationArchType", + "dataKind" ] } }, diff --git a/src/Microsoft.ML.AutoML/CodeGen/type_converter_search_space.json b/src/Microsoft.ML.AutoML/CodeGen/type_converter_search_space.json index d91293b857..38d860e9c4 100644 --- a/src/Microsoft.ML.AutoML/CodeGen/type_converter_search_space.json +++ b/src/Microsoft.ML.AutoML/CodeGen/type_converter_search_space.json @@ -9,6 +9,11 @@ { "name": "InputColumnNames", "type": "strings" + }, + { + "name": "TargetType", + "type": "dataKind", + "default": "DataKind.Single" } ] } diff --git a/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/TypeConvert.cs b/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/TypeConvert.cs index f5e567baa0..795064a2c9 100644 --- a/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/TypeConvert.cs +++ b/src/Microsoft.ML.AutoML/SweepableEstimator/Estimators/TypeConvert.cs @@ -9,7 +9,7 @@ internal partial class ConvertType public override IEstimator BuildFromOption(MLContext context, ConvertTypeOption param) { var inputOutputPairs = AutoMlUtils.CreateInputOutputColumnPairsFromStrings(param.InputColumnNames, param.OutputColumnNames); - return context.Transforms.Conversion.ConvertType(inputOutputPairs); + return context.Transforms.Conversion.ConvertType(inputOutputPairs, param.TargetType); } } } diff --git a/test/Microsoft.ML.AutoML.Tests/ApprovalTests/AutoFeaturizerTests.AutoFeaturizer_creditapproval_test.approved.txt b/test/Microsoft.ML.AutoML.Tests/ApprovalTests/AutoFeaturizerTests.AutoFeaturizer_creditapproval_test.approved.txt index c2bea3ba98..37287a9664 100644 --- a/test/Microsoft.ML.AutoML.Tests/ApprovalTests/AutoFeaturizerTests.AutoFeaturizer_creditapproval_test.approved.txt +++ b/test/Microsoft.ML.AutoML.Tests/ApprovalTests/AutoFeaturizerTests.AutoFeaturizer_creditapproval_test.approved.txt @@ -21,7 +21,8 @@ ], "InputColumnNames": [ "BooleanFeatures" - ] + ], + "TargetType": "Single" } }, "e2": { diff --git a/test/Microsoft.ML.AutoML.Tests/ApprovalTests/AutoFeaturizerTests.AutoFeaturizer_newspaperchurn_test.approved.txt b/test/Microsoft.ML.AutoML.Tests/ApprovalTests/AutoFeaturizerTests.AutoFeaturizer_newspaperchurn_test.approved.txt index fe234e918c..d85e26c72d 100644 --- a/test/Microsoft.ML.AutoML.Tests/ApprovalTests/AutoFeaturizerTests.AutoFeaturizer_newspaperchurn_test.approved.txt +++ b/test/Microsoft.ML.AutoML.Tests/ApprovalTests/AutoFeaturizerTests.AutoFeaturizer_newspaperchurn_test.approved.txt @@ -21,7 +21,8 @@ ], "InputColumnNames": [ "dummy for Children" - ] + ], + "TargetType": "Single" } }, "e2": { diff --git a/tools-local/Microsoft.ML.AutoML.SourceGenerator/SearchSpaceGenerator.cs b/tools-local/Microsoft.ML.AutoML.SourceGenerator/SearchSpaceGenerator.cs index 5480d683fa..379d9e618e 100644 --- a/tools-local/Microsoft.ML.AutoML.SourceGenerator/SearchSpaceGenerator.cs +++ b/tools-local/Microsoft.ML.AutoML.SourceGenerator/SearchSpaceGenerator.cs @@ -56,6 +56,7 @@ public void Execute(GeneratorExecutionContext context) "dnnModelFactory" => "string", "bertArchitecture" => "BertArchitecture", "imageClassificationArchType" => "Microsoft.ML.Vision.ImageClassificationTrainer.Architecture", + "dataKind" => "Microsoft.ML.Data.DataKind", _ => throw new ArgumentException("unknown type"), }; @@ -74,6 +75,7 @@ public void Execute(GeneratorExecutionContext context) (_, "ColorsOrder") => defaultToken.GetValue(), (_, "BertArchitecture") => defaultToken.GetValue(), (_, "Microsoft.ML.Vision.ImageClassificationTrainer.Architecture") => defaultToken.GetValue(), + (_, "Microsoft.ML.Data.DataKind") => defaultToken.GetValue(), (_, _) => throw new ArgumentException("unknown"), }; diff --git a/tools-local/Microsoft.ML.AutoML.SourceGenerator/Template/SearchSpace.cs b/tools-local/Microsoft.ML.AutoML.SourceGenerator/Template/SearchSpace.cs index be05d616f2..409937611b 100644 --- a/tools-local/Microsoft.ML.AutoML.SourceGenerator/Template/SearchSpace.cs +++ b/tools-local/Microsoft.ML.AutoML.SourceGenerator/Template/SearchSpace.cs @@ -34,6 +34,7 @@ public virtual string TransformText() using Anchor = Microsoft.ML.Transforms.Image.ImageResizingEstimator.Anchor; using BertArchitecture = Microsoft.ML.TorchSharp.NasBert.BertArchitecture; using static Microsoft.ML.Vision.ImageClassificationTrainer.Architecture; +using DataKind = Microsoft.ML.Data.DataKind; #nullable enable namespace "); @@ -92,7 +93,7 @@ internal class SearchSpaceBase /// /// The string builder that generation-time code is using to assemble generated output /// - protected System.Text.StringBuilder GenerationEnvironment + public System.Text.StringBuilder GenerationEnvironment { get { diff --git a/tools-local/Microsoft.ML.AutoML.SourceGenerator/Template/SearchSpace.tt b/tools-local/Microsoft.ML.AutoML.SourceGenerator/Template/SearchSpace.tt index b64990f3f2..bbfad4ecb5 100644 --- a/tools-local/Microsoft.ML.AutoML.SourceGenerator/Template/SearchSpace.tt +++ b/tools-local/Microsoft.ML.AutoML.SourceGenerator/Template/SearchSpace.tt @@ -12,6 +12,7 @@ using ResizingKind = Microsoft.ML.Transforms.Image.ImageResizingEstimator.Resizi using Anchor = Microsoft.ML.Transforms.Image.ImageResizingEstimator.Anchor; using BertArchitecture = Microsoft.ML.TorchSharp.NasBert.BertArchitecture; using static Microsoft.ML.Vision.ImageClassificationTrainer.Architecture; +using DataKind = Microsoft.ML.Data.DataKind; #nullable enable namespace <#=NameSpace#>