Skip to content

Commit 1a5e5d8

Browse files
Fix loading local datasets (#108)
1 parent 96d1354 commit 1a5e5d8

File tree

6 files changed

+39
-5
lines changed

6 files changed

+39
-5
lines changed

docs/source_en/Components/Dataset/Dataset.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ from twinkle.dataset import Dataset, DatasetMeta
5555
dataset = Dataset(DatasetMeta(dataset_id='my/custom/dataset.jsonl', data_slice=range(1500)))
5656
```
5757

58+
If using a local path or a local file, please follow these instructions:
59+
60+
1. If you are using a local dataset file, pass a single file path (better to be an absolute path to avoid relative path errors), list is not supported.
61+
2. If you are using a local dir, please make sure files in the path share the same data structure, and the file extensions.
62+
3. We use `datasets` library to do data loading, check the support extensions [here](https://huggingface.co/docs/hub/datasets-libraries).
63+
5864
2. Setting template
5965

6066
The Template component is responsible for converting string/image multimodal raw data into model input tokens. The dataset can set a Template to complete the `encode` process.

docs/source_zh/组件/数据集/Dataset.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ from twinkle.dataset import Dataset, DatasetMeta
5555
dataset = Dataset(DatasetMeta(dataset_id='my/custom/dataset.jsonl', data_slice=range(1500)))
5656
```
5757

58+
如果使用本地路径或本地文件,请遵循以下说明:
59+
60+
1. 如果使用的是本地数据集文件,请传入单个文件路径(最好使用绝对路径以避免相对路径错误),不支持传入列表。
61+
2. 如果使用的是本地目录,请确保目录中的文件具有相同的数据结构和文件扩展名。
62+
3. 我们使用 `datasets` 库进行数据加载,支持的扩展名请查看[此处](https://huggingface.co/docs/hub/datasets-libraries)
63+
5864
2. 设置 template
5965

6066
Template 组件是负责将字符串/图片多模态原始数据转换为模型输入 token 的组件。数据集可以设置一个 Template 来完成 `encode` 过程。

src/twinkle/dataset/base.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,15 +120,25 @@ def _load_dataset(dataset_meta: DatasetMeta, **kwargs):
120120
if os.path.exists(dataset_id):
121121
streaming = kwargs.get('streaming', False)
122122
num_proc = kwargs.get('num_proc', 1)
123-
ext = os.path.splitext(dataset_id)[1].lstrip('.')
124-
file_type = {'jsonl': 'json', 'txt': 'text'}.get(ext) or ext
125123
if streaming:
126124
kwargs = {'split': 'train', 'streaming': True}
127125
else:
128126
kwargs = {'split': 'train', 'num_proc': num_proc}
129-
if file_type == 'csv':
130-
kwargs['na_filter'] = False
131-
dataset = load_dataset(file_type, data_files=dataset_id, **kwargs)
127+
if os.path.isdir(dataset_id):
128+
folder_path = dataset_id
129+
files = os.listdir(folder_path)
130+
first_file = files[0] if files else None
131+
ext = os.path.splitext(first_file)[1].lstrip('.')
132+
file_type = {'jsonl': 'json', 'txt': 'text'}.get(ext) or ext
133+
if file_type == 'csv':
134+
kwargs['na_filter'] = False
135+
dataset = load_dataset(file_type, data_dir=dataset_id, **kwargs)
136+
else:
137+
ext = os.path.splitext(dataset_id)[1].lstrip('.')
138+
file_type = {'jsonl': 'json', 'txt': 'text'}.get(ext) or ext
139+
if file_type == 'csv':
140+
kwargs['na_filter'] = False
141+
dataset = load_dataset(file_type, data_files=dataset_id, **kwargs)
132142
else:
133143
dataset = HubOperation.load_dataset(dataset_id, subset_name, split, **kwargs)
134144

tests/dataset/test_data/1.lance

575 Bytes
Binary file not shown.
575 Bytes
Binary file not shown.

tests/dataset/test_loading.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,18 @@ def test_load_local_json(self):
4141
assert dataset[0]['text'] == 'Hello world'
4242
assert dataset[0]['label'] == 0
4343

44+
def test_load_local_lance(self):
45+
"""Test loading local Lance file"""
46+
lance_path = str(TEST_DATA_DIR / '1.lance')
47+
dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=lance_path))
48+
assert len(dataset) == 2
49+
50+
def test_load_local_lance_dir(self):
51+
"""Test loading local Lance dir"""
52+
lance_path = str(TEST_DATA_DIR / 'lance')
53+
dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=lance_path))
54+
assert len(dataset) == 2
55+
4456
def test_load_local_jsonl(self):
4557
jsonl_path = str(TEST_DATA_DIR / 'test.jsonl')
4658
dataset = Dataset(dataset_meta=DatasetMeta(dataset_id=jsonl_path))

0 commit comments

Comments
 (0)