Skip to content

Commit 4a7e4ce

Browse files
authored
Add condtional generation to AudioDiffusionPipeline (huggingface#1826)
* Add condtional generation * add fast test for conditional audio generation
1 parent f45c675 commit 4a7e4ce

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

src/diffusers/pipelines/audio_diffusion/pipeline_audio_diffusion.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,11 @@ def __call__(
8989
step_generator: torch.Generator = None,
9090
eta: float = 0,
9191
noise: torch.Tensor = None,
92+
encoding: torch.Tensor = None,
9293
return_dict=True,
9394
) -> Union[
94-
Union[AudioPipelineOutput, ImagePipelineOutput], Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]]
95+
Union[AudioPipelineOutput, ImagePipelineOutput],
96+
Tuple[List[Image.Image], Tuple[int, List[np.ndarray]]],
9597
]:
9698
"""Generate random mel spectrogram from audio input and convert to audio.
9799
@@ -108,6 +110,7 @@ def __call__(
108110
step_generator (`torch.Generator`): random number generator used to de-noise or None
109111
eta (`float`): parameter between 0 and 1 used with DDIM scheduler
110112
noise (`torch.Tensor`): noise tensor of shape (batch_size, 1, height, width) or None
113+
encoding (`torch.Tensor`): for UNet2DConditionModel shape (batch_size, seq_length, cross_attention_dim)
111114
return_dict (`bool`): if True return AudioPipelineOutput, ImagePipelineOutput else Tuple
112115
113116
Returns:
@@ -124,7 +127,12 @@ def __call__(
124127
self.mel.set_resolution(x_res=input_dims[1], y_res=input_dims[0])
125128
if noise is None:
126129
noise = torch.randn(
127-
(batch_size, self.unet.in_channels, self.unet.sample_size[0], self.unet.sample_size[1]),
130+
(
131+
batch_size,
132+
self.unet.in_channels,
133+
self.unet.sample_size[0],
134+
self.unet.sample_size[1],
135+
),
128136
generator=generator,
129137
device=self.device,
130138
)
@@ -157,15 +165,25 @@ def __call__(
157165
mask = self.scheduler.add_noise(input_images, noise, torch.tensor(self.scheduler.timesteps[start_step:]))
158166

159167
for step, t in enumerate(self.progress_bar(self.scheduler.timesteps[start_step:])):
160-
model_output = self.unet(images, t)["sample"]
168+
if isinstance(self.unet, UNet2DConditionModel):
169+
model_output = self.unet(images, t, encoding)["sample"]
170+
else:
171+
model_output = self.unet(images, t)["sample"]
161172

162173
if isinstance(self.scheduler, DDIMScheduler):
163174
images = self.scheduler.step(
164-
model_output=model_output, timestep=t, sample=images, eta=eta, generator=step_generator
175+
model_output=model_output,
176+
timestep=t,
177+
sample=images,
178+
eta=eta,
179+
generator=step_generator,
165180
)["prev_sample"]
166181
else:
167182
images = self.scheduler.step(
168-
model_output=model_output, timestep=t, sample=images, generator=step_generator
183+
model_output=model_output,
184+
timestep=t,
185+
sample=images,
186+
generator=step_generator,
169187
)["prev_sample"]
170188

171189
if mask is not None:

tests/pipelines/audio_diffusion/test_audio_diffusion.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
DDPMScheduler,
2727
DiffusionPipeline,
2828
Mel,
29+
UNet2DConditionModel,
2930
UNet2DModel,
3031
)
3132
from diffusers.utils import slow, torch_device
@@ -56,6 +57,21 @@ def dummy_unet(self):
5657
)
5758
return model
5859

60+
@property
61+
def dummy_unet_condition(self):
62+
torch.manual_seed(0)
63+
model = UNet2DConditionModel(
64+
sample_size=(64, 32),
65+
in_channels=1,
66+
out_channels=1,
67+
layers_per_block=2,
68+
block_out_channels=(128, 128),
69+
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
70+
up_block_types=("UpBlock2D", "CrossAttnUpBlock2D"),
71+
cross_attention_dim=10,
72+
)
73+
return model
74+
5975
@property
6076
def dummy_vqvae_and_unet(self):
6177
torch.manual_seed(0)
@@ -128,6 +144,19 @@ def test_audio_diffusion(self):
128144
expected_slice = np.array([120, 117, 110, 109, 138, 167, 138, 148, 132, 121])
129145
assert np.abs(image_slice.flatten() - expected_slice).max() == 0
130146

147+
dummy_unet_condition = self.dummy_unet_condition
148+
pipe = AudioDiffusionPipeline(
149+
vqvae=self.dummy_vqvae_and_unet[0], unet=dummy_unet_condition, mel=mel, scheduler=scheduler
150+
)
151+
152+
np.random.seed(0)
153+
encoding = torch.rand((1, 1, 10))
154+
output = pipe(generator=generator, encoding=encoding)
155+
image = output.images[0]
156+
image_slice = np.frombuffer(image.tobytes(), dtype="uint8")[:10]
157+
expected_slice = np.array([120, 139, 147, 123, 124, 96, 115, 121, 126, 144])
158+
assert np.abs(image_slice.flatten() - expected_slice).max() == 0
159+
131160

132161
@slow
133162
@require_torch_gpu

0 commit comments

Comments
 (0)