Skip to content

Commit 95d474e

Browse files
committed
wip
1 parent 946810a commit 95d474e

File tree

5 files changed

+28
-6
lines changed

5 files changed

+28
-6
lines changed

src/twinkle/infra/_ray/resource_manager.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,19 @@ def __init__(self, nproc_per_node: int, ncpu_proc_per_node: int, groups: List[De
137137
if self.node_ranks.count(0) > 1:
138138
self.node_ranks = list(range(len(self.placement_groups)))
139139

140+
self.visible_devices = []
141+
142+
@ray.remote
143+
def get_visible_devices():
144+
return os.environ.get(Platform.get_platform(group.device_type).visible_device_env())
145+
146+
if self.placement_groups:
147+
self.visible_devices = ray.get([
148+
get_visible_devices.options(placement_group=pg).remote() for pg in self.placement_groups
149+
])
150+
151+
breakpoint()
152+
140153
self.node2pg: Dict[int, PlacementGroup] = {}
141154
# Map actual node indices to placement groups
142155
# For GPU/NPU groups, node indices start from self.min_node_idx

src/twinkle/server/tinker/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,17 @@ def __init__(self,
100100
else:
101101
self.device_mesh = DeviceMesh.from_sizes(**device_mesh)
102102
self.use_megatron = use_megatron
103+
replica_context = serve.get_replica_context()
104+
replica_id = replica_context.replica_id.unique_id
103105
# Initialize model immediately - choose backend based on use_megatron
104106
if use_megatron:
105107
from .common.megatron_model import TwinkleCompatMegatronModel
106108
self.model = TwinkleCompatMegatronModel(
107-
model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs)
109+
model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, instance_id=replica_id, **kwargs)
108110
else:
109111
from .common.transformers_model import TwinkleCompatTransformersModel
110112
self.model = TwinkleCompatTransformersModel(
111-
model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs)
113+
model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, instance_id=replica_id, **kwargs)
112114
self.base_model = model_id
113115
self.state: ServerStateProxy = get_server_state()
114116

src/twinkle/server/tinker/sampler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ def __init__(self,
102102
else:
103103
self.device_mesh = DeviceMesh.from_sizes(**device_mesh)
104104
self.sampler_type = sampler_type
105+
replica_context = serve.get_replica_context()
106+
replica_id = replica_context.replica_id.unique_id
105107

106108
# Initialize sampler based on type
107109
if sampler_type == 'vllm':
@@ -112,6 +114,7 @@ def __init__(self,
112114
engine_args=sampler_kwargs,
113115
device_mesh=self.device_mesh,
114116
remote_group=self.device_group.name,
117+
instance_id=replica_id,
115118
**{
116119
k: v
117120
for k, v in kwargs.items() if k not in ['engine_args']

src/twinkle/server/twinkle/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,16 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mes
171171
self.device_mesh = DeviceMesh(**device_mesh)
172172
else:
173173
self.device_mesh = DeviceMesh.from_sizes(**device_mesh)
174+
replica_context = serve.get_replica_context()
175+
replica_id = replica_context.replica_id.unique_id
174176
if use_megatron:
175177
from twinkle.model import MultiLoraMegatronModel
176178
self.model = MultiLoraMegatronModel(
177-
model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs)
179+
model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, instance_id=replica_id, **kwargs)
178180
else:
179181
from twinkle.model import MultiLoraTransformersModel
180182
self.model = MultiLoraTransformersModel(
181-
model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs)
183+
model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, instance_id=replica_id, **kwargs)
182184

183185
# Initialize state before adapter manager (mixin needs self.state)
184186
self.state: ServerStateProxy = get_server_state()

src/twinkle/server/twinkle/sampler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ def __init__(self,
152152
else:
153153
self.device_mesh = DeviceMesh.from_sizes(**device_mesh)
154154
self.sampler_type = sampler_type
155-
155+
replica_context = serve.get_replica_context()
156+
replica_id = replica_context.replica_id.unique_id
156157
# Initialize sampler based on type
157158
if sampler_type == 'vllm':
158159
from twinkle.sampler import vLLMSampler
@@ -162,14 +163,15 @@ def __init__(self,
162163
engine_args=sampler_kwargs,
163164
device_mesh=self.device_mesh,
164165
remote_group=self.device_group.name,
166+
instance_id=replica_id,
165167
**{
166168
k: v
167169
for k, v in kwargs.items() if k not in ['engine_args']
168170
})
169171
else:
170172
from twinkle.sampler import TorchSampler
171173
self.sampler = TorchSampler(
172-
model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, **kwargs)
174+
model_id=model_id, device_mesh=self.device_mesh, instance_id=replica_id, remote_group=self.device_group.name, **kwargs)
173175

174176
# Initialize state and adapter manager
175177
self.state: ServerStateProxy = get_server_state()

0 commit comments

Comments
 (0)