Skip to main content

pyo3/types/
set.rs

1use crate::types::PyIterator;
2use crate::{
3    err::{self, PyErr, PyResult},
4    ffi_ptr_ext::FfiPtrExt,
5    instance::Bound,
6    py_result_ext::PyResultExt,
7};
8use crate::{ffi, Borrowed, BoundObject, IntoPyObject, IntoPyObjectExt, PyAny, Python};
9#[cfg(RustPython)]
10use crate::{
11    sync::PyOnceLock,
12    types::{PyType, PyTypeMethods},
13    Py,
14};
15use core::ptr;
16
17/// Represents a Python `set`.
18///
19/// Values of this type are accessed via PyO3's smart pointers, e.g. as
20/// [`Py<PySet>`][crate::Py] or [`Bound<'py, PySet>`][Bound].
21///
22/// For APIs available on `set` objects, see the [`PySetMethods`] trait which is implemented for
23/// [`Bound<'py, PySet>`][Bound].
24#[repr(transparent)]
25pub struct PySet(PyAny);
26
27#[cfg(not(any(PyPy, GraalPy)))]
28pyobject_subclassable_native_type!(PySet, crate::ffi::PySetObject);
29
30#[cfg(all(not(any(PyPy, GraalPy)), not(RustPython)))]
31pyobject_native_type!(
32    PySet,
33    ffi::PySetObject,
34    pyobject_native_static_type_object!(ffi::PySet_Type),
35    "builtins",
36    "set",
37    #checkfunction=ffi::PySet_Check
38);
39
40#[cfg(all(not(any(PyPy, GraalPy)), RustPython))]
41pyobject_native_type!(
42    PySet,
43    ffi::PySetObject,
44    |py| {
45        static TYPE: PyOnceLock<Py<PyType>> = PyOnceLock::new();
46        TYPE.import(py, "builtins", "set").unwrap().as_type_ptr()
47    },
48    "builtins",
49    "set",
50    #checkfunction=ffi::PySet_Check
51);
52
53#[cfg(any(PyPy, GraalPy))]
54pyobject_native_type_core!(
55    PySet,
56    pyobject_native_static_type_object!(ffi::PySet_Type),
57    "builtins",
58    "set",
59    #checkfunction=ffi::PySet_Check
60);
61
62impl PySet {
63    /// Creates a new set with elements from the given slice.
64    ///
65    /// Returns an error if some element is not hashable.
66    #[inline]
67    pub fn new<'py, T>(
68        py: Python<'py>,
69        elements: impl IntoIterator<Item = T>,
70    ) -> PyResult<Bound<'py, PySet>>
71    where
72        T: IntoPyObject<'py>,
73    {
74        let set = Self::empty(py)?;
75        for e in elements {
76            set.add(e)?;
77        }
78        Ok(set)
79    }
80
81    /// Creates a new empty set.
82    pub fn empty(py: Python<'_>) -> PyResult<Bound<'_, PySet>> {
83        unsafe {
84            ffi::PySet_New(ptr::null_mut())
85                .assume_owned_or_err(py)
86                .cast_into_unchecked()
87        }
88    }
89}
90
91/// Implementation of functionality for [`PySet`].
92///
93/// These methods are defined for the `Bound<'py, PySet>` smart pointer, so to use method call
94/// syntax these methods are separated into a trait, because stable Rust does not yet support
95/// `arbitrary_self_types`.
96#[doc(alias = "PySet")]
97pub trait PySetMethods<'py>: crate::sealed::Sealed {
98    /// Removes all elements from the set.
99    fn clear(&self);
100
101    /// Returns the number of items in the set.
102    ///
103    /// This is equivalent to the Python expression `len(self)`.
104    fn len(&self) -> usize;
105
106    /// Checks if set is empty.
107    fn is_empty(&self) -> bool {
108        self.len() == 0
109    }
110
111    /// Determines if the set contains the specified key.
112    ///
113    /// This is equivalent to the Python expression `key in self`.
114    fn contains<K>(&self, key: K) -> PyResult<bool>
115    where
116        K: IntoPyObject<'py>;
117
118    /// Removes the element from the set if it is present.
119    ///
120    /// Returns `true` if the element was present in the set.
121    fn discard<K>(&self, key: K) -> PyResult<bool>
122    where
123        K: IntoPyObject<'py>;
124
125    /// Adds an element to the set.
126    fn add<K>(&self, key: K) -> PyResult<()>
127    where
128        K: IntoPyObject<'py>;
129
130    /// Removes and returns an arbitrary element from the set.
131    fn pop(&self) -> Option<Bound<'py, PyAny>>;
132
133    /// Returns an iterator of values in this set.
134    ///
135    /// # Panics
136    ///
137    /// If PyO3 detects that the set is mutated during iteration, it will panic.
138    fn iter(&self) -> BoundSetIterator<'py>;
139}
140
141impl<'py> PySetMethods<'py> for Bound<'py, PySet> {
142    #[inline]
143    fn clear(&self) {
144        unsafe {
145            ffi::PySet_Clear(self.as_ptr());
146        }
147    }
148
149    #[inline]
150    fn len(&self) -> usize {
151        unsafe { ffi::PySet_Size(self.as_ptr()) as usize }
152    }
153
154    fn contains<K>(&self, key: K) -> PyResult<bool>
155    where
156        K: IntoPyObject<'py>,
157    {
158        fn inner(set: &Bound<'_, PySet>, key: Borrowed<'_, '_, PyAny>) -> PyResult<bool> {
159            match unsafe { ffi::PySet_Contains(set.as_ptr(), key.as_ptr()) } {
160                1 => Ok(true),
161                0 => Ok(false),
162                _ => Err(PyErr::fetch(set.py())),
163            }
164        }
165
166        let py = self.py();
167        inner(
168            self,
169            key.into_pyobject_or_pyerr(py)?.into_any().as_borrowed(),
170        )
171    }
172
173    fn discard<K>(&self, key: K) -> PyResult<bool>
174    where
175        K: IntoPyObject<'py>,
176    {
177        fn inner(set: &Bound<'_, PySet>, key: Borrowed<'_, '_, PyAny>) -> PyResult<bool> {
178            match unsafe { ffi::PySet_Discard(set.as_ptr(), key.as_ptr()) } {
179                1 => Ok(true),
180                0 => Ok(false),
181                _ => Err(PyErr::fetch(set.py())),
182            }
183        }
184
185        let py = self.py();
186        inner(
187            self,
188            key.into_pyobject_or_pyerr(py)?.into_any().as_borrowed(),
189        )
190    }
191
192    fn add<K>(&self, key: K) -> PyResult<()>
193    where
194        K: IntoPyObject<'py>,
195    {
196        fn inner(set: &Bound<'_, PySet>, key: Borrowed<'_, '_, PyAny>) -> PyResult<()> {
197            err::error_on_minusone(set.py(), unsafe {
198                ffi::PySet_Add(set.as_ptr(), key.as_ptr())
199            })
200        }
201
202        let py = self.py();
203        inner(
204            self,
205            key.into_pyobject_or_pyerr(py)?.into_any().as_borrowed(),
206        )
207    }
208
209    fn pop(&self) -> Option<Bound<'py, PyAny>> {
210        let element = unsafe { ffi::PySet_Pop(self.as_ptr()).assume_owned_or_err(self.py()) };
211        element.ok()
212    }
213
214    fn iter(&self) -> BoundSetIterator<'py> {
215        BoundSetIterator::new(self.clone())
216    }
217}
218
219impl<'py> IntoIterator for Bound<'py, PySet> {
220    type Item = Bound<'py, PyAny>;
221    type IntoIter = BoundSetIterator<'py>;
222
223    /// Returns an iterator of values in this set.
224    ///
225    /// # Panics
226    ///
227    /// If PyO3 detects that the set is mutated during iteration, it will panic.
228    fn into_iter(self) -> Self::IntoIter {
229        BoundSetIterator::new(self)
230    }
231}
232
233impl<'py> IntoIterator for &Bound<'py, PySet> {
234    type Item = Bound<'py, PyAny>;
235    type IntoIter = BoundSetIterator<'py>;
236
237    /// Returns an iterator of values in this set.
238    ///
239    /// # Panics
240    ///
241    /// If PyO3 detects that the set is mutated during iteration, it will panic.
242    fn into_iter(self) -> Self::IntoIter {
243        self.iter()
244    }
245}
246
247/// PyO3 implementation of an iterator for a Python `set` object.
248pub struct BoundSetIterator<'py>(Bound<'py, PyIterator>);
249
250impl<'py> BoundSetIterator<'py> {
251    pub(super) fn new(set: Bound<'py, PySet>) -> Self {
252        Self(PyIterator::from_object(&set).expect("set should always be iterable"))
253    }
254}
255
256impl<'py> Iterator for BoundSetIterator<'py> {
257    type Item = Bound<'py, super::PyAny>;
258
259    /// Advances the iterator and returns the next value.
260    fn next(&mut self) -> Option<Self::Item> {
261        self.0
262            .next()
263            .map(|result| result.expect("set iteration should be infallible"))
264    }
265
266    fn size_hint(&self) -> (usize, Option<usize>) {
267        let len = ExactSizeIterator::len(self);
268        (len, Some(len))
269    }
270
271    #[inline]
272    fn count(self) -> usize
273    where
274        Self: Sized,
275    {
276        self.len()
277    }
278}
279
280impl ExactSizeIterator for BoundSetIterator<'_> {
281    fn len(&self) -> usize {
282        self.0.size_hint().0
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::PySet;
289    use crate::platform::HashSet;
290    use crate::{
291        conversion::IntoPyObject,
292        types::{PyAnyMethods, PySetMethods},
293        Python,
294    };
295
296    #[test]
297    fn test_set_new() {
298        Python::attach(|py| {
299            let set = PySet::new(py, [1]).unwrap();
300            assert_eq!(1, set.len());
301
302            let v = vec![1];
303            assert!(PySet::new(py, &[v]).is_err());
304        });
305    }
306
307    #[test]
308    fn test_set_empty() {
309        Python::attach(|py| {
310            let set = PySet::empty(py).unwrap();
311            assert_eq!(0, set.len());
312            assert!(set.is_empty());
313        });
314    }
315
316    #[test]
317    fn test_set_len() {
318        Python::attach(|py| {
319            let mut v = HashSet::<i32>::new();
320            let ob = (&v).into_pyobject(py).unwrap();
321            let set = ob.cast::<PySet>().unwrap();
322            assert_eq!(0, set.len());
323            v.insert(7);
324            let ob = v.into_pyobject(py).unwrap();
325            let set2 = ob.cast::<PySet>().unwrap();
326            assert_eq!(1, set2.len());
327        });
328    }
329
330    #[test]
331    fn test_set_clear() {
332        Python::attach(|py| {
333            let set = PySet::new(py, [1]).unwrap();
334            assert_eq!(1, set.len());
335            set.clear();
336            assert_eq!(0, set.len());
337        });
338    }
339
340    #[test]
341    fn test_set_contains() {
342        Python::attach(|py| {
343            let set = PySet::new(py, [1]).unwrap();
344            assert!(set.contains(1).unwrap());
345        });
346    }
347
348    #[test]
349    fn test_set_discard() {
350        Python::attach(|py| {
351            let set = PySet::new(py, [1]).unwrap();
352            assert!(!set.discard(2).unwrap());
353            assert_eq!(1, set.len());
354
355            assert!(set.discard(1).unwrap());
356            assert_eq!(0, set.len());
357            assert!(!set.discard(1).unwrap());
358
359            assert!(set.discard(vec![1, 2]).is_err());
360        });
361    }
362
363    #[test]
364    fn test_set_add() {
365        Python::attach(|py| {
366            let set = PySet::new(py, [1, 2]).unwrap();
367            set.add(1).unwrap(); // Add a duplicated element
368            assert!(set.contains(1).unwrap());
369        });
370    }
371
372    #[test]
373    fn test_set_pop() {
374        Python::attach(|py| {
375            let set = PySet::new(py, [1]).unwrap();
376            let val = set.pop();
377            assert!(val.is_some());
378            let val2 = set.pop();
379            assert!(val2.is_none());
380            assert!(py
381                .eval(c"print('Exception state should not be set.')", None, None)
382                .is_ok());
383        });
384    }
385
386    #[test]
387    fn test_set_iter() {
388        Python::attach(|py| {
389            let set = PySet::new(py, [1]).unwrap();
390
391            for el in set {
392                assert_eq!(1i32, el.extract::<'_, i32>().unwrap());
393            }
394        });
395    }
396
397    #[test]
398    fn test_set_iter_bound() {
399        use crate::types::any::PyAnyMethods;
400
401        Python::attach(|py| {
402            let set = PySet::new(py, [1]).unwrap();
403
404            for el in &set {
405                assert_eq!(1i32, el.extract::<i32>().unwrap());
406            }
407        });
408    }
409
410    #[test]
411    #[should_panic]
412    fn test_set_iter_mutation() {
413        Python::attach(|py| {
414            let set = PySet::new(py, [1, 2, 3, 4, 5]).unwrap();
415
416            for _ in &set {
417                let _ = set.add(42);
418            }
419        });
420    }
421
422    #[test]
423    #[should_panic]
424    fn test_set_iter_mutation_same_len() {
425        Python::attach(|py| {
426            let set = PySet::new(py, [1, 2, 3, 4, 5]).unwrap();
427
428            for item in &set {
429                let item: i32 = item.extract().unwrap();
430                let _ = set.del_item(item);
431                let _ = set.add(item + 10);
432            }
433        });
434    }
435
436    #[test]
437    fn test_set_iter_size_hint() {
438        Python::attach(|py| {
439            let set = PySet::new(py, [1]).unwrap();
440            let mut iter = set.iter();
441
442            // Exact size
443            assert_eq!(iter.len(), 1);
444            assert_eq!(iter.size_hint(), (1, Some(1)));
445            iter.next();
446            assert_eq!(iter.len(), 0);
447            assert_eq!(iter.size_hint(), (0, Some(0)));
448        });
449    }
450
451    #[test]
452    fn test_iter_count() {
453        Python::attach(|py| {
454            let set = PySet::new(py, vec![1, 2, 3]).unwrap();
455            assert_eq!(set.iter().count(), 3);
456        })
457    }
458}