use pyo3::exceptions::PyStopIteration;
use pyo3::prelude::*;
#[pyclass]
#[derive(Debug)]
pub(crate) struct IterAwaitable {
result: Option<PyResult<PyObject>>,
}
#[pymethods]
impl IterAwaitable {
#[new]
fn new(result: PyObject) -> Self {
IterAwaitable {
result: Some(Ok(result)),
}
}
fn __await__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> {
pyself
}
fn __iter__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> {
pyself
}
fn __next__(&mut self, py: Python<'_>) -> PyResult<PyObject> {
match self.result.take() {
Some(res) => match res {
Ok(v) => Err(PyStopIteration::new_err(v)),
Err(err) => Err(err),
},
_ => Ok(py.None()),
}
}
}
#[pyclass]
pub(crate) struct FutureAwaitable {
#[pyo3(get, set, name = "_asyncio_future_blocking")]
py_block: bool,
result: Option<PyResult<PyObject>>,
}
#[pymethods]
impl FutureAwaitable {
#[new]
fn new(result: PyObject) -> Self {
FutureAwaitable {
py_block: false,
result: Some(Ok(result)),
}
}
fn __await__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> {
pyself
}
fn __iter__(pyself: PyRef<'_, Self>) -> PyRef<'_, Self> {
pyself
}
fn __next__(mut pyself: PyRefMut<'_, Self>) -> PyResult<PyRefMut<'_, Self>> {
match pyself.result {
Some(_) => match pyself.result.take().unwrap() {
Ok(v) => Err(PyStopIteration::new_err(v)),
Err(err) => Err(err),
},
_ => Ok(pyself),
}
}
}
#[pymodule]
pub fn awaitable(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<IterAwaitable>()?;
m.add_class::<FutureAwaitable>()?;
Ok(())
}