diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index 3ac06c412b6..32b700e4798 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -89,7 +89,9 @@ def encode(self, json_line): ids = {} lens = {} for key in self.args.json_keys: - text = data[key] + text = data.get(key) + if text is None: + continue if isinstance(text, list): sentences = text else: @@ -184,7 +186,8 @@ def process_json_file(self, file_name): self.print_processing_stats(i, proc_start, total_bytes_processed) fin.close() - builders[key].finalize(output_idx_files[key]) + for key in self.args.json_keys: + builders[key].finalize(output_idx_files[key]) def get_args(): @@ -273,6 +276,7 @@ def main(): 'output_prefix': args.output_prefix} in_ss_out_names.append(file_names) else: + assert args.workers % args.partitions == 0 in_file_names = glob.glob(args.input) # Count total number of lines across .jsonl files @@ -326,7 +330,6 @@ def main(): for idx in range(args.partitions): partitioned_input_files[idx].close() - assert args.workers % args.partitions == 0 partition = Partition(args, args.workers//args.partitions) # check to see if paritions with split sentences already created