Skip to content

Commit 6694a3e

Browse files
authored
Adds RoBERTa model (#11)
1 parent 0858607 commit 6694a3e

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

lib/informers/models.rb

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,18 @@ class RobertaPreTrainedModel < PreTrainedModel
972972
class RobertaModel < RobertaPreTrainedModel
973973
end
974974

975+
class RobertaForTokenClassification < RobertaPreTrainedModel
976+
def call(model_inputs)
977+
TokenClassifierOutput.new(*super(model_inputs))
978+
end
979+
end
980+
981+
class RobertaForSequenceClassification < RobertaPreTrainedModel
982+
def call(model_inputs)
983+
SequenceClassifierOutput.new(*super(model_inputs))
984+
end
985+
end
986+
975987
class RobertaForMaskedLM < RobertaPreTrainedModel
976988
def call(model_inputs)
977989
MaskedLMOutput.new(*super(model_inputs))
@@ -1224,12 +1236,14 @@ class ClapModel < ClapPreTrainedModel
12241236
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
12251237
"bert" => ["BertForSequenceClassification", BertForSequenceClassification],
12261238
"distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification],
1239+
"roberta" => ["RobertaForSequenceClassification", RobertaForSequenceClassification],
12271240
"xlm-roberta" => ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification],
12281241
"bart" => ["BartForSequenceClassification", BartForSequenceClassification]
12291242
}
12301243

12311244
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
1232-
"bert" => ["BertForTokenClassification", BertForTokenClassification]
1245+
"bert" => ["BertForTokenClassification", BertForTokenClassification],
1246+
"roberta" => ["RobertaForTokenClassification", RobertaForTokenClassification]
12331247
}
12341248

12351249
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = {

0 commit comments

Comments
 (0)