Skip to main content

pyo3/conversions/std/
set.rs

1use std::{cmp, collections, hash};
2
3#[cfg(feature = "experimental-inspect")]
4use crate::inspect::types::TypeInfo;
5#[cfg(feature = "experimental-inspect")]
6use crate::inspect::{type_hint_subscript, PyStaticExpr};
7#[cfg(feature = "experimental-inspect")]
8use crate::type_object::PyTypeInfo;
9use crate::{
10    conversion::{FromPyObjectOwned, IntoPyObject},
11    types::{
12        any::PyAnyMethods, frozenset::PyFrozenSetMethods, set::PySetMethods, PyFrozenSet, PySet,
13    },
14    Borrowed, Bound, FromPyObject, PyAny, PyErr, Python,
15};
16
17impl<'py, K, S> IntoPyObject<'py> for collections::HashSet<K, S>
18where
19    K: IntoPyObject<'py> + Eq + hash::Hash,
20    S: hash::BuildHasher + Default,
21{
22    type Target = PySet;
23    type Output = Bound<'py, Self::Target>;
24    type Error = PyErr;
25
26    #[cfg(feature = "experimental-inspect")]
27    const OUTPUT_TYPE: PyStaticExpr = type_hint_subscript!(PySet::TYPE_HINT, K::OUTPUT_TYPE);
28
29    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
30        PySet::new(py, self)
31    }
32
33    #[cfg(feature = "experimental-inspect")]
34    fn type_output() -> TypeInfo {
35        TypeInfo::set_of(K::type_output())
36    }
37}
38
39impl<'a, 'py, K, H> IntoPyObject<'py> for &'a collections::HashSet<K, H>
40where
41    &'a K: IntoPyObject<'py> + Eq + hash::Hash,
42    H: hash::BuildHasher,
43{
44    type Target = PySet;
45    type Output = Bound<'py, Self::Target>;
46    type Error = PyErr;
47    #[cfg(feature = "experimental-inspect")]
48    const OUTPUT_TYPE: PyStaticExpr = type_hint_subscript!(PySet::TYPE_HINT, <&K>::OUTPUT_TYPE);
49
50    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
51        PySet::new(py, self)
52    }
53
54    #[cfg(feature = "experimental-inspect")]
55    fn type_output() -> TypeInfo {
56        TypeInfo::set_of(<&K>::type_output())
57    }
58}
59
60impl<'py, K, S> FromPyObject<'_, 'py> for collections::HashSet<K, S>
61where
62    K: FromPyObjectOwned<'py> + cmp::Eq + hash::Hash,
63    S: hash::BuildHasher + Default,
64{
65    type Error = PyErr;
66
67    #[cfg(feature = "experimental-inspect")]
68    const INPUT_TYPE: PyStaticExpr = type_hint_subscript!(PySet::TYPE_HINT, K::INPUT_TYPE);
69
70    fn extract(ob: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
71        match ob.cast::<PySet>() {
72            Ok(set) => set
73                .iter()
74                .map(|any| any.extract().map_err(Into::into))
75                .collect(),
76            Err(err) => {
77                if let Ok(frozen_set) = ob.cast::<PyFrozenSet>() {
78                    frozen_set
79                        .iter()
80                        .map(|any| any.extract().map_err(Into::into))
81                        .collect()
82                } else {
83                    Err(PyErr::from(err))
84                }
85            }
86        }
87    }
88
89    #[cfg(feature = "experimental-inspect")]
90    fn type_input() -> TypeInfo {
91        TypeInfo::set_of(K::type_input())
92    }
93}
94
95impl<'py, K> IntoPyObject<'py> for collections::BTreeSet<K>
96where
97    K: IntoPyObject<'py> + cmp::Ord,
98{
99    type Target = PySet;
100    type Output = Bound<'py, Self::Target>;
101    type Error = PyErr;
102
103    #[cfg(feature = "experimental-inspect")]
104    const OUTPUT_TYPE: PyStaticExpr = type_hint_subscript!(PySet::TYPE_HINT, K::OUTPUT_TYPE);
105
106    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
107        PySet::new(py, self)
108    }
109
110    #[cfg(feature = "experimental-inspect")]
111    fn type_output() -> TypeInfo {
112        TypeInfo::set_of(K::type_output())
113    }
114}
115
116impl<'a, 'py, K> IntoPyObject<'py> for &'a collections::BTreeSet<K>
117where
118    &'a K: IntoPyObject<'py> + cmp::Ord,
119    K: 'a,
120{
121    type Target = PySet;
122    type Output = Bound<'py, Self::Target>;
123    type Error = PyErr;
124
125    #[cfg(feature = "experimental-inspect")]
126    const OUTPUT_TYPE: PyStaticExpr = type_hint_subscript!(PySet::TYPE_HINT, <&K>::OUTPUT_TYPE);
127
128    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
129        PySet::new(py, self)
130    }
131
132    #[cfg(feature = "experimental-inspect")]
133    fn type_output() -> TypeInfo {
134        TypeInfo::set_of(<&K>::type_output())
135    }
136}
137
138impl<'py, K> FromPyObject<'_, 'py> for collections::BTreeSet<K>
139where
140    K: FromPyObjectOwned<'py> + cmp::Ord,
141{
142    type Error = PyErr;
143
144    #[cfg(feature = "experimental-inspect")]
145    const INPUT_TYPE: PyStaticExpr = type_hint_subscript!(PySet::TYPE_HINT, K::INPUT_TYPE);
146
147    fn extract(ob: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
148        match ob.cast::<PySet>() {
149            Ok(set) => set
150                .iter()
151                .map(|any| any.extract().map_err(Into::into))
152                .collect(),
153            Err(err) => {
154                if let Ok(frozen_set) = ob.cast::<PyFrozenSet>() {
155                    frozen_set
156                        .iter()
157                        .map(|any| any.extract().map_err(Into::into))
158                        .collect()
159                } else {
160                    Err(PyErr::from(err))
161                }
162            }
163        }
164    }
165
166    #[cfg(feature = "experimental-inspect")]
167    fn type_input() -> TypeInfo {
168        TypeInfo::set_of(K::type_input())
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use crate::types::{any::PyAnyMethods, PyFrozenSet, PySet};
175    use crate::{IntoPyObject, Python};
176    use std::collections::{BTreeSet, HashSet};
177
178    #[test]
179    fn test_extract_hashset() {
180        Python::attach(|py| {
181            let set = PySet::new(py, [1, 2, 3, 4, 5]).unwrap();
182            let hash_set: HashSet<usize> = set.extract().unwrap();
183            assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect());
184
185            let set = PyFrozenSet::new(py, [1, 2, 3, 4, 5]).unwrap();
186            let hash_set: HashSet<usize> = set.extract().unwrap();
187            assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect());
188        });
189    }
190
191    #[test]
192    fn test_extract_btreeset() {
193        Python::attach(|py| {
194            let set = PySet::new(py, [1, 2, 3, 4, 5]).unwrap();
195            let hash_set: BTreeSet<usize> = set.extract().unwrap();
196            assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect());
197
198            let set = PyFrozenSet::new(py, [1, 2, 3, 4, 5]).unwrap();
199            let hash_set: BTreeSet<usize> = set.extract().unwrap();
200            assert_eq!(hash_set, [1, 2, 3, 4, 5].iter().copied().collect());
201        });
202    }
203
204    #[test]
205    fn test_set_into_pyobject() {
206        Python::attach(|py| {
207            let bt: BTreeSet<u64> = [1, 2, 3, 4, 5].iter().cloned().collect();
208            let hs: HashSet<u64> = [1, 2, 3, 4, 5].iter().cloned().collect();
209
210            let bto = (&bt).into_pyobject(py).unwrap();
211            let hso = (&hs).into_pyobject(py).unwrap();
212
213            assert_eq!(bt, bto.extract().unwrap());
214            assert_eq!(hs, hso.extract().unwrap());
215        });
216    }
217}