Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions src/egoallo/transforms/_se3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,40 @@
from ._so3 import SO3
from .utils import get_epsilon, register_lie_group

def are_se3(matrices: Tensor, tol=1e-6):
"""
Verifies if a tensor of shape (..., 4, 4) contains matrices that belong to the SE(3) group.

Args:
matrices (torch.Tensor): The input tensor of shape (..., 4, 4).
tol (float): Tolerance for floating-point comparisons.

Returns:
torch.Tensor: A boolean tensor of shape (...) indicating which matrices are in SE(3).
"""
# Check if the matrix has the correct shape
# Check if the last two dimensions are 4x4
if matrices.shape[-2:] != (4, 4):
raise ValueError("Input tensor must have dimensions (..., 4, 4).")

# Extract the rotation components
R = matrices[..., :3, :3]
bottom_row = matrices[..., 3, :]

# Check if the bottom row is [0, 0, 0, 1] for all matrices
bottom_row_check = torch.all(torch.isclose(bottom_row, torch.tensor([0.0, 0.0, 0.0, 1.0], device=matrices.device), atol=tol), dim=-1)

# Check if the 3x3 submatrix R is a valid rotation matrix (R in SO(3))
# 1. R^T * R = I (R is orthogonal)
identity_matrix = torch.eye(3, device=matrices.device)
r_transpose_r = torch.matmul(R.transpose(-1, -2), R)
orthogonality_check = torch.all(torch.isclose(r_transpose_r, identity_matrix, atol=tol), dim=(-2, -1))

# 2. det(R) = 1 (R is a proper rotation)
determinant_check = torch.isclose(torch.det(R), torch.tensor(1.0, device=matrices.device), atol=tol)

# Combine all checks
return bottom_row_check & orthogonality_check & determinant_check

def _skew(omega: Tensor) -> Tensor:
"""
Expand Down Expand Up @@ -85,9 +119,10 @@ def identity(cls, device: Union[torch.device, str], dtype: torch.dtype) -> SE3:

@classmethod
@override
def from_matrix(cls, matrix: Tensor) -> SE3:
assert matrix.shape[-2:] == (4, 4) or matrix.shape[-2:] == (3, 4)
# Currently assumes bottom row is [0, 0, 0, 1].
def from_matrix(cls, matrix: Tensor, check=False) -> SE3:
# check if the matrix is indeed SE(3)
if check:
assert torch.all(are_se3(matrix)), "Input matrix is not in SE(3)"
return SE3.from_rotation_and_translation(
rotation=SO3.from_matrix(matrix[..., :3, :3]),
translation=matrix[..., :3, 3],
Expand Down