Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/fluffyrocket/fluffyrocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class FluffyRocket(MiniRocketBase):
Sharpness parameter for the sigmoid function used to compute soft PPV.
num_features : int, default=10,000
max_dilations_per_kernel : int, default=32
learnable : bool, default=False
Whether the sharpness parameter is learnable.
random_state : int, default=None

Examples
Expand Down Expand Up @@ -51,10 +53,14 @@ def __init__(
sharpness=10.0,
num_features=10_000,
max_dilations_per_kernel=32,
learnable=False,
random_state=None,
):
super().__init__(num_features, max_dilations_per_kernel, random_state)
self.sharpness = sharpness
if learnable:
self.sharpness = torch.nn.Parameter(torch.tensor(sharpness))
else:
self.sharpness = sharpness

def ppv(self, x, biases):
return torch.sigmoid(self.sharpness * (x - biases)).mean(1)