Skip to content

Commit b8a31da

Browse files
authored
download pretrain if missing (tusen-ai#254)
1 parent 2a5e650 commit b8a31da

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

detection_train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ def train_net(config):
129129
elif pModel.from_scratch:
130130
arg_params, aux_params = dict(), dict()
131131
else:
132+
if not os.path.exists("%s-%04d.params" % (pretrain_prefix, pretrain_epoch)):
133+
from utils.download_pretrain import download
134+
download(pretrain_prefix, pretrain_epoch)
132135
arg_params, aux_params = load_checkpoint(pretrain_prefix, pretrain_epoch)
133136

134137
if pModel.process_weight is not None:

utils/download_pretrain.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import os
2+
import urllib.request
3+
4+
5+
def report(block_count, block_size, content_size):
6+
if block_count % (content_size // block_size // 5) == 1:
7+
print("Downloaded %.1f/100" % (block_size * block_count / content_size * 100))
8+
9+
10+
def download(prefix, epoch):
11+
base_name = prefix.replace("pretrain_model/", "") + "-%04d.params" % epoch
12+
save_name = "%s-%04d.params" % (prefix, epoch)
13+
base_url = os.environ.get("SIMPLEDET_BASE_URL", "https://1dv.alarge.space/")
14+
full_url = base_url + base_name
15+
16+
try:
17+
print("Downloading %s from %s" % (save_name, full_url))
18+
urllib.request.urlretrieve(full_url, save_name, report)
19+
except Exception as e:
20+
print("Fail to download %s. You can mannually download it from %s and put it to %s" % (base_name, full_url, save_name))
21+
os.remove(save_name)
22+
raise e

0 commit comments

Comments
 (0)