diff --git a/src/twinkle/infra/__init__.py b/src/twinkle/infra/__init__.py index 1e8f773f..ecbbca6a 100644 --- a/src/twinkle/infra/__init__.py +++ b/src/twinkle/infra/__init__.py @@ -442,7 +442,7 @@ def new_init(self, *args, **kwargs): # Pass an instance_id is recommended instance_id = kwargs.pop('instance_id', '') + f'{caller_file}_{caller_line}' remote_group = kwargs.get('remote_group') - if remote_group is None: + if os.environ.get('WORKER_NAME') is None and remote_group is None: logger.info(f'⚠️ Using local initialization of class: {cls}, please make sure the class ' 'does not need remote execution.') # If cannot trust_remote_code, no callable and type can be used. diff --git a/src/twinkle/server/tinker/common/compat_base.py b/src/twinkle/server/tinker/common/compat_base.py index 160e303a..62e22ff6 100644 --- a/src/twinkle/server/tinker/common/compat_base.py +++ b/src/twinkle/server/tinker/common/compat_base.py @@ -117,13 +117,14 @@ def get_template(self, adapter_name: str) -> Template: def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps: torch.Tensor) -> List[dict]: """Convert raw logits to the expected output format with logprobs and elementwise_loss.""" from twinkle.utils.torch_utils import selective_log_softmax - + device = logits.device if logits is not None else logps.device results = [] - for feature, logit in zip(inputs, logits): + if logits is None: + logits = [None] * len(inputs) + for idx, (feature, logit) in enumerate(zip(inputs, logits)): # Ensure 1D shape and correct device to avoid dimension mismatch and device errors - labels = feature.loss_fn_inputs['target_tokens'].to_torch().long().view(-1).to( - logit.device) # shape (seq_len,) - weights = feature.loss_fn_inputs['weights'].to_torch().view(-1).to(logit.device) # shape (seq_len,) + labels = feature.loss_fn_inputs['target_tokens'].to_torch().long().view(-1).to(device) # shape (seq_len,) + weights = feature.loss_fn_inputs['weights'].to_torch().view(-1).to(device) # shape (seq_len,) # Slice logits to match the sequence length of labels # Labels are assumed to be already shifted/aligned with logits @@ -138,7 +139,7 @@ def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps: # Calculate log probs for all labels token_log_probs = selective_log_softmax(feature_logits, labels) else: - token_log_probs = logps[:seq_len, :] + token_log_probs = logps[idx, :seq_len] # elementwise_loss: positive NLL loss (0.0 where masked) elementwise_loss = -token_log_probs * weights diff --git a/src/twinkle/server/tinker/common/megatron_model.py b/src/twinkle/server/tinker/common/megatron_model.py index 2d54c950..ebd4df76 100644 --- a/src/twinkle/server/tinker/common/megatron_model.py +++ b/src/twinkle/server/tinker/common/megatron_model.py @@ -74,24 +74,22 @@ def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss loss_kwargs = kwargs.copy() loss_kwargs.update(loss_values) # Megatron forward_backward returns loss directly - loss = super().forward_backward(inputs=input_features, adapter_name=adapter_name, **loss_kwargs) - - # Get logits from outputs - optimizer_config = self.optimizer_group.get(adapter_name) - outputs = optimizer_config.outputs if optimizer_config else {} + outputs = super().forward_backward(inputs=input_features, adapter_name=adapter_name, **loss_kwargs) + loss = outputs.get('loss', None) logits_list = outputs.get('logits', []) - logps = outputs.get('logprobs', []) - + logps = outputs.get('logps', []) # When PP enabled, only logits from last stage are available - if not logits_list and not logps: + if logits_list is None and logps is None: return [None, None] - # Process logits to match transformers output format - if isinstance(logits_list, torch.Tensor): - logits = logits_list.detach() - else: - # Concatenate logits from multiple microbatches - logits = torch.cat([logit.detach() for logit in logits_list], dim=0) + logits = None + if logits_list is not None: + # Process logits to match transformers output format + if isinstance(logits_list, torch.Tensor): + logits = logits_list.detach() + else: + # Concatenate logits from multiple microbatches + logits = torch.cat([logit.detach() for logit in logits_list], dim=0) logps = logps.detach().cpu() results = self._get_forward_output(inputs, logits, logps)