Skip to content

Commit 6f6d3f8

Browse files
authored
Replace most print()s with logging calls (Stability-AI#42)
1 parent 6ecd0a9 commit 6f6d3f8

File tree

10 files changed

+118
-92
lines changed

10 files changed

+118
-92
lines changed

sgm/data/dataset.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,22 @@
1+
import logging
12
from typing import Optional
23

34
import torchdata.datapipes.iter
45
import webdataset as wds
56
from omegaconf import DictConfig
67
from pytorch_lightning import LightningDataModule
78

9+
logger = logging.getLogger(__name__)
10+
811
try:
912
from sdata import create_dataset, create_dummy_dataset, create_loader
1013
except ImportError as e:
11-
print("#" * 100)
12-
print("Datasets not yet available")
13-
print("to enable, we need to add stable-datasets as a submodule")
14-
print("please use ``git submodule update --init --recursive``")
15-
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
16-
print("#" * 100)
17-
exit(1)
14+
raise NotImplementedError(
15+
"Datasets not yet available. "
16+
"To enable, we need to add stable-datasets as a submodule; "
17+
"please use ``git submodule update --init --recursive`` "
18+
"and do ``pip install -e stable-datasets/`` from the root of this repo"
19+
) from e
1820

1921

2022
class StableDataModuleFromConfig(LightningDataModule):
@@ -39,8 +41,8 @@ def __init__(
3941
"datapipeline" in self.val_config and "loader" in self.val_config
4042
), "validation config requires the fields `datapipeline` and `loader`"
4143
else:
42-
print(
43-
"Warning: No Validation datapipeline defined, using that one from training"
44+
logger.warning(
45+
"No Validation datapipeline defined, using that one from training"
4446
)
4547
self.val_config = train
4648

@@ -52,12 +54,10 @@ def __init__(
5254

5355
self.dummy = dummy
5456
if self.dummy:
55-
print("#" * 100)
56-
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
57-
print("#" * 100)
57+
logger.warning("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
5858

5959
def setup(self, stage: str) -> None:
60-
print("Preparing datasets")
60+
logger.debug("Preparing datasets")
6161
if self.dummy:
6262
data_fn = create_dummy_dataset
6363
else:

sgm/lr_scheduler.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
1+
import logging
2+
13
import numpy as np
24

5+
logger = logging.getLogger(__name__)
6+
37

48
class LambdaWarmUpCosineScheduler:
59
"""
@@ -24,9 +28,8 @@ def __init__(
2428
self.verbosity_interval = verbosity_interval
2529

2630
def schedule(self, n, **kwargs):
27-
if self.verbosity_interval > 0:
28-
if n % self.verbosity_interval == 0:
29-
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
31+
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
32+
logger.info(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
3033
if n < self.lr_warm_up_steps:
3134
lr = (
3235
self.lr_max - self.lr_start
@@ -83,12 +86,11 @@ def find_in_interval(self, n):
8386
def schedule(self, n, **kwargs):
8487
cycle = self.find_in_interval(n)
8588
n = n - self.cum_cycles[cycle]
86-
if self.verbosity_interval > 0:
87-
if n % self.verbosity_interval == 0:
88-
print(
89-
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
90-
f"current cycle {cycle}"
91-
)
89+
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
90+
logger.info(
91+
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
92+
f"current cycle {cycle}"
93+
)
9294
if n < self.lr_warm_up_steps[cycle]:
9395
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
9496
cycle
@@ -114,12 +116,11 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
114116
def schedule(self, n, **kwargs):
115117
cycle = self.find_in_interval(n)
116118
n = n - self.cum_cycles[cycle]
117-
if self.verbosity_interval > 0:
118-
if n % self.verbosity_interval == 0:
119-
print(
120-
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
121-
f"current cycle {cycle}"
122-
)
119+
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
120+
logger.info(
121+
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
122+
f"current cycle {cycle}"
123+
)
123124

124125
if n < self.lr_warm_up_steps[cycle]:
125126
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[

sgm/models/autoencoder.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import re
23
from abc import abstractmethod
34
from contextlib import contextmanager
@@ -14,6 +15,8 @@
1415
from ..modules.ema import LitEma
1516
from ..util import default, get_obj_from_str, instantiate_from_config
1617

18+
logger = logging.getLogger(__name__)
19+
1720

1821
class AbstractAutoencoder(pl.LightningModule):
1922
"""
@@ -38,7 +41,7 @@ def __init__(
3841

3942
if self.use_ema:
4043
self.model_ema = LitEma(self, decay=ema_decay)
41-
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
44+
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
4245

4346
if ckpt_path is not None:
4447
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
@@ -60,16 +63,16 @@ def init_from_ckpt(
6063
for k in keys:
6164
for ik in ignore_keys:
6265
if re.match(ik, k):
63-
print("Deleting key {} from state_dict.".format(k))
66+
logger.debug(f"Deleting key {k} from state_dict.")
6467
del sd[k]
6568
missing, unexpected = self.load_state_dict(sd, strict=False)
66-
print(
69+
logger.debug(
6770
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
6871
)
6972
if len(missing) > 0:
70-
print(f"Missing Keys: {missing}")
73+
logger.info(f"Missing Keys: {missing}")
7174
if len(unexpected) > 0:
72-
print(f"Unexpected Keys: {unexpected}")
75+
logger.info(f"Unexpected Keys: {unexpected}")
7376

7477
@abstractmethod
7578
def get_input(self, batch) -> Any:
@@ -86,14 +89,14 @@ def ema_scope(self, context=None):
8689
self.model_ema.store(self.parameters())
8790
self.model_ema.copy_to(self)
8891
if context is not None:
89-
print(f"{context}: Switched to EMA weights")
92+
logger.info(f"{context}: Switched to EMA weights")
9093
try:
9194
yield None
9295
finally:
9396
if self.use_ema:
9497
self.model_ema.restore(self.parameters())
9598
if context is not None:
96-
print(f"{context}: Restored training weights")
99+
logger.info(f"{context}: Restored training weights")
97100

98101
@abstractmethod
99102
def encode(self, *args, **kwargs) -> torch.Tensor:
@@ -104,7 +107,7 @@ def decode(self, *args, **kwargs) -> torch.Tensor:
104107
raise NotImplementedError("decode()-method of abstract base class called")
105108

106109
def instantiate_optimizer_from_config(self, params, lr, cfg):
107-
print(f"loading >>> {cfg['target']} <<< optimizer from config")
110+
logger.debug(f"loading >>> {cfg['target']} <<< optimizer from config")
108111
return get_obj_from_str(cfg["target"])(
109112
params, lr=lr, **cfg.get("params", dict())
110113
)

sgm/models/diffusion.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from contextlib import contextmanager
23
from typing import Any, Dict, List, Tuple, Union
34

@@ -18,6 +19,8 @@
1819
log_txt_as_img,
1920
)
2021

22+
logger = logging.getLogger(__name__)
23+
2124

2225
class DiffusionEngine(pl.LightningModule):
2326
def __init__(
@@ -73,7 +76,7 @@ def __init__(
7376
self.use_ema = use_ema
7477
if self.use_ema:
7578
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
76-
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
79+
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
7780

7881
self.scale_factor = scale_factor
7982
self.disable_first_stage_autocast = disable_first_stage_autocast
@@ -94,13 +97,13 @@ def init_from_ckpt(
9497
raise NotImplementedError
9598

9699
missing, unexpected = self.load_state_dict(sd, strict=False)
97-
print(
100+
logger.info(
98101
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
99102
)
100103
if len(missing) > 0:
101-
print(f"Missing Keys: {missing}")
104+
logger.info(f"Missing Keys: {missing}")
102105
if len(unexpected) > 0:
103-
print(f"Unexpected Keys: {unexpected}")
106+
logger.info(f"Unexpected Keys: {unexpected}")
104107

105108
def _init_first_stage(self, config):
106109
model = instantiate_from_config(config).eval()
@@ -179,14 +182,14 @@ def ema_scope(self, context=None):
179182
self.model_ema.store(self.model.parameters())
180183
self.model_ema.copy_to(self.model)
181184
if context is not None:
182-
print(f"{context}: Switched to EMA weights")
185+
logger.info(f"{context}: Switched to EMA weights")
183186
try:
184187
yield None
185188
finally:
186189
if self.use_ema:
187190
self.model_ema.restore(self.model.parameters())
188191
if context is not None:
189-
print(f"{context}: Restored training weights")
192+
logger.info(f"{context}: Restored training weights")
190193

191194
def instantiate_optimizer_from_config(self, params, lr, cfg):
192195
return get_obj_from_str(cfg["target"])(
@@ -202,7 +205,7 @@ def configure_optimizers(self):
202205
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
203206
if self.scheduler_config is not None:
204207
scheduler = instantiate_from_config(self.scheduler_config)
205-
print("Setting up LambdaLR scheduler...")
208+
logger.debug("Setting up LambdaLR scheduler...")
206209
scheduler = [
207210
{
208211
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
@@ -304,7 +307,7 @@ def log_images(
304307
log["inputs"] = x
305308
z = self.encode_first_stage(x)
306309
log["reconstructions"] = self.decode_first_stage(z)
307-
log.update(self.log_conditionings(batch, N))
310+
logger.update(self.log_conditionings(batch, N))
308311

309312
for k in c:
310313
if isinstance(c[k], torch.Tensor):

sgm/modules/attention.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import math
23
from inspect import isfunction
34
from typing import Any, Optional
@@ -8,6 +9,10 @@
89
from packaging import version
910
from torch import nn
1011

12+
13+
logger = logging.getLogger(__name__)
14+
15+
1116
if version.parse(torch.__version__) >= version.parse("2.0.0"):
1217
SDP_IS_AVAILABLE = True
1318
from torch.backends.cuda import SDPBackend, sdp_kernel
@@ -36,9 +41,9 @@
3641
SDP_IS_AVAILABLE = False
3742
sdp_kernel = nullcontext
3843
BACKEND_MAP = {}
39-
print(
40-
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
41-
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
44+
logger.warning(
45+
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. "
46+
f"In fact, you are using PyTorch {torch.__version__}. You might want to consider upgrading."
4247
)
4348

4449
try:
@@ -48,7 +53,7 @@
4853
XFORMERS_IS_AVAILABLE = True
4954
except:
5055
XFORMERS_IS_AVAILABLE = False
51-
print("no module 'xformers'. Processing without...")
56+
logger.debug("no module 'xformers'. Processing without...")
5257

5358
from .diffusionmodules.util import checkpoint
5459

@@ -289,7 +294,7 @@ def __init__(
289294
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
290295
):
291296
super().__init__()
292-
print(
297+
logger.info(
293298
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
294299
f"{heads} heads with a dimension of {dim_head}."
295300
)
@@ -393,22 +398,21 @@ def __init__(
393398
super().__init__()
394399
assert attn_mode in self.ATTENTION_MODES
395400
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
396-
print(
401+
logger.warning(
397402
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
398403
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
399404
)
400405
attn_mode = "softmax"
401406
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
402-
print(
407+
logger.warning(
403408
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
404409
)
405410
if not XFORMERS_IS_AVAILABLE:
406-
assert (
407-
False
408-
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
409-
else:
410-
print("Falling back to xformers efficient attention.")
411-
attn_mode = "softmax-xformers"
411+
raise NotImplementedError(
412+
"Please install xformers via e.g. 'pip install xformers==0.0.16'"
413+
)
414+
logger.info("Falling back to xformers efficient attention.")
415+
attn_mode = "softmax-xformers"
412416
attn_cls = self.ATTENTION_MODES[attn_mode]
413417
if version.parse(torch.__version__) >= version.parse("2.0.0"):
414418
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
@@ -437,7 +441,7 @@ def __init__(
437441
self.norm3 = nn.LayerNorm(dim)
438442
self.checkpoint = checkpoint
439443
if self.checkpoint:
440-
print(f"{self.__class__.__name__} is using checkpointing")
444+
logger.info(f"{self.__class__.__name__} is using checkpointing")
441445

442446
def forward(
443447
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
@@ -554,7 +558,7 @@ def __init__(
554558
sdp_backend=None,
555559
):
556560
super().__init__()
557-
print(
561+
logger.debug(
558562
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
559563
)
560564
from omegaconf import ListConfig
@@ -563,8 +567,8 @@ def __init__(
563567
context_dim = [context_dim]
564568
if exists(context_dim) and isinstance(context_dim, list):
565569
if depth != len(context_dim):
566-
print(
567-
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
570+
logger.warning(
571+
f"{self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
568572
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
569573
)
570574
# depth does not match context dims.

sgm/modules/autoencoding/losses/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from typing import Any, Union
23

34
import torch
@@ -10,6 +11,9 @@
1011
from ....util import default, instantiate_from_config
1112

1213

14+
logger = logging.getLogger(__name__)
15+
16+
1317
def adopt_weight(weight, global_step, threshold=0, value=0.0):
1418
if global_step < threshold:
1519
weight = value
@@ -104,7 +108,7 @@ def __init__(
104108
super().__init__()
105109
self.dims = dims
106110
if self.dims > 2:
107-
print(
111+
logger.info(
108112
f"running with dims={dims}. This means that for perceptual loss calculation, "
109113
f"the LPIPS loss will be applied to each frame independently. "
110114
)

0 commit comments

Comments
 (0)