From e8522f2d83b2512ee37b169bc42aeeec3cedd2d5 Mon Sep 17 00:00:00 2001 From: Dhyey Mavani <82772894+DhyeyMavani2003@users.noreply.github.com> Date: Tue, 17 Mar 2026 18:52:00 -0700 Subject: [PATCH] Add experimental causal-robust Toto finetuning path --- README.md | 20 ++++ toto/model/attention.py | 88 ++++++++++++++++++ toto/model/backbone.py | 25 +++++ toto/model/lightning_module.py | 34 +++++++ toto/model/transformer.py | 44 +++++++++ toto/scripts/configs/finetune_config.yaml | 6 +- toto/scripts/finetune_toto.py | 4 + toto/test/model/causal_robustness_test.py | 106 ++++++++++++++++++++++ 8 files changed, 326 insertions(+), 1 deletion(-) create mode 100644 toto/test/model/causal_robustness_test.py diff --git a/README.md b/README.md index 530b194..29b9798 100644 --- a/README.md +++ b/README.md @@ -233,6 +233,26 @@ To customize the fine-tuning recipe, modify the base configuration in [finetune_ By default, the tutorial uses the `proenfo_gfc12` dataset from the [autogluon/fev_datasets](https://huggingface.co/datasets/autogluon/fev_datasets) collection. +#### Causal Robust Fine-Tuning (Experimental) + +Toto now supports an optional causal robustness regularizer inspired by recent causal-transformer stability work. +The regularizer adds a log-barrier-style penalty during training based on attention-weighted variance in time-wise attention layers. + +Enable it in `model` config: + +```yaml +model: + causal_robust_lambda: 0.02 + causal_robust_alpha: 0.0001 + causal_robust_eps: 0.000001 + causal_robust_max_penalty: 20.0 +``` + +- `causal_robust_lambda`: weight of the robustness penalty in total training loss. +- `causal_robust_alpha`: coupling term in `margin = 1 - alpha * variance_trace`. +- `causal_robust_eps`: numerical floor for the log-barrier margin. +- `causal_robust_max_penalty`: cap to avoid exploding penalties. + #### Custom Datasets There are two ways to use custom datasets for fine-tuning: diff --git a/toto/model/attention.py b/toto/model/attention.py index a999612..ba33000 100644 --- a/toto/model/attention.py +++ b/toto/model/attention.py @@ -4,6 +4,7 @@ # Copyright 2025 Datadog, Inc. import logging +import math import warnings from enum import Enum from typing import TYPE_CHECKING, Optional, Union @@ -73,6 +74,84 @@ def __init__( if not hasattr(self, "attention_axis") or self.attention_axis not in (AttentionAxis.TIME, AttentionAxis.SPACE): raise ValueError("Child class must define attention_axis as AttentionAxis.TIME or AttentionAxis.SPACE.") + # Optional stability regularization inspired by causal-transformer margin/barrier ideas. + # Disabled by default to preserve legacy behavior and performance. + self.enable_causal_robustness = False + self.causal_robustness_alpha = 1e-4 + self.causal_robustness_eps = 1e-6 + self.causal_robustness_max_penalty = 20.0 + self.latest_causal_robustness_penalty: Optional[torch.Tensor] = None + + def configure_causal_robustness( + self, + *, + enabled: bool, + alpha: float = 1e-4, + eps: float = 1e-6, + max_penalty: float = 20.0, + ) -> None: + """ + Configure optional causal robustness regularization on this attention layer. + + The regularizer computes a log-barrier-style penalty based on attention-weighted + value variance in time-wise attention: + margin = 1 - alpha * Tr(Var_attn[v]) + penalty = -log(clamp(margin, min=eps)) + """ + self.enable_causal_robustness = enabled + self.causal_robustness_alpha = float(alpha) + self.causal_robustness_eps = float(eps) + self.causal_robustness_max_penalty = float(max_penalty) + + def _to_stats_layout(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, ...]: + """ + Convert Q/K/V to a shared layout (N, H, T, D) for diagnostic calculations. + """ + if self.use_memory_efficient_attention: + # (N, T, H, D) -> (N, H, T, D) + return q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3), v.permute(0, 2, 1, 3) + # Already (N, H, T, D) + return q, k, v + + def _compute_causal_robustness_penalty( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + seq_pos_offset: int, + ) -> torch.Tensor: + """ + Compute a differentiable stability penalty from time-wise attention statistics. + """ + q_stats, k_stats, v_stats = self._to_stats_layout(q, k, v) + + # Use float32 for numerical stability while keeping gradients. + q_stats = q_stats.float() + k_stats = k_stats.float() + v_stats = v_stats.float() + + scale = 1.0 / math.sqrt(self.head_dim) + scores = torch.einsum("nhtd,nhsd->nhts", q_stats, k_stats) * scale + + # Match the current attention behavior: strict causal masking on first pass, + # no additional within-chunk masking when decoding with a cache offset. + if seq_pos_offset == 0: + q_len = scores.shape[-2] + k_len = scores.shape[-1] + causal_mask = torch.tril(torch.ones((q_len, k_len), device=scores.device, dtype=torch.bool)) + scores = scores.masked_fill(~causal_mask, torch.finfo(scores.dtype).min) + + attn_weights = torch.softmax(scores, dim=-1) + + attn_mean = torch.einsum("nhts,nhsd->nhtd", attn_weights, v_stats) + attn_second_moment = torch.einsum("nhts,nhsd->nhtd", attn_weights, v_stats * v_stats) + attn_var_trace = (attn_second_moment - attn_mean * attn_mean).clamp_min(0.0).sum(dim=-1) + + margins = 1.0 - self.causal_robustness_alpha * attn_var_trace + barrier = -torch.log(margins.clamp_min(self.causal_robustness_eps)) + barrier = barrier.clamp_max(self.causal_robustness_max_penalty) + return barrier.mean().to(dtype=q.dtype) + def rearrange_inputs( self, inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"] ) -> Float[torch.Tensor, "... embed_dim"]: @@ -198,6 +277,7 @@ def forward( ] = None, kv_cache: Optional["KVCache"] = None, ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]: + self.latest_causal_robustness_penalty = None batch_size, variate, seq_len, _ = inputs.shape dropout = self.dropout if self.training else 0.0 @@ -206,6 +286,14 @@ def forward( q, k, v, seq_pos_offset = self.positional_embedding(q, k, v, kv_cache, layer_idx) + if self.enable_causal_robustness and self.training and self.attention_axis == AttentionAxis.TIME: + self.latest_causal_robustness_penalty = self._compute_causal_robustness_penalty( + q=q, + k=k, + v=v, + seq_pos_offset=seq_pos_offset, + ) + output = self.run_attention(attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate) output = self.rearrange_output(output, batch_size, variate, seq_len) diff --git a/toto/model/backbone.py b/toto/model/backbone.py index ba4ded9..3304bfa 100644 --- a/toto/model/backbone.py +++ b/toto/model/backbone.py @@ -94,9 +94,14 @@ def __init__( use_memory_efficient_attention: bool = True, stabilize_with_global: bool = True, scale_factor_exponent: float = 10.0, + enable_causal_robustness: bool = False, + causal_robustness_alpha: float = 1e-4, + causal_robustness_eps: float = 1e-6, + causal_robustness_max_penalty: float = 20.0, ): super().__init__() self.embed_dim = embed_dim + self.latest_causal_robustness_penalty: Optional[torch.Tensor] = None # Attributes for variate-label fusion (initialized when enable_variate_labels is called) self.fusion: Optional[Fusion] = None self.num_prepended_tokens: int = 0 @@ -126,6 +131,10 @@ def __init__( spacewise_every_n_layers=spacewise_every_n_layers, spacewise_first=spacewise_first, use_memory_efficient_attention=self.use_memory_efficient_attention, + enable_causal_robustness=enable_causal_robustness, + causal_robustness_alpha=causal_robustness_alpha, + causal_robustness_eps=causal_robustness_eps, + causal_robustness_max_penalty=causal_robustness_max_penalty, fusion=self.fusion, ) self.unembed = torch.nn.Linear(embed_dim, embed_dim * patch_size) @@ -216,6 +225,7 @@ def backbone( transformed: Float[torch.Tensor, "batch variates seq_len embed_dim"] = self.transformer( # type: ignore[assignment] embeddings, reduced_id_mask, kv_cache, variate_label_embeds=variate_label_embeds ) + self.latest_causal_robustness_penalty = self.transformer.latest_causal_robustness_penalty # Crop out the prepended tokens before unembedding added_tokens = transformed.shape[2] - original_seq_len if added_tokens > 0: @@ -298,3 +308,18 @@ def build_variate_label_embeds( exog_mask[:, -num_exogenous_variables:] = True # Select per-variate label: target label for genuine targets, exogenous label for EV channels return torch.where(exog_mask, exogenous_variate_label, target_variate_label) # (B, V, 1, D) + + def configure_causal_robustness( + self, + *, + enabled: bool, + alpha: float = 1e-4, + eps: float = 1e-6, + max_penalty: float = 20.0, + ) -> None: + self.transformer.configure_causal_robustness( + enabled=enabled, + alpha=alpha, + eps=eps, + max_penalty=max_penalty, + ) diff --git a/toto/model/lightning_module.py b/toto/model/lightning_module.py index 4962cee..3eef461 100644 --- a/toto/model/lightning_module.py +++ b/toto/model/lightning_module.py @@ -83,6 +83,10 @@ def __init__( min_lr: float = 1e-5, betas: tuple[float, float] = (0.9, 0.999), weight_decay: float = 0.01, + causal_robust_lambda: float = 0.0, + causal_robust_alpha: float = 1e-4, + causal_robust_eps: float = 1e-6, + causal_robust_max_penalty: float = 20.0, # backbone construction pretrained_backbone: TotoBackbone | None = None, add_exogenous_features: bool = False, @@ -109,6 +113,15 @@ def __init__( self.stable_steps = stable_steps self.decay_steps = decay_steps self.val_prediction_len = val_prediction_len + self.causal_robust_lambda = causal_robust_lambda + + if self.causal_robust_lambda > 0: + self.model.configure_causal_robustness( + enabled=True, + alpha=causal_robust_alpha, + eps=causal_robust_eps, + max_penalty=causal_robust_max_penalty, + ) # Loss setup (currently hard-wired to CombinedLoss) self.combined_loss = CombinedLoss() @@ -277,6 +290,27 @@ def _train_or_val_step(self, batch: CausalMaskedTimeseries, is_train: bool) -> t mean_total = total_sum / (valid_count + eps) prefix = "train" if is_train else "val" + if self.causal_robust_lambda > 0 and self.model.latest_causal_robustness_penalty is not None: + causal_penalty = self.model.latest_causal_robustness_penalty + weighted_causal_penalty = self.causal_robust_lambda * causal_penalty + mean_total = mean_total + weighted_causal_penalty + self.log( + f"{prefix}_causal_robustness_penalty", + causal_penalty.detach(), + prog_bar=False, + on_step=is_train, + on_epoch=True, + batch_size=batch.series.shape[0], + ) + self.log( + f"{prefix}_causal_robustness_penalty_weighted", + weighted_causal_penalty.detach(), + prog_bar=False, + on_step=is_train, + on_epoch=True, + batch_size=batch.series.shape[0], + ) + # Log loss for lightning progress bar and metrics self.log( f"{prefix}_loss", diff --git a/toto/model/transformer.py b/toto/model/transformer.py index a477e65..54f753c 100644 --- a/toto/model/transformer.py +++ b/toto/model/transformer.py @@ -68,6 +68,10 @@ def __init__( attention_axis: AttentionAxis = AttentionAxis.TIME, RMS_norm: bool = True, use_memory_efficient_attention: bool = True, + enable_causal_robustness: bool = False, + causal_robustness_alpha: float = 1e-4, + causal_robustness_eps: float = 1e-6, + causal_robustness_max_penalty: float = 20.0, ): super().__init__() self.embed_dim = embed_dim @@ -105,6 +109,13 @@ def __init__( else: raise ValueError("Invalid attention axis") + self.attention.configure_causal_robustness( + enabled=enable_causal_robustness, + alpha=causal_robustness_alpha, + eps=causal_robustness_eps, + max_penalty=causal_robustness_max_penalty, + ) + if XFORMERS_SWIGLU_AVAILABLE: self.mlp = torch.nn.Sequential( SwiGLU_fused(in_features=embed_dim, hidden_features=mlp_hidden_dim), @@ -176,6 +187,10 @@ def __init__( spacewise_every_n_layers: int, spacewise_first: bool, use_memory_efficient_attention: bool = True, + enable_causal_robustness: bool = False, + causal_robustness_alpha: float = 1e-4, + causal_robustness_eps: float = 1e-6, + causal_robustness_max_penalty: float = 20.0, *, fusion: Optional[Fusion] = None, ): @@ -193,6 +208,7 @@ def __init__( self.use_memory_efficient_attention = use_memory_efficient_attention self.fusion = fusion + self.latest_causal_robustness_penalty: Optional[torch.Tensor] = None self.layers = torch.nn.ModuleList( [ @@ -204,11 +220,31 @@ def __init__( rotary_emb=self.rotary_emb, attention_axis=attention_axes[i], use_memory_efficient_attention=self.use_memory_efficient_attention, + enable_causal_robustness=enable_causal_robustness, + causal_robustness_alpha=causal_robustness_alpha, + causal_robustness_eps=causal_robustness_eps, + causal_robustness_max_penalty=causal_robustness_max_penalty, ) for i in range(num_layers) ] ) + def configure_causal_robustness( + self, + *, + enabled: bool, + alpha: float = 1e-4, + eps: float = 1e-6, + max_penalty: float = 20.0, + ) -> None: + for layer in self.layers: + layer.attention.configure_causal_robustness( + enabled=enabled, + alpha=alpha, + eps=eps, + max_penalty=max_penalty, + ) + def _get_mask( self, num_heads: int, @@ -341,6 +377,7 @@ def forward( id_mask=id_mask, ) + causal_robustness_terms = [] for layer_idx, layer in enumerate(self.layers): inputs = layer( layer_idx, @@ -348,4 +385,11 @@ def forward( (timewise_attention_mask if layer.attention_axis == AttentionAxis.TIME else spacewise_attention_mask), kv_cache, ) + if layer.attention.latest_causal_robustness_penalty is not None: + causal_robustness_terms.append(layer.attention.latest_causal_robustness_penalty) + + if len(causal_robustness_terms) > 0: + self.latest_causal_robustness_penalty = torch.stack(causal_robustness_terms).mean() + else: + self.latest_causal_robustness_penalty = None return inputs diff --git a/toto/scripts/configs/finetune_config.yaml b/toto/scripts/configs/finetune_config.yaml index fc9c953..079adef 100644 --- a/toto/scripts/configs/finetune_config.yaml +++ b/toto/scripts/configs/finetune_config.yaml @@ -8,6 +8,10 @@ model: warmup_steps: 1000 # number of linear warmup steps (change for different training schedules) stable_steps: 200 # number of stable learning rate steps (change for different training schedules) decay_steps: 200 # number of exponential decay steps (change for different training schedules) + causal_robust_lambda: 0.0 # set >0 to enable causal robustness barrier regularization + causal_robust_alpha: 0.0001 # margin coupling coefficient in margin = 1 - alpha * attn_var_trace + causal_robust_eps: 0.000001 # numerical floor for the log-barrier margin + causal_robust_max_penalty: 20.0 # cap per-token barrier contribution for stability data: context_factor: 8 # context length = patch_size * context_factor change for different training context lengths @@ -36,4 +40,4 @@ checkpoint: logging: save_dir: lightning_logs - name: toto_finetuning # name of the experiment, will be used to save the logs \ No newline at end of file + name: toto_finetuning # name of the experiment, will be used to save the logs diff --git a/toto/scripts/finetune_toto.py b/toto/scripts/finetune_toto.py index be71d02..8cab4a6 100644 --- a/toto/scripts/finetune_toto.py +++ b/toto/scripts/finetune_toto.py @@ -50,6 +50,10 @@ def init_lightning(config: Dict[str, Any]) -> Tuple[TotoForFinetuning, int]: warmup_steps=int(mcfg.get("warmup_steps", 200)), lr=float(mcfg.get("lr", 1e-4)), min_lr=float(mcfg.get("min_lr", 1e-5)), + causal_robust_lambda=float(mcfg.get("causal_robust_lambda", 0.0)), + causal_robust_alpha=float(mcfg.get("causal_robust_alpha", 1e-4)), + causal_robust_eps=float(mcfg.get("causal_robust_eps", 1e-6)), + causal_robust_max_penalty=float(mcfg.get("causal_robust_max_penalty", 20.0)), add_exogenous_features=bool(dcfg.get("add_exogenous_features", False)), ) diff --git a/toto/test/model/causal_robustness_test.py b/toto/test/model/causal_robustness_test.py new file mode 100644 index 0000000..8d7a38c --- /dev/null +++ b/toto/test/model/causal_robustness_test.py @@ -0,0 +1,106 @@ +# Unless explicitly stated otherwise all files in this repository are licensed under the Apache-2.0 License. +# +# This product includes software developed at Datadog (https://www.datadoghq.com/) +# Copyright 2025 Datadog, Inc. + +import torch +import pytest + +try: + from toto.data.util.dataset import CausalMaskedTimeseries + from toto.model.backbone import TotoBackbone + from toto.model.lightning_module import TotoForFinetuning +except Exception as exc: # pragma: no cover - environment-specific import guard + pytest.skip(f"Skipping causal robustness tests due unavailable dependencies: {exc}", allow_module_level=True) + + +def _make_backbone() -> TotoBackbone: + return TotoBackbone( + patch_size=4, + stride=4, + embed_dim=32, + num_layers=2, + num_heads=4, + mlp_hidden_dim=64, + dropout=0.0, + spacewise_every_n_layers=2, + spacewise_first=True, + use_memory_efficient_attention=False, + scaler_cls="", + output_distribution_classes=[""], + ) + + +def _make_inputs(batch: int = 2, variates: int = 3, timesteps: int = 16): + series = torch.randn(batch, variates, timesteps) + padding_mask = torch.ones_like(series, dtype=torch.bool) + id_mask = torch.zeros_like(series, dtype=torch.long) + return series, padding_mask, id_mask + + +def _make_causal_batch(batch: int = 2, variates: int = 3, timesteps: int = 16) -> CausalMaskedTimeseries: + series = torch.randn(batch, variates, timesteps) + padding_mask = torch.ones_like(series, dtype=torch.bool) + id_mask = torch.zeros_like(series, dtype=torch.long) + timestamp_seconds = torch.arange(timesteps).view(1, 1, timesteps).expand(batch, variates, timesteps).long() + time_interval_seconds = torch.ones(batch, variates, dtype=torch.long) + # Input and target lengths are both 12 so that model outputs and targets align. + input_slice = slice(0, timesteps - 4) + target_slice = slice(4, timesteps) + return CausalMaskedTimeseries( + series=series, + padding_mask=padding_mask, + id_mask=id_mask, + timestamp_seconds=timestamp_seconds, + time_interval_seconds=time_interval_seconds, + input_slice=input_slice, + target_slice=target_slice, + num_exogenous_variables=0, + ) + + +def test_causal_robustness_penalty_disabled_by_default(): + model = _make_backbone().train() + series, padding_mask, id_mask = _make_inputs() + + _ = model(series, padding_mask, id_mask) + assert model.latest_causal_robustness_penalty is None + + +def test_causal_robustness_penalty_enabled_and_finite(): + model = _make_backbone().train() + model.configure_causal_robustness( + enabled=True, + alpha=1e-4, + eps=1e-6, + max_penalty=20.0, + ) + series, padding_mask, id_mask = _make_inputs() + + _ = model(series, padding_mask, id_mask) + penalty = model.latest_causal_robustness_penalty + + assert penalty is not None + assert torch.isfinite(penalty) + assert penalty.item() >= 0.0 + + +def test_lightning_step_adds_weighted_causal_robustness_penalty(): + module = TotoForFinetuning( + pretrained_backbone=_make_backbone(), + causal_robust_lambda=1.0, + causal_robust_alpha=1e-4, + causal_robust_eps=1e-6, + causal_robust_max_penalty=20.0, + ).train() + batch = _make_causal_batch() + + with torch.no_grad(): + loss_with_penalty = module._train_or_val_step(batch, is_train=True) + penalty = module.model.latest_causal_robustness_penalty + module.causal_robust_lambda = 0.0 + loss_without_penalty = module._train_or_val_step(batch, is_train=True) + + assert penalty is not None + assert penalty.item() >= 0.0 + assert loss_with_penalty >= loss_without_penalty