@@ -15,15 +15,20 @@ def __init__(self, text_categories, text_channels, text_embeddings_path,
1515 visual_projs_path , vit = False , bg_thresh = 0. ,
1616 num_vote = 0 , vote_thresh = 0. , topk_text = 0 ,
1717 cls_thresh = 0. , attn_pooling = False , num_heads = 32 ,
18- freeze_text = False , ** kwargs ):
18+ ** kwargs ):
1919 super (MaskClipHead , self ).__init__ (** kwargs )
2020
2121 self .text_categories = text_categories
2222 self .text_channels = text_channels
2323 self .text_embeddings_path = text_embeddings_path
2424 self .visual_projs_path = visual_projs_path
2525
26- self .register_buffer ('text_embeddings' , torch .randn (text_categories , text_channels ))
26+ if self .text_embeddings_path is None :
27+ self .text_embeddings = nn .Parameter (torch .zeros (text_categories , text_channels ))
28+ nn .init .normal_ (self .text_embeddings , mean = 0.0 , std = 0.01 )
29+ else :
30+ self .register_buffer ('text_embeddings' , torch .randn (text_categories , text_channels ))
31+ self .load_text_embeddings ()
2732
2833 self .vit = vit
2934 if vit :
@@ -33,6 +38,7 @@ def __init__(self, text_categories, text_channels, text_embeddings_path,
3338 self .k_proj = nn .Conv2d (self .in_channels , self .in_channels , 1 )
3439 self .v_proj = nn .Conv2d (self .in_channels , self .in_channels , 1 )
3540 self .c_proj = nn .Conv2d (self .in_channels , text_channels , 1 )
41+ self .load_visual_projs ()
3642
3743 self .bg_thresh = bg_thresh
3844 self .num_vote = num_vote
@@ -43,14 +49,13 @@ def __init__(self, text_categories, text_channels, text_embeddings_path,
4349 self .cls_thresh = cls_thresh
4450 self .attn_pooling = attn_pooling
4551 self .num_heads = num_heads
46- self .freeze_text = freeze_text
47-
48- self .load_text_embeddings ()
49- self .load_visual_projs ()
5052
5153 def init_weights (self ):
5254 super (MaskClipHead , self ).init_weights ()
53- self .load_text_embeddings ()
55+ if self .text_embeddings_path is None :
56+ nn .init .normal_ (self .text_embeddings , mean = 0.0 , std = 0.01 )
57+ else :
58+ self .load_text_embeddings ()
5459 self .load_visual_projs ()
5560
5661 def load_text_embeddings (self ):
@@ -70,16 +75,6 @@ def load_visual_projs(self):
7075 current_attr .load_state_dict (state_dict )
7176 print_log (f'Loaded proj weights from { self .visual_projs_path } ' , logger = get_root_logger ())
7277
73- def _freeze_text (self ):
74- """Freeze params and norm stats."""
75- if self .freeze_text :
76- self .text_embeddings .requires_grad = False
77-
78- def train (self , mode = True ):
79- super (MaskClipHead , self ).train (mode )
80- if mode :
81- self ._freeze_text ()
82-
8378 def forward (self , inputs ):
8479 x = self ._transform_inputs (inputs )
8580 q , k , v , cls_token = None , None , None , None
0 commit comments