1
1
import torch
2
2
import numpy as np
3
3
from functools import partial
4
+ from abc import abstractmethod
4
5
5
6
from ...util import append_zero
6
7
from ...modules .diffusionmodules .util import make_beta_schedule
@@ -13,19 +14,23 @@ def generate_roughly_equally_spaced_steps(
13
14
14
15
15
16
class Discretization :
16
- def __call__ (self , n , do_append_zero = True , device = "cuda " , flip = False ):
17
- sigmas = self .get_sigmas (n , device )
17
+ def __call__ (self , n , do_append_zero = True , device = "cpu " , flip = False ):
18
+ sigmas = self .get_sigmas (n , device = device )
18
19
sigmas = append_zero (sigmas ) if do_append_zero else sigmas
19
20
return sigmas if not flip else torch .flip (sigmas , (0 ,))
20
21
22
+ @abstractmethod
23
+ def get_sigmas (self , n , device ):
24
+ raise NotImplementedError ("abstract class should not be called" )
25
+
21
26
22
27
class EDMDiscretization (Discretization ):
23
28
def __init__ (self , sigma_min = 0.02 , sigma_max = 80.0 , rho = 7.0 ):
24
29
self .sigma_min = sigma_min
25
30
self .sigma_max = sigma_max
26
31
self .rho = rho
27
32
28
- def get_sigmas (self , n , device ):
33
+ def get_sigmas (self , n , device = "cpu" ):
29
34
ramp = torch .linspace (0 , 1 , n , device = device )
30
35
min_inv_rho = self .sigma_min ** (1 / self .rho )
31
36
max_inv_rho = self .sigma_max ** (1 / self .rho )
@@ -40,6 +45,7 @@ def __init__(
40
45
linear_end = 0.0120 ,
41
46
num_timesteps = 1000 ,
42
47
):
48
+ super ().__init__ ()
43
49
self .num_timesteps = num_timesteps
44
50
betas = make_beta_schedule (
45
51
"linear" , num_timesteps , linear_start = linear_start , linear_end = linear_end
@@ -48,7 +54,7 @@ def __init__(
48
54
self .alphas_cumprod = np .cumprod (alphas , axis = 0 )
49
55
self .to_torch = partial (torch .tensor , dtype = torch .float32 )
50
56
51
- def get_sigmas (self , n , device ):
57
+ def get_sigmas (self , n , device = "cpu" ):
52
58
if n < self .num_timesteps :
53
59
timesteps = generate_roughly_equally_spaced_steps (n , self .num_timesteps )
54
60
alphas_cumprod = self .alphas_cumprod [timesteps ]
0 commit comments