From 4f5ed6b24e51d9a549857d7c2d5f0ce95c00da1e Mon Sep 17 00:00:00 2001 From: Jisoo Song Date: Mon, 1 Dec 2025 16:01:48 +0900 Subject: [PATCH] feat: implement FluffyRocket Make sharpness learnable --- src/fluffyrocket/fluffyrocket.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/fluffyrocket/fluffyrocket.py b/src/fluffyrocket/fluffyrocket.py index 0841ee9..6eb035f 100644 --- a/src/fluffyrocket/fluffyrocket.py +++ b/src/fluffyrocket/fluffyrocket.py @@ -18,6 +18,8 @@ class FluffyRocket(MiniRocketBase): Sharpness parameter for the sigmoid function used to compute soft PPV. num_features : int, default=10,000 max_dilations_per_kernel : int, default=32 + learnable : bool, default=False + Whether the sharpness parameter is learnable. random_state : int, default=None Examples @@ -51,10 +53,14 @@ def __init__( sharpness=10.0, num_features=10_000, max_dilations_per_kernel=32, + learnable=False, random_state=None, ): super().__init__(num_features, max_dilations_per_kernel, random_state) - self.sharpness = sharpness + if learnable: + self.sharpness = torch.nn.Parameter(torch.tensor(sharpness)) + else: + self.sharpness = sharpness def ppv(self, x, biases): return torch.sigmoid(self.sharpness * (x - biases)).mean(1)