From 9160d6d65b8f1ee83193cc92f27bfaf0b941bd01 Mon Sep 17 00:00:00 2001 From: Ninjani <48680156+Ninjani@users.noreply.github.com> Date: Wed, 3 Feb 2021 18:41:29 +0100 Subject: [PATCH] feat(moments): add density parameter to moment invariant calculation --- geometricus/geometricus.py | 35 +++- geometricus/moment_utility.py | 312 +++++++++++++++------------------- 2 files changed, 164 insertions(+), 183 deletions(-) diff --git a/geometricus/geometricus.py b/geometricus/geometricus.py index 457ed4c..217d28e 100644 --- a/geometricus/geometricus.py +++ b/geometricus/geometricus.py @@ -421,6 +421,8 @@ class MomentInvariants(Structure): """Filled with moment invariant values for each structural fragment""" moment_names: List[str] = None """Names of moments used""" + density: np.ndarray = None + """Density of each residue""" @classmethod def from_coordinates( @@ -431,7 +433,8 @@ def from_coordinates( split_type: SplitType = SplitType.KMER, split_size: int = 16, upsample_rate: int = 50, - moment_names: List[str] = ("O_3", "O_4", "O_5", "F") + moment_names: List[str] = ("O_3", "O_4", "O_5", "F"), + density: np.ndarray = None ): """ Construct MomentInvariants instance from a set of coordinates. @@ -448,8 +451,11 @@ def from_coordinates( split_type=split_type, split_size=split_size, upsample_rate=upsample_rate, - moment_names=moment_names + moment_names=moment_names, + density=density ) + if shape.density is None: + shape.density = np.ones(shape.coordinates.shape[0]) shape._split(split_type) return shape @@ -462,7 +468,8 @@ def from_prody_atomgroup( split_size: int = 16, selection: str = "calpha", upsample_rate: int = 50, - moment_names: List[str] = ("O_3", "O_4", "O_5", "F") + moment_names: List[str] = ("O_3", "O_4", "O_5", "F"), + density: np.ndarray = None ): """ Construct MomentInvariants instance from a ProDy AtomGroup object. @@ -485,8 +492,11 @@ def from_prody_atomgroup( split_type=split_type, split_size=split_size, upsample_rate=upsample_rate, - moment_names=moment_names + moment_names=moment_names, + density=density ) + if shape.density is None: + shape.density = np.ones(shape.coordinates.shape[0]) shape._split(split_type) return shape @@ -516,7 +526,8 @@ def from_pdb_file( split_size: int = 16, selection: str = "calpha", upsample_rate: int = 50, - moment_names: List[str] = ("O_3", "O_4", "O_5", "F") + moment_names: List[str] = ("O_3", "O_4", "O_5", "F"), + density: np.ndarray = None ): """ Construct MomentInvariants instance from a PDB file and optional chain. @@ -538,7 +549,8 @@ def from_pdb_file( split_size, selection=selection, upsample_rate=upsample_rate, - moment_names=moment_names + moment_names=moment_names, + density=density ) @classmethod @@ -550,7 +562,8 @@ def from_pdb_id( split_size: int = 16, selection: str = "calpha", upsample_rate: int = 50, - moment_names: List[str] = ("O_3", "O_4", "O_5", "F") + moment_names: List[str] = ("O_3", "O_4", "O_5", "F"), + density: np.ndarray = None ): """ Construct MomentInvariants instance from a PDB ID and optional chain (downloads the PDB file from RCSB). @@ -572,7 +585,8 @@ def from_pdb_id( split_size, selection=selection, upsample_rate=upsample_rate, - moment_names=moment_names + moment_names=moment_names, + density=density ) def _kmerize(self): @@ -618,7 +632,9 @@ def _allmerize(self): def _get_moments(self, split_indices): moments = np.zeros((len(split_indices), len(self.moment_names))) for i, indices in enumerate(split_indices): - moments[i] = get_moments_from_coordinates(self.coordinates[indices], self.moment_names) + moments[i] = get_moments_from_coordinates(self.coordinates[indices], + self.moment_names, + self.density[indices]) return split_indices, moments def _split_radius(self): @@ -655,6 +671,7 @@ def _split_radius_upsample(self): split_indices.append(kd_tree.getIndices()) moments = np.zeros((len(split_indices), len(self.moment_names))) for i, indices in enumerate(split_indices_upsample): + # TODO: add density here moments[i] = get_moments_from_coordinates(coordinates_upsample[indices], self.moment_names) return split_indices, moments diff --git a/geometricus/moment_utility.py b/geometricus/moment_utility.py index 098fe08..0c352cd 100644 --- a/geometricus/moment_utility.py +++ b/geometricus/moment_utility.py @@ -12,7 +12,9 @@ class MomentInfo: def get_moments_from_coordinates( - coordinates: np.ndarray, moment_names: ty.List[str] = ("O_3", "O_4", "O_5", "F") + coordinates: np.ndarray, + moment_names: ty.List[str] = ("O_3", "O_4", "O_5", "F"), + density: np.ndarray = None, ) -> ty.List[float]: """ Gets rotation-invariant moments for a set of coordinates @@ -23,18 +25,23 @@ def get_moments_from_coordinates( moment_names Which moments to calculate Choose from ['O_3', 'O_4', 'O_5', 'F', 'phi_2', 'phi_3', 'phi_4', 'phi_5', 'phi_6', 'phi_7', 'phi_8', 'phi_9', 'phi_10', 'phi_11', 'phi_12', 'phi_13'] - + density + assign a density to each residue/coordinate. 1 by default Returns ------- list of moments """ + if density is None: + density = np.ones(coordinates.shape[0]) + else: + assert density.shape[0] == coordinates.shape[0] moment_types: ty.List[MomentType] = [MomentType[m] for m in moment_names] all_moment_mu_types: ty.Set[ty.Tuple[int, int, int]] = set( m for moment_type in moment_types for m in moment_type.value.mu_arguments ) centroid = nb_mean_axis_0(coordinates) mus = { - (x, y, z): mu(float(x), float(y), float(z), coordinates, centroid) + (x, y, z): mu(float(x), float(y), float(z), coordinates, density, centroid) for (x, y, z) in all_moment_mu_types } moments = [ @@ -58,7 +65,7 @@ def nb_mean_axis_0(array: np.ndarray) -> np.ndarray: @nb.njit(cache=False) -def mu(p, q, r, coords, centroid): +def mu(p, q, r, coords, density, centroid): """ Central moment """ @@ -66,6 +73,7 @@ def mu(p, q, r, coords, centroid): ((coords[:, 0] - centroid[0]) ** p) * ((coords[:, 1] - centroid[1]) ** q) * ((coords[:, 2] - centroid[2]) ** r) + * density ) @@ -99,17 +107,8 @@ def O_5(mu_200, mu_020, mu_002, mu_110, mu_101, mu_011): @nb.njit def F( - mu_201, - mu_021, - mu_210, - mu_300, - mu_111, - mu_012, - mu_003, - mu_030, - mu_102, - mu_120, - ): + mu_201, mu_021, mu_210, mu_300, mu_111, mu_012, mu_003, mu_030, mu_102, mu_120, +): return ( mu_003 ** 2 + 6 * mu_012 ** 2 @@ -199,17 +198,8 @@ def phi_3(mu_020, mu_011, mu_110, mu_200, mu_002, mu_101): @nb.njit def phi_4( - mu_030, - mu_021, - mu_120, - mu_003, - mu_111, - mu_201, - mu_102, - mu_210, - mu_012, - mu_300, - ): + mu_030, mu_021, mu_120, mu_003, mu_111, mu_201, mu_102, mu_210, mu_012, mu_300, +): return ( mu_300 ** 2 + mu_030 ** 2 @@ -250,17 +240,8 @@ def phi_5(mu_030, mu_021, mu_120, mu_003, mu_201, mu_102, mu_210, mu_012, mu_300 @nb.njit def phi_6( - mu_030, - mu_021, - mu_120, - mu_003, - mu_111, - mu_201, - mu_102, - mu_210, - mu_012, - mu_300, - ): + mu_030, mu_021, mu_120, mu_003, mu_111, mu_201, mu_102, mu_210, mu_012, mu_300, +): return ( 1 * mu_300 ** 4 + 6 * mu_300 ** 2 * mu_210 ** 2 @@ -349,17 +330,8 @@ def phi_6( @nb.njit def phi_7( - mu_030, - mu_021, - mu_120, - mu_003, - mu_111, - mu_201, - mu_102, - mu_210, - mu_012, - mu_300, - ): + mu_030, mu_021, mu_120, mu_003, mu_111, mu_201, mu_102, mu_210, mu_012, mu_300, +): return ( 1 * mu_300 ** 4 + 1 * mu_300 ** 3 * mu_120 @@ -533,17 +505,8 @@ def phi_7( @nb.njit def phi_8( - mu_030, - mu_021, - mu_120, - mu_003, - mu_111, - mu_201, - mu_102, - mu_210, - mu_012, - mu_300, - ): + mu_030, mu_021, mu_120, mu_003, mu_111, mu_201, mu_102, mu_210, mu_012, mu_300, +): return ( 1 * mu_300 ** 4 + 2 * mu_300 ** 3 * mu_120 @@ -726,23 +689,23 @@ def phi_8( @nb.njit def phi_9( - mu_030, - mu_021, - mu_120, - mu_101, - mu_003, - mu_200, - mu_110, - mu_201, - mu_111, - mu_102, - mu_210, - mu_020, - mu_012, - mu_002, - mu_011, - mu_300, - ): + mu_030, + mu_021, + mu_120, + mu_101, + mu_003, + mu_200, + mu_110, + mu_201, + mu_111, + mu_102, + mu_210, + mu_020, + mu_012, + mu_002, + mu_011, + mu_300, +): return ( 1 * mu_200 * mu_300 ** 2 + 2 * mu_110 * mu_300 * mu_210 @@ -785,23 +748,23 @@ def phi_9( @nb.njit def phi_10( - mu_030, - mu_021, - mu_120, - mu_101, - mu_003, - mu_200, - mu_110, - mu_201, - mu_111, - mu_102, - mu_210, - mu_020, - mu_012, - mu_002, - mu_011, - mu_300, - ): + mu_030, + mu_021, + mu_120, + mu_101, + mu_003, + mu_200, + mu_110, + mu_201, + mu_111, + mu_102, + mu_210, + mu_020, + mu_012, + mu_002, + mu_011, + mu_300, +): return ( 1 * mu_200 * mu_300 ** 2 + 1 * mu_200 * mu_300 * mu_120 @@ -859,22 +822,22 @@ def phi_10( @nb.njit def phi_11( - mu_030, - mu_021, - mu_120, - mu_101, - mu_003, - mu_200, - mu_110, - mu_201, - mu_102, - mu_210, - mu_012, - mu_020, - mu_002, - mu_011, - mu_300, - ): + mu_030, + mu_021, + mu_120, + mu_101, + mu_003, + mu_200, + mu_110, + mu_201, + mu_102, + mu_210, + mu_012, + mu_020, + mu_002, + mu_011, + mu_300, +): return ( 1 * mu_200 * mu_300 ** 2 + 2 * mu_200 * mu_300 * mu_120 @@ -926,23 +889,23 @@ def phi_11( @nb.njit def phi_12( - mu_030, - mu_021, - mu_120, - mu_101, - mu_003, - mu_200, - mu_110, - mu_201, - mu_111, - mu_102, - mu_210, - mu_020, - mu_012, - mu_002, - mu_011, - mu_300, - ): + mu_030, + mu_021, + mu_120, + mu_101, + mu_003, + mu_200, + mu_110, + mu_201, + mu_111, + mu_102, + mu_210, + mu_020, + mu_012, + mu_002, + mu_011, + mu_300, +): return ( 1 * mu_200 ** 2 * mu_300 ** 2 + 4 * mu_200 * mu_110 * mu_300 * mu_210 @@ -1012,23 +975,23 @@ def phi_12( @nb.njit def phi_13( - mu_030, - mu_021, - mu_120, - mu_101, - mu_003, - mu_200, - mu_110, - mu_201, - mu_111, - mu_102, - mu_210, - mu_012, - mu_020, - mu_002, - mu_011, - mu_300, - ): + mu_030, + mu_021, + mu_120, + mu_101, + mu_003, + mu_200, + mu_110, + mu_201, + mu_111, + mu_102, + mu_210, + mu_012, + mu_020, + mu_002, + mu_011, + mu_300, +): return ( 1 * mu_200 ** 2 * mu_300 ** 2 + 2 * mu_200 * mu_110 * mu_300 * mu_210 @@ -1137,6 +1100,7 @@ class MomentType(Enum): Different rotation invariant moments (order 2 and order 3) Choose from ['O_3', 'O_4', 'O_5', 'F', 'phi_2', 'phi_3', 'phi_4', 'phi_5', 'phi_6', 'phi_7', 'phi_8', 'phi_9', 'phi_10', 'phi_11', 'phi_12', 'phi_13'] """ + O_3 = MomentInfo(O_3, [(2, 0, 0), (0, 2, 0), (0, 0, 2)]) O_4 = MomentInfo( O_4, [(2, 0, 0), (0, 2, 0), (0, 0, 2), (1, 1, 0), (1, 0, 1), (0, 1, 1),] @@ -1348,10 +1312,10 @@ def get_moments_from_coordinates(self, mus: ty.List[float]): return self.value.moment_function(*mus) -def alpha(index, coords, centroid): - mu_200 = mu(2.0, 0.0, 0.0, coords, centroid) - mu_020 = mu(0.0, 2.0, 0.0, coords, centroid) - mu_002 = mu(0.0, 0.0, 2.0, coords, centroid) +def alpha(index, coords, density, centroid): + mu_200 = mu(2.0, 0.0, 0.0, coords, density, centroid) + mu_020 = mu(0.0, 2.0, 0.0, coords, density, centroid) + mu_002 = mu(0.0, 0.0, 2.0, coords, density, centroid) if index == 1: return mu_002 - mu_020 @@ -1361,17 +1325,17 @@ def alpha(index, coords, centroid): return mu_200 - mu_002 -def beta(index, coords, centroid): - mu_003 = mu(0.0, 0.0, 3.0, coords, centroid) - mu_012 = mu(0.0, 1.0, 2.0, coords, centroid) - mu_021 = mu(0.0, 2.0, 1.0, coords, centroid) - mu_030 = mu(0.0, 3.0, 0.0, coords, centroid) - mu_102 = mu(1.0, 0.0, 2.0, coords, centroid) - mu_111 = mu(1.0, 1.0, 1.0, coords, centroid) - mu_210 = mu(2.0, 1.0, 0.0, coords, centroid) - mu_201 = mu(2.0, 0.0, 1.0, coords, centroid) - mu_120 = mu(1.0, 2.0, 0.0, coords, centroid) - mu_300 = mu(3.0, 0.0, 0.0, coords, centroid) +def beta(index, coords, density, centroid): + mu_003 = mu(0.0, 0.0, 3.0, coords, density, centroid) + mu_012 = mu(0.0, 1.0, 2.0, coords, density, centroid) + mu_021 = mu(0.0, 2.0, 1.0, coords, density, centroid) + mu_030 = mu(0.0, 3.0, 0.0, coords, density, centroid) + mu_102 = mu(1.0, 0.0, 2.0, coords, density, centroid) + mu_111 = mu(1.0, 1.0, 1.0, coords, density, centroid) + mu_210 = mu(2.0, 1.0, 0.0, coords, density, centroid) + mu_201 = mu(2.0, 0.0, 1.0, coords, density, centroid) + mu_120 = mu(1.0, 2.0, 0.0, coords, density, centroid) + mu_300 = mu(3.0, 0.0, 0.0, coords, density, centroid) if index == 1: return mu_021 - mu_201 @@ -1413,26 +1377,26 @@ def beta(index, coords, centroid): raise IndexError -def gamma(index, coords, centroid): - mu_022 = mu(0, 2, 2, coords, centroid) - mu_202 = mu(2, 0, 2, coords, centroid) - mu_220 = mu(2, 2, 0, coords, centroid) +def gamma(index, coords, density, centroid): + mu_022 = mu(0, 2, 2, coords, density, centroid) + mu_202 = mu(2, 0, 2, coords, density, centroid) + mu_220 = mu(2, 2, 0, coords, density, centroid) - mu_400 = mu(4, 0, 0, coords, centroid) - mu_040 = mu(0, 4, 0, coords, centroid) - mu_004 = mu(0, 0, 4, coords, centroid) + mu_400 = mu(4, 0, 0, coords, density, centroid) + mu_040 = mu(0, 4, 0, coords, density, centroid) + mu_004 = mu(0, 0, 4, coords, density, centroid) - mu_112 = mu(1, 1, 2, coords, centroid) - mu_121 = mu(1, 2, 1, coords, centroid) - mu_211 = mu(2, 1, 1, coords, centroid) + mu_112 = mu(1, 1, 2, coords, density, centroid) + mu_121 = mu(1, 2, 1, coords, density, centroid) + mu_211 = mu(2, 1, 1, coords, density, centroid) - mu_130 = mu(1, 3, 0, coords, centroid) - mu_103 = mu(1, 0, 3, coords, centroid) - mu_013 = mu(0, 1, 3, coords, centroid) + mu_130 = mu(1, 3, 0, coords, density, centroid) + mu_103 = mu(1, 0, 3, coords, density, centroid) + mu_013 = mu(0, 1, 3, coords, density, centroid) - mu_310 = mu(3, 1, 0, coords, centroid) - mu_301 = mu(3, 0, 1, coords, centroid) - mu_031 = mu(0, 3, 1, coords, centroid) + mu_310 = mu(3, 1, 0, coords, density, centroid) + mu_301 = mu(3, 0, 1, coords, density, centroid) + mu_031 = mu(0, 3, 1, coords, density, centroid) if index == 1: return mu_022 - mu_400