diff --git a/pytools/persistent_dict.py b/pytools/persistent_dict.py index cafd34b3..5641d125 100644 --- a/pytools/persistent_dict.py +++ b/pytools/persistent_dict.py @@ -291,6 +291,14 @@ def update_for_bool(self, key_hash: Hash, key: bool) -> None: key_hash.update(str(key).encode("utf8")) def update_for_float(self, key_hash: Hash, key: float) -> None: + import math + if math.isnan(key): + # Also applies to np.nan, float("nan") + warn("Encountered a NaN while hashing. Since NaNs compare unequal " + "to themselves, the resulting key can not be retrieved from a " + "PersistentDict and will lead to a collision error on retrieval.", + stacklevel=1) + key_hash.update(key.hex().encode("utf8")) def update_for_complex(self, key_hash: Hash, key: float) -> None: @@ -661,6 +669,12 @@ class WriteOncePersistentDict(_PersistentDictBase[K, V]): and if they occur, almost always due to a bug in the hash key generation code (:class:`KeyBuilder`). + .. warning:: + + Since NaNs compare unequal to themselves, keys that include NaNs can + not be retrieved from a :class:`WriteOncePersistentDict` and will lead to a + :exc:`NoSuchEntryCollisionError` on retrieval. + .. automethod:: __init__ .. automethod:: __getitem__ .. automethod:: __setitem__ @@ -774,6 +788,12 @@ class PersistentDict(_PersistentDictBase[K, V]): and if they occur, almost always due to a bug in the hash key generation code (:class:`KeyBuilder`). + .. warning:: + + Since NaNs compare unequal to themselves, keys that include NaNs can + not be retrieved from a :class:`PersistentDict` and will lead to a + :exc:`NoSuchEntryCollisionError` on retrieval. + .. automethod:: __init__ .. automethod:: __getitem__ .. automethod:: __setitem__ diff --git a/pytools/test/test_persistent_dict.py b/pytools/test/test_persistent_dict.py index cd805062..33cfb33c 100644 --- a/pytools/test/test_persistent_dict.py +++ b/pytools/test/test_persistent_dict.py @@ -1048,6 +1048,40 @@ def test_concurrency_threads() -> None: # }}} +def test_nan_keys() -> None: + # test for https://github.com/inducer/pytools/issues/287 + + try: + tmpdir = tempfile.mkdtemp() + keyb = KeyBuilder() + pdict: PersistentDict[float, int] = PersistentDict("pytools-test", + container_dir=tmpdir, + safe_sync=False, + key_builder=keyb) + + import math + + nan_values = [math.nan, float("nan")] + + try: + import numpy as np + nan_values.append(np.nan) + except ImportError: + pass + + for nan_value in nan_values: + assert nan_value != nan_value + assert keyb(nan_value) == keyb(nan_value) + + pdict[nan_value] = 42 + + with (pytest.warns(CollisionWarning), + pytest.raises(NoSuchEntryCollisionError)): + pdict[math.nan] + finally: + shutil.rmtree(tmpdir) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])