diff --git a/prototorch/core/distances.py b/prototorch/core/distances.py index b3278fd..ae66afa 100644 --- a/prototorch/core/distances.py +++ b/prototorch/core/distances.py @@ -73,6 +73,19 @@ def omega_distance(x, y, omega): return distances +def ML_omega_distance(x, y, omegas, masks): + """Multi-Layer Omega distance.""" + x, y = (arr.view(arr.size(0), -1) for arr in (x, y)) + omegas = [torch.mul(_omega, _mask) for _omega, _mask in zip(omegas, masks)] + omega = omegas[0] @ omegas[1] + for _omega in omegas[2:]: + omega = omega @ _omega + projected_x = x @ omega + projected_y = y @ omega + distances = squared_euclidean_distance(projected_x, projected_y) + return distances + + def lomega_distance(x, y, omegas): r"""Localized Omega distance. diff --git a/prototorch/core/initializers.py b/prototorch/core/initializers.py index d6ae098..0df770a 100644 --- a/prototorch/core/initializers.py +++ b/prototorch/core/initializers.py @@ -17,6 +17,7 @@ # Components class AbstractComponentsInitializer(ABC): """Abstract class for all components initializers.""" + ... @@ -34,9 +35,9 @@ def generate(self, num_components: int = 0): """Ignore `num_components` and simply return `self.components`.""" provided_num_components = len(self.components) if provided_num_components != num_components: - wmsg = f"The number of components ({provided_num_components}) " \ - f"provided to {self.__class__.__name__} " \ - f"does not match the expected number ({num_components})." + wmsg = (f"The number of components ({provided_num_components}) " + f"provided to {self.__class__.__name__} " + f"does not match the expected number ({num_components}).") warnings.warn(wmsg) if not isinstance(self.components, torch.Tensor): wmsg = f"Converting components to {torch.Tensor}..." @@ -127,10 +128,12 @@ class AbstractDataAwareCompInitializer(AbstractComponentsInitializer): """ - def __init__(self, - data: torch.Tensor, - noise: float = 0.0, - transform: Callable = torch.nn.Identity()): + def __init__( + self, + data: torch.Tensor, + noise: float = 0.0, + transform: Callable = torch.nn.Identity(), + ): self.data = data self.noise = noise self.transform = transform @@ -429,6 +432,7 @@ def generate(self, distribution: Union[dict, list, tuple]): # Transforms class AbstractTransformInitializer(ABC): """Abstract class for all transform initializers.""" + ... @@ -486,11 +490,13 @@ def generate(self, in_dim: int, out_dim: int): class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer): """Abstract class for all data-aware linear transform initializers.""" - def __init__(self, - data: torch.Tensor, - noise: float = 0.0, - transform: Callable = torch.nn.Identity(), - out_dim_first: bool = False): + def __init__( + self, + data: torch.Tensor, + noise: float = 0.0, + transform: Callable = torch.nn.Identity(), + out_dim_first: bool = False, + ): super().__init__(out_dim_first) self.data = data self.noise = noise