Skip to content

Commit f18c185

Browse files
committed
load from github release
1 parent ff2109a commit f18c185

File tree

4 files changed

+14
-3
lines changed

4 files changed

+14
-3
lines changed

scripts/reproduce_test/outdoor.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ cd $PROJECT_DIR
99

1010
data_cfg_path="configs/data/megadepth_test_1500.py"
1111
main_cfg_path="configs/jamma/outdoor/test.py"
12-
ckpt_path='weight/jamma_weight.ckpt'
12+
ckpt_path='official' # your path or 'official' (load from github release)
1313
dump_dir="dump/jamma_outdoor"
1414
profiler_name="inference"
1515
n_nodes=1 # mannually keep this the same with --nodes

src/jamma/backbone.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ def __init__(self):
1616
self.cnn.stages[2] = None
1717
self.cnn.stages[3] = None
1818

19+
state_dict = torch.hub.load_state_dict_from_url(
20+
'https://github.com/leoluxxx/JamMa/releases/download/v0.1/convnextv2_nano_pretrain.ckpt',
21+
file_name='convnextv2_nano_pretrain.ckpt')
22+
self.cnn.load_state_dict(state_dict, strict=True)
23+
1924
self.lin_4 = nn.Conv2d(80, 128, 1)
2025
self.lin_8 = nn.Conv2d(160, 256, 1)
2126

src/lightning/lightning_jamma.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,14 @@ def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None):
3939
self.matcher = JamMa(config=_config['jamma'], profiler=profiler)
4040
self.loss = Loss(_config)
4141

42-
if pretrained_ckpt:
43-
state_dict = torch.load(pretrained_ckpt, map_location='cpu')
42+
if pretrained_ckpt == 'official':
43+
state_dict = torch.hub.load_state_dict_from_url(
44+
'https://github.com/leoluxxx/JamMa/releases/download/v0.1/jamma.ckpt',
45+
file_name='jamma.ckpt')['state_dict']
46+
self.load_state_dict(state_dict, strict=True)
47+
logger.info(f"Load Official JamMa Weight")
48+
elif pretrained_ckpt:
49+
state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
4450
self.load_state_dict(state_dict, strict=True)
4551
logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint")
4652

weight/jamma_weight.ckpt

-21.6 MB
Binary file not shown.

0 commit comments

Comments
 (0)