Skip to content

Commit acf00e3

Browse files
committed
support from_sizes for tinker server
1 parent 0d81634 commit acf00e3

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

src/twinkle/server/tinker/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,10 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any],
100100
nproc_per_node=nproc_per_node,
101101
groups=[self.device_group],
102102
lazy_collect=False)
103-
self.device_mesh = DeviceMesh(**device_mesh)
103+
if 'mesh_dim_names' in device_mesh:
104+
self.device_mesh = DeviceMesh(**device_mesh)
105+
else:
106+
self.device_mesh = DeviceMesh.from_sizes(**device_mesh)
104107
self.use_megatron = use_megatron
105108
# Initialize model immediately - choose backend based on use_megatron
106109
if use_megatron:

src/twinkle/server/tinker/sampler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any],
9696
nproc_per_node=nproc_per_node,
9797
groups=[self.device_group],
9898
lazy_collect=False)
99-
self.device_mesh = DeviceMesh(**device_mesh)
99+
if 'mesh_dim_names' in device_mesh:
100+
self.device_mesh = DeviceMesh(**device_mesh)
101+
else:
102+
self.device_mesh = DeviceMesh.from_sizes(**device_mesh)
100103
self.sampler_type = sampler_type
101104

102105
# Initialize sampler based on type

0 commit comments

Comments
 (0)