Skip to content

Commit 384a283

Browse files
feat: add support for ModernBERT (#15)
1 parent 332d5f8 commit 384a283

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

lib/informers/models.rb

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,30 @@ def call(model_inputs)
809809
end
810810
end
811811

812+
class ModernBertPreTrainedModel < PreTrainedModel
813+
end
814+
815+
class ModernBertModel < ModernBertPreTrainedModel
816+
end
817+
818+
class ModernBertForMaskedLM < ModernBertPreTrainedModel
819+
def call(model_inputs)
820+
MaskedLMOutput.new(*super(model_inputs))
821+
end
822+
end
823+
824+
class ModernBertForSequenceClassification < ModernBertPreTrainedModel
825+
def call(model_inputs)
826+
SequenceClassifierOutput.new(*super(model_inputs))
827+
end
828+
end
829+
830+
class ModernBertForTokenClassification < ModernBertPreTrainedModel
831+
def call(model_inputs)
832+
TokenClassifierOutput.new(*super(model_inputs))
833+
end
834+
end
835+
812836
class NomicBertPreTrainedModel < PreTrainedModel
813837
end
814838

@@ -1198,6 +1222,7 @@ class ClapModel < ClapPreTrainedModel
11981222

11991223
MODEL_MAPPING_NAMES_ENCODER_ONLY = {
12001224
"bert" => ["BertModel", BertModel],
1225+
"modernbert" => ["ModernBertModel", ModernBertModel],
12011226
"nomic_bert" => ["NomicBertModel", NomicBertModel],
12021227
"electra" => ["ElectraModel", ElectraModel],
12031228
"convbert" => ["ConvBertModel", ConvBertModel],
@@ -1235,6 +1260,7 @@ class ClapModel < ClapPreTrainedModel
12351260

12361261
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = {
12371262
"bert" => ["BertForSequenceClassification", BertForSequenceClassification],
1263+
"modernbert" => ["ModernBertForSequenceClassification", ModernBertForSequenceClassification],
12381264
"distilbert" => ["DistilBertForSequenceClassification", DistilBertForSequenceClassification],
12391265
"roberta" => ["RobertaForSequenceClassification", RobertaForSequenceClassification],
12401266
"xlm-roberta" => ["XLMRobertaForSequenceClassification", XLMRobertaForSequenceClassification],
@@ -1243,6 +1269,7 @@ class ClapModel < ClapPreTrainedModel
12431269

12441270
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = {
12451271
"bert" => ["BertForTokenClassification", BertForTokenClassification],
1272+
"modernbert" => ["ModernBertForTokenClassification", ModernBertForTokenClassification],
12461273
"roberta" => ["RobertaForTokenClassification", RobertaForTokenClassification]
12471274
}
12481275

@@ -1259,6 +1286,7 @@ class ClapModel < ClapPreTrainedModel
12591286

12601287
MODEL_FOR_MASKED_LM_MAPPING_NAMES = {
12611288
"bert" => ["BertForMaskedLM", BertForMaskedLM],
1289+
"modernbert" => ["ModernBertForMaskedLM", ModernBertForMaskedLM],
12621290
"roberta" => ["RobertaForMaskedLM", RobertaForMaskedLM]
12631291
}
12641292

0 commit comments

Comments
 (0)