From 09b4aba7fe57869ad9d0d5edafc2944a6804550d Mon Sep 17 00:00:00 2001 From: Andrew Fitzgibbon Date: Wed, 16 Jul 2025 22:04:56 +0000 Subject: [PATCH] AAC crashing on import under pytest @py_assrt --- docs/ndarray_str.ipynb | 4 +- experiments/typecheck_example.py | 4 +- requirements.txt | 1 + src/awfutils/__init__.py | 1 + src/awfutils/ndarray_str.py | 71 +++++++++++++------ test/test_ndarray_str.py | 116 +++++++++++++++++++++++++------ 6 files changed, 152 insertions(+), 45 deletions(-) diff --git a/docs/ndarray_str.ipynb b/docs/ndarray_str.ipynb index 55884e0..30f31eb 100644 --- a/docs/ndarray_str.ipynb +++ b/docs/ndarray_str.ipynb @@ -131,7 +131,7 @@ ], "metadata": { "kernelspec": { - "display_name": "awfutils-test", + "display_name": "awfutils", "language": "python", "name": "python3" }, @@ -145,7 +145,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.13.0" + "version": "3.12.10" } }, "nbformat": 4, diff --git a/experiments/typecheck_example.py b/experiments/typecheck_example.py index bbcb2e7..eaca677 100644 --- a/experiments/typecheck_example.py +++ b/experiments/typecheck_example.py @@ -1,6 +1,6 @@ import functools -from typecheck import typecheck +from awfutils import typecheck def foo(x: int, y: float): @@ -12,7 +12,7 @@ def foo(x: int, y: float): foo(3, 1.3) -@functools.partial(typecheck, show_src=True) +@typecheck(show_src=True) def foo(x: int, y: float): z: int = x * y # Now it raises AssertionError: z not int w: float = z * 3.2 diff --git a/requirements.txt b/requirements.txt index 05f3f19..06a4092 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +array_api_compat astpretty prettyprinter ml_collections diff --git a/src/awfutils/__init__.py b/src/awfutils/__init__.py index b5ac5b4..55cdd43 100644 --- a/src/awfutils/__init__.py +++ b/src/awfutils/__init__.py @@ -10,6 +10,7 @@ pt_sub, pt_sum, pt_print, + pt_print_aux, ) from .print_utils import fn_name, class_name from .typecheck import get_ast_for_function, typecheck diff --git a/src/awfutils/ndarray_str.py b/src/awfutils/ndarray_str.py index 30fc3e7..171a929 100644 --- a/src/awfutils/ndarray_str.py +++ b/src/awfutils/ndarray_str.py @@ -1,7 +1,42 @@ -import numpy as np +from array_api_compat import array_namespace +import re -def ndarray_str(x, tiny : int = 10, large : int = 100_000_000): +# Fixes for Array API +import array_api_compat + + +def _is_torch_array(xp, x): + # String check for jax so we don't need to import it + return xp.__name__ != "jax.numpy" and xp.is_torch_array(x) + + +def _is_floating(xp, x): + if _is_torch_array(xp, x): + return xp.is_floating_point(x) + + return xp.issubdtype(x.dtype, xp.floating) + + +def _quantile(values, quantiles): + xp = array_namespace(values) + + if _is_torch_array(xp, values): + dtype = values.dtype + if dtype != xp.float32: + values = xp.tensor(values, dtype=xp.float32) + qs = xp.quantile(values, xp.tensor(quantiles), interpolation="nearest") + return qs.to(dtype=dtype) + else: + return xp.quantile(values, xp.array(quantiles), interpolation="nearest").astype( + values.dtype + ) + + +# End/Fixes for Array API + + +def ndarray_str(x, tiny: int = 10, large: int = 100_000_000): """ Nicely print an ndarray on one line. @@ -17,28 +52,22 @@ def ndarray_str(x, tiny : int = 10, large : int = 100_000_000): See the notebook <11-utils.ipynb> for more information. """ - def size(x): - """ - Size, for torch or np - """ - return np.prod(x.shape) - if not hasattr(x, "__array__"): return repr(x) - # Convert to numpy array - TODO: do this after we know it's small enough - x = np.array(x) + xp = array_namespace(x) shape_str = "x".join(map(str, x.shape)) dtype_str = ( f"{x.dtype}".replace("float", "f").replace("uint", "u").replace("int", "i") ) + dtype_str = re.sub(r"^torch\.", r"", dtype_str) type_str = f"{dtype_str}[{shape_str}]" notes = "" - finite_vals = x[np.isfinite(x)] - all_finite = size(finite_vals) == size(x) - display_all = size(x) <= tiny + finite_vals = x[xp.isfinite(x)] + all_finite = xp.size(finite_vals) == xp.size(x) + display_all = xp.size(x) <= tiny def disp(a, fmt): if len(a.shape) > 1: @@ -53,25 +82,25 @@ def disp(a, fmt): head, tail = "[", "]" else: if not all_finite: - notes += f" #inf={np.isinf(x).sum()} #nan={np.isnan(x).sum()}" + notes += f" #inf={xp.isinf(x).sum()} #nan={xp.isnan(x).sum()}" - if size(x) < large: + if xp.size(x) < large: quantiles = [0, 0.05, 0.25, 0.5, 0.75, 0.95, 1.0] - vals = np.quantile(finite_vals, quantiles, method="nearest") + vals = _quantile(finite_vals, quantiles) head, tail = "Percentiles{", "}" else: # Too large to sort, just show min, median, max - vals = np.array( - [finite_vals.min(), np.median(finite_vals), finite_vals.max()], + vals = xp.array( + [finite_vals.min(), xp.median(finite_vals), finite_vals.max()], dtype=x.dtype, ) head, tail = "MinMedMax{", "}" - if np.issubdtype(x.dtype, np.floating): + if _is_floating(xp, x): # scale down vals - max = np.abs(finite_vals).max() if size(finite_vals) else 0 + max = xp.abs(finite_vals).max() if xp.size(finite_vals) else 0 if max > 0: - logmax = np.floor(np.log10(max)) + logmax = xp.floor(xp.log10(max)) if -2 <= logmax <= 3: logmax = 0 max_scale = 10**-logmax diff --git a/test/test_ndarray_str.py b/test/test_ndarray_str.py index 487ed3c..0359116 100644 --- a/test/test_ndarray_str.py +++ b/test/test_ndarray_str.py @@ -1,47 +1,123 @@ +import pytest import numpy as np +import torch +from types import SimpleNamespace from awfutils import ndarray_str -def mx(dtype, *sz): +def np_mk(dtype, *sz): return np.arange(np.prod(sz), dtype=dtype).reshape(sz) +def np_mx(dtype, vals): + return np.array(vals, dtype=dtype) + + +def np_zeros(dtype, shape): + return np.zeros(shape, dtype=dtype) + + +import jax.numpy as jnp + + +def jax_mk(dtype, *sz): + return jnp.arange(np.prod(sz), dtype=dtype).reshape(sz) + + +def jax_mx(dtype, vals): + return jnp.array(vals, dtype=dtype) + + +def jax_zeros(dtype, shape): + return jnp.zeros(shape, dtype=dtype) + + +numpy_to_torch_dtype_dict = { + np.bool: torch.bool, + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float64, + np.complex64: torch.complex64, + np.complex128: torch.complex128, +} + + +def torch_mk(dtype, *sz): + return torch.arange(np.prod(sz), dtype=numpy_to_torch_dtype_dict[dtype]).reshape(sz) + + +def torch_mx(dtype, vals): + return torch.tensor(vals) + + +def torch_zeros(dtype, shape): + return torch.zeros(shape, dtype=numpy_to_torch_dtype_dict[dtype]) + + +platforms = [ + pytest.param( + SimpleNamespace(mx=np_mx, mk=np_mk, zeros=np_zeros), + id="np", + ), + pytest.param( + SimpleNamespace(mx=torch_mx, mk=torch_mk, zeros=torch_zeros), + id="torch", + ), + pytest.param( + SimpleNamespace(mx=jax_mx, mk=jax_mk, zeros=jax_zeros), + id="jax", + ), +] + + def go(x, target): - act = ndarray_str(np.array(x)) + act = ndarray_str(x) print(act) assert act == target -def test_ndarray_str(): +@pytest.mark.parametrize("p", platforms) +def test_ndarray_str(p): + mx, mk = p.mx, p.mk go( - [[1, 2, np.inf], [np.inf, np.nan, 1.1]], - "f64[2x3] [[1.000 2.000 inf], [inf nan 1.100]]", + mx(np.float32, [[1, 2, np.inf], [np.inf, np.nan, 1.1]]), + "f32[2x3] [[1.000 2.000 inf], [inf nan 1.100]]", ) go( - mx(np.float64, 2, 3, 4), - "f64[2x3x4] Percentiles{0.000 1.000 6.000 12.000 17.000 22.000 23.000}", + mk(np.float32, 2, 3, 4), + "f32[2x3x4] Percentiles{0.000 1.000 6.000 11.000 17.000 22.000 23.000}", ) go( - mx(np.float32, 2, 3), + mk(np.float32, 2, 3), "f32[2x3] [[0.000 1.000 2.000], [3.000 4.000 5.000]]", ) go( - [1, 2, 3, 4, 5, 5, 4, 3, 2, 1, np.inf, np.inf, np.nan], - "f64[13] Percentiles{1.000 1.000 2.000 3.000 4.000 5.000 5.000} #inf=2 #nan=1", + mx(np.float32, [1, 2, 3, 4, 5, 5, 4, 3, 2, 1, np.inf, np.inf, np.nan]), + "f32[13] Percentiles{1.000 1.000 2.000 3.000 4.000 5.000 5.000} #inf=2 #nan=1", + ) + go( + mx(np.float32, [0.0, 0.0, np.nan]), + "f32[3] [0.0 0.0 nan]", ) - go([0.0, 0.0, np.nan], "f64[3] [0.0 0.0 nan]") -def test_zeros(): - go(np.zeros([]), "f64[] [0.0]") - go(np.zeros(0), "f64[0] []") - go(np.zeros(1), "f64[1] [0.0]") - go(np.zeros(100), "f64[100] Zeros") - go(np.zeros((100, 100, 101), dtype=np.float16), "f16[100x100x101] Zeros") +@pytest.mark.parametrize("p", platforms) +def test_zeros(p): + go(p.zeros(np.float32, []), "f32[] [0.0]") + go(p.zeros(np.float32, 0), "f32[0] []") + go(p.zeros(np.float32, 1), "f32[1] [0.0]") + go(p.zeros(np.float32, 100), "f32[100] Zeros") + go(p.zeros(np.float16, (100, 100, 101)), "f16[100x100x101] Zeros") -def test_ints(): - go(mx(np.int32, 2, 3), "i32[2x3] [[0 1 2], [3 4 5]]") - go(mx(np.int32, 20, 30), "i32[20x30] Percentiles{0 30 150 300 449 569 599}") +@pytest.mark.parametrize("p", platforms) +def test_ints(p): + go(p.mk(np.int32, 2, 3), "i32[2x3] [[0 1 2], [3 4 5]]") + go(p.mk(np.int32, 21, 31), "i32[21x31] Percentiles{0 32 162 325 488 618 650}")