Skip to content

Commit bd0d13f

Browse files
authored
add the tensorrt python api for hrnet (wang-xinyu#692)
1 parent ec20cea commit bd0d13f

File tree

1 file changed

+217
-0
lines changed

1 file changed

+217
-0
lines changed
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
"""
2+
An example that uses TensorRT's Python api to make inferences for hrnet.
3+
"""
4+
import os
5+
import shutil
6+
import random
7+
import sys
8+
import threading
9+
import time
10+
import cv2
11+
import numpy as np
12+
import pycuda.autoinit
13+
import pycuda.driver as cuda
14+
import tensorrt as trt
15+
from imgaug import augmenters as iaa
16+
17+
def get_img_path_batches(batch_size, img_dir):
18+
ret = []
19+
batch = []
20+
for root, dirs, files in os.walk(img_dir):
21+
for name in files:
22+
if len(batch) == batch_size:
23+
ret.append(batch)
24+
batch = []
25+
batch.append(os.path.join(root, name))
26+
if len(batch) > 0:
27+
ret.append(batch)
28+
return ret
29+
30+
class Hrnet_TRT(object):
31+
"""
32+
description: A Hrnet class that warps TensorRT ops, preprocess and postprocess ops.
33+
"""
34+
35+
def __init__(self, engine_file_path):
36+
# Create a Context on this device,
37+
self.cfx = cuda.Device(0).make_context()
38+
stream = cuda.Stream()
39+
runtime = trt.Runtime(trt.Logger(trt.Logger.INFO))
40+
assert runtime
41+
42+
# Deserialize the engine from file
43+
with open(engine_file_path, "rb") as f:
44+
engine = runtime.deserialize_cuda_engine(f.read())
45+
context = engine.create_execution_context()
46+
47+
host_inputs = []
48+
cuda_inputs = []
49+
host_outputs = []
50+
cuda_outputs = []
51+
bindings = []
52+
53+
for binding in engine:
54+
size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
55+
dtype = trt.nptype(engine.get_binding_dtype(binding))
56+
# Allocate host and device buffers
57+
host_mem = cuda.pagelocked_empty(size, dtype)
58+
cuda_mem = cuda.mem_alloc(host_mem.nbytes)
59+
# Append the device buffer to device bindings.
60+
bindings.append(int(cuda_mem))
61+
# Append to the appropriate list.
62+
if engine.binding_is_input(binding):
63+
self.input_w = engine.get_binding_shape(binding)[-2]
64+
self.input_h = engine.get_binding_shape(binding)[-3]
65+
host_inputs.append(host_mem)
66+
cuda_inputs.append(cuda_mem)
67+
else:
68+
host_outputs.append(host_mem)
69+
cuda_outputs.append(cuda_mem)
70+
71+
# Store
72+
self.stream = stream
73+
self.context = context
74+
self.engine = engine
75+
self.host_inputs = host_inputs
76+
self.cuda_inputs = cuda_inputs
77+
self.host_outputs = host_outputs
78+
self.cuda_outputs = cuda_outputs
79+
self.bindings = bindings
80+
self.batch_size = engine.max_batch_size
81+
82+
def infer(self, image_raw):
83+
threading.Thread.__init__(self)
84+
# Make self the active context, pushing it on top of the context stack.
85+
self.cfx.push()
86+
# Restore
87+
stream = self.stream
88+
context = self.context
89+
engine = self.engine
90+
host_inputs = self.host_inputs
91+
cuda_inputs = self.cuda_inputs
92+
host_outputs = self.host_outputs
93+
cuda_outputs = self.cuda_outputs
94+
bindings = self.bindings
95+
print('ori_shape: ', image_raw.shape)
96+
# if image_raw is constant, image_raw.shape[1] != self.input_w
97+
w_ori, h_ori = image_raw.shape[1], image_raw.shape[0]
98+
# Do image preprocess
99+
input_image = self.preprocess_image(image_raw)
100+
# Copy input image to host buffer
101+
np.copyto(host_inputs[0], input_image.ravel())
102+
start = time.time()
103+
# Transfer input data to the GPU.
104+
cuda.memcpy_htod_async(cuda_inputs[0], host_inputs[0], stream)
105+
# Run inference.
106+
context.execute_async(bindings=bindings, stream_handle=stream.handle)
107+
# Transfer predictions back from the GPU.
108+
cuda.memcpy_dtoh_async(host_outputs[0], cuda_outputs[0], stream)
109+
# Synchronize the stream
110+
stream.synchronize()
111+
end = time.time()
112+
# Remove any context from the top of the context stack, deactivating it.
113+
self.cfx.pop()
114+
# Here we use the first row of output in that batch_size = 1
115+
output = host_outputs[0]
116+
# Do postprocess
117+
output = output.reshape(self.input_h, self.input_w).astype('uint8')
118+
print('output_shape: ', output.shape)
119+
output = cv2.resize(output, (w_ori, h_ori))
120+
return output, end - start
121+
122+
def destroy(self):
123+
# Remove any context from the top of the context stack, deactivating it.
124+
self.cfx.pop()
125+
126+
def preprocess_image(self, image_raw):
127+
"""
128+
description: Read an image from image path, convert it to RGB,
129+
resize and pad it to target size.
130+
param:
131+
image_raw: numpy, raw image
132+
return:
133+
image: the processed image
134+
"""
135+
image = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB)
136+
resize = iaa.Resize({
137+
'width': self.input_w,
138+
'height': self.input_h
139+
})
140+
image = resize.augment_image(image)
141+
print('resized', image.shape, image.dtype)
142+
image = image.astype(np.float32)
143+
return image
144+
145+
def get_raw_image(self, image_path_batch):
146+
"""
147+
description: Read an image from image path
148+
"""
149+
for img_path in image_path_batch:
150+
return cv2.imread(img_path)
151+
152+
def get_raw_image_zeros(self, image_path_batch=None):
153+
"""
154+
description: Ready data for warmup
155+
"""
156+
for _ in range(self.batch_size):
157+
return np.zeros([self.input_h, self.input_w, 3], dtype=np.uint8)
158+
159+
160+
class inferThread(threading.Thread):
161+
def __init__(self, hrnet_wrapper, image_path_batch):
162+
threading.Thread.__init__(self)
163+
self.hrnet_wrapper = hrnet_wrapper
164+
self.image_path_batch = image_path_batch
165+
166+
def run(self):
167+
batch_image_raw, use_time = self.hrnet_wrapper.infer(self.hrnet_wrapper.get_raw_image(self.image_path_batch))
168+
for i, img_path in enumerate(self.image_path_batch):
169+
parent, filename = os.path.split(img_path)
170+
save_name = os.path.join('output', filename)
171+
# Save image
172+
cv2.imwrite(save_name, batch_image_raw*255)
173+
print('input->{}, time->{:.2f}ms, saving into output/'.format(self.image_path_batch, use_time * 1000))
174+
175+
176+
class warmUpThread(threading.Thread):
177+
def __init__(self, hrnet_wrapper):
178+
threading.Thread.__init__(self)
179+
self.hrnet_wrapper = hrnet_wrapper
180+
181+
def run(self):
182+
batch_image_raw, use_time = self.hrnet_wrapper.infer(self.hrnet_wrapper.get_raw_image_zeros())
183+
print('warm_up->{}, time->{:.2f}ms'.format(batch_image_raw[0].shape, use_time * 1000))
184+
185+
186+
187+
if __name__ == "__main__":
188+
# load custom engine
189+
engine_file_path = "build/hrnet.engine" # the generated engine file
190+
191+
if len(sys.argv) > 1:
192+
engine_file_path = sys.argv[1]
193+
194+
if os.path.exists('output/'):
195+
shutil.rmtree('output/')
196+
os.makedirs('output/')
197+
# a hrnet instance
198+
hrnet_wrapper = Hrnet_TRT(engine_file_path)
199+
try:
200+
print('batch size is', hrnet_wrapper.batch_size) # batch size is set to 1!
201+
202+
image_dir = "samples/"
203+
image_path_batches = get_img_path_batches(hrnet_wrapper.batch_size, image_dir)
204+
205+
for i in range(10):
206+
# create a new thread to do warm_up
207+
thread1 = warmUpThread(hrnet_wrapper)
208+
thread1.start()
209+
thread1.join()
210+
for batch in image_path_batches:
211+
# create a new thread to do inference
212+
thread1 = inferThread(hrnet_wrapper, batch)
213+
thread1.start()
214+
thread1.join()
215+
finally:
216+
# destroy the instance
217+
hrnet_wrapper.destroy()

0 commit comments

Comments
 (0)