Skip to content

Commit cdb8342

Browse files
Fix logps (#86)
1 parent 7178b06 commit cdb8342

File tree

3 files changed

+20
-21
lines changed

3 files changed

+20
-21
lines changed

src/twinkle/infra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def new_init(self, *args, **kwargs):
442442
# Pass an instance_id is recommended
443443
instance_id = kwargs.pop('instance_id', '') + f'{caller_file}_{caller_line}'
444444
remote_group = kwargs.get('remote_group')
445-
if remote_group is None:
445+
if os.environ.get('WORKER_NAME') is None and remote_group is None:
446446
logger.info(f'⚠️ Using local initialization of class: {cls}, please make sure the class '
447447
'does not need remote execution.')
448448
# If cannot trust_remote_code, no callable and type can be used.

src/twinkle/server/tinker/common/compat_base.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,14 @@ def get_template(self, adapter_name: str) -> Template:
117117
def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps: torch.Tensor) -> List[dict]:
118118
"""Convert raw logits to the expected output format with logprobs and elementwise_loss."""
119119
from twinkle.utils.torch_utils import selective_log_softmax
120-
120+
device = logits.device if logits is not None else logps.device
121121
results = []
122-
for feature, logit in zip(inputs, logits):
122+
if logits is None:
123+
logits = [None] * len(inputs)
124+
for idx, (feature, logit) in enumerate(zip(inputs, logits)):
123125
# Ensure 1D shape and correct device to avoid dimension mismatch and device errors
124-
labels = feature.loss_fn_inputs['target_tokens'].to_torch().long().view(-1).to(
125-
logit.device) # shape (seq_len,)
126-
weights = feature.loss_fn_inputs['weights'].to_torch().view(-1).to(logit.device) # shape (seq_len,)
126+
labels = feature.loss_fn_inputs['target_tokens'].to_torch().long().view(-1).to(device) # shape (seq_len,)
127+
weights = feature.loss_fn_inputs['weights'].to_torch().view(-1).to(device) # shape (seq_len,)
127128

128129
# Slice logits to match the sequence length of labels
129130
# Labels are assumed to be already shifted/aligned with logits
@@ -138,7 +139,7 @@ def _get_forward_output(inputs: List[types.Datum], logits: torch.Tensor, logps:
138139
# Calculate log probs for all labels
139140
token_log_probs = selective_log_softmax(feature_logits, labels)
140141
else:
141-
token_log_probs = logps[:seq_len, :]
142+
token_log_probs = logps[idx, :seq_len]
142143

143144
# elementwise_loss: positive NLL loss (0.0 where masked)
144145
elementwise_loss = -token_log_probs * weights

src/twinkle/server/tinker/common/megatron_model.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -74,24 +74,22 @@ def forward_backward(self, *, inputs: List[types.Datum], adapter_name: str, loss
7474
loss_kwargs = kwargs.copy()
7575
loss_kwargs.update(loss_values)
7676
# Megatron forward_backward returns loss directly
77-
loss = super().forward_backward(inputs=input_features, adapter_name=adapter_name, **loss_kwargs)
78-
79-
# Get logits from outputs
80-
optimizer_config = self.optimizer_group.get(adapter_name)
81-
outputs = optimizer_config.outputs if optimizer_config else {}
77+
outputs = super().forward_backward(inputs=input_features, adapter_name=adapter_name, **loss_kwargs)
78+
loss = outputs.get('loss', None)
8279
logits_list = outputs.get('logits', [])
83-
logps = outputs.get('logprobs', [])
84-
80+
logps = outputs.get('logps', [])
8581
# When PP enabled, only logits from last stage are available
86-
if not logits_list and not logps:
82+
if logits_list is None and logps is None:
8783
return [None, None]
8884

89-
# Process logits to match transformers output format
90-
if isinstance(logits_list, torch.Tensor):
91-
logits = logits_list.detach()
92-
else:
93-
# Concatenate logits from multiple microbatches
94-
logits = torch.cat([logit.detach() for logit in logits_list], dim=0)
85+
logits = None
86+
if logits_list is not None:
87+
# Process logits to match transformers output format
88+
if isinstance(logits_list, torch.Tensor):
89+
logits = logits_list.detach()
90+
else:
91+
# Concatenate logits from multiple microbatches
92+
logits = torch.cat([logit.detach() for logit in logits_list], dim=0)
9593
logps = logps.detach().cpu()
9694
results = self._get_forward_output(inputs, logits, logps)
9795

0 commit comments

Comments
 (0)