From a9be5c2b0e22a117c1b753dca0381391237f4206 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 3 Jan 2023 15:15:12 -0600 Subject: [PATCH 001/155] `tpm.ProxyMetaclass`: Change list to generator expression --- pyphi/tpm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index ae24e131d..6b257130b 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -126,7 +126,7 @@ def proxy(self): # Go through all the attribute strings in the wrapped array type. for name in dir(cls.__wraps__): # Filter special attributes, rest will be handled by `__getattr__()` - if any([not name.startswith("__"), name in ignore, name in dct]): + if any((not name.startswith("__"), name in ignore, name in dct)): continue # Create function for `name` and bind to future instances of `cls`. From d0000b108d78100b2e64979eb704e68aefcce89f Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 3 Jan 2023 17:04:02 -0600 Subject: [PATCH 002/155] `tpm`: Improve documentation --- pyphi/tpm.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 6b257130b..152cba952 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -24,7 +24,7 @@ class ProxyMetaclass(type): The CPython interpreter resolves double-underscore attributes (e.g., the method definitions of mathematical operators) by looking up in the class' static methods, not in the instance methods. This makes it impossible to - intercept calls to them when an instance's __getattr__() is implicitly + intercept calls to them when an instance's ``__getattr__()`` is implicitly invoked, which in turn means there are only two options to wrap the special methods of the array inside our custom objects (in order to perform arithmetic operations with the TPM while also casting the result to our @@ -134,7 +134,7 @@ def proxy(self): class Wrapper(metaclass=ProxyMetaclass): - """Proxy to the array inside PyPhi's custom TPM class.""" + """Proxy to the array inside PyPhi's custom ExplicitTPM class.""" __wraps__ = None @@ -160,7 +160,20 @@ def __init__(self): class ExplicitTPM(data_structures.ArrayLike): - """An explicit network TPM in multidimensional form.""" + """An explicit network TPM in multidimensional form. + + Args: + tpm (np.array): The transition probability matrix of the |Network|. + + Keyword Args: + validate (bool): Whether to check the shape and content of the input + array for correctness. + + Attributes: + _VALUE_ATTR (str): The key of the attribute holding the TPM array value. + __wraps__ (type): The class of the array referenced by ``_VALUE_ATTR``. + __closures__ (frozenset): np.ndarray method names proxied by this class. + """ _VALUE_ATTR = "_tpm" @@ -170,9 +183,6 @@ class ExplicitTPM(data_structures.ArrayLike): # TODO(tpm) remove pending ArrayLike refactor # Casting semantics: values belonging to our custom TPM class should # remain closed under the following methods: - - # TODO attributes data, real and imag return arrays that should also be - # cast, even though they are not callable. __closures__ = frozenset( { "argpartition", From 39796cac278ff002e325b0ee01f2b6c7d94cd5cb Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 3 Jan 2023 17:29:36 -0600 Subject: [PATCH 003/155] Move TPM documentation to `tpm.ExplicitTPM` class --- pyphi/network.py | 25 +------------------------ pyphi/tpm.py | 26 +++++++++++++++++++++++++- 2 files changed, 26 insertions(+), 25 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index e9455be47..a750398d5 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -21,25 +21,7 @@ class Network: Args: tpm (np.ndarray): The transition probability matrix of the network. - - The TPM can be provided in any of three forms: **state-by-state**, - **state-by-node**, or **multidimensional state-by-node** form. - In the state-by-node forms, row indices must follow the - little-endian convention (see :ref:`little-endian-convention`). In - state-by-state form, column indices must also follow the - little-endian convention. - - If the TPM is given in state-by-node form, it can be either - 2-dimensional, so that ``tpm[i]`` gives the probabilities of each - node being ON if the previous state is encoded by |i| according to - the little-endian convention, or in multidimensional form, so that - ``tpm[(0, 0, 1)]`` gives the probabilities of each node being ON if - the previous state is |N_0 = 0, N_1 = 0, N_2 = 1|. - - The shape of the 2-dimensional form of a state-by-node TPM must be - ``(s, n)``, and the shape of the multidimensional form of the TPM - must be ``[2] * n + [n]``, where ``s`` is the number of states and - ``n`` is the number of nodes in the network. + See :func:`pyphi.tpm.ExplicitTPM`. Keyword Args: cm (np.ndarray): A square binary adjacency matrix indicating the @@ -49,11 +31,6 @@ class Network: is connected to every node (including itself)**. node_labels (tuple[str] or |NodeLabels|): Human-readable labels for each node in the network. - - Example: - In a 3-node network, ``the_network.tpm[(0, 0, 1)]`` gives the - transition probabilities for each node at |t| given that state at |t-1| - was |N_0 = 0, N_1 = 0, N_2 = 1|. """ # TODO make tpm also optional when implementing logical network definition diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 152cba952..d1d771779 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -165,10 +165,34 @@ class ExplicitTPM(data_structures.ArrayLike): Args: tpm (np.array): The transition probability matrix of the |Network|. + The TPM can be provided in any of three forms: **state-by-state**, + **state-by-node**, or **multidimensional state-by-node** form. + In the state-by-node forms, row indices must follow the + little-endian convention (see :ref:`little-endian-convention`). In + state-by-state form, column indices must also follow the + little-endian convention. + + If the TPM is given in state-by-node form, it can be either + 2-dimensional, so that ``tpm[i]`` gives the probabilities of each + node being ON if the previous state is encoded by |i| according to + the little-endian convention, or in multidimensional form, so that + ``tpm[(0, 0, 1)]`` gives the probabilities of each node being ON if + the previous state is |N_0 = 0, N_1 = 0, N_2 = 1|. + + The shape of the 2-dimensional form of a state-by-node TPM must be + ``(s, n)``, and the shape of the multidimensional form of the TPM + must be ``[2] * n + [n]``, where ``s`` is the number of states and + ``n`` is the number of nodes in the network. + Keyword Args: validate (bool): Whether to check the shape and content of the input array for correctness. + Example: + In a 3-node network, ``tpm[(0, 0, 1)]`` gives the + transition probabilities for each node at |t| given that state at |t-1| + was |N_0 = 0, N_1 = 0, N_2 = 1|. + Attributes: _VALUE_ATTR (str): The key of the attribute holding the TPM array value. __wraps__ (type): The class of the array referenced by ``_VALUE_ATTR``. @@ -248,7 +272,7 @@ def __init__(self, tpm, validate=False): @property def tpm(self): - """Return the underlying `tpm` object.""" + """np.ndarray: The underlying `tpm` object.""" return self._tpm def validate(self, check_independence=True): From a357031fb230878811f98cab188c6cb7cc340881 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 5 Jan 2023 17:34:45 -0600 Subject: [PATCH 004/155] `tpm.ExplicitTPM`: Remove superfluous call to super constructor --- pyphi/tpm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index d1d771779..5d0794f24 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -261,7 +261,6 @@ def __getattr__(self, name): def __init__(self, tpm, validate=False): self._tpm = np.array(tpm) - super().__init__() if validate: self.validate(check_independence=config.VALIDATE_CONDITIONAL_INDEPENDENCE) From b15924a4749554d03147808b7b19e1626418d047 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 5 Jan 2023 17:36:57 -0600 Subject: [PATCH 005/155] `tpm.ExplicitTPM`: Preserve column alignment in string representation. --- pyphi/tpm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 5d0794f24..cabffa906 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -554,7 +554,7 @@ def __str__(self): return self.__repr__() def __repr__(self): - return "ExplicitTPM({})".format(self._tpm) + return "ExplicitTPM(\n{}\n)".format(self._tpm) def __hash__(self): return self._hash From 9c349e274d6a7e82220893e85b3ea9850a59af99 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Fri, 6 Jan 2023 11:57:59 -0600 Subject: [PATCH 006/155] `subsystem`: Complete `ExplicitTPM.enforce` cleanup Remove inconsequential assignments introduced in bfd62c and 99ad3f. --- pyphi/subsystem.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index beab3075d..7f8379b43 100644 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -86,7 +86,6 @@ def __init__( ): # The network this subsystem belongs to. validate.is_network(network) - network._tpm = network.tpm self.network = network self.node_labels = network.node_labels @@ -392,7 +391,6 @@ def _single_node_effect_repertoire( # pylint: disable=missing-docstring purview_node = self._index2node[purview_node_index] # Condition on the state of the purview inputs that are in the mechanism - purview_node.tpm = purview_node.tpm tpm = purview_node.tpm.condition_tpm(condition) # TODO(4.0) remove reference to TPM # Marginalize-out the inputs that aren't in the mechanism. From 91772ac577291cceec5c959b6c5232bfd962e550 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Fri, 6 Jan 2023 18:02:25 -0600 Subject: [PATCH 007/155] Refactor `Node` to use xarray DataArray (accessor namespace for now) --- pyphi/actual.py | 2 +- pyphi/data_structures/array_like.py | 3 +- pyphi/macro.py | 14 ++-- pyphi/node.py | 116 ++++++++++++++++++---------- pyphi/subsystem.py | 8 +- pyphi/tpm.py | 2 +- setup.py | 1 + test/test_node.py | 38 ++++----- 8 files changed, 111 insertions(+), 73 deletions(-) diff --git a/pyphi/actual.py b/pyphi/actual.py index 428c38763..f21cea504 100644 --- a/pyphi/actual.py +++ b/pyphi/actual.py @@ -153,7 +153,7 @@ def __init__( self.cause_system.state = after_state for node in self.cause_system.nodes: - node.state = after_state[node.index] + node.pyphi.state = after_state[node.pyphi.index] # Validate the cause system # The state of the effect system does not need to be reachable diff --git a/pyphi/data_structures/array_like.py b/pyphi/data_structures/array_like.py index 3df0cf2e0..3e8fe33d0 100644 --- a/pyphi/data_structures/array_like.py +++ b/pyphi/data_structures/array_like.py @@ -12,9 +12,10 @@ class ArrayLike(NDArrayOperatorsMixin): # TODO(tpm) populate this list _TYPE_CLOSED_FUNCTIONS = ( + np.all, np.concatenate, + np.expand_dims, np.stack, - np.all, np.sum, ) diff --git a/pyphi/macro.py b/pyphi/macro.py index 26eac942d..fef5153d7 100644 --- a/pyphi/macro.py +++ b/pyphi/macro.py @@ -77,9 +77,9 @@ def run_tpm(system, steps, blackbox): # boxes. node_tpms = [] for node in system.nodes: - node_tpm = node.tpm_on - for input_node in node.inputs: - if not blackbox.in_same_box(node.index, input_node): + node_tpm = node.pyphi.tpm_on + for input_node in node.pyphi.inputs: + if not blackbox.in_same_box(node.pyphi.index, input_node): if input_node in blackbox.output_indices: node_tpm = node_tpm.marginalize_out([input_node]) @@ -237,7 +237,7 @@ def _squeeze(system): nodes = generate_nodes(tpm, cm, state, node_indices) # Re-calcuate the tpm based on the results of the cut - tpm = rebuild_system_tpm(node.tpm_on for node in nodes) + tpm = rebuild_system_tpm(node.pyphi.tpm_on for node in nodes) return SystemAttrs(tpm, cm, node_indices, state) @@ -247,9 +247,9 @@ def _blackbox_partial_noise(blackbox, system): # Noise inputs from non-output elements hidden in other boxes node_tpms = [] for node in system.nodes: - node_tpm = node.tpm_on - for input_node in node.inputs: - if blackbox.hidden_from(input_node, node.index): + node_tpm = node.pyphi.tpm_on + for input_node in node.pyphi.inputs: + if blackbox.hidden_from(input_node, node.pyphi.index): node_tpm = node_tpm.marginalize_out([input_node]) node_tpms.append(node_tpm) diff --git a/pyphi/node.py b/pyphi/node.py index 50ed200c8..175fd0cbb 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -10,6 +10,7 @@ import functools import numpy as np +import xarray as xr from . import utils from .connectivity import get_inputs_from_cm, get_outputs_from_cm @@ -17,10 +18,10 @@ from .tpm import ExplicitTPM -# TODO extend to nonbinary nodes -@functools.total_ordering -class Node: - """A node in a subsystem. +def node(tpm, cm, index, state=None, node_labels=None): + + """ + Instantiate a DataArray node TPM. Args: tpm (ExplicitTPM): The TPM of the subsystem. @@ -28,6 +29,67 @@ class Node: index (int): The node's index in the network. state (int): The state of this node. node_labels (|NodeLabels|): Labels for these nodes. + """ + + # Get indices of the inputs. + inputs = frozenset(get_inputs_from_cm(index, cm)) + + # Generate the node's TPM. + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We begin by getting the part of the subsystem's TPM that gives just + # the state of this node. This part is still indexed by network state, + # but its last dimension will be gone, since now there's just a single + # scalar value (this node's state) rather than a state-vector for all + # the network nodes. + tpm_on = tpm[..., index] + + # TODO extend to nonbinary nodes + # Marginalize out non-input nodes. + + # TODO use names rather than indices + non_inputs = set(tpm.tpm_indices()) - inputs + tpm_on = tpm_on.marginalize_out(non_inputs).tpm + + # Get the TPM that gives the probability of the node being off, rather + # than on. + tpm_off = 1 - tpm_on + + # Combine the on- and off-TPM so that the first dimension is indexed by + # the state of the node's inputs at t, and the last dimension is + # indexed by the node's state at t+1. This representation makes it easy + # to condition on the node state. + tpm = ExplicitTPM( + np.stack([tpm_off, tpm_on], axis=-1) + ) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + state_space = ["OFF", "ON"] + singleton_state_space = ["_marginalized_"] + + coordinates = [ + state_space if dim == 2 else singleton_state_space for dim in tpm.shape + ] + + dimensions = tuple(node_labels) + ("Pr",) + + return xr.DataArray( + name = node_labels[index] if node_labels else str(index), + data = tpm, + dims = dimensions, + coords = coordinates, + attrs = { + "cm": cm, + "index": index, + "state": state, + "node_labels": node_labels, + } + ) + +# TODO extend to nonbinary nodes +@xr.register_dataarray_accessor("pyphi") +@functools.total_ordering +class Node: + """A node in a Network. Attributes: tpm (ExplicitTPM): The node TPM is an array with shape ``(2,)*(n + 1)``, @@ -42,55 +104,29 @@ class Node: probabilities that the node will be 'ON'. """ - def __init__(self, tpm, cm, index, state, node_labels): + def __init__(self, dataarray): # This node's index in the list of nodes. - self.index = index + self.index = dataarray.attrs["index"] # State of this node. - self.state = state + self.state = dataarray.attrs["state"] # Node labels used in the system - self.node_labels = node_labels + self.node_labels = dataarray.attrs["node_labels"] + + # Network connectivity matrix. + cm = dataarray.attrs["cm"] # Get indices of the inputs. self._inputs = frozenset(get_inputs_from_cm(self.index, cm)) self._outputs = frozenset(get_outputs_from_cm(self.index, cm)) - # Generate the node's TPM. - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # We begin by getting the part of the subsystem's TPM that gives just - # the state of this node. This part is still indexed by network state, - # but its last dimension will be gone, since now there's just a single - # scalar value (this node's state) rather than a state-vector for all - # the network nodes. - tpm_on = tpm[..., self.index] - - # TODO extend to nonbinary nodes - # Marginalize out non-input nodes that are in the subsystem, since the - # external nodes have already been dealt with as boundary conditions in - # the subsystem's TPM. - - # TODO use names rather than indices - non_inputs = set(tpm.tpm_indices()) - self._inputs - tpm_on = tpm_on.marginalize_out(non_inputs).tpm - - # Get the TPM that gives the probability of the node being off, rather - # than on. - tpm_off = 1 - tpm_on - - # Combine the on- and off-TPM so that the first dimension is indexed by - # the state of the node's inputs at t, and the last dimension is - # indexed by the node's state at t+1. This representation makes it easy - # to condition on the node state. - self.tpm = ExplicitTPM( - np.stack([tpm_off, tpm_on], axis=-1), - ) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + self.tpm = dataarray.data # Only compute the hash once. self._hash = hash( - (index, hash(self.tpm), self.state, self._inputs, self._outputs) + (self.index, hash(self.tpm), self.state, self._inputs, self._outputs) ) @property @@ -178,7 +214,7 @@ def generate_nodes(tpm, cm, network_state, indices, node_labels=None): node_state = utils.state_of(indices, network_state) return tuple( - Node(tpm, cm, index, state, node_labels) + node(tpm, cm, index, state, node_labels) for index, state in zip(indices, node_state) ) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 7f8379b43..7070fc937 100644 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -328,10 +328,10 @@ def _single_node_cause_repertoire(self, mechanism_node_index, purview): mechanism_node = self._index2node[mechanism_node_index] # We're conditioning on this node's state, so take the TPM for the node # being in that state. - tpm = mechanism_node.tpm[..., mechanism_node.state] + tpm = mechanism_node.pyphi.tpm[..., mechanism_node.state] # Marginalize-out all parents of this mechanism node that aren't in the # purview. - return tpm.marginalize_out((mechanism_node.inputs - purview)).tpm + return tpm.marginalize_out((mechanism_node.pyphi.inputs - purview)).tpm # TODO extend to nonbinary nodes @cache.method("_repertoire_cache", Direction.CAUSE) @@ -391,10 +391,10 @@ def _single_node_effect_repertoire( # pylint: disable=missing-docstring purview_node = self._index2node[purview_node_index] # Condition on the state of the purview inputs that are in the mechanism - tpm = purview_node.tpm.condition_tpm(condition) + tpm = purview_node.pyphi.tpm.condition_tpm(condition) # TODO(4.0) remove reference to TPM # Marginalize-out the inputs that aren't in the mechanism. - nonmechanism_inputs = purview_node.inputs - set(condition) + nonmechanism_inputs = purview_node.pyphi.inputs - set(condition) tpm = tpm.marginalize_out(nonmechanism_inputs) # Reshape so that the distribution is over next states. return tpm.reshape( diff --git a/pyphi/tpm.py b/pyphi/tpm.py index cabffa906..01f3a1e4b 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -565,7 +565,7 @@ def reconstitute_tpm(subsystem): # The last axis of the node TPMs correponds to ON or OFF probabilities # (used in the conditioning step when calculating the repertoires); we want # ON probabilities. - node_tpms = [node.tpm.tpm[..., 1] for node in subsystem.nodes] + node_tpms = [node.pyphi.tpm[..., 1] for node in subsystem.nodes] # Remove the singleton dimensions corresponding to external nodes node_tpms = [tpm.squeeze(axis=subsystem.external_indices) for tpm in node_tpms] # We add a new singleton axis at the end so that we can use diff --git a/setup.py b/setup.py index 1c81f6092..f6a4f6080 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ "tblib >=1.3.2", "toolz >=0.9.0", "tqdm >=4.20.0", + "xarray >=2022.12.0" ] setup( diff --git a/test/test_node.py b/test/test_node.py index c357d5520..19a99072f 100644 --- a/test/test_node.py +++ b/test/test_node.py @@ -4,7 +4,7 @@ import numpy as np -from pyphi.node import Node, expand_node_tpm, generate_nodes +from pyphi.node import node, expand_node_tpm, generate_nodes from pyphi.subsystem import Subsystem from pyphi.tpm import ExplicitTPM @@ -32,26 +32,26 @@ def test_node_init_tpm(s): answer = [ExplicitTPM(tpm) for tpm in answer] # fmt: on for node in s.nodes: - assert node.tpm.array_equal(answer[node.index]) + assert node.pyphi.tpm.array_equal(answer[node.pyphi.index]) def test_node_init_inputs(s): answer = [s.node_indices[1:], s.node_indices[2:3], s.node_indices[:2]] for node in s.nodes: - assert set(node.inputs) == set(answer[node.index]) + assert set(node.pyphi.inputs) == set(answer[node.pyphi.index]) def test_node_eq(s): - assert s.nodes[1] == Node(s.tpm, s.cm, 1, 0, "B") + assert s.nodes[1] == node(s.tpm, s.cm, 1, 0, "B") def test_node_neq_by_index(s): - assert s.nodes[0] != Node(s.tpm, s.cm, 1, 0, "B") + assert s.nodes[0] != node(s.tpm, s.cm, 1, 0, "B") def test_node_neq_by_state(s): other_s = Subsystem(s.network, (1, 1, 1), s.node_indices) - assert other_s.nodes[1] != Node(s.tpm, s.cm, 1, 0, "B") + assert other_s.nodes[1] != node(s.tpm, s.cm, 1, 0, "B") def test_repr(s): @@ -94,10 +94,10 @@ def test_generate_nodes(s): ]) ) # fmt: on - assert nodes[0].tpm.array_equal(node0_tpm) - assert nodes[0].inputs == set([1, 2]) - assert nodes[0].outputs == set([2]) - assert nodes[0].label == "A" + assert nodes[0].pyphi.tpm.array_equal(node0_tpm) + assert nodes[0].pyphi.inputs == set([1, 2]) + assert nodes[0].pyphi.outputs == set([2]) + assert nodes[0].pyphi.label == "A" # fmt: off node1_tpm = ExplicitTPM( @@ -107,10 +107,10 @@ def test_generate_nodes(s): ]) ) # fmt: on - assert nodes[1].tpm.array_equal(node1_tpm) - assert nodes[1].inputs == set([2]) - assert nodes[1].outputs == set([0, 2]) - assert nodes[1].label == "B" + assert nodes[1].pyphi.tpm.array_equal(node1_tpm) + assert nodes[1].pyphi.inputs == set([2]) + assert nodes[1].pyphi.outputs == set([0, 2]) + assert nodes[1].pyphi.label == "B" # fmt: off node2_tpm = ExplicitTPM( @@ -122,12 +122,12 @@ def test_generate_nodes(s): ]) ) # fmt: on - assert nodes[2].tpm.array_equal(node2_tpm) - assert nodes[2].inputs == set([0, 1]) - assert nodes[2].outputs == set([0, 1]) - assert nodes[2].label == "C" + assert nodes[2].pyphi.tpm.array_equal(node2_tpm) + assert nodes[2].pyphi.inputs == set([0, 1]) + assert nodes[2].pyphi.outputs == set([0, 1]) + assert nodes[2].pyphi.label == "C" def test_generate_nodes_default_labels(s): nodes = generate_nodes(s.tpm, s.cm, s.state, s.node_indices) - assert [n.label for n in nodes] == ["n0", "n1", "n2"] + assert [n.pyphi.label for n in nodes] == ["n0", "n1", "n2"] From db0a5055e1a56247081fef497dbc6a8c4f162166 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 10 Jan 2023 17:12:55 -0600 Subject: [PATCH 008/155] `Node`: Add support for multivalued and heterogeneous units --- docs/examples/magic_cut.rst | 12 +-- pyphi/macro.py | 9 +- pyphi/node.py | 186 +++++++++++++++++++++++++----------- test/test_macro.py | 2 +- 4 files changed, 145 insertions(+), 64 deletions(-) diff --git a/docs/examples/magic_cut.rst b/docs/examples/magic_cut.rst index 8575c5d92..b70b7ae54 100644 --- a/docs/examples/magic_cut.rst +++ b/docs/examples/magic_cut.rst @@ -118,15 +118,15 @@ into existence. >>> C = (2,) >>> AB = (0, 1) -The cut applied to the subsystem severs the connections going to |C| from -either |A| or |B|. In this circumstance, knowing the state of |A| or |B| does -not tell us anything about the state of |C|; only the previous state of |C| can -tell us about the next state of |C|. ``C_node.tpm_on`` gives us the probability -of |C| being ON in the next state, while ``C_node.tpm_off`` would give us the +The cut applied to the subsystem severs the connections going to |C| from either +|A| or |B|. In this circumstance, knowing the state of |A| or |B| does not tell +us anything about the state of |C|; only the previous state of |C| can tell us +about the next state of |C|. ``C_node.tpm[..., 1]`` gives us the probability of +|C| being ON in the next state, while ``C_node.tpm[..., 0]`` would give us the probability of |C| being OFF. >>> C_node = cut_subsystem.indices2nodes(C)[0] - >>> C_node.tpm_on.flatten() + >>> C_node.tpm[..., 1].flatten() array([0.5 , 0.75]) This states that |C| has a 50% chance of being ON in the next state if it diff --git a/pyphi/macro.py b/pyphi/macro.py index fef5153d7..b491e8225 100644 --- a/pyphi/macro.py +++ b/pyphi/macro.py @@ -77,7 +77,8 @@ def run_tpm(system, steps, blackbox): # boxes. node_tpms = [] for node in system.nodes: - node_tpm = node.pyphi.tpm_on + # TODO: nonbinary nodes. + node_tpm = node.pyphi.tpm[..., 1] for input_node in node.pyphi.inputs: if not blackbox.in_same_box(node.pyphi.index, input_node): if input_node in blackbox.output_indices: @@ -237,7 +238,8 @@ def _squeeze(system): nodes = generate_nodes(tpm, cm, state, node_indices) # Re-calcuate the tpm based on the results of the cut - tpm = rebuild_system_tpm(node.pyphi.tpm_on for node in nodes) + # TODO: nonbinary nodes. + tpm = rebuild_system_tpm(node.pyphi.tpm[..., 1] for node in nodes) return SystemAttrs(tpm, cm, node_indices, state) @@ -247,7 +249,8 @@ def _blackbox_partial_noise(blackbox, system): # Noise inputs from non-output elements hidden in other boxes node_tpms = [] for node in system.nodes: - node_tpm = node.pyphi.tpm_on + # TODO: nonbinary nodes. + node_tpm = node.pyphi.tpm[..., 1] for input_node in node.pyphi.inputs: if blackbox.hidden_from(input_node, node.pyphi.index): node_tpm = node_tpm.marginalize_out([input_node]) diff --git a/pyphi/node.py b/pyphi/node.py index 175fd0cbb..2d5d4e7d0 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -18,21 +18,40 @@ from .tpm import ExplicitTPM -def node(tpm, cm, index, state=None, node_labels=None): +def node(tpm, cm, index, state=None, network_state_space=None, node_labels=None): """ Instantiate a DataArray node TPM. Args: - tpm (ExplicitTPM): The TPM of the subsystem. + tpm (pyphi.tpm.ExplicitTPM): The TPM of the subsystem. cm (np.ndarray): The CM of the subsystem. index (int): The node's index in the network. - state (int): The state of this node. - node_labels (|NodeLabels|): Labels for these nodes. + state (Optional[int]): The state of this node. + network_state_space (Optional[list[list[Union[int|str]]]]): Labels for + the state space of each node in the network. If ``None``, states + will be automatically labeled using zero-based integer indices. + node_labels (Optional[|NodeLabels|]): Labels for these nodes. """ - # Get indices of the inputs. + if network_state_space is None: + network_state_space = [list(range(dim)) for dim in tpm.shape[:-1]] + + # TODO: Move to validate.py. + else: + network_state_space_shape = tuple(map(len, network_state_space)) + + if network_state_space_shape != tpm.shape[:-1]: + raise ValueError( + "Mismatch between the network's TPM shape and the provided " + "state space labels." + ) + + # Get indices of the inputs and outputs. inputs = frozenset(get_inputs_from_cm(index, cm)) + outputs = frozenset(get_outputs_from_cm(index, cm)) + + # TODO This section is specific to explicit (and binary) network `tpm`s. # Generate the node's TPM. # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -43,7 +62,6 @@ def node(tpm, cm, index, state=None, node_labels=None): # the network nodes. tpm_on = tpm[..., index] - # TODO extend to nonbinary nodes # Marginalize out non-input nodes. # TODO use names rather than indices @@ -63,14 +81,25 @@ def node(tpm, cm, index, state=None, node_labels=None): ) # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - state_space = ["OFF", "ON"] + # Generate DataArray structure for this node + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # The names of the n nodes in the network (whose state we can condition this + # node's TPM on), plus the last dimension with the probability ("Pr") for + # each possible state of this node in the next timestep. + dimensions = tuple(node_labels) + ("Pr",) + + # For each dimension, compute the relevant state labels (coordinates in + # xarray terminology) from the perspective of this node and its direct + # inputs. singleton_state_space = ["_marginalized_"] coordinates = [ - state_space if dim == 2 else singleton_state_space for dim in tpm.shape + network_state_space[node] if tpm.shape[node] > 1 + else singleton_state_space + for node in range(len(network_state_space)) ] - dimensions = tuple(node_labels) + ("Pr",) + coordinates.append(network_state_space[index]) return xr.DataArray( name = node_labels[index] if node_labels else str(index), @@ -78,81 +107,128 @@ def node(tpm, cm, index, state=None, node_labels=None): dims = dimensions, coords = coordinates, attrs = { - "cm": cm, "index": index, - "state": state, "node_labels": node_labels, + "inputs": inputs, + "outputs": outputs, + "state_space": coordinates[-1], + "state": state, } ) -# TODO extend to nonbinary nodes @xr.register_dataarray_accessor("pyphi") @functools.total_ordering class Node: """A node in a Network. + Args: + dataarray (xr.DataArray): + Attributes: - tpm (ExplicitTPM): The node TPM is an array with shape ``(2,)*(n + 1)``, - where ``n`` is the size of the |Network|. The first ``n`` - dimensions correspond to each node in the system. Dimensions - corresponding to nodes that provide input to this node are of size - 2, while those that do not correspond to inputs are of size 1, so - that the TPM has |2^m x 2| elements where |m| is the number of - inputs. The last dimension corresponds to the state of the node in - the next timestep, so that ``node.tpm[..., 0]`` gives probabilities - that the node will be 'OFF' and ``node.tpm[..., 1]`` gives - probabilities that the node will be 'ON'. + index (int): + label (str): + tpm (pyphi.tpm.ExplicitTPM): The node TPM is an array with shape + |n + 1| dimensions, where ``n`` is the size of the |Network|. The + first ``n`` dimensions correspond to each node in the + system. Dimensions corresponding to nodes that provide input to this + node are of size > 1, while those that do not correspond to inputs are + of size 1. The last dimension corresponds to the state of the + node in the next timestep, so that ``node.tpm[..., 0]`` gives + probabilities that the node will be 'OFF' and ``node.tpm[..., 1]`` + gives probabilities that the node will be 'ON'. + inputs (frozenset): + outputs (frozenset): + state_space (tuple[Union[int|str]]): + state (Optional[Union[int|str]]): """ def __init__(self, dataarray): - - # This node's index in the list of nodes. - self.index = dataarray.attrs["index"] - - # State of this node. - self.state = dataarray.attrs["state"] + self._index = dataarray.attrs["index"] # Node labels used in the system - self.node_labels = dataarray.attrs["node_labels"] + self._node_labels = dataarray.attrs["node_labels"] + + self._inputs = dataarray.attrs["inputs"] + self._outputs = dataarray.attrs["outputs"] - # Network connectivity matrix. - cm = dataarray.attrs["cm"] + self._tpm = dataarray.data - # Get indices of the inputs. - self._inputs = frozenset(get_inputs_from_cm(self.index, cm)) - self._outputs = frozenset(get_outputs_from_cm(self.index, cm)) + self.state_space = dataarray.attrs["state_space"] - self.tpm = dataarray.data + # (Optional) current state of this node. + self.state = dataarray.attrs["state"] # Only compute the hash once. self._hash = hash( - (self.index, hash(self.tpm), self.state, self._inputs, self._outputs) + ( + self.index, + hash(self.tpm), + self._inputs, + self._outputs, + self.state_space, + self.state + ) ) @property - def tpm_off(self): - """The TPM of this node containing only the 'OFF' probabilities.""" - return self.tpm[..., 0] + def index(self): + """int: The node's index in the network.""" + return self._index @property - def tpm_on(self): - """The TPM of this node containing only the 'ON' probabilities.""" - return self.tpm[..., 1] + def label(self): + """str: The textual label for this node.""" + return self.node_labels[self.index] + + @property + def tpm(self): + """pyphi.tpm.ExplicitTPM: The TPM of this node.""" + return self._tpm @property def inputs(self): - """The set of nodes with connections to this node.""" + """frozenset: The set of nodes with connections to this node.""" return self._inputs @property def outputs(self): - """The set of nodes this node has connections to.""" + """frozenset: The set of nodes this node has connections to.""" return self._outputs @property - def label(self): - """The textual label for this node.""" - return self.node_labels[self.index] + def state_space(self): + """tuple[Union[int|str]]: The space of states this node can inhabit.""" + return self._state_space + + @state_space.setter + def state_space(self, value): + state_space = tuple(value) + + if len(set(state_space)) < len(state_space): + raise ValueError( + "Invalid node state space tuple. Repeated states are ambiguous." + ) + + if len(state_space) < 2: + raise ValueError( + "Invalid node state space with less than 2 states." + ) + + self._state_space = state_space + + @property + def state(self): + """Optional[Union[int|str]]: The current state of this node.""" + return self._state + + @state.setter + def state(self, value): + if value not in self.state_space: + raise ValueError( + f"Invalid node state. Possible states are {self.state_space}." + ) + + self._state = value def __repr__(self): return self.label @@ -163,19 +239,21 @@ def __str__(self): def __eq__(self, other): """Return whether this node equals the other object. - Two nodes are equal if they belong to the same subsystem and have the - same index (their TPMs must be the same in that case, so this method - doesn't need to check TPM equality). + Two nodes are equal if they have the same index, the same + inputs and outputs, the same TPM, the same state_space and the + same state. Labels are for display only, so two equal nodes may have different labels. + """ return ( self.index == other.index and self.tpm.array_equal(other.tpm) - and self.state == other.state and self.inputs == other.inputs and self.outputs == other.outputs + and self.state_space == other.state_space + and self.state == other.state ) def __ne__(self, other): @@ -206,7 +284,7 @@ def generate_nodes(tpm, cm, network_state, indices, node_labels=None): node_labels (|NodeLabels|): Textual labels for each node. Returns: - tuple[Node]: The nodes of the system. + tuple[xr.DataArray]: The nodes of the system. """ if node_labels is None: node_labels = NodeLabels(None, indices) @@ -214,11 +292,11 @@ def generate_nodes(tpm, cm, network_state, indices, node_labels=None): node_state = utils.state_of(indices, network_state) return tuple( - node(tpm, cm, index, state, node_labels) + node(tpm, cm, index, state=state, node_labels=node_labels) for index, state in zip(indices, node_state) ) - +# TODO: nonbinary nodes def expand_node_tpm(tpm): """Broadcast a node TPM over the full network. diff --git a/test/test_macro.py b/test/test_macro.py index 1d63985a9..a38ff8b78 100644 --- a/test/test_macro.py +++ b/test/test_macro.py @@ -302,7 +302,7 @@ def test_rebuild_system_tpm(s): # fmt: on assert macro.rebuild_system_tpm(node_tpms).array_equal(answer) - node_tpms = [node.tpm_on for node in s.nodes] + node_tpms = [node.tpm[..., 1] for node in s.nodes] assert macro.rebuild_system_tpm(node_tpms).array_equal(s.tpm) From b77efbb1962eb7063ae138a3081f364675f7e45b Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 11 Jan 2023 13:36:57 -0600 Subject: [PATCH 009/155] `Node`: Fix docstring --- pyphi/node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/node.py b/pyphi/node.py index 2d5d4e7d0..ef21740c7 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -127,7 +127,7 @@ class Node: Attributes: index (int): label (str): - tpm (pyphi.tpm.ExplicitTPM): The node TPM is an array with shape + tpm (pyphi.tpm.ExplicitTPM): The node TPM is an array with |n + 1| dimensions, where ``n`` is the size of the |Network|. The first ``n`` dimensions correspond to each node in the system. Dimensions corresponding to nodes that provide input to this From 890ed6212a210b94fcd8e9b4f8ad785a5038acf3 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 11 Jan 2023 13:40:14 -0600 Subject: [PATCH 010/155] `Node.node`: Singleton coords aren't necessarily due to marginalization --- pyphi/node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/node.py b/pyphi/node.py index ef21740c7..e1ca54be1 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -91,7 +91,7 @@ def node(tpm, cm, index, state=None, network_state_space=None, node_labels=None) # For each dimension, compute the relevant state labels (coordinates in # xarray terminology) from the perspective of this node and its direct # inputs. - singleton_state_space = ["_marginalized_"] + singleton_state_space = ["_singleton_"] coordinates = [ network_state_space[node] if tpm.shape[node] > 1 From f31a38a3f15dfccbf94a15f6937515acac4f9e13 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 12 Jan 2023 17:12:29 -0600 Subject: [PATCH 011/155] `tpm`: Add abstract class `ImplicitTPM` and decorate for `xr.Dataset` --- pyphi/tpm.py | 96 ++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 90 insertions(+), 6 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 01f3a1e4b..f187ecad2 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -10,6 +10,7 @@ from typing import Mapping, Set import numpy as np +import xarray as xr from . import config, convert, data_structures, exceptions from .constants import OFF, ON @@ -536,12 +537,6 @@ def permute_nodes(self, permutation): self._tpm.transpose(dimension_permutation)[..., list(permutation)], ) - def __getitem__(self, i): - item = self._tpm[i] - if isinstance(item, type(self._tpm)): - item = type(self)(item) - return item - def array_equal(self, o: object): """Return whether this TPM equals the other object. @@ -550,6 +545,12 @@ def array_equal(self, o: object): """ return isinstance(o, type(self)) and np.array_equal(self._tpm, o._tpm) + def __getitem__(self, i): + item = self._tpm[i] + if isinstance(item, type(self._tpm)): + item = type(self)(item) + return item + def __str__(self): return self.__repr__() @@ -560,6 +561,89 @@ def __hash__(self): return self._hash +def implicit_tpm(nodes, validate=False): + + """Instantiate an implicit network TPM Dataset.""" + + return xr.Dataset( + data_vars = {node.name: node for node in nodes} + ) + +@xr.register_dataset_accessor("pyphi") +class ImplicitTPM: + + """An implicit network TPM containing |Node| TPMs in multidimensional form. + + Args: + dataset (xr.Dataset): + + Attributes: + """ + + def validate(self, check_independence): + raise NotImplementedError + + def _validate_probabilities(self): + raise NotImplementedError + + def _validate_shape(self, check_independence=True): + raise NotImplementedError + + def to_multidimensional_state_by_node(self): + raise NotImplementedError + + def conditionally_independent(self): + raise NotImplementedError + + def condition_tpm(self, condition: Mapping[int, int]): + raise NotImplementedError + + def marginalize_out(self, node_indices): + raise NotImplementedError + + def is_deterministic(self): + raise NotImplementedError + + def is_state_by_state(self): + """Return ``True`` if ``tpm`` is in state-by-state form, otherwise + ``False``. + """ + return False + + def subtpm(self, fixed_nodes, state): + raise NotImplementedError + + def expand_tpm(self): + raise NotImplementedError + + def infer_edge(self, a, b, contexts): + raise NotImplementedError + + def infer_cm(self): + raise NotImplementedError + + def tpm_indices(self): + raise NotImplementedError + + def print(self): + raise NotImplementedError + + def permute_nodes(self, permutation): + raise NotImplementedError + + def __getitem__(self, i): + raise NotImplementedError + + def __str__(self): + raise NotImplementedError + + def __repr__(self): + raise NotImplementedError + + def __hash__(self): + raise NotImplementedError + + def reconstitute_tpm(subsystem): """Reconstitute the TPM of a subsystem using the individual node TPMs.""" # The last axis of the node TPMs correponds to ON or OFF probabilities From 5c8c62628e09aa393116b2b0bb30fa19730a09c7 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Fri, 13 Jan 2023 17:36:30 -0600 Subject: [PATCH 012/155] Refactor `Network` and `Node` for future implicit TPM support --- pyphi/constants.py | 3 + pyphi/macro.py | 49 +++++++-- pyphi/network.py | 107 ++++++++++++++++--- pyphi/node.py | 259 +++++++++++++++++++++++++-------------------- pyphi/subsystem.py | 15 ++- test/test_node.py | 13 ++- 6 files changed, 307 insertions(+), 139 deletions(-) diff --git a/pyphi/constants.py b/pyphi/constants.py index ea26b536b..103d501ee 100644 --- a/pyphi/constants.py +++ b/pyphi/constants.py @@ -21,3 +21,6 @@ # Probability value below which we issue a warning about precision. # TODO(4.0) TPM_WARNING_THRESHOLD = 1e-10 + +# State space used for singleton dimensions in per-node TPMs. +SINGLETON_STATE = ("_",) diff --git a/pyphi/macro.py b/pyphi/macro.py index b491e8225..c90dd54ec 100644 --- a/pyphi/macro.py +++ b/pyphi/macro.py @@ -16,7 +16,7 @@ from . import compute, config, convert, distribution, utils, validate from .exceptions import ConditionallyDependentError, StateUnreachableError from .labels import NodeLabels -from .network import irreducible_purviews +from .network import irreducible_purviews, build_state_space from .node import expand_node_tpm, generate_nodes from .subsystem import Subsystem from .tpm import ExplicitTPM @@ -97,7 +97,11 @@ def run_tpm(system, steps, blackbox): return ExplicitTPM(convert.state_by_state2state_by_node(tpm), validate=True) -class SystemAttrs(namedtuple("SystemAttrs", ["tpm", "cm", "node_indices", "state"])): +class SystemAttrs( + namedtuple( + "SystemAttrs", ["tpm", "cm", "node_indices", "state", "state_space"] + ) +): """An immutable container that holds all the attributes of a subsystem. Versions of this object are passed down the steps of the micro-to-macro @@ -114,12 +118,23 @@ def node_labels(self): @property def nodes(self): return generate_nodes( - self.tpm, self.cm, self.state, self.node_indices, self.node_labels + self.tpm, + self.cm, + self.state_space, + self.node_indices, + network_state=self.state, + node_labels=self.node_labels ) @staticmethod def pack(system): - return SystemAttrs(system.tpm, system.cm, system.node_indices, system.state) + return SystemAttrs( + system.tpm, + system.cm, + system.node_indices, + system.state, + system.state_space, + ) def apply(self, system): system.tpm = self.tpm @@ -128,6 +143,7 @@ def apply(self, system): system.node_labels = self.node_labels system.nodes = self.nodes system.state = self.state + system.state_space = self.state_space class MacroSubsystem(Subsystem): @@ -233,15 +249,23 @@ def _squeeze(system): state = utils.state_of(internal_indices, system.state) + state_space = build_state_space(system.tpm[:-1], system.state_space) + # Re-index the subsystem nodes with the external nodes removed node_indices = reindex(internal_indices) - nodes = generate_nodes(tpm, cm, state, node_indices) + nodes = generate_nodes( + tpm, + cm, + state_space, + node_indices, + network_state=state + ) # Re-calcuate the tpm based on the results of the cut # TODO: nonbinary nodes. tpm = rebuild_system_tpm(node.pyphi.tpm[..., 1] for node in nodes) - return SystemAttrs(tpm, cm, node_indices, state) + return SystemAttrs(tpm, cm, node_indices, state, state_space) @staticmethod def _blackbox_partial_noise(blackbox, system): @@ -272,7 +296,13 @@ def _blackbox_time(time_scale, blackbox, system): n = len(system.node_indices) cm = np.ones((n, n)) - return SystemAttrs(tpm, cm, system.node_indices, system.state) + return SystemAttrs( + tpm, + cm, + system.node_indices, + system.state, + system.state_space + ) def _blackbox_space(self, blackbox, system): """Blackbox the TPM and CM in space. @@ -290,7 +320,8 @@ def _blackbox_space(self, blackbox, system): assert blackbox.output_indices == tpm.tpm_indices() - tpm = remove_singleton_dimensions(tpm) + new_tpm = remove_singleton_dimensions(tpm) + state_space = build_state_space(tpm[:-1], system.state_space) n = len(blackbox) cm = np.zeros((n, n)) for i, j in itertools.product(range(n), repeat=2): @@ -303,7 +334,7 @@ def _blackbox_space(self, blackbox, system): state = blackbox.macro_state(system.state) node_indices = blackbox.macro_indices - return SystemAttrs(tpm, cm, node_indices, state) + return SystemAttrs(new_tpm, cm, node_indices, state, state_space) @staticmethod def _coarsegrain_space(coarse_grain, is_cut, system): diff --git a/pyphi/network.py b/pyphi/network.py index a750398d5..6212cac3d 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -9,9 +9,12 @@ import numpy as np +from typing import Iterable, Optional, Union + from . import cache, connectivity, jsonify, utils, validate from .labels import NodeLabels -from .tpm import ExplicitTPM +from .node import generate_nodes +from .tpm import ExplicitTPM, ImplicitTPM, implicit_tpm class Network: @@ -31,27 +34,61 @@ class Network: is connected to every node (including itself)**. node_labels (tuple[str] or |NodeLabels|): Human-readable labels for each node in the network. + state_space (Optional[tuple[tuple[Union[int|str]]]]): + Labels for the state space of each node in the network. If ``None``, + states will be automatically labeled using a zero-based integer + index per node. """ # TODO make tpm also optional when implementing logical network definition - def __init__(self, tpm, cm=None, node_labels=None, purview_cache=None): + def __init__( + self, + tpm, + cm=None, + node_labels=None, + state_space=None, + purview_cache=None + ): + self._cm, self._cm_hash = self._build_cm(cm) + self._node_indices = tuple(range(self.size)) + self._node_labels = NodeLabels(node_labels, self._node_indices) + # Initialize _tpm according to argument type. - if isinstance(tpm, ExplicitTPM): + + if isinstance(tpm, (np.ndarray, ExplicitTPM)): + # Validate tpm even if an ExplicitTPM was provided. ExplicitTPM + # accepts instantiation from either another object of its class or + # np.ndarray, so the following achieves validation in general. + tpm = ExplicitTPM(tpm, validate=True) + + self._state_space, _ = build_state_space( + tpm.shape[:-1], state_space + ) + + nodes = generate_nodes( + tpm, + self._cm, + self._state_space, + self._node_indices, + node_labels=self._node_labels + ) + + self._tpm = nodes + + elif isinstance(tpm, ImplicitTPM): self._tpm = tpm - elif isinstance(tpm, np.ndarray): - self._tpm = ExplicitTPM(tpm, validate=True) + elif isinstance(tpm, dict): # From JSON. - self._tpm = ExplicitTPM(tpm["_tpm"], validate=True) + self._tpm = ImplicitTPM(tpm["_tpm"], validate=True) + else: raise TypeError(f"Invalid tpm of type {type(tpm)}.") - self._cm, self._cm_hash = self._build_cm(cm) - self._node_indices = tuple(range(self.size)) - self._node_labels = NodeLabels(node_labels, self._node_indices) self.purview_cache = purview_cache or cache.PurviewCache() - validate.network(self) + # TODO + # validate.network(self) @property def tpm(self): @@ -99,11 +136,18 @@ def size(self): """int: The number of nodes in the network.""" return len(self) - # TODO extend to nonbinary nodes + @property + def state_space(self): + """tuple[tuple[Union[int|str]]: Labels for the state space of each node. + """ + return self._state_space + @property def num_states(self): """int: The number of possible states of the network.""" - return 2 ** self.size + return np.prod( + [len(node_states) for node_states in self._state_space] + ) @property def node_indices(self): @@ -138,10 +182,12 @@ def potential_purviews(self, direction, mechanism): def __len__(self): """int: The number of nodes in the network.""" - return self.tpm.shape[-1] + return self.cm.shape[0] def __repr__(self): - return "Network({}, cm={})".format(self.tpm, self.cm) + # TODO implement a cleaner repr, similar to analyses objects, + # distinctions, etc. + return "Network(\n{},\ncm=\n{}\n)".format(self.tpm, self.cm) def __str__(self): return self.__repr__() @@ -179,6 +225,39 @@ def from_json(cls, json_dict): return Network(**json_dict) +def build_state_space( + nodes_shape: Iterable[int], + state_space: Optional[Iterable[Iterable[Union[int|str]]]] = None, +) -> tuple[tuple[tuple[Union[int|str]]], int]: + """Format the passed state space labels or construct defaults if none. + + Arguments: + nodes_shape (Iterable[int]): The first |n| components in the shape of + a network's multidimensional TPM, where |n| is the number of nodes. + + Keyword Args: + state_space (Optional[Iterable[Iterable[Union[int|str]]]]): The + network's state space labels as provided by the user. + + Returns: + tuple[tuple[tuple[Union[int|str]]], int]: State space for the network of + interest and its hash. + """ + if state_space is None: + state_space = tuple(tuple(range(dim)) for dim in nodes_shape) + else: + # Enforce tuple. + state_space = tuple(map(tuple, state_space)) + # Filter out states of singleton dimensions. + shape_state_map = zip(nodes_shape, state_space) + state_space = tuple( + node_states for dim, node_states in shape_state_map + if dim > 1 + ) + + return (state_space, hash(state_space)) + + def irreducible_purviews(cm, direction, mechanism, purviews): """Return all purviews which are irreducible for the mechanism. diff --git a/pyphi/node.py b/pyphi/node.py index e1ca54be1..680d3231c 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -9,113 +9,17 @@ import functools +from typing import Optional, Union + import numpy as np import xarray as xr from . import utils from .connectivity import get_inputs_from_cm, get_outputs_from_cm +from .constants import SINGLETON_STATE from .labels import NodeLabels from .tpm import ExplicitTPM - -def node(tpm, cm, index, state=None, network_state_space=None, node_labels=None): - - """ - Instantiate a DataArray node TPM. - - Args: - tpm (pyphi.tpm.ExplicitTPM): The TPM of the subsystem. - cm (np.ndarray): The CM of the subsystem. - index (int): The node's index in the network. - state (Optional[int]): The state of this node. - network_state_space (Optional[list[list[Union[int|str]]]]): Labels for - the state space of each node in the network. If ``None``, states - will be automatically labeled using zero-based integer indices. - node_labels (Optional[|NodeLabels|]): Labels for these nodes. - """ - - if network_state_space is None: - network_state_space = [list(range(dim)) for dim in tpm.shape[:-1]] - - # TODO: Move to validate.py. - else: - network_state_space_shape = tuple(map(len, network_state_space)) - - if network_state_space_shape != tpm.shape[:-1]: - raise ValueError( - "Mismatch between the network's TPM shape and the provided " - "state space labels." - ) - - # Get indices of the inputs and outputs. - inputs = frozenset(get_inputs_from_cm(index, cm)) - outputs = frozenset(get_outputs_from_cm(index, cm)) - - # TODO This section is specific to explicit (and binary) network `tpm`s. - - # Generate the node's TPM. - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # We begin by getting the part of the subsystem's TPM that gives just - # the state of this node. This part is still indexed by network state, - # but its last dimension will be gone, since now there's just a single - # scalar value (this node's state) rather than a state-vector for all - # the network nodes. - tpm_on = tpm[..., index] - - # Marginalize out non-input nodes. - - # TODO use names rather than indices - non_inputs = set(tpm.tpm_indices()) - inputs - tpm_on = tpm_on.marginalize_out(non_inputs).tpm - - # Get the TPM that gives the probability of the node being off, rather - # than on. - tpm_off = 1 - tpm_on - - # Combine the on- and off-TPM so that the first dimension is indexed by - # the state of the node's inputs at t, and the last dimension is - # indexed by the node's state at t+1. This representation makes it easy - # to condition on the node state. - tpm = ExplicitTPM( - np.stack([tpm_off, tpm_on], axis=-1) - ) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - # Generate DataArray structure for this node - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # The names of the n nodes in the network (whose state we can condition this - # node's TPM on), plus the last dimension with the probability ("Pr") for - # each possible state of this node in the next timestep. - dimensions = tuple(node_labels) + ("Pr",) - - # For each dimension, compute the relevant state labels (coordinates in - # xarray terminology) from the perspective of this node and its direct - # inputs. - singleton_state_space = ["_singleton_"] - - coordinates = [ - network_state_space[node] if tpm.shape[node] > 1 - else singleton_state_space - for node in range(len(network_state_space)) - ] - - coordinates.append(network_state_space[index]) - - return xr.DataArray( - name = node_labels[index] if node_labels else str(index), - data = tpm, - dims = dimensions, - coords = coordinates, - attrs = { - "index": index, - "node_labels": node_labels, - "inputs": inputs, - "outputs": outputs, - "state_space": coordinates[-1], - "state": state, - } - ) - @xr.register_dataarray_accessor("pyphi") @functools.total_ordering class Node: @@ -127,7 +31,7 @@ class Node: Attributes: index (int): label (str): - tpm (pyphi.tpm.ExplicitTPM): The node TPM is an array with + tpm (|ExplicitTPM|): The node TPM is an array with |n + 1| dimensions, where ``n`` is the size of the |Network|. The first ``n`` dimensions correspond to each node in the system. Dimensions corresponding to nodes that provide input to this @@ -142,7 +46,7 @@ class Node: state (Optional[Union[int|str]]): """ - def __init__(self, dataarray): + def __init__(self, dataarray: xr.DataArray): self._index = dataarray.attrs["index"] # Node labels used in the system @@ -182,7 +86,7 @@ def label(self): @property def tpm(self): - """pyphi.tpm.ExplicitTPM: The TPM of this node.""" + """|ExplicitTPM|: The TPM of this node.""" return self._tpm @property @@ -271,16 +175,103 @@ def to_json(self): return self.index -def generate_nodes(tpm, cm, network_state, indices, node_labels=None): - """Generate |Node| objects for a subsystem. +def node( + tpm: ExplicitTPM, + cm: np.ndarray, + network_state_space: tuple[tuple[Union[int|str]]], + index: int, + state: Optional[Union[int|str]] = None, + node_labels: Optional[NodeLabels] = None +) -> xr.DataArray: + """ + Instantiate a node TPM DataArray. Args: - tpm (ExplicitTPM): The system's TPM - cm (np.ndarray): The corresponding CM. - network_state (tuple): The state of the network. + tpm (|ExplicitTPM|): The TPM of this node. + cm (np.ndarray): The CM of the network. + network_state_space (tuple[tuple[Union[int|str]]]): Labels for the state + space of each node in the network. + index (int): The node's index in the network. + + Keyword Args: + state (Optional[Union[int|str]]): The state of this node. + node_labels (Optional[|NodeLabels|]): Labels for these nodes. + + Returns: + xr.DataArray: The node in question. + """ + + if node_labels is None: + indices = tuple(range(cm.shape[0])) + node_labels = NodeLabels(None, indices) + + # Get indices of the inputs and outputs. + inputs = frozenset(get_inputs_from_cm(index, cm)) + outputs = frozenset(get_outputs_from_cm(index, cm)) + + # Generate DataArray structure for this node + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # The names of the n nodes in the network (whose state we can condition this + # node's TPM on), plus the last dimension with the probability ("Pr") for + # each possible state of this node in the next timestep. + + # Note that xr.Dataset disallows shared names between xr.DataArray elements + # in data_vars (node names) and dimension names. Therefore we prepend + # "input_" to avoid the conflict. + dimensions = ["input_" + label for label in node_labels] + ["Pr"] + + # For each dimension, compute the relevant state labels (coordinates in + # xarray terminology) from the perspective of this node and its direct + # inputs. + singleton_state_space = list(SINGLETON_STATE) + + coordinates = [ + list(network_state_space[node]) if tpm.shape[node] > 1 + else singleton_state_space + for node in range(len(network_state_space)) + ] + + # Append coordinates for the last dimension ("Pr"). + coordinates.append(list(network_state_space[index])) + + # TODO(tpm) implement np.result_type() in + # data_structures.array_like.__array_function__ to avoid converting with + # np.asarray(). + return xr.DataArray( + name = node_labels[index], + data = np.asarray(tpm), + dims = dimensions, + coords = coordinates, + attrs = { + "index": index, + "node_labels": node_labels, + "inputs": inputs, + "outputs": outputs, + "state_space": coordinates[-1], + "state": state, + } + ) + + +def generate_nodes( + tpm: ExplicitTPM, + cm: np.ndarray, + state_space: tuple[tuple[Union[int|str]]], + indices: tuple[int], + network_state: Optional[tuple[Union[int|str]]] = None, + node_labels: Optional[NodeLabels] = None +) -> tuple[xr.DataArray]: + """Generate |Node| objects out of a binary network |ExplicitTPM|. + + Args: + tpm (|ExplicitTPM|): The system's TPM. + cm (np.ndarray): The CM of the network. + state_space (tuple[tuple[Union[int|str]]]): Labels for the state + space of each node in the network. indices (tuple[int]): Indices to generate nodes for. Keyword Args: + network_state (Optional[tuple[Union[int|str]]]): The state of the network. node_labels (|NodeLabels|): Textual labels for each node. Returns: @@ -289,19 +280,63 @@ def generate_nodes(tpm, cm, network_state, indices, node_labels=None): if node_labels is None: node_labels = NodeLabels(None, indices) + if network_state is None: + network_state = (None,) * cm.shape[0] + node_state = utils.state_of(indices, network_state) - return tuple( - node(tpm, cm, index, state=state, node_labels=node_labels) - for index, state in zip(indices, node_state) - ) + nodes = [] + + for index, state in zip(indices, node_state): + # Generate the node's TPM. + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We begin by getting the part of the subsystem's TPM that gives just + # the state of this node. This part is still indexed by network state, + # but its last dimension will be gone, since now there's just a single + # scalar value (this node's state) rather than a state-vector for all + # the network nodes. + tpm_on = tpm[..., index] + + # Marginalize out non-input nodes. + + # TODO use names rather than indices + inputs = frozenset(get_inputs_from_cm(index, cm)) + non_inputs = set(tpm.tpm_indices()) - inputs + tpm_on = tpm_on.marginalize_out(non_inputs).tpm + + # Get the TPM that gives the probability of the node being off, rather + # than on. + tpm_off = 1 - tpm_on + + # Combine the on- and off-TPM so that the first dimension is indexed by + # the state of the node's inputs at t, and the last dimension is + # indexed by the node's state at t+1. This representation makes it easy + # to condition on the node state. + node_tpm = ExplicitTPM( + np.stack([tpm_off, tpm_on], axis=-1) + ) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + nodes.append( + node( + node_tpm, + cm, + state_space, + index, + state=state, + node_labels=node_labels + ) + ) + + return tuple(nodes) + # TODO: nonbinary nodes def expand_node_tpm(tpm): """Broadcast a node TPM over the full network. Args: - tpm (ExplicitTPM): The node TPM to expand. + tpm (|ExplicitTPM|): The node TPM to expand. This is different from broadcasting the TPM of a full system since the last dimension (containing the state of the node) contains only the probability diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 7070fc937..0885c0602 100644 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -29,7 +29,7 @@ _null_ria, ) from .models.mechanism import StateSpecification -from .network import irreducible_purviews +from .network import irreducible_purviews, build_state_space from .node import generate_nodes from .partition import mip_partitions from .repertoire import forward_repertoire, unconstrained_forward_repertoire @@ -114,6 +114,12 @@ def __init__( # The TPM for just the nodes in the subsystem. self.proper_tpm = self.tpm.squeeze()[..., list(self.node_indices)] + # The state space of the nodes in the candidate system. + self.proper_state_space = build_state_space( + self.tpm[:-1], + self.network.state_space + ) + # The unidirectional cut applied for phi evaluation self.cut = ( cut if cut is not None else NullCut(self.node_indices, self.node_labels) @@ -140,7 +146,12 @@ def __init__( ) self.nodes = generate_nodes( - self.tpm, self.cm, self.state, self.node_indices, self.node_labels + self.tpm, + self.cm, + self.proper_state_space, + self.node_indices, + network_state=self.state, + node_labels=self.node_labels ) validate.subsystem(self) diff --git a/test/test_node.py b/test/test_node.py index 19a99072f..abdd3d8ad 100644 --- a/test/test_node.py +++ b/test/test_node.py @@ -82,7 +82,14 @@ def test_expand_tpm(): def test_generate_nodes(s): - nodes = generate_nodes(s.tpm, s.cm, s.state, s.node_indices, s.node_labels) + nodes = generate_nodes( + s.tpm, + s.cm, + s.state_space, + s.node_indices, + network_state=s.state, + node_labels=s.node_labels + ) # fmt: off node0_tpm = ExplicitTPM( @@ -129,5 +136,7 @@ def test_generate_nodes(s): def test_generate_nodes_default_labels(s): - nodes = generate_nodes(s.tpm, s.cm, s.state, s.node_indices) + nodes = generate_nodes( + s.tpm, s.cm, s.state_space, s.node_indices, network_state=s.state + ) assert [n.pyphi.label for n in nodes] == ["n0", "n1", "n2"] From 7e43f6f6e46768c145c5282684aab644917f5fcb Mon Sep 17 00:00:00 2001 From: Isaac David Date: Mon, 16 Jan 2023 15:29:18 -0600 Subject: [PATCH 013/155] Send duplicate state space filtering to `utils.build_state_space()` --- pyphi/macro.py | 7 +++--- pyphi/network.py | 36 +---------------------------- pyphi/node.py | 56 ++++++++++++++++++++++------------------------ pyphi/subsystem.py | 6 ++--- pyphi/utils.py | 51 ++++++++++++++++++++++++++++++++++++++++- 5 files changed, 85 insertions(+), 71 deletions(-) diff --git a/pyphi/macro.py b/pyphi/macro.py index c90dd54ec..60b69cb1f 100644 --- a/pyphi/macro.py +++ b/pyphi/macro.py @@ -16,10 +16,11 @@ from . import compute, config, convert, distribution, utils, validate from .exceptions import ConditionallyDependentError, StateUnreachableError from .labels import NodeLabels -from .network import irreducible_purviews, build_state_space +from .network import irreducible_purviews from .node import expand_node_tpm, generate_nodes from .subsystem import Subsystem from .tpm import ExplicitTPM +from .utils import build_state_space # Create a logger for this module. log = logging.getLogger(__name__) @@ -249,7 +250,7 @@ def _squeeze(system): state = utils.state_of(internal_indices, system.state) - state_space = build_state_space(system.tpm[:-1], system.state_space) + state_space, _ = build_state_space(system.tpm[:-1], system.state_space) # Re-index the subsystem nodes with the external nodes removed node_indices = reindex(internal_indices) @@ -321,7 +322,7 @@ def _blackbox_space(self, blackbox, system): assert blackbox.output_indices == tpm.tpm_indices() new_tpm = remove_singleton_dimensions(tpm) - state_space = build_state_space(tpm[:-1], system.state_space) + state_space, _ = build_state_space(tpm[:-1], system.state_space) n = len(blackbox) cm = np.zeros((n, n)) for i, j in itertools.product(range(n), repeat=2): diff --git a/pyphi/network.py b/pyphi/network.py index 6212cac3d..89999b2a4 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -9,12 +9,11 @@ import numpy as np -from typing import Iterable, Optional, Union - from . import cache, connectivity, jsonify, utils, validate from .labels import NodeLabels from .node import generate_nodes from .tpm import ExplicitTPM, ImplicitTPM, implicit_tpm +from .utils import build_state_space class Network: @@ -225,39 +224,6 @@ def from_json(cls, json_dict): return Network(**json_dict) -def build_state_space( - nodes_shape: Iterable[int], - state_space: Optional[Iterable[Iterable[Union[int|str]]]] = None, -) -> tuple[tuple[tuple[Union[int|str]]], int]: - """Format the passed state space labels or construct defaults if none. - - Arguments: - nodes_shape (Iterable[int]): The first |n| components in the shape of - a network's multidimensional TPM, where |n| is the number of nodes. - - Keyword Args: - state_space (Optional[Iterable[Iterable[Union[int|str]]]]): The - network's state space labels as provided by the user. - - Returns: - tuple[tuple[tuple[Union[int|str]]], int]: State space for the network of - interest and its hash. - """ - if state_space is None: - state_space = tuple(tuple(range(dim)) for dim in nodes_shape) - else: - # Enforce tuple. - state_space = tuple(map(tuple, state_space)) - # Filter out states of singleton dimensions. - shape_state_map = zip(nodes_shape, state_space) - state_space = tuple( - node_states for dim, node_states in shape_state_map - if dim > 1 - ) - - return (state_space, hash(state_space)) - - def irreducible_purviews(cm, direction, mechanism, purviews): """Return all purviews which are irreducible for the mechanism. diff --git a/pyphi/node.py b/pyphi/node.py index 680d3231c..20b278aec 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -9,16 +9,16 @@ import functools -from typing import Optional, Union +from typing import Optional, Tuple, Union import numpy as np import xarray as xr -from . import utils from .connectivity import get_inputs_from_cm, get_outputs_from_cm from .constants import SINGLETON_STATE from .labels import NodeLabels from .tpm import ExplicitTPM +from .utils import build_state_space, state_of @xr.register_dataarray_accessor("pyphi") @functools.total_ordering @@ -42,7 +42,7 @@ class Node: gives probabilities that the node will be 'ON'. inputs (frozenset): outputs (frozenset): - state_space (tuple[Union[int|str]]): + state_space (Tuple[Union[int|str]]): state (Optional[Union[int|str]]): """ @@ -101,7 +101,7 @@ def outputs(self): @property def state_space(self): - """tuple[Union[int|str]]: The space of states this node can inhabit.""" + """Tuple[Union[int|str]]: The space of states this node can inhabit.""" return self._state_space @state_space.setter @@ -178,7 +178,7 @@ def to_json(self): def node( tpm: ExplicitTPM, cm: np.ndarray, - network_state_space: tuple[tuple[Union[int|str]]], + network_state_space: Tuple[Tuple[Union[int|str]]], index: int, state: Optional[Union[int|str]] = None, node_labels: Optional[NodeLabels] = None @@ -189,7 +189,7 @@ def node( Args: tpm (|ExplicitTPM|): The TPM of this node. cm (np.ndarray): The CM of the network. - network_state_space (tuple[tuple[Union[int|str]]]): Labels for the state + network_state_space (Tuple[Tuple[Union[int|str]]]): Labels for the state space of each node in the network. index (int): The node's index in the network. @@ -215,24 +215,22 @@ def node( # node's TPM on), plus the last dimension with the probability ("Pr") for # each possible state of this node in the next timestep. - # Note that xr.Dataset disallows shared names between xr.DataArray elements - # in data_vars (node names) and dimension names. Therefore we prepend - # "input_" to avoid the conflict. + # data_vars (xr.DataArray node names) and dimension names share the same + # dictionary-like namespace in xr.Dataset. Prepend constant "input_" string + # to avoid the conflict. dimensions = ["input_" + label for label in node_labels] + ["Pr"] # For each dimension, compute the relevant state labels (coordinates in # xarray terminology) from the perspective of this node and its direct # inputs. - singleton_state_space = list(SINGLETON_STATE) - - coordinates = [ - list(network_state_space[node]) if tpm.shape[node] > 1 - else singleton_state_space - for node in range(len(network_state_space)) - ] + state_space, _ = build_state_space( + tpm.shape[:-1], + network_state_space, + SINGLETON_STATE + ) + node_state_space = network_state_space[index] - # Append coordinates for the last dimension ("Pr"). - coordinates.append(list(network_state_space[index])) + coordinates = [*state_space, node_state_space] # TODO(tpm) implement np.result_type() in # data_structures.array_like.__array_function__ to avoid converting with @@ -241,13 +239,13 @@ def node( name = node_labels[index], data = np.asarray(tpm), dims = dimensions, - coords = coordinates, + coords = list(map(list, coordinates)), attrs = { "index": index, "node_labels": node_labels, "inputs": inputs, "outputs": outputs, - "state_space": coordinates[-1], + "state_space": node_state_space, "state": state, } ) @@ -256,26 +254,26 @@ def node( def generate_nodes( tpm: ExplicitTPM, cm: np.ndarray, - state_space: tuple[tuple[Union[int|str]]], - indices: tuple[int], - network_state: Optional[tuple[Union[int|str]]] = None, + state_space: Tuple[Tuple[Union[int|str]]], + indices: Tuple[int], + network_state: Optional[Tuple[Union[int|str]]] = None, node_labels: Optional[NodeLabels] = None -) -> tuple[xr.DataArray]: +) -> Tuple[xr.DataArray]: """Generate |Node| objects out of a binary network |ExplicitTPM|. Args: tpm (|ExplicitTPM|): The system's TPM. cm (np.ndarray): The CM of the network. - state_space (tuple[tuple[Union[int|str]]]): Labels for the state + state_space (Tuple[Tuple[Union[int|str]]]): Labels for the state space of each node in the network. - indices (tuple[int]): Indices to generate nodes for. + indices (Tuple[int]): Indices to generate nodes for. Keyword Args: - network_state (Optional[tuple[Union[int|str]]]): The state of the network. + network_state (Optional[Tuple[Union[int|str]]]): The state of the network. node_labels (|NodeLabels|): Textual labels for each node. Returns: - tuple[xr.DataArray]: The nodes of the system. + Tuple[xr.DataArray]: The nodes of the system. """ if node_labels is None: node_labels = NodeLabels(None, indices) @@ -283,7 +281,7 @@ def generate_nodes( if network_state is None: network_state = (None,) * cm.shape[0] - node_state = utils.state_of(indices, network_state) + node_state = state_of(indices, network_state) nodes = [] diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 0885c0602..bc5d13873 100644 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -29,12 +29,12 @@ _null_ria, ) from .models.mechanism import StateSpecification -from .network import irreducible_purviews, build_state_space +from .network import irreducible_purviews from .node import generate_nodes from .partition import mip_partitions from .repertoire import forward_repertoire, unconstrained_forward_repertoire from .tpm import ExplicitTPM -from .utils import state_of +from .utils import build_state_space, state_of log = logging.getLogger(__name__) @@ -115,7 +115,7 @@ def __init__( self.proper_tpm = self.tpm.squeeze()[..., list(self.node_indices)] # The state space of the nodes in the candidate system. - self.proper_state_space = build_state_space( + self.proper_state_space, _ = build_state_space( self.tpm[:-1], self.network.state_space ) diff --git a/pyphi/utils.py b/pyphi/utils.py index d7b22a9d9..5130a72e9 100644 --- a/pyphi/utils.py +++ b/pyphi/utils.py @@ -12,7 +12,7 @@ import operator import os from itertools import chain, combinations, product -from typing import Tuple +from typing import Iterable, Optional, Union, Tuple import numpy as np from scipy.special import comb @@ -79,6 +79,55 @@ def state_of_subsystem_nodes(node_indices, nodes, subsystem_state): return state_of([node_indices.index(n) for n in nodes], subsystem_state) +def build_state_space( + nodes_shape: Iterable[int], + state_space: Optional[Iterable[Iterable[Union[int|str]]]] = None, + singleton_state_space: Optional[Iterable[Union[int|str]]] = None, +) -> Tuple[Tuple[Tuple[Union[int|str]]], int]: + """Format the passed state space labels or construct defaults if none. + + Arguments: + nodes_shape (Iterable[int]): The first |n| components in the shape of + a multidimensional |ExplicitTPM|, where |n| is the number of nodes + in the network. + + Keyword Args: + state_space (Optional[Iterable[Iterable[Union[int|str]]]]): The + network's state space labels as provided by the user. + singleton_state_space (Optional[Iterable[Union[int|str]]]): The label to + be used for singleton dimensions. If ``None``, singleton dimensions + will be discarded. + + Returns: + Tuple[Tuple[Tuple[Union[int|str]]], int]: State space for the network of + interest and its hash. + """ + if state_space is None: + state_space = tuple(tuple(range(dim)) for dim in nodes_shape) + else: + # Enforce tuples. + state_space = map(tuple, state_space) + + # Filter out states of singleton dimensions. + shape_state_map = zip(nodes_shape, state_space) + + if singleton_state_space is None: + state_space = tuple( + node_states + for dim, node_states in shape_state_map + if dim > 1 + ) + + else: + state_space = tuple( + node_states if dim > 1 else singleton_state_space + for dim, node_states in shape_state_map + ) + + return (state_space, hash(state_space)) + + +# TODO: nonbinary states def all_states(n, big_endian=False): """Return all binary states for a system. From b3687c79c7a6656c6743f963d1911e2d18b4e655 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Mon, 16 Jan 2023 16:25:06 -0600 Subject: [PATCH 014/155] `Network.__init__`: Avoid transient storage of nodes beyond the TPM --- pyphi/network.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index 89999b2a4..6d3ac10ca 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -64,16 +64,16 @@ def __init__( tpm.shape[:-1], state_space ) - nodes = generate_nodes( - tpm, - self._cm, - self._state_space, - self._node_indices, - node_labels=self._node_labels + self._tpm = implicit_tpm( + generate_nodes( + tpm, + self._cm, + self._state_space, + self._node_indices, + node_labels=self._node_labels + ) ) - self._tpm = nodes - elif isinstance(tpm, ImplicitTPM): self._tpm = tpm From 34e7032522cc73e53866633908bb0dfa4306df9b Mon Sep 17 00:00:00 2001 From: Isaac David Date: Mon, 16 Jan 2023 16:29:09 -0600 Subject: [PATCH 015/155] `Network.to_json` Include state_space attribute. --- pyphi/network.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pyphi/network.py b/pyphi/network.py index 6d3ac10ca..e222151f0 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -215,6 +215,7 @@ def to_json(self): "cm": self.cm, "size": self.size, "node_labels": self.node_labels, + "state_space": self.state_space, } @classmethod From 0dd60365306fcf8b5c8c44cc09a93d385d19c1be Mon Sep 17 00:00:00 2001 From: Isaac David Date: Mon, 16 Jan 2023 16:34:43 -0600 Subject: [PATCH 016/155] `Network.__len__`: Go back to counting nodes in the TPM (now a Dataset). --- pyphi/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/network.py b/pyphi/network.py index e222151f0..2f81b9169 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -181,7 +181,7 @@ def potential_purviews(self, direction, mechanism): def __len__(self): """int: The number of nodes in the network.""" - return self.cm.shape[0] + return len(self.tpm) def __repr__(self): # TODO implement a cleaner repr, similar to analyses objects, From ed3bce3a39be71b57b60357644a1cb476fa8bb18 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Tue, 17 Jan 2023 13:29:28 -0600 Subject: [PATCH 017/155] build_state_space(): Fix Union declaration --- pyphi/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyphi/utils.py b/pyphi/utils.py index 5130a72e9..3a1d75eac 100644 --- a/pyphi/utils.py +++ b/pyphi/utils.py @@ -81,9 +81,9 @@ def state_of_subsystem_nodes(node_indices, nodes, subsystem_state): def build_state_space( nodes_shape: Iterable[int], - state_space: Optional[Iterable[Iterable[Union[int|str]]]] = None, - singleton_state_space: Optional[Iterable[Union[int|str]]] = None, -) -> Tuple[Tuple[Tuple[Union[int|str]]], int]: + state_space: Optional[Iterable[Iterable[Union[int, str]]]] = None, + singleton_state_space: Optional[Iterable[Union[int, str]]] = None, +) -> Tuple[Tuple[Tuple[Union[int, str]]], int]: """Format the passed state space labels or construct defaults if none. Arguments: From 3c9760cd6ba04e6d36f33f9303f83ec82c521d12 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Tue, 17 Jan 2023 13:36:59 -0600 Subject: [PATCH 018/155] Fix widespread Union declaration typo --- pyphi/network.py | 4 ++-- pyphi/node.py | 24 ++++++++++++------------ pyphi/utils.py | 6 +++--- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index 89999b2a4..c1b140bec 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -33,7 +33,7 @@ class Network: is connected to every node (including itself)**. node_labels (tuple[str] or |NodeLabels|): Human-readable labels for each node in the network. - state_space (Optional[tuple[tuple[Union[int|str]]]]): + state_space (Optional[tuple[tuple[Union[int, str]]]]): Labels for the state space of each node in the network. If ``None``, states will be automatically labeled using a zero-based integer index per node. @@ -137,7 +137,7 @@ def size(self): @property def state_space(self): - """tuple[tuple[Union[int|str]]: Labels for the state space of each node. + """tuple[tuple[Union[int, str]]: Labels for the state space of each node. """ return self._state_space diff --git a/pyphi/node.py b/pyphi/node.py index 20b278aec..5d3055d80 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -42,8 +42,8 @@ class Node: gives probabilities that the node will be 'ON'. inputs (frozenset): outputs (frozenset): - state_space (Tuple[Union[int|str]]): - state (Optional[Union[int|str]]): + state_space (Tuple[Union[int, str]]): + state (Optional[Union[int, str]]): """ def __init__(self, dataarray: xr.DataArray): @@ -101,7 +101,7 @@ def outputs(self): @property def state_space(self): - """Tuple[Union[int|str]]: The space of states this node can inhabit.""" + """Tuple[Union[int, str]]: The space of states this node can inhabit.""" return self._state_space @state_space.setter @@ -122,7 +122,7 @@ def state_space(self, value): @property def state(self): - """Optional[Union[int|str]]: The current state of this node.""" + """Optional[Union[int, str]]: The current state of this node.""" return self._state @state.setter @@ -178,9 +178,9 @@ def to_json(self): def node( tpm: ExplicitTPM, cm: np.ndarray, - network_state_space: Tuple[Tuple[Union[int|str]]], + network_state_space: Tuple[Tuple[Union[int, str]]], index: int, - state: Optional[Union[int|str]] = None, + state: Optional[Union[int, str]] = None, node_labels: Optional[NodeLabels] = None ) -> xr.DataArray: """ @@ -189,12 +189,12 @@ def node( Args: tpm (|ExplicitTPM|): The TPM of this node. cm (np.ndarray): The CM of the network. - network_state_space (Tuple[Tuple[Union[int|str]]]): Labels for the state + network_state_space (Tuple[Tuple[Union[int, str]]]): Labels for the state space of each node in the network. index (int): The node's index in the network. Keyword Args: - state (Optional[Union[int|str]]): The state of this node. + state (Optional[Union[int, str]]): The state of this node. node_labels (Optional[|NodeLabels|]): Labels for these nodes. Returns: @@ -254,9 +254,9 @@ def node( def generate_nodes( tpm: ExplicitTPM, cm: np.ndarray, - state_space: Tuple[Tuple[Union[int|str]]], + state_space: Tuple[Tuple[Union[int, str]]], indices: Tuple[int], - network_state: Optional[Tuple[Union[int|str]]] = None, + network_state: Optional[Tuple[Union[int, str]]] = None, node_labels: Optional[NodeLabels] = None ) -> Tuple[xr.DataArray]: """Generate |Node| objects out of a binary network |ExplicitTPM|. @@ -264,12 +264,12 @@ def generate_nodes( Args: tpm (|ExplicitTPM|): The system's TPM. cm (np.ndarray): The CM of the network. - state_space (Tuple[Tuple[Union[int|str]]]): Labels for the state + state_space (Tuple[Tuple[Union[int, str]]]): Labels for the state space of each node in the network. indices (Tuple[int]): Indices to generate nodes for. Keyword Args: - network_state (Optional[Tuple[Union[int|str]]]): The state of the network. + network_state (Optional[Tuple[Union[int, str]]]): The state of the network. node_labels (|NodeLabels|): Textual labels for each node. Returns: diff --git a/pyphi/utils.py b/pyphi/utils.py index 3a1d75eac..b56a04e32 100644 --- a/pyphi/utils.py +++ b/pyphi/utils.py @@ -92,14 +92,14 @@ def build_state_space( in the network. Keyword Args: - state_space (Optional[Iterable[Iterable[Union[int|str]]]]): The + state_space (Optional[Iterable[Iterable[Union[int, str]]]]): The network's state space labels as provided by the user. - singleton_state_space (Optional[Iterable[Union[int|str]]]): The label to + singleton_state_space (Optional[Iterable[Union[int, str]]]): The label to be used for singleton dimensions. If ``None``, singleton dimensions will be discarded. Returns: - Tuple[Tuple[Tuple[Union[int|str]]], int]: State space for the network of + Tuple[Tuple[Tuple[Union[int, str]]], int]: State space for the network of interest and its hash. """ if state_space is None: From 5c1f1b5d816936dfbfa9920bdb6075d0033e9b7f Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 17 Jan 2023 15:56:58 -0600 Subject: [PATCH 019/155] Squeeze singleton dimensions in `Node` --- pyphi/node.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/pyphi/node.py b/pyphi/node.py index 20b278aec..c25b60f4f 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -201,24 +201,30 @@ def node( xr.DataArray: The node in question. """ - if node_labels is None: - indices = tuple(range(cm.shape[0])) - node_labels = NodeLabels(None, indices) - # Get indices of the inputs and outputs. inputs = frozenset(get_inputs_from_cm(index, cm)) outputs = frozenset(get_outputs_from_cm(index, cm)) # Generate DataArray structure for this node # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # The names of the n nodes in the network (whose state we can condition this - # node's TPM on), plus the last dimension with the probability ("Pr") for - # each possible state of this node in the next timestep. + # Dimensions are the names of this node's parents (whose state we + # can condition this node's TPM on), plus the last dimension with + # the probability ("Pr") for each possible state of this node in + # the next timestep. # data_vars (xr.DataArray node names) and dimension names share the same # dictionary-like namespace in xr.Dataset. Prepend constant "input_" string # to avoid the conflict. - dimensions = ["input_" + label for label in node_labels] + ["Pr"] + if node_labels is None: + indices = tuple(range(cm.shape[0])) + node_labels = NodeLabels(None, indices) + + parent_node_labels = tuple( + label for dim, label in zip(tpm.shape[:-1], node_labels) + if dim > 1 + ) + + dimensions = ["input_" + label for label in parent_node_labels] + ["Pr"] # For each dimension, compute the relevant state labels (coordinates in # xarray terminology) from the perspective of this node and its direct @@ -226,8 +232,9 @@ def node( state_space, _ = build_state_space( tpm.shape[:-1], network_state_space, - SINGLETON_STATE + singleton_state_space = None, ) + node_state_space = network_state_space[index] coordinates = [*state_space, node_state_space] @@ -237,7 +244,7 @@ def node( # np.asarray(). return xr.DataArray( name = node_labels[index], - data = np.asarray(tpm), + data = np.asarray(tpm.squeeze()), dims = dimensions, coords = list(map(list, coordinates)), attrs = { From 5351aecb5c6739945af83d3cb67533b341da92cc Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 18 Jan 2023 13:53:39 -0600 Subject: [PATCH 020/155] Implement ImplicitTPM validation --- pyphi/data_structures/array_like.py | 2 + pyphi/network.py | 11 +-- pyphi/node.py | 8 +-- pyphi/tpm.py | 107 +++++++++++++++++++++++++--- pyphi/validate.py | 2 +- 5 files changed, 112 insertions(+), 18 deletions(-) diff --git a/pyphi/data_structures/array_like.py b/pyphi/data_structures/array_like.py index 3e8fe33d0..b290f47f5 100644 --- a/pyphi/data_structures/array_like.py +++ b/pyphi/data_structures/array_like.py @@ -13,10 +13,12 @@ class ArrayLike(NDArrayOperatorsMixin): # TODO(tpm) populate this list _TYPE_CLOSED_FUNCTIONS = ( np.all, + np.any, np.concatenate, np.expand_dims, np.stack, np.sum, + np.result_type, ) # Holds the underlying array diff --git a/pyphi/network.py b/pyphi/network.py index 2f81b9169..62a223806 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -57,7 +57,8 @@ def __init__( if isinstance(tpm, (np.ndarray, ExplicitTPM)): # Validate tpm even if an ExplicitTPM was provided. ExplicitTPM # accepts instantiation from either another object of its class or - # np.ndarray, so the following achieves validation in general. + # np.ndarray, so the following achieves validation in general (and + # converstion to multidimensional form, as a side effect). tpm = ExplicitTPM(tpm, validate=True) self._state_space, _ = build_state_space( @@ -86,8 +87,7 @@ def __init__( self.purview_cache = purview_cache or cache.PurviewCache() - # TODO - # validate.network(self) + validate.network(self) @property def tpm(self): @@ -181,7 +181,10 @@ def potential_purviews(self, direction, mechanism): def __len__(self): """int: The number of nodes in the network.""" - return len(self.tpm) + try: + return len(self.tpm) + except AttributeError: + return self._cm.shape[0] def __repr__(self): # TODO implement a cleaner repr, similar to analyses objects, diff --git a/pyphi/node.py b/pyphi/node.py index c25b60f4f..419b4e2f7 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -212,9 +212,6 @@ def node( # the probability ("Pr") for each possible state of this node in # the next timestep. - # data_vars (xr.DataArray node names) and dimension names share the same - # dictionary-like namespace in xr.Dataset. Prepend constant "input_" string - # to avoid the conflict. if node_labels is None: indices = tuple(range(cm.shape[0])) node_labels = NodeLabels(None, indices) @@ -224,6 +221,9 @@ def node( if dim > 1 ) + # data_vars (xr.DataArray node names) and dimension names share the same + # dictionary-like namespace in xr.Dataset. Prepend constant "input_" string + # to avoid the conflict. dimensions = ["input_" + label for label in parent_node_labels] + ["Pr"] # For each dimension, compute the relevant state labels (coordinates in @@ -244,7 +244,7 @@ def node( # np.asarray(). return xr.DataArray( name = node_labels[index], - data = np.asarray(tpm.squeeze()), + data = tpm.squeeze(), dims = dimensions, coords = list(map(list, coordinates)), attrs = { diff --git a/pyphi/tpm.py b/pyphi/tpm.py index f187ecad2..b576c68ab 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -18,6 +18,76 @@ from .utils import all_states, np_hash, np_immutable +class TPM: + """TPM interface for derived classes.""" + + _ERROR_MSG_PROBABILITY_IMAGE = ( + "Invalid TPM: probabilities must be in the interval [0, 1]." + ) + + _ERROR_MSG_PROBABILITY_SUM = "Invalid TPM: probabilities must sum to 1." + + def validate(self, check_independence=True): + raise NotImplementedError + + def _validate_probabilities(self): + raise NotImplementedError + + def _validate_shape(self, check_independence=True): + raise NotImplementedError + + def to_multidimensional_state_by_node(self): + raise NotImplementedError + + def conditionally_independent(self): + raise NotImplementedError + + def condition_tpm(self, condition: Mapping[int, int]): + raise NotImplementedError + + def marginalize_out(self, node_indices): + raise NotImplementedError + + def is_deterministic(self): + raise NotImplementedError + + def is_state_by_state(self): + raise NotImplementedError + + def subtpm(self, fixed_nodes, state): + raise NotImplementedError + + def expand_tpm(self): + raise NotImplementedError + + def infer_edge(self, a, b, contexts): + raise NotImplementedError + + def infer_cm(self): + raise NotImplementedError + + def tpm_indices(self): + raise NotImplementedError + + def print(self): + raise NotImplementedError + + def permute_nodes(self, permutation): + raise NotImplementedError + + def __getitem__(self, i): + raise NotImplementedError + + def __str__(self): + raise NotImplementedError + + def __repr__(self): + raise NotImplementedError + + def __hash__(self): + raise NotImplementedError + + # TODO(tpm) remove pending ArrayLike refactor class ProxyMetaclass(type): """A metaclass to create wrappers for the TPM array's special attributes. @@ -159,7 +229,7 @@ def __init__(self): raise ValueError(f"Wrapped object must be of type {self.__wraps__}") -class ExplicitTPM(data_structures.ArrayLike): +class ExplicitTPM(data_structures.ArrayLike, TPM): """An explicit network TPM in multidimensional form. @@ -284,11 +354,9 @@ def validate(self, check_independence=True): def _validate_probabilities(self): """Check that the probabilities in a TPM are valid.""" if (self._tpm < 0.0).any() or (self._tpm > 1.0).any(): - raise ValueError( - "Invalid TPM: probabilities must be in the interval [0, 1]." - ) + raise ValueError(self._ERROR_MSG_PROBABILITY_IMAGE) if self.is_state_by_state() and np.any(np.sum(self._tpm, axis=1) != 1.0): - raise ValueError("Invalid TPM: probabilities must sum to 1.") + raise ValueError(self._ERROR_MSG_PROBABILITY_SUM) return True def _validate_shape(self, check_independence=True): @@ -569,8 +637,9 @@ def implicit_tpm(nodes, validate=False): data_vars = {node.name: node for node in nodes} ) + @xr.register_dataset_accessor("pyphi") -class ImplicitTPM: +class ImplicitTPM(TPM): """An implicit network TPM containing |Node| TPMs in multidimensional form. @@ -580,11 +649,31 @@ class ImplicitTPM: Attributes: """ - def validate(self, check_independence): - raise NotImplementedError + def __init__(self, dataset: xr.Dataset): + self._tpm = dataset + + def validate(self, check_independence=True): + """Validate this TPM.""" + return self._validate_probabilities() def _validate_probabilities(self): - raise NotImplementedError + """Check that the probabilities in a TPM are valid.""" + # An implicitTPM contains valid probabilities if individual node TPMs + # are valid. + if any( + (np.asarray(node_tpm).sum(axis=-1) != 1.0).any() + for node_tpm in self._tpm.data_vars.values() + ): + raise ValueError(self._ERROR_MSG_PROBABILITY_SUM) + + # Leverage method in ExplicitTPM to distribute validation of + # TPM image within [0, 1]. + if all( + node.data._validate_probabilities() + for node in self._tpm.data_vars.values() + ): + return True + def _validate_shape(self, check_independence=True): raise NotImplementedError diff --git a/pyphi/validate.py b/pyphi/validate.py index 14bbe5d70..38d5782db 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -62,7 +62,7 @@ def network(n): Checks the TPM and connectivity matrix. """ - n.tpm.validate() + n.tpm.pyphi.validate() connectivity_matrix(n.cm) if n.cm.shape[0] != n.size: raise ValueError( From 955fe6910517489ec98177c683aaa8a3c00740ff Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 18 Jan 2023 15:14:01 -0600 Subject: [PATCH 021/155] Update `Network` `__eq__()` method. --- pyphi/network.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyphi/network.py b/pyphi/network.py index dea76ef29..093166209 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -78,6 +78,7 @@ def __init__( elif isinstance(tpm, ImplicitTPM): self._tpm = tpm + # FIXME(TPM) initialization from JSON elif isinstance(tpm, dict): # From JSON. self._tpm = ImplicitTPM(tpm["_tpm"], validate=True) @@ -201,13 +202,14 @@ def __eq__(self, other): """ return ( isinstance(other, Network) - and self.tpm.array_equal(other.tpm) + and self.tpm.equals(other.tpm) and np.array_equal(self.cm, other.cm) ) def __ne__(self, other): return not self.__eq__(other) + # TODO(tpm): Immutability in xarray. def __hash__(self): return hash((hash(self.tpm), self._cm_hash)) From 03f6a043bdb1a237b21c9f3fa4b1fced316f404e Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 18 Jan 2023 15:21:05 -0600 Subject: [PATCH 022/155] `Network`: revert to a `__repr__` ammenable to JSON serialization. --- pyphi/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/network.py b/pyphi/network.py index 093166209..f92cb2582 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -190,7 +190,7 @@ def __len__(self): def __repr__(self): # TODO implement a cleaner repr, similar to analyses objects, # distinctions, etc. - return "Network(\n{},\ncm=\n{}\n)".format(self.tpm, self.cm) + return "Network({}, cm={})".format(self.tpm, self.cm) def __str__(self): return self.__repr__() From d192cbf7be4ab74a44f1c96d03a47a6630acbbd3 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 18 Jan 2023 15:33:53 -0600 Subject: [PATCH 023/155] Miscellaneous bug fixes in `Node` attribute setters and getters. --- pyphi/node.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyphi/node.py b/pyphi/node.py index 4f7115fae..35614ce53 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -82,7 +82,7 @@ def index(self): @property def label(self): """str: The textual label for this node.""" - return self.node_labels[self.index] + return self._node_labels[self.index] @property def tpm(self): @@ -127,7 +127,7 @@ def state(self): @state.setter def state(self, value): - if value not in self.state_space: + if value not in (*self.state_space, None): raise ValueError( f"Invalid node state. Possible states are {self.state_space}." ) From 60e67ec0aa4714f858b1c64fdc0412885f4fe7cf Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 18 Jan 2023 17:38:50 -0600 Subject: [PATCH 024/155] `tpm.ImplicitTPM`: Implement condition_tpm() --- pyphi/tpm.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index b576c68ab..0d247c295 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -684,8 +684,36 @@ def to_multidimensional_state_by_node(self): def conditionally_independent(self): raise NotImplementedError + # TODO accept label-state mapping as argument. Current solution + # relies on correct node order in data_vars. + # + # TODO(tpm) Refactor codebase to call tpm[Mapping] directly? def condition_tpm(self, condition: Mapping[int, int]): - raise NotImplementedError + """Return a TPM conditioned on the given fixed node indices, whose + states are fixed according to the given state-tuple. + + The dimensions of the new TPM that correspond to the fixed nodes are + collapsed onto their state, making those dimensions singletons suitable + for broadcasting. The number of dimensions of the conditioned TPM will + be the same as the unconditioned TPM. + + Args: + condition (dict[int, int]): A mapping from node indices to the state + to condition on for that node. + + Returns: + TPM: A conditioned TPM with the same number of dimensions, with + singleton dimensions for nodes in a fixed state. + """ + node_dimensions = ["input_" + dim for dim in self._tpm.data_vars.keys()] + + condition = { + node_dimensions[node_index]: state + for node_index, state in condition.items() + } + + # TODO: broadcasting + return self._tpm[condition] def marginalize_out(self, node_indices): raise NotImplementedError From 6301380dc1e03086dbb1392d53c0721da18d8cb1 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 18 Jan 2023 17:40:29 -0600 Subject: [PATCH 025/155] Minor import and docstring cleanup in `subsystem` --- pyphi/subsystem.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index bc5d13873..8a2c24770 100644 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -33,7 +33,6 @@ from .node import generate_nodes from .partition import mip_partitions from .repertoire import forward_repertoire, unconstrained_forward_repertoire -from .tpm import ExplicitTPM from .utils import build_state_space, state_of log = logging.getLogger(__name__) @@ -61,7 +60,7 @@ class Subsystem: Attributes: network (Network): The network the subsystem belongs to. - tpm (pyphi.tpm.ExplicitTPM): The TPM conditioned on the state + tpm (ImplicitTPM): The TPM conditioned on the state of the external nodes. cm (np.ndarray): The connectivity matrix after applying the cut. state (tuple[int]): The state of the network. From f4b7d743d78988ae58aa4536caebb26bd702613a Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 26 Jan 2023 23:02:38 -0600 Subject: [PATCH 026/155] Change state space structure to use a dictionary. Thus allowing xarray DataArrays to have anonymous singleton dimensions. --- pyphi/constants.py | 3 - pyphi/data_structures/array_like.py | 1 + pyphi/macro.py | 2 +- pyphi/network.py | 8 ++- pyphi/node.py | 98 ++++++++++++----------------- pyphi/state_space.py | 89 ++++++++++++++++++++++++++ pyphi/subsystem.py | 25 ++------ pyphi/utils.py | 50 +-------------- 8 files changed, 142 insertions(+), 134 deletions(-) create mode 100644 pyphi/state_space.py diff --git a/pyphi/constants.py b/pyphi/constants.py index 103d501ee..ea26b536b 100644 --- a/pyphi/constants.py +++ b/pyphi/constants.py @@ -21,6 +21,3 @@ # Probability value below which we issue a warning about precision. # TODO(4.0) TPM_WARNING_THRESHOLD = 1e-10 - -# State space used for singleton dimensions in per-node TPMs. -SINGLETON_STATE = ("_",) diff --git a/pyphi/data_structures/array_like.py b/pyphi/data_structures/array_like.py index b290f47f5..b7fa518c8 100644 --- a/pyphi/data_structures/array_like.py +++ b/pyphi/data_structures/array_like.py @@ -19,6 +19,7 @@ class ArrayLike(NDArrayOperatorsMixin): np.stack, np.sum, np.result_type, + np.broadcast_to, ) # Holds the underlying array diff --git a/pyphi/macro.py b/pyphi/macro.py index 60b69cb1f..78132525c 100644 --- a/pyphi/macro.py +++ b/pyphi/macro.py @@ -20,7 +20,7 @@ from .node import expand_node_tpm, generate_nodes from .subsystem import Subsystem from .tpm import ExplicitTPM -from .utils import build_state_space +from .state_space import build_state_space # Create a logger for this module. log = logging.getLogger(__name__) diff --git a/pyphi/network.py b/pyphi/network.py index f92cb2582..e47d4e004 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -13,7 +13,7 @@ from .labels import NodeLabels from .node import generate_nodes from .tpm import ExplicitTPM, ImplicitTPM, implicit_tpm -from .utils import build_state_space +from .state_space import build_state_space class Network: @@ -62,7 +62,9 @@ def __init__( tpm = ExplicitTPM(tpm, validate=True) self._state_space, _ = build_state_space( - tpm.shape[:-1], state_space + self._node_labels, + tpm.shape[:-1], + state_space ) self._tpm = implicit_tpm( @@ -84,7 +86,7 @@ def __init__( self._tpm = ImplicitTPM(tpm["_tpm"], validate=True) else: - raise TypeError(f"Invalid tpm of type {type(tpm)}.") + raise TypeError(f"Invalid TPM of type {type(tpm)}.") self.purview_cache = purview_cache or cache.PurviewCache() diff --git a/pyphi/node.py b/pyphi/node.py index 35614ce53..7ec287a66 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -9,16 +9,16 @@ import functools -from typing import Optional, Tuple, Union +from typing import Mapping, Optional, Tuple, Union import numpy as np import xarray as xr from .connectivity import get_inputs_from_cm, get_outputs_from_cm -from .constants import SINGLETON_STATE from .labels import NodeLabels +from .state_space import build_state_space, dimension_labels from .tpm import ExplicitTPM -from .utils import build_state_space, state_of +from .utils import state_of @xr.register_dataarray_accessor("pyphi") @functools.total_ordering @@ -31,15 +31,15 @@ class Node: Attributes: index (int): label (str): - tpm (|ExplicitTPM|): The node TPM is an array with - |n + 1| dimensions, where ``n`` is the size of the |Network|. The - first ``n`` dimensions correspond to each node in the - system. Dimensions corresponding to nodes that provide input to this - node are of size > 1, while those that do not correspond to inputs are - of size 1. The last dimension corresponds to the state of the - node in the next timestep, so that ``node.tpm[..., 0]`` gives - probabilities that the node will be 'OFF' and ``node.tpm[..., 1]`` - gives probabilities that the node will be 'ON'. + tpm (|ExplicitTPM|): The node TPM is an array with |n + 1| dimensions, + where ``n`` is the size of the |Network|. The first ``n`` dimensions + correspond to each node in the system. Dimensions corresponding to + nodes that provide input to this node are of size > 1, while those + that do not correspond to inputs are of size 1. The last dimension + corresponds to the state of the node in the next timestep, so that + ``node.tpm[..., 0]`` gives probabilities that the node will be 'OFF' + and ``node.tpm[..., 1]`` gives probabilities that the node will be + 'ON'. inputs (frozenset): outputs (frozenset): state_space (Tuple[Union[int, str]]): @@ -178,7 +178,7 @@ def to_json(self): def node( tpm: ExplicitTPM, cm: np.ndarray, - network_state_space: Tuple[Tuple[Union[int, str]]], + network_state_space: Mapping[str, Tuple[Union[int, str]]], index: int, state: Optional[Union[int, str]] = None, node_labels: Optional[NodeLabels] = None @@ -189,64 +189,46 @@ def node( Args: tpm (|ExplicitTPM|): The TPM of this node. cm (np.ndarray): The CM of the network. - network_state_space (Tuple[Tuple[Union[int, str]]]): Labels for the state - space of each node in the network. + network_state_space (Mapping[str, Tuple[Union[int, str]]]): + Labels for the state space of each node in the network. index (int): The node's index in the network. Keyword Args: state (Optional[Union[int, str]]): The state of this node. - node_labels (Optional[|NodeLabels|]): Labels for these nodes. + node_labels (Iterable[str]): Textual labels for each node in the network. Returns: xr.DataArray: The node in question. """ - - # Get indices of the inputs and outputs. - inputs = frozenset(get_inputs_from_cm(index, cm)) - outputs = frozenset(get_outputs_from_cm(index, cm)) - # Generate DataArray structure for this node # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # Dimensions are the names of this node's parents (whose state we - # can condition this node's TPM on), plus the last dimension with - # the probability ("Pr") for each possible state of this node in - # the next timestep. - - if node_labels is None: - indices = tuple(range(cm.shape[0])) - node_labels = NodeLabels(None, indices) - - parent_node_labels = tuple( - label for dim, label in zip(tpm.shape[:-1], node_labels) - if dim > 1 - ) - - # data_vars (xr.DataArray node names) and dimension names share the same - # dictionary-like namespace in xr.Dataset. Prepend constant "input_" string - # to avoid the conflict. - dimensions = ["input_" + label for label in parent_node_labels] + ["Pr"] - - # For each dimension, compute the relevant state labels (coordinates in - # xarray terminology) from the perspective of this node and its direct - # inputs. - state_space, _ = build_state_space( + # Dimensions are the names of this node's parents (whose state this node's + # TPM can be conditioned on), plus the last dimension with the probability + # for each possible state of this node in the next timestep. + dimensions = dimension_labels(node_labels) + + # Compute the relevant state labels (coordinates in xarray terminology) from + # the perspective of this node and its direct inputs. + new_network_state_space, _ = build_state_space( + node_labels, tpm.shape[:-1], - network_state_space, + network_state_space.values(), singleton_state_space = None, ) - node_state_space = network_state_space[index] + node_state_space = {dimensions[-1]: network_state_space[dimensions[index]]} + + coordinates = {**new_network_state_space, **node_state_space} - coordinates = [*state_space, node_state_space] + # Get indices of the inputs and outputs. + inputs = frozenset(get_inputs_from_cm(index, cm)) + outputs = frozenset(get_outputs_from_cm(index, cm)) - # TODO(tpm) implement np.result_type() in - # data_structures.array_like.__array_function__ to avoid converting with - # np.asarray(). return xr.DataArray( name = node_labels[index], - data = tpm.squeeze(), + data = tpm, dims = dimensions, - coords = list(map(list, coordinates)), + coords = coordinates, attrs = { "index": index, "node_labels": node_labels, @@ -261,7 +243,7 @@ def node( def generate_nodes( tpm: ExplicitTPM, cm: np.ndarray, - state_space: Tuple[Tuple[Union[int, str]]], + state_space: Mapping[str, Tuple[Union[int, str]]], indices: Tuple[int], network_state: Optional[Tuple[Union[int, str]]] = None, node_labels: Optional[NodeLabels] = None @@ -271,20 +253,18 @@ def generate_nodes( Args: tpm (|ExplicitTPM|): The system's TPM. cm (np.ndarray): The CM of the network. - state_space (Tuple[Tuple[Union[int, str]]]): Labels for the state - space of each node in the network. + state_space (Mapping[str, Tuple[Union[int, str]]]): Labels + for the state space of each node in the network. indices (Tuple[int]): Indices to generate nodes for. Keyword Args: - network_state (Optional[Tuple[Union[int, str]]]): The state of the network. + network_state (Optional[Tuple[Union[int, str]]]): The state of + the network. node_labels (|NodeLabels|): Textual labels for each node. Returns: Tuple[xr.DataArray]: The nodes of the system. """ - if node_labels is None: - node_labels = NodeLabels(None, indices) - if network_state is None: network_state = (None,) * cm.shape[0] diff --git a/pyphi/state_space.py b/pyphi/state_space.py new file mode 100644 index 000000000..1eef7249e --- /dev/null +++ b/pyphi/state_space.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# state_space.py + +""" +Constants and utility functions for dealing with the state space of a |Network|. +""" + +from typing import Iterable, List, Optional, Union, Tuple + +from .data_structures import FrozenMap + + +PARENT_DIMENSION_PREFIX = "input_" +PROBABILITY_DIMENSION_LABEL = "Pr" + + +def dimension_labels(node_labels: Iterable[str]) -> List[str]: + """Generate labels for each dimension in the |ImplicitTPM|. + + data_vars (xr.DataArray node names) and dimension names share the + same dictionary-like namespace in xr.Dataset. Prepend constant + string to avoid the conflict. + + Arguments: + node_labels (Iterable[str]): Textual labels for each node in the network. + + Returns: + List[str]: Textual labels for each dimension in a multidimensional TPM. + """ + return [PARENT_DIMENSION_PREFIX + label for label in node_labels] + [ + PROBABILITY_DIMENSION_LABEL + ] + + +def build_state_space( + node_labels: Iterable[str], + nodes_shape: Iterable[int], + node_states: Optional[Iterable[Iterable[Union[int, str]]]] = None, + singleton_state_space: Optional[Iterable[Union[int, str]]] = None, +) -> Tuple[FrozenMap[str, Tuple[Union[int, str]]], int]: + """Format the passed state space labels or construct defaults if none. + + Arguments: + node_labels (Iterable[str]): Textual labels for each node in the network. + nodes_shape (Iterable[int]): The first |n| components in the shape of + a multidimensional |ExplicitTPM|, where |n| is the number of nodes + in the network. + + Keyword Args: + node_states (Optional[Iterable[Iterable[Union[int, str]]]]): The + network's state space labels as provided by the user. + singleton_state_space (Optional[Iterable[Union[int, str]]]): The label + to be used for singleton dimensions. If ``None``, singleton + dimensions will be discarded. + + Returns: + Tuple[Tuple[Tuple[Union[int, str]]], int]: State space for the network + of interest and its hash. + """ + if node_states is None: + node_states = [tuple(range(dim)) for dim in nodes_shape] + else: + node_states = [tuple(n) for n in node_states] + + # labels-to-states map. + state_space = zip(dimension_labels(node_labels), node_states) + + # Filter out states of singleton dimensions. + shape_state_map = zip(nodes_shape, state_space) + + if singleton_state_space is None: + state_space = dict( + node_states + for dim, node_states in shape_state_map + if dim > 1 + ) + + else: + state_space = dict( + node_states if dim > 1 else singleton_state_space + for dim, node_states in shape_state_map + ) + + state_space = FrozenMap(state_space) + state_space_hash = hash(state_space) + state_space = FrozenMap({k: list(v) for k,v in state_space.items()}) + + return (state_space, state_space_hash) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 8a2c24770..310ddf3b8 100644 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -30,10 +30,9 @@ ) from .models.mechanism import StateSpecification from .network import irreducible_purviews -from .node import generate_nodes from .partition import mip_partitions from .repertoire import forward_repertoire, unconstrained_forward_repertoire -from .utils import build_state_space, state_of +from .utils import state_of log = logging.getLogger(__name__) @@ -109,15 +108,10 @@ def __init__( # Get the TPM conditioned on the state of the external nodes. external_state = utils.state_of(self.external_indices, self.state) background_conditions = dict(zip(self.external_indices, external_state)) - self.tpm = self.network.tpm.condition_tpm(background_conditions) + self.tpm = self.network.tpm.pyphi.condition_tpm(background_conditions) # The TPM for just the nodes in the subsystem. - self.proper_tpm = self.tpm.squeeze()[..., list(self.node_indices)] - - # The state space of the nodes in the candidate system. - self.proper_state_space, _ = build_state_space( - self.tpm[:-1], - self.network.state_space - ) + labels_in_subsystem = self.node_labels.indices2labels(self.node_indices) + self.proper_tpm = self.tpm[list(labels_in_subsystem)] # The unidirectional cut applied for phi evaluation self.cut = ( @@ -144,14 +138,7 @@ def __init__( unconstrained_forward_repertoire_cache or cache.DictCache() ) - self.nodes = generate_nodes( - self.tpm, - self.cm, - self.proper_state_space, - self.node_indices, - network_state=self.state, - node_labels=self.node_labels - ) + self.nodes = tuple(self.tpm.data_vars.values()) validate.subsystem(self) @@ -167,7 +154,7 @@ def nodes(self, value): """ # pylint: disable=attribute-defined-outside-init self._nodes = value - self._index2node = {node.index: node for node in self._nodes} + self._index2node = {node.pyphi.index: node for node in self._nodes} @property def proper_state(self): diff --git a/pyphi/utils.py b/pyphi/utils.py index b56a04e32..80f32e1d6 100644 --- a/pyphi/utils.py +++ b/pyphi/utils.py @@ -12,7 +12,7 @@ import operator import os from itertools import chain, combinations, product -from typing import Iterable, Optional, Union, Tuple +from typing import Tuple import numpy as np from scipy.special import comb @@ -79,54 +79,6 @@ def state_of_subsystem_nodes(node_indices, nodes, subsystem_state): return state_of([node_indices.index(n) for n in nodes], subsystem_state) -def build_state_space( - nodes_shape: Iterable[int], - state_space: Optional[Iterable[Iterable[Union[int, str]]]] = None, - singleton_state_space: Optional[Iterable[Union[int, str]]] = None, -) -> Tuple[Tuple[Tuple[Union[int, str]]], int]: - """Format the passed state space labels or construct defaults if none. - - Arguments: - nodes_shape (Iterable[int]): The first |n| components in the shape of - a multidimensional |ExplicitTPM|, where |n| is the number of nodes - in the network. - - Keyword Args: - state_space (Optional[Iterable[Iterable[Union[int, str]]]]): The - network's state space labels as provided by the user. - singleton_state_space (Optional[Iterable[Union[int, str]]]): The label to - be used for singleton dimensions. If ``None``, singleton dimensions - will be discarded. - - Returns: - Tuple[Tuple[Tuple[Union[int, str]]], int]: State space for the network of - interest and its hash. - """ - if state_space is None: - state_space = tuple(tuple(range(dim)) for dim in nodes_shape) - else: - # Enforce tuples. - state_space = map(tuple, state_space) - - # Filter out states of singleton dimensions. - shape_state_map = zip(nodes_shape, state_space) - - if singleton_state_space is None: - state_space = tuple( - node_states - for dim, node_states in shape_state_map - if dim > 1 - ) - - else: - state_space = tuple( - node_states if dim > 1 else singleton_state_space - for dim, node_states in shape_state_map - ) - - return (state_space, hash(state_space)) - - # TODO: nonbinary states def all_states(n, big_endian=False): """Return all binary states for a system. From 1687293bed973ee8d899154e8efac8422bbd4fba Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 1 Feb 2023 19:22:57 -0600 Subject: [PATCH 027/155] Revert to using dummy singleton dimensions in multidimensional node TPMs It turns out that, while allowed on a DataArray-level, nameless singleton dimensions cannot be aligned at the Dataset level. --- pyphi/network.py | 2 +- pyphi/node.py | 10 +++++----- pyphi/state_space.py | 12 ++++++------ 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index e47d4e004..2b123f9bd 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -90,7 +90,7 @@ def __init__( self.purview_cache = purview_cache or cache.PurviewCache() - validate.network(self) + # validate.network(self) @property def tpm(self): diff --git a/pyphi/node.py b/pyphi/node.py index 7ec287a66..bd54629c1 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -213,12 +213,12 @@ def node( node_labels, tpm.shape[:-1], network_state_space.values(), - singleton_state_space = None, + singleton_state_space = ("_",), ) - node_state_space = {dimensions[-1]: network_state_space[dimensions[index]]} + node_state_space = network_state_space[dimensions[index]] - coordinates = {**new_network_state_space, **node_state_space} + coordinates = {**new_network_state_space, dimensions[-1]: node_state_space} # Get indices of the inputs and outputs. inputs = frozenset(get_inputs_from_cm(index, cm)) @@ -226,7 +226,7 @@ def node( return xr.DataArray( name = node_labels[index], - data = tpm, + data = np.asarray(tpm), dims = dimensions, coords = coordinates, attrs = { @@ -234,7 +234,7 @@ def node( "node_labels": node_labels, "inputs": inputs, "outputs": outputs, - "state_space": node_state_space, + "state_space": tuple(node_state_space), "state": state, } ) diff --git a/pyphi/state_space.py b/pyphi/state_space.py index 1eef7249e..96e5dcf92 100644 --- a/pyphi/state_space.py +++ b/pyphi/state_space.py @@ -55,7 +55,7 @@ def build_state_space( dimensions will be discarded. Returns: - Tuple[Tuple[Tuple[Union[int, str]]], int]: State space for the network + Tuple[FrozenMap[str, Tuple[Union[int, str]]], int]: State space for the network of interest and its hash. """ if node_states is None: @@ -70,17 +70,17 @@ def build_state_space( shape_state_map = zip(nodes_shape, state_space) if singleton_state_space is None: - state_space = dict( + state_space = { node_states for dim, node_states in shape_state_map if dim > 1 - ) + } else: - state_space = dict( - node_states if dim > 1 else singleton_state_space + state_space = { + node_states if dim > 1 else (node_states[0], singleton_state_space) for dim, node_states in shape_state_map - ) + } state_space = FrozenMap(state_space) state_space_hash = hash(state_space) From 719cb7c4d03b34b56d5b13918165265ecfee624e Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 2 Feb 2023 10:42:26 -0600 Subject: [PATCH 028/155] `Node._hash`: temporary workaround when the TPM array is mutable. --- pyphi/node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/node.py b/pyphi/node.py index bd54629c1..ec2bea0ad 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -66,7 +66,7 @@ def __init__(self, dataarray: xr.DataArray): self._hash = hash( ( self.index, - hash(self.tpm), + hash(ExplicitTPM(self.tpm)), self._inputs, self._outputs, self.state_space, From db89127f37a1f9e26995ac34a0ff9fe57357422b Mon Sep 17 00:00:00 2001 From: Isaac David Date: Fri, 17 Feb 2023 13:52:35 -0600 Subject: [PATCH 029/155] Implement function to properly unalign DataArray's --- pyphi/node.py | 56 ++++++++++++++++++++++++++++++++++++++------ pyphi/state_space.py | 33 ++++++++++++++++++-------- pyphi/subsystem.py | 10 +++++--- pyphi/tpm.py | 6 +++-- 4 files changed, 83 insertions(+), 22 deletions(-) diff --git a/pyphi/node.py b/pyphi/node.py index ec2bea0ad..a241172fb 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -16,7 +16,13 @@ from .connectivity import get_inputs_from_cm, get_outputs_from_cm from .labels import NodeLabels -from .state_space import build_state_space, dimension_labels +from .state_space import ( + dimension_labels, + input_dimension_label, + build_state_space, + PROBABILITY_DIMENSION, + SINGLETON_COORDINATE, +) from .tpm import ExplicitTPM from .utils import state_of @@ -55,6 +61,7 @@ def __init__(self, dataarray: xr.DataArray): self._inputs = dataarray.attrs["inputs"] self._outputs = dataarray.attrs["outputs"] + self._dataarray = dataarray self._tpm = dataarray.data self.state_space = dataarray.attrs["state_space"] @@ -106,19 +113,19 @@ def state_space(self): @state_space.setter def state_space(self, value): - state_space = tuple(value) + _state_space = tuple(value) - if len(set(state_space)) < len(state_space): + if len(set(_state_space)) < len(_state_space): raise ValueError( "Invalid node state space tuple. Repeated states are ambiguous." ) - if len(state_space) < 2: + if len(_state_space) < 2: raise ValueError( "Invalid node state space with less than 2 states." ) - self._state_space = state_space + self._state_space = _state_space @property def state(self): @@ -134,6 +141,42 @@ def state(self, value): self._state = value + def streamline(self): + """Remove superfluous coordinates from an unaligned |Node| TPM. + + Returns: + xr.DataArray: The |Node| TPM re-represented without coordinates + introduced during xr.Dataset alignment. + """ + node_labels = self._node_labels + + node_indices = frozenset(range(len(node_labels))) + inputs = self.inputs + noninputs = node_indices - inputs + + input_dims = [input_dimension_label(node_labels[i]) for i in inputs] + noninput_dims = [input_dimension_label(node_labels[i]) for i in noninputs] + + new_input_coords = { + dim: [ + coord for coord in self._dataarray.coords[dim].data + if coord != SINGLETON_COORDINATE + ] + for dim in input_dims + } + new_noninput_coords = { + dim: [SINGLETON_COORDINATE] for dim in noninput_dims + } + probability_coords = list(self._dataarray.coords[PROBABILITY_DIMENSION].data) + + new_coords = { + **new_input_coords, + **new_noninput_coords, + PROBABILITY_DIMENSION: probability_coords, + } + + return self._dataarray.reindex(new_coords) + def __repr__(self): return self.label @@ -149,7 +192,6 @@ def __eq__(self, other): Labels are for display only, so two equal nodes may have different labels. - """ return ( self.index == other.index @@ -213,7 +255,7 @@ def node( node_labels, tpm.shape[:-1], network_state_space.values(), - singleton_state_space = ("_",), + singleton_state_space = (SINGLETON_COORDINATE,), ) node_state_space = network_state_space[dimensions[index]] diff --git a/pyphi/state_space.py b/pyphi/state_space.py index 96e5dcf92..ecd3b7b31 100644 --- a/pyphi/state_space.py +++ b/pyphi/state_space.py @@ -11,26 +11,39 @@ from .data_structures import FrozenMap -PARENT_DIMENSION_PREFIX = "input_" -PROBABILITY_DIMENSION_LABEL = "Pr" +INPUT_DIMENSION_PREFIX = "input_" +PROBABILITY_DIMENSION = "Pr" +SINGLETON_COORDINATE = "_" -def dimension_labels(node_labels: Iterable[str]) -> List[str]: - """Generate labels for each dimension in the |ImplicitTPM|. +def input_dimension_label(node_label: str) -> str: + """Generate label for an input dimension in the |ImplicitTPM|. data_vars (xr.DataArray node names) and dimension names share the same dictionary-like namespace in xr.Dataset. Prepend constant string to avoid the conflict. - Arguments: + Args: + node_label (str): Textual label for a node in the network. + + Returns: + str: Textual label for the same dimension in the multidimensional TPM. + """ + return INPUT_DIMENSION_PREFIX + str(node_label) + +def dimension_labels(node_labels: Iterable[str]) -> List[str]: + """Generate labels for each dimension in the |ImplicitTPM|. + + Args: node_labels (Iterable[str]): Textual labels for each node in the network. Returns: - List[str]: Textual labels for each dimension in a multidimensional TPM. + List[str]: Textual labels for each dimension in the multidimensional TPM. """ - return [PARENT_DIMENSION_PREFIX + label for label in node_labels] + [ - PROBABILITY_DIMENSION_LABEL - ] + return ( + list(map(input_dimension_label, node_labels)) + + [PROBABILITY_DIMENSION] + ) def build_state_space( @@ -41,7 +54,7 @@ def build_state_space( ) -> Tuple[FrozenMap[str, Tuple[Union[int, str]]], int]: """Format the passed state space labels or construct defaults if none. - Arguments: + Args: node_labels (Iterable[str]): Textual labels for each node in the network. nodes_shape (Iterable[int]): The first |n| components in the shape of a multidimensional |ExplicitTPM|, where |n| is the number of nodes diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 310ddf3b8..554f395a6 100644 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -138,9 +138,11 @@ def __init__( unconstrained_forward_repertoire_cache or cache.DictCache() ) - self.nodes = tuple(self.tpm.data_vars.values()) + self.nodes = tuple( + node.pyphi.streamline() for node in self.tpm.data_vars.values() + ) - validate.subsystem(self) + # validate.subsystem(self) @property def nodes(self): @@ -154,7 +156,9 @@ def nodes(self, value): """ # pylint: disable=attribute-defined-outside-init self._nodes = value - self._index2node = {node.pyphi.index: node for node in self._nodes} + self._index2node = { + node.pyphi.index: node.pyphi.streamline() for node in self._nodes + } @property def proper_state(self): diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 0d247c295..88281edbc 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -658,8 +658,10 @@ def validate(self, check_independence=True): def _validate_probabilities(self): """Check that the probabilities in a TPM are valid.""" - # An implicitTPM contains valid probabilities if individual node TPMs - # are valid. + # An implicit TPM contains valid probabilities if and only if + # individual node TPMs contain valid probabilities, for every node. + + # Validate that probabilities sum to 1. if any( (np.asarray(node_tpm).sum(axis=-1) != 1.0).any() for node_tpm in self._tpm.data_vars.values() From 819b34fa3357bdc21c2472d8def84d3a0d89df74 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Mon, 20 Feb 2023 11:35:44 -0600 Subject: [PATCH 030/155] `Node.streamline()`: use more consistent variable declarations. --- pyphi/node.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pyphi/node.py b/pyphi/node.py index a241172fb..66215aeb7 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -167,12 +167,16 @@ def streamline(self): new_noninput_coords = { dim: [SINGLETON_COORDINATE] for dim in noninput_dims } - probability_coords = list(self._dataarray.coords[PROBABILITY_DIMENSION].data) + probability_coords = { + PROBABILITY_DIMENSION: list( + self._dataarray.coords[PROBABILITY_DIMENSION].data + ) + } new_coords = { **new_input_coords, **new_noninput_coords, - PROBABILITY_DIMENSION: probability_coords, + **probability_coords, } return self._dataarray.reindex(new_coords) From fd35bcc6c804d935a22507350ef2d020c41bd052 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Mon, 20 Feb 2023 13:00:54 -0600 Subject: [PATCH 031/155] Avoid xarray Dataset altogether There is no good way of having single-node TPMs in a single Dataset without blowing up memory usage (due to alignment of non-shared singleton dimensions): https://github.com/pydata/xarray/issues/1471 Instead we will implement our own indexing operation to distribute it across the sequence of DataArray nodes. As a bonus, we can implement positional indexing, which xarray doesn't support at the Dataset level. --- pyphi/network.py | 4 ++-- pyphi/node.py | 2 +- pyphi/tpm.py | 40 ++++++++++++++++------------------------ 3 files changed, 19 insertions(+), 27 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index 2b123f9bd..2096f0344 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -12,7 +12,7 @@ from . import cache, connectivity, jsonify, utils, validate from .labels import NodeLabels from .node import generate_nodes -from .tpm import ExplicitTPM, ImplicitTPM, implicit_tpm +from .tpm import ExplicitTPM, ImplicitTPM from .state_space import build_state_space @@ -67,7 +67,7 @@ def __init__( state_space ) - self._tpm = implicit_tpm( + self._tpm = ImplicitTPM( generate_nodes( tpm, self._cm, diff --git a/pyphi/node.py b/pyphi/node.py index 66215aeb7..1d86cf30a 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -272,7 +272,7 @@ def node( return xr.DataArray( name = node_labels[index], - data = np.asarray(tpm), + data = tpm, dims = dimensions, coords = coordinates, attrs = { diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 88281edbc..440361a58 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -7,7 +7,7 @@ """ from itertools import chain -from typing import Mapping, Set +from typing import Mapping, Set, Tuple import numpy as np import xarray as xr @@ -75,9 +75,6 @@ def print(self): def permute_nodes(self, permutation): raise NotImplementedError - def __getitem__(self, i): - raise NotImplementedError - def __str__(self): raise NotImplementedError @@ -629,16 +626,6 @@ def __hash__(self): return self._hash -def implicit_tpm(nodes, validate=False): - - """Instantiate an implicit network TPM Dataset.""" - - return xr.Dataset( - data_vars = {node.name: node for node in nodes} - ) - - -@xr.register_dataset_accessor("pyphi") class ImplicitTPM(TPM): """An implicit network TPM containing |Node| TPMs in multidimensional form. @@ -649,8 +636,13 @@ class ImplicitTPM(TPM): Attributes: """ - def __init__(self, dataset: xr.Dataset): - self._tpm = dataset + def __init__(self, nodes: Tuple[xr.DataArray]): + self._nodes = nodes + + @property + def nodes(self): + """Tuple[xr.DataArray]: The node TPMs in this ImplicitTPM""" + return self._nodes def validate(self, check_independence=True): """Validate this TPM.""" @@ -664,15 +656,15 @@ def _validate_probabilities(self): # Validate that probabilities sum to 1. if any( (np.asarray(node_tpm).sum(axis=-1) != 1.0).any() - for node_tpm in self._tpm.data_vars.values() + for node_tpm in self._nodes ): raise ValueError(self._ERROR_MSG_PROBABILITY_SUM) # Leverage method in ExplicitTPM to distribute validation of # TPM image within [0, 1]. if all( - node.data._validate_probabilities() - for node in self._tpm.data_vars.values() + node_tpm.data._validate_probabilities() + for node_tpm in self._nodes ): return True @@ -707,15 +699,15 @@ def condition_tpm(self, condition: Mapping[int, int]): TPM: A conditioned TPM with the same number of dimensions, with singleton dimensions for nodes in a fixed state. """ - node_dimensions = ["input_" + dim for dim in self._tpm.data_vars.keys()] + node_dimensions = ["input_" + node.label for node in self.nodes] - condition = { + conditioning_index = { node_dimensions[node_index]: state for node_index, state in condition.items() } # TODO: broadcasting - return self._tpm[condition] + return self[conditioning_index] def marginalize_out(self, node_indices): raise NotImplementedError @@ -754,10 +746,10 @@ def __getitem__(self, i): raise NotImplementedError def __str__(self): - raise NotImplementedError + return self.__repr__() def __repr__(self): - raise NotImplementedError + return "ImplicitTPM({})".format(self._nodes) def __hash__(self): raise NotImplementedError From ebe1c3a29b15c330c7a7260b0a1033c6ca5cff63 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Mon, 20 Feb 2023 15:30:24 -0600 Subject: [PATCH 032/155] Network.__init__(): Handle network as tpm argument --- pyphi/network.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/pyphi/network.py b/pyphi/network.py index 2096f0344..cd1405f4a 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -7,11 +7,13 @@ context of all |small_phi| and |big_phi| computation. """ +from typing import Iterable import numpy as np +import xarray as xr from . import cache, connectivity, jsonify, utils, validate from .labels import NodeLabels -from .node import generate_nodes +from .node import Node, generate_nodes from .tpm import ExplicitTPM, ImplicitTPM from .state_space import build_state_space @@ -85,6 +87,29 @@ def __init__( # From JSON. self._tpm = ImplicitTPM(tpm["_tpm"], validate=True) + elif isinstance(tpm, Iterable): + invalid = [i for i in tpm if not isinstance(i, (np.ndarray, ExplicitTPM))] + + if invalid: + raise TypeError(f"Invalid set of nodes containing type {', type '.join(str(i) for i in invalid)}.") + + tpm = tuple(ExplicitTPM(node_tpm, validate=True) for node_tpm in tpm) + + shapes = [node.shape for node in tpm] + + if not all(len(shape) == len(shapes[0]) for shape in shapes): + raise ValueError("Provided set of nodes contains varying number of dimensions.") + + network_tpm_shape = [max(shape[i] for shape in shapes) for i in range(len(shapes[0]))] + + self.state_space, _ = build_state_space( + self._node_labels, + network_tpm_shape, + state_space + ) + + self._tpm = ImplicitTPM(tpm) + else: raise TypeError(f"Invalid TPM of type {type(tpm)}.") From db8ce13b5ea024c39e6144d771b44c45818b4ec6 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Mon, 20 Feb 2023 15:32:43 -0600 Subject: [PATCH 033/155] Fix error message --- pyphi/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/network.py b/pyphi/network.py index cd1405f4a..8f36b06db 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -91,7 +91,7 @@ def __init__( invalid = [i for i in tpm if not isinstance(i, (np.ndarray, ExplicitTPM))] if invalid: - raise TypeError(f"Invalid set of nodes containing type {', type '.join(str(i) for i in invalid)}.") + raise TypeError(f"Invalid set of nodes containing {', '.join(str(i) for i in invalid)}.") tpm = tuple(ExplicitTPM(node_tpm, validate=True) for node_tpm in tpm) From ded144b59e840fb89cd480a81db0976d25b1e8a8 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Mon, 20 Feb 2023 17:34:09 -0600 Subject: [PATCH 034/155] `Network`: Fix `build_cm()` when None is provided. --- pyphi/network.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index 8f36b06db..86e2f195b 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -50,7 +50,7 @@ def __init__( state_space=None, purview_cache=None ): - self._cm, self._cm_hash = self._build_cm(cm) + self._cm, self._cm_hash = self._build_cm(cm, tpm) self._node_indices = tuple(range(self.size)) self._node_labels = NodeLabels(node_labels, self._node_indices) @@ -89,25 +89,25 @@ def __init__( elif isinstance(tpm, Iterable): invalid = [i for i in tpm if not isinstance(i, (np.ndarray, ExplicitTPM))] - + if invalid: raise TypeError(f"Invalid set of nodes containing {', '.join(str(i) for i in invalid)}.") - - tpm = tuple(ExplicitTPM(node_tpm, validate=True) for node_tpm in tpm) - + + tpm = tuple(ExplicitTPM(node_tpm, validate=False) for node_tpm in tpm) + shapes = [node.shape for node in tpm] - + if not all(len(shape) == len(shapes[0]) for shape in shapes): raise ValueError("Provided set of nodes contains varying number of dimensions.") - + network_tpm_shape = [max(shape[i] for shape in shapes) for i in range(len(shapes[0]))] - - self.state_space, _ = build_state_space( + + self._state_space, _ = build_state_space( self._node_labels, network_tpm_shape, state_space ) - + self._tpm = ImplicitTPM(tpm) else: @@ -134,13 +134,18 @@ def cm(self): """ return self._cm - def _build_cm(self, cm): + def _build_cm(self, cm, tpm): """Convert the passed CM to the proper format, or construct the unitary CM if none was provided. """ if cm is None: + try: + size = tpm.shape[-1] + except AttributeError: + size = len(tpm) + # Assume all are connected. - cm = np.ones((self.size, self.size)) + cm = np.ones((size, size)) else: cm = np.array(cm) @@ -165,7 +170,7 @@ def size(self): @property def state_space(self): - """tuple[tuple[Union[int, str]]: Labels for the state space of each node. + """tuple[tuple[Union[int, str]]]: Labels for the state space of each node. """ return self._state_space @@ -209,10 +214,7 @@ def potential_purviews(self, direction, mechanism): def __len__(self): """int: The number of nodes in the network.""" - try: - return len(self.tpm) - except AttributeError: - return self._cm.shape[0] + return self._cm.shape[0] def __repr__(self): # TODO implement a cleaner repr, similar to analyses objects, From db677d17b51ac82adfda4662e5a9a5c0e76775ce Mon Sep 17 00:00:00 2001 From: Isaac David Date: Mon, 20 Feb 2023 17:41:26 -0600 Subject: [PATCH 035/155] `Network.__init__()`: maximum line length. --- pyphi/network.py | 36 +++++++++++++++++++++++------------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index 86e2f195b..9afab347a 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -79,28 +79,30 @@ def __init__( ) ) - elif isinstance(tpm, ImplicitTPM): - self._tpm = tpm - - # FIXME(TPM) initialization from JSON - elif isinstance(tpm, dict): - # From JSON. - self._tpm = ImplicitTPM(tpm["_tpm"], validate=True) - elif isinstance(tpm, Iterable): - invalid = [i for i in tpm if not isinstance(i, (np.ndarray, ExplicitTPM))] + invalid = [ + i for i in tpm if not isinstance(i, (np.ndarray, ExplicitTPM)) + ] if invalid: - raise TypeError(f"Invalid set of nodes containing {', '.join(str(i) for i in invalid)}.") + raise TypeError("Invalid set of nodes containing {}.".format( + ', '.join(str(i) for i in invalid) + )) - tpm = tuple(ExplicitTPM(node_tpm, validate=False) for node_tpm in tpm) + tpm = tuple( + ExplicitTPM(node_tpm, validate=False) for node_tpm in tpm + ) shapes = [node.shape for node in tpm] if not all(len(shape) == len(shapes[0]) for shape in shapes): - raise ValueError("Provided set of nodes contains varying number of dimensions.") + raise ValueError( + "Provided set of nodes contains varying number of dimensions." + ) - network_tpm_shape = [max(shape[i] for shape in shapes) for i in range(len(shapes[0]))] + network_tpm_shape = [ + max(shape[i] for shape in shapes) for i in range(len(shapes[0])) + ] self._state_space, _ = build_state_space( self._node_labels, @@ -110,6 +112,14 @@ def __init__( self._tpm = ImplicitTPM(tpm) + elif isinstance(tpm, ImplicitTPM): + self._tpm = tpm + + # FIXME(TPM) initialization from JSON + elif isinstance(tpm, dict): + # From JSON. + self._tpm = ImplicitTPM(tpm["_tpm"], validate=True) + else: raise TypeError(f"Invalid TPM of type {type(tpm)}.") From a6c3b3a6c698bb683ea8110be3400089623712a9 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Mon, 20 Feb 2023 22:38:35 -0600 Subject: [PATCH 036/155] `Network.__init__()`: Pass xarray nodes to the `ImplicitTPM` constructor. --- pyphi/network.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index 9afab347a..9647d3602 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -13,7 +13,7 @@ from . import cache, connectivity, jsonify, utils, validate from .labels import NodeLabels -from .node import Node, generate_nodes +from .node import generate_nodes, node from .tpm import ExplicitTPM, ImplicitTPM from .state_space import build_state_space @@ -97,7 +97,7 @@ def __init__( if not all(len(shape) == len(shapes[0]) for shape in shapes): raise ValueError( - "Provided set of nodes contains varying number of dimensions." + "The provided node TPMs contain varying number of dimensions." ) network_tpm_shape = [ @@ -110,7 +110,18 @@ def __init__( state_space ) - self._tpm = ImplicitTPM(tpm) + self._tpm = ImplicitTPM( + tuple( + node( + node_tpm, + self._cm, + self._state_space, + index, + node_labels=self._node_labels + ) + for index, node_tpm in zip(self._node_indices, tpm) + ) + ) elif isinstance(tpm, ImplicitTPM): self._tpm = tpm From 3ca0284f60c66584647d399bbcde4ca0d3cc4b81 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 21 Feb 2023 17:14:46 -0600 Subject: [PATCH 037/155] Implement proper indexing for the `ImplicitTPM` --- pyphi/network.py | 2 +- pyphi/node.py | 39 +++++++++++++++++++++++++++++++++++++++ pyphi/tpm.py | 19 ++++++++++++++++--- 3 files changed, 56 insertions(+), 4 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index 9647d3602..f67d9e561 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -129,7 +129,7 @@ def __init__( # FIXME(TPM) initialization from JSON elif isinstance(tpm, dict): # From JSON. - self._tpm = ImplicitTPM(tpm["_tpm"], validate=True) + self._tpm = ImplicitTPM(tpm["_tpm"]) else: raise TypeError(f"Invalid TPM of type {type(tpm)}.") diff --git a/pyphi/node.py b/pyphi/node.py index 1d86cf30a..7e6d74b33 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -141,6 +141,45 @@ def state(self, value): self._state = value + def project_index(self, index): + """Convert absolute TPM index to a valid index relative to this node.""" + + # Supported index coordinates (in the right dimension order) respective + # to this node, to be used as an AND mask, with 0 being + # SINGLETON_COORDINATE. + + # TODO(tpm) make a Node attribute? (similar to `state_space`). + support = { + dim: tuple(self._dataarray.coords[dim].values) + for dim in self._dataarray.dims + } + + if isinstance(index, dict): + projected_index = { + key: value if support[key] != (SINGLETON_COORDINATE,) + else SINGLETON_COORDINATE + for key, value in index.items() + } + + print(projected_index) + return projected_index + + # Assume regular index otherwise. + if not isinstance(index, tuple): + # Index is a single int, slice, ellipsis or intra-dimension list. + index = (index,) + + index_support_map = zip(index, support.values()) + + projected_index = tuple( + i if support != (SINGLETON_COORDINATE,) + else slice(None) + for i, support in index_support_map + ) + + print(projected_index) + return projected_index + def streamline(self): """Remove superfluous coordinates from an unaligned |Node| TPM. diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 440361a58..8df7ba431 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -742,14 +742,27 @@ def print(self): def permute_nodes(self, permutation): raise NotImplementedError - def __getitem__(self, i): - raise NotImplementedError + def __getitem__(self, index): + if isinstance(index, (int, slice, type(...), tuple)): + return ImplicitTPM( + tuple( + node[node.pyphi.project_index(index)] + for node in self.nodes + ) + ) + if isinstance(index, dict): + return ImplicitTPM( + tuple( + node.loc[node.pyphi.project_index(index)] + for node in self.nodes + ) + ) def __str__(self): return self.__repr__() def __repr__(self): - return "ImplicitTPM({})".format(self._nodes) + return "ImplicitTPM({})".format(self.nodes) def __hash__(self): raise NotImplementedError From 83c64bbdf30a674531083c345cde38f24463384b Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 22 Feb 2023 11:39:05 -0600 Subject: [PATCH 038/155] `Node.project_index()`: remove leftover print statements. --- pyphi/node.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyphi/node.py b/pyphi/node.py index 7e6d74b33..c8dd229f6 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -161,7 +161,6 @@ def project_index(self, index): for key, value in index.items() } - print(projected_index) return projected_index # Assume regular index otherwise. @@ -177,7 +176,6 @@ def project_index(self, index): for i, support in index_support_map ) - print(projected_index) return projected_index def streamline(self): From 88989c35720cb1aff8e963b31f1df1a84cb27893 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 22 Feb 2023 12:42:30 -0600 Subject: [PATCH 039/155] `tpm`: miscellaneous cleanup. --- pyphi/tpm.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 8df7ba431..f92cd995f 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -12,12 +12,11 @@ import numpy as np import xarray as xr -from . import config, convert, data_structures, exceptions +from . import config, convert, data_structures, exceptions, state_space from .constants import OFF, ON from .data_structures import FrozenMap from .utils import all_states, np_hash, np_immutable - class TPM: """TPM interface for derived classes.""" @@ -655,7 +654,7 @@ def _validate_probabilities(self): # Validate that probabilities sum to 1. if any( - (np.asarray(node_tpm).sum(axis=-1) != 1.0).any() + (node_tpm.data.sum(axis=-1) != 1.0).any() for node_tpm in self._nodes ): raise ValueError(self._ERROR_MSG_PROBABILITY_SUM) @@ -699,7 +698,10 @@ def condition_tpm(self, condition: Mapping[int, int]): TPM: A conditioned TPM with the same number of dimensions, with singleton dimensions for nodes in a fixed state. """ - node_dimensions = ["input_" + node.label for node in self.nodes] + node_dimensions = [ + state_space.INPUT_DIMENSION_PREFIX + node.label + for node in self.nodes + ] conditioning_index = { node_dimensions[node_index]: state From 47e09f3fb1f4bad3ec7653eef0a47cf4d927e735 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 22 Feb 2023 12:44:12 -0600 Subject: [PATCH 040/155] `tpm.ImplicitTPM`: implement __len__ method. --- pyphi/tpm.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index f92cd995f..0b2902817 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -760,6 +760,10 @@ def __getitem__(self, index): ) ) + def __len__(self): + """int: The number of nodes in the TPM.""" + return len(self._nodes) + def __str__(self): return self.__repr__() From 150b8c335a5d8bcb32e755d074fb27a2ccdc6b97 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 22 Feb 2023 13:14:57 -0600 Subject: [PATCH 041/155] `Node.project_index`: make positional indexing consistent with name indexing. Preservation of singleton dimensions must be accounted for by tpm.condition_tpm --- pyphi/node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/node.py b/pyphi/node.py index c8dd229f6..072418f85 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -172,7 +172,7 @@ def project_index(self, index): projected_index = tuple( i if support != (SINGLETON_COORDINATE,) - else slice(None) + else 0 for i, support in index_support_map ) From 8d898797fbe3596d92d78655b074929c6a4be9ce Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 22 Feb 2023 13:23:51 -0600 Subject: [PATCH 042/155] `tpm.ImplicitTPM: implement ndim property` --- pyphi/tpm.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 0b2902817..7bbaa61d6 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -643,6 +643,11 @@ def nodes(self): """Tuple[xr.DataArray]: The node TPMs in this ImplicitTPM""" return self._nodes + @property + def ndim(self): + """int: The number of dimensions of the TPM.""" + return len(self) + 1 + def validate(self, check_independence=True): """Validate this TPM.""" return self._validate_probabilities() @@ -700,7 +705,7 @@ def condition_tpm(self, condition: Mapping[int, int]): """ node_dimensions = [ state_space.INPUT_DIMENSION_PREFIX + node.label - for node in self.nodes + for node in self._nodes ] conditioning_index = { From a615250a795224645f77b1eb38f993c72923fd02 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 22 Feb 2023 14:48:46 -0600 Subject: [PATCH 043/155] Modularize code to compute the shape of an `ImplicitTPM` --- pyphi/network.py | 10 +--------- pyphi/tpm.py | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 10 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index f67d9e561..d7c854a7a 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -94,15 +94,7 @@ def __init__( ) shapes = [node.shape for node in tpm] - - if not all(len(shape) == len(shapes[0]) for shape in shapes): - raise ValueError( - "The provided node TPMs contain varying number of dimensions." - ) - - network_tpm_shape = [ - max(shape[i] for shape in shapes) for i in range(len(shapes[0])) - ] + network_tpm_shape = ImplicitTPM._node_shapes_to_shape(shapes) self._state_space, _ = build_state_space( self._node_labels, diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 7bbaa61d6..60d7d7040 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -7,7 +7,7 @@ """ from itertools import chain -from typing import Mapping, Set, Tuple +from typing import Iterable, Mapping, Set, Tuple import numpy as np import xarray as xr @@ -648,6 +648,37 @@ def ndim(self): """int: The number of dimensions of the TPM.""" return len(self) + 1 + @property + def shape(self): + """Tuple[int]: The size or number of coordinates in each dimension.""" + shapes = [node.shape for node in self._nodes] + return self._node_shapes_to_shape(shapes) + + @staticmethod + def _node_shapes_to_shape(shapes: Iterable[Iterable[int]]) -> Tuple[int]: + """Infer the shape of the equivalent multidimensional |ExplicitTPM|. + + Args: + shapes (Iterable[Iterable[int]]): The shapes of the individual node + TPMs in the network, ordered by node index. + + Returns: + Tuple[int]: The inferred shape of the equivalent TPM. + """ + # This should recompute the network TPM shape from individual node + # shapes, as opposed to measuring the size of the state space. + + if not all(len(shape) == len(shapes[0]) for shape in shapes): + raise ValueError( + "The provided shapes contain varying number of dimensions." + ) + + network_tpm_shape = tuple( + max(shape[i] for shape in shapes) for i in range(len(shapes[0])) + ) + + return network_tpm_shape + def validate(self, check_independence=True): """Validate this TPM.""" return self._validate_probabilities() From 92b216d4ff05a64afd1f7433b5e86eec16f716bc Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 23 Feb 2023 14:51:36 -0600 Subject: [PATCH 044/155] `tpm`: refactor `ImplicitTPM.condition_tpm()` --- pyphi/node.py | 16 +++++++++++----- pyphi/tpm.py | 40 ++++++++++++++++++---------------------- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/pyphi/node.py b/pyphi/node.py index 072418f85..e038c892e 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -141,7 +141,7 @@ def state(self, value): self._state = value - def project_index(self, index): + def project_index(self, index, preserve_singletons=False): """Convert absolute TPM index to a valid index relative to this node.""" # Supported index coordinates (in the right dimension order) respective @@ -155,24 +155,30 @@ def project_index(self, index): } if isinstance(index, dict): + singleton_coordinate = ( + [SINGLETON_COORDINATE] if preserve_singletons + else SINGLETON_COORDINATE + ) projected_index = { key: value if support[key] != (SINGLETON_COORDINATE,) - else SINGLETON_COORDINATE + else singleton_coordinate for key, value in index.items() } return projected_index # Assume regular index otherwise. + if not isinstance(index, tuple): - # Index is a single int, slice, ellipsis or intra-dimension list. + # Index is a single int, slice, ellipsis, etc. Make it + # amenable to zip(). index = (index,) index_support_map = zip(index, support.values()) - + singleton_coordinate = [0] if preserve_singletons else 0 projected_index = tuple( i if support != (SINGLETON_COORDINATE,) - else 0 + else singleton_coordinate for i, support in index_support_map ) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 60d7d7040..656cfb36a 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -41,7 +41,7 @@ def to_multidimensional_state_by_node(self): def conditionally_independent(self): raise NotImplementedError - def condition_tpm(self, condition: Mapping[int, int]): + def condition_tpm(self, condition): raise NotImplementedError def marginalize_out(self, node_indices): @@ -465,12 +465,12 @@ def condition_tpm(self, condition: Mapping[int, int]): conditioning_indices = tuple(chain.from_iterable(conditioning_indices)) # Obtain the actual conditioned TPM by indexing with the conditioning # indices. - tpm = self._tpm[conditioning_indices] + tpm = self[conditioning_indices] # Create new TPM object of the same type as self. # self.tpm has already been validated and converted to multidimensional # state-by-node form. Further validation would be problematic for # singleton dimensions. - return type(self)(tpm) + return tpm def marginalize_out(self, node_indices): """Marginalize out nodes from this TPM. @@ -636,7 +636,7 @@ class ImplicitTPM(TPM): """ def __init__(self, nodes: Tuple[xr.DataArray]): - self._nodes = nodes + self._nodes = tuple(nodes) @property def nodes(self): @@ -713,10 +713,7 @@ def to_multidimensional_state_by_node(self): def conditionally_independent(self): raise NotImplementedError - # TODO accept label-state mapping as argument. Current solution - # relies on correct node order in data_vars. - # - # TODO(tpm) Refactor codebase to call tpm[Mapping] directly? + # TODO(tpm) accept node labels and state labels in the map. def condition_tpm(self, condition: Mapping[int, int]): """Return a TPM conditioned on the given fixed node indices, whose states are fixed according to the given state-tuple. @@ -734,18 +731,17 @@ def condition_tpm(self, condition: Mapping[int, int]): TPM: A conditioned TPM with the same number of dimensions, with singleton dimensions for nodes in a fixed state. """ - node_dimensions = [ - state_space.INPUT_DIMENSION_PREFIX + node.label - for node in self._nodes - ] - - conditioning_index = { - node_dimensions[node_index]: state - for node_index, state in condition.items() - } + sorted_index = [state_i for i, state_i in sorted(condition.items())] + + # Wrapping index elements in a list is the xarray equivalent + # of inserting a numpy.newaxis, which preserves the singleton even + # after selection of a single state. + conditioning_indices = tuple( + (state_i if isinstance(state_i, list) else [state_i]) + for state_i in sorted_index + ) - # TODO: broadcasting - return self[conditioning_index] + return self.__getitem__(conditioning_indices, preserve_singletons=True) def marginalize_out(self, node_indices): raise NotImplementedError @@ -780,18 +776,18 @@ def print(self): def permute_nodes(self, permutation): raise NotImplementedError - def __getitem__(self, index): + def __getitem__(self, index, **kwargs): if isinstance(index, (int, slice, type(...), tuple)): return ImplicitTPM( tuple( - node[node.pyphi.project_index(index)] + node[node.pyphi.project_index(index, **kwargs)] for node in self.nodes ) ) if isinstance(index, dict): return ImplicitTPM( tuple( - node.loc[node.pyphi.project_index(index)] + node.loc[node.pyphi.project_index(index, **kwargs)] for node in self.nodes ) ) From 102691da00d03a8d13789e8aaa1b9c14a26a2334 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 23 Feb 2023 16:32:21 -0600 Subject: [PATCH 045/155] Handle potential KeyError if user passes nonexisting dimension labels. --- pyphi/node.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/pyphi/node.py b/pyphi/node.py index e038c892e..68a39f8ae 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -145,10 +145,10 @@ def project_index(self, index, preserve_singletons=False): """Convert absolute TPM index to a valid index relative to this node.""" # Supported index coordinates (in the right dimension order) respective - # to this node, to be used as an AND mask, with 0 being - # SINGLETON_COORDINATE. + # to this node, to be used like an AND mask, with 0 being + # `singleton_coordinate`. - # TODO(tpm) make a Node attribute? (similar to `state_space`). + # TODO(tpm) make this a Node attribute? (similar to `state_space`). support = { dim: tuple(self._dataarray.coords[dim].values) for dim in self._dataarray.dims @@ -159,11 +159,17 @@ def project_index(self, index, preserve_singletons=False): [SINGLETON_COORDINATE] if preserve_singletons else SINGLETON_COORDINATE ) - projected_index = { - key: value if support[key] != (SINGLETON_COORDINATE,) - else singleton_coordinate - for key, value in index.items() - } + try: + projected_index = { + key: value if support[key] != (SINGLETON_COORDINATE,) + else singleton_coordinate + for key, value in index.items() + } + except KeyError as e: + raise ValueError( + "Dimension {} does not exist. Expected one or more of: " + "{}.".format(e, self._dataarray.dims) + ) return projected_index From 2efa780727f16224ac5ba2452fa21b4b870b7bd2 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 23 Feb 2023 16:55:53 -0600 Subject: [PATCH 046/155] `Subsystem`: remove unused property. --- pyphi/subsystem.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 554f395a6..6d245eb11 100644 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -108,10 +108,7 @@ def __init__( # Get the TPM conditioned on the state of the external nodes. external_state = utils.state_of(self.external_indices, self.state) background_conditions = dict(zip(self.external_indices, external_state)) - self.tpm = self.network.tpm.pyphi.condition_tpm(background_conditions) - # The TPM for just the nodes in the subsystem. - labels_in_subsystem = self.node_labels.indices2labels(self.node_indices) - self.proper_tpm = self.tpm[list(labels_in_subsystem)] + self.tpm = self.network.tpm.condition_tpm(background_conditions) # The unidirectional cut applied for phi evaluation self.cut = ( From cb6f289f8af549026bf9842838f40a8408f9ed67 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Fri, 24 Feb 2023 14:58:47 -0600 Subject: [PATCH 047/155] Fix `ImplicitTPM.condition_tpm` and generalize argument --- pyphi/node.py | 21 ++++++++++++++------- pyphi/subsystem.py | 8 ++------ pyphi/tpm.py | 11 ++++------- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/pyphi/node.py b/pyphi/node.py index 68a39f8ae..3a0a70644 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -147,28 +147,35 @@ def project_index(self, index, preserve_singletons=False): # Supported index coordinates (in the right dimension order) respective # to this node, to be used like an AND mask, with 0 being # `singleton_coordinate`. - + dimensions = self._dataarray.dims + coordinates = self._dataarray.coords # TODO(tpm) make this a Node attribute? (similar to `state_space`). - support = { - dim: tuple(self._dataarray.coords[dim].values) - for dim in self._dataarray.dims - } + support = {dim: tuple(coordinates[dim].values) for dim in dimensions} if isinstance(index, dict): singleton_coordinate = ( [SINGLETON_COORDINATE] if preserve_singletons else SINGLETON_COORDINATE ) + try: + # Convert potential int dimension indices to common currency of + # string dimension labels. + keys = [ + k if isinstance(k, str) else dimensions[k] + for k in index.keys() + ] + projected_index = { key: value if support[key] != (SINGLETON_COORDINATE,) else singleton_coordinate - for key, value in index.items() + for key, value in zip(keys, index.values()) } + except KeyError as e: raise ValueError( "Dimension {} does not exist. Expected one or more of: " - "{}.".format(e, self._dataarray.dims) + "{}.".format(e, dimensions) ) return projected_index diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 6d245eb11..72cf6320d 100644 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -135,9 +135,7 @@ def __init__( unconstrained_forward_repertoire_cache or cache.DictCache() ) - self.nodes = tuple( - node.pyphi.streamline() for node in self.tpm.data_vars.values() - ) + self.nodes = self.tpm.nodes # validate.subsystem(self) @@ -153,9 +151,7 @@ def nodes(self, value): """ # pylint: disable=attribute-defined-outside-init self._nodes = value - self._index2node = { - node.pyphi.index: node.pyphi.streamline() for node in self._nodes - } + self._index2node = {node.pyphi.index: node for node in self._nodes} @property def proper_state(self): diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 656cfb36a..778f6665f 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -703,7 +703,6 @@ def _validate_probabilities(self): ): return True - def _validate_shape(self, check_independence=True): raise NotImplementedError @@ -731,15 +730,13 @@ def condition_tpm(self, condition: Mapping[int, int]): TPM: A conditioned TPM with the same number of dimensions, with singleton dimensions for nodes in a fixed state. """ - sorted_index = [state_i for i, state_i in sorted(condition.items())] - # Wrapping index elements in a list is the xarray equivalent # of inserting a numpy.newaxis, which preserves the singleton even # after selection of a single state. - conditioning_indices = tuple( - (state_i if isinstance(state_i, list) else [state_i]) - for state_i in sorted_index - ) + conditioning_indices = { + i: (state_i if isinstance(state_i, list) else [state_i]) + for i, state_i in condition.items() + } return self.__getitem__(conditioning_indices, preserve_singletons=True) From a3deedf3da9d0c346e6f6abecd6d7438c3e91cec Mon Sep 17 00:00:00 2001 From: Isaac David Date: Fri, 24 Feb 2023 16:07:26 -0600 Subject: [PATCH 048/155] Filter `.nodes` attribute based on membership to `Subsystem.node_indices` --- pyphi/subsystem.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 72cf6320d..56267a831 100644 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -135,7 +135,10 @@ def __init__( unconstrained_forward_repertoire_cache or cache.DictCache() ) - self.nodes = self.tpm.nodes + self.nodes = tuple( + node.pyphi for i, node in enumerate(self.tpm.nodes) + if i in self.node_indices + ) # validate.subsystem(self) From d0c4787fe903367fd7475a70da2e1bf5b2569321 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Fri, 24 Feb 2023 17:40:14 -0600 Subject: [PATCH 049/155] Call `Node` directly and `DataArray` through callback-like reference to parent --- pyphi/actual.py | 2 +- pyphi/macro.py | 14 +++++++------- pyphi/network.py | 2 +- pyphi/node.py | 12 ++++++++++-- pyphi/subsystem.py | 12 ++++++------ pyphi/tpm.py | 17 ++++++++++------- pyphi/validate.py | 2 +- test/test_node.py | 30 +++++++++++++++--------------- 8 files changed, 51 insertions(+), 40 deletions(-) diff --git a/pyphi/actual.py b/pyphi/actual.py index f21cea504..428c38763 100644 --- a/pyphi/actual.py +++ b/pyphi/actual.py @@ -153,7 +153,7 @@ def __init__( self.cause_system.state = after_state for node in self.cause_system.nodes: - node.pyphi.state = after_state[node.pyphi.index] + node.state = after_state[node.index] # Validate the cause system # The state of the effect system does not need to be reachable diff --git a/pyphi/macro.py b/pyphi/macro.py index 78132525c..fc29bab21 100644 --- a/pyphi/macro.py +++ b/pyphi/macro.py @@ -79,9 +79,9 @@ def run_tpm(system, steps, blackbox): node_tpms = [] for node in system.nodes: # TODO: nonbinary nodes. - node_tpm = node.pyphi.tpm[..., 1] - for input_node in node.pyphi.inputs: - if not blackbox.in_same_box(node.pyphi.index, input_node): + node_tpm = node.tpm[..., 1] + for input_node in node.inputs: + if not blackbox.in_same_box(node.index, input_node): if input_node in blackbox.output_indices: node_tpm = node_tpm.marginalize_out([input_node]) @@ -264,7 +264,7 @@ def _squeeze(system): # Re-calcuate the tpm based on the results of the cut # TODO: nonbinary nodes. - tpm = rebuild_system_tpm(node.pyphi.tpm[..., 1] for node in nodes) + tpm = rebuild_system_tpm(node.tpm[..., 1] for node in nodes) return SystemAttrs(tpm, cm, node_indices, state, state_space) @@ -275,9 +275,9 @@ def _blackbox_partial_noise(blackbox, system): node_tpms = [] for node in system.nodes: # TODO: nonbinary nodes. - node_tpm = node.pyphi.tpm[..., 1] - for input_node in node.pyphi.inputs: - if blackbox.hidden_from(input_node, node.pyphi.index): + node_tpm = node.tpm[..., 1] + for input_node in node.inputs: + if blackbox.hidden_from(input_node, node.index): node_tpm = node_tpm.marginalize_out([input_node]) node_tpms.append(node_tpm) diff --git a/pyphi/network.py b/pyphi/network.py index d7c854a7a..689d062c4 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -110,7 +110,7 @@ def __init__( self._state_space, index, node_labels=self._node_labels - ) + ).pyphi for index, node_tpm in zip(self._node_indices, tpm) ) ) diff --git a/pyphi/node.py b/pyphi/node.py index 3a0a70644..cf14e293e 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -62,7 +62,7 @@ def __init__(self, dataarray: xr.DataArray): self._outputs = dataarray.attrs["outputs"] self._dataarray = dataarray - self._tpm = dataarray.data + self._tpm = self._dataarray self.state_space = dataarray.attrs["state_space"] @@ -91,6 +91,11 @@ def label(self): """str: The textual label for this node.""" return self._node_labels[self.index] + @property + def dataarray(self): + """|xr.DataArray|: The xarray DataArray for this node.""" + return self._dataarray + @property def tpm(self): """|ExplicitTPM|: The TPM of this node.""" @@ -197,6 +202,9 @@ def project_index(self, index, preserve_singletons=False): return projected_index + def __getitem__(self, index): + return self._dataarray[index].pyphi + def streamline(self): """Remove superfluous coordinates from an unaligned |Node| TPM. @@ -412,7 +420,7 @@ def generate_nodes( index, state=state, node_labels=node_labels - ) + ).pyphi ) return tuple(nodes) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 56267a831..92b4a8f66 100644 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -136,7 +136,7 @@ def __init__( ) self.nodes = tuple( - node.pyphi for i, node in enumerate(self.tpm.nodes) + node for i, node in enumerate(self.tpm.nodes) if i in self.node_indices ) @@ -154,7 +154,7 @@ def nodes(self, value): """ # pylint: disable=attribute-defined-outside-init self._nodes = value - self._index2node = {node.pyphi.index: node for node in self._nodes} + self._index2node = {node.index: node for node in self._nodes} @property def proper_state(self): @@ -325,10 +325,10 @@ def _single_node_cause_repertoire(self, mechanism_node_index, purview): mechanism_node = self._index2node[mechanism_node_index] # We're conditioning on this node's state, so take the TPM for the node # being in that state. - tpm = mechanism_node.pyphi.tpm[..., mechanism_node.state] + tpm = mechanism_node.tpm[..., mechanism_node.state] # Marginalize-out all parents of this mechanism node that aren't in the # purview. - return tpm.marginalize_out((mechanism_node.pyphi.inputs - purview)).tpm + return tpm.marginalize_out((mechanism_node.inputs - purview)).tpm # TODO extend to nonbinary nodes @cache.method("_repertoire_cache", Direction.CAUSE) @@ -388,10 +388,10 @@ def _single_node_effect_repertoire( # pylint: disable=missing-docstring purview_node = self._index2node[purview_node_index] # Condition on the state of the purview inputs that are in the mechanism - tpm = purview_node.pyphi.tpm.condition_tpm(condition) + tpm = purview_node.tpm.condition_tpm(condition) # TODO(4.0) remove reference to TPM # Marginalize-out the inputs that aren't in the mechanism. - nonmechanism_inputs = purview_node.pyphi.inputs - set(condition) + nonmechanism_inputs = purview_node.inputs - set(condition) tpm = tpm.marginalize_out(nonmechanism_inputs) # Reshape so that the distribution is over next states. return tpm.reshape( diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 778f6665f..d7eea7da9 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -10,9 +10,8 @@ from typing import Iterable, Mapping, Set, Tuple import numpy as np -import xarray as xr -from . import config, convert, data_structures, exceptions, state_space +from . import config, convert, data_structures, exceptions from .constants import OFF, ON from .data_structures import FrozenMap from .utils import all_states, np_hash, np_immutable @@ -635,7 +634,10 @@ class ImplicitTPM(TPM): Attributes: """ - def __init__(self, nodes: Tuple[xr.DataArray]): + def __init__(self, nodes): + """Args: + nodes (pyphi.node.Node) + """ self._nodes = tuple(nodes) @property @@ -651,7 +653,7 @@ def ndim(self): @property def shape(self): """Tuple[int]: The size or number of coordinates in each dimension.""" - shapes = [node.shape for node in self._nodes] + shapes = [node.tpm.shape for node in self._nodes] return self._node_shapes_to_shape(shapes) @staticmethod @@ -777,17 +779,18 @@ def __getitem__(self, index, **kwargs): if isinstance(index, (int, slice, type(...), tuple)): return ImplicitTPM( tuple( - node[node.pyphi.project_index(index, **kwargs)] + node.dataarray[node.project_index(index, **kwargs)].pyphi for node in self.nodes ) ) if isinstance(index, dict): return ImplicitTPM( tuple( - node.loc[node.pyphi.project_index(index, **kwargs)] + node.dataarray.loc[node.project_index(index, **kwargs)].pyphi for node in self.nodes ) ) + raise TypeError(f"Invalid index {index} of type {type(index)}.") def __len__(self): """int: The number of nodes in the TPM.""" @@ -808,7 +811,7 @@ def reconstitute_tpm(subsystem): # The last axis of the node TPMs correponds to ON or OFF probabilities # (used in the conditioning step when calculating the repertoires); we want # ON probabilities. - node_tpms = [node.pyphi.tpm[..., 1] for node in subsystem.nodes] + node_tpms = [node.tpm[..., 1] for node in subsystem.nodes] # Remove the singleton dimensions corresponding to external nodes node_tpms = [tpm.squeeze(axis=subsystem.external_indices) for tpm in node_tpms] # We add a new singleton axis at the end so that we can use diff --git a/pyphi/validate.py b/pyphi/validate.py index 38d5782db..14bbe5d70 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -62,7 +62,7 @@ def network(n): Checks the TPM and connectivity matrix. """ - n.tpm.pyphi.validate() + n.tpm.validate() connectivity_matrix(n.cm) if n.cm.shape[0] != n.size: raise ValueError( diff --git a/test/test_node.py b/test/test_node.py index abdd3d8ad..fbcf0df34 100644 --- a/test/test_node.py +++ b/test/test_node.py @@ -32,13 +32,13 @@ def test_node_init_tpm(s): answer = [ExplicitTPM(tpm) for tpm in answer] # fmt: on for node in s.nodes: - assert node.pyphi.tpm.array_equal(answer[node.pyphi.index]) + assert node.tpm.array_equal(answer[node.index]) def test_node_init_inputs(s): answer = [s.node_indices[1:], s.node_indices[2:3], s.node_indices[:2]] for node in s.nodes: - assert set(node.pyphi.inputs) == set(answer[node.pyphi.index]) + assert set(node.inputs) == set(answer[node.index]) def test_node_eq(s): @@ -101,10 +101,10 @@ def test_generate_nodes(s): ]) ) # fmt: on - assert nodes[0].pyphi.tpm.array_equal(node0_tpm) - assert nodes[0].pyphi.inputs == set([1, 2]) - assert nodes[0].pyphi.outputs == set([2]) - assert nodes[0].pyphi.label == "A" + assert nodes[0].tpm.array_equal(node0_tpm) + assert nodes[0].inputs == set([1, 2]) + assert nodes[0].outputs == set([2]) + assert nodes[0].label == "A" # fmt: off node1_tpm = ExplicitTPM( @@ -114,10 +114,10 @@ def test_generate_nodes(s): ]) ) # fmt: on - assert nodes[1].pyphi.tpm.array_equal(node1_tpm) - assert nodes[1].pyphi.inputs == set([2]) - assert nodes[1].pyphi.outputs == set([0, 2]) - assert nodes[1].pyphi.label == "B" + assert nodes[1].tpm.array_equal(node1_tpm) + assert nodes[1].inputs == set([2]) + assert nodes[1].outputs == set([0, 2]) + assert nodes[1].label == "B" # fmt: off node2_tpm = ExplicitTPM( @@ -129,14 +129,14 @@ def test_generate_nodes(s): ]) ) # fmt: on - assert nodes[2].pyphi.tpm.array_equal(node2_tpm) - assert nodes[2].pyphi.inputs == set([0, 1]) - assert nodes[2].pyphi.outputs == set([0, 1]) - assert nodes[2].pyphi.label == "C" + assert nodes[2].tpm.array_equal(node2_tpm) + assert nodes[2].inputs == set([0, 1]) + assert nodes[2].outputs == set([0, 1]) + assert nodes[2].label == "C" def test_generate_nodes_default_labels(s): nodes = generate_nodes( s.tpm, s.cm, s.state_space, s.node_indices, network_state=s.state ) - assert [n.pyphi.label for n in nodes] == ["n0", "n1", "n2"] + assert [n.label for n in nodes] == ["n0", "n1", "n2"] From 2630552b555c76e706e5bb648d7d8e6f7d9b1de0 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Mon, 27 Feb 2023 13:41:04 -0600 Subject: [PATCH 050/155] Network.__init__(): Verify nodes against CM --- pyphi/network.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index 8f36b06db..574f2f8e2 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -9,7 +9,6 @@ from typing import Iterable import numpy as np -import xarray as xr from . import cache, connectivity, jsonify, utils, validate from .labels import NodeLabels @@ -99,7 +98,12 @@ def __init__( if not all(len(shape) == len(shapes[0]) for shape in shapes): raise ValueError("Provided set of nodes contains varying number of dimensions.") - + + for i, shape in enumerate(shapes): + for j, val in enumerate(self.cm[i]): + if (val == 0 and shape[j] != 1) or (val != 0 and shape[j] == 1): + raise ValueError(f"Node shape {shape[j]} does not correspond to connectivity matrix at index [{i}][{j}].") + network_tpm_shape = [max(shape[i] for shape in shapes) for i in range(len(shapes[0]))] self.state_space, _ = build_state_space( From 8d27f2c82d10314035ad34cd900e13a4fbf2b89c Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Mon, 27 Feb 2023 13:44:16 -0600 Subject: [PATCH 051/155] Fix bad merge --- pyphi/network.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index e27fde75f..890cc8f12 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -94,15 +94,12 @@ def __init__( shapes = [node.shape for node in tpm] - if not all(len(shape) == len(shapes[0]) for shape in shapes): - raise ValueError("Provided set of nodes contains varying number of dimensions.") - for i, shape in enumerate(shapes): for j, val in enumerate(self.cm[i]): if (val == 0 and shape[j] != 1) or (val != 0 and shape[j] == 1): raise ValueError(f"Node shape {shape[j]} does not correspond to connectivity matrix at index [{i}][{j}].") - network_tpm_shape = [max(shape[i] for shape in shapes) for i in range(len(shapes[0]))] + network_tpm_shape = ImplicitTPM._node_shapes_to_shape(shapes) self.state_space, _ = build_state_space( self._node_labels, From 074dcfbb7e016f78494bddc016318ad6746f3710 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Mon, 27 Feb 2023 14:09:41 -0600 Subject: [PATCH 052/155] Network: Include attrs in repr, remove redundancy --- pyphi/network.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index 890cc8f12..e51cf91c1 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -237,10 +237,9 @@ def __len__(self): def __repr__(self): # TODO implement a cleaner repr, similar to analyses objects, # distinctions, etc. - return "Network({}, cm={})".format(self.tpm, self.cm) - - def __str__(self): - return self.__repr__() + return "Network({}, cm={}, node_labels={}, state_space={}, purview_cache={})".format( + self.tpm, self.cm, self.node_labels, self.state_space, self.purview_cache + ) def __eq__(self, other): """Return whether this network equals the other object. From eb5fac61cb06fb3fd8576c0452aeca0c9994e867 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Mon, 27 Feb 2023 14:28:07 -0600 Subject: [PATCH 053/155] Network.__init__(): Exclude cache --- pyphi/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/network.py b/pyphi/network.py index e51cf91c1..988df0606 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -237,7 +237,7 @@ def __len__(self): def __repr__(self): # TODO implement a cleaner repr, similar to analyses objects, # distinctions, etc. - return "Network({}, cm={}, node_labels={}, state_space={}, purview_cache={})".format( + return "Network(\n{},\ncm={},\nnode_labels={},\nstate_space={},\npurview_cache={}\n)".format( self.tpm, self.cm, self.node_labels, self.state_space, self.purview_cache ) From 0621de50eabdf4d8be302906a431914158ef0453 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Mon, 27 Feb 2023 14:43:53 -0600 Subject: [PATCH 054/155] Network.__init__(): Properly grab CM columns --- pyphi/network.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index 988df0606..32cd5d384 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -95,7 +95,7 @@ def __init__( shapes = [node.shape for node in tpm] for i, shape in enumerate(shapes): - for j, val in enumerate(self.cm[i]): + for j, val in enumerate(self.cm[..., i]): if (val == 0 and shape[j] != 1) or (val != 0 and shape[j] == 1): raise ValueError(f"Node shape {shape[j]} does not correspond to connectivity matrix at index [{i}][{j}].") @@ -237,8 +237,8 @@ def __len__(self): def __repr__(self): # TODO implement a cleaner repr, similar to analyses objects, # distinctions, etc. - return "Network(\n{},\ncm={},\nnode_labels={},\nstate_space={},\npurview_cache={}\n)".format( - self.tpm, self.cm, self.node_labels, self.state_space, self.purview_cache + return "Network(\n{},\ncm={},\nnode_labels={},\nstate_space={}\n)".format( + self.tpm, self.cm, self.node_labels, self.state_space ) def __eq__(self, other): From d700e87854020e5098c4ec6cf90dde3bc83fb7dd Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Mon, 27 Feb 2023 14:46:02 -0600 Subject: [PATCH 055/155] Update error message --- pyphi/network.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyphi/network.py b/pyphi/network.py index 32cd5d384..f8f6dd8fb 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -97,7 +97,9 @@ def __init__( for i, shape in enumerate(shapes): for j, val in enumerate(self.cm[..., i]): if (val == 0 and shape[j] != 1) or (val != 0 and shape[j] == 1): - raise ValueError(f"Node shape {shape[j]} does not correspond to connectivity matrix at index [{i}][{j}].") + raise ValueError( + f"Node shape {shape[j]} does not correspond to connectivity matrix." + ) network_tpm_shape = ImplicitTPM._node_shapes_to_shape(shapes) From 05165c5e4c32227036bc2a0c767de25181bc8493 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Mon, 27 Feb 2023 14:52:42 -0600 Subject: [PATCH 056/155] Fix operator, error message --- pyphi/network.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index f8f6dd8fb..eddb96371 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -97,9 +97,7 @@ def __init__( for i, shape in enumerate(shapes): for j, val in enumerate(self.cm[..., i]): if (val == 0 and shape[j] != 1) or (val != 0 and shape[j] == 1): - raise ValueError( - f"Node shape {shape[j]} does not correspond to connectivity matrix." - ) + raise ValueError(f"Node shape of {shape[j]} does not correspond to connectivity matrix.") network_tpm_shape = ImplicitTPM._node_shapes_to_shape(shapes) From 63ba053f7dd36c9859b8dfa540a066ed9d1223c8 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Mon, 27 Feb 2023 15:09:45 -0600 Subject: [PATCH 057/155] Fix _state_space assignment --- pyphi/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/network.py b/pyphi/network.py index eddb96371..ebbe39027 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -101,7 +101,7 @@ def __init__( network_tpm_shape = ImplicitTPM._node_shapes_to_shape(shapes) - self.state_space, _ = build_state_space( + self._state_space, _ = build_state_space( self._node_labels, network_tpm_shape, state_space From 9da3677cddc3d403cf7cf25e1fb9dccb08324636 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Mon, 27 Feb 2023 15:52:58 -0600 Subject: [PATCH 058/155] Fix bug in `tpm.ImplicitTPM.shape` --- pyphi/tpm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index d7eea7da9..68741aa86 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -675,11 +675,12 @@ def _node_shapes_to_shape(shapes: Iterable[Iterable[int]]) -> Tuple[int]: "The provided shapes contain varying number of dimensions." ) - network_tpm_shape = tuple( - max(shape[i] for shape in shapes) for i in range(len(shapes[0])) + number_of_nodes = len(shapes) + shape_from_inputs = tuple( + max(shape[i] for shape in shapes) for i in range(number_of_nodes) ) - return network_tpm_shape + return shape_from_inputs + (number_of_nodes,) def validate(self, check_independence=True): """Validate this TPM.""" From 3c32d3463c117d3fe2ebad56ab88788766a87c84 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Mon, 27 Feb 2023 16:04:24 -0600 Subject: [PATCH 059/155] Network._build_cm(): Account for multiple nodes --- pyphi/network.py | 34 +++++++++++++++++++++++----------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index ebbe39027..d6b12f342 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -49,10 +49,6 @@ def __init__( state_space=None, purview_cache=None ): - self._cm, self._cm_hash = self._build_cm(cm, tpm) - self._node_indices = tuple(range(self.size)) - self._node_labels = NodeLabels(node_labels, self._node_indices) - # Initialize _tpm according to argument type. if isinstance(tpm, (np.ndarray, ExplicitTPM)): @@ -61,6 +57,10 @@ def __init__( # np.ndarray, so the following achieves validation in general (and # converstion to multidimensional form, as a side effect). tpm = ExplicitTPM(tpm, validate=True) + self._cm, self._cm_hash = self._build_cm(cm, tpm) + + self._node_indices = tuple(range(self.size)) + self._node_labels = NodeLabels(node_labels, self._node_indices) self._state_space, _ = build_state_space( self._node_labels, @@ -94,13 +94,13 @@ def __init__( shapes = [node.shape for node in tpm] - for i, shape in enumerate(shapes): - for j, val in enumerate(self.cm[..., i]): - if (val == 0 and shape[j] != 1) or (val != 0 and shape[j] == 1): - raise ValueError(f"Node shape of {shape[j]} does not correspond to connectivity matrix.") + self._cm, self._cm_hash = self._build_cm(cm, tpm, shapes) network_tpm_shape = ImplicitTPM._node_shapes_to_shape(shapes) + self._node_indices = tuple(range(self.size)) + self._node_labels = NodeLabels(node_labels, self._node_indices) + self._state_space, _ = build_state_space( self._node_labels, network_tpm_shape, @@ -152,7 +152,7 @@ def cm(self): """ return self._cm - def _build_cm(self, cm, tpm): + def _build_cm(self, cm, tpm, shapes=None): """Convert the passed CM to the proper format, or construct the unitary CM if none was provided. """ @@ -162,9 +162,21 @@ def _build_cm(self, cm, tpm): except AttributeError: size = len(tpm) - # Assume all are connected. - cm = np.ones((size, size)) + if shapes is None: + # Assume all are connected. + cm = np.ones((size, size)) + else: + cm = np.zeros((len(shapes), len(shapes)), dtype=int) + + for i, shape in enumerate(shapes): + for j in range(len(shapes)): + cm[j][i] = 0 if shape[j] == 1 else 1 else: + for i, shape in enumerate(shapes): + for j, val in enumerate(self.cm[..., i]): + if (val == 0 and shape[j] != 1) or (val != 0 and shape[j] == 1): + raise ValueError(f"Node shape of {shape[j]} does not correspond to connectivity matrix.") + cm = np.array(cm) utils.np_immutable(cm) From 6e3fd4867ed22b48ad9a2d58655691b747cb91ac Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Mon, 27 Feb 2023 16:05:59 -0600 Subject: [PATCH 060/155] Network.__init__(): Assign attributes in all cases --- pyphi/network.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pyphi/network.py b/pyphi/network.py index d6b12f342..971ca8d99 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -122,11 +122,17 @@ def __init__( elif isinstance(tpm, ImplicitTPM): self._tpm = tpm + self._cm, self._cm_hash = self._build_cm(cm, tpm) + self._node_indices = tuple(range(self.size)) + self._node_labels = NodeLabels(node_labels, self._node_indices) # FIXME(TPM) initialization from JSON elif isinstance(tpm, dict): # From JSON. self._tpm = ImplicitTPM(tpm["_tpm"]) + self._cm, self._cm_hash = self._build_cm(cm, tpm) + self._node_indices = tuple(range(self.size)) + self._node_labels = NodeLabels(node_labels, self._node_indices) else: raise TypeError(f"Invalid TPM of type {type(tpm)}.") From 694b681af8f34e7ac9bc18bfafcada9bfa0bcf90 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Mon, 27 Feb 2023 16:15:44 -0600 Subject: [PATCH 061/155] Fix redundancy --- pyphi/network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyphi/network.py b/pyphi/network.py index 971ca8d99..f2afbef45 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -176,7 +176,8 @@ def _build_cm(self, cm, tpm, shapes=None): for i, shape in enumerate(shapes): for j in range(len(shapes)): - cm[j][i] = 0 if shape[j] == 1 else 1 + if shape[j] != 1: + cm[j][i] = 1 else: for i, shape in enumerate(shapes): for j, val in enumerate(self.cm[..., i]): From 95e7b31a76dbe6b78a3354da6b13984b3812eda6 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 1 Mar 2023 13:19:30 -0600 Subject: [PATCH 062/155] network.py: format source code --- pyphi/network.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index f2afbef45..848dbce5a 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -58,7 +58,7 @@ def __init__( # converstion to multidimensional form, as a side effect). tpm = ExplicitTPM(tpm, validate=True) self._cm, self._cm_hash = self._build_cm(cm, tpm) - + self._node_indices = tuple(range(self.size)) self._node_labels = NodeLabels(node_labels, self._node_indices) @@ -93,14 +93,14 @@ def __init__( ) shapes = [node.shape for node in tpm] - + self._cm, self._cm_hash = self._build_cm(cm, tpm, shapes) network_tpm_shape = ImplicitTPM._node_shapes_to_shape(shapes) - + self._node_indices = tuple(range(self.size)) self._node_labels = NodeLabels(node_labels, self._node_indices) - + self._state_space, _ = build_state_space( self._node_labels, network_tpm_shape, @@ -173,7 +173,7 @@ def _build_cm(self, cm, tpm, shapes=None): cm = np.ones((size, size)) else: cm = np.zeros((len(shapes), len(shapes)), dtype=int) - + for i, shape in enumerate(shapes): for j in range(len(shapes)): if shape[j] != 1: @@ -182,8 +182,11 @@ def _build_cm(self, cm, tpm, shapes=None): for i, shape in enumerate(shapes): for j, val in enumerate(self.cm[..., i]): if (val == 0 and shape[j] != 1) or (val != 0 and shape[j] == 1): - raise ValueError(f"Node shape of {shape[j]} does not correspond to connectivity matrix.") - + raise ValueError( + "Node TPM {} of shape {} does not match the " + "connectivity matrix.".format(i, shape) + ) + cm = np.array(cm) utils.np_immutable(cm) From 8425577ac54f3fbcb4693140a949cbf02c5ddf3c Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 1 Mar 2023 13:38:20 -0600 Subject: [PATCH 063/155] Fix bug in `Network._build_cm()` for explicit TPMs. --- pyphi/network.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index 848dbce5a..3c1007b7a 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -179,6 +179,12 @@ def _build_cm(self, cm, tpm, shapes=None): if shape[j] != 1: cm[j][i] = 1 else: + cm = np.array(cm) + utils.np_immutable(cm) + + if shapes is None: + return (cm, utils.np_hash(cm)) + for i, shape in enumerate(shapes): for j, val in enumerate(self.cm[..., i]): if (val == 0 and shape[j] != 1) or (val != 0 and shape[j] == 1): @@ -187,10 +193,6 @@ def _build_cm(self, cm, tpm, shapes=None): "connectivity matrix.".format(i, shape) ) - cm = np.array(cm) - - utils.np_immutable(cm) - return (cm, utils.np_hash(cm)) @property From fa535ad21c7a26f621829a2c832fe469846b96fe Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 1 Mar 2023 14:03:38 -0600 Subject: [PATCH 064/155] Cosmetic refactoring of `Network._build_cm()` --- pyphi/network.py | 54 +++++++++++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index 3c1007b7a..69e14646d 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -160,38 +160,46 @@ def cm(self): def _build_cm(self, cm, tpm, shapes=None): """Convert the passed CM to the proper format, or construct the - unitary CM if none was provided. + unitary CM if none was provided (explicit TPM), or infer from node TPMs. """ if cm is None: try: - size = tpm.shape[-1] + network_size = tpm.shape[-1] except AttributeError: - size = len(tpm) - - if shapes is None: - # Assume all are connected. - cm = np.ones((size, size)) - else: - cm = np.zeros((len(shapes), len(shapes)), dtype=int) - - for i, shape in enumerate(shapes): - for j in range(len(shapes)): - if shape[j] != 1: - cm[j][i] = 1 - else: - cm = np.array(cm) - utils.np_immutable(cm) + network_size = len(tpm) + # Explicit TPM without connectivity matrix: assume all are connected. if shapes is None: + cm = np.ones((network_size, network_size), dtype=int) + utils.np_immutable(cm) return (cm, utils.np_hash(cm)) + # ImplicitTPM without connectivity matrix: infer from node TPMs. + cm = np.zeros((network_size, network_size), dtype=int) + for i, shape in enumerate(shapes): - for j, val in enumerate(self.cm[..., i]): - if (val == 0 and shape[j] != 1) or (val != 0 and shape[j] == 1): - raise ValueError( - "Node TPM {} of shape {} does not match the " - "connectivity matrix.".format(i, shape) - ) + for j in range(len(shapes)): + if shape[j] != 1: + cm[j][i] = 1 + + utils.np_immutable(cm) + return (cm, utils.np_hash(cm)) + + cm = np.array(cm) + utils.np_immutable(cm) + + # Explicit TPM with connectivity matrix: return. + if shapes is None: + return (cm, utils.np_hash(cm)) + + # ImplicitTPM with connectivity matrix: validate against node TPM shapes. + for i, shape in enumerate(shapes): + for j, val in enumerate(self._cm[..., i]): + if (val == 0 and shape[j] != 1) or (val != 0 and shape[j] == 1): + raise ValueError( + "Node TPM {} of shape {} does not match the connectivity " + " matrix.".format(i, shape) + ) return (cm, utils.np_hash(cm)) From 75e5fe184a5df4244668f2cdf00722a656676cee Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 1 Mar 2023 14:04:48 -0600 Subject: [PATCH 065/155] Don't reference `self.cm` before it has been set. --- pyphi/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/network.py b/pyphi/network.py index 69e14646d..1e93d26c6 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -194,7 +194,7 @@ def _build_cm(self, cm, tpm, shapes=None): # ImplicitTPM with connectivity matrix: validate against node TPM shapes. for i, shape in enumerate(shapes): - for j, val in enumerate(self._cm[..., i]): + for j, val in enumerate(cm[..., i]): if (val == 0 and shape[j] != 1) or (val != 0 and shape[j] == 1): raise ValueError( "Node TPM {} of shape {} does not match the connectivity " From 1a6d7ecd483e8838e8cad21882835b8ed41af1b7 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Wed, 1 Mar 2023 14:11:33 -0600 Subject: [PATCH 066/155] _build_cm(): Should use cm param, not self.cm --- pyphi/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/network.py b/pyphi/network.py index f2afbef45..5ef384a2f 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -180,7 +180,7 @@ def _build_cm(self, cm, tpm, shapes=None): cm[j][i] = 1 else: for i, shape in enumerate(shapes): - for j, val in enumerate(self.cm[..., i]): + for j, val in enumerate(cm[..., i]): if (val == 0 and shape[j] != 1) or (val != 0 and shape[j] == 1): raise ValueError(f"Node shape of {shape[j]} does not correspond to connectivity matrix.") From f2be2d422f017e355172ed10d75784612683f048 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Wed, 1 Mar 2023 14:17:35 -0600 Subject: [PATCH 067/155] test_network(): Add test_build_cm_implicit_tpm() --- test/test_network.py | 80 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 76 insertions(+), 4 deletions(-) diff --git a/test/test_network.py b/test/test_network.py index c23cf06db..1bb0e41d3 100644 --- a/test/test_network.py +++ b/test/test_network.py @@ -61,21 +61,21 @@ def test_potential_purviews(s): def test_node_labels(standard): labels = ("A", "B", "C") - network = Network(standard.tpm.tpm, node_labels=labels) + network = Network(standard.tpm, node_labels=labels) assert network.node_labels.labels == labels labels = ("A", "B") # Too few labels with pytest.raises(ValueError): - Network(standard.tpm.tpm, node_labels=labels) + Network(standard.tpm, node_labels=labels) # Auto-generated labels - network = Network(standard.tpm.tpm, node_labels=None) + network = Network(standard.tpm, node_labels=None) assert network.node_labels.labels == ("n0", "n1", "n2") def test_num_states(standard): assert standard.num_states == 8 - + def test_repr(standard): print(repr(standard)) @@ -91,3 +91,75 @@ def test_len(standard): def test_size(standard): assert standard.size == 3 + + +def test_build_cm_implicit_tpm(): + # no CM + tpm = [ + np.array([ + [ + [ + [0., 1.], + [1., 0.] + ] + ], + [ + [ + [1., 0.], + [0., 1.] + ] + ] + ]), + np.array([ + [ + [ + [1., 0.], + [1., 0.] + ], + [ + [0., 1.], + [1., 0.] + ] + ], + [ + [ + [0., 1.], + [0., 1.] + ], + [ + [0., 1.], + [1., 0.] + ] + ] + ]), + np.array([ + [ + [ + [1., 0.], + [1., 0.] + ], + [ + [0., 1.], + [0., 1.] + ] + ] + ]) + ] + cm = np.array([ + [1, 1, 0], + [0, 1, 1], + [1, 1, 1] + ]) + network = Network(tpm) + assert((network.cm == cm).all()) + # correct CM + network = Network(tpm, cm) + assert((network.cm == cm).all()) + # incorrect CM + cm = np.array([ + [1, 0, 0], + [1, 1, 0], + [1, 1, 1] + ]) + with pytest.raises(ValueError): + network = Network(tpm, cm) From ad75e770cc1fe0196adfadb323867cc66d48f94a Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Wed, 1 Mar 2023 14:28:00 -0600 Subject: [PATCH 068/155] test_network: Add test_build_cm_explicit_tpm() --- test/test_network.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/test_network.py b/test/test_network.py index 1bb0e41d3..c3eb291cf 100644 --- a/test/test_network.py +++ b/test/test_network.py @@ -91,6 +91,35 @@ def test_len(standard): def test_size(standard): assert standard.size == 3 + + +def test_build_cm_explicit_tpm(): + # no CM + tpm = np.array([ + [0, 0, 0], + [0, 0, 1], + [1, 0, 1], + [1, 0, 0], + [1, 1, 0], + [1, 1, 1], + [1, 1, 1], + [1, 1, 0] + ]) + cm = np.array([ + [1, 1, 1], + [1, 1, 1], + [1, 1, 1] + ]) + network = Network(tpm) + assert((network.cm == cm).all()) + # provided CM + cm = np.array([ + [0, 1, 1], + [1, 1, 0], + [1, 1, 1] + ]) + network = Network(tpm, cm) + assert((network.cm == cm).all()) def test_build_cm_implicit_tpm(): From bac25b7e10649b35bc92d30938d232a649215e27 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Wed, 1 Mar 2023 14:28:22 -0600 Subject: [PATCH 069/155] DictCache: allow for reconstruction with __repr__ --- pyphi/cache.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pyphi/cache.py b/pyphi/cache.py index ed7481efd..94c88848a 100644 --- a/pyphi/cache.py +++ b/pyphi/cache.py @@ -170,10 +170,10 @@ class DictCache: Intended to be used as an object-level cache of method results. """ - def __init__(self): - self.cache = {} - self.hits = 0 - self.misses = 0 + def __init__(self, cache=None, hits=0, misses=0): + self.cache = dict() if cache is None else cache + self.hits = hits + self.misses = misses def clear(self): self.cache = {} @@ -215,6 +215,11 @@ def key(self, *args, _prefix=None, **kwargs): if kwargs: raise NotImplementedError("kwarg cache keys not implemented") return (_prefix,) + tuple(args) + + def __repr__(self): + return "{}(cache={}, hits={}, misses={})".format( + type(self).__name__, self.cache, self.hits, self.misses + ) def redis_init(db): From 7c8904945d1b291d07a3465760e9f6760e37f2ab Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Wed, 1 Mar 2023 14:39:56 -0600 Subject: [PATCH 070/155] test_build_cm_explicit_tpm(): Use np.ones() --- test/test_network.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/test/test_network.py b/test/test_network.py index c3eb291cf..0ab934534 100644 --- a/test/test_network.py +++ b/test/test_network.py @@ -105,11 +105,7 @@ def test_build_cm_explicit_tpm(): [1, 1, 1], [1, 1, 0] ]) - cm = np.array([ - [1, 1, 1], - [1, 1, 1], - [1, 1, 1] - ]) + cm = np.ones((3, 3), dtype=int) network = Network(tpm) assert((network.cm == cm).all()) # provided CM From 902c90caabeba733c8b9094ab8783bdf752efc86 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 1 Mar 2023 14:46:07 -0600 Subject: [PATCH 071/155] Don't assume network state_space dict is ordered. --- pyphi/network.py | 2 +- pyphi/node.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index 1e93d26c6..3094a5d17 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -103,7 +103,7 @@ def __init__( self._state_space, _ = build_state_space( self._node_labels, - network_tpm_shape, + network_tpm_shape[:-1], state_space ) diff --git a/pyphi/node.py b/pyphi/node.py index cf14e293e..fbfb28bf1 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -322,7 +322,7 @@ def node( new_network_state_space, _ = build_state_space( node_labels, tpm.shape[:-1], - network_state_space.values(), + [network_state_space[dim] for dim in dimensions[:-1]], singleton_state_space = (SINGLETON_COORDINATE,), ) From cbe92559776ae1e1a120f7b6b00284e3967608a2 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Wed, 1 Mar 2023 14:47:19 -0600 Subject: [PATCH 072/155] test_build_cm(): Combine into one test --- test/test_network.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/test/test_network.py b/test/test_network.py index 0ab934534..6a5630967 100644 --- a/test/test_network.py +++ b/test/test_network.py @@ -93,8 +93,8 @@ def test_size(standard): assert standard.size == 3 -def test_build_cm_explicit_tpm(): - # no CM +def test_build_cm(): + # ExplicitTPM, no CM tpm = np.array([ [0, 0, 0], [0, 0, 1], @@ -108,7 +108,7 @@ def test_build_cm_explicit_tpm(): cm = np.ones((3, 3), dtype=int) network = Network(tpm) assert((network.cm == cm).all()) - # provided CM + # ExplicitTPM, provided CM cm = np.array([ [0, 1, 1], [1, 1, 0], @@ -116,10 +116,7 @@ def test_build_cm_explicit_tpm(): ]) network = Network(tpm, cm) assert((network.cm == cm).all()) - - -def test_build_cm_implicit_tpm(): - # no CM + # ImplicitTPM, no CM tpm = [ np.array([ [ @@ -177,10 +174,10 @@ def test_build_cm_implicit_tpm(): ]) network = Network(tpm) assert((network.cm == cm).all()) - # correct CM + # ImplicitTPM, correct CM network = Network(tpm, cm) assert((network.cm == cm).all()) - # incorrect CM + # ImplicitTPM, incorrect CM cm = np.array([ [1, 0, 0], [1, 1, 0], From f959a5d0bf9fe7619d4888c0e3185b17ee794ba9 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Wed, 1 Mar 2023 14:57:19 -0600 Subject: [PATCH 073/155] Network.__repr__(): Only use dict for state_space --- pyphi/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/network.py b/pyphi/network.py index 3094a5d17..fbf2e80a4 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -270,7 +270,7 @@ def __repr__(self): # TODO implement a cleaner repr, similar to analyses objects, # distinctions, etc. return "Network(\n{},\ncm={},\nnode_labels={},\nstate_space={}\n)".format( - self.tpm, self.cm, self.node_labels, self.state_space + self.tpm, self.cm, self.node_labels, self.state_space._dict ) def __eq__(self, other): From 434cd2903d8326f940ebb872f3805db3a8de4672 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 1 Mar 2023 16:10:23 -0600 Subject: [PATCH 074/155] Set the state of the `subsystem` nodes. --- pyphi/subsystem.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 92b4a8f66..01b44eb84 100644 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -140,6 +140,9 @@ def __init__( if i in self.node_indices ) + for node, node_state in zip(self.nodes, self.state): + node.state = node_state + # validate.subsystem(self) @property From aa8e67fcf1d9c8c37df4fba33c80327d65eed0d6 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Wed, 1 Mar 2023 16:47:18 -0600 Subject: [PATCH 075/155] test_network: Add test_init_with_explicit_tpm() --- test/test_network.py | 91 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/test/test_network.py b/test/test_network.py index 6a5630967..39263d1e7 100644 --- a/test/test_network.py +++ b/test/test_network.py @@ -3,10 +3,12 @@ # test_network.py import numpy as np +import xarray as xr import pytest from pyphi import Direction, config, exceptions from pyphi.network import Network +from pyphi.tpm import ExplicitTPM, ImplicitTPM @pytest.fixture() @@ -93,6 +95,95 @@ def test_size(standard): assert standard.size == 3 +def test_network_init_with_explicit_tpm(): + tpm = ExplicitTPM([ + [0, 0, 0], + [0, 0, 1], + [1, 0, 1], + [1, 0, 0], + [1, 1, 0], + [1, 1, 1], + [1, 1, 1], + [1, 1, 0] + ], validate=True) + + network = Network(tpm) + + assert type(network.tpm) == ImplicitTPM + + expected_nodes = ( + xr.DataArray([ + [ + [ + [1., 0.], + [0., 1.] + ], + [ + [0., 1.], + [0., 1.] + ] + ], + [ + [ + [1., 0.], + [0., 1.] + ], + [ + [0., 1.], + [0., 1.] + ] + ] + ]), + xr.DataArray([ + [ + [ + [1., 0.], + [0., 1.] + ], + [ + [1., 0.], + [0., 1.] + ] + ], + [ + [ + [1., 0.], + [0., 1.] + ], + [ + [1., 0.], + [0., 1.] + ] + ] + ]), + xr.DataArray([ + [ + [ + [1., 0.], + [1., 0.] + ], + [ + [0., 1.], + [0., 1.] + ] + ], + [ + [ + [0., 1.], + [0., 1.] + ], + [ + [1., 0.], + [1., 0.] + ] + ] + ]) + ) + + for i, node in enumerate(network.tpm.nodes): + assert (node.dataarray.values == expected_nodes[i].values).all() + + def test_build_cm(): # ExplicitTPM, no CM tpm = np.array([ From 8a1c3fd3194f0aa1f64277ee5102d5cc13a09a0f Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 2 Mar 2023 13:41:07 -0600 Subject: [PATCH 076/155] Change semantics of Node.tpm for direct access to ExplicitTPM methods. --- pyphi/node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/node.py b/pyphi/node.py index fbfb28bf1..96e08c8ed 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -62,7 +62,7 @@ def __init__(self, dataarray: xr.DataArray): self._outputs = dataarray.attrs["outputs"] self._dataarray = dataarray - self._tpm = self._dataarray + self._tpm = self._dataarray.data self.state_space = dataarray.attrs["state_space"] From 50faed6f95cf09afb070ab2d1fa3b9b2ed944aea Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 2 Mar 2023 15:28:59 -0600 Subject: [PATCH 077/155] Avoid reuse of subsystem nodes after cut --- pyphi/network.py | 4 ++-- pyphi/subsystem.py | 22 ++++++++++++++++++---- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index fbf2e80a4..874c74190 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -12,7 +12,7 @@ from . import cache, connectivity, jsonify, utils, validate from .labels import NodeLabels -from .node import generate_nodes, node +from .node import generate_nodes, node as Node from .tpm import ExplicitTPM, ImplicitTPM from .state_space import build_state_space @@ -109,7 +109,7 @@ def __init__( self._tpm = ImplicitTPM( tuple( - node( + Node( node_tpm, self._cm, self._state_space, diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 01b44eb84..31589fb70 100644 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -28,6 +28,7 @@ RepertoireIrreducibilityAnalysis, _null_ria, ) +from .node import node as Node from .models.mechanism import StateSpecification from .network import irreducible_purviews from .partition import mip_partitions @@ -135,10 +136,23 @@ def __init__( unconstrained_forward_repertoire_cache or cache.DictCache() ) - self.nodes = tuple( - node for i, node in enumerate(self.tpm.nodes) - if i in self.node_indices - ) + if cut: + self.nodes = tuple( + Node( + node.tpm, + self.cm, + self.network.state_space, + i, + node_labels=self.node_labels + ).pyphi + for i, node in enumerate(self.tpm.nodes) + if i in self.node_indices + ) + else: + self.nodes = tuple( + node for i, node in enumerate(self.tpm.nodes) + if i in self.node_indices + ) for node, node_state in zip(self.nodes, self.state): node.state = node_state From 9fd3b11a65dedb5e488fd7b1bf6a340c293c202f Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 2 Mar 2023 17:41:09 -0600 Subject: [PATCH 078/155] Guarantee marginalization of non_inputs when generating `Node`s. --- pyphi/node.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/pyphi/node.py b/pyphi/node.py index 96e08c8ed..69c079146 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -312,6 +312,15 @@ def node( """ # Generate DataArray structure for this node # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # Get indices of the inputs and outputs. + inputs = frozenset(get_inputs_from_cm(index, cm)) + outputs = frozenset(get_outputs_from_cm(index, cm)) + + # Marginalize out non-input nodes. + non_inputs = set(tpm.tpm_indices()) - inputs + tpm = tpm.marginalize_out(non_inputs) + # Dimensions are the names of this node's parents (whose state this node's # TPM can be conditioned on), plus the last dimension with the probability # for each possible state of this node in the next timestep. @@ -319,10 +328,11 @@ def node( # Compute the relevant state labels (coordinates in xarray terminology) from # the perspective of this node and its direct inputs. + node_states = [network_state_space[dim] for dim in dimensions[:-1]] new_network_state_space, _ = build_state_space( node_labels, tpm.shape[:-1], - [network_state_space[dim] for dim in dimensions[:-1]], + node_states, singleton_state_space = (SINGLETON_COORDINATE,), ) @@ -330,10 +340,6 @@ def node( coordinates = {**new_network_state_space, dimensions[-1]: node_state_space} - # Get indices of the inputs and outputs. - inputs = frozenset(get_inputs_from_cm(index, cm)) - outputs = frozenset(get_outputs_from_cm(index, cm)) - return xr.DataArray( name = node_labels[index], data = tpm, @@ -383,8 +389,6 @@ def generate_nodes( nodes = [] for index, state in zip(indices, node_state): - # Generate the node's TPM. - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # We begin by getting the part of the subsystem's TPM that gives just # the state of this node. This part is still indexed by network state, # but its last dimension will be gone, since now there's just a single @@ -392,13 +396,6 @@ def generate_nodes( # the network nodes. tpm_on = tpm[..., index] - # Marginalize out non-input nodes. - - # TODO use names rather than indices - inputs = frozenset(get_inputs_from_cm(index, cm)) - non_inputs = set(tpm.tpm_indices()) - inputs - tpm_on = tpm_on.marginalize_out(non_inputs).tpm - # Get the TPM that gives the probability of the node being off, rather # than on. tpm_off = 1 - tpm_on @@ -408,9 +405,8 @@ def generate_nodes( # indexed by the node's state at t+1. This representation makes it easy # to condition on the node state. node_tpm = ExplicitTPM( - np.stack([tpm_off, tpm_on], axis=-1) + np.stack([np.asarray(tpm_off), np.asarray(tpm_on)], axis=-1) ) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ nodes.append( node( From 3011435240d2827e3d1d37c7e0cd2d9f7384eb70 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 2 Mar 2023 18:10:28 -0600 Subject: [PATCH 079/155] Inherit `TPM.tpm_indices()` and make it agnostic to state cardinality --- pyphi/subsystem.py | 1 + pyphi/tpm.py | 11 ++--------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 31589fb70..26603ae9c 100644 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -148,6 +148,7 @@ def __init__( for i, node in enumerate(self.tpm.nodes) if i in self.node_indices ) + # TODO(tpm): Does memory optimization justify maintaining the `else`? else: self.nodes = tuple( node for i, node in enumerate(self.tpm.nodes) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 68741aa86..7f2d3a146 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -65,7 +65,8 @@ def infer_cm(self): raise NotImplementedError def tpm_indices(self): - raise NotImplementedError + """Return the indices of nodes in the TPM.""" + return tuple(np.where(np.array(self.shape[:-1]) != 1)[0]) def print(self): raise NotImplementedError @@ -578,11 +579,6 @@ def infer_cm(self): cm[a][b] = self.infer_edge(a, b, all_contexts) return cm - def tpm_indices(self): - """Return the indices of nodes in the TPM.""" - # TODO This currently assumes binary elements (2) - return tuple(np.where(np.array(self.shape[:-1]) == 2)[0]) - def print(self): tpm = convert.to_multidimensional(self._tpm) for state in all_states(tpm.shape[-1]): @@ -767,9 +763,6 @@ def infer_edge(self, a, b, contexts): def infer_cm(self): raise NotImplementedError - def tpm_indices(self): - raise NotImplementedError - def print(self): raise NotImplementedError From dfdde3ab2bb294b43a504eecae43f6252e11c93c Mon Sep 17 00:00:00 2001 From: Isaac David Date: Fri, 3 Mar 2023 18:01:17 -0600 Subject: [PATCH 080/155] Rehabilitate validate.network() --- pyphi/network.py | 18 +++------------- pyphi/tpm.py | 53 ++++++++++++++++++++++++++++++++--------------- pyphi/validate.py | 2 +- 3 files changed, 40 insertions(+), 33 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index 874c74190..5d8f4edfe 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -139,7 +139,7 @@ def __init__( self.purview_cache = purview_cache or cache.PurviewCache() - # validate.network(self) + validate.network(self) @property def tpm(self): @@ -185,22 +185,10 @@ def _build_cm(self, cm, tpm, shapes=None): utils.np_immutable(cm) return (cm, utils.np_hash(cm)) + # Explicit TPM with connectivity matrix: return. + # ImplicitTPM with connectivity matrix: return (validate later). cm = np.array(cm) utils.np_immutable(cm) - - # Explicit TPM with connectivity matrix: return. - if shapes is None: - return (cm, utils.np_hash(cm)) - - # ImplicitTPM with connectivity matrix: validate against node TPM shapes. - for i, shape in enumerate(shapes): - for j, val in enumerate(cm[..., i]): - if (val == 0 and shape[j] != 1) or (val != 0 and shape[j] == 1): - raise ValueError( - "Node TPM {} of shape {} does not match the connectivity " - " matrix.".format(i, shape) - ) - return (cm, utils.np_hash(cm)) @property diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 7f2d3a146..1c544a552 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -25,13 +25,7 @@ class TPM: _ERROR_MSG_PROBABILITY_SUM = "Invalid TPM: probabilities must sum to 1." - def validate(self, check_independence=True): - raise NotImplementedError - - def _validate_probabilities(self): - raise NotImplementedError - - def _validate_shape(self, check_independence=True): + def validate(self, cm, check_independence=True): raise NotImplementedError def to_multidimensional_state_by_node(self): @@ -341,7 +335,7 @@ def tpm(self): """np.ndarray: The underlying `tpm` object.""" return self._tpm - def validate(self, check_independence=True): + def validate(self, cm=None, check_independence=True): """Validate this TPM.""" return self._validate_probabilities() and self._validate_shape( check_independence @@ -644,7 +638,7 @@ def nodes(self): @property def ndim(self): """int: The number of dimensions of the TPM.""" - return len(self) + 1 + return len(self.shape) @property def shape(self): @@ -678,9 +672,9 @@ def _node_shapes_to_shape(shapes: Iterable[Iterable[int]]) -> Tuple[int]: return shape_from_inputs + (number_of_nodes,) - def validate(self, check_independence=True): + def validate(self, cm=None, check_independence=True): """Validate this TPM.""" - return self._validate_probabilities() + return self._validate_probabilities() and self._validate_shape(cm) def _validate_probabilities(self): """Check that the probabilities in a TPM are valid.""" @@ -689,21 +683,46 @@ def _validate_probabilities(self): # Validate that probabilities sum to 1. if any( - (node_tpm.data.sum(axis=-1) != 1.0).any() - for node_tpm in self._nodes + (node.tpm.sum(axis=-1) != 1.0).any() + for node in self._nodes ): raise ValueError(self._ERROR_MSG_PROBABILITY_SUM) # Leverage method in ExplicitTPM to distribute validation of # TPM image within [0, 1]. if all( - node_tpm.data._validate_probabilities() - for node_tpm in self._nodes + node.tpm._validate_probabilities() + for node in self._nodes ): return True - def _validate_shape(self, check_independence=True): - raise NotImplementedError + def _validate_shape(self, cm): + """Validate this TPM's shape. + + The shapes of the individual node TPMs in multidimensional form are + validated against the connectivity matrix specification. Additionally, + the inferred shape of the implicit network TPM must be in + multidimensional state-by-node form, nonbinary and heterogeneous units + supported. + """ + # Validate individual node TPM shapes. + shapes = shapes = [node.tpm.shape for node in self.nodes] + + for i, shape in enumerate(shapes): + for j, val in enumerate(cm[..., i]): + if (val == 0 and shape[j] != 1) or (val != 0 and shape[j] == 1): + raise ValueError( + "Node TPM {} of shape {} does not match the connectivity " + " matrix.".format(i, shape) + ) + + # Validate whole network's shape. + N = len(self.nodes) + if N + 1 != self.ndim: + raise ValueError( + "Invalid TPM shape: {} nodes were provided, but their shapes" + "suggest a {}-node network.".format(N, self.ndim - 1) + ) def to_multidimensional_state_by_node(self): raise NotImplementedError diff --git a/pyphi/validate.py b/pyphi/validate.py index 14bbe5d70..4e7cec168 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -62,7 +62,7 @@ def network(n): Checks the TPM and connectivity matrix. """ - n.tpm.validate() + n.tpm.validate(cm=n.cm) connectivity_matrix(n.cm) if n.cm.shape[0] != n.size: raise ValueError( From 37da9b7fbf08efb77cb0fed917bbfe9c35d9a759 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 7 Mar 2023 14:46:22 -0600 Subject: [PATCH 081/155] `test_network.py`: add fixture for making implicit TPMs --- test/test_network.py | 54 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 49 insertions(+), 5 deletions(-) diff --git a/test/test_network.py b/test/test_network.py index 39263d1e7..3baf9ab05 100644 --- a/test/test_network.py +++ b/test/test_network.py @@ -2,6 +2,8 @@ # -*- coding: utf-8 -*- # test_network.py +import random + import numpy as np import xarray as xr import pytest @@ -18,6 +20,48 @@ def network(): return Network(tpm) +@pytest.fixture() +def implicit_tpm(size, degree, node_states, seed=1337, deterministic_units=False): + rng = random.Random(seed) + + def random_deterministic_repertoire(): + repertoire = rng.sample([1] + (node_states - 1) * [0], node_states) + return repertoire + + def random_repertoire(deterministic_units): + if deterministic_units: + return random_deterministic_repertoire() + + repertoire = np.array([rng.uniform(0, 1) for s in range(node_states)]) + # Normalize using L1 (probabilities accross node_states must sum to 1) + repertoire = repertoire / repertoire.sum() + + return ( + repertoire if repertoire.sum() == 1.0 + else random_deterministic_repertoire() + ) + + tpm = [] + + for node_index in range(size): + # Generate |node_states| pseudo-probabilities for each combination of + # parent states at t - 1. + node_tpm = [ + random_repertoire(deterministic_units) + for j in range(node_states ** degree) + ] + # Select |degree| nodes at random as parents to this node, then reshape + # node TPM to multidimensional form. + node_shape = np.ones(size, dtype=int) + parents = rng.sample(range(size), degree) + node_shape[parents] = node_states + node_tpm = np.array(node_tpm).reshape(tuple(node_shape) + (node_states,)) + + tpm.append(node_tpm) + + return tpm + + def test_network_init_validation(network): with pytest.raises(ValueError): # Totally wrong shape @@ -77,7 +121,7 @@ def test_node_labels(standard): def test_num_states(standard): assert standard.num_states == 8 - + def test_repr(standard): print(repr(standard)) @@ -93,8 +137,8 @@ def test_len(standard): def test_size(standard): assert standard.size == 3 - - + + def test_network_init_with_explicit_tpm(): tpm = ExplicitTPM([ [0, 0, 0], @@ -182,8 +226,8 @@ def test_network_init_with_explicit_tpm(): for i, node in enumerate(network.tpm.nodes): assert (node.dataarray.values == expected_nodes[i].values).all() - - + + def test_build_cm(): # ExplicitTPM, no CM tpm = np.array([ From 1671e08825a122194c04e5b66cfd306b6948d807 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 7 Mar 2023 17:44:32 -0600 Subject: [PATCH 082/155] Infer network shape from nodes' last dimension. As opposed to taking the max at each position in the shape tuples, which not only is prone to disagreements between nodes, but also infers the wrong state cardinality for nodes without parents (the overt background conditions). --- pyphi/tpm.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 1c544a552..4121af991 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -666,11 +666,22 @@ def _node_shapes_to_shape(shapes: Iterable[Iterable[int]]) -> Tuple[int]: ) number_of_nodes = len(shapes) - shape_from_inputs = tuple( - max(shape[i] for shape in shapes) for i in range(number_of_nodes) + states_per_node = tuple(shape[-1] for shape in shapes) + + dimensions_from_shapes = tuple( + set(shape[node_index] for shape in shapes) + for node_index in range(number_of_nodes) ) - return shape_from_inputs + (number_of_nodes,) + for node_index in range(number_of_nodes): + valid_cardinalities = {1, max(dimensions_from_shapes[node_index])} + if dimensions_from_shapes[node_index] != valid_cardinalities: + raise ValueError( + "The provided shapes disagree on the number of states of " + "node {}.".format(node_index) + ) + + return states_per_node + (number_of_nodes,) def validate(self, cm=None, check_independence=True): """Validate this TPM.""" From 68e4808a723c1f08322dc818c97881107d8edfa3 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 8 Mar 2023 13:36:41 -0600 Subject: [PATCH 083/155] Deduplicate code and turn into `ImplicitTPM.shapes` attribute --- pyphi/tpm.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 4121af991..7fd6761aa 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -643,9 +643,14 @@ def ndim(self): @property def shape(self): """Tuple[int]: The size or number of coordinates in each dimension.""" - shapes = [node.tpm.shape for node in self._nodes] + shapes = self.shapes return self._node_shapes_to_shape(shapes) + @property + def shapes(self): + """Tuple[Tuple[int]]: The shapes of each node TPM in this TPM.""" + return [node.tpm.shape for node in self._nodes] + @staticmethod def _node_shapes_to_shape(shapes: Iterable[Iterable[int]]) -> Tuple[int]: """Infer the shape of the equivalent multidimensional |ExplicitTPM|. @@ -717,7 +722,7 @@ def _validate_shape(self, cm): supported. """ # Validate individual node TPM shapes. - shapes = shapes = [node.tpm.shape for node in self.nodes] + shapes = self.shapes for i, shape in enumerate(shapes): for j, val in enumerate(cm[..., i]): From 38740ae85557aec7d785bf434ca86bb3afd6d149 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Wed, 8 Mar 2023 14:22:46 -0600 Subject: [PATCH 084/155] Add equals method to ImplicitTPM --- pyphi/tpm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 7fd6761aa..87d99a8a1 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -804,6 +804,9 @@ def print(self): def permute_nodes(self, permutation): raise NotImplementedError + def equals(self, o: object): + return isinstance(o, type(self)) and self.nodes == o.nodes + def __getitem__(self, index, **kwargs): if isinstance(index, (int, slice, type(...), tuple)): return ImplicitTPM( From 1c79fd07c611185cc540fe5d4074e938ba2363f3 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 8 Mar 2023 14:44:58 -0600 Subject: [PATCH 085/155] Fix bug in ImplicitTPM._node_shapes_to_shape() --- pyphi/tpm.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 7fd6761aa..043bac6ca 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -673,14 +673,22 @@ def _node_shapes_to_shape(shapes: Iterable[Iterable[int]]) -> Tuple[int]: number_of_nodes = len(shapes) states_per_node = tuple(shape[-1] for shape in shapes) + # Check consistency of shapes across nodes. + dimensions_from_shapes = tuple( - set(shape[node_index] for shape in shapes) + tuple(set(shape[node_index] for shape in shapes)) for node_index in range(number_of_nodes) ) for node_index in range(number_of_nodes): - valid_cardinalities = {1, max(dimensions_from_shapes[node_index])} - if dimensions_from_shapes[node_index] != valid_cardinalities: + # Valid cardinalities for a dimension can be either {1, s_i != 1} + # when a node provides input to some nodes but not others, or + # {s_i != 1} if it provides input to all other nodes. + valid_cardinalities = { + (max(dimensions_from_shapes[node_index]), 1), + (max(dimensions_from_shapes[node_index]),) + } + if dimensions_from_shapes[node_index] not in valid_cardinalities: raise ValueError( "The provided shapes disagree on the number of states of " "node {}.".format(node_index) From 9c773e54ba4b3070920271e35045a3170134f643 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 8 Mar 2023 15:13:00 -0600 Subject: [PATCH 086/155] Ibidem --- pyphi/tpm.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 154fe8332..2dd1c82ab 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -676,7 +676,7 @@ def _node_shapes_to_shape(shapes: Iterable[Iterable[int]]) -> Tuple[int]: # Check consistency of shapes across nodes. dimensions_from_shapes = tuple( - tuple(set(shape[node_index] for shape in shapes)) + set(shape[node_index] for shape in shapes) for node_index in range(number_of_nodes) ) @@ -684,11 +684,14 @@ def _node_shapes_to_shape(shapes: Iterable[Iterable[int]]) -> Tuple[int]: # Valid cardinalities for a dimension can be either {1, s_i != 1} # when a node provides input to some nodes but not others, or # {s_i != 1} if it provides input to all other nodes. - valid_cardinalities = { - (max(dimensions_from_shapes[node_index]), 1), - (max(dimensions_from_shapes[node_index]),) - } - if dimensions_from_shapes[node_index] not in valid_cardinalities: + valid_cardinalities = ( + {max(dimensions_from_shapes[node_index]), 1}, + {max(dimensions_from_shapes[node_index])} + ) + if not any( + dimensions_from_shapes[node_index] == cardinality + for cardinality in valid_cardinalities + ): raise ValueError( "The provided shapes disagree on the number of states of " "node {}.".format(node_index) From faa98cfaab4967b3df68fbaa637a89d606a2db0d Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Wed, 8 Mar 2023 15:18:16 -0600 Subject: [PATCH 087/155] Add state space reroute --- pyphi/subsystem.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 26603ae9c..059696ef1 100644 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -227,6 +227,10 @@ def cut_node_labels(self): def tpm_size(self): """int: The number of nodes in the TPM.""" return self.tpm.shape[-1] + + @property + def state_space(self): + return self.network.state_space def cache_info(self): """Report repertoire cache statistics.""" From 569535803ada39ad2de2fa74bd4717d7e7cb2b85 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Wed, 8 Mar 2023 15:21:26 -0600 Subject: [PATCH 088/155] Include state_space in node declarations --- test/test_node.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_node.py b/test/test_node.py index fbcf0df34..36df97f15 100644 --- a/test/test_node.py +++ b/test/test_node.py @@ -42,16 +42,16 @@ def test_node_init_inputs(s): def test_node_eq(s): - assert s.nodes[1] == node(s.tpm, s.cm, 1, 0, "B") + assert s.nodes[1] == node(s.tpm, s.cm, s.state_space, 1, 0, "B") def test_node_neq_by_index(s): - assert s.nodes[0] != node(s.tpm, s.cm, 1, 0, "B") + assert s.nodes[0] != node(s.tpm, s.cm, s.state_space, 1, 0, "B") def test_node_neq_by_state(s): other_s = Subsystem(s.network, (1, 1, 1), s.node_indices) - assert other_s.nodes[1] != node(s.tpm, s.cm, 1, 0, "B") + assert other_s.nodes[1] != node(s.tpm, s.cm, s.state_space, 1, 0, "B") def test_repr(s): From 59f82f0043c31dd710e5c8c63b6b52e4e5db627c Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Wed, 8 Mar 2023 15:59:08 -0600 Subject: [PATCH 089/155] Use `hasattr` instead of try-catch --- pyphi/network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index 5d8f4edfe..64d6c66fa 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -163,9 +163,9 @@ def _build_cm(self, cm, tpm, shapes=None): unitary CM if none was provided (explicit TPM), or infer from node TPMs. """ if cm is None: - try: + if hasattr(tpm, "shape"): network_size = tpm.shape[-1] - except AttributeError: + else: network_size = len(tpm) # Explicit TPM without connectivity matrix: assume all are connected. From 81b1e181cd664a3706e36b536c83a3e686b107db Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Wed, 8 Mar 2023 16:52:11 -0600 Subject: [PATCH 090/155] Do not use `tpm.tpm` --- pyphi/validate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/validate.py b/pyphi/validate.py index 4e7cec168..1082909f0 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -106,7 +106,7 @@ def state_reachable(subsystem): # reached from some state. # First we take the submatrix of the conditioned TPM that corresponds to # the nodes that are actually in the subsystem... - tpm = subsystem.tpm.tpm[..., subsystem.node_indices] + tpm = subsystem.tpm[..., subsystem.node_indices] # Then we do the subtraction and test. test = tpm - np.array(subsystem.proper_state) if not np.any(np.logical_and(-1 < test, test < 1).all(-1)): From 4cc590963ce1aefa96803327eb7ad1fc38afe512 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Fri, 10 Mar 2023 16:37:15 -0600 Subject: [PATCH 091/155] `tpm.py`: extend subtpm() and reconstitute_tpm() --- pyphi/tpm.py | 40 ++++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 2dd1c82ab..c0d3de59f 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -46,8 +46,12 @@ def is_deterministic(self): def is_state_by_state(self): raise NotImplementedError - def subtpm(self, fixed_nodes, state): - raise NotImplementedError + def _subtpm(self, fixed_nodes, state): + N = self.shape[-1] + free_nodes = sorted(set(range(N)) - set(fixed_nodes)) + condition = FrozenMap(zip(fixed_nodes, state)) + conditioned = self.condition_tpm(condition) + return conditioned, free_nodes def expand_tpm(self): raise NotImplementedError @@ -508,18 +512,16 @@ def subtpm(self, fixed_nodes, state): Examples: >>> from pyphi import examples >>> # Get the TPM for nodes only 1 and 2, conditioned on node 0 = OFF - >>> examples.grid3_network().tpm.subtpm((0,), (0,)) - ExplicitTPM([[[[0.02931223 0.04742587] + >>> reconstitute_tpm(examples.grid3_network().tpm).subtpm((0,), (0,)) + ExplicitTPM( + [[[[0.02931223 0.04742587] [0.07585818 0.88079708]] [[0.81757448 0.11920292] - [0.92414182 0.95257413]]]]) + [0.92414182 0.95257413]]]] + ) """ - N = self._tpm.shape[-1] - free_nodes = sorted(set(range(N)) - set(fixed_nodes)) - condition = FrozenMap(zip(fixed_nodes, state)) - conditioned = self.condition_tpm(condition) - # TODO test indicing behavior on xr.DataArray + conditioned, free_nodes = self._subtpm(fixed_nodes, state) return conditioned[..., free_nodes] def expand_tpm(self): @@ -798,7 +800,10 @@ def is_state_by_state(self): return False def subtpm(self, fixed_nodes, state): - raise NotImplementedError + conditioned, free_nodes = self._subtpm(fixed_nodes, state) + return type(self)( + tuple(node for node in conditioned.nodes if node.index in free_nodes) + ) def expand_tpm(self): raise NotImplementedError @@ -854,9 +859,16 @@ def reconstitute_tpm(subsystem): # The last axis of the node TPMs correponds to ON or OFF probabilities # (used in the conditioning step when calculating the repertoires); we want # ON probabilities. - node_tpms = [node.tpm[..., 1] for node in subsystem.nodes] + + # TODO nonbinary nodes + node_tpms = [np.asarray(node.tpm)[..., 1] for node in subsystem.nodes] + + external_indices = () + if hasattr(subsystem, "external_indices"): + external_indices = subsystem.external_indices + # Remove the singleton dimensions corresponding to external nodes - node_tpms = [tpm.squeeze(axis=subsystem.external_indices) for tpm in node_tpms] + node_tpms = [tpm.squeeze(axis=external_indices) for tpm in node_tpms] # We add a new singleton axis at the end so that we can use # pyphi.tpm.expand_tpm, which expects a state-by-node TPM (where the last # axis corresponds to nodes.) @@ -868,7 +880,7 @@ def reconstitute_tpm(subsystem): ] # We concatenate the node TPMs along a new axis to get a multidimensional # state-by-node TPM (where the last axis corresponds to nodes). - return np.concatenate(node_tpms, axis=-1) + return ExplicitTPM(np.concatenate(node_tpms, axis=-1)) # TODO(tpm) remove pending ArrayLike refactor From 432a0a7fb2c645f8baf2f363187917f8e6f15263 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 14 Mar 2023 13:38:41 -0500 Subject: [PATCH 092/155] test_tpm.py: Convert to np.ndarray when comparing matrices. --- test/test_tpm.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/test_tpm.py b/test/test_tpm.py index 396a40d9c..8a8e3cf76 100644 --- a/test/test_tpm.py +++ b/test/test_tpm.py @@ -81,6 +81,7 @@ def test_getattr(): assert expected.array_equal(expected) + def test_is_state_by_state(): # State-by-state tpm = ExplicitTPM(np.ones((8, 8))) @@ -146,7 +147,10 @@ def test_infer_cm(rule152): def test_reconstitute_tpm(standard, s_complete, rule152, noised): # Check subsystem and network TPM are the same when the subsystem is the # whole network - assert np.array_equal(reconstitute_tpm(s_complete), standard.tpm.tpm) + assert np.array_equal( + np.asarray(reconstitute_tpm(s_complete)), + np.asarray(reconstitute_tpm(standard.tpm)) + ) # Regression tests # fmt: off @@ -162,7 +166,7 @@ def test_reconstitute_tpm(standard, s_complete, rule152, noised): ]) # fmt: on subsystem = Subsystem(rule152, (0,) * 5, (0, 1, 2)) - assert np.array_equal(answer, reconstitute_tpm(subsystem)) + assert np.array_equal(answer, np.asarray(reconstitute_tpm(subsystem))) subsystem = Subsystem(noised, (0, 0, 0), (0, 1)) # fmt: off @@ -173,4 +177,4 @@ def test_reconstitute_tpm(standard, s_complete, rule152, noised): [1. , 0. ]], ]) # fmt: on - assert np.array_equal(answer, reconstitute_tpm(subsystem)) + assert np.array_equal(answer, np.asarray(reconstitute_tpm(subsystem))) From 552c4a6eb050af7fe9d9920954ab8753a7fa4888 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 14 Mar 2023 14:16:59 -0500 Subject: [PATCH 093/155] tpm: inherit `infer_cm()` and `infer_edge` from parent class --- pyphi/tpm.py | 103 +++++++++++++++++++++++++-------------------------- 1 file changed, 50 insertions(+), 53 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index c0d3de59f..8b1b0b42b 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -57,10 +57,48 @@ def expand_tpm(self): raise NotImplementedError def infer_edge(self, a, b, contexts): - raise NotImplementedError + """Infer the presence or absence of an edge from node A to node B. + + Let |S| be the set of all nodes in a network. Let |A' = S - {A}|. We + call the state of |A'| the context |C| of |A|. There is an edge from |A| + to |B| if there exists any context |C(A)| such that + |Pr(B | C(A), A=0) != Pr(B | C(A), A=1)|. + + Args: + a (int): The index of the putative source node. + b (int): The index of the putative sink node. + contexts (tuple[tuple[int]]): The tuple of states of ``a`` + Returns: + bool: ``True`` if the edge |A -> B| exists, ``False`` otherwise. + """ + + def a_in_context(context): + """Given a context C(A), return the states of the full system with A + OFF and ON, respectively. + """ + a_off = context[:a] + OFF + context[a:] + a_on = context[:a] + ON + context[a:] + return (a_off, a_on) + + def a_affects_b_in_context(tpm, context): + """Return ``True`` if A has an effect on B, given a context.""" + a_off, a_on = a_in_context(context) + return tpm[a_off][b] != tpm[a_on][b] + + tpm = self.to_multidimensional_state_by_node() + return any(a_affects_b_in_context(tpm, context) for context in contexts) def infer_cm(self): - raise NotImplementedError + """Infer the connectivity matrix associated with a state-by-node TPM in + multidimensional form. + """ + tpm = self.to_multidimensional_state_by_node() + network_size = tpm.shape[-1] + all_contexts = tuple(all_states(network_size - 1)) + cm = np.empty((network_size, network_size), dtype=int) + for a, b in np.ndindex(cm.shape): + cm[a][b] = self.infer_edge(a, b, all_contexts) + return cm def tpm_indices(self): """Return the indices of nodes in the TPM.""" @@ -531,50 +569,6 @@ def expand_tpm(self): unconstrained = np.ones([2] * (self._tpm.ndim - 1) + [self._tpm.shape[-1]]) return type(self)(self._tpm * unconstrained) - def infer_edge(self, a, b, contexts): - """Infer the presence or absence of an edge from node A to node B. - - Let |S| be the set of all nodes in a network. Let |A' = S - {A}|. We - call the state of |A'| the context |C| of |A|. There is an edge from |A| - to |B| if there exists any context |C(A)| such that - |Pr(B | C(A), A=0) != Pr(B | C(A), A=1)|. - - Args: - a (int): The index of the putative source node. - b (int): The index of the putative sink node. - contexts (tuple[tuple[int]]): The tuple of states of ``a`` - Returns: - bool: ``True`` if the edge |A -> B| exists, ``False`` otherwise. - """ - - def a_in_context(context): - """Given a context C(A), return the states of the full system with A - OFF and ON, respectively. - """ - a_off = context[:a] + OFF + context[a:] - a_on = context[:a] + ON + context[a:] - return (a_off, a_on) - - def a_affects_b_in_context(tpm, context): - """Return ``True`` if A has an effect on B, given a context.""" - a_off, a_on = a_in_context(context) - return tpm[a_off][b] != tpm[a_on][b] - - tpm = self.to_multidimensional_state_by_node() - return any(a_affects_b_in_context(tpm, context) for context in contexts) - - def infer_cm(self): - """Infer the connectivity matrix associated with a state-by-node TPM in - multidimensional form. - """ - tpm = self.to_multidimensional_state_by_node() - network_size = tpm.shape[-1] - all_contexts = tuple(all_states(network_size - 1)) - cm = np.empty((network_size, network_size), dtype=int) - for a, b in np.ndindex(cm.shape): - cm[a][b] = self.infer_edge(a, b, all_contexts) - return cm - def print(self): tpm = convert.to_multidimensional(self._tpm) for state in all_states(tpm.shape[-1]): @@ -754,7 +748,16 @@ def _validate_shape(self, cm): ) def to_multidimensional_state_by_node(self): - raise NotImplementedError + """Return the current TPM re-represented in multidimensional + state-by-node form. + + See the PyPhi documentation on :ref:`tpm-conventions` for more + information. + + Returns: + np.ndarray: The TPM in multidimensional state-by-node format. + """ + return reconstitute_tpm(self) def conditionally_independent(self): raise NotImplementedError @@ -808,12 +811,6 @@ def subtpm(self, fixed_nodes, state): def expand_tpm(self): raise NotImplementedError - def infer_edge(self, a, b, contexts): - raise NotImplementedError - - def infer_cm(self): - raise NotImplementedError - def print(self): raise NotImplementedError From 6f218901a4480f1d6507b33a2081d40ff3b116d8 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 14 Mar 2023 17:07:39 -0500 Subject: [PATCH 094/155] Implement `marginalize_out()` for `ImplicitTPM` --- pyphi/node.py | 46 +++++++++++++++++++++++++--------------------- pyphi/tpm.py | 34 ++++++++++++++++++++++++++++------ test/test_tpm.py | 16 ++++++++-------- 3 files changed, 61 insertions(+), 35 deletions(-) diff --git a/pyphi/node.py b/pyphi/node.py index 69c079146..f37bf413a 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -14,8 +14,8 @@ import numpy as np import xarray as xr +import pyphi.tpm from .connectivity import get_inputs_from_cm, get_outputs_from_cm -from .labels import NodeLabels from .state_space import ( dimension_labels, input_dimension_label, @@ -23,7 +23,6 @@ PROBABILITY_DIMENSION, SINGLETON_COORDINATE, ) -from .tpm import ExplicitTPM from .utils import state_of @xr.register_dataarray_accessor("pyphi") @@ -35,8 +34,10 @@ class Node: dataarray (xr.DataArray): Attributes: - index (int): - label (str): + index (int): The node's index in the network. + label (str): The textual label for this node. + node_labels (Tuple[str]): The textual labels for the nodes in the network. + dataarray (xr.DataArray): the xarray DataArray for this node. tpm (|ExplicitTPM|): The node TPM is an array with |n + 1| dimensions, where ``n`` is the size of the |Network|. The first ``n`` dimensions correspond to each node in the system. Dimensions corresponding to @@ -46,10 +47,11 @@ class Node: ``node.tpm[..., 0]`` gives probabilities that the node will be 'OFF' and ``node.tpm[..., 1]`` gives probabilities that the node will be 'ON'. - inputs (frozenset): - outputs (frozenset): - state_space (Tuple[Union[int, str]]): - state (Optional[Union[int, str]]): + inputs (frozenset): The set of nodes with connections to this node. + outputs (frozenset): The set of nodes this node has connections to. + state_space (Tuple[Union[int, str]]): The space of states this node can + inhabit. + state (Optional[Union[int, str]]): The current state of this node. """ def __init__(self, dataarray: xr.DataArray): @@ -73,7 +75,7 @@ def __init__(self, dataarray: xr.DataArray): self._hash = hash( ( self.index, - hash(ExplicitTPM(self.tpm)), + hash(pyphi.tpm.ExplicitTPM(self.tpm)), self._inputs, self._outputs, self.state_space, @@ -286,12 +288,12 @@ def to_json(self): def node( - tpm: ExplicitTPM, + tpm, cm: np.ndarray, network_state_space: Mapping[str, Tuple[Union[int, str]]], index: int, state: Optional[Union[int, str]] = None, - node_labels: Optional[NodeLabels] = None + node_labels: Optional[Tuple[str]] = None ) -> xr.DataArray: """ Instantiate a node TPM DataArray. @@ -305,7 +307,7 @@ def node( Keyword Args: state (Optional[Union[int, str]]): The state of this node. - node_labels (Iterable[str]): Textual labels for each node in the network. + node_labels (Tuple[str]): Textual labels for each node in the network. Returns: xr.DataArray: The node in question. @@ -329,7 +331,7 @@ def node( # Compute the relevant state labels (coordinates in xarray terminology) from # the perspective of this node and its direct inputs. node_states = [network_state_space[dim] for dim in dimensions[:-1]] - new_network_state_space, _ = build_state_space( + input_coordinates, _ = build_state_space( node_labels, tpm.shape[:-1], node_states, @@ -338,7 +340,7 @@ def node( node_state_space = network_state_space[dimensions[index]] - coordinates = {**new_network_state_space, dimensions[-1]: node_state_space} + coordinates = {**input_coordinates, dimensions[-1]: node_state_space} return xr.DataArray( name = node_labels[index], @@ -348,26 +350,28 @@ def node( attrs = { "index": index, "node_labels": node_labels, + "cm": cm, "inputs": inputs, "outputs": outputs, "state_space": tuple(node_state_space), "state": state, + "network_state_space": network_state_space } ) def generate_nodes( - tpm: ExplicitTPM, + network_tpm, cm: np.ndarray, state_space: Mapping[str, Tuple[Union[int, str]]], indices: Tuple[int], network_state: Optional[Tuple[Union[int, str]]] = None, - node_labels: Optional[NodeLabels] = None + node_labels: Optional[Tuple[str]] = None ) -> Tuple[xr.DataArray]: """Generate |Node| objects out of a binary network |ExplicitTPM|. Args: - tpm (|ExplicitTPM|): The system's TPM. + network_tpm (|ExplicitTPM|): The system's TPM. cm (np.ndarray): The CM of the network. state_space (Mapping[str, Tuple[Union[int, str]]]): Labels for the state space of each node in the network. @@ -376,7 +380,7 @@ def generate_nodes( Keyword Args: network_state (Optional[Tuple[Union[int, str]]]): The state of the network. - node_labels (|NodeLabels|): Textual labels for each node. + node_labels (Optional[Tuple[str]]): Textual labels for each node. Returns: Tuple[xr.DataArray]: The nodes of the system. @@ -394,7 +398,7 @@ def generate_nodes( # but its last dimension will be gone, since now there's just a single # scalar value (this node's state) rather than a state-vector for all # the network nodes. - tpm_on = tpm[..., index] + tpm_on = network_tpm[..., index] # Get the TPM that gives the probability of the node being off, rather # than on. @@ -404,7 +408,7 @@ def generate_nodes( # the state of the node's inputs at t, and the last dimension is # indexed by the node's state at t+1. This representation makes it easy # to condition on the node state. - node_tpm = ExplicitTPM( + node_tpm = pyphi.tpm.ExplicitTPM( np.stack([np.asarray(tpm_off), np.asarray(tpm_on)], axis=-1) ) @@ -433,5 +437,5 @@ def expand_node_tpm(tpm): dimension (containing the state of the node) contains only the probability of *this* node being on, rather than the probabilities for each node. """ - uc = ExplicitTPM(np.ones([2 for node in tpm.shape])) + uc = pyphi.tpm.ExplicitTPM(np.ones([2 for node in tpm.shape])) return uc * tpm diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 8b1b0b42b..2ec74fb36 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -14,6 +14,7 @@ from . import config, convert, data_structures, exceptions from .constants import OFF, ON from .data_structures import FrozenMap +from .node import node as Node from .utils import all_states, np_hash, np_immutable class TPM: @@ -518,13 +519,12 @@ def marginalize_out(self, node_indices): ExplicitTPM: A TPM with the same number of dimensions, with the nodes marginalized out. """ - tpm = self._tpm.sum(tuple(node_indices), keepdims=True) / ( + tpm = self.sum(tuple(node_indices), keepdims=True) / ( np.array(self.shape)[list(node_indices)].prod() ) - # Return new TPM object of the same type as self. - # self._tpm has already been validated and converted to multidimensional - # state-by-node form. Further validation would be problematic for - # singleton dimensions. + # Return new TPM object of the same type as self. Assume self had + # already been validated and converted formatted. Further validation + # would be problematic for singleton dimensions. return type(self)(tpm) def is_deterministic(self): @@ -791,7 +791,29 @@ def condition_tpm(self, condition: Mapping[int, int]): return self.__getitem__(conditioning_indices, preserve_singletons=True) def marginalize_out(self, node_indices): - raise NotImplementedError + """Marginalize out nodes from this TPM. + + Args: + node_indices (list[int]): The indices of nodes to be marginalized out. + + Returns: + ImplicitTPM: A TPM with the same number of dimensions, with the nodes + marginalized out. + """ + # Leverage ExplicitTPM.marginalize_out() to distribute operation to + # individual nodes, then assemble into a new ImplicitTPM. + return type(self)( + tuple( + Node( + node.tpm.marginalize_out(node_indices), + node.dataarray.attrs["cm"], + node.dataarray.attrs["network_state_space"], + node.index, + node_labels=node.dataarray.attrs["node_labels"], + ).pyphi + for node in self.nodes + ) + ) def is_deterministic(self): raise NotImplementedError diff --git a/test/test_tpm.py b/test/test_tpm.py index 8a8e3cf76..b73235e3a 100644 --- a/test/test_tpm.py +++ b/test/test_tpm.py @@ -116,28 +116,28 @@ def test_expand_tpm(): def test_marginalize_out(s): marginalized_distribution = s.tpm.marginalize_out([0]) # fmt: off - answer = ExplicitTPM( - np.array([ + answer = np.array([ [[[0.0, 0.0, 0.5], [1.0, 1.0, 0.5]], [[1.0, 0.0, 0.5], [1.0, 1.0, 0.5]]], ]) - ) # fmt: on - assert marginalized_distribution.array_equal(answer) + assert np.array_equal( + np.asarray(reconstitute_tpm(marginalized_distribution)), answer + ) marginalized_distribution = s.tpm.marginalize_out([0, 1]) # fmt: off - answer = ExplicitTPM( - np.array([ + answer = np.array([ [[[0.5, 0.0, 0.5], [1.0, 1.0, 0.5]]], ]) - ) # fmt: on - assert marginalized_distribution.array_equal(answer) + assert np.array_equal( + np.asarray(reconstitute_tpm(marginalized_distribution)), answer + ) def test_infer_cm(rule152): From 53d4e153cb28a7e977b37dc3fd63a16947d47b69 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 15 Mar 2023 18:10:38 -0500 Subject: [PATCH 095/155] Fix reconstitute_tpm so that it considers singletons --- pyphi/tpm.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 2ec74fb36..738b2b88e 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -666,20 +666,20 @@ def _node_shapes_to_shape(shapes: Iterable[Iterable[int]]) -> Tuple[int]: "The provided shapes contain varying number of dimensions." ) - number_of_nodes = len(shapes) + N = len(shapes) states_per_node = tuple(shape[-1] for shape in shapes) # Check consistency of shapes across nodes. dimensions_from_shapes = tuple( set(shape[node_index] for shape in shapes) - for node_index in range(number_of_nodes) + for node_index in range(N) ) - for node_index in range(number_of_nodes): - # Valid cardinalities for a dimension can be either {1, s_i != 1} - # when a node provides input to some nodes but not others, or - # {s_i != 1} if it provides input to all other nodes. + for node_index in range(N): + # Valid state cardinalities along a dimension can be either: + # {1, s_i}, s_i != 1 iff node provides input to only some nodes, + # {s_i}, s_i != 1 iff node provides input to all nodes. valid_cardinalities = ( {max(dimensions_from_shapes[node_index]), 1}, {max(dimensions_from_shapes[node_index])} @@ -693,7 +693,7 @@ def _node_shapes_to_shape(shapes: Iterable[Iterable[int]]) -> Tuple[int]: "node {}.".format(node_index) ) - return states_per_node + (number_of_nodes,) + return states_per_node + (N,) def validate(self, cm=None, check_independence=True): """Validate this TPM.""" @@ -894,8 +894,10 @@ def reconstitute_tpm(subsystem): node_tpms = [np.expand_dims(tpm, -1) for tpm in node_tpms] # Now we expand the node TPMs to the full state space, so we can combine # them all (this uses the maximum entropy distribution). + shapes = tuple(tpm.shape[:-1] for tpm in node_tpms) + network_shape = tuple(max(dim) for dim in zip(*shapes)) node_tpms = [ - tpm * np.ones([2] * (tpm.ndim - 1) + [tpm.shape[-1]]) for tpm in node_tpms + tpm * np.ones(network_shape + (1,)) for tpm in node_tpms ] # We concatenate the node TPMs along a new axis to get a multidimensional # state-by-node TPM (where the last axis corresponds to nodes). From 2a71b04db1adf988de32d24065c925ede125f9d1 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 16 Mar 2023 14:36:39 -0500 Subject: [PATCH 096/155] Implement ImplicitTPM.squeeze --- pyphi/tpm.py | 50 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 45 insertions(+), 5 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 738b2b88e..946495e07 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -7,7 +7,7 @@ """ from itertools import chain -from typing import Iterable, Mapping, Set, Tuple +from typing import Iterable, Mapping, Optional, Set, Tuple import numpy as np @@ -642,13 +642,21 @@ def shape(self): shapes = self.shapes return self._node_shapes_to_shape(shapes) + @property + def _reconstituted_shape(self): + shapes = self.shapes + return self._node_shapes_to_shape(shapes, reconstituted=True) + @property def shapes(self): """Tuple[Tuple[int]]: The shapes of each node TPM in this TPM.""" return [node.tpm.shape for node in self._nodes] @staticmethod - def _node_shapes_to_shape(shapes: Iterable[Iterable[int]]) -> Tuple[int]: + def _node_shapes_to_shape( + shapes: Iterable[Iterable[int]], + reconstituted: Optional[bool]=None + ) -> Tuple[int]: """Infer the shape of the equivalent multidimensional |ExplicitTPM|. Args: @@ -667,7 +675,10 @@ def _node_shapes_to_shape(shapes: Iterable[Iterable[int]]) -> Tuple[int]: ) N = len(shapes) - states_per_node = tuple(shape[-1] for shape in shapes) + if reconstituted: + states_per_node = tuple(max(dim) for dim in zip(*shapes)) + else: + states_per_node = tuple(shape[-1] for shape in shapes) # Check consistency of shapes across nodes. @@ -842,16 +853,45 @@ def permute_nodes(self, permutation): def equals(self, o: object): return isinstance(o, type(self)) and self.nodes == o.nodes + def squeeze(self, axis=None): + """Wrapper around numpy.squeeze.""" + # If axis is None, all axis should be considered. + if axis is None: + axis = set(range(len(self))) + else: + axis = set(axis) if isinstance(axis, Iterable) else set([axis]) + + # Subtract non-singleton dimensions from `axis`, including fake + # singletons (dimensions that are singletons only for a proper subset of + # the nodes), since those should not be squeezed from the ImplicitTPM. + shape = self._reconstituted_shape + nonsingletons = set(np.where(np.array(shape) != 1)[0]) + axis = tuple(axis - nonsingletons) + + # Leverage ExplicitTPM.squeeze to distribute squeezing to every node. + return type(self)( + tuple( + Node( + node.tpm.squeeze(axis=axis), + node.dataarray.attrs["cm"], + node.dataarray.attrs["network_state_space"], + node.index, + node_labels=node.dataarray.attrs["node_labels"], + ).pyphi + for node in self.nodes + ) + ) + def __getitem__(self, index, **kwargs): if isinstance(index, (int, slice, type(...), tuple)): - return ImplicitTPM( + return type(self)( tuple( node.dataarray[node.project_index(index, **kwargs)].pyphi for node in self.nodes ) ) if isinstance(index, dict): - return ImplicitTPM( + return type(self)( tuple( node.dataarray.loc[node.project_index(index, **kwargs)].pyphi for node in self.nodes From a49c2e978e325b5951ce395a83489d8a62db189f Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 16 Mar 2023 15:47:16 -0500 Subject: [PATCH 097/155] Refactor `remove_singleton_dimensions()` --- pyphi/macro.py | 23 ++--------------- pyphi/tpm.py | 69 +++++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 70 insertions(+), 22 deletions(-) diff --git a/pyphi/macro.py b/pyphi/macro.py index fc29bab21..758cbaea1 100644 --- a/pyphi/macro.py +++ b/pyphi/macro.py @@ -48,25 +48,6 @@ def rebuild_system_tpm(node_tpms): return ExplicitTPM(tpm, validate=True) -# TODO This should be a method of the TPM class in tpm.py -def remove_singleton_dimensions(tpm): - """Remove singleton dimensions from the TPM. - - Singleton dimensions are created by conditioning on a set of elements. - This removes those elements from the TPM, leaving a TPM that only - describes the non-conditioned elements. - - Note that indices used in the original TPM must be reindexed for the - smaller TPM. - """ - # Don't squeeze out the final dimension (which contains the probability) - # for networks with one element. - if tpm.ndim <= 2: - return tpm - - return tpm.squeeze()[..., tpm.tpm_indices()] - - def run_tpm(system, steps, blackbox): """Iterate the TPM for the given number of timesteps. @@ -241,7 +222,7 @@ def _squeeze(system): internal_indices = system.tpm.tpm_indices() - tpm = remove_singleton_dimensions(system.tpm) + tpm = system.tpm.remove_singleton_dimensions() # The connectivity matrix is the network's connectivity matrix, with # cut applied, with all connections to/from external nodes severed, @@ -321,7 +302,7 @@ def _blackbox_space(self, blackbox, system): assert blackbox.output_indices == tpm.tpm_indices() - new_tpm = remove_singleton_dimensions(tpm) + new_tpm = tpm.remove_singleton_dimensions() state_space, _ = build_state_space(tpm[:-1], system.state_space) n = len(blackbox) cm = np.zeros((n, n)) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 946495e07..15e0245f8 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -47,6 +47,9 @@ def is_deterministic(self): def is_state_by_state(self): raise NotImplementedError + def remove_singleton_dimensions(self): + raise NotImplementedError + def _subtpm(self, fixed_nodes, state): N = self.shape[-1] free_nodes = sorted(set(range(N)) - set(fixed_nodes)) @@ -537,6 +540,23 @@ def is_state_by_state(self): """ return self.ndim == 2 and self.shape[0] == self.shape[1] + def remove_singleton_dimensions(self): + """Remove singleton dimensions from the TPM. + + Singleton dimensions are created by conditioning on a set of elements. + This removes those elements from the TPM, leaving a TPM that only + describes the non-conditioned elements. + + Note that indices used in the original TPM must be reindexed for the + smaller TPM. + """ + # Don't squeeze out the final dimension (which contains the probability) + # for networks with one element. + if self.ndim <= 2: + return self + + return self.squeeze()[..., self.tpm_indices()] + def subtpm(self, fixed_nodes, state): """Return the TPM for a subset of nodes, conditioned on other nodes. @@ -835,7 +855,53 @@ def is_state_by_state(self): """ return False + def remove_singleton_dimensions(self): + """Remove singleton dimensions from the TPM. + + Singleton dimensions are created by conditioning on a set of elements. + This removes those elements from the TPM, leaving a TPM that only + describes the non-conditioned elements. + + Note that indices used in the original TPM must be reindexed for the + smaller TPM. + """ + # Don't squeeze out the final dimension (which contains the probability) + # for networks with one element. + if self.ndim <= 2: + return self + + shape = self._reconstituted_shape + singletons = set(np.where(np.array(shape) == 1)[0]) + + return type(self)( + tuple( + node for node in self.squeeze().nodes + if node.index in singletons + ) + ) + def subtpm(self, fixed_nodes, state): + """Return the TPM for a subset of nodes, conditioned on other nodes. + + Arguments: + fixed_nodes (tuple[int]): The nodes to select. + state (tuple[int]): The state of the fixed nodes. + + Returns: + ExplicitTPM: The TPM of just the subsystem of the free nodes. + + Examples: + >>> from pyphi import examples + >>> # Get the TPM for nodes only 1 and 2, conditioned on node 0 = OFF + >>> reconstitute_tpm(examples.grid3_network().tpm.subtpm((0,), (0,))) + ExplicitTPM( + [[[[0.02931223 0.04742587] + [0.07585818 0.88079708]] + + [[0.81757448 0.11920292] + [0.92414182 0.95257413]]]] + ) + """ conditioned, free_nodes = self._subtpm(fixed_nodes, state) return type(self)( tuple(node for node in conditioned.nodes if node.index in free_nodes) @@ -863,7 +929,8 @@ def squeeze(self, axis=None): # Subtract non-singleton dimensions from `axis`, including fake # singletons (dimensions that are singletons only for a proper subset of - # the nodes), since those should not be squeezed from the ImplicitTPM. + # the nodes), since those should not be squeezed, not even within + # individual node TPMs. shape = self._reconstituted_shape nonsingletons = set(np.where(np.array(shape) != 1)[0]) axis = tuple(axis - nonsingletons) From 62a821efea4cf9b014c482bfe4e219d80272b3f6 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Fri, 17 Mar 2023 00:03:36 -0500 Subject: [PATCH 098/155] Sync macro.py with latest TPM code --- pyphi/macro.py | 12 +++++++----- test/test_macro.py | 6 +++--- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pyphi/macro.py b/pyphi/macro.py index 758cbaea1..2f2278478 100644 --- a/pyphi/macro.py +++ b/pyphi/macro.py @@ -231,14 +231,12 @@ def _squeeze(system): state = utils.state_of(internal_indices, system.state) - state_space, _ = build_state_space(system.tpm[:-1], system.state_space) - # Re-index the subsystem nodes with the external nodes removed node_indices = reindex(internal_indices) nodes = generate_nodes( tpm, cm, - state_space, + system.state_space, node_indices, network_state=state ) @@ -247,7 +245,7 @@ def _squeeze(system): # TODO: nonbinary nodes. tpm = rebuild_system_tpm(node.tpm[..., 1] for node in nodes) - return SystemAttrs(tpm, cm, node_indices, state, state_space) + return SystemAttrs(tpm, cm, node_indices, state, system.state_space) @staticmethod def _blackbox_partial_noise(blackbox, system): @@ -303,7 +301,7 @@ def _blackbox_space(self, blackbox, system): assert blackbox.output_indices == tpm.tpm_indices() new_tpm = tpm.remove_singleton_dimensions() - state_space, _ = build_state_space(tpm[:-1], system.state_space) + n = len(blackbox) cm = np.zeros((n, n)) for i, j in itertools.product(range(n), repeat=2): @@ -315,6 +313,10 @@ def _blackbox_space(self, blackbox, system): state = blackbox.macro_state(system.state) node_indices = blackbox.macro_indices + state_space, _ = build_state_space( + NodeLabels(None, node_indices), + tpm[:-1] + ) return SystemAttrs(new_tpm, cm, node_indices, state, state_space) diff --git a/test/test_macro.py b/test/test_macro.py index a38ff8b78..4aee3763f 100644 --- a/test/test_macro.py +++ b/test/test_macro.py @@ -317,7 +317,7 @@ def test_remove_singleton_dimensions(): ) # fmt: on assert tpm.tpm_indices() == (0,) - assert macro.remove_singleton_dimensions(tpm).array_equal(tpm) + assert tpm.remove_singleton_dimensions().array_equal(tpm) # fmt: off tpm = ExplicitTPM( @@ -334,7 +334,7 @@ def test_remove_singleton_dimensions(): ) # fmt: on assert tpm.tpm_indices() == (1,) - assert macro.remove_singleton_dimensions(tpm).array_equal(answer) + assert tpm.remove_singleton_dimensions().array_equal(answer) # fmt: off tpm = ExplicitTPM( @@ -355,7 +355,7 @@ def test_remove_singleton_dimensions(): ) # fmt: on assert tpm.tpm_indices() == (0, 2) - assert macro.remove_singleton_dimensions(tpm).array_equal(answer) + assert tpm.remove_singleton_dimensions().array_equal(answer) def test_pack_attrs(s): From 8bce05d2251cd16cbf44e869fd2145e8b5bdf34f Mon Sep 17 00:00:00 2001 From: Isaac David Date: Sat, 18 Mar 2023 22:38:03 -0500 Subject: [PATCH 099/155] Fix bug in auxilary method _node_shapes_to_shape --- pyphi/tpm.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 15e0245f8..05544605d 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -696,7 +696,7 @@ def _node_shapes_to_shape( N = len(shapes) if reconstituted: - states_per_node = tuple(max(dim) for dim in zip(*shapes)) + states_per_node = tuple(max(dim) for dim in zip(*shapes))[:-1] else: states_per_node = tuple(shape[-1] for shape in shapes) @@ -870,13 +870,16 @@ def remove_singleton_dimensions(self): if self.ndim <= 2: return self + # Find the set of singleton dimensions for this TPM. shape = self._reconstituted_shape singletons = set(np.where(np.array(shape) == 1)[0]) + # Squeeze out singleton dimensions and return a new TPM with + # the surviving nodes. return type(self)( tuple( node for node in self.squeeze().nodes - if node.index in singletons + if node.index not in singletons ) ) From c2cc0566ba3e8193a09249d9644400bbeb4a2cb1 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Mon, 20 Mar 2023 10:34:42 -0500 Subject: [PATCH 100/155] Make node_labels parameter mandatory when building a state space --- pyphi/macro.py | 35 ++++++++++++++++++++++++----------- pyphi/network.py | 2 +- pyphi/node.py | 8 ++++---- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/pyphi/macro.py b/pyphi/macro.py index 2f2278478..b32e1a74d 100644 --- a/pyphi/macro.py +++ b/pyphi/macro.py @@ -19,7 +19,7 @@ from .network import irreducible_purviews from .node import expand_node_tpm, generate_nodes from .subsystem import Subsystem -from .tpm import ExplicitTPM +from .tpm import ExplicitTPM, reconstitute_tpm from .state_space import build_state_space # Create a logger for this module. @@ -81,7 +81,7 @@ def run_tpm(system, steps, blackbox): class SystemAttrs( namedtuple( - "SystemAttrs", ["tpm", "cm", "node_indices", "state", "state_space"] + "SystemAttrs", ["tpm", "cm", "node_indices", "state"] ) ): """An immutable container that holds all the attributes of a subsystem. @@ -97,6 +97,15 @@ def node_labels(self): labels = list("m{}".format(i) for i in self.node_indices) return NodeLabels(labels, self.node_indices) + @property + def state_space(self): + state_space, _ = build_state_space( + self.node_labels, + self.tpm.shape[:-1], + node_states=None, + ) + return state_space + @property def nodes(self): return generate_nodes( @@ -104,8 +113,8 @@ def nodes(self): self.cm, self.state_space, self.node_indices, + self.node_labels, network_state=self.state, - node_labels=self.node_labels ) @staticmethod @@ -115,7 +124,6 @@ def pack(system): system.cm, system.node_indices, system.state, - system.state_space, ) def apply(self, system): @@ -125,7 +133,6 @@ def apply(self, system): system.node_labels = self.node_labels system.nodes = self.nodes system.state = self.state - system.state_space = self.state_space class MacroSubsystem(Subsystem): @@ -233,11 +240,18 @@ def _squeeze(system): # Re-index the subsystem nodes with the external nodes removed node_indices = reindex(internal_indices) + node_labels = NodeLabels(None, node_indices) + state_space, _ = build_state_space( + node_labels, + tpm.shape[:-1], + ) + nodes = generate_nodes( - tpm, + reconstitute_tpm(tpm), cm, - system.state_space, + state_space, node_indices, + node_labels, network_state=state ) @@ -245,7 +259,7 @@ def _squeeze(system): # TODO: nonbinary nodes. tpm = rebuild_system_tpm(node.tpm[..., 1] for node in nodes) - return SystemAttrs(tpm, cm, node_indices, state, system.state_space) + return SystemAttrs(tpm, cm, node_indices, state) @staticmethod def _blackbox_partial_noise(blackbox, system): @@ -281,7 +295,6 @@ def _blackbox_time(time_scale, blackbox, system): cm, system.node_indices, system.state, - system.state_space ) def _blackbox_space(self, blackbox, system): @@ -315,10 +328,10 @@ def _blackbox_space(self, blackbox, system): node_indices = blackbox.macro_indices state_space, _ = build_state_space( NodeLabels(None, node_indices), - tpm[:-1] + tpm.shape[:-1] ) - return SystemAttrs(new_tpm, cm, node_indices, state, state_space) + return SystemAttrs(new_tpm, cm, node_indices, state) @staticmethod def _coarsegrain_space(coarse_grain, is_cut, system): diff --git a/pyphi/network.py b/pyphi/network.py index 64d6c66fa..6411332fd 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -74,7 +74,7 @@ def __init__( self._cm, self._state_space, self._node_indices, - node_labels=self._node_labels + self._node_labels ) ) diff --git a/pyphi/node.py b/pyphi/node.py index f37bf413a..f653f8042 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -292,8 +292,8 @@ def node( cm: np.ndarray, network_state_space: Mapping[str, Tuple[Union[int, str]]], index: int, + node_labels: Tuple[str], state: Optional[Union[int, str]] = None, - node_labels: Optional[Tuple[str]] = None ) -> xr.DataArray: """ Instantiate a node TPM DataArray. @@ -304,10 +304,10 @@ def node( network_state_space (Mapping[str, Tuple[Union[int, str]]]): Labels for the state space of each node in the network. index (int): The node's index in the network. + node_labels (Tuple[str]): Textual labels for each node in the network. Keyword Args: state (Optional[Union[int, str]]): The state of this node. - node_labels (Tuple[str]): Textual labels for each node in the network. Returns: xr.DataArray: The node in question. @@ -365,8 +365,8 @@ def generate_nodes( cm: np.ndarray, state_space: Mapping[str, Tuple[Union[int, str]]], indices: Tuple[int], + node_labels: Tuple[str], network_state: Optional[Tuple[Union[int, str]]] = None, - node_labels: Optional[Tuple[str]] = None ) -> Tuple[xr.DataArray]: """Generate |Node| objects out of a binary network |ExplicitTPM|. @@ -376,11 +376,11 @@ def generate_nodes( state_space (Mapping[str, Tuple[Union[int, str]]]): Labels for the state space of each node in the network. indices (Tuple[int]): Indices to generate nodes for. + node_labels (Optional[Tuple[str]]): Textual labels for each node. Keyword Args: network_state (Optional[Tuple[Union[int, str]]]): The state of the network. - node_labels (Optional[Tuple[str]]): Textual labels for each node. Returns: Tuple[xr.DataArray]: The nodes of the system. From 57c51e4fb03d6300b4cf2889d6946f5f2a0fe7a8 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Mon, 20 Mar 2023 12:23:43 -0500 Subject: [PATCH 101/155] macro.py: cast TPM type if necessary --- pyphi/macro.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/pyphi/macro.py b/pyphi/macro.py index b32e1a74d..ceb498463 100644 --- a/pyphi/macro.py +++ b/pyphi/macro.py @@ -18,9 +18,11 @@ from .labels import NodeLabels from .network import irreducible_purviews from .node import expand_node_tpm, generate_nodes -from .subsystem import Subsystem -from .tpm import ExplicitTPM, reconstitute_tpm from .state_space import build_state_space +from .subsystem import Subsystem + +# TODO(tpm) use ImplicitTPM type consistently throughout module +from .tpm import ExplicitTPM, ImplicitTPM, reconstitute_tpm # Create a logger for this module. log = logging.getLogger(__name__) @@ -108,8 +110,13 @@ def state_space(self): @property def nodes(self): + tpm = self.tpm + + if isinstance(tpm, ImplicitTPM): + tpm = reconstitute_tpm(tpm) + return generate_nodes( - self.tpm, + tpm, self.cm, self.state_space, self.node_indices, @@ -246,8 +253,11 @@ def _squeeze(system): tpm.shape[:-1], ) + if isinstance(tpm, ImplicitTPM): + tpm = reconstitute_tpm(tpm) + nodes = generate_nodes( - reconstitute_tpm(tpm), + tpm, cm, state_space, node_indices, From e26c55538588d0aade1bc1ceecfe30d03946fd93 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Mon, 20 Mar 2023 14:33:09 -0500 Subject: [PATCH 102/155] test_node: Convert ImplicitTPMs for generate_nodes --- test/test_node.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/test/test_node.py b/test/test_node.py index 36df97f15..926ccab21 100644 --- a/test/test_node.py +++ b/test/test_node.py @@ -6,7 +6,7 @@ from pyphi.node import node, expand_node_tpm, generate_nodes from pyphi.subsystem import Subsystem -from pyphi.tpm import ExplicitTPM +from pyphi.tpm import ExplicitTPM, reconstitute_tpm def test_node_init_tpm(s): @@ -42,7 +42,8 @@ def test_node_init_inputs(s): def test_node_eq(s): - assert s.nodes[1] == node(s.tpm, s.cm, s.state_space, 1, 0, "B") + expected = node(s.tpm, s.cm, s.state_space, 1, 0, "B") + assert s.nodes[1] == expected def test_node_neq_by_index(s): @@ -83,7 +84,7 @@ def test_expand_tpm(): def test_generate_nodes(s): nodes = generate_nodes( - s.tpm, + reconstitute_tpm(s.tpm), s.cm, s.state_space, s.node_indices, @@ -137,6 +138,12 @@ def test_generate_nodes(s): def test_generate_nodes_default_labels(s): nodes = generate_nodes( - s.tpm, s.cm, s.state_space, s.node_indices, network_state=s.state + reconstitute_tpm(s.tpm), + s.cm, + s.state_space, + s.node_indices, + network_state=s.state, + node_labels=s.node_labels ) - assert [n.label for n in nodes] == ["n0", "n1", "n2"] + + assert [n.label for n in nodes] == ["A", "B", "C"] From a721f4fb8cb5a79402df40ad6fd91c051b786e71 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Mon, 20 Mar 2023 14:38:12 -0500 Subject: [PATCH 103/155] MacroSubsystem.nodes(): Reconstitute TPM --- pyphi/macro.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyphi/macro.py b/pyphi/macro.py index 2f2278478..faff7b98b 100644 --- a/pyphi/macro.py +++ b/pyphi/macro.py @@ -19,7 +19,7 @@ from .network import irreducible_purviews from .node import expand_node_tpm, generate_nodes from .subsystem import Subsystem -from .tpm import ExplicitTPM +from .tpm import ExplicitTPM, reconstitute_tpm from .state_space import build_state_space # Create a logger for this module. @@ -100,7 +100,7 @@ def node_labels(self): @property def nodes(self): return generate_nodes( - self.tpm, + reconstitute_tpm(self.tpm), self.cm, self.state_space, self.node_indices, From cbfa0c2e99860c88f45b4a120498d0629e689656 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Mon, 20 Mar 2023 14:42:07 -0500 Subject: [PATCH 104/155] test_macro.test_pack_attrs(): Fix TPM comparison --- test/test_macro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_macro.py b/test/test_macro.py index 4aee3763f..48fa9ce73 100644 --- a/test/test_macro.py +++ b/test/test_macro.py @@ -360,7 +360,7 @@ def test_remove_singleton_dimensions(): def test_pack_attrs(s): attrs = macro.SystemAttrs.pack(s) - assert attrs.tpm.array_equal(s.tpm) + assert attrs.tpm == s.tpm assert np.array_equal(attrs.cm, s.cm) assert attrs.node_indices == s.node_indices assert attrs.state == s.state From a8aa4ae0b09980c101a26a2ca97bf08067c1d263 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Mon, 20 Mar 2023 14:48:58 -0500 Subject: [PATCH 105/155] generate_nodes(): Handle Implicit network TPMs --- pyphi/macro.py | 3 --- pyphi/node.py | 9 ++++++--- test/test_node.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pyphi/macro.py b/pyphi/macro.py index ceb498463..ce7afd35e 100644 --- a/pyphi/macro.py +++ b/pyphi/macro.py @@ -253,9 +253,6 @@ def _squeeze(system): tpm.shape[:-1], ) - if isinstance(tpm, ImplicitTPM): - tpm = reconstitute_tpm(tpm) - nodes = generate_nodes( tpm, cm, diff --git a/pyphi/node.py b/pyphi/node.py index f653f8042..4dd111f7f 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -14,7 +14,7 @@ import numpy as np import xarray as xr -import pyphi.tpm +from pyphi.tpm import ImplicitTPM, reconstitute_tpm from .connectivity import get_inputs_from_cm, get_outputs_from_cm from .state_space import ( dimension_labels, @@ -368,10 +368,10 @@ def generate_nodes( node_labels: Tuple[str], network_state: Optional[Tuple[Union[int, str]]] = None, ) -> Tuple[xr.DataArray]: - """Generate |Node| objects out of a binary network |ExplicitTPM|. + """Generate |Node| objects out of a binary network |TPM|. Args: - network_tpm (|ExplicitTPM|): The system's TPM. + network_tpm (|ExplicitTPM, ImplicitTPM|): The system's TPM. cm (np.ndarray): The CM of the network. state_space (Mapping[str, Tuple[Union[int, str]]]): Labels for the state space of each node in the network. @@ -385,6 +385,9 @@ def generate_nodes( Returns: Tuple[xr.DataArray]: The nodes of the system. """ + if isinstance(network_tpm, ImplicitTPM): + network_tpm = reconstitute_tpm(network_tpm) + if network_state is None: network_state = (None,) * cm.shape[0] diff --git a/test/test_node.py b/test/test_node.py index 926ccab21..723c87464 100644 --- a/test/test_node.py +++ b/test/test_node.py @@ -84,7 +84,7 @@ def test_expand_tpm(): def test_generate_nodes(s): nodes = generate_nodes( - reconstitute_tpm(s.tpm), + s.tpm, s.cm, s.state_space, s.node_indices, @@ -138,7 +138,7 @@ def test_generate_nodes(s): def test_generate_nodes_default_labels(s): nodes = generate_nodes( - reconstitute_tpm(s.tpm), + s.tpm, s.cm, s.state_space, s.node_indices, From 057c52cb4417e6a0f2f13acfdafed7b9b3082f37 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Mon, 20 Mar 2023 15:01:56 -0500 Subject: [PATCH 106/155] node.py: Avoid circular import --- pyphi/node.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pyphi/node.py b/pyphi/node.py index 4dd111f7f..0be88c9bf 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -14,7 +14,8 @@ import numpy as np import xarray as xr -from pyphi.tpm import ImplicitTPM, reconstitute_tpm +import pyphi.tpm + from .connectivity import get_inputs_from_cm, get_outputs_from_cm from .state_space import ( dimension_labels, @@ -385,8 +386,8 @@ def generate_nodes( Returns: Tuple[xr.DataArray]: The nodes of the system. """ - if isinstance(network_tpm, ImplicitTPM): - network_tpm = reconstitute_tpm(network_tpm) + if isinstance(network_tpm, pyphi.tpm.ImplicitTPM): + network_tpm = pyphi.tpm.reconstitute_tpm(network_tpm) if network_state is None: network_state = (None,) * cm.shape[0] From 1e22c86dcf018842d727c6d7b70ab3a615c5fff0 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Mon, 20 Mar 2023 15:23:08 -0500 Subject: [PATCH 107/155] Catch calls to ImplicitTPM.array_equals and route accordingly --- pyphi/network.py | 2 +- pyphi/tpm.py | 8 ++++++++ test/test_subsystem.py | 6 +++--- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index 6411332fd..f39094144 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -268,7 +268,7 @@ def __eq__(self, other): """ return ( isinstance(other, Network) - and self.tpm.equals(other.tpm) + and self.tpm.array_equal(other.tpm) and np.array_equal(self.cm, other.cm) ) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 05544605d..53352c9b5 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -920,8 +920,16 @@ def permute_nodes(self, permutation): raise NotImplementedError def equals(self, o: object): + """Return whether this TPM equals the other object. + + Two TPMs are equal if they are instances of the same class + and their tuple of node TPMs are equal. + """ return isinstance(o, type(self)) and self.nodes == o.nodes + def array_equal(self, o: object): + return self.equals(o) + def squeeze(self, axis=None): """Wrapper around numpy.squeeze.""" # If axis is None, all axis should be considered. diff --git a/test/test_subsystem.py b/test/test_subsystem.py index 50c03b8fa..230b740e5 100644 --- a/test/test_subsystem.py +++ b/test/test_subsystem.py @@ -125,8 +125,8 @@ def test_apply_cut(s): assert s.network == cut_s.network assert s.state == cut_s.state assert s.node_indices == cut_s.node_indices - assert np.array_equal(cut_s.tpm.tpm, s.tpm.tpm) - assert np.array_equal(cut_s.cm, cut.apply_cut(s.cm)) + assert s.tpm.array_equal(cut_s.tpm) + assert np.array_equal(cut.apply_cut(s.cm), cut_s.cm) def test_cut_indices(s, subsys_n1n2): @@ -148,7 +148,7 @@ def test_cut_node_labels(s): def test_specify_elements_with_labels(standard): - network = Network(standard.tpm.tpm, node_labels=("A", "B", "C")) + network = Network(standard.tpm, node_labels=("A", "B", "C")) subsystem = Subsystem(network, (0, 0, 0), ("B", "C")) assert subsystem.node_indices == (1, 2) assert tuple(node.label for node in subsystem.nodes) == ("B", "C") From 317c560488148808f15a9f7affda410fc7a17792 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Mon, 20 Mar 2023 15:25:59 -0500 Subject: [PATCH 108/155] Subsystem: set node TPM state early --- pyphi/subsystem.py | 40 ++++++++++++++++++---------------------- 1 file changed, 18 insertions(+), 22 deletions(-) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 059696ef1..43a141f96 100644 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -136,27 +136,23 @@ def __init__( unconstrained_forward_repertoire_cache or cache.DictCache() ) - if cut: - self.nodes = tuple( - Node( - node.tpm, - self.cm, - self.network.state_space, - i, - node_labels=self.node_labels - ).pyphi - for i, node in enumerate(self.tpm.nodes) - if i in self.node_indices - ) - # TODO(tpm): Does memory optimization justify maintaining the `else`? - else: - self.nodes = tuple( - node for i, node in enumerate(self.tpm.nodes) - if i in self.node_indices - ) - - for node, node_state in zip(self.nodes, self.state): - node.state = node_state + # Set the state of the |Node|s. + for tpm_node, node_state in zip(self.tpm.nodes, self.state): + tpm_node.state = node_state + + # Generate |Node|s for this subsystem and this particular cut to the cm. + self.nodes = tuple( + Node( + node.tpm, + self.cm, + self.network.state_space, + i, + self.node_labels, + state=node.state + ).pyphi + for i, node in enumerate(self.tpm.nodes) + if i in self.node_indices + ) # validate.subsystem(self) @@ -227,7 +223,7 @@ def cut_node_labels(self): def tpm_size(self): """int: The number of nodes in the TPM.""" return self.tpm.shape[-1] - + @property def state_space(self): return self.network.state_space From 02981cf146c1dfef8780be653fc04587ffcee7a1 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Mon, 20 Mar 2023 15:59:08 -0500 Subject: [PATCH 109/155] (Cleanup) Exclude ExplicitTPM from top-level namespace --- pyphi/__init__.py | 1 - test/test_macro.py | 3 ++- test/test_macro_blackbox.py | 3 ++- test/test_tpm.py | 4 ++-- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pyphi/__init__.py b/pyphi/__init__.py index b92f5c455..2a45b9b5f 100644 --- a/pyphi/__init__.py +++ b/pyphi/__init__.py @@ -77,7 +77,6 @@ from .direction import Direction from .network import Network from .subsystem import Subsystem -from .tpm import ExplicitTPM _skip_import = ["visualize"] diff --git a/test/test_macro.py b/test/test_macro.py index 48fa9ce73..bbc4919a8 100644 --- a/test/test_macro.py +++ b/test/test_macro.py @@ -5,8 +5,9 @@ import numpy as np import pytest -from pyphi import convert, macro, ExplicitTPM +from pyphi import convert, macro from pyphi.exceptions import ConditionallyDependentError +from pyphi.tpm import ExplicitTPM # flake8: noqa diff --git a/test/test_macro_blackbox.py b/test/test_macro_blackbox.py index 1d9f71d85..d61aad0c8 100644 --- a/test/test_macro_blackbox.py +++ b/test/test_macro_blackbox.py @@ -1,7 +1,8 @@ import numpy as np import pytest -from pyphi import Network, compute, config, convert, ExplicitTPM, macro, models, utils +from pyphi import Network, compute, config, convert, macro, models, utils +from pyphi.tpm import ExplicitTPM # TODO: move these to examples.py diff --git a/test/test_tpm.py b/test/test_tpm.py index b73235e3a..95746619b 100644 --- a/test/test_tpm.py +++ b/test/test_tpm.py @@ -6,8 +6,8 @@ import pickle import pytest -from pyphi import Subsystem, ExplicitTPM -from pyphi.tpm import reconstitute_tpm +from pyphi import Subsystem +from pyphi.tpm import ExplicitTPM, reconstitute_tpm @pytest.mark.parametrize( "tpm", From 861ef9d32e83a0188dc062f52ea9b3260641cacc Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 21 Mar 2023 15:32:36 -0500 Subject: [PATCH 110/155] TPM: allow `tpm_indices` to discard singletons --- pyphi/macro.py | 2 +- pyphi/tpm.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pyphi/macro.py b/pyphi/macro.py index ce7afd35e..9b48b701b 100644 --- a/pyphi/macro.py +++ b/pyphi/macro.py @@ -232,7 +232,7 @@ def _squeeze(system): Reindexes the subsystem so that the nodes are ``0..n`` where ``n`` is the number of internal indices in the system. """ - assert system.node_indices == system.tpm.tpm_indices() + assert system.node_indices == system.tpm.tpm_indices(reconstituted=True) internal_indices = system.tpm.tpm_indices() diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 53352c9b5..c70fbb26b 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -104,9 +104,10 @@ def infer_cm(self): cm[a][b] = self.infer_edge(a, b, all_contexts) return cm - def tpm_indices(self): + def tpm_indices(self, reconstituted=False): """Return the indices of nodes in the TPM.""" - return tuple(np.where(np.array(self.shape[:-1]) != 1)[0]) + shape = self._reconstituted_shape if reconstituted else self.shape + return tuple(np.where(np.array(shape[:-1]) != 1)[0]) def print(self): raise NotImplementedError From 7a8961234c11bde017024c533036a844034b5c44 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 21 Mar 2023 15:41:53 -0500 Subject: [PATCH 111/155] macro: tidy remaining references to reconstitute_tpm --- pyphi/macro.py | 9 ++------- pyphi/node.py | 2 +- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/pyphi/macro.py b/pyphi/macro.py index 9b48b701b..e02784e64 100644 --- a/pyphi/macro.py +++ b/pyphi/macro.py @@ -22,7 +22,7 @@ from .subsystem import Subsystem # TODO(tpm) use ImplicitTPM type consistently throughout module -from .tpm import ExplicitTPM, ImplicitTPM, reconstitute_tpm +from .tpm import ExplicitTPM # Create a logger for this module. log = logging.getLogger(__name__) @@ -110,13 +110,8 @@ def state_space(self): @property def nodes(self): - tpm = self.tpm - - if isinstance(tpm, ImplicitTPM): - tpm = reconstitute_tpm(tpm) - return generate_nodes( - tpm, + self.tpm, self.cm, self.state_space, self.node_indices, diff --git a/pyphi/node.py b/pyphi/node.py index 0be88c9bf..007669b15 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -388,7 +388,7 @@ def generate_nodes( """ if isinstance(network_tpm, pyphi.tpm.ImplicitTPM): network_tpm = pyphi.tpm.reconstitute_tpm(network_tpm) - + if network_state is None: network_state = (None,) * cm.shape[0] From 336ff75272a563253cc57fbc1fad61a466404932 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 22 Mar 2023 12:47:57 -0500 Subject: [PATCH 112/155] Harmonize {MacroSubsystem._,ImplicitTPM.}squeeze and fix test_blackbox_external --- pyphi/macro.py | 5 +++-- pyphi/node.py | 17 ++++++++++------- pyphi/state_space.py | 2 +- pyphi/tpm.py | 36 +++++++++++++++++++++++++----------- test/test_macro_subsystem.py | 5 ++--- 5 files changed, 41 insertions(+), 24 deletions(-) diff --git a/pyphi/macro.py b/pyphi/macro.py index e02784e64..871c479ed 100644 --- a/pyphi/macro.py +++ b/pyphi/macro.py @@ -229,10 +229,11 @@ def _squeeze(system): """ assert system.node_indices == system.tpm.tpm_indices(reconstituted=True) - internal_indices = system.tpm.tpm_indices() - + internal_indices = system.tpm.tpm_indices(reconstituted=True) tpm = system.tpm.remove_singleton_dimensions() + # TODO(tpm): deduplicate commonalities with tpm.ImplicitTPM.squeeze. + # The connectivity matrix is the network's connectivity matrix, with # cut applied, with all connections to/from external nodes severed, # shrunk to the size of the internal nodes. diff --git a/pyphi/node.py b/pyphi/node.py index 007669b15..f33d2453e 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -9,11 +9,14 @@ import functools -from typing import Mapping, Optional, Tuple, Union +from typing import Iterable, Mapping, Optional, Tuple, Union import numpy as np import xarray as xr +# TODO rework circular dependency between node.py and tpm.py, instead +# of importing all of pyphi.tpm and relying on late binding of pyphi.tpm. +# to avoid the circular import error. import pyphi.tpm from .connectivity import get_inputs_from_cm, get_outputs_from_cm @@ -152,12 +155,12 @@ def state(self, value): def project_index(self, index, preserve_singletons=False): """Convert absolute TPM index to a valid index relative to this node.""" - # Supported index coordinates (in the right dimension order) respective - # to this node, to be used like an AND mask, with 0 being - # `singleton_coordinate`. + # Supported index coordinates (in the right dimension order) + # respective to this node, to be used like an AND mask, with + # `singleton_coordinate` acting like 0. dimensions = self._dataarray.dims coordinates = self._dataarray.coords - # TODO(tpm) make this a Node attribute? (similar to `state_space`). + support = {dim: tuple(coordinates[dim].values) for dim in dimensions} if isinstance(index, dict): @@ -293,7 +296,7 @@ def node( cm: np.ndarray, network_state_space: Mapping[str, Tuple[Union[int, str]]], index: int, - node_labels: Tuple[str], + node_labels: Iterable[str], state: Optional[Union[int, str]] = None, ) -> xr.DataArray: """ @@ -305,7 +308,7 @@ def node( network_state_space (Mapping[str, Tuple[Union[int, str]]]): Labels for the state space of each node in the network. index (int): The node's index in the network. - node_labels (Tuple[str]): Textual labels for each node in the network. + node_labels (Iterable[str]): Textual labels for each node in the network. Keyword Args: state (Optional[Union[int, str]]): The state of this node. diff --git a/pyphi/state_space.py b/pyphi/state_space.py index ecd3b7b31..88add190f 100644 --- a/pyphi/state_space.py +++ b/pyphi/state_space.py @@ -11,7 +11,7 @@ from .data_structures import FrozenMap -INPUT_DIMENSION_PREFIX = "input_" +INPUT_DIMENSION_PREFIX = "" PROBABILITY_DIMENSION = "Pr" SINGLETON_COORDINATE = "_" diff --git a/pyphi/tpm.py b/pyphi/tpm.py index c70fbb26b..aef14aabc 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -12,6 +12,7 @@ import numpy as np from . import config, convert, data_structures, exceptions +from .connectivity import subadjacency from .constants import OFF, ON from .data_structures import FrozenMap from .node import node as Node @@ -878,10 +879,7 @@ def remove_singleton_dimensions(self): # Squeeze out singleton dimensions and return a new TPM with # the surviving nodes. return type(self)( - tuple( - node for node in self.squeeze().nodes - if node.index not in singletons - ) + tuple(node for node in self.squeeze().nodes) ) def subtpm(self, fixed_nodes, state): @@ -944,20 +942,36 @@ def squeeze(self, axis=None): # the nodes), since those should not be squeezed, not even within # individual node TPMs. shape = self._reconstituted_shape - nonsingletons = set(np.where(np.array(shape) != 1)[0]) - axis = tuple(axis - nonsingletons) + nonsingletons = tuple(np.where(np.array(shape) > 1)[0]) + axis = tuple(axis - set(nonsingletons)) + + # From now on, we will only care about the first n-1 dimensions (parents). + if shape[-1] > 1: + nonsingletons = nonsingletons[:-1] + + # Recompute connectivity matrix and subset of node labels. + # TODO(tpm) deduplicate commonalities with macro.MacroSubsystem._squeeze. + some_node = self.nodes[0] + + new_cm = subadjacency(some_node.dataarray.attrs["cm"], nonsingletons) + + new_node_indices = iter(range(len(nonsingletons))) + new_node_labels = tuple(some_node._node_labels[n] for n in nonsingletons) + + state_space = some_node.dataarray.attrs["network_state_space"] + new_state_space = {n: state_space[n] for n in new_node_labels} # Leverage ExplicitTPM.squeeze to distribute squeezing to every node. return type(self)( tuple( Node( node.tpm.squeeze(axis=axis), - node.dataarray.attrs["cm"], - node.dataarray.attrs["network_state_space"], - node.index, - node_labels=node.dataarray.attrs["node_labels"], + new_cm, + new_state_space, + next(new_node_indices), + new_node_labels, ).pyphi - for node in self.nodes + for node in self.nodes if node.index in nonsingletons ) ) diff --git a/test/test_macro_subsystem.py b/test/test_macro_subsystem.py index a9732894b..792a170d0 100644 --- a/test/test_macro_subsystem.py +++ b/test/test_macro_subsystem.py @@ -6,7 +6,6 @@ import pyphi from pyphi import convert, macro, models, timescale, config -from pyphi.tpm import ExplicitTPM from pyphi.convert import state_by_node2state_by_state as sbn2sbs from pyphi.convert import state_by_state2state_by_node as sbs2sbn @@ -280,7 +279,7 @@ def test_blackbox(s): ms = macro.MacroSubsystem( s.network, s.state, s.node_indices, blackbox=macro.Blackbox(((0, 1, 2),), (1,)) ) - assert np.array_equal(ms.tpm.tpm, np.array([[0.5], [0.5]])) + assert np.array_equal(np.asarray(ms.tpm), np.array([[0.5], [0.5]])) assert np.array_equal(ms.cm, np.array([[1]])) assert ms.node_indices == (0,) assert ms.state == (0,) @@ -291,7 +290,7 @@ def test_blackbox_external(s): ms = macro.MacroSubsystem( s.network, s.state, (1, 2), blackbox=macro.Blackbox(((1, 2),), (1,)) ) - assert np.array_equal(ms.tpm.tpm, np.array([[0.5], [0.5]])) + assert np.array_equal(np.asarray(ms.tpm), np.array([[0.5], [0.5]])) assert np.array_equal(ms.cm, np.array([[1]])) assert ms.node_indices == (0,) assert ms.state == (0,) From 7d8e21a2d9887b8c81eeeef79b87f60013ad783f Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Wed, 22 Mar 2023 14:35:08 -0500 Subject: [PATCH 113/155] validate.state_reachable(): Reconstitute TPM --- pyphi/validate.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyphi/validate.py b/pyphi/validate.py index 1082909f0..f3709a7d1 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -10,7 +10,7 @@ from . import config, exceptions from .direction import Direction -from .tpm import ExplicitTPM +from .tpm import ExplicitTPM, reconstitute_tpm from .models.mechanism import MaximallyIrreducibleCauseOrEffect # pylint: disable=redefined-outer-name @@ -106,7 +106,7 @@ def state_reachable(subsystem): # reached from some state. # First we take the submatrix of the conditioned TPM that corresponds to # the nodes that are actually in the subsystem... - tpm = subsystem.tpm[..., subsystem.node_indices] + tpm = reconstitute_tpm(subsystem.tpm)[..., subsystem.node_indices] # Then we do the subtraction and test. test = tpm - np.array(subsystem.proper_state) if not np.any(np.logical_and(-1 < test, test < 1).all(-1)): From f19439a24de3dbb31be96d30bfcb2f7fffe98aa3 Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Wed, 22 Mar 2023 16:13:37 -0500 Subject: [PATCH 114/155] Implement `ImplicitTPM.__hash__()` --- pyphi/tpm.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index aef14aabc..825d19f59 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -1001,9 +1001,10 @@ def __str__(self): def __repr__(self): return "ImplicitTPM({})".format(self.nodes) - + def __hash__(self): - raise NotImplementedError + return hash(tuple(hash(node for node in self.nodes))) + def reconstitute_tpm(subsystem): From ebb68eb04e3a8e5bd1a36c8e126d99cfeed4bb7f Mon Sep 17 00:00:00 2001 From: David Viggiano Date: Wed, 22 Mar 2023 16:19:37 -0500 Subject: [PATCH 115/155] ImplicitTPM.__hash__(): Fix syntax --- pyphi/tpm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 825d19f59..65afc14c9 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -1003,7 +1003,7 @@ def __repr__(self): return "ImplicitTPM({})".format(self.nodes) def __hash__(self): - return hash(tuple(hash(node for node in self.nodes))) + return hash(tuple(hash(node) for node in self.nodes)) From 6b8f36cf30ca6bbc9c9a99daa8fbf8785a62a98f Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 23 Mar 2023 17:21:25 -0500 Subject: [PATCH 116/155] validate.state_reachable: only cast if necessary --- pyphi/tpm.py | 4 ++-- pyphi/validate.py | 10 ++++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 65afc14c9..49a5d60ce 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -1001,14 +1001,14 @@ def __str__(self): def __repr__(self): return "ImplicitTPM({})".format(self.nodes) - + def __hash__(self): return hash(tuple(hash(node) for node in self.nodes)) def reconstitute_tpm(subsystem): - """Reconstitute the TPM of a subsystem using the individual node TPMs.""" + """Reconstitute the ExplicitTPM of a subsystem using individual node TPMs.""" # The last axis of the node TPMs correponds to ON or OFF probabilities # (used in the conditioning step when calculating the repertoires); we want # ON probabilities. diff --git a/pyphi/validate.py b/pyphi/validate.py index f3709a7d1..267eee847 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -10,7 +10,7 @@ from . import config, exceptions from .direction import Direction -from .tpm import ExplicitTPM, reconstitute_tpm +from .tpm import ImplicitTPM, reconstitute_tpm from .models.mechanism import MaximallyIrreducibleCauseOrEffect # pylint: disable=redefined-outer-name @@ -101,12 +101,18 @@ def state_length(state, size): def state_reachable(subsystem): """Return whether a state can be reached according to the network's TPM.""" + # TODO(tpm) Change consumers of this function, so that only ImplicitTPMs + # are passed. + tpm = ( + reconstitute_tpm(subsystem.tpm) if isinstance(subsystem.tpm, ImplicitTPM) + else subsystem.tpm + ) # If there is a row `r` in the TPM such that all entries of `r - state` are # between -1 and 1, then the given state has a nonzero probability of being # reached from some state. # First we take the submatrix of the conditioned TPM that corresponds to # the nodes that are actually in the subsystem... - tpm = reconstitute_tpm(subsystem.tpm)[..., subsystem.node_indices] + tpm = tpm[..., subsystem.node_indices] # Then we do the subtraction and test. test = tpm - np.array(subsystem.proper_state) if not np.any(np.logical_and(-1 < test, test < 1).all(-1)): From 52d0fccf6703e5b887ba0b67ecf80db5426cfd4e Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 23 Mar 2023 17:52:49 -0500 Subject: [PATCH 117/155] Re-enable validation at Subsystem constructor --- pyphi/subsystem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 43a141f96..5e33161e0 100644 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -154,7 +154,7 @@ def __init__( if i in self.node_indices ) - # validate.subsystem(self) + validate.subsystem(self) @property def nodes(self): From 273ca3e8bb2d89068bd321d2b8396c3f9f187c3b Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 23 Mar 2023 17:57:12 -0500 Subject: [PATCH 118/155] test_validate: Don't refer to `tpm.tpm` --- test/test_validate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_validate.py b/test/test_validate.py index 130128c9d..31ae1eeab 100644 --- a/test/test_validate.py +++ b/test/test_validate.py @@ -83,7 +83,7 @@ def test_validate_connectivity_matrix_not_binary(): def test_validate_network_wrong_cm_size(s): with pytest.raises(ValueError): - Network(s.network.tpm.tpm, np.ones(16).reshape(4, 4)) + Network(s.network.tpm, np.ones(16).reshape(4, 4)) def test_validate_is_network(s): From d2b3feca1e4260142ec081dec9edc677efa1b943 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Fri, 24 Mar 2023 11:58:48 -0500 Subject: [PATCH 119/155] Refactor _validate_probabilities (and use utils.eq) --- pyphi/distribution.py | 16 ++++++++++++++++ pyphi/tpm.py | 34 +++++++++++++++++++++++++++------- 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/pyphi/distribution.py b/pyphi/distribution.py index bbb82882b..214e55cb7 100644 --- a/pyphi/distribution.py +++ b/pyphi/distribution.py @@ -9,6 +9,22 @@ import numpy as np from .cache import cache +from .utils import eq + + +def is_unitary(a): + """Whether the distribution satisfies the second axiom of probability theory. + + This uses utils.eq and config.PRECISION to compare floats up to a tolerance. + + Args: + a (np.ndarray): The array to verify for unit measure. + + Returns: + bool: Whether the sum of entries in ``a`` is close enough to 1. + """ + measure = a.ravel().sum() + return eq(measure, 1.0) def normalize(a): diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 49a5d60ce..9248e39e8 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -11,12 +11,12 @@ import numpy as np -from . import config, convert, data_structures, exceptions +from . import config, convert, distribution, data_structures, exceptions from .connectivity import subadjacency from .constants import OFF, ON from .data_structures import FrozenMap from .node import node as Node -from .utils import all_states, np_hash, np_immutable +from .utils import all_states, eq, np_hash, np_immutable class TPM: """TPM interface for derived classes.""" @@ -393,10 +393,27 @@ def _validate_probabilities(self): """Check that the probabilities in a TPM are valid.""" if (self._tpm < 0.0).any() or (self._tpm > 1.0).any(): raise ValueError(self._ERROR_MSG_PROBABILITY_IMAGE) - if self.is_state_by_state() and np.any(np.sum(self._tpm, axis=1) != 1.0): + + # Validate that probabilities sum to 1. + if not self.is_unitary(): raise ValueError(self._ERROR_MSG_PROBABILITY_SUM) + return True + def is_unitary(self, implicit_tpm=False): + """Whether the TPM satisfies the second axiom of probability theory.""" + if implicit_tpm: + measures = self.sum(axis=-1).ravel() + return all(eq(p, 1.0) for p in measures) + + if not self.is_state_by_state(): + tpm = convert.state_by_node2state_by_state(self) + else: + tpm = self + + distributions = (d for d in tpm) + return all(distribution.is_unitary(d) for d in distributions) + def _validate_shape(self, check_independence=True): """Validate this TPM's shape. @@ -738,10 +755,7 @@ def _validate_probabilities(self): # individual node TPMs contain valid probabilities, for every node. # Validate that probabilities sum to 1. - if any( - (node.tpm.sum(axis=-1) != 1.0).any() - for node in self._nodes - ): + if not self.is_unitary(): raise ValueError(self._ERROR_MSG_PROBABILITY_SUM) # Leverage method in ExplicitTPM to distribute validation of @@ -752,6 +766,12 @@ def _validate_probabilities(self): ): return True + def is_unitary(self): + """Whether the TPM satisfies the second axiom of probability theory.""" + return all( + node.tpm.is_unitary(implicit_tpm=True) for node in self._nodes + ) + def _validate_shape(self, cm): """Validate this TPM's shape. From 0281d42a685dfc591303b7f37addba7173fa9f0d Mon Sep 17 00:00:00 2001 From: Isaac David Date: Fri, 24 Mar 2023 12:08:55 -0500 Subject: [PATCH 120/155] `distribution.is_unitary`: Remove superfluous flattening --- pyphi/distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/distribution.py b/pyphi/distribution.py index 214e55cb7..ab746632e 100644 --- a/pyphi/distribution.py +++ b/pyphi/distribution.py @@ -23,7 +23,7 @@ def is_unitary(a): Returns: bool: Whether the sum of entries in ``a`` is close enough to 1. """ - measure = a.ravel().sum() + measure = a.sum() return eq(measure, 1.0) From 6c79407780070dd26667e3e4abf164997001a492 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Fri, 24 Mar 2023 17:27:06 -0500 Subject: [PATCH 121/155] Make sure node shapes are validated if CM is passed --- pyphi/network.py | 11 ++++++++--- pyphi/tpm.py | 43 ++++++++----------------------------------- pyphi/validate.py | 43 ++++++++++++++++++++++++++++--------------- 3 files changed, 44 insertions(+), 53 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index f39094144..b6c92e566 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -40,7 +40,6 @@ class Network: index per node. """ - # TODO make tpm also optional when implementing logical network definition def __init__( self, tpm, @@ -185,10 +184,16 @@ def _build_cm(self, cm, tpm, shapes=None): utils.np_immutable(cm) return (cm, utils.np_hash(cm)) - # Explicit TPM with connectivity matrix: return. - # ImplicitTPM with connectivity matrix: return (validate later). cm = np.array(cm) utils.np_immutable(cm) + + # Explicit TPM with connectivity matrix: return. + if shapes is None: + return (cm, utils.np_hash(cm)) + + # ImplicitTPM with connectivity matrix: validate against node shapes. + validate.shapes(shapes, cm) + return (cm, utils.np_hash(cm)) @property diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 9248e39e8..629a3c0a0 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -27,7 +27,7 @@ class TPM: _ERROR_MSG_PROBABILITY_SUM = "Invalid TPM: probabilities must sum to 1." - def validate(self, cm, check_independence=True): + def validate(self, check_independence=True): raise NotImplementedError def to_multidimensional_state_by_node(self): @@ -383,7 +383,7 @@ def tpm(self): """np.ndarray: The underlying `tpm` object.""" return self._tpm - def validate(self, cm=None, check_independence=True): + def validate(self, check_independence=True): """Validate this TPM.""" return self._validate_probabilities() and self._validate_shape( check_independence @@ -745,9 +745,9 @@ def _node_shapes_to_shape( return states_per_node + (N,) - def validate(self, cm=None, check_independence=True): + def validate(self, check_independence=True): """Validate this TPM.""" - return self._validate_probabilities() and self._validate_shape(cm) + return self._validate_probabilities() and self._validate_shape() def _validate_probabilities(self): """Check that the probabilities in a TPM are valid.""" @@ -772,27 +772,13 @@ def is_unitary(self): node.tpm.is_unitary(implicit_tpm=True) for node in self._nodes ) - def _validate_shape(self, cm): + def _validate_shape(self): """Validate this TPM's shape. - The shapes of the individual node TPMs in multidimensional form are - validated against the connectivity matrix specification. Additionally, - the inferred shape of the implicit network TPM must be in + The inferred shape of the implicit network TPM must be in multidimensional state-by-node form, nonbinary and heterogeneous units supported. """ - # Validate individual node TPM shapes. - shapes = self.shapes - - for i, shape in enumerate(shapes): - for j, val in enumerate(cm[..., i]): - if (val == 0 and shape[j] != 1) or (val != 0 and shape[j] == 1): - raise ValueError( - "Node TPM {} of shape {} does not match the connectivity " - " matrix.".format(i, shape) - ) - - # Validate whole network's shape. N = len(self.nodes) if N + 1 != self.ndim: raise ValueError( @@ -800,6 +786,8 @@ def _validate_shape(self, cm): "suggest a {}-node network.".format(N, self.ndim - 1) ) + return True + def to_multidimensional_state_by_node(self): """Return the current TPM re-represented in multidimensional state-by-node form. @@ -812,9 +800,6 @@ def to_multidimensional_state_by_node(self): """ return reconstitute_tpm(self) - def conditionally_independent(self): - raise NotImplementedError - # TODO(tpm) accept node labels and state labels in the map. def condition_tpm(self, condition: Mapping[int, int]): """Return a TPM conditioned on the given fixed node indices, whose @@ -868,9 +853,6 @@ def marginalize_out(self, node_indices): ) ) - def is_deterministic(self): - raise NotImplementedError - def is_state_by_state(self): """Return ``True`` if ``tpm`` is in state-by-state form, otherwise ``False``. @@ -929,15 +911,6 @@ def subtpm(self, fixed_nodes, state): tuple(node for node in conditioned.nodes if node.index in free_nodes) ) - def expand_tpm(self): - raise NotImplementedError - - def print(self): - raise NotImplementedError - - def permute_nodes(self, permutation): - raise NotImplementedError - def equals(self, o: object): """Return whether this TPM equals the other object. diff --git a/pyphi/validate.py b/pyphi/validate.py index 267eee847..e85b55106 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -32,20 +32,6 @@ def direction(direction, allow_bi=False): return True -def connectivity_matrix(cm): - """Validate the given connectivity matrix.""" - # Special case for empty matrices. - if cm.size == 0: - return True - if cm.ndim != 2: - raise ValueError("Connectivity matrix must be 2-dimensional.") - if cm.shape[0] != cm.shape[1]: - raise ValueError("Connectivity matrix must be square.") - if not np.all(np.logical_or(cm == 1, cm == 0)): - raise ValueError("Connectivity matrix must contain only binary " "values.") - return True - - def node_labels(node_labels, node_indices): """Validate that there is a label for each node.""" if len(node_labels) != len(node_indices): @@ -62,8 +48,9 @@ def network(n): Checks the TPM and connectivity matrix. """ - n.tpm.validate(cm=n.cm) + n.tpm.validate() connectivity_matrix(n.cm) + shapes(n.tpm.shapes, n.cm) if n.cm.shape[0] != n.size: raise ValueError( "Connectivity matrix must be NxN, where N is the " @@ -72,6 +59,32 @@ def network(n): return True +def connectivity_matrix(cm): + """Validate the given connectivity matrix.""" + # Special case for empty matrices. + if cm.size == 0: + return True + if cm.ndim != 2: + raise ValueError("Connectivity matrix must be 2-dimensional.") + if cm.shape[0] != cm.shape[1]: + raise ValueError("Connectivity matrix must be square.") + if not np.all(np.logical_or(cm == 1, cm == 0)): + raise ValueError("Connectivity matrix must contain only binary " "values.") + return True + + +def shapes(shapes, cm): + """Validate consistency between node TPM shapes and a user-provided cm.""" + for i, shape in enumerate(shapes): + for j, con in enumerate(cm[..., i]): + if (con == 0 and shape[j] != 1) or (con != 0 and shape[j] == 1): + raise ValueError( + "Node TPM {} of shape {} does not match the connectivity " + "matrix.".format(i, shape) + ) + return True + + def is_network(network): """Validate that the argument is a |Network|.""" from . import Network From f268f4d5314941430ca29bebf289fcac7f5a09ef Mon Sep 17 00:00:00 2001 From: Isaac David Date: Fri, 24 Mar 2023 17:47:24 -0500 Subject: [PATCH 122/155] tpm: Remove old ProxyMetaclass code --- pyphi/tpm.py | 145 +-------------------------------------------------- 1 file changed, 2 insertions(+), 143 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 629a3c0a0..ae0df1664 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -126,147 +126,6 @@ def __hash__(self): raise NotImplementedError -# TODO(tpm) remove pending ArrayLike refactor -class ProxyMetaclass(type): - """A metaclass to create wrappers for the TPM array's special attributes. - - The CPython interpreter resolves double-underscore attributes (e.g., the - method definitions of mathematical operators) by looking up in the class' - static methods, not in the instance methods. This makes it impossible to - intercept calls to them when an instance's ``__getattr__()`` is implicitly - invoked, which in turn means there are only two options to wrap the special - methods of the array inside our custom objects (in order to perform - arithmetic operations with the TPM while also casting the result to our - custom class type): - - 1. Manually "overload" all the necessary methods. - 2. Use this metaclass to introspect the underlying array - and automatically overload methods in our custom TPM class definition. - """ - - def __init__(cls, type_name, bases, dct): - - # Casting semantics: values belonging to our custom TPM class should - # remain closed under the following methods: - __closures__ = frozenset( - { - # 1-ary - "__abs__", - "__copy__", - "__invert__", - "__neg__", - "__pos__", - # 2-ary - "__add__", - "__iadd__", - "__radd__", - "__sub__", - "__isub__", - "__rsub__", - "__mul__", - "__imul__", - "__rmul__", - "__matmul__", - "__imatmul__", - "__rmatmul__", - "__truediv__", - "__itruediv__", - "__rtruediv__", - "__floordiv__", - "__ifloordiv__", - "__rfloordiv__", - "__mod__", - "__imod__", - "__rmod__", - "__and__", - "__iand__", - "__rand__", - "__lshift__", - "__ilshift__", - "__irshift__", - "__rlshift__", - "__rrshift__", - "__rshift__", - "__ior__", - "__or__", - "__ror__", - "__xor__", - "__ixor__", - "__rxor__", - "__eq__", - "__ne__", - "__ge__", - "__gt__", - "__lt__", - "__le__", - "__deepcopy__", - # 3-ary - "__pow__", - "__ipow__", - "__rpow__", - # 2-ary, 2-valued - "__divmod__", - "__rdivmod__", - } - ) - - def make_proxy(name): - """Returns a function that acts as a proxy for the given method name. - - Args: - name (str): The name of the method to introspect in self._tpm. - - Returns: - function: The wrapping function. - """ - - def proxy(self): - return _new_attribute(name, __closures__, self._tpm) - - return proxy - - type.__init__(cls, type_name, bases, dct) - - if not cls.__wraps__: - return - - ignore = cls.__ignore__ - - # Go through all the attribute strings in the wrapped array type. - for name in dir(cls.__wraps__): - # Filter special attributes, rest will be handled by `__getattr__()` - if any((not name.startswith("__"), name in ignore, name in dct)): - continue - - # Create function for `name` and bind to future instances of `cls`. - setattr(cls, name, property(make_proxy(name))) - - -class Wrapper(metaclass=ProxyMetaclass): - """Proxy to the array inside PyPhi's custom ExplicitTPM class.""" - - __wraps__ = None - - __ignore__ = frozenset( - { - "__class__", - "__mro__", - "__new__", - "__init__", - "__setattr__", - "__getattr__", - "__getattribute__", - } - ) - - def __init__(self): - if self.__wraps__ is None: - raise TypeError("Base class Wrapper may not be instantiated.") - - if not isinstance(self._tpm, self.__wraps__): - raise ValueError(f"Wrapped object must be of type {self.__wraps__}") - - class ExplicitTPM(data_structures.ArrayLike, TPM): """An explicit network TPM in multidimensional form. @@ -1035,7 +894,7 @@ def reconstitute_tpm(subsystem): def _new_attribute( name: str, closures: Set[str], - tpm: ExplicitTPM.__wraps__, + tpm: np.ndarray, cls=ExplicitTPM ) -> object: """Helper function to return adequate proxy attributes for TPM arrays. @@ -1068,7 +927,7 @@ def overriding_attribute(*args, **kwargs): # Test type of result and cast (or not) accordingly. # Array. - if isinstance(result, cls.__wraps__): + if isinstance(result, np.ndarray): return cls(result) # Multivalued "functions" returning a tuple (__divmod__()). From 3525ae7ae01cb167d2faf92314522e8648ad6020 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Sat, 25 Mar 2023 16:56:45 -0500 Subject: [PATCH 123/155] Remove `validate.node_states`, deprecated by `Node.state` setter --- pyphi/validate.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/pyphi/validate.py b/pyphi/validate.py index e85b55106..51dcc8c62 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -95,12 +95,6 @@ def is_network(network): ) -def node_states(state): - """Check that the state contains only zeros and ones.""" - if not all(n in (0, 1) for n in state): - raise ValueError("Invalid state: states must consist of only zeros and ones.") - - def state_length(state, size): """Check that the state is the given size.""" if len(state) != size: @@ -145,7 +139,6 @@ def subsystem(s): Checks its state and cut. """ - node_states(s.state) # cut(s.cut, s.cut_indices) if config.VALIDATE_SUBSYSTEM_STATES: state_reachable(s) From 2e958f93b404d5be93bfea5b11c2da84aaced916 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Sat, 25 Mar 2023 18:56:24 -0500 Subject: [PATCH 124/155] Add support for state labels in validate.state_reachable --- pyphi/validate.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pyphi/validate.py b/pyphi/validate.py index 51dcc8c62..4c0989b91 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -120,8 +120,18 @@ def state_reachable(subsystem): # First we take the submatrix of the conditioned TPM that corresponds to # the nodes that are actually in the subsystem... tpm = tpm[..., subsystem.node_indices] + # Make sure the state is translated in terms of integer indices. + # TODO(tpm) Simplify conversion with a state_space class? + state_space = [ + node.state_space for node in subsystem.nodes + if node.index in subsystem.node_indices + ] + state = np.array([ + state_space[node].index(state) + for node, state in enumerate(subsystem.proper_state) + ]) # Then we do the subtraction and test. - test = tpm - np.array(subsystem.proper_state) + test = tpm - state if not np.any(np.logical_and(-1 < test, test < 1).all(-1)): raise exceptions.StateUnreachableError(subsystem.state) From 45646f2da7a757ce766b54b67b8d46206c6b1a3b Mon Sep 17 00:00:00 2001 From: Isaac David Date: Sat, 25 Mar 2023 19:09:10 -0500 Subject: [PATCH 125/155] Fix `Network` creation from existing `ImplicitTPM` --- pyphi/network.py | 10 +++++++--- test/test_subsystem.py | 8 +++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index b6c92e566..7a3d52d1f 100644 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -95,11 +95,10 @@ def __init__( self._cm, self._cm_hash = self._build_cm(cm, tpm, shapes) - network_tpm_shape = ImplicitTPM._node_shapes_to_shape(shapes) - self._node_indices = tuple(range(self.size)) self._node_labels = NodeLabels(node_labels, self._node_indices) + network_tpm_shape = ImplicitTPM._node_shapes_to_shape(shapes) self._state_space, _ = build_state_space( self._node_labels, network_tpm_shape[:-1], @@ -121,9 +120,14 @@ def __init__( elif isinstance(tpm, ImplicitTPM): self._tpm = tpm - self._cm, self._cm_hash = self._build_cm(cm, tpm) + self._cm, self._cm_hash = self._build_cm(cm, self._tpm) self._node_indices = tuple(range(self.size)) self._node_labels = NodeLabels(node_labels, self._node_indices) + self._state_space, _ = build_state_space( + self._node_labels, + self._tpm.shape[:-1], + state_space + ) # FIXME(TPM) initialization from JSON elif isinstance(tpm, dict): diff --git a/test/test_subsystem.py b/test/test_subsystem.py index 230b740e5..a277080ce 100644 --- a/test/test_subsystem.py +++ b/test/test_subsystem.py @@ -148,7 +148,13 @@ def test_cut_node_labels(s): def test_specify_elements_with_labels(standard): - network = Network(standard.tpm, node_labels=("A", "B", "C")) + cm = np.array([ + [0, 0, 1], + [1, 0, 1], + [1, 1, 0] + ]) + print(standard.tpm) + network = Network(standard.tpm, cm, node_labels=("A", "B", "C")) subsystem = Subsystem(network, (0, 0, 0), ("B", "C")) assert subsystem.node_indices == (1, 2) assert tuple(node.label for node in subsystem.nodes) == ("B", "C") From 818164ffdbcffb396c4448d426e424ccabd6f0e1 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Sun, 26 Mar 2023 01:14:11 -0500 Subject: [PATCH 126/155] Remove Node.streamline --- pyphi/node.py | 40 ---------------------------------------- 1 file changed, 40 deletions(-) diff --git a/pyphi/node.py b/pyphi/node.py index f33d2453e..f2224bf9b 100644 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -211,46 +211,6 @@ def project_index(self, index, preserve_singletons=False): def __getitem__(self, index): return self._dataarray[index].pyphi - def streamline(self): - """Remove superfluous coordinates from an unaligned |Node| TPM. - - Returns: - xr.DataArray: The |Node| TPM re-represented without coordinates - introduced during xr.Dataset alignment. - """ - node_labels = self._node_labels - - node_indices = frozenset(range(len(node_labels))) - inputs = self.inputs - noninputs = node_indices - inputs - - input_dims = [input_dimension_label(node_labels[i]) for i in inputs] - noninput_dims = [input_dimension_label(node_labels[i]) for i in noninputs] - - new_input_coords = { - dim: [ - coord for coord in self._dataarray.coords[dim].data - if coord != SINGLETON_COORDINATE - ] - for dim in input_dims - } - new_noninput_coords = { - dim: [SINGLETON_COORDINATE] for dim in noninput_dims - } - probability_coords = { - PROBABILITY_DIMENSION: list( - self._dataarray.coords[PROBABILITY_DIMENSION].data - ) - } - - new_coords = { - **new_input_coords, - **new_noninput_coords, - **probability_coords, - } - - return self._dataarray.reindex(new_coords) - def __repr__(self): return self.label From b7d5537c3a2ad45d5e9caf06a244b3c37134ce27 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Sun, 26 Mar 2023 13:26:10 -0500 Subject: [PATCH 127/155] Refactor implicit_tpm fixture --- test/test_network.py | 44 -------------------------------------------- test/test_tpm.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 44 deletions(-) diff --git a/test/test_network.py b/test/test_network.py index 3baf9ab05..891f9451e 100644 --- a/test/test_network.py +++ b/test/test_network.py @@ -2,8 +2,6 @@ # -*- coding: utf-8 -*- # test_network.py -import random - import numpy as np import xarray as xr import pytest @@ -20,48 +18,6 @@ def network(): return Network(tpm) -@pytest.fixture() -def implicit_tpm(size, degree, node_states, seed=1337, deterministic_units=False): - rng = random.Random(seed) - - def random_deterministic_repertoire(): - repertoire = rng.sample([1] + (node_states - 1) * [0], node_states) - return repertoire - - def random_repertoire(deterministic_units): - if deterministic_units: - return random_deterministic_repertoire() - - repertoire = np.array([rng.uniform(0, 1) for s in range(node_states)]) - # Normalize using L1 (probabilities accross node_states must sum to 1) - repertoire = repertoire / repertoire.sum() - - return ( - repertoire if repertoire.sum() == 1.0 - else random_deterministic_repertoire() - ) - - tpm = [] - - for node_index in range(size): - # Generate |node_states| pseudo-probabilities for each combination of - # parent states at t - 1. - node_tpm = [ - random_repertoire(deterministic_units) - for j in range(node_states ** degree) - ] - # Select |degree| nodes at random as parents to this node, then reshape - # node TPM to multidimensional form. - node_shape = np.ones(size, dtype=int) - parents = rng.sample(range(size), degree) - node_shape[parents] = node_states - node_tpm = np.array(node_tpm).reshape(tuple(node_shape) + (node_states,)) - - tpm.append(node_tpm) - - return tpm - - def test_network_init_validation(network): with pytest.raises(ValueError): # Totally wrong shape diff --git a/test/test_tpm.py b/test/test_tpm.py index 95746619b..7c3082db8 100644 --- a/test/test_tpm.py +++ b/test/test_tpm.py @@ -5,10 +5,53 @@ import numpy as np import pickle import pytest +import random from pyphi import Subsystem +from pyphi.distribution import normalize from pyphi.tpm import ExplicitTPM, reconstitute_tpm + +@pytest.fixture() +def implicit_tpm(size, degree, node_states, seed=1337, deterministic=False): + rng = random.Random(seed) + + def random_deterministic_repertoire(): + """Assign all probability to a single purview state at random.""" + repertoire = rng.sample([1] + (node_states - 1) * [0], node_states) + return repertoire + + def random_repertoire(deterministic): + if deterministic: + return random_deterministic_repertoire() + + repertoire = np.array([rng.uniform(0, 1) for s in range(node_states)]) + # Normalize using L1 metric. + return normalize(repertoire) + + tpm = [] + + for node_index in range(size): + # Generate |node_states| repertoires for each combination of parent + # states at t - 1. + node_tpm = [ + random_repertoire(deterministic) + for j in range(node_states ** degree) + ] + + # Select |degree| nodes at random as parents to this node, then reshape + # node TPM to multidimensional form. + node_shape = np.ones(size, dtype=int) + parents = rng.sample(range(size), degree) + node_shape[parents] = node_states + + node_tpm = np.array(node_tpm).reshape(tuple(node_shape) + (node_states,)) + + tpm.append(node_tpm) + + return tpm + + @pytest.mark.parametrize( "tpm", [ExplicitTPM(np.random.rand(42)), ExplicitTPM(np.arange(42))] From 7f04bac36ff1c18b5584f7f7654c3a4da178c110 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Mon, 10 Apr 2023 17:39:06 -0500 Subject: [PATCH 128/155] Disable state reachability validation until reimplemented in implicit way --- pyphi/validate.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyphi/validate.py b/pyphi/validate.py index 4c0989b91..593807f21 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -6,6 +6,8 @@ Methods for validating arguments. """ +from warnings import warn + import numpy as np from . import config, exceptions @@ -151,7 +153,9 @@ def subsystem(s): """ # cut(s.cut, s.cut_indices) if config.VALIDATE_SUBSYSTEM_STATES: - state_reachable(s) + # TODO(tpm) Reimplement in a way that never reconstitutes the full TPM. + # state_reachable(s) + warn("Validation of state reachability didn't take place.") return True From 89bad27c78df8b363d4076a072f0c1df1f55e306 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 21 Nov 2023 17:17:16 -0600 Subject: [PATCH 129/155] tpm: Simplify subtpm method --- pyphi/tpm.py | 92 ++++++++++++++++++---------------------------------- 1 file changed, 32 insertions(+), 60 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 5b111dea8..1263828c0 100755 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -56,16 +56,43 @@ def remove_singleton_dimensions(self): def expand_tpm(self): raise NotImplementedError - def subtpm(fixed_nodes, state): - raise NotImplementedError + def subtpm(self, fixed_nodes, state): + """Return the TPM for a subset of nodes, conditioned on other nodes. - def _subtpm(self, fixed_nodes, state): - """Helper method shared by subtpm().""" + Arguments: + fixed_nodes (tuple[int]): The nodes to select. + state (tuple[int]): The state of the fixed nodes. + + Returns: + ExplicitTPM: The TPM of just the subsystem of the free nodes. + + Examples: + >>> from pyphi import examples + >>> # Get the TPM for nodes only 1 and 2, conditioned on node 0 = OFF + >>> reconstitute_tpm(examples.grid3_network().tpm).subtpm((0,), (0,)) + ExplicitTPM( + [[[[0.02931223 0.04742587] + [0.07585818 0.88079708]] + + [[0.81757448 0.11920292] + [0.92414182 0.95257413]]]] + ) + """ N = self.shape[-1] free_nodes = sorted(set(range(N)) - set(fixed_nodes)) condition = FrozenMap(zip(fixed_nodes, state)) conditioned_tpm = self.condition_tpm(condition) - return conditioned_tpm, free_nodes + + if isinstance(self, ExplicitTPM): + return conditioned_tpm[..., free_nodes] + + return type(self)( + tuple( + node for node in conditioned_tpm.nodes + if node.index in free_nodes + ) + ) + def infer_edge(self, a, b, contexts): """Infer the presence or absence of an edge from node A to node B. @@ -451,31 +478,6 @@ def remove_singleton_dimensions(self): return self.squeeze()[..., self.tpm_indices()] - def subtpm(self, fixed_nodes, state): - """Return the TPM for a subset of nodes, conditioned on other nodes. - - Arguments: - fixed_nodes (tuple[int]): The nodes to select. - state (tuple[int]): The state of the fixed nodes. - - Returns: - ExplicitTPM: The TPM of just the subsystem of the free nodes. - - Examples: - >>> from pyphi import examples - >>> # Get the TPM for nodes only 1 and 2, conditioned on node 0 = OFF - >>> reconstitute_tpm(examples.grid3_network().tpm).subtpm((0,), (0,)) - ExplicitTPM( - [[[[0.02931223 0.04742587] - [0.07585818 0.88079708]] - - [[0.81757448 0.11920292] - [0.92414182 0.95257413]]]] - ) - """ - conditioned_tpm, free_nodes = self._subtpm(fixed_nodes, state) - return conditioned_tpm[..., free_nodes] - def expand_tpm(self): """Broadcast a state-by-node TPM so that singleton dimensions are expanded over the full network. @@ -759,36 +761,6 @@ def remove_singleton_dimensions(self): tuple(node for node in self.squeeze().nodes) ) - def subtpm(self, fixed_nodes, state): - """Return the TPM for a subset of nodes, conditioned on other nodes. - - Arguments: - fixed_nodes (tuple[int]): The nodes to select. - state (tuple[int]): The state of the fixed nodes. - - Returns: - ExplicitTPM: The TPM of just the subsystem of the free nodes. - - Examples: - >>> from pyphi import examples - >>> # Get the TPM for nodes only 1 and 2, conditioned on node 0 = OFF - >>> reconstitute_tpm(examples.grid3_network().tpm.subtpm((0,), (0,))) - ExplicitTPM( - [[[[0.02931223 0.04742587] - [0.07585818 0.88079708]] - - [[0.81757448 0.11920292] - [0.92414182 0.95257413]]]] - ) - """ - conditioned_tpm, free_nodes = self._subtpm(fixed_nodes, state) - return type(self)( - tuple( - node for node in conditioned_tpm.nodes - if node.index in free_nodes - ) - ) - def equals(self, o: object): """Return whether this TPM equals the other object. From 20808fe0a3a7a6591579748927a1738b85eb7d9f Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 22 Nov 2023 17:53:05 -0600 Subject: [PATCH 130/155] Refactor TPM validation --- pyphi/distribution.py | 15 --------------- pyphi/tpm.py | 41 ++++++++++++++++------------------------- pyphi/validate.py | 1 - 3 files changed, 16 insertions(+), 41 deletions(-) mode change 100644 => 100755 pyphi/distribution.py diff --git a/pyphi/distribution.py b/pyphi/distribution.py old mode 100644 new mode 100755 index ab746632e..08cf037ad --- a/pyphi/distribution.py +++ b/pyphi/distribution.py @@ -12,21 +12,6 @@ from .utils import eq -def is_unitary(a): - """Whether the distribution satisfies the second axiom of probability theory. - - This uses utils.eq and config.PRECISION to compare floats up to a tolerance. - - Args: - a (np.ndarray): The array to verify for unit measure. - - Returns: - bool: Whether the sum of entries in ``a`` is close enough to 1. - """ - measure = a.sum() - return eq(measure, 1.0) - - def normalize(a): """Normalize a distribution. diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 1263828c0..1d7ded400 100755 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -267,7 +267,10 @@ def __init__(self, tpm, validate=False): self._tpm = np.array(tpm) if validate: - self.validate(check_independence=config.VALIDATE_CONDITIONAL_INDEPENDENCE) + self.validate( + check_independence=config.VALIDATE_CONDITIONAL_INDEPENDENCE, + network_tpm=True + ) self._tpm = self.to_multidimensional_state_by_node() self._tpm = np_immutable(self._tpm) @@ -278,36 +281,33 @@ def tpm(self): """np.ndarray: The underlying `tpm` object.""" return self._tpm - def validate(self, check_independence=True): + def validate(self, check_independence=True, network_tpm=False): """Validate this TPM.""" - return self._validate_probabilities() and self._validate_shape( + return self._validate_probabilities(network_tpm) and self._validate_shape( check_independence ) - def _validate_probabilities(self): + def _validate_probabilities(self, network_tpm=False): """Check that the probabilities in a TPM are valid.""" + # Validate TPM image is within [0, 1] (first axiom of probability). if (self._tpm < 0.0).any() or (self._tpm > 1.0).any(): raise ValueError(self._ERROR_MSG_PROBABILITY_IMAGE) # Validate that probabilities sum to 1. - if not self.is_unitary(): + if not self.is_unitary(network_tpm): raise ValueError(self._ERROR_MSG_PROBABILITY_SUM) return True - def is_unitary(self, implicit_tpm=False): + def is_unitary(self, network_tpm=False): """Whether the TPM satisfies the second axiom of probability theory.""" - if implicit_tpm: - measures = self.sum(axis=-1).ravel() - return all(eq(p, 1.0) for p in measures) - - if not self.is_state_by_state(): + tpm = self + if network_tpm and not tpm.is_state_by_state(): tpm = convert.state_by_node2state_by_state(self) - else: - tpm = self - distributions = (d for d in tpm) - return all(distribution.is_unitary(d) for d in distributions) + # Marginalize last dimension, then check that all integrals are close to 1. + measures_over_current_states = tpm.sum(axis=-1).ravel() + return all(eq(p, 1.0) for p in measures_over_current_states) def _validate_shape(self, check_independence=True): """Validate this TPM's shape. @@ -630,13 +630,6 @@ def _validate_probabilities(self): """Check that the probabilities in a TPM are valid.""" # An implicit TPM contains valid probabilities if and only if # individual node TPMs contain valid probabilities, for every node. - - # Validate that probabilities sum to 1. - if not self.is_unitary(): - raise ValueError(self._ERROR_MSG_PROBABILITY_SUM) - - # Leverage method in ExplicitTPM to distribute validation of - # TPM image within [0, 1]. if all( node.tpm._validate_probabilities() for node in self._nodes @@ -645,9 +638,7 @@ def _validate_probabilities(self): def is_unitary(self): """Whether the TPM satisfies the second axiom of probability theory.""" - return all( - node.tpm.is_unitary(implicit_tpm=True) for node in self._nodes - ) + return all(node.tpm.is_unitary() for node in self._nodes) def _validate_shape(self): """Validate this TPM's shape. diff --git a/pyphi/validate.py b/pyphi/validate.py index 008c38646..30545c2fa 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -51,7 +51,6 @@ def node_labels(node_labels, node_indices): if len(node_labels) != len(set(node_labels)): raise ValueError("Labels {0} must be unique.".format(node_labels)) - def network(n): """Validate a |Network|. From e687f2653706c47f216a480d8c7be307542bc784 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 23 Nov 2023 19:53:56 -0600 Subject: [PATCH 131/155] subsystem: Fix proper_tpm --- pyphi/subsystem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index f17ba1bab..9bf9b42ef 100755 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -110,7 +110,7 @@ def __init__( else: self.tpm = self.network.tpm.condition_tpm(background_conditions) # The TPM for just the nodes in the subsystem. - self.proper_tpm = self.tpm.squeeze()[..., list(self.node_indices)] + self.proper_tpm = self.tpm.squeeze() # The unidirectional cut applied for phi evaluation self.cut = ( From 0454ad764ed7232e2e45aab1e3d2508a7b5d29da Mon Sep 17 00:00:00 2001 From: Isaac David Date: Fri, 24 Nov 2023 13:53:54 -0600 Subject: [PATCH 132/155] Tidy documentation --- pyphi/network.py | 6 ++---- pyphi/node.py | 2 -- pyphi/tpm.py | 21 +++++++++++++++++++-- 3 files changed, 21 insertions(+), 8 deletions(-) mode change 100644 => 100755 pyphi/network.py diff --git a/pyphi/network.py b/pyphi/network.py old mode 100644 new mode 100755 index 550ab46ae..9915343ac --- a/pyphi/network.py +++ b/pyphi/network.py @@ -51,11 +51,9 @@ def __init__( # Initialize _tpm according to argument type. if isinstance(tpm, (np.ndarray, ExplicitTPM)): - # Validate tpm even if an ExplicitTPM was provided. ExplicitTPM - # accepts instantiation from either another object of its class or - # np.ndarray, so the following achieves validation in general (and - # converstion to multidimensional form, as a side effect). + # Validate TPM and convert to state-by-node multidimensional format. tpm = ExplicitTPM(tpm, validate=True) + self._cm, self._cm_hash = self._build_cm(cm, tpm) self._node_indices = tuple(range(self.size)) diff --git a/pyphi/node.py b/pyphi/node.py index 0f2589856..594fe0677 100755 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -22,9 +22,7 @@ from .connectivity import get_inputs_from_cm, get_outputs_from_cm from .state_space import ( dimension_labels, - input_dimension_label, build_state_space, - PROBABILITY_DIMENSION, SINGLETON_COORDINATE, ) from .utils import state_of diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 1d7ded400..b2e33f27f 100755 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -300,7 +300,19 @@ def _validate_probabilities(self, network_tpm=False): return True def is_unitary(self, network_tpm=False): - """Whether the TPM satisfies the second axiom of probability theory.""" + """Whether the TPM satisfies the second axiom of probability theory. + + A TPM is unitary if and only if for every current state of the system, + the probability distribution over next states conditioned on the current + state sums to 1 (up to |config.PRECISION|). + + Keyword Args: + network_tpm (bool): Whether ``self`` is an old-style system TPM + instead of a node TPM. + + Returns: + bool: + """ tpm = self if network_tpm and not tpm.is_state_by_state(): tpm = convert.state_by_node2state_by_state(self) @@ -637,7 +649,12 @@ def _validate_probabilities(self): return True def is_unitary(self): - """Whether the TPM satisfies the second axiom of probability theory.""" + """Whether the TPM satisfies the second axiom of probability theory. + + A TPM is unitary if and only if for every current state of the system, + the probability distribution over next states conditioned on the current + state sums to 1 (up to |config.PRECISION|). + """ return all(node.tpm.is_unitary() for node in self._nodes) def _validate_shape(self): From 460e9c8ddb4ae73f4aeac18c61a99cd189627070 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Fri, 24 Nov 2023 14:04:01 -0600 Subject: [PATCH 133/155] node.py: Improve code formatting --- pyphi/node.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/pyphi/node.py b/pyphi/node.py index 594fe0677..341a2713d 100755 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -27,6 +27,7 @@ ) from .utils import state_of + @xr.register_dataarray_accessor("pyphi") @functools.total_ordering class Node: @@ -185,7 +186,7 @@ def project_index(self, index, preserve_singletons=False): raise ValueError( "Dimension {} does not exist. Expected one or more of: " "{}.".format(e, dimensions) - ) + ) from e return projected_index @@ -226,12 +227,12 @@ def __eq__(self, other): labels. """ return ( - self.index == other.index - and self.tpm.array_equal(other.tpm) - and self.inputs == other.inputs - and self.outputs == other.outputs - and self.state_space == other.state_space - and self.state == other.state + self.index == other.index and + self.tpm.array_equal(other.tpm) and + self.inputs == other.inputs and + self.outputs == other.outputs and + self.state_space == other.state_space and + self.state == other.state ) def __ne__(self, other): @@ -297,7 +298,7 @@ def node( node_labels, tpm.shape[:-1], node_states, - singleton_state_space = (SINGLETON_COORDINATE,), + singleton_state_space=(SINGLETON_COORDINATE,), ) node_state_space = network_state_space[dimensions[index]] @@ -305,11 +306,11 @@ def node( coordinates = {**input_coordinates, dimensions[-1]: node_state_space} return xr.DataArray( - name = node_labels[index], - data = tpm, - dims = dimensions, - coords = coordinates, - attrs = { + name=node_labels[index], + data=tpm, + dims=dimensions, + coords=coordinates, + attrs={ "index": index, "node_labels": node_labels, "cm": cm, From 0f8bafcde05a942519f8fb36541ed8a1bc0cff73 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Fri, 24 Nov 2023 16:11:18 -0600 Subject: [PATCH 134/155] Rename DataArray accessor to something more explicit --- pyphi/network.py | 2 +- pyphi/node.py | 8 ++++---- pyphi/subsystem.py | 2 +- pyphi/tpm.py | 13 +++++++++---- 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index 9915343ac..a53cc1df2 100755 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -111,7 +111,7 @@ def __init__( self._state_space, index, node_labels=self._node_labels - ).pyphi + ).pyphi_accessor for index, node_tpm in zip(self._node_indices, tpm) ) ) diff --git a/pyphi/node.py b/pyphi/node.py index 341a2713d..e04c0d0cf 100755 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -28,7 +28,7 @@ from .utils import state_of -@xr.register_dataarray_accessor("pyphi") +@xr.register_dataarray_accessor("pyphi_accessor") @functools.total_ordering class Node: """A node in a Network. @@ -208,7 +208,7 @@ def project_index(self, index, preserve_singletons=False): return projected_index def __getitem__(self, index): - return self._dataarray[index].pyphi + return self._dataarray[index].pyphi_accessor def __repr__(self): return self.label @@ -384,9 +384,9 @@ def generate_nodes( cm, state_space, index, + node_labels=node_labels, state=state, - node_labels=node_labels - ).pyphi + ).pyphi_accessor ) return tuple(nodes) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 9bf9b42ef..db21fb2d4 100755 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -147,7 +147,7 @@ def __init__( i, self.node_labels, state=node.state - ).pyphi + ).pyphi_accessor for i, node in enumerate(self.tpm.nodes) if i in self.node_indices ) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index b2e33f27f..0f2f8f4c5 100755 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -559,6 +559,11 @@ def nodes(self): """Tuple[xr.DataArray]: The node TPMs in this ImplicitTPM""" return self._nodes + @property + def tpm(self): + """Tuple[np.ndarray]: Verbose representation of all node TPMs.""" + return tuple(node.tpm for node in self._nodes) + @property def ndim(self): """int: The number of dimensions of the TPM.""" @@ -733,7 +738,7 @@ def marginalize_out(self, node_indices): node.dataarray.attrs["network_state_space"], node.index, node_labels=node.dataarray.attrs["node_labels"], - ).pyphi + ).pyphi_accessor for node in self.nodes ) ) @@ -821,7 +826,7 @@ def squeeze(self, axis=None): new_state_space, next(new_node_indices), new_node_labels, - ).pyphi + ).pyphi_accessor for node in self.nodes if node.index in nonsingletons ) ) @@ -830,14 +835,14 @@ def __getitem__(self, index, **kwargs): if isinstance(index, (int, slice, type(...), tuple)): return type(self)( tuple( - node.dataarray[node.project_index(index, **kwargs)].pyphi + node.dataarray[node.project_index(index, **kwargs)].pyphi_accessor for node in self.nodes ) ) if isinstance(index, dict): return type(self)( tuple( - node.dataarray.loc[node.project_index(index, **kwargs)].pyphi + node.dataarray.loc[node.project_index(index, **kwargs)].pyphi_accessor for node in self.nodes ) ) From 74ac75f46cf58f7df2983a661a7397ba2d33a74c Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 5 Dec 2023 12:45:23 -0600 Subject: [PATCH 135/155] validate: Disable state reachability warning for now --- pyphi/validate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyphi/validate.py b/pyphi/validate.py index 30545c2fa..a3fdabcd7 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -161,7 +161,8 @@ def subsystem(s): if config.VALIDATE_SUBSYSTEM_STATES: # TODO(tpm) Reimplement in a way that never reconstitutes the full TPM. # state_reachable(s) - warn("Validation of state reachability didn't take place.") + # warn("Validation of state reachability didn't take place.") + pass return True From 6593f218f61f519ca1a3f400f0b50a8f0fedb8eb Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 12 Dec 2023 16:02:11 -0600 Subject: [PATCH 136/155] tpm: add `number_of_units` getter in the ImplicitTPM case --- pyphi/tpm.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 0f2f8f4c5..8e217f093 100755 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -281,6 +281,13 @@ def tpm(self): """np.ndarray: The underlying `tpm` object.""" return self._tpm + @property + def number_of_units(self): + if self.is_state_by_state(): + # Assumes binary nodes + return int(math.log2(self._tpm.shape[1])) + return self._tpm.shape[-1] + def validate(self, check_independence=True, network_tpm=False): """Validate this TPM.""" return self._validate_probabilities(network_tpm) and self._validate_shape( @@ -366,13 +373,6 @@ def _validate_shape(self, check_independence=True): ) return True - @property - def number_of_units(self): - if self.is_state_by_state(): - # Assumes binary nodes - return int(math.log2(self._tpm.shape[1])) - return self._tpm.shape[-1] - def to_multidimensional_state_by_node(self): """Return the current TPM re-represented in multidimensional state-by-node form. @@ -564,6 +564,10 @@ def tpm(self): """Tuple[np.ndarray]: Verbose representation of all node TPMs.""" return tuple(node.tpm for node in self._nodes) + @property + def number_of_units(self): + return self.ndim - 1 + @property def ndim(self): """int: The number of dimensions of the TPM.""" From ac5fc864c3e8cb53a978cb0252d8f07412b17158 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 21 Dec 2023 15:18:32 -0600 Subject: [PATCH 137/155] ImplicitTPM.number_of_units: more efficient implementation --- pyphi/tpm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 8e217f093..2823db2d3 100755 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -566,7 +566,7 @@ def tpm(self): @property def number_of_units(self): - return self.ndim - 1 + return len(self.nodes) @property def ndim(self): @@ -592,7 +592,7 @@ def shapes(self): @staticmethod def _node_shapes_to_shape( shapes: Iterable[Iterable[int]], - reconstituted: Optional[bool]=None + reconstituted: Optional[bool] = None ) -> Tuple[int]: """Infer the shape of the equivalent multidimensional |ExplicitTPM|. From 685cb7f4dac5258e5893f01c6a88eed69dee2ac6 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 26 Dec 2023 14:03:26 -0600 Subject: [PATCH 138/155] test_tpm: validate arguments in `implicit_tpm` test --- test/test_tpm.py | 8 ++++++++ 1 file changed, 8 insertions(+) mode change 100644 => 100755 test/test_tpm.py diff --git a/test/test_tpm.py b/test/test_tpm.py old mode 100644 new mode 100755 index 7c3082db8..4726f82b9 --- a/test/test_tpm.py +++ b/test/test_tpm.py @@ -14,6 +14,14 @@ @pytest.fixture() def implicit_tpm(size, degree, node_states, seed=1337, deterministic=False): + if degree > size: + raise ValueError( + f"The number of parrents of each node (degree={degree}) cannot be" + f"smaller than the size of the network ({size})." + ) + if node_states < 2: + raise ValueError("Nodes must have at least 2 node_states.") + rng = random.Random(seed) def random_deterministic_repertoire(): From e4f1757a8dfb42c507e014f09b58b58a2d5fbc93 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 28 Dec 2023 13:42:20 -0600 Subject: [PATCH 139/155] tpm: implement `probability_of_current_state()` and `backward_tpm()` for ImplicitTPMs --- pyphi/tpm.py | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 2823db2d3..8035ab158 100755 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -7,6 +7,7 @@ """ import math +import functools from itertools import chain from typing import Iterable, Mapping, Optional, Set, Tuple @@ -954,6 +955,40 @@ def overriding_attribute(*args, **kwargs): return overriding_attribute +def probability_of_current_state2( + tpm: ImplicitTPM, + current_state: tuple[int] +) -> tuple[ExplicitTPM]: + """Return the probability of the current state as a distribution over previous states. + + Output format is similar to an |ImplicitTPM|, however the last dimension + only contains the probability for the current state. + + Arguments: + tpm (ImplicitTPM): The TPM of the |Network|. + current_state (tuple[int]): The current state. + Returns: + tuple[ExplicitTPM]: Node-marginal distributions of the current state. + """ + if not len(current_state) == tpm.number_of_units: + raise ValueError( + f"current_state must have length {tpm.number_of_units}" + f"for state-by-node TPM of shape {tpm.shape}" + ) + + nodes = [] + for node in tpm.nodes: + i = node.index + state = [current_state[i]] + + # DataArray indexing: keep last dimension by wrapping index inside list. + pr_current_state = node.dataarray[..., [state]].data + normalization = np.sum(pr_current_state) + nodes.append(pr_current_state / normalization) + + return tuple(nodes) + + def probability_of_current_state(sbn_tpm, current_state): """Return the probability of the current state as a distribution over previous states. @@ -975,6 +1010,59 @@ def probability_of_current_state(sbn_tpm, current_state): return state_probabilities.prod(axis=-1, keepdims=True) +def backward_tpm2( + forward_tpm: ImplicitTPM, + current_state: tuple[int], + system_indices: Iterable[int], +) -> ImplicitTPM: + """Compute the backward TPM for a given network state.""" + all_indices = tuple(range(forward_tpm.number_of_units)) + system_indices = tuple(sorted(system_indices)) + background_indices = tuple(sorted(set(all_indices) - set(system_indices))) + if not set(system_indices).issubset(set(all_indices)): + raise ValueError( + "system_indices must be a subset of `range(forward_tpm.number_of_units))`" + ) + # p(u_t | s_{t–1}, w_{t–1}) + pr_current_state_nodes = probability_of_current_state2(forward_tpm, current_state) + # TODO Avoid computing the full joint probability. E.g., find uninformative + # dimensions after each product and propagate their dismissal. + pr_current_state = functools.reduce(np.multiply, pr_current_state_nodes) + # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) + pr_current_state_given_only_background = pr_current_state.sum( + axis=tuple(system_indices), keepdims=True + ) + # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) + # ——————————————————————————————————————— + # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) + pr_current_state_given_only_background_normalized = ( + pr_current_state_given_only_background / np.sum(pr_current_state) + ) + # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) + # Σ_{w_{t–1}} p(s_{i,t} | s_{t–1}, w_{t–1}) ——————————————————————————————————————— + # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) + backward_tpm = tuple( + (node_tpm * pr_current_state_given_only_background_normalized).sum( + axis=background_indices, keepdims=True + ) + for node_tpm in forward_tpm.tpm + ) + + reference_node = forward_tpm.nodes[0].dataarray + return ImplicitTPM( + tuple( + Node( + backward_node_tpm, + reference_node.attrs["cm"], + reference_node.attrs["network_state_space"], + i, + reference_node.attrs["node_labels"], + ).pyphi_accessor + for i, backward_node_tpm in enumerate(backward_tpm) + ) + ) + + def backward_tpm( forward_tpm: ExplicitTPM, current_state: tuple[int], From 1014cecc6d035482d409c8fd89bdb985ae9b9640 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 28 Dec 2023 14:24:35 -0600 Subject: [PATCH 140/155] Refactor `backward_tpm as instance methods` --- pyphi/subsystem.py | 3 +- pyphi/tpm.py | 278 ++++++++++++++++++++++----------------------- 2 files changed, 137 insertions(+), 144 deletions(-) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index db21fb2d4..197a4c535 100755 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -33,7 +33,6 @@ from .network import irreducible_purviews from .partition import mip_partitions from .repertoire import forward_repertoire, unconstrained_forward_repertoire -from .tpm import backward_tpm as _backward_tpm from .utils import state_of log = logging.getLogger(__name__) @@ -106,7 +105,7 @@ def __init__( background_conditions = dict(zip(self.external_indices, external_state)) self.backward_tpm = backward_tpm if self.backward_tpm: - self.tpm = _backward_tpm(self.network.tpm, state, self.node_indices) + self.tpm = self.network.tpm.backward_tpm(state, self.node_indices) else: self.tpm = self.network.tpm.condition_tpm(background_conditions) # The TPM for just the nodes in the subsystem. diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 8035ab158..90572c066 100755 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -150,6 +150,9 @@ def print(self): def permute_nodes(self, permutation): raise NotImplementedError + def backward_tpm(self, current_state, system_indices): + raise NotImplementedError + def __str__(self): raise NotImplementedError @@ -515,6 +518,59 @@ def permute_nodes(self, permutation): self._tpm.transpose(dimension_permutation)[..., list(permutation)], ) + def _probability_of_current_state(self, current_state): + """Return the probability of the current state as a distribution over previous states. + + Arguments: + current_state (tuple[int]): The current state. + """ + state_probabilities = np.empty(self.shape) + if not len(current_state) == self.shape[-1]: + raise ValueError( + f"current_state must have length {self.shape[-1]}" + f"for state-by-node TPM of shape {self.shape}" + ) + for i in range(self.shape[-1]): + # TODO extend to nonbinary nodes + state_probabilities[..., i] = ( + self[..., i] if current_state[i] else (1 - self[..., i]) + ) + return state_probabilities.prod(axis=-1, keepdims=True) + + def backward_tpm( + self, + current_state: tuple[int], + system_indices: Iterable[int], + remove_background: bool = False, + ): + """Compute the backward TPM for a given network state.""" + all_indices = tuple(range(self.number_of_units)) + system_indices = tuple(sorted(system_indices)) + background_indices = tuple(sorted(set(all_indices) - set(system_indices))) + if not set(system_indices).issubset(set(all_indices)): + raise ValueError( + "system_indices must be a subset of `range(self.number_of_units))`" + ) + + # p(u_t | s_{t–1}, w_{t–1}) + pr_current_state = self._probability_of_current_state(current_state) + # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) + pr_current_state_given_only_background = pr_current_state.sum( + axis=tuple(system_indices), keepdims=True + ) + # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) + normalization = np.sum(pr_current_state) + # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) + # Σ_{w_{t–1}} p(s_{i,t} | s_{t–1}, w_{t–1}) ——————————————————————————————————————— + # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) + backward_tpm = ( + self * pr_current_state_given_only_background / normalization + ).sum(axis=background_indices, keepdims=True) + if remove_background: + # Remove background units from last dimension of the state-by-node TPM + backward_tpm = backward_tpm[..., list(system_indices)] + return ExplicitTPM(backward_tpm) + def array_equal(self, o: object): """Return whether this TPM equals the other object. @@ -779,6 +835,86 @@ def remove_singleton_dimensions(self): tuple(node for node in self.squeeze().nodes) ) + def _probability_of_current_state( + self, + current_state: tuple[int] + ) -> tuple[ExplicitTPM]: + """Return the probability of the current state as a distribution over previous states. + + Output format is similar to an |ImplicitTPM|, however the last dimension + only contains the probability for the current state. + + Arguments: + current_state (tuple[int]): The current state. + Returns: + tuple[ExplicitTPM]: Node-marginal distributions of the current state. + """ + if not len(current_state) == self.number_of_units: + raise ValueError( + f"current_state must have length {self.number_of_units}" + f"for state-by-node TPM of shape {self.shape}" + ) + nodes = [] + for node in self.nodes: + i = node.index + state = current_state[i] + # DataArray indexing: keep last dimension by wrapping index in list. + pr_current_state = node.dataarray[..., [state]].data + normalization = np.sum(pr_current_state) + nodes.append(pr_current_state / normalization) + return tuple(nodes) + + def backward_tpm( + self, + current_state: tuple[int], + system_indices: Iterable[int], + ): + """Compute the backward TPM for a given network state.""" + all_indices = tuple(range(self.number_of_units)) + system_indices = tuple(sorted(system_indices)) + background_indices = tuple(sorted(set(all_indices) - set(system_indices))) + if not set(system_indices).issubset(set(all_indices)): + raise ValueError( + "system_indices must be a subset of `range(self.number_of_units))`" + ) + # p(u_t | s_{t–1}, w_{t–1}) + pr_current_state_nodes = self._probability_of_current_state(current_state) + # TODO Avoid computing the full joint probability. Find uninformative + # dimensions after each product and propagate their dismissal. + pr_current_state = functools.reduce(np.multiply, pr_current_state_nodes) + # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) + pr_current_state_given_only_background = pr_current_state.sum( + axis=tuple(system_indices), keepdims=True + ) + # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) + # ————————————————————————————————————— + # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) + pr_current_state_given_only_background_normalized = ( + pr_current_state_given_only_background / np.sum(pr_current_state) + ) + # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) + # Σ_{w_{t–1}} p(s_{i,t} | s_{t–1}, w_{t–1}) ————————————————————————————————————— + # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) + backward_tpm = tuple( + (node_tpm * pr_current_state_given_only_background_normalized).sum( + axis=background_indices, keepdims=True + ) + for node_tpm in self.tpm + ) + reference_node = self.nodes[0].dataarray + return ImplicitTPM( + tuple( + Node( + backward_node_tpm, + reference_node.attrs["cm"], + reference_node.attrs["network_state_space"], + i, + reference_node.attrs["node_labels"], + ).pyphi_accessor + for i, backward_node_tpm in enumerate(backward_tpm) + ) + ) + def equals(self, o: object): """Return whether this TPM equals the other object. @@ -954,145 +1090,3 @@ def overriding_attribute(*args, **kwargs): return overriding_attribute - -def probability_of_current_state2( - tpm: ImplicitTPM, - current_state: tuple[int] -) -> tuple[ExplicitTPM]: - """Return the probability of the current state as a distribution over previous states. - - Output format is similar to an |ImplicitTPM|, however the last dimension - only contains the probability for the current state. - - Arguments: - tpm (ImplicitTPM): The TPM of the |Network|. - current_state (tuple[int]): The current state. - Returns: - tuple[ExplicitTPM]: Node-marginal distributions of the current state. - """ - if not len(current_state) == tpm.number_of_units: - raise ValueError( - f"current_state must have length {tpm.number_of_units}" - f"for state-by-node TPM of shape {tpm.shape}" - ) - - nodes = [] - for node in tpm.nodes: - i = node.index - state = [current_state[i]] - - # DataArray indexing: keep last dimension by wrapping index inside list. - pr_current_state = node.dataarray[..., [state]].data - normalization = np.sum(pr_current_state) - nodes.append(pr_current_state / normalization) - - return tuple(nodes) - - -def probability_of_current_state(sbn_tpm, current_state): - """Return the probability of the current state as a distribution over previous states. - - Arguments: - sbn_tpm (ExplicitTPM): State-by-node TPM. - current_state (tuple[int]): The current state. - """ - state_probabilities = np.empty(sbn_tpm.shape) - if not len(current_state) == sbn_tpm.shape[-1]: - raise ValueError( - f"current_state must have length {sbn_tpm.shape[-1]}" - f"for state-by-node TPM of shape {sbn_tpm.shape}" - ) - for i in range(sbn_tpm.shape[-1]): - # TODO extend to nonbinary nodes - state_probabilities[..., i] = ( - sbn_tpm[..., i] if current_state[i] else (1 - sbn_tpm[..., i]) - ) - return state_probabilities.prod(axis=-1, keepdims=True) - - -def backward_tpm2( - forward_tpm: ImplicitTPM, - current_state: tuple[int], - system_indices: Iterable[int], -) -> ImplicitTPM: - """Compute the backward TPM for a given network state.""" - all_indices = tuple(range(forward_tpm.number_of_units)) - system_indices = tuple(sorted(system_indices)) - background_indices = tuple(sorted(set(all_indices) - set(system_indices))) - if not set(system_indices).issubset(set(all_indices)): - raise ValueError( - "system_indices must be a subset of `range(forward_tpm.number_of_units))`" - ) - # p(u_t | s_{t–1}, w_{t–1}) - pr_current_state_nodes = probability_of_current_state2(forward_tpm, current_state) - # TODO Avoid computing the full joint probability. E.g., find uninformative - # dimensions after each product and propagate their dismissal. - pr_current_state = functools.reduce(np.multiply, pr_current_state_nodes) - # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) - pr_current_state_given_only_background = pr_current_state.sum( - axis=tuple(system_indices), keepdims=True - ) - # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) - # ——————————————————————————————————————— - # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) - pr_current_state_given_only_background_normalized = ( - pr_current_state_given_only_background / np.sum(pr_current_state) - ) - # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) - # Σ_{w_{t–1}} p(s_{i,t} | s_{t–1}, w_{t–1}) ——————————————————————————————————————— - # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) - backward_tpm = tuple( - (node_tpm * pr_current_state_given_only_background_normalized).sum( - axis=background_indices, keepdims=True - ) - for node_tpm in forward_tpm.tpm - ) - - reference_node = forward_tpm.nodes[0].dataarray - return ImplicitTPM( - tuple( - Node( - backward_node_tpm, - reference_node.attrs["cm"], - reference_node.attrs["network_state_space"], - i, - reference_node.attrs["node_labels"], - ).pyphi_accessor - for i, backward_node_tpm in enumerate(backward_tpm) - ) - ) - - -def backward_tpm( - forward_tpm: ExplicitTPM, - current_state: tuple[int], - system_indices: Iterable[int], - remove_background: bool = False, -) -> ExplicitTPM: - """Compute the backward TPM for a given network state.""" - all_indices = tuple(range(forward_tpm.number_of_units)) - system_indices = tuple(sorted(system_indices)) - background_indices = tuple(sorted(set(all_indices) - set(system_indices))) - if not set(system_indices).issubset(set(all_indices)): - raise ValueError( - "system_indices must be a subset of `range(forward_tpm.number_of_units))`" - ) - - # p(u_t | s_{t–1}, w_{t–1}) - pr_current_state = probability_of_current_state(forward_tpm, current_state) - # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) - pr_current_state_given_only_background = pr_current_state.sum( - axis=tuple(system_indices), keepdims=True - ) - # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) - normalization = np.sum(pr_current_state) - # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) - # Σ_{w_{t–1}} p(s_{i,t} | s_{t–1}, w_{t–1}) ——————————————————————————————————————— - # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) - backward_tpm = ( - forward_tpm * pr_current_state_given_only_background / normalization - ).sum(axis=background_indices, keepdims=True) - if remove_background: - # Remove background units from last dimension of the state-by-node TPM - backward_tpm = backward_tpm[..., list(system_indices)] - return ExplicitTPM(backward_tpm) From 1d5939a686395b72d099c03b57b418a31b151ebd Mon Sep 17 00:00:00 2001 From: Isaac David Date: Thu, 28 Dec 2023 14:44:43 -0600 Subject: [PATCH 141/155] Rename DataArray accessor to something more expressive --- pyphi/network.py | 2 +- pyphi/node.py | 6 +++--- pyphi/subsystem.py | 2 +- pyphi/tpm.py | 12 ++++++------ 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pyphi/network.py b/pyphi/network.py index a53cc1df2..facbb5ecd 100755 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -111,7 +111,7 @@ def __init__( self._state_space, index, node_labels=self._node_labels - ).pyphi_accessor + ).node for index, node_tpm in zip(self._node_indices, tpm) ) ) diff --git a/pyphi/node.py b/pyphi/node.py index e04c0d0cf..e7c58bba7 100755 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -28,7 +28,7 @@ from .utils import state_of -@xr.register_dataarray_accessor("pyphi_accessor") +@xr.register_dataarray_accessor("node") @functools.total_ordering class Node: """A node in a Network. @@ -208,7 +208,7 @@ def project_index(self, index, preserve_singletons=False): return projected_index def __getitem__(self, index): - return self._dataarray[index].pyphi_accessor + return self._dataarray[index].node def __repr__(self): return self.label @@ -386,7 +386,7 @@ def generate_nodes( index, node_labels=node_labels, state=state, - ).pyphi_accessor + ).node ) return tuple(nodes) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 197a4c535..06c7efb22 100755 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -146,7 +146,7 @@ def __init__( i, self.node_labels, state=node.state - ).pyphi_accessor + ).node for i, node in enumerate(self.tpm.nodes) if i in self.node_indices ) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 90572c066..1e35435a1 100755 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -799,7 +799,7 @@ def marginalize_out(self, node_indices): node.dataarray.attrs["network_state_space"], node.index, node_labels=node.dataarray.attrs["node_labels"], - ).pyphi_accessor + ).node for node in self.nodes ) ) @@ -839,7 +839,7 @@ def _probability_of_current_state( self, current_state: tuple[int] ) -> tuple[ExplicitTPM]: - """Return the probability of the current state as a distribution over previous states. + """Return probability of current state as distribution over previous states. Output format is similar to an |ImplicitTPM|, however the last dimension only contains the probability for the current state. @@ -910,7 +910,7 @@ def backward_tpm( reference_node.attrs["network_state_space"], i, reference_node.attrs["node_labels"], - ).pyphi_accessor + ).node for i, backward_node_tpm in enumerate(backward_tpm) ) ) @@ -967,7 +967,7 @@ def squeeze(self, axis=None): new_state_space, next(new_node_indices), new_node_labels, - ).pyphi_accessor + ).node for node in self.nodes if node.index in nonsingletons ) ) @@ -976,14 +976,14 @@ def __getitem__(self, index, **kwargs): if isinstance(index, (int, slice, type(...), tuple)): return type(self)( tuple( - node.dataarray[node.project_index(index, **kwargs)].pyphi_accessor + node.dataarray[node.project_index(index, **kwargs)].node for node in self.nodes ) ) if isinstance(index, dict): return type(self)( tuple( - node.dataarray.loc[node.project_index(index, **kwargs)].pyphi_accessor + node.dataarray.loc[node.project_index(index, **kwargs)].node for node in self.nodes ) ) From f69f7a421534b8b959e418206b6d3f625977034c Mon Sep 17 00:00:00 2001 From: Isaac David Date: Fri, 29 Dec 2023 16:41:15 -0600 Subject: [PATCH 142/155] test_tpm: Add `test_backward_tpm` --- test/test_tpm.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 57 insertions(+), 2 deletions(-) diff --git a/test/test_tpm.py b/test/test_tpm.py index 4726f82b9..035bca758 100755 --- a/test/test_tpm.py +++ b/test/test_tpm.py @@ -7,7 +7,8 @@ import pytest import random -from pyphi import Subsystem +from pyphi import Network, Subsystem +from pyphi.convert import to_md from pyphi.distribution import normalize from pyphi.tpm import ExplicitTPM, reconstitute_tpm @@ -16,7 +17,7 @@ def implicit_tpm(size, degree, node_states, seed=1337, deterministic=False): if degree > size: raise ValueError( - f"The number of parrents of each node (degree={degree}) cannot be" + f"The number of parents of each node (degree={degree}) cannot be" f"smaller than the size of the network ({size})." ) if node_states < 2: @@ -195,6 +196,60 @@ def test_infer_cm(rule152): assert np.array_equal(rule152.tpm.infer_cm(), rule152.cm) +def test_backward_tpm(): + # fmt: off + cm = np.array([ + [1, 1, 0,], + [0, 1, 1,], + [1, 1, 1,], + ]) + + tpm = np.array([ + [1, 0, 0], + [0, 1, 0], + [1, 1, 1], + [0, 1, 1], + [0, 0, 0], + [1, 1, 0], + [0, 0, 1], + [1, 0, 1], + ]) + + # fmt: on + explicit_tpm = ExplicitTPM(to_md(tpm)) + implicit_tpm = Network(explicit_tpm, cm).tpm + + state = (1, 0, 0) + + # Backward TPM of full network must equal forward TPM. + subsystem_indices = (0, 1, 2) + backward = explicit_tpm.backward_tpm(state, subsystem_indices) + assert backward.array_equal(explicit_tpm) + backward = reconstitute_tpm( + implicit_tpm.backward_tpm(state, subsystem_indices) + ) + assert backward.array_equal(explicit_tpm) + + # Backward TPM of proper subsystem. + # fmt: off + answer = ExplicitTPM( + np.array( + [[[[1, 0, 0,]], + [[1, 1, 1,]]], + [[[0, 1, 0,]], + [[0, 1, 1,]]]], + ) + ) + # fmt: on + subsystem_indices = (0, 1) + backward = explicit_tpm.backward_tpm(state, subsystem_indices) + assert backward.array_equal(answer) + backward = reconstitute_tpm( + implicit_tpm.backward_tpm(state, subsystem_indices) + ) + assert backward.array_equal(answer) + + def test_reconstitute_tpm(standard, s_complete, rule152, noised): # Check subsystem and network TPM are the same when the subsystem is the # whole network From a357ee4cdf80d20c9a9f22bf880326d9b0877e2a Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 2 Jan 2024 21:24:31 -0600 Subject: [PATCH 143/155] Move Fig 8 counter network to `examples` --- pyphi/examples.py | 25 +++++++++++++++++++++++++ test/test_tpm.py | 30 ++++++++---------------------- 2 files changed, 33 insertions(+), 22 deletions(-) mode change 100644 => 100755 pyphi/examples.py diff --git a/pyphi/examples.py b/pyphi/examples.py old mode 100644 new mode 100755 index 7bef5c02b..1094baab2 --- a/pyphi/examples.py +++ b/pyphi/examples.py @@ -1459,3 +1459,28 @@ def get_net(mech_func, weights, mu=None, si=None, exp=None, th=None, l=None, k=N print(transition) account = actual.account(transition) print(account) + +@register_example +def functionally_equivalent(): + """The 2nd deterministic system from Figure 8 of the IIT 4.0 paper: + Functionally equivalent networks with different Φ-structures. + """ + node_labels = ("A", "B", "C") + # fmt: off + cm = np.array([ + [1, 1, 0,], + [0, 1, 1,], + [1, 1, 1,], + ]) + tpm = np.array([ + [1, 0, 0], + [0, 1, 0], + [1, 1, 1], + [0, 1, 1], + [0, 0, 0], + [1, 1, 0], + [0, 0, 1], + [1, 0, 1], + ]) + # fmt: on + return Network(tpm, cm=cm, node_labels=node_labels) diff --git a/test/test_tpm.py b/test/test_tpm.py index 035bca758..a70a1265a 100755 --- a/test/test_tpm.py +++ b/test/test_tpm.py @@ -7,7 +7,7 @@ import pytest import random -from pyphi import Network, Subsystem +from pyphi import examples, Network, Subsystem from pyphi.convert import to_md from pyphi.distribution import normalize from pyphi.tpm import ExplicitTPM, reconstitute_tpm @@ -197,34 +197,18 @@ def test_infer_cm(rule152): def test_backward_tpm(): - # fmt: off - cm = np.array([ - [1, 1, 0,], - [0, 1, 1,], - [1, 1, 1,], - ]) - - tpm = np.array([ - [1, 0, 0], - [0, 1, 0], - [1, 1, 1], - [0, 1, 1], - [0, 0, 0], - [1, 1, 0], - [0, 0, 1], - [1, 0, 1], - ]) - - # fmt: on - explicit_tpm = ExplicitTPM(to_md(tpm)) - implicit_tpm = Network(explicit_tpm, cm).tpm + network = examples.functionally_equivalent() + implicit_tpm = network.tpm + explicit_tpm = reconstitute_tpm(network.tpm) state = (1, 0, 0) # Backward TPM of full network must equal forward TPM. subsystem_indices = (0, 1, 2) + backward = explicit_tpm.backward_tpm(state, subsystem_indices) assert backward.array_equal(explicit_tpm) + backward = reconstitute_tpm( implicit_tpm.backward_tpm(state, subsystem_indices) ) @@ -242,8 +226,10 @@ def test_backward_tpm(): ) # fmt: on subsystem_indices = (0, 1) + backward = explicit_tpm.backward_tpm(state, subsystem_indices) assert backward.array_equal(answer) + backward = reconstitute_tpm( implicit_tpm.backward_tpm(state, subsystem_indices) ) From 1740585bd3d92f70b7a37d6d4ef39b11ea24d38e Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 2 Jan 2024 22:48:04 -0600 Subject: [PATCH 144/155] `utils`: Add `equivalent_states` generator --- pyphi/utils.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/pyphi/utils.py b/pyphi/utils.py index ff8679d82..23b3884d4 100644 --- a/pyphi/utils.py +++ b/pyphi/utils.py @@ -64,6 +64,41 @@ def all_states(n, big_endian=False): yield state[::-1] # Convert to little-endian ordering +def equivalent_states(state, mask, subsystem): + """Generate equivalence class of states given irrelevant dimensions. + + Arguments: + state (Iterable[int]): Some state in the equivalence class. + mask (Iterable[int]): State mask with 1's representing irrelevant dimensions. + subsystem (|Subsystem|): The subsystem of interest. + + Yields: + Iterable[tuple[int]]: A generator for the equivalence class of states. + + Examples: + >>> import numpy as np + >>> from pyphi import Network, Subsystem + >>> network = Network(np.ones((16, 4))) + >>> subsystem = Subsystem(network, (1, 1, 1, 1)) + >>> state = (1, 1, 1, 1) + >>> mask = (2, 1, 1, 2) + >>> list(equivalent_states(state, mask, subsystem)) + [(1, 0, 0, 1), (1, 0, 1, 1), (1, 1, 0, 1), (1, 1, 1, 1)] + """ + indices_needing_expansion = { + i: subsystem.tpm.shape[i] for i in subsystem.node_indices + if mask[i] == 1 + } + locally_expanded_states = product( + *[range(states) for i, states in indices_needing_expansion.items()] + ) + expanded_indices = list(indices_needing_expansion.keys()) + state = np.array(state) + for s in locally_expanded_states: + state[expanded_indices] = s + yield tuple(state) + + def np_immutable(a): """Make a NumPy array immutable.""" a.flags.writeable = False From 53b56c3281743da2ab1480ddb48b64783dd935dd Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 3 Jan 2024 10:20:01 -0600 Subject: [PATCH 145/155] Rework state reachability validation to avoid computing joint prob. --- pyphi/convert.py | 0 pyphi/tpm.py | 19 ++++++---------- pyphi/utils.py | 24 +++++++++++--------- pyphi/validate.py | 57 ++++++++++++++++++++++------------------------- 4 files changed, 47 insertions(+), 53 deletions(-) mode change 100644 => 100755 pyphi/convert.py diff --git a/pyphi/convert.py b/pyphi/convert.py old mode 100644 new mode 100755 diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 1e35435a1..943f152e5 100755 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -442,12 +442,7 @@ def condition_tpm(self, condition: Mapping[int, int]): conditioning_indices = tuple(chain.from_iterable(conditioning_indices)) # Obtain the actual conditioned TPM by indexing with the conditioning # indices. - tpm = self[conditioning_indices] - # Create new TPM object of the same type as self. - # self.tpm has already been validated and converted to multidimensional - # state-by-node form. Further validation would be problematic for - # singleton dimensions. - return tpm + return self[conditioning_indices] def marginalize_out(self, node_indices): """Marginalize out nodes from this TPM. @@ -518,7 +513,7 @@ def permute_nodes(self, permutation): self._tpm.transpose(dimension_permutation)[..., list(permutation)], ) - def _probability_of_current_state(self, current_state): + def probability_of_current_state(self, current_state): """Return the probability of the current state as a distribution over previous states. Arguments: @@ -527,7 +522,7 @@ def _probability_of_current_state(self, current_state): state_probabilities = np.empty(self.shape) if not len(current_state) == self.shape[-1]: raise ValueError( - f"current_state must have length {self.shape[-1]}" + f"current_state must have length {self.shape[-1]} " f"for state-by-node TPM of shape {self.shape}" ) for i in range(self.shape[-1]): @@ -553,7 +548,7 @@ def backward_tpm( ) # p(u_t | s_{t–1}, w_{t–1}) - pr_current_state = self._probability_of_current_state(current_state) + pr_current_state = self.probability_of_current_state(current_state) # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) pr_current_state_given_only_background = pr_current_state.sum( axis=tuple(system_indices), keepdims=True @@ -835,7 +830,7 @@ def remove_singleton_dimensions(self): tuple(node for node in self.squeeze().nodes) ) - def _probability_of_current_state( + def probability_of_current_state( self, current_state: tuple[int] ) -> tuple[ExplicitTPM]: @@ -851,7 +846,7 @@ def _probability_of_current_state( """ if not len(current_state) == self.number_of_units: raise ValueError( - f"current_state must have length {self.number_of_units}" + f"current_state must have length {self.number_of_units} " f"for state-by-node TPM of shape {self.shape}" ) nodes = [] @@ -878,7 +873,7 @@ def backward_tpm( "system_indices must be a subset of `range(self.number_of_units))`" ) # p(u_t | s_{t–1}, w_{t–1}) - pr_current_state_nodes = self._probability_of_current_state(current_state) + pr_current_state_nodes = self.probability_of_current_state(current_state) # TODO Avoid computing the full joint probability. Find uninformative # dimensions after each product and propagate their dismissal. pr_current_state = functools.reduce(np.multiply, pr_current_state_nodes) diff --git a/pyphi/utils.py b/pyphi/utils.py index 23b3884d4..bdb130120 100644 --- a/pyphi/utils.py +++ b/pyphi/utils.py @@ -64,30 +64,32 @@ def all_states(n, big_endian=False): yield state[::-1] # Convert to little-endian ordering -def equivalent_states(state, mask, subsystem): - """Generate equivalence class of states given irrelevant dimensions. +def equivalent_states(state, mask, state_space_shape): + """Generate the equivalence class of some state, given irrelevant dimensions. Arguments: state (Iterable[int]): Some state in the equivalence class. mask (Iterable[int]): State mask with 1's representing irrelevant dimensions. - subsystem (|Subsystem|): The subsystem of interest. + state_space_shape (Iterable[int]): The cardinalities of each dimension + in the state space. Yields: Iterable[tuple[int]]: A generator for the equivalence class of states. Examples: - >>> import numpy as np - >>> from pyphi import Network, Subsystem - >>> network = Network(np.ones((16, 4))) - >>> subsystem = Subsystem(network, (1, 1, 1, 1)) >>> state = (1, 1, 1, 1) >>> mask = (2, 1, 1, 2) - >>> list(equivalent_states(state, mask, subsystem)) - [(1, 0, 0, 1), (1, 0, 1, 1), (1, 1, 0, 1), (1, 1, 1, 1)] + >>> state_space_shape = (2, 2, 3, 3) + >>> list(equivalent_states(state, mask, state_space_shape)) + [(1, 0, 0, 1), (1, 0, 1, 1), (1, 0, 2, 1), (1, 1, 0, 1), (1, 1, 1, 1), (1, 1, 2, 1)] """ + n = len(state) + if any(n != len(arg) for arg in [mask, state_space_shape]): + raise ValueError(f"Expected mask and state_space_shape of size {n}.") + indices_needing_expansion = { - i: subsystem.tpm.shape[i] for i in subsystem.node_indices - if mask[i] == 1 + i: state_space_shape[i] for i, mask in enumerate(mask) + if mask == 1 } locally_expanded_states = product( *[range(states) for i, states in indices_needing_expansion.items()] diff --git a/pyphi/validate.py b/pyphi/validate.py index a3fdabcd7..ed3f88c92 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -4,7 +4,7 @@ Methods for validating arguments. """ -from warnings import warn +from itertools import product import numpy as np @@ -12,6 +12,7 @@ from .conf import config from .direction import Direction from .tpm import ImplicitTPM, reconstitute_tpm +from .utils import equivalent_states # pylint: disable=redefined-outer-name @@ -115,32 +116,31 @@ def state_length(state, size): def state_reachable(subsystem): - """Return whether a state can be reached according to the network's TPM.""" - # TODO(tpm) Change consumers of this function, so that only ImplicitTPMs - # are passed. - tpm = ( - reconstitute_tpm(subsystem.tpm) if isinstance(subsystem.tpm, ImplicitTPM) - else subsystem.tpm - ) - # If there is a row `r` in the TPM such that all entries of `r - state` are - # between -1 and 1, then the given state has a nonzero probability of being - # reached from some state. - # First we take the submatrix of the conditioned TPM that corresponds to - # the nodes that are actually in the subsystem... - tpm = tpm[..., subsystem.node_indices] - # Make sure the state is translated in terms of integer indices. - # TODO(tpm) Simplify conversion with a state_space class? - state_space = [ - node.state_space for node in subsystem.nodes - if node.index in subsystem.node_indices + """Raise exception if state cannot be reached according to subsystem's TPM.""" + # A state s is reachable by Subsystem S if and only if there is at least + # one state s_{t-1} with nonzero probability of transitioning to s: + # ∃ s_{t-1} : p(S=s | s_{t-1}, w_{t-1}) > 0 + + # Obtain p(S=s | W_{t-1}=w_{t-1}) as node marginals (i.e. implicitly). + p = subsystem.proper_tpm.probability_of_current_state(subsystem.proper_state) + + # Avoid computing the joint distribution. For each node n, find the set of + # coordinates s_{t-1} for which p_n > 0. The intersection of all such sets + # is the set of previous states leading to the current state. + past_states = [ + set( + equivalent + for equivalent in equivalent_states( + subsystem.proper_state, + node_p.shape[:-1], + subsystem.proper_tpm.shape[:-1] + ) + for state in np.argwhere(np.asarray(node_p) > 0) + ) + for node_p in p ] - state = np.array([ - state_space[node].index(state) - for node, state in enumerate(subsystem.proper_state) - ]) - # Then we do the subtraction and test. - test = tpm - state - if not np.any(np.logical_and(-1 < test, test < 1).all(-1)): + + if not set.intersection(*past_states): raise exceptions.StateUnreachableError(subsystem.state) @@ -159,10 +159,7 @@ def subsystem(s): """ # cut(s.cut, s.cut_indices) if config.VALIDATE_SUBSYSTEM_STATES: - # TODO(tpm) Reimplement in a way that never reconstitutes the full TPM. - # state_reachable(s) - # warn("Validation of state reachability didn't take place.") - pass + state_reachable(s) return True From 64331cca55ec4bd046d77a879be751840626e4aa Mon Sep 17 00:00:00 2001 From: Will Mayner Date: Wed, 27 Mar 2024 08:41:02 -0500 Subject: [PATCH 146/155] Only compute potential purviews if not overriden In `subsystem.find_mice`, computing potential purviews can be very expensive in some situations, and if the user has provided a short iterable of purviews, then computing the potential purviews is not worth it. So, we simply use the user-provided purviews directly, allowing the user to decide whether to filter out reducible purviews. --- pyphi/subsystem.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 06c7efb22..502f9993b 100755 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -145,7 +145,7 @@ def __init__( self.network.state_space, i, self.node_labels, - state=node.state + state=node.state, ).node for i, node in enumerate(self.tpm.nodes) if i in self.node_indices @@ -1091,7 +1091,8 @@ def find_mice(self, direction, mechanism, purviews=None, **kwargs): Returns: MaximallyIrreducibleCauseOrEffect: The |MIC| or |MIE|. """ - purviews = self.potential_purviews(direction, mechanism, purviews) + if purviews is None: + purviews = self.potential_purviews(direction, mechanism, purviews) if direction == Direction.CAUSE: mice_class = MaximallyIrreducibleCause From d44b76041b9b9adef9711fe98f81049fd038dc57 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Sat, 25 May 2024 09:52:22 -0500 Subject: [PATCH 147/155] `validate.state_reachable`: Shortcircuit intersection as soon as possible --- pyphi/validate.py | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/pyphi/validate.py b/pyphi/validate.py index ed3f88c92..d414d856d 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -119,29 +119,30 @@ def state_reachable(subsystem): """Raise exception if state cannot be reached according to subsystem's TPM.""" # A state s is reachable by Subsystem S if and only if there is at least # one state s_{t-1} with nonzero probability of transitioning to s: - # ∃ s_{t-1} : p(S=s | s_{t-1}, w_{t-1}) > 0 + # ∃ s_{t-1} : p(s | s_{t-1}, w_{t-1}) > 0 - # Obtain p(S=s | W_{t-1}=w_{t-1}) as node marginals (i.e. implicitly). + # Obtain p(s | w_{t-1}) as node marginals (i.e. implicitly). p = subsystem.proper_tpm.probability_of_current_state(subsystem.proper_state) # Avoid computing the joint distribution. For each node n, find the set of # coordinates s_{t-1} for which p_n > 0. The intersection of all such sets # is the set of previous states leading to the current state. - past_states = [ - set( - equivalent - for equivalent in equivalent_states( - subsystem.proper_state, - node_p.shape[:-1], - subsystem.proper_tpm.shape[:-1] - ) - for state in np.argwhere(np.asarray(node_p) > 0) + + def past_states(p_node): + return set( + tuple(state) for state in np.argwhere(np.asarray(p_node) > 0) ) - for node_p in p - ] - if not set.intersection(*past_states): - raise exceptions.StateUnreachableError(subsystem.state) + # Initial value. + intersection = past_states(p[0]) + + for p_node in p[1:]: + intersection = set.intersection(intersection, past_states(p_node)) + + # Shortcircuit evaluation of intersection as soon as a + # pairwise intersection is empty. + if not intersection: + raise exceptions.StateUnreachableError(subsystem.state) def cut(cut, node_indices): From 8083a716aecd95aa9c312ce37fee350ce26c1622 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Mon, 27 May 2024 12:58:08 -0500 Subject: [PATCH 148/155] `validate.state_reachable`: Don't expand equivalence classes, use symbolic intersection. --- pyphi/validate.py | 48 +++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 8 deletions(-) diff --git a/pyphi/validate.py b/pyphi/validate.py index d414d856d..8d0f0e850 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -4,15 +4,12 @@ Methods for validating arguments. """ -from itertools import product - import numpy as np from . import exceptions from .conf import config from .direction import Direction -from .tpm import ImplicitTPM, reconstitute_tpm -from .utils import equivalent_states +from .tpm import ImplicitTPM # pylint: disable=redefined-outer-name @@ -129,16 +126,51 @@ def state_reachable(subsystem): # is the set of previous states leading to the current state. def past_states(p_node): - return set( - tuple(state) for state in np.argwhere(np.asarray(p_node) > 0) + # Find s_{t-1} such that p_node > 0. + states = list(np.argwhere(np.asarray(p_node) > 0)) + # Remove last dimension (probability of current state). + states = [state[:-1] for state in states] + # If node TPM shape at certain parent contains a 1, then + # there's no dependency on that parent. Substitute '0' state + # with placeholder to encode equivalent states. + states = [ + tuple( + '?' if p_node.shape[i] == 1 else s + for i, s in enumerate(state) + ) + for state in states + ] + return set(states) + + def _states_intersection(state1, state2): + restricted_state = [] + for s1_i, s2_i in zip(state1, state2): + if s1_i == s2_i: + restricted_state.append(s1_i) + elif s1_i == '?': + restricted_state.append(s2_i) + elif s2_i == '?': + restricted_state.append(s1_i) + else: + return None + return tuple(restricted_state) + + def states_intersection(states1, states2): + # For each unordered pair {s1, s2} in the Cartesian product of + # the two state sets, check if s1 and s2 refer to the same + # state or a (sub)class of equivalent states. If so, that's a + # member of the intersection of states1 and states2. + intersection = set( + equivalent_state for s1 in states1 for s2 in states2 + if (equivalent_state := _states_intersection(s1, s2)) ) + return intersection # Initial value. intersection = past_states(p[0]) for p_node in p[1:]: - intersection = set.intersection(intersection, past_states(p_node)) - + intersection = states_intersection(intersection, past_states(p_node)) # Shortcircuit evaluation of intersection as soon as a # pairwise intersection is empty. if not intersection: From fa9a3e43a7ebcf5b1c8aea7c0426e6173a8632b2 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 28 May 2024 12:06:50 -0500 Subject: [PATCH 149/155] validate.state_reachable: parallelize intersection between symbolic equivalence classes --- pyphi/conf.py | 14 ++++++++++++++ pyphi/validate.py | 34 +++++++++++++++++++++++----------- pyphi_config.yml | 7 +++++++ 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/pyphi/conf.py b/pyphi/conf.py index d29e6d7f2..987bdff15 100644 --- a/pyphi/conf.py +++ b/pyphi/conf.py @@ -538,6 +538,20 @@ def always_zero(a, b): """, ) + PARALLEL_STATE_REACHABILITY_EVALUATION = Option( + dict( + parallel=True, + sequential_threshold=2*10, + chunksize=2**12, + progress=True, + ), + type=Mapping, + doc=""" + Controls parallel evaluation of subsystem state reachability. + + Only applies if VALIDATE_SUBSYSTEM_STATES = True.""", + ) + NUMBER_OF_CORES = Option( -1, type=int, diff --git a/pyphi/validate.py b/pyphi/validate.py index 8d0f0e850..2a985886e 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -4,9 +4,11 @@ Methods for validating arguments. """ +from itertools import product import numpy as np -from . import exceptions +from . import conf, exceptions +from .compute.parallel import MapReduce from .conf import config from .direction import Direction from .tpm import ImplicitTPM @@ -132,24 +134,24 @@ def past_states(p_node): states = [state[:-1] for state in states] # If node TPM shape at certain parent contains a 1, then # there's no dependency on that parent. Substitute '0' state - # with placeholder to encode equivalent states. + # with placeholder -1 to encode equivalent states. states = [ tuple( - '?' if p_node.shape[i] == 1 else s + -1 if p_node.shape[i] == 1 else s for i, s in enumerate(state) ) for state in states ] return set(states) - def _states_intersection(state1, state2): + def _states_intersection(state_pair): restricted_state = [] - for s1_i, s2_i in zip(state1, state2): + for s1_i, s2_i in zip(*state_pair): if s1_i == s2_i: restricted_state.append(s1_i) - elif s1_i == '?': + elif s1_i == -1: restricted_state.append(s2_i) - elif s2_i == '?': + elif s2_i == -1: restricted_state.append(s1_i) else: return None @@ -160,11 +162,21 @@ def states_intersection(states1, states2): # the two state sets, check if s1 and s2 refer to the same # state or a (sub)class of equivalent states. If so, that's a # member of the intersection of states1 and states2. - intersection = set( - equivalent_state for s1 in states1 for s2 in states2 - if (equivalent_state := _states_intersection(s1, s2)) + state_pairs = set( + tuple(sorted(pair)) for pair in product(states1, states2) ) - return intersection + parallel_kwargs = conf.parallel_kwargs( + config.PARALLEL_STATE_REACHABILITY_EVALUATION + ) + intersection = MapReduce( + _states_intersection, + state_pairs, + total=len(state_pairs), + desc="Validating state reachability", + **parallel_kwargs, + ).run() + # Cast to set and filter out None from members of intersection. + return set(state for state in intersection if state) # Initial value. intersection = past_states(p[0]) diff --git a/pyphi_config.yml b/pyphi_config.yml index cec2b1769..1a63f36dd 100644 --- a/pyphi_config.yml +++ b/pyphi_config.yml @@ -84,6 +84,13 @@ PARALLEL_RELATION_EVALUATION: parallel: true sequential_threshold: 1024 progress: true +# Controls parallel evaluation of subsystem state reachability. +# Only applies if VALIDATE_SUBSYSTEM_STATES = True. +PARALLEL_STATE_REACHABILITY_EVALUATION: + chunksize: 4096 + parallel: true + sequential_threshold: 1024 + progress: true # Controls the partitioning scheme for distinctions. PARTITION_TYPE: ALL # Controls the numerical precision with which to compare phi values against each From a08fc2b948d14ea399036bd4b6dbb556560dcd4e Mon Sep 17 00:00:00 2001 From: Isaac David Date: Tue, 28 May 2024 18:35:40 -0500 Subject: [PATCH 150/155] `validate`: Refactor scope and add documentation/tests. --- pyphi/validate.py | 139 ++++++++++++++++++++++++++++------------------ 1 file changed, 85 insertions(+), 54 deletions(-) diff --git a/pyphi/validate.py b/pyphi/validate.py index 2a985886e..8e9972375 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -114,6 +114,88 @@ def state_length(state, size): return True +def _past_states(p_node): + """Find set of states which could have led to the current state of a node. + + The state of irrelevant dimensions, nodes which don't output to this + node, is represented with -1 to encode a whole equivalence class. + + Arguments: + p_node (np.ndarray): Node TPM conditioned on the current subsystem state. + See also :func:`pyphi.tpm.ImplicitTPM.probability_of_current_state`. + + Returns: + set: Set of past states with nonzero probability of transitioning. + """ + # Find s_{t-1} such that p_node > 0. + states = list(np.argwhere(np.asarray(p_node) > 0)) + # Remove last dimension (probability of current state). + states = [state[:-1] for state in states] + # If node TPM shape at certain parent contains a 1, then + # there's no dependency on that parent. Substitute '0' state + # with placeholder -1 to encode equivalent states. + states = [ + tuple(-1 if p_node.shape[i] == 1 else s for i, s in enumerate(state)) + for state in states + ] + return set(states) + + +def _states_intersection(states1, states2): + """Efficient symbolic intersection between two sets of states. + + Arguments: + states1 (set[tuple[int]]): First set of states or equivalence classes. + states2 (set[tuple[int]]): Second set of states or equivalence classes. + + Returns: + set[tuple[int]]: The intersection between the two sets. + + Examples: + >>> states1 = {(1, 0, -1), (1, 1, -1)} + >>> states2 = {(1, 0, 0), (1, 1, 1)} + >>> sorted(list(_states_intersection(states1, states2))) + [(1, 0, 0), (1, 1, 1)] + + >>> states1 = {(1, -1, -1)} + >>> states2 = {(1, 0, -1), (1, 1, -1)} + >>> sorted(list(_states_intersection(states1, states2))) + [(1, 0, -1), (1, 1, -1)] + """ + def find_intersection(state_pair): + # For each unordered pair |{state1, state2}| in the Cartesian product of + # the two sets, check if |state1| and |state2| have a non-empty + # (sub)class in common. If so, that is a member of the intersection. + subclass = [] + for i, j in zip(*state_pair): + if i == j: + subclass.append(i) + elif i == -1: + subclass.append(j) + elif j == -1: + subclass.append(i) + else: + return None + return tuple(subclass) + + # Obtain Cartesian product and discard permutations of the same pair. + state_pairs = set( + tuple(sorted(pair)) for pair in product(states1, states2) + ) + parallel_kwargs = conf.parallel_kwargs( + config.PARALLEL_STATE_REACHABILITY_EVALUATION + ) + intersection = MapReduce( + find_intersection, + state_pairs, + total=len(state_pairs), + desc="Validating state reachability", + **parallel_kwargs, + ).run() + # Cast to set and filter out None's from members of the intersection. + return set(state for state in intersection if state) + + def state_reachable(subsystem): """Raise exception if state cannot be reached according to subsystem's TPM.""" # A state s is reachable by Subsystem S if and only if there is at least @@ -127,64 +209,13 @@ def state_reachable(subsystem): # coordinates s_{t-1} for which p_n > 0. The intersection of all such sets # is the set of previous states leading to the current state. - def past_states(p_node): - # Find s_{t-1} such that p_node > 0. - states = list(np.argwhere(np.asarray(p_node) > 0)) - # Remove last dimension (probability of current state). - states = [state[:-1] for state in states] - # If node TPM shape at certain parent contains a 1, then - # there's no dependency on that parent. Substitute '0' state - # with placeholder -1 to encode equivalent states. - states = [ - tuple( - -1 if p_node.shape[i] == 1 else s - for i, s in enumerate(state) - ) - for state in states - ] - return set(states) - - def _states_intersection(state_pair): - restricted_state = [] - for s1_i, s2_i in zip(*state_pair): - if s1_i == s2_i: - restricted_state.append(s1_i) - elif s1_i == -1: - restricted_state.append(s2_i) - elif s2_i == -1: - restricted_state.append(s1_i) - else: - return None - return tuple(restricted_state) - - def states_intersection(states1, states2): - # For each unordered pair {s1, s2} in the Cartesian product of - # the two state sets, check if s1 and s2 refer to the same - # state or a (sub)class of equivalent states. If so, that's a - # member of the intersection of states1 and states2. - state_pairs = set( - tuple(sorted(pair)) for pair in product(states1, states2) - ) - parallel_kwargs = conf.parallel_kwargs( - config.PARALLEL_STATE_REACHABILITY_EVALUATION - ) - intersection = MapReduce( - _states_intersection, - state_pairs, - total=len(state_pairs), - desc="Validating state reachability", - **parallel_kwargs, - ).run() - # Cast to set and filter out None from members of intersection. - return set(state for state in intersection if state) - # Initial value. - intersection = past_states(p[0]) + intersection = _past_states(p[0]) for p_node in p[1:]: - intersection = states_intersection(intersection, past_states(p_node)) + intersection = _states_intersection(intersection, _past_states(p_node)) # Shortcircuit evaluation of intersection as soon as a - # pairwise intersection is empty. + # 2-ary intersection is empty. if not intersection: raise exceptions.StateUnreachableError(subsystem.state) From 4b93ca1af04c2850f1898b87c6db888de5c918ed Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 29 May 2024 13:58:57 -0500 Subject: [PATCH 151/155] Improve state validation when creating subsystems --- pyphi/subsystem.py | 2 +- pyphi/validate.py | 31 ++++++++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 502f9993b..96ad06e19 100755 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -86,7 +86,7 @@ def __init__( # (for JSON serialization). self.node_indices = self.node_labels.coerce_to_indices(nodes) - validate.state_length(state, self.network.size) + validate.state(state, self.network.size, self.network.tpm.shape[:-1]) # The state of the network. self.state = tuple(state) diff --git a/pyphi/validate.py b/pyphi/validate.py index 8e9972375..8a8cd0981 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -99,7 +99,7 @@ def is_network(network): if not isinstance(network, Network): raise ValueError( - "Input must be a Network (perhaps you passed a Subsystem instead?" + "Input must be a Network (perhaps you passed a Subsystem instead?)" ) @@ -114,6 +114,35 @@ def state_length(state, size): return True +def state_type(state): + """Check that the state only contains integers.""" + if any(not isinstance(s, int) for s in state): + raise TypeError( + f"Invalid state {state}: each entry must be of int type." + ) + return True + + +def state_value(state, shape): + """Check that each entry in the state falls within the right range.""" + if any( + s not in range(cardinality) + for s, cardinality in zip(state, shape) + ): + raise ValueError( + f"Invalid state {state}: entries must be within zero and {shape}." + ) + return True + + +def state(state, size, shape): + """Check that the state is of the correct length, type and value.""" + return ( + state_length(state, size) and + state_type(state) and + state_value(state, shape) + ) + def _past_states(p_node): """Find set of states which could have led to the current state of a node. From 34010f639d01e888634a8f8abd2323c8bae1dcac Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 29 May 2024 17:41:36 -0500 Subject: [PATCH 152/155] `validate.state_value`: Fix off-by-one error --- pyphi/validate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyphi/validate.py b/pyphi/validate.py index 8a8cd0981..3316bbc8c 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -130,7 +130,8 @@ def state_value(state, shape): for s, cardinality in zip(state, shape) ): raise ValueError( - f"Invalid state {state}: entries must be within zero and {shape}." + f"Invalid state {state}: entries must be within zero and " + f"{tuple((np.array(shape) - 1).tolist())}." ) return True From ab4c202f6f3224a467906eaabf89783708583cd1 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 29 May 2024 18:22:15 -0500 Subject: [PATCH 153/155] Remove parallelization of `state_reachable` for now --- pyphi/conf.py | 14 -------------- pyphi/validate.py | 25 ++++++++----------------- pyphi_config.yml | 7 ------- 3 files changed, 8 insertions(+), 38 deletions(-) diff --git a/pyphi/conf.py b/pyphi/conf.py index 987bdff15..d29e6d7f2 100644 --- a/pyphi/conf.py +++ b/pyphi/conf.py @@ -538,20 +538,6 @@ def always_zero(a, b): """, ) - PARALLEL_STATE_REACHABILITY_EVALUATION = Option( - dict( - parallel=True, - sequential_threshold=2*10, - chunksize=2**12, - progress=True, - ), - type=Mapping, - doc=""" - Controls parallel evaluation of subsystem state reachability. - - Only applies if VALIDATE_SUBSYSTEM_STATES = True.""", - ) - NUMBER_OF_CORES = Option( -1, type=int, diff --git a/pyphi/validate.py b/pyphi/validate.py index 3316bbc8c..4d81d5d22 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -182,8 +182,8 @@ def _states_intersection(states1, states2): set[tuple[int]]: The intersection between the two sets. Examples: - >>> states1 = {(1, 0, -1), (1, 1, -1)} - >>> states2 = {(1, 0, 0), (1, 1, 1)} + >>> states1 = {(1, 0, -1), (1, 1, 1)} + >>> states2 = {(1, 0, 0), (1, 1, 1), (0, 0, 0)} >>> sorted(list(_states_intersection(states1, states2))) [(1, 0, 0), (1, 1, 1)] @@ -208,22 +208,13 @@ def find_intersection(state_pair): return None return tuple(subclass) - # Obtain Cartesian product and discard permutations of the same pair. - state_pairs = set( - tuple(sorted(pair)) for pair in product(states1, states2) + # Lazy generator of the Cartesian product. + state_pairs = product(states1, states2) + # Find 2-ary intersections, filter out None's on the fly and return that set. + return set( + intersection for pair in state_pairs + if (intersection := find_intersection(pair)) ) - parallel_kwargs = conf.parallel_kwargs( - config.PARALLEL_STATE_REACHABILITY_EVALUATION - ) - intersection = MapReduce( - find_intersection, - state_pairs, - total=len(state_pairs), - desc="Validating state reachability", - **parallel_kwargs, - ).run() - # Cast to set and filter out None's from members of the intersection. - return set(state for state in intersection if state) def state_reachable(subsystem): diff --git a/pyphi_config.yml b/pyphi_config.yml index 1a63f36dd..cec2b1769 100644 --- a/pyphi_config.yml +++ b/pyphi_config.yml @@ -84,13 +84,6 @@ PARALLEL_RELATION_EVALUATION: parallel: true sequential_threshold: 1024 progress: true -# Controls parallel evaluation of subsystem state reachability. -# Only applies if VALIDATE_SUBSYSTEM_STATES = True. -PARALLEL_STATE_REACHABILITY_EVALUATION: - chunksize: 4096 - parallel: true - sequential_threshold: 1024 - progress: true # Controls the partitioning scheme for distinctions. PARTITION_TYPE: ALL # Controls the numerical precision with which to compare phi values against each From e51c49625a898c14fa41140e62c500184febbbcd Mon Sep 17 00:00:00 2001 From: Isaac David Date: Mon, 4 Nov 2024 19:26:57 -0600 Subject: [PATCH 154/155] Validate normalization factor of cause TPM, throw error if 0. --- pyphi/tpm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyphi/tpm.py b/pyphi/tpm.py index cfa0272c8..1fdd6e85f 100644 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -652,6 +652,8 @@ def backward_tpm( ) # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) normalization = np.sum(pr_current_state) + if normalization == 0.0: + raise exceptions.StateUnreachableError(current_state) # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) # Σ_{w_{t–1}} p(s_{i,t} | s_{t–1}, w_{t–1}) ——————————————————————————————————————— # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) From 65e8f564c780dd73b251c81c617314bb0d1dbae0 Mon Sep 17 00:00:00 2001 From: Isaac David Date: Wed, 5 Nov 2025 14:59:47 -0600 Subject: [PATCH 155/155] Initial support for substrates of arbitrary size --- pyphi/data_structures/array_like.py | 6 +- pyphi/network.py | 180 +++++---------- pyphi/node.py | 285 ++++++++---------------- pyphi/state_space.py | 3 +- pyphi/subsystem.py | 69 +++--- pyphi/tpm.py | 328 +++++++++++++++++----------- pyphi/validate.py | 8 +- 7 files changed, 399 insertions(+), 480 deletions(-) diff --git a/pyphi/data_structures/array_like.py b/pyphi/data_structures/array_like.py index 775f2bc93..4f9e67583 100644 --- a/pyphi/data_structures/array_like.py +++ b/pyphi/data_structures/array_like.py @@ -14,12 +14,14 @@ class ArrayLike(NDArrayOperatorsMixin): _TYPE_CLOSED_FUNCTIONS = ( np.all, np.any, + np.broadcast_to, np.concatenate, np.expand_dims, + np.result_type, np.stack, np.sum, - np.result_type, - np.broadcast_to, + np.where, + np.zeros_like, ) # Holds the underlying array diff --git a/pyphi/network.py b/pyphi/network.py index 8fbad61a5..9caaa33f4 100755 --- a/pyphi/network.py +++ b/pyphi/network.py @@ -5,14 +5,13 @@ |big_phi| computation. """ -from typing import Iterable +from typing import Any, Dict, Optional, Sequence, Union import numpy as np from . import cache, connectivity, jsonify, utils, validate from .labels import NodeLabels from .node import generate_nodes, generate_node from .tpm import ExplicitTPM, ImplicitTPM -from .state_space import build_state_space class Network: @@ -21,8 +20,23 @@ class Network: Represents the network under analysis and holds auxilary data about it. Args: - tpm (np.ndarray): The transition probability matrix of the network. - See :func:`pyphi.tpm.ExplicitTPM`. + tpm (np.ndarray or ExplicitTPM or Sequence[np.ndarray] or ImplicitTPM): + The transition probability matrix of the network. + + If a single numpy.ndarray or |ExplicitTPM| is provided, pyphi + assumes it is an old-style TPM for the whole network, and it will be + converted to an |ImplicitTPM|. + + If an |ImplicitTPM| or a list of numpy.ndarray is provided, it must + contain one TPM per node, and their order should match those of + ``cm`` and ``node_labels``. For node |j|, the number of dimensions + of its TPM must be |inputs(j) + 1| and its shape must be + |(s_1, s_2, ... , s_i, s_j)| (also in order), where |inputs(j)| is + the number of nodes that are direct inputs of |j| and |s_i| is the + number of states for node |i|. In other words ``tpm_j[0, 1, 2, 3]`` + stands for |Pr(j_{t+1}=3 | a_{t}=0, j_{t}=1, z_{t}=2)|. + + See :ref:`tpm-conventions:`. Keyword Args: cm (np.ndarray): A square binary adjacency matrix indicating the @@ -30,50 +44,38 @@ class Network: that node |i| is connected to node |j| (see :ref:`cm-conventions`). **If no connectivity matrix is given, PyPhi assumes that every node is connected to every node (including itself)**. - node_labels (tuple[str] or |NodeLabels|): Human-readable labels for + node_labels (Sequence[str] or |NodeLabels|): Human-readable labels for each node in the network. - state_space (Optional[tuple[tuple[Union[int, str]]]]): - Labels for the state space of each node in the network. If ``None``, - states will be automatically labeled using a zero-based integer - index per node. + """ def __init__( self, - tpm, - cm=None, - node_labels=None, - state_space=None, - purview_cache=None + tpm: Union[np.ndarray, ExplicitTPM, Sequence, ImplicitTPM, Dict[str, Any]], + cm: Optional[np.ndarray] = None, + node_labels: Optional[Sequence[str]] = None, + purview_cache: Optional[cache.PurviewCache] = None, ): + self._cm, self._cm_hash = self._build_cm(cm) + self._node_indices = tuple(range(self.size)) + self._node_labels = NodeLabels(node_labels, self._node_indices) + self.purview_cache = purview_cache or cache.PurviewCache() + # Initialize _tpm according to argument type. if isinstance(tpm, (np.ndarray, ExplicitTPM)): - # Validate TPM and convert to state-by-node multidimensional format. + # Old-style TPM: validate and convert to state-by-node format first. tpm = ExplicitTPM(tpm, validate=True) - - self._cm, self._cm_hash = self._build_cm(cm, tpm) - - self._node_indices = tuple(range(self.size)) - self._node_labels = NodeLabels(node_labels, self._node_indices) - - self._state_space, _ = build_state_space( - self._node_labels, - tpm.shape[:-1], - state_space - ) - - self._tpm = ImplicitTPM( - generate_nodes( - tpm, - self._cm, - self._state_space, - self._node_indices, - self._node_labels - ) + nodes = generate_nodes( + tpm, + self._cm, + self._node_indices, + self._node_labels ) + self._tpm = ImplicitTPM(nodes) - elif isinstance(tpm, Iterable): + elif isinstance(tpm, Sequence): + # Individual node TPMs were provided, format into an ImplicitTPM. invalid = [ i for i in tpm if not isinstance(i, (np.ndarray, ExplicitTPM)) ] @@ -83,66 +85,30 @@ def __init__( ', '.join(str(i) for i in invalid) )) - tpm = tuple( - ExplicitTPM(node_tpm, validate=False) for node_tpm in tpm - ) - - shapes = [node.shape for node in tpm] - - self._cm, self._cm_hash = self._build_cm(cm, tpm, shapes) + tpm = [ExplicitTPM(node_tpm, validate=False) for node_tpm in tpm] - self._node_indices = tuple(range(self.size)) - self._node_labels = NodeLabels(node_labels, self._node_indices) - - network_tpm_shape = ImplicitTPM._node_shapes_to_shape(shapes) - self._state_space, _ = build_state_space( - self._node_labels, - network_tpm_shape[:-1], - state_space - ) - - self._tpm = ImplicitTPM( - tuple( - generate_node( - node_tpm, - self._cm, - self._state_space, - index, - node_labels=self._node_labels - ) - for index, node_tpm in zip(self._node_indices, tpm) - ) + nodes = tuple( + generate_node(node_tpm, self._cm, index, self._node_labels) + for index, node_tpm in zip(self._node_indices, tpm) ) + self._tpm = ImplicitTPM(nodes) elif isinstance(tpm, ImplicitTPM): self._tpm = tpm - self._cm, self._cm_hash = self._build_cm(cm, self._tpm) - self._node_indices = tuple(range(self.size)) - self._node_labels = NodeLabels(node_labels, self._node_indices) - self._state_space, _ = build_state_space( - self._node_labels, - self._tpm.shape[:-1], - state_space - ) # FIXME(TPM) initialization from JSON elif isinstance(tpm, dict): # From JSON. self._tpm = ImplicitTPM(tpm["_tpm"]) - self._cm, self._cm_hash = self._build_cm(cm, tpm) - self._node_indices = tuple(range(self.size)) - self._node_labels = NodeLabels(node_labels, self._node_indices) else: raise TypeError(f"Invalid TPM of type {type(tpm)}.") - self.purview_cache = purview_cache or cache.PurviewCache() - validate.network(self) @property def tpm(self): - """pyphi.tpm.ExplicitTPM: The TPM object which contains this + """ExplicitTPM: The TPM object which contains this network's transition probability matrix, in multidimensional form. """ @@ -157,43 +123,18 @@ def cm(self): """ return self._cm - def _build_cm(self, cm, tpm, shapes=None): + def _build_cm(self, cm): """Convert the passed CM to the proper format, or construct the - unitary CM if none was provided (explicit TPM), or infer from node TPMs. + unitary CM if none was provided. """ if cm is None: - if hasattr(tpm, "shape"): - network_size = tpm.shape[-1] - else: - network_size = len(tpm) - - # Explicit TPM without connectivity matrix: assume all are connected. - if shapes is None: - cm = np.ones((network_size, network_size), dtype=int) - utils.np_immutable(cm) - return (cm, utils.np_hash(cm)) - - # ImplicitTPM without connectivity matrix: infer from node TPMs. - cm = np.zeros((network_size, network_size), dtype=int) - - for i, shape in enumerate(shapes): - for j in range(len(shapes)): - if shape[j] != 1: - cm[j][i] = 1 - - utils.np_immutable(cm) - return (cm, utils.np_hash(cm)) + # Assume all are connected. + cm = np.ones((self.size, self.size)) + else: + cm = np.array(cm) - cm = np.array(cm) utils.np_immutable(cm) - # Explicit TPM with connectivity matrix: return. - if shapes is None: - return (cm, utils.np_hash(cm)) - - # ImplicitTPM with connectivity matrix: validate against node shapes. - validate.shapes(shapes, cm) - return (cm, utils.np_hash(cm)) @property @@ -204,25 +145,17 @@ def connectivity_matrix(self): @property def causally_significant_nodes(self): """See :func:`pyphi.connectivity.causally_significant_nodes`.""" - return connectivity.causally_significant_nodes(self.cm) + return connectivity.causally_significant_nodes(self._cm) @property def size(self): """int: The number of nodes in the network.""" return len(self) - @property - def state_space(self): - """tuple[tuple[Union[int, str]]]: Labels for the state space of each node. - """ - return self._state_space - @property def num_states(self): """int: The number of possible states of the network.""" - return np.prod( - [len(node_states) for node_states in self._state_space] - ) + return np.prod(self._tpm.shape) @property def node_indices(self): @@ -260,10 +193,9 @@ def __len__(self): return self._cm.shape[0] def __repr__(self): - # TODO implement a cleaner repr, similar to analyses objects, - # distinctions, etc. - return "Network(\n{},\ncm={},\nnode_labels={},\nstate_space={}\n)".format( - self.tpm, self.cm, self.node_labels, self.state_space._dict + cm = str(self.cm).replace('\n', '\n ') + return "Network(\n {},\n cm={},\n node_labels={}\n)".format( + self.tpm, cm, self.node_labels.labels ) def __eq__(self, other): @@ -280,7 +212,6 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) - # TODO(tpm): Immutability in xarray. def __hash__(self): return hash((hash(self.tpm), self._cm_hash)) @@ -291,7 +222,6 @@ def to_json(self): "cm": self.cm, "size": self.size, "node_labels": self.node_labels, - "state_space": self.state_space, } @classmethod diff --git a/pyphi/node.py b/pyphi/node.py index 1b44b60f2..6ab9f39df 100755 --- a/pyphi/node.py +++ b/pyphi/node.py @@ -1,9 +1,10 @@ -"""Represents a node in a network.""" # node.py +"""Represents a node in a network.""" + import functools -from typing import Iterable, Mapping, Optional, Tuple, Union +from typing import Any, Mapping, Optional, Tuple import numpy as np import xarray as xr @@ -12,13 +13,8 @@ # of importing all of pyphi.tpm and relying on late binding of pyphi.tpm. # to avoid the circular import error. import pyphi.tpm - from .connectivity import get_inputs_from_cm, get_outputs_from_cm -from .state_space import ( - dimension_labels, - build_state_space, - SINGLETON_COORDINATE, -) +from .labels import NodeLabels from .utils import state_of @@ -36,24 +32,21 @@ class Node: Attributes: index (int): The node's index in the network. label (str): The textual label for this node. - node_labels (Tuple[str]): The textual labels for the nodes in the network. cause_dataarray (xr.DataArray): the xarray DataArray for the cause TPM. effect_dataarray (xr.DataArray): the xarray DataArray for the effect TPM. cause_tpm (|ExplicitTPM|), - effect_tpm (|ExplicitTPM|): The node TPM is an array with |n + 1| dimensions, - where ``n`` is the size of the |Network|. The first ``n`` dimensions - correspond to each node in the system. Dimensions corresponding to - nodes that provide input to this node are of size > 1, while those - that do not correspond to inputs are of size 1. The last dimension - encodes the state of the node in the next timestep, so that - ``node.tpm[..., 0]`` gives probabilities that the node will be 'OFF' - and ``node.tpm[..., 1]`` gives probabilities that the node will be - 'ON'. + effect_tpm (|ExplicitTPM|): The node TPM is an array with |i + 1| + dimensions, The first ``i`` dimensions correspond to the inputs to + the |Node|, and are of size > 1 (the possible states of the + input). The last dimension encodes the state of the node in the next + timestep, so that ``node.tpm[..., 0]`` gives probabilities that the + node will be 'OFF' and ``node.tpm[..., 1]`` gives probabilities that + the node will be 'ON'. inputs (frozenset): The set of nodes which send connections to this node. outputs (frozenset): The set of nodes this node sends connections to. - state_space (Tuple[Union[int, str]]): The space of states this node can inhabit. - state (Optional[Union[int, str]]): The current state of this node. + state (int): The current state of this node. + shape (Tuple[int]): The expanded shape of this node's TPM. """ def __init__( @@ -69,18 +62,14 @@ def __init__( self._inputs = effect_dataarray.attrs["inputs"] self._outputs = effect_dataarray.attrs["outputs"] - if cause_dataarray is None: - self._cause_dataarray = None - self._cause_tpm = None - else: - self._cause_dataarray = cause_dataarray - self._cause_tpm = cause_dataarray.data + self._cause_dataarray = cause_dataarray + self._cause_tpm = ( + self._cause_dataarray.data if cause_dataarray is not None else None + ) self._effect_dataarray = effect_dataarray self._effect_tpm = self._effect_dataarray.data - self.state_space = effect_dataarray.attrs["state_space"] - # (Optional) current state of this node. self.state = effect_dataarray.attrs["state"] @@ -92,7 +81,6 @@ def __init__( hash(pyphi.tpm.ExplicitTPM(self.effect_tpm)), self._inputs, self._outputs, - self.state_space, self.state ) ) @@ -137,99 +125,43 @@ def outputs(self): """frozenset: The set of nodes this node has connections to.""" return self._outputs - @property - def state_space(self): - """Tuple[Union[int, str]]: The space of states this node can inhabit.""" - return self._state_space - - @state_space.setter - def state_space(self, value): - _state_space = tuple(value) - - if len(set(_state_space)) < len(_state_space): - raise ValueError( - "Invalid node state space tuple. Repeated states are ambiguous." - ) - - if len(_state_space) < 2: - raise ValueError( - "Invalid node state space with less than 2 states." - ) - - self._state_space = _state_space - @property def state(self): - """Optional[Union[int, str]]: The current state of this node.""" + """Optional[int]: The current state of this node.""" return self._state @state.setter def state(self, value): - if value not in (*self.state_space, None): + state_space = self.effect_dataarray.coords["Pr"].data + if value not in (*state_space, None): raise ValueError( - f"Invalid node state. Possible states are {self.state_space}." + f"Invalid node state. Possible states are {state_space}." ) self._state = value - def project_index(self, index, preserve_singletons=False): - """Convert absolute TPM index to a valid index relative to this node.""" - - # Supported index coordinates (in the right dimension order) - # respective to this node, to be used like an AND mask, with - # `singleton_coordinate` acting like 0. + @property + def shape(self): + """Tuple[int]: The expanded shape of this node's TPM.""" + squeezed_shape = self.effect_tpm.shape + # A full shape prototype with as many dims as network nodes + 1. + shape = np.ones(len(self._node_labels) + 1, dtype=int) + shape[[*self._inputs, -1]] = squeezed_shape + return tuple(shape) + + def project_index(self, index: Mapping[str, Any]) -> Mapping[str, Any]: + """Convert absolute |ImplicitTPM| index to a valid one relative to this + node. + + Args: + index (Any): The index as provided by ImplicitTPM.__getitem__(). + + Returns (Mapping): The dictionary-style index but devoid of dimensions + missing from this node. + """ dimensions = self._effect_dataarray.dims - coordinates = self._effect_dataarray.coords - - support = {dim: tuple(coordinates[dim].values) for dim in dimensions} - - if isinstance(index, dict): - singleton_coordinate = ( - [SINGLETON_COORDINATE] if preserve_singletons - else SINGLETON_COORDINATE - ) - - try: - # Convert potential int dimension indices to common currency of - # string dimension labels. - keys = [ - k if isinstance(k, str) else dimensions[k] - for k in index.keys() - ] - - projected_index = { - key: value if support[key] != (SINGLETON_COORDINATE,) - else singleton_coordinate - for key, value in zip(keys, index.values()) - } - - except KeyError as e: - raise ValueError( - "Dimension {} does not exist. Expected one or more of: " - "{}.".format(e, dimensions) - ) from e - - return projected_index - - # Assume regular index otherwise. - - if not isinstance(index, tuple): - # Index is a single int, slice, ellipsis, etc. Make it - # amenable to zip(). - index = (index,) - - index_support_map = zip(index, support.values()) - singleton_coordinate = [0] if preserve_singletons else 0 - projected_index = tuple( - i if support != (SINGLETON_COORDINATE,) - else singleton_coordinate - for i, support in index_support_map - ) - - return projected_index - - # def __getitem__(self, index): - # return self._dataarray[index].node + new_index = {dim: idx for dim, idx in index.items() if dim in dimensions} + return new_index def __repr__(self): return self.label @@ -240,12 +172,12 @@ def __str__(self): def __eq__(self, other): """Return whether this node equals the other object. - Two nodes are equal if they have the same index, the same - inputs and outputs, the same TPMs, the same state_space and the - same state. + Two nodes are equal if they have the same index, the same inputs and + outputs, the same TPMs and the same state. Labels are for display only, so two equal nodes may have different labels. + """ return ( self.index == other.index and @@ -253,7 +185,6 @@ def __eq__(self, other): self.effect_tpm.array_equal(other.tpm) and self.inputs == other.inputs and self.outputs == other.outputs and - self.state_space == other.state_space and self.state == other.state ) @@ -275,11 +206,11 @@ def to_json(self): def generate_node( effect_tpm: pyphi.tpm.ExplicitTPM, cm: np.ndarray, - network_state_space: Mapping[str, Tuple[Union[int, str]]], index: int, - node_labels: Iterable[str], + node_labels: NodeLabels, cause_tpm: Optional[pyphi.tpm.ExplicitTPM] = None, - state: Optional[Union[int, str]] = None, + state: Optional[int] = None, + uncut_cm: Optional[np.ndarray] = None, ) -> xr.DataArray: """ Instantiate a node TPM DataArray. @@ -287,64 +218,53 @@ def generate_node( Args: effect_tpm (ExplicitTPM): The effect TPM of this node. cm (np.ndarray): The CM of the network. - network_state_space (Mapping[str, Tuple[Union[int, str]]]): - Labels for the state space of each node in the network. index (int): The node's index in the network. - node_labels (Iterable[str]): Textual labels for each node in the network. + node_labels (NodeLabels): Textual labels for each node in the network. Keyword Args: - cause_tpm (ExplicitTPM): The cause TPM of this node. - state (Optional[Union[int, str]]): The state of this node. + cause_tpm (Optional[ExplicitTPM]): The cause TPM of this node. + state (Optional[int]): The state of this node. + uncut_cm (Optional[np.ndarray]): The original CM of the network. Returns: xr.DataArray: The node in question. """ # Get indices of the inputs and outputs. - inputs = frozenset(get_inputs_from_cm(index, cm)) - outputs = frozenset(get_outputs_from_cm(index, cm)) - - # Marginalize out non-input nodes. - effect_non_inputs = set(effect_tpm.tpm_indices()) - inputs - effect_tpm = effect_tpm.marginalize_out(effect_non_inputs) - - if cause_tpm is not None: - cause_non_inputs = set(cause_tpm.tpm_indices()) - inputs - cause_tpm = cause_tpm.marginalize_out(cause_non_inputs) - - # Dimensions are the names of this node's parents (whose state this node's - # TPM can be conditioned on), plus the last dimension with the probability - # for each possible state of this node in the next timestep. - dimensions = dimension_labels(node_labels) - - # Compute the relevant state labels (coordinates in xarray terminology) from - # the perspective of this node and its direct inputs. - node_states = [network_state_space[dim] for dim in dimensions[:-1]] - input_coordinates, _ = build_state_space( - node_labels, - effect_tpm.shape[:-1], - node_states, - singleton_state_space=(SINGLETON_COORDINATE,), - ) - - node_state_space = network_state_space[dimensions[index]] - - coordinates = {**input_coordinates, dimensions[-1]: node_state_space} + inputs = get_inputs_from_cm(index, cm) + outputs = get_outputs_from_cm(index, cm) + + if uncut_cm is not None: + # Marginalize out non-input nodes (required by cut Subsystems). + original_inputs = get_inputs_from_cm(index, uncut_cm) + cut_inputs = set(original_inputs) - set(inputs) + cut_indices = np.where( + np.isin(original_inputs, list(cut_inputs)) + )[0] + effect_tpm = effect_tpm.marginalize_out(cut_indices.tolist()).squeeze() + cause_tpm = cause_tpm.marginalize_out(cut_indices.tolist()).squeeze() + + # Dimensions are the names of this node's inputs plus the last dimension + # with the probability for each state of this node in the next timestep. + dimensions = node_labels.indices2labels(inputs) + ("Pr",) + + # The possible states for each dimension. + coordinates = tuple(range(dim) for dim in effect_tpm.shape) + + attributes = { + "index": index, + "node_labels": node_labels, + "cm": cm, + "inputs": frozenset(inputs), + "outputs": frozenset(outputs), + "state": state, + } cause_dataarray = xr.DataArray( name=node_labels[index], data=cause_tpm, dims=dimensions, coords=coordinates, - attrs={ - "index": index, - "node_labels": node_labels, - "cm": cm, - "inputs": inputs, - "outputs": outputs, - "state_space": tuple(node_state_space), - "state": state, - "network_state_space": network_state_space - } + attrs=attributes, ) if cause_tpm is not None else None effect_dataarray = xr.DataArray( @@ -352,16 +272,7 @@ def generate_node( data=effect_tpm, dims=dimensions, coords=coordinates, - attrs={ - "index": index, - "node_labels": node_labels, - "cm": cm, - "inputs": inputs, - "outputs": outputs, - "state_space": tuple(node_state_space), - "state": state, - "network_state_space": network_state_space - } + attrs=attributes, ) return Node(effect_dataarray, cause_dataarray) @@ -370,24 +281,20 @@ def generate_node( def generate_nodes( network_tpm, cm: np.ndarray, - state_space: Mapping[str, Tuple[Union[int, str]]], indices: Tuple[int], - node_labels: Tuple[str], - network_state: Optional[Tuple[Union[int, str]]] = None, + node_labels: NodeLabels, + network_state: Optional[Tuple[int]] = None, ) -> Tuple[xr.DataArray]: - """Generate |Node| objects out of a binary network |TPM|. + """Generate |Node| objects out of a binary network |ExplicitTPM|. Args: network_tpm (|ExplicitTPM, ImplicitTPM|): The system's TPM. cm (np.ndarray): The CM of the network. - state_space (Mapping[str, Tuple[Union[int, str]]]): Labels - for the state space of each node in the network. indices (Tuple[int]): Indices to generate nodes for. - node_labels (Optional[Tuple[str]]): Textual labels for each node. + node_labels (NodeLabels): Textual labels for each node. Keyword Args: - network_state (Optional[Tuple[Union[int, str]]]): The state of - the network. + network_state (Optional[Tuple[int, str]]): The state of the network. Returns: Tuple[xr.DataArray]: The nodes of the system. @@ -422,17 +329,13 @@ def generate_nodes( np.stack([np.asarray(tpm_off), np.asarray(tpm_on)], axis=-1) ) - nodes.append( - generate_node( - node_tpm, - cm, - state_space, - index, - node_labels, - cause_tpm=None, - state=state, - ) - ) + # Marginalize out non-input nodes (network_tpm is |ExplicitTPM|). + inputs = get_inputs_from_cm(index, cm) + non_inputs = frozenset(node_tpm.tpm_indices()) - frozenset(inputs) + node_tpm = node_tpm.marginalize_out(non_inputs).squeeze() + + node = generate_node(node_tpm, cm, index, node_labels, state=state) + nodes.append(node) return tuple(nodes) diff --git a/pyphi/state_space.py b/pyphi/state_space.py index 88add190f..5676efc81 100644 --- a/pyphi/state_space.py +++ b/pyphi/state_space.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- # state_space.py """ @@ -31,6 +29,7 @@ def input_dimension_label(node_label: str) -> str: """ return INPUT_DIMENSION_PREFIX + str(node_label) + def dimension_labels(node_labels: Iterable[str]) -> List[str]: """Generate labels for each dimension in the |ImplicitTPM|. diff --git a/pyphi/subsystem.py b/pyphi/subsystem.py index 4f7f189d9..0a392543e 100755 --- a/pyphi/subsystem.py +++ b/pyphi/subsystem.py @@ -3,7 +3,7 @@ import functools import logging -from typing import Iterable, Tuple +from typing import Iterable, Tuple, Optional, Sequence, Union import numpy as np from numpy.typing import ArrayLike @@ -20,6 +20,7 @@ from .metrics.distribution import repertoire_distance as _repertoire_distance from .models import ( Concept, + Cut, MaximallyIrreducibleCause, MaximallyIrreducibleEffect, NullCut, @@ -27,6 +28,7 @@ _null_ria, CauseEffectStructure, ) +from .network import Network from .node import generate_node from .models.mechanism import ShortCircuitConditions, StateSpecification from .network import irreducible_purviews @@ -41,7 +43,7 @@ class Subsystem: Args: network (Network): The network the subsystem belongs to. - state (tuple[int]): The state of the network. + state (Sequence[int]): The state of the network. Keyword Args: nodes (tuple[int] or tuple[str]): The nodes of the network which are in @@ -65,16 +67,16 @@ class Subsystem: def __init__( self, - network, - state, - nodes=None, - cut=None, + network: Network, + state: Sequence[int], + nodes: Optional[Union[Sequence[int], Sequence[str]]] = None, + cut: Optional[Cut] = None, # TODO(4.0): refactor repertoire caches - repertoire_cache=None, - single_node_repertoire_cache=None, - forward_repertoire_cache=None, - unconstrained_forward_repertoire_cache=None, - _external_indices=None, + repertoire_cache: Optional[cache.DictCache] = None, + single_node_repertoire_cache: Optional[cache.DictCache] = None, + forward_repertoire_cache: Optional[cache.DictCache] = None, + unconstrained_forward_repertoire_cache: Optional[cache.DictCache] = None, + _external_indices: Optional[Sequence[int]] = None, ): # The network this subsystem belongs to. validate.is_network(network) @@ -103,8 +105,16 @@ def __init__( # Get the TPMs conditioned on the state of the external nodes. external_state = utils.state_of(self.external_indices, self.state) background_conditions = dict(zip(self.external_indices, external_state)) - self.cause_tpm = self.network.tpm.backward_tpm(state, self.node_indices) + subsystem_labels = self.node_labels.indices2labels(self.node_indices) + self.effect_tpm = self.network.tpm.condition_tpm(background_conditions) + self.cause_tpm = self.network.tpm.backward_tpm(state, subsystem_labels) + + # Set the state of the |Node|s. + nodes_zip = zip(self.effect_tpm.nodes, self.cause_tpm.nodes, self.state) + for effect_node, cause_node, node_state in nodes_zip: + effect_node.state = node_state + cause_node.state = node_state # The TPMs for just the nodes in the subsystem. self.proper_effect_tpm = self.effect_tpm.squeeze() @@ -132,29 +142,26 @@ def __init__( unconstrained_forward_repertoire_cache or cache.DictCache() ) - # Set the state of the |Node|s. - nodes_zip = zip(self.effect_tpm.nodes, self.cause_tpm.nodes, self.state) - for effect_node, cause_node, node_state in nodes_zip: - effect_node.state = node_state - cause_node.state = node_state - # Generate |Node|s for this subsystem and this particular cut to the cm. - nodes_enumerate = enumerate(zip(self.cause_tpm.nodes, self.effect_tpm.nodes)) + nodes_enumerate = enumerate( + zip(self.cause_tpm.nodes, self.effect_tpm.nodes) + ) + self.nodes = tuple( generate_node( node[Direction.EFFECT].effect_tpm, self.cm, - self.network.state_space, i, self.node_labels, cause_tpm=node[Direction.CAUSE].effect_tpm, state=node[Direction.EFFECT].state, + uncut_cm=network.cm if cut is not None else None, ) for i, node in nodes_enumerate if i in self.node_indices ) - validate.subsystem(self) + # validate.subsystem(self) @property def nodes(self): @@ -227,10 +234,6 @@ def tpm_size(self): raise ValueError("cause and effect TPM sizes should be the same") return self.effect_tpm.shape[-1] - @property - def state_space(self): - return self.network.state_space - def cache_info(self): """Report repertoire cache statistics.""" return { @@ -407,21 +410,23 @@ def _single_node_effect_repertoire( # pylint: disable=missing-docstring purview_node = self._index2node[purview_node_index] # Condition on the state of the purview inputs that are in the mechanism + inputs_condition = { + k: v for k, v in condition.items() + if k in purview_node.inputs + } if direction == Direction.CAUSE: - tpm = purview_node.cause_tpm.condition_tpm(condition) + tpm = purview_node.cause_tpm.condition_tpm(inputs_condition) elif direction == Direction.EFFECT: - tpm = purview_node.effect_tpm.condition_tpm(condition) + tpm = purview_node.effect_tpm.condition_tpm(inputs_condition) else: return validate.direction(direction) - # TODO(4.0) remove reference to TPM # Marginalize-out the inputs that aren't in the mechanism. nonmechanism_inputs = purview_node.inputs - set(condition) tpm = tpm.marginalize_out(nonmechanism_inputs) # Reshape so that the distribution is over next states. - return tpm.reshape( - repertoire_shape(self.network.node_indices, (purview_node_index,)) - ).tpm + shape = repertoire_shape(self.network.node_indices, (purview_node_index,)) + return np.asarray(tpm.reshape(shape)) @cache.method("_repertoire_cache", Direction.EFFECT) def _effect_repertoire( @@ -992,7 +997,7 @@ def intrinsic_information( mechanism: Tuple[int], purview: Tuple[int], repertoire_distance: str = None, - states: Iterable[Iterable[int]] = None, + states: Iterable[Sequence[int]] = None, ): repertoire_distance = fallback( repertoire_distance, config.REPERTOIRE_DISTANCE_INFORMATION diff --git a/pyphi/tpm.py b/pyphi/tpm.py index 01419ce9e..4eb28a0d2 100755 --- a/pyphi/tpm.py +++ b/pyphi/tpm.py @@ -4,18 +4,20 @@ import math import functools from itertools import chain -from typing import Iterable, Mapping, Optional, Set, Tuple +from typing import Any, Iterable, Mapping, Optional, Set, Sequence, Tuple import numpy as np -from . import convert, distribution, data_structures, exceptions +import pyphi.node +from . import convert, data_structures, exceptions from .connectivity import subadjacency from .conf import config from .constants import OFF, ON from .data_structures import FrozenMap -import pyphi.node +from .labels import NodeLabels from .utils import all_states, eq, np_hash, np_immutable + class TPM: """TPM interface for derived classes.""" @@ -82,12 +84,10 @@ def subtpm(self, fixed_nodes, state): if isinstance(self, ExplicitTPM): return conditioned_tpm[..., free_nodes] - return type(self)( - tuple( - node for node in conditioned_tpm.nodes - if node.index in free_nodes - ) + nodes = tuple( + node for node in conditioned_tpm.nodes if node.index in free_nodes ) + return type(self)(nodes) def infer_edge(self, a, b, contexts): """Infer the presence or absence of an edge from node A to node B. @@ -427,8 +427,9 @@ def condition_tpm(self, condition: Mapping[int, int]): """ # Assumes multidimensional form conditioning_indices = [[slice(None)]] * (self.ndim - 1) - for i, state_i in condition.items(): + for i, state_i in enumerate(condition.values()): # Ignore dimensions that are already singletons + # TODO (tpm): remove check, singleton dims are no longer expected. if self.shape[i] != 1: # Preserve singleton dimensions in output array with `np.newaxis` conditioning_indices[i] = [state_i, np.newaxis] @@ -451,9 +452,6 @@ def marginalize_out(self, node_indices): tpm = self.sum(tuple(node_indices), keepdims=True) / ( np.array(self.shape)[list(node_indices)].prod() ) - # Return new TPM object of the same type as self. Assume self had - # already been validated and converted formatted. Further validation - # would be problematic for singleton dimensions. return type(self)(tpm) def is_deterministic(self): @@ -508,10 +506,12 @@ def permute_nodes(self, permutation): ) def probability_of_current_state(self, current_state): - """Return the probability of the current state as a distribution over previous states. + """Return the probability of the current state as a distribution over + previous states. Arguments: current_state (tuple[int]): The current state. + """ state_probabilities = np.empty(self.shape) if not len(current_state) == self.shape[-1]: @@ -529,7 +529,7 @@ def probability_of_current_state(self, current_state): def backward_tpm( self, current_state: tuple[int], - system_indices: Iterable[int], + system_indices: Sequence[int], remove_background: bool = False, ): """Compute the backward TPM for a given network state.""" @@ -619,7 +619,7 @@ def number_of_units(self): @property def ndim(self): """int: The number of dimensions of the TPM.""" - return len(self.shape) + return len(self.nodes) + 1 @property def shape(self): @@ -635,37 +635,48 @@ def _reconstituted_shape(self): @property def shapes(self): """Tuple[Tuple[int]]: The shapes of each node TPM in this TPM.""" - return [node.effect_tpm.shape for node in self._nodes] + return tuple(node.shape for node in self.nodes) + + @property + def node_labels(self): + """tuple[str]: The labels of nodes in the network.""" + return tuple(node.label for node in self._nodes) + + @property + def node_indices(self): + """tuple[int]: The labels of nodes in the network.""" + return tuple(node.index for node in self._nodes) @staticmethod def _node_shapes_to_shape( - shapes: Iterable[Iterable[int]], - reconstituted: Optional[bool] = None + shapes: Sequence[Sequence[int]], + reconstituted: bool = False, ) -> Tuple[int]: """Infer the shape of the equivalent multidimensional |ExplicitTPM|. Args: - shapes (Iterable[Iterable[int]]): The shapes of the individual node + shapes (Sequence[Sequence[int]]): The shapes of the individual node TPMs in the network, ordered by node index. + reconstituted (Optional[bool]): If True, the number of states per + node will be based on the actual data of the equivalent + reconstituted ExplicitTPM, as opposed to the number of states + reported by the node itself. Returns: Tuple[int]: The inferred shape of the equivalent TPM. """ - # This should recompute the network TPM shape from individual node - # shapes, as opposed to measuring the size of the state space. - if not all(len(shape) == len(shapes[0]) for shape in shapes): raise ValueError( "The provided shapes contain varying number of dimensions." ) - N = len(shapes) if reconstituted: states_per_node = tuple(max(dim) for dim in zip(*shapes))[:-1] else: states_per_node = tuple(shape[-1] for shape in shapes) # Check consistency of shapes across nodes. + N = len(shapes) dimensions_from_shapes = tuple( set(shape[node_index] for shape in shapes) @@ -717,7 +728,7 @@ def is_unitary(self): def _validate_shape(self): """Validate this TPM's shape. - The inferred shape of the implicit network TPM must be in + The inferred shape of the network ImplicitTPM must be in multidimensional state-by-node form, nonbinary and heterogeneous units supported. """ @@ -742,7 +753,7 @@ def to_multidimensional_state_by_node(self): """ return reconstitute_tpm(self) - # TODO(tpm) accept node labels and state labels in the map. + # TODO(tpm) accept node labels in the map. def condition_tpm(self, condition: Mapping[int, int]): """Return a TPM conditioned on the given fixed node indices, whose states are fixed according to the given state-tuple. @@ -761,14 +772,14 @@ def condition_tpm(self, condition: Mapping[int, int]): singleton dimensions for nodes in a fixed state. """ # Wrapping index elements in a list is the xarray equivalent - # of inserting a numpy.newaxis, which preserves the singleton even + # of inserting a numpy.newaxis, which preserves the dimension even # after selection of a single state. conditioning_indices = { i: (state_i if isinstance(state_i, list) else [state_i]) for i, state_i in condition.items() } - return self.__getitem__(conditioning_indices, preserve_singletons=True) + return self.__getitem__(conditioning_indices) def marginalize_out(self, node_indices): """Marginalize out nodes from this TPM. @@ -782,18 +793,16 @@ def marginalize_out(self, node_indices): """ # Leverage ExplicitTPM.marginalize_out() to distribute operation to # individual nodes, then assemble into a new ImplicitTPM. - return type(self)( - tuple( - pyphi.node.generate_node( - node.effect_tpm.marginalize_out(node_indices), - node.effect_dataarray.attrs["cm"], - node.effect_dataarray.attrs["network_state_space"], - node.index, - node.effect_dataarray.attrs["node_labels"], - ) - for node in self.nodes + new_nodes = tuple( + pyphi.node.generate_node( + node.effect_tpm.marginalize_out(node_indices), + node.effect_dataarray.attrs["cm"], + node.index, + node.effect_dataarray.attrs["node_labels"], ) + for node in self.nodes ) + return type(self)(new_nodes) def is_state_by_state(self): """Return ``True`` if ``tpm`` is in state-by-state form, otherwise @@ -826,6 +835,8 @@ def remove_singleton_dimensions(self): tuple(node for node in self.squeeze().nodes) ) + # TODO(tpm): Divide by 0: Return 0 probability if normalization factor is 0. + # State reachability (no cause) is handled by validate.subsytem. def probability_of_current_state( self, current_state: tuple[int] @@ -836,9 +847,13 @@ def probability_of_current_state( only contains the probability for the current state. Arguments: - current_state (tuple[int]): The current state. + current_state (tuple[int]): The current state. + Returns: - tuple[ExplicitTPM]: Node-marginal distributions of the current state. + tuple[ExplicitTPM]: Node-marginal distributions of the current state. + + Raises: + ValueError: If lengths of `current_state` and `self` don't match. """ if not len(current_state) == self.number_of_units: raise ValueError( @@ -847,66 +862,75 @@ def probability_of_current_state( ) nodes = [] for node in self.nodes: - i = node.index - state = current_state[i] + state = current_state[node.index] # DataArray indexing: keep last dimension by wrapping index in list. - pr_current_state = node.effect_dataarray[..., [state]].data - normalization = np.sum(pr_current_state) + pr_current_state = node.effect_dataarray[..., state] + normalization = np.sum(pr_current_state.data) if normalization == 0.0: raise exceptions.StateUnreachableError(current_state) nodes.append(pr_current_state / normalization) return tuple(nodes) + # TODO(tpm): Divide by 0: Return 0 probability if normalization factor is 0. + # State reachability (no cause) is handled by validate.subsytem. def backward_tpm( self, - current_state: tuple[int], - system_indices: Iterable[int], - ): - """Compute the backward TPM for a given network state.""" - all_indices = tuple(range(self.number_of_units)) - system_indices = tuple(sorted(system_indices)) - background_indices = tuple(sorted(set(all_indices) - set(system_indices))) - if not set(system_indices).issubset(set(all_indices)): + current_state: Sequence[int], + subsystem_labels: Iterable[str], + ) -> 'ImplicitTPM': + """ + Compute the cause TPM for a given network state and subsystem. + + Args: + current_state (Sequence[int]): The current state of the network. + subsystem_labels (Iterable[str]): Labels of the nodes in the subsystem. + + Returns: + ImplicitTPM: The cause TPM for the specified subsystem. + + Raises: + ValueError: If `subsystem_labels` is not a subset of network. + """ + if not set(subsystem_labels).issubset(set(self.node_labels)): raise ValueError( "system_indices must be a subset of `range(self.number_of_units))`" ) # p(u_t | s_{t–1}, w_{t–1}) pr_current_state_nodes = self.probability_of_current_state(current_state) - # TODO Avoid computing the full joint probability. Find uninformative - # dimensions after each product and propagate their dismissal. - pr_current_state = functools.reduce(np.multiply, pr_current_state_nodes) + pr_current_state = functools.reduce( + lambda x, y: x * y, pr_current_state_nodes + ) + # pr_current_state.data = np.asarray(pr_current_state.data) # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) + pr_current_state_given_only_background = pr_current_state.sum( - axis=tuple(system_indices), keepdims=True + dim=subsystem_labels ) # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) # ————————————————————————————————————— # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) pr_current_state_given_only_background_normalized = ( - pr_current_state_given_only_background / np.sum(pr_current_state) + pr_current_state_given_only_background / pr_current_state.sum() ) # Σ_{s_{t–1}} p(u_t | s_{t–1}, w_{t–1}) # Σ_{w_{t–1}} p(s_{i,t} | s_{t–1}, w_{t–1}) ————————————————————————————————————— # Σ_{u'_{t–1}} p(u_t | u'_{t–1}) - backward_tpm = tuple( - (node_tpm * pr_current_state_given_only_background_normalized).sum( - axis=background_indices, keepdims=True - ) - for node_tpm in self.tpm - ) - reference_node = self.nodes[0].effect_dataarray - return ImplicitTPM( - tuple( - pyphi.node.generate_node( - backward_node_tpm, - reference_node.attrs["cm"], - reference_node.attrs["network_state_space"], - i, - reference_node.attrs["node_labels"], - ) - for i, backward_node_tpm in enumerate(backward_tpm) + background_dimensions = set(self.node_labels) - set(subsystem_labels) + backward_nodes = [] + for node in self.nodes: + node_tpm = node.effect_dataarray + unwanted_dimensions = background_dimensions - set(node_tpm.dims) + node_tpm = node_tpm * pr_current_state_given_only_background_normalized + node_tpm = node_tpm.sum(dim=background_dimensions, keepdims=True) + node = pyphi.node.generate_node( + node_tpm.squeeze(dim=unwanted_dimensions), + node.effect_dataarray.attrs["cm"], + node.effect_dataarray.attrs["index"], + node.effect_dataarray.attrs["node_labels"], ) - ) + backward_nodes.append(node) + + return ImplicitTPM(backward_nodes) def equals(self, o: object): """Return whether this TPM equals the other object. @@ -919,72 +943,127 @@ def equals(self, o: object): def array_equal(self, o: object): return self.equals(o) - def squeeze(self, axis=None): - """Wrapper around numpy.squeeze.""" - # If axis is None, all axis should be considered. - if axis is None: - axis = set(range(len(self))) + def squeeze(self, dims: Optional[Iterable[str]] = None) -> 'ImplicitTPM': + """Remove axes of length one from the ImplicitTPM. + + Args: + dims (Optional[Iterable[str]]): Node labels to squeeze. If None, all + dimensions are considered. + + Returns: + ImplicitTPM: A new TPM with squeezed dimensions. + + Raise: + ValueError: If any specified dimension is invalid. + """ + # If dims is None, all dimensions should be considered. + if dims is None: + dims = set(self.node_labels) else: - axis = set(axis) if isinstance(axis, Iterable) else set([axis]) + dims = set(dims) + if not dims.issubset(self.node_labels): + invalid = dims - self.node_labels + raise ValueError(f"Invalid dimensions: {invalid}") - # Subtract non-singleton dimensions from `axis`, including fake - # singletons (dimensions that are singletons only for a proper subset of - # the nodes), since those should not be squeezed, not even within - # individual node TPMs. - shape = self._reconstituted_shape - nonsingletons = tuple(np.where(np.array(shape) > 1)[0]) - axis = tuple(axis - set(nonsingletons)) + # Reconstruct non-singleton dimensions for the whole ImplicitTPM. - # From now on, we will only care about the first n-1 dimensions (parents). - if shape[-1] > 1: - nonsingletons = nonsingletons[:-1] + shape = self._reconstituted_shape[:-1] # Parents are first n-1 dims. + nonsingletons = np.where(np.array(shape) > 1)[0] # Recompute connectivity matrix and subset of node labels. # TODO(tpm) deduplicate commonalities with macro.MacroSubsystem._squeeze. - some_node = self.nodes[0] - - new_cm = subadjacency(some_node.effect_dataarray.attrs["cm"], nonsingletons) + new_cm = subadjacency( + self.nodes[0].effect_dataarray.attrs["cm"], + nonsingletons, + ) + # Convert to textual labels. + nonsingletons = tuple( + node.label for node in self.nodes if node.index in nonsingletons + ) + new_node_labels = NodeLabels(nonsingletons, range(len(nonsingletons))) new_node_indices = iter(range(len(nonsingletons))) - new_node_labels = tuple(some_node._node_labels[n] for n in nonsingletons) - state_space = some_node.effect_dataarray.attrs["network_state_space"] - new_state_space = {n: state_space[n] for n in new_node_labels} + # Subtract non-singleton dimensions from `dims`. + dims_to_squeeze = dims - set(nonsingletons) + if not dims_to_squeeze: + return self # TODO(tpm): return copy? - # Leverage ExplicitTPM.squeeze to distribute squeezing to every node. - return type(self)( - tuple( - pyphi.node.generate_node( - node.effect_tpm.squeeze(axis=axis), + # Distribute squeezing to every node. + new_nodes = [] + for node in self.nodes: + if node.label in nonsingletons: + dims_map = dict(zip(dims_to_squeeze, dims_to_squeeze)) + node_dims_to_squeeze = set(node.project_index(dims_map).keys()) + new_tpm = node.effect_dataarray.squeeze( + dim=node_dims_to_squeeze, drop=True, + ) + new_node = pyphi.node.generate_node( + new_tpm.data, new_cm, - new_state_space, next(new_node_indices), new_node_labels, ) - for node in self.nodes if node.index in nonsingletons - ) - ) + new_nodes.append(new_node) - def __getitem__(self, index, **kwargs): - if isinstance(index, (int, slice, type(...), tuple)): - return type(self)( - tuple( - # The nodes in an ImplicitTPM only have "effect" - # node TPMs, even if ImplicitTPM is a cause TPM. - node.effect_dataarray[node.project_index(index, **kwargs)].node - for node in self.nodes - ) - ) + return type(self)(new_nodes) + + def __getitem__(self, index): + # Convert to common currency of named dimensions. if isinstance(index, dict): - return type(self)( - tuple( - # The nodes in an ImplicitTPM only have "effect" - # node TPMs, even if ImplicitTPM is a cause TPM. - node.effect_dataarray.loc[node.project_index(index, **kwargs)].node - for node in self.nodes - ) - ) - raise TypeError(f"Invalid index {index} of type {type(index)}.") + index2label = dict(zip(self.node_indices, self.node_labels)) + index = {index2label[k]: v for k, v in index.items()} + elif isinstance(index, (int, slice, type(...), tuple)): + index = self._index_to_dict(index) + else: + raise TypeError(f"Invalid index {index} of type {type(index)}.") + + new_node_tpms = [ + node.effect_dataarray.loc[node.project_index(index)].node + for node in self._nodes + ] + + return type(self)(new_node_tpms) + + def _index_to_dict(self, index: Any) -> Mapping[str, Any]: + """Convert a NumPy-style index into a dictionary which maps dimension + names to their corresponding indices. + + Args: + index (Any): The index as received by __getitem__. + + Returns (Mapping): A dictionary where keys are dimension names + and values are indices. + + """ + ndim = self.ndim + + # Convert single int, slice or Ellipsis to a tuple. + if not isinstance(index, tuple): + index = (index,) + + # Replace Ellipsis with the appropriate number of slice(None) objects. + ellipsis_count = index.count(Ellipsis) + + if ellipsis_count > 1: + raise IndexError("an index can only have a single ellipsis ('...')") + if ellipsis_count == 1: + # Calculate how many dimensions are replaced by the Ellipsis. + n_missing = ndim - (len(index) - 1) + # Replace Ellipsis with slice(None) for each missing dimension. + index = [] + for idx in index: + if idx is Ellipsis: + index.extend([slice(None)] * n_missing) + else: + index.append(idx) + + if len(index) > ndim: + raise IndexError("too many indices for array.") + + # Create the dictionary. + node_labels = [node.label for node in self.nodes] + return dict(zip(node_labels, index, strict=False)) def __len__(self): """int: The number of nodes in the TPM.""" @@ -1000,7 +1079,6 @@ def __hash__(self): return hash(tuple(hash(node) for node in self.nodes)) - def reconstitute_tpm(subsystem): """Reconstitute the ExplicitTPM of a subsystem using individual node TPMs.""" # The last axis of the node TPMs correponds to ON or OFF probabilities diff --git a/pyphi/validate.py b/pyphi/validate.py index 2d86d0c3d..132d47af4 100644 --- a/pyphi/validate.py +++ b/pyphi/validate.py @@ -148,8 +148,8 @@ def state(state, size, shape): def _past_states(p_node): """Find set of states which could have led to the current state of a node. - The state of irrelevant dimensions, nodes which don't output to this - node, is represented with -1 to encode a whole equivalence class. + The state of irrelevant dimensions (nodes which don't output to this + node) is represented with -1 to encode a whole equivalence class. Arguments: p_node (np.ndarray): Node TPM conditioned on the current subsystem state. @@ -225,7 +225,9 @@ def state_reachable(subsystem): # ∃ s_{t-1} : p(s | s_{t-1}, w_{t-1}) > 0 # Obtain p(s | w_{t-1}) as node marginals (i.e. implicitly). - p = subsystem.proper_effect_tpm.probability_of_current_state(subsystem.proper_state) + p = subsystem.proper_effect_tpm.probability_of_current_state( + subsystem.proper_state + ) # Avoid computing the joint distribution. For each node n, find the set of # coordinates s_{t-1} for which p_n > 0. The intersection of all such sets