Skip to content

Commit ba3e7fe

Browse files
committed
Fixing additional GPU memory on device 0 due to discretization
1 parent 061d11d commit ba3e7fe

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

sgm/modules/diffusionmodules/discretizer.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import numpy as np
33
from functools import partial
4+
from abc import abstractmethod
45

56
from ...util import append_zero
67
from ...modules.diffusionmodules.util import make_beta_schedule
@@ -13,19 +14,23 @@ def generate_roughly_equally_spaced_steps(
1314

1415

1516
class Discretization:
16-
def __call__(self, n, do_append_zero=True, device="cuda", flip=False):
17-
sigmas = self.get_sigmas(n, device)
17+
def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
18+
sigmas = self.get_sigmas(n, device=device)
1819
sigmas = append_zero(sigmas) if do_append_zero else sigmas
1920
return sigmas if not flip else torch.flip(sigmas, (0,))
2021

22+
@abstractmethod
23+
def get_sigmas(self, n, device):
24+
raise NotImplementedError("abstract class should not be called")
25+
2126

2227
class EDMDiscretization(Discretization):
2328
def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
2429
self.sigma_min = sigma_min
2530
self.sigma_max = sigma_max
2631
self.rho = rho
2732

28-
def get_sigmas(self, n, device):
33+
def get_sigmas(self, n, device="cpu"):
2934
ramp = torch.linspace(0, 1, n, device=device)
3035
min_inv_rho = self.sigma_min ** (1 / self.rho)
3136
max_inv_rho = self.sigma_max ** (1 / self.rho)
@@ -40,6 +45,7 @@ def __init__(
4045
linear_end=0.0120,
4146
num_timesteps=1000,
4247
):
48+
super().__init__()
4349
self.num_timesteps = num_timesteps
4450
betas = make_beta_schedule(
4551
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
@@ -48,7 +54,7 @@ def __init__(
4854
self.alphas_cumprod = np.cumprod(alphas, axis=0)
4955
self.to_torch = partial(torch.tensor, dtype=torch.float32)
5056

51-
def get_sigmas(self, n, device):
57+
def get_sigmas(self, n, device="cpu"):
5258
if n < self.num_timesteps:
5359
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
5460
alphas_cumprod = self.alphas_cumprod[timesteps]

0 commit comments

Comments
 (0)