diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 67ec7f8..d5140fe 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -3,8 +3,6 @@ name: CI on: push: branches: [main, dev] - pull_request: - branches: [main] jobs: test: diff --git a/cka/cka.py b/cka/cka.py index 3bbd327..b09c4de 100644 --- a/cka/cka.py +++ b/cka/cka.py @@ -188,7 +188,7 @@ def _compute_cka_matrix( ) -> torch.Tensor: denominator = torch.sqrt(torch.clamp(hsic_xx.unsqueeze(1) * hsic_yy.unsqueeze(0), min=0.0)) denominator = torch.where(denominator == 0, 1e-6, denominator) - return hsic_xy / denominator + return torch.clamp(hsic_xy / denominator, min=0.0, max=1.0) def _make_hook( self,