Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,26 @@ To customize the fine-tuning recipe, modify the base configuration in [finetune_

By default, the tutorial uses the `proenfo_gfc12` dataset from the [autogluon/fev_datasets](https://huggingface.co/datasets/autogluon/fev_datasets) collection.

#### Causal Robust Fine-Tuning (Experimental)

Toto now supports an optional causal robustness regularizer inspired by recent causal-transformer stability work.
The regularizer adds a log-barrier-style penalty during training based on attention-weighted variance in time-wise attention layers.

Enable it in `model` config:

```yaml
model:
causal_robust_lambda: 0.02
causal_robust_alpha: 0.0001
causal_robust_eps: 0.000001
causal_robust_max_penalty: 20.0
```

- `causal_robust_lambda`: weight of the robustness penalty in total training loss.
- `causal_robust_alpha`: coupling term in `margin = 1 - alpha * variance_trace`.
- `causal_robust_eps`: numerical floor for the log-barrier margin.
- `causal_robust_max_penalty`: cap to avoid exploding penalties.

#### Custom Datasets

There are two ways to use custom datasets for fine-tuning:
Expand Down
88 changes: 88 additions & 0 deletions toto/model/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Copyright 2025 Datadog, Inc.

import logging
import math
import warnings
from enum import Enum
from typing import TYPE_CHECKING, Optional, Union
Expand Down Expand Up @@ -73,6 +74,84 @@ def __init__(
if not hasattr(self, "attention_axis") or self.attention_axis not in (AttentionAxis.TIME, AttentionAxis.SPACE):
raise ValueError("Child class must define attention_axis as AttentionAxis.TIME or AttentionAxis.SPACE.")

# Optional stability regularization inspired by causal-transformer margin/barrier ideas.
# Disabled by default to preserve legacy behavior and performance.
self.enable_causal_robustness = False
self.causal_robustness_alpha = 1e-4
self.causal_robustness_eps = 1e-6
self.causal_robustness_max_penalty = 20.0
self.latest_causal_robustness_penalty: Optional[torch.Tensor] = None

def configure_causal_robustness(
self,
*,
enabled: bool,
alpha: float = 1e-4,
eps: float = 1e-6,
max_penalty: float = 20.0,
) -> None:
"""
Configure optional causal robustness regularization on this attention layer.

The regularizer computes a log-barrier-style penalty based on attention-weighted
value variance in time-wise attention:
margin = 1 - alpha * Tr(Var_attn[v])
penalty = -log(clamp(margin, min=eps))
"""
self.enable_causal_robustness = enabled
self.causal_robustness_alpha = float(alpha)
self.causal_robustness_eps = float(eps)
self.causal_robustness_max_penalty = float(max_penalty)

def _to_stats_layout(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, ...]:
"""
Convert Q/K/V to a shared layout (N, H, T, D) for diagnostic calculations.
"""
if self.use_memory_efficient_attention:
# (N, T, H, D) -> (N, H, T, D)
return q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3)
# Already (N, H, T, D)
return q, k, v

def _compute_causal_robustness_penalty(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_pos_offset: int,
) -> torch.Tensor:
"""
Compute a differentiable stability penalty from time-wise attention statistics.
"""
q_stats, k_stats, v_stats = self._to_stats_layout(q, k, v)

# Use float32 for numerical stability while keeping gradients.
q_stats = q_stats.float()
k_stats = k_stats.float()
v_stats = v_stats.float()

scale = 1.0 / math.sqrt(self.head_dim)
scores = torch.einsum("nhtd,nhsd->nhts", q_stats, k_stats) * scale

# Match the current attention behavior: strict causal masking on first pass,
# no additional within-chunk masking when decoding with a cache offset.
if seq_pos_offset == 0:
q_len = scores.shape[-2]
k_len = scores.shape[-1]
causal_mask = torch.tril(torch.ones((q_len, k_len), device=scores.device, dtype=torch.bool))
scores = scores.masked_fill(~causal_mask, torch.finfo(scores.dtype).min)

attn_weights = torch.softmax(scores, dim=-1)

attn_mean = torch.einsum("nhts,nhsd->nhtd", attn_weights, v_stats)
attn_second_moment = torch.einsum("nhts,nhsd->nhtd", attn_weights, v_stats * v_stats)
attn_var_trace = (attn_second_moment - attn_mean * attn_mean).clamp_min(0.0).sum(dim=-1)

margins = 1.0 - self.causal_robustness_alpha * attn_var_trace
barrier = -torch.log(margins.clamp_min(self.causal_robustness_eps))
barrier = barrier.clamp_max(self.causal_robustness_max_penalty)
return barrier.mean().to(dtype=q.dtype)

def rearrange_inputs(
self, inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"]
) -> Float[torch.Tensor, "... embed_dim"]:
Expand Down Expand Up @@ -198,6 +277,7 @@ def forward(
] = None,
kv_cache: Optional["KVCache"] = None,
) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]:
self.latest_causal_robustness_penalty = None
batch_size, variate, seq_len, _ = inputs.shape
dropout = self.dropout if self.training else 0.0

Expand All @@ -206,6 +286,14 @@ def forward(

q, k, v, seq_pos_offset = self.positional_embedding(q, k, v, kv_cache, layer_idx)

if self.enable_causal_robustness and self.training and self.attention_axis == AttentionAxis.TIME:
self.latest_causal_robustness_penalty = self._compute_causal_robustness_penalty(
q=q,
k=k,
v=v,
seq_pos_offset=seq_pos_offset,
)

output = self.run_attention(attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate)

output = self.rearrange_output(output, batch_size, variate, seq_len)
Expand Down
25 changes: 25 additions & 0 deletions toto/model/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,14 @@ def __init__(
use_memory_efficient_attention: bool = True,
stabilize_with_global: bool = True,
scale_factor_exponent: float = 10.0,
enable_causal_robustness: bool = False,
causal_robustness_alpha: float = 1e-4,
causal_robustness_eps: float = 1e-6,
causal_robustness_max_penalty: float = 20.0,
):
super().__init__()
self.embed_dim = embed_dim
self.latest_causal_robustness_penalty: Optional[torch.Tensor] = None
# Attributes for variate-label fusion (initialized when enable_variate_labels is called)
self.fusion: Optional[Fusion] = None
self.num_prepended_tokens: int = 0
Expand Down Expand Up @@ -126,6 +131,10 @@ def __init__(
spacewise_every_n_layers=spacewise_every_n_layers,
spacewise_first=spacewise_first,
use_memory_efficient_attention=self.use_memory_efficient_attention,
enable_causal_robustness=enable_causal_robustness,
causal_robustness_alpha=causal_robustness_alpha,
causal_robustness_eps=causal_robustness_eps,
causal_robustness_max_penalty=causal_robustness_max_penalty,
fusion=self.fusion,
)
self.unembed = torch.nn.Linear(embed_dim, embed_dim * patch_size)
Expand Down Expand Up @@ -216,6 +225,7 @@ def backbone(
transformed: Float[torch.Tensor, "batch variates seq_len embed_dim"] = self.transformer( # type: ignore[assignment]
embeddings, reduced_id_mask, kv_cache, variate_label_embeds=variate_label_embeds
)
self.latest_causal_robustness_penalty = self.transformer.latest_causal_robustness_penalty
# Crop out the prepended tokens before unembedding
added_tokens = transformed.shape[2] - original_seq_len
if added_tokens > 0:
Expand Down Expand Up @@ -298,3 +308,18 @@ def build_variate_label_embeds(
exog_mask[:, -num_exogenous_variables:] = True
# Select per-variate label: target label for genuine targets, exogenous label for EV channels
return torch.where(exog_mask, exogenous_variate_label, target_variate_label) # (B, V, 1, D)

def configure_causal_robustness(
self,
*,
enabled: bool,
alpha: float = 1e-4,
eps: float = 1e-6,
max_penalty: float = 20.0,
) -> None:
self.transformer.configure_causal_robustness(
enabled=enabled,
alpha=alpha,
eps=eps,
max_penalty=max_penalty,
)
34 changes: 34 additions & 0 deletions toto/model/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def __init__(
min_lr: float = 1e-5,
betas: tuple[float, float] = (0.9, 0.999),
weight_decay: float = 0.01,
causal_robust_lambda: float = 0.0,
causal_robust_alpha: float = 1e-4,
causal_robust_eps: float = 1e-6,
causal_robust_max_penalty: float = 20.0,
# backbone construction
pretrained_backbone: TotoBackbone | None = None,
add_exogenous_features: bool = False,
Expand All @@ -109,6 +113,15 @@ def __init__(
self.stable_steps = stable_steps
self.decay_steps = decay_steps
self.val_prediction_len = val_prediction_len
self.causal_robust_lambda = causal_robust_lambda

if self.causal_robust_lambda > 0:
self.model.configure_causal_robustness(
enabled=True,
alpha=causal_robust_alpha,
eps=causal_robust_eps,
max_penalty=causal_robust_max_penalty,
)

# Loss setup (currently hard-wired to CombinedLoss)
self.combined_loss = CombinedLoss()
Expand Down Expand Up @@ -277,6 +290,27 @@ def _train_or_val_step(self, batch: CausalMaskedTimeseries, is_train: bool) -> t
mean_total = total_sum / (valid_count + eps)
prefix = "train" if is_train else "val"

if self.causal_robust_lambda > 0 and self.model.latest_causal_robustness_penalty is not None:
causal_penalty = self.model.latest_causal_robustness_penalty
weighted_causal_penalty = self.causal_robust_lambda * causal_penalty
mean_total = mean_total + weighted_causal_penalty
self.log(
f"{prefix}_causal_robustness_penalty",
causal_penalty.detach(),
prog_bar=False,
on_step=is_train,
on_epoch=True,
batch_size=batch.series.shape[0],
)
self.log(
f"{prefix}_causal_robustness_penalty_weighted",
weighted_causal_penalty.detach(),
prog_bar=False,
on_step=is_train,
on_epoch=True,
batch_size=batch.series.shape[0],
)

# Log loss for lightning progress bar and metrics
self.log(
f"{prefix}_loss",
Expand Down
44 changes: 44 additions & 0 deletions toto/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def __init__(
attention_axis: AttentionAxis = AttentionAxis.TIME,
RMS_norm: bool = True,
use_memory_efficient_attention: bool = True,
enable_causal_robustness: bool = False,
causal_robustness_alpha: float = 1e-4,
causal_robustness_eps: float = 1e-6,
causal_robustness_max_penalty: float = 20.0,
):
super().__init__()
self.embed_dim = embed_dim
Expand Down Expand Up @@ -105,6 +109,13 @@ def __init__(
else:
raise ValueError("Invalid attention axis")

self.attention.configure_causal_robustness(
enabled=enable_causal_robustness,
alpha=causal_robustness_alpha,
eps=causal_robustness_eps,
max_penalty=causal_robustness_max_penalty,
)

if XFORMERS_SWIGLU_AVAILABLE:
self.mlp = torch.nn.Sequential(
SwiGLU_fused(in_features=embed_dim, hidden_features=mlp_hidden_dim),
Expand Down Expand Up @@ -176,6 +187,10 @@ def __init__(
spacewise_every_n_layers: int,
spacewise_first: bool,
use_memory_efficient_attention: bool = True,
enable_causal_robustness: bool = False,
causal_robustness_alpha: float = 1e-4,
causal_robustness_eps: float = 1e-6,
causal_robustness_max_penalty: float = 20.0,
*,
fusion: Optional[Fusion] = None,
):
Expand All @@ -193,6 +208,7 @@ def __init__(

self.use_memory_efficient_attention = use_memory_efficient_attention
self.fusion = fusion
self.latest_causal_robustness_penalty: Optional[torch.Tensor] = None

self.layers = torch.nn.ModuleList(
[
Expand All @@ -204,11 +220,31 @@ def __init__(
rotary_emb=self.rotary_emb,
attention_axis=attention_axes[i],
use_memory_efficient_attention=self.use_memory_efficient_attention,
enable_causal_robustness=enable_causal_robustness,
causal_robustness_alpha=causal_robustness_alpha,
causal_robustness_eps=causal_robustness_eps,
causal_robustness_max_penalty=causal_robustness_max_penalty,
)
for i in range(num_layers)
]
)

def configure_causal_robustness(
self,
*,
enabled: bool,
alpha: float = 1e-4,
eps: float = 1e-6,
max_penalty: float = 20.0,
) -> None:
for layer in self.layers:
layer.attention.configure_causal_robustness(
enabled=enabled,
alpha=alpha,
eps=eps,
max_penalty=max_penalty,
)

def _get_mask(
self,
num_heads: int,
Expand Down Expand Up @@ -341,11 +377,19 @@ def forward(
id_mask=id_mask,
)

causal_robustness_terms = []
for layer_idx, layer in enumerate(self.layers):
inputs = layer(
layer_idx,
inputs,
(timewise_attention_mask if layer.attention_axis == AttentionAxis.TIME else spacewise_attention_mask),
kv_cache,
)
if layer.attention.latest_causal_robustness_penalty is not None:
causal_robustness_terms.append(layer.attention.latest_causal_robustness_penalty)

if len(causal_robustness_terms) > 0:
self.latest_causal_robustness_penalty = torch.stack(causal_robustness_terms).mean()
else:
self.latest_causal_robustness_penalty = None
return inputs
6 changes: 5 additions & 1 deletion toto/scripts/configs/finetune_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ model:
warmup_steps: 1000 # number of linear warmup steps (change for different training schedules)
stable_steps: 200 # number of stable learning rate steps (change for different training schedules)
decay_steps: 200 # number of exponential decay steps (change for different training schedules)
causal_robust_lambda: 0.0 # set >0 to enable causal robustness barrier regularization
causal_robust_alpha: 0.0001 # margin coupling coefficient in margin = 1 - alpha * attn_var_trace
causal_robust_eps: 0.000001 # numerical floor for the log-barrier margin
causal_robust_max_penalty: 20.0 # cap per-token barrier contribution for stability

data:
context_factor: 8 # context length = patch_size * context_factor change for different training context lengths
Expand Down Expand Up @@ -36,4 +40,4 @@ checkpoint:

logging:
save_dir: lightning_logs
name: toto_finetuning # name of the experiment, will be used to save the logs
name: toto_finetuning # name of the experiment, will be used to save the logs
4 changes: 4 additions & 0 deletions toto/scripts/finetune_toto.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def init_lightning(config: Dict[str, Any]) -> Tuple[TotoForFinetuning, int]:
warmup_steps=int(mcfg.get("warmup_steps", 200)),
lr=float(mcfg.get("lr", 1e-4)),
min_lr=float(mcfg.get("min_lr", 1e-5)),
causal_robust_lambda=float(mcfg.get("causal_robust_lambda", 0.0)),
causal_robust_alpha=float(mcfg.get("causal_robust_alpha", 1e-4)),
causal_robust_eps=float(mcfg.get("causal_robust_eps", 1e-6)),
causal_robust_max_penalty=float(mcfg.get("causal_robust_max_penalty", 20.0)),
add_exogenous_features=bool(dcfg.get("add_exogenous_features", False)),
)

Expand Down
Loading