Skip to content

Commit d2c0ab7

Browse files
author
baijin.xh
committed
En
1 parent 5b0e132 commit d2c0ab7

File tree

11 files changed

+199
-195
lines changed

11 files changed

+199
-195
lines changed

tests/DeviceMesh/test_device_mesh.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,40 +24,40 @@ def test_dp_rank_only(self):
2424

2525
def test_tp_rank_only(self):
2626
mesh = DeviceMesh.from_sizes(tp_size=4)
27-
# from_sizes 默认 dp_size=1,维度顺序是 (dp, tp)
27+
# from_sizes default dp_size=1, dimension order (dp, tp)
2828
mesh_array = mesh.mesh.reshape(1, 4)
2929

3030
for tp_idx in range(4):
3131
global_rank = int(mesh_array[0, tp_idx])
3232
with patch.object(Platform, 'get_rank', return_value=global_rank):
3333
assert mesh.tp_rank == tp_idx
34-
assert mesh.dp_rank == 0 # dp 默认是 1,所以 dp_rank 总是 0
34+
assert mesh.dp_rank == 0 # dp default is 1, so dp_rank is always 0
3535
assert mesh.pp_rank is None
3636
assert mesh.fsdp_rank is None
3737

3838
def test_pp_rank_only(self):
3939
mesh = DeviceMesh.from_sizes(pp_size=4)
40-
# from_sizes 维度顺序是 (pp, dp),默认 dp_size=1
40+
# from_sizes dimension order (pp, dp), default dp_size=1
4141
mesh_array = mesh.mesh.reshape(4, 1)
4242

4343
for pp_idx in range(4):
4444
global_rank = int(mesh_array[pp_idx, 0])
4545
with patch.object(Platform, 'get_rank', return_value=global_rank):
4646
assert mesh.pp_rank == pp_idx
47-
assert mesh.dp_rank == 0 # dp 默认是 1,所以 dp_rank 总是 0
47+
assert mesh.dp_rank == 0 # dp default is 1, so dp_rank is always 0
4848
assert mesh.tp_rank is None
4949
assert mesh.fsdp_rank is None
5050

5151
def test_fsdp_rank_only(self):
5252
mesh = DeviceMesh.from_sizes(fsdp_size=4)
53-
# from_sizes 维度顺序是 (fsdp, dp),默认 dp_size=1
53+
# from_sizes dimension order (fsdp, dp), default dp_size=1
5454
mesh_array = mesh.mesh.reshape(4, 1)
5555

5656
for fsdp_idx in range(4):
5757
global_rank = int(mesh_array[fsdp_idx, 0])
5858
with patch.object(Platform, 'get_rank', return_value=global_rank):
5959
assert mesh.fsdp_rank == fsdp_idx
60-
assert mesh.dp_rank == 0 # dp 默认是 1,所以 dp_rank 总是 0
60+
assert mesh.dp_rank == 0 # dp default is 1, so dp_rank is always 0
6161
assert mesh.tp_rank is None
6262
assert mesh.pp_rank is None
6363

@@ -77,7 +77,7 @@ def test_dp_tp_combination(self):
7777

7878
def test_dp_fsdp_combination(self):
7979
mesh = DeviceMesh.from_sizes(dp_size=2, fsdp_size=4)
80-
# from_sizes 维度顺序是 (fsdp, dp)
80+
# from_sizes dimension order (fsdp, dp)
8181
mesh_array = mesh.mesh.reshape(4, 2)
8282

8383
for fsdp_idx in range(4):
@@ -91,7 +91,7 @@ def test_dp_fsdp_combination(self):
9191

9292
def test_tp_pp_combination(self):
9393
mesh = DeviceMesh.from_sizes(tp_size=2, pp_size=4)
94-
# from_sizes 维度顺序是 (pp, dp, tp),默认 dp_size=1
94+
# from_sizes dimension order (pp, dp, tp), default dp_size=1
9595
mesh_array = mesh.mesh.reshape(4, 1, 2)
9696

9797
for pp_idx in range(4):
@@ -100,12 +100,12 @@ def test_tp_pp_combination(self):
100100
with patch.object(Platform, 'get_rank', return_value=global_rank):
101101
assert mesh.pp_rank == pp_idx
102102
assert mesh.tp_rank == tp_idx
103-
assert mesh.dp_rank == 0 # dp 默认是 1,所以 dp_rank 总是 0
103+
assert mesh.dp_rank == 0 # dp default is 1, so dp_rank is always 0
104104
assert mesh.fsdp_rank is None
105105

106106
def test_dp_tp_pp_combination(self):
107107
mesh = DeviceMesh.from_sizes(dp_size=2, tp_size=2, pp_size=2)
108-
# from_sizes 维度顺序是 (pp, dp, tp)
108+
# from_sizes dimension order (pp, dp, tp)
109109
mesh_array = mesh.mesh.reshape(2, 2, 2)
110110

111111
for pp_idx in range(2):
@@ -120,7 +120,7 @@ def test_dp_tp_pp_combination(self):
120120

121121
def test_dp_fsdp_tp_combination(self):
122122
mesh = DeviceMesh.from_sizes(dp_size=2, fsdp_size=2, tp_size=2)
123-
# from_sizes 维度顺序是 (fsdp, dp, tp)
123+
# from_sizes dimension order (fsdp, dp, tp)
124124
mesh_array = mesh.mesh.reshape(2, 2, 2)
125125

126126
for fsdp_idx in range(2):
@@ -135,7 +135,7 @@ def test_dp_fsdp_tp_combination(self):
135135

136136
def test_all_dimensions_combination(self):
137137
mesh = DeviceMesh.from_sizes(dp_size=2, fsdp_size=2, tp_size=2, pp_size=2)
138-
# from_sizes 维度顺序是 (fsdp, pp, dp, tp)
138+
# from_sizes dimension order (fsdp, pp, dp, tp)
139139
mesh_array = mesh.mesh.reshape(2, 2, 2, 2)
140140

141141
for fsdp_idx in range(2):
@@ -197,13 +197,13 @@ def test_data_rank_with_fsdp_only(self):
197197

198198
def test_data_rank_with_dp_fsdp(self):
199199
mesh = DeviceMesh.from_sizes(dp_size=2, fsdp_size=3)
200-
# from_sizes 维度顺序是 (fsdp, dp)
200+
# from_sizes dimension order (fsdp, dp)
201201
mesh_array = mesh.mesh.reshape(3, 2)
202202

203203
for fsdp_idx in range(3):
204204
for dp_idx in range(2):
205205
global_rank = int(mesh_array[fsdp_idx, dp_idx])
206206
with patch.object(Platform, 'get_rank', return_value=global_rank):
207-
# data_rank 的计算公式: dp_rank * fsdp_world_size + fsdp_rank
207+
# data_rank formula: dp_rank * fsdp_world_size + fsdp_rank
208208
expected_data_rank = dp_idx * 3 + fsdp_idx
209209
assert mesh.data_rank == expected_data_rank

tests/dataloader/test_dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def create_dataset():
5050

5151

5252
class TestDataCollator:
53-
"""测试data_collator(InputProcessor)功能"""
53+
"""Test data_collator (InputProcessor) functionality"""
5454

5555
@pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
5656
def test_dataloader_with_collator(self):

tests/dataset/test_lazy.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,25 +21,25 @@ def convert_to_messages(example):
2121
class TestLazyDataset:
2222

2323
def test_lazy_dataset_basic(self):
24-
# 基本功能测试
24+
# Basic functionality test
2525
csv_path = str(TEST_DATA_DIR / 'test.csv')
2626
dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
2727

2828
assert len(dataset) == 4
29-
assert dataset.do_encode
30-
assert dataset.do_check
29+
assert not dataset.do_encode
30+
assert not dataset.do_check
3131

3232
item = dataset[0]
3333
assert 'text' in item
3434
assert item['text'] == 'Hello world'
3535

3636
def test_lazy_dataset_encode_flag(self):
37-
# 懒加载编码标志测试
37+
# Lazy encode flag test
3838
csv_path = str(TEST_DATA_DIR / 'test.csv')
3939
dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
4040
dataset.map(convert_to_messages)
4141

42-
assert dataset.do_encode
42+
assert not dataset.do_encode
4343

4444
try:
4545
dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=128)
@@ -48,12 +48,14 @@ def test_lazy_dataset_encode_flag(self):
4848

4949
dataset.encode()
5050

51-
assert dataset.do_encode
51+
# Lazy load: encode() only sets flag, actual encoding on access; raw dataset has no input_ids
5252
assert 'messages' in dataset.dataset[0]
5353
assert 'input_ids' not in dataset.dataset[0]
54+
item = dataset[0]
55+
assert 'input_ids' in item
5456

5557
def test_lazy_dataset_encode_on_access(self):
56-
# 懒加载编码执行测试
58+
# Lazy encode execution test
5759
csv_path = str(TEST_DATA_DIR / 'test.csv')
5860
dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
5961
dataset.map(convert_to_messages)
@@ -71,12 +73,12 @@ def test_lazy_dataset_encode_on_access(self):
7173
assert len(item['input_ids']) > 0
7274

7375
def test_lazy_dataset_check_flag(self):
74-
# 懒加载检查标志测试,验证check()只设置标志,不实际执行检查
76+
# Lazy check flag test: check() only sets flag, does not execute check
7577
csv_path = str(TEST_DATA_DIR / 'test.csv')
7678
dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
7779
dataset.map(convert_to_messages)
7880

79-
assert dataset.do_check
81+
assert not dataset.do_check
8082

8183
try:
8284
dataset.set_template('Template', model_id='ms://Qwen/Qwen2.5-0.5B-Instruct', max_length=128)
@@ -85,10 +87,12 @@ def test_lazy_dataset_check_flag(self):
8587

8688
dataset.check()
8789

88-
assert dataset.do_check
90+
# Lazy load: check() only sets flag, actual check on access
91+
item = dataset[0]
92+
assert item is not None
8993

9094
def test_lazy_dataset_check_on_access(self):
91-
# 懒加载检查执行测试,验证在访问数据时才执行检查
95+
# Lazy check execution test: check runs on data access
9296
csv_path = str(TEST_DATA_DIR / 'test.csv')
9397
dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
9498
dataset.map(convert_to_messages)
@@ -105,7 +109,7 @@ def test_lazy_dataset_check_on_access(self):
105109
assert 'messages' in item or item is None
106110

107111
def test_lazy_dataset_encode_requires_template(self):
108-
# 编码要求模板测试,验证未设置模板时抛出异常
112+
# Encode requires template: raises when template not set
109113
csv_path = str(TEST_DATA_DIR / 'test.csv')
110114
dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
111115

@@ -121,7 +125,7 @@ def test_lazy_dataset_check_requires_template(self):
121125

122126
@pytest.mark.skipif(SKIP_MODEL_DOWNLOAD, reason='Skipping tests that require model download')
123127
def test_lazy_dataset_no_split_strategy(self):
124-
# 编码不支持split策略测试,验证未设置模板时抛出异常
128+
# Encode does not support split strategy: raises when template not set
125129
csv_path = str(TEST_DATA_DIR / 'test.csv')
126130
dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
127131
dataset.map(convert_to_messages)
@@ -136,7 +140,7 @@ def test_lazy_dataset_no_split_strategy(self):
136140
dataset.encode()
137141

138142
def test_lazy_dataset_multiple_items(self):
139-
# 验证多个数据项的懒加载编码
143+
# Lazy encode for multiple items
140144
csv_path = str(TEST_DATA_DIR / 'test.csv')
141145
dataset = LazyDataset(dataset_meta=DatasetMeta(dataset_id=csv_path))
142146
dataset.map(convert_to_messages)

0 commit comments

Comments
 (0)