Chronos Bolt with Gemma decoder for time-series forecasting.
Input → Chronos Encoder → Gemma → Chronos Output Head → Quantile Predictions
(scaling/patching) (decoder) (forecasting)
- Encoder: Chronos Bolt's encoder (instance norm → patching → embedding)
- Decoder: Gemma (replaces T5)
- Output: Chronos Bolt's quantile prediction head
- Install dependencies:
pip install -r requirements.txt- Set up Hugging Face authentication for gated models (Gemma):
# Copy the example env file
cp .env.example .env
# Add your HF token to .env
# Get your token from: https://huggingface.co/settings/tokens
# HF_TOKEN=your_token_here- Request access to Gemma models:
- Visit https://huggingface.co/google/gemma-3-270m
- Click "Request access" and wait for approval
- Activate environment (if using conda):
conda activate weather
export KMP_DUPLICATE_LIB_OK=TRUE # Required for macOS- Run training as a Python module from project root:
# Quick demo (20 epochs)
python -m src.scripts.train --config gemma_ts_demo
# Full training (100 epochs)
python -m src.scripts.train --config gemma_tsConfigs are in src/configs/:
gemma_ts_demo.py: Quick demo (20 epochs)gemma_ts.py: Full training (100 epochs)chronos_demo.py: Chronos Bolt baseline demochronos.py: Chronos Bolt baseline full trainingpatchtst.py: PatchTST baseline
Uses accelerate for distributed training. First configure accelerate:
accelerate configThen launch training:
accelerate launch src/scripts/train.py --config gemma_tsFor 4 GPUs with mixed precision (fp16):
accelerate launch --num_processes=4 --mixed_precision=fp16 src/scripts/train.py --config gemma_tsfrom src.models.gemma_ts import create_gemma_ts
# Create model
model = create_gemma_ts(
chronos_base="amazon/chronos-bolt-tiny",
gemma_model="google/gemma-3-270m",
context_length=512,
prediction_length=64,
)
# Forward pass
outputs = model(context, mask, target, target_mask)
loss = outputs["loss"]
predictions = outputs["quantile_preds"] # (B, num_quantiles, pred_len)
# Autoregressive generation
generated = model.generate(context, steps=5, mask=mask)GemmaTS/
├── src/
│ ├── models/
│ │ ├── gemma_ts.py # Main model (inherits ChronosBolt)
│ │ ├── chronos_bolt.py # Chronos Bolt baseline
│ │ └── patchtst.py # PatchTST baseline
│ ├── configs/
│ │ ├── gemma_ts.py # Full GemmaTS training
│ │ ├── gemma_ts_demo.py # Quick GemmaTS demo
│ │ ├── chronos.py # Chronos Bolt full training
│ │ ├── chronos_demo.py # Chronos Bolt demo
│ │ └── patchtst.py # PatchTST config
│ ├── dataloader/
│ │ ├── data_loader.py # Dataset classes
│ │ └── data_factory.py # Data provider
│ ├── utils/
│ │ ├── metrics.py # MSE, MAE
│ │ └── seed.py # Reproducibility
│ └── scripts/
│ └── train.py # Training script
├── data/
│ ├── datasets/ # Dataset files
│ └── checkpoints/ # Saved models
├── requirements.txt
└── README.md
- Minimal code: Inherits from Chronos Bolt, only overrides decoder
- No duplication: Reuses Chronos's scaling, patching, and output logic
- Quantile forecasting: Outputs 9 quantiles [0.1, 0.2, ..., 0.9]
- Autoregressive generation: Uses median quantile for multi-step forecasting
- Training log:
data/results/train_log.tsv(step, split, loss, mse, mae, smape) - Checkpoints:
data/checkpoints/*.pt