Skip to main content

pyo3/conversions/
num_complex.rs

1#![cfg(feature = "num-complex")]
2
3//!  Conversions to and from [num-complex](https://docs.rs/num-complex)’
4//! [`Complex`]`<`[`f32`]`>` and [`Complex`]`<`[`f64`]`>`.
5//!
6//! num-complex’ [`Complex`] supports more operations than PyO3's [`PyComplex`]
7//! and can be used with the rest of the Rust ecosystem.
8//!
9//! # Setup
10//!
11//! To use this feature, add this to your **`Cargo.toml`**:
12//!
13//! ```toml
14//! [dependencies]
15//! # change * to the latest versions
16//! num-complex = "*"
17#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"),  "\", features = [\"num-complex\"] }")]
18//! ```
19//!
20//! Note that you must use compatible versions of num-complex and PyO3.
21//! The required num-complex version may vary based on the version of PyO3.
22//!
23//! # Examples
24//!
25//! Using [num-complex](https://docs.rs/num-complex) and [nalgebra](https://docs.rs/nalgebra)
26//! to create a pyfunction that calculates the eigenvalues of a 2x2 matrix.
27//! ```ignore
28//! # // not tested because nalgebra isn't supported on msrv
29//! # // please file an issue if it breaks!
30//! use nalgebra::base::{dimension::Const, Matrix};
31//! use num_complex::Complex;
32//! use pyo3::prelude::*;
33//!
34//! type T = Complex<f64>;
35//!
36//! #[pyfunction]
37//! fn get_eigenvalues(m11: T, m12: T, m21: T, m22: T) -> Vec<T> {
38//!     let mat = Matrix::<T, Const<2>, Const<2>, _>::new(m11, m12, m21, m22);
39//!
40//!     match mat.eigenvalues() {
41//!         Some(e) => e.data.as_slice().to_vec(),
42//!         None => vec![],
43//!     }
44//! }
45//!
46//! #[pymodule]
47//! fn my_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
48//!     m.add_function(wrap_pyfunction!(get_eigenvalues, m)?)?;
49//!     Ok(())
50//! }
51//! # // test
52//! # use assert_approx_eq::assert_approx_eq;
53//! # use nalgebra::ComplexField;
54//! # use pyo3::types::PyComplex;
55//! #
56//! # fn main() -> PyResult<()> {
57//! #     Python::attach(|py| -> PyResult<()> {
58//! #         let module = PyModule::new(py, "my_module")?;
59//! #
60//! #         module.add_function(&wrap_pyfunction!(get_eigenvalues, module)?)?;
61//! #
62//! #         let m11 = PyComplex::from_doubles(py, 0_f64, -1_f64);
63//! #         let m12 = PyComplex::from_doubles(py, 1_f64, 0_f64);
64//! #         let m21 = PyComplex::from_doubles(py, 2_f64, -1_f64);
65//! #         let m22 = PyComplex::from_doubles(py, -1_f64, 0_f64);
66//! #
67//! #         let result = module
68//! #             .getattr("get_eigenvalues")?
69//! #             .call1((m11, m12, m21, m22))?;
70//! #         println!("eigenvalues: {:?}", result);
71//! #
72//! #         let result = result.extract::<Vec<T>>()?;
73//! #         let e0 = result[0];
74//! #         let e1 = result[1];
75//! #
76//! #         assert_approx_eq!(e0, Complex::new(1_f64, -1_f64));
77//! #         assert_approx_eq!(e1, Complex::new(-2_f64, 0_f64));
78//! #
79//! #         Ok(())
80//! #     })
81//! # }
82//! ```
83//!
84//! Python code:
85//! ```python
86//! from my_module import get_eigenvalues
87//!
88//! m11 = complex(0,-1)
89//! m12 = complex(1,0)
90//! m21 = complex(2,-1)
91//! m22 = complex(-1,0)
92//!
93//! result = get_eigenvalues(m11,m12,m21,m22)
94//! assert result == [complex(1,-1), complex(-2,0)]
95//! ```
96#[cfg(feature = "experimental-inspect")]
97use crate::inspect::PyStaticExpr;
98#[cfg(feature = "experimental-inspect")]
99use crate::type_hint_identifier;
100use crate::{
101    ffi, ffi_ptr_ext::FfiPtrExt, types::PyComplex, Borrowed, Bound, FromPyObject, PyAny, PyErr,
102    Python,
103};
104use num_complex::Complex;
105use std::ffi::c_double;
106
107impl PyComplex {
108    /// Creates a new Python `PyComplex` object from `num_complex`'s [`Complex`].
109    pub fn from_complex_bound<F: Into<c_double>>(
110        py: Python<'_>,
111        complex: Complex<F>,
112    ) -> Bound<'_, PyComplex> {
113        unsafe {
114            ffi::PyComplex_FromDoubles(complex.re.into(), complex.im.into())
115                .assume_owned(py)
116                .cast_into_unchecked()
117        }
118    }
119}
120
121macro_rules! complex_conversion {
122    ($float: ty) => {
123        #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
124        impl<'py> crate::conversion::IntoPyObject<'py> for Complex<$float> {
125            type Target = PyComplex;
126            type Output = Bound<'py, Self::Target>;
127            type Error = std::convert::Infallible;
128
129            #[cfg(feature = "experimental-inspect")]
130            const OUTPUT_TYPE: PyStaticExpr = type_hint_identifier!("builtins", "complex");
131
132            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
133                unsafe {
134                    Ok(
135                        ffi::PyComplex_FromDoubles(self.re as c_double, self.im as c_double)
136                            .assume_owned(py)
137                            .cast_into_unchecked(),
138                    )
139                }
140            }
141        }
142
143        #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
144        impl<'py> crate::conversion::IntoPyObject<'py> for &Complex<$float> {
145            type Target = PyComplex;
146            type Output = Bound<'py, Self::Target>;
147            type Error = std::convert::Infallible;
148
149            #[cfg(feature = "experimental-inspect")]
150            const OUTPUT_TYPE: PyStaticExpr = <Complex<$float>>::OUTPUT_TYPE;
151
152            #[inline]
153            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
154                (*self).into_pyobject(py)
155            }
156        }
157
158        #[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
159        impl FromPyObject<'_, '_> for Complex<$float> {
160            type Error = PyErr;
161
162            #[cfg(feature = "experimental-inspect")]
163            const INPUT_TYPE: PyStaticExpr = type_hint_identifier!("builtins", "complex");
164
165            fn extract(obj: Borrowed<'_, '_, PyAny>) -> Result<Complex<$float>, Self::Error> {
166                #[cfg(not(any(Py_LIMITED_API, PyPy)))]
167                unsafe {
168                    let val = ffi::PyComplex_AsCComplex(obj.as_ptr());
169                    if val.real == -1.0 {
170                        if let Some(err) = PyErr::take(obj.py()) {
171                            return Err(err);
172                        }
173                    }
174                    Ok(Complex::new(val.real as $float, val.imag as $float))
175                }
176
177                #[cfg(any(Py_LIMITED_API, PyPy))]
178                unsafe {
179                    use $crate::types::any::PyAnyMethods;
180                    let complex;
181                    let obj = if obj.is_instance_of::<PyComplex>() {
182                        obj
183                    } else if let Some(method) =
184                        obj.lookup_special(crate::intern!(obj.py(), "__complex__"))?
185                    {
186                        complex = method.call0()?;
187                        complex.as_borrowed()
188                    } else {
189                        // `obj` might still implement `__float__` or `__index__`, which will be
190                        // handled by `PyComplex_{Real,Imag}AsDouble`, including propagating any
191                        // errors if those methods don't exist / raise exceptions.
192                        obj
193                    };
194                    let ptr = obj.as_ptr();
195                    let real = ffi::PyComplex_RealAsDouble(ptr);
196                    if real == -1.0 {
197                        if let Some(err) = PyErr::take(obj.py()) {
198                            return Err(err);
199                        }
200                    }
201                    let imag = ffi::PyComplex_ImagAsDouble(ptr);
202                    Ok(Complex::new(real as $float, imag as $float))
203                }
204            }
205        }
206    };
207}
208complex_conversion!(f32);
209complex_conversion!(f64);
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::test_utils::generate_unique_module_name;
215    use crate::types::PyAnyMethods as _;
216    use crate::types::{complex::PyComplexMethods, PyModule};
217    use crate::IntoPyObject;
218
219    #[test]
220    fn from_complex() {
221        Python::attach(|py| {
222            let complex = Complex::new(3.0, 1.2);
223            let py_c = PyComplex::from_complex_bound(py, complex);
224            assert_eq!(py_c.real(), 3.0);
225            assert_eq!(py_c.imag(), 1.2);
226        });
227    }
228    #[test]
229    fn to_from_complex() {
230        Python::attach(|py| {
231            let val = Complex::new(3.0f64, 1.2);
232            let obj = val.into_pyobject(py).unwrap();
233            assert_eq!(obj.extract::<Complex<f64>>().unwrap(), val);
234        });
235    }
236    #[test]
237    fn from_complex_err() {
238        Python::attach(|py| {
239            let obj = vec![1i32].into_pyobject(py).unwrap();
240            assert!(obj.extract::<Complex<f64>>().is_err());
241        });
242    }
243    #[test]
244    fn from_python_magic() {
245        Python::attach(|py| {
246            let module = PyModule::from_code(
247                py,
248                cr#"
249class A:
250    def __complex__(self): return 3.0+1.2j
251class B:
252    def __float__(self): return 3.0
253class C:
254    def __index__(self): return 3
255                "#,
256                c"test.py",
257                &generate_unique_module_name("test"),
258            )
259            .unwrap();
260            let from_complex = module.getattr("A").unwrap().call0().unwrap();
261            assert_eq!(
262                from_complex.extract::<Complex<f64>>().unwrap(),
263                Complex::new(3.0, 1.2)
264            );
265            let from_float = module.getattr("B").unwrap().call0().unwrap();
266            assert_eq!(
267                from_float.extract::<Complex<f64>>().unwrap(),
268                Complex::new(3.0, 0.0)
269            );
270            // Before Python 3.8, `__index__` wasn't tried by `float`/`complex`.
271            #[cfg(Py_3_8)]
272            {
273                let from_index = module.getattr("C").unwrap().call0().unwrap();
274                assert_eq!(
275                    from_index.extract::<Complex<f64>>().unwrap(),
276                    Complex::new(3.0, 0.0)
277                );
278            }
279        })
280    }
281    #[test]
282    fn from_python_inherited_magic() {
283        Python::attach(|py| {
284            let module = PyModule::from_code(
285                py,
286                cr#"
287class First: pass
288class ComplexMixin:
289    def __complex__(self): return 3.0+1.2j
290class FloatMixin:
291    def __float__(self): return 3.0
292class IndexMixin:
293    def __index__(self): return 3
294class A(First, ComplexMixin): pass
295class B(First, FloatMixin): pass
296class C(First, IndexMixin): pass
297                "#,
298                c"test.py",
299                &generate_unique_module_name("test"),
300            )
301            .unwrap();
302            let from_complex = module.getattr("A").unwrap().call0().unwrap();
303            assert_eq!(
304                from_complex.extract::<Complex<f64>>().unwrap(),
305                Complex::new(3.0, 1.2)
306            );
307            let from_float = module.getattr("B").unwrap().call0().unwrap();
308            assert_eq!(
309                from_float.extract::<Complex<f64>>().unwrap(),
310                Complex::new(3.0, 0.0)
311            );
312            #[cfg(Py_3_8)]
313            {
314                let from_index = module.getattr("C").unwrap().call0().unwrap();
315                assert_eq!(
316                    from_index.extract::<Complex<f64>>().unwrap(),
317                    Complex::new(3.0, 0.0)
318                );
319            }
320        })
321    }
322    #[test]
323    fn from_python_noncallable_descriptor_magic() {
324        // Functions and lambdas implement the descriptor protocol in a way that makes
325        // `type(inst).attr(inst)` equivalent to `inst.attr()` for methods, but this isn't the only
326        // way the descriptor protocol might be implemented.
327        Python::attach(|py| {
328            let module = PyModule::from_code(
329                py,
330                cr#"
331class A:
332    @property
333    def __complex__(self):
334        return lambda: 3.0+1.2j
335                "#,
336                c"test.py",
337                &generate_unique_module_name("test"),
338            )
339            .unwrap();
340            let obj = module.getattr("A").unwrap().call0().unwrap();
341            assert_eq!(
342                obj.extract::<Complex<f64>>().unwrap(),
343                Complex::new(3.0, 1.2)
344            );
345        })
346    }
347    #[test]
348    fn from_python_nondescriptor_magic() {
349        // Magic methods don't need to implement the descriptor protocol, if they're callable.
350        Python::attach(|py| {
351            let module = PyModule::from_code(
352                py,
353                cr#"
354class MyComplex:
355    def __call__(self): return 3.0+1.2j
356class A:
357    __complex__ = MyComplex()
358                "#,
359                c"test.py",
360                &generate_unique_module_name("test"),
361            )
362            .unwrap();
363            let obj = module.getattr("A").unwrap().call0().unwrap();
364            assert_eq!(
365                obj.extract::<Complex<f64>>().unwrap(),
366                Complex::new(3.0, 1.2)
367            );
368        })
369    }
370}