Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@

import abc
import io
from typing import Any, Iterator, Mapping, Self, Sequence, Union
from typing import Any, Iterator, Mapping, Sequence, Union
try:
from typing import Self # type: ignore
except ImportError:
from typing_extensions import Self

import numpy as np

Expand Down
32 changes: 12 additions & 20 deletions python/data_accessors/local_file_handlers/generic_dicom_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,23 +165,9 @@ def apply(self, image: np.ndarray) -> np.ndarray:
Returns:
Windowed image as numpy array.
"""
iinfo = np.iinfo(self._default_type)
# Actual range is center - half width to center + half width.
# Actual number of pixels is width + 1.
# See https://radiopaedia.org/articles/windowing-ct?lang=us
half_window_width = self.width // 2
center = self.center
top_clip = center + half_window_width
bottom_clip = center - half_window_width
# Round prior to cast to minimize precision loss.
return np.round(
np.interp(
image.clip(bottom_clip, top_clip),
(bottom_clip, top_clip),
(0, iinfo.max),
),
0,
).astype(iinfo)
return image_utils.window_accurate(
image, self.center, self.width, self._default_type
)


class RGBWindow(ImageTransform):
Expand Down Expand Up @@ -661,9 +647,15 @@ def _process_buffered_mri_volume(
return
output_dtype = np.uint8
max_dtype_value = np.iinfo(output_dtype).max
min_val = np.min(images)
max_delta = np.max(images) - min_val
for image in images:
min_val = min(np.min(img) for img in images)
max_val = max(np.max(img) for img in images)
max_delta = max_val - min_val

# Convert to list to allow popping if it's not already a mutable list.
# This helps free memory of processed slices.
images_list = list(images)
while images_list:
image = images_list.pop(0)
image = image.astype(np.float64)
if max_delta == 0:
image[...] = max_dtype_value
Expand Down
37 changes: 37 additions & 0 deletions python/data_processing/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,40 @@ def window(
(bottom_clip, top_clip),
(0, iinfo.max),
).astype(iinfo)


def window_accurate(
image: np.ndarray,
window_center: int,
window_width: int,
dtype: Any = np.uint8,
) -> np.ndarray:
"""Applies the Window operation accurately.

This implementation addresses the bugs in the legacy `window` function:
1) Corrects the window range to center +/- half width.
2) Rounds to the nearest integer before casting to minimize precision loss.

Args:
image: An image to be windowed, containing signed integer pixels.
window_center: The center of the window.
window_width: The width of the window.
dtype: Output data type (default: uint8).

Returns:
Windowed image as a numpy array.
"""
iinfo = np.iinfo(dtype)
half_window_width = window_width // 2
top_clip = window_center + half_window_width
bottom_clip = window_center - half_window_width

# Round prior to cast to minimize precision loss.
return np.round(
np.interp(
image.clip(bottom_clip, top_clip),
(bottom_clip, top_clip),
(0, iinfo.max),
),
0,
).astype(iinfo)
58 changes: 58 additions & 0 deletions python/data_processing/window_accurate_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
from data_processing import image_utils

class TestWindowAccurate(parameterized.TestCase):
"""Unit tests for `window_accurate()`."""

@parameterized.named_parameters(
('ExactCenter', 2048, 2048, 4096, 32768), # Midpoint (ceil of 32767.5)
('BottomEdge', 0, 2048, 4096, 0), # Bottom clip
('TopEdge', 4096, 2048, 4096, 65535), # Top clip
('BelowBottom', -100, 2048, 4096, 0), # Below bottom clip
('AboveTop', 5000, 2048, 4096, 65535), # Above top clip
)
def testStandardRange(self, input_value: int, center: int, width: int, expected: int):
"""Tests standard 12-bit range with uint16 output."""
actual = image_utils.window_accurate(
np.array([input_value], dtype=np.int16),
center,
width,
np.uint16
)
self.assertEqual(actual[0], expected)

@parameterized.named_parameters(
('BeforeLowest', 2045, 2048, 4, 0), # Center=2048, Width=4 -> [2046, 2050]
('AtLowest', 2046, 2048, 4, 0), # At bottom clip
('Midway', 2048, 2048, 4, 32768), # At center
('AtHighest', 2050, 2048, 4, 65535), # At top clip
('AfterHighest', 2051, 2048, 4, 65535), # After top clip
)
def testSmallWindow(self, input_value: int, center: int, width: int, expected: int):
"""Tests behavior with a very narrow window."""
actual = image_utils.window_accurate(
np.array([input_value], dtype=np.int16),
center,
width,
np.uint16
)
self.assertEqual(actual[0], expected)

def testRoundingCorrectness(self):
"""Specifically tests that rounding to nearest integer is working."""
# Window 100 to 200 (Center 150, Width 100)
# Norm to 0-255 (uint8)
image = np.array([125], dtype=np.int16) # Exactly 1/4 of the way: (125-100)/(200-100) * 255 = 0.25 * 255 = 63.75
# Round(63.75) should be 64.
actual = image_utils.window_accurate(image, 150, 100, np.uint8)
self.assertEqual(actual[0], 64)

image_low = np.array([124], dtype=np.int16) # (124-100)/100 * 255 = 0.24 * 255 = 61.2
# Round(61.2) should be 61.
actual_low = image_utils.window_accurate(image_low, 150, 100, np.uint8)
self.assertEqual(actual_low[0], 61)

if __name__ == '__main__':
absltest.main()
9 changes: 5 additions & 4 deletions python/serving/logging_lib/cloud_logging_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,8 @@ def set_log_trace_key(key: str) -> None:
# will seg-fault (python queue wait). This can be avoided, by stoping and
# the background transport prior to forking and then restarting the transport
# following the fork.
os.register_at_fork(
before=CloudLoggingClient._fork_shutdown, # pylint: disable=protected-access
after_in_child=CloudLoggingClient._init_fork_module_state, # pylint: disable=protected-access
)
if hasattr(os, 'register_at_fork'):
os.register_at_fork(
before=CloudLoggingClient._fork_shutdown, # pylint: disable=protected-access
after_in_child=CloudLoggingClient._init_fork_module_state, # pylint: disable=protected-access
)
9 changes: 5 additions & 4 deletions python/serving/logging_lib/cloud_logging_client_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,8 @@ def log_error_level(self, level: int) -> None:
# will seg-fault (python queue wait). This can be avoided, by stopping the
# background transport prior to forking and then restarting the transport
# following the fork.
os.register_at_fork(
before=CloudLoggingClientInstance.fork_shutdown, # pylint: disable=protected-access
after_in_child=CloudLoggingClientInstance._init_fork_module_state, # pylint: disable=protected-access
)
if hasattr(os, 'register_at_fork'):
os.register_at_fork(
before=CloudLoggingClientInstance.fork_shutdown, # pylint: disable=protected-access
after_in_child=CloudLoggingClientInstance._init_fork_module_state, # pylint: disable=protected-access
)
3 changes: 2 additions & 1 deletion python/serving/logging_lib/flags/secret_flag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,4 +216,5 @@ def get_bool_secret_or_env(
# state, e.g., acquired locks that will not release or references to invalid
# state.
_init_fork_module_state()
os.register_at_fork(after_in_child=_init_fork_module_state)
if hasattr(os, 'register_at_fork'):
os.register_at_fork(after_in_child=_init_fork_module_state)