diff --git a/python/data_accessors/local_file_handlers/abstract_handler.py b/python/data_accessors/local_file_handlers/abstract_handler.py index b385a45..3b153ce 100644 --- a/python/data_accessors/local_file_handlers/abstract_handler.py +++ b/python/data_accessors/local_file_handlers/abstract_handler.py @@ -15,7 +15,11 @@ import abc import io -from typing import Any, Iterator, Mapping, Self, Sequence, Union +from typing import Any, Iterator, Mapping, Sequence, Union +try: + from typing import Self # type: ignore +except ImportError: + from typing_extensions import Self import numpy as np diff --git a/python/data_accessors/local_file_handlers/generic_dicom_handler.py b/python/data_accessors/local_file_handlers/generic_dicom_handler.py index 4e9b055..ae297f5 100644 --- a/python/data_accessors/local_file_handlers/generic_dicom_handler.py +++ b/python/data_accessors/local_file_handlers/generic_dicom_handler.py @@ -165,23 +165,9 @@ def apply(self, image: np.ndarray) -> np.ndarray: Returns: Windowed image as numpy array. """ - iinfo = np.iinfo(self._default_type) - # Actual range is center - half width to center + half width. - # Actual number of pixels is width + 1. - # See https://radiopaedia.org/articles/windowing-ct?lang=us - half_window_width = self.width // 2 - center = self.center - top_clip = center + half_window_width - bottom_clip = center - half_window_width - # Round prior to cast to minimize precision loss. - return np.round( - np.interp( - image.clip(bottom_clip, top_clip), - (bottom_clip, top_clip), - (0, iinfo.max), - ), - 0, - ).astype(iinfo) + return image_utils.window_accurate( + image, self.center, self.width, self._default_type + ) class RGBWindow(ImageTransform): @@ -661,9 +647,15 @@ def _process_buffered_mri_volume( return output_dtype = np.uint8 max_dtype_value = np.iinfo(output_dtype).max - min_val = np.min(images) - max_delta = np.max(images) - min_val - for image in images: + min_val = min(np.min(img) for img in images) + max_val = max(np.max(img) for img in images) + max_delta = max_val - min_val + + # Convert to list to allow popping if it's not already a mutable list. + # This helps free memory of processed slices. + images_list = list(images) + while images_list: + image = images_list.pop(0) image = image.astype(np.float64) if max_delta == 0: image[...] = max_dtype_value diff --git a/python/data_processing/image_utils.py b/python/data_processing/image_utils.py index 7792a33..c63c8ae 100644 --- a/python/data_processing/image_utils.py +++ b/python/data_processing/image_utils.py @@ -148,3 +148,40 @@ def window( (bottom_clip, top_clip), (0, iinfo.max), ).astype(iinfo) + + +def window_accurate( + image: np.ndarray, + window_center: int, + window_width: int, + dtype: Any = np.uint8, +) -> np.ndarray: + """Applies the Window operation accurately. + + This implementation addresses the bugs in the legacy `window` function: + 1) Corrects the window range to center +/- half width. + 2) Rounds to the nearest integer before casting to minimize precision loss. + + Args: + image: An image to be windowed, containing signed integer pixels. + window_center: The center of the window. + window_width: The width of the window. + dtype: Output data type (default: uint8). + + Returns: + Windowed image as a numpy array. + """ + iinfo = np.iinfo(dtype) + half_window_width = window_width // 2 + top_clip = window_center + half_window_width + bottom_clip = window_center - half_window_width + + # Round prior to cast to minimize precision loss. + return np.round( + np.interp( + image.clip(bottom_clip, top_clip), + (bottom_clip, top_clip), + (0, iinfo.max), + ), + 0, + ).astype(iinfo) diff --git a/python/data_processing/window_accurate_test.py b/python/data_processing/window_accurate_test.py new file mode 100644 index 0000000..06ac121 --- /dev/null +++ b/python/data_processing/window_accurate_test.py @@ -0,0 +1,58 @@ +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized +from data_processing import image_utils + +class TestWindowAccurate(parameterized.TestCase): + """Unit tests for `window_accurate()`.""" + + @parameterized.named_parameters( + ('ExactCenter', 2048, 2048, 4096, 32768), # Midpoint (ceil of 32767.5) + ('BottomEdge', 0, 2048, 4096, 0), # Bottom clip + ('TopEdge', 4096, 2048, 4096, 65535), # Top clip + ('BelowBottom', -100, 2048, 4096, 0), # Below bottom clip + ('AboveTop', 5000, 2048, 4096, 65535), # Above top clip + ) + def testStandardRange(self, input_value: int, center: int, width: int, expected: int): + """Tests standard 12-bit range with uint16 output.""" + actual = image_utils.window_accurate( + np.array([input_value], dtype=np.int16), + center, + width, + np.uint16 + ) + self.assertEqual(actual[0], expected) + + @parameterized.named_parameters( + ('BeforeLowest', 2045, 2048, 4, 0), # Center=2048, Width=4 -> [2046, 2050] + ('AtLowest', 2046, 2048, 4, 0), # At bottom clip + ('Midway', 2048, 2048, 4, 32768), # At center + ('AtHighest', 2050, 2048, 4, 65535), # At top clip + ('AfterHighest', 2051, 2048, 4, 65535), # After top clip + ) + def testSmallWindow(self, input_value: int, center: int, width: int, expected: int): + """Tests behavior with a very narrow window.""" + actual = image_utils.window_accurate( + np.array([input_value], dtype=np.int16), + center, + width, + np.uint16 + ) + self.assertEqual(actual[0], expected) + + def testRoundingCorrectness(self): + """Specifically tests that rounding to nearest integer is working.""" + # Window 100 to 200 (Center 150, Width 100) + # Norm to 0-255 (uint8) + image = np.array([125], dtype=np.int16) # Exactly 1/4 of the way: (125-100)/(200-100) * 255 = 0.25 * 255 = 63.75 + # Round(63.75) should be 64. + actual = image_utils.window_accurate(image, 150, 100, np.uint8) + self.assertEqual(actual[0], 64) + + image_low = np.array([124], dtype=np.int16) # (124-100)/100 * 255 = 0.24 * 255 = 61.2 + # Round(61.2) should be 61. + actual_low = image_utils.window_accurate(image_low, 150, 100, np.uint8) + self.assertEqual(actual_low[0], 61) + +if __name__ == '__main__': + absltest.main() diff --git a/python/serving/logging_lib/cloud_logging_client.py b/python/serving/logging_lib/cloud_logging_client.py index 1418513..33e0bec 100644 --- a/python/serving/logging_lib/cloud_logging_client.py +++ b/python/serving/logging_lib/cloud_logging_client.py @@ -386,7 +386,8 @@ def set_log_trace_key(key: str) -> None: # will seg-fault (python queue wait). This can be avoided, by stoping and # the background transport prior to forking and then restarting the transport # following the fork. -os.register_at_fork( - before=CloudLoggingClient._fork_shutdown, # pylint: disable=protected-access - after_in_child=CloudLoggingClient._init_fork_module_state, # pylint: disable=protected-access -) +if hasattr(os, 'register_at_fork'): + os.register_at_fork( + before=CloudLoggingClient._fork_shutdown, # pylint: disable=protected-access + after_in_child=CloudLoggingClient._init_fork_module_state, # pylint: disable=protected-access + ) diff --git a/python/serving/logging_lib/cloud_logging_client_instance.py b/python/serving/logging_lib/cloud_logging_client_instance.py index 133cfd2..38b4aac 100644 --- a/python/serving/logging_lib/cloud_logging_client_instance.py +++ b/python/serving/logging_lib/cloud_logging_client_instance.py @@ -829,7 +829,8 @@ def log_error_level(self, level: int) -> None: # will seg-fault (python queue wait). This can be avoided, by stopping the # background transport prior to forking and then restarting the transport # following the fork. -os.register_at_fork( - before=CloudLoggingClientInstance.fork_shutdown, # pylint: disable=protected-access - after_in_child=CloudLoggingClientInstance._init_fork_module_state, # pylint: disable=protected-access -) +if hasattr(os, 'register_at_fork'): + os.register_at_fork( + before=CloudLoggingClientInstance.fork_shutdown, # pylint: disable=protected-access + after_in_child=CloudLoggingClientInstance._init_fork_module_state, # pylint: disable=protected-access + ) diff --git a/python/serving/logging_lib/flags/secret_flag_utils.py b/python/serving/logging_lib/flags/secret_flag_utils.py index 801888f..b87a02e 100644 --- a/python/serving/logging_lib/flags/secret_flag_utils.py +++ b/python/serving/logging_lib/flags/secret_flag_utils.py @@ -216,4 +216,5 @@ def get_bool_secret_or_env( # state, e.g., acquired locks that will not release or references to invalid # state. _init_fork_module_state() -os.register_at_fork(after_in_child=_init_fork_module_state) +if hasattr(os, 'register_at_fork'): + os.register_at_fork(after_in_child=_init_fork_module_state)