Skip to content
Merged

Dev #14

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Git
.git/
.gitignore

# Claude Code
.claude/
CLAUDE.md

# Virtual environments
.venv/
venv/
env/

# Python cache and build artifacts
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# Testing
tests/
.pytest_cache/
.coverage
htmlcov/
.tox/
.nox/

# IDE
.idea/
.vscode/
*.swp
*.swo
*~

# OS
.DS_Store
Thumbs.db

# Example outputs and images
*.png
*.jpg
*.jpeg
cka_checkpoint.pt

# Jupyter
.ipynb_checkpoints/
examples/

# mypy
.mypy_cache/

# ruff
.ruff_cache/

# Package manager
uv.lock

# Docker
Dockerfile
.dockerignore

# CI/CD
.github/
.python-version
33 changes: 33 additions & 0 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: CI

on:
push:
branches: [main, dev]
pull_request:
branches: [main]

jobs:
test:
name: Test (Python ${{ matrix.python-version }})
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
steps:
- name: Checkout repository
uses: actions/checkout@v6

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
pip install pytest pytest-cov torchvision timm transformers

- name: Run tests
run: pytest tests/ -v --cov=cka --cov-report=term-missing
51 changes: 51 additions & 0 deletions .github/workflows/docker.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
name: Docker

on:
release:
types: [published]
workflow_dispatch:

env:
REGISTRY: ghcr.io
IMAGE_NAME: ryusudol/pytorch-cka

jobs:
build-and-push:
name: Build and Push Docker Image
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Checkout repository
uses: actions/checkout@v6

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3

- name: Log in to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}

- name: Extract metadata (tags, labels)
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
type=raw,value=latest

- name: Build and push Docker image
uses: docker/build-push-action@v6
with:
context: .
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
13 changes: 3 additions & 10 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
# Reference repositories (for analysis only)
centered-kernel-alignment/
PyTorch-Model-Compare/

# Claude Code
.claude/

# Virtual environments
.venv/
venv/
Expand Down Expand Up @@ -59,8 +52,8 @@ cka_checkpoint.pt
# Jupyter
.ipynb_checkpoints/

# mypy
.mypy_cache/

.ruff_cache/
.claude/
CLAUDE.md
uv.lock
tests/
44 changes: 44 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
ARG PYTHON_VERSION=3.10
ARG PYTORCH_VERSION=2.0.0

# Build stage
FROM python:${PYTHON_VERSION}-slim AS builder

ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PIP_NO_CACHE_DIR=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1

WORKDIR /app

RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
&& rm -rf /var/lib/apt/lists/*

# Create virtual environment
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"

COPY pyproject.toml README.md LICENSE ./
COPY cka/ ./cka/

ARG PYTORCH_VERSION
RUN pip install torch==${PYTORCH_VERSION} --index-url https://download.pytorch.org/whl/cpu && \
pip install .

# Runtime stage
FROM python:${PYTHON_VERSION}-slim

ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1 \
PATH="/opt/venv/bin:$PATH"

WORKDIR /app

# Copy virtual environment from builder
COPY --from=builder /opt/venv /opt/venv

RUN useradd --create-home --shell /bin/bash cka
USER cka

CMD ["python", "-c", "from cka import CKA; print('pytorch-cka is ready! Import with: from cka import CKA')"]
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ pip install pytorch-cka
# Using uv
uv add pytorch-cka

# Using docker
docker pull ghcr.io/ryusudol/pytorch-cka

# From source
git clone https://github.com/ryusudol/Centered-Kernel-Alignment
cd pytorch-cka
Expand All @@ -36,7 +39,7 @@ uv sync # or: pip install -e .

```python
from torch.utils.data import DataLoader
from pytorch_cka import CKA
from cka import CKA

pretrained_model = ... # e.g. pretrained ResNet-18
fine_tuned_model = ... # e.g. fine-tuned ResNet-18
Expand Down Expand Up @@ -70,7 +73,7 @@ with cka:
**Heatmap**

```python
from pytorch_cka import plot_cka_heatmap
from cka import plot_cka_heatmap

fig, ax = plot_cka_heatmap(
cka_matrix,
Expand Down Expand Up @@ -100,7 +103,7 @@ fig, ax = plot_cka_heatmap(
**Trend Plot**

```python
from pytorch_cka import plot_cka_trend
from cka import plot_cka_trend

# Plot diagonal (self-similarity across layers)
diagonal = torch.diag(matrix)
Expand All @@ -126,7 +129,7 @@ fig, ax = plot_cka_trend(
<!-- **Side-by-Side Comparison**

```python
from pytorch_cka import plot_cka_comparison
from cka import plot_cka_comparison

fig, axes = plot_cka_comparison(
matrices=[matrix1, matrix2, matrix3],
Expand Down
2 changes: 1 addition & 1 deletion cka/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
representations using Centered Kernel Alignment (CKA).

Example:
>>> from pytorch_cka import CKA
>>> from cka import CKA
>>>
>>> with CKA(model1, model2, model1_layers=["layer1", "layer2"]) as cka:
... matrix = cka.compare(dataloader)
Expand Down
15 changes: 11 additions & 4 deletions cka/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,12 @@ def shorten_name(name: str, depth: int | None) -> str:

# Set tick labels
if layers1 is not None:
shortened = [shorten_name(l, layer_name_depth) for l in layers1]
shortened = [shorten_name(layer, layer_name_depth) for layer in layers1]
ax.set_yticks(range(n_layers1))
ax.set_yticklabels(shortened, fontsize=tick_fontsize)

if layers2 is not None:
shortened = [shorten_name(l, layer_name_depth) for l in layers2]
shortened = [shorten_name(layer, layer_name_depth) for layer in layers2]
ax.set_xticks(range(n_layers2))
ax.set_xticklabels(shortened, fontsize=tick_fontsize, rotation=45, ha="right")

Expand Down Expand Up @@ -197,7 +197,11 @@ def plot_cka_trend(
"""
# Normalize input to list of arrays
if isinstance(cka_values, (torch.Tensor, np.ndarray)):
arr = cka_values.detach().cpu().numpy() if isinstance(cka_values, torch.Tensor) else cka_values
arr = (
cka_values.detach().cpu().numpy()
if isinstance(cka_values, torch.Tensor)
else cka_values
)
if arr.ndim == 1:
cka_values = [arr]
else:
Expand Down Expand Up @@ -304,6 +308,7 @@ def plot_cka_trend_with_range(
Returns:
Tuple of (Figure, Axes).
"""

# Convert to numpy
def to_numpy(arr):
if isinstance(arr, torch.Tensor):
Expand Down Expand Up @@ -415,7 +420,7 @@ def plot_cka_comparison(
if figsize is None:
figsize = (5 * ncols, 4 * nrows)

fig, axes = plt.subplots(nrows, ncols, figsize=figsize, constrained_layout=True)
fig, axes = plt.subplots(nrows, ncols, figsize=figsize, constrained_layout=share_colorbar)
axes = np.atleast_2d(axes)

# Find global min/max for shared colorbar
Expand Down Expand Up @@ -458,6 +463,8 @@ def plot_cka_comparison(
sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
sm.set_array([])
fig.colorbar(sm, ax=axes, fraction=0.02, pad=0.02, label="CKA Similarity")
else:
fig.tight_layout()

if show:
plt.show()
Expand Down
7 changes: 1 addition & 6 deletions examples/basic_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,7 @@
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

from pytorch_cka import CKA, plot_cka_heatmap, plot_cka_trend


# =============================================================================
# Define example models
# =============================================================================
from cka import CKA, plot_cka_heatmap, plot_cka_trend


class SimpleCNN(nn.Module):
Expand Down
Loading