Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions libsvm/svm.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
from ctypes import *
from os import path, listdir, pardir
import numpy as np

try:
import scipy
Expand Down Expand Up @@ -80,15 +81,15 @@ def gen_svm_nodearray(xi, feature_max=None, isKernel=False):

xi_shift = 0 # ensure correct indices of xi
if scipy and isinstance(xi, tuple) and len(xi) == 2 \
and isinstance(xi[0], scipy.ndarray) and isinstance(xi[1],
scipy.ndarray): # for a sparse vector
and isinstance(xi[0], np.ndarray) and isinstance(xi[1],
np.ndarray): # for a sparse vector
if not isKernel:
index_range = xi[0] + 1 # index starts from 1
else:
index_range = xi[0] # index starts from 0 for precomputed kernel
if feature_max:
index_range = index_range[scipy.where(index_range <= feature_max)]
elif scipy and isinstance(xi, scipy.ndarray):
elif scipy and isinstance(xi, np.ndarray):
if not isKernel:
xi_shift = 1
index_range = xi.nonzero()[0] + 1 # index starts from 1
Expand Down Expand Up @@ -120,8 +121,8 @@ def gen_svm_nodearray(xi, feature_max=None, isKernel=False):
ret[-1].index = -1

if scipy and isinstance(xi, tuple) and len(xi) == 2 \
and isinstance(xi[0], scipy.ndarray) and isinstance(xi[1],
scipy.ndarray): # for a sparse vector
and isinstance(xi[0], np.ndarray) and isinstance(xi[1],
np.ndarray): # for a sparse vector
for idx, j in enumerate(index_range):
ret[idx].index = j
ret[idx].value = (xi[1])[idx]
Expand Down Expand Up @@ -192,16 +193,16 @@ class svm_problem(Structure):
_fields_ = genFields(_names, _types)

def __init__(self, y, x, isKernel=False):
if (not isinstance(y, (list, tuple))) and (not (scipy and isinstance(y, scipy.ndarray))):
if (not isinstance(y, (list, tuple))) and (not (scipy and isinstance(y, np.ndarray))):
raise TypeError("type of y: {0} is not supported!".format(type(y)))

if isinstance(x, (list, tuple)):
if len(y) != len(x):
raise ValueError("len(y) != len(x)")
elif scipy != None and isinstance(x, (scipy.ndarray, sparse.spmatrix)):
elif scipy != None and isinstance(x, (np.ndarray, sparse.spmatrix)):
if len(y) != x.shape[0]:
raise ValueError("len(y) != len(x)")
if isinstance(x, scipy.ndarray):
if isinstance(x, np.ndarray):
x = scipy.ascontiguousarray(x) # enforce row-major
if isinstance(x, sparse.spmatrix):
x = x.tocsr()
Expand All @@ -223,7 +224,7 @@ def __init__(self, y, x, isKernel=False):
self.n = max_idx

self.y = (c_double * l)()
if scipy != None and isinstance(y, scipy.ndarray):
if scipy != None and isinstance(y, np.ndarray):
scipy.ctypeslib.as_array(self.y, (self.l,))[:] = y
else:
for i, yi in enumerate(y): self.y[i] = yi
Expand Down