Skip to content

feat: add support for compio runtime #1364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 74 additions & 100 deletions .github/workflows/ci.yaml

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,5 @@ exclude = [
"examples/raft-kv-memstore-opendal-snapshot-data",
"examples/raft-kv-rocksdb",
"rt-monoio",
"rt-compio"
]
6 changes: 6 additions & 0 deletions openraft/src/instant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,9 @@ impl Instant for tokio::time::Instant {
tokio::time::Instant::now()
}
}

impl Instant for std::time::Instant {
fn now() -> Self {
std::time::Instant::now()
}
}
24 changes: 24 additions & 0 deletions rt-compio/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[package]
name = "openraft-rt-compio"
description = "compio AsyncRuntime support for Openraft"
documentation = "https://docs.rs/openraft-rt-compio"
readme = "README.md"
version = "0.10.0"
edition = "2021"
authors = ["Databend Authors <[email protected]>"]
categories = ["algorithms", "asynchronous", "data-structures"]
homepage = "https://github.com/databendlabs/openraft"
keywords = ["consensus", "raft"]
license = "MIT OR Apache-2.0"
repository = "https://github.com/databendlabs/openraft"

[dependencies]
openraft = { path = "../openraft", version = "0.10.0", default-features = false, features = ["singlethreaded"] }
compio = { version = "0.14.0", features = ["runtime", "time"] }
tokio = { version = "1.22", features = ["sync"], default-features = false }
rand = "0.9"
futures = "0.3"
pin-project-lite = "0.2.16"

[dev-dependencies]
compio = { version = "0.14.0", features = ["macros"] }
5 changes: 5 additions & 0 deletions rt-compio/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# openraft-rt-compio

compio [`AsyncRuntime`][rt_link] support for Openraft.

[rt_link]: https://docs.rs/openraft/latest/openraft/async_runtime/trait.AsyncRuntime.html
158 changes: 158 additions & 0 deletions rt-compio/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
use std::any::Any;
use std::fmt::Debug;
use std::fmt::Display;
use std::fmt::Error;
use std::fmt::Formatter;
use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;

pub use compio;
pub use futures;
use futures::FutureExt;
pub use openraft;
use openraft::AsyncRuntime;
use openraft::OptionalSend;
pub use rand;
use rand::rngs::ThreadRng;

mod mpsc;
mod mpsc_unbounded;
mod mutex;
mod oneshot;
mod watch;

#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct CompioRuntime;

#[derive(Debug)]
pub struct CompioJoinError(#[allow(dead_code)] Box<dyn Any + Send>);

impl Display for CompioJoinError {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
write!(f, "Spawned task panicked")
}
}

pub struct CompioJoinHandle<T>(Option<compio::runtime::JoinHandle<T>>);

impl<T> Drop for CompioJoinHandle<T> {
fn drop(&mut self) {
let Some(j) = self.0.take() else {
return;
};
j.detach();
}
}

impl<T> Future for CompioJoinHandle<T> {
type Output = Result<T, CompioJoinError>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
let task = this.0.as_mut().expect("Task has been cancelled");
match task.poll_unpin(cx) {
Poll::Ready(Ok(v)) => Poll::Ready(Ok(v)),
Poll::Ready(Err(e)) => Poll::Ready(Err(CompioJoinError(e))),
Poll::Pending => Poll::Pending,
}
}
}

pub type BoxedFuture<T> = Pin<Box<dyn Future<Output = T>>>;

pin_project_lite::pin_project! {
pub struct CompioTimeout<F> {
#[pin]
future: F,
delay: BoxedFuture<()>
}
}

impl<F: Future> Future for CompioTimeout<F> {
type Output = Result<F::Output, compio::time::Elapsed>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.delay.poll_unpin(cx) {
Poll::Ready(()) => {
// The delay has elapsed, so we return an error.
Poll::Ready(Err(compio::time::Elapsed))
}
Poll::Pending => {
// The delay has not yet elapsed, so we poll the future.
match this.future.poll(cx) {
Poll::Ready(v) => Poll::Ready(Ok(v)),
Poll::Pending => Poll::Pending,
}
}
}
}
}

impl AsyncRuntime for CompioRuntime {
type JoinError = CompioJoinError;
type JoinHandle<T: OptionalSend + 'static> = CompioJoinHandle<T>;
type Sleep = BoxedFuture<()>;
type Instant = std::time::Instant;
type TimeoutError = compio::time::Elapsed;
type Timeout<R, T: Future<Output = R> + OptionalSend> = CompioTimeout<T>;
type ThreadLocalRng = ThreadRng;
type Mpsc = mpsc::CompioMpsc;
type MpscUnbounded = mpsc_unbounded::TokioMpscUnbounded;
type Watch = watch::TokioWatch;
type Oneshot = oneshot::FuturesOneshot;
type Mutex<T: OptionalSend + 'static> = mutex::TokioMutex<T>;

fn spawn<T>(fut: T) -> Self::JoinHandle<T::Output>
where
T: Future + OptionalSend + 'static,
T::Output: OptionalSend + 'static,
{
CompioJoinHandle(Some(compio::runtime::spawn(fut)))
}

fn sleep(duration: std::time::Duration) -> Self::Sleep {
Box::pin(compio::time::sleep(duration))
}

fn sleep_until(deadline: Self::Instant) -> Self::Sleep {
Box::pin(compio::time::sleep_until(deadline))
}

fn timeout<R, F: Future<Output = R> + OptionalSend>(
duration: std::time::Duration,
future: F,
) -> Self::Timeout<R, F> {
let delay = Box::pin(compio::time::sleep(duration));
CompioTimeout { future, delay }
}

fn timeout_at<R, F: Future<Output = R> + OptionalSend>(deadline: Self::Instant, future: F) -> Self::Timeout<R, F> {
let delay = Box::pin(compio::time::sleep_until(deadline));
CompioTimeout { future, delay }
}

fn is_panic(_: &Self::JoinError) -> bool {
// Task only returns `JoinError` if the spawned future panics.
true
}

fn thread_rng() -> Self::ThreadLocalRng {
rand::rng()
}
}

#[cfg(test)]
mod tests {
use openraft::testing::runtime::Suite;

use super::*;

#[test]
fn test_compio_rt() {
let rt = compio::runtime::Runtime::new().unwrap();
rt.block_on(Suite::<CompioRuntime>::test_all());
}
}
87 changes: 87 additions & 0 deletions rt-compio/src/mpsc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use std::future::Future;

use futures::TryFutureExt;
use openraft::async_runtime::Mpsc;
use openraft::async_runtime::MpscReceiver;
use openraft::async_runtime::MpscSender;
use openraft::async_runtime::MpscWeakSender;
use openraft::async_runtime::SendError;
use openraft::async_runtime::TryRecvError;
use openraft::OptionalSend;
use tokio::sync::mpsc as tokio_mpsc;

pub struct CompioMpsc;

pub struct CompioMpscSender<T>(tokio_mpsc::Sender<T>);

impl<T> Clone for CompioMpscSender<T> {
#[inline]
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

pub struct CompioMpscReceiver<T>(tokio_mpsc::Receiver<T>);

pub struct CompioMpscWeakSender<T>(tokio_mpsc::WeakSender<T>);

impl<T> Clone for CompioMpscWeakSender<T> {
#[inline]
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

impl Mpsc for CompioMpsc {
type Sender<T: OptionalSend> = CompioMpscSender<T>;
type Receiver<T: OptionalSend> = CompioMpscReceiver<T>;
type WeakSender<T: OptionalSend> = CompioMpscWeakSender<T>;

#[inline]
fn channel<T: OptionalSend>(buffer: usize) -> (Self::Sender<T>, Self::Receiver<T>) {
let (tx, rx) = tokio_mpsc::channel(buffer);
let tx_wrapper = CompioMpscSender(tx);
let rx_wrapper = CompioMpscReceiver(rx);

(tx_wrapper, rx_wrapper)
}
}

impl<T> MpscSender<CompioMpsc, T> for CompioMpscSender<T>
where T: OptionalSend
{
#[inline]
fn send(&self, msg: T) -> impl Future<Output = Result<(), SendError<T>>> {
self.0.send(msg).map_err(|e| SendError(e.0))
}

#[inline]
fn downgrade(&self) -> <CompioMpsc as Mpsc>::WeakSender<T> {
let inner = self.0.downgrade();
CompioMpscWeakSender(inner)
}
}

impl<T> MpscReceiver<T> for CompioMpscReceiver<T> {
#[inline]
fn recv(&mut self) -> impl Future<Output = Option<T>> {
self.0.recv()
}

#[inline]
fn try_recv(&mut self) -> Result<T, TryRecvError> {
self.0.try_recv().map_err(|e| match e {
tokio_mpsc::error::TryRecvError::Empty => TryRecvError::Empty,
tokio_mpsc::error::TryRecvError::Disconnected => TryRecvError::Disconnected,
})
}
}

impl<T> MpscWeakSender<CompioMpsc, T> for CompioMpscWeakSender<T>
where T: OptionalSend
{
#[inline]
fn upgrade(&self) -> Option<<CompioMpsc as Mpsc>::Sender<T>> {
self.0.upgrade().map(CompioMpscSender)
}
}
81 changes: 81 additions & 0 deletions rt-compio/src/mpsc_unbounded.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//! Unbounded MPSC channel wrapper types and their trait impl.

use openraft::type_config::async_runtime::mpsc_unbounded;
use openraft::OptionalSend;
use tokio::sync::mpsc as tokio_mpsc;

pub struct TokioMpscUnbounded;

pub struct TokioMpscUnboundedSender<T>(tokio_mpsc::UnboundedSender<T>);

impl<T> Clone for TokioMpscUnboundedSender<T> {
#[inline]
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

pub struct TokioMpscUnboundedReceiver<T>(tokio_mpsc::UnboundedReceiver<T>);

pub struct TokioMpscUnboundedWeakSender<T>(tokio_mpsc::WeakUnboundedSender<T>);

impl<T> Clone for TokioMpscUnboundedWeakSender<T> {
#[inline]
fn clone(&self) -> Self {
Self(self.0.clone())
}
}

impl mpsc_unbounded::MpscUnbounded for TokioMpscUnbounded {
type Sender<T: OptionalSend> = TokioMpscUnboundedSender<T>;
type Receiver<T: OptionalSend> = TokioMpscUnboundedReceiver<T>;
type WeakSender<T: OptionalSend> = TokioMpscUnboundedWeakSender<T>;

#[inline]
fn channel<T: OptionalSend>() -> (Self::Sender<T>, Self::Receiver<T>) {
let (tx, rx) = tokio_mpsc::unbounded_channel();
let tx_wrapper = TokioMpscUnboundedSender(tx);
let rx_wrapper = TokioMpscUnboundedReceiver(rx);

(tx_wrapper, rx_wrapper)
}
}

impl<T> mpsc_unbounded::MpscUnboundedSender<TokioMpscUnbounded, T> for TokioMpscUnboundedSender<T>
where T: OptionalSend
{
#[inline]
fn send(&self, msg: T) -> Result<(), mpsc_unbounded::SendError<T>> {
self.0.send(msg).map_err(|e| mpsc_unbounded::SendError(e.0))
}

#[inline]
fn downgrade(&self) -> <TokioMpscUnbounded as mpsc_unbounded::MpscUnbounded>::WeakSender<T> {
let inner = self.0.downgrade();
TokioMpscUnboundedWeakSender(inner)
}
}

impl<T> mpsc_unbounded::MpscUnboundedReceiver<T> for TokioMpscUnboundedReceiver<T> {
#[inline]
async fn recv(&mut self) -> Option<T> {
self.0.recv().await
}

#[inline]
fn try_recv(&mut self) -> Result<T, mpsc_unbounded::TryRecvError> {
self.0.try_recv().map_err(|e| match e {
tokio_mpsc::error::TryRecvError::Empty => mpsc_unbounded::TryRecvError::Empty,
tokio_mpsc::error::TryRecvError::Disconnected => mpsc_unbounded::TryRecvError::Disconnected,
})
}
}

impl<T> mpsc_unbounded::MpscUnboundedWeakSender<TokioMpscUnbounded, T> for TokioMpscUnboundedWeakSender<T>
where T: OptionalSend
{
#[inline]
fn upgrade(&self) -> Option<<TokioMpscUnbounded as mpsc_unbounded::MpscUnbounded>::Sender<T>> {
self.0.upgrade().map(TokioMpscUnboundedSender)
}
}
22 changes: 22 additions & 0 deletions rt-compio/src/mutex.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use std::future::Future;

use openraft::type_config::async_runtime::mutex;
use openraft::OptionalSend;

pub struct TokioMutex<T>(tokio::sync::Mutex<T>);

impl<T> mutex::Mutex<T> for TokioMutex<T>
where T: OptionalSend + 'static
{
type Guard<'a> = tokio::sync::MutexGuard<'a, T>;

#[inline]
fn new(value: T) -> Self {
TokioMutex(tokio::sync::Mutex::new(value))
}

#[inline]
fn lock(&self) -> impl Future<Output = Self::Guard<'_>> + OptionalSend {
self.0.lock()
}
}
Loading
Loading