You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
the example given by @MLRadfys for merging converted roberta tokenizer with onnx model was very helpful. However, when I adapted for BERT, it fails on this line of code
outputs=session.run(None, input_feed)
with this error:
2025-05-08 15:56:03.400921608 [E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running Gather node. Name:'/bert/Gather' Status Message: indices element out of data bounds, idx=1 must be within the inclusive range [-1,0]
Any suggestions / feedback on where the code is going wrong would be much appreciated.
Output from code before the error was raised/thrown:
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Tokenizer Model Inputs: ['text']
Tokenizer Model Outputs: ['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping']
Model Inputs: ['input_ids', 'attention_mask', 'token_type_ids']
Model Outputs: ['logits']
Hugging Face input feed ['A test text!']
Hugging Face logits: [[-0.44599634 0.35376194]]
Code to convert to BERT tokenizer and merge with onnx model:
# Including tokenizer to onnx model / basic usage of the onnxruntime-extensions:# https://github.com/microsoft/onnxruntime-extensions/issues/798importtorchimportonnximportonnxruntimeasortimportnumpyasnpfromtransformersimportAutoTokenizer, BertForSequenceClassificationfromonnxruntime_extensionsimportgen_processing_models, get_library_path# --------------------------------# Configuration & HF Setup # --------------------------------# Load the Huggingface BERT tokenizer and modelinput_text="A test text!"model_type="bert-base-cased"model=BertForSequenceClassification.from_pretrained(model_type)
tokenizer=AutoTokenizer.from_pretrained(model_type)
# Get HF model logitshf_inputs=tokenizer(
input_text,
return_tensors="pt",
padding=True,
truncation=True,
)
withtorch.no_grad():
outputs=model(**hf_inputs)
logits_tensor=outputs.logitshf_logits=logits_tensor.cpu().detach().numpy()
# --------------------------------# Export Tokenizer to ONNX # --------------------------------onnx_tokenizer_path="bert_tokenizer.onnx"# Export the tokenizer to ONNX using gen_processing_modelstokenizer_onnx_model=gen_processing_models(tokenizer, pre_kwargs={})[0]
# Save the tokenizer ONNX modelwithopen(onnx_tokenizer_path, "wb") asf:
f.write(tokenizer_onnx_model.SerializeToString())
# --------------------------------# Export BERT model to ONNX # --------------------------------# Export the Huggingface BERT model to ONNXonnx_model_path="bert_model.onnx"dummy_input=tokenizer("This is a dummy input", return_tensors="pt")
seq_len=dummy_input["input_ids"].shape[1]
torch.onnx.export(
model,
(dummy_input['input_ids'],dummy_input["attention_mask"],dummy_input["token_type_ids"]),
onnx_model_path,
input_names=["input_ids", "attention_mask", "token_type_ids"],
output_names=["logits"],
dynamic_axes={
"input_ids": {0: "batch_size", 1: "seq_len"},
"attention_mask": {0: "batch_size", 1: "seq_len"},
"token_type_ids": {0: "batch_size", 1: "seq_len"},
"logits": {0: "batch_size"}
},
)
# Step 4: Merge the tokenizer and model ONNX files into oneonnx_combined_model_path="combined_bert_model_tokenizer.onnx"# Load the tokenizer and model ONNX filestokenizer_onnx_model=onnx.load(onnx_tokenizer_path)
model_onnx_model=onnx.load(onnx_model_path)
# Force versions to be the samemodel_onnx_model.ir_version=tokenizer_onnx_model.ir_version# Inspect the ONNX models to find the correct input/output namesprint("Tokenizer Model Inputs:", [node.namefornodeintokenizer_onnx_model.graph.input])
print("Tokenizer Model Outputs:", [node.namefornodeintokenizer_onnx_model.graph.output])
print("Model Inputs:", [node.namefornodeinmodel_onnx_model.graph.input])
print("Model Outputs:", [node.namefornodeinmodel_onnx_model.graph.output])
# Merge the tokenizer and model ONNX filescombined_model=onnx.compose.merge_models(
tokenizer_onnx_model,
model_onnx_model,
io_map=[
('input_ids', 'input_ids'),
('attention_mask', 'attention_mask'),
('token_type_ids', 'token_type_ids'),
]
)
# Save the combined modelonnx.save(combined_model, onnx_combined_model_path)
# Step 5: Test the combined ONNX model using an Inference session with ONNX Runtime Extensions# Initialize ONNX Runtime SessionOptions and load custom ops librarysess_options=ort.SessionOptions()
sess_options.register_custom_ops_library(get_library_path())
# Initialize ONNX Runtime Inference session with Extensionssession=ort.InferenceSession(onnx_combined_model_path, sess_options=sess_options, providers=['CPUExecutionProvider'])
# Prepare dummy input textinput_feed= {"text": np.asarray([input_text])} # Assuming "input_text" is the input expected by the tokenizerprint(" Hugging Face input feed", np.asarray([input_text]))
print(" Hugging Face logits:", hf_logits)
# Run the modeloutputs=session.run(None, input_feed)
# Print the outputsprint("logits:", outputs[1][0])
# ---------------------------# Verify exact parity# ---------------------------print(" Hugging Face logits:", hf_logits.flatten())
print("Combined ONNX logits:", combined_logits.flatten())
print(" Match exactly?", np.allclose(hf_logits, combined_logits, atol=1e-6))
importtorchfromonnxruntime_extensionsimportgen_processing_modelsfromonnxruntime_extensionsimportget_library_pathimportonnximportonnxruntimeasortimportnumpyasnpfromtransformersimportRobertaForSequenceClassification, RobertaTokenizer# Step 1: Load the Huggingface Roberta tokenizer and modelinput_text="A test text!"model_type="roberta-base"model=RobertaForSequenceClassification.from_pretrained(model_type)
tokenizer=RobertaTokenizer.from_pretrained(model_type)
# Step 2: Export the tokenizer to ONNX using gen_processing_modelsonnx_tokenizer_path="tokenizer.onnx"# Generate the tokenizer ONNX modeltokenizer_onnx_model=gen_processing_models(tokenizer, pre_kwargs={})[0]
# Save the tokenizer ONNX modelwithopen(onnx_tokenizer_path, "wb") asf:
f.write(tokenizer_onnx_model.SerializeToString())
# Step 3: Export the Huggingface Roberta model to ONNXonnx_model_path="model.onnx"dummy_input=tokenizer("This is a dummy input", return_tensors="pt")
# 5. Export the model to ONNXtorch.onnx.export(
model, # model to be exported
(dummy_input['input_ids'],dummy_input["attention_mask"]), # model input (dummy input)onnx_model_path, # where to save the ONNX modelinput_names=["input_ids", "attention_mask_input"], # input tensor nameoutput_names=["logits"], # output tensor namesdynamic_axes={"input_ids": {0: "batch_size", 1: "sequence_length"}, # dynamic axes"logits": {0: "batch_size"}
}
)
# Step 4: Merge the tokenizer and model ONNX files into oneonnx_combined_model_path="combined_model_tokenizer.onnx"# Load the tokenizer and model ONNX filestokenizer_onnx_model=onnx.load(onnx_tokenizer_path)
model_onnx_model=onnx.load(onnx_model_path)
# Inspect the ONNX models to find the correct input/output namesprint("Tokenizer Model Inputs:", [node.namefornodeintokenizer_onnx_model.graph.input])
print("Tokenizer Model Outputs:", [node.namefornodeintokenizer_onnx_model.graph.output])
print("Model Inputs:", [node.namefornodeinmodel_onnx_model.graph.input])
print("Model Outputs:", [node.namefornodeinmodel_onnx_model.graph.output])
# Merge the tokenizer and model ONNX filescombined_model=onnx.compose.merge_models(
tokenizer_onnx_model,
model_onnx_model,
io_map=[('input_ids', 'input_ids'), ('attention_mask', 'attention_mask_input')]
)
# Save the combined modelonnx.save(combined_model, onnx_combined_model_path)
# Step 5: Test the combined ONNX model using an Inference session with ONNX Runtime Extensions# Initialize ONNX Runtime SessionOptions and load custom ops librarysess_options=ort.SessionOptions()
sess_options.register_custom_ops_library(get_library_path())
# Initialize ONNX Runtime Inference session with Extensionssession=ort.InferenceSession(onnx_combined_model_path, sess_options=sess_options, providers=['CPUExecutionProvider'])
# Prepare dummy input textinput_feed= {"input_text": np.asarray([input_text])} # Assuming "input_text" is the input expected by the tokenizer# Run the modeloutputs=session.run(None, input_feed)
# Print the outputsprint("logits:", outputs[1][0])
Hi there,
the example given by @MLRadfys for merging converted roberta tokenizer with onnx model was very helpful. However, when I adapted for BERT, it fails on this line of code
with this error:
Any suggestions / feedback on where the code is going wrong would be much appreciated.
Output from code before the error was raised/thrown:
Code to convert to BERT tokenizer and merge with onnx model:
Roberta working example posted in #798:
Originally posted by @MLRadfys in #798
The text was updated successfully, but these errors were encountered: