Skip to content

Commit 2086e87

Browse files
committed
fix loss computation bug
1 parent f9a8f78 commit 2086e87

File tree

4 files changed

+29
-32
lines changed

4 files changed

+29
-32
lines changed

cookbook/legacy/sft/single_controller_sp_packing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from twinkle.dataset import PackingDataset, DatasetMeta
99
from twinkle.model import TransformersModel
1010
from twinkle.preprocessor import SelfCognitionProcessor
11+
from twinkle.processor import InputProcessor
1112

1213
logger = get_logger()
1314
MODEL_ID = 'ms://Qwen/Qwen2.5-7B-Instruct'
@@ -82,9 +83,10 @@ def train():
8283
strategy="native_fsdp",
8384
remote_group="default",
8485
)
85-
8686
lora_config = LoraConfig(target_modules="all-linear")
8787
model.add_adapter_to_model("default", lora_config, gradient_accumulation_steps=1)
88+
model.set_processor(InputProcessor, padding_free=True, adapter_name="default")
89+
model.set_loss("CrossEntropyLoss", reduction="mean", adapter_name="default")
8890
model.set_optimizer("AdamW", lr=1e-4, adapter_name="default")
8991

9092
loss_metric = 99.0

cookbook/legacy/single_controller_sp.py

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -46,27 +46,13 @@ def create_dataset(data_slice=None):
4646
"Template",
4747
model_id=MODEL_ID,
4848
truncation_strategy="left",
49-
max_length=64,
49+
max_length=256,
5050
)
5151
dataset.map(SelfCognitionProcessor("twinkle模型", "twinkle团队"))
5252
dataset.encode(batched=True)
5353
return dataset
5454

5555

56-
def eval(model: TransformersModel):
57-
dataloader = DataLoader(
58-
dataset=partial(create_dataset, data_slice=range(20)),
59-
batch_size=4,
60-
drop_last=True,
61-
device_mesh=device_mesh,
62-
remote_group="default",
63-
)
64-
for step, batch in enumerate(dataloader):
65-
model.forward_only(inputs=batch, adapter_name="default")
66-
model.calculate_loss(adapter_name="default")
67-
metrics = model.calculate_metric(is_training=False, adapter_name="default")
68-
return metrics()
69-
7056

7157
def train():
7258
dataloader = DataLoader(
@@ -87,21 +73,20 @@ def train():
8773
model.add_adapter_to_model("default", lora_config, gradient_accumulation_steps=1)
8874
model.set_optimizer("AdamW", lr=1e-4, adapter_name="default")
8975

90-
loss_metric = 99.0
9176
for step, batch in enumerate(dataloader):
92-
if isinstance(batch, list) and len(batch) == 0:
93-
continue
94-
output = model.forward_backward(inputs=batch, adapter_name="default")
95-
loss_value = output() if callable(output) else output
96-
logger.info(f"step {step}, loss: {loss_value}")
97-
model.clip_grad_and_step(adapter_name="default")
98-
if step % 50 == 0 and step > 0:
99-
metrics = eval(model)
100-
logger.info(f"Current is step {step} of {len(dataloader)}, metric: {metrics}")
101-
metrics["step"] = step
102-
if loss_metric > metrics["loss"]:
103-
model.save(f"checkpoint-{step}")
104-
loss_metric = metrics["loss"]
77+
model.forward_backward(inputs=batch, adapter_name='default')
78+
model.clip_grad_and_step(adapter_name='default')
79+
if step % 1 == 0:
80+
metric = model.calculate_metric(is_training=True, adapter_name='default')
81+
_metrics = {}
82+
for key, value in metric.items():
83+
try:
84+
value = float(value)
85+
_metrics[key] = value
86+
except:
87+
pass
88+
logger.info(f'Current is step {step} of {len(dataloader)}, metric: {metric}')
89+
model.save(f'last-checkpoint', interval=1)
10590

10691

10792
if __name__ == "__main__":

src/twinkle/model/transformers/strategy/sequence_parallel.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
22
import math
3+
import os
34
from functools import partial
45
from types import SimpleNamespace
56
from typing import Any, Dict, Optional, Tuple, Union
@@ -1004,7 +1005,11 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore
10041005
local_sum = loss
10051006
global_sum = local_sum.detach().clone()
10061007
dist.all_reduce(global_sum, group=sequence_parallel._sp_group)
1007-
return global_sum + (local_sum - local_sum.detach())
1008+
out = global_sum + (local_sum - local_sum.detach())
1009+
if sequence_parallel.world_size > 1:
1010+
out_metric = out.detach() / sequence_parallel.world_size
1011+
return out_metric + (out - out.detach())
1012+
return out
10081013
# Default to mean reduction.
10091014
local_sum = loss * num_valid_tokens
10101015
global_sum = local_sum.detach().clone()
@@ -1013,7 +1018,11 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore
10131018
dist.all_reduce(global_tokens, group=sequence_parallel._sp_group)
10141019
if global_tokens.item() == 0:
10151020
return loss
1016-
return (global_sum + (local_sum - local_sum.detach())) / global_tokens
1021+
out = (global_sum + (local_sum - local_sum.detach())) / global_tokens
1022+
if sequence_parallel.world_size > 1:
1023+
out_metric = out.detach() / sequence_parallel.world_size
1024+
return out_metric + (out - out.detach())
1025+
return out
10171026

10181027
def wrap_model(self, model, optimizer=None):
10191028
self.initialize()

src/twinkle/model/transformers/transformers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) ModelScope Contributors. All rights reserved.
22
import contextlib
3+
import os
34
import json
45
import os
56
import re

0 commit comments

Comments
 (0)