1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
use futures_sink::Sink;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{fmt, mem};
use tokio::sync::mpsc::OwnedPermit;
use tokio::sync::mpsc::Sender;
use super::ReusableBoxFuture;
/// Error returned by the `PollSender` when the channel is closed.
#[derive(Debug)]
pub struct PollSendError<T>(Option<T>);
impl<T> PollSendError<T> {
/// Consumes the stored value, if any.
///
/// If this error was encountered when calling `start_send`/`send_item`, this will be the item
/// that the caller attempted to send. Otherwise, it will be `None`.
pub fn into_inner(self) -> Option<T> {
self.0
}
}
impl<T> fmt::Display for PollSendError<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "channel closed")
}
}
impl<T: fmt::Debug> std::error::Error for PollSendError<T> {}
#[derive(Debug)]
enum State<T> {
Idle(Sender<T>),
Acquiring,
ReadyToSend(OwnedPermit<T>),
Closed,
}
/// A wrapper around [`mpsc::Sender`] that can be polled.
///
/// [`mpsc::Sender`]: tokio::sync::mpsc::Sender
#[derive(Debug)]
pub struct PollSender<T> {
sender: Option<Sender<T>>,
state: State<T>,
acquire: PollSenderFuture<T>,
}
// Creates a future for acquiring a permit from the underlying channel. This is used to ensure
// there's capacity for a send to complete.
//
// By reusing the same async fn for both `Some` and `None`, we make sure every future passed to
// ReusableBoxFuture has the same underlying type, and hence the same size and alignment.
async fn make_acquire_future<T>(
data: Option<Sender<T>>,
) -> Result<OwnedPermit<T>, PollSendError<T>> {
match data {
Some(sender) => sender
.reserve_owned()
.await
.map_err(|_| PollSendError(None)),
None => unreachable!("this future should not be pollable in this state"),
}
}
type InnerFuture<'a, T> = ReusableBoxFuture<'a, Result<OwnedPermit<T>, PollSendError<T>>>;
#[derive(Debug)]
// TODO: This should be replace with a type_alias_impl_trait to eliminate `'static` and all the transmutes
struct PollSenderFuture<T>(InnerFuture<'static, T>);
impl<T> PollSenderFuture<T> {
/// Create with an empty inner future with no `Send` bound.
fn empty() -> Self {
// We don't use `make_acquire_future` here because our relaxed bounds on `T` are not
// compatible with the transitive bounds required by `Sender<T>`.
Self(ReusableBoxFuture::new(async { unreachable!() }))
}
}
impl<T: Send> PollSenderFuture<T> {
/// Create with an empty inner future.
fn new() -> Self {
let v = InnerFuture::new(make_acquire_future(None));
// This is safe because `make_acquire_future(None)` is actually `'static`
Self(unsafe { mem::transmute::<InnerFuture<'_, T>, InnerFuture<'static, T>>(v) })
}
/// Poll the inner future.
fn poll(&mut self, cx: &mut Context<'_>) -> Poll<Result<OwnedPermit<T>, PollSendError<T>>> {
self.0.poll(cx)
}
/// Replace the inner future.
fn set(&mut self, sender: Option<Sender<T>>) {
let inner: *mut InnerFuture<'static, T> = &mut self.0;
let inner: *mut InnerFuture<'_, T> = inner.cast();
// SAFETY: The `make_acquire_future(sender)` future must not exist after the type `T`
// becomes invalid, and this casts away the type-level lifetime check for that. However, the
// inner future is never moved out of this `PollSenderFuture<T>`, so the future will not
// live longer than the `PollSenderFuture<T>` lives. A `PollSenderFuture<T>` is guaranteed
// to not exist after the type `T` becomes invalid, because it is annotated with a `T`, so
// this is ok.
let inner = unsafe { &mut *inner };
inner.set(make_acquire_future(sender));
}
}
impl<T: Send> PollSender<T> {
/// Creates a new `PollSender`.
pub fn new(sender: Sender<T>) -> Self {
Self {
sender: Some(sender.clone()),
state: State::Idle(sender),
acquire: PollSenderFuture::new(),
}
}
fn take_state(&mut self) -> State<T> {
mem::replace(&mut self.state, State::Closed)
}
/// Attempts to prepare the sender to receive a value.
///
/// This method must be called and return `Poll::Ready(Ok(()))` prior to each call to
/// `send_item`.
///
/// This method returns `Poll::Ready` once the underlying channel is ready to receive a value,
/// by reserving a slot in the channel for the item to be sent. If this method returns
/// `Poll::Pending`, the current task is registered to be notified (via
/// `cx.waker().wake_by_ref()`) when `poll_reserve` should be called again.
///
/// # Errors
///
/// If the channel is closed, an error will be returned. This is a permanent state.
pub fn poll_reserve(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), PollSendError<T>>> {
loop {
let (result, next_state) = match self.take_state() {
State::Idle(sender) => {
// Start trying to acquire a permit to reserve a slot for our send, and
// immediately loop back around to poll it the first time.
self.acquire.set(Some(sender));
(None, State::Acquiring)
}
State::Acquiring => match self.acquire.poll(cx) {
// Channel has capacity.
Poll::Ready(Ok(permit)) => {
(Some(Poll::Ready(Ok(()))), State::ReadyToSend(permit))
}
// Channel is closed.
Poll::Ready(Err(e)) => (Some(Poll::Ready(Err(e))), State::Closed),
// Channel doesn't have capacity yet, so we need to wait.
Poll::Pending => (Some(Poll::Pending), State::Acquiring),
},
// We're closed, either by choice or because the underlying sender was closed.
s @ State::Closed => (Some(Poll::Ready(Err(PollSendError(None)))), s),
// We're already ready to send an item.
s @ State::ReadyToSend(_) => (Some(Poll::Ready(Ok(()))), s),
};
self.state = next_state;
if let Some(result) = result {
return result;
}
}
}
/// Sends an item to the channel.
///
/// Before calling `send_item`, `poll_reserve` must be called with a successful return
/// value of `Poll::Ready(Ok(()))`.
///
/// # Errors
///
/// If the channel is closed, an error will be returned. This is a permanent state.
///
/// # Panics
///
/// If `poll_reserve` was not successfully called prior to calling `send_item`, then this method
/// will panic.
#[track_caller]
pub fn send_item(&mut self, value: T) -> Result<(), PollSendError<T>> {
let (result, next_state) = match self.take_state() {
State::Idle(_) | State::Acquiring => {
panic!("`send_item` called without first calling `poll_reserve`")
}
// We have a permit to send our item, so go ahead, which gets us our sender back.
State::ReadyToSend(permit) => (Ok(()), State::Idle(permit.send(value))),
// We're closed, either by choice or because the underlying sender was closed.
State::Closed => (Err(PollSendError(Some(value))), State::Closed),
};
// Handle deferred closing if `close` was called between `poll_reserve` and `send_item`.
self.state = if self.sender.is_some() {
next_state
} else {
State::Closed
};
result
}
/// Checks whether this sender is been closed.
///
/// The underlying channel that this sender was wrapping may still be open.
pub fn is_closed(&self) -> bool {
matches!(self.state, State::Closed) || self.sender.is_none()
}
/// Gets a reference to the `Sender` of the underlying channel.
///
/// If `PollSender` has been closed, `None` is returned. The underlying channel that this sender
/// was wrapping may still be open.
pub fn get_ref(&self) -> Option<&Sender<T>> {
self.sender.as_ref()
}
/// Closes this sender.
///
/// No more messages will be able to be sent from this sender, but the underlying channel will
/// remain open until all senders have dropped, or until the [`Receiver`] closes the channel.
///
/// If a slot was previously reserved by calling `poll_reserve`, then a final call can be made
/// to `send_item` in order to consume the reserved slot. After that, no further sends will be
/// possible. If you do not intend to send another item, you can release the reserved slot back
/// to the underlying sender by calling [`abort_send`].
///
/// [`abort_send`]: crate::sync::PollSender::abort_send
/// [`Receiver`]: tokio::sync::mpsc::Receiver
pub fn close(&mut self) {
// Mark ourselves officially closed by dropping our main sender.
self.sender = None;
// If we're already idle, closed, or we haven't yet reserved a slot, we can quickly
// transition to the closed state. Otherwise, leave the existing permit in place for the
// caller if they want to complete the send.
match self.state {
State::Idle(_) => self.state = State::Closed,
State::Acquiring => {
self.acquire.set(None);
self.state = State::Closed;
}
_ => {}
}
}
/// Aborts the current in-progress send, if any.
///
/// Returns `true` if a send was aborted. If the sender was closed prior to calling
/// `abort_send`, then the sender will remain in the closed state, otherwise the sender will be
/// ready to attempt another send.
pub fn abort_send(&mut self) -> bool {
// We may have been closed in the meantime, after a call to `poll_reserve` already
// succeeded. We'll check if `self.sender` is `None` to see if we should transition to the
// closed state when we actually abort a send, rather than resetting ourselves back to idle.
let (result, next_state) = match self.take_state() {
// We're currently trying to reserve a slot to send into.
State::Acquiring => {
// Replacing the future drops the in-flight one.
self.acquire.set(None);
// If we haven't closed yet, we have to clone our stored sender since we have no way
// to get it back from the acquire future we just dropped.
let state = match self.sender.clone() {
Some(sender) => State::Idle(sender),
None => State::Closed,
};
(true, state)
}
// We got the permit. If we haven't closed yet, get the sender back.
State::ReadyToSend(permit) => {
let state = if self.sender.is_some() {
State::Idle(permit.release())
} else {
State::Closed
};
(true, state)
}
s => (false, s),
};
self.state = next_state;
result
}
}
impl<T> Clone for PollSender<T> {
/// Clones this `PollSender`.
///
/// The resulting `PollSender` will have an initial state identical to calling `PollSender::new`.
fn clone(&self) -> PollSender<T> {
let (sender, state) = match self.sender.clone() {
Some(sender) => (Some(sender.clone()), State::Idle(sender)),
None => (None, State::Closed),
};
Self {
sender,
state,
acquire: PollSenderFuture::empty(),
}
}
}
impl<T: Send + 'static> Sink<T> for PollSender<T> {
type Error = PollSendError<T>;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::into_inner(self).poll_reserve(cx)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
Pin::into_inner(self).send_item(item)
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::into_inner(self).close();
Poll::Ready(Ok(()))
}
}