Skip to content

Commit 5068165

Browse files
committed
refined LP
1 parent e9dfbfd commit 5068165

4 files changed

Lines changed: 28 additions & 25 deletions

File tree

benchmark/figure/LP_comm_costs/extract_LP_log.py

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -258,18 +258,18 @@ def generate_auc_comparison(df, output_file="lp_auc_comparison.pdf"):
258258
], # Use color from specified palette
259259
)
260260

261-
plt.title("AUC Comparison", fontsize=30)
261+
# Removed plot title
262262
plt.xlabel("Dataset (Countries)", fontsize=30)
263263
plt.ylabel("AUC", fontsize=30)
264-
plt.xticks(x_positions, datasets, rotation=45, fontsize=30)
264+
plt.xticks(x_positions, datasets, rotation=0, fontsize=30)
265265
plt.yticks(fontsize=30)
266266
plt.ylim(0, 1.0)
267267
plt.legend(
268-
title="Algorithms",
268+
# title="Algorithms",
269269
loc="upper left",
270270
bbox_to_anchor=(1, 1),
271271
fontsize=25,
272-
title_fontsize=25,
272+
#title_fontsize=25,
273273
)
274274

275275
# Remove grid lines
@@ -337,7 +337,8 @@ def generate_train_time_comparison(df, output_file="lp_train_time_comparison.pdf
337337
if not dataset_row.empty and not pd.isna(
338338
dataset_row["TrainTime"].values[0]
339339
):
340-
train_time_values.append(dataset_row["TrainTime"].values[0])
340+
# Convert ms to s
341+
train_time_values.append(dataset_row["TrainTime"].values[0] / 1000)
341342
else:
342343
train_time_values.append(0)
343344

@@ -352,18 +353,17 @@ def generate_train_time_comparison(df, output_file="lp_train_time_comparison.pdf
352353
], # Use color from specified palette
353354
)
354355

355-
# Set chart title and labels
356-
plt.title("Train Time Comparison", fontsize=30)
356+
# Removed plot title
357357
plt.xlabel("Dataset (Countries)", fontsize=30)
358-
plt.ylabel("Train Time (ms)", fontsize=28)
359-
plt.xticks(x_positions, datasets, rotation=45, fontsize=30)
358+
plt.ylabel("Train Time (s)", fontsize=28)
359+
plt.xticks(x_positions, datasets, rotation=0, fontsize=30)
360360
plt.yticks(fontsize=28)
361361
plt.legend(
362-
title="Algorithms",
362+
# title="Algorithms",
363363
loc="upper left",
364364
bbox_to_anchor=(1, 1),
365365
fontsize=25,
366-
title_fontsize=25,
366+
#title_fontsize=25,
367367
)
368368

369369
# Remove grid lines
@@ -391,10 +391,14 @@ def generate_comm_cost_comparison(df, output_file="lp_comm_cost_comparison.pdf")
391391
subset=["Actual_Total_MB", "Theoretical_Total_MB"], how="all"
392392
)
393393

394+
# Convert MB to GB for plotting
395+
df_filtered = df_filtered.copy()
396+
df_filtered["Theoretical_Total_GB"] = df_filtered["Theoretical_Total_MB"] / 1024
397+
df_filtered["Actual_Total_GB"] = df_filtered["Actual_Total_MB"] / 1024
394398
# Create a grouped DataFrame
395399
comparison_data = (
396400
df_filtered.groupby(["Dataset", "Algorithm"])
397-
.agg({"Theoretical_Total_MB": "mean", "Actual_Total_MB": "mean"})
401+
.agg({"Theoretical_Total_GB": "mean", "Actual_Total_GB": "mean"})
398402
.reset_index()
399403
)
400404

@@ -425,14 +429,14 @@ def generate_comm_cost_comparison(df, output_file="lp_comm_cost_comparison.pdf")
425429
for i, algo in enumerate(algorithms):
426430
algo_data = comparison_data[comparison_data["Algorithm"] == algo]
427431

428-
# Actual values
432+
# Actual values (in GB)
429433
actual_values = []
430434
for dataset in datasets:
431435
dataset_row = algo_data[algo_data["Dataset"] == dataset]
432436
if not dataset_row.empty and not pd.isna(
433-
dataset_row["Actual_Total_MB"].values[0]
437+
dataset_row["Actual_Total_GB"].values[0]
434438
):
435-
actual_values.append(dataset_row["Actual_Total_MB"].values[0])
439+
actual_values.append(dataset_row["Actual_Total_GB"].values[0])
436440
else:
437441
actual_values.append(0)
438442

@@ -446,14 +450,14 @@ def generate_comm_cost_comparison(df, output_file="lp_comm_cost_comparison.pdf")
446450
)
447451
current_pos += 1
448452

449-
# Theoretical values
453+
# Theoretical values (in GB)
450454
theoretical_values = []
451455
for dataset in datasets:
452456
dataset_row = algo_data[algo_data["Dataset"] == dataset]
453457
if not dataset_row.empty and not pd.isna(
454-
dataset_row["Theoretical_Total_MB"].values[0]
458+
dataset_row["Theoretical_Total_GB"].values[0]
455459
):
456-
theoretical_values.append(dataset_row["Theoretical_Total_MB"].values[0])
460+
theoretical_values.append(dataset_row["Theoretical_Total_GB"].values[0])
457461
else:
458462
theoretical_values.append(0)
459463

@@ -467,18 +471,17 @@ def generate_comm_cost_comparison(df, output_file="lp_comm_cost_comparison.pdf")
467471
)
468472
current_pos += 1
469473

470-
# Set chart title and labels
471-
plt.title("Communication Cost Comparison", fontsize=30)
474+
# Removed plot title
472475
plt.xlabel("Dataset (Countries)", fontsize=30)
473-
plt.ylabel("Communication Cost (MB)", fontsize=28)
474-
plt.xticks(x_positions, datasets, rotation=45, fontsize=30)
476+
plt.ylabel("Communication Cost (GB)", fontsize=28)
477+
plt.xticks(x_positions, datasets, rotation=0, fontsize=30)
475478
plt.yticks(fontsize=28)
476479
plt.legend(
477-
title="Algorithms",
480+
# title="Algorithms",
478481
loc="upper left",
479482
bbox_to_anchor=(1, 1),
480-
fontsize=22,
481-
title_fontsize=25,
483+
fontsize=18,
484+
#title_fontsize=25,
482485
)
483486

484487
# Remove grid lines
-1.67 KB
Binary file not shown.
-1.12 KB
Binary file not shown.
-1.5 KB
Binary file not shown.

0 commit comments

Comments
 (0)