Skip to content

Commit cda1e3b

Browse files
authored
Merge pull request #14 from ryusudol/dev
Dev
2 parents b03701f + 6d88391 commit cda1e3b

15 files changed

Lines changed: 1670 additions & 68 deletions

.dockerignore

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Git
2+
.git/
3+
.gitignore
4+
5+
# Claude Code
6+
.claude/
7+
CLAUDE.md
8+
9+
# Virtual environments
10+
.venv/
11+
venv/
12+
env/
13+
14+
# Python cache and build artifacts
15+
__pycache__/
16+
*.py[cod]
17+
*$py.class
18+
*.so
19+
.Python
20+
build/
21+
develop-eggs/
22+
dist/
23+
downloads/
24+
eggs/
25+
.eggs/
26+
lib/
27+
lib64/
28+
parts/
29+
sdist/
30+
var/
31+
wheels/
32+
*.egg-info/
33+
.installed.cfg
34+
*.egg
35+
36+
# Testing
37+
tests/
38+
.pytest_cache/
39+
.coverage
40+
htmlcov/
41+
.tox/
42+
.nox/
43+
44+
# IDE
45+
.idea/
46+
.vscode/
47+
*.swp
48+
*.swo
49+
*~
50+
51+
# OS
52+
.DS_Store
53+
Thumbs.db
54+
55+
# Example outputs and images
56+
*.png
57+
*.jpg
58+
*.jpeg
59+
cka_checkpoint.pt
60+
61+
# Jupyter
62+
.ipynb_checkpoints/
63+
examples/
64+
65+
# mypy
66+
.mypy_cache/
67+
68+
# ruff
69+
.ruff_cache/
70+
71+
# Package manager
72+
uv.lock
73+
74+
# Docker
75+
Dockerfile
76+
.dockerignore
77+
78+
# CI/CD
79+
.github/
80+
.python-version

.github/workflows/ci.yaml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
name: CI
2+
3+
on:
4+
push:
5+
branches: [main, dev]
6+
pull_request:
7+
branches: [main]
8+
9+
jobs:
10+
test:
11+
name: Test (Python ${{ matrix.python-version }})
12+
runs-on: ubuntu-latest
13+
strategy:
14+
fail-fast: false
15+
matrix:
16+
python-version: ["3.10", "3.11", "3.12", "3.13", "3.14"]
17+
steps:
18+
- name: Checkout repository
19+
uses: actions/checkout@v6
20+
21+
- name: Set up Python ${{ matrix.python-version }}
22+
uses: actions/setup-python@v6
23+
with:
24+
python-version: ${{ matrix.python-version }}
25+
26+
- name: Install dependencies
27+
run: |
28+
python -m pip install --upgrade pip
29+
pip install -e .
30+
pip install pytest pytest-cov torchvision timm transformers
31+
32+
- name: Run tests
33+
run: pytest tests/ -v --cov=cka --cov-report=term-missing

.github/workflows/docker.yaml

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
name: Docker
2+
3+
on:
4+
release:
5+
types: [published]
6+
workflow_dispatch:
7+
8+
env:
9+
REGISTRY: ghcr.io
10+
IMAGE_NAME: ryusudol/pytorch-cka
11+
12+
jobs:
13+
build-and-push:
14+
name: Build and Push Docker Image
15+
runs-on: ubuntu-latest
16+
permissions:
17+
contents: read
18+
packages: write
19+
steps:
20+
- name: Checkout repository
21+
uses: actions/checkout@v6
22+
23+
- name: Set up Docker Buildx
24+
uses: docker/setup-buildx-action@v3
25+
26+
- name: Log in to GitHub Container Registry
27+
uses: docker/login-action@v3
28+
with:
29+
registry: ${{ env.REGISTRY }}
30+
username: ${{ github.actor }}
31+
password: ${{ secrets.GITHUB_TOKEN }}
32+
33+
- name: Extract metadata (tags, labels)
34+
id: meta
35+
uses: docker/metadata-action@v5
36+
with:
37+
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
38+
tags: |
39+
type=semver,pattern={{version}}
40+
type=semver,pattern={{major}}.{{minor}}
41+
type=raw,value=latest
42+
43+
- name: Build and push Docker image
44+
uses: docker/build-push-action@v6
45+
with:
46+
context: .
47+
push: true
48+
tags: ${{ steps.meta.outputs.tags }}
49+
labels: ${{ steps.meta.outputs.labels }}
50+
cache-from: type=gha
51+
cache-to: type=gha,mode=max

.gitignore

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
# Reference repositories (for analysis only)
2-
centered-kernel-alignment/
3-
PyTorch-Model-Compare/
4-
5-
# Claude Code
6-
.claude/
7-
81
# Virtual environments
92
.venv/
103
venv/
@@ -59,8 +52,8 @@ cka_checkpoint.pt
5952
# Jupyter
6053
.ipynb_checkpoints/
6154

62-
# mypy
6355
.mypy_cache/
64-
56+
.ruff_cache/
57+
.claude/
58+
CLAUDE.md
6559
uv.lock
66-
tests/

Dockerfile

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
ARG PYTHON_VERSION=3.10
2+
ARG PYTORCH_VERSION=2.0.0
3+
4+
# Build stage
5+
FROM python:${PYTHON_VERSION}-slim AS builder
6+
7+
ENV PYTHONDONTWRITEBYTECODE=1 \
8+
PYTHONUNBUFFERED=1 \
9+
PIP_NO_CACHE_DIR=1 \
10+
PIP_DISABLE_PIP_VERSION_CHECK=1
11+
12+
WORKDIR /app
13+
14+
RUN apt-get update && apt-get install -y --no-install-recommends \
15+
build-essential \
16+
&& rm -rf /var/lib/apt/lists/*
17+
18+
# Create virtual environment
19+
RUN python -m venv /opt/venv
20+
ENV PATH="/opt/venv/bin:$PATH"
21+
22+
COPY pyproject.toml README.md LICENSE ./
23+
COPY cka/ ./cka/
24+
25+
ARG PYTORCH_VERSION
26+
RUN pip install torch==${PYTORCH_VERSION} --index-url https://download.pytorch.org/whl/cpu && \
27+
pip install .
28+
29+
# Runtime stage
30+
FROM python:${PYTHON_VERSION}-slim
31+
32+
ENV PYTHONDONTWRITEBYTECODE=1 \
33+
PYTHONUNBUFFERED=1 \
34+
PATH="/opt/venv/bin:$PATH"
35+
36+
WORKDIR /app
37+
38+
# Copy virtual environment from builder
39+
COPY --from=builder /opt/venv /opt/venv
40+
41+
RUN useradd --create-home --shell /bin/bash cka
42+
USER cka
43+
44+
CMD ["python", "-c", "from cka import CKA; print('pytorch-cka is ready! Import with: from cka import CKA')"]

README.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ pip install pytorch-cka
2424
# Using uv
2525
uv add pytorch-cka
2626

27+
# Using docker
28+
docker pull ghcr.io/ryusudol/pytorch-cka
29+
2730
# From source
2831
git clone https://github.com/ryusudol/Centered-Kernel-Alignment
2932
cd pytorch-cka
@@ -36,7 +39,7 @@ uv sync # or: pip install -e .
3639

3740
```python
3841
from torch.utils.data import DataLoader
39-
from pytorch_cka import CKA
42+
from cka import CKA
4043

4144
pretrained_model = ... # e.g. pretrained ResNet-18
4245
fine_tuned_model = ... # e.g. fine-tuned ResNet-18
@@ -70,7 +73,7 @@ with cka:
7073
**Heatmap**
7174

7275
```python
73-
from pytorch_cka import plot_cka_heatmap
76+
from cka import plot_cka_heatmap
7477

7578
fig, ax = plot_cka_heatmap(
7679
cka_matrix,
@@ -100,7 +103,7 @@ fig, ax = plot_cka_heatmap(
100103
**Trend Plot**
101104

102105
```python
103-
from pytorch_cka import plot_cka_trend
106+
from cka import plot_cka_trend
104107

105108
# Plot diagonal (self-similarity across layers)
106109
diagonal = torch.diag(matrix)
@@ -126,7 +129,7 @@ fig, ax = plot_cka_trend(
126129
<!-- **Side-by-Side Comparison**
127130
128131
```python
129-
from pytorch_cka import plot_cka_comparison
132+
from cka import plot_cka_comparison
130133
131134
fig, axes = plot_cka_comparison(
132135
matrices=[matrix1, matrix2, matrix3],

cka/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
representations using Centered Kernel Alignment (CKA).
55
66
Example:
7-
>>> from pytorch_cka import CKA
7+
>>> from cka import CKA
88
>>>
99
>>> with CKA(model1, model2, model1_layers=["layer1", "layer2"]) as cka:
1010
... matrix = cka.compare(dataloader)

cka/viz.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,12 @@ def shorten_name(name: str, depth: int | None) -> str:
128128

129129
# Set tick labels
130130
if layers1 is not None:
131-
shortened = [shorten_name(l, layer_name_depth) for l in layers1]
131+
shortened = [shorten_name(layer, layer_name_depth) for layer in layers1]
132132
ax.set_yticks(range(n_layers1))
133133
ax.set_yticklabels(shortened, fontsize=tick_fontsize)
134134

135135
if layers2 is not None:
136-
shortened = [shorten_name(l, layer_name_depth) for l in layers2]
136+
shortened = [shorten_name(layer, layer_name_depth) for layer in layers2]
137137
ax.set_xticks(range(n_layers2))
138138
ax.set_xticklabels(shortened, fontsize=tick_fontsize, rotation=45, ha="right")
139139

@@ -197,7 +197,11 @@ def plot_cka_trend(
197197
"""
198198
# Normalize input to list of arrays
199199
if isinstance(cka_values, (torch.Tensor, np.ndarray)):
200-
arr = cka_values.detach().cpu().numpy() if isinstance(cka_values, torch.Tensor) else cka_values
200+
arr = (
201+
cka_values.detach().cpu().numpy()
202+
if isinstance(cka_values, torch.Tensor)
203+
else cka_values
204+
)
201205
if arr.ndim == 1:
202206
cka_values = [arr]
203207
else:
@@ -304,6 +308,7 @@ def plot_cka_trend_with_range(
304308
Returns:
305309
Tuple of (Figure, Axes).
306310
"""
311+
307312
# Convert to numpy
308313
def to_numpy(arr):
309314
if isinstance(arr, torch.Tensor):
@@ -415,7 +420,7 @@ def plot_cka_comparison(
415420
if figsize is None:
416421
figsize = (5 * ncols, 4 * nrows)
417422

418-
fig, axes = plt.subplots(nrows, ncols, figsize=figsize, constrained_layout=True)
423+
fig, axes = plt.subplots(nrows, ncols, figsize=figsize, constrained_layout=share_colorbar)
419424
axes = np.atleast_2d(axes)
420425

421426
# Find global min/max for shared colorbar
@@ -458,6 +463,8 @@ def plot_cka_comparison(
458463
sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
459464
sm.set_array([])
460465
fig.colorbar(sm, ax=axes, fraction=0.02, pad=0.02, label="CKA Similarity")
466+
else:
467+
fig.tight_layout()
461468

462469
if show:
463470
plt.show()

examples/basic_comparison.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,7 @@
1111
import torch.nn as nn
1212
from torch.utils.data import DataLoader, TensorDataset
1313

14-
from pytorch_cka import CKA, plot_cka_heatmap, plot_cka_trend
15-
16-
17-
# =============================================================================
18-
# Define example models
19-
# =============================================================================
14+
from cka import CKA, plot_cka_heatmap, plot_cka_trend
2015

2116

2217
class SimpleCNN(nn.Module):

0 commit comments

Comments
 (0)