pyo3/types/
set.rs

1use crate::types::PyIterator;
2use crate::{
3    err::{self, PyErr, PyResult},
4    ffi_ptr_ext::FfiPtrExt,
5    instance::Bound,
6    py_result_ext::PyResultExt,
7};
8use crate::{ffi, Borrowed, BoundObject, IntoPyObject, IntoPyObjectExt, PyAny, Python};
9use std::ptr;
10
11/// Represents a Python `set`.
12///
13/// Values of this type are accessed via PyO3's smart pointers, e.g. as
14/// [`Py<PySet>`][crate::Py] or [`Bound<'py, PySet>`][Bound].
15///
16/// For APIs available on `set` objects, see the [`PySetMethods`] trait which is implemented for
17/// [`Bound<'py, PySet>`][Bound].
18#[repr(transparent)]
19pub struct PySet(PyAny);
20
21#[cfg(not(any(PyPy, GraalPy)))]
22pyobject_subclassable_native_type!(PySet, crate::ffi::PySetObject);
23
24#[cfg(not(any(PyPy, GraalPy)))]
25pyobject_native_type!(
26    PySet,
27    ffi::PySetObject,
28    pyobject_native_static_type_object!(ffi::PySet_Type),
29    "builtins",
30    "set",
31    #checkfunction=ffi::PySet_Check
32);
33
34#[cfg(any(PyPy, GraalPy))]
35pyobject_native_type_core!(
36    PySet,
37    pyobject_native_static_type_object!(ffi::PySet_Type),
38    "builtins",
39    "set",
40    #checkfunction=ffi::PySet_Check
41);
42
43impl PySet {
44    /// Creates a new set with elements from the given slice.
45    ///
46    /// Returns an error if some element is not hashable.
47    #[inline]
48    pub fn new<'py, T>(
49        py: Python<'py>,
50        elements: impl IntoIterator<Item = T>,
51    ) -> PyResult<Bound<'py, PySet>>
52    where
53        T: IntoPyObject<'py>,
54    {
55        try_new_from_iter(py, elements)
56    }
57
58    /// Creates a new empty set.
59    pub fn empty(py: Python<'_>) -> PyResult<Bound<'_, PySet>> {
60        unsafe {
61            ffi::PySet_New(ptr::null_mut())
62                .assume_owned_or_err(py)
63                .cast_into_unchecked()
64        }
65    }
66}
67
68/// Implementation of functionality for [`PySet`].
69///
70/// These methods are defined for the `Bound<'py, PySet>` smart pointer, so to use method call
71/// syntax these methods are separated into a trait, because stable Rust does not yet support
72/// `arbitrary_self_types`.
73#[doc(alias = "PySet")]
74pub trait PySetMethods<'py>: crate::sealed::Sealed {
75    /// Removes all elements from the set.
76    fn clear(&self);
77
78    /// Returns the number of items in the set.
79    ///
80    /// This is equivalent to the Python expression `len(self)`.
81    fn len(&self) -> usize;
82
83    /// Checks if set is empty.
84    fn is_empty(&self) -> bool {
85        self.len() == 0
86    }
87
88    /// Determines if the set contains the specified key.
89    ///
90    /// This is equivalent to the Python expression `key in self`.
91    fn contains<K>(&self, key: K) -> PyResult<bool>
92    where
93        K: IntoPyObject<'py>;
94
95    /// Removes the element from the set if it is present.
96    ///
97    /// Returns `true` if the element was present in the set.
98    fn discard<K>(&self, key: K) -> PyResult<bool>
99    where
100        K: IntoPyObject<'py>;
101
102    /// Adds an element to the set.
103    fn add<K>(&self, key: K) -> PyResult<()>
104    where
105        K: IntoPyObject<'py>;
106
107    /// Removes and returns an arbitrary element from the set.
108    fn pop(&self) -> Option<Bound<'py, PyAny>>;
109
110    /// Returns an iterator of values in this set.
111    ///
112    /// # Panics
113    ///
114    /// If PyO3 detects that the set is mutated during iteration, it will panic.
115    fn iter(&self) -> BoundSetIterator<'py>;
116}
117
118impl<'py> PySetMethods<'py> for Bound<'py, PySet> {
119    #[inline]
120    fn clear(&self) {
121        unsafe {
122            ffi::PySet_Clear(self.as_ptr());
123        }
124    }
125
126    #[inline]
127    fn len(&self) -> usize {
128        unsafe { ffi::PySet_Size(self.as_ptr()) as usize }
129    }
130
131    fn contains<K>(&self, key: K) -> PyResult<bool>
132    where
133        K: IntoPyObject<'py>,
134    {
135        fn inner(set: &Bound<'_, PySet>, key: Borrowed<'_, '_, PyAny>) -> PyResult<bool> {
136            match unsafe { ffi::PySet_Contains(set.as_ptr(), key.as_ptr()) } {
137                1 => Ok(true),
138                0 => Ok(false),
139                _ => Err(PyErr::fetch(set.py())),
140            }
141        }
142
143        let py = self.py();
144        inner(
145            self,
146            key.into_pyobject_or_pyerr(py)?.into_any().as_borrowed(),
147        )
148    }
149
150    fn discard<K>(&self, key: K) -> PyResult<bool>
151    where
152        K: IntoPyObject<'py>,
153    {
154        fn inner(set: &Bound<'_, PySet>, key: Borrowed<'_, '_, PyAny>) -> PyResult<bool> {
155            match unsafe { ffi::PySet_Discard(set.as_ptr(), key.as_ptr()) } {
156                1 => Ok(true),
157                0 => Ok(false),
158                _ => Err(PyErr::fetch(set.py())),
159            }
160        }
161
162        let py = self.py();
163        inner(
164            self,
165            key.into_pyobject_or_pyerr(py)?.into_any().as_borrowed(),
166        )
167    }
168
169    fn add<K>(&self, key: K) -> PyResult<()>
170    where
171        K: IntoPyObject<'py>,
172    {
173        fn inner(set: &Bound<'_, PySet>, key: Borrowed<'_, '_, PyAny>) -> PyResult<()> {
174            err::error_on_minusone(set.py(), unsafe {
175                ffi::PySet_Add(set.as_ptr(), key.as_ptr())
176            })
177        }
178
179        let py = self.py();
180        inner(
181            self,
182            key.into_pyobject_or_pyerr(py)?.into_any().as_borrowed(),
183        )
184    }
185
186    fn pop(&self) -> Option<Bound<'py, PyAny>> {
187        let element = unsafe { ffi::PySet_Pop(self.as_ptr()).assume_owned_or_err(self.py()) };
188        element.ok()
189    }
190
191    fn iter(&self) -> BoundSetIterator<'py> {
192        BoundSetIterator::new(self.clone())
193    }
194}
195
196impl<'py> IntoIterator for Bound<'py, PySet> {
197    type Item = Bound<'py, PyAny>;
198    type IntoIter = BoundSetIterator<'py>;
199
200    /// Returns an iterator of values in this set.
201    ///
202    /// # Panics
203    ///
204    /// If PyO3 detects that the set is mutated during iteration, it will panic.
205    fn into_iter(self) -> Self::IntoIter {
206        BoundSetIterator::new(self)
207    }
208}
209
210impl<'py> IntoIterator for &Bound<'py, PySet> {
211    type Item = Bound<'py, PyAny>;
212    type IntoIter = BoundSetIterator<'py>;
213
214    /// Returns an iterator of values in this set.
215    ///
216    /// # Panics
217    ///
218    /// If PyO3 detects that the set is mutated during iteration, it will panic.
219    fn into_iter(self) -> Self::IntoIter {
220        self.iter()
221    }
222}
223
224/// PyO3 implementation of an iterator for a Python `set` object.
225pub struct BoundSetIterator<'p> {
226    it: Bound<'p, PyIterator>,
227    // Remaining elements in the set. This is fine to store because
228    // Python will error if the set changes size during iteration.
229    remaining: usize,
230}
231
232impl<'py> BoundSetIterator<'py> {
233    pub(super) fn new(set: Bound<'py, PySet>) -> Self {
234        Self {
235            it: PyIterator::from_object(&set).unwrap(),
236            remaining: set.len(),
237        }
238    }
239}
240
241impl<'py> Iterator for BoundSetIterator<'py> {
242    type Item = Bound<'py, super::PyAny>;
243
244    /// Advances the iterator and returns the next value.
245    fn next(&mut self) -> Option<Self::Item> {
246        self.remaining = self.remaining.saturating_sub(1);
247        self.it.next().map(Result::unwrap)
248    }
249
250    fn size_hint(&self) -> (usize, Option<usize>) {
251        (self.remaining, Some(self.remaining))
252    }
253
254    #[inline]
255    fn count(self) -> usize
256    where
257        Self: Sized,
258    {
259        self.len()
260    }
261}
262
263impl ExactSizeIterator for BoundSetIterator<'_> {
264    fn len(&self) -> usize {
265        self.remaining
266    }
267}
268
269#[inline]
270pub(crate) fn try_new_from_iter<'py, T>(
271    py: Python<'py>,
272    elements: impl IntoIterator<Item = T>,
273) -> PyResult<Bound<'py, PySet>>
274where
275    T: IntoPyObject<'py>,
276{
277    let set = unsafe {
278        // We create the `Bound` pointer because its Drop cleans up the set if
279        // user code errors or panics.
280        ffi::PySet_New(std::ptr::null_mut())
281            .assume_owned_or_err(py)?
282            .cast_into_unchecked()
283    };
284    let ptr = set.as_ptr();
285
286    elements.into_iter().try_for_each(|element| {
287        let obj = element.into_pyobject_or_pyerr(py)?;
288        err::error_on_minusone(py, unsafe { ffi::PySet_Add(ptr, obj.as_ptr()) })
289    })?;
290
291    Ok(set)
292}
293
294#[cfg(test)]
295mod tests {
296    use super::PySet;
297    use crate::{
298        conversion::IntoPyObject,
299        types::{PyAnyMethods, PySetMethods},
300        Python,
301    };
302    use std::collections::HashSet;
303
304    #[test]
305    fn test_set_new() {
306        Python::attach(|py| {
307            let set = PySet::new(py, [1]).unwrap();
308            assert_eq!(1, set.len());
309
310            let v = vec![1];
311            assert!(PySet::new(py, &[v]).is_err());
312        });
313    }
314
315    #[test]
316    fn test_set_empty() {
317        Python::attach(|py| {
318            let set = PySet::empty(py).unwrap();
319            assert_eq!(0, set.len());
320            assert!(set.is_empty());
321        });
322    }
323
324    #[test]
325    fn test_set_len() {
326        Python::attach(|py| {
327            let mut v = HashSet::<i32>::new();
328            let ob = (&v).into_pyobject(py).unwrap();
329            let set = ob.cast::<PySet>().unwrap();
330            assert_eq!(0, set.len());
331            v.insert(7);
332            let ob = v.into_pyobject(py).unwrap();
333            let set2 = ob.cast::<PySet>().unwrap();
334            assert_eq!(1, set2.len());
335        });
336    }
337
338    #[test]
339    fn test_set_clear() {
340        Python::attach(|py| {
341            let set = PySet::new(py, [1]).unwrap();
342            assert_eq!(1, set.len());
343            set.clear();
344            assert_eq!(0, set.len());
345        });
346    }
347
348    #[test]
349    fn test_set_contains() {
350        Python::attach(|py| {
351            let set = PySet::new(py, [1]).unwrap();
352            assert!(set.contains(1).unwrap());
353        });
354    }
355
356    #[test]
357    fn test_set_discard() {
358        Python::attach(|py| {
359            let set = PySet::new(py, [1]).unwrap();
360            assert!(!set.discard(2).unwrap());
361            assert_eq!(1, set.len());
362
363            assert!(set.discard(1).unwrap());
364            assert_eq!(0, set.len());
365            assert!(!set.discard(1).unwrap());
366
367            assert!(set.discard(vec![1, 2]).is_err());
368        });
369    }
370
371    #[test]
372    fn test_set_add() {
373        Python::attach(|py| {
374            let set = PySet::new(py, [1, 2]).unwrap();
375            set.add(1).unwrap(); // Add a duplicated element
376            assert!(set.contains(1).unwrap());
377        });
378    }
379
380    #[test]
381    fn test_set_pop() {
382        Python::attach(|py| {
383            let set = PySet::new(py, [1]).unwrap();
384            let val = set.pop();
385            assert!(val.is_some());
386            let val2 = set.pop();
387            assert!(val2.is_none());
388            assert!(py
389                .eval(c"print('Exception state should not be set.')", None, None)
390                .is_ok());
391        });
392    }
393
394    #[test]
395    fn test_set_iter() {
396        Python::attach(|py| {
397            let set = PySet::new(py, [1]).unwrap();
398
399            for el in set {
400                assert_eq!(1i32, el.extract::<'_, i32>().unwrap());
401            }
402        });
403    }
404
405    #[test]
406    fn test_set_iter_bound() {
407        use crate::types::any::PyAnyMethods;
408
409        Python::attach(|py| {
410            let set = PySet::new(py, [1]).unwrap();
411
412            for el in &set {
413                assert_eq!(1i32, el.extract::<i32>().unwrap());
414            }
415        });
416    }
417
418    #[test]
419    #[should_panic]
420    fn test_set_iter_mutation() {
421        Python::attach(|py| {
422            let set = PySet::new(py, [1, 2, 3, 4, 5]).unwrap();
423
424            for _ in &set {
425                let _ = set.add(42);
426            }
427        });
428    }
429
430    #[test]
431    #[should_panic]
432    fn test_set_iter_mutation_same_len() {
433        Python::attach(|py| {
434            let set = PySet::new(py, [1, 2, 3, 4, 5]).unwrap();
435
436            for item in &set {
437                let item: i32 = item.extract().unwrap();
438                let _ = set.del_item(item);
439                let _ = set.add(item + 10);
440            }
441        });
442    }
443
444    #[test]
445    fn test_set_iter_size_hint() {
446        Python::attach(|py| {
447            let set = PySet::new(py, [1]).unwrap();
448            let mut iter = set.iter();
449
450            // Exact size
451            assert_eq!(iter.len(), 1);
452            assert_eq!(iter.size_hint(), (1, Some(1)));
453            iter.next();
454            assert_eq!(iter.len(), 0);
455            assert_eq!(iter.size_hint(), (0, Some(0)));
456        });
457    }
458
459    #[test]
460    fn test_iter_count() {
461        Python::attach(|py| {
462            let set = PySet::new(py, vec![1, 2, 3]).unwrap();
463            assert_eq!(set.iter().count(), 3);
464        })
465    }
466}