Overview | Demo | Installation | Text Generation | Data Pipeline | Training
Fable is a compact storytelling language model implemented from scratch using JAX and Flax NNX.
It provides an end-to-end training pipeline β from text preparation and tokenisation to model training and autoregressive text generation.
Fable is designed to be small and fast β a fluent model can be trained on consumer GPUs in under an hour.
The included demo checkpoint was trained for ~2 hours on an RTX 4090 using the default configuration.
- π§ Minimal GPT Architecture β Lightweight decoder-only transformer (~800 k parameters).
- π§΅ Simple Data Pipeline β Deterministic download β clean β tokenise workflow for small text datasets.
- π JIT Compilation β Core steps compiled with
jax.jit, achieving ~3 million tokens/sec throughput. - πΎ Checkpointing β Save and restore model state and hyperparameters in a single folder.
- π Text Generation Tools β Generate short stories with adjustable sampling temperature.
The demo notebook walks through:
- Loading the pretrained demo checkpoint.
- Generating short stories from prompts.
- Exploring different temperature values.
- Optionally running the data pipeline and a brief training step.
The notebook runs entirely in a hosted Colab environment β no local setup required.
To install locally:
# Install latest release
pip install git+https://github.com/auxeno/fable
# Local development setup
git clone https://github.com/auxeno/fable.git
cd fable
pip install -e .For GPU acceleration, install the JAX wheel matching your CUDA version, for example:
pip install --upgrade "jax[cuda13]"Refer to the JAX installation guide for up-to-date instructions.
To generate a story, simply write the first few words/lines and Fable will continue.
from fable import generate_text
generate_text("Lily got a new puppy") # Uses checkpoints/demo by default
# To sample your own run: generate_text("...", checkpoint="model_state")fable-generate --start "Lily got a new puppy" --temperature 0.6
# Use --checkpoint model_state to load your own training runSampling temperature controls the balance between determinism and creativity:
- Low (β0.4): Predictable and faithful to training data.
- Medium (β0.6): More varied but occasionally incoherent.
- High (β0.8): Grammatically fluent but semantically unstable.
generate_text("Lily got a new puppy", temperature=0.4)
"""
Lily got a new puppy and said, "I want to see it!" Her mom smiled and said,
"Okay, Lily. Let's go to the puppy and see if we can play with it."
Lily smiled and said, "Okay."
Lily felt better.
...
"""generate_text("Lily got a new puppy", temperature=0.6)
"""
Lily got a new puppy named Spot. He hit his ball out and fell on the floor.
He cried and benly, but Spot was too fit fast.
His friend cheered and then flew to the ball.
They pulled and tugged and tugged. They ran away.
Spot gave the ball a kiss and the ball back to their mommy.
...
"""generate_text("Lily got a new puppy", temperature=0.8)
"""
Lily got a new puppy and stopped pretending,
because other automobiles walked through the fall.
Aftenma's mushy man purred a sleepy scene for a while, although Weggin circle.
They barked and snuggled until Fridge was.
Fred finally paddle of his bedroom were even blacker!
Jeddy was so excited that he didn't want to give.
...
"""Fable includes a small data-preparation utility that downloads and processes the TinyStories dataset for training.
# Download, clean, and tokenise text data
fable-prepare-data --stage all
# Or run an individual stage (download, clean, tokenise)
fable-prepare-data --stage cleanThis creates:
data/raw/β raw TinyStories dataset.txtfiles.data/clean/β cleaned text filtered to supported characters.data/tokenized/βint8binary token buffers used for model training.
While TinyStories is used by default, the same pipeline can be adapted for other small-scale narrative datasets with minimal modification.
Once data is prepared, train a model from scratch using either Python or the CLI.
from fable import save, train
model = train()
save(model)fable-train --num-epochs 5 --batch-size 128 --learning-rate 3e-4Training progress and checkpoints are saved automatically in checkpoints/.
fable/
βββ checkpoint.py # Save/load wrappers for NNX state trees
βββ config.py # GPTConfig dataclass and defaults
βββ data/
β βββ pipeline.py # Text data download/clean/tokenise commands
β βββ tokenize.py # Character-level tokenizer and vocabulary tools
β βββ tokenizer-config.json # Vocabulary and end-of-text token definitions
βββ evaluate.py # Validation step shared across training and notebooks
βββ generate.py # Text generation helpers and CLI entry point
βββ model/
β βββ attention.py # Multi-head self-attention
β βββ dropout.py # Lightweight stochastic dropout layer
β βββ feed_forward.py # GELU MLP block
β βββ gpt.py # GPT model assembly and forward pass
β βββ normalize.py # Layer normalisation layer
β βββ position.py # Sinusoidal positional embeddings
β βββ transformer.py # Pre-norm decoder block with dropout
βββ train.py # JIT-compiled training loop with Optax optimizers
βββ utils.py # TQDM helper wrappers
checkpoints/
βββ demo/ # Example checkpoint trained for ~2 hours on RTX 4090
If Fable supports your research or teaching, please cite:
@software{fable2025,
title = {Fable: A Compact Storytelling Language Model in JAX},
author = {Alex Goddard},
year = {2025},
url = {https://github.com/auxeno/fable}
}Released under the MIT License.
See licence for the full text.
Thanks to the creators of the TinyStories dataset (Eldan & Li),
and to the JAX and Flax contributors whose work made Fable possible.