Skip to content

RuntimeError: FlashAttention only support fp16 and bf16 data type when running training with flash_attn, log_samples: true & mixed_precision=bf16 #1201

@yash-malviya-tih

Description

@yash-malviya-tih

Checks

  • This template is only for bug reports, usage problems go with 'Help Wanted'.
  • I have thoroughly reviewed the project documentation but couldn't find information to solve my problem.
  • I have searched for existing issues, including closed ones, and couldn't find a solution.
  • I am using English to submit this issue to facilitate community communication.

Environment Details

I am running in Docker container - ghcr.io/swivid/f5-tts:main, without changing any existing dependency versions

Steps to Reproduce

Set /root/.cache/huggingface/accelerate/default_config.yaml as

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: 'NO'
downcast_bf16: 'no'
enable_cpu_affinity: false
gpu_ids: '0'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Set training config (F5TTS_v1_Base.yaml) as

hydra:
  run:
    dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}_frame_flash_attn/${now:%Y-%m-%d}/${now:%H-%M-%S}


datasets:
  name: Emilia_EN  # dataset name
  batch_size_per_gpu: 38400 # 8 GPUs, 8 * 38400 = 307200
  batch_size_type: frame  # frame | sample
  max_samples: 64  # max sequences per batch if use frame-wise batch_size. we set 32 for small models, 64 for base models
  num_workers: 16


optim:
  epochs: 1
  learning_rate: 7.5e-5
  num_warmup_updates: 0  # warmup updates
  grad_accumulation_steps: 1  # note: updates = steps / grad_accumulation_steps
  max_grad_norm: 1.0  # gradient clipping
  bnb_optimizer: False  # use bnb 8bit AdamW optimizer or not


model:
  name: F5TTS_v1_Base  # model name
  tokenizer: char  # tokenizer type
  tokenizer_path: null  # if 'custom' tokenizer, define the path want to use (should be vocab.txt)
  backbone: DiT
  arch:
    dim: 1024
    depth: 22
    heads: 16
    ff_mult: 2
    text_dim: 512
    text_mask_padding: True
    qk_norm: null  # null | rms_norm
    conv_layers: 4
    pe_attn_head: null
    attn_backend: flash_attn  # torch | flash_attn
    attn_mask_enabled: False
    checkpoint_activations: False  # recompute activations and save memory for extra compute
  mel_spec:
    target_sample_rate: 24000
    n_mel_channels: 100
    hop_length: 256
    win_length: 1024
    n_fft: 1024
    mel_spec_type: vocos  # vocos | bigvgan
  vocoder:
    is_local: False  # use local offline ckpt or not
    local_path: null  # local vocoder path


ckpts:
  logger: wandb  # wandb | tensorboard | null
  log_samples: True  # infer random sample per save checkpoint. wip, normal to fail with extra long samples
  save_per_updates: 100  # save checkpoint per updates
  keep_last_n_checkpoints: 3  # -1 to keep all, 0 to not save intermediate, > 0 to keep last N checkpoints
  last_per_updates: 50  # save last checkpoint per updates
  save_dir: ckpts/${model.name}_${model.mel_spec.mel_spec_type}_${model.tokenizer}_${datasets.name}_frame_flash_attn

Then command to run training

CUDA_VISIBLE_DEVICES=0 accelerate launch --mixed_precision=bf16 src/f5_tts/train/train.py --config-name F5TTS_v1_Base.yaml

✔️ Expected Behavior

No exception, training should complete

❌ Actual Behavior

Got this error

Traceback (most recent call last):
  File "/workspace/personal/F5-TTS/src/f5_tts/train/train.py", line 69, in main
    trainer.train(
  File "/workspace/personal/F5-TTS/src/f5_tts/model/trainer.py", line 418, in train
    generated, _ = self.accelerator.unwrap_model(self.model).sample(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^                                                   File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/personal/F5-TTS/src/f5_tts/model/cfm.py", line 217, in sample
    trajectory = odeint(fn, y0, t, **self.odeint_kwargs)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torchdiffeq/_impl/odeint.py", line 80, in odeint
    solution = solver.integrate(t)
               ^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torchdiffeq/_impl/solvers.py", line 114, in integrate
    dy, f0 = self._step_func(self.func, t0, dt, t1, y0)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torchdiffeq/_impl/fixed_grid.py", line 10, in _step_func
    f0 = func(t0, y0, perturb=Perturb.NEXT if self.perturb else Perturb.NONE)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl       
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torchdiffeq/_impl/misc.py", line 197, in forward
    return self.base_func(t, y)
           ^^^^^^^^^^^^^^^^^^^^
  File "/workspace/personal/F5-TTS/src/f5_tts/model/cfm.py", line 180, in fn
    pred_cfg = self.transformer(
               ^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/personal/F5-TTS/src/f5_tts/model/backbones/dit.py", line 307, in forward
    x = block(x, t, mask=mask, rope=rope)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/personal/F5-TTS/src/f5_tts/model/modules.py", line 685, in forward
    attn_output = self.attn(x=norm, mask=mask, rope=rope)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/personal/F5-TTS/src/f5_tts/model/modules.py", line 432, in forward
    return self.processor(self, x, mask=mask, rope=rope)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/personal/F5-TTS/src/f5_tts/model/modules.py", line 527, in __call__
    x = flash_attn_func(query, key, value, dropout_p=0.0, causal=False)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 1196, in flash_attn_func
    return FlashAttnFunc.apply(
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/autograd/function.py", line 574, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 834, in forward
    out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_ops.py", line 1061, in __call__
    return self_._op(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/torch/_library/custom_ops.py", line 236, in backend_impl
    result = self._backend_fns[device_type](*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/flash_attn/flash_attn_interface.py", line 91, in _flash_attn_forward
    out, softmax_lse, S_dmask, rng_state = flash_attn_gpu.fwd(
                                           ^^^^^^^^^^^^^^^^^^^
RuntimeError: FlashAttention only support fp16 and bf16 data type

Error is not there when I set log_samples: False

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions