Skip to content

Commit ce066de

Browse files
authored
Merge pull request #6 from boschresearch/feature/5-anomalous-dataset
Feature/5 anomalous dataset
2 parents aec9af9 + a4d5588 commit ce066de

4 files changed

Lines changed: 244 additions & 4 deletions

File tree

README.md

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
[![Code style: ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/format.json)](https://github.com/astral-sh/ruff)
55

66
This repo provides details regarding $\texttt{causalAssembly}$, a causal discovery benchmark data tool based on complex production data.
7-
Theoretical details and information regarding construction are presented in the paper:
7+
Theoretical details and information regarding construction are presented in the [paper](https://arxiv.org/abs/2306.10816):
88

9-
Göbler, K., Windisch, T., Pychynski, T., Sonntag, S., Roth, M., & Drton, M. (2023). causalAssembly: Generating Realistic Production Data for Benchmarking Causal Discovery. arXiv preprint arXiv:2306.10816.
9+
Göbler, K., Windisch, T., Pychynski, T., Sonntag, S., Roth, M., & Drton, M. causalAssembly: Generating Realistic Production Data for Benchmarking Causal Discovery, to appear in Proceedings of the 3rd Conference on Causal Learning and Reasoning (CLeaR), 2024,
1010
## Authors
1111
* [Konstantin Goebler](mailto:konstantin.goebler@de.bosch.com)
1212
* [Steffen Sonntag](mailto:steffen.sonntag@de.bosch.com)
@@ -69,6 +69,55 @@ assembly_line.Station3.drf = fit_drf(assembly_line.Station3, data=assembly_line_
6969
station3_sample = assembly_line.Station3.sample_from_drf(size=n_select)
7070

7171
```
72+
### <a name="Interventional data">Interventional data</a>
73+
In case you want to create interventional data, we currently support hard and soft interventions.
74+
For soft interventions we use `sympy`'s `RandomSymbol` class. Essentially, soft interventions should
75+
be declared by choosing your preferred random variable with associated distribution from [here](https://docs.sympy.org/latest/modules/stats.html#continuous-types). Simple examples include:
76+
77+
```python
78+
from sympy.stats import Beta, Normal, Uniform
79+
80+
x = Beta("x", 1, 1)
81+
y = Normal("y", 0, 1)
82+
z = Uniform("z", 0, 1)
83+
84+
```
85+
86+
The following example is similar to the basic use example above where we now intervene on two nodes in the graph.
87+
88+
```python
89+
from sympy.stats import Beta
90+
91+
from causalAssembly.drf_fitting import fit_drf
92+
from causalAssembly.models_dag import ProductionLineGraph
93+
94+
seed = 2023
95+
n_select = 500
96+
97+
assembly_line_data = ProductionLineGraph.get_data()
98+
99+
# take subsample for demonstration purposes
100+
assembly_line_data = assembly_line_data.sample(n_select, random_state=seed, replace=False)
101+
102+
# load in ground truth
103+
assembly_line = ProductionLineGraph.get_ground_truth()
104+
105+
# fit drf and sample for entire line
106+
assembly_line.drf = fit_drf(assembly_line, data=assembly_line_data)
107+
108+
# intervene on two nodes in the assembly line
109+
assembly_line.intervene_on(
110+
nodes_values={"Station3_mp_41": 2, "Station4_mp_58": Beta("noise", 1, 1)}
111+
)
112+
113+
# sample from the corresponding interventional distribution
114+
my_int_df = assembly_line.sample_from_interventional_drf(size=5)
115+
116+
print(my_int_df[["Station3_mp_41", "Station4_mp_58"]])
117+
118+
```
119+
120+
Note that intervening does not alter any of the functionalities introduced above. The interevened upon DAGs are stored in `mutilated_dags`. When calling `sample_from_drf()` the ground truth DAG as described in the paper is used. To sample from the interventional distribution, you must use `sample_from_interventional_drf`.
72121

73122
### <a name="how-to-semisynthesize">How to semisynthesize</a>
74123

VERSION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.1.2
1+
1.1.3

causalAssembly/models_dag.py

Lines changed: 186 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
from matplotlib.patches import BoxStyle, FancyBboxPatch
3131
from networkx.readwrite import json_graph
3232
from scipy.stats import gaussian_kde
33+
from sympy.stats import sample as sympy_sample
34+
from sympy.stats.rv import RandomSymbol
3335

3436
from causalAssembly.dag_utils import _bootstrap_sample, tuples_from_cartesian_product
3537
from causalAssembly.pdag import PDAG, dag2cpdag
@@ -79,6 +81,78 @@ def _sample_from_drf(
7981
return new_df[prod_object.nodes]
8082

8183

84+
def _interventional_sample_from_drf(
85+
prod_object: ProductionLineGraph,
86+
which_intervention: int | str = 0,
87+
size=10,
88+
smoothed: bool = True,
89+
) -> pd.DataFrame:
90+
if not prod_object.drf:
91+
raise ValueError("Nothing to sample from. Learn DRF first!")
92+
93+
if not prod_object.mutilated_dags:
94+
raise ValueError("No mutilated DAGs available. Please intervene first.")
95+
96+
if not prod_object.interventional_drf:
97+
raise ValueError(
98+
"No intervention values available. \
99+
Please verify your hard/soft interventions."
100+
)
101+
102+
if isinstance(which_intervention, int):
103+
intervention_replace_dict = list(prod_object.interventional_drf.values())[
104+
which_intervention
105+
]
106+
elif isinstance(which_intervention, str):
107+
intervention_replace_dict = prod_object.interventional_drf[which_intervention]
108+
109+
else:
110+
raise ValueError("Please specify which intervention you want to sample from.")
111+
112+
# for node, value in intervention_replace_dict.items():
113+
# prod_object.drf[node] = value
114+
115+
sample_dict = {}
116+
for node in prod_object.causal_order:
117+
if node in intervention_replace_dict:
118+
if isinstance(intervention_replace_dict[node], int):
119+
sample_dict[node] = np.repeat(a=intervention_replace_dict[node], repeats=size)
120+
121+
elif isinstance(intervention_replace_dict[node], RandomSymbol):
122+
sample_dict[node] = sympy_sample(
123+
expr=intervention_replace_dict[node], size=size, seed=prod_object.random_state
124+
)
125+
else:
126+
raise NotImplementedError(
127+
"Currently only hard and soft interventions are implemented"
128+
)
129+
continue
130+
131+
if isinstance(prod_object.drf[node], gaussian_kde):
132+
# Node has no parents, generate a sample using bootstrapping
133+
#
134+
if smoothed:
135+
sample_dict[node] = prod_object.drf[node].resample(
136+
size=size, seed=prod_object.random_state
137+
)[0]
138+
else:
139+
sample_dict[node] = _bootstrap_sample(
140+
rng=prod_object.random_state,
141+
data=prod_object.drf[node].dataset[0],
142+
size=size,
143+
)
144+
else:
145+
parents = prod_object.parents(of_node=node)
146+
new_data = pd.DataFrame({col: sample_dict[col] for col in parents})
147+
# new_data = pd.DataFrame(sample_dict[parents])
148+
forest = prod_object.drf[node]
149+
sample_dict[node] = forest.produce_sample(
150+
newdata=new_data, random_state=prod_object.random_state
151+
)
152+
new_df = pd.DataFrame(sample_dict)
153+
return new_df[prod_object.nodes]
154+
155+
82156
class ProcessCell:
83157
"""
84158
Representation of a single Production Line Cell
@@ -288,6 +362,19 @@ def sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
288362
"""
289363
return _sample_from_drf(prod_object=self, size=size, smoothed=smoothed)
290364

365+
def interventional_sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
366+
"""Draw from the trained DRF.
367+
368+
Args:
369+
size (int, optional): Number of samples to be drawn. Defaults to 10.
370+
smoothed (bool, optional): If set to true, marginal distributions will
371+
be sampled from smoothed bootstraps. Defaults to True.
372+
373+
Returns:
374+
pd.DataFrame: Data frame that follows the distribution implied by the ground truth.
375+
"""
376+
return _interventional_sample_from_drf(prod_object=self, size=size, smoothed=smoothed)
377+
291378
def _generate_random_dag(self, n_nodes: int = 5, p: float = 0.1) -> nx.DiGraph:
292379
"""
293380
Creates a random DAG by
@@ -711,6 +798,8 @@ def __init__(self):
711798
self.cell_connector_edges = list()
712799
self.cell_order = list()
713800
self.drf: dict = dict()
801+
self.interventional_drf: dict = dict()
802+
self.__init_mutilated_dag()
714803

715804
@property
716805
def random_state(self):
@@ -722,6 +811,9 @@ def random_state(self, r: np.random.Generator):
722811
raise AssertionError("Specify numpy random number generator object!")
723812
self._random_state = r
724813

814+
def __init_mutilated_dag(self):
815+
self.mutilated_dags = dict()
816+
725817
@property
726818
def graph(self) -> nx.DiGraph:
727819
"""
@@ -1002,6 +1094,78 @@ def connect_across_cells_manually(self, edges: list[tuple]):
10021094
"""
10031095
self.cell_connector_edges.extend(edges)
10041096

1097+
def intervene_on(self, nodes_values: dict[str, RandomSymbol | float]):
1098+
"""Specify hard or soft intervention. If you want to intervene
1099+
upon more than one node provide a list of nodes to intervene on
1100+
and a list of corresponding values to set these nodes to.
1101+
(see example). The mutilated dag will automatically be
1102+
stored in `mutiliated_dags`.
1103+
1104+
Args:
1105+
nodes_values (dict[str, RandomSymbol | float]): either single real
1106+
number or sympy.stats.RandomSymbol. If you like to intervene on
1107+
more than one node, just provide more key-value pairs.
1108+
1109+
Raises:
1110+
AssertionError: If node(s) are not in the graph
1111+
"""
1112+
if not self.drf:
1113+
raise AssertionError("You need to train a drf first.")
1114+
drf_replace = {}
1115+
1116+
if not set(nodes_values.keys()).issubset(set(self.nodes)):
1117+
raise AssertionError(
1118+
"One or more nodes you want to intervene upon are not in the graph."
1119+
)
1120+
1121+
mutilated_dag = self.graph.copy()
1122+
1123+
for node, value in nodes_values.items():
1124+
old_incoming = self.parents(of_node=node)
1125+
edges_to_remove = [(old, node) for old in old_incoming]
1126+
mutilated_dag.remove_edges_from(edges_to_remove)
1127+
drf_replace[node] = value
1128+
1129+
self.mutilated_dags[
1130+
f"do({list(nodes_values.keys())})"
1131+
] = mutilated_dag # specifiying the same set twice will override
1132+
1133+
self.interventional_drf[f"do({list(nodes_values.keys())})"] = drf_replace
1134+
1135+
@property
1136+
def interventions(self) -> list:
1137+
"""Returns all interventions performed on the original graph
1138+
1139+
Returns:
1140+
list: list of intervened upon nodes in do(x) notation.
1141+
"""
1142+
return list(self.mutilated_dags.keys())
1143+
1144+
def interventional_amat(self, which_intervention: int | str) -> pd.DataFrame:
1145+
"""Returns the adjacency matrix of a chosen mutilated DAG.
1146+
1147+
Args:
1148+
which_intervention (int | str): Integer count of your chosen intervention or
1149+
literal string.
1150+
1151+
Raises:
1152+
ValueError: "The intervention you provide does not exist."
1153+
1154+
Returns:
1155+
pd.DataFrame: Adjacency matrix.
1156+
"""
1157+
if isinstance(which_intervention, str) and which_intervention not in self.interventions:
1158+
raise ValueError("The intervention you provide does not exist.")
1159+
1160+
if isinstance(which_intervention, int) and which_intervention > len(self.interventions):
1161+
raise ValueError("The intervention you index does not exist.")
1162+
1163+
if isinstance(which_intervention, int):
1164+
which_intervention = self.interventions[which_intervention]
1165+
1166+
mutilated_dag = self.mutilated_dags[which_intervention].copy()
1167+
return nx.to_pandas_adjacency(mutilated_dag, weight=None)
1168+
10051169
@classmethod
10061170
def get_ground_truth(cls) -> ProductionLineGraph:
10071171
"""Loads in the ground_truth as described in the paper:
@@ -1142,6 +1306,27 @@ def sample_from_drf(self, size=10, smoothed: bool = True) -> pd.DataFrame:
11421306
"""
11431307
return _sample_from_drf(prod_object=self, size=size, smoothed=smoothed)
11441308

1309+
def sample_from_interventional_drf(
1310+
self, which_intervention: str | int = 0, size=10, smoothed: bool = True
1311+
) -> pd.DataFrame:
1312+
"""Draw from the trained and intervened upon DRF.
1313+
1314+
Args:
1315+
size (int, optional): Number of samples to be drawn. Defaults to 10.
1316+
which_intervention (str | int): Which intervention to choose from.
1317+
Both the literal name (see the property `interventions`) and the index
1318+
are possible. Defaults to the first intervention.
1319+
smoothed (bool, optional): If set to true, marginal distributions will
1320+
be sampled from smoothed bootstraps. Defaults to True.
1321+
1322+
Returns:
1323+
pd.DataFrame: Data frame that follows the interventional distribution
1324+
implied by the ground truth.
1325+
"""
1326+
return _interventional_sample_from_drf(
1327+
prod_object=self, which_intervention=which_intervention, size=size, smoothed=smoothed
1328+
)
1329+
11451330
def hidden_nodes(self) -> list:
11461331
"""Returns list of nodes marked as hidden
11471332
@@ -1230,7 +1415,7 @@ def show(self, meta_description: list | None = None, fig_size: tuple = (15, 8)):
12301415
Raises:
12311416
AssertionError: Meta list entry needs to exist for each cell!
12321417
"""
1233-
fig, ax = plt.subplots(figsize=fig_size)
1418+
_, ax = plt.subplots(figsize=fig_size)
12341419

12351420
pos = {}
12361421

tests/test_models_dag.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import numpy as np
2121
import pandas as pd
2222
import pytest
23+
from sympy.stats import Beta
2324

2425
from causalAssembly.models_dag import NodeAttributes, ProcessCell, ProductionLineGraph
2526

@@ -603,3 +604,8 @@ def test_between_edges_adjacency_matrix(self):
603604
)
604605
assert between_amat.loc[pline.cell1.nodes, pline.cell1.nodes].sum().sum() == 0
605606
assert between_amat.loc[pline.cell2.nodes, pline.cell2.nodes].sum().sum() == 0
607+
608+
def test_interventional_drf_error(self):
609+
testline = ProductionLineGraph()
610+
with pytest.raises(ValueError):
611+
testline.sample_from_interventional_drf()

0 commit comments

Comments
 (0)