From 6d636968d065cde519135518a0c32e2a8692db0c Mon Sep 17 00:00:00 2001 From: Vcholerae1 <165551159+Vcholerae1@users.noreply.github.com> Date: Thu, 4 Dec 2025 16:49:37 +0800 Subject: [PATCH] Update the Python backend type annotation to Literal["eager", "jit", "compile"] so that mismatches are caught during static type checking. --- src/deepwave/elastic.py | 8 ++++---- src/deepwave/scalar.py | 8 ++++---- src/deepwave/scalar_born.py | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/deepwave/elastic.py b/src/deepwave/elastic.py index 417be74..61ecce0 100644 --- a/src/deepwave/elastic.py +++ b/src/deepwave/elastic.py @@ -11,7 +11,7 @@ material parameters (Lam'e parameters and buoyancy) and source amplitudes. """ -from typing import Any, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, List, Literal, Optional, Sequence, Tuple, Union, cast import torch @@ -241,7 +241,7 @@ def forward( forward_callback: Optional[deepwave.common.Callback] = None, backward_callback: Optional[deepwave.common.Callback] = None, callback_frequency: int = 1, - python_backend: Union[bool, str] = False, + python_backend: Union[Literal["eager", "jit", "compile"], bool] = False, ) -> Tuple[torch.Tensor, ...]: """Perform forward propagation/modelling. @@ -371,7 +371,7 @@ def elastic( forward_callback: Optional[deepwave.common.Callback] = None, backward_callback: Optional[deepwave.common.Callback] = None, callback_frequency: int = 1, - python_backend: Union[bool, str] = False, + python_backend: Union[Literal["eager", "jit", "compile"], bool] = False ) -> Tuple[torch.Tensor, ...]: """Elastic wave propagation (functional interface). @@ -2518,7 +2518,7 @@ def elastic_python( def elastic_func( - python_backend: Union[bool, str], *args: Any + python_backend: Union[Literal["eager", "jit", "compile"], bool] = False, *args: Any ) -> Tuple[torch.Tensor, ...]: """A helper function to apply the ElasticForwardFunc. diff --git a/src/deepwave/scalar.py b/src/deepwave/scalar.py index c620500..87a7dd5 100644 --- a/src/deepwave/scalar.py +++ b/src/deepwave/scalar.py @@ -18,7 +18,7 @@ All outputs are differentiable with respect to float torch.Tensor inputs. """ -from typing import Any, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, List, Literal, Optional, Sequence, Tuple, Union, cast import torch @@ -89,7 +89,7 @@ def forward( forward_callback: Optional[deepwave.common.Callback] = None, backward_callback: Optional[deepwave.common.Callback] = None, callback_frequency: int = 1, - python_backend: Union[bool, str] = False, + python_backend: Union[Literal["eager", "jit", "compile"], bool] = False, ) -> List[torch.Tensor]: """Performs forward propagation/modelling. @@ -157,7 +157,7 @@ def scalar( forward_callback: Optional[deepwave.common.Callback] = None, backward_callback: Optional[deepwave.common.Callback] = None, callback_frequency: int = 1, - python_backend: Union[bool, str] = False, + python_backend: Union[Literal["eager", "jit", "compile"], bool] = False ) -> List[torch.Tensor]: """Scalar wave propagation (functional interface). @@ -1890,7 +1890,7 @@ def scalar_python( def scalar_func( - python_backend: Union[bool, str], *args: Any + python_backend: Union[Literal["eager", "jit", "compile"], bool], *args: Any ) -> Tuple[torch.Tensor, ...]: """Helper function to apply the ScalarForwardFunc. diff --git a/src/deepwave/scalar_born.py b/src/deepwave/scalar_born.py index 235579b..9fec59c 100644 --- a/src/deepwave/scalar_born.py +++ b/src/deepwave/scalar_born.py @@ -6,7 +6,7 @@ wavefield that uses 2 / v * scatter * dt^2 * wavefield as the source term. """ -from typing import Any, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, List, Literal, Optional, Sequence, Tuple, Union, cast import torch @@ -125,7 +125,7 @@ def forward( forward_callback: Optional[deepwave.common.Callback] = None, backward_callback: Optional[deepwave.common.Callback] = None, callback_frequency: int = 1, - python_backend: Union[bool, str] = False, + python_backend: Union[Literal["eager", "jit", "compile"], bool] = False, ) -> Tuple[torch.Tensor, ...]: """Perform forward propagation/modelling. @@ -215,7 +215,7 @@ def scalar_born( forward_callback: Optional[deepwave.common.Callback] = None, backward_callback: Optional[deepwave.common.Callback] = None, callback_frequency: int = 1, - python_backend: Union[bool, str] = False, + python_backend: Union[Literal["eager", "jit", "compile"], bool] = False, ) -> Tuple[torch.Tensor, ...]: """Scalar Born wave propagation (functional interface). @@ -1561,7 +1561,7 @@ def scalar_born_python( def scalar_born_func( - python_backend: Union[bool, str], + python_backend: Union[Literal["eager", "jit", "compile"], bool] = False, *args: Any, ) -> Tuple[torch.Tensor, ...]: """Helper function to apply the ScalarBornForwardFunc.