@@ -238,8 +238,8 @@ def test_queue_enumerate(self):
238238        tr_latents  =  []
239239        for  tr  in  trs :
240240            tr_latents .append (tuple ([int (tr .nodes [name ]["value" ].view (- 1 ).data [0 ]) for  name  in  tr 
241-                                      if  tr .nodes [name ]["type" ] ==  "sample" 
242-                                      and   not  tr .nodes [name ]["is_observed" ]]))
241+                                      if  tr .nodes [name ]["type" ] ==  "sample"   and 
242+                                      not  tr .nodes [name ]["is_observed" ]]))
243243
244244        assert  true_latents  ==  set (tr_latents )
245245
@@ -333,8 +333,8 @@ def test_prior_dict(self):
333333            if  name  in  {'sigma1' , 'mu1' , 'sigma2' , 'mu2' }:
334334                self .assertTrue (name  +  "_prior"  ==  lifted_tr .nodes [name ]['fn' ].__name__ )
335335            if  tr .nodes [name ]["type" ] ==  "param" :
336-                 self .assertTrue (lifted_tr .nodes [name ]["type" ] ==  "sample" 
337-                                 and   not  lifted_tr .nodes [name ]["is_observed" ])
336+                 self .assertTrue (lifted_tr .nodes [name ]["type" ] ==  "sample"   and 
337+                                 not  lifted_tr .nodes [name ]["is_observed" ])
338338
339339    def  test_unlifted_param (self ):
340340        tr  =  poutine .trace (self .guide ).get_trace ()
@@ -343,8 +343,8 @@ def test_unlifted_param(self):
343343            self .assertTrue (name  in  lifted_tr )
344344            if  name  in  ('sigma1' , 'mu1' ):
345345                self .assertTrue (name  +  "_prior"  ==  lifted_tr .nodes [name ]['fn' ].__name__ )
346-                 self .assertTrue (lifted_tr .nodes [name ]["type" ] ==  "sample" 
347-                                 and   not  lifted_tr .nodes [name ]["is_observed" ])
346+                 self .assertTrue (lifted_tr .nodes [name ]["type" ] ==  "sample"   and 
347+                                 not  lifted_tr .nodes [name ]["is_observed" ])
348348            if  name  in  ('sigma2' , 'mu2' ):
349349                self .assertTrue (lifted_tr .nodes [name ]["type" ] ==  "param" )
350350
@@ -353,8 +353,8 @@ def test_random_module(self):
353353        lifted_tr  =  poutine .trace (pyro .random_module ("name" , self .model , prior = self .prior )).get_trace ()
354354        for  name  in  lifted_tr .nodes .keys ():
355355            if  lifted_tr .nodes [name ]["type" ] ==  "param" :
356-                 self .assertTrue (lifted_tr .nodes [name ]["type" ] ==  "sample" 
357-                                 and   not  lifted_tr .nodes [name ]["is_observed" ])
356+                 self .assertTrue (lifted_tr .nodes [name ]["type" ] ==  "sample"   and 
357+                                 not  lifted_tr .nodes [name ]["is_observed" ])
358358
359359    def  test_random_module_prior_dict (self ):
360360        pyro .clear_param_store ()
@@ -366,8 +366,8 @@ def test_random_module_prior_dict(self):
366366                dist_name  =  name [3 :]
367367                self .assertTrue (
368368                    dist_name  +  "_prior"  ==  lifted_tr .nodes [key_name ]['fn' ].__name__ )
369-                 self .assertTrue (lifted_tr .nodes [key_name ]["type" ] ==  "sample" 
370-                                 and   not  lifted_tr .nodes [key_name ]["is_observed" ])
369+                 self .assertTrue (lifted_tr .nodes [key_name ]["type" ] ==  "sample"   and 
370+                                 not  lifted_tr .nodes [key_name ]["is_observed" ])
371371
372372
373373class  QueuePoutineMixedTest (TestCase ):
0 commit comments