From cf8a7fdf012b4149d153f45c98519d619f564ee5 Mon Sep 17 00:00:00 2001 From: ShouyangDong Date: Mon, 23 Jun 2025 19:23:43 +0800 Subject: [PATCH] add Simulated Annealing --- python/tvm/meta_schedule/XGBgradient.py | 207 +++++++++++++++++------- 1 file changed, 152 insertions(+), 55 deletions(-) diff --git a/python/tvm/meta_schedule/XGBgradient.py b/python/tvm/meta_schedule/XGBgradient.py index ec2da4bcaf9b..d66eb8d004eb 100644 --- a/python/tvm/meta_schedule/XGBgradient.py +++ b/python/tvm/meta_schedule/XGBgradient.py @@ -81,7 +81,11 @@ def has_sample_instruction(traces) -> bool: If a sample instruction was found """ # Function could potentially be moved to tvm.schedule.Trace - sample_instructions = ["SampleComputeLocation", "SampleCategorical", "SamplePerfectTile"] + sample_instructions = [ + "SampleComputeLocation", + "SampleCategorical", + "SamplePerfectTile", + ] for trace in traces: for inst in trace.insts: @@ -94,7 +98,7 @@ def count_splits(n, k, m): """ 统计将整数 n 拆分为 k 个正整数,使得它们的乘积等于 n, 且最后一个因子小于等于 m 的方案数。 - + 参数: n: 待分解的正整数 k: 分解成的因子个数 @@ -136,7 +140,9 @@ def get_sample_number(traces): elif inst.kind.name == "SamplePerfectTile": title_num = inst.attrs[0].value inter_most = inst.attrs[1].value - extend_list = [extend.value for extend in trace.decisions[inst]] + extend_list = [ + extend.value for extend in trace.decisions[inst] + ] extend = np.prod(extend_list) choices = count_splits(extend, title_num, inter_most) sample_num *= choices @@ -254,7 +260,9 @@ def modify_multi_level_tiling_node(self, new_coordinates): new_tile_part = new_coordinates[coord_idx : coord_idx + length] original_prod = prod(tile_config) new_prod = prod(new_tile_part) - last_element = original_prod // new_prod if new_prod != 0 else 0 + last_element = ( + original_prod // new_prod if new_prod != 0 else 0 + ) configs[cfg_idx][1] = list(new_tile_part) + [last_element] coord_idx += length elif schedule_name == "SampleCategorical": @@ -347,7 +355,10 @@ def assemble_candidates(picks): The list of MeasureCandidates """ return [ - ms.MeasureCandidate(sch, ms.arg_info.ArgInfo.from_entry_func(sch.mod, remove_preproc=True)) + ms.MeasureCandidate( + sch, + ms.arg_info.ArgInfo.from_entry_func(sch.mod, remove_preproc=True), + ) for sch in picks ] @@ -438,7 +449,9 @@ def __init__( ), task_name=task_name, ) - self.design_spaces = [space.trace for space in self.context.generate_design_space()] + self.design_spaces = [ + space.trace for space in self.context.generate_design_space() + ] self.num_trials_per_iter = num_trials_per_iter self.max_trials = max_trials @@ -446,7 +459,9 @@ def __init__( self.sample_init_population = tvm.get_global_func( "meta_schedule.SearchStrategyEvolutionarySearchSampleInitPopulation" ) - self.database = Database.create("json", work_dir=self.tmpdir, module_equality="structural") + self.database = Database.create( + "json", work_dir=self.tmpdir, module_equality="structural" + ) self.context.pre_tuning( max_trials=self.max_trials, num_trials_per_iter=self.num_trials_per_iter, @@ -455,7 +470,9 @@ def __init__( cost_model=self.model, ) - self.builder = Builder.create("local", max_workers=os.cpu_count(), timeout_sec=10.0) + self.builder = Builder.create( + "local", max_workers=os.cpu_count(), timeout_sec=10.0 + ) self.builder_results = [] self.warmup_iter = warmup_iter @@ -471,9 +488,13 @@ def get_sample_records(self, number): """ logging.debug("Sampling Init Population") - raw_states = self.sample_init_population(self.context.search_strategy, number) + raw_states = self.sample_init_population( + self.context.search_strategy, number + ) states = remove_duplicates_and_measured_schedules(raw_states) - builder_inputs = [ms.builder.BuilderInput(state.mod, self.target) for state in states] + builder_inputs = [ + ms.builder.BuilderInput(state.mod, self.target) for state in states + ] builder_results = self.builder.build(builder_inputs) tuning_records = [] workload = ms.database.Workload(self.context.mod) @@ -490,7 +511,7 @@ def get_sample_records(self, number): ) runner_inputs.append(runner_input) valid_indices.append(i) - + if not runner_inputs: return @@ -511,7 +532,11 @@ def process_result(item): return record with ThreadPoolExecutor(max_workers=self.n_jobs) as executor: - tuning_records = list(executor.map(process_result, zip(valid_indices, runner_results))) + tuning_records = list( + executor.map( + process_result, zip(valid_indices, runner_results) + ) + ) self.count_total_measured += len(raw_states) self.builder_results.extend(builder_results) @@ -575,19 +600,25 @@ def dgd_search(self, record): # get all 3 hops and predict/sorted by scores candidate_inputs.append(record.as_measure_candidate()) - candidate_scores = self.model.predict(self.context, candidate_inputs) - new_base, tmp_measured_inputs, tmp_measured_results = self.dgd_move( - base_result, - candidate_scores[0], - candidate_scores[1:], - all_neighbors, + candidate_scores = self.model.predict( + self.context, candidate_inputs + ) + new_base, tmp_measured_inputs, tmp_measured_results = ( + self.dgd_move( + base_result, + candidate_scores[0], + candidate_scores[1:], + all_neighbors, + ) ) if ( self.count_total_measured >= self.max_trials or time.time() - self.start_time >= self.max_tuning_time ): - self.model.update(self.context, tmp_measured_inputs, tmp_measured_results) + self.model.update( + self.context, tmp_measured_inputs, tmp_measured_results + ) return new_base, measured_inputs, measured_results measured_inputs.extend(tmp_measured_inputs) @@ -617,17 +648,23 @@ def dgd_move( base_cost = np.mean([v.value for v in base_result]) self.measured_throughputs_.append(1 / base_cost) # Filter scores and get the indices of scores that meet the threshold - filtered_indices = np.where(np.array(candidate_scores) >= score_threshold)[0] + filtered_indices = np.where( + np.array(candidate_scores) >= score_threshold + )[0] # Sort the filtered indices based on their scores in descending order - sorted_indices = filtered_indices[np.argsort(-candidate_scores[filtered_indices])] + sorted_indices = filtered_indices[ + np.argsort(-candidate_scores[filtered_indices]) + ] next_base = None measured_inputs = [] measured_results = [] index_slide = 0 - + while index_slide < len(sorted_indices) and not next_base: - slide_window_indices = sorted_indices[index_slide : index_slide + self.slide_window_size] + slide_window_indices = sorted_indices[ + index_slide : index_slide + self.slide_window_size + ] selected_records = [records[i] for i in slide_window_indices] selected_candidate_inputs = [ record.as_measure_candidate() for record in selected_records @@ -694,17 +731,25 @@ def process_update_record(item): logging.debug("Early stop in current window") break - selected_candidate_inputs = [selected_candidate_inputs[idx] for idx in valid_indices] - slide_window_results = [ms.runner.RunnerResult(rs, None) for rs in run_secs_list] + selected_candidate_inputs = [ + selected_candidate_inputs[idx] for idx in valid_indices + ] + slide_window_results = [ + ms.runner.RunnerResult(rs, None) for rs in run_secs_list + ] # 并行处理 update_records with ThreadPoolExecutor(max_workers=self.n_jobs) as executor: update_records = list( - executor.map(process_update_record, zip(valid_indices, run_secs_list)) + executor.map( + process_update_record, + zip(valid_indices, run_secs_list), + ) ) # break after self.max_trials measurements if ( - self.count_total_measured + len(slide_window_inputs) >= self.max_trials + self.count_total_measured + len(slide_window_inputs) + >= self.max_trials or time.time() - self.start_time >= self.max_tuning_time ): tmp_size = min( @@ -730,15 +775,30 @@ def process_update_record(item): sorted_idx = np.argsort(slide_window_costs) # find a better cost to move, add to visited, and avoid re-visit + # 可能接受更差解 for idx in sorted_idx: - if ( - slide_window_costs[idx] < base_cost - and slide_window_inputs[idx] not in self.visited - ): + cost = slide_window_costs[idx] + candidate_input = slide_window_inputs[idx] + + if candidate_input in self.visited: + continue + + accept = False + if cost < base_cost: + accept = True # 贪婪接受更优 + else: + # 模拟退火接受劣解的概率 + delta = cost - base_cost + prob = np.exp(-delta / 0.02) + if np.random.rand() < prob: + accept = True + logging.debug( + f"[SA] Accepted worse candidate: Δ={delta:.6f}, prob={prob:.4f}" + ) + + if accept: next_base = update_records[idx] - logging.debug("Found a better base candidate") - # add to visited - self.visited.add(slide_window_inputs[idx]) + self.visited.add(candidate_input) break index_slide += self.slide_window_size return next_base, measured_inputs, measured_results @@ -806,7 +866,9 @@ def _process_single_task(self, task): return [processor.record] return None - def _apply_changes(self, new_coordinates, indices, changes, record_trace, sample_category): + def _apply_changes( + self, new_coordinates, indices, changes, record_trace, sample_category + ): coord_idx = 0 for counter, config in enumerate(record_trace[0]): schedule_name = config[0] @@ -828,7 +890,9 @@ def _apply_changes(self, new_coordinates, indices, changes, record_trace, sample ): return False - product_of_dims = prod(new_coordinates[coord_idx : coord_idx + length]) + product_of_dims = prod( + new_coordinates[coord_idx : coord_idx + length] + ) if product_of_dims > dim_len or dim_len % product_of_dims != 0: return False @@ -861,7 +925,9 @@ def _apply_changes(self, new_coordinates, indices, changes, record_trace, sample return True - def _update_coordinates(self, new_coordinates, indices, changes, coord_idx, length, factors): + def _update_coordinates( + self, new_coordinates, indices, changes, coord_idx, length, factors + ): for i, change in enumerate(changes): idx = indices[i] if coord_idx <= idx < coord_idx + length: @@ -877,7 +943,9 @@ def _update_coordinates(self, new_coordinates, indices, changes, coord_idx, leng # TODO: add cuda constraints return True - def _update_coordinate_with_factor(self, new_coordinates, idx, change, factors): + def _update_coordinate_with_factor( + self, new_coordinates, idx, change, factors + ): factor_index = factors.index(new_coordinates[idx]) new_factor_index = factor_index + change if 0 <= new_factor_index < len(factors): @@ -902,9 +970,13 @@ def XGB_gradient_search(self): if not tuning_records: print("No valid initial samples found.") else: - candidates = [record.as_measure_candidate() for record in tuning_records] + candidates = [ + record.as_measure_candidate() for record in tuning_records + ] results = [ - ms.runner.RunnerResult(run_secs=record.run_secs, error_msg=None) + ms.runner.RunnerResult( + run_secs=record.run_secs, error_msg=None + ) for record in tuning_records ] self.model.update(self.context, candidates, results) @@ -913,28 +985,38 @@ def XGB_gradient_search(self): while ( record is not None and self.count_total_measured < self.max_trials - and time.time() - self.start_time < self.max_tuning_time + and time.time() - self.start_time + < self.max_tuning_time ): record, _, _ = self.dgd_search(record) else: # We apply gradient decent with cost model. if self.warmup_iter[0] < 1: self.warmup_iter[0] += 1 - tuning_records = self.get_sample_records(self.num_trials_per_iter) + tuning_records = self.get_sample_records( + self.num_trials_per_iter + ) if not tuning_records: logging.debug("No valid initial samples found.") return self.database # Update the cost model - candidates = [record.as_measure_candidate() for record in tuning_records] + candidates = [ + record.as_measure_candidate() for record in tuning_records + ] results = [ - ms.runner.RunnerResult(run_secs=record.run_secs, error_msg=None) + ms.runner.RunnerResult( + run_secs=record.run_secs, error_msg=None + ) for record in tuning_records ] self.model.update(self.context, candidates, results) costs = np.array( - [np.mean([v.value for v in res.run_secs]) for res in tuning_records] + [ + np.mean([v.value for v in res.run_secs]) + for res in tuning_records + ] ) self.measured_throughputs_.extend(1 / np.array(costs)) @@ -950,7 +1032,8 @@ def XGB_gradient_search(self): while ( record is not None and self.count_total_measured < self.max_trials - and time.time() - self.start_time < self.max_tuning_time + and time.time() - self.start_time + < self.max_tuning_time ): record, _, _ = self.dgd_search(record) @@ -960,21 +1043,26 @@ def XGB_gradient_search(self): if total_sample_num > self.init_population_size else total_sample_num ) - + raw_states = self.sample_init_population( self.context.search_strategy, total_sample_num ) - candidate_states = remove_duplicates_and_measured_schedules(raw_states) + candidate_states = remove_duplicates_and_measured_schedules( + raw_states + ) candidate_inputs = assemble_candidates(candidate_states) # 2. 利用 cost model 对候选方案进行预测打分 - candidates_score = self.model.predict(self.context, candidate_inputs) + candidates_score = self.model.predict( + self.context, candidate_inputs + ) topk = min(self.n_start, len(candidate_inputs)) topk_indices = np.argsort(-candidates_score)[:topk] topk_states = [candidate_states[i] for i in topk_indices] # 3. 对 topk 候选进行实际的构建与测量,转换成 tuning record builder_inputs = [ - ms.builder.BuilderInput(state.mod, self.target) for state in topk_states + ms.builder.BuilderInput(state.mod, self.target) + for state in topk_states ] builder_results = self.builder.build(builder_inputs) valid_indices = [] @@ -985,18 +1073,24 @@ def XGB_gradient_search(self): runner_input = ms.runner.RunnerInput( res.artifact_path, device_type=self.target.kind.name, - args_info=ms.arg_info.ArgInfo.from_prim_func(self.task), + args_info=ms.arg_info.ArgInfo.from_prim_func( + self.task + ), ) runner_inputs.append(runner_input) valid_indices.append(i) if not runner_inputs: - logging.debug("No valid candidates found after prediction.") + logging.debug( + "No valid candidates found after prediction." + ) return self.database # 使用多进程并行跑测 runner_inputs_2d = list(map(lambda x: [x], runner_inputs)) with Pool(self.n_jobs) as pool: - run_secs_list = pool.map(parallel_runner_run, runner_inputs_2d) + run_secs_list = pool.map( + parallel_runner_run, runner_inputs_2d + ) topk_records = [] workload = ms.database.Workload(self.context.mod) @@ -1006,7 +1100,9 @@ def XGB_gradient_search(self): workload=workload, run_secs=run_secs, target=self.target, - args_info=ms.arg_info.ArgInfo.from_prim_func(self.task), + args_info=ms.arg_info.ArgInfo.from_prim_func( + self.task + ), ) topk_records.append(record) @@ -1018,7 +1114,8 @@ def XGB_gradient_search(self): while ( record is not None and self.count_total_measured < self.max_trials - and time.time() - self.start_time < self.max_tuning_time + and time.time() - self.start_time + < self.max_tuning_time ): record, _, _ = self.dgd_search(record)