diff --git a/src/reax/metrics/utils.py b/src/reax/metrics/utils.py index e91be0f..7249924 100644 --- a/src/reax/metrics/utils.py +++ b/src/reax/metrics/utils.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import TYPE_CHECKING, ClassVar, Optional, Protocol, TypeVar +from typing import TYPE_CHECKING, ClassVar, Protocol, TypeVar import beartype import clu.internal.utils @@ -20,7 +20,7 @@ __all__ = tuple() M = TypeVar("M", bound=Metric) -OptionalMask = Optional["reax.types.ArrayMask"] +OptionalMask = types.ArrayMask | None class ReduceFn(Protocol):