From 1aacb80e002d016f0ea838eb0e2e7cfd5024c27d Mon Sep 17 00:00:00 2001 From: Warren Snipes Date: Tue, 1 Apr 2025 19:13:57 +0000 Subject: [PATCH 1/5] Add dataclass and pydantic support --- Cargo.toml | 4 + src/de.rs | 21 ++++- src/lib.rs | 3 + src/py_module_cache.rs | 95 +++++++++++++++++++++ tests/check_revertible.rs | 168 +++++++++++++++++++++++++++++++++++++- 5 files changed, 287 insertions(+), 4 deletions(-) create mode 100644 src/py_module_cache.rs diff --git a/Cargo.toml b/Cargo.toml index fe5bba8..870007c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ license = "MIT OR Apache-2.0" [dependencies] pyo3 = "0.23.0" serde = "1.0.190" +once_cell = { version = ">=1.21", optional = true } [dev-dependencies] maplit = "1.0.2" @@ -20,4 +21,7 @@ serde = { version = "1.0.190", features = ["derive"] } serde_json = "1.0.108" [features] +default = ["dataclass_support", "pydantic_support"] abi3-py38 = ["pyo3/abi3-py38"] +dataclass_support = ["dep:once_cell"] +pydantic_support = ["dep:once_cell"] diff --git a/src/de.rs b/src/de.rs index e4f2981..0dda4f9 100644 --- a/src/de.rs +++ b/src/de.rs @@ -328,8 +328,27 @@ impl<'de> de::Deserializer<'de> for PyAnyDeserializer<'_> { if self.0.is_instance_of::() { return visitor.visit_f64(self.0.extract()?); } + #[cfg(feature = "dataclass_support")] + if crate::py_module_cache::is_dataclass(self.0.py(), &self.0)? { + // Use dataclasses.asdict to get the dict representation of the object + // dataclasses.asdict(obj) + let dataclasses = PyModule::import(self.0.py(), "dataclasses")?; + let asdict = dataclasses.getattr("asdict")?; + let dict = asdict.call1((self.0,))?; + return visitor.visit_map(MapDeserializer::new(dict.downcast()?)); + } + #[cfg(feature = "pydantic_support")] + if crate::py_module_cache::is_pydantic_base_model(self.0.py(), &self.0)? { + // Use pydantic.BaseModel.model_dump() to get the dict representation of the object + // call model_dump() on the object + let model_dump = self.0.getattr("model_dump")?; + let dict = model_dump.call0()?; + return visitor.visit_map(MapDeserializer::new(dict.downcast()?)); + } if self.0.hasattr("__dict__")? { - return visitor.visit_map(MapDeserializer::new(self.0.getattr("__dict__")?.downcast()?)); + return visitor.visit_map(MapDeserializer::new( + self.0.getattr("__dict__")?.downcast()?, + )); } if self.0.is_none() { return visitor.visit_none(); diff --git a/src/lib.rs b/src/lib.rs index ebcc407..4497e39 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,3 +18,6 @@ pub use ser::to_pyobject; #[cfg_attr(doc, doc = include_str!("../README.md"))] mod readme {} + +#[cfg(any(feature = "dataclass_support", feature = "pydantic_support"))] +mod py_module_cache; diff --git a/src/py_module_cache.rs b/src/py_module_cache.rs new file mode 100644 index 0000000..badba57 --- /dev/null +++ b/src/py_module_cache.rs @@ -0,0 +1,95 @@ +use once_cell::sync::OnceCell; +use pyo3::{types::*, Bound, IntoPyObject, Py, PyResult, Python}; + +// Individual OnceCell instances for each cached item +#[cfg(feature = "dataclass_support")] +static DATACLASSES_MODULE: OnceCell = OnceCell::new(); +#[cfg(feature = "dataclass_support")] +static IS_DATACLASS_FN: OnceCell = OnceCell::new(); + +#[cfg(feature = "pydantic_support")] +static PYDANTIC_MODULE: OnceCell = OnceCell::new(); +#[cfg(feature = "pydantic_support")] +static PYDANTIC_BASE_MODEL: OnceCell = OnceCell::new(); + +fn is_module_installed(py: Python, module_name: &str) -> PyResult { + match PyModule::import(py, module_name) { + Ok(_) => Ok(true), + Err(err) => { + if err.is_instance_of::(py) { + Ok(false) + } else { + Err(err) + } + } + } +} + +#[cfg(feature = "dataclass_support")] +pub fn is_dataclass(py: Python, obj: &Bound<'_, PyAny>) -> PyResult { + // Initialize the dataclasses module if needed + if DATACLASSES_MODULE.get().is_none() { + let dataclasses = PyModule::import(py, "dataclasses")?; + let _ = DATACLASSES_MODULE.set(dataclasses.into()); + } + + // Initialize the is_dataclass function if needed + let is_dataclass_fn = if let Some(fn_obj) = IS_DATACLASS_FN.get() { + fn_obj + } else { + let dataclasses = DATACLASSES_MODULE + .get() + .ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err("Dataclasses module not initialized") + })? + .bind(py); + let is_dataclass_fn: Py = dataclasses + .getattr("is_dataclass")? + .into_pyobject(py)? + .into(); + // Safe to unwrap because we know the value is set + let _ = IS_DATACLASS_FN.set(is_dataclass_fn); + IS_DATACLASS_FN.get().ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err("Failed to initialize is_dataclass function") + })? + }; + + // Execute the function + let result = is_dataclass_fn.bind(py).call1((obj,))?; + result.extract() +} +#[cfg(feature = "pydantic_support")] +pub fn is_pydantic_base_model(py: Python, obj: &Bound<'_, PyAny>) -> PyResult { + // First check if pydantic is installed + if !is_module_installed(py, "pydantic")? { + return Ok(false); + } + + // Initialize pydantic module if needed + if PYDANTIC_MODULE.get().is_none() { + let pydantic = PyModule::import(py, "pydantic")?; + // Safe to unwrap because we know the value is empty + let _ = PYDANTIC_MODULE.set(pydantic.into()); + } + + // Initialize BaseModel if needed + let base_model = if let Some(model) = PYDANTIC_BASE_MODEL.get() { + model + } else { + let pydantic = PYDANTIC_MODULE + .get() + .ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err("Pydantic module not initialized") + })? + .bind(py); + let base_model: Py = pydantic.getattr("BaseModel")?.into_pyobject(py)?.into(); + // Safe to unwrap because we know the value is empty + let _ = PYDANTIC_BASE_MODEL.set(base_model); + PYDANTIC_BASE_MODEL.get().ok_or_else(|| { + pyo3::exceptions::PyRuntimeError::new_err("Failed to initialize BaseModel") + })? + }; + + // Check if object is instance of BaseModel + obj.is_instance(base_model.bind(py)) +} diff --git a/tests/check_revertible.rs b/tests/check_revertible.rs index bdda5f4..c23ca5f 100644 --- a/tests/check_revertible.rs +++ b/tests/check_revertible.rs @@ -1,4 +1,3 @@ -use std::{any::Any, collections::HashMap}; use maplit::hashmap; use pyo3::{ffi::c_str, prelude::*}; @@ -136,7 +135,6 @@ fn struct_variant() { }); } - #[test] fn check_python_object() { #[derive(Debug, PartialEq, Serialize, Deserialize)] @@ -163,7 +161,8 @@ class MyClass: // Create an instance of MyClass let my_python_class = py .eval( - c_str!(r#" + c_str!( + r#" MyClass("John", 30) "# ), @@ -182,3 +181,166 @@ MyClass("John", 30) assert_eq!(rust_version, python_version); }) } + +#[cfg(feature = "pydantic_support")] +#[test] +fn check_pydantic_object() { + #[derive(Debug, PartialEq, Serialize, Deserialize)] + struct MyClass { + name: String, + age: i32, + } + + Python::with_gil(|py| { + // Create an instance of Python object + py.run( + c_str!( + r#" +from pydantic import BaseModel +class MyClass(BaseModel): + name: str + age: int +"# + ), + None, + None, + ) + .unwrap(); + // Create an instance of MyClass + let my_python_class = py + .eval( + c_str!( + r#" +MyClass(name="John", age=30) +"# + ), + None, + None, + ) + .unwrap(); + + let my_rust_class = MyClass { + name: "John".to_string(), + age: 30, + }; + let any: Bound<'_, PyAny> = to_pyobject(py, &my_rust_class).unwrap(); + println!("any: {:?}", any); + + let rust_version: MyClass = from_pyobject(my_python_class).unwrap(); + let python_version: MyClass = from_pyobject(any).unwrap(); + assert_eq!(rust_version, python_version); + }) +} + +#[cfg(feature = "dataclass_support")] +#[test] +fn check_dataclass_object() { + #[derive(Debug, PartialEq, Serialize, Deserialize)] + struct MyClass { + name: String, + age: i32, + } + + Python::with_gil(|py| { + // Create an instance of Python object + py.run( + c_str!( + r#" +from dataclasses import dataclass +@dataclass +class MyClass: + name: str + age: int +"# + ), + None, + None, + ) + .unwrap(); + // Create an instance of MyClass + let my_python_class = py + .eval( + c_str!( + r#" +MyClass(name="John", age=30) +"# + ), + None, + None, + ) + .unwrap(); + + let my_rust_class = MyClass { + name: "John".to_string(), + age: 30, + }; + let any: Bound<'_, PyAny> = to_pyobject(py, &my_rust_class).unwrap(); + println!("any: {:?}", any); + + let rust_version: MyClass = from_pyobject(my_python_class).unwrap(); + let python_version: MyClass = from_pyobject(any).unwrap(); + assert_eq!(rust_version, python_version); + }) +} + +#[cfg(feature = "dataclass_support")] +#[test] +fn check_dataclass_object_nested() { + #[derive(Debug, PartialEq, Serialize, Deserialize)] + struct MyClassNested { + name: String, + age: i32, + } + + #[derive(Debug, PartialEq, Serialize, Deserialize)] + struct MyClass { + my_class: MyClassNested, + } + + Python::with_gil(|py| { + // Create an instance of Python object + py.run( + c_str!( + r#" +from dataclasses import dataclass +@dataclass +class MyClassNested: + name: str + age: int + +@dataclass +class MyClass: + my_class: MyClassNested +"# + ), + None, + None, + ) + .unwrap(); + // Create an instance of MyClass + let my_python_class = py + .eval( + c_str!( + r#" +MyClass(my_class=MyClassNested(name="John", age=30)) +"# + ), + None, + None, + ) + .unwrap(); + + let my_rust_class = MyClass { + my_class: MyClassNested { + name: "John".to_string(), + age: 30, + }, + }; + let any: Bound<'_, PyAny> = to_pyobject(py, &my_rust_class).unwrap(); + println!("any: {:?}", any); + + let rust_version: MyClass = from_pyobject(my_python_class).unwrap(); + let python_version: MyClass = from_pyobject(any).unwrap(); + assert_eq!(rust_version, python_version); + }) +} From 96c7ac5429421f8c2fc9a8cdba00f6b76d22c11f Mon Sep 17 00:00:00 2001 From: WarrenS Date: Wed, 2 Apr 2025 07:45:02 -0400 Subject: [PATCH 2/5] Remove dataclass and pydantic support from defaults --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 870007c..9f17e05 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ serde = { version = "1.0.190", features = ["derive"] } serde_json = "1.0.108" [features] -default = ["dataclass_support", "pydantic_support"] +default = [] abi3-py38 = ["pyo3/abi3-py38"] dataclass_support = ["dep:once_cell"] pydantic_support = ["dep:once_cell"] From 1cdf01c0b02e1126f43b8b1c868e7d094248f311 Mon Sep 17 00:00:00 2001 From: WarrenS Date: Wed, 2 Apr 2025 07:46:18 -0400 Subject: [PATCH 3/5] remove unnecessary comments --- src/de.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/de.rs b/src/de.rs index 0dda4f9..5e1affd 100644 --- a/src/de.rs +++ b/src/de.rs @@ -330,8 +330,7 @@ impl<'de> de::Deserializer<'de> for PyAnyDeserializer<'_> { } #[cfg(feature = "dataclass_support")] if crate::py_module_cache::is_dataclass(self.0.py(), &self.0)? { - // Use dataclasses.asdict to get the dict representation of the object - // dataclasses.asdict(obj) + // Use dataclasses.asdict(obj) to get the dict representtion of the object let dataclasses = PyModule::import(self.0.py(), "dataclasses")?; let asdict = dataclasses.getattr("asdict")?; let dict = asdict.call1((self.0,))?; @@ -339,8 +338,7 @@ impl<'de> de::Deserializer<'de> for PyAnyDeserializer<'_> { } #[cfg(feature = "pydantic_support")] if crate::py_module_cache::is_pydantic_base_model(self.0.py(), &self.0)? { - // Use pydantic.BaseModel.model_dump() to get the dict representation of the object - // call model_dump() on the object + // Use pydantic.BaseModel#model_dump() to get the dict representation of the object let model_dump = self.0.getattr("model_dump")?; let dict = model_dump.call0()?; return visitor.visit_map(MapDeserializer::new(dict.downcast()?)); From 61327595b84f1c18c14daff0efabd13ce5a0a138 Mon Sep 17 00:00:00 2001 From: Warren Snipes Date: Sat, 24 May 2025 20:29:56 +0000 Subject: [PATCH 4/5] Run rustfmt --- tests/check_revertible.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/check_revertible.rs b/tests/check_revertible.rs index c23ca5f..7961628 100644 --- a/tests/check_revertible.rs +++ b/tests/check_revertible.rs @@ -1,4 +1,3 @@ - use maplit::hashmap; use pyo3::{ffi::c_str, prelude::*}; use serde::{Deserialize, Serialize}; From 5c303f4fa51ca61b51afc3c492935ab635e8eae2 Mon Sep 17 00:00:00 2001 From: Toshiki Teramura Date: Fri, 17 Oct 2025 02:04:44 +0900 Subject: [PATCH 5/5] Fix deprecation warnings and dead code in pydantic/dataclass support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Replace deprecated pyo3::PyObject with Py - Add #[cfg(feature = "pydantic_support")] to is_module_installed to fix unused function warning - Replace deprecated Python::with_gil with Python::attach in tests 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/py_module_cache.rs | 9 +++++---- tests/check_revertible.rs | 6 +++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/py_module_cache.rs b/src/py_module_cache.rs index badba57..ace1ecf 100644 --- a/src/py_module_cache.rs +++ b/src/py_module_cache.rs @@ -3,15 +3,16 @@ use pyo3::{types::*, Bound, IntoPyObject, Py, PyResult, Python}; // Individual OnceCell instances for each cached item #[cfg(feature = "dataclass_support")] -static DATACLASSES_MODULE: OnceCell = OnceCell::new(); +static DATACLASSES_MODULE: OnceCell> = OnceCell::new(); #[cfg(feature = "dataclass_support")] -static IS_DATACLASS_FN: OnceCell = OnceCell::new(); +static IS_DATACLASS_FN: OnceCell> = OnceCell::new(); #[cfg(feature = "pydantic_support")] -static PYDANTIC_MODULE: OnceCell = OnceCell::new(); +static PYDANTIC_MODULE: OnceCell> = OnceCell::new(); #[cfg(feature = "pydantic_support")] -static PYDANTIC_BASE_MODEL: OnceCell = OnceCell::new(); +static PYDANTIC_BASE_MODEL: OnceCell> = OnceCell::new(); +#[cfg(feature = "pydantic_support")] fn is_module_installed(py: Python, module_name: &str) -> PyResult { match PyModule::import(py, module_name) { Ok(_) => Ok(true), diff --git a/tests/check_revertible.rs b/tests/check_revertible.rs index 1a03d0f..2947d41 100644 --- a/tests/check_revertible.rs +++ b/tests/check_revertible.rs @@ -190,7 +190,7 @@ fn check_pydantic_object() { age: i32, } - Python::with_gil(|py| { + Python::attach(|py| { // Create an instance of Python object py.run( c_str!( @@ -240,7 +240,7 @@ fn check_dataclass_object() { age: i32, } - Python::with_gil(|py| { + Python::attach(|py| { // Create an instance of Python object py.run( c_str!( @@ -296,7 +296,7 @@ fn check_dataclass_object_nested() { my_class: MyClassNested, } - Python::with_gil(|py| { + Python::attach(|py| { // Create an instance of Python object py.run( c_str!(