pyo3/
coroutine.rs

1//! Python coroutine implementation, used notably when wrapping `async fn`
2//! with `#[pyfunction]`/`#[pymethods]`.
3use std::{
4    future::Future,
5    panic,
6    pin::Pin,
7    sync::Arc,
8    task::{Context, Poll, Waker},
9};
10
11use pyo3_macros::{pyclass, pymethods};
12
13use crate::{
14    coroutine::{cancel::ThrowCallback, waker::AsyncioWaker},
15    exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration},
16    panic::PanicException,
17    types::{string::PyStringMethods, PyIterator, PyString},
18    Bound, Py, PyAny, PyErr, PyResult, Python,
19};
20
21pub(crate) mod cancel;
22mod waker;
23
24pub use cancel::CancelHandle;
25
26const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";
27
28/// Python coroutine wrapping a [`Future`].
29#[pyclass(crate = "crate")]
30pub struct Coroutine {
31    name: Option<Py<PyString>>,
32    qualname_prefix: Option<&'static str>,
33    throw_callback: Option<ThrowCallback>,
34    #[expect(clippy::type_complexity)]
35    future: Option<Pin<Box<dyn Future<Output = PyResult<Py<PyAny>>> + Send>>>,
36    waker: Option<Arc<AsyncioWaker>>,
37}
38
39// Safety: `Coroutine` is allowed to be `Sync` even though the future is not,
40// because the future is polled with `&mut self` receiver
41unsafe impl Sync for Coroutine {}
42
43impl Coroutine {
44    ///  Wrap a future into a Python coroutine.
45    ///
46    /// Coroutine `send` polls the wrapped future, ignoring the value passed
47    /// (should always be `None` anyway).
48    ///
49    /// `Coroutine `throw` drop the wrapped future and reraise the exception passed
50    pub(crate) fn new<'py, F>(
51        name: Option<Bound<'py, PyString>>,
52        qualname_prefix: Option<&'static str>,
53        throw_callback: Option<ThrowCallback>,
54        future: F,
55    ) -> Self
56    where
57        F: Future<Output = Result<Py<PyAny>, PyErr>> + Send + 'static,
58    {
59        Self {
60            name: name.map(Bound::unbind),
61            qualname_prefix,
62            throw_callback,
63            future: Some(Box::pin(future)),
64            waker: None,
65        }
66    }
67
68    fn poll(&mut self, py: Python<'_>, throw: Option<Py<PyAny>>) -> PyResult<Py<PyAny>> {
69        // raise if the coroutine has already been run to completion
70        let future_rs = match self.future {
71            Some(ref mut fut) => fut,
72            None => return Err(PyRuntimeError::new_err(COROUTINE_REUSED_ERROR)),
73        };
74        // reraise thrown exception it
75        match (throw, &self.throw_callback) {
76            (Some(exc), Some(cb)) => cb.throw(exc),
77            (Some(exc), None) => {
78                self.close();
79                return Err(PyErr::from_value(exc.into_bound(py)));
80            }
81            (None, _) => {}
82        }
83        // create a new waker, or try to reset it in place
84        if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) {
85            waker.reset();
86        } else {
87            self.waker = Some(Arc::new(AsyncioWaker::new()));
88        }
89        let waker = Waker::from(self.waker.clone().unwrap());
90        // poll the Rust future and forward its results if ready
91        // polling is UnwindSafe because the future is dropped in case of panic
92        let poll = || future_rs.as_mut().poll(&mut Context::from_waker(&waker));
93        match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
94            Ok(Poll::Ready(res)) => {
95                self.close();
96                return Err(PyStopIteration::new_err((res?,)));
97            }
98            Err(err) => {
99                self.close();
100                return Err(PanicException::from_panic_payload(err));
101            }
102            _ => {}
103        }
104        // otherwise, initialize the waker `asyncio.Future`
105        if let Some(future) = self.waker.as_ref().unwrap().initialize_future(py)? {
106            // `asyncio.Future` must be awaited; fortunately, it implements `__iter__ = __await__`
107            // and will yield itself if its result has not been set in polling above
108            if let Some(future) = PyIterator::from_object(future).unwrap().next() {
109                // future has not been leaked into Python for now, and Rust code can only call
110                // `set_result(None)` in `Wake` implementation, so it's safe to unwrap
111                return Ok(future.unwrap().into());
112            }
113        }
114        // if waker has been waken during future polling, this is roughly equivalent to
115        // `await asyncio.sleep(0)`, so just yield `None`.
116        Ok(py.None())
117    }
118}
119
120#[pymethods(crate = "crate")]
121impl Coroutine {
122    #[getter]
123    fn __name__(&self, py: Python<'_>) -> PyResult<Py<PyString>> {
124        match &self.name {
125            Some(name) => Ok(name.clone_ref(py)),
126            None => Err(PyAttributeError::new_err("__name__")),
127        }
128    }
129
130    #[getter]
131    fn __qualname__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyString>> {
132        match (&self.name, &self.qualname_prefix) {
133            (Some(name), Some(prefix)) => Ok(PyString::new(
134                py,
135                &format!("{}.{}", prefix, name.bind(py).to_cow()?),
136            )),
137            (Some(name), None) => Ok(name.bind(py).clone()),
138            (None, _) => Err(PyAttributeError::new_err("__qualname__")),
139        }
140    }
141
142    fn send(&mut self, py: Python<'_>, _value: &Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
143        self.poll(py, None)
144    }
145
146    fn throw(&mut self, py: Python<'_>, exc: Py<PyAny>) -> PyResult<Py<PyAny>> {
147        self.poll(py, Some(exc))
148    }
149
150    fn close(&mut self) {
151        // the Rust future is dropped, and the field set to `None`
152        // to indicate the coroutine has been run to completion
153        drop(self.future.take());
154    }
155
156    fn __await__(self_: Py<Self>) -> Py<Self> {
157        self_
158    }
159
160    fn __next__(&mut self, py: Python<'_>) -> PyResult<Py<PyAny>> {
161        self.poll(py, None)
162    }
163}