1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
use crate::{Py, PyAny, PyObject};
use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};

#[derive(Debug, Default)]
struct Inner {
    exception: Option<PyObject>,
    waker: Option<Waker>,
}

/// Helper used to wait and retrieve exception thrown in [`Coroutine`](super::Coroutine).
///
/// Only the last exception thrown can be retrieved.
#[derive(Debug, Default)]
pub struct CancelHandle(Arc<Mutex<Inner>>);

impl CancelHandle {
    /// Create a new `CoroutineCancel`.
    pub fn new() -> Self {
        Default::default()
    }

    /// Returns whether the associated coroutine has been cancelled.
    pub fn is_cancelled(&self) -> bool {
        self.0.lock().unwrap().exception.is_some()
    }

    /// Poll to retrieve the exception thrown in the associated coroutine.
    pub fn poll_cancelled(&mut self, cx: &mut Context<'_>) -> Poll<PyObject> {
        let mut inner = self.0.lock().unwrap();
        if let Some(exc) = inner.exception.take() {
            return Poll::Ready(exc);
        }
        if let Some(ref waker) = inner.waker {
            if cx.waker().will_wake(waker) {
                return Poll::Pending;
            }
        }
        inner.waker = Some(cx.waker().clone());
        Poll::Pending
    }

    /// Retrieve the exception thrown in the associated coroutine.
    pub async fn cancelled(&mut self) -> PyObject {
        Cancelled(self).await
    }

    #[doc(hidden)]
    pub fn throw_callback(&self) -> ThrowCallback {
        ThrowCallback(self.0.clone())
    }
}

// Because `poll_fn` is not available in MSRV
struct Cancelled<'a>(&'a mut CancelHandle);

impl Future for Cancelled<'_> {
    type Output = PyObject;
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        self.0.poll_cancelled(cx)
    }
}

#[doc(hidden)]
pub struct ThrowCallback(Arc<Mutex<Inner>>);

impl ThrowCallback {
    pub(super) fn throw(&self, exc: Py<PyAny>) {
        let mut inner = self.0.lock().unwrap();
        inner.exception = Some(exc);
        if let Some(waker) = inner.waker.take() {
            waker.wake();
        }
    }
}