Skip to content

Commit 864223a

Browse files
committed
add trainable conv_seg for MaskCLIP
1 parent 9c793ec commit 864223a

File tree

3 files changed

+18
-19
lines changed

3 files changed

+18
-19
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,4 +123,8 @@ pretrain/
123123
vis/
124124
log/
125125
show_dirs/
126-
wordnet/
126+
wordnet/
127+
corrupt.sh
128+
run.sh
129+
configs/rename.py
130+
demo/maskclip_demo.ipynb

configs/maskclip/finetune/maskclip_vit16_p10_ft_480x480_40k_pascal_context_59.py renamed to configs/maskclip/finetune/maskclip_vit16_p10_ftt_480x480_40k_pascal_context_59.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
_base_ = './maskclip_vit16_p10_480x480_40k_pascal_context_59.py'
22
model = dict(
33
decode_head=dict(
4-
freeze_text=True,
4+
text_embeddings_path=None,
55
),
66
)
77
# data = dict(

mmseg/models/decode_heads/maskclip_head.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,20 @@ def __init__(self, text_categories, text_channels, text_embeddings_path,
1515
visual_projs_path, vit=False, bg_thresh=0.,
1616
num_vote=0, vote_thresh=0., topk_text=0,
1717
cls_thresh=0., attn_pooling=False, num_heads=32,
18-
freeze_text=False, **kwargs):
18+
**kwargs):
1919
super(MaskClipHead, self).__init__(**kwargs)
2020

2121
self.text_categories = text_categories
2222
self.text_channels = text_channels
2323
self.text_embeddings_path = text_embeddings_path
2424
self.visual_projs_path = visual_projs_path
2525

26-
self.register_buffer('text_embeddings', torch.randn(text_categories, text_channels))
26+
if self.text_embeddings_path is None:
27+
self.text_embeddings = nn.Parameter(torch.zeros(text_categories, text_channels))
28+
nn.init.normal_(self.text_embeddings, mean=0.0, std=0.01)
29+
else:
30+
self.register_buffer('text_embeddings', torch.randn(text_categories, text_channels))
31+
self.load_text_embeddings()
2732

2833
self.vit = vit
2934
if vit:
@@ -33,6 +38,7 @@ def __init__(self, text_categories, text_channels, text_embeddings_path,
3338
self.k_proj = nn.Conv2d(self.in_channels, self.in_channels, 1)
3439
self.v_proj = nn.Conv2d(self.in_channels, self.in_channels, 1)
3540
self.c_proj = nn.Conv2d(self.in_channels, text_channels, 1)
41+
self.load_visual_projs()
3642

3743
self.bg_thresh = bg_thresh
3844
self.num_vote = num_vote
@@ -43,14 +49,13 @@ def __init__(self, text_categories, text_channels, text_embeddings_path,
4349
self.cls_thresh = cls_thresh
4450
self.attn_pooling = attn_pooling
4551
self.num_heads = num_heads
46-
self.freeze_text = freeze_text
47-
48-
self.load_text_embeddings()
49-
self.load_visual_projs()
5052

5153
def init_weights(self):
5254
super(MaskClipHead, self).init_weights()
53-
self.load_text_embeddings()
55+
if self.text_embeddings_path is None:
56+
nn.init.normal_(self.text_embeddings, mean=0.0, std=0.01)
57+
else:
58+
self.load_text_embeddings()
5459
self.load_visual_projs()
5560

5661
def load_text_embeddings(self):
@@ -70,16 +75,6 @@ def load_visual_projs(self):
7075
current_attr.load_state_dict(state_dict)
7176
print_log(f'Loaded proj weights from {self.visual_projs_path}', logger=get_root_logger())
7277

73-
def _freeze_text(self):
74-
"""Freeze params and norm stats."""
75-
if self.freeze_text:
76-
self.text_embeddings.requires_grad = False
77-
78-
def train(self, mode=True):
79-
super(MaskClipHead, self).train(mode)
80-
if mode:
81-
self._freeze_text()
82-
8378
def forward(self, inputs):
8479
x = self._transform_inputs(inputs)
8580
q, k, v, cls_token = None, None, None, None

0 commit comments

Comments
 (0)