diff --git a/common/common.cpp b/common/common.cpp index 1dcc235eac0..99fa55fefde 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -154,18 +154,37 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - params.n_threads = std::stoi(argv[i]); - if (params.n_threads <= 0) { - params.n_threads = std::thread::hardware_concurrency(); + std::string arg_next = argv[i]; + + // split string by , and / + const std::regex regex{R"([,/]+)"}; + std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; + std::vector split_arg{it, {}}; + params.n_threads.resize(split_arg.size()); + for (size_t i = 0; i < split_arg.size(); ++i) { + params.n_threads[i] = std::stoi(split_arg[i]); + if (params.n_threads[i] <= 0) { + params.n_threads[i] = std::thread::hardware_concurrency(); + } } + } else if (arg == "-tb" || arg == "--threads-batch") { if (++i >= argc) { invalid_param = true; break; } - params.n_threads_batch = std::stoi(argv[i]); - if (params.n_threads_batch <= 0) { - params.n_threads_batch = std::thread::hardware_concurrency(); + std::string arg_next = argv[i]; + + // split string by , and / + const std::regex regex{R"([,/]+)"}; + std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; + std::vector split_arg{it, {}}; + params.n_threads_batch.resize(split_arg.size()); + for (size_t i = 0; i < split_arg.size(); ++i) { + params.n_threads_batch[i] = std::stoi(split_arg[i]); + if (params.n_threads_batch[i] <= 0) { + params.n_threads_batch[i] = std::thread::hardware_concurrency(); + } } } else if (arg == "-p" || arg == "--prompt") { if (++i >= argc) { @@ -429,6 +448,18 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.p_split = std::stof(argv[i]); + } else if (arg == "--p-recovery" || arg == "-pr") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.p_recovery = std::stof(argv[i]); + } else if (arg == "--p-decay" || arg == "-pd") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.p_decay = std::stof(argv[i]); } else if (arg == "-m" || arg == "--model") { if (++i >= argc) { invalid_param = true; @@ -540,6 +571,30 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { #else fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n"); #endif + } else if (arg == "--mpi-layer-split") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::string arg_next = argv[i]; + + // split string by , and / + const std::regex regex{R"([\/]+)"}; + const std::regex inner_regex{R"([,]+)"}; + + std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; + std::vector split_arg{it, {}}; + params.mpi_layer_split.resize(split_arg.size()); + for (size_t i = 0; i < split_arg.size(); ++i) { + std::sregex_token_iterator it_inner{split_arg[i].begin(), split_arg[i].end(), inner_regex, -1}; + std::vector split_arg_inner{it_inner, {}}; + params.mpi_layer_split[i].resize(split_arg_inner.size()); + for (size_t j = 0; j < split_arg_inner.size(); ++j) { + params.mpi_layer_split[i][j] = std::stof(split_arg_inner[j]); + } + } + + } else if (arg == "--tensor-split" || arg == "-ts") { if (++i >= argc) { invalid_param = true; @@ -742,7 +797,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" (can be specified more than once for multiple prompts).\n"); printf(" --color colorise output to distinguish prompt and user input from generations\n"); printf(" -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); - printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads); + printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads[0]); printf(" -tb N, --threads-batch N\n"); printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n"); printf(" -p PROMPT, --prompt PROMPT\n"); @@ -811,6 +866,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept); printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split); + printf(" -pr N, --p-recovery N PipeInfer probability recovery (default: %.1f)\n", (double)params.p_recovery); + printf(" -pd N, --p-decay N PipeInfer probability decay (default: %.1f)\n", (double)params.p_decay); printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n"); @@ -836,6 +893,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n"); printf(" Not recommended since this is both slower and uses more VRAM.\n"); #endif // GGML_USE_CUBLAS +#endif +#ifdef GGML_USE_MPI + printf(" --mpi-layer-split N percentiles to split the layers by across nodes\n"); #endif printf(" --verbose-prompt print prompt before generation\n"); printf(" -dkvc, --dump-kv-cache\n"); @@ -859,9 +919,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { std::string get_system_info(const gpt_params & params) { std::ostringstream os; - os << "system_info: n_threads = " << params.n_threads; - if (params.n_threads_batch != -1) { - os << " (n_threads_batch = " << params.n_threads_batch << ")"; + os << "system_info: n_threads = " << params.n_threads[0]; + if (params.n_threads_batch[0] != -1) { + os << " (n_threads_batch = " << params.n_threads_batch[0] << ")"; } os << " / " << std::thread::hardware_concurrency() << " | " << llama_print_system_info(); @@ -909,8 +969,8 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.n_ctx = params.n_ctx; cparams.n_batch = params.n_batch; - cparams.n_threads = params.n_threads; - cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + cparams.n_threads = params.n_threads[0]; + cparams.n_threads_batch = params.n_threads_batch[0] == -1 ? params.n_threads[0] : params.n_threads_batch[0]; cparams.mul_mat_q = params.mul_mat_q; cparams.seed = params.seed; cparams.f16_kv = params.memory_f16; @@ -944,12 +1004,15 @@ void llama_batch_add( for (size_t i = 0; i < seq_ids.size(); ++i) { batch.seq_id[batch.n_tokens][i] = seq_ids[i]; } - batch.logits [batch.n_tokens] = logits; + if (batch.logits) { + batch.logits[batch.n_tokens] = logits; + } batch.n_tokens++; } std::tuple llama_init_from_gpt_params(gpt_params & params) { + int32_t n_threads = params.n_threads[0]; auto mparams = llama_model_params_from_gpt_params(params); llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams); @@ -967,6 +1030,16 @@ std::tuple llama_init_from_gpt_par return std::make_tuple(nullptr, nullptr); } +#ifdef GGML_USE_MPI + int node_id = llama_node_id(lctx); + n_threads = (node_id >= params.n_threads.size()) ? get_num_physical_cores() : params.n_threads[node_id]; + int32_t n_threads_batch = (node_id >= params.n_threads_batch.size()) ? -1 : params.n_threads_batch[node_id]; + + params.n_threads[0] = n_threads; // So we can treat index 0 as what our n_threads is elsewhere + params.n_threads_batch[0] = n_threads_batch; + llama_set_n_threads(lctx, n_threads, (n_threads_batch > 0) ? n_threads_batch : get_num_physical_cores()); +#endif + for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { const std::string& lora_adapter = std::get<0>(params.lora_adapter[i]); float lora_scale = std::get<1>(params.lora_adapter[i]); @@ -976,7 +1049,7 @@ std::tuple llama_init_from_gpt_par ((i > 0) || params.lora_base.empty()) ? NULL : params.lora_base.c_str(), - params.n_threads); + n_threads); if (err != 0) { fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); llama_free(lctx); @@ -992,9 +1065,16 @@ std::tuple llama_init_from_gpt_par { LOG("warming up the model with an empty run\n"); +#ifndef GGML_USE_MPI + // When using MPI, llama_decode() enters into an infinite loop + // on non-head nodes. Thus, we only want to warmup the model here + // if we aren't using MPI. + // FIXME have a way to terminate the infinite loop so we can warmup the model + // in MPI mode std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); llama_kv_cache_clear(lctx); +#endif llama_reset_timings(lctx); } @@ -1384,7 +1464,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector); fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z); - fprintf(stream, "threads: %d # default: %d\n", params.n_threads, std::thread::hardware_concurrency()); + fprintf(stream, "threads: %d # default: %d\n", params.n_threads[0], std::thread::hardware_concurrency()); fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); @@ -1419,49 +1499,4 @@ void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) { printf("\n=== Done dumping\n"); } -void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) { - static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; - - printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n", - view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx); - - std::unordered_map seqs; - llama_kv_cache_view_cell * c_curr = view.cells; - llama_seq_id * cs_curr = view.cells_sequences; - - for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) { - for (int j = 0; j < view.n_max_seq; j++) { - if (cs_curr[j] < 0) { continue; } - if (seqs.find(cs_curr[j]) == seqs.end()) { - if (seqs.size() + 1 >= sizeof(slot_chars)) { break; } - seqs[cs_curr[j]] = seqs.size(); - } - } - if (seqs.size() + 1 >= sizeof(slot_chars)) { break; } - } - - printf("=== Sequence legend: "); - for (const auto & it : seqs) { - printf("%zu=%d, ", it.second, it.first); - } - printf("'+'=other sequence ids"); - c_curr = view.cells; - cs_curr = view.cells_sequences; - for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) { - if (i % row_size == 0) { - printf("\n%5d: ", i); - } - for (int j = 0; j < view.n_max_seq; j++) { - if (cs_curr[j] >= 0) { - const auto & it = seqs.find(cs_curr[j]); - putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+'); - } else { - putchar('.'); - } - } - putchar(' '); - } - - printf("\n=== Done dumping\n"); -} diff --git a/common/common.h b/common/common.h index 2f6fe48ab53..6053bb0d2ea 100644 --- a/common/common.h +++ b/common/common.h @@ -45,8 +45,8 @@ int32_t get_num_physical_cores(); struct gpt_params { uint32_t seed = -1; // RNG seed - int32_t n_threads = get_num_physical_cores(); - int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) + std::vector n_threads = {get_num_physical_cores()}; + std::vector n_threads_batch = {-1}; // number of threads to use for batch processing (-1 = use n_threads) int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) @@ -57,9 +57,12 @@ struct gpt_params { int32_t n_sequences = 1; // number of sequences to decode float p_accept = 0.5f; // speculative decoding accept probability float p_split = 0.1f; // speculative decoding split probability + float p_recovery = 0.0f; // Cumulative probability that p_accept and p_split are increased by per-iteration. + float p_decay = 0.0f; // Cumulative probability that p_accept and p_split are decreased by per-iteration when drafting stops due to p_accept. int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + std::vector> mpi_layer_split = {{1.0}}; // list of percentages of the total number of layers float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs int32_t n_beams = 0; // if non-zero then use beam search of given width. float rope_freq_base = 0.0f; // RoPE base frequency @@ -227,5 +230,4 @@ void dump_non_result_info_yaml( // Dump the KV cache view with the number of sequences per cell. void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80); -// Dump the KV cache view showing individual sequences in each cell (long output). -void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40); + diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 31ec8cade19..842a5e2a465 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -188,6 +188,8 @@ int main(int argc, char ** argv) { return 1; } + llama_split_layers_weighted(ctx, params.mpi_layer_split[0].data(), params.mpi_layer_split[0].size()); + const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); LOG("n_ctx: %d\n", n_ctx); @@ -233,13 +235,24 @@ int main(int argc, char ** argv) { LOG("add_bos: %d\n", add_bos); std::vector embd_inp; - + int n_past = 0; if (params.interactive_first || params.instruct || params.chatml || !params.prompt.empty() || session_tokens.empty()) { LOG("tokenize the prompt\n"); if (params.chatml) { params.prompt = "<|im_start|>system\n" + params.prompt + "<|im_end|>"; } embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); + llama_batch batch = llama_batch_init(n_ctx, 0, 1); + + for (int i = 0; i < embd_inp.size()-1; i++) { + llama_batch_add(batch, embd_inp[i], i, {0}, true); + + } + llama_decode(ctx, batch); + llama_token last = embd_inp.back(); + n_past = embd_inp.size()-2; + embd_inp.clear(); + embd_inp.push_back(last); } else { LOG("use session tokens\n"); embd_inp = session_tokens; @@ -456,7 +469,7 @@ int main(int argc, char ** argv) { bool input_echo = true; bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < embd_inp.size(); - int n_past = 0; + int n_remain = params.n_predict; int n_consumed = 0; int n_session_consumed = 0; @@ -474,6 +487,12 @@ int main(int argc, char ** argv) { struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + long ttft = ggml_time_ms(); + std::vector inter_token_times; + int64_t itt_start; + bool first_token = false; + bool has_run_first_token = false; + while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict if (!embd.empty()) { @@ -595,7 +614,11 @@ int main(int argc, char ** argv) { LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); - if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) { + llama_batch batch = llama_batch_init(n_eval, 0, 1); + for (int j = 0; j < n_eval; j++) { + llama_batch_add(batch, embd[i+j], n_past+j, {0}, true); + } + if (llama_decode(ctx, batch)) { LOG_TEE("%s : failed to eval\n", __func__); return 1; } @@ -623,6 +646,18 @@ int main(int argc, char ** argv) { LOG("saved session to %s\n", path_session.c_str()); } + if (has_run_first_token) { + if (first_token) { + ttft = ggml_time_ms() - ttft; + LOG("\nTTFT: %ld\n", ttft); + first_token = false; + } else { + inter_token_times.push_back(ggml_time_ms() - itt_start); + } + + itt_start = ggml_time_ms(); + } + const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); llama_sampling_accept(ctx_sampling, ctx, id, true); @@ -638,6 +673,13 @@ int main(int argc, char ** argv) { --n_remain; LOG("n_remain: %d\n", n_remain); + + if (!has_run_first_token && (int) embd_inp.size() <= n_consumed) { + + has_run_first_token = true; + first_token = true; + ttft = ggml_time_ms(); + } } else { // some user input remains from prompt or interaction, forward it to processing LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); @@ -850,6 +892,16 @@ int main(int argc, char ** argv) { llama_print_timings(ctx); write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); + uint64_t avg_itt = 0; + for (auto latency : inter_token_times) { + avg_itt += latency; + } + + avg_itt = avg_itt / inter_token_times.size(); + + LOG_TEE("Average inter-token latency: %ld microseconds\n", avg_itt); + LOG_TEE("Time-to-first-token: %ld microseconds\n", ttft); + if (ctx_guidance) { llama_free(ctx_guidance); } llama_free(ctx); llama_free_model(model); diff --git a/examples/mpi/CMakeLists.txt b/examples/mpi/CMakeLists.txt new file mode 100644 index 00000000000..07d83b61d99 --- /dev/null +++ b/examples/mpi/CMakeLists.txt @@ -0,0 +1,8 @@ +set(TARGET mpi) +add_executable(${TARGET} mpi.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_11) +if(TARGET BUILD_INFO) + add_dependencies(${TARGET} BUILD_INFO) +endif() diff --git a/examples/mpi/README.md b/examples/mpi/README.md new file mode 100644 index 00000000000..4b934b0edbc --- /dev/null +++ b/examples/mpi/README.md @@ -0,0 +1,60 @@ +# llama.cpp/example/mpi + +This example program allows you to use various LLaMA language models in an easy and efficient way across an MPI cluster. +It is specifically designed to work with the [llama.cpp](https://github.com/ggerganov/llama.cpp) project, which provides a plain C/C++ implementation with optional 4-bit quantization support for faster, lower memory inference, and is optimized for desktop CPUs. This program can be used to perform various inference tasks with LLaMA models, including generating text based on user-provided prompts and chat-like interactions with reverse prompts. + +## Table of Contents + +1. [Quick Start](#quick-start) +2. [Common Options](#common-options) + +## Quick Start + +To get started right away, write the following to a file on each node, making sure to use the correct path for the model you have: +```bash +--mpi-layer-split 0.8,0.2 -t 4 -m ~/llm-local/codellama-7b.Q3_K_M.gguf --color -c 512 --temp 0.0 --repeat_penalty 1.0 -n 128 -p "double fast_inverse_square_root(double x" +``` + +Each node may have different options, currently they must have the same number of arguments to the mpi-layer-split option and the same +model path, but that will eventually be synchronized from the head node. + +Next, write the hostsfile on the head node. Make sure there is only one slot on each node. + +Finally, run the following command on the head node to start the program across the cluster: + +#### Unix-based systems (Linux, macOS, etc.): + +```bash +mpirun -hostfile hostsfile -mca orte_keep_fqdn_hostnames t --bind-to none ./mpi options.txt +``` + +Where `hostsfile` is the file containing the cluster hostname configuration and `options.txt` is the path +where each node can find its own options. Storing the model on a network filesystem has not yet been +tested and optimized for. + +#### Windows: +Not supported currently. + +For an interactive experience, try this command: + +#### Unix-based systems (Linux, macOS, etc.): + +```bash +./main -m models/7B/ggml-model.bin -n -1 --color -r "User:" --in-prefix " " \ +'User: Hi +AI: Hello. I am an AI chatbot. Would you like to talk? +User: Sure! +AI: What would you like to talk about? +User:' +``` + +## Common Options + +In this section, we cover the most commonly used options for running the `mpi` program with the LLaMA models: + +- `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`). +- `-i, --interactive`: Run the program in interactive mode, allowing you to provide input directly and receive real-time responses. +- `-ins, --instruct`: Run the program in instruction mode, which is particularly useful when working with Alpaca models. +- `-n N, --n-predict N`: Set the number of tokens to predict when generating text. Adjusting this value can influence the length of the generated text. +- `-c N, --ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. +- `--mpi-layer-split`: Set the percentage of layers to distribute to each node. Must have the same number of arguments as the number of nodes in the cluster. Only the layer split percentages passed to the head node are used, they are scattered to all other nodes in the cluster. diff --git a/examples/mpi/mpi.cpp b/examples/mpi/mpi.cpp new file mode 100644 index 00000000000..b4944099eaa --- /dev/null +++ b/examples/mpi/mpi.cpp @@ -0,0 +1,876 @@ +#include "common.h" + +#include "console.h" +#include "llama.h" +#include "build-info.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// TODO add Windows support +#include + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) +#include +#include +#elif defined (_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#include +#endif + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +static llama_context ** g_ctx; +static llama_model ** g_model; +static gpt_params * g_params; +static std::vector * g_input_tokens; +static std::ostringstream * g_output_ss; +static std::vector * g_output_tokens; +static bool is_interacting = false; + + +static void write_logfile( + const llama_context * ctx, const gpt_params & params, const llama_model * model, + const std::vector & input_tokens, const std::string & output, + const std::vector & output_tokens +) { + if (params.logdir.empty()) { + return; + } + + const std::string timestamp = get_sortable_timestamp(); + + const bool success = create_directory_with_parents(params.logdir); + if (!success) { + fprintf(stderr, "%s: warning: failed to create logdir %s, cannot write logfile\n", + __func__, params.logdir.c_str()); + return; + } + + const std::string logfile_path = params.logdir + timestamp + ".yml"; + FILE * logfile = fopen(logfile_path.c_str(), "w"); + + if (logfile == NULL) { + fprintf(stderr, "%s: failed to open logfile %s\n", __func__, logfile_path.c_str()); + return; + } + + fprintf(logfile, "binary: main\n"); + char model_desc[128]; + llama_model_desc(model, model_desc, sizeof(model_desc)); + dump_non_result_info_yaml(logfile, params, ctx, timestamp, input_tokens, model_desc); + + fprintf(logfile, "\n"); + fprintf(logfile, "######################\n"); + fprintf(logfile, "# Generation Results #\n"); + fprintf(logfile, "######################\n"); + fprintf(logfile, "\n"); + + dump_string_yaml_multiline(logfile, "output", output.c_str()); + dump_vector_int_yaml(logfile, "output_tokens", output_tokens); + + llama_dump_timing_info_yaml(logfile, ctx); + fclose(logfile); +} + +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) +static void sigint_handler(int signo) { + if (signo == SIGINT) { + if (!is_interacting) { + is_interacting = true; + } else { + console::cleanup(); + printf("\n"); + llama_print_timings(*g_ctx); + write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); + _exit(130); + } + } +} +#endif + +int main(int argc, char ** argv) { + gpt_params params; + g_params = ¶ms; + + if (argc > 2) { + fprintf(stderr, "Must only have one argument, the file to read options from.\n"); + return 2; + } + + // Manually add the path used to launch this program to the + // options + std::string rawOptions = argv[0]; + rawOptions += ' '; + std::ifstream optionsFile(argv[1]); + if (optionsFile.is_open()) { + // Read in the options file, appending to the launch path + std::ostringstream buf; + buf << optionsFile.rdbuf(); + rawOptions += buf.str(); + optionsFile.close(); + + } else { + fprintf(stderr, "Cannot open options file at path %s\n", argv[1]); + return 3; + } + + // wordexp doesn't work right if there's a trailing newline, so strip it + rawOptions.erase(rawOptions.find_last_not_of(" \t\n\r\f\v") + 1); + + wordexp_t splitOptions; + wordexp(rawOptions.c_str(), &splitOptions, WRDE_NOCMD); + + // Now we can parse like normal, but using the loaded options instead of the passed argv + if (gpt_params_parse(splitOptions.we_wordc, splitOptions.we_wordv, params) == false) { + wordfree(&splitOptions); + return 1; + } + wordfree(&splitOptions); + llama_sampling_params & sparams = params.sparams; + +#ifndef LOG_DISABLE_LOGS + log_set_target(log_filename_generator("main", "log")); + LOG_TEE("Log start\n"); + log_dump_cmdline(argc, argv); +#endif // LOG_DISABLE_LOGS + + // TODO: Dump params ? + //LOG("Params perplexity: %s\n", LOG_TOSTR(params.perplexity)); + + // save choice to use color for later + // (note for later: this is a slightly awkward choice) + console::init(params.simple_io, params.use_color); + atexit([]() { console::cleanup(); }); + + if (params.logits_all) { + printf("\n************\n"); + printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__); + printf("************\n\n"); + + return 0; + } + + if (params.embedding) { + printf("\n************\n"); + printf("%s: please use the 'embedding' tool for embedding calculations\n", __func__); + printf("************\n\n"); + + return 0; + } + + if (params.n_ctx != 0 && params.n_ctx < 8) { + LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__); + params.n_ctx = 8; + } + + if (params.rope_freq_base != 0.0) { + LOG_TEE("%s: warning: changing RoPE frequency base to %g.\n", __func__, params.rope_freq_base); + } + + if (params.rope_freq_scale != 0.0) { + LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale); + } + + LOG_TEE("%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); + LOG_TEE("%s: built with %s for %s\n", __func__, BUILD_COMPILER, BUILD_TARGET); + + if (params.seed == LLAMA_DEFAULT_SEED) { + params.seed = time(NULL); + } + + LOG_TEE("%s: seed = %u\n", __func__, params.seed); + + std::mt19937 rng(params.seed); + if (params.random_prompt) { + params.prompt = gpt_random_prompt(rng); + } + + LOG("%s: llama backend init\n", __func__); + llama_backend_init(params.numa); + + llama_model * model; + llama_context * ctx; + llama_context * ctx_guidance = NULL; + g_model = &model; + g_ctx = &ctx; + + // load the model and apply lora adapter, if any + LOG("%s: load the model and apply lora adapter, if any\n", __func__); + std::tie(model, ctx) = llama_init_from_gpt_params(params); + if (sparams.cfg_scale > 1.f) { + struct llama_context_params lparams = llama_context_params_from_gpt_params(params); + ctx_guidance = llama_new_context_with_model(model, lparams); + } + + if (model == NULL) { + LOG_TEE("%s: error: unable to load model\n", __func__); + return 1; + } + + const int n_ctx_train = llama_n_ctx_train(model); + const int n_ctx = llama_n_ctx(ctx); + LOG("n_ctx: %d\n", n_ctx); + + if (n_ctx > n_ctx_train) { + LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n", + __func__, n_ctx_train, n_ctx); + } + + // print system information + { + LOG_TEE("\n"); + LOG_TEE("%s\n", get_system_info(params).c_str()); + } + + llama_split_layers_weighted(ctx, params.mpi_layer_split.data(), params.mpi_layer_split.size()); + + std::string path_session = params.path_prompt_cache; + std::vector session_tokens; + + if (!path_session.empty()) { + LOG_TEE("%s: attempting to load saved session from '%s'\n", __func__, path_session.c_str()); + + // fopen to check for existing session + FILE * fp = std::fopen(path_session.c_str(), "rb"); + if (fp != NULL) { + std::fclose(fp); + + session_tokens.resize(n_ctx); + size_t n_token_count_out = 0; + if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) { + LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str()); + return 1; + } + session_tokens.resize(n_token_count_out); + llama_set_rng_seed(ctx, params.seed); + + LOG_TEE("%s: loaded a session with prompt size of %d tokens\n", __func__, (int) session_tokens.size()); + } else { + LOG_TEE("%s: session file does not exist, will create\n", __func__); + } + } + + const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM; + LOG("add_bos: %d\n", add_bos); + + std::vector embd_inp; + + if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { + LOG("tokenize the prompt\n"); + embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); + } else { + LOG("use session tokens\n"); + embd_inp = session_tokens; + } + + LOG("prompt: \"%s\"\n", log_tostr(params.prompt)); + LOG("tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + + // Should not run without any tokens + if (embd_inp.empty()) { + embd_inp.push_back(llama_token_bos(model)); + LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str()); + } + + // Tokenize negative prompt + std::vector guidance_inp; + int guidance_offset = 0; + int original_prompt_len = 0; + if (ctx_guidance) { + LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt)); + + guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos, true); + LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str()); + + std::vector original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); + LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str()); + + original_prompt_len = original_inp.size(); + guidance_offset = (int)guidance_inp.size() - original_prompt_len; + LOG("original_prompt_len: %s", log_tostr(original_prompt_len)); + LOG("guidance_offset: %s", log_tostr(guidance_offset)); + } + + if ((int) embd_inp.size() > n_ctx - 4) { + LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4); + return 1; + } + + // debug message about similarity of saved session, if applicable + size_t n_matching_session_tokens = 0; + if (!session_tokens.empty()) { + for (llama_token id : session_tokens) { + if (n_matching_session_tokens >= embd_inp.size() || id != embd_inp[n_matching_session_tokens]) { + break; + } + n_matching_session_tokens++; + } + if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) { + LOG_TEE("%s: using full prompt from session file\n", __func__); + } else if (n_matching_session_tokens >= embd_inp.size()) { + LOG_TEE("%s: session file has exact match for prompt!\n", __func__); + } else if (n_matching_session_tokens < (embd_inp.size() / 2)) { + LOG_TEE("%s: warning: session file has low similarity to prompt (%zu / %zu tokens); will mostly be reevaluated\n", + __func__, n_matching_session_tokens, embd_inp.size()); + } else { + LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n", + __func__, n_matching_session_tokens, embd_inp.size()); + } + + // remove any "future" tokens that we might have inherited from the previous session + llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1); + } + + LOGLN( + "recalculate the cached logits (check): embd_inp.empty() %s, n_matching_session_tokens %zu, embd_inp.size() %zu, session_tokens.size() %zu, embd_inp.size() %zu", + log_tostr(embd_inp.empty()), n_matching_session_tokens, embd_inp.size(), session_tokens.size(), embd_inp.size()); + + // if we will use the cache for the full prompt without reaching the end of the cache, force + // reevaluation of the last token token to recalculate the cached logits + if (!embd_inp.empty() && n_matching_session_tokens == embd_inp.size() && session_tokens.size() > embd_inp.size()) { + LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1); + + session_tokens.resize(embd_inp.size() - 1); + } + + // number of tokens to keep when resetting context + if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct) { + params.n_keep = (int)embd_inp.size(); + } + + // prefix & suffix for instruct mode + const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true); + const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true); + + LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx).c_str()); + LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx).c_str()); + + // in instruct mode, we inject a prefix and a suffix to each input by the user + if (params.instruct) { + params.interactive_first = true; + params.antiprompt.push_back("### Instruction:\n\n"); + } + + // enable interactive mode if interactive start is specified + if (params.interactive_first) { + params.interactive = true; + } + + if (params.verbose_prompt) { + LOG_TEE("\n"); + LOG_TEE("%s: prompt: '%s'\n", __func__, params.prompt.c_str()); + LOG_TEE("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + for (int i = 0; i < (int) embd_inp.size(); i++) { + LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str()); + } + + if (ctx_guidance) { + LOG_TEE("\n"); + LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str()); + LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); + for (int i = 0; i < (int) guidance_inp.size(); i++) { + LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str()); + } + } + + if (params.n_keep > 0) { + LOG_TEE("%s: static prompt based on n_keep: '", __func__); + for (int i = 0; i < params.n_keep; i++) { + LOG_TEE("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str()); + } + LOG_TEE("'\n"); + } + LOG_TEE("\n"); + } + + if (params.interactive) { +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = sigint_handler; + sigemptyset (&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); +#elif defined (_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); +#endif + + LOG_TEE("%s: interactive mode on.\n", __func__); + + if (!params.antiprompt.empty()) { + for (const auto & antiprompt : params.antiprompt) { + LOG_TEE("Reverse prompt: '%s'\n", antiprompt.c_str()); + if (params.verbose_prompt) { + auto tmp = ::llama_tokenize(ctx, antiprompt, false, true); + for (int i = 0; i < (int) tmp.size(); i++) { + LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str()); + } + } + } + } + + if (params.input_prefix_bos) { + LOG_TEE("Input prefix with BOS\n"); + } + + if (!params.input_prefix.empty()) { + LOG_TEE("Input prefix: '%s'\n", params.input_prefix.c_str()); + if (params.verbose_prompt) { + auto tmp = ::llama_tokenize(ctx, params.input_prefix, true, true); + for (int i = 0; i < (int) tmp.size(); i++) { + LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str()); + } + } + } + + if (!params.input_suffix.empty()) { + LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str()); + if (params.verbose_prompt) { + auto tmp = ::llama_tokenize(ctx, params.input_suffix, false, true); + for (int i = 0; i < (int) tmp.size(); i++) { + LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str()); + } + } + } + } + LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str()); + LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); + LOG_TEE("\n\n"); + + if (params.interactive) { + const char *control_message; + if (params.multiline_input) { + control_message = " - To return control to LLaMa, end your input with '\\'.\n" + " - To return control without starting a new line, end your input with '/'.\n"; + } else { + control_message = " - Press Return to return control to LLaMa.\n" + " - To return control without starting a new line, end your input with '/'.\n" + " - If you want to submit another line, end your input with '\\'.\n"; + } + LOG_TEE("== Running in interactive mode. ==\n"); +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) + LOG_TEE( " - Press Ctrl+C to interject at any time.\n"); +#endif + LOG_TEE( "%s\n", control_message); + + is_interacting = params.interactive_first; + } + + bool is_antiprompt = false; + bool input_echo = true; + bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < embd_inp.size(); + + int n_past = 0; + int n_remain = params.n_predict; + int n_consumed = 0; + int n_session_consumed = 0; + int n_past_guidance = 0; + + std::vector input_tokens; g_input_tokens = &input_tokens; + std::vector output_tokens; g_output_tokens = &output_tokens; + std::ostringstream output_ss; g_output_ss = &output_ss; + + // the first thing we will do is to output the prompt, so set color accordingly + console::set_display(console::prompt); + + std::vector embd; + std::vector embd_guidance; + + struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams); + + while ((n_remain != 0 && !is_antiprompt) || params.interactive) { + // predict + if (!embd.empty()) { + // Note: n_ctx - 4 here is to match the logic for commandline prompt handling via + // --prompt or --file which uses the same value. + int max_embd_size = n_ctx - 4; + + // Ensure the input doesn't exceed the context size by truncating embd if necessary. + if ((int) embd.size() > max_embd_size) { + const int skipped_tokens = (int) embd.size() - max_embd_size; + embd.resize(max_embd_size); + + console::set_display(console::error); + printf("<>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); + console::set_display(console::reset); + fflush(stdout); + } + + // infinite text generation via context swapping + // if we run out of context: + // - take the n_keep first tokens from the original prompt (via n_past) + // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches + if (n_past + (int) embd.size() + std::max(0, guidance_offset) > n_ctx) { + if (params.n_predict == -2) { + LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); + break; + } + + const int n_left = n_past - params.n_keep - 1; + const int n_discard = n_left/2; + + LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", + n_past, n_left, n_ctx, params.n_keep, n_discard); + + llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + + n_past -= n_discard; + + if (ctx_guidance) { + n_past_guidance -= n_discard; + } + + LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); + + LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + + LOG("clear session path\n"); + path_session.clear(); + } + + // try to reuse a matching prefix from the loaded session instead of re-eval (via n_past) + if (n_session_consumed < (int) session_tokens.size()) { + size_t i = 0; + for ( ; i < embd.size(); i++) { + if (embd[i] != session_tokens[n_session_consumed]) { + session_tokens.resize(n_session_consumed); + break; + } + + n_past++; + n_session_consumed++; + + if (n_session_consumed >= (int) session_tokens.size()) { + ++i; + break; + } + } + if (i > 0) { + embd.erase(embd.begin(), embd.begin() + i); + } + } + + // evaluate tokens in batches + // embd is typically prepared beforehand to fit within a batch, but not always + if (ctx_guidance) { + int input_size = 0; + llama_token * input_buf = NULL; + + if (n_past_guidance < (int) guidance_inp.size()) { + // Guidance context should have the same data with these modifications: + // + // * Replace the initial prompt + // * Shift everything by guidance_offset + embd_guidance = guidance_inp; + if (embd.begin() + original_prompt_len < embd.end()) { + embd_guidance.insert( + embd_guidance.end(), + embd.begin() + original_prompt_len, + embd.end() + ); + } + + input_buf = embd_guidance.data(); + input_size = embd_guidance.size(); + + LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str()); + } else { + input_buf = embd.data(); + input_size = embd.size(); + } + + for (int i = 0; i < input_size; i += params.n_batch) { + int n_eval = std::min(input_size - i, params.n_batch); + if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) { + LOG_TEE("%s : failed to eval\n", __func__); + return 1; + } + + n_past_guidance += n_eval; + } + } + + for (int i = 0; i < (int) embd.size(); i += params.n_batch) { + int n_eval = (int) embd.size() - i; + if (n_eval > params.n_batch) { + n_eval = params.n_batch; + } + + LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + + if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0))) { + LOG_TEE("%s : failed to eval\n", __func__); + return 1; + } + + n_past += n_eval; + + LOG("n_past = %d\n", n_past); + } + + if (!embd.empty() && !path_session.empty()) { + session_tokens.insert(session_tokens.end(), embd.begin(), embd.end()); + n_session_consumed = session_tokens.size(); + } + } + + embd.clear(); + embd_guidance.clear(); + + if ((int) embd_inp.size() <= n_consumed && !is_interacting) { + // optionally save the session on first sample (for faster prompt loading next time) + if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) { + need_to_save_session = false; + llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); + + LOG("saved session to %s\n", path_session.c_str()); + } + + const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance); + + llama_sampling_accept(ctx_sampling, ctx, id, true); + + LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str()); + + embd.push_back(id); + + // echo this to console + input_echo = true; + + // decrement remaining sampling budget + --n_remain; + + LOG("n_remain: %d\n", n_remain); + } else { + // some user input remains from prompt or interaction, forward it to processing + LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed); + while ((int) embd_inp.size() > n_consumed) { + embd.push_back(embd_inp[n_consumed]); + + // push the prompt in the sampling context in order to apply repetition penalties later + // for the prompt, we don't apply grammar rules + llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false); + + ++n_consumed; + if ((int) embd.size() >= params.n_batch) { + break; + } + } + } + + // display text + if (input_echo) { + for (auto id : embd) { + const std::string token_str = llama_token_to_piece(ctx, id); + printf("%s", token_str.c_str()); + + if (embd.size() > 1) { + input_tokens.push_back(id); + } else { + output_tokens.push_back(id); + output_ss << token_str; + } + } + fflush(stdout); + } + // reset color to default if there is no pending user input + if (input_echo && (int) embd_inp.size() == n_consumed) { + console::set_display(console::reset); + } + + // if not currently processing queued inputs; + if ((int) embd_inp.size() <= n_consumed) { + // check for reverse prompt in the last n_prev tokens + if (!params.antiprompt.empty()) { + const int n_prev = 32; + const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev); + + is_antiprompt = false; + // Check if each of the reverse prompts appears at the end of the output. + // If we're not running interactively, the reverse prompt might be tokenized with some following characters + // so we'll compensate for that by widening the search window a bit. + for (std::string & antiprompt : params.antiprompt) { + size_t extra_padding = params.interactive ? 0 : 2; + size_t search_start_pos = last_output.length() > static_cast(antiprompt.length() + extra_padding) + ? last_output.length() - static_cast(antiprompt.length() + extra_padding) + : 0; + + if (last_output.find(antiprompt, search_start_pos) != std::string::npos) { + if (params.interactive) { + is_interacting = true; + } + is_antiprompt = true; + break; + } + } + + if (is_antiprompt) { + LOG("found antiprompt: %s\n", last_output.c_str()); + } + } + + // deal with end of text token in interactive mode + if (llama_sampling_last(ctx_sampling) == llama_token_eos(model)) { + LOG("found EOS token\n"); + + if (params.interactive) { + if (!params.antiprompt.empty()) { + // tokenize and inject first reverse prompt + const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true); + embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); + is_antiprompt = true; + } + + is_interacting = true; + printf("\n"); + } else if (params.instruct) { + is_interacting = true; + } + } + + if (n_past > 0 && is_interacting) { + LOG("waiting for user input\n"); + + if (params.instruct) { + printf("\n> "); + } + + if (params.input_prefix_bos) { + LOG("adding input prefix BOS token\n"); + embd_inp.push_back(llama_token_bos(model)); + } + + std::string buffer; + if (!params.input_prefix.empty()) { + LOG("appending input prefix: '%s'\n", params.input_prefix.c_str()); + printf("%s", params.input_prefix.c_str()); + } + + // color user input only + console::set_display(console::user_input); + + std::string line; + bool another_line = true; + do { + another_line = console::readline(line, params.multiline_input); + buffer += line; + } while (another_line); + + // done taking input, reset color + console::set_display(console::reset); + + // Add tokens to embd only if the input buffer is non-empty + // Entering a empty line lets the user pass control back + if (buffer.length() > 1) { + // append input suffix if any + if (!params.input_suffix.empty()) { + LOG("appending input suffix: '%s'\n", params.input_suffix.c_str()); + printf("%s", params.input_suffix.c_str()); + } + + LOG("buffer: '%s'\n", buffer.c_str()); + + const size_t original_size = embd_inp.size(); + + // instruct mode: insert instruction prefix + if (params.instruct && !is_antiprompt) { + LOG("inserting instruction prefix\n"); + n_consumed = embd_inp.size(); + embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); + } + if (params.escape) { + process_escapes(buffer); + } + + const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true); + const auto line_inp = ::llama_tokenize(ctx, buffer, false, false); + const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true); + LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str()); + + embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end()); + embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); + embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end()); + + // instruct mode: insert response suffix + if (params.instruct) { + LOG("inserting instruction suffix\n"); + embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); + } + + for (size_t i = original_size; i < embd_inp.size(); ++i) { + const llama_token token = embd_inp[i]; + output_tokens.push_back(token); + output_ss << llama_token_to_piece(ctx, token); + } + + n_remain -= line_inp.size(); + LOG("n_remain: %d\n", n_remain); + } else { + LOG("empty line, passing control back\n"); + } + + input_echo = false; // do not echo this again + } + + if (n_past > 0) { + if (is_interacting) { + llama_sampling_reset(ctx_sampling); + } + is_interacting = false; + } + } + + // end of text token + if (!embd.empty() && embd.back() == llama_token_eos(model) && !(params.instruct || params.interactive)) { + LOG_TEE(" [end of text]\n"); + break; + } + + // In interactive mode, respect the maximum number of tokens and drop back to user input when reached. + // We skip this logic when n_predict == -1 (infinite) or -2 (stop at context size). + if (params.interactive && n_remain <= 0 && params.n_predict >= 0) { + n_remain = params.n_predict; + is_interacting = true; + } + } + + if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) { + LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str()); + llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); + } + + llama_print_timings(ctx); + write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); + + if (ctx_guidance) { llama_free(ctx_guidance); } + llama_free(ctx); + llama_free_model(model); + + llama_sampling_free(ctx_sampling); + llama_backend_free(); + +#ifndef LOG_DISABLE_LOGS + LOG_TEE("Log end\n"); +#endif // LOG_DISABLE_LOGS + + return 0; +} diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index ace755c51d8..591b3f21fdf 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -5,6 +5,9 @@ #include #include #include +#include +#include + #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 @@ -18,14 +21,63 @@ struct seq_draft { std::vector i_batch_tgt; std::vector tokens; + std::vector prefix_tokens; struct llama_sampling_context * ctx_sampling; }; +struct seq_async_run { + struct ggml_cgraph * cgraph; + llama_batch batch; + std::vector drafts; + int n_past_tgt; + int prefix_n_past_tgt; + int n_past_dft; + int i_dft; + int s_keep; + int seq_offset; + int n_past_max; + bool speculative; + bool canceled; +}; + + +void check_for_cancel(llama_context *ctx_tgt, int n_past_tgt, std::deque &tgt_cgraphs, + std::vector &generated, int n_seq_dft); + +void begin_async_run(const llama_sampling_params& sparams, int n_seq_dft, + llama_context *ctx_tgt, int max_seq, + int n_past_dft, const std::vector &drafts, + std::deque &tgt_cgraphs, + int32_t &batch_id, int &n_past, llama_kv_cache_view &kvc_view, + bool is_spec, llama_batch batch, int n_past_max, int prefix_n_past, int seq_offset); + +bool start_async_spec_run(const gpt_params ¶ms, llama_context *ctx_tgt, llama_context *ctx_dft, + std::deque &free_sequence_offsets, int max_seq, llama_batch &batch_tgt, int n_predict, + int prefix_n_past, int n_past_dft, llama_sampling_context *ctx_sampling, + std::deque &tgt_cgraphs, const seq_async_run ¤t_run, + int &spec_past_tgt, int &spec_past_dft, int first_run, int orig_offset, int32_t &batch_id, + llama_batch &batch_dft, int &n_drafted, std::vector &drafts, llama_token &id, + llama_kv_cache_view &kvc, float p_adjust, int &n_reject); + +void begin_non_spec_run(const gpt_params ¶ms, int n_seq_dft, llama_context *ctx, int max_seq, + const std::vector &drafts, llama_token id, int32_t &batch_id, int &n_past, int n_past_dft, + std::deque &dft_cgraphs, llama_kv_cache_view &kvc_view); + +void +run_speculation_loop(const gpt_params ¶ms, const float p_accept, llama_context *ctx_tgt, llama_context *ctx_dft, + const int max_seq, llama_batch &batch_tgt, int n_predict, int n_past_tgt, int n_past_dft, + llama_sampling_context *ctx_sampling, int &spec_past_tgt, int &spec_past_dft, bool &first_run, + std::deque &free_sequence_offsets, int32_t &batch_id, llama_batch &batch_dft, int &n_drafted, + std::vector &drafts, std::deque &tgt_cgraphs, + seq_async_run ¤t_run, llama_kv_cache_view &kvc_view_dft, llama_token &id, int &n_rejected); + +float calc_p_adjust(const gpt_params ¶ms, int iter, int n_reject); + int main(int argc, char ** argv) { gpt_params params; - if (gpt_params_parse(argc, argv, params) == false) { + if (!gpt_params_parse(argc, argv, params)) { return 1; } @@ -62,11 +114,40 @@ int main(int argc, char ** argv) { params.logits_all = true; std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params); + llama_split_comm(ctx_tgt, (llama_node_id(ctx_tgt) == 0 || llama_node_id(ctx_tgt) == params.mpi_layer_split[0].size()) ? 0 : -1); + llama_swap_comm(ctx_tgt); + + llama_split_comm(ctx_tgt, (llama_node_id(ctx_tgt) < params.mpi_layer_split[0].size()) ? 0 : -1); +// printf("Size of first split: %lu, element: %f\n", params.mpi_layer_split[0].size(), params.mpi_layer_split[0][0]); + // load the draft model params.model = params.model_draft; params.n_gpu_layers = params.n_gpu_layers_draft; std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params); + llama_split_comm(ctx_dft, (llama_node_id(ctx_dft) == 0 || llama_node_id(ctx_dft) == params.mpi_layer_split[0].size()) ? 0 : -1); + llama_swap_comm(ctx_dft); + + llama_split_comm(ctx_dft, (llama_node_id(ctx_dft) >= params.mpi_layer_split[0].size()) ? 0 : -1); + +// printf("Size of second split: %lu, element: %f\n", params.mpi_layer_split[1].size(), params.mpi_layer_split[1][0]); + + + llama_split_layers_weighted(ctx_tgt, params.mpi_layer_split[0].data(), params.mpi_layer_split[0].size()); + llama_split_layers_weighted(ctx_dft, params.mpi_layer_split[1].data(), params.mpi_layer_split[1].size()); + + std::deque free_sequence_offsets; + const int n_simul_seqs = 1000; + const int max_seq = n_simul_seqs * n_seq_dft + 1; + for (int i = 0; i < n_simul_seqs; i++) { + free_sequence_offsets.push_back(i*n_seq_dft + 1); + } + + { + LOG_TEE("\n"); + LOG_TEE("%s\n", get_system_info(params).c_str()); + } + { const int n_vocab_tgt = llama_n_vocab(model_tgt); const int n_vocab_dft = llama_n_vocab(model_dft); @@ -81,17 +162,17 @@ int main(int argc, char ** argv) { return 1; } - for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { - const char * token_text_tgt = llama_token_get_text(model_tgt, i); - const char * token_text_dft = llama_token_get_text(model_dft, i); - if (std::strcmp(token_text_tgt, token_text_dft) != 0) { - fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__); - fprintf(stderr, "token %d content differs - target '%s', draft '%s'\n", i, - llama_token_to_piece(ctx_tgt, i).c_str(), - llama_token_to_piece(ctx_dft, i).c_str()); - return 1; - } - } +// for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { +// const char * token_text_tgt = llama_token_get_text(model_tgt, i); +// const char * token_text_dft = llama_token_get_text(model_dft, i); +// if (std::strcmp(token_text_tgt, token_text_dft) != 0) { +// fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__); +// fprintf(stderr, "token %d content differs - target '%s', draft '%s'\n", i, +// llama_token_to_piece(ctx_tgt, i).c_str(), +// llama_token_to_piece(ctx_dft, i).c_str()); +// return 1; +// } +// } } @@ -121,20 +202,47 @@ int main(int argc, char ** argv) { fprintf(stderr, "\n\n"); - for (auto id : inp) { - fprintf(stderr, "%s", llama_token_to_piece(ctx_tgt, id).c_str()); + if (llama_node_id(ctx_tgt) == 0) { + for (auto id : inp) { + fprintf(stderr, "%s", llama_token_to_piece(ctx_tgt, id).c_str()); + } } + + fflush(stderr); const int n_input = inp.size(); const auto t_enc_start = ggml_time_us(); + int32_t batch_id = 0; + + llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, max_seq+1); + llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, max_seq+1); + llama_batch batch_tgt_async = llama_batch_init(params.n_ctx, 0, max_seq+1); + + batch_dft.batch_id = batch_id; + batch_tgt.batch_id = batch_id; + batch_tgt_async.batch_id = batch_id; + + std::vector seq_ids; + for (int i = 0; i <= max_seq; i++) { + seq_ids.emplace_back(i); + } + + for (int i = 0; i < n_input-1; i++) { + llama_batch_add(batch_dft, inp[i], i, seq_ids, true); + llama_batch_add(batch_tgt, inp[i], i, seq_ids, true); + } + llama_decode(ctx_tgt, batch_tgt); + llama_batch_clear(batch_tgt); + llama_batch_add(batch_dft, inp.back(), n_input-1, seq_ids, true); + llama_batch_add(batch_tgt, inp.back(), n_input-1, seq_ids, true); + // eval the prompt with both models - llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0)); - llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); - llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0)); + llama_decode(ctx_tgt, batch_tgt); + llama_decode(ctx_dft, batch_dft); const auto t_enc_end = ggml_time_us(); @@ -148,8 +256,9 @@ int main(int argc, char ** argv) { int n_drafted = 0; int n_accept = 0; - int n_past_tgt = inp.size(); - int n_past_dft = inp.size(); + + int n_past_tgt = n_input; + int n_past_dft = n_input; // used to determine end of generation bool has_eos = false; @@ -167,8 +276,11 @@ int main(int argc, char ** argv) { drafts[s].ctx_sampling = llama_sampling_init(params.sparams); } - llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); - llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft); + + + + std::deque dft_cgraphs; + std::deque tgt_cgraphs; const auto t_dec_start = ggml_time_us(); @@ -176,46 +288,236 @@ int main(int argc, char ** argv) { drafts[0].i_batch_tgt.resize(1); drafts[0].i_batch_tgt[0] = 0; + seq_async_run current_run; + + current_run.n_past_tgt = n_past_tgt - 1; + current_run.n_past_max = n_past_tgt; + current_run.n_past_dft = n_past_dft - 1; + current_run.seq_offset = free_sequence_offsets.front(); + struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx_tgt, max_seq+1); + struct llama_kv_cache_view kvc_view_dft = llama_kv_cache_view_init(ctx_dft, max_seq+1); + std::vector generated = inp; + + int spec_past_tgt = n_past_tgt; + int spec_past_dft = n_past_dft; + + long ttft = ggml_time_us(); + std::vector inter_token_times; + int64_t itt_start; + bool first_token = false; + bool has_run_first_token = false; + + bool first_run = true; + llama_token id; + + int n_rejected = 0; + while (true) { + + + int i_dft = 0; + int s_keep = 0; + + + bool is_waiting = llama_mpi_iprobe(ctx_tgt); + llama_swap_comm(ctx_tgt); + llama_sync_token(ctx_tgt, reinterpret_cast(&is_waiting), 0); + llama_swap_comm(ctx_tgt); + + if (!tgt_cgraphs.empty() && is_waiting) { + check_for_cancel(ctx_tgt, n_past_tgt, tgt_cgraphs, generated, n_seq_dft); + + struct seq_async_run run = tgt_cgraphs.back(); + LOG("Finishing async decode, is spec = %d, old seq_offset = %d, new seq offset = %d, batch id = %d\n", run.speculative, current_run.seq_offset, run.seq_offset, run.batch.batch_id); + struct ggml_cgraph * cgraph = run.cgraph; + + + + LOG("Checking run, last generated: %d, first draft: %d\n", generated.back(), run.drafts[run.s_keep].tokens[0]); +// if(run.n_past_max >= n_past_tgt && (!run_speculative || (n_past_tgt-current_run.n_past_tgt >= 0 && generated.at(generated.size() - (n_past_tgt-current_run.n_past_tgt+1)) == drafts[s_keep].tokens[0]))) { + + if(!run.canceled) { + + drafts = run.drafts; + current_run.speculative = run.speculative; + current_run.n_past_max = run.n_past_max; + current_run.n_past_tgt = run.n_past_tgt; + current_run.n_past_dft = run.n_past_dft; + current_run.seq_offset = run.seq_offset; + s_keep = run.s_keep; + + //drafts[0].tokens.erase(drafts[0].tokens.begin()); + for (int s = 0; s < n_seq_dft; ++s) { + if (!drafts[s].active) { + continue; + } + + drafts[s].tokens.erase(drafts[s].tokens.begin()); + + } + + } else { +// if (llama_node_id(ctx_tgt) == 0) { +// printf("\nFinishing canceled async run, spec: %d, batch id: %d, batch: %s\n", run.speculative, run.batch.batch_id, LOG_BATCH_TOSTR_PRETTY(ctx_tgt, run.batch).c_str()); +// } +// FIXME Main bottleneck because when finishing a canceled run, we're forced to wait until a correct run +// is finished instead of jumping back to speculation + llama_finish_async_decode(*ctx_tgt, run.batch, cgraph); + tgt_cgraphs.pop_back(); + if (run.speculative) { +// if(llama_node_id(ctx_tgt) == 0) { +// fprintf(stderr, "\nRun was canceled, pushing seq offset %d to free seq offsets\n", +// run.seq_offset); +// fflush(stderr); +// } + free_sequence_offsets.push_back(run.seq_offset); +// if(llama_node_id(ctx_tgt) == 0) { +// +// fprintf(stderr, "\nDone pushing seq offset %d to free seq offsets\n", run.seq_offset); +// fflush(stderr); +// } + } +// fprintf(stderr, "Incorrect starting token\n"); + continue; + } + + +// if (llama_node_id(ctx_tgt) == 0) { +// printf("\nFinishing async run, spec: %d, batch id: %d, batch: %s\n", run.speculative, run.batch.batch_id, LOG_BATCH_TOSTR_PRETTY(ctx_tgt, run.batch).c_str()); +// } + llama_finish_async_decode(*ctx_tgt, run.batch, cgraph); + tgt_cgraphs.pop_back(); + + spec_past_tgt = n_past_tgt; + spec_past_dft = n_past_dft; + + first_run = true; + + } else if (!tgt_cgraphs.empty()) { + run_speculation_loop(params, p_accept, ctx_tgt, ctx_dft, max_seq, batch_tgt, n_predict, n_past_tgt, + n_past_dft, ctx_sampling, + spec_past_tgt, spec_past_dft, first_run, free_sequence_offsets, batch_id, batch_dft, + n_drafted, drafts, tgt_cgraphs, current_run, kvc_view_dft, id, n_rejected); + continue; + } + + + if (llama_node_id(ctx_tgt) == 0) { +// llama_kv_cache_view_update(ctx_tgt, &kvc_view); +// LOG("Beginning sampling, tgt cache layout:\n%s", dump_kv_cache_view_seqs(kvc_view, 1).c_str()); + LOG("n_past_tgt: %d, current_run.n_past_tgt: %d, current_run.n_past_max: %d\n", n_past_tgt, current_run.n_past_tgt, current_run.n_past_max); + } else { +// llama_kv_cache_view_update(ctx_dft, &kvc_view_dft); +// LOG("Beginning sampling, dft cache layout:\n%s", dump_kv_cache_view_seqs(kvc_view_dft, 1).c_str()); + LOG("n_past_dft: %d, current_run.n_past_dft: %d, current_run.n_past_max: %d\n", n_past_dft, current_run.n_past_dft, current_run.n_past_max); + } // print current draft sequences + bool any_active = false; for (int s = 0; s < n_seq_dft; ++s) { if (!drafts[s].active) { continue; } + any_active = true; const auto & tokens = drafts[s].tokens; LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str()); } + LOG("Any active drafts: %d\n", any_active); + + + bool any_match = false; + + std::string token_str; - int i_dft = 0; - int s_keep = 0; - while (true) { - LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); + int old_n_past_tgt = n_past_tgt; + int old_n_past_dft = n_past_dft; + + + std::deque keeps(seq_ids.begin(), seq_ids.end()); + keeps.erase(std::find(keeps.begin(), keeps.end(),s_keep)); + keeps.push_front(s_keep); + while (!keeps.empty()) { + + LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d, current_run.n_past_tgt = %3d, n_past_tgt = %3d, seq_offset = %d, keeps[0] = %d\n", s_keep, i_dft, drafts[keeps[0]].i_batch_tgt[i_dft], current_run.n_past_tgt, n_past_tgt, current_run.seq_offset, keeps[0]); + // sample from the target model - llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); + id = llama_sampling_sample(ctx_sampling, ctx_tgt, nullptr, drafts[keeps[0]].i_batch_tgt[i_dft]); + token_str = llama_token_to_piece(ctx_tgt, id); + // Swap to pipeline roots + llama_swap_comm(ctx_tgt); + LOG("Swapped comm to pipeline roots, id %d\n", llama_node_id(ctx_tgt)); - llama_sampling_accept(ctx_sampling, ctx_tgt, id, true); + llama_sync_token(ctx_tgt, &id, 0); - //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); + LOG("Sampling index: %d\n", drafts[keeps[0]].i_batch_tgt[i_dft]); + + + LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str()); + + + LOG("Sampled token: %d ('%s'), n_past_tgt: %d, current_run.n_past_tgt + i_dft: %d, drafts[keeps[0]].i_batch_tgt[i_dft]: %d\n", id, token_str.c_str(), n_past_tgt, current_run.n_past_tgt + i_dft, drafts[keeps[0]].i_batch_tgt[i_dft]); + + + if (current_run.n_past_tgt + i_dft == n_past_tgt-1) { + any_match = true; + ++n_predict; + if (current_run.speculative) { + n_accept++; + } + + if (has_run_first_token) { + if (first_token) { + ttft = ggml_time_us() - ttft; + LOG("\nTTFT: %ld\n", ttft); + first_token = false; + } else { + inter_token_times.push_back(ggml_time_us() - itt_start); + } + + itt_start = ggml_time_us(); + } + llama_sampling_accept(ctx_sampling, ctx_tgt, id, true); + + // Root of WORLD + LOG("Accepting token %d ('%s'), n_past_tgt: %d\n", id, token_str.c_str(), n_past_tgt); + generated.push_back(id); + if (llama_node_id(ctx_tgt) == 0) { + if (!params.use_color) { + printf("%s", token_str.c_str()); +// fprintf(stderr, "%s", token_str.c_str()); + fflush(stdout); +// fflush(stderr); + } + + } + } + + // Switch back to target pipeline only + llama_swap_comm(ctx_tgt); + LOG("Swapped comm to target only, id %d\n", llama_node_id(ctx_tgt)); - const std::string token_str = llama_token_to_piece(ctx_tgt, id); - printf("%s", token_str.c_str()); - fflush(stdout); if (id == llama_token_eos(model_tgt)) { has_eos = true; } - ++n_predict; + + + + + if (!current_run.speculative) { + break; + } + // check if the target token matches any of the drafts { bool matches = false; - + keeps.clear(); for (int s = 0; s < n_seq_dft; ++s) { if (!drafts[s].active) { continue; @@ -223,219 +525,175 @@ int main(int argc, char ** argv) { if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) { LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str()); - - s_keep = s; matches = true; + keeps.push_back(s); + s_keep = keeps[0]; } else { drafts[s].active = false; } } if (matches) { - ++n_accept; - ++n_past_tgt; - ++n_past_dft; + if (current_run.n_past_tgt + i_dft == n_past_tgt-1) { + ++n_accept; + ++n_past_tgt; + ++n_past_dft; + } ++i_dft; - - continue; + if (params.use_color) { + // Color token according to its origin sequence + printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str()); + fflush(stdout); + } + if (current_run.speculative && current_run.n_past_tgt + i_dft < n_past_tgt) { + continue; + } } } - LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); - - // TODO: simplify - { - LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); - - llama_kv_cache_seq_keep(ctx_dft, s_keep); - llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1); - llama_kv_cache_seq_keep(ctx_dft, 0); - - llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); - llama_kv_cache_seq_keep(ctx_tgt, s_keep); - llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1); - llama_kv_cache_seq_keep(ctx_tgt, 0); + if (params.use_color) { + printf("%s", token_str.c_str()); } + fflush(stdout); - for (int s = 0; s < n_seq_dft; ++s) { - drafts[s].active = false; - drafts[s].tokens.clear(); - drafts[s].i_batch_tgt.clear(); - } - // note: will be erased after the speculation phase - drafts[0].tokens.push_back(id); - drafts[0].i_batch_tgt.push_back(0); + } - llama_batch_clear(batch_dft); - llama_batch_add (batch_dft, id, n_past_dft, { 0 }, true); - llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); - // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); - llama_decode (ctx_dft, batch_dft); - ++n_past_dft; + if (llama_node_id(ctx_tgt) < 0) { + LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); - break; } - if (n_predict > params.n_predict || has_eos) { - break; + if (!any_match) { + if (current_run.speculative) { +// fprintf(stderr, "\nNo match, pushing seq offset %d to free seq offsets\n", current_run.seq_offset); +// fflush(stderr); + free_sequence_offsets.push_back(current_run.seq_offset); + } +// fprintf(stderr, "No match\n"); + continue; } - llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling); + n_rejected = 0; - int n_seq_cur = 1; - int n_past_cur = n_past_dft; + check_for_cancel(ctx_tgt, n_past_tgt, tgt_cgraphs, generated, n_seq_dft); - for (int s = 0; s < n_seq_dft; ++s) { - drafts[s].active = false; - drafts[s].drafting = false; - } - drafts[0].active = true; - drafts[0].drafting = true; - drafts[0].i_batch_dft = 0; - llama_batch_clear(batch_tgt); - llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true); + // TODO: simplify + if (current_run.speculative){ + LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d, current_run.n_past_tgt = %d, current_run.n_past_dft = %d\n", s_keep+current_run.seq_offset, n_past_tgt, n_past_dft, current_run.n_past_tgt, current_run.n_past_dft); - // sample n_draft tokens from the draft model using tree-based sampling - for (int i = 0; i < n_draft; ++i) { - batch_dft.n_tokens = 0; + for (int i = 0; i < n_seq_dft; i++) { - for (int s = 0; s < n_seq_dft; ++s) { - drafts[s].skip = false; - } + llama_kv_cache_seq_rm (ctx_tgt, i+current_run.seq_offset, n_past_tgt, -1); + llama_kv_cache_seq_rm (ctx_dft, i+current_run.seq_offset, n_past_dft, -1); - for (int s = 0; s < n_seq_dft; ++s) { - if (!drafts[s].drafting || drafts[s].skip) { - continue; - } - llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft); + } - const auto & cur_p = drafts[s].ctx_sampling->cur; + llama_kv_cache_seq_rm (ctx_tgt, 0, current_run.n_past_tgt+1, n_past_tgt); + llama_kv_cache_seq_rm (ctx_dft, 0, current_run.n_past_dft+1, n_past_dft); - for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) { - LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str()); - } + llama_kv_cache_seq_cp (ctx_tgt, s_keep+current_run.seq_offset, 0, current_run.n_past_tgt+1, n_past_tgt); + llama_kv_cache_seq_cp (ctx_dft, s_keep+current_run.seq_offset, 0, current_run.n_past_dft+1, n_past_dft); - if (cur_p[0].p < p_accept) { - LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, p_accept); - drafts[s].drafting = false; - continue; - } - std::vector sa(1, s); + for (int i = 1; i <= max_seq; i++) { - // attempt to split the branch if the probability is high enough - for (int f = 1; f < 8; ++f) { - if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) { - LOG("splitting seq %3d into %3d\n", s, n_seq_cur); + llama_kv_cache_seq_rm(ctx_tgt, i, current_run.n_past_tgt+1, n_past_tgt); + llama_kv_cache_seq_rm(ctx_dft, i, current_run.n_past_dft+1, n_past_dft); - llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); - llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); + llama_kv_cache_seq_cp(ctx_tgt, 0, i, current_run.n_past_tgt+1, n_past_tgt); + llama_kv_cache_seq_cp(ctx_dft, 0, i, current_run.n_past_dft+1, n_past_dft); - // all previous tokens from this branch are now also part of the new branch - for (int t = 0; t < batch_tgt.n_tokens; ++t) { - for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) { - if (batch_tgt.seq_id[t][p] == s) { - batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur; - batch_tgt.n_seq_id[t]++; - break; - } - } - } - // copy the draft state - drafts[n_seq_cur].active = true; - drafts[n_seq_cur].drafting = true; - drafts[n_seq_cur].skip = true; + } - drafts[n_seq_cur].tokens = drafts[s].tokens; - drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; - drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; - llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling); - sa.push_back(n_seq_cur); - n_seq_cur++; - } else { - break; - } - } + if (llama_node_id(ctx_tgt) == 0) { +// llama_kv_cache_view_update(ctx_tgt, &kvc_view); +// LOG("Done keeping sequence, new tgt cache layout:\n%s", dump_kv_cache_view_seqs(kvc_view, 1).c_str()); + } else { +// llama_kv_cache_view_update(ctx_dft, &kvc_view_dft); +// LOG("Done keeping sequence, new dft cache layout:\n%s", dump_kv_cache_view_seqs(kvc_view_dft, 1).c_str()); + } - // add drafted token for each sequence - for (int is = 0; is < (int) sa.size(); ++is) { - const llama_token id = cur_p[is].id; - const int s = sa[is]; - llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true); - drafts[s].tokens.push_back(id); + } - // add unique drafted tokens to the target batch - drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); + begin_non_spec_run(params, n_seq_dft, ctx_tgt, max_seq, drafts, id, batch_id, n_past_tgt, n_past_dft, tgt_cgraphs, + kvc_view); - llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true); + begin_non_spec_run(params, n_seq_dft, ctx_dft, max_seq, drafts, id, batch_id, n_past_dft, n_past_dft, dft_cgraphs, + kvc_view_dft); - // add the token to the batch for batched decoding with the draft model - drafts[s].i_batch_dft = batch_dft.n_tokens; + if (!has_run_first_token) { - llama_batch_add(batch_dft, id, n_past_cur, { s }, true); + has_run_first_token = true; + first_token = true; + } - if (batch_tgt.n_tokens > n_draft) { - drafts[s].drafting = false; - } - } - } + seq_async_run dft_run = dft_cgraphs.back(); + dft_cgraphs.pop_back(); + llama_finish_async_decode(*ctx_dft, dft_run.batch, dft_run.cgraph); - // no sequence is drafting anymore - if (batch_dft.n_tokens == 0) { - break; - } - // evaluate the drafted tokens on the draft model - llama_decode(ctx_dft, batch_dft); - ++n_past_cur; - ++n_drafted; + spec_past_tgt = n_past_tgt; + spec_past_dft = n_past_dft; - if (batch_tgt.n_tokens > n_draft) { - break; + + if (!current_run.speculative) { + if (free_sequence_offsets.empty()) { + continue; } + current_run.seq_offset = free_sequence_offsets.front(); +// if (llama_node_id(ctx_tgt) == 0) { +// fprintf(stderr, "Popping %d from seq offsets for spec run\n", current_run.seq_offset); +// fflush(stderr); +// } + free_sequence_offsets.pop_front(); } - // evaluate the target model on the drafted tokens - { - llama_kv_cache_seq_keep(ctx_tgt, 0); - for (int s = 1; s < n_seq_dft; ++s) { - llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1); - } - // LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); - llama_decode(ctx_tgt, batch_tgt); - ++n_past_tgt; - } +// bool is_waiting = false; - // the first token is always proposed by the traget model before the speculation loop so we erase it here - for (int s = 0; s < n_seq_dft; ++s) { - if (!drafts[s].active) { - continue; - } + run_speculation_loop(params, p_accept, ctx_tgt, ctx_dft, max_seq, batch_tgt, n_predict, n_past_tgt, n_past_dft, + ctx_sampling, + spec_past_tgt, spec_past_dft, first_run, free_sequence_offsets, batch_id, batch_dft, + n_drafted, drafts, tgt_cgraphs, current_run, kvc_view_dft, id, n_rejected); - drafts[s].tokens.erase(drafts[s].tokens.begin()); + + if (n_predict > params.n_predict || has_eos) { + break; } + + + + } auto t_dec_end = ggml_time_us(); + uint64_t avg_itt = 0; + for (auto latency : inter_token_times) { + avg_itt += latency; + } + + avg_itt = avg_itt / inter_token_times.size(); + LOG_TEE("\n\n"); LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); - + LOG_TEE("Average inter-token latency: %f seconds\n", avg_itt / 1e6f); + LOG_TEE("Time-to-first-token: %f seconds\n", ttft / 1e6f); + + LOG_TEE("\n"); LOG_TEE("n_draft = %d\n", n_draft); LOG_TEE("n_predict = %d\n", n_predict); @@ -454,7 +712,17 @@ int main(int argc, char ** argv) { llama_sampling_free(drafts[s].ctx_sampling); } - llama_batch_free(batch_dft); + if (llama_node_id(ctx_tgt) == 0) { + for (size_t i = tgt_cgraphs.size() - 1; i >= 0; i--) { + const auto &run = tgt_cgraphs[i]; + llama_finish_async_decode(*ctx_tgt, run.batch, run.cgraph); + } + } + + for (size_t i = dft_cgraphs.size()-1; i >= 0; i--) { + const auto& run = dft_cgraphs[i]; + llama_finish_async_decode(*ctx_dft, run.batch, run.cgraph); + } llama_free(ctx_tgt); llama_free_model(model_tgt); @@ -468,3 +736,599 @@ int main(int argc, char ** argv) { return 0; } + +void +run_speculation_loop(const gpt_params ¶ms, const float p_accept, llama_context *ctx_tgt, llama_context *ctx_dft, + const int max_seq, llama_batch &batch_tgt, int n_predict, int n_past_tgt, int n_past_dft, + llama_sampling_context *ctx_sampling, int &spec_past_tgt, int &spec_past_dft, bool &first_run, + std::deque &free_sequence_offsets, int32_t &batch_id, llama_batch &batch_dft, int &n_drafted, + std::vector &drafts, std::deque &tgt_cgraphs, + seq_async_run ¤t_run, llama_kv_cache_view &kvc_view_dft, llama_token &id, int &n_rejected) { + bool is_waiting = llama_mpi_iprobe(ctx_tgt); + llama_swap_comm(ctx_tgt); + llama_sync_token(ctx_tgt, reinterpret_cast(&is_waiting), 0); + llama_swap_comm(ctx_tgt); + + + if (is_waiting) { +// fprintf(stderr, "\nIs waiting, pushing seq offset %d to free seq offsets\n", current_run.seq_offset); +// fflush(stderr); + free_sequence_offsets.push_back(current_run.seq_offset); + } + int iter = 0; + float p_adjust = calc_p_adjust(params, iter, n_rejected); + while((!is_waiting && p_accept + (p_adjust = calc_p_adjust(params, iter, n_rejected)) < 1.0)) { + + + + + + int orig_offset = current_run.seq_offset; + bool should_run_spec = true; + std::deque checked_offsets; + do { + should_run_spec = true; + for (const auto &r: tgt_cgraphs) { + if (r.seq_offset == current_run.seq_offset && r.speculative) { + checked_offsets.push_back(current_run.seq_offset); + + should_run_spec = false; + if (!free_sequence_offsets.empty()) { + current_run.seq_offset = free_sequence_offsets.front(); + free_sequence_offsets.pop_front(); + + } + break; + } + } + } while (!should_run_spec && !free_sequence_offsets.empty()); + + if (!should_run_spec) { + LOG("Ending spec because no available offsets\n"); + break; + } +// if (llama_node_id(ctx_tgt) == 0) { +// fprintf(stderr, "\nErasing seq offset %d from free seq offsets\n", current_run.seq_offset); +// fflush(stderr); +// } + auto it = std::find(free_sequence_offsets.begin(), free_sequence_offsets.end(), current_run.seq_offset); + if (it != free_sequence_offsets.end()) { + free_sequence_offsets.erase(it); + } + + + if (start_async_spec_run(params, ctx_tgt, ctx_dft, free_sequence_offsets, max_seq, + batch_tgt, n_predict, n_past_tgt, n_past_dft, ctx_sampling, + tgt_cgraphs, + current_run, spec_past_tgt, spec_past_dft, first_run, orig_offset, + batch_id, batch_dft, n_drafted, drafts, id, kvc_view_dft, p_adjust, n_rejected)) { + LOG("Ending spec run because returned true\n"); + break; + } + + is_waiting = llama_mpi_iprobe(ctx_tgt); + llama_swap_comm(ctx_tgt); + llama_sync_token(ctx_tgt, reinterpret_cast(&is_waiting), 0); + llama_swap_comm(ctx_tgt); + first_run = false; + + iter++; +// break; + + } +} + +float calc_p_adjust(const gpt_params ¶ms, int iter, int n_reject) { + return iter * params.p_recovery - std::max(n_reject * params.p_decay, 0.0f); +} + +void begin_non_spec_run(const gpt_params ¶ms, const int n_seq_dft, llama_context *ctx, const int max_seq, + const std::vector &drafts, llama_token id, int32_t &batch_id, int &n_past, + int n_past_dft, + std::deque &dft_cgraphs, llama_kv_cache_view &kvc_view) { + + std::vector non_spec_drafts = std::vector(n_seq_dft); + for (int s = 0; s < n_seq_dft; ++s) { + non_spec_drafts[s].ctx_sampling = llama_sampling_init(params.sparams); + llama_sampling_cp(drafts[s].ctx_sampling, drafts[s].ctx_sampling); + non_spec_drafts[s].i_batch_tgt = std::vector(1,0); + non_spec_drafts[s].i_batch_dft = drafts[s].i_batch_dft; + non_spec_drafts[s].tokens = std::vector(1, id); + non_spec_drafts[s].active = drafts[s].active; + non_spec_drafts[s].drafting = drafts[s].drafting; + non_spec_drafts[s].skip = drafts[s].skip; + non_spec_drafts[s].prefix_tokens = std::vector(0); + } + + llama_batch async_batch = llama_batch_init(params.n_ctx, 0, max_seq + 1); + + llama_batch_clear(async_batch); + + llama_batch_add(async_batch, id, n_past, {0}, true); + + begin_async_run(params.sparams, n_seq_dft, ctx, max_seq, n_past_dft, + non_spec_drafts, dft_cgraphs, batch_id, n_past, kvc_view, false, async_batch, n_past+1, n_past, 0); + + n_past++; + +} + +bool start_async_spec_run(const gpt_params ¶ms, llama_context *ctx_tgt, llama_context *ctx_dft, + std::deque &free_sequence_offsets, int max_seq, llama_batch &batch_tgt, int n_predict, + int prefix_n_past, int n_past_dft, llama_sampling_context *ctx_sampling, + std::deque &tgt_cgraphs, const seq_async_run ¤t_run, + int &spec_past_tgt, int &spec_past_dft, int first_run, int orig_offset, int32_t &batch_id, + llama_batch &batch_dft, int &n_drafted, std::vector &drafts, llama_token &id, + llama_kv_cache_view &kvc, float p_adjust, int &n_reject) { + LOG("Doing speculative run, seq_offset = %d, spec_past_tgt = %d, spec_past_dft = %d, prefix_n_past = %d, n_past_dft = %d\n", + current_run.seq_offset, spec_past_tgt, spec_past_dft, prefix_n_past, n_past_dft); + + for (int i = 0; i < params.n_parallel; i++) { + + llama_kv_cache_seq_rm(ctx_tgt, i + current_run.seq_offset, (first_run) ? prefix_n_past : prefix_n_past - 1, -1); + llama_kv_cache_seq_rm(ctx_dft, i + current_run.seq_offset, (first_run) ? n_past_dft : n_past_dft - 1, -1); + + LOG("Copying tgt sequence %d to %d from positions %d to %d\n", (first_run) ? 0 : orig_offset, + i + current_run.seq_offset, prefix_n_past, spec_past_tgt); + + llama_kv_cache_seq_cp(ctx_tgt, (first_run) ? 0 : orig_offset, i + current_run.seq_offset, (first_run) ? prefix_n_past : prefix_n_past - 1, + spec_past_tgt+1); + llama_kv_cache_seq_cp(ctx_dft, (first_run) ? 0 : orig_offset, i + current_run.seq_offset, (first_run) ? n_past_dft : n_past_dft - 1, + spec_past_dft+1); + + + } + + + llama_batch_clear(batch_tgt); + + for (int s = 0; s < params.n_parallel; ++s) { + drafts[s].active = false; + if (!first_run) { + if (!drafts[s].tokens.empty()) { + drafts[s].prefix_tokens.insert(drafts[s].prefix_tokens.end(), drafts[s].tokens.begin(), + drafts[s].tokens.end()); + } + } else { + drafts[s].prefix_tokens.clear(); + } + drafts[s].tokens.clear(); + drafts[s].i_batch_tgt.clear(); + } + // note: will be erased after the speculation phase + drafts[0].tokens.push_back(id); + + + llama_batch_clear(batch_dft); + + + if (llama_node_id(ctx_dft) == 0) { +// llama_kv_cache_view_update(ctx_dft, &kvc); +// LOG("Draft KV cache view:\n%s\n", dump_kv_cache_view_seqs(kvc, 1).c_str()); + } + + + llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling); + + int n_seq_cur = 0; + int max_ran_seq = 0; + int n_past_cur = spec_past_dft; + + for (int s = 0; s < params.n_parallel; ++s) { + drafts[s].skip = true; + drafts[s].active = false; + drafts[s].drafting = false; + } + + + drafts[0].active = true; + drafts[0].drafting = true; + drafts[0].skip = false; + + drafts[0].i_batch_dft = 0; + + + // sample n_draft tokens from the draft model using tree-based sampling + for (int i = 0; i < params.n_draft; ++i) { + batch_dft.n_tokens = 0; + + + + for (int s = 0; s <= max_ran_seq; ++s) { + if (!drafts[s].drafting || drafts[s].skip) { + continue; + } + + + + // Swap back to pipeline roots + llama_swap_comm(ctx_dft); + LOG("Swapped comm to pipeline roots, id %d\n", llama_node_id(ctx_dft)); + + llama_sync_token(ctx_dft, &(drafts[s].i_batch_dft), 1); + + llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, nullptr, drafts[s].i_batch_dft); + + auto &cur_p = drafts[s].ctx_sampling->cur; + + llama_sync_token_data(ctx_dft, cur_p.data(), 1); + // TODO investigate potential bottleneck + for (int k = 1; k < 8; ++k) { + llama_sync_token_data(ctx_dft, &(cur_p[k]), 1); + } + + // Back to draft pipeline only + llama_swap_comm(ctx_dft); + LOG("Swapped comm to draft only, id %d\n", llama_node_id(ctx_dft)); + + + if (llama_node_id(ctx_dft) >= 0) { + for (int k = 0; k < std::min(params.n_parallel, (int) cur_p.size()); ++k) { + LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, s+current_run.seq_offset, i+spec_past_dft, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str()); + } + } + + + if (cur_p[0].p < params.p_accept + p_adjust) { + LOG("stopping drafting for seq %3d, probability too low: %.3f < %.3f\n", s, cur_p[0].p, + params.p_accept); + drafts[s].drafting = false; + continue; + } + + + std::vector sa(1, s); + + // attempt to split the branch if the probability is high enough + for (int f = 1; f < 8; ++f) { + if (n_seq_cur < params.n_parallel - 1 && cur_p[f].p > params.p_split + p_adjust) { + n_seq_cur++; + LOG("splitting seq %3d into %3d\n", s, n_seq_cur); + + + LOG("Removing dft sequence %d from positions %d to %d\n", n_seq_cur + current_run.seq_offset, n_past_dft, n_past_cur); + + llama_kv_cache_seq_rm(ctx_dft, n_seq_cur + current_run.seq_offset, n_past_dft, n_past_cur); + + LOG("Copying dft sequence %d to %d from positions %d to %d\n", s + current_run.seq_offset, n_seq_cur + current_run.seq_offset, n_past_dft, n_past_cur); + + llama_kv_cache_seq_cp(ctx_dft, s + current_run.seq_offset, n_seq_cur + current_run.seq_offset, n_past_dft, n_past_cur); + + // all previous tokens from this branch are now also part of the new branch + for (int t = 0; t < batch_tgt.n_tokens; ++t) { + for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) { + if (batch_tgt.seq_id[t][p] == s + current_run.seq_offset) { + batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur + current_run.seq_offset; + batch_tgt.n_seq_id[t]++; + break; + } + } + } + + + // copy the draft state + drafts[n_seq_cur].active = true; + drafts[n_seq_cur].drafting = true; + drafts[n_seq_cur].skip = false; + + drafts[n_seq_cur].tokens = drafts[s].tokens; + drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; + + llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling); + + sa.push_back(n_seq_cur); + + + } else { + break; + } + } + + // add drafted token for each sequence + // TODO commenting this out fixes async + for (int is = 0; is < (int) sa.size(); ++is) { + const llama_token id = cur_p[is].id; + + const int s = sa[is]; + + llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true); + + drafts[s].tokens.push_back(id); + + // add unique drafted tokens to the target batch + + drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); + + LOG("Adding drafted token %d to tgt, sequence %d, position %d, i_batch_tgt %d\n", id, + s + current_run.seq_offset, spec_past_tgt + i, batch_tgt.n_tokens); + llama_batch_add(batch_tgt, id, spec_past_tgt + i, {s + current_run.seq_offset}, true); + + // add the token to the batch for batched decoding with the draft model + drafts[s].i_batch_dft = batch_dft.n_tokens; + + LOG("Adding drafted token %d to dft, sequence %d, position %d\n", id, s + current_run.seq_offset, n_past_cur); + + llama_batch_add(batch_dft, id, n_past_cur, {s + current_run.seq_offset}, true); + + if (batch_tgt.n_tokens > params.n_draft) { + drafts[s].drafting = false; + } + } + } + + // no sequence is drafting anymore + if (batch_dft.n_tokens == 0) { + break; + } + + // evaluate the drafted tokens on the draft model + LOG("Running synchronous draft decode while still drafting\n"); + LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); + llama_decode(ctx_dft, batch_dft); + ++n_past_cur; + ++n_drafted; + + max_ran_seq = n_seq_cur; + + llama_batch_clear(batch_dft); + + if (batch_tgt.n_tokens > params.n_draft) { + break; + } + } + + // no sequence is drafting anymore + if (batch_dft.n_tokens != 0) { + // evaluate the drafted tokens on the draft model + LOG("Running synchronous draft decode when no seqs drafting\n"); + LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); + llama_decode(ctx_dft, batch_dft); + + } + + + + + + + + // evaluate the target model on the drafted tokens + { +// llama_kv_cache_seq_keep(ctx_tgt, 0); // Needed to get to "Here's the code:" + + + + + + if (batch_tgt.n_tokens == 0) { +// fprintf(stderr, "\nNo tgt tokens, pushing seq offset %d to free seq offsets\n", current_run.seq_offset); +// fflush(stderr); + free_sequence_offsets.push_back(current_run.seq_offset); + n_reject++; + return true; + } + +// bool is_waiting = llama_mpi_iprobe(ctx_tgt); +// llama_swap_comm(ctx_tgt); +// llama_sync_token(ctx_tgt, reinterpret_cast(&is_waiting), 0); +// llama_swap_comm(ctx_tgt); +// +// if (is_waiting) { +// free_sequence_offsets.push_back(current_run.seq_offset); +// return true; +// } + + size_t max_draft_tokens = 0; + + for (int s = 0; s < params.n_parallel; ++s) { + if (!drafts[s].active) { + continue; + } + + drafts[s].tokens.erase(drafts[s].tokens.begin()); + max_draft_tokens = std::max(max_draft_tokens, drafts[s].tokens.size()); + //drafts[s].tokens.erase(drafts[s].tokens.begin()); + } + + if (first_run) { + ++n_drafted; + } + + begin_async_run(params.sparams, params.n_parallel, ctx_tgt, max_seq, n_past_dft, drafts, tgt_cgraphs, + batch_id, spec_past_tgt, kvc, true, batch_tgt, spec_past_tgt + drafts[0].tokens.size(), prefix_n_past, current_run.seq_offset); + + spec_past_tgt += drafts[0].tokens.size(); + spec_past_dft += drafts[0].tokens.size(); + id = drafts[0].tokens.back(); + first_run = false; + +// LOG("Beginning tgt spec run, run.prefix_n_past=%d, run.prefix_n_past_tgt=%d, run.n_past_dft=%d, run.n_past_max=%d, new spec_past_tgt=%d, new spec_past_dft=%d, new id=%d\n", +// run.prefix_n_past, run.prefix_n_past_tgt, run.n_past_dft, run.n_past_max, spec_past_tgt, spec_past_dft, id +// ); + + } + + return false; + + +} + +void begin_async_run(const llama_sampling_params& sparams, const int n_seq_dft, + llama_context *ctx_tgt, const int max_seq, + int n_past_dft, const std::vector &drafts, + std::deque &tgt_cgraphs, + int32_t &batch_id, int &n_past, llama_kv_cache_view &kvc_view, + const bool is_spec, llama_batch batch, const int n_past_max, const int prefix_n_past, const int seq_offset) { + batch_id++; + + + LOG("Beginning async decode, batch id = %d\n", batch_id); + + + + + + // batch_tgt.n_tokens = 1 + + + struct seq_async_run run; + run.seq_offset = seq_offset; + run.batch = llama_batch_init(1028, 0, max_seq); + run.batch.batch_id = batch_id; + run.batch.n_tokens = batch.n_tokens; + for (int i = 0; i < batch.n_tokens; i++) { + run.batch.n_seq_id[i] = batch.n_seq_id[i]; + int cur_n_seqs = 0; + for (int j = 0; j < run.batch.n_seq_id[i]; j++) { + run.batch.seq_id[i][j] = batch.seq_id[i][j]; + } + run.batch.token[i] = batch.token[i]; + run.batch.pos[i] = batch.pos[i]; + run.batch.logits[i] = batch.logits[i]; + } + run.batch.batch_id = batch_id; + run.canceled = false; + run.s_keep = 0; +// if (!free_sequence_offsets.empty()) { +// run.seq_offset = free_sequence_offsets.front(); +// printf("Popping %d from seq offsets\n", run.seq_offset); +// free_sequence_offsets.pop_front(); +// } else if(!tgt_cgraphs.empty()){ +// printf("Getting offset from head of tgt cgraphs\n"); +// run.seq_offset = tgt_cgraphs.front().seq_offset; +// } else { +// printf("NO FREE OFFSETS AND NO TGT CGRAPHS\n"); +// } + + + + LOG("target async batch: %s\n, batch_id = %d\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, run.batch).c_str(), + batch_id); + + run.drafts = std::vector(n_seq_dft); + for (int s = 0; s < n_seq_dft; ++s) { + run.drafts[s].ctx_sampling = llama_sampling_init(sparams); + llama_sampling_cp(drafts[s].ctx_sampling, run.drafts[s].ctx_sampling); + run.drafts[s].i_batch_tgt = drafts[s].i_batch_tgt; + run.drafts[s].i_batch_dft = drafts[s].i_batch_dft; + run.drafts[s].tokens = drafts[s].tokens; + run.drafts[s].active = drafts[s].active; + run.drafts[s].drafting = drafts[s].drafting; + run.drafts[s].skip = drafts[s].skip; + run.drafts[s].prefix_tokens = drafts[s].prefix_tokens; + } + run.n_past_tgt = n_past; + run.prefix_n_past_tgt = prefix_n_past; + run.n_past_max = n_past_max; + run.n_past_dft = n_past_dft; + run.speculative = is_spec; + + if (!is_spec) { + for (int i = 0; i <= max_seq; i++) { + llama_kv_cache_seq_rm(ctx_tgt, i, n_past, n_past + 1); + } + } else { + for (int i = 0; i < n_seq_dft; i++) { + llama_kv_cache_seq_rm(ctx_tgt, i+seq_offset, n_past, n_past + 1); + } + } + run.cgraph = llama_start_async_decode(*ctx_tgt, run.batch); + tgt_cgraphs.push_front(run); + + if (!is_spec) { + for (int i = 1; i <= max_seq; i++) { + llama_kv_cache_seq_cp(ctx_tgt, 0, i, n_past, n_past + 1); + } + } + + if (llama_node_id(ctx_tgt) == 0) { +// llama_kv_cache_view_update(ctx_tgt, &kvc_view); +// LOG("Done running non-spec, cache view:\n%s", dump_kv_cache_view_seqs(kvc_view, 1).c_str()); +// printf("\nBeginning async run, batch id: %d, batch: %s\n", run.batch.batch_id, LOG_BATCH_TOSTR_PRETTY(ctx_tgt, run.batch).c_str()); + } +} + +void check_for_cancel(llama_context *ctx_tgt, int n_past_tgt, std::deque &tgt_cgraphs, + std::vector &generated, const int n_seq_dft) { + std::vector canceled_batches; + for (auto &run : tgt_cgraphs) { + if(!run.canceled) { + bool correct_prefix = true; + + if (run.speculative && n_past_tgt >= run.prefix_n_past_tgt) { + for (int draft_id = n_seq_dft - 1; draft_id >= 0; draft_id--) { + if (!run.drafts[draft_id].tokens.empty()) { + correct_prefix = true; + } else { + continue; + } + size_t draft_index = 0; + int prev_token = -1; + int prev_gen_token = -1; + std::vector concat_tokens = run.drafts[draft_id].prefix_tokens; + concat_tokens.insert(concat_tokens.end(), run.drafts[draft_id].tokens.begin(), + run.drafts[draft_id].tokens.end()); + + + LOG("Prefix tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, run.drafts[draft_id].prefix_tokens).c_str()); + + LOG("Concat tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, concat_tokens).c_str()); + + + size_t index = run.prefix_n_past_tgt + draft_index; + LOG("Looping over run starting at gen index %zu, draft index %zu, prefix_n_past_tgt %d, n_past_tgt %d, generated size %zu\n", + index, draft_index, run.prefix_n_past_tgt, n_past_tgt, generated.size()); + while (index < generated.size() && draft_index < concat_tokens.size() && + generated.size() > (size_t) run.prefix_n_past_tgt) { + LOG("Checking draft at index %zu and generated index %zu\n", draft_index, index); + if (generated.at(index) != concat_tokens[draft_index]) { + LOG("Found non-matching prefix at generated index %zu, draft index %zu, gen token %d, draft token %d, prev draft token %d, prev gen token %d\n", + index, draft_index, generated.at(index), concat_tokens[draft_index], prev_token, + prev_gen_token); + correct_prefix = false; + break; + } + prev_token = concat_tokens[draft_index]; + prev_gen_token = generated[index]; + draft_index++; + index = run.prefix_n_past_tgt + draft_index; + } + if (correct_prefix) { + run.s_keep = draft_id; + } + } + } + + + if (run.n_past_max < n_past_tgt || !correct_prefix) { + LOG("Cancelling batch ID %d, run.npast_max %d, run.n_past_tgt %d, n_past_tgt %d, run_speculative %d, tokens[0] %d, generated: %d, generated index: %zu\n", + run.batch.batch_id, run.n_past_max, run.n_past_tgt, n_past_tgt, run.speculative, + run.drafts[0].tokens[0], (n_past_tgt < run.n_past_tgt) ? -1 : generated.at( + generated.size() - (n_past_tgt - run.n_past_tgt + 1)), + generated.size() - (n_past_tgt - run.n_past_tgt + 1)); + + if (run.speculative) { + // TODO put these in a vector so they are transmitted in a burst + canceled_batches.push_back(run.batch.batch_id); + for (int i = 0; i < n_seq_dft; i++) { + +// llama_kv_cache_seq_rm (ctx_tgt, i+run.seq_offset, run.n_past_tgt, -1); + + + } + } + run.canceled = true; +//// } +// +// if (run_speculative) { +// free_sequence_offsets.push_back(seq_offset); +// } + } + } + } + + if (!canceled_batches.empty()) { + llama_cancel_run(ctx_tgt, canceled_batches.data(), canceled_batches.size()); + } +} diff --git a/ggml-mpi.c b/ggml-mpi.c index ae176d70758..b1e9b9b047c 100644 --- a/ggml-mpi.c +++ b/ggml-mpi.c @@ -6,6 +6,7 @@ #include #include +#include #define MIN(a, b) ((a) < (b) ? (a) : (b)) @@ -14,10 +15,49 @@ struct ggml_mpi_context { int rank; int size; + MPI_Comm comm; + int layer_start; + int layer_end; + MPI_Status status; + MPI_Request asyncSendRequest; + struct ggml_tensor * duped_send_tensor; + MPI_Request asyncRecvRequest; + struct ggml_tensor * duped_recv_tensor; + bool asyncSendWaiting; + bool asyncRecvWaiting; + struct ggml_cgraph * cgraph; + bool async; + bool running_decode; + bool res; + bool embed; + void* send_buffer; + int trans_id; + int recv_trans_id; }; +int ggml_mpi_recv_trans_id(struct ggml_mpi_context * ctx_mpi) { + return ctx_mpi->recv_trans_id; +} + +int ggml_mpi_trans_id(struct ggml_mpi_context * ctx_mpi) { + return ctx_mpi->trans_id; +} + +void ggml_mpi_inc_trans_id(struct ggml_mpi_context * ctx_mpi) { + ctx_mpi->trans_id++; +} + +void ggml_mpi_sync_pipelined( + struct ggml_mpi_context * ctx_mpi, + void * val, + int count, + MPI_Datatype datatype, + int tag +); + void ggml_mpi_backend_init(void) { - MPI_Init(NULL, NULL); + int ret; + MPI_Init_thread(NULL, NULL, MPI_THREAD_FUNNELED, &ret); } void ggml_mpi_backend_free(void) { @@ -29,31 +69,355 @@ struct ggml_mpi_context * ggml_mpi_init(void) { MPI_Comm_rank(MPI_COMM_WORLD, &ctx->rank); MPI_Comm_size(MPI_COMM_WORLD, &ctx->size); + ctx->comm = MPI_COMM_WORLD; + ctx->asyncSendWaiting = false; + ctx->asyncRecvWaiting = false; + ctx->running_decode = false; + ctx->async = false; + const int buffer_size = 128*1024*1024; + ctx->send_buffer = calloc(1, buffer_size); // 128MB buffer + MPI_Buffer_attach(ctx->send_buffer, buffer_size); return ctx; } +struct ggml_mpi_context * ggml_mpi_split_comm(struct ggml_mpi_context * ctx, int color, int key) { + if (color < 0) { + color = MPI_UNDEFINED; + } + struct ggml_mpi_context * newCtx = calloc(1, sizeof(struct ggml_mpi_context)); + MPI_Comm_split(ctx->comm, color, key, &newCtx->comm); + if(newCtx->comm == MPI_COMM_NULL) { + newCtx->rank = -1; + newCtx->size = -1; + return newCtx; + } + MPI_Comm_rank(newCtx->comm, &newCtx->rank); + MPI_Comm_size(newCtx->comm, &newCtx->size); + return newCtx; +} + void ggml_mpi_free(struct ggml_mpi_context * ctx) { + if(ctx->comm == MPI_COMM_NULL) { + return; + } + + if (ctx->comm == NULL) { + return; + } + + ggml_mpi_sync_pipelined(ctx, NULL, 0, MPI_INT8_T, GGML_MPI_SHUTDOWN); + int buffer_size = 128*1024*1024; + MPI_Buffer_detach(&ctx->send_buffer, &buffer_size); + MPI_Comm_free(&(ctx->comm)); free(ctx); } +bool ggml_mpi_is_decoding(struct ggml_mpi_context * ctx_mpi) { + return ctx_mpi->running_decode; +} + +struct ggml_cgraph * ggml_mpi_get_cgraph(struct ggml_mpi_context * ctx_mpi) { + return ctx_mpi->cgraph; +} + +void ggml_mpi_set_cgraph(struct ggml_mpi_context * ctx_mpi, struct ggml_cgraph * cgraph) { + ctx_mpi->cgraph = cgraph; +} + int ggml_mpi_rank(struct ggml_mpi_context * ctx) { return ctx->rank; } -void ggml_mpi_eval_init( +size_t ggml_mpi_size(struct ggml_mpi_context * ctx) { + return ctx->size; +} + +void ggml_mpi_barrier(struct ggml_mpi_context * ctx_mpi) { + MPI_Barrier(ctx_mpi->comm); +} + +void ggml_mpi_probe(struct ggml_mpi_context * ctx_mpi, int src, int tag) { + MPI_Probe((src >= 0) ? src : MPI_ANY_SOURCE, (tag >= 0) ? tag : MPI_ANY_TAG, ctx_mpi->comm, &(ctx_mpi->status)); +} + +int ggml_mpi_iprobe(struct ggml_mpi_context * ctx_mpi, int src, int tag) { + if(ctx_mpi->comm == MPI_COMM_NULL) { + return 0; + } + + int ret; + MPI_Iprobe((src >= 0) ? src : MPI_ANY_SOURCE, (tag >= 0) ? tag : MPI_ANY_TAG, ctx_mpi->comm, &ret, &(ctx_mpi->status)); + return ret; +} + +int ggml_mpi_status_tag(struct ggml_mpi_context * ctx_mpi) { + return ctx_mpi->status.MPI_TAG; +} + +int ggml_mpi_status_count_int32(struct ggml_mpi_context * ctx_mpi) { + int32_t count; + MPI_Get_count(&ctx_mpi->status, MPI_INT32_T, &count); + return count; +} + +int ggml_mpi_next_node(struct ggml_mpi_context * ctx_mpi) { + return (ctx_mpi->rank + 1) % ctx_mpi->size; +} + +int ggml_mpi_prev_node(struct ggml_mpi_context * ctx_mpi) { + int temp = (ctx_mpi->rank - 1); + return (temp >= 0) ? temp : ctx_mpi->size - 1; +} + +void ggml_mpi_sync_pipelined_recv( + struct ggml_mpi_context * ctx_mpi, + void * val, + int count, + MPI_Datatype datatype, + int tag +) { + if(ctx_mpi->comm == MPI_COMM_NULL) { + return; + } + MPI_Recv(val, count, datatype, ggml_mpi_prev_node(ctx_mpi), tag, ctx_mpi->comm, MPI_STATUS_IGNORE); + +} + + +void ggml_mpi_sync_pipelined( + struct ggml_mpi_context * ctx_mpi, + void * val, + int count, + MPI_Datatype datatype, + int tag + ) { + if(ctx_mpi->comm == MPI_COMM_NULL) { + return; + } + + //printf("Rank %d sync pipelined\n", ctx_mpi->rank); + + + if (ctx_mpi->rank != 0) { + MPI_Recv(val, count, datatype, ggml_mpi_prev_node(ctx_mpi), tag, ctx_mpi->comm, MPI_STATUS_IGNORE); + } + if(ctx_mpi->rank < ctx_mpi->size - 1) { + const int retval = MPI_Bsend(val, count, datatype, ggml_mpi_next_node(ctx_mpi), tag, ctx_mpi->comm); + GGML_ASSERT(retval == MPI_SUCCESS); + + } +} + +void ggml_mpi_sync_pipelined_back( + struct ggml_mpi_context * ctx_mpi, + void * val, + int count, + MPI_Datatype datatype, + int tag +) { + if(ctx_mpi->comm == MPI_COMM_NULL) { + return; + } + + //printf("Rank %d sync pipelined\n", ctx_mpi->rank); + + + if (ctx_mpi->rank != 0) { + MPI_Recv(val, count, datatype, ggml_mpi_next_node(ctx_mpi), tag, ctx_mpi->comm, MPI_STATUS_IGNORE); + } + if(ctx_mpi->rank != 1) { + const int retval = MPI_Bsend(val, count, datatype, ggml_mpi_prev_node(ctx_mpi), tag, ctx_mpi->comm); + GGML_ASSERT(retval == MPI_SUCCESS); + + } +} + +bool ggml_mpi_eval_init( + struct ggml_mpi_context * ctx_mpi, + int32_t * n_tokens, + int32_t ** tokens, + int32_t ** pos, + int32_t ** n_seq_ids, + int32_t *** seq_id, + int8_t ** logits, + int32_t * batch_id, + bool receive_only) { + if(ctx_mpi->comm == MPI_COMM_NULL) { + return false; + } + int32_t old_n_tokens = *n_tokens; + + + ggml_mpi_sync_pipelined(ctx_mpi, batch_id, 1, MPI_INT, GGML_MPI_BATCH_ID); + + + ggml_mpi_sync_pipelined(ctx_mpi, n_tokens, 1, MPI_INT, GGML_MPI_N_TOKENS); + int8_t* temp_logits = (int8_t*) calloc(*n_tokens, sizeof(int8_t)); + + if (ctx_mpi->rank == 0 && *logits != NULL) { + ggml_mpi_sync_pipelined(ctx_mpi, *logits, *n_tokens, MPI_INT8_T, GGML_MPI_BATCH_LOGITS); + } else { + ggml_mpi_sync_pipelined(ctx_mpi, temp_logits, *n_tokens, MPI_INT8_T, GGML_MPI_BATCH_LOGITS); + } + + + + + + + + if (ctx_mpi->rank != 0) { + bool should_set_batch_logits = false; + for (int i = 0; i < *n_tokens; i++) { + if (temp_logits[i]) { + should_set_batch_logits = true; + break; + } + } + if (should_set_batch_logits) { + if (*logits != NULL) { + free(*logits); + *logits = NULL; + } + *logits = temp_logits; + } else { + if (*logits != NULL) { + free(*logits); + *logits = NULL; + } + free(temp_logits); + } + } else { + free(temp_logits); + } + + // For now, we assume that the pos, seq_ids, tokens, etc have been + // pre-allocated for the largest possible sizes, even on worker nodes. + //if (old_n_tokens != *n_tokens) { + // *pos = realloc(*pos, *n_tokens * sizeof(int32_t)); + // *n_seq_ids = realloc(*n_seq_ids, *n_tokens * sizeof(int32_t )); + // *tokens = realloc(*tokens, *n_tokens * sizeof(int32_t )); + //} + + ggml_mpi_sync_pipelined(ctx_mpi, *tokens, *n_tokens, MPI_INT32_T, GGML_MPI_TOKENS); + + + ggml_mpi_sync_pipelined(ctx_mpi, *n_seq_ids, *n_tokens, MPI_INT32_T, GGML_MPI_N_SEQ_IDS); + + // We need to know the total number of sequence + // ids, so we count them all up + int32_t total_n_seq_ids = 0; + for (int32_t i = 0; i < *n_tokens; i++) { + total_n_seq_ids += (*n_seq_ids)[i]; + } + + // MPI can't chase the pointers for multidimensional arrays, so we flatten them first + // for transit + int32_t * flattened_seq_ids = calloc(total_n_seq_ids, sizeof(int32_t)); + + int32_t current_index = 0; + + // Only rank 0 needs to flatten since the others don't have the real seq_id + if (ctx_mpi->rank == 0) { + for (int32_t i = 0; i < *n_tokens; i++) { + for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) { + flattened_seq_ids[current_index] = (*seq_id)[i][j]; + current_index++; + } + } + } + + + + ggml_mpi_sync_pipelined(ctx_mpi, *pos, *n_tokens, MPI_INT32_T, GGML_MPI_POS); + ggml_mpi_sync_pipelined(ctx_mpi, flattened_seq_ids, total_n_seq_ids, MPI_INT32_T, GGML_MPI_SEQ_IDS); + + current_index = 0; + for (int32_t i = 0; i < *n_tokens; i++) { + for (int32_t j = 0; j < (*n_seq_ids)[i]; j++) { + (*seq_id)[i][j] = flattened_seq_ids[current_index]; + current_index++; + } + + } + free(flattened_seq_ids); + + return true; +} + +void ggml_mpi_sync_ints_pipelined( struct ggml_mpi_context * ctx_mpi, - int * n_tokens, - int * n_past, - int * n_threads) { - UNUSED(ctx_mpi); + int32_t * vals, + int count, + int tag +) { + ggml_mpi_sync_pipelined(ctx_mpi, vals, count, MPI_INT32_T, tag); + int old_trans = ctx_mpi->trans_id; + ggml_mpi_sync_pipelined(ctx_mpi, &ctx_mpi->trans_id, 1, MPI_INT32_T, GGML_MPI_TRANS_ID); + ctx_mpi->recv_trans_id = ctx_mpi->trans_id; + ctx_mpi->trans_id = old_trans; +} - // synchronize the worker node parameters with the root node - MPI_Barrier(MPI_COMM_WORLD); +void ggml_mpi_sync_ints_pipelined_back( + struct ggml_mpi_context * ctx_mpi, + int32_t * vals, + int count, + int tag +) { + ggml_mpi_sync_pipelined_back(ctx_mpi, vals, count, MPI_INT32_T, tag); +// int old_trans = ctx_mpi->trans_id; +// ggml_mpi_sync_pipelined_back(ctx_mpi, &ctx_mpi->trans_id, 1, MPI_INT32_T, GGML_MPI_TRANS_ID); +// ctx_mpi->recv_trans_id = ctx_mpi->trans_id; +// ctx_mpi->trans_id = old_trans; +} + +void ggml_mpi_synch_int( + struct ggml_mpi_context * ctx_mpi, + int32_t * val, + int root +) { + if(ctx_mpi->comm == MPI_COMM_NULL) { + return; + } +// printf("Rank %d sync int\n", ctx_mpi->rank); + MPI_Bcast(val, 1, MPI_INT32_T, root, ctx_mpi->comm); +} - MPI_Bcast(n_tokens, 1, MPI_INT, 0, MPI_COMM_WORLD); - MPI_Bcast(n_past, 1, MPI_INT, 0, MPI_COMM_WORLD); - MPI_Bcast(n_threads, 1, MPI_INT, 0, MPI_COMM_WORLD); +void ggml_mpi_synch_float( + struct ggml_mpi_context * ctx_mpi, + float * val, + int root +) { + if(ctx_mpi->comm == MPI_COMM_NULL) { + return; + } +// printf("Rank %d sync float\n", ctx_mpi->rank); + MPI_Bcast(val, 1, MPI_FLOAT, root, ctx_mpi->comm); +} + +void ggml_mpi_recv_float_array( + struct ggml_mpi_context * ctx_mpi, + float * val, + int arr_size, + int src, + int tag +) { +// printf("Rank %d recv float array, count=%d\n", ctx_mpi->rank, arr_size); + int ret = MPI_Recv(val, arr_size, MPI_FLOAT, src, tag, ctx_mpi->comm, MPI_STATUS_IGNORE); + GGML_ASSERT(ret == MPI_SUCCESS); +} + +void ggml_mpi_send_float_array_async( + struct ggml_mpi_context * ctx_mpi, + float * val, + int arr_size, + int dest, + int tag +) { +// printf("Rank %d send float array async, count=%d, val==null: %d\n", ctx_mpi->rank, arr_size, val == NULL); + int ret = MPI_Bsend(val, arr_size, MPI_FLOAT, dest, tag, ctx_mpi->comm); + GGML_ASSERT(ret == MPI_SUCCESS); } static int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) { @@ -73,7 +437,40 @@ static int ggml_graph_get_node_idx(struct ggml_cgraph * gf, const char * name) { return -1; } -static void ggml_mpi_tensor_send(struct ggml_tensor * t, int mpi_rank_dst) { +struct ggml_tensor * ggml_mpi_dup_tensor(struct ggml_tensor * t) { + struct ggml_tensor * duped = malloc(sizeof(struct ggml_tensor)); + for (int i = 0; i < 4; i++) { + duped->ne[i] = t->ne[i]; + } + size_t data_size = ggml_element_size(t) * ggml_nelements(t); + duped->data = malloc(data_size); + memcpy(duped->data, t->data, data_size); + return duped; +} + +static void ggml_mpi_tensor_send(struct ggml_mpi_context * ctx_mpi, struct ggml_tensor * t, int mpi_rank_dst) { + if(ctx_mpi->comm == MPI_COMM_NULL) { + return; + } + +// printf("\nSending tensor of size %zu from node %d to node %d", ggml_nelements(t), ctx_mpi->rank, mpi_rank_dst); +// printf("Rank %d tensor send\n", ctx_mpi->rank); + MPI_Datatype mpi_type; + + switch (t->type) { + case GGML_TYPE_I32: mpi_type = MPI_INT32_T; break; + case GGML_TYPE_F32: mpi_type = MPI_FLOAT; break; + default: GGML_ASSERT(false && "not implemented"); + } + + const int retval = MPI_Bsend(t->data, ggml_nelements(t), mpi_type, mpi_rank_dst, GGML_MPI_TRANSFER_TENSORS, ctx_mpi->comm); + GGML_ASSERT(retval == MPI_SUCCESS); +} + +static void ggml_mpi_tensor_recv(struct ggml_mpi_context * ctx_mpi, struct ggml_tensor * t, int mpi_rank_src) { + if(ctx_mpi->comm == MPI_COMM_NULL) { + return; + } MPI_Datatype mpi_type; switch (t->type) { @@ -82,11 +479,32 @@ static void ggml_mpi_tensor_send(struct ggml_tensor * t, int mpi_rank_dst) { default: GGML_ASSERT(false && "not implemented"); } - const int retval = MPI_Send(t->data, ggml_nelements(t), mpi_type, mpi_rank_dst, 0, MPI_COMM_WORLD); +// printf("\nReceiving tensor of size %zu, at node %d, from node %d", ggml_nelements(t), ctx_mpi->rank, mpi_rank_src); + + + const int retval = MPI_Recv(t->data, ggml_nelements(t), mpi_type, mpi_rank_src, GGML_MPI_TRANSFER_TENSORS, ctx_mpi->comm, MPI_STATUS_IGNORE); GGML_ASSERT(retval == MPI_SUCCESS); } -static void ggml_mpi_tensor_recv(struct ggml_tensor * t, int mpi_rank_src) { +void ggml_mpi_wait_recv(struct ggml_mpi_context * ctx_mpi) { + if(ctx_mpi->comm == MPI_COMM_NULL) { + return; + } + if (ctx_mpi->asyncRecvWaiting) { + MPI_Wait(&(ctx_mpi->asyncRecvRequest), MPI_STATUS_IGNORE); + ctx_mpi->asyncRecvWaiting = false; + } +} + +struct ggml_tensor * ggml_mpi_async_received_tensor(struct ggml_mpi_context * ctx_mpi) { + ggml_mpi_wait_recv(ctx_mpi); + return ctx_mpi->duped_recv_tensor; +} + +static void ggml_mpi_async_tensor_recv(struct ggml_mpi_context * ctx_mpi, struct ggml_tensor * t, int mpi_rank_src) { + if(ctx_mpi->comm == MPI_COMM_NULL) { + return; + } MPI_Datatype mpi_type; switch (t->type) { @@ -95,17 +513,85 @@ static void ggml_mpi_tensor_recv(struct ggml_tensor * t, int mpi_rank_src) { default: GGML_ASSERT(false && "not implemented"); } - MPI_Status status; UNUSED(status); + ggml_mpi_wait_recv(ctx_mpi); + ctx_mpi->asyncRecvWaiting = true; + const int retval = MPI_Irecv(t->data, ggml_nelements(t), mpi_type, mpi_rank_src, GGML_MPI_TRANSFER_TENSORS, ctx_mpi->comm, &(ctx_mpi->asyncRecvRequest)); - const int retval = MPI_Recv(t->data, ggml_nelements(t), mpi_type, mpi_rank_src, MPI_ANY_TAG, MPI_COMM_WORLD, &status); GGML_ASSERT(retval == MPI_SUCCESS); } +uint16_t** ggml_mpi_split_range( + struct ggml_mpi_context * ctx_mpi, + uint16_t start, + uint16_t end, + float node_weights[] +) { + // Splits the range given by start and end + // over the available nodes. This implementation + // assumes that node 0 handles the final part of the range + // while node 1 handles the beginning, to form a ring pipeline + + // Only node 0 deals with the device splits, other nodes + // get the splits from the scatter layers operation + + if (ctx_mpi->comm == MPI_COMM_NULL || ctx_mpi->rank != 0) { + return NULL; + } + + uint16_t range_length = end - start + 1; + uint16_t ** ranges = (uint16_t**) malloc(sizeof(uint16_t*) * ctx_mpi->size); + for (int i = 0; i < ctx_mpi->size; i++) { + ranges[i] = (uint16_t*) malloc(sizeof(uint16_t) * 2); + } + uint16_t next_layer = 0; + for (int i=0; i < ctx_mpi->size-1; i++) { + ranges[i][0] = next_layer; + ranges[i][1] = MIN(end, ranges[i][0] + (node_weights[i] * range_length) + start); + next_layer = ranges[i][1]; + } + + ranges[ctx_mpi->size-1][0] = next_layer; +// ranges[ctx_mpi->size-1][1] = MIN(end, next_layer + (node_weights[ctx_mpi->size-1] * range_length) + start); + ranges[ctx_mpi->size-1][1] = end; + + return ranges; + +} + +void ggml_mpi_scatter_layers( + struct ggml_mpi_context * ctx_mpi, + uint16_t ** layer_ranges +) { + if(ctx_mpi->comm == MPI_COMM_NULL) { + return; + } + + // Layer ranges is a 2d array with the first dimension + // having a length of the number of nodes and the second + // dimension having a length of 2. The inner arrays contain + // the start and end layer ID for a node. + uint16_t flattened_ranges[ctx_mpi->size * 2]; + + if (layer_ranges != NULL) { + for (int i = 0; i < ctx_mpi->size * 2; i += 2) { + flattened_ranges[i] = layer_ranges[i/2][0]; + flattened_ranges[i + 1] = layer_ranges[i/2][1]; + } + } + + uint16_t received_range[2]; + MPI_Scatter(flattened_ranges, 2, MPI_UINT16_T, received_range, 2, MPI_UINT16_T, 0, ctx_mpi->comm); + ctx_mpi->layer_start = received_range[0]; + ctx_mpi->layer_end = received_range[1]; + fprintf(stderr, "Ranges for rank %d: [%d, %d]\n", ctx_mpi->rank, ctx_mpi->layer_start, ctx_mpi->layer_end); +} + + // TODO: there are many improvements that can be done to this implementation -void ggml_mpi_graph_compute_pre( +void ggml_mpi_graph_creation_post( struct ggml_mpi_context * ctx_mpi, struct ggml_cgraph * gf, - int n_layers) { + const int n_layers) { const int mpi_rank = ctx_mpi->rank; const int mpi_size = ctx_mpi->size; @@ -123,6 +609,8 @@ void ggml_mpi_graph_compute_pre( GGML_ASSERT(inp0 == gf->nodes[0]); +// printf("Rank %d creation post\n", mpi_rank); + // distribute the compute graph into slices across the MPI nodes // // the main node (0) processes the last layers + the remainder of the compute graph @@ -134,83 +622,100 @@ void ggml_mpi_graph_compute_pre( // node n-1: [(n-2) * n_per_node, (n-1) * n_per_node) // node 0: [(n-1) * n_per_node, n_nodes) // + + + if (mpi_rank > 0) { - if (mpi_rank == 1) { - // the first node (1) receives the input tokens from the main node (0) - ggml_mpi_tensor_recv(inp_tokens, 0); - } else { - // recv input data for each node into the "inp0" tensor (i.e. the first node in the compute graph) - ggml_mpi_tensor_recv(inp0, mpi_rank - 1); - } - } else if (mpi_size > 1) { - // node 0 sends the input tokens to node 1 - ggml_mpi_tensor_send(inp_tokens, 1); + // recv input data for each node into the "inp0" tensor (i.e. the first node in the compute graph) + ggml_mpi_tensor_recv(ctx_mpi, inp0, mpi_rank - 1); - // recv the output data from the last node - ggml_mpi_tensor_recv(inp0, mpi_size - 1); + } else if (mpi_size > 1) { + // node 0 processes the inputs and then sends to node 1 } - { - const int n_per_node = (n_layers + (mpi_size - 1)) / mpi_size; + //const int n_per_node = (n_layers + (mpi_size - 1)) / mpi_size; - const int mpi_idx = mpi_rank > 0 ? mpi_rank - 1 : mpi_size - 1; + //const int il0 = (mpi_idx + 0) * n_per_node; + //const int il1 = MIN(n_layers, (mpi_idx + 1) * n_per_node); + int il0 = ctx_mpi->layer_start; + int il1 = MIN(n_layers, ctx_mpi->layer_end); - const int il0 = (mpi_idx + 0) * n_per_node; - const int il1 = MIN(n_layers, (mpi_idx + 1) * n_per_node); + char name_l0[GGML_MAX_NAME]; + char name_l1[GGML_MAX_NAME]; - char name_l0[GGML_MAX_NAME]; - char name_l1[GGML_MAX_NAME]; + snprintf(name_l0, sizeof(name_l0), "layer_inp_%d", il0); + snprintf(name_l1, sizeof(name_l1), "layer_inp_%d", il1); - snprintf(name_l0, sizeof(name_l0), "layer_inp_%d", il0); - snprintf(name_l1, sizeof(name_l1), "layer_inp_%d", il1); + const int idx_l0 = ggml_graph_get_node_idx(gf, name_l0); + const int idx_l1 = mpi_rank == mpi_size - 1 ? gf->n_nodes : ggml_graph_get_node_idx(gf, name_l1) + 1; - const int idx_l0 = ggml_graph_get_node_idx(gf, name_l0); - const int idx_l1 = mpi_rank > 0 ? ggml_graph_get_node_idx(gf, name_l1) + 1 : gf->n_nodes; + if (idx_l0 < 0 || idx_l1 < 0) { + fprintf(stderr, "%s: layer input nodes not found\n", __func__); + return; + } - if (idx_l0 < 0 || idx_l1 < 0) { - fprintf(stderr, "%s: layer input nodes not found\n", __func__); - return; + // attach the input data to all nodes that need it + // TODO: not great - should be able to do this without modifying the compute graph (see next TODO below) + for (int i = idx_l0; i < idx_l1; i++) { + if (gf->nodes[i]->src[0] == gf->nodes[idx_l0]) { + gf->nodes[i]->src[0] = inp0; } - - // attach the input data to all nodes that need it - // TODO: not great - should be able to do this without modifying the compute graph (see next TODO below) - for (int i = idx_l0; i < idx_l1; i++) { - if (gf->nodes[i]->src[0] == gf->nodes[idx_l0]) { - gf->nodes[i]->src[0] = inp0; - } - if (gf->nodes[i]->src[1] == gf->nodes[idx_l0]) { - gf->nodes[i]->src[1] = inp0; - } + if (gf->nodes[i]->src[1] == gf->nodes[idx_l0]) { + gf->nodes[i]->src[1] = inp0; } + } - // TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph - for (int i = 1; i < idx_l1 - idx_l0; i++) { - gf->nodes[i] = gf->nodes[idx_l0 + i]; - gf->grads[i] = gf->grads[idx_l0 + i]; - } + // TODO: instead of rearranging the nodes, we should be able to execute a subset of the compute graph + for (int i = 1; i < idx_l1 - idx_l0; i++) { + gf->nodes[i] = gf->nodes[idx_l0 + i]; + } - // the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node - if (mpi_idx != 0) { - gf->nodes[0]->op = GGML_OP_NONE; - } + // the first node performs the "get_rows" operation, the rest of the nodes get the data from the previous node + if (mpi_rank != 0 && mpi_size > 1) { + gf->nodes[0]->op = GGML_OP_NONE; + } + + gf->n_nodes = idx_l1 - idx_l0; + + +} + +bool ggml_mpi_graph_compute_pre(struct ggml_mpi_context * ctx_mpi, struct ggml_cgraph * gf) { + if (ctx_mpi->comm == MPI_COMM_NULL) { + return false; + } + +// printf("Rank %d compute pre\n", ctx_mpi->rank); + + const int mpi_rank = ctx_mpi->rank; + const int mpi_size = ctx_mpi->size; - gf->n_nodes = idx_l1 - idx_l0; + struct ggml_tensor * inp_tokens = ggml_graph_get_tensor(gf, "inp_tokens"); + if (inp_tokens == NULL) { + fprintf(stderr, "%s: tensor 'inp_tokens' not found\n", __func__); + return false; + } - //fprintf(stderr, "%s: node %d: processing %d nodes [%d, %d)\n", __func__, mpi_rank, gf->n_nodes, il0, il1); + struct ggml_tensor * inp0 = ggml_graph_get_tensor(gf, "layer_inp_0"); + if (inp0 == NULL) { + fprintf(stderr, "%s: tensor 'inp0' not found\n", __func__); + return false; } + + GGML_ASSERT(inp0 == gf->nodes[0]); + + return true; } void ggml_mpi_graph_compute_post( struct ggml_mpi_context * ctx_mpi, - struct ggml_cgraph * gf, - int n_layers) { - UNUSED(n_layers); + struct ggml_cgraph * gf) { const int mpi_rank = ctx_mpi->rank; - const int mpi_size = ctx_mpi->size; +// printf("Rank %d compute post\n", mpi_rank); // send the output data to the next node - if (mpi_rank > 0) { - ggml_mpi_tensor_send(gf->nodes[gf->n_nodes - 1], (mpi_rank + 1) % mpi_size); + if (mpi_rank < ctx_mpi->size - 1) { + ggml_mpi_tensor_send(ctx_mpi, gf->nodes[gf->n_nodes - 1], ggml_mpi_next_node(ctx_mpi)); } } diff --git a/ggml-mpi.h b/ggml-mpi.h index eda119d4498..ca2365862c1 100644 --- a/ggml-mpi.h +++ b/ggml-mpi.h @@ -1,4 +1,7 @@ #pragma once +#include +#include +#include struct ggml_context; struct ggml_tensor; @@ -8,31 +11,294 @@ struct ggml_cgraph; extern "C" { #endif +#define GGML_MPI_DECODE 0 + +#define GGML_MPI_KV_CLEAR 1 + +#define GGML_MPI_KV_SEQ_RM 2 + +#define GGML_MPI_KV_SEQ_CP 3 + +#define GGML_MPI_KV_SEQ_KEEP 4 + +#define GGML_MPI_KV_SEQ_SHIFT 5 + +#define GGML_MPI_SHUTDOWN 6 + +#define GGML_MPI_TRANSFER_TENSORS 7 + +#define GGML_MPI_SYNC_LOGITS 8 + +#define GGML_MPI_CANCEL_RUN 9 + +#define GGML_MPI_KV_SEQ_CP_BACK 10 + +#define GGML_MPI_TRANS_ID 11 + +#define GGML_MPI_BATCH_ID 12 + +#define GGML_MPI_N_TOKENS 13 + +#define GGML_MPI_TOKENS 14 + +#define GGML_MPI_N_SEQ_IDS 15 + +#define GGML_MPI_SEQ_IDS 16 + +#define GGML_MPI_POS 17 + +#define GGML_MPI_BEGIN_TRANSACTION 18 + +#define GGML_MPI_MAX_N_SEQ 19 + +#define GGML_MPI_BATCH_LOGITS 20 + +/** + * The context used for MPI operations, + * a program may make use of more than one + * context but must always have at least one. + * + * The context stores required information like the + * node rank and a communicator to use for MPI operations. + * A context is guaranteed to be internally consistent, + * meaning that a context's stored rank is valid within + * the context's communicator. + */ struct ggml_mpi_context; + +int ggml_mpi_trans_id(struct ggml_mpi_context * ctx_mpi); + +int ggml_mpi_recv_trans_id(struct ggml_mpi_context * ctx_mpi); + +void ggml_mpi_inc_trans_id(struct ggml_mpi_context * ctx_mpi); + +/** + * Initialize the MPI library and the GGML MPI backend. + * Calling more than once during the lifetime of the program + * leads to undefined behavior. This function must be called before + * any MPI operations. + */ void ggml_mpi_backend_init(void); + +bool ggml_mpi_is_decoding(struct ggml_mpi_context * ctx_mpi); + +int ggml_mpi_status_count_int32(struct ggml_mpi_context * ctx_mpi); + +void ggml_mpi_graph_creation_post(struct ggml_mpi_context * ctx_mpi, struct ggml_cgraph * cgraph, int n_layers); + +void ggml_mpi_wait_recv(struct ggml_mpi_context * ctx_mpi); + +/** + * Frees the MPI backend, must be called only once at termination + * of the program. No MPI operations may be completed after calling this function, + * and attempting to do so will lead to undefined behavior. + */ void ggml_mpi_backend_free(void); +/** + * Construct a new MPI context using the MPI_WORLD + * communicator. This is useful only to create the + * initial context, as calling multiple times + * will only create effective copies of the same data. + * + * @return A context for us in the global communicator. + */ struct ggml_mpi_context * ggml_mpi_init(void); + +/** + * Create a new context by splitting the given context's + * communicator, creating a "sub-communicator." This is a collective + * operation and must be performed by all nodes within the same communicator. + * The color and key have the same meaning as in MPI_Comm_split(), i.e. + * the color is used to determine the sub-communicator this node will belong to, + * and the key is the relative rank of this node in the new communicator. + * + * An example: if a node passes a color of 1, and a different node passes a color of 2, + * the nodes will belong to two different sub-communicators. If two nodes pass the same + * color, then their ranks will be ordered by the order of their keys. If they pass the same + * key, then the tie will be broken by the nodes' ranks in the old communicator. + * + * The communicator used by the given context remains entirely valid, so it is advisable + * to store both the old and new contexts. This allows an application to + * select at runtime which communicator to perform MPI operations with. An example + * would be to segregate the nodes into multiple domains categorized by the functions + * they perform, and use the original context to broadcast to all nodes in the cluster. + * + * @param ctx The context containing the communicator to split. + * @param color The sub-communicator that this node will belong to. + * @param key The relative rank of this node in the new communicator. + * @return A new context with all values referencing the newly-created communicator. + */ +struct ggml_mpi_context * ggml_mpi_split_comm(struct ggml_mpi_context * ctx, int color, int key); + +void ggml_mpi_barrier(struct ggml_mpi_context * ctx); + +int ggml_mpi_next_node(struct ggml_mpi_context * ctx_mpi); + +int ggml_mpi_prev_node(struct ggml_mpi_context * ctx_mpi); + +void ggml_mpi_sync_ints_pipelined( + struct ggml_mpi_context * ctx_mpi, + int32_t * vals, + int count, + int tag +); + +void ggml_mpi_sync_ints_pipelined_back( + struct ggml_mpi_context * ctx_mpi, + int32_t * vals, + int count, + int tag +); +// clear = 1, rm = 2, cp = 3, keep = 4, seq_shift = 5 +void ggml_mpi_probe(struct ggml_mpi_context * ctx_mpi, int src, int tag); +int ggml_mpi_status_tag(struct ggml_mpi_context * ctx_mpi); + +int ggml_mpi_iprobe(struct ggml_mpi_context * ctx_mpi, int src, int tag); + +/** + * Frees the given context, including the communicator. No MPI + * operations besides ggml_mpi_backend_freee(void) should be executed after + * running this function. + * + * @param ctx The context to free. + */ void ggml_mpi_free(struct ggml_mpi_context * ctx); +/** + * Get the rank of this node in the given context's communicator. + * + * @param ctx The context to use to determine the rank with regards to. + * @return The rank of this node. + */ int ggml_mpi_rank(struct ggml_mpi_context * ctx); -void ggml_mpi_eval_init( - struct ggml_mpi_context * ctx_mpi, - int * n_tokens, - int * n_past, - int * n_threads); +/** + * Get the number of nodes that are a part of + * the communicator referenced by the given context. + * + * @param ctx The context containing the communicator used for this size check. + * @return The number of nodes that are a part of the given context's communicator. + */ +size_t ggml_mpi_size(struct ggml_mpi_context * ctx); + +/** + * Synchronize needed information among the nodes + * to prepare for running an evaluation iteration. + * This is a collective operation and all nodes must + * call this function. It will block until all + * nodes have entered it, to prevent any desync + * between nodes. + * + * @param ctx_mpi The context in which to prepare for evaluation. + * @param n_tokens A pointer to the n_tokens, which will be synchronized after this function. + * @param pos A pointer to the pos array, which will be synchronized after this function. + * @param n_seq_ids A pointer to the n_seq_ids array, which will be synchronized after this function. + * @param seq_id A pointer to the seq_id 2D array, which will be synchronized after this function. + * @param logits A pointer to the logits array, which is unused currently since only node 0 needs them. + */ +bool ggml_mpi_eval_init( + struct ggml_mpi_context * ctx_mpi, + int32_t * n_tokens, + int32_t ** tokens, + int32_t ** pos, + int32_t ** n_seq_ids, + int32_t *** seq_id, + int8_t ** logits, + int32_t * batch_id, + bool receive_only); + +void ggml_mpi_synch_int( + struct ggml_mpi_context * ctx_mpi, + int32_t * val, + int root + ); + +void ggml_mpi_synch_float( + struct ggml_mpi_context * ctx_mpi, + float * val, + int root +); + +void ggml_mpi_recv_float_array( + struct ggml_mpi_context * ctx_mpi, + float * val, + int arr_size, + int src, + int tag +); + +void ggml_mpi_send_float_array_async( + struct ggml_mpi_context * ctx_mpi, + float * val, + int arr_size, + int dest, + int tag +); + +/** + * Split a range across all nodes within the given + * context, weighting the allocations by the given weights. + * The dimensions of the returned 2d array are (number of nodes in the context, 2). + * The first element in the inner array is the starting point of the range allocated + * to the node indicated by the index into the outer array, + * and the second element is the end point of the allocated range, inclusive. + * + * @param ctx_mpi The context used to determine the number of nodes + * to split the range across. + * @param start The starting point of the range. + * @param end The end point of the range, inclusive. + * @param node_weights How to weight the allocations across the nodes, + * must sum to 1.0. + * @return A 2d array, the first dimension is the number of nodes in the context + * and the second dimension is 2. + */ +uint16_t** ggml_mpi_split_range( + struct ggml_mpi_context * ctx_mpi, + uint16_t start, + uint16_t end, + float node_weights[] +); + +/** + * Scatter the layer ranges across all nodes + * in the given context. This is a collective operation + * and must be called by all nodes that are within the same + * communicator. The given layer ranges must be in the same + * format as created by the ggml_mpi_split_range(). + * + * @param ctx_mpi The context to scatter the layers across. + * @param layer_ranges The pre-split ranges to scatter to the nodes. + */ +void ggml_mpi_scatter_layers( + struct ggml_mpi_context * ctx_mpi, + uint16_t ** layer_ranges +); -void ggml_mpi_graph_compute_pre( +/** + * Modify compute graph to only process allocated + * layers. + * + * @param ctx_mpi The context containing the allocated layer range. + * @param gf The compute graph to modify + * @param n_layers The number of layers in the model, used as an upper bound in the layer ranges. + */ +bool ggml_mpi_graph_compute_pre( struct ggml_mpi_context * ctx_mpi, - struct ggml_cgraph * gf, - int n_layers); + struct ggml_cgraph * gf); +/** + * Sends the output tensor to the next node for processing + * of later layers. + * + * @param ctx_mpi The context to use for MPI operations. + * @param gf The graph used in the computations + * @param n_layers The number of layers in the model. + */ void ggml_mpi_graph_compute_post( struct ggml_mpi_context * ctx_mpi, - struct ggml_cgraph * gf, - int n_layers); + struct ggml_cgraph * gf); #ifdef __cplusplus } diff --git a/ggml.c b/ggml.c index f92292b39c6..6b1275cf3ee 100644 --- a/ggml.c +++ b/ggml.c @@ -362,13 +362,13 @@ int64_t ggml_time_us(void) { void ggml_time_init(void) {} int64_t ggml_time_ms(void) { struct timespec ts; - clock_gettime(CLOCK_MONOTONIC, &ts); + clock_gettime(CLOCK_THREAD_CPUTIME_ID, &ts); return (int64_t)ts.tv_sec*1000 + (int64_t)ts.tv_nsec/1000000; } int64_t ggml_time_us(void) { struct timespec ts; - clock_gettime(CLOCK_MONOTONIC, &ts); + clock_gettime(CLOCK_THREAD_CPUTIME_ID, &ts); return (int64_t)ts.tv_sec*1000000 + (int64_t)ts.tv_nsec/1000; } #endif @@ -15718,10 +15718,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { int node_n = -1; while (true) { - if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { - state->shared->node_n += 1; - return (thread_ret_t) GGML_EXIT_ABORTED; - } +// if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { +// state->shared->node_n += 1; +// return (thread_ret_t) GGML_EXIT_ABORTED; +// } if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { // all other threads are finished and spinning // do finalize and init here so we don't have synchronize again @@ -15745,6 +15745,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { // distribute new work or execute it direct if 1T while (++node_n < cgraph->n_nodes) { + if (cplan->abort_callback && cplan->abort_callback(state->ith, cplan->abort_callback_data)) { + break; + } GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes); struct ggml_tensor * node = cgraph->nodes[node_n]; @@ -15777,9 +15780,6 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { break; } - if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { - break; - } } atomic_store(&state->shared->n_active, n_threads); @@ -15798,16 +15798,23 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { node_n = atomic_load(&state->shared->node_n); if (node_n != last) break; + if (cplan->abort_callback && cplan->abort_callback(state->ith, cplan->abort_callback_data)) { + break; + } }; } // check if we should stop if (node_n >= cgraph->n_nodes) break; + if (cplan->abort_callback && cplan->abort_callback(state->ith, cplan->abort_callback_data)) { + break; + } /* COMPUTE */ struct ggml_tensor * node = cgraph->nodes[node_n]; const int n_tasks = ggml_get_n_tasks(node, n_threads); + struct ggml_compute_params params = { /*.type =*/ GGML_TASK_COMPUTE, /*.ith =*/ state->ith, @@ -15819,6 +15826,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { if (state->ith < n_tasks) { ggml_compute_forward(¶ms, node); } + + + } return GGML_EXIT_SUCCESS; @@ -16036,6 +16046,10 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { }; struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads); + if (cplan->abort_callback && cplan->abort_callback(0, cplan->abort_callback_data)) { + return GGML_EXIT_SUCCESS; + } + // create thread pool if (n_threads > 1) { for (int j = 1; j < n_threads; ++j) { diff --git a/ggml.h b/ggml.h index f2fce0f22d3..e39cdd8a859 100644 --- a/ggml.h +++ b/ggml.h @@ -531,7 +531,7 @@ extern "C" { int n_threads; // abort ggml_graph_compute when true - bool (*abort_callback)(void * data); + bool (*abort_callback)(int ith, void * data); void * abort_callback_data; }; diff --git a/llama.cpp b/llama.cpp index f2b5967d791..47d897fb15f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -18,6 +18,8 @@ #endif #ifdef GGML_USE_MPI # include "ggml-mpi.h" +#include "common/log.h" + #endif #ifndef QK_K # ifdef GGML_QKK_64 @@ -76,6 +78,8 @@ #include #include #include +#include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -662,9 +666,23 @@ static std::string gguf_kv_to_str(struct gguf_context * ctx_gguf, int i) { // ggml helpers // -static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads) { +static std::function abort_callback_function; + +static bool ab_callback(int ithread, void * data) { + if (abort_callback_function != nullptr) { + return abort_callback_function(ithread, data); + } + return false; +} + +static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * graph, int n_threads, std::function callback = nullptr, void * abort_data = nullptr) { struct ggml_cplan plan = ggml_graph_plan(graph, n_threads); + abort_callback_function = std::move(callback); + + plan.abort_callback = &ab_callback; + plan.abort_callback_data = abort_data; + if (plan.work_size > 0) { buf.resize(plan.work_size); plan.work_data = buf.data(); @@ -850,9 +868,14 @@ struct llama_mmap { int flags = MAP_SHARED; // prefetch/readahead impairs performance on NUMA systems if (numa) { prefetch = 0; } + +#ifdef GGML_USE_MPI + prefetch = 0; +#endif #ifdef __linux__ if (prefetch) { flags |= MAP_POPULATE; } #endif + addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0); if (addr == MAP_FAILED) { throw std::runtime_error(format("mmap failed: %s", strerror(errno))); @@ -1488,6 +1511,8 @@ struct llama_context { #ifdef GGML_USE_MPI ggml_mpi_context * ctx_mpi = NULL; + ggml_mpi_context * ctx_mpi_orig = NULL; + std::unordered_map canceled_batches; #endif }; @@ -1649,13 +1674,13 @@ static void llama_kv_cache_seq_rm( if (p1 < 0) p1 = std::numeric_limits::max(); for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + if ((cache.cells[i].pos >= p0 || cache.cells[i].pos < 0) && cache.cells[i].pos < p1) { if (seq_id < 0) { cache.cells[i].seq_id.clear(); } else if (cache.cells[i].has_seq_id(seq_id)) { cache.cells[i].seq_id.erase(seq_id); } else { - continue; +// continue; } if (cache.cells[i].seq_id.empty()) { // keep count of the number of used cells @@ -3873,6 +3898,7 @@ struct llm_build_context { } for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); //MPI struct ggml_tensor * inpSA = inpL; // norm @@ -3983,6 +4009,7 @@ struct llm_build_context { } for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); //MPI struct ggml_tensor * inpSA = inpL; cur = llm_build_norm(ctx0, inpL, hparams, @@ -4103,6 +4130,7 @@ struct llm_build_context { } for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); //MPI struct ggml_tensor * attn_norm; attn_norm = llm_build_norm(ctx0, inpL, hparams, @@ -4227,6 +4255,7 @@ struct llm_build_context { cb(inpL, "inpL", -1); for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); //MPI cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, model.layers[il].attn_norm_b, @@ -4323,6 +4352,7 @@ struct llm_build_context { } for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); //MPI struct ggml_tensor * residual = inpL; cur = llm_build_norm(ctx0, inpL, hparams, @@ -4525,6 +4555,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); //MPI struct ggml_tensor * inpSA = inpL; cur = llm_build_norm(ctx0, inpL, hparams, @@ -4622,6 +4653,7 @@ struct llm_build_context { cb(inpL, "inp_norm", -1); for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); //MPI cur = llm_build_norm(ctx0, inpL, hparams, model.layers[il].attn_norm, model.layers[il].attn_norm_b, @@ -4710,6 +4742,7 @@ struct llm_build_context { cb(KQ_mask, "KQ_mask", -1); for (int il = 0; il < n_layer; ++il) { + ggml_format_name(inpL, "layer_inp_%d", il); //MPI struct ggml_tensor * attn_norm; attn_norm = llm_build_norm(ctx0, inpL, hparams, @@ -5177,10 +5210,25 @@ static struct ggml_cgraph * llama_build_graph( for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { const llama_pos pos = batch.pos[j]; - const llama_seq_id seq_id = batch.seq_id[j][0]; + for (int i = 0; i < n_kv; ++i) { - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { + + const int n_seq_id = batch.n_seq_id[j]; + + bool has_seq_id = false; + + for (int seq_index = 0; seq_index < n_seq_id; seq_index++) { + llama_seq_id seq_id = batch.seq_id[j][seq_index]; +// printf("Seq id %d in index %d, n_seq_id %d\n", seq_id, seq_index, n_seq_id); + + has_seq_id = lctx.kv_self.cells[i].has_seq_id(seq_id); + if (has_seq_id) { + break; + } + } + + if (!has_seq_id || lctx.kv_self.cells[i].pos > pos) { data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; } } @@ -5406,23 +5454,34 @@ static struct ggml_cgraph * llama_build_graph( return result; } -// decode a batch of tokens by evaluating the transformer -// -// - lctx: llama context -// - batch: batch to evaluate -// -// return 0 on success -// return positive int on warning -// return negative int on error -// -static int llama_decode_internal( - llama_context & lctx, - llama_batch batch) { - const uint32_t n_tokens = batch.n_tokens; +bool llama_mpi_iprobe(struct llama_context * lctx) { + return ggml_mpi_iprobe(lctx->ctx_mpi, ggml_mpi_prev_node(lctx->ctx_mpi), GGML_MPI_SYNC_LOGITS); +} + +static struct ggml_cgraph * llama_decode_internal_phased( + llama_context & lctx, + llama_batch & batch, + uint8_t phase, + ggml_cgraph * cgraph) { + if (ggml_mpi_rank(lctx.ctx_mpi) < 0) { + return nullptr; + } + if (phase == 0) { + if (ggml_mpi_rank(lctx.ctx_mpi) == 0 && ggml_mpi_size(lctx.ctx_mpi) > 1) { + int transaction_type = GGML_MPI_DECODE; + ggml_mpi_sync_ints_pipelined(lctx.ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION); + } + ggml_mpi_sync_ints_pipelined(lctx.ctx_mpi, &batch.batch_id, 1, GGML_MPI_BATCH_ID); + + ggml_mpi_sync_ints_pipelined(lctx.ctx_mpi, &batch.n_tokens, 1, GGML_MPI_N_TOKENS); + + ggml_mpi_sync_ints_pipelined(lctx.ctx_mpi, &batch.max_n_seq, 1, GGML_MPI_MAX_N_SEQ); + } + uint32_t n_tokens = batch.n_tokens; if (n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0", __func__); - return -1; + return nullptr; } const auto & model = lctx.model; @@ -5434,15 +5493,9 @@ static int llama_decode_internal( GGML_ASSERT(n_tokens <= n_batch); int n_threads = n_tokens == 1 ? cparams.n_threads : cparams.n_threads_batch; - GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT - const int64_t t_start_us = ggml_time_us(); -#ifdef GGML_USE_MPI - // TODO: needs fix after #3228 - GGML_ASSERT(false && "not implemented"); - //ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); -#endif + const int64_t t_start_us = ggml_time_us(); GGML_ASSERT(n_threads > 0); @@ -5476,7 +5529,7 @@ static int llama_decode_internal( seq_id_arr.resize(n_tokens); for (uint32_t i = 0; i < n_tokens; i++) { n_seq_id[i] = 1; - seq_id[i].resize(1); + seq_id[i].resize(batch.max_n_seq); seq_id[i][0] = batch.all_seq_id; seq_id_arr[i] = seq_id[i].data(); } @@ -5485,70 +5538,103 @@ static int llama_decode_internal( batch.seq_id = seq_id_arr.data(); } - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self.head > kv_self.used + 2*n_tokens) { - kv_self.head = 0; - } + if (phase == 0) { - if (!llama_kv_cache_find_slot(kv_self, batch)) { - return 1; - } + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (kv_self.head > kv_self.used + 2*n_tokens) { + kv_self.head = 0; + } - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - //kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA? - kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self))); +#ifdef GGML_USE_MPI + // TODO: needs fix after #3228 + if (!ggml_mpi_eval_init(lctx.ctx_mpi, &(batch.n_tokens), &(batch.token), &(batch.pos), &(batch.n_seq_id), + &(batch.seq_id), &(batch.logits), &(batch.batch_id), false)) { + return nullptr; + } + n_tokens = batch.n_tokens; +#endif - //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); + GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT + if (!llama_kv_cache_find_slot(kv_self, batch)) { + printf("Cannot find cache slot\n"); + return nullptr; + } + + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + //kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA? + kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self))); - ggml_allocr_reset(lctx.alloc); + //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); - ggml_cgraph * gf = llama_build_graph(lctx, batch); + ggml_allocr_reset(lctx.alloc); + ggml_cgraph * gf = llama_build_graph(lctx, batch); - ggml_allocr_alloc_graph(lctx.alloc, gf); + ggml_allocr_alloc_graph(lctx.alloc, gf); - struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; - struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; + struct ggml_tensor *res = gf->nodes[gf->n_nodes - 1]; + struct ggml_tensor *embeddings = gf->nodes[gf->n_nodes - 2]; - GGML_ASSERT(strcmp(res->name, "result_output") == 0); - GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + GGML_ASSERT(strcmp(res->name, "result_output") == 0); + GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); +#ifdef GGML_USE_MPI + const int64_t n_layer = hparams.n_layer; + ggml_mpi_graph_creation_post(lctx.ctx_mpi, gf, n_layer); +#endif #ifdef GGML_USE_CUBLAS - for (int i = 0; i < gf->n_leafs; i++) { - ggml_tensor * node = gf->leafs[i]; - if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { - ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); - ggml_cuda_copy_to_device(node); + for (int i = 0; i < gf->n_leafs; i++) { + ggml_tensor * node = gf->leafs[i]; + if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { + ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); + ggml_cuda_copy_to_device(node); + } } - } - for (int i = 0; i < gf->n_nodes; i++) { - ggml_tensor * node = gf->nodes[i]; - if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { - ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); + for (int i = 0; i < gf->n_nodes; i++) { + ggml_tensor * node = gf->nodes[i]; + if (node->backend == GGML_BACKEND_GPU && node->extra == NULL) { + ggml_cuda_assign_scratch_offset(node, (char*)node->data - (char *) lctx.buf_alloc.data); + } } - } - // HACK: ggml-alloc may change the tensor backend when reusing a parent, so force output to be on the CPU here if needed - if (!lctx.embedding.empty()) { - embeddings->backend = GGML_BACKEND_CPU; - } - res->backend = GGML_BACKEND_CPU; + // HACK: ggml-alloc may change the tensor backend when reusing a parent, so force output to be on the CPU here if needed + if (!lctx.embedding.empty()) { + embeddings->backend = GGML_BACKEND_CPU; + } + res->backend = GGML_BACKEND_CPU; #endif - // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); +#ifdef GGML_USE_MPI + if (ggml_mpi_iprobe(lctx.ctx_mpi, ggml_mpi_next_node(lctx.ctx_mpi), GGML_MPI_CANCEL_RUN)) { + int count = ggml_mpi_status_count_int32(lctx.ctx_mpi); +// printf("Received async cancel run\n"); + { + std::vector canceled(count, -1); + llama_cancel_run(&lctx, canceled.data(), canceled.size()); - // for big prompts, if BLAS is enabled, it is better to use only one thread - // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance - // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well - // we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering - // with the BLAS calls. need a better solution - if (n_tokens >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) { - n_threads = std::min(4, n_threads); - } + } + } + if (!ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf)) { + return nullptr; + } + auto it = lctx.canceled_batches.find(batch.batch_id); + if (it == lctx.canceled_batches.end() || !lctx.canceled_batches[batch.batch_id]) { +#endif + + // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + + // for big prompts, if BLAS is enabled, it is better to use only one thread + // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance + // TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well + // we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering + // with the BLAS calls. need a better solution + if (n_tokens >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) { + n_threads = std::min(4, n_threads); + } // If all tensors can be run on the GPU then using more than 1 thread is detrimental. const bool full_offload_supported = @@ -5560,108 +5646,226 @@ static int llama_decode_internal( model.arch == LLM_ARCH_STARCODER || model.arch == LLM_ARCH_STABLELM; - const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3; - if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) { - n_threads = 1; - } - -#if GGML_USE_MPI - const int64_t n_layer = hparams.n_layer; - ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); -#endif + const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3; + if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) { + n_threads = 1; + } #ifdef GGML_USE_METAL - if (lctx.ctx_metal) { - ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); - ggml_metal_graph_compute(lctx.ctx_metal, gf); - } else { - ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); - } + if (lctx.ctx_metal) { + ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); + ggml_metal_graph_compute(lctx.ctx_metal, gf); + } else { + ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); + } #else - ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); + auto abort_callback = [&lctx, &batch](int ithread, void * data) -> bool { + if (data != nullptr && *((std::atomic_bool*)data)) { +// printf("\nAborting because already have aborted\n"); + return true; + } + if (ithread == 0 && ggml_mpi_iprobe(lctx.ctx_mpi, ggml_mpi_next_node(lctx.ctx_mpi), GGML_MPI_CANCEL_RUN)) { + int count = ggml_mpi_status_count_int32(lctx.ctx_mpi); +// printf("\nReceived async cancel run, count of %d\n", count); + { + std::vector canceled(count, -1); + llama_cancel_run(&lctx, canceled.data(), canceled.size()); + + } + auto it = lctx.canceled_batches.find(batch.batch_id); + if (it != lctx.canceled_batches.end() && lctx.canceled_batches[batch.batch_id]) { + if (data != nullptr) { + *((std::atomic_bool *) data) = true; + } + return true; + } + } + return false; + }; + + auto* aborted = new std::atomic_bool(false); + + ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads, abort_callback, aborted); + + delete aborted; #endif +// update the kv ring buffer #if GGML_USE_MPI - ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); -#endif + } - // update the kv ring buffer - { - if (kv_self.has_shift) { - kv_self.has_shift = false; - for (uint32_t i = 0; i < kv_self.size; ++i) { - kv_self.cells[i].delta = 0; + { + if (kv_self.has_shift) { + kv_self.has_shift = false; + for (uint32_t i = 0; i < kv_self.size; ++i) { + kv_self.cells[i].delta = 0; + } } - } - kv_self.head += n_tokens; + kv_self.head += n_tokens; - // Ensure kv cache head points to a valid index. - if (kv_self.head >= kv_self.size) { - kv_self.head = 0; + // Ensure kv cache head points to a valid index. + if (kv_self.head >= kv_self.size) { + kv_self.head = 0; + } } - } + ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf); +#endif + + #ifdef GGML_PERF - // print timing information per ggml operation (for debugging purposes) - // requires GGML_PERF to be defined - ggml_graph_print(gf); + // print timing information per ggml operation (for debugging purposes) + // requires GGML_PERF to be defined + ggml_graph_print(gf); #endif - // plot the computation graph in dot format (for debugging purposes) - //if (n_past%100 == 0) { - // ggml_graph_dump_dot(gf, NULL, "llama.dot"); - //} + return gf; - // extract logits - // TODO: do not compute and extract logits if only embeddings are needed - // need to update the graphs to skip "result_output" - { + } else if (phase == 1) { + ggml_cgraph * gf = cgraph; + struct ggml_tensor *res = gf->nodes[gf->n_nodes - 1]; + struct ggml_tensor *embeddings = gf->nodes[gf->n_nodes - 2]; + + // Resize logits auto & logits_out = lctx.logits; + { - if (batch.logits) { - logits_out.resize(n_vocab * n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - if (batch.logits[i] == 0) { - continue; + + if (batch.logits || lctx.logits_all) { + logits_out.resize(n_vocab * n_tokens); + } else { + logits_out.resize(n_vocab); + } + } + +#ifdef GGML_USE_MPI + if (ggml_mpi_size(lctx.ctx_mpi) > 1 && ggml_mpi_rank(lctx.ctx_mpi) == 0) { + // TODO print logits array for comparison + ggml_mpi_recv_float_array(lctx.ctx_mpi, logits_out.data(), (batch.logits || lctx.logits_all) ? n_vocab * n_tokens : n_vocab, ggml_mpi_size(lctx.ctx_mpi) - 1, GGML_MPI_SYNC_LOGITS); +// printf("\nReceived %zu logits, logits_out.size = %zu\n", n_vocab * n_tokens, logits_out.size()); +// printf("batch: %s\n", LOG_BATCH_TOSTR_PRETTY(&lctx, batch).c_str()); +// for (auto logit : logits_out) { +// printf("%f, ", logit); +// } +// printf("]\n"); + } + + if (ggml_mpi_rank(lctx.ctx_mpi) == ggml_mpi_size(lctx.ctx_mpi) - 1) { + +#endif + + auto * net_output = (float *) ggml_get_data(res); + + + // plot the computation graph in dot format (for debugging purposes) + //if (n_past%100 == 0) { + // ggml_graph_dump_dot(gf, NULL, "llama.dot"); + //} + + // extract logits + // TODO: do not compute and extract logits if only embeddings are needed + // need to update the graphs to skip "result_output" + { + + if (batch.logits) { + for (uint32_t i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + memcpy(logits_out.data() + (n_vocab*i), net_output + (n_vocab*i), sizeof(float)*n_vocab); } - memcpy(logits_out.data() + (n_vocab*i), (float *) ggml_get_data(res) + (n_vocab*i), sizeof(float)*n_vocab); + } else if (lctx.logits_all) { + memcpy(logits_out.data(), net_output, sizeof(float)*n_vocab*n_tokens); + } else { + memcpy(logits_out.data(), net_output + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab); } - } else if (lctx.logits_all) { - logits_out.resize(n_vocab * n_tokens); - memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*n_tokens); - } else { - logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(n_tokens - 1)), sizeof(float)*n_vocab); } - } - // extract embeddings - if (!lctx.embedding.empty()) { - auto & embedding_out = lctx.embedding; + // extract embeddings + if (!lctx.embedding.empty()) { + auto & embedding_out = lctx.embedding; - embedding_out.resize(n_embd); - memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(n_tokens - 1)), sizeof(float)*n_embd); - } + embedding_out.resize(n_embd); + memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(n_tokens - 1)), sizeof(float)*n_embd); + } + +#ifdef GGML_USE_MPI + } + if (ggml_mpi_size(lctx.ctx_mpi) > 1 && ggml_mpi_rank(lctx.ctx_mpi) == ggml_mpi_size(lctx.ctx_mpi) - 1) { +// printf("\nSent %zu logits, logits_out.size = %zu\nbatch: %s\n", n_vocab * n_tokens, logits_out.size(), LOG_BATCH_TOSTR_PRETTY(&lctx, batch).c_str()); + ggml_mpi_send_float_array_async(lctx.ctx_mpi, logits_out.data(), (batch.logits || lctx.logits_all) ? n_vocab * n_tokens : n_vocab, 0, GGML_MPI_SYNC_LOGITS); +// llama_kv_cache_view view = llama_kv_cache_view_init(&lctx, 21); +// llama_kv_cache_view_update(&lctx, &view); +// printf("Cache view:\n%s\n", dump_kv_cache_view_seqs(view, 1).c_str()); + } +#endif + + // measure the performance only for the single-token evals + if (n_tokens == 1) { + lctx.t_eval_us += ggml_time_us() - t_start_us; + lctx.n_eval++; + } + else if (n_tokens > 1) { + lctx.t_p_eval_us += ggml_time_us() - t_start_us; + lctx.n_p_eval += n_tokens; + } + + // get a more accurate load time, upon first eval + // TODO: fix this + if (!lctx.has_evaluated_once) { + lctx.t_load_us = ggml_time_us() - lctx.t_start_us; + lctx.has_evaluated_once = true; + } + return cgraph; - // measure the performance only for the single-token evals - if (n_tokens == 1) { - lctx.t_eval_us += ggml_time_us() - t_start_us; - lctx.n_eval++; } - else if (n_tokens > 1) { - lctx.t_p_eval_us += ggml_time_us() - t_start_us; - lctx.n_p_eval += n_tokens; + return nullptr; +} + +// decode a batch of tokens by evaluating the transformer +// +// - lctx: llama context +// - batch: batch to evaluate +// +// return 0 on success +// return positive int on warning +// return negative int on error +// +static int llama_decode_internal( + llama_context & lctx, + llama_batch & batch) { + struct ggml_cgraph * gf = llama_decode_internal_phased(lctx, batch, 0, nullptr); + if (gf != nullptr) { + return llama_decode_internal_phased(lctx, batch, 1, gf) == nullptr; + } else { + //printf("Graph is null\n"); + return -1; } +} + +struct ggml_cgraph * llama_start_async_decode( + llama_context & lctx, + llama_batch batch) { + return llama_decode_internal_phased(lctx, batch, 0, nullptr); + +} + +int llama_finish_async_decode( + struct llama_context & lctx, + struct llama_batch batch, + struct ggml_cgraph * cgraph) { + + int ret; + if (cgraph != nullptr) { - // get a more accurate load time, upon first eval - // TODO: fix this - if (!lctx.has_evaluated_once) { - lctx.t_load_us = ggml_time_us() - lctx.t_start_us; - lctx.has_evaluated_once = true; + ret = llama_decode_internal_phased(lctx, batch, 1, cgraph) != nullptr; + } else { + ret = -1; } - return 0; + return ret; + } // @@ -8412,6 +8616,14 @@ struct llama_model_quantize_params llama_model_quantize_default_params() { return result; } +int llama_node_id(struct llama_context * ctx) { +#ifdef GGML_USE_MPI + return ggml_mpi_rank(ctx->ctx_mpi); + +#endif + return 0; +} + int llama_max_devices(void) { return LLAMA_MAX_DEVICES; } @@ -8675,22 +8887,58 @@ struct llama_context * llama_new_context_with_model( #ifdef GGML_USE_MPI ctx->ctx_mpi = ggml_mpi_init(); + ctx->ctx_mpi_orig = ctx->ctx_mpi; + - if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { - // Enter a blocking eval loop with dummy input, letting rank=0 drive the process - // TODO: needs fix after #3228 - GGML_ASSERT(false && "not implemented"); - //const std::vector tmp(ctx->model.hparams.n_ctx, llama_token_bos(ctx)); - //while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; - llama_backend_free(); - exit(1); - } #endif return ctx; } +void llama_sync_token(struct llama_context * ctx, llama_token * token, int root) { +#ifdef GGML_USE_MPI + ggml_mpi_synch_int(ctx->ctx_mpi, token, root); +#endif +} + +void llama_sync_token_data(struct llama_context * ctx, llama_token_data * data, int root) { +#ifdef GGML_USE_MPI + ggml_mpi_synch_int(ctx->ctx_mpi, &(data->id), root); + ggml_mpi_synch_float(ctx->ctx_mpi, &(data->logit), root); + ggml_mpi_synch_float(ctx->ctx_mpi, &(data->p), root); +#endif +} + +void llama_swap_comm(struct llama_context * ctx) { +#ifdef GGML_USE_MPI + ggml_mpi_context * temp = ctx->ctx_mpi; + ctx->ctx_mpi = ctx->ctx_mpi_orig; + ctx->ctx_mpi_orig = temp; +#endif +} + +void llama_split_comm(struct llama_context * ctx, int color) { +#ifdef GGML_USE_MPI + ctx->ctx_mpi = ggml_mpi_split_comm(ctx->ctx_mpi, color, ggml_mpi_rank(ctx->ctx_mpi)); +#endif +} + +void llama_split_layers_weighted(struct llama_context * ctx, float device_weights[], size_t num_weights) { +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->ctx_mpi) == 0 && ggml_mpi_size(ctx->ctx_mpi) != num_weights) { + GGML_ASSERT(false && "Must have same number of split percentages as devices"); + } + uint16_t** ranges = ggml_mpi_split_range(ctx->ctx_mpi, 0, ctx->model.hparams.n_layer - 1, device_weights); + ggml_mpi_scatter_layers(ctx->ctx_mpi, ranges); + free(ranges); +#endif +} + void llama_free(struct llama_context * ctx) { +#ifdef GGML_USE_MPI + ggml_mpi_free(ctx->ctx_mpi); + ggml_mpi_free(ctx->ctx_mpi_orig); +#endif delete ctx; } @@ -8844,6 +9092,68 @@ void llama_kv_cache_view_free(struct llama_kv_cache_view * view) { } } +std::string dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) { + static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + std::stringstream dumped; + + dumped << "=== Dumping KV cache. total cells " + << view.n_cells + << ", max sequences per cell " + << view.n_max_seq + <<", populated cells " + << view.used_cells + << ", total tokens in cache " + << view.token_count + << ", largest empty slot=" + << view.max_contiguous + << "@ " + << view.max_contiguous_idx + << '\n'; + + std::unordered_map seqs; + llama_kv_cache_view_cell * c_curr = view.cells; + llama_seq_id * cs_curr = view.cells_sequences; + + for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) { + for (int j = 0; j < view.n_max_seq; j++) { + if (cs_curr[j] < 0) { continue; } + if (seqs.find(cs_curr[j]) == seqs.end()) { + if (seqs.size() + 1 >= sizeof(slot_chars)) { break; } + seqs[cs_curr[j]] = cs_curr[j]; + } + } + if (seqs.size() + 1 >= sizeof(slot_chars)) { break; } + } + + dumped << "=== Sequence legend: "; + for (const auto & it : seqs) { + dumped << slot_chars[it.second] << "=" << it.first << ", "; + } + dumped << "'+'=other sequence ids"; + + c_curr = view.cells; + cs_curr = view.cells_sequences; + for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) { + if (i % row_size == 0) { + dumped << "\n" << i << ": "; + } + for (int j = 0; j < view.n_max_seq; j++) { + if (cs_curr[j] >= 0) { + const auto & it = seqs.find(cs_curr[j]); + dumped << ((it != seqs.end()) ? slot_chars[it->second] : '+'); + } else { + dumped << '.'; + } + } + dumped << " (" << c_curr->pos << ") "; + //putchar(' '); + } + + dumped << "\n=== Done dumping\n"; + return dumped.str(); +} + void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) { if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) { view->n_cells = int32_t(ctx->kv_self.size); @@ -8923,25 +9233,121 @@ int llama_get_kv_cache_used_cells(const struct llama_context * ctx) { } void llama_kv_cache_clear(struct llama_context * ctx) { +#ifdef GGML_USE_MPI + ggml_mpi_sync_ints_pipelined(ctx->ctx_mpi, NULL, 0, 1); +#endif llama_kv_cache_clear(ctx->kv_self); } void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->ctx_mpi) == 0 && ggml_mpi_size(ctx->ctx_mpi) > 1) { + int transaction_type = GGML_MPI_KV_SEQ_RM; + ggml_mpi_sync_ints_pipelined(ctx->ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION); + } + int32_t vals[3] = {seq_id, p0, p1}; + ggml_mpi_sync_ints_pipelined(ctx->ctx_mpi, vals, 3, GGML_MPI_KV_SEQ_RM); + seq_id = vals[0]; + p0 = vals[1]; + p1 = vals[2]; +// if (ggml_mpi_rank(ctx->ctx_mpi) == ggml_mpi_size(ctx->ctx_mpi) - 1 && ggml_mpi_size(ctx->ctx_mpi) > 1) { +// printf("\nRemoving sequence %d from %d to %d\n", seq_id, p0, p1); +// } +#endif llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); } +void llama_kv_cache_seq_cp_sync_bi(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { +#ifdef GGML_USE_MPI + + int32_t vals[4] = {seq_id_src, seq_id_dst, p0, p1}; + ggml_mpi_sync_ints_pipelined(ctx->ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_CP); + ggml_mpi_sync_ints_pipelined_back(ctx->ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_CP_BACK); + ggml_mpi_inc_trans_id(ctx->ctx_mpi); + seq_id_src = vals[0]; + seq_id_dst = vals[1]; + p0 = vals[2]; + p1 = vals[3]; +#endif + if (seq_id_src == seq_id_dst) { + return; + } + + llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); +} + void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->ctx_mpi) == 0 && ggml_mpi_size(ctx->ctx_mpi) > 1) { + int transaction_type = GGML_MPI_KV_SEQ_CP; + ggml_mpi_sync_ints_pipelined(ctx->ctx_mpi, &transaction_type, 1, GGML_MPI_BEGIN_TRANSACTION); + } + + int32_t vals[4] = {seq_id_src, seq_id_dst, p0, p1}; + ggml_mpi_sync_ints_pipelined(ctx->ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_CP); + if(ggml_mpi_recv_trans_id(ctx->ctx_mpi) < ggml_mpi_trans_id(ctx->ctx_mpi)) { +// return; + } + ggml_mpi_inc_trans_id(ctx->ctx_mpi); + seq_id_src = vals[0]; + seq_id_dst = vals[1]; + p0 = vals[2]; + p1 = vals[3]; +// if (ggml_mpi_rank(ctx->ctx_mpi) == ggml_mpi_size(ctx->ctx_mpi) - 1 && ggml_mpi_size(ctx->ctx_mpi) > 1) { +// printf("\nCopying sequence %d to sequence %d from %d to %d\n", seq_id_src, seq_id_dst, p0, p1); +// } +#endif + if (seq_id_src == seq_id_dst) { + return; + } + + llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_cache_seq_cp_back(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { +#ifdef GGML_USE_MPI + int32_t vals[4] = {seq_id_src, seq_id_dst, p0, p1}; + ggml_mpi_sync_ints_pipelined_back(ctx->ctx_mpi, vals, 4, GGML_MPI_KV_SEQ_CP_BACK); + if(ggml_mpi_recv_trans_id(ctx->ctx_mpi) < ggml_mpi_trans_id(ctx->ctx_mpi)) { +// return; + } + ggml_mpi_inc_trans_id(ctx->ctx_mpi); + seq_id_src = vals[0]; + seq_id_dst = vals[1]; + p0 = vals[2]; + p1 = vals[3]; + +// if (ggml_mpi_rank(ctx->ctx_mpi) == ggml_mpi_size(ctx->ctx_mpi) - 1 && ggml_mpi_size(ctx->ctx_mpi) > 1) { +// printf("\nCopying sequence %d to sequence %d from %d to %d\n", seq_id_src, seq_id_dst, p0, p1); +// } + +#endif if (seq_id_src == seq_id_dst) { return; } + llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); } void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { +#ifdef GGML_USE_MPI + int32_t vals[1] = {seq_id}; + ggml_mpi_sync_ints_pipelined(ctx->ctx_mpi, vals, 1, 4); + seq_id = vals[0]; +#endif llama_kv_cache_seq_keep(ctx->kv_self, seq_id); } void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { +#ifdef GGML_USE_MPI + int32_t vals[4] = {seq_id, p0, p1, delta}; + ggml_mpi_sync_ints_pipelined(ctx->ctx_mpi, vals, 4, 5); + seq_id = vals[0]; + p0 = vals[1]; + p1 = vals[2]; + delta = vals[3]; +#endif + llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta); } @@ -9374,9 +9780,26 @@ int llama_eval( llama_token * tokens, int32_t n_tokens, int n_past) { - llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1); - const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0)); + +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { + // Enter a blocking eval loop with dummy input, letting rank=0 drive the process + const int n_ctx = llama_n_ctx(ctx); + std::vector tmp(n_ctx, llama_token_bos(&(ctx->model))); + llama_batch tmp_batch = llama_batch_get_one(tmp.data(), tmp.size(), n_past, 0); + do { + //ggml_mpi_synch_int(ctx->ctx_mpi, &n_past); + llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1); + } while (llama_decode_internal(*ctx, tmp_batch) >= 0); + llama_backend_free(); + exit(1); + } +#endif + + llama_batch tmp_batch = llama_batch_get_one(tokens, n_tokens, n_past, 0); + llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1); + const int ret = llama_decode_internal(*ctx, tmp_batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } @@ -9391,7 +9814,7 @@ int llama_eval_embd( int n_past) { llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1); - llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, }; + llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, 0,}; const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { @@ -9422,11 +9845,12 @@ struct llama_batch llama_batch_get_one( /*all_pos_0 =*/ pos_0, /*all_pos_1 =*/ 1, /*all_seq_id =*/ seq_id, + 0, 1 }; } struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_seq_max) { - llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, }; + llama_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, 0, n_seq_max}; if (embd) { batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd); @@ -9445,7 +9869,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens, int32_t embd, int32_t n_se return batch; } -void llama_batch_free(struct llama_batch batch) { +void llama_batch_free(struct llama_batch & batch) { if (batch.token) free(batch.token); if (batch.embd) free(batch.embd); if (batch.pos) free(batch.pos); @@ -9457,13 +9881,128 @@ void llama_batch_free(struct llama_batch batch) { free(batch.seq_id); } if (batch.logits) free(batch.logits); + + batch.token = nullptr; + batch.embd = nullptr; + batch.pos = nullptr; + batch.n_seq_id = nullptr; + batch.seq_id = nullptr; + batch.logits = nullptr; +} + +#ifdef GGML_USE_MPI + +int llama_process_mpi_transaction( + struct llama_context * ctx, + struct llama_batch & batch, + int tag) { +// if (ggml_mpi_rank(ctx->ctx_mpi) == ggml_mpi_size(ctx->ctx_mpi) - 1) { +// printf("\nBeginning transaction type %d\n", tag); +// } + + switch (tag) { + case GGML_MPI_DECODE: +// llama_batch_free(batch); + return llama_decode_internal(*ctx, batch); + break; + case GGML_MPI_KV_CLEAR: + llama_kv_cache_clear(ctx); + break; + case GGML_MPI_KV_SEQ_RM: + llama_kv_cache_seq_rm(ctx, 1, -1, -1); + break; + case GGML_MPI_KV_SEQ_CP: + llama_kv_cache_seq_cp(ctx, 0, 0, 0, 0); + break; + case GGML_MPI_KV_SEQ_CP_BACK: + llama_kv_cache_seq_cp_back(ctx, 0, 0, 0, 0); + break; + case GGML_MPI_KV_SEQ_KEEP: + llama_kv_cache_seq_keep(ctx, 0); + break; + case GGML_MPI_KV_SEQ_SHIFT: + llama_kv_cache_seq_shift(ctx, 0, 0, 0, 0); + break; + default: + printf("Unknown operation, exiting\n"); + exit(1); + break; + } + return 0; +} + +int llama_process_mpi_worker( + struct llama_context * ctx, + struct llama_batch & batch) { + ggml_mpi_probe(ctx->ctx_mpi, -1, -1); + int tag = ggml_mpi_status_tag(ctx->ctx_mpi); + int32_t count; + int32_t trans_type; +// if (ggml_mpi_rank(ctx->ctx_mpi) == ggml_mpi_size(ctx->ctx_mpi) - 1) { +// printf("\nReceived command %d\n", tag); +// } + switch (tag) { + case GGML_MPI_BEGIN_TRANSACTION: + + ggml_mpi_sync_ints_pipelined(ctx->ctx_mpi, &trans_type, 1, GGML_MPI_BEGIN_TRANSACTION); + return llama_process_mpi_transaction(ctx, batch, trans_type); + break; + case GGML_MPI_SHUTDOWN: + llama_free(ctx); + llama_backend_free(); + exit(0); + break; + case GGML_MPI_CANCEL_RUN: + count = ggml_mpi_status_count_int32(ctx->ctx_mpi); +// printf("Received cancel run\n"); + { + std::vector canceled(count, -1); + llama_cancel_run(ctx, canceled.data(), canceled.size()); + + } + break; + default: + printf("Unknown operation, exiting\n"); + exit(1); + break; + } + return 0; +} + +#endif + +void llama_cancel_run(struct llama_context * ctx, int32_t * canceled, int count) { + ggml_mpi_sync_ints_pipelined_back(ctx->ctx_mpi, canceled, count, GGML_MPI_CANCEL_RUN); + for (int i = 0; i < count; i++) { + int32_t run_id = canceled[i]; + auto it = ctx->canceled_batches.find(run_id); + if (it != ctx->canceled_batches.end()) { + it->second = true; + } else { + ctx->canceled_batches.insert({run_id, true}); +// ctx->canceled_batches[run_id] = true; + } + } } int llama_decode( struct llama_context * ctx, struct llama_batch batch) { + +#ifdef GGML_USE_MPI + if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { + // Enter a blocking eval loop with dummy input, letting rank=0 drive the process + const int n_ctx = llama_n_ctx(ctx); + std::vector tmp(n_ctx, llama_token_bos(&(ctx->model))); + while (llama_process_mpi_worker(ctx, batch) >= 0){} + llama_backend_free(); + exit(1); + } else if (ggml_mpi_rank(ctx->ctx_mpi) < 0) { + return 0; + } +#endif const int ret = llama_decode_internal(*ctx, batch); - if (ret < 0) { + if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } diff --git a/llama.h b/llama.h index 89cb6198e84..7ccbdf03064 100644 --- a/llama.h +++ b/llama.h @@ -12,6 +12,7 @@ #include #include #include +#include #ifdef LLAMA_SHARED # if defined(_WIN32) && !defined(__MINGW32__) @@ -156,6 +157,8 @@ extern "C" { llama_pos all_pos_0; // used if pos == NULL llama_pos all_pos_1; // used if pos == NULL llama_seq_id all_seq_id; // used if seq_id == NULL + int32_t batch_id; + int32_t max_n_seq; } llama_batch; struct llama_model_params { @@ -273,6 +276,23 @@ extern "C" { const char * path_model, struct llama_model_params params); + LLAMA_API void llama_split_layers_weighted(struct llama_context * ctx, float device_weights[], size_t num_weights); + + LLAMA_API void llama_swap_comm(struct llama_context * ctx); + + LLAMA_API void llama_sync_token(struct llama_context * ctx, llama_token * token, int root); + + LLAMA_API struct ggml_cgraph * llama_start_async_decode(struct llama_context & lctx, + struct llama_batch batch); + + LLAMA_API int llama_finish_async_decode(struct llama_context & lctx, + struct llama_batch batch, + struct ggml_cgraph * cgraph); + + LLAMA_API void llama_sync_token_data(struct llama_context * ctx, llama_token_data * data, int root); + + LLAMA_API void llama_split_comm(struct llama_context * ctx, int color); + LLAMA_API void llama_free_model(struct llama_model * model); LLAMA_API struct llama_context * llama_new_context_with_model( @@ -284,6 +304,14 @@ extern "C" { LLAMA_API int64_t llama_time_us(void); + // Get the ID of this compute node, usually 0 + // unless running MPI, in which case it is the rank of the node + LLAMA_API int llama_node_id(struct llama_context * ctx); + + LLAMA_API bool llama_mpi_iprobe(struct llama_context * lctx); + + LLAMA_API void llama_cancel_run(struct llama_context * ctx, int32_t * canceled, int count); + LLAMA_API int llama_max_devices (void); LLAMA_API bool llama_mmap_supported (void); LLAMA_API bool llama_mlock_supported(void); @@ -441,6 +469,20 @@ extern "C" { llama_pos p0, llama_pos p1); +LLAMA_API void llama_kv_cache_seq_cp_back( + struct llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1); + +LLAMA_API void llama_kv_cache_seq_cp_sync_bi( + struct llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1); + // Removes all tokens that do not belong to the specified sequence LLAMA_API void llama_kv_cache_seq_keep( struct llama_context * ctx, @@ -540,7 +582,7 @@ extern "C" { int32_t n_seq_max); // Frees a batch of tokens allocated with llama_batch_init() - LLAMA_API void llama_batch_free(struct llama_batch batch); + LLAMA_API void llama_batch_free(struct llama_batch & batch); // Positive return values does not mean a fatal error, but rather a warning. // 0 - success @@ -625,7 +667,10 @@ extern "C" { llama_token token, char * buf, int length); - +extern "C++" { +// Dump the KV cache view showing individual sequences in each cell (long output). +std::string dump_kv_cache_view_seqs(const llama_kv_cache_view &view, int row_size = 40); +} // // Grammar //