diff --git a/src/subscript/restartthinner/restartthinner.py b/src/subscript/restartthinner/restartthinner.py index 2de4fb7a6..5c4a5093f 100644 --- a/src/subscript/restartthinner/restartthinner.py +++ b/src/subscript/restartthinner/restartthinner.py @@ -2,18 +2,22 @@ import argparse import datetime -import glob +import logging import os import shutil -import sys +import subprocess import tempfile +from collections.abc import Iterator +from contextlib import contextmanager from pathlib import Path -import numpy -import pandas +import numpy as np +import pandas as pd from resdata.resfile import ResdataFile -from subscript import __version__ +from subscript import __version__, getLogger + +logger = getLogger(__name__) DESCRIPTION = """ Slice a subset of restart-dates from an E100 Restart file (UNRST) @@ -28,97 +32,110 @@ def find_resdata_app(toolname: str) -> str: - """Locate path of apps in resdata. - - These have varying suffixes due through the history of resdata Makefiles. + """Locate path of resdata apps, trying common suffixes (.x, .c.x, .cpp.x). - Depending on resdata-version, it has the .x or the .c.x suffix - We prefer .x. + Args: + toolname: Base name of the tool (e.g., 'rd_unpack') Returns: - String with path if found. + Full path to the executable. Raises: - IOError: if tool can't be found + OSError: If tool cannot be found in PATH. """ - extensions = [".x", ".c.x", ".cpp.x", ""] # Order matters. - candidates = [toolname + extension for extension in extensions] - for candidate in candidates: - for path in os.environ["PATH"].split(os.pathsep): - candidatepath = Path(path) / candidate - if candidatepath.exists(): - return str(candidatepath) - raise OSError(toolname + " not found in path, PATH=" + str(os.environ["PATH"])) - - -def date_slicer(slicedates: list, restartdates: list, restartindices: list) -> dict: - """Make a dict that maps a chosen restart date to a report index""" - slicedatemap = {} + for ext in [".x", ".c.x", ".cpp.x", ""]: # Order matters. + if path := shutil.which(toolname + ext): + return path + raise OSError(f"{toolname} not found in PATH") + + +def date_slicer( + slicedates: list[pd.Timestamp], + restartdates: list[datetime.datetime], + restartindices: list[int], +) -> list[int]: + """Make a list of report indices that match the input slicedates.""" + slicedatelist = [] for slicedate in slicedates: - daydistances = [ - abs((pandas.Timestamp(slicedate) - x).days) for x in restartdates - ] - slicedatemap[slicedate] = restartindices[daydistances.index(min(daydistances))] - return slicedatemap + daydistances = [abs((pd.Timestamp(slicedate) - x).days) for x in restartdates] + slicedatelist.append(restartindices[daydistances.index(min(daydistances))]) + return slicedatelist -def rd_repacker(rstfilename: str, slicerstindices: list, quiet: bool) -> None: - """ - Wrapper for ecl_unpack.x and ecl_pack.x utilities. These - utilities are from resdata. +@contextmanager +def _working_directory(path: Path) -> Iterator[None]: + original_cwd = Path.cwd() + try: + os.chdir(path) + yield + finally: + os.chdir(original_cwd) - First unpacking a UNRST file, then deleting dates the dont't want, then - pack the remainding files into a new UNRST file - This function will change working directory to the - location of the UNRST file, dump temporary files in there, and - modify the original filename. - """ - out = " >/dev/null" if quiet else "" - # Error early if resdata tools are not available - try: - find_resdata_app("rd_unpack") - find_resdata_app("rd_pack") - except OSError: - sys.exit( - "ERROR: rd_unpack.x and/or rd_pack.x not found.\n" - "These tools are required and must be installed separately" - ) - - # Take special care if the UNRST file we get in is not in current directory - cwd = os.getcwd() - rstfilepath = Path(rstfilename).parent - tempdir = None +def rd_repacker(rstfilename: str, slicerstindices: list[int], quiet: bool) -> None: + """Repack a UNRST file keeping only selected restart indices. - try: - os.chdir(Path(rstfilename).parent) - tempdir = tempfile.mkdtemp(dir=".") - os.rename( - os.path.basename(rstfilename), - os.path.join(tempdir, os.path.basename(rstfilename)), - ) - os.chdir(tempdir) - os.system( - find_resdata_app("rd_unpack") + " " + os.path.basename(rstfilename) + out - ) - unpackedfiles = glob.glob("*.X*") - for file in unpackedfiles: - if int(file.split(".X")[1]) not in slicerstindices: - os.remove(file) - os.system(find_resdata_app("rd_pack") + " *.X*" + out) - # We are inside the tmp directory, move file one step up: - os.rename( - os.path.join(os.getcwd(), os.path.basename(rstfilename)), - os.path.join(os.getcwd(), "../", os.path.basename(rstfilename)), - ) - finally: - os.chdir(cwd) - if tempdir is not None: - shutil.rmtree(rstfilepath / tempdir) + Uses rd_unpack and rd_pack utilities from resdata to unpack the UNRST file, + remove unwanted dates, and repack into a new UNRST file. + Args: + rstfilename: Path to the UNRST file. + slicerstindices: List of restart indices to keep. + quiet: If True, suppress subprocess output. + + Raises: + OSError: If rd_unpack or rd_pack tools are not found. + """ + rd_unpack = find_resdata_app("rd_unpack") + rd_pack = find_resdata_app("rd_pack") + + rstpath = Path(rstfilename) + rstdir = rstpath.parent or Path(".") + rstname = rstpath.name + + with _working_directory(rstdir): + tempdir = Path(tempfile.mkdtemp(dir=".")) + try: + # Move UNRST into temp directory and work there + shutil.move(rstname, tempdir / rstname) + + with _working_directory(tempdir): + subprocess.run( + [rd_unpack, rstname], + capture_output=quiet, + check=True, + ) + + for file in Path(".").glob("*.X*"): + index = int(file.suffix.lstrip(".X")) + if index not in slicerstindices: + file.unlink() + + remaining_files = sorted(Path(".").glob("*.X*")) + subprocess.run( + [rd_pack, *[str(f) for f in remaining_files]], + capture_output=quiet, + check=True, + ) + + # Move result back up + shutil.move(rstname, f"../{rstname}") + finally: + shutil.rmtree(tempdir) + + +def get_restart_indices(rstfilename: str) -> list[int]: + """Extract a list of restart indices for a filename. + + Args: + rstfilename: Path to the UNRST file. -def get_restart_indices(rstfilename: str) -> list: - """Extract a list of RST indices for a filename""" + Returns: + List of restart report indices. + + Raises: + FileNotFoundError: If the file does not exist. + """ if Path(rstfilename).exists(): # This function segfaults if file does not exist return ResdataFile.file_report_list(str(rstfilename)) @@ -132,8 +149,14 @@ def restartthinner( dryrun: bool = True, keep: bool = False, ) -> None: - """ - Thin an existing UNRST file to selected number of restarts. + """Thin an existing UNRST file to selected number of restarts. + + Args: + filename: Path to the UNRST file. + numberofslices: Number of restart dates to keep. + quiet: If True, suppress informational output. + dryrun: If True, only show what would be done without modifying files. + keep: If True, keep original file with .orig suffix. """ rst = ResdataFile(filename) restart_indices = get_restart_indices(filename) @@ -142,41 +165,39 @@ def restartthinner( ] if numberofslices > 1: - slicedates = pandas.DatetimeIndex( - numpy.linspace( - pandas.Timestamp(restart_dates[0]).value, - pandas.Timestamp(restart_dates[-1]).value, + slicedates = pd.DatetimeIndex( + np.linspace( + pd.Timestamp(restart_dates[0]).value, + pd.Timestamp(restart_dates[-1]).value, int(numberofslices), ) ).to_list() else: slicedates = [restart_dates[-1]] # Only return last date if only one is wanted - slicerstindices = list( - date_slicer(slicedates, restart_dates, restart_indices).values() - ) - slicerstindices.sort() - slicerstindices = list(set(slicerstindices)) # uniquify + slicerstindices = date_slicer(slicedates, restart_dates, restart_indices) + slicerstindices = sorted(set(slicerstindices)) # uniquify if not quiet: - print("Selected restarts:") - print("-----------------------") + logger.info("Selected restarts:") + logger.info("-----------------------") for idx, rstidx in enumerate(restart_indices): slicepresent = "X" if rstidx in slicerstindices else "" - print( - f"{rstidx:4d} " - f"{datetime.date.strftime(restart_dates[idx], '%Y-%m-%d')} " - f"{slicepresent}" + logger.info( + "%4d %s %s", + rstidx, + datetime.date.strftime(restart_dates[idx], "%Y-%m-%d"), + slicepresent, ) - print("-----------------------") + logger.info("-----------------------") + if not dryrun: if keep: backupname = filename + ".orig" - if not quiet: - print(f"Info: Backing up {filename} to {backupname}") + logger.info("Backing up %s to %s", filename, backupname) shutil.copyfile(filename, backupname) rd_repacker(filename, slicerstindices, quiet) - print(f"Written to {filename}") + logger.info("Written to %s", filename) def get_parser() -> argparse.ArgumentParser: @@ -186,7 +207,11 @@ def get_parser() -> argparse.ArgumentParser: ) parser.add_argument("UNRST", help="Name of UNRST file") parser.add_argument( - "-n", "--restarts", type=int, help="Number of restart dates wanted", default=0 + "-n", + "--restarts", + type=int, + help="Number of restart dates wanted", + required=True, ) parser.add_argument( "-d", @@ -218,13 +243,19 @@ def get_parser() -> argparse.ArgumentParser: def main() -> None: - """Endpoint for command line script""" + """Endpoint for command line script.""" parser = get_parser() args = parser.parse_args() + if args.restarts <= 0: - print("ERROR: Number of restarts must be a positive number") - sys.exit(1) - if args.UNRST.endswith("DATA"): - print("ERROR: Provide the UNRST file, not the DATA file") - sys.exit(1) + parser.error("Number of restarts must be a positive number") + if args.UNRST.endswith(".DATA"): + parser.error("Provide the UNRST file, not the DATA file") + if args.quiet: + logger.setLevel(logging.WARNING) + restartthinner(args.UNRST, args.restarts, args.quiet, args.dryrun, args.keep) + + +if __name__ == "__main__": + main() diff --git a/tests/test_restartthinner.py b/tests/test_restartthinner.py index 20e716250..64be80540 100644 --- a/tests/test_restartthinner.py +++ b/tests/test_restartthinner.py @@ -1,8 +1,12 @@ +import datetime +import logging import os import shutil import subprocess from pathlib import Path +from unittest.mock import patch +import pandas as pd import pytest from subscript.restartthinner import restartthinner @@ -12,11 +16,11 @@ UNRST_FNAME = "2_R001_REEK-0.UNRST" -def test_dryrun(tmp_path, mocker): +def test_dryrun(tmp_path, mocker, monkeypatch): """Test dry-run""" shutil.copyfile(ECLDIR / UNRST_FNAME, tmp_path / UNRST_FNAME) - os.chdir(tmp_path) + monkeypatch.chdir(tmp_path) orig_rstindices = restartthinner.get_restart_indices(UNRST_FNAME) assert len(orig_rstindices) == 4 @@ -29,11 +33,11 @@ def test_dryrun(tmp_path, mocker): assert len(orig_rstindices) == len(restartthinner.get_restart_indices(UNRST_FNAME)) -def test_first_and_last(tmp_path, mocker): +def test_first_and_last(tmp_path, mocker, monkeypatch): """Ask for two restart points, this should give us the first and last.""" shutil.copyfile(ECLDIR / UNRST_FNAME, tmp_path / UNRST_FNAME) - os.chdir(tmp_path) + monkeypatch.chdir(tmp_path) orig_rstindices = restartthinner.get_restart_indices(UNRST_FNAME) @@ -50,9 +54,9 @@ def test_first_and_last(tmp_path, mocker): assert len(restartthinner.get_restart_indices(UNRST_FNAME + ".orig")) == 4 -def test_subdirectory(tmp_path, mocker): +def test_subdirectory(tmp_path, mocker, monkeypatch): """Check that we can thin an UNRST file two directory levels down""" - os.chdir(tmp_path) + monkeypatch.chdir(tmp_path) subdir = Path("eclipse/model") subdir.mkdir(parents=True) @@ -79,14 +83,14 @@ def test_subdirectory(tmp_path, mocker): ) -def test_get_restart_indices_filenotfound(tmp_path): +def test_get_restart_indices_filenotfound(tmp_path, monkeypatch): """EclFile.file_report_list segfaults unless the code is careful""" with pytest.raises(FileNotFoundError, match="foo"): restartthinner.get_restart_indices("foo") with pytest.raises(FileNotFoundError, match="foo"): restartthinner.get_restart_indices(Path("foo")) - os.chdir(tmp_path) + monkeypatch.chdir(tmp_path) Path("FOO.UNRST").write_text("this is not an unrst file", encoding="utf8") with pytest.raises(TypeError, match="which is not a restart file"): restartthinner.get_restart_indices("FOO.UNRST") @@ -96,3 +100,123 @@ def test_get_restart_indices_filenotfound(tmp_path): def test_integration(): """Test that the endpoint is installed, and the binary tools are available""" assert subprocess.check_output(["restartthinner", "-h"]) + + +def test_single_restart_slice(tmp_path, mocker, monkeypatch): + """Test requesting only 1 restart (should return just the last date).""" + shutil.copyfile(ECLDIR / UNRST_FNAME, tmp_path / UNRST_FNAME) + monkeypatch.chdir(tmp_path) + + orig_rstindices = restartthinner.get_restart_indices(UNRST_FNAME) + + mocker.patch("sys.argv", ["restartthinner", "-n", "1", UNRST_FNAME]) + restartthinner.main() + + new_rstindices = restartthinner.get_restart_indices(UNRST_FNAME) + assert len(new_rstindices) == 1 + assert new_rstindices[0] == orig_rstindices[-1] # Should be the last date + + +def test_negative_restarts_error(tmp_path, mocker, capsys, monkeypatch): + """Test that negative restart count gives an error.""" + shutil.copyfile(ECLDIR / UNRST_FNAME, tmp_path / UNRST_FNAME) + monkeypatch.chdir(tmp_path) + + mocker.patch("sys.argv", ["restartthinner", "-n", "-1", UNRST_FNAME]) + with pytest.raises(SystemExit) as excinfo: + restartthinner.main() + + assert excinfo.value.code == 2 # argparse error exit code + captured = capsys.readouterr() + assert "positive number" in captured.err + + +def test_zero_restarts_error(tmp_path, mocker, capsys, monkeypatch): + """Test that zero restart count gives an error.""" + shutil.copyfile(ECLDIR / UNRST_FNAME, tmp_path / UNRST_FNAME) + monkeypatch.chdir(tmp_path) + + mocker.patch("sys.argv", ["restartthinner", "-n", "0", UNRST_FNAME]) + with pytest.raises(SystemExit) as excinfo: + restartthinner.main() + + assert excinfo.value.code == 2 + captured = capsys.readouterr() + assert "positive number" in captured.err + + +def test_data_file_error(tmp_path, mocker, capsys, monkeypatch): + """Test that providing a DATA file instead of UNRST gives an error.""" + monkeypatch.chdir(tmp_path) + Path("TEST.DATA").touch() + + mocker.patch("sys.argv", ["restartthinner", "-n", "2", "TEST.DATA"]) + with pytest.raises(SystemExit) as excinfo: + restartthinner.main() + + assert excinfo.value.code == 2 + captured = capsys.readouterr() + assert "UNRST file" in captured.err + + +def test_find_resdata_app_not_found(): + """Test that OSError is raised when resdata tools are not in PATH.""" + with ( + patch.object(shutil, "which", return_value=None), + pytest.raises(OSError, match="nonexistent_tool not found in PATH"), + ): + restartthinner.find_resdata_app("nonexistent_tool") + + +def test_find_resdata_app_with_suffix(): + """Test that find_resdata_app tries different suffixes.""" + call_count = {"count": 0} + + def mock_which(name): + call_count["count"] += 1 + # Simulate finding tool with .c.x suffix on second try + if name == "rd_unpack.c.x": + return "/usr/bin/rd_unpack.c.x" + return None + + with patch.object(shutil, "which", side_effect=mock_which): + result = restartthinner.find_resdata_app("rd_unpack") + assert result == "/usr/bin/rd_unpack.c.x" + assert call_count["count"] == 2 # Tried .x first, then .c.x + + +def test_date_slicer(): + """Test date_slicer matches slicedates to nearest restart dates.""" + + # Create test data with 4 restart dates + restart_dates = [ + datetime.datetime(2020, 1, 1), + datetime.datetime(2020, 4, 1), + datetime.datetime(2020, 7, 1), + datetime.datetime(2020, 10, 1), + ] + restart_indices = [0, 1, 2, 3] + + # Slice dates that should match to indices 0, 2, 3 + slice_dates = [ + pd.Timestamp("2020-01-15"), # Closest to Jan 1 -> index 0 + pd.Timestamp("2020-06-15"), # Closest to Jul 1 -> index 2 + pd.Timestamp("2020-11-01"), # Closest to Oct 1 -> index 3 + ] + + result = restartthinner.date_slicer(slice_dates, restart_dates, restart_indices) + assert result == [0, 2, 3] + + +def test_quiet_mode(tmp_path, mocker, caplog, monkeypatch): + """Test that quiet mode suppresses log output.""" + shutil.copyfile(ECLDIR / UNRST_FNAME, tmp_path / UNRST_FNAME) + monkeypatch.chdir(tmp_path) + + # Run with quiet mode + with caplog.at_level(logging.INFO, logger="subscript"): + mocker.patch("sys.argv", ["restartthinner", "-q", "-n", "2", UNRST_FNAME]) + restartthinner.main() + + # Check that log output is empty + assert len(caplog.text) == 0