Skip to content

Commit af3f9a5

Browse files
committed
Added support for latest pytorch lightning version
1 parent 8700d24 commit af3f9a5

31 files changed

+543
-305
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
data/
2+
Data/
33
logs/
44

55
VanillaVAE/version_0/

README.md

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ $ cd PyTorch-VAE
3939
$ python run.py -c configs/<config-file-name.yaml>
4040
```
4141
**Config file template**
42+
4243
```yaml
4344
model_params:
4445
name: "<name of VAE model>"
@@ -48,10 +49,15 @@ model_params:
4849
.
4950
.
5051

51-
exp_params:
52+
data_params:
5253
data_path: "<path to the celebA dataset>"
53-
img_size: 64 # Models are designed to work for this size
54-
batch_size: 64 # Better to have a square number
54+
train_batch_size: 64 # Better to have a square number
55+
val_batch_size: 64
56+
patch_size: 64 # Models are designed to work for this size
57+
num_workers: 4
58+
59+
exp_params:
60+
manual_seed: 1265
5561
LR: 0.005
5662
weight_decay:
5763
. # Other arguments required for training, like scheduler etc.
@@ -60,7 +66,7 @@ exp_params:
6066

6167
trainer_params:
6268
gpus: 1
63-
max_nb_epochs: 50
69+
max_epochs: 100
6470
gradient_clip_val: 1.5
6571
.
6672
.
@@ -69,15 +75,17 @@ trainer_params:
6975
logging_params:
7076
save_dir: "logs/"
7177
name: "<experiment name>"
72-
manual_seed:
7378
```
7479
7580
**View TensorBoard Logs**
7681
```
7782
$ cd logs/<experiment name>/version_<the version you want>
78-
$ tensorboard --logdir tf
83+
$ tensorboard --logdir .
7984
```
8085

86+
**Note:** The default dataset is CelebA. However, there has been many issues with downloading the dataset from google drive (owing to some file structure changes). So, the recommendation is to download the [file](https://drive.google.com/file/d/1m8-EBPgi5MRubrm6iQjafK2QMHDBMSfJ/view?usp=sharing) from google drive directly and extract to the path of your choice. The default path assumed in the config files is `Data/celeba/img_align_celeba'. But you can change it acording to your preference.
87+
88+
8189
----
8290
<h2 align="center">
8391
<b>Results</b><br>

configs/bbvae.yaml

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,25 @@ model_params:
77
max_capacity: 25
88
Capacity_max_iter: 10000
99

10+
data_params:
11+
data_path: "Data/"
12+
train_batch_size: 64
13+
val_batch_size: 64
14+
patch_size: 64
15+
num_workers: 4
16+
1017
exp_params:
11-
dataset: celeba
12-
data_path: "../../shared/Data/"
13-
img_size: 64
14-
batch_size: 144 # Better to have a square number
15-
LR: 0.0005
18+
LR: 0.005
1619
weight_decay: 0.0
1720
scheduler_gamma: 0.95
21+
kld_weight: 0.00025
22+
manual_seed: 1265
1823

1924
trainer_params:
20-
gpus: 1
21-
max_nb_epochs: 50
22-
max_epochs: 50
25+
gpus: [1]
26+
max_epochs: 10
2327

2428
logging_params:
2529
save_dir: "logs/"
26-
name: "BetaVAE_B"
2730
manual_seed: 1265
31+
name: 'BetaVAE'

configs/betatc_vae.yaml

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,25 @@ model_params:
77
beta: 6.
88
gamma: 1.
99

10+
data_params:
11+
data_path: "Data/"
12+
train_batch_size: 64
13+
val_batch_size: 64
14+
patch_size: 64
15+
num_workers: 4
16+
17+
1018
exp_params:
11-
dataset: celeba
12-
data_path: "../../shared/momo/Data/"
13-
img_size: 64
14-
batch_size: 144 # Better to have a square number
15-
LR: 0.001
19+
LR: 0.005
1620
weight_decay: 0.0
17-
# scheduler_gamma: 0.99
21+
scheduler_gamma: 0.95
22+
kld_weight: 0.00025
23+
manual_seed: 1265
1824

1925
trainer_params:
20-
gpus: 1
21-
max_nb_epochs: 50
22-
max_epochs: 50
26+
gpus: [1]
27+
max_epochs: 10
2328

2429
logging_params:
2530
save_dir: "logs/"
26-
name: "BetaTCVAE"
27-
manual_seed: 1265
31+
name: 'BetaTCVAE'

configs/bhvae.yaml

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,25 @@ model_params:
55
loss_type: 'H'
66
beta: 10.
77

8+
data_params:
9+
data_path: "Data/"
10+
train_batch_size: 64
11+
val_batch_size: 64
12+
patch_size: 64
13+
num_workers: 4
14+
15+
816
exp_params:
9-
dataset: celeba
10-
data_path: "../../shared/Data/"
11-
img_size: 64
12-
batch_size: 144 # Better to have a square number
13-
LR: 0.0005
17+
LR: 0.005
1418
weight_decay: 0.0
1519
scheduler_gamma: 0.95
20+
kld_weight: 0.00025
21+
manual_seed: 1265
1622

1723
trainer_params:
18-
gpus: 1
19-
max_nb_epochs: 50
20-
max_epochs: 50
24+
gpus: [1]
25+
max_epochs: 10
2126

2227
logging_params:
2328
save_dir: "logs/"
24-
name: "BetaVAE_H"
25-
manual_seed: 1265
29+
name: 'BetaVAE'

configs/cat_vae.yaml

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,25 @@ model_params:
88
anneal_interval: 100
99
alpha: 1.0
1010

11+
data_params:
12+
data_path: "Data/"
13+
train_batch_size: 64
14+
val_batch_size: 64
15+
patch_size: 64
16+
num_workers: 4
17+
18+
1119
exp_params:
12-
dataset: celeba
13-
data_path: "../../shared/Data/"
14-
img_size: 64
15-
batch_size: 144 # Better to have a square number
1620
LR: 0.005
1721
weight_decay: 0.0
1822
scheduler_gamma: 0.95
23+
kld_weight: 0.00025
24+
manual_seed: 1265
1925

2026
trainer_params:
2127
gpus: [1]
22-
max_nb_epochs: 50
23-
max_epochs: 50
28+
max_epochs: 10
2429

2530
logging_params:
2631
save_dir: "logs/"
2732
name: "CategoricalVAE"
28-
manual_seed: 1265

configs/cvae.yaml

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,25 @@ model_params:
44
num_classes: 40
55
latent_dim: 128
66

7+
data_params:
8+
data_path: "Data/"
9+
train_batch_size: 64
10+
val_batch_size: 64
11+
patch_size: 64
12+
num_workers: 4
13+
14+
715
exp_params:
8-
dataset: celeba
9-
data_path: "../../shared/Data/"
10-
img_size: 64
11-
batch_size: 144 # Better to have a square number
1216
LR: 0.005
1317
weight_decay: 0.0
1418
scheduler_gamma: 0.95
19+
kld_weight: 0.00025
20+
manual_seed: 1265
1521

1622
trainer_params:
17-
gpus: 1
18-
max_nb_epochs: 50
19-
max_epochs: 50
23+
gpus: [1]
24+
max_epochs: 10
2025

2126
logging_params:
2227
save_dir: "logs/"
23-
name: "ConditionalVAE"
24-
manual_seed: 1265
28+
name: "ConditionalVAE"

configs/dfc_vae.yaml

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,25 @@ model_params:
33
in_channels: 3
44
latent_dim: 128
55

6+
data_params:
7+
data_path: "Data/"
8+
train_batch_size: 64
9+
val_batch_size: 64
10+
patch_size: 64
11+
num_workers: 4
12+
13+
614
exp_params:
7-
dataset: celeba
8-
data_path: "../../shared/Data/"
9-
img_size: 64
10-
batch_size: 144 # Better to have a square number
1115
LR: 0.005
1216
weight_decay: 0.0
1317
scheduler_gamma: 0.95
18+
kld_weight: 0.00025
19+
manual_seed: 1265
1420

1521
trainer_params:
16-
gpus: 1
17-
max_nb_epochs: 50
18-
max_epochs: 50
22+
gpus: [1]
23+
max_epochs: 10
1924

2025
logging_params:
2126
save_dir: "logs/"
2227
name: "DFCVAE"
23-
manual_seed: 1265

configs/dip_vae.yaml

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,24 @@ model_params:
66
lambda_offdiag: 0.1
77

88

9+
data_params:
10+
data_path: "Data/"
11+
train_batch_size: 64
12+
val_batch_size: 64
13+
patch_size: 64
14+
num_workers: 4
15+
16+
917
exp_params:
10-
dataset: celeba
11-
data_path: "../../shared/momo/Data/"
12-
img_size: 64
13-
batch_size: 144 # Better to have a square number
1418
LR: 0.001
1519
weight_decay: 0.0
1620
scheduler_gamma: 0.97
21+
kld_weight: 1
22+
manual_seed: 1265
1723

1824
trainer_params:
19-
gpus: 1
20-
max_nb_epochs: 50
21-
max_epochs: 50
25+
gpus: [1]
26+
max_epochs: 10
2227

2328
logging_params:
2429
save_dir: "logs/"

configs/factorvae.yaml

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,31 @@ model_params:
44
latent_dim: 128
55
gamma: 6.4
66

7+
data_params:
8+
data_path: "Data/"
9+
train_batch_size: 64
10+
val_batch_size: 64
11+
patch_size: 64
12+
num_workers: 4
13+
14+
715
exp_params:
8-
dataset: celeba
9-
data_path: "../../shared/Data/"
1016
submodel: 'discriminator'
1117
retain_first_backpass: True
12-
img_size: 64
13-
batch_size: 144 # Better to have a square number
1418
LR: 0.005
1519
weight_decay: 0.0
16-
scheduler_gamma: 0.95
1720
LR_2: 0.005
1821
scheduler_gamma_2: 0.95
22+
scheduler_gamma: 0.95
23+
kld_weight: 0.00025
24+
manual_seed: 1265
1925

2026
trainer_params:
21-
gpus: [3]
22-
max_nb_epochs: 30
23-
max_epochs: 50
27+
gpus: [1]
28+
max_epochs: 10
2429

2530
logging_params:
2631
save_dir: "logs/"
27-
name: "FactorVAE"
28-
manual_seed: 1265
32+
name: "FactorVAE"
33+
34+

0 commit comments

Comments
 (0)