Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
[![Code style: ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/format.json)](https://github.com/astral-sh/ruff)

This repo provides details regarding $\texttt{causalAssembly}$, a causal discovery benchmark data tool based on complex production data.
Theoretical details and information regarding construction are presented in the [paper](https://arxiv.org/abs/2306.10816):
Theoretical details and information regarding construction are presented in the [paper](https://proceedings.mlr.press/v236/gobler24a.html):

Göbler, K., Windisch, T., Drton, M., Pychynski, T., Roth, M., & Sonntag, S. (2024). causalAssembly: Generating Realistic Production Data for Benchmarking Causal Discovery. In Proceedings of the Third Conference on Causal Learning and Reasoning (pp. 609–642). PMLR.

Göbler, K., Windisch, T., Pychynski, T., Sonntag, S., Roth, M., & Drton, M. causalAssembly: Generating Realistic Production Data for Benchmarking Causal Discovery, to appear in Proceedings of the 3rd Conference on Causal Learning and Reasoning (CLeaR), 2024,
## Authors

* [Konstantin Goebler](mailto:konstantin.goebler@de.bosch.com)
* [Steffen Sonntag](mailto:steffen.sonntag@de.bosch.com)

Expand All @@ -26,14 +28,13 @@ The package can be installed as follows

pip install causalAssembly

[comment]: <> (git+https://github.com/boschresearch/causalAssembly.git)

## <a name="using">How to use</a>

This is how $\texttt{causalAssembly}$'s functionality may be used. Be sure to read the [documentation](https://boschresearch.github.io/causalAssembly/) for more in-depth details regarding available functions and classes.

In case you want to train a distributional random forests yourself (see [how to semisynthetsize](#how-to-semisynthesize)),
you need an R installation as well as the corresponding [drf](https://cran.r-project.org/web/packages/drf/index.html) R package.
you need an R installation.
Sampling has first been proposed in [[2]](#2).

*Note*: For Windows users the python package [rpy2](https://github.com/rpy2/rpy2) might cause issues.
Expand Down Expand Up @@ -69,7 +70,9 @@ assembly_line.Station3.drf = fit_drf(assembly_line.Station3, data=assembly_line_
station3_sample = assembly_line.Station3.sample_from_drf(size=n_select)

```

### <a name="Interventional data">Interventional data</a>

In case you want to create interventional data, we currently support hard and soft interventions.
For soft interventions we use `sympy`'s `RandomSymbol` class. Essentially, soft interventions should
be declared by choosing your preferred random variable with associated distribution from [here](https://docs.sympy.org/latest/modules/stats.html#continuous-types). Simple examples include:
Expand Down Expand Up @@ -158,7 +161,6 @@ if nx.is_directed_acyclic_graph(s_graph):

### <a name="how-to-rand">How to generate random production DAGs</a>


The `ProductionLineGraph` class can further be used to generate completely random DAGs that follow an assembly line logic. Consider the following example:

```python
Expand All @@ -183,12 +185,11 @@ example_line.connect_cells(forward_probs= [.1])
example_line.show()

```
### <a name="how-to-fcm">How to generate FCMs</a>

### <a name="how-to-fcm">How to generate FCMs</a>

$\texttt{causalAssembly}$ also allows creating structural causal models (SCM) or synonymously functional causal models (FCM). In particular, we employ symbolic programming to allow for a seamless interplay between readability and performance. The `FCM` class is completely general and inherits no production data logic. See the example below for construction and usage.


```python

import numpy as np
Expand Down Expand Up @@ -279,6 +280,7 @@ Please feel free to contact one of the authors in case you wish to contribute.
| [matplotlib](https://github.com/matplotlib/matplotlib) | [Other](https://github.com/matplotlib/matplotlib/tree/main/LICENSE) | Dependency |
| [sympy](https://github.com/sympy/sympy) | [BSD-3-Clause License](https://github.com/sympy/sympy/blob/master/LICENSE) | Dependency |
| [rpy2](https://github.com/rpy2/rpy2) | [GNU General Public License v2.0](https://github.com/rpy2/rpy2/blob/master/LICENSE) | Dependency |

### Development dependency

| Name | License | Type |
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.2.0
1.2.1
58 changes: 34 additions & 24 deletions causalAssembly/drf_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,26 @@
import numpy as np
import pandas as pd
import rpy2.robjects as ro
import rpy2.robjects.numpy2ri
from rpy2.robjects import pandas2ri
import rpy2.robjects.packages as rpackages
from rpy2.robjects import numpy2ri, pandas2ri
from rpy2.robjects.conversion import localconverter
from rpy2.robjects.packages import importr
from scipy.stats import gaussian_kde

from causalAssembly.dag import DAG
from causalAssembly.models_dag import ProcessCell, ProductionLineGraph

rpy2.robjects.numpy2ri.activate()
pandas2ri.activate()
# Converter for numpy + pandas instead of using deprecated activate()
R_CONVERTER = ro.default_converter + numpy2ri.converter + pandas2ri.converter

base_r_package = importr("base")
utils = importr("utils")

if not rpackages.isinstalled("drf"):
# select a mirror for R packages
utils.chooseCRANmirror(graphics=False)
utils.install_packages("drf", repos="https://cloud.r-project.org/")

drf_r_package = importr("drf")


Expand All @@ -41,22 +50,24 @@ class DRF:
"""

def __init__(self, **fit_params):
"""Initialize DRF object."""
"""Initialize the DRF object with fit parameters."""
self.fit_params = fit_params

def fit(self, X: pd.DataFrame, Y: pd.DataFrame | pd.Series):
"""Fit DRF in order to estimate conditional distribution P(Y|X=x).

Args:
X (pd.DataFrame): Conditioning set.
Y (pd.DataFrame): Variable of interest (can be vector-valued).
X (pd.DataFrame): Predictor variables.
Y (pd.DataFrame | pd.Series): Response variable(s).
"""
self.X_train = X
self.Y_train = Y

X_r = ro.conversion.py2rpy(X)
Y_r = ro.conversion.py2rpy(Y)
self.r_fit_object = drf_r_package.drf(X_r, Y_r, **self.fit_params)
# Use localconverter
with localconverter(R_CONVERTER):
X_r = ro.conversion.py2rpy(X)
Y_r = ro.conversion.py2rpy(Y)
self.r_fit_object = drf_r_package.drf(X_r, Y_r, **self.fit_params)

def produce_sample(
self,
Expand All @@ -72,15 +83,18 @@ def produce_sample(
n (int, optional): Number of n-samples to draw. Defaults to 1.

Returns:
np.ndarray: New predicted samlpe of Y.
np.ndarray: New predicted sample of Y.
Copy link

Copilot AI Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a spelling error in the comment. 'samlpe' should be 'sample'. This was correctly fixed in the code.

Copilot uses AI. Check for mistakes.
"""
newdata_r = ro.conversion.py2rpy(newdata)
r_output = drf_r_package.predict_drf(self.r_fit_object, newdata_r)
with localconverter(R_CONVERTER):
newdata_r = ro.conversion.py2rpy(newdata)
r_output = drf_r_package.predict_drf(self.r_fit_object, newdata_r)

weights = base_r_package.as_matrix(r_output[0])
# Convert back to Python
weights = ro.conversion.rpy2py(base_r_package.as_matrix(r_output[0]))
Y = ro.conversion.rpy2py(base_r_package.as_matrix(r_output[1]))

Y = pd.DataFrame(base_r_package.as_matrix(r_output[1]))
Y = Y.apply(pd.Series)
if not isinstance(Y, pd.DataFrame):
Y = pd.DataFrame(Y)

sample = np.zeros((newdata.shape[0], Y.shape[1], n))
for i in range(newdata.shape[0]):
Expand All @@ -98,17 +112,14 @@ def fit_drf(graph: ProductionLineGraph | ProcessCell | DAG, data: pd.DataFrame):
graph (ProductionLineGraph | ProcessCell | DAG): Graph to fit the DRF to.
data (pd.DataFrame): Columns of dataframe need to match name and order of the graph

Raises:
ValueError: Raises error if columns don't meet this requirement
Raises: ValueError: Raises error if columns don't meet this requirement

Returns:
(dict): dict of fitted DRFs.
Returns: (dict): dict of fitted DRFs.
"""
tempdata = data.copy()

if set(graph.nodes).issubset(tempdata.columns):
tempdata = tempdata[graph.nodes]

else:
raise ValueError("Data columns don't match node names.")

Expand All @@ -118,9 +129,8 @@ def fit_drf(graph: ProductionLineGraph | ProcessCell | DAG, data: pd.DataFrame):
if not parents:
drf_dict[node] = gaussian_kde(tempdata[node].to_numpy())
elif parents:
drf_object = DRF(
min_node_size=15, num_trees=2000, splitting_rule="FourierMMD"
) # default setting as suggested in the paper
# default setting as suggested in the paper
drf_object = DRF(min_node_size=15, num_trees=2000, splitting_rule="FourierMMD")
X = tempdata[parents]
Y = tempdata[node]
drf_object.fit(X, Y)
Expand Down