From 1679fccb088b4b07f90a45cba5e5f2d6bc53035d Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Sun, 28 Dec 2025 16:32:09 -0300 Subject: [PATCH 01/24] wip --- shared/data-provider/src/gcs.rs | 129 +++++++++++++++++++++++++++++++ shared/data-provider/src/lib.rs | 2 + shared/eval/examples/evaluate.rs | 23 +++++- 3 files changed, 151 insertions(+), 3 deletions(-) create mode 100644 shared/data-provider/src/gcs.rs diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs new file mode 100644 index 000000000..d5c7b39f7 --- /dev/null +++ b/shared/data-provider/src/gcs.rs @@ -0,0 +1,129 @@ +use google_cloud_storage::http::objects::{download::Range, get::GetObjectRequest, list::ListObjectsRequest}; +use std::path::PathBuf; +use tokio::runtime::Runtime; +use tracing::info; + +const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; + +fn check_model_extension(filename: &str) -> bool { + MODEL_EXTENSIONS.iter().any(|ext| filename.ends_with(ext)) +} + +fn get_cache_dir(bucket: &str, prefix: Option<&str>) -> PathBuf { + let base = std::env::var("PSYCHE_CACHE_DIR") + .map(PathBuf::from) + .unwrap_or_else(|_| { + std::env::var("HOME") + .map(|h| PathBuf::from(h).join(".cache")) + .unwrap_or_else(|_| PathBuf::from(".cache")) + }) + .join("psyche") + .join("gcs") + .join(bucket); + + match prefix { + Some(p) => base.join(p.trim_end_matches('/')), + None => base, + } +} + +async fn download_model_from_gcs_async( + bucket: &str, + prefix: Option<&str>, + cache_dir: Option, + progress_bar: bool, +) -> Vec { + let config = google_cloud_storage::client::ClientConfig::default().anonymous(); + let client = google_cloud_storage::client::Client::new(config); + + // List all objects in the bucket with optional prefix + let mut all_objects = vec![]; + let mut next_page_token: Option> = Some(None); + + while let Some(maybe_next_page_token) = next_page_token { + let results = client + .list_objects(&ListObjectsRequest { + bucket: bucket.to_owned(), + prefix: prefix.map(|s| s.to_owned()), + page_token: maybe_next_page_token, + ..Default::default() + }) + .await + .unwrap(); + + for obj in results.items.iter().flatten() { + if check_model_extension(&obj.name) { + all_objects.push(obj.name.clone()); + } + } + + next_page_token = results.next_page_token.map(Some); + } + + if progress_bar { + info!("Found {} model files in gs://{}/{}", all_objects.len(), bucket, prefix.unwrap_or("")); + } + + // Determine cache directory + let cache_dir = cache_dir.unwrap_or_else(|| get_cache_dir(bucket, prefix)); + std::fs::create_dir_all(&cache_dir).unwrap(); + + let mut downloaded_files = Vec::new(); + + for object_name in all_objects { + // Get just the filename (strip prefix if present) + let filename = object_name + .rsplit('/') + .next() + .unwrap_or(&object_name); + + let local_path = cache_dir.join(filename); + + // Skip if already cached + if local_path.exists() { + if progress_bar { + info!("Using cached: {}", filename); + } + downloaded_files.push(local_path); + continue; + } + + if progress_bar { + info!("Downloading: {}", object_name); + } + + // Download the object + let data = client + .download_object( + &GetObjectRequest { + bucket: bucket.to_owned(), + object: object_name.clone(), + ..Default::default() + }, + &Range::default(), + ) + .await + .unwrap(); + + // Write to cache + std::fs::write(&local_path, &data).unwrap(); + + if progress_bar { + info!("Downloaded: {} ({} bytes)", filename, data.len()); + } + + downloaded_files.push(local_path); + } + + downloaded_files +} + +pub fn download_model_from_gcs_sync( + bucket: &str, + prefix: Option<&str>, + cache_dir: Option, + progress_bar: bool, +) -> Vec { + let rt = Runtime::new().unwrap(); + rt.block_on(download_model_from_gcs_async(bucket, prefix, cache_dir, progress_bar)) +} diff --git a/shared/data-provider/src/lib.rs b/shared/data-provider/src/lib.rs index 868bfa793..27b43b3a6 100644 --- a/shared/data-provider/src/lib.rs +++ b/shared/data-provider/src/lib.rs @@ -2,6 +2,7 @@ mod data_provider; mod dataset; mod dummy; mod file_extensions; +mod gcs; pub mod http; mod hub; mod local; @@ -14,6 +15,7 @@ pub use data_provider::DataProvider; pub use dataset::{Dataset, Field, Row, Split}; pub use dummy::DummyDataProvider; pub use file_extensions::{DATA_FILE_EXTENSIONS, PARQUET_EXTENSION}; +pub use gcs::download_model_from_gcs_sync; pub use hub::{ UploadModelError, download_dataset_repo_async, download_dataset_repo_sync, download_model_repo_async, download_model_repo_sync, upload_model_repo_async, diff --git a/shared/eval/examples/evaluate.rs b/shared/eval/examples/evaluate.rs index c21dd26cf..c867c7ae2 100644 --- a/shared/eval/examples/evaluate.rs +++ b/shared/eval/examples/evaluate.rs @@ -2,7 +2,7 @@ use anyhow::Result; use clap::Parser; use indicatif::{ProgressBar, ProgressStyle}; use psyche_core::RunningAverage; -use psyche_data_provider::download_model_repo_sync; +use psyche_data_provider::{download_model_from_gcs_sync, download_model_repo_sync}; use psyche_eval::{ ALL_TASK_NAMES, EvalTaskOptions, Task, progress_bar_template_with_task, tasktype_from_name, }; @@ -25,6 +25,14 @@ struct Args { #[arg(long)] hf_token: Option, + /// GCS bucket name (use instead of HuggingFace model) + #[arg(long)] + gcs_bucket: Option, + + /// GCS prefix/directory within the bucket + #[arg(long)] + gcs_prefix: Option, + #[arg(long, default_value_t = ALL_TASK_NAMES.join(","))] tasks: String, @@ -75,9 +83,14 @@ fn main() -> Result<()> { } else { "".to_string() }; + let model_source = if let Some(ref bucket) = args.gcs_bucket { + format!("gs://{}/{}", bucket, args.gcs_prefix.as_deref().unwrap_or("")) + } else { + args.model.clone() + }; println!( "Running tasks with model {}, seed: {}, DP={}{}", - args.model, args.seed, args.data_parallelism, limit_str + model_source, args.seed, args.data_parallelism, limit_str ); for task in &tasks { println!(" - {}: {} few-shot examples", task, task.num_fewshot); @@ -107,7 +120,11 @@ fn main() -> Result<()> { } } - let repo = download_model_repo_sync(&args.model, args.revision, None, args.hf_token, true)?; + let repo = if let Some(ref bucket) = args.gcs_bucket { + download_model_from_gcs_sync(bucket, args.gcs_prefix.as_deref(), None, true) + } else { + download_model_repo_sync(&args.model, args.revision, None, args.hf_token, true)? + }; let tokenizer = auto_tokenizer(&repo)?; let (python, python_arch) = { From 472d2d6c22f3cda6b8316ead9339db6506e3e216 Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Sun, 28 Dec 2025 17:01:07 -0300 Subject: [PATCH 02/24] crendential wip --- shared/data-provider/src/gcs.rs | 39 ++++++++++++++++++++++++-------- shared/eval/examples/evaluate.rs | 6 ++++- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index d5c7b39f7..dacab92e9 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -1,4 +1,7 @@ -use google_cloud_storage::http::objects::{download::Range, get::GetObjectRequest, list::ListObjectsRequest}; +use google_cloud_storage::client::{Client, ClientConfig}; +use google_cloud_storage::http::objects::{ + download::Range, get::GetObjectRequest, list::ListObjectsRequest, +}; use std::path::PathBuf; use tokio::runtime::Runtime; use tracing::info; @@ -33,8 +36,19 @@ async fn download_model_from_gcs_async( cache_dir: Option, progress_bar: bool, ) -> Vec { - let config = google_cloud_storage::client::ClientConfig::default().anonymous(); - let client = google_cloud_storage::client::Client::new(config); + // Use authenticated client if GOOGLE_APPLICATION_CREDENTIALS is set, otherwise anonymous + let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { + if progress_bar { + info!("Using authenticated GCS client"); + } + ClientConfig::default().with_auth().await.unwrap() + } else { + if progress_bar { + info!("Using anonymous GCS client"); + } + ClientConfig::default().anonymous() + }; + let client = Client::new(config); // List all objects in the bucket with optional prefix let mut all_objects = vec![]; @@ -61,7 +75,12 @@ async fn download_model_from_gcs_async( } if progress_bar { - info!("Found {} model files in gs://{}/{}", all_objects.len(), bucket, prefix.unwrap_or("")); + info!( + "Found {} model files in gs://{}/{}", + all_objects.len(), + bucket, + prefix.unwrap_or("") + ); } // Determine cache directory @@ -72,10 +91,7 @@ async fn download_model_from_gcs_async( for object_name in all_objects { // Get just the filename (strip prefix if present) - let filename = object_name - .rsplit('/') - .next() - .unwrap_or(&object_name); + let filename = object_name.rsplit('/').next().unwrap_or(&object_name); let local_path = cache_dir.join(filename); @@ -125,5 +141,10 @@ pub fn download_model_from_gcs_sync( progress_bar: bool, ) -> Vec { let rt = Runtime::new().unwrap(); - rt.block_on(download_model_from_gcs_async(bucket, prefix, cache_dir, progress_bar)) + rt.block_on(download_model_from_gcs_async( + bucket, + prefix, + cache_dir, + progress_bar, + )) } diff --git a/shared/eval/examples/evaluate.rs b/shared/eval/examples/evaluate.rs index c867c7ae2..e4cd1ebaa 100644 --- a/shared/eval/examples/evaluate.rs +++ b/shared/eval/examples/evaluate.rs @@ -84,7 +84,11 @@ fn main() -> Result<()> { "".to_string() }; let model_source = if let Some(ref bucket) = args.gcs_bucket { - format!("gs://{}/{}", bucket, args.gcs_prefix.as_deref().unwrap_or("")) + format!( + "gs://{}/{}", + bucket, + args.gcs_prefix.as_deref().unwrap_or("") + ) } else { args.model.clone() }; From ff4ea2d87e4d057f68c062c2c3d6b69ed4bdf93f Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Tue, 6 Jan 2026 14:36:21 -0800 Subject: [PATCH 03/24] implement Checkpoint::Gcs(gcs_repo) --- architectures/centralized/server/src/app.rs | 1 + shared/client/src/state/init.rs | 44 ++++++++++++++++++++- shared/coordinator/src/model.rs | 32 +++++++++++++++ shared/data-provider/src/gcs.rs | 2 +- shared/data-provider/src/lib.rs | 2 +- 5 files changed, 77 insertions(+), 4 deletions(-) diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index bc034d1db..eace5f780 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -204,6 +204,7 @@ impl App { Checkpoint::P2P(_) => { bail!("Can't start up a run with a P2P checkpoint.") } + Checkpoint::Gcs(gcs_repo) => todo!(), } let server_addr: SocketAddr = String::from(url).parse().map_err(|e| { diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index 05afb13b5..e6942e4d3 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -6,7 +6,8 @@ use psyche_coordinator::{ use psyche_core::{Barrier, CancellableBarrier, NodeIdentity, Shuffle, TokenSize}; use psyche_data_provider::{ DataProvider, DataProviderTcpClient, DummyDataProvider, PreprocessedDataProvider, Split, - WeightedDataProvider, download_dataset_repo_async, download_model_repo_async, + WeightedDataProvider, download_dataset_repo_async, download_model_from_gcs_async, + download_model_repo_async, http::{FileURLs, HttpDataProvider}, }; use psyche_metrics::ClientMetrics; @@ -310,7 +311,9 @@ impl RunInitConfigAndIO { + model::Checkpoint::Hub(_) + | model::Checkpoint::P2P(_) + | model::Checkpoint::Gcs(_) => { let checkpoint = llm.checkpoint; tokio::spawn(async move { let (source, tokenizer, checkpoint_extra_files) = match checkpoint { @@ -427,6 +430,43 @@ impl RunInitConfigAndIO { + let bucket: String = (&gcs_repo.bucket).into(); + let prefix: Option = gcs_repo.prefix.map(|p| (&p).into()); + + info!( + "Downloading model from gs://{}/{}", + bucket, + prefix.as_deref().unwrap_or("") + ); + + let repo_files = download_model_from_gcs_async( + &bucket, + prefix.as_deref(), + None, + true, + ) + .await; + + let checkpoint_extra_files = repo_files + .iter() + .filter(|file| { + file.ends_with("config.json") + || file.ends_with("tokenizer.json") + || file.ends_with("tokenizer_config.json") + || file.ends_with("special_tokens_map.json") + || file.ends_with("generation_config.json") + || file.ends_with(".py") + }) + .cloned() + .collect(); + let tokenizer = Arc::new(auto_tokenizer(&repo_files)?); + ( + PretrainedSource::::RepoFiles(repo_files), + tokenizer, + checkpoint_extra_files, + ) + } _ => unreachable!(), }; diff --git a/shared/coordinator/src/model.rs b/shared/coordinator/src/model.rs index 083f3ce31..0bd07dfa7 100644 --- a/shared/coordinator/src/model.rs +++ b/shared/coordinator/src/model.rs @@ -238,6 +238,32 @@ impl HubRepo { } } +#[derive( + Clone, + Debug, + Copy, + AnchorDeserialize, + AnchorSerialize, + InitSpace, + Serialize, + Deserialize, + PartialEq, + TS, +)] +pub struct GcsRepo { + pub bucket: FixedString<{ SOLANA_MAX_STRING_LEN }>, + pub prefix: Option>, +} + +impl GcsRepo { + pub fn dummy() -> Self { + Self { + bucket: FixedString::new(), + prefix: None, + } + } +} + #[derive( AnchorSerialize, AnchorDeserialize, @@ -256,6 +282,7 @@ pub enum Checkpoint { Dummy(HubRepo), Hub(HubRepo), P2P(HubRepo), + Gcs(GcsRepo), } impl std::fmt::Display for Checkpoint { @@ -267,6 +294,10 @@ impl std::fmt::Display for Checkpoint { Checkpoint::P2P(hub_repo) => { write!(f, "P2P - Hub repo: {}", &hub_repo.repo_id) } + Checkpoint::Gcs(gcs_repo) => match &gcs_repo.prefix { + Some(prefix) => write!(f, "gs://{}/{}", &gcs_repo.bucket, prefix), + None => write!(f, "gs://{}", &gcs_repo.bucket), + }, } } } @@ -307,6 +338,7 @@ impl Model { Checkpoint::Ephemeral => true, Checkpoint::Hub(hub_repo) => hub_repo.repo_id.is_empty(), Checkpoint::P2P(hub_repo) => hub_repo.repo_id.is_empty(), + Checkpoint::Gcs(gcs_repo) => gcs_repo.bucket.is_empty(), }; if bad_checkpoint { diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index dacab92e9..3f3b3f1b2 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -30,7 +30,7 @@ fn get_cache_dir(bucket: &str, prefix: Option<&str>) -> PathBuf { } } -async fn download_model_from_gcs_async( +pub async fn download_model_from_gcs_async( bucket: &str, prefix: Option<&str>, cache_dir: Option, diff --git a/shared/data-provider/src/lib.rs b/shared/data-provider/src/lib.rs index 27b43b3a6..8e0091eaf 100644 --- a/shared/data-provider/src/lib.rs +++ b/shared/data-provider/src/lib.rs @@ -15,7 +15,7 @@ pub use data_provider::DataProvider; pub use dataset::{Dataset, Field, Row, Split}; pub use dummy::DummyDataProvider; pub use file_extensions::{DATA_FILE_EXTENSIONS, PARQUET_EXTENSION}; -pub use gcs::download_model_from_gcs_sync; +pub use gcs::{download_model_from_gcs_async, download_model_from_gcs_sync}; pub use hub::{ UploadModelError, download_dataset_repo_async, download_dataset_repo_sync, download_model_repo_async, download_model_repo_sync, upload_model_repo_async, From 5ea32b121bbea31296c8da1921763c67e227ea0c Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Tue, 6 Jan 2026 14:42:13 -0800 Subject: [PATCH 04/24] light-config-gcs.toml --- config/solana-test/light-config-gcs.toml | 44 ++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 config/solana-test/light-config-gcs.toml diff --git a/config/solana-test/light-config-gcs.toml b/config/solana-test/light-config-gcs.toml new file mode 100644 index 000000000..d96fd468f --- /dev/null +++ b/config/solana-test/light-config-gcs.toml @@ -0,0 +1,44 @@ +[config] +warmup_time = 30 +cooldown_time = 30 +epoch_time = 60 +max_round_train_time = 15 +round_witness_time = 1 +min_clients = 1 +init_min_clients = 1 +verification_percent = 0 +witness_nodes = 0 +global_batch_size_start = 8 +global_batch_size_end = 8 +global_batch_size_warmup_tokens = 0 +total_steps = 25000 +waiting_for_members_extra_time = 3 + +[model.LLM] +architecture = "HfLlama" +data_type = "Pretraining" +max_seq_len = 2048 +cold_start_warmup_steps = 0 + +[model.LLM.checkpoint.Gcs] +bucket = "llama220minit" + +[model.LLM.data_location.Http] +token_size_in_bytes = "TwoBytes" +shuffle = "DontShuffle" +[model.LLM.data_location.Http.location.Gcp] +bucket_name = "nous-pretraining-public-us" +filter_directory = "fineweb-edu-tokenized-llama2" + +[model.LLM.lr_schedule.Cosine] +base_lr = 4.0e-4 +warmup_steps = 250 +warmup_init_lr = 0.0 +total_steps = 25000 +final_lr = 4.0e-5 +[model.LLM.optimizer.Distro] +clip_grad_norm = 1.0 +compression_decay = 0.999 +compression_chunk = 64 +compression_topk = 8 +quantize_1bit = true From f8b193ab78d95be5b60716e767b6979aeba487b7 Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Tue, 6 Jan 2026 14:50:12 -0800 Subject: [PATCH 05/24] can_join command for gcs --- .../solana-client/src/command/can_join.rs | 60 ++++++++++++------- 1 file changed, 38 insertions(+), 22 deletions(-) diff --git a/architectures/decentralized/solana-client/src/command/can_join.rs b/architectures/decentralized/solana-client/src/command/can_join.rs index 3e5a649f9..ccb2b00b5 100644 --- a/architectures/decentralized/solana-client/src/command/can_join.rs +++ b/architectures/decentralized/solana-client/src/command/can_join.rs @@ -106,34 +106,50 @@ pub async fn command_can_join_execute( bail!("model is not an LLM, unsure how to predownload."); }; - let checkpoint = match model.checkpoint { + match model.checkpoint { Checkpoint::Ephemeral => { bail!("Can't predownload model with ephemeral checkpoint.") } Checkpoint::Dummy(hub_repo) | Checkpoint::Hub(hub_repo) | Checkpoint::P2P(hub_repo) => { - hub_repo + let repo_id = hub_repo.repo_id.to_string(); + let revision = hub_repo.revision.map(|s| s.to_string()); + println!( + "Predownloading model {repo_id} revision {}", + revision.as_ref().unwrap_or(&"main".to_string()) + ); + let hub_read_token = std::env::var("HF_TOKEN").ok(); + + // If you pass None as a cache folder, it'll use the env var `HF_HOME`. + let cache_folder = None; + + psyche_data_provider::download_model_repo_async( + &repo_id, + revision, + cache_folder, + hub_read_token, + Some(hub_max_concurrent_downloads), + true, + ) + .await?; + } + Checkpoint::Gcs(gcs_repo) => { + let bucket = gcs_repo.bucket.to_string(); + let prefix: Option = gcs_repo.prefix.map(|p| p.to_string()); + println!( + "Predownloading model from gs://{}/{}", + bucket, + prefix.as_deref().unwrap_or("") + ); + + psyche_data_provider::download_model_from_gcs_async( + &bucket, + prefix.as_deref(), + None, + true, + ) + .await; } }; - let repo_id = checkpoint.repo_id.to_string(); - let revision = checkpoint.revision.map(|s| s.to_string()); - println!( - "Predownloading model {repo_id} revision {}", - revision.as_ref().unwrap_or(&"main".to_string()) - ); - let hub_read_token = std::env::var("HF_TOKEN").ok(); - - // If you pass None as a cache folder, it'll use the env var `HF_HOME`. - let cache_folder = None; - - psyche_data_provider::download_model_repo_async( - &repo_id, - revision, - cache_folder, - hub_read_token, - Some(hub_max_concurrent_downloads), - true, - ) - .await?; println!("Model predownloaded successfully.") } From c1f0e1f7789d456c8abb019f0e1c4ba12770c82e Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Wed, 7 Jan 2026 11:38:05 -0800 Subject: [PATCH 06/24] gcs centralized version --- architectures/centralized/server/src/app.rs | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index eace5f780..ac1a86202 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -9,7 +9,8 @@ use psyche_coordinator::{ use psyche_core::{FixedVec, Shuffle, SizedIterator, TokenSize}; use psyche_data_provider::{ - DataProviderTcpServer, DataServerTui, LocalDataProvider, download_model_repo_async, + DataProviderTcpServer, DataServerTui, LocalDataProvider, download_model_from_gcs_async, + download_model_repo_async, }; use psyche_network::{ClientNotification, TcpServer}; use psyche_tui::{ @@ -204,7 +205,18 @@ impl App { Checkpoint::P2P(_) => { bail!("Can't start up a run with a P2P checkpoint.") } - Checkpoint::Gcs(gcs_repo) => todo!(), + Checkpoint::Gcs(gcs_repo) => { + let bucket: String = (&gcs_repo.bucket).into(); + let prefix: Option = + gcs_repo.prefix.map(|p| (&p).into()); + download_model_from_gcs_async( + &bucket, + prefix.as_deref(), + None, + true, + ) + .await; + } } let server_addr: SocketAddr = String::from(url).parse().map_err(|e| { From 6ddfe52799ded31cfc1694ec02c2bfa86690ab03 Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Wed, 7 Jan 2026 12:00:44 -0800 Subject: [PATCH 07/24] remove fn download_model_from_gcs_sync --- shared/data-provider/src/gcs.rs | 16 ---------------- shared/data-provider/src/lib.rs | 2 +- shared/eval/examples/evaluate.rs | 12 ++++++++++-- 3 files changed, 11 insertions(+), 19 deletions(-) diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index 3f3b3f1b2..09300edbd 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -3,7 +3,6 @@ use google_cloud_storage::http::objects::{ download::Range, get::GetObjectRequest, list::ListObjectsRequest, }; use std::path::PathBuf; -use tokio::runtime::Runtime; use tracing::info; const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; @@ -133,18 +132,3 @@ pub async fn download_model_from_gcs_async( downloaded_files } - -pub fn download_model_from_gcs_sync( - bucket: &str, - prefix: Option<&str>, - cache_dir: Option, - progress_bar: bool, -) -> Vec { - let rt = Runtime::new().unwrap(); - rt.block_on(download_model_from_gcs_async( - bucket, - prefix, - cache_dir, - progress_bar, - )) -} diff --git a/shared/data-provider/src/lib.rs b/shared/data-provider/src/lib.rs index 8e0091eaf..b9a80d68a 100644 --- a/shared/data-provider/src/lib.rs +++ b/shared/data-provider/src/lib.rs @@ -15,7 +15,7 @@ pub use data_provider::DataProvider; pub use dataset::{Dataset, Field, Row, Split}; pub use dummy::DummyDataProvider; pub use file_extensions::{DATA_FILE_EXTENSIONS, PARQUET_EXTENSION}; -pub use gcs::{download_model_from_gcs_async, download_model_from_gcs_sync}; +pub use gcs::download_model_from_gcs_async; pub use hub::{ UploadModelError, download_dataset_repo_async, download_dataset_repo_sync, download_model_repo_async, download_model_repo_sync, upload_model_repo_async, diff --git a/shared/eval/examples/evaluate.rs b/shared/eval/examples/evaluate.rs index e4cd1ebaa..9652609b1 100644 --- a/shared/eval/examples/evaluate.rs +++ b/shared/eval/examples/evaluate.rs @@ -2,7 +2,7 @@ use anyhow::Result; use clap::Parser; use indicatif::{ProgressBar, ProgressStyle}; use psyche_core::RunningAverage; -use psyche_data_provider::{download_model_from_gcs_sync, download_model_repo_sync}; +use psyche_data_provider::{download_model_from_gcs_async, download_model_repo_sync}; use psyche_eval::{ ALL_TASK_NAMES, EvalTaskOptions, Task, progress_bar_template_with_task, tasktype_from_name, }; @@ -125,7 +125,15 @@ fn main() -> Result<()> { } let repo = if let Some(ref bucket) = args.gcs_bucket { - download_model_from_gcs_sync(bucket, args.gcs_prefix.as_deref(), None, true) + // download model from GCS sync + tokio::runtime::Runtime::new() + .unwrap() + .block_on(download_model_from_gcs_async( + bucket, + args.gcs_prefix.as_deref(), + None, + true, + )) } else { download_model_repo_sync(&args.model, args.revision, None, args.hf_token, true)? }; From c0aa10f0638bc86aaf098d0eeb10e728da5d40a7 Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Wed, 7 Jan 2026 12:10:33 -0800 Subject: [PATCH 08/24] Revert "remove fn download_model_from_gcs_sync" This reverts commit 6ddfe52799ded31cfc1694ec02c2bfa86690ab03. --- shared/data-provider/src/gcs.rs | 16 ++++++++++++++++ shared/data-provider/src/lib.rs | 2 +- shared/eval/examples/evaluate.rs | 12 ++---------- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index 09300edbd..3f3b3f1b2 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -3,6 +3,7 @@ use google_cloud_storage::http::objects::{ download::Range, get::GetObjectRequest, list::ListObjectsRequest, }; use std::path::PathBuf; +use tokio::runtime::Runtime; use tracing::info; const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; @@ -132,3 +133,18 @@ pub async fn download_model_from_gcs_async( downloaded_files } + +pub fn download_model_from_gcs_sync( + bucket: &str, + prefix: Option<&str>, + cache_dir: Option, + progress_bar: bool, +) -> Vec { + let rt = Runtime::new().unwrap(); + rt.block_on(download_model_from_gcs_async( + bucket, + prefix, + cache_dir, + progress_bar, + )) +} diff --git a/shared/data-provider/src/lib.rs b/shared/data-provider/src/lib.rs index b9a80d68a..8e0091eaf 100644 --- a/shared/data-provider/src/lib.rs +++ b/shared/data-provider/src/lib.rs @@ -15,7 +15,7 @@ pub use data_provider::DataProvider; pub use dataset::{Dataset, Field, Row, Split}; pub use dummy::DummyDataProvider; pub use file_extensions::{DATA_FILE_EXTENSIONS, PARQUET_EXTENSION}; -pub use gcs::download_model_from_gcs_async; +pub use gcs::{download_model_from_gcs_async, download_model_from_gcs_sync}; pub use hub::{ UploadModelError, download_dataset_repo_async, download_dataset_repo_sync, download_model_repo_async, download_model_repo_sync, upload_model_repo_async, diff --git a/shared/eval/examples/evaluate.rs b/shared/eval/examples/evaluate.rs index 9652609b1..e4cd1ebaa 100644 --- a/shared/eval/examples/evaluate.rs +++ b/shared/eval/examples/evaluate.rs @@ -2,7 +2,7 @@ use anyhow::Result; use clap::Parser; use indicatif::{ProgressBar, ProgressStyle}; use psyche_core::RunningAverage; -use psyche_data_provider::{download_model_from_gcs_async, download_model_repo_sync}; +use psyche_data_provider::{download_model_from_gcs_sync, download_model_repo_sync}; use psyche_eval::{ ALL_TASK_NAMES, EvalTaskOptions, Task, progress_bar_template_with_task, tasktype_from_name, }; @@ -125,15 +125,7 @@ fn main() -> Result<()> { } let repo = if let Some(ref bucket) = args.gcs_bucket { - // download model from GCS sync - tokio::runtime::Runtime::new() - .unwrap() - .block_on(download_model_from_gcs_async( - bucket, - args.gcs_prefix.as_deref(), - None, - true, - )) + download_model_from_gcs_sync(bucket, args.gcs_prefix.as_deref(), None, true) } else { download_model_repo_sync(&args.model, args.revision, None, args.hf_token, true)? }; From 3b707ac37d39fa5e06d63bf53f7ee1df40cabfc0 Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Wed, 7 Jan 2026 12:37:32 -0800 Subject: [PATCH 09/24] handle errors --- architectures/centralized/server/src/app.rs | 2 +- .../solana-client/src/command/can_join.rs | 2 +- shared/client/src/state/init.rs | 9 +++-- shared/data-provider/src/gcs.rs | 33 ++++++++++++------- shared/data-provider/src/lib.rs | 2 +- shared/eval/examples/evaluate.rs | 2 +- 6 files changed, 32 insertions(+), 18 deletions(-) diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index ac1a86202..6ad9337d2 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -215,7 +215,7 @@ impl App { None, true, ) - .await; + .await?; } } diff --git a/architectures/decentralized/solana-client/src/command/can_join.rs b/architectures/decentralized/solana-client/src/command/can_join.rs index ccb2b00b5..442b1aa7b 100644 --- a/architectures/decentralized/solana-client/src/command/can_join.rs +++ b/architectures/decentralized/solana-client/src/command/can_join.rs @@ -147,7 +147,7 @@ pub async fn command_can_join_execute( None, true, ) - .await; + .await?; } }; println!("Model predownloaded successfully.") diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index e6942e4d3..18fafe807 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -5,8 +5,8 @@ use psyche_coordinator::{ }; use psyche_core::{Barrier, CancellableBarrier, NodeIdentity, Shuffle, TokenSize}; use psyche_data_provider::{ - DataProvider, DataProviderTcpClient, DummyDataProvider, PreprocessedDataProvider, Split, - WeightedDataProvider, download_dataset_repo_async, download_model_from_gcs_async, + DataProvider, DataProviderTcpClient, DummyDataProvider, GcsError, PreprocessedDataProvider, + Split, WeightedDataProvider, download_dataset_repo_async, download_model_from_gcs_async, download_model_repo_async, http::{FileURLs, HttpDataProvider}, }; @@ -90,6 +90,9 @@ pub enum InitRunError { #[error("failed to read HF model info: {0}")] HfModelLoad(#[from] hf_hub::api::tokio::ApiError), + #[error("failed to download model from GCS: {0}")] + GcsModelLoad(#[from] GcsError), + #[error("model loading thread crashed")] ModelLoadingThreadCrashed(JoinError), @@ -446,7 +449,7 @@ impl RunInitConfigAndIO bool { @@ -35,13 +48,13 @@ pub async fn download_model_from_gcs_async( prefix: Option<&str>, cache_dir: Option, progress_bar: bool, -) -> Vec { +) -> Result, GcsError> { // Use authenticated client if GOOGLE_APPLICATION_CREDENTIALS is set, otherwise anonymous let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { if progress_bar { info!("Using authenticated GCS client"); } - ClientConfig::default().with_auth().await.unwrap() + ClientConfig::default().with_auth().await? } else { if progress_bar { info!("Using anonymous GCS client"); @@ -62,8 +75,7 @@ pub async fn download_model_from_gcs_async( page_token: maybe_next_page_token, ..Default::default() }) - .await - .unwrap(); + .await?; for obj in results.items.iter().flatten() { if check_model_extension(&obj.name) { @@ -85,7 +97,7 @@ pub async fn download_model_from_gcs_async( // Determine cache directory let cache_dir = cache_dir.unwrap_or_else(|| get_cache_dir(bucket, prefix)); - std::fs::create_dir_all(&cache_dir).unwrap(); + std::fs::create_dir_all(&cache_dir)?; let mut downloaded_files = Vec::new(); @@ -118,11 +130,10 @@ pub async fn download_model_from_gcs_async( }, &Range::default(), ) - .await - .unwrap(); + .await?; // Write to cache - std::fs::write(&local_path, &data).unwrap(); + std::fs::write(&local_path, &data)?; if progress_bar { info!("Downloaded: {} ({} bytes)", filename, data.len()); @@ -131,7 +142,7 @@ pub async fn download_model_from_gcs_async( downloaded_files.push(local_path); } - downloaded_files + Ok(downloaded_files) } pub fn download_model_from_gcs_sync( @@ -139,8 +150,8 @@ pub fn download_model_from_gcs_sync( prefix: Option<&str>, cache_dir: Option, progress_bar: bool, -) -> Vec { - let rt = Runtime::new().unwrap(); +) -> Result, GcsError> { + let rt = Runtime::new().map_err(GcsError::Io)?; rt.block_on(download_model_from_gcs_async( bucket, prefix, diff --git a/shared/data-provider/src/lib.rs b/shared/data-provider/src/lib.rs index 8e0091eaf..7b97fe101 100644 --- a/shared/data-provider/src/lib.rs +++ b/shared/data-provider/src/lib.rs @@ -15,7 +15,7 @@ pub use data_provider::DataProvider; pub use dataset::{Dataset, Field, Row, Split}; pub use dummy::DummyDataProvider; pub use file_extensions::{DATA_FILE_EXTENSIONS, PARQUET_EXTENSION}; -pub use gcs::{download_model_from_gcs_async, download_model_from_gcs_sync}; +pub use gcs::{GcsError, download_model_from_gcs_async, download_model_from_gcs_sync}; pub use hub::{ UploadModelError, download_dataset_repo_async, download_dataset_repo_sync, download_model_repo_async, download_model_repo_sync, upload_model_repo_async, diff --git a/shared/eval/examples/evaluate.rs b/shared/eval/examples/evaluate.rs index e4cd1ebaa..927a635c7 100644 --- a/shared/eval/examples/evaluate.rs +++ b/shared/eval/examples/evaluate.rs @@ -125,7 +125,7 @@ fn main() -> Result<()> { } let repo = if let Some(ref bucket) = args.gcs_bucket { - download_model_from_gcs_sync(bucket, args.gcs_prefix.as_deref(), None, true) + download_model_from_gcs_sync(bucket, args.gcs_prefix.as_deref(), None, true)? } else { download_model_repo_sync(&args.model, args.revision, None, args.hf_token, true)? }; From 4b4a96d99a2dea90747d2e9c129169f93f00fe99 Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Wed, 7 Jan 2026 13:32:56 -0800 Subject: [PATCH 10/24] remove progress_bar --- architectures/centralized/server/src/app.rs | 1 - shared/data-provider/src/gcs.rs | 42 ++++++--------------- 2 files changed, 12 insertions(+), 31 deletions(-) diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index 6ad9337d2..2ca006479 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -213,7 +213,6 @@ impl App { &bucket, prefix.as_deref(), None, - true, ) .await?; } diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index f7db91498..e44a671be 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -47,18 +47,13 @@ pub async fn download_model_from_gcs_async( bucket: &str, prefix: Option<&str>, cache_dir: Option, - progress_bar: bool, ) -> Result, GcsError> { // Use authenticated client if GOOGLE_APPLICATION_CREDENTIALS is set, otherwise anonymous let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { - if progress_bar { - info!("Using authenticated GCS client"); - } + info!("Using authenticated GCS client"); ClientConfig::default().with_auth().await? } else { - if progress_bar { - info!("Using anonymous GCS client"); - } + info!("Using anonymous GCS client"); ClientConfig::default().anonymous() }; let client = Client::new(config); @@ -86,14 +81,12 @@ pub async fn download_model_from_gcs_async( next_page_token = results.next_page_token.map(Some); } - if progress_bar { - info!( - "Found {} model files in gs://{}/{}", - all_objects.len(), - bucket, - prefix.unwrap_or("") - ); - } + info!( + "Found {} model files in gs://{}/{}", + all_objects.len(), + bucket, + prefix.unwrap_or("") + ); // Determine cache directory let cache_dir = cache_dir.unwrap_or_else(|| get_cache_dir(bucket, prefix)); @@ -109,16 +102,12 @@ pub async fn download_model_from_gcs_async( // Skip if already cached if local_path.exists() { - if progress_bar { - info!("Using cached: {}", filename); - } + info!("Using cached: {}", filename); downloaded_files.push(local_path); continue; } - if progress_bar { - info!("Downloading: {}", object_name); - } + info!("Downloading: {}", object_name); // Download the object let data = client @@ -135,9 +124,7 @@ pub async fn download_model_from_gcs_async( // Write to cache std::fs::write(&local_path, &data)?; - if progress_bar { - info!("Downloaded: {} ({} bytes)", filename, data.len()); - } + info!("Downloaded: {} ({} bytes)", filename, data.len()); downloaded_files.push(local_path); } @@ -152,10 +139,5 @@ pub fn download_model_from_gcs_sync( progress_bar: bool, ) -> Result, GcsError> { let rt = Runtime::new().map_err(GcsError::Io)?; - rt.block_on(download_model_from_gcs_async( - bucket, - prefix, - cache_dir, - progress_bar, - )) + rt.block_on(download_model_from_gcs_async(bucket, prefix, cache_dir)) } From 616293794a5c24afa3e7ea3d962faa6b49998b81 Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Wed, 7 Jan 2026 13:43:43 -0800 Subject: [PATCH 11/24] use default cache dir --- architectures/centralized/server/src/app.rs | 7 +------ .../solana-client/src/command/can_join.rs | 9 ++------- shared/client/src/state/init.rs | 10 +++------- shared/data-provider/src/gcs.rs | 8 ++------ shared/eval/examples/evaluate.rs | 2 +- 5 files changed, 9 insertions(+), 27 deletions(-) diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index 2ca006479..d683b8451 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -209,12 +209,7 @@ impl App { let bucket: String = (&gcs_repo.bucket).into(); let prefix: Option = gcs_repo.prefix.map(|p| (&p).into()); - download_model_from_gcs_async( - &bucket, - prefix.as_deref(), - None, - ) - .await?; + download_model_from_gcs_async(&bucket, prefix.as_deref()).await?; } } diff --git a/architectures/decentralized/solana-client/src/command/can_join.rs b/architectures/decentralized/solana-client/src/command/can_join.rs index be7a4c91a..7c4b889c3 100644 --- a/architectures/decentralized/solana-client/src/command/can_join.rs +++ b/architectures/decentralized/solana-client/src/command/can_join.rs @@ -144,13 +144,8 @@ pub async fn command_can_join_execute( prefix.as_deref().unwrap_or("") ); - psyche_data_provider::download_model_from_gcs_async( - &bucket, - prefix.as_deref(), - None, - true, - ) - .await?; + psyche_data_provider::download_model_from_gcs_async(&bucket, prefix.as_deref()) + .await?; } }; println!("Model predownloaded successfully.") diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index 18fafe807..80f592864 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -443,13 +443,9 @@ impl RunInitConfigAndIO) -> PathBuf { pub async fn download_model_from_gcs_async( bucket: &str, prefix: Option<&str>, - cache_dir: Option, ) -> Result, GcsError> { // Use authenticated client if GOOGLE_APPLICATION_CREDENTIALS is set, otherwise anonymous let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { @@ -88,8 +87,7 @@ pub async fn download_model_from_gcs_async( prefix.unwrap_or("") ); - // Determine cache directory - let cache_dir = cache_dir.unwrap_or_else(|| get_cache_dir(bucket, prefix)); + let cache_dir = get_cache_dir(bucket, prefix); std::fs::create_dir_all(&cache_dir)?; let mut downloaded_files = Vec::new(); @@ -135,9 +133,7 @@ pub async fn download_model_from_gcs_async( pub fn download_model_from_gcs_sync( bucket: &str, prefix: Option<&str>, - cache_dir: Option, - progress_bar: bool, ) -> Result, GcsError> { let rt = Runtime::new().map_err(GcsError::Io)?; - rt.block_on(download_model_from_gcs_async(bucket, prefix, cache_dir)) + rt.block_on(download_model_from_gcs_async(bucket, prefix)) } diff --git a/shared/eval/examples/evaluate.rs b/shared/eval/examples/evaluate.rs index 927a635c7..c8c8e5875 100644 --- a/shared/eval/examples/evaluate.rs +++ b/shared/eval/examples/evaluate.rs @@ -125,7 +125,7 @@ fn main() -> Result<()> { } let repo = if let Some(ref bucket) = args.gcs_bucket { - download_model_from_gcs_sync(bucket, args.gcs_prefix.as_deref(), None, true)? + download_model_from_gcs_sync(bucket, args.gcs_prefix.as_deref())? } else { download_model_repo_sync(&args.model, args.revision, None, args.hf_token, true)? }; From 1bf03ab69cfdd6d1827f92870877b2d2590d73bd Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Wed, 7 Jan 2026 13:55:17 -0800 Subject: [PATCH 12/24] refactor loop --- shared/data-provider/src/gcs.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index 5bc406a9b..96d3def38 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -59,14 +59,14 @@ pub async fn download_model_from_gcs_async( // List all objects in the bucket with optional prefix let mut all_objects = vec![]; - let mut next_page_token: Option> = Some(None); + let mut page_token: Option = None; - while let Some(maybe_next_page_token) = next_page_token { + loop { let results = client .list_objects(&ListObjectsRequest { bucket: bucket.to_owned(), prefix: prefix.map(|s| s.to_owned()), - page_token: maybe_next_page_token, + page_token: page_token.clone(), ..Default::default() }) .await?; @@ -77,7 +77,10 @@ pub async fn download_model_from_gcs_async( } } - next_page_token = results.next_page_token.map(Some); + match results.next_page_token { + Some(token) => page_token = Some(token), + None => break, + } } info!( From 39fd70e93b5cd7a7fb12fe849dd87525275c7ee1 Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Thu, 8 Jan 2026 08:53:28 -0800 Subject: [PATCH 13/24] rm custom PSYCHE_CACHE_DIR --- shared/data-provider/src/gcs.rs | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index 96d3def38..cb6848c1d 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -26,13 +26,9 @@ fn check_model_extension(filename: &str) -> bool { } fn get_cache_dir(bucket: &str, prefix: Option<&str>) -> PathBuf { - let base = std::env::var("PSYCHE_CACHE_DIR") - .map(PathBuf::from) - .unwrap_or_else(|_| { - std::env::var("HOME") - .map(|h| PathBuf::from(h).join(".cache")) - .unwrap_or_else(|_| PathBuf::from(".cache")) - }) + let base = std::env::var("HOME") + .map(|h| PathBuf::from(h).join(".cache")) + .unwrap_or_else(|_| PathBuf::from(".cache")) .join("psyche") .join("gcs") .join(bucket); From 712f7cb00b3dba2b6a84eb1da8393417e4b220d5 Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Thu, 8 Jan 2026 09:37:29 -0800 Subject: [PATCH 14/24] GcsRepo prefix: Option> --- shared/coordinator/src/model.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shared/coordinator/src/model.rs b/shared/coordinator/src/model.rs index 0bd07dfa7..974f6349c 100644 --- a/shared/coordinator/src/model.rs +++ b/shared/coordinator/src/model.rs @@ -252,7 +252,7 @@ impl HubRepo { )] pub struct GcsRepo { pub bucket: FixedString<{ SOLANA_MAX_STRING_LEN }>, - pub prefix: Option>, + pub prefix: Option>, } impl GcsRepo { From 2296e3faf73ce635b4f84e0d05540ac62e6300c5 Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Thu, 8 Jan 2026 14:13:48 -0800 Subject: [PATCH 15/24] implement P2PGCS checkpoint --- architectures/centralized/server/src/app.rs | 2 +- .../decentralized/solana-client/src/command/can_join.rs | 2 +- shared/client/src/state/init.rs | 3 ++- shared/coordinator/src/coordinator.rs | 7 +++++-- shared/coordinator/src/model.rs | 7 +++++-- 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index d683b8451..6707b49bb 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -202,7 +202,7 @@ impl App { Checkpoint::Dummy(_) => { // ok! } - Checkpoint::P2P(_) => { + Checkpoint::P2P(_) | Checkpoint::P2PGcs(_) => { bail!("Can't start up a run with a P2P checkpoint.") } Checkpoint::Gcs(gcs_repo) => { diff --git a/architectures/decentralized/solana-client/src/command/can_join.rs b/architectures/decentralized/solana-client/src/command/can_join.rs index 7c4b889c3..ee1da6d30 100644 --- a/architectures/decentralized/solana-client/src/command/can_join.rs +++ b/architectures/decentralized/solana-client/src/command/can_join.rs @@ -135,7 +135,7 @@ pub async fn command_can_join_execute( ) .await?; } - Checkpoint::Gcs(gcs_repo) => { + Checkpoint::Gcs(gcs_repo) | Checkpoint::P2PGcs(gcs_repo) => { let bucket = gcs_repo.bucket.to_string(); let prefix: Option = gcs_repo.prefix.map(|p| p.to_string()); println!( diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index 80f592864..54073d3b2 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -316,6 +316,7 @@ impl RunInitConfigAndIO { let checkpoint = llm.checkpoint; tokio::spawn(async move { @@ -372,7 +373,7 @@ impl RunInitConfigAndIO { + model::Checkpoint::P2P(_) | model::Checkpoint::P2PGcs(_) => { let (tx_model_config_response, rx_model_config_response) = oneshot::channel(); info!("Checkpoint is p2p, requesting model config over network"); diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index a9ac66c85..0a3a9975d 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -912,8 +912,10 @@ impl Coordinator { .any(|client| pending_clients_unordered.contains(&client.id)); if all_prev_clients_disconnected { let Model::LLM(llm) = &mut self.model; - if let Checkpoint::P2P(hub_repo) = llm.checkpoint { - llm.checkpoint = Checkpoint::Hub(hub_repo); + match llm.checkpoint { + Checkpoint::P2P(hub_repo) => llm.checkpoint = Checkpoint::Hub(hub_repo), + Checkpoint::P2PGcs(gcs_repo) => llm.checkpoint = Checkpoint::Gcs(gcs_repo), + _ => {} } } @@ -1045,6 +1047,7 @@ impl Coordinator { Checkpoint::Hub(hub_repo) | Checkpoint::Dummy(hub_repo) => { llm.checkpoint = Checkpoint::P2P(hub_repo) } + Checkpoint::Gcs(gcs_repo) => llm.checkpoint = Checkpoint::P2PGcs(gcs_repo), _ => {} } diff --git a/shared/coordinator/src/model.rs b/shared/coordinator/src/model.rs index 974f6349c..87cdf5456 100644 --- a/shared/coordinator/src/model.rs +++ b/shared/coordinator/src/model.rs @@ -283,6 +283,7 @@ pub enum Checkpoint { Hub(HubRepo), P2P(HubRepo), Gcs(GcsRepo), + P2PGcs(GcsRepo), } impl std::fmt::Display for Checkpoint { @@ -294,7 +295,7 @@ impl std::fmt::Display for Checkpoint { Checkpoint::P2P(hub_repo) => { write!(f, "P2P - Hub repo: {}", &hub_repo.repo_id) } - Checkpoint::Gcs(gcs_repo) => match &gcs_repo.prefix { + Checkpoint::Gcs(gcs_repo) | Checkpoint::P2PGcs(gcs_repo) => match &gcs_repo.prefix { Some(prefix) => write!(f, "gs://{}/{}", &gcs_repo.bucket, prefix), None => write!(f, "gs://{}", &gcs_repo.bucket), }, @@ -338,7 +339,9 @@ impl Model { Checkpoint::Ephemeral => true, Checkpoint::Hub(hub_repo) => hub_repo.repo_id.is_empty(), Checkpoint::P2P(hub_repo) => hub_repo.repo_id.is_empty(), - Checkpoint::Gcs(gcs_repo) => gcs_repo.bucket.is_empty(), + Checkpoint::Gcs(gcs_repo) | Checkpoint::P2PGcs(gcs_repo) => { + gcs_repo.bucket.is_empty() + } }; if bad_checkpoint { From d1324e087bd4563ca7bcbac92a92f16aa9b4af86 Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Thu, 8 Jan 2026 14:45:35 -0800 Subject: [PATCH 16/24] Add tracing to evaluate crate --- Cargo.lock | 1 + shared/eval/Cargo.toml | 1 + shared/eval/examples/evaluate.rs | 5 ++++- 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 7553c18b5..554aaaad8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7101,6 +7101,7 @@ dependencies = [ "tokio-util 0.7.16", "torch-sys", "tracing", + "tracing-subscriber", ] [[package]] diff --git a/shared/eval/Cargo.toml b/shared/eval/Cargo.toml index 1b1cfdaab..18b9d76bd 100644 --- a/shared/eval/Cargo.toml +++ b/shared/eval/Cargo.toml @@ -28,3 +28,4 @@ tokio-util.workspace = true [dev-dependencies] clap.workspace = true psyche-python-extension-impl.workspace = true +tracing-subscriber = "0.3" diff --git a/shared/eval/examples/evaluate.rs b/shared/eval/examples/evaluate.rs index c8c8e5875..b344abd0b 100644 --- a/shared/eval/examples/evaluate.rs +++ b/shared/eval/examples/evaluate.rs @@ -13,6 +13,7 @@ use std::sync::Arc; use std::thread::JoinHandle; use tch::{Device, Kind}; use tokenizers::Tokenizer; +use tracing::Level; #[derive(Parser, Debug, Clone)] struct Args { @@ -61,6 +62,7 @@ struct Args { } fn main() -> Result<()> { + tracing_subscriber::fmt().with_max_level(Level::INFO).init(); let args = Args::parse(); let tasks: Result> = args .tasks @@ -93,12 +95,13 @@ fn main() -> Result<()> { args.model.clone() }; println!( - "Running tasks with model {}, seed: {}, DP={}{}", + "\n\n Running tasks with model {}, seed: {}, DP={}{}", model_source, args.seed, args.data_parallelism, limit_str ); for task in &tasks { println!(" - {}: {} few-shot examples", task, task.num_fewshot); } + println!("\n\n"); } if args.data_parallelism > 1 { From 037cc0859b7b5f34800d1cd9c7e7bf249f52b68b Mon Sep 17 00:00:00 2001 From: Nacho Avecilla Date: Wed, 14 Jan 2026 15:19:02 -0300 Subject: [PATCH 17/24] Model checkpoint upload with GCS (#476) Co-authored-by: Pedro Fontana --- Cargo.lock | 1 + architectures/centralized/client/src/app.rs | 21 +- architectures/centralized/server/src/app.rs | 2 +- .../centralized/shared/src/protocol.rs | 2 +- .../solana-client/src/backend.rs | 7 +- .../solana-client/src/command/checkpoint.rs | 3 +- .../solana-client/src/instructions.rs | 2 +- .../solana-coordinator/src/instance_state.rs | 8 +- .../programs/solana-coordinator/src/lib.rs | 4 +- nix/lib.nix | 1 + shared/client/src/cli.rs | 107 +++-- shared/client/src/lib.rs | 3 +- shared/client/src/state/cooldown.rs | 173 ++++---- shared/client/src/state/init.rs | 10 +- shared/client/src/state/mod.rs | 3 +- shared/client/src/state/types.rs | 9 +- shared/coordinator/src/coordinator.rs | 41 +- shared/data-provider/Cargo.toml | 3 +- shared/data-provider/src/errors.rs | 53 +++ shared/data-provider/src/gcs.rs | 376 ++++++++++++++++-- shared/data-provider/src/hub.rs | 101 ++--- shared/data-provider/src/lib.rs | 11 +- shared/watcher/src/traits.rs | 2 +- 23 files changed, 689 insertions(+), 254 deletions(-) create mode 100644 shared/data-provider/src/errors.rs diff --git a/Cargo.lock b/Cargo.lock index a9eae255c..3792ad852 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7015,6 +7015,7 @@ dependencies = [ "anyhow", "async-trait", "bytemuck", + "chrono", "clap", "futures", "google-cloud-storage", diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 3560e0a7a..8f6a9bcca 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -2,6 +2,8 @@ use anyhow::{Error, Result}; use bytemuck::Zeroable; use hf_hub::Repo; use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientMessage}; +use psyche_client::HubUploadInfo; +use psyche_client::UploadInfo; use psyche_client::{ Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key, }; @@ -29,7 +31,7 @@ pub type TabsData = ::Data; pub enum ToSend { Witness(Box), HealthCheck(HealthChecks), - Checkpoint(model::HubRepo), + Checkpoint(model::Checkpoint), } struct Backend { @@ -67,7 +69,7 @@ impl WatcherBackend for Backend { Ok(()) } - async fn send_checkpoint(&mut self, checkpoint: model::HubRepo) -> Result<()> { + async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()> { self.tx.send(ToSend::Checkpoint(checkpoint))?; Ok(()) } @@ -173,18 +175,19 @@ impl App { ) -> Result<()> { // sanity checks if let Some(checkpoint_config) = &state_options.checkpoint_config { - if let Some(hub_upload) = &checkpoint_config.hub_upload { + if let Some(UploadInfo::Hub(HubUploadInfo { + hub_repo, + hub_token, + })) = &checkpoint_config.upload_info + { let api = hf_hub::api::tokio::ApiBuilder::new() - .with_token(Some(hub_upload.hub_token.clone())) + .with_token(Some(hub_token.clone())) .build()?; - let repo_api = api.repo(Repo::new( - hub_upload.hub_repo.clone(), - hf_hub::RepoType::Model, - )); + let repo_api = api.repo(Repo::new(hub_repo.clone(), hf_hub::RepoType::Model)); if !repo_api.is_writable().await { anyhow::bail!( "Checkpoint upload repo {} is not writable with the passed API key.", - hub_upload.hub_repo + hub_repo ) } } diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index 6707b49bb..402146817 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -81,7 +81,7 @@ impl psyche_watcher::Backend for ChannelCoordinatorBackend { bail!("Server does not send health checks"); } - async fn send_checkpoint(&mut self, _checkpoint: model::HubRepo) -> Result<()> { + async fn send_checkpoint(&mut self, _checkpoint: model::Checkpoint) -> Result<()> { bail!("Server does not send checkpoints"); } } diff --git a/architectures/centralized/shared/src/protocol.rs b/architectures/centralized/shared/src/protocol.rs index c71c7524f..a96c64b8f 100644 --- a/architectures/centralized/shared/src/protocol.rs +++ b/architectures/centralized/shared/src/protocol.rs @@ -15,7 +15,7 @@ pub enum ClientToServerMessage { Join { run_id: String }, Witness(Box), HealthCheck(HealthChecks), - Checkpoint(model::HubRepo), + Checkpoint(model::Checkpoint), } #[derive(Serialize, Deserialize, Debug, Clone)] diff --git a/architectures/decentralized/solana-client/src/backend.rs b/architectures/decentralized/solana-client/src/backend.rs index 439c880a6..24205a334 100644 --- a/architectures/decentralized/solana-client/src/backend.rs +++ b/architectures/decentralized/solana-client/src/backend.rs @@ -20,7 +20,8 @@ use anchor_client::{ use anyhow::{Context, Result, anyhow}; use futures_util::StreamExt; use psyche_client::IntegrationTestLogMarker; -use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks, model::HubRepo}; +use psyche_coordinator::model::{self, Checkpoint}; +use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks}; use psyche_watcher::{Backend as WatcherBackend, OpportunisticData}; use solana_account_decoder_client_types::{UiAccount, UiAccountEncoding}; use solana_transaction_status_client_types::UiTransactionEncoding; @@ -333,7 +334,7 @@ impl SolanaBackend { &self, coordinator_instance: Pubkey, coordinator_account: Pubkey, - repo: HubRepo, + repo: Checkpoint, ) { let user = self.get_payer(); let instruction = instructions::coordinator_checkpoint( @@ -603,7 +604,7 @@ impl WatcherBackend for SolanaBackendRunner Ok(()) } - async fn send_checkpoint(&mut self, checkpoint: HubRepo) -> Result<()> { + async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()> { self.backend .send_checkpoint(self.instance, self.account, checkpoint); Ok(()) diff --git a/architectures/decentralized/solana-client/src/command/checkpoint.rs b/architectures/decentralized/solana-client/src/command/checkpoint.rs index 73adc2105..3ee098640 100644 --- a/architectures/decentralized/solana-client/src/command/checkpoint.rs +++ b/architectures/decentralized/solana-client/src/command/checkpoint.rs @@ -1,5 +1,6 @@ use anyhow::Result; use clap::Args; +use psyche_coordinator::model::Checkpoint; use psyche_coordinator::model::HubRepo; use psyche_core::FixedString; @@ -45,7 +46,7 @@ pub async fn command_checkpoint_execute( &coordinator_instance, &coordinator_account, &user, - repo, + Checkpoint::Hub(repo), ); let signature = backend .send_and_retry("Checkpoint", &[instruction], &[]) diff --git a/architectures/decentralized/solana-client/src/instructions.rs b/architectures/decentralized/solana-client/src/instructions.rs index 68e169008..b535d54f4 100644 --- a/architectures/decentralized/solana-client/src/instructions.rs +++ b/architectures/decentralized/solana-client/src/instructions.rs @@ -206,7 +206,7 @@ pub fn coordinator_checkpoint( coordinator_instance: &Pubkey, coordinator_account: &Pubkey, user: &Pubkey, - repo: psyche_coordinator::model::HubRepo, + repo: psyche_coordinator::model::Checkpoint, ) -> Instruction { anchor_instruction( psyche_solana_coordinator::ID, diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs index 029b40fad..8d9ddb977 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs @@ -10,7 +10,7 @@ use psyche_coordinator::RunState; use psyche_coordinator::SOLANA_MAX_STRING_LEN; use psyche_coordinator::TickResult; use psyche_coordinator::Witness; -use psyche_coordinator::model::HubRepo; +use psyche_coordinator::model::Checkpoint; use psyche_coordinator::model::Model; use psyche_core::FixedString; use psyche_core::SmallBoolean; @@ -389,7 +389,11 @@ impl CoordinatorInstanceState { self.tick() } - pub fn checkpoint(&mut self, payer: &Pubkey, repo: HubRepo) -> Result<()> { + pub fn checkpoint( + &mut self, + payer: &Pubkey, + repo: Checkpoint, + ) -> Result<()> { // O(n) on clients, reconsider let id = self.clients_state.find_signer(payer)?; let index = self diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs index 657424195..0a041a6e9 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs @@ -21,7 +21,7 @@ use psyche_coordinator::Witness; use psyche_coordinator::WitnessBloom; use psyche_coordinator::WitnessMetadata; use psyche_coordinator::WitnessProof; -use psyche_coordinator::model::{HubRepo, Model}; +use psyche_coordinator::model::Model; use psyche_core::MerkleRoot; use serde::Deserialize; use serde::Serialize; @@ -313,7 +313,7 @@ pub mod psyche_solana_coordinator { pub fn checkpoint( ctx: Context, - repo: HubRepo, + repo: psyche_coordinator::model::Checkpoint, ) -> Result<()> { let mut account = ctx.accounts.coordinator_account.load_mut()?; account.increment_nonce(); diff --git a/nix/lib.nix b/nix/lib.nix index 882f17040..37ab6d844 100644 --- a/nix/lib.nix +++ b/nix/lib.nix @@ -36,6 +36,7 @@ let python312 pkg-config perl + cargo-nextest ]; buildInputs = diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index 268ea753a..a4ef145f0 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -1,7 +1,9 @@ -use crate::{CheckpointConfig, HubUploadInfo, WandBInfo}; +use crate::{CheckpointConfig, WandBInfo}; +use crate::UploadInfo; use anyhow::{Result, anyhow, bail}; use clap::Args; +use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; use psyche_eval::tasktype_from_name; use psyche_modeling::Devices; use psyche_network::{DiscoveryMode, RelayKind, SecretKey}; @@ -146,6 +148,14 @@ pub struct TrainArgs { #[clap(long, env)] pub hub_repo: Option, + /// Name of the GCS bucket containing model data and configuration. + #[clap(long, env)] + pub gcs_bucket: Option, + + /// Prefix within the GCS bucket for model data and configuration. + #[clap(long, env)] + pub gcs_prefix: Option, + #[clap(long, env, default_value_t = 3)] pub hub_max_concurrent_downloads: usize, @@ -224,43 +234,72 @@ impl TrainArgs { pub fn checkpoint_config(&self) -> Result> { let hub_read_token = std::env::var("HF_TOKEN").ok(); - let checkpoint_upload_info = match ( - &hub_read_token, - self.hub_repo.clone(), - self.checkpoint_dir.clone(), - self.delete_old_steps, - self.keep_steps, - ) { - (Some(token), Some(repo), Some(dir), delete_old_steps, keep_steps) => { - if keep_steps == 0 { - bail!("keep_steps must be >= 1 for hub repository uploads (got {keep_steps})") + + if self.hub_repo.is_some() && self.gcs_bucket.is_some() { + bail!("Use either GCS or HF hub for checkpoint uploads, not both."); + } + + let checkpoint_dir = match &self.checkpoint_dir { + Some(dir) => dir, + None => { + if self.hub_repo.is_some() || self.gcs_bucket.is_some() { + bail!( + "--hub-repo or --gcs-bucket was set, but no --checkpoint-dir was passed!" + ); } - Some(CheckpointConfig { - checkpoint_dir: dir, - hub_upload: Some(HubUploadInfo { - hub_repo: repo, - hub_token: token.to_string(), - }), - delete_old_steps, - keep_steps, - }) - } - (None, Some(_), Some(_), _, _) => { - bail!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.") - } - (_, Some(_), None, _, _) => { - bail!("--hub-repo was set, but no --checkpoint-dir was passed!") + return Ok(None); } - (_, None, Some(dir), delete_old_steps, keep_steps) => Some(CheckpointConfig { - checkpoint_dir: dir, - hub_upload: None, - delete_old_steps, - keep_steps, - }), - (_, None, _, _, _) => None, }; - Ok(checkpoint_upload_info) + let upload_info = self.build_upload_info(&hub_read_token)?; + + if upload_info.is_some() && self.keep_steps == 0 { + bail!( + "keep_steps must be >= 1 for checkpoint uploads (got {})", + self.keep_steps + ); + } + + Ok(Some(CheckpointConfig { + checkpoint_dir: checkpoint_dir.clone(), + upload_info, + delete_old_steps: self.delete_old_steps, + keep_steps: self.keep_steps, + })) + } + + fn build_upload_info(&self, hub_token: &Option) -> Result> { + if let Some(repo) = &self.hub_repo { + return self.build_hub_upload_info(repo, hub_token); + } + + if let Some(bucket) = &self.gcs_bucket { + return self.build_gcs_upload_info(bucket); + } + + Ok(None) + } + + fn build_hub_upload_info( + &self, + repo: &str, + token: &Option, + ) -> Result> { + let token = token.as_ref().ok_or_else(|| { + anyhow::anyhow!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.") + })?; + + Ok(Some(UploadInfo::Hub(HubUploadInfo { + hub_repo: repo.to_string(), + hub_token: token.to_string(), + }))) + } + + fn build_gcs_upload_info(&self, bucket: &str) -> Result> { + Ok(Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: bucket.to_string(), + gcs_prefix: self.gcs_prefix.clone(), + }))) } pub fn eval_tasks(&self) -> Result> { diff --git a/shared/client/src/lib.rs b/shared/client/src/lib.rs index f1c545ed7..8a0086932 100644 --- a/shared/client/src/lib.rs +++ b/shared/client/src/lib.rs @@ -10,7 +10,8 @@ pub use cli::{TrainArgs, prepare_environment, print_identity_keys, read_identity pub use client::Client; pub use protocol::{Broadcast, BroadcastType, Finished, NC, TrainingResult}; pub use state::{ - CheckpointConfig, HubUploadInfo, InitRunError, RoundState, RunInitConfig, RunInitConfigAndIO, + CheckpointConfig, GcsUploadInfo, HubUploadInfo, InitRunError, RoundState, RunInitConfig, + RunInitConfigAndIO, UploadInfo, }; pub use testing::IntegrationTestLogMarker; pub use tui::{ClientTUI, ClientTUIState}; diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index 6e2efb6ef..ff322e298 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -1,14 +1,14 @@ -use crate::HubUploadInfo; - +use crate::UploadInfo; use psyche_coordinator::{ Coordinator, - model::{self, HubRepo}, + model::{self}, }; -use psyche_core::{FixedString, NodeIdentity}; -use psyche_data_provider::{UploadModelError, upload_model_repo_async}; +use psyche_core::NodeIdentity; +use psyche_data_provider::{GcsManifestMetadata, UploadError, upload_to_gcs, upload_to_hub}; +#[cfg(feature = "python")] +use psyche_modeling::CausalLM; use psyche_modeling::{ - CausalLM, SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, - save_tensors_into_safetensors, + SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, save_tensors_into_safetensors, }; use std::{ cmp::Reverse, @@ -42,7 +42,7 @@ pub enum CooldownError { } pub struct CooldownStepMetadata { - tx_checkpoint: mpsc::UnboundedSender, + tx_checkpoint: mpsc::UnboundedSender, tx_model: mpsc::UnboundedSender>, checkpoint_info: Option, checkpoint_extra_files: Vec, @@ -59,7 +59,7 @@ pub struct CooldownStepMetadata { impl CooldownStepMetadata { pub fn new( - tx_checkpoint: mpsc::UnboundedSender, + tx_checkpoint: mpsc::UnboundedSender, tx_model: mpsc::UnboundedSender>, checkpoint_info: Option, checkpoint_extra_files: Vec, @@ -93,8 +93,8 @@ pub enum CheckpointError { #[error("Writing extra file to disk failed: {0}")] WriteExtraFile(#[from] tokio::io::Error), - #[error("Couldn't upload model to huggingface: {0}")] - UploadError(#[from] UploadModelError), + #[error("Couldn't upload model to huggingface or GCS: {0}")] + UploadError(#[from] UploadError), #[error("Couldn't send checkpoint - channel closed")] SendCheckpoint, @@ -137,6 +137,7 @@ impl CooldownStepMetadata { let step = state.progress.step - 1; let run_id = String::from(&state.run_id); + let epoch = state.progress.epoch as u32; let checkpoint_extra_files = self.checkpoint_extra_files.clone(); let checkpoint_info = self.checkpoint_info.clone(); let tx_checkpoint = self.tx_checkpoint.clone(); @@ -165,103 +166,51 @@ impl CooldownStepMetadata { .send(variables_clone) .map_err(|_| CheckpointError::SendCheckpoint)?; - let (variables, trainer) = if checkpoint_info.is_some() { - // convert from internal shape to serialized shape (e.g. torchtitan to hf) - tokio::task::spawn_blocking(|| (trainer.convert(Some(variables)), trainer)) - .await - .map_err(|_| CheckpointError::ExtractThreadCrashed)? - } else { - (variables, trainer) + // convert from internal shape to serialized shape (e.g. torchtitan to hf) + let (variables, trainer) = match trainer { + #[cfg(feature = "python")] + Trainer::PythonDistributed(_) => { + info!("Converting distributed trainer variables for checkpointing..."); + tokio::task::spawn_blocking(|| (trainer.convert(Some(variables)), trainer)) + .await + .map_err(|_| CheckpointError::ExtractThreadCrashed)? + } + _ => (variables, trainer), }; trainers.push(trainer); let evals = model_task_runner.start(trainers); let Some(CheckpointConfig { - hub_upload, + upload_info, checkpoint_dir, delete_old_steps, keep_steps, }) = checkpoint_info else { - // If there was no HF checkpointing configuration, return immediately return Ok((evals, None)); }; - // Start the upload process of the updated model parameters in a separate task let upload_handle = tokio::task::spawn(async move { let path = checkpoint_dir.join(format!("{run_id}-step{step}")); - info!("Saving to {}", path.display()); - let mut local = tokio::task::spawn_blocking({ - let path = path.clone(); - move || save_tensors_into_safetensors(variables, path) - }) - .await - .map_err(|_| CheckpointError::WriteThreadCrashed)??; - - for extra in checkpoint_extra_files { - let to = path.join(extra.file_name().unwrap()); - tokio::fs::copy(extra.clone(), to.clone()) - .await - .map_err(CheckpointError::WriteExtraFile)?; - local.push(to); + let local = + save_checkpoint_locally(path, variables, checkpoint_extra_files).await?; + + if let Some(upload_info) = upload_info { + let manifest_metadata = GcsManifestMetadata { + epoch, + run_id: run_id.clone(), + }; + upload_checkpoint( + upload_info, + manifest_metadata, + local.clone(), + step as u64, + tx_checkpoint, + ) + .await?; } - let Some(HubUploadInfo { - hub_repo, - hub_token, - }) = hub_upload - else { - cleanup_dirs( - delete_queue, - keep_steps, - run_id, - delete_old_steps, - step, - checkpoint_dir, - ) - .await; - return Ok::<(), CheckpointError>(()); - }; - - info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); - let revision = match upload_model_repo_async( - hub_repo.clone(), - local, - hub_token.clone(), - Some(format!("step {step}")), - None, - ) - .await - { - Ok(revision) => { - info!( - repo = hub_repo, - revision = revision, - "Upload to HuggingFace complete" - ); - revision - } - Err(err) => { - error!(repo = hub_repo, "Error uploading to HuggingFace: {err:#}"); - return Err(err.into()); - } - }; - - tx_checkpoint - .send(HubRepo { - repo_id: FixedString::from_str_truncated(&hub_repo), - revision: Some(FixedString::from_str_truncated(&revision)), - }) - .map_err(|_| CheckpointError::SendCheckpoint)?; - - // we put the cleanup step at the end, so that if keep_steps == 0 the logic will still work - // we'll just delete the dir after we've uploaded it - // if we fail in any of the above steps we may wind up not queueing this dir for delete - // but that's probably better than risking having the dir deleted from under us - // for a relatively low priority disk cleanup task - // and this may actually be preferred anyway because if we failed to upload, we may want to keep - // the data around locally on disk cleanup_dirs( delete_queue, keep_steps, @@ -279,12 +228,56 @@ impl CooldownStepMetadata { } .instrument(info_span!("checkpointing")), ); + Ok(CooldownStep { checkpointing_and_evals, }) } } +async fn save_checkpoint_locally( + path: PathBuf, + variables: HashMap, + checkpoint_extra_files: Vec, +) -> Result, CheckpointError> { + info!("Saving to {}", path.display()); + let mut local = tokio::task::spawn_blocking({ + let path = path.clone(); + move || save_tensors_into_safetensors(variables, path) + }) + .await + .map_err(|_| CheckpointError::WriteThreadCrashed)??; + + for extra in checkpoint_extra_files { + let to = path.join(extra.file_name().unwrap()); + tokio::fs::copy(extra.clone(), to.clone()) + .await + .map_err(CheckpointError::WriteExtraFile)?; + local.push(to); + } + + Ok(local) +} + +async fn upload_checkpoint( + upload_info: UploadInfo, + manifest_metadata: GcsManifestMetadata, + local: Vec, + step: u64, + tx_checkpoint: mpsc::UnboundedSender, +) -> Result<(), CheckpointError> { + match upload_info { + UploadInfo::Gcs(gcs_info) => { + upload_to_gcs(gcs_info, manifest_metadata, local, step, tx_checkpoint) + .await + .map_err(CheckpointError::UploadError) + } + UploadInfo::Hub(hub_info) => upload_to_hub(hub_info, local, step, tx_checkpoint) + .await + .map_err(CheckpointError::UploadError), + } +} + type CheckpointAndEvalsHandle = JoinHandle< Result< ( diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index 160af0ec1..a1fcb532d 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -5,9 +5,9 @@ use psyche_coordinator::{ }; use psyche_core::{Barrier, CancellableBarrier, NodeIdentity, Shuffle, TokenSize}; use psyche_data_provider::{ - DataProvider, DataProviderTcpClient, DummyDataProvider, GcsError, PreprocessedDataProvider, - Split, WeightedDataProvider, download_dataset_repo_async, download_model_from_gcs_async, - download_model_repo_async, + DataProvider, DataProviderTcpClient, DownloadError, DummyDataProvider, + PreprocessedDataProvider, Split, WeightedDataProvider, download_dataset_repo_async, + download_model_from_gcs_async, download_model_repo_async, http::{FileURLs, HttpDataProvider}, }; use psyche_metrics::ClientMetrics; @@ -90,7 +90,7 @@ pub enum InitRunError { HfModelLoad(#[from] hf_hub::api::tokio::ApiError), #[error("failed to download model from GCS: {0}")] - GcsModelLoad(#[from] GcsError), + GcsModelLoad(#[from] DownloadError), #[error("model loading thread crashed")] ModelLoadingThreadCrashed(JoinError), @@ -153,7 +153,7 @@ pub struct RunInitConfigAndIO { pub tx_health_check: UnboundedSender>, pub tx_witness: UnboundedSender, - pub tx_checkpoint: UnboundedSender, + pub tx_checkpoint: UnboundedSender, pub tx_model: UnboundedSender>, pub tx_parameters_req: UnboundedSender<(Vec, OneshotModelParameterSender)>, pub tx_config: UnboundedSender<(String, String)>, diff --git a/shared/client/src/state/mod.rs b/shared/client/src/state/mod.rs index e4d73d5b9..78e6cd1eb 100644 --- a/shared/client/src/state/mod.rs +++ b/shared/client/src/state/mod.rs @@ -14,6 +14,7 @@ mod warmup; mod witness; pub use init::{InitRunError, RunInitConfig, RunInitConfigAndIO}; +pub use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; pub use round_state::RoundState; pub use steps::{ApplyMessageOutcome, RunManager}; -pub use types::{CheckpointConfig, DistroBroadcastAndPayload, FinishedBroadcast, HubUploadInfo}; +pub use types::{CheckpointConfig, DistroBroadcastAndPayload, FinishedBroadcast, UploadInfo}; diff --git a/shared/client/src/state/types.rs b/shared/client/src/state/types.rs index 2edf22760..29734f1a0 100644 --- a/shared/client/src/state/types.rs +++ b/shared/client/src/state/types.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use psyche_coordinator::CommitteeProof; use psyche_core::{BatchId, MerkleRoot, NodeIdentity}; +use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; use psyche_modeling::DistroResult; use psyche_network::{BlobTicket, TransmittableDistroResult}; use tch::TchError; @@ -9,14 +10,14 @@ use thiserror::Error; use tokio::task::JoinHandle; #[derive(Debug, Clone)] -pub struct HubUploadInfo { - pub hub_repo: String, - pub hub_token: String, +pub enum UploadInfo { + Hub(HubUploadInfo), + Gcs(GcsUploadInfo), } #[derive(Debug, Clone)] pub struct CheckpointConfig { - pub hub_upload: Option, + pub upload_info: Option, pub checkpoint_dir: PathBuf, pub delete_old_steps: bool, pub keep_steps: u32, diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index 0a3a9975d..5b10bfcbd 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -1,6 +1,6 @@ use crate::{ Commitment, Committee, CommitteeProof, CommitteeSelection, WitnessProof, - model::{Checkpoint, HubRepo, Model}, + model::{Checkpoint, Model}, }; use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; @@ -596,22 +596,43 @@ impl Coordinator { &mut self, from: &T, index: u64, - hub_repo: HubRepo, + checkpoint_repo: Checkpoint, ) -> std::result::Result<(), CoordinatorError> { let index = index as usize; if index >= self.epoch_state.clients.len() || self.epoch_state.clients[index].id != *from { return Err(CoordinatorError::InvalidCommitteeProof); } - // TODO: In the case of more than one checkpointer, this will overwrite the hub repo - // with the last checkpointed one. We could instead have a vector of hub repos to have + + // TODO: In the case of more than one checkpointer, this will overwrite the checkpoint + // with the last checkpointed one. We could instead have a vector of checkpoints to have // more download options. - match &mut self.model { - Model::LLM(llm) => match llm.checkpoint { - Checkpoint::P2P(_) => llm.checkpoint = Checkpoint::P2P(hub_repo), - Checkpoint::Hub(_) => llm.checkpoint = Checkpoint::Hub(hub_repo), - _ => {} - }, + let Model::LLM(llm) = &mut self.model; + match (&llm.checkpoint, checkpoint_repo) { + // If current is P2P, wrap the new checkpoint in P2P + (Checkpoint::P2P(_), Checkpoint::Hub(hub_repo)) => { + llm.checkpoint = Checkpoint::P2P(hub_repo); + } + (Checkpoint::P2PGcs(_), Checkpoint::Gcs(gcs_repo)) => { + llm.checkpoint = Checkpoint::P2PGcs(gcs_repo); + } + // If current is Hub, only accept Hub updates + (Checkpoint::Hub(_), Checkpoint::Hub(hub_repo)) => { + llm.checkpoint = Checkpoint::Hub(hub_repo); + } + // If current is Gcs, only accept Gcs updates + (Checkpoint::Gcs(_), Checkpoint::Gcs(gcs_repo)) => { + llm.checkpoint = Checkpoint::Gcs(gcs_repo); + } + (Checkpoint::P2PGcs(_), Checkpoint::Hub(hub_repo)) => { + llm.checkpoint = Checkpoint::P2P(hub_repo); + } + (Checkpoint::P2P(_), Checkpoint::Gcs(gcs_repo)) => { + llm.checkpoint = Checkpoint::P2PGcs(gcs_repo); + } + // Ignore other combinations + _ => {} } + Ok(()) } diff --git a/shared/data-provider/Cargo.toml b/shared/data-provider/Cargo.toml index ea0e59791..7ad69e616 100644 --- a/shared/data-provider/Cargo.toml +++ b/shared/data-provider/Cargo.toml @@ -27,6 +27,8 @@ postcard.workspace = true bytemuck.workspace = true reqwest = "0.12.12" google-cloud-storage = "0.24.0" +chrono = { version = "0.4", features = ["serde"] } +serde_json.workspace = true ts-rs.workspace = true rayon.workspace = true @@ -38,4 +40,3 @@ test-log.workspace = true clap.workspace = true tempfile = "3.15.0" static-web-server = { git = "https://github.com/arilotter/static-web-server", rev = "c91445427b56c5ddff0365d8ec116e3b567377ac" } # forked to add a channel for getting the port -serde_json.workspace = true diff --git a/shared/data-provider/src/errors.rs b/shared/data-provider/src/errors.rs new file mode 100644 index 000000000..b84bc5f9a --- /dev/null +++ b/shared/data-provider/src/errors.rs @@ -0,0 +1,53 @@ +use std::path::PathBuf; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum UploadError { + #[error("path {0} is not a file")] + NotAFile(PathBuf), + + #[error("file {0} doesn't have a valid utf-8 representation")] + InvalidFilename(PathBuf), + + #[error("failed to send checkpoint notification")] + SendCheckpoint, + + // Hub-specific errors + #[error("failed to connect to HF hub: {0}")] + HfHub(#[from] hf_hub::api::tokio::ApiError), + + #[error("failed to commit files: {0}")] + Commit(#[from] hf_hub::api::tokio::CommitError), + + // GCS-specific errors + #[error("GCS authentication failed: {0}")] + GcsAuth(#[from] google_cloud_storage::client::google_cloud_auth::error::Error), + + #[error("GCS operation failed: {0}")] + GcsStorage(#[from] google_cloud_storage::http::Error), + + // Common errors + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), +} + +#[derive(Error, Debug)] +pub enum DownloadError { + #[error("failed to connect to HF hub: {0}")] + HfHub(#[from] hf_hub::api::tokio::ApiError), + + #[error("GCS authentication failed: {0}")] + GcsAuth(#[from] google_cloud_storage::client::google_cloud_auth::error::Error), + + #[error("GCS operation failed: {0}")] + GcsStorage(#[from] google_cloud_storage::http::Error), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), +} diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index cb6848c1d..b79a850b1 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -1,48 +1,115 @@ +use crate::errors::{DownloadError, UploadError}; +use chrono::{DateTime, Utc}; use google_cloud_storage::client::{Client, ClientConfig}; +use google_cloud_storage::http::objects::upload::Media; +use google_cloud_storage::http::objects::upload::UploadObjectRequest; +use google_cloud_storage::http::objects::upload::UploadType; use google_cloud_storage::http::objects::{ download::Range, get::GetObjectRequest, list::ListObjectsRequest, }; -use std::path::PathBuf; -use thiserror::Error; +use psyche_coordinator::model::{self, GcsRepo}; +use psyche_core::FixedString; +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; use tokio::runtime::Runtime; +use tokio::sync::mpsc; use tracing::info; -#[derive(Debug, Error)] -pub enum GcsError { - #[error("GCS authentication failed: {0}")] - Auth(#[from] google_cloud_storage::client::google_cloud_auth::error::Error), +/// Checkpoint manifest.json uploaded to GCS alongside safetensors files. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GcsCheckpointManifest { + pub metadata: ManifestMetadata, + pub files: Vec, +} - #[error("GCS operation failed: {0}")] - Storage(#[from] google_cloud_storage::http::Error), +/// Checkpoint metadata. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ManifestMetadata { + pub timestamp: DateTime, + pub epoch: u32, + pub step: u32, + pub run_id: String, +} - #[error("IO error: {0}")] - Io(#[from] std::io::Error), +/// Single file entry in the manifest. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ManifestFileEntry { + pub filename: String, + pub generation: i64, + pub size_bytes: u64, } -const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; +#[derive(Debug, Clone)] +pub struct GcsUploadInfo { + pub gcs_bucket: String, + pub gcs_prefix: Option, +} -fn check_model_extension(filename: &str) -> bool { - MODEL_EXTENSIONS.iter().any(|ext| filename.ends_with(ext)) +#[derive(Debug, Clone)] +pub struct GcsManifestMetadata { + pub epoch: u32, + pub run_id: String, } -fn get_cache_dir(bucket: &str, prefix: Option<&str>) -> PathBuf { - let base = std::env::var("HOME") - .map(|h| PathBuf::from(h).join(".cache")) - .unwrap_or_else(|_| PathBuf::from(".cache")) +const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; + +fn get_cache_base(bucket: &str) -> PathBuf { + // Use HF_HOME if set, otherwise fall back to ~/.cache + std::env::var("HF_HOME") + .map(PathBuf::from) + .unwrap_or_else(|_| { + std::env::var("HOME") + .map(|h| PathBuf::from(h).join(".cache")) + .unwrap_or_else(|_| PathBuf::from(".cache")) + }) .join("psyche") .join("gcs") - .join(bucket); + .join(bucket) +} + +fn get_cache_dir( + bucket: &str, + prefix: Option<&str>, + step: u32, + manifest_generation: i64, +) -> PathBuf { + let base = get_cache_base(bucket); + let versioned_folder = format!("step-{}-{}", step, manifest_generation); match prefix { - Some(p) => base.join(p.trim_end_matches('/')), - None => base, + Some(p) => base.join(p.trim_end_matches('/')).join(versioned_folder), + None => base.join(versioned_folder), } } +fn get_cache_dir_no_manifest(bucket: &str, prefix: Option<&str>) -> PathBuf { + let base = get_cache_base(bucket); + + match prefix { + Some(p) => base.join(p.trim_end_matches('/')).join("no_manifest"), + None => base.join("no_manifest"), + } +} + +fn collect_cached_files( + cache_dir: &Path, + manifest: &GcsCheckpointManifest, +) -> Option> { + let mut files = Vec::new(); + for file_entry in &manifest.files { + let path = cache_dir.join(&file_entry.filename); + if !path.exists() { + return None; + } + files.push(path); + } + Some(files) +} + pub async fn download_model_from_gcs_async( bucket: &str, prefix: Option<&str>, -) -> Result, GcsError> { +) -> Result, DownloadError> { // Use authenticated client if GOOGLE_APPLICATION_CREDENTIALS is set, otherwise anonymous let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { info!("Using authenticated GCS client"); @@ -53,7 +120,131 @@ pub async fn download_model_from_gcs_async( }; let client = Client::new(config); - // List all objects in the bucket with optional prefix + let manifest_object_path = match prefix { + Some(p) => format!("{}/manifest.json", p), + None => "manifest.json".to_string(), + }; + + // Get manifest metadata to obtain generation number + let manifest_metadata = client + .get_object(&GetObjectRequest { + bucket: bucket.to_owned(), + object: manifest_object_path.clone(), + ..Default::default() + }) + .await; + + match manifest_metadata { + Ok(object_meta) => { + let manifest_generation = object_meta.generation; + + // Download manifest content + let manifest_data = client + .download_object( + &GetObjectRequest { + bucket: bucket.to_owned(), + object: manifest_object_path, + ..Default::default() + }, + &Range::default(), + ) + .await?; + + let manifest: GcsCheckpointManifest = serde_json::from_slice(&manifest_data)?; + + info!( + "Found manifest: step {}, epoch {}, generation {}", + manifest.metadata.step, manifest.metadata.epoch, manifest_generation + ); + + // Build versioned cache path + let cache_dir = + get_cache_dir(bucket, prefix, manifest.metadata.step, manifest_generation); + + // Check if all manifest files exist in cache + let mut files = if let Some(cached) = collect_cached_files(&cache_dir, &manifest) { + info!("Using cached checkpoint at {:?}", cache_dir); + cached + } else { + info!( + "Model not found in cache, downloading checkpoint to {:?}", + cache_dir + ); + std::fs::create_dir_all(&cache_dir)?; + download_files_from_manifest(&client, bucket, prefix, &cache_dir, &manifest).await? + }; + // Download config files (json, py) - skips if already cached + let config_files = + download_files_no_manifest(&client, bucket, prefix, &cache_dir, &[".json", ".py"]) + .await?; + files.extend(config_files); + Ok(files) + } + Err(_) => { + // Fallback for old checkpoints without manifest + info!("No manifest found, downloading model without manifest"); + let cache_dir = get_cache_dir_no_manifest(bucket, prefix); + std::fs::create_dir_all(&cache_dir)?; + download_files_no_manifest(&client, bucket, prefix, &cache_dir, &MODEL_EXTENSIONS).await + } + } +} + +async fn download_files_from_manifest( + client: &Client, + bucket: &str, + prefix: Option<&str>, + cache_dir: &Path, + manifest: &GcsCheckpointManifest, +) -> Result, DownloadError> { + let mut downloaded_files = Vec::new(); + + for file_entry in &manifest.files { + let object_name = match prefix { + Some(p) => format!("{}/{}", p, file_entry.filename), + None => file_entry.filename.clone(), + }; + let local_path = cache_dir.join(&file_entry.filename); + + if local_path.exists() { + info!("Using cached: {}", file_entry.filename); + downloaded_files.push(local_path); + continue; + } + + info!( + "Downloading: gs://{}/{} (generation {})", + bucket, object_name, file_entry.generation + ); + + let data = client + .download_object( + &GetObjectRequest { + bucket: bucket.to_owned(), + object: object_name, + ..Default::default() + }, + &Range::default(), + ) + .await?; + + std::fs::write(&local_path, &data)?; + info!("Downloaded: {} ({} bytes)", file_entry.filename, data.len()); + downloaded_files.push(local_path); + } + + Ok(downloaded_files) +} + +/// Download model files by listing the bucket. Skips files that already exist in cache. +/// Used for initial model download (no manifest) and to fetch config files (json, py) after manifest download. +async fn download_files_no_manifest( + client: &Client, + bucket: &str, + prefix: Option<&str>, + cache_dir: &Path, + extensions: &[&str], +) -> Result, DownloadError> { let mut all_objects = vec![]; let mut page_token: Option = None; @@ -68,7 +259,7 @@ pub async fn download_model_from_gcs_async( .await?; for obj in results.items.iter().flatten() { - if check_model_extension(&obj.name) { + if extensions.iter().any(|ext| obj.name.ends_with(ext)) { all_objects.push(obj.name.clone()); } } @@ -80,33 +271,27 @@ pub async fn download_model_from_gcs_async( } info!( - "Found {} model files in gs://{}/{}", + "Found {} files ({}) in gs://{}/{}", all_objects.len(), + extensions.join(", "), bucket, prefix.unwrap_or("") ); - let cache_dir = get_cache_dir(bucket, prefix); - std::fs::create_dir_all(&cache_dir)?; - let mut downloaded_files = Vec::new(); for object_name in all_objects { - // Get just the filename (strip prefix if present) let filename = object_name.rsplit('/').next().unwrap_or(&object_name); - let local_path = cache_dir.join(filename); - // Skip if already cached if local_path.exists() { info!("Using cached: {}", filename); downloaded_files.push(local_path); continue; } - info!("Downloading: {}", object_name); + info!("Downloading: gs://{}/{}", bucket, object_name); - // Download the object let data = client .download_object( &GetObjectRequest { @@ -132,7 +317,130 @@ pub async fn download_model_from_gcs_async( pub fn download_model_from_gcs_sync( bucket: &str, prefix: Option<&str>, -) -> Result, GcsError> { - let rt = Runtime::new().map_err(GcsError::Io)?; +) -> Result, DownloadError> { + let rt = Runtime::new().map_err(DownloadError::Io)?; rt.block_on(download_model_from_gcs_async(bucket, prefix)) } + +pub async fn upload_to_gcs( + gcs_info: GcsUploadInfo, + manifest_metadata: GcsManifestMetadata, + local: Vec, + step: u64, + tx_checkpoint: mpsc::UnboundedSender, +) -> Result<(), UploadError> { + let GcsUploadInfo { + gcs_bucket, + gcs_prefix, + } = gcs_info; + + let GcsManifestMetadata { epoch, run_id } = manifest_metadata; + + info!(bucket = gcs_bucket, "Uploading checkpoint to GCS"); + + let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { + info!("Using authenticated GCS client"); + ClientConfig::default().with_auth().await? + } else { + info!("Using anonymous GCS client"); + ClientConfig::default().anonymous() + }; + let client = Client::new(config); + + let mut manifest = GcsCheckpointManifest { + metadata: ManifestMetadata { + timestamp: Utc::now(), + epoch, + step: step as u32, + run_id, + }, + files: Vec::new(), + }; + + for path in local { + let file_name = path + .file_name() + .ok_or_else(|| UploadError::NotAFile(path.clone()))? + .to_str() + .ok_or_else(|| UploadError::InvalidFilename(path.clone()))?; + + // Only upload safetensors files + if !file_name.ends_with(".safetensors") { + continue; + } + + let object_name = match &gcs_prefix { + Some(p) => format!("{}/{}", p, file_name), + None => file_name.to_string(), + }; + + let size = std::fs::metadata(&path)?.len(); + let data = tokio::fs::read(&path).await?; + + let upload_type = UploadType::Simple(Media::new(object_name.clone())); + let uploaded = client + .upload_object( + &UploadObjectRequest { + bucket: gcs_bucket.clone(), + ..Default::default() + }, + data, + &upload_type, + ) + .await?; + + info!( + bucket = gcs_bucket, + object = object_name, + size = uploaded.size, + generation = uploaded.generation, + "Uploaded file to GCS" + ); + + manifest.files.push(ManifestFileEntry { + filename: file_name.to_string(), + generation: uploaded.generation, + size_bytes: size, + }); + } + + // Upload the manifest file + let manifest_path = match &gcs_prefix { + Some(p) => format!("{}/manifest.json", p), + None => "manifest.json".to_string(), + }; + let manifest_json = serde_json::to_string_pretty(&manifest)?; + + let upload_type = UploadType::Simple(Media::new(manifest_path.clone())); + client + .upload_object( + &UploadObjectRequest { + bucket: gcs_bucket.clone(), + ..Default::default() + }, + manifest_json.into_bytes(), + &upload_type, + ) + .await?; + + info!( + bucket = gcs_bucket, + object = manifest_path, + "Uploaded manifest to GCS" + ); + + info!( + "Upload to GCS complete at gs://{}/{}", + gcs_bucket, + gcs_prefix.as_deref().unwrap_or("") + ); + + tx_checkpoint + .send(model::Checkpoint::Gcs(GcsRepo { + bucket: FixedString::from_str_truncated(&gcs_bucket), + prefix: gcs_prefix.map(|p| FixedString::from_str_truncated(&p)), + })) + .map_err(|_| UploadError::SendCheckpoint)?; + + Ok(()) +} diff --git a/shared/data-provider/src/hub.rs b/shared/data-provider/src/hub.rs index ca090a759..13a575b84 100644 --- a/shared/data-provider/src/hub.rs +++ b/shared/data-provider/src/hub.rs @@ -1,12 +1,16 @@ +use crate::errors::UploadError; +use crate::hub::model::HubRepo; use hf_hub::{ Cache, Repo, RepoType, api::{ Siblings, - tokio::{ApiError, CommitError, UploadSource}, + tokio::{ApiError, UploadSource}, }, }; +use psyche_coordinator::model; +use psyche_core::FixedString; use std::{path::PathBuf, time::Instant}; -use thiserror::Error; +use tokio::sync::mpsc; use tracing::{error, info}; const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; @@ -189,42 +193,39 @@ pub fn download_dataset_repo_sync( ) } -#[derive(Error, Debug)] -pub enum UploadModelError { - #[error("path {0} is not a file")] - NotAFile(PathBuf), - - #[error("file {0} doesn't have a valid utf-8 representation")] - InvalidFilename(PathBuf), +#[derive(Debug, Clone)] +pub struct HubUploadInfo { + pub hub_repo: String, + pub hub_token: String, +} - #[error("failed to connect to HF hub: {0}")] - HfHub(#[from] ApiError), +pub async fn upload_to_hub( + hub_info: HubUploadInfo, + local: Vec, + step: u64, + tx_checkpoint: mpsc::UnboundedSender, +) -> Result<(), UploadError> { + let HubUploadInfo { + hub_repo, + hub_token, + } = hub_info; - #[error("failed to commit files: {0}")] - Commit(#[from] CommitError), -} + info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); -pub async fn upload_model_repo_async( - repo_id: String, - files: Vec, - token: String, - commit_message: Option, - commit_description: Option, -) -> Result { let api = hf_hub::api::tokio::ApiBuilder::new() - .with_token(Some(token)) + .with_token(Some(hub_token.clone())) .build()?; - let repo = Repo::model(repo_id.clone()); + let repo = Repo::model(hub_repo.clone()); let api_repo = api.repo(repo); - let files: Result, _> = files + let files: Result, _> = local .into_iter() .map(|path| { path.file_name() - .ok_or(UploadModelError::NotAFile(path.clone())) + .ok_or(UploadError::NotAFile(path.clone())) .and_then(|name| { name.to_str() - .ok_or(UploadModelError::InvalidFilename(path.clone())) + .ok_or(UploadError::InvalidFilename(path.clone())) .map(|s| s.to_string()) }) .map(|name| (path.into(), name)) @@ -233,32 +234,32 @@ pub async fn upload_model_repo_async( let files = files?; - let commit_info = match api_repo - .upload_files( - files, - commit_message.clone(), - commit_description.clone(), - false, - ) + let commit_info = api_repo + .upload_files(files, Some(format!("step {step}")), None, false) .await - { - Ok(info) => { - info!( - repo = repo_id, - oid = info.oid, - "Successfully uploaded files to HuggingFace" - ); - info - } - Err(e) => { + .map_err(|e| { error!( - repo = repo_id, + repo = hub_repo, error = ?e, - "Failed to upload files to HuggingFace. Full error details: {:#?}", - e + "Failed to upload files to HuggingFace" ); - return Err(e.into()); - } - }; - Ok(commit_info.oid) + e + })?; + + let revision = commit_info.oid; + + info!( + repo = hub_repo, + revision = revision, + "Upload to HuggingFace complete" + ); + + tx_checkpoint + .send(model::Checkpoint::Hub(HubRepo { + repo_id: FixedString::from_str_truncated(&hub_repo), + revision: Some(FixedString::from_str_truncated(&revision)), + })) + .map_err(|_| UploadError::SendCheckpoint)?; + + Ok(()) } diff --git a/shared/data-provider/src/lib.rs b/shared/data-provider/src/lib.rs index 7b97fe101..0044d77d2 100644 --- a/shared/data-provider/src/lib.rs +++ b/shared/data-provider/src/lib.rs @@ -1,6 +1,7 @@ mod data_provider; mod dataset; mod dummy; +mod errors; mod file_extensions; mod gcs; pub mod http; @@ -14,11 +15,15 @@ mod weighted; pub use data_provider::DataProvider; pub use dataset::{Dataset, Field, Row, Split}; pub use dummy::DummyDataProvider; +pub use errors::{DownloadError, UploadError}; pub use file_extensions::{DATA_FILE_EXTENSIONS, PARQUET_EXTENSION}; -pub use gcs::{GcsError, download_model_from_gcs_async, download_model_from_gcs_sync}; +pub use gcs::{ + GcsCheckpointManifest, GcsManifestMetadata, GcsUploadInfo, ManifestFileEntry, ManifestMetadata, + download_model_from_gcs_async, download_model_from_gcs_sync, upload_to_gcs, +}; pub use hub::{ - UploadModelError, download_dataset_repo_async, download_dataset_repo_sync, - download_model_repo_async, download_model_repo_sync, upload_model_repo_async, + HubUploadInfo, download_dataset_repo_async, download_dataset_repo_sync, + download_model_repo_async, download_model_repo_sync, upload_to_hub, }; pub use local::LocalDataProvider; pub use parquet::record::{ListAccessor, MapAccessor, RowAccessor}; diff --git a/shared/watcher/src/traits.rs b/shared/watcher/src/traits.rs index 1cd96a7ed..50bf317bb 100644 --- a/shared/watcher/src/traits.rs +++ b/shared/watcher/src/traits.rs @@ -27,5 +27,5 @@ pub trait Backend: Send + Sync { async fn wait_for_new_state(&mut self) -> Result>; async fn send_witness(&mut self, opportunistic_data: OpportunisticData) -> Result<()>; async fn send_health_check(&mut self, health_check: HealthChecks) -> Result<()>; - async fn send_checkpoint(&mut self, checkpoint: model::HubRepo) -> Result<()>; + async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()>; } From b4db519f87dce74ed350cac77b839b74e330dad7 Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Thu, 15 Jan 2026 08:16:21 -0800 Subject: [PATCH 18/24] generation: Some(file_entry.generation) --- shared/data-provider/src/gcs.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index b79a850b1..71f29e414 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -222,6 +222,7 @@ async fn download_files_from_manifest( &GetObjectRequest { bucket: bucket.to_owned(), object: object_name, + generation: Some(file_entry.generation), ..Default::default() }, &Range::default(), From e2d87ef2635517b0052d98df433c189cbfcc02c5 Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Fri, 16 Jan 2026 07:26:24 -0800 Subject: [PATCH 19/24] fix tcp send_checkpoint --- shared/data-provider/examples/tcp.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shared/data-provider/examples/tcp.rs b/shared/data-provider/examples/tcp.rs index 13dcf8b44..67cc9ab9d 100644 --- a/shared/data-provider/examples/tcp.rs +++ b/shared/data-provider/examples/tcp.rs @@ -37,7 +37,7 @@ impl WatcherBackend for DummyBackend { bail!("Data provider does not send health check"); } - async fn send_checkpoint(&mut self, _checkpoint: model::HubRepo) -> anyhow::Result<()> { + async fn send_checkpoint(&mut self, _checkpoint: model::Checkpoint) -> anyhow::Result<()> { bail!("Data provider does not send checkpoints"); } } From 45864cac9155dcac2b1c50e433f055c259e8e707 Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Fri, 16 Jan 2026 08:17:06 -0800 Subject: [PATCH 20/24] HF_TOKEN: secrets.HF_TOKEN --- .github/workflows/solana-integration-test-base.yml | 2 ++ .github/workflows/solana-integration-test-run.yml | 1 + 2 files changed, 3 insertions(+) diff --git a/.github/workflows/solana-integration-test-base.yml b/.github/workflows/solana-integration-test-base.yml index fca75de51..2849cdb89 100644 --- a/.github/workflows/solana-integration-test-base.yml +++ b/.github/workflows/solana-integration-test-base.yml @@ -126,5 +126,7 @@ jobs: # Step 4: Run Integration Tests - name: Run integration test + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | nix develop --command cargo test --release -p psyche-decentralized-testing --test integration_tests -- --nocapture "${{ inputs.test-name }}" diff --git a/.github/workflows/solana-integration-test-run.yml b/.github/workflows/solana-integration-test-run.yml index 7e540d73b..25ae27075 100644 --- a/.github/workflows/solana-integration-test-run.yml +++ b/.github/workflows/solana-integration-test-run.yml @@ -35,3 +35,4 @@ jobs: uses: ./.github/workflows/solana-integration-test-base.yml with: test-name: ${{ matrix.test-name }} + secrets: inherit From 916d3ae2e0a84e56aa4716170b711614f1d90918 Mon Sep 17 00:00:00 2001 From: Mariano Nicolini Date: Wed, 28 Jan 2026 12:22:36 -0800 Subject: [PATCH 21/24] add button for GCS checkpoint if bucket is public --- .../solana-authorizer/Cargo.lock | 2 +- .../solana-coordinator/Cargo.lock | 8 +- .../solana-mining-pool/Cargo.lock | 2 +- .../src/dataStores/flatFileCoordinator.ts | 16 +-- website/backend/src/index.ts | 16 +++ .../src/components/CheckpointButton.tsx | 119 ++++++++++++------ website/frontend/src/fetchRuns.ts | 8 ++ website/shared/index.ts | 3 +- 8 files changed, 125 insertions(+), 49 deletions(-) diff --git a/architectures/decentralized/solana-authorizer/Cargo.lock b/architectures/decentralized/solana-authorizer/Cargo.lock index c1a36b8a6..7eb386dad 100644 --- a/architectures/decentralized/solana-authorizer/Cargo.lock +++ b/architectures/decentralized/solana-authorizer/Cargo.lock @@ -1389,7 +1389,7 @@ dependencies = [ [[package]] name = "psyche-solana-authorizer" -version = "0.1.0" +version = "0.2.0" dependencies = [ "anchor-lang", "anchor-spl", diff --git a/architectures/decentralized/solana-coordinator/Cargo.lock b/architectures/decentralized/solana-coordinator/Cargo.lock index 22d64fadc..72a38df53 100644 --- a/architectures/decentralized/solana-coordinator/Cargo.lock +++ b/architectures/decentralized/solana-coordinator/Cargo.lock @@ -1602,7 +1602,7 @@ dependencies = [ [[package]] name = "psyche-coordinator" -version = "0.1.0" +version = "0.2.0" dependencies = [ "anchor-lang", "async-trait", @@ -1616,7 +1616,7 @@ dependencies = [ [[package]] name = "psyche-core" -version = "0.1.0" +version = "0.2.0" dependencies = [ "anchor-lang", "anchor-lang-idl", @@ -1635,7 +1635,7 @@ dependencies = [ [[package]] name = "psyche-solana-authorizer" -version = "0.1.0" +version = "0.2.0" dependencies = [ "anchor-lang", "anchor-spl", @@ -1643,7 +1643,7 @@ dependencies = [ [[package]] name = "psyche-solana-coordinator" -version = "0.1.0" +version = "0.2.0" dependencies = [ "anchor-lang", "bytemuck", diff --git a/architectures/decentralized/solana-mining-pool/Cargo.lock b/architectures/decentralized/solana-mining-pool/Cargo.lock index 06fb31df5..225d03bf9 100644 --- a/architectures/decentralized/solana-mining-pool/Cargo.lock +++ b/architectures/decentralized/solana-mining-pool/Cargo.lock @@ -1389,7 +1389,7 @@ dependencies = [ [[package]] name = "psyche-solana-mining-pool" -version = "0.1.1" +version = "0.2.0" dependencies = [ "anchor-lang", "anchor-spl", diff --git a/website/backend/src/dataStores/flatFileCoordinator.ts b/website/backend/src/dataStores/flatFileCoordinator.ts index 30c84813a..6bdc15bba 100644 --- a/website/backend/src/dataStores/flatFileCoordinator.ts +++ b/website/backend/src/dataStores/flatFileCoordinator.ts @@ -642,13 +642,15 @@ export class FlatFileCoordinatorDataStore implements CoordinatorDataStore { } satisfies RunRoundClient }) - const checkpoint = - (typeof c.coordinator.model.LLM.checkpoint === 'object' && - (('Hub' in c.coordinator.model.LLM.checkpoint && - c.coordinator.model.LLM.checkpoint.Hub) || - ('P2P' in c.coordinator.model.LLM.checkpoint && - c.coordinator.model.LLM.checkpoint.P2P))) || - null + const checkpoint = (() => { + const cp = c.coordinator.model.LLM.checkpoint + if (typeof cp !== 'object') return null + if ('Hub' in cp) return { Hub: cp.Hub } + if ('P2P' in cp) return { Hub: cp.P2P } + if ('Gcs' in cp) return { Gcs: cp.Gcs } + if ('P2PGcs' in cp) return { Gcs: cp.P2PGcs } + return null + })() const config = c.coordinator.config state = { diff --git a/website/backend/src/index.ts b/website/backend/src/index.ts index dcc86d63f..07549a0e7 100644 --- a/website/backend/src/index.ts +++ b/website/backend/src/index.ts @@ -382,6 +382,22 @@ async function main() { } }) + fastify.get<{ + Querystring: { bucket: string; prefix?: string } + }>('/check-gcs-bucket', async (request) => { + const { bucket, prefix } = request.query + const path = prefix ? `${prefix}/` : '' + const url = `https://storage.googleapis.com/${bucket}/${path}manifest.json` + try { + const response = await fetch(url, { method: 'HEAD' }) + return { isValid: response.ok, description: response.statusText } + } catch (error) { + const errorMessage = + error instanceof Error ? error.message : 'Unknown error' + return { isValid: false, description: errorMessage } + } + }) + fastify.get('/status', async (_, res) => { const data = { commit: process.env.GITCOMMIT ?? '???', diff --git a/website/frontend/src/components/CheckpointButton.tsx b/website/frontend/src/components/CheckpointButton.tsx index af58d11ad..51ae0d797 100644 --- a/website/frontend/src/components/CheckpointButton.tsx +++ b/website/frontend/src/components/CheckpointButton.tsx @@ -2,32 +2,54 @@ import { useState, useEffect } from 'react' import { Button } from './Button.js' import HuggingfaceIcon from '../assets/icons/huggingface.svg?react' -import { fetchCheckpointStatus } from '../fetchRuns.js' +import LinkIcon from '../assets/icons/link.svg?react' +import { + fetchCheckpointStatus, + fetchGcsCheckpointStatus, +} from '../fetchRuns.js' +import type { GcsRepo, HubRepo } from 'shared' -export const CheckpointButton = ({ - checkpoint, -}: { - checkpoint: { repo_id: string; revision?: string | null } -}) => { +type CheckpointProps = { + checkpoint: { Hub: HubRepo } | { Gcs: GcsRepo } +} + +export const CheckpointButton = ({ checkpoint }: CheckpointProps) => { const [isValid, setIsValid] = useState(undefined) - useEffect(() => { - const parsedRepo = checkpoint.repo_id.split('/') + const isHub = 'Hub' in checkpoint + const isGcs = 'Gcs' in checkpoint - if (parsedRepo.length !== 2) { - setIsValid(false) - return - } - const [owner, repo] = parsedRepo + useEffect(() => { + if (isHub) { + const repoId = checkpoint.Hub.repo_id + const parsedRepo = repoId.split('/') - fetchCheckpointStatus(owner, repo, checkpoint.revision || undefined) - .then((data) => { - setIsValid(data.isValid) - }) - .catch(() => { + if (parsedRepo.length !== 2) { setIsValid(false) - }) - }, [checkpoint.repo_id, checkpoint.revision]) + return + } + const [owner, repo] = parsedRepo + + fetchCheckpointStatus(owner, repo, checkpoint.Hub.revision || undefined) + .then((data) => { + setIsValid(data.isValid) + }) + .catch(() => { + setIsValid(false) + }) + } else if (isGcs) { + const bucket = checkpoint.Gcs.bucket + const prefix = checkpoint.Gcs.prefix || undefined + + fetchGcsCheckpointStatus(bucket, prefix) + .then((data) => { + setIsValid(data.isValid) + }) + .catch(() => { + setIsValid(false) + }) + } + }, [checkpoint, isHub, isGcs]) if (isValid === undefined) { return null @@ -38,19 +60,46 @@ export const CheckpointButton = ({ return null } - return ( - - ) + if (isHub) { + const repoId = checkpoint.Hub.repo_id + const revision = checkpoint.Hub.revision + + return ( + + ) + } + + if (isGcs) { + const bucket = checkpoint.Gcs.bucket + const prefix = checkpoint.Gcs.prefix + + return ( + + ) + } + + return null } diff --git a/website/frontend/src/fetchRuns.ts b/website/frontend/src/fetchRuns.ts index 438d2d648..742867253 100644 --- a/website/frontend/src/fetchRuns.ts +++ b/website/frontend/src/fetchRuns.ts @@ -64,6 +64,14 @@ export async function fetchCheckpointStatus( return psycheJsonFetch(`check-checkpoint?${queryParams}`) } +export async function fetchGcsCheckpointStatus( + bucket: string, + prefix?: string +): Promise { + const queryParams = `bucket=${bucket}${prefix ? `&prefix=${prefix}` : ''}` + return psycheJsonFetch(`check-gcs-bucket?${queryParams}`) +} + interface DecodeState { buffer: string decoder: TextDecoder diff --git a/website/shared/index.ts b/website/shared/index.ts index 61f10b7a0..0b19b75d0 100644 --- a/website/shared/index.ts +++ b/website/shared/index.ts @@ -9,6 +9,7 @@ type PsycheSolanaCoordinator = coordinatorTypes.PsycheSolanaCoordinator type PsycheSolanaMiningPool = miningPoolTypes.PsycheSolanaMiningPool import type { + GcsRepo, HubRepo, LearningRateSchedule, LLMArchitecture, @@ -118,7 +119,7 @@ export interface RunData { epochStartTime: Date clients: Array - checkpoint: HubRepo | null + checkpoint: { Hub: HubRepo } | { Gcs: GcsRepo } | null round: number config: { From 2d1e5272d34f68082b4b438bbd4bccc04e2fe29a Mon Sep 17 00:00:00 2001 From: Mariano Nicolini Date: Wed, 28 Jan 2026 14:14:40 -0800 Subject: [PATCH 22/24] fix fakedata.ts --- website/frontend/src/fakeData.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/website/frontend/src/fakeData.ts b/website/frontend/src/fakeData.ts index 87df62f7e..eb3d062b5 100644 --- a/website/frontend/src/fakeData.ts +++ b/website/frontend/src/fakeData.ts @@ -382,8 +382,10 @@ function makeFakeRunDataSeeded(seed = 1, step = 0, index = 0): RunData { round, clients, checkpoint: { - repo_id: 'PsycheFoundation/Skibbler', - revision: null, + Hub: { + repo_id: 'PsycheFoundation/Skibbler', + revision: null, + }, }, config: { cooldownTime: 5_000, From 699c4fc379620b45e21975599d49090ac566d5ae Mon Sep 17 00:00:00 2001 From: Pedro Fontana Date: Mon, 2 Feb 2026 07:57:54 -0800 Subject: [PATCH 23/24] clippy --- shared/client/src/state/cooldown.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index f756343f9..4a3ce6a82 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -22,7 +22,7 @@ use tokio::{ sync::{Mutex, mpsc}, task::JoinHandle, }; -use tracing::{Instrument, error, info, info_span, warn}; +use tracing::{Instrument, info, info_span, warn}; use super::{ CheckpointConfig, From 39d5a4b254b08bbfc0b25775eb3482665a19ff2d Mon Sep 17 00:00:00 2001 From: Mariano Nicolini Date: Mon, 2 Feb 2026 18:54:10 -0300 Subject: [PATCH 24/24] add GCS documentation for bucket creation and manifest file explanation --- psyche-book/src/SUMMARY.md | 1 + psyche-book/src/explain/gcs-checkpoints.md | 213 +++++++++++++++++++++ 2 files changed, 214 insertions(+) create mode 100644 psyche-book/src/explain/gcs-checkpoints.md diff --git a/psyche-book/src/SUMMARY.md b/psyche-book/src/SUMMARY.md index ace53fd84..3a767f171 100644 --- a/psyche-book/src/SUMMARY.md +++ b/psyche-book/src/SUMMARY.md @@ -14,6 +14,7 @@ - [General workflow](./explain/general-workflow.md) - [Data provider](./explain/data-provider.md) - [Model sharing](./explain/model-sharing.md) + - [GCS Checkpoints](./explain/gcs-checkpoints.md) - [Rewards](./explain/rewards.md) - [Glossary](./explain/glossary.md) diff --git a/psyche-book/src/explain/gcs-checkpoints.md b/psyche-book/src/explain/gcs-checkpoints.md new file mode 100644 index 000000000..e1f21e53f --- /dev/null +++ b/psyche-book/src/explain/gcs-checkpoints.md @@ -0,0 +1,213 @@ +# GCS Checkpoints + +Psyche supports uploading training checkpoints to Google Cloud Storage (GCS). This page covers how to set up a GCS bucket for checkpoints and explains the manifest system used to ensure reliable checkpoint downloads. + +## Setting Up a GCS Bucket + +### Prerequisites + +- A Google account +- A credit or debit card for billing purposes + +### Step 1: Go to Google Cloud Console + +1. Open your browser and go to [console.cloud.google.com](https://console.cloud.google.com) +2. Sign in with your Google account + +### Step 2: Create a Project + +A project is a container for all your Google Cloud resources. + +1. Click the project dropdown at the top of the page (it may say "Select a project" or show an existing project name) +2. Click **New Project** in the top right of the popup +3. Enter a **Project name** (e.g., `my-ml-checkpoints`) +4. Leave Organization as "No organization" (unless you have one) +5. Click **Create** +6. Wait a few seconds, then select your new project from the dropdown + +### Step 3: Set Up Billing + +Google Cloud requires a billing account to use any services, even free ones. + +1. Click the hamburger menu (☰) in the top left +2. Go to **Billing** +3. Click **Link a billing account** (or **Create account** if you don't have one) + +If creating a new billing account: + +1. Click **Create billing account** +2. Choose your country +3. Enter your payment information (credit or debit card) +4. Click **Submit and enable billing** + +Your project is now linked to your billing account. + +> **Note:** Google Cloud has a free tier. Small buckets with light usage cost almost nothing. You can also set up budget alerts to avoid surprises. + +### Step 4: Create a Storage Bucket + +1. Click the hamburger menu (☰) in the top left +2. Scroll down and click **Cloud Storage** → **Buckets** +3. Once you are redirected to the Buckets workspace, click **Create** at the top +4. Enter a **globally unique** name (e.g., `yourname-ml-checkpoints-2025`) +5. Choose a location type: + - **Region** (recommended): Cheapest option. Best for backups, large files, or when your compute runs in a single region. Pick one close to you (e.g., `us-central1`, `europe-west1`). + - **Multi-region**: ~20-30% more expensive. Better availability and lower latency for users spread globally. Choose this if you're serving content worldwide and need fast access. + - **Dual-region**: Most expensive. High availability between two specific regions. Rarely needed for most use cases. +6. Choose where to store your data - select the **Standard** storage class. This is best for frequently accessed data +7. Choose how to store your data - select **Uniform** access control and leave "Enforce public access prevention" checked +8. Choose how to protect object data: + - **Soft delete**: Leave default (7 days) — lets you recover accidentally deleted files + - **Object versioning**: Turn **ON** — this is important for having a history of the latest checkpoints. It keeps previous versions when files are overwritten. Select a number of versions to store per object – this will be important so that storage of the bucket doesn't grow infinitely. Set a reasonable number depending on the amount of checkpoints you want stored. Leave the 'Expire noncurrent versions after' blank so that old versions of the checkpoints are not deleted after some amount of time. +9. Encryption – Leave as **Google-managed encryption key** (default) +10. Click **Create**. If prompted, leave "enforce public access prevention" **on** + +### Step 5: Verify Your Bucket + +1. You should see your new bucket in the list +2. Click on the bucket name to open it +3. You can now upload files using the **Upload files** button + +### Step 6: Grant Storage Access to Users + +To allow users to access the bucket in order to push checkpoints in a training run, you can grant them bucket-level permissions. + +1. Go to **Cloud Storage** → **Buckets** +2. Click on your bucket name to open it +3. Click the **Permissions** tab +4. Click **Grant Access** +5. In the **New principals** field, enter the Gmail address (e.g., `someone@gmail.com`) +6. Click **Select a role** and choose **Cloud Storage** → **Storage Object User**. This allows read, list, create, and overwrite objects, but not delete +7. Click **Save** + +The user can now authenticate using the gcloud CLI. If you don't have it installed, follow the [installation guide](https://cloud.google.com/sdk/docs/install). + +```bash +gcloud auth application-default login +``` + +or + +```bash +gcloud auth application-default login --scopes="https://www.googleapis.com/auth/cloud-platform" +``` + +### Useful Links + +- [Google Cloud Console](https://console.cloud.google.com) +- [Cloud Storage Documentation](https://cloud.google.com/storage/docs) +- [Pricing Calculator](https://cloud.google.com/products/calculator) +- [Free Tier Details](https://cloud.google.com/free) + +--- + +## Checkpoint Manifest + +The `manifest.json` file is a metadata document uploaded to GCS alongside checkpoint files. It serves as an atomic, versioned index of checkpoint files that enables reliable and efficient checkpoint downloads. + +### File Location + +``` +gs://{bucket}/{prefix}/manifest.json +``` + +Or without prefix: + +``` +gs://{bucket}/manifest.json +``` + +### Schema + +```json +{ + "metadata": { + "timestamp": "2024-01-15T10:30:00Z", + "epoch": 5, + "step": 12500, + "run_id": "my-training-run" + }, + "files": [ + { + "filename": "model-00001-of-00004.safetensors", + "generation": 1705312200123456, + "size_bytes": 536870912 + }, + { + "filename": "model-00002-of-00004.safetensors", + "generation": 1705312205654321, + "size_bytes": 536870912 + } + ] +} +``` + +| Field | Description | +| -------------------- | ------------------------------------------- | +| `metadata.timestamp` | ISO 8601 timestamp of upload | +| `metadata.epoch` | Training epoch number | +| `metadata.step` | Training step number | +| `metadata.run_id` | Unique identifier for the training run | +| `files[].filename` | Name of the checkpoint file | +| `files[].generation` | GCS object generation number for versioning | +| `files[].size_bytes` | File size in bytes | + +### Why Use a Manifest? + +#### 1. Atomic Checkpoint Consistency + +GCS has eventual consistency for object listings. When a checkpoint consists of multiple safetensors shards (e.g., `model-00001-of-00004.safetensors` through `model-00004-of-00004.safetensors`), a bucket listing might return a mix of files from different checkpoint versions if an upload is in progress. + +The manifest is uploaded **after** all safetensors files are successfully uploaded. It acts as an atomic marker indicating that all files for a checkpoint are available. Downloaders read the manifest first to get the exact list of files to fetch. + +#### 2. GCS Object Versioning with Generation Numbers + +GCS uses "generation numbers" to version objects. When a file is overwritten, it gets a new generation number. Without tracking generations, a client might download file A from checkpoint v1 and file B from checkpoint v2 if uploads overlap. + +The manifest records the exact `generation` number for each file at upload time. During download, the client requests files with their specific generation numbers, ensuring all files belong to the same checkpoint version. + +```rust +// Upload: Record generation after each file upload +manifest.files.push(ManifestFileEntry { + filename: file_name.to_string(), + generation: uploaded.generation, // GCS returns this after upload + size_bytes: size, +}); + +// Download: Request specific generation +client.download_object(&GetObjectRequest { + generation: Some(file_entry.generation), // Pin to exact version + .. +}) +``` + +### Upload Flow + +1. Client completes a training epoch and saves checkpoint locally +2. `upload_to_gcs()` is called with local file paths +3. For each `.safetensors` file: + - Upload to GCS + - Capture the returned `generation` number + - Add entry to manifest's `files` array +4. Serialize manifest to JSON and upload to `{prefix}/manifest.json` +5. Notify coordinator via channel with `Checkpoint::Gcs(GcsRepo { bucket, prefix })` + +### Download Flow + +1. Client receives `Checkpoint::Gcs` from coordinator +2. Fetch `manifest.json` metadata to get its generation number +3. Download and parse manifest JSON +4. Compute cache directory: `step-{step}-{manifest_generation}` +5. Check if all files exist in cache + - **Cache hit:** Return cached file paths immediately + - **Cache miss:** Download each file using its recorded generation number +6. Also download config files (`.json`, `.py`) that aren't in manifest +7. Return list of local file paths + +### Files Tracked + +The manifest only tracks `.safetensors` files (model weights). Config files (`.json`, `.py`) are downloaded separately via bucket listing because: + +- They're small and change infrequently +- They may be shared across checkpoints +- They don't have the same consistency concerns as sharded model weights