From 8cadfe6a3b7af1fa5854d19e68cd1a45574c38f7 Mon Sep 17 00:00:00 2001 From: tison Date: Fri, 30 May 2025 18:07:59 +0800 Subject: [PATCH 1/3] feat: oneshot channel Signed-off-by: tison --- README.md | 2 + mea/src/lib.rs | 17 + mea/src/oneshot/mod.rs | 743 +++++++++++++++++++++++++++++++++++++++ mea/src/oneshot/tests.rs | 133 +++++++ 4 files changed, 895 insertions(+) create mode 100644 mea/src/oneshot/mod.rs create mode 100644 mea/src/oneshot/tests.rs diff --git a/README.md b/README.md index cf39866..a5cfe97 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ Mea (Make Easy Async) is a runtime-agnostic library providing essential synchron * [**Semaphore**](https://docs.rs/mea/*/mea/semaphore/struct.Semaphore.html): A synchronization primitive that controls access to a shared resource. * [**ShutdownSend & ShutdownRecv**](https://docs.rs/mea/*/mea/shutdown/): A composite synchronization primitive for managing shutdown signals. * [**WaitGroup**](https://docs.rs/mea/*/mea/waitgroup/struct.WaitGroup.html): A synchronization primitive that allows waiting for multiple tasks to complete. +* [**oneshot::channel**](https://docs.rs/mea/*/mea/oneshot/index.html): A one-shot channel for sending a single value between tasks. ## Installation @@ -68,6 +69,7 @@ This crate collects runtime-agnostic synchronization primitives from spare parts * **RwLock** is derived from `tokio::sync::RwLock`, but the `max_readers` can be any `usize` instead of `[0, u32::MAX >> 3]`. No blocking method is provided, since it can be easily implemented with block_on of any runtime. * **Semaphore** is derived from `tokio::sync::Semaphore`, without `close` method since it is quite tricky to use. And thus, this semaphore doesn't have the limitation of max permits. Besides, new methods like `forget_exact` are added to fit the specific use case. * **WaitGroup** is inspired by [`waitgroup-rs`](https://github.com/laizy/waitgroup-rs), with a different implementation based on the internal `CountdownState` primitive. It fixes the unsound issue as described [here](https://github.com/rust-lang/futures-rs/issues/2880#issuecomment-2333842804). +* **oneshot::channel** is derived from [`oneshot`](https://github.com/faern/oneshot), with significant simplifications since we need not support synchronized receiving functions. Other parts are written from scratch. diff --git a/mea/src/lib.rs b/mea/src/lib.rs index f67efbb..2622076 100644 --- a/mea/src/lib.rs +++ b/mea/src/lib.rs @@ -33,6 +33,7 @@ //! * [`ShutdownSend`] & [`ShutdownRecv`]: A composite synchronization primitive for managing //! shutdown signals //! * [`WaitGroup`]: A synchronization primitive that allows waiting for multiple tasks to complete +//! * [`oneshot`]: A one-shot channel for sending a single value between tasks. //! //! ## Runtime Agnostic //! @@ -62,6 +63,7 @@ pub mod barrier; pub mod condvar; pub mod latch; pub mod mutex; +pub mod oneshot; pub mod rwlock; pub mod semaphore; pub mod shutdown; @@ -83,12 +85,14 @@ mod tests { use crate::latch::Latch; use crate::mutex::Mutex; use crate::mutex::MutexGuard; + use crate::oneshot; use crate::rwlock::RwLock; use crate::rwlock::RwLockReadGuard; use crate::rwlock::RwLockWriteGuard; use crate::semaphore::Semaphore; use crate::shutdown::ShutdownRecv; use crate::shutdown::ShutdownSend; + use crate::waitgroup::Wait; use crate::waitgroup::WaitGroup; #[test] @@ -106,6 +110,15 @@ mod tests { do_assert_send_and_sync::>(); do_assert_send_and_sync::>(); do_assert_send_and_sync::>(); + do_assert_send_and_sync::>(); + do_assert_send_and_sync::>(); + } + + #[test] + fn assert_send() { + fn do_assert_send() {} + do_assert_send::>(); + do_assert_send::>(); } #[test] @@ -118,10 +131,14 @@ mod tests { do_assert_unpin::(); do_assert_unpin::(); do_assert_unpin::(); + do_assert_unpin::(); do_assert_unpin::>(); do_assert_unpin::>(); do_assert_unpin::>(); do_assert_unpin::>(); do_assert_unpin::>(); + do_assert_unpin::>(); + do_assert_unpin::>(); + do_assert_unpin::>(); } } diff --git a/mea/src/oneshot/mod.rs b/mea/src/oneshot/mod.rs new file mode 100644 index 0000000..8f9142e --- /dev/null +++ b/mea/src/oneshot/mod.rs @@ -0,0 +1,743 @@ +// Copyright 2024 tison +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This implementation is derived from the `oneshot` crate [1], with significant simplifications +// since mea needs not support synchronized receiving functions. +// +// [1] https://github.com/faern/oneshot/blob/25274e99/src/lib.rs + +//! A one-shot channel is used for sending a single message between +//! asynchronous tasks. The [`channel`] function is used to create a +//! [`Sender`] and [`Receiver`] handle pair that form the channel. +//! +//! The `Sender` handle is used by the producer to send the value. +//! The `Receiver` handle is used by the consumer to receive the value. +//! +//! Each handle can be used on separate tasks. +//! +//! Since the `send` method is not async, it can be used anywhere. This includes +//! sending between two runtimes, and using it from non-async code. +//! +//! # Examples +//! +//! ``` +//! # #[tokio::main] +//! # async fn main() { +//! use mea::oneshot; +//! +//! let (tx, rx) = oneshot::channel(); +//! +//! tokio::spawn(async move { +//! if let Err(_) = tx.send(3) { +//! println!("the receiver dropped"); +//! } +//! }); +//! +//! match rx.await { +//! Ok(v) => println!("got = {:?}", v), +//! Err(_) => println!("the sender dropped"), +//! } +//! # } +//! ``` +//! +//! If the sender is dropped without sending, the receiver will fail with +//! [`RecvError`]: +//! +//! ``` +//! # #[tokio::main] +//! # async fn main() { +//! use mea::oneshot; +//! +//! let (tx, rx) = oneshot::channel::(); +//! +//! tokio::spawn(async move { +//! drop(tx); +//! }); +//! +//! match rx.await { +//! Ok(_) => panic!("This doesn't happen"), +//! Err(_) => println!("the sender dropped"), +//! } +//! # } +//! ``` + +use std::cell::UnsafeCell; +use std::fmt; +use std::future::Future; +use std::future::IntoFuture; +use std::hint; +use std::marker::PhantomData; +use std::mem; +use std::mem::MaybeUninit; +use std::pin::Pin; +use std::ptr; +use std::ptr::NonNull; +use std::sync::atomic::fence; +use std::sync::atomic::AtomicU8; +use std::sync::atomic::Ordering; +use std::task::Context; +use std::task::Poll; +use std::task::Waker; + +#[cfg(test)] +mod tests; + +/// Creates a new oneshot channel and returns the two endpoints, [`Sender`] and [`Receiver`]. +pub fn channel() -> (Sender, Receiver) { + let channel_ptr = NonNull::from(Box::leak(Box::new(Channel::new()))); + let sender = Sender { + channel_ptr, + _invariant: PhantomData, + }; + let receiver = Receiver { channel_ptr }; + (sender, receiver) +} + +/// Sends a value to the associated [`Receiver`]. +#[derive(Debug)] +pub struct Sender { + channel_ptr: NonNull>, + _invariant: PhantomData T>, +} + +unsafe impl Send for Sender {} +unsafe impl Sync for Sender {} + +#[inline(always)] +fn sender_wake_up_receiver(channel: &Channel, state: u8) { + // ORDERING: Synchronizes with writing waker to memory, and prevents the + // taking of the waker from being ordered before this operation. + fence(Ordering::Acquire); + + // Take the waker, but critically do not awake it. If we awake it now, the + // receiving thread could still observe the AWAKING state and re-await, meaning + // that after we change to the MESSAGE state, it would remain waiting indefinitely + // or until a spurious wakeup. + // + // SAFETY: at this point we are in the AWAKING state, and the receiving thread + // does not access the waker while in this state, nor does it free the channel + // allocation in this state. + let waker = unsafe { channel.take_waker() }; + + // ORDERING: this ordering serves two-fold: it synchronizes with the acquire load + // in the receiving thread, ensuring that both our read of the waker and write of + // the message happen-before the taking of the message and freeing of the channel. + // Furthermore, we need acquire ordering to ensure awaking the receiver + // happens after the channel state is updated. + channel.state.swap(state, Ordering::AcqRel); + + // Note: it is possible that between the store above and this statement that + // the receiving thread is spuriously awakened, takes the message, and frees + // the channel allocation. However, we took ownership of the channel out of + // that allocation, and freeing the channel does not drop the waker since the + // waker is wrapped in MaybeUninit. Therefore, this data is valid regardless of + // whether the receiver has completed by this point. + waker.wake(); +} + +impl Sender { + /// Attempts to send a value on this channel, returning an error contains the message if it + /// could not be sent. + pub fn send(self, message: T) -> Result<(), SendError> { + let channel_ptr = self.channel_ptr; + + // Do not run the Drop implementation if send was called, any cleanup happens below. + mem::forget(self); + + // SAFETY: The channel exists on the heap for the entire duration of this method, and we + // only ever acquire shared references to it. Note that if the receiver disconnects it + // does not free the channel. + let channel = unsafe { channel_ptr.as_ref() }; + + // Write the message into the channel on the heap. + // + // SAFETY: The receiver only ever accesses this memory location if we are in the MESSAGE + // state, and since we are responsible for setting that state, we can guarantee that we have + // exclusive access to this memory location to perform this write. + unsafe { channel.write_message(message) }; + + // Update the state to signal there is a message on the channel: + // + // * EMPTY + 1 = MESSAGE + // * RECEIVING + 1 = AWAKING + // * DISCONNECTED + 1 = EMPTY (invalid), however this state is never observed + // + // ORDERING: we use release ordering to ensure writing the message is visible to the + // receiving thread. The EMPTY and DISCONNECTED branches do not observe any shared state, + // and thus we do not need an acquire ordering. The RECEIVING branch manages synchronization + // independent of this operation. + match channel.state.fetch_add(1, Ordering::Release) { + // The receiver is alive and has not started waiting. Send done. + EMPTY => Ok(()), + // The receiver is waiting. Wake it up so it can return the message. + RECEIVING => { + sender_wake_up_receiver(channel, MESSAGE); + Ok(()) + } + // The receiver was already dropped. The error is responsible for freeing the channel. + // + // SAFETY: since the receiver disconnected it will no longer access `channel_ptr`, so + // we can transfer exclusive ownership of the channel's resources to the error. + // Moreover, since we just placed the message in the channel, the channel contains a + // valid message. + DISCONNECTED => Err(SendError { channel_ptr }), + state => unreachable!("unexpected channel state: {}", state), + } + } + + /// Returns true if the associated [`Receiver`] has been dropped. + /// + /// If true is returned, a future call to send is guaranteed to return an error. + pub fn is_closed(&self) -> bool { + // SAFETY: The channel exists on the heap for the entire duration of this method, and we + // only ever acquire shared references to it. Note that if the receiver disconnects it + // does not free the channel. + let channel = unsafe { self.channel_ptr.as_ref() }; + + // ORDERING: We *chose* a Relaxed ordering here as it sufficient to enforce the method's + // contract: "if true is returned, a future call to send is guaranteed to return an error." + // + // Once true has been observed, it will remain true. However, if false is observed, + // the receiver might have just disconnected but this thread has not observed it yet. + matches!(channel.state.load(Ordering::Relaxed), DISCONNECTED) + } +} + +impl Drop for Sender { + fn drop(&mut self) { + // SAFETY: The receiver only ever frees the channel if we are in the MESSAGE or + // DISCONNECTED states. + // + // * If we are in the MESSAGE state, then we called mem::forget(self), so we should + // not be in this function call. + // * If we are in the DISCONNECTED state, then the receiver either received a MESSAGE + // so this statement is unreachable, or was dropped and observed that our side was still + // alive, and thus didn't free the channel. + let channel = unsafe { self.channel_ptr.as_ref() }; + + // Update the channel state to disconnected: + // + // * EMPTY ^ 001 = DISCONNECTED + // * RECEIVING ^ 001 = AWAKING + // * DISCONNECTED ^ 001 = EMPTY (invalid), but this state is never observed + // + // ORDERING: we need not release ordering here since there are no modifications we + // need to make visible to other thread, and the Err(RECEIVING) branch handles + // synchronization independent of this fetch_xor + match channel.state.fetch_xor(0b001, Ordering::Relaxed) { + // The receiver has not started waiting, nor is it dropped. + EMPTY => {} + // The receiver is waiting. Wake it up so it can detect that the channel disconnected. + RECEIVING => sender_wake_up_receiver(channel, DISCONNECTED), + // The receiver was already dropped. We are responsible for freeing the channel. + DISCONNECTED => { + // SAFETY: when the receiver switches the state to DISCONNECTED they have received + // the message or will no longer be trying to receive the message, and have + // observed that the sender is still alive, meaning that we are responsible for + // freeing the channel allocation. + unsafe { dealloc(self.channel_ptr) }; + } + state => unreachable!("unexpected channel state: {}", state), + } + } +} + +/// Receives a value from the associated [`Sender`]. +#[derive(Debug)] +pub struct Receiver { + channel_ptr: NonNull>, +} + +unsafe impl Send for Receiver {} + +impl IntoFuture for Receiver { + type Output = Result; + + type IntoFuture = Recv; + + fn into_future(self) -> Self::IntoFuture { + let Receiver { channel_ptr } = self; + // Do not run our Drop implementation, since the receiver lives on as the new future. + mem::forget(self); + Recv { channel_ptr } + } +} + +impl Receiver { + /// Returns true if the associated [`Sender`] was dropped before sending a message. Or if + /// the message has already been received. + /// + /// If `true` is returned, all future calls to receive the message are guaranteed to return + /// [`RecvError`]. And future calls to this method is guaranteed to also return `true`. + pub fn is_closed(&self) -> bool { + // SAFETY: the existence of the `self` parameter serves as a certificate that the receiver + // is still alive, meaning that even if the sender was dropped then it would have observed + // the fact that we are still alive and left the responsibility of deallocating the + // channel to us, so `self.channel` is valid + let channel = unsafe { self.channel_ptr.as_ref() }; + + // ORDERING: We *chose* a Relaxed ordering here as it is sufficient to + // enforce the method's contract. + // + // Once true has been observed, it will remain true. However, if false is observed, + // the sender might have just disconnected but this thread has not observed it yet. + matches!(channel.state.load(Ordering::Relaxed), DISCONNECTED) + } + + /// Returns true if there is a message in the channel, ready to be received. + /// + /// If `true` is returned, the next call to receive the message is guaranteed to return + /// the message immediately. + pub fn has_message(&self) -> bool { + // SAFETY: the existence of the `self` parameter serves as a certificate that the receiver + // is still alive, meaning that even if the sender was dropped then it would have observed + // the fact that we are still alive and left the responsibility of deallocating the + // channel to us, so `self.channel` is valid + let channel = unsafe { self.channel_ptr.as_ref() }; + + // ORDERING: An acquire ordering is used to guarantee no subsequent loads is reordered + // before this one. This upholds the contract that if true is returned, the next call to + // receive the message is guaranteed to also observe the `MESSAGE` state and return the + // message immediately. + matches!(channel.state.load(Ordering::Acquire), MESSAGE) + } + + /// Checks if there is a message in the channel without blocking. Returns: + /// + /// * `Ok(Some(message))` if there was a message in the channel. + /// * `Ok(None)` if the [`Sender`] is alive, but has not yet sent a message. + /// * `Err(RecvError)` if the [`Sender`] was dropped before sending anything or if the message + /// has already been extracted by a previous `try_recv` call. + /// + /// If a message is returned, the channel is disconnected and any subsequent receive operation + /// using this receiver will return an [`RecvError`]. + pub fn try_recv(&self) -> Result, RecvError> { + // SAFETY: The channel will not be freed while this method is still running. + let channel = unsafe { self.channel_ptr.as_ref() }; + + // ORDERING: we use acquire ordering to synchronize with the store of the message. + match channel.state.load(Ordering::Acquire) { + EMPTY => Ok(None), + DISCONNECTED => Err(RecvError(())), + MESSAGE => { + // It is okay to break up the load and store since once we are in the MESSAGE state, + // the sender no longer modifies the state + // + // ORDERING: at this point the sender has done its job and is no longer active, so + // we need not make any side effects visible to it. + channel.state.store(DISCONNECTED, Ordering::Relaxed); + + // SAFETY: we are in the MESSAGE state so the message is present + Ok(Some(unsafe { channel.take_message() })) + } + state => unreachable!("unexpected channel state: {}", state), + } + } +} + +impl Drop for Receiver { + fn drop(&mut self) { + // SAFETY: since the receiving side is still alive the sender would have observed that and + // left deallocating the channel allocation to us. + let channel = unsafe { self.channel_ptr.as_ref() }; + + // Set the channel state to disconnected and read what state the receiver was in. + match channel.state.swap(DISCONNECTED, Ordering::Acquire) { + // The sender has not sent anything, nor is it dropped. + EMPTY => {} + // The sender already sent something. We must drop it, and free the channel. + MESSAGE => { + unsafe { channel.drop_message() }; + unsafe { dealloc(self.channel_ptr) }; + } + // The sender was already dropped. We are responsible for freeing the channel. + DISCONNECTED => { + unsafe { dealloc(self.channel_ptr) }; + } + // NOTE: the receiver, unless transformed into a future, will never see the + // RECEIVING or AWAKING states, so we can ignore them here. + state => unreachable!("unexpected channel state: {}", state), + } + } +} + +/// A future that completes when the message is sent from the associated [`Sender`], or the +/// [`Sender`] is dropped before sending a message. +#[derive(Debug)] +pub struct Recv { + channel_ptr: NonNull>, +} + +unsafe impl Send for Recv {} + +fn recv_awaken(channel: &Channel) -> Poll> { + loop { + hint::spin_loop(); + + // ORDERING: The load above has already synchronized with writing message. + match channel.state.load(Ordering::Relaxed) { + AWAKING => {} + DISCONNECTED => break Poll::Ready(Err(RecvError(()))), + MESSAGE => { + // ORDERING: the sender has been dropped, so this update only + // needs to be visible to us. + channel.state.store(DISCONNECTED, Ordering::Relaxed); + // SAFETY: We observed the MESSAGE state. + break Poll::Ready(Ok(unsafe { channel.take_message() })); + } + state => unreachable!("unexpected channel state: {}", state), + } + } +} + +impl Future for Recv { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + // SAFETY: the existence of the `self` parameter serves as a certificate that the receiver + // is still alive, meaning that even if the sender was dropped then it would have observed + // the fact that we are still alive and left the responsibility of deallocating the + // channel to us, so `self.channel` is valid + let channel = unsafe { self.channel_ptr.as_ref() }; + + // ORDERING: we use acquire ordering to synchronize with the store of the message. + match channel.state.load(Ordering::Acquire) { + // The sender is alive but has not sent anything yet. + EMPTY => { + let waker = cx.waker().clone(); + // SAFETY: We can not be in the forbidden states, and no waker in the channel. + unsafe { channel.write_waker(waker) } + } + // The sender sent the message. + MESSAGE => { + // ORDERING: the sender has been dropped so this update only needs to be + // visible to us. + channel.state.store(DISCONNECTED, Ordering::Relaxed); + Poll::Ready(Ok(unsafe { channel.take_message() })) + } + // We were polled again while waiting for the sender. Replace the waker with the new + // one. + RECEIVING => { + // ORDERING: We use relaxed ordering on both success and failure since we have not + // written anything above that must be released, and the individual match arms + // handle any additional synchronization. + match channel.state.compare_exchange( + RECEIVING, + EMPTY, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + // We successfully changed the state back to EMPTY. + // + // This is the most likely branch to be taken, which is why we do not use any + // memory barriers in the compare_exchange above. + Ok(_) => { + let waker = cx.waker().clone(); + + // SAFETY: We wrote the waker in a previous call to poll. We do not need + // a memory barrier since the previous write here was by ourselves. + unsafe { channel.drop_waker() }; + + // SAFETY: We can not be in the forbidden states, and no waker in the + // channel. + unsafe { channel.write_waker(waker) } + } + // The sender sent the message while we prepared to replace the waker. + // We take the message and mark the channel disconnected. + // The sender has already taken the waker. + Err(MESSAGE) => { + // ORDERING: Synchronize with writing message. This branch is + // unlikely to be taken. + channel.state.swap(DISCONNECTED, Ordering::Acquire); + + // SAFETY: The state tells us the sender has initialized the message. + Poll::Ready(Ok(unsafe { channel.take_message() })) + } + // The sender is currently waking us up. + Err(AWAKING) => recv_awaken(channel), + // The sender was dropped before sending anything while we prepared to park. + // The sender has taken the waker already. + Err(DISCONNECTED) => Poll::Ready(Err(RecvError(()))), + Err(state) => unreachable!("unexpected channel state: {}", state), + } + } + // The sender has observed the RECEIVING state and is currently reading the waker from + // a previous poll. We need to loop here until we observe the MESSAGE or DISCONNECTED + // state. We busy loop here since we know the sender is done very soon. + AWAKING => recv_awaken(channel), + // The sender was dropped before sending anything. + DISCONNECTED => Poll::Ready(Err(RecvError(()))), + state => unreachable!("unexpected channel state: {}", state), + } + } +} + +impl Drop for Recv { + fn drop(&mut self) { + // SAFETY: since the receiving side is still alive the sender would have observed that and + // left deallocating the channel allocation to us. + let channel = unsafe { self.channel_ptr.as_ref() }; + + // Set the channel state to disconnected and read what state the receiver was in. + match channel.state.swap(DISCONNECTED, Ordering::Acquire) { + // The sender has not sent anything, nor is it dropped. + EMPTY => {} + // The sender already sent something. We must drop it, and free the channel. + MESSAGE => { + unsafe { channel.drop_message() }; + unsafe { dealloc(self.channel_ptr) }; + } + // The receiver has been polled. We must drop the waker. + RECEIVING => { + unsafe { channel.drop_waker() }; + } + // The sender was already dropped. We are responsible for freeing the channel. + DISCONNECTED => { + // SAFETY: see safety comment at top of function. + unsafe { dealloc(self.channel_ptr) }; + } + // This receiver was previously polled, so the channel was in the RECEIVING state. + // But the sender has observed the RECEIVING state and is currently reading the waker + // to wake us up. We need to loop here until we observe the MESSAGE or DISCONNECTED + // state. We busy loop here since we know the sender is done very soon. + AWAKING => { + loop { + hint::spin_loop(); + + // ORDERING: The swap above has already synchronized with writing message. + match channel.state.load(Ordering::Relaxed) { + AWAKING => {} + DISCONNECTED => break, + MESSAGE => { + // SAFETY: we are in the message state so the message is initialized. + unsafe { channel.drop_message() }; + break; + } + state => unreachable!("unexpected channel state: {}", state), + } + } + unsafe { dealloc(self.channel_ptr) }; + } + state => unreachable!("unexpected channel state: {}", state), + } + } +} + +/// Internal channel data structure. +/// +/// The [`channel`] method allocates and puts one instance of this struct on the heap for each +/// oneshot channel instance. The struct holds: +/// +/// * The current state of the channel. +/// * The message in the channel. This memory is uninitialized until the message is sent. +/// * The waker instance for the task that is currently receiving on this channel. This memory is +/// uninitialized until the receiver starts receiving. +struct Channel { + state: AtomicU8, + message: UnsafeCell>, + waker: UnsafeCell>, +} + +impl Channel { + const fn new() -> Self { + Self { + state: AtomicU8::new(EMPTY), + message: UnsafeCell::new(MaybeUninit::uninit()), + waker: UnsafeCell::new(MaybeUninit::uninit()), + } + } + + #[inline(always)] + unsafe fn message(&self) -> &MaybeUninit { + &*self.message.get() + } + + #[inline(always)] + unsafe fn write_message(&self, message: T) { + let slot = &mut *self.message.get(); + slot.as_mut_ptr().write(message); + } + + #[inline(always)] + unsafe fn drop_message(&self) { + let slot = &mut *self.message.get(); + slot.assume_init_drop(); + } + + #[inline(always)] + unsafe fn take_message(&self) -> T { + ptr::read(self.message.get()).assume_init() + } + + /// # Safety + /// + /// * The `waker` field must not have a waker stored when calling this method. + /// * The `state` must not be in the RECEIVING state when calling this method. + unsafe fn write_waker(&self, waker: Waker) -> Poll> { + // Write the waker instance to the channel. + // + // SAFETY: we are not yet in the RECEIVING state, meaning that the sender will not + // try to access the waker until it sees the state set to RECEIVING below. + let slot = &mut *self.waker.get(); + slot.as_mut_ptr().write(waker); + + // ORDERING: we use release ordering on success so the sender can synchronize with + // our write of the waker. We use relaxed ordering on failure since the sender does + // not need to synchronize with our write and the individual match arms handle any + // additional synchronization + match self + .state + .compare_exchange(EMPTY, RECEIVING, Ordering::Release, Ordering::Relaxed) + { + // We stored our waker, now we return and let the sender wake us up. + Ok(_) => Poll::Pending, + // The sender sent the message while we prepared to await. + // We take the message and mark the channel disconnected. + Err(MESSAGE) => { + // ORDERING: Synchronize with writing message. This branch is unlikely to be taken, + // so it is likely more efficient to use a fence here instead of AcqRel ordering on + // the compare_exchange operation. + fence(Ordering::Acquire); + + // SAFETY: we started in the EMPTY state and the sender switched us to the + // MESSAGE state. This means that it did not take the waker, so we're + // responsible for dropping it. + self.drop_waker(); + + // ORDERING: sender does not exist, so this update only needs to be visible to us. + self.state.store(DISCONNECTED, Ordering::Relaxed); + + // SAFETY: The MESSAGE state tells us there is a correctly initialized message. + Poll::Ready(Ok(self.take_message())) + } + // The sender was dropped before sending anything while we prepared to await. + Err(DISCONNECTED) => { + // SAFETY: we started in the EMPTY state and the sender switched us to the + // DISCONNECTED state. This means that it did not take the waker, so we are + // responsible for dropping it. + self.drop_waker(); + Poll::Ready(Err(RecvError(()))) + } + Err(state) => unreachable!("unexpected channel state: {}", state), + } + } + + #[inline(always)] + unsafe fn drop_waker(&self) { + let slot = &mut *self.waker.get(); + slot.assume_init_drop(); + } + + #[inline(always)] + unsafe fn take_waker(&self) -> Waker { + ptr::read(self.waker.get()).assume_init() + } +} + +unsafe fn dealloc(channel: NonNull>) { + drop(Box::from_raw(channel.as_ptr())) +} + +/// An error returned when trying to send on a closed channel. Returned from +/// [`Sender::send`] if the corresponding [`Receiver`] has already been dropped. +/// +/// The message that could not be sent can be retrieved again with [`SendError::into_inner`]. +pub struct SendError { + channel_ptr: NonNull>, +} + +unsafe impl Send for SendError {} +unsafe impl Sync for SendError {} + +impl SendError { + /// Get a reference to the message that failed to be sent. + pub fn as_inner(&self) -> &T { + unsafe { self.channel_ptr.as_ref().message().assume_init_ref() } + } + + /// Consumes the error and returns the message that failed to be sent. + pub fn into_inner(self) -> T { + let channel_ptr = self.channel_ptr; + + // Do not run destructor if we consumed ourselves. Freeing happens below. + mem::forget(self); + + // SAFETY: we have ownership of the channel + let channel: &Channel = unsafe { channel_ptr.as_ref() }; + + // SAFETY: we know that the message is initialized according to the safety requirements of + // `new` + let message = unsafe { channel.take_message() }; + + // SAFETY: we own the channel + unsafe { dealloc(channel_ptr) }; + + message + } +} + +impl Drop for SendError { + fn drop(&mut self) { + // SAFETY: we have ownership of the channel and require that the message is initialized + // upon construction + unsafe { + self.channel_ptr.as_ref().drop_message(); + dealloc(self.channel_ptr); + } + } +} + +impl fmt::Display for SendError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + "sending on a closed channel".fmt(f) + } +} + +impl fmt::Debug for SendError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "SendError<{}>(..)", stringify!(T)) + } +} + +impl std::error::Error for SendError {} + +/// An error returned when receiving the message. +/// +/// The receiving operation can only fail if the corresponding [`Sender`] was dropped +/// before sending any message, or if a message has already been received on the channel. +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct RecvError(()); + +impl fmt::Display for RecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + "receiving on a closed channel".fmt(f) + } +} + +impl std::error::Error for RecvError {} + +/// The initial channel state. Active while both endpoints are still alive, no message has been +/// sent, and the receiver is not receiving. +const EMPTY: u8 = 0b011; +/// A message has been sent to the channel, but the receiver has not yet read it. +const MESSAGE: u8 = 0b100; +/// No message has yet been sent on the channel, but the receiver future ([`Recv`]) is currently +/// receiving. +const RECEIVING: u8 = 0b000; +/// A message is sending to the channel, or the channel is closing. The receiver future ([`Recv`]) +/// is currently being awakened. +const AWAKING: u8 = 0b001; +/// The channel has been closed. This means that either the sender or receiver has been dropped, +/// or the message sent to the channel has already been received. +const DISCONNECTED: u8 = 0b010; diff --git a/mea/src/oneshot/tests.rs b/mea/src/oneshot/tests.rs new file mode 100644 index 0000000..5d1f2a1 --- /dev/null +++ b/mea/src/oneshot/tests.rs @@ -0,0 +1,133 @@ +// Copyright 2024 tison +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::future::IntoFuture; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use std::time::Duration; + +use crate::oneshot; + +struct DropCounterHandle(Arc); + +impl DropCounterHandle { + pub fn count(&self) -> usize { + self.0.load(Ordering::SeqCst) + } +} + +struct DropCounter { + drop_count: Arc, + value: Option, +} + +impl DropCounter { + fn new(value: T) -> (Self, DropCounterHandle) { + let drop_count = Arc::new(AtomicUsize::new(0)); + ( + Self { + drop_count: drop_count.clone(), + value: Some(value), + }, + DropCounterHandle(drop_count), + ) + } + + fn value(&self) -> &T { + self.value.as_ref().unwrap() + } +} + +impl Drop for DropCounter { + fn drop(&mut self) { + self.drop_count.fetch_add(1, Ordering::SeqCst); + } +} + +#[tokio::test] +async fn send_before_await() { + let (sender, receiver) = oneshot::channel(); + assert!(sender.send(19i128).is_ok()); + assert_eq!(receiver.await, Ok(19i128)); +} + +#[tokio::test] +async fn await_with_dropped_sender() { + let (sender, receiver) = oneshot::channel::(); + drop(sender); + receiver.await.unwrap_err(); +} + +#[tokio::test] +async fn await_before_send() { + let (sender, receiver) = oneshot::channel(); + let (message, counter) = DropCounter::new(79u128); + let t = tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(10)).await; + sender.send(message) + }); + let returned_message = receiver.await.unwrap(); + assert_eq!(counter.count(), 0); + assert_eq!(*returned_message.value(), 79u128); + drop(returned_message); + assert_eq!(counter.count(), 1); + t.await.unwrap().unwrap(); +} + +#[tokio::test] +async fn await_before_send_then_drop_sender() { + let (sender, receiver) = oneshot::channel::(); + let t = tokio::spawn(async { + tokio::time::sleep(Duration::from_millis(10)).await; + drop(sender); + }); + assert!(receiver.await.is_err()); + t.await.unwrap(); +} + +#[tokio::test] +async fn poll_receiver_then_drop_it() { + let (sender, receiver) = oneshot::channel::<()>(); + // This will poll the receiver and then give up after 100 ms. + tokio::time::timeout(Duration::from_millis(100), receiver) + .await + .unwrap_err(); + // Make sure the receiver has been dropped by the runtime. + assert!(sender.send(()).is_err()); +} + +#[tokio::test] +async fn recv_within_select() { + let (tx, rx) = oneshot::channel::<&'static str>(); + let mut interval = tokio::time::interval(Duration::from_secs(100)); + + let handle = tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(1)).await; + tx.send("shut down").unwrap(); + }); + + let mut recv = rx.into_future(); + loop { + tokio::select! { + _ = interval.tick() => println!("another 100ms"), + msg = &mut recv => { + println!("Got message: {}", msg.unwrap()); + break; + } + } + } + + handle.await.unwrap(); +} From b9c062b37b670cccbff3922bfc9191e3f1e22dee Mon Sep 17 00:00:00 2001 From: tison Date: Fri, 30 May 2025 18:39:13 +0800 Subject: [PATCH 2/3] more tests Signed-off-by: tison --- mea/src/oneshot/tests.rs | 159 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 156 insertions(+), 3 deletions(-) diff --git a/mea/src/oneshot/tests.rs b/mea/src/oneshot/tests.rs index 5d1f2a1..3d7a81b 100644 --- a/mea/src/oneshot/tests.rs +++ b/mea/src/oneshot/tests.rs @@ -12,10 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::future::Future; use std::future::IntoFuture; +use std::mem; +use std::pin::Pin; +use std::sync::atomic::AtomicU32; use std::sync::atomic::AtomicUsize; use std::sync::atomic::Ordering; use std::sync::Arc; +use std::task::Context; +use std::task::Poll; +use std::task::RawWaker; +use std::task::RawWakerVTable; +use std::task::Waker; use std::time::Duration; use crate::oneshot; @@ -111,17 +120,17 @@ async fn poll_receiver_then_drop_it() { #[tokio::test] async fn recv_within_select() { let (tx, rx) = oneshot::channel::<&'static str>(); - let mut interval = tokio::time::interval(Duration::from_secs(100)); + let mut interval = tokio::time::interval(Duration::from_millis(10)); let handle = tokio::spawn(async move { - tokio::time::sleep(Duration::from_secs(1)).await; + tokio::time::sleep(Duration::from_millis(100)).await; tx.send("shut down").unwrap(); }); let mut recv = rx.into_future(); loop { tokio::select! { - _ = interval.tick() => println!("another 100ms"), + _ = interval.tick() => println!("another 10ms"), msg = &mut recv => { println!("Got message: {}", msg.unwrap()); break; @@ -131,3 +140,147 @@ async fn recv_within_select() { handle.await.unwrap(); } + +#[derive(Default)] +pub struct WakerHandle { + clone_count: AtomicU32, + drop_count: AtomicU32, + wake_count: AtomicU32, +} + +impl WakerHandle { + pub fn clone_count(&self) -> u32 { + self.clone_count.load(Ordering::Relaxed) + } + + pub fn drop_count(&self) -> u32 { + self.drop_count.load(Ordering::Relaxed) + } + + pub fn wake_count(&self) -> u32 { + self.wake_count.load(Ordering::Relaxed) + } +} + +fn waker() -> (Waker, Arc) { + let waker_handle = Arc::new(WakerHandle::default()); + let waker_handle_ptr = Arc::into_raw(waker_handle.clone()); + let raw_waker = RawWaker::new(waker_handle_ptr as *const _, waker_vtable()); + (unsafe { Waker::from_raw(raw_waker) }, waker_handle) +} + +fn waker_vtable() -> &'static RawWakerVTable { + &RawWakerVTable::new(clone_raw, wake_raw, wake_by_ref_raw, drop_raw) +} + +unsafe fn clone_raw(data: *const ()) -> RawWaker { + let handle: Arc = Arc::from_raw(data as *const _); + handle.clone_count.fetch_add(1, Ordering::Relaxed); + mem::forget(handle.clone()); + mem::forget(handle); + RawWaker::new(data, waker_vtable()) +} + +unsafe fn wake_raw(data: *const ()) { + let handle: Arc = Arc::from_raw(data as *const _); + handle.wake_count.fetch_add(1, Ordering::Relaxed); + handle.drop_count.fetch_add(1, Ordering::Relaxed); +} + +unsafe fn wake_by_ref_raw(data: *const ()) { + let handle: Arc = Arc::from_raw(data as *const _); + handle.wake_count.fetch_add(1, Ordering::Relaxed); + mem::forget(handle) +} + +unsafe fn drop_raw(data: *const ()) { + let handle: Arc = Arc::from_raw(data as *const _); + handle.drop_count.fetch_add(1, Ordering::Relaxed); + drop(handle) +} + +#[test] +fn poll_then_send() { + let (sender, receiver) = oneshot::channel::(); + let mut receiver = receiver.into_future(); + + let (waker, waker_handle) = waker(); + let mut context = Context::from_waker(&waker); + + assert_eq!(Pin::new(&mut receiver).poll(&mut context), Poll::Pending); + assert_eq!(waker_handle.clone_count(), 1); + assert_eq!(waker_handle.drop_count(), 0); + assert_eq!(waker_handle.wake_count(), 0); + + sender.send(1234).unwrap(); + assert_eq!(waker_handle.clone_count(), 1); + assert_eq!(waker_handle.drop_count(), 1); + assert_eq!(waker_handle.wake_count(), 1); + + assert_eq!( + Pin::new(&mut receiver).poll(&mut context), + Poll::Ready(Ok(1234)) + ); + assert_eq!(waker_handle.clone_count(), 1); + assert_eq!(waker_handle.drop_count(), 1); + assert_eq!(waker_handle.wake_count(), 1); +} + +#[test] +fn poll_with_different_wakers() { + let (sender, receiver) = oneshot::channel::(); + let mut receiver = receiver.into_future(); + + let (waker1, waker_handle1) = waker(); + let mut context1 = Context::from_waker(&waker1); + + assert_eq!(Pin::new(&mut receiver).poll(&mut context1), Poll::Pending); + assert_eq!(waker_handle1.clone_count(), 1); + assert_eq!(waker_handle1.drop_count(), 0); + assert_eq!(waker_handle1.wake_count(), 0); + + let (waker2, waker_handle2) = waker(); + let mut context2 = Context::from_waker(&waker2); + + assert_eq!(Pin::new(&mut receiver).poll(&mut context2), Poll::Pending); + assert_eq!(waker_handle1.clone_count(), 1); + assert_eq!(waker_handle1.drop_count(), 1); + assert_eq!(waker_handle1.wake_count(), 0); + + assert_eq!(waker_handle2.clone_count(), 1); + assert_eq!(waker_handle2.drop_count(), 0); + assert_eq!(waker_handle2.wake_count(), 0); + + // Sending should cause the waker from the latest poll to be woken up + sender.send(1234).unwrap(); + assert_eq!(waker_handle1.clone_count(), 1); + assert_eq!(waker_handle1.drop_count(), 1); + assert_eq!(waker_handle1.wake_count(), 0); + + assert_eq!(waker_handle2.clone_count(), 1); + assert_eq!(waker_handle2.drop_count(), 1); + assert_eq!(waker_handle2.wake_count(), 1); +} + +#[test] +fn poll_then_drop_receiver_during_send() { + let (sender, receiver) = oneshot::channel::(); + let mut receiver = receiver.into_future(); + + let (waker, _waker_handle) = waker(); + let mut context = Context::from_waker(&waker); + + // Put the channel into the receiving state + assert_eq!(Pin::new(&mut receiver).poll(&mut context), Poll::Pending); + + // Spawn a separate thread that sends in parallel + let t = std::thread::spawn(move || { + let _ = sender.send(1234); + }); + + // Drop the receiver. + drop(receiver); + + // The send operation should also not have panicked + t.join().unwrap(); +} From 30430450a9d0114d2545affdfd8402794c4fd896 Mon Sep 17 00:00:00 2001 From: tison Date: Fri, 30 May 2025 19:22:59 +0800 Subject: [PATCH 3/3] simplify variance Signed-off-by: tison --- mea/src/oneshot/mod.rs | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/mea/src/oneshot/mod.rs b/mea/src/oneshot/mod.rs index 8f9142e..424b708 100644 --- a/mea/src/oneshot/mod.rs +++ b/mea/src/oneshot/mod.rs @@ -77,7 +77,6 @@ use std::fmt; use std::future::Future; use std::future::IntoFuture; use std::hint; -use std::marker::PhantomData; use std::mem; use std::mem::MaybeUninit; use std::pin::Pin; @@ -96,19 +95,13 @@ mod tests; /// Creates a new oneshot channel and returns the two endpoints, [`Sender`] and [`Receiver`]. pub fn channel() -> (Sender, Receiver) { let channel_ptr = NonNull::from(Box::leak(Box::new(Channel::new()))); - let sender = Sender { - channel_ptr, - _invariant: PhantomData, - }; - let receiver = Receiver { channel_ptr }; - (sender, receiver) + (Sender { channel_ptr }, Receiver { channel_ptr }) } /// Sends a value to the associated [`Receiver`]. #[derive(Debug)] pub struct Sender { channel_ptr: NonNull>, - _invariant: PhantomData T>, } unsafe impl Send for Sender {}