Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ max-line-length = 88
extend-ignore =
E203,
D,
exclude = src/fluffyrocket/_minirocket.py
exclude =
.github,
src/fluffyrocket/_minirocket.py
6 changes: 5 additions & 1 deletion src/fluffyrocket/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""PyTorch implementation of MiniRocket with soft PPV."""

from .fluffyrocket import FluffyRocket
from .minirocket import MiniRocket

__all__ = ["MiniRocket"]
__all__ = [
"MiniRocket",
"FluffyRocket",
]
132 changes: 132 additions & 0 deletions src/fluffyrocket/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""Base class for MiniRocket-like transformers."""

import abc
from itertools import combinations

import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import Module

from ._minirocket import fit as _fit

_INDICES = np.array(list(combinations(np.arange(9), 3)), dtype=np.int32)
_KERNELS = np.full((84, 1, 9), -1.0, dtype=np.float32)
_KERNELS[np.arange(84)[:, np.newaxis], 0, _INDICES] = 2.0

__all__ = [
"MiniRocketBase",
]


class MiniRocketBase(Module, abc.ABC):
"""Base class for MiniRocket-like transformers.

Subclass should implement :meth:`ppv`.

Parameters
----------
num_features : int, default=10,000
max_dilations_per_kernel : int, default=32
random_state : int, default=None
"""

def __init__(
self,
num_features=10_000,
max_dilations_per_kernel=32,
random_state=None,
):
super().__init__()
self.num_features = num_features
self.max_dilations_per_kernel = max_dilations_per_kernel
self.random_state = random_state

def fit(self, X, y=None):
"""Fit dilation and biases.

Parameters
----------
X : array of shape ``(num_examples, num_channels, input_length)``.
Data type must be float32.
y : ignored argument for interface compatibility

Returns
-------
self
"""
_, num_channels, _ = X.shape
kernels = torch.from_numpy(_KERNELS).repeat(num_channels, 1, 1)
self.register_buffer("kernels", kernels)

(
num_channels_per_combination,
channel_indices,
dilations,
num_features_per_dilation,
biases,
) = _fit(X, self.num_features, self.max_dilations_per_kernel, self.random_state)

self.register_buffer(
"num_channels_per_combination",
torch.from_numpy(num_channels_per_combination),
)
self.register_buffer("channel_indices", torch.from_numpy(channel_indices))
self.register_buffer("dilations", torch.from_numpy(dilations))
self.register_buffer(
"num_features_per_dilation",
torch.from_numpy(num_features_per_dilation),
)
self.register_buffer("biases", torch.from_numpy(biases))

return self

def forward(self, x):
_, num_channels, _ = x.shape

features = []
feature_index_start = 0
combination_index = 0
num_channels_start = 0

for i in range(len(self.dilations)):
dilation = self.dilations[i].item()
padding = ((9 - 1) * dilation) // 2
num_features_this_dilation = self.num_features_per_dilation[i].item()

C = F.conv1d(
x, self.kernels, padding=padding, dilation=dilation, groups=num_channels
)
C = C.view(-1, num_channels, 84, C.shape[-1])

for j in range(84):
feature_index_end = feature_index_start + num_features_this_dilation
num_channels_this_combination = self.num_channels_per_combination[
combination_index
].item()
num_channels_end = num_channels_start + num_channels_this_combination
channels_this_combination = self.channel_indices[
num_channels_start:num_channels_end
]

C_sum = torch.sum(C[:, channels_this_combination, j, :], dim=1)

biases_this_kernel = self.biases[feature_index_start:feature_index_end]

if (i + j) % 2 == 0:
ppv = self.ppv(C_sum.unsqueeze(-1), biases_this_kernel)
else:
ppv = self.ppv(
C_sum[:, padding:-padding].unsqueeze(-1), biases_this_kernel
)
features.append(ppv)

feature_index_start = feature_index_end
combination_index += 1
num_channels_start = num_channels_end

return torch.cat(features, dim=1)

@abc.abstractmethod
def ppv(self, x, biases):
raise NotImplementedError
60 changes: 60 additions & 0 deletions src/fluffyrocket/fluffyrocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""PyTorch implementation of MiniRocket with soft PPV."""

import torch

from .base import MiniRocketBase

__all__ = [
"FluffyRocket",
]


class FluffyRocket(MiniRocketBase):
"""PyTorch MiniRocket with soft PPV.

Parameters
----------
sharpness : float, default=10.0
Sharpness parameter for the sigmoid function used to compute soft PPV.
num_features : int, default=10,000
max_dilations_per_kernel : int, default=32
random_state : int, default=None

Examples
--------
>>> from aeon.datasets import load_unit_test
>>> import torch
>>> from fluffyrocket import FluffyRocket
>>> from fluffyrocket._minirocket import fit, transform
>>> X, _ = load_unit_test()
>>> X = X.astype("float32")
>>> trf_original = transform(X, fit(X, num_features=84, seed=42))

Small sharpness gives a smoother approximation of PPV, making it more
differentiable.

>>> fluffyrocket = FluffyRocket(1.0, num_features=84, random_state=42).fit(X)
>>> trf_torch = fluffyrocket(torch.from_numpy(X))
>>> torch.allclose(torch.from_numpy(trf_original), trf_torch)
False

Large sharpness approximates hard PPV used in the original MiniRocket.

>>> fluffyrocket = FluffyRocket(1000, num_features=84, random_state=42).fit(X)
>>> trf_torch = fluffyrocket(torch.from_numpy(X))
>>> torch.allclose(torch.from_numpy(trf_original), trf_torch)
True
"""

def __init__(
self,
sharpness=10.0,
num_features=10_000,
max_dilations_per_kernel=32,
random_state=None,
):
super().__init__(num_features, max_dilations_per_kernel, random_state)
self.sharpness = sharpness

def ppv(self, x, biases):
return torch.sigmoid(self.sharpness * (x - biases)).mean(1)
115 changes: 4 additions & 111 deletions src/fluffyrocket/minirocket.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,13 @@
"""PyTorch implementation of the original MiniRocket with hard PPV."""

from itertools import combinations

import numpy as np
import torch
import torch.nn.functional as F
from torch.nn import Module

from ._minirocket import fit as _fit
from .base import MiniRocketBase

__all__ = [
"MiniRocket",
]


_INDICES = np.array(list(combinations(np.arange(9), 3)), dtype=np.int32)
_KERNELS = np.full((84, 1, 9), -1.0, dtype=np.float32)
_KERNELS[np.arange(84)[:, np.newaxis], 0, _INDICES] = 2.0


class MiniRocket(Module):
class MiniRocket(MiniRocketBase):
"""PyTorch MiniRocket [1]_ with hard PPV.

This class aims to exactly reproduce transformation result from
Expand Down Expand Up @@ -53,100 +41,5 @@ class MiniRocket(Module):
True
"""

def __init__(
self,
num_features=10_000,
max_dilations_per_kernel=32,
random_state=None,
):
super().__init__()
self.num_features = num_features
self.max_dilations_per_kernel = max_dilations_per_kernel
self.random_state = random_state

def fit(self, X, y=None):
"""Fit dilation and biases.

Parameters
----------
X : array of shape ``(num_examples, num_channels, input_length)``.
Data type must be float32.
y : ignored argument for interface compatibility

Returns
-------
self
"""
_, num_channels, _ = X.shape
kernels = torch.from_numpy(_KERNELS).repeat(num_channels, 1, 1)
self.register_buffer("kernels", kernels)

(
num_channels_per_combination,
channel_indices,
dilations,
num_features_per_dilation,
biases,
) = _fit(X, self.num_features, self.max_dilations_per_kernel, self.random_state)

self.register_buffer(
"num_channels_per_combination",
torch.from_numpy(num_channels_per_combination),
)
self.register_buffer("channel_indices", torch.from_numpy(channel_indices))
self.register_buffer("dilations", torch.from_numpy(dilations))
self.register_buffer(
"num_features_per_dilation",
torch.from_numpy(num_features_per_dilation),
)
self.register_buffer("biases", torch.from_numpy(biases))

return self

def forward(self, x):
_, num_channels, _ = x.shape

features = []
feature_index_start = 0
combination_index = 0
num_channels_start = 0

for i in range(len(self.dilations)):
dilation = self.dilations[i].item()
padding = ((9 - 1) * dilation) // 2
num_features_this_dilation = self.num_features_per_dilation[i].item()

C = F.conv1d(
x, self.kernels, padding=padding, dilation=dilation, groups=num_channels
)
C = C.view(-1, num_channels, 84, C.shape[-1])

for j in range(84):
feature_index_end = feature_index_start + num_features_this_dilation
num_channels_this_combination = self.num_channels_per_combination[
combination_index
].item()
num_channels_end = num_channels_start + num_channels_this_combination
channels_this_combination = self.channel_indices[
num_channels_start:num_channels_end
]

C_sum = torch.sum(C[:, channels_this_combination, j, :], dim=1)

biases_this_kernel = self.biases[feature_index_start:feature_index_end]

if (i + j) % 2 == 0:
ppv = (C_sum.unsqueeze(-1) > biases_this_kernel).float().mean(1)
else:
ppv = (
(C_sum[:, padding:-padding].unsqueeze(-1) > biases_this_kernel)
.float()
.mean(1)
)
features.append(ppv)

feature_index_start = feature_index_end
combination_index += 1
num_channels_start = num_channels_end

return torch.cat(features, dim=1)
def ppv(self, x, biases):
return (x > biases).float().mean(1)