Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions .github/workflows/python-ci-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,11 @@ jobs:
run: |
echo "Main branch version: ${{ steps.get_main_version.outputs.main_version }}"
echo "PR branch version: ${{ steps.get_pr_version.outputs.pr_version }}"
if [ "${{ steps.get_main_version.outputs.main_version }}" = "${{ steps.get_pr_version.outputs.pr_version }}" ]; then
echo "Error: Version is the same as on the main branch"
exit 1
else
echo "Ok: Version is different from the main branch"
if [ "${{ github.event_name }}" == "pull_request" ]; then
if [ "${{ steps.get_main_version.outputs.main_version }}" = "${{ steps.get_pr_version.outputs.pr_version }}" ]; then
echo "Error: Version is the same as on the main branch"
exit 1
else
echo "Ok: Version is different from the main branch"
fi
fi
3 changes: 0 additions & 3 deletions .github/workflows/python-publish.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries

name: Publish-Main-PyPi

env:
Expand Down
2 changes: 1 addition & 1 deletion aplot/__config__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.0"
__version__ = "0.2.3"
3 changes: 3 additions & 0 deletions aplot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# flake8: noqa: F401

import matplotlib.patches as patches

from . import analysis, styles
from .__config__ import __version__
from .core import ax, axs, close, figure, figure_class, show, subplot, subplots
from .core.axes_class import AAxes as Axes
from .core.axes_list import AxesList
from .core.figure_class import AFigure as Figure

s = styles
Expand Down
7 changes: 5 additions & 2 deletions aplot/analysis/array_manipulation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import typing as _t

import numpy as np
import scipy

ArrayLike = _t.Union[np.ndarray, _t.List]

Expand Down Expand Up @@ -29,6 +28,8 @@ def argmin2d(
Tuple[int, int]: index_y, index_x, i.e. the min value is d[index_y, index_x]
"""
if filter_ and filter_ > 1:
import scipy

d = scipy.ndimage.uniform_filter(d, size=3, mode="nearest")

if x_mask is not None:
Expand Down Expand Up @@ -133,7 +134,9 @@ def array_from_span(
return res


def get_z(I: np.ndarray, Q: np.ndarray) -> np.ndarray: # pylint: disable=invalid-name # noqa: E741
def get_z(
I: np.ndarray, Q: np.ndarray
) -> np.ndarray: # pylint: disable=invalid-name # noqa: E741
min_len = min(len(I), len(Q))
return I[:min_len] + 1j * Q[:min_len]

Expand Down
9 changes: 7 additions & 2 deletions aplot/analysis/signal_analysis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import typing as _t

import numpy as np
import scipy


def find_h_symmetry_axis(data: np.ndarray) -> int:
Expand All @@ -13,6 +12,8 @@ def find_h_symmetry_axis(data: np.ndarray) -> int:
Returns:
(int): x index of the symmetry axis.
"""
import scipy

data = (data - np.mean(data)) / np.std(data)
# corr = scipy.signal.fftconvolve(
# data[:, : len(data[0]) // 2], data[:, ::-1], mode="full"
Expand All @@ -22,11 +23,15 @@ def find_h_symmetry_axis(data: np.ndarray) -> int:


def remove_background(data: np.ndarray, convolve_len: _t.Optional[int] = None):
import scipy

if convolve_len is None:
convolve_len = min(50, len(data) // 15)
data = (
data
- scipy.signal.convolve2d(data, np.ones((convolve_len, 1)), mode="same", boundary="symm")
- scipy.signal.convolve2d(
data, np.ones((convolve_len, 1)), mode="same", boundary="symm"
)
/ convolve_len
)
return data - data.mean(axis=1)[:, np.newaxis]
99 changes: 77 additions & 22 deletions aplot/core/axes_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)

_T = _t.TypeVar("_T")
_R = _t.TypeVar("_R")
if _t.TYPE_CHECKING:
from .figure_class import AFigure

Expand Down Expand Up @@ -125,12 +126,29 @@
FILTER_KWARGS = {"hist2d", QuadMesh}


class ClassicReturnAxis:
def __init__(self, axes: "AAxes"):
self.axes = axes
self._previous_state = False

def __enter__(self):
self._previous_state = self.axes._classical_return
self.axes._classical_return = True
return self.axes

def __exit__(self, exc_type, exc_value, traceback):
self.axes._classical_return = self._previous_state
if exc_type is not None:
raise


class AAxes(
MplAxes,
_t.Generic[_T],
):
name = "AAxis" # Give a name for the matplotlib registry
_last_result = None
_classical_return = False
# _fit_result: FitResult | None = None
# __all__ = MplAxes.__all__ + ["fit", "last_result", "fit_result", "res", "set"]
# __dict__ = MplAxes.__dict__ ("fit", "last_result", "fit_result", "res", "set")
Expand Down Expand Up @@ -182,7 +200,7 @@ def __getattribute__(self, name: str):

def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
if isinstance(result, (MplAxes, AAxes)):
if isinstance(result, (MplAxes, AAxes)) or self._classical_return:
return result
self._last_result = result
return self
Expand All @@ -209,34 +227,33 @@ def set( # type: ignore
"ylabel": ylabel,
}
)
super().set(**filter_none_types(kwargs))
return self
return super().set(**filter_none_types(kwargs))
# return self

def hist2d(
def hist2d( # type: ignore
self,
x,
y=None,
bins=10,
range=None, # pylint: disable=redefined-builtin
density=False,
weights=False,
cmin=None,
cmax=None,
*args,
**kwargs,
):
if y is None:
x = np.array(x)
x = x[:, 0]
y = x[:, 1]
return super().hist2d(x, y, bins, range, density, weights, cmin, cmax, **kwargs)
y = x[..., 1]
x = x[..., 0]
return super().hist2d(x, y, *args, **kwargs)

def hist(self, *args, **kwargs):
with ClassicReturnAxis(self):
return super().hist(*args, **kwargs)

def z_parametric(self, z, **kwargs):
self.plot(np.real(z), np.imag(z), **kwargs)
return self
return self.plot(np.real(z), np.imag(z), **kwargs)
# return self

def z_historograms(self, z, **kwargs):
self.hist2d(np.real(z), np.imag(z), **kwargs)
return self
def hist_z(self, z, **kwargs):
return self.hist2d(np.real(z), np.imag(z), **kwargs)
# return self

def imshow( # type: ignore
self,
Expand Down Expand Up @@ -298,10 +315,11 @@ def imshow( # type: ignore
raise ValueError("The figure is None cannot add colorbar")
cbar = fig.colorbar(im, cax=cax, orientation="vertical")
cbar.ax.set_ylabel(kwargs.get("bar_label", ""))
cbar.ax.set_rasterized(kwargs.get("bar_rasterized", kwargs.get("rasterized", False)))
else:
cbar = None

return self
return im
# return self

def pcolorfast( # type: ignore
self,
Expand Down Expand Up @@ -344,7 +362,8 @@ def pcolorfast( # type: ignore
if colorbar:
cbar = fig.colorbar(im, cax=cax, orientation="vertical")
cbar.ax.set_ylabel(kwargs.get("bar_label", ""))
return self
return im
# return self

def autoaxis(self, level: int = 0, func_name="plot") -> "AAxes":
variables = get_auto_args(level, func_name)
Expand All @@ -357,7 +376,14 @@ def tight_layout(self, *, pad=1.08, h_pad=None, w_pad=None, rect=None):
self.figure.tight_layout(pad=pad, h_pad=h_pad, w_pad=w_pad, rect=rect) # type: ignore
return self

def plot(self, *args, keep_xlims: bool = False, keep_ylims: bool = False, axes=None, **kwargs):
def plot(
self,
*args,
keep_xlims: bool = False,
keep_ylims: bool = False,
axes=None,
**kwargs,
):
del axes
xlims = self.get_xlim() if keep_xlims else None
ylims = self.get_ylim() if keep_ylims else None
Expand All @@ -368,9 +394,38 @@ def plot(self, *args, keep_xlims: bool = False, keep_ylims: bool = False, axes=N
self.set_ylim(*ylims)
return res

def axhline(self, y=0, xmin=0, xmax=1, **kwargs) -> "AAxes": # type: ignore
if isinstance(y, _t.Iterable):
return self.update_result(
[self.axhline(y_, xmin=xmin, xmax=xmax, **kwargs).res for y_ in y]
)
return self.update_result(super().axhline(y, xmin=xmin, xmax=xmax, **kwargs))

def axvline(self, x=0, ymin=0, ymax=1, **kwargs) -> "AAxes": # type: ignore
if isinstance(x, _t.Iterable):
return self.update_result(
[self.axvline(x_, ymin=ymin, ymax=ymax, **kwargs).res for x_ in x]
)
return self.update_result(super().axvline(x, ymin=ymin, ymax=ymax, **kwargs))

def __add__(self, other):
from .axes_list import AxesList

if isinstance(other, list):
return AxesList([self] + other) # type: ignore
return AxesList([self, other]) # type: ignore

def update_result(self, result: _R) -> "AAxes[_R]":
self._last_result = result
return self # type: ignore

def colorbar(self, label: _t.Optional[str] = None, *args, **kwargs):
c = self.res
assert c is not None
cbar = self.fig.colorbar(c, ax=self)
if label is not None:
cbar.set_label(label)
return self

def classic_return(self):
return ClassicReturnAxis(self)
Loading