Skip to content

Move lapack_info_check inside of onemkl_cusolver_host_task #238

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: develop
Choose a base branch
from
Prev Previous commit
Next Next commit
Asynchronously free ipiv 32-bit memory
  • Loading branch information
aidan.belton committed Aug 16, 2022
commit c515f253d00843f538babbdfe8b9a3816020eaf3
18 changes: 18 additions & 0 deletions src/lapack/backends/cusolver/cusolver_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,24 @@ inline int *create_dev_info(int num_elements = 1) {
return reinterpret_cast<int *>(dev_info_d);
}

// Helper function for waiting on a vector of sycl events
inline void depends_on_events(sycl::handler &cgh, std::vector<sycl::event> &dependencies = {}) {
for (sycl::event &e : dependencies)
cgh.depends_on(e);
}

// Asynchronously frees sycl USM `ptr` after waiting on events `dependencies`
template <typename T>
inline sycl::event free_async(sycl::queue &queue, T *ptr,
std::vector<sycl::event> &dependencies = {}) {
sycl::event done = queue.submit([&](sycl::handler &cgh) {
depends_on_events(cgh, dependencies);

cgh.host_task([=](sycl::interop_handle ih) { sycl::free(ptr, queue); });
});
return done;
}

} // namespace cusolver
} // namespace lapack
} // namespace mkl
Expand Down
10 changes: 3 additions & 7 deletions src/lapack/backends/cusolver/cusolver_lapack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1367,9 +1367,8 @@ inline sycl::event getrf(const char *func_name, Func func, sycl::queue &queue, s
});
});

queue.wait();
free_async(queue, ipiv32, { done_casting });

free(ipiv32, queue);
return done_casting;
}

Expand Down Expand Up @@ -1449,9 +1448,7 @@ inline sycl::event getrs(const char *func_name, Func func, sycl::queue &queue,
});
});

queue.wait();

free(ipiv32, queue);
free_async(queue, ipiv32, { done });

return done;
}
Expand Down Expand Up @@ -2207,9 +2204,8 @@ inline sycl::event sytrf(const char *func_name, Func func, sycl::queue &queue,
});
});

queue.wait();
free_async(queue, ipiv32, { done_casting });

free(ipiv32, queue);
return done_casting;
}

Expand Down