@@ -89,9 +89,11 @@ def __call__(
8989 step_generator : torch .Generator = None ,
9090 eta : float = 0 ,
9191 noise : torch .Tensor = None ,
92+ encoding : torch .Tensor = None ,
9293 return_dict = True ,
9394 ) -> Union [
94- Union [AudioPipelineOutput , ImagePipelineOutput ], Tuple [List [Image .Image ], Tuple [int , List [np .ndarray ]]]
95+ Union [AudioPipelineOutput , ImagePipelineOutput ],
96+ Tuple [List [Image .Image ], Tuple [int , List [np .ndarray ]]],
9597 ]:
9698 """Generate random mel spectrogram from audio input and convert to audio.
9799
@@ -108,6 +110,7 @@ def __call__(
108110 step_generator (`torch.Generator`): random number generator used to de-noise or None
109111 eta (`float`): parameter between 0 and 1 used with DDIM scheduler
110112 noise (`torch.Tensor`): noise tensor of shape (batch_size, 1, height, width) or None
113+ encoding (`torch.Tensor`): for UNet2DConditionModel shape (batch_size, seq_length, cross_attention_dim)
111114 return_dict (`bool`): if True return AudioPipelineOutput, ImagePipelineOutput else Tuple
112115
113116 Returns:
@@ -124,7 +127,12 @@ def __call__(
124127 self .mel .set_resolution (x_res = input_dims [1 ], y_res = input_dims [0 ])
125128 if noise is None :
126129 noise = torch .randn (
127- (batch_size , self .unet .in_channels , self .unet .sample_size [0 ], self .unet .sample_size [1 ]),
130+ (
131+ batch_size ,
132+ self .unet .in_channels ,
133+ self .unet .sample_size [0 ],
134+ self .unet .sample_size [1 ],
135+ ),
128136 generator = generator ,
129137 device = self .device ,
130138 )
@@ -157,15 +165,25 @@ def __call__(
157165 mask = self .scheduler .add_noise (input_images , noise , torch .tensor (self .scheduler .timesteps [start_step :]))
158166
159167 for step , t in enumerate (self .progress_bar (self .scheduler .timesteps [start_step :])):
160- model_output = self .unet (images , t )["sample" ]
168+ if isinstance (self .unet , UNet2DConditionModel ):
169+ model_output = self .unet (images , t , encoding )["sample" ]
170+ else :
171+ model_output = self .unet (images , t )["sample" ]
161172
162173 if isinstance (self .scheduler , DDIMScheduler ):
163174 images = self .scheduler .step (
164- model_output = model_output , timestep = t , sample = images , eta = eta , generator = step_generator
175+ model_output = model_output ,
176+ timestep = t ,
177+ sample = images ,
178+ eta = eta ,
179+ generator = step_generator ,
165180 )["prev_sample" ]
166181 else :
167182 images = self .scheduler .step (
168- model_output = model_output , timestep = t , sample = images , generator = step_generator
183+ model_output = model_output ,
184+ timestep = t ,
185+ sample = images ,
186+ generator = step_generator ,
169187 )["prev_sample" ]
170188
171189 if mask is not None :
0 commit comments