pyo3/err/
err_state.rs

1use std::{
2    cell::UnsafeCell,
3    sync::{Mutex, Once},
4    thread::ThreadId,
5};
6
7use crate::{
8    exceptions::{PyBaseException, PyTypeError},
9    ffi,
10    ffi_ptr_ext::FfiPtrExt,
11    types::{PyAnyMethods, PyTraceback, PyType},
12    Bound, Py, PyAny, PyErrArguments, PyTypeInfo, Python,
13};
14
15pub(crate) struct PyErrState {
16    // Safety: can only hand out references when in the "normalized" state. Will never change
17    // after normalization.
18    normalized: Once,
19    // Guard against re-entrancy when normalizing the exception state.
20    normalizing_thread: Mutex<Option<ThreadId>>,
21    inner: UnsafeCell<Option<PyErrStateInner>>,
22}
23
24// Safety: The inner value is protected by locking to ensure that only the normalized state is
25// handed out as a reference.
26unsafe impl Send for PyErrState {}
27unsafe impl Sync for PyErrState {}
28#[cfg(feature = "nightly")]
29unsafe impl crate::marker::Ungil for PyErrState {}
30
31impl PyErrState {
32    pub(crate) fn lazy(f: Box<PyErrStateLazyFn>) -> Self {
33        Self::from_inner(PyErrStateInner::Lazy(f))
34    }
35
36    pub(crate) fn lazy_arguments(ptype: Py<PyAny>, args: impl PyErrArguments + 'static) -> Self {
37        Self::from_inner(PyErrStateInner::Lazy(Box::new(move |py| {
38            PyErrStateLazyFnOutput {
39                ptype,
40                pvalue: args.arguments(py),
41            }
42        })))
43    }
44
45    pub(crate) fn normalized(normalized: PyErrStateNormalized) -> Self {
46        let state = Self::from_inner(PyErrStateInner::Normalized(normalized));
47        // This state is already normalized, by completing the Once immediately we avoid
48        // reaching the `py.detach` in `make_normalized` which is less efficient
49        // and introduces a GIL switch which could deadlock.
50        // See https://github.com/PyO3/pyo3/issues/4764
51        state.normalized.call_once(|| {});
52        state
53    }
54
55    pub(crate) fn restore(self, py: Python<'_>) {
56        self.inner
57            .into_inner()
58            .expect("PyErr state should never be invalid outside of normalization")
59            .restore(py)
60    }
61
62    fn from_inner(inner: PyErrStateInner) -> Self {
63        Self {
64            normalized: Once::new(),
65            normalizing_thread: Mutex::new(None),
66            inner: UnsafeCell::new(Some(inner)),
67        }
68    }
69
70    #[inline]
71    pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
72        if self.normalized.is_completed() {
73            match unsafe {
74                // Safety: self.inner will never be written again once normalized.
75                &*self.inner.get()
76            } {
77                Some(PyErrStateInner::Normalized(n)) => return n,
78                _ => unreachable!(),
79            }
80        }
81
82        self.make_normalized(py)
83    }
84
85    #[cold]
86    fn make_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
87        // This process is safe because:
88        // - Write happens only once, and then never will change again.
89        // - The `Once` ensure that only one thread will do the write.
90
91        // Guard against re-entrant normalization, because `Once` does not provide
92        // re-entrancy guarantees.
93        if let Some(thread) = self.normalizing_thread.lock().unwrap().as_ref() {
94            assert!(
95                !(*thread == std::thread::current().id()),
96                "Re-entrant normalization of PyErrState detected"
97            );
98        }
99
100        // avoid deadlock of `.call_once` with the GIL
101        py.detach(|| {
102            self.normalized.call_once(|| {
103                self.normalizing_thread
104                    .lock()
105                    .unwrap()
106                    .replace(std::thread::current().id());
107
108                // Safety: no other thread can access the inner value while we are normalizing it.
109                let state = unsafe {
110                    (*self.inner.get())
111                        .take()
112                        .expect("Cannot normalize a PyErr while already normalizing it.")
113                };
114
115                let normalized_state =
116                    Python::attach(|py| PyErrStateInner::Normalized(state.normalize(py)));
117
118                // Safety: no other thread can access the inner value while we are normalizing it.
119                unsafe {
120                    *self.inner.get() = Some(normalized_state);
121                }
122            })
123        });
124
125        match unsafe {
126            // Safety: self.inner will never be written again once normalized.
127            &*self.inner.get()
128        } {
129            Some(PyErrStateInner::Normalized(n)) => n,
130            _ => unreachable!(),
131        }
132    }
133}
134
135pub(crate) struct PyErrStateNormalized {
136    #[cfg(not(Py_3_12))]
137    ptype: Py<PyType>,
138    pub pvalue: Py<PyBaseException>,
139    #[cfg(not(Py_3_12))]
140    ptraceback: Option<Py<PyTraceback>>,
141}
142
143impl PyErrStateNormalized {
144    pub(crate) fn new(pvalue: Bound<'_, PyBaseException>) -> Self {
145        Self {
146            #[cfg(not(Py_3_12))]
147            ptype: pvalue.get_type().into(),
148            #[cfg(not(Py_3_12))]
149            ptraceback: unsafe {
150                ffi::PyException_GetTraceback(pvalue.as_ptr())
151                    .assume_owned_or_opt(pvalue.py())
152                    .map(|b| b.cast_into_unchecked().unbind())
153            },
154            pvalue: pvalue.into(),
155        }
156    }
157
158    #[cfg(not(Py_3_12))]
159    pub(crate) fn ptype<'py>(&self, py: Python<'py>) -> Bound<'py, PyType> {
160        self.ptype.bind(py).clone()
161    }
162
163    #[cfg(Py_3_12)]
164    pub(crate) fn ptype<'py>(&self, py: Python<'py>) -> Bound<'py, PyType> {
165        self.pvalue.bind(py).get_type()
166    }
167
168    #[cfg(not(Py_3_12))]
169    pub(crate) fn ptraceback<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyTraceback>> {
170        self.ptraceback
171            .as_ref()
172            .map(|traceback| traceback.bind(py).clone())
173    }
174
175    #[cfg(Py_3_12)]
176    pub(crate) fn ptraceback<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyTraceback>> {
177        unsafe {
178            ffi::PyException_GetTraceback(self.pvalue.as_ptr())
179                .assume_owned_or_opt(py)
180                .map(|b| b.cast_into_unchecked())
181        }
182    }
183
184    pub(crate) fn take(py: Python<'_>) -> Option<PyErrStateNormalized> {
185        #[cfg(Py_3_12)]
186        {
187            // Safety: PyErr_GetRaisedException can be called when attached to Python and
188            // returns either NULL or an owned reference.
189            unsafe { ffi::PyErr_GetRaisedException().assume_owned_or_opt(py) }.map(|pvalue| {
190                PyErrStateNormalized {
191                    // Safety: PyErr_GetRaisedException returns a valid exception type.
192                    pvalue: unsafe { pvalue.cast_into_unchecked() }.unbind(),
193                }
194            })
195        }
196
197        #[cfg(not(Py_3_12))]
198        {
199            let (ptype, pvalue, ptraceback) = unsafe {
200                let mut ptype: *mut ffi::PyObject = std::ptr::null_mut();
201                let mut pvalue: *mut ffi::PyObject = std::ptr::null_mut();
202                let mut ptraceback: *mut ffi::PyObject = std::ptr::null_mut();
203
204                ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback);
205
206                // Ensure that the exception coming from the interpreter is normalized.
207                if !ptype.is_null() {
208                    ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
209                }
210
211                // Safety: PyErr_NormalizeException will have produced up to three owned
212                // references of the correct types.
213                (
214                    ptype
215                        .assume_owned_or_opt(py)
216                        .map(|b| b.cast_into_unchecked()),
217                    pvalue
218                        .assume_owned_or_opt(py)
219                        .map(|b| b.cast_into_unchecked()),
220                    ptraceback
221                        .assume_owned_or_opt(py)
222                        .map(|b| b.cast_into_unchecked()),
223                )
224            };
225
226            ptype.map(|ptype| PyErrStateNormalized {
227                ptype: ptype.unbind(),
228                pvalue: pvalue.expect("normalized exception value missing").unbind(),
229                ptraceback: ptraceback.map(Bound::unbind),
230            })
231        }
232    }
233
234    #[cfg(not(Py_3_12))]
235    unsafe fn from_normalized_ffi_tuple(
236        py: Python<'_>,
237        ptype: *mut ffi::PyObject,
238        pvalue: *mut ffi::PyObject,
239        ptraceback: *mut ffi::PyObject,
240    ) -> Self {
241        PyErrStateNormalized {
242            ptype: unsafe {
243                ptype
244                    .assume_owned_or_opt(py)
245                    .expect("Exception type missing")
246                    .cast_into_unchecked()
247            }
248            .unbind(),
249            pvalue: unsafe {
250                pvalue
251                    .assume_owned_or_opt(py)
252                    .expect("Exception value missing")
253                    .cast_into_unchecked()
254            }
255            .unbind(),
256            ptraceback: unsafe { ptraceback.assume_owned_or_opt(py) }
257                .map(|b| unsafe { b.cast_into_unchecked() }.unbind()),
258        }
259    }
260
261    pub fn clone_ref(&self, py: Python<'_>) -> Self {
262        Self {
263            #[cfg(not(Py_3_12))]
264            ptype: self.ptype.clone_ref(py),
265            pvalue: self.pvalue.clone_ref(py),
266            #[cfg(not(Py_3_12))]
267            ptraceback: self
268                .ptraceback
269                .as_ref()
270                .map(|ptraceback| ptraceback.clone_ref(py)),
271        }
272    }
273}
274
275pub(crate) struct PyErrStateLazyFnOutput {
276    pub(crate) ptype: Py<PyAny>,
277    pub(crate) pvalue: Py<PyAny>,
278}
279
280pub(crate) type PyErrStateLazyFn =
281    dyn for<'py> FnOnce(Python<'py>) -> PyErrStateLazyFnOutput + Send + Sync;
282
283enum PyErrStateInner {
284    Lazy(Box<PyErrStateLazyFn>),
285    Normalized(PyErrStateNormalized),
286}
287
288impl PyErrStateInner {
289    fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
290        match self {
291            #[cfg(not(Py_3_12))]
292            PyErrStateInner::Lazy(lazy) => {
293                let (ptype, pvalue, ptraceback) = lazy_into_normalized_ffi_tuple(py, lazy);
294                unsafe {
295                    PyErrStateNormalized::from_normalized_ffi_tuple(py, ptype, pvalue, ptraceback)
296                }
297            }
298            #[cfg(Py_3_12)]
299            PyErrStateInner::Lazy(lazy) => {
300                // To keep the implementation simple, just write the exception into the interpreter,
301                // which will cause it to be normalized
302                raise_lazy(py, lazy);
303                PyErrStateNormalized::take(py)
304                    .expect("exception missing after writing to the interpreter")
305            }
306            PyErrStateInner::Normalized(normalized) => normalized,
307        }
308    }
309
310    #[cfg(not(Py_3_12))]
311    fn restore(self, py: Python<'_>) {
312        let (ptype, pvalue, ptraceback) = match self {
313            PyErrStateInner::Lazy(lazy) => lazy_into_normalized_ffi_tuple(py, lazy),
314            PyErrStateInner::Normalized(PyErrStateNormalized {
315                ptype,
316                pvalue,
317                ptraceback,
318            }) => (
319                ptype.into_ptr(),
320                pvalue.into_ptr(),
321                ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr),
322            ),
323        };
324        unsafe { ffi::PyErr_Restore(ptype, pvalue, ptraceback) }
325    }
326
327    #[cfg(Py_3_12)]
328    fn restore(self, py: Python<'_>) {
329        match self {
330            PyErrStateInner::Lazy(lazy) => raise_lazy(py, lazy),
331            PyErrStateInner::Normalized(PyErrStateNormalized { pvalue }) => unsafe {
332                ffi::PyErr_SetRaisedException(pvalue.into_ptr())
333            },
334        }
335    }
336}
337
338#[cfg(not(Py_3_12))]
339fn lazy_into_normalized_ffi_tuple(
340    py: Python<'_>,
341    lazy: Box<PyErrStateLazyFn>,
342) -> (*mut ffi::PyObject, *mut ffi::PyObject, *mut ffi::PyObject) {
343    // To be consistent with 3.12 logic, go via raise_lazy, but also then normalize
344    // the resulting exception
345    raise_lazy(py, lazy);
346    let mut ptype = std::ptr::null_mut();
347    let mut pvalue = std::ptr::null_mut();
348    let mut ptraceback = std::ptr::null_mut();
349    unsafe {
350        ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback);
351        ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
352    }
353    (ptype, pvalue, ptraceback)
354}
355
356/// Raises a "lazy" exception state into the Python interpreter.
357///
358/// In principle this could be split in two; first a function to create an exception
359/// in a normalized state, and then a call to `PyErr_SetRaisedException` to raise it.
360///
361/// This would require either moving some logic from C to Rust, or requesting a new
362/// API in CPython.
363fn raise_lazy(py: Python<'_>, lazy: Box<PyErrStateLazyFn>) {
364    let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py);
365    unsafe {
366        if ffi::PyExceptionClass_Check(ptype.as_ptr()) == 0 {
367            ffi::PyErr_SetString(
368                PyTypeError::type_object_raw(py).cast(),
369                c"exceptions must derive from BaseException".as_ptr(),
370            )
371        } else {
372            ffi::PyErr_SetObject(ptype.as_ptr(), pvalue.as_ptr())
373        }
374    }
375}
376
377#[cfg(test)]
378mod tests {
379
380    use crate::{
381        exceptions::PyValueError, sync::PyOnceLock, Py, PyAny, PyErr, PyErrArguments, Python,
382    };
383
384    #[test]
385    #[should_panic(expected = "Re-entrant normalization of PyErrState detected")]
386    fn test_reentrant_normalization() {
387        static ERR: PyOnceLock<PyErr> = PyOnceLock::new();
388
389        struct RecursiveArgs;
390
391        impl PyErrArguments for RecursiveArgs {
392            fn arguments(self, py: Python<'_>) -> Py<PyAny> {
393                // .value(py) triggers normalization
394                ERR.get(py)
395                    .expect("is set just below")
396                    .value(py)
397                    .clone()
398                    .into()
399            }
400        }
401
402        Python::attach(|py| {
403            ERR.set(py, PyValueError::new_err(RecursiveArgs)).unwrap();
404            ERR.get(py).expect("is set just above").value(py);
405        })
406    }
407
408    #[test]
409    #[cfg(not(target_arch = "wasm32"))] // We are building wasm Python with pthreads disabled
410    fn test_no_deadlock_thread_switch() {
411        static ERR: PyOnceLock<PyErr> = PyOnceLock::new();
412
413        struct GILSwitchArgs;
414
415        impl PyErrArguments for GILSwitchArgs {
416            fn arguments(self, py: Python<'_>) -> Py<PyAny> {
417                // releasing the GIL potentially allows for other threads to deadlock
418                // with the normalization going on here
419                py.detach(|| {
420                    std::thread::sleep(std::time::Duration::from_millis(10));
421                });
422                py.None()
423            }
424        }
425
426        Python::attach(|py| ERR.set(py, PyValueError::new_err(GILSwitchArgs)).unwrap());
427
428        // Let many threads attempt to read the normalized value at the same time
429        let handles = (0..10)
430            .map(|_| {
431                std::thread::spawn(|| {
432                    Python::attach(|py| {
433                        ERR.get(py).expect("is set just above").value(py);
434                    });
435                })
436            })
437            .collect::<Vec<_>>();
438
439        for handle in handles {
440            handle.join().unwrap();
441        }
442
443        // We should never have deadlocked, and should be able to run
444        // this assertion
445        Python::attach(|py| {
446            assert!(ERR
447                .get(py)
448                .expect("is set above")
449                .is_instance_of::<PyValueError>(py))
450        });
451    }
452}