Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion docs/tutorial/missingness_lingam.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ where :math:`i\in\{1,\dots,n\}\mapsto k(i)` denotes a causal order, and the non-
The induced subgraph :math:`G[V_o \cup V_m \cup R]` follows a LiM model. The missingness mechanisms :math:`R_i \in R` follow a logistic model as for binary variables in LiM [4]_:

.. math::
x_i = \mathbf 1\llbracket\sum_{k(j)<k(i)} b_{ij} x_j + e_i > 0\rrbracket, \qquad e_i \sim \text{Logistic}(0,1)
x_i = \mathbf 1[\sum_{k(j)<k(i)} b_{ij} x_j + e_i > 0], \qquad e_i \sim \text{Logistic}(0,1)


Assumptions
Expand All @@ -50,6 +50,9 @@ The following assumptions are made to ensure identifiability:
Note that even if self-masking is not allowed, indirect self-masking is: a partially observed variable can be an indirect cause (an ancestor) of its own missingness mechanism.
Under these assumptions, m-LiNGAM guarantees identifiability of both the causal structure and parameters from observational data in the large-sample limit.

Example
^^^^^^^^^^^^^^^^^^

An example Python notebook demonstrating m-LiNGAM is available `here <https://github.com/cdt15/lingam/blob/master/examples/MissingnessLiNGAM.ipynb>`__.

References
Expand Down
97 changes: 44 additions & 53 deletions examples/MissingnessLiNGAM.ipynb
Original file line number Diff line number Diff line change
@@ -1,44 +1,35 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "559e060a",
"metadata": {},
"source": [
"In this example, we need to import `numpy`, `pandas`, `graphviz`, and `IPython` in addition to `lingam`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 6,
"id": "cad9ee87",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['1.26.4', '2.2.3', '0.20.3', '8.2.0', '1.10.0']\n"
"['1.26.1', '2.1.2', '0.20.1', '1.12.2']\n"
]
}
],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import graphviz\n",
"import IPython\n",
"from IPython.display import display, HTML\n",
"import lingam\n",
"from IPython.display import display, HTML\n",
"\n",
"print([np.__version__, pd.__version__, graphviz.__version__, IPython.__version__, lingam.__version__])\n",
"print([np.__version__, pd.__version__, graphviz.__version__, lingam.__version__])\n",
"\n",
"np.set_printoptions(precision=3, suppress=True)\n",
"np.random.seed(100)"
]
},
{
"cell_type": "markdown",
"id": "dfd35512",
"id": "f841fa9f",
"metadata": {},
"source": [
"# Missingness-LiNGAM (m-LiNGAM)\n",
Expand All @@ -47,37 +38,38 @@
},
{
"cell_type": "markdown",
"id": "52cdf6e2",
"id": "9d1feb55",
"metadata": {},
"source": [
"First, we initialize a test grah and generate a dataset affected by missingness."
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 17,
"id": "0b70fb17",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Percentage of rows affected by missing data: 0.3558\n"
"Percentage of rows affected by missing data: 0.30718\n"
]
}
],
"source": [
"# Initialize the parameters\n",
"sample_size=5000\n",
"#Ex. 2 of MV-PC\n",
"sample_size=50000\n",
"logistic_scale=1.0\n",
"B_true = np.array([[0,0,0,0],[1.1,0,0,0],[0,0.9,0,0],[0.8,0,1.2,0]])\n",
"\n",
"# Generate the data\n",
"X = np.random.laplace(size=sample_size, scale=1.0)\n",
"Z = 1.1*X + np.random.laplace(size=sample_size, scale=1.0)\n",
"Y = 0.9*Z + np.random.laplace(size=sample_size, scale=1.0)\n",
"W = 0.8*X + 1.2*Y + np.random.laplace(size=sample_size, scale=1.0)\n",
"Ry = (0.5*W -1 + np.random.logistic(size=sample_size, scale=1.0))>0\n",
"Ry = (1.0*W -2 + np.random.logistic(size=sample_size, scale=logistic_scale))>0\n",
"\n",
"# Mask missing values\n",
"Ys = Y.copy()\n",
Expand All @@ -90,7 +82,7 @@
},
{
"cell_type": "markdown",
"id": "3d0fa9b7",
"id": "e86ad509",
"metadata": {},
"source": [
"Then, we run both `DirectLiNGAM` and `mLiNGAM` on the same dataset. Since DirectLiNGAM cannot handle missing values, we apply it to the dataset after removing all rows affected by missingness."
Expand All @@ -110,18 +102,10 @@
"ml.fit(ds);"
]
},
{
"cell_type": "markdown",
"id": "c94a91e2",
"metadata": {},
"source": [
"We can now compare the ground truth graph, the `mLiNGAM` output graph, and the list-wise deletion `DirectLiNGAM` graph."
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "7b6fc2e1",
"execution_count": 19,
"id": "a9172d95",
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -204,7 +188,7 @@
"<title>3&#45;&gt;4</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M126.42,-18C148.44,-18 179.63,-18 204.21,-18\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"204.16,-21.5 214.16,-18 204.16,-14.5 204.16,-21.5\"/>\r\n",
"<text text-anchor=\"middle\" x=\"156.69\" y=\"-21.2\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.5</text>\r\n",
"<text text-anchor=\"middle\" x=\"156.69\" y=\"-21.2\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.0</text>\r\n",
"</g>\r\n",
"</g>\r\n",
"</svg>\r\n",
Expand Down Expand Up @@ -239,7 +223,7 @@
"<title>0&#45;&gt;1</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M42.27,-105.27C52.01,-115.01 64.81,-127.81 75.79,-138.79\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"73.14,-141.09 82.68,-145.68 78.09,-136.14 73.14,-141.09\"/>\r\n",
"<text text-anchor=\"middle\" x=\"47.03\" y=\"-125.23\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.11</text>\r\n",
"<text text-anchor=\"middle\" x=\"50.41\" y=\"-125.23\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.1</text>\r\n",
"</g>\r\n",
"<!-- 3 -->\r\n",
"<g id=\"node4\" class=\"node\">\r\n",
Expand All @@ -252,7 +236,7 @@
"<title>0&#45;&gt;3</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M42.27,-74.73C52.01,-64.99 64.81,-52.19 75.79,-41.21\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"78.09,-43.86 82.68,-34.32 73.14,-38.91 78.09,-43.86\"/>\r\n",
"<text text-anchor=\"middle\" x=\"47.03\" y=\"-44.67\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.82</text>\r\n",
"<text text-anchor=\"middle\" x=\"47.03\" y=\"-44.67\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.76</text>\r\n",
"</g>\r\n",
"<!-- 2 -->\r\n",
"<g id=\"node3\" class=\"node\">\r\n",
Expand All @@ -265,14 +249,14 @@
"<title>1&#45;&gt;2</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M114.27,-146.73C124.01,-136.99 136.81,-124.19 147.79,-113.21\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"150.09,-115.86 154.68,-106.32 145.14,-110.91 150.09,-115.86\"/>\r\n",
"<text text-anchor=\"middle\" x=\"119.03\" y=\"-116.67\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.87</text>\r\n",
"<text text-anchor=\"middle\" x=\"119.03\" y=\"-116.67\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.86</text>\r\n",
"</g>\r\n",
"<!-- 2&#45;&gt;3 -->\r\n",
"<g id=\"edge4\" class=\"edge\">\r\n",
"<title>2&#45;&gt;3</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M155.73,-74.73C145.99,-64.99 133.19,-52.19 122.21,-41.21\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"124.86,-38.91 115.32,-34.32 119.91,-43.86 124.86,-38.91\"/>\r\n",
"<text text-anchor=\"middle\" x=\"130.34\" y=\"-61.17\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.2</text>\r\n",
"<text text-anchor=\"middle\" x=\"126.97\" y=\"-61.17\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.18</text>\r\n",
"</g>\r\n",
"<!-- 4 -->\r\n",
"<g id=\"node5\" class=\"node\">\r\n",
Expand All @@ -285,7 +269,7 @@
"<title>3&#45;&gt;4</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M126.42,-18C148.44,-18 179.63,-18 204.21,-18\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"204.16,-21.5 214.16,-18 204.16,-14.5 204.16,-21.5\"/>\r\n",
"<text text-anchor=\"middle\" x=\"153.32\" y=\"-21.2\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.49</text>\r\n",
"<text text-anchor=\"middle\" x=\"153.32\" y=\"-21.2\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.99</text>\r\n",
"</g>\r\n",
"</g>\r\n",
"</svg>\r\n",
Expand Down Expand Up @@ -320,7 +304,20 @@
"<title>0&#45;&gt;1</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M42.27,-105.27C52.01,-115.01 64.81,-127.81 75.79,-138.79\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"73.14,-141.09 82.68,-145.68 78.09,-136.14 73.14,-141.09\"/>\r\n",
"<text text-anchor=\"middle\" x=\"47.03\" y=\"-125.23\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.03</text>\r\n",
"<text text-anchor=\"middle\" x=\"47.03\" y=\"-125.23\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.97</text>\r\n",
"</g>\r\n",
"<!-- 2 -->\r\n",
"<g id=\"node3\" class=\"node\">\r\n",
"<title>2</title>\r\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"171\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\r\n",
"<text text-anchor=\"middle\" x=\"171\" y=\"-84.95\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">Y</text>\r\n",
"</g>\r\n",
"<!-- 0&#45;&gt;2 -->\r\n",
"<g id=\"edge2\" class=\"edge\">\r\n",
"<title>0&#45;&gt;2</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M54.42,-90C76.44,-90 107.63,-90 132.21,-90\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"132.16,-93.5 142.16,-90 132.16,-86.5 132.16,-93.5\"/>\r\n",
"<text text-anchor=\"middle\" x=\"79.07\" y=\"-93.2\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">&#45;0.06</text>\r\n",
"</g>\r\n",
"<!-- 3 -->\r\n",
"<g id=\"node4\" class=\"node\">\r\n",
Expand All @@ -329,31 +326,25 @@
"<text text-anchor=\"middle\" x=\"99\" y=\"-12.95\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">W</text>\r\n",
"</g>\r\n",
"<!-- 0&#45;&gt;3 -->\r\n",
"<g id=\"edge2\" class=\"edge\">\r\n",
"<g id=\"edge3\" class=\"edge\">\r\n",
"<title>0&#45;&gt;3</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M42.27,-74.73C52.01,-64.99 64.81,-52.19 75.79,-41.21\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"78.09,-43.86 82.68,-34.32 73.14,-38.91 78.09,-43.86\"/>\r\n",
"<text text-anchor=\"middle\" x=\"47.03\" y=\"-44.67\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.77</text>\r\n",
"</g>\r\n",
"<!-- 2 -->\r\n",
"<g id=\"node3\" class=\"node\">\r\n",
"<title>2</title>\r\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"171\" cy=\"-90\" rx=\"27\" ry=\"18\"/>\r\n",
"<text text-anchor=\"middle\" x=\"171\" y=\"-84.95\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">Y</text>\r\n",
"<text text-anchor=\"middle\" x=\"47.03\" y=\"-44.67\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.75</text>\r\n",
"</g>\r\n",
"<!-- 1&#45;&gt;2 -->\r\n",
"<g id=\"edge3\" class=\"edge\">\r\n",
"<g id=\"edge4\" class=\"edge\">\r\n",
"<title>1&#45;&gt;2</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M114.27,-146.73C124.01,-136.99 136.81,-124.19 147.79,-113.21\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"150.09,-115.86 154.68,-106.32 145.14,-110.91 150.09,-115.86\"/>\r\n",
"<text text-anchor=\"middle\" x=\"119.03\" y=\"-116.67\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">0.81</text>\r\n",
"</g>\r\n",
"<!-- 2&#45;&gt;3 -->\r\n",
"<g id=\"edge4\" class=\"edge\">\r\n",
"<g id=\"edge5\" class=\"edge\">\r\n",
"<title>2&#45;&gt;3</title>\r\n",
"<path fill=\"none\" stroke=\"black\" d=\"M155.73,-74.73C145.99,-64.99 133.19,-52.19 122.21,-41.21\"/>\r\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"124.86,-38.91 115.32,-34.32 119.91,-43.86 124.86,-38.91\"/>\r\n",
"<text text-anchor=\"middle\" x=\"126.97\" y=\"-61.17\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.13</text>\r\n",
"<text text-anchor=\"middle\" x=\"130.34\" y=\"-61.17\" font-family=\"Times New Roman,serif\" font-size=\"14.00\">1.1</text>\r\n",
"</g>\r\n",
"</g>\r\n",
"</svg>\r\n",
Expand Down Expand Up @@ -400,7 +391,7 @@
"# Add missingness mechanisms to m-LiNGAM's output\n",
"dotML.node('4', 'Ry', pos='3,0!')\n",
"dotTrue.node('4', 'Ry', pos='3,0!')\n",
"dotTrue.edge('3', '4', label=str(0.5))\n",
"dotTrue.edge('3', '4', label=str(1.0))\n",
"for k,missigness_parent in enumerate(ml._missingness_mechanisms_parents[2]):\n",
" dotML.edge(str(missigness_parent), '4', label=str(round(ml._missingness_mechanisms_coef[2][k+1], 2)))\n",
"\n",
Expand Down Expand Up @@ -432,7 +423,7 @@
},
{
"cell_type": "markdown",
"id": "57849ac4",
"id": "f08fd54b",
"metadata": {},
"source": [
"As expected, while naively applying `DirectLiNGAM` to the list-wise deleted dataset produce extraneous edges and biased estimates for the parameters, `mLiNGAM` is able to produce a more accurate estimate of both the graph structure and the parameters. Note that `mLiNGAM` also reconstructs the missingness mechanism that led to missingness, correctly identyfing that the data was missing at random with $W$ as parent of $R_Y$, the missingness mechanism corresponding to variable $Y$."
Expand All @@ -441,7 +432,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "base",
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
Expand All @@ -455,7 +446,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.12.0"
}
},
"nbformat": 4,
Expand Down
Loading