Skip to content

Commit 240abdd

Browse files
[Flax] added broadcast_to_shape_from_left helper and Scheduler tests (huggingface#864)
* added broadcast_to_shape_from_left helper * initial tests * fixed pndm tests * shape required for pndm * added require_flax * fix style * fix more imports Co-authored-by: Patrick von Platen <[email protected]>
1 parent 38ae5a2 commit 240abdd

File tree

9 files changed

+931
-40
lines changed

9 files changed

+931
-40
lines changed

src/diffusers/schedulers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
3535
from .scheduling_pndm_flax import FlaxPNDMScheduler
3636
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
37-
from .scheduling_utils_flax import FlaxSchedulerMixin
37+
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
3838
else:
3939
from ..utils.dummy_flax_objects import * # noqa F403
4040

src/diffusers/schedulers/scheduling_ddim_flax.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import jax.numpy as jnp
2424

2525
from ..configuration_utils import ConfigMixin, register_to_config
26-
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
26+
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
2727

2828

2929
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
@@ -173,7 +173,9 @@ def _get_variance(self, timestep, prev_timestep, alphas_cumprod):
173173

174174
return variance
175175

176-
def set_timesteps(self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDIMSchedulerState:
176+
def set_timesteps(
177+
self, state: DDIMSchedulerState, num_inference_steps: int, shape: Tuple = ()
178+
) -> DDIMSchedulerState:
177179
"""
178180
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
179181
@@ -211,9 +213,6 @@ def step(
211213
timestep (`int`): current discrete timestep in the diffusion chain.
212214
sample (`jnp.ndarray`):
213215
current instance of sample being created by diffusion process.
214-
key (`random.KeyArray`): a PRNG key.
215-
eta (`float`): weight of noise for added noise in diffusion step.
216-
use_clipped_model_output (`bool`): TODO
217216
return_dict (`bool`): option for returning tuple rather than FlaxDDIMSchedulerOutput class
218217
219218
Returns:
@@ -279,13 +278,11 @@ def add_noise(
279278
) -> jnp.ndarray:
280279
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
281280
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
282-
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
283-
sqrt_alpha_prod = sqrt_alpha_prod[:, None]
281+
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
284282

285283
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.0
286284
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
287-
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
288-
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[:, None]
285+
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
289286

290287
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
291288
return noisy_samples

src/diffusers/schedulers/scheduling_ddpm_flax.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from jax import random
2424

2525
from ..configuration_utils import ConfigMixin, register_to_config
26-
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
26+
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
2727

2828

2929
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray:
@@ -101,6 +101,10 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
101101
102102
"""
103103

104+
@property
105+
def has_state(self):
106+
return True
107+
104108
@register_to_config
105109
def __init__(
106110
self,
@@ -129,11 +133,12 @@ def __init__(
129133
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
130134
self.one = jnp.array(1.0)
131135

132-
self.state = DDPMSchedulerState.create(num_train_timesteps=num_train_timesteps)
133-
134-
self.variance_type = variance_type
136+
def create_state(self):
137+
return DDPMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
135138

136-
def set_timesteps(self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple) -> DDPMSchedulerState:
139+
def set_timesteps(
140+
self, state: DDPMSchedulerState, num_inference_steps: int, shape: Tuple = ()
141+
) -> DDPMSchedulerState:
137142
"""
138143
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
139144
@@ -214,7 +219,7 @@ def step(
214219
"""
215220
t = timestep
216221

217-
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
222+
if model_output.shape[1] == sample.shape[1] * 2 and self.config.variance_type in ["learned", "learned_range"]:
218223
model_output, predicted_variance = jnp.split(model_output, sample.shape[1], axis=1)
219224
else:
220225
predicted_variance = None
@@ -267,13 +272,11 @@ def add_noise(
267272
) -> jnp.ndarray:
268273
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
269274
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
270-
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
271-
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
275+
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
272276

273277
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
274278
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
275-
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
276-
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
279+
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
277280

278281
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
279282
return noisy_samples

src/diffusers/schedulers/scheduling_karras_ve_flax.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
8787
A reasonable range is [0.2, 80].
8888
"""
8989

90+
@property
91+
def has_state(self):
92+
return True
93+
9094
@register_to_config
9195
def __init__(
9296
self,
@@ -97,10 +101,13 @@ def __init__(
97101
s_min: float = 0.05,
98102
s_max: float = 50,
99103
):
100-
self.state = KarrasVeSchedulerState.create()
104+
pass
105+
106+
def create_state(self):
107+
return KarrasVeSchedulerState.create()
101108

102109
def set_timesteps(
103-
self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple
110+
self, state: KarrasVeSchedulerState, num_inference_steps: int, shape: Tuple = ()
104111
) -> KarrasVeSchedulerState:
105112
"""
106113
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.

src/diffusers/schedulers/scheduling_lms_discrete_flax.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from scipy import integrate
2121

2222
from ..configuration_utils import ConfigMixin, register_to_config
23-
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
23+
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
2424

2525

2626
@flax.struct.dataclass
@@ -63,6 +63,10 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
6363
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
6464
"""
6565

66+
@property
67+
def has_state(self):
68+
return True
69+
6670
@register_to_config
6771
def __init__(
6872
self,
@@ -85,8 +89,10 @@ def __init__(
8589
self.alphas = 1.0 - self.betas
8690
self.alphas_cumprod = jnp.cumprod(self.alphas, axis=0)
8791

92+
def create_state(self):
8893
self.state = LMSDiscreteSchedulerState.create(
89-
num_train_timesteps=num_train_timesteps, sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
94+
num_train_timesteps=self.config.num_train_timesteps,
95+
sigmas=((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5,
9096
)
9197

9298
def get_lms_coefficient(self, state, order, t, current_order):
@@ -112,7 +118,7 @@ def lms_derivative(tau):
112118
return integrated_coeff
113119

114120
def set_timesteps(
115-
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple
121+
self, state: LMSDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
116122
) -> LMSDiscreteSchedulerState:
117123
"""
118124
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -199,8 +205,7 @@ def add_noise(
199205
timesteps: jnp.ndarray,
200206
) -> jnp.ndarray:
201207
sigma = state.sigmas[timesteps].flatten()
202-
while len(sigma.shape) < len(noise.shape):
203-
sigma = sigma[..., None]
208+
sigma = broadcast_to_shape_from_left(sigma, noise.shape)
204209

205210
noisy_samples = original_samples + noise * sigma
206211

src/diffusers/schedulers/scheduling_pndm_flax.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import jax.numpy as jnp
2424

2525
from ..configuration_utils import ConfigMixin, register_to_config
26-
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
26+
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
2727

2828

2929
def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray:
@@ -168,6 +168,8 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int, sha
168168
the `FlaxPNDMScheduler` state data class instance.
169169
num_inference_steps (`int`):
170170
the number of diffusion steps used when generating samples with a pre-trained model.
171+
shape (`Tuple`):
172+
the shape of the samples to be generated.
171173
"""
172174
offset = self.config.steps_offset
173175

@@ -509,13 +511,11 @@ def add_noise(
509511
) -> jnp.ndarray:
510512
sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5
511513
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
512-
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
513-
sqrt_alpha_prod = sqrt_alpha_prod[..., None]
514+
sqrt_alpha_prod = broadcast_to_shape_from_left(sqrt_alpha_prod, original_samples.shape)
514515

515516
sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5
516517
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
517-
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
518-
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod[..., None]
518+
sqrt_one_minus_alpha_prod = broadcast_to_shape_from_left(sqrt_one_minus_alpha_prod, original_samples.shape)
519519

520520
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
521521
return noisy_samples

src/diffusers/schedulers/scheduling_sde_ve_flax.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from jax import random
2323

2424
from ..configuration_utils import ConfigMixin, register_to_config
25-
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput
25+
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
2626

2727

2828
@flax.struct.dataclass
@@ -80,6 +80,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
8080
correct_steps (`int`): number of correction steps performed on a produced sample.
8181
"""
8282

83+
@property
84+
def has_state(self):
85+
return True
86+
8387
@register_to_config
8488
def __init__(
8589
self,
@@ -90,12 +94,20 @@ def __init__(
9094
sampling_eps: float = 1e-5,
9195
correct_steps: int = 1,
9296
):
93-
state = ScoreSdeVeSchedulerState.create()
97+
pass
9498

95-
self.state = self.set_sigmas(state, num_train_timesteps, sigma_min, sigma_max, sampling_eps)
99+
def create_state(self):
100+
state = ScoreSdeVeSchedulerState.create()
101+
return self.set_sigmas(
102+
state,
103+
self.config.num_train_timesteps,
104+
self.config.sigma_min,
105+
self.config.sigma_max,
106+
self.config.sampling_eps,
107+
)
96108

97109
def set_timesteps(
98-
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple, sampling_eps: float = None
110+
self, state: ScoreSdeVeSchedulerState, num_inference_steps: int, shape: Tuple = (), sampling_eps: float = None
99111
) -> ScoreSdeVeSchedulerState:
100112
"""
101113
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -193,8 +205,7 @@ def step_pred(
193205
# equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
194206
# also equation 47 shows the analog from SDE models to ancestral sampling methods
195207
diffusion = diffusion.flatten()
196-
while len(diffusion.shape) < len(sample.shape):
197-
diffusion = diffusion[:, None]
208+
diffusion = broadcast_to_shape_from_left(diffusion, sample.shape)
198209
drift = drift - diffusion**2 * model_output
199210

200211
# equation 6: sample noise for the diffusion term of
@@ -252,8 +263,7 @@ def step_correct(
252263

253264
# compute corrected sample: model_output term and noise term
254265
step_size = step_size.flatten()
255-
while len(step_size.shape) < len(sample.shape):
256-
step_size = step_size[:, None]
266+
step_size = broadcast_to_shape_from_left(step_size, sample.shape)
257267
prev_sample_mean = sample + step_size * model_output
258268
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5) * noise
259269

src/diffusers/schedulers/scheduling_utils_flax.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from dataclasses import dataclass
15+
from typing import Tuple
1516

1617
import jax.numpy as jnp
1718

@@ -41,3 +42,8 @@ class FlaxSchedulerMixin:
4142
"""
4243

4344
config_name = SCHEDULER_CONFIG_NAME
45+
46+
47+
def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray:
48+
assert len(shape) >= x.ndim
49+
return jnp.broadcast_to(x.reshape(x.shape + (1,) * (len(shape) - x.ndim)), shape)

0 commit comments

Comments
 (0)