Skip to content

FlashAttention training is unstable #1217

@yash98

Description

@yash98

Checks

  • This template is only for bug reports, usage problems go with 'Help Wanted'.
  • I have thoroughly reviewed the project documentation but couldn't find information to solve my problem.
  • I have searched for existing issues, including closed ones, and couldn't find a solution.
  • I am using English to submit this issue to facilitate community communication.

Environment Details

Running in docker container - ghcr.io/swivid/f5-tts:main
I am running this on 8xH200

Steps to Reproduce

  1. Set training config and accelerate config as mentioned below
  2. Steps to start docker and starting training script is same as that mention in repo README
  3. Run the training on only English subset of Emilia, during 2nd epoch / after 159k steps loss starts to increase. accelerate launch --mixed_precision= bf16 src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml

F5TTS_v1_Base.yaml

hydra:
  run:
    dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}/${now:%Y-%m-%d}/${now:%H-%M-%S}


datasets:
  name: Emilia_EN_full  # dataset name
  batch_size_per_gpu: 57600  # 75/50*38400  # 7 GPUs, 7 * 57600 = 403200
  batch_size_type: frame  # frame | sample
  max_samples: 64  # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
  num_workers: 16


optim:
  epochs: 11
  learning_rate: 7.5e-5
  num_warmup_updates: 20000  # warmup updates
  grad_accumulation_steps: 1  # note: updates = steps / grad_accumulation_steps
  max_grad_norm: 1.0  # gradient clipping
  bnb_optimizer: False  # use bnb 8bit AdamW optimizer or not


model:
  name: F5TTS_v1_Base  # model name
  tokenizer: char  # tokenizer type
  tokenizer_path: null  # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
  backbone: DiT
  arch:
    dim: 1024
    depth: 22
    heads: 16
    ff_mult: 2
    text_dim: 512
    text_mask_padding: True
    qk_norm: null  # null | rms_norm
    conv_layers: 4
    pe_attn_head: null
    attn_backend: flash_attn  # torch | flash_attn
    attn_mask_enabled: False
    checkpoint_activations: False  # recompute activations and save memory for extra compute
  mel_spec:
    target_sample_rate: 24000
    n_mel_channels: 100
    hop_length: 256
    win_length: 1024
    n_fft: 1024
    mel_spec_type: vocos  # vocos | bigvgan
  vocoder:
    is_local: False  # use local offline ckpt or not
    local_path: null  # local vocoder path


ckpts:
  logger: wandb  # wandb | tensorboard | null
  logger_run_name: ${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}
  log_samples: True  # infer random sample per save checkpoint. wip, normal to fail with extra long samples
  save_per_updates: 50000  # save checkpoint per updates
  keep_last_n_checkpoints: 30  # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
  last_per_updates: 5000  # save last checkpoint per updates
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}

/root/.cache/huggingface/accelerate/default_config.yam

compute_environment: LOCAL_MACHINE
debug: true
distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: false
gpu_ids: 1,2,3,4,5,6,7
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 7
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

✔️ Expected Behavior

Loss converges and reduces

❌ Actual Behavior

Loss increases to value even higher than when training started. Subsequent steps decrease the loss a bit but not to earlier extent

Metadata

Metadata

Assignees

No one assigned

    Labels

    help wantedExtra attention is needed

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions