2727from squad_QSL import get_squad_QSL
2828
2929class BERT_ONNXRuntime_SUT ():
30- def __init__ (self , quantized ):
30+ def __init__ (self , args ):
31+ self .profile = args .profile
32+ self .options = onnxruntime .SessionOptions ()
33+ self .options .enable_profiling = args .profile
34+
3135 print ("Loading ONNX model..." )
32- self .quantized = quantized
33- if not quantized :
34- model_path = "build/data/bert_tf_v1_1_large_fp32_384_v2/model.onnx"
35- else :
36+ self .quantized = args .quantized
37+ if self .quantized :
3638 model_path = "build/data/bert_tf_v1_1_large_fp32_384_v2/bert_large_v1_1_fake_quant.onnx"
37- self .sess = onnxruntime .InferenceSession (model_path )
39+ else :
40+ model_path = "build/data/bert_tf_v1_1_large_fp32_384_v2/model.onnx"
41+ self .sess = onnxruntime .InferenceSession (model_path , self .options )
3842
3943 print ("Constructing SUT..." )
4044 self .sut = lg .ConstructSUT (self .issue_queries , self .flush_queries , self .process_latencies )
41- self .qsl = get_squad_QSL ()
4245 print ("Finished constructing SUT." )
4346
47+ self .qsl = get_squad_QSL ()
48+
4449 def issue_queries (self , query_samples ):
4550 for i in range (len (query_samples )):
4651 eval_features = self .qsl .get_features (query_samples [i ].index )
47- if not self .quantized :
52+ if self .quantized :
4853 fd = {
4954 "input_ids" : np .array (eval_features .input_ids ).astype (np .int64 )[np .newaxis , :],
50- "input_mask " : np .array (eval_features .input_mask ).astype (np .int64 )[np .newaxis , :],
51- "segment_ids " : np .array (eval_features .segment_ids ).astype (np .int64 )[np .newaxis , :]
55+ "attention_mask " : np .array (eval_features .input_mask ).astype (np .int64 )[np .newaxis , :],
56+ "token_type_ids " : np .array (eval_features .segment_ids ).astype (np .int64 )[np .newaxis , :]
5257 }
5358 else :
5459 fd = {
5560 "input_ids" : np .array (eval_features .input_ids ).astype (np .int64 )[np .newaxis , :],
56- "attention_mask " : np .array (eval_features .input_mask ).astype (np .int64 )[np .newaxis , :],
57- "token_type_ids " : np .array (eval_features .segment_ids ).astype (np .int64 )[np .newaxis , :]
61+ "input_mask " : np .array (eval_features .input_mask ).astype (np .int64 )[np .newaxis , :],
62+ "segment_ids " : np .array (eval_features .segment_ids ).astype (np .int64 )[np .newaxis , :]
5863 }
5964 scores = self .sess .run ([o .name for o in self .sess .get_outputs ()], fd )
6065 output = np .stack (scores , axis = - 1 )[0 ]
@@ -71,8 +76,9 @@ def process_latencies(self, latencies_ns):
7176 pass
7277
7378 def __del__ (self ):
74- lg .DestroySUT (self .sut )
79+ if self .profile :
80+ print ("ONNX runtime profile dumped to: '{}'" .format (self .sess .end_profiling ()))
7581 print ("Finished destroying SUT." )
7682
77- def get_onnxruntime_sut (quantized = False ):
78- return BERT_ONNXRuntime_SUT (quantized )
83+ def get_onnxruntime_sut (args ):
84+ return BERT_ONNXRuntime_SUT (args )
0 commit comments