diff --git a/Chapter11/Temporal_GraphML.ipynb b/Chapter11/Temporal_GraphML.ipynb new file mode 100644 index 0000000..2e5d13b --- /dev/null +++ b/Chapter11/Temporal_GraphML.ipynb @@ -0,0 +1,966 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "su-ySan88ru6" + }, + "source": [ + "# Temporal GraphML" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IJ0eLdE5lh3d" + }, + "source": [ + "In this notebook, we will introduce representative examples of the machine learning approaches for dealing with temporal graphs. We will offer a general understanding of their implementation using publicly available frameworks." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UfEu4ta08v0e" + }, + "source": [ + "## Temporal Matrix Factorization\n", + "\n", + "Temporal Matrix Factorization model (TMF) by Yu et al. (2017) is a method used for temporal link prediction, particularly in dynamic network scenarios. This technique leverages matrix factorization with temporal dynamics to model the evolution of links in a dynamic network over time.\n", + "\n", + "We adopt the implementation provided in the publicly available library [OpenTLP](https://github.com/KuroginQin/OpenTLP). It integrates an encoder-decoder architecture, where the encoder learns model parameters through matrix factorization, and the decoder generates predictions based on these parameters." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "AnL-blMhkR5a", + "outputId": "697a0d81-c5ca-43b9-dbc1-4496cc4d5ddb" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cloning into 'OpenTLP'...\n", + "remote: Enumerating objects: 147, done.\u001b[K\n", + "remote: Counting objects: 100% (32/32), done.\u001b[K\n", + "remote: Compressing objects: 100% (5/5), done.\u001b[K\n", + "remote: Total 147 (delta 27), reused 27 (delta 27), pack-reused 115 (from 1)\u001b[K\n", + "Receiving objects: 100% (147/147), 13.13 MiB | 19.66 MiB/s, done.\n", + "Resolving deltas: 100% (68/68), done.\n" + ] + } + ], + "source": [ + "# Donwload the OpenTLP repository\n", + "!git clone https://github.com/KuroginQin/OpenTLP.git" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sDiVZdas-d5_" + }, + "source": [ + "OpenTLP contains a set of useful temporal graph data. Let's unzip it as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "j-UOMplglz0c", + "outputId": "d353dc92-d81b-439e-cdce-23ad530aa989" + }, + "outputs": [], + "source": [ + "# Unzip the sample graph\n", + "import zipfile\n", + "zipfile.ZipFile('OpenTLP/Python/data/data.zip').extractall('data')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QRriEYix-m2C" + }, + "source": [ + "For the TMF example, the Enron dataset is used as a case study.\n", + "The dataset consists of temporal snapshots of a graph with 184 nodes over 26 time points. Historical edge sequences (`edge_seq`) are loaded, representing graph snapshots at different timestamps.\n", + "\n", + "Let's now implementing the TMF example. This code is adapted from the [OpenTLP examples](https://github.com/KuroginQin/OpenTLP/blob/main/Python/TMF_demo1.py).\n", + "\n", + "### 1. Model Setup\n", + "- TMF is implemented with the following parameters:\n", + " - **Latent dimensionality of node embeddings** (`hid_dim = 64`).\n", + " - **Regularization coefficients** for model optimization (`alpha`, `beta`, and `theta`).\n", + " - **Learning rate** for gradient-based optimization.\n", + " - A **sliding window of historical snapshots** (`win_size = 5`) is used to predict the adjacency structure at the next time step.\n", + "\n", + "### 2. Training the TMF Model\n", + "- For each time step after the historical window (`win_size` to `num_snaps`):\n", + " - The model uses the last `win_size` adjacency matrices to learn a low-dimensional representation of the graph.\n", + " - The learned representation is used to predict the adjacency matrix for the current time step.\n", + "- The adjacency matrices are refined to ensure symmetry and zero diagonal elements.\n", + "\n", + "### 3. Evaluation\n", + "- The **Area Under the Curve (AUC)** score is computed to evaluate the quality of predictions against the ground truth adjacency matrix at each time step.\n", + "- The average AUC and standard deviation across all time steps are reported as metrics of model performance.\n", + "\n", + "## Results\n", + "The model iteratively predicts the graph structure for each snapshot and computes the corresponding AUC. This provides insight into the TMF's ability to generalize and learn temporal patterns from historical graph data." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('OpenTLP/Python/')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GlafCysq_Nuq", + "outputId": "e3e307c5-fef0-438e-e0c9-7f25e2f8fd13" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Snapshot 5: AUC = 0.752615\n", + "Snapshot 6: AUC = 0.797639\n", + "Snapshot 7: AUC = 0.765659\n", + "Snapshot 8: AUC = 0.834835\n", + "Snapshot 9: AUC = 0.860552\n", + "Snapshot 10: AUC = 0.855638\n", + "Snapshot 11: AUC = 0.880935\n", + "Snapshot 12: AUC = 0.836431\n", + "Snapshot 13: AUC = 0.874893\n", + "Snapshot 14: AUC = 0.864066\n", + "Snapshot 15: AUC = 0.879439\n", + "Snapshot 16: AUC = 0.772748\n", + "Snapshot 17: AUC = 0.800971\n", + "Snapshot 18: AUC = 0.805514\n", + "Snapshot 19: AUC = 0.760164\n", + "Snapshot 20: AUC = 0.806832\n", + "Snapshot 21: AUC = 0.805647\n", + "Snapshot 22: AUC = 0.857036\n", + "Snapshot 23: AUC = 0.923815\n", + "Snapshot 24: AUC = 0.859344\n", + "Snapshot 25: AUC = 0.856243\n", + "Mean AUC: 0.831001\n", + "Standard Deviation of AUC: 0.046278\n" + ] + } + ], + "source": [ + "# Import necessary libraries and modules\n", + "import numpy as np\n", + "import torch\n", + "from TMF.TMF import TMF # Custom TMF implementation\n", + "from utils import get_adj_un, get_AUC # Utility functions for adjacency and evaluation\n", + "\n", + "# Check if GPU is available, otherwise use CPU\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "\n", + "# ====================\n", + "# Dataset and model configuration\n", + "data_name = 'Enron' # Name of the dataset\n", + "num_nodes = 184 # Total number of nodes in the graph\n", + "num_snaps = 26 # Total number of snapshots (time points)\n", + "hid_dim = 64 # Dimensionality of the latent space\n", + "theta = 0.1 # Regularization parameter for model training\n", + "alpha = 0.01 # TMF-specific hyperparameter\n", + "beta = 0.01 # TMF-specific hyperparameter\n", + "\n", + "# Load edge sequences from dataset\n", + "edge_seq = np.load(f'data/{data_name}_edge_seq.npy', allow_pickle=True)\n", + "\n", + "# ====================\n", + "# Training hyperparameters\n", + "learn_rate = 1e-3 # Learning rate for the optimizer\n", + "win_size = 5 # Size of the historical window for snapshots\n", + "num_epochs = 200 # Number of training epochs\n", + "\n", + "# ====================\n", + "# Initialize a list to store AUC scores for each snapshot\n", + "AUC_list = []\n", + "\n", + "# Iterate through snapshots, starting after the initial window size\n", + "for tau in range(win_size, num_snaps):\n", + " # Ground truth edges for the current snapshot\n", + " edges = edge_seq[tau]\n", + " gnd = get_adj_un(edges, num_nodes) # Generate ground truth adjacency matrix\n", + "\n", + " # Collect adjacency matrices for historical snapshots within the window\n", + " adj_list = []\n", + " for t in range(tau - win_size, tau):\n", + " edges = edge_seq[t]\n", + " adj = get_adj_un(edges, num_nodes)\n", + " adj_tnr = torch.FloatTensor(adj).to(device)\n", + " adj_list.append(adj_tnr)\n", + "\n", + " # Initialize and train the TMF model\n", + " TMF_model = TMF(num_nodes,\n", + " hid_dim,\n", + " win_size,\n", + " num_epochs,\n", + " alpha,\n", + " beta,\n", + " theta,\n", + " learn_rate,\n", + " device)\n", + "\n", + " adj_est = TMF_model.TMF_fun(adj_list) # Predict adjacency matrix for the current snapshot\n", + "\n", + " # Convert predicted adjacency matrix to NumPy array if necessary\n", + " adj_est = adj_est.cpu().data.numpy() if torch.cuda.is_available() else adj_est.data.numpy()\n", + "\n", + " # Refine the predicted adjacency matrix\n", + " adj_est = (adj_est + adj_est.T) / 2 # Ensure symmetry\n", + " np.fill_diagonal(adj_est, 0) # Set diagonal elements to 0 (no self-loops)\n", + "\n", + " # Evaluate prediction quality using AUC metric\n", + " AUC = get_AUC(adj_est, gnd, num_nodes)\n", + " AUC_list.append(AUC)\n", + " print(f'Snapshot {tau}: AUC = {AUC:.6f}')\n", + "\n", + "# ====================\n", + "# Compute mean and standard deviation of AUC scores\n", + "AUC_mean = np.mean(AUC_list)\n", + "AUC_std = np.std(AUC_list, ddof=1)\n", + "\n", + "# Display overall results\n", + "print(f'Mean AUC: {AUC_mean:.6f}')\n", + "print(f'Standard Deviation of AUC: {AUC_std:.6f}')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "uKW8CUdMrYZX" + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "p37_4aR1vzrw" + }, + "source": [ + "## Temporal Random Walk\n", + "We pesent here a Temporal Random Walk-based method called CTDNE, by Nguyen et al. (2018), for learning time-preserving embedding.\n", + "\n", + "This code is adapted from [StellarGraph](https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/link-prediction/ctdne-link-prediction.ipynb#scrollTo=I2Vw-NfmeMU5)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4kC5hO9NJKKb" + }, + "source": [ + "In this example we will be using again the Enron dataset. You can find a specific class handling the Enron dataset in StellarGraph." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LSHjNRIAJlym", + "outputId": "043f13c8-99ab-4040-8ff9-f0e1384efc6a" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-06-23 12:11:15.352421: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n", + "2025-06-23 12:11:15.352784: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.\n", + "2025-06-23 12:11:16.775450: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory\n", + "2025-06-23 12:11:16.775493: W tensorflow/stream_executor/cuda/cuda_driver.cc:269] failed call to cuInit: UNKNOWN ERROR (303)\n", + "2025-06-23 12:11:16.775520: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (08c0bdca2fee): /proc/driver/nvidia/version does not exist\n", + "2025-06-23 12:11:16.775759: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n", + "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "A dataset of edges that represent emails sent from one employee to another.There are 50572 edges, and each of them contains timestamp information. Edges refer to 151 unique node IDs in total.Ryan A. Rossi and Nesreen K. Ahmed “The Network Data Repository with Interactive Graph Analytics and Visualization” (2015)\n" + ] + } + ], + "source": [ + "from stellargraph.datasets import IAEnronEmployees\n", + "\n", + "dataset = IAEnronEmployees()\n", + "print(dataset.description)\n", + "full_graph, edges = dataset.load()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HMntKbp3Kngp" + }, + "source": [ + "In this example, we show how random walks can be obtained from time graphs, and how they can be used to generate network embeddings for a link prediction task.\n", + "\n", + "Since we will be address a link prediction task, let's prepare the graph for the task: Let's split the edges into two parts:\n", + "\n", + "* the oldest edges are used to create the graph structure\n", + "* the recent edges are what we are interested in predicting - we randomly split this part further into training and test sets." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "_ZtuSGtfLGNJ" + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from stellargraph import StellarGraph\n", + "\n", + "# Finally, let's create an instance of the StellarGraph class\n", + "graph = StellarGraph(\n", + " nodes=pd.DataFrame(index=full_graph.nodes()),\n", + " edges=edges,\n", + " edge_weight_column=\"time\",\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "urrcUoAOL38I" + }, + "source": [ + "It's now time for running the Temporal Random Walk algorithm" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "uaPnd-BZL-RR", + "outputId": "89bb875f-e990-4f90-b430-a602bd6444c2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of temporal random walks: 1730\n" + ] + } + ], + "source": [ + "from stellargraph.data import TemporalRandomWalk\n", + "from gensim.models import Word2Vec\n", + "\n", + "num_walks_per_node = 10\n", + "walk_length = 80\n", + "context_window_size = 10\n", + "\n", + "num_cw = len(graph.nodes()) * num_walks_per_node * (walk_length - context_window_size + 1)\n", + "\n", + "temporal_rw = TemporalRandomWalk(graph)\n", + "temporal_walks = temporal_rw.run(\n", + " num_cw=num_cw,\n", + " cw_size=context_window_size,\n", + " max_walk_length=walk_length,\n", + " walk_bias=\"exponential\",\n", + ")\n", + "\n", + "print(\"Number of temporal random walks: {}\".format(len(temporal_walks)))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "IyRwcekhNAt9" + }, + "outputs": [], + "source": [ + "embedding_size = 128\n", + "temporal_model = Word2Vec(\n", + " temporal_walks,\n", + " vector_size=embedding_size, # \"size\" in older gensim versions\n", + " window=context_window_size,\n", + " min_count=0,\n", + " sg=1,\n", + " workers=2,\n", + " epochs=1, # \"iter\" in older gensim versions\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7ubW4bxyMsKV" + }, + "source": [ + "Let's visualize the embeddings:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 657 + }, + "id": "VLeUlo9wMpm7", + "outputId": "0ab52d41-7306-480f-81a3-b41aa85218f0" + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import numpy as np\n", + "np.random.seed(5)\n", + "\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.manifold import TSNE\n", + "%matplotlib inline\n", + "\n", + "def plot_tsne(title, x, y=None):\n", + " tsne = TSNE(n_components=2)\n", + " x_t = tsne.fit_transform(x)\n", + "\n", + " plt.figure(figsize=(7, 7))\n", + " plt.title(title)\n", + " alpha = 0.7 if y is None else 0.5\n", + "\n", + " scatter = plt.scatter(x_t[:, 0], x_t[:, 1], c=y, cmap=\"jet\", alpha=alpha)\n", + " if y is not None:\n", + " plt.legend(*scatter.legend_elements(), loc=\"lower left\", title=\"Classes\")\n", + "\n", + "temporal_node_embeddings = temporal_model.wv.vectors\n", + "plot_tsne(\"TSNE visualisation of temporal node embeddings\", temporal_node_embeddings);" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dXukSMRqCcct" + }, + "source": [ + "You may want to use these embeddings for downstream tasks such as link prediction! To this aim, you can split the dataset in order to create \"future\" link examples. Check [the stellargraph repo](https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/link-prediction/ctdne-link-prediction.ipynb#scrollTo=I2Vw-NfmeMU5) for a full example" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_SBxlsmfUkMx" + }, + "source": [ + "## Temporal Graph Neural Network\n", + "In this example, we will explore the implementation of Temporal Graph Networks (TGN) using PyTorch Geometric (PyG). TGNs are designed to handle dynamic graphs where interactions between nodes occur at different timestamps. We'll use the Wikipedia dataset from JODIE, where nodes represent users and articles, and edges represent user-article interactions." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VhcWnVDsVAOn" + }, + "source": [ + "First, let's set up our environment and load the data.\n", + "We use the JODIE Wikipedia dataset, which contains temporal interactions between users and articles.\n", + "\n", + "The `TemporalDataLoader` is specially designed for temporal graphs. The `neg_sampling_ratio=1.0` means for each positive edge, we sample one negative edge for training." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "WylvEgEsUPEK" + }, + "outputs": [], + "source": [ + "# Setup and Data Loading\n", + "import os.path as osp\n", + "import torch\n", + "from sklearn.metrics import average_precision_score, roc_auc_score\n", + "from torch.nn import Linear\n", + "from torch_geometric.datasets import JODIEDataset\n", + "from torch_geometric.loader import TemporalDataLoader\n", + "from torch_geometric.nn.models.tgn import LastNeighborLoader\n", + "\n", + "# Device configuration\n", + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "\n", + "# Load Wikipedia dataset from JODIE\n", + "path = osp.join('data', 'JODIE')\n", + "dataset = JODIEDataset(path, name='wikipedia')\n", + "data = dataset[0]\n", + "data = data.to(device) # Move data to GPU if available\n", + "\n", + "# Split dataset into train, validation, and test sets\n", + "train_data, val_data, test_data = data.train_val_test_split(\n", + " val_ratio=0.15, test_ratio=0.15)\n", + "\n", + "# Create data loaders with negative sampling\n", + "train_loader = TemporalDataLoader(\n", + " train_data,\n", + " batch_size=200,\n", + " neg_sampling_ratio=1.0,\n", + ")\n", + "\n", + "val_loader = TemporalDataLoader(\n", + " val_data,\n", + " batch_size=200,\n", + " neg_sampling_ratio=1.0,\n", + ")\n", + "\n", + "test_loader = TemporalDataLoader(\n", + " test_data,\n", + " batch_size=200,\n", + " neg_sampling_ratio=1.0,\n", + ")\n", + "\n", + "neighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "x3JPyEy3VPOn" + }, + "source": [ + "Let's now proceed implementing the key components of TGN.\n", + "* The memory module is a key component of TGN that maintains node states over time" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "sX2UdSRFVb8Y" + }, + "outputs": [], + "source": [ + "from torch_geometric.nn import TGNMemory, TransformerConv\n", + "from torch_geometric.nn.models.tgn import (\n", + " IdentityMessage,\n", + " LastAggregator,\n", + " LastNeighborLoader,\n", + ")\n", + "\n", + "memory_dim = 100\n", + "time_dim = 100\n", + "embedding_dim = 100\n", + "\n", + "memory = TGNMemory(\n", + " data.num_nodes, # Number of nodes in the graph\n", + " data.msg.size(-1), # Message dimension\n", + " memory_dim, # Memory dimension\n", + " time_dim, # Time encoding dimension\n", + " message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),\n", + " aggregator_module=LastAggregator(),\n", + ").to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "b-paMXo2VqOq" + }, + "source": [ + "* Together with the `TGNMemory`, we will also create a GNN for obtaining the embeddings. In this example, we will define a `GraphAttentionEmbedding` class, which uses the `TransformerConv` module (a message passing module implemented in PyTorch)." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "x_MgWFMZVusK" + }, + "outputs": [], + "source": [ + "class GraphAttentionEmbedding(torch.nn.Module):\n", + " def __init__(self, in_channels, out_channels, msg_dim, time_enc):\n", + " super().__init__()\n", + " self.time_enc = time_enc\n", + " edge_dim = msg_dim + time_enc.out_channels\n", + " self.conv = TransformerConv(in_channels, out_channels // 2, heads=2,\n", + " dropout=0.1, edge_dim=edge_dim)\n", + "\n", + " def forward(self, x, last_update, edge_index, t, msg):\n", + " # Compute relative temporal encoding\n", + " rel_t = last_update[edge_index[0]] - t\n", + " rel_t_enc = self.time_enc(rel_t.to(x.dtype))\n", + " # Concatenate temporal and message features\n", + " edge_attr = torch.cat([rel_t_enc, msg], dim=-1)\n", + " return self.conv(x, edge_index, edge_attr)\n", + "\n", + "# Create the GNN\n", + "gnn = GraphAttentionEmbedding(\n", + " in_channels=memory_dim,\n", + " out_channels=embedding_dim,\n", + " msg_dim=data.msg.size(-1),\n", + " time_enc=memory.time_enc,\n", + ").to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jbRFwBECV0My" + }, + "source": [ + "The GraphAttentionEmbedding uses a transformer-based graph convolution that:\n", + "1. Encodes temporal information using relative timestamps\n", + "2. Combines temporal encodings with edge messages\n", + "3. Applies multi-head attention to compute node embeddings" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "K0AapIP9V7FC" + }, + "source": [ + "* Finally, let's use a simple MLP that predicts link probabilities between node pairs:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "id": "NH_0oS0OV_Sx" + }, + "outputs": [], + "source": [ + "class LinkPredictor(torch.nn.Module):\n", + " def __init__(self, in_channels):\n", + " super().__init__()\n", + " self.lin_src = Linear(in_channels, in_channels)\n", + " self.lin_dst = Linear(in_channels, in_channels)\n", + " self.lin_final = Linear(in_channels, 1)\n", + "\n", + " def forward(self, z_src, z_dst):\n", + " h = self.lin_src(z_src) + self.lin_dst(z_dst)\n", + " h = h.relu()\n", + " return self.lin_final(h)\n", + "\n", + "# Create the LinkPredictor Object\n", + "link_pred = LinkPredictor(in_channels=embedding_dim).to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "poGFIqhKWD02" + }, + "source": [ + "### Training\n", + "A few important points about the training:\n", + "\n", + "1. Memory and neighbor states are reset at the start of each epoch\n", + "For each batch, we first compute temporal neighborhoods using neighbor_loader\n", + "2. Node embeddings are computed using the current memory state and graph attention\n", + "3. The model predicts both positive and negative links\n", + "After prediction, we update the memory with the true interactions\n" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "id": "IkpNhogDbN-K" + }, + "outputs": [], + "source": [ + "# Let's define the optimizer and the Loss function\n", + "optimizer = torch.optim.Adam(set(memory.parameters()) | set(gnn.parameters()) | set(link_pred.parameters()), lr=0.0001)\n", + "criterion = torch.nn.BCEWithLogitsLoss()\n", + "\n", + "# Helper vector to map global node indices to local ones.\n", + "assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "id": "6ddZznkJWEgP" + }, + "outputs": [], + "source": [ + "def train():\n", + " # Reset memory and neighbor loader states\n", + " memory.reset_state()\n", + " neighbor_loader.reset_state()\n", + "\n", + " total_loss = 0\n", + " for batch in train_loader:\n", + " optimizer.zero_grad()\n", + " batch = batch.to(device)\n", + "\n", + " # Get temporal neighborhood\n", + " n_id, edge_index, e_id = neighbor_loader(batch.n_id)\n", + " assoc[n_id] = torch.arange(n_id.size(0), device=device)\n", + "\n", + " # Compute node embeddings\n", + " z, last_update = memory(n_id)\n", + " z = gnn(z, last_update, edge_index, data.t[e_id].to(device),\n", + " data.msg[e_id].to(device))\n", + "\n", + " # Predict positive and negative links\n", + " pos_out = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]])\n", + " neg_out = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]])\n", + "\n", + " # Compute binary cross entropy loss\n", + " loss = criterion(pos_out, torch.ones_like(pos_out))\n", + " loss += criterion(neg_out, torch.zeros_like(neg_out))\n", + "\n", + " # Update memory and graph structure\n", + " memory.update_state(batch.src, batch.dst, batch.t, batch.msg)\n", + " neighbor_loader.insert(batch.src, batch.dst)\n", + "\n", + " loss.backward()\n", + " optimizer.step()\n", + " memory.detach()\n", + " total_loss += float(loss) * batch.num_events\n", + "\n", + " return total_loss / train_data.num_events" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0mAVcbdiVxjT" + }, + "source": [ + "Let's also implement a testing function for model evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "id": "Fg0UOoVEWdkJ" + }, + "outputs": [], + "source": [ + "@torch.no_grad()\n", + "def test(loader):\n", + " # Set all modules to evaluation mode\n", + " memory.eval()\n", + " gnn.eval()\n", + " link_pred.eval()\n", + "\n", + " # Set random seed for reproducible negative sampling\n", + " torch.manual_seed(12345)\n", + "\n", + " aps, aucs = [], []\n", + " for batch in loader:\n", + " batch = batch.to(device)\n", + "\n", + " # Get temporal neighborhood for current batch\n", + " n_id, edge_index, e_id = neighbor_loader(batch.n_id)\n", + " # Create mapping from global to local node indices\n", + " assoc[n_id] = torch.arange(n_id.size(0), device=device)\n", + "\n", + " # Get node embeddings from memory\n", + " z, last_update = memory(n_id)\n", + " # Update embeddings using graph attention\n", + " z = gnn(z, last_update, edge_index, data.t[e_id].to(device),\n", + " data.msg[e_id].to(device))\n", + "\n", + " # Predict on positive and negative edges\n", + " pos_out = link_pred(z[assoc[batch.src]], z[assoc[batch.dst]])\n", + " neg_out = link_pred(z[assoc[batch.src]], z[assoc[batch.neg_dst]])\n", + "\n", + " # Combine predictions and convert to probabilities\n", + " y_pred = torch.cat([pos_out, neg_out], dim=0).sigmoid().cpu()\n", + " # Create ground truth labels (1 for positive edges, 0 for negative)\n", + " y_true = torch.cat(\n", + " [torch.ones(pos_out.size(0)),\n", + " torch.zeros(neg_out.size(0))], dim=0)\n", + "\n", + " # Calculate metrics\n", + " aps.append(average_precision_score(y_true, y_pred))\n", + " aucs.append(roc_auc_score(y_true, y_pred))\n", + "\n", + " # Update memory and graph with ground truth interactions\n", + " memory.update_state(batch.src, batch.dst, batch.t, batch.msg)\n", + " neighbor_loader.insert(batch.src, batch.dst)\n", + "\n", + " # Return average metrics across all batches\n", + " return float(torch.tensor(aps).mean()), float(torch.tensor(aucs).mean())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sl-cqTNYWtYH" + }, + "source": [ + "We evaluate the model using two metrics:\n", + "\n", + "* Average Precision Score (AP): Measures the precision-recall trade-off\n", + "* Area Under ROC Curve (AUC): Measures the model's ability to distinguish between classes" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TCvdFDqkXtpN" + }, + "source": [ + "Finally, let's train and test the model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9I8tl_PeXpWh", + "outputId": "591f36a7-0f65-43c8-b833-9fd2943c8782" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 01, Loss: 1.1259\n", + "Val AP: 0.8647, Val AUC: 0.8768\n", + "Test AP: 0.8350, Test AUC: 0.8548\n", + "Epoch: 02, Loss: 0.9952\n", + "Val AP: 0.8293, Val AUC: 0.8436\n", + "Test AP: 0.8149, Test AUC: 0.8225\n", + "Epoch: 03, Loss: 0.9123\n", + "Val AP: 0.8679, Val AUC: 0.8697\n", + "Test AP: 0.8468, Test AUC: 0.8471\n", + "Epoch: 04, Loss: 0.8480\n", + "Val AP: 0.8882, Val AUC: 0.8845\n", + "Test AP: 0.8682, Test AUC: 0.8646\n", + "Epoch: 05, Loss: 0.7954\n", + "Val AP: 0.9040, Val AUC: 0.8983\n", + "Test AP: 0.8839, Test AUC: 0.8776\n", + "Epoch: 06, Loss: 0.7564\n", + "Val AP: 0.9130, Val AUC: 0.9057\n", + "Test AP: 0.8935, Test AUC: 0.8834\n", + "Epoch: 07, Loss: 0.7330\n", + "Val AP: 0.9173, Val AUC: 0.9098\n", + "Test AP: 0.9002, Test AUC: 0.8911\n" + ] + } + ], + "source": [ + "# Training and evaluation loop\n", + "for epoch in range(1, 51):\n", + " loss = train()\n", + " print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')\n", + "\n", + " # Evaluate on validation and test sets\n", + " val_ap, val_auc = test(val_loader)\n", + " test_ap, test_auc = test(test_loader)\n", + "\n", + " print(f'Val AP: {val_ap:.4f}, Val AUC: {val_auc:.4f}')\n", + " print(f'Test AP: {test_ap:.4f}, Test AUC: {test_auc:.4f}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xyT5TyO7XwAl" + }, + "source": [ + "The model achieves around 93.50% performance on the Wikipedia dataset. Notice that the performance differs slightly from the original TGN paper as noted in the PyTorch Geometric repository.\n", + "\n", + "Here, a slightly different evaluation setup is used. Predictions within the same batch are made in parallel, meaning that interactions occurring later in the batch do not have access to any information about earlier interactions in the same batch. By contrast, the original TGN paper's code allows access to earlier interactions in the batch when sampling node neighborhoods for later interactions. While both methods are valid, we, in collaboration with the authors of the paper, chose to present this version as it is more realistic and provides a better testing ground for future methodologies." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZcdNbcCEQwnQ" + }, + "source": [ + "### Final notes\n", + "* The interested reader can take a look at [Pytorch Geometric Temporal](https://pytorch-geometric-temporal.readthedocs.io/en/latest/modules/root.html) a temporal graph neural network extension library for PyTorch Geometric.\n", + "* A DGL implementation of TGN can be found [here](https://github.com/ytchx1999/TGN-DGL)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0_1XwgXXRFYM" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "chap4", + "language": "python", + "name": "chap4" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.20" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/Chapter11/requirements.txt b/Chapter11/requirements.txt new file mode 100644 index 0000000..104a080 --- /dev/null +++ b/Chapter11/requirements.txt @@ -0,0 +1,132 @@ +absl-py==2.1.0 ; python_version >= "3.8" and python_version < "3.9" +aiohappyeyeballs==2.4.3 ; python_version >= "3.8" and python_version < "3.9" +aiohttp==3.10.10 ; python_version >= "3.8" and python_version < "3.9" +aiosignal==1.3.1 ; python_version >= "3.8" and python_version < "3.9" +annotated-types==0.7.0 ; python_version >= "3.8" and python_version < "3.9" +appnope==0.1.4 ; python_version >= "3.8" and python_version < "3.9" and (platform_system == "Darwin" or sys_platform == "darwin") +asttokens==2.4.1 ; python_version >= "3.8" and python_version < "3.9" +astunparse==1.6.3 ; python_version >= "3.8" and python_version < "3.9" +async-timeout==4.0.3 ; python_version >= "3.8" and python_version < "3.9" +attrs==24.2.0 ; python_version >= "3.8" and python_version < "3.9" +backcall==0.2.0 ; python_version >= "3.8" and python_version < "3.9" +cachetools==5.5.0 ; python_version >= "3.8" and python_version < "3.9" +certifi==2024.8.30 ; python_version >= "3.8" and python_version < "3.9" +cffi==1.17.1 ; python_version >= "3.8" and python_version < "3.9" and implementation_name == "pypy" +chardet==5.2.0 ; python_version >= "3.8" and python_version < "3.9" +charset-normalizer==3.4.0 ; python_version >= "3.8" and python_version < "3.9" +colorama==0.4.6 ; python_version >= "3.8" and python_version < "3.9" and (sys_platform == "win32" or platform_system == "Windows") +comm==0.2.2 ; python_version >= "3.8" and python_version < "3.9" +cycler==0.12.1 ; python_version >= "3.8" and python_version < "3.9" +debugpy==1.8.7 ; python_version >= "3.8" and python_version < "3.9" +decorator==5.1.1 ; python_version >= "3.8" and python_version < "3.9" +dgl @ https://data.dgl.ai/wheels/torch-2.1/dgl-2.4.0-cp38-cp38-manylinux1_x86_64.whl ; python_version >= "3.8" and python_version < "3.9" +executing==2.1.0 ; python_version >= "3.8" and python_version < "3.9" +filelock==3.16.1 ; python_version >= "3.8" and python_version < "3.9" +flatbuffers==2.0.7 ; python_version >= "3.8" and python_version < "3.9" +frozenlist==1.4.1 ; python_version >= "3.8" and python_version < "3.9" +fsspec==2024.9.0 ; python_version >= "3.8" and python_version < "3.9" +gast==0.4.0 ; python_version >= "3.8" and python_version < "3.9" +gensim==4.3.3 ; python_version >= "3.8" and python_version < "3.9" +google-auth-oauthlib==1.0.0 ; python_version >= "3.8" and python_version < "3.9" +google-auth==2.35.0 ; python_version >= "3.8" and python_version < "3.9" +google-pasta==0.2.0 ; python_version >= "3.8" and python_version < "3.9" +grpcio==1.66.2 ; python_version >= "3.8" and python_version < "3.9" +h5py==3.11.0 ; python_version >= "3.8" and python_version < "3.9" +idna==3.10 ; python_version >= "3.8" and python_version < "3.9" +importlib-metadata==8.5.0 ; python_version >= "3.8" and python_version < "3.9" +ipykernel==6.29.5 ; python_version >= "3.8" and python_version < "3.9" +ipython==8.12.3 ; python_version >= "3.8" and python_version < "3.9" +jedi==0.19.1 ; python_version >= "3.8" and python_version < "3.9" +jinja2==3.1.4 ; python_version >= "3.8" and python_version < "3.9" +joblib==1.4.2 ; python_version >= "3.8" and python_version < "3.9" +jupyter-client==8.6.3 ; python_version >= "3.8" and python_version < "3.9" +jupyter-core==5.7.2 ; python_version >= "3.8" and python_version < "3.9" +keras-preprocessing==1.1.2 ; python_version >= "3.8" and python_version < "3.9" +keras==2.7.0 ; python_version >= "3.8" and python_version < "3.9" +kiwisolver==1.4.7 ; python_version >= "3.8" and python_version < "3.9" +libclang==18.1.1 ; python_version >= "3.8" and python_version < "3.9" +lightning-utilities==0.11.7 ; python_version >= "3.8" and python_version < "3.9" +markdown==3.7 ; python_version >= "3.8" and python_version < "3.9" +markupsafe==2.1.5 ; python_version >= "3.8" and python_version < "3.9" +matplotlib-inline==0.1.7 ; python_version >= "3.8" and python_version < "3.9" +matplotlib==3.2.2 ; python_version >= "3.8" and python_version < "3.9" +mpmath==1.3.0 ; python_version >= "3.8" and python_version < "3.9" +multidict==6.1.0 ; python_version >= "3.8" and python_version < "3.9" +nest-asyncio==1.6.0 ; python_version >= "3.8" and python_version < "3.9" +networkx==2.5 ; python_version >= "3.8" and python_version < "3.9" +neural-structured-learning==1.3.1 ; python_version >= "3.8" and python_version < "3.9" +numpy==1.21.6 ; python_version >= "3.8" and python_version < "3.9" +nvidia-cublas-cu12==12.1.3.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.9" +nvidia-cuda-cupti-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.9" +nvidia-cuda-nvrtc-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.9" +nvidia-cuda-runtime-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.9" +nvidia-cudnn-cu12==8.9.2.26 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.9" +nvidia-cufft-cu12==11.0.2.54 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.9" +nvidia-curand-cu12==10.3.2.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.9" +nvidia-cusolver-cu12==11.4.5.107 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.9" +nvidia-cusparse-cu12==12.1.0.106 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.9" +nvidia-nccl-cu12==2.18.1 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.9" +nvidia-nvjitlink-cu12==12.6.77 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.9" +nvidia-nvtx-cu12==12.1.105 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.9" +oauthlib==3.2.2 ; python_version >= "3.8" and python_version < "3.9" +opt-einsum==3.4.0 ; python_version >= "3.8" and python_version < "3.9" +packaging==24.1 ; python_version >= "3.8" and python_version < "3.9" +pandas==2.0.3 ; python_version >= "3.8" and python_version < "3.9" +parso==0.8.4 ; python_version >= "3.8" and python_version < "3.9" +pexpect==4.9.0 ; python_version >= "3.8" and python_version < "3.9" and sys_platform != "win32" +pickleshare==0.7.5 ; python_version >= "3.8" and python_version < "3.9" +pillow==10.4.0 ; python_version >= "3.8" and python_version < "3.9" +platformdirs==4.3.6 ; python_version >= "3.8" and python_version < "3.9" +prompt-toolkit==3.0.48 ; python_version >= "3.8" and python_version < "3.9" +propcache==0.2.0 ; python_version >= "3.8" and python_version < "3.9" +protobuf==3.20.3 ; python_version >= "3.8" and python_version < "3.9" +psutil==6.0.0 ; python_version >= "3.8" and python_version < "3.9" +ptyprocess==0.7.0 ; python_version >= "3.8" and python_version < "3.9" and sys_platform != "win32" +pure-eval==0.2.3 ; python_version >= "3.8" and python_version < "3.9" +pyasn1-modules==0.4.1 ; python_version >= "3.8" and python_version < "3.9" +pyasn1==0.6.1 ; python_version >= "3.8" and python_version < "3.9" +pycparser==2.22 ; python_version >= "3.8" and python_version < "3.9" and implementation_name == "pypy" +pydantic-core==2.23.4 ; python_version >= "3.8" and python_version < "3.9" +pydantic==2.9.2 ; python_version >= "3.8" and python_version < "3.9" +pygments==2.18.0 ; python_version >= "3.8" and python_version < "3.9" +pyparsing==3.1.4 ; python_version >= "3.8" and python_version < "3.9" +python-dateutil==2.9.0.post0 ; python_version >= "3.8" and python_version < "3.9" +pytz==2024.2 ; python_version >= "3.8" and python_version < "3.9" +pywin32==307 ; sys_platform == "win32" and platform_python_implementation != "PyPy" and python_version >= "3.8" and python_version < "3.9" +pyyaml==6.0.2 ; python_version >= "3.8" and python_version < "3.9" +pyzmq==26.2.0 ; python_version >= "3.8" and python_version < "3.9" +requests-oauthlib==2.0.0 ; python_version >= "3.8" and python_version < "3.9" +requests==2.32.3 ; python_version >= "3.8" and python_version < "3.9" +rsa==4.9 ; python_version >= "3.8" and python_version < "3.9" +scikit-learn==1.3.2 ; python_version >= "3.8" and python_version < "3.9" +scipy==1.10.1 ; python_version >= "3.8" and python_version < "3.9" +setuptools==75.1.0 ; python_version >= "3.8" and python_version < "3.9" +six==1.16.0 ; python_version >= "3.8" and python_version < "3.9" +smart-open==7.0.5 ; python_version >= "3.8" and python_version < "3.9" +stack-data==0.6.3 ; python_version >= "3.8" and python_version < "3.9" +stellargraph==1.2.1 ; python_version >= "3.8" and python_version < "3.9" +sympy==1.13.3 ; python_version >= "3.8" and python_version < "3.9" +tensorboard-data-server==0.7.2 ; python_version >= "3.8" and python_version < "3.9" +tensorboard==2.14.0 ; python_version >= "3.8" and python_version < "3.9" +tensorflow-estimator==2.7.0 ; python_version >= "3.8" and python_version < "3.9" +tensorflow-io-gcs-filesystem==0.21.0 ; python_version >= "3.8" and python_version < "3.9" +tensorflow==2.7.2 ; python_version >= "3.8" and python_version < "3.9" +termcolor==2.4.0 ; python_version >= "3.8" and python_version < "3.9" +threadpoolctl==3.5.0 ; python_version >= "3.8" and python_version < "3.9" +torch-geometric==2.6.1 ; python_version >= "3.8" and python_version < "3.9" +torch==2.1.2 ; python_version >= "3.8" and python_version < "3.9" +torchmetrics==1.4.3 ; python_version >= "3.8" and python_version < "3.9" +torchvision==0.16.2 ; python_version >= "3.8" and python_version < "3.9" +tornado==6.4.1 ; python_version >= "3.8" and python_version < "3.9" +tqdm==4.66.5 ; python_version >= "3.8" and python_version < "3.9" +traitlets==5.14.3 ; python_version >= "3.8" and python_version < "3.9" +triton==2.1.0 ; platform_system == "Linux" and platform_machine == "x86_64" and python_version >= "3.8" and python_version < "3.9" +typing-extensions==4.12.2 ; python_version >= "3.8" and python_version < "3.9" +tzdata==2024.2 ; python_version >= "3.8" and python_version < "3.9" +urllib3==2.2.3 ; python_version >= "3.8" and python_version < "3.9" +wcwidth==0.2.13 ; python_version >= "3.8" and python_version < "3.9" +werkzeug==3.0.4 ; python_version >= "3.8" and python_version < "3.9" +wheel==0.44.0 ; python_version >= "3.8" and python_version < "3.9" +wrapt==1.16.0 ; python_version >= "3.8" and python_version < "3.9" +yarl==1.14.0 ; python_version >= "3.8" and python_version < "3.9" +zipp==3.20.2 ; python_version >= "3.8" and python_version < "3.9" diff --git a/docker/Dockerfile b/docker/Dockerfile index 3530d5e..d060efb 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -85,3 +85,9 @@ RUN ls -d -1 */ | grep -v -e Chapter10 | xargs rm -rf RUN conda create -n chap10 python=3.10 RUN conda run -n chap10 pip install -r Chapter10/requirements.txt RUN conda run -n chap10 python -m ipykernel install --name chap10 --user + +FROM base as chap11 +RUN ls -d -1 */ | grep -v -e Chapter11 | xargs rm -rf +RUN conda create -n chap11 python=3.8 +RUN conda run -n chap11 pip install -r Chapter11/requirements.txt +RUN conda run -n chap11 python -m ipykernel install --name chap11 --user