Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions glrm/glrm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@
class GLRM(object):

def __init__(self, A, loss, regX, regY, k, missing_list = None, converge = None, scale=True):

self.scale = scale
# Turn everything in to lists / convert to correct dimensions
if not isinstance(A, list): A = [A]
if not isinstance(loss, list): loss = [loss]
if not isinstance(regY, list): regY = [regY]
if len(regY) == 1 and len(regY) < len(loss):
if len(regY) == 1 and len(regY) < len(loss):
regY = [copy(regY[0]) for _ in range(len(loss))]
if missing_list and not isinstance(missing_list[0], list): missing_list = [missing_list]

Expand All @@ -30,8 +30,8 @@ def __init__(self, A, loss, regX, regY, k, missing_list = None, converge = None,

# initialize cvxpy problems
self._initialize_probs(A, k, missing_list, regX, regY)


def factors(self):
# return X, Y as matrices (not lists of sub matrices)
return self.X, hstack(self.Y)
Expand All @@ -45,7 +45,7 @@ def predict(self):
return hstack([L.decode(self.X.dot(yj)) for Aj, yj, L in zip(self.A, self.Y, self.L)])

def fit(self, max_iters=100, eps=1e-2, use_indirect=False, warm_start=False):

Xv, Yp, pX = self.probX
Xp, Yv, pY = self.probY
self.converge.reset()
Expand All @@ -57,7 +57,7 @@ def fit(self, max_iters=100, eps=1e-2, use_indirect=False, warm_start=False):
Xp.value[:,:-1] = copy(Xv.value)

# can parallelize this
for ypj, yvj, pyj in zip(Yp, Yv, pY):
for ypj, yvj, pyj in zip(Yp, Yv, pY):
objY = pyj.solve(solver=cp.SCS, eps=eps, max_iters=max_iters,
use_indirect=use_indirect, warm_start=warm_start)
ypj.value = copy(yvj.value)
Expand All @@ -67,7 +67,7 @@ def fit(self, max_iters=100, eps=1e-2, use_indirect=False, warm_start=False):
return self.X, self.Y

def _initialize_probs(self, A, k, missing_list, regX, regY):

# useful parameters
m = A[0].shape[0]
ns = [a.shape[1] for a in A]
Expand Down Expand Up @@ -107,7 +107,7 @@ def _initialize_A(self, A, missing_list):

# compute stdev for entries that are not missing
for ni, sv, mu, ai, missing, L in zip(ns, stdev, mean, A, missing_list, self.L):

# collect non-missing terms
for j in range(ni):
elems = array([ai[i,j] for i in range(m) if (i,j) not in missing])
Expand All @@ -123,7 +123,7 @@ def _initialize_A(self, A, missing_list):

# zero-out missing entries (for XY initialization)
for (i,j) in missing: bi[i,j], mask[i,j] = 0, 0

B.append(bi) # save
masks.append(mask)
offsets.append(offset)
Expand Down Expand Up @@ -156,14 +156,14 @@ def _initialize_XY(self, B, k, missing_list):
ns = cumsum([bj.shape[1] for bj in B])
if len(ns) == 1: Y0 = [Y0]
else: Y0 = split(Y0, ns, 1)

return X0, Y0

def _finalize_XY(self, Xv, Yv):
""" Multiply by std, offset by mean """
m, k = Xv.shape.size
m, k = Xv.size
self.X = asarray(hstack((Xv.value, ones((m,1)))))
self.Y = [asarray(yj.value)*tile(mask[0,:],(k+1,1)) \
for yj, mask in zip(Yv, self.masks)]
for offset, Y in zip(self.offsets, self.Y): Y[-1,:] += offset[0,:]

3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@
url="http://github.com/cehorn/GLRM/",
license="MIT",
install_requires=[ "numpy >= 1.8",
"scipy >= 0.13"]
"scipy >= 0.13",
"cvxpy >= 0.4.10"]
)