Skip to content

Commit dd154eb

Browse files
committed
Fixed some criteria formatting and missing flags in multisimilarity.py
1 parent ec7fd6c commit dd154eb

File tree

10 files changed

+41
-15
lines changed

10 files changed

+41
-15
lines changed

criteria/adversarial_separation.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,6 @@ def __init__(self, opt):
2121
super().__init__()
2222

2323
####
24-
self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS
25-
self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER
26-
self.REQUIRES_OPTIM = REQUIRES_OPTIM
27-
28-
####
2924
self.embed_dim = opt.embed_dim
3025
self.proj_dim = opt.diva_decorrnet_dim
3126

@@ -43,6 +38,14 @@ def __init__(self, opt):
4338
self.lr = opt.diva_decorrnet_lr
4439

4540

41+
####
42+
self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS
43+
self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER
44+
self.REQUIRES_OPTIM = REQUIRES_OPTIM
45+
46+
47+
48+
4649
def forward(self, feature_dict):
4750
#Apply gradient reversal on input embeddings.
4851
adj_feature_dict = {key:torch.nn.functional.normalize(grad_reverse(features),dim=-1) for key, features in feature_dict.items()}

criteria/angular.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def __init__(self, opt, batchminer):
2424
self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER
2525
self.REQUIRES_OPTIM = REQUIRES_OPTIM
2626

27+
28+
2729
def forward(self, batch, labels, **kwargs):
2830
####NOTE: Normalize Angular Loss, but dont normalize npair loss!
2931
anchors, positives, negatives = self.batchminer(batch, labels)

criteria/arcface.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,6 @@ def __init__(self, opt):
1313
super(Criterion, self).__init__()
1414
self.par = opt
1515

16-
####
17-
self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS
18-
self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER
19-
self.REQUIRES_OPTIM = REQUIRES_OPTIM
20-
2116
####
2217
self.angular_margin = opt.loss_arcface_angular_margin
2318
self.feature_scale = opt.loss_arcface_feature_scale
@@ -30,6 +25,14 @@ def __init__(self, opt):
3025

3126
self.lr = opt.loss_arcface_lr
3227

28+
####
29+
self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS
30+
self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER
31+
self.REQUIRES_OPTIM = REQUIRES_OPTIM
32+
33+
34+
35+
3336
def forward(self, batch, labels, **kwargs):
3437
bs, labels = len(batch), labels.to(self.par.device)
3538

criteria/contrastive.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def __init__(self, opt, batchminer):
2222
self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER
2323
self.REQUIRES_OPTIM = REQUIRES_OPTIM
2424

25+
26+
2527
def forward(self, batch, labels, **kwargs):
2628
sampled_triplets = self.batchminer(batch, labels)
2729

criteria/lifted.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def __init__(self, opt, batchminer):
2222
self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS
2323
self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER
2424
self.REQUIRES_OPTIM = REQUIRES_OPTIM
25+
26+
2527

2628
def forward(self, batch, labels, **kwargs):
2729
anchors, positives, negatives = self.batchminer(batch, labels)

criteria/margin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def __init__(self, opt, batchminer):
3434
self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER
3535
self.REQUIRES_OPTIM = REQUIRES_OPTIM
3636

37+
38+
3739
def forward(self, batch, labels, **kwargs):
3840
sampled_triplets = self.batchminer(batch, labels)
3941

criteria/multisimilarity.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ def __init__(self, opt):
1919

2020
self.name = 'multisimilarity'
2121

22+
####
23+
self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS
24+
self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER
25+
self.REQUIRES_OPTIM = REQUIRES_OPTIM
26+
27+
2228
def forward(self, batch, labels, **kwargs):
2329
similarity = batch.mm(batch.T)
2430

criteria/proxynca.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,6 @@ def __init__(self, opt):
1717
"""
1818
super(Criterion, self).__init__()
1919

20-
####
21-
self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS
22-
self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER
23-
self.REQUIRES_OPTIM = REQUIRES_OPTIM
24-
2520
####
2621
self.num_proxies = opt.n_classes
2722
self.embed_dim = opt.embed_dim
@@ -34,6 +29,13 @@ def __init__(self, opt):
3429
self.optim_dict_list = [{'params':self.proxies, 'lr':opt.lr * opt.loss_proxynca_lrmulti}]
3530

3631

32+
####
33+
self.ALLOWED_MINING_OPS = ALLOWED_MINING_OPS
34+
self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER
35+
self.REQUIRES_OPTIM = REQUIRES_OPTIM
36+
37+
38+
3739
def forward(self, batch, labels, **kwargs):
3840
#Empirically, multiplying the embeddings during the computation of the loss seem to allow for more stable training;
3941
#Acts as a temperature in the NCA objective.

criteria/quadruplet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ def __init__(self, opt, batchminer):
2222
self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER
2323
self.REQUIRES_OPTIM = REQUIRES_OPTIM
2424

25+
26+
2527
def triplet_distance(self, anchor, positive, negative):
2628
return torch.nn.functional.relu(torch.norm(anchor-positive, p=2, dim=-1)-torch.norm(anchor-negative, p=2, dim=-1)+self.margin_alpha_1)
2729

criteria/snr.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def __init__(self, opt, batchminer):
2424
self.REQUIRES_BATCHMINER = REQUIRES_BATCHMINER
2525
self.REQUIRES_OPTIM = REQUIRES_OPTIM
2626

27+
28+
2729
def forward(self, batch, labels, **kwargs):
2830
sampled_triplets = self.batchminer(batch, labels)
2931
anchors = [triplet[0] for triplet in sampled_triplets]

0 commit comments

Comments
 (0)