Skip to content

Commit 35de126

Browse files
committed
fix rpy2 deprication error
Signed-off-by: kgoebler <Konstantin.Goebler@de.bosch.com>
1 parent 0d2e9d1 commit 35de126

2 files changed

Lines changed: 43 additions & 31 deletions

File tree

README.md

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
[![Code style: ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/format.json)](https://github.com/astral-sh/ruff)
55

66
This repo provides details regarding $\texttt{causalAssembly}$, a causal discovery benchmark data tool based on complex production data.
7-
Theoretical details and information regarding construction are presented in the [paper](https://arxiv.org/abs/2306.10816):
7+
Theoretical details and information regarding construction are presented in the [paper](https://proceedings.mlr.press/v236/gobler24a.html):
8+
9+
Göbler, K., Windisch, T., Drton, M., Pychynski, T., Roth, M., & Sonntag, S. (2024). causalAssembly: Generating Realistic Production Data for Benchmarking Causal Discovery. In Proceedings of the Third Conference on Causal Learning and Reasoning (pp. 609–642). PMLR.
810

9-
Göbler, K., Windisch, T., Pychynski, T., Sonntag, S., Roth, M., & Drton, M. causalAssembly: Generating Realistic Production Data for Benchmarking Causal Discovery, to appear in Proceedings of the 3rd Conference on Causal Learning and Reasoning (CLeaR), 2024,
1011
## Authors
12+
1113
* [Konstantin Goebler](mailto:konstantin.goebler@de.bosch.com)
1214
* [Steffen Sonntag](mailto:steffen.sonntag@de.bosch.com)
1315

@@ -26,14 +28,13 @@ The package can be installed as follows
2628

2729
pip install causalAssembly
2830

29-
[comment]: <> (git+https://github.com/boschresearch/causalAssembly.git)
3031

3132
## <a name="using">How to use</a>
3233

3334
This is how $\texttt{causalAssembly}$'s functionality may be used. Be sure to read the [documentation](https://boschresearch.github.io/causalAssembly/) for more in-depth details regarding available functions and classes.
3435

3536
In case you want to train a distributional random forests yourself (see [how to semisynthetsize](#how-to-semisynthesize)),
36-
you need an R installation as well as the corresponding [drf](https://cran.r-project.org/web/packages/drf/index.html) R package.
37+
you need an R installation.
3738
Sampling has first been proposed in [[2]](#2).
3839

3940
*Note*: For Windows users the python package [rpy2](https://github.com/rpy2/rpy2) might cause issues.
@@ -69,7 +70,9 @@ assembly_line.Station3.drf = fit_drf(assembly_line.Station3, data=assembly_line_
6970
station3_sample = assembly_line.Station3.sample_from_drf(size=n_select)
7071

7172
```
73+
7274
### <a name="Interventional data">Interventional data</a>
75+
7376
In case you want to create interventional data, we currently support hard and soft interventions.
7477
For soft interventions we use `sympy`'s `RandomSymbol` class. Essentially, soft interventions should
7578
be declared by choosing your preferred random variable with associated distribution from [here](https://docs.sympy.org/latest/modules/stats.html#continuous-types). Simple examples include:
@@ -158,7 +161,6 @@ if nx.is_directed_acyclic_graph(s_graph):
158161

159162
### <a name="how-to-rand">How to generate random production DAGs</a>
160163

161-
162164
The `ProductionLineGraph` class can further be used to generate completely random DAGs that follow an assembly line logic. Consider the following example:
163165

164166
```python
@@ -183,12 +185,11 @@ example_line.connect_cells(forward_probs= [.1])
183185
example_line.show()
184186

185187
```
186-
### <a name="how-to-fcm">How to generate FCMs</a>
187188

189+
### <a name="how-to-fcm">How to generate FCMs</a>
188190

189191
$\texttt{causalAssembly}$ also allows creating structural causal models (SCM) or synonymously functional causal models (FCM). In particular, we employ symbolic programming to allow for a seamless interplay between readability and performance. The `FCM` class is completely general and inherits no production data logic. See the example below for construction and usage.
190192

191-
192193
```python
193194

194195
import numpy as np
@@ -279,6 +280,7 @@ Please feel free to contact one of the authors in case you wish to contribute.
279280
| [matplotlib](https://github.com/matplotlib/matplotlib) | [Other](https://github.com/matplotlib/matplotlib/tree/main/LICENSE) | Dependency |
280281
| [sympy](https://github.com/sympy/sympy) | [BSD-3-Clause License](https://github.com/sympy/sympy/blob/master/LICENSE) | Dependency |
281282
| [rpy2](https://github.com/rpy2/rpy2) | [GNU General Public License v2.0](https://github.com/rpy2/rpy2/blob/master/LICENSE) | Dependency |
283+
282284
### Development dependency
283285

284286
| Name | License | Type |

causalAssembly/drf_fitting.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,26 @@
1919
import numpy as np
2020
import pandas as pd
2121
import rpy2.robjects as ro
22-
import rpy2.robjects.numpy2ri
23-
from rpy2.robjects import pandas2ri
22+
import rpy2.robjects.packages as rpackages
23+
from rpy2.robjects import numpy2ri, pandas2ri
24+
from rpy2.robjects.conversion import localconverter
2425
from rpy2.robjects.packages import importr
2526
from scipy.stats import gaussian_kde
2627

2728
from causalAssembly.dag import DAG
2829
from causalAssembly.models_dag import ProcessCell, ProductionLineGraph
2930

30-
rpy2.robjects.numpy2ri.activate()
31-
pandas2ri.activate()
31+
# Converter for numpy + pandas instead of using deprecated activate()
32+
R_CONVERTER = ro.default_converter + numpy2ri.converter + pandas2ri.converter
33+
3234
base_r_package = importr("base")
35+
utils = importr("utils")
36+
37+
if not rpackages.isinstalled("drf"):
38+
# select a mirror for R packages
39+
utils.chooseCRANmirror(ind=1)
40+
utils.install_packages("drf", repos="https://cloud.r-project.org/")
41+
3342
drf_r_package = importr("drf")
3443

3544

@@ -41,22 +50,24 @@ class DRF:
4150
"""
4251

4352
def __init__(self, **fit_params):
44-
"""Initialize DRF object."""
53+
"""Initialize the DRF object with fit parameters."""
4554
self.fit_params = fit_params
4655

4756
def fit(self, X: pd.DataFrame, Y: pd.DataFrame | pd.Series):
4857
"""Fit DRF in order to estimate conditional distribution P(Y|X=x).
4958
5059
Args:
51-
X (pd.DataFrame): Conditioning set.
52-
Y (pd.DataFrame): Variable of interest (can be vector-valued).
60+
X (pd.DataFrame): Predictor variables.
61+
Y (pd.DataFrame | pd.Series): Response variable(s).
5362
"""
5463
self.X_train = X
5564
self.Y_train = Y
5665

57-
X_r = ro.conversion.py2rpy(X)
58-
Y_r = ro.conversion.py2rpy(Y)
59-
self.r_fit_object = drf_r_package.drf(X_r, Y_r, **self.fit_params)
66+
# Use localconverter
67+
with localconverter(R_CONVERTER):
68+
X_r = ro.conversion.py2rpy(X)
69+
Y_r = ro.conversion.py2rpy(Y)
70+
self.r_fit_object = drf_r_package.drf(X_r, Y_r, **self.fit_params)
6071

6172
def produce_sample(
6273
self,
@@ -72,15 +83,18 @@ def produce_sample(
7283
n (int, optional): Number of n-samples to draw. Defaults to 1.
7384
7485
Returns:
75-
np.ndarray: New predicted samlpe of Y.
86+
np.ndarray: New predicted sample of Y.
7687
"""
77-
newdata_r = ro.conversion.py2rpy(newdata)
78-
r_output = drf_r_package.predict_drf(self.r_fit_object, newdata_r)
88+
with localconverter(R_CONVERTER):
89+
newdata_r = ro.conversion.py2rpy(newdata)
90+
r_output = drf_r_package.predict_drf(self.r_fit_object, newdata_r)
7991

80-
weights = base_r_package.as_matrix(r_output[0])
92+
# Convert back to Python
93+
weights = ro.conversion.rpy2py(base_r_package.as_matrix(r_output[0]))
94+
Y = ro.conversion.rpy2py(base_r_package.as_matrix(r_output[1]))
8195

82-
Y = pd.DataFrame(base_r_package.as_matrix(r_output[1]))
83-
Y = Y.apply(pd.Series)
96+
if not isinstance(Y, pd.DataFrame):
97+
Y = pd.DataFrame(Y)
8498

8599
sample = np.zeros((newdata.shape[0], Y.shape[1], n))
86100
for i in range(newdata.shape[0]):
@@ -98,17 +112,14 @@ def fit_drf(graph: ProductionLineGraph | ProcessCell | DAG, data: pd.DataFrame):
98112
graph (ProductionLineGraph | ProcessCell | DAG): Graph to fit the DRF to.
99113
data (pd.DataFrame): Columns of dataframe need to match name and order of the graph
100114
101-
Raises:
102-
ValueError: Raises error if columns don't meet this requirement
115+
Raises: ValueError: Raises error if columns don't meet this requirement
103116
104-
Returns:
105-
(dict): dict of fitted DRFs.
117+
Returns: (dict): dict of fitted DRFs.
106118
"""
107119
tempdata = data.copy()
108120

109121
if set(graph.nodes).issubset(tempdata.columns):
110122
tempdata = tempdata[graph.nodes]
111-
112123
else:
113124
raise ValueError("Data columns don't match node names.")
114125

@@ -118,9 +129,8 @@ def fit_drf(graph: ProductionLineGraph | ProcessCell | DAG, data: pd.DataFrame):
118129
if not parents:
119130
drf_dict[node] = gaussian_kde(tempdata[node].to_numpy())
120131
elif parents:
121-
drf_object = DRF(
122-
min_node_size=15, num_trees=2000, splitting_rule="FourierMMD"
123-
) # default setting as suggested in the paper
132+
# default setting as suggested in the paper
133+
drf_object = DRF(min_node_size=15, num_trees=2000, splitting_rule="FourierMMD")
124134
X = tempdata[parents]
125135
Y = tempdata[node]
126136
drf_object.fit(X, Y)

0 commit comments

Comments
 (0)