diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..851be7a --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,50 @@ +name: CI + +on: + workflow_dispatch: + push: + branches: [main, master] + pull_request: + branches: [main, master] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest + pip install -e . + pip install -r requirements.txt + + - name: Detect GPU availability + id: detect-gpu + shell: bash + run: | + HAS_GPU=$(python -c "import torch; print('true' if torch.cuda.is_available() else 'false')") + echo "has_gpu=$HAS_GPU" >> "$GITHUB_OUTPUT" + echo "GPU available: $HAS_GPU" + + - name: Pre-cache model checkpoint + if: steps.detect-gpu.outputs.has_gpu == 'true' + run: | + mkdir -p ~/.cache/torch/hub/checkpoints + if [ ! -f ~/.cache/torch/hub/checkpoints/sharp_2572gikvuh.pt ]; then + curl -L -o ~/.cache/torch/hub/checkpoints/sharp_2572gikvuh.pt https://ml-site.cdn-apple.com/models/sharp/sharp_2572gikvuh.pt + fi + + - name: Run tests + run: | + pytest diff --git a/.gitignore b/.gitignore index 49aa19b..8f1c51f 100644 --- a/.gitignore +++ b/.gitignore @@ -164,3 +164,4 @@ cython_debug/ .DS_STORE *.pt .aider* +.vscode/ diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/data/example.jpg b/tests/data/example.jpg new file mode 100644 index 0000000..37179e8 Binary files /dev/null and b/tests/data/example.jpg differ diff --git a/tests/test_basic.py b/tests/test_basic.py new file mode 100644 index 0000000..18ace9a --- /dev/null +++ b/tests/test_basic.py @@ -0,0 +1,11 @@ +def test_import(): + """Test that the package can be imported.""" + import sharp + assert sharp is not None + + +def test_cli_import(): + """Test that CLI modules can be imported.""" + from sharp import cli + assert cli is not None + diff --git a/tests/test_gaussians.py b/tests/test_gaussians.py new file mode 100644 index 0000000..ce97e8a --- /dev/null +++ b/tests/test_gaussians.py @@ -0,0 +1,73 @@ +import tempfile +from pathlib import Path + +import numpy as np +import torch + +from sharp.utils.gaussians import Gaussians3D, SceneMetaData, save_ply + + +def test_gaussians3d_creation(): + """Test creating a Gaussians3D object.""" + num_gaussians = 100 + gaussians = Gaussians3D( + mean_vectors=torch.randn(1, num_gaussians, 3), + singular_values=torch.rand(1, num_gaussians, 3), + quaternions=torch.randn(1, num_gaussians, 4), + colors=torch.rand(1, num_gaussians, 3), + opacities=torch.rand(1, num_gaussians), + ) + + assert gaussians.mean_vectors.shape == (1, num_gaussians, 3) + assert gaussians.singular_values.shape == (1, num_gaussians, 3) + assert gaussians.quaternions.shape == (1, num_gaussians, 4) + assert gaussians.colors.shape == (1, num_gaussians, 3) + assert gaussians.opacities.shape == (1, num_gaussians) + + +def test_gaussians3d_to_device(): + """Test moving Gaussians3D to a device.""" + num_gaussians = 50 + gaussians = Gaussians3D( + mean_vectors=torch.randn(1, num_gaussians, 3), + singular_values=torch.rand(1, num_gaussians, 3), + quaternions=torch.randn(1, num_gaussians, 4), + colors=torch.rand(1, num_gaussians, 3), + opacities=torch.rand(1, num_gaussians), + ) + + gaussians_cpu = gaussians.to(torch.device("cpu")) + assert gaussians_cpu.mean_vectors.device.type == "cpu" + assert gaussians_cpu.singular_values.device.type == "cpu" + + +def test_scene_metadata(): + """Test creating SceneMetaData.""" + metadata = SceneMetaData( + focal_length_px=1000.0, + resolution_px=(1920, 1080), + color_space="linearRGB", + ) + + assert metadata.focal_length_px == 1000.0 + assert metadata.resolution_px == (1920, 1080) + assert metadata.color_space == "linearRGB" + + +def test_save_ply(tmp_path): + """Test saving Gaussians to PLY file.""" + num_gaussians = 10 + gaussians = Gaussians3D( + mean_vectors=torch.randn(1, num_gaussians, 3), + singular_values=torch.rand(1, num_gaussians, 3), + quaternions=torch.randn(1, num_gaussians, 4), + colors=torch.rand(1, num_gaussians, 3), + opacities=torch.rand(1, num_gaussians), + ) + + output_path = tmp_path / "test_gaussians.ply" + save_ply(gaussians, f_px=1000.0, image_shape=(1920, 1080), path=output_path) + + assert output_path.exists() + assert output_path.stat().st_size > 0 + diff --git a/tests/test_io.py b/tests/test_io.py new file mode 100644 index 0000000..1b14f54 --- /dev/null +++ b/tests/test_io.py @@ -0,0 +1,69 @@ +import tempfile +from pathlib import Path + +import numpy as np +import pytest +from PIL import Image + +from sharp.utils import io + + +def test_load_rgb(tmp_path): + """Test loading an RGB image.""" + test_image_path = tmp_path / "test.jpg" + img_array = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + img = Image.fromarray(img_array) + img.save(test_image_path) + + image, icc_profile, f_px = io.load_rgb(test_image_path) + + assert image.shape == (100, 100, 3) + assert image.dtype == np.uint8 + assert f_px > 0 + assert isinstance(icc_profile, (list, type(None))) + + +def test_get_supported_image_extensions(): + """Test getting supported image extensions.""" + extensions = io.get_supported_image_extensions() + assert isinstance(extensions, list) + assert len(extensions) > 0 + assert ".jpg" in extensions or ".JPG" in extensions + assert ".png" in extensions or ".PNG" in extensions + + +def test_get_supported_video_extensions(): + """Test getting supported video extensions.""" + extensions = io.get_supported_video_extensions() + assert isinstance(extensions, list) + assert ".mp4" in extensions or ".MP4" in extensions + + +def test_save_image(tmp_path): + """Test saving an image.""" + test_image_path = tmp_path / "test_output.jpg" + img_array = np.random.randint(0, 255, (50, 50, 3), dtype=np.uint8) + + io.save_image(img_array, test_image_path) + + assert test_image_path.exists() + loaded_image, _, _ = io.load_rgb(test_image_path) + assert loaded_image.shape == (50, 50, 3) + + +def test_convert_focallength(): + """Test focal length conversion.""" + f_px = io.convert_focallength(1920, 1080, 30.0) + assert f_px > 0 + assert isinstance(f_px, float) + + +def test_load_example_image(): + """Test loading the example image from tests/data directory.""" + example_path = Path(__file__).parent / "data" / "example.jpg" + if example_path.exists(): + image, _, f_px = io.load_rgb(example_path) + assert image.ndim == 3 + assert image.shape[2] == 3 + assert f_px > 0 + diff --git a/tests/test_predict.py b/tests/test_predict.py new file mode 100644 index 0000000..236410b --- /dev/null +++ b/tests/test_predict.py @@ -0,0 +1,75 @@ +import numpy as np +import pytest +import torch +from pathlib import Path + +from sharp.cli.predict import predict_image +from sharp.models import PredictorParams, create_predictor +from sharp.utils.gaussians import Gaussians3D + + +@pytest.mark.skipif( + not torch.cuda.is_available() and not torch.mps.is_available(), + reason="Requires CUDA or MPS for model inference", +) +def test_predict_image_with_model(tmp_path): + """Test predict_image function with a real model checkpoint.""" + example_path = Path(__file__).parent / "data" / "example.jpg" + if not example_path.exists(): + pytest.skip("Test image not found") + + from sharp.utils import io + + image, _, f_px = io.load_rgb(example_path) + + device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu") + + # Use the pre-cached model checkpoint + checkpoint_path = Path.home() / ".cache" / "torch" / "hub" / "checkpoints" / "sharp_2572gikvuh.pt" + if not checkpoint_path.exists(): + pytest.skip("Model checkpoint not found in cache") + + try: + predictor = create_predictor(PredictorParams(checkpoint_path=str(checkpoint_path))) + predictor.eval() + predictor.to(device) + + gaussians = predict_image(predictor, image, f_px, device) + + assert isinstance(gaussians, Gaussians3D) + assert gaussians.mean_vectors.shape[0] == 1 + assert gaussians.mean_vectors.shape[2] == 3 + assert gaussians.colors.shape[2] == 3 + assert gaussians.opacities.shape[1] > 0 + except Exception as e: + pytest.skip(f"Model inference failed (likely missing checkpoint): {e}") + + +def test_predict_image_signature(): + """Test that predict_image function has correct signature.""" + import inspect + + sig = inspect.signature(predict_image) + params = sig.parameters + + assert "predictor" in params + assert "image" in params + assert "f_px" in params + assert "device" in params + + +def test_create_predictor(): + """Test creating a predictor model.""" + params = PredictorParams() + predictor = create_predictor(params) + + assert predictor is not None + assert hasattr(predictor, "eval") + assert hasattr(predictor, "to") + + +def test_predictor_params(): + """Test PredictorParams creation.""" + params = PredictorParams() + assert params is not None +