Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion labtech/lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import Counter, defaultdict
from enum import StrEnum
from pathlib import Path
from time import monotonic
from typing import TYPE_CHECKING

from tqdm.contrib.logging import logging_redirect_tqdm
Expand Down Expand Up @@ -223,7 +224,11 @@ def run(self, tasks: Sequence[Task]) -> dict[Task, Any]:
)
task_monitor.show()

last_monitor_update = monotonic()

def process_completed_tasks():
nonlocal last_monitor_update

# Wait up to a short delay before allowing the
# task monitor to update.
for task, res in runner.wait(timeout_seconds=0.5):
Expand All @@ -241,8 +246,10 @@ def process_completed_tasks():

runner.remove_results(tasks_with_removable_results)

if task_monitor is not None:
# Update task monitor at most every half second.
if task_monitor is not None and ((monotonic() - last_monitor_update) >= 0.5):
task_monitor.update()
last_monitor_update = monotonic()

redirected_loggers = [] if self.lab.notebook else [logger]
with logging_redirect_tqdm(loggers=redirected_loggers):
Expand Down
8 changes: 7 additions & 1 deletion labtech/runners/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dataclasses import dataclass
from logging.handlers import QueueHandler
from queue import Empty
from time import monotonic
from typing import TYPE_CHECKING, Generic, TypeVar, cast
from uuid import uuid4

Expand Down Expand Up @@ -152,6 +153,7 @@ class ProcessRunner(Runner, Generic[FutureT], ABC):
"""Runner based on Python multiprocessing."""

def __init__(self) -> None:
self.last_consume_log = monotonic()
self.log_queue = multiprocessing.Manager().Queue(-1)
self.task_event_queue = multiprocessing.Manager().Queue(-1)
self.process_monitor = ProcessMonitor(task_event_queue = self.task_event_queue)
Expand All @@ -178,7 +180,11 @@ def submit_task(self, task: Task, task_name: str, use_cache: bool) -> None:
self.future_to_task[future] = task

def wait(self, *, timeout_seconds: float | None) -> Iterator[tuple[Task, ResultMeta | BaseException]]:
self._consume_log_queue()
# Consume logs at most every half second.
if (monotonic() - self.last_consume_log) >= 0.5:
self._consume_log_queue()
self.last_consume_log = monotonic()

done = self._get_completed_futures(
futures=list(self.future_to_task.keys()),
timeout_seconds=timeout_seconds,
Expand Down