Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions popcor/examples/rotation_lifter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""RotationAveraging class for rotation averaging and synchronization problems."""

import warnings
from typing import Any, Dict, List, Tuple

import numpy as np
Expand All @@ -8,6 +9,7 @@
from scipy.spatial.transform import Rotation

from popcor.base_lifters import StateLifter
from popcor.utils.plotting_tools import LINESTYLES

METHOD: str = "CG"
SOLVER_KWARGS: dict = dict(
Expand Down Expand Up @@ -174,8 +176,16 @@ def get_theta(self, x: np.ndarray) -> np.ndarray:
return C_flat.reshape((self.d, self.n_rot * self.d), order="F")
elif self.level == "bm":
R0 = x[: self.d, : self.d].T

I = R0.T @ R0
scale_R = I[0, 0]
np.testing.assert_allclose(np.diag(I), scale_R)
warnings.warn(f"R0 is scaled by {scale_R}", UserWarning)

Ri = np.array(x[self.d : (self.n_rot + 1) * self.d, : self.d]).T
Ri_world = R0.T @ Ri
np.testing.assert_allclose(np.diag(Ri.T @ Ri), scale_R)
Ri_world = R0.T @ Ri / scale_R
np.testing.assert_allclose(np.diag(Ri_world.T @ Ri_world), 1.0)
return Ri_world
else:
raise ValueError(f"Unknown level {self.level} for RotationLifter")
Expand Down Expand Up @@ -448,14 +458,15 @@ def plot(self, estimates: Dict[str, np.ndarray] = {}) -> Tuple[Any, Any]:
)
label = None

linestyles = itertools.cycle(["--", "-.", ":"])
linestyles = itertools.cycle(LINESTYLES)
for label, theta in estimates.items():
ls = next(linestyles)
for i in range(self.n_rot):
plot_frame(
ax=ax,
theta=theta[:, i * self.d : (i + 1) * self.d],
label=label,
ls=next(linestyles),
ls=ls,
scale=1.0,
marker="",
r_wc_w=np.hstack([i * 2.0] + [0.0] * (self.d - 1)), # type: ignore
Expand Down
13 changes: 13 additions & 0 deletions popcor/utils/plotting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,19 @@

from popcor.utils.geometry import get_C_r_from_theta

LINESTYLES = [
"-", # solid
"--", # dashed
"-.", # dash-dot
":", # dotted
(0, (5, 2)), # long dash, short gap
(0, (1, 1)), # very fine dashed (dot style)
(0, (3, 5, 1, 5)), # dash, gap, dot, gap
(0, (5, 10)), # widely spaced dashes
(0, (3, 1, 1, 1, 1, 1)), # dash-dot-dot
(0, (2, 4, 6, 4)), # mixed dash lengths
]


def import_plt():
import shutil
Expand Down
28 changes: 20 additions & 8 deletions tests/test_rotation_lifter.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,17 @@ def test_solve_local():
print("done")


def test_solve_sdp():
@pytest.mark.parametrize("level", ["no", "bm"])
def test_solve_sdp(level):
"""Solve the SDP relaxation and compare the recovered rotations to ground truth."""
d = 3
n_abs = 1
n_rot = 4
n_rot = 5
estimates = {}
for noise in [0.0, 0.2]:
for noise in [0, 1e-3, 0.2]:
np.random.seed(2)
lifter = RotationLifter(
d=d, n_abs=n_abs, n_rot=n_rot, sparsity="chain", level="no"
d=d, n_abs=n_abs, n_rot=n_rot, sparsity="chain", level=level
)

y = lifter.simulate_y(noise=noise)
Expand All @@ -160,13 +161,23 @@ def test_solve_sdp():

X, info = solve_sdp(Q, constraints, verbose=False)

x, info_rank = rank_project(X, p=1)
if lifter.level == "no":
x, info_rank = rank_project(X, p=1)
theta_sdp = lifter.get_theta(x)
else:
X, info_rank = rank_project(X, p=lifter.d)
theta_sdp = lifter.get_theta(X)

print(f"EVR at noise {noise:.2f}: {info_rank['EVR']:.2e}")
if noise <= 1:
assert info_rank["EVR"] > 1e8
theta_sdp = lifter.get_theta(x)

error = lifter.get_error(theta_sdp)
if noise == 0.0:
np.testing.assert_allclose(theta_sdp, lifter.theta, atol=1e-5)
np.testing.assert_allclose(theta_sdp, lifter.theta, atol=1e-8)
assert error < 1e-10
elif noise < 1e-2:
assert error < 1e-2

estimates.update({"init gt": lifter.theta, f"SDP noise {noise:.2f}": theta_sdp})

Expand All @@ -178,7 +189,8 @@ def test_solve_sdp():


if __name__ == "__main__":
test_solve_sdp()
test_solve_sdp(level="no")
test_solve_sdp(level="bm")
test_solve_local()

test_measurements(d=2, level="no")
Expand Down