From 602c7abf9d5696067cb7ff7cf04c04d630aa6f4b Mon Sep 17 00:00:00 2001 From: Simon Waters Date: Sat, 31 Jan 2026 22:46:12 +0000 Subject: [PATCH] [TT] Add support for common runtime args - All uniform params broadcast as common runtime args - Only persistent start,end passed as non-uniform core runtime args --- plugins/tenstorrent/tt_command.cpp | 15 +++++++++------ plugins/tenstorrent/tt_library.cpp | 6 ++++++ plugins/tenstorrent/tt_library.h | 1 + 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/plugins/tenstorrent/tt_command.cpp b/plugins/tenstorrent/tt_command.cpp index 07c2767..67a79dd 100644 --- a/plugins/tenstorrent/tt_command.cpp +++ b/plugins/tenstorrent/tt_command.cpp @@ -71,15 +71,18 @@ nxs_status TTCommand::runCommand(nxs_int stream, ttmd::MeshWorkload &workload, int persistent_grid_stride = std::max(1, total_grid_size / (int)cores.size()); NXSAPI_LOG(nexus::NXS_LOG_NOTE, "Total grid size: ", total_grid_size, ", cores: ", cores.size(), ", persistent grid stride: ", persistent_grid_stride); + library->setupCommonRuntime(program, rt_args); + // set params int persistent_grid_idx = 0; for (const auto& core : cores) { - rt_args[numArgs] = persistent_grid_idx * persistent_grid_stride; - rt_args[numArgs+1] = persistent_grid_idx * persistent_grid_stride + persistent_grid_stride; - if (rt_args[numArgs+1] > total_grid_size) - rt_args[numArgs+1] = total_grid_size; - NXSAPI_LOG(nexus::NXS_LOG_NOTE, "Launch params: grid_idx=", persistent_grid_idx, ", start=", rt_args[numArgs], ", end=", rt_args[numArgs+1]); - library->setupCoreRuntime(program, core, rt_args); + TTLibrary::RunTimeArgs core_rt_args; + core_rt_args[0] = persistent_grid_idx * persistent_grid_stride; + core_rt_args[1] = persistent_grid_idx * persistent_grid_stride + persistent_grid_stride; + if (core_rt_args[1] > total_grid_size) + core_rt_args[1] = total_grid_size; + NXSAPI_LOG(nexus::NXS_LOG_NOTE, "Launch params: grid_idx=", persistent_grid_idx, ", start=", core_rt_args[0], ", end=", core_rt_args[1]); + library->setupCoreRuntime(program, core, core_rt_args); persistent_grid_idx++; } diff --git a/plugins/tenstorrent/tt_library.cpp b/plugins/tenstorrent/tt_library.cpp index 3df30f7..b66f5f4 100644 --- a/plugins/tenstorrent/tt_library.cpp +++ b/plugins/tenstorrent/tt_library.cpp @@ -22,6 +22,12 @@ void TTLibrary::jitProgram(ttm::Program &program, const ttm::CoreRange &cores, c ttm::ComputeConfig{.math_fidelity = MathFidelity::HiFi4, .compile_args = compile_time_args}); } +void TTLibrary::setupCommonRuntime(ttm::Program &program, const RunTimeArgs &run_time_args) { + TT_CHECK(ttm::SetCommonRuntimeArgs, program, reader_kernel, run_time_args); + TT_CHECK(ttm::SetCommonRuntimeArgs, program, writer_kernel, run_time_args); + TT_CHECK(ttm::SetCommonRuntimeArgs, program, compute_kernel, run_time_args); +} + void TTLibrary::setupCoreRuntime(ttm::Program &program, const ttm::CoreCoord &core, const RunTimeArgs &run_time_args) { TT_CHECK(ttm::SetRuntimeArgs, program, reader_kernel, core, run_time_args); TT_CHECK(ttm::SetRuntimeArgs, program, writer_kernel, core, run_time_args); diff --git a/plugins/tenstorrent/tt_library.h b/plugins/tenstorrent/tt_library.h index 7204ac3..2fa3085 100644 --- a/plugins/tenstorrent/tt_library.h +++ b/plugins/tenstorrent/tt_library.h @@ -41,6 +41,7 @@ class TTLibrary { typedef std::array RunTimeArgs; void jitProgram(ttm::Program &program, const ttm::CoreRange &cores, const CompileTimeArgs &compile_time_args); + void setupCommonRuntime(ttm::Program &program, const RunTimeArgs &run_time_args); void setupCoreRuntime(ttm::Program &program, const ttm::CoreCoord &core, const RunTimeArgs &run_time_args); };