From ae94f8ec9413228e3b62d18bc911e9f2b1e6ba6c Mon Sep 17 00:00:00 2001 From: jorivesga Date: Mon, 17 Nov 2025 06:05:52 +0200 Subject: [PATCH 1/6] Refm life science SwinUNETR Inference (#459) Add helm charts for the REFM life-science SwinUNETR inference service usecase --- .pre-commit-config.yaml | 8 +- .../examples/demo_inference_service.ipynb | 376 ++++++++++++++++++ .../examples/utils.py | 262 ++++++++++++ .../helm/Chart.yaml | 4 + .../helm/README.md | 68 ++++ .../helm/mount/README.md | 3 + .../helm/mount/data_utils.py | 39 ++ .../helm/mount/entrypoint.sh | 11 + .../helm/mount/inference_service.py | 223 +++++++++++ .../helm/mount/requirements.txt | 11 + .../helm/mount/swinunetr.py | 121 ++++++ .../helm/mount/swinunetr_configuration.py | 94 +++++ .../helm/overrides/kaiwo/kaiwo-enable.yaml | 3 + .../helm/templates/_helpers.tpl | 72 ++++ .../helm/templates/configmap.yaml | 11 + .../helm/templates/deployment.yaml | 96 +++++ .../helm/templates/service.yaml | 22 + .../helm/values.schema.json | 221 ++++++++++ .../helm/values.yaml | 55 +++ 19 files changed, 1696 insertions(+), 4 deletions(-) create mode 100644 workloads/dev-lifescience-swinunetr-inference/examples/demo_inference_service.ipynb create mode 100644 workloads/dev-lifescience-swinunetr-inference/examples/utils.py create mode 100644 workloads/dev-lifescience-swinunetr-inference/helm/Chart.yaml create mode 100644 workloads/dev-lifescience-swinunetr-inference/helm/README.md create mode 100644 workloads/dev-lifescience-swinunetr-inference/helm/mount/README.md create mode 100644 workloads/dev-lifescience-swinunetr-inference/helm/mount/data_utils.py create mode 100644 workloads/dev-lifescience-swinunetr-inference/helm/mount/entrypoint.sh create mode 100644 workloads/dev-lifescience-swinunetr-inference/helm/mount/inference_service.py create mode 100644 workloads/dev-lifescience-swinunetr-inference/helm/mount/requirements.txt create mode 100644 workloads/dev-lifescience-swinunetr-inference/helm/mount/swinunetr.py create mode 100644 workloads/dev-lifescience-swinunetr-inference/helm/mount/swinunetr_configuration.py create mode 100644 workloads/dev-lifescience-swinunetr-inference/helm/overrides/kaiwo/kaiwo-enable.yaml create mode 100644 workloads/dev-lifescience-swinunetr-inference/helm/templates/_helpers.tpl create mode 100644 workloads/dev-lifescience-swinunetr-inference/helm/templates/configmap.yaml create mode 100644 workloads/dev-lifescience-swinunetr-inference/helm/templates/deployment.yaml create mode 100644 workloads/dev-lifescience-swinunetr-inference/helm/templates/service.yaml create mode 100644 workloads/dev-lifescience-swinunetr-inference/helm/values.schema.json create mode 100644 workloads/dev-lifescience-swinunetr-inference/helm/values.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 39675cd..8240e42 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,26 +17,26 @@ repos: hooks: - id: black language_version: python3.12 - args: ["--config=pyproject.toml"] + args: [ "--config=pyproject.toml" ] - repo: https://github.com/pycqa/flake8 rev: 7.2.0 hooks: - id: flake8 - args: ["--config=.flake8"] + args: [ "--config=.flake8" ] - repo: https://github.com/pycqa/isort rev: 6.0.1 hooks: - id: isort name: isort (python) - args: ["--settings-path=pyproject.toml"] + args: [ "--settings-path=pyproject.toml" ] - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.16.0 hooks: - id: mypy - args: ["--config-file=pyproject.toml"] + args: [ "--config-file=pyproject.toml", "--install-types", "--non-interactive" ] exclude: kaiwo|mount language_version: python3.12 additional_dependencies: diff --git a/workloads/dev-lifescience-swinunetr-inference/examples/demo_inference_service.ipynb b/workloads/dev-lifescience-swinunetr-inference/examples/demo_inference_service.ipynb new file mode 100644 index 0000000..c3b15d0 --- /dev/null +++ b/workloads/dev-lifescience-swinunetr-inference/examples/demo_inference_service.ipynb @@ -0,0 +1,376 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "df57a5ef-2a09-472b-ab6c-ddbbb6e7f7ce", + "metadata": {}, + "source": [ + "# Demo Inference Service\n", + "\n", + "Demo notebook for sending a prediction request to the deployed inference service, and visualizing the results." + ] + }, + { + "cell_type": "markdown", + "id": "96505e68-1c04-4d77-bc40-cd24bc323933", + "metadata": {}, + "source": [ + "## Requirements" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e18a2528-56e0-4552-a796-a6552892fd55", + "metadata": {}, + "outputs": [], + "source": "! pip install numpy==1.26.4 einops==0.4.1 scipy==1.10.1 connected-components-3d monai[nibabel,pillow,ignite,tqdm,pydicom]====1.5.0 synapseclient" + }, + { + "cell_type": "markdown", + "id": "edf47a9e-51f2-44e6-817d-3b1ff9f35e52", + "metadata": {}, + "source": [ + "## Import utility functions" + ] + }, + { + "cell_type": "code", + "id": "1d7f577c-abe3-4e39-9b04-52029eee283b", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-04T07:51:09.815819Z", + "start_time": "2025-08-04T07:50:53.040426Z" + } + }, + "source": "from utils import plot_input_scan, send_prediction_request, plot_results, plot_results_overlap", + "outputs": [], + "execution_count": 1 + }, + { + "cell_type": "markdown", + "id": "e449c4c2-bf4f-4e3f-88d6-9c89029dfc0c", + "metadata": {}, + "source": [ + "## Dataset\n", + "\n", + "Download scan image from the [Multi-Atlas Labeling Beyond the Cranial Vault - Workshop and Challenge](https://www.synapse.org/Synapse:syn3193805/wiki/89480) dataset.\n", + "\n", + "You must create a Synapse account and get Personal Access Token (PAT) to access the dataset.\n", + "\n", + "```\n", + "pip install synapseclient\n", + "\n", + "synapse get syn3553734\n", + "\n", + "unzip Abdomen.zip\n", + "```\n", + "\n", + "The data contains labels for 13 different organs. Check this link for additional details on the dataset: [Abdomen dataset](https://www.synapse.org/Synapse:syn3193805/wiki/217789).\n", + "\n", + "- (1) spleen\n", + "- (2) right kidney\n", + "- (3) left kidney\n", + "- (4) gallbladder\n", + "- (5) esophagus\n", + "- (6) liver\n", + "- (7) stomach\n", + "- (8) aorta\n", + "- (9) inferior vena cava\n", + "- (10) portal vein and splenic vein\n", + "- (11) pancreas\n", + "- (12) right adrenal gland\n", + "- (13) left adrenal gland" + ] + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "## Port forwarding to inference service", + "id": "50ff6cfe-1691-4f9a-8406-918d929b3a96" + }, + { + "cell_type": "markdown", + "id": "fd9d0267-fb50-466e-94ed-47689e98f078", + "metadata": {}, + "source": [ + "1. Open your terminal and run the following command:\n", + " \n", + " `kubectl port-forward -n service/dev-lifescience-swinunetr- 8000:80`\n", + "\n", + "2. Keep the Terminal Open: The port-forward command runs in the foreground. You must keep this terminal window open for the connection to remain active." + ] + }, + { + "cell_type": "markdown", + "id": "f1aa4370-b34e-4957-88a2-f58e221c7188", + "metadata": {}, + "source": [ + "## Inference Parameters" + ] + }, + { + "cell_type": "code", + "id": "b4bc3921-c81e-4fa1-829f-d8674d9f793e", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-04T07:51:09.825745Z", + "start_time": "2025-08-04T07:51:09.821312Z" + } + }, + "source": [ + "# Path to the image file to send for prediction.\n", + "image_path = \"./Abdomen/RawData/Training/img/img0001.nii.gz\"\n", + "\n", + "# Ground truth label file for the image.\n", + "label_path = \"./Abdomen/RawData/Training/label/label0001.nii.gz\"\n", + "\n", + "# Filename to save the output prediction\n", + "output_path = \"client_prediction.nii.gz\"\n", + "\n", + "# URL of the prediction inference service endpoint\n", + "service_url = \"http://localhost:8000/predict/\"" + ], + "outputs": [], + "execution_count": 2 + }, + { + "cell_type": "markdown", + "id": "bf2b1839-c8b0-48f1-8575-8465946c087d", + "metadata": {}, + "source": [ + "## Input view" + ] + }, + { + "cell_type": "code", + "id": "28ecbd0b-a38b-4868-bad7-b6c6af735365", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-04T07:51:50.219603Z", + "start_time": "2025-08-04T07:51:45.747510Z" + } + }, + "source": [ + "plot_input_scan(input_path=image_path, num_slices_to_plot=3)" + ], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1\n", + "2025-08-04 10:51:45,765 - INFO - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1\n", + "2025-08-04 10:51:49,784 - INFO - Input data shape: torch.Size([229, 229, 220])\n" + ] + }, + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 3 + }, + { + "cell_type": "markdown", + "id": "d3526241-feb0-4e9a-83c8-30267538f1bd", + "metadata": {}, + "source": [ + "## Send prediction request\n", + "\n", + "Send a CT san to the inference service for prediction:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "6351f3f9-304a-48c5-8900-62ab27fd46e7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:utils:Attempting to send ../../data/inputs_nifti/LUNG1-001_input.nii.gz to http://localhost:8001/predict/\n", + "INFO:utils:Status Code: 200\n", + "INFO:utils:Success! Prediction saved to client_prediction.nii.gz\n" + ] + } + ], + "source": [ + "send_prediction_request(image_path=image_path, output_path=output_path, server_url=service_url)" + ] + }, + { + "cell_type": "markdown", + "id": "175b4943-b71a-421a-b821-7986fa9e6c82", + "metadata": {}, + "source": [ + "## Visualize results" + ] + }, + { + "cell_type": "code", + "id": "c6988e85-a0c3-4de0-a1c7-287fb43d13b2", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-04T07:53:00.108362Z", + "start_time": "2025-08-04T07:52:08.500029Z" + } + }, + "source": [ + "channel_idx = 6\n", + "\n", + "plot_results(image_path, label_path, output_path, channel_idx=channel_idx, num_slices_to_plot=3)" + ], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1\n", + "2025-08-04 10:52:08,516 - INFO - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1\n", + "pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1\n", + "2025-08-04 10:52:10,516 - INFO - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input data shape: torch.Size([1, 229, 229, 220])\n", + "Label data shape: torch.Size([14, 229, 229, 220])\n", + "Prediction data shape: torch.Size([14, 229, 229, 220])\n" + ] + }, + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 4 + }, + { + "cell_type": "code", + "id": "dc1cc7f2-3d49-40e0-a0d8-19dae6ed4673", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-04T07:53:41.881454Z", + "start_time": "2025-08-04T07:53:00.124815Z" + } + }, + "source": [ + "slice_to_plot = 110\n", + "\n", + "plot_results_overlap(\n", + " image_path, label_path, output_path, channel_idx=channel_idx, slice_to_plot=slice_to_plot\n", + ")" + ], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1\n", + "2025-08-04 10:53:00,148 - INFO - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1\n", + "pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1\n", + "2025-08-04 10:53:02,367 - INFO - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1\n" + ] + }, + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 5 + }, + { + "cell_type": "code", + "id": "e01e3447-02cf-4d3d-b9f0-9b32dd6a0ec7", + "metadata": { + "ExecuteTime": { + "end_time": "2025-08-04T07:54:25.689698Z", + "start_time": "2025-08-04T07:53:41.990150Z" + } + }, + "source": [ + "channel_idx = 6\n", + "slice_to_plot = 165\n", + "\n", + "plot_results_overlap(\n", + " image_path, label_path, output_path, channel_idx=channel_idx, slice_to_plot=slice_to_plot\n", + ")" + ], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1\n", + "2025-08-04 10:53:42,010 - INFO - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1\n", + "pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1\n", + "2025-08-04 10:53:44,120 - INFO - pixdim[0] (qfac) should be 1 (default) or -1; setting qfac to 1\n" + ] + }, + { + "data": { + "text/plain": [ + "
" + ], + "image/png": "" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 6 + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "", + "id": "194cacfaaa52c3ca" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/workloads/dev-lifescience-swinunetr-inference/examples/utils.py b/workloads/dev-lifescience-swinunetr-inference/examples/utils.py new file mode 100644 index 0000000..05eccc1 --- /dev/null +++ b/workloads/dev-lifescience-swinunetr-inference/examples/utils.py @@ -0,0 +1,262 @@ +import logging +import os +from typing import List + +import matplotlib.pyplot as plt +import numpy as np +import requests +from monai.transforms import AsDiscreted, Compose, EnsureChannelFirstd, LoadImaged, Spacingd + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +LOGGER = logging.getLogger(__name__) + + +def send_prediction_request( + image_path: str, output_path: str = "prediction_output.nii.gz", server_url: str = "http://localhost:8000/predict/" +): + """ + Sends an image to the prediction server and saves the result. + + Args: + image_path (str): Path to the input image file (.nii, .nii.gz, or .npy). + output_path (str): Path to save the received prediction file. + server_url (str): URL of the inference service. + """ + if not os.path.exists(image_path): + LOGGER.error(f"Error: Input image file not found at {image_path}") + return + + LOGGER.info(f"Attempting to send {image_path} to {server_url}") + try: + with open(image_path, "rb") as f: + files = {"file": (os.path.basename(image_path), f)} + response = requests.post(server_url, files=files, timeout=120) # 120-second timeout + + LOGGER.info(f"Status Code: {response.status_code}") + + if response.status_code == 200: + with open(output_path, "wb") as out_f: + out_f.write(response.content) + LOGGER.info(f"Success! Prediction saved to {output_path}") + else: + LOGGER.error("Error: Request failed.") + try: + error_detail = response.json() + LOGGER.error(f"Server Error Detail: {error_detail}") + except requests.exceptions.JSONDecodeError: + LOGGER.error(f"Server Error (non-JSON): {response.text}") + + except requests.exceptions.ConnectionError: + LOGGER.error(f"Error: Could not connect to the server at {server_url}. Is it running?") + except requests.exceptions.Timeout: + LOGGER.error("Error: The request timed out.") + except Exception as e: + LOGGER.error(f"An unexpected error occurred: {e}") + + +def load_and_transform_input(keys: List[str], data_dict: dict): + image_transforms = Compose( + [ + LoadImaged(keys=keys), + EnsureChannelFirstd(keys="input", channel_dim="no_channel"), + Spacingd(keys=keys, pixdim=(1.5, 1.5, 2), mode=("bilinear",)), + ] + ) + return image_transforms(data_dict) + + +def load_and_transform(keys: List[str], data_dict: dict): + """ + Loads and transforms data, converting the label to one-hot format. + """ + num_classes = 14 + + image_transforms = Compose( + [ + LoadImaged(keys=keys), + EnsureChannelFirstd(keys=["input", "label"], channel_dim="no_channel"), + AsDiscreted(keys="label", to_onehot=num_classes), + Spacingd(keys=keys, pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest", "nearest")), + ] + ) + return image_transforms(data_dict) + + +def plot_results( + input_path_str: str, + label_path_str: str, + pred_path_str: str, + channel_idx: int, + num_slices_to_plot: int = 4, +): + """ + Plots slices of the input, prediction, and optionally label NIfTI images. + + Args: + input_path_str (str): Path to the original input NIfTI file. + label_path_str (str): Path to the label NIfTI file. + pred_path_str (str): Path to the prediction NIfTI file. + channel_idx (int): Channel/Label to extract from the multichannel label file. + num_slices_to_plot (int): Number of slices to display. + """ + try: + data_dict = {"input": input_path_str, "pred": pred_path_str, "label": label_path_str} + keys = ["input", "label", "pred"] + + processed_dict = load_and_transform(keys=keys, data_dict=data_dict) + + input_data = processed_dict["input"].squeeze() + pred_data = processed_dict["pred"][channel_idx, ...].squeeze() + label_data = processed_dict["label"][channel_idx, ...].squeeze() + + num_cols = 3 + depth = input_data.shape[2] + + slice_indices = np.linspace(depth // 4, 3 * depth // 4, num_slices_to_plot, dtype=int) + slice_indices = np.clip(np.unique(slice_indices), 0, depth - 1) + + fig, axes = plt.subplots(len(slice_indices), num_cols, figsize=(num_cols * 3, len(slice_indices) * 3)) + if len(slice_indices) == 1: + axes = np.array([axes]).reshape(1, -1) + + for i, slice_idx in enumerate(slice_indices): + axes[i, 0].imshow(np.rot90(input_data[:, :, slice_idx]), cmap="gray") + axes[i, 0].set_title(f"Input - Slice {slice_idx}") + axes[i, 0].axis("off") + + axes[i, 1].imshow(np.rot90(label_data[:, :, slice_idx]), cmap="viridis") + axes[i, 1].set_title(f"Label - Slice {slice_idx}") + axes[i, 1].axis("off") + + axes[i, 2].imshow(np.rot90(pred_data[:, :, slice_idx]), cmap="viridis") + axes[i, 2].set_title(f"Prediction - Slice {slice_idx}") + axes[i, 2].axis("off") + + fig.suptitle(f"Input vs. Prediction (vs. Label) - {os.path.basename(input_path_str)}", fontsize=16) + plt.tight_layout(rect=[0, 0, 1, 0.96]) + plt.show() + plt.close(fig) + except Exception as e: + LOGGER.error(f"An error occurred during plotting: {e}", exc_info=True) + + +def plot_results_overlap( + input_path_str: str, + label_path_str: str, + pred_path_str: str, + channel_idx: int, + slice_to_plot: int, +): + """ + Plots a specific slice showing overlaps of input, prediction, and target label. + + This function generates a 1x3 plot for a single specified slice: + 1. Input slice with the target mask superimposed. + 2. Input slice with the prediction mask superimposed. + 3. Input slice with both target and prediction masks superimposed. + + Args: + input_path_str (str): Path to the original input NIfTI file. + label_path_str (str): Path to the label NIfTI file. + pred_path_str (str): Path to the prediction NIfTI file. + channel_idx (int): Labels to extract from the multichannel label file. + slice_to_plot (int): The specific slice index to visualize. + """ + try: + data_dict = {"input": input_path_str, "pred": pred_path_str, "label": label_path_str} + keys = ["input", "label", "pred"] + + processed_dict = load_and_transform(keys=keys, data_dict=data_dict) + + input_data = processed_dict["input"].squeeze() + pred_data = processed_dict["pred"][channel_idx, ...].squeeze() + label_data = processed_dict["label"][channel_idx, ...].squeeze() + + # Validate slice index + if not (0 <= slice_to_plot < input_data.shape[2]): + raise ValueError(f"Slice index {slice_to_plot} is out of bounds for depth {input_data.shape[2]}.") + + input_slice = np.rot90(input_data[:, :, slice_to_plot]) + pred_slice = np.rot90(pred_data[:, :, slice_to_plot]) + label_slice = np.rot90(label_data[:, :, slice_to_plot]) + + # Use a masked array to only show the "on" pixels of the masks + pred_mask = np.ma.masked_where(pred_slice == 0, pred_slice) + label_mask = np.ma.masked_where(label_slice == 0, label_slice) + + fig, axes = plt.subplots(1, 3, figsize=(12, 4)) + + # Plot 1: Input with Target + axes[0].imshow(input_slice, cmap="gray") + axes[0].imshow(label_mask, cmap="autumn", alpha=0.5) # autumn is yellow-red + axes[0].set_title(f"Input + Target (Slice {slice_to_plot})") + axes[0].axis("off") + + # Plot 2: Input with Prediction + axes[1].imshow(input_slice, cmap="gray") + axes[1].imshow(pred_mask, cmap="cool", alpha=0.5) # cool is cyan-magenta + axes[1].set_title(f"Input + Prediction (Slice {slice_to_plot})") + axes[1].axis("off") + + # Plot 3: Input with Target and Prediction + axes[2].imshow(input_slice, cmap="gray") + axes[2].imshow(label_mask, cmap="autumn", alpha=0.7) + axes[2].imshow(pred_mask, cmap="cool", alpha=0.4) + axes[2].set_title("Input + Target (Red) + Pred (Blue)") + axes[2].axis("off") + + # --- 4. Finalize and Show --- + fig.suptitle(f"Overlap Visualization - {os.path.basename(input_path_str)}", fontsize=16) + plt.tight_layout(rect=[0, 0, 1, 0.95]) + plt.show() + plt.close(fig) + except Exception as e: + LOGGER.error(f"An error occurred during overlap plotting: {e}", exc_info=True) + + +def plot_input_scan(input_path: str, num_slices_to_plot: int = 4): + """ + Plots several axial slices of a single NIfTI input scan. + + Args: + input_path (str): Path to the original input NIfTI file. + num_slices_to_plot (int): Number of slices to display. + """ + try: + # 1. Load only the input data + data_dict = {"input": input_path} + processed_dict = load_and_transform_input(keys=["input"], data_dict=data_dict) + input_data = processed_dict["input"].squeeze() + LOGGER.info(f"Input data shape: {input_data.shape}") + depth = input_data.shape[2] + + # 2. Select representative slices to plot + slice_indices = np.linspace(depth // 4, 3 * depth // 4, num_slices_to_plot, dtype=int) + slice_indices = np.clip(np.unique(slice_indices), 0, depth - 1) + + if len(slice_indices) == 0: + LOGGER.error("Could not determine valid slice indices to plot.") + return + + # 3. Create the plot grid + num_cols = int(np.ceil(np.sqrt(len(slice_indices)))) + num_rows = int(np.ceil(len(slice_indices) / num_cols)) + + fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 3, num_rows * 3)) + # Flatten the axes array to make it easy to iterate over + axes = axes.flatten() + for i, slice_idx in enumerate(slice_indices): + ax = axes[i] + ax.imshow(np.rot90(input_data[:, :, slice_idx]), cmap="gray") + ax.set_title(f"Slice {slice_idx}") + ax.axis("off") + + # Turn off any unused subplots in the grid + for i in range(len(slice_indices), len(axes)): + axes[i].axis("off") + + fig.suptitle(f"Input Scan: {os.path.basename(input_path)}", fontsize=16) + plt.tight_layout(rect=[0, 0, 1, 0.96]) + plt.show() + except Exception as e: + LOGGER.error(f"An error occurred during plotting: {e}", exc_info=True) diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/Chart.yaml b/workloads/dev-lifescience-swinunetr-inference/helm/Chart.yaml new file mode 100644 index 0000000..0218469 --- /dev/null +++ b/workloads/dev-lifescience-swinunetr-inference/helm/Chart.yaml @@ -0,0 +1,4 @@ +apiVersion: v2 +name: dev-lifescience-swinunetr-inference +description: A Helm chart for SwinUNETR inference +version: 0.0.1 diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/README.md b/workloads/dev-lifescience-swinunetr-inference/helm/README.md new file mode 100644 index 0000000..6ae0348 --- /dev/null +++ b/workloads/dev-lifescience-swinunetr-inference/helm/README.md @@ -0,0 +1,68 @@ +# Life Science - SwinUNETR Inference + +This Helm Chart deploys a [SwinUNETR](https://arxiv.org/abs/2201.01266) model as an Inference Service, for multiorgan segmentation of 3D CT scans. + +## SwinUNETR model + +[SwinUNETR](https://arxiv.org/abs/2201.01266) is a deep learning architecture designed for medical image segmentation, particularly in 3D volumetric data +such as CT or MRI scans with the aim to detect tumors in the images. It combines the strengths of two powerful +models: 1. Swin Transformer - a hierarchical vision transformer that captures long-range dependencies and contextual +information efficiently and 2. UNETR (UNet with Transformers) - a transformer-based encoder-decoder architecture +tailored for medical image segmentation. + +Model weights are loaded from [HuggingFace](https://huggingface.co/darragh/swinunetr-btcv-base). + +Check out the [demo_inference_service.ipynb](../examples/demo_inference_service.ipynb) example to see how it works. + +## Data + +The training data are 3D CT scans from the [BTCV challenge dataset](https://www.synapse.org/#!Synapse:syn3193805/wiki/217752). The target segmentation includes 13 abdominal organs: + +1. Spleen +2. Right Kidney +3. Left Kideny +4. Gallbladder +5. Esophagus +6. Liver +7. Stomach +8. Aorta +9. IVC +10. Portal and Splenic Veins +11. Pancreas +12. Right adrenal gland +13. Left adrenal gland + + +## Prerequisites + +Ensure the following prerequisites are met before deploying any workloads: + +1. **Helm**: Install `helm`. Refer to the [Helm documentation](https://helm.sh/) for instructions. + +## Deploying the Workload + +It is recommended to use `helm template` and pipe the result to `kubectl create` , rather than using `helm install`. Generally, a command looks as follows + +```bash +helm template [your-release-name] ./helm | kubectl apply -f - +``` + +The chart provides three main ways to deploy models, detailed below. + +## User Input Values + +Refer to the `values.yaml` file for the user input values you can provide, along with instructions. + +## Interacting with Deployed Model + +### Verify Deployment + +Check the deployment status: + +```bash +kubectl get deployment +``` + +### Send prediction request + +Follow the [demo_inference_service.ipynb](../examples/demo_inference_service.ipynb) notebook to see how to use the inference service. diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/mount/README.md b/workloads/dev-lifescience-swinunetr-inference/helm/mount/README.md new file mode 100644 index 0000000..75734b3 --- /dev/null +++ b/workloads/dev-lifescience-swinunetr-inference/helm/mount/README.md @@ -0,0 +1,3 @@ +Files in this directory are mounted to the workload at `/workload/mount`. + +**Note:** Subdirectories and binary files are not supported. diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/mount/data_utils.py b/workloads/dev-lifescience-swinunetr-inference/helm/mount/data_utils.py new file mode 100644 index 0000000..976f9c4 --- /dev/null +++ b/workloads/dev-lifescience-swinunetr-inference/helm/mount/data_utils.py @@ -0,0 +1,39 @@ +import logging + +from monai import transforms + +LOGGER = logging.getLogger(__name__) + +IMAGE_DATA = "image" +SPACING_MODE = "bilinear" +DEVICE = "cuda" + + +def get_transforms(args): + """Returns the transforms to be applied to the input data.""" + inference_transforms = transforms.Compose( + [ + transforms.LoadImaged(keys=[IMAGE_DATA]), + transforms.EnsureChannelFirstd(keys=[IMAGE_DATA], channel_dim="no_channel"), + transforms.Spacingd( + keys=[IMAGE_DATA], pixdim=(args.space_x, args.space_y, args.space_z), mode=SPACING_MODE + ), + transforms.ScaleIntensityRanged( + keys=[IMAGE_DATA], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True + ), + ] + ) + return inference_transforms + + +def get_post_transforms_inverter(forward_transforms_obj): + """Transform to revert all transforms previously applied.""" + return transforms.Invertd( + keys=IMAGE_DATA, + transform=forward_transforms_obj, + orig_keys=IMAGE_DATA, + orig_meta_keys=f"{IMAGE_DATA}_meta_dict", + nearest_interp=True, + to_tensor=True, + device=DEVICE, + ) diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/mount/entrypoint.sh b/workloads/dev-lifescience-swinunetr-inference/helm/mount/entrypoint.sh new file mode 100644 index 0000000..b3ec4a7 --- /dev/null +++ b/workloads/dev-lifescience-swinunetr-inference/helm/mount/entrypoint.sh @@ -0,0 +1,11 @@ +#!/bin/bash + +# Exit immediately if a command exits with a non-zero status +set -e + +# Change directory to the location of this script +cd "$(dirname "$0")" + +pip install --no-cache-dir -r requirements.txt + +python inference_service.py diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/mount/inference_service.py b/workloads/dev-lifescience-swinunetr-inference/helm/mount/inference_service.py new file mode 100644 index 0000000..d788664 --- /dev/null +++ b/workloads/dev-lifescience-swinunetr-inference/helm/mount/inference_service.py @@ -0,0 +1,223 @@ +import argparse +import logging +import os +import tempfile +import traceback +from contextlib import asynccontextmanager, nullcontext + +import nibabel as nib +import numpy as np +import torch +import uvicorn +from data_utils import DEVICE, IMAGE_DATA, get_post_transforms_inverter, get_transforms +from fastapi import FastAPI, File, HTTPException, UploadFile +from fastapi.responses import FileResponse +from monai import transforms +from swinunetr import SwinUnetrModelForInference +from torch.cuda.amp import autocast + +LOGGER = logging.getLogger(__name__) + +# Read configuration from environment variables with fallbacks +ARGS = argparse.Namespace( + hf_model=os.environ.get("HF_MODEL", "darragh/swinunetr-btcv-base"), + roi_x=int(os.environ.get("ROI_X", "96")), + roi_y=int(os.environ.get("ROI_Y", "96")), + roi_z=int(os.environ.get("ROI_Z", "96")), + space_x=float(os.environ.get("SPACE_X", "1.5")), + space_y=float(os.environ.get("SPACE_Y", "1.5")), + space_z=float(os.environ.get("SPACE_Z", "2.0")), + a_min=float(os.environ.get("A_MIN", "-175.0")), + a_max=float(os.environ.get("A_MAX", "250.0")), + b_min=float(os.environ.get("B_MIN", "0.0")), + b_max=float(os.environ.get("B_MAX", "1.0")), + infer_overlap=float(os.environ.get("INFER_OVERLAP", "0.5")), + compile=os.environ.get("COMPILE", "false").lower() == "true", + compile_mode=os.environ.get("COMPILE_MODE", "max-autotune"), + autocast=os.environ.get("AUTOCAST", "false").lower() == "true", +) + +MODEL: SwinUnetrModelForInference | None = None +IMAGE_TRANSFORMS: transforms.Compose | None = None +POST_TRANSFORMS_INVERTER: transforms.Compose | None = None +AUTOCAST_CONTEXT = nullcontext() + + +def load_model(args): + """Loads the pre-trained model.""" + global MODEL, IMAGE_TRANSFORMS, POST_TRANSFORMS_INVERTER, AUTOCAST_CONTEXT + MODEL = SwinUnetrModelForInference.from_pretrained(args.hf_model) + + if args.compile: + LOGGER.info(f"Compiling model to {args.compile_mode}") + MODEL = torch.compile(MODEL, mode=args.compile_mode, dynamic=False) + LOGGER.info("Model compiled successfully.") + + # Set the model to evaluation mode + MODEL.eval() + MODEL.to(DEVICE) + LOGGER.info(f"Model loaded on {DEVICE}.") + + # Configure autocast context once during model loading + if args.autocast: + LOGGER.info("Configuring autocast for inference") + AUTOCAST_CONTEXT = autocast() + else: + AUTOCAST_CONTEXT = nullcontext() + + IMAGE_TRANSFORMS = get_transforms(args=args) + POST_TRANSFORMS_INVERTER = get_post_transforms_inverter(IMAGE_TRANSFORMS) + LOGGER.info("Image transform pipeline initialized.") + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Load model on server startup.""" + LOGGER.info("Server startup: Loading model...") + global MODEL, IMAGE_TRANSFORMS, POST_TRANSFORMS_INVERTER, AUTOCAST_CONTEXT + try: + load_model(ARGS) + except Exception as e: + LOGGER.error(f"Error loading model during startup: {e}") + LOGGER.error(f"Traceback: \n{traceback.format_exc()}") + yield + + MODEL = None + IMAGE_TRANSFORMS = None + POST_TRANSFORMS_INVERTER = None + AUTOCAST_CONTEXT = nullcontext() + + +app = FastAPI(title="SwinUNETR Inference Service", lifespan=lifespan) + + +@app.get("/health", status_code=200) +async def health_check(): + """ + Simple health check endpoint. + Returns 200 OK if the model is loaded and the server is running. + """ + if MODEL is None or IMAGE_TRANSFORMS is None: + raise HTTPException( + status_code=503, detail="Model not loaded. Server is not ready yet." # 503 Service Unavailable + ) + return {"status": "ok"} + + +@app.post("/predict/", response_class=FileResponse) +async def predict(file: UploadFile = File(...)): + """ + Accepts an image file (NIfTI or .npy), performs inference, + and returns the segmentation mask as a NIfTI file. + """ + if MODEL is None or IMAGE_TRANSFORMS is None: + raise HTTPException( + status_code=503, detail="Model not loaded. Server might be starting or encountered an error." + ) + + try: + original_filename = file.filename + file_suffix = "" + if original_filename.endswith(".nii.gz"): + file_suffix = ".nii.gz" + elif original_filename.endswith(".nii"): + file_suffix = ".nii" + elif original_filename.endswith(".npy"): + file_suffix = ".npy" + else: + raise HTTPException(status_code=400, detail="Unsupported file type. Use .nii, .nii.gz, or .npy") + + with tempfile.NamedTemporaryFile(delete=False, suffix=file_suffix) as tmp_file: + content = await file.read() + tmp_file.write(content) + tmp_file_path = tmp_file.name + + LOGGER.info(f"Temporary file saved at: {tmp_file_path} for original: {original_filename}") + + original_affine = np.eye(4) # Default for npy + original_shape_3d = None + img_data_np = None + if file_suffix in [".nii", ".nii.gz"]: + nib_image = nib.load(tmp_file_path) + img_data_np = np.asarray(nib_image.dataobj, dtype=np.float32) + original_affine = nib_image.affine.astype(np.float32) + original_shape_3d = img_data_np.shape + elif file_suffix == ".npy": + img_data_np = np.load(tmp_file_path).astype(np.float32) + original_shape_3d = img_data_np.shape + LOGGER.warning( + "Using identity affine for .npy input. Inverse transform might not perfectly restore original physical space if original .npy had specific spacing." + ) + + # Prepare data dictionary for MONAI transforms + data_dict = {IMAGE_DATA: tmp_file_path} + # Apply transforms + val_input_transformed_dict = IMAGE_TRANSFORMS(data_dict) + + val_inputs = val_input_transformed_dict[IMAGE_DATA] + val_inputs = val_inputs.unsqueeze(0).to(DEVICE) + LOGGER.info(f"Shape transformed inputs: {val_inputs.shape}") + + LOGGER.info("Predicting...") + with torch.inference_mode(), AUTOCAST_CONTEXT: + logits = MODEL.forward( + inputs=val_inputs, + roi_size=(ARGS.roi_x, ARGS.roi_y, ARGS.roi_z), + sw_batch_size=4, + overlap=ARGS.infer_overlap, + mode="gaussian", + ) + LOGGER.info(f"Shape of prediction before inversion: {logits.shape}") + + # Post-processing and transforms inversion + invert_dict = { + IMAGE_DATA: logits[0], + } + inverted = POST_TRANSFORMS_INVERTER(invert_dict) + inverted_logits = inverted[IMAGE_DATA] + LOGGER.info(f"Shape of prediction after inverse transforms: {inverted_logits.shape}") + + probs = torch.sigmoid(inverted_logits) + LOGGER.info(f"Max/Min probs: {probs.max()} / {probs.min()}") + seg = (probs > 0.5).float() + prediction_np = seg.cpu().numpy().squeeze() + LOGGER.info(f"Final prediction shape: {prediction_np.shape}") + + # Save prediction to a NIfTI file + output_affine = inverted[IMAGE_DATA].affine.cpu().numpy() + pred_nifti = nib.Nifti1Image(prediction_np, output_affine) + + with tempfile.NamedTemporaryFile(delete=False, suffix="_prediction.nii.gz") as pred_output_file: + nib.save(pred_nifti, pred_output_file.name) + response_file_path = pred_output_file.name + + os.unlink(tmp_file_path) + + return FileResponse(response_file_path, media_type="application/gzip", filename="prediction.nii.gz") + + except HTTPException as http_exc: + raise http_exc + except Exception as e: + if "tmp_file_path" in locals() and os.path.exists(tmp_file_path): + os.unlink(tmp_file_path) + if "response_file_path" in locals() and os.path.exists(response_file_path): + os.unlink(response_file_path) + LOGGER.error(f"Error during prediction: {e}") + LOGGER.error(f"Traceback: \n{traceback.format_exc()}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + + finally: + if os.path.exists(tmp_file_path): + os.unlink(tmp_file_path) + LOGGER.info(f"Temporary input file {tmp_file_path} unlinked.") + + +if __name__ == "__main__": + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler()], + ) + + port = int(os.environ.get("PORT", "8000")) + uvicorn.run(app, host="0.0.0.0", port=port) diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/mount/requirements.txt b/workloads/dev-lifescience-swinunetr-inference/helm/mount/requirements.txt new file mode 100644 index 0000000..67a7e1e --- /dev/null +++ b/workloads/dev-lifescience-swinunetr-inference/helm/mount/requirements.txt @@ -0,0 +1,11 @@ +connected-components-3d==3.24.0 +einops==0.4.1 +fastapi==0.115.13 +monai[nibabel,pillow,ignite,tqdm,pydicom]==1.5.0 +numpy==1.26.4 +python-multipart==0.0.20 +scipy +tensorboard==2.13.0 +tensorboardX==2.1 +transformers==4.54.1 +uvicorn==0.34.3 diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/mount/swinunetr.py b/workloads/dev-lifescience-swinunetr-inference/helm/mount/swinunetr.py new file mode 100644 index 0000000..bf1049b --- /dev/null +++ b/workloads/dev-lifescience-swinunetr-inference/helm/mount/swinunetr.py @@ -0,0 +1,121 @@ +from typing import Sequence, Union + +import torch +from monai.inferers import sliding_window_inference +from monai.networks.nets import SwinUNETR +from monai.utils import BlendMode +from swinunetr_configuration import SwinUnetrConfig +from torch import nn +from transformers.modeling_utils import ( + PreTrainedModel, +) +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "darragh/swinunetr-btcv-tiny" +_CONFIG_FOR_DOC = "swinunetrConfig" + +SWINUNETR_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "swinunetr-btcv-tiny", + "swinunetr-btcv-small", + "swinunetr-btcv-base", +] + + +class SwinUnetrPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = SwinUnetrConfig + base_model_prefix = "swinunetr" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=0.02) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class SwinUnetrModelForInference(SwinUnetrPreTrainedModel): + """ + Swin UNETR based on: "Hatamizadeh et al., + Swin UNETR: Swin Transformers for Semantic Segmentation of Brain Tumors in MRI Images + " + Source : https://docs.monai.io/en/stable/_modules/monai/networks/nets/swin_unetr.html + """ + + def __init__(self, config, add_pooling_layer=True): + super().__init__(config) + + self.config = config + + self.model = SwinUNETR( + in_channels=config.in_channels, + out_channels=config.out_channels, + depths=config.depths, + num_heads=config.num_heads, + feature_size=config.feature_size, + norm_name=config.norm_name, + drop_rate=config.drop_rate, + attn_drop_rate=config.attn_drop_rate, + dropout_path_rate=config.dropout_path_rate, + normalize=config.normalize, + use_checkpoint=config.use_checkpoint, + spatial_dims=config.spatial_dims, + ) + + self.init_weights() + + def forward( + self, + inputs: torch.Tensor, + roi_size: Union[Sequence[int], int], + sw_batch_size: int, + overlap: float = 0.25, + mode: Union[BlendMode, str] = BlendMode.CONSTANT, + ): + r""" + Sliding window inference on `inputs` with `predictor`. + + The outputs of `predictor` could be a tensor, a tuple, or a dictionary of tensors. + Each output in the tuple or dict value is allowed to have different resolutions with respect to the input. + e.g., the input patch spatial size is [128,128,128], the output (a tuple of two patches) patch sizes + could be ([128,64,256], [64,32,128]). + In this case, the parameter `overlap` and `roi_size` need to be carefully chosen to ensure the output ROI is still + an integer. If the predictor's input and output spatial sizes are not equal, we recommend choosing the parameters + so that `overlap*roi_size*output_size/input_size` is an integer (for each spatial dimension). + + When roi_size is larger than the inputs' spatial size, the input image are padded during inference. + To maintain the same spatial sizes, the output image will be cropped to the original input size. + + Args: + inputs: input image to be processed (assuming NCHW[D]) + roi_size: the spatial window size for inferences. + When its components have None or non-positives, the corresponding inputs dimension will be used. + if the components of the `roi_size` are non-positive values, the transform will use the + corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + sw_batch_size: the batch size to run window slices. + overlap: Amount of overlap between scans. + mode: {``"constant"``, ``"gaussian"``} + How to blend output of overlapping windows. Defaults to ``"constant"``. + + - ``"constant``": gives equal weight to all predictions. + - ``"gaussian``": gives less weight to predictions on edges of windows. + kwargs: optional keyword args to be passed to ``predictor``. + + Note: + - input must be channel-first and have a batch dim, supports N-D sliding window. + + """ + + return sliding_window_inference(inputs, roi_size, sw_batch_size, self.model, overlap, mode) diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/mount/swinunetr_configuration.py b/workloads/dev-lifescience-swinunetr-inference/helm/mount/swinunetr_configuration.py new file mode 100644 index 0000000..c94b60f --- /dev/null +++ b/workloads/dev-lifescience-swinunetr-inference/helm/mount/swinunetr_configuration.py @@ -0,0 +1,94 @@ +# coding=utf-8 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Swin Unnetr configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +SWINUNETR_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "swinunetr-btcv-tiny": "https://huggingface.co/darragh/swinunetr-btcv-tiny/raw/main/config.json", + "swinunetr-btcv-small": "https://huggingface.co/darragh/swinunetr-btcv-small/raw/main/config.json", + "swinunetr-btcv-base": "https://huggingface.co/darragh/swinunetr-btcv-base/raw/main/config.json", +} + + +class SwinUnetrConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class:`~transformers.BertModel` or a + :class:`~transformers.TFBertModel`. It is used to instantiate a model according to the specified arguments, + defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration + to that of the BERT `bert-base-uncased `__ architecture. + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model + outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information. + + + Args: + in_channels: dimension of input channels. + out_channels: dimension of output channels. + feature_size: dimension of network feature size. + depths: number of layers in each stage. + num_heads: number of attention heads. + norm_name: feature normalization type and arguments. + drop_rate: dropout rate. + attn_drop_rate: attention dropout rate. + dropout_path_rate: drop path rate. + normalize: normalize output intermediate features in each stage. + use_checkpoint: use gradient checkpointing for reduced memory usage. + spatial_dims: number of spatial dims. + + Examples:: + + >>> TBD + """ + + model_type = "swin" + + def __init__( + self, + architecture="SwinUNETR", + img_size=96, + in_channels=1, + out_channels=14, + depths=(2, 2, 2, 2), + num_heads=(3, 6, 12, 24), + feature_size=12, + norm_name="instance", + drop_rate=0.0, + attn_drop_rate=0.0, + dropout_path_rate=0.0, + normalize=True, + use_checkpoint=False, + spatial_dims=3, + **kwargs, + ): + super().__init__( + architecture=architecture, + img_size=img_size, + in_channels=in_channels, + out_channels=out_channels, + depths=depths, + num_heads=num_heads, + feature_size=feature_size, + norm_name=norm_name, + drop_rate=drop_rate, + attn_drop_rate=attn_drop_rate, + dropout_path_rate=dropout_path_rate, + normalize=normalize, + use_checkpoint=use_checkpoint, + spatial_dims=spatial_dims, + **kwargs, + ) diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/overrides/kaiwo/kaiwo-enable.yaml b/workloads/dev-lifescience-swinunetr-inference/helm/overrides/kaiwo/kaiwo-enable.yaml new file mode 100644 index 0000000..e6d278a --- /dev/null +++ b/workloads/dev-lifescience-swinunetr-inference/helm/overrides/kaiwo/kaiwo-enable.yaml @@ -0,0 +1,3 @@ +# kaiwo settings (if enabled, use kaiwo CRDs to have kaiwo operator manage the workload) +kaiwo: + enabled: true diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/templates/_helpers.tpl b/workloads/dev-lifescience-swinunetr-inference/helm/templates/_helpers.tpl new file mode 100644 index 0000000..d9d5d8f --- /dev/null +++ b/workloads/dev-lifescience-swinunetr-inference/helm/templates/_helpers.tpl @@ -0,0 +1,72 @@ +# Release name helper +{{- define "release.name" -}} +{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" -}} +{{- end -}} + +# Release fullname helper +{{- define "release.fullname" -}} +{{- $currentTime := now | date "20060102-1504" -}} +{{- if .Values.fullnameOverride -}} +{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" -}} +{{- else -}} +{{- if ne .Release.Name "release-name" -}} +{{- include "release.name" . }}-{{ .Release.Name | trunc 63 | trimSuffix "-" -}} +{{- else -}} +{{- include "release.name" . }}-{{ $currentTime | lower | trunc 63 | trimSuffix "-" -}} +{{- end -}} +{{- end -}} +{{- end -}} + +# Container resources helper +{{- define "container.resources" -}} +requests: + {{- if .Values.gpus }} + amd.com/gpu: "{{ .Values.gpus }}" + {{- end }} + {{- if .Values.ephemeral_storage }} + ephemeral-storage: "{{ .Values.ephemeral_storage }}" + {{- end }} +limits: + {{- if .Values.gpus }} + amd.com/gpu: "{{ .Values.gpus }}" + {{- end }} + {{- if .Values.ephemeral_storage }} + ephemeral-storage: "{{ .Values.ephemeral_storage }}" + {{- end }} +{{- end -}} + +# Container volume mounts helper +{{- define "container.volumeMounts" -}} +- mountPath: /workload/mount + name: workload-mount +- mountPath: /dev/shm + name: dshm +{{- end -}} + +# Container volumes helper +{{- define "container.volumes" -}} +- name: dshm + emptyDir: + medium: Memory + sizeLimit: {{ .Values.storage.dshm.sizeLimit }} +- configMap: + name: {{ include "release.fullname" . }} + defaultMode: 0755 + name: workload-mount +{{- end -}} + +# Container environment variables helper +{{- define "container.env" -}} +{{- range $key, $value := .Values.env_vars }} +{{- if (kindIs "map" $value) }} +- name: {{ $key }} + valueFrom: + secretKeyRef: + name: {{ $value.name }} + key: {{ $value.key }} +{{- else }} +- name: {{ $key }} + value: {{ $value | quote }} +{{- end }} +{{- end }} +{{- end -}} diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/templates/configmap.yaml b/workloads/dev-lifescience-swinunetr-inference/helm/templates/configmap.yaml new file mode 100644 index 0000000..db5a6c7 --- /dev/null +++ b/workloads/dev-lifescience-swinunetr-inference/helm/templates/configmap.yaml @@ -0,0 +1,11 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ include "release.fullname" . }} +data: +{{- $files := .Files }} +{{- range $path, $_ := .Files.Glob "mount/*" }} + {{ $key := $path | trimPrefix "mount/" }} + {{- $key }}: | +{{ $files.Get $path | indent 4 }} +{{- end }} diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/templates/deployment.yaml b/workloads/dev-lifescience-swinunetr-inference/helm/templates/deployment.yaml new file mode 100644 index 0000000..c0a803f --- /dev/null +++ b/workloads/dev-lifescience-swinunetr-inference/helm/templates/deployment.yaml @@ -0,0 +1,96 @@ +{{- define "deployment" -}} +apiVersion: apps/v1 +kind: Deployment +metadata: + name: {{ include "release.fullname" . }} + labels: + app: {{ include "release.fullname" . }} + {{- range $key, $value := .Values.metadata.labels }} + {{ $key }}: {{ $value | quote }} + {{- end }} +spec: + replicas: 1 + selector: + matchLabels: + app: {{ include "release.fullname" . }} + template: + metadata: + labels: + app: {{ include "release.fullname" . }} + spec: + {{- if .Values.nodeSelector }} + nodeSelector: + {{- .Values.nodeSelector | toYaml | nindent 8 }} + {{- end }} + {{- if .Values.imagePullSecrets }} + imagePullSecrets: + {{- range .Values.imagePullSecrets }} + - name: {{ . }} + {{- end }} + {{- end }} + containers: + - name: {{ .Chart.Name }} + {{- if .Values.entrypoint }} + command: ["sh", "-c"] + args: + - | + {{- .Values.entrypoint | nindent 12 }} + {{- end }} + {{- if .Values.env_vars }} + env: + {{- include "container.env" . | nindent 12 }} + {{- end }} + image: {{ .Values.image | quote}} + imagePullPolicy: {{ default "Always" .Values.imagePullPolicy | quote }} + ports: + {{- range $key, $value := .Values.deployment.ports }} + - name: {{ $key }} + containerPort: {{ $value }} + {{- end }} + {{- if .Values.livenessProbe }} + livenessProbe: + {{- .Values.livenessProbe | toYaml | nindent 12 -}} + {{- end }} + {{- if .Values.readinessProbe }} + readinessProbe: + {{- .Values.readinessProbe | toYaml | nindent 12 -}} + {{- end }} + {{- if .Values.startupProbe }} + startupProbe: + {{- .Values.startupProbe | toYaml | nindent 12 -}} + {{- end }} + resources: + {{- include "container.resources" . | nindent 12 }} + volumeMounts: + {{- include "container.volumeMounts" . | nindent 12 }} + volumes: + {{- include "container.volumes" . | nindent 8 }} +{{- end -}} + +{{- define "deployment_stripped" -}} +{{- $deployment := include "deployment" . | fromYaml }} +{{- $ := unset $deployment "metadata" }} +{{- $ := unset $deployment.spec.template "metadata" }} +{{- $deployment | toYaml }} +{{- end -}} + +{{- define "deployment_wrapped_with_kaiwoservice" -}} +apiVersion: kaiwo.silogen.ai/v1alpha1 +kind: KaiwoService +metadata: + name: {{ include "release.fullname" . }} + labels: + app: {{ include "release.fullname" . }} + {{- range $key, $value := .Values.metadata.labels }} + {{ $key }}: {{ $value | quote }} + {{- end }} +spec: + deployment: + {{- include "deployment" . | nindent 4 }} +{{- end -}} + +{{- if .Values.kaiwo.enabled -}} +{{- include "deployment_wrapped_with_kaiwoservice" . }} +{{- else -}} +{{- include "deployment" . }} +{{- end -}} diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/templates/service.yaml b/workloads/dev-lifescience-swinunetr-inference/helm/templates/service.yaml new file mode 100644 index 0000000..e4968e9 --- /dev/null +++ b/workloads/dev-lifescience-swinunetr-inference/helm/templates/service.yaml @@ -0,0 +1,22 @@ +apiVersion: v1 +kind: Service +metadata: + name: {{ include "release.fullname" . }} + labels: + app: {{ include "release.fullname" . }} +spec: + type: ClusterIP + ports: + {{ range $name, $port := .Values.deployment.ports }} + {{- if ne $name "http" }} + - name: {{ $name }} + port: {{ $port }} + targetPort: {{ $port }} + {{- else -}} + - name: {{ $name }} + port: 80 + targetPort: {{ $port }} + {{- end }} + {{- end }} + selector: + app: {{ include "release.fullname" . }} diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/values.schema.json b/workloads/dev-lifescience-swinunetr-inference/helm/values.schema.json new file mode 100644 index 0000000..cf2637c --- /dev/null +++ b/workloads/dev-lifescience-swinunetr-inference/helm/values.schema.json @@ -0,0 +1,221 @@ +{ + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "object", + "properties": { + "metadata": { + "type": "object", + "description": "Metadata for the deployment", + "properties": { + "labels": { + "type": "object", + "description": "Labels to apply to the deployment", + "additionalProperties": { + "type": "string" + } + } + }, + "required": ["labels"] + }, + "image": { + "type": "string", + "description": "Docker image to use for the deployment" + }, + "imagePullPolicy": { + "type": "string", + "description": "Image pull policy", + "enum": ["Always", "IfNotPresent", "Never"] + }, + "imagePullSecrets": { + "type": "array", + "description": "Image pull secrets for private registries" + }, + "entrypoint": { + "type": "string", + "description": "Entrypoint for the container" + }, + "gpus": { + "type": "integer", + "description": "Number of GPUs to allocate", + "minimum": 1 + }, + "ephemeral_storage": { + "type": "string", + "description": "Ephemeral storage space written with as Gi", + "minimum": 1 + }, + "env_vars": { + "type": "object", + "description": "Environment variables for the container", + "properties": { + "HF_MODEL": { + "type": "string", + "description": "Hugging Face model name" + }, + "ROI_X": { + "type": "integer", + "description": "Region of interest X dimension" + }, + "ROI_Y": { + "type": "integer", + "description": "Region of interest Y dimension" + }, + "ROI_Z": { + "type": "integer", + "description": "Region of interest Z dimension" + }, + "SPACE_X": { + "type": "number", + "description": "Spacing in X dimension" + }, + "SPACE_Y": { + "type": "number", + "description": "Spacing in Y dimension" + }, + "SPACE_Z": { + "type": "number", + "description": "Spacing in Z dimension" + }, + "A_MIN": { + "type": "number", + "description": "Minimum value for normalization" + }, + "A_MAX": { + "type": "number", + "description": "Maximum value for normalization" + }, + "B_MIN": { + "type": "number", + "description": "Minimum value for normalization range" + }, + "B_MAX": { + "type": "number", + "description": "Maximum value for normalization range" + }, + "INFER_OVERLAP": { + "type": "number", + "description": "Inference overlap value" + }, + "COMPILE": { + "type": "string", + "description": "Whether to compile the model", + "enum": ["true", "false"] + }, + "COMPILE_MODE": { + "type": "string", + "description": "Compilation mode for the model" + }, + "AUTOCAST": { + "type": "string", + "description": "Whether to use autocast", + "enum": ["true", "false"] + } + }, + "additionalProperties": { + "type": "string" + } + }, + "vllm_engine_args": { + "type": "object", + "description": "Arguments for the vllm engine", + "additionalProperties": { + "type": "string" + } + }, + "storage": { + "type": "object", + "description": "Storage configuration", + "properties": { + "dshm": { + "type": "object", + "description": "Shared memory configuration", + "properties": { + "sizeLimit": { + "type": "string", + "description": "Size limit for shared memory" + } + }, + "required": ["sizeLimit"] + } + }, + "required": ["dshm"] + }, + "volumes": { + "type": "array", + "description": "Custom volumes to mount from secrets or configmaps.", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "mountPath": { + "type": "string" + }, + "secret": { + "type": "object", + "properties": { + "secretName": { + "type": "string" + } + }, + "required": ["secretName"] + } + }, + "required": ["name", "mountPath"] + } + }, + "deployment": { + "type": "object", + "description": "Deployment configuration", + "properties": { + "ports": { + "type": "object", + "description": "Ports for the deployment", + "properties": { + "http": { + "type": "integer", + "description": "HTTP port for the deployment" + } + }, + "required": ["http"] + } + }, + "required": ["ports"] + }, + "nodeSelector": { + "type": "object", + "properties": { + "dev": { + "type": "string", + "description": "If true, use the dev node selector" + } + } + }, + "kaiwo": { + "type": "object", + "properties": { + "enabled": { + "type": "boolean", + "description": "If true, use Kaiwo CRDs to have Kaiwo operator manage the workload" + } + } + }, + "startupProbe": { + "type": ["object"], + "additionalProperties": true, + "description": "Startup probe configuration for the container" + }, + "livenessProbe": { + "type": ["object"], + "additionalProperties": true, + "description": "Liveness probe configuration for the container" + }, + "readinessProbe": { + "type": ["object"], + "additionalProperties": true, + "description": "Readiness probe configuration for the container" + } + }, + "required": ["metadata", "image", "imagePullPolicy", "gpus", "ephemeral_storage", "env_vars", "storage", "deployment", "kaiwo"], + "additionalProperties": false +} diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/values.yaml b/workloads/dev-lifescience-swinunetr-inference/helm/values.yaml new file mode 100644 index 0000000..567840c --- /dev/null +++ b/workloads/dev-lifescience-swinunetr-inference/helm/values.yaml @@ -0,0 +1,55 @@ +metadata: + labels: {} + +image: rocm/pytorch:rocm7.0_ubuntu24.04_py3.12_pytorch_release_2.6.0 +imagePullPolicy: Always + +gpus: 1 +ephemeral_storage: 128Gi + +entrypoint: | + bash /workload/mount/entrypoint.sh + +env_vars: + HF_MODEL: "darragh/swinunetr-btcv-base" + ROI_X: 96 + ROI_Y: 96 + ROI_Z: 96 + SPACE_X: 1.5 + SPACE_Y: 1.5 + SPACE_Z: 2.0 + A_MIN: -175.0 + A_MAX: 250.0 + B_MIN: 0.0 + B_MAX: 1.0 + INFER_OVERLAP: 0.5 + COMPILE: "false" + COMPILE_MODE: "max-autotune" + AUTOCAST: "false" + +storage: + dshm: + sizeLimit: 32Gi + +deployment: + ports: + http: 8000 + +startupProbe: + httpGet: + path: /health + port: http + failureThreshold: 60 + periodSeconds: 10 +livenessProbe: + httpGet: + path: /health + port: http +readinessProbe: + httpGet: + path: /health + port: http + +# kaiwo settings (if enabled, use kaiwo CRDs to have kaiwo operator manage the workload) +kaiwo: + enabled: false From 26705671adf035588f56e1693de3cda2a413fba5 Mon Sep 17 00:00:00 2001 From: aivanni <4340981+aivanni@users.noreply.github.com> Date: Mon, 17 Nov 2025 12:13:16 +0200 Subject: [PATCH 2/6] Make fixes to Megatron-LM checkpoint processing and minor improvements (#463) * make fixes to checkpoint processing and minor improvements * Update readme of the multinode Megatron pretrain * update tutorials with debug info * change path to trained model in tutorial 04 megatron inference wl override * fix path to set-env-vars.sh * fix token alignment in inference * precommit hooks * update inference patch * fix patch for inference --- ...-deliver-resources-and-run-megatron-cpt.md | 14 ++ ...a70b-and-run-megatron-cpt-with-tp8-ddp2.md | 14 ++ .../helm/mount/Megatron-LM-inference.patch | 237 ++++++++++++++++++ .../megatron-lm-inference-llama3-1-70b.patch | 17 -- .../overrides/tutorial-04-llama-3-1-70b.yaml | 2 +- .../helm/templates/deployment.yaml | 2 +- .../helm/mount/Megatron-LM.patch | 103 +++++++- .../helm/templates/conversion-job.yaml | 7 +- .../helm/values.yaml | 6 + .../helm/README.md | 15 +- .../helm/mount/Megatron-LM.patch | 143 +++++++++++ .../helm/mount/ray_entrypoint.py | 4 + .../helm/mount/Megatron-LM.patch | 110 +++++++- .../helm/mount/train-cpt.sh | 3 +- 14 files changed, 639 insertions(+), 38 deletions(-) create mode 100644 workloads/llm-inference-megatron-lm/helm/mount/Megatron-LM-inference.patch delete mode 100644 workloads/llm-inference-megatron-lm/helm/mount/megatron-lm-inference-llama3-1-70b.patch create mode 100644 workloads/llm-pretraining-megatron-lm-ray/helm/mount/Megatron-LM.patch diff --git a/docs/tutorials/tutorial-03-deliver-resources-and-run-megatron-cpt.md b/docs/tutorials/tutorial-03-deliver-resources-and-run-megatron-cpt.md index bff6a10..236c190 100644 --- a/docs/tutorials/tutorial-03-deliver-resources-and-run-megatron-cpt.md +++ b/docs/tutorials/tutorial-03-deliver-resources-and-run-megatron-cpt.md @@ -63,6 +63,20 @@ helm template workloads/llm-pretraining-megatron-lm-ray/helm \ | kubectl apply -f - ``` +It is important to note that service account used by the rayjob must have `get rayjob` and `patch configmap | pvc` permissions in order to run garbage collection script from [https://github.com/silogen/ai-workloads/blob/main/workloads/llm-pretraining-megatron-lm-ray/helm/mount/gc.sh](https://github.com/silogen/ai-workloads/blob/main/workloads/llm-pretraining-megatron-lm-ray/helm/mount/gc.sh) successfully. If this requirement is not satisfied it will manifest by failing to start the ray cluster. The head pod of the cluster will have `Init:Error` status because init container that runs `gc.sh` script fails with the error similar to + +```bash +Error from server (Forbidden): rayjobs.ray.io is forbidden: User "system:serviceaccount:examplenamespace:default" cannot get resource "rayjobs" in API group "ray.io" in the namespace "examplenamespace" +``` + +To quickly overcome this issue while waiting for permissions setup one can comment out this line in [https://github.com/silogen/ai-workloads/blob/main/workloads/llm-pretraining-megatron-lm-ray/helm/templates/ray_job.yaml](https://github.com/silogen/ai-workloads/blob/main/workloads/llm-pretraining-megatron-lm-ray/helm/templates/ray_job.yaml#L69) + +``` +bash /local_resources/mount/gc.sh{{- if and .Values.kaiwo.storageEnabled .Values.kaiwo.enabled}} --skip-pvc{{- end }} {{ include "release.fullname" . }} +``` + +If automatic garbage collection was disabled this way then resources of the workload such as `PVC` and `ConfigMap` should be deleted manually using `kubectl delete` commands in the end of the run. + ### 2.4 Run inference workload with the final checkpoint (2.3) and query it using sample prompts on Llama-3.1-8B In order to perform inference with the just trained Llama-3.1-8B model and verify it's quality, follow the steps: diff --git a/docs/tutorials/tutorial-04-deliver-llama70b-and-run-megatron-cpt-with-tp8-ddp2.md b/docs/tutorials/tutorial-04-deliver-llama70b-and-run-megatron-cpt-with-tp8-ddp2.md index adecc24..481fb0b 100644 --- a/docs/tutorials/tutorial-04-deliver-llama70b-and-run-megatron-cpt-with-tp8-ddp2.md +++ b/docs/tutorials/tutorial-04-deliver-llama70b-and-run-megatron-cpt-with-tp8-ddp2.md @@ -62,6 +62,20 @@ helm template workloads/llm-pretraining-megatron-lm-ray/helm \ | kubectl apply -f - ``` +It is important to note that service account used by the rayjob must have `get rayjob` and `patch configmap | pvc` permissions in order to run garbage collection script from [https://github.com/silogen/ai-workloads/blob/main/workloads/llm-pretraining-megatron-lm-ray/helm/mount/gc.sh](https://github.com/silogen/ai-workloads/blob/main/workloads/llm-pretraining-megatron-lm-ray/helm/mount/gc.sh) successfully. If this requirement is not satisfied it will manifest by failing to start the ray cluster. The head pod of the cluster will have `Init:Error` status because init container that runs `gc.sh` script fails with the error similar to + +```bash +Error from server (Forbidden): rayjobs.ray.io is forbidden: User "system:serviceaccount:examplenamespace:default" cannot get resource "rayjobs" in API group "ray.io" in the namespace "examplenamespace" +``` + +To quickly overcome this issue while waiting for permissions setup one can comment out this line in [https://github.com/silogen/ai-workloads/blob/main/workloads/llm-pretraining-megatron-lm-ray/helm/templates/ray_job.yaml](https://github.com/silogen/ai-workloads/blob/main/workloads/llm-pretraining-megatron-lm-ray/helm/templates/ray_job.yaml#L69) + +``` +bash /local_resources/mount/gc.sh{{- if and .Values.kaiwo.storageEnabled .Values.kaiwo.enabled}} --skip-pvc{{- end }} {{ include "release.fullname" . }} +``` + +If automatic garbage collection was disabled this way then resources of the workload such as `PVC` and `ConfigMap` should be deleted manually using `kubectl delete` commands in the end of the run. + ### 2.4 Run inference workload with the final checkpoint (2.3) and query it using sample prompts on Llama-3.1-70B diff --git a/workloads/llm-inference-megatron-lm/helm/mount/Megatron-LM-inference.patch b/workloads/llm-inference-megatron-lm/helm/mount/Megatron-LM-inference.patch new file mode 100644 index 0000000..433d845 --- /dev/null +++ b/workloads/llm-inference-megatron-lm/helm/mount/Megatron-LM-inference.patch @@ -0,0 +1,237 @@ +diff --git a/megatron/core/dist_checkpointing/strategies/filesystem_async.py b/megatron/core/dist_checkpointing/strategies/filesystem_async.py +index 47ab4d11..0c1b868d 100644 +--- a/megatron/core/dist_checkpointing/strategies/filesystem_async.py ++++ b/megatron/core/dist_checkpointing/strategies/filesystem_async.py +@@ -20,6 +20,15 @@ from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteIte + from torch.distributed.checkpoint.storage import WriteResult + from torch.futures import Future + ++try: ++ # This PR https://github.com/pytorch/pytorch/pull/143359 introduced breaking change to saving checkpoints ++ # in torch_dist format. This is a workaround to fix the issue. ++ from torch.distributed.checkpoint.filesystem import _StorageWriterTransforms ++ from functools import partial ++ _write_item = partial(_write_item, _StorageWriterTransforms()) ++except ImportError: ++ pass ++ + logger = logging.getLogger(__name__) + + WriteBucket = Tuple[Path, str, Tuple[list, list]] # represents writes to a single file +diff --git a/megatron/inference/endpoints/completions.py b/megatron/inference/endpoints/completions.py +index 32dbc5dc..ab187a93 100644 +--- a/megatron/inference/endpoints/completions.py ++++ b/megatron/inference/endpoints/completions.py +@@ -144,12 +144,17 @@ class MegatronCompletions(Resource): + for batch_idx, (prompt_plus_generation, prompt) in enumerate( + zip(prompts_plus_generations, prompts) + ): ++ prompt_tokens = tok.tokenize(prompt) ++ prompt_token_count = len(prompt_tokens) ++ prompt_reconstructed = tok.detokenize(tokens[batch_idx][:prompt_token_count]) ++ + tok_offsets = tok.offsets(tokens[batch_idx], prompt_plus_generation) + if echo: +- str_trunc_start_idx, tok_idx_start = 0, 0 ++ str_trunc_start_idx, tok_idx_start, tok_idx_start_offsets = 0, 0, 0 + else: +- str_trunc_start_idx = len(prompt) +- tok_idx_start = np.searchsorted(tok_offsets, len(prompt)) ++ str_trunc_start_idx = len(prompt_reconstructed) ++ tok_idx_start_offsets = np.searchsorted(tok_offsets, str_trunc_start_idx) ++ tok_idx_start = prompt_token_count + + # truncate the generation at the first stop token + trunc_idxs = [ +@@ -161,21 +166,21 @@ class MegatronCompletions(Resource): + truncated_generation = prompt_plus_generation[str_trunc_start_idx:str_trunc_end_idx] + + # TODO(sasatheesh): handle cases where truncated_generation is not a full token +- tok_idx_end = np.searchsorted(tok_offsets, len(truncated_generation)) ++ tok_idx_end = np.searchsorted(tok_offsets, str_trunc_end_idx) + +- truncated_generation_logprobs = output_log_probs[batch_idx][tok_idx_start:tok_idx_end] ++ truncated_generation_logprobs = output_log_probs[batch_idx][max(tok_idx_start-1,0):tok_idx_end-1] + truncated_generation_tokens = tokens[batch_idx][tok_idx_start:tok_idx_end] + truncated_generation_topk_logprobs = ret_topk_logprobs[batch_idx][ + tok_idx_start:tok_idx_end + ] +- truncated_generation_tok_offsets = tok_offsets[tok_idx_start:tok_idx_end] ++ truncated_generation_tok_offsets = tok_offsets[tok_idx_start_offsets:tok_idx_end] + + results.append( + { + "index": batch_idx, + "text": truncated_generation, + "logprobs": { +- "token_logprobs": [None] + truncated_generation_logprobs, ++ "token_logprobs": truncated_generation_logprobs, + "tokens": [tok.detokenize([tk]) for tk in truncated_generation_tokens], + "text_offset": truncated_generation_tok_offsets, + "top_logprobs": truncated_generation_topk_logprobs, +diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py +index 5ac0747e..69a3dd75 100644 +--- a/megatron/training/arguments.py ++++ b/megatron/training/arguments.py +@@ -710,9 +710,9 @@ def validate_args(args, defaults={}): + if args.num_experts is not None: + assert args.spec is None, "Model Spec must be None when using MoEs" + +- if args.tensor_model_parallel_size > 1: +- assert args.sequence_parallel, \ +- "When using MoE and tensor parallelism, sequence parallelism must be used." ++ #if args.tensor_model_parallel_size > 1: ++ # assert args.sequence_parallel, \ ++ # "When using MoE and tensor parallelism, sequence parallelism must be used." + + if args.moe_ffn_hidden_size is None: + args.moe_ffn_hidden_size = args.ffn_hidden_size +diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py +index 92813050..dd771395 100644 +--- a/megatron/training/checkpointing.py ++++ b/megatron/training/checkpointing.py +@@ -970,6 +970,9 @@ def load_args_from_checkpoint( + _set_arg('rotary_base', force=True) + _set_arg('rotary_percent', force=True) + _set_arg('rotary_interleaved', force=True) ++ _set_arg('rotary_seq_len_interpolation_factor', force=True) ++ _set_arg('use_rope_scaling', force=True) ++ _set_arg('norm_epsilon', force=True) + _set_arg('add_bias_linear', force=True) + _set_arg('add_qkv_bias', force=True) + _set_arg('squared_relu', force=True) +diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py +index 11e7645d..366c7a7b 100644 +--- a/megatron/training/tokenizer/tokenizer.py ++++ b/megatron/training/tokenizer/tokenizer.py +@@ -182,15 +182,12 @@ class _HuggingFaceTokenizer(MegatronTokenizer): + return self._tokenizer.decode(token_ids, **kwargs) + + def offsets(self, ids: list[int], text: str) -> list[int]: +- retok_ids: "transformers.BatchEncoding" = self._tokenizer(text) +- offsets, next_start_idx = [], 0 +- for i in range(len(ids)): +- span = retok_ids.token_to_chars(i) +- if span is not None: +- offsets.append(span.start) +- next_start_idx = span.end +- else: +- offsets.append(next_start_idx) ++ tokens = self._tokenizer.convert_ids_to_tokens(ids) ++ offsets = [] ++ current = 0 ++ for t in tokens: ++ offsets.append(current) ++ current += len(t) + return offsets + + @property +diff --git a/pretrain_gpt.py b/pretrain_gpt.py +index d31c0954..a850624a 100644 +--- a/pretrain_gpt.py ++++ b/pretrain_gpt.py +@@ -125,7 +125,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, +- rope_scaling=args.use_rope_scaling ++ rope_scaling=args.use_rope_scaling, ++ seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor + ) + + return model +diff --git a/tools/checkpoint/convert.py b/tools/checkpoint/convert.py +index 935613b1..4a2297f6 100644 +--- a/tools/checkpoint/convert.py ++++ b/tools/checkpoint/convert.py +@@ -151,4 +151,8 @@ def main(): + + + if __name__ == '__main__': ++ try: ++ mp.set_start_method('spawn') ++ except RuntimeError: ++ pass + main() +diff --git a/tools/checkpoint/loader_llama_mistral.py b/tools/checkpoint/loader_llama_mistral.py +index b6697964..c96fec45 100644 +--- a/tools/checkpoint/loader_llama_mistral.py ++++ b/tools/checkpoint/loader_llama_mistral.py +@@ -320,6 +320,13 @@ def load_args_from_checkpoint(args): + args.padded_vocab_size = model_args["vocab_size"] + args.ffn_hidden_size = model_args["intermediate_size"] + ++ if "rope_theta" in model_args: ++ args.rotary_base = int(model_args["rope_theta"]) ++ if "rope_scaling" in model_args and model_args["rope_scaling"].get("type", "") == "linear" and "factor" in model_args["rope_scaling"]: ++ args.rotary_seq_len_interpolation_factor = int(model_args["rope_scaling"]["factor"]) ++ if "rope_scaling" in model_args and model_args["rope_scaling"].get("rope_type", "") == "llama3": ++ args.use_rope_scaling = True ++ + if "num_key_value_heads" in model_args: + args.group_query_attention = True + args.num_query_groups = model_args["num_key_value_heads"] +@@ -457,6 +464,7 @@ def _load_checkpoint(queue, args): + '--no-save-rng', + '--mock-data', # To pass the "blend data checks" in arguments.py + '--no-initialization', ++ '--no-gradient-accumulation-fusion', + '--load', args.load_dir, + '--no-one-logger', + ] +@@ -560,6 +568,10 @@ def _load_checkpoint(queue, args): + md.checkpoint_args = margs + md.consumed_train_samples = 0 + md.consumed_valid_samples = 0 ++ md.norm_epsilon = margs.norm_epsilon ++ md.rotary_base = margs.rotary_base ++ md.rotary_seq_len_interpolation_factor = margs.rotary_seq_len_interpolation_factor ++ md.use_rope_scaling = margs.use_rope_scaling + + margs.model_size = args.model_size + +diff --git a/tools/checkpoint/saver_mcore.py b/tools/checkpoint/saver_mcore.py +index 2caf26a9..83c7951b 100644 +--- a/tools/checkpoint/saver_mcore.py ++++ b/tools/checkpoint/saver_mcore.py +@@ -137,6 +137,15 @@ def save_checkpoint(queue, args): + '--no-one-logger', + ] + ++ if md.norm_epsilon: ++ sys.argv.extend(['--norm-epsilon', str(md.norm_epsilon)]) ++ if md.rotary_base: ++ sys.argv.extend(['--rotary-base', str(md.rotary_base)]) ++ if md.rotary_seq_len_interpolation_factor: ++ sys.argv.extend(['--rotary-seq-len-interpolation-factor', str(md.rotary_seq_len_interpolation_factor)]) ++ if md.use_rope_scaling: ++ sys.argv.append('--use-rope-scaling') ++ + if md.make_vocab_size_divisible_by is not None: + sys.argv.extend(['--make-vocab-size-divisible-by', str(md.make_vocab_size_divisible_by)]) + if md.params_dtype == torch.float16: +@@ -188,8 +197,8 @@ def save_checkpoint(queue, args): + margs.apply_query_key_layer_scaling = md.checkpoint_args.apply_query_key_layer_scaling + + # Sequence parallel is required if use both tensor-parallel and Moe. +- if margs.num_experts is not None and args.target_tensor_parallel_size is not None: +- if margs.num_experts > 1 and args.target_tensor_parallel_size > 1: ++ if args.target_tensor_parallel_size is not None: ++ if args.target_tensor_parallel_size > 1: + margs.sequence_parallel = True + + validate_args(margs) +diff --git a/tools/run_text_generation_server.py b/tools/run_text_generation_server.py +index e5b3f08a..fd6688e2 100644 +--- a/tools/run_text_generation_server.py ++++ b/tools/run_text_generation_server.py +@@ -84,7 +84,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, +- rope_scaling=args.use_rope_scaling ++ rope_scaling=args.use_rope_scaling, ++ seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor + ) + + return model diff --git a/workloads/llm-inference-megatron-lm/helm/mount/megatron-lm-inference-llama3-1-70b.patch b/workloads/llm-inference-megatron-lm/helm/mount/megatron-lm-inference-llama3-1-70b.patch deleted file mode 100644 index a72cb3a..0000000 --- a/workloads/llm-inference-megatron-lm/helm/mount/megatron-lm-inference-llama3-1-70b.patch +++ /dev/null @@ -1,17 +0,0 @@ -diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py -index 5ac0747e..69a3dd75 100644 ---- a/megatron/training/arguments.py -+++ b/megatron/training/arguments.py -@@ -710,9 +710,9 @@ def validate_args(args, defaults={}): - if args.num_experts is not None: - assert args.spec is None, "Model Spec must be None when using MoEs" - -- if args.tensor_model_parallel_size > 1: -- assert args.sequence_parallel, \ -- "When using MoE and tensor parallelism, sequence parallelism must be used." -+ #if args.tensor_model_parallel_size > 1: -+ # assert args.sequence_parallel, \ -+ # "When using MoE and tensor parallelism, sequence parallelism must be used." - - if args.moe_ffn_hidden_size is None: - args.moe_ffn_hidden_size = args.ffn_hidden_size \ No newline at end of file diff --git a/workloads/llm-inference-megatron-lm/helm/overrides/tutorial-04-llama-3-1-70b.yaml b/workloads/llm-inference-megatron-lm/helm/overrides/tutorial-04-llama-3-1-70b.yaml index 0bb1037..42a5bf7 100644 --- a/workloads/llm-inference-megatron-lm/helm/overrides/tutorial-04-llama-3-1-70b.yaml +++ b/workloads/llm-inference-megatron-lm/helm/overrides/tutorial-04-llama-3-1-70b.yaml @@ -7,7 +7,7 @@ imagePullPolicy: Always remoteTokenizerPath: default-bucket/models/meta-llama/Llama-3.1-70B/ # Path to the checkpoint for Llama 3.1 70B -remoteModelPath: default-bucket/megatron-models/meta-llama/Llama-3.1-70B/ +remoteModelPath: default-bucket/experiments/megatron-lm/llama-3.1-70b-cpt-test/ envVars: BUCKET_STORAGE_HOST: http://minio.minio-tenant-default.svc.cluster.local:80 diff --git a/workloads/llm-inference-megatron-lm/helm/templates/deployment.yaml b/workloads/llm-inference-megatron-lm/helm/templates/deployment.yaml index f67f869..18224af 100644 --- a/workloads/llm-inference-megatron-lm/helm/templates/deployment.yaml +++ b/workloads/llm-inference-megatron-lm/helm/templates/deployment.yaml @@ -29,7 +29,7 @@ spec: - | bash /workload/mount/download_files.sh {{ .Values.remoteModelPath }} {{ .Values.remoteTokenizerPath | trimSuffix "/" }} git checkout fd6f0d11 - git apply /workload/mount/megatron-lm-inference-llama3-1-70b.patch + git apply /workload/mount/Megatron-LM-inference.patch echo "Patch applied successfully" bash /workload/mount/run_megatron.sh ports: diff --git a/workloads/llm-megatron-ckpt-conversion/helm/mount/Megatron-LM.patch b/workloads/llm-megatron-ckpt-conversion/helm/mount/Megatron-LM.patch index cf956ea..e9bfa47 100644 --- a/workloads/llm-megatron-ckpt-conversion/helm/mount/Megatron-LM.patch +++ b/workloads/llm-megatron-ckpt-conversion/helm/mount/Megatron-LM.patch @@ -5,7 +5,7 @@ index 47ab4d11..0c1b868d 100644 @@ -20,6 +20,15 @@ from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteIte from torch.distributed.checkpoint.storage import WriteResult from torch.futures import Future - + +try: + # This PR https://github.com/pytorch/pytorch/pull/143359 introduced breaking change to saving checkpoints + # in torch_dist format. This is a workaround to fix the issue. @@ -16,15 +16,43 @@ index 47ab4d11..0c1b868d 100644 + pass + logger = logging.getLogger(__name__) - + WriteBucket = Tuple[Path, str, Tuple[list, list]] # represents writes to a single file +diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py +index 92813050..dd771395 100644 +--- a/megatron/training/checkpointing.py ++++ b/megatron/training/checkpointing.py +@@ -970,6 +970,9 @@ def load_args_from_checkpoint( + _set_arg('rotary_base', force=True) + _set_arg('rotary_percent', force=True) + _set_arg('rotary_interleaved', force=True) ++ _set_arg('rotary_seq_len_interpolation_factor', force=True) ++ _set_arg('use_rope_scaling', force=True) ++ _set_arg('norm_epsilon', force=True) + _set_arg('add_bias_linear', force=True) + _set_arg('add_qkv_bias', force=True) + _set_arg('squared_relu', force=True) +diff --git a/pretrain_gpt.py b/pretrain_gpt.py +index d31c0954..a850624a 100644 +--- a/pretrain_gpt.py ++++ b/pretrain_gpt.py +@@ -125,7 +125,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, +- rope_scaling=args.use_rope_scaling ++ rope_scaling=args.use_rope_scaling, ++ seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor + ) + + return model diff --git a/tools/checkpoint/convert.py b/tools/checkpoint/convert.py index 935613b1..4a2297f6 100644 --- a/tools/checkpoint/convert.py +++ b/tools/checkpoint/convert.py @@ -151,4 +151,8 @@ def main(): - - + + if __name__ == '__main__': + try: + mp.set_start_method('spawn') @@ -32,10 +60,24 @@ index 935613b1..4a2297f6 100644 + pass main() diff --git a/tools/checkpoint/loader_llama_mistral.py b/tools/checkpoint/loader_llama_mistral.py -index b6697964..054ab941 100644 +index b6697964..c96fec45 100644 --- a/tools/checkpoint/loader_llama_mistral.py +++ b/tools/checkpoint/loader_llama_mistral.py -@@ -457,6 +457,7 @@ def _load_checkpoint(queue, args): +@@ -320,6 +320,13 @@ def load_args_from_checkpoint(args): + args.padded_vocab_size = model_args["vocab_size"] + args.ffn_hidden_size = model_args["intermediate_size"] + ++ if "rope_theta" in model_args: ++ args.rotary_base = int(model_args["rope_theta"]) ++ if "rope_scaling" in model_args and model_args["rope_scaling"].get("type", "") == "linear" and "factor" in model_args["rope_scaling"]: ++ args.rotary_seq_len_interpolation_factor = int(model_args["rope_scaling"]["factor"]) ++ if "rope_scaling" in model_args and model_args["rope_scaling"].get("rope_type", "") == "llama3": ++ args.use_rope_scaling = True ++ + if "num_key_value_heads" in model_args: + args.group_query_attention = True + args.num_query_groups = model_args["num_key_value_heads"] +@@ -457,6 +464,7 @@ def _load_checkpoint(queue, args): '--no-save-rng', '--mock-data', # To pass the "blend data checks" in arguments.py '--no-initialization', @@ -43,18 +85,59 @@ index b6697964..054ab941 100644 '--load', args.load_dir, '--no-one-logger', ] +@@ -560,6 +568,10 @@ def _load_checkpoint(queue, args): + md.checkpoint_args = margs + md.consumed_train_samples = 0 + md.consumed_valid_samples = 0 ++ md.norm_epsilon = margs.norm_epsilon ++ md.rotary_base = margs.rotary_base ++ md.rotary_seq_len_interpolation_factor = margs.rotary_seq_len_interpolation_factor ++ md.use_rope_scaling = margs.use_rope_scaling + + margs.model_size = args.model_size + diff --git a/tools/checkpoint/saver_mcore.py b/tools/checkpoint/saver_mcore.py -index 2caf26a9..0bfe2a8a 100644 +index 2caf26a9..83c7951b 100644 --- a/tools/checkpoint/saver_mcore.py +++ b/tools/checkpoint/saver_mcore.py -@@ -188,8 +188,8 @@ def save_checkpoint(queue, args): +@@ -137,6 +137,15 @@ def save_checkpoint(queue, args): + '--no-one-logger', + ] + ++ if md.norm_epsilon: ++ sys.argv.extend(['--norm-epsilon', str(md.norm_epsilon)]) ++ if md.rotary_base: ++ sys.argv.extend(['--rotary-base', str(md.rotary_base)]) ++ if md.rotary_seq_len_interpolation_factor: ++ sys.argv.extend(['--rotary-seq-len-interpolation-factor', str(md.rotary_seq_len_interpolation_factor)]) ++ if md.use_rope_scaling: ++ sys.argv.append('--use-rope-scaling') ++ + if md.make_vocab_size_divisible_by is not None: + sys.argv.extend(['--make-vocab-size-divisible-by', str(md.make_vocab_size_divisible_by)]) + if md.params_dtype == torch.float16: +@@ -188,8 +197,8 @@ def save_checkpoint(queue, args): margs.apply_query_key_layer_scaling = md.checkpoint_args.apply_query_key_layer_scaling - + # Sequence parallel is required if use both tensor-parallel and Moe. - if margs.num_experts is not None and args.target_tensor_parallel_size is not None: - if margs.num_experts > 1 and args.target_tensor_parallel_size > 1: + if args.target_tensor_parallel_size is not None: + if args.target_tensor_parallel_size > 1: margs.sequence_parallel = True - + validate_args(margs) +diff --git a/tools/run_text_generation_server.py b/tools/run_text_generation_server.py +index e5b3f08a..fd6688e2 100644 +--- a/tools/run_text_generation_server.py ++++ b/tools/run_text_generation_server.py +@@ -84,7 +84,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, +- rope_scaling=args.use_rope_scaling ++ rope_scaling=args.use_rope_scaling, ++ seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor + ) + + return model diff --git a/workloads/llm-megatron-ckpt-conversion/helm/templates/conversion-job.yaml b/workloads/llm-megatron-ckpt-conversion/helm/templates/conversion-job.yaml index b6ca364..2a0f39e 100644 --- a/workloads/llm-megatron-ckpt-conversion/helm/templates/conversion-job.yaml +++ b/workloads/llm-megatron-ckpt-conversion/helm/templates/conversion-job.yaml @@ -130,10 +130,15 @@ spec: --loader {{ .loader }} \ --saver {{ .saver }} \ --target-tensor-parallel-size {{ .tensorParallel }} \ + --target-pipeline-parallel-size {{ .pipelineParallel }} \ + --target-expert-parallel-size {{ .expertParallel }} \ --checkpoint-type hf \ --load-dir /local-resources/sourcemodel \ --save-dir /local-resources/checkpoints \ - --tokenizer-model /local-resources/sourcemodel + --tokenizer-model /local-resources/sourcemodel \ + {{- range $.Values.additionalArgs }} + {{ . }} \ + {{- end }} echo "Checkpoint conversion completed" diff --git a/workloads/llm-megatron-ckpt-conversion/helm/values.yaml b/workloads/llm-megatron-ckpt-conversion/helm/values.yaml index c65b9c8..7f98c62 100644 --- a/workloads/llm-megatron-ckpt-conversion/helm/values.yaml +++ b/workloads/llm-megatron-ckpt-conversion/helm/values.yaml @@ -42,3 +42,9 @@ conversionArgs: loader: "llama_mistral" # Model loader: llama_mistral, megatron, etc. saver: "mcore" tensorParallel: 1 # 1 for 8B, 8 for 70B + pipelineParallel: 1 + expertParallel: 1 + +# Additional args to pass to conversion script, e.g. --bf16 or --fp16 for specific precision +additionalArgs: +- "--bf16" diff --git a/workloads/llm-pretraining-megatron-lm-ray/helm/README.md b/workloads/llm-pretraining-megatron-lm-ray/helm/README.md index 408b19e..577d9d2 100644 --- a/workloads/llm-pretraining-megatron-lm-ray/helm/README.md +++ b/workloads/llm-pretraining-megatron-lm-ray/helm/README.md @@ -8,7 +8,6 @@ To generate manifests and print them in standard output using the default `value helm template workloads/llm-pretraining-megatron-lm-ray/helm ``` - This will generate a kubernetes manifest with a RayJob, a ConfigMap and a PersistentVolumeClaim resources in the user's active namespace. To override the default values, a specific file can be passed using `--values` flag @@ -48,6 +47,20 @@ Some assumptions for running the pretraining jobs are as follows: The initial mo key: minio-secret-key ``` +Additionally service account used by the rayjob must have `get rayjob` and `patch configmap | pvc` permissions in order to run garbage collection script from [helm/mount/gc.sh](./mount/gc.sh) successfully. If this requirement is not satisfied it will manifest by failing to start the ray cluster. The head pod of the cluster will have `Init:Error` status because init container that runs `gc.sh` script fails with the error similar to + +```bash +Error from server (Forbidden): rayjobs.ray.io is forbidden: User "system:serviceaccount:examplenamespace:default" cannot get resource "rayjobs" in API group "ray.io" in the namespace "examplenamespace" +``` + +To quickly overcome this issue while waiting for permissions setup one can comment out this line in [./templates/ray_job.yaml](./templates/ray_job.yaml#L69) + +``` +bash /local_resources/mount/gc.sh{{- if and .Values.kaiwo.storageEnabled .Values.kaiwo.enabled}} --skip-pvc{{- end }} {{ include "release.fullname" . }} +``` + ## Cleanup Note that this chart, when run with `kubectl apply`, will create RayJob, PersistentVolumeClaim and ConfigMap objects. After the RayJob has finished, there is a 3600-second grace period to remove the RayJob object from the namespace. ConfigMap and PersistentVolumeClaim are attached to the lifecycle of the RayJob at the start of the workload and cleaned up automatically. However, if there is an issue during start up of the workload, there can be a situation, when ConfigMap and PersistentVolumeClaim are created but are not owned by the RayJob. In this case ConfigMap and PersistentVolumeClaim resources should be cleaned up manually using `kubectl delete` command. + +If automatic garbage collection was disabled then resources of the workload should be deleted manually using `kubectl delete` commands. diff --git a/workloads/llm-pretraining-megatron-lm-ray/helm/mount/Megatron-LM.patch b/workloads/llm-pretraining-megatron-lm-ray/helm/mount/Megatron-LM.patch new file mode 100644 index 0000000..e9bfa47 --- /dev/null +++ b/workloads/llm-pretraining-megatron-lm-ray/helm/mount/Megatron-LM.patch @@ -0,0 +1,143 @@ +diff --git a/megatron/core/dist_checkpointing/strategies/filesystem_async.py b/megatron/core/dist_checkpointing/strategies/filesystem_async.py +index 47ab4d11..0c1b868d 100644 +--- a/megatron/core/dist_checkpointing/strategies/filesystem_async.py ++++ b/megatron/core/dist_checkpointing/strategies/filesystem_async.py +@@ -20,6 +20,15 @@ from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteIte + from torch.distributed.checkpoint.storage import WriteResult + from torch.futures import Future + ++try: ++ # This PR https://github.com/pytorch/pytorch/pull/143359 introduced breaking change to saving checkpoints ++ # in torch_dist format. This is a workaround to fix the issue. ++ from torch.distributed.checkpoint.filesystem import _StorageWriterTransforms ++ from functools import partial ++ _write_item = partial(_write_item, _StorageWriterTransforms()) ++except ImportError: ++ pass ++ + logger = logging.getLogger(__name__) + + WriteBucket = Tuple[Path, str, Tuple[list, list]] # represents writes to a single file +diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py +index 92813050..dd771395 100644 +--- a/megatron/training/checkpointing.py ++++ b/megatron/training/checkpointing.py +@@ -970,6 +970,9 @@ def load_args_from_checkpoint( + _set_arg('rotary_base', force=True) + _set_arg('rotary_percent', force=True) + _set_arg('rotary_interleaved', force=True) ++ _set_arg('rotary_seq_len_interpolation_factor', force=True) ++ _set_arg('use_rope_scaling', force=True) ++ _set_arg('norm_epsilon', force=True) + _set_arg('add_bias_linear', force=True) + _set_arg('add_qkv_bias', force=True) + _set_arg('squared_relu', force=True) +diff --git a/pretrain_gpt.py b/pretrain_gpt.py +index d31c0954..a850624a 100644 +--- a/pretrain_gpt.py ++++ b/pretrain_gpt.py +@@ -125,7 +125,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, +- rope_scaling=args.use_rope_scaling ++ rope_scaling=args.use_rope_scaling, ++ seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor + ) + + return model +diff --git a/tools/checkpoint/convert.py b/tools/checkpoint/convert.py +index 935613b1..4a2297f6 100644 +--- a/tools/checkpoint/convert.py ++++ b/tools/checkpoint/convert.py +@@ -151,4 +151,8 @@ def main(): + + + if __name__ == '__main__': ++ try: ++ mp.set_start_method('spawn') ++ except RuntimeError: ++ pass + main() +diff --git a/tools/checkpoint/loader_llama_mistral.py b/tools/checkpoint/loader_llama_mistral.py +index b6697964..c96fec45 100644 +--- a/tools/checkpoint/loader_llama_mistral.py ++++ b/tools/checkpoint/loader_llama_mistral.py +@@ -320,6 +320,13 @@ def load_args_from_checkpoint(args): + args.padded_vocab_size = model_args["vocab_size"] + args.ffn_hidden_size = model_args["intermediate_size"] + ++ if "rope_theta" in model_args: ++ args.rotary_base = int(model_args["rope_theta"]) ++ if "rope_scaling" in model_args and model_args["rope_scaling"].get("type", "") == "linear" and "factor" in model_args["rope_scaling"]: ++ args.rotary_seq_len_interpolation_factor = int(model_args["rope_scaling"]["factor"]) ++ if "rope_scaling" in model_args and model_args["rope_scaling"].get("rope_type", "") == "llama3": ++ args.use_rope_scaling = True ++ + if "num_key_value_heads" in model_args: + args.group_query_attention = True + args.num_query_groups = model_args["num_key_value_heads"] +@@ -457,6 +464,7 @@ def _load_checkpoint(queue, args): + '--no-save-rng', + '--mock-data', # To pass the "blend data checks" in arguments.py + '--no-initialization', ++ '--no-gradient-accumulation-fusion', + '--load', args.load_dir, + '--no-one-logger', + ] +@@ -560,6 +568,10 @@ def _load_checkpoint(queue, args): + md.checkpoint_args = margs + md.consumed_train_samples = 0 + md.consumed_valid_samples = 0 ++ md.norm_epsilon = margs.norm_epsilon ++ md.rotary_base = margs.rotary_base ++ md.rotary_seq_len_interpolation_factor = margs.rotary_seq_len_interpolation_factor ++ md.use_rope_scaling = margs.use_rope_scaling + + margs.model_size = args.model_size + +diff --git a/tools/checkpoint/saver_mcore.py b/tools/checkpoint/saver_mcore.py +index 2caf26a9..83c7951b 100644 +--- a/tools/checkpoint/saver_mcore.py ++++ b/tools/checkpoint/saver_mcore.py +@@ -137,6 +137,15 @@ def save_checkpoint(queue, args): + '--no-one-logger', + ] + ++ if md.norm_epsilon: ++ sys.argv.extend(['--norm-epsilon', str(md.norm_epsilon)]) ++ if md.rotary_base: ++ sys.argv.extend(['--rotary-base', str(md.rotary_base)]) ++ if md.rotary_seq_len_interpolation_factor: ++ sys.argv.extend(['--rotary-seq-len-interpolation-factor', str(md.rotary_seq_len_interpolation_factor)]) ++ if md.use_rope_scaling: ++ sys.argv.append('--use-rope-scaling') ++ + if md.make_vocab_size_divisible_by is not None: + sys.argv.extend(['--make-vocab-size-divisible-by', str(md.make_vocab_size_divisible_by)]) + if md.params_dtype == torch.float16: +@@ -188,8 +197,8 @@ def save_checkpoint(queue, args): + margs.apply_query_key_layer_scaling = md.checkpoint_args.apply_query_key_layer_scaling + + # Sequence parallel is required if use both tensor-parallel and Moe. +- if margs.num_experts is not None and args.target_tensor_parallel_size is not None: +- if margs.num_experts > 1 and args.target_tensor_parallel_size > 1: ++ if args.target_tensor_parallel_size is not None: ++ if args.target_tensor_parallel_size > 1: + margs.sequence_parallel = True + + validate_args(margs) +diff --git a/tools/run_text_generation_server.py b/tools/run_text_generation_server.py +index e5b3f08a..fd6688e2 100644 +--- a/tools/run_text_generation_server.py ++++ b/tools/run_text_generation_server.py +@@ -84,7 +84,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, +- rope_scaling=args.use_rope_scaling ++ rope_scaling=args.use_rope_scaling, ++ seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor + ) + + return model diff --git a/workloads/llm-pretraining-megatron-lm-ray/helm/mount/ray_entrypoint.py b/workloads/llm-pretraining-megatron-lm-ray/helm/mount/ray_entrypoint.py index b9fb2db..d2c0fa6 100644 --- a/workloads/llm-pretraining-megatron-lm-ray/helm/mount/ray_entrypoint.py +++ b/workloads/llm-pretraining-megatron-lm-ray/helm/mount/ray_entrypoint.py @@ -54,6 +54,10 @@ def setup_environment(self): """ Sets up the environment (device, env vars) required for Megatron initialization which will happen inside the pretrain function. """ + import subprocess + + subprocess.run("cd /workspace/Megatron-LM; git apply /local_resources/mount/Megatron-LM.patch;", shell=True) + # Synchronize GPU Visibility Environment Variables hip_visible_devices = os.environ.get("HIP_VISIBLE_DEVICES") print( diff --git a/workloads/llm-pretraining-megatron-lm/helm/mount/Megatron-LM.patch b/workloads/llm-pretraining-megatron-lm/helm/mount/Megatron-LM.patch index 3074cb5..e9bfa47 100644 --- a/workloads/llm-pretraining-megatron-lm/helm/mount/Megatron-LM.patch +++ b/workloads/llm-pretraining-megatron-lm/helm/mount/Megatron-LM.patch @@ -5,7 +5,7 @@ index 47ab4d11..0c1b868d 100644 @@ -20,6 +20,15 @@ from torch.distributed.checkpoint.planner import SavePlan, SavePlanner, WriteIte from torch.distributed.checkpoint.storage import WriteResult from torch.futures import Future - + +try: + # This PR https://github.com/pytorch/pytorch/pull/143359 introduced breaking change to saving checkpoints + # in torch_dist format. This is a workaround to fix the issue. @@ -16,15 +16,43 @@ index 47ab4d11..0c1b868d 100644 + pass + logger = logging.getLogger(__name__) - + WriteBucket = Tuple[Path, str, Tuple[list, list]] # represents writes to a single file +diff --git a/megatron/training/checkpointing.py b/megatron/training/checkpointing.py +index 92813050..dd771395 100644 +--- a/megatron/training/checkpointing.py ++++ b/megatron/training/checkpointing.py +@@ -970,6 +970,9 @@ def load_args_from_checkpoint( + _set_arg('rotary_base', force=True) + _set_arg('rotary_percent', force=True) + _set_arg('rotary_interleaved', force=True) ++ _set_arg('rotary_seq_len_interpolation_factor', force=True) ++ _set_arg('use_rope_scaling', force=True) ++ _set_arg('norm_epsilon', force=True) + _set_arg('add_bias_linear', force=True) + _set_arg('add_qkv_bias', force=True) + _set_arg('squared_relu', force=True) +diff --git a/pretrain_gpt.py b/pretrain_gpt.py +index d31c0954..a850624a 100644 +--- a/pretrain_gpt.py ++++ b/pretrain_gpt.py +@@ -125,7 +125,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, +- rope_scaling=args.use_rope_scaling ++ rope_scaling=args.use_rope_scaling, ++ seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor + ) + + return model diff --git a/tools/checkpoint/convert.py b/tools/checkpoint/convert.py index 935613b1..4a2297f6 100644 --- a/tools/checkpoint/convert.py +++ b/tools/checkpoint/convert.py @@ -151,4 +151,8 @@ def main(): - - + + if __name__ == '__main__': + try: + mp.set_start_method('spawn') @@ -32,10 +60,24 @@ index 935613b1..4a2297f6 100644 + pass main() diff --git a/tools/checkpoint/loader_llama_mistral.py b/tools/checkpoint/loader_llama_mistral.py -index b6697964..054ab941 100644 +index b6697964..c96fec45 100644 --- a/tools/checkpoint/loader_llama_mistral.py +++ b/tools/checkpoint/loader_llama_mistral.py -@@ -457,6 +457,7 @@ def _load_checkpoint(queue, args): +@@ -320,6 +320,13 @@ def load_args_from_checkpoint(args): + args.padded_vocab_size = model_args["vocab_size"] + args.ffn_hidden_size = model_args["intermediate_size"] + ++ if "rope_theta" in model_args: ++ args.rotary_base = int(model_args["rope_theta"]) ++ if "rope_scaling" in model_args and model_args["rope_scaling"].get("type", "") == "linear" and "factor" in model_args["rope_scaling"]: ++ args.rotary_seq_len_interpolation_factor = int(model_args["rope_scaling"]["factor"]) ++ if "rope_scaling" in model_args and model_args["rope_scaling"].get("rope_type", "") == "llama3": ++ args.use_rope_scaling = True ++ + if "num_key_value_heads" in model_args: + args.group_query_attention = True + args.num_query_groups = model_args["num_key_value_heads"] +@@ -457,6 +464,7 @@ def _load_checkpoint(queue, args): '--no-save-rng', '--mock-data', # To pass the "blend data checks" in arguments.py '--no-initialization', @@ -43,3 +85,59 @@ index b6697964..054ab941 100644 '--load', args.load_dir, '--no-one-logger', ] +@@ -560,6 +568,10 @@ def _load_checkpoint(queue, args): + md.checkpoint_args = margs + md.consumed_train_samples = 0 + md.consumed_valid_samples = 0 ++ md.norm_epsilon = margs.norm_epsilon ++ md.rotary_base = margs.rotary_base ++ md.rotary_seq_len_interpolation_factor = margs.rotary_seq_len_interpolation_factor ++ md.use_rope_scaling = margs.use_rope_scaling + + margs.model_size = args.model_size + +diff --git a/tools/checkpoint/saver_mcore.py b/tools/checkpoint/saver_mcore.py +index 2caf26a9..83c7951b 100644 +--- a/tools/checkpoint/saver_mcore.py ++++ b/tools/checkpoint/saver_mcore.py +@@ -137,6 +137,15 @@ def save_checkpoint(queue, args): + '--no-one-logger', + ] + ++ if md.norm_epsilon: ++ sys.argv.extend(['--norm-epsilon', str(md.norm_epsilon)]) ++ if md.rotary_base: ++ sys.argv.extend(['--rotary-base', str(md.rotary_base)]) ++ if md.rotary_seq_len_interpolation_factor: ++ sys.argv.extend(['--rotary-seq-len-interpolation-factor', str(md.rotary_seq_len_interpolation_factor)]) ++ if md.use_rope_scaling: ++ sys.argv.append('--use-rope-scaling') ++ + if md.make_vocab_size_divisible_by is not None: + sys.argv.extend(['--make-vocab-size-divisible-by', str(md.make_vocab_size_divisible_by)]) + if md.params_dtype == torch.float16: +@@ -188,8 +197,8 @@ def save_checkpoint(queue, args): + margs.apply_query_key_layer_scaling = md.checkpoint_args.apply_query_key_layer_scaling + + # Sequence parallel is required if use both tensor-parallel and Moe. +- if margs.num_experts is not None and args.target_tensor_parallel_size is not None: +- if margs.num_experts > 1 and args.target_tensor_parallel_size > 1: ++ if args.target_tensor_parallel_size is not None: ++ if args.target_tensor_parallel_size > 1: + margs.sequence_parallel = True + + validate_args(margs) +diff --git a/tools/run_text_generation_server.py b/tools/run_text_generation_server.py +index e5b3f08a..fd6688e2 100644 +--- a/tools/run_text_generation_server.py ++++ b/tools/run_text_generation_server.py +@@ -84,7 +84,8 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, +- rope_scaling=args.use_rope_scaling ++ rope_scaling=args.use_rope_scaling, ++ seq_len_interpolation_factor=args.rotary_seq_len_interpolation_factor + ) + + return model diff --git a/workloads/llm-pretraining-megatron-lm/helm/mount/train-cpt.sh b/workloads/llm-pretraining-megatron-lm/helm/mount/train-cpt.sh index f2e6969..de38114 100644 --- a/workloads/llm-pretraining-megatron-lm/helm/mount/train-cpt.sh +++ b/workloads/llm-pretraining-megatron-lm/helm/mount/train-cpt.sh @@ -11,7 +11,8 @@ ################################################################################# # Source the environment variables from the separate script -source ./set-env-vars.sh # path in container +DIR="$(cd -P "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +source "$DIR/set-env-vars.sh" # path in container TIME_STAMP=$(date +"%Y-%m-%d_%H-%M-%S") EXP_NAME="${EXP_NAME:-perf}" From 0eebd2113e96291356b469b9d40cf8c4b9aa2fbb Mon Sep 17 00:00:00 2001 From: Jussi Elo Date: Wed, 26 Nov 2025 09:29:34 +0200 Subject: [PATCH 3/6] Update the Docs workflow with contemporary setup (#465) --- .github/workflows/docs-file-copy.yml | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/.github/workflows/docs-file-copy.yml b/.github/workflows/docs-file-copy.yml index eb432eb..fca2b37 100644 --- a/.github/workflows/docs-file-copy.yml +++ b/.github/workflows/docs-file-copy.yml @@ -1,6 +1,8 @@ name: Copy workload documentation to public docs repo -# We rsync the ai-workloads documentation to a temp clone of the public docs repo -# and commit and push the changes to the main branch of the public docs repo. Purpose is to keep the Docs repo (consolidated SiloGen docs) updated with ai-workloads repository changes. +# We rsync the docs and workloads directories to the public documentation. +# We commit and push the changes to the develop branch of the public docs +# repository. Purpose is to keep the consolidated EAI documentation updated +# with changes from the contributing repositories. on: push: @@ -16,20 +18,22 @@ jobs: if: github.repository == 'silogen/ai-workloads' runs-on: ubuntu-latest steps: - - name: Checkout core repo + - name: Checkout the repo uses: actions/checkout@v4 - - name: Push to public docs repo + - name: Push to external docs develop branch run: | git config --global user.name 'GitHub Actions' git config --global user.email 'actions@github.com' git clone https://x-access-token:${{ secrets.DOCS_REPO_TOKEN }}@github.com/silogen/ai-workloads.git source_docs - git clone https://x-access-token:${{ secrets.DOCS_REPO_TOKEN }}@github.com/silogen/docs.git target_silogen_docs - cd target_silogen_docs + git clone https://x-access-token:${{ secrets.DOCS_REPO_TOKEN }}@github.com/silogen/AMDEnterpriseAISuiteDocs.git target_amd_docs + cd target_amd_docs + rsync -av --delete --exclude='.git' ../source_docs/docs docs/ai-workloads-docs rsync -av --delete --exclude='.git' ../source_docs/workloads docs/ai-workloads-manifests + git add . git diff --staged --quiet || git commit -m "Update external docs from ai-workloads repo" - git push origin main + git push origin develop env: DOCS_REPO_TOKEN: ${{ secrets.DOCS_REPO_TOKEN }} From e56e1d87166f38f4f291c7239d33ecbfee932260 Mon Sep 17 00:00:00 2001 From: jorivesga Date: Wed, 26 Nov 2025 10:22:59 +0200 Subject: [PATCH 4/6] Rename swinunetr training and inference workloads (#464) * rename swinunetr training and inference workloads to benchmak-lifecience-... for consistency with the other lifescience workloads (reinvent, semaflow) * remove folder commited by mistake --- .../examples/demo_inference_service.ipynb | 0 .../examples/utils.py | 0 .../helm/Chart.yaml | 0 .../helm/README.md | 0 .../helm/mount/README.md | 0 .../helm/mount/data_utils.py | 0 .../helm/mount/entrypoint.sh | 0 .../helm/mount/inference_service.py | 0 .../helm/mount/requirements.txt | 0 .../helm/mount/swinunetr.py | 0 .../helm/mount/swinunetr_configuration.py | 0 .../helm/overrides/kaiwo/kaiwo-enable.yaml | 0 .../helm/templates/_helpers.tpl | 0 .../helm/templates/configmap.yaml | 0 .../helm/templates/deployment.yaml | 0 .../helm/templates/service.yaml | 0 .../helm/values.schema.json | 0 .../helm/values.yaml | 0 .../helm/Chart.yaml | 0 .../helm/README.md | 0 .../helm/mount/README.md | 0 .../helm/mount/data_utils.py | 0 .../helm/mount/entrypoint.sh | 0 .../helm/mount/lr_scheduler.py | 0 .../helm/mount/main.py | 0 .../helm/mount/requirements.txt | 0 .../helm/mount/trainer.py | 0 .../helm/mount/utils.py | 0 .../helm/overrides/kaiwo/kaiwo-enable.yaml | 0 .../helm/templates/_helpers.tpl | 0 .../helm/templates/configmap.yaml | 0 .../helm/templates/job.yaml | 0 .../helm/values.schema.json | 0 .../helm/values.yaml | 0 34 files changed, 0 insertions(+), 0 deletions(-) rename workloads/{dev-lifescience-swinunetr-inference => benchmark-lifescience-swinunetr-inference}/examples/demo_inference_service.ipynb (100%) rename workloads/{dev-lifescience-swinunetr-inference => benchmark-lifescience-swinunetr-inference}/examples/utils.py (100%) rename workloads/{dev-lifescience-swinunetr-inference => benchmark-lifescience-swinunetr-inference}/helm/Chart.yaml (100%) rename workloads/{dev-lifescience-swinunetr-inference => benchmark-lifescience-swinunetr-inference}/helm/README.md (100%) rename workloads/{dev-lifescience-swinunetr-inference => benchmark-lifescience-swinunetr-inference}/helm/mount/README.md (100%) rename workloads/{dev-lifescience-swinunetr-inference => benchmark-lifescience-swinunetr-inference}/helm/mount/data_utils.py (100%) rename workloads/{dev-lifescience-swinunetr-inference => benchmark-lifescience-swinunetr-inference}/helm/mount/entrypoint.sh (100%) rename workloads/{dev-lifescience-swinunetr-inference => benchmark-lifescience-swinunetr-inference}/helm/mount/inference_service.py (100%) rename workloads/{dev-lifescience-swinunetr-inference => benchmark-lifescience-swinunetr-inference}/helm/mount/requirements.txt (100%) rename workloads/{dev-lifescience-swinunetr-inference => benchmark-lifescience-swinunetr-inference}/helm/mount/swinunetr.py (100%) rename workloads/{dev-lifescience-swinunetr-inference => benchmark-lifescience-swinunetr-inference}/helm/mount/swinunetr_configuration.py (100%) rename workloads/{dev-lifescience-swinunetr-inference => benchmark-lifescience-swinunetr-inference}/helm/overrides/kaiwo/kaiwo-enable.yaml (100%) rename workloads/{dev-lifescience-swinunetr-inference => benchmark-lifescience-swinunetr-inference}/helm/templates/_helpers.tpl (100%) rename workloads/{dev-lifescience-swinunetr-inference => benchmark-lifescience-swinunetr-inference}/helm/templates/configmap.yaml (100%) rename workloads/{dev-lifescience-swinunetr-inference => benchmark-lifescience-swinunetr-inference}/helm/templates/deployment.yaml (100%) rename workloads/{dev-lifescience-swinunetr-inference => benchmark-lifescience-swinunetr-inference}/helm/templates/service.yaml (100%) rename workloads/{dev-lifescience-swinunetr-inference => benchmark-lifescience-swinunetr-inference}/helm/values.schema.json (100%) rename workloads/{dev-lifescience-swinunetr-inference => benchmark-lifescience-swinunetr-inference}/helm/values.yaml (100%) rename workloads/{dev-lifescience-swinunetr-training => benchmark-lifescience-swinunetr-training}/helm/Chart.yaml (100%) rename workloads/{dev-lifescience-swinunetr-training => benchmark-lifescience-swinunetr-training}/helm/README.md (100%) rename workloads/{dev-lifescience-swinunetr-training => benchmark-lifescience-swinunetr-training}/helm/mount/README.md (100%) rename workloads/{dev-lifescience-swinunetr-training => benchmark-lifescience-swinunetr-training}/helm/mount/data_utils.py (100%) rename workloads/{dev-lifescience-swinunetr-training => benchmark-lifescience-swinunetr-training}/helm/mount/entrypoint.sh (100%) rename workloads/{dev-lifescience-swinunetr-training => benchmark-lifescience-swinunetr-training}/helm/mount/lr_scheduler.py (100%) rename workloads/{dev-lifescience-swinunetr-training => benchmark-lifescience-swinunetr-training}/helm/mount/main.py (100%) rename workloads/{dev-lifescience-swinunetr-training => benchmark-lifescience-swinunetr-training}/helm/mount/requirements.txt (100%) rename workloads/{dev-lifescience-swinunetr-training => benchmark-lifescience-swinunetr-training}/helm/mount/trainer.py (100%) rename workloads/{dev-lifescience-swinunetr-training => benchmark-lifescience-swinunetr-training}/helm/mount/utils.py (100%) rename workloads/{dev-lifescience-swinunetr-training => benchmark-lifescience-swinunetr-training}/helm/overrides/kaiwo/kaiwo-enable.yaml (100%) rename workloads/{dev-lifescience-swinunetr-training => benchmark-lifescience-swinunetr-training}/helm/templates/_helpers.tpl (100%) rename workloads/{dev-lifescience-swinunetr-training => benchmark-lifescience-swinunetr-training}/helm/templates/configmap.yaml (100%) rename workloads/{dev-lifescience-swinunetr-training => benchmark-lifescience-swinunetr-training}/helm/templates/job.yaml (100%) rename workloads/{dev-lifescience-swinunetr-training => benchmark-lifescience-swinunetr-training}/helm/values.schema.json (100%) rename workloads/{dev-lifescience-swinunetr-training => benchmark-lifescience-swinunetr-training}/helm/values.yaml (100%) diff --git a/workloads/dev-lifescience-swinunetr-inference/examples/demo_inference_service.ipynb b/workloads/benchmark-lifescience-swinunetr-inference/examples/demo_inference_service.ipynb similarity index 100% rename from workloads/dev-lifescience-swinunetr-inference/examples/demo_inference_service.ipynb rename to workloads/benchmark-lifescience-swinunetr-inference/examples/demo_inference_service.ipynb diff --git a/workloads/dev-lifescience-swinunetr-inference/examples/utils.py b/workloads/benchmark-lifescience-swinunetr-inference/examples/utils.py similarity index 100% rename from workloads/dev-lifescience-swinunetr-inference/examples/utils.py rename to workloads/benchmark-lifescience-swinunetr-inference/examples/utils.py diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/Chart.yaml b/workloads/benchmark-lifescience-swinunetr-inference/helm/Chart.yaml similarity index 100% rename from workloads/dev-lifescience-swinunetr-inference/helm/Chart.yaml rename to workloads/benchmark-lifescience-swinunetr-inference/helm/Chart.yaml diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/README.md b/workloads/benchmark-lifescience-swinunetr-inference/helm/README.md similarity index 100% rename from workloads/dev-lifescience-swinunetr-inference/helm/README.md rename to workloads/benchmark-lifescience-swinunetr-inference/helm/README.md diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/mount/README.md b/workloads/benchmark-lifescience-swinunetr-inference/helm/mount/README.md similarity index 100% rename from workloads/dev-lifescience-swinunetr-inference/helm/mount/README.md rename to workloads/benchmark-lifescience-swinunetr-inference/helm/mount/README.md diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/mount/data_utils.py b/workloads/benchmark-lifescience-swinunetr-inference/helm/mount/data_utils.py similarity index 100% rename from workloads/dev-lifescience-swinunetr-inference/helm/mount/data_utils.py rename to workloads/benchmark-lifescience-swinunetr-inference/helm/mount/data_utils.py diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/mount/entrypoint.sh b/workloads/benchmark-lifescience-swinunetr-inference/helm/mount/entrypoint.sh similarity index 100% rename from workloads/dev-lifescience-swinunetr-inference/helm/mount/entrypoint.sh rename to workloads/benchmark-lifescience-swinunetr-inference/helm/mount/entrypoint.sh diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/mount/inference_service.py b/workloads/benchmark-lifescience-swinunetr-inference/helm/mount/inference_service.py similarity index 100% rename from workloads/dev-lifescience-swinunetr-inference/helm/mount/inference_service.py rename to workloads/benchmark-lifescience-swinunetr-inference/helm/mount/inference_service.py diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/mount/requirements.txt b/workloads/benchmark-lifescience-swinunetr-inference/helm/mount/requirements.txt similarity index 100% rename from workloads/dev-lifescience-swinunetr-inference/helm/mount/requirements.txt rename to workloads/benchmark-lifescience-swinunetr-inference/helm/mount/requirements.txt diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/mount/swinunetr.py b/workloads/benchmark-lifescience-swinunetr-inference/helm/mount/swinunetr.py similarity index 100% rename from workloads/dev-lifescience-swinunetr-inference/helm/mount/swinunetr.py rename to workloads/benchmark-lifescience-swinunetr-inference/helm/mount/swinunetr.py diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/mount/swinunetr_configuration.py b/workloads/benchmark-lifescience-swinunetr-inference/helm/mount/swinunetr_configuration.py similarity index 100% rename from workloads/dev-lifescience-swinunetr-inference/helm/mount/swinunetr_configuration.py rename to workloads/benchmark-lifescience-swinunetr-inference/helm/mount/swinunetr_configuration.py diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/overrides/kaiwo/kaiwo-enable.yaml b/workloads/benchmark-lifescience-swinunetr-inference/helm/overrides/kaiwo/kaiwo-enable.yaml similarity index 100% rename from workloads/dev-lifescience-swinunetr-inference/helm/overrides/kaiwo/kaiwo-enable.yaml rename to workloads/benchmark-lifescience-swinunetr-inference/helm/overrides/kaiwo/kaiwo-enable.yaml diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/templates/_helpers.tpl b/workloads/benchmark-lifescience-swinunetr-inference/helm/templates/_helpers.tpl similarity index 100% rename from workloads/dev-lifescience-swinunetr-inference/helm/templates/_helpers.tpl rename to workloads/benchmark-lifescience-swinunetr-inference/helm/templates/_helpers.tpl diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/templates/configmap.yaml b/workloads/benchmark-lifescience-swinunetr-inference/helm/templates/configmap.yaml similarity index 100% rename from workloads/dev-lifescience-swinunetr-inference/helm/templates/configmap.yaml rename to workloads/benchmark-lifescience-swinunetr-inference/helm/templates/configmap.yaml diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/templates/deployment.yaml b/workloads/benchmark-lifescience-swinunetr-inference/helm/templates/deployment.yaml similarity index 100% rename from workloads/dev-lifescience-swinunetr-inference/helm/templates/deployment.yaml rename to workloads/benchmark-lifescience-swinunetr-inference/helm/templates/deployment.yaml diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/templates/service.yaml b/workloads/benchmark-lifescience-swinunetr-inference/helm/templates/service.yaml similarity index 100% rename from workloads/dev-lifescience-swinunetr-inference/helm/templates/service.yaml rename to workloads/benchmark-lifescience-swinunetr-inference/helm/templates/service.yaml diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/values.schema.json b/workloads/benchmark-lifescience-swinunetr-inference/helm/values.schema.json similarity index 100% rename from workloads/dev-lifescience-swinunetr-inference/helm/values.schema.json rename to workloads/benchmark-lifescience-swinunetr-inference/helm/values.schema.json diff --git a/workloads/dev-lifescience-swinunetr-inference/helm/values.yaml b/workloads/benchmark-lifescience-swinunetr-inference/helm/values.yaml similarity index 100% rename from workloads/dev-lifescience-swinunetr-inference/helm/values.yaml rename to workloads/benchmark-lifescience-swinunetr-inference/helm/values.yaml diff --git a/workloads/dev-lifescience-swinunetr-training/helm/Chart.yaml b/workloads/benchmark-lifescience-swinunetr-training/helm/Chart.yaml similarity index 100% rename from workloads/dev-lifescience-swinunetr-training/helm/Chart.yaml rename to workloads/benchmark-lifescience-swinunetr-training/helm/Chart.yaml diff --git a/workloads/dev-lifescience-swinunetr-training/helm/README.md b/workloads/benchmark-lifescience-swinunetr-training/helm/README.md similarity index 100% rename from workloads/dev-lifescience-swinunetr-training/helm/README.md rename to workloads/benchmark-lifescience-swinunetr-training/helm/README.md diff --git a/workloads/dev-lifescience-swinunetr-training/helm/mount/README.md b/workloads/benchmark-lifescience-swinunetr-training/helm/mount/README.md similarity index 100% rename from workloads/dev-lifescience-swinunetr-training/helm/mount/README.md rename to workloads/benchmark-lifescience-swinunetr-training/helm/mount/README.md diff --git a/workloads/dev-lifescience-swinunetr-training/helm/mount/data_utils.py b/workloads/benchmark-lifescience-swinunetr-training/helm/mount/data_utils.py similarity index 100% rename from workloads/dev-lifescience-swinunetr-training/helm/mount/data_utils.py rename to workloads/benchmark-lifescience-swinunetr-training/helm/mount/data_utils.py diff --git a/workloads/dev-lifescience-swinunetr-training/helm/mount/entrypoint.sh b/workloads/benchmark-lifescience-swinunetr-training/helm/mount/entrypoint.sh similarity index 100% rename from workloads/dev-lifescience-swinunetr-training/helm/mount/entrypoint.sh rename to workloads/benchmark-lifescience-swinunetr-training/helm/mount/entrypoint.sh diff --git a/workloads/dev-lifescience-swinunetr-training/helm/mount/lr_scheduler.py b/workloads/benchmark-lifescience-swinunetr-training/helm/mount/lr_scheduler.py similarity index 100% rename from workloads/dev-lifescience-swinunetr-training/helm/mount/lr_scheduler.py rename to workloads/benchmark-lifescience-swinunetr-training/helm/mount/lr_scheduler.py diff --git a/workloads/dev-lifescience-swinunetr-training/helm/mount/main.py b/workloads/benchmark-lifescience-swinunetr-training/helm/mount/main.py similarity index 100% rename from workloads/dev-lifescience-swinunetr-training/helm/mount/main.py rename to workloads/benchmark-lifescience-swinunetr-training/helm/mount/main.py diff --git a/workloads/dev-lifescience-swinunetr-training/helm/mount/requirements.txt b/workloads/benchmark-lifescience-swinunetr-training/helm/mount/requirements.txt similarity index 100% rename from workloads/dev-lifescience-swinunetr-training/helm/mount/requirements.txt rename to workloads/benchmark-lifescience-swinunetr-training/helm/mount/requirements.txt diff --git a/workloads/dev-lifescience-swinunetr-training/helm/mount/trainer.py b/workloads/benchmark-lifescience-swinunetr-training/helm/mount/trainer.py similarity index 100% rename from workloads/dev-lifescience-swinunetr-training/helm/mount/trainer.py rename to workloads/benchmark-lifescience-swinunetr-training/helm/mount/trainer.py diff --git a/workloads/dev-lifescience-swinunetr-training/helm/mount/utils.py b/workloads/benchmark-lifescience-swinunetr-training/helm/mount/utils.py similarity index 100% rename from workloads/dev-lifescience-swinunetr-training/helm/mount/utils.py rename to workloads/benchmark-lifescience-swinunetr-training/helm/mount/utils.py diff --git a/workloads/dev-lifescience-swinunetr-training/helm/overrides/kaiwo/kaiwo-enable.yaml b/workloads/benchmark-lifescience-swinunetr-training/helm/overrides/kaiwo/kaiwo-enable.yaml similarity index 100% rename from workloads/dev-lifescience-swinunetr-training/helm/overrides/kaiwo/kaiwo-enable.yaml rename to workloads/benchmark-lifescience-swinunetr-training/helm/overrides/kaiwo/kaiwo-enable.yaml diff --git a/workloads/dev-lifescience-swinunetr-training/helm/templates/_helpers.tpl b/workloads/benchmark-lifescience-swinunetr-training/helm/templates/_helpers.tpl similarity index 100% rename from workloads/dev-lifescience-swinunetr-training/helm/templates/_helpers.tpl rename to workloads/benchmark-lifescience-swinunetr-training/helm/templates/_helpers.tpl diff --git a/workloads/dev-lifescience-swinunetr-training/helm/templates/configmap.yaml b/workloads/benchmark-lifescience-swinunetr-training/helm/templates/configmap.yaml similarity index 100% rename from workloads/dev-lifescience-swinunetr-training/helm/templates/configmap.yaml rename to workloads/benchmark-lifescience-swinunetr-training/helm/templates/configmap.yaml diff --git a/workloads/dev-lifescience-swinunetr-training/helm/templates/job.yaml b/workloads/benchmark-lifescience-swinunetr-training/helm/templates/job.yaml similarity index 100% rename from workloads/dev-lifescience-swinunetr-training/helm/templates/job.yaml rename to workloads/benchmark-lifescience-swinunetr-training/helm/templates/job.yaml diff --git a/workloads/dev-lifescience-swinunetr-training/helm/values.schema.json b/workloads/benchmark-lifescience-swinunetr-training/helm/values.schema.json similarity index 100% rename from workloads/dev-lifescience-swinunetr-training/helm/values.schema.json rename to workloads/benchmark-lifescience-swinunetr-training/helm/values.schema.json diff --git a/workloads/dev-lifescience-swinunetr-training/helm/values.yaml b/workloads/benchmark-lifescience-swinunetr-training/helm/values.yaml similarity index 100% rename from workloads/dev-lifescience-swinunetr-training/helm/values.yaml rename to workloads/benchmark-lifescience-swinunetr-training/helm/values.yaml From df65f04ef9f38ee3fce7f51df06febd08a5b1903 Mon Sep 17 00:00:00 2001 From: Bo Date: Tue, 2 Dec 2025 11:03:16 +0100 Subject: [PATCH 5/6] Update JupyterLab to ROCm 7.0.2 (#468) * Update JupyterLab container image to rocm7.0.2 and adjust related configurations * Fix quotes in HTTP probe paths for consistency * Add failureThreshold and periodSeconds to startupProbe configuration * Update startupProbe configuration --- workloads/dev-workspace-jupyterlab/helm/README.md | 2 +- .../helm/overrides/dev-center/signature.yaml | 2 +- workloads/dev-workspace-jupyterlab/helm/values.yaml | 10 ++++++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/workloads/dev-workspace-jupyterlab/helm/README.md b/workloads/dev-workspace-jupyterlab/helm/README.md index 15b7aad..de1fca1 100644 --- a/workloads/dev-workspace-jupyterlab/helm/README.md +++ b/workloads/dev-workspace-jupyterlab/helm/README.md @@ -8,7 +8,7 @@ You can configure the following parameters in the `values.yaml` file or override | Parameter | Description | Default | |------------------------|-----------------------------------------------------------------------------|-------------------------------------------------------------------------| -| `image` | Container image repository and tag | `rocm/pytorch:rocm6.4_ubuntu24.04_py3.12_pytorch_release_2.6.0` | +| `image` | Container image repository and tag | `rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0` | | `imagePullPolicy` | Image pull policy | `Always` | | `imagePullSecrets` | List of image pull secrets for private registries | `[]` | | `gpus` | Number of GPUs to allocate (set to 0 for CPU-only mode) | `1` | diff --git a/workloads/dev-workspace-jupyterlab/helm/overrides/dev-center/signature.yaml b/workloads/dev-workspace-jupyterlab/helm/overrides/dev-center/signature.yaml index 3516a77..527d61d 100644 --- a/workloads/dev-workspace-jupyterlab/helm/overrides/dev-center/signature.yaml +++ b/workloads/dev-workspace-jupyterlab/helm/overrides/dev-center/signature.yaml @@ -1,4 +1,4 @@ -image: rocm/pytorch:rocm6.4_ubuntu24.04_py3.12_pytorch_release_2.6.0 +image: rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0 gpus: 1 memory_per_gpu: 64 # Gi cpu_per_gpu: 4 diff --git a/workloads/dev-workspace-jupyterlab/helm/values.yaml b/workloads/dev-workspace-jupyterlab/helm/values.yaml index 126401e..c9eb60e 100644 --- a/workloads/dev-workspace-jupyterlab/helm/values.yaml +++ b/workloads/dev-workspace-jupyterlab/helm/values.yaml @@ -9,7 +9,7 @@ metadata: user_id: user workload_id: # defaults to the release name -image: rocm/pytorch:rocm6.4_ubuntu24.04_py3.12_pytorch_release_2.6.0 +image: rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0 imagePullPolicy: Always imagePullSecrets: [] gpus: 1 @@ -46,15 +46,17 @@ deployment: # ref: https://kubernetes.io/docs/tasks/configure-pod-container/configure-liveness-readiness-startup-probes/ startupProbe: httpGet: - path: "{{ include \"httpRoute.baseUrl\" . }}/api/status" + path: '{{ include "httpRoute.baseUrl" . }}/api/status' port: http + failureThreshold: 60 + periodSeconds: 10 livenessProbe: httpGet: - path: "{{ include \"httpRoute.baseUrl\" . }}/api/status" + path: '{{ include "httpRoute.baseUrl" . }}/api/status' port: http readinessProbe: httpGet: - path: "{{ include \"httpRoute.baseUrl\" . }}/api/status" + path: '{{ include "httpRoute.baseUrl" . }}/api/status' port: http entrypoint: | From f8bb079ea240ea036ad86cf1f5c1de3893192e00 Mon Sep 17 00:00:00 2001 From: Bo Date: Tue, 2 Dec 2025 15:42:36 +0100 Subject: [PATCH 6/6] Update VS Code to ROCm 7.0.2 (#470) * EAI-470: update vscode dev workspace image to ROCm 7.0.2 * Update README and dev-center signature files * Update AWS S3 VSCode extension to version 1.8.5 * Update workloads/dev-workspace-vscode/helm/values.yaml Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update VSCode default settings, disable AI chat by default --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- workloads/dev-workspace-vscode/helm/README.md | 2 +- .../helm/mount/default_settings.json | 3 ++- .../helm/overrides/dev-center/signature.yaml | 2 +- workloads/dev-workspace-vscode/helm/values.yaml | 15 ++++++++------- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/workloads/dev-workspace-vscode/helm/README.md b/workloads/dev-workspace-vscode/helm/README.md index 2538ffa..8466229 100644 --- a/workloads/dev-workspace-vscode/helm/README.md +++ b/workloads/dev-workspace-vscode/helm/README.md @@ -28,7 +28,7 @@ You can configure the following parameters in the `values.yaml` file or override | Parameter | Description | Default | |------------------------|-----------------------------------------------------------------------------|-------------------------------------------------------------------------| -| `image` | Container image repository and tag | `rocm/pytorch:rocm6.4_ubuntu24.04_py3.12_pytorch_release_2.6.0` | +| `image` | Container image repository and tag | `rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0` | | `imagePullPolicy` | Image pull policy | `Always` | | `imagePullSecrets` | List of image pull secrets for private registries | `[]` | | `gpus` | Number of GPUs to allocate (set to 0 for CPU-only mode) | `1` | diff --git a/workloads/dev-workspace-vscode/helm/mount/default_settings.json b/workloads/dev-workspace-vscode/helm/mount/default_settings.json index d66e5c5..d9d6f31 100644 --- a/workloads/dev-workspace-vscode/helm/mount/default_settings.json +++ b/workloads/dev-workspace-vscode/helm/mount/default_settings.json @@ -1,3 +1,4 @@ { - "workbench.colorTheme": "Default Dark Modern" + "workbench.colorTheme": "Default Dark Modern", + "chat.disableAIFeatures": true } diff --git a/workloads/dev-workspace-vscode/helm/overrides/dev-center/signature.yaml b/workloads/dev-workspace-vscode/helm/overrides/dev-center/signature.yaml index 3516a77..527d61d 100644 --- a/workloads/dev-workspace-vscode/helm/overrides/dev-center/signature.yaml +++ b/workloads/dev-workspace-vscode/helm/overrides/dev-center/signature.yaml @@ -1,4 +1,4 @@ -image: rocm/pytorch:rocm6.4_ubuntu24.04_py3.12_pytorch_release_2.6.0 +image: rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0 gpus: 1 memory_per_gpu: 64 # Gi cpu_per_gpu: 4 diff --git a/workloads/dev-workspace-vscode/helm/values.yaml b/workloads/dev-workspace-vscode/helm/values.yaml index 59a07bf..75d76a4 100644 --- a/workloads/dev-workspace-vscode/helm/values.yaml +++ b/workloads/dev-workspace-vscode/helm/values.yaml @@ -12,7 +12,7 @@ pvc_annotations: pvc.silogen.ai/user-pvc-storage-class-name: "multinode" pvc.silogen.ai/user-pvc-uid: "{{ .Values.metadata.user_id }}" -image: rocm/pytorch:rocm6.4_ubuntu24.04_py3.12_pytorch_release_2.6.0 +image: rocm/pytorch:rocm7.0.2_ubuntu24.04_py3.12_pytorch_release_2.8.0 imagePullPolicy: Always imagePullSecrets: [] gpus: 1 @@ -50,7 +50,8 @@ startupProbe: httpGet: path: /healthz port: http - failureThreshold: 20 + failureThreshold: 60 + periodSeconds: 10 livenessProbe: httpGet: path: /healthz @@ -65,12 +66,12 @@ entrypoint: | curl -fsSL https://code-server.dev/install.sh | sh # Set up persistent VSCode configuration directory using environment variable - VSCODE_USER_DIR="$VSCODE_CONFIG_DIR/User" + VSCODE_USER_DIR="$VSCODE_CONFIG_DIR/code-server/User" # Create directory structure if it doesn't exist mkdir -p "$VSCODE_USER_DIR" - mkdir -p "$VSCODE_CONFIG_DIR/extensions" - mkdir -p "$VSCODE_CONFIG_DIR/logs" + mkdir -p "$VSCODE_CONFIG_DIR/code-server/extensions" + mkdir -p "$VSCODE_CONFIG_DIR/code-server/logs" # Copy default settings only if user settings don't exist (preserve user customizations) if [ ! -f "$VSCODE_USER_DIR/settings.json" ]; then @@ -82,8 +83,8 @@ entrypoint: | export XDG_CONFIG_HOME="$VSCODE_CONFIG_DIR" # Install extensions (these will be stored persistently) - curl -L -o /tmp/aws-s3-vscode-extension-1.8.4.vsix https://github.com/necatiarslan/aws-s3/raw/refs/heads/main/vsix/aws-s3-vscode-extension-1.8.4.vsix && - code-server --install-extension /tmp/aws-s3-vscode-extension-1.8.4.vsix + curl -L -o /tmp/aws-s3-vscode-extension-1.8.5.vsix https://github.com/necatiarslan/aws-s3/raw/refs/heads/main/vsix/aws-s3-vscode-extension-1.8.5.vsix && + code-server --install-extension /tmp/aws-s3-vscode-extension-1.8.5.vsix code-server --install-extension ms-python.python code-server --install-extension GitHub.vscode-pull-request-github code-server --install-extension ms-kubernetes-tools.vscode-kubernetes-tools