diff --git a/README.md b/README.md index f119403..b3a42ba 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,8 @@ For data ordering, we devise **Folding Ordering (FO)** method, which addresses i ## 📢 News and Updates Done +- [x] 2026/02/28: 💥 The **Data Ordering** module is officially integrated into DELT, supporting various data organization strategies including **Folding**, **Shuffle**, **Sorting**, **Zig-zag**, **Segment**, **Stair**, and **Saw Ordering**. +- [x] 2026/01/05: 💥Our paper **"Demystifying Data Organization for Enhanced LLM Training"** was submitted to ACL ARR January 2026. - [x] 2025/06/28: 💥The [Arxiv paper](https://arxiv.org/abs/2506.21545) released. - [x] 2025/08/31: 💥The DELT code released for pre-training on general domain. @@ -137,7 +139,7 @@ bash data_selection/entry.sh $INPUT_DATA_PATH $OUTPUT_DATA_PATH $METHOD $CONFIG_
Data Ordering -Existing ordering method: **Folding Ordering (FO)** (`folding`), Shuffle (`shuffle`), and Sorting (`sorting`). +Existing ordering method: **Folding Ordering (FO)** (`folding`), Shuffle (`shuffle`), Sorting (`sorting`), Zig-zag Ordering(`zigzag`), Segment Ordering(`segment`), Stair Ordering(`str`), Saw Ordering(`saw`). ```bash bash data_ordering/entry.sh $INPUT_DATA_PATH $OUTPUT_DATA_PATH $METHOD $CONFIG_PATH diff --git a/data_ordering/config/folding.yaml b/data_ordering/config/folding.yaml index dd353d9..cb8c559 100644 --- a/data_ordering/config/folding.yaml +++ b/data_ordering/config/folding.yaml @@ -4,3 +4,5 @@ description: Config of folding method in data ordering. score_field: score folding_layer: 3 +window_size: 100000 +seed: 42 diff --git a/data_ordering/config/saw.yaml b/data_ordering/config/saw.yaml new file mode 100644 index 0000000..e45e860 --- /dev/null +++ b/data_ordering/config/saw.yaml @@ -0,0 +1,13 @@ +name: section +version: 1.0 +description: Config of section (hybrid sorting/folding) method in data ordering. + + +score_field: score +ascending: True +reverse_even_layers: True +folding_layer: 3 +num_sections: 3 +folding_ratio: 0.1 +window_size: 100000 +seed: 42 \ No newline at end of file diff --git a/data_ordering/config/segment.yaml b/data_ordering/config/segment.yaml new file mode 100644 index 0000000..a8b7016 --- /dev/null +++ b/data_ordering/config/segment.yaml @@ -0,0 +1,13 @@ +name: segment +version: 1.0 +description: Config of segment-based ordering method. + + +score_field: score + +x_pct: 30 +y_pct: 30 +front_is_high: false +back_is_high: true + +seed: 42 \ No newline at end of file diff --git a/data_ordering/config/shuffle.yaml b/data_ordering/config/shuffle.yaml index e01b190..e15d73d 100644 --- a/data_ordering/config/shuffle.yaml +++ b/data_ordering/config/shuffle.yaml @@ -3,4 +3,4 @@ version: 1.0 description: Config of shuffle method in data ordering. score_field: score -seed: 10 +seed: 42 diff --git a/data_ordering/config/sorting.yaml b/data_ordering/config/sorting.yaml index 03d903f..c8a8042 100644 --- a/data_ordering/config/sorting.yaml +++ b/data_ordering/config/sorting.yaml @@ -1,6 +1,10 @@ name: sorting -version: 1.0 -description: Config of sorting method in data ordering. +version: 1.1 +description: Config of sorting method with local window shuffling. score_field: score ascending: true +use_gumbel: true +temperature: 0 +window_size: 100000 +seed: 42 diff --git a/data_ordering/config/str.yaml b/data_ordering/config/str.yaml new file mode 100644 index 0000000..d7938e1 --- /dev/null +++ b/data_ordering/config/str.yaml @@ -0,0 +1,13 @@ +name: section +version: 1.0 +description: Config of section (hybrid sorting/folding) method in data ordering. + + +score_field: score +ascending: True +reverse_even_layers: False +folding_layer: 2 +num_sections: 3 +folding_ratio: 0.1 +window_size: 100000 +seed: 42 \ No newline at end of file diff --git a/data_ordering/config/zigzag.yaml b/data_ordering/config/zigzag.yaml new file mode 100644 index 0000000..3aaec7d --- /dev/null +++ b/data_ordering/config/zigzag.yaml @@ -0,0 +1,10 @@ +name: zigzag +version: 2.0 +description: Config of zigzag method in data ordering. + +score_field: score +zigzag_layer: 2 +use_gumbel: true +temperature: 1000 +seed: 42 +window_size: 100000 \ No newline at end of file diff --git a/data_ordering/entry.py b/data_ordering/entry.py index 5db62a7..047ec36 100644 --- a/data_ordering/entry.py +++ b/data_ordering/entry.py @@ -5,6 +5,10 @@ import shuffle import sorting import folding +import zigzag +import segment +import str +import saw from utils import load_yaml, load_jsonl, add_args, write_jsonl @@ -12,8 +16,8 @@ parser = argparse.ArgumentParser(description="Data ordering.") parser.add_argument("--input_data_path", type=str, help="Path to the input .jsonl file.") parser.add_argument("--output_data_path", type=str, help="Path to the output .jsonl file.") - parser.add_argument("--method", type=str, choices=["shuffle", "sorting", "folding"], default="folding", - help="Ordering method: 'shuffle', 'sorting', and 'folding'. Defaults to 'folding'.") + parser.add_argument("--method", type=str, choices=["shuffle", "sorting", "folding", "zigzag", "segment", "str", "saw"], default="folding", + help="Ordering method: 'shuffle', 'sorting', and 'folding','zigzag','segment','str','saw'. Defaults to 'folding'.") parser.add_argument("--config_path", type=str, default="./config/folding.yaml", help="Config file for additional parameters (YAML format).") args = parser.parse_args() @@ -33,10 +37,46 @@ if args.method == "sorting": out_data = sorting.order(in_data, args) print(f" Ascending: {args.ascending}") + print(f" Temperature: {args.temperature}") + print(f" Use gumbel: {args.use_gumbel}") + print(f" Window size: {args.window_size}") if args.method == "folding": out_data = folding.order(in_data, args) print(f" Folding layer: {args.folding_layer}") + print(f" Window size: {args.window_size}") + + if args.method == "zigzag": + out_data = zigzag.order(in_data, args) + print(f" Zigzag layer: {args.zigzag_layer}") + print(f" Temperature: {args.temperature}") + print(f" Use gumbel: {args.use_gumbel}") + print(f" Window size: {args.window_size}") + + if args.method == "segment": + out_data = segment.order(in_data, args) + print(f" Front percentage: {args.x_pct}%") + print(f" Back percentage: {args.y_pct}%") + print(f" Front is high: {args.front_is_high}") + print(f" Back is high: {args.back_is_high}") + if hasattr(args, 'seed'): + print(f" Random seed: {args.seed}") + + if args.method == "str": + out_data = str.order(in_data, args) + print(f" Global Ascending: {args.ascending}") + print(f" Num sections: {args.num_sections}") + print(f" Folding ratio: {args.folding_ratio}") + print(f" Folding layer (in section): {args.folding_layer}") + print(f" Window size: {args.window_size}") + + if args.method == "saw": + out_data = saw.order(in_data, args) + print(f" Global Ascending: {args.ascending}") + print(f" Num sections: {args.num_sections}") + print(f" Folding ratio: {args.folding_ratio}") + print(f" Folding layer (in section): {args.folding_layer}") + print(f" Window size: {args.window_size}") write_jsonl(args.output_data_path, out_data) diff --git a/data_ordering/folding.py b/data_ordering/folding.py index c28afed..ba56bad 100644 --- a/data_ordering/folding.py +++ b/data_ordering/folding.py @@ -1,12 +1,62 @@ +import numpy as np + +def window_based_shuffle(data, window_size, seed=42): + """ + Jittering Ordering:对列表进行局部窗口内的随机打乱,整体有序,局部无序 + + Args: + data:输入数据列表 + window_size:局部打乱窗口大小,如果为 0 或 1,则不进行局部打乱 + seed:随机种子 + + Returns: + list: 重排序后的数据列表 + """ + if window_size <= 1: + return data + + n = len(data) + rng = np.random.RandomState(seed) + shuffled_final_data = [] + + for i in range(0, n, window_size): + chunk = data[i: i + window_size] + rng.shuffle(chunk) + shuffled_final_data.extend(chunk) + + return shuffled_final_data + + def order(in_data, args): + """ + Folding Ordering:将输入数据按分数进行升序排列,排序后的序列按 folding_layer 进行取模分桶,依次提取每个桶中的元素并拼接,实现分数的跳跃式分布。 + 最后可选执行局部窗口打乱。 + + Args: + in_data (list): 输入数据列表,每个元素为带有分数的字典 + args: 包含配置参数的对象 + - score_field: 分数字段名 + - folding_layer:折叠层数 + - window_size: 局部打乱窗口大小 (可选) + - seed: 随机种子 (可选) + + Returns: + list: 重排序后的数据列表 + """ score_field = args.score_field layers = args.folding_layer - # folding order. + window_size = getattr(args, "window_size", 0) + seed = getattr(args, "seed", 42) + sorted_data = sorted(in_data, key=lambda x: x[score_field], reverse=False) - + out_data = list() for l in range(layers): sub_data = [sorted_data[i] for i in range(len(sorted_data)) if i % layers == l] out_data.extend(sub_data) + + if window_size > 1: + out_data = window_based_shuffle(out_data, window_size, seed) + return out_data diff --git a/data_ordering/saw.py b/data_ordering/saw.py new file mode 100644 index 0000000..422a2ea --- /dev/null +++ b/data_ordering/saw.py @@ -0,0 +1,139 @@ +import numpy as np + + +def _apply_interleave_fold(data_segment, score_field, layers, reverse_even_layers=False): # <--- MODIFICATION: 增加新参数 + """ + 应用 "Folding" 逻辑 + + Args: + data_segment (list): 要处理的数据片段 + score_field (str): 分数字段 + layers (int): 交叉的层数 + reverse_even_layers (bool): 是否翻转偶数层(实现 Zigzag) + """ + if not data_segment: + return [] + + sorted_data = sorted(data_segment, key=lambda x: x[score_field], reverse=False) + + out_data = list() + for l in range(layers): + + sub_data = [sorted_data[i] for i in range(len(sorted_data)) if i % layers == l] + if reverse_even_layers and (l % 2 != 0): + sub_data.reverse() + out_data.extend(sub_data) + + return out_data + + +def window_based_shuffle(data, window_size, seed=42): + """ + Jittering Ordering:对列表进行局部窗口内的随机打乱,整体有序,局部无序 + + Args: + data:输入数据列表 + window_size:局部打乱窗口大小,如果为 0 或 1,则不进行局部打乱 + seed:随机种子 + + Returns: + list: 重排序后的数据列表 + """ + if window_size <= 1: + return data + + n = len(data) + rng = np.random.RandomState(seed) + shuffled_final_data = [] + + for i in range(0, n, window_size): + chunk = data[i: i + window_size] + rng.shuffle(chunk) + shuffled_final_data.extend(chunk) + + return shuffled_final_data + + +def order(in_data, args): + """ + Saw Ordering:先全局排序,然后在 K-1 个分割点应用局部折叠并反转奇数层,最后局部窗口打乱(可选) + + Args: + in_data (list): 输入数据列表,每个元素为带有分数的字典。 + args: 包含配置参数的对象。 + - score_field: 分数字段名 + - ascending: 是否升序 + - reverse_even_layers:是否翻转偶数层(参数默认True) + - folding_layer: 局部折叠参数 (来自 'folding') + - num_section: 数据被分成的总折数(参数默认为3) + - folding_ratio: 在分割点处,向上和向下各取多少比例的数据进行折叠 (例如 0.10 表示各 10%) + - window_size:局部打乱窗口大小,如果为 0 或 1,则不进行局部打乱 + - seed:随机种子 + + Returns: + list: 重排序后的数据列表 + """ + + score_field = args.score_field + ascending = args.ascending + num_sections = args.num_sections + folding_ratio = args.folding_ratio + interleave_layers = args.folding_layer + reverse_even_layers = getattr(args, 'reverse_even_layers', False) + + window_size = getattr(args, "window_size", 0) + seed = getattr(args, "seed", 42) + + if ascending: + sorted_data = sorted(in_data, key=lambda x: x[score_field], reverse=False) + else: + sorted_data = sorted(in_data, key=lambda x: x[score_field], reverse=True) + N = len(sorted_data) + + if N == 0: + return sorted_data + + if num_sections > 1: + split_indices = [int(round(N * i / num_sections)) for i in range(1, num_sections)] + radius_items = int(round(N * folding_ratio)) + segments = [] + current_index = 0 + + for sp_index in split_indices: + fold_start = max(0, sp_index - radius_items) + fold_end = min(N, sp_index + radius_items) + fold_start = max(fold_start, current_index) + fold_end = max(fold_end, fold_start) + if fold_start > current_index: + segments.append((current_index, fold_start, 'stable')) + if fold_end > fold_start: + segments.append((fold_start, fold_end, 'fold')) + current_index = fold_end + + if current_index < N: + segments.append((current_index, N, 'stable')) + + out_data = list() + for start, end, segment_type in segments: + data_segment = sorted_data[start:end] + + if not data_segment: + continue + + if segment_type == 'stable': + out_data.extend(data_segment) + else: + folded_segment = _apply_interleave_fold( + data_segment, + score_field, + interleave_layers, + reverse_even_layers + ) + out_data.extend(folded_segment) + else: + out_data = sorted_data + + if window_size > 1: + out_data = window_based_shuffle(out_data, window_size, seed) + + return out_data \ No newline at end of file diff --git a/data_ordering/segment.py b/data_ordering/segment.py new file mode 100644 index 0000000..681afcc --- /dev/null +++ b/data_ordering/segment.py @@ -0,0 +1,118 @@ +import random +import numpy as np +import warnings + +def window_based_shuffle(data, window_size, seed=42): + """ + Jittering Ordering:对列表进行局部窗口内的随机打乱,整体有序,局部无序 + + Args: + data:输入数据列表 + window_size:局部打乱窗口大小,如果为 0 或 1,则不进行局部打乱 + seed:随机种子 + + Returns: + list: 重排序后的数据列表 + """ + if window_size <= 1: + return data + + n = len(data) + rng = np.random.RandomState(seed) + shuffled_final_data = [] + + for i in range(0, n, window_size): + + chunk = data[i: i + window_size] + rng.shuffle(chunk) + shuffled_final_data.extend(chunk) + + return shuffled_final_data +z +def order(in_data, args): + """ + Segment Ordering:按分数重排序数据,将数据分为前、中、后三段,并分别打乱 + + Args: + in_data (list): 输入数据列表,每个元素为带有分数的字典 + args: 包含配置参数的对象 + - score_field: 分数字段名 + - x_pct: 前段百分比 (0-100) + - y_pct: 后段百分比 (0-100) + - front_is_high: 前段是否取高分样本 + - back_is_high: 后段是否取高分样本 + - seed: 随机种子(可选) + + Returns: + list: 重排序后的数据列表 + """ + score_field = args.score_field + total_samples = len(in_data) + + if hasattr(args, 'seed'): + random.seed(args.seed) + np.random.seed(args.seed) + + sorted_data = sorted( + enumerate(in_data), + key=lambda x: (x[1][score_field], x[0]), + reverse=False + ) + sorted_data = [item[1] for item in sorted_data] + + n_front = int(np.floor(args.x_pct / 100 * total_samples)) + n_back = int(np.floor(args.y_pct / 100 * total_samples)) + + total_selected = n_front + n_back + if total_selected > total_samples: + + ratio = total_samples / total_selected + n_front = int(np.floor(n_front * ratio)) + n_back = int(np.floor(n_back * ratio)) + warnings.warn( + f"前段({args.x_pct}%)和后段({args.y_pct}%)的总和超过100%! " + f"已按比例缩减为前段{n_front}个、后段{n_back}个样本。" + ) + total_selected = n_front + n_back + + front = [] + back = [] + middle = [] + + if args.front_is_high == args.back_is_high: + + if args.front_is_high: + selected = sorted_data[-total_selected:] if total_selected > 0 else [] + middle = sorted_data[:-total_selected] if total_selected > 0 else sorted_data + else: + selected = sorted_data[:total_selected] if total_selected > 0 else [] + middle = sorted_data[total_selected:] if total_selected > 0 else sorted_data + + if selected: + random.shuffle(selected) + + front = selected[:n_front] + back = selected[n_front:total_selected] + else: + + if args.front_is_high: + front = sorted_data[-n_front:] if n_front > 0 else [] + remaining = sorted_data[:-n_front] if n_front > 0 else sorted_data + else: + front = sorted_data[:n_front] if n_front > 0 else [] + remaining = sorted_data[n_front:] if n_front > 0 else sorted_data + + if args.back_is_high: + back = remaining[-n_back:] if n_back > 0 else [] + middle = remaining[:-n_back] if n_back > 0 else remaining + else: + back = remaining[:n_back] if n_back > 0 else [] + middle = remaining[n_back:] if n_back > 0 else remaining + + + random.shuffle(front) + random.shuffle(middle) + random.shuffle(back) + out_data = front + middle + back + + return out_data \ No newline at end of file diff --git a/data_ordering/shuffle.py b/data_ordering/shuffle.py index c0bc5a3..53599c9 100644 --- a/data_ordering/shuffle.py +++ b/data_ordering/shuffle.py @@ -1,6 +1,18 @@ import random def order(in_data, args): + """ + 随机排序 + + Args: + in_data (list): 输入数据列表,每个元素为带有分数的字典 + args: 包含配置参数的对象 + - score_field: 分数字段名 + - seed:随机种子(可选) + + Returns: + list: 重排序后的数据列表 + """ random.seed(args.seed) out_data = random.sample(in_data, len(in_data)) return out_data diff --git a/data_ordering/sorting.py b/data_ordering/sorting.py index 18b3142..d229a72 100644 --- a/data_ordering/sorting.py +++ b/data_ordering/sorting.py @@ -1,11 +1,103 @@ +import numpy as np + + +def gumbel_indices_sort(sorted_indices, tau=1.0, use_gumbel=False, seed=42, ascending=True): + """ + 对已经按真实分数排序后的索引序列进行“排名位置”的 Gumbel 扰动,再据此重排索引。 + + Args: + sorted_indices (list[int]): 已由真实分数排序得到的索引列表 + tau (float): 温度系数(控制扰动强度,越大越随机) + use_gumbel (bool): 是否启用 Gumbel 扰动 + seed (int): 随机种子 + ascending (bool): 排序方向,True=升序,False=降序 + + Returns: + list[int]: 加入 Gumbel 扰动后的新索引序列 + """ + idx = np.array(sorted_indices) + n = idx.shape[0] + + np.random.seed(seed) + if use_gumbel: + gumbel_noise = -np.log(-np.log(np.random.rand(n))) + perturbed_positions = np.arange(n, dtype=float) + gumbel_noise * tau + else: + perturbed_positions = np.arange(n, dtype=float) + + if ascending: + new_order = np.argsort(perturbed_positions) + else: + new_order = np.argsort(-perturbed_positions) + + return list(idx[new_order]) + + +def window_based_shuffle(data, window_size, seed=42): + """ + Jittering Ordering:对列表进行局部窗口内的随机打乱,整体有序,局部无序 + + Args: + data:输入数据列表 + window_size:局部打乱窗口大小,如果为 0 或 1,则不进行局部打乱 + seed:随机种子 + + Returns: + list: 重排序后的数据列表 + """ + if window_size <= 1: + return data + + n = len(data) + rng = np.random.RandomState(seed) + shuffled_final_data = [] + + for i in range(0, n, window_size): + + chunk = data[i: i + window_size] + rng.shuffle(chunk) + shuffled_final_data.extend(chunk) + + return shuffled_final_data + def order(in_data, args): + """ + Sorting Ordering:先按真实分数确定性进行sorting排序 再在索引顺序上加入 Gumbel 噪声(可选),局部窗口打乱(可选) + + Args: + in_data (list): 输入数据列表,每个元素为带有分数的字典。 + args: 包含配置参数的对象。 + - score_field: 分数字段名 + - ascending: 决定是否升序 + - use_gumbel (bool): 是否启用 Gumbel 扰动。False 时保持输入顺序不变 + - temperature: 温度系数(控制扰动强度,越大越随机) + - window_size:局部打乱窗口大小,如果为 0 或 1,则不进行局部打乱 + - seed (int): 随机种子 + + Returns: + list: 重排序后的数据列表 + """ score_field = args.score_field + ascending = args.ascending + + + tau = getattr(args, "temperature", 1.0) + use_gumbel = getattr(args, "use_gumbel", False) + seed = getattr(args, "seed", 42) + window_size = getattr(args, "window_size", 0) + + + scores = np.array([item[score_field] for item in in_data]) + base_sorted_indices = list(np.argsort(scores)) + if not ascending: + base_sorted_indices = base_sorted_indices[::-1] + sorted_indices = gumbel_indices_sort(base_sorted_indices, tau, use_gumbel, seed, ascending) + sorted_data = [in_data[i] for i in sorted_indices] + - if args.ascending: - # ascending order. - out_data = sorted(in_data, key=lambda x: x[score_field], reverse=False) + if window_size > 1: + final_data = window_based_shuffle(sorted_data, window_size, seed) else: - # descending order. - out_data = sorted(in_data, key=lambda x: x[score_field], reverse=True) + final_data = sorted_data - return out_data + return final_data diff --git a/data_ordering/str.py b/data_ordering/str.py new file mode 100644 index 0000000..c3a353f --- /dev/null +++ b/data_ordering/str.py @@ -0,0 +1,139 @@ +import numpy as np +def _apply_interleave_fold(data_segment, score_field, layers, reverse_even_layers=False): # <--- MODIFICATION: 增加新参数 + """ + 应用 "Folding" 逻辑 + + Args: + data_segment (list): 要处理的数据片段 + score_field (str): 分数字段 + layers (int): 交叉的层数 + reverse_even_layers (bool): 是否翻转偶数层(实现 Zigzag) + """ + if not data_segment: + return [] + + + sorted_data = sorted(data_segment, key=lambda x: x[score_field], reverse=False) + + out_data = list() + for l in range(layers): + + sub_data = [sorted_data[i] for i in range(len(sorted_data)) if i % layers == l] + if reverse_even_layers and (l % 2 != 0): + sub_data.reverse() + out_data.extend(sub_data) + + return out_data + + +def window_based_shuffle(data, window_size, seed=42): + """ + Jittering Ordering:对列表进行局部窗口内的随机打乱,整体有序,局部无序 + + Args: + data:输入数据列表 + window_size:局部打乱窗口大小,如果为 0 或 1,则不进行局部打乱 + seed:随机种子 + + Returns: + list: 重排序后的数据列表 + """ + if window_size <= 1: + return data + + n = len(data) + rng = np.random.RandomState(seed) + shuffled_final_data = [] + + for i in range(0, n, window_size): + + chunk = data[i: i + window_size] + rng.shuffle(chunk) + shuffled_final_data.extend(chunk) + + return shuffled_final_data + +def order(in_data, args): + """ + Stair Ordering:先全局排序,然后在 K-1 个分割点应用局部折叠,最后局部窗口打乱(可选) + + Args: + in_data (list): 输入数据列表,每个元素为带有分数的字典。 + args: 包含配置参数的对象。 + - score_field: 分数字段名 + - ascending: 是否升序 + - reverse_even_layers:是否翻转偶数层(参数默认False) + - folding_layer: 局部折叠参数 (来自 'folding') + - num_section: 数据被分成的总折数(参数默认为2) + - folding_ratio: 在分割点处,向上和向下各取多少比例的数据进行折叠 (例如 0.10 表示各 10%) + - window_size:局部打乱窗口大小,如果为 0 或 1,则不进行局部打乱 + - seed:随机种子 + + Returns: + list: 重排序后的数据列表 + """ + + + score_field = args.score_field + ascending = args.ascending + num_sections = args.num_sections + folding_ratio = args.folding_ratio + interleave_layers = args.folding_layer + reverse_even_layers = getattr(args, 'reverse_even_layers', False) + + window_size = getattr(args, "window_size", 0) + seed = getattr(args, "seed", 42) + + if ascending: + sorted_data = sorted(in_data, key=lambda x: x[score_field], reverse=False) + else: + sorted_data = sorted(in_data, key=lambda x: x[score_field], reverse=True) + N = len(sorted_data) + + if N == 0: + return sorted_data + + if num_sections > 1: + split_indices = [int(round(N * i / num_sections)) for i in range(1, num_sections)] + radius_items = int(round(N * folding_ratio)) + segments = [] + current_index = 0 + + for sp_index in split_indices: + fold_start = max(0, sp_index - radius_items) + fold_end = min(N, sp_index + radius_items) + fold_start = max(fold_start, current_index) + fold_end = max(fold_end, fold_start) + if fold_start > current_index: + segments.append((current_index, fold_start, 'stable')) + if fold_end > fold_start: + segments.append((fold_start, fold_end, 'fold')) + current_index = fold_end + + if current_index < N: + segments.append((current_index, N, 'stable')) + + out_data = list() + for start, end, segment_type in segments: + data_segment = sorted_data[start:end] + + if not data_segment: + continue + + if segment_type == 'stable': + out_data.extend(data_segment) + else: + folded_segment = _apply_interleave_fold( + data_segment, + score_field, + interleave_layers, + reverse_even_layers + ) + out_data.extend(folded_segment) + else: + out_data = sorted_data + + if window_size > 1: + out_data = window_based_shuffle(out_data, window_size, seed) + + return out_data \ No newline at end of file diff --git a/data_ordering/zigzag.py b/data_ordering/zigzag.py new file mode 100644 index 0000000..1054ba7 --- /dev/null +++ b/data_ordering/zigzag.py @@ -0,0 +1,100 @@ +import math +import numpy as np + +def gumbel_indices_sort(sorted_indices, tau=1.0, use_gumbel=False, seed=42): + """ + 对已经按真实分数排序后的索引序列进行“排名位置”的 Gumbel 扰动,再据此重排索引 + 这消除了分数分布对随机性的影响,随机性仅由 tau 决定 + + Args: + sorted_indices (Sequence[int]): 已由真实分数升序排序得到的索引列表 + tau (float): 温度系数,控制扰动强度。越大越随机 + use_gumbel (bool): 是否启用 Gumbel 扰动。False 时保持输入顺序不变 + seed (int): 随机种子,保证可复现 + + Returns: + List[int]: 经过 Gumbel 扰动后的新索引序列 + """ + idx = np.array(sorted_indices) + n = idx.shape[0] + np.random.seed(seed) + if use_gumbel: + gumbel_noise = -np.log(-np.log(np.random.rand(n))) + perturbed_positions = np.arange(n, dtype=float) + gumbel_noise * tau + else: + perturbed_positions = np.arange(n, dtype=float) + + new_order = np.argsort(perturbed_positions) + return list(idx[new_order]) + +def window_based_shuffle(data, window_size, seed=42): + """ + Jittering Ordering:对列表进行局部窗口内的随机打乱,整体有序,局部无序 + + Args: + data:输入数据列表 + window_size:局部打乱窗口大小,如果为 0 或 1,则不进行局部打乱 + seed:随机种子 + + Returns: + list: 重排序后的数据列表 + """ + if window_size <= 1: + return data + + n = len(data) + rng = np.random.RandomState(seed) + shuffled_final_data = [] + + for i in range(0, n, window_size): + + chunk = data[i: i + window_size] + rng.shuffle(chunk) + shuffled_final_data.extend(chunk) + + return shuffled_final_data + +def order(in_data, args): + """ + Zigzag Ordering:先按真实分数进行确定性排序得到 sorted_indices,再用 gumbel_indices_sort + 对 indices 的“排名位置”加 Gumbel 噪声(可选),局部窗口打乱(可选),最后执行Zigzag折叠逻辑 + + Args: + in_data (list): 输入数据列表,每个元素为带有分数的字典 + args: 包含配置参数的对象 + - score_field: 分数字段名 + - zigzag_layer: zigzag折叠层数 + - use_gumble: 是否启用gumble扰动 + - temperature: 温度系数,控制扰动强度。越大越随机 + - seed: 随机种子(可选) + - window_size:局部打乱窗口大小,如果为 0 或 1,则不进行局部打乱 + + Returns: + list[int]: 经过 Gumbel 扰动后的新索引序列 + """ + score_field = args.score_field + zigzag_layer = args.zigzag_layer + tau = args.temperature + use_gumbel = args.use_gumbel + seed = args.seed + window_size = getattr(args, "window_size", 0) + + scores = np.array([item[score_field] for item in in_data]) + base_sorted_indices = list(np.argsort(scores)) + sorted_indices = gumbel_indices_sort(base_sorted_indices, tau, use_gumbel, seed) + + r = 2 + out_data = [] + n = len(sorted_indices) + for l in range(zigzag_layer): + sub_data = [in_data[sorted_indices[i]] for i in range(n) if i % zigzag_layer == l] + sub_data_drop = [sub_data[i] for i in range(len(sub_data)) if i % r == 0][::-1] + sub_data_rise = [sub_data[i] for i in range(len(sub_data)) if i % r != 0] + out_data.extend(sub_data_drop) + out_data.extend(sub_data_rise) + + if window_size > 1: + out_data = window_based_shuffle(out_data, window_size, seed) + + return out_data +