Skip to content

[BLAS] SYCL-Graph integration for native-command #669

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 60 additions & 16 deletions src/blas/backends/cublas/cublas_scope_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,36 +32,80 @@ namespace cublas {
*/
thread_local cublas_handle CublasScopedContextHandler::handle_helper = cublas_handle{};

CublasScopedContextHandler::CublasScopedContextHandler(sycl::interop_handle& ih) : ih(ih) {}
CublasScopedContextHandler::CublasScopedContextHandler(sycl::interop_handle& ih) : ih(ih) {
// Initialize streamID member to a CUstream associated with the queue `ih`
// has been submitted to.
streamId = ih.get_native_queue<sycl::backend::ext_oneapi_cuda>();

cublasHandle_t CublasScopedContextHandler::get_handle() {
// Initialize the `cublasHandle_t` member `nativeHandle`
CUdevice device = ih.get_native_device<sycl::backend::ext_oneapi_cuda>();
CUstream streamId = get_stream();
cublasStatus_t err;

auto it = handle_helper.cublas_handle_mapper_.find(device);
if (it != handle_helper.cublas_handle_mapper_.end()) {
cublasHandle_t nativeHandle = it->second;
// Use existing handle if one already exists for the device, but update
// the native stream.
nativeHandle = it->second;
cudaStream_t currentStreamId;
cublasStatus_t err;
CUBLAS_ERROR_FUNC(cublasGetStream, err, nativeHandle, &currentStreamId);
if (currentStreamId != streamId) {
CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId);
}
return nativeHandle;
}

cublasHandle_t nativeHandle;
CUBLAS_ERROR_FUNC(cublasCreate, err, &nativeHandle);
CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId);

auto insert_iter =
else {
// Create a new handle if one doesn't already exist for the device
cublasStatus_t err;
CUBLAS_ERROR_FUNC(cublasCreate, err, &nativeHandle);
CUBLAS_ERROR_FUNC(cublasSetStream, err, nativeHandle, streamId);
handle_helper.cublas_handle_mapper_.insert(std::make_pair(device, nativeHandle));
}
}

return nativeHandle;
void CublasScopedContextHandler::begin_recording_if_graph() {
// interop_handle graph methods only available from extension version 2
#if SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
if (!ih.ext_codeplay_has_graph()) {
return;
}

CUresult err;
#if CUDA_VERSION >= 12030
// After CUDA 12.3 we can use cuStreamBeginCaptureToGraph to capture
// the stream directly in the native graph, rather than needing to
// instantiate the stream capture as a new graph.
auto graph = ih.ext_codeplay_get_native_graph<sycl::backend::ext_oneapi_cuda>();
CUDA_ERROR_FUNC(cuStreamBeginCaptureToGraph, err, streamId, graph, nullptr, nullptr, 0,
CU_STREAM_CAPTURE_MODE_GLOBAL);
#else
CUDA_ERROR_FUNC(cuStreamBeginCapture, err, streamId, CU_STREAM_CAPTURE_MODE_GLOBAL);
#endif // CUDA_VERSION
#endif // SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
}

CUstream CublasScopedContextHandler::get_stream() {
return ih.get_native_queue<sycl::backend::ext_oneapi_cuda>();
void CublasScopedContextHandler::end_recording_if_graph() {
// interop_handle graph methods only available from extension version 2
#if SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
if (!ih.ext_codeplay_has_graph()) {
return;
}

auto graph = ih.ext_codeplay_get_native_graph<sycl::backend::ext_oneapi_cuda>();
CUresult err;
#if CUDA_VERSION >= 12030
CUDA_ERROR_FUNC(cuStreamEndCapture, err, streamId, &graph);
#else
// cuStreamEndCapture returns a new graph, if we overwrite
// "graph" it won't be picked up by the SYCL runtime, as
// "ext_codeplay_get_native_graph" returns a passed-by-value pointer.
CUgraph recorded_graph;
CUDA_ERROR_FUNC(cuStreamEndCapture, err, streamId, &recorded_graph);

// Add graph to native graph as a child node
// Need to return a node object for the node to be created,
// can't be nullptr.
CUgraphNode node;
CUDA_ERROR_FUNC(cuGraphAddChildGraphNode, err, &node, graph, nullptr, 0, recorded_graph);
#endif // CUDA_VERSION
#endif // SYCL_EXT_ONEAPI_ENQUEUE_NATIVE_COMMAND >= 2
}
} // namespace cublas
} // namespace blas
Expand Down
45 changes: 38 additions & 7 deletions src/blas/backends/cublas/cublas_scope_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,49 @@ the handle must be destroyed when the context goes out of scope. This will bind
class CublasScopedContextHandler {
sycl::interop_handle& ih;
static thread_local cublas_handle handle_helper;
CUstream get_stream();
cublasHandle_t nativeHandle;
// Cache the native CU stream when the `CublasScopedContextHandler`object
// is constructed. This avoids calling `get_native_queue(ih)` multiple
// times which isn't guaranteed to return the same CUstream handle each
// time. A scenario that causes problems when trying to start/end cuda
// stream recording to a graph.
CUstream streamId;

public:
/**
* @brief Constructor
* @detail Creates the cublasHandle_t by implicitly impose the advice
* given by nvidia for creating a cublas_handle. (e.g. one cuStream per device
* per thread).
*/
CublasScopedContextHandler(sycl::interop_handle& ih);

/**
* @brief get_handle: creates the handle by implicitly impose the advice
* given by nvidia for creating a cublas_handle. (e.g. one cuStream per device
* per thread).
* @return cublasHandle_t a handle to construct cublas routines
*/
cublasHandle_t get_handle();
* @brief Start recording cuBlas calls to a graph.
* @detail Checks if the command-group associated with \p ih is being added
* to a graph, and if so, begin stream recording of the native CUDA stream
* associated with \p queue to the native cuda-graph object.
*/
void begin_recording_if_graph();

/**
* @brief End recording cuBlas calls to a graph.
* @detail Checks if the command-group associated with \p ih is being added
* to a graph, and if so, ends stream recording of the native CUDA stream
* associated with \p queue to the native cuda-graph object. Doing any
* extra work to ensure that stream recorded calls get added as nodes to
* the native graph object associated with \p ih.
* @param queue The sycl queue to end stream recording on native stream
* backing the queue.
*/
void end_recording_if_graph();

/// @brief Query the cuBLAS handle created on construction
/// @return cublasHandle_t a handle to construct cublas routines
cublasHandle_t get_handle() const {
return nativeHandle;
}

// This is a work-around function for reinterpret_casting the memory. This
// will be fixed when SYCL-2020 has been implemented for Pi backend.
template <typename T, typename U>
Expand Down
2 changes: 2 additions & 0 deletions src/blas/backends/cublas/cublas_task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ static inline void host_task_internal(H& cgh, F f) {
cgh.host_task([f](sycl::interop_handle ih) {
#endif
auto sc = CublasScopedContextHandler(ih);
sc.begin_recording_if_graph();
f(sc);
sc.end_recording_if_graph();
});
}
#endif
Expand Down
3 changes: 2 additions & 1 deletion tests/unit_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ set(blas_TEST_LIST
blas_level2
blas_level3
blas_batch
blas_extensions)
blas_extensions
blas_sycl_graph)

set(blas_TEST_LINK "")

Expand Down
1 change: 1 addition & 0 deletions tests/unit_tests/blas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ add_subdirectory(level2)
add_subdirectory(level3)
add_subdirectory(batch)
add_subdirectory(extensions)
add_subdirectory(sycl-graph)
61 changes: 61 additions & 0 deletions tests/unit_tests/blas/sycl-graph/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#===============================================================================
# Copyright 2025 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
#
#
# SPDX-License-Identifier: Apache-2.0
#===============================================================================

# Build object from all test sources
set(SYCL_GRAPH_SOURCES)

set(SYCL_GRAPH_SOURCES_W_CBLAS "gemm_usm.cpp" "gemm_batch_usm.cpp")

if(CBLAS_FOUND)
list(APPEND SYCL_GRAPH_SOURCES ${SYCL_GRAPH_SOURCES_W_CBLAS})
endif()

if(BUILD_SHARED_LIBS)
add_library(blas_sycl_graph_rt OBJECT ${SYCL_GRAPH_SOURCES})
target_compile_options(blas_sycl_graph_rt PRIVATE -DCALL_RT_API -DNOMINMAX)
target_include_directories(blas_sycl_graph_rt
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../include
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include
PUBLIC ${PROJECT_SOURCE_DIR}/include
PUBLIC ${PROJECT_SOURCE_DIR}/deps/googletest/include
PUBLIC ${CMAKE_BINARY_DIR}/bin
$<$<BOOL:${CBLAS_FOUND}>:${CBLAS_INCLUDE}>
)
if (USE_ADD_SYCL_TO_TARGET_INTEGRATION)
add_sycl_to_target(TARGET blas_sycl_graph_rt SOURCES ${SYCL_GRAPH_SOURCES})
else()
target_link_libraries(blas_sycl_graph_rt PUBLIC ONEMATH::SYCL::SYCL)
endif()
endif()

add_library(blas_sycl_graph_ct OBJECT ${SYCL_GRAPH_SOURCES})
target_compile_options(blas_sycl_graph_ct PRIVATE -DNOMINMAX)
target_include_directories(blas_sycl_graph_ct
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../include
PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include
PUBLIC ${PROJECT_SOURCE_DIR}/include
PUBLIC ${PROJECT_SOURCE_DIR}/deps/googletest/include
PUBLIC ${CMAKE_BINARY_DIR}/bin
$<$<BOOL:${CBLAS_FOUND}>:${CBLAS_INCLUDE}>
)
if (USE_ADD_SYCL_TO_TARGET_INTEGRATION)
add_sycl_to_target(TARGET blas_sycl_graph_ct SOURCES ${SYCL_GRAPH_SOURCES})
else()
target_link_libraries(blas_sycl_graph_ct PUBLIC ONEMATH::SYCL::SYCL)
endif()
Loading
Loading