From 571e814ddbe8400f8f5ab11e77c06e30e3a314e6 Mon Sep 17 00:00:00 2001 From: vdeltatto Date: Wed, 7 Jan 2026 19:55:32 +0100 Subject: [PATCH 1/5] simplify edge-drawing step --- dadapy/causal_graph.py | 42 +++++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/dadapy/causal_graph.py b/dadapy/causal_graph.py index 309775f8..404fc0d2 100644 --- a/dadapy/causal_graph.py +++ b/dadapy/causal_graph.py @@ -789,29 +789,33 @@ def community_graph_visualization( # draw edges for community_effect_idx, order_idx in keys: - if order_idx > 0: - # for each putative effect community at order >=1... - for community_effect in communities_orders[order_idx]: - community_name_effect = community_names[tuple(community_effect)] - # ...loop over all putative causal communities at order -1 - previous_order = order_idx - 1 + if order_idx == 0: + continue + # for each putative effect community at order >=1... + for community_effect in communities_orders[order_idx]: + community_name_effect = community_names[tuple(community_effect)] + # ...loop over all putative causal communities at previous orders + for previous_order in range(0, order_idx): for community_cause in communities_orders[previous_order]: community_name_cause = community_names[ tuple(community_cause) ] - # ...loop over all variables in each putative causal community - for variable_cause in community_cause: - # ...and draw an edge if at least a link is found - if adj_matrix[variable_cause, community_effect].any(): - G.add_edges_from( - [ - ( - str(community_name_cause), - str(community_name_effect), - ) - ] - ) - break + # ...and draw an edge if at least a link is found + if adj_matrix[ + np.ix_(community_cause, community_effect) + ].any(): + G.add_edges_from( + [ + ( + str(community_name_cause), + str(community_name_effect), + ) + ] + ) + + # delete edges in presence of indirect paths + G = nx.transitive_reduction(G) + # show graph options = { "node_color": "gray", From d7f00675b00aadbf6fbbd61840f90d595736d822 Mon Sep 17 00:00:00 2001 From: vdeltatto Date: Wed, 7 Jan 2026 20:01:49 +0100 Subject: [PATCH 2/5] add variable names in community graph --- dadapy/causal_graph.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/dadapy/causal_graph.py b/dadapy/causal_graph.py index 404fc0d2..6a284a05 100644 --- a/dadapy/causal_graph.py +++ b/dadapy/causal_graph.py @@ -624,6 +624,7 @@ def community_graph_visualization( adj_matrix, type="community", savefig_name=None, + variable_names=None, **kwargs, ): """Shows a visual representation of the dynamical communities on a graph. @@ -641,6 +642,8 @@ def community_graph_visualization( with different colors in a graph with all the original D variables in the time series. savefig_name (str): path at which the picture of the final graph is saved in pdf format. If None (default), the figure is not saved. + variable_names (np.array(str)): array of shape (D,) containing the names of the D variables. + Used only if type="community", to show the variable names in the printout of each community. **kwargs: customizable arguments used by the networkx library. If type="all-variable", these include: 'scale','k1' and 'k2', 'cmap', 'width' and 'arrowsize'. If type="community", the possible arguments are: 'node_color', 'node_size', 'width', 'arrowstyle', 'arrowsize'. @@ -776,9 +779,14 @@ def community_graph_visualization( for community, key in zip(community_names, keys): community_name = community_names[tuple(community)] G.add_node(str(community_name)) - print( - f"Community {community_name} ({len(community)} variables, level {key[1]}): {community}" - ) + if variable_names is None: + print( + f"Community {community_name} ({len(community)} variables, level {key[1]}): {community}" + ) + else: + print( + f"Community {community_name} ({len(community)} variables, level {key[1]}): {variable_names[list(community)]}" + ) # dictionary with keys: (order_idx) and values: list of communities at that order (list of list) communities_orders = {key[1]: [] for key in keys} From 2cf710420b9003898ee94b3ddee8b04617281f5c Mon Sep 17 00:00:00 2001 From: vdeltatto Date: Wed, 7 Jan 2026 20:22:08 +0100 Subject: [PATCH 3/5] macOS-13 -> macOS-14 in tests --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cf03a21a..7f8a872c 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,7 +15,7 @@ jobs: strategy: matrix: include: - - {os: "macOS-13", python-version: "3.10"} + - {os: "macOS-14", python-version: "3.10"} - {os: "ubuntu-22.04", python-version: "3.12"} - {os: "ubuntu-22.04", python-version: "3.11"} - {os: "ubuntu-22.04", python-version: "3.10"} From c3af923e31e3c1e0fbd889f02fb1731651aa5a0c Mon Sep 17 00:00:00 2001 From: vdeltatto Date: Thu, 8 Jan 2026 20:15:08 +0100 Subject: [PATCH 4/5] extend community names to arbitrary number of communities --- dadapy/causal_graph.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/dadapy/causal_graph.py b/dadapy/causal_graph.py index 6a284a05..a25ba5c8 100644 --- a/dadapy/causal_graph.py +++ b/dadapy/causal_graph.py @@ -20,6 +20,7 @@ jax.config.update('jax_platform_name', 'gpu') # set 'cpu' or 'gpu' """ +import itertools import string import warnings @@ -31,6 +32,14 @@ from dadapy import DiffImbalance +def symbol_generator(): + """Generates the alphanumeric strings used to label the dynamical communities.""" + yield from string.ascii_uppercase + for i in itertools.count(1): + for c in string.ascii_uppercase: + yield f"{c}{i}" + + class CausalGraph(DiffImbalance): """Constructs a community causal graph where variables are grouped into single nodes. @@ -66,6 +75,8 @@ def __init__( self.standardize = standardize self.num_variables, self.periods = self._check_and_initialize_args(periods) self.seed = seed + + # outputs self.imbs_training = None self.weights_training = None self.weights_final = None @@ -766,7 +777,9 @@ def community_graph_visualization( G = nx.DiGraph() keys = list(community_dictionary.keys()) values = list(community_dictionary.values()) - alphabet_string = list(string.ascii_uppercase) + alphabet_string = list( + itertools.islice(symbol_generator(), adj_matrix.shape[0]) + ) community_names = { tuple(community): alphabet_string[i] for i, community in enumerate(values) From f236ed8af6fa39f25d74dc02d67cfdbf239a1195 Mon Sep 17 00:00:00 2001 From: vdeltatto Date: Fri, 9 Jan 2026 10:45:14 +0100 Subject: [PATCH 5/5] add method for graph refinement step --- dadapy/causal_graph.py | 601 ++++++++++++++++++++++++ examples/notebook_on_causal_graph.ipynb | 7 +- 2 files changed, 605 insertions(+), 3 deletions(-) diff --git a/dadapy/causal_graph.py b/dadapy/causal_graph.py index a25ba5c8..b4e3a8a0 100644 --- a/dadapy/causal_graph.py +++ b/dadapy/causal_graph.py @@ -85,6 +85,13 @@ def __init__( self.adj_matrix = None self.community_dictionary = None + # graph refinement (direct vs indirect links beteween communities) + self.weights_final_refine = None + self.communities_and_lags_refine = None + self.imbs_training_refine = None + self.imbs_final_refine = None + self.errors_final_refine = None + def _check_and_initialize_args(self, periods): """Checks input arguments to constructor of CausalGraph object.""" num_variables = None @@ -792,6 +799,7 @@ def community_graph_visualization( for community, key in zip(community_names, keys): community_name = community_names[tuple(community)] G.add_node(str(community_name)) + print("Conversion labels - communities:") if variable_names is None: print( f"Community {community_name} ({len(community)} variables, level {key[1]}): {community}" @@ -853,3 +861,596 @@ def community_graph_visualization( # return networkx object return G + + def find_direct_links_communities( + self, + adj_matrix, + community_graph, + community_dictionary, + num_samples, + time_lags, + embedding_dim_present=1, + embedding_dim_future=1, + embedding_time=1, + num_epochs=200, + batches_per_epoch=1, + l1_strength=0.0, + point_adapt_lambda=False, + k_init=1, + k_final=1, + lambda_factor=0.1, + optimizer_name="sgd", + learning_rate=1e-2, + learning_rate_decay=None, + num_points_rows=None, + compute_imb_final=False, + compute_error=False, + ratio_rows_columns=1, + discard_close_ind=None, + ): + """Implements the refinement step to distinguish direct and indirect links between nonconsecutive communities. + + Whenever a pattern A->C->B is found in the community causal graph, the loss + DII(w * [A(0), B(tau-1), C(tau-1),..., B(tau-E), C(tau-E)] -> B(tau)) + is optimized for all input time lags 'tau'. All variables within the same community are scaled by the same weight. + The number of previous time steps E included in the optimization is given by the argument 'embedding_dim_present'. + + Arguments 'num_samples', 'time_lags', 'embedding_dim_present', 'embedding_dim_future' and 'embedding_time' + are read only when data are provided to the CausalGraph object through the argument 'time_series'. + Arguments 'compute_error', 'ratio_rows_columns' and 'discard_close_ind' are only read when 'compute_imb_final' + is set to True. + + Args: + adj_matrix (np.ndarray(float)): binary matrix of shape (D,D) defining the links of a directed + graph with D nodes. + community_graph (networkx.DiGraph): output of method 'community_graph_visualization', with option + type="community". + community_dictionary (dict): dictionary with pairs (comm_id, level) as keys, and lists containing + the indices of the variables in each community as values. + num_samples (int): number of samples harvested from the full time series, interpreted as + independent initial conditions of the same dynamical process. + time_lags (list(int), np.ndarray(int)): tested time lags between 'present' and 'future'. + embedding_dim_present (int): dimension of the time-delay embedding vectors built in the present + space (t=0, t=-1, ...). Default is 1, which means the time-delay embeddings are not employed. + embedding_dim_future (int): dimension of the time-delay embedding vectors built in the space of + the target variable (t=tau, t=tau-1, ...). Default is 1. + embedding_time (int): lag between consecutive samples in the time-delay embedding vectors of each + variable. Default is 1. + target_variables (str or list(int), np.array(int)): list or np.array of the target variables + defining the distance space in the future. Default is "all", for which the optimization is + iterated over all variables as target. + save_weights (bool): whether to save or not the weights during training, rather than only the final + weights. If True, weights are saved in the attribute 'weights_training' of the CausalGraph object, + which is an array of shape (n_target_variables, n_time_lags, num_epochs+1, num_variables). + Default is False. + num_epochs (int): number of training epochs. Default is 200. + batches_per_epoch (int): number of minibatches; must be a divisor of n_points. Each weight update is + carried out by computing the DII gradient over n_points / batches_per_epoch points. Default is 1, + which means that the gradient is computed over all the available points (batch GD). + seed (int): seed of JAX random generator, default is 0. Different seeds determine different mini-batch + partitions. + l1_strength (float): strength of the L1 regularization (LASSO) term. Default is 0. + point_adapt_lambda (bool): whether to use a global smoothing parameter lambda for the c_ij coefficients + in the DII (if False), or a different parameter for each point (if True). Default is True. + k_init (int): initial rank of neighbors used to set lambda. Ranks are defined starting from 1. If + batches_per_epoch > 1, neighbors are recomputed within each mini-batch. Default is 1. + k_final (int): final rank of neighbors used to set lambda. If batches_per_epoch > 1, neighbors are + recomputed within each mini-batch. Default is 1. + lambda_factor (float): factor defining the scale of lambda. Default is 0.1. + params_init (np.array(float), jnp.array(float)): array of shape (n_features_A,) containing the initial + values of the scaling weights to be optimized. If None, params_init is set to [0.1, 0.1, ..., 0.1]. + optimizer_name (str): name of the optimizer, calling the Optax library. Possible choices are 'sgd' + (default), 'adam' and 'adamw'. See https://optax.readthedocs.io/en/latest/api/optimizers.html for + additional details. + learning_rate (float): value of the learning rate. Default is 1e-2. + learning_rate_decay (str): schedule to damp the learning rate to zero starting from the value provided + with the attribute learning_rate. The available schedules are: cosine decay ("cos"), exponential + decay ("exp"; the initial learning rate is halved every 10 steps), or constant learning rate (None). + Default is None (constant learning rate). + num_points_rows (int): number of points sampled from the rows of rank and distance matrices. In case of large + datasets, choosing num_points_rows < n_points can significantly speed up the training. The default is + None, for which num_points_rows == n_points. + compute_imb_final (bool): whether to compute the final DII over the full data set, using the options + specified by 'compute_error', 'ratio_rows_columns' and 'discard_close_ind'. Default is False, for + which those arguments are ignored. + compute_error (bool): whether to compute the final DII and its error by sampling different points along + rows and columns of the distance matrix. If False, the final DII is computed using the same points + along rows and columns, which does not allow for an error estimation. Default is True. + ratio_rows_columns (float): only read when compute_error is True; defines the ratio between the number + of points along rows (nrows) and along columns (ncolumns) of distance and rank matrices, in two groups + randomly sampled. In general, nrows and ncolumns are determined by solving the equations + nrows / ncolumns = ratio_rows_columns, + nrows + ncolumns = n_total_points. + Default is 1, which means that both groups have n_points / 2 elements. + discard_close_ind (int): given any point i, defines the "close" points (following the time ordering + along axis=0 of 'time_series' or 'coords_present') that are known to be significantly correlated with i. + If compute_error is True, "time-correlated" points are excluded by subsampling the data along axis=0 + with stride discard_close_ind + 1. If compute_error is False, distances between each point i and points + within the time window [i-discard_close_ind, i+discard_close_ind] are discarded. Default is 0, for which + no distances between points close in the time are discarded. + + Returns: + weights_final (dict): dictionary containing the final optimization weights for each pair of communities + linked through indirect paths in the community causal graph. The keys are tuples (community_name_cause, + community_name_effect), while the values are np.arrays of shape (n_time_lags, n_weights), where n_weights + is equal to 1 + E + E (1 weight for the cause community, E weights for the effect community, and E weights + for the mediator communities), and E=embedding_dim_present. + communities_and_lags (dict): dictionary containing as keys the tuples (community_name_cause, community_name_effect), + and as values the a list of communities and corresponding lags + imbs_training (dict): dictionary containing as keys the tuples (community_name_cause, community_name_effect), and + as values the DIIs during the trainings. + imbs_final (dict): dictionary containing as keys the tuples (community_name_cause, community_name_effect), and as + values the final DIIs. + errors_final (dict): dictionary containing as keys the tuples (community_name_cause, community_name_effect), and as + values the errors of the final DIIs. + """ + + def find_mediators(graph, node_start, node_end): + all_paths = list( + nx.all_simple_paths(graph, source=node_start, target=node_end) + ) + if not all_paths: + return set() + + # Get intermediates (excluding A and B) for each path + intermediates_per_path = [set(path[1:-1]) for path in all_paths] + + # Find intersection across all path intermediates + mediators = list(set.union(*intermediates_per_path)) + return mediators + + keys = list(community_dictionary.keys()) + values = list(community_dictionary.values()) + alphabet_string = list( + itertools.islice(symbol_generator(), adj_matrix.shape[0]) + ) + community_names = { + tuple(community): alphabet_string[i] for i, community in enumerate(values) + } + from_names_to_communities = { + community_names[key]: key for key in community_names.keys() + } + + # dictionary with keys: (order_idx) and values: list of communities at that order (list of list) + communities_orders = {key[1]: [] for key in keys} + for community_idx, order_idx in keys: + communities_orders[order_idx].append( + community_dictionary[community_idx, order_idx] + ) + + # print names of communities + print("Conversion labels - communities:") + for community, key in zip(community_names, keys): + community_name = community_names[tuple(community)] + print( + f"- Community {community_name} ({len(community)} variables, level {key[1]}): {community}" + ) + + # initialize output variables + imbs_training = {} + weights_final = {} + imbs_final = {} + errors_final = {} + communities_and_lags = {} + + ############# identify all pairs of indirectly linked communities, and mediator communities ############# + for community_effect_idx, order_idx in keys: + if order_idx < 2: + continue + # for each putative effect community at order >=2... + for community_effect in communities_orders[order_idx]: + community_name_effect = community_names[tuple(community_effect)] + + # ...find all its ancestor communities... + effect_ancestors_names = set( + nx.ancestors(community_graph, community_name_effect) + ) + effect_ancestors_sets = [ + from_names_to_communities[effect_ancestor_name] + for effect_ancestor_name in effect_ancestors_names + ] + + # ...and take each ancestor (not a parent!) community as putative cause... + for community_name_cause in list( + effect_ancestors_names.difference( + list(community_graph.predecessors(community_name_effect)) + ) + ): + community_cause = list( + from_names_to_communities[community_name_cause] + ) + + # skip the test if there are no links in the adjacency matrix + if ( + adj_matrix is not None + and ( + adj_matrix[community_cause, :][:, community_effect] == 0 + ).all() + ): + print( + f"Communities {community_name_cause}->{community_name_effect}: not linked according to adjacency matrix, test skipped." + ) + continue + + # find mediating communities between cause and effect + mediator_names = find_mediators( + community_graph, community_name_cause, community_name_effect + ) + mediator_sets = [ + from_names_to_communities[mediator_name] + for mediator_name in mediator_names + ] + mediator_vars = list(set().union(*mediator_sets)) + + # initialize output variables + nvars = 1 + embedding_dim_present + embedding_dim_present + imbs_training[ + community_name_cause, community_name_effect + ] = np.zeros((len(time_lags), num_epochs + 1)) + weights_final[ + community_name_cause, community_name_effect + ] = np.zeros((len(time_lags), nvars)) + communities_ordered = np.concatenate( + ( + [community_name_cause], # don't repeat (single slice) + np.tile( + [community_name_effect], reps=embedding_dim_present + ), # Repeat E (=max_lag) times + np.tile( + mediator_names, reps=embedding_dim_present + ), # Repeat E (=max_lag) times + ) + ) + lags_ordered = np.concatenate( + ( + ["t=0"], + [ + f"t=tau-{lag}" + for lag in np.arange(1, embedding_dim_present + 1) + ], + [ + f"t=tau-{lag}" + for lag in np.arange(1, embedding_dim_present + 1) + ], + ) + ) + communities_and_lags[ + community_name_cause, community_name_effect + ] = [communities_ordered, lags_ordered] + if compute_imb_final: + imbs_final[ + community_name_cause, community_name_effect + ] = np.zeros(len(time_lags)) + if compute_error: + errors_final[ + community_name_cause, community_name_effect + ] = np.zeros(len(time_lags)) + + ########### compute DII((cause(t=0), effect(t=tau-1), ... , mediator(t=tau-1), ...) -> effect(t=tau)) + coords_present = None + variables_t0 = community_cause + if self.time_series is not None: + assert num_samples <= self.time_series.shape[0] - max( + time_lags + ), ( + f"Error: cannot extract {num_samples} samples from {self.time_series.shape[0]} initial " + + f"samples, if the maximum time lag is {np.max(time_lags)}.\nChoose a smaller value of " + + f"num_samples." + ) + + t0s = np.linspace( + (max(embedding_dim_present, embedding_dim_future) - 1) + * embedding_time, # select times defining the ensemble of trajectories + self.time_series.shape[0] - max(time_lags) - 1, + num_samples, + dtype=int, + ) + indices_present = +t0s # no embedding for causal community! + coords_present = self.time_series[:, variables_t0][ + indices_present + ] # has shape (num_samples, n_variables_t0) + elif self.coords_present is not None: + if num_samples is not None: + warninings.warn( + f"Argument 'num_samples' will be ignored, as you already provided the independent " + + f"initial conditions through arguments 'coords_present' and 'coords_future'.\n " + + f"To suppress this warning, set 'num_samples' to None." + ) + if time_lags is not None: + warninings.warn( + f"Argument 'time_lags' will be ignored, as the samples at different time lags t=tau " + + f"are already read from the last dimension of 'coords_future'.\n " + + f"To suppress this warning, set 'time_lags' to None." + ) + num_samples = self.coords_present.shape[0] + time_lags = np.arange(1, self.coords_future.shape[2] + 1) + coords_present = self.coords_present[:, variables_t0] + else: + print( + "To call this method, provide either a time series or directly the present and future samples " + + f"while initializing the CausalGraph class." + ) + + ############# LOOP OVER TAU ############# + for j_tau, tau in enumerate(time_lags): + indices_future = ( # for space B: effect(t=tau) + np.array( + [ + t0s - embedding_time * i + for i in range(embedding_dim_future) + ] + ) + + tau + ) + indices_cond1 = ( # for conditioning on effect(t=tau-1), effect(t=tau-2), ... + np.array( + [ + t0s - embedding_time * i + for i in range(embedding_dim_present) + ] + ) + + tau + - 1 + ) + indices_cond2 = ( # for conditioning on mediators(t=tau-1), effect(t=tau-2), ... + np.array( + [ + t0s - embedding_time * i + for i in range(embedding_dim_present) + ] + ) + + tau + - 1 + ) + + if self.time_series is not None: + coords_future = self.time_series[:, community_effect][ + indices_future + ] # has shape (embedding_dim_future, num_samples, n_variables_effect_community) + coords_future = np.transpose( + coords_future, axes=[1, 2, 0] + ) # convert to shape (num_samples, n_variables_effect_community, embedding_dim_future) + coords_future = coords_future.reshape( + ( + num_samples, + len(community_effect) * embedding_dim_future, + ) + ) + + coords_cond_effect = self.time_series[:, community_effect][ + indices_cond1 + ] # has shape (embedding_dim_present, num_samples, n_variables_effect_community) + coords_cond_effect = np.transpose( + coords_cond_effect, axes=[1, 2, 0] + ) # convert to shape (num_samples, n_variables_effect_community, embedding_dim_present) + coords_cond_effect = coords_cond_effect.reshape( + ( + num_samples, + len(community_effect) * embedding_dim_present, + ) + ) + + coords_cond_mediators = self.time_series[:, mediator_vars][ + indices_cond2 + ] # has shape (embedding_dim_present+1, num_samples, n_variables_mediators) + coords_cond_mediators = np.transpose( + coords_cond_mediators, axes=[1, 2, 0] + ) # convert to shape (num_samples, n_variables_mediators, embedding_dim_present+1) + coords_cond_mediators = coords_cond_mediators.reshape( + ( + num_samples, + len(mediator_vars) * (embedding_dim_present), + ) + ) + else: + coords_future = self.coords_future[ + :, community_effect, j_tau + ] + coords_cond = self.coords_future[ + :, mediator_vars, j_tau - 1 : jtau + 1 + ] + + coords_A = np.concatenate( + ( + coords_present, + coords_cond_effect, + coords_cond_mediators, + ), + axis=1, + ) + variables_A = np.concatenate( + ( + community_cause, # don't repeat (single slice) + np.tile( + community_effect, reps=embedding_dim_present + ), # Repeat E (=max_lag) times + np.tile( + mediator_vars, reps=embedding_dim_present + ), # Repeat E (=max_lag) times + ) + ) + params_groups = np.concatenate( + ( + [len(community_cause)], # don't repeat (single slice) + np.tile( + [len(community_effect)], + reps=embedding_dim_present, + ), # Repeat E (=max_lag) times + np.tile( + [len(set) for set in mediator_sets], + reps=embedding_dim_present, + ), # Repeat E (=max_lag) times + ) + ) + + dii = DiffImbalance( + data_A=coords_A, + data_B=coords_future, + periods_A=None + if self.periods is None + else self.periods[variables_A], + periods_B=None + if self.periods is None + else self.periods[community_effect], + seed=self.seed, + num_epochs=num_epochs, + batches_per_epoch=batches_per_epoch, + l1_strength=l1_strength, + point_adapt_lambda=point_adapt_lambda, + k_init=k_init, + k_final=k_final, + lambda_factor=lambda_factor, + params_init=None, + params_groups=params_groups, + optimizer_name=optimizer_name, + learning_rate=learning_rate, + learning_rate_decay=learning_rate_decay, + num_points_rows=num_points_rows, + ) + ( + weights_temp, + imbs_training[community_name_cause, community_name_effect][ + j_tau + ], + ) = dii.train( + bar_label=f"Communities {community_name_cause}->{community_name_effect}, tau={tau}" + ) + + # compute final DII and its error + if compute_imb_final: + imb, err = dii.return_final_dii( + compute_error=compute_error, + ratio_rows_columns=ratio_rows_columns, + seed=self.seed, + discard_close_ind=discard_close_ind, + ) + imbs_final[community_name_cause, community_name_effect][ + j_tau + ] = imb + if compute_error: + errors_final[ + community_name_cause, community_name_effect + ][j_tau] = dii.error_final + + # save weights + weights_final[community_name_cause, community_name_effect][ + j_tau + ] = weights_temp[-1] + + self.weights_final_refine = weights_final + self.communities_and_lags = communities_and_lags + self.imbs_training_refine = imbs_training + self.imbs_final_refine = imbs_final + self.errors_final_refine = errors_final + return ( + weights_final, + communities_and_lags, + imbs_training, + imbs_final, + errors_final, + ) + + def community_graph_refinement( + self, + community_graph, + community_dictionary, + weights_refine, + communities_and_lags, + variable_names, + threshold=1e-1, + savefig_name=None, + **kwargs, + ): + """Shows a visual representation of the dynamical communities on a graph, after the refinement step. + + This function makes use of the library networkx (https://networkx.org/documentation/stable/index.html) + + Args: + community_graph (networkx.DiGraph): output of method 'community_graph_visualization', with option + type="community". + community_dictionary (dict): dictionary with pairs (comm_id, level) as keys and lists containing + the indices of the variables in each dynamical community as values. + weights_refine (dict): output weights of method 'find_direct_links_communities'. + communities_and_lags (dict): output communities and lags of method 'find_direct_links_communities'. + variable_names (np.array(str)): array of shape (D,) containing the names of the D variables. + threshold (float): weight threshold above which a direct link between two communities is drawn. + savefig_name (str): path at which the picture of the final graph is saved in pdf format. If + None (default), the figure is not saved. + **kwargs: customizable arguments used by the networkx library. If type="all-variable", these + include: 'scale','k1' and 'k2', 'cmap', 'width' and 'arrowsize'. If type="community", the + possible arguments are: 'node_color', 'node_size', 'width', 'arrowstyle', 'arrowsize'. + + Returns: + G (nx.diGraph object): refined community causal graph. + """ + assert ( + community_dictionary is not None + ), "Provide as intput the community dictionary computed with the method find_communities" + + # construct graph + G = community_graph.copy() + + # read community dictionary and extract conversion + keys = list(community_dictionary.keys()) + values = list(community_dictionary.values()) + alphabet_string = list(itertools.islice(symbol_generator(), len(keys))) + community_names = { + tuple(community): alphabet_string[i] for i, community in enumerate(values) + } + from_names_to_communities = { + community_names[key]: key for key in community_names.keys() + } + + # print conversion labels - communities + print("Conversion labels - communities:") + for community, key in zip(community_names, keys): + community_name = community_names[tuple(community)] + if variable_names is None: + print( + f"Community {community_name} ({len(community)} variables, level {key[1]}): {community}" + ) + else: + print( + f"Community {community_name} ({len(community)} variables, level {key[1]}): {variable_names[list(community)]}" + ) + + # extract pairs of communities tested for direct vs indirect links + pairs_cause_effect = list(weights_refine.keys()) + + # loop over such pairs and connect them when at least one weight of causal community > threshold + for community_name_cause, community_name_effect in pairs_cause_effect: + mask_variables_cause = ( + communities_and_lags[community_name_cause, community_name_effect][0] + == community_name_cause + ) + + max_weights_refine = np.max( + weights_refine[community_name_cause, community_name_effect], + axis=0, # axis of lag tau + )[mask_variables_cause] + if (max_weights_refine > threshold).any(): + G.add_edges_from( + [ + ( + str(community_name_cause), + str(community_name_effect), + ) + ] + ) + + # show graph + options = { + "node_color": "gray", + "node_size": 3000, + "width": 3, + "arrowstyle": "-|>", + "arrowsize": 12, + } + options.update(kwargs) + nx.draw_circular(G, arrows=True, with_labels=True, **options) + if savefig_name is not None: + plt.savefig(savefig_name, dpi=300, bbox_inches="tight") + plt.show() + + # return networkx object + return G diff --git a/examples/notebook_on_causal_graph.ipynb b/examples/notebook_on_causal_graph.ipynb index 8d170ccf..70f3c812 100644 --- a/examples/notebook_on_causal_graph.ipynb +++ b/examples/notebook_on_causal_graph.ipynb @@ -52,9 +52,10 @@ "Given a set of $D$ dynamical variables $\\{ x^\\alpha(t)\\}_{\\alpha = 1}^D$, the algorithm aims at splitting the dynamical variables in groups with different levels of autonomy. We call $\\mathcal{S}^\\beta = \\{x^\\beta\\}$ an autonomous subset if each (direct or indirect) cause of any variable $x^\\beta \\in \\mathcal{S}^\\beta$ also belongs to $\\mathcal{S}^\\beta$.\n", "\n", "In the first step of the algorithm, $\\forall \\alpha=1,...,D$ we minimize\n", - "\\begin{equation}\n", + "\n", + "$$\n", " DII(\\boldsymbol{w}\\odot \\boldsymbol{x}(t=0) \\rightarrow x^\\alpha(t=\\tau))\\,.\n", - "\\end{equation}\n", + "$$\n", "Here, $\\boldsymbol{w}$ is a vector of $D$ parameters weighting each dynamical variable, and $\\odot$ denotes the element-wise product. The test is repeated by scanning several values of $\\tau$, choosing $\\tau$ in a range where $x^\\alpha(t=\\tau)$ has not significantly decorrelated from $x^\\alpha(t=0)$.\n", "\n", "This first part of the algorithm is carried out by the method 'optimize_present_to_future', which returns the final weights for each optimization and the corresponding DII over all the training epochs.\n", @@ -592,7 +593,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Finally, the communities identified at the previous step can be represented as single nodes in a causal graph, where a link is drawn only between communities $G^\\alpha$ and $G^\\beta$ identified at consecutive levels if at least a variable in $G^\\alpha$ is linked to a variable in $G^\\beta$ in the original all-variable graph. The following method draws the community causal graph and returns an object of the networkx.DiGraph class." + "Finally, the communities identified at the previous step can be represented as single nodes in a causal graph, where a link is drawn only between communities $G^\\alpha$ and $G^\\beta$ if at least a variable in $G^\\alpha$ is linked to a variable in $G^\\beta$ in the original all-variable graph. Directed arrows between communities linked by indirect paths ($G^\\alpha\\rightarrow ... \\rightarrow G^\\beta$) are omitted. The following method draws the community causal graph and returns an object of the networkx.DiGraph class." ] }, {