diff --git a/HISTORY.rst b/HISTORY.rst index 31bd8bc9..10f97c3a 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -7,6 +7,9 @@ Unreleased Changes References to the old name and website domain have been updated to reflect this change. https://github.com/natcap/pygeoprocessing/issues/458 +* Handling the case where a floating-point raster passed to + ``pygeoprocessing.reclassify_raster`` may have a NaN nodata value. + https://github.com/natcap/pygeoprocessing/issues/454 * Updating pyproject.toml to use the standard ``license-files`` key and replacing the license-related Trove classifier with the approved SPDX string. https://github.com/natcap/pygeoprocessing/issues/466 diff --git a/src/pygeoprocessing/geoprocessing.py b/src/pygeoprocessing/geoprocessing.py index 5af4bef7..95132f34 100644 --- a/src/pygeoprocessing/geoprocessing.py +++ b/src/pygeoprocessing/geoprocessing.py @@ -37,7 +37,8 @@ from .geoprocessing_core import DEFAULT_GTIFF_CREATION_TUPLE_OPTIONS from .geoprocessing_core import DEFAULT_OSR_AXIS_MAPPING_STRATEGY from .geoprocessing_core import INT8_CREATION_OPTIONS -from .utils import GDALUseExceptions, gdal_use_exceptions +from .utils import gdal_use_exceptions +from .utils import GDALUseExceptions # This is used to efficiently pass data to the raster stats worker if available if sys.version_info >= (3, 8): @@ -2408,7 +2409,7 @@ def reclassify_raster( nodata_dest_value = target_nodata if nodata is not None: for key, val in value_map.items(): - if numpy.isclose(key, nodata): + if numpy.isclose(key, nodata, equal_nan=True): nodata_dest_value = val del value_map_copy[key] break @@ -2435,7 +2436,7 @@ def _map_dataset_to_value_op(original_values): if nodata is None: valid_mask = numpy.full(original_values.shape, True) else: - valid_mask = ~numpy.isclose(original_values, nodata) + valid_mask = ~array_equals_nodata(original_values, nodata) out_array[~valid_mask] = nodata_dest_value if values_required: diff --git a/tests/test_geoprocessing.py b/tests/test_geoprocessing.py index 75a4522b..aee982f0 100644 --- a/tests/test_geoprocessing.py +++ b/tests/test_geoprocessing.py @@ -32,10 +32,10 @@ from pygeoprocessing.geoprocessing_core import DEFAULT_CREATION_OPTIONS from pygeoprocessing.geoprocessing_core import \ DEFAULT_GTIFF_CREATION_TUPLE_OPTIONS -from pygeoprocessing.utils import gdal_use_exceptions from pygeoprocessing.geoprocessing_core import INT8_CREATION_OPTIONS from pygeoprocessing.geoprocessing_core import \ INT8_GTIFF_CREATION_TUPLE_OPTIONS +from pygeoprocessing.utils import gdal_use_exceptions _DEFAULT_ORIGIN = (444720, 3751320) _DEFAULT_PIXEL_SIZE = (30, -30) @@ -429,6 +429,33 @@ def test_reclassify_raster_reclass_nodata_ambiguity(self): actual_message = str(cm.exception) self.assertIn(expected_message, actual_message) + def test_reclassify_raster_reclass_nan_nodata(self): + """PGP.geoprocessing: test reclassifying nan nodata value.""" + n_pixels = 9 + pixel_matrix = numpy.ones((n_pixels, n_pixels), numpy.float32) + test_value = 0.5 + pixel_matrix[:] = test_value + nodata = numpy.nan + pixel_matrix[0,0] = nodata + pixel_matrix[5,7] = nodata + raster_path = os.path.join(self.workspace_dir, 'raster.tif') + target_path = os.path.join(self.workspace_dir, 'target.tif') + _array_to_raster( + pixel_matrix, nodata, raster_path) + + value_map = { + test_value: 0, + nodata: 1, + } + target_nodata = -1 + pygeoprocessing.reclassify_raster( + (raster_path, 1), value_map, target_path, gdal.GDT_Float32, + target_nodata, values_required=True) + target_info = pygeoprocessing.get_raster_info(target_path) + target_array = pygeoprocessing.raster_to_numpy_array(target_path) + self.assertAlmostEqual(numpy.sum(target_array), 2) + self.assertAlmostEqual(target_info['nodata'][0], target_nodata) + def test_reproject_vector(self): """PGP.geoprocessing: test reproject vector.""" # Create polygon shapefile to reproject