From a671783e42c81d95b7c290e25e3c5411445430e2 Mon Sep 17 00:00:00 2001 From: Kirill <4cd87a@gmail.com> Date: Sun, 30 Jun 2024 14:57:41 +0200 Subject: [PATCH 01/15] fix typing for py3.7 --- aplot/core/axes_class.py | 12 +++++++----- aplot/core/axes_class.pyi | 3 +++ aplot/core/front.pyi | 12 ++++++------ aplot/core/utils.py | 2 +- 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/aplot/core/axes_class.py b/aplot/core/axes_class.py index 275eade..c90530d 100644 --- a/aplot/core/axes_class.py +++ b/aplot/core/axes_class.py @@ -193,10 +193,10 @@ def wrapper(*args, **kwargs): def set( # type: ignore self, - title: str | None | NoneType = noneType, - xlabel: str | None | NoneType = noneType, - ylabel: str | None | NoneType = noneType, - grid: bool | None = None, + title: _t.Optional[_t.Union[str, NoneType]] = noneType, + xlabel: _t.Optional[_t.Union[str, NoneType]] = noneType, + ylabel: _t.Optional[_t.Union[str, NoneType]] = noneType, + grid: _t.Optional[bool] = None, **kwargs, ): kwargs = filter_set_kwargs(AAxes, **kwargs) @@ -355,8 +355,10 @@ def autoaxis(self, level: int = 0, func_name="plot") -> "AAxes": def tight_layout(self, *, pad=1.08, h_pad=None, w_pad=None, rect=None): self.figure.tight_layout(pad=pad, h_pad=h_pad, w_pad=w_pad, rect=rect) # type: ignore + return self - def plot(self, *args, keep_xlims: bool = False, keep_ylims: bool = False, **kwargs): + def plot(self, *args, keep_xlims: bool = False, keep_ylims: bool = False, axes=None, **kwargs): + del axes xlims = self.get_xlim() if keep_xlims else None ylims = self.get_ylim() if keep_ylims else None res = super().plot(*args, **kwargs) diff --git a/aplot/core/axes_class.pyi b/aplot/core/axes_class.pyi index 66e7883..aa02cd7 100644 --- a/aplot/core/axes_class.pyi +++ b/aplot/core/axes_class.pyi @@ -64,6 +64,7 @@ class AAxes(MplAxes, Generic[_T]): cursor_to_use: Cursors = ... figure: AFigure = ... # type: ignore fig: AFigure = ... + patch: Rectangle = ... def set_title( # type: ignore self, label: str, @@ -821,3 +822,5 @@ class AAxes(MplAxes, Generic[_T]): rect: Sequence[float] = ... ) -> _S: ... def __add__(self, other) -> "AxesList": ... + def set_xticks(self, ticks: ArrayLike, labels: ArrayLike | None = None) -> "AAxes[None]": ... + def set_yticks(self, ticks: ArrayLike, labels: ArrayLike | None = None) -> "AAxes[None]": ... diff --git a/aplot/core/front.pyi b/aplot/core/front.pyi index 27afa16..d4bebd1 100644 --- a/aplot/core/front.pyi +++ b/aplot/core/front.pyi @@ -18,7 +18,7 @@ def figure( FigureClass=AFigure, clear: bool = False, **kwargs -): ... +) -> AFigure: ... @overload def subplots() -> _t.Tuple[AFigure, AAxes[None]]: ... @overload @@ -82,7 +82,7 @@ def subplots( **fig_kw ): ... @overload -def axs() -> AAxes[None]: ... +def axs(**kwargs) -> AAxes[None]: ... @overload def axs(nrows: _t.Literal[1], ncols: _t.Literal[1], **kwargs) -> AAxes[None]: ... # type: ignore @overload @@ -92,13 +92,13 @@ def axs(nrows: int, ncols: _t.Literal[1], **kwargs) -> AxesList[AAxes[None]]: .. @overload def axs(nrows: int, ncols: int, **kwargs) -> AxesList[AxesList[AAxes[None]]]: ... @overload -def axs(nrows: _S) -> _S: ... +def axs(nrows: _S, **kwargs) -> _S: ... @overload -def axs(nrows: List[AAxes]) -> AxesList[AAxes[None]]: ... +def axs(nrows: List[AAxes], **kwargs) -> AxesList[AAxes[None]]: ... @overload -def axs(nrows: List[List[AAxes]]) -> AxesList[AxesList[AAxes[None]]]: ... +def axs(nrows: List[List[AAxes]], **kwargs) -> AxesList[AxesList[AAxes[None]]]: ... @overload -def axs(nrows: _t.Union[List[int], _t.Tuple[int]]) -> AxesList[AAxes[None]]: ... +def axs(nrows: _t.Union[List[int], _t.Tuple[int]], **kwargs) -> AxesList[AAxes[None]]: ... def axs(nrows: int = 1, ncols: int = 1, **kwargs): ... def subplot(*args, **kwargs) -> AAxes: ... def ax(*args) -> AAxes: ... diff --git a/aplot/core/utils.py b/aplot/core/utils.py index f0a95e0..7872d90 100644 --- a/aplot/core/utils.py +++ b/aplot/core/utils.py @@ -31,7 +31,7 @@ def filter_none_types(kwargs: dict) -> dict: return {k: v for k, v in kwargs.items() if not isinstance(v, NoneType)} -def filter_none(data: dict | None = None, **kwargs) -> dict: +def filter_none(data: _t.Optional[dict] = None, **kwargs) -> dict: if data is not None: kwargs.update(data) return {k: v for k, v in kwargs.items() if v is not None} From d5bb539c76b2e0d11cfd23d60e9e9af1e7d272a0 Mon Sep 17 00:00:00 2001 From: Kirill <4cd87a@gmail.com> Date: Sun, 30 Jun 2024 17:57:16 +0200 Subject: [PATCH 02/15] update workflow names --- .github/workflows/python-flake8.yml | 2 +- .github/workflows/python-tests.yml | 2 +- aplot/core/front.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/python-flake8.yml b/.github/workflows/python-flake8.yml index 598655c..e48326b 100644 --- a/.github/workflows/python-flake8.yml +++ b/.github/workflows/python-flake8.yml @@ -13,7 +13,7 @@ permissions: contents: read jobs: - build: + build-flake8: runs-on: ubuntu-latest steps: diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml index 91cc9d0..c262767 100644 --- a/.github/workflows/python-tests.yml +++ b/.github/workflows/python-tests.yml @@ -13,7 +13,7 @@ permissions: contents: read jobs: - build: + run-tests: runs-on: ubuntu-latest env: TestingOn: GitHub diff --git a/aplot/core/front.py b/aplot/core/front.py index 03fc411..c7eef70 100644 --- a/aplot/core/front.py +++ b/aplot/core/front.py @@ -73,7 +73,7 @@ def subplots( def axs( - nrows: int | AAxes | _t.List[AAxes] | _t.List[int] | "AxesList" = 1, + nrows: _t.Union[int, AAxes, _t.List[AAxes], _t.List[int], "AxesList"] = 1, ncols: int = 1, /, **kwargs, From a8eebc83bc1b7e3f8a0185b23ff8ac555d8ae452 Mon Sep 17 00:00:00 2001 From: Kirill <4cd87a@gmail.com> Date: Sun, 30 Jun 2024 18:25:47 +0200 Subject: [PATCH 03/15] update workflow to have less files --- .github/workflows/python-ci-main.yml | 96 ++++++++++++++++++++++++++ .github/workflows/python-flake8.yml | 32 --------- .github/workflows/python-publish.yml | 40 +++++++++-- .github/workflows/python-test-pypi.yml | 44 ------------ .github/workflows/python-tests.yml | 44 ------------ 5 files changed, 132 insertions(+), 124 deletions(-) create mode 100644 .github/workflows/python-ci-main.yml delete mode 100644 .github/workflows/python-flake8.yml delete mode 100644 .github/workflows/python-test-pypi.yml delete mode 100644 .github/workflows/python-tests.yml diff --git a/.github/workflows/python-ci-main.yml b/.github/workflows/python-ci-main.yml new file mode 100644 index 0000000..309b769 --- /dev/null +++ b/.github/workflows/python-ci-main.yml @@ -0,0 +1,96 @@ +# This workflow will install Python dependencies, run tests and lint with a single version of Python +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: CI on main + +env: + package-name: aplot + +on: + push: + branches: ['main'] + pull_request: + branches: ['main'] + +permissions: + contents: read + +jobs: + run-tests: + runs-on: ubuntu-latest + env: + TestingOn: GitHub + + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.8 + uses: actions/setup-python@v3 + with: + python-version: '3.8' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest pytest-cov + pip install -e . + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Test with pytest + run: | + pytest --cov=${{ env.package-name }} --cov-report=xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v1 + with: + token: ${{secrets.CODECOV_TOKEN}} + file: ./coverage.xml + flags: unittests + name: unit-tests-coverage + fail_ci_if_error: false + + run-flake8: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.10 + uses: actions/setup-python@v3 + with: + python-version: '3.10' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install flake8 + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Lint with flake8 + run: | + flake8 . --count --max-complexity=10 --max-line-length=127 --ignore="E731, E741, E203, E265, E226, C901, W504, W503, E704" + + version-check: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Get main branch version + id: get_main_version + run: | + git fetch origin main + main_version=$(git show origin/main: ${{env.package-name}}/__config__.py | grep '__version__' | cut -d'"' -f2) + echo "Main branch version: $main_version" + echo "::set-output name=main_version::$main_version" + + - name: Get PR branch version + id: get_pr_version + run: | + pr_version=$(grep '__version__' ${{env.package-name}}/__config__.py | cut -d'"' -f2) + echo "PR branch version: $pr_version" + echo "::set-output name=pr_version::$pr_version" + + - name: Compare versions + run: | + if [ "${{ steps.get_main_version.outputs.main_version }}" = "${{ steps.get_pr_version.outputs.pr_version }}" ]; then + echo "Error: Version is the same as on the main branch" + exit 1 + else + echo "Ok: Version is different from the main branch" + fi diff --git a/.github/workflows/python-flake8.yml b/.github/workflows/python-flake8.yml deleted file mode 100644 index e48326b..0000000 --- a/.github/workflows/python-flake8.yml +++ /dev/null @@ -1,32 +0,0 @@ -# This workflow will install Python dependencies, run tests and lint with a single version of Python -# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions - -name: Flake8 - -on: - push: - branches: ['main'] - pull_request: - branches: ['main'] - -permissions: - contents: read - -jobs: - build-flake8: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.10 - uses: actions/setup-python@v3 - with: - python-version: '3.10' - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install flake8 - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - - name: Lint with flake8 - run: | - flake8 . --count --max-complexity=10 --max-line-length=127 --ignore="E731, E741, E203, E265, E226, C901, W504, W503, E704" diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 65e5e7d..d15daef 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -1,14 +1,17 @@ # This workflow will upload a Python Package using Twine when a release is created # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries -name: UploadPyPi +name: Publish-Main-PyPi + +env: + package-name: aplot on: - release: - types: [published] + push: + branches: ['main'] jobs: - publish: + publish-pypi: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 @@ -27,3 +30,32 @@ jobs: run: | python setup.py sdist bdist_wheel twine upload dist/* + + # Test package that was just published + test-pypi: + needs: publish-pypi + runs-on: ubuntu-latest + env: + TestingOn: GitHub + + strategy: + fail-fast: false + matrix: + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install flake8 pytest + python -m pip install ${{ env.package-name }} + + - name: Test with pytest + run: | + rm -R ${{ env.package-name }} + pytest . diff --git a/.github/workflows/python-test-pypi.yml b/.github/workflows/python-test-pypi.yml deleted file mode 100644 index 1d982b9..0000000 --- a/.github/workflows/python-test-pypi.yml +++ /dev/null @@ -1,44 +0,0 @@ -# This workflow will install Python dependencies, run tests and lint with a variety of Python versions -# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python - -name: TestPyPi - -on: - workflow_dispatch: - inputs: - environment: - type: environment - default: DEV - required: true - workflow_run: - workflows: ['UploadPyPi'] - types: - - completed - -jobs: - build: - runs-on: ubuntu-latest - env: - TestingOn: GitHub - - strategy: - fail-fast: false - matrix: - python-version: ['3.8', '3.9', '3.10'] - - steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v3 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install flake8 pytest - python -m pip install aplot - - - name: Test with pytest - run: | - rm -R aplot - pytest . diff --git a/.github/workflows/python-tests.yml b/.github/workflows/python-tests.yml deleted file mode 100644 index c262767..0000000 --- a/.github/workflows/python-tests.yml +++ /dev/null @@ -1,44 +0,0 @@ -# This workflow will install Python dependencies, run tests and lint with a single version of Python -# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions - -name: Tests - -on: - push: - branches: ['main'] - pull_request: - branches: ['main'] - -permissions: - contents: read - -jobs: - run-tests: - runs-on: ubuntu-latest - env: - TestingOn: GitHub - - steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.8 - uses: actions/setup-python@v3 - with: - python-version: '3.8' - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install pytest pytest-cov - pip install -e . - if [ -f requirements.txt ]; then pip install -r requirements.txt; fi - - name: Test with pytest - run: | - pytest --cov=aplot --cov-report=xml - - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v1 - with: - token: ${{secrets.CODECOV_TOKEN}} - file: ./coverage.xml - flags: unittests - name: unit-tests-coverage - fail_ci_if_error: false From a9ebafb49cf98f9a7c784076ba03818bcdaca812 Mon Sep 17 00:00:00 2001 From: Kirill <4cd87a@gmail.com> Date: Sun, 30 Jun 2024 18:35:34 +0200 Subject: [PATCH 04/15] only build docs workflow on PR --- .github/workflows/docs.yml | 7 ++++++- .github/workflows/python-ci-main.yml | 9 ++++++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 96e0e66..5799280 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -23,4 +23,9 @@ jobs: restore-keys: | mkdocs- - run: pip install -r requirements-docs.txt - - run: mkdocs gh-deploy --force + - run: | + if [ "${{ github.event_name }}" == "push" ]; then + mkdocs gh-deploy --force + else + mkdocs build --strict + fi diff --git a/.github/workflows/python-ci-main.yml b/.github/workflows/python-ci-main.yml index 309b769..9130805 100644 --- a/.github/workflows/python-ci-main.yml +++ b/.github/workflows/python-ci-main.yml @@ -74,20 +74,23 @@ jobs: - name: Get main branch version id: get_main_version run: | - git fetch origin main + git fetch main + git show origin/main: ${{env.package-name}}/__config__.py main_version=$(git show origin/main: ${{env.package-name}}/__config__.py | grep '__version__' | cut -d'"' -f2) echo "Main branch version: $main_version" - echo "::set-output name=main_version::$main_version" + echo "main_version=$main_version" >> "$GITHUB_OUTPUT" - name: Get PR branch version id: get_pr_version run: | pr_version=$(grep '__version__' ${{env.package-name}}/__config__.py | cut -d'"' -f2) echo "PR branch version: $pr_version" - echo "::set-output name=pr_version::$pr_version" + echo "pr_version=$pr_version" >> "$GITHUB_OUTPUT" - name: Compare versions run: | + echo "Main branch version: ${{ steps.get_main_version.outputs.main_version }}" + echo "PR branch version: ${{ steps.get_pr_version.outputs.pr_version }}" if [ "${{ steps.get_main_version.outputs.main_version }}" = "${{ steps.get_pr_version.outputs.pr_version }}" ]; then echo "Error: Version is the same as on the main branch" exit 1 From e137f6ec2f6e8e1d1cf0d0c8701c5f68e83365c9 Mon Sep 17 00:00:00 2001 From: Kirill <4cd87a@gmail.com> Date: Sun, 30 Jun 2024 18:37:59 +0200 Subject: [PATCH 05/15] fix compare workflow --- .github/workflows/python-ci-main.yml | 2 +- mkdocs.yml | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/python-ci-main.yml b/.github/workflows/python-ci-main.yml index 9130805..9244899 100644 --- a/.github/workflows/python-ci-main.yml +++ b/.github/workflows/python-ci-main.yml @@ -74,7 +74,7 @@ jobs: - name: Get main branch version id: get_main_version run: | - git fetch main + git fetch origin main git show origin/main: ${{env.package-name}}/__config__.py main_version=$(git show origin/main: ${{env.package-name}}/__config__.py | grep '__version__' | cut -d'"' -f2) echo "Main branch version: $main_version" diff --git a/mkdocs.yml b/mkdocs.yml index 0558921..e97e108 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -60,7 +60,6 @@ plugins: separate_signature: true unwrap_annotated: true merge_init_into_class: true - watch: aplot/ nav: - Getting Started: From 3bb300266ed650913240123251257b80a5773118 Mon Sep 17 00:00:00 2001 From: Kirill <4cd87a@gmail.com> Date: Sun, 30 Jun 2024 18:45:34 +0200 Subject: [PATCH 06/15] fix workflows --- .github/workflows/docs.yml | 2 +- .github/workflows/python-ci-main.yml | 22 ++++++++++++---------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 5799280..c2670b0 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -27,5 +27,5 @@ jobs: if [ "${{ github.event_name }}" == "push" ]; then mkdocs gh-deploy --force else - mkdocs build --strict + mkdocs build fi diff --git a/.github/workflows/python-ci-main.yml b/.github/workflows/python-ci-main.yml index 9244899..52e7528 100644 --- a/.github/workflows/python-ci-main.yml +++ b/.github/workflows/python-ci-main.yml @@ -69,16 +69,7 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v2 - - - name: Get main branch version - id: get_main_version - run: | - git fetch origin main - git show origin/main: ${{env.package-name}}/__config__.py - main_version=$(git show origin/main: ${{env.package-name}}/__config__.py | grep '__version__' | cut -d'"' -f2) - echo "Main branch version: $main_version" - echo "main_version=$main_version" >> "$GITHUB_OUTPUT" + uses: actions/checkout@v3 - name: Get PR branch version id: get_pr_version @@ -87,6 +78,17 @@ jobs: echo "PR branch version: $pr_version" echo "pr_version=$pr_version" >> "$GITHUB_OUTPUT" + - name: Get main branch version + id: get_main_version + run: | + git fetch origin + git checkout origin/main -- ${{env.package-name}}/__config__.py + echo "Config file:" + git show origin/main: ${{env.package-name}}/__config__.py + main_version=$(grep '__version__' __config__.py | cut -d'"' -f2) + echo "Main branch version: $main_version" + echo "main_version=$main_version" >> "$GITHUB_OUTPUT" + - name: Compare versions run: | echo "Main branch version: ${{ steps.get_main_version.outputs.main_version }}" From 7ceaee71e7e3db6de9c2ba83f23a3503a5320d36 Mon Sep 17 00:00:00 2001 From: Kirill <4cd87a@gmail.com> Date: Sun, 30 Jun 2024 18:51:13 +0200 Subject: [PATCH 07/15] fix workflows --- .github/workflows/docs.yml | 6 ++++-- .github/workflows/python-ci-main.yml | 5 ++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index c2670b0..b2f087f 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -15,7 +15,8 @@ jobs: - uses: actions/setup-python@v4 with: python-version: 3.x - - run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV + - name: Restore cache + run: echo "cache_id=$(date --utc '+%V')" >> $GITHUB_ENV - uses: actions/cache@v3 with: key: mkdocs-${{ env.cache_id }} @@ -23,7 +24,8 @@ jobs: restore-keys: | mkdocs- - run: pip install -r requirements-docs.txt - - run: | + - name: Build and deploy if push + run: | if [ "${{ github.event_name }}" == "push" ]; then mkdocs gh-deploy --force else diff --git a/.github/workflows/python-ci-main.yml b/.github/workflows/python-ci-main.yml index 52e7528..7cc7352 100644 --- a/.github/workflows/python-ci-main.yml +++ b/.github/workflows/python-ci-main.yml @@ -82,10 +82,9 @@ jobs: id: get_main_version run: | git fetch origin - git checkout origin/main -- ${{env.package-name}}/__config__.py echo "Config file:" - git show origin/main: ${{env.package-name}}/__config__.py - main_version=$(grep '__version__' __config__.py | cut -d'"' -f2) + git show origin/main:${{env.package-name}}/__config__.py + main_version=$(git show origin/main:${{env.package-name}}/__config__.py | grep '__version__' | cut -d'"' -f2) echo "Main branch version: $main_version" echo "main_version=$main_version" >> "$GITHUB_OUTPUT" From 75143a08a6d4ea8253c055574416044c375628e2 Mon Sep 17 00:00:00 2001 From: Kirill <4cd87a@gmail.com> Date: Sun, 30 Jun 2024 19:03:41 +0200 Subject: [PATCH 08/15] add add_axes to Figure --- aplot/__config__.py | 2 +- aplot/core/axes_list.py | 11 +++++++--- aplot/core/axes_list.pyi | 2 +- aplot/core/figure_class.py | 41 ++++++++++++++++++++++++++++++++++++-- 4 files changed, 49 insertions(+), 7 deletions(-) diff --git a/aplot/__config__.py b/aplot/__config__.py index 3dc1f76..d3ec452 100644 --- a/aplot/__config__.py +++ b/aplot/__config__.py @@ -1 +1 @@ -__version__ = "0.1.0" +__version__ = "0.2.0" diff --git a/aplot/core/axes_list.py b/aplot/core/axes_list.py index d9c3304..dd5485c 100644 --- a/aplot/core/axes_list.py +++ b/aplot/core/axes_list.py @@ -30,7 +30,12 @@ def set(self, **kwargs): # return self.__getitem__(item % len(self))[item // len(self)] # return super().__getitem__(item) - def plot(self, x, data, *args, scalex: bool = True, scaley: bool = True, **kwargs): + def plot(self, x, data, *args, axes=None, **kwargs): + if axes is not None: + ax = self[axes] + ax.plot(x, data, *args, **kwargs) + return self + if len(x) != len(data): if len(data) != len(self): raise ValueError( @@ -171,6 +176,7 @@ def __getattr__(self, key): def mapping(*args, **kwargs): for ax in self: getattr(ax, key)(*args, **kwargs) + return self return mapping # return super().__getattribute__(key) @@ -198,5 +204,4 @@ def __getitem__(self, key: _t.Union[int, _t.Tuple[int, ...]]): # type: ignore if not isinstance(res, AxesList): return AxesList(res) return res - - return super(AxesList, self).__getitem__(key) + return super().__getitem__(key) diff --git a/aplot/core/axes_list.pyi b/aplot/core/axes_list.pyi index ad82ecd..b0cdd07 100644 --- a/aplot/core/axes_list.pyi +++ b/aplot/core/axes_list.pyi @@ -25,7 +25,7 @@ from matplotlib.figure import Figure from matplotlib.image import AxesImage from matplotlib.lines import Line2D from matplotlib.markers import MarkerStyle -from matplotlib.patches import Patch +from matplotlib.patches import Patch, Rectangle from matplotlib.quiver import Quiver from matplotlib.scale import ScaleBase from matplotlib.spines import Spines diff --git a/aplot/core/figure_class.py b/aplot/core/figure_class.py index 23bba1a..c283558 100644 --- a/aplot/core/figure_class.py +++ b/aplot/core/figure_class.py @@ -1,7 +1,16 @@ +from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, overload + +import matplotlib.pyplot as plt from matplotlib.figure import Figure as MplFigure +if TYPE_CHECKING: + from mpl_toolkits.mplot3d.axes3d import Axes3D as MplAxes3D + from matplotlib.projections.polar import PolarAxes as MplPolarAxes + from .axes_class import AAxes +_T = TypeVar("_T") + class AFigure(MplFigure): @@ -12,7 +21,35 @@ def __init__(self, *args, **kwargs): def custom_draw_method(self): print("Custom drawing behavior here") - def add_subplot(self, *args, **kwargs): + def add_subplot(self, *args, **kwargs) -> AAxes: # type: ignore # Ensuring that the custom axes class is used - kwargs.update({"axes_class": AAxes}) + if "projection" not in kwargs and "polar" not in kwargs: + kwargs.update({"axes_class": AAxes}) return super().add_subplot(*args, **kwargs) + + def savefig(self, fname: Any, *, transparent=None, **kwargs): # type: ignore + super().savefig(fname, transparent=transparent, **kwargs) + return self + + def tight_layout(self, *args, **kwargs): # type: ignore + super().tight_layout(*args, **kwargs) + return self + + @overload + def add_axes(self, rect, projection: Literal["3d"]) -> "MplAxes3D": ... # type: ignore + @overload + def add_axes(self, rect, projection: Literal["polar"]) -> "MplPolarAxes": ... # type: ignore + @overload + def add_axes(self, rect, polar: Literal[True]) -> "MplPolarAxes": ... # type: ignore + @overload + def add_axes(self, rect, projection: Optional[str], polar: bool) -> AAxes: ... # type: ignore + @overload + def add_axes(self, ax: _T) -> _T: ... # type: ignore + + def add_axes(self, *args, **kwargs): # type: ignore + if "projection" not in kwargs and "polar" not in kwargs: + kwargs.update({"axes_class": AAxes}) + return super().add_axes(*args, **kwargs) # type: ignore + + def show(self): # type: ignore + plt.show(self) From 35efc4310bfd8e951bea3142c0613ddef030dd7b Mon Sep 17 00:00:00 2001 From: kyrylo-gr Date: Sun, 11 Aug 2024 17:30:00 +0200 Subject: [PATCH 09/15] add label_axes on figure --- .github/workflows/python-ci-main.yml | 12 +++++----- .github/workflows/python-publish.yml | 3 --- aplot/__init__.py | 2 ++ aplot/core/axes_list.py | 9 ++++++++ aplot/core/axes_list.pyi | 22 +++++++++++++++---- aplot/core/figure_class.py | 33 +++++++++++++++++++++++++++- 6 files changed, 68 insertions(+), 13 deletions(-) diff --git a/.github/workflows/python-ci-main.yml b/.github/workflows/python-ci-main.yml index 7cc7352..3731892 100644 --- a/.github/workflows/python-ci-main.yml +++ b/.github/workflows/python-ci-main.yml @@ -92,9 +92,11 @@ jobs: run: | echo "Main branch version: ${{ steps.get_main_version.outputs.main_version }}" echo "PR branch version: ${{ steps.get_pr_version.outputs.pr_version }}" - if [ "${{ steps.get_main_version.outputs.main_version }}" = "${{ steps.get_pr_version.outputs.pr_version }}" ]; then - echo "Error: Version is the same as on the main branch" - exit 1 - else - echo "Ok: Version is different from the main branch" + if [ "${{ github.event_name }}" == "pull_request" ]; then + if [ "${{ steps.get_main_version.outputs.main_version }}" = "${{ steps.get_pr_version.outputs.pr_version }}" ]; then + echo "Error: Version is the same as on the main branch" + exit 1 + else + echo "Ok: Version is different from the main branch" + fi fi diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index d15daef..4fc8e95 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -1,6 +1,3 @@ -# This workflow will upload a Python Package using Twine when a release is created -# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries - name: Publish-Main-PyPi env: diff --git a/aplot/__init__.py b/aplot/__init__.py index 0f499e3..095e259 100644 --- a/aplot/__init__.py +++ b/aplot/__init__.py @@ -1,5 +1,7 @@ # flake8: noqa: F401 +import matplotlib.patches as patches + from . import analysis, styles from .__config__ import __version__ from .core import ax, axs, close, figure, figure_class, show, subplot, subplots diff --git a/aplot/core/axes_list.py b/aplot/core/axes_list.py index dd5485c..4cd1940 100644 --- a/aplot/core/axes_list.py +++ b/aplot/core/axes_list.py @@ -205,3 +205,12 @@ def __getitem__(self, key: _t.Union[int, _t.Tuple[int, ...]]): # type: ignore return AxesList(res) return res return super().__getitem__(key) + + def flat(self): + res = [] + for ax in self: + if isinstance(ax, AxesList): + res.extend(ax.flat()) + else: + res.append(ax) + return AxesList(res) diff --git a/aplot/core/axes_list.pyi b/aplot/core/axes_list.pyi index b0cdd07..1e3c1e9 100644 --- a/aplot/core/axes_list.pyi +++ b/aplot/core/axes_list.pyi @@ -134,7 +134,12 @@ class AxesList(List[_T]): **kwargs, ) -> _S: ... def axhspan( # type: ignore - self: _S, ymin: float, ymax: float, xmin: float = ..., xmax: float = ..., **kwargs + self: _S, + ymin: float, + ymax: float, + xmin: float = ..., + xmax: float = ..., + **kwargs, ) -> _S: ... def axvspan( # type: ignore self: _S, xmin: float, xmax: float, ymin: float = 0, ymax: float = 1, **kwargs @@ -235,7 +240,10 @@ class AxesList(List[_T]): **kwargs, ) -> _S: ... def broken_barh( # type: ignore - self: _S, xranges: Sequence[tuple[float, float]], yrange: tuple[float, float], **kwargs + self: _S, + xranges: Sequence[tuple[float, float]], + yrange: tuple[float, float], + **kwargs, ) -> _S: ... def stem( # type: ignore self: _S, @@ -734,7 +742,10 @@ class AxesList(List[_T]): useMathText: bool = ..., ) -> _S: ... def locator_params( # type: ignore - self: _S, axis: Literal["both", "x", "y"] = ..., tight: bool | None = ..., **kwargs + self: _S, + axis: Literal["both", "x", "y"] = ..., + tight: bool | None = ..., + **kwargs, ) -> _S: ... def tick_params(self: _S, axis: Literal["x", "y", "both"] = ..., **kwargs) -> _S: ... # type: ignore def set_axis_off(self: _S) -> _S: ... # type: ignore @@ -794,7 +805,9 @@ class AxesList(List[_T]): ymax: float = ..., ) -> _S: ... def set_yscale( # type: ignore - self: _S, value: Literal["linear", "log", "symlog", "logit"] | ScaleBase, **kwargs + self: _S, + value: Literal["linear", "log", "symlog", "logit"] | ScaleBase, + **kwargs, ) -> _S: ... def minorticks_on(self: _S) -> _S: ... # type: ignore def minorticks_off(self: _S) -> _S: ... # type: ignore @@ -845,3 +858,4 @@ class AxesList(List[_T]): self, key: Union[int, Tuple[Union[int, slice], ...], slice], ) -> _T: ... # type: ignore + def flat(self) -> "AxesList[AAxes]": ... diff --git a/aplot/core/figure_class.py b/aplot/core/figure_class.py index c283558..bcb34da 100644 --- a/aplot/core/figure_class.py +++ b/aplot/core/figure_class.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar, overload +from typing import TYPE_CHECKING, Any, List, Literal, Optional, TypeVar, Union, overload import matplotlib.pyplot as plt from matplotlib.figure import Figure as MplFigure @@ -8,8 +8,10 @@ from matplotlib.projections.polar import PolarAxes as MplPolarAxes from .axes_class import AAxes +from .axes_list import AxesList _T = TypeVar("_T") +_F = TypeVar("_F", bound="AFigure") class AFigure(MplFigure): @@ -53,3 +55,32 @@ def add_axes(self, *args, **kwargs): # type: ignore def show(self): # type: ignore plt.show(self) + + @property + def axes(self) -> "AxesList[AAxes]": # type: ignore + return AxesList(self._axstack.as_list()) # type: ignore + + def label_axes( + self: _F, + labels: Union[Literal["vertical", "horizontal"], List[str]] = "horizontal", + *, + axes: Optional["AxesList"] = None, + ) -> _F: + if axes is None: + axes = self.axes + axes_list = axes.flat() + if labels == "horizontal": + labels = [f"({chr(65+i)})" for i in range(len(axes_list))] + elif labels == "vertical": + raise NotImplementedError("Vertical labels not yet implemented") + for ax, label in zip(axes_list, labels): + ax.text( + 0.02, + 0.95, + label, + transform=ax.transAxes, + fontsize=14, + va="top", + ) + + return self From ce9cbdbc5f458475bce66814282d957ae0cbe0a4 Mon Sep 17 00:00:00 2001 From: kyrylo-gr Date: Wed, 30 Oct 2024 09:56:29 +0100 Subject: [PATCH 10/15] add label_axes func on a figure. clean the styles --- aplot/__init__.py | 1 + aplot/core/axes_class.py | 63 ++++++++++--- aplot/core/axes_class.pyi | 180 ++++++++++++++++++++++++++----------- aplot/core/axes_list.py | 12 ++- aplot/core/axes_list.pyi | 15 +++- aplot/core/figure_class.py | 115 +++++++++++++++++++++--- aplot/styles/__init__.py | 107 +++++++++++----------- 7 files changed, 364 insertions(+), 129 deletions(-) diff --git a/aplot/__init__.py b/aplot/__init__.py index 095e259..bbd0f1c 100644 --- a/aplot/__init__.py +++ b/aplot/__init__.py @@ -6,6 +6,7 @@ from .__config__ import __version__ from .core import ax, axs, close, figure, figure_class, show, subplot, subplots from .core.axes_class import AAxes as Axes +from .core.axes_list import AxesList from .core.figure_class import AFigure as Figure s = styles diff --git a/aplot/core/axes_class.py b/aplot/core/axes_class.py index c90530d..ebb2c2c 100644 --- a/aplot/core/axes_class.py +++ b/aplot/core/axes_class.py @@ -19,6 +19,7 @@ ) _T = _t.TypeVar("_T") +_R = _t.TypeVar("_R") if _t.TYPE_CHECKING: from .figure_class import AFigure @@ -212,29 +213,24 @@ def set( # type: ignore super().set(**filter_none_types(kwargs)) return self - def hist2d( + def hist2d( # type: ignore self, x, y=None, - bins=10, - range=None, # pylint: disable=redefined-builtin - density=False, - weights=False, - cmin=None, - cmax=None, + *args, **kwargs, ): if y is None: x = np.array(x) x = x[:, 0] y = x[:, 1] - return super().hist2d(x, y, bins, range, density, weights, cmin, cmax, **kwargs) + return super().hist2d(x, y, *args, **kwargs) def z_parametric(self, z, **kwargs): self.plot(np.real(z), np.imag(z), **kwargs) return self - def z_historograms(self, z, **kwargs): + def hist_z(self, z, **kwargs): self.hist2d(np.real(z), np.imag(z), **kwargs) return self @@ -261,8 +257,14 @@ def imshow( # type: ignore colorbar: bool = True, **kwargs, ): - if x is not None and y is not None and (len(data) != len(y) or len(data[0]) != len(x)): - raise ValueError(f"Wrong shapes. {len(data)} != {len(y)} or {len(data[0])} != {len(x)}") + if ( + x is not None + and y is not None + and (len(data) != len(y) or len(data[0]) != len(x)) + ): + raise ValueError( + f"Wrong shapes. {len(data)} != {len(y)} or {len(data[0])} != {len(x)}" + ) imshow_kwargs: dict = imshow_kwds(x, y) imshow_kwargs.update( @@ -298,6 +300,7 @@ def imshow( # type: ignore raise ValueError("The figure is None cannot add colorbar") cbar = fig.colorbar(im, cax=cax, orientation="vertical") cbar.ax.set_ylabel(kwargs.get("bar_label", "")) + cbar.ax.set_rasterized(False) else: cbar = None @@ -318,7 +321,9 @@ def pcolorfast( # type: ignore elif len(args) == 3: x, y, data = args elif len(args) > 0: - raise ValueError(f"Wrong number of arguments: {len(args)}. Should be 0, 1 or 3.") + raise ValueError( + f"Wrong number of arguments: {len(args)}. Should be 0, 1 or 3." + ) if data is None: raise ValueError("Data should be provided") @@ -357,7 +362,14 @@ def tight_layout(self, *, pad=1.08, h_pad=None, w_pad=None, rect=None): self.figure.tight_layout(pad=pad, h_pad=h_pad, w_pad=w_pad, rect=rect) # type: ignore return self - def plot(self, *args, keep_xlims: bool = False, keep_ylims: bool = False, axes=None, **kwargs): + def plot( + self, + *args, + keep_xlims: bool = False, + keep_ylims: bool = False, + axes=None, + **kwargs, + ): del axes xlims = self.get_xlim() if keep_xlims else None ylims = self.get_ylim() if keep_ylims else None @@ -368,9 +380,34 @@ def plot(self, *args, keep_xlims: bool = False, keep_ylims: bool = False, axes=N self.set_ylim(*ylims) return res + def axhline(self, y=0, xmin=0, xmax=1, **kwargs) -> "AAxes": # type: ignore + if isinstance(y, _t.Iterable): + return self.update_result( + [self.axhline(y_, xmin=xmin, xmax=xmax, **kwargs).res for y_ in y] + ) + return self.update_result(super().axhline(y, xmin=xmin, xmax=xmax, **kwargs)) + + def axvline(self, x=0, ymin=0, ymax=1, **kwargs) -> "AAxes": # type: ignore + if isinstance(x, _t.Iterable): + return self.update_result( + [self.axvline(x_, ymin=ymin, ymax=ymax, **kwargs).res for x_ in x] + ) + return self.update_result(super().axvline(x, ymin=ymin, ymax=ymax, **kwargs)) + def __add__(self, other): from .axes_list import AxesList if isinstance(other, list): return AxesList([self] + other) # type: ignore return AxesList([self, other]) # type: ignore + + def update_result(self, result: _R) -> "AAxes[_R]": + self._last_result = result + return self # type: ignore + + def colorbar(self, label: _t.Optional[str] = None, *args, **kwargs): + c = self.res + cbar = self.fig.colorbar(c, ax=self) + if label is not None: + cbar.set_label(label) + return self diff --git a/aplot/core/axes_class.pyi b/aplot/core/axes_class.pyi index aa02cd7..c119c84 100644 --- a/aplot/core/axes_class.pyi +++ b/aplot/core/axes_class.pyi @@ -1,6 +1,16 @@ # flake8: noqa: E302, E704 import datetime -from typing import Callable, Generic, Literal, Sequence, TypeVar, Union, overload +from typing import ( + Callable, + Generic, + Iterable, + List, + Literal, + Sequence, + TypeVar, + Union, + overload, +) import numpy as np from matplotlib.artist import Artist @@ -73,11 +83,36 @@ class AAxes(MplAxes, Generic[_T]): pad: float = ..., *, y: float = ..., - **kwargs + **kwargs, ) -> "AAxes": ... - def legend(self, *args, **kwargs) -> "AAxes": ... # type: ignore + def legend( + self, + *args, + loc: Union[ + int, + Literal[ + "best", + "upper right", + "upper left", + "lower left", + "lower right", + "right", + "center left", + "center right", + "lower center", + "upper center", + "center", + ], + ] = "best", + **kwargs, + ) -> "AAxes": ... # type: ignore def inset_axes( - self, bounds: Sequence[float], *, transform: Transform = ..., zorder: float = ..., **kwargs + self, + bounds: Sequence[float], + *, + transform: Transform = ..., + zorder: float = ..., + **kwargs, ) -> "AAxes": ... def indicate_inset( # type: ignore self, @@ -89,7 +124,7 @@ class AAxes(MplAxes, Generic[_T]): edgecolor: Color = ..., alpha: float = ..., zorder: float = ..., - **kwargs + **kwargs, ) -> "AAxes": ... def indicate_inset_zoom(self, inset_ax: _Axes, **kwargs) -> "AAxes[Rectangle]": ... # type: ignore def secondary_xaxis( # type: ignore @@ -97,14 +132,14 @@ class AAxes(MplAxes, Generic[_T]): location: Literal["top", "bottom", "left", "right"] | float, *, functions=..., - **kwargs + **kwargs, ) -> "AAxes[SecondaryAxis]": ... def secondary_yaxis( # type: ignore self, location: Literal["top", "bottom", "left", "right"] | float, *, functions=..., - **kwargs + **kwargs, ) -> "AAxes[SecondaryAxis]": ... def text(self, x: float, y: float, s: str, fontdict: dict = ..., **kwargs) -> "AAxes[Text]": ... # type: ignore def annotate( # type: ignore @@ -116,21 +151,35 @@ class AAxes(MplAxes, Generic[_T]): textcoords: str | Artist | Transform | Callable = ..., arrowprops: dict = ..., annotation_clip: bool | None = ..., - **kwargs + **kwargs, ) -> "AAxes[Annotation]": ... + @overload def axhline( # type: ignore - self, y: float = 0, xmin: float = 0, xmax: float = 1, **kwargs + self, y: float, xmin: float = 0, xmax: float = 1, **kwargs ) -> "AAxes[Line2D]": ... + @overload + def axhline( # type: ignore + self, y: Iterable[float], xmin: float = 0, xmax: float = 1, **kwargs + ) -> "AAxes[List[Line2D]]": ... + def axhline( # type: ignore + self, y=0, xmin: float = 0, xmax: float = 1, **kwargs + ) -> "Union[AAxes[List[Line2D]], AAxes[Line2D]]": ... def axvline( # type: ignore self, x: float = ..., ymin: float = ..., ymax: float = ..., **kwargs ) -> "AAxes[Line2D]": ... + def axvline( # type: ignore + self, x: Iterable[float] = ..., ymin: float = ..., ymax: float = ..., **kwargs + ) -> "AAxes[List[Line2D]]": ... + def axvline( # type: ignore + self, x: float = ..., ymin: float = ..., ymax: float = ..., **kwargs + ) -> "Union[AAxes[List[Line2D]], AAxes[Line2D]]": ... def axline( # type: ignore self, xy1: tuple[float, float], xy2: tuple[float, float] = ..., *, slope: float = ..., - **kwargs + **kwargs, ) -> "AAxes[Line2D]": ... def axhspan( # type: ignore self, ymin: float, ymax: float, xmin: float = ..., xmax: float = ..., **kwargs @@ -146,7 +195,7 @@ class AAxes(MplAxes, Generic[_T]): colors: list[Color] = ..., linestyles: Literal["solid", "dashed", "dashdot", "dotted"] = ..., label: str = ..., - **kwargs + **kwargs, ) -> "AAxes[LineCollection]": ... def vlines( # type: ignore self, @@ -156,7 +205,7 @@ class AAxes(MplAxes, Generic[_T]): colors: list[Color] = ..., linestyles: Literal["solid", "dashed", "dashdot", "dotted"] = ..., label: str = ..., - **kwargs + **kwargs, ) -> "AAxes[LineCollection]": ... def eventplot( # type: ignore self, @@ -167,7 +216,7 @@ class AAxes(MplAxes, Generic[_T]): linewidths: float | ArrayLike = ..., colors: Color | list[Color] = ..., linestyles: str | tuple | list = ..., - **kwargs + **kwargs, ) -> "AAxes[list[EventCollection]]": ... def plot(self, *args, scalex=..., scaley=..., data=..., **kwargs) -> "AAxes[list[Line2D]]": ... # type: ignore def plot_date( # type: ignore @@ -178,7 +227,7 @@ class AAxes(MplAxes, Generic[_T]): tz: datetime.tzinfo = ..., xdate: bool = ..., ydate: bool = ..., - **kwargs + **kwargs, ) -> "AAxes[list[Line2D]]": ... def loglog(self, *args, **kwargs) -> "AAxes[list[Line2D]]": ... # type: ignore def semilogx(self, *args, **kwargs) -> "AAxes[list[Line2D]]": ... # type: ignore @@ -192,7 +241,7 @@ class AAxes(MplAxes, Generic[_T]): detrend: Callable = ..., usevlines: bool = True, maxlags: int = 10, - **kwargs + **kwargs, ) -> "AAxes[tuple[np.ndarray, np.ndarray, bool, int]]": ... def step( # type: ignore self, @@ -201,7 +250,7 @@ class AAxes(MplAxes, Generic[_T]): *args, where: Literal["pre", "post", "mid"] = ..., data=..., - **kwargs + **kwargs, ) -> "AAxes[list[Line2D]]": ... def bar( # type: ignore self, @@ -211,7 +260,7 @@ class AAxes(MplAxes, Generic[_T]): bottom: float | ArrayLike = ..., *, align: Literal["center", "edge"] = "center", - **kwargs + **kwargs, ) -> "AAxes[BarContainer]": ... def barh( # type: ignore self, @@ -221,7 +270,7 @@ class AAxes(MplAxes, Generic[_T]): left: float | ArrayLike = ..., *, align: Literal["center", "edge"] = "center", - **kwargs + **kwargs, ) -> "AAxes[BarContainer]": ... def bar_label( # type: ignore self, @@ -231,10 +280,13 @@ class AAxes(MplAxes, Generic[_T]): fmt: str = "%g", label_type: Literal["edge", "center"] = "edge", padding: float = 0, - **kwargs + **kwargs, ) -> "AAxes[list[Text]]": ... def broken_barh( # type: ignore - self, xranges: Sequence[tuple[float, float]], yrange: tuple[float, float], **kwargs + self, + xranges: Sequence[tuple[float, float]], + yrange: tuple[float, float], + **kwargs, ) -> "AAxes[BrokenBarHCollection]": ... def stem( # type: ignore self, @@ -245,7 +297,7 @@ class AAxes(MplAxes, Generic[_T]): bottom: float = 0, label: str | None = None, use_line_collection: bool = True, - orientation: str = "verical" + orientation: str = "verical", ) -> "AAxes[StemContainer]": ... def pie( # type: ignore self, @@ -266,7 +318,7 @@ class AAxes(MplAxes, Generic[_T]): frame: bool = False, rotatelabels: bool = False, *, - normalize: bool = True + normalize: bool = True, ) -> "AAxes[tuple[list[Wedge], list[Text], list[Text]]]": ... def errorbar( # type: ignore self, @@ -285,7 +337,7 @@ class AAxes(MplAxes, Generic[_T]): xuplims: bool = False, errorevery: int = 1, capthick: float | None = None, - **kwargs + **kwargs, ) -> "AAxes[ErrorbarContainer]": ... def boxplot( # type: ignore self, @@ -356,7 +408,7 @@ class AAxes(MplAxes, Generic[_T]): *, edgecolors: Color = ..., plotnonfinite: bool = False, - **kwargs + **kwargs, ) -> "AAxes[PathCollection]": ... def hexbin( # type: ignore self, @@ -378,7 +430,7 @@ class AAxes(MplAxes, Generic[_T]): reduce_C_function=..., mincnt: int | None = None, marginals: bool = False, - **kwargs + **kwargs, ) -> "AAxes[PolyCollection]": ... def arrow(self, x: float, y: float, dx: float, dy: float, **kwargs) -> "AAxes[FancyArrow]": ... # type: ignore def quiverkey( # type: ignore @@ -395,7 +447,7 @@ class AAxes(MplAxes, Generic[_T]): where: ArrayLike = ..., interpolate: bool = ..., step: Literal["pre", "post", "mid"] = ..., - **kwargs + **kwargs, ) -> "AAxes[PolyCollection]": ... def fill_betweenx( # type: ignore self, @@ -405,7 +457,7 @@ class AAxes(MplAxes, Generic[_T]): where: ArrayLike = ..., step: Literal["pre", "post", "mid"] = ..., interpolate: bool = ..., - **kwargs + **kwargs, ) -> "AAxes[PolyCollection]": ... def imshow( # type: ignore self, @@ -425,7 +477,7 @@ class AAxes(MplAxes, Generic[_T]): filterrad: float = 4, resample: bool = ..., url: str = ..., - **kwargs + **kwargs, ) -> "AAxes[AxesImage]": ... def pcolor( # type: ignore self, @@ -436,7 +488,7 @@ class AAxes(MplAxes, Generic[_T]): cmap: str | Colormap = ..., vmin: float | None = None, vmax: float | None = None, - **kwargs + **kwargs, ) -> "AAxes[Collection]": ... def pcolormesh( # type: ignore self, @@ -448,7 +500,7 @@ class AAxes(MplAxes, Generic[_T]): vmax: float | None = None, shading: Literal["flat", "nearest", "gouraud", "auto"] = ..., antialiased=..., - **kwargs + **kwargs, ) -> "AAxes[QuadMesh]": ... def pcolorfast( # type: ignore self, @@ -458,7 +510,7 @@ class AAxes(MplAxes, Generic[_T]): cmap: str | Colormap = ..., vmin: float | None = None, vmax: float | None = None, - **kwargs + **kwargs, ) -> "AAxes[tuple[AxesImage, PcolorImage, QuadMesh]]": ... def contour(self, *args, **kwargs) -> "AAxes[QuadContourSet]": ... # type: ignore def contourf(self, *args, **kwargs) -> "AAxes[QuadContourSet]": ... # type: ignore @@ -481,7 +533,7 @@ class AAxes(MplAxes, Generic[_T]): color: Color | None = ..., label: str | None = ..., stacked: bool = ..., - **kwargs + **kwargs, ) -> "AAxes[tuple[list[list[float]], list[float], BarContainer | list]]": ... @overload def hist( # type: ignore @@ -501,7 +553,7 @@ class AAxes(MplAxes, Generic[_T]): color: Color | None = None, label: str | None = None, stacked: bool = False, - **kwargs + **kwargs, ) -> "AAxes[tuple[list[float], list[float], BarContainer | list]]": ... def stairs( # type: ignore self, @@ -511,7 +563,7 @@ class AAxes(MplAxes, Generic[_T]): orientation: Literal["vertical", "horizontal"] = "vertical", baseline: float | ArrayLike | None = 0, fill: bool = False, - **kwargs + **kwargs, ) -> "AAxes[StepPatch]": ... def hist2d( # type: ignore self, @@ -523,8 +575,10 @@ class AAxes(MplAxes, Generic[_T]): weights=..., cmin: float | None = None, cmax: float | None = None, - **kwargs - ) -> "AAxes[tuple[np.ndarray, np.ndarray, np.ndarray, tuple[float, float] | None]]": ... + **kwargs, + ) -> ( + "AAxes[tuple[np.ndarray, np.ndarray, np.ndarray, tuple[float, float] | None]]" + ): ... def psd( # type: ignore self, x: Sequence, @@ -538,7 +592,7 @@ class AAxes(MplAxes, Generic[_T]): sides: Literal["default", "onesided", "twosided"] = ..., scale_by_freq: bool = ..., return_line: bool = False, - **kwargs + **kwargs, ) -> "AAxes[tuple[np.ndarray, np.ndarray, Line2D]]": ... def csd( # type: ignore self, @@ -554,7 +608,7 @@ class AAxes(MplAxes, Generic[_T]): sides: Literal["default", "onesided", "twosided"] = ..., scale_by_freq: bool = ..., return_line: bool = False, - **kwargs + **kwargs, ) -> "AAxes[tuple[np.ndarray, np.ndarray, Line2D]]": ... def magnitude_spectrum( # type: ignore self, @@ -565,7 +619,7 @@ class AAxes(MplAxes, Generic[_T]): pad_to: int = ..., sides: Literal["default", "onesided", "twosided"] = ..., scale: Literal["default", "linear", "dB"] = "linear", - **kwargs + **kwargs, ) -> "AAxes[tuple[np.ndarray, np.ndarray, Line2D]]": ... def angle_spectrum( # type: ignore self, @@ -575,7 +629,7 @@ class AAxes(MplAxes, Generic[_T]): window: Callable | np.ndarray = ..., pad_to: int = ..., sides: Literal["default", "onesided", "twosided"] = ..., - **kwargs + **kwargs, ) -> "AAxes[tuple[np.ndarray, np.ndarray, Line2D]]": ... def phase_spectrum( # type: ignore self, @@ -585,7 +639,7 @@ class AAxes(MplAxes, Generic[_T]): window: Callable | np.ndarray = ..., pad_to: int = ..., sides: Literal["default", "onesided", "twosided"] = ..., - **kwargs + **kwargs, ) -> "AAxes[tuple[np.ndarray, np.ndarray, Line2D]]": ... def cohere( # type: ignore self, @@ -600,7 +654,7 @@ class AAxes(MplAxes, Generic[_T]): pad_to: int = ..., sides: Literal["default", "onesided", "twosided"] = ..., scale_by_freq: bool = ..., - **kwargs + **kwargs, ) -> "AAxes[tuple[np.ndarray, np.ndarray]]": ... def specgram( # type: ignore self, @@ -620,7 +674,7 @@ class AAxes(MplAxes, Generic[_T]): scale: Literal["default", "linear", "dB"] = "dB", vmin=..., vmax=..., - **kwargs + **kwargs, ) -> "AAxes[tuple[np.ndarray, np.ndarray, np.ndarray, AxesImage]]": ... def spy( # type: ignore self, @@ -630,7 +684,7 @@ class AAxes(MplAxes, Generic[_T]): markersize=..., aspect: Literal["equal", "auto", None] | float = "equal", origin: Literal["upper", "lower"] = ..., - **kwargs + **kwargs, ) -> "AAxes[AxesImage | Line2D]": ... # type: ignore def matshow(self, Z: ArrayLike, **kwargs) -> "AAxes[AxesImage]": ... # type: ignore def violinplot( # type: ignore @@ -720,7 +774,7 @@ class AAxes(MplAxes, Generic[_T]): visible: bool | None = ..., which: Literal["major", "minor", "both"] = ..., axis: Literal["both", "x", "y"] = ..., - **kwargs + **kwargs, ) -> "AAxes[None]": ... def ticklabel_format( # type: ignore self, @@ -730,7 +784,7 @@ class AAxes(MplAxes, Generic[_T]): scilimits=..., useOffset: bool | float = ..., useLocale: bool = ..., - useMathText: bool = ... + useMathText: bool = ..., ) -> "AAxes[None]": ... def locator_params( # type: ignore self, axis: Literal["both", "x", "y"] = ..., tight: bool | None = ..., **kwargs @@ -745,7 +799,7 @@ class AAxes(MplAxes, Generic[_T]): labelpad: float = ..., *, loc: Literal["left", "center", "right"] = ..., - **kwargs + **kwargs, ) -> "AAxes[None]": ... def invert_xaxis(self) -> "AAxes[None]": ... # type: ignore def set_xbound(self, lower: float | None = ..., upper: float | None = ...) -> "AAxes[None]": ... # type: ignore @@ -757,7 +811,7 @@ class AAxes(MplAxes, Generic[_T]): emit: bool = ..., auto: bool | None = ..., xmin: float = ..., - xmax: float = ... + xmax: float = ..., ) -> "AAxes[tuple[float, float]]": ... @overload def set_xlim( # type: ignore @@ -768,7 +822,7 @@ class AAxes(MplAxes, Generic[_T]): auto: bool | None = ..., *, xmin: float = ..., - xmax: float = ... + xmax: float = ..., ) -> "AAxes[tuple[float, float]]": ... def set_xscale(self, value: ..., **kwargs) -> "AAxes[None]": ... # type: ignore def set_ylabel( # type: ignore @@ -778,7 +832,7 @@ class AAxes(MplAxes, Generic[_T]): labelpad: float = ..., *, loc: Literal["bottom", "center", "top"] = ..., - **kwargs + **kwargs, ) -> "AAxes[None]": ... def invert_yaxis(self) -> "AAxes[None]": ... # type: ignore def set_ybound(self, lower: float | None = ..., upper: float | None = ...) -> "AAxes[None]": ... # type: ignore @@ -790,7 +844,7 @@ class AAxes(MplAxes, Generic[_T]): auto: bool | None = ..., *, ymin: float = ..., - ymax: float = ... + ymax: float = ..., ) -> "AAxes[None]": ... def set_yscale( # type: ignore self, value: Literal["linear", "log", "symlog", "logit"] | ScaleBase, **kwargs @@ -819,8 +873,26 @@ class AAxes(MplAxes, Generic[_T]): pad: float = ..., h_pad: float = ..., w_pad: float = ..., - rect: Sequence[float] = ... + rect: Sequence[float] = ..., ) -> _S: ... def __add__(self, other) -> "AxesList": ... - def set_xticks(self, ticks: ArrayLike, labels: ArrayLike | None = None) -> "AAxes[None]": ... - def set_yticks(self, ticks: ArrayLike, labels: ArrayLike | None = None) -> "AAxes[None]": ... + def set_xticks( + self, ticks: ArrayLike, labels: ArrayLike | None = None + ) -> "AAxes[None]": ... + def set_yticks( + self, ticks: ArrayLike, labels: ArrayLike | None = None + ) -> "AAxes[None]": ... + def colorbar(self: _Axes, *args, **kwargs) -> _Axes: ... + def hist_z( # type: ignore + self, + z, + bins: None | int | ArrayLike = ..., + range=..., + density: bool = False, + weights=..., + cmin: float | None = None, + cmax: float | None = None, + **kwargs, + ) -> ( + "AAxes[tuple[np.ndarray, np.ndarray, np.ndarray, tuple[float, float] | None]]" + ): ... diff --git a/aplot/core/axes_list.py b/aplot/core/axes_list.py index 17bcfe0..f48a118 100644 --- a/aplot/core/axes_list.py +++ b/aplot/core/axes_list.py @@ -204,7 +204,10 @@ def __getitem__(self, key: _t.Union[int, _t.Tuple[int, ...]]): # type: ignore if not isinstance(res, AxesList): return AxesList(res) return res - return super().__getitem__(key) + res = super().__getitem__(key) + if not isinstance(res, (AAxes, AxesList)): + return AxesList(res) + return res def flat(self): res = [] @@ -215,3 +218,10 @@ def flat(self): res.append(ax) return AxesList(res) + def hist_z(self, z, **kwargs): + if len(z) == len(self): + for i, ax in enumerate(self): + ax.hist_z(z[i], **kwargs) + return self + self.map(lambda ax: ax.hist_z(z, **kwargs)) + return self diff --git a/aplot/core/axes_list.pyi b/aplot/core/axes_list.pyi index 1e3c1e9..eb4c003 100644 --- a/aplot/core/axes_list.pyi +++ b/aplot/core/axes_list.pyi @@ -854,8 +854,21 @@ class AxesList(List[_T]): ) -> _S: ... def map(self: _S, func: Callable[[AAxes], Any]) -> _S: ... def suptitle(self: _S, title: str) -> _S: ... - def __getitem__( + def __getitem__( # type: ignore self, key: Union[int, Tuple[Union[int, slice], ...], slice], ) -> _T: ... # type: ignore def flat(self) -> "AxesList[AAxes]": ... + def hist_z( # type: ignore + self, + z, + bins: None | int | ArrayLike = ..., + range=..., + density: bool = False, + weights=..., + cmin: float | None = None, + cmax: float | None = None, + **kwargs, + ) -> ( + "AAxes[tuple[np.ndarray, np.ndarray, np.ndarray, tuple[float, float] | None]]" + ): ... diff --git a/aplot/core/figure_class.py b/aplot/core/figure_class.py index af47eff..d045ec0 100644 --- a/aplot/core/figure_class.py +++ b/aplot/core/figure_class.py @@ -1,12 +1,23 @@ -from typing import TYPE_CHECKING, Any, List, Literal, Optional, TypeVar, Union, overload +from copy import copy +from typing import ( + TYPE_CHECKING, + Any, + List, + Literal, + Optional, + Tuple, + TypeVar, + Union, + overload, +) import matplotlib.pyplot as plt from matplotlib.figure import Figure as MplFigure if TYPE_CHECKING: - from mpl_toolkits.mplot3d.axes3d import Axes3D as MplAxes3D from matplotlib.projections.polar import PolarAxes as MplPolarAxes + from mpl_toolkits.mplot3d.axes3d import Axes3D as MplAxes3D from .axes_class import AAxes from .axes_list import AxesList @@ -64,25 +75,109 @@ def axes(self) -> "AxesList[AAxes]": # type: ignore def label_axes( self: _F, - labels: Union[Literal["vertical", "horizontal"], List[str]] = "horizontal", + labels: Union[ + Literal["vertical", "horizontal"], List[str], List[int] + ] = "horizontal", *, axes: Optional["AxesList"] = None, + label_position: Union[Tuple[float, float], List[Tuple[float, float]]] = ( + 0.02, + 0.95, + ), + fontsize: Optional[Union[int, float, List[Union[float, int]]]] = None, + capitalize: bool = False, + **kwargs, ) -> _F: + """Label the axes of the figure. + + Args: + labels (Union[Literal["vertical", "horizontal"], List[str]], optional): + - "vertical": Label the axes vertically first, then horizontally. + - "horizontal": Label the axes horizontally first, then vertically. + - List[str]: List of labels to use for each axes. + - List[int]: Order of axes in which to label. + For example, [2, 0, 1] will label the third axes first, then the first, and finally the second. + Defaults to "horizontal". + axes (Optional["AxesList"], optional): _description_. Defaults to None. + label_position (Union[Tuple[float, float], List[Tuple[float, float]]], optional): + The (x, y) position of the label. Defaults to (0.02, 0.95). + If a list of tuples is provided, each tuple will be used for each axes. + fontsize (Optional[Union[int, float, List[Union[float, int]]]], optional): + Fontsize of the labels. Defaults to None. + If a list of font sizes is provided, each font size will be used for each axes. + capitalize (bool, optional): Capitalize the labels. Defaults to False. + **kwargs: Additional keyword arguments to pass to the text function. + + Raises: + NotImplementedError: Raised if labels is set to "vertical" + + Returns: + _F: Figure itself + + Example: + ``` + import numpy as np + import aplot as ap + + x = np.linspace(0, 10, 100) + y = np.sin(x) + ax = ap.axs(2, 2, figsize=(10, 8)) + ax.plot(x, y) + + ax.fig.label_axes().show() + ``` + """ if axes is None: axes = self.axes axes_list = axes.flat() - if labels == "horizontal": + axes_list = [ax for ax in axes_list if not detect_minor_axes(ax)] + if isinstance(labels, list): + if len(labels) != len(axes_list): + raise ValueError( + "Length of labels should be equal to the number of axes" + ) + if isinstance(labels[0], int): + labels = [f"({chr(65+((int(i)-1)%len(axes_list)))})" for i in labels] + elif labels == "horizontal": labels = [f"({chr(65+i)})" for i in range(len(axes_list))] elif labels == "vertical": raise NotImplementedError("Vertical labels not yet implemented") - for ax, label in zip(axes_list, labels): + + if len(label_position) == 2 and isinstance(label_position[0], (int, float)): + label_position_each = False + else: + label_position_each = True + for i, (ax, label) in enumerate(zip(axes_list, labels)): + text_kwargs = copy(kwargs) + x_pos: float = label_position[i][0] if label_position_each else label_position[0] # type: ignore + y_pos: float = label_position[i][1] if label_position_each else label_position[1] # type: ignore + if fontsize is not None: + fs = fontsize[i] if isinstance(fontsize, (list, tuple)) else fontsize + text_kwargs.setdefault("fontsize", fs) + text_kwargs.setdefault("transform", ax.transAxes) + text_kwargs.setdefault("va", "top") + + label = str(label).upper() if capitalize else str(label).lower() + ax.text( - 0.02, - 0.95, + x_pos, + y_pos, label, - transform=ax.transAxes, - fontsize=14, - va="top", + **text_kwargs, ) return self + + +def detect_minor_axes(ax: "AAxes") -> bool: + """Detect if the axes are minor axes. + + Args: + axes (AxesList[AAxes]): List of axes + + Returns: + bool: True if the axes are minor axes + """ + if hasattr(ax, "_colorbar"): + return True + return False diff --git a/aplot/styles/__init__.py b/aplot/styles/__init__.py index 16fac4c..3c409df 100644 --- a/aplot/styles/__init__.py +++ b/aplot/styles/__init__.py @@ -1,65 +1,72 @@ from ..code_utils import LabelDict -colors = [ - "#0066cc", - "#ffcc00", - "#ff7400", - "#962fbf", - "#8b5a2b", - "#d62976", - "#b8a7ea", - "#ed5555", - "#1da2d8", -] +# colors = [ +# "#0066cc", +# "#ffcc00", +# "#ff7400", +# "#962fbf", +# "#8b5a2b", +# "#d62976", +# "#b8a7ea", +# "#ed5555", +# "#1da2d8", +# ] + DATA = LabelDict( { "markerfacecolor": "none", - "markeredgecolor": colors[0], - "marker": "h", + # "markeredgecolor": colors[0], + "marker": "o", "linestyle": "none", } ) -FIT = LabelDict({"color": colors[1], "linewidth": 2, "label": "fit"}) -GUESS = LabelDict({"color": colors[2], "linewidth": 2, "label": "guess", "alpha": 0.6}) +FIT = LabelDict({"linewidth": 2, "label": "fit"}) # "color": colors[1], +GUESS = LabelDict( + {"linewidth": 2, "label": "guess", "alpha": 0.6} +) # "color": colors[2], -VOLT_TIME = LabelDict( - { - "xlabel": "Voltage, V", - "ylabel": r"Time, $\mu$s", - } -) -IQquadrature = LabelDict( - { - "xlabel": "I quadrature", - "ylabel": "Q quadrature", - "aspect": "equal", - } -) +# VOLT_TIME = LabelDict( +# { +# "xlabel": "Voltage, V", +# "ylabel": r"Time, $\mu$s", +# } +# ) +# IQquadrature = LabelDict( +# { +# "xlabel": "I quadrature", +# "ylabel": "Q quadrature", +# "aspect": "equal", +# } +# ) -DRIVE_FREQ = "Drive frequency (Hz)" -DRIVE_FREQ_GHz = "Drive frequency (GHz)" -READOUT_FREQ = "Readout IF frequency (Hz)" -READOUT_PHASE = "Readout phase (rad)" -LEFT = "Left" -BIAS_VOLTAGE = "Bias voltage (V)" +# DRIVE_FREQ = "Drive frequency (Hz)" +# DRIVE_FREQ_GHz = "Drive frequency (GHz)" +# READOUT_FREQ = "Readout IF frequency (Hz)" +# READOUT_PHASE = "Readout phase (rad)" +# LEFT = "Left" +# BIAS_VOLTAGE = "Bias voltage (V)" -TWO_TONE = LabelDict({"xlabel": BIAS_VOLTAGE, "ylabel": DRIVE_FREQ}, GHz={"ylabel": DRIVE_FREQ_GHz}) -CHEVRON = LabelDict({"xlabel": "Pulse duration (ns)", "ylabel": "Frequency (MHz)"}) -AMPLITUDE_TIME = LabelDict({"xlabel": "Pulse duration (ns)", "ylabel": "Frequency (MHz)"}) -RAMSEY = LabelDict( - { - "label": "ramsey", - "xlabel": "Pulse duration (ns)", - "ylabel": "Readout quadrature", - } -) -REIM = LabelDict( - { - "xlabel": "Re(z)", - "ylabel": "Im(z)", - } -) +# TWO_TONE = LabelDict( +# {"xlabel": BIAS_VOLTAGE, "ylabel": DRIVE_FREQ}, GHz={"ylabel": DRIVE_FREQ_GHz} +# ) +# CHEVRON = LabelDict({"xlabel": "Pulse duration (ns)", "ylabel": "Frequency (MHz)"}) +# AMPLITUDE_TIME = LabelDict( +# {"xlabel": "Pulse duration (ns)", "ylabel": "Frequency (MHz)"} +# ) +# RAMSEY = LabelDict( +# { +# "label": "ramsey", +# "xlabel": "Pulse duration (ns)", +# "ylabel": "Readout quadrature", +# } +# ) +# REIM = LabelDict( +# { +# "xlabel": "Re(z)", +# "ylabel": "Im(z)", +# } +# ) aspect_equal = LabelDict({"aspect": "equal"}) From bb14af1d265bbac86281b789c6baf0d9ea9e39c5 Mon Sep 17 00:00:00 2001 From: kyrylo-gr Date: Wed, 30 Oct 2024 10:26:36 +0100 Subject: [PATCH 11/15] fix version --- aplot/__config__.py | 2 +- aplot/core/figure_class.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aplot/__config__.py b/aplot/__config__.py index d3ec452..3ced358 100644 --- a/aplot/__config__.py +++ b/aplot/__config__.py @@ -1 +1 @@ -__version__ = "0.2.0" +__version__ = "0.2.1" diff --git a/aplot/core/figure_class.py b/aplot/core/figure_class.py index d045ec0..9baaf05 100644 --- a/aplot/core/figure_class.py +++ b/aplot/core/figure_class.py @@ -11,7 +11,6 @@ overload, ) - import matplotlib.pyplot as plt from matplotlib.figure import Figure as MplFigure @@ -69,6 +68,7 @@ def add_axes(self, *args, **kwargs): # type: ignore def show(self): # type: ignore plt.show(self) + @property def axes(self) -> "AxesList[AAxes]": # type: ignore return AxesList(self._axstack.as_list()) # type: ignore From c9a084326f4e54c4c31f66186d98a29f38e0dbbd Mon Sep 17 00:00:00 2001 From: kyrylo-gr Date: Mon, 30 Dec 2024 18:17:24 +0100 Subject: [PATCH 12/15] feat: classical return. fix: hist --- .github/workflows/python-publish.yml | 2 - aplot/__config__.py | 2 +- aplot/core/axes_class.py | 50 +++++++++++++++++++------ aplot/core/axes_class.pyi | 1 + aplot/core/figure_class.py | 55 ++++++++++++++++++++++------ 5 files changed, 84 insertions(+), 26 deletions(-) diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 97e6289..4fc8e95 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -1,7 +1,5 @@ name: Publish-Main-PyPi - -name: Publish-Main-PyPi env: package-name: aplot diff --git a/aplot/__config__.py b/aplot/__config__.py index 3ced358..b5fdc75 100644 --- a/aplot/__config__.py +++ b/aplot/__config__.py @@ -1 +1 @@ -__version__ = "0.2.1" +__version__ = "0.2.2" diff --git a/aplot/core/axes_class.py b/aplot/core/axes_class.py index ebb2c2c..06de7ad 100644 --- a/aplot/core/axes_class.py +++ b/aplot/core/axes_class.py @@ -126,12 +126,29 @@ FILTER_KWARGS = {"hist2d", QuadMesh} +class ClassicReturnAxis: + def __init__(self, axes: "AAxes"): + self.axes = axes + self._previous_state = False + + def __enter__(self): + self._previous_state = self.axes._classical_return + self.axes._classical_return = True + return self.axes + + def __exit__(self, exc_type, exc_value, traceback): + self.axes._classical_return = self._previous_state + if exc_type is not None: + raise + + class AAxes( MplAxes, _t.Generic[_T], ): name = "AAxis" # Give a name for the matplotlib registry _last_result = None + _classical_return = False # _fit_result: FitResult | None = None # __all__ = MplAxes.__all__ + ["fit", "last_result", "fit_result", "res", "set"] # __dict__ = MplAxes.__dict__ ("fit", "last_result", "fit_result", "res", "set") @@ -183,7 +200,7 @@ def __getattribute__(self, name: str): def wrapper(*args, **kwargs): result = func(*args, **kwargs) - if isinstance(result, (MplAxes, AAxes)): + if isinstance(result, (MplAxes, AAxes)) or self._classical_return: return result self._last_result = result return self @@ -210,8 +227,8 @@ def set( # type: ignore "ylabel": ylabel, } ) - super().set(**filter_none_types(kwargs)) - return self + return super().set(**filter_none_types(kwargs)) + # return self def hist2d( # type: ignore self, @@ -222,17 +239,21 @@ def hist2d( # type: ignore ): if y is None: x = np.array(x) - x = x[:, 0] - y = x[:, 1] + y = x[..., 1] + x = x[..., 0] return super().hist2d(x, y, *args, **kwargs) + def hist(self, *args, **kwargs): + with ClassicReturnAxis(self): + return super().hist(*args, **kwargs) + def z_parametric(self, z, **kwargs): - self.plot(np.real(z), np.imag(z), **kwargs) - return self + return self.plot(np.real(z), np.imag(z), **kwargs) + # return self def hist_z(self, z, **kwargs): - self.hist2d(np.real(z), np.imag(z), **kwargs) - return self + return self.hist2d(np.real(z), np.imag(z), **kwargs) + # return self def imshow( # type: ignore self, @@ -303,8 +324,8 @@ def imshow( # type: ignore cbar.ax.set_rasterized(False) else: cbar = None - - return self + return im + # return self def pcolorfast( # type: ignore self, @@ -349,7 +370,8 @@ def pcolorfast( # type: ignore if colorbar: cbar = fig.colorbar(im, cax=cax, orientation="vertical") cbar.ax.set_ylabel(kwargs.get("bar_label", "")) - return self + return im + # return self def autoaxis(self, level: int = 0, func_name="plot") -> "AAxes": variables = get_auto_args(level, func_name) @@ -407,7 +429,11 @@ def update_result(self, result: _R) -> "AAxes[_R]": def colorbar(self, label: _t.Optional[str] = None, *args, **kwargs): c = self.res + assert c is not None cbar = self.fig.colorbar(c, ax=self) if label is not None: cbar.set_label(label) return self + + def classic_return(self): + return ClassicReturnAxis(self) diff --git a/aplot/core/axes_class.pyi b/aplot/core/axes_class.pyi index c119c84..2b54d8f 100644 --- a/aplot/core/axes_class.pyi +++ b/aplot/core/axes_class.pyi @@ -896,3 +896,4 @@ class AAxes(MplAxes, Generic[_T]): ) -> ( "AAxes[tuple[np.ndarray, np.ndarray, np.ndarray, tuple[float, float] | None]]" ): ... + def classic_return(self): ... diff --git a/aplot/core/figure_class.py b/aplot/core/figure_class.py index 9baaf05..c8b3cf7 100644 --- a/aplot/core/figure_class.py +++ b/aplot/core/figure_class.py @@ -12,6 +12,7 @@ ) import matplotlib.pyplot as plt +import numpy as np from matplotlib.figure import Figure as MplFigure if TYPE_CHECKING: @@ -76,7 +77,7 @@ def axes(self) -> "AxesList[AAxes]": # type: ignore def label_axes( self: _F, labels: Union[ - Literal["vertical", "horizontal"], List[str], List[int] + Literal["vertical", "horizontal"], List[Optional[Union[str, int]]] ] = "horizontal", *, axes: Optional["AxesList"] = None, @@ -94,9 +95,9 @@ def label_axes( labels (Union[Literal["vertical", "horizontal"], List[str]], optional): - "vertical": Label the axes vertically first, then horizontally. - "horizontal": Label the axes horizontally first, then vertically. - - List[str]: List of labels to use for each axes. - - List[int]: Order of axes in which to label. - For example, [2, 0, 1] will label the third axes first, then the first, and finally the second. + - List[str | int | None]: List of labels to use for each axes. + if None, the axes will not be labeled. + if int, the axes will be labeled with the corresponding alphabet. Defaults to "horizontal". axes (Optional["AxesList"], optional): _description_. Defaults to None. label_position (Union[Tuple[float, float], List[Tuple[float, float]]], optional): @@ -129,15 +130,26 @@ def label_axes( """ if axes is None: axes = self.axes - axes_list = axes.flat() - axes_list = [ax for ax in axes_list if not detect_minor_axes(ax)] + axes_list = filter_secondary_axes(axes.flat()) + if isinstance(labels, list): if len(labels) != len(axes_list): raise ValueError( "Length of labels should be equal to the number of axes" ) - if isinstance(labels[0], int): - labels = [f"({chr(65+((int(i)-1)%len(axes_list)))})" for i in labels] + + labels = [ + ( + ( + f"({chr(65+((int(i)-1) % len(axes_list)))})" + if isinstance(i, int) + else str(i) + ) + if i is not None + else None + ) + for i in labels + ] elif labels == "horizontal": labels = [f"({chr(65+i)})" for i in range(len(axes_list))] elif labels == "vertical": @@ -148,6 +160,8 @@ def label_axes( else: label_position_each = True for i, (ax, label) in enumerate(zip(axes_list, labels)): + if label is None: + continue text_kwargs = copy(kwargs) x_pos: float = label_position[i][0] if label_position_each else label_position[0] # type: ignore y_pos: float = label_position[i][1] if label_position_each else label_position[1] # type: ignore @@ -158,14 +172,12 @@ def label_axes( text_kwargs.setdefault("va", "top") label = str(label).upper() if capitalize else str(label).lower() - - ax.text( + getattr(ax, "text2D", ax.text)( x_pos, y_pos, label, **text_kwargs, ) - return self @@ -181,3 +193,24 @@ def detect_minor_axes(ax: "AAxes") -> bool: if hasattr(ax, "_colorbar"): return True return False + + +def filter_secondary_axes(axes: "List[AAxes]") -> "List[AAxes]": + """Detect and remove if the axes are secondary axes. + + Args: + axes (List[AAxes]): List of axes + + Returns: + List[AAxes]: List of secondary axes + """ + axes = [ax for ax in axes if not detect_minor_axes(ax)] + axes_list: "List[AAxes]" = [] + for ax1 in axes: + for ax2 in axes_list: + if np.isclose(ax1.get_position().bounds, ax2.get_position().bounds).all(): + break + else: + axes_list.append(ax1) + + return axes_list From 64605b2d3d27e4215710ca64e3c68b47dd07b63c Mon Sep 17 00:00:00 2001 From: kyrylo-gr Date: Fri, 10 Jan 2025 13:16:32 +0100 Subject: [PATCH 13/15] remove scipy from requirements. leave lazy import --- aplot/__config__.py | 2 +- aplot/analysis/array_manipulation.py | 7 +++++-- aplot/analysis/signal_analysis.py | 9 +++++++-- aplot/core/figure_class.py | 27 +++++++++++++++++++++------ requirements.txt | 1 - setup.py | 2 +- 6 files changed, 35 insertions(+), 13 deletions(-) diff --git a/aplot/__config__.py b/aplot/__config__.py index b5fdc75..d31c31e 100644 --- a/aplot/__config__.py +++ b/aplot/__config__.py @@ -1 +1 @@ -__version__ = "0.2.2" +__version__ = "0.2.3" diff --git a/aplot/analysis/array_manipulation.py b/aplot/analysis/array_manipulation.py index 0da10ef..0cdad94 100644 --- a/aplot/analysis/array_manipulation.py +++ b/aplot/analysis/array_manipulation.py @@ -1,7 +1,6 @@ import typing as _t import numpy as np -import scipy ArrayLike = _t.Union[np.ndarray, _t.List] @@ -29,6 +28,8 @@ def argmin2d( Tuple[int, int]: index_y, index_x, i.e. the min value is d[index_y, index_x] """ if filter_ and filter_ > 1: + import scipy + d = scipy.ndimage.uniform_filter(d, size=3, mode="nearest") if x_mask is not None: @@ -133,7 +134,9 @@ def array_from_span( return res -def get_z(I: np.ndarray, Q: np.ndarray) -> np.ndarray: # pylint: disable=invalid-name # noqa: E741 +def get_z( + I: np.ndarray, Q: np.ndarray +) -> np.ndarray: # pylint: disable=invalid-name # noqa: E741 min_len = min(len(I), len(Q)) return I[:min_len] + 1j * Q[:min_len] diff --git a/aplot/analysis/signal_analysis.py b/aplot/analysis/signal_analysis.py index f3d8e8f..f22b9a0 100644 --- a/aplot/analysis/signal_analysis.py +++ b/aplot/analysis/signal_analysis.py @@ -1,7 +1,6 @@ import typing as _t import numpy as np -import scipy def find_h_symmetry_axis(data: np.ndarray) -> int: @@ -13,6 +12,8 @@ def find_h_symmetry_axis(data: np.ndarray) -> int: Returns: (int): x index of the symmetry axis. """ + import scipy + data = (data - np.mean(data)) / np.std(data) # corr = scipy.signal.fftconvolve( # data[:, : len(data[0]) // 2], data[:, ::-1], mode="full" @@ -22,11 +23,15 @@ def find_h_symmetry_axis(data: np.ndarray) -> int: def remove_background(data: np.ndarray, convolve_len: _t.Optional[int] = None): + import scipy + if convolve_len is None: convolve_len = min(50, len(data) // 15) data = ( data - - scipy.signal.convolve2d(data, np.ones((convolve_len, 1)), mode="same", boundary="symm") + - scipy.signal.convolve2d( + data, np.ones((convolve_len, 1)), mode="same", boundary="symm" + ) / convolve_len ) return data - data.mean(axis=1)[:, np.newaxis] diff --git a/aplot/core/figure_class.py b/aplot/core/figure_class.py index c8b3cf7..7081cd7 100644 --- a/aplot/core/figure_class.py +++ b/aplot/core/figure_class.py @@ -41,7 +41,7 @@ def add_subplot(self, *args, **kwargs) -> AAxes: # type: ignore # Ensuring that the custom axes class is used if "projection" not in kwargs and "polar" not in kwargs: kwargs.update({"axes_class": AAxes}) - return super().add_subplot(*args, **kwargs) + return super().add_subplot(*args, **kwargs) # type: ignore def savefig(self, fname: Any, *, transparent=None, **kwargs): # type: ignore super().savefig(fname, transparent=transparent, **kwargs) @@ -76,7 +76,7 @@ def axes(self) -> "AxesList[AAxes]": # type: ignore def label_axes( self: _F, - labels: Union[ + labels: Union[ # type: ignore Literal["vertical", "horizontal"], List[Optional[Union[str, int]]] ] = "horizontal", *, @@ -87,6 +87,7 @@ def label_axes( ), fontsize: Optional[Union[int, float, List[Union[float, int]]]] = None, capitalize: bool = False, + label_titles: Optional[List[str]] = None, **kwargs, ) -> _F: """Label the axes of the figure. @@ -138,10 +139,10 @@ def label_axes( "Length of labels should be equal to the number of axes" ) - labels = [ + labels: List[Optional[str]] = [ ( ( - f"({chr(65+((int(i)-1) % len(axes_list)))})" + f"({chr(65+((int(i)-1) % len(axes_list))).lower()})" if isinstance(i, int) else str(i) ) @@ -150,8 +151,23 @@ def label_axes( ) for i in labels ] + if capitalize: + labels = [(label.upper() if label else label) for label in labels] + + if label_titles is not None: + for i, label in enumerate(label_titles): + if labels[i] is not None: + labels[i] = f"{labels[i]} {label}" + elif labels == "horizontal": - labels = [f"({chr(65+i)})" for i in range(len(axes_list))] + labels = [f"({chr(65+i)})".lower() for i in range(len(axes_list))] + if capitalize: + labels = [(label.upper() if label else label) for label in labels] + if label_titles is not None: + for i, label in enumerate(label_titles): + if labels[i] is not None: + labels[i] = f"{labels[i]} {label}" + elif labels == "vertical": raise NotImplementedError("Vertical labels not yet implemented") @@ -171,7 +187,6 @@ def label_axes( text_kwargs.setdefault("transform", ax.transAxes) text_kwargs.setdefault("va", "top") - label = str(label).upper() if capitalize else str(label).lower() getattr(ax, "text2D", ax.text)( x_pos, y_pos, diff --git a/requirements.txt b/requirements.txt index 4a2a08b..806f221 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,2 @@ numpy -scipy matplotlib \ No newline at end of file diff --git a/setup.py b/setup.py index 72659f7..236a935 100644 --- a/setup.py +++ b/setup.py @@ -35,5 +35,5 @@ def get_version() -> str: "Operating System :: OS Independent", ], python_requires=">=3.8", - install_requires=["numpy", "scipy", "matplotlib"], + install_requires=["numpy", "matplotlib"], ) From 9d1833c3ddaa413cbf40c0562448f3dad8de35ca Mon Sep 17 00:00:00 2001 From: kyrylo-gr Date: Tue, 14 Jan 2025 19:33:31 +0100 Subject: [PATCH 14/15] label_position list can be None --- aplot/core/figure_class.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/aplot/core/figure_class.py b/aplot/core/figure_class.py index 7081cd7..a4eee38 100644 --- a/aplot/core/figure_class.py +++ b/aplot/core/figure_class.py @@ -81,7 +81,9 @@ def label_axes( ] = "horizontal", *, axes: Optional["AxesList"] = None, - label_position: Union[Tuple[float, float], List[Tuple[float, float]]] = ( + label_position: Union[ + Tuple[float, float], List[Optional[Tuple[float, float]]] + ] = ( 0.02, 0.95, ), @@ -176,7 +178,9 @@ def label_axes( else: label_position_each = True for i, (ax, label) in enumerate(zip(axes_list, labels)): - if label is None: + if label is None or ( + label_position_each is True and label_position[i] is None + ): continue text_kwargs = copy(kwargs) x_pos: float = label_position[i][0] if label_position_each else label_position[0] # type: ignore From ddf42ad154677ea5d42fdb883ac48f8e7a863a27 Mon Sep 17 00:00:00 2001 From: kyrylo-gr Date: Thu, 25 Sep 2025 22:08:42 +0200 Subject: [PATCH 15/15] minor fixes --- aplot/core/axes_class.py | 16 ++++------------ aplot/core/axes_list.py | 4 +++- aplot/core/figure_class.py | 14 +++++++++++--- aplot/core/utils.py | 15 +++++++++++++++ 4 files changed, 33 insertions(+), 16 deletions(-) diff --git a/aplot/core/axes_class.py b/aplot/core/axes_class.py index 06de7ad..3821302 100644 --- a/aplot/core/axes_class.py +++ b/aplot/core/axes_class.py @@ -278,14 +278,8 @@ def imshow( # type: ignore colorbar: bool = True, **kwargs, ): - if ( - x is not None - and y is not None - and (len(data) != len(y) or len(data[0]) != len(x)) - ): - raise ValueError( - f"Wrong shapes. {len(data)} != {len(y)} or {len(data[0])} != {len(x)}" - ) + if x is not None and y is not None and (len(data) != len(y) or len(data[0]) != len(x)): + raise ValueError(f"Wrong shapes. {len(data)} != {len(y)} or {len(data[0])} != {len(x)}") imshow_kwargs: dict = imshow_kwds(x, y) imshow_kwargs.update( @@ -321,7 +315,7 @@ def imshow( # type: ignore raise ValueError("The figure is None cannot add colorbar") cbar = fig.colorbar(im, cax=cax, orientation="vertical") cbar.ax.set_ylabel(kwargs.get("bar_label", "")) - cbar.ax.set_rasterized(False) + cbar.ax.set_rasterized(kwargs.get("bar_rasterized", kwargs.get("rasterized", False))) else: cbar = None return im @@ -342,9 +336,7 @@ def pcolorfast( # type: ignore elif len(args) == 3: x, y, data = args elif len(args) > 0: - raise ValueError( - f"Wrong number of arguments: {len(args)}. Should be 0, 1 or 3." - ) + raise ValueError(f"Wrong number of arguments: {len(args)}. Should be 0, 1 or 3.") if data is None: raise ValueError("Data should be provided") diff --git a/aplot/core/axes_list.py b/aplot/core/axes_list.py index f48a118..9307eac 100644 --- a/aplot/core/axes_list.py +++ b/aplot/core/axes_list.py @@ -3,7 +3,7 @@ import numpy as np from .axes_class import AAxes -from .utils import filter_set_kwargs, pop_from_dict +from .utils import filter_set_kwargs, pop_from_dict, get_edge_points # from matplotlib import pyplot as plt @@ -112,6 +112,8 @@ def plot_z_2d( else: raise ValueError("Plot_format should be either bode or real_imag") + x = get_edge_points(x) + y = get_edge_points(y) kwargs_without_xlabel = pop_from_dict(kwargs, "xlabel") self[0].pcolorfast(x=x, y=y, data=data1, **kwargs_without_xlabel) self[1].pcolorfast(x=x, y=y, data=data2, **kwargs) diff --git a/aplot/core/figure_class.py b/aplot/core/figure_class.py index a4eee38..158933f 100644 --- a/aplot/core/figure_class.py +++ b/aplot/core/figure_class.py @@ -29,7 +29,6 @@ class AFigure(MplFigure): - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -69,6 +68,11 @@ def add_axes(self, *args, **kwargs): # type: ignore def show(self): # type: ignore plt.show(self) + return self + + def close(self): # type: ignore + plt.close(self) + return self @property def axes(self) -> "AxesList[AAxes]": # type: ignore @@ -183,8 +187,12 @@ def label_axes( ): continue text_kwargs = copy(kwargs) - x_pos: float = label_position[i][0] if label_position_each else label_position[0] # type: ignore - y_pos: float = label_position[i][1] if label_position_each else label_position[1] # type: ignore + x_pos: float = ( + label_position[i][0] if label_position_each else label_position[0] + ) # type: ignore + y_pos: float = ( + label_position[i][1] if label_position_each else label_position[1] + ) # type: ignore if fontsize is not None: fs = fontsize[i] if isinstance(fontsize, (list, tuple)) else fontsize text_kwargs.setdefault("fontsize", fs) diff --git a/aplot/core/utils.py b/aplot/core/utils.py index 7872d90..c79a251 100644 --- a/aplot/core/utils.py +++ b/aplot/core/utils.py @@ -1,5 +1,7 @@ import typing as _t +import numpy as np + from .typing import NoneType if _t.TYPE_CHECKING: @@ -95,3 +97,16 @@ def pop_from_dict(data, keys: _t.Union[str, _t.Tuple[str, ...]]): if key in main: main.pop(key) return main + + +def get_center_points(x): + return x[:-1] + np.diff(x) / 2 + + +def get_edge_points(m): + m = np.asarray(m) + edges = np.empty(m.size + 1) + edges[1:-1] = (m[:-1] + m[1:]) / 2 # inner edges + edges[0] = m[0] - (m[1] - m[0]) / 2 # extrapolate left edge + edges[-1] = m[-1] + (m[-1] - m[-2]) / 2 # extrapolate right edge + return edges