Skip to content

Commit 6fdd70a

Browse files
committed
Translate Done
1 parent 66e8d7b commit 6fdd70a

File tree

7 files changed

+99
-5
lines changed

7 files changed

+99
-5
lines changed

cuda_backend/kernel/spatial_deform.cu

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,37 @@ __global__ void flip_3D(float* coords,
8888
}
8989
}
9090
}
91+
92+
__global__ void translate_3D(float* coords,
93+
size_t dim_z,
94+
size_t dim_y,
95+
size_t dim_x,
96+
float seg_z,
97+
float seg_y,
98+
float seg_x){
99+
size_t index = blockIdx.x * blockDim.x + threadIdx.x;
100+
size_t total = dim_x * dim_y * dim_z;
101+
if(index < total){
102+
coords[index] += seg_z;
103+
coords[index + total] += seg_y;
104+
coords[index + total * 2] += seg_x;
105+
__syncthreads();
106+
}
107+
}
108+
109+
__global__ void translate_2D(float* coords,
110+
size_t dim_y,
111+
size_t dim_x,
112+
float seg_y,
113+
float seg_x){
114+
size_t index = blockIdx.x * blockDim.x + threadIdx.x;
115+
size_t total = dim_x * dim_y;
116+
if(index < total){
117+
coords[index] += seg_y;
118+
coords[index + total] += seg_x;
119+
__syncthreads();
120+
}
121+
}
122+
123+
124+

cuda_backend/kernel/spatial_deform.cuh

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,4 +26,19 @@ __global__ void flip_3D(float* coords,
2626
int do_y,
2727
int do_x);
2828

29+
30+
__global__ void translate_3D(float* coords,
31+
size_t dim_z,
32+
size_t dim_y,
33+
size_t dim_x,
34+
float seg_z,
35+
float seg_y,
36+
float seg_x);
37+
38+
__global__ void translate_2D(float* coords,
39+
size_t dim_y,
40+
size_t dim_x,
41+
float seg_y,
42+
float seg_x);
43+
2944
#endif

cuda_backend/kernel/test.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ void cu_flip(Handle* cuda_handle, int do_x, int do_y, int do_z){
3838
cuda_handle->flip(do_x, do_y, do_z);
3939
}
4040

41+
void cu_translate(Handle* cuda_handle, float seg_x, float seg_y, float seg_z){
42+
cuda_handle->translate(seg_x, seg_y, seg_z);
43+
}
44+
4145
void endding_flag(Handle* cuda_handle){
4246
cuda_handle->recenter();
4347
}

cuda_backend/kernel/utils.cu

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,22 @@ void Handle::flip(int do_x, int do_y, int do_z){
8484
}
8585
}
8686

87+
void Handle::translate(float seg_x, float seg_y, float seg_z){
88+
if(is_3D){
89+
dim3 threads(min(total_size, (long)512), 1, 1);
90+
dim3 blocks(total_size/512 + 1, 1, 1);
91+
translate_3D<<<blocks, threads, 0, stream>>>(coords, dim_z, dim_y, dim_x,
92+
seg_z, seg_y, seg_x);
93+
checkCudaErrors(cudaStreamSynchronize(stream));
94+
}
95+
else{
96+
dim3 threads(min(total_size, (long)512), 1, 1);
97+
dim3 blocks(total_size/512 + 1, 1, 1);
98+
translate_2D<<<blocks, threads, 0, stream>>>(coords, dim_y, dim_x, seg_y, seg_x);
99+
checkCudaErrors(cudaStreamSynchronize(stream));
100+
}
101+
}
102+
87103
void Handle::set_3D(size_t z, size_t y, size_t x){
88104
is_3D = true;
89105
dim_x = x;

cuda_backend/kernel/utils.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ public:
4040

4141
void flip(int do_x, int do_y, int do_z=0);
4242

43+
void translate(float seg_x=0, float seg_y=0, float seg_z=0);
44+
4345
~Handle(){
4446
checkCudaErrors(cudaFree(img));
4547
checkCudaErrors(cudaFree(output));

cuda_backend/py_api.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
flip = lib.cu_flip
3030
flip.argtypes = [c_void_p, c_int, c_int, c_int]
3131

32+
translate = lib.cu_translate
33+
translate.argtypes = [c_void_p, c_float, c_float, c_float]
34+
3235
class Spatial_Deform(object):
3336
def __init__(self, prob=1.0):
3437
self.prob = prob
@@ -71,6 +74,21 @@ def defrom(self, handle):
7174
else:
7275
return None
7376

77+
class Translate(Spatial_Deform):
78+
def __init__(self, seg_x=0.0, seg_y=0.0, seg_z=0.0, prob=1.0):
79+
Spatial_Deform.__init__(self, prob)
80+
self.label = 'Translate'
81+
self.seg_x = seg_x
82+
self.seg_y = seg_y
83+
self.seg_z = seg_z
84+
85+
def defrom(self, handle):
86+
if np.random.uniform() < self.prob:
87+
translate(handle, self.seg_x, self.seg_y, self.seg_z)
88+
return self.label
89+
else:
90+
return None
91+
7492
class End_Flag(Spatial_Deform):
7593
def __init__(self, prob=1.0):
7694
Spatial_Deform.__init__(self, prob)
@@ -106,8 +124,8 @@ def augment(self, img):
106124
output = np.ones(img.shape).astype(np.float32)
107125
labels = self.deform_coords()
108126

109-
# check coords
110-
self.get_coords()
127+
# # check coords
128+
# self.get_coords()
111129

112130
if not self.RGB:
113131
l_i(self.cuda_handle, output, img, 1)
@@ -126,6 +144,9 @@ def scale(self, sc, prob=1.0):
126144
def flip(self, do_x=False, do_y=False, do_z=False, prob=1.0):
127145
self.deform_list.append(Flip(do_x, do_y, do_z, prob))
128146

147+
def translate(self, seg_x=0.0, seg_y=0.0, seg_z=0.0, prob=1.0):
148+
self.deform_list.append(Translate(seg_x, seg_y, seg_z, prob))
149+
129150
def end_flag(self):
130151
self.deform_list.append(End_Flag())
131152

python_augmentation.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def test_3D():
3535

3636
cuda_handle = Handle(array_image.shape)
3737
# cuda_handle.scale(0.5)
38-
cuda_handle.flip(do_y=True, do_x=True, do_z=True)
38+
# cuda_handle.flip(do_y=True, do_x=True, do_z=True)
39+
cuda_handle.translate(100, 100, 20)
3940
cuda_handle.end_flag()
4041

4142
correct_ret = deform.spatial_augment(array_image)
@@ -68,7 +69,8 @@ def test_2D():
6869

6970
cuda_handle = Handle(array_image.shape, RGB=True)
7071
# cuda_handle.scale(0.5)
71-
cuda_handle.flip(do_y=True)
72+
# cuda_handle.flip(do_y=True)
73+
cuda_handle.translate(400, 400)
7274
cuda_handle.end_flag()
7375

7476
if len(array_image.shape) == 2:
@@ -115,5 +117,5 @@ def test_2D():
115117
(end - start) * 1000 / Iters_CPU))
116118

117119
if __name__ == "__main__":
118-
test_3D()
120+
# test_3D()
119121
test_2D()

0 commit comments

Comments
 (0)