diff --git a/src/oet/calculator/uma.py b/src/oet/calculator/uma.py index 9637c61..c326156 100755 --- a/src/oet/calculator/uma.py +++ b/src/oet/calculator/uma.py @@ -303,19 +303,23 @@ def calc( or not isinstance(cache_dir, str) ): raise RuntimeError("Problems handling input parameters.") - # Check if we have the respective models stored locally in cache - # If so, switch to offline mode - if self.check_for_model_files(basemodel=basemodel, cache_dir=cache_dir): - print("Model parameters found in cache. Switching to offline mode.") + # Check if the model files are available + model_files_available = self.check_for_model_files(basemodel=basemodel, cache_dir=cache_dir) + # If they are available, switch to offline mode. + if model_files_available: self.switch_to_offline_mode() - # If set by the user, switch also to offline mode. - # This is a fallback, if online communication must be prevented. - if offline_mode: + print("Model files detected. Switching to offline mode.") + # If they are not available, but the user requested the offline mode to prevent online communication, + # print a warning as this will likely cause subsequent errors. + elif offline_mode: self.switch_to_offline_mode() - if self.check_for_model_files(basemodel=basemodel, cache_dir=cache_dir): - print( - "WARNING: No model files were detected. This might lead to subsequent errors." + # Check if the model files are locally available. If not, subsequent errors will occur + # as they cannot be downloaded. + print( + "WARNING: Offline mode selected, but no model files were detected. " + "This will likely cause subsequent errors." ) + # setup calculator if not already set # this is important as usage on a server would otherwise cause # initialization with every call so that nothing is gained diff --git a/src/oet/server_client/server.py b/src/oet/server_client/server.py index 8d3c2d6..13e7b8d 100755 --- a/src/oet/server_client/server.py +++ b/src/oet/server_client/server.py @@ -31,6 +31,7 @@ from collections import OrderedDict from collections.abc import Mapping, Sequence from concurrent.futures import BrokenExecutor, ProcessPoolExecutor +import multiprocessing as mp from contextlib import redirect_stdout from pathlib import Path from types import FrameType @@ -623,8 +624,12 @@ def main() -> None: # Create workers workers = args.nthreads + # Use spawn start method so CUDA can be initialized safely in worker processes + mp_ctx = mp.get_context("spawn") # Initialize the ProcessPool - executor = ProcessPoolExecutor(max_workers=workers, initializer=worker_initializer) + executor = ProcessPoolExecutor( + max_workers=workers, initializer=worker_initializer, mp_context=mp_ctx + ) # Make a CalculatorClass getting the hooks on calculators argument parsing # Info on calculator type is store in the object for client requests