From b3945cf74f88c845c5f416da0cccc7aa96eaee24 Mon Sep 17 00:00:00 2001 From: Yingtian Tang Date: Wed, 13 Sep 2023 17:40:26 +0200 Subject: [PATCH 1/2] add gitignore; extend gather indexes to signle coord --- .gitignore | 20 ++++++++++++++++++++ brainio/assemblies.py | 26 ++++++++++++++++++++++++-- 2 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3521f7b --- /dev/null +++ b/.gitignore @@ -0,0 +1,20 @@ +*.egg* +*.pyc +.idea +*~ +*.egg-info +*.DS_Store +.*swp +._* +.ipynb_checkpoints +rnn_searches/scripts/logs/ +rnn_searches/scripts/ +rnn_searches/neural_fitting/model* +rnn_searches/neural_fitting/10Lv9res_lr01_bs256_1/* +*.aux +*.out +*.lof +*.log +*.toc +*.bbl +*.dvi diff --git a/brainio/assemblies.py b/brainio/assemblies.py index 13a19b4..a58d647 100644 --- a/brainio/assemblies.py +++ b/brainio/assemblies.py @@ -431,14 +431,36 @@ def get_levels(assembly): def gather_indexes(assembly): - """This is only necessary as long as xarray cannot persist MultiIndex to netCDF. """ + """This is only necessary as long as xarray cannot persist MultiIndex to netCDF.""" coords_d = {} for dim in assembly.dims: - coord_names = list(get_metadata(assembly, dims=(dim,), names_only=True, include_indexes=False, include_levels=False)) + coord_names = list( + get_metadata( + assembly, + dims=(dim,), + names_only=True, + include_indexes=False, + include_levels=False, + ) + ) if coord_names: coords_d[dim] = coord_names + + # fix single-coord-single-dim + to_stack = {} + for dim, coords in coords_d.items(): + if len(coords) == 1 and len(list(get_metadata(assembly, dims=(dim,)))) == 1: + to_stack[dim] = coords[0] + if coords_d: assembly = assembly.set_index(append=True, **coords_d) + + if to_stack: + # single-coord stacking trick + for dim, coord in to_stack.items(): + assembly = assembly.rename({dim: coord}) + assembly = assembly.stack({dim: [coord]}) + return assembly From 69d9adbd53118532552080adab2e80ce8865e95e Mon Sep 17 00:00:00 2001 From: YingtianDt Date: Tue, 4 Jun 2024 13:30:30 +0200 Subject: [PATCH 2/2] update tests --- .gitignore | 18 ++---------------- tests/test_assemblies.py | 4 ++-- tests/test_packaging.py | 2 ++ 3 files changed, 6 insertions(+), 18 deletions(-) diff --git a/.gitignore b/.gitignore index 3521f7b..c3f1f42 100644 --- a/.gitignore +++ b/.gitignore @@ -1,20 +1,6 @@ -*.egg* *.pyc -.idea -*~ *.egg-info *.DS_Store .*swp -._* -.ipynb_checkpoints -rnn_searches/scripts/logs/ -rnn_searches/scripts/ -rnn_searches/neural_fitting/model* -rnn_searches/neural_fitting/10Lv9res_lr01_bs256_1/* -*.aux -*.out -*.lof -*.log -*.toc -*.bbl -*.dvi +__pycache__ +.pytest_cache \ No newline at end of file diff --git a/tests/test_assemblies.py b/tests/test_assemblies.py index eaeb626..b5215eb 100644 --- a/tests/test_assemblies.py +++ b/tests/test_assemblies.py @@ -67,7 +67,7 @@ def test_get_levels(): }, dims=['a', 'b'] ) - assert get_levels(assy) == ["up", "down"] + assert get_levels(assy) == ["up", "down", "sideways"] class TestSubclassing: @@ -115,7 +115,7 @@ def test_reset_index(self): dims=['a', 'b'] ) da = DataArray(assy) - da = da.reset_index(["up", "down"]) + da = da.reset_index(["up", "down", "sideways"]) assert get_levels(da) == [] def test_repr(self): diff --git a/tests/test_packaging.py b/tests/test_packaging.py index d746e35..e963e8a 100644 --- a/tests/test_packaging.py +++ b/tests/test_packaging.py @@ -57,6 +57,8 @@ def test_reset_index_levels(): ) assert assy["a"].variable.level_names == ["up", "down"] assy = assy.reset_index(["up", "down"]) + assert get_levels(assy) == ["sideways"] + assy = assy.reset_index(["sideways"]) assert get_levels(assy) == []