From 0e15d4a7754e9dfbe77b2c5a571a6a34c2e3dcd1 Mon Sep 17 00:00:00 2001 From: Achal Dave Date: Wed, 8 Mar 2023 22:11:20 +0000 Subject: [PATCH 1/2] Add option to allow newlines in captions The YFCC-15M descriptions can have new lines in the caption, which causes parquet's csv module to error by default. This commit allows passing --newlines-in-captions True to img2dataset, which will tell parquet to allow newlines in CSV values. --- img2dataset/main.py | 2 ++ img2dataset/reader.py | 17 ++++++++++++++--- tests/test_reader.py | 1 + 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/img2dataset/main.py b/img2dataset/main.py index 11d9c9a..093213f 100644 --- a/img2dataset/main.py +++ b/img2dataset/main.py @@ -104,6 +104,7 @@ def download( max_shard_retry: int = 1, user_agent_token: Optional[str] = None, disallowed_header_directives: Optional[List[str]] = None, + newlines_in_captions: bool = False, ): """Download is the main entry point of img2dataset, it uses multiple processes and download multiple files""" if disallowed_header_directives is None: @@ -183,6 +184,7 @@ def signal_handler(signal_arg, frame): # pylint: disable=unused-argument number_sample_per_shard, done_shards, tmp_path, + newlines_in_captions, ) if output_format == "webdataset": diff --git a/img2dataset/reader.py b/img2dataset/reader.py index 0dec953..db234df 100644 --- a/img2dataset/reader.py +++ b/img2dataset/reader.py @@ -38,6 +38,7 @@ def __init__( number_sample_per_shard, done_shards, tmp_path, + newlines_in_captions, ) -> None: self.input_format = input_format self.url_col = url_col @@ -47,6 +48,7 @@ def __init__( self.save_additional_columns = save_additional_columns self.number_sample_per_shard = number_sample_per_shard self.done_shards = done_shards + self.newlines_in_captions = newlines_in_captions fs, url_path = fsspec.core.url_to_fs(url_list) self.fs = fs @@ -79,13 +81,22 @@ def _save_to_arrow(self, input_file, start_shard_id): if self.input_format in ["txt", "json", "csv", "tsv"]: with self.fs.open(input_file, mode="rb") as file: if self.input_format == "txt": - df = csv_pq.read_csv(file, read_options=csv_pq.ReadOptions(column_names=["url"])) + df = csv_pq.read_csv( + file, + read_options=csv_pq.ReadOptions(column_names=["url"]), + parse_options=csv_pq.ParseOptions(newlines_in_values=self.newlines_in_captions), + ) elif self.input_format == "json": df = pa.Table.from_pandas(pd.read_json(file)) elif self.input_format == "csv": - df = csv_pq.read_csv(file) + df = csv_pq.read_csv( + file, parse_options=csv_pq.ParseOptions(newlines_in_values=self.newlines_in_captions) + ) elif self.input_format == "tsv": - df = csv_pq.read_csv(file, parse_options=csv_pq.ParseOptions(delimiter="\t")) + df = csv_pq.read_csv( + file, + parse_options=csv_pq.ParseOptions(delimiter="\t", newlines_in_values=self.newlines_in_captions), + ) else: raise ValueError(f"Unknown input format {self.input_format}") elif self.input_format == "tsv.gz": diff --git a/tests/test_reader.py b/tests/test_reader.py index 81225f6..24ff52d 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -49,6 +49,7 @@ def test_reader(input_format, tmp_path): number_sample_per_shard=batch_size, done_shards=done_shards, tmp_path=test_folder, + newlines_in_captions=False, ) if input_format == "txt": From a91c119cddc63ea83a2a06b4bd0930971a17fd86 Mon Sep 17 00:00:00 2001 From: Achal Dave Date: Mon, 20 May 2024 13:23:29 -0700 Subject: [PATCH 2/2] typo fix --- img2dataset/reader.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/img2dataset/reader.py b/img2dataset/reader.py index 0daff7c..2546782 100644 --- a/img2dataset/reader.py +++ b/img2dataset/reader.py @@ -99,21 +99,21 @@ def _save_to_arrow(self, input_file, start_shard_id): compression = "gzip" with self.fs.open(input_file, encoding="utf-8", mode="rb", compression=compression) as file: if self.input_format in ["txt", "txt.gz"]: - df = csv_pq.read_csv( + df = csv_pa.read_csv( file, - read_options=csv_pq.ReadOptions(column_names=["url"]), - parse_options=csv_pq.ParseOptions(newlines_in_values=self.newlines_in_captions), + read_options=csv_pa.ReadOptions(column_names=["url"]), + parse_options=csv_pa.ParseOptions(newlines_in_values=self.newlines_in_captions), ) elif self.input_format in ["json", "json.gz"]: df = pa.Table.from_pandas(pd.read_json(file)) elif self.input_format in ["csv", "csv.gz"]: - df = csv_pq.read_csv( - file, parse_options=csv_pq.ParseOptions(newlines_in_values=self.newlines_in_captions) + df = csv_pa.read_csv( + file, parse_options=csv_pa.ParseOptions(newlines_in_values=self.newlines_in_captions) ) elif self.input_format in ["tsv", "tsv.gz"]: - df = csv_pq.read_csv( + df = csv_pa.read_csv( file, - parse_options=csv_pq.ParseOptions(delimiter="\t", newlines_in_values=self.newlines_in_captions), + parse_options=csv_pa.ParseOptions(delimiter="\t", newlines_in_values=self.newlines_in_captions), ) elif self.input_format in ["jsonl", "jsonl.gz"]: df = json_pa.read_json(file)