From d25021bb4ad268dd4b42fe03f9c199a0c57b8769 Mon Sep 17 00:00:00 2001 From: JLBegin Date: Tue, 22 Apr 2025 20:38:13 -0400 Subject: [PATCH 1/3] define single interface for 3D viewer backends --- .../rayscattering/display/utils/__init__.py | 2 - .../rayscattering/display/viewer.py | 8 +-- .../scene/tests/viewer/testMayaviViewer.py | 13 +++-- pytissueoptics/scene/viewer/__init__.py | 3 +- .../scene/viewer/abstract3DViewer.py | 57 +++++++++++++++++++ .../scene/viewer/mayavi/__init__.py | 2 - .../scene/viewer/mayavi/mayaviViewer.py | 24 ++++++-- .../viewer/mayavi/mayaviVolumeSlicer.py} | 6 +- .../scene/viewer/{mayavi => }/viewPoint.py | 2 + 9 files changed, 93 insertions(+), 24 deletions(-) create mode 100644 pytissueoptics/scene/viewer/abstract3DViewer.py rename pytissueoptics/{rayscattering/display/utils/volumeSlicer.py => scene/viewer/mayavi/mayaviVolumeSlicer.py} (97%) rename pytissueoptics/scene/viewer/{mayavi => }/viewPoint.py (93%) diff --git a/pytissueoptics/rayscattering/display/utils/__init__.py b/pytissueoptics/rayscattering/display/utils/__init__.py index ae6087cb..8aa886cc 100644 --- a/pytissueoptics/rayscattering/display/utils/__init__.py +++ b/pytissueoptics/rayscattering/display/utils/__init__.py @@ -1,10 +1,8 @@ from .direction import DEFAULT_X_VIEW_DIRECTIONS, DEFAULT_Y_VIEW_DIRECTIONS, DEFAULT_Z_VIEW_DIRECTIONS, Direction -from .volumeSlicer import VolumeSlicer __all__ = [ "DEFAULT_X_VIEW_DIRECTIONS", "DEFAULT_Y_VIEW_DIRECTIONS", "DEFAULT_Z_VIEW_DIRECTIONS", "Direction", - "VolumeSlicer", ] diff --git a/pytissueoptics/rayscattering/display/viewer.py b/pytissueoptics/rayscattering/display/viewer.py index 1f21280c..bd84c758 100644 --- a/pytissueoptics/rayscattering/display/viewer.py +++ b/pytissueoptics/rayscattering/display/viewer.py @@ -116,7 +116,8 @@ def show3D( utils.warn("Package 'mayavi' is not available. Please install it to use 3D visualizations.") return - self._viewer3D = MayaviViewer(viewPointStyle=ViewPointStyle.OPTICS) + self._viewer3D = MayaviViewer() + self._viewer3D.setViewPointStyle(ViewPointStyle.OPTICS) if visibility == Visibility.AUTO: visibility = Visibility.DEFAULT_3D if self._logger.has3D else Visibility.DEFAULT_2D @@ -182,10 +183,7 @@ def show3DVolumeSlicer( if logScale: hist = utils.logNorm(hist) - from pytissueoptics.rayscattering.display.utils.volumeSlicer import VolumeSlicer - - slicer = VolumeSlicer(hist, interpolate=interpolate) - slicer.show() + self._viewer3D.showVolumeSlicer(hist, interpolate=interpolate) def show2D(self, view: View2D = None, viewIndex: int = None, logScale: bool = True, colormap: str = "viridis"): self._logger.showView(view=view, viewIndex=viewIndex, logScale=logScale, colormap=colormap) diff --git a/pytissueoptics/scene/tests/viewer/testMayaviViewer.py b/pytissueoptics/scene/tests/viewer/testMayaviViewer.py index 1d750b29..3cf91c03 100644 --- a/pytissueoptics/scene/tests/viewer/testMayaviViewer.py +++ b/pytissueoptics/scene/tests/viewer/testMayaviViewer.py @@ -5,12 +5,12 @@ import numpy as np -from pytissueoptics import Logger +from pytissueoptics import Logger, ViewPointStyle from pytissueoptics.scene.geometry import Vector from pytissueoptics.scene.scene import Scene from pytissueoptics.scene.solids import Cuboid, Ellipsoid, Sphere from pytissueoptics.scene.tests import SHOW_VISUAL_TESTS, compareVisuals -from pytissueoptics.scene.viewer.mayavi import MayaviViewer, ViewPointStyle +from pytissueoptics.scene.viewer.mayavi import MayaviViewer TEST_IMAGES_DIR = os.path.join(os.path.dirname(__file__), "testImages") @@ -36,12 +36,14 @@ def testWhenAddLogger_shouldDrawAllLoggerComponents(self): self._assertViewerDisplays("logger_natural") def testGivenOpticsViewPoint_shouldDisplayFromOpticsViewPoint(self): - self.viewer = MayaviViewer(viewPointStyle=ViewPointStyle.OPTICS) + self.viewer = MayaviViewer() + self.viewer.setViewPointStyle(ViewPointStyle.OPTICS) self.viewer.add(self._getSimpleSolid()) self._assertViewerDisplays("solid_optics") def testGivenNaturalFrontViewPoint_shouldDisplayFromNaturalFrontViewPoint(self): - self.viewer = MayaviViewer(viewPointStyle=ViewPointStyle.NATURAL_FRONT) + self.viewer = MayaviViewer() + self.viewer.setViewPointStyle(ViewPointStyle.NATURAL_FRONT) self.viewer.add(self._getSimpleSolid()) self._assertViewerDisplays("solid_natural_front") @@ -51,7 +53,8 @@ def testWhenAddSpecialTestSphere_shouldDrawCorrectly(self): self._assertViewerDisplays("sphere_normals") def testWhenAddImages_shouldDraw2DImagesCorrectly(self): - self.viewer = MayaviViewer(viewPointStyle=ViewPointStyle.NATURAL) + self.viewer = MayaviViewer() + self.viewer.setViewPointStyle(ViewPointStyle.NATURAL) testImage = np.zeros((5, 5)) testImage[4, 4] = 1 for axis in range(3): diff --git a/pytissueoptics/scene/viewer/__init__.py b/pytissueoptics/scene/viewer/__init__.py index a3de150a..cc198448 100644 --- a/pytissueoptics/scene/viewer/__init__.py +++ b/pytissueoptics/scene/viewer/__init__.py @@ -1,4 +1,5 @@ from .displayable import Displayable -from .mayavi import MAYAVI_AVAILABLE, MayaviViewer, ViewPointStyle +from .mayavi import MAYAVI_AVAILABLE, MayaviViewer +from .viewPoint import ViewPointStyle __all__ = ["Displayable", "MAYAVI_AVAILABLE", "MayaviViewer", "ViewPointStyle"] diff --git a/pytissueoptics/scene/viewer/abstract3DViewer.py b/pytissueoptics/scene/viewer/abstract3DViewer.py new file mode 100644 index 00000000..73204f1d --- /dev/null +++ b/pytissueoptics/scene/viewer/abstract3DViewer.py @@ -0,0 +1,57 @@ +from abc import abstractmethod + +import numpy as np + +from pytissueoptics.scene.solids import Solid +from .viewPoint import ViewPointStyle + + +class Abstract3DViewer: + @abstractmethod + def setViewPointStyle(self, viewPointStyle: ViewPointStyle): ... + + @abstractmethod + def add( + self, + *solids: Solid, + representation="wireframe", + lineWidth=0.25, + showNormals=False, + normalLength=0.3, + colormap="viridis", + reverseColormap=False, + colorWithPosition=False, + opacity=1, + **kwargs, + ): ... + + @abstractmethod + def addDataPoints( + self, + dataPoints: np.ndarray, + colormap="rainbow", + reverseColormap=False, + scale=0.15, + scaleWithValue=True, + asSpheres=True, + ): + """'dataPoints' has to be of shape (n, 4) where the second axis is (value, x, y, z).""" + ... + + @abstractmethod + def addImage( + self, + image: np.ndarray, + size: tuple = None, + minCorner: tuple = (0, 0), + axis: int = 2, + position: float = 0, + colormap: str = "viridis", + ): ... + + @staticmethod + @abstractmethod + def showVolumeSlicer(hist3D: np.ndarray, colormap: str = "viridis", interpolate=False, **kwargs): ... + + @abstractmethod + def show(self): ... diff --git a/pytissueoptics/scene/viewer/mayavi/__init__.py b/pytissueoptics/scene/viewer/mayavi/__init__.py index 217e9cc6..03718e73 100644 --- a/pytissueoptics/scene/viewer/mayavi/__init__.py +++ b/pytissueoptics/scene/viewer/mayavi/__init__.py @@ -1,7 +1,6 @@ from .mayaviSolid import MayaviObject, MayaviSolid from .mayaviTriangleMesh import MayaviTriangleMesh from .mayaviViewer import MAYAVI_AVAILABLE, MayaviViewer -from .viewPoint import ViewPointStyle __all__ = [ "MayaviObject", @@ -9,5 +8,4 @@ "MayaviTriangleMesh", "MayaviViewer", "MAYAVI_AVAILABLE", - "ViewPointStyle", ] diff --git a/pytissueoptics/scene/viewer/mayavi/mayaviViewer.py b/pytissueoptics/scene/viewer/mayavi/mayaviViewer.py index 78ca9013..70742b3e 100644 --- a/pytissueoptics/scene/viewer/mayavi/mayaviViewer.py +++ b/pytissueoptics/scene/viewer/mayavi/mayaviViewer.py @@ -2,7 +2,9 @@ from pytissueoptics.scene.geometry import BoundingBox from pytissueoptics.scene.logger import Logger -from pytissueoptics.scene.viewer.mayavi.viewPoint import ViewPointFactory, ViewPointStyle + +from ..abstract3DViewer import Abstract3DViewer +from ..viewPoint import ViewPointFactory, ViewPointStyle try: from mayavi import mlab @@ -16,17 +18,20 @@ from .mayaviSolid import MayaviSolid -class MayaviViewer: - def __init__(self, viewPointStyle=ViewPointStyle.NATURAL): +class MayaviViewer(Abstract3DViewer): + def __init__(self): self._scenes = { "DefaultScene": { "figureParameters": {"bgColor": (0.11, 0.11, 0.11), "fgColor": (0.9, 0.9, 0.9)}, "Solids": [], } } - self._viewPoint = ViewPointFactory().create(viewPointStyle) + self._viewPoint = ViewPointFactory().create(ViewPointStyle.NATURAL) self.clear() + def setViewPointStyle(self, viewPointStyle: ViewPointStyle): + self._viewPoint = ViewPointFactory().create(viewPointStyle) + def add( self, *solids: "Solid", @@ -88,8 +93,8 @@ def addPoints(points: np.ndarray, colormap="rainbow", reverseColormap=False, sca s = mlab.points3d(x, y, z, mode=mode, scale_factor=scale, scale_mode="none", colormap=colormap) s.module_manager.scalar_lut_manager.reverse_lut = reverseColormap - @staticmethod def addDataPoints( + self, dataPoints: np.ndarray, colormap="rainbow", reverseColormap=False, @@ -118,8 +123,8 @@ def addSegments(segments: np.ndarray, colormap="rainbow", reverseColormap=False) s = mlab.plot3d(x, y, z, tube_radius=None, line_width=1, colormap=colormap) s.module_manager.scalar_lut_manager.reverse_lut = reverseColormap - @staticmethod def addImage( + self, image: np.ndarray, size: tuple = None, minCorner: tuple = (0, 0), @@ -203,3 +208,10 @@ def addBBox(self, bbox: BoundingBox, lineWidth=0.25, color=(1, 1, 1), opacity=1. opacity=0, ) mlab.outline(s, line_width=lineWidth, color=color, opacity=opacity, **kwargs) + + @staticmethod + def showVolumeSlicer(hist3D: np.ndarray, colormap="viridis", interpolate=False, **kwargs): + from .mayaviVolumeSlicer import MayaviVolumeSlicer + + slicer = MayaviVolumeSlicer(hist3D, colormap=colormap, interpolate=interpolate, **kwargs) + slicer.show() diff --git a/pytissueoptics/rayscattering/display/utils/volumeSlicer.py b/pytissueoptics/scene/viewer/mayavi/mayaviVolumeSlicer.py similarity index 97% rename from pytissueoptics/rayscattering/display/utils/volumeSlicer.py rename to pytissueoptics/scene/viewer/mayavi/mayaviVolumeSlicer.py index aa2f30ec..648834a1 100644 --- a/pytissueoptics/rayscattering/display/utils/volumeSlicer.py +++ b/pytissueoptics/scene/viewer/mayavi/mayaviVolumeSlicer.py @@ -35,7 +35,7 @@ VIEW = e -class VolumeSlicer(HasTraits): +class MayaviVolumeSlicer(HasTraits): # The data to plot data = Array() @@ -67,7 +67,7 @@ def __init__(self, hist3D: np.ndarray, colormap: str = "viridis", interpolate=Fa if isinstance(VIEW, Exception): raise VIEW - super(VolumeSlicer, self).__init__(data=hist3D, **traits) + super(MayaviVolumeSlicer, self).__init__(data=hist3D, **traits) # Force the creation of the image_plane_widgets: for ipw in (self.ipw_3d_x, self.ipw_3d_y, self.ipw_3d_z): @@ -184,5 +184,5 @@ def display_scene_z(self): x, y, z = np.ogrid[-5:5:64j, -5:5:64j, -5:5:64j] data = np.sin(3 * x) / x + 0.05 * z**2 + np.cos(3 * y) - m = VolumeSlicer(data) + m = MayaviVolumeSlicer(data) m.show() diff --git a/pytissueoptics/scene/viewer/mayavi/viewPoint.py b/pytissueoptics/scene/viewer/viewPoint.py similarity index 93% rename from pytissueoptics/scene/viewer/mayavi/viewPoint.py rename to pytissueoptics/scene/viewer/viewPoint.py index 010161ff..4728b967 100644 --- a/pytissueoptics/scene/viewer/mayavi/viewPoint.py +++ b/pytissueoptics/scene/viewer/viewPoint.py @@ -28,6 +28,8 @@ def create(self, viewPointStyle: ViewPointStyle): return self.getNaturalViewPoint() elif viewPointStyle == ViewPointStyle.NATURAL_FRONT: return self.getNaturalFrontViewPoint() + else: + raise ValueError(f"Invalid viewpoint style: {viewPointStyle}") @staticmethod def getOpticsViewPoint(): From 3337a9ed10f92a71c82d35659fd77f32d4f9e940 Mon Sep 17 00:00:00 2001 From: JLBegin Date: Tue, 22 Apr 2025 21:54:42 -0400 Subject: [PATCH 2/3] wrap 3D viewer backend under provider safe mayavi import, null backend option to bypass mayavi, and option for other backends in the future. removed mentions and imports of mayavi in unrelated tests. --- pytissueoptics/__init__.py | 3 +- pytissueoptics/examples/scene/example0.py | 4 +- pytissueoptics/examples/scene/example1.py | 4 +- pytissueoptics/examples/scene/example2.py | 6 +- pytissueoptics/examples/scene/example3.py | 4 +- pytissueoptics/examples/scene/example4.py | 4 +- .../rayscattering/display/viewer.py | 14 +- .../rayscattering/opencl/__init__.py | 3 +- .../rayscattering/scatteringScene.py | 4 +- pytissueoptics/rayscattering/source.py | 6 +- .../rayscattering/tests/display/testViewer.py | 139 +++++++----------- .../tests/testScatteringScene.py | 6 +- pytissueoptics/scene/__init__.py | 5 +- pytissueoptics/scene/scene/scene.py | 5 +- .../benchmarkIntersectionFinder.py | 8 +- ...tMayaviViewer.py => testMayavi3DViewer.py} | 10 +- pytissueoptics/scene/viewer/__init__.py | 5 +- .../scene/viewer/abstract3DViewer.py | 1 + pytissueoptics/scene/viewer/displayable.py | 4 +- .../scene/viewer/mayavi/__init__.py | 3 - .../{mayaviViewer.py => mayavi3DViewer.py} | 16 +- pytissueoptics/scene/viewer/null3DViewer.py | 56 +++++++ pytissueoptics/scene/viewer/provider.py | 30 ++++ 23 files changed, 186 insertions(+), 154 deletions(-) rename pytissueoptics/scene/tests/viewer/{testMayaviViewer.py => testMayavi3DViewer.py} (94%) rename pytissueoptics/scene/viewer/mayavi/{mayaviViewer.py => mayavi3DViewer.py} (96%) create mode 100644 pytissueoptics/scene/viewer/null3DViewer.py create mode 100644 pytissueoptics/scene/viewer/provider.py diff --git a/pytissueoptics/__init__.py b/pytissueoptics/__init__.py index b90fc8af..5921717f 100644 --- a/pytissueoptics/__init__.py +++ b/pytissueoptics/__init__.py @@ -1,5 +1,6 @@ from importlib.metadata import version -from .scene import * # noqa: F403 + from .rayscattering import * # noqa: F403 +from .scene import * # noqa: F403 __version__ = version("pytissueoptics") diff --git a/pytissueoptics/examples/scene/example0.py b/pytissueoptics/examples/scene/example0.py index 7bf1cc1d..fae1ba52 100644 --- a/pytissueoptics/examples/scene/example0.py +++ b/pytissueoptics/examples/scene/example0.py @@ -10,13 +10,13 @@ def exampleCode(): - from pytissueoptics.scene import Cuboid, Ellipsoid, MayaviViewer, Sphere, Vector + from pytissueoptics.scene import Cuboid, Ellipsoid, Sphere, Vector, get3DViewer cuboid = Cuboid(a=1, b=3, c=1, position=Vector(1, 0, 0)) sphere = Sphere(radius=0.5, position=Vector(0, 0, 0)) ellipsoid = Ellipsoid(a=1.5, b=1, c=1, position=Vector(-2, 0, 0)) - viewer = MayaviViewer() + viewer = get3DViewer() viewer.add(cuboid, sphere, ellipsoid, representation="surface") viewer.show() diff --git a/pytissueoptics/examples/scene/example1.py b/pytissueoptics/examples/scene/example1.py index bf5b2f42..fa43ee49 100644 --- a/pytissueoptics/examples/scene/example1.py +++ b/pytissueoptics/examples/scene/example1.py @@ -8,7 +8,7 @@ def exampleCode(): - from pytissueoptics.scene import Cuboid, MayaviViewer, Vector + from pytissueoptics.scene import Cuboid, Vector, get3DViewer centerCube = Cuboid(a=1, b=1, c=1, position=Vector(0, 0, 0)) topCube = Cuboid(a=1, b=1, c=1, position=Vector(0, 2, 0)) @@ -17,7 +17,7 @@ def exampleCode(): bottomCube.translateTo(Vector(0, -2, 0)) centerCube.rotate(0, 30, 30) - viewer = MayaviViewer() + viewer = get3DViewer() viewer.add(centerCube, topCube, bottomCube, representation="surface") viewer.show() diff --git a/pytissueoptics/examples/scene/example2.py b/pytissueoptics/examples/scene/example2.py index 28b866a4..b58415ee 100644 --- a/pytissueoptics/examples/scene/example2.py +++ b/pytissueoptics/examples/scene/example2.py @@ -7,20 +7,20 @@ def exampleCode(): - from pytissueoptics.scene import Cuboid, MayaviViewer, Vector + from pytissueoptics.scene import Cuboid, Vector, get3DViewer cuboid1 = Cuboid(1, 1, 1, position=Vector(2, 0, 0)) cuboid2 = Cuboid(2, 1, 1, position=Vector(0, 2, 0)) cuboid3 = Cuboid(3, 1, 1, position=Vector(0, 0, 2)) - viewer = MayaviViewer() + viewer = get3DViewer() viewer.add(cuboid1, cuboid2, cuboid3, representation="wireframe", lineWidth=5) viewer.show() cuboidStack = cuboid1.stack(cuboid2, onSurface="right") cuboidStack = cuboidStack.stack(cuboid3, onSurface="top") - viewer.clear() + viewer = get3DViewer() viewer.add(cuboidStack, representation="wireframe", lineWidth=5) viewer.show() diff --git a/pytissueoptics/examples/scene/example3.py b/pytissueoptics/examples/scene/example3.py index b61e0182..cf4b93f2 100644 --- a/pytissueoptics/examples/scene/example3.py +++ b/pytissueoptics/examples/scene/example3.py @@ -6,11 +6,11 @@ def exampleCode(): - from pytissueoptics.scene import MayaviViewer, loadSolid + from pytissueoptics.scene import get3DViewer, loadSolid solid = loadSolid("pytissueoptics/examples/scene/droid.obj") - viewer = MayaviViewer() + viewer = get3DViewer() viewer.add(solid, representation="surface", showNormals=True, normalLength=0.2) viewer.show() diff --git a/pytissueoptics/examples/scene/example4.py b/pytissueoptics/examples/scene/example4.py index 47fd42e7..0011fa72 100644 --- a/pytissueoptics/examples/scene/example4.py +++ b/pytissueoptics/examples/scene/example4.py @@ -7,13 +7,13 @@ def exampleCode(): from pytissueoptics.scene import ( - MayaviViewer, PlanoConcaveLens, PlanoConvexLens, RefractiveMaterial, SymmetricLens, ThickLens, Vector, + get3DViewer, ) material = RefractiveMaterial(refractiveIndex=1.44) @@ -22,7 +22,7 @@ def exampleCode(): lens2 = PlanoConvexLens(f=-60, diameter=25.4, thickness=4, material=material, position=Vector(0, 0, 20)) lens4 = PlanoConcaveLens(f=60, diameter=25.4, thickness=4, material=material, position=Vector(0, 0, 30)) - viewer = MayaviViewer() + viewer = get3DViewer() viewer.add(lens1, lens2, lens3, lens4, representation="surface", colormap="viridis", showNormals=False) viewer.show() diff --git a/pytissueoptics/rayscattering/display/viewer.py b/pytissueoptics/rayscattering/display/viewer.py index bd84c758..614ca2e4 100644 --- a/pytissueoptics/rayscattering/display/viewer.py +++ b/pytissueoptics/rayscattering/display/viewer.py @@ -12,7 +12,7 @@ from pytissueoptics.rayscattering.scatteringScene import ScatteringScene from pytissueoptics.rayscattering.source import Source from pytissueoptics.rayscattering.statistics import Stats -from pytissueoptics.scene import MAYAVI_AVAILABLE, MayaviViewer, ViewPointStyle +from pytissueoptics.scene import ViewPointStyle, get3DViewer class Visibility(Flag): @@ -112,11 +112,7 @@ def show3D( viewsLogScale: bool = True, viewsColormap: str = "viridis", ): - if not MAYAVI_AVAILABLE: - utils.warn("Package 'mayavi' is not available. Please install it to use 3D visualizations.") - return - - self._viewer3D = MayaviViewer() + self._viewer3D = get3DViewer() self._viewer3D.setViewPointStyle(ViewPointStyle.OPTICS) if visibility == Visibility.AUTO: @@ -147,10 +143,6 @@ def show3DVolumeSlicer( interpolate: bool = False, limits: Tuple[tuple, tuple, tuple] = None, ): - if not MAYAVI_AVAILABLE: - utils.warn("ERROR: Package 'mayavi' is not available. Please install it to use 3D visualizations.") - return - if not self._logger.has3D: utils.warn("ERROR: Cannot show 3D volume slicer without 3D data.") return @@ -183,7 +175,7 @@ def show3DVolumeSlicer( if logScale: hist = utils.logNorm(hist) - self._viewer3D.showVolumeSlicer(hist, interpolate=interpolate) + get3DViewer().showVolumeSlicer(hist, interpolate=interpolate) def show2D(self, view: View2D = None, viewIndex: int = None, logScale: bool = True, colormap: str = "viridis"): self._logger.showView(view=view, viewIndex=viewIndex, logScale=logScale, colormap=colormap) diff --git a/pytissueoptics/rayscattering/opencl/__init__.py b/pytissueoptics/rayscattering/opencl/__init__.py index e923c2fa..2d9606c5 100644 --- a/pytissueoptics/rayscattering/opencl/__init__.py +++ b/pytissueoptics/rayscattering/opencl/__init__.py @@ -1,6 +1,7 @@ +import os + from pytissueoptics.rayscattering.opencl.config.CLConfig import OPENCL_AVAILABLE, WEIGHT_THRESHOLD, CLConfig, warnings from pytissueoptics.rayscattering.opencl.config.IPPTable import IPPTable -import os OPENCL_OK = True OPENCL_DISABLED = os.environ.get("PTO_DISABLE_OPENCL", "0") == "1" diff --git a/pytissueoptics/rayscattering/scatteringScene.py b/pytissueoptics/rayscattering/scatteringScene.py index 7d223610..2cb05cbe 100644 --- a/pytissueoptics/rayscattering/scatteringScene.py +++ b/pytissueoptics/rayscattering/scatteringScene.py @@ -3,7 +3,7 @@ import numpy as np from pytissueoptics.rayscattering.materials import ScatteringMaterial -from pytissueoptics.scene import MayaviViewer, Scene, Vector +from pytissueoptics.scene import Scene, Vector, get3DViewer from pytissueoptics.scene.solids import Solid from pytissueoptics.scene.viewer.displayable import Displayable @@ -22,7 +22,7 @@ def add(self, solid: Solid, position: Vector = None): super().add(solid, position) def show(self, source: Displayable = None, opacity=0.8, colormap="cool", **kwargs): - viewer = MayaviViewer() + viewer = get3DViewer() self.addToViewer(viewer, opacity=opacity, colormap=colormap, **kwargs) if source: source.addToViewer(viewer) diff --git a/pytissueoptics/rayscattering/source.py b/pytissueoptics/rayscattering/source.py index 042f299f..59c0c1f1 100644 --- a/pytissueoptics/rayscattering/source.py +++ b/pytissueoptics/rayscattering/source.py @@ -18,7 +18,7 @@ from pytissueoptics.scene.solids.cone import Cone from pytissueoptics.scene.solids.cylinder import Cylinder from pytissueoptics.scene.utils import progressBar -from pytissueoptics.scene.viewer import Displayable, MayaviViewer +from pytissueoptics.scene.viewer import Abstract3DViewer, Displayable class Source(Displayable): @@ -184,7 +184,7 @@ def photons(self): def getPhotonCount(self) -> int: return self._N - def addToViewer(self, viewer: MayaviViewer, representation="surface", colormap="Wistia", opacity=1.0, **kwargs): + def addToViewer(self, viewer: Abstract3DViewer, representation="surface", colormap="Wistia", opacity=1.0, **kwargs): sphere = Sphere(radius=self.displaySize / 2, position=self._position) viewer.add(sphere, representation=representation, colormap=colormap, opacity=opacity, **kwargs) @@ -227,7 +227,7 @@ def getInitialPositionsAndDirections(self) -> Tuple[np.ndarray, np.ndarray]: directions = self._getInitialDirections() return positions, directions - def addToViewer(self, viewer: MayaviViewer, representation="surface", colormap="Wistia", opacity=1, **kwargs): + def addToViewer(self, viewer: Abstract3DViewer, representation="surface", colormap="Wistia", opacity=1, **kwargs): baseHeight = 0.5 * self.displaySize baseCenter = self._position + self._direction * baseHeight / 2 base = Cylinder(radius=self.displaySize / 8, length=baseHeight, position=baseCenter) diff --git a/pytissueoptics/rayscattering/tests/display/testViewer.py b/pytissueoptics/rayscattering/tests/display/testViewer.py index 6087c52c..4115411a 100644 --- a/pytissueoptics/rayscattering/tests/display/testViewer.py +++ b/pytissueoptics/rayscattering/tests/display/testViewer.py @@ -1,8 +1,9 @@ import unittest -from unittest.mock import patch +from unittest.mock import Mock, patch import numpy as np from mockito import ANY, mock, verify, when +from scene.viewer import Abstract3DViewer from pytissueoptics import Direction, View2DProjectionX, ViewGroup from pytissueoptics.rayscattering.display.profiles import Profile1D, ProfileFactory @@ -15,12 +16,6 @@ from pytissueoptics.scene.logger import Logger -def patchMayaviRender(func): - for module in ["show", "gcf", "figure", "clf", "triangular_mesh", "points3d"]: - func = patch("mayavi.mlab." + module)(func) - return func - - class TestViewer(unittest.TestCase): def setUp(self): self.scene = mock(ScatteringScene) @@ -34,6 +29,11 @@ def setUp(self): self.logger.info = {"photonCount": 0, "sourceSolidLabel": None} self.viewer = Viewer(self.scene, self.source, self.logger) + self.mock3DViewer = Mock(spec=Abstract3DViewer) + p = patch("pytissueoptics.rayscattering.display.viewer.get3DViewer", return_value=self.mock3DViewer) + self.addCleanup(p.stop) + p.start() + def testGivenViewerWithBaseLogger_shouldRaiseException(self): with self.assertRaises(AssertionError): self.viewer = Viewer(self.scene, self.source, Logger()) @@ -101,32 +101,17 @@ def testWhenShow2DAllViewsWithViewGroup_shouldOnlyShowViewsOfThisGroup(self): verify(self.logger, times=1).showView(view=None, viewIndex=1, logScale=True, colormap=ANY(str)) verify(self.logger, times=0).showView(view=None, viewIndex=0, logScale=True, colormap=ANY(str)) - def testWhenShow3DWithoutMayaviInstalled_shouldWarnAndIgnore(self): - from pytissueoptics.rayscattering.display import viewer - - viewer.MAYAVI_AVAILABLE = False - with self.assertWarns(UserWarning): - self.viewer.show3D() - - @patchMayaviRender - def testWhenShow3DWithScene_shouldDisplayScene(self, mockShow, *args): + def testWhenShow3DWithScene_shouldDisplayScene(self): self.viewer.show3D(visibility=Visibility.SCENE) - verify(self.scene, times=1).addToViewer(...) - mockShow.assert_called_once() + self.mock3DViewer.show.assert_called_once() - @patchMayaviRender - def testWhenShow3DWithSource_shouldDisplaySource(self, mockShow, *args): + def testWhenShow3DWithSource_shouldDisplaySource(self): self.viewer.show3D(visibility=Visibility.SOURCE) - verify(self.source, times=1).addToViewer(...) - mockShow.assert_called_once() + self.mock3DViewer.show.assert_called_once() - @patchMayaviRender - @patch("pytissueoptics.scene.viewer.MayaviViewer.addDataPoints") - def testWhenShow3DWithDefaultPointCloud_shouldDisplayPointCloudOfSolidsAndSurfaceLeaving( - self, mockAddDataPoints, mockShow, *args - ): + def testWhenShow3DWithDefaultPointCloud_shouldDisplayPointCloudOfSolidsAndSurfaceLeaving(self): mockPointCloudFactory = mock(PointCloudFactory) aPointCloud = PointCloud( solidPoints=np.array([[0.5, 0, 0, 0]]), surfacePoints=np.array([[1, 0, 0, 0], [-1, 0, 0, 0]]) @@ -136,17 +121,15 @@ def testWhenShow3DWithDefaultPointCloud_shouldDisplayPointCloudOfSolidsAndSurfac self.viewer.show3D(visibility=Visibility.POINT_CLOUD) - mockAddDataPoints.assert_called() - addedSolidPoints = mockAddDataPoints.call_args_list[0][0][0] - addedSurfacePoints = mockAddDataPoints.call_args_list[1][0][0] + self.mock3DViewer.addDataPoints.assert_called() + addedSolidPoints = self.mock3DViewer.addDataPoints.call_args_list[0][0][0] + addedSurfacePoints = self.mock3DViewer.addDataPoints.call_args_list[1][0][0] self.assertTrue(np.array_equal(addedSolidPoints, aPointCloud.solidPoints)) self.assertTrue(np.array_equal(addedSurfacePoints, aPointCloud.leavingSurfacePoints)) - mockShow.assert_called_once() + self.mock3DViewer.show.assert_called_once() - @patchMayaviRender - @patch("pytissueoptics.scene.viewer.MayaviViewer.addDataPoints") - def testGivenNoData_whenShow3DWithPointCloud_shouldNotDisplayPointCloud(self, mockAddDataPoints, mockShow, *args): + def testGivenNoData_whenShow3DWithPointCloud_shouldNotDisplayPointCloud(self): mockPointCloudFactory = mock(PointCloudFactory) aPointCloud = PointCloud() when(mockPointCloudFactory).getPointCloud(...).thenReturn(aPointCloud) @@ -154,12 +137,10 @@ def testGivenNoData_whenShow3DWithPointCloud_shouldNotDisplayPointCloud(self, mo self.viewer.show3D(visibility=Visibility.POINT_CLOUD) - mockAddDataPoints.assert_not_called() - mockShow.assert_called_once() + self.mock3DViewer.addDataPoints.assert_not_called() + self.mock3DViewer.show.assert_called_once() - @patchMayaviRender - @patch("pytissueoptics.scene.viewer.MayaviViewer.addDataPoints") - def testWhenShow3DWithSurfacePointCloud_shouldOnlyDisplaySurfacePoints(self, mockAddDataPoints, mockShow, *args): + def testWhenShow3DWithSurfacePointCloud_shouldOnlyDisplaySurfacePoints(self): mockPointCloudFactory = mock(PointCloudFactory) aPointCloud = PointCloud( solidPoints=np.array([[0.5, 0, 0, 0]]), surfacePoints=np.array([[1, 0, 0, 0], [-1, 0, 0, 0]]) @@ -169,15 +150,13 @@ def testWhenShow3DWithSurfacePointCloud_shouldOnlyDisplaySurfacePoints(self, moc self.viewer.show3D(visibility=Visibility.POINT_CLOUD, pointCloudStyle=PointCloudStyle(showSolidPoints=False)) - mockAddDataPoints.assert_called_once() - self.assertTrue(np.array_equal(mockAddDataPoints.call_args[0][0], aPointCloud.leavingSurfacePoints)) - mockShow.assert_called_once() + self.mock3DViewer.addDataPoints.assert_called_once() + self.assertTrue( + np.array_equal(self.mock3DViewer.addDataPoints.call_args[0][0], aPointCloud.leavingSurfacePoints) + ) + self.mock3DViewer.show.assert_called_once() - @patchMayaviRender - @patch("pytissueoptics.scene.viewer.MayaviViewer.addDataPoints") - def testWhenShow3DWithEnteringSurfacePointCloud_shouldOnlyDisplayEnteringSurfacePoints( - self, mockAddDataPoints, mockShow, *args - ): + def testWhenShow3DWithEnteringSurfacePointCloud_shouldOnlyDisplayEnteringSurfacePoints(self): mockPointCloudFactory = mock(PointCloudFactory) aPointCloud = PointCloud( solidPoints=np.array([[0.5, 0, 0, 0]]), surfacePoints=np.array([[1, 0, 0, 0], [-1, 1, 1, 1]]) @@ -192,30 +171,26 @@ def testWhenShow3DWithEnteringSurfacePointCloud_shouldOnlyDisplayEnteringSurface ), ) - mockAddDataPoints.assert_called_once() - self.assertTrue(np.array_equal(mockAddDataPoints.call_args[0][0], aPointCloud.enteringSurfacePointsPositive)) - mockShow.assert_called_once() + self.mock3DViewer.addDataPoints.assert_called_once() + self.assertTrue( + np.array_equal(self.mock3DViewer.addDataPoints.call_args[0][0], aPointCloud.enteringSurfacePointsPositive) + ) + self.mock3DViewer.show.assert_called_once() - @patchMayaviRender - @patch("pytissueoptics.scene.viewer.MayaviViewer.addImage") - def testWhenShow3DWithViews_shouldAdd2DImageOfTheseViewsInThe3DDisplay(self, mockAddImage, mockShow, *args): + def testWhenShow3DWithViews_shouldAdd2DImageOfTheseViewsInThe3DDisplay(self): self._givenLoggerWithXSceneView() sceneView = self.logger.views[0] self.viewer.show3D(visibility=Visibility.VIEWS) - mockAddImage.assert_called_once() - addedImage = mockAddImage.call_args[0][0] + self.mock3DViewer.addImage.assert_called_once() + addedImage = self.mock3DViewer.addImage.call_args[0][0] self.assertTrue(np.array_equal(sceneView.getImageDataWithDefaultAlignment(), addedImage)) - displayedPosition = mockAddImage.call_args[0][4] + displayedPosition = self.mock3DViewer.addImage.call_args[0][4] self.assertEqual(-2.1, displayedPosition) - mockShow.assert_called_once() + self.mock3DViewer.show.assert_called_once() - @patchMayaviRender - @patch("pytissueoptics.scene.viewer.MayaviViewer.addImage") - def testWhenShow3DWithViewsIndexList_shouldAdd2DImageOfTheseViewsInThe3DDisplay( - self, mockAddImage, mockShow, *args - ): + def testWhenShow3DWithViewsIndexList_shouldAdd2DImageOfTheseViewsInThe3DDisplay(self): self._givenLoggerWithXSceneView() sceneView = self.logger.views[0] theViewIndex = 9 @@ -223,16 +198,14 @@ def testWhenShow3DWithViewsIndexList_shouldAdd2DImageOfTheseViewsInThe3DDisplay( self.viewer.show3D(visibility=Visibility.VIEWS, viewsVisibility=[theViewIndex]) - mockAddImage.assert_called_once() - addedImage = mockAddImage.call_args[0][0] + self.mock3DViewer.addImage.assert_called_once() + addedImage = self.mock3DViewer.addImage.call_args[0][0] self.assertTrue(np.array_equal(sceneView.getImageDataWithDefaultAlignment(), addedImage)) - displayedPosition = mockAddImage.call_args[0][4] + displayedPosition = self.mock3DViewer.addImage.call_args[0][4] self.assertEqual(-2.1, displayedPosition) - mockShow.assert_called_once() + self.mock3DViewer.show.assert_called_once() - @patchMayaviRender - @patch("pytissueoptics.scene.viewer.MayaviViewer.addImage") - def testGiven3DLogger_whenShow3DDefault_shouldDisplayEverythingExceptViews(self, mockAddImage, mockShow, *args): + def testGiven3DLogger_whenShow3DDefault_shouldDisplayEverythingExceptViews(self): mockPointCloudFactory = mock(PointCloudFactory) aPointCloud = PointCloud() when(mockPointCloudFactory).getPointCloud(...).thenReturn(aPointCloud) @@ -243,14 +216,10 @@ def testGiven3DLogger_whenShow3DDefault_shouldDisplayEverythingExceptViews(self, verify(self.source, times=1).addToViewer(...) verify(self.scene, times=1).addToViewer(...) verify(mockPointCloudFactory, times=1).getPointCloud(...) - mockAddImage.assert_not_called() - mockShow.assert_called_once() - - @patchMayaviRender - @patch("pytissueoptics.scene.viewer.MayaviViewer.addImage") - def testGiven2DLogger_whenShow3DDefault_shouldDisplayEverythingExceptPointCloud( - self, mockAddImage, mockShow, *args - ): + self.mock3DViewer.addImage.assert_not_called() + self.mock3DViewer.show.assert_called_once() + + def testGiven2DLogger_whenShow3DDefault_shouldDisplayEverythingExceptPointCloud(self): self._givenLoggerWithXSceneView() self.logger.has3D = False @@ -263,14 +232,10 @@ def testGiven2DLogger_whenShow3DDefault_shouldDisplayEverythingExceptPointCloud( verify(self.source, times=1).addToViewer(...) verify(self.scene, times=1).addToViewer(...) verify(mockPointCloudFactory, times=0).getPointCloud(...) - mockAddImage.assert_called() - mockShow.assert_called_once() - - @patchMayaviRender - @patch("pytissueoptics.scene.viewer.MayaviViewer.addImage") - def testGiven2DLogger_whenShow3DWithDefault3DVisibility_shouldWarnAndDisplayDefault2D( - self, mockAddImage, mockShow, *args - ): + self.mock3DViewer.addImage.assert_called() + self.mock3DViewer.show.assert_called_once() + + def testGiven2DLogger_whenShow3DWithDefault3DVisibility_shouldWarnAndDisplayDefault2D(self): self._givenLoggerWithXSceneView() self.logger.has3D = False @@ -284,8 +249,8 @@ def testGiven2DLogger_whenShow3DWithDefault3DVisibility_shouldWarnAndDisplayDefa verify(self.source, times=1).addToViewer(...) verify(self.scene, times=1).addToViewer(...) verify(mockPointCloudFactory, times=0).getPointCloud(...) - mockAddImage.assert_called() - mockShow.assert_called_once() + self.mock3DViewer.addImage.assert_called() + self.mock3DViewer.show.assert_called_once() def _givenLoggerWithXSceneView(self): sceneView = View2DProjectionX() diff --git a/pytissueoptics/rayscattering/tests/testScatteringScene.py b/pytissueoptics/rayscattering/tests/testScatteringScene.py index 6c758fa5..b4bee288 100644 --- a/pytissueoptics/rayscattering/tests/testScatteringScene.py +++ b/pytissueoptics/rayscattering/tests/testScatteringScene.py @@ -3,11 +3,11 @@ from unittest.mock import patch from mockito import mock, verify, when +from scene.viewer import Abstract3DViewer from pytissueoptics.rayscattering.materials import ScatteringMaterial from pytissueoptics.rayscattering.scatteringScene import ScatteringScene from pytissueoptics.scene.solids import Cuboid -from pytissueoptics.scene.viewer import MayaviViewer def patchMayaviShow(func): @@ -29,7 +29,7 @@ def testWhenAddingASolidWithNoScatteringMaterialDefined_shouldRaiseException(sel def testWhenAddToViewer_shouldAddAllSolidsToViewer(self): scene = ScatteringScene([Cuboid(1, 1, 1, material=ScatteringMaterial())]) - viewer = mock(MayaviViewer) + viewer = mock(Abstract3DViewer) when(viewer).add(...).thenReturn() scene.addToViewer(viewer) @@ -37,7 +37,7 @@ def testWhenAddToViewer_shouldAddAllSolidsToViewer(self): verify(viewer).add(*scene.solids, ...) @patchMayaviShow - def testWhenShow_shouldShowInsideMayaviViewer(self, mockShow, *args): + def testWhenShow_shouldShowInside3DViewer(self, mockShow, *args): scene = ScatteringScene([Cuboid(1, 1, 1, material=ScatteringMaterial())]) scene.show() diff --git a/pytissueoptics/scene/__init__.py b/pytissueoptics/scene/__init__.py index 4962affa..c3edf9b6 100644 --- a/pytissueoptics/scene/__init__.py +++ b/pytissueoptics/scene/__init__.py @@ -15,7 +15,7 @@ SymmetricLens, ThickLens, ) -from .viewer import MAYAVI_AVAILABLE, MayaviViewer, ViewPointStyle +from .viewer import ViewPointStyle, get3DViewer __all__ = [ "Cuboid", @@ -28,8 +28,7 @@ "Scene", "Loader", "loadSolid", - "MayaviViewer", - "MAYAVI_AVAILABLE", + "get3DViewer", "ViewPointStyle", "Logger", "InteractionKey", diff --git a/pytissueoptics/scene/scene/scene.py b/pytissueoptics/scene/scene/scene.py index 3811e41e..838088c4 100644 --- a/pytissueoptics/scene/scene/scene.py +++ b/pytissueoptics/scene/scene/scene.py @@ -4,8 +4,7 @@ from pytissueoptics.scene.geometry import INTERFACE_KEY, BoundingBox, Environment, Polygon, Vector from pytissueoptics.scene.solids import Solid -from pytissueoptics.scene.viewer.displayable import Displayable -from pytissueoptics.scene.viewer.mayavi import MayaviViewer +from pytissueoptics.scene.viewer import Abstract3DViewer, Displayable class Scene(Displayable): @@ -33,7 +32,7 @@ def add(self, solid: Solid, position: Vector = None): def solids(self): return self._solids - def addToViewer(self, viewer: MayaviViewer, representation="surface", colormap="bone", opacity=0.1, **kwargs): + def addToViewer(self, viewer: Abstract3DViewer, representation="surface", colormap="bone", opacity=0.1, **kwargs): viewer.add(*self.solids, representation=representation, colormap=colormap, opacity=opacity, **kwargs) def getWorldEnvironment(self) -> Environment: diff --git a/pytissueoptics/scene/tests/intersection/benchmarkIntersectionFinder.py b/pytissueoptics/scene/tests/intersection/benchmarkIntersectionFinder.py index f85428c6..413ae6db 100644 --- a/pytissueoptics/scene/tests/intersection/benchmarkIntersectionFinder.py +++ b/pytissueoptics/scene/tests/intersection/benchmarkIntersectionFinder.py @@ -33,7 +33,7 @@ from pytissueoptics.scene.tree.treeConstructor.binary.noSplitOneAxisConstructor import NoSplitOneAxisConstructor from pytissueoptics.scene.tree.treeConstructor.binary.noSplitThreeAxesConstructor import NoSplitThreeAxesConstructor from pytissueoptics.scene.tree.treeConstructor.binary.splitTreeAxesConstructor import SplitThreeAxesConstructor -from pytissueoptics.scene.viewer import MayaviViewer +from pytissueoptics.scene.viewer import get3DViewer pandas.set_option("display.max_columns", 20) pandas.set_option("display.width", 1200) @@ -153,7 +153,7 @@ def _runValidationReference( f" - {str('REFERENCE'):^10}" ) if display: - viewer = MayaviViewer() + viewer = get3DViewer() viewer.addLogger(logger) viewer.show() return missedRays @@ -192,7 +192,7 @@ def _runValidationForConstructor( f" - {partition.getAverageDepth():^10.2f} - {missedRays:^10} - {missedRays == referenceMissed:^10}", ) if display and missedRays != referenceMissed: - viewer = MayaviViewer() + viewer = get3DViewer() viewer.addLogger(logger) viewer.show() @@ -297,7 +297,7 @@ def displayStats(self): def displayBenchmarkTreeResults( self, objectsDisplay: bool = True, scenes: List[Scene] = None, objectsOpacity: float = 0.5 ): - viewer = MayaviViewer() + viewer = get3DViewer() if scenes is None: scenes = self.scenes for j, scene in enumerate(scenes): diff --git a/pytissueoptics/scene/tests/viewer/testMayaviViewer.py b/pytissueoptics/scene/tests/viewer/testMayavi3DViewer.py similarity index 94% rename from pytissueoptics/scene/tests/viewer/testMayaviViewer.py rename to pytissueoptics/scene/tests/viewer/testMayavi3DViewer.py index 3cf91c03..c412cd67 100644 --- a/pytissueoptics/scene/tests/viewer/testMayaviViewer.py +++ b/pytissueoptics/scene/tests/viewer/testMayavi3DViewer.py @@ -10,7 +10,6 @@ from pytissueoptics.scene.scene import Scene from pytissueoptics.scene.solids import Cuboid, Ellipsoid, Sphere from pytissueoptics.scene.tests import SHOW_VISUAL_TESTS, compareVisuals -from pytissueoptics.scene.viewer.mayavi import MayaviViewer TEST_IMAGES_DIR = os.path.join(os.path.dirname(__file__), "testImages") @@ -26,9 +25,11 @@ def patchMayaviShow(func): @unittest.skipIf( not SHOW_VISUAL_TESTS, "Visual tests are disabled. Set scene.tests.SHOW_VISUAL_TESTS to True to enable them." ) -class TestMayaviViewer(unittest.TestCase): +class TestMayavi3DViewer(unittest.TestCase): def setUp(self): - self.viewer = MayaviViewer() + from scene.viewer.mayavi.mayavi3DViewer import Mayavi3DViewer + + self.viewer = Mayavi3DViewer() def testWhenAddLogger_shouldDrawAllLoggerComponents(self): logger = self._getTestLogger() @@ -36,13 +37,11 @@ def testWhenAddLogger_shouldDrawAllLoggerComponents(self): self._assertViewerDisplays("logger_natural") def testGivenOpticsViewPoint_shouldDisplayFromOpticsViewPoint(self): - self.viewer = MayaviViewer() self.viewer.setViewPointStyle(ViewPointStyle.OPTICS) self.viewer.add(self._getSimpleSolid()) self._assertViewerDisplays("solid_optics") def testGivenNaturalFrontViewPoint_shouldDisplayFromNaturalFrontViewPoint(self): - self.viewer = MayaviViewer() self.viewer.setViewPointStyle(ViewPointStyle.NATURAL_FRONT) self.viewer.add(self._getSimpleSolid()) self._assertViewerDisplays("solid_natural_front") @@ -53,7 +52,6 @@ def testWhenAddSpecialTestSphere_shouldDrawCorrectly(self): self._assertViewerDisplays("sphere_normals") def testWhenAddImages_shouldDraw2DImagesCorrectly(self): - self.viewer = MayaviViewer() self.viewer.setViewPointStyle(ViewPointStyle.NATURAL) testImage = np.zeros((5, 5)) testImage[4, 4] = 1 diff --git a/pytissueoptics/scene/viewer/__init__.py b/pytissueoptics/scene/viewer/__init__.py index cc198448..808efdb6 100644 --- a/pytissueoptics/scene/viewer/__init__.py +++ b/pytissueoptics/scene/viewer/__init__.py @@ -1,5 +1,6 @@ +from .abstract3DViewer import Abstract3DViewer from .displayable import Displayable -from .mayavi import MAYAVI_AVAILABLE, MayaviViewer +from .provider import get3DViewer from .viewPoint import ViewPointStyle -__all__ = ["Displayable", "MAYAVI_AVAILABLE", "MayaviViewer", "ViewPointStyle"] +__all__ = ["Displayable", "get3DViewer", "Abstract3DViewer", "ViewPointStyle"] diff --git a/pytissueoptics/scene/viewer/abstract3DViewer.py b/pytissueoptics/scene/viewer/abstract3DViewer.py index 73204f1d..3611d787 100644 --- a/pytissueoptics/scene/viewer/abstract3DViewer.py +++ b/pytissueoptics/scene/viewer/abstract3DViewer.py @@ -3,6 +3,7 @@ import numpy as np from pytissueoptics.scene.solids import Solid + from .viewPoint import ViewPointStyle diff --git a/pytissueoptics/scene/viewer/displayable.py b/pytissueoptics/scene/viewer/displayable.py index 6d2544e2..7d9e82ed 100644 --- a/pytissueoptics/scene/viewer/displayable.py +++ b/pytissueoptics/scene/viewer/displayable.py @@ -1,6 +1,6 @@ from abc import abstractmethod -from pytissueoptics.scene.viewer.mayavi import MayaviViewer +from .provider import get3DViewer class Displayable: @@ -9,6 +9,6 @@ def addToViewer(self, viewer, **kwargs): pass def show(self, **kwargs): - viewer = MayaviViewer() + viewer = get3DViewer() self.addToViewer(viewer, **kwargs) viewer.show() diff --git a/pytissueoptics/scene/viewer/mayavi/__init__.py b/pytissueoptics/scene/viewer/mayavi/__init__.py index 03718e73..3177da28 100644 --- a/pytissueoptics/scene/viewer/mayavi/__init__.py +++ b/pytissueoptics/scene/viewer/mayavi/__init__.py @@ -1,11 +1,8 @@ from .mayaviSolid import MayaviObject, MayaviSolid from .mayaviTriangleMesh import MayaviTriangleMesh -from .mayaviViewer import MAYAVI_AVAILABLE, MayaviViewer __all__ = [ "MayaviObject", "MayaviSolid", "MayaviTriangleMesh", - "MayaviViewer", - "MAYAVI_AVAILABLE", ] diff --git a/pytissueoptics/scene/viewer/mayavi/mayaviViewer.py b/pytissueoptics/scene/viewer/mayavi/mayavi3DViewer.py similarity index 96% rename from pytissueoptics/scene/viewer/mayavi/mayaviViewer.py rename to pytissueoptics/scene/viewer/mayavi/mayavi3DViewer.py index 70742b3e..22da6b02 100644 --- a/pytissueoptics/scene/viewer/mayavi/mayaviViewer.py +++ b/pytissueoptics/scene/viewer/mayavi/mayavi3DViewer.py @@ -1,24 +1,16 @@ import numpy as np +from mayavi import mlab from pytissueoptics.scene.geometry import BoundingBox from pytissueoptics.scene.logger import Logger - -from ..abstract3DViewer import Abstract3DViewer -from ..viewPoint import ViewPointFactory, ViewPointStyle - -try: - from mayavi import mlab - - MAYAVI_AVAILABLE = True -except ImportError: - MAYAVI_AVAILABLE = False - from pytissueoptics.scene.solids import Solid +from pytissueoptics.scene.viewer import Abstract3DViewer, ViewPointStyle +from pytissueoptics.scene.viewer.viewPoint import ViewPointFactory from .mayaviSolid import MayaviSolid -class MayaviViewer(Abstract3DViewer): +class Mayavi3DViewer(Abstract3DViewer): def __init__(self): self._scenes = { "DefaultScene": { diff --git a/pytissueoptics/scene/viewer/null3DViewer.py b/pytissueoptics/scene/viewer/null3DViewer.py new file mode 100644 index 00000000..274ac000 --- /dev/null +++ b/pytissueoptics/scene/viewer/null3DViewer.py @@ -0,0 +1,56 @@ +import warnings + +import numpy as np +from scene.solids import Solid +from scene.viewer.abstract3DViewer import Abstract3DViewer + +from pytissueoptics import ViewPointStyle + + +class Null3DViewer(Abstract3DViewer): + def setViewPointStyle(self, viewPointStyle: ViewPointStyle): + pass + + def add( + self, + *solids: Solid, + representation="wireframe", + lineWidth=0.25, + showNormals=False, + normalLength=0.3, + colormap="viridis", + reverseColormap=False, + colorWithPosition=False, + opacity=1, + **kwargs, + ): + pass + + def addDataPoints( + self, + dataPoints: np.ndarray, + colormap="rainbow", + reverseColormap=False, + scale=0.15, + scaleWithValue=True, + asSpheres=True, + ): + pass + + def addImage( + self, + image: np.ndarray, + size: tuple = None, + minCorner: tuple = (0, 0), + axis: int = 2, + position: float = 0, + colormap: str = "viridis", + ): + pass + + @staticmethod + def showVolumeSlicer(hist3D: np.ndarray, colormap: str = "viridis", interpolate=False, **kwargs): + warnings.warn("Attempting to show a volume slicer with a Null3DViewer. No action will be taken.") + + def show(self): + warnings.warn("Attempting to show a Null3DViewer. No action will be taken.") diff --git a/pytissueoptics/scene/viewer/provider.py b/pytissueoptics/scene/viewer/provider.py new file mode 100644 index 00000000..80c651e1 --- /dev/null +++ b/pytissueoptics/scene/viewer/provider.py @@ -0,0 +1,30 @@ +import os +import warnings + +from .abstract3DViewer import Abstract3DViewer + +AVAILABLE_BACKENDS = ("mayavi", "null") + + +def get3DViewer() -> Abstract3DViewer: + backend = os.environ.get("PTO_3D_BACKEND", "mayavi").lower() + if backend == "mayavi": + try: + from .mayavi.mayavi3DViewer import Mayavi3DViewer + + return Mayavi3DViewer() + except Exception as e: + warnings.warn( + "Mayavi is not available. Falling back to a null 3D viewer. Fix the following error to use the Mayavi " + "backend or select another backend by setting the PTO_3D_BACKEND environment variable (available " + f"backends: {AVAILABLE_BACKENDS}). \n{e}" + ) + from .null3DViewer import Null3DViewer + + return Null3DViewer() + elif backend == "null": + from .null3DViewer import Null3DViewer + + return Null3DViewer() + else: + raise ValueError(f"Invalid backend '{backend}'. Available backends: {AVAILABLE_BACKENDS}") From 5f6a0b392ffb6fb616ca31713f2a27d5426cff3c Mon Sep 17 00:00:00 2001 From: JLBegin Date: Tue, 22 Apr 2025 22:05:00 -0400 Subject: [PATCH 3/3] fix imports --- pytissueoptics/rayscattering/tests/display/testViewer.py | 2 +- pytissueoptics/rayscattering/tests/testScatteringScene.py | 2 +- pytissueoptics/scene/tests/viewer/testMayavi3DViewer.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytissueoptics/rayscattering/tests/display/testViewer.py b/pytissueoptics/rayscattering/tests/display/testViewer.py index 4115411a..468c4531 100644 --- a/pytissueoptics/rayscattering/tests/display/testViewer.py +++ b/pytissueoptics/rayscattering/tests/display/testViewer.py @@ -3,7 +3,6 @@ import numpy as np from mockito import ANY, mock, verify, when -from scene.viewer import Abstract3DViewer from pytissueoptics import Direction, View2DProjectionX, ViewGroup from pytissueoptics.rayscattering.display.profiles import Profile1D, ProfileFactory @@ -14,6 +13,7 @@ from pytissueoptics.rayscattering.source import Source from pytissueoptics.scene.geometry import BoundingBox from pytissueoptics.scene.logger import Logger +from pytissueoptics.scene.viewer import Abstract3DViewer class TestViewer(unittest.TestCase): diff --git a/pytissueoptics/rayscattering/tests/testScatteringScene.py b/pytissueoptics/rayscattering/tests/testScatteringScene.py index b4bee288..e2a9a4e4 100644 --- a/pytissueoptics/rayscattering/tests/testScatteringScene.py +++ b/pytissueoptics/rayscattering/tests/testScatteringScene.py @@ -3,11 +3,11 @@ from unittest.mock import patch from mockito import mock, verify, when -from scene.viewer import Abstract3DViewer from pytissueoptics.rayscattering.materials import ScatteringMaterial from pytissueoptics.rayscattering.scatteringScene import ScatteringScene from pytissueoptics.scene.solids import Cuboid +from pytissueoptics.scene.viewer import Abstract3DViewer def patchMayaviShow(func): diff --git a/pytissueoptics/scene/tests/viewer/testMayavi3DViewer.py b/pytissueoptics/scene/tests/viewer/testMayavi3DViewer.py index c412cd67..908f524c 100644 --- a/pytissueoptics/scene/tests/viewer/testMayavi3DViewer.py +++ b/pytissueoptics/scene/tests/viewer/testMayavi3DViewer.py @@ -27,7 +27,7 @@ def patchMayaviShow(func): ) class TestMayavi3DViewer(unittest.TestCase): def setUp(self): - from scene.viewer.mayavi.mayavi3DViewer import Mayavi3DViewer + from pytissueoptics.scene.viewer.mayavi.mayavi3DViewer import Mayavi3DViewer self.viewer = Mayavi3DViewer()