Skip to content
This repository was archived by the owner on Aug 18, 2020. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions magnolia/python/analysis/bss_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,12 @@ def evaluate(input_path, output_csv_file, target_stype=None, eval_sr=8000, num_s
'/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/lab41/in_sample_test.csv'],
['/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/lab41/out_of_sample_test',
'/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/lab41/out_of_sample_test.csv'],

['/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/large_lab41/in_sample_test',
'/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/large_lab41/in_sample_test.csv'],
['/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/large_lab41/out_of_sample_test',
'/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/large_lab41/out_of_sample_test.csv'],

['/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/chimera/in_sample_test',
'/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/chimera/mi_in_sample_test.csv',
'mi'],
Expand All @@ -123,13 +125,27 @@ def evaluate(input_path, output_csv_file, target_stype=None, eval_sr=8000, num_s
['/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/chimera/out_of_sample_test',
'/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/chimera/dc_out_of_sample_test.csv',
'dc'],

['/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/mask_sce/in_sample_test',
'/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/mask_sce/mi_in_sample_test.csv',
'mi'],
['/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/mask_sce/out_of_sample_test',
'/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/mask_sce/mi_out_of_sample_test.csv',
'mi'],
['/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/mask_sce/in_sample_test',
'/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/mask_sce/dc_in_sample_test.csv',
'dc'],
['/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/mask_sce/out_of_sample_test',
'/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/mask_sce/dc_out_of_sample_test.csv',
'dc'],

['/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/snmf/in_sample_test',
'/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/snmf/in_sample_test.csv'],
['/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/snmf/out_of_sample_test',
'/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/snmf/out_of_sample_test.csv']
]

args = args[8:]
args = args[8:12]

# Parallel
#processes = []
Expand All @@ -142,9 +158,9 @@ def evaluate(input_path, output_csv_file, target_stype=None, eval_sr=8000, num_s

# Parallel
#pool = mp.Pool(processes=min(len(args), os.cpu_count() - 1))
pool = mp.Pool(processes=2)
pool.starmap(evaluate, args)
#pool = mp.Pool(processes=2)
#pool.starmap(evaluate, args)

# Sequential
#for arg in args:
# evaluate(*arg)
for arg in args:
evaluate(*arg)
117 changes: 69 additions & 48 deletions magnolia/python/analysis/comparison_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@ def format_dae_columns(df):
cols[5] = 'Input_SNR'
cols[6] = 'Input_SDR'
cols[7] = 'Output_SDR'

df.columns = cols


def load_dataframes(models):
for model in models:
if 'in_set' in model:
model['in_set_df'] = pd.read_csv(model['in_set'])
if model['name'] == 'DAE':
format_dae_columns(model['in_set_df'])
#if model['name'] == 'DAE':
# format_dae_columns(model['in_set_df'])
if 'out_of_set' in model:
model['out_of_set_df'] = pd.read_csv(model['out_of_set'])
if model['name'] == 'DAE':
format_dae_columns(model['out_of_set_df'])
#if model['name'] == 'DAE':
# format_dae_columns(model['out_of_set_df'])


def error_on_the_mean(x):
Expand All @@ -50,7 +50,7 @@ def make_sdr_delta_versus_noise_source_plot(models, df_base_name):
df_name = '{}_df'.format(df_base_name)
mean_multiindex_name = ('SDR_Improvement', 'mean')
eotm_multiindex_name = ('SDR_Improvement', 'error_on_the_mean')

all_groups = {}
all_colors = {}
all_names = []
Expand All @@ -68,25 +68,27 @@ def make_sdr_delta_versus_noise_source_plot(models, df_base_name):
all_label_df = all_groups[model['name']]
else:
all_label_df = all_label_df.merge(all_groups[model['name']], how='outer')

labels = all_label_df['Noise_Type'].unique()
n_groups = len(labels)
del all_label_df

# create plot
fig, ax = plt.subplots(figsize=(8, 6))
index = np.arange(n_groups)
bar_width = 0.25
plt.xlim(-0.5, n_groups + 0.5)
bar_width = (n_groups + 1)/(1.15*n_groups*len(models))
#bar_width = 0.15
opacity = 0.8

offset = 0
all_rects = []
for entry_name in all_names:
color = 'b'
groups = all_groups[entry_name]
if entry_name in all_colors:
color = all_colors[entry_name]

#male_means = groups[groups['Speaker_Sex'] == 'M']
#male_means = male_means[male_means['Noise_Type'] == labels].fillna(0)
male_means = groups[groups['Noise_Type'] == labels].fillna(0)
Expand All @@ -98,36 +100,38 @@ def make_sdr_delta_versus_noise_source_plot(models, df_base_name):
color=color,
label=entry_name,
yerr=male_errors)

offset += 1

for i in range(len(labels)):
labels[i] = labels[i].replace('_', ' ')
plt.xticks(index + (len(all_names)/2 - 0.5)*bar_width, labels)
for tick in ax.get_xticklabels():
tick.set_rotation(75)
tick.set_fontsize(12)
plt.xlabel('Noise Type')
plt.ylabel('SDR Improvement')
plt.ylabel('SDR Improvement (dB)')
plt.title('SDR Improvement Versus Noise Type', fontsize=20)
ax.xaxis.label.set_size(15)
ax.yaxis.label.set_size(15)
ylim = [-0.5, ax.get_ylim()[1]]

ylim = [-0.5, 1.3*ax.get_ylim()[1]]
#ylim[0] = -0.5
ax.set_ylim(ylim)
#plt.axis([0, 11, -.5, 16])
plt.legend(fontsize=12, edgecolor='black')
plt.legend(fontsize=12, edgecolor='black',
loc='upper center', ncol=3, mode='expand')
#ax.legend(bbox_to_anchor=(1.5, 1.5))
plt.tight_layout()

plt.savefig('{}_sdr_delta_versus_noise_type.pdf'.format(df_base_name), format='pdf')


def make_sdr_delta_versus_input_snr_plot(models, df_base_name, bins):
df_name = '{}_df'.format(df_base_name)
mean_multiindex_name = ('SDR_Improvement', 'mean')
eotm_multiindex_name = ('SDR_Improvement', 'error_on_the_mean')

all_groups = {}
all_colors = {}
all_names = []
Expand All @@ -145,25 +149,25 @@ def make_sdr_delta_versus_input_snr_plot(models, df_base_name, bins):
all_label_df = all_groups[model['name']]
else:
all_label_df = all_label_df.merge(all_groups[model['name']], how='outer')

labels = all_label_df['Input_SNR_Bin'].unique()
n_groups = len(labels)
del all_label_df

# create plot
fig, ax = plt.subplots(figsize=(8, 6))
index = np.arange(n_groups)
bar_width = 0.25
bar_width = (bins[-1] - bins[0])/(1.15*n_groups*len(models))
opacity = 0.8

offset = 0
all_rects = []
for entry_name in all_names:
color = 'b'
groups = all_groups[entry_name]
if entry_name in all_colors:
color = all_colors[entry_name]

#male_means = groups[groups['Speaker_Sex'] == 'M']
#male_means = male_means[male_means['Input_SNR_Bin'] == labels].fillna(0)
male_means = groups[groups['Input_SNR_Bin'] == labels].fillna(0)
Expand All @@ -175,9 +179,9 @@ def make_sdr_delta_versus_input_snr_plot(models, df_base_name, bins):
color=color,
label=entry_name,
yerr=male_errors)

offset += 1

print_labels = []
for i in range(len(labels)):
#print_labels.append('[{}, {})'.format(i - 5, i - 4))
Expand All @@ -187,19 +191,20 @@ def make_sdr_delta_versus_input_snr_plot(models, df_base_name, bins):
tick.set_rotation(0)
tick.set_fontsize(12)
#plt.xlabel('Input SNR Range')
plt.xlabel('Input SNR')
plt.ylabel('SDR Improvement')
plt.xlabel('Input SNR (dB)')
plt.ylabel('SDR Improvement (dB)')
plt.title('SDR Improvement Versus Input SNR', fontsize=20)
ax.xaxis.label.set_size(15)
ax.yaxis.label.set_size(15)
ylim = [-0.5, ax.get_ylim()[1]]

ylim = [-0.5, 1.3*ax.get_ylim()[1]]
#ylim[0] = -0.5
ax.set_ylim(ylim)
#plt.axis([0, 11, -.5, 16])
plt.legend(fontsize=12, edgecolor='black')
plt.legend(fontsize=12, edgecolor='black',
loc='upper center', ncol=3, mode='expand')
plt.tight_layout()

plt.savefig('{}_sdr_delta_versus_input_snr.pdf'.format(df_base_name), format='pdf')


Expand All @@ -211,39 +216,55 @@ def main():
'out_of_set': '/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/snmf/out_of_sample_test_sdr_summary.csv',
'color': '#98C1D9'
},
#{
# 'name': 'DAE',
# 'out_of_set': '/data/fs4/home/pgamble/Magnolia/Denoising/Autoencoder/Final Results/eval_test_A.csv',
# 'color': '#E0FBFC'
#},
#{
# 'name': 'Chimera MI',
# 'in_set': '/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/chimera/mi_in_sample_test_sdr_summary.csv',
# 'out_of_set': '/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/chimera/mi_out_of_sample_test_sdr_summary.csv',
# 'color': '#3D5A80'
#},
{
'name': 'DC',#'Chimera DC',
'name': 'DAE',
'in_set': '/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/dae/ae_in_sample_test.csv',
'out_of_set': '/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/dae/ae_out_of_sample_test.csv',
'color': '#828C51'#'#E0FBFC'
},
{
'name': 'DC + MI (MI)',
'in_set': '/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/chimera/mi_in_sample_test_sdr_summary.csv',
'out_of_set': '/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/chimera/mi_out_of_sample_test_sdr_summary.csv',
'color': '#3D5A80'
},
{
'name': 'DC + MI (C)',
'in_set': '/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/chimera/dc_in_sample_test_sdr_summary.csv',
'out_of_set': '/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/chimera/dc_out_of_sample_test_sdr_summary.csv',
'color': '#0C0A3E'
},
{
'name': 'SCE + MI (MI)',
'in_set': '/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/mask_sce/mi_in_sample_test_sdr_summary.csv',
'out_of_set': '/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/mask_sce/mi_out_of_sample_test_sdr_summary.csv',
'color': '#CA054D'
},
{
'name': 'SCE + MI (C)',
'in_set': '/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/mask_sce/dc_in_sample_test_sdr_summary.csv',
'out_of_set': '/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/mask_sce/dc_out_of_sample_test_sdr_summary.csv',
'color': '#393E41'
},
{
'name': 'SCE',
'in_set': '/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/large_lab41/in_sample_test_sdr_summary.csv',
'out_of_set': '/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/large_lab41/out_of_sample_test_sdr_summary.csv',
'color': '#A4303F'
},
]
# TODO: the input SNR range should be determined automatically
bins = np.linspace(-5, 5, 11)
bins[-1] = 1.02*bins[-1]

load_dataframes(models)

make_sdr_delta_versus_input_snr_plot(models, 'out_of_set', bins)
make_sdr_delta_versus_noise_source_plot(models, 'out_of_set')
#make_sdr_delta_versus_sex_plot(models, 'out_of_set')

make_sdr_delta_versus_input_snr_plot(models, 'in_set', bins)
make_sdr_delta_versus_noise_source_plot(models, 'in_set')


if __name__ == '__main__':
main()
main()
15 changes: 14 additions & 1 deletion magnolia/python/analysis/make_sdr_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,19 @@ def main():
['/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/large_lab41/out_of_sample_test.csv',
'/local_data/magnolia/pipeline_data/date_2017_09_27_time_13_25/aux/out_of_sample_test_mixes.csv',
'/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/large_lab41/out_of_sample_test_sdr_summary.csv'],

['/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/mask_sce/mi_in_sample_test.csv',
'/local_data/magnolia/pipeline_data/date_2017_09_27_time_13_25/aux/in_sample_test_mixes.csv',
'/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/mask_sce/mi_in_sample_test_sdr_summary.csv'],
['/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/mask_sce/mi_out_of_sample_test.csv',
'/local_data/magnolia/pipeline_data/date_2017_09_27_time_13_25/aux/out_of_sample_test_mixes.csv',
'/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/mask_sce/mi_out_of_sample_test_sdr_summary.csv'],
['/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/mask_sce/dc_in_sample_test.csv',
'/local_data/magnolia/pipeline_data/date_2017_09_27_time_13_25/aux/in_sample_test_mixes.csv',
'/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/mask_sce/dc_in_sample_test_sdr_summary.csv'],
['/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/mask_sce/dc_out_of_sample_test.csv',
'/local_data/magnolia/pipeline_data/date_2017_09_27_time_13_25/aux/out_of_sample_test_mixes.csv',
'/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/mask_sce/dc_out_of_sample_test_sdr_summary.csv'],
['/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/chimera/mi_in_sample_test.csv',
'/local_data/magnolia/pipeline_data/date_2017_09_27_time_13_25/aux/in_sample_test_mixes.csv',
'/local_data/magnolia/experiment_data/date_2017_09_28_time_13_14/aux/evaluations/bss/chimera/mi_in_sample_test_sdr_summary.csv'],
Expand All @@ -77,7 +90,7 @@ def main():
# args = args[2:4]
#args = args[4:6]
#args = args[6:8]
args = args[8:]
#args = args[8:]
for arg in args:
make_nice_table(*arg)

Expand Down
Binary file not shown.
Binary file not shown.
Loading