Skip to content

Commit 81ed5cc

Browse files
committed
Support ONNX runtime profiling.
1 parent b5cf6d8 commit 81ed5cc

File tree

7 files changed

+40
-22
lines changed

7 files changed

+40
-22
lines changed

v0.7/language/bert/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
build/
22
eval_features.pickle
3+
onnxruntime_profile__*.json

v0.7/language/bert/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,4 @@ evaluate:
169169
clean:
170170
@rm -rf ${BUILD_DIR}
171171
@rm -f ${FEATURE_CACHE}
172+
@rm -f onnxruntime_profile__*.json

v0.7/language/bert/onnxruntime_SUT.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,34 +27,39 @@
2727
from squad_QSL import get_squad_QSL
2828

2929
class BERT_ONNXRuntime_SUT():
30-
def __init__(self, quantized):
30+
def __init__(self, args):
31+
self.profile = args.profile
32+
self.options = onnxruntime.SessionOptions()
33+
self.options.enable_profiling = args.profile
34+
3135
print("Loading ONNX model...")
32-
self.quantized = quantized
33-
if not quantized:
34-
model_path = "build/data/bert_tf_v1_1_large_fp32_384_v2/model.onnx"
35-
else:
36+
self.quantized = args.quantized
37+
if self.quantized:
3638
model_path = "build/data/bert_tf_v1_1_large_fp32_384_v2/bert_large_v1_1_fake_quant.onnx"
37-
self.sess = onnxruntime.InferenceSession(model_path)
39+
else:
40+
model_path = "build/data/bert_tf_v1_1_large_fp32_384_v2/model.onnx"
41+
self.sess = onnxruntime.InferenceSession(model_path, self.options)
3842

3943
print("Constructing SUT...")
4044
self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries, self.process_latencies)
41-
self.qsl = get_squad_QSL()
4245
print("Finished constructing SUT.")
4346

47+
self.qsl = get_squad_QSL()
48+
4449
def issue_queries(self, query_samples):
4550
for i in range(len(query_samples)):
4651
eval_features = self.qsl.get_features(query_samples[i].index)
47-
if not self.quantized:
52+
if self.quantized:
4853
fd = {
4954
"input_ids": np.array(eval_features.input_ids).astype(np.int64)[np.newaxis, :],
50-
"input_mask": np.array(eval_features.input_mask).astype(np.int64)[np.newaxis, :],
51-
"segment_ids": np.array(eval_features.segment_ids).astype(np.int64)[np.newaxis, :]
55+
"attention_mask": np.array(eval_features.input_mask).astype(np.int64)[np.newaxis, :],
56+
"token_type_ids": np.array(eval_features.segment_ids).astype(np.int64)[np.newaxis, :]
5257
}
5358
else:
5459
fd = {
5560
"input_ids": np.array(eval_features.input_ids).astype(np.int64)[np.newaxis, :],
56-
"attention_mask": np.array(eval_features.input_mask).astype(np.int64)[np.newaxis, :],
57-
"token_type_ids": np.array(eval_features.segment_ids).astype(np.int64)[np.newaxis, :]
61+
"input_mask": np.array(eval_features.input_mask).astype(np.int64)[np.newaxis, :],
62+
"segment_ids": np.array(eval_features.segment_ids).astype(np.int64)[np.newaxis, :]
5863
}
5964
scores = self.sess.run([o.name for o in self.sess.get_outputs()], fd)
6065
output = np.stack(scores, axis=-1)[0]
@@ -71,8 +76,9 @@ def process_latencies(self, latencies_ns):
7176
pass
7277

7378
def __del__(self):
74-
lg.DestroySUT(self.sut)
79+
if self.profile:
80+
print("ONNX runtime profile dumped to: '{}'".format(self.sess.end_profiling()))
7581
print("Finished destroying SUT.")
7682

77-
def get_onnxruntime_sut(quantized=False):
78-
return BERT_ONNXRuntime_SUT(quantized)
83+
def get_onnxruntime_sut(args):
84+
return BERT_ONNXRuntime_SUT(args)

v0.7/language/bert/pytorch_SUT.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,10 @@ def __init__(self):
5353

5454
print("Constructing SUT...")
5555
self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries, self.process_latencies)
56-
self.qsl = get_squad_QSL()
5756
print("Finished constructing SUT.")
5857

58+
self.qsl = get_squad_QSL()
59+
5960
def issue_queries(self, query_samples):
6061
with torch.no_grad():
6162
for i in range(len(query_samples)):
@@ -77,7 +78,6 @@ def process_latencies(self, latencies_ns):
7778
pass
7879

7980
def __del__(self):
80-
lg.DestroySUT(self.sut)
8181
print("Finished destroying SUT.")
8282

8383
def get_pytorch_sut():

v0.7/language/bert/run.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,15 @@
2222
import mlperf_loadgen as lg
2323
import subprocess
2424

25+
from squad_QSL import get_squad_QSL
26+
2527
def get_args():
2628
parser = argparse.ArgumentParser()
2729
parser.add_argument("--backend", choices=["tf","pytorch","onnxruntime"], default="tf", help="Backend")
2830
parser.add_argument("--scenario", choices=["SingleStream", "Offline", "Server", "MultiStream"], default="Offline", help="Scenario")
2931
parser.add_argument("--accuracy", action="store_true", help="enable accuracy pass")
3032
parser.add_argument("--quantized", action="store_true", help="use quantized model (only valid for onnxruntime backend)")
33+
parser.add_argument("--profile", action="store_true", help="enable profiling (only valid for onnxruntime backend)")
3134
parser.add_argument("--mlperf_conf", default="build/mlperf.conf", help="mlperf rules config")
3235
parser.add_argument("--user_conf", default="user.conf", help="mlperf rules config")
3336
args = parser.parse_args()
@@ -45,15 +48,17 @@ def main():
4548

4649
if args.backend == "pytorch":
4750
assert not args.quantized, "Quantized model is only supported by onnxruntime backend!"
51+
assert not args.profile, "Profiling is only supported by onnxruntime backend!"
4852
from pytorch_SUT import get_pytorch_sut
4953
sut = get_pytorch_sut()
5054
elif args.backend == "tf":
5155
assert not args.quantized, "Quantized model is only supported by onnxruntime backend!"
56+
assert not args.profile, "Profiling is only supported by onnxruntime backend!"
5257
from tf_SUT import get_tf_sut
5358
sut = get_tf_sut()
5459
elif args.backend == "onnxruntime":
5560
from onnxruntime_SUT import get_onnxruntime_sut
56-
sut = get_onnxruntime_sut(args.quantized)
61+
sut = get_onnxruntime_sut(args)
5762
else:
5863
raise ValueError("Unknown backend: {:}".format(args.backend))
5964

@@ -76,7 +81,7 @@ def main():
7681
log_settings = lg.LogSettings()
7782
log_settings.log_output = log_output_settings
7883

79-
print("Running Loadgen test...")
84+
print("Running LoadGen test...")
8085
lg.StartTestWithLogSettings(sut.sut, sut.qsl.qsl, settings, log_settings)
8186

8287
if args.accuracy:
@@ -85,5 +90,11 @@ def main():
8590

8691
print("Done!")
8792

93+
print("Destroying SUT...")
94+
lg.DestroySUT(sut.sut)
95+
96+
print("Destroying QSL...")
97+
lg.DestroyQSL(sut.qsl.qsl)
98+
8899
if __name__ == "__main__":
89100
main()

v0.7/language/bert/squad_QSL.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def get_features(self, sample_id):
8484
return self.eval_features[sample_id]
8585

8686
def __del__(self):
87-
lg.DestroyQSL(self.qsl)
8887
print("Finished destroying QSL.")
8988

9089
def get_squad_QSL():

v0.7/language/bert/tf_SUT.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@ def __init__(self, batch_size=8):
4646

4747
print("Constructing SUT...")
4848
self.sut = lg.ConstructSUT(self.issue_queries, self.flush_queries, self.process_latencies)
49-
self.qsl = get_squad_QSL()
5049
print("Finished constructing SUT.")
5150

51+
self.qsl = get_squad_QSL()
52+
5253
def issue_queries(self, query_samples):
5354
input_ids = np.zeros((len(query_samples), 1, 384), dtype=np.int32)
5455
input_mask = np.zeros((len(query_samples), 1, 384), dtype=np.int32)
@@ -81,7 +82,6 @@ def process_latencies(self, latencies_ns):
8182
pass
8283

8384
def __del__(self):
84-
lg.DestroySUT(self.sut)
8585
print("Finished destroying SUT.")
8686

8787
def create_model(self, bert_config, is_training, input_ids, input_mask, segment_ids, use_one_hot_embeddings):

0 commit comments

Comments
 (0)