Skip to content

Commit 0d80420

Browse files
authored
Implement a general chat template support for more models (#918)
* Add general chat template implmentation * add tokenize flag for chat template API
1 parent 2c3d6f7 commit 0d80420

20 files changed

+5095
-1135
lines changed

CMakeLists.txt

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,11 @@ if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR)
5757
set(_ORTX_STANDALONE_PROJECT ON)
5858
endif()
5959

60+
set(_ORTX_CPP_NO_RTTI ON)
6061
set(_ORTX_SHARED_BUILD_SUPPORTED ON)
6162
if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
6263
set(_ORTX_SHARED_BUILD_SUPPORTED OFF)
64+
set(_ORTX_CPP_NO_RTTI OFF)
6365
endif()
6466

6567
option(CC_OPTIMIZE "Allow compiler optimizations, Set to OFF to disable" ON)
@@ -189,6 +191,18 @@ if(NOT OCOS_BUILD_PYTHON AND OCOS_ENABLE_PYTHON)
189191
set(OCOS_BUILD_PYTHON ON CACHE INTERNAL "")
190192
endif()
191193

194+
if(OCOS_BUILD_PYTHON)
195+
set(_ORTX_CPP_NO_RTTI OFF)
196+
endif()
197+
198+
if(_ORTX_CPP_NO_RTTI)
199+
if(MSVC)
200+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /GR-")
201+
else()
202+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fno-rtti")
203+
endif()
204+
endif()
205+
192206
if(OCOS_BUILD_ANDROID)
193207
if(NOT CMAKE_TOOLCHAIN_FILE MATCHES "android.toolchain.cmake")
194208
message(FATAL_ERROR "CMAKE_TOOLCHAIN_FILE must be set to build/cmake/android.toolchain.cmake from the Android NDK.")

include/ortx_cpp_helper.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class OrtxObjectPtr : public std::unique_ptr<T, OrtxDeleter<T>> {
3535
*
3636
* Constructs an OrtxObjectPtr with a null pointer.
3737
*/
38-
OrtxObjectPtr() : std::unique_ptr<T, OrtxDeleter<T>>(nullptr) {}
38+
explicit OrtxObjectPtr(T* ptr=nullptr) : std::unique_ptr<T, OrtxDeleter<T>>(ptr) {}
3939

4040
/**
4141
* @brief Constructor that creates an OrtxObjectPtr from a function call.
@@ -58,6 +58,15 @@ class OrtxObjectPtr : public std::unique_ptr<T, OrtxDeleter<T>> {
5858
}
5959
}
6060

61+
template <typename TFn, typename... Args>
62+
static OrtxObjectPtr<T> FromCapi(TFn fn, Args&&... args) {
63+
OrtxObject* proc = nullptr;
64+
extError_t err = fn(&proc, std::forward<Args>(args)...);
65+
if (err == kOrtxOK) {
66+
return OrtxObjectPtr(static_cast<T*>(proc));
67+
}
68+
}
69+
6170
/**
6271
* @brief Get the error code associated with the creation of the OrtxObject.
6372
*

include/ortx_tokenizer.h

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,21 @@ struct OrtxTokenizerBlob {
3333
const size_t reserved_blob_1_len;
3434

3535
#ifdef __cplusplus
36-
OrtxTokenizerBlob(const std::string_view& config_json_blob,
37-
const std::string_view& vocab_json_blob,
38-
const std::string_view& token_module_blob = {},
39-
const std::string_view& raw_model_blob = {})
40-
: config_json_blob(config_json_blob.data()), vocab_json_blob(vocab_json_blob.data()),
41-
token_module_blob(token_module_blob.data()), raw_model_blob(raw_model_blob.data()),
42-
reserved_blob_1(nullptr), config_blob_len(config_json_blob.size()),
43-
vocab_blob_len(vocab_json_blob.size()), token_module_blob_len(token_module_blob.size()),
44-
raw_model_blob_len(raw_model_blob.size()), reserved_blob_1_len(0) {}
36+
OrtxTokenizerBlob(const std::string_view& config_json_blob, const std::string_view& vocab_json_blob,
37+
const std::string_view& token_module_blob = {}, const std::string_view& raw_model_blob = {})
38+
: config_json_blob(config_json_blob.data()),
39+
vocab_json_blob(vocab_json_blob.data()),
40+
token_module_blob(token_module_blob.data()),
41+
raw_model_blob(raw_model_blob.data()),
42+
reserved_blob_1(nullptr),
43+
config_blob_len(config_json_blob.size()),
44+
vocab_blob_len(vocab_json_blob.size()),
45+
token_module_blob_len(token_module_blob.size()),
46+
raw_model_blob_len(raw_model_blob.size()),
47+
reserved_blob_1_len(0) {}
4548
#endif
4649
};
4750

48-
4951
#ifdef __cplusplus
5052
extern "C" {
5153
#endif
@@ -64,8 +66,8 @@ extError_t ORTX_API_CALL OrtxCreateTokenizer(OrtxTokenizer** tokenizer, const ch
6466
* \param tokenizer_blob Pointer to the tokenizer blob
6567
* \return Error code indicating the success or failure of the operation
6668
*/
67-
extError_t ORTX_API_CALL OrtxCreateTokenizerFromBlob(OrtxTokenizer** tokenizer, const struct OrtxTokenizerBlob* tokenizer_blob);
68-
69+
extError_t ORTX_API_CALL OrtxCreateTokenizerFromBlob(OrtxTokenizer** tokenizer,
70+
const struct OrtxTokenizerBlob* tokenizer_blob);
6971

7072
/** \brief Tokenize the input using the specified tokenizer
7173
*
@@ -75,8 +77,8 @@ extError_t ORTX_API_CALL OrtxCreateTokenizerFromBlob(OrtxTokenizer** tokenizer,
7577
* \param output Pointer to store the tokenized result
7678
* \return Error code indicating the success or failure of the operation
7779
*/
78-
extError_t ORTX_API_CALL OrtxTokenize(
79-
const OrtxTokenizer* tokenizer, const char* input[], size_t batch_size, OrtxTokenId2DArray** output);
80+
extError_t ORTX_API_CALL OrtxTokenize(const OrtxTokenizer* tokenizer, const char* input[], size_t batch_size,
81+
OrtxTokenId2DArray** output);
8082

8183
/**
8284
* Converts a token to its corresponding ID.
@@ -101,8 +103,8 @@ extError_t ORTX_API_CALL OrtxConvertTokenToId(const OrtxTokenizer* tokenizer, co
101103
* @param output A pointer to the OrtxTokenId2DArray object to store the output.
102104
* @return An extError_t value indicating the success or failure of the operation.
103105
*/
104-
extError_t ORTX_API_CALL OrtxGetDecoderPromptIds(
105-
const OrtxTokenizer* tokenizer, size_t batch_size, const char* lang, const char* task, int no_timestamps, OrtxTokenId2DArray** output);
106+
extError_t ORTX_API_CALL OrtxGetDecoderPromptIds(const OrtxTokenizer* tokenizer, size_t batch_size, const char* lang,
107+
const char* task, int no_timestamps, OrtxTokenId2DArray** output);
106108

107109
/** \brief Detokenize the input using the specified tokenizer
108110
*
@@ -111,8 +113,8 @@ extError_t ORTX_API_CALL OrtxGetDecoderPromptIds(
111113
* \param output Pointer to store the detokenized result
112114
* \return Error code indicating the success or failure of the operation
113115
*/
114-
extError_t ORTX_API_CALL OrtxDetokenize(
115-
const OrtxTokenizer* tokenizer, const OrtxTokenId2DArray* input, OrtxStringArray** output);
116+
extError_t ORTX_API_CALL OrtxDetokenize(const OrtxTokenizer* tokenizer, const OrtxTokenId2DArray* input,
117+
OrtxStringArray** output);
116118

117119
/** \brief Detokenize the input using the specified tokenizer (1D version)
118120
*
@@ -122,8 +124,8 @@ extError_t ORTX_API_CALL OrtxDetokenize(
122124
* \param output Pointer to store the detokenized result
123125
* \return Error code indicating the success or failure of the operation
124126
*/
125-
extError_t ORTX_API_CALL OrtxDetokenize1D(
126-
const OrtxTokenizer* tokenizer, const extTokenId_t* input, size_t len, OrtxStringArray** output);
127+
extError_t ORTX_API_CALL OrtxDetokenize1D(const OrtxTokenizer* tokenizer, const extTokenId_t* input, size_t len,
128+
OrtxStringArray** output);
127129

128130
/** \brief Detokenize the input using the specified tokenizer with caching
129131
*
@@ -133,8 +135,8 @@ extError_t ORTX_API_CALL OrtxDetokenize1D(
133135
* \param text_out Pointer to store the detokenized text
134136
* \return Error code indicating the success or failure of the operation
135137
*/
136-
extError_t ORTX_API_CALL OrtxDetokenizeCached(
137-
const OrtxTokenizer* tokenizer, OrtxDetokenizerCache* cache, extTokenId_t next_id, const char** text_out);
138+
extError_t ORTX_API_CALL OrtxDetokenizeCached(const OrtxTokenizer* tokenizer, OrtxDetokenizerCache* cache,
139+
extTokenId_t next_id, const char** text_out);
138140

139141
/**
140142
* @brief Retrieves the C-style string representation from an OrtxString object.
@@ -182,28 +184,30 @@ extError_t ORTX_API_CALL OrtxTokenId2DArrayGetBatch(const OrtxTokenId2DArray* to
182184
* \param length Pointer to store the length of the item
183185
* \return Error code indicating the success or failure of the operation
184186
*/
185-
extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(
186-
const OrtxTokenId2DArray* token_id_2d_array, size_t index, const extTokenId_t** item, size_t* length);
187+
extError_t ORTX_API_CALL OrtxTokenId2DArrayGetItem(const OrtxTokenId2DArray* token_id_2d_array, size_t index,
188+
const extTokenId_t** item, size_t* length);
187189

188190
/**
189191
* @brief Applies a chat template to the given input.
190192
*
191193
* This function processes the specified template with the provided input using the
192-
* tokenizer, and outputs the resulting string array. Optionally, it can include a
194+
* tokenizer, and outputs the resulting tensor result. Optionally, it can include a
193195
* generation prompt in the output. The chat template can be provided as a string or
194-
* be retrieved from a loaded tokenizer json file which contains the chat template its json file.
195-
* if both tokenizer and template_str are provided, the template_str will supersede the tokenizer.
196+
* be retrieved from a loaded tokenizer json file which contains the chat template in its json file.
197+
* If both tokenizer and template_str are provided, the template_str will supersede the tokenizer.
196198
*
197-
* @param tokenizer Pointer to an OrtxTokenizer used for template processing
198-
* @param template_str Null-terminated string representing the chat template, can be null if tokenizer.json has one.
199+
* @param tokenizer Pointer to an OrtxTokenizer used for template processing.
200+
* @param template_str Null-terminated string representing the chat template; can be null if tokenizer.json has one.
199201
* @param input Null-terminated string containing the input to be processed.
200-
* @param output an OrtxString that will be populated with the output strings.
202+
* @param output Pointer to an OrtxTensorResult that will be populated with the output strings,
203+
* if tokenize is true, the ids will be in the output as indexed 1.
201204
* @param add_generation_prompt Indicates whether to add a generation prompt to the output.
205+
* @param tokenize Indicates whether to tokenize the templated text to IDs
202206
* @return extError_t Returns an error code indicating success or the type of failure.
203207
*/
204208
extError_t ORTX_API_CALL OrtxApplyChatTemplate(const OrtxTokenizer* tokenizer, const char* template_str,
205-
const char* input, OrtxString** output,
206-
bool add_generation_prompt);
209+
const char* input, OrtxTensorResult** output,
210+
bool add_generation_prompt, bool tokenize);
207211

208212
#ifdef __cplusplus
209213
}

include/ortx_utils.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ typedef OrtxObject OrtxTensorResult;
4646
} \
4747
} while (0)
4848

49-
5049
typedef uint32_t extTokenId_t;
5150

5251
#ifdef __cplusplus
@@ -99,7 +98,7 @@ extError_t ORTX_API_CALL OrtxDisposeOnly(OrtxObject* object);
9998
* @param tensor A pointer to a variable that will hold the retrieved tensor.
10099
* @return An error code indicating the success or failure of the operation.
101100
*/
102-
extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t index, OrtxTensor** tensor);
101+
extError_t ORTX_API_CALL OrtxTensorResultGetAt(const OrtxTensorResult* result, size_t index, OrtxTensor** tensor);
103102

104103
/**
105104
* @brief Retrieves the data type of the given tensor.
@@ -112,18 +111,19 @@ extError_t ORTX_API_CALL OrtxTensorResultGetAt(OrtxTensorResult* result, size_t
112111
*
113112
* @return An `extError_t` value indicating the success or failure of the operation.
114113
*/
115-
extError_t ORTX_API_CALL OrtxGetTensorType(OrtxTensor* tensor, extDataType_t* type);
114+
extError_t ORTX_API_CALL OrtxGetTensorType(const OrtxTensor* tensor, extDataType_t* type);
116115

117116
/**
118117
* @brief Retrieves the size of each element in the given tensor.
119118
*
120-
* This function calculates the size of each element in the specified tensor and stores it in the provided size variable.
119+
* This function calculates the size of each element in the specified tensor and stores it in the provided size
120+
* variable.
121121
*
122122
* @param tensor A pointer to the OrtxTensor object.
123123
* @param size A pointer to a size_t variable to store the size of each element.
124124
* @return An extError_t value indicating the success or failure of the operation.
125125
*/
126-
extError_t ORTX_API_CALL OrtxGetTensorSizeOfElement(OrtxTensor* tensor, size_t* size);
126+
extError_t ORTX_API_CALL OrtxGetTensorSizeOfElement(const OrtxTensor* tensor, size_t* size);
127127

128128
/** \brief Get the data from the tensor
129129
*
@@ -133,7 +133,7 @@ extError_t ORTX_API_CALL OrtxGetTensorSizeOfElement(OrtxTensor* tensor, size_t*
133133
* \param num_dims Pointer to store the number of dimensions
134134
* \return Error code indicating the success or failure of the operation
135135
*/
136-
extError_t ORTX_API_CALL OrtxGetTensorData(OrtxTensor* tensor, const void** data, const int64_t** shape,
136+
extError_t ORTX_API_CALL OrtxGetTensorData(const OrtxTensor* tensor, const void** data, const int64_t** shape,
137137
size_t* num_dims);
138138

139139
#ifdef __cplusplus

onnxruntime_extensions/pp_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,10 @@ def tokenize(self, text):
5757
def detokenize(self, tokens):
5858
return batch_detokenize(self.tokenizer, [tokens])
5959

60-
def apply_chat_template(self, chat, add_generation_prompt=True):
61-
prompt = _apply_chat_template(
62-
self.tokenizer, "", chat, add_generation_prompt)
63-
return prompt
60+
def apply_chat_template(self, chat, add_generation_prompt=True, tokenize=False):
61+
result = _apply_chat_template(
62+
self.tokenizer, "", chat, add_generation_prompt, tokenize)
63+
return tensor_result_get_at(result, 1 if tokenize else 0)
6464

6565
def __del__(self):
6666
if delete_object and self.tokenizer:

0 commit comments

Comments
 (0)