Skip to content

Commit a29f7f1

Browse files
authored
Merge pull request pymc-devs#1331 from pymc-devs/pep8
STY Ran autopep8 on full code-base.
2 parents d092dd8 + 7a0bdb5 commit a29f7f1

File tree

104 files changed

+1289
-894
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

104 files changed

+1289
-894
lines changed

pymc3/backends/base.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class BaseTrace(object):
2424
Sampling values will be stored for these variables. If None,
2525
`model.unobserved_RVs` is used.
2626
"""
27+
2728
def __init__(self, name, model=None, vars=None):
2829
self.name = name
2930

@@ -35,17 +36,16 @@ def __init__(self, name, model=None, vars=None):
3536
self.varnames = [var.name for var in vars]
3637
self.fn = model.fastfn(vars)
3738

38-
39-
## Get variable shapes. Most backends will need this
40-
## information.
39+
# Get variable shapes. Most backends will need this
40+
# information.
4141
var_values = list(zip(self.varnames, self.fn(model.test_point)))
4242
self.var_shapes = {var: value.shape
4343
for var, value in var_values}
4444
self.var_dtypes = {var: value.dtype
4545
for var, value in var_values}
4646
self.chain = None
4747

48-
## Sampling methods
48+
# Sampling methods
4949

5050
def setup(self, draws, chain):
5151
"""Perform chain-specific setup.
@@ -76,7 +76,7 @@ def close(self):
7676
"""
7777
pass
7878

79-
## Selection methods
79+
# Selection methods
8080

8181
def __getitem__(self, idx):
8282
if isinstance(idx, slice):
@@ -149,6 +149,7 @@ class MultiTrace(object):
149149
of the MultiTrace instance, which returns the number of draws), the
150150
trace with the highest chain number is always used.
151151
"""
152+
152153
def __init__(self, straces):
153154
self._straces = {}
154155
for strace in straces:

pymc3/backends/ndarray.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@ class NDArray(base.BaseTrace):
1919
Sampling values will be stored for these variables. If None,
2020
`model.unobserved_RVs` is used.
2121
"""
22+
2223
def __init__(self, name=None, model=None, vars=None):
2324
super(NDArray, self).__init__(name, model, vars)
2425
self.draw_idx = 0
2526
self.draws = None
2627
self.samples = {}
2728

28-
## Sampling methods
29+
# Sampling methods
2930

3031
def setup(self, draws, chain):
3132
"""Perform chain-specific setup.
@@ -70,12 +71,12 @@ def record(self, point):
7071
def close(self):
7172
if self.draw_idx == self.draws:
7273
return
73-
## Remove trailing zeros if interrupted before completed all
74-
## draws.
74+
# Remove trailing zeros if interrupted before completed all
75+
# draws.
7576
self.samples = {var: vtrace[:self.draw_idx]
7677
for var, vtrace in self.samples.items()}
7778

78-
## Selection methods
79+
# Selection methods
7980

8081
def __len__(self):
8182
if not self.samples: # `setup` has not been called.

pymc3/backends/sqlite.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
'WHERE chain = ?'),
3535
'draw_count': ('SELECT COUNT(*) FROM [{table}] '
3636
'WHERE chain = ?'),
37-
## Named placeholders are used in the selection templates because
38-
## some values occur more than once in the same template.
37+
# Named placeholders are used in the selection templates because
38+
# some values occur more than once in the same template.
3939
'select': ('SELECT * FROM [{table}] '
4040
'WHERE (chain = :chain)'),
4141
'select_burn': ('SELECT * FROM [{table}] '
@@ -71,6 +71,7 @@ class SQLite(base.BaseTrace):
7171
Sampling values will be stored for these variables. If None,
7272
`model.unobserved_RVs` is used.
7373
"""
74+
7475
def __init__(self, name, model=None, vars=None):
7576
super(SQLite, self).__init__(name, model, vars)
7677
self._var_cols = {}
@@ -80,13 +81,13 @@ def __init__(self, name, model=None, vars=None):
8081
self._len = None
8182

8283
self.db = _SQLiteDB(name)
83-
## Inserting sampling information is queued to avoid locks
84-
## caused by hitting the database with transactions each
85-
## iteration.
84+
# Inserting sampling information is queued to avoid locks
85+
# caused by hitting the database with transactions each
86+
# iteration.
8687
self._queue = {varname: [] for varname in self.varnames}
8788
self._queue_limit = 5000
8889

89-
## Sampling methods
90+
# Sampling methods
9091

9192
def setup(self, draws, chain):
9293
"""Perform chain-specific setup.
@@ -127,7 +128,7 @@ def _create_table(self):
127128
def _create_insert_queries(self, chain):
128129
template = TEMPLATES['insert']
129130
for varname, var_cols in self._var_cols.items():
130-
## Create insert statement for each variable.
131+
# Create insert statement for each variable.
131132
var_str = ', '.join(var_cols)
132133
placeholders = ', '.join(['?'] * len(var_cols))
133134
statement = template.format(table=varname,
@@ -164,7 +165,7 @@ def close(self):
164165
self._execute_queue()
165166
self.db.close()
166167

167-
## Selection methods
168+
# Selection methods
168169

169170
def __len__(self):
170171
if not self._is_setup:
@@ -252,6 +253,7 @@ def point(self, idx):
252253

253254

254255
class _SQLiteDB(object):
256+
255257
def __init__(self, name):
256258
self.name = name
257259
self.con = None
@@ -306,8 +308,8 @@ def load(name, model=None):
306308

307309
def _get_table_list(cursor):
308310
"""Return a list of table names in the current database."""
309-
## Modified from Django. Skips the sqlite_sequence system table used
310-
## for autoincrement key generation.
311+
# Modified from Django. Skips the sqlite_sequence system table used
312+
# for autoincrement key generation.
311313
cursor.execute("SELECT name FROM sqlite_master "
312314
"WHERE type='table' AND NOT name='sqlite_sequence' "
313315
"ORDER BY name")

pymc3/backends/text.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class Text(base.BaseTrace):
3737
Sampling values will be stored for these variables. If None,
3838
`model.unobserved_RVs` is used.
3939
"""
40+
4041
def __init__(self, name, model=None, vars=None):
4142
if not os.path.exists(name):
4243
os.mkdir(name)
@@ -49,7 +50,7 @@ def __init__(self, name, model=None, vars=None):
4950
self._fh = None
5051
self.df = None
5152

52-
## Sampling methods
53+
# Sampling methods
5354

5455
def setup(self, draws, chain):
5556
"""Perform chain-specific setup.
@@ -96,7 +97,7 @@ def close(self):
9697
self._fh.close()
9798
self._fh = None # Avoid serialization issue.
9899

99-
## Selection methods
100+
# Selection methods
100101

101102
def _load_df(self):
102103
if self.df is None:
@@ -194,5 +195,6 @@ def dump(name, trace, chains=None):
194195

195196
for chain in chains:
196197
filename = os.path.join(name, 'chain-{}.csv'.format(chain))
197-
df = ttab.trace_to_dataframe(trace, chains=chain, flat_names=flat_names)
198+
df = ttab.trace_to_dataframe(
199+
trace, chains=chain, flat_names=flat_names)
198200
df.to_csv(filename, index=False)

pymc3/blocking.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class ArrayOrdering(object):
1717
"""
1818
An ordering for an array space
1919
"""
20+
2021
def __init__(self, vars):
2122
self.vmap = []
2223
dim = 0
@@ -33,6 +34,7 @@ class DictToArrayBijection(object):
3334
"""
3435
A mapping between a dict space and an array space
3536
"""
37+
3638
def __init__(self, ordering, dpoint):
3739
self.ordering = ordering
3840
self.dpt = dpoint
@@ -85,6 +87,7 @@ class DictToVarBijection(object):
8587
"""
8688
A mapping between a dict space and the array space for one element within the dict space
8789
"""
90+
8891
def __init__(self, var, idx, dpoint):
8992
self.var = str(var)
9093
self.idx = idx
@@ -111,6 +114,7 @@ class Compose(object):
111114
"""
112115
Compose two functions in a pickleable way
113116
"""
117+
114118
def __init__(self, fa, fb):
115119
self.fa = fa
116120
self.fb = fb

pymc3/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
__all__ = ['get_data_file']
55

6+
67
def get_data_file(pkg, path):
78
"""Returns a file object for a package data file.
8-
9+
910
Parameters
1011
----------
1112
pkg : str
@@ -18,4 +19,3 @@ def get_data_file(pkg, path):
1819
"""
1920

2021
return io.BytesIO(pkgutil.get_data(pkg, path))
21-

pymc3/diagnostics.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def geweke(x, first=.1, last=.5, intervals=20):
7575
last_start_idx = (1 - last) * end
7676

7777
# Calculate starting indices
78-
start_indices = np.arange(0, int(last_start_idx), step=int((last_start_idx) / (intervals - 1)))
78+
start_indices = np.arange(0, int(last_start_idx), step=int(
79+
(last_start_idx) / (intervals - 1)))
7980

8081
# Loop over start indices
8182
for start in start_indices:
@@ -151,9 +152,9 @@ def calc_rhat(x):
151152
W = np.mean(np.var(x, axis=1, ddof=1))
152153

153154
# Estimate of marginal posterior variance
154-
Vhat = W*(n - 1)/n + B/n
155+
Vhat = W * (n - 1) / n + B / n
155156

156-
return np.sqrt(Vhat/W)
157+
return np.sqrt(Vhat / W)
157158

158159
except ValueError:
159160

@@ -223,7 +224,7 @@ def calc_vhat(x):
223224
W = np.mean(np.var(x, axis=1, ddof=1))
224225

225226
# Estimate of marginal posterior variance
226-
Vhat = W*(n - 1)/n + B/n
227+
Vhat = W * (n - 1) / n + B / n
227228

228229
return Vhat
229230

@@ -243,21 +244,22 @@ def calc_n_eff(x):
243244

244245
Vhat = calc_vhat(x)
245246

246-
variogram = lambda t: (sum(sum((x[j][i] - x[j][i-t])**2
247-
for i in range(t,n)) for j in range(m)) / (m*(n - t)))
247+
variogram = lambda t: (sum(sum((x[j][i] - x[j][i - t])**2
248+
for i in range(t, n)) for j in range(m)) / (m * (n - t)))
248249

249250
rho = np.ones(n)
250-
# Iterate until the sum of consecutive estimates of autocorrelation is negative
251+
# Iterate until the sum of consecutive estimates of autocorrelation is
252+
# negative
251253
while not negative_autocorr and (t < n):
252254

253-
rho[t] = 1. - variogram(t)/(2.*Vhat)
255+
rho[t] = 1. - variogram(t) / (2. * Vhat)
254256

255257
if not t % 2:
256-
negative_autocorr = sum(rho[t-1:t+1]) < 0
258+
negative_autocorr = sum(rho[t - 1:t + 1]) < 0
257259

258260
t += 1
259261

260-
return int(m*n / (1. + 2*rho[1:t].sum()))
262+
return int(m * n / (1. + 2 * rho[1:t].sum()))
261263

262264
n_eff = {}
263265
for var in mtrace.varnames:

pymc3/distributions/__init__.py

Lines changed: 48 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -62,53 +62,51 @@
6262
from .transforms import sum_to_1
6363

6464
__all__ = ['Uniform',
65-
'Flat',
66-
'Normal',
67-
'Beta',
68-
'Exponential',
69-
'Laplace',
70-
'StudentT',
71-
'Cauchy',
72-
'HalfCauchy',
73-
'Gamma',
74-
'Weibull',
75-
'Bound',
76-
'StudentTpos',
77-
'Lognormal',
78-
'ChiSquared',
79-
'HalfNormal',
80-
'Wald',
81-
'Pareto',
82-
'InverseGamma',
83-
'ExGaussian',
84-
'VonMises',
85-
'Binomial',
86-
'BetaBinomial',
87-
'Bernoulli',
88-
'Poisson',
89-
'NegativeBinomial',
90-
'ConstantDist',
91-
'ZeroInflatedPoisson',
92-
'ZeroInflatedNegativeBinomial',
93-
'DiscreteUniform',
94-
'Geometric',
95-
'Categorical',
96-
'DensityDist',
97-
'Distribution',
98-
'Continuous',
99-
'Discrete',
100-
'NoDistribution',
101-
'TensorType',
102-
'MvNormal',
103-
'MvStudentT',
104-
'Dirichlet',
105-
'Multinomial',
106-
'Wishart',
107-
'WishartBartlett',
108-
'LKJCorr',
109-
'AR1',
110-
'GaussianRandomWalk',
111-
'GARCH11'
112-
]
113-
114-
65+
'Flat',
66+
'Normal',
67+
'Beta',
68+
'Exponential',
69+
'Laplace',
70+
'StudentT',
71+
'Cauchy',
72+
'HalfCauchy',
73+
'Gamma',
74+
'Weibull',
75+
'Bound',
76+
'StudentTpos',
77+
'Lognormal',
78+
'ChiSquared',
79+
'HalfNormal',
80+
'Wald',
81+
'Pareto',
82+
'InverseGamma',
83+
'ExGaussian',
84+
'VonMises',
85+
'Binomial',
86+
'BetaBinomial',
87+
'Bernoulli',
88+
'Poisson',
89+
'NegativeBinomial',
90+
'ConstantDist',
91+
'ZeroInflatedPoisson',
92+
'ZeroInflatedNegativeBinomial',
93+
'DiscreteUniform',
94+
'Geometric',
95+
'Categorical',
96+
'DensityDist',
97+
'Distribution',
98+
'Continuous',
99+
'Discrete',
100+
'NoDistribution',
101+
'TensorType',
102+
'MvNormal',
103+
'MvStudentT',
104+
'Dirichlet',
105+
'Multinomial',
106+
'Wishart',
107+
'WishartBartlett',
108+
'LKJCorr',
109+
'AR1',
110+
'GaussianRandomWalk',
111+
'GARCH11'
112+
]

0 commit comments

Comments
 (0)