diff --git a/Cargo.lock b/Cargo.lock index dbcbedf81..85cd3764a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6968,6 +6968,7 @@ dependencies = [ "tokenizers", "tokio", "tokio-util 0.7.16", + "toml 0.8.23", "tracing", "wandb", ] @@ -6982,6 +6983,7 @@ dependencies = [ "cfg_eval", "psyche-core", "serde", + "serde_json", "serde_with", "ts-rs", ] @@ -7073,6 +7075,7 @@ dependencies = [ name = "psyche-deserialize-zerocopy-wasm" version = "0.2.0" dependencies = [ + "psyche-coordinator", "psyche-core", "psyche-solana-coordinator", "serde", @@ -8106,6 +8109,7 @@ dependencies = [ "clap-markdown", "psyche-coordinator", "psyche-core", + "psyche-data-provider", "psyche-solana-authorizer", "psyche-solana-coordinator", "psyche-solana-rpc", diff --git a/architectures/centralized/client/src/app.rs b/architectures/centralized/client/src/app.rs index 8f6a9bcca..aab6e66b9 100644 --- a/architectures/centralized/client/src/app.rs +++ b/architectures/centralized/client/src/app.rs @@ -5,7 +5,8 @@ use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientM use psyche_client::HubUploadInfo; use psyche_client::UploadInfo; use psyche_client::{ - Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key, + Client, ClientTUI, ClientTUIState, ModelExtraData, NC, RunInitConfig, TrainArgs, + read_identity_secret_key, }; use psyche_coordinator::{Coordinator, HealthChecks, model}; use psyche_metrics::ClientMetrics; @@ -110,6 +111,7 @@ pub async fn build_app( let hub_read_token = std::env::var("HF_TOKEN").ok(); let eval_tasks = p.eval_tasks()?; let checkpoint_config = p.checkpoint_config()?; + let model_extra_data_override: Option = p.model_extra_data_override()?; let wandb_info = p.wandb_info(format!( "{}-{}", p.run_id.clone(), @@ -153,6 +155,7 @@ pub async fn build_app( max_concurrent_parameter_requests: p.max_concurrent_parameter_requests, device: p.device, sidecar_port: p.sidecar_port, + model_extra_data_override, }; let app = App { cancel, diff --git a/architectures/centralized/server/src/app.rs b/architectures/centralized/server/src/app.rs index 402146817..f2a4c7797 100644 --- a/architectures/centralized/server/src/app.rs +++ b/architectures/centralized/server/src/app.rs @@ -1,7 +1,7 @@ -use anyhow::{Result, anyhow, bail}; +use anyhow::{Result, bail}; use async_trait::async_trait; use psyche_centralized_shared::{ClientId, ClientToServerMessage, ServerToClientMessage}; -use psyche_coordinator::model::{self, Checkpoint, LLM, LLMTrainingDataLocation, Model}; +use psyche_coordinator::model::{self, Checkpoint, Model}; use psyche_coordinator::{ Client, ClientState, Coordinator, CoordinatorError, HealthChecks, Round, RunState, SOLANA_MAX_NUM_CLIENTS, TickResult, @@ -148,12 +148,18 @@ impl App { } } +fn default_data_server_port() -> u16 { + 9088 +} + #[derive(Serialize, Deserialize, Debug)] pub struct DataServerInfo { pub dir: PathBuf, pub token_size: TokenSize, pub seq_len: usize, pub shuffle_seed: [u8; 32], + #[serde(default = "default_data_server_port")] + pub port: u16, } impl App { @@ -176,72 +182,58 @@ impl App { debug!("potentially launching data server..."); - let training_data_server = match &coordinator.model { - Model::LLM(LLM { - data_location, - checkpoint, - .. - }) => { - if let LLMTrainingDataLocation::Server(url) = data_location { - match checkpoint { - Checkpoint::Hub(hub_repo) => { - let repo_id = String::from(&hub_repo.repo_id); - let revision = hub_repo.revision.map(|bytes| (&bytes).into()); - if revision.is_some() - || !tokio::fs::try_exists(PathBuf::from(repo_id.clone())) - .await - .unwrap_or_default() - { - download_model_repo_async(&repo_id, revision, None, None, None, true) - .await?; - } - } - Checkpoint::Ephemeral => { - bail!("Can't start up a run with an Ephemeral checkpoint.") - } - Checkpoint::Dummy(_) => { - // ok! - } - Checkpoint::P2P(_) | Checkpoint::P2PGcs(_) => { - bail!("Can't start up a run with a P2P checkpoint.") - } - Checkpoint::Gcs(gcs_repo) => { - let bucket: String = (&gcs_repo.bucket).into(); - let prefix: Option = - gcs_repo.prefix.map(|p| (&p).into()); - download_model_from_gcs_async(&bucket, prefix.as_deref()).await?; - } - } - - let server_addr: SocketAddr = String::from(url).parse().map_err(|e| { - anyhow!("Failed to parse training data server URL {:?}: {}", url, e) - })?; - let data_server_port = server_addr.port(); - let DataServerInfo { - dir, - seq_len, - shuffle_seed, - token_size - } = data_server_config.ok_or_else(|| anyhow!( - "Coordinator state requires we host training data, but no --data-config passed." - ))?; - - let local_data_provider = LocalDataProvider::new_from_directory( - dir, - token_size, - seq_len, - Shuffle::Seeded(shuffle_seed), - )?; - - let (tx, backend) = ChannelCoordinatorBackend::new(); - let data_server = - DataProviderTcpServer::start(local_data_provider, backend, data_server_port) + let training_data_server = if let Some(DataServerInfo { + dir, + seq_len, + shuffle_seed, + token_size, + port, + }) = data_server_config + { + // Download model if needed based on checkpoint type + let Model::LLM(llm) = &coordinator.model; + match &llm.checkpoint { + Checkpoint::Hub(hub_repo) => { + let repo_id = String::from(&hub_repo.repo_id); + let revision = hub_repo.revision.map(|bytes| (&bytes).into()); + if revision.is_some() + || !tokio::fs::try_exists(PathBuf::from(repo_id.clone())) + .await + .unwrap_or_default() + { + download_model_repo_async(&repo_id, revision, None, None, None, true) .await?; - Some((tx, data_server)) - } else { - None + } + } + Checkpoint::Ephemeral => { + bail!("Can't start up a run with an Ephemeral checkpoint.") + } + Checkpoint::Dummy(_) => { + // ok! + } + Checkpoint::P2P(_) | Checkpoint::P2PDummy | Checkpoint::P2PGcs(_) => { + bail!("Can't start up a run with a P2P checkpoint.") + } + Checkpoint::Gcs(gcs_repo) => { + let bucket: String = (&gcs_repo.bucket).into(); + let prefix: Option = gcs_repo.prefix.map(|p| (&p).into()); + download_model_from_gcs_async(&bucket, prefix.as_deref()).await?; } } + + let local_data_provider = LocalDataProvider::new_from_directory( + dir, + token_size, + seq_len, + Shuffle::Seeded(shuffle_seed), + )?; + + let (tx, backend) = ChannelCoordinatorBackend::new(); + let data_server = + DataProviderTcpServer::start(local_data_provider, backend, port).await?; + Some((tx, data_server)) + } else { + None }; debug!("data server work done."); @@ -253,8 +245,7 @@ impl App { } else { (None, None) }; - let (cancel, tx_tui_state) = - maybe_start_render_loop(tabs)?; + let (cancel, tx_tui_state) = maybe_start_render_loop(tabs)?; let mut tick_interval = interval(Duration::from_millis(500)); tick_interval.set_missed_tick_behavior(MissedTickBehavior::Skip); //important! @@ -293,7 +284,9 @@ impl App { withdraw_on_disconnect, pause, }) - }.instrument(info_span!("App::new")).await + } + .instrument(info_span!("App::new")) + .await } pub async fn run(&mut self) -> Result<()> { diff --git a/architectures/centralized/testing/tests/integration_tests.rs b/architectures/centralized/testing/tests/integration_tests.rs index 81943e827..656e6951a 100644 --- a/architectures/centralized/testing/tests/integration_tests.rs +++ b/architectures/centralized/testing/tests/integration_tests.rs @@ -9,10 +9,7 @@ use psyche_centralized_testing::{ spawn_clients_with_training_delay, }, }; -use psyche_coordinator::{ - RunState, - model::{Checkpoint, HubRepo}, -}; +use psyche_coordinator::{RunState, model::Checkpoint}; use tracing::info; #[test_log::test(tokio::test(flavor = "multi_thread"))] @@ -639,7 +636,7 @@ async fn client_join_in_training_and_get_model_using_p2p() { assert_with_retries( || server_handle.get_checkpoint(), - std::mem::discriminant(&Checkpoint::P2P(HubRepo::dummy())), + std::mem::discriminant(&Checkpoint::P2PDummy), ) .await; @@ -722,7 +719,7 @@ async fn two_clients_join_in_training_and_get_model_using_p2p() { assert_with_retries( || server_handle.get_checkpoint(), - std::mem::discriminant(&Checkpoint::P2P(HubRepo::dummy())), + std::mem::discriminant(&Checkpoint::P2PDummy), ) .await; diff --git a/architectures/decentralized/solana-client/src/app.rs b/architectures/decentralized/solana-client/src/app.rs index 36a529bbb..dfb242cb0 100644 --- a/architectures/decentralized/solana-client/src/app.rs +++ b/architectures/decentralized/solana-client/src/app.rs @@ -11,7 +11,8 @@ use anchor_client::{ }; use anyhow::{Result, anyhow}; use psyche_client::{ - Client, ClientTUI, ClientTUIState, NC, RunInitConfig, TrainArgs, read_identity_secret_key, + Client, ClientTUI, ClientTUIState, ModelExtraData, NC, RunInitConfig, TrainArgs, + read_identity_secret_key, }; use psyche_coordinator::{ClientState, Coordinator, CoordinatorError, RunState}; use psyche_core::sha256; @@ -90,6 +91,7 @@ pub async fn build_app( let eval_tasks = p.eval_tasks()?; let hub_read_token = std::env::var("HF_TOKEN").ok(); let checkpoint_config = p.checkpoint_config()?; + let model_extra_data_override: Option = p.model_extra_data_override()?; let solana_pubkey = wallet_keypair.pubkey(); let wandb_info = p.wandb_info(format!("{}-{solana_pubkey}", p.run_id))?; @@ -134,6 +136,7 @@ pub async fn build_app( max_concurrent_parameter_requests: p.max_concurrent_parameter_requests, device: p.device, sidecar_port: p.sidecar_port, + model_extra_data_override, }; let app = App { run_id: p.run_id.clone(), diff --git a/architectures/decentralized/solana-client/src/main.rs b/architectures/decentralized/solana-client/src/main.rs index 547940b8f..d4f45c179 100644 --- a/architectures/decentralized/solana-client/src/main.rs +++ b/architectures/decentralized/solana-client/src/main.rs @@ -288,6 +288,9 @@ async fn async_main() -> Result<()> { Checkpoint::Ephemeral => { bail!("Can't predownload model with ephemeral checkpoint.") } + Checkpoint::P2PDummy => { + println!("P2PDummy checkpoint (for testing), nothing to predownload."); + } Checkpoint::Dummy(hub_repo) | Checkpoint::Hub(hub_repo) | Checkpoint::P2P(hub_repo) => { diff --git a/architectures/decentralized/solana-coordinator/Cargo.lock b/architectures/decentralized/solana-coordinator/Cargo.lock index 72a38df53..6a1fe747f 100644 --- a/architectures/decentralized/solana-coordinator/Cargo.lock +++ b/architectures/decentralized/solana-coordinator/Cargo.lock @@ -1610,6 +1610,7 @@ dependencies = [ "cfg_eval", "psyche-core", "serde", + "serde_json", "serde_with", "ts-rs", ] diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_data_layout.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_data_layout.rs index 8257cbbd1..87e1a26d0 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_data_layout.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_data_layout.rs @@ -1,19 +1,10 @@ use psyche_coordinator::Round; use psyche_coordinator::RunState; use psyche_coordinator::model::Checkpoint; -use psyche_coordinator::model::HttpTrainingDataLocation; -use psyche_coordinator::model::LLMArchitecture; -use psyche_coordinator::model::LLMTrainingDataLocation; -use psyche_coordinator::model::LLMTrainingDataType; use psyche_coordinator::model::Model; -use psyche_core::CosineLR; use psyche_core::FixedString; use psyche_core::FixedVec; -use psyche_core::LearningRateSchedule; -use psyche_core::OptimizerDefinition; -use psyche_core::Shuffle; use psyche_core::SmallBoolean; -use psyche_core::TokenSize; use psyche_solana_coordinator::CoordinatorAccount; use psyche_solana_coordinator::coordinator_account_from_bytes; @@ -43,12 +34,11 @@ pub async fn run() { assert_eq!(coordinator.run_state, RunState::Uninitialized); assert_eq!(coordinator.run_state_start_unix_timestamp, 0); assert_eq!(coordinator.pending_pause, SmallBoolean::FALSE); - // Coordinator model + // Coordinator model (only on-chain fields) match coordinator.model { Model::LLM(llm) => { assert_eq!(llm.max_seq_len, 2048); assert_eq!(llm.cold_start_warmup_steps, 0); - assert_eq!(llm.architecture, LLMArchitecture::HfLlama); match llm.checkpoint { Checkpoint::Hub(hub) => { assert_eq!( @@ -59,57 +49,6 @@ pub async fn run() { }, _ => panic!("Expected Hub checkpoint"), }; - assert_eq!(llm.data_type, LLMTrainingDataType::Pretraining); - match llm.data_location { - LLMTrainingDataLocation::Http(http) => { - match http.location { - HttpTrainingDataLocation::Gcp { - bucket_name, - filter_directory, - } => { - assert_eq!( - bucket_name, - fixed_str("nous-pretraining-public-us") - ); - assert_eq!( - filter_directory, - fixed_str("fineweb-edu-tokenized-llama2") - ); - }, - _ => panic!("Expected Gcp data location"), - }; - assert_eq!(http.token_size_in_bytes, TokenSize::TwoBytes); - assert_eq!(http.shuffle, Shuffle::DontShuffle); - }, - _ => panic!("Expected Http data location"), - }; - match llm.lr_schedule { - LearningRateSchedule::Cosine(learning_rate) => { - assert_eq!( - learning_rate, - CosineLR::new(0.0004, 250, 0.0, 25000, 0.00004) - ); - }, - _ => panic!("Expected Constant LR schedule"), - }; - match llm.optimizer { - OptimizerDefinition::Distro { - clip_grad_norm, - weight_decay, - compression_decay, - compression_topk, - compression_chunk, - quantize_1bit, - } => { - assert_eq!(clip_grad_norm, Some(1.0)); - assert_eq!(weight_decay, None); - assert_eq!(compression_decay, 0.999); - assert_eq!(compression_topk, 2); - assert_eq!(compression_chunk, 64); - assert_eq!(quantize_1bit, false); - }, - _ => panic!("Expected Distro optimizer"), - } }, }; // Coordinator config diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs index fc03202cf..45a235723 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs @@ -5,13 +5,7 @@ use psyche_coordinator::WitnessProof; use psyche_coordinator::model::Checkpoint; use psyche_coordinator::model::HubRepo; use psyche_coordinator::model::LLM; -use psyche_coordinator::model::LLMArchitecture; -use psyche_coordinator::model::LLMTrainingDataLocation; -use psyche_coordinator::model::LLMTrainingDataType; use psyche_coordinator::model::Model; -use psyche_core::ConstantLR; -use psyche_core::LearningRateSchedule; -use psyche_core::OptimizerDefinition; use psyche_solana_authorizer::logic::AuthorizationGrantorUpdateParams; use psyche_solana_coordinator::ClientId; use psyche_solana_coordinator::CoordinatorAccount; @@ -111,20 +105,8 @@ pub async fn run() { waiting_for_members_extra_time: 3, }), Some(Model::LLM(LLM { - architecture: LLMArchitecture::HfLlama, checkpoint: Checkpoint::Dummy(HubRepo::dummy()), max_seq_len: 4096, - data_type: LLMTrainingDataType::Pretraining, - data_location: LLMTrainingDataLocation::default(), - lr_schedule: LearningRateSchedule::Constant(ConstantLR::default()), - optimizer: OptimizerDefinition::Distro { - clip_grad_norm: None, - compression_decay: 1.0, - compression_topk: 1, - compression_chunk: 1, - quantize_1bit: false, - weight_decay: None, - }, cold_start_warmup_steps: 0, })), None, // no explicit progress diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs index f69bbf8c5..d8779d257 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs @@ -5,13 +5,7 @@ use psyche_coordinator::WAITING_FOR_MEMBERS_EXTRA_SECONDS; use psyche_coordinator::model::Checkpoint; use psyche_coordinator::model::HubRepo; use psyche_coordinator::model::LLM; -use psyche_coordinator::model::LLMArchitecture; -use psyche_coordinator::model::LLMTrainingDataLocation; -use psyche_coordinator::model::LLMTrainingDataType; use psyche_coordinator::model::Model; -use psyche_core::ConstantLR; -use psyche_core::LearningRateSchedule; -use psyche_core::OptimizerDefinition; use psyche_solana_authorizer::logic::AuthorizationGrantorUpdateParams; use psyche_solana_coordinator::ClientId; use psyche_solana_coordinator::CoordinatorAccount; @@ -108,20 +102,8 @@ pub async fn run() { total_steps: 100, }), Some(Model::LLM(LLM { - architecture: LLMArchitecture::HfLlama, checkpoint: Checkpoint::Dummy(HubRepo::dummy()), max_seq_len: 4096, - data_type: LLMTrainingDataType::Pretraining, - data_location: LLMTrainingDataLocation::default(), - lr_schedule: LearningRateSchedule::Constant(ConstantLR::default()), - optimizer: OptimizerDefinition::Distro { - clip_grad_norm: None, - compression_decay: 1.0, - compression_topk: 1, - compression_chunk: 1, - quantize_1bit: false, - weight_decay: None, - }, cold_start_warmup_steps: 0, })), None, // no explicit progress diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs index e51ced2dd..3d185f7fd 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_update.rs @@ -3,13 +3,7 @@ use psyche_coordinator::WAITING_FOR_MEMBERS_EXTRA_SECONDS; use psyche_coordinator::model::Checkpoint; use psyche_coordinator::model::HubRepo; use psyche_coordinator::model::LLM; -use psyche_coordinator::model::LLMArchitecture; -use psyche_coordinator::model::LLMTrainingDataLocation; -use psyche_coordinator::model::LLMTrainingDataType; use psyche_coordinator::model::Model; -use psyche_core::ConstantLR; -use psyche_core::LearningRateSchedule; -use psyche_core::OptimizerDefinition; use psyche_solana_coordinator::CoordinatorAccount; use psyche_solana_tooling::create_memnet_endpoint::create_memnet_endpoint; use psyche_solana_tooling::process_treasurer_instructions::process_treasurer_run_create; @@ -55,20 +49,8 @@ pub async fn run() { as u8, }), model: Some(Model::LLM(LLM { - architecture: LLMArchitecture::HfLlama, checkpoint: Checkpoint::Dummy(HubRepo::dummy()), max_seq_len: 4096, - data_type: LLMTrainingDataType::Pretraining, - data_location: LLMTrainingDataLocation::default(), - lr_schedule: LearningRateSchedule::Constant(ConstantLR::default()), - optimizer: OptimizerDefinition::Distro { - clip_grad_norm: None, - compression_decay: 1.0, - compression_topk: 1, - compression_chunk: 1, - quantize_1bit: false, - weight_decay: None, - }, cold_start_warmup_steps: 0, })), progress: None, diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs index 014772e32..3875de2be 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs @@ -7,13 +7,7 @@ use psyche_coordinator::WAITING_FOR_MEMBERS_EXTRA_SECONDS; use psyche_coordinator::model::Checkpoint; use psyche_coordinator::model::HubRepo; use psyche_coordinator::model::LLM; -use psyche_coordinator::model::LLMArchitecture; -use psyche_coordinator::model::LLMTrainingDataLocation; -use psyche_coordinator::model::LLMTrainingDataType; use psyche_coordinator::model::Model; -use psyche_core::ConstantLR; -use psyche_core::LearningRateSchedule; -use psyche_core::OptimizerDefinition; use psyche_solana_authorizer::logic::AuthorizationGranteeUpdateParams; use psyche_solana_authorizer::logic::AuthorizationGrantorUpdateParams; use psyche_solana_coordinator::ClientId; @@ -230,22 +224,8 @@ pub async fn run() { waiting_for_members_extra_time: 3, }), model: Some(Model::LLM(LLM { - architecture: LLMArchitecture::HfLlama, checkpoint: Checkpoint::Dummy(HubRepo::dummy()), max_seq_len: 4096, - data_type: LLMTrainingDataType::Pretraining, - data_location: LLMTrainingDataLocation::default(), - lr_schedule: LearningRateSchedule::Constant( - ConstantLR::default(), - ), - optimizer: OptimizerDefinition::Distro { - clip_grad_norm: None, - compression_decay: 1.0, - compression_topk: 1, - compression_chunk: 1, - quantize_1bit: false, - weight_decay: None, - }, cold_start_warmup_steps: 0, })), progress: None, diff --git a/architectures/decentralized/testing/src/docker_setup.rs b/architectures/decentralized/testing/src/docker_setup.rs index 4f5a80dd9..6d0f09abd 100644 --- a/architectures/decentralized/testing/src/docker_setup.rs +++ b/architectures/decentralized/testing/src/docker_setup.rs @@ -4,7 +4,7 @@ use bollard::{ Config, CreateContainerOptions, KillContainerOptions, ListContainersOptions, RemoveContainerOptions, }, - models::DeviceRequest, + models::{DeviceRequest, Mount, MountTypeEnum}, secret::{ContainerSummary, HostConfig}, }; use psyche_core::IntegrationTestLogMarker; @@ -119,6 +119,24 @@ pub async fn spawn_new_client(docker_client: Arc) -> Result) -> Result, } impl TrainArgs { @@ -327,6 +333,32 @@ impl TrainArgs { .collect(); result } + + pub fn model_extra_data_override(&self) -> Result> { + let Some(path) = &self.model_extra_data_toml else { + return Ok(None); + }; + + let content = std::fs::read_to_string(path) + .with_context(|| format!("failed to read model extra data TOML file {:?}", path))?; + + let toml_value: toml::Value = toml::from_str(&content) + .with_context(|| format!("failed to parse TOML file {:?}", path))?; + + let model_extra_data_table = toml_value + .get("model_extra_data") + .ok_or_else(|| anyhow::anyhow!("missing [model_extra_data] section in {:?}", path))?; + + let config: ModelExtraData = + model_extra_data_table.clone().try_into().with_context(|| { + format!( + "failed to deserialize model_extra_data from TOML file {:?}", + path + ) + })?; + + Ok(Some(config)) + } } pub fn prepare_environment() { diff --git a/shared/client/src/lib.rs b/shared/client/src/lib.rs index bdad43e30..998795c35 100644 --- a/shared/client/src/lib.rs +++ b/shared/client/src/lib.rs @@ -9,8 +9,8 @@ pub use cli::{TrainArgs, prepare_environment, print_identity_keys, read_identity pub use client::Client; pub use protocol::{Broadcast, BroadcastType, Finished, NC, TrainingResult}; pub use state::{ - CheckpointConfig, GcsUploadInfo, HubUploadInfo, InitRunError, RoundState, RunInitConfig, - RunInitConfigAndIO, UploadInfo, + CheckpointConfig, GcsUploadInfo, HubUploadInfo, InitRunError, ModelExtraData, RoundState, + RunInitConfig, RunInitConfigAndIO, UploadInfo, }; pub use tui::{ClientTUI, ClientTUIState}; diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index 7a945f74b..0484e795b 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -1,7 +1,9 @@ use crate::{WandBInfo, fetch_data::DataFetcher}; +pub use psyche_coordinator::model_extra_data::ModelExtraData; use psyche_coordinator::{ Coordinator, HealthChecks, model::{self, HttpLLMTrainingDataLocation, LLMTrainingDataLocation}, + model_extra_data::{CONFIG_PREFIX, MODEL_CONFIG_FILENAME}, }; use psyche_core::{ Barrier, CancellableBarrier, IntegrationTestLogMarker, NodeIdentity, Shuffle, TokenSize, @@ -9,7 +11,8 @@ use psyche_core::{ use psyche_data_provider::{ DataProvider, DataProviderTcpClient, DownloadError, DummyDataProvider, PreprocessedDataProvider, Split, WeightedDataProvider, download_dataset_repo_async, - download_model_from_gcs_async, download_model_repo_async, + download_model_from_gcs_async, download_model_repo_async, fetch_json_from_gcs, + fetch_json_from_hub, http::{FileURLs, HttpDataProvider}, }; use psyche_metrics::ClientMetrics; @@ -75,6 +78,10 @@ pub struct RunInitConfig { pub dummy_training_delay_secs: Option, pub sidecar_port: Option, + + /// If provided, use this model extra data instead of fetching from GCS/Hub. + /// Only meant for testing/debugging. + pub model_extra_data_override: Option, } #[derive(Debug, Error)] @@ -199,11 +206,57 @@ impl RunInitConfigAndIO { + let bucket = gcs_repo.bucket.to_string(); + let path = format!("{}/{}", CONFIG_PREFIX, MODEL_CONFIG_FILENAME); + debug!("Fetching model extra data from gs://{}/{}", bucket, path); + fetch_json_from_gcs(&bucket, &path).await? + } + model::Checkpoint::Hub(repo) | model::Checkpoint::P2P(repo) => { + let repo_id: String = (&repo.repo_id).into(); + let revision = repo.revision.map(|bytes| (&bytes).into()); + let path = format!("{}/{}", CONFIG_PREFIX, MODEL_CONFIG_FILENAME); + debug!("Fetching model extra data from Hub: {}/{}", repo_id, path); + match fetch_json_from_hub( + &repo_id, + revision, + &path, + init_config.hub_read_token.clone(), + ) + .await + { + Ok(config) => config, + Err(e) => { + error!( + "Failed to fetch model extra data from Hub ({}), using default: {}", + repo_id, e + ); + return Err(InitRunError::GcsModelLoad(e)); + } + } + } + // Dummy/Ephemeral checkpoints use default model extra data (for testing) + _ => ModelExtraData::default(), + } + }; + let hub_read_token = init_config.hub_read_token.clone(); let hub_max_concurrent_downloads = init_config.hub_max_concurrent_downloads; let data_future = async { - debug!("Setting up data provider from {:?}", llm.data_location); - let data_provider = match llm.data_location { + debug!( + "Setting up data provider from {:?}", + model_extra_data.data_location + ); + let data_provider = match model_extra_data.data_location { LLMTrainingDataLocation::Server(data_server) => DataProvider::Server( DataProviderTcpClient::connect( (&data_server).into(), @@ -268,7 +321,8 @@ impl RunInitConfigAndIO> = match &llm.architecture + let model_future: JoinHandle> = match &model_extra_data + .architecture { model::LLMArchitecture::HfLlama | model::LLMArchitecture::HfDeepseek @@ -317,6 +371,7 @@ impl RunInitConfigAndIO { let checkpoint = llm.checkpoint; @@ -374,7 +429,9 @@ impl RunInitConfigAndIO { + model::Checkpoint::P2P(_) + | model::Checkpoint::P2PDummy + | model::Checkpoint::P2PGcs(_) => { let (tx_model_config_response, rx_model_config_response) = oneshot::channel(); info!("Checkpoint is p2p, requesting model config over network"); @@ -387,7 +444,7 @@ impl RunInitConfigAndIO { AutoConfig::Llama(serde_json::from_str(&model_config)?) } @@ -408,7 +465,7 @@ impl RunInitConfigAndIO RunInitConfigAndIO 1 - && llm.architecture == model::LLMArchitecture::HfAuto + && model_extra_data.architecture == model::LLMArchitecture::HfAuto { 1 } else { @@ -489,7 +546,7 @@ impl RunInitConfigAndIO = - match llm.data_type { + match model_extra_data.data_type { model::LLMTrainingDataType::Finetuning => { #[cfg(feature = "parallelism")] { @@ -503,7 +560,9 @@ impl RunInitConfigAndIO None, }; - let raw_loaded_model_type: RawLoadedModelType = match llm.architecture { + let raw_loaded_model_type: RawLoadedModelType = match model_extra_data + .architecture + { model::LLMArchitecture::HfAuto | model::LLMArchitecture::Torchtitan => { #[cfg(feature = "python")] { @@ -513,7 +572,7 @@ impl RunInitConfigAndIO RunInitConfigAndIO RunInitConfigAndIO RunInitConfigAndIO RunInitConfigAndIO, data_parallel: None, }, - llm.lr_schedule, - llm.optimizer, + model_extra_data.lr_schedule, + model_extra_data.optimizer, init_config.micro_batch_size, init_config.optim_stats_every_n_steps, init_config.grad_accum_in_fp32, @@ -822,8 +881,8 @@ impl RunInitConfigAndIO RunInitConfigAndIO RunInitConfigAndIO quantize_1bit, + _ => false, + }; + let training = TrainingStepMetadata { data_fetcher, identity: init_config.identity, @@ -855,6 +919,7 @@ impl RunInitConfigAndIO { pub write_gradients_dir: Option, pub model_task_runner: ModelTaskRunner, + pub quantize_1bit: bool, } #[derive(Debug)] @@ -274,12 +275,7 @@ impl TrainingStepMetadata let cancel_training = cancel_training.clone(); let write_gradients_dir = self.write_gradients_dir.clone(); let tx_distro_result = self.tx_distro_result.clone(); - let quantize = match &state.model { - model::Model::LLM(llm) => match llm.optimizer { - OptimizerDefinition::Distro { quantize_1bit, .. } => quantize_1bit, - _ => false, - }, - }; + let quantize = self.quantize_1bit; let finished = finished.clone(); let TrainingDataForStep { diff --git a/shared/coordinator/Cargo.toml b/shared/coordinator/Cargo.toml index f7cdecc81..2ca5ec09a 100644 --- a/shared/coordinator/Cargo.toml +++ b/shared/coordinator/Cargo.toml @@ -9,6 +9,7 @@ async-trait.workspace = true anchor-lang.workspace = true bytemuck.workspace = true serde_with.workspace = true +serde_json.workspace = true serde.workspace = true cfg_eval = "0.1.2" ts-rs.workspace = true diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index c726655e2..7db9a8f6d 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -1,6 +1,6 @@ use crate::{ Commitment, Committee, CommitteeProof, CommitteeSelection, WitnessProof, - model::{Checkpoint, Model}, + model::{Checkpoint, HubRepo, Model}, }; use anchor_lang::{AnchorDeserialize, AnchorSerialize, InitSpace, prelude::borsh}; @@ -936,6 +936,7 @@ impl Coordinator { match llm.checkpoint { Checkpoint::P2P(hub_repo) => llm.checkpoint = Checkpoint::Hub(hub_repo), Checkpoint::P2PGcs(gcs_repo) => llm.checkpoint = Checkpoint::Gcs(gcs_repo), + Checkpoint::P2PDummy => llm.checkpoint = Checkpoint::Dummy(HubRepo::dummy()), _ => {} } } @@ -1065,9 +1066,8 @@ impl Coordinator { // we've completed an epoch, switch to P2P from now on let Model::LLM(llm) = &mut self.model; match llm.checkpoint { - Checkpoint::Hub(hub_repo) | Checkpoint::Dummy(hub_repo) => { - llm.checkpoint = Checkpoint::P2P(hub_repo) - } + Checkpoint::Hub(hub_repo) => llm.checkpoint = Checkpoint::P2P(hub_repo), + Checkpoint::Dummy(_) => llm.checkpoint = Checkpoint::P2PDummy, Checkpoint::Gcs(gcs_repo) => llm.checkpoint = Checkpoint::P2PGcs(gcs_repo), _ => {} } diff --git a/shared/coordinator/src/lib.rs b/shared/coordinator/src/lib.rs index bef26863e..aae450fb0 100644 --- a/shared/coordinator/src/lib.rs +++ b/shared/coordinator/src/lib.rs @@ -5,6 +5,7 @@ mod committee_selection; mod coordinator; mod data_selection; pub mod model; +pub mod model_extra_data; pub use commitment::Commitment; pub use committee_selection::{ diff --git a/shared/coordinator/src/model.rs b/shared/coordinator/src/model.rs index 3176f276e..5d3dbb958 100644 --- a/shared/coordinator/src/model.rs +++ b/shared/coordinator/src/model.rs @@ -5,10 +5,7 @@ use anchor_lang::{ prelude::{borsh, msg}, }; use bytemuck::{Zeroable, ZeroableInOption}; -use psyche_core::{ - ConstantLR, FixedString, FixedVec, LearningRateSchedule, OptimizerDefinition, Shuffle, - TokenSize, -}; +use psyche_core::{FixedString, FixedVec, Shuffle, TokenSize}; use serde::{Deserialize, Serialize}; use ts_rs::TS; @@ -190,24 +187,14 @@ pub enum HttpTrainingDataLocation { pub struct LLM { pub max_seq_len: u32, pub cold_start_warmup_steps: u32, - pub architecture: LLMArchitecture, pub checkpoint: Checkpoint, - pub data_type: LLMTrainingDataType, - pub data_location: LLMTrainingDataLocation, - pub lr_schedule: LearningRateSchedule, - pub optimizer: OptimizerDefinition, } impl LLM { pub fn dummy() -> Self { Self { - architecture: LLMArchitecture::HfLlama, checkpoint: Checkpoint::Dummy(HubRepo::dummy()), - data_location: LLMTrainingDataLocation::default(), - data_type: LLMTrainingDataType::Pretraining, - lr_schedule: LearningRateSchedule::Constant(ConstantLR::default()), max_seq_len: 2048, - optimizer: OptimizerDefinition::Dummy, cold_start_warmup_steps: 0, } } @@ -283,6 +270,8 @@ pub enum Checkpoint { Dummy(HubRepo), Hub(HubRepo), P2P(HubRepo), + /// P2P checkpoint that originated from a Dummy checkpoint (for testing) + P2PDummy, Gcs(GcsRepo), P2PGcs(GcsRepo), } @@ -290,7 +279,7 @@ pub enum Checkpoint { impl std::fmt::Display for Checkpoint { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Checkpoint::Dummy(_hub_repo) => write!(f, "Dummy"), + Checkpoint::Dummy(_) | Checkpoint::P2PDummy => write!(f, "Dummy"), Checkpoint::Ephemeral => write!(f, "Ephemeral"), Checkpoint::Hub(hub_repo) => write!(f, "{}", &hub_repo.repo_id), Checkpoint::P2P(hub_repo) => { @@ -313,30 +302,8 @@ impl Model { return false; } - let bad_data_location = match llm.data_location { - LLMTrainingDataLocation::Dummy => false, - LLMTrainingDataLocation::Server(url) => url.is_empty(), - LLMTrainingDataLocation::Local(_) => false, - LLMTrainingDataLocation::Http(HttpLLMTrainingDataLocation { - location, .. - }) => match location { - HttpTrainingDataLocation::SingleUrl(url) => url.is_empty(), - HttpTrainingDataLocation::NumberedFiles { - url_template, - num_files, - .. - } => url_template.is_empty() || num_files == 0, - HttpTrainingDataLocation::Gcp { bucket_name, .. } => bucket_name.is_empty(), - }, - LLMTrainingDataLocation::WeightedHttp(url) => url.is_empty(), - LLMTrainingDataLocation::Preprocessed(url) => url.is_empty(), - }; - if bad_data_location { - msg!("model check failed: bad LLM training data location."); - return false; - } let bad_checkpoint = match llm.checkpoint { - Checkpoint::Dummy(_hub_repo) => false, + Checkpoint::Dummy(_) | Checkpoint::P2PDummy => false, Checkpoint::Ephemeral => true, Checkpoint::Hub(hub_repo) => hub_repo.repo_id.is_empty(), Checkpoint::P2P(hub_repo) => hub_repo.repo_id.is_empty(), @@ -349,14 +316,7 @@ impl Model { msg!("model check failed: bad checkpoint"); return false; } - if !match llm.optimizer { - OptimizerDefinition::Dummy => false, - OptimizerDefinition::AdamW { .. } => true, - OptimizerDefinition::Distro { .. } => true, - } { - msg!("model check failed: bad optimizer"); - return false; - } + true } } diff --git a/shared/coordinator/src/model_extra_data.rs b/shared/coordinator/src/model_extra_data.rs new file mode 100644 index 000000000..23b68b0de --- /dev/null +++ b/shared/coordinator/src/model_extra_data.rs @@ -0,0 +1,106 @@ +use serde::{Deserialize, Serialize}; + +use crate::model::{LLMArchitecture, LLMTrainingDataLocation, LLMTrainingDataType}; +use psyche_core::{LearningRateSchedule, OptimizerDefinition}; + +/// Path within the bucket where config is stored +pub const CONFIG_PREFIX: &str = "config"; +/// Filename for the model config +pub const MODEL_CONFIG_FILENAME: &str = "model_config.json"; + +/// Extra model data that is stored off-chain and fetched by clients. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelExtraData { + #[serde(default)] + pub version: u32, + + pub architecture: LLMArchitecture, + + pub data_type: LLMTrainingDataType, + + pub data_location: LLMTrainingDataLocation, + + pub lr_schedule: LearningRateSchedule, + + pub optimizer: OptimizerDefinition, + + /// Optional run metadata + #[serde(default, skip_serializing_if = "Option::is_none")] + pub run_metadata: Option, +} + +impl Default for ModelExtraData { + fn default() -> Self { + Self { + version: 1, + architecture: LLMArchitecture::HfLlama, + data_type: LLMTrainingDataType::Pretraining, + data_location: LLMTrainingDataLocation::default(), + lr_schedule: LearningRateSchedule::Constant(psyche_core::ConstantLR::default()), + optimizer: OptimizerDefinition::Dummy, + run_metadata: None, + } + } +} + +/// Run metadata - display information about the run +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct RunMetadata { + #[serde(default)] + pub name: String, + + #[serde(default)] + pub description: String, + + #[serde(default)] + pub num_parameters: u64, + + #[serde(default)] + pub vocab_size: u64, + + #[serde(default)] + pub client_version: String, +} + +impl ModelExtraData { + pub fn to_json(&self) -> Result { + serde_json::to_string_pretty(self) + } + + pub fn from_json(json: &str) -> Result { + serde_json::from_str(json) + } + + /// Validate the configuration + pub fn check(&self) -> bool { + let bad_data_location = match &self.data_location { + LLMTrainingDataLocation::Dummy => false, + LLMTrainingDataLocation::Server(url) => url.is_empty(), + LLMTrainingDataLocation::Local(_) => false, + LLMTrainingDataLocation::Http(http_loc) => { + use crate::model::HttpTrainingDataLocation; + match &http_loc.location { + HttpTrainingDataLocation::SingleUrl(url) => url.is_empty(), + HttpTrainingDataLocation::NumberedFiles { + url_template, + num_files, + .. + } => url_template.is_empty() || *num_files == 0, + HttpTrainingDataLocation::Gcp { bucket_name, .. } => bucket_name.is_empty(), + } + } + LLMTrainingDataLocation::WeightedHttp(url) => url.is_empty(), + LLMTrainingDataLocation::Preprocessed(url) => url.is_empty(), + }; + + if bad_data_location { + return false; + } + + match &self.optimizer { + OptimizerDefinition::Dummy => false, + OptimizerDefinition::AdamW { .. } => true, + OptimizerDefinition::Distro { .. } => true, + } + } +} diff --git a/shared/data-provider/src/errors.rs b/shared/data-provider/src/errors.rs index b84bc5f9a..d2afa5ce7 100644 --- a/shared/data-provider/src/errors.rs +++ b/shared/data-provider/src/errors.rs @@ -26,6 +26,9 @@ pub enum UploadError { #[error("GCS operation failed: {0}")] GcsStorage(#[from] google_cloud_storage::http::Error), + #[error("GCS error: {0}")] + Gcs(String), + // Common errors #[error("IO error: {0}")] Io(#[from] std::io::Error), @@ -45,6 +48,9 @@ pub enum DownloadError { #[error("GCS operation failed: {0}")] GcsStorage(#[from] google_cloud_storage::http::Error), + #[error("GCS error: {0}")] + Gcs(String), + #[error("IO error: {0}")] Io(#[from] std::io::Error), diff --git a/shared/data-provider/src/gcs.rs b/shared/data-provider/src/gcs.rs index 71f29e414..ddd31e0b3 100644 --- a/shared/data-provider/src/gcs.rs +++ b/shared/data-provider/src/gcs.rs @@ -323,6 +323,69 @@ pub fn download_model_from_gcs_sync( rt.block_on(download_model_from_gcs_async(bucket, prefix)) } +/// Fetch a JSON file from GCS and deserialize it. +/// Used for fetching external model configuration. +pub async fn fetch_json_from_gcs( + bucket: &str, + object_path: &str, +) -> Result { + // Use authenticated client if GOOGLE_APPLICATION_CREDENTIALS is set, otherwise anonymous + let config = if std::env::var("GOOGLE_APPLICATION_CREDENTIALS").is_ok() { + info!("Using authenticated GCS client for config fetch"); + ClientConfig::default().with_auth().await? + } else { + info!("Using anonymous GCS client for config fetch"); + ClientConfig::default().anonymous() + }; + let client = Client::new(config); + + info!("Fetching gs://{}/{}", bucket, object_path); + + let data = client + .download_object( + &GetObjectRequest { + bucket: bucket.to_owned(), + object: object_path.to_owned(), + ..Default::default() + }, + &Range::default(), + ) + .await?; + + serde_json::from_slice(&data).map_err(DownloadError::Json) +} + +/// Upload a JSON-serializable value to GCS. +pub async fn upload_json_to_gcs( + bucket: &str, + object_path: &str, + value: &T, +) -> Result<(), UploadError> { + // Use authenticated client - must have credentials for upload + let config = ClientConfig::default().with_auth().await?; + let client = Client::new(config); + + let json = serde_json::to_string_pretty(value)?; + let data = json.into_bytes(); + + info!("Uploading JSON to gs://{}/{}", bucket, object_path); + + client + .upload_object( + &UploadObjectRequest { + bucket: bucket.to_owned(), + ..Default::default() + }, + data, + &UploadType::Simple(Media::new(object_path.to_owned())), + ) + .await?; + + info!("Uploaded JSON to gs://{}/{}", bucket, object_path); + + Ok(()) +} + pub async fn upload_to_gcs( gcs_info: GcsUploadInfo, manifest_metadata: GcsManifestMetadata, diff --git a/shared/data-provider/src/hub.rs b/shared/data-provider/src/hub.rs index 13a575b84..c38236460 100644 --- a/shared/data-provider/src/hub.rs +++ b/shared/data-provider/src/hub.rs @@ -1,4 +1,4 @@ -use crate::errors::UploadError; +use crate::errors::{DownloadError, UploadError}; use crate::hub::model::HubRepo; use hf_hub::{ Cache, Repo, RepoType, @@ -8,10 +8,11 @@ use hf_hub::{ }, }; use psyche_coordinator::model; +use psyche_coordinator::model_extra_data::ModelExtraData; use psyche_core::FixedString; use std::{path::PathBuf, time::Instant}; use tokio::sync::mpsc; -use tracing::{error, info}; +use tracing::{debug, error, info}; const MODEL_EXTENSIONS: [&str; 3] = [".safetensors", ".json", ".py"]; const DATASET_EXTENSIONS: [&str; 1] = [".parquet"]; @@ -193,6 +194,73 @@ pub fn download_dataset_repo_sync( ) } +/// Fetch a JSON file from HuggingFace and deserialize it. +/// Used for fetching external model configuration from Hub. +pub async fn fetch_json_from_hub( + repo_id: &str, + revision: Option, + filename: &str, + token: Option, +) -> Result { + let cache = Cache::default(); + let api = hf_hub::api::tokio::ApiBuilder::new() + .with_cache_dir(cache.path().clone()) + .with_token(token.or(cache.token())) + .with_progress(false) + .build()?; + + let repo = match revision { + Some(rev) => Repo::with_revision(repo_id.to_string(), RepoType::Model, rev), + None => Repo::model(repo_id.to_string()), + }; + let api_repo = api.repo(repo); + + debug!("Fetching {} from {}", filename, repo_id); + + let file_path = api_repo.get(filename).await?; + let content = tokio::fs::read_to_string(&file_path).await?; + + serde_json::from_str(&content).map_err(DownloadError::Json) +} + +/// Upload model extra data to HuggingFace Hub. +pub async fn upload_model_extra_data_to_hub( + repo_id: &str, + filename: &str, + model_extra_data: &ModelExtraData, + token: Option, + commit_message: Option, +) -> Result<(), UploadError> { + let cache = Cache::default(); + let api = hf_hub::api::tokio::ApiBuilder::new() + .with_cache_dir(cache.path().clone()) + .with_token(token.or(cache.token())) + .with_progress(false) + .build()?; + + let repo = Repo::model(repo_id.to_string()); + let api_repo = api.repo(repo); + + let json = serde_json::to_string_pretty(model_extra_data)?; + let data = json.into_bytes(); + + info!("Uploading JSON to {}/{} on HuggingFace", repo_id, filename); + + api_repo + .upload_file( + UploadSource::Bytes(data), + filename, + commit_message.or_else(|| Some(format!("Upload {}", filename))), + None, + false, + ) + .await?; + + info!("Uploaded JSON to {}/{} on HuggingFace", repo_id, filename); + + Ok(()) +} + #[derive(Debug, Clone)] pub struct HubUploadInfo { pub hub_repo: String, diff --git a/shared/data-provider/src/lib.rs b/shared/data-provider/src/lib.rs index 0044d77d2..8899313e3 100644 --- a/shared/data-provider/src/lib.rs +++ b/shared/data-provider/src/lib.rs @@ -19,11 +19,13 @@ pub use errors::{DownloadError, UploadError}; pub use file_extensions::{DATA_FILE_EXTENSIONS, PARQUET_EXTENSION}; pub use gcs::{ GcsCheckpointManifest, GcsManifestMetadata, GcsUploadInfo, ManifestFileEntry, ManifestMetadata, - download_model_from_gcs_async, download_model_from_gcs_sync, upload_to_gcs, + download_model_from_gcs_async, download_model_from_gcs_sync, fetch_json_from_gcs, + upload_json_to_gcs, upload_to_gcs, }; pub use hub::{ HubUploadInfo, download_dataset_repo_async, download_dataset_repo_sync, - download_model_repo_async, download_model_repo_sync, upload_to_hub, + download_model_repo_async, download_model_repo_sync, fetch_json_from_hub, + upload_model_extra_data_to_hub, upload_to_hub, }; pub use local::LocalDataProvider; pub use parquet::record::{ListAccessor, MapAccessor, RowAccessor}; diff --git a/shared/watcher/src/tui.rs b/shared/watcher/src/tui.rs index 58bac18ba..b95272cca 100644 --- a/shared/watcher/src/tui.rs +++ b/shared/watcher/src/tui.rs @@ -43,13 +43,10 @@ impl psyche_tui::CustomWidget for CoordinatorTui { let vsplit = Layout::vertical(Constraint::from_fills([1, 1])).split(coord_split[1]); { Paragraph::new( - [ - format!("Data Source: {}", state.data_source), - format!("Model Checkpoint: {}", state.model_checkpoint), - ] - .into_iter() - .map(Line::from) - .collect::>(), + [format!("Model Checkpoint: {}", state.model_checkpoint)] + .into_iter() + .map(Line::from) + .collect::>(), ) .block(Block::bordered().title("Config")) .render(vsplit[0], buf); @@ -169,7 +166,6 @@ pub struct CoordinatorTuiState { pub run_state: TuiRunState, pub height: u32, pub clients: Vec, - pub data_source: String, pub model_checkpoint: String, pub exited_clients: usize, pub pending_pause: bool, @@ -187,9 +183,6 @@ impl From<&Coordinator> for CoordinatorTuiState { .iter() .map(|c| format!("{:?}", c.id)) .collect(), - data_source: match &value.model { - Model::LLM(l) => format!("{:?}", l.data_type), - }, model_checkpoint: match &value.model { Model::LLM(l) => format!("{}", l.checkpoint), }, diff --git a/tools/rust-tools/preview-lr/src/main.rs b/tools/rust-tools/preview-lr/src/main.rs index ffea7ea00..51c52a3c6 100644 --- a/tools/rust-tools/preview-lr/src/main.rs +++ b/tools/rust-tools/preview-lr/src/main.rs @@ -1,6 +1,6 @@ use clap::Parser; use plotters::prelude::*; -use psyche_coordinator::{CoordinatorConfig, model::Model}; +use psyche_coordinator::{CoordinatorConfig, model_extra_data::ModelExtraData}; use serde::Deserialize; use std::path::PathBuf; @@ -28,7 +28,7 @@ enum Commands { #[derive(Deserialize)] struct Config { pub config: CoordinatorConfig, - pub model: Model, + pub model_extra_data: ModelExtraData, } fn main() -> anyhow::Result<()> { let args = Args::parse(); @@ -48,9 +48,8 @@ fn main() -> anyhow::Result<()> { let config: Config = toml::from_str(&std::fs::read_to_string(&config_path)?)?; - let Model::LLM(llm) = config.model; let steps = config.config.total_steps; - let lr = llm.lr_schedule; + let lr = config.model_extra_data.lr_schedule; let root = BitMapBackend::new("lr-plot.png", (steps.min(10_000), 1024)).into_drawing_area(); root.fill(&WHITE)?; diff --git a/tools/rust-tools/run-manager/Cargo.toml b/tools/rust-tools/run-manager/Cargo.toml index e47a1931f..c6782c9e1 100644 --- a/tools/rust-tools/run-manager/Cargo.toml +++ b/tools/rust-tools/run-manager/Cargo.toml @@ -22,6 +22,7 @@ psyche-solana-authorizer.workspace = true psyche-solana-treasurer.workspace = true psyche-coordinator.workspace = true psyche-core.workspace = true +psyche-data-provider.workspace = true anchor-client.workspace = true anchor-lang.workspace = true anchor-spl.workspace = true diff --git a/tools/rust-tools/run-manager/src/commands/run/update_config.rs b/tools/rust-tools/run-manager/src/commands/run/update_config.rs index 641577307..7356de704 100644 --- a/tools/rust-tools/run-manager/src/commands/run/update_config.rs +++ b/tools/rust-tools/run-manager/src/commands/run/update_config.rs @@ -7,9 +7,12 @@ use clap::Args; use psyche_coordinator::{ CoordinatorConfig, CoordinatorProgress, get_data_index_for_step, model::{Checkpoint, Model}, + model_extra_data::{CONFIG_PREFIX, MODEL_CONFIG_FILENAME, ModelExtraData}, }; +use psyche_data_provider::upload_json_to_gcs; use psyche_solana_treasurer::logic::RunUpdateParams; use serde::{Deserialize, Serialize}; +use tracing::info; use crate::{SolanaBackend, instructions}; @@ -40,6 +43,12 @@ pub struct CommandUpdateConfig { // end metadata #[clap(long, env)] pub client_version: Option, + #[clap(long, default_value_t = false, hide = true)] + pub skip_upload_model_extra_data: bool, + + /// HuggingFace token for uploading to Hub repos (can also use HF_TOKEN env var) + #[clap(long, env = "HF_TOKEN")] + pub hub_token: Option, } #[async_trait] @@ -56,6 +65,8 @@ impl Command for CommandUpdateConfig { num_parameters, vocab_size, client_version, + skip_upload_model_extra_data, + hub_token, } = self; let main_authority = backend.get_payer(); @@ -69,12 +80,13 @@ impl Command for CommandUpdateConfig { .get_coordinator_account(&coordinator_account) .await?; - let (config, mut model) = match config_path { + let (config, mut model, model_extra_data) = match config_path { Some(config_path) => { #[derive(Serialize, Deserialize)] struct State { pub config: CoordinatorConfig, pub model: Model, + pub model_extra_data: ModelExtraData, } let state: State = toml::from_str(std::str::from_utf8( &std::fs::read(&config_path).with_context(|| { @@ -83,9 +95,13 @@ impl Command for CommandUpdateConfig { )?) .with_context(|| format!("failed to parse config toml file {config_path:?}"))?; - (Some(state.config), Some(state.model)) + ( + Some(state.config), + Some(state.model), + Some(state.model_extra_data), + ) } - None => (None, None), + None => (None, None, None), }; model = if switch_to_hub { @@ -135,6 +151,53 @@ impl Command for CommandUpdateConfig { coordinator_account_state.state.coordinator.model = model; } + // Upload model extra data to GCS or hub repo depending of the model checkpoint + if !skip_upload_model_extra_data { + if let Some(model_extra_data) = model_extra_data { + let Model::LLM(llm) = &coordinator_account_state.state.coordinator.model; + match llm.checkpoint { + Checkpoint::Gcs(ref gcs_repo) | Checkpoint::P2PGcs(ref gcs_repo) => { + let bucket = gcs_repo.bucket.to_string(); + let path = format!("{}/{}", CONFIG_PREFIX, MODEL_CONFIG_FILENAME); + info!("Uploading model extra data to gs://{}/{}", bucket, path); + upload_json_to_gcs(&bucket, &path, &model_extra_data) + .await + .with_context(|| { + format!( + "failed to upload model extra data to gs://{}/{}", + bucket, path + ) + })?; + println!("Uploaded model extra data to gs://{}/{}", bucket, path); + } + Checkpoint::Hub(ref hub_repo) | Checkpoint::P2P(ref hub_repo) => { + let repo_id = hub_repo.repo_id.to_string(); + let path = format!("{}/{}", CONFIG_PREFIX, MODEL_CONFIG_FILENAME); + psyche_data_provider::upload_model_extra_data_to_hub( + &repo_id, + &path, + &model_extra_data, + hub_token.clone(), + None, + ) + .await + .with_context(|| { + format!( + "failed to upload model extra data to Hub repo {}/{}", + repo_id, path + ) + })?; + println!("Uploaded model extra data to Hub repo {}/{}", repo_id, path); + } + _ => { + println!( + "Warning: model_extra_data provided but checkpoint is not GCS- or Hub-based, skipping upload" + ); + } + } + } + } + let progress = restart_from_step.map(|step| CoordinatorProgress { epoch: coordinator_account_state.state.coordinator.progress.epoch, step, diff --git a/website/backend/src/dataStores/flatFileCoordinator.ts b/website/backend/src/dataStores/flatFileCoordinator.ts index 6bdc15bba..d2438530e 100644 --- a/website/backend/src/dataStores/flatFileCoordinator.ts +++ b/website/backend/src/dataStores/flatFileCoordinator.ts @@ -4,7 +4,6 @@ import { Model, PsycheCoordinator, RunMetadata, - lr_at_step, } from 'psyche-deserialize-zerocopy-wasm' import { RunSummary, @@ -291,14 +290,6 @@ export class FlatFileCoordinatorDataStore implements CoordinatorDataStore { lastRun.lastUpdated = eventTime lastRun.lastState = newState - const step = newState.coordinator.progress.step - if (step > (lastRun.observedLrByStep.at(-1)?.[0] ?? 0)) { - const lr = lr_at_step(newState.coordinator.model.LLM.lr_schedule, step) - if (isGoodNumber(lr)) { - lastRun.observedLrByStep.push([step, lr]) - } - } - if (configChanged) { lastRun.configChanges.push({ timestamp: eventTime, @@ -673,8 +664,6 @@ export class FlatFileCoordinatorDataStore implements CoordinatorDataStore { maxRoundTrainTime: Number(config.max_round_train_time), roundWitnessTime: Number(config.round_witness_time), warmupTime: Number(config.warmup_time), - - lrSchedule: c.coordinator.model.LLM.lr_schedule, }, } } @@ -746,7 +735,6 @@ function makeRunSummary( : undefined const summary: RunSummary = { - arch: c.model.LLM.architecture, id: c.run_id, index: index, isOnlyRunAtThisIndex, diff --git a/website/frontend/src/components/RunSummary.tsx b/website/frontend/src/components/RunSummary.tsx index 57a9af6e2..d72e69a44 100644 --- a/website/frontend/src/components/RunSummary.tsx +++ b/website/frontend/src/components/RunSummary.tsx @@ -55,7 +55,6 @@ const InfoChits = styled.div` export const RunSummaryCard = memo(function RunSummaryCard({ info: { id, - arch, description, name, size, @@ -93,7 +92,6 @@ export const RunSummaryCard = memo(function RunSummaryCard({ {size !== 0n && ( {formatNumber(Number(size), 2)} )} - {arch} {type} {formatNumber(Number(totalTokens), 2)} diff --git a/website/frontend/src/fakeData.ts b/website/frontend/src/fakeData.ts index eb3d062b5..37a7ce0b1 100644 --- a/website/frontend/src/fakeData.ts +++ b/website/frontend/src/fakeData.ts @@ -60,7 +60,6 @@ export const fakeRunSummaries: RunSummary[] = [ status: { type: 'paused' }, totalTokens: 100000n, size: 1000000000n, - arch: 'HfLlama', type: 'vision', pauseHistory: [], lastUpdate: { @@ -77,7 +76,6 @@ export const fakeRunSummaries: RunSummary[] = [ status: { type: 'active' }, totalTokens: 200000n, size: 2000000000n, - arch: 'HfLlama', type: 'text', pauseHistory: [], lastUpdate: { @@ -100,7 +98,6 @@ export const fakeRunSummaries: RunSummary[] = [ }, // 1 day ago totalTokens: 50000n, size: 500000000n, - arch: 'HfLlama', type: 'text', pauseHistory: [], lastUpdate: { @@ -117,7 +114,6 @@ export const fakeRunSummaries: RunSummary[] = [ status: { type: 'active' }, totalTokens: 100000n, size: 1000000000n, - arch: 'HfLlama', type: 'vision', pauseHistory: [], lastUpdate: { @@ -394,15 +390,6 @@ function makeFakeRunDataSeeded(seed = 1, step = 0, index = 0): RunData { roundWitnessTime: 2_000, minClients, epochTime, - lrSchedule: { - Cosine: { - base_lr: 4.0e-4, - warmup_steps: 500, - warmup_init_lr: 0.0, - total_steps: 25000, - final_lr: 4.0e-5, - }, - }, }, }, } diff --git a/website/frontend/src/routes/runs/$run.$index.tsx b/website/frontend/src/routes/runs/$run.$index.tsx index 943f2abf4..ebb6625cc 100644 --- a/website/frontend/src/routes/runs/$run.$index.tsx +++ b/website/frontend/src/routes/runs/$run.$index.tsx @@ -178,7 +178,6 @@ function RouteComponent() { {formatNumber(Number(info.size), 2)} - {info.arch} {info.type} diff --git a/website/shared/index.ts b/website/shared/index.ts index 0b19b75d0..b415443ef 100644 --- a/website/shared/index.ts +++ b/website/shared/index.ts @@ -11,8 +11,6 @@ type PsycheSolanaMiningPool = miningPoolTypes.PsycheSolanaMiningPool import type { GcsRepo, HubRepo, - LearningRateSchedule, - LLMArchitecture, RunState, } from 'psyche-deserialize-zerocopy-wasm' @@ -73,7 +71,6 @@ export interface RunSummary { } size: bigint - arch: LLMArchitecture type: ModelType } @@ -131,8 +128,6 @@ export interface RunData { maxRoundTrainTime: number roundWitnessTime: number - - lrSchedule: LearningRateSchedule } } recentTxs: Array diff --git a/website/wasm/Cargo.toml b/website/wasm/Cargo.toml index 54c58a6c2..c9ff68400 100644 --- a/website/wasm/Cargo.toml +++ b/website/wasm/Cargo.toml @@ -8,6 +8,7 @@ crate-type = ["cdylib"] [dependencies] psyche-solana-coordinator = { path = "../../architectures/decentralized/solana-coordinator/programs/solana-coordinator" } +psyche-coordinator.workspace = true serde.workspace = true serde-wasm-bindgen = "0.6.5" wasm-bindgen = "=0.2.108" diff --git a/website/wasm/src/lib.rs b/website/wasm/src/lib.rs index 1beee51dd..f57cf73da 100644 --- a/website/wasm/src/lib.rs +++ b/website/wasm/src/lib.rs @@ -1,3 +1,4 @@ +use psyche_coordinator::model::LLMArchitecture; use psyche_core::LearningRateSchedule; use psyche_solana_coordinator::{ClientId, CoordinatorAccount, coordinator_account_from_bytes}; use serde::ser::Serialize; @@ -38,3 +39,14 @@ pub struct DummyCoordinatorAccount(CoordinatorAccount); #[derive(TS)] #[ts(export)] pub struct DummyClientId(ClientId); + +// Export types that are now in ModelExtraData but still needed by the website +#[allow(dead_code)] +#[derive(TS)] +#[ts(export)] +pub struct DummyLLMArchitecture(LLMArchitecture); + +#[allow(dead_code)] +#[derive(TS)] +#[ts(export)] +pub struct DummyLearningRateSchedule(LearningRateSchedule);