diff --git a/.flake8 b/.flake8 index c48a979..fe3a3f1 100644 --- a/.flake8 +++ b/.flake8 @@ -5,4 +5,6 @@ max-line-length = 88 extend-ignore = E203, D, -exclude = src/fluffyrocket/_minirocket.py +exclude = + .github, + src/fluffyrocket/_minirocket.py diff --git a/src/fluffyrocket/__init__.py b/src/fluffyrocket/__init__.py index 13a231d..5fba596 100644 --- a/src/fluffyrocket/__init__.py +++ b/src/fluffyrocket/__init__.py @@ -1,5 +1,9 @@ """PyTorch implementation of MiniRocket with soft PPV.""" +from .fluffyrocket import FluffyRocket from .minirocket import MiniRocket -__all__ = ["MiniRocket"] +__all__ = [ + "MiniRocket", + "FluffyRocket", +] diff --git a/src/fluffyrocket/base.py b/src/fluffyrocket/base.py new file mode 100644 index 0000000..ca30ee6 --- /dev/null +++ b/src/fluffyrocket/base.py @@ -0,0 +1,132 @@ +"""Base class for MiniRocket-like transformers.""" + +import abc +from itertools import combinations + +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn import Module + +from ._minirocket import fit as _fit + +_INDICES = np.array(list(combinations(np.arange(9), 3)), dtype=np.int32) +_KERNELS = np.full((84, 1, 9), -1.0, dtype=np.float32) +_KERNELS[np.arange(84)[:, np.newaxis], 0, _INDICES] = 2.0 + +__all__ = [ + "MiniRocketBase", +] + + +class MiniRocketBase(Module, abc.ABC): + """Base class for MiniRocket-like transformers. + + Subclass should implement :meth:`ppv`. + + Parameters + ---------- + num_features : int, default=10,000 + max_dilations_per_kernel : int, default=32 + random_state : int, default=None + """ + + def __init__( + self, + num_features=10_000, + max_dilations_per_kernel=32, + random_state=None, + ): + super().__init__() + self.num_features = num_features + self.max_dilations_per_kernel = max_dilations_per_kernel + self.random_state = random_state + + def fit(self, X, y=None): + """Fit dilation and biases. + + Parameters + ---------- + X : array of shape ``(num_examples, num_channels, input_length)``. + Data type must be float32. + y : ignored argument for interface compatibility + + Returns + ------- + self + """ + _, num_channels, _ = X.shape + kernels = torch.from_numpy(_KERNELS).repeat(num_channels, 1, 1) + self.register_buffer("kernels", kernels) + + ( + num_channels_per_combination, + channel_indices, + dilations, + num_features_per_dilation, + biases, + ) = _fit(X, self.num_features, self.max_dilations_per_kernel, self.random_state) + + self.register_buffer( + "num_channels_per_combination", + torch.from_numpy(num_channels_per_combination), + ) + self.register_buffer("channel_indices", torch.from_numpy(channel_indices)) + self.register_buffer("dilations", torch.from_numpy(dilations)) + self.register_buffer( + "num_features_per_dilation", + torch.from_numpy(num_features_per_dilation), + ) + self.register_buffer("biases", torch.from_numpy(biases)) + + return self + + def forward(self, x): + _, num_channels, _ = x.shape + + features = [] + feature_index_start = 0 + combination_index = 0 + num_channels_start = 0 + + for i in range(len(self.dilations)): + dilation = self.dilations[i].item() + padding = ((9 - 1) * dilation) // 2 + num_features_this_dilation = self.num_features_per_dilation[i].item() + + C = F.conv1d( + x, self.kernels, padding=padding, dilation=dilation, groups=num_channels + ) + C = C.view(-1, num_channels, 84, C.shape[-1]) + + for j in range(84): + feature_index_end = feature_index_start + num_features_this_dilation + num_channels_this_combination = self.num_channels_per_combination[ + combination_index + ].item() + num_channels_end = num_channels_start + num_channels_this_combination + channels_this_combination = self.channel_indices[ + num_channels_start:num_channels_end + ] + + C_sum = torch.sum(C[:, channels_this_combination, j, :], dim=1) + + biases_this_kernel = self.biases[feature_index_start:feature_index_end] + + if (i + j) % 2 == 0: + ppv = self.ppv(C_sum.unsqueeze(-1), biases_this_kernel) + else: + ppv = self.ppv( + C_sum[:, padding:-padding].unsqueeze(-1), biases_this_kernel + ) + features.append(ppv) + + feature_index_start = feature_index_end + combination_index += 1 + num_channels_start = num_channels_end + + return torch.cat(features, dim=1) + + @abc.abstractmethod + def ppv(self, x, biases): + raise NotImplementedError diff --git a/src/fluffyrocket/fluffyrocket.py b/src/fluffyrocket/fluffyrocket.py new file mode 100644 index 0000000..0841ee9 --- /dev/null +++ b/src/fluffyrocket/fluffyrocket.py @@ -0,0 +1,60 @@ +"""PyTorch implementation of MiniRocket with soft PPV.""" + +import torch + +from .base import MiniRocketBase + +__all__ = [ + "FluffyRocket", +] + + +class FluffyRocket(MiniRocketBase): + """PyTorch MiniRocket with soft PPV. + + Parameters + ---------- + sharpness : float, default=10.0 + Sharpness parameter for the sigmoid function used to compute soft PPV. + num_features : int, default=10,000 + max_dilations_per_kernel : int, default=32 + random_state : int, default=None + + Examples + -------- + >>> from aeon.datasets import load_unit_test + >>> import torch + >>> from fluffyrocket import FluffyRocket + >>> from fluffyrocket._minirocket import fit, transform + >>> X, _ = load_unit_test() + >>> X = X.astype("float32") + >>> trf_original = transform(X, fit(X, num_features=84, seed=42)) + + Small sharpness gives a smoother approximation of PPV, making it more + differentiable. + + >>> fluffyrocket = FluffyRocket(1.0, num_features=84, random_state=42).fit(X) + >>> trf_torch = fluffyrocket(torch.from_numpy(X)) + >>> torch.allclose(torch.from_numpy(trf_original), trf_torch) + False + + Large sharpness approximates hard PPV used in the original MiniRocket. + + >>> fluffyrocket = FluffyRocket(1000, num_features=84, random_state=42).fit(X) + >>> trf_torch = fluffyrocket(torch.from_numpy(X)) + >>> torch.allclose(torch.from_numpy(trf_original), trf_torch) + True + """ + + def __init__( + self, + sharpness=10.0, + num_features=10_000, + max_dilations_per_kernel=32, + random_state=None, + ): + super().__init__(num_features, max_dilations_per_kernel, random_state) + self.sharpness = sharpness + + def ppv(self, x, biases): + return torch.sigmoid(self.sharpness * (x - biases)).mean(1) diff --git a/src/fluffyrocket/minirocket.py b/src/fluffyrocket/minirocket.py index 97acd8a..42cfcc4 100644 --- a/src/fluffyrocket/minirocket.py +++ b/src/fluffyrocket/minirocket.py @@ -1,25 +1,13 @@ """PyTorch implementation of the original MiniRocket with hard PPV.""" -from itertools import combinations - -import numpy as np -import torch -import torch.nn.functional as F -from torch.nn import Module - -from ._minirocket import fit as _fit +from .base import MiniRocketBase __all__ = [ "MiniRocket", ] -_INDICES = np.array(list(combinations(np.arange(9), 3)), dtype=np.int32) -_KERNELS = np.full((84, 1, 9), -1.0, dtype=np.float32) -_KERNELS[np.arange(84)[:, np.newaxis], 0, _INDICES] = 2.0 - - -class MiniRocket(Module): +class MiniRocket(MiniRocketBase): """PyTorch MiniRocket [1]_ with hard PPV. This class aims to exactly reproduce transformation result from @@ -53,100 +41,5 @@ class MiniRocket(Module): True """ - def __init__( - self, - num_features=10_000, - max_dilations_per_kernel=32, - random_state=None, - ): - super().__init__() - self.num_features = num_features - self.max_dilations_per_kernel = max_dilations_per_kernel - self.random_state = random_state - - def fit(self, X, y=None): - """Fit dilation and biases. - - Parameters - ---------- - X : array of shape ``(num_examples, num_channels, input_length)``. - Data type must be float32. - y : ignored argument for interface compatibility - - Returns - ------- - self - """ - _, num_channels, _ = X.shape - kernels = torch.from_numpy(_KERNELS).repeat(num_channels, 1, 1) - self.register_buffer("kernels", kernels) - - ( - num_channels_per_combination, - channel_indices, - dilations, - num_features_per_dilation, - biases, - ) = _fit(X, self.num_features, self.max_dilations_per_kernel, self.random_state) - - self.register_buffer( - "num_channels_per_combination", - torch.from_numpy(num_channels_per_combination), - ) - self.register_buffer("channel_indices", torch.from_numpy(channel_indices)) - self.register_buffer("dilations", torch.from_numpy(dilations)) - self.register_buffer( - "num_features_per_dilation", - torch.from_numpy(num_features_per_dilation), - ) - self.register_buffer("biases", torch.from_numpy(biases)) - - return self - - def forward(self, x): - _, num_channels, _ = x.shape - - features = [] - feature_index_start = 0 - combination_index = 0 - num_channels_start = 0 - - for i in range(len(self.dilations)): - dilation = self.dilations[i].item() - padding = ((9 - 1) * dilation) // 2 - num_features_this_dilation = self.num_features_per_dilation[i].item() - - C = F.conv1d( - x, self.kernels, padding=padding, dilation=dilation, groups=num_channels - ) - C = C.view(-1, num_channels, 84, C.shape[-1]) - - for j in range(84): - feature_index_end = feature_index_start + num_features_this_dilation - num_channels_this_combination = self.num_channels_per_combination[ - combination_index - ].item() - num_channels_end = num_channels_start + num_channels_this_combination - channels_this_combination = self.channel_indices[ - num_channels_start:num_channels_end - ] - - C_sum = torch.sum(C[:, channels_this_combination, j, :], dim=1) - - biases_this_kernel = self.biases[feature_index_start:feature_index_end] - - if (i + j) % 2 == 0: - ppv = (C_sum.unsqueeze(-1) > biases_this_kernel).float().mean(1) - else: - ppv = ( - (C_sum[:, padding:-padding].unsqueeze(-1) > biases_this_kernel) - .float() - .mean(1) - ) - features.append(ppv) - - feature_index_start = feature_index_end - combination_index += 1 - num_channels_start = num_channels_end - - return torch.cat(features, dim=1) + def ppv(self, x, biases): + return (x > biases).float().mean(1)