diff --git a/src/fluffyrocket/fluffyrocket.py b/src/fluffyrocket/fluffyrocket.py index 0841ee9..6eb035f 100644 --- a/src/fluffyrocket/fluffyrocket.py +++ b/src/fluffyrocket/fluffyrocket.py @@ -18,6 +18,8 @@ class FluffyRocket(MiniRocketBase): Sharpness parameter for the sigmoid function used to compute soft PPV. num_features : int, default=10,000 max_dilations_per_kernel : int, default=32 + learnable : bool, default=False + Whether the sharpness parameter is learnable. random_state : int, default=None Examples @@ -51,10 +53,14 @@ def __init__( sharpness=10.0, num_features=10_000, max_dilations_per_kernel=32, + learnable=False, random_state=None, ): super().__init__(num_features, max_dilations_per_kernel, random_state) - self.sharpness = sharpness + if learnable: + self.sharpness = torch.nn.Parameter(torch.tensor(sharpness)) + else: + self.sharpness = sharpness def ppv(self, x, biases): return torch.sigmoid(self.sharpness * (x - biases)).mean(1)