Skip to main content

pyo3/conversions/std/
map.rs

1#[cfg(feature = "experimental-inspect")]
2use crate::inspect::types::TypeInfo;
3#[cfg(feature = "experimental-inspect")]
4use crate::inspect::{type_hint_subscript, PyStaticExpr};
5#[cfg(feature = "experimental-inspect")]
6use crate::type_object::PyTypeInfo;
7use crate::{
8    conversion::{FromPyObjectOwned, IntoPyObject},
9    instance::Bound,
10    types::{any::PyAnyMethods, dict::PyDictMethods, PyDict},
11    Borrowed, FromPyObject, PyAny, PyErr, Python,
12};
13use std::{cmp, collections, hash};
14
15impl<'py, K, V, H> IntoPyObject<'py> for collections::HashMap<K, V, H>
16where
17    K: IntoPyObject<'py> + cmp::Eq + hash::Hash,
18    V: IntoPyObject<'py>,
19    H: hash::BuildHasher,
20{
21    type Target = PyDict;
22    type Output = Bound<'py, Self::Target>;
23    type Error = PyErr;
24
25    #[cfg(feature = "experimental-inspect")]
26    const OUTPUT_TYPE: PyStaticExpr =
27        type_hint_subscript!(PyDict::TYPE_HINT, K::OUTPUT_TYPE, V::OUTPUT_TYPE);
28
29    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
30        let dict = PyDict::new(py);
31        for (k, v) in self {
32            dict.set_item(k, v)?;
33        }
34        Ok(dict)
35    }
36
37    #[cfg(feature = "experimental-inspect")]
38    fn type_output() -> TypeInfo {
39        TypeInfo::dict_of(K::type_output(), V::type_output())
40    }
41}
42
43impl<'a, 'py, K, V, H> IntoPyObject<'py> for &'a collections::HashMap<K, V, H>
44where
45    &'a K: IntoPyObject<'py> + cmp::Eq + hash::Hash,
46    &'a V: IntoPyObject<'py>,
47    H: hash::BuildHasher,
48{
49    type Target = PyDict;
50    type Output = Bound<'py, Self::Target>;
51    type Error = PyErr;
52
53    #[cfg(feature = "experimental-inspect")]
54    const OUTPUT_TYPE: PyStaticExpr =
55        type_hint_subscript!(PyDict::TYPE_HINT, <&K>::OUTPUT_TYPE, <&V>::OUTPUT_TYPE);
56
57    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
58        let dict = PyDict::new(py);
59        for (k, v) in self {
60            dict.set_item(k, v)?;
61        }
62        Ok(dict)
63    }
64
65    #[cfg(feature = "experimental-inspect")]
66    fn type_output() -> TypeInfo {
67        TypeInfo::dict_of(<&K>::type_output(), <&V>::type_output())
68    }
69}
70
71impl<'py, K, V> IntoPyObject<'py> for collections::BTreeMap<K, V>
72where
73    K: IntoPyObject<'py> + cmp::Eq,
74    V: IntoPyObject<'py>,
75{
76    type Target = PyDict;
77    type Output = Bound<'py, Self::Target>;
78    type Error = PyErr;
79
80    #[cfg(feature = "experimental-inspect")]
81    const OUTPUT_TYPE: PyStaticExpr =
82        type_hint_subscript!(PyDict::TYPE_HINT, K::OUTPUT_TYPE, V::OUTPUT_TYPE);
83
84    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
85        let dict = PyDict::new(py);
86        for (k, v) in self {
87            dict.set_item(k, v)?;
88        }
89        Ok(dict)
90    }
91
92    #[cfg(feature = "experimental-inspect")]
93    fn type_output() -> TypeInfo {
94        TypeInfo::dict_of(K::type_output(), V::type_output())
95    }
96}
97
98impl<'a, 'py, K, V> IntoPyObject<'py> for &'a collections::BTreeMap<K, V>
99where
100    &'a K: IntoPyObject<'py> + cmp::Eq,
101    &'a V: IntoPyObject<'py>,
102    K: 'a,
103    V: 'a,
104{
105    type Target = PyDict;
106    type Output = Bound<'py, Self::Target>;
107    type Error = PyErr;
108
109    #[cfg(feature = "experimental-inspect")]
110    const OUTPUT_TYPE: PyStaticExpr =
111        type_hint_subscript!(PyDict::TYPE_HINT, <&K>::OUTPUT_TYPE, <&V>::OUTPUT_TYPE);
112
113    fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
114        let dict = PyDict::new(py);
115        for (k, v) in self {
116            dict.set_item(k, v)?;
117        }
118        Ok(dict)
119    }
120
121    #[cfg(feature = "experimental-inspect")]
122    fn type_output() -> TypeInfo {
123        TypeInfo::dict_of(<&K>::type_output(), <&V>::type_output())
124    }
125}
126
127impl<'py, K, V, S> FromPyObject<'_, 'py> for collections::HashMap<K, V, S>
128where
129    K: FromPyObjectOwned<'py> + cmp::Eq + hash::Hash,
130    V: FromPyObjectOwned<'py>,
131    S: hash::BuildHasher + Default,
132{
133    type Error = PyErr;
134
135    #[cfg(feature = "experimental-inspect")]
136    const INPUT_TYPE: PyStaticExpr =
137        type_hint_subscript!(&PyDict::TYPE_HINT, K::INPUT_TYPE, V::INPUT_TYPE);
138
139    fn extract(ob: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
140        let dict = ob.cast::<PyDict>()?;
141        let mut ret = collections::HashMap::with_capacity_and_hasher(dict.len(), S::default());
142        for (k, v) in dict.iter() {
143            ret.insert(
144                k.extract().map_err(Into::into)?,
145                v.extract().map_err(Into::into)?,
146            );
147        }
148        Ok(ret)
149    }
150
151    #[cfg(feature = "experimental-inspect")]
152    fn type_input() -> TypeInfo {
153        TypeInfo::mapping_of(K::type_input(), V::type_input())
154    }
155}
156
157impl<'py, K, V> FromPyObject<'_, 'py> for collections::BTreeMap<K, V>
158where
159    K: FromPyObjectOwned<'py> + cmp::Ord,
160    V: FromPyObjectOwned<'py>,
161{
162    type Error = PyErr;
163
164    #[cfg(feature = "experimental-inspect")]
165    const INPUT_TYPE: PyStaticExpr =
166        type_hint_subscript!(PyDict::TYPE_HINT, K::INPUT_TYPE, V::INPUT_TYPE);
167
168    fn extract(ob: Borrowed<'_, 'py, PyAny>) -> Result<Self, PyErr> {
169        let dict = ob.cast::<PyDict>()?;
170        let mut ret = collections::BTreeMap::new();
171        for (k, v) in dict.iter() {
172            ret.insert(
173                k.extract().map_err(Into::into)?,
174                v.extract().map_err(Into::into)?,
175            );
176        }
177        Ok(ret)
178    }
179
180    #[cfg(feature = "experimental-inspect")]
181    fn type_input() -> TypeInfo {
182        TypeInfo::mapping_of(K::type_input(), V::type_input())
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use std::collections::{BTreeMap, HashMap};
190
191    #[test]
192    fn test_hashmap_to_python() {
193        Python::attach(|py| {
194            let mut map = HashMap::<i32, i32>::new();
195            map.insert(1, 1);
196
197            let py_map = (&map).into_pyobject(py).unwrap();
198
199            assert_eq!(py_map.len(), 1);
200            assert!(
201                py_map
202                    .get_item(1)
203                    .unwrap()
204                    .unwrap()
205                    .extract::<i32>()
206                    .unwrap()
207                    == 1
208            );
209            assert_eq!(map, py_map.extract().unwrap());
210        });
211    }
212
213    #[test]
214    fn test_btreemap_to_python() {
215        Python::attach(|py| {
216            let mut map = BTreeMap::<i32, i32>::new();
217            map.insert(1, 1);
218
219            let py_map = (&map).into_pyobject(py).unwrap();
220
221            assert_eq!(py_map.len(), 1);
222            assert!(
223                py_map
224                    .get_item(1)
225                    .unwrap()
226                    .unwrap()
227                    .extract::<i32>()
228                    .unwrap()
229                    == 1
230            );
231            assert_eq!(map, py_map.extract().unwrap());
232        });
233    }
234
235    #[test]
236    fn test_hashmap_into_python() {
237        Python::attach(|py| {
238            let mut map = HashMap::<i32, i32>::new();
239            map.insert(1, 1);
240
241            let py_map = map.into_pyobject(py).unwrap();
242
243            assert_eq!(py_map.len(), 1);
244            assert!(
245                py_map
246                    .get_item(1)
247                    .unwrap()
248                    .unwrap()
249                    .extract::<i32>()
250                    .unwrap()
251                    == 1
252            );
253        });
254    }
255
256    #[test]
257    fn test_btreemap_into_py() {
258        Python::attach(|py| {
259            let mut map = BTreeMap::<i32, i32>::new();
260            map.insert(1, 1);
261
262            let py_map = map.into_pyobject(py).unwrap();
263
264            assert_eq!(py_map.len(), 1);
265            assert!(
266                py_map
267                    .get_item(1)
268                    .unwrap()
269                    .unwrap()
270                    .extract::<i32>()
271                    .unwrap()
272                    == 1
273            );
274        });
275    }
276}