Skip to content

Commit d224f03

Browse files
authored
Merge pull request CSAILVision#188 from CSAILVision/hang
fix compatibility with torch1.1 and scipy1.3
2 parents 2765068 + 9056343 commit d224f03

File tree

6 files changed

+32
-9
lines changed

6 files changed

+32
-9
lines changed

dataset.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,29 @@
11
import os
22
import json
33
import torch
4+
import lib.utils.data as torchdata
45
import cv2
56
from torchvision import transforms
67
import numpy as np
7-
from scipy.misc import imresize
8+
import PIL
89

910

10-
class BaseDataset(torch.utils.data.Dataset):
11+
def imresize(im, size, interp='bilinear'):
12+
if interp == 'nearest':
13+
resample = PIL.Image.NEAREST
14+
elif interp == 'bilinear':
15+
resample = PIL.Image.BILINEAR
16+
elif interp == 'bicubic':
17+
resample = PIL.Image.BICUBIC
18+
else:
19+
raise Exception('resample method undefined!')
20+
21+
return np.array(
22+
PIL.Image.fromarray(im).resize((size[1], size[0]), resample)
23+
)
24+
25+
26+
class BaseDataset(torchdata.Dataset):
1127
def __init__(self, odgt, opt, **kwargs):
1228
# parse options
1329
self.imgSizes = opt.imgSizes

eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from utils import AverageMeter, colorEncode, accuracy, intersectionAndUnion, setup_logger
1616
from lib.nn import user_scattered_collate, async_copy_to
1717
from lib.utils import as_numpy
18+
import lib.utils.data as torchdata
1819
import cv2
1920
from tqdm import tqdm
2021

@@ -132,7 +133,7 @@ def main(cfg, gpu):
132133
cfg.DATASET.root_dataset,
133134
cfg.DATASET.list_val,
134135
cfg.DATASET)
135-
loader_val = torch.utils.data.DataLoader(
136+
loader_val = torchdata.DataLoader(
136137
dataset_val,
137138
batch_size=cfg.VAL.batch_size,
138139
shuffle=False,

eval_multipro.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from utils import AverageMeter, colorEncode, accuracy, intersectionAndUnion, parse_devices, setup_logger
1717
from lib.nn import user_scattered_collate, async_copy_to
1818
from lib.utils import as_numpy
19+
import lib.utils.data as torchdata
1920
import cv2
2021
from tqdm import tqdm
2122

@@ -93,7 +94,7 @@ def worker(cfg, gpu_id, start_idx, end_idx, result_queue):
9394
cfg.DATASET.list_val,
9495
cfg.DATASET,
9596
start_idx=start_idx, end_idx=end_idx)
96-
loader_val = torch.utils.data.DataLoader(
97+
loader_val = torchdata.DataLoader(
9798
dataset_val,
9899
batch_size=cfg.VAL.batch_size,
99100
shuffle=False,

lib/utils/data/dataloader.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import torch
22
import torch.multiprocessing as multiprocessing
3-
from torch._C import _set_worker_signal_handlers, _update_worker_pids, \
3+
from torch._C import _set_worker_signal_handlers, \
44
_remove_worker_pids, _error_if_any_worker_fails
5+
try:
6+
from torch._C import _set_worker_pids
7+
except:
8+
from torch._C import _update_worker_pids as _set_worker_pids
59
from .sampler import SequentialSampler, RandomSampler, BatchSampler
610
import signal
7-
import functools
811
import collections
912
import re
1013
import sys
@@ -235,7 +238,7 @@ def __init__(self, loader):
235238
w.daemon = True # ensure that the worker exits on process exit
236239
w.start()
237240

238-
_update_worker_pids(id(self), tuple(w.pid for w in self.workers))
241+
_set_worker_pids(id(self), tuple(w.pid for w in self.workers))
239242
_set_SIGCHLD_handler()
240243
self.worker_pids_set = True
241244

test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from utils import colorEncode, find_recursive, setup_logger
1515
from lib.nn import user_scattered_collate, async_copy_to
1616
from lib.utils import as_numpy
17+
import lib.utils.data as torchdata
1718
import cv2
1819
from tqdm import tqdm
1920
from config import cfg
@@ -115,7 +116,7 @@ def main(cfg, gpu):
115116
dataset_test = TestDataset(
116117
cfg.list_test,
117118
cfg.DATASET)
118-
loader_test = torch.utils.data.DataLoader(
119+
loader_test = torchdata.DataLoader(
119120
dataset_test,
120121
batch_size=cfg.TEST.batch_size,
121122
shuffle=False,

train.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from models import ModelBuilder, SegmentationModule
1515
from utils import AverageMeter, parse_devices, setup_logger
1616
from lib.nn import UserScatteredDataParallel, user_scattered_collate, patch_replication_callback
17+
import lib.utils.data as torchdata
1718

1819

1920
# train one epoch
@@ -168,7 +169,7 @@ def main(cfg, gpus):
168169
cfg.DATASET,
169170
batch_per_gpu=cfg.TRAIN.batch_size_per_gpu)
170171

171-
loader_train = torch.utils.data.DataLoader(
172+
loader_train = torchdata.DataLoader(
172173
dataset_train,
173174
batch_size=len(gpus), # we have modified data_parallel
174175
shuffle=False, # we do not use this param

0 commit comments

Comments
 (0)