Skip to content

Commit 5791f4a

Browse files
authored
[Type Hints] VAE models (huggingface#344)
* [Type Hints] VAE models * apply suggestions from code review apply suggestions to also return the return type
1 parent 878af0e commit 5791f4a

File tree

1 file changed

+24
-22
lines changed

1 file changed

+24
-22
lines changed

src/diffusers/models/vae.py

Lines changed: 24 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Optional, Tuple
2+
13
import numpy as np
24
import torch
35
import torch.nn as nn
@@ -293,7 +295,7 @@ def __init__(self, parameters, deterministic=False):
293295
if self.deterministic:
294296
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
295297

296-
def sample(self, generator=None):
298+
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
297299
x = self.mean + self.std * torch.randn(self.mean.shape, generator=generator, device=self.parameters.device)
298300
return x
299301

@@ -327,16 +329,16 @@ class VQModel(ModelMixin, ConfigMixin):
327329
@register_to_config
328330
def __init__(
329331
self,
330-
in_channels=3,
331-
out_channels=3,
332-
down_block_types=("DownEncoderBlock2D",),
333-
up_block_types=("UpDecoderBlock2D",),
334-
block_out_channels=(64,),
335-
layers_per_block=1,
336-
act_fn="silu",
337-
latent_channels=3,
338-
sample_size=32,
339-
num_vq_embeddings=256,
332+
in_channels: int = 3,
333+
out_channels: int = 3,
334+
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
335+
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
336+
block_out_channels: Tuple[int] = (64,),
337+
layers_per_block: int = 1,
338+
act_fn: str = "silu",
339+
latent_channels: int = 3,
340+
sample_size: int = 32,
341+
num_vq_embeddings: int = 256,
340342
):
341343
super().__init__()
342344

@@ -382,7 +384,7 @@ def decode(self, h, force_not_quantize=False):
382384
dec = self.decoder(quant)
383385
return dec
384386

385-
def forward(self, sample):
387+
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
386388
x = sample
387389
h = self.encode(x)
388390
dec = self.decode(h)
@@ -393,15 +395,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
393395
@register_to_config
394396
def __init__(
395397
self,
396-
in_channels=3,
397-
out_channels=3,
398-
down_block_types=("DownEncoderBlock2D",),
399-
up_block_types=("UpDecoderBlock2D",),
400-
block_out_channels=(64,),
401-
layers_per_block=1,
402-
act_fn="silu",
403-
latent_channels=4,
404-
sample_size=32,
398+
in_channels: int = 3,
399+
out_channels: int = 3,
400+
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
401+
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
402+
block_out_channels: Tuple[int] = (64,),
403+
layers_per_block: int = 1,
404+
act_fn: str = "silu",
405+
latent_channels: int = 4,
406+
sample_size: int = 32,
405407
):
406408
super().__init__()
407409

@@ -440,7 +442,7 @@ def decode(self, z):
440442
dec = self.decoder(z)
441443
return dec
442444

443-
def forward(self, sample, sample_posterior=False):
445+
def forward(self, sample: torch.FloatTensor, sample_posterior: bool = False) -> torch.FloatTensor:
444446
x = sample
445447
posterior = self.encode(x)
446448
if sample_posterior:

0 commit comments

Comments
 (0)