Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions prototorch/core/distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,19 @@ def omega_distance(x, y, omega):
return distances


def ML_omega_distance(x, y, omegas, masks):
"""Multi-Layer Omega distance."""
x, y = (arr.view(arr.size(0), -1) for arr in (x, y))
omegas = [torch.mul(_omega, _mask) for _omega, _mask in zip(omegas, masks)]
omega = omegas[0] @ omegas[1]
for _omega in omegas[2:]:
omega = omega @ _omega
projected_x = x @ omega
projected_y = y @ omega
distances = squared_euclidean_distance(projected_x, projected_y)
return distances


def lomega_distance(x, y, omegas):
r"""Localized Omega distance.

Expand Down
30 changes: 18 additions & 12 deletions prototorch/core/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# Components
class AbstractComponentsInitializer(ABC):
"""Abstract class for all components initializers."""

...


Expand All @@ -34,9 +35,9 @@ def generate(self, num_components: int = 0):
"""Ignore `num_components` and simply return `self.components`."""
provided_num_components = len(self.components)
if provided_num_components != num_components:
wmsg = f"The number of components ({provided_num_components}) " \
f"provided to {self.__class__.__name__} " \
f"does not match the expected number ({num_components})."
wmsg = (f"The number of components ({provided_num_components}) "
f"provided to {self.__class__.__name__} "
f"does not match the expected number ({num_components}).")
warnings.warn(wmsg)
if not isinstance(self.components, torch.Tensor):
wmsg = f"Converting components to {torch.Tensor}..."
Expand Down Expand Up @@ -127,10 +128,12 @@ class AbstractDataAwareCompInitializer(AbstractComponentsInitializer):

"""

def __init__(self,
data: torch.Tensor,
noise: float = 0.0,
transform: Callable = torch.nn.Identity()):
def __init__(
self,
data: torch.Tensor,
noise: float = 0.0,
transform: Callable = torch.nn.Identity(),
):
self.data = data
self.noise = noise
self.transform = transform
Expand Down Expand Up @@ -429,6 +432,7 @@ def generate(self, distribution: Union[dict, list, tuple]):
# Transforms
class AbstractTransformInitializer(ABC):
"""Abstract class for all transform initializers."""

...


Expand Down Expand Up @@ -486,11 +490,13 @@ def generate(self, in_dim: int, out_dim: int):
class AbstractDataAwareLTInitializer(AbstractLinearTransformInitializer):
"""Abstract class for all data-aware linear transform initializers."""

def __init__(self,
data: torch.Tensor,
noise: float = 0.0,
transform: Callable = torch.nn.Identity(),
out_dim_first: bool = False):
def __init__(
self,
data: torch.Tensor,
noise: float = 0.0,
transform: Callable = torch.nn.Identity(),
out_dim_first: bool = False,
):
super().__init__(out_dim_first)
self.data = data
self.noise = noise
Expand Down