@@ -117,19 +117,23 @@ def place_model_on_device(self):
117
117
return False
118
118
119
119
120
- def load_task_model (task : str , model_path : str , config : Any ) -> Module :
120
+ def load_task_model (
121
+ task : str , model_path : str , config : Any , trust_remote_code : bool = False
122
+ ) -> Module :
121
123
if task == "masked-language-modeling" or task == "mlm" :
122
124
return SparseAutoModel .masked_language_modeling_from_pretrained (
123
125
model_name_or_path = model_path ,
124
126
config = config ,
125
127
model_type = "model" ,
128
+ trust_remote_code = trust_remote_code ,
126
129
)
127
130
128
131
if task == "question-answering" or task == "qa" :
129
132
return SparseAutoModel .question_answering_from_pretrained (
130
133
model_name_or_path = model_path ,
131
134
config = config ,
132
135
model_type = "model" ,
136
+ trust_remote_code = trust_remote_code ,
133
137
)
134
138
135
139
if (
@@ -142,20 +146,23 @@ def load_task_model(task: str, model_path: str, config: Any) -> Module:
142
146
model_name_or_path = model_path ,
143
147
config = config ,
144
148
model_type = "model" ,
149
+ trust_remote_code = trust_remote_code ,
145
150
)
146
151
147
152
if task == "token-classification" or task == "ner" :
148
153
return SparseAutoModel .token_classification_from_pretrained (
149
154
model_name_or_path = model_path ,
150
155
config = config ,
151
156
model_type = "model" ,
157
+ trust_remote_code = trust_remote_code ,
152
158
)
153
159
154
160
if task == "text-generation" :
155
161
return SparseAutoModel .text_generation_from_pretrained (
156
162
model_name_or_path = model_path ,
157
163
config = config ,
158
164
model_type = "model" ,
165
+ trust_remote_code = trust_remote_code ,
159
166
)
160
167
161
168
raise ValueError (f"unrecognized task given of { task } " )
@@ -236,6 +243,7 @@ def export_transformer_to_onnx(
236
243
finetuning_task : Optional [str ] = None ,
237
244
onnx_file_name : str = MODEL_ONNX_NAME ,
238
245
num_export_samples : int = 0 ,
246
+ trust_remote_code : bool = False ,
239
247
data_args : Optional [Union [Dict [str , Any ], str ]] = None ,
240
248
one_shot : Optional [str ] = None ,
241
249
) -> str :
@@ -255,6 +263,7 @@ def export_transformer_to_onnx(
255
263
is model.onnx. Note that when loading a model directory to a deepsparse
256
264
pipeline, it will look only for 'model.onnx'
257
265
:param num_export_samples: number of samples (inputs/outputs) to export
266
+ :param trust_remote_code: set True to allow custom models in HF-transformers
258
267
:param data_args: additional args to instantiate a `DataTrainingArguments`
259
268
instance for exporting samples
260
269
:param one_shot: one shot recipe to be applied before exporting model
@@ -280,6 +289,7 @@ def export_transformer_to_onnx(
280
289
config_args = {"finetuning_task" : finetuning_task } if finetuning_task else {}
281
290
config = AutoConfig .from_pretrained (
282
291
model_path ,
292
+ trust_remote_code = trust_remote_code ,
283
293
** config_args ,
284
294
)
285
295
tokenizer = AutoTokenizer .from_pretrained (
@@ -288,7 +298,7 @@ def export_transformer_to_onnx(
288
298
if task == "text-generation" :
289
299
tokenizer .pad_token = tokenizer .eos_token
290
300
291
- model = load_task_model (task , model_path , config )
301
+ model = load_task_model (task , model_path , config , trust_remote_code )
292
302
_LOGGER .info (f"loaded model, config, and tokenizer from { model_path } " )
293
303
294
304
eval_dataset = None
@@ -547,6 +557,11 @@ def _parse_args() -> argparse.Namespace:
547
557
help = "local path or SparseZoo stub to a recipe that should be applied "
548
558
"in a one-shot manner before exporting" ,
549
559
)
560
+ parser .add_argument (
561
+ "--trust_remote_code" ,
562
+ action = "store_true" ,
563
+ help = ("Set flag to allow custom models in HF-transformers" ),
564
+ )
550
565
551
566
return parser .parse_args ()
552
567
@@ -559,6 +574,7 @@ def export(
559
574
finetuning_task : str ,
560
575
onnx_file_name : str ,
561
576
num_export_samples : int = 0 ,
577
+ trust_remote_code : bool = False ,
562
578
data_args : Optional [str ] = None ,
563
579
one_shot : Optional [str ] = None ,
564
580
):
@@ -570,6 +586,7 @@ def export(
570
586
finetuning_task = finetuning_task ,
571
587
onnx_file_name = onnx_file_name ,
572
588
num_export_samples = num_export_samples ,
589
+ trust_remote_code = trust_remote_code ,
573
590
data_args = data_args ,
574
591
one_shot = one_shot ,
575
592
)
@@ -593,6 +610,7 @@ def main():
593
610
finetuning_task = args .finetuning_task ,
594
611
onnx_file_name = args .onnx_file_name ,
595
612
num_export_samples = args .num_export_samples ,
613
+ trust_remote_code = args .trust_remote_code ,
596
614
data_args = args .data_args ,
597
615
one_shot = args .one_shot ,
598
616
)
0 commit comments