diff --git a/.codeclimate.yml b/.codeclimate.yml new file mode 100644 index 00000000..eac99d3b --- /dev/null +++ b/.codeclimate.yml @@ -0,0 +1,36 @@ + +engines: + # ... CONFIG CONTENT ... + pep8: + enabled: true + # ... CONFIG CONTENT ... + checks: + E731: # do not assign a lambda expression, use a def + enabled: false + E402: # module level import not at top of file + enabled: false + E501: # maximum line length + enabled: false + +checks: + argument-count: + enabled: false + complex-logic: + enabled: false + file-lines: + enabled: false + method-complexity: + enabled: false + method-count: + enabled: false + method-lines: + enabled: false + nested-control-flow: + enabled: false + return-statements: + enabled: false + similar-code: + enabled: false + identical-code: + enabled: true + diff --git a/.pep8speaks.yml b/.pep8speaks.yml new file mode 100644 index 00000000..7cc18aec --- /dev/null +++ b/.pep8speaks.yml @@ -0,0 +1,10 @@ + +scanner: + diff_only: True # If False, the entire file touched by the Pull Request is scanned for errors. If True, only the diff is scanned. + linter: pycodestyle # Other option is flake8 + +pycodestyle: # Same as scanner.linter value. Other option is flake8 + max-line-length: 121 # Default is 79 in PEP 8 + ignore: # Errors and warnings to ignore + - E731 # do not assign a lambda expression, use a def + - E402 # module level import not at top of file diff --git a/.travis.yml b/.travis.yml index 20a622fc..9c692249 100644 --- a/.travis.yml +++ b/.travis.yml @@ -35,7 +35,7 @@ install: # Command to run tests, e.g. python setup.py test script: # run unit tests - - py.test + - py.test --runslow # run integration tests - lettuce --no-cuda convergence - lettuce --no-cuda benchmark diff --git a/lettuce/__init__.py b/lettuce/__init__.py index 3d24d57b..b1e9b557 100644 --- a/lettuce/__init__.py +++ b/lettuce/__init__.py @@ -10,6 +10,7 @@ __version__ = get_versions()['version'] del get_versions + from lettuce.util import * from lettuce.unit import * from lettuce.lattices import * @@ -25,5 +26,7 @@ from lettuce.simulation import * from lettuce.force import * from lettuce.observables import * +from lettuce.symmetry import * +from lettuce.neural import * from lettuce.flows import * diff --git a/lettuce/boundary.py b/lettuce/boundary.py index 260451a2..20c550c7 100644 --- a/lettuce/boundary.py +++ b/lettuce/boundary.py @@ -117,9 +117,9 @@ def __call__(self, f): u = self.lattice.u(f) u_w = u[[slice(None)] + self.index] + 0.5 * (u[[slice(None)] + self.index] - u[[slice(None)] + self.neighbor]) f[[np.array(self.lattice.stencil.opposite)[self.velocities]] + self.index] = ( - - f[[self.velocities] + self.index] + self.w * self.lattice.rho(f)[[slice(None)] + self.index] * - (2 + torch.einsum(self.dims, self.lattice.e[self.velocities], u_w) ** 2 / self.lattice.cs ** 4 - - (torch.norm(u_w, dim=0) / self.lattice.cs) ** 2) + - f[[self.velocities] + self.index] + self.w * self.lattice.rho(f)[[slice(None)] + self.index] * + (2 + torch.einsum(self.dims, self.lattice.e[self.velocities], u_w) ** 2 / self.lattice.cs ** 4 + - (torch.norm(u_w, dim=0) / self.lattice.cs) ** 2) ) return f diff --git a/lettuce/collision.py b/lettuce/collision.py index 79dc6aec..3c693b60 100644 --- a/lettuce/collision.py +++ b/lettuce/collision.py @@ -3,17 +3,26 @@ """ import torch +import numpy as np from lettuce.equilibrium import QuadraticEquilibrium -from lettuce.util import LettuceException +from lettuce.moments import DEFAULT_TRANSFORM +from lettuce.util import LettuceCollisionNotDefined +from lettuce.stencils import D2Q9, D3Q27 __all__ = [ - "BGKCollision", "KBCCollision2D", "KBCCollision3D", "MRTCollision", "RegularizedCollision", - "SmagorinskyCollision", "TRTCollision", "BGKInitialization" + "Collision", + "BGKCollision", "KBCCollision2D", "KBCCollision3D", "MRTCollision", + "RegularizedCollision", "SmagorinskyCollision", "TRTCollision", "BGKInitialization", ] -class BGKCollision: +class Collision: + def __call__(self, f): + return NotImplemented + + +class BGKCollision(Collision): def __init__(self, lattice, tau, force=None): self.force = force self.lattice = lattice @@ -28,17 +37,27 @@ def __call__(self, f): return f - 1.0 / self.tau * (f - feq) + Si -class MRTCollision: +class MRTCollision(Collision): """Multiple relaxation time collision operator This is an MRT operator in the most general sense of the word. The transform does not have to be linear and can, e.g., be any moment or cumulant transform. """ - def __init__(self, lattice, transform, relaxation_parameters): + def __init__(self, lattice, relaxation_parameters, transform=None): self.lattice = lattice - self.transform = transform - self.relaxation_parameters = lattice.convert_to_tensor(relaxation_parameters) + if transform is None: + try: + self.transform = DEFAULT_TRANSFORM[lattice.stencil](lattice) + except KeyError: + raise LettuceCollisionNotDefined("No entry for stencil {lattice.stencil} in moments.DEFAULT_TRANSFORM") + else: + self.transform = transform + if isinstance(relaxation_parameters, float): + tau = relaxation_parameters + self.relaxation_parameters = lattice.convert_to_tensor(tau * np.ones(lattice.stencil.Q())) + else: + self.relaxation_parameters = lattice.convert_to_tensor(relaxation_parameters) def __call__(self, f): m = self.transform.transform(f) @@ -48,7 +67,7 @@ def __call__(self, f): return f -class TRTCollision: +class TRTCollision(Collision): """Two relaxation time collision model - standard implementation (cf. Krüger 2017) """ @@ -62,14 +81,14 @@ def __call__(self, f): u = self.lattice.u(f) feq = self.lattice.equilibrium(rho, u) f_diff_neq = ((f + f[self.lattice.stencil.opposite]) - (feq + feq[self.lattice.stencil.opposite])) / ( - 2.0 * self.tau_plus) + 2.0 * self.tau_plus) f_diff_neq += ((f - f[self.lattice.stencil.opposite]) - (feq - feq[self.lattice.stencil.opposite])) / ( - 2.0 * self.tau_minus) + 2.0 * self.tau_minus) f = f - f_diff_neq return f -class RegularizedCollision: +class RegularizedCollision(Collision): """Regularized LBM according to Jonas Latt and Bastien Chopard (2006)""" def __init__(self, lattice, tau): @@ -100,12 +119,13 @@ def __call__(self, f): return f -class KBCCollision2D: +class KBCCollision2D(Collision): """Entropic multi-relaxation time model according to Karlin et al. in two dimensions""" def __init__(self, lattice, tau): self.lattice = lattice - assert lattice.Q == 9, LettuceException("KBC2D only realized for D2Q9") + if not lattice.stencil == D2Q9: + raise LettuceCollisionNotDefined("This implementation only works for the D2Q9 stencil.") self.tau = tau self.beta = 1. / (2 * tau) @@ -178,12 +198,13 @@ def __call__(self, f): return f -class KBCCollision3D: +class KBCCollision3D(Collision): """Entropic multi-relaxation time-relaxation time model according to Karlin et al. in three dimensions""" def __init__(self, lattice, tau): self.lattice = lattice - assert lattice.Q == 27, LettuceException("KBC only realized for D3Q27") + if not lattice.stencil == D3Q27: + raise LettuceCollisionNotDefined("This implementation only works for the D3Q27 stencil.") self.tau = tau self.beta = 1. / (2 * tau) @@ -269,7 +290,7 @@ def __call__(self, f): return f -class SmagorinskyCollision: +class SmagorinskyCollision(Collision): """Smagorinsky large eddy simulation (LES) collision model with BGK operator.""" def __init__(self, lattice, tau, smagorinsky_constant=0.17, force=None): @@ -324,3 +345,4 @@ def __call__(self, f): mnew[self.momentum_indices] = rho * self.u f = self.moment_transformation.inverse_transform(mnew) return f + diff --git a/lettuce/lattices.py b/lettuce/lattices.py index 26376c54..a46e84b1 100644 --- a/lettuce/lattices.py +++ b/lettuce/lattices.py @@ -8,6 +8,7 @@ Its stencil is still accessible trough Lattice.stencil. """ +import copy import warnings import numpy as np import torch @@ -102,15 +103,16 @@ def mv(self, m, v): def einsum(self, equation, fields, **kwargs): """Einstein summation on local fields.""" input, output = equation.split("->") + out = copy.copy(output) inputs = input.split(",") + xyz = "xyz"[:self.D] for i, inp in enumerate(inputs): - if len(inp) == len(fields[i].shape): + if "..." in inp: + raise LettuceException("... not allowed in lattice.einsum") + elif len(inp) == len(fields[i].shape): pass - elif len(inp) == len(fields[i].shape) - self.D: - inputs[i] += "..." - if not output.endswith("..."): - output += "..." else: - raise LettuceException("Bad dimension.") - equation = ",".join(inputs) + "->" + output + inputs[i] = f"...{inp}{xyz}" + out = f"...{output}{xyz}" + equation = ",".join(inputs) + "->" + out return torch.einsum(equation, fields, **kwargs) diff --git a/lettuce/moments.py b/lettuce/moments.py index 7ea95470..8e1c3239 100644 --- a/lettuce/moments.py +++ b/lettuce/moments.py @@ -11,7 +11,7 @@ __all__ = [ "moment_tensor", "get_default_moment_transform", "Moments", "Transform", "D1Q3Transform", - "D2Q9Lallemand", "D2Q9Dellar", "D3Q27Hermite" + "D2Q9Lallemand", "D2Q9Dellar", "D3Q27Hermite", "DEFAULT_TRANSFORM" ] _ALL_STENCILS = get_subclasses(Stencil, module=lettuce) @@ -24,15 +24,6 @@ def moment_tensor(e, multiindex): return np.prod(np.power(e, multiindex[..., None, :]), axis=-1) -def get_default_moment_transform(lattice): - if lattice.stencil == D1Q3: - return D1Q3Transform(lattice) - if lattice.stencil == D2Q9: - return D2Q9Lallemand(lattice) - else: - raise LettuceException(f"No default moment transform for lattice {lattice}.") - - class Moments: def __init__(self, lattice): self.rho = moment_tensor(lattice.e, lattice.convert_to_tensor(np.zeros(lattice.D))) @@ -54,10 +45,10 @@ def __getitem__(self, moment_names): return [self.names.index(name) for name in moment_names] def transform(self, f): - return f + return self.lattice.einsum("ij,j->i", (self.matrix, f)) def inverse_transform(self, m): - return m + return self.lattice.einsum("ij,j->i", (self.inverse, m)) def equilibrium(self, m): """A very inefficient and basic implementation of the equilibrium moments. @@ -83,7 +74,7 @@ class D1Q3Transform(Transform): [0, 1 / 2, 1 / 2], [0, -1 / 2, 1 / 2] ]) - names = ["rho", "j", "e"] + names = ["rho", "j_x", "e_xx"] supported_stencils = [D1Q3] def __init__(self, lattice): @@ -91,12 +82,6 @@ def __init__(self, lattice): self.matrix = self.lattice.convert_to_tensor(self.matrix) self.inverse = self.lattice.convert_to_tensor(self.inverse) - def transform(self, f): - return self.lattice.mv(self.matrix, f) - - def inverse_transform(self, m): - return self.lattice.mv(self.inverse, m) - # def equilibrium(self, m): # # TODO # raise NotImplementedError @@ -125,7 +110,7 @@ class D2Q9Dellar(Transform): [1 / 36, -1 / 12, -1 / 12, 1 / 54, 1 / 36, 1 / 54, 1 / 36, -1 / 24, -1 / 24], [1 / 36, 1 / 12, -1 / 12, 1 / 54, -1 / 36, 1 / 54, 1 / 36, 1 / 24, -1 / 24]] ) - names = ['rho', 'jx', 'jy', 'Pi_xx', 'Pi_xy', 'PI_yy', 'N', 'Jx', 'Jy'] + names = ['rho', 'jx', 'jy', 'Pi_xx', 'Pi_xy', 'PI_yy', 'N_xxyy', 'J_xxy', 'J_xyy'] supported_stencils = [D2Q9] def __init__(self, lattice): @@ -135,14 +120,7 @@ def __init__(self, lattice): self.matrix = self.lattice.convert_to_tensor(self.matrix) self.inverse = self.lattice.convert_to_tensor(self.inverse) - def transform(self, f): - return self.lattice.mv(self.matrix, f) - - def inverse_transform(self, m): - return self.lattice.mv(self.inverse, m) - def equilibrium(self, m): - warnings.warn("I am not 100% sure if this equilibrium is correct.", ExperimentalWarning) meq = torch.zeros_like(m) rho = m[0] jx = m[1] @@ -192,12 +170,6 @@ def __init__(self, lattice): self.matrix = self.lattice.convert_to_tensor(self.matrix) self.inverse = self.lattice.convert_to_tensor(self.inverse) - def transform(self, f): - return self.lattice.mv(self.matrix, f) - - def inverse_transform(self, m): - return self.lattice.mv(self.inverse, m) - def equilibrium(self, m): """From Lallemand and Luo""" warnings.warn("I am not 100% sure if this equilibrium is correct.", ExperimentalWarning) @@ -414,12 +386,6 @@ def __init__(self, lattice): self.matrix = self.lattice.convert_to_tensor(self.matrix) self.inverse = self.lattice.convert_to_tensor(self.inverse) - def transform(self, f): - return self.lattice.mv(self.matrix, f) - - def inverse_transform(self, m): - return self.lattice.mv(self.inverse, m) - def equilibrium(self, m): meq = torch.zeros_like(m) rho = m[0] @@ -454,3 +420,18 @@ def equilibrium(self, m): meq[25] = jx * jy * jy * jz * jz / rho ** 4 meq[26] = jx * jy * jx * jz * jy * jz / rho ** 5 return meq + + +DEFAULT_TRANSFORM = { + D1Q3: D1Q3Transform, + D2Q9: D2Q9Dellar, + D3Q27: D3Q27Hermite +} + + +def get_default_moment_transform(lattice): + try: + transform_class = DEFAULT_TRANSFORM[lattice.stencil] + except KeyError: + raise LettuceException(f"No default moment transform for lattice {lattice}.") + return transform_class(lattice) diff --git a/lettuce/neural.py b/lettuce/neural.py new file mode 100644 index 00000000..4abdb904 --- /dev/null +++ b/lettuce/neural.py @@ -0,0 +1,234 @@ +import string + +import torch +import numpy as np + +from lettuce.symmetry import SymmetryGroup + +__all__= ["GConv", "GConvPermutation", "EquivariantNet", "EquivariantNeuralCollision"] + + +class GConv(torch.nn.Module): + """Group Convolution Layer. + Linear layer without bias that is equivariant with respect to a given symmetry group. + + Parameters + ---------- + in_channels : int + number of input channels + out_channels : int + number of input channels + group_action : torch.Tensor + Tensor of shape (group_order, dim, dim) that defines the group representation in GL(n). + For the LBM: M * P_g * M^{-1} gives the moment action for the g-th permutation of fs. + inverse_group_action : torch.Tensor + Tensor of shape (group_order, dim, dim) that defines the inverse group representation in GL(n). + in_indices : np.ndarray + Index array. Indices of the input tensors that are convolved. + out_indices : np.ndarray + Index array. Indices of the output tensors of the convolution. + feature_dim : int + Dimensions that contains the features (in the LBM, the Q-dimension). + channel_dim : int + Dimension that contains the channels. + """ + def __init__( + self, + in_channels, + out_channels, + group_actions, + inverse_group_actions, + in_indices=None, + out_indices=None, + feature_dim=1, + channel_dim=0 + ): + super().__init__() + self.dim = group_actions.shape[1] + self.in_indices = np.arange(self.dim) if in_indices is None else in_indices + self.out_indices = np.arange(self.dim) if out_indices is None else out_indices + self.actions = group_actions + self.inverse_actions = inverse_group_actions + self.kernels = torch.nn.Parameter(torch.randn([out_channels, in_channels, self.dim, self.dim])) + assert feature_dim != channel_dim + self.feature_dim = feature_dim + self.channel_dim = channel_dim + + def forward(self, m): + in_to_out = self._in_to_out() + m_indices = string.ascii_lowercase[:len(m.shape)] + m_indices = m_indices[:self.channel_dim] + "u" + m_indices[self.channel_dim+1:] + m_indices = m_indices[:self.feature_dim] + "w" + m_indices[self.feature_dim+1:] + out_indices = m_indices.replace("u", "v").replace("w", "x") + return torch.einsum(f"vuxw,{m_indices}->{out_indices}", in_to_out, m) + + def _in_to_out(self): + return torch.einsum( + "gij,cdjk,gkl->cdil", + self.actions[:, self.out_indices, :], + self.kernels, + self.inverse_actions[:, :, self.in_indices] + ) + + +class GConvPermutation(GConv): + """Group Convolution Layer based on permutations as group actions + + See GConv. The only difference are in the following parameters. + + Parameters + ---------- + group_action : np.ndarray + Index tensor of shape (group_order, dim) that defines the permutations. + inverse_group_action : np.ndarray + Index tensor of shape (group_order, dim) that defines the permutations. + + """ + def __init__( + self, + in_channels, + out_channels, + group_actions, + inverse_group_actions=None, + in_indices=None, + out_indices=None, + feature_dim=1, + channel_dim=0, + ): + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + group_actions=group_actions, + inverse_group_actions=inverse_group_actions, + in_indices=in_indices, + out_indices=out_indices, + feature_dim=feature_dim, + channel_dim=channel_dim + ) + + def _in_to_out(self): + return ( + self.kernels[:, :, self.actions[:, self.out_indices], :][ + ..., self.inverse_actions[:, self.in_indices] + ].sum(4).sum(2) + ) + + +class EquivariantNet(torch.nn.Module): + """Render net equivariant by summing over all group representations. + + Parameters + ---------- + net : torch.nn.Module + The net that is wrapped to be equivariant. + group_actions : torch.Tensor + see GConv + inverse_group_actions : torch.Tensor + see GConv + in_indices : np.ndarray + see GConv + out_indices : np.ndarray + see GConv + """ + def __init__( + self, + net, + group_actions, + inverse_group_actions=None, + in_indices=None, + out_indices=None + ): + super().__init__() + self.dim = group_actions.shape[1] + self.in_indices = np.arange(self.dim) if in_indices is None else in_indices + self.out_indices = np.arange(self.dim) if out_indices is None else out_indices + self.actions = group_actions + self.inverse_actions = inverse_group_actions + self.net = net + + def forward(self, x): + x_in_group = torch.einsum( + "gij,...j->g...i", + self.inverse_actions[:, self.in_indices, :][:, :, self.in_indices], + x + ) + out_group = self.net(x_in_group) + out = torch.einsum( + "gij,g...j->...i", + self.actions[:, self.out_indices, :][:, :, self.out_indices], + out_group + ) + return out + + +class EquivariantNeuralCollision(torch.nn.Module): + """An MRT model that is equivariant under the lattice symmetries by relaxing all moments of the same + order with the same rate. + + Parameters + ---------- + lower_tau : float + The default relaxation parameter operating on lower-order moments. + tau_net : torch.nn.Module + ... + moment_transform : Transform + The moment transformation. + in_indices : np.ndarray + Indices of the moments that the learned relaxation rates are conditined on. + If None, use all moments with order <= 2. + out_indices : np.ndarray + Indices of the moments that relaxation rates are learned for. + If None, use all moments with order > 2. + """ + def __init__(self, lower_tau, tau_net, moment_transform, in_indices=None, out_indices=None): + super().__init__() + self.trafo = moment_transform + self.lattice = moment_transform.lattice + self.tau = lower_tau + # infer moment order from moment name + self.moment_order = np.array([sum(name.count(x) for x in "xyz") for name in moment_transform.names]) + self.last_taus = None + # symmetries; wrap tau net equivariant + symmetry_group = SymmetryGroup(moment_transform.lattice.stencil) + self.in_indices = np.where(self.moment_order <= 2)[0] if in_indices is None else in_indices + self.out_indices = np.where(self.moment_order > 2)[0] if out_indices is None else out_indices + self.net = EquivariantNet( + tau_net, + symmetry_group.moment_action(moment_transform), + symmetry_group.inverse_moment_action(moment_transform), + in_indices=self.in_indices, + out_indices=self.out_indices + ) + self.net.to(dtype=self.lattice.dtype, device=self.lattice.device) + + @staticmethod + def gt_half(a): + """transform into a value > 0.5""" + result = 1.5 + torch.nn.ELU()(a) + assert (result >= 0.5).all() + return result + + def _compute_relaxation_parameters(self, m): + # move Q-axis to the back + q_dim = len(m.shape) - 1 - self.lattice.D + m = m.moveaxis(q_dim, len(m.shape)-1) + # default taus + taus = self.tau * torch.ones_like(m) + # compute higher-order taus from lower-order ones through neural network + tau = self.net(m[..., self.in_indices]) + # move Q-axis in front of grid axes + # render tau > 0.5 + tau = self.gt_half(tau) + taus[..., self.out_indices] = tau + taus = taus.moveaxis(len(tau.shape) - 1, q_dim) + return taus + + def forward(self, f): + m = self.trafo.transform(f) + taus = self._compute_relaxation_parameters(m) + self.last_taus = taus + meq = self.trafo.equilibrium(m) + m_postcollision = m - 1. / taus * (m - meq) + return self.trafo.inverse_transform(m_postcollision) + + diff --git a/lettuce/observables.py b/lettuce/observables.py index 46f88546..ff5cc792 100644 --- a/lettuce/observables.py +++ b/lettuce/observables.py @@ -82,8 +82,8 @@ def __init__(self, lattice, flow): self.wavenumbers = torch.arange(int(torch.max(wavenorms))) self.wavemask = ( - (wavenorms[..., None] > self.wavenumbers.to(dtype=lattice.dtype, device=lattice.device) - 0.5) & - (wavenorms[..., None] <= self.wavenumbers.to(dtype=lattice.dtype, device=lattice.device) + 0.5) + (wavenorms[..., None] > self.wavenumbers.to(dtype=lattice.dtype, device=lattice.device) - 0.5) & + (wavenorms[..., None] <= self.wavenumbers.to(dtype=lattice.dtype, device=lattice.device) + 0.5) ) def __call__(self, f): diff --git a/lettuce/symmetry.py b/lettuce/symmetry.py new file mode 100644 index 00000000..31242771 --- /dev/null +++ b/lettuce/symmetry.py @@ -0,0 +1,265 @@ +"""Lattice Symmetries""" + +import numpy as np + + +__all__ = [ + "is_symmetry", "are_symmetries_equal", "Symmetry", + "SymmetryGroup", "InverseSymmetry", "ChainedSymmetry", "Identity", + "Rotation90", "Reflection" +] + + +def is_symmetry(operation, stencil): + "whether the operation leaves the stencil invariant" + original_e = set(tuple(e) for e in stencil.e) + new_e = set(tuple(operation.forward(e)) for e in stencil.e) + reverse_e = set(tuple(operation.inverse(e)) for e in stencil.e) + return original_e == new_e and original_e == reverse_e + + +def are_symmetries_equal(symmetry1, symmetry2, stencil): + return np.allclose(symmetry1.forward(stencil.e), symmetry2.forward(stencil.e)) + + +class Symmetry: + """Abstract base class for symmetry operations.""" + + def __init__(self): + super().__init__() + + def forward(self, x): + return NotImplemented + + def inverse(self, x): + return NotImplemented + + def permutation(self, stencil): + assert is_symmetry(self, stencil) + other = self.forward(stencil.e) + return np.concatenate([np.where((ei == stencil.e).all(axis=-1))[0] for ei in other]) + + +class ChainedSymmetry(Symmetry): + """Stitch multiple symmetries together.""" + + def __init__(self, *symmetries): + super().__init__() + self.symmetries = [] + # unfold chains + for i, symmetry in enumerate(symmetries): + if isinstance(symmetry, ChainedSymmetry): + self.symmetries.extend([*symmetry]) + else: + self.symmetries.append(symmetry) + self.symmetries = tuple(self.symmetries) + + def forward(self, x): + for s in self.symmetries: + x = s.forward(x) + return x + + def inverse(self, x): + for s in reversed(self.symmetries): + x = s.inverse(x) + return x + + def __iter__(self): + return self.symmetries.__iter__() + + def __len__(self): + return self.symmetries.__len__() + + def __repr__(self): + return f"" + + +class InverseSymmetry(Symmetry): + """Inverse of a symmetry operation""" + + def __init__(self, delegate): + super().__init__() + self.delegate = delegate + + def forward(self, x): + return self.delegate.inverse(x) + + def inverse(self, x): + return self.delegate.forward(x) + + def __repr__(self): + return f"" + + +class Reflection(Symmetry): + """Reflection along one dimension.""" + + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + y = x.copy() + y[..., self.dim] *= -1 + return y + + def inverse(self, x): + y = x.copy() + y[..., self.dim] *= -1 + return y + + def __repr__(self): + return f"" + + +class Rotation90(Symmetry): + """Counterclockwise rotation by 90 degrees.""" + + def __init__(self, *dims): + super().__init__() + self.dims = dims + self.mat = np.array( + [ + [np.cos(np.pi / 2), -np.sin(np.pi / 2)], + [np.sin(np.pi / 2), np.cos(np.pi / 2)] + ], dtype=int + ) + + def forward(self, x): + y = x.copy() + y[..., [self.dims[0], self.dims[1]]] = np.einsum( + "ij,...i->...j", + self.mat, + y[..., [self.dims[0], self.dims[1]]] + ) + return y + + def inverse(self, x): + y = x.copy() + y[..., [self.dims[0], self.dims[1]]] = np.einsum( + "ji,...i->...j", + self.mat, + y[..., [self.dims[0], self.dims[1]]] + ) + return y + + def __repr__(self): + return f"" + + +class Identity(Symmetry): + """The identity.""" + + def forward(self, x): + return x + + def inverse(self, x): + return x + + def __repr__(self): + return "" + + +class SymmetryGroup(list): + """ + Lattice symmetry group. + """ + + def __init__(self, stencil): + super().__init__() + self.stencil = stencil + candidates = self._make_candidates(stencil.D()) + new_symmetries = {Identity()} + while len(new_symmetries) > 0: + for n in new_symmetries: + if n not in self: + self.append(n) + new_symmetries = self._new_symmetries(candidates) + self._multiplication_table = None + self._inverse_table = None + + @property + def permutations(self): + return np.stack([symmetry.permutation(self.stencil) for symmetry in self]) + + @property + def inverse_permutations(self): + return np.stack([InverseSymmetry(symmetry).permutation(self.stencil) for symmetry in self]) + + @property + def multiplication_table(self): + if self._multiplication_table is None: + m = np.zeros([len(self), len(self)], dtype=int) + for i, symmetry1 in enumerate(self): + for j, symmetry2 in enumerate(self): + m[i,j] = self.index(ChainedSymmetry(symmetry1, symmetry2)) + self._multiplication_table = m + return self._multiplication_table + + @property + def inverse_table(self): + if self._inverse_table is None: + inv = np.zeros([len(self)], dtype=int) + for i, symmetry in enumerate(self): + inv[i] = self.index(InverseSymmetry(symmetry)) + self._inverse_table = inv + return self._inverse_table + + def moment_action(self, moment_transform): + return (moment_transform.matrix[:, self.permutations] @ moment_transform.inverse).swapaxes(0, 1) + + def inverse_moment_action(self, moment_transform): + return (moment_transform.matrix[:, self.inverse_permutations] @ moment_transform.inverse).swapaxes(0, 1) + + def _new_symmetries(self, candidates): + result = [] + for c in candidates: + for s in self: + proposed = self._chain_symmetries(s, c) + if proposed not in self: + result.append(proposed) + return result + + def __contains__(self, symmetry): + for elem in self: + if are_symmetries_equal(symmetry, elem, self.stencil): + return True + return False + + def index(self, symmetry): + for i, elem in enumerate(self): + if are_symmetries_equal(symmetry, elem, self.stencil): + return i + raise KeyError("symmetry not in symmetry group") + + def __eq__(self, other): + assert super.__eq__(other) and self.stencil == other.stencil + + @staticmethod + def _make_candidates(dim): + candidates = [] + for i in range(dim): + for j in range(i + 1, dim): + candidates.append(Rotation90(i, j)) + for i in range(dim): + candidates.append(Reflection(i)) + # if for some reason we cannot reach all elements by 90-degree rotations + for i in range(dim): + for j in range(i + 1, dim): + # 180 degree rotations + candidates.append(ChainedSymmetry(Rotation90(i, j), Rotation90(i, j))) + # inverse rotations + candidates.append(InverseSymmetry(Rotation90(i, j))) + return candidates + + def _chain_symmetries(self, *symmetries): + symmetries = [ + s for s in symmetries + if not are_symmetries_equal(s, Identity(), self.stencil) + ] + if len(symmetries) == 0: + return Identity() + elif len(symmetries) == 1: + return symmetries[0] + else: + return ChainedSymmetry(*symmetries) diff --git a/lettuce/unit.py b/lettuce/unit.py index 6be89db3..663607cb 100644 --- a/lettuce/unit.py +++ b/lettuce/unit.py @@ -105,15 +105,15 @@ def convert_length_to_lu(self, length_pu): def convert_energy_to_pu(self, energy_lu): """Energy is defined here in units of [density * velocity**2]""" return ( - energy_lu * (self.characteristic_density_pu * self.characteristic_velocity_pu ** 2) - / (self.characteristic_density_lu * self.characteristic_velocity_lu ** 2) + energy_lu * (self.characteristic_density_pu * self.characteristic_velocity_pu ** 2) + / (self.characteristic_density_lu * self.characteristic_velocity_lu ** 2) ) def convert_energy_to_lu(self, energy_pu): """Energy is defined here in units of [density * velocity**2]""" return ( - energy_pu * (self.characteristic_density_lu * self.characteristic_velocity_lu ** 2) - / (self.characteristic_density_pu * self.characteristic_velocity_pu ** 2) + energy_pu * (self.characteristic_density_lu * self.characteristic_velocity_lu ** 2) + / (self.characteristic_density_pu * self.characteristic_velocity_pu ** 2) ) def convert_incompressible_energy_to_pu(self, energy_lu): diff --git a/lettuce/util.py b/lettuce/util.py index d5b66d00..ce00e86f 100644 --- a/lettuce/util.py +++ b/lettuce/util.py @@ -6,7 +6,8 @@ import torch __all__ = [ - "LettuceException", "LettuceWarning", "InefficientCodeWarning", "ExperimentalWarning", + "LettuceException", "LettuceCollisionNotDefined", "LettuceInvalidNetworkOutput", + "LettuceWarning", "InefficientCodeWarning", "ExperimentalWarning", "get_subclasses", "torch_gradient", "torch_jacobi", "grid_fine_to_coarse", "pressure_poisson" ] @@ -15,6 +16,14 @@ class LettuceException(Exception): pass +class LettuceCollisionNotDefined(Exception): + pass + + +class LettuceInvalidNetworkOutput(Exception): + pass + + class LettuceWarning(UserWarning): pass @@ -86,12 +95,12 @@ def torch_gradient(f, dx=1, order=2): out = torch.cat(dim * [f[None, ...]]) for i in range(dim): out[i, ...] = ( - weight[0] * f.roll(shifts=shift[i][0], dims=dims) + - weight[1] * f.roll(shifts=shift[i][1], dims=dims) + - weight[2] * f.roll(shifts=shift[i][2], dims=dims) + - weight[3] * f.roll(shifts=shift[i][3], dims=dims) + - weight[4] * f.roll(shifts=shift[i][4], dims=dims) + - weight[5] * f.roll(shifts=shift[i][5], dims=dims) + weight[0] * f.roll(shifts=shift[i][0], dims=dims) + + weight[1] * f.roll(shifts=shift[i][1], dims=dims) + + weight[2] * f.roll(shifts=shift[i][2], dims=dims) + + weight[3] * f.roll(shifts=shift[i][3], dims=dims) + + weight[4] * f.roll(shifts=shift[i][4], dims=dims) + + weight[5] * f.roll(shifts=shift[i][5], dims=dims) ) * torch.tensor(1.0 / dx, dtype=f.dtype, device=f.device) return out diff --git a/setup.cfg b/setup.cfg index bbbcffda..2b363f7a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,9 +8,6 @@ exclude = docs # Define setup.py command aliases here test = pytest -[tool:pytest] -collect_ignore = ['setup.py'] - # See the docstring in versioneer.py for instructions. Note that you must # re-run 'versioneer.py setup' after changing this section, and commit the diff --git a/tests/conftest.py b/tests/conftest.py index 7982e606..31c81c42 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,13 +7,38 @@ import torch from lettuce import ( - stencils, Stencil, get_subclasses, Transform, Lattice, moments + stencils, Stencil, get_subclasses, Transform, Lattice, moments, collision, Collision, + SymmetryGroup ) STENCILS = list(get_subclasses(Stencil, stencils)) TRANSFORMS = list(get_subclasses(Transform, moments)) +COLLISION_MODELS = list(get_subclasses(Collision, collision)) +# == slow tests +def pytest_addoption(parser): + parser.addoption( + "--runslow", action="store_true", default=False, help="run slow tests" + ) + + +def pytest_configure(config): + config.addinivalue_line("markers", "slow: mark test as slow to run") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--runslow"): + # --runslow given in cli: do not skip slow tests + return + skip_slow = pytest.mark.skip(reason="need --runslow option to run") + for item in items: + if "slow" in item.keywords: + item.add_marker(skip_slow) + + +# === fixtures + @pytest.fixture( params=["cpu", pytest.param( "cuda:0", marks=pytest.mark.skipif( @@ -73,3 +98,16 @@ def f_transform(request, f_all_lattices): return f, Transform(lattice) else: pytest.skip("Stencil not supported for this transform.") + + +@pytest.fixture(params=COLLISION_MODELS) +def Collision(request): + """Run a test for all stencils.""" + return request.param + + +@pytest.fixture(params=STENCILS, scope="session", autouse=True) +def symmetry_group(request): + group = SymmetryGroup(request.param) + return group + diff --git a/tests/test_collision.py b/tests/test_collision.py index ecb11fd0..2354b5c8 100644 --- a/tests/test_collision.py +++ b/tests/test_collision.py @@ -8,28 +8,24 @@ from lettuce import * -@pytest.mark.parametrize("Collision", [BGKCollision, KBCCollision2D, KBCCollision3D, TRTCollision, RegularizedCollision, - SmagorinskyCollision]) def test_collision_conserves_mass(Collision, f_all_lattices): f, lattice = f_all_lattices - if ((Collision == KBCCollision2D and lattice.stencil != D2Q9) or ( - (Collision == KBCCollision3D and lattice.stencil != D3Q27))): + try: + collision = Collision(lattice, 0.51) + except LettuceCollisionNotDefined: pytest.skip() f_old = copy(f) - collision = Collision(lattice, 0.51) f = collision(f) assert lattice.rho(f).cpu().numpy() == pytest.approx(lattice.rho(f_old).cpu().numpy()) -@pytest.mark.parametrize("Collision", [BGKCollision, KBCCollision2D, KBCCollision3D, TRTCollision, RegularizedCollision, - SmagorinskyCollision]) def test_collision_conserves_momentum(Collision, f_all_lattices): f, lattice = f_all_lattices - if ((Collision == KBCCollision2D and lattice.stencil != D2Q9) or ( - (Collision == KBCCollision3D and lattice.stencil != D3Q27))): + try: + collision = Collision(lattice, 0.51) + except LettuceCollisionNotDefined: pytest.skip() f_old = copy(f) - collision = Collision(lattice, 0.51) f = collision(f) assert lattice.j(f).cpu().numpy() == pytest.approx(lattice.j(f_old).cpu().numpy(), abs=1e-5) @@ -43,22 +39,22 @@ def test_collision_fixpoint_2x(Collision, f_all_lattices): assert f.cpu().numpy() == pytest.approx(f_old.cpu().numpy(), abs=1e-5) -@pytest.mark.parametrize("Collision", - [BGKCollision, TRTCollision, KBCCollision2D, KBCCollision3D, RegularizedCollision]) def test_collision_relaxes_shear_moments(Collision, f_all_lattices): """checks whether the collision models relax the shear moments according to the prescribed relaxation time""" f, lattice = f_all_lattices - if ((Collision == KBCCollision2D and lattice.stencil != D2Q9) or ( - (Collision == KBCCollision3D and lattice.stencil != D3Q27))): + tau = 0.6 + try: + collision = Collision(lattice, tau) + except LettuceCollisionNotDefined: pytest.skip() + if Collision == SmagorinskyCollision: + pytest.skip("Introduces additional viscosity.") rho = lattice.rho(f) u = lattice.u(f) feq = lattice.equilibrium(rho, u) shear_pre = lattice.shear_tensor(f) shear_eq_pre = lattice.shear_tensor(feq) - tau = 0.6 - coll = Collision(lattice, tau) - f_post = coll(f) + f_post = collision(f) shear_post = lattice.shear_tensor(f_post) assert shear_post.cpu().numpy() == pytest.approx((shear_pre - 1 / tau * (shear_pre - shear_eq_pre)).cpu().numpy(), abs=1e-5) @@ -72,7 +68,10 @@ def test_collision_optimizes_pseudo_entropy(Collision, f_all_lattices): (Collision == KBCCollision3D and lattice.stencil != D3Q27))): pytest.skip() tau = 0.5003 - coll_kbc = Collision(lattice, tau) + try: + coll_kbc = Collision(lattice, tau) + except LettuceCollisionNotDefined: + pytest.skip() coll_bgk = BGKCollision(lattice, tau) f_kbc = coll_kbc(f) f_bgk = coll_bgk(f) @@ -88,7 +87,6 @@ def test_collision_fixpoint_2x_MRT(Transform, dtype_device): np.random.seed(1) # arbitrary, but deterministic f = lattice.convert_to_tensor(np.random.random([lattice.Q] + [3] * lattice.D)) f_old = copy(f) - collision = MRTCollision(lattice, Transform(lattice), np.array([0.5] * 9)) + collision = MRTCollision(lattice, np.array([0.5] * 9), Transform(lattice)) f = collision(collision(f)) - print(f.cpu().numpy(), f_old.cpu().numpy()) assert f.cpu().numpy() == pytest.approx(f_old.cpu().numpy(), abs=1e-5) diff --git a/tests/test_neural.py b/tests/test_neural.py new file mode 100644 index 00000000..aa7330a1 --- /dev/null +++ b/tests/test_neural.py @@ -0,0 +1,148 @@ + +import pytest +import numpy as np +from lettuce import ( + EquivariantNeuralCollision, Lattice, + get_default_moment_transform, LettuceException, + D1Q3, D2Q9, D2Q9Dellar, LettuceInvalidNetworkOutput, + GConv, GConvPermutation, EquivariantNet +) +import torch + + +@pytest.mark.parametrize("out_channels", [1, 3]) +@pytest.mark.parametrize("in_channels", [1, 3]) +def test_gconv_equivariance(symmetry_group, dtype_device, in_channels, out_channels): + dtype, device = dtype_device + lattice = Lattice(symmetry_group.stencil, dtype=dtype, device=device) + try: + moments = get_default_moment_transform(lattice) + except LettuceException: + pytest.skip(f"No moment transform for {lattice.stencil}") + f = lattice.convert_to_tensor(np.random.random([in_channels, lattice.Q] + [3] * lattice.D)) + m = moments.transform(f) + conv = GConv( + in_channels, + out_channels, + symmetry_group.moment_action(moments), + symmetry_group.inverse_moment_action(moments) + ) + conv.to(dtype=dtype, device=device) + + def apply(rep, p): + xyz = "xyz"[:lattice.D] + return torch.einsum(f"ij, fj{xyz}->fj{xyz}", rep, p) + + for action in symmetry_group.moment_action(moments): + assert torch.allclose( + conv(apply(action, m)), + apply(action, conv(m)), + atol=1e-3 if dtype == torch.float32 else 1e-5 + ) + + +def test_gconv_permutation_equivariance(symmetry_group, dtype_device): + dtype, device = dtype_device + lattice = Lattice(symmetry_group.stencil, dtype=dtype, device=device) + f = lattice.convert_to_tensor(np.random.random([1] + [lattice.Q] + [3] * lattice.D)) + conv = GConvPermutation(1, 1, symmetry_group.permutations, symmetry_group.inverse_permutations) + conv.to(dtype=dtype, device=device) + for p in symmetry_group.permutations: + assert torch.allclose( + conv(f[:, p]), + conv(f)[:, p], + atol=1e-3 if dtype == torch.float32 else 1e-5 + ) + + +def test_equivariant_net(symmetry_group, dtype_device): + dtype, device = dtype_device + lattice = Lattice(symmetry_group.stencil, dtype=dtype, device=device) + try: + moments = get_default_moment_transform(lattice) + except LettuceException: + pytest.skip(f"No moment transform for {lattice.stencil}") + f = lattice.convert_to_tensor(np.random.random([lattice.Q] + [3] * lattice.D)) + m = moments.transform(f) + m = m.moveaxis(0, lattice.D) + net = torch.nn.Sequential(torch.nn.Linear(lattice.Q, 23), torch.nn.ReLU(), torch.nn.Linear(23, lattice.Q)) + equi = EquivariantNet( + net=net, + group_actions=symmetry_group.moment_action(moments), + inverse_group_actions=symmetry_group.inverse_moment_action(moments), + ) + equi.to(dtype=dtype, device=device) + + def apply(rep, p): + xyz = "xyz"[:lattice.D] + return torch.einsum(f"ij, {xyz}j->{xyz}i", rep, p) + + for action in symmetry_group.moment_action(moments): + assert torch.allclose( + equi(apply(action, m)), + apply(action, equi(m)), + atol=1e-3 if dtype == torch.float32 else 1e-5 + ) + + +def test_equivariant_net_selected_moments(symmetry_group, dtype_device): + if symmetry_group.stencil == D1Q3: + pytest.skip("Too few moments in D1Q3") + dtype, device = dtype_device + lattice = Lattice(symmetry_group.stencil, dtype=dtype, device=device) + try: + moments = get_default_moment_transform(lattice) + except LettuceException: + pytest.skip(f"No moment transform for {lattice.stencil}") + f = lattice.convert_to_tensor(np.random.random([lattice.Q] + [3] * lattice.D)) + m = moments.transform(f) + m = m.moveaxis(0, lattice.D) + moment_orders = np.array([sum(name.count(x) for x in "xyz") for name in moments.names]) + in_indices = np.where(moment_orders <= 2)[0] + out_indices = np.where(moment_orders > 2)[0] + net = torch.nn.Sequential( + torch.nn.Linear(len(in_indices), 23), + torch.nn.ReLU(), + torch.nn.Linear(23, len(out_indices)) + ) + equi = EquivariantNet( + net=net, + group_actions=symmetry_group.moment_action(moments), + inverse_group_actions=symmetry_group.inverse_moment_action(moments), + in_indices=in_indices, + out_indices=out_indices + ) + equi.to(dtype=dtype, device=device) + + def apply(rep, p): + xyz = "xyz"[:lattice.D] + return torch.einsum(f"ij, {xyz}j->{xyz}i", rep, p) + + for action in symmetry_group.moment_action(moments): + assert torch.allclose( + equi(apply(action[in_indices][..., in_indices], m[..., in_indices])), + apply(action[out_indices][..., out_indices], equi(m[..., in_indices])), + atol=1e-3 if dtype == torch.float32 else 1e-5 + ) + + +@pytest.mark.slow +def test_equivariant_mrt(symmetry_group, dtype_device): + if symmetry_group.stencil == D1Q3: + pytest.skip("Too few moments in D1Q3") + dtype, device = dtype_device + lattice = Lattice(symmetry_group.stencil, dtype=dtype, device=device) + try: + moment_transform = get_default_moment_transform(lattice) + except LettuceException: + pytest.skip() + moment_orders = np.array([sum(name.count(x) for x in "xyz") for name in moment_transform.names]) + net = torch.nn.Linear((moment_orders <= 2).sum(), (moment_orders > 2).sum()) + collision = EquivariantNeuralCollision(0.7, net, moment_transform) + f = lattice.convert_to_tensor(np.random.random([lattice.Q] + [3] * lattice.D)) + for p in symmetry_group.permutations: + assert torch.allclose( + collision(f[p]), + collision(f)[p], + atol=1e-2 if dtype == torch.float32 else 1e-4 + ) diff --git a/tests/test_symmetry.py b/tests/test_symmetry.py new file mode 100644 index 00000000..6b3d0681 --- /dev/null +++ b/tests/test_symmetry.py @@ -0,0 +1,193 @@ +import pytest +import torch +import numpy as np +from lettuce.symmetry import * +from lettuce.stencils import D1Q3, D2Q9, D3Q19, D3Q27 +from lettuce.lattices import Lattice +from lettuce.collision import MRTCollision +from lettuce.util import LettuceCollisionNotDefined +from lettuce.moments import DEFAULT_TRANSFORM + + +pytestmark = pytest.mark.slow + + +def test_four_rotations(stencil): + for i in range(stencil.D()): + for j in range(i + 1, stencil.D()): + assert are_symmetries_equal( + ChainedSymmetry(*([Rotation90(i, j)] * 4)), + Identity(), + stencil=stencil + ) + assert not are_symmetries_equal( + ChainedSymmetry(*([Rotation90(i, j)] * 3)), + Identity(), + stencil=stencil + ) + + +def test_two_reflections(stencil): + for i in range(stencil.D()): + assert are_symmetries_equal( + ChainedSymmetry(*([Reflection(i)] * 2)), + Identity(), + stencil=stencil + ) + assert not are_symmetries_equal( + ChainedSymmetry(*([Reflection(i)] * 3)), + Identity(), + stencil=stencil + ) + + +def test_reflection_by_rotations(stencil): + for i in range(stencil.D()): + for j in range(i + 1, stencil.D()): + assert are_symmetries_equal( + ChainedSymmetry(Rotation90(i, j), Rotation90(i, j)), + ChainedSymmetry(Reflection(i), Reflection(j)), + stencil=stencil + ) + + +def test_inverse(stencil): + for i in range(stencil.D()): + for j in range(i + 1, stencil.D()): + assert are_symmetries_equal( + ChainedSymmetry(Rotation90(i, j), InverseSymmetry(Rotation90(i, j))), + Identity(), + stencil=stencil + ) + assert are_symmetries_equal( + ChainedSymmetry(Reflection(i), InverseSymmetry(Reflection(i))), + Identity(), + stencil=stencil + ) + + +def test_symmetry_group(symmetry_group): + group = symmetry_group + n_symmetries = {D1Q3: 2, D2Q9: 8, D3Q19: 48, D3Q27: 48}[group.stencil] + assert len(group) == n_symmetries + assert group.permutations.shape == (n_symmetries, group.stencil.Q()) + + # assert that it's a group: + # contains the identity + assert Identity() in group + for s1 in group: + for s2 in group: + if s1 != s2: + # elements are unique + assert not are_symmetries_equal(s1, s2, group.stencil) + # contains the inverse + assert InverseSymmetry(s1) in group + + +def test_permutations(symmetry_group): + for g in symmetry_group: + # symmetry operation is equal to the corresponding permutation + assert np.allclose( + g.forward(symmetry_group.stencil.e), + symmetry_group.stencil.e[g.permutation(symmetry_group.stencil), ...] + ) + # inverse permutation of permutation is identity + perm1 = g.permutation(symmetry_group.stencil) + perm2 = InverseSymmetry(g).permutation(symmetry_group.stencil) + assert np.allclose(perm1[perm2], np.arange(symmetry_group.stencil.Q())) + assert np.allclose(perm2[perm1], np.arange(symmetry_group.stencil.Q())) + + +def test_inverse_permutations(symmetry_group): + for p, pinv in zip(symmetry_group.permutations, symmetry_group.inverse_permutations): + assert (p[pinv] == np.arange(symmetry_group.stencil.Q())).all() + assert (pinv[p] == np.arange(symmetry_group.stencil.Q())).all() + + +def test_moment_representations(symmetry_group): + try: + transform = DEFAULT_TRANSFORM[symmetry_group.stencil] + except KeyError: + pytest.skip("No default transform for this stencil") + rep = symmetry_group.moment_action(transform) + irep = symmetry_group.inverse_moment_action(transform) + # test if this is a representation + # group op = matrix multiply + for i, symmetry in enumerate(symmetry_group): + for j, symmetry2 in enumerate(symmetry_group): + ji = symmetry_group.index(ChainedSymmetry(symmetry, symmetry2)) + assert np.allclose(rep[j] @ rep[i], rep[ji]) + # inverse group op = inverse matrix + for forward, inverse in zip(rep, irep): + assert np.allclose(forward @ inverse, np.eye(symmetry_group.stencil.Q())) + assert np.allclose(inverse @ forward, np.eye(symmetry_group.stencil.Q())) + + +def test_feq_equivariance(symmetry_group, dtype_device): + dtype, device = dtype_device + lattice = Lattice(symmetry_group.stencil, dtype=dtype, device=device) + feq = lambda x: lattice.equilibrium(lattice.rho(x), lattice.u(x)) + f = lattice.convert_to_tensor(np.random.random([lattice.Q] + [3] * lattice.D)) + for g in symmetry_group: + assert torch.allclose( + feq(f[g.permutation(symmetry_group.stencil)]), + feq(f)[g.permutation(symmetry_group.stencil)], + ) + + +def test_collision_equivariance(symmetry_group, dtype_device, Collision): + """Test whether all collision models obey the lattice symmetries.""" + dtype, device = dtype_device + lattice = Lattice(symmetry_group.stencil, dtype=dtype, device=device) + f = lattice.convert_to_tensor(np.random.random([lattice.Q] + [3] * lattice.D)) + try: + collision = Collision(lattice, 0.51) + except LettuceCollisionNotDefined: + pytest.skip() + f_post = collision(f.clone()) + for permutation in symmetry_group.permutations: + f_post_after_g = collision(f.clone()[permutation]) + assert torch.allclose( + f_post_after_g, + f_post[permutation], + atol=2e-5 if dtype == torch.float32 else 1e-7 + ), f"{(f_post_after_g - f_post[permutation]).norm()}" + + +def test_non_equivariant_mrt(dtype_device): + dtype, device = dtype_device + stencil = D2Q9 + lattice = Lattice(stencil, dtype=dtype, device=device) + symmetry_group = SymmetryGroup(D2Q9) + # non-equivariant choice of relaxation parameters + collision = MRTCollision(lattice, torch.arange(9.0)) + f = lattice.convert_to_tensor(np.random.random([lattice.Q] + [3] * lattice.D)) + f_post = collision(f.clone()) + is_equivariant = True + for permutation in symmetry_group.permutations: + f_post_after_g = collision(f.clone()[permutation]) + are_equal = torch.allclose( + f_post_after_g, + f_post[permutation], + atol=2e-5 if dtype == torch.float32 else 1e-7 + ) + if not are_equal: + is_equivariant = False + assert not is_equivariant + + +def test_multiplication_and_inverse_table(symmetry_group): + for i, symmetry1 in enumerate(symmetry_group): + assert are_symmetries_equal( + symmetry_group[symmetry_group.inverse_table[i]], + InverseSymmetry(symmetry1), + symmetry_group.stencil + ) + for j, symmetry2 in enumerate(symmetry_group): + assert are_symmetries_equal( + symmetry_group[symmetry_group.multiplication_table[i, j]], + ChainedSymmetry(symmetry1, symmetry2), + symmetry_group.stencil + ) + +