Skip to content

Commit 1b67219

Browse files
committed
Add QA sweepable
1 parent a823199 commit 1b67219

File tree

5 files changed

+89
-2
lines changed

5 files changed

+89
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
using Microsoft.ML.TorchSharp;
5+
using Microsoft.ML.TorchSharp.NasBert;
6+
using Microsoft.ML.TorchSharp.Roberta;
7+
8+
namespace Microsoft.ML.AutoML.CodeGen
9+
{
10+
internal class QuestionAnsweringMulti
11+
{
12+
public IEstimator<ITransformer> BuildFromOption(MLContext context, QATrainer.Options param)
13+
{
14+
return context.MulticlassClassification.Trainers.QuestionAnswer(
15+
contextColumnName: param.ContextColumnName,
16+
questionColumnName: param.QuestionColumnName,
17+
trainingAnswerColumnName: param.TrainingAnswerColumnName,
18+
answerIndexColumnName: param.AnswerIndexStartColumnName,
19+
predictedAnswerColumnName: param.PredictedAnswerColumnName,
20+
scoreColumnName: param.ScoreColumnName,
21+
batchSize: param.BatchSize,
22+
maxEpochs: param.MaxEpoch,
23+
topK: param.TopKAnswers,
24+
architecture: BertArchitecture.Roberta);
25+
}
26+
}
27+
}

src/Microsoft.ML.AutoML/CodeGen/estimator-schema.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@
7373
"ForecastBySsa",
7474
"TextClassifcation",
7575
"SentenceSimilarity",
76-
"ObjectDetection"
76+
"ObjectDetection",
77+
"QuestionAnswering"
7778
]
7879
},
7980
"nugetDependencies": {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
{
2+
"$schema": "./search-space-schema.json#",
3+
"name": "question_answering _option",
4+
"search_space": [
5+
{
6+
"name": "ContextColumnName",
7+
"type": "string",
8+
"default": "Context"
9+
},
10+
{
11+
"name": " QuestionColumnName",
12+
"type": "string",
13+
"default": "Question"
14+
},
15+
{
16+
"name": "TrainingAnswerColumnName",
17+
"type": "string",
18+
"default": "TrainingAnswer"
19+
},
20+
{
21+
"name": "AnswerIndexStartColumnName",
22+
"type": "string",
23+
"default": "AnswerStart"
24+
},
25+
{
26+
"name": "ScoreColumnName",
27+
"type": "string",
28+
"default": "Score"
29+
},
30+
{
31+
"name": "predictedAnswerColumnName",
32+
"type": "string",
33+
"default": "Answer"
34+
},
35+
{
36+
"name": "BatchSize",
37+
"type": "integer",
38+
"default": 4
39+
},
40+
{
41+
"name": "MaxEpochs",
42+
"type": "integer",
43+
"default": 10
44+
},
45+
{
46+
"name": "Architecture",
47+
"type": "bertArchitecture",
48+
"default": "BertArchitecture.Roberta"
49+
}
50+
]
51+
}

src/Microsoft.ML.AutoML/CodeGen/search-space-schema.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@
146146
"dnn_featurizer_image_option",
147147
"text_classification_option",
148148
"sentence_similarity_option",
149-
"object_detection_option"
149+
"object_detection_option",
150+
"question_answering _option"
150151
]
151152
},
152153
"option_name": {

src/Microsoft.ML.AutoML/CodeGen/trainer-estimators.json

+7
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,13 @@
532532
"usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ],
533533
"searchOption": "object_detection_option"
534534
},
535+
{
536+
"functionName": "QuestionAnswering",
537+
"estimatorTypes": [ "MultiClassification" ],
538+
"nugetDependencies": [ "Microsoft.ML", "Microsoft.ML.TorchSharp" ],
539+
"usingStatements": [ "Microsoft.ML", "Microsoft.ML.Trainers", "Microsoft.ML.TorchSharp" ],
540+
"searchOption": "question_answering_option"
541+
},
535542
{
536543
"functionName": "ForecastBySsa",
537544
"estimatorTypes": [ "Forecasting" ],

0 commit comments

Comments
 (0)