Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file modified build.sh
100644 → 100755
Empty file.
162 changes: 61 additions & 101 deletions python/tvm/meta_schedule/XGBgradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,6 @@ def __init__(
target=None,
task_name=None,
tmpdir=None,
warmup_iter=True,
):
"""
Initialize the DynamicGradientSearch object.
Expand Down Expand Up @@ -457,7 +456,6 @@ def __init__(

self.builder = Builder.create("local", max_workers=os.cpu_count(), timeout_sec=10.0)
self.builder_results = []
self.warmup_iter = warmup_iter

def get_sample_records(self, number):
"""
Expand Down Expand Up @@ -917,110 +915,72 @@ def XGB_gradient_search(self):
):
record, _, _ = self.dgd_search(record)
else:
# We apply gradient decent with cost model.
if self.warmup_iter[0] < 1:
self.warmup_iter[0] += 1
tuning_records = self.get_sample_records(self.num_trials_per_iter)
if not tuning_records:
logging.debug("No valid initial samples found.")
return self.database

# Update the cost model
candidates = [record.as_measure_candidate() for record in tuning_records]
results = [
ms.runner.RunnerResult(run_secs=record.run_secs, error_msg=None)
for record in tuning_records
]

self.model.update(self.context, candidates, results)
costs = np.array(
[np.mean([v.value for v in res.run_secs]) for res in tuning_records]
total_sample_num = (
self.init_population_size
if total_sample_num > self.init_population_size
else total_sample_num
)

raw_states = self.sample_init_population(
self.context.search_strategy, total_sample_num
)
candidate_states = remove_duplicates_and_measured_schedules(raw_states)
candidate_inputs = assemble_candidates(candidate_states)
# 2. 利用 cost model 对候选方案进行预测打分
candidates_score = self.model.predict(self.context, candidate_inputs)
topk = min(self.n_start, len(candidate_inputs))
topk_indices = np.argsort(-candidates_score)[:topk]
topk_states = [candidate_states[i] for i in topk_indices]

# 3. 对 topk 候选进行实际的构建与测量,转换成 tuning record
builder_inputs = [
ms.builder.BuilderInput(state.mod, self.target) for state in topk_states
]
builder_results = self.builder.build(builder_inputs)
valid_indices = []
runner_inputs = []
for i, res in enumerate(builder_results):
if res.error_msg:
continue
runner_input = ms.runner.RunnerInput(
res.artifact_path,
device_type=self.target.kind.name,
args_info=ms.arg_info.ArgInfo.from_prim_func(self.task),
)
runner_inputs.append(runner_input)
valid_indices.append(i)
if not runner_inputs:
logging.debug("No valid candidates found after prediction.")
return self.database

self.measured_throughputs_.extend(1 / np.array(costs))

topk = min(self.n_start, len(tuning_records))
topk_indices = np.argsort(costs)[:topk]
topk_records = [tuning_records[i] for i in topk_indices]

# use topk as budget now, later will add more options like n trials
# budget
for record in topk_records:
self.database.commit_tuning_record(record)
while (
record is not None
and self.count_total_measured < self.max_trials
and time.time() - self.start_time < self.max_tuning_time
):
record, _, _ = self.dgd_search(record)
# 使用多进程并行跑测
runner_inputs_2d = list(map(lambda x: [x], runner_inputs))
with Pool(self.n_jobs) as pool:
run_secs_list = pool.map(parallel_runner_run, runner_inputs_2d)

else:
total_sample_num = (
self.init_population_size
if total_sample_num > self.init_population_size
else total_sample_num
)
raw_states = self.sample_init_population(
self.context.search_strategy, total_sample_num
topk_records = []
workload = ms.database.Workload(self.context.mod)
for idx, run_secs in zip(valid_indices, run_secs_list):
record = ms.database.TuningRecord(
trace=topk_states[idx].trace,
workload=workload,
run_secs=run_secs,
target=self.target,
args_info=ms.arg_info.ArgInfo.from_prim_func(self.task),
)
candidate_states = remove_duplicates_and_measured_schedules(raw_states)
candidate_inputs = assemble_candidates(candidate_states)
# 2. 利用 cost model 对候选方案进行预测打分
candidates_score = self.model.predict(self.context, candidate_inputs)
topk = min(self.n_start, len(candidate_inputs))
topk_indices = np.argsort(-candidates_score)[:topk]
topk_states = [candidate_states[i] for i in topk_indices]

# 3. 对 topk 候选进行实际的构建与测量,转换成 tuning record
builder_inputs = [
ms.builder.BuilderInput(state.mod, self.target) for state in topk_states
]
builder_results = self.builder.build(builder_inputs)
valid_indices = []
runner_inputs = []
for i, res in enumerate(builder_results):
if res.error_msg:
continue
runner_input = ms.runner.RunnerInput(
res.artifact_path,
device_type=self.target.kind.name,
args_info=ms.arg_info.ArgInfo.from_prim_func(self.task),
)
runner_inputs.append(runner_input)
valid_indices.append(i)
if not runner_inputs:
logging.debug("No valid candidates found after prediction.")
return self.database

# 使用多进程并行跑测
runner_inputs_2d = list(map(lambda x: [x], runner_inputs))
with Pool(self.n_jobs) as pool:
run_secs_list = pool.map(parallel_runner_run, runner_inputs_2d)

topk_records = []
workload = ms.database.Workload(self.context.mod)
for idx, run_secs in zip(valid_indices, run_secs_list):
record = ms.database.TuningRecord(
trace=topk_states[idx].trace,
workload=workload,
run_secs=run_secs,
target=self.target,
args_info=ms.arg_info.ArgInfo.from_prim_func(self.task),
)
topk_records.append(record)
topk_records.append(record)

self.builder_results.extend(builder_results)
self.count_total_measured += topk
# 4. 对每个 topk record 作为初始基点,进入动态梯度搜索流程
for record in topk_records:
self.database.commit_tuning_record(record)
while (
record is not None
and self.count_total_measured < self.max_trials
and time.time() - self.start_time < self.max_tuning_time
):
record, _, _ = self.dgd_search(record)
self.builder_results.extend(builder_results)
self.count_total_measured += topk
# 4. 对每个 topk record 作为初始基点,进入动态梯度搜索流程
for record in topk_records:
self.database.commit_tuning_record(record)
while (
record is not None
and self.count_total_measured < self.max_trials
and time.time() - self.start_time < self.max_tuning_time
):
record, _, _ = self.dgd_search(record)

for res in self.builder_results:
if res.artifact_path:
Expand Down