Skip to main content

pyo3/types/
set.rs

1use crate::types::PyIterator;
2use crate::{
3    err::{self, PyErr, PyResult},
4    ffi_ptr_ext::FfiPtrExt,
5    instance::Bound,
6    py_result_ext::PyResultExt,
7};
8use crate::{ffi, Borrowed, BoundObject, IntoPyObject, IntoPyObjectExt, PyAny, Python};
9use std::ptr;
10
11/// Represents a Python `set`.
12///
13/// Values of this type are accessed via PyO3's smart pointers, e.g. as
14/// [`Py<PySet>`][crate::Py] or [`Bound<'py, PySet>`][Bound].
15///
16/// For APIs available on `set` objects, see the [`PySetMethods`] trait which is implemented for
17/// [`Bound<'py, PySet>`][Bound].
18#[repr(transparent)]
19pub struct PySet(PyAny);
20
21#[cfg(not(any(PyPy, GraalPy)))]
22pyobject_subclassable_native_type!(PySet, crate::ffi::PySetObject);
23
24#[cfg(not(any(PyPy, GraalPy)))]
25pyobject_native_type!(
26    PySet,
27    ffi::PySetObject,
28    pyobject_native_static_type_object!(ffi::PySet_Type),
29    "builtins",
30    "set",
31    #checkfunction=ffi::PySet_Check
32);
33
34#[cfg(any(PyPy, GraalPy))]
35pyobject_native_type_core!(
36    PySet,
37    pyobject_native_static_type_object!(ffi::PySet_Type),
38    "builtins",
39    "set",
40    #checkfunction=ffi::PySet_Check
41);
42
43impl PySet {
44    /// Creates a new set with elements from the given slice.
45    ///
46    /// Returns an error if some element is not hashable.
47    #[inline]
48    pub fn new<'py, T>(
49        py: Python<'py>,
50        elements: impl IntoIterator<Item = T>,
51    ) -> PyResult<Bound<'py, PySet>>
52    where
53        T: IntoPyObject<'py>,
54    {
55        let set = Self::empty(py)?;
56        for e in elements {
57            set.add(e)?;
58        }
59        Ok(set)
60    }
61
62    /// Creates a new empty set.
63    pub fn empty(py: Python<'_>) -> PyResult<Bound<'_, PySet>> {
64        unsafe {
65            ffi::PySet_New(ptr::null_mut())
66                .assume_owned_or_err(py)
67                .cast_into_unchecked()
68        }
69    }
70}
71
72/// Implementation of functionality for [`PySet`].
73///
74/// These methods are defined for the `Bound<'py, PySet>` smart pointer, so to use method call
75/// syntax these methods are separated into a trait, because stable Rust does not yet support
76/// `arbitrary_self_types`.
77#[doc(alias = "PySet")]
78pub trait PySetMethods<'py>: crate::sealed::Sealed {
79    /// Removes all elements from the set.
80    fn clear(&self);
81
82    /// Returns the number of items in the set.
83    ///
84    /// This is equivalent to the Python expression `len(self)`.
85    fn len(&self) -> usize;
86
87    /// Checks if set is empty.
88    fn is_empty(&self) -> bool {
89        self.len() == 0
90    }
91
92    /// Determines if the set contains the specified key.
93    ///
94    /// This is equivalent to the Python expression `key in self`.
95    fn contains<K>(&self, key: K) -> PyResult<bool>
96    where
97        K: IntoPyObject<'py>;
98
99    /// Removes the element from the set if it is present.
100    ///
101    /// Returns `true` if the element was present in the set.
102    fn discard<K>(&self, key: K) -> PyResult<bool>
103    where
104        K: IntoPyObject<'py>;
105
106    /// Adds an element to the set.
107    fn add<K>(&self, key: K) -> PyResult<()>
108    where
109        K: IntoPyObject<'py>;
110
111    /// Removes and returns an arbitrary element from the set.
112    fn pop(&self) -> Option<Bound<'py, PyAny>>;
113
114    /// Returns an iterator of values in this set.
115    ///
116    /// # Panics
117    ///
118    /// If PyO3 detects that the set is mutated during iteration, it will panic.
119    fn iter(&self) -> BoundSetIterator<'py>;
120}
121
122impl<'py> PySetMethods<'py> for Bound<'py, PySet> {
123    #[inline]
124    fn clear(&self) {
125        unsafe {
126            ffi::PySet_Clear(self.as_ptr());
127        }
128    }
129
130    #[inline]
131    fn len(&self) -> usize {
132        unsafe { ffi::PySet_Size(self.as_ptr()) as usize }
133    }
134
135    fn contains<K>(&self, key: K) -> PyResult<bool>
136    where
137        K: IntoPyObject<'py>,
138    {
139        fn inner(set: &Bound<'_, PySet>, key: Borrowed<'_, '_, PyAny>) -> PyResult<bool> {
140            match unsafe { ffi::PySet_Contains(set.as_ptr(), key.as_ptr()) } {
141                1 => Ok(true),
142                0 => Ok(false),
143                _ => Err(PyErr::fetch(set.py())),
144            }
145        }
146
147        let py = self.py();
148        inner(
149            self,
150            key.into_pyobject_or_pyerr(py)?.into_any().as_borrowed(),
151        )
152    }
153
154    fn discard<K>(&self, key: K) -> PyResult<bool>
155    where
156        K: IntoPyObject<'py>,
157    {
158        fn inner(set: &Bound<'_, PySet>, key: Borrowed<'_, '_, PyAny>) -> PyResult<bool> {
159            match unsafe { ffi::PySet_Discard(set.as_ptr(), key.as_ptr()) } {
160                1 => Ok(true),
161                0 => Ok(false),
162                _ => Err(PyErr::fetch(set.py())),
163            }
164        }
165
166        let py = self.py();
167        inner(
168            self,
169            key.into_pyobject_or_pyerr(py)?.into_any().as_borrowed(),
170        )
171    }
172
173    fn add<K>(&self, key: K) -> PyResult<()>
174    where
175        K: IntoPyObject<'py>,
176    {
177        fn inner(set: &Bound<'_, PySet>, key: Borrowed<'_, '_, PyAny>) -> PyResult<()> {
178            err::error_on_minusone(set.py(), unsafe {
179                ffi::PySet_Add(set.as_ptr(), key.as_ptr())
180            })
181        }
182
183        let py = self.py();
184        inner(
185            self,
186            key.into_pyobject_or_pyerr(py)?.into_any().as_borrowed(),
187        )
188    }
189
190    fn pop(&self) -> Option<Bound<'py, PyAny>> {
191        let element = unsafe { ffi::PySet_Pop(self.as_ptr()).assume_owned_or_err(self.py()) };
192        element.ok()
193    }
194
195    fn iter(&self) -> BoundSetIterator<'py> {
196        BoundSetIterator::new(self.clone())
197    }
198}
199
200impl<'py> IntoIterator for Bound<'py, PySet> {
201    type Item = Bound<'py, PyAny>;
202    type IntoIter = BoundSetIterator<'py>;
203
204    /// Returns an iterator of values in this set.
205    ///
206    /// # Panics
207    ///
208    /// If PyO3 detects that the set is mutated during iteration, it will panic.
209    fn into_iter(self) -> Self::IntoIter {
210        BoundSetIterator::new(self)
211    }
212}
213
214impl<'py> IntoIterator for &Bound<'py, PySet> {
215    type Item = Bound<'py, PyAny>;
216    type IntoIter = BoundSetIterator<'py>;
217
218    /// Returns an iterator of values in this set.
219    ///
220    /// # Panics
221    ///
222    /// If PyO3 detects that the set is mutated during iteration, it will panic.
223    fn into_iter(self) -> Self::IntoIter {
224        self.iter()
225    }
226}
227
228/// PyO3 implementation of an iterator for a Python `set` object.
229pub struct BoundSetIterator<'py>(Bound<'py, PyIterator>);
230
231impl<'py> BoundSetIterator<'py> {
232    pub(super) fn new(set: Bound<'py, PySet>) -> Self {
233        Self(PyIterator::from_object(&set).expect("set should always be iterable"))
234    }
235}
236
237impl<'py> Iterator for BoundSetIterator<'py> {
238    type Item = Bound<'py, super::PyAny>;
239
240    /// Advances the iterator and returns the next value.
241    fn next(&mut self) -> Option<Self::Item> {
242        self.0
243            .next()
244            .map(|result| result.expect("set iteration should be infallible"))
245    }
246
247    fn size_hint(&self) -> (usize, Option<usize>) {
248        let len = ExactSizeIterator::len(self);
249        (len, Some(len))
250    }
251
252    #[inline]
253    fn count(self) -> usize
254    where
255        Self: Sized,
256    {
257        self.len()
258    }
259}
260
261impl ExactSizeIterator for BoundSetIterator<'_> {
262    fn len(&self) -> usize {
263        self.0.size_hint().0
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::PySet;
270    use crate::{
271        conversion::IntoPyObject,
272        types::{PyAnyMethods, PySetMethods},
273        Python,
274    };
275    use std::collections::HashSet;
276
277    #[test]
278    fn test_set_new() {
279        Python::attach(|py| {
280            let set = PySet::new(py, [1]).unwrap();
281            assert_eq!(1, set.len());
282
283            let v = vec![1];
284            assert!(PySet::new(py, &[v]).is_err());
285        });
286    }
287
288    #[test]
289    fn test_set_empty() {
290        Python::attach(|py| {
291            let set = PySet::empty(py).unwrap();
292            assert_eq!(0, set.len());
293            assert!(set.is_empty());
294        });
295    }
296
297    #[test]
298    fn test_set_len() {
299        Python::attach(|py| {
300            let mut v = HashSet::<i32>::new();
301            let ob = (&v).into_pyobject(py).unwrap();
302            let set = ob.cast::<PySet>().unwrap();
303            assert_eq!(0, set.len());
304            v.insert(7);
305            let ob = v.into_pyobject(py).unwrap();
306            let set2 = ob.cast::<PySet>().unwrap();
307            assert_eq!(1, set2.len());
308        });
309    }
310
311    #[test]
312    fn test_set_clear() {
313        Python::attach(|py| {
314            let set = PySet::new(py, [1]).unwrap();
315            assert_eq!(1, set.len());
316            set.clear();
317            assert_eq!(0, set.len());
318        });
319    }
320
321    #[test]
322    fn test_set_contains() {
323        Python::attach(|py| {
324            let set = PySet::new(py, [1]).unwrap();
325            assert!(set.contains(1).unwrap());
326        });
327    }
328
329    #[test]
330    fn test_set_discard() {
331        Python::attach(|py| {
332            let set = PySet::new(py, [1]).unwrap();
333            assert!(!set.discard(2).unwrap());
334            assert_eq!(1, set.len());
335
336            assert!(set.discard(1).unwrap());
337            assert_eq!(0, set.len());
338            assert!(!set.discard(1).unwrap());
339
340            assert!(set.discard(vec![1, 2]).is_err());
341        });
342    }
343
344    #[test]
345    fn test_set_add() {
346        Python::attach(|py| {
347            let set = PySet::new(py, [1, 2]).unwrap();
348            set.add(1).unwrap(); // Add a duplicated element
349            assert!(set.contains(1).unwrap());
350        });
351    }
352
353    #[test]
354    fn test_set_pop() {
355        Python::attach(|py| {
356            let set = PySet::new(py, [1]).unwrap();
357            let val = set.pop();
358            assert!(val.is_some());
359            let val2 = set.pop();
360            assert!(val2.is_none());
361            assert!(py
362                .eval(c"print('Exception state should not be set.')", None, None)
363                .is_ok());
364        });
365    }
366
367    #[test]
368    fn test_set_iter() {
369        Python::attach(|py| {
370            let set = PySet::new(py, [1]).unwrap();
371
372            for el in set {
373                assert_eq!(1i32, el.extract::<'_, i32>().unwrap());
374            }
375        });
376    }
377
378    #[test]
379    fn test_set_iter_bound() {
380        use crate::types::any::PyAnyMethods;
381
382        Python::attach(|py| {
383            let set = PySet::new(py, [1]).unwrap();
384
385            for el in &set {
386                assert_eq!(1i32, el.extract::<i32>().unwrap());
387            }
388        });
389    }
390
391    #[test]
392    #[should_panic]
393    fn test_set_iter_mutation() {
394        Python::attach(|py| {
395            let set = PySet::new(py, [1, 2, 3, 4, 5]).unwrap();
396
397            for _ in &set {
398                let _ = set.add(42);
399            }
400        });
401    }
402
403    #[test]
404    #[should_panic]
405    fn test_set_iter_mutation_same_len() {
406        Python::attach(|py| {
407            let set = PySet::new(py, [1, 2, 3, 4, 5]).unwrap();
408
409            for item in &set {
410                let item: i32 = item.extract().unwrap();
411                let _ = set.del_item(item);
412                let _ = set.add(item + 10);
413            }
414        });
415    }
416
417    #[test]
418    fn test_set_iter_size_hint() {
419        Python::attach(|py| {
420            let set = PySet::new(py, [1]).unwrap();
421            let mut iter = set.iter();
422
423            // Exact size
424            assert_eq!(iter.len(), 1);
425            assert_eq!(iter.size_hint(), (1, Some(1)));
426            iter.next();
427            assert_eq!(iter.len(), 0);
428            assert_eq!(iter.size_hint(), (0, Some(0)));
429        });
430    }
431
432    #[test]
433    fn test_iter_count() {
434        Python::attach(|py| {
435            let set = PySet::new(py, vec![1, 2, 3]).unwrap();
436            assert_eq!(set.iter().count(), 3);
437        })
438    }
439}