diff --git a/build.sh b/build.sh old mode 100644 new mode 100755 diff --git a/python/tvm/meta_schedule/XGBgradient.py b/python/tvm/meta_schedule/XGBgradient.py index ec2da4bcaf9b..9b8647155d83 100644 --- a/python/tvm/meta_schedule/XGBgradient.py +++ b/python/tvm/meta_schedule/XGBgradient.py @@ -398,7 +398,6 @@ def __init__( target=None, task_name=None, tmpdir=None, - warmup_iter=True, ): """ Initialize the DynamicGradientSearch object. @@ -457,7 +456,6 @@ def __init__( self.builder = Builder.create("local", max_workers=os.cpu_count(), timeout_sec=10.0) self.builder_results = [] - self.warmup_iter = warmup_iter def get_sample_records(self, number): """ @@ -917,110 +915,72 @@ def XGB_gradient_search(self): ): record, _, _ = self.dgd_search(record) else: - # We apply gradient decent with cost model. - if self.warmup_iter[0] < 1: - self.warmup_iter[0] += 1 - tuning_records = self.get_sample_records(self.num_trials_per_iter) - if not tuning_records: - logging.debug("No valid initial samples found.") - return self.database - - # Update the cost model - candidates = [record.as_measure_candidate() for record in tuning_records] - results = [ - ms.runner.RunnerResult(run_secs=record.run_secs, error_msg=None) - for record in tuning_records - ] - - self.model.update(self.context, candidates, results) - costs = np.array( - [np.mean([v.value for v in res.run_secs]) for res in tuning_records] + total_sample_num = ( + self.init_population_size + if total_sample_num > self.init_population_size + else total_sample_num + ) + + raw_states = self.sample_init_population( + self.context.search_strategy, total_sample_num + ) + candidate_states = remove_duplicates_and_measured_schedules(raw_states) + candidate_inputs = assemble_candidates(candidate_states) + # 2. 利用 cost model 对候选方案进行预测打分 + candidates_score = self.model.predict(self.context, candidate_inputs) + topk = min(self.n_start, len(candidate_inputs)) + topk_indices = np.argsort(-candidates_score)[:topk] + topk_states = [candidate_states[i] for i in topk_indices] + + # 3. 对 topk 候选进行实际的构建与测量,转换成 tuning record + builder_inputs = [ + ms.builder.BuilderInput(state.mod, self.target) for state in topk_states + ] + builder_results = self.builder.build(builder_inputs) + valid_indices = [] + runner_inputs = [] + for i, res in enumerate(builder_results): + if res.error_msg: + continue + runner_input = ms.runner.RunnerInput( + res.artifact_path, + device_type=self.target.kind.name, + args_info=ms.arg_info.ArgInfo.from_prim_func(self.task), ) + runner_inputs.append(runner_input) + valid_indices.append(i) + if not runner_inputs: + logging.debug("No valid candidates found after prediction.") + return self.database - self.measured_throughputs_.extend(1 / np.array(costs)) - - topk = min(self.n_start, len(tuning_records)) - topk_indices = np.argsort(costs)[:topk] - topk_records = [tuning_records[i] for i in topk_indices] - - # use topk as budget now, later will add more options like n trials - # budget - for record in topk_records: - self.database.commit_tuning_record(record) - while ( - record is not None - and self.count_total_measured < self.max_trials - and time.time() - self.start_time < self.max_tuning_time - ): - record, _, _ = self.dgd_search(record) + # 使用多进程并行跑测 + runner_inputs_2d = list(map(lambda x: [x], runner_inputs)) + with Pool(self.n_jobs) as pool: + run_secs_list = pool.map(parallel_runner_run, runner_inputs_2d) - else: - total_sample_num = ( - self.init_population_size - if total_sample_num > self.init_population_size - else total_sample_num - ) - - raw_states = self.sample_init_population( - self.context.search_strategy, total_sample_num + topk_records = [] + workload = ms.database.Workload(self.context.mod) + for idx, run_secs in zip(valid_indices, run_secs_list): + record = ms.database.TuningRecord( + trace=topk_states[idx].trace, + workload=workload, + run_secs=run_secs, + target=self.target, + args_info=ms.arg_info.ArgInfo.from_prim_func(self.task), ) - candidate_states = remove_duplicates_and_measured_schedules(raw_states) - candidate_inputs = assemble_candidates(candidate_states) - # 2. 利用 cost model 对候选方案进行预测打分 - candidates_score = self.model.predict(self.context, candidate_inputs) - topk = min(self.n_start, len(candidate_inputs)) - topk_indices = np.argsort(-candidates_score)[:topk] - topk_states = [candidate_states[i] for i in topk_indices] - - # 3. 对 topk 候选进行实际的构建与测量,转换成 tuning record - builder_inputs = [ - ms.builder.BuilderInput(state.mod, self.target) for state in topk_states - ] - builder_results = self.builder.build(builder_inputs) - valid_indices = [] - runner_inputs = [] - for i, res in enumerate(builder_results): - if res.error_msg: - continue - runner_input = ms.runner.RunnerInput( - res.artifact_path, - device_type=self.target.kind.name, - args_info=ms.arg_info.ArgInfo.from_prim_func(self.task), - ) - runner_inputs.append(runner_input) - valid_indices.append(i) - if not runner_inputs: - logging.debug("No valid candidates found after prediction.") - return self.database - - # 使用多进程并行跑测 - runner_inputs_2d = list(map(lambda x: [x], runner_inputs)) - with Pool(self.n_jobs) as pool: - run_secs_list = pool.map(parallel_runner_run, runner_inputs_2d) - - topk_records = [] - workload = ms.database.Workload(self.context.mod) - for idx, run_secs in zip(valid_indices, run_secs_list): - record = ms.database.TuningRecord( - trace=topk_states[idx].trace, - workload=workload, - run_secs=run_secs, - target=self.target, - args_info=ms.arg_info.ArgInfo.from_prim_func(self.task), - ) - topk_records.append(record) + topk_records.append(record) - self.builder_results.extend(builder_results) - self.count_total_measured += topk - # 4. 对每个 topk record 作为初始基点,进入动态梯度搜索流程 - for record in topk_records: - self.database.commit_tuning_record(record) - while ( - record is not None - and self.count_total_measured < self.max_trials - and time.time() - self.start_time < self.max_tuning_time - ): - record, _, _ = self.dgd_search(record) + self.builder_results.extend(builder_results) + self.count_total_measured += topk + # 4. 对每个 topk record 作为初始基点,进入动态梯度搜索流程 + for record in topk_records: + self.database.commit_tuning_record(record) + while ( + record is not None + and self.count_total_measured < self.max_trials + and time.time() - self.start_time < self.max_tuning_time + ): + record, _, _ = self.dgd_search(record) for res in self.builder_results: if res.artifact_path: