Skip to content

Commit 78fdf65

Browse files
[Feature] add auto resume (open-mmlab#1172)
* [Feature] add auto resume * Update mmseg/utils/find_latest_checkpoint.py Co-authored-by: Miao Zheng <[email protected]> * Update mmseg/utils/find_latest_checkpoint.py Co-authored-by: Miao Zheng <[email protected]> * modify docstring * Update mmseg/utils/find_latest_checkpoint.py Co-authored-by: Miao Zheng <[email protected]> * add copyright Co-authored-by: Miao Zheng <[email protected]>
1 parent af9ccd3 commit 78fdf65

File tree

5 files changed

+93
-2
lines changed

5 files changed

+93
-2
lines changed

mmseg/apis/train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from mmseg.core import DistEvalHook, EvalHook
1313
from mmseg.datasets import build_dataloader, build_dataset
14-
from mmseg.utils import get_root_logger
14+
from mmseg.utils import find_latest_checkpoint, get_root_logger
1515

1616

1717
def init_random_seed(seed=None, device='cuda'):
@@ -160,6 +160,10 @@ def train_segmentor(model,
160160
hook = build_from_cfg(hook_cfg, HOOKS)
161161
runner.register_hook(hook, priority=priority)
162162

163+
if cfg.resume_from is None and cfg.get('auto_resume'):
164+
resume_from = find_latest_checkpoint(cfg.work_dir)
165+
if resume_from is not None:
166+
cfg.resume_from = resume_from
163167
if cfg.resume_from:
164168
runner.resume(cfg.resume_from)
165169
elif cfg.load_from:

mmseg/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .collect_env import collect_env
33
from .logger import get_root_logger
4+
from .misc import find_latest_checkpoint
45

5-
__all__ = ['get_root_logger', 'collect_env']
6+
__all__ = ['get_root_logger', 'collect_env', 'find_latest_checkpoint']

mmseg/utils/misc.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import glob
3+
import os.path as osp
4+
import warnings
5+
6+
7+
def find_latest_checkpoint(path, suffix='pth'):
8+
"""This function is for finding the latest checkpoint.
9+
10+
It will be used when automatically resume, modified from
11+
https://github.com/open-mmlab/mmdetection/blob/dev-v2.20.0/mmdet/utils/misc.py
12+
13+
Args:
14+
path (str): The path to find checkpoints.
15+
suffix (str): File extension for the checkpoint. Defaults to pth.
16+
17+
Returns:
18+
latest_path(str | None): File path of the latest checkpoint.
19+
"""
20+
if not osp.exists(path):
21+
warnings.warn("The path of the checkpoints doesn't exist.")
22+
return None
23+
if osp.exists(osp.join(path, f'latest.{suffix}')):
24+
return osp.join(path, f'latest.{suffix}')
25+
26+
checkpoints = glob.glob(osp.join(path, f'*.{suffix}'))
27+
if len(checkpoints) == 0:
28+
warnings.warn('The are no checkpoints in the path')
29+
return None
30+
latest = -1
31+
latest_path = ''
32+
for checkpoint in checkpoints:
33+
if len(checkpoint) < len(latest_path):
34+
continue
35+
# `count` is iteration number, as checkpoints are saved as
36+
# 'iter_xx.pth' or 'epoch_xx.pth' and xx is iteration number.
37+
count = int(osp.basename(checkpoint).split('_')[-1].split('.')[0])
38+
if count > latest:
39+
latest = count
40+
latest_path = checkpoint
41+
return latest_path

tests/test_utils/test_misc.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
import os.path as osp
3+
import tempfile
4+
5+
from mmseg.utils import find_latest_checkpoint
6+
7+
8+
def test_find_latest_checkpoint():
9+
with tempfile.TemporaryDirectory() as tempdir:
10+
# no checkpoints in the path
11+
path = tempdir
12+
latest = find_latest_checkpoint(path)
13+
assert latest is None
14+
15+
# The path doesn't exist
16+
path = osp.join(tempdir, 'none')
17+
latest = find_latest_checkpoint(path)
18+
assert latest is None
19+
20+
# test when latest.pth exists
21+
with tempfile.TemporaryDirectory() as tempdir:
22+
with open(osp.join(tempdir, 'latest.pth'), 'w') as f:
23+
f.write('latest')
24+
path = tempdir
25+
latest = find_latest_checkpoint(path)
26+
assert latest == osp.join(tempdir, 'latest.pth')
27+
28+
with tempfile.TemporaryDirectory() as tempdir:
29+
for iter in range(1600, 160001, 1600):
30+
with open(osp.join(tempdir, f'iter_{iter}.pth'), 'w') as f:
31+
f.write(f'iter_{iter}.pth')
32+
latest = find_latest_checkpoint(tempdir)
33+
assert latest == osp.join(tempdir, 'iter_160000.pth')
34+
35+
with tempfile.TemporaryDirectory() as tempdir:
36+
for epoch in range(1, 21):
37+
with open(osp.join(tempdir, f'epoch_{epoch}.pth'), 'w') as f:
38+
f.write(f'epoch_{epoch}.pth')
39+
latest = find_latest_checkpoint(tempdir)
40+
assert latest == osp.join(tempdir, 'epoch_20.pth')

tools/train.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ def parse_args():
7575
default='none',
7676
help='job launcher')
7777
parser.add_argument('--local_rank', type=int, default=0)
78+
parser.add_argument(
79+
'--auto-resume',
80+
action='store_true',
81+
help='resume from the latest checkpoint automatically.')
7882
args = parser.parse_args()
7983
if 'LOCAL_RANK' not in os.environ:
8084
os.environ['LOCAL_RANK'] = str(args.local_rank)
@@ -118,6 +122,7 @@ def main():
118122
cfg.gpu_ids = args.gpu_ids
119123
else:
120124
cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
125+
cfg.auto_resume = args.auto_resume
121126

122127
# init distributed env first, since logger depends on the dist info.
123128
if args.launcher == 'none':

0 commit comments

Comments
 (0)