diff --git a/labtech/runners/process.py b/labtech/runners/process.py index 2556a7f..7c461e4 100644 --- a/labtech/runners/process.py +++ b/labtech/runners/process.py @@ -315,11 +315,12 @@ class ProcessRunner(Runner, ABC): """Base class for Runner's based on Python multiprocessing.""" def __init__(self, *, context: LabContext, storage: Storage, max_workers: int | None): - self.process_event_queue = multiprocessing.Manager().Queue(-1) + mp_context = self._get_mp_context() + self.process_event_queue = mp_context.Manager().Queue(-1) self.process_monitor = ProcessMonitor(process_event_queue = self.process_event_queue) self.log_queue = multiprocessing.Manager().Queue(-1) self.executor = ProcessExecutor( - mp_context=self._get_mp_context(), + mp_context=mp_context, max_workers=max_workers, ) @@ -346,6 +347,8 @@ def _subprocess_func(*, task: Task, task_name: str, use_cache: bool, # in serial by the main process. logger.handlers = [] logger.addHandler(QueueHandler(log_queue)) + orig_stdout = sys.stdout + orig_stderr = sys.stderr # Ignore type errors for type of value used to override stdout and stderr sys.stdout = LoggerFileProxy(logger.info, 'Captured STDOUT:\n') # type: ignore[assignment] sys.stderr = LoggerFileProxy(logger.error, 'Captured STDERR:\n') # type: ignore[assignment] @@ -376,6 +379,10 @@ def _subprocess_func(*, task: Task, task_name: str, use_cache: bool, process_event_queue.put(ProcessEndEvent( task_name=task_name, )) + sys.stdout.flush() + sys.stderr.flush() + sys.stdout = orig_stdout + sys.stderr = orig_stderr def submit_task(self, task: Task, task_name: str, use_cache: bool) -> None: future = self._submit_task( @@ -415,7 +422,7 @@ def stop(self) -> None: self.executor.stop() def close(self) -> None: - pass + self._consume_log_queue() def pending_task_count(self) -> int: return len(self.future_to_task) @@ -573,6 +580,7 @@ def _submit_task(self, executor: ProcessExecutor, task: Task, task_name: str, ) def close(self) -> None: + super().close() try: del _RUNNER_FORK_MEMORY[self.uuid] except KeyError: diff --git a/labtech/utils.py b/labtech/utils.py index 7676d16..cf7cb56 100644 --- a/labtech/utils.py +++ b/labtech/utils.py @@ -115,6 +115,7 @@ def write(self, buf): def flush(self): if self.bufs: self.logger_func('\n'.join([f'{self.prefix}{buf}' for buf in self.bufs])) + self.bufs = [] def ensure_dict_key_str(value, *, exception_type: type[Exception]) -> str: