diff --git a/docs/user-guide/data-preparation.md b/docs/user-guide/data-preparation.md index 18da2d80fe1..45000390ff0 100644 --- a/docs/user-guide/data-preparation.md +++ b/docs/user-guide/data-preparation.md @@ -46,6 +46,48 @@ python tools/preprocess_data.py \ | `--workers` | Number of parallel workers for processing | | `--append-eod` | Add end-of-document token | +## Finding Optimal Number of Workers + +Use the `--find-optimal-num-workers` flag to find number of workers which gives the best performance in terms of preprocessed documents per second. +Script will lauch a few short data preprocessing runs with a different number of workers to define the fastest run in respect to collected performance data. + +```bash +python tools/preprocess_data.py \ + --input data.jsonl \ + --output-prefix processed_data \ + --tokenizer-type HuggingFaceTokenizer \ + --tokenizer-model /path/to/tokenizer.model \ + --workers 8 \ + --find-optimal-num-workers \ + --workers-to-check 4 8 16 32 \ + --performance-dir /path/to/save/perf/results \ + --max-documents 50000 +``` + +**Required arguments** + +| Argument | Description | +|----------|-------------| +| `--find-optimal-num-workers` | Activates search of optimal number of workers | +| `--workers-to-check` | List of possible number of workers to run | +| `--performance-dir` | Directory where to save performance results | +| `--max-documents` | Number of documents to be preprocessed during each run | + +**Output example** + +```bash +----------------------------------- +Performance results (fastest → slowest): +1. 16 workers → avg. docs/s: 9606.6476 +2. 32 workers → avg. docs/s: 9275.3284 +3. 8 workers → avg. docs/s: 9151.9280 +4. 4 workers → avg. docs/s: 6391.3819 + +----------------------------------- +The most optimal num of workers is 16 with avg. preprocessed docs/s: 9606.6476. +----------------------------------- +``` + ## Output Files The preprocessing tool generates two files: diff --git a/tests/unit_tests/data/test_preprocess_data.py b/tests/unit_tests/data/test_preprocess_data.py index e6922ec3748..618d873a079 100644 --- a/tests/unit_tests/data/test_preprocess_data.py +++ b/tests/unit_tests/data/test_preprocess_data.py @@ -2,6 +2,7 @@ import json import os +import runpy import sys import tempfile @@ -201,6 +202,42 @@ def test_preprocess_data_gpt(): do_test_preprocess_data(temp_dir, extra_args=gpt_args) +def test_preprocess_data_gpt_optimal_workers(): + with tempfile.TemporaryDirectory() as temp_dir: + + # gpt specific args + gpt_args = [ + "--input", + "/opt/data/datasets/dclm/dclm.jsonl", + "--output-prefix", + f"{temp_dir}/optimal_workers", + "--tokenizer-type", + "GPT2BPETokenizer", + "--vocab-file", + "/opt/data/tokenizers/megatron/gpt2-vocab.json", + "--merge-file", + "/opt/data/tokenizers/megatron/gpt2-merges.txt", + "--append-eod", + "--workers", + "2", + "--log-interval", + "1", + "--find-optimal-num-workers", + "--workers-to-check", + "2", + "4", + "8", + "--performance-dir", + f"{temp_dir}/perf", + "--max-documents", + "1002", + ] + sys.argv = ["/opt/megatron-lm/tools/preprocess_data.py"] + gpt_args + runpy.run_path("/opt/megatron-lm/tools/preprocess_data.py", run_name="__main__") + + assert os.path.exists(f"{temp_dir}/perf") + + def bert_vocab(odir): if os.path.exists(__LOCAL_BERT_VOCAB): return __LOCAL_BERT_VOCAB @@ -237,3 +274,4 @@ def test_preprocess_data_bert(): if __name__ == "__main__": test_preprocess_data_gpt() test_preprocess_data_bert() + test_preprocess_data_gpt_optimal_workers() diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index c1f19f6be31..27ba3429116 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -11,9 +11,8 @@ import time import gzip import glob -import torch -import numpy as np import multiprocessing +import numpy as np try: import nltk from nltk.tokenize.punkt import PunktLanguageVars @@ -113,6 +112,7 @@ class Partition(object): def __init__(self, args, workers): self.args = args self.workers = workers + self.performance = {self.workers: []} def print_processing_stats(self, count, proc_start, total_bytes_processed): if count % self.args.log_interval == 0: @@ -122,6 +122,7 @@ def print_processing_stats(self, count, proc_start, total_bytes_processed): print(f"Processed {count} documents", f"({count/elapsed} docs/s, {mbs} MB/s).", file=sys.stderr) + self.performance[self.workers].append(count/elapsed) def split_sentences(self, file_name): input_file_name, output_file_name = file_name @@ -180,10 +181,19 @@ def process_json_file(self, file_name): total_bytes_processed = 0 print("Time to startup:", startup_end - startup_start) for i, (doc, sentence_lens, bytes_processed) in enumerate(encoded_docs, start=1): - total_bytes_processed += bytes_processed - for key in doc.keys(): - builders[key].add_document(doc[key], sentence_lens[key]) - self.print_processing_stats(i, proc_start, total_bytes_processed) + if self.args.find_optimal_num_workers and i > self.args.max_documents: + break + else: + total_bytes_processed += bytes_processed + for key in doc.keys(): + builders[key].add_document(doc[key], sentence_lens[key]) + self.print_processing_stats(i, proc_start, total_bytes_processed) + + # Save performance data (preprocessed docs/s) + if self.args.find_optimal_num_workers: + perf_file_path = os.path.join(self.args.performance_dir, f"{self.workers}_workers.json") + with open(perf_file_path, "w") as perf_file: + json.dump(self.performance, perf_file) fin.close() builders[key].finalize(output_idx_files[key]) @@ -216,6 +226,22 @@ def get_args(): help=('Number of worker processes to launch.' 'A good default for fast pre-processing ' 'is: (workers * partitions) = available CPU cores.')) + group.add_argument('--find-optimal-num-workers', action='store_true', + help=('Find optimal number of workers.' + 'Script will run few small jobs with ' + 'different number of workers to define ' + 'optimal number of workers in terms of performance.')) + group.add_argument('--workers-to-check', nargs='+', type=int, default=[16, 32, 64], + help=('list of workers to run data processing with ' + 'to find optimal number of workers. ' + 'Works only when --find-optimal-num-workers is enabled. ')) + group.add_argument('--max-documents', type=int, default=100_000, + help=('Maximum number of documents to preprocess ' + 'to find optimal number of workers.' + 'Works only when --find-optimal-num-workers is enabled.')) + group.add_argument('--performance-dir', type=str, default=None, + help=('Path where to save performance results. ' + 'Works only when --find-optimal-num-workers is enabled.')) group.add_argument('--partitions', type=int, default=1, help='Number of file partitions') group.add_argument('--log-interval', type=int, default=1000, @@ -257,91 +283,159 @@ def check_files_exist(in_ss_out_names, key, num_partitions): return True +def find_optimal_num_workers(args): + """Parses saved .json files with perf. numbers and prints optimal number of workers""" + results = [] + + for filename in os.listdir(args.performance_dir): + if not filename.endswith(".json"): + continue + + filepath = os.path.join(args.performance_dir, filename) + + with open(filepath, "r") as f: + data = json.load(f) + + # each file assumed to contain a single {workers: [perf_list]} + for workers, perf_list in data.items(): + workers = int(workers) + avg_perf = np.mean(perf_list) + results.append((workers, avg_perf)) + + # sort by average performance (descending: fastest first) + results.sort(key=lambda x: x[1], reverse=True) + + print("\n-----------------------------------") + print("Performance results (fastest → slowest):") + for i, (workers, avg_perf) in enumerate(results): + print(f"{i+1}. {workers * args.partitions} workers → avg. docs/s: {avg_perf:.4f}") + + best_workers, best_perf = results[0] + + print("\n-----------------------------------") + print( + f"The most optimal num of workers is {best_workers * args.partitions} " + f"with avg. preprocessed docs/s: {best_perf:.4f}." + ) + print("-----------------------------------") + + def main(): args = get_args() - if args.split_sentences: - if nltk_available: - nltk.download("punkt", quiet=True, download_dir=os.environ.get("NLTK_DATA")) + workers = args.workers_to_check if args.find_optimal_num_workers else [args.workers] + for num_workers in workers: + if num_workers % args.partitions != 0: + print( + f"Removing num_workers ({num_workers}) from workers list " + f"because it's not divisible by num_partitions ({args.partitions})" + ) + workers.remove(num_workers) + assert workers, "Please, provide valid number of workers which is divisible by number of partitions." + if args.find_optimal_num_workers: + assert args.performance_dir, "Directory where to save performance results should be specified." + os.makedirs(args.performance_dir, exist_ok=True) + args.log_interval = 1000 + + for num_workers in workers: + print(f"Processing data with {num_workers} workers.") + if args.split_sentences: + if nltk_available: + nltk.download("punkt", quiet=True, download_dir=os.environ.get("NLTK_DATA")) + else: + raise Exception( + "nltk library required for sentence splitting is not available.") + + in_ss_out_names = [] + if args.partitions == 1: + file_name, extension = os.path.splitext(args.input) + sentence_split_file = file_name + "_ss" + extension + file_names = { + 'partition': args.input, + 'sentence_split': sentence_split_file, + 'output_prefix': args.output_prefix} + in_ss_out_names.append(file_names) else: - raise Exception( - "nltk library required for sentence splitting is not available.") - - in_ss_out_names = [] - if args.partitions == 1: - file_name, extension = os.path.splitext(args.input) - sentence_split_file = file_name + "_ss" + extension - file_names = { - 'partition': args.input, - 'sentence_split': sentence_split_file, - 'output_prefix': args.output_prefix} - in_ss_out_names.append(file_names) - else: - in_file_names = glob.glob(args.input) - - # Count total number of lines across .jsonl files - if args.keep_sequential_samples: - total_sample_count = 0 - for filename in in_file_names: - with open(filename, "r") as fin: - for fc, _ in enumerate(fin): - pass - total_sample_count += (fc + 1) - partition_size = math.ceil(total_sample_count / args.partitions) - - # create .jsonl parition files - for idx in range(args.partitions): - in_ss_out_name = get_file_name(args, idx) - in_ss_out_names.append(in_ss_out_name) - - # check to see if paritions were already created - partitions_present = check_files_exist(in_ss_out_names, 'partition', args.partitions) + in_file_names = glob.glob(args.input) + + # Count total number of lines across .jsonl files + if args.keep_sequential_samples: + total_sample_count = 0 + for filename in in_file_names: + with open(filename, "r") as fin: + for fc, _ in enumerate(fin): + pass + total_sample_count += (fc + 1) + partition_size = math.ceil(total_sample_count / args.partitions) + + # create .jsonl parition files + for idx in range(args.partitions): + in_ss_out_name = get_file_name(args, idx) + in_ss_out_names.append(in_ss_out_name) + + # check to see if paritions were already created + partitions_present = check_files_exist(in_ss_out_names, 'partition', args.partitions) + + # check to see if paritions with split sentences already created + split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions) + + if not partitions_present and not split_sentences_present: + # populate .jsonl partition files from parent files + partitioned_input_files = [] + for idx in range(args.partitions): + partitioned_input_file = open(in_ss_out_names[idx]['partition'], 'w') + partitioned_input_files.append(partitioned_input_file) + + index = 0 + if args.keep_sequential_samples: line_count = 0 + for in_file_name in in_file_names: + # support for gzip files + if in_file_name.endswith(".gz"): + fin = gzip.open(in_file_name, 'rt') + else: + fin = open(in_file_name, 'r', encoding='utf-8') + + for line in fin: + partitioned_input_files[index].write(line) + if args.keep_sequential_samples: + line_count += 1 + if line_count % partition_size == 0: + index += 1 + else: + index = (index + 1)%args.partitions + + fin.close() + + for idx in range(args.partitions): + partitioned_input_files[idx].close() + + partition = Partition(args, num_workers//args.partitions) # check to see if paritions with split sentences already created split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions) - if not partitions_present and not split_sentences_present: - # populate .jsonl partition files from parent files - partitioned_input_files = [] - for idx in range(args.partitions): - partitioned_input_file = open(in_ss_out_names[idx]['partition'], 'w') - partitioned_input_files.append(partitioned_input_file) - - index = 0 - if args.keep_sequential_samples: line_count = 0 - for in_file_name in in_file_names: - # support for gzip files - if in_file_name.endswith(".gz"): - fin = gzip.open(in_file_name, 'rt') - else: - fin = open(in_file_name, 'r', encoding='utf-8') - - for line in fin: - partitioned_input_files[index].write(line) - if args.keep_sequential_samples: - line_count += 1 - if line_count % partition_size == 0: - index += 1 - else: - index = (index + 1)%args.partitions + # split sentences in partition files + if args.split_sentences and not split_sentences_present: + processes = [] + for name in in_ss_out_names: + p = multiprocessing.Process(target=partition.split_sentences, + args=((name['partition'], name['sentence_split']),)) + p.start() + processes.append(p) - fin.close() + for p in processes: + p.join() - for idx in range(args.partitions): - partitioned_input_files[idx].close() + if args.partitions == 1: + continue - assert args.workers % args.partitions == 0 - partition = Partition(args, args.workers//args.partitions) - # check to see if paritions with split sentences already created - split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions) - - # split sentences in partition files - if args.split_sentences and not split_sentences_present: + # encode partition files in parallel processes = [] + input_key = 'sentence_split' if args.split_sentences else 'partition' for name in in_ss_out_names: - p = multiprocessing.Process(target=partition.split_sentences, - args=((name['partition'], name['sentence_split']),)) + p = multiprocessing.Process(target=partition.process_json_file, + args=((name[input_key], name['output_prefix']),)) p.start() processes.append(p) @@ -349,51 +443,38 @@ def main(): p.join() if args.partitions == 1: - return - - - # encode partition files in parallel - processes = [] - input_key = 'sentence_split' if args.split_sentences else 'partition' - for name in in_ss_out_names: - p = multiprocessing.Process(target=partition.process_json_file, - args=((name[input_key], name['output_prefix']),)) - p.start() - processes.append(p) - - for p in processes: - p.join() - - if args.partitions == 1: - return - - # merge bin/idx partitions - level = "document" - if args.split_sentences: - level = "sentence" - - output_bin_files = {} - output_idx_files = {} - builders = {} - tokenizer = build_tokenizer(args) - - for key in args.json_keys: - output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix, - key, level) - output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix, - key, level) - builders[key] = indexed_dataset.IndexedDatasetBuilder( - output_bin_files[key], - dtype=indexed_dataset.DType.optimal_dtype(tokenizer.vocab_size), - ) + continue - for name in in_ss_out_names: - parition_output_prefix = name['output_prefix'] - full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix, - key, level) - builders[key].add_index(full_partition_output_prefix) - builders[key].finalize(output_idx_files[key]) + # merge bin/idx partitions + level = "document" + if args.split_sentences: + level = "sentence" + + output_bin_files = {} + output_idx_files = {} + builders = {} + tokenizer = build_tokenizer(args) + + for key in args.json_keys: + output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix, + key, level) + output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix, + key, level) + builders[key] = indexed_dataset.IndexedDatasetBuilder( + output_bin_files[key], + dtype=indexed_dataset.DType.optimal_dtype(tokenizer.vocab_size), + ) + + for name in in_ss_out_names: + parition_output_prefix = name['output_prefix'] + full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix, + key, level) + builders[key].add_index(full_partition_output_prefix) + builders[key].finalize(output_idx_files[key]) + # Find the most optimal number of workers + if args.find_optimal_num_workers: + find_optimal_num_workers(args) if __name__ == '__main__':