Skip to content

YikaiZhang95/torchkm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

82 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

torchkm

GitHub Release

This is a PyTorch-based package to solve kernel SVM with GPU.

Table of contents

Introduction

torchkm, a PyTorch-based library that trains kernel SVMs and other large-margin classifiers with exact leave-one-out cross-validation (LOOCV) error computation. Conventional SVM solvers often face scalability and efficiency challenges, especially on large datasets or when multiple cross-validation runs are required. torchkm computes LOOCV at the same cost as training a single SVM while boosting speed and scalability via CUDA-accelerated matrix operations. Benchmark experiments indicate that TorchKSVM outperforms existing kernel SVM solvers in efficiency and speed. This document shows how to use the torchkm package to fit kernel SVM.

When dealing with low-dimensional problems or more complex scenarios, such as requiring non-linear decision boundaries or higher accuracy, kernel SVMs can be formulated using the kernel method within a reproducing kernel Hilbert space (RKHS). For consistency, we adopt the same notation introduced in the high-dimensional case in Chapter One.

Given a random sample $\{y_i, x_i\}_{i=1}^n$, the kernel SVM can be formulated as a function estimation problem:

kernel SVM formulation

where norm is the RKHS norm that acts as a regularizer, and $\lambda > 0$ is a tuning parameter.

According to the representer theorem for reproducing kernels (Wahba, 1990), the solution to our problem takes the form:

f(x) formula

The coefficients $\alpha^{SVM}$ are obtained by solving the optimization problem:

alpha optimization

where $\mathbf{K}$ is the kernel matrix.

Installation

You can use pip to install this package.

pip install torchkm

Quick start

Import necessary libraries and functions:

from torchkm.cvksvm import cvksvm
from torchkm.functions import *
import torch
import numpy

The usages are similar with scikit-learn:

model = cvksvm(Kmat=Kmat, y=y_train, nlam=nlam, ulam=ulam, nfolds=nfolds, eps=1e-5, maxit=1000, gamma=1e-8, is_exact=0, device='cuda')
model.fit()

Usage

Generate simulation data

functions provides a simulation data generation function, data_gen, to generate data from a mixture of Gaussian models. functions also provides kernel operations like rbf_kernel and kernelMult, as well as data processing functions such as standardize.

# Sample data
nn = 10000 # Number of samples
nm = 5    # Number of clusters per class
pp = 10   # Number of features
p1 = p2 = pp // 2    # Number of positive/negative centers
mu = 2.0  # Mean shift
ro = 3  # Standard deviation for normal distribution
sdn = 42  # Seed for reproducibility

nlam = 50
torch.manual_seed(sdn)
ulam = torch.logspace(3, -3, steps=nlam)

X_train, y_train, means_train = data_gen(nn, nm, pp, p1, p2, mu, ro, sdn)
X_test, y_test, means_test = data_gen(nn // 10, nm, pp, p1, p2, mu, ro, sdn)
X_train = standardize(X_train)
X_test = standardize(X_test)

sig = sigest(X_train)
Kmat = rbf_kernel(X_train, sig)

Basic operation

torchkm mainly provides cvksvm to tune kernel SVM fast with GPU acceleration and compute exact leave-one-out cross-validation (LOOCV) errors if needed.

model = cvksvm(Kmat=Kmat, y=y_train, nlam=nlam, ulam=ulam, nfolds=5, eps=1e-5, maxit=1000, gamma=1e-8, is_exact=0, device='cuda')
model.fit()
### Tune parameter
cv_mis = model.cv(model.pred, y_train).numpy()
best_ind = numpy.argmin(cv_mis)
best_ind

### Test Error and objective value
Kmat = Kmat.double()
alpmat = model.alpmat.to('cpu')
intcpt = alpmat[0,best_ind]
alp = alpmat[1:,best_ind]
ka = torch.mv(Kmat, alp)
aka = torch.dot(alp, ka)
obj_magic = model.objfun(intcpt, aka, ka, y_train, ulam[best_ind], nn)

Kmat_new = kernelMult(X_test, X_train, sig)
Kmat_new = Kmat_new.double()

result = torch.mv(Kmat_new, alpmat[1:,best_ind]) + alpmat[0, best_ind]

ypred = torch.where(result > 0, torch.tensor(1), torch.tensor(-1))

torch.mean((ypred == y_test).float())

Probability estimation

### Platt scaling
oof_f = torch.where(model.pred > 0, 1, -1).to(device = 'cpu')[:, best_ind]
platt = platt.PlattScalerTorch(dtype=torch.double, device='cuda').fit(oof_f, y_train)

X_test_raw = torch.mv(Kmat_new, alpmat[1:,best_ind]) + alpmat[0, best_ind]

with torch.no_grad():
    p_platt = platt.predict_proba(X_test_raw)[:, 1].cpu().numpy()
    y_test_np = torch.as_tensor(y_test).cpu().numpy()

# Reliability data for Platt
bc, mp, fp, cnt = platt.reliability_curve(y_test_np, p_platt, n_bins=15)
ece_platt  = platt.expected_calibration_error(mp, fp, cnt)
brier_platt = platt.brier_score(y_test_np, p_platt)

platt.plot_calibration(bc, mp, fp, cnt, label=f"Platt (ECE={ece_platt:.3f}, Brier={brier_platt:.3f})")

Real data

torchkm works well with sklearn datasets. We need to convert these datasets to torch.tensor with $y=1 \text{ or} -1$.

from sklearn.datasets import make_moons

# Generate non-linear dataset
X, y = make_moons(n_samples=300, noise=0.2, random_state=42)

X = torch.from_numpy(X).float()
y = torch.from_numpy(y).float()
y = 2 * y - 1

sig = sigest(X)
Kmat = rbf_kernel(X, sig)

model = cvksvm(Kmat=Kmat, y=y, nlam=nlam, ulam=ulam, nfolds=nfolds, eps=1e-5, maxit=1000, gamma=1e-8, is_exact=0, device='cuda')
model.fit()

Extensions to large-margin classifiers

It also provides applications for other large-margin classifiers:

  1. Kernel logistic regression
     model = cvklogit(Kmat=Kmat, y=y_train, nlam=nlam, ulam=ulam, eps=1e-5, maxit=1000, gamma=1e-8, device='cuda')
     model.fit()
  2. Kernel distance weighted discrimination
    model = cvkdwd(Kmat=Kmat, y=y_train, nlam=nlam, ulam=ulam, eps=1e-5, maxit=1000, gamma=1e-8, device='cuda')
    model.fit()

Getting help

Any questions or suggestions please contact: yikai-zhang@uiowa.edu

About

A PyTorch-based package to solve kernel SVM with GPU

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages