Skip to content

Commit 57900d3

Browse files
committed
Fixed minor bug in hflow, updated README
1 parent 635fd19 commit 57900d3

File tree

3 files changed

+22
-14
lines changed

3 files changed

+22
-14
lines changed

README.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,19 @@ The authors release their official implementation which can be found [here](http
88

99
```python
1010
import torch
11-
from src.hflow import HierarchicalFlow
11+
from src.hflow import HierarchyFlow
1212

1313
flow = HierarchyFlow(
1414
inp_channels=3,
15-
out_channels=[30, 120],
16-
pad_size=10,
15+
flow_channel_mult=[10, 4, 4], # Set channel mult of the hierarchy convs
16+
feat_channel_mult=[3, 3, 3], # Set the channel mult factors of the style convs
17+
pad_size = 10, # Input pad size
18+
pad_mode = 'reflect',
19+
style_out_dim = 8, # Number of channel of final style features
20+
style_conv_kw = [
21+
{'kernel_size' : 7, 'stride' : 1, 'padding' : 3},
22+
*[{'kernel_size' : 4, 'stride' : 2, 'padding' : 1}] * 2
23+
] # Parameter for style convolutional layers, should match length of feat_channel_mult
1724
)
1825

1926
x = torch.randn(1, 3, 256, 256)

src/hflow.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from typing import List
77
from torch import Tensor
8-
from itertools import pairwise
8+
from itertools import pairwise, repeat
99

1010
class HierarchyFlow(nn.Module):
1111
'''
@@ -20,12 +20,12 @@ def __init__(
2020
feat_channel_mult : List[int] = [3, 3],
2121
pad_size : int = 10,
2222
pad_mode : str = 'reflect',
23-
style_dim : int = 8,
24-
style_kw : dict | None = None,
23+
style_out_dim : int = 8,
24+
style_conv_kw : dict | None = None,
2525
):
2626
super(HierarchyFlow, self).__init__()
2727

28-
style_kw = default(style_kw, {'padding' : 0})
28+
style_conv_kw = default(style_conv_kw, repeat({'kernel_size' : 3}))
2929

3030
self.inp_channels = inp_channels
3131
self.flow_channel_mult = flow_channel_mult
@@ -39,18 +39,18 @@ def __init__(
3939
ReversiblePad2d(pad_size, pad_mode=pad_mode),
4040
*[HierarchyBlock(
4141
inp_chn, out_chn,
42-
mlp_inp_dim=style_dim)
42+
mlp_inp_dim=style_out_dim)
4343
for inp_chn, out_chn in pairwise(flow_channels)]
4444
]
4545
)
4646

4747
self.style_block = nn.Sequential([
4848
nn.Sequential(
49-
nn.Conv2d(inp_chn, out_chn, **style_kw),
49+
nn.Conv2d(inp_chn, out_chn, **style),
5050
nn.ReLU(),
51-
) for inp_chn, out_chn in zip(feat_channels)],
51+
) for (inp_chn, out_chn), style in zip(feat_channels, style_conv_kw)],
5252
nn.AdaptiveAvgPool2d(1), # global average pooling
53-
nn.Conv2d(feat_channels[-1], style_dim, 1, 1, 0)
53+
nn.Conv2d(feat_channels[-1], style_out_dim, 1, 1, 0)
5454
)
5555

5656
def forward(

src/losses.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,10 @@ class StyleLoss(nn.Module):
5050
the `content_weight` parameter.
5151
5252
Args:
53-
content_weight (float): The weight of the content loss.
54-
style_weight (float): The weight of the style loss.
55-
vgg_layers (list): The layers of the VGG-19 model to use.
53+
- enc_depth (List[int]): List of integers representing the
54+
backbone layer depth to use as feature encoders.
55+
- backbone (str): Torchvision model to use as feature extractor
56+
- content_weight (float): The weight of the content loss.
5657
'''
5758

5859
def __init__(

0 commit comments

Comments
 (0)