Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/ndarray_str.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "awfutils-test",
"display_name": "awfutils",
"language": "python",
"name": "python3"
},
Expand All @@ -145,7 +145,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
"version": "3.12.10"
}
},
"nbformat": 4,
Expand Down
4 changes: 2 additions & 2 deletions experiments/typecheck_example.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools

from typecheck import typecheck
from awfutils import typecheck


def foo(x: int, y: float):
Expand All @@ -12,7 +12,7 @@ def foo(x: int, y: float):
foo(3, 1.3)


@functools.partial(typecheck, show_src=True)
@typecheck(show_src=True)
def foo(x: int, y: float):
z: int = x * y # Now it raises AssertionError: z not int
w: float = z * 3.2
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
array_api_compat
astpretty
prettyprinter
ml_collections
1 change: 1 addition & 0 deletions src/awfutils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
pt_sub,
pt_sum,
pt_print,
pt_print_aux,
)
from .print_utils import fn_name, class_name
from .typecheck import get_ast_for_function, typecheck
Expand Down
71 changes: 50 additions & 21 deletions src/awfutils/ndarray_str.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,42 @@
import numpy as np
from array_api_compat import array_namespace
import re


def ndarray_str(x, tiny : int = 10, large : int = 100_000_000):
# Fixes for Array API
import array_api_compat


def _is_torch_array(xp, x):
# String check for jax so we don't need to import it
return xp.__name__ != "jax.numpy" and xp.is_torch_array(x)


def _is_floating(xp, x):
if _is_torch_array(xp, x):
return xp.is_floating_point(x)

return xp.issubdtype(x.dtype, xp.floating)


def _quantile(values, quantiles):
xp = array_namespace(values)

if _is_torch_array(xp, values):
dtype = values.dtype
if dtype != xp.float32:
values = xp.tensor(values, dtype=xp.float32)
qs = xp.quantile(values, xp.tensor(quantiles), interpolation="nearest")
return qs.to(dtype=dtype)
else:
return xp.quantile(values, xp.array(quantiles), interpolation="nearest").astype(
values.dtype
)


# End/Fixes for Array API


def ndarray_str(x, tiny: int = 10, large: int = 100_000_000):
"""
Nicely print an ndarray on one line.

Expand All @@ -17,28 +52,22 @@ def ndarray_str(x, tiny : int = 10, large : int = 100_000_000):
See the notebook <11-utils.ipynb> for more information.
"""

def size(x):
"""
Size, for torch or np
"""
return np.prod(x.shape)

if not hasattr(x, "__array__"):
return repr(x)

# Convert to numpy array - TODO: do this after we know it's small enough
x = np.array(x)
xp = array_namespace(x)

shape_str = "x".join(map(str, x.shape))
dtype_str = (
f"{x.dtype}".replace("float", "f").replace("uint", "u").replace("int", "i")
)
dtype_str = re.sub(r"^torch\.", r"", dtype_str)
type_str = f"{dtype_str}[{shape_str}]"

notes = ""
finite_vals = x[np.isfinite(x)]
all_finite = size(finite_vals) == size(x)
display_all = size(x) <= tiny
finite_vals = x[xp.isfinite(x)]
all_finite = xp.size(finite_vals) == xp.size(x)
display_all = xp.size(x) <= tiny

def disp(a, fmt):
if len(a.shape) > 1:
Expand All @@ -53,25 +82,25 @@ def disp(a, fmt):
head, tail = "[", "]"
else:
if not all_finite:
notes += f" #inf={np.isinf(x).sum()} #nan={np.isnan(x).sum()}"
notes += f" #inf={xp.isinf(x).sum()} #nan={xp.isnan(x).sum()}"

if size(x) < large:
if xp.size(x) < large:
quantiles = [0, 0.05, 0.25, 0.5, 0.75, 0.95, 1.0]
vals = np.quantile(finite_vals, quantiles, method="nearest")
vals = _quantile(finite_vals, quantiles)
head, tail = "Percentiles{", "}"
else:
# Too large to sort, just show min, median, max
vals = np.array(
[finite_vals.min(), np.median(finite_vals), finite_vals.max()],
vals = xp.array(
[finite_vals.min(), xp.median(finite_vals), finite_vals.max()],
dtype=x.dtype,
)
head, tail = "MinMedMax{", "}"

if np.issubdtype(x.dtype, np.floating):
if _is_floating(xp, x):
# scale down vals
max = np.abs(finite_vals).max() if size(finite_vals) else 0
max = xp.abs(finite_vals).max() if xp.size(finite_vals) else 0
if max > 0:
logmax = np.floor(np.log10(max))
logmax = xp.floor(xp.log10(max))
if -2 <= logmax <= 3:
logmax = 0
max_scale = 10**-logmax
Expand Down
116 changes: 96 additions & 20 deletions test/test_ndarray_str.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,123 @@
import pytest
import numpy as np
import torch
from types import SimpleNamespace

from awfutils import ndarray_str


def mx(dtype, *sz):
def np_mk(dtype, *sz):
return np.arange(np.prod(sz), dtype=dtype).reshape(sz)


def np_mx(dtype, vals):
return np.array(vals, dtype=dtype)


def np_zeros(dtype, shape):
return np.zeros(shape, dtype=dtype)


import jax.numpy as jnp


def jax_mk(dtype, *sz):
return jnp.arange(np.prod(sz), dtype=dtype).reshape(sz)


def jax_mx(dtype, vals):
return jnp.array(vals, dtype=dtype)


def jax_zeros(dtype, shape):
return jnp.zeros(shape, dtype=dtype)


numpy_to_torch_dtype_dict = {
np.bool: torch.bool,
np.uint8: torch.uint8,
np.int8: torch.int8,
np.int16: torch.int16,
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np.float32: torch.float32,
np.float64: torch.float64,
np.complex64: torch.complex64,
np.complex128: torch.complex128,
}


def torch_mk(dtype, *sz):
return torch.arange(np.prod(sz), dtype=numpy_to_torch_dtype_dict[dtype]).reshape(sz)


def torch_mx(dtype, vals):
return torch.tensor(vals)


def torch_zeros(dtype, shape):
return torch.zeros(shape, dtype=numpy_to_torch_dtype_dict[dtype])


platforms = [
pytest.param(
SimpleNamespace(mx=np_mx, mk=np_mk, zeros=np_zeros),
id="np",
),
pytest.param(
SimpleNamespace(mx=torch_mx, mk=torch_mk, zeros=torch_zeros),
id="torch",
),
pytest.param(
SimpleNamespace(mx=jax_mx, mk=jax_mk, zeros=jax_zeros),
id="jax",
),
]


def go(x, target):
act = ndarray_str(np.array(x))
act = ndarray_str(x)
print(act)
assert act == target


def test_ndarray_str():
@pytest.mark.parametrize("p", platforms)
def test_ndarray_str(p):
mx, mk = p.mx, p.mk
go(
[[1, 2, np.inf], [np.inf, np.nan, 1.1]],
"f64[2x3] [[1.000 2.000 inf], [inf nan 1.100]]",
mx(np.float32, [[1, 2, np.inf], [np.inf, np.nan, 1.1]]),
"f32[2x3] [[1.000 2.000 inf], [inf nan 1.100]]",
)
go(
mx(np.float64, 2, 3, 4),
"f64[2x3x4] Percentiles{0.000 1.000 6.000 12.000 17.000 22.000 23.000}",
mk(np.float32, 2, 3, 4),
"f32[2x3x4] Percentiles{0.000 1.000 6.000 11.000 17.000 22.000 23.000}",
)
go(
mx(np.float32, 2, 3),
mk(np.float32, 2, 3),
"f32[2x3] [[0.000 1.000 2.000], [3.000 4.000 5.000]]",
)

go(
[1, 2, 3, 4, 5, 5, 4, 3, 2, 1, np.inf, np.inf, np.nan],
"f64[13] Percentiles{1.000 1.000 2.000 3.000 4.000 5.000 5.000} #inf=2 #nan=1",
mx(np.float32, [1, 2, 3, 4, 5, 5, 4, 3, 2, 1, np.inf, np.inf, np.nan]),
"f32[13] Percentiles{1.000 1.000 2.000 3.000 4.000 5.000 5.000} #inf=2 #nan=1",
)
go(
mx(np.float32, [0.0, 0.0, np.nan]),
"f32[3] [0.0 0.0 nan]",
)
go([0.0, 0.0, np.nan], "f64[3] [0.0 0.0 nan]")


def test_zeros():
go(np.zeros([]), "f64[] [0.0]")
go(np.zeros(0), "f64[0] []")
go(np.zeros(1), "f64[1] [0.0]")
go(np.zeros(100), "f64[100] Zeros")
go(np.zeros((100, 100, 101), dtype=np.float16), "f16[100x100x101] Zeros")
@pytest.mark.parametrize("p", platforms)
def test_zeros(p):
go(p.zeros(np.float32, []), "f32[] [0.0]")
go(p.zeros(np.float32, 0), "f32[0] []")
go(p.zeros(np.float32, 1), "f32[1] [0.0]")
go(p.zeros(np.float32, 100), "f32[100] Zeros")
go(p.zeros(np.float16, (100, 100, 101)), "f16[100x100x101] Zeros")


def test_ints():
go(mx(np.int32, 2, 3), "i32[2x3] [[0 1 2], [3 4 5]]")
go(mx(np.int32, 20, 30), "i32[20x30] Percentiles{0 30 150 300 449 569 599}")
@pytest.mark.parametrize("p", platforms)
def test_ints(p):
go(p.mk(np.int32, 2, 3), "i32[2x3] [[0 1 2], [3 4 5]]")
go(p.mk(np.int32, 21, 31), "i32[21x31] Percentiles{0 32 162 325 488 618 650}")