Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 4ec5133

Browse files
eldarkurticnatuan
andauthored
Fix export of all quantized transformer models (#1654)
* Expose trust_remote_code flag for HF-transformers * Reload big model with multiple state dict files * Add description for reload func * handle new HF interface --------- Co-authored-by: Tuan Nguyen <[email protected]>
1 parent 718c7f4 commit 4ec5133

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

src/sparseml/transformers/export.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,19 +117,23 @@ def place_model_on_device(self):
117117
return False
118118

119119

120-
def load_task_model(task: str, model_path: str, config: Any) -> Module:
120+
def load_task_model(
121+
task: str, model_path: str, config: Any, trust_remote_code: bool = False
122+
) -> Module:
121123
if task == "masked-language-modeling" or task == "mlm":
122124
return SparseAutoModel.masked_language_modeling_from_pretrained(
123125
model_name_or_path=model_path,
124126
config=config,
125127
model_type="model",
128+
trust_remote_code=trust_remote_code,
126129
)
127130

128131
if task == "question-answering" or task == "qa":
129132
return SparseAutoModel.question_answering_from_pretrained(
130133
model_name_or_path=model_path,
131134
config=config,
132135
model_type="model",
136+
trust_remote_code=trust_remote_code,
133137
)
134138

135139
if (
@@ -142,20 +146,23 @@ def load_task_model(task: str, model_path: str, config: Any) -> Module:
142146
model_name_or_path=model_path,
143147
config=config,
144148
model_type="model",
149+
trust_remote_code=trust_remote_code,
145150
)
146151

147152
if task == "token-classification" or task == "ner":
148153
return SparseAutoModel.token_classification_from_pretrained(
149154
model_name_or_path=model_path,
150155
config=config,
151156
model_type="model",
157+
trust_remote_code=trust_remote_code,
152158
)
153159

154160
if task == "text-generation":
155161
return SparseAutoModel.text_generation_from_pretrained(
156162
model_name_or_path=model_path,
157163
config=config,
158164
model_type="model",
165+
trust_remote_code=trust_remote_code,
159166
)
160167

161168
raise ValueError(f"unrecognized task given of {task}")
@@ -236,6 +243,7 @@ def export_transformer_to_onnx(
236243
finetuning_task: Optional[str] = None,
237244
onnx_file_name: str = MODEL_ONNX_NAME,
238245
num_export_samples: int = 0,
246+
trust_remote_code: bool = False,
239247
data_args: Optional[Union[Dict[str, Any], str]] = None,
240248
one_shot: Optional[str] = None,
241249
) -> str:
@@ -255,6 +263,7 @@ def export_transformer_to_onnx(
255263
is model.onnx. Note that when loading a model directory to a deepsparse
256264
pipeline, it will look only for 'model.onnx'
257265
:param num_export_samples: number of samples (inputs/outputs) to export
266+
:param trust_remote_code: set True to allow custom models in HF-transformers
258267
:param data_args: additional args to instantiate a `DataTrainingArguments`
259268
instance for exporting samples
260269
:param one_shot: one shot recipe to be applied before exporting model
@@ -280,6 +289,7 @@ def export_transformer_to_onnx(
280289
config_args = {"finetuning_task": finetuning_task} if finetuning_task else {}
281290
config = AutoConfig.from_pretrained(
282291
model_path,
292+
trust_remote_code=trust_remote_code,
283293
**config_args,
284294
)
285295
tokenizer = AutoTokenizer.from_pretrained(
@@ -288,7 +298,7 @@ def export_transformer_to_onnx(
288298
if task == "text-generation":
289299
tokenizer.pad_token = tokenizer.eos_token
290300

291-
model = load_task_model(task, model_path, config)
301+
model = load_task_model(task, model_path, config, trust_remote_code)
292302
_LOGGER.info(f"loaded model, config, and tokenizer from {model_path}")
293303

294304
eval_dataset = None
@@ -547,6 +557,11 @@ def _parse_args() -> argparse.Namespace:
547557
help="local path or SparseZoo stub to a recipe that should be applied "
548558
"in a one-shot manner before exporting",
549559
)
560+
parser.add_argument(
561+
"--trust_remote_code",
562+
action="store_true",
563+
help=("Set flag to allow custom models in HF-transformers"),
564+
)
550565

551566
return parser.parse_args()
552567

@@ -559,6 +574,7 @@ def export(
559574
finetuning_task: str,
560575
onnx_file_name: str,
561576
num_export_samples: int = 0,
577+
trust_remote_code: bool = False,
562578
data_args: Optional[str] = None,
563579
one_shot: Optional[str] = None,
564580
):
@@ -570,6 +586,7 @@ def export(
570586
finetuning_task=finetuning_task,
571587
onnx_file_name=onnx_file_name,
572588
num_export_samples=num_export_samples,
589+
trust_remote_code=trust_remote_code,
573590
data_args=data_args,
574591
one_shot=one_shot,
575592
)
@@ -593,6 +610,7 @@ def main():
593610
finetuning_task=args.finetuning_task,
594611
onnx_file_name=args.onnx_file_name,
595612
num_export_samples=args.num_export_samples,
613+
trust_remote_code=args.trust_remote_code,
596614
data_args=args.data_args,
597615
one_shot=args.one_shot,
598616
)

src/sparseml/transformers/sparsification/trainer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,11 +683,11 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
683683
dd = torch.load(os.path.join(load_path, f), map_location="cpu")
684684
loaded_state_dict.update(dd)
685685

686-
_, missing, unexpected, _, _ = self.model._load_pretrained_model(
686+
_, missing, unexpected, mismatched, _, _ = self.model._load_pretrained_model(
687687
model=self.model,
688688
state_dict=loaded_state_dict,
689689
loaded_keys=list(loaded_state_dict.keys()),
690-
resolved_archive_file=[],
690+
resolved_archive_file=None,
691691
pretrained_model_name_or_path=load_path,
692692
_fast_init=False,
693693
)
@@ -704,6 +704,12 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
704704
f"{unexpected}"
705705
)
706706

707+
if mismatched:
708+
_LOGGER.warning(
709+
f"Mismatched keys found when reloading model state for SparseML recipe:"
710+
f"{mismatched}"
711+
)
712+
707713
total_loaded = len(current_state_dict) - (len(missing) if len(missing) else 0)
708714
_LOGGER.info(
709715
f"Reloaded {total_loaded} model params for SparseML Recipe from {load_path}"

0 commit comments

Comments
 (0)