Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/boutdata/restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,13 @@ def scalevar(var, factor, path="."):


def create(
averagelast=1, final=-1, path="data", output="./", informat="nc", outformat=None
averagelast=1,
averageZ=0.0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Default to None?

Suggested change
averageZ=0.0,
averageZ : None | float = None,

final=-1,
path="data",
output="./",
informat="nc",
outformat=None,
):
"""Create restart files from data (dmp) files.

Expand All @@ -475,12 +481,15 @@ def create(
File extension of original files (default: "nc")
outformat : str, optional
File extension of new files (default: use the same as `informat`)

averageZ : float, optional
Weight average in Z direction: 0 = no averaging, 1 = average in Z
e.g. averageZ = 0.99 will reduce Z fluctuations to 1% of original
"""

if outformat is None:
outformat = informat

averageZ = min([averageZ, 1.0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard error if outside expected range?

Suggested change
averageZ = min([averageZ, 1.0])
if not (0 < averageZ < 1.0):
raise ValueError(f"`averageZ` (={averageZ}) must be in (0, 1)")

path = pathlib.Path(path)
output = pathlib.Path(output)

Expand Down Expand Up @@ -541,6 +550,13 @@ def create(
data[(final - averagelast) : final, :, :, :], axis=0
)

if averageZ > 0.0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if averageZ > 0.0:
if averageZ is not None:

data_averaged = np.tile(
np.mean(data_slice, axis=-1)[..., np.newaxis],
(1, 1, data_slice.shape[-1]),
)
data_slice = averageZ * data_averaged + (1 - averageZ) * data_slice

print(data_slice.shape)
# This attribute results in the correct (x,y,z) dimension labels
data_slice.attributes["bout_type"] = "Field3D"
Expand Down
Loading