Skip to content

Commit ed91ca1

Browse files
committed
Add bool csr element-wise addition algo
1 parent e8e5dbf commit ed91ca1

File tree

1 file changed

+343
-0
lines changed

1 file changed

+343
-0
lines changed
Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
#include <iostream>
2+
#include <inttypes.h>
3+
#include <vector>
4+
5+
#define CL_HPP_ENABLE_EXCEPTIONS
6+
#define CL_HPP_MINIMUM_OPENCL_VERSION BUILD_CLVERSION
7+
#define CL_HPP_TARGET_OPENCL_VERSION BUILD_CLVERSION
8+
#include <CL/cl2.hpp>
9+
10+
#include "clSPARSE.h"
11+
#include "clSPARSE-error.h"
12+
13+
#include "include/clSPARSE-private.hpp"
14+
#include "internal/clsparse-control.hpp"
15+
#include "internal/kernel-cache.hpp"
16+
#include "internal/kernel-wrap.hpp"
17+
#include "internal/clsparse-internal.hpp"
18+
19+
#define GROUPSIZE_256 256
20+
#define WG_SIZE 64
21+
#define TUPLE_QUEUE 6
22+
#define NUM_SEGMENTS 128
23+
//#define WARPSIZE_NV_2HEAP 64
24+
#define value_type float
25+
#define index_type int
26+
#define MERGEPATH_LOCAL 0
27+
#define MERGEPATH_LOCAL_L2 1
28+
#define MERGEPATH_GLOBAL 2
29+
#define MERGELIST_INITSIZE 256
30+
#define BHSPARSE_SUCCESS 0
31+
32+
int run_bool_scan(
33+
cl_mem csrRowPtrA,
34+
cl_mem csrColIndA,
35+
cl_mem csrRowPtrB,
36+
cl_mem csrColIndB,
37+
cl_mem csrRowPtrCt_d,
38+
uint32_t &total_sum,
39+
int m,
40+
cl::Context context,
41+
const clsparseControl control)
42+
{
43+
cl_int cl_status;
44+
45+
uint array_size = m + 1;
46+
uint work_group_size = uint32_t(256);
47+
uint block_size = work_group_size;
48+
uint32_t a_size = (array_size + block_size - 1) / block_size; // max to save first roots
49+
uint32_t b_size = (a_size + block_size - 1) / block_size; // max to save second roots
50+
std::cout << "sizes " << array_size << ' ' << a_size << ' ' << b_size << std::endl;
51+
52+
cl_mem a_gpu = ::clCreateBuffer(context(), CL_MEM_READ_WRITE, sizeof(uint32_t) * a_size, NULL, &cl_status);
53+
cl_mem b_gpu = ::clCreateBuffer(context(), CL_MEM_READ_WRITE, sizeof(uint32_t) * b_size, NULL, &cl_status);
54+
cl_mem total_sum_gpu = ::clCreateBuffer(context(), CL_MEM_READ_WRITE, sizeof(uint32_t), NULL, &cl_status);
55+
56+
cl::LocalSpaceArg local_array = cl::Local(sizeof(uint32_t) * block_size);
57+
58+
const std::string params = std::string() +
59+
"-DINDEX_TYPE=" + OclTypeTraits<cl_int>::type
60+
+ " -DVALUE_TYPE=" + OclTypeTraits<cl_int>::type;
61+
62+
cl::Kernel kernel_scan = KernelCache::get(control->queue, "bool_csradd_scan", "scan_blelloch", params);
63+
cl::Kernel kernel_update = KernelCache::get(control->queue, "bool_csradd_scan", "update_pref_sum", params);
64+
65+
size_t szLocalWorkSize[1];
66+
size_t szGlobalWorkSize[1];
67+
68+
szLocalWorkSize[0] = work_group_size;
69+
szGlobalWorkSize[0] = (array_size + work_group_size - 1) / work_group_size * work_group_size;
70+
71+
cl::NDRange local(szLocalWorkSize[0]);
72+
cl::NDRange global(szGlobalWorkSize[0]);
73+
74+
KernelWrap kWrapper_scan(kernel_scan);
75+
KernelWrap kWrapper_update(kernel_update);
76+
77+
uint leaf_size = 1;
78+
79+
kWrapper_scan << a_gpu << csrRowPtrCt_d << local_array << total_sum_gpu << array_size;
80+
81+
cl_status = kWrapper_scan.run(control, global, local);
82+
83+
if (cl_status != CL_SUCCESS)
84+
{
85+
return clsparseInvalidKernelExecution;
86+
}
87+
88+
89+
uint32_t outer = (array_size + block_size - 1) / block_size;
90+
91+
cl_mem *a_gpu_ptr = &a_gpu;
92+
cl_mem *b_gpu_ptr = &b_gpu;
93+
94+
unsigned int *a_size_ptr = &a_size;
95+
unsigned int *b_size_ptr = &b_size;
96+
97+
clEnqueueReadBuffer(control->queue(), total_sum_gpu, CL_TRUE, 0, sizeof(uint32_t), &total_sum, 0, NULL, NULL);
98+
99+
std::cout << "INNER TOTAL SUM: " << total_sum << std::endl;
100+
101+
while (outer > 1) {
102+
leaf_size *= block_size;
103+
104+
std::cout << "META: " << (outer + work_group_size - 1) / work_group_size * work_group_size << std::endl;
105+
size_t rec_szLocalWorkSize[1];
106+
size_t rec_szGlobalWorkSize[1];
107+
108+
rec_szLocalWorkSize[0] = work_group_size;
109+
rec_szGlobalWorkSize[0] = (outer + work_group_size - 1) / work_group_size * work_group_size;
110+
111+
cl::NDRange rec_local(rec_szLocalWorkSize[0]);
112+
cl::NDRange rec_global(rec_szGlobalWorkSize[0]);
113+
114+
std::cout << "scan " << std::endl;
115+
116+
kWrapper_scan.reset();
117+
kWrapper_scan << *b_gpu_ptr << *a_gpu_ptr << local_array << total_sum_gpu << outer;
118+
119+
cl_status = kWrapper_scan.run(control, rec_global, rec_local);
120+
121+
if (cl_status != CL_SUCCESS)
122+
{
123+
return clsparseInvalidKernelExecution;
124+
}
125+
126+
std::cout << "update " << std::endl;
127+
128+
kWrapper_update.reset();
129+
kWrapper_update << csrRowPtrCt_d << *a_gpu_ptr << array_size << leaf_size;
130+
131+
cl_status = kWrapper_update.run(control, global, local);
132+
133+
if (cl_status != CL_SUCCESS)
134+
{
135+
return clsparseInvalidKernelExecution;
136+
}
137+
138+
outer = (outer + block_size - 1) / block_size;
139+
std::swap(a_gpu_ptr, b_gpu_ptr);
140+
std::swap(a_size_ptr, b_size_ptr);
141+
}
142+
143+
clEnqueueReadBuffer(control->queue(), total_sum_gpu, CL_TRUE, 0, sizeof(uint32_t), &total_sum, 0, NULL, NULL);
144+
}
145+
146+
int run_bool_merge_count(
147+
cl_mem csrRowPtrA,
148+
cl_mem csrColIndA,
149+
cl_mem csrRowPtrB,
150+
cl_mem csrColIndB,
151+
cl_mem csrRowPtrCt_d,
152+
int m,
153+
const clsparseControl control
154+
)
155+
{
156+
const std::string params = std::string() +
157+
"-DINDEX_TYPE=" + OclTypeTraits<cl_int>::type
158+
+ " -DVALUE_TYPE=" + OclTypeTraits<cl_int>::type;
159+
160+
cl::Kernel kernel = KernelCache::get(control->queue, "bool_csradd_merge", "merge_count", params);
161+
162+
size_t szLocalWorkSize[1];
163+
size_t szGlobalWorkSize[1];
164+
165+
int num_threads = WG_SIZE;
166+
size_t num_blocks = m;
167+
168+
szLocalWorkSize[0] = num_threads;
169+
szGlobalWorkSize[0] = num_blocks * szLocalWorkSize[0];
170+
171+
cl::NDRange local(szLocalWorkSize[0]);
172+
cl::NDRange global(szGlobalWorkSize[0]);
173+
174+
KernelWrap kWrapper(kernel);
175+
176+
kWrapper << csrRowPtrA << csrColIndA << csrRowPtrB << csrColIndB << csrRowPtrCt_d;
177+
178+
cl_int cl_status_2 = kWrapper.run(control, global, local);
179+
180+
if (cl_status_2 != CL_SUCCESS)
181+
{
182+
return clsparseInvalidKernelExecution;
183+
}
184+
}
185+
186+
int run_bool_merge_fill(
187+
cl_mem csrRowPtrA,
188+
cl_mem csrColIndA,
189+
cl_mem csrRowPtrB,
190+
cl_mem csrColIndB,
191+
cl_mem csrRowPtrCt_d,
192+
cl_mem csrColIndC,
193+
int m,
194+
int total_sum,
195+
const clsparseControl control
196+
)
197+
{
198+
const std::string params = std::string() +
199+
"-DINDEX_TYPE=" + OclTypeTraits<cl_int>::type
200+
+ " -DVALUE_TYPE=" + OclTypeTraits<cl_int>::type;
201+
202+
cl::Kernel kernel = KernelCache::get(control->queue, "bool_csradd_merge", "merge_fill", params);
203+
204+
size_t szLocalWorkSize[1];
205+
size_t szGlobalWorkSize[1];
206+
207+
szLocalWorkSize[0] = WG_SIZE;
208+
szGlobalWorkSize[0] = m * WG_SIZE;
209+
210+
printf("local %d global %d \n", szLocalWorkSize[0], szGlobalWorkSize[0]);
211+
212+
cl::NDRange local(szLocalWorkSize[0]);
213+
cl::NDRange global(szGlobalWorkSize[0]);
214+
215+
KernelWrap kWrapper(kernel);
216+
217+
kWrapper << csrRowPtrA << csrColIndA << csrRowPtrB << csrColIndB << csrRowPtrCt_d << csrColIndC;
218+
219+
cl_int cl_status_2 = kWrapper.run(control, global, local);
220+
221+
if (cl_status_2 != CL_SUCCESS)
222+
{
223+
return clsparseInvalidKernelExecution;
224+
}
225+
}
226+
227+
CLSPARSE_EXPORT clsparseStatus
228+
clsparseBoolScsrElemAdd(
229+
const clsparseBoolCsrMatrix* sparseMatA,
230+
const clsparseBoolCsrMatrix* sparseMatB,
231+
clsparseBoolCsrMatrix* sparseMatC,
232+
const clsparseControl control )
233+
{
234+
cl_int cl_status;
235+
236+
if (!clsparseInitialized)
237+
{
238+
return clsparseNotInitialized;
239+
}
240+
241+
if (control == nullptr)
242+
{
243+
return clsparseInvalidControlObject;
244+
}
245+
246+
const clsparseBoolCsrMatrixPrivate* A = static_cast<const clsparseBoolCsrMatrixPrivate*>(sparseMatA);
247+
const clsparseBoolCsrMatrixPrivate* B = static_cast<const clsparseBoolCsrMatrixPrivate*>(sparseMatB);
248+
clsparseBoolCsrMatrixPrivate* C = static_cast<clsparseBoolCsrMatrixPrivate*>(sparseMatC);
249+
250+
// outer init
251+
cl_mem csrRowPtrA = A->row_pointer;
252+
cl_mem csrColIndA = A->col_indices;
253+
254+
//int is important here, since kernel receives only bootstrapint, not size_t
255+
int m = A->num_rows;
256+
int k1 = A->num_cols;
257+
int k2 = B->num_rows;
258+
int n = B->num_cols;
259+
int nnzA = A->num_nonzeros;
260+
int nnzB = B->num_nonzeros;
261+
262+
if(k1 != k2)
263+
{
264+
std::cerr << "A.n and B.m don't match!" << std::endl;
265+
return clsparseInvalidKernelExecution;
266+
}
267+
268+
cl_mem csrRowPtrB = B->row_pointer;
269+
cl_mem csrColIndB = B->col_indices;
270+
271+
int pattern = 0;
272+
273+
cl::Context cxt = control->getContext();
274+
275+
cl_mem csrRowPtrCt_d = ::clCreateBuffer(cxt(), CL_MEM_READ_WRITE, (m + 1) * sizeof( cl_int ), NULL, &cl_status );
276+
std::vector<int> csrRowPtrCt_h(m + 1, 0);
277+
278+
clEnqueueFillBuffer(control->queue(), csrRowPtrCt_d, &pattern, sizeof(cl_int), 0, (m + 1)*sizeof(cl_int), 0, NULL, NULL);
279+
280+
std::cout << "mergecount " << std::endl;
281+
run_bool_merge_count(csrRowPtrA, csrColIndA, csrRowPtrB, csrColIndB, csrRowPtrCt_d, m, control);
282+
283+
int run_status = clEnqueueReadBuffer(control->queue(),
284+
csrRowPtrCt_d,
285+
1,
286+
0,
287+
(m + 1)*sizeof(cl_int),
288+
csrRowPtrCt_h.data(),
289+
0,
290+
0,
291+
0);
292+
// for (auto i = csrRowPtrCt_h.begin(); i != csrRowPtrCt_h.end(); ++i)
293+
// {
294+
// std::cout << *i << ' ';
295+
// }
296+
// std::cout << std::endl;
297+
298+
uint32_t total_sum = 0;
299+
std::cout << "scan " << std::endl;
300+
run_bool_scan(csrRowPtrA, csrColIndA, csrRowPtrB, csrColIndB, csrRowPtrCt_d, total_sum, m, cxt, control);
301+
std::cout << "TOTAL " << total_sum << std::endl;
302+
run_status = clEnqueueReadBuffer(control->queue(),
303+
csrRowPtrCt_d,
304+
1,
305+
0,
306+
(m + 1)*sizeof(cl_int),
307+
csrRowPtrCt_h.data(),
308+
0,
309+
0,
310+
0);
311+
// for (auto i = csrRowPtrCt_h.begin(); i != csrRowPtrCt_h.end(); ++i)
312+
// {
313+
// std::cout << *i << ' ';
314+
// }
315+
// std::cout << std::endl;
316+
317+
std::cout << "mergefill " << std::endl;
318+
std::vector<int> csrColIndC_h(total_sum, 0);
319+
cl_mem csrColIndC = ::clCreateBuffer( cxt(), CL_MEM_READ_WRITE, total_sum * sizeof( cl_int ), NULL, &cl_status );
320+
321+
run_bool_merge_fill(csrRowPtrA, csrColIndA, csrRowPtrB, csrColIndB, csrRowPtrCt_d, csrColIndC, m, total_sum, control);
322+
323+
run_status = clEnqueueReadBuffer(control->queue(),
324+
csrColIndC,
325+
1,
326+
0,
327+
total_sum*sizeof(cl_int),
328+
csrColIndC_h.data(),
329+
0,
330+
0,
331+
0);
332+
333+
C->num_rows = m;
334+
C->num_cols = n;
335+
C->num_nonzeros = total_sum;
336+
C->row_pointer = csrRowPtrCt_d;
337+
C->col_indices = csrColIndC;
338+
// for (auto i = csrColIndC_h.begin(); i != csrColIndC_h.end(); ++i)
339+
// {
340+
// std::cout << *i << ' ';
341+
// }
342+
// std::cout << std::endl;
343+
}

0 commit comments

Comments
 (0)