pyo3/conversions/
ordered_float.rs

1#![cfg(feature = "ordered-float")]
2//! Conversions to and from [ordered-float](https://docs.rs/ordered-float) types.
3//! [`NotNan`]`<`[`f32`]`>` and [`NotNan`]`<`[`f64`]`>`.
4//! [`OrderedFloat`]`<`[`f32`]`>` and [`OrderedFloat`]`<`[`f64`]`>`.
5//!
6//! This is useful for converting between Python's float into and from a native Rust type.
7//!
8//! Take care when comparing sorted collections of float types between Python and Rust.
9//! They will likely differ due to the ambiguous sort order of NaNs in Python.
10//
11//!
12//! To use this feature, add to your **`Cargo.toml`**:
13//!
14//! ```toml
15//! [dependencies]
16#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"),  "\", features = [\"ordered-float\"] }")]
17//! ordered-float = "5.0.0"
18//! ```
19//!
20//! # Example
21//!
22//! Rust code to create functions that add ordered floats:
23//!
24//! ```rust,no_run
25//! use ordered_float::{NotNan, OrderedFloat};
26//! use pyo3::prelude::*;
27//!
28//! #[pyfunction]
29//! fn add_not_nans(a: NotNan<f64>, b: NotNan<f64>) -> NotNan<f64> {
30//!     a + b
31//! }
32//!
33//! #[pyfunction]
34//! fn add_ordered_floats(a: OrderedFloat<f64>, b: OrderedFloat<f64>) -> OrderedFloat<f64> {
35//!     a + b
36//! }
37//!
38//! #[pymodule]
39//! fn my_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
40//!     m.add_function(wrap_pyfunction!(add_not_nans, m)?)?;
41//!     m.add_function(wrap_pyfunction!(add_ordered_floats, m)?)?;
42//!     Ok(())
43//! }
44//! ```
45//!
46//! Python code that validates the functionality:
47//! ```python
48//! from my_module import add_not_nans, add_ordered_floats
49//!
50//! assert add_not_nans(1.0,2.0) == 3.0
51//! assert add_ordered_floats(1.0,2.0) == 3.0
52//! ```
53
54use crate::conversion::IntoPyObject;
55use crate::exceptions::PyValueError;
56use crate::types::PyFloat;
57use crate::{Borrowed, Bound, FromPyObject, PyAny, Python};
58use ordered_float::{NotNan, OrderedFloat};
59use std::convert::Infallible;
60
61macro_rules! float_conversions {
62    ($wrapper:ident, $float_type:ty, $constructor:expr) => {
63        impl<'a, 'py> FromPyObject<'a, 'py> for $wrapper<$float_type> {
64            type Error = <$float_type as FromPyObject<'a, 'py>>::Error;
65
66            fn extract(obj: Borrowed<'a, 'py, PyAny>) -> Result<Self, Self::Error> {
67                let val: $float_type = obj.extract()?;
68                $constructor(val)
69            }
70        }
71
72        impl<'py> IntoPyObject<'py> for $wrapper<$float_type> {
73            type Target = PyFloat;
74            type Output = Bound<'py, Self::Target>;
75            type Error = Infallible;
76
77            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
78                self.into_inner().into_pyobject(py)
79            }
80        }
81
82        impl<'py> IntoPyObject<'py> for &$wrapper<$float_type> {
83            type Target = PyFloat;
84            type Output = Bound<'py, Self::Target>;
85            type Error = Infallible;
86
87            #[inline]
88            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
89                (*self).into_pyobject(py)
90            }
91        }
92    };
93}
94float_conversions!(OrderedFloat, f32, |val| Ok(OrderedFloat(val)));
95float_conversions!(OrderedFloat, f64, |val| Ok(OrderedFloat(val)));
96float_conversions!(NotNan, f32, |val| NotNan::new(val)
97    .map_err(|e| PyValueError::new_err(e.to_string())));
98float_conversions!(NotNan, f64, |val| NotNan::new(val)
99    .map_err(|e| PyValueError::new_err(e.to_string())));
100
101#[cfg(test)]
102mod test_ordered_float {
103    use super::*;
104    use crate::types::dict::IntoPyDict;
105    use crate::types::PyAnyMethods;
106    use std::ffi::CStr;
107    use std::ffi::CString;
108
109    #[cfg(not(target_arch = "wasm32"))]
110    use proptest::prelude::*;
111
112    fn py_run<'py>(py: Python<'py>, script: &CStr, locals: impl IntoPyDict<'py>) {
113        py.run(script, None, Some(&locals.into_py_dict(py).unwrap()))
114            .unwrap()
115    }
116
117    macro_rules! float_roundtrip_tests {
118        ($wrapper:ident, $float_type:ty, $constructor:expr, $standard_test:ident, $wasm_test:ident, $infinity_test:ident, $zero_test:ident) => {
119            #[cfg(not(target_arch = "wasm32"))]
120            proptest! {
121            #[test]
122            fn $standard_test(inner_f: $float_type) {
123                let f = $constructor(inner_f);
124
125                Python::attach(|py| {
126                    let f_py: Bound<'_, PyFloat>  = f.into_pyobject(py).unwrap();
127
128                    py_run(py, &CString::new(format!(
129                            "import math\nassert math.isclose(f_py, {})",
130                             inner_f as f64 // Always interpret the literal rs float value as f64
131                                            // so that it's comparable with the python float
132                        )).unwrap(), [("f_py", &f_py)]);
133
134                    let roundtripped_f: $wrapper<$float_type> = f_py.extract().unwrap();
135
136                    assert_eq!(f, roundtripped_f);
137                })
138            }
139            }
140
141            #[cfg(target_arch = "wasm32")]
142            #[test]
143            fn $wasm_test() {
144                let inner_f = 10.0;
145                let f = $constructor(inner_f);
146
147                Python::attach(|py| {
148                    let f_py: Bound<'_, PyFloat> = f.into_pyobject(py).unwrap();
149
150                    py_run(
151                        py,
152                        &CString::new(format!(
153                            "import math\nassert math.isclose(f_py, {})",
154                            inner_f as f64 // Always interpret the literal rs float value as f64
155                                           // so that it's comparable with the python float
156                        ))
157                        .unwrap(),
158                        [("f_py", &f_py)],
159                    );
160
161                    let roundtripped_f: $wrapper<$float_type> = f_py.extract().unwrap();
162
163                    assert_eq!(f, roundtripped_f);
164                })
165            }
166
167            #[test]
168            fn $infinity_test() {
169                let inner_pinf = <$float_type>::INFINITY;
170                let pinf = $constructor(inner_pinf);
171
172                let inner_ninf = <$float_type>::NEG_INFINITY;
173                let ninf = $constructor(inner_ninf);
174
175                Python::attach(|py| {
176                    let pinf_py: Bound<'_, PyFloat> = pinf.into_pyobject(py).unwrap();
177                    let ninf_py: Bound<'_, PyFloat> = ninf.into_pyobject(py).unwrap();
178
179                    py_run(
180                        py,
181                        c"\
182                        assert pinf_py == float('inf')\n\
183                        assert ninf_py == float('-inf')",
184                        [("pinf_py", &pinf_py), ("ninf_py", &ninf_py)],
185                    );
186
187                    let roundtripped_pinf: $wrapper<$float_type> = pinf_py.extract().unwrap();
188                    let roundtripped_ninf: $wrapper<$float_type> = ninf_py.extract().unwrap();
189
190                    assert_eq!(pinf, roundtripped_pinf);
191                    assert_eq!(ninf, roundtripped_ninf);
192                })
193            }
194
195            #[test]
196            fn $zero_test() {
197                let inner_pzero: $float_type = 0.0;
198                let pzero = $constructor(inner_pzero);
199
200                let inner_nzero: $float_type = -0.0;
201                let nzero = $constructor(inner_nzero);
202
203                Python::attach(|py| {
204                    let pzero_py: Bound<'_, PyFloat> = pzero.into_pyobject(py).unwrap();
205                    let nzero_py: Bound<'_, PyFloat> = nzero.into_pyobject(py).unwrap();
206
207                    // This python script verifies that the values are 0.0 in magnitude
208                    // and that the signs are correct(+0.0 vs -0.0)
209                    py_run(
210                        py,
211                        c"\
212                        import math\n\
213                        assert pzero_py == 0.0\n\
214                        assert math.copysign(1.0, pzero_py) > 0.0\n\
215                        assert nzero_py == 0.0\n\
216                        assert math.copysign(1.0, nzero_py) < 0.0",
217                        [("pzero_py", &pzero_py), ("nzero_py", &nzero_py)],
218                    );
219
220                    let roundtripped_pzero: $wrapper<$float_type> = pzero_py.extract().unwrap();
221                    let roundtripped_nzero: $wrapper<$float_type> = nzero_py.extract().unwrap();
222
223                    assert_eq!(pzero, roundtripped_pzero);
224                    assert_eq!(roundtripped_pzero.signum(), 1.0);
225                    assert_eq!(nzero, roundtripped_nzero);
226                    assert_eq!(roundtripped_nzero.signum(), -1.0);
227                })
228            }
229        };
230    }
231    float_roundtrip_tests!(
232        OrderedFloat,
233        f32,
234        OrderedFloat,
235        ordered_float_f32_standard,
236        ordered_float_f32_wasm,
237        ordered_float_f32_infinity,
238        ordered_float_f32_zero
239    );
240    float_roundtrip_tests!(
241        OrderedFloat,
242        f64,
243        OrderedFloat,
244        ordered_float_f64_standard,
245        ordered_float_f64_wasm,
246        ordered_float_f64_infinity,
247        ordered_float_f64_zero
248    );
249    float_roundtrip_tests!(
250        NotNan,
251        f32,
252        |val| NotNan::new(val).unwrap(),
253        not_nan_f32_standard,
254        not_nan_f32_wasm,
255        not_nan_f32_infinity,
256        not_nan_f32_zero
257    );
258    float_roundtrip_tests!(
259        NotNan,
260        f64,
261        |val| NotNan::new(val).unwrap(),
262        not_nan_f64_standard,
263        not_nan_f64_wasm,
264        not_nan_f64_infinity,
265        not_nan_f64_zero
266    );
267
268    macro_rules! ordered_float_pynan_tests {
269        ($test_name:ident, $float_type:ty) => {
270            #[test]
271            fn $test_name() {
272                let inner_nan: $float_type = <$float_type>::NAN;
273                let nan = OrderedFloat(inner_nan);
274
275                Python::attach(|py| {
276                    let nan_py: Bound<'_, PyFloat> = nan.into_pyobject(py).unwrap();
277
278                    py_run(
279                        py,
280                        c"import math\nassert math.isnan(nan_py)",
281                        [("nan_py", &nan_py)],
282                    );
283
284                    let roundtripped_nan: OrderedFloat<$float_type> = nan_py.extract().unwrap();
285
286                    assert_eq!(nan, roundtripped_nan);
287                })
288            }
289        };
290    }
291    ordered_float_pynan_tests!(test_ordered_float_pynan_f32, f32);
292    ordered_float_pynan_tests!(test_ordered_float_pynan_f64, f64);
293
294    macro_rules! not_nan_pynan_tests {
295        ($test_name:ident, $float_type:ty) => {
296            #[test]
297            fn $test_name() {
298                Python::attach(|py| {
299                    let nan_py = py.eval(c"float('nan')", None, None).unwrap();
300
301                    let nan_rs: Result<NotNan<$float_type>, _> = nan_py.extract();
302
303                    assert!(nan_rs.is_err());
304                })
305            }
306        };
307    }
308    not_nan_pynan_tests!(test_not_nan_pynan_f32, f32);
309    not_nan_pynan_tests!(test_not_nan_pynan_f64, f64);
310
311    macro_rules! py64_rs32 {
312        ($test_name:ident, $wrapper:ident, $float_type:ty) => {
313            #[test]
314            fn $test_name() {
315                Python::attach(|py| {
316                    let py_64 = py
317                        .import("sys")
318                        .unwrap()
319                        .getattr("float_info")
320                        .unwrap()
321                        .getattr("max")
322                        .unwrap();
323                    let rs_32 = py_64.extract::<$wrapper<f32>>().unwrap();
324                    // The python f64 is not representable in a rust f32
325                    assert!(rs_32.is_infinite());
326                })
327            }
328        };
329    }
330    py64_rs32!(ordered_float_f32, OrderedFloat, f32);
331    py64_rs32!(ordered_float_f64, OrderedFloat, f64);
332    py64_rs32!(not_nan_f32, NotNan, f32);
333    py64_rs32!(not_nan_f64, NotNan, f64);
334}