44import os
55import argparse
66from distutils .version import LooseVersion
7+ import queue
8+ import threading
79# Numerical libs
810import numpy as np
911import torch
@@ -41,6 +43,7 @@ def __init__(self, cfg, gpu, img_in, img_out):
4143 self .img_in = img_in
4244 self .img_out = img_out
4345 self .bridge = CvBridge ()
46+ self .loader_q = queue .Queue (1 )
4447
4548
4649 def visualize_result (self , data , pred , cfg ):
@@ -70,7 +73,14 @@ def visualize_result(self, data, pred, cfg):
7073 #Image.fromarray(im_vis).save(
7174 # os.path.join(cfg.TEST.result, img_name.replace('.jpg', '.png')))
7275
73- def run_inference (self , loader ):
76+ def run_inference (self ):
77+ rospy .loginfo ("Waiting for loader from queue..." )
78+ while self .loader_q .empty ():
79+ rospy .sleep (0.01 )
80+ tic = rospy .get_rostime ()
81+ rospy .loginfo ("Processing image..." )
82+ loader = self .loader_q .get ()
83+
7484 self .segmentation_module .eval ()
7585
7686 pbar = tqdm (total = len (loader ))
@@ -98,20 +108,24 @@ def run_inference(self, loader):
98108
99109 _ , pred = torch .max (scores , dim = 1 )
100110 pred = as_numpy (pred .squeeze (0 ).cpu ())
111+ # print(dir(scores))
112+ # print(scores.shape)
101113
102114 # visualization
103115 self .visualize_result (
104116 (batch_data ['img_ori' ], batch_data ['info' ]),
105117 pred ,
106118 self .cfg
107119 )
108-
109120 pbar .update (1 )
121+
122+ rospy .loginfo ('Inference done in %.03f seconds.' %
123+ ((rospy .get_rostime () - tic ).to_sec ()))
110124
111125 def image_callback (self , img ):
112- tic = rospy . get_rostime ()
113- rospy . loginfo ( "Processing image..." )
114-
126+ if self . loader_q . full ():
127+ return
128+
115129 try :
116130 cv_image = self .bridge .imgmsg_to_cv2 (img , desired_encoding = "rgb8" )
117131 except CvBridgeError as e :
@@ -120,6 +134,7 @@ def image_callback(self, img):
120134 imgs = []
121135 # Need it in PIL?
122136 PILimg = Image .fromarray (cv_image )
137+ PILimg = PILimg .resize ((480 , 320 ))
123138 # In case we ever want to batch multiple images
124139 imgs .append (PILimg )
125140 img_list = [{'img' : x } for x in imgs ]
@@ -139,9 +154,7 @@ def image_callback(self, img):
139154
140155 # img_labels = self.segment(gpu)
141156
142- self .run_inference (loader )
143- rospy .loginfo ('Inference done in %.03f seconds.' %
144- ((rospy .get_rostime () - tic ).to_sec ()))
157+ self .loader_q .put (loader )
145158
146159 def main (self ):
147160
@@ -164,11 +177,15 @@ def main(self):
164177 self .segmentation_module = SegmentationModule (net_encoder , net_decoder , crit )
165178 self .segmentation_module .cuda ()
166179
167- self .seg_pub = rospy .Publisher (self .img_out , sensor_msgs .msg .Image , queue_size = 10 )
180+ self .seg_pub = rospy .Publisher (self .img_out , sensor_msgs .msg .Image , queue_size = 1 )
168181 rospy .Subscriber (self .img_in , sensor_msgs .msg .Image , self .image_callback )
169182
170183 rospy .loginfo ("Listening for image messages on topic %s..." % self .img_in )
171184 rospy .loginfo ("Publishing segmented images to topic %s..." % self .img_out )
185+
186+ while not rospy .is_shutdown ():
187+ self .run_inference ()
188+
172189 rospy .spin ()
173190
174191
0 commit comments