|
34 | 34 | from twinkle.template import Template |
35 | 35 | from .strategy import MegatronStrategy |
36 | 36 | from twinkle.utils import construct_class, exists |
37 | | -from .args import get_args, set_args, TwinkleMegatronArgs |
38 | | -from .model import get_megatron_model_meta, GPTBridge |
39 | 37 | from twinkle.patch import Patch, apply_patch |
40 | 38 |
|
41 | 39 |
|
@@ -173,6 +171,7 @@ def __init__( |
173 | 171 | **kwargs, |
174 | 172 | ): |
175 | 173 | requires('megatron_core') |
| 174 | + from .args import get_args, set_args, TwinkleMegatronArgs |
176 | 175 | os.environ['TOKENIZERS_PARALLELISM'] = 'true' |
177 | 176 | nn.Module.__init__(self) |
178 | 177 | from twinkle.patch.megatron_peft import MegatronPeft |
@@ -240,6 +239,7 @@ def _create_megatron_model( |
240 | 239 | load_weights: bool = True, |
241 | 240 | **kwargs, |
242 | 241 | ) -> List[nn.Module]: |
| 242 | + from .args import get_args |
243 | 243 | args = get_args() |
244 | 244 | self.initialize(**kwargs) |
245 | 245 |
|
@@ -1002,6 +1002,7 @@ def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str |
1002 | 1002 | if isinstance(m, LoraLinear): |
1003 | 1003 | # just check |
1004 | 1004 | # TODO untested code |
| 1005 | + from .args import get_args |
1005 | 1006 | args = get_args() |
1006 | 1007 | from .tuners import LoraParallelLinear |
1007 | 1008 | assert args.is_multimodal and not isinstance(m, LoraParallelLinear) |
@@ -1114,6 +1115,7 @@ def initialize(self, **kwargs) -> None: |
1114 | 1115 |
|
1115 | 1116 | from megatron.core import parallel_state |
1116 | 1117 | from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed |
| 1118 | + from .args import get_args |
1117 | 1119 | self._try_init_process_group() |
1118 | 1120 | args = get_args() |
1119 | 1121 | init_kwargs = { |
@@ -1142,8 +1144,10 @@ def initialize(self, **kwargs) -> None: |
1142 | 1144 | self._initialized = True |
1143 | 1145 |
|
1144 | 1146 | @property |
1145 | | - def _bridge(self) -> GPTBridge: |
| 1147 | + def _bridge(self) -> 'GPTBridge': |
1146 | 1148 | if not hasattr(self, '_bridge_instance'): |
| 1149 | + from .args import get_args |
| 1150 | + from .model import get_megatron_model_meta |
1147 | 1151 | args = get_args() |
1148 | 1152 | megatron_model_meta = get_megatron_model_meta(args.hf_model_type) |
1149 | 1153 | assert megatron_model_meta is not None, f'Model: {args.hf_model_type} is not supported.' |
@@ -1181,6 +1185,7 @@ def send_weights( |
1181 | 1185 | # Trim any tensor whose dim-0 equals padded_vocab_size back to |
1182 | 1186 | # org_vocab_size — this is shape-based, not name-based, so it works |
1183 | 1187 | # regardless of the model architecture's naming convention. |
| 1188 | + from .args import get_args |
1184 | 1189 | args = get_args() |
1185 | 1190 | org_vocab_size = getattr(self.hf_config, 'vocab_size', args.padded_vocab_size) |
1186 | 1191 | _padded_vocab_size = args.padded_vocab_size |
|
0 commit comments