From c7de2ee8a3c1d707c36f0dd8811e7f75f5ab18b8 Mon Sep 17 00:00:00 2001 From: Fast-F5-TTS <2942755472@qq.com> Date: Wed, 4 Jun 2025 19:19:57 +0800 Subject: [PATCH 1/7] update Empirically Pruned Step Sampling --- src/f5_tts/model/cfm.py | 8 +++++++- src/f5_tts/model/utils.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/src/f5_tts/model/cfm.py b/src/f5_tts/model/cfm.py index 90679be0d..30450f986 100644 --- a/src/f5_tts/model/cfm.py +++ b/src/f5_tts/model/cfm.py @@ -26,6 +26,7 @@ list_str_to_idx, list_str_to_tensor, mask_from_frac_lengths, + get_epss_timesteps, ) @@ -96,6 +97,7 @@ def sample( duplicate_test=False, t_inter=0.1, edit_mask=None, + use_epss=True, ): self.eval() # raw wave @@ -190,7 +192,11 @@ def fn(t, x): y0 = (1 - t_start) * y0 + t_start * test_cond steps = int(steps * (1 - t_start)) - t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype) + # use Empirically Pruned Step Sampling to imporve synthesis quality with small number of sampling steps + if t_start == 0 and use_epss: + t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype) + else: + t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype) if sway_sampling_coef is not None: t = t + sway_sampling_coef * (torch.cos(torch.pi / 2 * t) - 1 + t) diff --git a/src/f5_tts/model/utils.py b/src/f5_tts/model/utils.py index 439184bc1..ea3a68e5b 100644 --- a/src/f5_tts/model/utils.py +++ b/src/f5_tts/model/utils.py @@ -189,3 +189,31 @@ def repetition_found(text, length=2, tolerance=10): if count > tolerance: return True return False + +def get_epss_timesteps(n, device, dtype): + t = [] + dt = 1 / 32 + if n == 5: + t = [0 * dt, 2 * dt, 4 * dt, 8 * dt, + 16* dt, 32* dt] + elif n == 6: + # t = [0 * dt, 2 * dt, 4 * dt, 8 * dt, + # 16* dt, 24* dt, 32* dt] + t = [0 * dt, 2 * dt, 4 * dt, 6 * dt, 8 * dt, + 16* dt, 32* dt] + elif n == 7: + t = [0 * dt, 2 * dt, 4 * dt, 6 * dt, 8 * dt, + 16* dt, 24* dt, 32* dt] + elif n == 10: + t = [0 * dt, 1 * dt, 2 * dt, 3 * dt, 4 * dt, 5 * dt, 6 * dt, 7 * dt, 8 * dt, + 12* dt, 16* dt, 20* dt, 24* dt, 28* dt, 32* dt] + elif n == 12: + t = [0 * dt, 2 * dt, 4 * dt, 6 * dt, 8 * dt, + 10* dt, 12* dt, 14* dt, 16* dt, 20* dt, 24* dt, 28* dt, 32* dt] + elif n == 16: + t = [0 * dt, 1 * dt, 2 * dt, 3 * dt, 4 * dt, 5 * dt, 6 * dt, 7 * dt, 8 * dt, + 10* dt, 12* dt, 14* dt, 16* dt, 20* dt, 24* dt, 28* dt, 32* dt] + if len(t) == 0: + return torch.linspace(0, 1, n + 1, device=device, dtype=dtype) + else: + return torch.tensor(t, device=device, dtype=dtype) \ No newline at end of file From af332ee065c86235631cb30a3702ad4a83aa7612 Mon Sep 17 00:00:00 2001 From: Fast-F5-TTS <2942755472@qq.com> Date: Wed, 4 Jun 2025 19:22:51 +0800 Subject: [PATCH 2/7] update Empirically Pruned Step Sampling --- src/f5_tts/model/cfm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/f5_tts/model/cfm.py b/src/f5_tts/model/cfm.py index 30450f986..12bf20ee9 100644 --- a/src/f5_tts/model/cfm.py +++ b/src/f5_tts/model/cfm.py @@ -192,7 +192,7 @@ def fn(t, x): y0 = (1 - t_start) * y0 + t_start * test_cond steps = int(steps * (1 - t_start)) - # use Empirically Pruned Step Sampling to imporve synthesis quality with small number of sampling steps + # use Empirically Pruned Step Sampling to improve synthesis quality with small number of sampling steps if t_start == 0 and use_epss: t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype) else: From 21d63a17c6c3848368ddd12484dc0267208edb04 Mon Sep 17 00:00:00 2001 From: Fast-F5-TTS <2942755472@qq.com> Date: Wed, 4 Jun 2025 19:36:59 +0800 Subject: [PATCH 3/7] update Empirically Pruned Step Sampling --- src/f5_tts/model/utils.py | 38 ++++++++++++++------------------------ 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/src/f5_tts/model/utils.py b/src/f5_tts/model/utils.py index ea3a68e5b..0e58c9f9f 100644 --- a/src/f5_tts/model/utils.py +++ b/src/f5_tts/model/utils.py @@ -191,29 +191,19 @@ def repetition_found(text, length=2, tolerance=10): return False def get_epss_timesteps(n, device, dtype): - t = [] dt = 1 / 32 - if n == 5: - t = [0 * dt, 2 * dt, 4 * dt, 8 * dt, - 16* dt, 32* dt] - elif n == 6: - # t = [0 * dt, 2 * dt, 4 * dt, 8 * dt, - # 16* dt, 24* dt, 32* dt] - t = [0 * dt, 2 * dt, 4 * dt, 6 * dt, 8 * dt, - 16* dt, 32* dt] - elif n == 7: - t = [0 * dt, 2 * dt, 4 * dt, 6 * dt, 8 * dt, - 16* dt, 24* dt, 32* dt] - elif n == 10: - t = [0 * dt, 1 * dt, 2 * dt, 3 * dt, 4 * dt, 5 * dt, 6 * dt, 7 * dt, 8 * dt, - 12* dt, 16* dt, 20* dt, 24* dt, 28* dt, 32* dt] - elif n == 12: - t = [0 * dt, 2 * dt, 4 * dt, 6 * dt, 8 * dt, - 10* dt, 12* dt, 14* dt, 16* dt, 20* dt, 24* dt, 28* dt, 32* dt] - elif n == 16: - t = [0 * dt, 1 * dt, 2 * dt, 3 * dt, 4 * dt, 5 * dt, 6 * dt, 7 * dt, 8 * dt, - 10* dt, 12* dt, 14* dt, 16* dt, 20* dt, 24* dt, 28* dt, 32* dt] - if len(t) == 0: + predefined_timesteps = { + 5: [0, 2, 4, 8, 16, 32], + 6: [0, 2, 4, 6, 8, 16, 32], + 7: [0, 2, 4, 6, 8, 16, 24, 32], + 10: [0, 1, 2, 3, 4, 5, 6, 7, 8, + 12, 16, 20, 24, 28, 32], + 12: [0, 2, 4, 6, 8, 10, 12, 14, + 16, 20, 24, 28, 32], + 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, + 10, 12, 14, 16, 20, 24, 28, 32], + } + t = predefined_timesteps.get(n, []) + if not t: return torch.linspace(0, 1, n + 1, device=device, dtype=dtype) - else: - return torch.tensor(t, device=device, dtype=dtype) \ No newline at end of file + return dt * torch.tensor(t, device=device, dtype=dtype) From c2c250f9f1f3b5abd871cb9feb2b1bdb3a65e738 Mon Sep 17 00:00:00 2001 From: Fast-F5-TTS <2942755472@qq.com> Date: Wed, 4 Jun 2025 19:41:00 +0800 Subject: [PATCH 4/7] update Empirically Pruned Step Sampling --- src/f5_tts/model/utils.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/f5_tts/model/utils.py b/src/f5_tts/model/utils.py index 0e58c9f9f..194cbed82 100644 --- a/src/f5_tts/model/utils.py +++ b/src/f5_tts/model/utils.py @@ -193,15 +193,12 @@ def repetition_found(text, length=2, tolerance=10): def get_epss_timesteps(n, device, dtype): dt = 1 / 32 predefined_timesteps = { - 5: [0, 2, 4, 8, 16, 32], - 6: [0, 2, 4, 6, 8, 16, 32], - 7: [0, 2, 4, 6, 8, 16, 24, 32], - 10: [0, 1, 2, 3, 4, 5, 6, 7, 8, - 12, 16, 20, 24, 28, 32], - 12: [0, 2, 4, 6, 8, 10, 12, 14, - 16, 20, 24, 28, 32], - 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, - 10, 12, 14, 16, 20, 24, 28, 32], + 5: [0, 2, 4, 8, 16, 32], + 6: [0, 2, 4, 6, 8, 16, 32], + 7: [0, 2, 4, 6, 8, 16, 24, 32], + 10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32], + 12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32], + 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32], } t = predefined_timesteps.get(n, []) if not t: From b9d5360994de459fea54678955205b7c13103fbe Mon Sep 17 00:00:00 2001 From: Fast-F5-TTS <2942755472@qq.com> Date: Wed, 4 Jun 2025 19:42:42 +0800 Subject: [PATCH 5/7] update Empirically Pruned Step Sampling --- src/f5_tts/model/cfm.py | 2 +- src/f5_tts/model/utils.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/f5_tts/model/cfm.py b/src/f5_tts/model/cfm.py index 12bf20ee9..9c3a12ce1 100644 --- a/src/f5_tts/model/cfm.py +++ b/src/f5_tts/model/cfm.py @@ -22,11 +22,11 @@ from f5_tts.model.utils import ( default, exists, + get_epss_timesteps, lens_to_mask, list_str_to_idx, list_str_to_tensor, mask_from_frac_lengths, - get_epss_timesteps, ) diff --git a/src/f5_tts/model/utils.py b/src/f5_tts/model/utils.py index 194cbed82..215215e5e 100644 --- a/src/f5_tts/model/utils.py +++ b/src/f5_tts/model/utils.py @@ -193,11 +193,11 @@ def repetition_found(text, length=2, tolerance=10): def get_epss_timesteps(n, device, dtype): dt = 1 / 32 predefined_timesteps = { - 5: [0, 2, 4, 8, 16, 32], - 6: [0, 2, 4, 6, 8, 16, 32], - 7: [0, 2, 4, 6, 8, 16, 24, 32], - 10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32], - 12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32], + 5: [0, 2, 4, 8, 16, 32], + 6: [0, 2, 4, 6, 8, 16, 32], + 7: [0, 2, 4, 6, 8, 16, 24, 32], + 10: [0, 2, 4, 6, 8, 12, 16, 20, 24, 28, 32], + 12: [0, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28, 32], 16: [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32], } t = predefined_timesteps.get(n, []) From e661b24d6e8b5bd809ac5060eb33857c198d8f15 Mon Sep 17 00:00:00 2001 From: Fast-F5-TTS <2942755472@qq.com> Date: Wed, 4 Jun 2025 19:45:05 +0800 Subject: [PATCH 6/7] update Empirically Pruned Step Sampling --- src/f5_tts/model/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/f5_tts/model/utils.py b/src/f5_tts/model/utils.py index 215215e5e..eb37fc52a 100644 --- a/src/f5_tts/model/utils.py +++ b/src/f5_tts/model/utils.py @@ -190,6 +190,8 @@ def repetition_found(text, length=2, tolerance=10): return True return False + +# get the empirically pruned step for sampling def get_epss_timesteps(n, device, dtype): dt = 1 / 32 predefined_timesteps = { From 0aacfae2b37aa18a89d7e02d7e02201538dd637d Mon Sep 17 00:00:00 2001 From: SWivid Date: Wed, 4 Jun 2025 22:57:43 +0800 Subject: [PATCH 7/7] format --- src/f5_tts/model/cfm.py | 5 ++--- src/f5_tts/model/utils.py | 2 ++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/f5_tts/model/cfm.py b/src/f5_tts/model/cfm.py index 9c3a12ce1..15be1bb55 100644 --- a/src/f5_tts/model/cfm.py +++ b/src/f5_tts/model/cfm.py @@ -93,11 +93,11 @@ def sample( seed: int | None = None, max_duration=4096, vocoder: Callable[[float["b d n"]], float["b nw"]] | None = None, # noqa: F722 + use_epss=True, no_ref_audio=False, duplicate_test=False, t_inter=0.1, edit_mask=None, - use_epss=True, ): self.eval() # raw wave @@ -192,8 +192,7 @@ def fn(t, x): y0 = (1 - t_start) * y0 + t_start * test_cond steps = int(steps * (1 - t_start)) - # use Empirically Pruned Step Sampling to improve synthesis quality with small number of sampling steps - if t_start == 0 and use_epss: + if t_start == 0 and use_epss: # use Empirically Pruned Step Sampling for low NFE t = get_epss_timesteps(steps, device=self.device, dtype=step_cond.dtype) else: t = torch.linspace(t_start, 1, steps + 1, device=self.device, dtype=step_cond.dtype) diff --git a/src/f5_tts/model/utils.py b/src/f5_tts/model/utils.py index eb37fc52a..37d51784c 100644 --- a/src/f5_tts/model/utils.py +++ b/src/f5_tts/model/utils.py @@ -192,6 +192,8 @@ def repetition_found(text, length=2, tolerance=10): # get the empirically pruned step for sampling + + def get_epss_timesteps(n, device, dtype): dt = 1 / 32 predefined_timesteps = {