Skip to content

Commit 31dabcd

Browse files
jdownerneerajprad
authored andcommitted
Minor linting fixes (pyro-ppl#522)
* distributions: removed redundant conditional * tests: fixed lint warnings * tests: fixed lint warnings
1 parent 4616026 commit 31dabcd

File tree

4 files changed

+24
-24
lines changed

4 files changed

+24
-24
lines changed

pyro/distributions/bernoulli.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,16 @@ def batch_shape(self, x=None):
5050
event_dim = 1
5151
ps = self.ps
5252
if x is not None:
53-
if x is not None:
54-
if x.size()[-event_dim] != ps.size()[-event_dim]:
55-
raise ValueError("The event size for the data and distribution parameters must match.\n"
56-
"Expected x.size()[-1] == self.ps.size()[-1], but got {} vs {}".format(
57-
x.size(-1), ps.size(-1)))
58-
try:
59-
ps = self.ps.expand_as(x)
60-
except RuntimeError as e:
61-
raise ValueError("Parameter `ps` with shape {} is not broadcastable to "
62-
"the data shape {}. \nError: {}".format(ps.size(), x.size(), str(e)))
53+
if x.size()[-event_dim] != ps.size()[-event_dim]:
54+
raise ValueError("The event size for the data and distribution parameters must match.\n"
55+
"Expected x.size()[-1] == self.ps.size()[-1], but got {} vs {}".format(
56+
x.size(-1), ps.size(-1)))
57+
try:
58+
ps = self.ps.expand_as(x)
59+
except RuntimeError as e:
60+
raise ValueError("Parameter `ps` with shape {} is not broadcastable to "
61+
"the data shape {}. \nError: {}".format(ps.size(), x.size(), str(e)))
62+
6363
return ps.size()[:-event_dim]
6464

6565
def event_shape(self):

pyro/poutine/replay_poutine.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def __init__(self, fn, guide_trace, sites=None):
2222
# case 1: no sites
2323
if sites is None:
2424
self.sites = {site: site for site in guide_trace.nodes.keys()
25-
if guide_trace.nodes[site]["type"] == "sample"
26-
and not guide_trace.nodes[site]["is_observed"]}
25+
if guide_trace.nodes[site]["type"] == "sample" and
26+
not guide_trace.nodes[site]["is_observed"]}
2727
# case 2: sites is a list/tuple/set
2828
elif isinstance(sites, (list, tuple, set)):
2929
self.sites = {site: site for site in sites}

tests/infer/test_sampling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,8 @@ def test_complete(self):
8080
for tr, _ in posterior._traces():
8181
tr_latents.add(tuple([tr.nodes[name]["value"].view(-1).data[0]
8282
for name in tr.nodes.keys()
83-
if tr.nodes[name]["type"] == "sample"
84-
and not tr.nodes[name]["is_observed"]]))
83+
if tr.nodes[name]["type"] == "sample" and
84+
not tr.nodes[name]["is_observed"]]))
8585

8686
assert true_latents == tr_latents
8787

tests/poutine/test_poutines.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -238,8 +238,8 @@ def test_queue_enumerate(self):
238238
tr_latents = []
239239
for tr in trs:
240240
tr_latents.append(tuple([int(tr.nodes[name]["value"].view(-1).data[0]) for name in tr
241-
if tr.nodes[name]["type"] == "sample"
242-
and not tr.nodes[name]["is_observed"]]))
241+
if tr.nodes[name]["type"] == "sample" and
242+
not tr.nodes[name]["is_observed"]]))
243243

244244
assert true_latents == set(tr_latents)
245245

@@ -333,8 +333,8 @@ def test_prior_dict(self):
333333
if name in {'sigma1', 'mu1', 'sigma2', 'mu2'}:
334334
self.assertTrue(name + "_prior" == lifted_tr.nodes[name]['fn'].__name__)
335335
if tr.nodes[name]["type"] == "param":
336-
self.assertTrue(lifted_tr.nodes[name]["type"] == "sample"
337-
and not lifted_tr.nodes[name]["is_observed"])
336+
self.assertTrue(lifted_tr.nodes[name]["type"] == "sample" and
337+
not lifted_tr.nodes[name]["is_observed"])
338338

339339
def test_unlifted_param(self):
340340
tr = poutine.trace(self.guide).get_trace()
@@ -343,8 +343,8 @@ def test_unlifted_param(self):
343343
self.assertTrue(name in lifted_tr)
344344
if name in ('sigma1', 'mu1'):
345345
self.assertTrue(name + "_prior" == lifted_tr.nodes[name]['fn'].__name__)
346-
self.assertTrue(lifted_tr.nodes[name]["type"] == "sample"
347-
and not lifted_tr.nodes[name]["is_observed"])
346+
self.assertTrue(lifted_tr.nodes[name]["type"] == "sample" and
347+
not lifted_tr.nodes[name]["is_observed"])
348348
if name in ('sigma2', 'mu2'):
349349
self.assertTrue(lifted_tr.nodes[name]["type"] == "param")
350350

@@ -353,8 +353,8 @@ def test_random_module(self):
353353
lifted_tr = poutine.trace(pyro.random_module("name", self.model, prior=self.prior)).get_trace()
354354
for name in lifted_tr.nodes.keys():
355355
if lifted_tr.nodes[name]["type"] == "param":
356-
self.assertTrue(lifted_tr.nodes[name]["type"] == "sample"
357-
and not lifted_tr.nodes[name]["is_observed"])
356+
self.assertTrue(lifted_tr.nodes[name]["type"] == "sample" and
357+
not lifted_tr.nodes[name]["is_observed"])
358358

359359
def test_random_module_prior_dict(self):
360360
pyro.clear_param_store()
@@ -366,8 +366,8 @@ def test_random_module_prior_dict(self):
366366
dist_name = name[3:]
367367
self.assertTrue(
368368
dist_name + "_prior" == lifted_tr.nodes[key_name]['fn'].__name__)
369-
self.assertTrue(lifted_tr.nodes[key_name]["type"] == "sample"
370-
and not lifted_tr.nodes[key_name]["is_observed"])
369+
self.assertTrue(lifted_tr.nodes[key_name]["type"] == "sample" and
370+
not lifted_tr.nodes[key_name]["is_observed"])
371371

372372

373373
class QueuePoutineMixedTest(TestCase):

0 commit comments

Comments
 (0)