Skip to main content

pyo3/conversions/std/
array.rs

1use crate::conversion::{FromPyObjectOwned, FromPyObjectSequence, IntoPyObject};
2#[cfg(feature = "experimental-inspect")]
3use crate::inspect::{type_hint_subscript, PyStaticExpr};
4use crate::types::any::PyAnyMethods;
5use crate::types::PySequence;
6use crate::{err::CastError, ffi, FromPyObject, PyAny, PyResult, PyTypeInfo, Python};
7use crate::{exceptions, Borrowed, Bound, PyErr};
8
9impl<'py, T, const N: usize> IntoPyObject<'py> for [T; N]
10where
11    T: IntoPyObject<'py>,
12{
13    type Target = PyAny;
14    type Output = Bound<'py, Self::Target>;
15    type Error = PyErr;
16
17    #[cfg(feature = "experimental-inspect")]
18    const OUTPUT_TYPE: PyStaticExpr = T::SEQUENCE_OUTPUT_TYPE;
19
20    /// Turns [`[u8; N]`](std::array) into [`PyBytes`], all other `T`s will be turned into a [`PyList`]
21    ///
22    /// [`PyBytes`]: crate::types::PyBytes
23    /// [`PyList`]: crate::types::PyList
24    #[inline]
25    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
26        T::owned_sequence_into_pyobject(self, py, crate::conversion::private::Token)
27    }
28}
29
30impl<'a, 'py, T, const N: usize> IntoPyObject<'py> for &'a [T; N]
31where
32    &'a T: IntoPyObject<'py>,
33{
34    type Target = PyAny;
35    type Output = Bound<'py, Self::Target>;
36    type Error = PyErr;
37
38    #[cfg(feature = "experimental-inspect")]
39    const OUTPUT_TYPE: PyStaticExpr = <&[T]>::OUTPUT_TYPE;
40
41    #[inline]
42    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
43        self.as_slice().into_pyobject(py)
44    }
45}
46
47impl<'py, T, const N: usize> FromPyObject<'_, 'py> for [T; N]
48where
49    T: FromPyObjectOwned<'py>,
50{
51    type Error = PyErr;
52
53    #[cfg(feature = "experimental-inspect")]
54    const INPUT_TYPE: PyStaticExpr = type_hint_subscript!(PySequence::TYPE_HINT, T::INPUT_TYPE);
55
56    fn extract(obj: Borrowed<'_, 'py, PyAny>) -> PyResult<Self> {
57        if let Some(extractor) = T::sequence_extractor(obj, crate::conversion::private::Token) {
58            return extractor.to_array();
59        }
60
61        create_array_from_obj(obj)
62    }
63}
64
65fn create_array_from_obj<'py, T, const N: usize>(obj: Borrowed<'_, 'py, PyAny>) -> PyResult<[T; N]>
66where
67    T: FromPyObjectOwned<'py>,
68{
69    // Types that pass `PySequence_Check` usually implement enough of the sequence protocol
70    // to support this function and if not, we will only fail extraction safely.
71    if unsafe { ffi::PySequence_Check(obj.as_ptr()) } == 0 {
72        return Err(CastError::new(obj, PySequence::type_object(obj.py()).into_any()).into());
73    }
74
75    let seq_len = obj.len()?;
76    if seq_len != N {
77        return Err(invalid_sequence_length(N, seq_len));
78    }
79    array_try_from_fn(|idx| {
80        obj.get_item(idx)
81            .and_then(|any| any.extract().map_err(Into::into))
82    })
83}
84
85// TODO use std::array::try_from_fn, if that stabilises:
86// (https://github.com/rust-lang/rust/issues/89379)
87fn array_try_from_fn<E, F, T, const N: usize>(mut cb: F) -> Result<[T; N], E>
88where
89    F: FnMut(usize) -> Result<T, E>,
90{
91    // Helper to safely create arrays since the standard library doesn't
92    // provide one yet. Shouldn't be necessary in the future.
93    struct ArrayGuard<T, const N: usize> {
94        dst: *mut T,
95        initialized: usize,
96    }
97
98    impl<T, const N: usize> Drop for ArrayGuard<T, N> {
99        fn drop(&mut self) {
100            debug_assert!(self.initialized <= N);
101            let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, self.initialized);
102            unsafe {
103                core::ptr::drop_in_place(initialized_part);
104            }
105        }
106    }
107
108    // [MaybeUninit<T>; N] would be "nicer" but is actually difficult to create - there are nightly
109    // APIs which would make this easier.
110    let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit();
111    let mut guard: ArrayGuard<T, N> = ArrayGuard {
112        dst: array.as_mut_ptr() as _,
113        initialized: 0,
114    };
115    unsafe {
116        let mut value_ptr = array.as_mut_ptr() as *mut T;
117        for i in 0..N {
118            core::ptr::write(value_ptr, cb(i)?);
119            value_ptr = value_ptr.offset(1);
120            guard.initialized += 1;
121        }
122        core::mem::forget(guard);
123        Ok(array.assume_init())
124    }
125}
126
127pub(crate) fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr {
128    exceptions::PyValueError::new_err(format!(
129        "expected a sequence of length {expected} (got {actual})"
130    ))
131}
132
133#[cfg(test)]
134mod tests {
135    #[cfg(panic = "unwind")]
136    use std::{
137        panic,
138        sync::atomic::{AtomicUsize, Ordering},
139    };
140
141    use crate::{
142        conversion::IntoPyObject,
143        types::{any::PyAnyMethods, PyBytes, PyBytesMethods},
144    };
145    use crate::{types::PyList, PyResult, Python};
146
147    #[test]
148    #[cfg(panic = "unwind")]
149    fn array_try_from_fn() {
150        static DROP_COUNTER: AtomicUsize = AtomicUsize::new(0);
151        struct CountDrop;
152        impl Drop for CountDrop {
153            fn drop(&mut self) {
154                DROP_COUNTER.fetch_add(1, Ordering::SeqCst);
155            }
156        }
157        let _ = catch_unwind_silent(move || {
158            let _: Result<[CountDrop; 4], ()> = super::array_try_from_fn(|idx| {
159                #[expect(clippy::manual_assert, reason = "testing panic during array creation")]
160                if idx == 2 {
161                    panic!("peek a boo");
162                }
163                Ok(CountDrop)
164            });
165        });
166        assert_eq!(DROP_COUNTER.load(Ordering::SeqCst), 2);
167    }
168
169    #[test]
170    fn test_extract_bytes_to_array() {
171        Python::attach(|py| {
172            let v: [u8; 33] = py
173                .eval(c"b'abcabcabcabcabcabcabcabcabcabcabc'", None, None)
174                .unwrap()
175                .extract()
176                .unwrap();
177            assert_eq!(&v, b"abcabcabcabcabcabcabcabcabcabcabc");
178        })
179    }
180
181    #[test]
182    fn test_extract_bytes_wrong_length() {
183        Python::attach(|py| {
184            let v: PyResult<[u8; 3]> = py.eval(c"b'abcdefg'", None, None).unwrap().extract();
185            assert_eq!(
186                v.unwrap_err().to_string(),
187                "ValueError: expected a sequence of length 3 (got 7)"
188            );
189        })
190    }
191
192    #[test]
193    fn test_extract_bytearray_to_array() {
194        Python::attach(|py| {
195            let v: [u8; 33] = py
196                .eval(
197                    c"bytearray(b'abcabcabcabcabcabcabcabcabcabcabc')",
198                    None,
199                    None,
200                )
201                .unwrap()
202                .extract()
203                .unwrap();
204            assert_eq!(&v, b"abcabcabcabcabcabcabcabcabcabcabc");
205        })
206    }
207
208    #[test]
209    fn test_extract_small_bytearray_to_array() {
210        Python::attach(|py| {
211            let v: [u8; 3] = py
212                .eval(c"bytearray(b'abc')", None, None)
213                .unwrap()
214                .extract()
215                .unwrap();
216            assert_eq!(&v, b"abc");
217        });
218    }
219    #[test]
220    fn test_into_pyobject_array_conversion() {
221        Python::attach(|py| {
222            let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0];
223            let pyobject = array.into_pyobject(py).unwrap();
224            let pylist = pyobject.cast::<PyList>().unwrap();
225            assert_eq!(pylist.get_item(0).unwrap().extract::<f32>().unwrap(), 0.0);
226            assert_eq!(pylist.get_item(1).unwrap().extract::<f32>().unwrap(), -16.0);
227            assert_eq!(pylist.get_item(2).unwrap().extract::<f32>().unwrap(), 16.0);
228            assert_eq!(pylist.get_item(3).unwrap().extract::<f32>().unwrap(), 42.0);
229        });
230    }
231
232    #[test]
233    fn test_extract_invalid_sequence_length() {
234        Python::attach(|py| {
235            let v: PyResult<[u8; 3]> = py
236                .eval(c"bytearray(b'abcdefg')", None, None)
237                .unwrap()
238                .extract();
239            assert_eq!(
240                v.unwrap_err().to_string(),
241                "ValueError: expected a sequence of length 3 (got 7)"
242            );
243        })
244    }
245
246    #[test]
247    fn test_intopyobject_array_conversion() {
248        Python::attach(|py| {
249            let array: [f32; 4] = [0.0, -16.0, 16.0, 42.0];
250            let pylist = array
251                .into_pyobject(py)
252                .unwrap()
253                .cast_into::<PyList>()
254                .unwrap();
255
256            assert_eq!(pylist.get_item(0).unwrap().extract::<f32>().unwrap(), 0.0);
257            assert_eq!(pylist.get_item(1).unwrap().extract::<f32>().unwrap(), -16.0);
258            assert_eq!(pylist.get_item(2).unwrap().extract::<f32>().unwrap(), 16.0);
259            assert_eq!(pylist.get_item(3).unwrap().extract::<f32>().unwrap(), 42.0);
260        });
261    }
262
263    #[test]
264    fn test_array_intopyobject_impl() {
265        Python::attach(|py| {
266            let bytes: [u8; 6] = *b"foobar";
267            let obj = bytes.into_pyobject(py).unwrap();
268            assert!(obj.is_instance_of::<PyBytes>());
269            let obj = obj.cast_into::<PyBytes>().unwrap();
270            assert_eq!(obj.as_bytes(), &bytes);
271
272            let nums: [u16; 4] = [0, 1, 2, 3];
273            let obj = nums.into_pyobject(py).unwrap();
274            assert!(obj.is_instance_of::<PyList>());
275        });
276    }
277
278    #[test]
279    fn test_extract_non_iterable_to_array() {
280        Python::attach(|py| {
281            let v = py.eval(c"42", None, None).unwrap();
282            v.extract::<i32>().unwrap();
283            v.extract::<[i32; 1]>().unwrap_err();
284        });
285    }
286
287    #[cfg(feature = "macros")]
288    #[test]
289    fn test_pyclass_intopy_array_conversion() {
290        #[crate::pyclass(crate = "crate")]
291        struct Foo;
292
293        Python::attach(|py| {
294            let array: [Foo; 8] = [Foo, Foo, Foo, Foo, Foo, Foo, Foo, Foo];
295            let list = array
296                .into_pyobject(py)
297                .unwrap()
298                .cast_into::<PyList>()
299                .unwrap();
300            let _bound = list.get_item(4).unwrap().cast::<Foo>().unwrap();
301        });
302    }
303
304    // https://stackoverflow.com/a/59211505
305    #[cfg(panic = "unwind")]
306    fn catch_unwind_silent<F, R>(f: F) -> std::thread::Result<R>
307    where
308        F: FnOnce() -> R + panic::UnwindSafe,
309    {
310        let prev_hook = panic::take_hook();
311        panic::set_hook(Box::new(|_| {}));
312        let result = panic::catch_unwind(f);
313        panic::set_hook(prev_hook);
314        result
315    }
316}