pyo3/conversions/
bigdecimal.rs1#![cfg(feature = "bigdecimal")]
2#![doc = concat!("pyo3 = { version = \"", env!("CARGO_PKG_VERSION"), "\", features = [\"bigdecimal\"] }")]
13use std::str::FromStr;
53
54use crate::types::PyTuple;
55use crate::{
56 exceptions::PyValueError,
57 sync::PyOnceLock,
58 types::{PyAnyMethods, PyStringMethods, PyType},
59 Borrowed, Bound, FromPyObject, IntoPyObject, Py, PyAny, PyErr, PyResult, Python,
60};
61use bigdecimal::BigDecimal;
62use num_bigint::Sign;
63
64fn get_decimal_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
65 static DECIMAL_CLS: PyOnceLock<Py<PyType>> = PyOnceLock::new();
66 DECIMAL_CLS.import(py, "decimal", "Decimal")
67}
68
69fn get_invalid_operation_error_cls(py: Python<'_>) -> PyResult<&Bound<'_, PyType>> {
70 static INVALID_OPERATION_CLS: PyOnceLock<Py<PyType>> = PyOnceLock::new();
71 INVALID_OPERATION_CLS.import(py, "decimal", "InvalidOperation")
72}
73
74impl FromPyObject<'_, '_> for BigDecimal {
75 type Error = PyErr;
76
77 fn extract(obj: Borrowed<'_, '_, PyAny>) -> PyResult<Self> {
78 let py_str = &obj.str()?;
79 let rs_str = &py_str.to_cow()?;
80 BigDecimal::from_str(rs_str).map_err(|e| PyValueError::new_err(e.to_string()))
81 }
82}
83
84impl<'py> IntoPyObject<'py> for BigDecimal {
85 type Target = PyAny;
86
87 type Output = Bound<'py, Self::Target>;
88
89 type Error = PyErr;
90
91 fn into_pyobject(self, py: Python<'py>) -> Result<Self::Output, Self::Error> {
92 let cls = get_decimal_cls(py)?;
93 let (bigint, scale) = self.into_bigint_and_scale();
94 if scale == 0 {
95 return cls.call1((bigint,));
96 }
97 let exponent = scale.checked_neg().ok_or_else(|| {
98 get_invalid_operation_error_cls(py)
99 .map_or_else(|err| err, |cls| PyErr::from_type(cls.clone(), ()))
100 })?;
101 let (sign, digits) = bigint.to_radix_be(10);
102 let signed = matches!(sign, Sign::Minus).into_pyobject(py)?;
103 let digits = PyTuple::new(py, digits)?;
104
105 cls.call1(((signed, digits, exponent),))
106 }
107}
108
109#[cfg(test)]
110mod test_bigdecimal {
111 use super::*;
112 use crate::types::dict::PyDictMethods;
113 use crate::types::PyDict;
114 use std::ffi::CString;
115
116 use bigdecimal::{One, Zero};
117 #[cfg(not(target_arch = "wasm32"))]
118 use proptest::prelude::*;
119
120 macro_rules! convert_constants {
121 ($name:ident, $rs:expr, $py:literal) => {
122 #[test]
123 fn $name() {
124 Python::attach(|py| {
125 let rs_orig = $rs;
126 let rs_dec = rs_orig.clone().into_pyobject(py).unwrap();
127 let locals = PyDict::new(py);
128 locals.set_item("rs_dec", &rs_dec).unwrap();
129 py.run(
131 &CString::new(format!(
132 "import decimal\npy_dec = decimal.Decimal(\"{}\")\nassert py_dec == rs_dec",
133 $py
134 ))
135 .unwrap(),
136 None,
137 Some(&locals),
138 )
139 .unwrap();
140 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
142 let py_result: BigDecimal = py_dec.extract().unwrap();
143 assert_eq!(rs_orig, py_result);
144 })
145 }
146 };
147 }
148
149 convert_constants!(convert_zero, BigDecimal::zero(), "0");
150 convert_constants!(convert_one, BigDecimal::one(), "1");
151 convert_constants!(convert_neg_one, -BigDecimal::one(), "-1");
152 convert_constants!(convert_two, BigDecimal::from(2), "2");
153 convert_constants!(convert_ten, BigDecimal::from_str("10").unwrap(), "10");
154 convert_constants!(
155 convert_one_hundred_point_one,
156 BigDecimal::from_str("100.1").unwrap(),
157 "100.1"
158 );
159 convert_constants!(
160 convert_one_thousand,
161 BigDecimal::from_str("1000").unwrap(),
162 "1000"
163 );
164 convert_constants!(
165 convert_scientific,
166 BigDecimal::from_str("1e10").unwrap(),
167 "1e10"
168 );
169
170 #[cfg(not(target_arch = "wasm32"))]
171 proptest! {
172 #[test]
173 fn test_roundtrip(
174 number in 0..28u32
175 ) {
176 let num = BigDecimal::from(number);
177 Python::attach(|py| {
178 let rs_dec = num.clone().into_pyobject(py).unwrap();
179 let locals = PyDict::new(py);
180 locals.set_item("rs_dec", &rs_dec).unwrap();
181 py.run(
182 &CString::new(format!(
183 "import decimal\npy_dec = decimal.Decimal(\"{num}\")\nassert py_dec == rs_dec")).unwrap(),
184 None, Some(&locals)).unwrap();
185 let roundtripped: BigDecimal = rs_dec.extract().unwrap();
186 assert_eq!(num, roundtripped);
187 })
188 }
189
190 #[test]
191 fn test_integers(num in any::<i64>()) {
192 Python::attach(|py| {
193 let py_num = num.into_pyobject(py).unwrap();
194 let roundtripped: BigDecimal = py_num.extract().unwrap();
195 let rs_dec = BigDecimal::from(num);
196 assert_eq!(rs_dec, roundtripped);
197 })
198 }
199 }
200
201 #[test]
202 fn test_nan() {
203 Python::attach(|py| {
204 let locals = PyDict::new(py);
205 py.run(
206 c"import decimal\npy_dec = decimal.Decimal(\"NaN\")",
207 None,
208 Some(&locals),
209 )
210 .unwrap();
211 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
212 let roundtripped: Result<BigDecimal, PyErr> = py_dec.extract();
213 assert!(roundtripped.is_err());
214 })
215 }
216
217 #[test]
218 fn test_infinity() {
219 Python::attach(|py| {
220 let locals = PyDict::new(py);
221 py.run(
222 c"import decimal\npy_dec = decimal.Decimal(\"Infinity\")",
223 None,
224 Some(&locals),
225 )
226 .unwrap();
227 let py_dec = locals.get_item("py_dec").unwrap().unwrap();
228 let roundtripped: Result<BigDecimal, PyErr> = py_dec.extract();
229 assert!(roundtripped.is_err());
230 })
231 }
232
233 #[test]
234 fn test_no_precision_loss() {
235 Python::attach(|py| {
236 let src = "1e4";
237 let expected = get_decimal_cls(py)
238 .unwrap()
239 .call1((src,))
240 .unwrap()
241 .call_method0("as_tuple")
242 .unwrap();
243 let actual = src
244 .parse::<BigDecimal>()
245 .unwrap()
246 .into_pyobject(py)
247 .unwrap()
248 .call_method0("as_tuple")
249 .unwrap();
250
251 assert!(actual.eq(expected).unwrap());
252 });
253 }
254}