55
66from typing import List
77from torch import Tensor
8- from itertools import pairwise
8+ from itertools import pairwise , repeat
99
1010class HierarchyFlow (nn .Module ):
1111 '''
@@ -20,12 +20,12 @@ def __init__(
2020 feat_channel_mult : List [int ] = [3 , 3 ],
2121 pad_size : int = 10 ,
2222 pad_mode : str = 'reflect' ,
23- style_dim : int = 8 ,
24- style_kw : dict | None = None ,
23+ style_out_dim : int = 8 ,
24+ style_conv_kw : dict | None = None ,
2525 ):
2626 super (HierarchyFlow , self ).__init__ ()
2727
28- style_kw = default (style_kw , { 'padding ' : 0 } )
28+ style_conv_kw = default (style_conv_kw , repeat ({ 'kernel_size ' : 3 }) )
2929
3030 self .inp_channels = inp_channels
3131 self .flow_channel_mult = flow_channel_mult
@@ -39,18 +39,18 @@ def __init__(
3939 ReversiblePad2d (pad_size , pad_mode = pad_mode ),
4040 * [HierarchyBlock (
4141 inp_chn , out_chn ,
42- mlp_inp_dim = style_dim )
42+ mlp_inp_dim = style_out_dim )
4343 for inp_chn , out_chn in pairwise (flow_channels )]
4444 ]
4545 )
4646
4747 self .style_block = nn .Sequential ([
4848 nn .Sequential (
49- nn .Conv2d (inp_chn , out_chn , ** style_kw ),
49+ nn .Conv2d (inp_chn , out_chn , ** style ),
5050 nn .ReLU (),
51- ) for inp_chn , out_chn in zip (feat_channels )],
51+ ) for ( inp_chn , out_chn ), style in zip (feat_channels , style_conv_kw )],
5252 nn .AdaptiveAvgPool2d (1 ), # global average pooling
53- nn .Conv2d (feat_channels [- 1 ], style_dim , 1 , 1 , 0 )
53+ nn .Conv2d (feat_channels [- 1 ], style_out_dim , 1 , 1 , 0 )
5454 )
5555
5656 def forward (
0 commit comments