Skip to content

Commit 29003e5

Browse files
authored
Merge pull request tobegit3hub#22 from yejw5/wait_callback
wait callback
2 parents 0c0854c + b2daa9a commit 29003e5

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

minimal_model/python_predict_client/benchmark_qps.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy
44
import multiprocessing
5+
import threading
56
import time
67

78
from grpc.beta import implementations
@@ -21,8 +22,9 @@
2122
FLAGS = tf.app.flags.FLAGS
2223

2324

24-
def _create_rpc_callback():
25+
def _create_rpc_callback(event):
2526
def _callback(result_future):
27+
event.set()
2628
exception = result_future.exception()
2729
if exception:
2830
print(exception)
@@ -59,16 +61,22 @@ def test_one_process(i):
5961
request_number = FLAGS.benchmark_test_number
6062
#start_time = time.time()
6163

64+
events = []
6265
for i in range(request_number):
66+
event = threading.Event()
6367
result_future = stub.Predict.future(request, request_timeout)
6468
#result_future = stub.Predict.future(request, 0.00000001)
65-
result_future.add_done_callback(_create_rpc_callback())
69+
result_future.add_done_callback(_create_rpc_callback(event))
70+
events.append(event)
6671
#result = stub.Predict(request, request_timeout)
6772

6873
#end_time = time.time()
6974
#print("Average latency is: {} ms".format((end_time - start_time) * 1000 / request_number))
7075
#print("Average qps is: {}".format(request_number / (end_time - start_time)))
7176

77+
for event in events:
78+
event.wait()
79+
7280

7381
def main():
7482
thread_number = FLAGS.benchmark_thread_number

0 commit comments

Comments
 (0)