@@ -20,40 +20,46 @@ def get_a_var(obj):
2020 return None
2121
2222
23- def parallel_apply (modules , inputs , kwargs_tup = None ):
23+ def parallel_apply (modules , inputs , kwargs_tup = None , devices = None ):
2424 assert len (modules ) == len (inputs )
25- if kwargs_tup :
25+ if kwargs_tup is not None :
2626 assert len (modules ) == len (kwargs_tup )
2727 else :
2828 kwargs_tup = ({},) * len (modules )
29- # Fast track
30- if len (modules ) == 1 :
31- return (modules [0 ](* inputs [0 ], ** kwargs_tup [0 ]), )
29+ if devices is not None :
30+ assert len (modules ) == len (devices )
31+ else :
32+ devices = [None ] * len (modules )
3233
3334 lock = threading .Lock ()
3435 results = {}
3536
36- def _worker (i , module , input , kwargs , results , lock ):
37- var_input = get_a_var (input )
37+ def _worker (i , module , input , kwargs , results , lock , device = None ):
38+ if device is None :
39+ device = get_a_var (input ).get_device ()
3840 try :
39- with torch .cuda .device_of ( var_input ):
41+ with torch .cuda .device ( device ):
4042 output = module (* input , ** kwargs )
4143 with lock :
4244 results [i ] = output
4345 except Exception as e :
4446 with lock :
4547 results [i ] = e
4648
47- threads = [threading .Thread (target = _worker ,
48- args = (i , module , input , kwargs , results , lock ),
49- )
50- for i , (module , input , kwargs ) in
51- enumerate (zip (modules , inputs , kwargs_tup ))]
49+ if len (modules ) > 1 :
50+ threads = [threading .Thread (target = _worker ,
51+ args = (i , module , input , kwargs , results , lock , device ),
52+ )
53+ for i , (module , input , kwargs , device ) in
54+ enumerate (zip (modules , inputs , kwargs_tup , devices ))]
55+
56+ for thread in threads :
57+ thread .start ()
58+ for thread in threads :
59+ thread .join ()
60+ else :
61+ _worker (0 , modules [0 ], inputs [0 ], kwargs_tup [0 ], results , lock , devices [0 ])
5262
53- for thread in threads :
54- thread .start ()
55- for thread in threads :
56- thread .join ()
5763 outputs = []
5864 for i in range (len (inputs )):
5965 output = results [i ]
0 commit comments