@@ -24,40 +24,40 @@ def test_dp_rank_only(self):
2424
2525 def test_tp_rank_only (self ):
2626 mesh = DeviceMesh .from_sizes (tp_size = 4 )
27- # from_sizes 默认 dp_size=1,维度顺序是 (dp, tp)
27+ # from_sizes default dp_size=1, dimension order (dp, tp)
2828 mesh_array = mesh .mesh .reshape (1 , 4 )
2929
3030 for tp_idx in range (4 ):
3131 global_rank = int (mesh_array [0 , tp_idx ])
3232 with patch .object (Platform , 'get_rank' , return_value = global_rank ):
3333 assert mesh .tp_rank == tp_idx
34- assert mesh .dp_rank == 0 # dp 默认是 1,所以 dp_rank 总是 0
34+ assert mesh .dp_rank == 0 # dp default is 1, so dp_rank is always 0
3535 assert mesh .pp_rank is None
3636 assert mesh .fsdp_rank is None
3737
3838 def test_pp_rank_only (self ):
3939 mesh = DeviceMesh .from_sizes (pp_size = 4 )
40- # from_sizes 维度顺序是 (pp, dp),默认 dp_size=1
40+ # from_sizes dimension order (pp, dp), default dp_size=1
4141 mesh_array = mesh .mesh .reshape (4 , 1 )
4242
4343 for pp_idx in range (4 ):
4444 global_rank = int (mesh_array [pp_idx , 0 ])
4545 with patch .object (Platform , 'get_rank' , return_value = global_rank ):
4646 assert mesh .pp_rank == pp_idx
47- assert mesh .dp_rank == 0 # dp 默认是 1,所以 dp_rank 总是 0
47+ assert mesh .dp_rank == 0 # dp default is 1, so dp_rank is always 0
4848 assert mesh .tp_rank is None
4949 assert mesh .fsdp_rank is None
5050
5151 def test_fsdp_rank_only (self ):
5252 mesh = DeviceMesh .from_sizes (fsdp_size = 4 )
53- # from_sizes 维度顺序是 (fsdp, dp),默认 dp_size=1
53+ # from_sizes dimension order (fsdp, dp), default dp_size=1
5454 mesh_array = mesh .mesh .reshape (4 , 1 )
5555
5656 for fsdp_idx in range (4 ):
5757 global_rank = int (mesh_array [fsdp_idx , 0 ])
5858 with patch .object (Platform , 'get_rank' , return_value = global_rank ):
5959 assert mesh .fsdp_rank == fsdp_idx
60- assert mesh .dp_rank == 0 # dp 默认是 1,所以 dp_rank 总是 0
60+ assert mesh .dp_rank == 0 # dp default is 1, so dp_rank is always 0
6161 assert mesh .tp_rank is None
6262 assert mesh .pp_rank is None
6363
@@ -77,7 +77,7 @@ def test_dp_tp_combination(self):
7777
7878 def test_dp_fsdp_combination (self ):
7979 mesh = DeviceMesh .from_sizes (dp_size = 2 , fsdp_size = 4 )
80- # from_sizes 维度顺序是 (fsdp, dp)
80+ # from_sizes dimension order (fsdp, dp)
8181 mesh_array = mesh .mesh .reshape (4 , 2 )
8282
8383 for fsdp_idx in range (4 ):
@@ -91,7 +91,7 @@ def test_dp_fsdp_combination(self):
9191
9292 def test_tp_pp_combination (self ):
9393 mesh = DeviceMesh .from_sizes (tp_size = 2 , pp_size = 4 )
94- # from_sizes 维度顺序是 (pp, dp, tp),默认 dp_size=1
94+ # from_sizes dimension order (pp, dp, tp), default dp_size=1
9595 mesh_array = mesh .mesh .reshape (4 , 1 , 2 )
9696
9797 for pp_idx in range (4 ):
@@ -100,12 +100,12 @@ def test_tp_pp_combination(self):
100100 with patch .object (Platform , 'get_rank' , return_value = global_rank ):
101101 assert mesh .pp_rank == pp_idx
102102 assert mesh .tp_rank == tp_idx
103- assert mesh .dp_rank == 0 # dp 默认是 1,所以 dp_rank 总是 0
103+ assert mesh .dp_rank == 0 # dp default is 1, so dp_rank is always 0
104104 assert mesh .fsdp_rank is None
105105
106106 def test_dp_tp_pp_combination (self ):
107107 mesh = DeviceMesh .from_sizes (dp_size = 2 , tp_size = 2 , pp_size = 2 )
108- # from_sizes 维度顺序是 (pp, dp, tp)
108+ # from_sizes dimension order (pp, dp, tp)
109109 mesh_array = mesh .mesh .reshape (2 , 2 , 2 )
110110
111111 for pp_idx in range (2 ):
@@ -120,7 +120,7 @@ def test_dp_tp_pp_combination(self):
120120
121121 def test_dp_fsdp_tp_combination (self ):
122122 mesh = DeviceMesh .from_sizes (dp_size = 2 , fsdp_size = 2 , tp_size = 2 )
123- # from_sizes 维度顺序是 (fsdp, dp, tp)
123+ # from_sizes dimension order (fsdp, dp, tp)
124124 mesh_array = mesh .mesh .reshape (2 , 2 , 2 )
125125
126126 for fsdp_idx in range (2 ):
@@ -135,7 +135,7 @@ def test_dp_fsdp_tp_combination(self):
135135
136136 def test_all_dimensions_combination (self ):
137137 mesh = DeviceMesh .from_sizes (dp_size = 2 , fsdp_size = 2 , tp_size = 2 , pp_size = 2 )
138- # from_sizes 维度顺序是 (fsdp, pp, dp, tp)
138+ # from_sizes dimension order (fsdp, pp, dp, tp)
139139 mesh_array = mesh .mesh .reshape (2 , 2 , 2 , 2 )
140140
141141 for fsdp_idx in range (2 ):
@@ -197,13 +197,13 @@ def test_data_rank_with_fsdp_only(self):
197197
198198 def test_data_rank_with_dp_fsdp (self ):
199199 mesh = DeviceMesh .from_sizes (dp_size = 2 , fsdp_size = 3 )
200- # from_sizes 维度顺序是 (fsdp, dp)
200+ # from_sizes dimension order (fsdp, dp)
201201 mesh_array = mesh .mesh .reshape (3 , 2 )
202202
203203 for fsdp_idx in range (3 ):
204204 for dp_idx in range (2 ):
205205 global_rank = int (mesh_array [fsdp_idx , dp_idx ])
206206 with patch .object (Platform , 'get_rank' , return_value = global_rank ):
207- # data_rank 的计算公式 : dp_rank * fsdp_world_size + fsdp_rank
207+ # data_rank formula : dp_rank * fsdp_world_size + fsdp_rank
208208 expected_data_rank = dp_idx * 3 + fsdp_idx
209209 assert mesh .data_rank == expected_data_rank
0 commit comments