Skip to content

Commit adc7296

Browse files
committed
id_gpu
1 parent 6f606d0 commit adc7296

File tree

5 files changed

+16
-19
lines changed

5 files changed

+16
-19
lines changed

cuda_backend/kernel/test.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
extern "C" {
44

5-
Handle* init_2D_handle(size_t y, size_t x, int mode_type, float c_val){
6-
Handle *ret = new Handle(mode_type, c_val);
5+
Handle* init_2D_handle(size_t y, size_t x, int mode_type, float c_val, int id_gpu){
6+
Handle *ret = new Handle(mode_type, c_val, id_gpu);
77
ret->set_2D(y, x);
88
return ret;
99
}
1010

11-
Handle* init_3D_handle(size_t z, size_t y, size_t x, int mode_type, float c_val){
12-
Handle *ret = new Handle(mode_type, c_val);
11+
Handle* init_3D_handle(size_t z, size_t y, size_t x, int mode_type, float c_val, int id_gpu){
12+
Handle *ret = new Handle(mode_type, c_val, id_gpu);
1313
ret->set_3D(z, y, x);
1414
return ret;
1515
}

cuda_backend/kernel/utils.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,9 @@ extern "C"{
1818

1919
class Handle {
2020
public:
21-
Handle(int mode_type, float c_val) : batchsize(1), dim_x(1), dim_y(1),
21+
Handle(int mode_type, float c_val, int id_gpu) : batchsize(1), dim_x(1), dim_y(1),
2222
dim_z(1), mode_type(mode_type), c_val(c_val){
23+
checkCudaErrors(cudaSetDevice(id_gpu));
2324
checkCudaErrors(cudaStreamCreate(&stream));
2425
checkCudaErrors(curandCreateGenerator(&gen,
2526
CURAND_RNG_PSEUDO_DEFAULT));

cuda_backend/py_api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
lib = CDLL(lib_dir + "/libcudaAugmentation.so", RTLD_GLOBAL)
88

99
init_2D = lib.init_2D_handle
10-
init_2D.argtypes = [c_int, c_int, c_int, c_float]
10+
init_2D.argtypes = [c_int, c_int, c_int, c_float, c_int]
1111
init_2D.restype = c_void_p
1212

1313
init_3D = lib.init_3D_handle
14-
init_3D.argtypes = [c_int, c_int, c_int, c_int, c_float]
14+
init_3D.argtypes = [c_int, c_int, c_int, c_int, c_float, c_int]
1515
init_3D.restype = c_void_p
1616

1717
l_i = lib.linear_interpolate
@@ -167,7 +167,7 @@ def defrom(self, handle):
167167
return None
168168

169169
class Handle(object):
170-
def __init__(self, shape, RGB=False, mode='constant', cval=0.0):
170+
def __init__(self, shape, RGB=False, mode='constant', cval=0.0, id_gpu=0):
171171
self.RGB = RGB
172172
self.shape = shape
173173
if self.RGB:
@@ -193,11 +193,11 @@ def __init__(self, shape, RGB=False, mode='constant', cval=0.0):
193193
raise ValueError
194194

195195
if(len(shape) == 2 or RGB):
196-
self.cuda_handle = init_2D(self.shape[0], self.shape[1], type_mode, float(cval))
196+
self.cuda_handle = init_2D(self.shape[0], self.shape[1], type_mode, float(cval), id_gpu)
197197
self.is_3D = False
198198
else:
199199
self.is_3D = True
200-
self.cuda_handle = init_3D(shape[0], shape[1], shape[2], type_mode, float(cval))
200+
self.cuda_handle = init_3D(shape[0], shape[1], shape[2], type_mode, float(cval), id_gpu)
201201

202202
def augment(self, img):
203203
if self.RGB:

deform.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
'''
2+
Reference to https://github.com/MIC-DKFZ/batchgenerators
3+
'''
14

25
import numpy as np
36
import SimpleITK as sitk
@@ -68,8 +71,8 @@ def spatial_augment(img, RGB=False, do_scale=True, scale=0.5, angle=0.75*np.pi,
6871
coords = create_zero_centered_coordinate_mesh(img.shape)
6972

7073
# coords = scale_coords(coords, scale)
71-
# coords = rotate_coords_3d(coords, angle, angle, angle)
72-
coords = elastic_deform_coordinates(coords, 500, 12)
74+
coords = rotate_coords_3d(coords, angle, angle, angle)
75+
# coords = elastic_deform_coordinates(coords, 500, 12)
7376

7477
for d in range(len(img.shape)):
7578
ctr = float(np.round(img.shape[d] / 2.))

python_augmentation.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,6 @@
1212

1313
np.set_printoptions(precision=3)
1414

15-
def create_zero_centered_coordinate_mesh(shape):
16-
tmp = tuple([np.arange(i) for i in shape])
17-
coords = np.array(np.meshgrid(*tmp, indexing='ij')).astype(float)
18-
for d in range(len(shape)):
19-
coords[d] -= ((np.array(shape).astype(float)) / 2.)[d]
20-
return coords
21-
2215
def check(correct, output):
2316
'''
2417
Unit Test Pass When less than 0.01 rate pixels loss ( > 0.001)

0 commit comments

Comments
 (0)