5
5
from ..math import logsumexp
6
6
from .dist_math import bound , random_choice
7
7
from .distribution import (Discrete , Distribution , draw_values ,
8
- generate_samples , _DrawValuesContext )
8
+ generate_samples , _DrawValuesContext ,
9
+ _DrawValuesContextBlocker , to_tuple )
9
10
from .continuous import get_tau_sigma , Normal
10
11
11
12
@@ -102,6 +103,35 @@ def __init__(self, w, comp_dists, *args, **kwargs):
102
103
103
104
super ().__init__ (shape , dtype , defaults = defaults , * args , ** kwargs )
104
105
106
+ @property
107
+ def comp_dists (self ):
108
+ return self ._comp_dists
109
+
110
+ @comp_dists .setter
111
+ def comp_dists (self , _comp_dists ):
112
+ self ._comp_dists = _comp_dists
113
+ # Tests if the comp_dists can call random with non None size
114
+ with _DrawValuesContextBlocker ():
115
+ if isinstance (self .comp_dists , (list , tuple )):
116
+ try :
117
+ [comp_dist .random (size = 23 )
118
+ for comp_dist in self .comp_dists ]
119
+ self ._comp_dists_vect = True
120
+ except Exception :
121
+ # The comp_dists cannot call random with non None size or
122
+ # without knowledge of the point so we assume that we will
123
+ # have to iterate calls to random to get the correct size
124
+ self ._comp_dists_vect = False
125
+ else :
126
+ try :
127
+ self .comp_dists .random (size = 23 )
128
+ self ._comp_dists_vect = True
129
+ except Exception :
130
+ # The comp_dists cannot call random with non None size or
131
+ # without knowledge of the point so we assume that we will
132
+ # have to iterate calls to random to get the correct size
133
+ self ._comp_dists_vect = False
134
+
105
135
def _comp_logp (self , value ):
106
136
comp_dists = self .comp_dists
107
137
@@ -131,13 +161,33 @@ def _comp_modes(self):
131
161
axis = 1 ))
132
162
133
163
def _comp_samples (self , point = None , size = None ):
134
- try :
135
- samples = self .comp_dists .random (point = point , size = size )
136
- except AttributeError :
137
- samples = np .column_stack ([comp_dist .random (point = point , size = size )
138
- for comp_dist in self .comp_dists ])
139
-
140
- return np .squeeze (samples )
164
+ if self ._comp_dists_vect or size is None :
165
+ try :
166
+ return self .comp_dists .random (point = point , size = size )
167
+ except AttributeError :
168
+ samples = np .array ([comp_dist .random (point = point , size = size )
169
+ for comp_dist in self .comp_dists ])
170
+ samples = np .moveaxis (samples , 0 , samples .ndim - 1 )
171
+ else :
172
+ # We must iterate the calls to random manually
173
+ size = to_tuple (size )
174
+ _size = int (np .prod (size ))
175
+ try :
176
+ samples = np .array ([self .comp_dists .random (point = point ,
177
+ size = None )
178
+ for _ in range (_size )])
179
+ samples = np .reshape (samples , size + samples .shape [1 :])
180
+ except AttributeError :
181
+ samples = np .array ([[comp_dist .random (point = point , size = None )
182
+ for _ in range (_size )]
183
+ for comp_dist in self .comp_dists ])
184
+ samples = np .moveaxis (samples , 0 , samples .ndim - 1 )
185
+ samples = np .reshape (samples , size + samples [1 :])
186
+
187
+ if samples .shape [- 1 ] == 1 :
188
+ return samples [..., 0 ]
189
+ else :
190
+ return samples
141
191
142
192
def logp (self , value ):
143
193
w = self .w
@@ -147,42 +197,99 @@ def logp(self, value):
147
197
broadcast_conditions = False )
148
198
149
199
def random (self , point = None , size = None ):
200
+ # Convert size to tuple
201
+ size = to_tuple (size )
202
+ # Draw mixture weights and a sample from each mixture to infer shape
150
203
with _DrawValuesContext () as draw_context :
151
- w = draw_values ([self .w ], point = point )[0 ]
204
+ # We first need to check w and comp_tmp shapes and re compute size
205
+ w = draw_values ([self .w ], point = point , size = size )[0 ]
206
+ with _DrawValuesContextBlocker ():
207
+ # We don't want to store the values drawn here in the context
208
+ # because they wont have the correct size
152
209
comp_tmp = self ._comp_samples (point = point , size = None )
153
- if np .asarray (self .shape ).size == 0 :
154
- distshape = np .asarray (np .broadcast (w , comp_tmp ).shape )[..., :- 1 ]
210
+
211
+ # When size is not None, it's hard to tell the w parameter shape
212
+ if size is not None and w .shape [:len (size )] == size :
213
+ w_shape = w .shape [len (size ):]
214
+ else :
215
+ w_shape = w .shape
216
+
217
+ # Try to determine parameter shape and dist_shape
218
+ param_shape = np .broadcast (np .empty (w_shape ),
219
+ comp_tmp ).shape
220
+ if np .asarray (self .shape ).size != 0 :
221
+ dist_shape = np .broadcast (np .empty (self .shape ),
222
+ np .empty (param_shape [:- 1 ])).shape
223
+ else :
224
+ dist_shape = param_shape [:- 1 ]
225
+
226
+ # When size is not None, maybe dist_shape partially overlaps with size
227
+ if size is not None :
228
+ if size == dist_shape :
229
+ size = None
230
+ elif size [- len (dist_shape ):] == dist_shape :
231
+ size = size [:len (size ) - len (dist_shape )]
232
+
233
+ # We get an integer _size instead of a tuple size for drawing the
234
+ # mixture, then we just reshape the output
235
+ if size is None :
236
+ _size = None
155
237
else :
156
- distshape = np .asarray (self .shape )
238
+ _size = int (np .prod (size ))
239
+
240
+ # Now we must broadcast w to the shape that considers size, dist_shape
241
+ # and param_shape. However, we must take care with the cases in which
242
+ # dist_shape and param_shape overlap
243
+ if size is not None and w .shape [:len (size )] == size :
244
+ if w .shape [:len (size + dist_shape )] != (size + dist_shape ):
245
+ # To allow w to broadcast, we insert new axis in between the
246
+ # "size" axis and the "mixture" axis
247
+ _w = w [(slice (None ),) * len (size ) + # Index the size axis
248
+ (np .newaxis ,) * len (dist_shape ) + # Add new axis for the dist_shape
249
+ (slice (None ),)] # Close with the slice of mixture components
250
+ w = np .broadcast_to (_w , size + dist_shape + (param_shape [- 1 ],))
251
+ elif size is not None :
252
+ w = np .broadcast_to (w , size + dist_shape + (param_shape [- 1 ],))
253
+ else :
254
+ w = np .broadcast_to (w , dist_shape + (param_shape [- 1 ],))
157
255
158
- # Normalize inputs
159
- w /= w .sum (axis = - 1 , keepdims = True )
256
+ # Compute the total size of the mixture's random call with size
257
+ if _size is not None :
258
+ output_size = int (_size * np .prod (dist_shape ) * param_shape [- 1 ])
259
+ else :
260
+ output_size = int (np .prod (dist_shape ) * param_shape [- 1 ])
261
+ # Get the size we need for the mixture's random call
262
+ mixture_size = int (output_size // np .prod (comp_tmp .shape ))
263
+ if mixture_size == 1 and _size is None :
264
+ mixture_size = None
265
+
266
+ # Semiflatten the mixture weights. The last axis is the number of
267
+ # mixture mixture components, and the rest is all about size,
268
+ # dist_shape and broadcasting
269
+ w = np .reshape (w , (- 1 , w .shape [- 1 ]))
270
+ # Normalize mixture weights
271
+ w = w / w .sum (axis = - 1 , keepdims = True )
160
272
161
273
w_samples = generate_samples (random_choice ,
162
274
p = w ,
163
275
broadcast_shape = w .shape [:- 1 ] or (1 ,),
164
- dist_shape = distshape ,
165
- size = size ).squeeze ()
166
- if (size is None ) or (distshape .size == 0 ):
167
- with draw_context :
168
- comp_samples = self ._comp_samples (point = point , size = size )
169
- if comp_samples .ndim > 1 :
170
- samples = np .squeeze (comp_samples [np .arange (w_samples .size ), ..., w_samples ])
171
- else :
172
- samples = np .squeeze (comp_samples [w_samples ])
276
+ dist_shape = w .shape [:- 1 ] or (1 ,),
277
+ size = size )
278
+ # Sample from the mixture
279
+ with draw_context :
280
+ mixed_samples = self ._comp_samples (point = point ,
281
+ size = mixture_size )
282
+ w_samples = w_samples .flatten ()
283
+ # Semiflatten the mixture to be able to zip it with w_samples
284
+ mixed_samples = np .reshape (mixed_samples , (- 1 , comp_tmp .shape [- 1 ]))
285
+ # Select the samples from the mixture
286
+ samples = np .array ([mixed [choice ] for choice , mixed in
287
+ zip (w_samples , mixed_samples )])
288
+ # Reshape the samples to the correct output shape
289
+ if size is None :
290
+ samples = np .reshape (samples , dist_shape )
173
291
else :
174
- if w_samples .ndim == 1 :
175
- w_samples = np .reshape (np .tile (w_samples , size ), (size ,) + w_samples .shape )
176
- samples = np .zeros ((size ,)+ tuple (distshape ))
177
- with draw_context :
178
- for i in range (size ):
179
- w_tmp = w_samples [i , :]
180
- comp_tmp = self ._comp_samples (point = point , size = None )
181
- if comp_tmp .ndim > 1 :
182
- samples [i , :] = np .squeeze (comp_tmp [np .arange (w_tmp .size ), ..., w_tmp ])
183
- else :
184
- samples [i , :] = np .squeeze (comp_tmp [w_tmp ])
185
-
292
+ samples = np .reshape (samples , size + dist_shape )
186
293
return samples
187
294
188
295
0 commit comments