diff --git a/src/PartSeg_smfish/__init__.py b/src/PartSeg_smfish/__init__.py index 44e7504..74e43ad 100644 --- a/src/PartSeg_smfish/__init__.py +++ b/src/PartSeg_smfish/__init__.py @@ -22,6 +22,10 @@ def register(): segmentation.LayerRangeThresholdFlow, RegisterEnum.roi_mask_segmentation_algorithm, ) + register_fun( + segmentation.ThresholdFlowAlgorithmWithDilation, + RegisterEnum.roi_mask_segmentation_algorithm, + ) register_fun(measurement.ComponentType, RegisterEnum.analysis_measurement) if getattr(sys, "frozen", False): diff --git a/src/PartSeg_smfish/segmentation.py b/src/PartSeg_smfish/segmentation.py index 0e559c1..3e26fcd 100644 --- a/src/PartSeg_smfish/segmentation.py +++ b/src/PartSeg_smfish/segmentation.py @@ -30,6 +30,8 @@ from PartSegCore.segmentation.segmentation_algorithm import ( CellFromNucleusFlow, StackAlgorithm, + ThresholdFlowAlgorithm, + ThresholdFlowAlgorithmParameters, ) from PartSegCore.segmentation.threshold import ( BaseThreshold, @@ -626,3 +628,56 @@ def maximum_projection( }, ) ) + + +class ThresholdFlowAlgorithmParametersWithDilation( + ThresholdFlowAlgorithmParameters +): + dilation_radius: int = Field( + 0, + title="Dilation radius", + description="Radius of dilation applied to the mask", + ge=0, + le=100, + ) + + +class ThresholdFlowAlgorithmWithDilation(ThresholdFlowAlgorithm): + new_parameters: ThresholdFlowAlgorithmParametersWithDilation + __argument_class__ = ThresholdFlowAlgorithmParametersWithDilation + + @classmethod + def get_name(cls) -> str: + return "Threshold Flow with Dilation" + + @classmethod + def get_steps_num(cls) -> int: + return ThresholdFlowAlgorithm.get_steps_num() + 1 + + def calculation_run( + self, report_fun: Callable[[str, int], None] + ) -> ROIExtractionResult: + res = super().calculation_run(report_fun) + if self.new_parameters.dilation_radius == 0: + return res + + report_fun("Dilate ROI", self.get_steps_num() - 1) + + from PartSegCore.image_operations import dilate + + rad = self.new_parameters.dilation_radius + roi = dilate(res.roi, [rad, rad], True) + + roi[res.roi > 0] = res.roi[res.roi > 0] + + res2 = ROIExtractionResult( + roi=roi, + parameters=self.get_segmentation_profile(), + additional_layers=deepcopy(res.additional_layers), + ) + + res2.additional_layers["dilated roi"] = AdditionalLayerDescription( + res.roi, "labels", "Dilated ROI" + ) + report_fun("Calculation done", self.get_steps_num()) + return res2