@@ -56,7 +56,7 @@ async def create_session(
5656 body : types .CreateSessionRequest ,
5757 self : GatewayServer = Depends (self_fn ),
5858 ) -> types .CreateSessionResponse :
59- session_id = self .state .create_session (body .model_dump ())
59+ session_id = await self .state .create_session (body .model_dump ())
6060 return types .CreateSessionResponse (session_id = session_id )
6161
6262 @app .post ('/session_heartbeat' )
@@ -72,7 +72,7 @@ async def session_heartbeat(
7272 async def create_sampling_session (
7373 request : Request , body : types .CreateSamplingSessionRequest , self : GatewayServer = Depends (self_fn )
7474 ) -> types .CreateSamplingSessionResponse : # noqa: E125
75- sampling_session_id = self .state .create_sampling_session (body .model_dump ())
75+ sampling_session_id = await self .state .create_sampling_session (body .model_dump ())
7676 return types .CreateSamplingSessionResponse (sampling_session_id = sampling_session_id )
7777
7878 @app .post ('/retrieve_future' )
@@ -223,44 +223,44 @@ async def create_model(request: Request, body: types.CreateModelRequest,
223223
224224 @app .post ('/get_info' )
225225 async def get_info (request : Request , body : types .GetInfoRequest , self : GatewayServer = Depends (self_fn )) -> Any :
226- return await self .proxy .proxy_to_model (request , 'get_info' , self ._get_base_model (body .model_id ))
226+ return await self .proxy .proxy_to_model (request , 'get_info' , await self ._get_base_model (body .model_id ))
227227
228228 @app .post ('/unload_model' )
229229 async def unload_model (request : Request , body : types .UnloadModelRequest ,
230230 self : GatewayServer = Depends (self_fn )) -> Any :
231- return await self .proxy .proxy_to_model (request , 'unload_model' , self ._get_base_model (body .model_id ))
231+ return await self .proxy .proxy_to_model (request , 'unload_model' , await self ._get_base_model (body .model_id ))
232232
233233 @app .post ('/forward' )
234234 async def forward (request : Request , body : types .ForwardRequest , self : GatewayServer = Depends (self_fn )) -> Any :
235- return await self .proxy .proxy_to_model (request , 'forward' , self ._get_base_model (body .model_id ))
235+ return await self .proxy .proxy_to_model (request , 'forward' , await self ._get_base_model (body .model_id ))
236236
237237 @app .post ('/forward_backward' )
238238 async def forward_backward (request : Request ,
239239 body : types .ForwardBackwardRequest ,
240240 self : GatewayServer = Depends (self_fn )) -> Any :
241- return await self .proxy .proxy_to_model (request , 'forward_backward' , self ._get_base_model (body .model_id ))
241+ return await self .proxy .proxy_to_model (request , 'forward_backward' , await self ._get_base_model (body .model_id ))
242242
243243 @app .post ('/optim_step' )
244244 async def optim_step (request : Request , body : types .OptimStepRequest , self : GatewayServer = Depends (self_fn )) -> Any :
245- return await self .proxy .proxy_to_model (request , 'optim_step' , self ._get_base_model (body .model_id ))
245+ return await self .proxy .proxy_to_model (request , 'optim_step' , await self ._get_base_model (body .model_id ))
246246
247247 @app .post ('/save_weights' )
248248 async def save_weights (request : Request , body : types .SaveWeightsRequest ,
249249 self : GatewayServer = Depends (self_fn )) -> Any :
250- return await self .proxy .proxy_to_model (request , 'save_weights' , self ._get_base_model (body .model_id ))
250+ return await self .proxy .proxy_to_model (request , 'save_weights' , await self ._get_base_model (body .model_id ))
251251
252252 @app .post ('/load_weights' )
253253 async def load_weights (request : Request , body : types .LoadWeightsRequest ,
254254 self : GatewayServer = Depends (self_fn )) -> Any :
255- return await self .proxy .proxy_to_model (request , 'load_weights' , self ._get_base_model (body .model_id ))
255+ return await self .proxy .proxy_to_model (request , 'load_weights' , await self ._get_base_model (body .model_id ))
256256
257257 # --- Sampler Proxy Endpoints ---
258258
259259 @app .post ('/asample' )
260260 async def asample (request : Request , body : types .SampleRequest , self : GatewayServer = Depends (self_fn )) -> Any :
261261 base_model = body .base_model
262262 if not base_model and body .sampling_session_id :
263- session = self .state .get_sampling_session (body .sampling_session_id )
263+ session = await self .state .get_sampling_session (body .sampling_session_id )
264264 if session :
265265 base_model = session .get ('base_model' )
266266 return await self .proxy .proxy_to_sampler (request , 'asample' , base_model )
@@ -271,4 +271,5 @@ async def save_weights_for_sampler(
271271 body : types .SaveWeightsForSamplerRequest ,
272272 self : GatewayServer = Depends (self_fn ),
273273 ) -> Any :
274- return await self .proxy .proxy_to_model (request , 'save_weights_for_sampler' , self ._get_base_model (body .model_id ))
274+ return await self .proxy .proxy_to_model (request , 'save_weights_for_sampler' , await
275+ self ._get_base_model (body .model_id ))
0 commit comments