Skip to main content

pyo3/conversions/
ordered_float.rs

1#![cfg(feature = "ordered-float")]
2//! Conversions to and from [ordered-float](https://docs.rs/ordered-float) types.
3//! [`NotNan`]`<`[`f32`]`>` and [`NotNan`]`<`[`f64`]`>`.
4//! [`OrderedFloat`]`<`[`f32`]`>` and [`OrderedFloat`]`<`[`f64`]`>`.
5//!
6//! This is useful for converting between Python's float into and from a native Rust type.
7//!
8//! Take care when comparing sorted collections of float types between Python and Rust.
9//! They will likely differ due to the ambiguous sort order of NaNs in Python.
10//
11//!
12//! To use this feature, add to your **`Cargo.toml`**:
13//!
14//! ```toml
15//! [dependencies]
16#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"),  "\", features = [\"ordered-float\"] }")]
17//! ordered-float = "5.0.0"
18//! ```
19//!
20//! # Example
21//!
22//! Rust code to create functions that add ordered floats:
23//!
24//! ```rust,no_run
25//! use ordered_float::{NotNan, OrderedFloat};
26//! use pyo3::prelude::*;
27//!
28//! #[pyfunction]
29//! fn add_not_nans(a: NotNan<f64>, b: NotNan<f64>) -> NotNan<f64> {
30//!     a + b
31//! }
32//!
33//! #[pyfunction]
34//! fn add_ordered_floats(a: OrderedFloat<f64>, b: OrderedFloat<f64>) -> OrderedFloat<f64> {
35//!     a + b
36//! }
37//!
38//! #[pymodule]
39//! fn my_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
40//!     m.add_function(wrap_pyfunction!(add_not_nans, m)?)?;
41//!     m.add_function(wrap_pyfunction!(add_ordered_floats, m)?)?;
42//!     Ok(())
43//! }
44//! ```
45//!
46//! Python code that validates the functionality:
47//! ```python
48//! from my_module import add_not_nans, add_ordered_floats
49//!
50//! assert add_not_nans(1.0,2.0) == 3.0
51//! assert add_ordered_floats(1.0,2.0) == 3.0
52//! ```
53
54use crate::conversion::IntoPyObject;
55use crate::exceptions::PyValueError;
56#[cfg(feature = "experimental-inspect")]
57use crate::inspect::PyStaticExpr;
58use crate::types::PyFloat;
59use crate::{Borrowed, Bound, FromPyObject, PyAny, Python};
60use ordered_float::{NotNan, OrderedFloat};
61use std::convert::Infallible;
62
63macro_rules! float_conversions {
64    ($wrapper:ident, $float_type:ty, $constructor:expr) => {
65        impl<'a, 'py> FromPyObject<'a, 'py> for $wrapper<$float_type> {
66            type Error = <$float_type as FromPyObject<'a, 'py>>::Error;
67
68            #[cfg(feature = "experimental-inspect")]
69            const INPUT_TYPE: PyStaticExpr = <$float_type>::INPUT_TYPE;
70
71            fn extract(obj: Borrowed<'a, 'py, PyAny>) -> Result<Self, Self::Error> {
72                let val: $float_type = obj.extract()?;
73                $constructor(val)
74            }
75        }
76
77        impl<'py> IntoPyObject<'py> for $wrapper<$float_type> {
78            type Target = PyFloat;
79            type Output = Bound<'py, Self::Target>;
80            type Error = Infallible;
81
82            #[cfg(feature = "experimental-inspect")]
83            const OUTPUT_TYPE: PyStaticExpr = <$float_type>::OUTPUT_TYPE;
84
85            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
86                self.into_inner().into_pyobject(py)
87            }
88        }
89
90        impl<'py> IntoPyObject<'py> for &$wrapper<$float_type> {
91            type Target = PyFloat;
92            type Output = Bound<'py, Self::Target>;
93            type Error = Infallible;
94
95            #[cfg(feature = "experimental-inspect")]
96            const OUTPUT_TYPE: PyStaticExpr = <$wrapper<$float_type>>::OUTPUT_TYPE;
97
98            #[inline]
99            fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
100                (*self).into_pyobject(py)
101            }
102        }
103    };
104}
105float_conversions!(OrderedFloat, f32, |val| Ok(OrderedFloat(val)));
106float_conversions!(OrderedFloat, f64, |val| Ok(OrderedFloat(val)));
107float_conversions!(NotNan, f32, |val| NotNan::new(val)
108    .map_err(|e| PyValueError::new_err(e.to_string())));
109float_conversions!(NotNan, f64, |val| NotNan::new(val)
110    .map_err(|e| PyValueError::new_err(e.to_string())));
111
112#[cfg(test)]
113mod test_ordered_float {
114    use super::*;
115    use crate::types::dict::IntoPyDict;
116    use crate::types::PyAnyMethods;
117    use std::ffi::CStr;
118    use std::ffi::CString;
119
120    #[cfg(not(target_arch = "wasm32"))]
121    use proptest::prelude::*;
122
123    fn py_run<'py>(py: Python<'py>, script: &CStr, locals: impl IntoPyDict<'py>) {
124        py.run(script, None, Some(&locals.into_py_dict(py).unwrap()))
125            .unwrap()
126    }
127
128    macro_rules! float_roundtrip_tests {
129        ($wrapper:ident, $float_type:ty, $constructor:expr, $standard_test:ident, $wasm_test:ident, $infinity_test:ident, $zero_test:ident) => {
130            #[cfg(not(target_arch = "wasm32"))]
131            proptest! {
132            #[test]
133            fn $standard_test(inner_f: $float_type) {
134                let f = $constructor(inner_f);
135
136                Python::attach(|py| {
137                    let f_py: Bound<'_, PyFloat>  = f.into_pyobject(py).unwrap();
138
139                    py_run(py, &CString::new(format!(
140                            "import math\nassert math.isclose(f_py, {})",
141                             inner_f as f64 // Always interpret the literal rs float value as f64
142                                            // so that it's comparable with the python float
143                        )).unwrap(), [("f_py", &f_py)]);
144
145                    let roundtripped_f: $wrapper<$float_type> = f_py.extract().unwrap();
146
147                    assert_eq!(f, roundtripped_f);
148                })
149            }
150            }
151
152            #[cfg(target_arch = "wasm32")]
153            #[test]
154            fn $wasm_test() {
155                let inner_f = 10.0;
156                let f = $constructor(inner_f);
157
158                Python::attach(|py| {
159                    let f_py: Bound<'_, PyFloat> = f.into_pyobject(py).unwrap();
160
161                    py_run(
162                        py,
163                        &CString::new(format!(
164                            "import math\nassert math.isclose(f_py, {})",
165                            inner_f as f64 // Always interpret the literal rs float value as f64
166                                           // so that it's comparable with the python float
167                        ))
168                        .unwrap(),
169                        [("f_py", &f_py)],
170                    );
171
172                    let roundtripped_f: $wrapper<$float_type> = f_py.extract().unwrap();
173
174                    assert_eq!(f, roundtripped_f);
175                })
176            }
177
178            #[test]
179            fn $infinity_test() {
180                let inner_pinf = <$float_type>::INFINITY;
181                let pinf = $constructor(inner_pinf);
182
183                let inner_ninf = <$float_type>::NEG_INFINITY;
184                let ninf = $constructor(inner_ninf);
185
186                Python::attach(|py| {
187                    let pinf_py: Bound<'_, PyFloat> = pinf.into_pyobject(py).unwrap();
188                    let ninf_py: Bound<'_, PyFloat> = ninf.into_pyobject(py).unwrap();
189
190                    py_run(
191                        py,
192                        c"\
193                        assert pinf_py == float('inf')\n\
194                        assert ninf_py == float('-inf')",
195                        [("pinf_py", &pinf_py), ("ninf_py", &ninf_py)],
196                    );
197
198                    let roundtripped_pinf: $wrapper<$float_type> = pinf_py.extract().unwrap();
199                    let roundtripped_ninf: $wrapper<$float_type> = ninf_py.extract().unwrap();
200
201                    assert_eq!(pinf, roundtripped_pinf);
202                    assert_eq!(ninf, roundtripped_ninf);
203                })
204            }
205
206            #[test]
207            fn $zero_test() {
208                let inner_pzero: $float_type = 0.0;
209                let pzero = $constructor(inner_pzero);
210
211                let inner_nzero: $float_type = -0.0;
212                let nzero = $constructor(inner_nzero);
213
214                Python::attach(|py| {
215                    let pzero_py: Bound<'_, PyFloat> = pzero.into_pyobject(py).unwrap();
216                    let nzero_py: Bound<'_, PyFloat> = nzero.into_pyobject(py).unwrap();
217
218                    // This python script verifies that the values are 0.0 in magnitude
219                    // and that the signs are correct(+0.0 vs -0.0)
220                    py_run(
221                        py,
222                        c"\
223                        import math\n\
224                        assert pzero_py == 0.0\n\
225                        assert math.copysign(1.0, pzero_py) > 0.0\n\
226                        assert nzero_py == 0.0\n\
227                        assert math.copysign(1.0, nzero_py) < 0.0",
228                        [("pzero_py", &pzero_py), ("nzero_py", &nzero_py)],
229                    );
230
231                    let roundtripped_pzero: $wrapper<$float_type> = pzero_py.extract().unwrap();
232                    let roundtripped_nzero: $wrapper<$float_type> = nzero_py.extract().unwrap();
233
234                    assert_eq!(pzero, roundtripped_pzero);
235                    assert_eq!(roundtripped_pzero.signum(), 1.0);
236                    assert_eq!(nzero, roundtripped_nzero);
237                    assert_eq!(roundtripped_nzero.signum(), -1.0);
238                })
239            }
240        };
241    }
242    float_roundtrip_tests!(
243        OrderedFloat,
244        f32,
245        OrderedFloat,
246        ordered_float_f32_standard,
247        ordered_float_f32_wasm,
248        ordered_float_f32_infinity,
249        ordered_float_f32_zero
250    );
251    float_roundtrip_tests!(
252        OrderedFloat,
253        f64,
254        OrderedFloat,
255        ordered_float_f64_standard,
256        ordered_float_f64_wasm,
257        ordered_float_f64_infinity,
258        ordered_float_f64_zero
259    );
260    float_roundtrip_tests!(
261        NotNan,
262        f32,
263        |val| NotNan::new(val).unwrap(),
264        not_nan_f32_standard,
265        not_nan_f32_wasm,
266        not_nan_f32_infinity,
267        not_nan_f32_zero
268    );
269    float_roundtrip_tests!(
270        NotNan,
271        f64,
272        |val| NotNan::new(val).unwrap(),
273        not_nan_f64_standard,
274        not_nan_f64_wasm,
275        not_nan_f64_infinity,
276        not_nan_f64_zero
277    );
278
279    macro_rules! ordered_float_pynan_tests {
280        ($test_name:ident, $float_type:ty) => {
281            #[test]
282            fn $test_name() {
283                let inner_nan: $float_type = <$float_type>::NAN;
284                let nan = OrderedFloat(inner_nan);
285
286                Python::attach(|py| {
287                    let nan_py: Bound<'_, PyFloat> = nan.into_pyobject(py).unwrap();
288
289                    py_run(
290                        py,
291                        c"import math\nassert math.isnan(nan_py)",
292                        [("nan_py", &nan_py)],
293                    );
294
295                    let roundtripped_nan: OrderedFloat<$float_type> = nan_py.extract().unwrap();
296
297                    assert_eq!(nan, roundtripped_nan);
298                })
299            }
300        };
301    }
302    ordered_float_pynan_tests!(test_ordered_float_pynan_f32, f32);
303    ordered_float_pynan_tests!(test_ordered_float_pynan_f64, f64);
304
305    macro_rules! not_nan_pynan_tests {
306        ($test_name:ident, $float_type:ty) => {
307            #[test]
308            fn $test_name() {
309                Python::attach(|py| {
310                    let nan_py = py.eval(c"float('nan')", None, None).unwrap();
311
312                    let nan_rs: Result<NotNan<$float_type>, _> = nan_py.extract();
313
314                    assert!(nan_rs.is_err());
315                })
316            }
317        };
318    }
319    not_nan_pynan_tests!(test_not_nan_pynan_f32, f32);
320    not_nan_pynan_tests!(test_not_nan_pynan_f64, f64);
321
322    macro_rules! py64_rs32 {
323        ($test_name:ident, $wrapper:ident, $float_type:ty) => {
324            #[test]
325            fn $test_name() {
326                Python::attach(|py| {
327                    let py_64 = py
328                        .import("sys")
329                        .unwrap()
330                        .getattr("float_info")
331                        .unwrap()
332                        .getattr("max")
333                        .unwrap();
334                    let rs_32 = py_64.extract::<$wrapper<f32>>().unwrap();
335                    // The python f64 is not representable in a rust f32
336                    assert!(rs_32.is_infinite());
337                })
338            }
339        };
340    }
341    py64_rs32!(ordered_float_f32, OrderedFloat, f32);
342    py64_rs32!(ordered_float_f64, OrderedFloat, f64);
343    py64_rs32!(not_nan_f32, NotNan, f32);
344    py64_rs32!(not_nan_f64, NotNan, f64);
345}