Skip to content

Commit cb00b43

Browse files
authored
Support Chat Template Override for Unsupported Models (#947)
1 parent 63c4a4d commit cb00b43

File tree

4 files changed

+27
-10
lines changed

4 files changed

+27
-10
lines changed

onnxruntime_extensions/pp_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ def tokenize(self, text, add_special_tokens = True):
5858
def detokenize(self, tokens):
5959
return batch_detokenize(self.tokenizer, [tokens])
6060

61-
def apply_chat_template(self, chat, add_generation_prompt=True, tokenize=False):
61+
def apply_chat_template(self, chat, template="", add_generation_prompt=True, tokenize=False):
6262
result = _apply_chat_template(
63-
self.tokenizer, "", chat, add_generation_prompt, tokenize)
63+
self.tokenizer, template, chat, add_generation_prompt, tokenize)
6464
return tensor_result_get_at(result, 1 if tokenize else 0)
6565

6666
def __del__(self):

shared/api/chat_template.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,10 @@ OrtxStatus TokenizerImpl::LoadChatTemplate() {
1111
try {
1212
chat_template_root_ = minja::Parser::parse(chat_template, {});
1313
} catch (const std::runtime_error& e) {
14-
chat_template_parsing_status = std::string(e.what());
1514
return OrtxStatus(kOrtxOK, "Warning: The chat template for this model is not yet supported, trying to apply chat template will cause an error.");
1615
}
1716
}
1817

19-
chat_template_parsing_status = "Success";
2018
return OrtxStatus(kOrtxOK, "Loaded chat template.");
2119
}
2220

@@ -722,9 +720,9 @@ void TokenizerImpl::InitializeChatParameters(const char* template_str,
722720
// ApplyChatTemplate method to choose the template logic based on chat_template
723721
OrtxStatus TokenizerImpl::ApplyChatTemplate(const TokenizerImpl::MessageList& message_list, std::string& output,
724722
bool add_generation_prompt) const {
725-
if (chat_template_parsing_status != "Success"){
726-
return OrtxStatus(kOrtxErrorInvalidArgument, "Failed to parse chat template: " + chat_template_parsing_status);
727-
}
723+
// Note: The official chat template from this model's config file may not be supported.
724+
// However, we do not throw an error until checking model_to_template_map as the user
725+
// may pass in a template string in our supported list to override the model config template.
728726

729727
// Find the chat_template string for this model if it is supported
730728
auto it = model_to_template_map.find(chat_template);

shared/api/tokenizer_impl.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@ class TokenizerImpl : public OrtxObjectImpl {
5858
std::string chat_template;
5959
mutable MessageList messages;
6060

61-
std::string chat_template_parsing_status;
62-
6361
std::string bos_token;
6462
std::string eos_token;
6563
std::vector<std::string> custom_tools;

test/test_pp_api.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,11 +234,32 @@ def test_OLMa_tokenizer(self):
234234
np.testing.assert_array_equal(ortx_inputs, inputs)
235235

236236
def test_Qwen_QVQ_tokenizer(self):
237-
model_id = "Qwen/QVQ-72B-Preview"
237+
model_id = "Qwen/Qwen3-0.6B-FP8"
238238
test_sentence = [self.tokenizer_test_sentence]
239239
hf_enc = AutoTokenizer.from_pretrained(model_id)
240240
inputs = hf_enc(test_sentence)["input_ids"]
241241
tokenizer = pp_api.Tokenizer(model_id)
242+
243+
# Note: we simply check if chat template override works here, as Qwen/Qwen3-0.6B-FP8 is not a
244+
# supported chat template model, but we do not compare the output of apply_chat_template
245+
# with HF, as it is not supported in Extensions yet.
246+
messages = [
247+
{
248+
"role": "user",
249+
"content": [
250+
{
251+
"type": "image",
252+
"url": "https://huggingface.co/spaces/big-vision/paligemma-hf/resolve/main/examples/password.jpg",
253+
},
254+
{"type": "text", "text": "What is the password?"},
255+
],
256+
}
257+
]
258+
message_json = json.dumps(messages)
259+
templ = """{% for message in messages %}{% if message['role'] == 'system' and 'tools' in message and message['tools'] is not none %}{{ '<|' + message['role'] + '|>' + message['content'] + '<|tool|>' + message['tools'] + '<|/tool|>' + '<|end|>' }}{% else %}{{ '<|' + message['role'] + '|>' + message['content'] + '<|end|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>' }}{% else %}{{ eos_token }}{% endif %}"""
260+
prompt = tokenizer.apply_chat_template(chat=message_json, template=templ)
261+
262+
# Continue tokenizer test
242263
ortx_inputs = tokenizer.tokenize(test_sentence)
243264
np.testing.assert_array_equal(ortx_inputs, inputs)
244265

0 commit comments

Comments
 (0)