Skip to content

Commit aec22dc

Browse files
committed
Adequating to changes in the base formats to use xarray datatrees
1 parent ac3d108 commit aec22dc

3 files changed

Lines changed: 10 additions & 14 deletions

File tree

src/astrohack/beamcut.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -148,25 +148,21 @@ def beamcut(
148148

149149
if executed_graph:
150150
logger.info("Finished processing")
151-
output_attr_file = "{name}/{ext}".format(
152-
name=beamcut_params["beamcut_name"], ext=".beamcut_input"
151+
add_caller_and_version_to_dict(beamcut_params, direct_call=True)
152+
beamcut_mds = AstrohackBeamcutFile.create_from_input_parameters(
153+
beamcut_params["beamcut_name"], beamcut_params
153154
)
154-
root = xr.DataTree(name="root")
155-
root.attrs.update(beamcut_params)
156-
add_caller_and_version_to_dict(root.attrs, direct_call=True)
157155

158156
for xdtree in graph_results:
159157
ant, ddi = xdtree.name.split("-")
160-
if ant in root.keys():
161-
ant = root.children[ant].update({ddi: xdtree})
158+
if ant in beamcut_mds.keys():
159+
ant = beamcut_mds[ant].update({ddi: xdtree})
162160
else:
163161
ant_tree = xr.DataTree(name=ant, children={ddi: xdtree})
164-
root = root.assign({ant: ant_tree})
162+
beamcut_mds[ant] = ant_tree
165163

166-
root.to_zarr(beamcut_params["beamcut_name"], mode="w", consolidated=True)
164+
beamcut_mds.write()
167165

168-
beamcut_mds = AstrohackBeamcutFile(beamcut_params["beamcut_name"])
169-
beamcut_mds.open()
170166
return beamcut_mds
171167
else:
172168
logger.warning("No data to process")

src/astrohack/utils/file.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ def _get_ds_metadata(ds):
554554
elif isinstance(ds, xr.Dataset) or isinstance(ds, xr.DataTree):
555555
metadata = getattr(ds, "attrs")
556556
else:
557-
metadata = ds.xdt.attrs
557+
metadata = ds.root.attrs
558558
return metadata
559559

560560

src/astrohack/utils/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,9 @@ def compute_graph(
123123
"""
124124

125125
delayed_list = []
126-
if hasattr(looping_dict, "xdt"):
126+
if hasattr(looping_dict, "root"):
127127
_construct_xdtree_graph_recursively(
128-
xr_datatree=looping_dict.xdt,
128+
xr_datatree=looping_dict.root,
129129
chunk_function=chunk_function,
130130
param_dict=param_dict,
131131
delayed_list=delayed_list,

0 commit comments

Comments
 (0)