diff --git a/pytissueoptics/__init__.py b/pytissueoptics/__init__.py index b90fc8af..5921717f 100644 --- a/pytissueoptics/__init__.py +++ b/pytissueoptics/__init__.py @@ -1,5 +1,6 @@ from importlib.metadata import version -from .scene import * # noqa: F403 + from .rayscattering import * # noqa: F403 +from .scene import * # noqa: F403 __version__ = version("pytissueoptics") diff --git a/pytissueoptics/examples/scene/example0.py b/pytissueoptics/examples/scene/example0.py index 7bf1cc1d..fae1ba52 100644 --- a/pytissueoptics/examples/scene/example0.py +++ b/pytissueoptics/examples/scene/example0.py @@ -10,13 +10,13 @@ def exampleCode(): - from pytissueoptics.scene import Cuboid, Ellipsoid, MayaviViewer, Sphere, Vector + from pytissueoptics.scene import Cuboid, Ellipsoid, Sphere, Vector, get3DViewer cuboid = Cuboid(a=1, b=3, c=1, position=Vector(1, 0, 0)) sphere = Sphere(radius=0.5, position=Vector(0, 0, 0)) ellipsoid = Ellipsoid(a=1.5, b=1, c=1, position=Vector(-2, 0, 0)) - viewer = MayaviViewer() + viewer = get3DViewer() viewer.add(cuboid, sphere, ellipsoid, representation="surface") viewer.show() diff --git a/pytissueoptics/examples/scene/example1.py b/pytissueoptics/examples/scene/example1.py index bf5b2f42..fa43ee49 100644 --- a/pytissueoptics/examples/scene/example1.py +++ b/pytissueoptics/examples/scene/example1.py @@ -8,7 +8,7 @@ def exampleCode(): - from pytissueoptics.scene import Cuboid, MayaviViewer, Vector + from pytissueoptics.scene import Cuboid, Vector, get3DViewer centerCube = Cuboid(a=1, b=1, c=1, position=Vector(0, 0, 0)) topCube = Cuboid(a=1, b=1, c=1, position=Vector(0, 2, 0)) @@ -17,7 +17,7 @@ def exampleCode(): bottomCube.translateTo(Vector(0, -2, 0)) centerCube.rotate(0, 30, 30) - viewer = MayaviViewer() + viewer = get3DViewer() viewer.add(centerCube, topCube, bottomCube, representation="surface") viewer.show() diff --git a/pytissueoptics/examples/scene/example2.py b/pytissueoptics/examples/scene/example2.py index 28b866a4..b58415ee 100644 --- a/pytissueoptics/examples/scene/example2.py +++ b/pytissueoptics/examples/scene/example2.py @@ -7,20 +7,20 @@ def exampleCode(): - from pytissueoptics.scene import Cuboid, MayaviViewer, Vector + from pytissueoptics.scene import Cuboid, Vector, get3DViewer cuboid1 = Cuboid(1, 1, 1, position=Vector(2, 0, 0)) cuboid2 = Cuboid(2, 1, 1, position=Vector(0, 2, 0)) cuboid3 = Cuboid(3, 1, 1, position=Vector(0, 0, 2)) - viewer = MayaviViewer() + viewer = get3DViewer() viewer.add(cuboid1, cuboid2, cuboid3, representation="wireframe", lineWidth=5) viewer.show() cuboidStack = cuboid1.stack(cuboid2, onSurface="right") cuboidStack = cuboidStack.stack(cuboid3, onSurface="top") - viewer.clear() + viewer = get3DViewer() viewer.add(cuboidStack, representation="wireframe", lineWidth=5) viewer.show() diff --git a/pytissueoptics/examples/scene/example3.py b/pytissueoptics/examples/scene/example3.py index b61e0182..cf4b93f2 100644 --- a/pytissueoptics/examples/scene/example3.py +++ b/pytissueoptics/examples/scene/example3.py @@ -6,11 +6,11 @@ def exampleCode(): - from pytissueoptics.scene import MayaviViewer, loadSolid + from pytissueoptics.scene import get3DViewer, loadSolid solid = loadSolid("pytissueoptics/examples/scene/droid.obj") - viewer = MayaviViewer() + viewer = get3DViewer() viewer.add(solid, representation="surface", showNormals=True, normalLength=0.2) viewer.show() diff --git a/pytissueoptics/examples/scene/example4.py b/pytissueoptics/examples/scene/example4.py index 47fd42e7..0011fa72 100644 --- a/pytissueoptics/examples/scene/example4.py +++ b/pytissueoptics/examples/scene/example4.py @@ -7,13 +7,13 @@ def exampleCode(): from pytissueoptics.scene import ( - MayaviViewer, PlanoConcaveLens, PlanoConvexLens, RefractiveMaterial, SymmetricLens, ThickLens, Vector, + get3DViewer, ) material = RefractiveMaterial(refractiveIndex=1.44) @@ -22,7 +22,7 @@ def exampleCode(): lens2 = PlanoConvexLens(f=-60, diameter=25.4, thickness=4, material=material, position=Vector(0, 0, 20)) lens4 = PlanoConcaveLens(f=60, diameter=25.4, thickness=4, material=material, position=Vector(0, 0, 30)) - viewer = MayaviViewer() + viewer = get3DViewer() viewer.add(lens1, lens2, lens3, lens4, representation="surface", colormap="viridis", showNormals=False) viewer.show() diff --git a/pytissueoptics/rayscattering/display/utils/__init__.py b/pytissueoptics/rayscattering/display/utils/__init__.py index ae6087cb..8aa886cc 100644 --- a/pytissueoptics/rayscattering/display/utils/__init__.py +++ b/pytissueoptics/rayscattering/display/utils/__init__.py @@ -1,10 +1,8 @@ from .direction import DEFAULT_X_VIEW_DIRECTIONS, DEFAULT_Y_VIEW_DIRECTIONS, DEFAULT_Z_VIEW_DIRECTIONS, Direction -from .volumeSlicer import VolumeSlicer __all__ = [ "DEFAULT_X_VIEW_DIRECTIONS", "DEFAULT_Y_VIEW_DIRECTIONS", "DEFAULT_Z_VIEW_DIRECTIONS", "Direction", - "VolumeSlicer", ] diff --git a/pytissueoptics/rayscattering/display/viewer.py b/pytissueoptics/rayscattering/display/viewer.py index 1f21280c..614ca2e4 100644 --- a/pytissueoptics/rayscattering/display/viewer.py +++ b/pytissueoptics/rayscattering/display/viewer.py @@ -12,7 +12,7 @@ from pytissueoptics.rayscattering.scatteringScene import ScatteringScene from pytissueoptics.rayscattering.source import Source from pytissueoptics.rayscattering.statistics import Stats -from pytissueoptics.scene import MAYAVI_AVAILABLE, MayaviViewer, ViewPointStyle +from pytissueoptics.scene import ViewPointStyle, get3DViewer class Visibility(Flag): @@ -112,11 +112,8 @@ def show3D( viewsLogScale: bool = True, viewsColormap: str = "viridis", ): - if not MAYAVI_AVAILABLE: - utils.warn("Package 'mayavi' is not available. Please install it to use 3D visualizations.") - return - - self._viewer3D = MayaviViewer(viewPointStyle=ViewPointStyle.OPTICS) + self._viewer3D = get3DViewer() + self._viewer3D.setViewPointStyle(ViewPointStyle.OPTICS) if visibility == Visibility.AUTO: visibility = Visibility.DEFAULT_3D if self._logger.has3D else Visibility.DEFAULT_2D @@ -146,10 +143,6 @@ def show3DVolumeSlicer( interpolate: bool = False, limits: Tuple[tuple, tuple, tuple] = None, ): - if not MAYAVI_AVAILABLE: - utils.warn("ERROR: Package 'mayavi' is not available. Please install it to use 3D visualizations.") - return - if not self._logger.has3D: utils.warn("ERROR: Cannot show 3D volume slicer without 3D data.") return @@ -182,10 +175,7 @@ def show3DVolumeSlicer( if logScale: hist = utils.logNorm(hist) - from pytissueoptics.rayscattering.display.utils.volumeSlicer import VolumeSlicer - - slicer = VolumeSlicer(hist, interpolate=interpolate) - slicer.show() + get3DViewer().showVolumeSlicer(hist, interpolate=interpolate) def show2D(self, view: View2D = None, viewIndex: int = None, logScale: bool = True, colormap: str = "viridis"): self._logger.showView(view=view, viewIndex=viewIndex, logScale=logScale, colormap=colormap) diff --git a/pytissueoptics/rayscattering/opencl/__init__.py b/pytissueoptics/rayscattering/opencl/__init__.py index e923c2fa..2d9606c5 100644 --- a/pytissueoptics/rayscattering/opencl/__init__.py +++ b/pytissueoptics/rayscattering/opencl/__init__.py @@ -1,6 +1,7 @@ +import os + from pytissueoptics.rayscattering.opencl.config.CLConfig import OPENCL_AVAILABLE, WEIGHT_THRESHOLD, CLConfig, warnings from pytissueoptics.rayscattering.opencl.config.IPPTable import IPPTable -import os OPENCL_OK = True OPENCL_DISABLED = os.environ.get("PTO_DISABLE_OPENCL", "0") == "1" diff --git a/pytissueoptics/rayscattering/scatteringScene.py b/pytissueoptics/rayscattering/scatteringScene.py index 7d223610..2cb05cbe 100644 --- a/pytissueoptics/rayscattering/scatteringScene.py +++ b/pytissueoptics/rayscattering/scatteringScene.py @@ -3,7 +3,7 @@ import numpy as np from pytissueoptics.rayscattering.materials import ScatteringMaterial -from pytissueoptics.scene import MayaviViewer, Scene, Vector +from pytissueoptics.scene import Scene, Vector, get3DViewer from pytissueoptics.scene.solids import Solid from pytissueoptics.scene.viewer.displayable import Displayable @@ -22,7 +22,7 @@ def add(self, solid: Solid, position: Vector = None): super().add(solid, position) def show(self, source: Displayable = None, opacity=0.8, colormap="cool", **kwargs): - viewer = MayaviViewer() + viewer = get3DViewer() self.addToViewer(viewer, opacity=opacity, colormap=colormap, **kwargs) if source: source.addToViewer(viewer) diff --git a/pytissueoptics/rayscattering/source.py b/pytissueoptics/rayscattering/source.py index 042f299f..59c0c1f1 100644 --- a/pytissueoptics/rayscattering/source.py +++ b/pytissueoptics/rayscattering/source.py @@ -18,7 +18,7 @@ from pytissueoptics.scene.solids.cone import Cone from pytissueoptics.scene.solids.cylinder import Cylinder from pytissueoptics.scene.utils import progressBar -from pytissueoptics.scene.viewer import Displayable, MayaviViewer +from pytissueoptics.scene.viewer import Abstract3DViewer, Displayable class Source(Displayable): @@ -184,7 +184,7 @@ def photons(self): def getPhotonCount(self) -> int: return self._N - def addToViewer(self, viewer: MayaviViewer, representation="surface", colormap="Wistia", opacity=1.0, **kwargs): + def addToViewer(self, viewer: Abstract3DViewer, representation="surface", colormap="Wistia", opacity=1.0, **kwargs): sphere = Sphere(radius=self.displaySize / 2, position=self._position) viewer.add(sphere, representation=representation, colormap=colormap, opacity=opacity, **kwargs) @@ -227,7 +227,7 @@ def getInitialPositionsAndDirections(self) -> Tuple[np.ndarray, np.ndarray]: directions = self._getInitialDirections() return positions, directions - def addToViewer(self, viewer: MayaviViewer, representation="surface", colormap="Wistia", opacity=1, **kwargs): + def addToViewer(self, viewer: Abstract3DViewer, representation="surface", colormap="Wistia", opacity=1, **kwargs): baseHeight = 0.5 * self.displaySize baseCenter = self._position + self._direction * baseHeight / 2 base = Cylinder(radius=self.displaySize / 8, length=baseHeight, position=baseCenter) diff --git a/pytissueoptics/rayscattering/tests/display/testViewer.py b/pytissueoptics/rayscattering/tests/display/testViewer.py index 6087c52c..468c4531 100644 --- a/pytissueoptics/rayscattering/tests/display/testViewer.py +++ b/pytissueoptics/rayscattering/tests/display/testViewer.py @@ -1,5 +1,5 @@ import unittest -from unittest.mock import patch +from unittest.mock import Mock, patch import numpy as np from mockito import ANY, mock, verify, when @@ -13,12 +13,7 @@ from pytissueoptics.rayscattering.source import Source from pytissueoptics.scene.geometry import BoundingBox from pytissueoptics.scene.logger import Logger - - -def patchMayaviRender(func): - for module in ["show", "gcf", "figure", "clf", "triangular_mesh", "points3d"]: - func = patch("mayavi.mlab." + module)(func) - return func +from pytissueoptics.scene.viewer import Abstract3DViewer class TestViewer(unittest.TestCase): @@ -34,6 +29,11 @@ def setUp(self): self.logger.info = {"photonCount": 0, "sourceSolidLabel": None} self.viewer = Viewer(self.scene, self.source, self.logger) + self.mock3DViewer = Mock(spec=Abstract3DViewer) + p = patch("pytissueoptics.rayscattering.display.viewer.get3DViewer", return_value=self.mock3DViewer) + self.addCleanup(p.stop) + p.start() + def testGivenViewerWithBaseLogger_shouldRaiseException(self): with self.assertRaises(AssertionError): self.viewer = Viewer(self.scene, self.source, Logger()) @@ -101,32 +101,17 @@ def testWhenShow2DAllViewsWithViewGroup_shouldOnlyShowViewsOfThisGroup(self): verify(self.logger, times=1).showView(view=None, viewIndex=1, logScale=True, colormap=ANY(str)) verify(self.logger, times=0).showView(view=None, viewIndex=0, logScale=True, colormap=ANY(str)) - def testWhenShow3DWithoutMayaviInstalled_shouldWarnAndIgnore(self): - from pytissueoptics.rayscattering.display import viewer - - viewer.MAYAVI_AVAILABLE = False - with self.assertWarns(UserWarning): - self.viewer.show3D() - - @patchMayaviRender - def testWhenShow3DWithScene_shouldDisplayScene(self, mockShow, *args): + def testWhenShow3DWithScene_shouldDisplayScene(self): self.viewer.show3D(visibility=Visibility.SCENE) - verify(self.scene, times=1).addToViewer(...) - mockShow.assert_called_once() + self.mock3DViewer.show.assert_called_once() - @patchMayaviRender - def testWhenShow3DWithSource_shouldDisplaySource(self, mockShow, *args): + def testWhenShow3DWithSource_shouldDisplaySource(self): self.viewer.show3D(visibility=Visibility.SOURCE) - verify(self.source, times=1).addToViewer(...) - mockShow.assert_called_once() + self.mock3DViewer.show.assert_called_once() - @patchMayaviRender - @patch("pytissueoptics.scene.viewer.MayaviViewer.addDataPoints") - def testWhenShow3DWithDefaultPointCloud_shouldDisplayPointCloudOfSolidsAndSurfaceLeaving( - self, mockAddDataPoints, mockShow, *args - ): + def testWhenShow3DWithDefaultPointCloud_shouldDisplayPointCloudOfSolidsAndSurfaceLeaving(self): mockPointCloudFactory = mock(PointCloudFactory) aPointCloud = PointCloud( solidPoints=np.array([[0.5, 0, 0, 0]]), surfacePoints=np.array([[1, 0, 0, 0], [-1, 0, 0, 0]]) @@ -136,17 +121,15 @@ def testWhenShow3DWithDefaultPointCloud_shouldDisplayPointCloudOfSolidsAndSurfac self.viewer.show3D(visibility=Visibility.POINT_CLOUD) - mockAddDataPoints.assert_called() - addedSolidPoints = mockAddDataPoints.call_args_list[0][0][0] - addedSurfacePoints = mockAddDataPoints.call_args_list[1][0][0] + self.mock3DViewer.addDataPoints.assert_called() + addedSolidPoints = self.mock3DViewer.addDataPoints.call_args_list[0][0][0] + addedSurfacePoints = self.mock3DViewer.addDataPoints.call_args_list[1][0][0] self.assertTrue(np.array_equal(addedSolidPoints, aPointCloud.solidPoints)) self.assertTrue(np.array_equal(addedSurfacePoints, aPointCloud.leavingSurfacePoints)) - mockShow.assert_called_once() + self.mock3DViewer.show.assert_called_once() - @patchMayaviRender - @patch("pytissueoptics.scene.viewer.MayaviViewer.addDataPoints") - def testGivenNoData_whenShow3DWithPointCloud_shouldNotDisplayPointCloud(self, mockAddDataPoints, mockShow, *args): + def testGivenNoData_whenShow3DWithPointCloud_shouldNotDisplayPointCloud(self): mockPointCloudFactory = mock(PointCloudFactory) aPointCloud = PointCloud() when(mockPointCloudFactory).getPointCloud(...).thenReturn(aPointCloud) @@ -154,12 +137,10 @@ def testGivenNoData_whenShow3DWithPointCloud_shouldNotDisplayPointCloud(self, mo self.viewer.show3D(visibility=Visibility.POINT_CLOUD) - mockAddDataPoints.assert_not_called() - mockShow.assert_called_once() + self.mock3DViewer.addDataPoints.assert_not_called() + self.mock3DViewer.show.assert_called_once() - @patchMayaviRender - @patch("pytissueoptics.scene.viewer.MayaviViewer.addDataPoints") - def testWhenShow3DWithSurfacePointCloud_shouldOnlyDisplaySurfacePoints(self, mockAddDataPoints, mockShow, *args): + def testWhenShow3DWithSurfacePointCloud_shouldOnlyDisplaySurfacePoints(self): mockPointCloudFactory = mock(PointCloudFactory) aPointCloud = PointCloud( solidPoints=np.array([[0.5, 0, 0, 0]]), surfacePoints=np.array([[1, 0, 0, 0], [-1, 0, 0, 0]]) @@ -169,15 +150,13 @@ def testWhenShow3DWithSurfacePointCloud_shouldOnlyDisplaySurfacePoints(self, moc self.viewer.show3D(visibility=Visibility.POINT_CLOUD, pointCloudStyle=PointCloudStyle(showSolidPoints=False)) - mockAddDataPoints.assert_called_once() - self.assertTrue(np.array_equal(mockAddDataPoints.call_args[0][0], aPointCloud.leavingSurfacePoints)) - mockShow.assert_called_once() + self.mock3DViewer.addDataPoints.assert_called_once() + self.assertTrue( + np.array_equal(self.mock3DViewer.addDataPoints.call_args[0][0], aPointCloud.leavingSurfacePoints) + ) + self.mock3DViewer.show.assert_called_once() - @patchMayaviRender - @patch("pytissueoptics.scene.viewer.MayaviViewer.addDataPoints") - def testWhenShow3DWithEnteringSurfacePointCloud_shouldOnlyDisplayEnteringSurfacePoints( - self, mockAddDataPoints, mockShow, *args - ): + def testWhenShow3DWithEnteringSurfacePointCloud_shouldOnlyDisplayEnteringSurfacePoints(self): mockPointCloudFactory = mock(PointCloudFactory) aPointCloud = PointCloud( solidPoints=np.array([[0.5, 0, 0, 0]]), surfacePoints=np.array([[1, 0, 0, 0], [-1, 1, 1, 1]]) @@ -192,30 +171,26 @@ def testWhenShow3DWithEnteringSurfacePointCloud_shouldOnlyDisplayEnteringSurface ), ) - mockAddDataPoints.assert_called_once() - self.assertTrue(np.array_equal(mockAddDataPoints.call_args[0][0], aPointCloud.enteringSurfacePointsPositive)) - mockShow.assert_called_once() + self.mock3DViewer.addDataPoints.assert_called_once() + self.assertTrue( + np.array_equal(self.mock3DViewer.addDataPoints.call_args[0][0], aPointCloud.enteringSurfacePointsPositive) + ) + self.mock3DViewer.show.assert_called_once() - @patchMayaviRender - @patch("pytissueoptics.scene.viewer.MayaviViewer.addImage") - def testWhenShow3DWithViews_shouldAdd2DImageOfTheseViewsInThe3DDisplay(self, mockAddImage, mockShow, *args): + def testWhenShow3DWithViews_shouldAdd2DImageOfTheseViewsInThe3DDisplay(self): self._givenLoggerWithXSceneView() sceneView = self.logger.views[0] self.viewer.show3D(visibility=Visibility.VIEWS) - mockAddImage.assert_called_once() - addedImage = mockAddImage.call_args[0][0] + self.mock3DViewer.addImage.assert_called_once() + addedImage = self.mock3DViewer.addImage.call_args[0][0] self.assertTrue(np.array_equal(sceneView.getImageDataWithDefaultAlignment(), addedImage)) - displayedPosition = mockAddImage.call_args[0][4] + displayedPosition = self.mock3DViewer.addImage.call_args[0][4] self.assertEqual(-2.1, displayedPosition) - mockShow.assert_called_once() + self.mock3DViewer.show.assert_called_once() - @patchMayaviRender - @patch("pytissueoptics.scene.viewer.MayaviViewer.addImage") - def testWhenShow3DWithViewsIndexList_shouldAdd2DImageOfTheseViewsInThe3DDisplay( - self, mockAddImage, mockShow, *args - ): + def testWhenShow3DWithViewsIndexList_shouldAdd2DImageOfTheseViewsInThe3DDisplay(self): self._givenLoggerWithXSceneView() sceneView = self.logger.views[0] theViewIndex = 9 @@ -223,16 +198,14 @@ def testWhenShow3DWithViewsIndexList_shouldAdd2DImageOfTheseViewsInThe3DDisplay( self.viewer.show3D(visibility=Visibility.VIEWS, viewsVisibility=[theViewIndex]) - mockAddImage.assert_called_once() - addedImage = mockAddImage.call_args[0][0] + self.mock3DViewer.addImage.assert_called_once() + addedImage = self.mock3DViewer.addImage.call_args[0][0] self.assertTrue(np.array_equal(sceneView.getImageDataWithDefaultAlignment(), addedImage)) - displayedPosition = mockAddImage.call_args[0][4] + displayedPosition = self.mock3DViewer.addImage.call_args[0][4] self.assertEqual(-2.1, displayedPosition) - mockShow.assert_called_once() + self.mock3DViewer.show.assert_called_once() - @patchMayaviRender - @patch("pytissueoptics.scene.viewer.MayaviViewer.addImage") - def testGiven3DLogger_whenShow3DDefault_shouldDisplayEverythingExceptViews(self, mockAddImage, mockShow, *args): + def testGiven3DLogger_whenShow3DDefault_shouldDisplayEverythingExceptViews(self): mockPointCloudFactory = mock(PointCloudFactory) aPointCloud = PointCloud() when(mockPointCloudFactory).getPointCloud(...).thenReturn(aPointCloud) @@ -243,14 +216,10 @@ def testGiven3DLogger_whenShow3DDefault_shouldDisplayEverythingExceptViews(self, verify(self.source, times=1).addToViewer(...) verify(self.scene, times=1).addToViewer(...) verify(mockPointCloudFactory, times=1).getPointCloud(...) - mockAddImage.assert_not_called() - mockShow.assert_called_once() - - @patchMayaviRender - @patch("pytissueoptics.scene.viewer.MayaviViewer.addImage") - def testGiven2DLogger_whenShow3DDefault_shouldDisplayEverythingExceptPointCloud( - self, mockAddImage, mockShow, *args - ): + self.mock3DViewer.addImage.assert_not_called() + self.mock3DViewer.show.assert_called_once() + + def testGiven2DLogger_whenShow3DDefault_shouldDisplayEverythingExceptPointCloud(self): self._givenLoggerWithXSceneView() self.logger.has3D = False @@ -263,14 +232,10 @@ def testGiven2DLogger_whenShow3DDefault_shouldDisplayEverythingExceptPointCloud( verify(self.source, times=1).addToViewer(...) verify(self.scene, times=1).addToViewer(...) verify(mockPointCloudFactory, times=0).getPointCloud(...) - mockAddImage.assert_called() - mockShow.assert_called_once() - - @patchMayaviRender - @patch("pytissueoptics.scene.viewer.MayaviViewer.addImage") - def testGiven2DLogger_whenShow3DWithDefault3DVisibility_shouldWarnAndDisplayDefault2D( - self, mockAddImage, mockShow, *args - ): + self.mock3DViewer.addImage.assert_called() + self.mock3DViewer.show.assert_called_once() + + def testGiven2DLogger_whenShow3DWithDefault3DVisibility_shouldWarnAndDisplayDefault2D(self): self._givenLoggerWithXSceneView() self.logger.has3D = False @@ -284,8 +249,8 @@ def testGiven2DLogger_whenShow3DWithDefault3DVisibility_shouldWarnAndDisplayDefa verify(self.source, times=1).addToViewer(...) verify(self.scene, times=1).addToViewer(...) verify(mockPointCloudFactory, times=0).getPointCloud(...) - mockAddImage.assert_called() - mockShow.assert_called_once() + self.mock3DViewer.addImage.assert_called() + self.mock3DViewer.show.assert_called_once() def _givenLoggerWithXSceneView(self): sceneView = View2DProjectionX() diff --git a/pytissueoptics/rayscattering/tests/testScatteringScene.py b/pytissueoptics/rayscattering/tests/testScatteringScene.py index 6c758fa5..e2a9a4e4 100644 --- a/pytissueoptics/rayscattering/tests/testScatteringScene.py +++ b/pytissueoptics/rayscattering/tests/testScatteringScene.py @@ -7,7 +7,7 @@ from pytissueoptics.rayscattering.materials import ScatteringMaterial from pytissueoptics.rayscattering.scatteringScene import ScatteringScene from pytissueoptics.scene.solids import Cuboid -from pytissueoptics.scene.viewer import MayaviViewer +from pytissueoptics.scene.viewer import Abstract3DViewer def patchMayaviShow(func): @@ -29,7 +29,7 @@ def testWhenAddingASolidWithNoScatteringMaterialDefined_shouldRaiseException(sel def testWhenAddToViewer_shouldAddAllSolidsToViewer(self): scene = ScatteringScene([Cuboid(1, 1, 1, material=ScatteringMaterial())]) - viewer = mock(MayaviViewer) + viewer = mock(Abstract3DViewer) when(viewer).add(...).thenReturn() scene.addToViewer(viewer) @@ -37,7 +37,7 @@ def testWhenAddToViewer_shouldAddAllSolidsToViewer(self): verify(viewer).add(*scene.solids, ...) @patchMayaviShow - def testWhenShow_shouldShowInsideMayaviViewer(self, mockShow, *args): + def testWhenShow_shouldShowInside3DViewer(self, mockShow, *args): scene = ScatteringScene([Cuboid(1, 1, 1, material=ScatteringMaterial())]) scene.show() diff --git a/pytissueoptics/scene/__init__.py b/pytissueoptics/scene/__init__.py index 4962affa..c3edf9b6 100644 --- a/pytissueoptics/scene/__init__.py +++ b/pytissueoptics/scene/__init__.py @@ -15,7 +15,7 @@ SymmetricLens, ThickLens, ) -from .viewer import MAYAVI_AVAILABLE, MayaviViewer, ViewPointStyle +from .viewer import ViewPointStyle, get3DViewer __all__ = [ "Cuboid", @@ -28,8 +28,7 @@ "Scene", "Loader", "loadSolid", - "MayaviViewer", - "MAYAVI_AVAILABLE", + "get3DViewer", "ViewPointStyle", "Logger", "InteractionKey", diff --git a/pytissueoptics/scene/scene/scene.py b/pytissueoptics/scene/scene/scene.py index 3811e41e..838088c4 100644 --- a/pytissueoptics/scene/scene/scene.py +++ b/pytissueoptics/scene/scene/scene.py @@ -4,8 +4,7 @@ from pytissueoptics.scene.geometry import INTERFACE_KEY, BoundingBox, Environment, Polygon, Vector from pytissueoptics.scene.solids import Solid -from pytissueoptics.scene.viewer.displayable import Displayable -from pytissueoptics.scene.viewer.mayavi import MayaviViewer +from pytissueoptics.scene.viewer import Abstract3DViewer, Displayable class Scene(Displayable): @@ -33,7 +32,7 @@ def add(self, solid: Solid, position: Vector = None): def solids(self): return self._solids - def addToViewer(self, viewer: MayaviViewer, representation="surface", colormap="bone", opacity=0.1, **kwargs): + def addToViewer(self, viewer: Abstract3DViewer, representation="surface", colormap="bone", opacity=0.1, **kwargs): viewer.add(*self.solids, representation=representation, colormap=colormap, opacity=opacity, **kwargs) def getWorldEnvironment(self) -> Environment: diff --git a/pytissueoptics/scene/tests/intersection/benchmarkIntersectionFinder.py b/pytissueoptics/scene/tests/intersection/benchmarkIntersectionFinder.py index f85428c6..413ae6db 100644 --- a/pytissueoptics/scene/tests/intersection/benchmarkIntersectionFinder.py +++ b/pytissueoptics/scene/tests/intersection/benchmarkIntersectionFinder.py @@ -33,7 +33,7 @@ from pytissueoptics.scene.tree.treeConstructor.binary.noSplitOneAxisConstructor import NoSplitOneAxisConstructor from pytissueoptics.scene.tree.treeConstructor.binary.noSplitThreeAxesConstructor import NoSplitThreeAxesConstructor from pytissueoptics.scene.tree.treeConstructor.binary.splitTreeAxesConstructor import SplitThreeAxesConstructor -from pytissueoptics.scene.viewer import MayaviViewer +from pytissueoptics.scene.viewer import get3DViewer pandas.set_option("display.max_columns", 20) pandas.set_option("display.width", 1200) @@ -153,7 +153,7 @@ def _runValidationReference( f" - {str('REFERENCE'):^10}" ) if display: - viewer = MayaviViewer() + viewer = get3DViewer() viewer.addLogger(logger) viewer.show() return missedRays @@ -192,7 +192,7 @@ def _runValidationForConstructor( f" - {partition.getAverageDepth():^10.2f} - {missedRays:^10} - {missedRays == referenceMissed:^10}", ) if display and missedRays != referenceMissed: - viewer = MayaviViewer() + viewer = get3DViewer() viewer.addLogger(logger) viewer.show() @@ -297,7 +297,7 @@ def displayStats(self): def displayBenchmarkTreeResults( self, objectsDisplay: bool = True, scenes: List[Scene] = None, objectsOpacity: float = 0.5 ): - viewer = MayaviViewer() + viewer = get3DViewer() if scenes is None: scenes = self.scenes for j, scene in enumerate(scenes): diff --git a/pytissueoptics/scene/tests/viewer/testMayaviViewer.py b/pytissueoptics/scene/tests/viewer/testMayavi3DViewer.py similarity index 89% rename from pytissueoptics/scene/tests/viewer/testMayaviViewer.py rename to pytissueoptics/scene/tests/viewer/testMayavi3DViewer.py index 1d750b29..908f524c 100644 --- a/pytissueoptics/scene/tests/viewer/testMayaviViewer.py +++ b/pytissueoptics/scene/tests/viewer/testMayavi3DViewer.py @@ -5,12 +5,11 @@ import numpy as np -from pytissueoptics import Logger +from pytissueoptics import Logger, ViewPointStyle from pytissueoptics.scene.geometry import Vector from pytissueoptics.scene.scene import Scene from pytissueoptics.scene.solids import Cuboid, Ellipsoid, Sphere from pytissueoptics.scene.tests import SHOW_VISUAL_TESTS, compareVisuals -from pytissueoptics.scene.viewer.mayavi import MayaviViewer, ViewPointStyle TEST_IMAGES_DIR = os.path.join(os.path.dirname(__file__), "testImages") @@ -26,9 +25,11 @@ def patchMayaviShow(func): @unittest.skipIf( not SHOW_VISUAL_TESTS, "Visual tests are disabled. Set scene.tests.SHOW_VISUAL_TESTS to True to enable them." ) -class TestMayaviViewer(unittest.TestCase): +class TestMayavi3DViewer(unittest.TestCase): def setUp(self): - self.viewer = MayaviViewer() + from pytissueoptics.scene.viewer.mayavi.mayavi3DViewer import Mayavi3DViewer + + self.viewer = Mayavi3DViewer() def testWhenAddLogger_shouldDrawAllLoggerComponents(self): logger = self._getTestLogger() @@ -36,12 +37,12 @@ def testWhenAddLogger_shouldDrawAllLoggerComponents(self): self._assertViewerDisplays("logger_natural") def testGivenOpticsViewPoint_shouldDisplayFromOpticsViewPoint(self): - self.viewer = MayaviViewer(viewPointStyle=ViewPointStyle.OPTICS) + self.viewer.setViewPointStyle(ViewPointStyle.OPTICS) self.viewer.add(self._getSimpleSolid()) self._assertViewerDisplays("solid_optics") def testGivenNaturalFrontViewPoint_shouldDisplayFromNaturalFrontViewPoint(self): - self.viewer = MayaviViewer(viewPointStyle=ViewPointStyle.NATURAL_FRONT) + self.viewer.setViewPointStyle(ViewPointStyle.NATURAL_FRONT) self.viewer.add(self._getSimpleSolid()) self._assertViewerDisplays("solid_natural_front") @@ -51,7 +52,7 @@ def testWhenAddSpecialTestSphere_shouldDrawCorrectly(self): self._assertViewerDisplays("sphere_normals") def testWhenAddImages_shouldDraw2DImagesCorrectly(self): - self.viewer = MayaviViewer(viewPointStyle=ViewPointStyle.NATURAL) + self.viewer.setViewPointStyle(ViewPointStyle.NATURAL) testImage = np.zeros((5, 5)) testImage[4, 4] = 1 for axis in range(3): diff --git a/pytissueoptics/scene/viewer/__init__.py b/pytissueoptics/scene/viewer/__init__.py index a3de150a..808efdb6 100644 --- a/pytissueoptics/scene/viewer/__init__.py +++ b/pytissueoptics/scene/viewer/__init__.py @@ -1,4 +1,6 @@ +from .abstract3DViewer import Abstract3DViewer from .displayable import Displayable -from .mayavi import MAYAVI_AVAILABLE, MayaviViewer, ViewPointStyle +from .provider import get3DViewer +from .viewPoint import ViewPointStyle -__all__ = ["Displayable", "MAYAVI_AVAILABLE", "MayaviViewer", "ViewPointStyle"] +__all__ = ["Displayable", "get3DViewer", "Abstract3DViewer", "ViewPointStyle"] diff --git a/pytissueoptics/scene/viewer/abstract3DViewer.py b/pytissueoptics/scene/viewer/abstract3DViewer.py new file mode 100644 index 00000000..3611d787 --- /dev/null +++ b/pytissueoptics/scene/viewer/abstract3DViewer.py @@ -0,0 +1,58 @@ +from abc import abstractmethod + +import numpy as np + +from pytissueoptics.scene.solids import Solid + +from .viewPoint import ViewPointStyle + + +class Abstract3DViewer: + @abstractmethod + def setViewPointStyle(self, viewPointStyle: ViewPointStyle): ... + + @abstractmethod + def add( + self, + *solids: Solid, + representation="wireframe", + lineWidth=0.25, + showNormals=False, + normalLength=0.3, + colormap="viridis", + reverseColormap=False, + colorWithPosition=False, + opacity=1, + **kwargs, + ): ... + + @abstractmethod + def addDataPoints( + self, + dataPoints: np.ndarray, + colormap="rainbow", + reverseColormap=False, + scale=0.15, + scaleWithValue=True, + asSpheres=True, + ): + """'dataPoints' has to be of shape (n, 4) where the second axis is (value, x, y, z).""" + ... + + @abstractmethod + def addImage( + self, + image: np.ndarray, + size: tuple = None, + minCorner: tuple = (0, 0), + axis: int = 2, + position: float = 0, + colormap: str = "viridis", + ): ... + + @staticmethod + @abstractmethod + def showVolumeSlicer(hist3D: np.ndarray, colormap: str = "viridis", interpolate=False, **kwargs): ... + + @abstractmethod + def show(self): ... diff --git a/pytissueoptics/scene/viewer/displayable.py b/pytissueoptics/scene/viewer/displayable.py index 6d2544e2..7d9e82ed 100644 --- a/pytissueoptics/scene/viewer/displayable.py +++ b/pytissueoptics/scene/viewer/displayable.py @@ -1,6 +1,6 @@ from abc import abstractmethod -from pytissueoptics.scene.viewer.mayavi import MayaviViewer +from .provider import get3DViewer class Displayable: @@ -9,6 +9,6 @@ def addToViewer(self, viewer, **kwargs): pass def show(self, **kwargs): - viewer = MayaviViewer() + viewer = get3DViewer() self.addToViewer(viewer, **kwargs) viewer.show() diff --git a/pytissueoptics/scene/viewer/mayavi/__init__.py b/pytissueoptics/scene/viewer/mayavi/__init__.py index 217e9cc6..3177da28 100644 --- a/pytissueoptics/scene/viewer/mayavi/__init__.py +++ b/pytissueoptics/scene/viewer/mayavi/__init__.py @@ -1,13 +1,8 @@ from .mayaviSolid import MayaviObject, MayaviSolid from .mayaviTriangleMesh import MayaviTriangleMesh -from .mayaviViewer import MAYAVI_AVAILABLE, MayaviViewer -from .viewPoint import ViewPointStyle __all__ = [ "MayaviObject", "MayaviSolid", "MayaviTriangleMesh", - "MayaviViewer", - "MAYAVI_AVAILABLE", - "ViewPointStyle", ] diff --git a/pytissueoptics/scene/viewer/mayavi/mayaviViewer.py b/pytissueoptics/scene/viewer/mayavi/mayavi3DViewer.py similarity index 91% rename from pytissueoptics/scene/viewer/mayavi/mayaviViewer.py rename to pytissueoptics/scene/viewer/mayavi/mayavi3DViewer.py index 78ca9013..22da6b02 100644 --- a/pytissueoptics/scene/viewer/mayavi/mayaviViewer.py +++ b/pytissueoptics/scene/viewer/mayavi/mayavi3DViewer.py @@ -1,32 +1,29 @@ import numpy as np +from mayavi import mlab from pytissueoptics.scene.geometry import BoundingBox from pytissueoptics.scene.logger import Logger -from pytissueoptics.scene.viewer.mayavi.viewPoint import ViewPointFactory, ViewPointStyle - -try: - from mayavi import mlab - - MAYAVI_AVAILABLE = True -except ImportError: - MAYAVI_AVAILABLE = False - from pytissueoptics.scene.solids import Solid +from pytissueoptics.scene.viewer import Abstract3DViewer, ViewPointStyle +from pytissueoptics.scene.viewer.viewPoint import ViewPointFactory from .mayaviSolid import MayaviSolid -class MayaviViewer: - def __init__(self, viewPointStyle=ViewPointStyle.NATURAL): +class Mayavi3DViewer(Abstract3DViewer): + def __init__(self): self._scenes = { "DefaultScene": { "figureParameters": {"bgColor": (0.11, 0.11, 0.11), "fgColor": (0.9, 0.9, 0.9)}, "Solids": [], } } - self._viewPoint = ViewPointFactory().create(viewPointStyle) + self._viewPoint = ViewPointFactory().create(ViewPointStyle.NATURAL) self.clear() + def setViewPointStyle(self, viewPointStyle: ViewPointStyle): + self._viewPoint = ViewPointFactory().create(viewPointStyle) + def add( self, *solids: "Solid", @@ -88,8 +85,8 @@ def addPoints(points: np.ndarray, colormap="rainbow", reverseColormap=False, sca s = mlab.points3d(x, y, z, mode=mode, scale_factor=scale, scale_mode="none", colormap=colormap) s.module_manager.scalar_lut_manager.reverse_lut = reverseColormap - @staticmethod def addDataPoints( + self, dataPoints: np.ndarray, colormap="rainbow", reverseColormap=False, @@ -118,8 +115,8 @@ def addSegments(segments: np.ndarray, colormap="rainbow", reverseColormap=False) s = mlab.plot3d(x, y, z, tube_radius=None, line_width=1, colormap=colormap) s.module_manager.scalar_lut_manager.reverse_lut = reverseColormap - @staticmethod def addImage( + self, image: np.ndarray, size: tuple = None, minCorner: tuple = (0, 0), @@ -203,3 +200,10 @@ def addBBox(self, bbox: BoundingBox, lineWidth=0.25, color=(1, 1, 1), opacity=1. opacity=0, ) mlab.outline(s, line_width=lineWidth, color=color, opacity=opacity, **kwargs) + + @staticmethod + def showVolumeSlicer(hist3D: np.ndarray, colormap="viridis", interpolate=False, **kwargs): + from .mayaviVolumeSlicer import MayaviVolumeSlicer + + slicer = MayaviVolumeSlicer(hist3D, colormap=colormap, interpolate=interpolate, **kwargs) + slicer.show() diff --git a/pytissueoptics/rayscattering/display/utils/volumeSlicer.py b/pytissueoptics/scene/viewer/mayavi/mayaviVolumeSlicer.py similarity index 97% rename from pytissueoptics/rayscattering/display/utils/volumeSlicer.py rename to pytissueoptics/scene/viewer/mayavi/mayaviVolumeSlicer.py index aa2f30ec..648834a1 100644 --- a/pytissueoptics/rayscattering/display/utils/volumeSlicer.py +++ b/pytissueoptics/scene/viewer/mayavi/mayaviVolumeSlicer.py @@ -35,7 +35,7 @@ VIEW = e -class VolumeSlicer(HasTraits): +class MayaviVolumeSlicer(HasTraits): # The data to plot data = Array() @@ -67,7 +67,7 @@ def __init__(self, hist3D: np.ndarray, colormap: str = "viridis", interpolate=Fa if isinstance(VIEW, Exception): raise VIEW - super(VolumeSlicer, self).__init__(data=hist3D, **traits) + super(MayaviVolumeSlicer, self).__init__(data=hist3D, **traits) # Force the creation of the image_plane_widgets: for ipw in (self.ipw_3d_x, self.ipw_3d_y, self.ipw_3d_z): @@ -184,5 +184,5 @@ def display_scene_z(self): x, y, z = np.ogrid[-5:5:64j, -5:5:64j, -5:5:64j] data = np.sin(3 * x) / x + 0.05 * z**2 + np.cos(3 * y) - m = VolumeSlicer(data) + m = MayaviVolumeSlicer(data) m.show() diff --git a/pytissueoptics/scene/viewer/null3DViewer.py b/pytissueoptics/scene/viewer/null3DViewer.py new file mode 100644 index 00000000..274ac000 --- /dev/null +++ b/pytissueoptics/scene/viewer/null3DViewer.py @@ -0,0 +1,56 @@ +import warnings + +import numpy as np +from scene.solids import Solid +from scene.viewer.abstract3DViewer import Abstract3DViewer + +from pytissueoptics import ViewPointStyle + + +class Null3DViewer(Abstract3DViewer): + def setViewPointStyle(self, viewPointStyle: ViewPointStyle): + pass + + def add( + self, + *solids: Solid, + representation="wireframe", + lineWidth=0.25, + showNormals=False, + normalLength=0.3, + colormap="viridis", + reverseColormap=False, + colorWithPosition=False, + opacity=1, + **kwargs, + ): + pass + + def addDataPoints( + self, + dataPoints: np.ndarray, + colormap="rainbow", + reverseColormap=False, + scale=0.15, + scaleWithValue=True, + asSpheres=True, + ): + pass + + def addImage( + self, + image: np.ndarray, + size: tuple = None, + minCorner: tuple = (0, 0), + axis: int = 2, + position: float = 0, + colormap: str = "viridis", + ): + pass + + @staticmethod + def showVolumeSlicer(hist3D: np.ndarray, colormap: str = "viridis", interpolate=False, **kwargs): + warnings.warn("Attempting to show a volume slicer with a Null3DViewer. No action will be taken.") + + def show(self): + warnings.warn("Attempting to show a Null3DViewer. No action will be taken.") diff --git a/pytissueoptics/scene/viewer/provider.py b/pytissueoptics/scene/viewer/provider.py new file mode 100644 index 00000000..80c651e1 --- /dev/null +++ b/pytissueoptics/scene/viewer/provider.py @@ -0,0 +1,30 @@ +import os +import warnings + +from .abstract3DViewer import Abstract3DViewer + +AVAILABLE_BACKENDS = ("mayavi", "null") + + +def get3DViewer() -> Abstract3DViewer: + backend = os.environ.get("PTO_3D_BACKEND", "mayavi").lower() + if backend == "mayavi": + try: + from .mayavi.mayavi3DViewer import Mayavi3DViewer + + return Mayavi3DViewer() + except Exception as e: + warnings.warn( + "Mayavi is not available. Falling back to a null 3D viewer. Fix the following error to use the Mayavi " + "backend or select another backend by setting the PTO_3D_BACKEND environment variable (available " + f"backends: {AVAILABLE_BACKENDS}). \n{e}" + ) + from .null3DViewer import Null3DViewer + + return Null3DViewer() + elif backend == "null": + from .null3DViewer import Null3DViewer + + return Null3DViewer() + else: + raise ValueError(f"Invalid backend '{backend}'. Available backends: {AVAILABLE_BACKENDS}") diff --git a/pytissueoptics/scene/viewer/mayavi/viewPoint.py b/pytissueoptics/scene/viewer/viewPoint.py similarity index 93% rename from pytissueoptics/scene/viewer/mayavi/viewPoint.py rename to pytissueoptics/scene/viewer/viewPoint.py index 010161ff..4728b967 100644 --- a/pytissueoptics/scene/viewer/mayavi/viewPoint.py +++ b/pytissueoptics/scene/viewer/viewPoint.py @@ -28,6 +28,8 @@ def create(self, viewPointStyle: ViewPointStyle): return self.getNaturalViewPoint() elif viewPointStyle == ViewPointStyle.NATURAL_FRONT: return self.getNaturalFrontViewPoint() + else: + raise ValueError(f"Invalid viewpoint style: {viewPointStyle}") @staticmethod def getOpticsViewPoint():