|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +from torchvision.models import get_model |
| 4 | +from torch.nn.functional import mse_loss |
| 5 | + |
| 6 | +from typing import List |
| 7 | +from torch import Tensor |
| 8 | + |
| 9 | +from .recorder import Traced |
| 10 | +from .utils import compute_stat |
| 11 | + |
| 12 | +def style_loss(targ : Tensor, pred : Tensor, k : float = 0.) -> Tensor: |
| 13 | + bs, c, *_ = targ.shape |
| 14 | + |
| 15 | + # Compare mean & std of target and pred tensor |
| 16 | + targ_mean, targ_std = compute_stat(targ) |
| 17 | + pred_mean, pred_std = compute_stat(pred) |
| 18 | + |
| 19 | + loss_mean = mse_loss(targ_mean, pred_mean, reduction='none') |
| 20 | + loss_std = mse_loss(targ_std , pred_std , reduction='none') |
| 21 | + |
| 22 | + if k > 0: |
| 23 | + loss_mean, sort_idx = torch.sort(loss_mean, dim=1) |
| 24 | + |
| 25 | + loss_mean[:, int(c * k):] = 0 |
| 26 | + loss_std[:, sort_idx[:, int(c * k):]] = 0 |
| 27 | + |
| 28 | + return (loss_mean + loss_std).mean() |
| 29 | + |
| 30 | +def content_loss(targ : Tensor, pred : Tensor) -> Tensor: |
| 31 | + # Compare mean & std of target and pred tensor |
| 32 | + targ_mean, targ_std = compute_stat(targ) |
| 33 | + pred_mean, pred_std = compute_stat(pred) |
| 34 | + |
| 35 | + norm_targ = (targ - targ_mean) / targ_std |
| 36 | + norm_pred = (pred - pred_mean) / pred_std |
| 37 | + |
| 38 | + return mse_loss(norm_targ, norm_pred) |
| 39 | + |
| 40 | +class StyleLoss(nn.Module): |
| 41 | + ''' |
| 42 | + This module implements the style loss used for model |
| 43 | + training in the paper: |
| 44 | + `Hierarchy Flow For High-Fidelity Image-to-Image Translation` |
| 45 | + Fan et al. (2023) (arxiv:2308.06909). |
| 46 | +
|
| 47 | + This loss is a combination of a standard VGG-19 (style) loss |
| 48 | + and and an original (modification of) content loss. The tradeoff |
| 49 | + between content-preserving and style-preserving is controlled by |
| 50 | + the `content_weight` parameter. |
| 51 | +
|
| 52 | + Args: |
| 53 | + content_weight (float): The weight of the content loss. |
| 54 | + style_weight (float): The weight of the style loss. |
| 55 | + vgg_layers (list): The layers of the VGG-19 model to use. |
| 56 | + ''' |
| 57 | + |
| 58 | + def __init__( |
| 59 | + self, |
| 60 | + enc_depth : List[int] = (3, 10, 17, 30), |
| 61 | + backbone : str = 'vgg19_bn', |
| 62 | + content_weight : float = .8, |
| 63 | + ) -> None: |
| 64 | + super().__init__() |
| 65 | + |
| 66 | + backbone = get_model(backbone, weights='DEFAULT') |
| 67 | + backbone = Traced(backbone, enc_depth) |
| 68 | + |
| 69 | + self.backbone = backbone |
| 70 | + |
| 71 | + self.content_weight = content_weight |
| 72 | + |
| 73 | + def forward( |
| 74 | + self, |
| 75 | + orig_img : Tensor, |
| 76 | + targ_sty : Tensor, |
| 77 | + pred_img : Tensor, |
| 78 | + align_k : float = 0.8 |
| 79 | + ) -> Tensor: |
| 80 | + ''' |
| 81 | + ''' |
| 82 | + # Get features from the backbone model for the |
| 83 | + # original content image |
| 84 | + (*_, orig_feat) = self.backbone(orig_img) |
| 85 | + |
| 86 | + # Get the style features for both the reference |
| 87 | + # and produced images |
| 88 | + targ_feats = self.backbone(targ_sty) |
| 89 | + pred_feats = self.backbone(pred_img) |
| 90 | + |
| 91 | + # Compute the style & content loss |
| 92 | + loss_style = sum([style_loss(f1, f2, k=align_k) for f1, f2 in zip(targ_feats, pred_feats)]) |
| 93 | + loss_content = content_loss(orig_feat, pred_feats[-1]) |
| 94 | + |
| 95 | + # Combine the style and content loss |
| 96 | + loss = loss_content + self.content_weight * loss_style |
| 97 | + |
| 98 | + return loss |
| 99 | + |
| 100 | + |
0 commit comments