Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .dev_scripts/ci_container_test.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
install_twinkle_with_kernels() {
pip install ".[kernels]" -i https://mirrors.aliyun.com/pypi/simple/ || pip install ".[kernels]"
}

if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
# pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
git config --global --add safe.directory /twinkle
Expand Down Expand Up @@ -25,8 +29,9 @@ if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
pip install optimum

# test with install
pip install .
install_twinkle_with_kernels
else
install_twinkle_with_kernels
echo "Running case in release image, run case directly!"
fi
# remove torch_extensions folder to avoid ci hang.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ transformers = [
"torch>=2.6.0,<3.0.0",
"torchvision",
]
kernels = ["kernels"]
megatron = ["megatron-core>=0.12.0", "transformer-engine[pytorch]"]
vllm = ["vllm>=0.11"]
ray = ["ray[serve]"]
Expand Down
12 changes: 10 additions & 2 deletions src/twinkle/kernel/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,16 @@ def impl(*args, **kwargs):
from kernels._versions import select_revision_or_version
from kernels.utils import get_kernel
assert repo_id is not None
resolved = select_revision_or_version(repo_id, revision, version)
kernel = get_kernel(repo_id, revision=resolved)
# kernels API changed across versions; use keyword args for modern API
# and fall back to repo_id-only for older variants.
try:
resolved = select_revision_or_version(repo_id, revision=revision, version=version)
except TypeError:
resolved = select_revision_or_version(repo_id)
try:
kernel = get_kernel(repo_id, revision=resolved)
except TypeError:
kernel = get_kernel(repo_id, resolved)
func = getattr(kernel, func_name, None)
if func is None:
raise AttributeError(f'Kernel repo {repo_id} does not export {func_name}.')
Expand Down
83 changes: 59 additions & 24 deletions src/twinkle/model/transformers/strategy/sequence_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ def backward(ctx, *grad_output):
# Split grads back to local sequence chunk.
_grad = grad_output[0]
if sequence_parallel.world_size > 1 and sequence_parallel._sp_group is not None:
# Gather replicates the sequence dimension across SP ranks. Scale once here
# so downstream FSDP avg does not shrink this path by an extra SP factor.
_grad = _grad * sequence_parallel.world_size
_grad = sequence_parallel.split(_grad, dim=ctx.gather_idx, position_ids=ctx.position_ids).contiguous()
return _grad, None, None, None

Expand Down Expand Up @@ -785,7 +788,8 @@ def pad_and_split_inputs(self,
# - In next-token-aligned labels, this appears at labels[b-1]
boundary_starts = (real_position_ids == 0)
prev = torch.zeros_like(boundary_starts, dtype=torch.bool)
prev[..., 1:] = boundary_starts[..., :-1]
# Mask token b-1 when boundary starts at b.
prev[..., :-1] = boundary_starts[..., 1:]
labels = labels.clone()
labels[prev] = -100
# Also avoid any potential wrap-around supervision at the end of the concatenated stream.
Expand Down Expand Up @@ -867,6 +871,7 @@ class SequenceParallelConfig:
ulysses_size: Optional[int] = None
gather_logits: bool = True
loss_reduction: str = 'mean'
compensate_fsdp_avg: bool = False


def _get_ulysses_size(device_mesh, sp_config: Optional[Dict[str, Any]] = None) -> int:
Expand Down Expand Up @@ -969,34 +974,64 @@ def reduce_loss(self, loss: torch.Tensor, labels: Optional[torch.Tensor], ignore
return loss
if labels is None or sequence_parallel._sp_group is None:
return loss
# Compute full-sequence loss in forward, but keep backward local to this rank.
# Compute global loss via autograd-aware all-reduce.
reduction = str(self.sp_config.get('loss_reduction', 'mean')).lower()
if reduction == 'none':
raise ValueError("SequenceParallelStrategy.reduce_loss only supports reduction='sum' or 'mean'. "
'Please aggregate per-token losses before calling reduce_loss.')
num_valid_tokens = (labels != ignore_index).sum().to(loss.device)
compensate_fsdp_avg = bool(self.sp_config.get('compensate_fsdp_avg', False))
compensate_factor = float(self.ulysses_size if compensate_fsdp_avg else 1.0)
sum_metric_scale = float(self.ulysses_size)

class _ReduceSequenceParallelLoss(torch.autograd.Function):

@staticmethod
def forward(ctx, local_mean: torch.Tensor, num_valid_tokens: torch.Tensor) -> torch.Tensor:
local_tokens = num_valid_tokens.detach().clone()
local_sum = local_mean * local_tokens
if local_tokens.item() == 0:
local_sum = torch.nan_to_num(local_sum)
global_sum = local_sum.detach().clone()
dist.all_reduce(global_sum, group=sequence_parallel._sp_group)
global_tokens = num_valid_tokens.detach().clone()
dist.all_reduce(global_tokens, group=sequence_parallel._sp_group)
ctx.save_for_backward(local_tokens, global_tokens)
if global_tokens.item() == 0:
return local_sum
return global_sum / global_tokens

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
local_tokens, global_tokens = ctx.saved_tensors
if global_tokens.item() == 0:
return torch.zeros_like(grad_output), None
# d(global_mean)/d(local_mean) = local_tokens / global_tokens.
grad_local_mean = grad_output * (local_tokens / global_tokens) * compensate_factor
return grad_local_mean, None

class _ReduceSequenceParallelSum(torch.autograd.Function):

@staticmethod
def forward(ctx, local_sum: torch.Tensor) -> torch.Tensor:
ctx.sum_metric_scale = sum_metric_scale
global_sum = local_sum.detach().clone()
dist.all_reduce(global_sum, group=sequence_parallel._sp_group)
# Keep logging/metric value aligned with non-SP sum semantics under
# outer collect='mean' by removing one SP replication factor.
return global_sum / ctx.sum_metric_scale

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
# Keep training gradient scale unchanged; forward-side scaling is for
# logging/metric alignment under outer collect='mean'.
return grad_output

if reduction == 'sum':
local_sum = loss
global_sum = local_sum.detach().clone()
dist.all_reduce(global_sum, group=sequence_parallel._sp_group)
out = global_sum + (local_sum - local_sum.detach())
if sequence_parallel.world_size > 1:
out_metric = out.detach() / sequence_parallel.world_size
return out_metric + (out - out.detach())
return out
# Default to mean reduction.
local_sum = loss * num_valid_tokens
global_sum = local_sum.detach().clone()
dist.all_reduce(global_sum, group=sequence_parallel._sp_group)
global_tokens = num_valid_tokens.detach().clone()
dist.all_reduce(global_tokens, group=sequence_parallel._sp_group)
if global_tokens.item() == 0:
return loss
out = (global_sum + (local_sum - local_sum.detach())) / global_tokens
if sequence_parallel.world_size > 1:
out_metric = out.detach() / sequence_parallel.world_size
return out_metric + (out - out.detach())
return out
return _ReduceSequenceParallelSum.apply(loss)

# Default to mean reduction: `loss` is local mean.
num_valid_tokens = (labels != ignore_index).sum().to(loss.device)
return _ReduceSequenceParallelLoss.apply(loss, num_valid_tokens)

def wrap_model(self, model, optimizer=None):
self.initialize()
Expand Down
15 changes: 10 additions & 5 deletions src/twinkle/model/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,15 @@ def _ensure_sp_strategy(self) -> None:
return
from .strategy.sequence_parallel import SequenceParallelStrategy

sp_config = {}
# When data-parallel gradient averaging runs across SP shards (native FSDP or
# accelerate DDP/FSDP paths), compensate SP loss backward to keep gradient scale.
if isinstance(self.strategy, (NativeFSDPStrategy, AccelerateStrategy)) and self.device_mesh is not None:
if (self.device_mesh.ulysses_size or 1) > 1 and (self.device_mesh.data_world_size or 1) > 1:
sp_config['compensate_fsdp_avg'] = True
self.sp_strategy = SequenceParallelStrategy(
self.device_mesh,
{},
sp_config,
model=self.model,
tokenizer_id=self.tokenizer_id,
)
Expand Down Expand Up @@ -434,10 +440,9 @@ def calculate_loss(self, **kwargs):
optimizer_config = self.optimizer_group[adapter_name]
optimizer_config.num_tokens += counts.item()
if self.sp_strategy is not None and 'labels' in inputs:
if 'loss_reduction' not in self.sp_strategy.sp_config:
reduction = getattr(loss_instance, 'reduction', None)
if reduction is not None:
self.sp_strategy.sp_config['loss_reduction'] = str(reduction)
reduction = getattr(loss_instance, 'reduction', None)
if reduction is not None:
self.sp_strategy.sp_config['loss_reduction'] = str(reduction)
loss_value = self.sp_strategy.reduce_loss(loss_value, inputs['labels'])
optimizer_config.loss_value += loss_value
outputs['loss'] = optimizer_config.loss_value
Expand Down
Loading