@@ -258,18 +258,18 @@ def generate_auc_comparison(df, output_file="lp_auc_comparison.pdf"):
258258 ], # Use color from specified palette
259259 )
260260
261- plt . title ( "AUC Comparison" , fontsize = 30 )
261+ # Removed plot title
262262 plt .xlabel ("Dataset (Countries)" , fontsize = 30 )
263263 plt .ylabel ("AUC" , fontsize = 30 )
264- plt .xticks (x_positions , datasets , rotation = 45 , fontsize = 30 )
264+ plt .xticks (x_positions , datasets , rotation = 0 , fontsize = 30 )
265265 plt .yticks (fontsize = 30 )
266266 plt .ylim (0 , 1.0 )
267267 plt .legend (
268- title = "Algorithms" ,
268+ # title="Algorithms",
269269 loc = "upper left" ,
270270 bbox_to_anchor = (1 , 1 ),
271271 fontsize = 25 ,
272- title_fontsize = 25 ,
272+ # title_fontsize=25,
273273 )
274274
275275 # Remove grid lines
@@ -337,7 +337,8 @@ def generate_train_time_comparison(df, output_file="lp_train_time_comparison.pdf
337337 if not dataset_row .empty and not pd .isna (
338338 dataset_row ["TrainTime" ].values [0 ]
339339 ):
340- train_time_values .append (dataset_row ["TrainTime" ].values [0 ])
340+ # Convert ms to s
341+ train_time_values .append (dataset_row ["TrainTime" ].values [0 ] / 1000 )
341342 else :
342343 train_time_values .append (0 )
343344
@@ -352,18 +353,17 @@ def generate_train_time_comparison(df, output_file="lp_train_time_comparison.pdf
352353 ], # Use color from specified palette
353354 )
354355
355- # Set chart title and labels
356- plt .title ("Train Time Comparison" , fontsize = 30 )
356+ # Removed plot title
357357 plt .xlabel ("Dataset (Countries)" , fontsize = 30 )
358- plt .ylabel ("Train Time (ms )" , fontsize = 28 )
359- plt .xticks (x_positions , datasets , rotation = 45 , fontsize = 30 )
358+ plt .ylabel ("Train Time (s )" , fontsize = 28 )
359+ plt .xticks (x_positions , datasets , rotation = 0 , fontsize = 30 )
360360 plt .yticks (fontsize = 28 )
361361 plt .legend (
362- title = "Algorithms" ,
362+ # title="Algorithms",
363363 loc = "upper left" ,
364364 bbox_to_anchor = (1 , 1 ),
365365 fontsize = 25 ,
366- title_fontsize = 25 ,
366+ # title_fontsize=25,
367367 )
368368
369369 # Remove grid lines
@@ -391,10 +391,14 @@ def generate_comm_cost_comparison(df, output_file="lp_comm_cost_comparison.pdf")
391391 subset = ["Actual_Total_MB" , "Theoretical_Total_MB" ], how = "all"
392392 )
393393
394+ # Convert MB to GB for plotting
395+ df_filtered = df_filtered .copy ()
396+ df_filtered ["Theoretical_Total_GB" ] = df_filtered ["Theoretical_Total_MB" ] / 1024
397+ df_filtered ["Actual_Total_GB" ] = df_filtered ["Actual_Total_MB" ] / 1024
394398 # Create a grouped DataFrame
395399 comparison_data = (
396400 df_filtered .groupby (["Dataset" , "Algorithm" ])
397- .agg ({"Theoretical_Total_MB " : "mean" , "Actual_Total_MB " : "mean" })
401+ .agg ({"Theoretical_Total_GB " : "mean" , "Actual_Total_GB " : "mean" })
398402 .reset_index ()
399403 )
400404
@@ -425,14 +429,14 @@ def generate_comm_cost_comparison(df, output_file="lp_comm_cost_comparison.pdf")
425429 for i , algo in enumerate (algorithms ):
426430 algo_data = comparison_data [comparison_data ["Algorithm" ] == algo ]
427431
428- # Actual values
432+ # Actual values (in GB)
429433 actual_values = []
430434 for dataset in datasets :
431435 dataset_row = algo_data [algo_data ["Dataset" ] == dataset ]
432436 if not dataset_row .empty and not pd .isna (
433- dataset_row ["Actual_Total_MB " ].values [0 ]
437+ dataset_row ["Actual_Total_GB " ].values [0 ]
434438 ):
435- actual_values .append (dataset_row ["Actual_Total_MB " ].values [0 ])
439+ actual_values .append (dataset_row ["Actual_Total_GB " ].values [0 ])
436440 else :
437441 actual_values .append (0 )
438442
@@ -446,14 +450,14 @@ def generate_comm_cost_comparison(df, output_file="lp_comm_cost_comparison.pdf")
446450 )
447451 current_pos += 1
448452
449- # Theoretical values
453+ # Theoretical values (in GB)
450454 theoretical_values = []
451455 for dataset in datasets :
452456 dataset_row = algo_data [algo_data ["Dataset" ] == dataset ]
453457 if not dataset_row .empty and not pd .isna (
454- dataset_row ["Theoretical_Total_MB " ].values [0 ]
458+ dataset_row ["Theoretical_Total_GB " ].values [0 ]
455459 ):
456- theoretical_values .append (dataset_row ["Theoretical_Total_MB " ].values [0 ])
460+ theoretical_values .append (dataset_row ["Theoretical_Total_GB " ].values [0 ])
457461 else :
458462 theoretical_values .append (0 )
459463
@@ -467,18 +471,17 @@ def generate_comm_cost_comparison(df, output_file="lp_comm_cost_comparison.pdf")
467471 )
468472 current_pos += 1
469473
470- # Set chart title and labels
471- plt .title ("Communication Cost Comparison" , fontsize = 30 )
474+ # Removed plot title
472475 plt .xlabel ("Dataset (Countries)" , fontsize = 30 )
473- plt .ylabel ("Communication Cost (MB )" , fontsize = 28 )
474- plt .xticks (x_positions , datasets , rotation = 45 , fontsize = 30 )
476+ plt .ylabel ("Communication Cost (GB )" , fontsize = 28 )
477+ plt .xticks (x_positions , datasets , rotation = 0 , fontsize = 30 )
475478 plt .yticks (fontsize = 28 )
476479 plt .legend (
477- title = "Algorithms" ,
480+ # title="Algorithms",
478481 loc = "upper left" ,
479482 bbox_to_anchor = (1 , 1 ),
480- fontsize = 22 ,
481- title_fontsize = 25 ,
483+ fontsize = 18 ,
484+ # title_fontsize=25,
482485 )
483486
484487 # Remove grid lines
0 commit comments