use std::{
future::Future,
panic,
pin::Pin,
sync::Arc,
task::{Context, Poll, Waker},
};
use pyo3_macros::{pyclass, pymethods};
use crate::{
coroutine::{cancel::ThrowCallback, waker::AsyncioWaker},
exceptions::{PyAttributeError, PyRuntimeError, PyStopIteration},
panic::PanicException,
types::{string::PyStringMethods, PyIterator, PyString},
Bound, IntoPyObject, IntoPyObjectExt, Py, PyAny, PyErr, PyObject, PyResult, Python,
};
pub(crate) mod cancel;
mod waker;
pub use cancel::CancelHandle;
const COROUTINE_REUSED_ERROR: &str = "cannot reuse already awaited coroutine";
#[pyclass(crate = "crate")]
pub struct Coroutine {
name: Option<Py<PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
future: Option<Pin<Box<dyn Future<Output = PyResult<PyObject>> + Send>>>,
waker: Option<Arc<AsyncioWaker>>,
}
unsafe impl Sync for Coroutine {}
impl Coroutine {
pub(crate) fn new<'py, F, T, E>(
name: Option<Bound<'py, PyString>>,
qualname_prefix: Option<&'static str>,
throw_callback: Option<ThrowCallback>,
future: F,
) -> Self
where
F: Future<Output = Result<T, E>> + Send + 'static,
T: IntoPyObject<'py>,
E: Into<PyErr>,
{
let wrap = async move {
let obj = future.await.map_err(Into::into)?;
obj.into_py_any(unsafe { Python::assume_gil_acquired() })
};
Self {
name: name.map(Bound::unbind),
qualname_prefix,
throw_callback,
future: Some(Box::pin(wrap)),
waker: None,
}
}
fn poll(&mut self, py: Python<'_>, throw: Option<PyObject>) -> PyResult<PyObject> {
let future_rs = match self.future {
Some(ref mut fut) => fut,
None => return Err(PyRuntimeError::new_err(COROUTINE_REUSED_ERROR)),
};
match (throw, &self.throw_callback) {
(Some(exc), Some(cb)) => cb.throw(exc),
(Some(exc), None) => {
self.close();
return Err(PyErr::from_value(exc.into_bound(py)));
}
(None, _) => {}
}
if let Some(waker) = self.waker.as_mut().and_then(Arc::get_mut) {
waker.reset();
} else {
self.waker = Some(Arc::new(AsyncioWaker::new()));
}
let waker = Waker::from(self.waker.clone().unwrap());
let poll = || future_rs.as_mut().poll(&mut Context::from_waker(&waker));
match panic::catch_unwind(panic::AssertUnwindSafe(poll)) {
Ok(Poll::Ready(res)) => {
self.close();
return Err(PyStopIteration::new_err((res?,)));
}
Err(err) => {
self.close();
return Err(PanicException::from_panic_payload(err));
}
_ => {}
}
if let Some(future) = self.waker.as_ref().unwrap().initialize_future(py)? {
if let Some(future) = PyIterator::from_object(future).unwrap().next() {
return Ok(future.unwrap().into());
}
}
Ok(py.None())
}
}
#[pymethods(crate = "crate")]
impl Coroutine {
#[getter]
fn __name__(&self, py: Python<'_>) -> PyResult<Py<PyString>> {
match &self.name {
Some(name) => Ok(name.clone_ref(py)),
None => Err(PyAttributeError::new_err("__name__")),
}
}
#[getter]
fn __qualname__<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyString>> {
match (&self.name, &self.qualname_prefix) {
(Some(name), Some(prefix)) => Ok(PyString::new(
py,
&format!("{}.{}", prefix, name.bind(py).to_cow()?),
)),
(Some(name), None) => Ok(name.bind(py).clone()),
(None, _) => Err(PyAttributeError::new_err("__qualname__")),
}
}
fn send(&mut self, py: Python<'_>, _value: &Bound<'_, PyAny>) -> PyResult<PyObject> {
self.poll(py, None)
}
fn throw(&mut self, py: Python<'_>, exc: PyObject) -> PyResult<PyObject> {
self.poll(py, Some(exc))
}
fn close(&mut self) {
drop(self.future.take());
}
fn __await__(self_: Py<Self>) -> Py<Self> {
self_
}
fn __next__(&mut self, py: Python<'_>) -> PyResult<PyObject> {
self.poll(py, None)
}
}