diff --git a/linx/reactions.py b/linx/reactions.py index ce2102e..f3f29cd 100644 --- a/linx/reactions.py +++ b/linx/reactions.py @@ -149,7 +149,8 @@ def __init__( self.expsigma_vec = jax.device_put( self.expsigma_vec, device=gpus[0] ) - except: + except (RuntimeError, IndexError): + # No GPU available or no GPU devices found - data stays on CPU pass elif frwrd_rate_param_func is not None: diff --git a/linx/thermo.py b/linx/thermo.py index 85cff9d..161c888 100644 --- a/linx/thermo.py +++ b/linx/thermo.py @@ -671,7 +671,8 @@ def p_massive_MB(T, mu, m, g): f_numu_ann_tab = device_put( f_numu_ann_tab, device=gpus[0] ) -except: +except (RuntimeError, IndexError): + # No GPU available or no GPU devices found - data stays on CPU pass diff --git a/linx/weak_rates.py b/linx/weak_rates.py index 71940c7..e187ce1 100644 --- a/linx/weak_rates.py +++ b/linx/weak_rates.py @@ -102,7 +102,8 @@ def __init__(self, self.L_nTOpCCRTh_res = jax.device_put( self.L_nTOpCCRTh_res, device=gpus[0] ) - except: + except (RuntimeError, IndexError): + # No GPU available or no GPU devices found - data stays on CPU pass self.T_pTOn_thermal_interval, self.L_pTOnCCRTh_res = np.loadtxt( @@ -119,7 +120,8 @@ def __init__(self, self.L_pTOnCCRTh_res = jax.device_put( self.L_pTOnCCRTh_res, device=gpus[0] ) - except: + except (RuntimeError, IndexError): + # No GPU available or no GPU devices found - data stays on CPU pass else: