Skip to content
2 changes: 1 addition & 1 deletion src/twinkle/infra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def new_init(self, *args, **kwargs):
# Pass an instance_id is recommended
instance_id = kwargs.pop('instance_id', '') + f'{caller_file}_{caller_line}'
remote_group = kwargs.get('remote_group')
if remote_group is None:
if os.environ.get('WORKER_NAME') is None and remote_group is None:
logger.info(f'⚠️ Using local initialization of class: {cls}, please make sure the class '
'does not need remote execution.')
# If cannot trust_remote_code, no callable and type can be used.
Expand Down
13 changes: 7 additions & 6 deletions src/twinkle/server/tinker/common/compat_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,14 @@ def get_template(self, adapter_name: str) -> Template:
def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps: torch.Tensor) -> List[dict]:
"""Convert raw logits to the expected output format with logprobs and elementwise_loss."""
from twinkle.utils.torch_utils import selective_log_softmax

device = logits.device if logits is not None else logps.device
results = []
for feature, logit in zip(inputs, logits):
if logits is None:
logits = [None] * len(inputs)
for idx, (feature, logit) in enumerate(zip(inputs, logits)):
# Ensure 1D shape and correct device to avoid dimension mismatch and device errors
labels = feature.loss_fn_inputs['target_tokens'].to_torch().long().view(-1).to(
logit.device) # shape (seq_len,)
weights = feature.loss_fn_inputs['weights'].to_torch().view(-1).to(logit.device) # shape (seq_len,)
labels = feature.loss_fn_inputs['target_tokens'].to_torch().long().view(-1).to(device) # shape (seq_len,)
weights = feature.loss_fn_inputs['weights'].to_torch().view(-1).to(device) # shape (seq_len,)

# Slice logits to match the sequence length of labels
# Labels are assumed to be already shifted/aligned with logits
Expand All @@ -138,7 +139,7 @@ def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps:
# Calculate log probs for all labels
token_log_probs = selective_log_softmax(feature_logits, labels)
else:
token_log_probs = logps[:seq_len, :]
token_log_probs = logps[idx, :seq_len]

# elementwise_loss: positive NLL loss (0.0 where masked)
elementwise_loss = -token_log_probs * weights
Expand Down
26 changes: 12 additions & 14 deletions src/twinkle/server/tinker/common/megatron_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,24 +74,22 @@ def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss
loss_kwargs = kwargs.copy()
loss_kwargs.update(loss_values)
# Megatron forward_backward returns loss directly
loss = super().forward_backward(inputs=input_features, adapter_name=adapter_name, **loss_kwargs)

# Get logits from outputs
optimizer_config = self.optimizer_group.get(adapter_name)
outputs = optimizer_config.outputs if optimizer_config else {}
outputs = super().forward_backward(inputs=input_features, adapter_name=adapter_name, **loss_kwargs)
loss = outputs.get('loss', None)
logits_list = outputs.get('logits', [])
logps = outputs.get('logprobs', [])

logps = outputs.get('logps', [])
# When PP enabled, only logits from last stage are available
if not logits_list and not logps:
if logits_list is None and logps is None:
return [None, None]

# Process logits to match transformers output format
if isinstance(logits_list, torch.Tensor):
logits = logits_list.detach()
else:
# Concatenate logits from multiple microbatches
logits = torch.cat([logit.detach() for logit in logits_list], dim=0)
logits = None
if logits_list is not None:
# Process logits to match transformers output format
if isinstance(logits_list, torch.Tensor):
logits = logits_list.detach()
else:
# Concatenate logits from multiple microbatches
logits = torch.cat([logit.detach() for logit in logits_list], dim=0)
logps = logps.detach().cpu()
results = self._get_forward_output(inputs, logits, logps)

Expand Down
Loading