@@ -282,13 +282,13 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
282282 https://arxiv.org/abs/2205.11487
283283 """
284284 dtype = sample .dtype
285- batch_size , channels , * remaining_dims = sample .shape
285+ batch_size , channels , height , width = sample .shape
286286
287287 if dtype not in (torch .float32 , torch .float64 ):
288288 sample = sample .float () # upcast for quantile calculation, and clamp not implemented for cpu half
289289
290290 # Flatten sample for doing quantile calculation along each image
291- sample = sample .reshape (batch_size , channels * np . prod ( remaining_dims ) )
291+ sample = sample .reshape (batch_size , channels * height * width )
292292
293293 abs_sample = sample .abs () # "a certain percentile absolute pixel value"
294294
@@ -300,7 +300,7 @@ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
300300 s = s .unsqueeze (1 ) # (batch_size, 1) because clamp will broadcast along dim=0
301301 sample = torch .clamp (sample , - s , s ) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
302302
303- sample = sample .reshape (batch_size , channels , * remaining_dims )
303+ sample = sample .reshape (batch_size , channels , height , width )
304304 sample = sample .to (dtype )
305305
306306 return sample
0 commit comments