1+ from typing import Optional , Tuple
2+
13import numpy as np
24import torch
35import torch .nn as nn
@@ -293,7 +295,7 @@ def __init__(self, parameters, deterministic=False):
293295 if self .deterministic :
294296 self .var = self .std = torch .zeros_like (self .mean ).to (device = self .parameters .device )
295297
296- def sample (self , generator = None ):
298+ def sample (self , generator : Optional [ torch . Generator ] = None ) -> torch . FloatTensor :
297299 x = self .mean + self .std * torch .randn (self .mean .shape , generator = generator , device = self .parameters .device )
298300 return x
299301
@@ -327,16 +329,16 @@ class VQModel(ModelMixin, ConfigMixin):
327329 @register_to_config
328330 def __init__ (
329331 self ,
330- in_channels = 3 ,
331- out_channels = 3 ,
332- down_block_types = ("DownEncoderBlock2D" ,),
333- up_block_types = ("UpDecoderBlock2D" ,),
334- block_out_channels = (64 ,),
335- layers_per_block = 1 ,
336- act_fn = "silu" ,
337- latent_channels = 3 ,
338- sample_size = 32 ,
339- num_vq_embeddings = 256 ,
332+ in_channels : int = 3 ,
333+ out_channels : int = 3 ,
334+ down_block_types : Tuple [ str ] = ("DownEncoderBlock2D" ,),
335+ up_block_types : Tuple [ str ] = ("UpDecoderBlock2D" ,),
336+ block_out_channels : Tuple [ int ] = (64 ,),
337+ layers_per_block : int = 1 ,
338+ act_fn : str = "silu" ,
339+ latent_channels : int = 3 ,
340+ sample_size : int = 32 ,
341+ num_vq_embeddings : int = 256 ,
340342 ):
341343 super ().__init__ ()
342344
@@ -382,7 +384,7 @@ def decode(self, h, force_not_quantize=False):
382384 dec = self .decoder (quant )
383385 return dec
384386
385- def forward (self , sample ) :
387+ def forward (self , sample : torch . FloatTensor ) -> torch . FloatTensor :
386388 x = sample
387389 h = self .encode (x )
388390 dec = self .decode (h )
@@ -393,15 +395,15 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
393395 @register_to_config
394396 def __init__ (
395397 self ,
396- in_channels = 3 ,
397- out_channels = 3 ,
398- down_block_types = ("DownEncoderBlock2D" ,),
399- up_block_types = ("UpDecoderBlock2D" ,),
400- block_out_channels = (64 ,),
401- layers_per_block = 1 ,
402- act_fn = "silu" ,
403- latent_channels = 4 ,
404- sample_size = 32 ,
398+ in_channels : int = 3 ,
399+ out_channels : int = 3 ,
400+ down_block_types : Tuple [ str ] = ("DownEncoderBlock2D" ,),
401+ up_block_types : Tuple [ str ] = ("UpDecoderBlock2D" ,),
402+ block_out_channels : Tuple [ int ] = (64 ,),
403+ layers_per_block : int = 1 ,
404+ act_fn : str = "silu" ,
405+ latent_channels : int = 4 ,
406+ sample_size : int = 32 ,
405407 ):
406408 super ().__init__ ()
407409
@@ -440,7 +442,7 @@ def decode(self, z):
440442 dec = self .decoder (z )
441443 return dec
442444
443- def forward (self , sample , sample_posterior = False ):
445+ def forward (self , sample : torch . FloatTensor , sample_posterior : bool = False ) -> torch . FloatTensor :
444446 x = sample
445447 posterior = self .encode (x )
446448 if sample_posterior :
0 commit comments