2222from jax import random
2323
2424from ..configuration_utils import ConfigMixin , register_to_config
25- from .scheduling_utils_flax import FlaxSchedulerMixin , FlaxSchedulerOutput
25+ from .scheduling_utils_flax import FlaxSchedulerMixin , FlaxSchedulerOutput , broadcast_to_shape_from_left
2626
2727
2828@flax .struct .dataclass
@@ -80,6 +80,10 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
8080 correct_steps (`int`): number of correction steps performed on a produced sample.
8181 """
8282
83+ @property
84+ def has_state (self ):
85+ return True
86+
8387 @register_to_config
8488 def __init__ (
8589 self ,
@@ -90,12 +94,20 @@ def __init__(
9094 sampling_eps : float = 1e-5 ,
9195 correct_steps : int = 1 ,
9296 ):
93- state = ScoreSdeVeSchedulerState . create ()
97+ pass
9498
95- self .state = self .set_sigmas (state , num_train_timesteps , sigma_min , sigma_max , sampling_eps )
99+ def create_state (self ):
100+ state = ScoreSdeVeSchedulerState .create ()
101+ return self .set_sigmas (
102+ state ,
103+ self .config .num_train_timesteps ,
104+ self .config .sigma_min ,
105+ self .config .sigma_max ,
106+ self .config .sampling_eps ,
107+ )
96108
97109 def set_timesteps (
98- self , state : ScoreSdeVeSchedulerState , num_inference_steps : int , shape : Tuple , sampling_eps : float = None
110+ self , state : ScoreSdeVeSchedulerState , num_inference_steps : int , shape : Tuple = () , sampling_eps : float = None
99111 ) -> ScoreSdeVeSchedulerState :
100112 """
101113 Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -193,8 +205,7 @@ def step_pred(
193205 # equation 6 in the paper: the model_output modeled by the network is grad_x log pt(x)
194206 # also equation 47 shows the analog from SDE models to ancestral sampling methods
195207 diffusion = diffusion .flatten ()
196- while len (diffusion .shape ) < len (sample .shape ):
197- diffusion = diffusion [:, None ]
208+ diffusion = broadcast_to_shape_from_left (diffusion , sample .shape )
198209 drift = drift - diffusion ** 2 * model_output
199210
200211 # equation 6: sample noise for the diffusion term of
@@ -252,8 +263,7 @@ def step_correct(
252263
253264 # compute corrected sample: model_output term and noise term
254265 step_size = step_size .flatten ()
255- while len (step_size .shape ) < len (sample .shape ):
256- step_size = step_size [:, None ]
266+ step_size = broadcast_to_shape_from_left (step_size , sample .shape )
257267 prev_sample_mean = sample + step_size * model_output
258268 prev_sample = prev_sample_mean + ((step_size * 2 ) ** 0.5 ) * noise
259269
0 commit comments