Skip to content

Commit 814a874

Browse files
UltraShieldRogalexmirrington
authored andcommitted
cnn config update
1 parent ece7a04 commit 814a874

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

code/algorithm/main.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,12 @@ def main(config):
127127
model = TFIDF(ImageCaptionDataset.CLASSES,
128128
tfidf_vectorizer,
129129
config.threshold)
130-
elif config.model_type == 'cnn':
131-
model = CNN(ImageCaptionDataset.CLASSES, 'alexnet')
130+
elif config.model_type in [
131+
'resnet', 'alexnet', 'vgg',
132+
'densenet', 'googlenet', 'resnext',
133+
'wide_resnet', 'mnasnet'
134+
]:
135+
model = CNN(ImageCaptionDataset.CLASSES, config.model_type)
132136

133137
loss_func = BCEWithLogitsLoss()
134138
optimiser = Adam(model.parameters())
@@ -433,7 +437,14 @@ def parse_args(args):
433437
'rcnn',
434438
'lstm',
435439
'tfidf',
436-
'cnn'
440+
'resnet',
441+
'alexnet',
442+
'vgg',
443+
'densenet',
444+
'googlenet',
445+
'resnext',
446+
'wide_resnet',
447+
'mnasnet'
437448
],
438449
type=str,
439450
required=False,

code/logs/20200529-021503.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
micro_f1: 0.702053, macro_f1: 0.491366, weighted_f1: 0.710688
2+
micro_f1: 0.710000, macro_f1: 0.475064, weighted_f1: 0.727771

0 commit comments

Comments
 (0)