Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions labtech/runners/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,11 +315,12 @@ class ProcessRunner(Runner, ABC):
"""Base class for Runner's based on Python multiprocessing."""

def __init__(self, *, context: LabContext, storage: Storage, max_workers: int | None):
self.process_event_queue = multiprocessing.Manager().Queue(-1)
mp_context = self._get_mp_context()
self.process_event_queue = mp_context.Manager().Queue(-1)
self.process_monitor = ProcessMonitor(process_event_queue = self.process_event_queue)
self.log_queue = multiprocessing.Manager().Queue(-1)
self.executor = ProcessExecutor(
mp_context=self._get_mp_context(),
mp_context=mp_context,
max_workers=max_workers,
)

Expand All @@ -346,6 +347,8 @@ def _subprocess_func(*, task: Task, task_name: str, use_cache: bool,
# in serial by the main process.
logger.handlers = []
logger.addHandler(QueueHandler(log_queue))
orig_stdout = sys.stdout
orig_stderr = sys.stderr
# Ignore type errors for type of value used to override stdout and stderr
sys.stdout = LoggerFileProxy(logger.info, 'Captured STDOUT:\n') # type: ignore[assignment]
sys.stderr = LoggerFileProxy(logger.error, 'Captured STDERR:\n') # type: ignore[assignment]
Expand Down Expand Up @@ -376,6 +379,10 @@ def _subprocess_func(*, task: Task, task_name: str, use_cache: bool,
process_event_queue.put(ProcessEndEvent(
task_name=task_name,
))
sys.stdout.flush()
sys.stderr.flush()
sys.stdout = orig_stdout
sys.stderr = orig_stderr

def submit_task(self, task: Task, task_name: str, use_cache: bool) -> None:
future = self._submit_task(
Expand Down Expand Up @@ -415,7 +422,7 @@ def stop(self) -> None:
self.executor.stop()

def close(self) -> None:
pass
self._consume_log_queue()

def pending_task_count(self) -> int:
return len(self.future_to_task)
Expand Down Expand Up @@ -573,6 +580,7 @@ def _submit_task(self, executor: ProcessExecutor, task: Task, task_name: str,
)

def close(self) -> None:
super().close()
try:
del _RUNNER_FORK_MEMORY[self.uuid]
except KeyError:
Expand Down
1 change: 1 addition & 0 deletions labtech/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def write(self, buf):
def flush(self):
if self.bufs:
self.logger_func('\n'.join([f'{self.prefix}{buf}' for buf in self.bufs]))
self.bufs = []


def ensure_dict_key_str(value, *, exception_type: type[Exception]) -> str:
Expand Down