diff --git a/src/scheduler/types/Scheduler.py b/src/scheduler/types/Scheduler.py index 5545bd0..76fe80a 100644 --- a/src/scheduler/types/Scheduler.py +++ b/src/scheduler/types/Scheduler.py @@ -418,12 +418,12 @@ def bucket_stats(vals): avg_metric = self.set_schedule_values(best_schedule) return avg_metric - def best_sol_for_stem(self, stem, cost_benefits, cost_threshold, target_fps_options): + def best_sol_for_stem(self, stem, cost_benefits, cost_threshold, target_fps_options, mode): func_init, agg_func = self.func_init, self.agg_func dp = {} ops, updates = 0, 0 for i, app in enumerate(self.apps): - num_frozen_options = sorted(app["accuracies"].keys()) + num_frozen_options = self._get_num_frozen_options(app, mode) stem_ptr = 0 min_objective_by_budget = [] if i > 0: @@ -518,7 +518,7 @@ def stems_scheduler(self, cost_threshold, mode): if stem.cost > cost_threshold: continue num_stems_in_budget += 1 - best_stem_sol, (ops_, updates_) = self.best_sol_for_stem(stem, cost_benefits, cost_threshold, target_fps_options) + best_stem_sol, (ops_, updates_) = self.best_sol_for_stem(stem, cost_benefits, cost_threshold, target_fps_options, mode) ops += ops_ updates += updates_ if scheduler_util.sol_better(best_result, best_stem_sol):