Skip to content

Commit b941858

Browse files
committed
fix
1 parent 3e4c283 commit b941858

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

src/twinkle/model/megatron/megatron.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434
from twinkle.template import Template
3535
from .strategy import MegatronStrategy
3636
from twinkle.utils import construct_class, exists
37-
from .args import get_args, set_args, TwinkleMegatronArgs
38-
from .model import get_megatron_model_meta, GPTBridge
3937
from twinkle.patch import Patch, apply_patch
4038

4139

@@ -173,6 +171,7 @@ def __init__(
173171
**kwargs,
174172
):
175173
requires('megatron_core')
174+
from .args import get_args, set_args, TwinkleMegatronArgs
176175
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
177176
nn.Module.__init__(self)
178177
from twinkle.patch.megatron_peft import MegatronPeft
@@ -240,6 +239,7 @@ def _create_megatron_model(
240239
load_weights: bool = True,
241240
**kwargs,
242241
) -> List[nn.Module]:
242+
from .args import get_args
243243
args = get_args()
244244
self.initialize(**kwargs)
245245

@@ -1002,6 +1002,7 @@ def _patch_adapter(self, adapter_name: str, config_or_dir: Union[PeftConfig, str
10021002
if isinstance(m, LoraLinear):
10031003
# just check
10041004
# TODO untested code
1005+
from .args import get_args
10051006
args = get_args()
10061007
from .tuners import LoraParallelLinear
10071008
assert args.is_multimodal and not isinstance(m, LoraParallelLinear)
@@ -1114,6 +1115,7 @@ def initialize(self, **kwargs) -> None:
11141115

11151116
from megatron.core import parallel_state
11161117
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
1118+
from .args import get_args
11171119
self._try_init_process_group()
11181120
args = get_args()
11191121
init_kwargs = {
@@ -1142,8 +1144,10 @@ def initialize(self, **kwargs) -> None:
11421144
self._initialized = True
11431145

11441146
@property
1145-
def _bridge(self) -> GPTBridge:
1147+
def _bridge(self) -> 'GPTBridge':
11461148
if not hasattr(self, '_bridge_instance'):
1149+
from .args import get_args
1150+
from .model import get_megatron_model_meta
11471151
args = get_args()
11481152
megatron_model_meta = get_megatron_model_meta(args.hf_model_type)
11491153
assert megatron_model_meta is not None, f'Model: {args.hf_model_type} is not supported.'
@@ -1181,6 +1185,7 @@ def send_weights(
11811185
# Trim any tensor whose dim-0 equals padded_vocab_size back to
11821186
# org_vocab_size — this is shape-based, not name-based, so it works
11831187
# regardless of the model architecture's naming convention.
1188+
from .args import get_args
11841189
args = get_args()
11851190
org_vocab_size = getattr(self.hf_config, 'vocab_size', args.padded_vocab_size)
11861191
_padded_vocab_size = args.padded_vocab_size

src/twinkle/model/megatron/multi_lora_megatron.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from twinkle.loss import Loss
1818
from twinkle.metric import Metric
1919
from twinkle.processor import InputProcessor
20-
from .args import TwinkleMegatronArgs, set_args
2120
from .megatron import MegatronModel
2221
from .strategy import MegatronStrategy
2322
from ..multi_lora import MultiLora
@@ -42,6 +41,7 @@ def __init__(self,
4241
requires('megatron_core')
4342
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
4443
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
44+
from .args import TwinkleMegatronArgs, set_args
4545
nn.Module.__init__(self)
4646
from twinkle.patch.megatron_peft import MegatronPeft
4747

0 commit comments

Comments
 (0)