diff --git a/Cargo.toml b/Cargo.toml index 8fd1914..a8b3c86 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ keywords = ["serde", "pyo3", "python", "ffi"] license = "MIT OR Apache-2.0" [dependencies] +once_cell = { version = ">=1.21", optional = true } pyo3 = ">=0.26.0" serde = "1.0.228" @@ -20,4 +21,7 @@ serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.145" [features] +default = [] abi3-py38 = ["pyo3/abi3-py38"] +dataclass_support = ["dep:once_cell"] +pydantic_support = ["dep:once_cell"] diff --git a/src/de.rs b/src/de.rs index 8135cf8..227384b 100644 --- a/src/de.rs +++ b/src/de.rs @@ -328,6 +328,21 @@ impl<'de> de::Deserializer<'de> for PyAnyDeserializer<'_> { if self.0.is_instance_of::() { return visitor.visit_f64(self.0.extract()?); } + #[cfg(feature = "dataclass_support")] + if crate::py_module_cache::is_dataclass(self.0.py(), &self.0)? { + // Use dataclasses.asdict(obj) to get the dict representtion of the object + let dataclasses = PyModule::import(self.0.py(), "dataclasses")?; + let asdict = dataclasses.getattr("asdict")?; + let dict = asdict.call1((self.0,))?; + return visitor.visit_map(MapDeserializer::new(dict.downcast()?)); + } + #[cfg(feature = "pydantic_support")] + if crate::py_module_cache::is_pydantic_base_model(self.0.py(), &self.0)? { + // Use pydantic.BaseModel#model_dump() to get the dict representation of the object + let model_dump = self.0.getattr("model_dump")?; + let dict = model_dump.call0()?; + return visitor.visit_map(MapDeserializer::new(dict.downcast()?)); + } if self.0.hasattr("__dict__")? { return visitor.visit_map(MapDeserializer::new( self.0.getattr("__dict__")?.downcast()?, diff --git a/src/lib.rs b/src/lib.rs index ebcc407..4497e39 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,3 +18,6 @@ pub use ser::to_pyobject; #[cfg_attr(doc, doc = include_str!("../README.md"))] mod readme {} + +#[cfg(any(feature = "dataclass_support", feature = "pydantic_support"))] +mod py_module_cache; diff --git a/src/py_module_cache.rs b/src/py_module_cache.rs new file mode 100644 index 0000000..ace1ecf --- /dev/null +++ b/src/py_module_cache.rs @@ -0,0 +1,96 @@ +use once_cell::sync::OnceCell; +use pyo3::{types::*, Bound, IntoPyObject, Py, PyResult, Python}; + +// Individual OnceCell instances for each cached item +#[cfg(feature = "dataclass_support")] +static DATACLASSES_MODULE: OnceCell> = OnceCell::new(); +#[cfg(feature = "dataclass_support")] +static IS_DATACLASS_FN: OnceCell> = OnceCell::new(); + +#[cfg(feature = "pydantic_support")] +static PYDANTIC_MODULE: OnceCell> = OnceCell::new(); +#[cfg(feature = "pydantic_support")] +static PYDANTIC_BASE_MODEL: OnceCell> = OnceCell::new(); + +#[cfg(feature = "pydantic_support")] +fn is_module_installed(py: Python, module_name: &str) -> PyResult { + match PyModule::import(py, module_name) { + Ok(_) => Ok(true), + Err(err) => { + if err.is_instance_of::(py) { + Ok(false) + } else { + Err(err) + } + } + } +} + +#[cfg(feature = "dataclass_support")] +pub fn is_dataclass(py: Python, obj: &Bound<'_, PyAny>) -> PyResult { + // Initialize the dataclasses module if needed + if DATACLASSES_MODULE.get().is_none() { + let dataclasses = PyModule::import(py, "dataclasses")?; + let _ = DATACLASSES_MODULE.set(dataclasses.into()); + } + + // Initialize the is_dataclass function if needed + let is_dataclass_fn = if let Some(fn_obj) = IS_DATACLASS_FN.get() { + fn_obj + } else { + let dataclasses = DATACLASSES_MODULE + .get() + .ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err("Dataclasses module not initialized") + })? + .bind(py); + let is_dataclass_fn: Py = dataclasses + .getattr("is_dataclass")? + .into_pyobject(py)? + .into(); + // Safe to unwrap because we know the value is set + let _ = IS_DATACLASS_FN.set(is_dataclass_fn); + IS_DATACLASS_FN.get().ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err("Failed to initialize is_dataclass function") + })? + }; + + // Execute the function + let result = is_dataclass_fn.bind(py).call1((obj,))?; + result.extract() +} +#[cfg(feature = "pydantic_support")] +pub fn is_pydantic_base_model(py: Python, obj: &Bound<'_, PyAny>) -> PyResult { + // First check if pydantic is installed + if !is_module_installed(py, "pydantic")? { + return Ok(false); + } + + // Initialize pydantic module if needed + if PYDANTIC_MODULE.get().is_none() { + let pydantic = PyModule::import(py, "pydantic")?; + // Safe to unwrap because we know the value is empty + let _ = PYDANTIC_MODULE.set(pydantic.into()); + } + + // Initialize BaseModel if needed + let base_model = if let Some(model) = PYDANTIC_BASE_MODEL.get() { + model + } else { + let pydantic = PYDANTIC_MODULE + .get() + .ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err("Pydantic module not initialized") + })? + .bind(py); + let base_model: Py = pydantic.getattr("BaseModel")?.into_pyobject(py)?.into(); + // Safe to unwrap because we know the value is empty + let _ = PYDANTIC_BASE_MODEL.set(base_model); + PYDANTIC_BASE_MODEL.get().ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err("Failed to initialize BaseModel") + })? + }; + + // Check if object is instance of BaseModel + obj.is_instance(base_model.bind(py)) +} diff --git a/tests/check_revertible.rs b/tests/check_revertible.rs index c79835b..2947d41 100644 --- a/tests/check_revertible.rs +++ b/tests/check_revertible.rs @@ -180,3 +180,166 @@ MyClass("John", 30) assert_eq!(rust_version, python_version); }) } + +#[cfg(feature = "pydantic_support")] +#[test] +fn check_pydantic_object() { + #[derive(Debug, PartialEq, Serialize, Deserialize)] + struct MyClass { + name: String, + age: i32, + } + + Python::attach(|py| { + // Create an instance of Python object + py.run( + c_str!( + r#" +from pydantic import BaseModel +class MyClass(BaseModel): + name: str + age: int +"# + ), + None, + None, + ) + .unwrap(); + // Create an instance of MyClass + let my_python_class = py + .eval( + c_str!( + r#" +MyClass(name="John", age=30) +"# + ), + None, + None, + ) + .unwrap(); + + let my_rust_class = MyClass { + name: "John".to_string(), + age: 30, + }; + let any: Bound<'_, PyAny> = to_pyobject(py, &my_rust_class).unwrap(); + println!("any: {:?}", any); + + let rust_version: MyClass = from_pyobject(my_python_class).unwrap(); + let python_version: MyClass = from_pyobject(any).unwrap(); + assert_eq!(rust_version, python_version); + }) +} + +#[cfg(feature = "dataclass_support")] +#[test] +fn check_dataclass_object() { + #[derive(Debug, PartialEq, Serialize, Deserialize)] + struct MyClass { + name: String, + age: i32, + } + + Python::attach(|py| { + // Create an instance of Python object + py.run( + c_str!( + r#" +from dataclasses import dataclass +@dataclass +class MyClass: + name: str + age: int +"# + ), + None, + None, + ) + .unwrap(); + // Create an instance of MyClass + let my_python_class = py + .eval( + c_str!( + r#" +MyClass(name="John", age=30) +"# + ), + None, + None, + ) + .unwrap(); + + let my_rust_class = MyClass { + name: "John".to_string(), + age: 30, + }; + let any: Bound<'_, PyAny> = to_pyobject(py, &my_rust_class).unwrap(); + println!("any: {:?}", any); + + let rust_version: MyClass = from_pyobject(my_python_class).unwrap(); + let python_version: MyClass = from_pyobject(any).unwrap(); + assert_eq!(rust_version, python_version); + }) +} + +#[cfg(feature = "dataclass_support")] +#[test] +fn check_dataclass_object_nested() { + #[derive(Debug, PartialEq, Serialize, Deserialize)] + struct MyClassNested { + name: String, + age: i32, + } + + #[derive(Debug, PartialEq, Serialize, Deserialize)] + struct MyClass { + my_class: MyClassNested, + } + + Python::attach(|py| { + // Create an instance of Python object + py.run( + c_str!( + r#" +from dataclasses import dataclass +@dataclass +class MyClassNested: + name: str + age: int + +@dataclass +class MyClass: + my_class: MyClassNested +"# + ), + None, + None, + ) + .unwrap(); + // Create an instance of MyClass + let my_python_class = py + .eval( + c_str!( + r#" +MyClass(my_class=MyClassNested(name="John", age=30)) +"# + ), + None, + None, + ) + .unwrap(); + + let my_rust_class = MyClass { + my_class: MyClassNested { + name: "John".to_string(), + age: 30, + }, + }; + let any: Bound<'_, PyAny> = to_pyobject(py, &my_rust_class).unwrap(); + println!("any: {:?}", any); + + let rust_version: MyClass = from_pyobject(my_python_class).unwrap(); + let python_version: MyClass = from_pyobject(any).unwrap(); + assert_eq!(rust_version, python_version); + }) +}