@@ -972,6 +972,18 @@ class RobertaPreTrainedModel < PreTrainedModel
972972 class RobertaModel < RobertaPreTrainedModel
973973 end
974974
975+ class RobertaForTokenClassification < RobertaPreTrainedModel
976+ def call ( model_inputs )
977+ TokenClassifierOutput . new ( *super ( model_inputs ) )
978+ end
979+ end
980+
981+ class RobertaForSequenceClassification < RobertaPreTrainedModel
982+ def call ( model_inputs )
983+ SequenceClassifierOutput . new ( *super ( model_inputs ) )
984+ end
985+ end
986+
975987 class RobertaForMaskedLM < RobertaPreTrainedModel
976988 def call ( model_inputs )
977989 MaskedLMOutput . new ( *super ( model_inputs ) )
@@ -1224,12 +1236,14 @@ class ClapModel < ClapPreTrainedModel
12241236 MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
12251237 "bert" => [ "BertForSequenceClassification" , BertForSequenceClassification ] ,
12261238 "distilbert" => [ "DistilBertForSequenceClassification" , DistilBertForSequenceClassification ] ,
1239+ "roberta" => [ "RobertaForSequenceClassification" , RobertaForSequenceClassification ] ,
12271240 "xlm-roberta" => [ "XLMRobertaForSequenceClassification" , XLMRobertaForSequenceClassification ] ,
12281241 "bart" => [ "BartForSequenceClassification" , BartForSequenceClassification ]
12291242 }
12301243
12311244 MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
1232- "bert" => [ "BertForTokenClassification" , BertForTokenClassification ]
1245+ "bert" => [ "BertForTokenClassification" , BertForTokenClassification ] ,
1246+ "roberta" => [ "RobertaForTokenClassification" , RobertaForTokenClassification ]
12331247 }
12341248
12351249 MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES = {
0 commit comments