Skip to content

Commit 4a3f0f5

Browse files
author
Jonas Müller
authored
Revert "Replace most print()s with logging calls (Stability-AI#42)" (Stability-AI#65)
This reverts commit 6f6d3f8.
1 parent 7934245 commit 4a3f0f5

File tree

10 files changed

+91
-117
lines changed

10 files changed

+91
-117
lines changed

sgm/data/dataset.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
1-
import logging
21
from typing import Optional
32

43
import torchdata.datapipes.iter
54
import webdataset as wds
65
from omegaconf import DictConfig
76
from pytorch_lightning import LightningDataModule
87

9-
logger = logging.getLogger(__name__)
10-
118
try:
129
from sdata import create_dataset, create_dummy_dataset, create_loader
1310
except ImportError as e:
14-
raise NotImplementedError(
15-
"Datasets not yet available. "
16-
"To enable, we need to add stable-datasets as a submodule; "
17-
"please use ``git submodule update --init --recursive`` "
18-
"and do ``pip install -e stable-datasets/`` from the root of this repo"
19-
) from e
11+
print("#" * 100)
12+
print("Datasets not yet available")
13+
print("to enable, we need to add stable-datasets as a submodule")
14+
print("please use ``git submodule update --init --recursive``")
15+
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
16+
print("#" * 100)
17+
exit(1)
2018

2119

2220
class StableDataModuleFromConfig(LightningDataModule):
@@ -41,8 +39,8 @@ def __init__(
4139
"datapipeline" in self.val_config and "loader" in self.val_config
4240
), "validation config requires the fields `datapipeline` and `loader`"
4341
else:
44-
logger.warning(
45-
"No Validation datapipeline defined, using that one from training"
42+
print(
43+
"Warning: No Validation datapipeline defined, using that one from training"
4644
)
4745
self.val_config = train
4846

@@ -54,10 +52,12 @@ def __init__(
5452

5553
self.dummy = dummy
5654
if self.dummy:
57-
logger.warning("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
55+
print("#" * 100)
56+
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
57+
print("#" * 100)
5858

5959
def setup(self, stage: str) -> None:
60-
logger.debug("Preparing datasets")
60+
print("Preparing datasets")
6161
if self.dummy:
6262
data_fn = create_dummy_dataset
6363
else:

sgm/lr_scheduler.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
1-
import logging
2-
31
import numpy as np
42

5-
logger = logging.getLogger(__name__)
6-
73

84
class LambdaWarmUpCosineScheduler:
95
"""
@@ -28,8 +24,9 @@ def __init__(
2824
self.verbosity_interval = verbosity_interval
2925

3026
def schedule(self, n, **kwargs):
31-
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
32-
logger.info(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
27+
if self.verbosity_interval > 0:
28+
if n % self.verbosity_interval == 0:
29+
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
3330
if n < self.lr_warm_up_steps:
3431
lr = (
3532
self.lr_max - self.lr_start
@@ -86,11 +83,12 @@ def find_in_interval(self, n):
8683
def schedule(self, n, **kwargs):
8784
cycle = self.find_in_interval(n)
8885
n = n - self.cum_cycles[cycle]
89-
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
90-
logger.info(
91-
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
92-
f"current cycle {cycle}"
93-
)
86+
if self.verbosity_interval > 0:
87+
if n % self.verbosity_interval == 0:
88+
print(
89+
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
90+
f"current cycle {cycle}"
91+
)
9492
if n < self.lr_warm_up_steps[cycle]:
9593
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
9694
cycle
@@ -116,11 +114,12 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
116114
def schedule(self, n, **kwargs):
117115
cycle = self.find_in_interval(n)
118116
n = n - self.cum_cycles[cycle]
119-
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
120-
logger.info(
121-
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
122-
f"current cycle {cycle}"
123-
)
117+
if self.verbosity_interval > 0:
118+
if n % self.verbosity_interval == 0:
119+
print(
120+
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
121+
f"current cycle {cycle}"
122+
)
124123

125124
if n < self.lr_warm_up_steps[cycle]:
126125
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[

sgm/models/autoencoder.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
import re
32
from abc import abstractmethod
43
from contextlib import contextmanager
@@ -15,8 +14,6 @@
1514
from ..modules.ema import LitEma
1615
from ..util import default, get_obj_from_str, instantiate_from_config
1716

18-
logger = logging.getLogger(__name__)
19-
2017

2118
class AbstractAutoencoder(pl.LightningModule):
2219
"""
@@ -41,7 +38,7 @@ def __init__(
4138

4239
if self.use_ema:
4340
self.model_ema = LitEma(self, decay=ema_decay)
44-
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
41+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
4542

4643
if ckpt_path is not None:
4744
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
@@ -63,16 +60,16 @@ def init_from_ckpt(
6360
for k in keys:
6461
for ik in ignore_keys:
6562
if re.match(ik, k):
66-
logger.debug(f"Deleting key {k} from state_dict.")
63+
print("Deleting key {} from state_dict.".format(k))
6764
del sd[k]
6865
missing, unexpected = self.load_state_dict(sd, strict=False)
69-
logger.debug(
66+
print(
7067
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
7168
)
7269
if len(missing) > 0:
73-
logger.info(f"Missing Keys: {missing}")
70+
print(f"Missing Keys: {missing}")
7471
if len(unexpected) > 0:
75-
logger.info(f"Unexpected Keys: {unexpected}")
72+
print(f"Unexpected Keys: {unexpected}")
7673

7774
@abstractmethod
7875
def get_input(self, batch) -> Any:
@@ -89,14 +86,14 @@ def ema_scope(self, context=None):
8986
self.model_ema.store(self.parameters())
9087
self.model_ema.copy_to(self)
9188
if context is not None:
92-
logger.info(f"{context}: Switched to EMA weights")
89+
print(f"{context}: Switched to EMA weights")
9390
try:
9491
yield None
9592
finally:
9693
if self.use_ema:
9794
self.model_ema.restore(self.parameters())
9895
if context is not None:
99-
logger.info(f"{context}: Restored training weights")
96+
print(f"{context}: Restored training weights")
10097

10198
@abstractmethod
10299
def encode(self, *args, **kwargs) -> torch.Tensor:
@@ -107,7 +104,7 @@ def decode(self, *args, **kwargs) -> torch.Tensor:
107104
raise NotImplementedError("decode()-method of abstract base class called")
108105

109106
def instantiate_optimizer_from_config(self, params, lr, cfg):
110-
logger.debug(f"loading >>> {cfg['target']} <<< optimizer from config")
107+
print(f"loading >>> {cfg['target']} <<< optimizer from config")
111108
return get_obj_from_str(cfg["target"])(
112109
params, lr=lr, **cfg.get("params", dict())
113110
)

sgm/models/diffusion.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
from contextlib import contextmanager
32
from typing import Any, Dict, List, Tuple, Union
43

@@ -19,8 +18,6 @@
1918
log_txt_as_img,
2019
)
2120

22-
logger = logging.getLogger(__name__)
23-
2421

2522
class DiffusionEngine(pl.LightningModule):
2623
def __init__(
@@ -76,7 +73,7 @@ def __init__(
7673
self.use_ema = use_ema
7774
if self.use_ema:
7875
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
79-
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
76+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
8077

8178
self.scale_factor = scale_factor
8279
self.disable_first_stage_autocast = disable_first_stage_autocast
@@ -97,13 +94,13 @@ def init_from_ckpt(
9794
raise NotImplementedError
9895

9996
missing, unexpected = self.load_state_dict(sd, strict=False)
100-
logger.info(
97+
print(
10198
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
10299
)
103100
if len(missing) > 0:
104-
logger.info(f"Missing Keys: {missing}")
101+
print(f"Missing Keys: {missing}")
105102
if len(unexpected) > 0:
106-
logger.info(f"Unexpected Keys: {unexpected}")
103+
print(f"Unexpected Keys: {unexpected}")
107104

108105
def _init_first_stage(self, config):
109106
model = instantiate_from_config(config).eval()
@@ -182,14 +179,14 @@ def ema_scope(self, context=None):
182179
self.model_ema.store(self.model.parameters())
183180
self.model_ema.copy_to(self.model)
184181
if context is not None:
185-
logger.info(f"{context}: Switched to EMA weights")
182+
print(f"{context}: Switched to EMA weights")
186183
try:
187184
yield None
188185
finally:
189186
if self.use_ema:
190187
self.model_ema.restore(self.model.parameters())
191188
if context is not None:
192-
logger.info(f"{context}: Restored training weights")
189+
print(f"{context}: Restored training weights")
193190

194191
def instantiate_optimizer_from_config(self, params, lr, cfg):
195192
return get_obj_from_str(cfg["target"])(
@@ -205,7 +202,7 @@ def configure_optimizers(self):
205202
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
206203
if self.scheduler_config is not None:
207204
scheduler = instantiate_from_config(self.scheduler_config)
208-
logger.debug("Setting up LambdaLR scheduler...")
205+
print("Setting up LambdaLR scheduler...")
209206
scheduler = [
210207
{
211208
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),

sgm/modules/attention.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
import math
32
from inspect import isfunction
43
from typing import Any, Optional
@@ -9,10 +8,6 @@
98
from packaging import version
109
from torch import nn
1110

12-
13-
logger = logging.getLogger(__name__)
14-
15-
1611
if version.parse(torch.__version__) >= version.parse("2.0.0"):
1712
SDP_IS_AVAILABLE = True
1813
from torch.backends.cuda import SDPBackend, sdp_kernel
@@ -41,9 +36,9 @@
4136
SDP_IS_AVAILABLE = False
4237
sdp_kernel = nullcontext
4338
BACKEND_MAP = {}
44-
logger.warning(
45-
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. "
46-
f"In fact, you are using PyTorch {torch.__version__}. You might want to consider upgrading."
39+
print(
40+
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
41+
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
4742
)
4843

4944
try:
@@ -53,7 +48,7 @@
5348
XFORMERS_IS_AVAILABLE = True
5449
except:
5550
XFORMERS_IS_AVAILABLE = False
56-
logger.debug("no module 'xformers'. Processing without...")
51+
print("no module 'xformers'. Processing without...")
5752

5853
from .diffusionmodules.util import checkpoint
5954

@@ -294,7 +289,7 @@ def __init__(
294289
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
295290
):
296291
super().__init__()
297-
logger.info(
292+
print(
298293
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
299294
f"{heads} heads with a dimension of {dim_head}."
300295
)
@@ -398,21 +393,22 @@ def __init__(
398393
super().__init__()
399394
assert attn_mode in self.ATTENTION_MODES
400395
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
401-
logger.warning(
396+
print(
402397
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
403398
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
404399
)
405400
attn_mode = "softmax"
406401
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
407-
logger.warning(
402+
print(
408403
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
409404
)
410405
if not XFORMERS_IS_AVAILABLE:
411-
raise NotImplementedError(
412-
"Please install xformers via e.g. 'pip install xformers==0.0.16'"
413-
)
414-
logger.info("Falling back to xformers efficient attention.")
415-
attn_mode = "softmax-xformers"
406+
assert (
407+
False
408+
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
409+
else:
410+
print("Falling back to xformers efficient attention.")
411+
attn_mode = "softmax-xformers"
416412
attn_cls = self.ATTENTION_MODES[attn_mode]
417413
if version.parse(torch.__version__) >= version.parse("2.0.0"):
418414
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
@@ -441,7 +437,7 @@ def __init__(
441437
self.norm3 = nn.LayerNorm(dim)
442438
self.checkpoint = checkpoint
443439
if self.checkpoint:
444-
logger.info(f"{self.__class__.__name__} is using checkpointing")
440+
print(f"{self.__class__.__name__} is using checkpointing")
445441

446442
def forward(
447443
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
@@ -558,7 +554,7 @@ def __init__(
558554
sdp_backend=None,
559555
):
560556
super().__init__()
561-
logger.debug(
557+
print(
562558
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
563559
)
564560
from omegaconf import ListConfig
@@ -567,8 +563,8 @@ def __init__(
567563
context_dim = [context_dim]
568564
if exists(context_dim) and isinstance(context_dim, list):
569565
if depth != len(context_dim):
570-
logger.warning(
571-
f"{self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
566+
print(
567+
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
572568
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
573569
)
574570
# depth does not match context dims.

sgm/modules/autoencoding/losses/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import logging
21
from typing import Any, Union
32

43
import torch
@@ -11,9 +10,6 @@
1110
from ....util import default, instantiate_from_config
1211

1312

14-
logger = logging.getLogger(__name__)
15-
16-
1713
def adopt_weight(weight, global_step, threshold=0, value=0.0):
1814
if global_step < threshold:
1915
weight = value
@@ -108,7 +104,7 @@ def __init__(
108104
super().__init__()
109105
self.dims = dims
110106
if self.dims > 2:
111-
logger.info(
107+
print(
112108
f"running with dims={dims}. This means that for perceptual loss calculation, "
113109
f"the LPIPS loss will be applied to each frame independently. "
114110
)

0 commit comments

Comments
 (0)