diff --git a/.github/workflows/python-ci-main.yml b/.github/workflows/python-ci-main.yml index 7cc7352..3731892 100644 --- a/.github/workflows/python-ci-main.yml +++ b/.github/workflows/python-ci-main.yml @@ -92,9 +92,11 @@ jobs: run: | echo "Main branch version: ${{ steps.get_main_version.outputs.main_version }}" echo "PR branch version: ${{ steps.get_pr_version.outputs.pr_version }}" - if [ "${{ steps.get_main_version.outputs.main_version }}" = "${{ steps.get_pr_version.outputs.pr_version }}" ]; then - echo "Error: Version is the same as on the main branch" - exit 1 - else - echo "Ok: Version is different from the main branch" + if [ "${{ github.event_name }}" == "pull_request" ]; then + if [ "${{ steps.get_main_version.outputs.main_version }}" = "${{ steps.get_pr_version.outputs.pr_version }}" ]; then + echo "Error: Version is the same as on the main branch" + exit 1 + else + echo "Ok: Version is different from the main branch" + fi fi diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index d15daef..4fc8e95 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -1,6 +1,3 @@ -# This workflow will upload a Python Package using Twine when a release is created -# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries - name: Publish-Main-PyPi env: diff --git a/aplot/__config__.py b/aplot/__config__.py index d3ec452..d31c31e 100644 --- a/aplot/__config__.py +++ b/aplot/__config__.py @@ -1 +1 @@ -__version__ = "0.2.0" +__version__ = "0.2.3" diff --git a/aplot/__init__.py b/aplot/__init__.py index 0f499e3..bbd0f1c 100644 --- a/aplot/__init__.py +++ b/aplot/__init__.py @@ -1,9 +1,12 @@ # flake8: noqa: F401 +import matplotlib.patches as patches + from . import analysis, styles from .__config__ import __version__ from .core import ax, axs, close, figure, figure_class, show, subplot, subplots from .core.axes_class import AAxes as Axes +from .core.axes_list import AxesList from .core.figure_class import AFigure as Figure s = styles diff --git a/aplot/analysis/array_manipulation.py b/aplot/analysis/array_manipulation.py index 0da10ef..0cdad94 100644 --- a/aplot/analysis/array_manipulation.py +++ b/aplot/analysis/array_manipulation.py @@ -1,7 +1,6 @@ import typing as _t import numpy as np -import scipy ArrayLike = _t.Union[np.ndarray, _t.List] @@ -29,6 +28,8 @@ def argmin2d( Tuple[int, int]: index_y, index_x, i.e. the min value is d[index_y, index_x] """ if filter_ and filter_ > 1: + import scipy + d = scipy.ndimage.uniform_filter(d, size=3, mode="nearest") if x_mask is not None: @@ -133,7 +134,9 @@ def array_from_span( return res -def get_z(I: np.ndarray, Q: np.ndarray) -> np.ndarray: # pylint: disable=invalid-name # noqa: E741 +def get_z( + I: np.ndarray, Q: np.ndarray +) -> np.ndarray: # pylint: disable=invalid-name # noqa: E741 min_len = min(len(I), len(Q)) return I[:min_len] + 1j * Q[:min_len] diff --git a/aplot/analysis/signal_analysis.py b/aplot/analysis/signal_analysis.py index f3d8e8f..f22b9a0 100644 --- a/aplot/analysis/signal_analysis.py +++ b/aplot/analysis/signal_analysis.py @@ -1,7 +1,6 @@ import typing as _t import numpy as np -import scipy def find_h_symmetry_axis(data: np.ndarray) -> int: @@ -13,6 +12,8 @@ def find_h_symmetry_axis(data: np.ndarray) -> int: Returns: (int): x index of the symmetry axis. """ + import scipy + data = (data - np.mean(data)) / np.std(data) # corr = scipy.signal.fftconvolve( # data[:, : len(data[0]) // 2], data[:, ::-1], mode="full" @@ -22,11 +23,15 @@ def find_h_symmetry_axis(data: np.ndarray) -> int: def remove_background(data: np.ndarray, convolve_len: _t.Optional[int] = None): + import scipy + if convolve_len is None: convolve_len = min(50, len(data) // 15) data = ( data - - scipy.signal.convolve2d(data, np.ones((convolve_len, 1)), mode="same", boundary="symm") + - scipy.signal.convolve2d( + data, np.ones((convolve_len, 1)), mode="same", boundary="symm" + ) / convolve_len ) return data - data.mean(axis=1)[:, np.newaxis] diff --git a/aplot/core/axes_class.py b/aplot/core/axes_class.py index c90530d..3821302 100644 --- a/aplot/core/axes_class.py +++ b/aplot/core/axes_class.py @@ -19,6 +19,7 @@ ) _T = _t.TypeVar("_T") +_R = _t.TypeVar("_R") if _t.TYPE_CHECKING: from .figure_class import AFigure @@ -125,12 +126,29 @@ FILTER_KWARGS = {"hist2d", QuadMesh} +class ClassicReturnAxis: + def __init__(self, axes: "AAxes"): + self.axes = axes + self._previous_state = False + + def __enter__(self): + self._previous_state = self.axes._classical_return + self.axes._classical_return = True + return self.axes + + def __exit__(self, exc_type, exc_value, traceback): + self.axes._classical_return = self._previous_state + if exc_type is not None: + raise + + class AAxes( MplAxes, _t.Generic[_T], ): name = "AAxis" # Give a name for the matplotlib registry _last_result = None + _classical_return = False # _fit_result: FitResult | None = None # __all__ = MplAxes.__all__ + ["fit", "last_result", "fit_result", "res", "set"] # __dict__ = MplAxes.__dict__ ("fit", "last_result", "fit_result", "res", "set") @@ -182,7 +200,7 @@ def __getattribute__(self, name: str): def wrapper(*args, **kwargs): result = func(*args, **kwargs) - if isinstance(result, (MplAxes, AAxes)): + if isinstance(result, (MplAxes, AAxes)) or self._classical_return: return result self._last_result = result return self @@ -209,34 +227,33 @@ def set( # type: ignore "ylabel": ylabel, } ) - super().set(**filter_none_types(kwargs)) - return self + return super().set(**filter_none_types(kwargs)) + # return self - def hist2d( + def hist2d( # type: ignore self, x, y=None, - bins=10, - range=None, # pylint: disable=redefined-builtin - density=False, - weights=False, - cmin=None, - cmax=None, + *args, **kwargs, ): if y is None: x = np.array(x) - x = x[:, 0] - y = x[:, 1] - return super().hist2d(x, y, bins, range, density, weights, cmin, cmax, **kwargs) + y = x[..., 1] + x = x[..., 0] + return super().hist2d(x, y, *args, **kwargs) + + def hist(self, *args, **kwargs): + with ClassicReturnAxis(self): + return super().hist(*args, **kwargs) def z_parametric(self, z, **kwargs): - self.plot(np.real(z), np.imag(z), **kwargs) - return self + return self.plot(np.real(z), np.imag(z), **kwargs) + # return self - def z_historograms(self, z, **kwargs): - self.hist2d(np.real(z), np.imag(z), **kwargs) - return self + def hist_z(self, z, **kwargs): + return self.hist2d(np.real(z), np.imag(z), **kwargs) + # return self def imshow( # type: ignore self, @@ -298,10 +315,11 @@ def imshow( # type: ignore raise ValueError("The figure is None cannot add colorbar") cbar = fig.colorbar(im, cax=cax, orientation="vertical") cbar.ax.set_ylabel(kwargs.get("bar_label", "")) + cbar.ax.set_rasterized(kwargs.get("bar_rasterized", kwargs.get("rasterized", False))) else: cbar = None - - return self + return im + # return self def pcolorfast( # type: ignore self, @@ -344,7 +362,8 @@ def pcolorfast( # type: ignore if colorbar: cbar = fig.colorbar(im, cax=cax, orientation="vertical") cbar.ax.set_ylabel(kwargs.get("bar_label", "")) - return self + return im + # return self def autoaxis(self, level: int = 0, func_name="plot") -> "AAxes": variables = get_auto_args(level, func_name) @@ -357,7 +376,14 @@ def tight_layout(self, *, pad=1.08, h_pad=None, w_pad=None, rect=None): self.figure.tight_layout(pad=pad, h_pad=h_pad, w_pad=w_pad, rect=rect) # type: ignore return self - def plot(self, *args, keep_xlims: bool = False, keep_ylims: bool = False, axes=None, **kwargs): + def plot( + self, + *args, + keep_xlims: bool = False, + keep_ylims: bool = False, + axes=None, + **kwargs, + ): del axes xlims = self.get_xlim() if keep_xlims else None ylims = self.get_ylim() if keep_ylims else None @@ -368,9 +394,38 @@ def plot(self, *args, keep_xlims: bool = False, keep_ylims: bool = False, axes=N self.set_ylim(*ylims) return res + def axhline(self, y=0, xmin=0, xmax=1, **kwargs) -> "AAxes": # type: ignore + if isinstance(y, _t.Iterable): + return self.update_result( + [self.axhline(y_, xmin=xmin, xmax=xmax, **kwargs).res for y_ in y] + ) + return self.update_result(super().axhline(y, xmin=xmin, xmax=xmax, **kwargs)) + + def axvline(self, x=0, ymin=0, ymax=1, **kwargs) -> "AAxes": # type: ignore + if isinstance(x, _t.Iterable): + return self.update_result( + [self.axvline(x_, ymin=ymin, ymax=ymax, **kwargs).res for x_ in x] + ) + return self.update_result(super().axvline(x, ymin=ymin, ymax=ymax, **kwargs)) + def __add__(self, other): from .axes_list import AxesList if isinstance(other, list): return AxesList([self] + other) # type: ignore return AxesList([self, other]) # type: ignore + + def update_result(self, result: _R) -> "AAxes[_R]": + self._last_result = result + return self # type: ignore + + def colorbar(self, label: _t.Optional[str] = None, *args, **kwargs): + c = self.res + assert c is not None + cbar = self.fig.colorbar(c, ax=self) + if label is not None: + cbar.set_label(label) + return self + + def classic_return(self): + return ClassicReturnAxis(self) diff --git a/aplot/core/axes_class.pyi b/aplot/core/axes_class.pyi index aa02cd7..2b54d8f 100644 --- a/aplot/core/axes_class.pyi +++ b/aplot/core/axes_class.pyi @@ -1,6 +1,16 @@ # flake8: noqa: E302, E704 import datetime -from typing import Callable, Generic, Literal, Sequence, TypeVar, Union, overload +from typing import ( + Callable, + Generic, + Iterable, + List, + Literal, + Sequence, + TypeVar, + Union, + overload, +) import numpy as np from matplotlib.artist import Artist @@ -73,11 +83,36 @@ class AAxes(MplAxes, Generic[_T]): pad: float = ..., *, y: float = ..., - **kwargs + **kwargs, ) -> "AAxes": ... - def legend(self, *args, **kwargs) -> "AAxes": ... # type: ignore + def legend( + self, + *args, + loc: Union[ + int, + Literal[ + "best", + "upper right", + "upper left", + "lower left", + "lower right", + "right", + "center left", + "center right", + "lower center", + "upper center", + "center", + ], + ] = "best", + **kwargs, + ) -> "AAxes": ... # type: ignore def inset_axes( - self, bounds: Sequence[float], *, transform: Transform = ..., zorder: float = ..., **kwargs + self, + bounds: Sequence[float], + *, + transform: Transform = ..., + zorder: float = ..., + **kwargs, ) -> "AAxes": ... def indicate_inset( # type: ignore self, @@ -89,7 +124,7 @@ class AAxes(MplAxes, Generic[_T]): edgecolor: Color = ..., alpha: float = ..., zorder: float = ..., - **kwargs + **kwargs, ) -> "AAxes": ... def indicate_inset_zoom(self, inset_ax: _Axes, **kwargs) -> "AAxes[Rectangle]": ... # type: ignore def secondary_xaxis( # type: ignore @@ -97,14 +132,14 @@ class AAxes(MplAxes, Generic[_T]): location: Literal["top", "bottom", "left", "right"] | float, *, functions=..., - **kwargs + **kwargs, ) -> "AAxes[SecondaryAxis]": ... def secondary_yaxis( # type: ignore self, location: Literal["top", "bottom", "left", "right"] | float, *, functions=..., - **kwargs + **kwargs, ) -> "AAxes[SecondaryAxis]": ... def text(self, x: float, y: float, s: str, fontdict: dict = ..., **kwargs) -> "AAxes[Text]": ... # type: ignore def annotate( # type: ignore @@ -116,21 +151,35 @@ class AAxes(MplAxes, Generic[_T]): textcoords: str | Artist | Transform | Callable = ..., arrowprops: dict = ..., annotation_clip: bool | None = ..., - **kwargs + **kwargs, ) -> "AAxes[Annotation]": ... + @overload def axhline( # type: ignore - self, y: float = 0, xmin: float = 0, xmax: float = 1, **kwargs + self, y: float, xmin: float = 0, xmax: float = 1, **kwargs ) -> "AAxes[Line2D]": ... + @overload + def axhline( # type: ignore + self, y: Iterable[float], xmin: float = 0, xmax: float = 1, **kwargs + ) -> "AAxes[List[Line2D]]": ... + def axhline( # type: ignore + self, y=0, xmin: float = 0, xmax: float = 1, **kwargs + ) -> "Union[AAxes[List[Line2D]], AAxes[Line2D]]": ... def axvline( # type: ignore self, x: float = ..., ymin: float = ..., ymax: float = ..., **kwargs ) -> "AAxes[Line2D]": ... + def axvline( # type: ignore + self, x: Iterable[float] = ..., ymin: float = ..., ymax: float = ..., **kwargs + ) -> "AAxes[List[Line2D]]": ... + def axvline( # type: ignore + self, x: float = ..., ymin: float = ..., ymax: float = ..., **kwargs + ) -> "Union[AAxes[List[Line2D]], AAxes[Line2D]]": ... def axline( # type: ignore self, xy1: tuple[float, float], xy2: tuple[float, float] = ..., *, slope: float = ..., - **kwargs + **kwargs, ) -> "AAxes[Line2D]": ... def axhspan( # type: ignore self, ymin: float, ymax: float, xmin: float = ..., xmax: float = ..., **kwargs @@ -146,7 +195,7 @@ class AAxes(MplAxes, Generic[_T]): colors: list[Color] = ..., linestyles: Literal["solid", "dashed", "dashdot", "dotted"] = ..., label: str = ..., - **kwargs + **kwargs, ) -> "AAxes[LineCollection]": ... def vlines( # type: ignore self, @@ -156,7 +205,7 @@ class AAxes(MplAxes, Generic[_T]): colors: list[Color] = ..., linestyles: Literal["solid", "dashed", "dashdot", "dotted"] = ..., label: str = ..., - **kwargs + **kwargs, ) -> "AAxes[LineCollection]": ... def eventplot( # type: ignore self, @@ -167,7 +216,7 @@ class AAxes(MplAxes, Generic[_T]): linewidths: float | ArrayLike = ..., colors: Color | list[Color] = ..., linestyles: str | tuple | list = ..., - **kwargs + **kwargs, ) -> "AAxes[list[EventCollection]]": ... def plot(self, *args, scalex=..., scaley=..., data=..., **kwargs) -> "AAxes[list[Line2D]]": ... # type: ignore def plot_date( # type: ignore @@ -178,7 +227,7 @@ class AAxes(MplAxes, Generic[_T]): tz: datetime.tzinfo = ..., xdate: bool = ..., ydate: bool = ..., - **kwargs + **kwargs, ) -> "AAxes[list[Line2D]]": ... def loglog(self, *args, **kwargs) -> "AAxes[list[Line2D]]": ... # type: ignore def semilogx(self, *args, **kwargs) -> "AAxes[list[Line2D]]": ... # type: ignore @@ -192,7 +241,7 @@ class AAxes(MplAxes, Generic[_T]): detrend: Callable = ..., usevlines: bool = True, maxlags: int = 10, - **kwargs + **kwargs, ) -> "AAxes[tuple[np.ndarray, np.ndarray, bool, int]]": ... def step( # type: ignore self, @@ -201,7 +250,7 @@ class AAxes(MplAxes, Generic[_T]): *args, where: Literal["pre", "post", "mid"] = ..., data=..., - **kwargs + **kwargs, ) -> "AAxes[list[Line2D]]": ... def bar( # type: ignore self, @@ -211,7 +260,7 @@ class AAxes(MplAxes, Generic[_T]): bottom: float | ArrayLike = ..., *, align: Literal["center", "edge"] = "center", - **kwargs + **kwargs, ) -> "AAxes[BarContainer]": ... def barh( # type: ignore self, @@ -221,7 +270,7 @@ class AAxes(MplAxes, Generic[_T]): left: float | ArrayLike = ..., *, align: Literal["center", "edge"] = "center", - **kwargs + **kwargs, ) -> "AAxes[BarContainer]": ... def bar_label( # type: ignore self, @@ -231,10 +280,13 @@ class AAxes(MplAxes, Generic[_T]): fmt: str = "%g", label_type: Literal["edge", "center"] = "edge", padding: float = 0, - **kwargs + **kwargs, ) -> "AAxes[list[Text]]": ... def broken_barh( # type: ignore - self, xranges: Sequence[tuple[float, float]], yrange: tuple[float, float], **kwargs + self, + xranges: Sequence[tuple[float, float]], + yrange: tuple[float, float], + **kwargs, ) -> "AAxes[BrokenBarHCollection]": ... def stem( # type: ignore self, @@ -245,7 +297,7 @@ class AAxes(MplAxes, Generic[_T]): bottom: float = 0, label: str | None = None, use_line_collection: bool = True, - orientation: str = "verical" + orientation: str = "verical", ) -> "AAxes[StemContainer]": ... def pie( # type: ignore self, @@ -266,7 +318,7 @@ class AAxes(MplAxes, Generic[_T]): frame: bool = False, rotatelabels: bool = False, *, - normalize: bool = True + normalize: bool = True, ) -> "AAxes[tuple[list[Wedge], list[Text], list[Text]]]": ... def errorbar( # type: ignore self, @@ -285,7 +337,7 @@ class AAxes(MplAxes, Generic[_T]): xuplims: bool = False, errorevery: int = 1, capthick: float | None = None, - **kwargs + **kwargs, ) -> "AAxes[ErrorbarContainer]": ... def boxplot( # type: ignore self, @@ -356,7 +408,7 @@ class AAxes(MplAxes, Generic[_T]): *, edgecolors: Color = ..., plotnonfinite: bool = False, - **kwargs + **kwargs, ) -> "AAxes[PathCollection]": ... def hexbin( # type: ignore self, @@ -378,7 +430,7 @@ class AAxes(MplAxes, Generic[_T]): reduce_C_function=..., mincnt: int | None = None, marginals: bool = False, - **kwargs + **kwargs, ) -> "AAxes[PolyCollection]": ... def arrow(self, x: float, y: float, dx: float, dy: float, **kwargs) -> "AAxes[FancyArrow]": ... # type: ignore def quiverkey( # type: ignore @@ -395,7 +447,7 @@ class AAxes(MplAxes, Generic[_T]): where: ArrayLike = ..., interpolate: bool = ..., step: Literal["pre", "post", "mid"] = ..., - **kwargs + **kwargs, ) -> "AAxes[PolyCollection]": ... def fill_betweenx( # type: ignore self, @@ -405,7 +457,7 @@ class AAxes(MplAxes, Generic[_T]): where: ArrayLike = ..., step: Literal["pre", "post", "mid"] = ..., interpolate: bool = ..., - **kwargs + **kwargs, ) -> "AAxes[PolyCollection]": ... def imshow( # type: ignore self, @@ -425,7 +477,7 @@ class AAxes(MplAxes, Generic[_T]): filterrad: float = 4, resample: bool = ..., url: str = ..., - **kwargs + **kwargs, ) -> "AAxes[AxesImage]": ... def pcolor( # type: ignore self, @@ -436,7 +488,7 @@ class AAxes(MplAxes, Generic[_T]): cmap: str | Colormap = ..., vmin: float | None = None, vmax: float | None = None, - **kwargs + **kwargs, ) -> "AAxes[Collection]": ... def pcolormesh( # type: ignore self, @@ -448,7 +500,7 @@ class AAxes(MplAxes, Generic[_T]): vmax: float | None = None, shading: Literal["flat", "nearest", "gouraud", "auto"] = ..., antialiased=..., - **kwargs + **kwargs, ) -> "AAxes[QuadMesh]": ... def pcolorfast( # type: ignore self, @@ -458,7 +510,7 @@ class AAxes(MplAxes, Generic[_T]): cmap: str | Colormap = ..., vmin: float | None = None, vmax: float | None = None, - **kwargs + **kwargs, ) -> "AAxes[tuple[AxesImage, PcolorImage, QuadMesh]]": ... def contour(self, *args, **kwargs) -> "AAxes[QuadContourSet]": ... # type: ignore def contourf(self, *args, **kwargs) -> "AAxes[QuadContourSet]": ... # type: ignore @@ -481,7 +533,7 @@ class AAxes(MplAxes, Generic[_T]): color: Color | None = ..., label: str | None = ..., stacked: bool = ..., - **kwargs + **kwargs, ) -> "AAxes[tuple[list[list[float]], list[float], BarContainer | list]]": ... @overload def hist( # type: ignore @@ -501,7 +553,7 @@ class AAxes(MplAxes, Generic[_T]): color: Color | None = None, label: str | None = None, stacked: bool = False, - **kwargs + **kwargs, ) -> "AAxes[tuple[list[float], list[float], BarContainer | list]]": ... def stairs( # type: ignore self, @@ -511,7 +563,7 @@ class AAxes(MplAxes, Generic[_T]): orientation: Literal["vertical", "horizontal"] = "vertical", baseline: float | ArrayLike | None = 0, fill: bool = False, - **kwargs + **kwargs, ) -> "AAxes[StepPatch]": ... def hist2d( # type: ignore self, @@ -523,8 +575,10 @@ class AAxes(MplAxes, Generic[_T]): weights=..., cmin: float | None = None, cmax: float | None = None, - **kwargs - ) -> "AAxes[tuple[np.ndarray, np.ndarray, np.ndarray, tuple[float, float] | None]]": ... + **kwargs, + ) -> ( + "AAxes[tuple[np.ndarray, np.ndarray, np.ndarray, tuple[float, float] | None]]" + ): ... def psd( # type: ignore self, x: Sequence, @@ -538,7 +592,7 @@ class AAxes(MplAxes, Generic[_T]): sides: Literal["default", "onesided", "twosided"] = ..., scale_by_freq: bool = ..., return_line: bool = False, - **kwargs + **kwargs, ) -> "AAxes[tuple[np.ndarray, np.ndarray, Line2D]]": ... def csd( # type: ignore self, @@ -554,7 +608,7 @@ class AAxes(MplAxes, Generic[_T]): sides: Literal["default", "onesided", "twosided"] = ..., scale_by_freq: bool = ..., return_line: bool = False, - **kwargs + **kwargs, ) -> "AAxes[tuple[np.ndarray, np.ndarray, Line2D]]": ... def magnitude_spectrum( # type: ignore self, @@ -565,7 +619,7 @@ class AAxes(MplAxes, Generic[_T]): pad_to: int = ..., sides: Literal["default", "onesided", "twosided"] = ..., scale: Literal["default", "linear", "dB"] = "linear", - **kwargs + **kwargs, ) -> "AAxes[tuple[np.ndarray, np.ndarray, Line2D]]": ... def angle_spectrum( # type: ignore self, @@ -575,7 +629,7 @@ class AAxes(MplAxes, Generic[_T]): window: Callable | np.ndarray = ..., pad_to: int = ..., sides: Literal["default", "onesided", "twosided"] = ..., - **kwargs + **kwargs, ) -> "AAxes[tuple[np.ndarray, np.ndarray, Line2D]]": ... def phase_spectrum( # type: ignore self, @@ -585,7 +639,7 @@ class AAxes(MplAxes, Generic[_T]): window: Callable | np.ndarray = ..., pad_to: int = ..., sides: Literal["default", "onesided", "twosided"] = ..., - **kwargs + **kwargs, ) -> "AAxes[tuple[np.ndarray, np.ndarray, Line2D]]": ... def cohere( # type: ignore self, @@ -600,7 +654,7 @@ class AAxes(MplAxes, Generic[_T]): pad_to: int = ..., sides: Literal["default", "onesided", "twosided"] = ..., scale_by_freq: bool = ..., - **kwargs + **kwargs, ) -> "AAxes[tuple[np.ndarray, np.ndarray]]": ... def specgram( # type: ignore self, @@ -620,7 +674,7 @@ class AAxes(MplAxes, Generic[_T]): scale: Literal["default", "linear", "dB"] = "dB", vmin=..., vmax=..., - **kwargs + **kwargs, ) -> "AAxes[tuple[np.ndarray, np.ndarray, np.ndarray, AxesImage]]": ... def spy( # type: ignore self, @@ -630,7 +684,7 @@ class AAxes(MplAxes, Generic[_T]): markersize=..., aspect: Literal["equal", "auto", None] | float = "equal", origin: Literal["upper", "lower"] = ..., - **kwargs + **kwargs, ) -> "AAxes[AxesImage | Line2D]": ... # type: ignore def matshow(self, Z: ArrayLike, **kwargs) -> "AAxes[AxesImage]": ... # type: ignore def violinplot( # type: ignore @@ -720,7 +774,7 @@ class AAxes(MplAxes, Generic[_T]): visible: bool | None = ..., which: Literal["major", "minor", "both"] = ..., axis: Literal["both", "x", "y"] = ..., - **kwargs + **kwargs, ) -> "AAxes[None]": ... def ticklabel_format( # type: ignore self, @@ -730,7 +784,7 @@ class AAxes(MplAxes, Generic[_T]): scilimits=..., useOffset: bool | float = ..., useLocale: bool = ..., - useMathText: bool = ... + useMathText: bool = ..., ) -> "AAxes[None]": ... def locator_params( # type: ignore self, axis: Literal["both", "x", "y"] = ..., tight: bool | None = ..., **kwargs @@ -745,7 +799,7 @@ class AAxes(MplAxes, Generic[_T]): labelpad: float = ..., *, loc: Literal["left", "center", "right"] = ..., - **kwargs + **kwargs, ) -> "AAxes[None]": ... def invert_xaxis(self) -> "AAxes[None]": ... # type: ignore def set_xbound(self, lower: float | None = ..., upper: float | None = ...) -> "AAxes[None]": ... # type: ignore @@ -757,7 +811,7 @@ class AAxes(MplAxes, Generic[_T]): emit: bool = ..., auto: bool | None = ..., xmin: float = ..., - xmax: float = ... + xmax: float = ..., ) -> "AAxes[tuple[float, float]]": ... @overload def set_xlim( # type: ignore @@ -768,7 +822,7 @@ class AAxes(MplAxes, Generic[_T]): auto: bool | None = ..., *, xmin: float = ..., - xmax: float = ... + xmax: float = ..., ) -> "AAxes[tuple[float, float]]": ... def set_xscale(self, value: ..., **kwargs) -> "AAxes[None]": ... # type: ignore def set_ylabel( # type: ignore @@ -778,7 +832,7 @@ class AAxes(MplAxes, Generic[_T]): labelpad: float = ..., *, loc: Literal["bottom", "center", "top"] = ..., - **kwargs + **kwargs, ) -> "AAxes[None]": ... def invert_yaxis(self) -> "AAxes[None]": ... # type: ignore def set_ybound(self, lower: float | None = ..., upper: float | None = ...) -> "AAxes[None]": ... # type: ignore @@ -790,7 +844,7 @@ class AAxes(MplAxes, Generic[_T]): auto: bool | None = ..., *, ymin: float = ..., - ymax: float = ... + ymax: float = ..., ) -> "AAxes[None]": ... def set_yscale( # type: ignore self, value: Literal["linear", "log", "symlog", "logit"] | ScaleBase, **kwargs @@ -819,8 +873,27 @@ class AAxes(MplAxes, Generic[_T]): pad: float = ..., h_pad: float = ..., w_pad: float = ..., - rect: Sequence[float] = ... + rect: Sequence[float] = ..., ) -> _S: ... def __add__(self, other) -> "AxesList": ... - def set_xticks(self, ticks: ArrayLike, labels: ArrayLike | None = None) -> "AAxes[None]": ... - def set_yticks(self, ticks: ArrayLike, labels: ArrayLike | None = None) -> "AAxes[None]": ... + def set_xticks( + self, ticks: ArrayLike, labels: ArrayLike | None = None + ) -> "AAxes[None]": ... + def set_yticks( + self, ticks: ArrayLike, labels: ArrayLike | None = None + ) -> "AAxes[None]": ... + def colorbar(self: _Axes, *args, **kwargs) -> _Axes: ... + def hist_z( # type: ignore + self, + z, + bins: None | int | ArrayLike = ..., + range=..., + density: bool = False, + weights=..., + cmin: float | None = None, + cmax: float | None = None, + **kwargs, + ) -> ( + "AAxes[tuple[np.ndarray, np.ndarray, np.ndarray, tuple[float, float] | None]]" + ): ... + def classic_return(self): ... diff --git a/aplot/core/axes_list.py b/aplot/core/axes_list.py index dd5485c..9307eac 100644 --- a/aplot/core/axes_list.py +++ b/aplot/core/axes_list.py @@ -3,7 +3,7 @@ import numpy as np from .axes_class import AAxes -from .utils import filter_set_kwargs, pop_from_dict +from .utils import filter_set_kwargs, pop_from_dict, get_edge_points # from matplotlib import pyplot as plt @@ -112,6 +112,8 @@ def plot_z_2d( else: raise ValueError("Plot_format should be either bode or real_imag") + x = get_edge_points(x) + y = get_edge_points(y) kwargs_without_xlabel = pop_from_dict(kwargs, "xlabel") self[0].pcolorfast(x=x, y=y, data=data1, **kwargs_without_xlabel) self[1].pcolorfast(x=x, y=y, data=data2, **kwargs) @@ -204,4 +206,24 @@ def __getitem__(self, key: _t.Union[int, _t.Tuple[int, ...]]): # type: ignore if not isinstance(res, AxesList): return AxesList(res) return res - return super().__getitem__(key) + res = super().__getitem__(key) + if not isinstance(res, (AAxes, AxesList)): + return AxesList(res) + return res + + def flat(self): + res = [] + for ax in self: + if isinstance(ax, AxesList): + res.extend(ax.flat()) + else: + res.append(ax) + return AxesList(res) + + def hist_z(self, z, **kwargs): + if len(z) == len(self): + for i, ax in enumerate(self): + ax.hist_z(z[i], **kwargs) + return self + self.map(lambda ax: ax.hist_z(z, **kwargs)) + return self diff --git a/aplot/core/axes_list.pyi b/aplot/core/axes_list.pyi index b0cdd07..eb4c003 100644 --- a/aplot/core/axes_list.pyi +++ b/aplot/core/axes_list.pyi @@ -134,7 +134,12 @@ class AxesList(List[_T]): **kwargs, ) -> _S: ... def axhspan( # type: ignore - self: _S, ymin: float, ymax: float, xmin: float = ..., xmax: float = ..., **kwargs + self: _S, + ymin: float, + ymax: float, + xmin: float = ..., + xmax: float = ..., + **kwargs, ) -> _S: ... def axvspan( # type: ignore self: _S, xmin: float, xmax: float, ymin: float = 0, ymax: float = 1, **kwargs @@ -235,7 +240,10 @@ class AxesList(List[_T]): **kwargs, ) -> _S: ... def broken_barh( # type: ignore - self: _S, xranges: Sequence[tuple[float, float]], yrange: tuple[float, float], **kwargs + self: _S, + xranges: Sequence[tuple[float, float]], + yrange: tuple[float, float], + **kwargs, ) -> _S: ... def stem( # type: ignore self: _S, @@ -734,7 +742,10 @@ class AxesList(List[_T]): useMathText: bool = ..., ) -> _S: ... def locator_params( # type: ignore - self: _S, axis: Literal["both", "x", "y"] = ..., tight: bool | None = ..., **kwargs + self: _S, + axis: Literal["both", "x", "y"] = ..., + tight: bool | None = ..., + **kwargs, ) -> _S: ... def tick_params(self: _S, axis: Literal["x", "y", "both"] = ..., **kwargs) -> _S: ... # type: ignore def set_axis_off(self: _S) -> _S: ... # type: ignore @@ -794,7 +805,9 @@ class AxesList(List[_T]): ymax: float = ..., ) -> _S: ... def set_yscale( # type: ignore - self: _S, value: Literal["linear", "log", "symlog", "logit"] | ScaleBase, **kwargs + self: _S, + value: Literal["linear", "log", "symlog", "logit"] | ScaleBase, + **kwargs, ) -> _S: ... def minorticks_on(self: _S) -> _S: ... # type: ignore def minorticks_off(self: _S) -> _S: ... # type: ignore @@ -841,7 +854,21 @@ class AxesList(List[_T]): ) -> _S: ... def map(self: _S, func: Callable[[AAxes], Any]) -> _S: ... def suptitle(self: _S, title: str) -> _S: ... - def __getitem__( + def __getitem__( # type: ignore self, key: Union[int, Tuple[Union[int, slice], ...], slice], ) -> _T: ... # type: ignore + def flat(self) -> "AxesList[AAxes]": ... + def hist_z( # type: ignore + self, + z, + bins: None | int | ArrayLike = ..., + range=..., + density: bool = False, + weights=..., + cmin: float | None = None, + cmax: float | None = None, + **kwargs, + ) -> ( + "AAxes[tuple[np.ndarray, np.ndarray, np.ndarray, tuple[float, float] | None]]" + ): ... diff --git a/aplot/core/figure_class.py b/aplot/core/figure_class.py index c283558..158933f 100644 --- a/aplot/core/figure_class.py +++ b/aplot/core/figure_class.py @@ -1,19 +1,34 @@ -from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, overload +from copy import copy +from typing import ( + TYPE_CHECKING, + Any, + List, + Literal, + Optional, + Tuple, + TypeVar, + Union, + overload, +) import matplotlib.pyplot as plt +import numpy as np from matplotlib.figure import Figure as MplFigure if TYPE_CHECKING: - from mpl_toolkits.mplot3d.axes3d import Axes3D as MplAxes3D from matplotlib.projections.polar import PolarAxes as MplPolarAxes + from mpl_toolkits.mplot3d.axes3d import Axes3D as MplAxes3D from .axes_class import AAxes +from .axes_list import AxesList _T = TypeVar("_T") +_F = TypeVar("_F", bound="AFigure") +_T = TypeVar("_T") -class AFigure(MplFigure): +class AFigure(MplFigure): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -25,7 +40,7 @@ def add_subplot(self, *args, **kwargs) -> AAxes: # type: ignore # Ensuring that the custom axes class is used if "projection" not in kwargs and "polar" not in kwargs: kwargs.update({"axes_class": AAxes}) - return super().add_subplot(*args, **kwargs) + return super().add_subplot(*args, **kwargs) # type: ignore def savefig(self, fname: Any, *, transparent=None, **kwargs): # type: ignore super().savefig(fname, transparent=transparent, **kwargs) @@ -53,3 +68,176 @@ def add_axes(self, *args, **kwargs): # type: ignore def show(self): # type: ignore plt.show(self) + return self + + def close(self): # type: ignore + plt.close(self) + return self + + @property + def axes(self) -> "AxesList[AAxes]": # type: ignore + return AxesList(self._axstack.as_list()) # type: ignore + + def label_axes( + self: _F, + labels: Union[ # type: ignore + Literal["vertical", "horizontal"], List[Optional[Union[str, int]]] + ] = "horizontal", + *, + axes: Optional["AxesList"] = None, + label_position: Union[ + Tuple[float, float], List[Optional[Tuple[float, float]]] + ] = ( + 0.02, + 0.95, + ), + fontsize: Optional[Union[int, float, List[Union[float, int]]]] = None, + capitalize: bool = False, + label_titles: Optional[List[str]] = None, + **kwargs, + ) -> _F: + """Label the axes of the figure. + + Args: + labels (Union[Literal["vertical", "horizontal"], List[str]], optional): + - "vertical": Label the axes vertically first, then horizontally. + - "horizontal": Label the axes horizontally first, then vertically. + - List[str | int | None]: List of labels to use for each axes. + if None, the axes will not be labeled. + if int, the axes will be labeled with the corresponding alphabet. + Defaults to "horizontal". + axes (Optional["AxesList"], optional): _description_. Defaults to None. + label_position (Union[Tuple[float, float], List[Tuple[float, float]]], optional): + The (x, y) position of the label. Defaults to (0.02, 0.95). + If a list of tuples is provided, each tuple will be used for each axes. + fontsize (Optional[Union[int, float, List[Union[float, int]]]], optional): + Fontsize of the labels. Defaults to None. + If a list of font sizes is provided, each font size will be used for each axes. + capitalize (bool, optional): Capitalize the labels. Defaults to False. + **kwargs: Additional keyword arguments to pass to the text function. + + Raises: + NotImplementedError: Raised if labels is set to "vertical" + + Returns: + _F: Figure itself + + Example: + ``` + import numpy as np + import aplot as ap + + x = np.linspace(0, 10, 100) + y = np.sin(x) + ax = ap.axs(2, 2, figsize=(10, 8)) + ax.plot(x, y) + + ax.fig.label_axes().show() + ``` + """ + if axes is None: + axes = self.axes + axes_list = filter_secondary_axes(axes.flat()) + + if isinstance(labels, list): + if len(labels) != len(axes_list): + raise ValueError( + "Length of labels should be equal to the number of axes" + ) + + labels: List[Optional[str]] = [ + ( + ( + f"({chr(65+((int(i)-1) % len(axes_list))).lower()})" + if isinstance(i, int) + else str(i) + ) + if i is not None + else None + ) + for i in labels + ] + if capitalize: + labels = [(label.upper() if label else label) for label in labels] + + if label_titles is not None: + for i, label in enumerate(label_titles): + if labels[i] is not None: + labels[i] = f"{labels[i]} {label}" + + elif labels == "horizontal": + labels = [f"({chr(65+i)})".lower() for i in range(len(axes_list))] + if capitalize: + labels = [(label.upper() if label else label) for label in labels] + if label_titles is not None: + for i, label in enumerate(label_titles): + if labels[i] is not None: + labels[i] = f"{labels[i]} {label}" + + elif labels == "vertical": + raise NotImplementedError("Vertical labels not yet implemented") + + if len(label_position) == 2 and isinstance(label_position[0], (int, float)): + label_position_each = False + else: + label_position_each = True + for i, (ax, label) in enumerate(zip(axes_list, labels)): + if label is None or ( + label_position_each is True and label_position[i] is None + ): + continue + text_kwargs = copy(kwargs) + x_pos: float = ( + label_position[i][0] if label_position_each else label_position[0] + ) # type: ignore + y_pos: float = ( + label_position[i][1] if label_position_each else label_position[1] + ) # type: ignore + if fontsize is not None: + fs = fontsize[i] if isinstance(fontsize, (list, tuple)) else fontsize + text_kwargs.setdefault("fontsize", fs) + text_kwargs.setdefault("transform", ax.transAxes) + text_kwargs.setdefault("va", "top") + + getattr(ax, "text2D", ax.text)( + x_pos, + y_pos, + label, + **text_kwargs, + ) + return self + + +def detect_minor_axes(ax: "AAxes") -> bool: + """Detect if the axes are minor axes. + + Args: + axes (AxesList[AAxes]): List of axes + + Returns: + bool: True if the axes are minor axes + """ + if hasattr(ax, "_colorbar"): + return True + return False + + +def filter_secondary_axes(axes: "List[AAxes]") -> "List[AAxes]": + """Detect and remove if the axes are secondary axes. + + Args: + axes (List[AAxes]): List of axes + + Returns: + List[AAxes]: List of secondary axes + """ + axes = [ax for ax in axes if not detect_minor_axes(ax)] + axes_list: "List[AAxes]" = [] + for ax1 in axes: + for ax2 in axes_list: + if np.isclose(ax1.get_position().bounds, ax2.get_position().bounds).all(): + break + else: + axes_list.append(ax1) + + return axes_list diff --git a/aplot/core/utils.py b/aplot/core/utils.py index 7872d90..c79a251 100644 --- a/aplot/core/utils.py +++ b/aplot/core/utils.py @@ -1,5 +1,7 @@ import typing as _t +import numpy as np + from .typing import NoneType if _t.TYPE_CHECKING: @@ -95,3 +97,16 @@ def pop_from_dict(data, keys: _t.Union[str, _t.Tuple[str, ...]]): if key in main: main.pop(key) return main + + +def get_center_points(x): + return x[:-1] + np.diff(x) / 2 + + +def get_edge_points(m): + m = np.asarray(m) + edges = np.empty(m.size + 1) + edges[1:-1] = (m[:-1] + m[1:]) / 2 # inner edges + edges[0] = m[0] - (m[1] - m[0]) / 2 # extrapolate left edge + edges[-1] = m[-1] + (m[-1] - m[-2]) / 2 # extrapolate right edge + return edges diff --git a/aplot/styles/__init__.py b/aplot/styles/__init__.py index 16fac4c..3c409df 100644 --- a/aplot/styles/__init__.py +++ b/aplot/styles/__init__.py @@ -1,65 +1,72 @@ from ..code_utils import LabelDict -colors = [ - "#0066cc", - "#ffcc00", - "#ff7400", - "#962fbf", - "#8b5a2b", - "#d62976", - "#b8a7ea", - "#ed5555", - "#1da2d8", -] +# colors = [ +# "#0066cc", +# "#ffcc00", +# "#ff7400", +# "#962fbf", +# "#8b5a2b", +# "#d62976", +# "#b8a7ea", +# "#ed5555", +# "#1da2d8", +# ] + DATA = LabelDict( { "markerfacecolor": "none", - "markeredgecolor": colors[0], - "marker": "h", + # "markeredgecolor": colors[0], + "marker": "o", "linestyle": "none", } ) -FIT = LabelDict({"color": colors[1], "linewidth": 2, "label": "fit"}) -GUESS = LabelDict({"color": colors[2], "linewidth": 2, "label": "guess", "alpha": 0.6}) +FIT = LabelDict({"linewidth": 2, "label": "fit"}) # "color": colors[1], +GUESS = LabelDict( + {"linewidth": 2, "label": "guess", "alpha": 0.6} +) # "color": colors[2], -VOLT_TIME = LabelDict( - { - "xlabel": "Voltage, V", - "ylabel": r"Time, $\mu$s", - } -) -IQquadrature = LabelDict( - { - "xlabel": "I quadrature", - "ylabel": "Q quadrature", - "aspect": "equal", - } -) +# VOLT_TIME = LabelDict( +# { +# "xlabel": "Voltage, V", +# "ylabel": r"Time, $\mu$s", +# } +# ) +# IQquadrature = LabelDict( +# { +# "xlabel": "I quadrature", +# "ylabel": "Q quadrature", +# "aspect": "equal", +# } +# ) -DRIVE_FREQ = "Drive frequency (Hz)" -DRIVE_FREQ_GHz = "Drive frequency (GHz)" -READOUT_FREQ = "Readout IF frequency (Hz)" -READOUT_PHASE = "Readout phase (rad)" -LEFT = "Left" -BIAS_VOLTAGE = "Bias voltage (V)" +# DRIVE_FREQ = "Drive frequency (Hz)" +# DRIVE_FREQ_GHz = "Drive frequency (GHz)" +# READOUT_FREQ = "Readout IF frequency (Hz)" +# READOUT_PHASE = "Readout phase (rad)" +# LEFT = "Left" +# BIAS_VOLTAGE = "Bias voltage (V)" -TWO_TONE = LabelDict({"xlabel": BIAS_VOLTAGE, "ylabel": DRIVE_FREQ}, GHz={"ylabel": DRIVE_FREQ_GHz}) -CHEVRON = LabelDict({"xlabel": "Pulse duration (ns)", "ylabel": "Frequency (MHz)"}) -AMPLITUDE_TIME = LabelDict({"xlabel": "Pulse duration (ns)", "ylabel": "Frequency (MHz)"}) -RAMSEY = LabelDict( - { - "label": "ramsey", - "xlabel": "Pulse duration (ns)", - "ylabel": "Readout quadrature", - } -) -REIM = LabelDict( - { - "xlabel": "Re(z)", - "ylabel": "Im(z)", - } -) +# TWO_TONE = LabelDict( +# {"xlabel": BIAS_VOLTAGE, "ylabel": DRIVE_FREQ}, GHz={"ylabel": DRIVE_FREQ_GHz} +# ) +# CHEVRON = LabelDict({"xlabel": "Pulse duration (ns)", "ylabel": "Frequency (MHz)"}) +# AMPLITUDE_TIME = LabelDict( +# {"xlabel": "Pulse duration (ns)", "ylabel": "Frequency (MHz)"} +# ) +# RAMSEY = LabelDict( +# { +# "label": "ramsey", +# "xlabel": "Pulse duration (ns)", +# "ylabel": "Readout quadrature", +# } +# ) +# REIM = LabelDict( +# { +# "xlabel": "Re(z)", +# "ylabel": "Im(z)", +# } +# ) aspect_equal = LabelDict({"aspect": "equal"}) diff --git a/requirements.txt b/requirements.txt index 4a2a08b..806f221 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,2 @@ numpy -scipy matplotlib \ No newline at end of file diff --git a/setup.py b/setup.py index 72659f7..236a935 100644 --- a/setup.py +++ b/setup.py @@ -35,5 +35,5 @@ def get_version() -> str: "Operating System :: OS Independent", ], python_requires=">=3.8", - install_requires=["numpy", "scipy", "matplotlib"], + install_requires=["numpy", "matplotlib"], )