Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions cert_tools/linalg_tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from copy import deepcopy

import numpy as np
Expand Down Expand Up @@ -84,7 +85,7 @@ def project_so3(X):
return U @ Vh


def rank_project(X, p=1, tolerance=1e-10):
def rank_project(X, p=1, tolerance=1e-10) -> tuple[np.ndarray, dict]:
"""Project matrix X to matrix of rank p."""
try:
assert la.issymmetric(X, atol=tolerance)
Expand Down Expand Up @@ -113,7 +114,46 @@ def rank_project(X, p=1, tolerance=1e-10):
"error eigs": np.sum(np.abs(E[p:])),
"EVR": abs(E[p - 1] / E[p]), # largest over second-largest
}
return x, info
return np.asarray(x), info


def extract_lower_rank_solution(X, A_list, tol=1e-8, max_trials=10):
"""Given a psd matrix X of rank r, returns a matrix X' of rank r-1"""
r_start = 0
for k in range(max_trials):
E, U = np.linalg.eigh(X)
r = int(np.sum(E > tol))
if r == 1:
return X, {"success": True, "rank": 1, "iter": k}
elif r == r_start - 1:
return X, {"success": True, "rank": r, "iter": k}
if r_start == 0:
r_start = r
V = U[:, -r:] @ np.diag(np.sqrt(E[-r:]))
np.testing.assert_allclose(X, V @ V.T, atol=tol)

AV = np.vstack([svec(V.T @ Ai @ V) for Ai in A_list]) # type: ignore

# below, the rows of bi have the nullspace vectors.
# AV @ bi = 0
b, info_null = get_nullspace(AV, tolerance=tol)
if b.shape[0] == 0:
return X, {"success": True, "rank": r, "iter": k}
break

# chose one of the vectors
idx = np.random.choice(b.shape[0])
bi = b[idx]

assert np.all(AV @ bi < tol)
Delta = smat(bi)

lambdas = np.abs(np.linalg.eigvalsh(Delta))
# maximum magnitude eigenvalue
alpha = -1 / np.max(lambdas)
X = V @ (np.eye(r) + alpha * Delta) @ V.T
warnings.warn(f"Did not find a lower-rank solution in {max_trials} trials.")
return X, {"success": True, "rank": r, "iter": k}


def find_dependent_columns(A_sparse, tolerance=1e-10, verbose=False, debug=False):
Expand Down
12 changes: 10 additions & 2 deletions cert_tools/sdp_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,21 @@ def solve_low_rank_sdp(
if x_cand is not None:
sol_input["x0"] = x_cand.reshape((-1, 1))
r = S(**sol_input)
Y_opt = r["x"]

# Get solver info
stats = S.stats()
success = stats.get("success", False)
info = {"solver_stats": stats, "success": success}

# Reshape and generate SDP solution
Y_opt = r["x"]
Y_opt = np.array(Y_opt).reshape((n, rank), order="F")
X_opt = Y_opt @ Y_opt.T

# Get cost
scale, offset = adjust
cost = np.trace(Q @ X_opt) * scale + offset

# Construct certificate
mults = np.array(r["lam_g"])
H = Q
Expand All @@ -186,7 +194,7 @@ def solve_low_rank_sdp(
H = H + A * mults[i, 0]

# Return
info = {"X": X_opt, "H": H, "cost": cost}
info = {"X": X_opt, "H": H, "cost": cost, "success": success}
return Y_opt, info


Expand Down