diff --git a/tests/unit/layer/test_precomputed.py b/tests/unit/layer/test_precomputed.py index e022eb6cc..0abc7957d 100644 --- a/tests/unit/layer/test_precomputed.py +++ b/tests/unit/layer/test_precomputed.py @@ -348,3 +348,21 @@ def test_change_scale_on_extend_shows_missing_scales(clear_caches_reset_mocks, m ) with pytest.raises(RuntimeError, match=r"Missing scales:"): info_spec.update_info(LAYER_X0_PATH, overwrite=False, keep_existing_scales=False) + + +def test_update_info_keyerror_includes_path(clear_caches_reset_mocks, mocker): + test_path = "gs://test/path" + info_spec = PrecomputedInfoSpec( + info_spec_params=InfoSpecParams.from_optional_reference( + reference_path=LAYER_X0_PATH, + scales=[[1, 1, 1]], + inherit_all_params=True, + ) + ) + mocker.patch.object( + precomputed, + "get_info", + return_value={"data_type": "uint8", "num_channels": 1}, # missing "scales" + ) + with pytest.raises(KeyError, match=test_path): + info_spec.update_info(test_path, overwrite=False, keep_existing_scales=True) diff --git a/zetta_utils/layer/deprecated/precomputed.py b/zetta_utils/layer/deprecated/precomputed.py index acd327317..36884d1c1 100644 --- a/zetta_utils/layer/deprecated/precomputed.py +++ b/zetta_utils/layer/deprecated/precomputed.py @@ -315,9 +315,12 @@ def update_info(self, path: str, on_info_exists: InfoExistsModes) -> bool: "while `on_info_exists` is set to 'expect_same'" ) if on_info_exists == "extend": - scales_unchanged = all( - e in new_info["scales"] for e in existing_info["scales"] - ) + try: + scales_unchanged = all( + e in new_info["scales"] for e in existing_info["scales"] + ) + except KeyError as exc: + raise KeyError(f"{exc} while validating info at path '{path}'") from exc if not scales_unchanged: raise RuntimeError( f"Info created by the info_spec {self} is not a pure extension " diff --git a/zetta_utils/layer/precomputed.py b/zetta_utils/layer/precomputed.py index 38f7f788c..3bc4129a9 100644 --- a/zetta_utils/layer/precomputed.py +++ b/zetta_utils/layer/precomputed.py @@ -348,46 +348,49 @@ def update_info(self, path: str, overwrite: bool, keep_existing_scales: bool) -> if new_info is not None: if existing_info is not None: - if keep_existing_scales: - if (new_info["data_type"] != existing_info["data_type"]) or ( - new_info["num_channels"] != existing_info["num_channels"] - ): - raise RuntimeError( - "Attempting to keep existing scales while 'data_type' or " - "'num_channels' have changed in the info file. " - "Consider setting `keep_existing_scales` to False." + try: + if keep_existing_scales: + if (new_info["data_type"] != existing_info["data_type"]) or ( + new_info["num_channels"] != existing_info["num_channels"] + ): + raise RuntimeError( + "Attempting to keep existing scales while 'data_type' or " + "'num_channels' have changed in the info file. " + "Consider setting `keep_existing_scales` to False." + ) + new_info["scales"] = _merge_and_sort_scales( + existing_info["scales"], new_info["scales"] ) - new_info["scales"] = _merge_and_sort_scales( - existing_info["scales"], new_info["scales"] - ) - - if not overwrite: - existing_scales_changed = any( - not (e in new_info["scales"]) for e in existing_info["scales"] - ) - if existing_scales_changed: - missing_scales = [ - e for e in existing_info["scales"] if e not in new_info["scales"] - ] - raise RuntimeError( - f"New info is not a pure extension of the info existing at '{path}' " - "while `info_overwrite` is set to False. Some scales present " - f"in `{path}` would be overwritten.\n" - f"Missing scales: {missing_scales}" - ) - existing_info_no_scales = copy.deepcopy(existing_info) - del existing_info_no_scales["scales"] - new_info_no_scales = copy.deepcopy(new_info) - del new_info_no_scales["scales"] - non_scales_changed = existing_info_no_scales != new_info_no_scales - if non_scales_changed: - diff = _get_info_diff(existing_info_no_scales, new_info_no_scales) - raise RuntimeError( - f"New info is not a pure extension of the info existing at '{path}' " - "while `info_overwrite` is set to False. Some non-scale keys " - f"in `{path}` would be overwritten.\n" - f"Differences:\n{diff}" + + if not overwrite: + existing_scales_changed = any( + not (e in new_info["scales"]) for e in existing_info["scales"] ) + if existing_scales_changed: + missing_scales = [ + e for e in existing_info["scales"] if e not in new_info["scales"] + ] + raise RuntimeError( + f"New info is not a pure extension of the info existing " + f"at '{path}' while `info_overwrite` is set to False. " + f"Some scales in `{path}` would be overwritten.\n" + f"Missing scales: {missing_scales}" + ) + existing_info_no_scales = copy.deepcopy(existing_info) + del existing_info_no_scales["scales"] + new_info_no_scales = copy.deepcopy(new_info) + del new_info_no_scales["scales"] + non_scales_changed = existing_info_no_scales != new_info_no_scales + if non_scales_changed: + diff = _get_info_diff(existing_info_no_scales, new_info_no_scales) + raise RuntimeError( + f"New info is not a pure extension of the info existing " + f"at '{path}' while `info_overwrite` is set to False. " + f"Some non-scale keys in `{path}` would be overwritten.\n" + f"Differences:\n{diff}" + ) + except KeyError as exc: + raise KeyError(f"{exc} while validating info at path '{path}'") from exc if existing_info != new_info: _write_info(new_info, path)