diff --git a/src/egoallo/transforms/_se3.py b/src/egoallo/transforms/_se3.py index 104fbbf..0faa3a4 100644 --- a/src/egoallo/transforms/_se3.py +++ b/src/egoallo/transforms/_se3.py @@ -11,6 +11,40 @@ from ._so3 import SO3 from .utils import get_epsilon, register_lie_group +def are_se3(matrices: Tensor, tol=1e-6): + """ + Verifies if a tensor of shape (..., 4, 4) contains matrices that belong to the SE(3) group. + + Args: + matrices (torch.Tensor): The input tensor of shape (..., 4, 4). + tol (float): Tolerance for floating-point comparisons. + + Returns: + torch.Tensor: A boolean tensor of shape (...) indicating which matrices are in SE(3). + """ + # Check if the matrix has the correct shape + # Check if the last two dimensions are 4x4 + if matrices.shape[-2:] != (4, 4): + raise ValueError("Input tensor must have dimensions (..., 4, 4).") + + # Extract the rotation components + R = matrices[..., :3, :3] + bottom_row = matrices[..., 3, :] + + # Check if the bottom row is [0, 0, 0, 1] for all matrices + bottom_row_check = torch.all(torch.isclose(bottom_row, torch.tensor([0.0, 0.0, 0.0, 1.0], device=matrices.device), atol=tol), dim=-1) + + # Check if the 3x3 submatrix R is a valid rotation matrix (R in SO(3)) + # 1. R^T * R = I (R is orthogonal) + identity_matrix = torch.eye(3, device=matrices.device) + r_transpose_r = torch.matmul(R.transpose(-1, -2), R) + orthogonality_check = torch.all(torch.isclose(r_transpose_r, identity_matrix, atol=tol), dim=(-2, -1)) + + # 2. det(R) = 1 (R is a proper rotation) + determinant_check = torch.isclose(torch.det(R), torch.tensor(1.0, device=matrices.device), atol=tol) + + # Combine all checks + return bottom_row_check & orthogonality_check & determinant_check def _skew(omega: Tensor) -> Tensor: """ @@ -85,9 +119,10 @@ def identity(cls, device: Union[torch.device, str], dtype: torch.dtype) -> SE3: @classmethod @override - def from_matrix(cls, matrix: Tensor) -> SE3: - assert matrix.shape[-2:] == (4, 4) or matrix.shape[-2:] == (3, 4) - # Currently assumes bottom row is [0, 0, 0, 1]. + def from_matrix(cls, matrix: Tensor, check=False) -> SE3: + # check if the matrix is indeed SE(3) + if check: + assert torch.all(are_se3(matrix)), "Input matrix is not in SE(3)" return SE3.from_rotation_and_translation( rotation=SO3.from_matrix(matrix[..., :3, :3]), translation=matrix[..., :3, 3],