diff --git a/Cargo.lock b/Cargo.lock index 85f3b06da..dbcbedf81 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7015,6 +7015,7 @@ dependencies = [ "anyhow", "async-trait", "bytemuck", + "chrono", "clap", "futures", "google-cloud-storage", @@ -7101,6 +7102,7 @@ dependencies = [ "tokio-util 0.7.16", "torch-sys", "tracing", + "tracing-subscriber", ] [[package]] diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 79034e814..9a82f2ed1 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -2,6 +2,8 @@ use anyhow::{Error, Result}; use bytemuck::Zeroable; use hf_hub::Repo; use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientMessage}; +use psyche_client::HubUploadInfo; +use psyche_client::UploadInfo; use psyche_client::{ Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key, }; @@ -29,7 +31,7 @@ pub type TabsData = ::Data; pub enum ToSend { Witness(Box), HealthCheck(HealthChecks), - Checkpoint(model::HubRepo), + Checkpoint(model::Checkpoint), } struct Backend { @@ -67,7 +69,7 @@ impl WatcherBackend for Backend { Ok(()) } - async fn send_checkpoint(&mut self, checkpoint: model::HubRepo) -> Result<()> { + async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()> { self.tx.send(ToSend::Checkpoint(checkpoint))?; Ok(()) } @@ -176,18 +178,19 @@ impl App { ) -> Result<()> { // sanity checks if let Some(checkpoint_config) = &state_options.checkpoint_config { - if let Some(hub_upload) = &checkpoint_config.hub_upload { + if let Some(UploadInfo::Hub(HubUploadInfo { + hub_repo, + hub_token, + })) = &checkpoint_config.upload_info + { let api = hf_hub::api::tokio::ApiBuilder::new() - .with_token(Some(hub_upload.hub_token.clone())) + .with_token(Some(hub_token.clone())) .build()?; - let repo_api = api.repo(Repo::new( - hub_upload.hub_repo.clone(), - hf_hub::RepoType::Model, - )); + let repo_api = api.repo(Repo::new(hub_repo.clone(), hf_hub::RepoType::Model)); if !repo_api.is_writable().await { anyhow::bail!( "Checkpoint upload repo {} is not writable with the passed API key.", - hub_upload.hub_repo + hub_repo ) } } diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index bc034d1db..402146817 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -9,7 +9,8 @@ use psyche_coordinator::{ use psyche_core::{FixedVec, Shuffle, SizedIterator, TokenSize}; use psyche_data_provider::{ - DataProviderTcpServer, DataServerTui, LocalDataProvider, download_model_repo_async, + DataProviderTcpServer, DataServerTui, LocalDataProvider, download_model_from_gcs_async, + download_model_repo_async, }; use psyche_network::{ClientNotification, TcpServer}; use psyche_tui::{ @@ -80,7 +81,7 @@ impl psyche_watcher::Backend for ChannelCoordinatorBackend { bail!("Server does not send health checks"); } - async fn send_checkpoint(&mut self, _checkpoint: model::HubRepo) -> Result<()> { + async fn send_checkpoint(&mut self, _checkpoint: model::Checkpoint) -> Result<()> { bail!("Server does not send checkpoints"); } } @@ -201,9 +202,15 @@ impl App { Checkpoint::Dummy(_) => { // ok! } - Checkpoint::P2P(_) => { + Checkpoint::P2P(_) | Checkpoint::P2PGcs(_) => { bail!("Can't start up a run with a P2P checkpoint.") } + Checkpoint::Gcs(gcs_repo) => { + let bucket: String = (&gcs_repo.bucket).into(); + let prefix: Option = + gcs_repo.prefix.map(|p| (&p).into()); + download_model_from_gcs_async(&bucket, prefix.as_deref()).await?; + } } let server_addr: SocketAddr = String::from(url).parse().map_err(|e| { diff --git a/architectures/centralized/shared/src/protocol.rs b/architectures/centralized/shared/src/protocol.rs index c71c7524f..a96c64b8f 100644 --- a/architectures/centralized/shared/src/protocol.rs +++ b/architectures/centralized/shared/src/protocol.rs @@ -15,7 +15,7 @@ pub enum ClientToServerMessage { Join { run_id: String }, Witness(Box), HealthCheck(HealthChecks), - Checkpoint(model::HubRepo), + Checkpoint(model::Checkpoint), } #[derive(Serialize, Deserialize, Debug, Clone)] diff --git a/architectures/decentralized/solana-client/src/main.rs b/architectures/decentralized/solana-client/src/main.rs index 4c4849c19..99d234b10 100644 --- a/architectures/decentralized/solana-client/src/main.rs +++ b/architectures/decentralized/solana-client/src/main.rs @@ -283,34 +283,50 @@ async fn async_main() -> Result<()> { bail!("Model is not an LLM, unsure how to predownload."); }; - let checkpoint = match model_config.checkpoint { + match model_config.checkpoint { Checkpoint::Ephemeral => { bail!("Can't predownload model with ephemeral checkpoint.") } Checkpoint::Dummy(hub_repo) | Checkpoint::Hub(hub_repo) - | Checkpoint::P2P(hub_repo) => hub_repo, + | Checkpoint::P2P(hub_repo) => { + let repo_id = hub_repo.repo_id.to_string(); + let revision = hub_repo.revision.map(|s| s.to_string()); + println!( + "Predownloading model {repo_id} revision {}", + revision.as_ref().unwrap_or(&"main".to_string()) + ); + + let hub_read_token = std::env::var("HF_TOKEN").ok(); + let cache_folder = None; // Uses HF_HOME env var + + psyche_data_provider::download_model_repo_async( + &repo_id, + revision, + cache_folder, + hub_read_token, + Some(hub_max_concurrent_downloads), + true, + ) + .await?; + } + Checkpoint::Gcs(gcs_repo) | Checkpoint::P2PGcs(gcs_repo) => { + let bucket = gcs_repo.bucket.to_string(); + let prefix: Option = gcs_repo.prefix.map(|p| p.to_string()); + println!( + "Predownloading model from gs://{}/{}", + bucket, + prefix.as_deref().unwrap_or("") + ); + + psyche_data_provider::download_model_from_gcs_async( + &bucket, + prefix.as_deref(), + ) + .await?; + } }; - let repo_id = checkpoint.repo_id.to_string(); - let revision = checkpoint.revision.map(|s| s.to_string()); - println!( - "Predownloading model {repo_id} revision {}", - revision.as_ref().unwrap_or(&"main".to_string()) - ); - - let hub_read_token = std::env::var("HF_TOKEN").ok(); - let cache_folder = None; // Uses HF_HOME env var - - psyche_data_provider::download_model_repo_async( - &repo_id, - revision, - cache_folder, - hub_read_token, - Some(hub_max_concurrent_downloads), - true, - ) - .await?; println!("Model predownloaded successfully."); } diff --git a/architectures/decentralized/solana-common/src/backend.rs b/architectures/decentralized/solana-common/src/backend.rs index b33d7f4c6..47a2b7dc9 100644 --- a/architectures/decentralized/solana-common/src/backend.rs +++ b/architectures/decentralized/solana-common/src/backend.rs @@ -19,7 +19,8 @@ use anchor_client::{ }; use anyhow::{Context, Result, anyhow}; use futures_util::StreamExt; -use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks, model::HubRepo}; +use psyche_coordinator::model::{self, Checkpoint}; +use psyche_coordinator::{CommitteeProof, Coordinator, HealthChecks}; use psyche_core::IntegrationTestLogMarker; use psyche_watcher::{Backend as WatcherBackend, OpportunisticData}; use solana_account_decoder_client_types::{UiAccount, UiAccountEncoding}; @@ -334,7 +335,7 @@ impl SolanaBackend { &self, coordinator_instance: Pubkey, coordinator_account: Pubkey, - repo: HubRepo, + repo: Checkpoint, ) { let user = self.get_payer(); let instruction = instructions::coordinator_checkpoint( @@ -604,7 +605,7 @@ impl WatcherBackend for SolanaBackendRunner Ok(()) } - async fn send_checkpoint(&mut self, checkpoint: HubRepo) -> Result<()> { + async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()> { self.backend .send_checkpoint(self.instance, self.account, checkpoint); Ok(()) diff --git a/architectures/decentralized/solana-common/src/instructions.rs b/architectures/decentralized/solana-common/src/instructions.rs index 68e169008..b535d54f4 100644 --- a/architectures/decentralized/solana-common/src/instructions.rs +++ b/architectures/decentralized/solana-common/src/instructions.rs @@ -206,7 +206,7 @@ pub fn coordinator_checkpoint( coordinator_instance: &Pubkey, coordinator_account: &Pubkey, user: &Pubkey, - repo: psyche_coordinator::model::HubRepo, + repo: psyche_coordinator::model::Checkpoint, ) -> Instruction { anchor_instruction( psyche_solana_coordinator::ID, diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs index 029b40fad..8d9ddb977 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs @@ -10,7 +10,7 @@ use psyche_coordinator::RunState; use psyche_coordinator::SOLANA_MAX_STRING_LEN; use psyche_coordinator::TickResult; use psyche_coordinator::Witness; -use psyche_coordinator::model::HubRepo; +use psyche_coordinator::model::Checkpoint; use psyche_coordinator::model::Model; use psyche_core::FixedString; use psyche_core::SmallBoolean; @@ -389,7 +389,11 @@ impl CoordinatorInstanceState { self.tick() } - pub fn checkpoint(&mut self, payer: &Pubkey, repo: HubRepo) -> Result<()> { + pub fn checkpoint( + &mut self, + payer: &Pubkey, + repo: Checkpoint, + ) -> Result<()> { // O(n) on clients, reconsider let id = self.clients_state.find_signer(payer)?; let index = self diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs index 657424195..0a041a6e9 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs @@ -21,7 +21,7 @@ use psyche_coordinator::Witness; use psyche_coordinator::WitnessBloom; use psyche_coordinator::WitnessMetadata; use psyche_coordinator::WitnessProof; -use psyche_coordinator::model::{HubRepo, Model}; +use psyche_coordinator::model::Model; use psyche_core::MerkleRoot; use serde::Deserialize; use serde::Serialize; @@ -313,7 +313,7 @@ pub mod psyche_solana_coordinator { pub fn checkpoint( ctx: Context, - repo: HubRepo, + repo: psyche_coordinator::model::Checkpoint, ) -> Result<()> { let mut account = ctx.accounts.coordinator_account.load_mut()?; account.increment_nonce(); diff --git a/architectures/decentralized/solana-mining-pool/Cargo.lock b/architectures/decentralized/solana-mining-pool/Cargo.lock index 06fb31df5..225d03bf9 100644 --- a/architectures/decentralized/solana-mining-pool/Cargo.lock +++ b/architectures/decentralized/solana-mining-pool/Cargo.lock @@ -1389,7 +1389,7 @@ dependencies = [ [[package]] name = "psyche-solana-mining-pool" -version = "0.1.1" +version = "0.2.0" dependencies = [ "anchor-lang", "anchor-spl", diff --git a/config/solana-test/light-config-gcs.toml b/config/solana-test/light-config-gcs.toml new file mode 100644 index 000000000..d96fd468f --- /dev/null +++ b/config/solana-test/light-config-gcs.toml @@ -0,0 +1,44 @@ +[config] +warmup_time = 30 +cooldown_time = 30 +epoch_time = 60 +max_round_train_time = 15 +round_witness_time = 1 +min_clients = 1 +init_min_clients = 1 +verification_percent = 0 +witness_nodes = 0 +global_batch_size_start = 8 +global_batch_size_end = 8 +global_batch_size_warmup_tokens = 0 +total_steps = 25000 +waiting_for_members_extra_time = 3 + +[model.LLM] +architecture = "HfLlama" +data_type = "Pretraining" +max_seq_len = 2048 +cold_start_warmup_steps = 0 + +[model.LLM.checkpoint.Gcs] +bucket = "llama220minit" + +[model.LLM.data_location.Http] +token_size_in_bytes = "TwoBytes" +shuffle = "DontShuffle" +[model.LLM.data_location.Http.location.Gcp] +bucket_name = "nous-pretraining-public-us" +filter_directory = "fineweb-edu-tokenized-llama2" + +[model.LLM.lr_schedule.Cosine] +base_lr = 4.0e-4 +warmup_steps = 250 +warmup_init_lr = 0.0 +total_steps = 25000 +final_lr = 4.0e-5 +[model.LLM.optimizer.Distro] +clip_grad_norm = 1.0 +compression_decay = 0.999 +compression_chunk = 64 +compression_topk = 8 +quantize_1bit = true diff --git a/nix/lib.nix b/nix/lib.nix index b0f8107f4..bd55b28f1 100644 --- a/nix/lib.nix +++ b/nix/lib.nix @@ -38,6 +38,7 @@ let python312 pkg-config perl + cargo-nextest ]; buildInputs = diff --git a/psyche-book/src/SUMMARY.md b/psyche-book/src/SUMMARY.md index ace53fd84..3a767f171 100644 --- a/psyche-book/src/SUMMARY.md +++ b/psyche-book/src/SUMMARY.md @@ -14,6 +14,7 @@ - [General workflow](./explain/general-workflow.md) - [Data provider](./explain/data-provider.md) - [Model sharing](./explain/model-sharing.md) + - [GCS Checkpoints](./explain/gcs-checkpoints.md) - [Rewards](./explain/rewards.md) - [Glossary](./explain/glossary.md) diff --git a/psyche-book/src/explain/gcs-checkpoints.md b/psyche-book/src/explain/gcs-checkpoints.md new file mode 100644 index 000000000..e1f21e53f --- /dev/null +++ b/psyche-book/src/explain/gcs-checkpoints.md @@ -0,0 +1,213 @@ +# GCS Checkpoints + +Psyche supports uploading training checkpoints to Google Cloud Storage (GCS). This page covers how to set up a GCS bucket for checkpoints and explains the manifest system used to ensure reliable checkpoint downloads. + +## Setting Up a GCS Bucket + +### Prerequisites + +- A Google account +- A credit or debit card for billing purposes + +### Step 1: Go to Google Cloud Console + +1. Open your browser and go to [console.cloud.google.com](https://console.cloud.google.com) +2. Sign in with your Google account + +### Step 2: Create a Project + +A project is a container for all your Google Cloud resources. + +1. Click the project dropdown at the top of the page (it may say "Select a project" or show an existing project name) +2. Click **New Project** in the top right of the popup +3. Enter a **Project name** (e.g., `my-ml-checkpoints`) +4. Leave Organization as "No organization" (unless you have one) +5. Click **Create** +6. Wait a few seconds, then select your new project from the dropdown + +### Step 3: Set Up Billing + +Google Cloud requires a billing account to use any services, even free ones. + +1. Click the hamburger menu (☰) in the top left +2. Go to **Billing** +3. Click **Link a billing account** (or **Create account** if you don't have one) + +If creating a new billing account: + +1. Click **Create billing account** +2. Choose your country +3. Enter your payment information (credit or debit card) +4. Click **Submit and enable billing** + +Your project is now linked to your billing account. + +> **Note:** Google Cloud has a free tier. Small buckets with light usage cost almost nothing. You can also set up budget alerts to avoid surprises. + +### Step 4: Create a Storage Bucket + +1. Click the hamburger menu (☰) in the top left +2. Scroll down and click **Cloud Storage** → **Buckets** +3. Once you are redirected to the Buckets workspace, click **Create** at the top +4. Enter a **globally unique** name (e.g., `yourname-ml-checkpoints-2025`) +5. Choose a location type: + - **Region** (recommended): Cheapest option. Best for backups, large files, or when your compute runs in a single region. Pick one close to you (e.g., `us-central1`, `europe-west1`). + - **Multi-region**: ~20-30% more expensive. Better availability and lower latency for users spread globally. Choose this if you're serving content worldwide and need fast access. + - **Dual-region**: Most expensive. High availability between two specific regions. Rarely needed for most use cases. +6. Choose where to store your data - select the **Standard** storage class. This is best for frequently accessed data +7. Choose how to store your data - select **Uniform** access control and leave "Enforce public access prevention" checked +8. Choose how to protect object data: + - **Soft delete**: Leave default (7 days) — lets you recover accidentally deleted files + - **Object versioning**: Turn **ON** — this is important for having a history of the latest checkpoints. It keeps previous versions when files are overwritten. Select a number of versions to store per object – this will be important so that storage of the bucket doesn't grow infinitely. Set a reasonable number depending on the amount of checkpoints you want stored. Leave the 'Expire noncurrent versions after' blank so that old versions of the checkpoints are not deleted after some amount of time. +9. Encryption – Leave as **Google-managed encryption key** (default) +10. Click **Create**. If prompted, leave "enforce public access prevention" **on** + +### Step 5: Verify Your Bucket + +1. You should see your new bucket in the list +2. Click on the bucket name to open it +3. You can now upload files using the **Upload files** button + +### Step 6: Grant Storage Access to Users + +To allow users to access the bucket in order to push checkpoints in a training run, you can grant them bucket-level permissions. + +1. Go to **Cloud Storage** → **Buckets** +2. Click on your bucket name to open it +3. Click the **Permissions** tab +4. Click **Grant Access** +5. In the **New principals** field, enter the Gmail address (e.g., `someone@gmail.com`) +6. Click **Select a role** and choose **Cloud Storage** → **Storage Object User**. This allows read, list, create, and overwrite objects, but not delete +7. Click **Save** + +The user can now authenticate using the gcloud CLI. If you don't have it installed, follow the [installation guide](https://cloud.google.com/sdk/docs/install). + +```bash +gcloud auth application-default login +``` + +or + +```bash +gcloud auth application-default login --scopes="https://www.googleapis.com/auth/cloud-platform" +``` + +### Useful Links + +- [Google Cloud Console](https://console.cloud.google.com) +- [Cloud Storage Documentation](https://cloud.google.com/storage/docs) +- [Pricing Calculator](https://cloud.google.com/products/calculator) +- [Free Tier Details](https://cloud.google.com/free) + +--- + +## Checkpoint Manifest + +The `manifest.json` file is a metadata document uploaded to GCS alongside checkpoint files. It serves as an atomic, versioned index of checkpoint files that enables reliable and efficient checkpoint downloads. + +### File Location + +``` +gs://{bucket}/{prefix}/manifest.json +``` + +Or without prefix: + +``` +gs://{bucket}/manifest.json +``` + +### Schema + +```json +{ + "metadata": { + "timestamp": "2024-01-15T10:30:00Z", + "epoch": 5, + "step": 12500, + "run_id": "my-training-run" + }, + "files": [ + { + "filename": "model-00001-of-00004.safetensors", + "generation": 1705312200123456, + "size_bytes": 536870912 + }, + { + "filename": "model-00002-of-00004.safetensors", + "generation": 1705312205654321, + "size_bytes": 536870912 + } + ] +} +``` + +| Field | Description | +| -------------------- | ------------------------------------------- | +| `metadata.timestamp` | ISO 8601 timestamp of upload | +| `metadata.epoch` | Training epoch number | +| `metadata.step` | Training step number | +| `metadata.run_id` | Unique identifier for the training run | +| `files[].filename` | Name of the checkpoint file | +| `files[].generation` | GCS object generation number for versioning | +| `files[].size_bytes` | File size in bytes | + +### Why Use a Manifest? + +#### 1. Atomic Checkpoint Consistency + +GCS has eventual consistency for object listings. When a checkpoint consists of multiple safetensors shards (e.g., `model-00001-of-00004.safetensors` through `model-00004-of-00004.safetensors`), a bucket listing might return a mix of files from different checkpoint versions if an upload is in progress. + +The manifest is uploaded **after** all safetensors files are successfully uploaded. It acts as an atomic marker indicating that all files for a checkpoint are available. Downloaders read the manifest first to get the exact list of files to fetch. + +#### 2. GCS Object Versioning with Generation Numbers + +GCS uses "generation numbers" to version objects. When a file is overwritten, it gets a new generation number. Without tracking generations, a client might download file A from checkpoint v1 and file B from checkpoint v2 if uploads overlap. + +The manifest records the exact `generation` number for each file at upload time. During download, the client requests files with their specific generation numbers, ensuring all files belong to the same checkpoint version. + +```rust +// Upload: Record generation after each file upload +manifest.files.push(ManifestFileEntry { + filename: file_name.to_string(), + generation: uploaded.generation, // GCS returns this after upload + size_bytes: size, +}); + +// Download: Request specific generation +client.download_object(&GetObjectRequest { + generation: Some(file_entry.generation), // Pin to exact version + .. +}) +``` + +### Upload Flow + +1. Client completes a training epoch and saves checkpoint locally +2. `upload_to_gcs()` is called with local file paths +3. For each `.safetensors` file: + - Upload to GCS + - Capture the returned `generation` number + - Add entry to manifest's `files` array +4. Serialize manifest to JSON and upload to `{prefix}/manifest.json` +5. Notify coordinator via channel with `Checkpoint::Gcs(GcsRepo { bucket, prefix })` + +### Download Flow + +1. Client receives `Checkpoint::Gcs` from coordinator +2. Fetch `manifest.json` metadata to get its generation number +3. Download and parse manifest JSON +4. Compute cache directory: `step-{step}-{manifest_generation}` +5. Check if all files exist in cache + - **Cache hit:** Return cached file paths immediately + - **Cache miss:** Download each file using its recorded generation number +6. Also download config files (`.json`, `.py`) that aren't in manifest +7. Return list of local file paths + +### Files Tracked + +The manifest only tracks `.safetensors` files (model weights). Config files (`.json`, `.py`) are downloaded separately via bucket listing because: + +- They're small and change infrequently +- They may be shared across checkpoints +- They don't have the same consistency concerns as sharded model weights diff --git a/shared/client/src/cli.rs b/shared/client/src/cli.rs index 268ea753a..a4ef145f0 100644 --- a/shared/client/src/cli.rs +++ b/shared/client/src/cli.rs @@ -1,7 +1,9 @@ -use crate::{CheckpointConfig, HubUploadInfo, WandBInfo}; +use crate::{CheckpointConfig, WandBInfo}; +use crate::UploadInfo; use anyhow::{Result, anyhow, bail}; use clap::Args; +use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; use psyche_eval::tasktype_from_name; use psyche_modeling::Devices; use psyche_network::{DiscoveryMode, RelayKind, SecretKey}; @@ -146,6 +148,14 @@ pub struct TrainArgs { #[clap(long, env)] pub hub_repo: Option, + /// Name of the GCS bucket containing model data and configuration. + #[clap(long, env)] + pub gcs_bucket: Option, + + /// Prefix within the GCS bucket for model data and configuration. + #[clap(long, env)] + pub gcs_prefix: Option, + #[clap(long, env, default_value_t = 3)] pub hub_max_concurrent_downloads: usize, @@ -224,43 +234,72 @@ impl TrainArgs { pub fn checkpoint_config(&self) -> Result> { let hub_read_token = std::env::var("HF_TOKEN").ok(); - let checkpoint_upload_info = match ( - &hub_read_token, - self.hub_repo.clone(), - self.checkpoint_dir.clone(), - self.delete_old_steps, - self.keep_steps, - ) { - (Some(token), Some(repo), Some(dir), delete_old_steps, keep_steps) => { - if keep_steps == 0 { - bail!("keep_steps must be >= 1 for hub repository uploads (got {keep_steps})") + + if self.hub_repo.is_some() && self.gcs_bucket.is_some() { + bail!("Use either GCS or HF hub for checkpoint uploads, not both."); + } + + let checkpoint_dir = match &self.checkpoint_dir { + Some(dir) => dir, + None => { + if self.hub_repo.is_some() || self.gcs_bucket.is_some() { + bail!( + "--hub-repo or --gcs-bucket was set, but no --checkpoint-dir was passed!" + ); } - Some(CheckpointConfig { - checkpoint_dir: dir, - hub_upload: Some(HubUploadInfo { - hub_repo: repo, - hub_token: token.to_string(), - }), - delete_old_steps, - keep_steps, - }) - } - (None, Some(_), Some(_), _, _) => { - bail!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.") - } - (_, Some(_), None, _, _) => { - bail!("--hub-repo was set, but no --checkpoint-dir was passed!") + return Ok(None); } - (_, None, Some(dir), delete_old_steps, keep_steps) => Some(CheckpointConfig { - checkpoint_dir: dir, - hub_upload: None, - delete_old_steps, - keep_steps, - }), - (_, None, _, _, _) => None, }; - Ok(checkpoint_upload_info) + let upload_info = self.build_upload_info(&hub_read_token)?; + + if upload_info.is_some() && self.keep_steps == 0 { + bail!( + "keep_steps must be >= 1 for checkpoint uploads (got {})", + self.keep_steps + ); + } + + Ok(Some(CheckpointConfig { + checkpoint_dir: checkpoint_dir.clone(), + upload_info, + delete_old_steps: self.delete_old_steps, + keep_steps: self.keep_steps, + })) + } + + fn build_upload_info(&self, hub_token: &Option) -> Result> { + if let Some(repo) = &self.hub_repo { + return self.build_hub_upload_info(repo, hub_token); + } + + if let Some(bucket) = &self.gcs_bucket { + return self.build_gcs_upload_info(bucket); + } + + Ok(None) + } + + fn build_hub_upload_info( + &self, + repo: &str, + token: &Option, + ) -> Result> { + let token = token.as_ref().ok_or_else(|| { + anyhow::anyhow!("hub-repo and checkpoint-dir set, but no HF_TOKEN env variable.") + })?; + + Ok(Some(UploadInfo::Hub(HubUploadInfo { + hub_repo: repo.to_string(), + hub_token: token.to_string(), + }))) + } + + fn build_gcs_upload_info(&self, bucket: &str) -> Result> { + Ok(Some(UploadInfo::Gcs(GcsUploadInfo { + gcs_bucket: bucket.to_string(), + gcs_prefix: self.gcs_prefix.clone(), + }))) } pub fn eval_tasks(&self) -> Result> { diff --git a/shared/client/src/lib.rs b/shared/client/src/lib.rs index b4ea80576..bdad43e30 100644 --- a/shared/client/src/lib.rs +++ b/shared/client/src/lib.rs @@ -9,7 +9,8 @@ pub use cli::{TrainArgs, prepare_environment, print_identity_keys, read_identity pub use client::Client; pub use protocol::{Broadcast, BroadcastType, Finished, NC, TrainingResult}; pub use state::{ - CheckpointConfig, HubUploadInfo, InitRunError, RoundState, RunInitConfig, RunInitConfigAndIO, + CheckpointConfig, GcsUploadInfo, HubUploadInfo, InitRunError, RoundState, RunInitConfig, + RunInitConfigAndIO, UploadInfo, }; pub use tui::{ClientTUI, ClientTUIState}; diff --git a/shared/client/src/state/cooldown.rs b/shared/client/src/state/cooldown.rs index da1b2569a..4a3ce6a82 100644 --- a/shared/client/src/state/cooldown.rs +++ b/shared/client/src/state/cooldown.rs @@ -1,14 +1,14 @@ -use crate::HubUploadInfo; - +use crate::UploadInfo; use psyche_coordinator::{ Coordinator, - model::{self, HubRepo}, + model::{self}, }; -use psyche_core::{FixedString, NodeIdentity}; -use psyche_data_provider::{UploadModelError, upload_model_repo_async}; +use psyche_core::NodeIdentity; +use psyche_data_provider::{GcsManifestMetadata, UploadError, upload_to_gcs, upload_to_hub}; +#[cfg(feature = "python")] +use psyche_modeling::CausalLM; use psyche_modeling::{ - CausalLM, SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, - save_tensors_into_safetensors, + SaveSafetensorsError, Trainer, TrainerThreadCommunicationError, save_tensors_into_safetensors, }; use std::{ cmp::Reverse, @@ -22,7 +22,7 @@ use tokio::{ sync::{Mutex, mpsc}, task::JoinHandle, }; -use tracing::{Instrument, error, info, info_span, warn}; +use tracing::{Instrument, info, info_span, warn}; use super::{ CheckpointConfig, @@ -42,7 +42,7 @@ pub enum CooldownError { } pub struct CooldownStepMetadata { - tx_checkpoint: mpsc::UnboundedSender, + tx_checkpoint: mpsc::UnboundedSender, tx_model: mpsc::UnboundedSender>, checkpoint_info: Option, checkpoint_extra_files: Vec, @@ -59,7 +59,7 @@ pub struct CooldownStepMetadata { impl CooldownStepMetadata { pub fn new( - tx_checkpoint: mpsc::UnboundedSender, + tx_checkpoint: mpsc::UnboundedSender, tx_model: mpsc::UnboundedSender>, checkpoint_info: Option, checkpoint_extra_files: Vec, @@ -93,8 +93,8 @@ pub enum CheckpointError { #[error("Writing extra file to disk failed: {0}")] WriteExtraFile(#[from] tokio::io::Error), - #[error("Couldn't upload model to huggingface: {0}")] - UploadError(#[from] UploadModelError), + #[error("Couldn't upload model to huggingface or GCS: {0}")] + UploadError(#[from] UploadError), #[error("Couldn't send checkpoint - channel closed")] SendCheckpoint, @@ -137,6 +137,7 @@ impl CooldownStepMetadata { let step = state.progress.step - 1; let run_id = String::from(&state.run_id); + let epoch = state.progress.epoch as u32; let checkpoint_extra_files = self.checkpoint_extra_files.clone(); let checkpoint_info = self.checkpoint_info.clone(); let tx_checkpoint = self.tx_checkpoint.clone(); @@ -166,103 +167,51 @@ impl CooldownStepMetadata { .send(variables_clone) .map_err(|_| CheckpointError::SendCheckpoint)?; - let (variables, trainer) = if checkpoint_info.is_some() { - // convert from internal shape to serialized shape (e.g. torchtitan to hf) - tokio::task::spawn_blocking(|| (trainer.convert(Some(variables)), trainer)) - .await - .map_err(|_| CheckpointError::ExtractThreadCrashed)? - } else { - (variables, trainer) + // convert from internal shape to serialized shape (e.g. torchtitan to hf) + let (variables, trainer) = match trainer { + #[cfg(feature = "python")] + Trainer::PythonDistributed(_) => { + info!("Converting distributed trainer variables for checkpointing..."); + tokio::task::spawn_blocking(|| (trainer.convert(Some(variables)), trainer)) + .await + .map_err(|_| CheckpointError::ExtractThreadCrashed)? + } + _ => (variables, trainer), }; trainers.push(trainer); let evals = model_task_runner.start(trainers); let Some(CheckpointConfig { - hub_upload, + upload_info, checkpoint_dir, delete_old_steps, keep_steps, }) = checkpoint_info else { - // If there was no HF checkpointing configuration, return immediately return Ok((evals, None)); }; - // Start the upload process of the updated model parameters in a separate task let upload_handle = tokio::task::spawn(async move { let path = checkpoint_dir.join(format!("{run_id}-step{step}")); - info!("Saving to {}", path.display()); - let mut local = tokio::task::spawn_blocking({ - let path = path.clone(); - move || save_tensors_into_safetensors(variables, path) - }) - .await - .map_err(|_| CheckpointError::WriteThreadCrashed)??; - - for extra in checkpoint_extra_files { - let to = path.join(extra.file_name().unwrap()); - tokio::fs::copy(extra.clone(), to.clone()) - .await - .map_err(CheckpointError::WriteExtraFile)?; - local.push(to); + let local = + save_checkpoint_locally(path, variables, checkpoint_extra_files).await?; + + if let Some(upload_info) = upload_info { + let manifest_metadata = GcsManifestMetadata { + epoch, + run_id: run_id.clone(), + }; + upload_checkpoint( + upload_info, + manifest_metadata, + local.clone(), + step as u64, + tx_checkpoint, + ) + .await?; } - let Some(HubUploadInfo { - hub_repo, - hub_token, - }) = hub_upload - else { - cleanup_dirs( - delete_queue, - keep_steps, - run_id, - delete_old_steps, - step, - checkpoint_dir, - ) - .await; - return Ok::<(), CheckpointError>(()); - }; - - info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); - let revision = match upload_model_repo_async( - hub_repo.clone(), - local, - hub_token.clone(), - Some(format!("step {step}")), - None, - ) - .await - { - Ok(revision) => { - info!( - repo = hub_repo, - revision = revision, - "Upload to HuggingFace complete" - ); - revision - } - Err(err) => { - error!(repo = hub_repo, "Error uploading to HuggingFace: {err:#}"); - return Err(err.into()); - } - }; - - tx_checkpoint - .send(HubRepo { - repo_id: FixedString::from_str_truncated(&hub_repo), - revision: Some(FixedString::from_str_truncated(&revision)), - }) - .map_err(|_| CheckpointError::SendCheckpoint)?; - - // we put the cleanup step at the end, so that if keep_steps == 0 the logic will still work - // we'll just delete the dir after we've uploaded it - // if we fail in any of the above steps we may wind up not queueing this dir for delete - // but that's probably better than risking having the dir deleted from under us - // for a relatively low priority disk cleanup task - // and this may actually be preferred anyway because if we failed to upload, we may want to keep - // the data around locally on disk cleanup_dirs( delete_queue, keep_steps, @@ -280,12 +229,56 @@ impl CooldownStepMetadata { } .instrument(info_span!("checkpointing")), ); + Ok(CooldownStep { checkpointing_and_evals, }) } } +async fn save_checkpoint_locally( + path: PathBuf, + variables: HashMap, + checkpoint_extra_files: Vec, +) -> Result, CheckpointError> { + info!("Saving to {}", path.display()); + let mut local = tokio::task::spawn_blocking({ + let path = path.clone(); + move || save_tensors_into_safetensors(variables, path) + }) + .await + .map_err(|_| CheckpointError::WriteThreadCrashed)??; + + for extra in checkpoint_extra_files { + let to = path.join(extra.file_name().unwrap()); + tokio::fs::copy(extra.clone(), to.clone()) + .await + .map_err(CheckpointError::WriteExtraFile)?; + local.push(to); + } + + Ok(local) +} + +async fn upload_checkpoint( + upload_info: UploadInfo, + manifest_metadata: GcsManifestMetadata, + local: Vec, + step: u64, + tx_checkpoint: mpsc::UnboundedSender, +) -> Result<(), CheckpointError> { + match upload_info { + UploadInfo::Gcs(gcs_info) => { + upload_to_gcs(gcs_info, manifest_metadata, local, step, tx_checkpoint) + .await + .map_err(CheckpointError::UploadError) + } + UploadInfo::Hub(hub_info) => upload_to_hub(hub_info, local, step, tx_checkpoint) + .await + .map_err(CheckpointError::UploadError), + } +} + type CheckpointAndEvalsHandle = JoinHandle< Result< ( diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index f7326f382..7a945f74b 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -7,8 +7,9 @@ use psyche_core::{ Barrier, CancellableBarrier, IntegrationTestLogMarker, NodeIdentity, Shuffle, TokenSize, }; use psyche_data_provider::{ - DataProvider, DataProviderTcpClient, DummyDataProvider, PreprocessedDataProvider, Split, - WeightedDataProvider, download_dataset_repo_async, download_model_repo_async, + DataProvider, DataProviderTcpClient, DownloadError, DummyDataProvider, + PreprocessedDataProvider, Split, WeightedDataProvider, download_dataset_repo_async, + download_model_from_gcs_async, download_model_repo_async, http::{FileURLs, HttpDataProvider}, }; use psyche_metrics::ClientMetrics; @@ -90,6 +91,9 @@ pub enum InitRunError { #[error("failed to read HF model info: {0}")] HfModelLoad(#[from] hf_hub::api::tokio::ApiError), + #[error("failed to download model from GCS: {0}")] + GcsModelLoad(#[from] DownloadError), + #[error("model loading thread crashed")] ModelLoadingThreadCrashed(JoinError), @@ -151,7 +155,7 @@ pub struct RunInitConfigAndIO { pub tx_health_check: UnboundedSender>, pub tx_witness: UnboundedSender, - pub tx_checkpoint: UnboundedSender, + pub tx_checkpoint: UnboundedSender, pub tx_model: UnboundedSender>, pub tx_parameters_req: UnboundedSender<(Vec, OneshotModelParameterSender)>, pub tx_config: UnboundedSender<(String, String)>, @@ -311,7 +315,10 @@ impl RunInitConfigAndIO { + model::Checkpoint::Hub(_) + | model::Checkpoint::P2P(_) + | model::Checkpoint::P2PGcs(_) + | model::Checkpoint::Gcs(_) => { let checkpoint = llm.checkpoint; tokio::spawn(async move { let (source, tokenizer, checkpoint_extra_files) = match checkpoint { @@ -367,7 +374,7 @@ impl RunInitConfigAndIO { + model::Checkpoint::P2P(_) | model::Checkpoint::P2PGcs(_) => { let (tx_model_config_response, rx_model_config_response) = oneshot::channel(); info!("Checkpoint is p2p, requesting model config over network"); @@ -427,6 +434,39 @@ impl RunInitConfigAndIO { + let bucket: String = (&gcs_repo.bucket).into(); + let prefix: Option = gcs_repo.prefix.map(|p| (&p).into()); + + info!( + "Downloading model from gs://{}/{}", + bucket, + prefix.as_deref().unwrap_or("") + ); + + let repo_files = + download_model_from_gcs_async(&bucket, prefix.as_deref()) + .await?; + + let checkpoint_extra_files = repo_files + .iter() + .filter(|file| { + file.ends_with("config.json") + || file.ends_with("tokenizer.json") + || file.ends_with("tokenizer_config.json") + || file.ends_with("special_tokens_map.json") + || file.ends_with("generation_config.json") + || file.ends_with(".py") + }) + .cloned() + .collect(); + let tokenizer = Arc::new(auto_tokenizer(&repo_files)?); + ( + PretrainedSource::::RepoFiles(repo_files), + tokenizer, + checkpoint_extra_files, + ) + } _ => unreachable!(), }; diff --git a/shared/client/src/state/mod.rs b/shared/client/src/state/mod.rs index e4d73d5b9..78e6cd1eb 100644 --- a/shared/client/src/state/mod.rs +++ b/shared/client/src/state/mod.rs @@ -14,6 +14,7 @@ mod warmup; mod witness; pub use init::{InitRunError, RunInitConfig, RunInitConfigAndIO}; +pub use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; pub use round_state::RoundState; pub use steps::{ApplyMessageOutcome, RunManager}; -pub use types::{CheckpointConfig, DistroBroadcastAndPayload, FinishedBroadcast, HubUploadInfo}; +pub use types::{CheckpointConfig, DistroBroadcastAndPayload, FinishedBroadcast, UploadInfo}; diff --git a/shared/client/src/state/types.rs b/shared/client/src/state/types.rs index 2edf22760..29734f1a0 100644 --- a/shared/client/src/state/types.rs +++ b/shared/client/src/state/types.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use psyche_coordinator::CommitteeProof; use psyche_core::{BatchId, MerkleRoot, NodeIdentity}; +use psyche_data_provider::{GcsUploadInfo, HubUploadInfo}; use psyche_modeling::DistroResult; use psyche_network::{BlobTicket, TransmittableDistroResult}; use tch::TchError; @@ -9,14 +10,14 @@ use thiserror::Error; use tokio::task::JoinHandle; #[derive(Debug, Clone)] -pub struct HubUploadInfo { - pub hub_repo: String, - pub hub_token: String, +pub enum UploadInfo { + Hub(HubUploadInfo), + Gcs(GcsUploadInfo), } #[derive(Debug, Clone)] pub struct CheckpointConfig { - pub hub_upload: Option, + pub upload_info: Option, pub checkpoint_dir: PathBuf, pub delete_old_steps: bool, pub keep_steps: u32, diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index 78f054262..c726655e2 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -1,6 +1,6 @@ use crate::{ Commitment, Committee, CommitteeProof, CommitteeSelection, WitnessProof, - model::{Checkpoint, HubRepo, Model}, + model::{Checkpoint, Model}, }; use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; @@ -597,22 +597,43 @@ impl Coordinator { &mut self, from: &T, index: u64, - hub_repo: HubRepo, + checkpoint_repo: Checkpoint, ) -> std::result::Result<(), CoordinatorError> { let index = index as usize; if index >= self.epoch_state.clients.len() || self.epoch_state.clients[index].id != *from { return Err(CoordinatorError::InvalidCommitteeProof); } - // TODO: In the case of more than one checkpointer, this will overwrite the hub repo - // with the last checkpointed one. We could instead have a vector of hub repos to have + + // TODO: In the case of more than one checkpointer, this will overwrite the checkpoint + // with the last checkpointed one. We could instead have a vector of checkpoints to have // more download options. - match &mut self.model { - Model::LLM(llm) => match llm.checkpoint { - Checkpoint::P2P(_) => llm.checkpoint = Checkpoint::P2P(hub_repo), - Checkpoint::Hub(_) => llm.checkpoint = Checkpoint::Hub(hub_repo), - _ => {} - }, + let Model::LLM(llm) = &mut self.model; + match (&llm.checkpoint, checkpoint_repo) { + // If current is P2P, wrap the new checkpoint in P2P + (Checkpoint::P2P(_), Checkpoint::Hub(hub_repo)) => { + llm.checkpoint = Checkpoint::P2P(hub_repo); + } + (Checkpoint::P2PGcs(_), Checkpoint::Gcs(gcs_repo)) => { + llm.checkpoint = Checkpoint::P2PGcs(gcs_repo); + } + // If current is Hub, only accept Hub updates + (Checkpoint::Hub(_), Checkpoint::Hub(hub_repo)) => { + llm.checkpoint = Checkpoint::Hub(hub_repo); + } + // If current is Gcs, only accept Gcs updates + (Checkpoint::Gcs(_), Checkpoint::Gcs(gcs_repo)) => { + llm.checkpoint = Checkpoint::Gcs(gcs_repo); + } + (Checkpoint::P2PGcs(_), Checkpoint::Hub(hub_repo)) => { + llm.checkpoint = Checkpoint::P2P(hub_repo); + } + (Checkpoint::P2P(_), Checkpoint::Gcs(gcs_repo)) => { + llm.checkpoint = Checkpoint::P2PGcs(gcs_repo); + } + // Ignore other combinations + _ => {} } + Ok(()) } @@ -912,8 +933,10 @@ impl Coordinator { .any(|client| pending_clients_unordered.contains(&client.id)); if all_prev_clients_disconnected { let Model::LLM(llm) = &mut self.model; - if let Checkpoint::P2P(hub_repo) = llm.checkpoint { - llm.checkpoint = Checkpoint::Hub(hub_repo); + match llm.checkpoint { + Checkpoint::P2P(hub_repo) => llm.checkpoint = Checkpoint::Hub(hub_repo), + Checkpoint::P2PGcs(gcs_repo) => llm.checkpoint = Checkpoint::Gcs(gcs_repo), + _ => {} } } @@ -1045,6 +1068,7 @@ impl Coordinator { Checkpoint::Hub(hub_repo) | Checkpoint::Dummy(hub_repo) => { llm.checkpoint = Checkpoint::P2P(hub_repo) } + Checkpoint::Gcs(gcs_repo) => llm.checkpoint = Checkpoint::P2PGcs(gcs_repo), _ => {} } diff --git a/shared/coordinator/src/model.rs b/shared/coordinator/src/model.rs index 60d539f8a..3176f276e 100644 --- a/shared/coordinator/src/model.rs +++ b/shared/coordinator/src/model.rs @@ -239,6 +239,32 @@ impl HubRepo { } } +#[derive( + Clone, + Debug, + Copy, + AnchorDeserialize, + AnchorSerialize, + InitSpace, + Serialize, + Deserialize, + PartialEq, + TS, +)] +pub struct GcsRepo { + pub bucket: FixedString<{ SOLANA_MAX_STRING_LEN }>, + pub prefix: Option>, +} + +impl GcsRepo { + pub fn dummy() -> Self { + Self { + bucket: FixedString::new(), + prefix: None, + } + } +} + #[derive( AnchorSerialize, AnchorDeserialize, @@ -257,6 +283,8 @@ pub enum Checkpoint { Dummy(HubRepo), Hub(HubRepo), P2P(HubRepo), + Gcs(GcsRepo), + P2PGcs(GcsRepo), } impl std::fmt::Display for Checkpoint { @@ -268,6 +296,10 @@ impl std::fmt::Display for Checkpoint { Checkpoint::P2P(hub_repo) => { write!(f, "P2P - Hub repo: {}", &hub_repo.repo_id) } + Checkpoint::Gcs(gcs_repo) | Checkpoint::P2PGcs(gcs_repo) => match &gcs_repo.prefix { + Some(prefix) => write!(f, "gs://{}/{}", &gcs_repo.bucket, prefix), + None => write!(f, "gs://{}", &gcs_repo.bucket), + }, } } } @@ -308,6 +340,9 @@ impl Model { Checkpoint::Ephemeral => true, Checkpoint::Hub(hub_repo) => hub_repo.repo_id.is_empty(), Checkpoint::P2P(hub_repo) => hub_repo.repo_id.is_empty(), + Checkpoint::Gcs(gcs_repo) | Checkpoint::P2PGcs(gcs_repo) => { + gcs_repo.bucket.is_empty() + } }; if bad_checkpoint { diff --git a/shared/data-provider/Cargo.toml b/shared/data-provider/Cargo.toml index ea0e59791..7ad69e616 100644 --- a/shared/data-provider/Cargo.toml +++ b/shared/data-provider/Cargo.toml @@ -27,6 +27,8 @@ postcard.workspace = true bytemuck.workspace = true reqwest = "0.12.12" google-cloud-storage = "0.24.0" +chrono = { version = "0.4", features = ["serde"] } +serde_json.workspace = true ts-rs.workspace = true rayon.workspace = true @@ -38,4 +40,3 @@ test-log.workspace = true clap.workspace = true tempfile = "3.15.0" static-web-server = { git = "https://github.com/arilotter/static-web-server", rev = "c91445427b56c5ddff0365d8ec116e3b567377ac" } # forked to add a channel for getting the port -serde_json.workspace = true diff --git a/shared/data-provider/examples/tcp.rs b/shared/data-provider/examples/tcp.rs index 13dcf8b44..67cc9ab9d 100644 --- a/shared/data-provider/examples/tcp.rs +++ b/shared/data-provider/examples/tcp.rs @@ -37,7 +37,7 @@ impl WatcherBackend for DummyBackend { bail!("Data provider does not send health check"); } - async fn send_checkpoint(&mut self, _checkpoint: model::HubRepo) -> anyhow::Result<()> { + async fn send_checkpoint(&mut self, _checkpoint: model::Checkpoint) -> anyhow::Result<()> { bail!("Data provider does not send checkpoints"); } } diff --git a/shared/data-provider/src/errors.rs b/shared/data-provider/src/errors.rs new file mode 100644 index 000000000..b84bc5f9a --- /dev/null +++ b/shared/data-provider/src/errors.rs @@ -0,0 +1,53 @@ +use std::path::PathBuf; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum UploadError { + #[error("path {0} is not a file")] + NotAFile(PathBuf), + + #[error("file {0} doesn't have a valid utf-8 representation")] + InvalidFilename(PathBuf), + + #[error("failed to send checkpoint notification")] + SendCheckpoint, + + // Hub-specific errors + #[error("failed to connect to HF hub: {0}")] + HfHub(#[from] hf_hub::api::tokio::ApiError), + + #[error("failed to commit files: {0}")] + Commit(#[from] hf_hub::api::tokio::CommitError), + + // GCS-specific errors + #[error("GCS authentication failed: {0}")] + GcsAuth(#[from] google_cloud_storage::client::google_cloud_auth::error::Error), + + #[error("GCS operation failed: {0}")] + GcsStorage(#[from] google_cloud_storage::http::Error), + + // Common errors + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), +} + +#[derive(Error, Debug)] +pub enum DownloadError { + #[error("failed to connect to HF hub: {0}")] + HfHub(#[from] hf_hub::api::tokio::ApiError), + + #[error("GCS authentication failed: {0}")] + GcsAuth(#[from] google_cloud_storage::client::google_cloud_auth::error::Error), + + #[error("GCS operation failed: {0}")] + GcsStorage(#[from] google_cloud_storage::http::Error), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), + + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), +} diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs new file mode 100644 index 000000000..71f29e414 --- /dev/null +++ b/shared/data-provider/src/gcs.rs @@ -0,0 +1,447 @@ +use crate::errors::{DownloadError, UploadError}; +use chrono::{DateTime, Utc}; +use google_cloud_storage::client::{Client, ClientConfig}; +use google_cloud_storage::http::objects::upload::Media; +use google_cloud_storage::http::objects::upload::UploadObjectRequest; +use google_cloud_storage::http::objects::upload::UploadType; +use google_cloud_storage::http::objects::{ + download::Range, get::GetObjectRequest, list::ListObjectsRequest, +}; +use psyche_coordinator::model::{self, GcsRepo}; +use psyche_core::FixedString; +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; +use tokio::runtime::Runtime; +use tokio::sync::mpsc; +use tracing::info; + +/// Checkpoint manifest.json uploaded to GCS alongside safetensors files. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GcsCheckpointManifest { + pub metadata: ManifestMetadata, + pub files: Vec, +} + +/// Checkpoint metadata. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ManifestMetadata { + pub timestamp: DateTime, + pub epoch: u32, + pub step: u32, + pub run_id: String, +} + +/// Single file entry in the manifest. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ManifestFileEntry { + pub filename: String, + pub generation: i64, + pub size_bytes: u64, +} + +#[derive(Debug, Clone)] +pub struct GcsUploadInfo { + pub gcs_bucket: String, + pub gcs_prefix: Option, +} + +#[derive(Debug, Clone)] +pub struct GcsManifestMetadata { + pub epoch: u32, + pub run_id: String, +} + +const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; + +fn get_cache_base(bucket: &str) -> PathBuf { + // Use HF_HOME if set, otherwise fall back to ~/.cache + std::env::var("HF_HOME") + .map(PathBuf::from) + .unwrap_or_else(|_| { + std::env::var("HOME") + .map(|h| PathBuf::from(h).join(".cache")) + .unwrap_or_else(|_| PathBuf::from(".cache")) + }) + .join("psyche") + .join("gcs") + .join(bucket) +} + +fn get_cache_dir( + bucket: &str, + prefix: Option<&str>, + step: u32, + manifest_generation: i64, +) -> PathBuf { + let base = get_cache_base(bucket); + let versioned_folder = format!("step-{}-{}", step, manifest_generation); + + match prefix { + Some(p) => base.join(p.trim_end_matches('/')).join(versioned_folder), + None => base.join(versioned_folder), + } +} + +fn get_cache_dir_no_manifest(bucket: &str, prefix: Option<&str>) -> PathBuf { + let base = get_cache_base(bucket); + + match prefix { + Some(p) => base.join(p.trim_end_matches('/')).join("no_manifest"), + None => base.join("no_manifest"), + } +} + +fn collect_cached_files( + cache_dir: &Path, + manifest: &GcsCheckpointManifest, +) -> Option> { + let mut files = Vec::new(); + for file_entry in &manifest.files { + let path = cache_dir.join(&file_entry.filename); + if !path.exists() { + return None; + } + files.push(path); + } + Some(files) +} + +pub async fn download_model_from_gcs_async( + bucket: &str, + prefix: Option<&str>, +) -> Result, DownloadError> { + // Use authenticated client if GOOGLE_APPLICATION_CREDENTIALS is set, otherwise anonymous + let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { + info!("Using authenticated GCS client"); + ClientConfig::default().with_auth().await? + } else { + info!("Using anonymous GCS client"); + ClientConfig::default().anonymous() + }; + let client = Client::new(config); + + let manifest_object_path = match prefix { + Some(p) => format!("{}/manifest.json", p), + None => "manifest.json".to_string(), + }; + + // Get manifest metadata to obtain generation number + let manifest_metadata = client + .get_object(&GetObjectRequest { + bucket: bucket.to_owned(), + object: manifest_object_path.clone(), + ..Default::default() + }) + .await; + + match manifest_metadata { + Ok(object_meta) => { + let manifest_generation = object_meta.generation; + + // Download manifest content + let manifest_data = client + .download_object( + &GetObjectRequest { + bucket: bucket.to_owned(), + object: manifest_object_path, + ..Default::default() + }, + &Range::default(), + ) + .await?; + + let manifest: GcsCheckpointManifest = serde_json::from_slice(&manifest_data)?; + + info!( + "Found manifest: step {}, epoch {}, generation {}", + manifest.metadata.step, manifest.metadata.epoch, manifest_generation + ); + + // Build versioned cache path + let cache_dir = + get_cache_dir(bucket, prefix, manifest.metadata.step, manifest_generation); + + // Check if all manifest files exist in cache + let mut files = if let Some(cached) = collect_cached_files(&cache_dir, &manifest) { + info!("Using cached checkpoint at {:?}", cache_dir); + cached + } else { + info!( + "Model not found in cache, downloading checkpoint to {:?}", + cache_dir + ); + std::fs::create_dir_all(&cache_dir)?; + download_files_from_manifest(&client, bucket, prefix, &cache_dir, &manifest).await? + }; + // Download config files (json, py) - skips if already cached + let config_files = + download_files_no_manifest(&client, bucket, prefix, &cache_dir, &[".json", ".py"]) + .await?; + files.extend(config_files); + Ok(files) + } + Err(_) => { + // Fallback for old checkpoints without manifest + info!("No manifest found, downloading model without manifest"); + let cache_dir = get_cache_dir_no_manifest(bucket, prefix); + std::fs::create_dir_all(&cache_dir)?; + download_files_no_manifest(&client, bucket, prefix, &cache_dir, &MODEL_EXTENSIONS).await + } + } +} + +async fn download_files_from_manifest( + client: &Client, + bucket: &str, + prefix: Option<&str>, + cache_dir: &Path, + manifest: &GcsCheckpointManifest, +) -> Result, DownloadError> { + let mut downloaded_files = Vec::new(); + + for file_entry in &manifest.files { + let object_name = match prefix { + Some(p) => format!("{}/{}", p, file_entry.filename), + None => file_entry.filename.clone(), + }; + let local_path = cache_dir.join(&file_entry.filename); + + if local_path.exists() { + info!("Using cached: {}", file_entry.filename); + downloaded_files.push(local_path); + continue; + } + + info!( + "Downloading: gs://{}/{} (generation {})", + bucket, object_name, file_entry.generation + ); + + let data = client + .download_object( + &GetObjectRequest { + bucket: bucket.to_owned(), + object: object_name, + generation: Some(file_entry.generation), + ..Default::default() + }, + &Range::default(), + ) + .await?; + + std::fs::write(&local_path, &data)?; + info!("Downloaded: {} ({} bytes)", file_entry.filename, data.len()); + downloaded_files.push(local_path); + } + + Ok(downloaded_files) +} + +/// Download model files by listing the bucket. Skips files that already exist in cache. +/// Used for initial model download (no manifest) and to fetch config files (json, py) after manifest download. +async fn download_files_no_manifest( + client: &Client, + bucket: &str, + prefix: Option<&str>, + cache_dir: &Path, + extensions: &[&str], +) -> Result, DownloadError> { + let mut all_objects = vec![]; + let mut page_token: Option = None; + + loop { + let results = client + .list_objects(&ListObjectsRequest { + bucket: bucket.to_owned(), + prefix: prefix.map(|s| s.to_owned()), + page_token: page_token.clone(), + ..Default::default() + }) + .await?; + + for obj in results.items.iter().flatten() { + if extensions.iter().any(|ext| obj.name.ends_with(ext)) { + all_objects.push(obj.name.clone()); + } + } + + match results.next_page_token { + Some(token) => page_token = Some(token), + None => break, + } + } + + info!( + "Found {} files ({}) in gs://{}/{}", + all_objects.len(), + extensions.join(", "), + bucket, + prefix.unwrap_or("") + ); + + let mut downloaded_files = Vec::new(); + + for object_name in all_objects { + let filename = object_name.rsplit('/').next().unwrap_or(&object_name); + let local_path = cache_dir.join(filename); + + if local_path.exists() { + info!("Using cached: {}", filename); + downloaded_files.push(local_path); + continue; + } + + info!("Downloading: gs://{}/{}", bucket, object_name); + + let data = client + .download_object( + &GetObjectRequest { + bucket: bucket.to_owned(), + object: object_name.clone(), + ..Default::default() + }, + &Range::default(), + ) + .await?; + + // Write to cache + std::fs::write(&local_path, &data)?; + + info!("Downloaded: {} ({} bytes)", filename, data.len()); + + downloaded_files.push(local_path); + } + + Ok(downloaded_files) +} + +pub fn download_model_from_gcs_sync( + bucket: &str, + prefix: Option<&str>, +) -> Result, DownloadError> { + let rt = Runtime::new().map_err(DownloadError::Io)?; + rt.block_on(download_model_from_gcs_async(bucket, prefix)) +} + +pub async fn upload_to_gcs( + gcs_info: GcsUploadInfo, + manifest_metadata: GcsManifestMetadata, + local: Vec, + step: u64, + tx_checkpoint: mpsc::UnboundedSender, +) -> Result<(), UploadError> { + let GcsUploadInfo { + gcs_bucket, + gcs_prefix, + } = gcs_info; + + let GcsManifestMetadata { epoch, run_id } = manifest_metadata; + + info!(bucket = gcs_bucket, "Uploading checkpoint to GCS"); + + let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { + info!("Using authenticated GCS client"); + ClientConfig::default().with_auth().await? + } else { + info!("Using anonymous GCS client"); + ClientConfig::default().anonymous() + }; + let client = Client::new(config); + + let mut manifest = GcsCheckpointManifest { + metadata: ManifestMetadata { + timestamp: Utc::now(), + epoch, + step: step as u32, + run_id, + }, + files: Vec::new(), + }; + + for path in local { + let file_name = path + .file_name() + .ok_or_else(|| UploadError::NotAFile(path.clone()))? + .to_str() + .ok_or_else(|| UploadError::InvalidFilename(path.clone()))?; + + // Only upload safetensors files + if !file_name.ends_with(".safetensors") { + continue; + } + + let object_name = match &gcs_prefix { + Some(p) => format!("{}/{}", p, file_name), + None => file_name.to_string(), + }; + + let size = std::fs::metadata(&path)?.len(); + let data = tokio::fs::read(&path).await?; + + let upload_type = UploadType::Simple(Media::new(object_name.clone())); + let uploaded = client + .upload_object( + &UploadObjectRequest { + bucket: gcs_bucket.clone(), + ..Default::default() + }, + data, + &upload_type, + ) + .await?; + + info!( + bucket = gcs_bucket, + object = object_name, + size = uploaded.size, + generation = uploaded.generation, + "Uploaded file to GCS" + ); + + manifest.files.push(ManifestFileEntry { + filename: file_name.to_string(), + generation: uploaded.generation, + size_bytes: size, + }); + } + + // Upload the manifest file + let manifest_path = match &gcs_prefix { + Some(p) => format!("{}/manifest.json", p), + None => "manifest.json".to_string(), + }; + let manifest_json = serde_json::to_string_pretty(&manifest)?; + + let upload_type = UploadType::Simple(Media::new(manifest_path.clone())); + client + .upload_object( + &UploadObjectRequest { + bucket: gcs_bucket.clone(), + ..Default::default() + }, + manifest_json.into_bytes(), + &upload_type, + ) + .await?; + + info!( + bucket = gcs_bucket, + object = manifest_path, + "Uploaded manifest to GCS" + ); + + info!( + "Upload to GCS complete at gs://{}/{}", + gcs_bucket, + gcs_prefix.as_deref().unwrap_or("") + ); + + tx_checkpoint + .send(model::Checkpoint::Gcs(GcsRepo { + bucket: FixedString::from_str_truncated(&gcs_bucket), + prefix: gcs_prefix.map(|p| FixedString::from_str_truncated(&p)), + })) + .map_err(|_| UploadError::SendCheckpoint)?; + + Ok(()) +} diff --git a/shared/data-provider/src/hub.rs b/shared/data-provider/src/hub.rs index ca090a759..13a575b84 100644 --- a/shared/data-provider/src/hub.rs +++ b/shared/data-provider/src/hub.rs @@ -1,12 +1,16 @@ +use crate::errors::UploadError; +use crate::hub::model::HubRepo; use hf_hub::{ Cache, Repo, RepoType, api::{ Siblings, - tokio::{ApiError, CommitError, UploadSource}, + tokio::{ApiError, UploadSource}, }, }; +use psyche_coordinator::model; +use psyche_core::FixedString; use std::{path::PathBuf, time::Instant}; -use thiserror::Error; +use tokio::sync::mpsc; use tracing::{error, info}; const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; @@ -189,42 +193,39 @@ pub fn download_dataset_repo_sync( ) } -#[derive(Error, Debug)] -pub enum UploadModelError { - #[error("path {0} is not a file")] - NotAFile(PathBuf), - - #[error("file {0} doesn't have a valid utf-8 representation")] - InvalidFilename(PathBuf), +#[derive(Debug, Clone)] +pub struct HubUploadInfo { + pub hub_repo: String, + pub hub_token: String, +} - #[error("failed to connect to HF hub: {0}")] - HfHub(#[from] ApiError), +pub async fn upload_to_hub( + hub_info: HubUploadInfo, + local: Vec, + step: u64, + tx_checkpoint: mpsc::UnboundedSender, +) -> Result<(), UploadError> { + let HubUploadInfo { + hub_repo, + hub_token, + } = hub_info; - #[error("failed to commit files: {0}")] - Commit(#[from] CommitError), -} + info!(repo = hub_repo, "Uploading checkpoint to HuggingFace"); -pub async fn upload_model_repo_async( - repo_id: String, - files: Vec, - token: String, - commit_message: Option, - commit_description: Option, -) -> Result { let api = hf_hub::api::tokio::ApiBuilder::new() - .with_token(Some(token)) + .with_token(Some(hub_token.clone())) .build()?; - let repo = Repo::model(repo_id.clone()); + let repo = Repo::model(hub_repo.clone()); let api_repo = api.repo(repo); - let files: Result, _> = files + let files: Result, _> = local .into_iter() .map(|path| { path.file_name() - .ok_or(UploadModelError::NotAFile(path.clone())) + .ok_or(UploadError::NotAFile(path.clone())) .and_then(|name| { name.to_str() - .ok_or(UploadModelError::InvalidFilename(path.clone())) + .ok_or(UploadError::InvalidFilename(path.clone())) .map(|s| s.to_string()) }) .map(|name| (path.into(), name)) @@ -233,32 +234,32 @@ pub async fn upload_model_repo_async( let files = files?; - let commit_info = match api_repo - .upload_files( - files, - commit_message.clone(), - commit_description.clone(), - false, - ) + let commit_info = api_repo + .upload_files(files, Some(format!("step {step}")), None, false) .await - { - Ok(info) => { - info!( - repo = repo_id, - oid = info.oid, - "Successfully uploaded files to HuggingFace" - ); - info - } - Err(e) => { + .map_err(|e| { error!( - repo = repo_id, + repo = hub_repo, error = ?e, - "Failed to upload files to HuggingFace. Full error details: {:#?}", - e + "Failed to upload files to HuggingFace" ); - return Err(e.into()); - } - }; - Ok(commit_info.oid) + e + })?; + + let revision = commit_info.oid; + + info!( + repo = hub_repo, + revision = revision, + "Upload to HuggingFace complete" + ); + + tx_checkpoint + .send(model::Checkpoint::Hub(HubRepo { + repo_id: FixedString::from_str_truncated(&hub_repo), + revision: Some(FixedString::from_str_truncated(&revision)), + })) + .map_err(|_| UploadError::SendCheckpoint)?; + + Ok(()) } diff --git a/shared/data-provider/src/lib.rs b/shared/data-provider/src/lib.rs index 868bfa793..0044d77d2 100644 --- a/shared/data-provider/src/lib.rs +++ b/shared/data-provider/src/lib.rs @@ -1,7 +1,9 @@ mod data_provider; mod dataset; mod dummy; +mod errors; mod file_extensions; +mod gcs; pub mod http; mod hub; mod local; @@ -13,10 +15,15 @@ mod weighted; pub use data_provider::DataProvider; pub use dataset::{Dataset, Field, Row, Split}; pub use dummy::DummyDataProvider; +pub use errors::{DownloadError, UploadError}; pub use file_extensions::{DATA_FILE_EXTENSIONS, PARQUET_EXTENSION}; +pub use gcs::{ + GcsCheckpointManifest, GcsManifestMetadata, GcsUploadInfo, ManifestFileEntry, ManifestMetadata, + download_model_from_gcs_async, download_model_from_gcs_sync, upload_to_gcs, +}; pub use hub::{ - UploadModelError, download_dataset_repo_async, download_dataset_repo_sync, - download_model_repo_async, download_model_repo_sync, upload_model_repo_async, + HubUploadInfo, download_dataset_repo_async, download_dataset_repo_sync, + download_model_repo_async, download_model_repo_sync, upload_to_hub, }; pub use local::LocalDataProvider; pub use parquet::record::{ListAccessor, MapAccessor, RowAccessor}; diff --git a/shared/eval/Cargo.toml b/shared/eval/Cargo.toml index 1b1cfdaab..18b9d76bd 100644 --- a/shared/eval/Cargo.toml +++ b/shared/eval/Cargo.toml @@ -28,3 +28,4 @@ tokio-util.workspace = true [dev-dependencies] clap.workspace = true psyche-python-extension-impl.workspace = true +tracing-subscriber = "0.3" diff --git a/shared/eval/examples/evaluate.rs b/shared/eval/examples/evaluate.rs index c21dd26cf..b344abd0b 100644 --- a/shared/eval/examples/evaluate.rs +++ b/shared/eval/examples/evaluate.rs @@ -2,7 +2,7 @@ use anyhow::Result; use clap::Parser; use indicatif::{ProgressBar, ProgressStyle}; use psyche_core::RunningAverage; -use psyche_data_provider::download_model_repo_sync; +use psyche_data_provider::{download_model_from_gcs_sync, download_model_repo_sync}; use psyche_eval::{ ALL_TASK_NAMES, EvalTaskOptions, Task, progress_bar_template_with_task, tasktype_from_name, }; @@ -13,6 +13,7 @@ use std::sync::Arc; use std::thread::JoinHandle; use tch::{Device, Kind}; use tokenizers::Tokenizer; +use tracing::Level; #[derive(Parser, Debug, Clone)] struct Args { @@ -25,6 +26,14 @@ struct Args { #[arg(long)] hf_token: Option, + /// GCS bucket name (use instead of HuggingFace model) + #[arg(long)] + gcs_bucket: Option, + + /// GCS prefix/directory within the bucket + #[arg(long)] + gcs_prefix: Option, + #[arg(long, default_value_t = ALL_TASK_NAMES.join(","))] tasks: String, @@ -53,6 +62,7 @@ struct Args { } fn main() -> Result<()> { + tracing_subscriber::fmt().with_max_level(Level::INFO).init(); let args = Args::parse(); let tasks: Result> = args .tasks @@ -75,13 +85,23 @@ fn main() -> Result<()> { } else { "".to_string() }; + let model_source = if let Some(ref bucket) = args.gcs_bucket { + format!( + "gs://{}/{}", + bucket, + args.gcs_prefix.as_deref().unwrap_or("") + ) + } else { + args.model.clone() + }; println!( - "Running tasks with model {}, seed: {}, DP={}{}", - args.model, args.seed, args.data_parallelism, limit_str + "\n\n Running tasks with model {}, seed: {}, DP={}{}", + model_source, args.seed, args.data_parallelism, limit_str ); for task in &tasks { println!(" - {}: {} few-shot examples", task, task.num_fewshot); } + println!("\n\n"); } if args.data_parallelism > 1 { @@ -107,7 +127,11 @@ fn main() -> Result<()> { } } - let repo = download_model_repo_sync(&args.model, args.revision, None, args.hf_token, true)?; + let repo = if let Some(ref bucket) = args.gcs_bucket { + download_model_from_gcs_sync(bucket, args.gcs_prefix.as_deref())? + } else { + download_model_repo_sync(&args.model, args.revision, None, args.hf_token, true)? + }; let tokenizer = auto_tokenizer(&repo)?; let (python, python_arch) = { diff --git a/shared/watcher/src/traits.rs b/shared/watcher/src/traits.rs index 1cd96a7ed..50bf317bb 100644 --- a/shared/watcher/src/traits.rs +++ b/shared/watcher/src/traits.rs @@ -27,5 +27,5 @@ pub trait Backend: Send + Sync { async fn wait_for_new_state(&mut self) -> Result>; async fn send_witness(&mut self, opportunistic_data: OpportunisticData) -> Result<()>; async fn send_health_check(&mut self, health_check: HealthChecks) -> Result<()>; - async fn send_checkpoint(&mut self, checkpoint: model::HubRepo) -> Result<()>; + async fn send_checkpoint(&mut self, checkpoint: model::Checkpoint) -> Result<()>; } diff --git a/tools/rust-tools/run-manager/src/commands/run/checkpoint.rs b/tools/rust-tools/run-manager/src/commands/run/checkpoint.rs index d1394f0a4..88a15536f 100644 --- a/tools/rust-tools/run-manager/src/commands/run/checkpoint.rs +++ b/tools/rust-tools/run-manager/src/commands/run/checkpoint.rs @@ -46,7 +46,7 @@ impl Command for CommandCheckpoint { &coordinator_instance, &coordinator_account, &user, - repo, + psyche_coordinator::model::Checkpoint::Hub(repo), ); let signature = backend .send_and_retry("Checkpoint", &[instruction], &[]) diff --git a/website/backend/src/dataStores/flatFileCoordinator.ts b/website/backend/src/dataStores/flatFileCoordinator.ts index e0217f639..a9b88f07e 100644 --- a/website/backend/src/dataStores/flatFileCoordinator.ts +++ b/website/backend/src/dataStores/flatFileCoordinator.ts @@ -642,13 +642,15 @@ export class FlatFileCoordinatorDataStore implements CoordinatorDataStore { } satisfies RunRoundClient }) - const checkpoint = - (typeof c.coordinator.model.LLM.checkpoint === 'object' && - (('Hub' in c.coordinator.model.LLM.checkpoint && - c.coordinator.model.LLM.checkpoint.Hub) || - ('P2P' in c.coordinator.model.LLM.checkpoint && - c.coordinator.model.LLM.checkpoint.P2P))) || - null + const checkpoint = (() => { + const cp = c.coordinator.model.LLM.checkpoint + if (typeof cp !== 'object') return null + if ('Hub' in cp) return { Hub: cp.Hub } + if ('P2P' in cp) return { Hub: cp.P2P } + if ('Gcs' in cp) return { Gcs: cp.Gcs } + if ('P2PGcs' in cp) return { Gcs: cp.P2PGcs } + return null + })() const config = c.coordinator.config diff --git a/website/backend/src/index.ts b/website/backend/src/index.ts index dcc86d63f..07549a0e7 100644 --- a/website/backend/src/index.ts +++ b/website/backend/src/index.ts @@ -382,6 +382,22 @@ async function main() { } }) + fastify.get<{ + Querystring: { bucket: string; prefix?: string } + }>('/check-gcs-bucket', async (request) => { + const { bucket, prefix } = request.query + const path = prefix ? `${prefix}/` : '' + const url = `https://storage.googleapis.com/${bucket}/${path}manifest.json` + try { + const response = await fetch(url, { method: 'HEAD' }) + return { isValid: response.ok, description: response.statusText } + } catch (error) { + const errorMessage = + error instanceof Error ? error.message : 'Unknown error' + return { isValid: false, description: errorMessage } + } + }) + fastify.get('/status', async (_, res) => { const data = { commit: process.env.GITCOMMIT ?? '???', diff --git a/website/frontend/src/components/CheckpointButton.tsx b/website/frontend/src/components/CheckpointButton.tsx index af58d11ad..51ae0d797 100644 --- a/website/frontend/src/components/CheckpointButton.tsx +++ b/website/frontend/src/components/CheckpointButton.tsx @@ -2,32 +2,54 @@ import { useState, useEffect } from 'react' import { Button } from './Button.js' import HuggingfaceIcon from '../assets/icons/huggingface.svg?react' -import { fetchCheckpointStatus } from '../fetchRuns.js' +import LinkIcon from '../assets/icons/link.svg?react' +import { + fetchCheckpointStatus, + fetchGcsCheckpointStatus, +} from '../fetchRuns.js' +import type { GcsRepo, HubRepo } from 'shared' -export const CheckpointButton = ({ - checkpoint, -}: { - checkpoint: { repo_id: string; revision?: string | null } -}) => { +type CheckpointProps = { + checkpoint: { Hub: HubRepo } | { Gcs: GcsRepo } +} + +export const CheckpointButton = ({ checkpoint }: CheckpointProps) => { const [isValid, setIsValid] = useState(undefined) - useEffect(() => { - const parsedRepo = checkpoint.repo_id.split('/') + const isHub = 'Hub' in checkpoint + const isGcs = 'Gcs' in checkpoint - if (parsedRepo.length !== 2) { - setIsValid(false) - return - } - const [owner, repo] = parsedRepo + useEffect(() => { + if (isHub) { + const repoId = checkpoint.Hub.repo_id + const parsedRepo = repoId.split('/') - fetchCheckpointStatus(owner, repo, checkpoint.revision || undefined) - .then((data) => { - setIsValid(data.isValid) - }) - .catch(() => { + if (parsedRepo.length !== 2) { setIsValid(false) - }) - }, [checkpoint.repo_id, checkpoint.revision]) + return + } + const [owner, repo] = parsedRepo + + fetchCheckpointStatus(owner, repo, checkpoint.Hub.revision || undefined) + .then((data) => { + setIsValid(data.isValid) + }) + .catch(() => { + setIsValid(false) + }) + } else if (isGcs) { + const bucket = checkpoint.Gcs.bucket + const prefix = checkpoint.Gcs.prefix || undefined + + fetchGcsCheckpointStatus(bucket, prefix) + .then((data) => { + setIsValid(data.isValid) + }) + .catch(() => { + setIsValid(false) + }) + } + }, [checkpoint, isHub, isGcs]) if (isValid === undefined) { return null @@ -38,19 +60,46 @@ export const CheckpointButton = ({ return null } - return ( - - ) + if (isHub) { + const repoId = checkpoint.Hub.repo_id + const revision = checkpoint.Hub.revision + + return ( + + ) + } + + if (isGcs) { + const bucket = checkpoint.Gcs.bucket + const prefix = checkpoint.Gcs.prefix + + return ( + + ) + } + + return null } diff --git a/website/frontend/src/fakeData.ts b/website/frontend/src/fakeData.ts index 87df62f7e..eb3d062b5 100644 --- a/website/frontend/src/fakeData.ts +++ b/website/frontend/src/fakeData.ts @@ -382,8 +382,10 @@ function makeFakeRunDataSeeded(seed = 1, step = 0, index = 0): RunData { round, clients, checkpoint: { - repo_id: 'PsycheFoundation/Skibbler', - revision: null, + Hub: { + repo_id: 'PsycheFoundation/Skibbler', + revision: null, + }, }, config: { cooldownTime: 5_000, diff --git a/website/frontend/src/fetchRuns.ts b/website/frontend/src/fetchRuns.ts index 438d2d648..742867253 100644 --- a/website/frontend/src/fetchRuns.ts +++ b/website/frontend/src/fetchRuns.ts @@ -64,6 +64,14 @@ export async function fetchCheckpointStatus( return psycheJsonFetch(`check-checkpoint?${queryParams}`) } +export async function fetchGcsCheckpointStatus( + bucket: string, + prefix?: string +): Promise { + const queryParams = `bucket=${bucket}${prefix ? `&prefix=${prefix}` : ''}` + return psycheJsonFetch(`check-gcs-bucket?${queryParams}`) +} + interface DecodeState { buffer: string decoder: TextDecoder diff --git a/website/shared/index.ts b/website/shared/index.ts index 61f10b7a0..0b19b75d0 100644 --- a/website/shared/index.ts +++ b/website/shared/index.ts @@ -9,6 +9,7 @@ type PsycheSolanaCoordinator = coordinatorTypes.PsycheSolanaCoordinator type PsycheSolanaMiningPool = miningPoolTypes.PsycheSolanaMiningPool import type { + GcsRepo, HubRepo, LearningRateSchedule, LLMArchitecture, @@ -118,7 +119,7 @@ export interface RunData { epochStartTime: Date clients: Array - checkpoint: HubRepo | null + checkpoint: { Hub: HubRepo } | { Gcs: GcsRepo } | null round: number config: {