Skip to content

Releases: datthinh1801/CKA-pytorch

v0.1.2

03 Jul 16:34

Choose a tag to compare

v0.1.1

Clean up redundant reference sub-repository.

v0.1.1

03 Jul 16:31

Choose a tag to compare

v0.1.1

✨ Implement CKA Matrix Visualization ✨

We're thrilled to announce a powerful new addition that transforms how you analyze and compare your deep learning models. This release introduces the highly anticipated plot_cka function, designed to bring clarity and insight to your model analysis.

What's New?

  • Intuitive CKA Matrix Plotting: Visualize the Centered Kernel Alignment (CKA) matrix between any two models with unprecedented ease. Understand the similarities and differences in their representations at a glance.
  • Rich Customization: Tailor your visualizations to perfection! The plot_cka function comes packed with extensive options for:
    • 🎨 Colormaps: Choose the perfect palette to highlight your data.
    • 🏷️ Axis Labels & Annotations: Add context and detail to your plots.
    • 📐 Layout Control: Fine-tune the presentation for maximum impact.
    • 💾 Effortless Saving: Export your insights with dedicated saving parameters.
    • 🔍 Granular Display Options: Control tick label visibility and heatmap halves for focused analysis.

This feature empowers researchers and practitioners to gain deeper insights into model behavior, facilitating better architectural decisions and a more profound understanding of neural network representations. Dive in and illuminate the hidden connections within your models!

v0.1.0

03 Jul 15:51

Choose a tag to compare

CKA-PyTorch

🚀 What's New

This is the initial release of cka-pytorch! This release focuses on providing a simple, efficient, and easy-to-use tool for researchers and developers to compare the representations learned by neural networks.


✨ Key Features

  • GPU Accelerated: Leverages the power of GPUs for significantly faster CKA calculations.

  • Memory Efficient: Computes CKA on-the-fly using mini-batches, avoiding the need to cache large intermediate feature representations.

  • Flexible and Easy to Use: A simple and intuitive API that can be used with any PyTorch models and dataloaders.


📦 Installation

You can install cka-pytorch using pip:

pip install cka-pytorch

Quick Start

Here's a quick example of how to use cka-pytorch to calculate the CKA matrix between two models:

import torch
import seaborn as sns
import matplotlib.pyplot as plt

from torchvision.models import resnet18
from torch.utils.data import DataLoader

from cka_pytorch.cka import CKACalculator


# 1. Define your models and dataloader
model1 = resnet18(pretrained=True).cuda()
model2 = resnet18(pretrained=True).cuda() # Or a different model

# Create a dummy dataloader for demonstration
dummy_data = torch.randn(100, 3, 224, 224)
dummy_labels = torch.randint(0, 10, (100,))
dummy_dataset = torch.utils.data.TensorDataset(dummy_data, dummy_labels)
dataloader = DataLoader(dummy_dataset, batch_size=32)

# 2. Initialize the CKACalculator
calculator = CKACalculator(model1, model2, dataloader)

# 3. Calculate the CKA matrix
cka_matrix = calculator.calculate_cka_matrix()

print("CKA Matrix:")
print(cka_matrix)

# 4. Plot the CKA Matrix as heatmap
fig, ax = plt.subplots(figsize=(10,10))
sns.heatmap(cka_matrix.cpu().numpy(), ax=ax)

🤝 Community

I welcome feedback and contributions! If you have any questions, suggestions, or would like to contribute to the project, please open an issue on our GitHub repository (https://github.com/datthinh1801/CKA.pytorch/issues).