-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
124 lines (88 loc) · 3.35 KB
/
utils.py
File metadata and controls
124 lines (88 loc) · 3.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
# General importations.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from ast import literal_eval
class Utils():
def get_cct(self,
n_cutpoints: int = 1,
names: list = None,
plot: bool = True) -> nx.classes.graph.Graph:
'''
Generates the commutative cut tree associated with a given number of cutpoints.
Optionally, the user can input node names.
'''
adj_cct = np.triu(np.ones((n_cutpoints+2,n_cutpoints+2)), k = 1)
cct = nx.from_numpy_array(adj_cct, create_using = nx.DiGraph)
if names is not None:
cct = nx.relabel_nodes(cct, dict(zip(cct.nodes,names)))
return cct
def get_total_paths_cct(self, n: int) -> int:
'''
n = total nodes in CCT
'''
return 2**(n-2)
def get_prc_direct(self,
df: pd.DataFrame,
x: str,
y: str,
y_do_x0: str,
y_do_x1: str) -> dict:
'''
Compute the probabilities of causation directly from observed and interventional data.
'''
res = dict()
df = df.astype("bool")
res['PN'] = np.mean(~df[df[x] & df[y]][y_do_x0])
res['PS'] = np.mean(df[~df[x] & ~df[y]][y_do_x1])
res['PNS'] = np.mean(df[y_do_x1] & ~df[y_do_x0])
return res
def get_pns_direct(self,
df: pd.DataFrame,
y_do_x0: str,
y_do_x1: str) -> float:
'''
Compute the PNS directly from interventional data.
'''
df = df.astype("bool")
return np.mean(df[y_do_x1] & ~df[y_do_x0])
def get_ate(self,
df: pd.DataFrame,
y_do_x1: str,
y_do_x0: str) -> float:
return df[y_do_x1].mean() - df[y_do_x0].mean()
def plot_nx(self,
adjacency_matrix: np.ndarray,
labels: list,
figsize: tuple = (10,10),
dpi: int = 200,
node_size: int = 800,
arrow_size: int = 10):
'''
Plot graph in networkx from adjacency matrix.
'''
g = nx.from_numpy_array(adjacency_matrix, create_using = nx.DiGraph)
plt.figure(figsize = figsize, dpi = dpi)
nx.draw_circular(g,
node_size = node_size,
labels = dict(zip(list(range(len(labels))), labels)),
arrowsize = arrow_size,
node_color = "pink",
with_labels = True)
plt.show()
plt.close()
def get_rae(self,
true: np.array,
pred: np.array) -> float:
return abs(true - pred) / true
def string_to_array(self,
array_string: str) -> np.array:
'''
Convert adjacency matrices back to numpy arrays when imported
dataframes automatically cast cell contents as strings.
'''
cleaned_string = array_string.replace('\n', '')
cleaned_string = cleaned_string.replace(' ', ', ')
new_list = literal_eval(cleaned_string)
return np.array(new_list)