Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/python-wheels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ jobs:
- uses: Jimver/cuda-toolkit@v0.2.21
id: cuda-toolkit
with:
cuda: "12.6.0"
# TODO CUDA build matrix
cuda: "11.8.0"

- run: |
echo "Installed cuda version is: ${{steps.cuda-toolkit.outputs.cuda}}"
Expand Down
8 changes: 4 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ members = [
"server",
]
resolver = "2"
package.version = "0.2.3"
package.version = "0.3.0"

[profile.release-with-debug]
inherits = "release"
Expand Down
8 changes: 4 additions & 4 deletions fish_speech_core/lib/lm/dual_ar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,8 @@ impl Attention {
.reshape((bsz, self.n_local_heads * n_rep, kv_seqlen, self.head_dim))?;
// TODO: Fix op to handle bsz > 1
#[cfg(feature = "cuda")]
let key_states = match bsz {
1 => repeat_kv(&key_states, n_rep)?,
let key_states = match (bsz, key_states.device()) {
(1, &Device::Cuda(_)) => repeat_kv(&key_states, n_rep)?,
_ => key_states
.unsqueeze(2)?
.expand((bsz, self.n_local_heads, n_rep, kv_seqlen, self.head_dim))?
Expand All @@ -348,8 +348,8 @@ impl Attention {
.expand((bsz, self.n_local_heads, n_rep, kv_seqlen, self.head_dim))?
.reshape((bsz, self.n_local_heads * n_rep, kv_seqlen, self.head_dim))?;
#[cfg(feature = "cuda")]
let value_states = match bsz {
1 => repeat_kv(&value_states, n_rep)?,
let value_states = match (bsz, value_states.device()) {
(1, &Device::Cuda(_)) => repeat_kv(&value_states, n_rep)?,
_ => value_states
.unsqueeze(2)?
.expand((bsz, self.n_local_heads, n_rep, kv_seqlen, self.head_dim))?
Expand Down
3 changes: 2 additions & 1 deletion fish_speech_python/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ OS + hardware:
- Linux:
- CPU: x86_64, glibc 2.34+
- Example: Ubuntu 22.04 IS supported, Ubuntu 20.04 IS NOT supported
- GPU: Nvidia CUDA 12+ with compute capability >= 8.0 (RTX 30 series+, A100 series+)
- GPU: Nvidia CUDA 11.8+ with compute capability >= 8.0 (RTX 30 series+, A100 series+)
- Example: 2080 Ti is NOT supported (Turing)
- Example: RX5700 XT is NOT supported (AMD)
- NOTE: We don't currently have a CUDA build matrix, so it's compiled with CUDA 11.8; sorry. It should be compatible with newer CUDA versions, but please use the Rust runtime if full optimization is required.
- macOS (M1+, 14.0+ (Monterey))

Windows and AMD hardware will never be supported, so don't ask.
Expand Down
2 changes: 1 addition & 1 deletion fish_speech_python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name = "fish_speech_rs"
# Handled by maturin
description = "High-performance speech synthesis"
version = "0.2.3"
version = "0.3.0"
readme = "README.md"
requires-python = ">=3.9"
dependencies = [
Expand Down
15 changes: 10 additions & 5 deletions fish_speech_python/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,17 @@ impl FireflyCodec {
.map_err(|_| PyException::new_err(format!("Unsupported model version: {}", version)))?;
let codec_type = WhichCodec::from_model(model_type);

let device = get_device(device)?;
let dtype = match dtype {
"f32" => DType::F32,
"bf16" => DType::BF16,
d => return Err(PyException::new_err(format!("Unsupported dtype: {}", d))),
let dtype = match (dtype, device) {
("bf16", "cuda") | ("bf16", "metal") => DType::BF16,
("f32", _) => DType::F32,
(d, _) => {
return Err(PyException::new_err(format!(
"Unsupported dtype on device {}: {}",
device, d
)))
}
};
let device = get_device(device)?;

let vb = match model_type {
WhichModel::Fish1_2 => VarBuilder::from_pth(
Expand Down
14 changes: 9 additions & 5 deletions fish_speech_python/src/lm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ impl LM {
let model_type = get_version(version)
.map_err(|_| PyException::new_err(format!("Unsupported model version: {}", version)))?;

// TODO Modularization
let dtype = match dtype {
"f32" => DType::F32,
"bf16" => DType::BF16,
d => return Err(PyException::new_err(format!("Unsupported dtype: {}", d))),
let dtype = match (dtype, device) {
("bf16", "cuda") | ("bf16", "metal") => DType::BF16,
("f32", _) => DType::F32,
(d, _) => {
return Err(PyException::new_err(format!(
"Unsupported dtype on device {}: {}",
device, d
)))
}
};
let device = get_device(device)?;
let vb = match model_type {
Expand Down