Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "tabsim"
version = "0.0.11"
version = "0.0.12"
license = {file = "LICENSE"}
readme = "README.md"
# dynamic = ["version", "readme"]
Expand Down
114 changes: 57 additions & 57 deletions tabsim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,11 @@ def sigma_value(value: Union[str, float, int], obs: Observation):
raise ValueError()


def load_obs(obs_spec: dict) -> Observation:
def load_obs(sim_config: dict) -> Observation:

tel_ = obs_spec["telescope"]
obs_ = obs_spec["observation"]
dask_ = obs_spec["dask"]
tel_ = sim_config["telescope"]
obs_ = sim_config["observation"]
dask_ = sim_config["dask"]

def arange(start: float, delta: float, n: int):
x = da.arange(start, start + n * delta, delta)[:n]
Expand Down Expand Up @@ -342,14 +342,14 @@ def add_power_spectrum_sources(obs: Observation, ps_rand: dict) -> None:
obs.addAstro(I[:, None, :], ra, dec)


def add_astro_sources(obs: Observation, obs_spec: dict) -> None:
def add_astro_sources(obs: Observation, sim_config: dict) -> None:
"""Add astronomical sources from the simulation config file to the observation object.

Parameters
----------
obs : Observation
Observation object instance.
obs_spec : dict
sim_config : dict
Simulation config dictionary.
"""

Expand All @@ -358,7 +358,7 @@ def add_astro_sources(obs: Observation, obs_spec: dict) -> None:
"gauss": obs.addAstroGauss,
"exp": obs.addAstroExp,
}
ast_ = obs_spec["ast_sources"]
ast_ = sim_config["ast_sources"]

if ast_["pow_spec"]["random"]["type"]:
add_power_spectrum_sources(obs, ast_["pow_spec"]["random"])
Expand Down Expand Up @@ -467,9 +467,9 @@ def generate_spectra(spec_df: pd.DataFrame, freqs: Array, id_key: str) -> tuple:
return np.array(ids), da.atleast_2d(spectra)


def add_satellite_sources(obs: Observation, obs_spec: dict) -> None:
def add_satellite_sources(obs: Observation, sim_config: dict) -> None:

sat_ = obs_spec["rfi_sources"]["satellite"]
sat_ = sim_config["rfi_sources"]["satellite"]

# Circular path based Satellites
if len(sat_["sat_ids"]) > 0:
Expand Down Expand Up @@ -513,10 +513,10 @@ def add_satellite_sources(obs: Observation, obs_spec: dict) -> None:


def add_tle_satellite_sources(
obs: Observation, obs_spec: dict, spacetrack_path: str
obs: Observation, sim_config: dict, spacetrack_path: str
) -> None:

sat_ = obs_spec["rfi_sources"]["tle_satellite"]
sat_ = sim_config["rfi_sources"]["tle_satellite"]

# TLE path based Satellites
tle_cond = [
Expand Down Expand Up @@ -588,9 +588,9 @@ def add_tle_satellite_sources(
print("No NORAD IDs matching in 'norad_spec_model' file given.")


def add_stationary_sources(obs: Observation, obs_spec: dict) -> None:
def add_stationary_sources(obs: Observation, sim_config: dict) -> None:

stat_ = obs_spec["rfi_sources"]["stationary"]
stat_ = sim_config["rfi_sources"]["stationary"]
if len(stat_["loc_ids"]) > 0:
ids = stat_["loc_ids"]
path_ids = []
Expand Down Expand Up @@ -637,9 +637,9 @@ def add_stationary_sources(obs: Observation, obs_spec: dict) -> None:
print("No locations IDs matching in 'spec_model' file given.")


def add_gains(obs: Observation, obs_spec: dict) -> None:
def add_gains(obs: Observation, sim_config: dict) -> None:

gains_ = obs_spec["gains"]
gains_ = sim_config["gains"]
gain_offset = gains_["G0_mean"] != 1 or gains_["G0_std"] != 0
gain_var = gains_["Gt_std_amp"] != 0 or gains_["Gt_std_phase"] != 0
if gain_offset or gain_var:
Expand All @@ -659,11 +659,11 @@ def add_gains(obs: Observation, obs_spec: dict) -> None:
print("No gains added ...")


def plot_diagnostics(obs: Observation, obs_spec: dict, save_path: str) -> None:
def plot_diagnostics(obs: Observation, sim_config: dict, save_path: str) -> None:

from tabsim.plot import plot_uv, plot_src_alt, plot_angular_seps

diag_ = obs_spec["diagnostics"]
diag_ = sim_config["diagnostics"]

if diag_["src_alt"]:
print()
Expand Down Expand Up @@ -713,29 +713,29 @@ def write_to_zarr(obs: Observation, zarr_path: str, overwrite: bool) -> None:
print(f"zarr Write Time : {end - start}")


def save_data(obs: Observation, obs_spec: dict, zarr_path: str, ms_path: str) -> None:
def save_data(obs: Observation, sim_config: dict, zarr_path: str, ms_path: str) -> None:

if obs_spec["output"]["zarr"] or obs_spec["output"]["ms"]:
if sim_config["output"]["zarr"] or sim_config["output"]["ms"]:
print()
print("Calculating visibilities ...")
obs.calculate_vis(flags=obs_spec["output"]["flag_data"])
obs.calculate_vis(flags=sim_config["output"]["flag_data"])

overwrite = obs_spec["output"]["overwrite"]
overwrite = sim_config["output"]["overwrite"]

if obs_spec["output"]["zarr"] and obs_spec["output"]["ms"]:
if sim_config["output"]["zarr"] and sim_config["output"]["ms"]:
write_to_zarr(obs, zarr_path, overwrite)
xds = xr.open_zarr(zarr_path)
write_to_ms(xds, ms_path, overwrite)
print_signal_specs(
xds.vis_rfi.data, xds.vis_ast.data, xds.noise_data.data, xds.flags.data
)
elif obs_spec["output"]["zarr"]:
elif sim_config["output"]["zarr"]:
write_to_zarr(obs, zarr_path, overwrite)
xds = xr.open_zarr(zarr_path)
print_signal_specs(
xds.vis_rfi.data, xds.vis_ast.data, xds.noise_data.data, xds.flags.data
)
elif obs_spec["output"]["ms"]:
elif sim_config["output"]["ms"]:
write_to_ms(obs.dataset, ms_path, overwrite)
xds = xds_from_ms(ms_path)
print_signal_specs(
Expand All @@ -750,15 +750,15 @@ def save_data(obs: Observation, obs_spec: dict, zarr_path: str, ms_path: str) ->
)


def save_inputs(obs: Observation, obs_spec: dict, save_path: str) -> None:
def save_inputs(obs: Observation, sim_config: dict, save_path: str) -> None:

for key in ["enu_path", "itrf_path"]:
path = obs_spec["telescope"][key]
path = sim_config["telescope"][key]
if path is not None:
shutil.copy(path, save_path)

for key in obs_spec["ast_sources"].keys():
path = obs_spec["ast_sources"][key]["path"]
for key in sim_config["ast_sources"].keys():
path = sim_config["ast_sources"][key]["path"]
if path is not None:
shutil.copy(path, save_path)

Expand All @@ -772,14 +772,14 @@ def save_inputs(obs: Observation, obs_spec: dict, save_path: str) -> None:
"spec_model",
]
for key1, key2 in zip(key, subkey):
path = obs_spec["rfi_sources"][key1][key2]
path = sim_config["rfi_sources"][key1][key2]
if path is not None:
shutil.copy(path, save_path)

np.savetxt(os.path.join(save_path, "norad_ids.yaml"), obs.norad_ids, fmt="%i")

with open(os.path.join(save_path, "sim_config.yaml"), "w") as fp:
yaml.dump(obs_spec, fp)
yaml.dump(sim_config, fp)


def print_fringe_freq_sat(obs: Observation):
Expand Down Expand Up @@ -876,7 +876,7 @@ def check_telescope_defintion(tel_def: dict):


def run_sim_config(
obs_spec: Optional[dict] = None,
sim_config: Optional[dict] = None,
config_path: Optional[str] = None,
spacetrack_path: Optional[str] = None,
) -> Tuple[Observation, str]:
Expand All @@ -892,24 +892,24 @@ def run_sim_config(
print(datetime.now())

if config_path:
obs_spec = load_config(config_path, config_type="sim")
elif obs_spec is None:
raise ValueError("obs_spec or config_path must be defined.")
sim_config = load_config(config_path, config_type="sim")
elif sim_config is None:
raise ValueError("sim_config or config_path must be defined.")

rfi_def = get_rfi_definitions()
obs_spec["rfi_sources"] = deep_update(obs_spec["rfi_sources"], rfi_def)
sim_config["rfi_sources"] = deep_update(sim_config["rfi_sources"], rfi_def)

if not check_telescope_defintion(obs_spec["telescope"]):
tel_def = get_telescope_definitions(obs_spec["telescope"]["name"])
obs_spec["telescope"] = deep_update(obs_spec["telescope"], tel_def)
if not check_telescope_defintion(sim_config["telescope"]):
tel_def = get_telescope_definitions(sim_config["telescope"]["name"])
sim_config["telescope"] = deep_update(sim_config["telescope"], tel_def)

obs = load_obs(obs_spec)
add_astro_sources(obs, obs_spec)
add_satellite_sources(obs, obs_spec)
if obs_spec["rfi_sources"]["tle_satellite"]["max_n_sat"] != 0 and spacetrack_path:
add_tle_satellite_sources(obs, obs_spec, spacetrack_path)
add_stationary_sources(obs, obs_spec)
add_gains(obs, obs_spec)
obs = load_obs(sim_config)
add_astro_sources(obs, sim_config)
add_satellite_sources(obs, sim_config)
if sim_config["rfi_sources"]["tle_satellite"]["max_n_sat"] != 0 and spacetrack_path:
add_tle_satellite_sources(obs, sim_config, spacetrack_path)
add_stationary_sources(obs, sim_config)
add_gains(obs, sim_config)

# Change this for calculating all RFI sources
if obs.n_rfi_satellite > 0:
Expand All @@ -920,39 +920,39 @@ def run_sim_config(
print(obs)

obs_name = mk_obs_name(
obs_spec["output"]["prefix"], obs, obs_spec["output"]["suffix"]
sim_config["output"]["prefix"], obs, sim_config["output"]["suffix"]
)
save_path, zarr_path, ms_path = mk_obs_dir(
obs_spec["output"]["path"], obs_name, obs_spec["output"]["overwrite"]
sim_config["output"]["path"], obs_name, sim_config["output"]["overwrite"]
)

input_path = os.path.join(save_path, "input_data")

os.makedirs(input_path, exist_ok=True)
save_inputs(obs, obs_spec, input_path)
save_inputs(obs, sim_config, input_path)

print()
print(f"Writing data to : {save_path}")

plot_cond = np.array(
[
obs_spec["diagnostics"]["rfi_seps"],
obs_spec["diagnostics"]["src_alt"],
obs_spec["diagnostics"]["uv_cov"],
sim_config["diagnostics"]["rfi_seps"],
sim_config["diagnostics"]["src_alt"],
sim_config["diagnostics"]["uv_cov"],
]
)
if np.any(plot_cond):
plot_diagnostics(obs, obs_spec, save_path)
plot_diagnostics(obs, sim_config, save_path)
else:
print("\nNo diagnostic plots.")

save_data(obs, obs_spec, zarr_path, ms_path)
save_data(obs, sim_config, zarr_path, ms_path)

if obs_spec["output"]["accumulate_ms"] is not None:
if sim_config["output"]["accumulate_ms"] is not None:
from tabsim.write import add_to_ms

xds = xr.open_zarr(zarr_path)
add_to_ms(xds, obs_spec["output"]["accumulate_ms"])
add_to_ms(xds, sim_config["output"]["accumulate_ms"])

end = datetime.now()
print()
Expand All @@ -966,8 +966,8 @@ def run_sim_config(
sys.stdout = backup

if (
not obs_spec["output"]["keep_sim"]
and obs_spec["output"]["accumulate_ms"] is not None
not sim_config["output"]["keep_sim"]
and sim_config["output"]["accumulate_ms"] is not None
):
shutil.rmtree(save_path)
return obs, save_path
Expand Down
Loading
Loading