Skip to content

Commit acff833

Browse files
committed
Update sample with bool operations
1 parent f032beb commit acff833

File tree

2 files changed

+70
-2
lines changed

2 files changed

+70
-2
lines changed

samples/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ foreach( sample ${SAMPLES_CPP} )
8080
add_executable( ${sample} ${sample}.cpp)
8181
target_link_libraries( ${sample} clSPARSE ${CMAKE_DL_LIBS} ${OPENCL_LIBRARIES} )
8282
if( CMAKE_COMPILER_IS_GNUCXX OR ( CMAKE_CXX_COMPILER_ID MATCHES "Clang" ) )
83-
target_compile_options( ${sample} PUBLIC -std=c++11 )
83+
target_compile_options( ${sample} PUBLIC -std=c++17 )
8484
endif( )
8585
set_target_properties( ${sample} PROPERTIES VERSION ${clSPARSE_VERSION} )
8686
set_target_properties( ${sample} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${PROJECT_BINARY_DIR}/staging" )

samples/my-sample.cpp

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <iostream>
2323
#include <vector>
24+
#include <assert.h>
2425

2526
#define CL_HPP_ENABLE_EXCEPTIONS
2627
#define CL_HPP_MINIMUM_OPENCL_VERSION BUILD_CLVERSION
@@ -224,8 +225,24 @@ int main (int argc, char* argv[])
224225
// status = clsparseBoolScsrSpGemm(&A, &B, &C, createResult.control );
225226
status = clsparseBoolScsrElemAdd(&A, &B, &C, createResult.control );
226227

227-
std::vector<int> csrColIndC_h(C.num_nonzeros, 0);
228+
std::vector<int> csrRowPtrC_h((C.num_rows + 1), 0);
228229
int run_status = clEnqueueReadBuffer(queue(),
230+
C.row_pointer,
231+
1,
232+
0,
233+
(C.num_rows + 1)*sizeof(cl_int),
234+
csrRowPtrC_h.data(),
235+
0,
236+
0,
237+
0);
238+
for (auto i = csrRowPtrC_h.begin(); i != csrRowPtrC_h.end(); ++i)
239+
{
240+
std::cout << *i << ' ';
241+
}
242+
std::cout << std::endl;
243+
244+
std::vector<int> csrColIndC_h(C.num_nonzeros, 0);
245+
run_status = clEnqueueReadBuffer(queue(),
229246
C.col_indices,
230247
1,
231248
0,
@@ -241,10 +258,61 @@ int main (int argc, char* argv[])
241258
std::cout << std::endl;
242259

243260

261+
// CPU ADDITION
262+
263+
assert(A.num_rows == B.num_rows);
264+
265+
clsparseIdx_t* row_ptr_A = (clsparseIdx_t*)malloc((A.num_rows + 1) * sizeof(clsparseIdx_t));
266+
clsparseIdx_t* cols_A = (clsparseIdx_t*)malloc(A.num_nonzeros * sizeof(clsparseIdx_t));
267+
clsparseIdx_t* row_ptr_B = (clsparseIdx_t*)malloc((B.num_rows + 1) * sizeof(clsparseIdx_t));
268+
clsparseIdx_t* cols_B = (clsparseIdx_t*)malloc(B.num_nonzeros * sizeof(clsparseIdx_t));
244269

270+
clEnqueueReadBuffer(queue(), A.row_pointer, CL_TRUE, 0, (A.num_rows + 1) * sizeof(clsparseIdx_t),
271+
row_ptr_A, 0, NULL, NULL);
272+
clEnqueueReadBuffer(queue(), A.col_indices, CL_TRUE, 0, A.num_nonzeros * sizeof(clsparseIdx_t),
273+
cols_A, 0, NULL, NULL);
245274

275+
clEnqueueReadBuffer(queue(), B.row_pointer, CL_TRUE, 0, (B.num_rows + 1) * sizeof(clsparseIdx_t),
276+
row_ptr_B, 0, NULL, NULL);
277+
clEnqueueReadBuffer(queue(), B.col_indices, CL_TRUE, 0, B.num_nonzeros * sizeof(clsparseIdx_t),
278+
cols_B, 0, NULL, NULL);
279+
280+
std::vector<int> row_ptr_C;
281+
std::vector<int> cols_C;
282+
283+
row_ptr_C.push_back(0);
284+
for (int i = 1; i <= A.num_rows; i++)
285+
{
286+
int start_A = row_ptr_A[i - 1];
287+
int end_A = row_ptr_A[i];
288+
int start_B = row_ptr_B[i - 1];
289+
int end_B = row_ptr_B[i];
290+
291+
std::vector<int> dst;
292+
std::merge(cols_A + start_A, cols_A + end_A, cols_B + start_B, cols_B + end_B, std::back_inserter(dst));
293+
dst.erase(std::unique(dst.begin(), dst.end()), dst.end());
294+
295+
row_ptr_C.push_back(row_ptr_C[i - 1] + dst.size());
296+
cols_C.insert(cols_C.end(), dst.begin(), dst.end());
297+
dst.clear();
298+
}
299+
300+
for (auto i = row_ptr_C.begin(); i != row_ptr_C.end(); ++i)
301+
{
302+
std::cout << *i << ' ';
303+
}
304+
std::cout << std::endl;
305+
306+
for (auto i = cols_C.begin(); i != cols_C.end(); ++i)
307+
{
308+
std::cout << *i << ' ';
309+
}
310+
std::cout << std::endl;
246311

312+
// VERIFY RESULTS
247313

314+
assert(csrRowPtrC_h == row_ptr_C);
315+
assert(csrColIndC_h == cols_C);
248316

249317

250318
if (status != clsparseSuccess)

0 commit comments

Comments
 (0)