Skip to content

Commit 635fd19

Browse files
committed
Added StyleLoss module
1 parent f4779c4 commit 635fd19

File tree

3 files changed

+219
-15
lines changed

3 files changed

+219
-15
lines changed

src/losses.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import torch
2+
import torch.nn as nn
3+
from torchvision.models import get_model
4+
from torch.nn.functional import mse_loss
5+
6+
from typing import List
7+
from torch import Tensor
8+
9+
from .recorder import Traced
10+
from .utils import compute_stat
11+
12+
def style_loss(targ : Tensor, pred : Tensor, k : float = 0.) -> Tensor:
13+
bs, c, *_ = targ.shape
14+
15+
# Compare mean & std of target and pred tensor
16+
targ_mean, targ_std = compute_stat(targ)
17+
pred_mean, pred_std = compute_stat(pred)
18+
19+
loss_mean = mse_loss(targ_mean, pred_mean, reduction='none')
20+
loss_std = mse_loss(targ_std , pred_std , reduction='none')
21+
22+
if k > 0:
23+
loss_mean, sort_idx = torch.sort(loss_mean, dim=1)
24+
25+
loss_mean[:, int(c * k):] = 0
26+
loss_std[:, sort_idx[:, int(c * k):]] = 0
27+
28+
return (loss_mean + loss_std).mean()
29+
30+
def content_loss(targ : Tensor, pred : Tensor) -> Tensor:
31+
# Compare mean & std of target and pred tensor
32+
targ_mean, targ_std = compute_stat(targ)
33+
pred_mean, pred_std = compute_stat(pred)
34+
35+
norm_targ = (targ - targ_mean) / targ_std
36+
norm_pred = (pred - pred_mean) / pred_std
37+
38+
return mse_loss(norm_targ, norm_pred)
39+
40+
class StyleLoss(nn.Module):
41+
'''
42+
This module implements the style loss used for model
43+
training in the paper:
44+
`Hierarchy Flow For High-Fidelity Image-to-Image Translation`
45+
Fan et al. (2023) (arxiv:2308.06909).
46+
47+
This loss is a combination of a standard VGG-19 (style) loss
48+
and and an original (modification of) content loss. The tradeoff
49+
between content-preserving and style-preserving is controlled by
50+
the `content_weight` parameter.
51+
52+
Args:
53+
content_weight (float): The weight of the content loss.
54+
style_weight (float): The weight of the style loss.
55+
vgg_layers (list): The layers of the VGG-19 model to use.
56+
'''
57+
58+
def __init__(
59+
self,
60+
enc_depth : List[int] = (3, 10, 17, 30),
61+
backbone : str = 'vgg19_bn',
62+
content_weight : float = .8,
63+
) -> None:
64+
super().__init__()
65+
66+
backbone = get_model(backbone, weights='DEFAULT')
67+
backbone = Traced(backbone, enc_depth)
68+
69+
self.backbone = backbone
70+
71+
self.content_weight = content_weight
72+
73+
def forward(
74+
self,
75+
orig_img : Tensor,
76+
targ_sty : Tensor,
77+
pred_img : Tensor,
78+
align_k : float = 0.8
79+
) -> Tensor:
80+
'''
81+
'''
82+
# Get features from the backbone model for the
83+
# original content image
84+
(*_, orig_feat) = self.backbone(orig_img)
85+
86+
# Get the style features for both the reference
87+
# and produced images
88+
targ_feats = self.backbone(targ_sty)
89+
pred_feats = self.backbone(pred_img)
90+
91+
# Compute the style & content loss
92+
loss_style = sum([style_loss(f1, f2, k=align_k) for f1, f2 in zip(targ_feats, pred_feats)])
93+
loss_content = content_loss(orig_feat, pred_feats[-1])
94+
95+
# Combine the style and content loss
96+
loss = loss_content + self.content_weight * loss_style
97+
98+
return loss
99+
100+

src/recorder.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
2+
import torch.nn as nn
3+
from torch import Tensor
4+
5+
from typing import List
6+
7+
from .utils import default
8+
from .utils import flatten
9+
10+
class FeatureRecorder:
11+
'''
12+
Basic feature recording class that implements a PyTorch hook to
13+
acquire a layer activations as they get processed by the network
14+
'''
15+
16+
def __init__(
17+
self,
18+
names : List[str],
19+
) -> None:
20+
21+
self.names = names
22+
self.feats = {l : [] for l in names}
23+
24+
def __call__(self, module : nn.Module, inp : Tensor, out : Tensor) -> None:
25+
# Detach layer output from PyTorch graph
26+
data = out.detach()
27+
28+
# Get the module name
29+
layer = module.name
30+
31+
self.feats[layer].append(data)
32+
33+
def clean(
34+
self,
35+
names : List[str] | None = None,
36+
) -> None:
37+
self.names = default(names, self.names)
38+
self.feats = {k : [] for k in self.names}
39+
40+
class Traced(nn.Module):
41+
'''
42+
A wrapper class of a Torch Module whose intermediate activations
43+
are traced (recorded) via forward hooks at chosen depth.
44+
'''
45+
46+
def __init__(
47+
self,
48+
module : nn.Module,
49+
depths : List[int],
50+
) -> None:
51+
super().__init__()
52+
53+
self.module = module
54+
self.depths = depths
55+
56+
# Get the list of layers to be traced
57+
self.layers = [l for depth, l in enumerate(flatten(module)) if depth in depths]
58+
59+
for layer, name in zip(self.layers, self.names): layer.name = name
60+
61+
# Initialize the feature recorder
62+
self.names = [f'enc_{d}' for d in depths]
63+
self.recorder = FeatureRecorder(self.names)
64+
65+
# Register the forward hooks for each layered targeted as 'traced'
66+
self.hook_handles = [l.register_forward_hook(self.recorder) for l in self.layers]
67+
68+
@property
69+
def features(self) -> List[Tensor]:
70+
return [self.recorder.feats[k] for k in self.names]
71+
72+
def forward(self, inp : Tensor, auto_clean : bool = True) -> List[Tensor]:
73+
# Propagate the input into the network
74+
_ = self.module(inp)
75+
76+
# Collect the features
77+
feats = self.features
78+
79+
if auto_clean: self.clean()
80+
81+
return feats
82+
83+
def clean(self) -> None:
84+
self.recorder.clean()
85+
86+
for h in self.hook_handles: h.remove()
87+
88+
def __str__(self) -> str:
89+
msg = 'Tracing module: \n'
90+
msg += f'{self.module} \n'
91+
msg += f'Traced layers: {self.layers} \n'
92+
msg += f'Traced depths: {self.depths} \n'
93+
return msg

src/utils.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,36 @@
22
import torch.nn as nn
33

44
from torch import Tensor
5-
from typing import Any, Tuple
5+
from typing import Any, Tuple, List
66

77
from einops import rearrange
88

9+
def exists(var : Any | None) -> bool:
10+
return var is not None
11+
912
def default(var : Any | None, val : Any) -> Any:
1013
return val if var is None else var
1114

15+
def flatten(model : nn.Module, exclude : List[nn.Module] = []) -> List[nn.Module]:
16+
flattened = [flatten(children) for children in model.children()]
17+
res = [model] if list(model.children()) == [] else []
18+
19+
for c in flattened: res += c
20+
21+
return res
22+
23+
24+
def compute_stat(feat : Tensor, eps : float = 1e-5) -> Tuple[Tensor, Tensor]:
25+
# Check input dimension
26+
bs, c, h, w = feat.shape
27+
28+
var = rearrange(feat, 'b c h w -> b c (h w)').var(dim=2) + eps
29+
std = rearrange(var.sqrt(), 'b c -> b c 1 1')
30+
31+
mean = rearrange(feat, 'b c h w -> b c (h w)').mean(dim=2)
32+
mean = rearrange(mean, 'b c -> b c 1 1')
33+
34+
return mean, std
1235

1336
class ReversibleConcat(nn.Module):
1437
'''
@@ -59,20 +82,8 @@ def __init__(self) -> None:
5982

6083
def forward(self, subj : Tensor, feat_mean : Tensor, feat_std : Tensor) -> Tensor:
6184
# Get subject mean and standard deviation
62-
subj_mean, subj_std = self._compute_stat(subj)
85+
subj_mean, subj_std = compute_stat(subj)
6386

6487
norm_feat = (subj - subj_mean) / subj_std
6588

66-
return norm_feat * feat_std + feat_mean
67-
68-
def _compute_stat(self, feat : Tensor, eps : float = 1e-5) -> Tuple[Tensor, Tensor]:
69-
# Check input dimension
70-
bs, c, h, w = feat.shape
71-
72-
var = rearrange(feat, 'b c h w -> b c (h w)').var(dim=2) + eps
73-
std = rearrange(var.sqrt(), 'b c -> b c 1 1')
74-
75-
mean = rearrange(feat, 'b c h w -> b c (h w)').mean(dim=2)
76-
mean = rearrange(mean, 'b c -> b c 1 1')
77-
78-
return mean, std
89+
return norm_feat * feat_std + feat_mean

0 commit comments

Comments
 (0)