Skip to content

Issue with BERT tokenizer -> onnx and merge_model:: Non-zero status code returned while running Gather node. Name:'/bert/Gather' #953

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
jilldaly opened this issue May 12, 2025 · 0 comments

Comments

@jilldaly
Copy link

Hi there,

the example given by @MLRadfys for merging converted roberta tokenizer with onnx model was very helpful. However, when I adapted for BERT, it fails on this line of code

outputs = session.run(None, input_feed)

with this error:

2025-05-08 15:56:03.400921608 [E:onnxruntime:, sequential_executor.cc:516 ExecuteKernel] Non-zero status code returned while running Gather node. Name:'/bert/Gather' Status Message: indices element out of data bounds, idx=1 must be within the inclusive range [-1,0]

Any suggestions / feedback on where the code is going wrong would be much appreciated.

Output from code before the error was raised/thrown:

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Tokenizer Model Inputs: ['text']
Tokenizer Model Outputs: ['input_ids', 'token_type_ids', 'attention_mask', 'offset_mapping']
Model Inputs: ['input_ids', 'attention_mask', 'token_type_ids']
Model Outputs: ['logits']
 Hugging Face input feed ['A test text!']
 Hugging Face logits: [[-0.44599634  0.35376194]]

Code to convert to BERT tokenizer and merge with onnx model:

# Including tokenizer to onnx model / basic usage of the onnxruntime-extensions:
# https://github.com/microsoft/onnxruntime-extensions/issues/798

import torch
import onnx
import onnxruntime as ort
import numpy as np
from transformers import AutoTokenizer, BertForSequenceClassification
from onnxruntime_extensions import gen_processing_models, get_library_path

# --------------------------------
#     Configuration & HF Setup   
# --------------------------------
# Load the Huggingface BERT tokenizer and model
input_text           = "A test text!"
model_type           = "bert-base-cased"
model                = BertForSequenceClassification.from_pretrained(model_type)
tokenizer            = AutoTokenizer.from_pretrained(model_type)

# Get HF model logits
hf_inputs = tokenizer(
    input_text,
    return_tensors="pt",
    padding=True,
    truncation=True,
)

with torch.no_grad():
    outputs = model(**hf_inputs)

logits_tensor = outputs.logits
hf_logits = logits_tensor.cpu().detach().numpy()

# --------------------------------
#     Export Tokenizer to ONNX   
# --------------------------------
onnx_tokenizer_path  = "bert_tokenizer.onnx"

# Export the tokenizer to ONNX using gen_processing_models
tokenizer_onnx_model = gen_processing_models(tokenizer, pre_kwargs={})[0]

# Save the tokenizer ONNX model
with open(onnx_tokenizer_path, "wb") as f:
    f.write(tokenizer_onnx_model.SerializeToString())


# --------------------------------
#    Export BERT model to ONNX   
# --------------------------------
# Export the Huggingface BERT model to ONNX
onnx_model_path     = "bert_model.onnx"
dummy_input         = tokenizer("This is a dummy input", return_tensors="pt")
seq_len             = dummy_input["input_ids"].shape[1] 

torch.onnx.export(
    model,
    (dummy_input['input_ids'],dummy_input["attention_mask"],dummy_input["token_type_ids"]),
    onnx_model_path,
    input_names=["input_ids", "attention_mask", "token_type_ids"],
    output_names=["logits"],
    dynamic_axes={
        "input_ids":      {0: "batch_size", 1: "seq_len"},
        "attention_mask": {0: "batch_size", 1: "seq_len"},
        "token_type_ids": {0: "batch_size", 1: "seq_len"},
        "logits":         {0: "batch_size"}
    },
)

# Step 4: Merge the tokenizer and model ONNX files into one
onnx_combined_model_path  = "combined_bert_model_tokenizer.onnx"

# Load the tokenizer and model ONNX files
tokenizer_onnx_model     = onnx.load(onnx_tokenizer_path)
model_onnx_model         = onnx.load(onnx_model_path)

# Force versions to be the same
model_onnx_model.ir_version = tokenizer_onnx_model.ir_version

# Inspect the ONNX models to find the correct input/output names
print("Tokenizer Model Inputs:", [node.name for node in tokenizer_onnx_model.graph.input])
print("Tokenizer Model Outputs:", [node.name for node in tokenizer_onnx_model.graph.output])
print("Model Inputs:", [node.name for node in model_onnx_model.graph.input])
print("Model Outputs:", [node.name for node in model_onnx_model.graph.output])

# Merge the tokenizer and model ONNX files
combined_model = onnx.compose.merge_models(
    tokenizer_onnx_model,
    model_onnx_model,
    io_map=[
        ('input_ids', 'input_ids'), 
        ('attention_mask', 'attention_mask'),
        ('token_type_ids', 'token_type_ids'),
    ]
)

# Save the combined model
onnx.save(combined_model, onnx_combined_model_path)

# Step 5: Test the combined ONNX model using an Inference session with ONNX Runtime Extensions
# Initialize ONNX Runtime SessionOptions and load custom ops library
sess_options = ort.SessionOptions()
sess_options.register_custom_ops_library(get_library_path())

# Initialize ONNX Runtime Inference session with Extensions
session = ort.InferenceSession(onnx_combined_model_path, sess_options=sess_options, providers=['CPUExecutionProvider'])

# Prepare dummy input text
input_feed = {"text": np.asarray([input_text])}  # Assuming "input_text" is the input expected by the tokenizer

print(" Hugging Face input feed", np.asarray([input_text]))
print(" Hugging Face logits:", hf_logits)

# Run the model
outputs = session.run(None, input_feed)

# Print the outputs
print("logits:", outputs[1][0])


# ---------------------------
#     Verify exact parity
# ---------------------------
print(" Hugging Face logits:", hf_logits.flatten())
print("Combined ONNX logits:", combined_logits.flatten())
print("       Match exactly?", np.allclose(hf_logits, combined_logits, atol=1e-6))

Roberta working example posted in #798:

import torch
from onnxruntime_extensions import gen_processing_models
from onnxruntime_extensions import get_library_path
import onnx
import onnxruntime as ort
import numpy as np
from transformers import RobertaForSequenceClassification, RobertaTokenizer

# Step 1: Load the Huggingface Roberta tokenizer and model
input_text = "A test text!"
model_type = "roberta-base"
model = RobertaForSequenceClassification.from_pretrained(model_type)
tokenizer =RobertaTokenizer.from_pretrained(model_type)

# Step 2: Export the tokenizer to ONNX using gen_processing_models
onnx_tokenizer_path = "tokenizer.onnx"

# Generate the tokenizer ONNX model
tokenizer_onnx_model = gen_processing_models(tokenizer, pre_kwargs={})[0]

# Save the tokenizer ONNX model
with open(onnx_tokenizer_path, "wb") as f:
    f.write(tokenizer_onnx_model.SerializeToString())

# Step 3: Export the Huggingface Roberta model to ONNX
onnx_model_path = "model.onnx"
dummy_input = tokenizer("This is a dummy input", return_tensors="pt")


# 5. Export the model to ONNX
torch.onnx.export(
    model,                                                              # model to be exported
    (dummy_input['input_ids'],dummy_input["attention_mask"]),           # model input (dummy input)
    onnx_model_path,                                                    # where to save the ONNX model
    input_names=["input_ids", "attention_mask_input"],                  # input tensor name
    output_names=["logits"],                                            # output tensor names
    dynamic_axes={"input_ids": {0: "batch_size", 1: "sequence_length"}, # dynamic axes
    "logits": {0: "batch_size"}
    }
)

# Step 4: Merge the tokenizer and model ONNX files into one
onnx_combined_model_path = "combined_model_tokenizer.onnx"

# Load the tokenizer and model ONNX files
tokenizer_onnx_model = onnx.load(onnx_tokenizer_path)
model_onnx_model = onnx.load(onnx_model_path)

# Inspect the ONNX models to find the correct input/output names
print("Tokenizer Model Inputs:", [node.name for node in tokenizer_onnx_model.graph.input])
print("Tokenizer Model Outputs:", [node.name for node in tokenizer_onnx_model.graph.output])
print("Model Inputs:", [node.name for node in model_onnx_model.graph.input])
print("Model Outputs:", [node.name for node in model_onnx_model.graph.output])

# Merge the tokenizer and model ONNX files
combined_model = onnx.compose.merge_models(
    tokenizer_onnx_model,
    model_onnx_model,
    io_map=[('input_ids', 'input_ids'), ('attention_mask', 'attention_mask_input')]
)

# Save the combined model
onnx.save(combined_model, onnx_combined_model_path)

# Step 5: Test the combined ONNX model using an Inference session with ONNX Runtime Extensions
# Initialize ONNX Runtime SessionOptions and load custom ops library
sess_options = ort.SessionOptions()
sess_options.register_custom_ops_library(get_library_path())

# Initialize ONNX Runtime Inference session with Extensions
session = ort.InferenceSession(onnx_combined_model_path, sess_options=sess_options, providers=['CPUExecutionProvider'])

# Prepare dummy input text
input_feed = {"input_text": np.asarray([input_text])}  # Assuming "input_text" is the input expected by the tokenizer

# Run the model
outputs = session.run(None, input_feed)

# Print the outputs
print("logits:", outputs[1][0])

Originally posted by @MLRadfys in #798

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant