Skip to content

Commit ed31c70

Browse files
committed
Async methods now takes Receiver
1 parent f322771 commit ed31c70

File tree

4 files changed

+90
-19
lines changed

4 files changed

+90
-19
lines changed

pyo3-derive-backend/src/defs.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,18 +156,21 @@ pub const ASYNC: Proto = Proto {
156156
slot_table: "pyo3::ffi::PyAsyncMethods",
157157
set_slot_table: "set_async_methods",
158158
methods: &[
159-
MethodProto::Unary {
159+
MethodProto::UnaryS {
160160
name: "__await__",
161+
arg: "Receiver",
161162
pyres: true,
162163
proto: "pyo3::class::pyasync::PyAsyncAwaitProtocol",
163164
},
164-
MethodProto::Unary {
165+
MethodProto::UnaryS {
165166
name: "__aiter__",
167+
arg: "Receiver",
166168
pyres: true,
167169
proto: "pyo3::class::pyasync::PyAsyncAiterProtocol",
168170
},
169-
MethodProto::Unary {
171+
MethodProto::UnaryS {
170172
name: "__anext__",
173+
arg: "Receiver",
171174
pyres: true,
172175
proto: "pyo3::class::pyasync::PyAsyncAnextProtocol",
173176
},

src/class/macros.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,9 @@ macro_rules! py_unarys_func {
3434
{
3535
$crate::callback_body!(py, {
3636
let slf = py.from_borrowed_ptr::<$crate::PyCell<T>>(slf);
37-
let borrow = <T::Receiver>::try_from_pycell(slf)
38-
.map_err(|e| e.into())?;
37+
let borrow =
38+
<T::Receiver as $crate::derive_utils::TryFromPyCell<_>>::try_from_pycell(slf)
39+
.map_err(|e| e.into())?;
3940

4041
$class::$f(borrow).into()$(.map($conv))?
4142
})

src/class/pyasync.rs

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
//! [PEP-0492](https://www.python.org/dev/peps/pep-0492/)
99
//!
1010
11+
use crate::derive_utils::TryFromPyCell;
1112
use crate::err::PyResult;
1213
use crate::{ffi, PyClass, PyObject};
1314

@@ -16,21 +17,21 @@ use crate::{ffi, PyClass, PyObject};
1617
/// Each method in this trait corresponds to Python async/await implementation.
1718
#[allow(unused_variables)]
1819
pub trait PyAsyncProtocol<'p>: PyClass {
19-
fn __await__(&'p self) -> Self::Result
20+
fn __await__(slf: Self::Receiver) -> Self::Result
2021
where
2122
Self: PyAsyncAwaitProtocol<'p>,
2223
{
2324
unimplemented!()
2425
}
2526

26-
fn __aiter__(&'p self) -> Self::Result
27+
fn __aiter__(slf: Self::Receiver) -> Self::Result
2728
where
2829
Self: PyAsyncAiterProtocol<'p>,
2930
{
3031
unimplemented!()
3132
}
3233

33-
fn __anext__(&'p mut self) -> Self::Result
34+
fn __anext__(slf: Self::Receiver) -> Self::Result
3435
where
3536
Self: PyAsyncAnextProtocol<'p>,
3637
{
@@ -58,16 +59,19 @@ pub trait PyAsyncProtocol<'p>: PyClass {
5859
}
5960

6061
pub trait PyAsyncAwaitProtocol<'p>: PyAsyncProtocol<'p> {
62+
type Receiver: TryFromPyCell<'p, Self>;
6163
type Success: crate::IntoPy<PyObject>;
6264
type Result: Into<PyResult<Self::Success>>;
6365
}
6466

6567
pub trait PyAsyncAiterProtocol<'p>: PyAsyncProtocol<'p> {
68+
type Receiver: TryFromPyCell<'p, Self>;
6669
type Success: crate::IntoPy<PyObject>;
6770
type Result: Into<PyResult<Self::Success>>;
6871
}
6972

7073
pub trait PyAsyncAnextProtocol<'p>: PyAsyncProtocol<'p> {
74+
type Receiver: TryFromPyCell<'p, Self>;
7175
type Success: crate::IntoPy<PyObject>;
7276
type Result: Into<PyResult<Option<Self::Success>>>;
7377
}
@@ -90,13 +94,13 @@ impl ffi::PyAsyncMethods {
9094
where
9195
T: for<'p> PyAsyncAwaitProtocol<'p>,
9296
{
93-
self.am_await = py_unary_func!(PyAsyncAwaitProtocol, T::__await__);
97+
self.am_await = py_unarys_func!(PyAsyncAwaitProtocol, T::__await__);
9498
}
9599
pub fn set_aiter<T>(&mut self)
96100
where
97101
T: for<'p> PyAsyncAiterProtocol<'p>,
98102
{
99-
self.am_aiter = py_unary_func!(PyAsyncAiterProtocol, T::__aiter__);
103+
self.am_aiter = py_unarys_func!(PyAsyncAiterProtocol, T::__aiter__);
100104
}
101105
pub fn set_anext<T>(&mut self)
102106
where
@@ -123,7 +127,9 @@ mod anext {
123127
fn convert(self, py: Python) -> PyResult<*mut ffi::PyObject> {
124128
match self.0 {
125129
Some(val) => Ok(val.into_py(py).into_ptr()),
126-
None => Err(crate::exceptions::StopAsyncIteration::py_err(())),
130+
None => Err(crate::exceptions::StopAsyncIteration::py_err(
131+
"Task Completed",
132+
)),
127133
}
128134
}
129135
}
@@ -133,12 +139,6 @@ mod anext {
133139
where
134140
T: for<'p> PyAsyncAnextProtocol<'p>,
135141
{
136-
py_unary_func!(
137-
PyAsyncAnextProtocol,
138-
T::__anext__,
139-
call_mut,
140-
*mut crate::ffi::PyObject,
141-
IterANextOutput
142-
)
142+
py_unarys_func!(PyAsyncAnextProtocol, T::__anext__, IterANextOutput)
143143
}
144144
}

tests/test_dunder.rs

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use pyo3::class::{
2-
PyContextProtocol, PyIterProtocol, PyMappingProtocol, PyObjectProtocol, PySequenceProtocol,
2+
PyAsyncProtocol, PyContextProtocol, PyIterProtocol, PyMappingProtocol, PyObjectProtocol,
3+
PySequenceProtocol,
34
};
45
use pyo3::exceptions::{IndexError, ValueError};
56
use pyo3::prelude::*;
@@ -552,3 +553,69 @@ fn getattr_doesnt_override_member() {
552553
py_assert!(py, inst, "inst.data == 4");
553554
py_assert!(py, inst, "inst.a == 8");
554555
}
556+
557+
/// Wraps a Python future and yield it once.
558+
#[pyclass]
559+
struct OnceFuture {
560+
future: PyObject,
561+
polled: bool,
562+
}
563+
564+
#[pymethods]
565+
impl OnceFuture {
566+
#[new]
567+
fn new(future: PyObject) -> Self {
568+
OnceFuture {
569+
future,
570+
polled: false,
571+
}
572+
}
573+
}
574+
575+
#[pyproto]
576+
impl PyAsyncProtocol for OnceFuture {
577+
fn __await__(slf: PyRef<Self>) -> PyResult<Py<Self>> {
578+
Ok(slf.into())
579+
}
580+
}
581+
582+
#[pyproto]
583+
impl PyIterProtocol for OnceFuture {
584+
fn __iter__(slf: PyRef<Self>) -> PyResult<Py<Self>> {
585+
Ok(slf.into())
586+
}
587+
fn __next__(mut slf: PyRefMut<Self>) -> PyResult<Option<PyObject>> {
588+
if !slf.polled {
589+
slf.polled = true;
590+
Ok(Some(slf.future.clone()))
591+
} else {
592+
Ok(None)
593+
}
594+
}
595+
}
596+
597+
#[test]
598+
fn test_await() {
599+
let gil = Python::acquire_gil();
600+
let py = gil.python();
601+
let once = py.get_type::<OnceFuture>();
602+
let source = pyo3::indoc::indoc!(
603+
r#"
604+
import asyncio
605+
async def main():
606+
res = await Once(await asyncio.sleep(0.1))
607+
return res
608+
loop = asyncio.get_event_loop()
609+
assert loop.run_until_complete(main()) is None
610+
loop.close()
611+
"#
612+
);
613+
let globals = [("Once", once)].into_py_dict(py);
614+
py.run(source, Some(globals), None)
615+
.map_err(|e| {
616+
e.print(py);
617+
py.run("import sys; sys.stderr.flush()", None, None)
618+
.unwrap();
619+
})
620+
.unwrap();
621+
}

0 commit comments

Comments
 (0)