diff --git a/examples/Waveforms+Polarities.py b/examples/Waveforms+Polarities.py old mode 100755 new mode 100644 index d059e694..455e65f2 --- a/examples/Waveforms+Polarities.py +++ b/examples/Waveforms+Polarities.py @@ -8,7 +8,7 @@ from mtuq.graphics import plot_data_greens2, plot_beachball, plot_polarities, plot_misfit_lune from mtuq.grid import FullMomentTensorGridSemiregular from mtuq.grid_search import grid_search -from mtuq.misfit import WaveformMisfit, PolarityMisfit +from mtuq.misfit import WaveformMisfit, PolarityMisfit, polarities_from_dict from mtuq.process_data import ProcessData from mtuq.util import fullpath, merge_dicts, save_json from mtuq.util.cap import parse_station_codes, Trapezoid @@ -108,7 +108,7 @@ "NSKI": +1, "PERI": +1, "SOLD": 0, - "TUPA": 1, + "TUPA": +1, } @@ -218,9 +218,7 @@ if comm.rank==0: print('Evaluating polarity misfit...\n') - polarities = np.zeros(len(stations)) - for _i, station in enumerate(stations): - polarities[_i] = polarities_dict[station.station] + polarities = polarities_from_dict(polarities_dict, stations) results_polarity = grid_search( polarities, greens_bw, polarity_misfit, origin, grid) diff --git a/mtuq/misfit/__init__.py b/mtuq/misfit/__init__.py index 087642d6..8819b9c1 100644 --- a/mtuq/misfit/__init__.py +++ b/mtuq/misfit/__init__.py @@ -1,7 +1,7 @@ from mtuq.misfit.waveform import WaveformMisfit -from mtuq.misfit.polarity import PolarityMisfit +from mtuq.misfit.polarity import PolarityMisfit, polarities_from_dict # # for backward compatibility diff --git a/mtuq/misfit/polarity.py b/mtuq/misfit/polarity.py index 5b3d7830..3fa9bd62 100644 --- a/mtuq/misfit/polarity.py +++ b/mtuq/misfit/polarity.py @@ -362,6 +362,26 @@ def _model_type(greens): return model_type +def polarities_from_dict(dict_polarity, stations): + """ + Converts a dictionary of polarities to a NumPy array based on the provided stations. + + Args: + dict_polarity (dict): Dictionary mapping station names to polarity values. + stations (list): List of station objects. + + Returns: + numpy.ndarray: NumPy array containing polarity values corresponding to the stations. + """ + + polarities = np.zeros(len(stations)) + for i, station in enumerate(stations): + station_name = station.station + if station_name in dict_polarity: + polarities[i] = dict_polarity[station_name] + else: + print(f'Station {station_name} not found in the dictionary') + return polarities def _check(greens, method): return diff --git a/mtuq/util/__init__.py b/mtuq/util/__init__.py index d283f9c0..ca3fa005 100644 --- a/mtuq/util/__init__.py +++ b/mtuq/util/__init__.py @@ -352,7 +352,6 @@ def dataarray_idxmax(da, warnings=True): da = da[0] return da.coords - def defaults(kwargs, defaults): for key in defaults: if key not in kwargs: diff --git a/setup/code_generator.py b/setup/code_generator.py index dd404203..f5a2695a 100755 --- a/setup/code_generator.py +++ b/setup/code_generator.py @@ -513,7 +513,7 @@ "NSKI": +1, "PERI": +1, "SOLD": 0, - "TUPA": 1, + "TUPA": +1, } @@ -1573,9 +1573,7 @@ if comm.rank==0: print('Evaluating polarity misfit...\\n') - polarities = np.zeros(len(stations)) - for _i, station in enumerate(stations): - polarities[_i] = polarities_dict[station.station] + polarities = polarities_from_dict(polarities_dict, stations) results_polarity = grid_search( polarities, greens_bw, polarity_misfit, origin, grid) @@ -1990,7 +1988,7 @@ def isclose(a, b, atol=1.e6, rtol=1.e-6): 'plot_beachball', 'plot_beachball, plot_polarities', 'from mtuq.misfit import Misfit', - 'from mtuq.misfit import WaveformMisfit, PolarityMisfit', + 'from mtuq.misfit import WaveformMisfit, PolarityMisfit, polarities_from_dict', )) file.write(Docstring_WaveformsPolarities) file.write(Paths_Syngine)