From 35dfcfd181987c7d424afc232809b9b2918ec170 Mon Sep 17 00:00:00 2001 From: Alexey Shevtsov Date: Tue, 19 Oct 2021 18:36:31 +0300 Subject: [PATCH] added possibility to combine patches from the lower levels of decoder --- dpipe/predict/shape.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/dpipe/predict/shape.py b/dpipe/predict/shape.py index b0e89a9..b5d62f4 100644 --- a/dpipe/predict/shape.py +++ b/dpipe/predict/shape.py @@ -80,7 +80,7 @@ def wrapper(x, *args, **kwargs): def patches_grid(patch_size: AxesLike, stride: AxesLike, axis: AxesLike = None, - padding_values: Union[AxesParams, Callable] = 0, ratio: AxesParams = 0.5, + padding_values: Union[AxesParams, Callable] = 0, ratio: AxesParams = 0.5, scale_factor: int = 1, combiner: Type[PatchCombiner] = Average): """ Divide an incoming array into patches of corresponding ``patch_size`` and ``stride`` and then combine @@ -108,10 +108,11 @@ def wrapper(x, *args, **kwargs): x = pad_to_shape(x, new_shape, input_axis, padding_values, ratio) patches = pmap(predict, divide(x, local_size, local_stride, input_axis), *args, **kwargs) - prediction = combine(patches, extract(x.shape, input_axis), local_stride, axis, combiner=combiner) + prediction = combine(patches, np.divide(extract(x.shape, input_axis), scale_factor), + local_stride / scale_factor, axis, combiner=combiner) if valid: - prediction = crop_to_shape(prediction, shape, axis, ratio) + prediction = crop_to_shape(prediction, np.divide(shape, scale_factor), axis, ratio) return prediction return wrapper