11# Copyright (c) ModelScope Contributors. All rights reserved.
22# Adapted from https://github.com/volcengine/verl/blob/main/verl/checkpoint_engine/base.py
3- import torch
43from abc import ABC , abstractmethod
5- from typing import Any , AsyncGenerator , Generator , TypedDict
4+ from typing import TYPE_CHECKING , Any , AsyncGenerator , Generator , TypedDict
5+
6+ if TYPE_CHECKING :
7+ import torch
68
79
810class TensorMeta (TypedDict ):
911 """Metadata for a tensor in the weight bucket."""
1012 name : str
11- shape : torch .Size
12- dtype : torch .dtype
13+ shape : ' torch.Size'
14+ dtype : ' torch.dtype'
1315 offset : int
1416
1517
@@ -99,7 +101,7 @@ def finalize(self):
99101 raise NotImplementedError
100102
101103 @abstractmethod
102- async def send_weights (self , weights : Generator [tuple [str , torch .Tensor ], None , None ]):
104+ async def send_weights (self , weights : Generator [tuple [str , ' torch.Tensor' ], None , None ]):
103105 """Send model weights to rollout workers.
104106
105107 This method streams weights in buckets to avoid memory issues with
@@ -112,7 +114,7 @@ async def send_weights(self, weights: Generator[tuple[str, torch.Tensor], None,
112114 raise NotImplementedError
113115
114116 @abstractmethod
115- async def receive_weights (self ) -> AsyncGenerator [tuple [str , torch .Tensor ], None ]:
117+ async def receive_weights (self ) -> AsyncGenerator [tuple [str , ' torch.Tensor' ], None ]:
116118 """Receive model weights from trainer.
117119
118120 This method receives weights in buckets and yields them as they
0 commit comments