Skip to content

Commit bfc3cdb

Browse files
authored
[Fix] Update digit_version (open-mmlab#778)
* update digit_version * add unittest * fix import
1 parent 58f5dbc commit bfc3cdb

File tree

5 files changed

+66
-14
lines changed

5 files changed

+66
-14
lines changed

mmseg/__init__.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,52 @@
1+
import warnings
2+
13
import mmcv
4+
from packaging.version import parse
25

36
from .version import __version__, version_info
47

58
MMCV_MIN = '1.3.7'
69
MMCV_MAX = '1.4.0'
710

811

9-
def digit_version(version_str):
10-
digit_version = []
11-
for x in version_str.split('.'):
12-
if x.isdigit():
13-
digit_version.append(int(x))
14-
elif x.find('rc') != -1:
15-
patch_version = x.split('rc')
16-
digit_version.append(int(patch_version[0]) - 1)
17-
digit_version.append(int(patch_version[1]))
18-
return digit_version
12+
def digit_version(version_str: str, length: int = 4):
13+
"""Convert a version string into a tuple of integers.
14+
15+
This method is usually used for comparing two versions. For pre-release
16+
versions: alpha < beta < rc.
17+
18+
Args:
19+
version_str (str): The version string.
20+
length (int): The maximum number of version levels. Default: 4.
21+
22+
Returns:
23+
tuple[int]: The version info in digits (integers).
24+
"""
25+
version = parse(version_str)
26+
assert version.release, f'failed to parse version {version_str}'
27+
release = list(version.release)
28+
release = release[:length]
29+
if len(release) < length:
30+
release = release + [0] * (length - len(release))
31+
if version.is_prerelease:
32+
mapping = {'a': -3, 'b': -2, 'rc': -1}
33+
val = -4
34+
# version.pre can be None
35+
if version.pre:
36+
if version.pre[0] not in mapping:
37+
warnings.warn(f'unknown prerelease version {version.pre[0]}, '
38+
'version checking may go wrong')
39+
else:
40+
val = mapping[version.pre[0]]
41+
release.extend([val, version.pre[-1]])
42+
else:
43+
release.extend([val, 0])
44+
45+
elif version.is_postrelease:
46+
release.extend([1, version.post])
47+
else:
48+
release.extend([0, 0])
49+
return tuple(release)
1950

2051

2152
mmcv_min_version = digit_version(MMCV_MIN)
@@ -27,4 +58,4 @@ def digit_version(version_str):
2758
f'MMCV=={mmcv.__version__} is used but incompatible. ' \
2859
f'Please install mmcv>={mmcv_min_version}, <={mmcv_max_version}.'
2960

30-
__all__ = ['__version__', 'version_info']
61+
__all__ = ['__version__', 'version_info', 'digit_version']

mmseg/datasets/builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from mmcv.parallel import collate
99
from mmcv.runner import get_dist_info
10-
from mmcv.utils import Registry, build_from_cfg
10+
from mmcv.utils import Registry, build_from_cfg, digit_version
1111
from torch.utils.data import DataLoader, DistributedSampler
1212

1313
if platform.system() != 'Windows':
@@ -133,7 +133,7 @@ def build_dataloader(dataset,
133133
worker_init_fn, num_workers=num_workers, rank=rank,
134134
seed=seed) if seed is not None else None
135135

136-
if torch.__version__ >= '1.8.0':
136+
if digit_version(torch.__version__) >= digit_version('1.8.0'):
137137
data_loader = DataLoader(
138138
dataset,
139139
batch_size=batch_size,

requirements/runtime.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
matplotlib
22
numpy
3+
packaging
34
prettytable

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ line_length = 79
88
multi_line_output = 0
99
known_standard_library = setuptools
1010
known_first_party = mmseg
11-
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,prettytable,pytest,scipy,seaborn,torch,ts
11+
known_third_party = PIL,cityscapesscripts,cv2,detail,matplotlib,mmcv,numpy,onnxruntime,oss2,packaging,prettytable,pytest,scipy,seaborn,torch,ts
1212
no_lines_before = STDLIB,LOCALFOLDER
1313
default_section = THIRDPARTY

tests/test_digit_version.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from mmseg import digit_version
2+
3+
4+
def test_digit_version():
5+
assert digit_version('0.2.16') == (0, 2, 16, 0, 0, 0)
6+
assert digit_version('1.2.3') == (1, 2, 3, 0, 0, 0)
7+
assert digit_version('1.2.3rc0') == (1, 2, 3, 0, -1, 0)
8+
assert digit_version('1.2.3rc1') == (1, 2, 3, 0, -1, 1)
9+
assert digit_version('1.0rc0') == (1, 0, 0, 0, -1, 0)
10+
assert digit_version('1.0') == digit_version('1.0.0')
11+
assert digit_version('1.5.0+cuda90_cudnn7.6.3_lms') == digit_version('1.5')
12+
assert digit_version('1.0.0dev') < digit_version('1.0.0a')
13+
assert digit_version('1.0.0a') < digit_version('1.0.0a1')
14+
assert digit_version('1.0.0a') < digit_version('1.0.0b')
15+
assert digit_version('1.0.0b') < digit_version('1.0.0rc')
16+
assert digit_version('1.0.0rc1') < digit_version('1.0.0')
17+
assert digit_version('1.0.0') < digit_version('1.0.0post')
18+
assert digit_version('1.0.0post') < digit_version('1.0.0post1')
19+
assert digit_version('v1') == (1, 0, 0, 0, 0, 0)
20+
assert digit_version('v1.1.5') == (1, 1, 5, 0, 0, 0)

0 commit comments

Comments
 (0)