diff --git a/python/tvm/meta_schedule/XGBgradient.py b/python/tvm/meta_schedule/XGBgradient.py index 9b8647155d83..66fe0b8cfbf6 100644 --- a/python/tvm/meta_schedule/XGBgradient.py +++ b/python/tvm/meta_schedule/XGBgradient.py @@ -49,18 +49,20 @@ def get_index(array: list, value: int): return index_map.get(value, -1) -def get_factors(n): +def get_factors(n, alignment=16): """ - Return the factors of a given number n as a sorted list. + Return the factors align to the aligmentof a given number n as a sorted list. """ factors = [] sqrt_n = isqrt(n) for i in range(1, sqrt_n + 1): if n % i == 0: - factors.append(i) + if i == 1 0r i % alignment == 0: + factors.append(i) j = n // i if j != i: - factors.append(j) + if j == 1 or j % alignment == 0: + factors.append(j) factors.sort() return factors @@ -456,6 +458,7 @@ def __init__( self.builder = Builder.create("local", max_workers=os.cpu_count(), timeout_sec=10.0) self.builder_results = [] + self.alignment = 16 def get_sample_records(self, number): """ @@ -814,7 +817,7 @@ def _apply_changes(self, new_coordinates, indices, changes, record_trace, sample tile_config = tile_configs[tile_idx][1] length = len(tile_config) - 1 dim_len = prod(tile_config) - factors = get_factors(dim_len) + factors = get_factors(dim_len, alignment=self.alignment) if not self._update_coordinates( new_coordinates,