Skip to content

Commit 75acb8b

Browse files
committed
pin coords
1 parent 1604893 commit 75acb8b

File tree

3 files changed

+19
-17
lines changed

3 files changed

+19
-17
lines changed

cuda_backend/kernel/utils.cu

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ void Handle::set_2D(size_t y, size_t x){
4747
checkCudaErrors(cudaMallocHost((void **)&pin_output,
4848
total_size * sizeof(float)));
4949

50-
checkCudaErrors(cudaMallocHost((void **)&coords,
50+
checkCudaErrors(cudaMalloc((void **)&coords,
51+
2 * total_size * sizeof(float)));
52+
checkCudaErrors(cudaMallocHost((void **)&pin_coords,
5153
2 * total_size * sizeof(float)));
5254

5355
dim3 threads(min(total_size, (long)512), 1, 1);
@@ -81,9 +83,11 @@ void Handle::set_3D(size_t z, size_t y, size_t x){
8183
checkCudaErrors(cudaMallocHost((void **)&pin_output,
8284
total_size * sizeof(float)));
8385

84-
checkCudaErrors(cudaMallocHost((void **)&coords,
86+
checkCudaErrors(cudaMalloc((void **)&coords,
8587
3 * total_size * sizeof(float)));
86-
88+
checkCudaErrors(cudaMallocHost((void **)&pin_coords,
89+
3 * total_size * sizeof(float)));
90+
8791
dim3 threads(min(total_size, (long)512), 1, 1);
8892
dim3 blocks(total_size/512 + 1, 1, 1);
8993
set_coords_3D<<<blocks, threads>>>(coords, dim_z, dim_y, dim_x);
@@ -110,19 +114,14 @@ void Handle::copy_output(float* ret){
110114
void Handle::check_coords(float* output){
111115
float* pin;
112116
if(is_3D){
113-
checkCudaErrors(cudaMallocHost((void **)&pin,
114-
3 * total_size * sizeof(float)));
115-
checkCudaErrors(cudaMemcpyAsync(pin, coords, 3 * total_size * sizeof(float),
117+
checkCudaErrors(cudaMemcpyAsync(pin_coords, coords, 3 * total_size * sizeof(float),
116118
cudaMemcpyDeviceToHost));
117-
memcpy(output, pin, 3 * total_size * sizeof(float));
119+
memcpy(output, pin_coords, 3 * total_size * sizeof(float));
118120
}
119121
else{
120-
checkCudaErrors(cudaMallocHost((void **)&pin,
121-
2 * total_size * sizeof(float)));
122-
checkCudaErrors(cudaMemcpyAsync(pin, coords, 2 * total_size * sizeof(float),
122+
checkCudaErrors(cudaMemcpyAsync(pin_coords, coords, 2 * total_size * sizeof(float),
123123
cudaMemcpyDeviceToHost));
124-
memcpy(output, pin, 2 * total_size * sizeof(float));
124+
memcpy(output, pin_coords, 2 * total_size * sizeof(float));
125125
}
126-
checkCudaErrors(cudaFreeHost(pin));
127126
}
128127

cuda_backend/kernel/utils.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public:
2727
void copy_output(float* ret);
2828

2929
void check_coords(float* coords);
30-
30+
3131
~Handle(){
3232
checkCudaErrors(cudaFree(img));
3333
checkCudaErrors(cudaFree(output));
@@ -46,6 +46,7 @@ private:
4646
float* pin_output;
4747

4848
float* coords;
49+
float* pin_coords;
4950
};
5051

5152
}

python_augmentation.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,12 @@ def create_zero_centered_coordinate_mesh(shape):
2424
test = output == array_image
2525
print(test)
2626

27-
28-
coords = cuda_handle.check_coords()
29-
cor_coords = create_zero_centered_coordinate_mesh(array_image.shape)
30-
import ipdb; ipdb.set_trace()
27+
start = time.time()
28+
for i in range(Iters):
29+
coords = cuda_handle.check_coords()
30+
end = time.time()
31+
print("Get Coords for Shape:{} Cost {}ms".format(array_image.shape, \
32+
(end - start) * 1000 / Iters))
3133

3234
start = time.time()
3335
for i in range(Iters):

0 commit comments

Comments
 (0)