1use std::{
2 cell::UnsafeCell,
3 sync::{Mutex, Once},
4 thread::ThreadId,
5};
6
7use crate::{
8 exceptions::{PyBaseException, PyTypeError},
9 ffi,
10 ffi_ptr_ext::FfiPtrExt,
11 types::{PyAnyMethods, PyTraceback, PyType},
12 Bound, Py, PyAny, PyErrArguments, PyTypeInfo, Python,
13};
14
15pub(crate) struct PyErrState {
16 normalized: Once,
19 normalizing_thread: Mutex<Option<ThreadId>>,
21 inner: UnsafeCell<Option<PyErrStateInner>>,
22}
23
24unsafe impl Send for PyErrState {}
27unsafe impl Sync for PyErrState {}
28#[cfg(feature = "nightly")]
29unsafe impl crate::marker::Ungil for PyErrState {}
30
31impl PyErrState {
32 pub(crate) fn lazy(f: Box<PyErrStateLazyFn>) -> Self {
33 Self::from_inner(PyErrStateInner::Lazy(f))
34 }
35
36 pub(crate) fn lazy_arguments(ptype: Py<PyAny>, args: impl PyErrArguments + 'static) -> Self {
37 Self::from_inner(PyErrStateInner::Lazy(Box::new(move |py| {
38 PyErrStateLazyFnOutput {
39 ptype,
40 pvalue: args.arguments(py),
41 }
42 })))
43 }
44
45 pub(crate) fn normalized(normalized: PyErrStateNormalized) -> Self {
46 let state = Self::from_inner(PyErrStateInner::Normalized(normalized));
47 state.normalized.call_once(|| {});
52 state
53 }
54
55 pub(crate) fn restore(self, py: Python<'_>) {
56 self.inner
57 .into_inner()
58 .expect("PyErr state should never be invalid outside of normalization")
59 .restore(py)
60 }
61
62 fn from_inner(inner: PyErrStateInner) -> Self {
63 Self {
64 normalized: Once::new(),
65 normalizing_thread: Mutex::new(None),
66 inner: UnsafeCell::new(Some(inner)),
67 }
68 }
69
70 #[inline]
71 pub(crate) fn as_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
72 if self.normalized.is_completed() {
73 match unsafe {
74 &*self.inner.get()
76 } {
77 Some(PyErrStateInner::Normalized(n)) => return n,
78 _ => unreachable!(),
79 }
80 }
81
82 self.make_normalized(py)
83 }
84
85 #[cold]
86 fn make_normalized(&self, py: Python<'_>) -> &PyErrStateNormalized {
87 if let Some(thread) = self.normalizing_thread.lock().unwrap().as_ref() {
94 assert!(
95 !(*thread == std::thread::current().id()),
96 "Re-entrant normalization of PyErrState detected"
97 );
98 }
99
100 py.detach(|| {
102 self.normalized.call_once(|| {
103 self.normalizing_thread
104 .lock()
105 .unwrap()
106 .replace(std::thread::current().id());
107
108 let state = unsafe {
110 (*self.inner.get())
111 .take()
112 .expect("Cannot normalize a PyErr while already normalizing it.")
113 };
114
115 let normalized_state =
116 Python::attach(|py| PyErrStateInner::Normalized(state.normalize(py)));
117
118 unsafe {
120 *self.inner.get() = Some(normalized_state);
121 }
122 })
123 });
124
125 match unsafe {
126 &*self.inner.get()
128 } {
129 Some(PyErrStateInner::Normalized(n)) => n,
130 _ => unreachable!(),
131 }
132 }
133}
134
135pub(crate) struct PyErrStateNormalized {
136 #[cfg(not(Py_3_12))]
137 ptype: Py<PyType>,
138 pub pvalue: Py<PyBaseException>,
139 #[cfg(not(Py_3_12))]
140 ptraceback: Option<Py<PyTraceback>>,
141}
142
143impl PyErrStateNormalized {
144 pub(crate) fn new(pvalue: Bound<'_, PyBaseException>) -> Self {
145 Self {
146 #[cfg(not(Py_3_12))]
147 ptype: pvalue.get_type().into(),
148 #[cfg(not(Py_3_12))]
149 ptraceback: unsafe {
150 ffi::PyException_GetTraceback(pvalue.as_ptr())
151 .assume_owned_or_opt(pvalue.py())
152 .map(|b| b.cast_into_unchecked().unbind())
153 },
154 pvalue: pvalue.into(),
155 }
156 }
157
158 #[cfg(not(Py_3_12))]
159 pub(crate) fn ptype<'py>(&self, py: Python<'py>) -> Bound<'py, PyType> {
160 self.ptype.bind(py).clone()
161 }
162
163 #[cfg(Py_3_12)]
164 pub(crate) fn ptype<'py>(&self, py: Python<'py>) -> Bound<'py, PyType> {
165 self.pvalue.bind(py).get_type()
166 }
167
168 #[cfg(not(Py_3_12))]
169 pub(crate) fn ptraceback<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyTraceback>> {
170 self.ptraceback
171 .as_ref()
172 .map(|traceback| traceback.bind(py).clone())
173 }
174
175 #[cfg(Py_3_12)]
176 pub(crate) fn ptraceback<'py>(&self, py: Python<'py>) -> Option<Bound<'py, PyTraceback>> {
177 unsafe {
178 ffi::PyException_GetTraceback(self.pvalue.as_ptr())
179 .assume_owned_or_opt(py)
180 .map(|b| b.cast_into_unchecked())
181 }
182 }
183
184 pub(crate) fn take(py: Python<'_>) -> Option<PyErrStateNormalized> {
185 #[cfg(Py_3_12)]
186 {
187 unsafe { ffi::PyErr_GetRaisedException().assume_owned_or_opt(py) }.map(|pvalue| {
190 PyErrStateNormalized {
191 pvalue: unsafe { pvalue.cast_into_unchecked() }.unbind(),
193 }
194 })
195 }
196
197 #[cfg(not(Py_3_12))]
198 {
199 let (ptype, pvalue, ptraceback) = unsafe {
200 let mut ptype: *mut ffi::PyObject = std::ptr::null_mut();
201 let mut pvalue: *mut ffi::PyObject = std::ptr::null_mut();
202 let mut ptraceback: *mut ffi::PyObject = std::ptr::null_mut();
203
204 ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback);
205
206 if !ptype.is_null() {
208 ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
209 }
210
211 (
214 ptype
215 .assume_owned_or_opt(py)
216 .map(|b| b.cast_into_unchecked()),
217 pvalue
218 .assume_owned_or_opt(py)
219 .map(|b| b.cast_into_unchecked()),
220 ptraceback
221 .assume_owned_or_opt(py)
222 .map(|b| b.cast_into_unchecked()),
223 )
224 };
225
226 ptype.map(|ptype| PyErrStateNormalized {
227 ptype: ptype.unbind(),
228 pvalue: pvalue.expect("normalized exception value missing").unbind(),
229 ptraceback: ptraceback.map(Bound::unbind),
230 })
231 }
232 }
233
234 #[cfg(not(Py_3_12))]
235 unsafe fn from_normalized_ffi_tuple(
236 py: Python<'_>,
237 ptype: *mut ffi::PyObject,
238 pvalue: *mut ffi::PyObject,
239 ptraceback: *mut ffi::PyObject,
240 ) -> Self {
241 PyErrStateNormalized {
242 ptype: unsafe {
243 ptype
244 .assume_owned_or_opt(py)
245 .expect("Exception type missing")
246 .cast_into_unchecked()
247 }
248 .unbind(),
249 pvalue: unsafe {
250 pvalue
251 .assume_owned_or_opt(py)
252 .expect("Exception value missing")
253 .cast_into_unchecked()
254 }
255 .unbind(),
256 ptraceback: unsafe { ptraceback.assume_owned_or_opt(py) }
257 .map(|b| unsafe { b.cast_into_unchecked() }.unbind()),
258 }
259 }
260
261 pub fn clone_ref(&self, py: Python<'_>) -> Self {
262 Self {
263 #[cfg(not(Py_3_12))]
264 ptype: self.ptype.clone_ref(py),
265 pvalue: self.pvalue.clone_ref(py),
266 #[cfg(not(Py_3_12))]
267 ptraceback: self
268 .ptraceback
269 .as_ref()
270 .map(|ptraceback| ptraceback.clone_ref(py)),
271 }
272 }
273}
274
275pub(crate) struct PyErrStateLazyFnOutput {
276 pub(crate) ptype: Py<PyAny>,
277 pub(crate) pvalue: Py<PyAny>,
278}
279
280pub(crate) type PyErrStateLazyFn =
281 dyn for<'py> FnOnce(Python<'py>) -> PyErrStateLazyFnOutput + Send + Sync;
282
283enum PyErrStateInner {
284 Lazy(Box<PyErrStateLazyFn>),
285 Normalized(PyErrStateNormalized),
286}
287
288impl PyErrStateInner {
289 fn normalize(self, py: Python<'_>) -> PyErrStateNormalized {
290 match self {
291 #[cfg(not(Py_3_12))]
292 PyErrStateInner::Lazy(lazy) => {
293 let (ptype, pvalue, ptraceback) = lazy_into_normalized_ffi_tuple(py, lazy);
294 unsafe {
295 PyErrStateNormalized::from_normalized_ffi_tuple(py, ptype, pvalue, ptraceback)
296 }
297 }
298 #[cfg(Py_3_12)]
299 PyErrStateInner::Lazy(lazy) => {
300 raise_lazy(py, lazy);
303 PyErrStateNormalized::take(py)
304 .expect("exception missing after writing to the interpreter")
305 }
306 PyErrStateInner::Normalized(normalized) => normalized,
307 }
308 }
309
310 #[cfg(not(Py_3_12))]
311 fn restore(self, py: Python<'_>) {
312 let (ptype, pvalue, ptraceback) = match self {
313 PyErrStateInner::Lazy(lazy) => lazy_into_normalized_ffi_tuple(py, lazy),
314 PyErrStateInner::Normalized(PyErrStateNormalized {
315 ptype,
316 pvalue,
317 ptraceback,
318 }) => (
319 ptype.into_ptr(),
320 pvalue.into_ptr(),
321 ptraceback.map_or(std::ptr::null_mut(), Py::into_ptr),
322 ),
323 };
324 unsafe { ffi::PyErr_Restore(ptype, pvalue, ptraceback) }
325 }
326
327 #[cfg(Py_3_12)]
328 fn restore(self, py: Python<'_>) {
329 match self {
330 PyErrStateInner::Lazy(lazy) => raise_lazy(py, lazy),
331 PyErrStateInner::Normalized(PyErrStateNormalized { pvalue }) => unsafe {
332 ffi::PyErr_SetRaisedException(pvalue.into_ptr())
333 },
334 }
335 }
336}
337
338#[cfg(not(Py_3_12))]
339fn lazy_into_normalized_ffi_tuple(
340 py: Python<'_>,
341 lazy: Box<PyErrStateLazyFn>,
342) -> (*mut ffi::PyObject, *mut ffi::PyObject, *mut ffi::PyObject) {
343 raise_lazy(py, lazy);
346 let mut ptype = std::ptr::null_mut();
347 let mut pvalue = std::ptr::null_mut();
348 let mut ptraceback = std::ptr::null_mut();
349 unsafe {
350 ffi::PyErr_Fetch(&mut ptype, &mut pvalue, &mut ptraceback);
351 ffi::PyErr_NormalizeException(&mut ptype, &mut pvalue, &mut ptraceback);
352 }
353 (ptype, pvalue, ptraceback)
354}
355
356fn raise_lazy(py: Python<'_>, lazy: Box<PyErrStateLazyFn>) {
364 let PyErrStateLazyFnOutput { ptype, pvalue } = lazy(py);
365 unsafe {
366 if ffi::PyExceptionClass_Check(ptype.as_ptr()) == 0 {
367 ffi::PyErr_SetString(
368 PyTypeError::type_object_raw(py).cast(),
369 c"exceptions must derive from BaseException".as_ptr(),
370 )
371 } else {
372 ffi::PyErr_SetObject(ptype.as_ptr(), pvalue.as_ptr())
373 }
374 }
375}
376
377#[cfg(test)]
378mod tests {
379
380 use crate::{
381 exceptions::PyValueError, sync::PyOnceLock, Py, PyAny, PyErr, PyErrArguments, Python,
382 };
383
384 #[test]
385 #[should_panic(expected = "Re-entrant normalization of PyErrState detected")]
386 fn test_reentrant_normalization() {
387 static ERR: PyOnceLock<PyErr> = PyOnceLock::new();
388
389 struct RecursiveArgs;
390
391 impl PyErrArguments for RecursiveArgs {
392 fn arguments(self, py: Python<'_>) -> Py<PyAny> {
393 ERR.get(py)
395 .expect("is set just below")
396 .value(py)
397 .clone()
398 .into()
399 }
400 }
401
402 Python::attach(|py| {
403 ERR.set(py, PyValueError::new_err(RecursiveArgs)).unwrap();
404 ERR.get(py).expect("is set just above").value(py);
405 })
406 }
407
408 #[test]
409 #[cfg(not(target_arch = "wasm32"))] fn test_no_deadlock_thread_switch() {
411 static ERR: PyOnceLock<PyErr> = PyOnceLock::new();
412
413 struct GILSwitchArgs;
414
415 impl PyErrArguments for GILSwitchArgs {
416 fn arguments(self, py: Python<'_>) -> Py<PyAny> {
417 py.detach(|| {
420 std::thread::sleep(std::time::Duration::from_millis(10));
421 });
422 py.None()
423 }
424 }
425
426 Python::attach(|py| ERR.set(py, PyValueError::new_err(GILSwitchArgs)).unwrap());
427
428 let handles = (0..10)
430 .map(|_| {
431 std::thread::spawn(|| {
432 Python::attach(|py| {
433 ERR.get(py).expect("is set just above").value(py);
434 });
435 })
436 })
437 .collect::<Vec<_>>();
438
439 for handle in handles {
440 handle.join().unwrap();
441 }
442
443 Python::attach(|py| {
446 assert!(ERR
447 .get(py)
448 .expect("is set above")
449 .is_instance_of::<PyValueError>(py))
450 });
451 }
452}