From 6dc32250402895cc05a6cab697c41a60d96a80de Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Fri, 5 Sep 2025 15:48:16 +0300 Subject: [PATCH 01/26] Draft safeguards compressor wrappers --- .../compressor/compressors/__init__.py | 10 ++++++ .../compressors/safeguards/__init__.py | 11 ++++++ .../compressors/safeguards/sperr.py | 22 ++++++++++++ .../compressor/compressors/safeguards/sz3.py | 31 +++++++++++++++++ .../compressor/compressors/safeguards/zero.py | 31 +++++++++++++++++ .../compressors/safeguards/zfp_round.py | 34 +++++++++++++++++++ 6 files changed, 139 insertions(+) create mode 100644 src/climatebenchpress/compressor/compressors/safeguards/__init__.py create mode 100644 src/climatebenchpress/compressor/compressors/safeguards/sperr.py create mode 100644 src/climatebenchpress/compressor/compressors/safeguards/sz3.py create mode 100644 src/climatebenchpress/compressor/compressors/safeguards/zero.py create mode 100644 src/climatebenchpress/compressor/compressors/safeguards/zfp_round.py diff --git a/src/climatebenchpress/compressor/compressors/__init__.py b/src/climatebenchpress/compressor/compressors/__init__.py index b523c75..31bcf95 100644 --- a/src/climatebenchpress/compressor/compressors/__init__.py +++ b/src/climatebenchpress/compressor/compressors/__init__.py @@ -2,6 +2,10 @@ "BitRound", "BitRoundPco", "Jpeg2000", + "SafeguardsSperr", + "SafeguardsSz3", + "SafeguardsZero", + "SafeguardsZfpRound", "Sperr", "StochRound", "StochRoundPco", @@ -15,6 +19,12 @@ from .bitround import BitRound from .bitround_pco import BitRoundPco from .jpeg2000 import Jpeg2000 +from .safeguards import ( + SafeguardsSperr, + SafeguardsSz3, + SafeguardsZero, + SafeguardsZfpRound, +) from .sperr import Sperr from .stochround import StochRound from .stochround_pco import StochRoundPco diff --git a/src/climatebenchpress/compressor/compressors/safeguards/__init__.py b/src/climatebenchpress/compressor/compressors/safeguards/__init__.py new file mode 100644 index 0000000..ffce923 --- /dev/null +++ b/src/climatebenchpress/compressor/compressors/safeguards/__init__.py @@ -0,0 +1,11 @@ +__all__ = [ + "SafeguardsSperr", + "SafeguardsSz3", + "SafeguardsZero", + "SafeguardsZfpRound", +] + +from .sperr import SafeguardsSperr +from .sz3 import SafeguardsSz3 +from .zero import SafeguardsZero +from .zfp_round import SafeguardsZfpRound diff --git a/src/climatebenchpress/compressor/compressors/safeguards/sperr.py b/src/climatebenchpress/compressor/compressors/safeguards/sperr.py new file mode 100644 index 0000000..c49183c --- /dev/null +++ b/src/climatebenchpress/compressor/compressors/safeguards/sperr.py @@ -0,0 +1,22 @@ +__all__ = ["SafeguardsSperr"] + +import numcodecs_safeguards +import numcodecs_wasm_sperr + +from ..abc import Compressor + + +class SafeguardsSperr(Compressor): + """Safeguarded SPERR compressor.""" + + name = "safeguards-sperr" + description = "Safeguards(SPERR)" + + @staticmethod + def abs_bound_codec(error_bound, **kwargs): + return numcodecs_safeguards.SafeguardsCodec( + codec=numcodecs_wasm_sperr.Sperr(mode="pwe", pwe=error_bound), + safeguards=[ + dict(kind="eb", type="abs", eb=error_bound, equal_nan=True), + ], + ) diff --git a/src/climatebenchpress/compressor/compressors/safeguards/sz3.py b/src/climatebenchpress/compressor/compressors/safeguards/sz3.py new file mode 100644 index 0000000..480992c --- /dev/null +++ b/src/climatebenchpress/compressor/compressors/safeguards/sz3.py @@ -0,0 +1,31 @@ +__all__ = ["SafeguardsSz3"] + +import numcodecs_safeguards +import numcodecs_wasm_sz3 + +from ..abc import Compressor + + +class SafeguardsSz3(Compressor): + """Safeguarded SZ3 compressor.""" + + name = "safeguards-sz3" + description = "Safeguards(SZ3)" + + @staticmethod + def abs_bound_codec(error_bound, **kwargs): + return numcodecs_safeguards.SafeguardsCodec( + codec=numcodecs_wasm_sz3.Sz3(eb_mode="abs", eb_abs=error_bound), + safeguards=[ + dict(kind="eb", type="abs", eb=error_bound, equal_nan=True), + ], + ) + + @staticmethod + def rel_bound_codec(error_bound, **kwargs): + return numcodecs_safeguards.SafeguardsCodec( + codec=numcodecs_wasm_sz3.Sz3(eb_mode="rel", eb_rel=error_bound), + safeguards=[ + dict(kind="eb", type="rel", eb=error_bound, equal_nan=True), + ], + ) diff --git a/src/climatebenchpress/compressor/compressors/safeguards/zero.py b/src/climatebenchpress/compressor/compressors/safeguards/zero.py new file mode 100644 index 0000000..d156c8a --- /dev/null +++ b/src/climatebenchpress/compressor/compressors/safeguards/zero.py @@ -0,0 +1,31 @@ +__all__ = ["SafeguardsZero"] + +import numcodecs_safeguards +import numcodecs_zero + +from ..abc import Compressor + + +class SafeguardsZero(Compressor): + """Safeguarded all-zero compressor.""" + + name = "safeguards-zero" + description = "Safeguards(0)" + + @staticmethod + def abs_bound_codec(error_bound, **kwargs): + return numcodecs_safeguards.SafeguardsCodec( + codec=numcodecs_zero.ZeroCodec(), + safeguards=[ + dict(kind="eb", type="abs", eb=error_bound, equal_nan=True), + ], + ) + + @staticmethod + def rel_bound_codec(error_bound, **kwargs): + return numcodecs_safeguards.SafeguardsCodec( + codec=numcodecs_zero.ZeroCodec(), + safeguards=[ + dict(kind="eb", type="rel", eb=error_bound, equal_nan=True), + ], + ) diff --git a/src/climatebenchpress/compressor/compressors/safeguards/zfp_round.py b/src/climatebenchpress/compressor/compressors/safeguards/zfp_round.py new file mode 100644 index 0000000..81b8caa --- /dev/null +++ b/src/climatebenchpress/compressor/compressors/safeguards/zfp_round.py @@ -0,0 +1,34 @@ +__all__ = ["SafeguardsZfpRound"] + +import numcodecs_safeguards +import numcodecs_wasm_zfp + +from ..abc import Compressor + + +class SafeguardsZfpRound(Compressor): + """Safeguarded ZFP-ROUND compressor. + + This is an adjusted version of the ZFP compressor with an improved rounding mechanism + for the transform coefficients. + """ + + name = "safeguards-zfp-round" + description = "Safeguards(ZFP-ROUND)" + + # NOTE: + # ZFP mechanism for strictly supporting relative error bounds is to + # truncate the floating point bit representation and then use ZFP's lossless + # mode for compression. This is essentially equivalent to the BitRound + # compressors we are already implementing (with a difference what the lossless + # compression algorithm is). + # See https://zfp.readthedocs.io/en/release1.0.1/faq.html#q-relerr for more details. + + @staticmethod + def abs_bound_codec(error_bound, **kwargs): + return numcodecs_safeguards.SafeguardsCodec( + codec=numcodecs_wasm_zfp.Zfp(mode="fixed-accuracy", tolerance=error_bound), + safeguards=[ + dict(kind="eb", type="abs", eb=error_bound, equal_nan=True), + ], + ) From 222618680c4881ebcf3417923bc41781f981dead Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Mon, 8 Sep 2025 10:57:26 +0300 Subject: [PATCH 02/26] Add a very simple, cheeky dSSIM safeguard --- .../compressor/compressors/__init__.py | 2 + .../compressors/safeguards/__init__.py | 2 + .../compressors/safeguards/zero_dssim.py | 79 +++++++++++++++++++ 3 files changed, 83 insertions(+) create mode 100644 src/climatebenchpress/compressor/compressors/safeguards/zero_dssim.py diff --git a/src/climatebenchpress/compressor/compressors/__init__.py b/src/climatebenchpress/compressor/compressors/__init__.py index 31bcf95..b86813c 100644 --- a/src/climatebenchpress/compressor/compressors/__init__.py +++ b/src/climatebenchpress/compressor/compressors/__init__.py @@ -5,6 +5,7 @@ "SafeguardsSperr", "SafeguardsSz3", "SafeguardsZero", + "SafeguardsZeroDssim", "SafeguardsZfpRound", "Sperr", "StochRound", @@ -23,6 +24,7 @@ SafeguardsSperr, SafeguardsSz3, SafeguardsZero, + SafeguardsZeroDssim, SafeguardsZfpRound, ) from .sperr import Sperr diff --git a/src/climatebenchpress/compressor/compressors/safeguards/__init__.py b/src/climatebenchpress/compressor/compressors/safeguards/__init__.py index ffce923..e75d448 100644 --- a/src/climatebenchpress/compressor/compressors/safeguards/__init__.py +++ b/src/climatebenchpress/compressor/compressors/safeguards/__init__.py @@ -2,10 +2,12 @@ "SafeguardsSperr", "SafeguardsSz3", "SafeguardsZero", + "SafeguardsZeroDssim", "SafeguardsZfpRound", ] from .sperr import SafeguardsSperr from .sz3 import SafeguardsSz3 from .zero import SafeguardsZero +from .zero_dssim import SafeguardsZeroDssim from .zfp_round import SafeguardsZfpRound diff --git a/src/climatebenchpress/compressor/compressors/safeguards/zero_dssim.py b/src/climatebenchpress/compressor/compressors/safeguards/zero_dssim.py new file mode 100644 index 0000000..d0d3c71 --- /dev/null +++ b/src/climatebenchpress/compressor/compressors/safeguards/zero_dssim.py @@ -0,0 +1,79 @@ +__all__ = ["SafeguardsZeroDssim"] + +import numcodecs_safeguards +import numcodecs_zero + +from ..abc import Compressor + + +class SafeguardsZeroDssim(Compressor): + """Safeguarded all-zero compressor that also safeguards the dSSIM score.""" + + name = "safeguards-zero-dssim" + description = "Safeguards(0, dSSIM)" + + @staticmethod + def abs_bound_codec(error_bound, **kwargs): + return numcodecs_safeguards.SafeguardsCodec( + codec=numcodecs_zero.ZeroCodec(), + safeguards=[ + dict(kind="eb", type="abs", eb=error_bound, equal_nan=True), + # guarantee that the global minimum and maximum are preserved, + # which simplifies the rescaling + dict(kind="sign", offset="$x_min"), + dict(kind="sign", offset="$x_max"), + dict( + kind="qoi_eb_pw", + qoi=""" + # we guarantee that + # min(data) = min(corrected) and + # max(data) = max(corrected) + # with the sign safeguards above + v["smin"] = c["$x_min"]; + v["smax"] = c["$x_max"]; + v["r"] = v["smax"] - v["smin"]; + + # re-scale to [0-1] and quantize to 256 bins + v["sc_a2"] = round_ties_even(((x - v["smin"]) / v["r"]) * 255) / 255; + + # force the quantized value to stay the same + return v["sc_a2"]; + """, + type="abs", + eb=0, + ), + ], + ) + + @staticmethod + def rel_bound_codec(error_bound, **kwargs): + return numcodecs_safeguards.SafeguardsCodec( + codec=numcodecs_zero.ZeroCodec(), + safeguards=[ + dict(kind="eb", type="rel", eb=error_bound, equal_nan=True), + # guarantee that the global minimum and maximum are preserved, + # which simplifies the rescaling + dict(kind="sign", offset="$x_min"), + dict(kind="sign", offset="$x_max"), + dict( + kind="qoi_eb_pw", + qoi=""" + # we guarantee that + # min(data) = min(corrected) and + # max(data) = max(corrected) + # with the sign safeguards above + v["smin"] = c["$x_min"]; + v["smax"] = c["$x_max"]; + v["r"] = v["smax"] - v["smin"]; + + # re-scale to [0-1] and quantize to 256 bins + v["sc_a2"] = round_ties_even(((x - v["smin"]) / v["r"]) * 255) / 255; + + # force the quantized value to stay the same + return v["sc_a2"]; + """, + type="abs", + eb=0, + ), + ], + ) From 93c903fcf5036e48fe32f427de65642d8f47e185 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Thu, 11 Dec 2025 10:33:59 +0200 Subject: [PATCH 03/26] Add numcodecs-safeguards PyPi dependency --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 596e327..c654e8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ dependencies = [ "numcodecs>=0.13.0,<0.17", "numcodecs-combinators[xarray]~=0.2.10", "numcodecs-observers~=0.1.2", + "numcodecs-safeguards==0.1.0a1", "numcodecs-wasm~=0.2.2", "numcodecs-wasm-bit-round~=0.4.0", "numcodecs-wasm-fixed-offset-scale~=0.4.0", @@ -28,6 +29,7 @@ dependencies = [ "numcodecs-wasm-zfp~=0.6.0", "numcodecs-wasm-zfp-classic~=0.4.0", "numcodecs-wasm-zstd~=0.4.0", + "numcodecs-zero~=0.1.0", "pandas~=2.2", "scipy~=1.14", "seaborn~=0.13.2", From 6cddcd892d1724c47a153e91ea42076ce7fa7000 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Fri, 23 Jan 2026 15:21:05 +0200 Subject: [PATCH 04/26] safeguard (conservative) relative error bounds --- pyproject.toml | 8 ++--- .../compressor/compressors/abc.py | 4 +++ .../compressors/safeguards/sperr.py | 35 ++++++++++++++++++- .../compressors/safeguards/zfp_round.py | 16 +++++++++ 4 files changed, 58 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c654e8d..73bf9a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,9 +13,9 @@ dependencies = [ "matplotlib~=3.8", "netcdf4==1.7.3", "numcodecs>=0.13.0,<0.17", - "numcodecs-combinators[xarray]~=0.2.10", + "numcodecs-combinators[xarray]~=0.2.13", "numcodecs-observers~=0.1.2", - "numcodecs-safeguards==0.1.0a1", + "numcodecs-safeguards==0.1.0b1", "numcodecs-wasm~=0.2.2", "numcodecs-wasm-bit-round~=0.4.0", "numcodecs-wasm-fixed-offset-scale~=0.4.0", @@ -24,12 +24,12 @@ dependencies = [ "numcodecs-wasm-round~=0.5.0", "numcodecs-wasm-sperr~=0.2.0", "numcodecs-wasm-stochastic-rounding~=0.2.0", - "numcodecs-wasm-sz3~=0.7.0", + "numcodecs-wasm-sz3~=0.8.0", "numcodecs-wasm-tthresh~=0.3.0", "numcodecs-wasm-zfp~=0.6.0", "numcodecs-wasm-zfp-classic~=0.4.0", "numcodecs-wasm-zstd~=0.4.0", - "numcodecs-zero~=0.1.0", + "numcodecs-zero~=0.1.2", "pandas~=2.2", "scipy~=1.14", "seaborn~=0.13.2", diff --git a/src/climatebenchpress/compressor/compressors/abc.py b/src/climatebenchpress/compressor/compressors/abc.py index c4dacb8..9029a83 100644 --- a/src/climatebenchpress/compressor/compressors/abc.py +++ b/src/climatebenchpress/compressor/compressors/abc.py @@ -167,6 +167,8 @@ def build( dtype=dtypes[var], data_min=data_min[var], data_max=data_max[var], + data_abs_min=data_abs_min[var], + data_abs_max=data_abs_max[var], ) elif eb.rel_error is not None and cls.has_rel_error_impl: new_codecs[var] = partial( @@ -175,6 +177,8 @@ def build( dtype=dtypes[var], data_min=data_min[var], data_max=data_max[var], + data_abs_min=data_abs_min[var], + data_abs_max=data_abs_max[var], ) else: # This should never happen as we have already transformed the error bounds. diff --git a/src/climatebenchpress/compressor/compressors/safeguards/sperr.py b/src/climatebenchpress/compressor/compressors/safeguards/sperr.py index c49183c..815290d 100644 --- a/src/climatebenchpress/compressor/compressors/safeguards/sperr.py +++ b/src/climatebenchpress/compressor/compressors/safeguards/sperr.py @@ -1,7 +1,10 @@ __all__ = ["SafeguardsSperr"] +import numcodecs import numcodecs_safeguards import numcodecs_wasm_sperr +import numpy as np +from numcodecs_combinators.stack import CodecStack from ..abc import Compressor @@ -15,8 +18,38 @@ class SafeguardsSperr(Compressor): @staticmethod def abs_bound_codec(error_bound, **kwargs): return numcodecs_safeguards.SafeguardsCodec( - codec=numcodecs_wasm_sperr.Sperr(mode="pwe", pwe=error_bound), + codec=CodecStack( + NaNToZero(), + numcodecs_wasm_sperr.Sperr(mode="pwe", pwe=error_bound), + ), safeguards=[ dict(kind="eb", type="abs", eb=error_bound, equal_nan=True), ], ) + + @staticmethod + def rel_bound_codec(error_bound, *, data_abs_min=None, **kwargs): + assert data_abs_min is not None, "data_abs_min must be provided" + + return numcodecs_safeguards.SafeguardsCodec( + codec=CodecStack( + NaNToZero(), + # conservative rel->abs error bound transformation, + # same as convert_rel_error_to_abs_error + # so that we can inform the safeguards of the rel bound + numcodecs_wasm_sperr.Sperr(mode="pwe", pwe=error_bound * data_abs_min), + ), + safeguards=[ + dict(kind="eb", type="rel", eb=error_bound, equal_nan=True), + ], + ) + + +class NaNToZero(numcodecs.abc.Codec): + codec_id = "nan-to-zero" + + def encode(self, buf): + return np.nan_to_num(buf, nan=0, posinf=np.inf, neginf=-np.inf) + + def decode(self, buf, out=None): + return numcodecs.compat.ndarray_copy(buf, out) diff --git a/src/climatebenchpress/compressor/compressors/safeguards/zfp_round.py b/src/climatebenchpress/compressor/compressors/safeguards/zfp_round.py index 81b8caa..de28f08 100644 --- a/src/climatebenchpress/compressor/compressors/safeguards/zfp_round.py +++ b/src/climatebenchpress/compressor/compressors/safeguards/zfp_round.py @@ -32,3 +32,19 @@ def abs_bound_codec(error_bound, **kwargs): dict(kind="eb", type="abs", eb=error_bound, equal_nan=True), ], ) + + @staticmethod + def rel_bound_codec(error_bound, *, data_abs_min=None, **kwargs): + assert data_abs_min is not None, "data_abs_min must be provided" + + return numcodecs_safeguards.SafeguardsCodec( + # conservative rel->abs error bound transformation, + # same as convert_rel_error_to_abs_error + # so that we can inform the safeguards of the rel bound + codec=numcodecs_wasm_zfp.Zfp( + mode="fixed-accuracy", tolerance=error_bound * data_abs_min + ), + safeguards=[ + dict(kind="eb", type="rel", eb=error_bound, equal_nan=True), + ], + ) From 916373f1bd2a559be4f00bda20a1f99208b2b153 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Sat, 24 Jan 2026 00:11:59 +0200 Subject: [PATCH 05/26] fix mypy --- src/climatebenchpress/compressor/compressors/abc.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/climatebenchpress/compressor/compressors/abc.py b/src/climatebenchpress/compressor/compressors/abc.py index 9029a83..b72eeb9 100644 --- a/src/climatebenchpress/compressor/compressors/abc.py +++ b/src/climatebenchpress/compressor/compressors/abc.py @@ -88,6 +88,8 @@ def abs_bound_codec( dtype: Optional[np.dtype] = None, data_min: Optional[float] = None, data_max: Optional[float] = None, + data_abs_min: Optional[float] = None, + data_abs_max: Optional[float] = None, ) -> Codec: """Create a codec with an absolute error bound.""" pass @@ -100,6 +102,8 @@ def rel_bound_codec( dtype: Optional[np.dtype] = None, data_min: Optional[float] = None, data_max: Optional[float] = None, + data_abs_min: Optional[float] = None, + data_abs_max: Optional[float] = None, ) -> Codec: """Create a codec with a relative error bound.""" pass From cc764a291ecad04e416df26f386ac937f6d23bc9 Mon Sep 17 00:00:00 2001 From: Juniper Tyree <50025784+juntyr@users.noreply.github.com> Date: Sun, 25 Jan 2026 11:18:20 +0200 Subject: [PATCH 06/26] Update sz3 to v0.8.1 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 73bf9a3..2540959 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "numcodecs-wasm-round~=0.5.0", "numcodecs-wasm-sperr~=0.2.0", "numcodecs-wasm-stochastic-rounding~=0.2.0", - "numcodecs-wasm-sz3~=0.8.0", + "numcodecs-wasm-sz3~=0.8.1", "numcodecs-wasm-tthresh~=0.3.0", "numcodecs-wasm-zfp~=0.6.0", "numcodecs-wasm-zfp-classic~=0.4.0", From d879ca5a2b6c5d09d7da3b8e932ce256bbad9b4f Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Thu, 29 Jan 2026 08:27:19 +0200 Subject: [PATCH 07/26] Replace NaNs with the mean before SPERR --- .../compressor/compressors/safeguards/sperr.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/climatebenchpress/compressor/compressors/safeguards/sperr.py b/src/climatebenchpress/compressor/compressors/safeguards/sperr.py index 815290d..fd82304 100644 --- a/src/climatebenchpress/compressor/compressors/safeguards/sperr.py +++ b/src/climatebenchpress/compressor/compressors/safeguards/sperr.py @@ -19,7 +19,7 @@ class SafeguardsSperr(Compressor): def abs_bound_codec(error_bound, **kwargs): return numcodecs_safeguards.SafeguardsCodec( codec=CodecStack( - NaNToZero(), + NaNToMean(), numcodecs_wasm_sperr.Sperr(mode="pwe", pwe=error_bound), ), safeguards=[ @@ -33,7 +33,7 @@ def rel_bound_codec(error_bound, *, data_abs_min=None, **kwargs): return numcodecs_safeguards.SafeguardsCodec( codec=CodecStack( - NaNToZero(), + NaNToMean(), # conservative rel->abs error bound transformation, # same as convert_rel_error_to_abs_error # so that we can inform the safeguards of the rel bound @@ -45,11 +45,14 @@ def rel_bound_codec(error_bound, *, data_abs_min=None, **kwargs): ) -class NaNToZero(numcodecs.abc.Codec): - codec_id = "nan-to-zero" +# inspired by H5Z-SPERR's treatment of NaN values: +# https://github.com/NCAR/H5Z-SPERR/blob/72ebcb00e382886c229c5ef5a7e237fe451d5fb8/src/h5z-sperr.c#L464-L473 +# https://github.com/NCAR/H5Z-SPERR/blob/72ebcb00e382886c229c5ef5a7e237fe451d5fb8/src/h5zsperr_helper.cpp#L179-L212 +class NaNToMean(numcodecs.abc.Codec): + codec_id = "nan-to-mean" def encode(self, buf): - return np.nan_to_num(buf, nan=0, posinf=np.inf, neginf=-np.inf) + return np.nan_to_num(buf, nan=np.nanmean(buf), posinf=np.inf, neginf=-np.inf) def decode(self, buf, out=None): return numcodecs.compat.ndarray_copy(buf, out) From 4837f76a985c1c8fa7d29d65f2928e672f99924f Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Fri, 30 Jan 2026 21:51:54 +0200 Subject: [PATCH 08/26] some improvements --- pyproject.toml | 2 +- .../compressor/compressors/__init__.py | 22 +++++------ .../compressors/safeguarded/__init__.py | 13 +++++++ .../{safeguards => safeguarded}/sperr.py | 10 ++--- .../{safeguards => safeguarded}/sz3.py | 8 ++-- .../{safeguards => safeguarded}/zero.py | 8 ++-- .../{safeguards => safeguarded}/zero_dssim.py | 38 ++++++++++++------- .../{safeguards => safeguarded}/zfp_round.py | 16 +++++--- .../compressors/safeguards/__init__.py | 13 ------- 9 files changed, 72 insertions(+), 58 deletions(-) create mode 100644 src/climatebenchpress/compressor/compressors/safeguarded/__init__.py rename src/climatebenchpress/compressor/compressors/{safeguards => safeguarded}/sperr.py (90%) rename src/climatebenchpress/compressor/compressors/{safeguards => safeguarded}/sz3.py (85%) rename src/climatebenchpress/compressor/compressors/{safeguards => safeguarded}/zero.py (84%) rename src/climatebenchpress/compressor/compressors/{safeguards => safeguarded}/zero_dssim.py (65%) rename src/climatebenchpress/compressor/compressors/{safeguards => safeguarded}/zfp_round.py (78%) delete mode 100644 src/climatebenchpress/compressor/compressors/safeguards/__init__.py diff --git a/pyproject.toml b/pyproject.toml index 2540959..7a48e7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ description = "Add your description here" readme = "README.md" requires-python = ">=3.12" dependencies = [ - "astropy~=7.0.1", + "astropy~=7.2.0", "cartopy~=0.24.1", "cf-xarray~=0.10", "cftime~=1.6.0", diff --git a/src/climatebenchpress/compressor/compressors/__init__.py b/src/climatebenchpress/compressor/compressors/__init__.py index b86813c..46638b3 100644 --- a/src/climatebenchpress/compressor/compressors/__init__.py +++ b/src/climatebenchpress/compressor/compressors/__init__.py @@ -2,11 +2,11 @@ "BitRound", "BitRoundPco", "Jpeg2000", - "SafeguardsSperr", - "SafeguardsSz3", - "SafeguardsZero", - "SafeguardsZeroDssim", - "SafeguardsZfpRound", + "SafeguardedSperr", + "SafeguardedSz3", + "SafeguardedZero", + "SafeguardedZeroDssim", + "SafeguardedZfpRound", "Sperr", "StochRound", "StochRoundPco", @@ -20,12 +20,12 @@ from .bitround import BitRound from .bitround_pco import BitRoundPco from .jpeg2000 import Jpeg2000 -from .safeguards import ( - SafeguardsSperr, - SafeguardsSz3, - SafeguardsZero, - SafeguardsZeroDssim, - SafeguardsZfpRound, +from .safeguarded import ( + SafeguardedSperr, + SafeguardedSz3, + SafeguardedZero, + SafeguardedZeroDssim, + SafeguardedZfpRound, ) from .sperr import Sperr from .stochround import StochRound diff --git a/src/climatebenchpress/compressor/compressors/safeguarded/__init__.py b/src/climatebenchpress/compressor/compressors/safeguarded/__init__.py new file mode 100644 index 0000000..7660f38 --- /dev/null +++ b/src/climatebenchpress/compressor/compressors/safeguarded/__init__.py @@ -0,0 +1,13 @@ +__all__ = [ + "SafeguardedSperr", + "SafeguardedSz3", + "SafeguardedZero", + "SafeguardedZeroDssim", + "SafeguardedZfpRound", +] + +from .sperr import SafeguardedSperr +from .sz3 import SafeguardedSz3 +from .zero import SafeguardedZero +from .zero_dssim import SafeguardedZeroDssim +from .zfp_round import SafeguardedZfpRound diff --git a/src/climatebenchpress/compressor/compressors/safeguards/sperr.py b/src/climatebenchpress/compressor/compressors/safeguarded/sperr.py similarity index 90% rename from src/climatebenchpress/compressor/compressors/safeguards/sperr.py rename to src/climatebenchpress/compressor/compressors/safeguarded/sperr.py index fd82304..58b0a1c 100644 --- a/src/climatebenchpress/compressor/compressors/safeguards/sperr.py +++ b/src/climatebenchpress/compressor/compressors/safeguarded/sperr.py @@ -1,4 +1,4 @@ -__all__ = ["SafeguardsSperr"] +__all__ = ["SafeguardedSperr"] import numcodecs import numcodecs_safeguards @@ -9,11 +9,11 @@ from ..abc import Compressor -class SafeguardsSperr(Compressor): +class SafeguardedSperr(Compressor): """Safeguarded SPERR compressor.""" - name = "safeguards-sperr" - description = "Safeguards(SPERR)" + name = "safeguarded-sperr" + description = "Safeguarded(SPERR)" @staticmethod def abs_bound_codec(error_bound, **kwargs): @@ -49,7 +49,7 @@ def rel_bound_codec(error_bound, *, data_abs_min=None, **kwargs): # https://github.com/NCAR/H5Z-SPERR/blob/72ebcb00e382886c229c5ef5a7e237fe451d5fb8/src/h5z-sperr.c#L464-L473 # https://github.com/NCAR/H5Z-SPERR/blob/72ebcb00e382886c229c5ef5a7e237fe451d5fb8/src/h5zsperr_helper.cpp#L179-L212 class NaNToMean(numcodecs.abc.Codec): - codec_id = "nan-to-mean" + codec_id = "nan-to-mean" # type: ignore def encode(self, buf): return np.nan_to_num(buf, nan=np.nanmean(buf), posinf=np.inf, neginf=-np.inf) diff --git a/src/climatebenchpress/compressor/compressors/safeguards/sz3.py b/src/climatebenchpress/compressor/compressors/safeguarded/sz3.py similarity index 85% rename from src/climatebenchpress/compressor/compressors/safeguards/sz3.py rename to src/climatebenchpress/compressor/compressors/safeguarded/sz3.py index 480992c..8d89e03 100644 --- a/src/climatebenchpress/compressor/compressors/safeguards/sz3.py +++ b/src/climatebenchpress/compressor/compressors/safeguarded/sz3.py @@ -1,4 +1,4 @@ -__all__ = ["SafeguardsSz3"] +__all__ = ["SafeguardedSz3"] import numcodecs_safeguards import numcodecs_wasm_sz3 @@ -6,11 +6,11 @@ from ..abc import Compressor -class SafeguardsSz3(Compressor): +class SafeguardedSz3(Compressor): """Safeguarded SZ3 compressor.""" - name = "safeguards-sz3" - description = "Safeguards(SZ3)" + name = "safeguarded-sz3" + description = "Safeguarded(SZ3)" @staticmethod def abs_bound_codec(error_bound, **kwargs): diff --git a/src/climatebenchpress/compressor/compressors/safeguards/zero.py b/src/climatebenchpress/compressor/compressors/safeguarded/zero.py similarity index 84% rename from src/climatebenchpress/compressor/compressors/safeguards/zero.py rename to src/climatebenchpress/compressor/compressors/safeguarded/zero.py index d156c8a..c15919a 100644 --- a/src/climatebenchpress/compressor/compressors/safeguards/zero.py +++ b/src/climatebenchpress/compressor/compressors/safeguarded/zero.py @@ -1,4 +1,4 @@ -__all__ = ["SafeguardsZero"] +__all__ = ["SafeguardedZero"] import numcodecs_safeguards import numcodecs_zero @@ -6,11 +6,11 @@ from ..abc import Compressor -class SafeguardsZero(Compressor): +class SafeguardedZero(Compressor): """Safeguarded all-zero compressor.""" - name = "safeguards-zero" - description = "Safeguards(0)" + name = "safeguarded-zero" + description = "Safeguarded(0)" @staticmethod def abs_bound_codec(error_bound, **kwargs): diff --git a/src/climatebenchpress/compressor/compressors/safeguards/zero_dssim.py b/src/climatebenchpress/compressor/compressors/safeguarded/zero_dssim.py similarity index 65% rename from src/climatebenchpress/compressor/compressors/safeguards/zero_dssim.py rename to src/climatebenchpress/compressor/compressors/safeguarded/zero_dssim.py index d0d3c71..cf1fefc 100644 --- a/src/climatebenchpress/compressor/compressors/safeguards/zero_dssim.py +++ b/src/climatebenchpress/compressor/compressors/safeguarded/zero_dssim.py @@ -1,4 +1,4 @@ -__all__ = ["SafeguardsZeroDssim"] +__all__ = ["SafeguardedZeroDssim"] import numcodecs_safeguards import numcodecs_zero @@ -6,22 +6,25 @@ from ..abc import Compressor -class SafeguardsZeroDssim(Compressor): +class SafeguardedZeroDssim(Compressor): """Safeguarded all-zero compressor that also safeguards the dSSIM score.""" - name = "safeguards-zero-dssim" - description = "Safeguards(0, dSSIM)" + name = "safeguarded-zero-dssim" + description = "Safeguarded(0, dSSIM)" @staticmethod - def abs_bound_codec(error_bound, **kwargs): + def abs_bound_codec(error_bound, data_min=None, data_max=None, **kwargs): + assert data_min is not None, "data_min must be provided" + assert data_max is not None, "data_max must be provided" + return numcodecs_safeguards.SafeguardsCodec( codec=numcodecs_zero.ZeroCodec(), safeguards=[ dict(kind="eb", type="abs", eb=error_bound, equal_nan=True), # guarantee that the global minimum and maximum are preserved, # which simplifies the rescaling - dict(kind="sign", offset="$x_min"), - dict(kind="sign", offset="$x_max"), + dict(kind="sign", offset="x_min"), + dict(kind="sign", offset="x_max"), dict( kind="qoi_eb_pw", qoi=""" @@ -29,8 +32,8 @@ def abs_bound_codec(error_bound, **kwargs): # min(data) = min(corrected) and # max(data) = max(corrected) # with the sign safeguards above - v["smin"] = c["$x_min"]; - v["smax"] = c["$x_max"]; + v["smin"] = c["x_min"]; + v["smax"] = c["x_max"]; v["r"] = v["smax"] - v["smin"]; # re-scale to [0-1] and quantize to 256 bins @@ -43,18 +46,23 @@ def abs_bound_codec(error_bound, **kwargs): eb=0, ), ], + # use data_min instead of $x_min to allow for chunking + fixed_constants=dict(x_min=data_min, x_max=data_max), ) @staticmethod - def rel_bound_codec(error_bound, **kwargs): + def rel_bound_codec(error_bound, data_min=None, data_max=None, **kwargs): + assert data_min is not None, "data_min must be provided" + assert data_max is not None, "data_max must be provided" + return numcodecs_safeguards.SafeguardsCodec( codec=numcodecs_zero.ZeroCodec(), safeguards=[ dict(kind="eb", type="rel", eb=error_bound, equal_nan=True), # guarantee that the global minimum and maximum are preserved, # which simplifies the rescaling - dict(kind="sign", offset="$x_min"), - dict(kind="sign", offset="$x_max"), + dict(kind="sign", offset="x_min"), + dict(kind="sign", offset="x_max"), dict( kind="qoi_eb_pw", qoi=""" @@ -62,8 +70,8 @@ def rel_bound_codec(error_bound, **kwargs): # min(data) = min(corrected) and # max(data) = max(corrected) # with the sign safeguards above - v["smin"] = c["$x_min"]; - v["smax"] = c["$x_max"]; + v["smin"] = c["x_min"]; + v["smax"] = c["x_max"]; v["r"] = v["smax"] - v["smin"]; # re-scale to [0-1] and quantize to 256 bins @@ -76,4 +84,6 @@ def rel_bound_codec(error_bound, **kwargs): eb=0, ), ], + # use data_min instead of $x_min to allow for chunking + fixed_constants=dict(x_min=data_min, x_max=data_max), ) diff --git a/src/climatebenchpress/compressor/compressors/safeguards/zfp_round.py b/src/climatebenchpress/compressor/compressors/safeguarded/zfp_round.py similarity index 78% rename from src/climatebenchpress/compressor/compressors/safeguards/zfp_round.py rename to src/climatebenchpress/compressor/compressors/safeguarded/zfp_round.py index de28f08..89f1641 100644 --- a/src/climatebenchpress/compressor/compressors/safeguards/zfp_round.py +++ b/src/climatebenchpress/compressor/compressors/safeguarded/zfp_round.py @@ -1,4 +1,4 @@ -__all__ = ["SafeguardsZfpRound"] +__all__ = ["SafeguardedZfpRound"] import numcodecs_safeguards import numcodecs_wasm_zfp @@ -6,15 +6,15 @@ from ..abc import Compressor -class SafeguardsZfpRound(Compressor): +class SafeguardedZfpRound(Compressor): """Safeguarded ZFP-ROUND compressor. This is an adjusted version of the ZFP compressor with an improved rounding mechanism for the transform coefficients. """ - name = "safeguards-zfp-round" - description = "Safeguards(ZFP-ROUND)" + name = "safeguarded-zfp-round" + description = "Safeguarded(ZFP-ROUND)" # NOTE: # ZFP mechanism for strictly supporting relative error bounds is to @@ -27,7 +27,9 @@ class SafeguardsZfpRound(Compressor): @staticmethod def abs_bound_codec(error_bound, **kwargs): return numcodecs_safeguards.SafeguardsCodec( - codec=numcodecs_wasm_zfp.Zfp(mode="fixed-accuracy", tolerance=error_bound), + codec=numcodecs_wasm_zfp.Zfp( + mode="fixed-accuracy", tolerance=error_bound, non_finite="allow-unsafe" + ), safeguards=[ dict(kind="eb", type="abs", eb=error_bound, equal_nan=True), ], @@ -42,7 +44,9 @@ def rel_bound_codec(error_bound, *, data_abs_min=None, **kwargs): # same as convert_rel_error_to_abs_error # so that we can inform the safeguards of the rel bound codec=numcodecs_wasm_zfp.Zfp( - mode="fixed-accuracy", tolerance=error_bound * data_abs_min + mode="fixed-accuracy", + tolerance=error_bound * data_abs_min, + non_finite="allow-unsafe", ), safeguards=[ dict(kind="eb", type="rel", eb=error_bound, equal_nan=True), diff --git a/src/climatebenchpress/compressor/compressors/safeguards/__init__.py b/src/climatebenchpress/compressor/compressors/safeguards/__init__.py deleted file mode 100644 index e75d448..0000000 --- a/src/climatebenchpress/compressor/compressors/safeguards/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -__all__ = [ - "SafeguardsSperr", - "SafeguardsSz3", - "SafeguardsZero", - "SafeguardsZeroDssim", - "SafeguardsZfpRound", -] - -from .sperr import SafeguardsSperr -from .sz3 import SafeguardsSz3 -from .zero import SafeguardsZero -from .zero_dssim import SafeguardsZeroDssim -from .zfp_round import SafeguardsZfpRound From f302b39fccd6b68a8f48316db2176603dd616dcf Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Sat, 31 Jan 2026 00:58:33 +0200 Subject: [PATCH 09/26] fix compression stats --- .../compressor/scripts/compress.py | 50 ++++++++++++------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/src/climatebenchpress/compressor/scripts/compress.py b/src/climatebenchpress/compressor/scripts/compress.py index 98b71e4..96fa51f 100644 --- a/src/climatebenchpress/compressor/scripts/compress.py +++ b/src/climatebenchpress/compressor/scripts/compress.py @@ -199,24 +199,38 @@ def compress_decompress( ) as codec_: variables[v] = codec_.encode_decode_data_array(ds[v]).compute() - measurements[v] = { - "encoded_bytes": sum( - b.post for b in nbytes.encode_sizes[HashableCodec(codec[-1])] - ), - "decoded_bytes": sum( - b.post for b in nbytes.decode_sizes[HashableCodec(codec[0])] - ), - "encode_timing": sum(t for ts in timing.encode_times.values() for t in ts), - "decode_timing": sum(t for ts in timing.decode_times.values() for t in ts), - "encode_instructions": sum( - i for is_ in instructions.encode_instructions.values() for i in is_ - ) - or None, - "decode_instructions": sum( - i for is_ in instructions.decode_instructions.values() for i in is_ - ) - or None, - } + cs = [c._codec for c in codec_.__iter__()] + + measurements[v] = { + # bytes measurements: only look at the first and last codec in + # the top level stack, which gives the total encoded and + # decoded sizes + "encoded_bytes": sum( + b.post for b in nbytes.encode_sizes[HashableCodec(cs[-1])] + ), + "decoded_bytes": sum( + b.post for b in nbytes.decode_sizes[HashableCodec(cs[0])] + ), + # time measurements: only sum over the top level stack members + # to avoid double counting from nested codec combinators + "encode_timing": sum( + t for c in cs for t in timing.encode_times[HashableCodec(c)] + ), + "decode_timing": sum( + t for c in cs for t in timing.decode_times[HashableCodec(c)] + ), + # encode instructions: sum over all codecs since WASM + # instruction counts are currently not aggregated in codec + # combinators + "encode_instructions": sum( + i for is_ in instructions.encode_instructions.values() for i in is_ + ) + or None, + "decode_instructions": sum( + i for is_ in instructions.decode_instructions.values() for i in is_ + ) + or None, + } return xr.Dataset(variables, coords=ds.coords, attrs=ds.attrs), measurements From 8f54e8e9f5e8063caa3ac12abc19c0d4a4059900 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Sat, 31 Jan 2026 09:23:04 +0200 Subject: [PATCH 10/26] downgrade SZ3 to 0.7.0 --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 7a48e7f..3e4f33d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "numcodecs-wasm-round~=0.5.0", "numcodecs-wasm-sperr~=0.2.0", "numcodecs-wasm-stochastic-rounding~=0.2.0", - "numcodecs-wasm-sz3~=0.8.1", + "numcodecs-wasm-sz3~=0.7.0", "numcodecs-wasm-tthresh~=0.3.0", "numcodecs-wasm-zfp~=0.6.0", "numcodecs-wasm-zfp-classic~=0.4.0", From 3772370faa0b0a2be64e5f2297c3b43b01e589d0 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Mon, 2 Feb 2026 11:03:42 +0200 Subject: [PATCH 11/26] Add safeguarded BitRound+PCO --- .../compressor/compressors/__init__.py | 2 + .../compressors/safeguarded/__init__.py | 2 + .../compressors/safeguarded/bitround_pco.py | 66 +++++++++++++++++++ .../compressors/safeguarded/sperr.py | 2 + 4 files changed, 72 insertions(+) create mode 100644 src/climatebenchpress/compressor/compressors/safeguarded/bitround_pco.py diff --git a/src/climatebenchpress/compressor/compressors/__init__.py b/src/climatebenchpress/compressor/compressors/__init__.py index 46638b3..63da691 100644 --- a/src/climatebenchpress/compressor/compressors/__init__.py +++ b/src/climatebenchpress/compressor/compressors/__init__.py @@ -2,6 +2,7 @@ "BitRound", "BitRoundPco", "Jpeg2000", + "SafeguardedBitRoundPco", "SafeguardedSperr", "SafeguardedSz3", "SafeguardedZero", @@ -21,6 +22,7 @@ from .bitround_pco import BitRoundPco from .jpeg2000 import Jpeg2000 from .safeguarded import ( + SafeguardedBitRoundPco, SafeguardedSperr, SafeguardedSz3, SafeguardedZero, diff --git a/src/climatebenchpress/compressor/compressors/safeguarded/__init__.py b/src/climatebenchpress/compressor/compressors/safeguarded/__init__.py index 7660f38..2fb5669 100644 --- a/src/climatebenchpress/compressor/compressors/safeguarded/__init__.py +++ b/src/climatebenchpress/compressor/compressors/safeguarded/__init__.py @@ -1,4 +1,5 @@ __all__ = [ + "SafeguardedBitRoundPco", "SafeguardedSperr", "SafeguardedSz3", "SafeguardedZero", @@ -6,6 +7,7 @@ "SafeguardedZfpRound", ] +from .bitround_pco import SafeguardedBitRoundPco from .sperr import SafeguardedSperr from .sz3 import SafeguardedSz3 from .zero import SafeguardedZero diff --git a/src/climatebenchpress/compressor/compressors/safeguarded/bitround_pco.py b/src/climatebenchpress/compressor/compressors/safeguarded/bitround_pco.py new file mode 100644 index 0000000..b175fff --- /dev/null +++ b/src/climatebenchpress/compressor/compressors/safeguarded/bitround_pco.py @@ -0,0 +1,66 @@ +__all__ = ["SafeguardedBitRoundPco"] + + +import numcodecs_safeguards +import numcodecs_wasm_bit_round +import numcodecs_wasm_pco + +from ..abc import Compressor +from ..utils import compute_keepbits + + +class SafeguardedBitRoundPco(Compressor): + """Safeguarded Bit Rounding + PCodec compressor. + + This compressor first applies bit rounding to the data, which reduces the precision of the data + while preserving its overall structure. After that, it uses PCodec for further compression. + """ + + name = "safeguarded-bitround-pco" + description = "Safeguarded(Bit Rounding + PCodec)" + + @staticmethod + def abs_bound_codec(error_bound, *, dtype=None, data_abs_max=None, **kwargs): + assert dtype is not None, "dtype must be provided" + assert data_abs_max is not None, "data_abs_max must be provided" + + # conservative abs->rel error bound transformation, + # same as convert_abs_error_to_rel_error + # so that we can inform the safeguards of the abs bound + keepbits = compute_keepbits(dtype, error_bound / data_abs_max) + + return numcodecs_safeguards.SafeguardsCodec( + codec=numcodecs_wasm_bit_round.BitRound(keepbits=keepbits), + lossless=numcodecs_safeguards.lossless.Lossless( + for_codec=numcodecs_wasm_pco.Pco( + level=8, + mode="auto", + delta="auto", + paging="equal-pages-up-to", + ) + ), + safeguards=[ + dict(kind="eb", type="abs", eb=error_bound, equal_nan=True), + ], + ) + + @staticmethod + def rel_bound_codec(error_bound, *, dtype=None, **kwargs): + assert dtype is not None, "dtype must be provided" + + keepbits = compute_keepbits(dtype, error_bound) + + return numcodecs_safeguards.SafeguardsCodec( + codec=numcodecs_wasm_bit_round.BitRound(keepbits=keepbits), + lossless=numcodecs_safeguards.lossless.Lossless( + for_codec=numcodecs_wasm_pco.Pco( + level=8, + mode="auto", + delta="auto", + paging="equal-pages-up-to", + ) + ), + safeguards=[ + dict(kind="eb", type="rel", eb=error_bound, equal_nan=True), + ], + ) diff --git a/src/climatebenchpress/compressor/compressors/safeguarded/sperr.py b/src/climatebenchpress/compressor/compressors/safeguarded/sperr.py index 58b0a1c..7498f8c 100644 --- a/src/climatebenchpress/compressor/compressors/safeguarded/sperr.py +++ b/src/climatebenchpress/compressor/compressors/safeguarded/sperr.py @@ -1,6 +1,8 @@ __all__ = ["SafeguardedSperr"] import numcodecs +import numcodecs.abc +import numcodecs.compat import numcodecs_safeguards import numcodecs_wasm_sperr import numpy as np From a9b373e07ec79f42f1c763bf9b4fc5146e20eb0e Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Mon, 2 Feb 2026 11:21:59 +0200 Subject: [PATCH 12/26] Upgrade to numcodecs-safeguards==0.1.0b2 --- pyproject.toml | 2 +- .../compressor/compressors/safeguarded/bitround_pco.py | 4 ++-- .../compressor/compressors/safeguarded/sperr.py | 4 ++-- .../compressor/compressors/safeguarded/sz3.py | 4 ++-- .../compressor/compressors/safeguarded/zero.py | 4 ++-- .../compressor/compressors/safeguarded/zero_dssim.py | 4 ++-- .../compressor/compressors/safeguarded/zfp_round.py | 4 ++-- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3e4f33d..ff7f490 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "numcodecs>=0.13.0,<0.17", "numcodecs-combinators[xarray]~=0.2.13", "numcodecs-observers~=0.1.2", - "numcodecs-safeguards==0.1.0b1", + "numcodecs-safeguards==0.1.0b2", "numcodecs-wasm~=0.2.2", "numcodecs-wasm-bit-round~=0.4.0", "numcodecs-wasm-fixed-offset-scale~=0.4.0", diff --git a/src/climatebenchpress/compressor/compressors/safeguarded/bitround_pco.py b/src/climatebenchpress/compressor/compressors/safeguarded/bitround_pco.py index b175fff..36cd540 100644 --- a/src/climatebenchpress/compressor/compressors/safeguarded/bitround_pco.py +++ b/src/climatebenchpress/compressor/compressors/safeguarded/bitround_pco.py @@ -29,7 +29,7 @@ def abs_bound_codec(error_bound, *, dtype=None, data_abs_max=None, **kwargs): # so that we can inform the safeguards of the abs bound keepbits = compute_keepbits(dtype, error_bound / data_abs_max) - return numcodecs_safeguards.SafeguardsCodec( + return numcodecs_safeguards.SafeguardedCodec( codec=numcodecs_wasm_bit_round.BitRound(keepbits=keepbits), lossless=numcodecs_safeguards.lossless.Lossless( for_codec=numcodecs_wasm_pco.Pco( @@ -50,7 +50,7 @@ def rel_bound_codec(error_bound, *, dtype=None, **kwargs): keepbits = compute_keepbits(dtype, error_bound) - return numcodecs_safeguards.SafeguardsCodec( + return numcodecs_safeguards.SafeguardedCodec( codec=numcodecs_wasm_bit_round.BitRound(keepbits=keepbits), lossless=numcodecs_safeguards.lossless.Lossless( for_codec=numcodecs_wasm_pco.Pco( diff --git a/src/climatebenchpress/compressor/compressors/safeguarded/sperr.py b/src/climatebenchpress/compressor/compressors/safeguarded/sperr.py index 7498f8c..5bb2373 100644 --- a/src/climatebenchpress/compressor/compressors/safeguarded/sperr.py +++ b/src/climatebenchpress/compressor/compressors/safeguarded/sperr.py @@ -19,7 +19,7 @@ class SafeguardedSperr(Compressor): @staticmethod def abs_bound_codec(error_bound, **kwargs): - return numcodecs_safeguards.SafeguardsCodec( + return numcodecs_safeguards.SafeguardedCodec( codec=CodecStack( NaNToMean(), numcodecs_wasm_sperr.Sperr(mode="pwe", pwe=error_bound), @@ -33,7 +33,7 @@ def abs_bound_codec(error_bound, **kwargs): def rel_bound_codec(error_bound, *, data_abs_min=None, **kwargs): assert data_abs_min is not None, "data_abs_min must be provided" - return numcodecs_safeguards.SafeguardsCodec( + return numcodecs_safeguards.SafeguardedCodec( codec=CodecStack( NaNToMean(), # conservative rel->abs error bound transformation, diff --git a/src/climatebenchpress/compressor/compressors/safeguarded/sz3.py b/src/climatebenchpress/compressor/compressors/safeguarded/sz3.py index 8d89e03..609f1de 100644 --- a/src/climatebenchpress/compressor/compressors/safeguarded/sz3.py +++ b/src/climatebenchpress/compressor/compressors/safeguarded/sz3.py @@ -14,7 +14,7 @@ class SafeguardedSz3(Compressor): @staticmethod def abs_bound_codec(error_bound, **kwargs): - return numcodecs_safeguards.SafeguardsCodec( + return numcodecs_safeguards.SafeguardedCodec( codec=numcodecs_wasm_sz3.Sz3(eb_mode="abs", eb_abs=error_bound), safeguards=[ dict(kind="eb", type="abs", eb=error_bound, equal_nan=True), @@ -23,7 +23,7 @@ def abs_bound_codec(error_bound, **kwargs): @staticmethod def rel_bound_codec(error_bound, **kwargs): - return numcodecs_safeguards.SafeguardsCodec( + return numcodecs_safeguards.SafeguardedCodec( codec=numcodecs_wasm_sz3.Sz3(eb_mode="rel", eb_rel=error_bound), safeguards=[ dict(kind="eb", type="rel", eb=error_bound, equal_nan=True), diff --git a/src/climatebenchpress/compressor/compressors/safeguarded/zero.py b/src/climatebenchpress/compressor/compressors/safeguarded/zero.py index c15919a..99e1790 100644 --- a/src/climatebenchpress/compressor/compressors/safeguarded/zero.py +++ b/src/climatebenchpress/compressor/compressors/safeguarded/zero.py @@ -14,7 +14,7 @@ class SafeguardedZero(Compressor): @staticmethod def abs_bound_codec(error_bound, **kwargs): - return numcodecs_safeguards.SafeguardsCodec( + return numcodecs_safeguards.SafeguardedCodec( codec=numcodecs_zero.ZeroCodec(), safeguards=[ dict(kind="eb", type="abs", eb=error_bound, equal_nan=True), @@ -23,7 +23,7 @@ def abs_bound_codec(error_bound, **kwargs): @staticmethod def rel_bound_codec(error_bound, **kwargs): - return numcodecs_safeguards.SafeguardsCodec( + return numcodecs_safeguards.SafeguardedCodec( codec=numcodecs_zero.ZeroCodec(), safeguards=[ dict(kind="eb", type="rel", eb=error_bound, equal_nan=True), diff --git a/src/climatebenchpress/compressor/compressors/safeguarded/zero_dssim.py b/src/climatebenchpress/compressor/compressors/safeguarded/zero_dssim.py index cf1fefc..0e9b295 100644 --- a/src/climatebenchpress/compressor/compressors/safeguarded/zero_dssim.py +++ b/src/climatebenchpress/compressor/compressors/safeguarded/zero_dssim.py @@ -17,7 +17,7 @@ def abs_bound_codec(error_bound, data_min=None, data_max=None, **kwargs): assert data_min is not None, "data_min must be provided" assert data_max is not None, "data_max must be provided" - return numcodecs_safeguards.SafeguardsCodec( + return numcodecs_safeguards.SafeguardedCodec( codec=numcodecs_zero.ZeroCodec(), safeguards=[ dict(kind="eb", type="abs", eb=error_bound, equal_nan=True), @@ -55,7 +55,7 @@ def rel_bound_codec(error_bound, data_min=None, data_max=None, **kwargs): assert data_min is not None, "data_min must be provided" assert data_max is not None, "data_max must be provided" - return numcodecs_safeguards.SafeguardsCodec( + return numcodecs_safeguards.SafeguardedCodec( codec=numcodecs_zero.ZeroCodec(), safeguards=[ dict(kind="eb", type="rel", eb=error_bound, equal_nan=True), diff --git a/src/climatebenchpress/compressor/compressors/safeguarded/zfp_round.py b/src/climatebenchpress/compressor/compressors/safeguarded/zfp_round.py index 89f1641..d98f3d9 100644 --- a/src/climatebenchpress/compressor/compressors/safeguarded/zfp_round.py +++ b/src/climatebenchpress/compressor/compressors/safeguarded/zfp_round.py @@ -26,7 +26,7 @@ class SafeguardedZfpRound(Compressor): @staticmethod def abs_bound_codec(error_bound, **kwargs): - return numcodecs_safeguards.SafeguardsCodec( + return numcodecs_safeguards.SafeguardedCodec( codec=numcodecs_wasm_zfp.Zfp( mode="fixed-accuracy", tolerance=error_bound, non_finite="allow-unsafe" ), @@ -39,7 +39,7 @@ def abs_bound_codec(error_bound, **kwargs): def rel_bound_codec(error_bound, *, data_abs_min=None, **kwargs): assert data_abs_min is not None, "data_abs_min must be provided" - return numcodecs_safeguards.SafeguardsCodec( + return numcodecs_safeguards.SafeguardedCodec( # conservative rel->abs error bound transformation, # same as convert_rel_error_to_abs_error # so that we can inform the safeguards of the rel bound From 35b25919ac6d9515414ad50ffa1117a74e68e2b2 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Wed, 4 Feb 2026 12:01:19 +0200 Subject: [PATCH 13/26] fix Safeguarded(0, dSSIM) for multiple 2D slices --- .../compressor/compressors/abc.py | 16 +++++++++ .../compressors/safeguarded/zero_dssim.py | 22 ++++++------ .../compressor/plotting/plot_metrics.py | 5 +-- .../compressor/scripts/compress.py | 34 +++++++++++++++++-- 4 files changed, 63 insertions(+), 14 deletions(-) diff --git a/src/climatebenchpress/compressor/compressors/abc.py b/src/climatebenchpress/compressor/compressors/abc.py index b72eeb9..e20429f 100644 --- a/src/climatebenchpress/compressor/compressors/abc.py +++ b/src/climatebenchpress/compressor/compressors/abc.py @@ -90,6 +90,8 @@ def abs_bound_codec( data_max: Optional[float] = None, data_abs_min: Optional[float] = None, data_abs_max: Optional[float] = None, + data_min_2d: Optional[np.ndarray] = None, + data_max_2d: Optional[np.ndarray] = None, ) -> Codec: """Create a codec with an absolute error bound.""" pass @@ -104,6 +106,8 @@ def rel_bound_codec( data_max: Optional[float] = None, data_abs_min: Optional[float] = None, data_abs_max: Optional[float] = None, + data_min_2d: Optional[np.ndarray] = None, + data_max_2d: Optional[np.ndarray] = None, ) -> Codec: """Create a codec with a relative error bound.""" pass @@ -116,6 +120,8 @@ def build( data_abs_max: dict[VariableName, float], data_min: dict[VariableName, float], data_max: dict[VariableName, float], + data_min_2d: dict[VariableName, np.ndarray], + data_max_2d: dict[VariableName, np.ndarray], error_bounds: list[dict[VariableName, ErrorBound]], ) -> dict[VariantName, list[NamedPerVariableCodec]]: """ @@ -139,6 +145,12 @@ def build( Dict mapping from variable name to minimum value for the variable. data_max : dict[VariableName, float] Dict mapping from variable name to maximum value for the variable. + data_min_2d : dict[VariableName, np.ndarray] + Dict mapping from variable name to per-lat-lon-slice minimum value for the + variable. + data_max_2d : dict[VariableName, np.ndarray] + Dict mapping from variable name to per-lat-lon-slice maximum value for the + variable. error_bounds: list[ErrorBound] List of error bounds to use for the compressor. @@ -173,6 +185,8 @@ def build( data_max=data_max[var], data_abs_min=data_abs_min[var], data_abs_max=data_abs_max[var], + data_min_2d=data_min_2d[var], + data_max_2d=data_max_2d[var], ) elif eb.rel_error is not None and cls.has_rel_error_impl: new_codecs[var] = partial( @@ -183,6 +197,8 @@ def build( data_max=data_max[var], data_abs_min=data_abs_min[var], data_abs_max=data_abs_max[var], + data_min_2d=data_min_2d[var], + data_max_2d=data_max_2d[var], ) else: # This should never happen as we have already transformed the error bounds. diff --git a/src/climatebenchpress/compressor/compressors/safeguarded/zero_dssim.py b/src/climatebenchpress/compressor/compressors/safeguarded/zero_dssim.py index 0e9b295..6f300ad 100644 --- a/src/climatebenchpress/compressor/compressors/safeguarded/zero_dssim.py +++ b/src/climatebenchpress/compressor/compressors/safeguarded/zero_dssim.py @@ -13,9 +13,9 @@ class SafeguardedZeroDssim(Compressor): description = "Safeguarded(0, dSSIM)" @staticmethod - def abs_bound_codec(error_bound, data_min=None, data_max=None, **kwargs): - assert data_min is not None, "data_min must be provided" - assert data_max is not None, "data_max must be provided" + def abs_bound_codec(error_bound, data_min_2d=None, data_max_2d=None, **kwargs): + assert data_min_2d is not None, "data_min_2d must be provided" + assert data_max_2d is not None, "data_max_2d must be provided" return numcodecs_safeguards.SafeguardedCodec( codec=numcodecs_zero.ZeroCodec(), @@ -46,14 +46,15 @@ def abs_bound_codec(error_bound, data_min=None, data_max=None, **kwargs): eb=0, ), ], - # use data_min instead of $x_min to allow for chunking - fixed_constants=dict(x_min=data_min, x_max=data_max), + # use data_min_2d instead of $x_min since we need the minimum per + # 2d latitude-longitude slice + fixed_constants=dict(x_min=data_min_2d, x_max=data_max_2d), ) @staticmethod - def rel_bound_codec(error_bound, data_min=None, data_max=None, **kwargs): - assert data_min is not None, "data_min must be provided" - assert data_max is not None, "data_max must be provided" + def rel_bound_codec(error_bound, data_min_2d=None, data_max_2d=None, **kwargs): + assert data_min_2d is not None, "data_min_2d must be provided" + assert data_max_2d is not None, "data_max_2d must be provided" return numcodecs_safeguards.SafeguardedCodec( codec=numcodecs_zero.ZeroCodec(), @@ -84,6 +85,7 @@ def rel_bound_codec(error_bound, data_min=None, data_max=None, **kwargs): eb=0, ), ], - # use data_min instead of $x_min to allow for chunking - fixed_constants=dict(x_min=data_min, x_max=data_max), + # use data_min_2d instead of $x_min since we need the minimum per + # 2d latitude-longitude slice + fixed_constants=dict(x_min=data_min_2d, x_max=data_max_2d), ) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index b00d92c..408b098 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -207,8 +207,9 @@ def _normalize(data): # Normalize each variable by its mean and std normalized[new_col] = normalized.apply( - lambda x: (x[col] - mean_std[x["Variable"]][0]) - / mean_std[x["Variable"]][1], + lambda x: ( + (x[col] - mean_std[x["Variable"]][0]) / mean_std[x["Variable"]][1] + ), axis=1, ) diff --git a/src/climatebenchpress/compressor/scripts/compress.py b/src/climatebenchpress/compressor/scripts/compress.py index 96fa51f..df7d40f 100644 --- a/src/climatebenchpress/compressor/scripts/compress.py +++ b/src/climatebenchpress/compressor/scripts/compress.py @@ -87,6 +87,8 @@ def compress( ds_abs_maxs: dict[str, float] = dict() ds_mins: dict[str, float] = dict() ds_maxs: dict[str, float] = dict() + ds_min_2ds: dict[str, np.ndarray] = dict() + ds_max_2ds: dict[str, np.ndarray] = dict() for v in ds: vs: str = str(v) abs_vals = xr.ufuncs.abs(ds[v]) @@ -96,6 +98,16 @@ def compress( ds_abs_maxs[vs] = abs_vals.max().values.item() ds_mins[vs] = ds[v].min().values.item() ds_maxs[vs] = ds[v].max().values.item() + ds_min_2ds[vs] = ( + ds[v] + .min(dim=[ds[v].cf["Y"].name, ds[v].cf["X"].name], keepdims=True) + .values + ) + ds_max_2ds[vs] = ( + ds[v] + .max(dim=[ds[v].cf["Y"].name, ds[v].cf["X"].name], keepdims=True) + .values + ) if chunked: for v in ds: @@ -115,7 +127,14 @@ def compress( compressor_variants: dict[str, list[NamedPerVariableCodec]] = ( compressor.build( - ds_dtypes, ds_abs_mins, ds_abs_maxs, ds_mins, ds_maxs, error_bounds + ds_dtypes, + ds_abs_mins, + ds_abs_maxs, + ds_mins, + ds_maxs, + ds_min_2ds, + ds_max_2ds, + error_bounds, ) ) @@ -189,6 +208,15 @@ def compress_decompress( if not isinstance(codec, CodecStack): codec = CodecStack(codec) + # HACK: Safeguarded(0, dSSIM) requires the per-lat-lon-slice minimum + # and maximum + # for potentially-chunked data we should really use xarray-safeguards, + # but not using chunks also works (for now) + is_safeguarded_zero_dssim = ( + "# === pointwise dSSIM quantity of interest === #" + in json.dumps(codec.get_config()) + ) + with numcodecs_observers.observe( codec, observers=[ @@ -197,7 +225,9 @@ def compress_decompress( timing, ], ) as codec_: - variables[v] = codec_.encode_decode_data_array(ds[v]).compute() + variables[v] = codec_.encode_decode_data_array( + ds[v].compute() if is_safeguarded_zero_dssim else ds[v] + ).compute() cs = [c._codec for c in codec_.__iter__()] From de79c732ccf324ae6413059c73c99926585f7d3b Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Wed, 4 Feb 2026 12:04:33 +0200 Subject: [PATCH 14/26] small cleanup --- .../compressor/compressors/safeguarded/zero_dssim.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/climatebenchpress/compressor/compressors/safeguarded/zero_dssim.py b/src/climatebenchpress/compressor/compressors/safeguarded/zero_dssim.py index 6f300ad..ea483ba 100644 --- a/src/climatebenchpress/compressor/compressors/safeguarded/zero_dssim.py +++ b/src/climatebenchpress/compressor/compressors/safeguarded/zero_dssim.py @@ -21,13 +21,15 @@ def abs_bound_codec(error_bound, data_min_2d=None, data_max_2d=None, **kwargs): codec=numcodecs_zero.ZeroCodec(), safeguards=[ dict(kind="eb", type="abs", eb=error_bound, equal_nan=True), - # guarantee that the global minimum and maximum are preserved, - # which simplifies the rescaling + # guarantee that the per-latitude-longitude-slice minimum and + # maximum are preserved, which simplifies the rescaling dict(kind="sign", offset="x_min"), dict(kind="sign", offset="x_max"), dict( kind="qoi_eb_pw", qoi=""" + # === pointwise dSSIM quantity of interest === # + # we guarantee that # min(data) = min(corrected) and # max(data) = max(corrected) @@ -60,13 +62,15 @@ def rel_bound_codec(error_bound, data_min_2d=None, data_max_2d=None, **kwargs): codec=numcodecs_zero.ZeroCodec(), safeguards=[ dict(kind="eb", type="rel", eb=error_bound, equal_nan=True), - # guarantee that the global minimum and maximum are preserved, - # which simplifies the rescaling + # guarantee that the per-latitude-longitude-slice minimum and + # maximum are preserved, which simplifies the rescaling dict(kind="sign", offset="x_min"), dict(kind="sign", offset="x_max"), dict( kind="qoi_eb_pw", qoi=""" + # === pointwise dSSIM quantity of interest === # + # we guarantee that # min(data) = min(corrected) and # max(data) = max(corrected) From 102d5127a1ea163cdb36b2557b925507ee0118dc Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Tue, 3 Feb 2026 13:28:05 +0200 Subject: [PATCH 15/26] Plotting hacks --- .../compressor/plotting/plot_metrics.py | 118 ++++++++++++------ 1 file changed, 79 insertions(+), 39 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 408b098..0385907 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -13,15 +13,21 @@ _COMPRESSOR2LINEINFO = [ ("jpeg2000", ("#EE7733", "-")), - ("sperr", ("#117733", ":")), - ("zfp-round", ("#DDAA33", "--")), + ("sperr", ("#117733", "-")), + ("zfp-round", ("#DDAA33", "-")), ("zfp", ("#EE3377", "--")), - ("sz3", ("#CC3311", "-.")), - ("bitround-pco", ("#0077BB", ":")), + ("sz3", ("#CC3311", "-")), + ("bitround-pco", ("#0077BB", "-")), ("bitround", ("#33BBEE", "-")), ("stochround-pco", ("#BBBBBB", "--")), ("stochround", ("#009988", "--")), ("tthresh", ("#882255", "-.")), + ("safeguarded-sperr", ("#117733", ":")), + ("safeguarded-zfp-round", ("#DDAA33", ":")), + ("safeguarded-sz3", ("#CC3311", ":")), + ("safeguarded-zero-dssim", ("#9467BD", "--")), + ("safeguarded-zero", ("#9467BD", ":")), + ("safeguarded-bitround-pco", ("#0077BB", ":")), ] @@ -44,6 +50,25 @@ def _get_lineinfo(compressor: str) -> tuple[str, str]: ("stochround-pco", "StochRound + PCO"), ("stochround", "StochRound + Zstd"), ("tthresh", "TTHRESH"), + ("safeguarded-sperr", "Safeguarded(SPERR)"), + ("safeguarded-zfp-round", "Safeguarded(ZFP-ROUND)"), + ("safeguarded-sz3", "Safeguarded(SZ3)"), + ("safeguarded-zero-dssim", "Safeguarded(0, dSSIM)"), + ("safeguarded-zero", "Safeguarded(0)"), + ("safeguarded-bitround-pco", "Safeguarded(BitRound + PCO)"), +] + +_COMPRESSOR_ORDER = [ + "BitRound + PCO", + "Safeguarded(BitRound + PCO)", + "ZFP-ROUND", + "Safeguarded(ZFP-ROUND)", + "SZ3", + "Safeguarded(SZ3)", + "SPERR", + "Safeguarded(SPERR)", + "Safeguarded(0)", + "Safeguarded(0, dSSIM)", ] DISTORTION2LEGEND_NAME = { @@ -102,6 +127,7 @@ def plot_metrics( df = pd.read_csv(metrics_path / "all_results.csv") # Filter out excluded datasets and compressors + # bitround jpeg2000-conservative-abs stochround-conservative-abs stochround-pco-conservative-abs zfp-conservative-abs bitround-conservative-rel stochround-pco stochround zfp jpeg2000 df = df[~df["Compressor"].isin(exclude_compressor)] df = df[~df["Dataset"].isin(exclude_dataset)] is_tiny = df["Dataset"].str.endswith("-tiny") @@ -111,13 +137,13 @@ def plot_metrics( filter_chunked = is_chunked if chunked_datasets else ~is_chunked df = df[filter_chunked] - _plot_per_variable_metrics( - datasets=datasets, - compressed_datasets=compressed_datasets, - plots_path=plots_path, - all_results=df, - rd_curves_metrics=["Max Absolute Error", "MAE", "DSSIM", "Spectral Error"], - ) + # _plot_per_variable_metrics( + # datasets=datasets, + # compressed_datasets=compressed_datasets, + # plots_path=plots_path, + # all_results=df, + # rd_curves_metrics=["Max Absolute Error", "MAE", "DSSIM", "Spectral Error"], + # ) df = _rename_compressors(df) normalized_df = _normalize(df) @@ -419,7 +445,10 @@ def _plot_aggregated_rd_curve( # Exclude variables that are not relevant for the distortion metric. normalized_df = normalized_df[~normalized_df["Variable"].isin(exclude_vars)] - compressors = normalized_df["Compressor"].unique() + compressors = sorted( + normalized_df["Compressor"].unique(), + key=lambda k: _COMPRESSOR_ORDER.index(_get_legend_name(k)), + ) agg_distortion = normalized_df.groupby(["Error Bound Name", "Compressor"])[ [compression_metric, distortion_metric] ].agg(agg) @@ -503,8 +532,8 @@ def _plot_aggregated_rd_curve( ) plt.legend( title="Compressor", - loc="upper right", - bbox_to_anchor=(0.8, 0.99), + loc="upper left", + # bbox_to_anchor=(0.8, 0.99), fontsize=12, title_fontsize=14, ) @@ -614,27 +643,32 @@ def _plot_instruction_count(df, outfile: None | Path = None): def _get_median_and_quantiles(df, encode_column, decode_column): - return df.groupby(["Compressor", "Error Bound Name"])[ - [encode_column, decode_column] - ].agg( - encode_median=pd.NamedAgg( - column=encode_column, aggfunc=lambda x: x.quantile(0.5) - ), - encode_lower_quantile=pd.NamedAgg( - column=encode_column, aggfunc=lambda x: x.quantile(0.25) - ), - encode_upper_quantile=pd.NamedAgg( - column=encode_column, aggfunc=lambda x: x.quantile(0.75) - ), - decode_median=pd.NamedAgg( - column=decode_column, aggfunc=lambda x: x.quantile(0.5) - ), - decode_lower_quantile=pd.NamedAgg( - column=decode_column, aggfunc=lambda x: x.quantile(0.25) - ), - decode_upper_quantile=pd.NamedAgg( - column=decode_column, aggfunc=lambda x: x.quantile(0.75) - ), + return ( + df.groupby(["Compressor", "Error Bound Name"])[[encode_column, decode_column]] + .agg( + encode_median=pd.NamedAgg( + column=encode_column, aggfunc=lambda x: x.quantile(0.5) + ), + encode_lower_quantile=pd.NamedAgg( + column=encode_column, aggfunc=lambda x: x.quantile(0.25) + ), + encode_upper_quantile=pd.NamedAgg( + column=encode_column, aggfunc=lambda x: x.quantile(0.75) + ), + decode_median=pd.NamedAgg( + column=decode_column, aggfunc=lambda x: x.quantile(0.5) + ), + decode_lower_quantile=pd.NamedAgg( + column=decode_column, aggfunc=lambda x: x.quantile(0.25) + ), + decode_upper_quantile=pd.NamedAgg( + column=decode_column, aggfunc=lambda x: x.quantile(0.75) + ), + ) + .sort_index( + level=0, + key=lambda ks: [_COMPRESSOR_ORDER.index(_get_legend_name(k)) for k in ks], + ) ) @@ -645,7 +679,10 @@ def _plot_grouped_df( # Bar width bar_width = 0.35 - compressors = grouped_df.index.levels[0].tolist() + compressors = sorted( + grouped_df.index.levels[0].tolist(), + key=lambda k: _COMPRESSOR_ORDER.index(_get_legend_name(k)), + ) x_labels = [_get_legend_name(c) for c in compressors] x_positions = range(len(x_labels)) @@ -653,7 +690,10 @@ def _plot_grouped_df( for i, error_bound in enumerate(error_bounds): ax = axes[i] - bound_data = grouped_df.xs(error_bound, level="Error Bound Name") + bound_data = grouped_df.xs(error_bound, level="Error Bound Name").sort_index( + level=0, + key=lambda ks: [_COMPRESSOR_ORDER.index(_get_legend_name(k)) for k in ks], + ) # Plot encode throughput ax.bar( @@ -720,11 +760,11 @@ def _plot_bound_violations(df, bound_names, outfile: None | Path = None): df_bound["Compressor"] = df_bound["Compressor"].map(_get_legend_name) pass_fail = df_bound.pivot( index="Compressor", columns="Variable", values="Satisfies Bound (Passed)" - ) + ).sort_index(key=lambda ks: [_COMPRESSOR_ORDER.index(k) for k in ks]) pass_fail = pass_fail.astype(np.float32) fraction_fail = df_bound.pivot( index="Compressor", columns="Variable", values="Satisfies Bound (Value)" - ) + ).sort_index(key=lambda ks: [_COMPRESSOR_ORDER.index(k) for k in ks]) annotations = fraction_fail.map( lambda x: "{:.2f}".format(x * 100) if x * 100 >= 0.01 else "<0.01" ) From c5e0d2d458222abac1a76ccbee7e4e9c13f85d90 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Tue, 3 Feb 2026 15:02:05 +0200 Subject: [PATCH 16/26] some improvements --- .../compressor/plotting/plot_metrics.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 0385907..fbf4990 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -42,27 +42,27 @@ def _get_lineinfo(compressor: str) -> tuple[str, str]: _COMPRESSOR2LEGEND_NAME = [ ("jpeg2000", "JPEG2000"), ("sperr", "SPERR"), - ("zfp-round", "ZFP-ROUND"), + ("zfp-round", "ZFP"), ("zfp", "ZFP"), ("sz3", "SZ3"), - ("bitround-pco", "BitRound + PCO"), + ("bitround-pco", "BitRound"), ("bitround", "BitRound + Zstd"), ("stochround-pco", "StochRound + PCO"), ("stochround", "StochRound + Zstd"), ("tthresh", "TTHRESH"), ("safeguarded-sperr", "Safeguarded(SPERR)"), - ("safeguarded-zfp-round", "Safeguarded(ZFP-ROUND)"), + ("safeguarded-zfp-round", "Safeguarded(ZFP)"), ("safeguarded-sz3", "Safeguarded(SZ3)"), ("safeguarded-zero-dssim", "Safeguarded(0, dSSIM)"), ("safeguarded-zero", "Safeguarded(0)"), - ("safeguarded-bitround-pco", "Safeguarded(BitRound + PCO)"), + ("safeguarded-bitround-pco", "Safeguarded(BitRound)"), ] _COMPRESSOR_ORDER = [ - "BitRound + PCO", - "Safeguarded(BitRound + PCO)", - "ZFP-ROUND", - "Safeguarded(ZFP-ROUND)", + "BitRound", + "Safeguarded(BitRound)", + "ZFP", + "Safeguarded(ZFP)", "SZ3", "Safeguarded(SZ3)", "SPERR", @@ -476,7 +476,13 @@ def _plot_aggregated_rd_curve( if remove_outliers: # SZ3 and JPEG2000 often give outlier values and violate the bounds. - exclude_compressors = ["sz3", "jpeg2000"] + exclude_compressors = [ + "sz3", + "jpeg2000", + "safeguarded-zero-dssim", + "safeguarded-zero", + "safeguarded-sz3", + ] filtered_agg = agg_distortion[ ~agg_distortion.index.get_level_values("Compressor").isin( exclude_compressors From adc83a309c5c9ec8a188c3b1340102986b391708 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Tue, 3 Feb 2026 15:55:44 +0200 Subject: [PATCH 17/26] Adjust SZ3 name --- src/climatebenchpress/compressor/plotting/plot_metrics.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index fbf4990..2e87acf 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -44,7 +44,7 @@ def _get_lineinfo(compressor: str) -> tuple[str, str]: ("sperr", "SPERR"), ("zfp-round", "ZFP"), ("zfp", "ZFP"), - ("sz3", "SZ3"), + ("sz3", "SZ3[v3.2]"), ("bitround-pco", "BitRound"), ("bitround", "BitRound + Zstd"), ("stochround-pco", "StochRound + PCO"), @@ -52,7 +52,7 @@ def _get_lineinfo(compressor: str) -> tuple[str, str]: ("tthresh", "TTHRESH"), ("safeguarded-sperr", "Safeguarded(SPERR)"), ("safeguarded-zfp-round", "Safeguarded(ZFP)"), - ("safeguarded-sz3", "Safeguarded(SZ3)"), + ("safeguarded-sz3", "Safeguarded(SZ3[v3.2])"), ("safeguarded-zero-dssim", "Safeguarded(0, dSSIM)"), ("safeguarded-zero", "Safeguarded(0)"), ("safeguarded-bitround-pco", "Safeguarded(BitRound)"), @@ -63,8 +63,8 @@ def _get_lineinfo(compressor: str) -> tuple[str, str]: "Safeguarded(BitRound)", "ZFP", "Safeguarded(ZFP)", - "SZ3", - "Safeguarded(SZ3)", + "SZ3[v3.2]", + "Safeguarded(SZ3[v3.2])", "SPERR", "Safeguarded(SPERR)", "Safeguarded(0)", From fdc1d15309ea8447b33b7028b9c5e72fe735a202 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Wed, 4 Feb 2026 15:40:59 +0200 Subject: [PATCH 18/26] Draft safeguards scorecards --- scorecards.ipynb | 722 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 722 insertions(+) create mode 100644 scorecards.ipynb diff --git a/scorecards.ipynb b/scorecards.ipynb new file mode 100644 index 0000000..64bf6ee --- /dev/null +++ b/scorecards.ipynb @@ -0,0 +1,722 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "e7ab252b", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib as mpl\n", + "import seaborn as sns\n", + "from matplotlib.colors import LinearSegmentedColormap\n", + "import matplotlib.patches as mpatches\n", + "from matplotlib.lines import Line2D\n", + "\n", + "from pathlib import Path\n", + "from climatebenchpress.compressor.plotting.plot_metrics import (\n", + " _rename_compressors, \n", + " _get_legend_name,\n", + " _normalize,\n", + " _get_lineinfo,\n", + " DISTORTION2LEGEND_NAME,\n", + " _COMPRESSOR_ORDER,\n", + " _savefig\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "4ee15a6b", + "metadata": {}, + "source": [ + "# Process results" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2d242e7c", + "metadata": {}, + "outputs": [], + "source": [ + "results_file = \"metrics/all_results.csv\"\n", + "df = pd.read_csv(results_file)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ec124c29", + "metadata": {}, + "outputs": [], + "source": [ + "def create_data_matrix(\n", + " df: pd.DataFrame, \n", + " error_bound: str, \n", + " metrics: list[str] = ['DSSIM', 'MAE', 'Max Absolute Error', \"Spectral Error\", \"Compression Ratio [raw B / enc B]\", 'Satisfies Bound (Value)']\n", + "):\n", + " df_filtered = df[df['Error Bound Name'] == error_bound].copy()\n", + " df_filtered[\"Satisfies Bound (Value)\"] = df_filtered[\"Satisfies Bound (Value)\"] * 100 # Convert to percentage\n", + "\n", + " # Get unique variables and compressors\n", + " # dataset_variables = sorted(df_filtered[['Dataset', 'Variable']].drop_duplicates().apply(lambda x: \"/\".join(x), axis=1).unique())\n", + " dataset_variables = sorted(df_filtered['Variable'].unique())\n", + " compressors = sorted(df_filtered['Compressor'].unique(), key=lambda k: _COMPRESSOR_ORDER.index(_get_legend_name(k)))\n", + "\n", + " column_labels = []\n", + " for metric in metrics:\n", + " for dataset_variable in dataset_variables:\n", + " column_labels.append(f\"{dataset_variable}\\n{metric}\")\n", + "\n", + " # Initialize the data matrix\n", + " data_matrix = np.full((len(compressors), len(column_labels)), np.nan)\n", + "\n", + " # Fill the matrix with data\n", + " for i, compressor in enumerate(compressors):\n", + " for j, metric in enumerate(metrics):\n", + " for k, dataset_variable in enumerate(dataset_variables):\n", + " # Get data for this compressor-variable combination\n", + " # dataset, variable = dataset_variable.split('/')\n", + " variable = dataset_variable\n", + " subset = df_filtered[\n", + " (df_filtered['Compressor'] == compressor) & \n", + " (df_filtered['Variable'] == variable) #&\n", + " # (df_filtered['Dataset'] == dataset)\n", + " ]\n", + " if subset.empty:\n", + " print(f\"No data for Compressor: {compressor}, Variable: {variable}\")\n", + " continue\n", + "\n", + " if metric in [\"DSSIM\", \"Spectral Error\"] and variable in [\"ta\", \"tos\"]:\n", + " # These variables have large regions of NaN values which makes the \n", + " # DSSIM and Spectral Error values unreliable.\n", + " continue\n", + "\n", + "\n", + " col_idx = j * len(dataset_variables) + k\n", + " if metric in subset.columns:\n", + " values = subset[metric]\n", + " if len(values) == 1:\n", + " data_matrix[i, col_idx] = values.iloc[0]\n", + " \n", + " return data_matrix, compressors, dataset_variables" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2a399d30", + "metadata": {}, + "outputs": [], + "source": [ + "df = df[~df[\"Compressor\"].isin([\n", + " \"bitround\", \"jpeg2000-conservative-abs\", \"stochround-conservative-abs\",\n", + " \"stochround-pco-conservative-abs\", \"zfp-conservative-abs\",\n", + " \"bitround-conservative-rel\", \"stochround-pco\", \"stochround\", \"zfp\", \"jpeg2000\",\n", + "])]\n", + "df = df[~df[\"Dataset\"].str.contains(\"-tiny\")]\n", + "df = df[~df[\"Dataset\"].str.contains(\"-chunked\")]\n", + "df = _rename_compressors(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2d019b8d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n" + ] + } + ], + "source": [ + "metrics = ['DSSIM', 'MAE', 'Max Absolute Error', \"Spectral Error\", \"Compression Ratio [raw B / enc B]\", 'Satisfies Bound (Value)']\n", + "scorecard_data = {}\n", + "for error_bound in [\"low\", \"mid\", \"high\"]:\n", + " scorecard_data[error_bound] = create_data_matrix(df, error_bound, metrics)" + ] + }, + { + "cell_type": "markdown", + "id": "ae80d757", + "metadata": {}, + "source": [ + "# Scorecard" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "871ae766", + "metadata": {}, + "outputs": [], + "source": [ + "METRICS2NAME = {\n", + " # \"Max Absolute Error\": \"MaxAE\",\n", + " \"MAE\": \"Mean Absolute Error\",\n", + " \"Spatial Relative Error (Value)\": \"SRE\",\n", + " \"Compression Ratio [raw B / enc B]\": \"Compression Ratio\",\n", + " \"Satisfies Bound (Value)\": r\"% of Pixels Exceeding Error Bound\",\n", + "}\n", + "\n", + "VARIABLE2NAME = {\n", + " \"10m_u_component_of_wind\": \"10u\",\n", + " \"10m_v_component_of_wind\": \"10v\",\n", + " \"mean_sea_level_pressure\": \"msl\",\n", + "}\n", + "\n", + "DATASET2PREFIX = {\n", + " \"era5-hurricane\": \"h-\",\n", + "}\n", + "\n", + "def get_variable_label(variable):\n", + " dataset, var_name = variable.split('/')\n", + " prefix = DATASET2PREFIX.get(dataset, \"\")\n", + " var_name = VARIABLE2NAME.get(var_name, var_name)\n", + " return f\"{prefix}{var_name}\"\n", + "\n", + "\n", + "def create_compression_scorecard(\n", + " data_matrix, \n", + " compressors, \n", + " variables, \n", + " metrics, \n", + " cbar=True,\n", + " ref_compressor='sz3', \n", + " higher_better_metrics=[\"DSSIM\", \"Compression Ratio [raw B / enc B]\"],\n", + " save_fn=None,\n", + " compare_against_0=False,\n", + " highlight_bigger_than_one=False\n", + "):\n", + " \"\"\"\n", + " Create a scorecard plot similar to the weather forecasting example\n", + " \n", + " Parameters:\n", + " - data_matrix: 2D array with compressors as rows, metric-variable combinations as columns\n", + " - compressors: list of compressor names\n", + " - variables: list of variable names \n", + " - metrics: list of metric names\n", + " - ref_compressor: reference compressor for relative calculations\n", + " - save_fn: filename to save plot (optional)\n", + " \"\"\"\n", + " \n", + " # Calculate relative differences vs reference compressor\n", + " ref_idx = compressors.index(ref_compressor)\n", + " ref_values = data_matrix[ref_idx, :]\n", + " if compare_against_0:\n", + " ref_values = np.zeros_like(data_matrix[ref_idx, :])\n", + " \n", + " relative_matrix = np.full_like(data_matrix, np.nan)\n", + " if highlight_bigger_than_one:\n", + " relative_matrix = np.sign(data_matrix) * 101\n", + " for j in range(data_matrix.shape[1]):\n", + " if metrics[j // len(variables)] == \"Satisfies Bound (Value)\":\n", + " # For bound satisfication lower is better (less number of pixels exceeding error bound).\n", + " relative_matrix[:, j] = -1 * relative_matrix[:, j]\n", + " else:\n", + " for i in range(len(compressors)):\n", + " for j in range(data_matrix.shape[1]):\n", + " if not np.isnan(data_matrix[i, j]) and not np.isnan(ref_values[j]):\n", + " ref_val = np.abs(ref_values[j])\n", + " if ref_val == 0.0:\n", + " ref_val = 1e-10 # Avoid division by zero\n", + " if metrics[j // len(variables)] in higher_better_metrics:\n", + " # Higher is better metrics\n", + " relative_matrix[i, j] = (ref_values[j] - data_matrix[i, j]) / ref_val * 100\n", + " elif metrics[j // len(variables)] == \"Satisfies Bound (Value)\":\n", + " relative_matrix[i, j] = 100 if data_matrix[i, j] != 0 else 0\n", + " else:\n", + " relative_matrix[i, j] = (data_matrix[i, j] - ref_values[j]) / ref_val * 100\n", + "\n", + " # Set up colormap - similar to original\n", + " reds = sns.color_palette('Reds', 6)\n", + " blues = sns.color_palette('Blues_r', 6)\n", + " cmap = mpl.colors.ListedColormap(blues + [(0.95, 0.95, 0.95)] + reds)\n", + " # cb_levels = [-50, -20, -10, -5, -2, -1, 1, 2, 5, 10, 20, 50]\n", + " # cb_levels = [-75, -50, -25, -10, -5, -1, 1, 5, 10, 25, 50, 75]\n", + " cb_levels = [-100, -75, -50, -25, -10, -1, 1, 10, 25, 50, 75, 100]\n", + "\n", + " norm = mpl.colors.BoundaryNorm(cb_levels, cmap.N, extend='both')\n", + " \n", + " # Calculate figure dimensions\n", + " ncompressors = len(compressors)\n", + " nvariables = len(variables)\n", + " nmetrics = len(metrics)\n", + " \n", + " panel_width = (2.5 / 5) * nvariables\n", + " label_width = 1.5 * panel_width\n", + " padding_right = 0.1\n", + " panel_height = panel_width / nvariables\n", + " \n", + " title_height = panel_height * 1.25\n", + " cbar_height = panel_height * 2\n", + " spacing_height = panel_height * 0.1\n", + " spacing_width = panel_height * 0.2\n", + " \n", + " total_width = label_width + nmetrics * panel_width + (nmetrics - 1) * spacing_width + padding_right\n", + " total_height = title_height + cbar_height + ncompressors * panel_height + (ncompressors - 1) * spacing_height\n", + " \n", + " # Create figure and gridspec\n", + " fig = plt.figure(figsize=(total_width, total_height))\n", + " gs = mpl.gridspec.GridSpec(\n", + " ncompressors, nmetrics,\n", + " figure=fig,\n", + " left=label_width / total_width,\n", + " right=1 - padding_right / total_width,\n", + " top=1 - (title_height / total_height),\n", + " bottom=cbar_height / total_height,\n", + " hspace=spacing_height / panel_height,\n", + " wspace=spacing_width / panel_width\n", + " )\n", + " \n", + " # Plot each panel\n", + " for row, compressor in enumerate(compressors):\n", + " for col, metric in enumerate(metrics):\n", + " ax = fig.add_subplot(gs[row, col])\n", + "\n", + " # Get data for this metric (all variables)\n", + " start_col = col * nvariables\n", + " end_col = start_col + nvariables\n", + "\n", + " rel_values = relative_matrix[row, start_col:end_col].reshape(1, -1)\n", + " abs_values = data_matrix[row, start_col:end_col]\n", + " \n", + " # Create heatmap\n", + " img = ax.imshow(rel_values, aspect='auto', cmap=cmap, norm=norm)\n", + " \n", + " # Customize axes\n", + " ax.set_xticks([])\n", + " ax.set_xticklabels([])\n", + " ax.set_yticks([])\n", + " ax.set_yticklabels([])\n", + " \n", + " # Add white grid lines\n", + " for i in range(nvariables):\n", + " rect = mpl.patches.Rectangle(\n", + " (i - 0.5, -0.5), 1, 1,\n", + " linewidth=1, edgecolor='white', facecolor='none'\n", + " )\n", + " ax.add_patch(rect)\n", + " \n", + " # Add absolute values as text\n", + " for i, val in enumerate(abs_values):\n", + " # Ensure we don't have black text on dark background\n", + " color = \"black\" if abs(rel_values[0, i]) < 75 else \"white\"\n", + " fontsize = 10\n", + " # Format numbers appropriately\n", + " if metric in [\"DSSIM\", \"Spectral Error\"] and variables[i] in [\"ta\", \"tos\"]:\n", + " # These variables have large regions of NaN values which makes the \n", + " # DSSIM and Spectral Error values unreliable.\n", + " text = \"N/A\"\n", + " color = \"black\"\n", + " elif np.isnan(val):\n", + " text = \"Fail\"\n", + " color = \"black\"\n", + " elif abs(val) > 10_000:\n", + " text = f\"{val:.1e}\"\n", + " fontsize = 8\n", + " elif abs(val) > 10:\n", + " text = f\"{val:.0f}\"\n", + " elif abs(val) > 1:\n", + " text = f\"{val:.1f}\"\n", + " elif val == 1 and metric == \"DSSIM\":\n", + " text = \"1\"\n", + " elif val == 0:\n", + " text = \"0\"\n", + " elif abs(val) < 0.01:\n", + " text = f\"{val:.1e}\"\n", + " fontsize = 8\n", + " else:\n", + " text = f\"{val:.2f}\"\n", + " ax.text(\n", + " i, \n", + " 0, \n", + " text, \n", + " ha='center', \n", + " va='center', \n", + " fontsize=fontsize, \n", + " color=color\n", + " )\n", + "\n", + " # Add row labels (compressor names)\n", + " if col == 0:\n", + " ax.set_ylabel(_get_legend_name(compressor), rotation=0, ha='right', va='center',\n", + " labelpad=10, fontsize=14)\n", + " \n", + " # Add column titles (variable names)\n", + " if row == 0:\n", + " # ax.set_title(VARIABLE2NAME.get(variable, variable), fontsize=10, pad=10)\n", + " ax.set_title(METRICS2NAME.get(metric, metric), fontsize=16, pad=10)\n", + "\n", + " # Add metric labels at the top on the top row\n", + " if row == 0:\n", + " # ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)\n", + " # ax.set_xticks(range(nmetrics))\n", + " # ax.set_xticklabels(\n", + " # [METRICS2NAME.get(m, m) for m in metrics], \n", + " # rotation=45, \n", + " # ha='left', fontsize=8)\n", + " ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)\n", + " ax.set_xticks(range(nvariables))\n", + " ax.set_xticklabels(\n", + " [VARIABLE2NAME.get(v, v) for v in variables],\n", + " rotation=45,\n", + " ha='left', fontsize=12)\n", + " \n", + " # Style spines\n", + " for spine in ax.spines.values():\n", + " spine.set_color('0.7')\n", + " \n", + " # Add colorbar\n", + " if cbar and not highlight_bigger_than_one:\n", + " rel_cbar_height = cbar_height / total_height\n", + " cax = fig.add_axes((0.4, rel_cbar_height * 0.3, 0.5, rel_cbar_height * 0.2))\n", + " cb = fig.colorbar(img, cax=cax, orientation='horizontal')\n", + " cb.ax.set_xticks(cb_levels)\n", + " if highlight_bigger_than_one:\n", + " cb.ax.set_xlabel('Better ← |non-chunked - chunked| → Worse', fontsize=16)\n", + " else:\n", + " cb.ax.set_xlabel(f'Better ← % difference vs {_get_legend_name(ref_compressor)} → Worse', fontsize=16)\n", + " \n", + " if highlight_bigger_than_one:\n", + " chunking_handles = [\n", + " Line2D([], [], marker=\"s\", color=cmap(101), linestyle=\"None\", markersize=10,\n", + " label=\"Not Chunked Better\"),\n", + " Line2D([], [], marker=\"s\", color=cmap(-101), linestyle=\"None\", markersize=10,\n", + " label=\"Chunked Better\"),\n", + " ]\n", + "\n", + " ax.legend(\n", + " handles=chunking_handles,\n", + " loc=\"upper left\",\n", + " ncol=2,\n", + " bbox_to_anchor=(-0.5, -0.05),\n", + " fontsize=16\n", + " )\n", + "\n", + " plt.tight_layout()\n", + " \n", + " if save_fn:\n", + " plt.savefig(save_fn, dpi=300, bbox_inches='tight')\n", + " plt.close()\n", + " else:\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "678c927b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating scorecard for low bound...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating scorecard for mid bound...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating scorecard for high bound...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n" + ] + } + ], + "source": [ + "for bound_name, (data_matrix, compressors, variables) in scorecard_data.items():\n", + " print(f\"Creating scorecard for {bound_name} bound...\")\n", + " # Split into two rows for better readability.\n", + " create_compression_scorecard(\n", + " data_matrix[:, :3*len(variables)], \n", + " compressors, \n", + " variables, \n", + " metrics[:3],\n", + " ref_compressor=\"bitround-pco\",\n", + " cbar=False,\n", + " save_fn=f\"figures_updated/{bound_name}_scorecard_row1.pdf\"\n", + " )\n", + "\n", + " create_compression_scorecard(\n", + " data_matrix[:, 3*len(variables):], \n", + " compressors, \n", + " variables, \n", + " metrics[3:],\n", + " ref_compressor=\"bitround-pco\",\n", + " save_fn=f\"figures_updated/{bound_name}_scorecard_row2.pdf\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "3afb646e", + "metadata": {}, + "source": [ + "## Two-Column Scorecard" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "b6fe5f55", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating scorecard for low bound...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating scorecard for mid bound...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating scorecard for high bound...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n" + ] + } + ], + "source": [ + "for bound_name, (data_matrix, compressors, variables) in scorecard_data.items():\n", + " print(f\"Creating scorecard for {bound_name} bound...\")\n", + " # Split into two rows for better readability.\n", + " num_vars = len(variables)\n", + " create_compression_scorecard(\n", + " data_matrix[:, :2*num_vars], \n", + " compressors, \n", + " variables, \n", + " metrics[:2],\n", + " ref_compressor=\"bitround-pco\",\n", + " cbar=False,\n", + " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row1.pdf\"\n", + " )\n", + "\n", + " create_compression_scorecard(\n", + " data_matrix[:, 2*num_vars:4*num_vars], \n", + " compressors, \n", + " variables, \n", + " metrics[2:4],\n", + " ref_compressor=\"bitround-pco\",\n", + " cbar=False,\n", + " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row2.pdf\"\n", + " )\n", + "\n", + " create_compression_scorecard(\n", + " data_matrix[:, 4*num_vars:], \n", + " compressors, \n", + " variables, \n", + " metrics[4:],\n", + " ref_compressor=\"bitround-pco\",\n", + " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row3.pdf\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c546f6b0-bb7b-4646-83bc-3f502bcf6a9f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From cdc9bd72460f048bb9d009227918dcc9a7cacb53 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Thu, 5 Feb 2026 14:46:41 +0200 Subject: [PATCH 19/26] some more tweaks --- scorecards.ipynb | 330 ++++++++++-------- .../compressor/plotting/plot_metrics.py | 23 +- 2 files changed, 203 insertions(+), 150 deletions(-) diff --git a/scorecards.ipynb b/scorecards.ipynb index 64bf6ee..8beff56 100644 --- a/scorecards.ipynb +++ b/scorecards.ipynb @@ -18,13 +18,13 @@ "\n", "from pathlib import Path\n", "from climatebenchpress.compressor.plotting.plot_metrics import (\n", - " _rename_compressors, \n", + " _rename_compressors,\n", " _get_legend_name,\n", " _normalize,\n", " _get_lineinfo,\n", " DISTORTION2LEGEND_NAME,\n", " _COMPRESSOR_ORDER,\n", - " _savefig\n", + " _savefig,\n", ")" ] }, @@ -55,17 +55,29 @@ "outputs": [], "source": [ "def create_data_matrix(\n", - " df: pd.DataFrame, \n", - " error_bound: str, \n", - " metrics: list[str] = ['DSSIM', 'MAE', 'Max Absolute Error', \"Spectral Error\", \"Compression Ratio [raw B / enc B]\", 'Satisfies Bound (Value)']\n", + " df: pd.DataFrame,\n", + " error_bound: str,\n", + " metrics: list[str] = [\n", + " \"DSSIM\",\n", + " \"MAE\",\n", + " \"Max Absolute Error\",\n", + " \"Spectral Error\",\n", + " \"Compression Ratio [raw B / enc B]\",\n", + " \"Satisfies Bound (Value)\",\n", + " ],\n", "):\n", - " df_filtered = df[df['Error Bound Name'] == error_bound].copy()\n", - " df_filtered[\"Satisfies Bound (Value)\"] = df_filtered[\"Satisfies Bound (Value)\"] * 100 # Convert to percentage\n", + " df_filtered = df[df[\"Error Bound Name\"] == error_bound].copy()\n", + " df_filtered[\"Satisfies Bound (Value)\"] = (\n", + " df_filtered[\"Satisfies Bound (Value)\"] * 100\n", + " ) # Convert to percentage\n", "\n", " # Get unique variables and compressors\n", " # dataset_variables = sorted(df_filtered[['Dataset', 'Variable']].drop_duplicates().apply(lambda x: \"/\".join(x), axis=1).unique())\n", - " dataset_variables = sorted(df_filtered['Variable'].unique())\n", - " compressors = sorted(df_filtered['Compressor'].unique(), key=lambda k: _COMPRESSOR_ORDER.index(_get_legend_name(k)))\n", + " dataset_variables = sorted(df_filtered[\"Variable\"].unique())\n", + " compressors = sorted(\n", + " df_filtered[\"Compressor\"].unique(),\n", + " key=lambda k: _COMPRESSOR_ORDER.index(_get_legend_name(k)),\n", + " )\n", "\n", " column_labels = []\n", " for metric in metrics:\n", @@ -83,8 +95,8 @@ " # dataset, variable = dataset_variable.split('/')\n", " variable = dataset_variable\n", " subset = df_filtered[\n", - " (df_filtered['Compressor'] == compressor) & \n", - " (df_filtered['Variable'] == variable) #&\n", + " (df_filtered[\"Compressor\"] == compressor)\n", + " & (df_filtered[\"Variable\"] == variable) # &\n", " # (df_filtered['Dataset'] == dataset)\n", " ]\n", " if subset.empty:\n", @@ -92,17 +104,16 @@ " continue\n", "\n", " if metric in [\"DSSIM\", \"Spectral Error\"] and variable in [\"ta\", \"tos\"]:\n", - " # These variables have large regions of NaN values which makes the \n", + " # These variables have large regions of NaN values which makes the\n", " # DSSIM and Spectral Error values unreliable.\n", " continue\n", "\n", - "\n", " col_idx = j * len(dataset_variables) + k\n", " if metric in subset.columns:\n", " values = subset[metric]\n", " if len(values) == 1:\n", " data_matrix[i, col_idx] = values.iloc[0]\n", - " \n", + "\n", " return data_matrix, compressors, dataset_variables" ] }, @@ -113,11 +124,22 @@ "metadata": {}, "outputs": [], "source": [ - "df = df[~df[\"Compressor\"].isin([\n", - " \"bitround\", \"jpeg2000-conservative-abs\", \"stochround-conservative-abs\",\n", - " \"stochround-pco-conservative-abs\", \"zfp-conservative-abs\",\n", - " \"bitround-conservative-rel\", \"stochround-pco\", \"stochround\", \"zfp\", \"jpeg2000\",\n", - "])]\n", + "df = df[\n", + " ~df[\"Compressor\"].isin(\n", + " [\n", + " \"bitround\",\n", + " \"jpeg2000-conservative-abs\",\n", + " \"stochround-conservative-abs\",\n", + " \"stochround-pco-conservative-abs\",\n", + " \"zfp-conservative-abs\",\n", + " \"bitround-conservative-rel\",\n", + " \"stochround-pco\",\n", + " \"stochround\",\n", + " \"zfp\",\n", + " \"jpeg2000\",\n", + " ]\n", + " )\n", + "]\n", "df = df[~df[\"Dataset\"].str.contains(\"-tiny\")]\n", "df = df[~df[\"Dataset\"].str.contains(\"-chunked\")]\n", "df = _rename_compressors(df)" @@ -152,17 +174,11 @@ "No data for Compressor: sperr, Variable: ta\n", "No data for Compressor: sperr, Variable: tos\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: sperr, Variable: pr\n", "No data for Compressor: sperr, Variable: ta\n", "No data for Compressor: sperr, Variable: tos\n", @@ -182,17 +198,11 @@ "No data for Compressor: sperr, Variable: ta\n", "No data for Compressor: sperr, Variable: tos\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: sperr, Variable: pr\n", "No data for Compressor: sperr, Variable: ta\n", "No data for Compressor: sperr, Variable: tos\n", @@ -212,22 +222,23 @@ "No data for Compressor: sperr, Variable: ta\n", "No data for Compressor: sperr, Variable: tos\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", - "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n" + "No data for Compressor: safeguarded-sperr, Variable: pr\n" ] } ], "source": [ - "metrics = ['DSSIM', 'MAE', 'Max Absolute Error', \"Spectral Error\", \"Compression Ratio [raw B / enc B]\", 'Satisfies Bound (Value)']\n", + "metrics = [\n", + " \"DSSIM\",\n", + " \"MAE\",\n", + " \"Max Absolute Error\",\n", + " \"Spectral Error\",\n", + " \"Compression Ratio [raw B / enc B]\",\n", + " \"Satisfies Bound (Value)\",\n", + "]\n", "scorecard_data = {}\n", "for error_bound in [\"low\", \"mid\", \"high\"]:\n", " scorecard_data[error_bound] = create_data_matrix(df, error_bound, metrics)" @@ -249,11 +260,10 @@ "outputs": [], "source": [ "METRICS2NAME = {\n", - " # \"Max Absolute Error\": \"MaxAE\",\n", + " \"DSSIM\": \"dSSIM\",\n", " \"MAE\": \"Mean Absolute Error\",\n", - " \"Spatial Relative Error (Value)\": \"SRE\",\n", " \"Compression Ratio [raw B / enc B]\": \"Compression Ratio\",\n", - " \"Satisfies Bound (Value)\": r\"% of Pixels Exceeding Error Bound\",\n", + " \"Satisfies Bound (Value)\": r\"% of Pixels Violating Error Bound\",\n", "}\n", "\n", "VARIABLE2NAME = {\n", @@ -266,43 +276,44 @@ " \"era5-hurricane\": \"h-\",\n", "}\n", "\n", + "\n", "def get_variable_label(variable):\n", - " dataset, var_name = variable.split('/')\n", + " dataset, var_name = variable.split(\"/\")\n", " prefix = DATASET2PREFIX.get(dataset, \"\")\n", " var_name = VARIABLE2NAME.get(var_name, var_name)\n", " return f\"{prefix}{var_name}\"\n", "\n", "\n", "def create_compression_scorecard(\n", - " data_matrix, \n", - " compressors, \n", - " variables, \n", - " metrics, \n", + " data_matrix,\n", + " compressors,\n", + " variables,\n", + " metrics,\n", " cbar=True,\n", - " ref_compressor='sz3', \n", + " ref_compressor=\"sz3\",\n", " higher_better_metrics=[\"DSSIM\", \"Compression Ratio [raw B / enc B]\"],\n", " save_fn=None,\n", " compare_against_0=False,\n", - " highlight_bigger_than_one=False\n", + " highlight_bigger_than_one=False,\n", "):\n", " \"\"\"\n", " Create a scorecard plot similar to the weather forecasting example\n", - " \n", + "\n", " Parameters:\n", " - data_matrix: 2D array with compressors as rows, metric-variable combinations as columns\n", " - compressors: list of compressor names\n", - " - variables: list of variable names \n", + " - variables: list of variable names\n", " - metrics: list of metric names\n", " - ref_compressor: reference compressor for relative calculations\n", " - save_fn: filename to save plot (optional)\n", " \"\"\"\n", - " \n", + "\n", " # Calculate relative differences vs reference compressor\n", " ref_idx = compressors.index(ref_compressor)\n", " ref_values = data_matrix[ref_idx, :]\n", " if compare_against_0:\n", " ref_values = np.zeros_like(data_matrix[ref_idx, :])\n", - " \n", + "\n", " relative_matrix = np.full_like(data_matrix, np.nan)\n", " if highlight_bigger_than_one:\n", " relative_matrix = np.sign(data_matrix) * 101\n", @@ -319,53 +330,68 @@ " ref_val = 1e-10 # Avoid division by zero\n", " if metrics[j // len(variables)] in higher_better_metrics:\n", " # Higher is better metrics\n", - " relative_matrix[i, j] = (ref_values[j] - data_matrix[i, j]) / ref_val * 100\n", + " relative_matrix[i, j] = (\n", + " (ref_values[j] - data_matrix[i, j]) / ref_val * 100\n", + " )\n", " elif metrics[j // len(variables)] == \"Satisfies Bound (Value)\":\n", " relative_matrix[i, j] = 100 if data_matrix[i, j] != 0 else 0\n", " else:\n", - " relative_matrix[i, j] = (data_matrix[i, j] - ref_values[j]) / ref_val * 100\n", + " relative_matrix[i, j] = (\n", + " (data_matrix[i, j] - ref_values[j]) / ref_val * 100\n", + " )\n", "\n", " # Set up colormap - similar to original\n", - " reds = sns.color_palette('Reds', 6)\n", - " blues = sns.color_palette('Blues_r', 6)\n", + " reds = sns.color_palette(\"Reds\", 6)\n", + " blues = sns.color_palette(\"Blues_r\", 6)\n", " cmap = mpl.colors.ListedColormap(blues + [(0.95, 0.95, 0.95)] + reds)\n", " # cb_levels = [-50, -20, -10, -5, -2, -1, 1, 2, 5, 10, 20, 50]\n", " # cb_levels = [-75, -50, -25, -10, -5, -1, 1, 5, 10, 25, 50, 75]\n", " cb_levels = [-100, -75, -50, -25, -10, -1, 1, 10, 25, 50, 75, 100]\n", "\n", - " norm = mpl.colors.BoundaryNorm(cb_levels, cmap.N, extend='both')\n", - " \n", + " norm = mpl.colors.BoundaryNorm(cb_levels, cmap.N, extend=\"both\")\n", + "\n", " # Calculate figure dimensions\n", " ncompressors = len(compressors)\n", " nvariables = len(variables)\n", " nmetrics = len(metrics)\n", - " \n", + "\n", " panel_width = (2.5 / 5) * nvariables\n", " label_width = 1.5 * panel_width\n", " padding_right = 0.1\n", " panel_height = panel_width / nvariables\n", - " \n", + "\n", " title_height = panel_height * 1.25\n", " cbar_height = panel_height * 2\n", " spacing_height = panel_height * 0.1\n", " spacing_width = panel_height * 0.2\n", - " \n", - " total_width = label_width + nmetrics * panel_width + (nmetrics - 1) * spacing_width + padding_right\n", - " total_height = title_height + cbar_height + ncompressors * panel_height + (ncompressors - 1) * spacing_height\n", - " \n", + "\n", + " total_width = (\n", + " label_width\n", + " + nmetrics * panel_width\n", + " + (nmetrics - 1) * spacing_width\n", + " + padding_right\n", + " )\n", + " total_height = (\n", + " title_height\n", + " + cbar_height\n", + " + ncompressors * panel_height\n", + " + (ncompressors - 1) * spacing_height\n", + " )\n", + "\n", " # Create figure and gridspec\n", " fig = plt.figure(figsize=(total_width, total_height))\n", " gs = mpl.gridspec.GridSpec(\n", - " ncompressors, nmetrics,\n", + " ncompressors,\n", + " nmetrics,\n", " figure=fig,\n", " left=label_width / total_width,\n", " right=1 - padding_right / total_width,\n", " top=1 - (title_height / total_height),\n", " bottom=cbar_height / total_height,\n", " hspace=spacing_height / panel_height,\n", - " wspace=spacing_width / panel_width\n", + " wspace=spacing_width / panel_width,\n", " )\n", - " \n", + "\n", " # Plot each panel\n", " for row, compressor in enumerate(compressors):\n", " for col, metric in enumerate(metrics):\n", @@ -377,32 +403,39 @@ "\n", " rel_values = relative_matrix[row, start_col:end_col].reshape(1, -1)\n", " abs_values = data_matrix[row, start_col:end_col]\n", - " \n", + "\n", " # Create heatmap\n", - " img = ax.imshow(rel_values, aspect='auto', cmap=cmap, norm=norm)\n", - " \n", + " img = ax.imshow(rel_values, aspect=\"auto\", cmap=cmap, norm=norm)\n", + "\n", " # Customize axes\n", " ax.set_xticks([])\n", " ax.set_xticklabels([])\n", " ax.set_yticks([])\n", " ax.set_yticklabels([])\n", - " \n", + "\n", " # Add white grid lines\n", " for i in range(nvariables):\n", " rect = mpl.patches.Rectangle(\n", - " (i - 0.5, -0.5), 1, 1,\n", - " linewidth=1, edgecolor='white', facecolor='none'\n", + " (i - 0.5, -0.5),\n", + " 1,\n", + " 1,\n", + " linewidth=1,\n", + " edgecolor=\"white\",\n", + " facecolor=\"none\",\n", " )\n", " ax.add_patch(rect)\n", - " \n", + "\n", " # Add absolute values as text\n", " for i, val in enumerate(abs_values):\n", " # Ensure we don't have black text on dark background\n", " color = \"black\" if abs(rel_values[0, i]) < 75 else \"white\"\n", " fontsize = 10\n", " # Format numbers appropriately\n", - " if metric in [\"DSSIM\", \"Spectral Error\"] and variables[i] in [\"ta\", \"tos\"]:\n", - " # These variables have large regions of NaN values which makes the \n", + " if metric in [\"DSSIM\", \"Spectral Error\"] and variables[i] in [\n", + " \"ta\",\n", + " \"tos\",\n", + " ]:\n", + " # These variables have large regions of NaN values which makes the\n", " # DSSIM and Spectral Error values unreliable.\n", " text = \"N/A\"\n", " color = \"black\"\n", @@ -426,20 +459,20 @@ " else:\n", " text = f\"{val:.2f}\"\n", " ax.text(\n", - " i, \n", - " 0, \n", - " text, \n", - " ha='center', \n", - " va='center', \n", - " fontsize=fontsize, \n", - " color=color\n", + " i, 0, text, ha=\"center\", va=\"center\", fontsize=fontsize, color=color\n", " )\n", "\n", " # Add row labels (compressor names)\n", " if col == 0:\n", - " ax.set_ylabel(_get_legend_name(compressor), rotation=0, ha='right', va='center',\n", - " labelpad=10, fontsize=14)\n", - " \n", + " ax.set_ylabel(\n", + " _get_legend_name(compressor),\n", + " rotation=0,\n", + " ha=\"right\",\n", + " va=\"center\",\n", + " labelpad=10,\n", + " fontsize=14,\n", + " )\n", + "\n", " # Add column titles (variable names)\n", " if row == 0:\n", " # ax.set_title(VARIABLE2NAME.get(variable, variable), fontsize=10, pad=10)\n", @@ -450,37 +483,56 @@ " # ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)\n", " # ax.set_xticks(range(nmetrics))\n", " # ax.set_xticklabels(\n", - " # [METRICS2NAME.get(m, m) for m in metrics], \n", - " # rotation=45, \n", + " # [METRICS2NAME.get(m, m) for m in metrics],\n", + " # rotation=45,\n", " # ha='left', fontsize=8)\n", " ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)\n", " ax.set_xticks(range(nvariables))\n", " ax.set_xticklabels(\n", " [VARIABLE2NAME.get(v, v) for v in variables],\n", " rotation=45,\n", - " ha='left', fontsize=12)\n", - " \n", + " ha=\"left\",\n", + " fontsize=12,\n", + " )\n", + "\n", " # Style spines\n", " for spine in ax.spines.values():\n", - " spine.set_color('0.7')\n", - " \n", + " spine.set_color(\"0.7\")\n", + "\n", " # Add colorbar\n", " if cbar and not highlight_bigger_than_one:\n", " rel_cbar_height = cbar_height / total_height\n", " cax = fig.add_axes((0.4, rel_cbar_height * 0.3, 0.5, rel_cbar_height * 0.2))\n", - " cb = fig.colorbar(img, cax=cax, orientation='horizontal')\n", + " cb = fig.colorbar(img, cax=cax, orientation=\"horizontal\")\n", " cb.ax.set_xticks(cb_levels)\n", " if highlight_bigger_than_one:\n", - " cb.ax.set_xlabel('Better ← |non-chunked - chunked| → Worse', fontsize=16)\n", + " cb.ax.set_xlabel(\"Better ← |non-chunked - chunked| → Worse\", fontsize=16)\n", " else:\n", - " cb.ax.set_xlabel(f'Better ← % difference vs {_get_legend_name(ref_compressor)} → Worse', fontsize=16)\n", - " \n", + " cb.ax.set_xlabel(\n", + " f\"Better ← % difference vs {_get_legend_name(ref_compressor)} → Worse\",\n", + " fontsize=16,\n", + " )\n", + "\n", " if highlight_bigger_than_one:\n", " chunking_handles = [\n", - " Line2D([], [], marker=\"s\", color=cmap(101), linestyle=\"None\", markersize=10,\n", - " label=\"Not Chunked Better\"),\n", - " Line2D([], [], marker=\"s\", color=cmap(-101), linestyle=\"None\", markersize=10,\n", - " label=\"Chunked Better\"),\n", + " Line2D(\n", + " [],\n", + " [],\n", + " marker=\"s\",\n", + " color=cmap(101),\n", + " linestyle=\"None\",\n", + " markersize=10,\n", + " label=\"Not Chunked Better\",\n", + " ),\n", + " Line2D(\n", + " [],\n", + " [],\n", + " marker=\"s\",\n", + " color=cmap(-101),\n", + " linestyle=\"None\",\n", + " markersize=10,\n", + " label=\"Chunked Better\",\n", + " ),\n", " ]\n", "\n", " ax.legend(\n", @@ -488,13 +540,13 @@ " loc=\"upper left\",\n", " ncol=2,\n", " bbox_to_anchor=(-0.5, -0.05),\n", - " fontsize=16\n", + " fontsize=16,\n", " )\n", "\n", " plt.tight_layout()\n", - " \n", + "\n", " if save_fn:\n", - " plt.savefig(save_fn, dpi=300, bbox_inches='tight')\n", + " plt.savefig(save_fn, dpi=300, bbox_inches=\"tight\")\n", " plt.close()\n", " else:\n", " plt.show()" @@ -517,9 +569,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] }, @@ -534,9 +586,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] }, @@ -551,9 +603,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] } @@ -563,22 +615,22 @@ " print(f\"Creating scorecard for {bound_name} bound...\")\n", " # Split into two rows for better readability.\n", " create_compression_scorecard(\n", - " data_matrix[:, :3*len(variables)], \n", - " compressors, \n", - " variables, \n", + " data_matrix[:, : 3 * len(variables)],\n", + " compressors,\n", + " variables,\n", " metrics[:3],\n", " ref_compressor=\"bitround-pco\",\n", " cbar=False,\n", - " save_fn=f\"figures_updated/{bound_name}_scorecard_row1.pdf\"\n", + " save_fn=f\"figures_updated/{bound_name}_scorecard_row1.pdf\",\n", " )\n", "\n", " create_compression_scorecard(\n", - " data_matrix[:, 3*len(variables):], \n", - " compressors, \n", - " variables, \n", + " data_matrix[:, 3 * len(variables) :],\n", + " compressors,\n", + " variables,\n", " metrics[3:],\n", " ref_compressor=\"bitround-pco\",\n", - " save_fn=f\"figures_updated/{bound_name}_scorecard_row2.pdf\"\n", + " save_fn=f\"figures_updated/{bound_name}_scorecard_row2.pdf\",\n", " )" ] }, @@ -607,11 +659,11 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] }, @@ -626,11 +678,11 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] }, @@ -645,11 +697,11 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] } @@ -660,32 +712,32 @@ " # Split into two rows for better readability.\n", " num_vars = len(variables)\n", " create_compression_scorecard(\n", - " data_matrix[:, :2*num_vars], \n", - " compressors, \n", - " variables, \n", + " data_matrix[:, : 2 * num_vars],\n", + " compressors,\n", + " variables,\n", " metrics[:2],\n", " ref_compressor=\"bitround-pco\",\n", " cbar=False,\n", - " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row1.pdf\"\n", + " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row1.pdf\",\n", " )\n", "\n", " create_compression_scorecard(\n", - " data_matrix[:, 2*num_vars:4*num_vars], \n", - " compressors, \n", - " variables, \n", + " data_matrix[:, 2 * num_vars : 4 * num_vars],\n", + " compressors,\n", + " variables,\n", " metrics[2:4],\n", " ref_compressor=\"bitround-pco\",\n", " cbar=False,\n", - " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row2.pdf\"\n", + " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row2.pdf\",\n", " )\n", "\n", " create_compression_scorecard(\n", - " data_matrix[:, 4*num_vars:], \n", - " compressors, \n", - " variables, \n", + " data_matrix[:, 4 * num_vars :],\n", + " compressors,\n", + " variables,\n", " metrics[4:],\n", " ref_compressor=\"bitround-pco\",\n", - " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row3.pdf\"\n", + " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row3.pdf\",\n", " )" ] }, diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 2e87acf..d15403e 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -73,7 +73,7 @@ def _get_lineinfo(compressor: str) -> tuple[str, str]: DISTORTION2LEGEND_NAME = { "Relative MAE": "Mean Absolute Error", - "Relative DSSIM": "DSSIM", + "Relative dSSIM": "dSSIM", "Relative MaxAbsError": "Max Absolute Error", "Spectral Error": "Spectral Error", } @@ -155,7 +155,7 @@ def plot_metrics( for metric in [ "Relative MAE", - "Relative DSSIM", + "Relative dSSIM", "Relative MaxAbsError", "Relative SpectralError", ]: @@ -214,7 +214,7 @@ def _normalize(data): normalize_vars = [ ("Compression Ratio [raw B / enc B]", "Relative CR"), ("MAE", "Relative MAE"), - ("DSSIM", "Relative DSSIM"), + ("DSSIM", "Relative dSSIM"), ("Max Absolute Error", "Relative MaxAbsError"), ("Spectral Error", "Relative SpectralError"), ] @@ -528,12 +528,12 @@ def _plot_aggregated_rd_curve( right=True, ) plt.xlabel( - r"Mean Normalized Compression Ratio ($\uparrow$)", + r"Mean Normalised Compression Ratio ($\uparrow$)", fontsize=16, ) metric_name = DISTORTION2LEGEND_NAME.get(distortion_metric, distortion_metric) plt.ylabel( - rf"Mean Normalized {metric_name} ($\downarrow$)", + rf"Mean Normalised {metric_name} ($\downarrow$)", fontsize=16, ) plt.legend( @@ -545,7 +545,7 @@ def _plot_aggregated_rd_curve( ) arrow_color = "black" - if "DSSIM" in distortion_metric: + if "dSSIM" in distortion_metric: # Add an arrow pointing into the top right corner plt.annotate( "", @@ -573,7 +573,7 @@ def _plot_aggregated_rd_curve( ) # Correct the y-label to point upwards plt.ylabel( - rf"Mean Normalized {metric_name} ($\uparrow$)", + rf"Mean Normalised {metric_name} ($\uparrow$)", fontsize=16, ) else: @@ -601,7 +601,7 @@ def _plot_aggregated_rd_curve( ha="center", ) if ( - "DSSIM" in distortion_metric + "dSSIM" in distortion_metric or "MaxAbsError" in distortion_metric or "SpectralError" in distortion_metric ): @@ -736,13 +736,14 @@ def _plot_grouped_df( ax.set_title(f"{error_bound.capitalize()} Error Bound", fontsize=14) ax.grid(axis="y", linestyle="--", alpha=0.7) if i == 0: - ax.legend(fontsize=14) + ax.legend(fontsize=14, loc="upper left") ax.set_ylabel(ylabel, fontsize=14) + if i == 1: ax.annotate( "Better", - xy=(0.1, 0.8), + xy=(0.1, 0.75), xycoords="axes fraction", - xytext=(0.1, 0.95), + xytext=(0.1, 0.9), textcoords="axes fraction", arrowprops=dict(arrowstyle="->", lw=3, color="black"), fontsize=14, From eaabb96f5a547fccc003eddf547950eac40d7f5e Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Thu, 5 Feb 2026 20:37:54 +0200 Subject: [PATCH 20/26] tweak throughput plots --- scorecards.ipynb | 7 ----- .../compressor/plotting/plot_metrics.py | 31 ++++++++++++------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/scorecards.ipynb b/scorecards.ipynb index 8beff56..f8b7b34 100644 --- a/scorecards.ipynb +++ b/scorecards.ipynb @@ -12,19 +12,12 @@ "import matplotlib.pyplot as plt\n", "import matplotlib as mpl\n", "import seaborn as sns\n", - "from matplotlib.colors import LinearSegmentedColormap\n", - "import matplotlib.patches as mpatches\n", "from matplotlib.lines import Line2D\n", "\n", - "from pathlib import Path\n", "from climatebenchpress.compressor.plotting.plot_metrics import (\n", " _rename_compressors,\n", " _get_legend_name,\n", - " _normalize,\n", - " _get_lineinfo,\n", - " DISTORTION2LEGEND_NAME,\n", " _COMPRESSOR_ORDER,\n", - " _savefig,\n", ")" ] }, diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index d15403e..780fbe9 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -614,24 +614,23 @@ def _plot_aggregated_rd_curve( def _plot_throughput(df, outfile: None | Path = None): - # Transform throughput measurements from raw B/s to s/MB for better comparison - # with instruction count measurements. encode_col = "Encode Throughput [raw B / s]" decode_col = "Decode Throughput [raw B / s]" new_df = df[["Compressor", "Error Bound Name", encode_col, decode_col]].copy() - transformed_encode_col = "Encode Throughput [s / MB]" - transformed_decode_col = "Decode Throughput [s / MB]" - new_df[transformed_encode_col] = 1e6 / new_df[encode_col] - new_df[transformed_decode_col] = 1e6 / new_df[decode_col] + transformed_encode_col = "Encode Throughput [MiB / s]" + transformed_decode_col = "Decode Throughput [MiB / s]" + new_df[transformed_encode_col] = new_df[encode_col] / (2**20) + new_df[transformed_decode_col] = new_df[decode_col] / (2**20) encode_col, decode_col = transformed_encode_col, transformed_decode_col grouped_df = _get_median_and_quantiles(new_df, encode_col, decode_col) _plot_grouped_df( grouped_df, title="", - ylabel="Throughput [s / MB]", + ylabel="Throughput [MiB / s]", logy=True, outfile=outfile, + up=True, ) @@ -645,6 +644,7 @@ def _plot_instruction_count(df, outfile: None | Path = None): ylabel="Instructions [# / raw B]", logy=True, outfile=outfile, + up=False, ) @@ -679,7 +679,12 @@ def _get_median_and_quantiles(df, encode_column, decode_column): def _plot_grouped_df( - grouped_df, title, ylabel, outfile: None | Path = None, logy=False + grouped_df, + title, + ylabel, + outfile: None | Path = None, + logy=False, + up=False, ): fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharex=True, sharey=True) @@ -736,16 +741,18 @@ def _plot_grouped_df( ax.set_title(f"{error_bound.capitalize()} Error Bound", fontsize=14) ax.grid(axis="y", linestyle="--", alpha=0.7) if i == 0: - ax.legend(fontsize=14, loc="upper left") + ax.legend( + fontsize=14, loc="lower left" if up else "upper left", framealpha=0.9 + ) ax.set_ylabel(ylabel, fontsize=14) if i == 1: ax.annotate( "Better", - xy=(0.1, 0.75), + xy=(0.51, 0.75), xycoords="axes fraction", - xytext=(0.1, 0.9), + xytext=(0.51, 0.92), textcoords="axes fraction", - arrowprops=dict(arrowstyle="->", lw=3, color="black"), + arrowprops=dict(arrowstyle="<-" if up else "->", lw=3, color="black"), fontsize=14, ha="center", va="bottom", From 4e3ad92afefdeaace7a3767484d3c2ce42ae4ab0 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Fri, 6 Feb 2026 06:38:49 +0200 Subject: [PATCH 21/26] Use Crash instead of Fail --- scorecards.ipynb | 125 ++++------------------------------------------- 1 file changed, 9 insertions(+), 116 deletions(-) diff --git a/scorecards.ipynb b/scorecards.ipynb index f8b7b34..b409d06 100644 --- a/scorecards.ipynb +++ b/scorecards.ipynb @@ -433,7 +433,7 @@ " text = \"N/A\"\n", " color = \"black\"\n", " elif np.isnan(val):\n", - " text = \"Fail\"\n", + " text = \"Crash\"\n", " color = \"black\"\n", " elif abs(val) > 10_000:\n", " text = f\"{val:.1e}\"\n", @@ -562,9 +562,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] }, @@ -579,9 +579,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] }, @@ -596,9 +596,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] } @@ -614,7 +614,7 @@ " metrics[:3],\n", " ref_compressor=\"bitround-pco\",\n", " cbar=False,\n", - " save_fn=f\"figures_updated/{bound_name}_scorecard_row1.pdf\",\n", + " save_fn=f\"scorecards/{bound_name}_scorecard_row1.pdf\",\n", " )\n", "\n", " create_compression_scorecard(\n", @@ -623,114 +623,7 @@ " variables,\n", " metrics[3:],\n", " ref_compressor=\"bitround-pco\",\n", - " save_fn=f\"figures_updated/{bound_name}_scorecard_row2.pdf\",\n", - " )" - ] - }, - { - "cell_type": "markdown", - "id": "3afb646e", - "metadata": {}, - "source": [ - "## Two-Column Scorecard" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "b6fe5f55", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Creating scorecard for low bound...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Creating scorecard for mid bound...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Creating scorecard for high bound...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n" - ] - } - ], - "source": [ - "for bound_name, (data_matrix, compressors, variables) in scorecard_data.items():\n", - " print(f\"Creating scorecard for {bound_name} bound...\")\n", - " # Split into two rows for better readability.\n", - " num_vars = len(variables)\n", - " create_compression_scorecard(\n", - " data_matrix[:, : 2 * num_vars],\n", - " compressors,\n", - " variables,\n", - " metrics[:2],\n", - " ref_compressor=\"bitround-pco\",\n", - " cbar=False,\n", - " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row1.pdf\",\n", - " )\n", - "\n", - " create_compression_scorecard(\n", - " data_matrix[:, 2 * num_vars : 4 * num_vars],\n", - " compressors,\n", - " variables,\n", - " metrics[2:4],\n", - " ref_compressor=\"bitround-pco\",\n", - " cbar=False,\n", - " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row2.pdf\",\n", - " )\n", - "\n", - " create_compression_scorecard(\n", - " data_matrix[:, 4 * num_vars :],\n", - " compressors,\n", - " variables,\n", - " metrics[4:],\n", - " ref_compressor=\"bitround-pco\",\n", - " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row3.pdf\",\n", + " save_fn=f\"scorecards/{bound_name}_scorecard_row2.pdf\",\n", " )" ] }, From f7963e8586d6cf545c95745c678627969bc0db5d Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Fri, 6 Feb 2026 09:02:53 +0200 Subject: [PATCH 22/26] add crash->crash arrows --- scorecards.ipynb | 77 ++++++++++++++++++------------------------------ 1 file changed, 29 insertions(+), 48 deletions(-) diff --git a/scorecards.ipynb b/scorecards.ipynb index b409d06..d787693 100644 --- a/scorecards.ipynb +++ b/scorecards.ipynb @@ -407,13 +407,15 @@ " ax.set_yticklabels([])\n", "\n", " # Add white grid lines\n", - " for i in range(nvariables):\n", + " for i in range(1, nvariables):\n", " rect = mpl.patches.Rectangle(\n", " (i - 0.5, -0.5),\n", " 1,\n", " 1,\n", " linewidth=1,\n", - " edgecolor=\"white\",\n", + " edgecolor=\"lightgrey\"\n", + " if np.isnan(abs_values[i]) and np.isnan(abs_values[i - 1])\n", + " else \"white\",\n", " facecolor=\"none\",\n", " )\n", " ax.add_patch(rect)\n", @@ -455,6 +457,27 @@ " i, 0, text, ha=\"center\", va=\"center\", fontsize=fontsize, color=color\n", " )\n", "\n", + " if (\n", + " row > 0\n", + " and np.isnan(val)\n", + " and np.isnan(data_matrix[row - 1, col * nvariables + i])\n", + " and compressor == f\"safeguarded-{compressors[row - 1]}\"\n", + " and not (\n", + " metric in [\"DSSIM\", \"Spectral Error\"]\n", + " and variables[i]\n", + " in [\n", + " \"ta\",\n", + " \"tos\",\n", + " ]\n", + " )\n", + " ):\n", + " ax.annotate(\n", + " \"\",\n", + " xy=(i, -0.15),\n", + " xytext=(i, -0.9),\n", + " arrowprops=dict(arrowstyle=\"->\", lw=2, color=\"lightgrey\"),\n", + " )\n", + "\n", " # Add row labels (compressor names)\n", " if col == 0:\n", " ax.set_ylabel(\n", @@ -536,7 +559,7 @@ " fontsize=16,\n", " )\n", "\n", - " plt.tight_layout()\n", + " # plt.tight_layout()\n", "\n", " if save_fn:\n", " plt.savefig(save_fn, dpi=300, bbox_inches=\"tight\")\n", @@ -555,52 +578,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "Creating scorecard for low bound...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Creating scorecard for mid bound...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "Creating scorecard for low bound...\n", + "Creating scorecard for mid bound...\n", "Creating scorecard for high bound...\n" ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n" - ] } ], "source": [ @@ -630,7 +611,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c546f6b0-bb7b-4646-83bc-3f502bcf6a9f", + "id": "c2d8dea1-cd87-48d1-9d5b-8fe106183cbf", "metadata": {}, "outputs": [], "source": [] From 51d32246e15d06ec575a3c177158377148d07122 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Fri, 6 Feb 2026 09:07:02 +0200 Subject: [PATCH 23/26] improve violation metric title --- scorecards.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scorecards.ipynb b/scorecards.ipynb index d787693..b95a393 100644 --- a/scorecards.ipynb +++ b/scorecards.ipynb @@ -256,7 +256,7 @@ " \"DSSIM\": \"dSSIM\",\n", " \"MAE\": \"Mean Absolute Error\",\n", " \"Compression Ratio [raw B / enc B]\": \"Compression Ratio\",\n", - " \"Satisfies Bound (Value)\": r\"% of Pixels Violating Error Bound\",\n", + " \"Satisfies Bound (Value)\": r\"% of Data Points Violating the Error Bound\",\n", "}\n", "\n", "VARIABLE2NAME = {\n", From 09ee78dd3752209b4d7ffee998af27c5a89145ad Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Fri, 6 Feb 2026 09:16:24 +0200 Subject: [PATCH 24/26] highlight safeguards in throughput --- src/climatebenchpress/compressor/plotting/plot_metrics.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 780fbe9..3c5d8ca 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -716,7 +716,12 @@ def _plot_grouped_df( bound_data["encode_upper_quantile"], ], label="Encoding", + edgecolor="white", + linewidth=0, color=[_get_lineinfo(comp)[0] for comp in compressors], + hatch=[ + "O" if comp.startswith("safeguarded-") else "" for comp in compressors + ], ) # Plot decode throughput @@ -732,6 +737,9 @@ def _plot_grouped_df( edgecolor=[_get_lineinfo(comp)[0] for comp in compressors], fill=False, linewidth=4, + hatch=[ + "O" if comp.startswith("safeguarded-") else "" for comp in compressors + ], ) # Add labels and title From b031d13da521b191b8e47a7b12179ffb8a2d1bf7 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Fri, 6 Feb 2026 10:23:36 +0200 Subject: [PATCH 25/26] change throughout legend labels --- src/climatebenchpress/compressor/plotting/plot_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 3c5d8ca..fc7cd7f 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -715,7 +715,7 @@ def _plot_grouped_df( bound_data["encode_lower_quantile"], bound_data["encode_upper_quantile"], ], - label="Encoding", + label="Compression", edgecolor="white", linewidth=0, color=[_get_lineinfo(comp)[0] for comp in compressors], @@ -733,7 +733,7 @@ def _plot_grouped_df( bound_data["decode_lower_quantile"], bound_data["decode_upper_quantile"], ], - label="Decoding", + label="Decompression", edgecolor=[_get_lineinfo(comp)[0] for comp in compressors], fill=False, linewidth=4, From 35ba589f20301405a6ae773e5d8da109ab54b03d Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Tue, 10 Feb 2026 21:58:01 +0200 Subject: [PATCH 26/26] Improve dSSIM rd curves --- .../compressor/plotting/plot_metrics.py | 31 ++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index fc7cd7f..1b5b1bb 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -146,7 +146,7 @@ def plot_metrics( # ) df = _rename_compressors(df) - normalized_df = _normalize(df) + normalized_df, normalized_mean_std = _normalize(df) _plot_bound_violations( normalized_df, bound_names, plots_path / "bound_violations.pdf" ) @@ -164,6 +164,7 @@ def plot_metrics( normalized_df, compression_metric="Relative CR", distortion_metric=metric, + mean_std=normalized_mean_std[metric], outfile=plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf", agg="mean", bound_names=bound_names, @@ -173,6 +174,7 @@ def plot_metrics( normalized_df, compression_metric="Relative CR", distortion_metric=metric, + mean_std=normalized_mean_std[metric], outfile=plots_path / f"full_rd_curve_{metric.lower().replace(' ', '_')}.pdf", agg="mean", @@ -224,6 +226,7 @@ def _normalize(data): dssim_unreliable = normalized["Variable"].isin(["ta", "tos"]) normalized.loc[dssim_unreliable, "DSSIM"] = np.nan + normalize_mean_std = dict() for col, new_col in normalize_vars: mean_std = dict() for var in variables: @@ -239,7 +242,9 @@ def _normalize(data): axis=1, ) - return normalized + normalize_mean_std[new_col] = mean_std + + return normalized, normalize_mean_std def _plot_per_variable_metrics( @@ -434,6 +439,7 @@ def _plot_aggregated_rd_curve( normalized_df, compression_metric, distortion_metric, + mean_std, outfile: None | Path = None, agg="median", bound_names=["low", "mid", "high"], @@ -546,10 +552,27 @@ def _plot_aggregated_rd_curve( arrow_color = "black" if "dSSIM" in distortion_metric: + # Annotate dSSIM = 1, accounting for the normalization + dssim_one = getattr(np, f"nan{agg}")( + [(1 - ms[0]) / ms[1] for ms in mean_std.values()] + ) + plt.axhline(dssim_one, c="k", ls="--") + plt.text( + np.percentile(plt.xlim(), 63), + dssim_one, + "dSSIM = 1", + fontsize=16, + fontweight="bold", + color="black", + ha="center", + va="center", + bbox=dict(edgecolor="none", facecolor="w", alpha=0.85), + ) + # Add an arrow pointing into the top right corner plt.annotate( "", - xy=(0.95, 0.95), + xy=(0.95, 0.875 if remove_outliers else 0.9), xycoords="axes fraction", xytext=(-60, -50), textcoords="offset points", @@ -562,7 +585,7 @@ def _plot_aggregated_rd_curve( # Attach the text to the lower left of the arrow plt.text( 0.83, - 0.92, + 0.845 if remove_outliers else 0.87, "Better", transform=plt.gca().transAxes, fontsize=16,