From 4852761250ad3bb4d9f2447f2ac7490f4d72024a Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Wed, 14 Jan 2026 22:26:18 +0000 Subject: [PATCH 1/2] testing Pytest abort plugin --- conftest.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/conftest.py b/conftest.py index 72b4b598891c..617add8273b8 100644 --- a/conftest.py +++ b/conftest.py @@ -76,6 +76,9 @@ def pytest_collection() -> None: "CUDA_VISIBLE_DEVICES", str(xdist_worker_number % num_cuda_devices) ) +_USE_ROCM_ABORT_DETECTOR_PLUGIN = bool(os.environ.get("JAX_ROCM_LAST_RUNNING_FILE")) + + class ThreadSafeTestLogger: """Thread-safe logging for parallel test execution and abort detection""" def __init__(self): @@ -171,8 +174,7 @@ def clear_running_test(self, test_file): os.remove(log_file) -# Global logger instance -test_logger = ThreadSafeTestLogger() +test_logger = ThreadSafeTestLogger() if not _USE_ROCM_ABORT_DETECTOR_PLUGIN else None @pytest.hookimpl(hookwrapper=True) @@ -183,6 +185,10 @@ def pytest_runtest_protocol(item, nextitem): when the test completes successfully. If the test crashes, the file remains and can be detected by the test runner. """ + if _USE_ROCM_ABORT_DETECTOR_PLUGIN or test_logger is None: + outcome = yield + return outcome + test_file = test_logger.get_test_file_name(item.session) test_name = item.name nodeid = item.nodeid @@ -230,6 +236,9 @@ def pytest_sessionfinish(session, exitstatus): If a crash file still exists, it means a test crashed and the runner will detect it. We just report it here for visibility. """ + if _USE_ROCM_ABORT_DETECTOR_PLUGIN or test_logger is None: + return + test_file = test_logger.get_test_file_name(session) log_file = f"{test_logger.base_dir}/{test_file}_last_running.json" From 224b67b4783b714e803222107ec99b5d3a89a41f Mon Sep 17 00:00:00 2001 From: gulsumgudukbay Date: Thu, 15 Jan 2026 21:13:36 +0000 Subject: [PATCH 2/2] update conftest --- conftest.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/conftest.py b/conftest.py index 617add8273b8..0e587e7d314e 100644 --- a/conftest.py +++ b/conftest.py @@ -76,7 +76,10 @@ def pytest_collection() -> None: "CUDA_VISIBLE_DEVICES", str(xdist_worker_number % num_cuda_devices) ) -_USE_ROCM_ABORT_DETECTOR_PLUGIN = bool(os.environ.get("JAX_ROCM_LAST_RUNNING_FILE")) +_USE_ROCM_ABORT_DETECTOR_PLUGIN = bool( + os.environ.get("PYTEST_ABORT_LAST_RUNNING_FILE") + or os.environ.get("PYTEST_ABORT_LAST_RUNNING_DIR") +) class ThreadSafeTestLogger: