Skip to content

Commit 18081dc

Browse files
committed
[BLAS][generic] Fix compilation with AdaptiveCpp
The generic BLAS backend currently offers limited support for AdaptiveCpp, where: * complex data type is not supported * USM API is not supported Add the required protections to make the generic BLAS backend compile and run correctly in the capacity it offers with AdaptiveCpp. That is, make the buffer USM without complex data work fine. Throw the unimplemented exception for the unsupported features. This change relies on the update of the generic blas backend to version 0.2.0 in: uxlfoundation/generic-sycl-components#7
1 parent 09ede6e commit 18081dc

File tree

4 files changed

+28
-2
lines changed

4 files changed

+28
-2
lines changed

src/blas/backends/generic/CMakeLists.txt

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ endif()
141141
# If find_package doesn't work, download onemath_sycl_blas from Github. This is
142142
# intended to make oneMath easier to use.
143143
message(STATUS "Looking for oneMATH blas kernels")
144-
find_package(ONEMATH_SYCL_BLAS QUIET)
144+
find_package(ONEMATH_SYCL_BLAS 0.2.0 QUIET)
145145
if (NOT ONEMATH_SYCL_BLAS_FOUND)
146146
message(STATUS "Looking for onemath_sycl_blas for generic backend - could not find onemath_sycl_blas with ONEMATH_SYCL_BLAS_DIR")
147147
include(FetchContent)
@@ -150,7 +150,6 @@ if (NOT ONEMATH_SYCL_BLAS_FOUND)
150150
endif()
151151
# Following variable TUNING_TARGET will be used in generic blas internal configuration
152152
set(TUNING_TARGET ${GENERIC_BLAS_TUNING_TARGET})
153-
set(BLAS_ENABLE_COMPLEX ON)
154153
# Set the policy to forward variables to generic blas configure step
155154
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
156155
set(FETCHCONTENT_BASE_DIR "${CMAKE_BINARY_DIR}/deps")
@@ -161,13 +160,21 @@ if (NOT ONEMATH_SYCL_BLAS_FOUND)
161160
SOURCE_SUBDIR onemath/sycl/blas
162161
)
163162
FetchContent_MakeAvailable(onemath_sycl_blas)
163+
install(
164+
TARGETS onemath_sycl_blas
165+
EXPORT oneMathTargets
166+
)
164167
message(STATUS "Looking for onemath_sycl_blas - downloaded")
165168

166169
else()
167170
message(STATUS "Looking for oneMath blas kernels - found")
168171
add_library(onemath_sycl_blas ALIAS ONEMATH_SYCL_BLAS::onemath_sycl_blas)
169172
endif()
170173

174+
# Read cmake options exported by the onemath_sycl_blas project into oneMath variables
175+
set(ONEMATH_GENERIC_BLAS_ENABLE_COMPLEX ${BLAS_ENABLE_COMPLEX} CACHE INTERNAL "Enable complex data support")
176+
set(ONEMATH_GENERIC_BLAS_ENABLE_USM ${BLAS_ENABLE_USM} CACHE INTERNAL "Enable USM API support")
177+
171178
set(SOURCES
172179
generic_level1_double.cpp generic_level1_float.cpp
173180
generic_level2_double.cpp generic_level2_float.cpp

src/blas/backends/generic/generic_common.hpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#define _GENERIC_BLAS_COMMON_HPP_
2222

2323
#include "onemath_sycl_blas.hpp"
24+
#include "oneapi/math/detail/config.hpp"
2425
#include "oneapi/math/types.hpp"
2526
#include "oneapi/math/exceptions.hpp"
2627

@@ -40,9 +41,11 @@ using handle_t = ::blas::SB_Handle;
4041
template <typename ElemT>
4142
using buffer_iterator_t = ::blas::BufferIterator<ElemT>;
4243

44+
#ifdef ONEMATH_GENERIC_BLAS_ENABLE_COMPLEX
4345
// sycl complex data type (experimental)
4446
template <typename ElemT>
4547
using sycl_complex_t = sycl::ext::oneapi::experimental::complex<ElemT>;
48+
#endif
4649

4750
/** A trait for obtaining equivalent onemath_sycl_blas API types from oneMath API
4851
* types.
@@ -68,8 +71,10 @@ DEF_GENERIC_BLAS_TYPE(oneapi::math::transpose, char)
6871
DEF_GENERIC_BLAS_TYPE(oneapi::math::uplo, char)
6972
DEF_GENERIC_BLAS_TYPE(oneapi::math::side, char)
7073
DEF_GENERIC_BLAS_TYPE(oneapi::math::diag, char)
74+
#ifdef ONEMATH_GENERIC_BLAS_ENABLE_COMPLEX
7175
DEF_GENERIC_BLAS_TYPE(std::complex<float>, sycl_complex_t<float>)
7276
DEF_GENERIC_BLAS_TYPE(std::complex<double>, sycl_complex_t<double>)
77+
#endif
7378
// Passthrough of onemath_sycl_blas arg types for more complex wrapping.
7479
DEF_GENERIC_BLAS_TYPE(::blas::gemm_batch_type_t, ::blas::gemm_batch_type_t)
7580

@@ -85,6 +90,7 @@ struct generic_type<ElemT*> {
8590
using type = ElemT*;
8691
};
8792

93+
#ifdef ONEMATH_GENERIC_BLAS_ENABLE_COMPLEX
8894
// USM Complex
8995
template <typename ElemT>
9096
struct generic_type<std::complex<ElemT>*> {
@@ -95,6 +101,7 @@ template <typename ElemT>
95101
struct generic_type<const std::complex<ElemT>*> {
96102
using type = const sycl_complex_t<ElemT>*;
97103
};
104+
#endif
98105

99106
template <>
100107
struct generic_type<std::vector<sycl::event>> {
@@ -210,6 +217,10 @@ struct throw_if_unsupported_by_device {
210217
throw unimplemented("blas", "onemath_sycl_blas function"); \
211218
}
212219

220+
#ifndef ONEMATH_GENERIC_BLAS_ENABLE_USM
221+
#define CALL_GENERIC_BLAS_USM_FN(genericFunc, ...) \
222+
throw unimplemented("blas", "onemath_sycl_blas USM API", "- unsupported compiler");
223+
#else
213224
#define CALL_GENERIC_BLAS_USM_FN(genericFunc, ...) \
214225
if constexpr (is_column_major()) { \
215226
detail::throw_if_unsupported_by_device<double, sycl::aspect::fp64>{}( \
@@ -230,6 +241,7 @@ struct throw_if_unsupported_by_device {
230241
else { \
231242
throw unimplemented("blas", "onemath_sycl_blas function"); \
232243
}
244+
#endif
233245

234246
} // namespace generic
235247
} // namespace blas

src/blas/backends/generic/generic_level3.cxx

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ void gemm(sycl::queue& queue, oneapi::math::transpose transa, oneapi::math::tran
3232
sycl::buffer<std::complex<real_t>, 1>& a, std::int64_t lda,
3333
sycl::buffer<std::complex<real_t>, 1>& b, std::int64_t ldb, std::complex<real_t> beta,
3434
sycl::buffer<std::complex<real_t>, 1>& c, std::int64_t ldc) {
35+
#ifndef ONEMATH_GENERIC_BLAS_ENABLE_COMPLEX
36+
throw unimplemented("blas", "onemath_sycl_blas gemm with complex data type",
37+
"- unsupported compiler");
38+
#else
3539
using sycl_complex_real_t = sycl::ext::oneapi::experimental::complex<real_t>;
3640
if (transa == oneapi::math::transpose::conjtrans ||
3741
transb == oneapi::math::transpose::conjtrans) {
@@ -62,6 +66,7 @@ void gemm(sycl::queue& queue, oneapi::math::transpose transa, oneapi::math::tran
6266
sycl::accessor<std::complex<real_t>, 1, sycl::access::mode::write> out_acc(c);
6367
sycl::accessor<sycl_complex_real_t, 1, sycl::access::mode::read> out_pb_acc(c_pb);
6468
queue.copy(out_pb_acc, out_acc);
69+
#endif
6570
}
6671

6772
void symm(sycl::queue& queue, oneapi::math::side left_right, oneapi::math::uplo upper_lower,

src/config.hpp.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
#cmakedefine ONEMATH_ENABLE_GENERIC_BLAS_BACKEND_INTEL_CPU
3535
#cmakedefine ONEMATH_ENABLE_GENERIC_BLAS_BACKEND_INTEL_GPU
3636
#cmakedefine ONEMATH_ENABLE_GENERIC_BLAS_BACKEND_NVIDIA_GPU
37+
#cmakedefine ONEMATH_GENERIC_BLAS_ENABLE_COMPLEX
38+
#cmakedefine ONEMATH_GENERIC_BLAS_ENABLE_USM
3739
#cmakedefine ONEMATH_ENABLE_PORTFFT_BACKEND
3840
#cmakedefine ONEMATH_ENABLE_ROCBLAS_BACKEND
3941
#cmakedefine ONEMATH_ENABLE_ROCFFT_BACKEND

0 commit comments

Comments
 (0)