Skip to content

Commit 2f503ba

Browse files
author
Adam Romlein
committed
Implemented queue to fix latency issues.
1 parent 31acbd8 commit 2f503ba

File tree

1 file changed

+26
-9
lines changed

1 file changed

+26
-9
lines changed

src/ros_wrapper.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import os
55
import argparse
66
from distutils.version import LooseVersion
7+
import queue
8+
import threading
79
# Numerical libs
810
import numpy as np
911
import torch
@@ -41,6 +43,7 @@ def __init__(self, cfg, gpu, img_in, img_out):
4143
self.img_in = img_in
4244
self.img_out = img_out
4345
self.bridge = CvBridge()
46+
self.loader_q = queue.Queue(1)
4447

4548

4649
def visualize_result(self, data, pred, cfg):
@@ -70,7 +73,14 @@ def visualize_result(self, data, pred, cfg):
7073
#Image.fromarray(im_vis).save(
7174
# os.path.join(cfg.TEST.result, img_name.replace('.jpg', '.png')))
7275

73-
def run_inference(self, loader):
76+
def run_inference(self):
77+
rospy.loginfo("Waiting for loader from queue...")
78+
while self.loader_q.empty():
79+
rospy.sleep(0.01)
80+
tic = rospy.get_rostime()
81+
rospy.loginfo("Processing image...")
82+
loader = self.loader_q.get()
83+
7484
self.segmentation_module.eval()
7585

7686
pbar = tqdm(total=len(loader))
@@ -98,20 +108,24 @@ def run_inference(self, loader):
98108

99109
_, pred = torch.max(scores, dim=1)
100110
pred = as_numpy(pred.squeeze(0).cpu())
111+
# print(dir(scores))
112+
# print(scores.shape)
101113

102114
# visualization
103115
self.visualize_result(
104116
(batch_data['img_ori'], batch_data['info']),
105117
pred,
106118
self.cfg
107119
)
108-
109120
pbar.update(1)
121+
122+
rospy.loginfo('Inference done in %.03f seconds.' %
123+
((rospy.get_rostime() - tic).to_sec()))
110124

111125
def image_callback(self, img):
112-
tic = rospy.get_rostime()
113-
rospy.loginfo("Processing image...")
114-
126+
if self.loader_q.full():
127+
return
128+
115129
try:
116130
cv_image = self.bridge.imgmsg_to_cv2(img, desired_encoding="rgb8")
117131
except CvBridgeError as e:
@@ -120,6 +134,7 @@ def image_callback(self, img):
120134
imgs = []
121135
# Need it in PIL?
122136
PILimg = Image.fromarray(cv_image)
137+
PILimg = PILimg.resize((480, 320))
123138
# In case we ever want to batch multiple images
124139
imgs.append(PILimg)
125140
img_list = [{'img': x} for x in imgs]
@@ -139,9 +154,7 @@ def image_callback(self, img):
139154

140155
# img_labels = self.segment(gpu)
141156

142-
self.run_inference(loader)
143-
rospy.loginfo('Inference done in %.03f seconds.' %
144-
((rospy.get_rostime() - tic).to_sec()))
157+
self.loader_q.put(loader)
145158

146159
def main(self):
147160

@@ -164,11 +177,15 @@ def main(self):
164177
self.segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)
165178
self.segmentation_module.cuda()
166179

167-
self.seg_pub = rospy.Publisher(self.img_out, sensor_msgs.msg.Image, queue_size=10)
180+
self.seg_pub = rospy.Publisher(self.img_out, sensor_msgs.msg.Image, queue_size=1)
168181
rospy.Subscriber(self.img_in, sensor_msgs.msg.Image, self.image_callback)
169182

170183
rospy.loginfo("Listening for image messages on topic %s..." % self.img_in)
171184
rospy.loginfo("Publishing segmented images to topic %s..." % self.img_out)
185+
186+
while not rospy.is_shutdown():
187+
self.run_inference()
188+
172189
rospy.spin()
173190

174191

0 commit comments

Comments
 (0)