Skip to content
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ keywords = ["serde", "pyo3", "python", "ffi"]
license = "MIT OR Apache-2.0"

[dependencies]
once_cell = { version = ">=1.21", optional = true }
pyo3 = ">=0.26.0"
serde = "1.0.228"

Expand All @@ -20,4 +21,7 @@ serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0.145"

[features]
default = []
abi3-py38 = ["pyo3/abi3-py38"]
dataclass_support = ["dep:once_cell"]
pydantic_support = ["dep:once_cell"]
15 changes: 15 additions & 0 deletions src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,21 @@ impl<'de> de::Deserializer<'de> for PyAnyDeserializer<'_> {
if self.0.is_instance_of::<PyFloat>() {
return visitor.visit_f64(self.0.extract()?);
}
#[cfg(feature = "dataclass_support")]
if crate::py_module_cache::is_dataclass(self.0.py(), &self.0)? {
// Use dataclasses.asdict(obj) to get the dict representtion of the object
let dataclasses = PyModule::import(self.0.py(), "dataclasses")?;
let asdict = dataclasses.getattr("asdict")?;
let dict = asdict.call1((self.0,))?;
return visitor.visit_map(MapDeserializer::new(dict.downcast()?));
}
#[cfg(feature = "pydantic_support")]
if crate::py_module_cache::is_pydantic_base_model(self.0.py(), &self.0)? {
// Use pydantic.BaseModel#model_dump() to get the dict representation of the object
let model_dump = self.0.getattr("model_dump")?;
let dict = model_dump.call0()?;
return visitor.visit_map(MapDeserializer::new(dict.downcast()?));
}
if self.0.hasattr("__dict__")? {
return visitor.visit_map(MapDeserializer::new(
self.0.getattr("__dict__")?.downcast()?,
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,6 @@ pub use ser::to_pyobject;

#[cfg_attr(doc, doc = include_str!("../README.md"))]
mod readme {}

#[cfg(any(feature = "dataclass_support", feature = "pydantic_support"))]
mod py_module_cache;
96 changes: 96 additions & 0 deletions src/py_module_cache.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use once_cell::sync::OnceCell;
use pyo3::{types::*, Bound, IntoPyObject, Py, PyResult, Python};

// Individual OnceCell instances for each cached item
#[cfg(feature = "dataclass_support")]
static DATACLASSES_MODULE: OnceCell<Py<PyAny>> = OnceCell::new();
#[cfg(feature = "dataclass_support")]
static IS_DATACLASS_FN: OnceCell<Py<PyAny>> = OnceCell::new();

#[cfg(feature = "pydantic_support")]
static PYDANTIC_MODULE: OnceCell<Py<PyAny>> = OnceCell::new();
#[cfg(feature = "pydantic_support")]
static PYDANTIC_BASE_MODEL: OnceCell<Py<PyAny>> = OnceCell::new();

#[cfg(feature = "pydantic_support")]
fn is_module_installed(py: Python, module_name: &str) -> PyResult<bool> {
match PyModule::import(py, module_name) {
Ok(_) => Ok(true),
Err(err) => {
if err.is_instance_of::<pyo3::exceptions::PyModuleNotFoundError>(py) {
Ok(false)
} else {
Err(err)
}
}
}
}

#[cfg(feature = "dataclass_support")]
pub fn is_dataclass(py: Python, obj: &Bound<'_, PyAny>) -> PyResult<bool> {
// Initialize the dataclasses module if needed
if DATACLASSES_MODULE.get().is_none() {
let dataclasses = PyModule::import(py, "dataclasses")?;
let _ = DATACLASSES_MODULE.set(dataclasses.into());
}

// Initialize the is_dataclass function if needed
let is_dataclass_fn = if let Some(fn_obj) = IS_DATACLASS_FN.get() {
fn_obj
} else {
let dataclasses = DATACLASSES_MODULE
.get()
.ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err("Dataclasses module not initialized")
})?
.bind(py);
let is_dataclass_fn: Py<PyAny> = dataclasses
.getattr("is_dataclass")?
.into_pyobject(py)?
.into();
// Safe to unwrap because we know the value is set
let _ = IS_DATACLASS_FN.set(is_dataclass_fn);
IS_DATACLASS_FN.get().ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err("Failed to initialize is_dataclass function")
})?
};

// Execute the function
let result = is_dataclass_fn.bind(py).call1((obj,))?;
result.extract()
}
#[cfg(feature = "pydantic_support")]
pub fn is_pydantic_base_model(py: Python, obj: &Bound<'_, PyAny>) -> PyResult<bool> {
// First check if pydantic is installed
if !is_module_installed(py, "pydantic")? {
return Ok(false);
}

// Initialize pydantic module if needed
if PYDANTIC_MODULE.get().is_none() {
let pydantic = PyModule::import(py, "pydantic")?;
// Safe to unwrap because we know the value is empty
let _ = PYDANTIC_MODULE.set(pydantic.into());
}

// Initialize BaseModel if needed
let base_model = if let Some(model) = PYDANTIC_BASE_MODEL.get() {
model
} else {
let pydantic = PYDANTIC_MODULE
.get()
.ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err("Pydantic module not initialized")
})?
.bind(py);
let base_model: Py<PyAny> = pydantic.getattr("BaseModel")?.into_pyobject(py)?.into();
// Safe to unwrap because we know the value is empty
let _ = PYDANTIC_BASE_MODEL.set(base_model);
PYDANTIC_BASE_MODEL.get().ok_or_else(|| {
pyo3::exceptions::PyRuntimeError::new_err("Failed to initialize BaseModel")
})?
};

// Check if object is instance of BaseModel
obj.is_instance(base_model.bind(py))
}
163 changes: 163 additions & 0 deletions tests/check_revertible.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,166 @@ MyClass("John", 30)
assert_eq!(rust_version, python_version);
})
}

#[cfg(feature = "pydantic_support")]
#[test]
fn check_pydantic_object() {
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct MyClass {
name: String,
age: i32,
}

Python::attach(|py| {
// Create an instance of Python object
py.run(
c_str!(
r#"
from pydantic import BaseModel
class MyClass(BaseModel):
name: str
age: int
"#
),
None,
None,
)
.unwrap();
// Create an instance of MyClass
let my_python_class = py
.eval(
c_str!(
r#"
MyClass(name="John", age=30)
"#
),
None,
None,
)
.unwrap();

let my_rust_class = MyClass {
name: "John".to_string(),
age: 30,
};
let any: Bound<'_, PyAny> = to_pyobject(py, &my_rust_class).unwrap();
println!("any: {:?}", any);

let rust_version: MyClass = from_pyobject(my_python_class).unwrap();
let python_version: MyClass = from_pyobject(any).unwrap();
assert_eq!(rust_version, python_version);
})
}

#[cfg(feature = "dataclass_support")]
#[test]
fn check_dataclass_object() {
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct MyClass {
name: String,
age: i32,
}

Python::attach(|py| {
// Create an instance of Python object
py.run(
c_str!(
r#"
from dataclasses import dataclass
@dataclass
class MyClass:
name: str
age: int
"#
),
None,
None,
)
.unwrap();
// Create an instance of MyClass
let my_python_class = py
.eval(
c_str!(
r#"
MyClass(name="John", age=30)
"#
),
None,
None,
)
.unwrap();

let my_rust_class = MyClass {
name: "John".to_string(),
age: 30,
};
let any: Bound<'_, PyAny> = to_pyobject(py, &my_rust_class).unwrap();
println!("any: {:?}", any);

let rust_version: MyClass = from_pyobject(my_python_class).unwrap();
let python_version: MyClass = from_pyobject(any).unwrap();
assert_eq!(rust_version, python_version);
})
}

#[cfg(feature = "dataclass_support")]
#[test]
fn check_dataclass_object_nested() {
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct MyClassNested {
name: String,
age: i32,
}

#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct MyClass {
my_class: MyClassNested,
}

Python::attach(|py| {
// Create an instance of Python object
py.run(
c_str!(
r#"
from dataclasses import dataclass
@dataclass
class MyClassNested:
name: str
age: int

@dataclass
class MyClass:
my_class: MyClassNested
"#
),
None,
None,
)
.unwrap();
// Create an instance of MyClass
let my_python_class = py
.eval(
c_str!(
r#"
MyClass(my_class=MyClassNested(name="John", age=30))
"#
),
None,
None,
)
.unwrap();

let my_rust_class = MyClass {
my_class: MyClassNested {
name: "John".to_string(),
age: 30,
},
};
let any: Bound<'_, PyAny> = to_pyobject(py, &my_rust_class).unwrap();
println!("any: {:?}", any);

let rust_version: MyClass = from_pyobject(my_python_class).unwrap();
let python_version: MyClass = from_pyobject(any).unwrap();
assert_eq!(rust_version, python_version);
})
}
Loading