diff --git a/src/blas/backends/cublas/cublas_scope_handle.cpp b/src/blas/backends/cublas/cublas_scope_handle.cpp index 366d67708..9ae426a53 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.cpp +++ b/src/blas/backends/cublas/cublas_scope_handle.cpp @@ -32,36 +32,80 @@ namespace cublas { */ thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{}; -CublasScopedContextHandler::CublasScopedContextHandler(sycl::interop_handle& ih) : ih(ih) {} +CublasScopedContextHandler::CublasScopedContextHandler(sycl::interop_handle& ih) : ih(ih) { + // Initialize streamID member to a CUstream associated with the queue `ih` + // has been submitted to. + streamId = ih.get_native_queue(); -cublasHandle_t CublasScopedContextHandler::get_handle() { + // Initialize the `cublasHandle_t` member `nativeHandle` CUdevice device = ih.get_native_device(); - CUstream streamId = get_stream(); - cublasStatus_t err; - auto it = handle_helper.cublas_handle_mapper_.find(device); if (it != handle_helper.cublas_handle_mapper_.end()) { - cublasHandle_t nativeHandle = it->second; + // Use existing handle if one already exists for the device, but update + // the native stream. + nativeHandle = it->second; cudaStream_t currentStreamId; + cublasStatus_t err; CUBLAS_ERROR_FUNC(cublasGetStream, err, nativeHandle, ¤tStreamId); if (currentStreamId != streamId) { CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId); } - return nativeHandle; } - - cublasHandle_t nativeHandle; - CUBLAS_ERROR_FUNC(cublasCreate, err, &nativeHandle); - CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId); - - auto insert_iter = + else { + // Create a new handle if one doesn't already exist for the device + cublasStatus_t err; + CUBLAS_ERROR_FUNC(cublasCreate, err, &nativeHandle); + CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId); handle_helper.cublas_handle_mapper_.insert(std::make_pair(device, nativeHandle)); + } +} - return nativeHandle; +void CublasScopedContextHandler::begin_recording_if_graph() { +// interop_handle graph methods only available from extension version 2 +#if SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2 + if (!ih.ext_codeplay_has_graph()) { + return; + } + + CUresult err; +#if CUDA_VERSION >= 12030 + // After CUDA 12.3 we can use cuStreamBeginCaptureToGraph to capture + // the stream directly in the native graph, rather than needing to + // instantiate the stream capture as a new graph. + auto graph = ih.ext_codeplay_get_native_graph(); + CUDA_ERROR_FUNC(cuStreamBeginCaptureToGraph, err, streamId, graph, nullptr, nullptr, 0, + CU_STREAM_CAPTURE_MODE_GLOBAL); +#else + CUDA_ERROR_FUNC(cuStreamBeginCapture, err, streamId, CU_STREAM_CAPTURE_MODE_GLOBAL); +#endif // CUDA_VERSION +#endif // SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2 } -CUstream CublasScopedContextHandler::get_stream() { - return ih.get_native_queue(); +void CublasScopedContextHandler::end_recording_if_graph() { +// interop_handle graph methods only available from extension version 2 +#if SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2 + if (!ih.ext_codeplay_has_graph()) { + return; + } + + auto graph = ih.ext_codeplay_get_native_graph(); + CUresult err; +#if CUDA_VERSION >= 12030 + CUDA_ERROR_FUNC(cuStreamEndCapture, err, streamId, &graph); +#else + // cuStreamEndCapture returns a new graph, if we overwrite + // "graph" it won't be picked up by the SYCL runtime, as + // "ext_codeplay_get_native_graph" returns a passed-by-value pointer. + CUgraph recorded_graph; + CUDA_ERROR_FUNC(cuStreamEndCapture, err, streamId, &recorded_graph); + + // Add graph to native graph as a child node + // Need to return a node object for the node to be created, + // can't be nullptr. + CUgraphNode node; + CUDA_ERROR_FUNC(cuGraphAddChildGraphNode, err, &node, graph, nullptr, 0, recorded_graph); +#endif // CUDA_VERSION +#endif // SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2 } } // namespace cublas } // namespace blas diff --git a/src/blas/backends/cublas/cublas_scope_handle.hpp b/src/blas/backends/cublas/cublas_scope_handle.hpp index a8ca67e9e..51a6bd46e 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.hpp +++ b/src/blas/backends/cublas/cublas_scope_handle.hpp @@ -63,18 +63,49 @@ the handle must be destroyed when the context goes out of scope. This will bind class CublasScopedContextHandler { sycl::interop_handle& ih; static thread_local cublas_handle handle_helper; - CUstream get_stream(); + cublasHandle_t nativeHandle; + // Cache the native CU stream when the `CublasScopedContextHandler`object + // is constructed. This avoids calling `get_native_queue(ih)` multiple + // times which isn't guaranteed to return the same CUstream handle each + // time. A scenario that causes problems when trying to start/end cuda + // stream recording to a graph. + CUstream streamId; public: + /** + * @brief Constructor + * @detail Creates the cublasHandle_t by implicitly impose the advice + * given by nvidia for creating a cublas_handle. (e.g. one cuStream per device + * per thread). + */ CublasScopedContextHandler(sycl::interop_handle& ih); /** - * @brief get_handle: creates the handle by implicitly impose the advice - * given by nvidia for creating a cublas_handle. (e.g. one cuStream per device - * per thread). - * @return cublasHandle_t a handle to construct cublas routines - */ - cublasHandle_t get_handle(); + * @brief Start recording cuBlas calls to a graph. + * @detail Checks if the command-group associated with \p ih is being added + * to a graph, and if so, begin stream recording of the native CUDA stream + * associated with \p queue to the native cuda-graph object. + */ + void begin_recording_if_graph(); + + /** + * @brief End recording cuBlas calls to a graph. + * @detail Checks if the command-group associated with \p ih is being added + * to a graph, and if so, ends stream recording of the native CUDA stream + * associated with \p queue to the native cuda-graph object. Doing any + * extra work to ensure that stream recorded calls get added as nodes to + * the native graph object associated with \p ih. + * @param queue The sycl queue to end stream recording on native stream + * backing the queue. + */ + void end_recording_if_graph(); + + /// @brief Query the cuBLAS handle created on construction + /// @return cublasHandle_t a handle to construct cublas routines + cublasHandle_t get_handle() const { + return nativeHandle; + } + // This is a work-around function for reinterpret_casting the memory. This // will be fixed when SYCL-2020 has been implemented for Pi backend. template diff --git a/src/blas/backends/cublas/cublas_task.hpp b/src/blas/backends/cublas/cublas_task.hpp index 1b86543f9..c34e3f6b8 100644 --- a/src/blas/backends/cublas/cublas_task.hpp +++ b/src/blas/backends/cublas/cublas_task.hpp @@ -61,7 +61,9 @@ static inline void host_task_internal(H& cgh, F f) { cgh.host_task([f](sycl::interop_handle ih) { #endif auto sc = CublasScopedContextHandler(ih); + sc.begin_recording_if_graph(); f(sc); + sc.end_recording_if_graph(); }); } #endif diff --git a/tests/unit_tests/CMakeLists.txt b/tests/unit_tests/CMakeLists.txt index 7fd6d508f..d8c0b2560 100644 --- a/tests/unit_tests/CMakeLists.txt +++ b/tests/unit_tests/CMakeLists.txt @@ -54,7 +54,8 @@ set(blas_TEST_LIST blas_level2 blas_level3 blas_batch - blas_extensions) + blas_extensions + blas_sycl_graph) set(blas_TEST_LINK "") diff --git a/tests/unit_tests/blas/CMakeLists.txt b/tests/unit_tests/blas/CMakeLists.txt index c80d0043a..52d657e92 100644 --- a/tests/unit_tests/blas/CMakeLists.txt +++ b/tests/unit_tests/blas/CMakeLists.txt @@ -27,3 +27,4 @@ add_subdirectory(level2) add_subdirectory(level3) add_subdirectory(batch) add_subdirectory(extensions) +add_subdirectory(sycl-graph) diff --git a/tests/unit_tests/blas/sycl-graph/CMakeLists.txt b/tests/unit_tests/blas/sycl-graph/CMakeLists.txt new file mode 100644 index 000000000..be44a0f6b --- /dev/null +++ b/tests/unit_tests/blas/sycl-graph/CMakeLists.txt @@ -0,0 +1,61 @@ +#=============================================================================== +# Copyright 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +# Build object from all test sources +set(SYCL_GRAPH_SOURCES) + +set(SYCL_GRAPH_SOURCES_W_CBLAS "gemm_usm.cpp" "gemm_batch_usm.cpp") + +if(CBLAS_FOUND) + list(APPEND SYCL_GRAPH_SOURCES ${SYCL_GRAPH_SOURCES_W_CBLAS}) +endif() + +if(BUILD_SHARED_LIBS) + add_library(blas_sycl_graph_rt OBJECT ${SYCL_GRAPH_SOURCES}) + target_compile_options(blas_sycl_graph_rt PRIVATE -DCALL_RT_API -DNOMINMAX) + target_include_directories(blas_sycl_graph_rt + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../include + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include + PUBLIC ${PROJECT_SOURCE_DIR}/include + PUBLIC ${PROJECT_SOURCE_DIR}/deps/googletest/include + PUBLIC ${CMAKE_BINARY_DIR}/bin + $<$:${CBLAS_INCLUDE}> + ) + if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) + add_sycl_to_target(TARGET blas_sycl_graph_rt SOURCES ${SYCL_GRAPH_SOURCES}) + else() + target_link_libraries(blas_sycl_graph_rt PUBLIC ONEMATH::SYCL::SYCL) + endif() +endif() + +add_library(blas_sycl_graph_ct OBJECT ${SYCL_GRAPH_SOURCES}) +target_compile_options(blas_sycl_graph_ct PRIVATE -DNOMINMAX) +target_include_directories(blas_sycl_graph_ct + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../include + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include + PUBLIC ${PROJECT_SOURCE_DIR}/include + PUBLIC ${PROJECT_SOURCE_DIR}/deps/googletest/include + PUBLIC ${CMAKE_BINARY_DIR}/bin + $<$:${CBLAS_INCLUDE}> +) +if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) + add_sycl_to_target(TARGET blas_sycl_graph_ct SOURCES ${SYCL_GRAPH_SOURCES}) +else() + target_link_libraries(blas_sycl_graph_ct PUBLIC ONEMATH::SYCL::SYCL) +endif() diff --git a/tests/unit_tests/blas/sycl-graph/gemm_batch_usm.cpp b/tests/unit_tests/blas/sycl-graph/gemm_batch_usm.cpp new file mode 100644 index 000000000..807b66dbc --- /dev/null +++ b/tests/unit_tests/blas/sycl-graph/gemm_batch_usm.cpp @@ -0,0 +1,379 @@ +/******************************************************************************* +* Copyright 2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif +#include "cblas.h" +#include "oneapi/math.hpp" +#include "oneapi/math/detail/config.hpp" +#include "allocator_helper.hpp" +#include "onemath_blas_helper.hpp" +#include "reference_blas_templates.hpp" +#include "test_common.hpp" +#include "test_helper.hpp" + +#include + +using namespace sycl; + +extern std::vector devices; + +namespace { + +#ifdef SYCL_EXT_ONEAPI_GRAPH +template +int test(device* dev, oneapi::math::layout layout, int64_t group_count, size_t graph_nodes) { + // Catch asynchronous exceptions. + auto exception_handler = [](exception_list exceptions) { + for (std::exception_ptr const& e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (exception const& e) { + std::cout << "Caught asynchronous SYCL exception during GEMM_BATCH:\n" + << e.what() << std::endl; + print_error_code(e); + } + } + }; + + queue main_queue(*dev, exception_handler, property::queue::in_order{}); + context cxt = main_queue.get_context(); + + // Prepare data. + auto uaint = usm_allocator(cxt, *dev); + std::vector m(uaint), n(uaint), k(uaint), lda(uaint), ldb(uaint), + ldc(uaint), group_size(uaint); + + auto uatranspose = usm_allocator(cxt, *dev); + std::vector transa(uatranspose), + transb(uatranspose); + + auto uaTs = usm_allocator(cxt, *dev); + std::vector alpha(uaTs), beta(uaTs); + + m.resize(group_count); + n.resize(group_count); + k.resize(group_count); + lda.resize(group_count); + ldb.resize(group_count); + ldc.resize(group_count); + group_size.resize(group_count); + transa.resize(group_count); + transb.resize(group_count); + alpha.resize(group_count); + beta.resize(group_count); + + int64_t total_batch_count = 0; + for (int64_t i = 0; i < group_count; i++) { + group_size[i] = 1 + std::rand() % 20; + m[i] = 1 + std::rand() % 500; + n[i] = 1 + std::rand() % 500; + k[i] = 1 + std::rand() % 500; + lda[i] = std::max(m[i], k[i]); + ldb[i] = std::max(n[i], k[i]); + ldc[i] = std::max(m[i], n[i]); + alpha[i] = rand_scalar(); + beta[i] = rand_scalar(); + transa[i] = (oneapi::math::transpose)(std::rand() % 2); + transb[i] = (oneapi::math::transpose)(std::rand() % 2); + total_batch_count += group_size[i]; + } + + auto uaTap = usm_allocator(cxt, *dev); + auto uaTbp = usm_allocator(cxt, *dev); + auto uaTcp = usm_allocator(cxt, *dev); + auto uaTsp = usm_allocator(cxt, *dev); + std::vector a_array(uaTap); + std::vector b_array(uaTbp); + std::vector c_array(uaTcp), c_cast_ref_array(uaTcp); + std::vector a_ref_array(uaTsp), b_ref_array(uaTsp), c_ref_array(uaTsp); + a_array.resize(total_batch_count); + b_array.resize(total_batch_count); + c_array.resize(total_batch_count); + a_ref_array.resize(total_batch_count); + b_ref_array.resize(total_batch_count); + c_cast_ref_array.resize(total_batch_count); + c_ref_array.resize(total_batch_count); + + size_t idx = 0; + int64_t size_a = 0, size_b = 0, size_c = 0; + for (int64_t i = 0; i < group_count; i++) { + switch (layout) { + case oneapi::math::layout::col_major: + size_a = lda[i] * ((transa[i] == oneapi::math::transpose::nontrans) ? k[i] : m[i]); + size_b = ldb[i] * ((transb[i] == oneapi::math::transpose::nontrans) ? n[i] : k[i]); + size_c = ldc[i] * n[i]; + break; + case oneapi::math::layout::row_major: + size_a = lda[i] * ((transa[i] == oneapi::math::transpose::nontrans) ? m[i] : k[i]); + size_b = ldb[i] * ((transb[i] == oneapi::math::transpose::nontrans) ? k[i] : n[i]); + size_c = ldc[i] * m[i]; + break; + default: break; + } + for (int64_t j = 0; j < group_size[i]; j++) { + a_array[idx] = (Ta*)oneapi::math::malloc_shared(64, sizeof(Ta) * size_a, *dev, cxt); + b_array[idx] = (Tb*)oneapi::math::malloc_shared(64, sizeof(Tb) * size_b, *dev, cxt); + c_array[idx] = (Tc*)oneapi::math::malloc_shared(64, sizeof(Tc) * size_c, *dev, cxt); + a_ref_array[idx] = (Ts*)oneapi::math::malloc_shared(64, sizeof(Ts) * size_a, *dev, cxt); + b_ref_array[idx] = (Ts*)oneapi::math::malloc_shared(64, sizeof(Ts) * size_b, *dev, cxt); + c_cast_ref_array[idx] = + (Tc*)oneapi::math::malloc_shared(64, sizeof(Tc) * size_c, *dev, cxt); + c_ref_array[idx] = (Ts*)oneapi::math::malloc_shared(64, sizeof(Ts) * size_c, *dev, cxt); + rand_matrix(a_array[idx], layout, transa[i], m[i], k[i], lda[i]); + rand_matrix(b_array[idx], layout, transb[i], k[i], n[i], ldb[i]); + rand_matrix(c_array[idx], layout, oneapi::math::transpose::nontrans, m[i], n[i], + ldc[i]); + copy_matrix(a_array[idx], layout, transa[i], m[i], k[i], lda[i], a_ref_array[idx]); + copy_matrix(b_array[idx], layout, transb[i], k[i], n[i], ldb[i], b_ref_array[idx]); + copy_matrix(c_array[idx], layout, oneapi::math::transpose::nontrans, m[i], n[i], ldc[i], + c_ref_array[idx]); + idx++; + } + } + + // Call reference GEMM_BATCH. + using fp_ref = typename ref_type_info::type; + int* m_ref = (int*)oneapi::math::aligned_alloc(64, sizeof(int) * group_count); + int* n_ref = (int*)oneapi::math::aligned_alloc(64, sizeof(int) * group_count); + int* k_ref = (int*)oneapi::math::aligned_alloc(64, sizeof(int) * group_count); + int* lda_ref = (int*)oneapi::math::aligned_alloc(64, sizeof(int) * group_count); + int* ldb_ref = (int*)oneapi::math::aligned_alloc(64, sizeof(int) * group_count); + int* ldc_ref = (int*)oneapi::math::aligned_alloc(64, sizeof(int) * group_count); + int* group_size_ref = (int*)oneapi::math::aligned_alloc(64, sizeof(int) * group_count); + + CBLAS_TRANSPOSE* transa_ref = + (CBLAS_TRANSPOSE*)oneapi::math::aligned_alloc(64, sizeof(CBLAS_TRANSPOSE) * group_count); + CBLAS_TRANSPOSE* transb_ref = + (CBLAS_TRANSPOSE*)oneapi::math::aligned_alloc(64, sizeof(CBLAS_TRANSPOSE) * group_count); + + if ((m_ref == NULL) || (n_ref == NULL) || (k_ref == NULL) || (lda_ref == NULL) || + (ldb_ref == NULL) || (ldc_ref == NULL) || (transa_ref == NULL) || (transb_ref == NULL) || + (group_size_ref == NULL)) { + std::cout << "Error cannot allocate input arrays\n"; + oneapi::math::aligned_free(m_ref); + oneapi::math::aligned_free(n_ref); + oneapi::math::aligned_free(k_ref); + oneapi::math::aligned_free(lda_ref); + oneapi::math::aligned_free(ldb_ref); + oneapi::math::aligned_free(ldc_ref); + oneapi::math::aligned_free(transa_ref); + oneapi::math::aligned_free(transb_ref); + oneapi::math::aligned_free(group_size_ref); + idx = 0; + for (int64_t i = 0; i < group_count; i++) { + for (int64_t j = 0; j < group_size[i]; j++) { + oneapi::math::free_shared(a_array[idx], cxt); + oneapi::math::free_shared(b_array[idx], cxt); + oneapi::math::free_shared(c_array[idx], cxt); + oneapi::math::free_shared(a_ref_array[idx], cxt); + oneapi::math::free_shared(b_ref_array[idx], cxt); + oneapi::math::free_shared(c_cast_ref_array[idx], cxt); + oneapi::math::free_shared(c_ref_array[idx], cxt); + idx++; + } + } + return false; + } + for (unsigned node = 0; node < graph_nodes; node++) { + idx = 0; + for (int64_t i = 0; i < group_count; i++) { + transa_ref[i] = convert_to_cblas_trans(transa[i]); + transb_ref[i] = convert_to_cblas_trans(transb[i]); + m_ref[i] = (int)m[i]; + n_ref[i] = (int)n[i]; + k_ref[i] = (int)k[i]; + lda_ref[i] = (int)lda[i]; + ldb_ref[i] = (int)ldb[i]; + ldc_ref[i] = (int)ldc[i]; + group_size_ref[i] = (int)group_size[i]; + for (int64_t j = 0; j < group_size_ref[i]; j++) { + ::gemm(convert_to_cblas_layout(layout), transa_ref[i], transb_ref[i], + (const int*)&m_ref[i], (const int*)&n_ref[i], (const int*)&k_ref[i], + (const fp_ref*)&alpha[i], (const fp_ref*)a_ref_array[idx], + (const int*)&lda_ref[i], (const fp_ref*)b_ref_array[idx], + (const int*)&ldb_ref[i], (const fp_ref*)&beta[i], (fp_ref*)c_ref_array[idx], + (const int*)&ldc_ref[i]); + idx++; + } + } + } + + // Being recording oneMath operations to a graph + namespace sycl_exp = sycl::ext::oneapi::experimental; + auto graph = sycl_exp::command_graph(main_queue); + graph.begin_recording(main_queue); + try { + for (unsigned node = 0; node < graph_nodes; node++) { +#ifdef CALL_RT_API + switch (layout) { + case oneapi::math::layout::col_major: + oneapi::math::blas::column_major::gemm_batch( + main_queue, &transa[0], &transb[0], &m[0], &n[0], &k[0], &alpha[0], + (const Ta**)&a_array[0], &lda[0], (const Tb**)&b_array[0], &ldb[0], + &beta[0], &c_array[0], &ldc[0], group_count, &group_size[0]); + break; + case oneapi::math::layout::row_major: + oneapi::math::blas::row_major::gemm_batch( + main_queue, &transa[0], &transb[0], &m[0], &n[0], &k[0], &alpha[0], + (const Ta**)&a_array[0], &lda[0], (const Tb**)&b_array[0], &ldb[0], + &beta[0], &c_array[0], &ldc[0], group_count, &group_size[0]); + break; + default: break; + } +#else + switch (layout) { + case oneapi::math::layout::col_major: + TEST_RUN_BLAS_CT_SELECT( + main_queue, oneapi::math::blas::column_major::gemm_batch, &transa[0], + &transb[0], &m[0], &n[0], &k[0], &alpha[0], (const Ta**)&a_array[0], + &lda[0], (const Tb**)&b_array[0], &ldb[0], &beta[0], &c_array[0], &ldc[0], + group_count, &group_size[0]); + break; + case oneapi::math::layout::row_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::math::blas::row_major::gemm_batch, + &transa[0], &transb[0], &m[0], &n[0], &k[0], &alpha[0], + (const Ta**)&a_array[0], &lda[0], + (const Ta**)&b_array[0], &ldb[0], &beta[0], &c_array[0], + &ldc[0], group_count, &group_size[0]); + break; + default: break; + } +#endif + } + } + catch (exception const& e) { + std::cout << "Caught synchronous SYCL exception during GEMM_BATCH:\n" + << e.what() << std::endl; + print_error_code(e); + } + catch (const oneapi::math::unimplemented& e) { + oneapi::math::aligned_free(m_ref); + oneapi::math::aligned_free(n_ref); + oneapi::math::aligned_free(k_ref); + oneapi::math::aligned_free(lda_ref); + oneapi::math::aligned_free(ldb_ref); + oneapi::math::aligned_free(ldc_ref); + oneapi::math::aligned_free(transa_ref); + oneapi::math::aligned_free(transb_ref); + oneapi::math::aligned_free(group_size_ref); + idx = 0; + for (int64_t i = 0; i < group_count; i++) { + for (int64_t j = 0; j < group_size[i]; j++) { + oneapi::math::free_shared(a_array[idx], cxt); + oneapi::math::free_shared(b_array[idx], cxt); + oneapi::math::free_shared(c_array[idx], cxt); + oneapi::math::free_shared(a_ref_array[idx], cxt); + oneapi::math::free_shared(b_ref_array[idx], cxt); + oneapi::math::free_shared(c_cast_ref_array[idx], cxt); + oneapi::math::free_shared(c_ref_array[idx], cxt); + idx++; + } + } + return test_skipped; + } + catch (const std::runtime_error& error) { + std::cout << "Error raised during execution of GEMM_BATCH:\n" << error.what() << std::endl; + } + + // End recording of sycl queue and create executable graph + graph.end_recording(main_queue); + auto exec_graph = graph.finalize(); + + // Submit graph to execute and wait for completion + main_queue.ext_oneapi_graph(exec_graph).wait_and_throw(); + + bool good = true; + idx = 0; + for (int64_t i = 0; i < group_count; i++) { + const int error_mag = 10 * k[i]; + for (int64_t j = 0; j < group_size[i]; j++) { + copy_matrix(c_ref_array[idx], layout, oneapi::math::transpose::nontrans, m[i], n[i], + ldc[i], c_cast_ref_array[idx]); + good = good && check_almost_equal_matrix(c_array[idx], c_cast_ref_array[idx], layout, + m[i], n[i], ldc[i], error_mag, std::cout); + idx++; + } + } + oneapi::math::aligned_free(m_ref); + oneapi::math::aligned_free(n_ref); + oneapi::math::aligned_free(k_ref); + oneapi::math::aligned_free(lda_ref); + oneapi::math::aligned_free(ldb_ref); + oneapi::math::aligned_free(ldc_ref); + oneapi::math::aligned_free(transa_ref); + oneapi::math::aligned_free(transb_ref); + oneapi::math::aligned_free(group_size_ref); + idx = 0; + for (int64_t i = 0; i < group_count; i++) { + for (int64_t j = 0; j < group_size[i]; j++) { + oneapi::math::free_shared(a_array[idx], cxt); + oneapi::math::free_shared(b_array[idx], cxt); + oneapi::math::free_shared(c_array[idx], cxt); + oneapi::math::free_shared(a_ref_array[idx], cxt); + oneapi::math::free_shared(b_ref_array[idx], cxt); + oneapi::math::free_shared(c_cast_ref_array[idx], cxt); + oneapi::math::free_shared(c_ref_array[idx], cxt); + idx++; + } + } + + return (int)good; +} +#else // ifdef SYCL_EXT_ONEAPI_GRAPH +template +int test(device*, oneapi::math::layout, int64_t, size_t) { + // Skip test if graph recording variant and device doesn't support sycl_ext_oneapi_graph + return 1; +} +#endif + +struct GraphGemmBatchUsmTests + : public ::testing::TestWithParam> { + virtual void SetUp() override { + // Skip test if graph recording variant and device doesn't support sycl_ext_oneapi_graph + CHECK_GRAPH_ON_DEVICE(std::get<0>(GetParam())); + } +}; + +TEST_P(GraphGemmBatchUsmTests, RealSinglePrecision) { + sycl::device* dev = std::get<0>(GetParam()); + oneapi::math::layout layout = std::get<1>(GetParam()); + const int64_t group_count = 5; + const unsigned graph_nodes = 3; + EXPECT_TRUEORSKIP((test(dev, layout, group_count, graph_nodes))); +} +INSTANTIATE_TEST_SUITE_P(GraphGemmBatchUsmTestSuite, GraphGemmBatchUsmTests, + ::testing::Combine(testing::ValuesIn(devices), + testing::Values(oneapi::math::layout::col_major, + oneapi::math::layout::row_major)), + ::LayoutDeviceNamePrint()); + +} // anonymous namespace diff --git a/tests/unit_tests/blas/sycl-graph/gemm_usm.cpp b/tests/unit_tests/blas/sycl-graph/gemm_usm.cpp new file mode 100644 index 000000000..30591cac7 --- /dev/null +++ b/tests/unit_tests/blas/sycl-graph/gemm_usm.cpp @@ -0,0 +1,202 @@ +/******************************************************************************* +* Copyright 2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif +#include "cblas.h" +#include "oneapi/math.hpp" +#include "oneapi/math/detail/config.hpp" +#include "onemath_blas_helper.hpp" +#include "reference_blas_templates.hpp" +#include "test_common.hpp" +#include "test_helper.hpp" + +#include + +using namespace sycl; + +extern std::vector devices; + +namespace { +#ifdef SYCL_EXT_ONEAPI_GRAPH +template +int test(device* dev, oneapi::math::layout layout, int m, int n, int k, int lda, int ldb, int ldc, + Tc alpha, Tc beta) { + // Catch asynchronous exceptions. + auto exception_handler = [](exception_list exceptions) { + for (std::exception_ptr const& e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (exception const& e) { + std::cout << "Caught asynchronous SYCL exception during GEMM:\n" + << e.what() << std::endl; + print_error_code(e); + } + } + }; + + // Test with 4 nodes in the graph, for each combination of matrix A & B being transposed + // or not. + constexpr size_t num_ops = 4; + std::array, num_ops> trans = { + std::make_pair(oneapi::math::transpose::nontrans, oneapi::math::transpose::nontrans), + std::make_pair(oneapi::math::transpose::nontrans, oneapi::math::transpose::trans), + std::make_pair(oneapi::math::transpose::trans, oneapi::math::transpose::nontrans), + std::make_pair(oneapi::math::transpose::trans, oneapi::math::transpose::trans), + }; + + queue main_queue(*dev, exception_handler, property::queue::in_order{}); + + // Prepare data. Have a single transposed and non-transposed matrix allocation that's + // reused across all nodes in the graph. + auto ua = usm_allocator(main_queue); + auto uc = usm_allocator(main_queue); + std::vector Trans(ua), NoTrans(ua); + rand_matrix(Trans, layout, oneapi::math::transpose::trans, m, k, lda); + rand_matrix(NoTrans, layout, oneapi::math::transpose::nontrans, m, k, lda); + + // Created input/output matrix C that is a data dependency across nodes, and + // C_ref used to verify it's final value against the cblas reference. + std::vector C(ua); + rand_matrix(C, layout, oneapi::math::transpose::nontrans, m, n, ldc); + auto C_ref = C; + + // Being recording oneMath operations to a graph + namespace sycl_exp = sycl::ext::oneapi::experimental; + auto graph = sycl_exp::command_graph(main_queue); + graph.begin_recording(main_queue); + for (auto [transa, transb] : trans) { + // Assign transpose or non-transpose matrix to A & B + Ta* A = transa == oneapi::math::transpose::trans ? Trans.data() : NoTrans.data(); + Ta* B = transa == oneapi::math::transpose::trans ? Trans.data() : NoTrans.data(); + + // Calculate reference + const int m_ref = m, n_ref = n, k_ref = k; + const int lda_ref = lda, ldb_ref = ldb, ldc_ref = ldc; + + using ta_ref = typename ref_type_info::type; + using tc_ref = typename ref_type_info::type; + ::gemm(convert_to_cblas_layout(layout), convert_to_cblas_trans(transa), + convert_to_cblas_trans(transb), &m_ref, &n_ref, &k_ref, (tc_ref*)&alpha, (ta_ref*)A, + &lda_ref, (ta_ref*)B, &ldb_ref, (tc_ref*)&beta, (tc_ref*)C_ref.data(), &ldc_ref); + + // Submit oneMath gemm operation to main_queue in recording mode + try { +#ifdef CALL_RT_API + switch (layout) { + case oneapi::math::layout::col_major: + oneapi::math::blas::column_major::gemm(main_queue, transa, transb, m, n, k, + alpha, A, lda, B, ldb, beta, C.data(), + ldc); + break; + case oneapi::math::layout::row_major: + oneapi::math::blas::row_major::gemm(main_queue, transa, transb, m, n, k, alpha, + A, lda, B, ldb, beta, C.data(), ldc); + break; + default: break; + } +#else + switch (layout) { + case oneapi::math::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::math::blas::column_major::gemm, + transa, transb, m, n, k, alpha, A, lda, B, ldb, beta, + C.data(), ldc); + break; + case oneapi::math::layout::row_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::math::blas::row_major::gemm, transa, + transb, m, n, k, alpha, A, lda, B, ldb, beta, C.data(), + ldc); + break; + default: break; + } +#endif + } + catch (exception const& e) { + std::cout << "Caught synchronous SYCL exception during GEMM:\n" + << e.what() << std::endl; + print_error_code(e); + } + catch (const oneapi::math::unimplemented& e) { + return test_skipped; + } + catch (const std::runtime_error& error) { + std::cout << "Error raised during execution of GEMM:\n" << error.what() << std::endl; + } + } + // End recording of sycl queue and create executable graph + graph.end_recording(main_queue); + auto exec_graph = graph.finalize(); + + // Submit graph to execute and wait for completion + main_queue.ext_oneapi_graph(exec_graph).wait_and_throw(); + + // Verify graph output against reference + bool good = check_equal_matrix(C, C_ref, layout, m, n, ldc, 10 * k, std::cout); + return (int)good; +} + +#else // ifdef SYCL_EXT_ONEAPI_GRAPH +template +int test(device*, oneapi::math::layout, int, int, int, int, int, int, Tc, Tc) { + // Stub test for SYCL compilers that don't define the sycl_ext_oneapi_graph extension + return 1; +} +#endif + +struct GraphGemmUsmTests + : public ::testing::TestWithParam> { + virtual void SetUp() override { + // Skip test if graph recording variant and device doesn't support sycl_ext_oneapi_graph + CHECK_GRAPH_ON_DEVICE(std::get<0>(GetParam())); + } +}; + +TEST_P(GraphGemmUsmTests, RealSinglePrecision) { + device* dev = std::get<0>(GetParam()); + oneapi::math::layout layout = std::get<1>(GetParam()); + + const int m(1); + const int n(2); + const int k(3); + const int lda(4); + const int ldb(5); + const int ldc(6); + const float alpha(2.0); + const float beta(3.0); + EXPECT_TRUEORSKIP((test(dev, layout, m, n, k, lda, ldb, ldc, alpha, beta))); +} + +INSTANTIATE_TEST_SUITE_P(GraphGemmUsmTestSuite, GraphGemmUsmTests, + ::testing::Combine(testing::ValuesIn(devices), + testing::Values(oneapi::math::layout::col_major, + oneapi::math::layout::row_major)), + ::LayoutDeviceNamePrint()); + +} // anonymous namespace diff --git a/tests/unit_tests/include/test_helper.hpp b/tests/unit_tests/include/test_helper.hpp index e17bbd5a6..e27a069f4 100644 --- a/tests/unit_tests/include/test_helper.hpp +++ b/tests/unit_tests/include/test_helper.hpp @@ -73,6 +73,14 @@ if (d->get_info().size() == 0) \ GTEST_SKIP() << "Double precision is not supported on the device" +#ifdef SYCL_EXT_ONEAPI_GRAPH +#define CHECK_GRAPH_ON_DEVICE(d) \ + if (!d->has(aspect::ext_oneapi_limited_graph)) \ + GTEST_SKIP() << "SYCL-Graph is not supported on the device" +#else +#define CHECK_GRAPH_ON_DEVICE(d) GTEST_SKIP() << "SYCL-Graph is not supported on the device" +#endif + #if defined(ONEMATH_ENABLE_MKLCPU_BACKEND) || defined(ONEMATH_ENABLE_NETLIB_BACKEND) || \ defined(ONEMATH_ENABLE_ARMPL_BACKEND) #ifdef ONEMATH_ENABLE_MKLCPU_BACKEND