Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tests/unit/layer/test_precomputed.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,3 +348,21 @@ def test_change_scale_on_extend_shows_missing_scales(clear_caches_reset_mocks, m
)
with pytest.raises(RuntimeError, match=r"Missing scales:"):
info_spec.update_info(LAYER_X0_PATH, overwrite=False, keep_existing_scales=False)


def test_update_info_keyerror_includes_path(clear_caches_reset_mocks, mocker):
test_path = "gs://test/path"
info_spec = PrecomputedInfoSpec(
info_spec_params=InfoSpecParams.from_optional_reference(
reference_path=LAYER_X0_PATH,
scales=[[1, 1, 1]],
inherit_all_params=True,
)
)
mocker.patch.object(
precomputed,
"get_info",
return_value={"data_type": "uint8", "num_channels": 1}, # missing "scales"
)
with pytest.raises(KeyError, match=test_path):
info_spec.update_info(test_path, overwrite=False, keep_existing_scales=True)
9 changes: 6 additions & 3 deletions zetta_utils/layer/deprecated/precomputed.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,12 @@ def update_info(self, path: str, on_info_exists: InfoExistsModes) -> bool:
"while `on_info_exists` is set to 'expect_same'"
)
if on_info_exists == "extend":
scales_unchanged = all(
e in new_info["scales"] for e in existing_info["scales"]
)
try:
scales_unchanged = all(
e in new_info["scales"] for e in existing_info["scales"]
)
except KeyError as exc:
raise KeyError(f"{exc} while validating info at path '{path}'") from exc
if not scales_unchanged:
raise RuntimeError(
f"Info created by the info_spec {self} is not a pure extension "
Expand Down
79 changes: 41 additions & 38 deletions zetta_utils/layer/precomputed.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,46 +348,49 @@ def update_info(self, path: str, overwrite: bool, keep_existing_scales: bool) ->

if new_info is not None:
if existing_info is not None:
if keep_existing_scales:
if (new_info["data_type"] != existing_info["data_type"]) or (
new_info["num_channels"] != existing_info["num_channels"]
):
raise RuntimeError(
"Attempting to keep existing scales while 'data_type' or "
"'num_channels' have changed in the info file. "
"Consider setting `keep_existing_scales` to False."
try:
if keep_existing_scales:
if (new_info["data_type"] != existing_info["data_type"]) or (
new_info["num_channels"] != existing_info["num_channels"]
):
raise RuntimeError(
"Attempting to keep existing scales while 'data_type' or "
"'num_channels' have changed in the info file. "
"Consider setting `keep_existing_scales` to False."
)
new_info["scales"] = _merge_and_sort_scales(
existing_info["scales"], new_info["scales"]
)
new_info["scales"] = _merge_and_sort_scales(
existing_info["scales"], new_info["scales"]
)

if not overwrite:
existing_scales_changed = any(
not (e in new_info["scales"]) for e in existing_info["scales"]
)
if existing_scales_changed:
missing_scales = [
e for e in existing_info["scales"] if e not in new_info["scales"]
]
raise RuntimeError(
f"New info is not a pure extension of the info existing at '{path}' "
"while `info_overwrite` is set to False. Some scales present "
f"in `{path}` would be overwritten.\n"
f"Missing scales: {missing_scales}"
)
existing_info_no_scales = copy.deepcopy(existing_info)
del existing_info_no_scales["scales"]
new_info_no_scales = copy.deepcopy(new_info)
del new_info_no_scales["scales"]
non_scales_changed = existing_info_no_scales != new_info_no_scales
if non_scales_changed:
diff = _get_info_diff(existing_info_no_scales, new_info_no_scales)
raise RuntimeError(
f"New info is not a pure extension of the info existing at '{path}' "
"while `info_overwrite` is set to False. Some non-scale keys "
f"in `{path}` would be overwritten.\n"
f"Differences:\n{diff}"

if not overwrite:
existing_scales_changed = any(
not (e in new_info["scales"]) for e in existing_info["scales"]
)
if existing_scales_changed:
missing_scales = [
e for e in existing_info["scales"] if e not in new_info["scales"]
]
raise RuntimeError(
f"New info is not a pure extension of the info existing "
f"at '{path}' while `info_overwrite` is set to False. "
f"Some scales in `{path}` would be overwritten.\n"
f"Missing scales: {missing_scales}"
)
existing_info_no_scales = copy.deepcopy(existing_info)
del existing_info_no_scales["scales"]
new_info_no_scales = copy.deepcopy(new_info)
del new_info_no_scales["scales"]
non_scales_changed = existing_info_no_scales != new_info_no_scales
if non_scales_changed:
diff = _get_info_diff(existing_info_no_scales, new_info_no_scales)
raise RuntimeError(
f"New info is not a pure extension of the info existing "
f"at '{path}' while `info_overwrite` is set to False. "
f"Some non-scale keys in `{path}` would be overwritten.\n"
f"Differences:\n{diff}"
)
except KeyError as exc:
raise KeyError(f"{exc} while validating info at path '{path}'") from exc

if existing_info != new_info:
_write_info(new_info, path)
Expand Down
Loading