Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion cpp/kiss_matcher/core/kiss_matcher/KISSMatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
#include <kiss_matcher/KISSMatcher.hpp>

namespace kiss_matcher {
KISSMatcher::KISSMatcher(const float &voxel_size) { config_ = KISSMatcherConfig(voxel_size); }
KISSMatcher::KISSMatcher(const float &voxel_size) {
config_ = KISSMatcherConfig(voxel_size);
reset();
}

KISSMatcher::KISSMatcher(const KISSMatcherConfig &config) {
config_ = config;
Expand Down
82 changes: 73 additions & 9 deletions python/tests/test_kiss_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
Basic tests for KISS-Matcher Python bindings.
"""
import unittest
import numpy as np

import kiss_matcher as km
import numpy as np


class TestKISSMatcher(unittest.TestCase):
Expand All @@ -13,19 +14,19 @@ class TestKISSMatcher(unittest.TestCase):
def test_import(self):
"""Test that kiss_matcher can be imported."""
# This test should pass if the module imports correctly
self.assertTrue(hasattr(km, 'KISSMatcher'))
self.assertTrue(hasattr(km, 'KISSMatcherConfig'))
self.assertTrue(hasattr(km, 'RegistrationSolution'))
self.assertTrue(hasattr(km, "KISSMatcher"))
self.assertTrue(hasattr(km, "KISSMatcherConfig"))
self.assertTrue(hasattr(km, "RegistrationSolution"))

def test_config_creation(self):
"""Test KISSMatcherConfig creation."""
config = km.KISSMatcherConfig()
self.assertIsInstance(config.voxel_size, float)
self.assertEqual(config.voxel_size, 0.3) # default value
self.assertAlmostEqual(config.voxel_size, 0.3, places=5) # default value

# Test custom config
custom_config = km.KISSMatcherConfig(voxel_size=0.5)
self.assertEqual(custom_config.voxel_size, 0.5)
self.assertAlmostEqual(custom_config.voxel_size, 0.5, places=5)

def test_matcher_creation(self):
"""Test KISSMatcher creation."""
Expand Down Expand Up @@ -66,9 +67,72 @@ def test_basic_functionality(self):

def test_version_attribute(self):
"""Test that version attribute exists."""
self.assertTrue(hasattr(km, '__version__'))
self.assertTrue(hasattr(km, "__version__"))
self.assertIsInstance(km.__version__, str)

def test_solve_with_prematched_points(self):
"""Test solve() method with pre-matched point correspondences."""
# Create a matcher
matcher = km.KISSMatcher(0.3)

# Create source points
src_points = np.random.rand(3, 50).astype(np.float64)

# Create target points (rotation + translation of source)
angle = np.pi / 6 # 30 degrees
R_true = np.array(
[
[np.cos(angle), -np.sin(angle), 0],
[np.sin(angle), np.cos(angle), 0],
[0, 0, 1],
]
)
t_true = np.array([1.0, 2.0, 0.5])
tgt_points = R_true @ src_points + t_true[:, np.newaxis]

# Test solve method (assumes correspondences are already established)
result = matcher.solve(src_points, tgt_points)

# Check that result is valid
self.assertIsInstance(result, km.RegistrationSolution)
self.assertTrue(hasattr(result, "rotation"))
self.assertTrue(hasattr(result, "translation"))

def test_prune_and_solve_with_matched_points(self):
"""Test prune_and_solve() method with pre-matched point correspondences."""
# Create a matcher
matcher = km.KISSMatcher(0.3)

# Create source points (as list of vectors)
src_points = [np.random.rand(3).astype(np.float32) for _ in range(100)]

# Create target points with some transformation
angle = np.pi / 4 # 45 degrees
R_true = np.array(
[
[np.cos(angle), -np.sin(angle), 0],
[np.sin(angle), np.cos(angle), 0],
[0, 0, 1],
],
dtype=np.float32,
)
t_true = np.array([0.5, 1.0, 0.0], dtype=np.float32)

tgt_points = [(R_true @ p + t_true) for p in src_points]

# Add some outliers (10% outliers)
num_outliers = 10
for i in range(num_outliers):
tgt_points[i] = np.random.rand(3).astype(np.float32) * 10

# Test prune_and_solve method
result = matcher.prune_and_solve(src_points, tgt_points)

# Check that result is valid
self.assertIsInstance(result, km.RegistrationSolution)
self.assertTrue(hasattr(result, "rotation"))
self.assertTrue(hasattr(result, "translation"))


if __name__ == '__main__':
unittest.main()
if __name__ == "__main__":
unittest.main()
Loading